14
14
15
15
16
16
class DTensorXLAFromLocalConversionTest (test_xla_sharding_base .XlaShardingTest ):
17
- """
17
+ """
18
18
Test suite for the automatic conversion of regular tensors to XLAShardedTensor
19
19
in DTensor.from_local() when using XLA device mesh.
20
20
"""
21
21
22
- @classmethod
23
- def setUpClass (cls ):
24
- super ().setUpClass ()
25
-
26
- def test_to_local (self ):
27
- from torch .distributed .tensor import distribute_tensor
28
- world_size = xr .global_runtime_device_count ()
29
- mesh = DeviceMesh ("xla" , list (range (world_size )))
30
-
31
- big_tensor = torch .randn (100000 , 88 )
32
- sharded_tensor = XLAShardedTensor (big_tensor , mesh , [Shard (0 )])
33
-
34
- local_tensor = sharded_tensor .to_local ()
35
-
36
- # Verify the shapes are the same
37
- self .assertEqual (local_tensor .shape , big_tensor .shape )
38
-
39
- # Check the value of the tensor
40
- torch .testing .assert_close (local_tensor , big_tensor , check_device = False )
41
-
42
- def test_to_local_requires_grad (self ):
43
- """Test that gradients flow correctly through to_local()."""
44
- # Create a tensor with requires_grad=True
45
- world_size = xr .global_runtime_device_count ()
46
- mesh = DeviceMesh ("xla" , list (range (world_size )))
47
-
48
- tensor = torch .randn (100_000 , 88 , requires_grad = True )
49
-
50
- # Create XLAShardedTensor
51
- sharded_tensor = XLAShardedTensor (tensor , mesh , [Shard (0 )])
52
-
53
- # Verify requires_grad is set
54
- self .assertTrue (sharded_tensor .requires_grad )
55
-
56
- res = sharded_tensor .sum ()
57
- res .backward ()
58
-
59
- # Verify grad are calculated
60
- self .assertTrue (sharded_tensor .grad is not None )
61
-
62
- # Call to local function
63
- local_tensor = sharded_tensor .to_local ()
64
-
65
- # Verify requires_grad is preserved
66
- self .assertTrue (local_tensor .requires_grad )
67
-
68
- # All gradients should be 1.0 since we did a sum()
69
- self .assertTrue (torch .allclose (local_tensor .grad , torch .ones_like (tensor )))
70
-
71
- print ("Gradient flow test successful" )
22
+ @classmethod
23
+ def setUpClass (cls ):
24
+ super ().setUpClass ()
25
+
26
+ def test_to_local (self ):
27
+ from torch .distributed .tensor import distribute_tensor
28
+ world_size = xr .global_runtime_device_count ()
29
+ mesh = DeviceMesh ("xla" , list (range (world_size )))
30
+
31
+ big_tensor = torch .randn (100000 , 88 )
32
+ sharded_tensor = XLAShardedTensor (big_tensor , mesh , [Shard (0 )])
33
+
34
+ local_tensor = sharded_tensor .to_local ()
35
+
36
+ # Verify the shapes are the same
37
+ self .assertEqual (local_tensor .shape , big_tensor .shape )
38
+
39
+ # Check the value of the tensor
40
+ torch .testing .assert_close (local_tensor , big_tensor , check_device = False )
41
+
42
+ def test_to_local_requires_grad (self ):
43
+ """Test that gradients flow correctly through to_local()."""
44
+ # Create a tensor with requires_grad=True
45
+ world_size = xr .global_runtime_device_count ()
46
+ mesh = DeviceMesh ("xla" , list (range (world_size )))
47
+
48
+ tensor = torch .randn (100_000 , 88 , requires_grad = True )
49
+
50
+ # Create XLAShardedTensor
51
+ sharded_tensor = XLAShardedTensor (tensor , mesh , [Shard (0 )])
52
+
53
+ # Verify requires_grad is set
54
+ self .assertTrue (sharded_tensor .requires_grad )
55
+
56
+ res = sharded_tensor .sum ()
57
+ res .backward ()
58
+
59
+ # Verify grad are calculated
60
+ self .assertTrue (sharded_tensor .grad is not None )
61
+
62
+ # Call to local function
63
+ local_tensor = sharded_tensor .to_local ()
64
+
65
+ # Verify requires_grad is preserved
66
+ self .assertTrue (local_tensor .requires_grad )
67
+
68
+ # All gradients should be 1.0 since we did a sum()
69
+ self .assertTrue (torch .allclose (local_tensor .grad , torch .ones_like (tensor )))
70
+
71
+ print ("Gradient flow test successful" )
72
+
72
73
73
74
if __name__ == "__main__" :
74
- result = unittest .main (exit = False )
75
- sys .exit (0 if result .result .wasSuccessful () else 1 )
75
+ result = unittest .main (exit = False )
76
+ sys .exit (0 if result .result .wasSuccessful () else 1 )
0 commit comments