Skip to content

Commit 77190d9

Browse files
committed
Fix tests.
1 parent 4184a1b commit 77190d9

File tree

1 file changed

+108
-80
lines changed

1 file changed

+108
-80
lines changed

test/test_dynamic_shapes_detector.py

Lines changed: 108 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,42 @@
1+
import re
2+
import textwrap
13
import torch
24
import torch_xla
35
import test_utils
46
import unittest
57

68

9+
# Processes a string, so that it can be used as the expected error regex.
10+
# Specifically, it does 3 things:
11+
#
12+
# 1. s[1:]: assumes the first character of the string is a new-line, and
13+
# removes it.
14+
#
15+
# 2. textwrap.dedent(): strips the leading space in the string, allowing us
16+
# to write more readable multi-line strings.
17+
#
18+
# 3. ESCAPE_RE.sub(): escapes special characters, such as parenthesis,
19+
# brackets, and braces, so as to allow us to write more
20+
# readable strings.
21+
#
22+
# Note that because of (3), we lose the "regex" part, not being able to use
23+
# regex wildcards, such as "*".
24+
ESCAPE_RE = re.compile(r"([\[\](){}])")
25+
26+
def escape(s):
27+
return ESCAPE_RE.sub(r"\\\1", textwrap.dedent(s[1:]))
28+
29+
730
class TestDynamicShapeDetector(test_utils.XlaTestCase):
831

9-
def _run_and_compare(self, f, args=None, allowed_traces=None):
32+
def _run_and_compare(self, f, args=None, max_different_graphs=None):
1033
"""Run f and its torch_xla.compile wrapped version, comparing the equality
1134
of their results.
1235
1336
If no optf is provided, we create a new one by wrapping it with
1437
torch_xla.compile ourselves.
1538
"""
16-
optf = torch_xla.compile(f, allowed_traces=allowed_traces)
39+
optf = torch_xla.compile(f, max_different_graphs=max_different_graphs)
1740
args = args or []
1841

1942
out = f(*args)
@@ -22,18 +45,18 @@ def _run_and_compare(self, f, args=None, allowed_traces=None):
2245
self.assertEqual(out, optout)
2346

2447
def test_single(self):
25-
# Test: trace a function once, when only one trace is allowed.
48+
# Test: trace a function once, when only one graph is allowed.
2649

2750
def foo(x):
2851
return x + x
2952

3053
inp = torch.rand(10, device=torch_xla.device())
31-
self._run_and_compare(foo, args=(inp,), allowed_traces=1)
54+
self._run_and_compare(foo, args=(inp,), max_different_graphs=1)
3255

33-
def test_many_traces(self):
34-
# Test: multiple traces of a function.
56+
def test_many_graphs(self):
57+
# Test: multiple graphs of a function.
3558
#
36-
# Steps 0~2 and 5: create new traces.
59+
# Steps 0~2 and 5: create new graphs.
3760
# Steps 3 and 4: ensure we have already traced these paths.
3861

3962
def foo(x, step):
@@ -50,41 +73,42 @@ def foo(x, step):
5073
inp = torch.rand(10, device=torch_xla.device())
5174

5275
for i in range(6):
53-
self._run_and_compare(foo, args=(inp, i), allowed_traces=4)
76+
self._run_and_compare(foo, args=(inp, i), max_different_graphs=4)
5477

55-
def test_trace_limit_exceeded_different_input_shape(self):
56-
# Test: catch trace limit exceeded error when running the function with a
78+
def test_graph_limit_exceeded_different_input_shape(self):
79+
# Test: catch graph limit exceeded error when running the function with a
5780
# function with different shape.
5881

59-
allowed_traces = 1
82+
max_different_graphs = 1
6083

6184
def foo(x):
6285
return x + x
6386

6487
inp1 = torch.rand(10, device=torch_xla.device())
65-
self._run_and_compare(foo, args=(inp1,), allowed_traces=allowed_traces)
88+
self._run_and_compare(foo, args=(inp1,), max_different_graphs=max_different_graphs)
6689

