From 287e55efc23271a781e17b830f1e9c8f19f509fd Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 20 Oct 2025 17:33:01 -0400 Subject: [PATCH 1/6] Allow registration of `IntEnum` with @gtscript.enum --- src/gt4py/cartesian/gtscript.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/gt4py/cartesian/gtscript.py b/src/gt4py/cartesian/gtscript.py index 56bf8874a8..fd011c9b51 100644 --- a/src/gt4py/cartesian/gtscript.py +++ b/src/gt4py/cartesian/gtscript.py @@ -16,6 +16,7 @@ import inspect import numbers import types +from enum import IntEnum from typing import Callable, Dict, Type, Union import numpy as np @@ -159,6 +160,26 @@ def _parse_annotation(arg, annotation): return original_annotations +_ENUM_REGISTER: dict[str, object] = {} +"""Register of IntEnum that will be available to parsing in stencils. Register +with @gtscript.enum()""" + + +def enum(class_: type[IntEnum]): + class_name = class_.__name__ + if class_name in _ENUM_REGISTER: + raise ValueError( + f"Cannot register @gtscript.enum {class_name} as a class" + "with the same name is already registered." + ) + + if not issubclass(class_, IntEnum): + raise ValueError(f"Enum {class_name} needs to derive from `enum.IntEnum`.") + + _ENUM_REGISTER[class_name] = class_ + return class_ + + def function(func): """Mark a GTScript function.""" from gt4py.cartesian.frontend import gtscript_frontend as gt_frontend From 5d40789edff3c1e46ef2f839dbab9ad390c0e611 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 20 Oct 2025 17:33:55 -0400 Subject: [PATCH 2/6] IntEnum are treated throughout the system as Integer: - Value replaced in frontend - Normalize at `call` time of the stencil --- .../cartesian/frontend/gtscript_frontend.py | 18 ++++++++++++++++++ src/gt4py/cartesian/stencil_object.py | 8 +++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index f4dd4dcb3b..fda8a4c342 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -309,6 +309,18 @@ def visit_ImportFrom(self, node: ast.ImportFrom): return node def visit_Attribute(self, node: ast.Attribute): + # An enum MyEnum.A would come has + # > ast.Attribute("A") + # - value: ast.Name("MyEnum") + # We want to replace the entire thing - so we capture the top level + # attribute. We don't use the `self.context` because of thise + # two-step AST structure which doesn't fit the generic `replace_node` + + if isinstance(node.value, ast.Name) and node.value.id in gtscript._ENUM_REGISTER.keys(): + int_value = getattr(gtscript._ENUM_REGISTER[node.value.id], node.attr) + return ast.Constant(value=int_value) + + # Common replace for all other node in context return self._replace_node(node) def visit_Name(self, node: ast.Name): @@ -1899,6 +1911,8 @@ def annotate_definition( and param.annotation in gtscript._VALID_DATA_TYPES ): dtype_annotation = np.dtype(param.annotation) + elif param.annotation in gtscript._ENUM_REGISTER.values(): + dtype_annotation = int # We will replace all enums with `int` elif param.annotation is inspect.Signature.empty: dtype_annotation = None else: @@ -2024,6 +2038,10 @@ def collect_external_symbols(definition): local_symbols = CollectLocalSymbolsAstVisitor.apply(gtscript_ast) nonlocal_symbols = {} + # De-pop from `context in Enum registered with `gtscript` + for enum_ in gtscript._ENUM_REGISTER.keys(): + context.pop(enum_, "") + name_nodes = gt_meta.collect_names(gtscript_ast, skip_annotations=False) for collected_name in name_nodes.keys(): if collected_name not in gtscript.builtins: diff --git a/src/gt4py/cartesian/stencil_object.py b/src/gt4py/cartesian/stencil_object.py index 3e988149bc..fb49349cd8 100644 --- a/src/gt4py/cartesian/stencil_object.py +++ b/src/gt4py/cartesian/stencil_object.py @@ -24,6 +24,7 @@ from gt4py.cartesian.gtc import utils as gtc_utils from gt4py.cartesian.gtc.definitions import Index, Shape from gt4py.storage.cartesian import utils as storage_utils +from gt4py.cartesian import gtscript try: @@ -558,8 +559,13 @@ def _call_run( exec_info["call_run_start_time"] = time.perf_counter() backend_cls = gt_backend.from_name(self.backend) device = backend_cls.storage_info["device"] - array_infos = _extract_array_infos(field_args, device) + # Normalize `gtscript.enum` to integers + for name, value in parameter_args.items(): + if type(value) in gtscript._ENUM_REGISTER.values(): + parameter_args[name] = value.value + + array_infos = _extract_array_infos(field_args, device) cache_key = _compute_domain_origin_cache_key(array_infos, parameter_args, domain, origin) if cache_key not in self._domain_origin_cache: origin = self._normalize_origins(array_infos, self.field_info, origin) From 6267f0fad750e31a4fc6d7bae0be11a24f97656d Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 20 Oct 2025 17:33:59 -0400 Subject: [PATCH 3/6] Unit tests --- .../test_code_generation.py | 37 +++++++++++++++++++ .../frontend_tests/test_gtscript_frontend.py | 31 ++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 9ffcb3f16c..94c2d06ab3 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from enum import IntEnum import numpy as np import pytest @@ -22,6 +23,7 @@ J, K, IJ, + IJK, computation, horizontal, interval, @@ -1296,3 +1298,38 @@ def test_lower_dim_field( out_arr[:, :, :] = 0 test_lower_dim_field(k_arr, out_arr) assert (out_arr[:, :, :] == 42.42).all() + + +@gtscript.enum +class MyEnum(IntEnum): + Zero = 0 + A = 10 + B = 20 + C = 30 + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_enum_runtime(backend): + @gtscript.stencil(backend=backend) + def the_stencil(out_field: Field[IJK, int], order: MyEnum): + with computation(PARALLEL), interval(0, 1): + out_field = 32 + if order < MyEnum.A: + out_field = MyEnum.A + + with computation(PARALLEL), interval(1, 2): + out_field = 23 + out_field = MyEnum.B + + with computation(PARALLEL), interval(2, None): + out_field = 56 + out_field = MyEnum.C + + domain = (5, 5, 5) + out_arr = gt_storage.zeros(backend=backend, shape=domain, dtype=int) + + the_stencil(out_arr, MyEnum.Zero) + + assert out_arr[0, 0, 0] == MyEnum.A.value + assert out_arr[0, 0, 1] == MyEnum.B.value + assert (out_arr[0, 0, 2:] == MyEnum.C.value).all() diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py index 00338ec790..2adca00867 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from enum import IntEnum import inspect import textwrap import types @@ -2029,3 +2030,33 @@ def test_assign_constant_numpy_typed(self): constant: nodes.ScalarLiteral = def_ir.computations[0].body.stmts[0].value assert isinstance(constant, nodes.ScalarLiteral) assert constant.data_type == nodes.DataType.FLOAT32 + + +@gtscript.enum +class LocalEnum(IntEnum): + A = 42 + B = 1000 + + +class TestEnum: + def setup_method(self): + def enum(field: gtscript.Field[float], order: LocalEnum): # type: ignore + with computation(PARALLEL), interval(0, 1): + if order > LocalEnum.A: + field[0, 0, 0] = LocalEnum.B + + self.stencil = enum + + def test_enum_in_stencil(self): + def_ir = parse_definition( + self.stencil, + name=inspect.stack()[0][3], + module=self.__class__.__name__, + ) + + assert isinstance(def_ir.computations[0].body.stmts[0].condition.rhs, nodes.ScalarLiteral) + assert def_ir.computations[0].body.stmts[0].condition.rhs.value == LocalEnum.A + assert isinstance( + def_ir.computations[0].body.stmts[0].main_body.stmts[0].value, nodes.ScalarLiteral + ) + assert def_ir.computations[0].body.stmts[0].main_body.stmts[0].value.value == LocalEnum.B From 9c24bc243e4b6af6300777b072443461cba56dce Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 21 Oct 2025 09:10:33 -0400 Subject: [PATCH 4/6] Lint & better inline docs and error --- src/gt4py/cartesian/frontend/gtscript_frontend.py | 8 ++++---- src/gt4py/cartesian/gtscript.py | 4 ++-- src/gt4py/cartesian/stencil_object.py | 3 +-- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index fda8a4c342..fe690135a8 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -313,14 +313,14 @@ def visit_Attribute(self, node: ast.Attribute): # > ast.Attribute("A") # - value: ast.Name("MyEnum") # We want to replace the entire thing - so we capture the top level - # attribute. We don't use the `self.context` because of thise - # two-step AST structure which doesn't fit the generic `replace_node` + # attribute. We don't use the `self.context` because of this + # two-step AST structure which doesn't fit the generic `replace_node`. if isinstance(node.value, ast.Name) and node.value.id in gtscript._ENUM_REGISTER.keys(): int_value = getattr(gtscript._ENUM_REGISTER[node.value.id], node.attr) return ast.Constant(value=int_value) - # Common replace for all other node in context + # Common replace for all other nodes in context. return self._replace_node(node) def visit_Name(self, node: ast.Name): @@ -2038,7 +2038,7 @@ def collect_external_symbols(definition): local_symbols = CollectLocalSymbolsAstVisitor.apply(gtscript_ast) nonlocal_symbols = {} - # De-pop from `context in Enum registered with `gtscript` + # Remove enums from `context`, they will be turned into integers in the ValueReplacer for enum_ in gtscript._ENUM_REGISTER.keys(): context.pop(enum_, "") diff --git a/src/gt4py/cartesian/gtscript.py b/src/gt4py/cartesian/gtscript.py index fd011c9b51..155dc67f35 100644 --- a/src/gt4py/cartesian/gtscript.py +++ b/src/gt4py/cartesian/gtscript.py @@ -166,11 +166,11 @@ def _parse_annotation(arg, annotation): def enum(class_: type[IntEnum]): + """Mark an IntEnum derived class as readable for GT4Py.""" class_name = class_.__name__ if class_name in _ENUM_REGISTER: raise ValueError( - f"Cannot register @gtscript.enum {class_name} as a class" - "with the same name is already registered." + f"Enum names must be unique. @gtscript.enum {class_name} is already taken." ) if not issubclass(class_, IntEnum): diff --git a/src/gt4py/cartesian/stencil_object.py b/src/gt4py/cartesian/stencil_object.py index fb49349cd8..47a561f736 100644 --- a/src/gt4py/cartesian/stencil_object.py +++ b/src/gt4py/cartesian/stencil_object.py @@ -19,12 +19,11 @@ import numpy as np -from gt4py.cartesian import backend as gt_backend +from gt4py.cartesian import backend as gt_backend, gtscript from gt4py.cartesian.definitions import AccessKind, DomainInfo, FieldInfo, ParameterInfo from gt4py.cartesian.gtc import utils as gtc_utils from gt4py.cartesian.gtc.definitions import Index, Shape from gt4py.storage.cartesian import utils as storage_utils -from gt4py.cartesian import gtscript try: From 44009be05663b49b30bdb235a4a55f97ff19ed73 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 21 Oct 2025 09:39:51 -0400 Subject: [PATCH 5/6] Mixed precision --- src/gt4py/cartesian/definitions.py | 15 +++++++++++++++ src/gt4py/cartesian/frontend/gtscript_frontend.py | 4 +++- src/gt4py/cartesian/stencil_object.py | 10 ++++++++-- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/gt4py/cartesian/definitions.py b/src/gt4py/cartesian/definitions.py index 4781faca22..62b87488e3 100644 --- a/src/gt4py/cartesian/definitions.py +++ b/src/gt4py/cartesian/definitions.py @@ -42,6 +42,21 @@ """Default literal precision used for unspecific `float` types and casts.""" +def get_integer_default_type(): + """Return the integer numpy type corresponding to the LITERAL_INT_PRECISION set.""" + # I'd love to return `numpy.signedinteger[LITERAL_INT_PRECISION]` but that won't work + if LITERAL_INT_PRECISION == 8: + return numpy.int8 + if LITERAL_INT_PRECISION == 32: + return numpy.int32 + if LITERAL_INT_PRECISION == 64: + return numpy.int64 + if LITERAL_INT_PRECISION == 128: + return numpy.int128 + + raise NotImplementedError("Unknown integer precision type") + + @enum.unique class AccessKind(enum.IntFlag): NONE = 0 diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index fe690135a8..efc61f0b8f 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -1912,7 +1912,9 @@ def annotate_definition( ): dtype_annotation = np.dtype(param.annotation) elif param.annotation in gtscript._ENUM_REGISTER.values(): - dtype_annotation = int # We will replace all enums with `int` + dtype_annotation = ( + gt_definitions.get_integer_default_type() + ) # We will replace all enums with `int` elif param.annotation is inspect.Signature.empty: dtype_annotation = None else: diff --git a/src/gt4py/cartesian/stencil_object.py b/src/gt4py/cartesian/stencil_object.py index 47a561f736..7e3c5e75c9 100644 --- a/src/gt4py/cartesian/stencil_object.py +++ b/src/gt4py/cartesian/stencil_object.py @@ -20,7 +20,13 @@ import numpy as np from gt4py.cartesian import backend as gt_backend, gtscript -from gt4py.cartesian.definitions import AccessKind, DomainInfo, FieldInfo, ParameterInfo +from gt4py.cartesian.definitions import ( + AccessKind, + DomainInfo, + FieldInfo, + ParameterInfo, + get_integer_default_type, +) from gt4py.cartesian.gtc import utils as gtc_utils from gt4py.cartesian.gtc.definitions import Index, Shape from gt4py.storage.cartesian import utils as storage_utils @@ -562,7 +568,7 @@ def _call_run( # Normalize `gtscript.enum` to integers for name, value in parameter_args.items(): if type(value) in gtscript._ENUM_REGISTER.values(): - parameter_args[name] = value.value + parameter_args[name] = get_integer_default_type()(value.value) array_infos = _extract_array_infos(field_args, device) cache_key = _compute_domain_origin_cache_key(array_infos, parameter_args, domain, origin) From 71c27c25ad7883ba87a7342159ca7cd4e8b11c04 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 21 Oct 2025 11:46:51 -0400 Subject: [PATCH 6/6] Move the ENUM_REGISTER to a frontend registry (slightly better) --- .../cartesian/frontend/gtscript_frontend.py | 27 ++++++++++++++++--- src/gt4py/cartesian/gtscript.py | 16 ++--------- src/gt4py/cartesian/stencil_object.py | 5 ++-- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index efc61f0b8f..62c71b9747 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -316,8 +316,8 @@ def visit_Attribute(self, node: ast.Attribute): # attribute. We don't use the `self.context` because of this # two-step AST structure which doesn't fit the generic `replace_node`. - if isinstance(node.value, ast.Name) and node.value.id in gtscript._ENUM_REGISTER.keys(): - int_value = getattr(gtscript._ENUM_REGISTER[node.value.id], node.attr) + if isinstance(node.value, ast.Name) and node.value.id in _ENUM_REGISTER.keys(): + int_value = getattr(_ENUM_REGISTER[node.value.id], node.attr) return ast.Constant(value=int_value) # Common replace for all other nodes in context. @@ -1820,6 +1820,11 @@ def visit_Assign(self, node: ast.Assign): raise invalid_target +_ENUM_REGISTER: dict[str, object] = {} +"""Register of IntEnum that will be available to parsing in stencils. Register +with @gtscript.enum()""" + + class GTScriptParser(ast.NodeVisitor): CONST_VALUE_TYPES = ( *gtscript._VALID_DATA_TYPES, @@ -1911,7 +1916,7 @@ def annotate_definition( and param.annotation in gtscript._VALID_DATA_TYPES ): dtype_annotation = np.dtype(param.annotation) - elif param.annotation in gtscript._ENUM_REGISTER.values(): + elif param.annotation in _ENUM_REGISTER.values(): dtype_annotation = ( gt_definitions.get_integer_default_type() ) # We will replace all enums with `int` @@ -2041,7 +2046,7 @@ def collect_external_symbols(definition): nonlocal_symbols = {} # Remove enums from `context`, they will be turned into integers in the ValueReplacer - for enum_ in gtscript._ENUM_REGISTER.keys(): + for enum_ in _ENUM_REGISTER.keys(): context.pop(enum_, "") name_nodes = gt_meta.collect_names(gtscript_ast, skip_annotations=False) @@ -2167,6 +2172,20 @@ def resolve_external_symbols( return result + @staticmethod + def register_enum(class_: type[enum.IntEnum]): + class_name = class_.__name__ + if class_name in _ENUM_REGISTER: + raise ValueError( + f"Enum names must be unique. @gtscript.enum {class_name} is already taken." + ) + + if not issubclass(class_, enum.IntEnum): + raise ValueError(f"Enum {class_name} needs to derive from `enum.IntEnum`.") + + _ENUM_REGISTER[class_name] = class_ + return class_ + def extract_arg_descriptors(self): api_signature = self.definition._gtscript_["api_signature"] api_annotations = self.definition._gtscript_["api_annotations"] diff --git a/src/gt4py/cartesian/gtscript.py b/src/gt4py/cartesian/gtscript.py index 155dc67f35..6debce00de 100644 --- a/src/gt4py/cartesian/gtscript.py +++ b/src/gt4py/cartesian/gtscript.py @@ -160,23 +160,11 @@ def _parse_annotation(arg, annotation): return original_annotations -_ENUM_REGISTER: dict[str, object] = {} -"""Register of IntEnum that will be available to parsing in stencils. Register -with @gtscript.enum()""" - - def enum(class_: type[IntEnum]): """Mark an IntEnum derived class as readable for GT4Py.""" - class_name = class_.__name__ - if class_name in _ENUM_REGISTER: - raise ValueError( - f"Enum names must be unique. @gtscript.enum {class_name} is already taken." - ) - - if not issubclass(class_, IntEnum): - raise ValueError(f"Enum {class_name} needs to derive from `enum.IntEnum`.") + from gt4py.cartesian.frontend import gtscript_frontend as gt_frontend - _ENUM_REGISTER[class_name] = class_ + gt_frontend.GTScriptParser.register_enum(class_) return class_ diff --git a/src/gt4py/cartesian/stencil_object.py b/src/gt4py/cartesian/stencil_object.py index 7e3c5e75c9..97c46d2ba9 100644 --- a/src/gt4py/cartesian/stencil_object.py +++ b/src/gt4py/cartesian/stencil_object.py @@ -19,7 +19,7 @@ import numpy as np -from gt4py.cartesian import backend as gt_backend, gtscript +from gt4py.cartesian import backend as gt_backend from gt4py.cartesian.definitions import ( AccessKind, DomainInfo, @@ -27,6 +27,7 @@ ParameterInfo, get_integer_default_type, ) +from gt4py.cartesian.frontend import gtscript_frontend from gt4py.cartesian.gtc import utils as gtc_utils from gt4py.cartesian.gtc.definitions import Index, Shape from gt4py.storage.cartesian import utils as storage_utils @@ -567,7 +568,7 @@ def _call_run( # Normalize `gtscript.enum` to integers for name, value in parameter_args.items(): - if type(value) in gtscript._ENUM_REGISTER.values(): + if type(value) in gtscript_frontend._ENUM_REGISTER.values(): parameter_args[name] = get_integer_default_type()(value.value) array_infos = _extract_array_infos(field_args, device)