Skip to content

Commit 6665085

Browse files
authored
Fix + Run DynamicShapeDetector tests on CI. (#9075)
1 parent 22a9916 commit 6665085

File tree

6 files changed

+228
-173
lines changed

6 files changed

+228
-173
lines changed

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ function run_xla_op_tests2 {
258258
run_test "$_TEST_DIR/test_assume_pure.py"
259259
run_test "$_TEST_DIR/test_assume_pure_spmd.py"
260260
run_test "$_TEST_DIR/test_assume_pure_torch.py"
261+
run_test "$_TEST_DIR/test_dynamic_shapes_detector.py"
261262
}
262263

263264
# All the new xla op tests should go to run_xla_op_tests3

test/test_dynamic_shapes_detector.py

Lines changed: 118 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,42 @@
1+
import re
2+
import textwrap
13
import torch
24
import torch_xla
35
import test_utils
46
import unittest
57

8+
# Processes a string, so that it can be used as the expected error regex.
9+
# Specifically, it does 3 things:
10+
#
11+
# 1. s[1:]: assumes the first character of the string is a new-line, and
12+
# removes it.
13+
#
14+
# 2. textwrap.dedent(): strips the leading space in the string, allowing us
15+
# to write more readable multi-line strings.
16+
#
17+
# 3. ESCAPE_RE.sub(): escapes special characters, such as parenthesis,
18+
# brackets, and braces, so as to allow us to write more
19+
# readable strings.
20+
#
21+
# Note that because of (3), we lose the "regex" part, not being able to use
22+
# regex wildcards, such as "*".
23+
ESCAPE_RE = re.compile(r"([\[\](){}])")
24+
25+
26+
def escape(s):
27+
return ESCAPE_RE.sub(r"\\\1", textwrap.dedent(s[1:]))
28+
629

730
class TestDynamicShapeDetector(test_utils.XlaTestCase):
831

9-
def _run_and_compare(self, f, args=None, allowed_traces=None):
32+
def _run_and_compare(self, f, args=None, max_different_graphs=None):
1033
"""Run f and its torch_xla.compile wrapped version, comparing the equality
1134
of their results.
1235
1336
If no optf is provided, we create a new one by wrapping it with
1437
torch_xla.compile ourselves.
1538
"""
16-
optf = torch_xla.compile(f, allowed_traces=allowed_traces)
39+
optf = torch_xla.compile(f, max_different_graphs=max_different_graphs)
1740
args = args or []
1841

1942
out = f(*args)
@@ -22,18 +45,18 @@ def _run_and_compare(self, f, args=None, allowed_traces=None):
2245
self.assertEqual(out, optout)
2346

2447
def test_single(self):
25-
# Test: trace a function once, when only one trace is allowed.
48+
# Test: trace a function once, when only one graph is allowed.
2649

2750
def foo(x):
2851
return x + x
2952

3053
inp = torch.rand(10, device=torch_xla.device())
31-
self._run_and_compare(foo, args=(inp,), allowed_traces=1)
54+
self._run_and_compare(foo, args=(inp,), max_different_graphs=1)
3255

33-
def test_many_traces(self):
34-
# Test: multiple traces of a function.
56+
def test_many_graphs(self):
57+
# Test: multiple graphs of a function.
3558
#
36-
# Steps 0~2 and 5: create new traces.
59+
# Steps 0~2 and 5: create new graphs.
3760
# Steps 3 and 4: ensure we have already traced these paths.
3861

3962
def foo(x, step):
@@ -50,41 +73,44 @@ def foo(x, step):
5073
inp = torch.rand(10, device=torch_xla.device())
5174

5275
for i in range(6):
53-
self._run_and_compare(foo, args=(inp, i), allowed_traces=4)
76+
self._run_and_compare(foo, args=(inp, i), max_different_graphs=4)
5477

55-
def test_trace_limit_exceeded_different_input_shape(self):
56-
# Test: catch trace limit exceeded error when running the function with a
78+
def test_graph_limit_exceeded_different_input_shape(self):
79+
# Test: catch graph limit exceeded error when running the function with a
5780
# function with different shape.
5881

59-
allowed_traces = 1
82+
max_different_graphs = 1
6083

6184
def foo(x):
6285
return x + x
6386

6487
inp1 = torch.rand(10, device=torch_xla.device())
65-
self._run_and_compare(foo, args=(inp1,), allowed_traces=allowed_traces)
88+
self._run_and_compare(
89+
foo, args=(inp1,), max_different_graphs=max_different_graphs)
6690

67-
msg = """\
68-
.* Maximum number of different traces allowed per function exceeded: 1
69-
Got: [] aten::expand, xla_shape=f32[10]{0}, dynamic_dims: (), size=(10)
70-
Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
91+
expected_error_msg = escape(r"""
92+
Maximum number of different graphs allowed per function exceeded: 1
93+
Got: [] aten::add, xla_shape=f32[5]{0}, dynamic_dims: ()
94+
Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
95+
""")
7196

72-
with self.assertRaises(RuntimeError, msg=msg):
97+
with self.assertRaisesRegex(RuntimeError, expected_error_msg):
7398
inp2 = torch.rand(5, device=torch_xla.device())
74-
self._run_and_compare(foo, args=(inp2,), allowed_traces=allowed_traces)
99+
self._run_and_compare(
100+
foo, args=(inp2,), max_different_graphs=max_different_graphs)
75101

76-
def test_trace_limit_exceeded_common_sequence_mismatch(self):
77-
# Test: catch trace limit exceeded error when the common sequence (i.e. compressed
102+
def test_graph_limit_exceeded_common_sequence_mismatch(self):
103+
# Test: catch graph limit exceeded error when the common sequence (i.e. compressed
78104
# path) of the trie node mismatches.
79105
#
80-
# Step 0: creates a trace with one node containing the add operation
106+
# Step 0: creates a graph with one node containing the add operation
81107
#
82108
# Step 1: tries to create 2 child nodes with:
83-
# (i) add operation (previous trace); and
109+
# (i) add operation (previous graph); and
84110
# (ii) mul operation.
85111
# However, it fails since we have reached the limit.
86112

87-
allowed_traces = 1
113+
max_different_graphs = 1
88114

89115
def foo(x, step):
90116
if step == 0:
@@ -93,32 +119,35 @@ def foo(x, step):
93119
return x * 5
94120

95121
inp = torch.rand(10, device=torch_xla.device())
96-
self._run_and_compare(foo, args=(inp, 0), allowed_traces=allowed_traces)
122+
self._run_and_compare(
123+
foo, args=(inp, 0), max_different_graphs=max_different_graphs)
97124

98-
msg = """\
99-
.* Maximum number of different traces allowed per function exceeded: 1
100-
Got: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
101-
Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
125+
expected_error_msg = escape(r"""
126+
Maximum number of different graphs allowed per function exceeded: 1
127+
Got: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
128+
Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
129+
""")
102130

103-
with self.assertRaises(RuntimeError, msg=msg):
104-
self._run_and_compare(foo, args=(inp, 2), allowed_traces=allowed_traces)
131+
with self.assertRaisesRegex(RuntimeError, expected_error_msg):
132+
self._run_and_compare(
133+
foo, args=(inp, 2), max_different_graphs=max_different_graphs)
105134

106-
def test_trace_limit_exceeded_children_mismatch(self):
107-
# Test: catch trace limit exceeded error when the expected child of the trie
135+
def test_graph_limit_exceeded_children_mismatch(self):
136+
# Test: catch graph limit exceeded error when the expected child of the trie
108137
# node mismatches.
109138
#
110-
# Step 0: creates a trace with one node containing 3 operations, the last
139+
# Step 0: creates a graph with one node containing 3 operations, the last
111140
# being a mul operation.
112141
#
113-
# Step 1: creates another trace by splitting the node, creating 2 other child
142+
# Step 1: creates another graph by splitting the node, creating 2 other child
114143
# nodes containing the different operations in the end:
115144
# (i) mul operation; and
116145
# (ii) add operation.
117146
#
118147
# Step 2: tries to create a 3rd child node: div operation. However, we can't
119148
# do it, since we have reached the limit.
120149

121-
allowed_traces = 2
150+
max_different_graphs = 2
122151

123152
def foo(x, step):
124153
r = x + x
@@ -129,30 +158,34 @@ def foo(x, step):
129158
return r / 3
130159

131160
inp = torch.rand(10, device=torch_xla.device())
132-
self._run_and_compare(foo, args=(inp, 0), allowed_traces=allowed_traces)
133-
self._run_and_compare(foo, args=(inp, 1), allowed_traces=allowed_traces)
134-
135-
msg = """\
136-
.* Maximum number of different traces allowed per function exceeded: 2
137-
Got: [] aten::expand, xla_shape=f32[10]{0}, dynamic_dims: (), size=(10)
138-
Expected either of:
139-
- [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
140-
- [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
141-
142-
with self.assertRaises(RuntimeError, msg=msg):
143-
self._run_and_compare(foo, args=(inp, 2), allowed_traces=allowed_traces)
161+
self._run_and_compare(
162+
foo, args=(inp, 0), max_different_graphs=max_different_graphs)
163+
self._run_and_compare(
164+
foo, args=(inp, 1), max_different_graphs=max_different_graphs)
165+
166+
expected_error_msg = escape(r"""
167+
Maximum number of different graphs allowed per function exceeded: 2
168+
Got: [] aten::div, xla_shape=f32[10]{0}, dynamic_dims: ()
169+
Expected either of:
170+
- [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
171+
- [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
172+
""")
173+
174+
with self.assertRaisesRegex(RuntimeError, expected_error_msg):
175+
self._run_and_compare(
176+
foo, args=(inp, 2), max_different_graphs=max_different_graphs)
144177

145-
def test_trace_limit_exceeded_common_sequence_early_stop(self):
146-
# Test: catch trace limit exceeded error when the trace ends unexpectedly in
178+
def test_graph_limit_exceeded_common_sequence_early_stop(self):
179+
# Test: catch graph limit exceeded error when the graph ends unexpectedly in
147180
# the common sequence.
148181
#
149-
# Step 0: creates a trace with one node containing 3 operations.
182+
# Step 0: creates a graph with one node containing 3 operations.
150183
#
151-
# Step 1: at the end of this trace, it tries to create a new node containing
152-
# the remaining operations of the previous trace, i.e. mul operation. However,
184+
# Step 1: at the end of this graph, it tries to create a new node containing
185+
# the remaining operations of the previous graph, i.e. mul operation. However,
153186
# it fails because we have reached the limit.
154187

155-
allowed_traces = 1
188+
max_different_graphs = 1
156189

157190
def foo(x, mul=False):
158191
r = x + x
@@ -162,31 +195,33 @@ def foo(x, mul=False):
162195
return r
163196

164197
inp = torch.rand(10, device=torch_xla.device())
165-
self._run_and_compare(foo, args=(inp, True), allowed_traces=allowed_traces)
198+
self._run_and_compare(
199+
foo, args=(inp, True), max_different_graphs=max_different_graphs)
166200

167-
msg = """\
168-
.* Maximum number of different traces allowed per function exceeded: 1
169-
Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
170-
Expected: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()"""
201+
expected_error_msg = escape(r"""
202+
Maximum number of different graphs allowed per function exceeded: 1
203+
Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
204+
Expected: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
205+
""")
171206

172-
with self.assertRaises(RuntimeError, msg=msg):
207+
with self.assertRaisesRegex(RuntimeError, expected_error_msg):
173208
self._run_and_compare(
174-
foo, args=(inp, False), allowed_traces=allowed_traces)
209+
foo, args=(inp, False), max_different_graphs=max_different_graphs)
175210

176-
def test_trace_limit_exceeded_children_early_stop(self):
177-
# Test: catch trace limit exceeded error when the trace ends unexpectedly at
211+
def test_graph_limit_exceeded_children_early_stop(self):
212+
# Test: catch graph limit exceeded error when the graph ends unexpectedly at
178213
# a fork point (i.e. next operation would jump to anothe trie node).
179214
#
180-
# Step 0: creates a trace with one node containing 3 operations.
215+
# Step 0: creates a graph with one node containing 3 operations.
181216
#
182217
# Step 1: splits the node, creating 2 child nodes containing:
183-
# (i) the differring operations from the last trace, i.e. mul operation
218+
# (i) the differring operations from the last graph, i.e. mul operation
184219
# (ii) the current last operation, i.e. add operation
185220
#
186-
# Step 3: at the end of this trace, it tries to turn the current trie node
187-
# into a new trace. However, it fails since we have reached the limit.
221+
# Step 3: at the end of this graph, it tries to turn the current trie node
222+
# into a new graph. However, it fails since we have reached the limit.
188223

189-
allowed_traces = 2
224+
max_different_graphs = 2
190225

191226
def foo(x, step):
192227
r = x + x
@@ -197,18 +232,22 @@ def foo(x, step):
197232
return r
198233

199234
inp = torch.rand(10, device=torch_xla.device())
200-
self._run_and_compare(foo, args=(inp, 0), allowed_traces=allowed_traces)
201-
self._run_and_compare(foo, args=(inp, 1), allowed_traces=allowed_traces)
202-
203-
msg = """\
204-
.* Maximum number of different traces allowed per function exceeded: 2
205-
Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
206-
Expected either of:
207-
- [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
208-
- [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
209-
210-
with self.assertRaises(RuntimeError, msg=msg):
211-
self._run_and_compare(foo, args=(inp, 2), allowed_traces=allowed_traces)
235+
self._run_and_compare(
236+
foo, args=(inp, 0), max_different_graphs=max_different_graphs)
237+
self._run_and_compare(
238+
foo, args=(inp, 1), max_different_graphs=max_different_graphs)
239+
240+
expected_error_msg = escape(r"""
241+
Maximum number of different graphs allowed per function exceeded: 2
242+
Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
243+
Expected either of:
244+
- [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
245+
- [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
246+
""")
247+
248+
with self.assertRaisesRegex(RuntimeError, expected_error_msg):
249+
self._run_and_compare(
250+
foo, args=(inp, 2), max_different_graphs=max_different_graphs)
212251

213252

214253
if __name__ == "__main__":

0 commit comments

Comments
 (0)