Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2233,6 +2233,10 @@ void InitXlaModuleBindings(py::module m) {
[](const at::Tensor& input, xla::OpSharding sharding) {
ShardingUtil::XlaMarkSharding(input, sharding);
});
m.def("_load_global_tensor_to_local_shards",
[](const at::Tensor& input, xla::OpSharding sharding, const std::vector<int64_t>& local_shape) {
ShardingUtil::XlaGlobalTensorFromLocalProcessData(input, sharding, local_shape);
});
m.def("_mark_manual_sharding",
[](const at::Tensor& input, xla::OpSharding sharding) {
XLA_CHECK(IsNonDeviceDataIR(input))
Expand Down
42 changes: 42 additions & 0 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,48 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
return WrapXlaData(handles);
}

std::vector<torch::lazy::BackendDataPtr> CreateGlobalTensorsData(
const std::vector<at::Tensor>& tensors,
const std::vector<XLATensor::ShardingSpecPtr>& shardings,
const std::vector<std::string>& devices,
const xla::Shape local_shape) {
TORCH_LAZY_TIMED("CreateGlobalTensorsData");
XLA_CHECK_EQ(tensors.size(), shardings.size());
XLA_CHECK_EQ(tensors.size(), devices.size());

std::vector<runtime::ComputationClient::DataPtr> handles;
for (size_t i = 0; i < tensors.size(); ++i) {
torch::lazy::BackendDevice device = ParseDeviceString(devices[i]);
xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device);

std::vector<std::shared_ptr<const runtime::TensorSource>>
source_tensors; // in
std::vector<runtime::ComputationClient::DataPtr> new_handles; // out
if (static_cast<XlaDeviceType>(device.type()) == XlaDeviceType::SPMD) {
// GetLocalDevices returns the list of local devices specified by their
// global ordinals (e.g. ["TPU:4", "TPU:5", "TPU:6", "TPU:7"]).

std::vector<std::string> local_devices =
runtime::GetComputationClient()->GetLocalDevices();
// Shards the input tensors with padding, to split evenly.
// The execution requires consistent shard sizes, and the zero-padded
// values should be ignored.
std::vector<at::Tensor> local_shards =
ShardingUtil::ShardTensor(tensors[i], shardings[i], local_devices,
/*padded=*/true);
new_handles.push_back(ShardingUtil::CreateGlobalShardedData(
local_shards, local_devices, shardings[i], local_shape));
} else {
source_tensors.push_back(std::make_shared<runtime::AtenSource>(
tensors[i], std::move(shape), devices[i]));
new_handles =
runtime::GetComputationClient()->TransferToDevice(source_tensors);
}
handles.insert(handles.end(), new_handles.begin(), new_handles.end());
}
return WrapXlaData(handles);
}

xla::Literal GetTensorLiteral(const at::Tensor& tensor, const xla::Shape* shape,
const torch::lazy::BackendDevice* device) {
torch::lazy::BackendDevice xla_device = bridge::GetDeviceOrCurrent(device);
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/tensor_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
const std::vector<at::Tensor>& tensors,
const std::vector<std::string>& devices);


std::vector<torch::lazy::BackendDataPtr> CreateGlobalTensorsData(
const std::vector<at::Tensor>& tensors,
const std::vector<XLATensor::ShardingSpecPtr>& shardings,
const std::vector<std::string>& devices,
const xla::Shape local_shape);


// Shard and transfer tensors to devices using `PjRtComputationClient`.
// The client's data transfer to device is asynchronous.
std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
Expand Down
100 changes: 100 additions & 0 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,25 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData(
source_tensors, GetVirtualDevice().toString(), global_shape, sharding);
}


runtime::ComputationClient::DataPtr ShardingUtil::CreateGlobalShardedData(
const std::vector<at::Tensor>& local_shards,
const std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec,
const xla::Shape local_shape) {

std::vector<std::shared_ptr<const runtime::TensorSource>> source_tensors;
for (int64_t j = 0; j < devices.size(); ++j) {
auto shard_device = ParseDeviceString(devices[j]);
auto shard_shape =
CreateComputationShapeFromTensor(local_shards[j], &shard_device);
source_tensors.push_back(std::make_shared<runtime::AtenSource>(
local_shards[j], shard_shape, devices[j]));
}
return runtime::GetComputationClient()->TransferShardsToDevice(
source_tensors, GetVirtualDevice().toString(), local_shape, sharding_spec->sharding);
}

std::vector<int64_t> ShardingUtil::GetAutoShardingMesh() {
// Auto-sharding uses mesh_shape = {n_devices, 1} if XLA_AUTO_SPMD_MESH
// is not set. XLA_AUTO_SPMD_MESH takes a form of string, "2,2" which
Expand Down Expand Up @@ -833,6 +852,87 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
}

