Skip to content

Commit dbda5ba

Browse files
committed
Add dtensor mesh conversion test
1 parent cc15111 commit dbda5ba

File tree

4 files changed

+117
-0
lines changed

4 files changed

+117
-0
lines changed

test/neuron/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ function run_xla_op_tests3 {
246246
run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py"
247247
#run_test "$_TEST_DIR/spmd/test_dtensor_integration.py"
248248
#run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py"
249+
run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
249250
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
250251
#run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
251252
run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py"

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ function run_xla_op_tests3 {
253253
run_test "$_TEST_DIR/spmd/test_dtensor_integration.py"
254254
run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py"
255255
run_test_multi_devices_without_func "$_TEST_DIR/spmd/test_dtensor_integration3.py"
256+
run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
256257
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
257258
run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
258259
run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py"
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import sys
2+
import unittest
3+
import torch
4+
from torch.distributed.tensor import DeviceMesh, init_device_mesh
5+
import torch_xla.runtime as xr
6+
from torch_xla.distributed.spmd import Mesh
7+
from torch_xla.distributed.spmd.api import convert_to_xla_mesh
8+
import test_xla_sharding_base
9+
10+
11+
class ConvertToXlaMeshIntegrationTest(test_xla_sharding_base.XlaShardingTest):
12+
13+
@classmethod
14+
def setUpClass(cls):
15+
super().setUpClass()
16+
17+
@unittest.skipIf(xr.global_runtime_device_count() == 1,
18+
"Multiple devices needed for 1D mesh test")
19+
def test_convert_1d_device_mesh(self):
20+
device_count = xr.global_runtime_device_count()
21+
dt_mesh = init_device_mesh("xla", mesh_shape=(device_count,))
22+
23+
xla_mesh = convert_to_xla_mesh(dt_mesh)
24+
25+
self.assertIsInstance(xla_mesh, Mesh)
26+
self.assertEqual(len(xla_mesh.device_ids), device_count)
27+
self.assertEqual(xla_mesh.mesh_shape, (device_count,))
28+
self.assertEqual(xla_mesh.axis_names, dt_mesh.mesh_dim_names)
29+
30+
@unittest.skipIf(xr.global_runtime_device_count() < 2,
31+
"Multiple devices needed for 2D mesh test")
32+
def test_convert_2d_device_mesh(self):
33+
device_count = xr.global_runtime_device_count()
34+
mesh_shape = (2, device_count // 2)
35+
36+
dt_mesh = DeviceMesh("xla", torch.arange(device_count).reshape(mesh_shape))
37+
38+
xla_mesh = convert_to_xla_mesh(dt_mesh)
39+
40+
self.assertIsInstance(xla_mesh, Mesh)
41+
self.assertEqual(len(xla_mesh.device_ids), device_count)
42+
self.assertEqual(xla_mesh.mesh_shape, mesh_shape)
43+
self.assertEqual(xla_mesh.axis_names, dt_mesh.mesh_dim_names)
44+
45+
@unittest.skipIf(xr.global_runtime_device_count() == 1,
46+
"Multiple devices needed for custom dim names test")
47+
def test_convert_with_custom_dim_names(self):
48+
device_count = xr.global_runtime_device_count()
49+
dt_mesh = DeviceMesh(
50+
"xla", list(range(device_count)), mesh_dim_names=["data_parallel"])
51+
52+
xla_mesh = convert_to_xla_mesh(dt_mesh)
53+
54+
self.assertIsInstance(xla_mesh, Mesh)
55+
self.assertEqual(len(xla_mesh.device_ids), device_count)
56+
self.assertEqual(xla_mesh.mesh_shape, (device_count,))
57+
self.assertEqual(xla_mesh.axis_names, ("data_parallel",))
58+
59+
@unittest.skipIf(xr.global_runtime_device_count() == 1,
60+
"Multiple devices needed for device IDs order test")
61+
def test_convert_mesh_device_ids_order(self):
62+
device_count = xr.global_runtime_device_count()
63+
device_ids = list(range(device_count))
64+
65+
mesh_shape = (2, device_count // 2)
66+
mesh_2d = torch.tensor(device_ids).reshape(mesh_shape)
67+
dt_mesh = DeviceMesh("xla", mesh_2d)
68+
69+
xla_mesh = convert_to_xla_mesh(dt_mesh)
70+
71+
expected_flattened = mesh_2d.flatten().tolist()
72+
self.assertEqual(list(xla_mesh.device_ids), expected_flattened)
73+
74+
@unittest.skipIf(xr.global_runtime_device_count() == 1,
75+
"Multiple devices needed for mismatch test")
76+
def test_device_count_mismatch_assertion(self):
77+
device_count = xr.global_runtime_device_count()
78+
with self.assertRaises(AssertionError):
79+
dt_mesh = DeviceMesh("xla", list(range(device_count - 1)))
80+
convert_to_xla_mesh(dt_mesh)
81+
82+
@unittest.skipIf(xr.global_runtime_device_count() < 4,
83+
"At least 4 devices needed for mesh configuration tests")
84+
def test_mesh_configurations(self):
85+
device_count = xr.global_runtime_device_count()
86+
test_configs = [((1, device_count), "flat"),
87+
((2, device_count // 2), "2d_balanced")]
88+
89+
for mesh_shape, config_name in test_configs:
90+
with self.subTest(configuration=config_name):
91+
dt_mesh = DeviceMesh("xla",
92+
torch.arange(device_count).reshape(mesh_shape))
93+
xla_mesh = convert_to_xla_mesh(dt_mesh)
94+
95+
self.assertEqual(xla_mesh.mesh_shape, mesh_shape)
96+
self.assertEqual(len(xla_mesh.device_ids), device_count)
97+
self.assertEqual(list(xla_mesh.device_ids), list(range(device_count)))
98+
99+
def test_mesh_property_consistency(self):
100+
device_count = xr.global_runtime_device_count()
101+
dt_mesh = init_device_mesh("xla", mesh_shape=(device_count,))
102+
103+
xla_mesh = convert_to_xla_mesh(dt_mesh)
104+
105+
self.assertEqual(dt_mesh.size(), len(xla_mesh.device_ids))
106+
self.assertEqual(tuple(dt_mesh.mesh.size()), xla_mesh.mesh_shape)
107+
108+
expected_device_ids = dt_mesh.mesh.flatten().tolist()
109+
self.assertEqual(list(xla_mesh.device_ids), expected_device_ids)
110+
111+
112+
if __name__ == '__main__':
113+
test = unittest.main()
114+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py"
6060
run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py"
6161
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
6262
run_test "$_TEST_DIR/spmd/test_fsdp_v2.py"
63+
run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
6364
run_test "$_TEST_DIR/test_gradient_accumulation.py"
6465
XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v
6566
run_test "$_TEST_DIR/test_autocast.py"

0 commit comments

Comments
 (0)