Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/gt4py/cartesian/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,21 @@
"""Default literal precision used for unspecific `float` types and casts."""


def get_integer_default_type():
"""Return the integer numpy type corresponding to the LITERAL_INT_PRECISION set."""
Comment on lines +45 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That won't work with the per-stencil overrides of literal_integer_precision that we allow. In the GTScriptParser, you have access to the BuildOptions in self.options which contains literal_integer_precision. Likewise, in annotate_definition() of the StencilObject you get the BuildOptions as function argument, so you can get the precision from there.

# I'd love to return `numpy.signedinteger[LITERAL_INT_PRECISION]` but that won't work
if LITERAL_INT_PRECISION == 8:
return numpy.int8
if LITERAL_INT_PRECISION == 32:
return numpy.int32
if LITERAL_INT_PRECISION == 64:
return numpy.int64
if LITERAL_INT_PRECISION == 128:
return numpy.int128

raise NotImplementedError("Unknown integer precision type")


@enum.unique
class AccessKind(enum.IntFlag):
NONE = 0
Expand Down
39 changes: 39 additions & 0 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,18 @@ def visit_ImportFrom(self, node: ast.ImportFrom):
return node

def visit_Attribute(self, node: ast.Attribute):
# An enum MyEnum.A would come has
# > ast.Attribute("A")
# - value: ast.Name("MyEnum")
# We want to replace the entire thing - so we capture the top level
# attribute. We don't use the `self.context` because of this
# two-step AST structure which doesn't fit the generic `replace_node`.

if isinstance(node.value, ast.Name) and node.value.id in _ENUM_REGISTER.keys():
int_value = getattr(_ENUM_REGISTER[node.value.id], node.attr)
return ast.Constant(value=int_value)

# Common replace for all other nodes in context.
return self._replace_node(node)

def visit_Name(self, node: ast.Name):
Expand Down Expand Up @@ -1808,6 +1820,11 @@ def visit_Assign(self, node: ast.Assign):
raise invalid_target


_ENUM_REGISTER: dict[str, object] = {}
"""Register of IntEnum that will be available to parsing in stencils. Register
with @gtscript.enum()"""


