1
+ import re
2
+ import textwrap
1
3
import torch
2
4
import torch_xla
3
5
import test_utils
4
6
import unittest
5
7
6
8
9
+ # Processes a string, so that it can be used as the expected error regex.
10
+ # Specifically, it does 3 things:
11
+ #
12
+ # 1. s[1:]: assumes the first character of the string is a new-line, and
13
+ # removes it.
14
+ #
15
+ # 2. textwrap.dedent(): strips the leading space in the string, allowing us
16
+ # to write more readable multi-line strings.
17
+ #
18
+ # 3. ESCAPE_RE.sub(): escapes special characters, such as parenthesis,
19
+ # brackets, and braces, so as to allow us to write more
20
+ # readable strings.
21
+ #
22
+ # Note that because of (3), we lose the "regex" part, not being able to use
23
+ # regex wildcards, such as "*".
24
+ ESCAPE_RE = re .compile (r"([\[\](){}])" )
25
+
26
+ def escape (s ):
27
+ return ESCAPE_RE .sub (r"\\\1" , textwrap .dedent (s [1 :]))
28
+
29
+
7
30
class TestDynamicShapeDetector (test_utils .XlaTestCase ):
8
31
9
- def _run_and_compare (self , f , args = None , allowed_traces = None ):
32
+ def _run_and_compare (self , f , args = None , max_different_graphs = None ):
10
33
"""Run f and its torch_xla.compile wrapped version, comparing the equality
11
34
of their results.
12
35
13
36
If no optf is provided, we create a new one by wrapping it with
14
37
torch_xla.compile ourselves.
15
38
"""
16
- optf = torch_xla .compile (f , allowed_traces = allowed_traces )
39
+ optf = torch_xla .compile (f , max_different_graphs = max_different_graphs )
17
40
args = args or []
18
41
19
42
out = f (* args )
@@ -22,18 +45,18 @@ def _run_and_compare(self, f, args=None, allowed_traces=None):
22
45
self .assertEqual (out , optout )
23
46
24
47
def test_single (self ):
25
- # Test: trace a function once, when only one trace is allowed.
48
+ # Test: trace a function once, when only one graph is allowed.
26
49
27
50
def foo (x ):
28
51
return x + x
29
52
30
53
inp = torch .rand (10 , device = torch_xla .device ())
31
- self ._run_and_compare (foo , args = (inp ,), allowed_traces = 1 )
54
+ self ._run_and_compare (foo , args = (inp ,), max_different_graphs = 1 )
32
55
33
- def test_many_traces (self ):
34
- # Test: multiple traces of a function.
56
+ def test_many_graphs (self ):
57
+ # Test: multiple graphs of a function.
35
58
#
36
- # Steps 0~2 and 5: create new traces .
59
+ # Steps 0~2 and 5: create new graphs .
37
60
# Steps 3 and 4: ensure we have already traced these paths.
38
61
39
62
def foo (x , step ):
@@ -50,41 +73,42 @@ def foo(x, step):
50
73
inp = torch .rand (10 , device = torch_xla .device ())
51
74
52
75
for i in range (6 ):
53
- self ._run_and_compare (foo , args = (inp , i ), allowed_traces = 4 )
76
+ self ._run_and_compare (foo , args = (inp , i ), max_different_graphs = 4 )
54
77
55
- def test_trace_limit_exceeded_different_input_shape (self ):
56
- # Test: catch trace limit exceeded error when running the function with a
78
+ def test_graph_limit_exceeded_different_input_shape (self ):
79
+ # Test: catch graph limit exceeded error when running the function with a
57
80
# function with different shape.
58
81
59
- allowed_traces = 1
82
+ max_different_graphs = 1
60
83
61
84
def foo (x ):
62
85
return x + x
63
86
64
87
inp1 = torch .rand (10 , device = torch_xla .device ())
65
- self ._run_and_compare (foo , args = (inp1 ,), allowed_traces = allowed_traces )
88
+ self ._run_and_compare (foo , args = (inp1 ,), max_different_graphs = max_different_graphs )
66
89
67
- msg = """\
68
- .* Maximum number of different traces allowed per function exceeded: 1
69
- Got: [] aten::expand, xla_shape=f32[10]{0}, dynamic_dims: (), size=(10)
70
- Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
90
+ expected_error_msg = escape (r"""
91
+ Maximum number of different graphs allowed per function exceeded: 1
92
+ Got: [] aten::add, xla_shape=f32[5]{0}, dynamic_dims: ()
93
+ Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
94
+ """ )
71
95
72
- with self .assertRaises (RuntimeError , msg = msg ):
96
+ with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
73
97
inp2 = torch .rand (5 , device = torch_xla .device ())
74
- self ._run_and_compare (foo , args = (inp2 ,), allowed_traces = allowed_traces )
98
+ self ._run_and_compare (foo , args = (inp2 ,), max_different_graphs = max_different_graphs )
75
99
76
- def test_trace_limit_exceeded_common_sequence_mismatch (self ):
77
- # Test: catch trace limit exceeded error when the common sequence (i.e. compressed
100
+ def test_graph_limit_exceeded_common_sequence_mismatch (self ):
101
+ # Test: catch graph limit exceeded error when the common sequence (i.e. compressed
78
102
# path) of the trie node mismatches.
79
103
#
80
- # Step 0: creates a trace with one node containing the add operation
104
+ # Step 0: creates a graph with one node containing the add operation
81
105
#
82
106
# Step 1: tries to create 2 child nodes with:
83
- # (i) add operation (previous trace ); and
107
+ # (i) add operation (previous graph ); and
84
108
# (ii) mul operation.
85
109
# However, it fails since we have reached the limit.
86
110
87
- allowed_traces = 1
111
+ max_different_graphs = 1
88
112
89
113
def foo (x , step ):
90
114
if step == 0 :
@@ -93,32 +117,33 @@ def foo(x, step):
93
117
return x * 5
94
118
95
119
inp = torch .rand (10 , device = torch_xla .device ())
96
- self ._run_and_compare (foo , args = (inp , 0 ), allowed_traces = allowed_traces )
120
+ self ._run_and_compare (foo , args = (inp , 0 ), max_different_graphs = max_different_graphs )
97
121
98
- msg = """\
99
- .* Maximum number of different traces allowed per function exceeded: 1
100
- Got: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
101
- Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
122
+ expected_error_msg = escape (r"""
123
+ Maximum number of different graphs allowed per function exceeded: 1
124
+ Got: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
125
+ Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
126
+ """ )
102
127
103
- with self .assertRaises (RuntimeError , msg = msg ):
104
- self ._run_and_compare (foo , args = (inp , 2 ), allowed_traces = allowed_traces )
128
+ with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
129
+ self ._run_and_compare (foo , args = (inp , 2 ), max_different_graphs = max_different_graphs )
105
130
106
- def test_trace_limit_exceeded_children_mismatch (self ):
107
- # Test: catch trace limit exceeded error when the expected child of the trie
131
+ def test_graph_limit_exceeded_children_mismatch (self ):
132
+ # Test: catch graph limit exceeded error when the expected child of the trie
108
133
# node mismatches.
109
134
#
110
- # Step 0: creates a trace with one node containing 3 operations, the last
135
+ # Step 0: creates a graph with one node containing 3 operations, the last
111
136
# being a mul operation.
112
137
#
113
- # Step 1: creates another trace by splitting the node, creating 2 other child
138
+ # Step 1: creates another graph by splitting the node, creating 2 other child
114
139
# nodes containing the different operations in the end:
115
140
# (i) mul operation; and
116
141
# (ii) add operation.
117
142
#
118
143
# Step 2: tries to create a 3rd child node: div operation. However, we can't
119
144
# do it, since we have reached the limit.
120
145
121
- allowed_traces = 2
146
+ max_different_graphs = 2
122
147
123
148
def foo (x , step ):
124
149
r = x + x
@@ -129,30 +154,31 @@ def foo(x, step):
129
154
return r / 3
130
155
131
156
inp = torch .rand (10 , device = torch_xla .device ())
132
- self ._run_and_compare (foo , args = (inp , 0 ), allowed_traces = allowed_traces )
133
- self ._run_and_compare (foo , args = (inp , 1 ), allowed_traces = allowed_traces )
134
-
135
- msg = """\
136
- .* Maximum number of different traces allowed per function exceeded: 2
137
- Got: [] aten::expand, xla_shape=f32[10]{0}, dynamic_dims: (), size=(10)
138
- Expected either of:
139
- - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
140
- - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
141
-
142
- with self .assertRaises (RuntimeError , msg = msg ):
143
- self ._run_and_compare (foo , args = (inp , 2 ), allowed_traces = allowed_traces )
144
-
145
- def test_trace_limit_exceeded_common_sequence_early_stop (self ):
146
- # Test: catch trace limit exceeded error when the trace ends unexpectedly in
157
+ self ._run_and_compare (foo , args = (inp , 0 ), max_different_graphs = max_different_graphs )
158
+ self ._run_and_compare (foo , args = (inp , 1 ), max_different_graphs = max_different_graphs )
159
+
160
+ expected_error_msg = escape (r"""
161
+ Maximum number of different graphs allowed per function exceeded: 2
162
+ Got: [] aten::div, xla_shape=f32[10]{0}, dynamic_dims: ()
163
+ Expected either of:
164
+ - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
165
+ - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
166
+ """ )
167
+
168
+ with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
169
+ self ._run_and_compare (foo , args = (inp , 2 ), max_different_graphs = max_different_graphs )
170
+
171
+ def test_graph_limit_exceeded_common_sequence_early_stop (self ):
172
+ # Test: catch graph limit exceeded error when the graph ends unexpectedly in
147
173
# the common sequence.
148
174
#
149
- # Step 0: creates a trace with one node containing 3 operations.
175
+ # Step 0: creates a graph with one node containing 3 operations.
150
176
#
151
- # Step 1: at the end of this trace , it tries to create a new node containing
152
- # the remaining operations of the previous trace , i.e. mul operation. However,
177
+ # Step 1: at the end of this graph , it tries to create a new node containing
178
+ # the remaining operations of the previous graph , i.e. mul operation. However,
153
179
# it fails because we have reached the limit.
154
180
155
- allowed_traces = 1
181
+ max_different_graphs = 1
156
182
157
183
def foo (x , mul = False ):
158
184
r = x + x
@@ -162,31 +188,32 @@ def foo(x, mul=False):
162
188
return r
163
189
164
190
inp = torch .rand (10 , device = torch_xla .device ())
165
- self ._run_and_compare (foo , args = (inp , True ), allowed_traces = allowed_traces )
191
+ self ._run_and_compare (foo , args = (inp , True ), max_different_graphs = max_different_graphs )
166
192
167
- msg = """\
168
- .* Maximum number of different traces allowed per function exceeded: 1
169
- Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
170
- Expected: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()"""
193
+ expected_error_msg = escape (r"""
194
+ Maximum number of different graphs allowed per function exceeded: 1
195
+ Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
196
+ Expected: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
197
+ """ )
171
198
172
- with self .assertRaises (RuntimeError , msg = msg ):
199
+ with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
173
200
self ._run_and_compare (
174
- foo , args = (inp , False ), allowed_traces = allowed_traces )
201
+ foo , args = (inp , False ), max_different_graphs = max_different_graphs )
175
202
176
- def test_trace_limit_exceeded_children_early_stop (self ):
177
- # Test: catch trace limit exceeded error when the trace ends unexpectedly at
203
+ def test_graph_limit_exceeded_children_early_stop (self ):
204
+ # Test: catch graph limit exceeded error when the graph ends unexpectedly at
178
205
# a fork point (i.e. next operation would jump to anothe trie node).
179
206
#
180
- # Step 0: creates a trace with one node containing 3 operations.
207
+ # Step 0: creates a graph with one node containing 3 operations.
181
208
#
182
209
# Step 1: splits the node, creating 2 child nodes containing:
183
- # (i) the differring operations from the last trace , i.e. mul operation
210
+ # (i) the differring operations from the last graph , i.e. mul operation
184
211
# (ii) the current last operation, i.e. add operation
185
212
#
186
- # Step 3: at the end of this trace , it tries to turn the current trie node
187
- # into a new trace . However, it fails since we have reached the limit.
213
+ # Step 3: at the end of this graph , it tries to turn the current trie node
214
+ # into a new graph . However, it fails since we have reached the limit.
188
215
189
- allowed_traces = 2
216
+ max_different_graphs = 2
190
217
191
218
def foo (x , step ):
192
219
r = x + x
@@ -197,18 +224,19 @@ def foo(x, step):
197
224
return r
198
225
199
226
inp = torch .rand (10 , device = torch_xla .device ())
200
- self ._run_and_compare (foo , args = (inp , 0 ), allowed_traces = allowed_traces )
201
- self ._run_and_compare (foo , args = (inp , 1 ), allowed_traces = allowed_traces )
202
-
203
- msg = """\
204
- .* Maximum number of different traces allowed per function exceeded: 2
205
- Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
206
- Expected either of:
207
- - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
208
- - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
209
-
210
- with self .assertRaises (RuntimeError , msg = msg ):
211
- self ._run_and_compare (foo , args = (inp , 2 ), allowed_traces = allowed_traces )
227
+ self ._run_and_compare (foo , args = (inp , 0 ), max_different_graphs = max_different_graphs )
228
+ self ._run_and_compare (foo , args = (inp , 1 ), max_different_graphs = max_different_graphs )
229
+
230
+ expected_error_msg = escape (r"""
231
+ Maximum number of different graphs allowed per function exceeded: 2
232
+ Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
233
+ Expected either of:
234
+ - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
235
+ - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
236
+ """ )
237
+
238
+ with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
239
+ self ._run_and_compare (foo , args = (inp , 2 ), max_different_graphs = max_different_graphs )
212
240
213
241
214
242
if __name__ == "__main__" :
0 commit comments