Skip to content

Implement collective gather op #9435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,59 @@ def test_scatter(self):
for ordinal, value in results.items():
np.testing.assert_array_equal(value, [ordinal])

@staticmethod
def _gather(scalar: bool = False):
dist.init_process_group("xla", init_method='xla://')
device = torch_xla.device()
world_size = xr.world_size()

# If scalar, tensors are tensor(i). Otherwise they are tensor([i]).
# The two cases follow different and should be tested separately.
if scalar:
item = xr.global_ordinal()
dummy = -1.0
else:
item = [xr.global_ordinal()]
dummy = [-1.0]

tensor = torch.tensor(item, device=device, dtype=torch.float)

# Instantiate tensors on device 0 to receive the results
output_tensors = None
if xr.global_ordinal() == 0:
output_tensors = [
torch.tensor(dummy, device=device, dtype=torch.float)
for _ in range(world_size)
]

dist.gather(tensor, output_tensors, dst=0)
if not output_tensors:
return None
else:
return [t.cpu() for t in output_tensors]

@parameterized.named_parameters(('scalar', True), ('tensor', False))
def test_gather(self, scalar):
# self._gather instantiates tensor i or [i], depending on the value of
# `scalar`, on device i. The results are gathered on device 0.
# All other devices get None.
results = pjrt.run_multiprocess(self._gather, scalar)
if scalar:
expected = [
torch.tensor(i, dtype=torch.float)
for i in range(tpu.num_expected_global_devices())
]
else:
expected = [
torch.tensor([i], dtype=torch.float)
for i in range(tpu.num_expected_global_devices())
]
for ordinal, value in results.items():
if ordinal == 0:
torch.testing.assert_close(value, expected)
else:
assert value is None

@staticmethod
def _reduce():
dist.init_process_group("xla", init_method='xla://')
Expand Down
1 change: 0 additions & 1 deletion test/test_torch_distributed_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,6 @@ def test_barrier(self):

@parameterized.parameters(
'allreduce_coalesced',
'gather',
'recv_anysource',
'monitored_barrier',
)
Expand Down
43 changes: 40 additions & 3 deletions torch_xla/distributed/xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch_xla._internal import rendezvous
import logging
import os
from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, AllToAllOptions, ReduceOptions
from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, AllToAllOptions, ReduceOptions, GatherOptions


def _create_xla_process_group(prefix_store, rank, size, timeout):
Expand Down Expand Up @@ -280,8 +280,45 @@ def alltoall_base(self, output, input, output_split_sizes, input_split_sizes,
output.copy_(result)
return _ret_work(output)

def gather(self, *args):
raise NotImplementedError
# Called by torch.distributed.gather. Call site example:
# https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4043
# Input tensors are gathered into list of output tensors on the dst device.
# Output tensors list is None for all non-dst devices.
# This is an inefficient operation. In order to avoid XLA deadlocks it
# performs redundant gathers on non-dst devices and materializes the result.
def gather(self, output_tensors_list: list[list[torch.Tensor]],
input_tensor_list: list[torch.Tensor], opts: GatherOptions):
rank = xr.global_ordinal()

for i, input_tensor in enumerate(input_tensor_list):
is_scalar = input_tensor.dim() == 0
input_for_all_gather = (
input_tensor.clone().reshape(1) if is_scalar else input_tensor)

gathered_tensor = xm.all_gather(
input_for_all_gather, dim=0, groups=self._mesh, pin_layout=False)

# Syncing is required to keep the heterogeneous copying below at the
# Python layer, avoiding deadlocks due to mismatched HLO.
torch_xla.sync()

if rank == opts.rootRank:
output_tensors = output_tensors_list[i]
if is_scalar:
for j in range(xr.world_size()):
output_tensors[j].copy_(gathered_tensor[j])
else:
chunk_size = input_tensor.shape[0]
gathered_chunks = torch.split(gathered_tensor, chunk_size, dim=0)
for j, chunk in enumerate(gathered_chunks):
if chunk.shape != output_tensors[j].shape:
chunk = chunk.reshape(output_tensors[j].shape)
output_tensors[j].copy_(chunk)

if rank == opts.rootRank:
return _ret_work(output_tensors_list)
else:
return _ret_work([[]])

# Called by torch.distributed.scatter. Call site example:
# https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4146
Expand Down
Loading