From 9af66d899de68ebdf346b2de1192d00ba36e526b Mon Sep 17 00:00:00 2001 From: Claire Huang Date: Tue, 15 Jul 2025 22:26:45 +0000 Subject: [PATCH 1/5] Implement XLAShardedTensor._spec and test --- test/spmd/test_xla_dtensor_spec_conversion.py | 148 ++++-------------- .../distributed/spmd/xla_sharded_tensor.py | 62 ++++---- torch_xla/distributed/spmd/xla_sharding.py | 25 +-- 3 files changed, 72 insertions(+), 163 deletions(-) diff --git a/test/spmd/test_xla_dtensor_spec_conversion.py b/test/spmd/test_xla_dtensor_spec_conversion.py index 81cb8a4aa2e..12102f555c2 100644 --- a/test/spmd/test_xla_dtensor_spec_conversion.py +++ b/test/spmd/test_xla_dtensor_spec_conversion.py @@ -3,12 +3,9 @@ import torch from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor -from torch.distributed.tensor.placement_types import Replicate import torch_xla import torch_xla.runtime as xr -from torch_xla.distributed.spmd import XLAShardedTensor -from torch_xla.distributed.spmd.xla_sharding import wrap_as_sharded_tensor import unittest import test_xla_sharding_base @@ -34,6 +31,7 @@ def test_xla_to_dtensor_spec_conversion(self): mesh = DeviceMesh("xla", list(range(device_count))) # Test different sharding patterns + from torch.distributed.tensor.placement_types import Replicate test_cases = [ (torch.randn(100, 50), [Shard(0)]), (torch.randn(100, 50), [Shard(1)]), @@ -66,20 +64,30 @@ def test_mesh_conversion(self): assert converted_spec.mesh.shape == original_mesh.shape def test_spec_caching(self): - """Test that _spec property caches results - """ + """Test that _spec property caches results for better performance""" + import time device_count = xr.global_runtime_device_count() mesh = DeviceMesh("xla", list(range(device_count))) - tensor = torch.randn(100, 100) + tensor = torch.randn(1000, + 1000) # Large tensor to make spec creation noticeable xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) + # first access should create and cache the spec + start_time = time.time() spec1 = xla_tensor._spec + first_access_time = time.time() - start_time - assert xla_tensor._cached_spec is not None - assert xla_tensor._cached_spec is spec1 - + # should be much faster due to caching + start_time = time.time() spec2 = xla_tensor._spec + second_access_time = time.time() - start_time + assert spec1 is spec2 + print( + f"First access: {first_access_time:.6f}s, Second access: {second_access_time:.6f}s" + ) + assert second_access_time * 10 < first_access_time, \ + f"Cached access should be much faster: {first_access_time:.6f}s vs {second_access_time:.6f}s" def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements): """Helper to create tensor and mesh for testing""" @@ -106,8 +114,22 @@ def test_multi_dim_sharding_spec(self): assert len(spec.placements) == 2 assert spec.mesh.ndim == 2 + def test_tensor_operations_preserve_spec(self): + """Test that tensor operations preserve sharding metadata""" + xla_tensor, mesh = self._create_test_tensor_and_mesh((100, 50), (-1,), + [Shard(0)]) + + result_add = xla_tensor + 1 + result_mul = xla_tensor * 2 + result_relu = torch.relu(xla_tensor) + + for result in [result_add, result_mul, result_relu]: + assert hasattr(result, '_spec') + assert result._spec.mesh.device_type == "xla" + def test_mixed_placement_spec(self): """Test _spec for tensors with mixed shard/replicate placements""" + from torch.distributed.tensor.placement_types import Replicate device_count = xr.global_runtime_device_count() if device_count < 4: self.skipTest("Need at least 4 devices for 2D mesh") @@ -121,114 +143,6 @@ def test_mixed_placement_spec(self): assert isinstance(spec.placements[0], Shard) assert isinstance(spec.placements[1], Replicate) - def test_sharding_info_acquisition(self): - """Test that non-XLAShardedTensor can acquire sharding information - - Tests case of 'elem is not an XLAShardedTensor but there exists - sharding information we want to acquire' - """ - - device_count = xr.global_runtime_device_count() - mesh_shape = (device_count,) - partition_spec = (0, None) - - regular_tensor = torch.randn(100, 50).to('xla') - - sharded_tensor = wrap_as_sharded_tensor( - regular_tensor, mesh_shape=mesh_shape, partition_spec=partition_spec) - - # Verify the tensor acquired the sharding information - assert isinstance(sharded_tensor, XLAShardedTensor) - assert sharded_tensor.mesh_shape == mesh_shape - assert sharded_tensor.partition_spec == partition_spec - - def test_resharding_logic(self): - """ - Tests wrap_as_sharded_tensor resharding before returning XLAShardedTensor t. - """ - - device_count = xr.global_runtime_device_count() - if device_count < 4: - self.skipTest("Need at least 4 devices for resharding test") - - # Initial sharding - initial_mesh_shape = (device_count,) - initial_partition_spec = (0, None) - new_mesh_shape = (2, device_count // 2) - new_partition_spec = (0, 1) - - # Create tensor and verify resharding - tensor = torch.randn(100, 50).to('xla') - sharded_tensor = wrap_as_sharded_tensor( - tensor, - mesh_shape=initial_mesh_shape, - partition_spec=initial_partition_spec) - initial_spec = sharded_tensor._spec - - resharded_tensor = wrap_as_sharded_tensor( - sharded_tensor, - mesh_shape=new_mesh_shape, - partition_spec=new_partition_spec) - - # Verify resharding worked and cache was invalidated - assert resharded_tensor.mesh_shape == new_mesh_shape - assert resharded_tensor.partition_spec == new_partition_spec - assert resharded_tensor._spec is not initial_spec - - def test_spec_invalidation_on_resharding(self): - """Tests cases where the cached spec may become outdated. - """ - - device_count = xr.global_runtime_device_count() - if device_count < 4: - self.skipTest("Need at least 4 devices for resharding test") - - tensor = torch.randn(100, 50).to('xla') - initial_mesh_shape = (device_count,) - initial_partition_spec = (0, None) - new_mesh_shape = (2, device_count // 2) - new_partition_spec = (0, 1) - - sharded_tensor = wrap_as_sharded_tensor( - tensor, - mesh_shape=initial_mesh_shape, - partition_spec=initial_partition_spec) - initial_spec = sharded_tensor._spec - assert sharded_tensor._cached_spec is not None - - # Changing mesh_shape / partition_spec through wrap_as_sharded_tensor invalidates cache - resharded_tensor = wrap_as_sharded_tensor( - sharded_tensor, - mesh_shape=new_mesh_shape, - partition_spec=initial_partition_spec) - assert resharded_tensor._spec is not initial_spec - assert resharded_tensor._spec.mesh.shape == new_mesh_shape - - initial_spec = resharded_tensor._spec - resharded_tensor = wrap_as_sharded_tensor( - resharded_tensor, - mesh_shape=new_mesh_shape, - partition_spec=new_partition_spec) - assert resharded_tensor._spec is not initial_spec - assert resharded_tensor._spec.placements[1].dim == 1 - - def test_auto_wrapped_tensor_spec_failure(self): - """Test that auto-wrapped tensors fail when accessing _spec property. - - Auto-wrapped tensors are created through operations that trigger __torch_dispatch__ - but don't yet have access to the sharding propagation done through open xla, - causing ._spec to fail. - """ - device_count = xr.global_runtime_device_count() - mesh = DeviceMesh("xla", torch.arange(device_count)) - tensor = torch.randn(4, 4) - sharded_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) - - auto_wrapped = sharded_tensor + sharded_tensor - - with self.assertRaises(ValueError): - _ = auto_wrapped._spec - if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index a20d530f3fa..314dea09591 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -10,7 +10,6 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.placement_types import Shard, Replicate -from torch.utils._pytree import tree_map_only @dataclass @@ -116,13 +115,11 @@ def __new__(cls, device=elem.device, requires_grad=kwargs.get("requires_grad", False)) r.global_tensor = elem.detach() if r.requires_grad else elem - - # Initialize mesh, partition, and spec information - r.mesh_shape = mesh_shape or (elem.mesh_shape if isinstance( - elem, XLAShardedTensor) else None) - r.partition_spec = partition_spec or (elem.partition_spec if isinstance( - elem, XLAShardedTensor) else None) - r._cached_spec = None + # Store mesh and partition information for DTensor compatibility + if mesh_shape is not None: + r.mesh_shape = mesh_shape + if partition_spec is not None: + r.partition_spec = partition_spec return r # Shards on the devices are materialized/available after the lazy @@ -179,7 +176,27 @@ def unwrap(elem): return elem.global_tensor if isinstance(elem, XLAShardedTensor) else elem def wrap(elem): - return XLAShardedTensor(elem) if isinstance(elem, torch.Tensor) else elem + if isinstance(elem, + torch.Tensor) and not isinstance(elem, XLAShardedTensor): + # Try to get mesh/partition info from any XLAShardedTensor in args + mesh_shape = None + partition_spec = None + + def find_sharded_info(x): + nonlocal mesh_shape, partition_spec + if isinstance(x, XLAShardedTensor): + if hasattr(x, 'mesh_shape') and x.mesh_shape: + mesh_shape = x.mesh_shape + if hasattr(x, 'partition_spec') and x.partition_spec: + partition_spec = x.partition_spec + + tree_map(find_sharded_info, args) + if kwargs: + tree_map(find_sharded_info, kwargs) + + return XLAShardedTensor( + elem, mesh_shape=mesh_shape, partition_spec=partition_spec) + return elem # no_dispatch is only needed if you use enable_python_mode. # It prevents infinite recursion. @@ -195,26 +212,25 @@ def _spec(self): Convert XLA sharding information to DTensorSpec for DTensor interface compatibility. """ # Return cached spec if available - if self._cached_spec is not None: + if hasattr(self, '_cached_spec'): return self._cached_spec # use existing mesh_shape - if self.mesh_shape is not None: + if hasattr(self, 'mesh_shape') and self.mesh_shape: + import torch_xla.runtime as xr device_count = xr.global_runtime_device_count() device_list = list(range(device_count)) mesh = DeviceMesh("xla", torch.tensor(device_list).reshape(self.mesh_shape)) else: - raise ValueError( - "mesh_shape must be specified to create DTensorSpec. " - "If this tensor was created through torch operations, it may be auto-wrapped. " - "Use wrap_as_sharded_tensor() to set mesh_shape before accessing _spec. " - ) + raise ValueError("mesh_shape must be specified to create DTensorSpec") # use existing partition_spec - if self.partition_spec is not None: + if hasattr(self, 'partition_spec') and self.partition_spec: placements = [] - for mesh_dim in range(len(self.mesh_shape)): + for mesh_dim in range( + len(self.mesh_shape + ) if hasattr(self, 'mesh_shape') and self.mesh_shape else 1): # find tensor dimension sharded on this mesh dimension tensor_dim = None for t_dim, m_dim in enumerate(self.partition_spec): @@ -224,11 +240,7 @@ def _spec(self): placements.append( Shard(tensor_dim) if tensor_dim is not None else Replicate()) else: - raise ValueError( - "partition_spec must be specified to create DTensorSpec. " - "If this tensor was created through torch operations, it may be auto-wrapped. " - "Use wrap_as_sharded_tensor() to set partition_spec before accessing _spec. " - ) + raise ValueError("partition_spec must be specified to create DTensorSpec") # tensor metadata tensor_meta = TensorMeta( @@ -241,10 +253,6 @@ def _spec(self): mesh=mesh, placements=tuple(placements), tensor_meta=tensor_meta) return self._cached_spec - def invalidate_spec_cache(self): - """Invalidate the cached DTensorSpec.""" - self._cached_spec = None - @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): return super().__torch_function__(func, types, args, kwargs) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 5f4d4378e7d..751fe7e9a66 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -765,27 +765,14 @@ def wrap_as_sharded_tensor(t: Union[torch.Tensor, XLAShardedTensor], partition_spec=None) -> XLAShardedTensor: # pass along mesh and partition spec information if not isinstance(t, XLAShardedTensor): - # Create a new XLAShardedTensor return XLAShardedTensor( t, mesh_shape=mesh_shape, partition_spec=partition_spec) - - # Update existing XLAShardedTensor if needed - needs_invalidate = False - - # Always set mesh_shape and partition_spec if provided - if mesh_shape is not None: - t.mesh_shape = mesh_shape - needs_invalidate = True - - if partition_spec is not None: - t.partition_spec = partition_spec - needs_invalidate = True - - # Invalidate cached spec if resharding occurred - if needs_invalidate: - t.invalidate_spec_cache() - - return t + else: + if mesh_shape is not None: + t.mesh_shape = mesh_shape + if partition_spec is not None: + t.partition_spec = partition_spec + return t def unwrap_sharded_tensor( From 6c2de4d9b6fe13db145b442e820e0f055854f3a6 Mon Sep 17 00:00:00 2001 From: Claire Huang Date: Tue, 22 Jul 2025 18:35:54 +0000 Subject: [PATCH 2/5] Removed auto wrapping sharding propagation, added cached spec invalidation --- test/spmd/test_xla_dtensor_spec_conversion.py | 138 ++++++++++++++---- .../distributed/spmd/xla_sharded_tensor.py | 49 +++---- torch_xla/distributed/spmd/xla_sharding.py | 25 +++- 3 files changed, 143 insertions(+), 69 deletions(-) diff --git a/test/spmd/test_xla_dtensor_spec_conversion.py b/test/spmd/test_xla_dtensor_spec_conversion.py index 12102f555c2..2fe53613e94 100644 --- a/test/spmd/test_xla_dtensor_spec_conversion.py +++ b/test/spmd/test_xla_dtensor_spec_conversion.py @@ -3,9 +3,12 @@ import torch from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor +from torch.distributed.tensor.placement_types import Replicate import torch_xla import torch_xla.runtime as xr +from torch_xla.distributed.spmd import XLAShardedTensor +from torch_xla.distributed.spmd.xla_sharding import wrap_as_sharded_tensor import unittest import test_xla_sharding_base @@ -31,7 +34,6 @@ def test_xla_to_dtensor_spec_conversion(self): mesh = DeviceMesh("xla", list(range(device_count))) # Test different sharding patterns - from torch.distributed.tensor.placement_types import Replicate test_cases = [ (torch.randn(100, 50), [Shard(0)]), (torch.randn(100, 50), [Shard(1)]), @@ -64,30 +66,27 @@ def test_mesh_conversion(self): assert converted_spec.mesh.shape == original_mesh.shape def test_spec_caching(self): - """Test that _spec property caches results for better performance""" - import time + """Test that _spec property caches results + + Addresses PR comment: "These sorts of tests that rely on the wall clock often lead to + annoying flakes in my experience. I think it's sufficient to just test that + self._cached_spec has a permanent value after the first call." + """ device_count = xr.global_runtime_device_count() mesh = DeviceMesh("xla", list(range(device_count))) - tensor = torch.randn(1000, - 1000) # Large tensor to make spec creation noticeable + tensor = torch.randn(100, 100) xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) - # first access should create and cache the spec - start_time = time.time() + # First access should create and cache the spec spec1 = xla_tensor._spec - first_access_time = time.time() - start_time - # should be much faster due to caching - start_time = time.time() - spec2 = xla_tensor._spec - second_access_time = time.time() - start_time + # Verify the spec is cached + assert xla_tensor._cached_spec is not None + assert xla_tensor._cached_spec is spec1 + # Second access should return the cached spec + spec2 = xla_tensor._spec assert spec1 is spec2 - print( - f"First access: {first_access_time:.6f}s, Second access: {second_access_time:.6f}s" - ) - assert second_access_time * 10 < first_access_time, \ - f"Cached access should be much faster: {first_access_time:.6f}s vs {second_access_time:.6f}s" def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements): """Helper to create tensor and mesh for testing""" @@ -114,22 +113,8 @@ def test_multi_dim_sharding_spec(self): assert len(spec.placements) == 2 assert spec.mesh.ndim == 2 - def test_tensor_operations_preserve_spec(self): - """Test that tensor operations preserve sharding metadata""" - xla_tensor, mesh = self._create_test_tensor_and_mesh((100, 50), (-1,), - [Shard(0)]) - - result_add = xla_tensor + 1 - result_mul = xla_tensor * 2 - result_relu = torch.relu(xla_tensor) - - for result in [result_add, result_mul, result_relu]: - assert hasattr(result, '_spec') - assert result._spec.mesh.device_type == "xla" - def test_mixed_placement_spec(self): """Test _spec for tensors with mixed shard/replicate placements""" - from torch.distributed.tensor.placement_types import Replicate device_count = xr.global_runtime_device_count() if device_count < 4: self.skipTest("Need at least 4 devices for 2D mesh") @@ -143,6 +128,97 @@ def test_mixed_placement_spec(self): assert isinstance(spec.placements[0], Shard) assert isinstance(spec.placements[1], Replicate) + def test_sharding_info_acquisition(self): + """Test that non-XLAShardedTensor can acquire sharding information + + Tests case of 'elem is not an XLAShardedTensor but there exists + sharding information we want to acquire' + """ + + device_count = xr.global_runtime_device_count() + mesh_shape = (device_count,) + partition_spec = (0, None) + + regular_tensor = torch.randn(100, 50).to('xla') + + sharded_tensor = wrap_as_sharded_tensor( + regular_tensor, mesh_shape=mesh_shape, partition_spec=partition_spec) + + # Verify the tensor acquired the sharding information + assert isinstance(sharded_tensor, XLAShardedTensor) + assert sharded_tensor.mesh_shape == mesh_shape + assert sharded_tensor.partition_spec == partition_spec + + def test_resharding_logic(self): + """ + Tests wrap_as_sharded_tensor resharding before returning XLAShardedTensor t. + """ + + device_count = xr.global_runtime_device_count() + if device_count < 4: + self.skipTest("Need at least 4 devices for resharding test") + + # Initial sharding + initial_mesh_shape = (device_count,) + initial_partition_spec = (0, None) + new_mesh_shape = (2, device_count // 2) + new_partition_spec = (0, 1) + + # Create tensor and verify resharding + tensor = torch.randn(100, 50).to('xla') + sharded_tensor = wrap_as_sharded_tensor( + tensor, + mesh_shape=initial_mesh_shape, + partition_spec=initial_partition_spec) + initial_spec = sharded_tensor._spec + + resharded_tensor = wrap_as_sharded_tensor( + sharded_tensor, + mesh_shape=new_mesh_shape, + partition_spec=new_partition_spec) + + # Verify resharding worked and cache was invalidated + assert resharded_tensor.mesh_shape == new_mesh_shape + assert resharded_tensor.partition_spec == new_partition_spec + assert resharded_tensor._spec is not initial_spec + + def test_spec_invalidation_on_resharding(self): + """Tests cases where the cached spec may become outdated. + """ + + device_count = xr.global_runtime_device_count() + if device_count < 4: + self.skipTest("Need at least 4 devices for resharding test") + + tensor = torch.randn(100, 50).to('xla') + initial_mesh_shape = (device_count,) + initial_partition_spec = (0, None) + new_mesh_shape = (2, device_count // 2) + new_partition_spec = (0, 1) + + sharded_tensor = wrap_as_sharded_tensor( + tensor, + mesh_shape=initial_mesh_shape, + partition_spec=initial_partition_spec) + initial_spec = sharded_tensor._spec + assert sharded_tensor._cached_spec is not None + + # Changing mesh_shape / partition_spec through wrap_as_sharded_tensor invalidates cache + resharded_tensor = wrap_as_sharded_tensor( + sharded_tensor, + mesh_shape=new_mesh_shape, + partition_spec=initial_partition_spec) + assert resharded_tensor._spec is not initial_spec + assert resharded_tensor._spec.mesh.shape == new_mesh_shape + + initial_spec = resharded_tensor._spec + resharded_tensor = wrap_as_sharded_tensor( + resharded_tensor, + mesh_shape=new_mesh_shape, + partition_spec=new_partition_spec) + assert resharded_tensor._spec is not initial_spec + assert resharded_tensor._spec.placements[1].dim == 1 + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index 314dea09591..3e90c4467af 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -10,6 +10,7 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.placement_types import Shard, Replicate +from torch.utils._pytree import tree_map_only @dataclass @@ -115,11 +116,13 @@ def __new__(cls, device=elem.device, requires_grad=kwargs.get("requires_grad", False)) r.global_tensor = elem.detach() if r.requires_grad else elem - # Store mesh and partition information for DTensor compatibility - if mesh_shape is not None: - r.mesh_shape = mesh_shape - if partition_spec is not None: - r.partition_spec = partition_spec + + # Initialize mesh, partition, and spec information + r.mesh_shape = mesh_shape or (elem.mesh_shape if isinstance( + elem, XLAShardedTensor) else None) + r.partition_spec = partition_spec or (elem.partition_spec if isinstance( + elem, XLAShardedTensor) else None) + r._cached_spec = None return r # Shards on the devices are materialized/available after the lazy @@ -176,27 +179,7 @@ def unwrap(elem): return elem.global_tensor if isinstance(elem, XLAShardedTensor) else elem def wrap(elem): - if isinstance(elem, - torch.Tensor) and not isinstance(elem, XLAShardedTensor): - # Try to get mesh/partition info from any XLAShardedTensor in args - mesh_shape = None - partition_spec = None - - def find_sharded_info(x): - nonlocal mesh_shape, partition_spec - if isinstance(x, XLAShardedTensor): - if hasattr(x, 'mesh_shape') and x.mesh_shape: - mesh_shape = x.mesh_shape - if hasattr(x, 'partition_spec') and x.partition_spec: - partition_spec = x.partition_spec - - tree_map(find_sharded_info, args) - if kwargs: - tree_map(find_sharded_info, kwargs) - - return XLAShardedTensor( - elem, mesh_shape=mesh_shape, partition_spec=partition_spec) - return elem + return XLAShardedTensor(elem) if isinstance(elem, torch.Tensor) else elem # no_dispatch is only needed if you use enable_python_mode. # It prevents infinite recursion. @@ -212,11 +195,11 @@ def _spec(self): Convert XLA sharding information to DTensorSpec for DTensor interface compatibility. """ # Return cached spec if available - if hasattr(self, '_cached_spec'): + if self._cached_spec is not None: return self._cached_spec # use existing mesh_shape - if hasattr(self, 'mesh_shape') and self.mesh_shape: + if self.mesh_shape is not None: import torch_xla.runtime as xr device_count = xr.global_runtime_device_count() device_list = list(range(device_count)) @@ -226,11 +209,9 @@ def _spec(self): raise ValueError("mesh_shape must be specified to create DTensorSpec") # use existing partition_spec - if hasattr(self, 'partition_spec') and self.partition_spec: + if self.partition_spec is not None: placements = [] - for mesh_dim in range( - len(self.mesh_shape - ) if hasattr(self, 'mesh_shape') and self.mesh_shape else 1): + for mesh_dim in range(len(self.mesh_shape)): # find tensor dimension sharded on this mesh dimension tensor_dim = None for t_dim, m_dim in enumerate(self.partition_spec): @@ -253,6 +234,10 @@ def _spec(self): mesh=mesh, placements=tuple(placements), tensor_meta=tensor_meta) return self._cached_spec + def invalidate_spec_cache(self): + """Invalidate the cached DTensorSpec.""" + self._cached_spec = None + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): return super().__torch_function__(func, types, args, kwargs) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 751fe7e9a66..5f4d4378e7d 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -765,14 +765,27 @@ def wrap_as_sharded_tensor(t: Union[torch.Tensor, XLAShardedTensor], partition_spec=None) -> XLAShardedTensor: # pass along mesh and partition spec information if not isinstance(t, XLAShardedTensor): + # Create a new XLAShardedTensor return XLAShardedTensor( t, mesh_shape=mesh_shape, partition_spec=partition_spec) - else: - if mesh_shape is not None: - t.mesh_shape = mesh_shape - if partition_spec is not None: - t.partition_spec = partition_spec - return t + + # Update existing XLAShardedTensor if needed + needs_invalidate = False + + # Always set mesh_shape and partition_spec if provided + if mesh_shape is not None: + t.mesh_shape = mesh_shape + needs_invalidate = True + + if partition_spec is not None: + t.partition_spec = partition_spec + needs_invalidate = True + + # Invalidate cached spec if resharding occurred + if needs_invalidate: + t.invalidate_spec_cache() + + return t def unwrap_sharded_tensor( From ef5f9ad0ce7140faf3e7db9b7e5422a267e10db1 Mon Sep 17 00:00:00 2001 From: Claire Huang Date: Tue, 22 Jul 2025 18:40:12 +0000 Subject: [PATCH 3/5] Removing lazy import --- test/spmd/test_xla_dtensor_spec_conversion.py | 7 ------- torch_xla/distributed/spmd/xla_sharded_tensor.py | 1 - 2 files changed, 8 deletions(-) diff --git a/test/spmd/test_xla_dtensor_spec_conversion.py b/test/spmd/test_xla_dtensor_spec_conversion.py index 2fe53613e94..e48a323bd15 100644 --- a/test/spmd/test_xla_dtensor_spec_conversion.py +++ b/test/spmd/test_xla_dtensor_spec_conversion.py @@ -67,24 +67,17 @@ def test_mesh_conversion(self): def test_spec_caching(self): """Test that _spec property caches results - - Addresses PR comment: "These sorts of tests that rely on the wall clock often lead to - annoying flakes in my experience. I think it's sufficient to just test that - self._cached_spec has a permanent value after the first call." """ device_count = xr.global_runtime_device_count() mesh = DeviceMesh("xla", list(range(device_count))) tensor = torch.randn(100, 100) xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) - # First access should create and cache the spec spec1 = xla_tensor._spec - # Verify the spec is cached assert xla_tensor._cached_spec is not None assert xla_tensor._cached_spec is spec1 - # Second access should return the cached spec spec2 = xla_tensor._spec assert spec1 is spec2 diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index 3e90c4467af..ed590e22fd8 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -200,7 +200,6 @@ def _spec(self): # use existing mesh_shape if self.mesh_shape is not None: - import torch_xla.runtime as xr device_count = xr.global_runtime_device_count() device_list = list(range(device_count)) mesh = DeviceMesh("xla", From b395d4cdcdf0c470cbbb3176715ce58429f89b1d Mon Sep 17 00:00:00 2001 From: Claire Huang Date: Tue, 22 Jul 2025 20:43:48 +0000 Subject: [PATCH 4/5] Added test for catching thrown error in spec --- test/spmd/test_xla_dtensor_spec_conversion.py | 17 +++++++++++++++++ .../distributed/spmd/xla_sharded_tensor.py | 12 ++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/test/spmd/test_xla_dtensor_spec_conversion.py b/test/spmd/test_xla_dtensor_spec_conversion.py index e48a323bd15..81cb8a4aa2e 100644 --- a/test/spmd/test_xla_dtensor_spec_conversion.py +++ b/test/spmd/test_xla_dtensor_spec_conversion.py @@ -212,6 +212,23 @@ def test_spec_invalidation_on_resharding(self): assert resharded_tensor._spec is not initial_spec assert resharded_tensor._spec.placements[1].dim == 1 + def test_auto_wrapped_tensor_spec_failure(self): + """Test that auto-wrapped tensors fail when accessing _spec property. + + Auto-wrapped tensors are created through operations that trigger __torch_dispatch__ + but don't yet have access to the sharding propagation done through open xla, + causing ._spec to fail. + """ + device_count = xr.global_runtime_device_count() + mesh = DeviceMesh("xla", torch.arange(device_count)) + tensor = torch.randn(4, 4) + sharded_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) + + auto_wrapped = sharded_tensor + sharded_tensor + + with self.assertRaises(ValueError): + _ = auto_wrapped._spec + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index ed590e22fd8..a20d530f3fa 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -205,7 +205,11 @@ def _spec(self): mesh = DeviceMesh("xla", torch.tensor(device_list).reshape(self.mesh_shape)) else: - raise ValueError("mesh_shape must be specified to create DTensorSpec") + raise ValueError( + "mesh_shape must be specified to create DTensorSpec. " + "If this tensor was created through torch operations, it may be auto-wrapped. " + "Use wrap_as_sharded_tensor() to set mesh_shape before accessing _spec. " + ) # use existing partition_spec if self.partition_spec is not None: @@ -220,7 +224,11 @@ def _spec(self): placements.append( Shard(tensor_dim) if tensor_dim is not None else Replicate()) else: - raise ValueError("partition_spec must be specified to create DTensorSpec") + raise ValueError( + "partition_spec must be specified to create DTensorSpec. " + "If this tensor was created through torch operations, it may be auto-wrapped. " + "Use wrap_as_sharded_tensor() to set partition_spec before accessing _spec. " + ) # tensor metadata tensor_meta = TensorMeta( From 7e8003f108a222183ea1207da6df7c826b8bb903 Mon Sep 17 00:00:00 2001 From: Hooman Hashemi Date: Wed, 23 Jul 2025 23:29:03 +0000 Subject: [PATCH 5/5] Test for Routing XLA device handling through distribute_tensor to ensure proper XLA support and maintain consistency with PyTorch/XLA SPMD integration. --- test/neuron/run_tests.sh | 1 + test/run_tests.sh | 1 + test/spmd/test_xla_dtensor_from_local.py | 149 +++++++++++++++++++++++ test/tpu/run_tests.sh | 1 + 4 files changed, 152 insertions(+) create mode 100644 test/spmd/test_xla_dtensor_from_local.py diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index f7671cc3d82..3472a073ea3 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -256,6 +256,7 @@ function run_xla_op_tests3 { #run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py" run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py" + run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" #run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py" diff --git a/test/run_tests.sh b/test/run_tests.sh index b2cc8f751d2..dabcdb83961 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -255,6 +255,7 @@ function run_xla_op_tests3 { run_test_multi_devices_without_func "$_TEST_DIR/spmd/test_dtensor_integration3.py" run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" + run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py" run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py" run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py" diff --git a/test/spmd/test_xla_dtensor_from_local.py b/test/spmd/test_xla_dtensor_from_local.py new file mode 100644 index 00000000000..40647e20590 --- /dev/null +++ b/test/spmd/test_xla_dtensor_from_local.py @@ -0,0 +1,149 @@ +import sys +import unittest +import torch +import numpy as np + +from torch.distributed.tensor import DeviceMesh +from torch.distributed._tensor import DTensor +from torch.distributed.tensor.placement_types import Replicate, Shard +import torch_xla +import torch_xla.runtime as xr +import torch_xla.core.xla_model as xm +from torch_xla.distributed.spmd.xla_sharded_tensor import XLAShardedTensor +import test_xla_sharding_base + + +class DTensorXLAFromLocalConversionTest(test_xla_sharding_base.XlaShardingTest): + """ + Test suite for the automatic conversion of regular tensors to XLAShardedTensor + in DTensor.from_local() when using XLA device mesh. + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + + def test_basic_conversion(self): + """Test basic conversion of regular tensor to XLAShardedTensor.""" + world_size = xr.global_runtime_device_count() + + # Create a regular tensor (not on XLA device) + tensor = torch.randn(100_000, 88) + tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison + + # Create a DeviceMesh + device_mesh = DeviceMesh("xla", list(range(world_size))) + + # Use DTensor.from_local with the regular tensor + dt = DTensor.from_local(tensor, device_mesh=device_mesh) + + # Verify the tensor was converted correctly + self.assertEqual(dt.shape, tensor.shape) + + # Check the value of the tensor + torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False) + + # Verify operations work + result = dt + 1.0 + self.assertEqual(result.shape, tensor.shape) + + print("Basic conversion successful") + + + def test_conversion_with_placements(self): + """Test conversion with explicit placements.""" + world_size = xr.global_runtime_device_count() + + # Create a regular tensor (not on XLA device) + tensor = torch.randn(100_000, 88) + tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison + + # Create a DeviceMesh + device_mesh = DeviceMesh("xla", list(range(world_size))) + + # Use DTensor.from_local with explicit placements + dt = DTensor.from_local( + tensor, + device_mesh=device_mesh, + placements=[Replicate()] + ) + + # Verify the tensor was converted correctly + self.assertEqual(dt.shape, tensor.shape) + + # Check the value of the tensor + torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False) + + # Verify operations work + result = dt + 1.0 + self.assertEqual(result.shape, tensor.shape) + + print("Conversion with placements successful") + + def test_conversion_with_sharding(self): + """Test conversion with sharding placement.""" + world_size = xr.global_runtime_device_count() + if world_size < 2: + self.skipTest("Need at least 2 devices for sharding test") + + # Create a tensor divisible by world_size + tensor = torch.randn(100_000, 88) + tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison + + # Create a DeviceMesh + device_mesh = DeviceMesh("xla", list(range(world_size))) + + # Use DTensor.from_local with sharding placement + dt = DTensor.from_local( + tensor, + device_mesh=device_mesh, + placements=[Shard(0)] + ) + + # Verify the tensor was converted correctly + self.assertEqual(dt.shape, tensor.shape) + + # Check the value of the tensor + torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False) + + # Verify operations work + result = dt + 1.0 + self.assertEqual(result.shape, tensor.shape) + + print("Conversion with sharding successful") + + def test_conversion_with_different_dtypes(self): + """Test conversion with different dtypes.""" + world_size = xr.global_runtime_device_count() + device_mesh = DeviceMesh("xla", list(range(world_size))) + + # Test with different dtypes + for dtype in [torch.float16, torch.float32, torch.int32, torch.int64]: + # Create a tensor with specific dtype + tensor = torch.ones(100_000, 88, dtype=dtype) + tensor_cpu = tensor.cpu() # Keep a CPU copy for comparison + + # Use DTensor.from_local with the tensor + dt = DTensor.from_local(tensor, device_mesh=device_mesh) + + # Verify dtype is preserved + self.assertEqual(dt.dtype, dtype) + + # Check the value of the tensor + torch.testing.assert_close(dt.global_tensor, tensor_cpu, check_device=False) + + # Verify operations work + if dtype.is_floating_point: + result = dt + 1.0 + else: + result = dt + 1 + + self.assertEqual(result.shape, tensor.shape) + self.assertEqual(result.dtype, dtype) + + print(f"Conversion with {dtype} successful") + + +if __name__ == "__main__": + result = unittest.main(exit=False) + sys.exit(0 if result.result.wasSuccessful() else 1) \ No newline at end of file diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 24f18d3bdcd..ec585716cb6 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -62,6 +62,7 @@ run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py" run_test "$_TEST_DIR/spmd/test_fsdp_v2.py" run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py" run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py" +run_test "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py" run_test "$_TEST_DIR/test_gradient_accumulation.py" XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v run_test "$_TEST_DIR/test_autocast.py"