From 2d2f08ab95452fc0e36e12339db163652aed5050 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Sun, 9 Mar 2025 08:01:01 +0000 Subject: [PATCH 01/16] update mesh/xla_sharding python api for local spmd --- torch_xla/distributed/spmd/xla_sharding.py | 18 ++++++++++++++---- torch_xla/runtime.py | 2 +- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index a1cd9540fd1c..de2714ad2492 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -68,7 +68,7 @@ def __init__(self, self.device_ids = device_ids self.mesh_shape = mesh_shape self.axis_names = axis_names - assert all(d < self.size() for d in device_ids) + # assert all(d < self.size() for d in device_ids) def size(self): return np.prod(self.mesh_shape) @@ -127,6 +127,10 @@ def get_op_sharding(self, tile_assignment, group_assignment, replication_groups, sharding_type = self._get_op_sharding_args( partition_spec) + print(f"check tile_assignment: {tile_assignment}") + print(f"check group_assignment: {group_assignment}") + print(f"check replication_groups: {replication_groups}") + print(f"check sharding_type: {sharding_type}") return torch_xla._XLAC.OpSharding(tile_assignment, group_assignment, replication_groups, sharding_type) @@ -377,6 +381,11 @@ def _get_sharding_type(partition_spec: Tuple[Union[int, None]], return sharding_type +def _normalize_logical_mesh(device_mesh: np.ndarray) -> np.ndarray: + device_id_min = np.min(device_mesh) + return device_mesh.copy() - device_id_min + + def _get_tile_assignment( mesh: Mesh, partition_spec: Tuple[Union[Tuple[int], int, None]]) -> np.ndarray: @@ -393,8 +402,8 @@ def _get_tile_assignment( tiled_dims = [x for x in partition_spec if x is not None] permutation = np.hstack(tiled_dims).tolist() if tiled_dims else [] missing_axes = sorted(set(range(len(mesh.shape()))) - set(permutation)) - tile_assignment = mesh.get_logical_mesh().transpose(permutation + - missing_axes) + tile_assignment = _normalize_logical_mesh( + mesh.get_logical_mesh()).transpose(permutation + missing_axes) # For any tuples in the partition_spec, the grouped axes will be adjacent # after the permutation. Combine these dimensions into a single axis. @@ -548,8 +557,9 @@ def mark_sharding( >>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel """ num_devices = xr.global_runtime_device_count() + num_local_devices = xr.addressable_runtime_device_count() assert num_devices > 0, "This requires XLA supported device(s)." - assert mesh.size() == num_devices, \ + assert mesh.size() == num_devices or mesh.size() == num_local_devices, \ f"{mesh.mesh_shape} is not mappable over {num_devices} devices." # We only allow fully specified `partition_spec` to be applicable, as opposed # to filling in the unspecified replicated dims. Fully specified `partiion_spec` diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 1946ae05a52b..a17b1c57e3b6 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -212,7 +212,7 @@ def global_runtime_device_attributes() -> List[Dict[str, object]]: @functools.lru_cache() def global_runtime_device_count() -> int: """Returns the total number of runtime devices across all processes/hosts, especially useful for SPMD.""" - return len(torch_xla._XLAC._xla_get_all_runtime_devices()) + return torch_xla._XLAC._xla_num_global_devices() def addressable_runtime_device_count() -> int: From c4aa8542eb00fda59a9281043d5ec2dce7517820 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Sun, 9 Mar 2025 08:02:57 +0000 Subject: [PATCH 02/16] make local spmd working --- torch_xla/csrc/init_python_bindings.cpp | 7 ++- torch_xla/csrc/lowering_context.cpp | 42 +++++++++++++++++ torch_xla/csrc/lowering_context.h | 10 ++++ torch_xla/csrc/runtime/computation_client.h | 7 ++- .../csrc/runtime/ifrt_computation_client.cc | 6 ++- .../csrc/runtime/ifrt_computation_client.h | 4 +- .../csrc/runtime/pjrt_computation_client.cc | 47 ++++++++++++++++--- .../csrc/runtime/pjrt_computation_client.h | 4 +- torch_xla/csrc/tensor_impl.cpp | 2 +- torch_xla/csrc/xla_graph_executor.cpp | 20 ++++++-- torch_xla/csrc/xla_sharding_util.cpp | 34 ++++++++++++-- 11 files changed, 161 insertions(+), 22 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 98012ea2d359..020979a48802 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1482,7 +1482,7 @@ void InitXlaModuleBindings(py::module m) { if (UseVirtualDevice()) { return 1; } else { - return runtime::GetComputationClient()->GetNumDevices(); + return runtime::GetComputationClient()->GetNumLocalDevices(); } }); m.def("_xla_get_all_devices", []() { @@ -1500,13 +1500,16 @@ void InitXlaModuleBindings(py::module m) { m.def("_xla_get_runtime_devices", []() { return runtime::GetComputationClient()->GetLocalDevices(); }); m.def("_xla_num_runtime_devices", []() -> int64_t { - return runtime::GetComputationClient()->GetNumDevices(); + return runtime::GetComputationClient()->GetNumLocalDevices(); }); m.def("_xla_get_all_runtime_devices", []() { std::vector all_devices = runtime::GetComputationClient()->GetAllDevices(); return all_devices; }); + m.def("_xla_num_global_devices", []() -> int64_t { + return runtime::GetComputationClient()->GetNumGlobalDevices(); + }); m.def( "_xla_real_devices", [](const std::optional> devices) { diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index 6c2906dc7247..a004be88c540 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -93,6 +93,7 @@ LoweringContext::LoweringContext(const std::string& name, torch::lazy::BackendDevice device) : torch::lazy::LoweringContext(name, device), builder_(name), + num_computation_partitions_(1), stack_frame_index_builder_(std::make_shared()) {} LoweringContext::LoweringContext( @@ -101,6 +102,7 @@ LoweringContext::LoweringContext( torch::lazy::Util::EmissionMap emit_status) : torch::lazy::LoweringContext(name, device, {}, emit_status), builder_(name), + num_computation_partitions_(1), stack_frame_index_builder_(std::make_shared()) { for (auto node : post_order) { LowerNode(node); @@ -131,6 +133,7 @@ xla::XlaOp LoweringContext::GetParameter( xla::OpSharding sharding = data->GetSharding(); xla::XlaScopedShardingAssignment scoped_sharding(builder(), sharding); param = xla::Parameter(builder(), param_index, shape, param_name); + UpdateNumPartitions(param); } else { param = xla::Parameter(builder(), param_index, shape, param_name); } @@ -254,6 +257,28 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) { mutable_dims->Set(dim, kUnboundedSize); } } + std::for_each(result_ops.begin(), result_ops.end(), + [this](xla::XlaOp xla_op) { + UpdateNumPartitions(xla_op); // Calling the member function + }); + // for (auto xla_op : result_ops) { + // UpdateNumPartitions(xla_op); + // // std::optional op_sharding = + // // ConsumeValue(builder()->GetOpSharding(xla_op)); + // // if (op_sharding.has_value()) { + // // size_t curr_num_partitions = + // // op_sharding.value().tile_assignment_devices().size(); + // // if (num_computation_partitions_ != 1) { + // // XLA_CHECK_EQ(curr_num_partitions, num_computation_partitions_) + // << + // // "Number of partitions must be the same for all ops in a HLO + // graph."; + // // continue; + // // } + // // num_computation_partitions_ = + // op_sharding.value().tile_assignment_devices().size(); + // // } + // } } catch (const std::exception& ex) { ReportBuilderError(node, ex.what()); } @@ -324,4 +349,21 @@ torch::lazy::ComputationPtr LoweringContext::Build() { builder_.name(), std::move(xla_computation), device_); } +void LoweringContext::UpdateNumPartitions(const xla::XlaOp& op) { + std::optional op_sharding = + ConsumeValue(builder()->GetOpSharding(op)); + if (op_sharding.has_value()) { + size_t curr_num_partitions = + op_sharding.value().tile_assignment_devices().size(); + if (num_computation_partitions_ != 1) { + XLA_CHECK_EQ(curr_num_partitions, num_computation_partitions_) + << "Number of partitions must be the same for all ops in a HLO " + "graph."; + return; + } + std::cout << "curr_num_partitions: " << curr_num_partitions << std::endl; + num_computation_partitions_ = curr_num_partitions; + } +} + } // namespace torch_xla diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index cb4f0bc2d2fa..fdaabb2b14da 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -113,10 +113,18 @@ class LoweringContext : public torch::lazy::LoweringContext { return emitted_outputs_; } + size_t GetComputationNumPartitions() const { + return num_computation_partitions_; + } + // Return stack frame id int64_t AddStackFrameLocation(const torch::lazy::SourceLocation& source, int64_t parent_id); + protected: + // Update the number of partitions from a XlaOp. + void UpdateNumPartitions(const xla::XlaOp& op); + private: struct Parameter { xla::XlaOp param; @@ -133,6 +141,8 @@ class LoweringContext : public torch::lazy::LoweringContext { std::vector root_tuple_; OutputMap emitted_outputs_; std::string name_; + // Number of partitions of the lowered XLA computation. + size_t num_computation_partitions_; std::shared_ptr stack_frame_index_builder_; }; // namespace torch_xla diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index b192d8d2e149..339d2a4f52c6 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -225,6 +225,7 @@ class ComputationClient { xla::XlaComputation computation, std::string compilation_device, std::vector devices, const xla::Shape* output_shape, bool parameter_is_tupled_arguments = false, bool is_sharded = false, + size_t computation_num_partitions = 1, bool allow_spmd_sharding_propagation_to_output = true, bool use_auto_spmd_partitioning = false, std::vector auto_spmd_mesh_shape = {}, @@ -235,6 +236,7 @@ class ComputationClient { output_shape(output_shape), parameter_is_tupled_arguments(parameter_is_tupled_arguments), is_sharded(is_sharded), + computation_num_partitions(computation_num_partitions), allow_spmd_sharding_propagation_to_output( allow_spmd_sharding_propagation_to_output), use_auto_spmd_partitioning(use_auto_spmd_partitioning), @@ -248,6 +250,7 @@ class ComputationClient { const xla::Shape* output_shape = nullptr; bool parameter_is_tupled_arguments; bool is_sharded; + size_t computation_num_partitions; bool allow_spmd_sharding_propagation_to_output; bool use_auto_spmd_partitioning; std::vector auto_spmd_mesh_shape; @@ -374,7 +377,9 @@ class ComputationClient { virtual std::intptr_t GetCudaStreamForDevice(int local_device_id) const = 0; - virtual size_t GetNumDevices() const = 0; + virtual size_t GetNumLocalDevices() const = 0; + + virtual size_t GetNumGlobalDevices() const = 0; virtual std::vector GetLocalDevices() const = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index a197aec460e4..11aaa1a0b8d2 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -613,10 +613,14 @@ IfrtComputationClient::ExecuteReplicated( return data_handles; } -size_t IfrtComputationClient::GetNumDevices() const { +size_t IfrtComputationClient::GetNumLocalDevices() const { return client_->addressable_device_count(); } +size_t IfrtComputationClient::GetNumGlobalDevices() const { + return client_->device_count(); +} + std::string IfrtComputationClient::GetDefaultDevice() const { return IfrtDeviceToString(client_->addressable_devices()[0]); } diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 73b8e21c9f06..26135f65ab55 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -79,7 +79,9 @@ class IfrtComputationClient : public ComputationClient { absl::Span devices, const ExecuteReplicatedOptions& options) override; - size_t GetNumDevices() const override; + size_t GetNumLocalDevices() const override; + + size_t GetNumGlobalDevices() const override; std::string GetDefaultDevice() const override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 749419f66cd4..6bf6217c0366 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -334,6 +334,7 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice( std::shared_ptr PjRtComputationClient::ReplicateShardedData( const ComputationClient::DataPtr& handle) { + std::cout << "PjRtComputationClient::ReplicateShardedData" << std::endl; if (auto unsharded_data = std::dynamic_pointer_cast(handle)) { return unsharded_data; } else if (auto sharded_data = @@ -347,7 +348,9 @@ PjRtComputationClient::ReplicateShardedData( } xla::XlaBuilder builder("ReplicateShardedData"); xla::Shape shape = sharded_data->shape(); - builder.SetSharding(sharded_data->GetSharding()); + xla::OpSharding sharding = sharded_data->GetSharding(); + builder.SetSharding(sharding); + size_t num_partitions = sharding.tile_assignment_devices().size(); // perform a simple identity calculation to reassemble the input as // replicated output. @@ -371,6 +374,7 @@ PjRtComputationClient::ReplicateShardedData( GetCompilationDevices(device, {}), &shape, /*should_wrap_parameter=*/false, /*is_sharded=*/true, + /*computation_num_partitions*/ num_partitions, /*allow_spmd_sharding_propagation_to_output=*/false}); std::vector< std::shared_ptr> @@ -537,6 +541,7 @@ std::vector PjRtComputationClient::TransferFromDevice( std::vector PjRtComputationClient::Compile( std::vector instances) { + std::cout << "in compile" << std::endl; auto metrics_fn = CompileMetric; if (instances[0].eager_mode) { metrics_fn = EagerCompileMetric; @@ -546,7 +551,9 @@ std::vector PjRtComputationClient::Compile( tsl::profiler::TraceMeLevel::kInfo); std::vector computations; + std::cout << "instances.size(): " << instances.size() << std::endl; for (auto& instance : instances) { + std::cout << "instance devices " << instance.devices << std::endl; xla::CompileOptions compile_options; if (instance.is_sharded) { // TODO(yeounoh) multi-host, multi-slice configurations @@ -560,6 +567,9 @@ std::vector PjRtComputationClient::Compile( {instance.allow_spmd_sharding_propagation_to_output}); int num_partitions = client_->device_count(); + // num_partitions = 4; + num_partitions = static_cast(instance.computation_num_partitions); + std::cout << "num_partitions: " << num_partitions << std::endl; compile_options.executable_build_options.set_num_partitions( num_partitions); compile_options.executable_build_options.set_num_replicas(1); @@ -589,11 +599,20 @@ std::vector PjRtComputationClient::Compile( } // TODO(244391366) verify this is correct for the collectives ops - xla::DeviceAssignment device_assignment(1, client_->device_count()); + // xla::DeviceAssignment device_assignment(1, client_->device_count()); + xla::DeviceAssignment device_assignment(1, num_partitions); + std::cout << "check client_->device_count(): " << client_->device_count() + << std::endl; // DeviceAssignment values must be the PjRtDevice ID, so we need to // unwind the global ordinal mapping. - for (const auto& [device_id, global_ordinal] : global_ordinals_) { - device_assignment(0, global_ordinal) = device_id; + // for (const auto& [device_id, global_ordinal] : global_ordinals_) { + // std::cout << "device_id: " << device_id + // << ", global_ordinal: " << global_ordinal << std::endl; + // device_assignment(0, global_ordinal) = device_id; + // } + auto local_pjrt_devices = client_->addressable_devices(); + for (int i = 0; i < local_pjrt_devices.size(); ++i) { + device_assignment(0, i) = local_pjrt_devices[i]->id(); } compile_options.executable_build_options.set_device_assignment( device_assignment); @@ -649,7 +668,7 @@ std::vector PjRtComputationClient::Compile( CreateCompileHandlesCounter()->AddValue(1); } - + std::cout << "finish compile" << std::endl; return computations; } @@ -701,6 +720,7 @@ PjRtComputationClient::ExecuteComputation( const ComputationClient::Computation& computation, absl::Span arguments, const std::string& device, const ExecuteComputationOptions& options) { + std::cout << "in execute" << std::endl; // Shared ownership of the timed section ensures that it will only get logged // once both `ExecuteComputation` and the async work in `ExecuteSharded` are // complete; a copy is held from the lambda that releases it when done. @@ -768,6 +788,7 @@ PjRtComputationClient::ExecuteComputation( CreateDataHandlesCounter()->AddValue(datas.size()); TF_VLOG(1) << "Returning " << datas.size() << " results"; + std::cout << "finish execute" << std::endl; return datas; } @@ -777,6 +798,10 @@ PjRtComputationClient::ExecuteReplicated( absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) { + std::cout << "in execute replicated" << std::endl; + for (auto d : devices) { + std::cout << "device: " << d << std::endl; + } // Shared ownership of the timed section ensures that it will only get logged // once both `ExecuteReplicated` and the async work in `Execute` are // complete; a copy is held from the lambda that releases it when done. @@ -914,13 +939,18 @@ PjRtComputationClient::ExecuteReplicated( } TF_VLOG(1) << "Returning " << data_handles.size() << " sharded outputs."; + std::cout << "finish execute replicated" << std::endl; return data_handles; } -size_t PjRtComputationClient::GetNumDevices() const { +size_t PjRtComputationClient::GetNumLocalDevices() const { return client_->addressable_device_count(); } +size_t PjRtComputationClient::GetNumGlobalDevices() const { + return client_->device_count(); +} + std::string PjRtComputationClient::GetDefaultDevice() const { return PjRtDeviceToString(client_->addressable_devices()[0]); } @@ -972,12 +1002,17 @@ xla::PjRtDevice* PjRtComputationClient::StringToPjRtDevice( void PjRtComputationClient::WaitDeviceOps( absl::Span devices) { + std::cout << "in wait device ops" << std::endl; + for (auto d : devices) { + std::cout << "device: " << d << std::endl; + } TF_VLOG(3) << "Waiting for " << absl::StrJoin(devices, ", "); operation_manager_.WaitForDevices( devices.empty() ? (UseVirtualDevice() ? std::vector({spmd_device_str}) : GetLocalDevices()) : devices); + std::cout << "finish wait device ops" << std::endl; } std::map PjRtComputationClient::GetMetrics() const { diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 9791f32381b6..090ff952fdf2 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -86,7 +86,9 @@ class PjRtComputationClient : public ComputationClient { absl::Span devices, const ExecuteReplicatedOptions& options) override; - size_t GetNumDevices() const override; + size_t GetNumLocalDevices() const override; + + size_t GetNumGlobalDevices() const override; std::string GetDefaultDevice() const override; diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index 4e69127ff816..fcf793ff5bc7 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -57,7 +57,7 @@ struct XLAGuardImpl : public c10::impl::DeviceGuardImplInterface { return 0; } - return client->GetNumDevices(); + return client->GetNumLocalDevices(); } }; diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 0b8c5489798c..514266518dc9 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1391,12 +1391,16 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( // Always execute sharded when running in SPMD mode bool is_sharded = (coll.device == GetVirtualDevice()) || UseVirtualDevice(); // Annotate HLO sharding selectively in the compuation. - ShardingUtil::SetHloSharding(&lowering_ctx); + bool is_sharded_2 = ShardingUtil::SetHloSharding(&lowering_ctx); + + std::cout << "is_sharded_2: " << is_sharded_2 << std::endl; SetBufferDonors(&lowering_ctx, buffer_donor_indices); xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla()); xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); + size_t computation_num_partitions = + lowering_ctx.GetComputationNumPartitions(); // TODO(yeounoh) enable wrapping with auto-sharding. bool should_wrap_parameter = @@ -1422,11 +1426,15 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( program_shape.result(), static_cast(coll.device.type())); std::vector instances; - instances.push_back({std::move(computation), coll.device.toString(), - runtime::GetComputationClient()->GetCompilationDevices( - coll.device.toString(), devices), - &shape, should_wrap_parameter, is_sharded}); + std::cout << "computation_num_partitions: " << computation_num_partitions + << std::endl; + instances.emplace_back(std::move(computation), coll.device.toString(), + runtime::GetComputationClient()->GetCompilationDevices( + coll.device.toString(), devices), + &shape, should_wrap_parameter, is_sharded, + computation_num_partitions); instances.front().eager_mode = UseEagerMode(); + instances.front().computation_num_partitions = computation_num_partitions; if (use_autosharding) { TF_VLOG(5) << "use_auto_spmd_partitioning is set."; TF_CHECK(is_sharded) << "Auto-sharding pass requires SPMD mode."; @@ -1455,6 +1463,8 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( TF_VLOG(3) << "Compiling IR graph hash " << torch::lazy::HashToString(coll.hash) << " on device " << coll.device << " ..."; + std::cout << "check instance num partitions" + << instances.front().computation_num_partitions << std::endl; std::vector> computations = runtime::GetComputationClient()->Compile(std::move(instances)); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index d58144d6844a..b2938f81dbe6 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -85,10 +85,11 @@ std::vector TileAssignmentDimensions( // order of the output corresponds to the order of the `devices`, which can be // arbitrarily set by the caller. std::unordered_map build_index_map( - const std::vector& devices) { + const std::vector& devices, size_t num_mesh_devices) { std::unordered_map device_index; for (int i = 0; i < devices.size(); ++i) { - int global_ordinal = ParseDeviceString(devices[i]).ordinal(); + int global_ordinal = + ParseDeviceString(devices[i]).ordinal() % num_mesh_devices; device_index[global_ordinal] = i; } return device_index; @@ -191,6 +192,9 @@ bool ShardingUtil::SetHloSharding(LoweringContext* lowering_ctx) { XlaBuilderFriend::GetInstruction(elem.second); const std::shared_ptr sharding = xla_node->GetSharding(elem.first.index); + if (sharding != nullptr) { + std::cout << "check opsharding " << sharding->DebugString() << std::endl; + } if (sharding != nullptr && sharding->type() != xla::OpSharding::UNKNOWN) { *instruction->mutable_sharding() = *sharding; is_sharded = true; @@ -371,10 +375,25 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices( shard_indices[i] = std::make_pair(global_ordinal, indices); } } else if (sharding.type() == xla::OpSharding::OTHER) { - auto device_index = build_index_map(devices); std::vector tile_assignment_devices( sharding.tile_assignment_devices().begin(), sharding.tile_assignment_devices().end()); + size_t num_local_devices = + runtime::GetComputationClient()->GetNumLocalDevices(); + size_t num_global_devices = + runtime::GetComputationClient()->GetNumGlobalDevices(); + XLA_CHECK(tile_assignment_devices.size() == num_global_devices || + tile_assignment_devices.size() == num_local_devices) + << "Number of tile_assignment_devices must be the number of global " + "devices or local devices"; + std::cout << "Num local devices " << num_local_devices << std::endl; + std::unordered_map device_index = + build_index_map(devices, tile_assignment_devices.size()); + std::cout << "Check device_index " << std::endl; + for (const auto& pair : device_index) { + std::cout << "Key: " << pair.first << ", Value: " << pair.second + << std::endl; + } if (!sharding.iota_reshape_dims().empty()) { auto tileAssignment = xla::TileAssignment( sharding.tile_assignment_dimensions(), sharding.iota_reshape_dims(), @@ -384,7 +403,10 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices( } for (size_t i = 0; i < tile_assignment_devices.size(); i++) { int64_t core = tile_assignment_devices[i]; + std::cout << "Check core " << core << std::endl; if (device_index.find(core) == device_index.end()) { + std::cout << "current core " << core << " is not in device_index" + << std::endl; // Skip any shards whose device is not part of the `devices` list. continue; } @@ -434,6 +456,8 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices( std::vector ShardingUtil::ShardTensor( const at::Tensor& tensor, const XLATensor::ShardingSpecPtr shardings, const std::vector& devices, bool padded) { + std::cout << "ShardingUtil::ShardTensor check devices " << devices + << std::endl; xla::OpSharding sharding; bool minibatch = false; if (shardings != nullptr) { @@ -442,7 +466,7 @@ std::vector ShardingUtil::ShardTensor( } TF_VLOG(5) << "ShardTensor with sharding type(" << sharding.type() << ")... and minibatch = " << minibatch << std::endl; - auto device_index = build_index_map(devices); + // auto device_index = build_index_map(devices); std::vector shards(devices.size()); if (shardings == nullptr || sharding.type() == xla::OpSharding::REPLICATED || sharding.type() == xla::OpSharding::UNKNOWN) { @@ -464,6 +488,8 @@ std::vector ShardingUtil::ShardTensor( std::back_inserter(shard_indices), [](auto& pair) { return pair.second; }); } + std::cout << "ShardingUtil::ShardTensor check shard_indices: " + << shard_indices << std::endl; for (size_t i = 0; i < shard_indices.size(); i++) { at::Tensor shard = tensor.index( From 2f33433c5562e58a6ba58bb08338e50dc4a10d93 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Sun, 9 Mar 2025 22:06:13 +0000 Subject: [PATCH 03/16] skip no tile assignment device case for num partition retrieving in lowering context --- torch_xla/csrc/lowering_context.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index a004be88c540..5a6621bb49a4 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -355,6 +355,9 @@ void LoweringContext::UpdateNumPartitions(const xla::XlaOp& op) { if (op_sharding.has_value()) { size_t curr_num_partitions = op_sharding.value().tile_assignment_devices().size(); + if (curr_num_partitions == 0) { + return; + } if (num_computation_partitions_ != 1) { XLA_CHECK_EQ(curr_num_partitions, num_computation_partitions_) << "Number of partitions must be the same for all ops in a HLO " From b633c76ef9a4ab945a41619b61f3e7175f28b628 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Sun, 9 Mar 2025 22:06:54 +0000 Subject: [PATCH 04/16] use env var for local spmd --- torch_xla/csrc/runtime/computation_client.h | 2 +- torch_xla/csrc/runtime/pjrt_computation_client.cc | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 339d2a4f52c6..bc01a9af33d7 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -250,7 +250,7 @@ class ComputationClient { const xla::Shape* output_shape = nullptr; bool parameter_is_tupled_arguments; bool is_sharded; - size_t computation_num_partitions; + size_t computation_num_partitions = 1; bool allow_spmd_sharding_propagation_to_output; bool use_auto_spmd_partitioning; std::vector auto_spmd_mesh_shape; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 6bf6217c0366..a81a16c0fb72 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -566,9 +566,11 @@ std::vector PjRtComputationClient::Compile( .set_allow_spmd_sharding_propagation_to_output( {instance.allow_spmd_sharding_propagation_to_output}); - int num_partitions = client_->device_count(); - // num_partitions = 4; - num_partitions = static_cast(instance.computation_num_partitions); + int num_partitions = GetNumGlobalDevices(); + if (runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false)) { + num_partitions = GetNumLocalDevices(); + } + // num_partitions = static_cast(instance.computation_num_partitions); std::cout << "num_partitions: " << num_partitions << std::endl; compile_options.executable_build_options.set_num_partitions( num_partitions); From ba4b480ea94540ca4b8b849a8ef9c45446b4b5f5 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 10 Mar 2025 00:36:47 +0000 Subject: [PATCH 05/16] get num partitions from prod of tile dims --- torch_xla/csrc/xla_sharding_util.cpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index b2938f81dbe6..00727187e376 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -382,13 +382,20 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices( runtime::GetComputationClient()->GetNumLocalDevices(); size_t num_global_devices = runtime::GetComputationClient()->GetNumGlobalDevices(); - XLA_CHECK(tile_assignment_devices.size() == num_global_devices || - tile_assignment_devices.size() == num_local_devices) - << "Number of tile_assignment_devices must be the number of global " - "devices or local devices"; + // XLA_CHECK(tile_assignment_devices.size() == 0 || + // tile_assignment_devices.size() == num_global_devices || + // tile_assignment_devices.size() == num_local_devices) + // << "Number of tile_assignment_devices must be the number of global " + // "devices or local devices, or 0, got unexpected size of " + // << tile_assignment_devices.size(); + size_t num_tiles = std::accumulate( + sharding.tile_assignment_dimensions().begin(), + sharding.tile_assignment_dimensions().end(), 1, + [](int a, int b) { return a * b; }); std::cout << "Num local devices " << num_local_devices << std::endl; + std::cout << "Num tile assignment size " << tile_assignment_devices.size() << std::endl; std::unordered_map device_index = - build_index_map(devices, tile_assignment_devices.size()); + build_index_map(devices, num_tiles); std::cout << "Check device_index " << std::endl; for (const auto& pair : device_index) { std::cout << "Key: " << pair.first << ", Value: " << pair.second From 7561786cb5c4acd1b1c98fe62c603bae390b8dc7 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 10 Mar 2025 00:37:07 +0000 Subject: [PATCH 06/16] clang --- torch_xla/csrc/xla_sharding_util.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 00727187e376..5e168433a441 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -388,12 +388,13 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices( // << "Number of tile_assignment_devices must be the number of global " // "devices or local devices, or 0, got unexpected size of " // << tile_assignment_devices.size(); - size_t num_tiles = std::accumulate( - sharding.tile_assignment_dimensions().begin(), - sharding.tile_assignment_dimensions().end(), 1, - [](int a, int b) { return a * b; }); + size_t num_tiles = + std::accumulate(sharding.tile_assignment_dimensions().begin(), + sharding.tile_assignment_dimensions().end(), 1, + [](int a, int b) { return a * b; }); std::cout << "Num local devices " << num_local_devices << std::endl; - std::cout << "Num tile assignment size " << tile_assignment_devices.size() << std::endl; + std::cout << "Num tile assignment size " << tile_assignment_devices.size() + << std::endl; std::unordered_map device_index = build_index_map(devices, num_tiles); std::cout << "Check device_index " << std::endl; From ff12d44a855d3a61b13aa55483ccc568694d3f36 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 10 Mar 2025 00:58:26 +0000 Subject: [PATCH 07/16] add a test for shard tensor for local mesh --- test/cpp/test_xla_sharding.cpp | 107 +++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index e1f908b5c806..ed9b5b7677ce 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -222,6 +222,113 @@ TEST_F(XLAShardingTest, ShardTensor) { EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); } +TEST_F(XLAShardingTest, ShardTensorLocalMesh) { + // Test sharding with a local mesh. + std::vector devices = {"TPU:8", "TPU:9", "TPU:10", "TPU:11", + "TPU:12", "TPU:13", "TPU:14", "TPU:15"}; + + // 1D tiled + at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + xla::OpSharding sharding = + xla::HloSharding::Tile1D( + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()), + devices.size()) + .ToProto(); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); + EXPECT_EQ(shards.size(), 8); + for (auto shard : shards) { + EXPECT_EQ(shard.sizes(), c10::ArrayRef({1})); + } + + // 2D tiled, The first dim is halved and the last replicated. The last shard + // size should be smaller in dim=1 because it's not evenly divisible. + tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); + tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + xla::Array2D mesh({ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + }); + sharding = xla::HloSharding::Tile(mesh).ToProto(); + sharding_spec = + std::make_shared(sharding, tensor_shape); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); + EXPECT_EQ(shards.size(), 8); + EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({4, 2, 4})); + EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({4, 1, 4})); + + // 3D tiled, the first dim is replicated and the last halved. The last shard + // size should be smaller in dim=1 because it's not evenly divisible. + xla::Array3D cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}); + sharding_spec->sharding = xla::HloSharding::Tile(cube).ToProto(); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); + EXPECT_EQ(shards.size(), 8); + EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({8, 2, 2})); + EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({8, 1, 2})); + + // Replicated, all shards should be identical. + sharding_spec->sharding = xla::HloSharding::Replicate().ToProto(); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); + EXPECT_EQ(shards.size(), 8); + EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({8, 7, 4})); + EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({8, 7, 4})); + + // 4D tiled, the first and second dims are replicated and the last halved. The + // last shard size should be smaller in dim=2 because it's not evenly + // divisible. + tensor = at::ones({1, 8, 7, 4}, at::TensorOptions(at::kFloat)); + tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + xla::Array4D tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}}); + sharding = xla::HloSharding::Tile(tesseract).ToProto(); + sharding_spec = + std::make_shared(sharding, tensor_shape); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); + EXPECT_EQ(shards.size(), 8); + EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({1, 8, 2, 2})); + EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({1, 8, 1, 2})); + + // 4D tiled and padded, all shard sizes should be idential. + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/true); + EXPECT_EQ(shards.size(), 8); + EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({1, 8, 2, 2})); + EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({1, 8, 2, 2})); + + // 5D tiled, the first and second dims are replicated and the last halved. The + // last shard size should be smaller in dim=2 because it's not evenly + // divisible. + tensor = at::ones({10, 1, 8, 7, 4}, at::TensorOptions(at::kFloat)); + tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + xla::Array hypercube(std::vector{1, 1, 2, 2, 2}); + hypercube.FillIota(0); + sharding = xla::HloSharding::Tile(hypercube).ToProto(); + sharding_spec = + std::make_shared(sharding, tensor_shape); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); + EXPECT_EQ(shards.size(), 8); + EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); + EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({10, 1, 4, 3, 2})); + + // 5D tiled and padded, all shard sizes should be identical. + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/true); + EXPECT_EQ(shards.size(), 8); + EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); + EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); +} + TEST_F(XLAShardingTest, ShardTensorMultiHost) { std::vector devices = {"TPU:4", "TPU:5", "TPU:6", "TPU:7"}; From 98731291ef9820262a7e54d1fd417154843b3a56 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 10 Mar 2025 01:11:27 +0000 Subject: [PATCH 08/16] use env var for device assignment handling --- .../csrc/runtime/pjrt_computation_client.cc | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index a81a16c0fb72..94160ca4949d 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -601,20 +601,18 @@ std::vector PjRtComputationClient::Compile( } // TODO(244391366) verify this is correct for the collectives ops - // xla::DeviceAssignment device_assignment(1, client_->device_count()); xla::DeviceAssignment device_assignment(1, num_partitions); - std::cout << "check client_->device_count(): " << client_->device_count() - << std::endl; // DeviceAssignment values must be the PjRtDevice ID, so we need to // unwind the global ordinal mapping. - // for (const auto& [device_id, global_ordinal] : global_ordinals_) { - // std::cout << "device_id: " << device_id - // << ", global_ordinal: " << global_ordinal << std::endl; - // device_assignment(0, global_ordinal) = device_id; - // } - auto local_pjrt_devices = client_->addressable_devices(); - for (int i = 0; i < local_pjrt_devices.size(); ++i) { - device_assignment(0, i) = local_pjrt_devices[i]->id(); + if (runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false)) { + auto local_pjrt_devices = client_->addressable_devices(); + for (int i = 0; i < local_pjrt_devices.size(); ++i) { + device_assignment(0, i) = local_pjrt_devices[i]->id(); + } + } else { + for (const auto& [device_id, global_ordinal] : global_ordinals_) { + device_assignment(0, global_ordinal) = device_id; + } } compile_options.executable_build_options.set_device_assignment( device_assignment); From e4183509ecf103b3b76f9a8518b5f14ab2a431f2 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 10 Mar 2025 01:31:37 +0000 Subject: [PATCH 09/16] add assertion, comment for xla sharding python api --- torch_xla/distributed/spmd/xla_sharding.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index de2714ad2492..dc82af375aff 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -63,12 +63,15 @@ def __init__(self, device_ids = np.array(device_ids) assert (axis_names is None) or (len(mesh_shape) == len(axis_names)) assert axis_names is None or (len(set(axis_names)) == len(axis_names)) + # size of device_ids matches mesh_shape assert (len(device_ids) == np.prod(mesh_shape)) + # device ids are unique assert len(device_ids) == len(np.unique(device_ids)) + # device ids are continous + assert all(d < self.size() for d in device_ids - np.min(device_ids)) self.device_ids = device_ids self.mesh_shape = mesh_shape self.axis_names = axis_names - # assert all(d < self.size() for d in device_ids) def size(self): return np.prod(self.mesh_shape) @@ -382,6 +385,15 @@ def _get_sharding_type(partition_spec: Tuple[Union[int, None]], def _normalize_logical_mesh(device_mesh: np.ndarray) -> np.ndarray: + """ + Normalize the device mesh to start from 0. + + This is needed when mesh doesn't include all global devices + (e.g. In multi-host setup, each host has a mesh containing local devices). + Because HLO graph always use logical device ids in the sharding annotation, + we need to normalize the physical device ids to generate the correct HLO + sharding annotation. + """ device_id_min = np.min(device_mesh) return device_mesh.copy() - device_id_min From e2d157b3f73517ff9a75924c50ab3f9c69d2a922 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 10 Mar 2025 01:42:13 +0000 Subject: [PATCH 10/16] fix assertion --- torch_xla/distributed/spmd/xla_sharding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index dc82af375aff..f16a99068871 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -67,11 +67,11 @@ def __init__(self, assert (len(device_ids) == np.prod(mesh_shape)) # device ids are unique assert len(device_ids) == len(np.unique(device_ids)) - # device ids are continous - assert all(d < self.size() for d in device_ids - np.min(device_ids)) self.device_ids = device_ids self.mesh_shape = mesh_shape self.axis_names = axis_names + # device ids are continous + assert all(d < self.size() for d in device_ids - np.min(device_ids)) def size(self): return np.prod(self.mesh_shape) From 3865f67e6847111899a324c475182bc82a7af5cf Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 10 Mar 2025 02:57:42 +0000 Subject: [PATCH 11/16] remove debug print, attemp to derive num partitions from lowering --- torch_xla/csrc/lowering_context.cpp | 45 ------------------- torch_xla/csrc/lowering_context.h | 10 ----- torch_xla/csrc/runtime/computation_client.h | 3 -- .../csrc/runtime/pjrt_computation_client.cc | 24 +--------- torch_xla/csrc/xla_graph_executor.cpp | 14 +----- torch_xla/csrc/xla_sharding_util.cpp | 34 ++------------ torch_xla/distributed/spmd/xla_sharding.py | 4 -- 7 files changed, 6 insertions(+), 128 deletions(-) diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index 5a6621bb49a4..6c2906dc7247 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -93,7 +93,6 @@ LoweringContext::LoweringContext(const std::string& name, torch::lazy::BackendDevice device) : torch::lazy::LoweringContext(name, device), builder_(name), - num_computation_partitions_(1), stack_frame_index_builder_(std::make_shared()) {} LoweringContext::LoweringContext( @@ -102,7 +101,6 @@ LoweringContext::LoweringContext( torch::lazy::Util::EmissionMap emit_status) : torch::lazy::LoweringContext(name, device, {}, emit_status), builder_(name), - num_computation_partitions_(1), stack_frame_index_builder_(std::make_shared()) { for (auto node : post_order) { LowerNode(node); @@ -133,7 +131,6 @@ xla::XlaOp LoweringContext::GetParameter( xla::OpSharding sharding = data->GetSharding(); xla::XlaScopedShardingAssignment scoped_sharding(builder(), sharding); param = xla::Parameter(builder(), param_index, shape, param_name); - UpdateNumPartitions(param); } else { param = xla::Parameter(builder(), param_index, shape, param_name); } @@ -257,28 +254,6 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) { mutable_dims->Set(dim, kUnboundedSize); } } - std::for_each(result_ops.begin(), result_ops.end(), - [this](xla::XlaOp xla_op) { - UpdateNumPartitions(xla_op); // Calling the member function - }); - // for (auto xla_op : result_ops) { - // UpdateNumPartitions(xla_op); - // // std::optional op_sharding = - // // ConsumeValue(builder()->GetOpSharding(xla_op)); - // // if (op_sharding.has_value()) { - // // size_t curr_num_partitions = - // // op_sharding.value().tile_assignment_devices().size(); - // // if (num_computation_partitions_ != 1) { - // // XLA_CHECK_EQ(curr_num_partitions, num_computation_partitions_) - // << - // // "Number of partitions must be the same for all ops in a HLO - // graph."; - // // continue; - // // } - // // num_computation_partitions_ = - // op_sharding.value().tile_assignment_devices().size(); - // // } - // } } catch (const std::exception& ex) { ReportBuilderError(node, ex.what()); } @@ -349,24 +324,4 @@ torch::lazy::ComputationPtr LoweringContext::Build() { builder_.name(), std::move(xla_computation), device_); } -void LoweringContext::UpdateNumPartitions(const xla::XlaOp& op) { - std::optional op_sharding = - ConsumeValue(builder()->GetOpSharding(op)); - if (op_sharding.has_value()) { - size_t curr_num_partitions = - op_sharding.value().tile_assignment_devices().size(); - if (curr_num_partitions == 0) { - return; - } - if (num_computation_partitions_ != 1) { - XLA_CHECK_EQ(curr_num_partitions, num_computation_partitions_) - << "Number of partitions must be the same for all ops in a HLO " - "graph."; - return; - } - std::cout << "curr_num_partitions: " << curr_num_partitions << std::endl; - num_computation_partitions_ = curr_num_partitions; - } -} - } // namespace torch_xla diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index fdaabb2b14da..cb4f0bc2d2fa 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -113,18 +113,10 @@ class LoweringContext : public torch::lazy::LoweringContext { return emitted_outputs_; } - size_t GetComputationNumPartitions() const { - return num_computation_partitions_; - } - // Return stack frame id int64_t AddStackFrameLocation(const torch::lazy::SourceLocation& source, int64_t parent_id); - protected: - // Update the number of partitions from a XlaOp. - void UpdateNumPartitions(const xla::XlaOp& op); - private: struct Parameter { xla::XlaOp param; @@ -141,8 +133,6 @@ class LoweringContext : public torch::lazy::LoweringContext { std::vector root_tuple_; OutputMap emitted_outputs_; std::string name_; - // Number of partitions of the lowered XLA computation. - size_t num_computation_partitions_; std::shared_ptr stack_frame_index_builder_; }; // namespace torch_xla diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index bc01a9af33d7..20915de32e2b 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -225,7 +225,6 @@ class ComputationClient { xla::XlaComputation computation, std::string compilation_device, std::vector devices, const xla::Shape* output_shape, bool parameter_is_tupled_arguments = false, bool is_sharded = false, - size_t computation_num_partitions = 1, bool allow_spmd_sharding_propagation_to_output = true, bool use_auto_spmd_partitioning = false, std::vector auto_spmd_mesh_shape = {}, @@ -236,7 +235,6 @@ class ComputationClient { output_shape(output_shape), parameter_is_tupled_arguments(parameter_is_tupled_arguments), is_sharded(is_sharded), - computation_num_partitions(computation_num_partitions), allow_spmd_sharding_propagation_to_output( allow_spmd_sharding_propagation_to_output), use_auto_spmd_partitioning(use_auto_spmd_partitioning), @@ -250,7 +248,6 @@ class ComputationClient { const xla::Shape* output_shape = nullptr; bool parameter_is_tupled_arguments; bool is_sharded; - size_t computation_num_partitions = 1; bool allow_spmd_sharding_propagation_to_output; bool use_auto_spmd_partitioning; std::vector auto_spmd_mesh_shape; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 94160ca4949d..3783bb61b5da 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -334,7 +334,6 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice( std::shared_ptr PjRtComputationClient::ReplicateShardedData( const ComputationClient::DataPtr& handle) { - std::cout << "PjRtComputationClient::ReplicateShardedData" << std::endl; if (auto unsharded_data = std::dynamic_pointer_cast(handle)) { return unsharded_data; } else if (auto sharded_data = @@ -348,9 +347,7 @@ PjRtComputationClient::ReplicateShardedData( } xla::XlaBuilder builder("ReplicateShardedData"); xla::Shape shape = sharded_data->shape(); - xla::OpSharding sharding = sharded_data->GetSharding(); - builder.SetSharding(sharding); - size_t num_partitions = sharding.tile_assignment_devices().size(); + builder.SetSharding(sharded_data->GetSharding()); // perform a simple identity calculation to reassemble the input as // replicated output. @@ -374,7 +371,6 @@ PjRtComputationClient::ReplicateShardedData( GetCompilationDevices(device, {}), &shape, /*should_wrap_parameter=*/false, /*is_sharded=*/true, - /*computation_num_partitions*/ num_partitions, /*allow_spmd_sharding_propagation_to_output=*/false}); std::vector< std::shared_ptr> @@ -541,7 +537,6 @@ std::vector PjRtComputationClient::TransferFromDevice( std::vector PjRtComputationClient::Compile( std::vector instances) { - std::cout << "in compile" << std::endl; auto metrics_fn = CompileMetric; if (instances[0].eager_mode) { metrics_fn = EagerCompileMetric; @@ -551,9 +546,7 @@ std::vector PjRtComputationClient::Compile( tsl::profiler::TraceMeLevel::kInfo); std::vector computations; - std::cout << "instances.size(): " << instances.size() << std::endl; for (auto& instance : instances) { - std::cout << "instance devices " << instance.devices << std::endl; xla::CompileOptions compile_options; if (instance.is_sharded) { // TODO(yeounoh) multi-host, multi-slice configurations @@ -570,8 +563,6 @@ std::vector PjRtComputationClient::Compile( if (runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false)) { num_partitions = GetNumLocalDevices(); } - // num_partitions = static_cast(instance.computation_num_partitions); - std::cout << "num_partitions: " << num_partitions << std::endl; compile_options.executable_build_options.set_num_partitions( num_partitions); compile_options.executable_build_options.set_num_replicas(1); @@ -668,7 +659,6 @@ std::vector PjRtComputationClient::Compile( CreateCompileHandlesCounter()->AddValue(1); } - std::cout << "finish compile" << std::endl; return computations; } @@ -720,7 +710,6 @@ PjRtComputationClient::ExecuteComputation( const ComputationClient::Computation& computation, absl::Span arguments, const std::string& device, const ExecuteComputationOptions& options) { - std::cout << "in execute" << std::endl; // Shared ownership of the timed section ensures that it will only get logged // once both `ExecuteComputation` and the async work in `ExecuteSharded` are // complete; a copy is held from the lambda that releases it when done. @@ -788,7 +777,6 @@ PjRtComputationClient::ExecuteComputation( CreateDataHandlesCounter()->AddValue(datas.size()); TF_VLOG(1) << "Returning " << datas.size() << " results"; - std::cout << "finish execute" << std::endl; return datas; } @@ -798,10 +786,6 @@ PjRtComputationClient::ExecuteReplicated( absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) { - std::cout << "in execute replicated" << std::endl; - for (auto d : devices) { - std::cout << "device: " << d << std::endl; - } // Shared ownership of the timed section ensures that it will only get logged // once both `ExecuteReplicated` and the async work in `Execute` are // complete; a copy is held from the lambda that releases it when done. @@ -939,7 +923,6 @@ PjRtComputationClient::ExecuteReplicated( } TF_VLOG(1) << "Returning " << data_handles.size() << " sharded outputs."; - std::cout << "finish execute replicated" << std::endl; return data_handles; } @@ -1002,17 +985,12 @@ xla::PjRtDevice* PjRtComputationClient::StringToPjRtDevice( void PjRtComputationClient::WaitDeviceOps( absl::Span devices) { - std::cout << "in wait device ops" << std::endl; - for (auto d : devices) { - std::cout << "device: " << d << std::endl; - } TF_VLOG(3) << "Waiting for " << absl::StrJoin(devices, ", "); operation_manager_.WaitForDevices( devices.empty() ? (UseVirtualDevice() ? std::vector({spmd_device_str}) : GetLocalDevices()) : devices); - std::cout << "finish wait device ops" << std::endl; } std::map PjRtComputationClient::GetMetrics() const { diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 514266518dc9..c33b5431455e 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1391,16 +1391,12 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( // Always execute sharded when running in SPMD mode bool is_sharded = (coll.device == GetVirtualDevice()) || UseVirtualDevice(); // Annotate HLO sharding selectively in the compuation. - bool is_sharded_2 = ShardingUtil::SetHloSharding(&lowering_ctx); - - std::cout << "is_sharded_2: " << is_sharded_2 << std::endl; + ShardingUtil::SetHloSharding(&lowering_ctx); SetBufferDonors(&lowering_ctx, buffer_donor_indices); xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla()); xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape()); - size_t computation_num_partitions = - lowering_ctx.GetComputationNumPartitions(); // TODO(yeounoh) enable wrapping with auto-sharding. bool should_wrap_parameter = @@ -1426,15 +1422,11 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( program_shape.result(), static_cast(coll.device.type())); std::vector instances; - std::cout << "computation_num_partitions: " << computation_num_partitions - << std::endl; instances.emplace_back(std::move(computation), coll.device.toString(), runtime::GetComputationClient()->GetCompilationDevices( coll.device.toString(), devices), - &shape, should_wrap_parameter, is_sharded, - computation_num_partitions); + &shape, should_wrap_parameter, is_sharded); instances.front().eager_mode = UseEagerMode(); - instances.front().computation_num_partitions = computation_num_partitions; if (use_autosharding) { TF_VLOG(5) << "use_auto_spmd_partitioning is set."; TF_CHECK(is_sharded) << "Auto-sharding pass requires SPMD mode."; @@ -1463,8 +1455,6 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( TF_VLOG(3) << "Compiling IR graph hash " << torch::lazy::HashToString(coll.hash) << " on device " << coll.device << " ..."; - std::cout << "check instance num partitions" - << instances.front().computation_num_partitions << std::endl; std::vector> computations = runtime::GetComputationClient()->Compile(std::move(instances)); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 5e168433a441..c4399c22be1f 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -192,9 +192,6 @@ bool ShardingUtil::SetHloSharding(LoweringContext* lowering_ctx) { XlaBuilderFriend::GetInstruction(elem.second); const std::shared_ptr sharding = xla_node->GetSharding(elem.first.index); - if (sharding != nullptr) { - std::cout << "check opsharding " << sharding->DebugString() << std::endl; - } if (sharding != nullptr && sharding->type() != xla::OpSharding::UNKNOWN) { *instruction->mutable_sharding() = *sharding; is_sharded = true; @@ -375,33 +372,15 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices( shard_indices[i] = std::make_pair(global_ordinal, indices); } } else if (sharding.type() == xla::OpSharding::OTHER) { - std::vector tile_assignment_devices( - sharding.tile_assignment_devices().begin(), - sharding.tile_assignment_devices().end()); - size_t num_local_devices = - runtime::GetComputationClient()->GetNumLocalDevices(); - size_t num_global_devices = - runtime::GetComputationClient()->GetNumGlobalDevices(); - // XLA_CHECK(tile_assignment_devices.size() == 0 || - // tile_assignment_devices.size() == num_global_devices || - // tile_assignment_devices.size() == num_local_devices) - // << "Number of tile_assignment_devices must be the number of global " - // "devices or local devices, or 0, got unexpected size of " - // << tile_assignment_devices.size(); size_t num_tiles = std::accumulate(sharding.tile_assignment_dimensions().begin(), sharding.tile_assignment_dimensions().end(), 1, [](int a, int b) { return a * b; }); - std::cout << "Num local devices " << num_local_devices << std::endl; - std::cout << "Num tile assignment size " << tile_assignment_devices.size() - << std::endl; std::unordered_map device_index = build_index_map(devices, num_tiles); - std::cout << "Check device_index " << std::endl; - for (const auto& pair : device_index) { - std::cout << "Key: " << pair.first << ", Value: " << pair.second - << std::endl; - } + std::vector tile_assignment_devices( + sharding.tile_assignment_devices().begin(), + sharding.tile_assignment_devices().end()); if (!sharding.iota_reshape_dims().empty()) { auto tileAssignment = xla::TileAssignment( sharding.tile_assignment_dimensions(), sharding.iota_reshape_dims(), @@ -411,10 +390,7 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices( } for (size_t i = 0; i < tile_assignment_devices.size(); i++) { int64_t core = tile_assignment_devices[i]; - std::cout << "Check core " << core << std::endl; if (device_index.find(core) == device_index.end()) { - std::cout << "current core " << core << " is not in device_index" - << std::endl; // Skip any shards whose device is not part of the `devices` list. continue; } @@ -464,8 +440,6 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices( std::vector ShardingUtil::ShardTensor( const at::Tensor& tensor, const XLATensor::ShardingSpecPtr shardings, const std::vector& devices, bool padded) { - std::cout << "ShardingUtil::ShardTensor check devices " << devices - << std::endl; xla::OpSharding sharding; bool minibatch = false; if (shardings != nullptr) { @@ -496,8 +470,6 @@ std::vector ShardingUtil::ShardTensor( std::back_inserter(shard_indices), [](auto& pair) { return pair.second; }); } - std::cout << "ShardingUtil::ShardTensor check shard_indices: " - << shard_indices << std::endl; for (size_t i = 0; i < shard_indices.size(); i++) { at::Tensor shard = tensor.index( diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index f16a99068871..4bc0b71318ff 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -130,10 +130,6 @@ def get_op_sharding(self, tile_assignment, group_assignment, replication_groups, sharding_type = self._get_op_sharding_args( partition_spec) - print(f"check tile_assignment: {tile_assignment}") - print(f"check group_assignment: {group_assignment}") - print(f"check replication_groups: {replication_groups}") - print(f"check sharding_type: {sharding_type}") return torch_xla._XLAC.OpSharding(tile_assignment, group_assignment, replication_groups, sharding_type) From d3feb5f7097f082897086f0f7add7e06b13e488a Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 10 Mar 2025 06:00:21 +0000 Subject: [PATCH 12/16] remove unused var --- torch_xla/csrc/xla_sharding_util.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index c4399c22be1f..4289cf0e00c7 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -448,7 +448,6 @@ std::vector ShardingUtil::ShardTensor( } TF_VLOG(5) << "ShardTensor with sharding type(" << sharding.type() << ")... and minibatch = " << minibatch << std::endl; - // auto device_index = build_index_map(devices); std::vector shards(devices.size()); if (shardings == nullptr || sharding.type() == xla::OpSharding::REPLICATED || sharding.type() == xla::OpSharding::UNKNOWN) { From 2f1fc1a411cb7bb4abc9181aeb4197002278ca34 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Thu, 13 Mar 2025 06:13:21 +0000 Subject: [PATCH 13/16] add comment for the modular of in device index map util func --- torch_xla/csrc/xla_sharding_util.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 4289cf0e00c7..58bc0ac20535 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -88,6 +88,11 @@ std::unordered_map build_index_map( const std::vector& devices, size_t num_mesh_devices) { std::unordered_map device_index; for (int i = 0; i < devices.size(); ++i) { + // The global ordianl here is the device's ordinal in the mesh, which is + // can be different from the physical device index. + // We only support 2 cases here: + // 1. Mesh contains all global devices. + // 2. Mesh contains only local devices. (in multi-host scenario) int global_ordinal = ParseDeviceString(devices[i]).ordinal() % num_mesh_devices; device_index[global_ordinal] = i; From a1eaaebfd4c472cbc61ec4fc3937eb4f92aea55c Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Thu, 13 Mar 2025 06:17:15 +0000 Subject: [PATCH 14/16] udpate comment --- torch_xla/csrc/xla_sharding_util.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 58bc0ac20535..2058e7490a68 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -93,6 +93,10 @@ std::unordered_map build_index_map( // We only support 2 cases here: // 1. Mesh contains all global devices. // 2. Mesh contains only local devices. (in multi-host scenario) + // Example: In multi-host v6e-8, each host has a mesh of its local + // devices, host 1 has devices TPU:{4, 5, 6, 7}. In this case + // the global ordinal of TPU:4 is 0, TPU:5 is 1, and so on. + int global_ordinal = ParseDeviceString(devices[i]).ordinal() % num_mesh_devices; device_index[global_ordinal] = i; From e97401d4a24f203acb64147e473f084f726522ab Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Thu, 13 Mar 2025 06:27:08 +0000 Subject: [PATCH 15/16] assert on local devices in mesh contructor --- torch_xla/distributed/spmd/xla_sharding.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 4bc0b71318ff..90a6df67388b 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -71,7 +71,15 @@ def __init__(self, self.mesh_shape = mesh_shape self.axis_names = axis_names # device ids are continous - assert all(d < self.size() for d in device_ids - np.min(device_ids)) + if min(device_ids) != 0: + # Mesh doesn't contain all global devices. Only creating a mesh with local + # devices is supported. + min_device_idx = xr.process_index() * xr.addressable_runtime_device_count( + ) + assert min_device_idx == min( + device_ids + ), "If not creating a mesh with all global devices, must use local devices." + assert all(d < self.size() for d in device_ids) def size(self): return np.prod(self.mesh_shape) From 4bfd7b59a0512e8707d459e04de1ae9862550e2d Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Thu, 13 Mar 2025 06:36:55 +0000 Subject: [PATCH 16/16] check local devices in mesh ctor --- torch_xla/distributed/spmd/xla_sharding.py | 40 ++++++++++++---------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 90a6df67388b..d1a1db4b6444 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -1,25 +1,25 @@ import collections -from collections.abc import Generator, MutableMapping +import functools +import itertools import math +import os from collections import OrderedDict, defaultdict +from collections.abc import Generator, MutableMapping from dataclasses import dataclass, field +from enum import IntEnum +from typing import Any, List, Optional, Sequence, Set, Tuple, Union + +import numpy as np import torch -from torch import Tensor -from torch.library import custom_op import torch_xla -import torch_xla.core.xla_model as xm import torch_xla._internal.utils as _utils -from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard -import torch_xla.runtime as xr +import torch_xla.core.xla_model as xm import torch_xla.debug.profiler as xp - -import numpy as np -import functools -import itertools -from typing import Tuple, Union, List, Sequence, Any, Optional, Set -from enum import IntEnum - -from torch.amp import custom_fwd, custom_bwd +import torch_xla.runtime as xr +from torch import Tensor +from torch.amp import custom_bwd, custom_fwd +from torch.library import custom_op +from torch_xla.distributed.spmd import XLAShard, XLAShardedTensor class Mesh: @@ -71,15 +71,16 @@ def __init__(self, self.mesh_shape = mesh_shape self.axis_names = axis_names # device ids are continous - if min(device_ids) != 0: - # Mesh doesn't contain all global devices. Only creating a mesh with local - # devices is supported. + if os.environ['XLA_USE_LOCAL_SPMD'] == '1': + # In local SPMD mesh only contains local devices. min_device_idx = xr.process_index() * xr.addressable_runtime_device_count( ) - assert min_device_idx == min( + assert min_device_idx == np.min( device_ids ), "If not creating a mesh with all global devices, must use local devices." - assert all(d < self.size() for d in device_ids) + assert all(d < self.size() for d in device_ids - np.min(device_ids)) + else: + assert all(d < self.size() for d in device_ids) def size(self): return np.prod(self.mesh_shape) @@ -151,6 +152,7 @@ def __str__(self): def from_str(cls, mesh_str: str) -> Optional["Mesh"]: """Create Mesh from string representation.""" import ast + import numpy as np try: dict_str = mesh_str.replace('Mesh', '')