void ShardingUtil::XlaGlobalTensorFromLocalProcessData(const at::Tensor& input,
xla::OpSharding sharding,
const std::vector<int64_t>& local_shape) {
TORCH_LAZY_COUNTER("XlaGlobalTensorFromLocalProcessData", 1);
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
XLA_CHECK(sharding.type() != xla::OpSharding::UNKNOWN)
<< "Can't explicilty annotate with UNKNOWN sharding type.";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
XLATensor::ShardingSpecPtr new_sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(), static_cast<XlaDeviceType>(
xtensor->GetDevice().type())));

// For Non DeviceData IR values, we directly attach the sharding spec
// to the xtensor.
const DeviceData* device_data_node = nullptr;
if (xtensor->CurrentIrValue()) {
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (!device_data_node) {
tensor_methods::custom_sharding_(xtensor, new_sharding_spec);
return;
}
}

// For data, we need to deal with the data transfers between
// host and device.
at::Tensor cpu_tensor;
if (xtensor->CurrentTensorData().has_value()) {
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1);
// When virtual device is enabled for SPMD, we defer the initial
// data transfer to the device and retain the original data on the
// host, until the sharded data transfer.
cpu_tensor = xtensor->CurrentTensorData().value();
} else {
// A new input tensor is not expected to be sharded. But sometimes,
// the same input is called for sharding annotation over multiple steps,
// in which case we can skip if it's the same sharding; however, if it's
// the same input with a different sharding then we block & ask the user
// to clear the existing sharding first.
XLATensor::ShardingSpecPtr current_sharding_spec = xtensor->sharding_spec();
if (current_sharding_spec) {
if (ShardingUtil::EqualShardingSpecs(*new_sharding_spec,
*current_sharding_spec)) {
return;
}
auto type = current_sharding_spec->sharding.type();
if (type != xla::OpSharding::REPLICATED &&
type != xla::OpSharding::UNKNOWN) {
XLA_CHECK(false) << "Existing annotation must be cleared first: "
<< current_sharding_spec->sharding.DebugString();
}
}

// If the at::Tensor data is not present, we need to re-download the
// tensor from the physical device to CPU. In that case, the value
// must be present on the backend device.
XLA_CHECK((xtensor->CurrentDataHandle() &&
xtensor->CurrentDataHandle()->HasValue()) ||
device_data_node != nullptr)
<< "Cannot shard tensor. Data does not present on any device.";
std::vector<XLATensorPtr> xla_tensors{xtensor};
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0];
}

xla::PrimitiveType size_type = GetShapeDimensionType(/*device=*/nullptr);
auto xla_local_shape = xla::ShapeUtil::MakeShape(size_type, local_shape);

auto xla_data = CreateGlobalTensorsData(
std::vector<at::Tensor>{cpu_tensor},
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec},
std::vector<std::string>{GetVirtualDevice().toString()},
xla_local_shape)[0];
xtensor->SetXlaData(xla_data);
xtensor->SetShardingSpec(*new_sharding_spec);

// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
}

void ShardingUtil::SetAutoSharding() {
// This stays on throughout the program.
use_auto_sharding = true;
Expand Down
9 changes: 9 additions & 0 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,18 @@ class ShardingUtil {
const std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec);

static runtime::ComputationClient::DataPtr CreateGlobalShardedData(
const std::vector<at::Tensor>& shards,
const std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec,
xla::Shape local_shape);

static void XlaMarkSharding(const at::Tensor& input,
xla::OpSharding sharding);

static void XlaGlobalTensorFromLocalProcessData(const at::Tensor& input,
xla::OpSharding sharding,
const std::vector<int64_t>& local_shape);
//////////////////////////// Auto-Sharding ////////////////////////////

// Construct a device mesh for auto-sharding pass. Returns a tuple of mesh
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,16 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
annotate_func(unwrap_sharded_tensor(t), op_sharding)
return wrap_as_sharded_tensor(t)

def create_global_tensor_from_local_process_data(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
partition_spec: PartitionSpec, local_shape) -> XLAShardedTensor:
assert len(t.shape) == len(partition_spec), \
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."

op_sharding = mesh.get_op_sharding(partition_spec)
annotate_func = torch_xla._XLAC._load_global_tensor_to_local_shards
annotate_func(unwrap_sharded_tensor(t), op_sharding, local_shape)
return wrap_as_sharded_tensor(t)


def mark_sharding_with_gradients(
t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
Expand Down
Loading