-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathtest_resharding_ext.py
More file actions
137 lines (115 loc) · 3.91 KB
/
test_resharding_ext.py
File metadata and controls
137 lines (115 loc) · 3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
from logging import getLogger
import pytest
from torch.distributed._tensor import Shard
from .test_resharding_basic import _test_resharding
from .utils import main, transport_plus_strategy_params
logger = getLogger(__name__)
def slow_tests_enabled():
return os.environ.get("TORCHSTORE_ENABLE_SLOW_TESTS", "0") == "1"
requires_slow_tests_enabled = pytest.mark.skipif(
not slow_tests_enabled(),
reason="Slow tests are disabled by default, use TORCHSTORE_ENABLE_SLOW_TESTS=1 to enable them",
)
@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.parametrize(
"put_mesh_shape,get_mesh_shape,put_sharding_dim,get_sharding_dim",
[
# shrink
((4,), (2,), 0, 0),
((4,), (2,), 0, 1),
((4,), (2,), 1, 0),
((4,), (2,), 1, 1),
# grow
((2,), (4,), 0, 0),
((2,), (4,), 0, 1),
((2,), (4,), 1, 0),
((2,), (4,), 1, 1),
],
)
@pytest.mark.asyncio
async def test_1d_resharding(
strategy_params,
transport_type,
put_mesh_shape,
get_mesh_shape,
put_sharding_dim,
get_sharding_dim,
):
_, strategy = strategy_params
# TODO: test Replicate as well, which is likely not working
await _test_resharding(
put_mesh_shape=put_mesh_shape,
put_placements=[Shard(put_sharding_dim)],
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(get_sharding_dim)],
strategy=strategy,
transport_type=transport_type,
)
@requires_slow_tests_enabled
@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_2d_to_2d_resharding(strategy_params, transport_type):
_, strategy = strategy_params
put_mesh_shape = get_mesh_shape = (2, 2)
for put_sharding_dims, get_sharding_dims in [
((1, 1), (0, 1)),
((1, 0), (1, 0)),
((0, 0), (0, 1)),
((1, 1), (0, 0)),
]:
await _test_resharding(
put_mesh_shape=put_mesh_shape,
put_placements=[Shard(dim) for dim in put_sharding_dims],
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(dim) for dim in get_sharding_dims],
strategy=strategy,
transport_type=transport_type,
)
@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_1d_to_2d_resharding(strategy_params, transport_type):
_, strategy = strategy_params
put_mesh_shape = (4,)
get_mesh_shape = (2, 2)
for put_sharding_dims, get_sharding_dims in [
((0,), (0, 1)),
((1,), (1, 0)),
((0,), (0, 0)),
((1,), (1, 1)),
]:
await _test_resharding(
put_mesh_shape=put_mesh_shape,
put_placements=[Shard(dim) for dim in put_sharding_dims],
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(dim) for dim in get_sharding_dims],
strategy=strategy,
transport_type=transport_type,
)
@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_2d_to_1d_resharding(strategy_params, transport_type):
_, strategy = strategy_params
put_mesh_shape = (2, 2)
get_mesh_shape = (4,)
for put_sharding_dims, get_sharding_dims in [
((0, 0), (0,)),
((1, 0), (1,)),
((0, 1), (0,)),
((1, 1), (1,)),
]:
await _test_resharding(
put_mesh_shape=put_mesh_shape,
put_placements=[Shard(dim) for dim in put_sharding_dims],
get_mesh_shape=get_mesh_shape,
get_placements=[Shard(dim) for dim in get_sharding_dims],
strategy=strategy,
transport_type=transport_type,
)
if __name__ == "__main__":
main(__file__)