Skip to content

Introduce multi-operand collective permute #9450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 51 additions & 9 deletions test/test_mp_collective_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,68 @@
import torch_xla.core.xla_model as xm


def _test_single_tensor_collective_permute(device, world_size, ordinal, pairs):
value = torch.tensor([ordinal] * 100, dtype=torch.int32, device=device)
result_tensor = xm.collective_permute(value, pairs)

result = result_tensor.cpu().tolist()
expected = [ordinal - 1] * 100 if ordinal != 0 else [world_size - 1] * 100

if result != expected:
print(f"Wrong result from core {ordinal}: {result}", file=sys.stderr)
return False
return True


def _test_multi_tensor_collective_permute(device, world_size, ordinal, pairs):
tensor1 = torch.tensor([ordinal] * 50, dtype=torch.int32, device=device)
tensor2 = torch.tensor([ordinal + 100] * 75, dtype=torch.int32, device=device)
tensor3 = torch.tensor(
[ordinal + 200] * 25, dtype=torch.float32, device=device)

result_list = xm.collective_permute([tensor1, tensor2, tensor3], pairs)
expected_ordinal = ordinal - 1 if ordinal != 0 else world_size - 1

result1 = result_list[0].cpu().tolist()
expected1 = [expected_ordinal] * 50
if result1 != expected1:
print(f"Wrong result from core {ordinal}: {result1}", file=sys.stderr)
return False

result2 = result_list[1].cpu().tolist()
expected2 = [expected_ordinal + 100] * 75
if result2 != expected2:
print(f"Wrong result from core {ordinal}: {result2}", file=sys.stderr)
return False

result3 = result_list[2].cpu().tolist()
expected3 = [expected_ordinal + 200.0] * 25
if result3 != expected3:
print(f"Wrong result from core {ordinal}: {result3}", file=sys.stderr)
return False

return True


def _mp_fn(index):
device = torch_xla.device()
if xm.xla_device_hw(device) in ['TPU', 'NEURON']:
world_size = xr.world_size()
ordinal = xr.global_ordinal()
value = torch.tensor([ordinal] * 100, dtype=torch.int32, device=device)
pairs = []
for i in range(1, world_size):
pairs.append([i - 1, i])
pairs.append([world_size - 1, 0])
result_tensor = xm.collective_permute(value, pairs)

result = result_tensor.cpu().tolist()
expected = [ordinal - 1] * 100 if ordinal != 0 else [world_size - 1] * 100

if result != expected:
print(f"Wrong result from core {ordinal}: {result}", file=sys.stderr)
if not _test_single_tensor_collective_permute(device, world_size, ordinal,
pairs):
sys.exit(1)
if not _test_multi_tensor_collective_permute(device, world_size, ordinal,
pairs):
sys.exit(1)
else:
print(f"Default device {device} is not a supported device", file=sys.stderr)
print(
f"Device {device} is not a supported device for this test",
file=sys.stderr)


