Skip to content

Fix + Run DynamicShapeDetector tests on CI. #9075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/test_jax_interop.py"
run_test "$CDIR/test_assume_pure.py"
run_test "$CDIR/test_assume_pure_spmd.py"
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=dynamic_shape_detector=5 run_test "$CDIR/test_dynamic_shapes_detector.py" -v
}

# All the new xla op tests should go to run_xla_op_tests3
Expand Down
197 changes: 118 additions & 79 deletions test/test_dynamic_shapes_detector.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,42 @@
import re
import textwrap
import torch
import torch_xla
import test_utils
import unittest

# Processes a string, so that it can be used as the expected error regex.
# Specifically, it does 3 things:
#
# 1. s[1:]: assumes the first character of the string is a new-line, and
# removes it.
#
# 2. textwrap.dedent(): strips the leading space in the string, allowing us
# to write more readable multi-line strings.
#
# 3. ESCAPE_RE.sub(): escapes special characters, such as parenthesis,
# brackets, and braces, so as to allow us to write more
# readable strings.
#
# Note that because of (3), we lose the "regex" part, not being able to use
# regex wildcards, such as "*".
ESCAPE_RE = re.compile(r"([\[\](){}])")


def escape(s):
return ESCAPE_RE.sub(r"\\\1", textwrap.dedent(s[1:]))


class TestDynamicShapeDetector(test_utils.XlaTestCase):

def _run_and_compare(self, f, args=None, allowed_traces=None):
def _run_and_compare(self, f, args=None, max_different_graphs=None):
"""Run f and its torch_xla.compile wrapped version, comparing the equality
of their results.

If no optf is provided, we create a new one by wrapping it with
torch_xla.compile ourselves.
"""
optf = torch_xla.compile(f, allowed_traces=allowed_traces)
optf = torch_xla.compile(f, max_different_graphs=max_different_graphs)
args = args or []

out = f(*args)
Expand All @@ -22,18 +45,18 @@ def _run_and_compare(self, f, args=None, allowed_traces=None):
self.assertEqual(out, optout)

def test_single(self):
# Test: trace a function once, when only one trace is allowed.
# Test: trace a function once, when only one graph is allowed.

def foo(x):
return x + x

inp = torch.rand(10, device=torch_xla.device())
self._run_and_compare(foo, args=(inp,), allowed_traces=1)
self._run_and_compare(foo, args=(inp,), max_different_graphs=1)

def test_many_traces(self):
# Test: multiple traces of a function.
def test_many_graphs(self):
# Test: multiple graphs of a function.
#
# Steps 0~2 and 5: create new traces.
# Steps 0~2 and 5: create new graphs.
# Steps 3 and 4: ensure we have already traced these paths.

def foo(x, step):
Expand All @@ -50,41 +73,44 @@ def foo(x, step):
inp = torch.rand(10, device=torch_xla.device())

for i in range(6):
self._run_and_compare(foo, args=(inp, i), allowed_traces=4)
self._run_and_compare(foo, args=(inp, i), max_different_graphs=4)

def test_trace_limit_exceeded_different_input_shape(self):
# Test: catch trace limit exceeded error when running the function with a
def test_graph_limit_exceeded_different_input_shape(self):
# Test: catch graph limit exceeded error when running the function with a
# function with different shape.

allowed_traces = 1
max_different_graphs = 1

def foo(x):
return x + x

inp1 = torch.rand(10, device=torch_xla.device())
self._run_and_compare(foo, args=(inp1,), allowed_traces=allowed_traces)
self._run_and_compare(
foo, args=(inp1,), max_different_graphs=max_different_graphs)

