Skip to content

Commit fb6038d

Browse files
authored
Add alternative dynamo backend (#8893)
1 parent 379ebd5 commit fb6038d

File tree

6 files changed

+140
-60
lines changed

6 files changed

+140
-60
lines changed

test/dynamo/test_dynamo.py

+49-44
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch_xla.core.xla_env_vars as xenv
1111
from torch_xla import runtime as xr
1212
import torch_xla.debug.profiler as xp
13+
from torch_xla._dynamo import dynamo_backend2
1314
import torch.optim as optim
1415
import torch.nn as nn
1516
import torch._dynamo as dynamo
@@ -38,31 +39,33 @@ def _is_on_neuron():
3839
skipOnNeuron = unittest.skipIf(_is_on_neuron(), 'Not supported on NEURON')
3940

4041

41-
class DynamoInPlaceTest(unittest.TestCase):
42+
class DynamoInPlaceTest(parameterized.TestCase):
4243

4344
def inplace_update(self, a):
4445
a += 1
4546
return a
4647

47-
def test_inplace_update_correctness(self):
48+
@parameterized.parameters(['openxla', dynamo_backend2.dynamo_backend])
49+
def test_inplace_update_correctness(self, backend):
4850
dynamo_inplace = torch.compile(
49-
self.inplace_update, backend="openxla", fullgraph=True)
51+
self.inplace_update, backend=backend, fullgraph=True)
5052
t = torch.tensor([0, 1, 2], device=xm.xla_device())
5153
for i in range(10):
5254
t = dynamo_inplace(t)
5355
self.assertTrue(torch.all(torch.eq(t.cpu(), torch.tensor([10, 11, 12]))))
5456

5557

56-
class DynamRandomOpTest(unittest.TestCase):
58+
class DynamRandomOpTest(parameterized.TestCase):
5759

5860
def random_op(self, a):
5961
return torch.randn(5, 5, device=a.device) + a
6062

61-
def test_random_op_different_result_each_run(self):
63+
@parameterized.parameters(['openxla', dynamo_backend2.dynamo_backend])
64+
def test_random_op_different_result_each_run(self, backend):
6265
xm.wait_device_ops()
6366
met.clear_all()
6467
dynamo_random_op = torch.compile(
65-
self.random_op, backend="openxla", fullgraph=True)
68+
self.random_op, backend=backend, fullgraph=True)
6669
t = torch.randn(5, 5).to(xm.xla_device())
6770
dynamo_res_1 = dynamo_random_op(t)
6871
dynamo_res_2 = dynamo_random_op(t)
@@ -75,7 +78,7 @@ def test_random_op_different_result_each_run(self):
7578
self.assertFalse(torch.allclose(dynamo_res_2, dynamo_res_3))
7679

7780

78-
class DynamoLTCInteractionTest(unittest.TestCase):
81+
class DynamoLTCInteractionTest(parameterized.TestCase):
7982

8083
def index_copy_inplace(self, cache, update_indices, xk):
8184
cache.index_copy_(0, update_indices, xk)
@@ -104,21 +107,22 @@ def test_mark_step_after_dynamo(self):
104107
xm.wait_device_ops()
105108
self.assertEqual(current_execute_time, met.metric_data('ExecuteTime')[0])
106109

107-
def test_copy_op(self):
110+
@parameterized.parameters(['openxla', dynamo_backend2.dynamo_backend])
111+
def test_copy_op(self, backend):
108112

109113
def copy_a_to_b(a):
110114
res = a.cos()
111-
copy = torch.ops.aten.copy.default(a, res)
115+
copy = torch.ops.aten.copy_.default(a, res)
112116
return copy
113117

114118
device = torch_xla.device()
115-
compiled_copy = torch.compile(copy_a_to_b, backend="openxla")
119+
compiled_copy = torch.compile(copy_a_to_b, backend=backend)
116120
a = torch.randn(2, 9).to(device)
117121
res = compiled_copy(a)
118122
self.assertTrue(torch.allclose(res, a))
119123

120124

121-
class DynamoProfilerTest(unittest.TestCase):
125+
class DynamoProfilerTest(parameterized.TestCase):
122126

