From ed8627481b47461e1b82666797bc18ad40c0520e Mon Sep 17 00:00:00 2001 From: Claire Huang Date: Tue, 15 Jul 2025 22:26:45 +0000 Subject: [PATCH 01/13] :qImplement XLAShardedTensor._spec and test --- test/spmd/test_xla_dtensor_spec_conversion.py | 148 ++++-------------- .../distributed/spmd/xla_sharded_tensor.py | 25 ++- torch_xla/distributed/spmd/xla_sharding.py | 25 +-- 3 files changed, 61 insertions(+), 137 deletions(-) diff --git a/test/spmd/test_xla_dtensor_spec_conversion.py b/test/spmd/test_xla_dtensor_spec_conversion.py index 81cb8a4aa2e..12102f555c2 100644 --- a/test/spmd/test_xla_dtensor_spec_conversion.py +++ b/test/spmd/test_xla_dtensor_spec_conversion.py @@ -3,12 +3,9 @@ import torch from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor -from torch.distributed.tensor.placement_types import Replicate import torch_xla import torch_xla.runtime as xr -from torch_xla.distributed.spmd import XLAShardedTensor -from torch_xla.distributed.spmd.xla_sharding import wrap_as_sharded_tensor import unittest import test_xla_sharding_base @@ -34,6 +31,7 @@ def test_xla_to_dtensor_spec_conversion(self): mesh = DeviceMesh("xla", list(range(device_count))) # Test different sharding patterns + from torch.distributed.tensor.placement_types import Replicate test_cases = [ (torch.randn(100, 50), [Shard(0)]), (torch.randn(100, 50), [Shard(1)]), @@ -66,20 +64,30 @@ def test_mesh_conversion(self): assert converted_spec.mesh.shape == original_mesh.shape def test_spec_caching(self): - """Test that _spec property caches results - """ + """Test that _spec property caches results for better performance""" + import time device_count = xr.global_runtime_device_count() mesh = DeviceMesh("xla", list(range(device_count))) - tensor = torch.randn(100, 100) + tensor = torch.randn(1000, + 1000) # Large tensor to make spec creation noticeable xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) + # first access should create and cache the spec + start_time = time.time() spec1 = xla_tensor._spec + first_access_time = time.time() - start_time - assert xla_tensor._cached_spec is not None - assert xla_tensor._cached_spec is spec1 - + # should be much faster due to caching + start_time = time.time() spec2 = xla_tensor._spec + second_access_time = time.time() - start_time + assert spec1 is spec2 + print( + f"First access: {first_access_time:.6f}s, Second access: {second_access_time:.6f}s" + ) + assert second_access_time * 10 < first_access_time, \ + f"Cached access should be much faster: {first_access_time:.6f}s vs {second_access_time:.6f}s" def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements): """Helper to create tensor and mesh for testing""" @@ -106,8 +114,22 @@ def test_multi_dim_sharding_spec(self): assert len(spec.placements) == 2 assert spec.mesh.ndim == 2 + def test_tensor_operations_preserve_spec(self): + """Test that tensor operations preserve sharding metadata""" + xla_tensor, mesh = self._create_test_tensor_and_mesh((100, 50), (-1,), + [Shard(0)]) + + result_add = xla_tensor + 1 + result_mul = xla_tensor * 2 + result_relu = torch.relu(xla_tensor) + + for result in [result_add, result_mul, result_relu]: + assert hasattr(result, '_spec') + assert result._spec.mesh.device_type == "xla" + def test_mixed_placement_spec(self): """Test _spec for tensors with mixed shard/replicate placements""" + from torch.distributed.tensor.placement_types import Replicate device_count = xr.global_runtime_device_count() if device_count < 4: self.skipTest("Need at least 4 devices for 2D mesh") @@ -121,114 +143,6 @@ def test_mixed_placement_spec(self): assert isinstance(spec.placements[0], Shard) assert isinstance(spec.placements[1], Replicate) - def test_sharding_info_acquisition(self): - """Test that non-XLAShardedTensor can acquire sharding information - - Tests case of 'elem is not an XLAShardedTensor but there exists - sharding information we want to acquire' - """ - - device_count = xr.global_runtime_device_count() - mesh_shape = (device_count,) - partition_spec = (0, None) - - regular_tensor = torch.randn(100, 50).to('xla') - - sharded_tensor = wrap_as_sharded_tensor( - regular_tensor, mesh_shape=mesh_shape, partition_spec=partition_spec) - - # Verify the tensor acquired the sharding information - assert isinstance(sharded_tensor, XLAShardedTensor) - assert sharded_tensor.mesh_shape == mesh_shape - assert sharded_tensor.partition_spec == partition_spec - - def test_resharding_logic(self): - """ - Tests wrap_as_sharded_tensor resharding before returning XLAShardedTensor t. - """ - - device_count = xr.global_runtime_device_count() - if device_count < 4: - self.skipTest("Need at least 4 devices for resharding test") - - # Initial sharding - initial_mesh_shape = (device_count,) - initial_partition_spec = (0, None) - new_mesh_shape = (2, device_count // 2) - new_partition_spec = (0, 1) - - # Create tensor and verify resharding - tensor = torch.randn(100, 50).to('xla') - sharded_tensor = wrap_as_sharded_tensor( - tensor, - mesh_shape=initial_mesh_shape, - partition_spec=initial_partition_spec) - initial_spec = sharded_tensor._spec - - resharded_tensor = wrap_as_sharded_tensor( - sharded_tensor, - mesh_shape=new_mesh_shape, - partition_spec=new_partition_spec) - - # Verify resharding worked and cache was invalidated - assert resharded_tensor.mesh_shape == new_mesh_shape - assert resharded_tensor.partition_spec == new_partition_spec - assert resharded_tensor._spec is not initial_spec - - def test_spec_invalidation_on_resharding(self): - """Tests cases where the cached spec may become outdated. - """ - - device_count = xr.global_runtime_device_count() - if device_count < 4: - self.skipTest("Need at least 4 devices for resharding test") - - tensor = torch.randn(100, 50).to('xla') - initial_mesh_shape = (device_count,) - initial_partition_spec = (0, None) - new_mesh_shape = (2, device_count // 2) - new_partition_spec = (0, 1) - - sharded_tensor = wrap_as_sharded_tensor( - tensor, - mesh_shape=initial_mesh_shape, - partition_spec=initial_partition_spec) - initial_spec = sharded_tensor._spec - assert sharded_tensor._cached_spec is not None - - # Changing mesh_shape / partition_spec through wrap_as_sharded_tensor invalidates cache - resharded_tensor = wrap_as_sharded_tensor( - sharded_tensor, - mesh_shape=new_mesh_shape, - partition_spec=initial_partition_spec) - assert resharded_tensor._spec is not initial_spec - assert resharded_tensor._spec.mesh.shape == new_mesh_shape - - initial_spec = resharded_tensor._spec - resharded_tensor = wrap_as_sharded_tensor( - resharded_tensor, - mesh_shape=new_mesh_shape, - partition_spec=new_partition_spec) - assert resharded_tensor._spec is not initial_spec - assert resharded_tensor._spec.placements[1].dim == 1 - - def test_auto_wrapped_tensor_spec_failure(self): - """Test that auto-wrapped tensors fail when accessing _spec property. - - Auto-wrapped tensors are created through operations that trigger __torch_dispatch__ - but don't yet have access to the sharding propagation done through open xla, - causing ._spec to fail. - """ - device_count = xr.global_runtime_device_count() - mesh = DeviceMesh("xla", torch.arange(device_count)) - tensor = torch.randn(4, 4) - sharded_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) - - auto_wrapped = sharded_tensor + sharded_tensor - - with self.assertRaises(ValueError): - _ = auto_wrapped._spec - if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index 652a2011cbd..f314e4de0a2 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -115,7 +115,7 @@ def __new__(cls, dtype=elem.dtype, layout=elem.layout, device=elem.device, - requires_grad=kwargs.get("requires_grad", False)) + requires_grad=kwargs.get("requires_grad", elem.requires_grad)) r.global_tensor = elem.detach() if r.requires_grad else elem # Initialize mesh, partition, and spec information @@ -151,6 +151,29 @@ def load_local_shards_(self, shards: List[XLAShard]): # Invalidate cached spec since the global_tensor data has changed self.invalidate_spec_cache() + def to_local(self): + """ + Returns the local representation of the XLAShardedTensor. + + This method returns the global tensor representation, which contains + the combined data across all devices. The returned tensor is on the + same device as the original XLAShardedTensor. The returned tensor + will have the same requires_grad value as the XLAShardedTensor. + If the original tensor has gradients, those will be preserved. + + Returns: + torch.Tensor: The global tensor representation with appropriate requires_grad setting. + """ + + # Create a new tensor with the same values of global_tensor + result = self.global_tensor.clone() + # Since global tensor is detached, add requires_grad and grad values back to the local tensor + if self.requires_grad: + result.requires_grad = self.requires_grad + result.grad = self.grad + + return result + @property def sharding_spec(self): return torch_xla._XLAC._get_xla_sharding_spec(self.global_tensor) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index c010fd4c352..1d6cd0248fe 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -767,27 +767,14 @@ def wrap_as_sharded_tensor(t: Union[torch.Tensor, XLAShardedTensor], partition_spec=None) -> XLAShardedTensor: # pass along mesh and partition spec information if not isinstance(t, XLAShardedTensor): - # Create a new XLAShardedTensor return XLAShardedTensor( t, mesh_shape=mesh_shape, partition_spec=partition_spec) - - # Update existing XLAShardedTensor if needed - needs_invalidate = False - - # Always set mesh_shape and partition_spec if provided - if mesh_shape is not None: - t.mesh_shape = mesh_shape - needs_invalidate = True - - if partition_spec is not None: - t.partition_spec = partition_spec - needs_invalidate = True - - # Invalidate cached spec if resharding occurred - if needs_invalidate: - t.invalidate_spec_cache() - - return t + else: + if mesh_shape is not None: + t.mesh_shape = mesh_shape + if partition_spec is not None: + t.partition_spec = partition_spec + return t def unwrap_sharded_tensor( From 858863e2ec15d05d69addc8c2335a98d27fc29a5 Mon Sep 17 00:00:00 2001 From: Claire Huang Date: Tue, 22 Jul 2025 18:35:54 +0000 Subject: [PATCH 02/13] Removed auto wrapping sharding propagation, added cached spec invalidation --- test/spmd/test_xla_dtensor_spec_conversion.py | 138 ++++++++++++++---- .../distributed/spmd/xla_sharded_tensor.py | 4 + torch_xla/distributed/spmd/xla_sharding.py | 25 +++- 3 files changed, 130 insertions(+), 37 deletions(-) diff --git a/test/spmd/test_xla_dtensor_spec_conversion.py b/test/spmd/test_xla_dtensor_spec_conversion.py index 12102f555c2..2fe53613e94 100644 --- a/test/spmd/test_xla_dtensor_spec_conversion.py +++ b/test/spmd/test_xla_dtensor_spec_conversion.py @@ -3,9 +3,12 @@ import torch from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor +from torch.distributed.tensor.placement_types import Replicate import torch_xla import torch_xla.runtime as xr +from torch_xla.distributed.spmd import XLAShardedTensor +from torch_xla.distributed.spmd.xla_sharding import wrap_as_sharded_tensor import unittest import test_xla_sharding_base @@ -31,7 +34,6 @@ def test_xla_to_dtensor_spec_conversion(self): mesh = DeviceMesh("xla", list(range(device_count))) # Test different sharding patterns - from torch.distributed.tensor.placement_types import Replicate test_cases = [ (torch.randn(100, 50), [Shard(0)]), (torch.randn(100, 50), [Shard(1)]), @@ -64,30 +66,27 @@ def test_mesh_conversion(self): assert converted_spec.mesh.shape == original_mesh.shape def test_spec_caching(self): - """Test that _spec property caches results for better performance""" - import time + """Test that _spec property caches results + + Addresses PR comment: "These sorts of tests that rely on the wall clock often lead to + annoying flakes in my experience. I think it's sufficient to just test that + self._cached_spec has a permanent value after the first call." + """ device_count = xr.global_runtime_device_count() mesh = DeviceMesh("xla", list(range(device_count))) - tensor = torch.randn(1000, - 1000) # Large tensor to make spec creation noticeable + tensor = torch.randn(100, 100) xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) - # first access should create and cache the spec - start_time = time.time() + # First access should create and cache the spec spec1 = xla_tensor._spec - first_access_time = time.time() - start_time - # should be much faster due to caching - start_time = time.time() - spec2 = xla_tensor._spec - second_access_time = time.time() - start_time + # Verify the spec is cached + assert xla_tensor._cached_spec is not None + assert xla_tensor._cached_spec is spec1 + # Second access should return the cached spec + spec2 = xla_tensor._spec assert spec1 is spec2 - print( - f"First access: {first_access_time:.6f}s, Second access: {second_access_time:.6f}s" - ) - assert second_access_time * 10 < first_access_time, \ - f"Cached access should be much faster: {first_access_time:.6f}s vs {second_access_time:.6f}s" def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements): """Helper to create tensor and mesh for testing""" @@ -114,22 +113,8 @@ def test_multi_dim_sharding_spec(self): assert len(spec.placements) == 2 assert spec.mesh.ndim == 2 - def test_tensor_operations_preserve_spec(self): - """Test that tensor operations preserve sharding metadata""" - xla_tensor, mesh = self._create_test_tensor_and_mesh((100, 50), (-1,), - [Shard(0)]) - - result_add = xla_tensor + 1 - result_mul = xla_tensor * 2 - result_relu = torch.relu(xla_tensor) - - for result in [result_add, result_mul, result_relu]: - assert hasattr(result, '_spec') - assert result._spec.mesh.device_type == "xla" - def test_mixed_placement_spec(self): """Test _spec for tensors with mixed shard/replicate placements""" - from torch.distributed.tensor.placement_types import Replicate device_count = xr.global_runtime_device_count() if device_count < 4: self.skipTest("Need at least 4 devices for 2D mesh") @@ -143,6 +128,97 @@ def test_mixed_placement_spec(self): assert isinstance(spec.placements[0], Shard) assert isinstance(spec.placements[1], Replicate) + def test_sharding_info_acquisition(self): + """Test that non-XLAShardedTensor can acquire sharding information + + Tests case of 'elem is not an XLAShardedTensor but there exists + sharding information we want to acquire' + """ + + device_count = xr.global_runtime_device_count() + mesh_shape = (device_count,) + partition_spec = (0, None) + + regular_tensor = torch.randn(100, 50).to('xla') + + sharded_tensor = wrap_as_sharded_tensor( + regular_tensor, mesh_shape=mesh_shape, partition_spec=partition_spec) + + # Verify the tensor acquired the sharding information + assert isinstance(sharded_tensor, XLAShardedTensor) + assert sharded_tensor.mesh_shape == mesh_shape + assert sharded_tensor.partition_spec == partition_spec + + def test_resharding_logic(self): + """ + Tests wrap_as_sharded_tensor resharding before returning XLAShardedTensor t. + """ + + device_count = xr.global_runtime_device_count() + if device_count < 4: + self.skipTest("Need at least 4 devices for resharding test") + + # Initial sharding + initial_mesh_shape = (device_count,) + initial_partition_spec = (0, None) + new_mesh_shape = (2, device_count // 2) + new_partition_spec = (0, 1) + + # Create tensor and verify resharding + tensor = torch.randn(100, 50).to('xla') + sharded_tensor = wrap_as_sharded_tensor( + tensor, + mesh_shape=initial_mesh_shape, + partition_spec=initial_partition_spec) + initial_spec = sharded_tensor._spec + + resharded_tensor = wrap_as_sharded_tensor( + sharded_tensor, + mesh_shape=new_mesh_shape, + partition_spec=new_partition_spec) + + # Verify resharding worked and cache was invalidated + assert resharded_tensor.mesh_shape == new_mesh_shape + assert resharded_tensor.partition_spec == new_partition_spec + assert resharded_tensor._spec is not initial_spec + + def test_spec_invalidation_on_resharding(self): + """Tests cases where the cached spec may become outdated. + """ + + device_count = xr.global_runtime_device_count() + if device_count < 4: + self.skipTest("Need at least 4 devices for resharding test") + + tensor = torch.randn(100, 50).to('xla') + initial_mesh_shape = (device_count,) + initial_partition_spec = (0, None) + new_mesh_shape = (2, device_count // 2) + new_partition_spec = (0, 1) + + sharded_tensor = wrap_as_sharded_tensor( + tensor, + mesh_shape=initial_mesh_shape, + partition_spec=initial_partition_spec) + initial_spec = sharded_tensor._spec + assert sharded_tensor._cached_spec is not None + + # Changing mesh_shape / partition_spec through wrap_as_sharded_tensor invalidates cache + resharded_tensor = wrap_as_sharded_tensor( + sharded_tensor, + mesh_shape=new_mesh_shape, + partition_spec=initial_partition_spec) + assert resharded_tensor._spec is not initial_spec + assert resharded_tensor._spec.mesh.shape == new_mesh_shape + + initial_spec = resharded_tensor._spec + resharded_tensor = wrap_as_sharded_tensor( + resharded_tensor, + mesh_shape=new_mesh_shape, + partition_spec=new_partition_spec) + assert resharded_tensor._spec is not initial_spec + assert resharded_tensor._spec.placements[1].dim == 1 + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index f314e4de0a2..4bbb4a0fd96 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -9,7 +9,11 @@ import torch_xla.runtime as xr from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.device_mesh import DeviceMesh +<<<<<<< HEAD from torch.distributed.tensor.placement_types import Placement, Shard, Replicate, Partial +======= +from torch.distributed.tensor.placement_types import Shard, Replicate +>>>>>>> 566959e10 (Removed auto wrapping sharding propagation, added cached spec invalidation) from torch.utils._pytree import tree_map_only from torch.distributed.tensor import DTensor diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 1d6cd0248fe..c010fd4c352 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -767,14 +767,27 @@ def wrap_as_sharded_tensor(t: Union[torch.Tensor, XLAShardedTensor], partition_spec=None) -> XLAShardedTensor: # pass along mesh and partition spec information if not isinstance(t, XLAShardedTensor): + # Create a new XLAShardedTensor return XLAShardedTensor( t, mesh_shape=mesh_shape, partition_spec=partition_spec) - else: - if mesh_shape is not None: - t.mesh_shape = mesh_shape - if partition_spec is not None: - t.partition_spec = partition_spec - return t + + # Update existing XLAShardedTensor if needed + needs_invalidate = False + + # Always set mesh_shape and partition_spec if provided + if mesh_shape is not None: + t.mesh_shape = mesh_shape + needs_invalidate = True + + if partition_spec is not None: + t.partition_spec = partition_spec + needs_invalidate = True + + # Invalidate cached spec if resharding occurred + if needs_invalidate: + t.invalidate_spec_cache() + + return t def unwrap_sharded_tensor( From dd79690e7d08b531d742d0afa889e147c2e7870c Mon Sep 17 00:00:00 2001 From: Claire Huang Date: Tue, 22 Jul 2025 18:40:12 +0000 Subject: [PATCH 03/13] Removing lazy import --- test/spmd/test_xla_dtensor_spec_conversion.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/test/spmd/test_xla_dtensor_spec_conversion.py b/test/spmd/test_xla_dtensor_spec_conversion.py index 2fe53613e94..e48a323bd15 100644 --- a/test/spmd/test_xla_dtensor_spec_conversion.py +++ b/test/spmd/test_xla_dtensor_spec_conversion.py @@ -67,24 +67,17 @@ def test_mesh_conversion(self): def test_spec_caching(self): """Test that _spec property caches results - - Addresses PR comment: "These sorts of tests that rely on the wall clock often lead to - annoying flakes in my experience. I think it's sufficient to just test that - self._cached_spec has a permanent value after the first call." """ device_count = xr.global_runtime_device_count() mesh = DeviceMesh("xla", list(range(device_count))) tensor = torch.randn(100, 100) xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) - # First access should create and cache the spec spec1 = xla_tensor._spec - # Verify the spec is cached assert xla_tensor._cached_spec is not None assert xla_tensor._cached_spec is spec1 - # Second access should return the cached spec spec2 = xla_tensor._spec assert spec1 is spec2 From 8c434788c88aaf71fdb05b34406c35946e091beb Mon Sep 17 00:00:00 2001 From: Claire Huang Date: Tue, 22 Jul 2025 20:43:48 +0000 Subject: [PATCH 04/13] Added test for catching thrown error in spec --- test/spmd/test_xla_dtensor_spec_conversion.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/spmd/test_xla_dtensor_spec_conversion.py b/test/spmd/test_xla_dtensor_spec_conversion.py index e48a323bd15..81cb8a4aa2e 100644 --- a/test/spmd/test_xla_dtensor_spec_conversion.py +++ b/test/spmd/test_xla_dtensor_spec_conversion.py @@ -212,6 +212,23 @@ def test_spec_invalidation_on_resharding(self): assert resharded_tensor._spec is not initial_spec assert resharded_tensor._spec.placements[1].dim == 1 + def test_auto_wrapped_tensor_spec_failure(self): + """Test that auto-wrapped tensors fail when accessing _spec property. + + Auto-wrapped tensors are created through operations that trigger __torch_dispatch__ + but don't yet have access to the sharding propagation done through open xla, + causing ._spec to fail. + """ + device_count = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", torch.arange(device_count)) + tensor = torch.randn(4, 4) + sharded_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) + + auto_wrapped = sharded_tensor + sharded_tensor + + with self.assertRaises(ValueError): + _ = auto_wrapped._spec + if __name__ == '__main__': test = unittest.main() From 4b5af4bf3cec9278e28f72754cb8ed556f13dffe Mon Sep 17 00:00:00 2001 From: Hooman Hashemi Date: Wed, 23 Jul 2025 23:29:03 +0000 Subject: [PATCH 05/13] Test for Routing XLA device handling through distribute_tensor to ensure proper XLA support and maintain consistency with PyTorch/XLA SPMD integration. --- test/neuron/run_tests.sh | 1 + test/run_tests.sh | 1 + test/spmd/test_xla_dtensor_from_local.py | 149 +++++++++++++++++++++++ test/tpu/run_tests.sh | 1 + 4 files changed, 152 insertions(+) create mode 100644 test/spmd/test_xla_dtensor_from_local.py diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index a68e0671a3b..076867b8e9a 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -257,6 +257,7 @@ function run_xla_op_tests3 { run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py" run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_redistribute.py" + run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py" run_test_multi_device "$_TEST_DIR/spmd/test_xla_sharded_tensor.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" #run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" diff --git a/test/run_tests.sh b/test/run_tests.sh index bb03d7abe16..cf55f4a3606 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -238,6 +238,7 @@ function run_xla_op_tests3 { run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_redistribute.py" + run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py" run_test_multi_devices "$_TEST_DIR/spmd/test_xla_sharded_tensor.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" diff --git a/test/spmd/test_xla_dtensor_from_local.py b/test/spmd/test_xla_dtensor_from_local.py new file mode 100644 index 00000000000..40647e20590 --- /dev/null +++ b/test/spmd/test_xla_dtensor_from_local.py @@ -0,0 +1,149 @@ +import sys +import unittest +import torch +import numpy as np + +from torch.distributed.tensor import DeviceMesh +from torch.distributed._tensor import DTensor +from torch.distributed.tensor.placement_types import Replicate, Shard +import torch_xla +import torch_xla.runtime as xr +import torch_xla.core.xla_model as xm +from torch_xla.distributed.spmd.xla_sharded_tensor import XLAShardedTensor +import test_xla_sharding_base + + +class DTensorXLAFromLocalConversionTest(test_xla_sharding_base.XlaShardingTest): + """ + Test suite for the automatic conversion of regular tensors to XLAShardedTensor + in DTensor.from_local() when using XLA device mesh. + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + + def test_basic_conversion(self): + """Test basic conversion of regular tensor to XLAShardedTensor.""" + world_size = xr.global_runtime_device_count() + + # Create a regular tensor (not on XLA device) + tensor = torch.randn(100_000, 88) + tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison + + # Create a DeviceMesh + device_mesh = DeviceMesh("xla", list(range(world_size))) + + # Use DTensor.from_local with the regular tensor + dt = DTensor.from_local(tensor, device_mesh=device_mesh) + + # Verify the tensor was converted correctly + self.assertEqual(dt.shape, tensor.shape) + + # Check the value of the tensor + torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False) + + # Verify operations work + result = dt + 1.0 + self.assertEqual(result.shape, tensor.shape) + + print("Basic conversion successful") + + + def test_conversion_with_placements(self): + """Test conversion with explicit placements.""" + world_size = xr.global_runtime_device_count() + + # Create a regular tensor (not on XLA device) + tensor = torch.randn(100_000, 88) + tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison + + # Create a DeviceMesh + device_mesh = DeviceMesh("xla", list(range(world_size))) + + # Use DTensor.from_local with explicit placements + dt = DTensor.from_local( + tensor, + device_mesh=device_mesh, + placements=[Replicate()] + ) + + # Verify the tensor was converted correctly + self.assertEqual(dt.shape, tensor.shape) + + # Check the value of the tensor + torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False) + + # Verify operations work + result = dt + 1.0 + self.assertEqual(result.shape, tensor.shape) + + print("Conversion with placements successful") + + def test_conversion_with_sharding(self): + """Test conversion with sharding placement.""" + world_size = xr.global_runtime_device_count() + if world_size < 2: + self.skipTest("Need at least 2 devices for sharding test") + + # Create a tensor divisible by world_size + tensor = torch.randn(100_000, 88) + tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison + + # Create a DeviceMesh + device_mesh = DeviceMesh("xla", list(range(world_size))) + + # Use DTensor.from_local with sharding placement + dt = DTensor.from_local( + tensor, + device_mesh=device_mesh, + placements=[Shard(0)] + ) + + # Verify the tensor was converted correctly + self.assertEqual(dt.shape, tensor.shape) + + # Check the value of the tensor + torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False) + + # Verify operations work + result = dt + 1.0 + self.assertEqual(result.shape, tensor.shape) + + print("Conversion with sharding successful") + + def test_conversion_with_different_dtypes(self): + """Test conversion with different dtypes.""" + world_size = xr.global_runtime_device_count() + device_mesh = DeviceMesh("xla", list(range(world_size))) + + # Test with different dtypes + for dtype in [torch.float16, torch.float32, torch.int32, torch.int64]: + # Create a tensor with specific dtype + tensor = torch.ones(100_000, 88, dtype=dtype) + tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison + + # Use DTensor.from_local with the tensor + dt = DTensor.from_local(tensor, device_mesh=device_mesh) + + # Verify dtype is preserved + self.assertEqual(dt.dtype, dtype) + + # Check the value of the tensor + torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False) + + # Verify operations work + if dtype.is_floating_point: + result = dt + 1.0 + else: + result = dt + 1 + + self.assertEqual(result.shape, tensor.shape) + self.assertEqual(result.dtype, dtype) + + print(f"Conversion with {dtype} successful") + + +if __name__ == "__main__": + result = unittest.main(exit=False) + sys.exit(0 if result.result.wasSuccessful() else 1) \ No newline at end of file diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index e1ad7c0023a..ae8b2b10c5c 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -63,6 +63,7 @@ run_test "$_TEST_DIR/spmd/test_fsdp_v2.py" run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" run_test "$_TEST_DIR/spmd/test_dtensor_redistribute.py" +run_test "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py" run_test "$_TEST_DIR/spmd/test_xla_sharded_tensor.py" run_test "$_TEST_DIR/test_gradient_accumulation.py" XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v From df9279b19badaacf096ebacb3a5e11d7865fe447 Mon Sep 17 00:00:00 2001 From: Hooman Hashemi Date: Thu, 24 Jul 2025 00:12:59 +0000 Subject: [PATCH 06/13] [XLA] Implement XLAShardedTensor.to_local() --- test/neuron/run_tests.sh | 2 +- test/run_tests.sh | 2 +- test/spmd/test_xla_dtensor_to_local.py | 75 ++++++++++++++++++++++++++ test/tpu/run_tests.sh | 2 +- 4 files changed, 78 insertions(+), 3 deletions(-) create mode 100644 test/spmd/test_xla_dtensor_to_local.py diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index 076867b8e9a..b052b754fb1 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -257,7 +257,7 @@ function run_xla_op_tests3 { run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py" run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_redistribute.py" - run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py" + run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_to_local.py" run_test_multi_device "$_TEST_DIR/spmd/test_xla_sharded_tensor.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" #run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" diff --git a/test/run_tests.sh b/test/run_tests.sh index cf55f4a3606..10715855aa6 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -238,7 +238,7 @@ function run_xla_op_tests3 { run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_redistribute.py" - run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py" + run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_to_local.py" run_test_multi_devices "$_TEST_DIR/spmd/test_xla_sharded_tensor.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" diff --git a/test/spmd/test_xla_dtensor_to_local.py b/test/spmd/test_xla_dtensor_to_local.py new file mode 100644 index 00000000000..cdba1b3ca73 --- /dev/null +++ b/test/spmd/test_xla_dtensor_to_local.py @@ -0,0 +1,75 @@ +import sys +import unittest +import torch +import numpy as np + +from torch.distributed.tensor import DeviceMesh +from torch.distributed._tensor import DTensor +from torch.distributed.tensor.placement_types import Replicate, Shard +import torch_xla +import torch_xla.runtime as xr +import torch_xla.core.xla_model as xm +from torch_xla.distributed.spmd.xla_sharded_tensor import XLAShardedTensor +import test_xla_sharding_base + + +class DTensorXLAFromLocalConversionTest(test_xla_sharding_base.XlaShardingTest): + """ + Test suite for the automatic conversion of regular tensors to XLAShardedTensor + in DTensor.from_local() when using XLA device mesh. + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + + def test_to_local(self): + from torch.distributed.tensor import distribute_tensor + world_size = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", list(range(world_size))) + + big_tensor = torch.randn(100000, 88) + sharded_tensor = XLAShardedTensor(big_tensor, mesh, [Shard(0)]) + + local_tensor = sharded_tensor.to_local() + + # Verify the shapes are the same + self.assertEqual(local_tensor.shape, big_tensor.shape) + + # Check the value of the tensor + torch.testing.assert_close(local_tensor, big_tensor, check_device=False) + + def test_to_local_requires_grad(self): + """Test that gradients flow correctly through to_local().""" + # Create a tensor with requires_grad=True + world_size = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", list(range(world_size))) + + tensor = torch.randn(100_000, 88, requires_grad=True) + + # Create XLAShardedTensor + sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)]) + + # Verify requires_grad is set + self.assertTrue(sharded_tensor.requires_grad) + + res = sharded_tensor.sum() + res.backward() + + # Verify grad are calculated + self.assertTrue(sharded_tensor.grad is not None) + + # Call to local function + local_tensor = sharded_tensor.to_local() + + # Verify requires_grad is preserved + self.assertTrue(local_tensor.requires_grad) + + # All gradients should be 1.0 since we did a sum() + self.assertTrue(torch.allclose(local_tensor.grad, torch.ones_like(tensor))) + + print("Gradient flow test successful") + +if __name__ == "__main__": + result = unittest.main(exit=False) + sys.exit(0 if result.result.wasSuccessful() else 1) \ No newline at end of file diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index ae8b2b10c5c..0712e131a81 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -63,7 +63,7 @@ run_test "$_TEST_DIR/spmd/test_fsdp_v2.py" run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" run_test "$_TEST_DIR/spmd/test_dtensor_redistribute.py" -run_test "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py" +run_test "$_TEST_DIR/spmd/test_xla_dtensor_to_local.py" run_test "$_TEST_DIR/spmd/test_xla_sharded_tensor.py" run_test "$_TEST_DIR/test_gradient_accumulation.py" XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v From bf331da5df8cb3f9003ee576674b07dc9bef1339 Mon Sep 17 00:00:00 2001 From: Hooman Hashemi Date: Fri, 25 Jul 2025 22:43:09 +0000 Subject: [PATCH 07/13] run git_fix for yapf --- test/spmd/test_xla_dtensor_from_local.py | 149 ----------------------- test/spmd/test_xla_dtensor_to_local.py | 107 ++++++++-------- 2 files changed, 54 insertions(+), 202 deletions(-) delete mode 100644 test/spmd/test_xla_dtensor_from_local.py diff --git a/test/spmd/test_xla_dtensor_from_local.py b/test/spmd/test_xla_dtensor_from_local.py deleted file mode 100644 index 40647e20590..00000000000 --- a/test/spmd/test_xla_dtensor_from_local.py +++ /dev/null @@ -1,149 +0,0 @@ -import sys -import unittest -import torch -import numpy as np - -from torch.distributed.tensor import DeviceMesh -from torch.distributed._tensor import DTensor -from torch.distributed.tensor.placement_types import Replicate, Shard -import torch_xla -import torch_xla.runtime as xr -import torch_xla.core.xla_model as xm -from torch_xla.distributed.spmd.xla_sharded_tensor import XLAShardedTensor -import test_xla_sharding_base - - -class DTensorXLAFromLocalConversionTest(test_xla_sharding_base.XlaShardingTest): - """ - Test suite for the automatic conversion of regular tensors to XLAShardedTensor - in DTensor.from_local() when using XLA device mesh. - """ - - @classmethod - def setUpClass(cls): - super().setUpClass() - - def test_basic_conversion(self): - """Test basic conversion of regular tensor to XLAShardedTensor.""" - world_size = xr.global_runtime_device_count() - - # Create a regular tensor (not on XLA device) - tensor = torch.randn(100_000, 88) - tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison - - # Create a DeviceMesh - device_mesh = DeviceMesh("xla", list(range(world_size))) - - # Use DTensor.from_local with the regular tensor - dt = DTensor.from_local(tensor, device_mesh=device_mesh) - - # Verify the tensor was converted correctly - self.assertEqual(dt.shape, tensor.shape) - - # Check the value of the tensor - torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False) - - # Verify operations work - result = dt + 1.0 - self.assertEqual(result.shape, tensor.shape) - - print("Basic conversion successful") - - - def test_conversion_with_placements(self): - """Test conversion with explicit placements.""" - world_size = xr.global_runtime_device_count() - - # Create a regular tensor (not on XLA device) - tensor = torch.randn(100_000, 88) - tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison - - # Create a DeviceMesh - device_mesh = DeviceMesh("xla", list(range(world_size))) - - # Use DTensor.from_local with explicit placements - dt = DTensor.from_local( - tensor, - device_mesh=device_mesh, - placements=[Replicate()] - ) - - # Verify the tensor was converted correctly - self.assertEqual(dt.shape, tensor.shape) - - # Check the value of the tensor - torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False) - - # Verify operations work - result = dt + 1.0 - self.assertEqual(result.shape, tensor.shape) - - print("Conversion with placements successful") - - def test_conversion_with_sharding(self): - """Test conversion with sharding placement.""" - world_size = xr.global_runtime_device_count() - if world_size < 2: - self.skipTest("Need at least 2 devices for sharding test") - - # Create a tensor divisible by world_size - tensor = torch.randn(100_000, 88) - tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison - - # Create a DeviceMesh - device_mesh = DeviceMesh("xla", list(range(world_size))) - - # Use DTensor.from_local with sharding placement - dt = DTensor.from_local( - tensor, - device_mesh=device_mesh, - placements=[Shard(0)] - ) - - # Verify the tensor was converted correctly - self.assertEqual(dt.shape, tensor.shape) - - # Check the value of the tensor - torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False) - - # Verify operations work - result = dt + 1.0 - self.assertEqual(result.shape, tensor.shape) - - print("Conversion with sharding successful") - - def test_conversion_with_different_dtypes(self): - """Test conversion with different dtypes.""" - world_size = xr.global_runtime_device_count() - device_mesh = DeviceMesh("xla", list(range(world_size))) - - # Test with different dtypes - for dtype in [torch.float16, torch.float32, torch.int32, torch.int64]: - # Create a tensor with specific dtype - tensor = torch.ones(100_000, 88, dtype=dtype) - tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison - - # Use DTensor.from_local with the tensor - dt = DTensor.from_local(tensor, device_mesh=device_mesh) - - # Verify dtype is preserved - self.assertEqual(dt.dtype, dtype) - - # Check the value of the tensor - torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False) - - # Verify operations work - if dtype.is_floating_point: - result = dt + 1.0 - else: - result = dt + 1 - - self.assertEqual(result.shape, tensor.shape) - self.assertEqual(result.dtype, dtype) - - print(f"Conversion with {dtype} successful") - - -if __name__ == "__main__": - result = unittest.main(exit=False) - sys.exit(0 if result.result.wasSuccessful() else 1) \ No newline at end of file diff --git a/test/spmd/test_xla_dtensor_to_local.py b/test/spmd/test_xla_dtensor_to_local.py index cdba1b3ca73..0c0f2fad588 100644 --- a/test/spmd/test_xla_dtensor_to_local.py +++ b/test/spmd/test_xla_dtensor_to_local.py @@ -14,62 +14,63 @@ class DTensorXLAFromLocalConversionTest(test_xla_sharding_base.XlaShardingTest): - """ + """ Test suite for the automatic conversion of regular tensors to XLAShardedTensor in DTensor.from_local() when using XLA device mesh. """ - @classmethod - def setUpClass(cls): - super().setUpClass() - - def test_to_local(self): - from torch.distributed.tensor import distribute_tensor - world_size = xr.global_runtime_device_count() - mesh = DeviceMesh("xla", list(range(world_size))) - - big_tensor = torch.randn(100000, 88) - sharded_tensor = XLAShardedTensor(big_tensor, mesh, [Shard(0)]) - - local_tensor = sharded_tensor.to_local() - - # Verify the shapes are the same - self.assertEqual(local_tensor.shape, big_tensor.shape) - - # Check the value of the tensor - torch.testing.assert_close(local_tensor, big_tensor, check_device=False) - - def test_to_local_requires_grad(self): - """Test that gradients flow correctly through to_local().""" - # Create a tensor with requires_grad=True - world_size = xr.global_runtime_device_count() - mesh = DeviceMesh("xla", list(range(world_size))) - - tensor = torch.randn(100_000, 88, requires_grad=True) - - # Create XLAShardedTensor - sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)]) - - # Verify requires_grad is set - self.assertTrue(sharded_tensor.requires_grad) - - res = sharded_tensor.sum() - res.backward() - - # Verify grad are calculated - self.assertTrue(sharded_tensor.grad is not None) - - # Call to local function - local_tensor = sharded_tensor.to_local() - - # Verify requires_grad is preserved - self.assertTrue(local_tensor.requires_grad) - - # All gradients should be 1.0 since we did a sum() - self.assertTrue(torch.allclose(local_tensor.grad, torch.ones_like(tensor))) - - print("Gradient flow test successful") + @classmethod + def setUpClass(cls): + super().setUpClass() + + def test_to_local(self): + from torch.distributed.tensor import distribute_tensor + world_size = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", list(range(world_size))) + + big_tensor = torch.randn(100000, 88) + sharded_tensor = XLAShardedTensor(big_tensor, mesh, [Shard(0)]) + + local_tensor = sharded_tensor.to_local() + + # Verify the shapes are the same + self.assertEqual(local_tensor.shape, big_tensor.shape) + + # Check the value of the tensor + torch.testing.assert_close(local_tensor, big_tensor, check_device=False) + + def test_to_local_requires_grad(self): + """Test that gradients flow correctly through to_local().""" + # Create a tensor with requires_grad=True + world_size = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", list(range(world_size))) + + tensor = torch.randn(100_000, 88, requires_grad=True) + + # Create XLAShardedTensor + sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)]) + + # Verify requires_grad is set + self.assertTrue(sharded_tensor.requires_grad) + + res = sharded_tensor.sum() + res.backward() + + # Verify grad are calculated + self.assertTrue(sharded_tensor.grad is not None) + + # Call to local function + local_tensor = sharded_tensor.to_local() + + # Verify requires_grad is preserved + self.assertTrue(local_tensor.requires_grad) + + # All gradients should be 1.0 since we did a sum() + self.assertTrue(torch.allclose(local_tensor.grad, torch.ones_like(tensor))) + + print("Gradient flow test successful") + if __name__ == "__main__": - result = unittest.main(exit=False) - sys.exit(0 if result.result.wasSuccessful() else 1) \ No newline at end of file + result = unittest.main(exit=False) + sys.exit(0 if result.result.wasSuccessful() else 1) From 04c5cd369dacf9cedcb21282645590a7400d666d Mon Sep 17 00:00:00 2001 From: Hooman Hashemi Date: Wed, 13 Aug 2025 00:24:15 +0000 Subject: [PATCH 08/13] Remove print statement --- test/spmd/test_xla_dtensor_to_local.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/spmd/test_xla_dtensor_to_local.py b/test/spmd/test_xla_dtensor_to_local.py index 0c0f2fad588..ad2197b9906 100644 --- a/test/spmd/test_xla_dtensor_to_local.py +++ b/test/spmd/test_xla_dtensor_to_local.py @@ -68,8 +68,6 @@ def test_to_local_requires_grad(self): # All gradients should be 1.0 since we did a sum() self.assertTrue(torch.allclose(local_tensor.grad, torch.ones_like(tensor))) - print("Gradient flow test successful") - if __name__ == "__main__": result = unittest.main(exit=False) From 52a5e705287167be6157a5daefdbcd41cec5240f Mon Sep 17 00:00:00 2001 From: Hooman Hashemi Date: Wed, 13 Aug 2025 00:28:14 +0000 Subject: [PATCH 09/13] code clean up --- torch_xla/distributed/spmd/xla_sharded_tensor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index 4bbb4a0fd96..f314e4de0a2 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -9,11 +9,7 @@ import torch_xla.runtime as xr from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.device_mesh import DeviceMesh -<<<<<<< HEAD from torch.distributed.tensor.placement_types import Placement, Shard, Replicate, Partial -======= -from torch.distributed.tensor.placement_types import Shard, Replicate ->>>>>>> 566959e10 (Removed auto wrapping sharding propagation, added cached spec invalidation) from torch.utils._pytree import tree_map_only from torch.distributed.tensor import DTensor From 4c7ffc285b9162abe1a3042c3bc4d630306c6312 Mon Sep 17 00:00:00 2001 From: Hooman Hashemi Date: Thu, 14 Aug 2025 19:27:13 +0000 Subject: [PATCH 10/13] Remove redundant setUpClass constructor --- test/spmd/test_xla_dtensor_to_local.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/spmd/test_xla_dtensor_to_local.py b/test/spmd/test_xla_dtensor_to_local.py index ad2197b9906..b68fde4afb2 100644 --- a/test/spmd/test_xla_dtensor_to_local.py +++ b/test/spmd/test_xla_dtensor_to_local.py @@ -19,10 +19,6 @@ class DTensorXLAFromLocalConversionTest(test_xla_sharding_base.XlaShardingTest): in DTensor.from_local() when using XLA device mesh. """ - @classmethod - def setUpClass(cls): - super().setUpClass() - def test_to_local(self): from torch.distributed.tensor import distribute_tensor world_size = xr.global_runtime_device_count() From bd9c9f37534c58b8e78db7ee17c4dd9bc421d7a4 Mon Sep 17 00:00:00 2001 From: Hooman Hashemi Date: Fri, 15 Aug 2025 18:06:07 +0000 Subject: [PATCH 11/13] Clone the grads and use inplace method for requires_grad --- test/spmd/test_xla_dtensor_to_local.py | 43 +++++++++++++++++++ .../distributed/spmd/xla_sharded_tensor.py | 4 +- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/test/spmd/test_xla_dtensor_to_local.py b/test/spmd/test_xla_dtensor_to_local.py index b68fde4afb2..62e22485557 100644 --- a/test/spmd/test_xla_dtensor_to_local.py +++ b/test/spmd/test_xla_dtensor_to_local.py @@ -64,6 +64,49 @@ def test_to_local_requires_grad(self): # All gradients should be 1.0 since we did a sum() self.assertTrue(torch.allclose(local_tensor.grad, torch.ones_like(tensor))) + def test_to_local_grad_independence(self): + """Test that gradients are independent between original and local tensor.""" + world_size = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", list(range(world_size))) + + tensor = torch.randn(100_000, 88, requires_grad=True) + sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)]) + + # Create gradients + res = sharded_tensor.sum() + res.backward() + + # Get local tensor + local_tensor = sharded_tensor.to_local() + + # Verify gradients are initially the same + self.assertTrue(torch.allclose(local_tensor.grad, sharded_tensor.grad)) + + # Modify local tensor's gradient + local_tensor.grad[0, 0] = 999.0 + + # Verify gradients are now independent (not the same object) + self.assertFalse(local_tensor.grad is sharded_tensor.grad) + self.assertFalse(torch.allclose(local_tensor.grad, sharded_tensor.grad)) + + def test_to_local_grad_none_handling(self): + """Test that to_local() handles None gradients correctly.""" + world_size = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", list(range(world_size))) + + tensor = torch.randn(100_000, 88, requires_grad=True) + sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)]) + + # Don't do backward pass, so grad remains None + self.assertIsNone(sharded_tensor.grad) + + # Get local tensor + local_tensor = sharded_tensor.to_local() + + # Verify local tensor has correct properties + self.assertTrue(local_tensor.requires_grad) + self.assertIsNone(local_tensor.grad) + if __name__ == "__main__": result = unittest.main(exit=False) diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index f314e4de0a2..6c82d94af92 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -169,8 +169,8 @@ def to_local(self): result = self.global_tensor.clone() # Since global tensor is detached, add requires_grad and grad values back to the local tensor if self.requires_grad: - result.requires_grad = self.requires_grad - result.grad = self.grad + result.requires_grad_(self.requires_grad) + result.grad = self.grad.clone() if self.grad is not None else None return result From be7ab62eadece8844bf2b7bc7b8585adcb482dd9 Mon Sep 17 00:00:00 2001 From: Hooman Hashemi Date: Mon, 18 Aug 2025 07:12:59 +0000 Subject: [PATCH 12/13] fix the failing CI by reverting to default requires_grad --- test/spmd/test_xla_dtensor_to_local.py | 6 ++--- .../distributed/spmd/xla_sharded_tensor.py | 22 ++++++++++++------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/test/spmd/test_xla_dtensor_to_local.py b/test/spmd/test_xla_dtensor_to_local.py index 62e22485557..f1741a92980 100644 --- a/test/spmd/test_xla_dtensor_to_local.py +++ b/test/spmd/test_xla_dtensor_to_local.py @@ -44,7 +44,7 @@ def test_to_local_requires_grad(self): tensor = torch.randn(100_000, 88, requires_grad=True) # Create XLAShardedTensor - sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)]) + sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad) # Verify requires_grad is set self.assertTrue(sharded_tensor.requires_grad) @@ -70,7 +70,7 @@ def test_to_local_grad_independence(self): mesh = DeviceMesh("xla", list(range(world_size))) tensor = torch.randn(100_000, 88, requires_grad=True) - sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)]) + sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad) # Create gradients res = sharded_tensor.sum() @@ -95,7 +95,7 @@ def test_to_local_grad_none_handling(self): mesh = DeviceMesh("xla", list(range(world_size))) tensor = torch.randn(100_000, 88, requires_grad=True) - sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)]) + sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad) # Don't do backward pass, so grad remains None self.assertIsNone(sharded_tensor.grad) diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index 6c82d94af92..20010862850 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -115,7 +115,7 @@ def __new__(cls, dtype=elem.dtype, layout=elem.layout, device=elem.device, - requires_grad=kwargs.get("requires_grad", elem.requires_grad)) + requires_grad=kwargs.get("requires_grad", False)) r.global_tensor = elem.detach() if r.requires_grad else elem # Initialize mesh, partition, and spec information @@ -165,14 +165,20 @@ def to_local(self): torch.Tensor: The global tensor representation with appropriate requires_grad setting. """ - # Create a new tensor with the same values of global_tensor - result = self.global_tensor.clone() - # Since global tensor is detached, add requires_grad and grad values back to the local tensor - if self.requires_grad: - result.requires_grad_(self.requires_grad) - result.grad = self.grad.clone() if self.grad is not None else None - return result + if not self.requires_grad: + # When requires_grad is False, global_tensor is the original tensor + return self.global_tensor + else: + # When requires_grad is True, global_tensor is detached + # Create a new tensor with the same values of global_tensor + result = self.global_tensor.clone() + # Since global tensor is detached, add requires_grad and grad values back to the local tensor + if self.requires_grad: + result.requires_grad_(self.requires_grad) + result.grad = self.grad.clone() if self.grad is not None else None + + return result @property def sharding_spec(self): From b378a462ac014302054e85c642abb1e16e32f1c6 Mon Sep 17 00:00:00 2001 From: Hooman Hashemi Date: Mon, 18 Aug 2025 16:32:18 +0000 Subject: [PATCH 13/13] run yapf --- test/spmd/test_xla_dtensor_to_local.py | 9 ++++++--- torch_xla/distributed/spmd/xla_sharded_tensor.py | 1 - 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/test/spmd/test_xla_dtensor_to_local.py b/test/spmd/test_xla_dtensor_to_local.py index f1741a92980..2335720a027 100644 --- a/test/spmd/test_xla_dtensor_to_local.py +++ b/test/spmd/test_xla_dtensor_to_local.py @@ -44,7 +44,8 @@ def test_to_local_requires_grad(self): tensor = torch.randn(100_000, 88, requires_grad=True) # Create XLAShardedTensor - sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad) + sharded_tensor = XLAShardedTensor( + tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad) # Verify requires_grad is set self.assertTrue(sharded_tensor.requires_grad) @@ -70,7 +71,8 @@ def test_to_local_grad_independence(self): mesh = DeviceMesh("xla", list(range(world_size))) tensor = torch.randn(100_000, 88, requires_grad=True) - sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad) + sharded_tensor = XLAShardedTensor( + tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad) # Create gradients res = sharded_tensor.sum() @@ -95,7 +97,8 @@ def test_to_local_grad_none_handling(self): mesh = DeviceMesh("xla", list(range(world_size))) tensor = torch.randn(100_000, 88, requires_grad=True) - sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad) + sharded_tensor = XLAShardedTensor( + tensor, mesh, [Shard(0)], requires_grad=tensor.requires_grad) # Don't do backward pass, so grad remains None self.assertIsNone(sharded_tensor.grad) diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index 20010862850..1355111eeb6 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -165,7 +165,6 @@ def to_local(self): torch.Tensor: The global tensor representation with appropriate requires_grad setting. """ - if not self.requires_grad: # When requires_grad is False, global_tensor is the original tensor return self.global_tensor