msg = """\
.* Maximum number of different traces allowed per function exceeded: 1
Got: [] aten::expand, xla_shape=f32[10]{0}, dynamic_dims: (), size=(10)
Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
expected_error_msg = escape(r"""
Maximum number of different graphs allowed per function exceeded: 1
Got: [] aten::add, xla_shape=f32[5]{0}, dynamic_dims: ()
Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
""")

with self.assertRaises(RuntimeError, msg=msg):
with self.assertRaisesRegex(RuntimeError, expected_error_msg):
inp2 = torch.rand(5, device=torch_xla.device())
self._run_and_compare(foo, args=(inp2,), allowed_traces=allowed_traces)
self._run_and_compare(
foo, args=(inp2,), max_different_graphs=max_different_graphs)

def test_trace_limit_exceeded_common_sequence_mismatch(self):
# Test: catch trace limit exceeded error when the common sequence (i.e. compressed
def test_graph_limit_exceeded_common_sequence_mismatch(self):
# Test: catch graph limit exceeded error when the common sequence (i.e. compressed
# path) of the trie node mismatches.
#
# Step 0: creates a trace with one node containing the add operation
# Step 0: creates a graph with one node containing the add operation
#
# Step 1: tries to create 2 child nodes with:
# (i) add operation (previous trace); and
# (i) add operation (previous graph); and
# (ii) mul operation.
# However, it fails since we have reached the limit.

allowed_traces = 1
max_different_graphs = 1

def foo(x, step):
if step == 0:
Expand All @@ -93,32 +119,35 @@ def foo(x, step):
return x * 5

inp = torch.rand(10, device=torch_xla.device())
self._run_and_compare(foo, args=(inp, 0), allowed_traces=allowed_traces)
self._run_and_compare(
foo, args=(inp, 0), max_different_graphs=max_different_graphs)

msg = """\
.* Maximum number of different traces allowed per function exceeded: 1
Got: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""
expected_error_msg = escape(r"""
Maximum number of different graphs allowed per function exceeded: 1
Got: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
Expected: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
""")

with self.assertRaises(RuntimeError, msg=msg):
self._run_and_compare(foo, args=(inp, 2), allowed_traces=allowed_traces)
with self.assertRaisesRegex(RuntimeError, expected_error_msg):
self._run_and_compare(
foo, args=(inp, 2), max_different_graphs=max_different_graphs)

def test_trace_limit_exceeded_children_mismatch(self):
# Test: catch trace limit exceeded error when the expected child of the trie
def test_graph_limit_exceeded_children_mismatch(self):
# Test: catch graph limit exceeded error when the expected child of the trie
# node mismatches.
#
# Step 0: creates a trace with one node containing 3 operations, the last
# Step 0: creates a graph with one node containing 3 operations, the last
# being a mul operation.
#
# Step 1: creates another trace by splitting the node, creating 2 other child
# Step 1: creates another graph by splitting the node, creating 2 other child
# nodes containing the different operations in the end:
# (i) mul operation; and
# (ii) add operation.
#
# Step 2: tries to create a 3rd child node: div operation. However, we can't
# do it, since we have reached the limit.

allowed_traces = 2
max_different_graphs = 2

def foo(x, step):
r = x + x
Expand All @@ -129,30 +158,34 @@ def foo(x, step):
return r / 3

inp = torch.rand(10, device=torch_xla.device())
self._run_and_compare(foo, args=(inp, 0), allowed_traces=allowed_traces)
self._run_and_compare(foo, args=(inp, 1), allowed_traces=allowed_traces)

