10
10
import torch_xla .core .xla_env_vars as xenv
11
11
from torch_xla import runtime as xr
12
12
import torch_xla .debug .profiler as xp
13
+ from torch_xla ._dynamo import dynamo_backend2
13
14
import torch .optim as optim
14
15
import torch .nn as nn
15
16
import torch ._dynamo as dynamo
@@ -38,31 +39,33 @@ def _is_on_neuron():
38
39
skipOnNeuron = unittest .skipIf (_is_on_neuron (), 'Not supported on NEURON' )
39
40
40
41
41
- class DynamoInPlaceTest (unittest .TestCase ):
42
+ class DynamoInPlaceTest (parameterized .TestCase ):
42
43
43
44
def inplace_update (self , a ):
44
45
a += 1
45
46
return a
46
47
47
- def test_inplace_update_correctness (self ):
48
+ @parameterized .parameters (['openxla' , dynamo_backend2 .dynamo_backend ])
49
+ def test_inplace_update_correctness (self , backend ):
48
50
dynamo_inplace = torch .compile (
49
- self .inplace_update , backend = "openxla" , fullgraph = True )
51
+ self .inplace_update , backend = backend , fullgraph = True )
50
52
t = torch .tensor ([0 , 1 , 2 ], device = xm .xla_device ())
51
53
for i in range (10 ):
52
54
t = dynamo_inplace (t )
53
55
self .assertTrue (torch .all (torch .eq (t .cpu (), torch .tensor ([10 , 11 , 12 ]))))
54
56
55
57
56
- class DynamRandomOpTest (unittest .TestCase ):
58
+ class DynamRandomOpTest (parameterized .TestCase ):
57
59
58
60
def random_op (self , a ):
59
61
return torch .randn (5 , 5 , device = a .device ) + a
60
62
61
- def test_random_op_different_result_each_run (self ):
63
+ @parameterized .parameters (['openxla' , dynamo_backend2 .dynamo_backend ])
64
+ def test_random_op_different_result_each_run (self , backend ):
62
65
xm .wait_device_ops ()
63
66
met .clear_all ()
64
67
dynamo_random_op = torch .compile (
65
- self .random_op , backend = "openxla" , fullgraph = True )
68
+ self .random_op , backend = backend , fullgraph = True )
66
69
t = torch .randn (5 , 5 ).to (xm .xla_device ())
67
70
dynamo_res_1 = dynamo_random_op (t )
68
71
dynamo_res_2 = dynamo_random_op (t )
@@ -75,7 +78,7 @@ def test_random_op_different_result_each_run(self):
75
78
self .assertFalse (torch .allclose (dynamo_res_2 , dynamo_res_3 ))
76
79
77
80
78
- class DynamoLTCInteractionTest (unittest .TestCase ):
81
+ class DynamoLTCInteractionTest (parameterized .TestCase ):
79
82
80
83
def index_copy_inplace (self , cache , update_indices , xk ):
81
84
cache .index_copy_ (0 , update_indices , xk )
@@ -104,21 +107,22 @@ def test_mark_step_after_dynamo(self):
104
107
xm .wait_device_ops ()
105
108
self .assertEqual (current_execute_time , met .metric_data ('ExecuteTime' )[0 ])
106
109
107
- def test_copy_op (self ):
110
+ @parameterized .parameters (['openxla' , dynamo_backend2 .dynamo_backend ])
111
+ def test_copy_op (self , backend ):
108
112
109
113
def copy_a_to_b (a ):
110
114
res = a .cos ()
111
- copy = torch .ops .aten .copy .default (a , res )
115
+ copy = torch .ops .aten .copy_ .default (a , res )
112
116
return copy
113
117
114
118
device = torch_xla .device ()
115
- compiled_copy = torch .compile (copy_a_to_b , backend = "openxla" )
119
+ compiled_copy = torch .compile (copy_a_to_b , backend = backend )
116
120
a = torch .randn (2 , 9 ).to (device )
117
121
res = compiled_copy (a )
118
122
self .assertTrue (torch .allclose (res , a ))
119
123
120
124
121
- class DynamoProfilerTest (unittest .TestCase ):
125
+ class DynamoProfilerTest (parameterized .TestCase ):
122
126
123
127
def dummy_fn (self , a ):
124
128
return torch .sin (a ) + a
@@ -253,11 +257,10 @@ def fn_without_input(device):
253
257
res_xla_dynamo = compiled_fn (device )
254
258
self .assertTrue (torch .allclose (res_cpu , res_xla_dynamo .cpu ()))
255
259
256
- @parameterized .parameters (
257
- True ,
258
- False ,
259
- )
260
- def test_simple_model_with_in_place_ops (self , initialize_on_cuda ):
260
+ @parameterized .product (
261
+ initialize_on_cuda = [True , False ],
262
+ backend = ['openxla' , dynamo_backend2 .dynamo_backend ])
263
+ def test_simple_model_with_in_place_ops (self , initialize_on_cuda , backend ):
261
264
262
265
class TestModel (nn .Module ):
263
266
@@ -286,7 +289,7 @@ def forward(self, index, copy_tensor, input_tensor, op_name):
286
289
287
290
cpu_model = TestModel ()
288
291
device_model = TestModel (device ).to (device )
289
- compiled_model = torch .compile (device_model , backend = 'openxla' )
292
+ compiled_model = torch .compile (device_model , backend = backend )
290
293
291
294
input_tensor = torch .ones (3 )
292
295
copy_tensor = torch .rand (5 , 3 )
@@ -306,11 +309,10 @@ def forward(self, index, copy_tensor, input_tensor, op_name):
306
309
op_name = in_place_op )
307
310
self .assertTrue (torch .allclose (res_cpu , res_device_dynamo .cpu ()))
308
311
309
- @parameterized .parameters (
310
- True ,
311
- False ,
312
- )
313
- def test_einsum (self , initialize_on_cuda ):
312
+ @parameterized .product (
313
+ initialize_on_cuda = [True , False ],
314
+ backend = ['openxla' , dynamo_backend2 .dynamo_backend ])
315
+ def test_einsum (self , initialize_on_cuda , backend ):
314
316
# einsum currently does not have meta function to compute the shape hence
315
317
# will fallback to XLA with FakeTensor as input to infer the output shape.
316
318
def einsum_mm (a , b ):
@@ -321,7 +323,7 @@ def einsum_mm(a, b):
321
323
b = torch .randn (4 , 4 , 4 , 4 ).to (device )
322
324
xm .mark_step ()
323
325
324
- dynamo_einsum_mm = torch .compile (einsum_mm , backend = "openxla" )
326
+ dynamo_einsum_mm = torch .compile (einsum_mm , backend = backend )
325
327
res_device_dynamo = dynamo_einsum_mm (a , b )
326
328
res_device_non_dynamo = einsum_mm (a , b )
327
329
self .assertTrue (
@@ -368,11 +370,10 @@ def get_loader(self, device, sample_count, batch_size=4):
368
370
369
371
@skipOnTpu
370
372
@skipOnNeuron
371
- @parameterized .parameters (
372
- True ,
373
- False ,
374
- )
375
- def test_resnet18 (self , initialize_on_cuda ):
373
+ @parameterized .product (
374
+ initialize_on_cuda = [True , False ],
375
+ backend = ['openxla' , dynamo_backend2 .dynamo_backend ])
376
+ def test_resnet18 (self , initialize_on_cuda , backend ):
376
377
device = self ._choose_proper_device (initialize_on_cuda )
377
378
sample_count = xu .getenv_as ('SAMPLE_COUNT' , int , defval = 10 )
378
379
loader = self .get_loader (device , sample_count , batch_size = 4 )
@@ -386,19 +387,21 @@ def test_resnet18(self, initialize_on_cuda):
386
387
xm .mark_step ()
387
388
xm .wait_device_ops ()
388
389
met .clear_all ()
389
- dynamo_resnet18 = torch .compile (device_resnet18 , backend = 'openxla' )
390
+ dynamo_resnet18 = torch .compile (device_resnet18 , backend = backend )
390
391
for data , _ in loader :
391
392
output = dynamo_resnet18 (data )
392
393
output_cpu = resnet18 (data .cpu ())
393
394
self .assertTrue (
394
395
torch .allclose (output_cpu , output .cpu (), rtol = 1e-05 , atol = 1e-05 ))
395
396
# We only expect one graph for the resnet18 inference.
396
- self .assertEqual (met .metric_data ('CompileTime' )[0 ], 1 )
397
- self .assertEqual (met .metric_data ('ExecuteTime' )[0 ], sample_count )
398
- self .assertEqual (
399
- met .metric_data ('RunCachedGraphInputData' )[0 ], sample_count )
400
- self .assertEqual (
401
- met .metric_data ('RunCachedGraphOutputData' )[0 ], sample_count )
397
+ if backend == 'openxla' :
398
+ # backend2 doesnt populate metrics
399
+ self .assertEqual (met .metric_data ('CompileTime' )[0 ], 1 )
400
+ self .assertEqual (met .metric_data ('ExecuteTime' )[0 ], sample_count )
401
+ self .assertEqual (
402
+ met .metric_data ('RunCachedGraphInputData' )[0 ], sample_count )
403
+ self .assertEqual (
404
+ met .metric_data ('RunCachedGraphOutputData' )[0 ], sample_count )
402
405
403
406
@skipOnNeuron
404
407
def test_resnet18_lazy_vs_dynamo (self ):
@@ -428,7 +431,7 @@ def test_resnet18_lazy_vs_dynamo(self):
428
431
# mess up the counter check.
429
432
430
433
431
- class DynamoCpuFallbackTest (unittest .TestCase ):
434
+ class DynamoCpuFallbackTest (parameterized .TestCase ):
432
435
433
436
def test_operator_fallback (self ):
434
437
@@ -509,7 +512,7 @@ def fn_fallback(t):
509
512
self .assertEqual (met .metric_data ('ExecuteTime' )[0 ], 3 )
510
513
511
514
512
- class DynamoTrainingBasicTest (unittest .TestCase ):
515
+ class DynamoTrainingBasicTest (parameterized .TestCase ):
513
516
514
517
@classmethod
515
518
def setUpClass (self ):
@@ -613,7 +616,7 @@ def test_resnet18(self):
613
616
met .metric_data ('RunCachedGraphOutputData' )[0 ], sample_count * 2 )
614
617
615
618
616
- class DynamoTrainingOptimizerTest (unittest .TestCase ):
619
+ class DynamoTrainingOptimizerTest (parameterized .TestCase ):
617
620
618
621
@classmethod
619
622
def setUpClass (self ):
@@ -719,7 +722,7 @@ def test_resnet18(self):
719
722
met .metric_data ('RunCachedGraphOutputData' )[0 ], sample_count * 3 )
720
723
721
724
722
- class DynamoErrorMessageTest (unittest .TestCase ):
725
+ class DynamoErrorMessageTest (parameterized .TestCase ):
723
726
724
727
def test_mixed_cpu_tensor (self ):
725
728
device = xm .xla_device ()
@@ -758,17 +761,18 @@ def test_all_cpu_tensor(self):
758
761
self .assertLessEqual (len (met .counter_names ()), 1 )
759
762
760
763
761
- class DynamoOperationsTests (test_utils .XlaTestCase ):
764
+ class DynamoOperationsTest (test_utils .XlaTestCase , parameterized . TestCase ):
762
765
763
- def test_new_with_sizes (self ):
766
+ @parameterized .parameters (['openxla' , dynamo_backend2 .dynamo_backend ])
767
+ def test_new_with_sizes (self , backend ):
764
768
765
769
# The addition operation is needed here, since the error only occurs when FakeTensorMode
766
770
# checks the device of the arguments of some operation. If there's no operation using the
767
771
# result of Tensor.new, this comparison never occurs.
768
772
def foo (x ):
769
773
return x .new (* x .size ()) + x
770
774
771
- optfoo = torch .compile (backend = "openxla" )(foo )
775
+ optfoo = torch .compile (backend = backend )(foo )
772
776
773
777
t = torch .arange (9 )
774
778
Xt = t .to (xm .xla_device ())
@@ -782,12 +786,13 @@ def foo(x):
782
786
self .assertEqual (expected .dtype , actual .dtype )
783
787
self .assertEqual (expected .device , actual .device )
784
788
785
- def test_return_expand (self ):
789
+ @parameterized .parameters (['openxla' , dynamo_backend2 .dynamo_backend ])
790
+ def test_return_expand (self , backend ):
786
791
787
792
def foo (x ):
788
793
return x .expand (2 , - 1 )
789
794
790
- optfoo = torch .compile (backend = "openxla" )(foo )
795
+ optfoo = torch .compile (backend = backend )(foo )
791
796
792
797
t = torch .arange (10 )
793
798
Xt = t .to (xm .xla_device ())
0 commit comments