diff --git a/test/test_mp_collective_permute.py b/test/test_mp_collective_permute.py index 81a1eb771bcd..ecf3417edf7f 100644 --- a/test/test_mp_collective_permute.py +++ b/test/test_mp_collective_permute.py @@ -5,26 +5,68 @@ import torch_xla.core.xla_model as xm +def _test_single_tensor_collective_permute(device, world_size, ordinal, pairs): + value = torch.tensor([ordinal] * 100, dtype=torch.int32, device=device) + result_tensor = xm.collective_permute(value, pairs) + + result = result_tensor.cpu().tolist() + expected = [ordinal - 1] * 100 if ordinal != 0 else [world_size - 1] * 100 + + if result != expected: + print(f"Wrong result from core {ordinal}: {result}", file=sys.stderr) + return False + return True + + +def _test_multi_tensor_collective_permute(device, world_size, ordinal, pairs): + tensor1 = torch.tensor([ordinal] * 50, dtype=torch.int32, device=device) + tensor2 = torch.tensor([ordinal + 100] * 75, dtype=torch.int32, device=device) + tensor3 = torch.tensor( + [ordinal + 200] * 25, dtype=torch.float32, device=device) + + result_list = xm.collective_permute([tensor1, tensor2, tensor3], pairs) + expected_ordinal = ordinal - 1 if ordinal != 0 else world_size - 1 + + result1 = result_list[0].cpu().tolist() + expected1 = [expected_ordinal] * 50 + if result1 != expected1: + print(f"Wrong result from core {ordinal}: {result1}", file=sys.stderr) + return False + + result2 = result_list[1].cpu().tolist() + expected2 = [expected_ordinal + 100] * 75 + if result2 != expected2: + print(f"Wrong result from core {ordinal}: {result2}", file=sys.stderr) + return False + + result3 = result_list[2].cpu().tolist() + expected3 = [expected_ordinal + 200.0] * 25 + if result3 != expected3: + print(f"Wrong result from core {ordinal}: {result3}", file=sys.stderr) + return False + + return True + + def _mp_fn(index): device = torch_xla.device() if xm.xla_device_hw(device) in ['TPU', 'NEURON']: world_size = xr.world_size() ordinal = xr.global_ordinal() - value = torch.tensor([ordinal] * 100, dtype=torch.int32, device=device) pairs = [] for i in range(1, world_size): pairs.append([i - 1, i]) pairs.append([world_size - 1, 0]) - result_tensor = xm.collective_permute(value, pairs) - - result = result_tensor.cpu().tolist() - expected = [ordinal - 1] * 100 if ordinal != 0 else [world_size - 1] * 100 - - if result != expected: - print(f"Wrong result from core {ordinal}: {result}", file=sys.stderr) + if not _test_single_tensor_collective_permute(device, world_size, ordinal, + pairs): + sys.exit(1) + if not _test_multi_tensor_collective_permute(device, world_size, ordinal, + pairs): sys.exit(1) else: - print(f"Default device {device} is not a supported device", file=sys.stderr) + print( + f"Device {device} is not a supported device for this test", + file=sys.stderr) if __name__ == '__main__': diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 6b68e656d333..d972b5e50d2e 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -744,8 +744,9 @@ def all_to_all(value: torch.Tensor, return result[0] -def collective_permute(value: torch.Tensor, - pairs: List[List[int]]) -> torch.Tensor: +def collective_permute( + tensors: Union[torch.Tensor, List[torch.Tensor]], + pairs: List[List[int]]) -> Union[torch.Tensor, List[torch.Tensor]]: """Performs a XLA `CollectivePermute()` operation on the input tensor. WARNING: This function is not very reliable, may produce wrong results under @@ -754,7 +755,8 @@ def collective_permute(value: torch.Tensor, See: https://www.tensorflow.org/xla/operation_semantics#collectivepermute Args: - value (torch.Tensor): The input tensor. + tensors: Either a single `torch.Tensor` or a list of `torch.Tensor` to + perform the collective permute over. pairs (list): A list of (source_replica_id, target_replica_id) pairs, representing the sender and receiver for the `collective_permute()` operation. Example: `[[0, 1], [1, 2], [2, 0]]` defines three pairs. The @@ -762,12 +764,16 @@ def collective_permute(value: torch.Tensor, and replica 2 to replica 0. Returns: - The result `torch.Tensor` of the `collective_permute()` operation. + A single or list of `torch.Tensor` results of the `collective_permute()` operation. """ + is_single_operand = isinstance(tensors, torch.Tensor) + assert is_single_operand or (isinstance(tensors, list) and all( + isinstance(v, torch.Tensor) for v in tensors)) + token, devctx = _get_all_reduce_token() - result = torch_xla._XLAC._xla_collective_permute(value, token, pairs) - torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1]) - return result[0] + result = torch_xla._XLAC._xla_collective_permute(tensors, token, pairs) + torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) + return result[0] if is_single_operand else result[:-1] def collective_broadcast(tensors: List[torch.Tensor], diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index 56eeac4e6a41..179752c4c8e9 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -378,6 +378,24 @@ CollectivePermuteResult BuildCollectivePermute( return {result, token_handler.GetNewToken(result)}; } +MultiCollectivePermuteResult BuildCollectivePermute( + absl::Span inputs, xla::XlaOp token, + const std::vector>& source_target_pairs) { + TokenHandler token_handler(token); + std::vector result(inputs.size()); + std::vector input_ops; + for (const auto& input : inputs) { + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); + input_ops.push_back(token_handler.GetInput(input, &input_shape)); + } + xla::XlaOp collective_result = + xla::MultiCollectivePermute(input_ops, source_target_pairs); + for (size_t i = 0; i < inputs.size(); ++i) { + result[i] = xla::GetTupleElement(collective_result, i); + } + return {result, token_handler.GetNewToken(result[0])}; +} + SendResult BuildSendWithToken(xla::XlaOp input, xla::XlaOp token, int64_t channel_id) { xla::ChannelHandle channel_handle; diff --git a/torch_xla/csrc/cross_replica_reduces.h b/torch_xla/csrc/cross_replica_reduces.h index e6877a19aa72..a9abedfb2bd9 100644 --- a/torch_xla/csrc/cross_replica_reduces.h +++ b/torch_xla/csrc/cross_replica_reduces.h @@ -60,6 +60,11 @@ struct ReduceScatterResultCoalesced { xla::XlaOp token; }; +struct MultiCollectivePermuteResult { + std::vector results; + xla::XlaOp token; +}; + std::vector BuildAllReduce( AllReduceType reduce_type, absl::Span operands, xla::XlaOp token, double scale, @@ -90,6 +95,10 @@ CollectivePermuteResult BuildCollectivePermute( xla::XlaOp input, xla::XlaOp token, const std::vector>& source_target_pairs); +MultiCollectivePermuteResult BuildCollectivePermute( + absl::Span inputs, xla::XlaOp token, + const std::vector>& source_target_pairs); + SendResult BuildSendWithToken(xla::XlaOp input, xla::XlaOp token, int64_t channel_id); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c52ca3c74911..5f19b374e2e7 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -601,9 +601,8 @@ std::pair> AllToAll( std::tie(result, new_token) = tensor_methods::all_to_all( bridge::GetXlaTensor(input), *token, split_dimension, concat_dimension, split_count, replica_groups, pin_layout); - return std::pair>( - bridge::AtenFromXlaTensor(std::move(result)), - std::make_shared(new_token)); + return {bridge::AtenFromXlaTensor(std::move(result)), + std::make_shared(new_token)}; } std::pair> CollectivePermute( @@ -618,6 +617,24 @@ std::pair> CollectivePermute( std::make_shared(new_token)); } +std::pair, std::shared_ptr> +CollectivePermute( + const std::vector& tensors, + const std::shared_ptr& token, + const std::vector>& source_target_pairs) { + std::vector xtensors = + GetXlaTensors(tensors, /*want_all=*/true); + std::vector results; + torch::lazy::Value new_token; + std::tie(results, new_token) = + tensor_methods::collective_permute(xtensors, *token, source_target_pairs); + std::vector aten_results; + for (auto& xt : results) { + aten_results.emplace_back(bridge::AtenFromXlaTensor(std::move(xt))); + } + return {aten_results, std::make_shared(new_token)}; +} + void OptimizationBarrier_(std::vector& tensors) { std::vector xtensors = GetXlaTensors(tensors, /*want_all=*/false); @@ -1990,6 +2007,27 @@ void InitXlaModuleBindings(py::module m) { result_tuple[1] = new_token; return result_tuple; }) + .def("_xla_collective_permute", + [](const std::vector& inputs, + const std::shared_ptr& token, + const py::list& pairs) { + std::vector> source_target_pairs = + CreateSourceTargetPairs(pairs); + std::vector results; + std::shared_ptr new_token; + { + NoGilSection nogil; + std::tie(results, new_token) = + CollectivePermute(inputs, token, source_target_pairs); + } + auto result_list = py::list(results.size() + 1); + for (int i = 0; i < results.size(); ++i) { + result_list[i] = torch::autograd::make_variable( + results[i], /*requires_grad=*/results[i].requires_grad()); + } + result_list[results.size()] = new_token; + return result_list; + }) .def("_xla_send", [](const at::Tensor& input, const std::shared_ptr& token, diff --git a/torch_xla/csrc/ops/collective_permute.cpp b/torch_xla/csrc/ops/collective_permute.cpp index 048355012e0f..847f137fe06c 100644 --- a/torch_xla/csrc/ops/collective_permute.cpp +++ b/torch_xla/csrc/ops/collective_permute.cpp @@ -31,18 +31,71 @@ CollectivePermute::CollectivePermute( /*num_outputs=*/2, torch::lazy::MHash(source_target_pairs)), source_target_pairs_(std::move(source_target_pairs)) {} +CollectivePermute::CollectivePermute( + c10::ArrayRef inputs, const torch::lazy::Value& token, + std::vector> source_target_pairs) + : XlaNode( + xla_collective_permute, GetOperandListWithToken(inputs, token), + [&]() { + std::vector input_shapes; + for (const auto& input : inputs) { + input_shapes.push_back(GetXlaShape(input)); + } + input_shapes.push_back(GetXlaShape(token)); + auto shape_fn = + [&](absl::Span operands) -> xla::XlaOp { + std::vector input_ops(operands.begin(), + operands.end() - 1); + xla::XlaOp token_op = operands.back(); + MultiCollectivePermuteResult result = BuildCollectivePermute( + input_ops, token_op, source_target_pairs); + std::vector outputs = result.results; + outputs.push_back(result.token); + return xla::Tuple(operands[0].builder(), outputs); + }; + return InferOutputShape(input_shapes, shape_fn); + }, + /*num_outputs=*/inputs.size() + 1, + torch::lazy::MHash(source_target_pairs)), + source_target_pairs_(std::move(source_target_pairs)) {} + torch::lazy::NodePtr CollectivePermute::Clone( torch::lazy::OpList operands) const { - return torch_xla::MakeNode(operands.at(0), operands.at(1), - source_target_pairs_); + if (operands.size() > 2) { + std::vector inputs(operands.begin(), + operands.end() - 1); + return torch_xla::MakeNode(inputs, operands.back(), + source_target_pairs_); + } else { + return torch_xla::MakeNode( + operands.at(0), operands.at(1), source_target_pairs_); + } } XlaOpVector CollectivePermute::Lower(LoweringContext* loctx) const { - xla::XlaOp input = loctx->GetOutputOp(operand(0)); - xla::XlaOp token = loctx->GetOutputOp(operand(1)); - CollectivePermuteResult result = - BuildCollectivePermute(input, token, source_target_pairs_); - return ReturnOps({result.result, result.token}, loctx); + auto& operand_list = operands(); + size_t operand_list_size = operand_list.size(); + if (operand_list_size > 2) { + std::vector inputs; + inputs.reserve(operand_list_size); + for (size_t i = 0; i < operand_list_size - 1; ++i) { + inputs.push_back(loctx->GetOutputOp(operand(i))); + } + xla::XlaOp token = loctx->GetOutputOp(operand_list.back()); + + MultiCollectivePermuteResult result = + BuildCollectivePermute(inputs, token, source_target_pairs_); + + std::vector outputs = result.results; + outputs.push_back(result.token); + return ReturnOps(outputs, loctx); + } else { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp token = loctx->GetOutputOp(operand(1)); + CollectivePermuteResult result = + BuildCollectivePermute(input, token, source_target_pairs_); + return ReturnOps({result.result, result.token}, loctx); + } } std::string CollectivePermute::ToString() const { diff --git a/torch_xla/csrc/ops/collective_permute.h b/torch_xla/csrc/ops/collective_permute.h index 3a9fa83288ad..028eb39ab609 100644 --- a/torch_xla/csrc/ops/collective_permute.h +++ b/torch_xla/csrc/ops/collective_permute.h @@ -12,6 +12,10 @@ class CollectivePermute : public XlaNode { const torch::lazy::Value& input, const torch::lazy::Value& token, std::vector> source_target_pairs); + CollectivePermute( + c10::ArrayRef inputs, const torch::lazy::Value& token, + std::vector> source_target_pairs); + std::string ToString() const override; torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; @@ -28,4 +32,4 @@ class CollectivePermute : public XlaNode { } // namespace torch_xla -#endif // XLA_TORCH_XLA_CSRC_OPS_COLLECTIVE_PERMUTE_H_ \ No newline at end of file +#endif // XLA_TORCH_XLA_CSRC_OPS_COLLECTIVE_PERMUTE_H_ diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 80d799076048..826223e48ffc 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -564,6 +564,25 @@ std::pair collective_permute( torch::lazy::Value(node, 1)}; } +std::pair, torch::lazy::Value> collective_permute( + const std::vector& inputs, const torch::lazy::Value& token, + std::vector> source_target_pairs) { + std::vector input_values; + input_values.reserve(inputs.size()); + for (const auto& input : inputs) { + input_values.push_back(input->GetIrValue()); + } + torch::lazy::NodePtr node = torch_xla::MakeNode( + input_values, token, std::move(source_target_pairs)); + + std::vector result; + result.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i))); + } + return {result, torch::lazy::Value(node, inputs.size())}; +} + std::vector custom_call( const std::vector& inputs, const std::string& target, const std::vector>& output_shapes, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 79f6acd8049d..63dfc9d35fe7 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -90,6 +90,10 @@ std::pair collective_permute( const XLATensorPtr& input, const torch::lazy::Value& token, std::vector> source_target_pairs); +std::pair, torch::lazy::Value> collective_permute( + const std::vector& inputs, const torch::lazy::Value& token, + std::vector> source_target_pairs); + std::vector custom_call( const std::vector& inputs, const std::string& target, const std::vector>& output_shapes,