From a3e7c54d5e67128178e71e64f7665d6ae719d453 Mon Sep 17 00:00:00 2001 From: bfolie Date: Wed, 2 Jul 2025 18:51:47 +0000 Subject: [PATCH 1/2] implement all_to_all collective op --- test/pjrt/test_collective_ops_tpu.py | 30 ++++++++++++++++++++++ test/test_torch_distributed_xla_backend.py | 1 - torch_xla/distributed/xla_backend.py | 22 +++++++++++++--- 3 files changed, 49 insertions(+), 4 deletions(-) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 7ee9e7d8a66f..7b2b1f01d8a2 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -359,6 +359,36 @@ def test_all_to_all_single(self, use_dynamo): expected.sort().values), f"Got {val}, expected {expected}") + @staticmethod + def _all_to_all(): + dist.init_process_group("xla", init_method='xla://') + device = torch_xla.device() + world_size = xr.world_size() + rank = xr.global_ordinal() + + input_tensors = list( + torch.full([world_size * 2], + fill_value=rank, + dtype=torch.float, + device=device).chunk(world_size)) + output_tensors = list( + torch.empty([world_size * 2], dtype=torch.float, + device=device).chunk(world_size)) + dist.all_to_all(output_tensors, input_tensors) + + return [t.cpu() for t in output_tensors] + + def test_all_to_all(self): + # Input on device i is ([i, i], [i, i], ...). After all_to_all, + # output on every device is ([0, 0], [1, 1], ...). + results = pjrt.run_multiprocess(self._all_to_all) + expected = [ + torch.tensor([i, i], dtype=torch.float) + for i in range(tpu.num_expected_global_devices()) + ] + for _, value in results.items(): + torch.testing.assert_close(value, expected) + if __name__ == '__main__': absltest.main() diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index 99b721a4fa16..7eeb1f4f17c5 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -358,7 +358,6 @@ def test_barrier(self): @parameterized.parameters( 'reduce', 'allreduce_coalesced', - 'alltoall', 'gather', 'recv_anysource', 'monitored_barrier', diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index daef50c243dc..6f782b3c9a97 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -5,7 +5,7 @@ from torch_xla._internal import rendezvous import logging import os -from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions +from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, AllToAllOptions def _create_xla_process_group(prefix_store, rank, size, timeout): @@ -233,8 +233,24 @@ def reduce(self, *args): def allreduce_coalesced(self, *args): raise NotImplementedError - def alltoall(self, *args): - raise NotImplementedError + # Called by torch.distributed.all_to_all. Call site example: + # https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4577 + # The difference between this and all_to_all_single is that this works + # on a list of tensors while all_to_all_single works on a single tensor + # and splits/concats along dimension 0. + def alltoall(self, output_tensor_list: list[torch.Tensor], + input_tensor_list: list[torch.Tensor], opts: AllToAllOptions): + stacked_inputs = torch.stack(input_tensor_list, dim=0) + split_count = len(input_tensor_list) + stacked_results = xm.all_to_all( + stacked_inputs, + split_dimension=0, + concat_dimension=0, + split_count=split_count) + results = torch.chunk(stacked_results, split_count, dim=0) + for result, output_tensor in zip(results, output_tensor_list): + output_tensor.copy_(result.squeeze(dim=0)) + return _ret_work(output_tensor_list) # handle the nondynamo path when call torch.distributed.all_to_all_single # call from https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3996 From 2fdf36ce414f2a2eb7850cb83c58efe21e3dd970 Mon Sep 17 00:00:00 2001 From: Brendan Folie Date: Thu, 17 Jul 2025 16:02:46 -0700 Subject: [PATCH 2/2] re-annotate _scatter test as static --- test/pjrt/test_collective_ops_tpu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index f01244a534cc..a9681fe5e06f 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -366,6 +366,7 @@ def test_all_to_all(self): for _, value in results.items(): torch.testing.assert_close(value, expected) + @staticmethod def _scatter(): dist.init_process_group("xla", init_method='xla://') device = torch_xla.device()