File tree 3 files changed +25
-0
lines changed
3 files changed +25
-0
lines changed Original file line number Diff line number Diff line change @@ -2579,6 +2579,26 @@ def f():
2579
2579
2580
2580
self .assertTrue (torch .allclose (f (), torch .tensor ([2.0 ])))
2581
2581
2582
+ def test_user_function_variable_supports_enum_argument (self ):
2583
+ class Foo (enum .Enum ):
2584
+ FOO = 0
2585
+ BAR = 1
2586
+
2587
+ def gn (x , y = Foo .FOO ):
2588
+ if y is Foo .FOO :
2589
+ return x
2590
+ else :
2591
+ return x + 1
2592
+
2593
+ def fn (x ):
2594
+ return gn (x )
2595
+
2596
+ x = torch .randn (2 , 3 )
2597
+ ref = fn (x )
2598
+ opt_fn = torchdynamo .optimize ("eager" , nopython = True )(fn )
2599
+ res = opt_fn (x )
2600
+ self .assertTrue (torch .allclose (ref , res ))
2601
+
2582
2602
2583
2603
class CustomFunc (torch .autograd .Function ):
2584
2604
@staticmethod
Original file line number Diff line number Diff line change 1
1
from .base import VariableTracker
2
2
from .builtin import BuiltinVariable
3
3
from .constant import ConstantVariable
4
+ from .constant import EnumVariable
4
5
from .dicts import ConstDictVariable
5
6
from .dicts import DataClassVariable
6
7
from .dicts import DefaultDictVariable
50
51
"ContextWrappingVariable" ,
51
52
"DataClassVariable" ,
52
53
"DefaultDictVariable" ,
54
+ "EnumVariable" ,
53
55
"FakeItemVariable" ,
54
56
"GetAttrVariable" ,
55
57
"GradModeVariable" ,
Original file line number Diff line number Diff line change
1
+ import enum
1
2
import functools
2
3
import inspect
3
4
import itertools
@@ -27,6 +28,8 @@ def wrap_bound_arg(val, options):
27
28
return cls ([wrap_bound_arg (x , options ) for x in val ], ** options )
28
29
elif variables .ConstantVariable .is_literal (val ):
29
30
return variables .ConstantVariable (val , ** options )
31
+ elif isinstance (val , enum .Enum ):
32
+ return variables .EnumVariable (val , ** options )
30
33
else :
31
34
assert isinstance (val , VariableTracker ), typestr (val )
32
35
return val
You can’t perform that action at this time.
0 commit comments