|
6 | 6 | from torch_xla._internal import rendezvous
|
7 | 7 | import logging
|
8 | 8 | import os
|
9 |
| -from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, AllToAllOptions, ReduceOptions |
| 9 | +from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, AllToAllOptions, ReduceOptions, GatherOptions |
10 | 10 |
|
11 | 11 |
|
12 | 12 | def _create_xla_process_group(prefix_store, rank, size, timeout):
|
@@ -280,8 +280,45 @@ def alltoall_base(self, output, input, output_split_sizes, input_split_sizes,
|
280 | 280 | output.copy_(result)
|
281 | 281 | return _ret_work(output)
|
282 | 282 |
|
283 |
| - def gather(self, *args): |
284 |
| - raise NotImplementedError |
| 283 | + # Called by torch.distributed.gather. Call site example: |
| 284 | + # https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4043 |
| 285 | + # Input tensors are gathered into list of output tensors on the dst device. |
| 286 | + # Output tensors list is None for all non-dst devices. |
| 287 | + # This is an inefficient operation. In order to avoid XLA deadlocks it |
| 288 | + # performs redundant gathers on non-dst devices and materializes the result. |
| 289 | + def gather(self, output_tensors_list: list[list[torch.Tensor]], |
| 290 | + input_tensor_list: list[torch.Tensor], opts: GatherOptions): |
| 291 | + rank = xr.global_ordinal() |
| 292 | + |
| 293 | + for i, input_tensor in enumerate(input_tensor_list): |
| 294 | + is_scalar = input_tensor.dim() == 0 |
| 295 | + input_for_all_gather = ( |
| 296 | + input_tensor.clone().reshape(1) if is_scalar else input_tensor) |
| 297 | + |
| 298 | + gathered_tensor = xm.all_gather( |
| 299 | + input_for_all_gather, dim=0, groups=self._mesh, pin_layout=False) |
| 300 | + |
| 301 | + # Syncing is required to keep the heterogeneous copying below at the |
| 302 | + # Python layer, avoiding deadlocks due to mismatched HLO. |
| 303 | + torch_xla.sync() |
| 304 | + |
| 305 | + if rank == opts.rootRank: |
| 306 | + output_tensors = output_tensors_list[i] |
| 307 | + if is_scalar: |
| 308 | + for j in range(xr.world_size()): |
| 309 | + output_tensors[j].copy_(gathered_tensor[j]) |
| 310 | + else: |
| 311 | + chunk_size = input_tensor.shape[0] |
| 312 | + gathered_chunks = torch.split(gathered_tensor, chunk_size, dim=0) |
| 313 | + for j, chunk in enumerate(gathered_chunks): |
| 314 | + if chunk.shape != output_tensors[j].shape: |
| 315 | + chunk = chunk.reshape(output_tensors[j].shape) |
| 316 | + output_tensors[j].copy_(chunk) |
| 317 | + |
| 318 | + if rank == opts.rootRank: |
| 319 | + return _ret_work(output_tensors_list) |
| 320 | + else: |
| 321 | + return _ret_work([[]]) |
285 | 322 |
|
286 | 323 | # Called by torch.distributed.scatter. Call site example:
|
287 | 324 | # https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4146
|
|
0 commit comments