msg = """\
.* Maximum number of different traces allowed per function exceeded: 2
Got: [] aten::expand, xla_shape=f32[10]{0}, dynamic_dims: (), size=(10)
Expected either of:
- [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
- [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""

with self.assertRaises(RuntimeError, msg=msg):
self._run_and_compare(foo, args=(inp, 2), allowed_traces=allowed_traces)
self._run_and_compare(
foo, args=(inp, 0), max_different_graphs=max_different_graphs)
self._run_and_compare(
foo, args=(inp, 1), max_different_graphs=max_different_graphs)

expected_error_msg = escape(r"""
Maximum number of different graphs allowed per function exceeded: 2
Got: [] aten::div, xla_shape=f32[10]{0}, dynamic_dims: ()
Expected either of:
- [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
- [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
""")

with self.assertRaisesRegex(RuntimeError, expected_error_msg):
self._run_and_compare(
foo, args=(inp, 2), max_different_graphs=max_different_graphs)

def test_trace_limit_exceeded_common_sequence_early_stop(self):
# Test: catch trace limit exceeded error when the trace ends unexpectedly in
def test_graph_limit_exceeded_common_sequence_early_stop(self):
# Test: catch graph limit exceeded error when the graph ends unexpectedly in
# the common sequence.
#
# Step 0: creates a trace with one node containing 3 operations.
# Step 0: creates a graph with one node containing 3 operations.
#
# Step 1: at the end of this trace, it tries to create a new node containing
# the remaining operations of the previous trace, i.e. mul operation. However,
# Step 1: at the end of this graph, it tries to create a new node containing
# the remaining operations of the previous graph, i.e. mul operation. However,
# it fails because we have reached the limit.

allowed_traces = 1
max_different_graphs = 1

def foo(x, mul=False):
r = x + x
Expand All @@ -162,31 +195,33 @@ def foo(x, mul=False):
return r

inp = torch.rand(10, device=torch_xla.device())
self._run_and_compare(foo, args=(inp, True), allowed_traces=allowed_traces)
self._run_and_compare(
foo, args=(inp, True), max_different_graphs=max_different_graphs)

msg = """\
.* Maximum number of different traces allowed per function exceeded: 1
Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
Expected: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()"""
expected_error_msg = escape(r"""
Maximum number of different graphs allowed per function exceeded: 1
Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
Expected: [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
""")

with self.assertRaises(RuntimeError, msg=msg):
with self.assertRaisesRegex(RuntimeError, expected_error_msg):
self._run_and_compare(
foo, args=(inp, False), allowed_traces=allowed_traces)
foo, args=(inp, False), max_different_graphs=max_different_graphs)

def test_trace_limit_exceeded_children_early_stop(self):
# Test: catch trace limit exceeded error when the trace ends unexpectedly at
def test_graph_limit_exceeded_children_early_stop(self):
# Test: catch graph limit exceeded error when the graph ends unexpectedly at
# a fork point (i.e. next operation would jump to anothe trie node).
#
# Step 0: creates a trace with one node containing 3 operations.
# Step 0: creates a graph with one node containing 3 operations.
#
# Step 1: splits the node, creating 2 child nodes containing:
# (i) the differring operations from the last trace, i.e. mul operation
# (i) the differring operations from the last graph, i.e. mul operation
# (ii) the current last operation, i.e. add operation
#
# Step 3: at the end of this trace, it tries to turn the current trie node
# into a new trace. However, it fails since we have reached the limit.
# Step 3: at the end of this graph, it tries to turn the current trie node
# into a new graph. However, it fails since we have reached the limit.

allowed_traces = 2
max_different_graphs = 2

def foo(x, step):
r = x + x
Expand All @@ -197,18 +232,22 @@ def foo(x, step):
return r

inp = torch.rand(10, device=torch_xla.device())
self._run_and_compare(foo, args=(inp, 0), allowed_traces=allowed_traces)
self._run_and_compare(foo, args=(inp, 1), allowed_traces=allowed_traces)

msg = """\
.* Maximum number of different traces allowed per function exceeded: 2
Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
Expected either of:
- [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
- [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()"""

with self.assertRaises(RuntimeError, msg=msg):
self._run_and_compare(foo, args=(inp, 2), allowed_traces=allowed_traces)
self._run_and_compare(
foo, args=(inp, 0), max_different_graphs=max_different_graphs)
self._run_and_compare(
foo, args=(inp, 1), max_different_graphs=max_different_graphs)

expected_error_msg = escape(r"""
Maximum number of different graphs allowed per function exceeded: 2
Reached the end of the function at: [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
Expected either of:
- [] aten::mul, xla_shape=f32[10]{0}, dynamic_dims: ()
- [] aten::add, xla_shape=f32[10]{0}, dynamic_dims: ()
""")

with self.assertRaisesRegex(RuntimeError, expected_error_msg):
self._run_and_compare(
foo, args=(inp, 2), max_different_graphs=max_different_graphs)


if __name__ == "__main__":
Expand Down
Loading
Loading