if __name__ == '__main__':
Expand Down
20 changes: 13 additions & 7 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,9 @@ def all_to_all(value: torch.Tensor,
return result[0]


def collective_permute(value: torch.Tensor,
pairs: List[List[int]]) -> torch.Tensor:
def collective_permute(
tensors: Union[torch.Tensor, List[torch.Tensor]],
pairs: List[List[int]]) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Performs a XLA `CollectivePermute()` operation on the input tensor.

WARNING: This function is not very reliable, may produce wrong results under
Expand All @@ -754,20 +755,25 @@ def collective_permute(value: torch.Tensor,
See: https://www.tensorflow.org/xla/operation_semantics#collectivepermute

Args:
value (torch.Tensor): The input tensor.
tensors: Either a single `torch.Tensor` or a list of `torch.Tensor` to
perform the collective permute over.
pairs (list): A list of (source_replica_id, target_replica_id) pairs,
representing the sender and receiver for the `collective_permute()`
operation. Example: `[[0, 1], [1, 2], [2, 0]]` defines three pairs. The
tensor will be sent from replica 0 to replica 1, replica 1 to replica 2,
and replica 2 to replica 0.

Returns:
The result `torch.Tensor` of the `collective_permute()` operation.
A single or list of `torch.Tensor` results of the `collective_permute()` operation.
"""
is_single_operand = isinstance(tensors, torch.Tensor)
assert is_single_operand or (isinstance(tensors, list) and all(
isinstance(v, torch.Tensor) for v in tensors))

token, devctx = _get_all_reduce_token()
result = torch_xla._XLAC._xla_collective_permute(value, token, pairs)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]
result = torch_xla._XLAC._xla_collective_permute(tensors, token, pairs)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
return result[0] if is_single_operand else result[:-1]


def collective_broadcast(tensors: List[torch.Tensor],
Expand Down
18 changes: 18 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,24 @@ CollectivePermuteResult BuildCollectivePermute(
return {result, token_handler.GetNewToken(result)};
}

MultiCollectivePermuteResult BuildCollectivePermute(
absl::Span<const xla::XlaOp> inputs, xla::XlaOp token,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs) {
TokenHandler token_handler(token);
std::vector<xla::XlaOp> result(inputs.size());
std::vector<xla::XlaOp> input_ops;
for (const auto& input : inputs) {
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
input_ops.push_back(token_handler.GetInput(input, &input_shape));
}
xla::XlaOp collective_result =
xla::MultiCollectivePermute(input_ops, source_target_pairs);
for (size_t i = 0; i < inputs.size(); ++i) {
result[i] = xla::GetTupleElement(collective_result, i);
}
return {result, token_handler.GetNewToken(result[0])};
}

SendResult BuildSendWithToken(xla::XlaOp input, xla::XlaOp token,
int64_t channel_id) {
xla::ChannelHandle channel_handle;
Expand Down
9 changes: 9 additions & 0 deletions torch_xla/csrc/cross_replica_reduces.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ struct ReduceScatterResultCoalesced {
xla::XlaOp token;
};

struct MultiCollectivePermuteResult {
std::vector<xla::XlaOp> results;
xla::XlaOp token;
};

std::vector<xla::XlaOp> BuildAllReduce(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> operands,
xla::XlaOp token, double scale,
Expand Down Expand Up @@ -90,6 +95,10 @@ CollectivePermuteResult BuildCollectivePermute(
xla::XlaOp input, xla::XlaOp token,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs);

MultiCollectivePermuteResult BuildCollectivePermute(
absl::Span<const xla::XlaOp> inputs, xla::XlaOp token,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs);

SendResult BuildSendWithToken(xla::XlaOp input, xla::XlaOp token,
int64_t channel_id);

Expand Down
44 changes: 41 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -601,9 +601,8 @@ std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> AllToAll(
std::tie(result, new_token) = tensor_methods::all_to_all(
bridge::GetXlaTensor(input), *token, split_dimension, concat_dimension,
split_count, replica_groups, pin_layout);
return std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>>(
bridge::AtenFromXlaTensor(std::move(result)),
std::make_shared<torch::lazy::Value>(new_token));
return {bridge::AtenFromXlaTensor(std::move(result)),
std::make_shared<torch::lazy::Value>(new_token)};
}

std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> CollectivePermute(
Expand All @@ -618,6 +617,24 @@ std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> CollectivePermute(
std::make_shared<torch::lazy::Value>(new_token));
}

std::pair<std::vector<at::Tensor>, std::shared_ptr<torch::lazy::Value>>
CollectivePermute(
const std::vector<at::Tensor>& tensors,
const std::shared_ptr<torch::lazy::Value>& token,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs) {
std::vector<XLATensorPtr> xtensors =
GetXlaTensors(tensors, /*want_all=*/true);
std::vector<XLATensorPtr> results;
torch::lazy::Value new_token;
std::tie(results, new_token) =
tensor_methods::collective_permute(xtensors, *token, source_target_pairs);
std::vector<at::Tensor> aten_results;
for (auto& xt : results) {
aten_results.emplace_back(bridge::AtenFromXlaTensor(std::move(xt)));
}
return {aten_results, std::make_shared<torch::lazy::Value>(new_token)};
}

void OptimizationBarrier_(std::vector<at::Tensor>& tensors) {
std::vector<XLATensorPtr> xtensors =
GetXlaTensors(tensors, /*want_all=*/false);
Expand Down Expand Up @@ -1990,6 +2007,27 @@ void InitXlaModuleBindings(py::module m) {
result_tuple[1] = new_token;
return result_tuple;
})
.def("_xla_collective_permute",
[](const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token,
const py::list& pairs) {
std::vector<std::pair<int64_t, int64_t>> source_target_pairs =
CreateSourceTargetPairs(pairs);
std::vector<at::Tensor> results;
std::shared_ptr<torch::lazy::Value> new_token;
{
NoGilSection nogil;
std::tie(results, new_token) =
CollectivePermute(inputs, token, source_target_pairs);
}
auto result_list = py::list(results.size() + 1);
for (int i = 0; i < results.size(); ++i) {
result_list[i] = torch::autograd::make_variable(
results[i], /*requires_grad=*/results[i].requires_grad());
}
result_list[results.size()] = new_token;
return result_list;
})
.def("_xla_send",
[](const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token,
Expand Down
67 changes: 60 additions & 7 deletions torch_xla/csrc/ops/collective_permute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,71 @@ CollectivePermute::CollectivePermute(
/*num_outputs=*/2, torch::lazy::MHash(source_target_pairs)),
source_target_pairs_(std::move(source_target_pairs)) {}

CollectivePermute::CollectivePermute(
c10::ArrayRef<torch::lazy::Value> inputs, const torch::lazy::Value& token,
std::vector<std::pair<int64_t, int64_t>> source_target_pairs)
: XlaNode(
xla_collective_permute, GetOperandListWithToken(inputs, token),
[&]() {
std::vector<xla::Shape> input_shapes;
for (const auto& input : inputs) {
input_shapes.push_back(GetXlaShape(input));
}
input_shapes.push_back(GetXlaShape(token));
auto shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
std::vector<xla::XlaOp> input_ops(operands.begin(),
operands.end() - 1);
xla::XlaOp token_op = operands.back();
MultiCollectivePermuteResult result = BuildCollectivePermute(
input_ops, token_op, source_target_pairs);
std::vector<xla::XlaOp> outputs = result.results;
outputs.push_back(result.token);
return xla::Tuple(operands[0].builder(), outputs);
};
return InferOutputShape(input_shapes, shape_fn);
},
/*num_outputs=*/inputs.size() + 1,
torch::lazy::MHash(source_target_pairs)),
source_target_pairs_(std::move(source_target_pairs)) {}

torch::lazy::NodePtr CollectivePermute::Clone(
torch::lazy::OpList operands) const {
return torch_xla::MakeNode<CollectivePermute>(operands.at(0), operands.at(1),
source_target_pairs_);
if (operands.size() > 2) {
std::vector<torch::lazy::Value> inputs(operands.begin(),
operands.end() - 1);
return torch_xla::MakeNode<CollectivePermute>(inputs, operands.back(),
source_target_pairs_);
} else {
return torch_xla::MakeNode<CollectivePermute>(
operands.at(0), operands.at(1), source_target_pairs_);
}
}

XlaOpVector CollectivePermute::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp token = loctx->GetOutputOp(operand(1));
CollectivePermuteResult result =
BuildCollectivePermute(input, token, source_target_pairs_);
return ReturnOps({result.result, result.token}, loctx);
auto& operand_list = operands();
size_t operand_list_size = operand_list.size();
if (operand_list_size > 2) {
std::vector<xla::XlaOp> inputs;
inputs.reserve(operand_list_size);
for (size_t i = 0; i < operand_list_size - 1; ++i) {
inputs.push_back(loctx->GetOutputOp(operand(i)));
}
xla::XlaOp token = loctx->GetOutputOp(operand_list.back());

MultiCollectivePermuteResult result =
BuildCollectivePermute(inputs, token, source_target_pairs_);

std::vector<xla::XlaOp> outputs = result.results;
outputs.push_back(result.token);
return ReturnOps(outputs, loctx);
} else {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp token = loctx->GetOutputOp(operand(1));
CollectivePermuteResult result =
BuildCollectivePermute(input, token, source_target_pairs_);
return ReturnOps({result.result, result.token}, loctx);
}
}

std::string CollectivePermute::ToString() const {
Expand Down
6 changes: 5 additions & 1 deletion torch_xla/csrc/ops/collective_permute.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ class CollectivePermute : public XlaNode {
const torch::lazy::Value& input, const torch::lazy::Value& token,
std::vector<std::pair<int64_t, int64_t>> source_target_pairs);

CollectivePermute(
c10::ArrayRef<torch::lazy::Value> inputs, const torch::lazy::Value& token,
std::vector<std::pair<int64_t, int64_t>> source_target_pairs);

std::string ToString() const override;

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;
Expand All @@ -28,4 +32,4 @@ class CollectivePermute : public XlaNode {

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_COLLECTIVE_PERMUTE_H_
#endif // XLA_TORCH_XLA_CSRC_OPS_COLLECTIVE_PERMUTE_H_
19 changes: 19 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,25 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
torch::lazy::Value(node, 1)};
}

std::pair<std::vector<XLATensorPtr>, torch::lazy::Value> collective_permute(
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
std::vector<std::pair<int64_t, int64_t>> source_target_pairs) {
std::vector<torch::lazy::Value> input_values;
input_values.reserve(inputs.size());
for (const auto& input : inputs) {
input_values.push_back(input->GetIrValue());
}
torch::lazy::NodePtr node = torch_xla::MakeNode<CollectivePermute>(
input_values, token, std::move(source_target_pairs));

std::vector<XLATensorPtr> result;
result.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i)));
}
return {result, torch::lazy::Value(node, inputs.size())};
}

std::vector<XLATensorPtr> custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& target,
const std::vector<std::vector<int64_t>>& output_shapes,
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
const XLATensorPtr& input, const torch::lazy::Value& token,
std::vector<std::pair<int64_t, int64_t>> source_target_pairs);

std::pair<std::vector<XLATensorPtr>, torch::lazy::Value> collective_permute(
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
std::vector<std::pair<int64_t, int64_t>> source_target_pairs);

std::vector<XLATensorPtr> custom_call(
const std::vector<XLATensorPtr>& inputs, const std::string& target,
const std::vector<std::vector<int64_t>>& output_shapes,
Expand Down
Loading