123127
def dummy_fn(self, a):
124128
return torch.sin(a) + a
@@ -253,11 +257,10 @@ def fn_without_input(device):
253257
res_xla_dynamo = compiled_fn(device)
254258
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))
255259

256-
@parameterized.parameters(
257-
True,
258-
False,
259-
)
260-
def test_simple_model_with_in_place_ops(self, initialize_on_cuda):
260+
@parameterized.product(
261+
initialize_on_cuda=[True, False],
262+
backend=['openxla', dynamo_backend2.dynamo_backend])
263+
def test_simple_model_with_in_place_ops(self, initialize_on_cuda, backend):
261264

262265
class TestModel(nn.Module):
263266

@@ -286,7 +289,7 @@ def forward(self, index, copy_tensor, input_tensor, op_name):
286289

287290
cpu_model = TestModel()
288291
device_model = TestModel(device).to(device)
289-
compiled_model = torch.compile(device_model, backend='openxla')
292+
compiled_model = torch.compile(device_model, backend=backend)
290293

291294
input_tensor = torch.ones(3)
292295
copy_tensor = torch.rand(5, 3)
@@ -306,11 +309,10 @@ def forward(self, index, copy_tensor, input_tensor, op_name):
306309
op_name=in_place_op)
307310
self.assertTrue(torch.allclose(res_cpu, res_device_dynamo.cpu()))
308311

309-
@parameterized.parameters(
310-
True,
311-
False,
312-
)
313-
def test_einsum(self, initialize_on_cuda):
312+
@parameterized.product(
313+
initialize_on_cuda=[True, False],
314+
backend=['openxla', dynamo_backend2.dynamo_backend])
315+
def test_einsum(self, initialize_on_cuda, backend):
314316
# einsum currently does not have meta function to compute the shape hence
315317
# will fallback to XLA with FakeTensor as input to infer the output shape.
316318
def einsum_mm(a, b):
@@ -321,7 +323,7 @@ def einsum_mm(a, b):
321323
b = torch.randn(4, 4, 4, 4).to(device)
322324
xm.mark_step()
323325

324-
dynamo_einsum_mm = torch.compile(einsum_mm, backend="openxla")
326+
dynamo_einsum_mm = torch.compile(einsum_mm, backend=backend)
325327
res_device_dynamo = dynamo_einsum_mm(a, b)
326328
res_device_non_dynamo = einsum_mm(a, b)
327329
self.assertTrue(
@@ -368,11 +370,10 @@ def get_loader(self, device, sample_count, batch_size=4):
368370

369371
@skipOnTpu
370372
@skipOnNeuron
371-
@parameterized.parameters(
372-
True,
373-
False,
374-
)
375-
def test_resnet18(self, initialize_on_cuda):
373+
@parameterized.product(
374+
initialize_on_cuda=[True, False],
375+
backend=['openxla', dynamo_backend2.dynamo_backend])
376+
def test_resnet18(self, initialize_on_cuda, backend):
376377
device = self._choose_proper_device(initialize_on_cuda)
377378
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
378379
loader = self.get_loader(device, sample_count, batch_size=4)
@@ -386,19 +387,21 @@ def test_resnet18(self, initialize_on_cuda):
386387
xm.mark_step()
387388
xm.wait_device_ops()
388389
met.clear_all()
389-
dynamo_resnet18 = torch.compile(device_resnet18, backend='openxla')
390+
dynamo_resnet18 = torch.compile(device_resnet18, backend=backend)
390391
for data, _ in loader:
391392
output = dynamo_resnet18(data)
392393
output_cpu = resnet18(data.cpu())
393394
self.assertTrue(
394395
torch.allclose(output_cpu, output.cpu(), rtol=1e-05, atol=1e-05))
395396
# We only expect one graph for the resnet18 inference.
396-
self.assertEqual(met.metric_data('CompileTime')[0], 1)
397-
self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count)
398-
self.assertEqual(
399-
met.metric_data('RunCachedGraphInputData')[0], sample_count)
400-
self.assertEqual(
401-
met.metric_data('RunCachedGraphOutputData')[0], sample_count)
397+
if backend == 'openxla':
398+
# backend2 doesnt populate metrics
399+
self.assertEqual(met.metric_data('CompileTime')[0], 1)
400+
self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count)
401+
self.assertEqual(
402+
met.metric_data('RunCachedGraphInputData')[0], sample_count)
403+
self.assertEqual(
404+
met.metric_data('RunCachedGraphOutputData')[0], sample_count)
402405

