Skip to content

implement collective all_to_all op #9442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,35 @@ def test_all_to_all_single(self, use_dynamo):
f"Got {val}, expected {expected}")

@staticmethod
def _all_to_all():
dist.init_process_group("xla", init_method='xla://')
device = torch_xla.device()
world_size = xr.world_size()
rank = xr.global_ordinal()

input_tensors = list(
torch.full([world_size * 2],
fill_value=rank,
dtype=torch.float,
device=device).chunk(world_size))
output_tensors = list(
torch.empty([world_size * 2], dtype=torch.float,
device=device).chunk(world_size))
dist.all_to_all(output_tensors, input_tensors)

return [t.cpu() for t in output_tensors]

def test_all_to_all(self):
# Input on device i is ([i, i], [i, i], ...). After all_to_all,
# output on every device is ([0, 0], [1, 1], ...).
results = pjrt.run_multiprocess(self._all_to_all)
expected = [
torch.tensor([i, i], dtype=torch.float)
for i in range(tpu.num_expected_global_devices())
]
for _, value in results.items():
torch.testing.assert_close(value, expected)

def _scatter():
dist.init_process_group("xla", init_method='xla://')
device = torch_xla.device()
Expand Down
1 change: 0 additions & 1 deletion test/test_torch_distributed_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,6 @@ def test_barrier(self):

@parameterized.parameters(
'allreduce_coalesced',
'alltoall',
'gather',
'recv_anysource',
'monitored_barrier',
Expand Down
22 changes: 19 additions & 3 deletions torch_xla/distributed/xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch_xla._internal import rendezvous
import logging
import os
from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, ReduceOptions
from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, AllToAllOptions, ReduceOptions


def _create_xla_process_group(prefix_store, rank, size, timeout):
Expand Down Expand Up @@ -247,8 +247,24 @@ def reduce(self, tensors: list[torch.Tensor], opts: ReduceOptions):
def allreduce_coalesced(self, *args):
raise NotImplementedError

def alltoall(self, *args):
raise NotImplementedError
# Called by torch.distributed.all_to_all. Call site example:
# https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4577
# The difference between this and all_to_all_single is that this works
# on a list of tensors while all_to_all_single works on a single tensor
# and splits/concats along dimension 0.
def alltoall(self, output_tensor_list: list[torch.Tensor],
input_tensor_list: list[torch.Tensor], opts: AllToAllOptions):
stacked_inputs = torch.stack(input_tensor_list, dim=0)
split_count = len(input_tensor_list)
stacked_results = xm.all_to_all(
stacked_inputs,
split_dimension=0,
concat_dimension=0,
split_count=split_count)
results = torch.chunk(stacked_results, split_count, dim=0)
for result, output_tensor in zip(results, output_tensor_list):
output_tensor.copy_(result.squeeze(dim=0))
return _ret_work(output_tensor_list)

# handle the nondynamo path when call torch.distributed.all_to_all_single
# call from https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3996
Expand Down