Skip to content

Commit cd38f0a

Browse files
committed
Fix segmentation fault.
- Reset session data after logging - Add the return statement
1 parent 689be57 commit cd38f0a

File tree

3 files changed

+36
-19
lines changed

3 files changed

+36
-19
lines changed

test/test_dynamic_shapes_detector.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import test_utils
66
import unittest
77

8-
98
# Processes a string, so that it can be used as the expected error regex.
109
# Specifically, it does 3 things:
1110
#
@@ -23,8 +22,9 @@
2322
# regex wildcards, such as "*".
2423
ESCAPE_RE = re.compile(r"([\[\](){}])")
2524

25+
2626
def escape(s):
27-
return ESCAPE_RE.sub(r"\\\1", textwrap.dedent(s[1:]))
27+
return ESCAPE_RE.sub(r"\\\1", textwrap.dedent(s[1:]))
2828

2929

3030
class TestDynamicShapeDetector(test_utils.XlaTestCase):
@@ -85,7 +85,8 @@ def foo(x):
8585
return x + x
8686

8787
inp1 = torch.rand(10, device=torch_xla.device())
88-
self._run_and_compare(foo, args=(inp1,), max_different_graphs=max_different_graphs)
88+
self._run_and_compare(
89+
foo, args=(inp1,), max_different_graphs=max_different_graphs)
8990

