Skip to content

Commit a83c21a

Browse files
authored
User function variable supports Enum argument (#1491)
1 parent 9b0867e commit a83c21a

File tree

3 files changed

+25
-0
lines changed

3 files changed

+25
-0
lines changed

test/test_misc.py

+20
Original file line numberDiff line numberDiff line change
@@ -2579,6 +2579,26 @@ def f():
25792579

25802580
self.assertTrue(torch.allclose(f(), torch.tensor([2.0])))
25812581

2582+
def test_user_function_variable_supports_enum_argument(self):
2583+
class Foo(enum.Enum):
2584+
FOO = 0
2585+
BAR = 1
2586+
2587+
def gn(x, y=Foo.FOO):
2588+
if y is Foo.FOO:
2589+
return x
2590+
else:
2591+
return x + 1
2592+
2593+
def fn(x):
2594+
return gn(x)
2595+
2596+
x = torch.randn(2, 3)
2597+
ref = fn(x)
2598+
opt_fn = torchdynamo.optimize("eager", nopython=True)(fn)
2599+
res = opt_fn(x)
2600+
self.assertTrue(torch.allclose(ref, res))
2601+
25822602

25832603
class CustomFunc(torch.autograd.Function):
25842604
@staticmethod

torchdynamo/variables/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .base import VariableTracker
22
from .builtin import BuiltinVariable
33
from .constant import ConstantVariable
4+
from .constant import EnumVariable
45
from .dicts import ConstDictVariable
56
from .dicts import DataClassVariable
67
from .dicts import DefaultDictVariable
@@ -50,6 +51,7 @@
5051
"ContextWrappingVariable",
5152
"DataClassVariable",
5253
"DefaultDictVariable",
54+
"EnumVariable",
5355
"FakeItemVariable",
5456
"GetAttrVariable",
5557
"GradModeVariable",

torchdynamo/variables/functions.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import functools
23
import inspect
34
import itertools
@@ -27,6 +28,8 @@ def wrap_bound_arg(val, options):
2728
return cls([wrap_bound_arg(x, options) for x in val], **options)
2829
elif variables.ConstantVariable.is_literal(val):
2930
return variables.ConstantVariable(val, **options)
31+
elif isinstance(val, enum.Enum):
32+
return variables.EnumVariable(val, **options)
3033
else:
3134
assert isinstance(val, VariableTracker), typestr(val)
3235
return val

0 commit comments

Comments
 (0)