403406
@skipOnNeuron
404407
def test_resnet18_lazy_vs_dynamo(self):
@@ -428,7 +431,7 @@ def test_resnet18_lazy_vs_dynamo(self):
428431
# mess up the counter check.
429432

430433

431-
class DynamoCpuFallbackTest(unittest.TestCase):
434+
class DynamoCpuFallbackTest(parameterized.TestCase):
432435

433436
def test_operator_fallback(self):
434437

@@ -509,7 +512,7 @@ def fn_fallback(t):
509512
self.assertEqual(met.metric_data('ExecuteTime')[0], 3)
510513

511514

512-
class DynamoTrainingBasicTest(unittest.TestCase):
515+
class DynamoTrainingBasicTest(parameterized.TestCase):
513516

514517
@classmethod
515518
def setUpClass(self):
@@ -613,7 +616,7 @@ def test_resnet18(self):
613616
met.metric_data('RunCachedGraphOutputData')[0], sample_count * 2)
614617

615618

616-
class DynamoTrainingOptimizerTest(unittest.TestCase):
619+
class DynamoTrainingOptimizerTest(parameterized.TestCase):
617620

618621
@classmethod
619622
def setUpClass(self):
@@ -719,7 +722,7 @@ def test_resnet18(self):
719722
met.metric_data('RunCachedGraphOutputData')[0], sample_count * 3)
720723

721724

722-
class DynamoErrorMessageTest(unittest.TestCase):
725+
class DynamoErrorMessageTest(parameterized.TestCase):
723726

724727
def test_mixed_cpu_tensor(self):
725728
device = xm.xla_device()
@@ -758,17 +761,18 @@ def test_all_cpu_tensor(self):
758761
self.assertLessEqual(len(met.counter_names()), 1)
759762

760763

761-
class DynamoOperationsTests(test_utils.XlaTestCase):
764+
class DynamoOperationsTest(test_utils.XlaTestCase, parameterized.TestCase):
762765

763-
def test_new_with_sizes(self):
766+
@parameterized.parameters(['openxla', dynamo_backend2.dynamo_backend])
767+
def test_new_with_sizes(self, backend):
764768

765769
# The addition operation is needed here, since the error only occurs when FakeTensorMode
766770
# checks the device of the arguments of some operation. If there's no operation using the
767771
# result of Tensor.new, this comparison never occurs.
768772
def foo(x):
769773
return x.new(*x.size()) + x
770774

771-
optfoo = torch.compile(backend="openxla")(foo)
775+
optfoo = torch.compile(backend=backend)(foo)
772776

773777
t = torch.arange(9)
774778
Xt = t.to(xm.xla_device())
@@ -782,12 +786,13 @@ def foo(x):
782786
self.assertEqual(expected.dtype, actual.dtype)
783787
self.assertEqual(expected.device, actual.device)
784788

785-
def test_return_expand(self):
789+
@parameterized.parameters(['openxla', dynamo_backend2.dynamo_backend])
790+
def test_return_expand(self, backend):
786791

787792
def foo(x):
788793
return x.expand(2, -1)
789794

790-
optfoo = torch.compile(backend="openxla")(foo)
795+
optfoo = torch.compile(backend=backend)(foo)
791796

792797
t = torch.arange(10)
793798
Xt = t.to(xm.xla_device())

