Skip to content

Commit e768a54

Browse files
committed
[XLA] Implement XLAShardedTensor.to_local()
1 parent 91e49ee commit e768a54

File tree

4 files changed

+78
-3
lines changed

4 files changed

+78
-3
lines changed

test/neuron/run_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ function run_xla_op_tests3 {
257257
run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
258258
run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py"
259259
run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_redistribute.py"
260-
run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py"
260+
run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_to_local.py"
261261
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
262262
#run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
263263
run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py"

test/run_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ function run_xla_op_tests3 {
257257
run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
258258
run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
259259
run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_redistribute.py"
260-
run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py"
260+
run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_to_local.py"
261261
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
262262
run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
263263
run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py"
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import sys
2+
import unittest
3+
import torch
4+
import numpy as np
5+
6+
from torch.distributed.tensor import DeviceMesh
7+
from torch.distributed._tensor import DTensor
8+
from torch.distributed.tensor.placement_types import Replicate, Shard
9+
import torch_xla
10+
import torch_xla.runtime as xr
11+
import torch_xla.core.xla_model as xm
12+
from torch_xla.distributed.spmd.xla_sharded_tensor import XLAShardedTensor
13+
import test_xla_sharding_base
14+
15+
16+
class DTensorXLAFromLocalConversionTest(test_xla_sharding_base.XlaShardingTest):
17+
"""
18+
Test suite for the automatic conversion of regular tensors to XLAShardedTensor
19+
in DTensor.from_local() when using XLA device mesh.
20+
"""
21+
22+
@classmethod
23+
def setUpClass(cls):
24+
super().setUpClass()
25+
26+
def test_to_local(self):
27+
from torch.distributed.tensor import distribute_tensor
28+
world_size = xr.global_runtime_device_count()
29+
mesh = DeviceMesh("xla", list(range(world_size)))
30+
31+
big_tensor = torch.randn(100000, 88)
32+
sharded_tensor = XLAShardedTensor(big_tensor, mesh, [Shard(0)])
33+
34+
local_tensor = sharded_tensor.to_local()
35+
36+
# Verify the shapes are the same
37+
self.assertEqual(local_tensor.shape, big_tensor.shape)
38+
39+
# Check the value of the tensor
40+
torch.testing.assert_close(local_tensor, big_tensor, check_device=False)
41+
42+
def test_to_local_requires_grad(self):
43+
"""Test that gradients flow correctly through to_local()."""
44+
# Create a tensor with requires_grad=True
45+
world_size = xr.global_runtime_device_count()
46+
mesh = DeviceMesh("xla", list(range(world_size)))
47+
48+
tensor = torch.randn(100_000, 88, requires_grad=True)
49+
50+
# Create XLAShardedTensor
51+
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)])
52+
53+
# Verify requires_grad is set
54+
self.assertTrue(sharded_tensor.requires_grad)
55+
56+
res = sharded_tensor.sum()
57+
res.backward()
58+
59+
# Verify grad are calculated
60+
self.assertTrue(sharded_tensor.grad is not None)
61+
62+
# Call to local function
63+
local_tensor = sharded_tensor.to_local()
64+
65+
# Verify requires_grad is preserved
66+
self.assertTrue(local_tensor.requires_grad)
67+
68+
# All gradients should be 1.0 since we did a sum()
69+
self.assertTrue(torch.allclose(local_tensor.grad, torch.ones_like(tensor)))
70+
71+
print("Gradient flow test successful")
72+
73+
if __name__ == "__main__":
74+
result = unittest.main(exit=False)
75+
sys.exit(0 if result.result.wasSuccessful() else 1)

test/tpu/run_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ run_test "$_TEST_DIR/spmd/test_fsdp_v2.py"
6363
run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
6464
run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
6565
run_test "$_TEST_DIR/spmd/test_dtensor_redistribute.py"
66-
run_test "$_TEST_DIR/spmd/test_xla_dtensor_from_local.py"
66+
run_test "$_TEST_DIR/spmd/test_xla_dtensor_to_local.py"
6767
run_test "$_TEST_DIR/test_gradient_accumulation.py"
6868
XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v
6969
run_test "$_TEST_DIR/test_autocast.py"

0 commit comments

Comments
 (0)