67-
msg = """\
68-
.* Maximum number of different traces allowed per function exceeded: 1
69-
Got: [] aten::expand, xla_shape=f32[10]{0}, dynamic_dims: (), size=(10)
70-
Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
90+
expected_error_msg = escape(r"""
91+
Maximum number of different graphs allowed per function exceeded: 1
92+
Got: [] aten::add, xla_shape=f32[5]{0}, dynamic_dims: ()
93+
Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
94+
""")
7195

72-
with self.assertRaises(RuntimeError, msg=msg):
96+
with self.assertRaisesRegex(RuntimeError, expected_error_msg):
7397
inp2 = torch.rand(5, device=torch_xla.device())
74-
self._run_and_compare(foo, args=(inp2,), allowed_traces=allowed_traces)
98+
self._run_and_compare(foo, args=(inp2,), max_different_graphs=max_different_graphs)
7599

76-
def test_trace_limit_exceeded_common_sequence_mismatch(self):
77-
# Test: catch trace limit exceeded error when the common sequence (i.e. compressed
100+
def test_graph_limit_exceeded_common_sequence_mismatch(self):
101+
# Test: catch graph limit exceeded error when the common sequence (i.e. compressed
78102
# path) of the trie node mismatches.
79103
#
80-
# Step 0: creates a trace with one node containing the add operation
104+
# Step 0: creates a graph with one node containing the add operation
81105
#
82106
# Step 1: tries to create 2 child nodes with:
83-
# (i) add operation (previous trace); and
107+
# (i) add operation (previous graph); and
84108
# (ii) mul operation.
85109
# However, it fails since we have reached the limit.
86110

87-
allowed_traces = 1
111+
max_different_graphs = 1
88112

89113
def foo(x, step):
90114
if step == 0:
@@ -93,32 +117,33 @@ def foo(x, step):
93117
return x * 5
94118

95119
inp = torch.rand(10, device=torch_xla.device())
96-
self._run_and_compare(foo, args=(inp, 0), allowed_traces=allowed_traces)
120+
self._run_and_compare(foo, args=(inp, 0), max_different_graphs=max_different_graphs)
97121

98-
msg = """\
99-
.* Maximum number of different traces allowed per function exceeded: 1
100-
Got: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
101-
Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
122+
expected_error_msg = escape(r"""
123+
Maximum number of different graphs allowed per function exceeded: 1
124+
Got: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
125+
Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
126+
""")
102127

103-
with self.assertRaises(RuntimeError, msg=msg):
104-
self._run_and_compare(foo, args=(inp, 2), allowed_traces=allowed_traces)
128+
with self.assertRaisesRegex(RuntimeError, expected_error_msg):
129+
self._run_and_compare(foo, args=(inp, 2), max_different_graphs=max_different_graphs)
105130

106-
def test_trace_limit_exceeded_children_mismatch(self):
107-
# Test: catch trace limit exceeded error when the expected child of the trie
131+
def test_graph_limit_exceeded_children_mismatch(self):
132+
# Test: catch graph limit exceeded error when the expected child of the trie
108133
# node mismatches.
109134
#
110-
# Step 0: creates a trace with one node containing 3 operations, the last
135+
# Step 0: creates a graph with one node containing 3 operations, the last
111136
# being a mul operation.
112137
#
113-
# Step 1: creates another trace by splitting the node, creating 2 other child
138+
# Step 1: creates another graph by splitting the node, creating 2 other child
114139
# nodes containing the different operations in the end:
115140
# (i) mul operation; and
116141
# (ii) add operation.
117142
#
118143
# Step 2: tries to create a 3rd child node: div operation. However, we can't
119144
# do it, since we have reached the limit.
120145

121-
allowed_traces = 2
146+
max_different_graphs = 2
122147

123148
def foo(x, step):
124149
r = x + x
@@ -129,30 +154,31 @@ def foo(x, step):
129154
return r / 3
130155