9091
expected_error_msg = escape(r"""
9192
Maximum number of different graphs allowed per function exceeded: 1
@@ -95,7 +96,8 @@ def foo(x):
9596

9697
with self.assertRaisesRegex(RuntimeError, expected_error_msg):
9798
inp2 = torch.rand(5, device=torch_xla.device())
98-
self._run_and_compare(foo, args=(inp2,), max_different_graphs=max_different_graphs)
99+
self._run_and_compare(
100+
foo, args=(inp2,), max_different_graphs=max_different_graphs)
99101

100102
def test_graph_limit_exceeded_common_sequence_mismatch(self):
101103
# Test: catch graph limit exceeded error when the common sequence (i.e. compressed
@@ -117,7 +119,8 @@ def foo(x, step):
117119
return x * 5
118120

119121
inp = torch.rand(10, device=torch_xla.device())
120-
self._run_and_compare(foo, args=(inp, 0), max_different_graphs=max_different_graphs)
122+
self._run_and_compare(
123+
foo, args=(inp, 0), max_different_graphs=max_different_graphs)
121124

122125
expected_error_msg = escape(r"""
123126
Maximum number of different graphs allowed per function exceeded: 1
@@ -126,7 +129,8 @@ def foo(x, step):
126129
""")
127130

128131
with self.assertRaisesRegex(RuntimeError, expected_error_msg):
129-
self._run_and_compare(foo, args=(inp, 2), max_different_graphs=max_different_graphs)
132+
self._run_and_compare(
133+
foo, args=(inp, 2), max_different_graphs=max_different_graphs)
130134

131135
def test_graph_limit_exceeded_children_mismatch(self):
132136
# Test: catch graph limit exceeded error when the expected child of the trie
@@ -154,8 +158,10 @@ def foo(x, step):
154158
return r / 3
155159

156160
inp = torch.rand(10, device=torch_xla.device())
157-
self._run_and_compare(foo, args=(inp, 0), max_different_graphs=max_different_graphs)
158-
self._run_and_compare(foo, args=(inp, 1), max_different_graphs=max_different_graphs)
161+
self._run_and_compare(
162+
foo, args=(inp, 0), max_different_graphs=max_different_graphs)
163+
self._run_and_compare(
164+
foo, args=(inp, 1), max_different_graphs=max_different_graphs)
159165

160166
expected_error_msg = escape(r"""
161167
Maximum number of different graphs allowed per function exceeded: 2
@@ -166,7 +172,8 @@ def foo(x, step):
166172
""")
167173

168174
with self.assertRaisesRegex(RuntimeError, expected_error_msg):
169-
self._run_and_compare(foo, args=(inp, 2), max_different_graphs=max_different_graphs)
175+
self._run_and_compare(
176+
foo, args=(inp, 2), max_different_graphs=max_different_graphs)
170177

171178
def test_graph_limit_exceeded_common_sequence_early_stop(self):
172179
# Test: catch graph limit exceeded error when the graph ends unexpectedly in
@@ -188,7 +195,8 @@ def foo(x, mul=False):
188195
return r
189196

190197
inp = torch.rand(10, device=torch_xla.device())
191-
self._run_and_compare(foo, args=(inp, True), max_different_graphs=max_different_graphs)
198+
self._run_and_compare(
199+
foo, args=(inp, True), max_different_graphs=max_different_graphs)
192200

193201
expected_error_msg = escape(r"""
194202
Maximum number of different graphs allowed per function exceeded: 1
@@ -224,8 +232,10 @@ def foo(x, step):
224232
return r
225233

226234
inp = torch.rand(10, device=torch_xla.device())
227-
self._run_and_compare(foo, args=(inp, 0), max_different_graphs=max_different_graphs)
228-
self._run_and_compare(foo, args=(inp, 1), max_different_graphs=max_different_graphs)
235+
self._run_and_compare(
236+
foo, args=(inp, 0), max_different_graphs=max_different_graphs)
237+
self._run_and_compare(
238+
foo, args=(inp, 1), max_different_graphs=max_different_graphs)
229239

230240
expected_error_msg = escape(r"""
231241
Maximum number of different graphs allowed per function exceeded: 2
@@ -236,7 +246,8 @@ def foo(x, step):
236246
""")
237247

238248
with self.assertRaisesRegex(RuntimeError, expected_error_msg):
239-
self._run_and_compare(foo, args=(inp, 2), max_different_graphs=max_different_graphs)
249+
self._run_and_compare(
250+
foo, args=(inp, 2), max_different_graphs=max_different_graphs)
240251

241252

242253
if __name__ == "__main__":

torch_xla/csrc/dynamic_shape_detector.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ namespace torch_xla {
99
// Maximum number of allowed graphs per function (i.e. session).
1010
static std::size_t max_different_graphs = 1;
1111

12-
TrieNode::TrieNode(const TrieValue& value, bool is_graph_boundary) : TrieNode() {
12+
TrieNode::TrieNode(const TrieValue& value, bool is_graph_boundary)
13+
: TrieNode() {
1314
common_sequence_.push_back(value);
1415
is_graph_boundary_ = is_graph_boundary;
1516
}
@@ -115,6 +116,7 @@ TrieBuilder TrieNode::AddValue(TrieValue value, std::size_t matched,
115116
// Maybe split the current node into: prefix (before matched) and suffix
116117
// (after matched).
117118
bool did_split = MaybeSplitAt(matched);
119+
TF_VLOG(5) << "MaybeSplitAt(" << matched << "): " << did_split;
118120

119121
// Create a new node that contains only the given value.
120122
std::unique_ptr<TrieNode> node = std::make_unique<TrieNode>(value);
@@ -132,6 +134,7 @@ TrieBuilder TrieNode::AddValue(TrieValue value, std::size_t matched,
132134
is_graph_boundary_ = false;
133135
}
134136

137+
TF_VLOG(5) << "Added value: " << value.str << " (" << value.hash << ")";
135138
return {children_[value.hash].get(), 1};
136139
}
137140

@@ -144,6 +147,8 @@ bool TrieNode::MaybeSplitAt(std::size_t matched) {
144147
common_sequence.subspan(0, /*len=*/matched);
145148
absl::Span<const TrieValue> suffix = common_sequence.subspan(matched);
146149

150+
bool did_split = false;
151+
147152
// A split only occurs if suffix is not empty.
148153
if (!suffix.empty()) {
149154
std::unique_ptr<TrieNode> suffix_node =
@@ -159,10 +164,13 @@ bool TrieNode::MaybeSplitAt(std::size_t matched) {
159164
TF_VLOG(5) << "Split node " << children_[suffix.front().hash].get()
160165
<< " at position " << matched << ": " << suffix.front().str
161166
<< " (" << suffix.front().hash << ")";
167+
168+
did_split = true;
162169
}
163170

164171
// This node's common_sequence_ will be whatever the prefix was.
165-
common_sequence_ = std::vector<TrieValue>{prefix.begin(), prefix.end()};
172+
common_sequence_.erase(common_sequence_.begin() + matched, common_sequence_.end());
173+
return did_split;
166174
}
167175

168176
DynamicShapeDetector* DynamicShapeDetector::Get() {
@@ -209,8 +217,8 @@ void DynamicShapeDetector::EndSession() {
209217
TF_VLOG(5) << "Created new graph.";
210218
}
211219

212-
ResetSession();
213220
TF_VLOG(5) << "Ended session: " << current_session_->name_;
221+
ResetSession();
214222
} catch (const std::exception& e) {
215223
// MarkGraphBoundary might raise an exception if AllowNewGraph() is false.
216224
// Catch it here, so that we can correctly end the session.

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2619,9 +2619,7 @@ void InitXlaModuleBindings(py::module m) {
26192619
DynamicShapeDetector::SetMaxDifferentGraphs(max_different_graphs);
26202620
});
26212621
m.def("_dynamic_shape_detector_get_max_different_graphs",
2622-
[]() {
2623-
return DynamicShapeDetector::GetMaxDifferentGraphs();
2624-
});
2622+
[]() { return DynamicShapeDetector::GetMaxDifferentGraphs(); });
26252623
m.def("_replace_xla_tensor",
26262624
[](at::Tensor& self, const at::Tensor& source) -> at::Tensor& {
26272625
return XLANativeFunctions::set_(self, source);

0 commit comments

Comments
 (0)