1
+ import re
2
+ import textwrap
1
3
import torch
2
4
import torch_xla
3
5
import test_utils
4
6
import unittest
5
7
8
+ # Processes a string, so that it can be used as the expected error regex.
9
+ # Specifically, it does 3 things:
10
+ #
11
+ # 1. s[1:]: assumes the first character of the string is a new-line, and
12
+ # removes it.
13
+ #
14
+ # 2. textwrap.dedent(): strips the leading space in the string, allowing us
15
+ # to write more readable multi-line strings.
16
+ #
17
+ # 3. ESCAPE_RE.sub(): escapes special characters, such as parenthesis,
18
+ # brackets, and braces, so as to allow us to write more
19
+ # readable strings.
20
+ #
21
+ # Note that because of (3), we lose the "regex" part, not being able to use
22
+ # regex wildcards, such as "*".
23
+ ESCAPE_RE = re .compile (r"([\[\](){}])" )
24
+
25
+
26
+ def escape (s ):
27
+ return ESCAPE_RE .sub (r"\\\1" , textwrap .dedent (s [1 :]))
28
+
6
29
7
30
class TestDynamicShapeDetector (test_utils .XlaTestCase ):
8
31
9
- def _run_and_compare (self , f , args = None , allowed_traces = None ):
32
+ def _run_and_compare (self , f , args = None , max_different_graphs = None ):
10
33
"""Run f and its torch_xla.compile wrapped version, comparing the equality
11
34
of their results.
12
35
13
36
If no optf is provided, we create a new one by wrapping it with
14
37
torch_xla.compile ourselves.
15
38
"""
16
- optf = torch_xla .compile (f , allowed_traces = allowed_traces )
39
+ optf = torch_xla .compile (f , max_different_graphs = max_different_graphs )
17
40
args = args or []
18
41
19
42
out = f (* args )
@@ -22,18 +45,18 @@ def _run_and_compare(self, f, args=None, allowed_traces=None):
22
45
self .assertEqual (out , optout )
23
46
24
47
def test_single (self ):
25
- # Test: trace a function once, when only one trace is allowed.
48
+ # Test: trace a function once, when only one graph is allowed.
26
49
27
50
def foo (x ):
28
51
return x + x
29
52
30
53
inp = torch .rand (10 , device = torch_xla .device ())
31
- self ._run_and_compare (foo , args = (inp ,), allowed_traces = 1 )
54
+ self ._run_and_compare (foo , args = (inp ,), max_different_graphs = 1 )
32
55
33
- def test_many_traces (self ):
34
- # Test: multiple traces of a function.
56
+ def test_many_graphs (self ):
57
+ # Test: multiple graphs of a function.
35
58
#
36
- # Steps 0~2 and 5: create new traces .
59
+ # Steps 0~2 and 5: create new graphs .
37
60
# Steps 3 and 4: ensure we have already traced these paths.
38
61
39
62
def foo (x , step ):
@@ -50,41 +73,44 @@ def foo(x, step):
50
73
inp = torch .rand (10 , device = torch_xla .device ())
51
74
52
75
for i in range (6 ):
53
- self ._run_and_compare (foo , args = (inp , i ), allowed_traces = 4 )
76
+ self ._run_and_compare (foo , args = (inp , i ), max_different_graphs = 4 )
54
77
55
- def test_trace_limit_exceeded_different_input_shape (self ):
56
- # Test: catch trace limit exceeded error when running the function with a
78
+ def test_graph_limit_exceeded_different_input_shape (self ):
79
+ # Test: catch graph limit exceeded error when running the function with a
57
80
# function with different shape.
58
81
59
- allowed_traces = 1
82
+ max_different_graphs = 1
60
83
61
84
def foo (x ):
62
85
return x + x
63
86
64
87
inp1 = torch .rand (10 , device = torch_xla .device ())
65
- self ._run_and_compare (foo , args = (inp1 ,), allowed_traces = allowed_traces )
88
+ self ._run_and_compare (
89
+ foo , args = (inp1 ,), max_different_graphs = max_different_graphs )
66
90
67
- msg = """\
68
- .* Maximum number of different traces allowed per function exceeded: 1
69
- Got: [] aten::expand, xla_shape=f32[10]{0}, dynamic_dims: (), size=(10)
70
- Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
91
+ expected_error_msg = escape (r"""
92
+ Maximum number of different graphs allowed per function exceeded: 1
93
+ Got: [] aten::add, xla_shape=f32[5]{0}, dynamic_dims: ()
94
+ Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
95
+ """ )
71
96
72
- with self .assertRaises (RuntimeError , msg = msg ):
97
+ with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
73
98
inp2 = torch .rand (5 , device = torch_xla .device ())
74
- self ._run_and_compare (foo , args = (inp2 ,), allowed_traces = allowed_traces )
99
+ self ._run_and_compare (
100
+ foo , args = (inp2 ,), max_different_graphs = max_different_graphs )
75
101
76
- def test_trace_limit_exceeded_common_sequence_mismatch (self ):
77
- # Test: catch trace limit exceeded error when the common sequence (i.e. compressed
102
+ def test_graph_limit_exceeded_common_sequence_mismatch (self ):
103
+ # Test: catch graph limit exceeded error when the common sequence (i.e. compressed
78
104
# path) of the trie node mismatches.
79
105
#
80
- # Step 0: creates a trace with one node containing the add operation
106
+ # Step 0: creates a graph with one node containing the add operation
81
107
#
82
108
# Step 1: tries to create 2 child nodes with:
83
- # (i) add operation (previous trace ); and
109
+ # (i) add operation (previous graph ); and
84
110
# (ii) mul operation.
85
111
# However, it fails since we have reached the limit.
86
112
87
- allowed_traces = 1
113
+ max_different_graphs = 1
88
114
89
115
def foo (x , step ):
90
116
if step == 0 :
@@ -93,32 +119,35 @@ def foo(x, step):
93
119
return x * 5
94
120
95
121
inp = torch .rand (10 , device = torch_xla .device ())
96
- self ._run_and_compare (foo , args = (inp , 0 ), allowed_traces = allowed_traces )
122
+ self ._run_and_compare (
123
+ foo , args = (inp , 0 ), max_different_graphs = max_different_graphs )
97
124
98
- msg = """\
99
- .* Maximum number of different traces allowed per function exceeded: 1
100
- Got: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
101
- Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
125
+ expected_error_msg = escape (r"""
126
+ Maximum number of different graphs allowed per function exceeded: 1
127
+ Got: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
128
+ Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
129
+ """ )
102
130
103
- with self .assertRaises (RuntimeError , msg = msg ):
104
- self ._run_and_compare (foo , args = (inp , 2 ), allowed_traces = allowed_traces )
131
+ with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
132
+ self ._run_and_compare (
133
+ foo , args = (inp , 2 ), max_different_graphs = max_different_graphs )
105
134
106
- def test_trace_limit_exceeded_children_mismatch (self ):
107
- # Test: catch trace limit exceeded error when the expected child of the trie
135
+ def test_graph_limit_exceeded_children_mismatch (self ):
136
+ # Test: catch graph limit exceeded error when the expected child of the trie
108
137
# node mismatches.
109
138
#
110
- # Step 0: creates a trace with one node containing 3 operations, the last
139
+ # Step 0: creates a graph with one node containing 3 operations, the last
111
140
# being a mul operation.
112
141
#
113
- # Step 1: creates another trace by splitting the node, creating 2 other child
142
+ # Step 1: creates another graph by splitting the node, creating 2 other child
114
143
# nodes containing the different operations in the end:
115
144
# (i) mul operation; and
116
145
# (ii) add operation.
117
146
#
118
147
# Step 2: tries to create a 3rd child node: div operation. However, we can't
119
148
# do it, since we have reached the limit.
120
149
121
- allowed_traces = 2
150
+ max_different_graphs = 2
122
151
123
152
def foo (x , step ):
124
153
r = x + x
@@ -129,30 +158,34 @@ def foo(x, step):
129
158
return r / 3
130
159
131
160
inp = torch .rand (10 , device = torch_xla .device ())
132
- self ._run_and_compare (foo , args = (inp , 0 ), allowed_traces = allowed_traces )
133
- self ._run_and_compare (foo , args = (inp , 1 ), allowed_traces = allowed_traces )
134
-
135
- msg = """\
136
- .* Maximum number of different traces allowed per function exceeded: 2
137
- Got: [] aten::expand, xla_shape=f32[10]{0}, dynamic_dims: (), size=(10)
138
- Expected either of:
139
- - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
140
- - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
141
-
142
- with self .assertRaises (RuntimeError , msg = msg ):
143
- self ._run_and_compare (foo , args = (inp , 2 ), allowed_traces = allowed_traces )
161
+ self ._run_and_compare (
162
+ foo , args = (inp , 0 ), max_different_graphs = max_different_graphs )
163
+ self ._run_and_compare (
164
+ foo , args = (inp , 1 ), max_different_graphs = max_different_graphs )
165
+
166
+ expected_error_msg = escape (r"""
167
+ Maximum number of different graphs allowed per function exceeded: 2
168
+ Got: [] aten::div, xla_shape=f32[10]{0}, dynamic_dims: ()
169
+ Expected either of:
170
+ - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
171
+ - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
172
+ """ )
173
+
174
+ with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
175
+ self ._run_and_compare (
176
+ foo , args = (inp , 2 ), max_different_graphs = max_different_graphs )
144
177
145
- def test_trace_limit_exceeded_common_sequence_early_stop (self ):
146
- # Test: catch trace limit exceeded error when the trace ends unexpectedly in
178
+ def test_graph_limit_exceeded_common_sequence_early_stop (self ):
179
+ # Test: catch graph limit exceeded error when the graph ends unexpectedly in
147
180
# the common sequence.
148
181
#
149
- # Step 0: creates a trace with one node containing 3 operations.
182
+ # Step 0: creates a graph with one node containing 3 operations.
150
183
#
151
- # Step 1: at the end of this trace , it tries to create a new node containing
152
- # the remaining operations of the previous trace , i.e. mul operation. However,
184
+ # Step 1: at the end of this graph , it tries to create a new node containing
185
+ # the remaining operations of the previous graph , i.e. mul operation. However,
153
186
# it fails because we have reached the limit.
154
187
155
- allowed_traces = 1
188
+ max_different_graphs = 1
156
189
157
190
def foo (x , mul = False ):
158
191
r = x + x
@@ -162,31 +195,33 @@ def foo(x, mul=False):
162
195
return r
163
196
164
197
inp = torch .rand (10 , device = torch_xla .device ())
165
- self ._run_and_compare (foo , args = (inp , True ), allowed_traces = allowed_traces )
198
+ self ._run_and_compare (
199
+ foo , args = (inp , True ), max_different_graphs = max_different_graphs )
166
200
167
- msg = """\
168
- .* Maximum number of different traces allowed per function exceeded: 1
169
- Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
170
- Expected: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()"""
201
+ expected_error_msg = escape (r"""
202
+ Maximum number of different graphs allowed per function exceeded: 1
203
+ Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
204
+ Expected: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
205
+ """ )
171
206
172
- with self .assertRaises (RuntimeError , msg = msg ):
207
+ with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
173
208
self ._run_and_compare (
174
- foo , args = (inp , False ), allowed_traces = allowed_traces )
209
+ foo , args = (inp , False ), max_different_graphs = max_different_graphs )
175
210
176
- def test_trace_limit_exceeded_children_early_stop (self ):
177
- # Test: catch trace limit exceeded error when the trace ends unexpectedly at
211
+ def test_graph_limit_exceeded_children_early_stop (self ):
212
+ # Test: catch graph limit exceeded error when the graph ends unexpectedly at
178
213
# a fork point (i.e. next operation would jump to anothe trie node).
179
214
#
180
- # Step 0: creates a trace with one node containing 3 operations.
215
+ # Step 0: creates a graph with one node containing 3 operations.
181
216
#
182
217
# Step 1: splits the node, creating 2 child nodes containing:
183
- # (i) the differring operations from the last trace , i.e. mul operation
218
+ # (i) the differring operations from the last graph , i.e. mul operation
184
219
# (ii) the current last operation, i.e. add operation
185
220
#
186
- # Step 3: at the end of this trace , it tries to turn the current trie node
187
- # into a new trace . However, it fails since we have reached the limit.
221
+ # Step 3: at the end of this graph , it tries to turn the current trie node
222
+ # into a new graph . However, it fails since we have reached the limit.
188
223
189
- allowed_traces = 2
224
+ max_different_graphs = 2
190
225
191
226
def foo (x , step ):
192
227
r = x + x
@@ -197,18 +232,22 @@ def foo(x, step):
197
232
return r
198
233
199
234
inp = torch .rand (10 , device = torch_xla .device ())
200
- self ._run_and_compare (foo , args = (inp , 0 ), allowed_traces = allowed_traces )
201
- self ._run_and_compare (foo , args = (inp , 1 ), allowed_traces = allowed_traces )
202
-
203
- msg = """\
204
- .* Maximum number of different traces allowed per function exceeded: 2
205
- Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
206
- Expected either of:
207
- - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
208
- - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
209
-
210
- with self .assertRaises (RuntimeError , msg = msg ):
211
- self ._run_and_compare (foo , args = (inp , 2 ), allowed_traces = allowed_traces )
235
+ self ._run_and_compare (
236
+ foo , args = (inp , 0 ), max_different_graphs = max_different_graphs )
237
+ self ._run_and_compare (
238
+ foo , args = (inp , 1 ), max_different_graphs = max_different_graphs )
239
+
240
+ expected_error_msg = escape (r"""
241
+ Maximum number of different graphs allowed per function exceeded: 2
242
+ Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
243
+ Expected either of:
244
+ - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
245
+ - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
246
+ """ )
247
+
248
+ with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
249
+ self ._run_and_compare (
250
+ foo , args = (inp , 2 ), max_different_graphs = max_different_graphs )
212
251
213
252
214
253
if __name__ == "__main__" :
0 commit comments