5
5
import test_utils
6
6
import unittest
7
7
8
-
9
8
# Processes a string, so that it can be used as the expected error regex.
10
9
# Specifically, it does 3 things:
11
10
#
23
22
# regex wildcards, such as "*".
24
23
ESCAPE_RE = re .compile (r"([\[\](){}])" )
25
24
25
+
26
26
def escape (s ):
27
- return ESCAPE_RE .sub (r"\\\1" , textwrap .dedent (s [1 :]))
27
+ return ESCAPE_RE .sub (r"\\\1" , textwrap .dedent (s [1 :]))
28
28
29
29
30
30
class TestDynamicShapeDetector (test_utils .XlaTestCase ):
@@ -85,7 +85,8 @@ def foo(x):
85
85
return x + x
86
86
87
87
inp1 = torch .rand (10 , device = torch_xla .device ())
88
- self ._run_and_compare (foo , args = (inp1 ,), max_different_graphs = max_different_graphs )
88
+ self ._run_and_compare (
89
+ foo , args = (inp1 ,), max_different_graphs = max_different_graphs )
89
90
90
91
expected_error_msg = escape (r"""
91
92
Maximum number of different graphs allowed per function exceeded: 1
@@ -95,7 +96,8 @@ def foo(x):
95
96
96
97
with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
97
98
inp2 = torch .rand (5 , device = torch_xla .device ())
98
- self ._run_and_compare (foo , args = (inp2 ,), max_different_graphs = max_different_graphs )
99
+ self ._run_and_compare (
100
+ foo , args = (inp2 ,), max_different_graphs = max_different_graphs )
99
101
100
102
def test_graph_limit_exceeded_common_sequence_mismatch (self ):
101
103
# Test: catch graph limit exceeded error when the common sequence (i.e. compressed
@@ -117,7 +119,8 @@ def foo(x, step):
117
119
return x * 5
118
120
119
121
inp = torch .rand (10 , device = torch_xla .device ())
120
- self ._run_and_compare (foo , args = (inp , 0 ), max_different_graphs = max_different_graphs )
122
+ self ._run_and_compare (
123
+ foo , args = (inp , 0 ), max_different_graphs = max_different_graphs )
121
124
122
125
expected_error_msg = escape (r"""
123
126
Maximum number of different graphs allowed per function exceeded: 1
@@ -126,7 +129,8 @@ def foo(x, step):
126
129
""" )
127
130
128
131
with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
129
- self ._run_and_compare (foo , args = (inp , 2 ), max_different_graphs = max_different_graphs )
132
+ self ._run_and_compare (
133
+ foo , args = (inp , 2 ), max_different_graphs = max_different_graphs )
130
134
131
135
def test_graph_limit_exceeded_children_mismatch (self ):
132
136
# Test: catch graph limit exceeded error when the expected child of the trie
@@ -154,8 +158,10 @@ def foo(x, step):
154
158
return r / 3
155
159
156
160
inp = torch .rand (10 , device = torch_xla .device ())
157
- self ._run_and_compare (foo , args = (inp , 0 ), max_different_graphs = max_different_graphs )
158
- self ._run_and_compare (foo , args = (inp , 1 ), max_different_graphs = max_different_graphs )
161
+ self ._run_and_compare (
162
+ foo , args = (inp , 0 ), max_different_graphs = max_different_graphs )
163
+ self ._run_and_compare (
164
+ foo , args = (inp , 1 ), max_different_graphs = max_different_graphs )
159
165
160
166
expected_error_msg = escape (r"""
161
167
Maximum number of different graphs allowed per function exceeded: 2
@@ -166,7 +172,8 @@ def foo(x, step):
166
172
""" )
167
173
168
174
with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
169
- self ._run_and_compare (foo , args = (inp , 2 ), max_different_graphs = max_different_graphs )
175
+ self ._run_and_compare (
176
+ foo , args = (inp , 2 ), max_different_graphs = max_different_graphs )
170
177
171
178
def test_graph_limit_exceeded_common_sequence_early_stop (self ):
172
179
# Test: catch graph limit exceeded error when the graph ends unexpectedly in
@@ -188,7 +195,8 @@ def foo(x, mul=False):
188
195
return r
189
196
190
197
inp = torch .rand (10 , device = torch_xla .device ())
191
- self ._run_and_compare (foo , args = (inp , True ), max_different_graphs = max_different_graphs )
198
+ self ._run_and_compare (
199
+ foo , args = (inp , True ), max_different_graphs = max_different_graphs )
192
200
193
201
expected_error_msg = escape (r"""
194
202
Maximum number of different graphs allowed per function exceeded: 1
@@ -224,8 +232,10 @@ def foo(x, step):
224
232
return r
225
233
226
234
inp = torch .rand (10 , device = torch_xla .device ())
227
- self ._run_and_compare (foo , args = (inp , 0 ), max_different_graphs = max_different_graphs )
228
- self ._run_and_compare (foo , args = (inp , 1 ), max_different_graphs = max_different_graphs )
235
+ self ._run_and_compare (
236
+ foo , args = (inp , 0 ), max_different_graphs = max_different_graphs )
237
+ self ._run_and_compare (
238
+ foo , args = (inp , 1 ), max_different_graphs = max_different_graphs )
229
239
230
240
expected_error_msg = escape (r"""
231
241
Maximum number of different graphs allowed per function exceeded: 2
@@ -236,7 +246,8 @@ def foo(x, step):
236
246
""" )
237
247
238
248
with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
239
- self ._run_and_compare (foo , args = (inp , 2 ), max_different_graphs = max_different_graphs )
249
+ self ._run_and_compare (
250
+ foo , args = (inp , 2 ), max_different_graphs = max_different_graphs )
240
251
241
252
242
253
if __name__ == "__main__" :
0 commit comments