131156
inp = torch.rand(10, device=torch_xla.device())
132-
self._run_and_compare(foo, args=(inp, 0), allowed_traces=allowed_traces)
133-
self._run_and_compare(foo, args=(inp, 1), allowed_traces=allowed_traces)
134-
135-
msg = """\
136-
.* Maximum number of different traces allowed per function exceeded: 2
137-
Got: [] aten::expand, xla_shape=f32[10]{0}, dynamic_dims: (), size=(10)
138-
Expected either of:
139-
- [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
140-
- [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
141-
142-
with self.assertRaises(RuntimeError, msg=msg):
143-
self._run_and_compare(foo, args=(inp, 2), allowed_traces=allowed_traces)
144-
145-
def test_trace_limit_exceeded_common_sequence_early_stop(self):
146-
# Test: catch trace limit exceeded error when the trace ends unexpectedly in
157+
self._run_and_compare(foo, args=(inp, 0), max_different_graphs=max_different_graphs)
158+
self._run_and_compare(foo, args=(inp, 1), max_different_graphs=max_different_graphs)
159+
160+
expected_error_msg = escape(r"""
161+
Maximum number of different graphs allowed per function exceeded: 2
162+
Got: [] aten::div, xla_shape=f32[10]{0}, dynamic_dims: ()
163+
Expected either of:
164+
- [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
165+
- [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
166+
""")
167+
168+
with self.assertRaisesRegex(RuntimeError, expected_error_msg):
169+
self._run_and_compare(foo, args=(inp, 2), max_different_graphs=max_different_graphs)
170+
171+
def test_graph_limit_exceeded_common_sequence_early_stop(self):
172+
# Test: catch graph limit exceeded error when the graph ends unexpectedly in
147173
# the common sequence.
148174
#
149-
# Step 0: creates a trace with one node containing 3 operations.
175+
# Step 0: creates a graph with one node containing 3 operations.
150176
#
151-
# Step 1: at the end of this trace, it tries to create a new node containing
152-
# the remaining operations of the previous trace, i.e. mul operation. However,
177+
# Step 1: at the end of this graph, it tries to create a new node containing
178+
# the remaining operations of the previous graph, i.e. mul operation. However,
153179
# it fails because we have reached the limit.
154180

155-
allowed_traces = 1
181+
max_different_graphs = 1
156182

157183
def foo(x, mul=False):
158184
r = x + x
@@ -162,31 +188,32 @@ def foo(x, mul=False):
162188
return r
163189

164190
inp = torch.rand(10, device=torch_xla.device())
165-
self._run_and_compare(foo, args=(inp, True), allowed_traces=allowed_traces)
191+
self._run_and_compare(foo, args=(inp, True), max_different_graphs=max_different_graphs)
166192

167-
msg = """\
168-
.* Maximum number of different traces allowed per function exceeded: 1
169-
Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
170-
Expected: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()"""
193+
expected_error_msg = escape(r"""
194+
Maximum number of different graphs allowed per function exceeded: 1
195+
Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
196+
Expected: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
197+
""")
171198

172-
with self.assertRaises(RuntimeError, msg=msg):
199+
with self.assertRaisesRegex(RuntimeError, expected_error_msg):
173200
self._run_and_compare(
174-
foo, args=(inp, False), allowed_traces=allowed_traces)
201+
foo, args=(inp, False), max_different_graphs=max_different_graphs)
175202

176-
def test_trace_limit_exceeded_children_early_stop(self):
177-
# Test: catch trace limit exceeded error when the trace ends unexpectedly at
203+
def test_graph_limit_exceeded_children_early_stop(self):
204+
# Test: catch graph limit exceeded error when the graph ends unexpectedly at
178205
# a fork point (i.e. next operation would jump to anothe trie node).
179206
#
180-
# Step 0: creates a trace with one node containing 3 operations.
207+
# Step 0: creates a graph with one node containing 3 operations.
181208
#
182209
# Step 1: splits the node, creating 2 child nodes containing:
183-
# (i) the differring operations from the last trace, i.e. mul operation
210+
# (i) the differring operations from the last graph, i.e. mul operation
184211
# (ii) the current last operation, i.e. add operation
185212
#
186-
# Step 3: at the end of this trace, it tries to turn the current trie node
187-
# into a new trace. However, it fails since we have reached the limit.
213+
# Step 3: at the end of this graph, it tries to turn the current trie node
214+
# into a new graph. However, it fails since we have reached the limit.
188215

189-
allowed_traces = 2
216+
max_different_graphs = 2
190217

191218
def foo(x, step):
192219
r = x + x
@@ -197,18 +224,19 @@ def foo(x, step):
197224
return r
198225

199226
inp = torch.rand(10, device=torch_xla.device())
200-
self._run_and_compare(foo, args=(inp, 0), allowed_traces=allowed_traces)
201-
self._run_and_compare(foo, args=(inp, 1), allowed_traces=allowed_traces)
202-
203-
msg = """\
204-
.* Maximum number of different traces allowed per function exceeded: 2
205-
Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
206-
Expected either of:
207-
- [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
208-
- [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
209-
210-
with self.assertRaises(RuntimeError, msg=msg):
211-
self._run_and_compare(foo, args=(inp, 2), allowed_traces=allowed_traces)
227+
self._run_and_compare(foo, args=(inp, 0), max_different_graphs=max_different_graphs)
228+
self._run_and_compare(foo, args=(inp, 1), max_different_graphs=max_different_graphs)
229+
230+
expected_error_msg = escape(r"""
231+
Maximum number of different graphs allowed per function exceeded: 2
232+
Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
233+
Expected either of:
234+
- [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
235+
- [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
236+
""")
237+
238+
with self.assertRaisesRegex(RuntimeError, expected_error_msg):
239+
self._run_and_compare(foo, args=(inp, 2), max_different_graphs=max_different_graphs)
212240

213241

214242
if __name__ == "__main__":

0 commit comments

Comments
 (0)