torch_xla/_dynamo/dynamo_backend2.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import functools
2+
from typing import Any
3+
import torch
4+
from torch.utils import _pytree as pytree
5+
from torch_xla.core import xla_builder as xb
6+
import torch_xla
7+
8+
from torch._dynamo.backends.common import aot_autograd
9+
from functorch.compile import make_boxed_func
10+
11+
12+
def _dynamo_backend(model: torch.fx.GraphModule, sample_args: Any):
13+
"""A dynamo backend that compiles a FX graph to HLO using JAX and torchax.
14+
15+
It takes FX graph as input and returns a compiled PyTorch function. The FX graph
16+
is traced into a JAX function using torchax, and the JAX function is lowered to HLO.
17+
18+
Args:
19+
model: the graph to be compiled
20+
sample_args: a tuple or list of sample inputs. I.e. model(*sample_args) produces
21+
the model output
22+
23+
Returns:
24+
Another callable f such that f(*sample_inputs) computes the same thing as model.
25+
"""
26+
27+
try:
28+
import torchax.interop
29+
from torchax.export import JaxInterpreter
30+
import jax
31+
except ImportError:
32+
print('To use this dynamo backend, please install torchax')
33+
raise
34+
35+
jax.config.update("jax_enable_x64", True)
36+
env = torchax.default_env()
37+
xla_device = torch_xla.device()
38+
39+
def run_jax(*args, initial_rng_key):
40+
args_t = torchax.interop.torch_view(args)
41+
env.manual_seed(initial_rng_key)
42+
with env:
43+
res = model(*args_t)
44+
return torchax.interop.jax_view(res)
45+
46+
initial_rng_key = torch.tensor(0, device=xla_device, dtype=torch.uint32)
47+
computation = xb.jax_func_to_xla_computation(
48+
run_jax, sample_args, {'initial_rng_key': initial_rng_key}, 'dynamo_jax')
49+
50+
def equivalent(*args, **kwargs):
51+
kwargs['initial_rng_key'] = torch.randint(
52+
0, 2**32, (), dtype=torch.uint32, device=xla_device)
53+
flattened, _ = pytree.tree_flatten((args, kwargs))
54+
res = computation(flattened)
55+
if not isinstance(res, (list, tuple)):
56+
return (res,)
57+
return res
58+
59+
return make_boxed_func(equivalent)
60+
61+
62+
def dynamo_backend(fx, args):
63+
from functorch.compile import aot_function
64+
return aot_function(fx, fw_compiler=_dynamo_backend)

torch_xla/core/xla_builder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -911,8 +911,8 @@ def get_hlo():
911911
import torch_xla.debug.profiler as xp
912912
# If we see this trace span in the profiler, we'll know that there's a cache miss.
913913
with xp.Trace('jax_to_hlo'):
914-
hlo_ir = jax.jit(
915-
fn, keep_unused=True).lower(*sample_tensor_args).compiler_ir('hlo')
914+
lowered = jax.jit(fn, keep_unused=True).lower(*sample_tensor_args)
915+
hlo_ir = lowered.compiler_ir('hlo')
916916

917917
# Get a protobuf representation of the HLO. `as_serialized_hlo_module_proto` is
918918
# mentioned at https://github.com/jax-ml/jax/discussions/22266

torchax/test/test_context.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,23 @@ def test_mode_decorator(self):
3939

4040
def test_same_manual_seed(self):
4141
with xla_env:
42-
torch.manual_seed(1234)
42+
xla_env.manual_seed(1234)
4343
x = torch.randn((3, 3))
4444
self.assertIsInstance(x, tensor.Tensor)
4545

46-
torch.manual_seed(1234)
46+
xla_env.manual_seed(1234)
4747
y = torch.randn((3, 3))
4848
self.assertIsInstance(y, tensor.Tensor)
4949

5050
self.assertTrue(torch.equal(torchax.tensor.j2t(x._elem), torchax.tensor.j2t(y._elem)))
5151

5252
def test_different_manual_seed(self):
5353
with xla_env:
54-
torch.manual_seed(1234)
54+
xla_env.manual_seed(1234)
5555
x = torch.randn((3, 3))
5656
self.assertIsInstance(x, tensor.Tensor)
5757

58-
torch.manual_seed(12345)
58+
xla_env.manual_seed(12345)
5959
y = torch.randn((3, 3))
6060
self.assertIsInstance(y, tensor.Tensor)
6161

torchax/torchax/ops/jaten.py

+7
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,13 @@ def reduce_fn(a, b):
13511351

13521352
return y, indices
13531353

1354+
try:
1355+
@op(torch.ops.xla.max_pool2d_forward)
1356+
def _xla_max_pool2d_foward(*args, **kwargs):
1357+
return _aten_max_pool2d_with_indices(*args, **kwargs)[0]
1358+
except AttributeError:
1359+
pass
1360+
13541361

13551362
# TODO add more ops
13561363

0 commit comments

Comments
 (0)