Skip to content

Commit 080e7d5

Browse files
committed
run git_fix for yapf
1 parent e768a54 commit 080e7d5

File tree

2 files changed

+54
-202
lines changed

2 files changed

+54
-202
lines changed

test/spmd/test_xla_dtensor_from_local.py

Lines changed: 0 additions & 149 deletions
This file was deleted.

test/spmd/test_xla_dtensor_to_local.py

Lines changed: 54 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -14,62 +14,63 @@
1414

1515

1616
class DTensorXLAFromLocalConversionTest(test_xla_sharding_base.XlaShardingTest):
17-
"""
17+
"""
1818
Test suite for the automatic conversion of regular tensors to XLAShardedTensor
1919
in DTensor.from_local() when using XLA device mesh.
2020
"""
2121

22-
@classmethod
23-
def setUpClass(cls):
24-
super().setUpClass()
25-
26-
def test_to_local(self):
27-
from torch.distributed.tensor import distribute_tensor
28-
world_size = xr.global_runtime_device_count()
29-
mesh = DeviceMesh("xla", list(range(world_size)))
30-
31-
big_tensor = torch.randn(100000, 88)
32-
sharded_tensor = XLAShardedTensor(big_tensor, mesh, [Shard(0)])
33-
34-
local_tensor = sharded_tensor.to_local()
35-
36-
# Verify the shapes are the same
37-
self.assertEqual(local_tensor.shape, big_tensor.shape)
38-
39-
# Check the value of the tensor
40-
torch.testing.assert_close(local_tensor, big_tensor, check_device=False)
41-
42-
def test_to_local_requires_grad(self):
43-
"""Test that gradients flow correctly through to_local()."""
44-
# Create a tensor with requires_grad=True
45-
world_size = xr.global_runtime_device_count()
46-
mesh = DeviceMesh("xla", list(range(world_size)))
47-
48-
tensor = torch.randn(100_000, 88, requires_grad=True)
49-
50-
# Create XLAShardedTensor
51-
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)])
52-
53-
# Verify requires_grad is set
54-
self.assertTrue(sharded_tensor.requires_grad)
55-
56-
res = sharded_tensor.sum()
57-
res.backward()
58-
59-
# Verify grad are calculated
60-
self.assertTrue(sharded_tensor.grad is not None)
61-
62-
# Call to local function
63-
local_tensor = sharded_tensor.to_local()
64-
65-
# Verify requires_grad is preserved
66-
self.assertTrue(local_tensor.requires_grad)
67-
68-
# All gradients should be 1.0 since we did a sum()
69-
self.assertTrue(torch.allclose(local_tensor.grad, torch.ones_like(tensor)))
70-
71-
print("Gradient flow test successful")
22+
@classmethod
23+
def setUpClass(cls):
24+
super().setUpClass()
25+
26+
def test_to_local(self):
27+
from torch.distributed.tensor import distribute_tensor
28+
world_size = xr.global_runtime_device_count()
29+
mesh = DeviceMesh("xla", list(range(world_size)))
30+
31+
big_tensor = torch.randn(100000, 88)
32+
sharded_tensor = XLAShardedTensor(big_tensor, mesh, [Shard(0)])
33+
34+
local_tensor = sharded_tensor.to_local()
35+
36+
# Verify the shapes are the same
37+
self.assertEqual(local_tensor.shape, big_tensor.shape)
38+
39+
# Check the value of the tensor
40+
torch.testing.assert_close(local_tensor, big_tensor, check_device=False)
41+
42+
def test_to_local_requires_grad(self):
43+
"""Test that gradients flow correctly through to_local()."""
44+
# Create a tensor with requires_grad=True
45+
world_size = xr.global_runtime_device_count()
46+
mesh = DeviceMesh("xla", list(range(world_size)))
47+
48+
tensor = torch.randn(100_000, 88, requires_grad=True)
49+
50+
# Create XLAShardedTensor
51+
sharded_tensor = XLAShardedTensor(tensor, mesh, [Shard(0)])
52+
53+
# Verify requires_grad is set
54+
self.assertTrue(sharded_tensor.requires_grad)
55+
56+
res = sharded_tensor.sum()
57+
res.backward()
58+
59+
# Verify grad are calculated
60+
self.assertTrue(sharded_tensor.grad is not None)
61+
62+
# Call to local function
63+
local_tensor = sharded_tensor.to_local()
64+
65+
# Verify requires_grad is preserved
66+
self.assertTrue(local_tensor.requires_grad)
67+
68+
# All gradients should be 1.0 since we did a sum()
69+
self.assertTrue(torch.allclose(local_tensor.grad, torch.ones_like(tensor)))
70+
71+
print("Gradient flow test successful")
72+
7273

7374
if __name__ == "__main__":
74-
result = unittest.main(exit=False)
75-
sys.exit(0 if result.result.wasSuccessful() else 1)
75+
result = unittest.main(exit=False)
76+
sys.exit(0 if result.result.wasSuccessful() else 1)

0 commit comments

Comments
 (0)