Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,18 @@ def visit_ImportFrom(self, node: ast.ImportFrom):
return node

def visit_Attribute(self, node: ast.Attribute):
# An enum MyEnum.A would come has
# > ast.Attribute("A")
# - value: ast.Name("MyEnum")
# We want to replace the entire thing - so we capture the top level
# attribute. We don't use the `self.context` because of thise
# two-step AST structure which doesn't fit the generic `replace_node`

if isinstance(node.value, ast.Name) and node.value.id in gtscript._ENUM_REGISTER.keys():
int_value = getattr(gtscript._ENUM_REGISTER[node.value.id], node.attr)
return ast.Constant(value=int_value)

# Common replace for all other node in context
return self._replace_node(node)

def visit_Name(self, node: ast.Name):
Expand Down Expand Up @@ -1899,6 +1911,8 @@ def annotate_definition(
and param.annotation in gtscript._VALID_DATA_TYPES
):
dtype_annotation = np.dtype(param.annotation)
elif param.annotation in gtscript._ENUM_REGISTER.values():
dtype_annotation = int # We will replace all enums with `int`
elif param.annotation is inspect.Signature.empty:
dtype_annotation = None
else:
Expand Down Expand Up @@ -2024,6 +2038,10 @@ def collect_external_symbols(definition):
local_symbols = CollectLocalSymbolsAstVisitor.apply(gtscript_ast)
nonlocal_symbols = {}

# De-pop from `context in Enum registered with `gtscript`
for enum_ in gtscript._ENUM_REGISTER.keys():
context.pop(enum_, "")

name_nodes = gt_meta.collect_names(gtscript_ast, skip_annotations=False)
for collected_name in name_nodes.keys():
if collected_name not in gtscript.builtins:
Expand Down
21 changes: 21 additions & 0 deletions src/gt4py/cartesian/gtscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import inspect
import numbers
import types
from enum import IntEnum
from typing import Callable, Dict, Type, Union

import numpy as np
Expand Down Expand Up @@ -159,6 +160,26 @@ def _parse_annotation(arg, annotation):
return original_annotations


_ENUM_REGISTER: dict[str, object] = {}
"""Register of IntEnum that will be available to parsing in stencils. Register
with @gtscript.enum()"""


def enum(class_: type[IntEnum]):
class_name = class_.__name__
if class_name in _ENUM_REGISTER:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this works and makes for some centralization!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that ENUM_REGISTER is on gtscript, I will probably move it into it's own little space

raise ValueError(
f"Cannot register @gtscript.enum {class_name} as a class"
"with the same name is already registered."
)

if not issubclass(class_, IntEnum):
raise ValueError(f"Enum {class_name} needs to derive from `enum.IntEnum`.")

_ENUM_REGISTER[class_name] = class_
return class_


def function(func):
"""Mark a GTScript function."""
from gt4py.cartesian.frontend import gtscript_frontend as gt_frontend
Expand Down
8 changes: 7 additions & 1 deletion src/gt4py/cartesian/stencil_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from gt4py.cartesian.gtc import utils as gtc_utils
from gt4py.cartesian.gtc.definitions import Index, Shape
from gt4py.storage.cartesian import utils as storage_utils
from gt4py.cartesian import gtscript


try:
Expand Down Expand Up @@ -558,8 +559,13 @@ def _call_run(
exec_info["call_run_start_time"] = time.perf_counter()
backend_cls = gt_backend.from_name(self.backend)
device = backend_cls.storage_info["device"]
array_infos = _extract_array_infos(field_args, device)

# Normalize `gtscript.enum` to integers
for name, value in parameter_args.items():
if type(value) in gtscript._ENUM_REGISTER.values():
parameter_args[name] = value.value

array_infos = _extract_array_infos(field_args, device)
cache_key = _compute_domain_origin_cache_key(array_infos, parameter_args, domain, origin)
if cache_key not in self._domain_origin_cache:
origin = self._normalize_origins(array_infos, self.field_info, origin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from enum import IntEnum
import numpy as np
import pytest

Expand All @@ -22,6 +23,7 @@
J,
K,
IJ,
IJK,
computation,
horizontal,
interval,
Expand Down Expand Up @@ -1296,3 +1298,38 @@ def test_lower_dim_field(
out_arr[:, :, :] = 0
test_lower_dim_field(k_arr, out_arr)
assert (out_arr[:, :, :] == 42.42).all()


@gtscript.enum
class MyEnum(IntEnum):
Zero = 0
A = 10
B = 20
C = 30


@pytest.mark.parametrize("backend", ALL_BACKENDS)
def test_enum_runtime(backend):
@gtscript.stencil(backend=backend)
def the_stencil(out_field: Field[IJK, int], order: MyEnum):
with computation(PARALLEL), interval(0, 1):
out_field = 32
if order < MyEnum.A:
out_field = MyEnum.A

with computation(PARALLEL), interval(1, 2):
out_field = 23
out_field = MyEnum.B

with computation(PARALLEL), interval(2, None):
out_field = 56
out_field = MyEnum.C

domain = (5, 5, 5)
out_arr = gt_storage.zeros(backend=backend, shape=domain, dtype=int)

the_stencil(out_arr, MyEnum.Zero)

assert out_arr[0, 0, 0] == MyEnum.A.value
assert out_arr[0, 0, 1] == MyEnum.B.value
assert (out_arr[0, 0, 2:] == MyEnum.C.value).all()
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from enum import IntEnum
import inspect
import textwrap
import types
Expand Down Expand Up @@ -2029,3 +2030,33 @@ def test_assign_constant_numpy_typed(self):
constant: nodes.ScalarLiteral = def_ir.computations[0].body.stmts[0].value
assert isinstance(constant, nodes.ScalarLiteral)
assert constant.data_type == nodes.DataType.FLOAT32


@gtscript.enum
class LocalEnum(IntEnum):
A = 42
B = 1000


class TestEnum:
def setup_method(self):
def enum(field: gtscript.Field[float], order: LocalEnum): # type: ignore
with computation(PARALLEL), interval(0, 1):
if order > LocalEnum.A:
field[0, 0, 0] = LocalEnum.B

self.stencil = enum

def test_enum_in_stencil(self):
def_ir = parse_definition(
self.stencil,
name=inspect.stack()[0][3],
module=self.__class__.__name__,
)

assert isinstance(def_ir.computations[0].body.stmts[0].condition.rhs, nodes.ScalarLiteral)
assert def_ir.computations[0].body.stmts[0].condition.rhs.value == LocalEnum.A
assert isinstance(
def_ir.computations[0].body.stmts[0].main_body.stmts[0].value, nodes.ScalarLiteral
)
assert def_ir.computations[0].body.stmts[0].main_body.stmts[0].value.value == LocalEnum.B
Loading