Skip to content

Commit a021bf0

Browse files
authored
Implement collective gather op (#9435)
1 parent e7dcc7b commit a021bf0

File tree

3 files changed

+93
-4
lines changed

3 files changed

+93
-4
lines changed

test/pjrt/test_collective_ops_tpu.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,59 @@ def test_scatter(self):
389389
for ordinal, value in results.items():
390390
np.testing.assert_array_equal(value, [ordinal])
391391

392+
@staticmethod
393+
def _gather(scalar: bool = False):
394+
dist.init_process_group("xla", init_method='xla://')
395+
device = torch_xla.device()
396+
world_size = xr.world_size()
397+
398+
# If scalar, tensors are tensor(i). Otherwise they are tensor([i]).
399+
# The two cases follow different and should be tested separately.
400+
if scalar:
401+
item = xr.global_ordinal()
402+
dummy = -1.0
403+
else:
404+
item = [xr.global_ordinal()]
405+
dummy = [-1.0]
406+
407+
tensor = torch.tensor(item, device=device, dtype=torch.float)
408+
409+
# Instantiate tensors on device 0 to receive the results
410+
output_tensors = None
411+
if xr.global_ordinal() == 0:
412+
output_tensors = [
413+
torch.tensor(dummy, device=device, dtype=torch.float)
414+
for _ in range(world_size)
415+
]
416+
417+
dist.gather(tensor, output_tensors, dst=0)
418+
if not output_tensors:
419+
return None
420+
else:
421+
return [t.cpu() for t in output_tensors]
422+
423+
@parameterized.named_parameters(('scalar', True), ('tensor', False))
424+
def test_gather(self, scalar):
425+
# self._gather instantiates tensor i or [i], depending on the value of
426+
# `scalar`, on device i. The results are gathered on device 0.
427+
# All other devices get None.
428+
results = pjrt.run_multiprocess(self._gather, scalar)
429+
if scalar:
430+
expected = [
431+
torch.tensor(i, dtype=torch.float)
432+
for i in range(tpu.num_expected_global_devices())
433+
]
434+
else:
435+
expected = [
436+
torch.tensor([i], dtype=torch.float)
437+
for i in range(tpu.num_expected_global_devices())
438+
]
439+
for ordinal, value in results.items():
440+
if ordinal == 0:
441+
torch.testing.assert_close(value, expected)
442+
else:
443+
assert value is None
444+
392445
@staticmethod
393446
def _reduce():
394447
dist.init_process_group("xla", init_method='xla://')

test/test_torch_distributed_xla_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,6 @@ def test_barrier(self):
357357

358358
@parameterized.parameters(
359359
'allreduce_coalesced',
360-
'gather',
361360
'recv_anysource',
362361
'monitored_barrier',
363362
)

torch_xla/distributed/xla_backend.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch_xla._internal import rendezvous
77
import logging
88
import os
9-
from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, AllToAllOptions, ReduceOptions
9+
from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, AllToAllOptions, ReduceOptions, GatherOptions
1010

1111

1212
def _create_xla_process_group(prefix_store, rank, size, timeout):
@@ -280,8 +280,45 @@ def alltoall_base(self, output, input, output_split_sizes, input_split_sizes,
280280
output.copy_(result)
281281
return _ret_work(output)
282282

283-
def gather(self, *args):
284-
raise NotImplementedError
283+
# Called by torch.distributed.gather. Call site example:
284+
# https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4043
285+
# Input tensors are gathered into list of output tensors on the dst device.
286+
# Output tensors list is None for all non-dst devices.
287+
# This is an inefficient operation. In order to avoid XLA deadlocks it
288+
# performs redundant gathers on non-dst devices and materializes the result.
289+
def gather(self, output_tensors_list: list[list[torch.Tensor]],
290+
input_tensor_list: list[torch.Tensor], opts: GatherOptions):
291+
rank = xr.global_ordinal()
292+
293+
for i, input_tensor in enumerate(input_tensor_list):
294+
is_scalar = input_tensor.dim() == 0
295+
input_for_all_gather = (
296+
input_tensor.clone().reshape(1) if is_scalar else input_tensor)
297+
298+
gathered_tensor = xm.all_gather(
299+
input_for_all_gather, dim=0, groups=self._mesh, pin_layout=False)
300+
301+
# Syncing is required to keep the heterogeneous copying below at the
302+
# Python layer, avoiding deadlocks due to mismatched HLO.
303+
torch_xla.sync()
304+
305+
if rank == opts.rootRank:
306+
output_tensors = output_tensors_list[i]
307+
if is_scalar:
308+
for j in range(xr.world_size()):
309+
output_tensors[j].copy_(gathered_tensor[j])
310+
else:
311+
chunk_size = input_tensor.shape[0]
312+
gathered_chunks = torch.split(gathered_tensor, chunk_size, dim=0)
313+
for j, chunk in enumerate(gathered_chunks):
314+
if chunk.shape != output_tensors[j].shape:
315+
chunk = chunk.reshape(output_tensors[j].shape)
316+
output_tensors[j].copy_(chunk)
317+
318+
if rank == opts.rootRank:
319+
return _ret_work(output_tensors_list)
320+
else:
321+
return _ret_work([[]])
285322

286323
# Called by torch.distributed.scatter. Call site example:
287324
# https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4146

0 commit comments

Comments
 (0)