1+ import re
2+ import textwrap
13import torch
24import torch_xla
35import test_utils
46import unittest
57
8+ # Processes a string, so that it can be used as the expected error regex.
9+ # Specifically, it does 3 things:
10+ #
11+ # 1. s[1:]: assumes the first character of the string is a new-line, and
12+ # removes it.
13+ #
14+ # 2. textwrap.dedent(): strips the leading space in the string, allowing us
15+ # to write more readable multi-line strings.
16+ #
17+ # 3. ESCAPE_RE.sub(): escapes special characters, such as parenthesis,
18+ # brackets, and braces, so as to allow us to write more
19+ # readable strings.
20+ #
21+ # Note that because of (3), we lose the "regex" part, not being able to use
22+ # regex wildcards, such as "*".
23+ ESCAPE_RE = re .compile (r"([\[\](){}])" )
24+
25+
26+ def escape (s ):
27+ return ESCAPE_RE .sub (r"\\\1" , textwrap .dedent (s [1 :]))
28+
629
730class TestDynamicShapeDetector (test_utils .XlaTestCase ):
831
9- def _run_and_compare (self , f , args = None , allowed_traces = None ):
32+ def _run_and_compare (self , f , args = None , max_different_graphs = None ):
1033 """Run f and its torch_xla.compile wrapped version, comparing the equality
1134 of their results.
1235
1336 If no optf is provided, we create a new one by wrapping it with
1437 torch_xla.compile ourselves.
1538 """
16- optf = torch_xla .compile (f , allowed_traces = allowed_traces )
39+ optf = torch_xla .compile (f , max_different_graphs = max_different_graphs )
1740 args = args or []
1841
1942 out = f (* args )
@@ -22,18 +45,18 @@ def _run_and_compare(self, f, args=None, allowed_traces=None):
2245 self .assertEqual (out , optout )
2346
2447 def test_single (self ):
25- # Test: trace a function once, when only one trace is allowed.
48+ # Test: trace a function once, when only one graph is allowed.
2649
2750 def foo (x ):
2851 return x + x
2952
3053 inp = torch .rand (10 , device = torch_xla .device ())
31- self ._run_and_compare (foo , args = (inp ,), allowed_traces = 1 )
54+ self ._run_and_compare (foo , args = (inp ,), max_different_graphs = 1 )
3255
33- def test_many_traces (self ):
34- # Test: multiple traces of a function.
56+ def test_many_graphs (self ):
57+ # Test: multiple graphs of a function.
3558 #
36- # Steps 0~2 and 5: create new traces .
59+ # Steps 0~2 and 5: create new graphs .
3760 # Steps 3 and 4: ensure we have already traced these paths.
3861
3962 def foo (x , step ):
@@ -50,41 +73,44 @@ def foo(x, step):
5073 inp = torch .rand (10 , device = torch_xla .device ())
5174
5275 for i in range (6 ):
53- self ._run_and_compare (foo , args = (inp , i ), allowed_traces = 4 )
76+ self ._run_and_compare (foo , args = (inp , i ), max_different_graphs = 4 )
5477
55- def test_trace_limit_exceeded_different_input_shape (self ):
56- # Test: catch trace limit exceeded error when running the function with a
78+ def test_graph_limit_exceeded_different_input_shape (self ):
79+ # Test: catch graph limit exceeded error when running the function with a
5780 # function with different shape.
5881
59- allowed_traces = 1
82+ max_different_graphs = 1
6083
6184 def foo (x ):
6285 return x + x
6386
6487 inp1 = torch .rand (10 , device = torch_xla .device ())
65- self ._run_and_compare (foo , args = (inp1 ,), allowed_traces = allowed_traces )
88+ self ._run_and_compare (
89+ foo , args = (inp1 ,), max_different_graphs = max_different_graphs )
6690
67- msg = """\
68- .* Maximum number of different traces allowed per function exceeded: 1
69- Got: [] aten::expand, xla_shape=f32[10]{0}, dynamic_dims: (), size=(10)
70- Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
91+ expected_error_msg = escape (r"""
92+ Maximum number of different graphs allowed per function exceeded: 1
93+ Got: [] aten::add, xla_shape=f32[5]{0}, dynamic_dims: ()
94+ Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
95+ """ )
7196
72- with self .assertRaises (RuntimeError , msg = msg ):
97+ with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
7398 inp2 = torch .rand (5 , device = torch_xla .device ())
74- self ._run_and_compare (foo , args = (inp2 ,), allowed_traces = allowed_traces )
99+ self ._run_and_compare (
100+ foo , args = (inp2 ,), max_different_graphs = max_different_graphs )
75101
76- def test_trace_limit_exceeded_common_sequence_mismatch (self ):
77- # Test: catch trace limit exceeded error when the common sequence (i.e. compressed
102+ def test_graph_limit_exceeded_common_sequence_mismatch (self ):
103+ # Test: catch graph limit exceeded error when the common sequence (i.e. compressed
78104 # path) of the trie node mismatches.
79105 #
80- # Step 0: creates a trace with one node containing the add operation
106+ # Step 0: creates a graph with one node containing the add operation
81107 #
82108 # Step 1: tries to create 2 child nodes with:
83- # (i) add operation (previous trace ); and
109+ # (i) add operation (previous graph ); and
84110 # (ii) mul operation.
85111 # However, it fails since we have reached the limit.
86112
87- allowed_traces = 1
113+ max_different_graphs = 1
88114
89115 def foo (x , step ):
90116 if step == 0 :
@@ -93,32 +119,35 @@ def foo(x, step):
93119 return x * 5
94120
95121 inp = torch .rand (10 , device = torch_xla .device ())
96- self ._run_and_compare (foo , args = (inp , 0 ), allowed_traces = allowed_traces )
122+ self ._run_and_compare (
123+ foo , args = (inp , 0 ), max_different_graphs = max_different_graphs )
97124
98- msg = """\
99- .* Maximum number of different traces allowed per function exceeded: 1
100- Got: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
101- Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
125+ expected_error_msg = escape (r"""
126+ Maximum number of different graphs allowed per function exceeded: 1
127+ Got: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
128+ Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
129+ """ )
102130
103- with self .assertRaises (RuntimeError , msg = msg ):
104- self ._run_and_compare (foo , args = (inp , 2 ), allowed_traces = allowed_traces )
131+ with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
132+ self ._run_and_compare (
133+ foo , args = (inp , 2 ), max_different_graphs = max_different_graphs )
105134
106- def test_trace_limit_exceeded_children_mismatch (self ):
107- # Test: catch trace limit exceeded error when the expected child of the trie
135+ def test_graph_limit_exceeded_children_mismatch (self ):
136+ # Test: catch graph limit exceeded error when the expected child of the trie
108137 # node mismatches.
109138 #
110- # Step 0: creates a trace with one node containing 3 operations, the last
139+ # Step 0: creates a graph with one node containing 3 operations, the last
111140 # being a mul operation.
112141 #
113- # Step 1: creates another trace by splitting the node, creating 2 other child
142+ # Step 1: creates another graph by splitting the node, creating 2 other child
114143 # nodes containing the different operations in the end:
115144 # (i) mul operation; and
116145 # (ii) add operation.
117146 #
118147 # Step 2: tries to create a 3rd child node: div operation. However, we can't
119148 # do it, since we have reached the limit.
120149
121- allowed_traces = 2
150+ max_different_graphs = 2
122151
123152 def foo (x , step ):
124153 r = x + x
@@ -129,30 +158,34 @@ def foo(x, step):
129158 return r / 3
130159
131160 inp = torch .rand (10 , device = torch_xla .device ())
132- self ._run_and_compare (foo , args = (inp , 0 ), allowed_traces = allowed_traces )
133- self ._run_and_compare (foo , args = (inp , 1 ), allowed_traces = allowed_traces )
134-
135- msg = """\
136- .* Maximum number of different traces allowed per function exceeded: 2
137- Got: [] aten::expand, xla_shape=f32[10]{0}, dynamic_dims: (), size=(10)
138- Expected either of:
139- - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
140- - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
141-
142- with self .assertRaises (RuntimeError , msg = msg ):
143- self ._run_and_compare (foo , args = (inp , 2 ), allowed_traces = allowed_traces )
161+ self ._run_and_compare (
162+ foo , args = (inp , 0 ), max_different_graphs = max_different_graphs )
163+ self ._run_and_compare (
164+ foo , args = (inp , 1 ), max_different_graphs = max_different_graphs )
165+
166+ expected_error_msg = escape (r"""
167+ Maximum number of different graphs allowed per function exceeded: 2
168+ Got: [] aten::div, xla_shape=f32[10]{0}, dynamic_dims: ()
169+ Expected either of:
170+ - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
171+ - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
172+ """ )
173+
174+ with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
175+ self ._run_and_compare (
176+ foo , args = (inp , 2 ), max_different_graphs = max_different_graphs )
144177
145- def test_trace_limit_exceeded_common_sequence_early_stop (self ):
146- # Test: catch trace limit exceeded error when the trace ends unexpectedly in
178+ def test_graph_limit_exceeded_common_sequence_early_stop (self ):
179+ # Test: catch graph limit exceeded error when the graph ends unexpectedly in
147180 # the common sequence.
148181 #
149- # Step 0: creates a trace with one node containing 3 operations.
182+ # Step 0: creates a graph with one node containing 3 operations.
150183 #
151- # Step 1: at the end of this trace , it tries to create a new node containing
152- # the remaining operations of the previous trace , i.e. mul operation. However,
184+ # Step 1: at the end of this graph , it tries to create a new node containing
185+ # the remaining operations of the previous graph , i.e. mul operation. However,
153186 # it fails because we have reached the limit.
154187
155- allowed_traces = 1
188+ max_different_graphs = 1
156189
157190 def foo (x , mul = False ):
158191 r = x + x
@@ -162,31 +195,33 @@ def foo(x, mul=False):
162195 return r
163196
164197 inp = torch .rand (10 , device = torch_xla .device ())
165- self ._run_and_compare (foo , args = (inp , True ), allowed_traces = allowed_traces )
198+ self ._run_and_compare (
199+ foo , args = (inp , True ), max_different_graphs = max_different_graphs )
166200
167- msg = """\
168- .* Maximum number of different traces allowed per function exceeded: 1
169- Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
170- Expected: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()"""
201+ expected_error_msg = escape (r"""
202+ Maximum number of different graphs allowed per function exceeded: 1
203+ Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
204+ Expected: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
205+ """ )
171206
172- with self .assertRaises (RuntimeError , msg = msg ):
207+ with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
173208 self ._run_and_compare (
174- foo , args = (inp , False ), allowed_traces = allowed_traces )
209+ foo , args = (inp , False ), max_different_graphs = max_different_graphs )
175210
176- def test_trace_limit_exceeded_children_early_stop (self ):
177- # Test: catch trace limit exceeded error when the trace ends unexpectedly at
211+ def test_graph_limit_exceeded_children_early_stop (self ):
212+ # Test: catch graph limit exceeded error when the graph ends unexpectedly at
178213 # a fork point (i.e. next operation would jump to anothe trie node).
179214 #
180- # Step 0: creates a trace with one node containing 3 operations.
215+ # Step 0: creates a graph with one node containing 3 operations.
181216 #
182217 # Step 1: splits the node, creating 2 child nodes containing:
183- # (i) the differring operations from the last trace , i.e. mul operation
218+ # (i) the differring operations from the last graph , i.e. mul operation
184219 # (ii) the current last operation, i.e. add operation
185220 #
186- # Step 3: at the end of this trace , it tries to turn the current trie node
187- # into a new trace . However, it fails since we have reached the limit.
221+ # Step 3: at the end of this graph , it tries to turn the current trie node
222+ # into a new graph . However, it fails since we have reached the limit.
188223
189- allowed_traces = 2
224+ max_different_graphs = 2
190225
191226 def foo (x , step ):
192227 r = x + x
@@ -197,18 +232,22 @@ def foo(x, step):
197232 return r
198233
199234 inp = torch .rand (10 , device = torch_xla .device ())
200- self ._run_and_compare (foo , args = (inp , 0 ), allowed_traces = allowed_traces )
201- self ._run_and_compare (foo , args = (inp , 1 ), allowed_traces = allowed_traces )
202-
203- msg = """\
204- .* Maximum number of different traces allowed per function exceeded: 2
205- Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
206- Expected either of:
207- - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
208- - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
209-
210- with self .assertRaises (RuntimeError , msg = msg ):
211- self ._run_and_compare (foo , args = (inp , 2 ), allowed_traces = allowed_traces )
235+ self ._run_and_compare (
236+ foo , args = (inp , 0 ), max_different_graphs = max_different_graphs )
237+ self ._run_and_compare (
238+ foo , args = (inp , 1 ), max_different_graphs = max_different_graphs )
239+
240+ expected_error_msg = escape (r"""
241+ Maximum number of different graphs allowed per function exceeded: 2
242+ Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
243+ Expected either of:
244+ - [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
245+ - [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
246+ """ )
247+
248+ with self .assertRaisesRegex (RuntimeError , expected_error_msg ):
249+ self ._run_and_compare (
250+ foo , args = (inp , 2 ), max_different_graphs = max_different_graphs )
212251
213252
214253if __name__ == "__main__" :
0 commit comments