class GTScriptParser(ast.NodeVisitor):
CONST_VALUE_TYPES = (
*gtscript._VALID_DATA_TYPES,
Expand Down Expand Up @@ -1899,6 +1916,10 @@ def annotate_definition(
and param.annotation in gtscript._VALID_DATA_TYPES
):
dtype_annotation = np.dtype(param.annotation)
elif param.annotation in _ENUM_REGISTER.values():
dtype_annotation = (
gt_definitions.get_integer_default_type()
) # We will replace all enums with `int`
elif param.annotation is inspect.Signature.empty:
dtype_annotation = None
else:
Expand Down Expand Up @@ -2024,6 +2045,10 @@ def collect_external_symbols(definition):
local_symbols = CollectLocalSymbolsAstVisitor.apply(gtscript_ast)
nonlocal_symbols = {}

# Remove enums from `context`, they will be turned into integers in the ValueReplacer
for enum_ in _ENUM_REGISTER.keys():
context.pop(enum_, "")

name_nodes = gt_meta.collect_names(gtscript_ast, skip_annotations=False)
for collected_name in name_nodes.keys():
if collected_name not in gtscript.builtins:
Expand Down Expand Up @@ -2147,6 +2172,20 @@ def resolve_external_symbols(

return result

@staticmethod
def register_enum(class_: type[enum.IntEnum]):
class_name = class_.__name__
if class_name in _ENUM_REGISTER:
raise ValueError(
f"Enum names must be unique. @gtscript.enum {class_name} is already taken."
)

if not issubclass(class_, enum.IntEnum):
raise ValueError(f"Enum {class_name} needs to derive from `enum.IntEnum`.")

_ENUM_REGISTER[class_name] = class_
return class_

def extract_arg_descriptors(self):
api_signature = self.definition._gtscript_["api_signature"]
api_annotations = self.definition._gtscript_["api_annotations"]
Expand Down
9 changes: 9 additions & 0 deletions src/gt4py/cartesian/gtscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import inspect
import numbers
import types
from enum import IntEnum
from typing import Callable, Dict, Type, Union

import numpy as np
Expand Down Expand Up @@ -159,6 +160,14 @@ def _parse_annotation(arg, annotation):
return original_annotations


def enum(class_: type[IntEnum]):
"""Mark an IntEnum derived class as readable for GT4Py."""
from gt4py.cartesian.frontend import gtscript_frontend as gt_frontend

gt_frontend.GTScriptParser.register_enum(class_)
return class_


def function(func):
"""Mark a GTScript function."""
from gt4py.cartesian.frontend import gtscript_frontend as gt_frontend
Expand Down
16 changes: 14 additions & 2 deletions src/gt4py/cartesian/stencil_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
import numpy as np

from gt4py.cartesian import backend as gt_backend
from gt4py.cartesian.definitions import AccessKind, DomainInfo, FieldInfo, ParameterInfo
from gt4py.cartesian.definitions import (
AccessKind,
DomainInfo,
FieldInfo,
ParameterInfo,
get_integer_default_type,
)
from gt4py.cartesian.frontend import gtscript_frontend
from gt4py.cartesian.gtc import utils as gtc_utils
from gt4py.cartesian.gtc.definitions import Index, Shape
from gt4py.storage.cartesian import utils as storage_utils
Expand Down Expand Up @@ -558,8 +565,13 @@ def _call_run(
exec_info["call_run_start_time"] = time.perf_counter()
backend_cls = gt_backend.from_name(self.backend)
device = backend_cls.storage_info["device"]
array_infos = _extract_array_infos(field_args, device)

# Normalize `gtscript.enum` to integers
for name, value in parameter_args.items():
if type(value) in gtscript_frontend._ENUM_REGISTER.values():
parameter_args[name] = get_integer_default_type()(value.value)

array_infos = _extract_array_infos(field_args, device)
cache_key = _compute_domain_origin_cache_key(array_infos, parameter_args, domain, origin)
if cache_key not in self._domain_origin_cache:
origin = self._normalize_origins(array_infos, self.field_info, origin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from enum import IntEnum
import numpy as np
import pytest

Expand All @@ -22,6 +23,7 @@
J,
K,
IJ,
IJK,
computation,
horizontal,
interval,
Expand Down Expand Up @@ -1296,3 +1298,38 @@ def test_lower_dim_field(
out_arr[:, :, :] = 0
test_lower_dim_field(k_arr, out_arr)
assert (out_arr[:, :, :] == 42.42).all()


@gtscript.enum
class MyEnum(IntEnum):
Zero = 0
A = 10
B = 20
C = 30


@pytest.mark.parametrize("backend", ALL_BACKENDS)
def test_enum_runtime(backend):
@gtscript.stencil(backend=backend)
def the_stencil(out_field: Field[IJK, int], order: MyEnum):
with computation(PARALLEL), interval(0, 1):
out_field = 32
if order < MyEnum.A:
out_field = MyEnum.A

with computation(PARALLEL), interval(1, 2):
out_field = 23
out_field = MyEnum.B

with computation(PARALLEL), interval(2, None):
out_field = 56
out_field = MyEnum.C

domain = (5, 5, 5)
out_arr = gt_storage.zeros(backend=backend, shape=domain, dtype=int)

the_stencil(out_arr, MyEnum.Zero)

assert out_arr[0, 0, 0] == MyEnum.A.value
assert out_arr[0, 0, 1] == MyEnum.B.value
assert (out_arr[0, 0, 2:] == MyEnum.C.value).all()
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from enum import IntEnum
import inspect
import textwrap
import types
Expand Down Expand Up @@ -2029,3 +2030,33 @@ def test_assign_constant_numpy_typed(self):
constant: nodes.ScalarLiteral = def_ir.computations[0].body.stmts[0].value
assert isinstance(constant, nodes.ScalarLiteral)
assert constant.data_type == nodes.DataType.FLOAT32


@gtscript.enum
class LocalEnum(IntEnum):
A = 42
B = 1000


class TestEnum:
def setup_method(self):
def enum(field: gtscript.Field[float], order: LocalEnum): # type: ignore
with computation(PARALLEL), interval(0, 1):
if order > LocalEnum.A:
field[0, 0, 0] = LocalEnum.B

self.stencil = enum

def test_enum_in_stencil(self):
def_ir = parse_definition(
self.stencil,
name=inspect.stack()[0][3],
module=self.__class__.__name__,
)

assert isinstance(def_ir.computations[0].body.stmts[0].condition.rhs, nodes.ScalarLiteral)
assert def_ir.computations[0].body.stmts[0].condition.rhs.value == LocalEnum.A
assert isinstance(
def_ir.computations[0].body.stmts[0].main_body.stmts[0].value, nodes.ScalarLiteral
)
assert def_ir.computations[0].body.stmts[0].main_body.stmts[0].value.value == LocalEnum.B
Loading