From 5e4601c09b23ee3fb74bf0f58d49005f5b7217a0 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 25 Jul 2025 18:55:59 +0200 Subject: [PATCH 01/10] Add pass which checks if a field that is written to is also read with an offset --- .../iterator/transforms/check_inout_field.py | 80 +++++++ .../next/iterator/transforms/pass_manager.py | 2 + .../test_check_inout_field.py | 224 ++++++++++++++++++ 3 files changed, 306 insertions(+) create mode 100644 src/gt4py/next/iterator/transforms/check_inout_field.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py diff --git a/src/gt4py/next/iterator/transforms/check_inout_field.py b/src/gt4py/next/iterator/transforms/check_inout_field.py new file mode 100644 index 0000000000..026d1d176d --- /dev/null +++ b/src/gt4py/next/iterator/transforms/check_inout_field.py @@ -0,0 +1,80 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dataclasses + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.transforms import collapse_tuple, trace_shifts + +@dataclasses.dataclass(frozen=True) +class CheckInOutField(PreserveLocationVisitor, NodeTranslator): + """ + Checks within a SetAt if any fields which are written to are also read with an offset and raises a ValueError in this case. + + Example: + >>> from gt4py.next.iterator.transforms import infer_domain + >>> from gt4py.next.type_system import type_specifications as ts + >>> float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + >>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) + >>> i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) + >>> offset_provider={"IOff": IDim} + >>> cartesian_domain = im.call("cartesian_domain")(im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 5)) + >>> ir = itir.Program( + ... id="test", + ... function_definitions=[], + ... params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + ... declarations=[], + ... body=[ + ... itir.SetAt( + ... expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))(im.ref("inout")), + ... domain=cartesian_domain, + ... target=im.ref("inout"), + ... ), + ... ], + ... ) + >>> CheckInOutField.apply(ir, offset_provider=offset_provider) + Traceback (most recent call last): + ... + ValueError: The target inout is also read with an offset. + """ + + @classmethod + def apply(cls, program: itir.Program, offset_provider: common.OffsetProvider | common.OffsetProviderType): + return cls().visit(program, offset_provider=offset_provider) + + def visit_SetAt(self, node: itir.SetAt, **kwargs) -> itir.SetAt: + offset_provider = kwargs["offset_provider"] + + def as_fieldop_subexprs(expr): + """Return a list of all subexpressions in expr.args, including expr itself.""" + subexprs = [expr] + if cpm.is_applied_as_fieldop(expr): + for arg in expr.args: + subexprs.extend(as_fieldop_subexprs(arg)) + return subexprs + + def check_expr(fun, args, offset_provider): + shifts = trace_shifts.trace_stencil(fun, num_args=len(args)) + for arg, shift in zip(args, shifts): + for subexpr in as_fieldop_subexprs(arg): + if subexpr == node.target: + if shift not in (set(), {()}): + # This condition is just to filter out the trivial offsets in the horizontal and vertical. + if any(offset_provider[off.value].kind not in {common.DimensionKind.HORIZONTAL, common.DimensionKind.VERTICAL} or val.value != 0 for off, val in shift): + raise ValueError(f"The target {node.target} is also read with an offset.") + if cpm.is_applied_as_fieldop(arg): + check_expr(arg.fun, arg.args, offset_provider) + + if cpm.is_applied_as_fieldop(node.expr): + check_expr(node.expr.fun, node.expr.args, offset_provider) + + return node + diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 08538788b6..f4eddb6e71 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -12,6 +12,7 @@ from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import ( + check_inout_field, concat_where, dead_code_elimination, fuse_as_fieldop, @@ -83,6 +84,7 @@ def apply_common_transforms( ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir ) # domain inference does not support dynamic offsets yet + ir = check_inout_field.CheckInOutField.apply(ir, offset_provider=offset_provider) ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py new file mode 100644 index 0000000000..d6aa20c377 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py @@ -0,0 +1,224 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Optional + +import pytest +from next_tests.toy_connectivity import e2v_conn + +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.check_inout_field import CheckInOutField +from gt4py.next.type_system import type_specifications as ts + +float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) +IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) +offset_provider = {"IOff": IDim} +i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) +cartesian_domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 5), itir.AxisLiteral(value="JDim"), 0, 7) + + +def program_factory( + params: list[itir.Sym], + body: list[itir.SetAt], + declarations: Optional[list[itir.Temporary]] = None, +) -> itir.Program: + return itir.Program( + id="testee", + function_definitions=[], + params=params, + declarations=declarations or [], + body=body, + ) + + +def test_check_inout_no_offset(): + ir = program_factory( + params=[im.sym("inout", i_field_type)], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.ref("deref"))(im.ref("inout")), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + # Should not raise + assert ir == CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_zero_offset(): + ir = program_factory( + params=[im.sym("inout", i_field_type)], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 0)("x"))))(im.ref("inout")), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + # Should not raise + assert ir == CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_e2v_zero_offset(): + offset_provider = {"E2V": e2v_conn} # override + ir = program_factory( + params=[im.sym("inout", i_field_type)], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 0)("x"))))(im.ref("inout")), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + with pytest.raises(ValueError, match="The target inout is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_offset(): + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))(im.ref("inout")), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + with pytest.raises(ValueError, match="The target inout is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_shift_different_field(): + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("x", "y")(im.plus(im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 1)("y")))))(im.ref("inout"), im.ref("in")), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + assert ir == CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_in_arg(): + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))(im.as_fieldop(im.ref("deref"))(im.ref("inout"))), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + with pytest.raises(ValueError, match="The target inout is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_in_arg_two_fields(): + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("x", "y")(im.plus(im.deref(im.shift("IOff", 1)("x")), im.deref(im.shift("IOff", 0)("y")))))(im.as_fieldop(im.ref("deref"))(im.ref("inout")), im.ref("in")), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + with pytest.raises(ValueError, match="The target inout is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_in_arg_shift_different_field(): + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("x", "y")(im.plus(im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 1)("y")))))(im.as_fieldop(im.ref("deref"))(im.ref("inout")), im.ref("in")), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + assert ir == CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_in_arg_shifted(): + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("x", "y")(im.plus(im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 0)("y")))))(im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))(im.ref("inout")), im.ref("in")), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + with pytest.raises(ValueError, match="The target inout is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_in_arg_nested_shifted(): + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("x", "y")(im.plus(im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 0)("y")))))(im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 0)("x"))))(im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))(im.ref("inout"))), im.ref("in")), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + with pytest.raises(ValueError, match="The target inout is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_in_arg_nested_shift_different_arg(): + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("x", "y")(im.plus(im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 0)("y")))))(im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 0)("x"))))(im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))(im.ref("in"))), im.ref("inout")), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + assert ir == CheckInOutField.apply(ir, offset_provider=offset_provider) From fedbcca52768517d5204d913b2393807c1378a13 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 25 Jul 2025 19:32:55 +0200 Subject: [PATCH 02/10] Account for tuples --- .../iterator/transforms/check_inout_field.py | 55 +++++-- .../test_check_inout_field.py | 136 ++++++++++++++++-- 2 files changed, 165 insertions(+), 26 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/check_inout_field.py b/src/gt4py/next/iterator/transforms/check_inout_field.py index 026d1d176d..9ce21310f6 100644 --- a/src/gt4py/next/iterator/transforms/check_inout_field.py +++ b/src/gt4py/next/iterator/transforms/check_inout_field.py @@ -11,8 +11,9 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.transforms import collapse_tuple, trace_shifts +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.transforms import trace_shifts + @dataclasses.dataclass(frozen=True) class CheckInOutField(PreserveLocationVisitor, NodeTranslator): @@ -25,8 +26,10 @@ class CheckInOutField(PreserveLocationVisitor, NodeTranslator): >>> float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) >>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) >>> i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) - >>> offset_provider={"IOff": IDim} - >>> cartesian_domain = im.call("cartesian_domain")(im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 5)) + >>> offset_provider = {"IOff": IDim} + >>> cartesian_domain = im.call("cartesian_domain")( + ... im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 5) + ... ) >>> ir = itir.Program( ... id="test", ... function_definitions=[], @@ -34,7 +37,9 @@ class CheckInOutField(PreserveLocationVisitor, NodeTranslator): ... declarations=[], ... body=[ ... itir.SetAt( - ... expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))(im.ref("inout")), + ... expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + ... im.ref("inout") + ... ), ... domain=cartesian_domain, ... target=im.ref("inout"), ... ), @@ -44,37 +49,57 @@ class CheckInOutField(PreserveLocationVisitor, NodeTranslator): Traceback (most recent call last): ... ValueError: The target inout is also read with an offset. - """ + """ @classmethod - def apply(cls, program: itir.Program, offset_provider: common.OffsetProvider | common.OffsetProviderType): + def apply( + cls, + program: itir.Program, + offset_provider: common.OffsetProvider | common.OffsetProviderType, + ): return cls().visit(program, offset_provider=offset_provider) def visit_SetAt(self, node: itir.SetAt, **kwargs) -> itir.SetAt: offset_provider = kwargs["offset_provider"] - def as_fieldop_subexprs(expr): + def extract_subexprs(expr): """Return a list of all subexpressions in expr.args, including expr itself.""" subexprs = [expr] - if cpm.is_applied_as_fieldop(expr): + if hasattr(expr, "args"): for arg in expr.args: - subexprs.extend(as_fieldop_subexprs(arg)) + subexprs.extend(extract_subexprs(arg)) return subexprs def check_expr(fun, args, offset_provider): shifts = trace_shifts.trace_stencil(fun, num_args=len(args)) for arg, shift in zip(args, shifts): - for subexpr in as_fieldop_subexprs(arg): - if subexpr == node.target: + arg_subexprs = extract_subexprs(arg) + target_subexprs = extract_subexprs(node.target) + for subexpr in arg_subexprs: + if subexpr in target_subexprs: # Account for im.make_tuple if shift not in (set(), {()}): # This condition is just to filter out the trivial offsets in the horizontal and vertical. - if any(offset_provider[off.value].kind not in {common.DimensionKind.HORIZONTAL, common.DimensionKind.VERTICAL} or val.value != 0 for off, val in shift): - raise ValueError(f"The target {node.target} is also read with an offset.") + if any( + offset_provider[off.value].kind + not in { + common.DimensionKind.HORIZONTAL, + common.DimensionKind.VERTICAL, + } + or val.value != 0 + for off, val in shift + ): + raise ValueError( + f"The target {node.target} is also read with an offset." + ) if cpm.is_applied_as_fieldop(arg): check_expr(arg.fun, arg.args, offset_provider) if cpm.is_applied_as_fieldop(node.expr): check_expr(node.expr.fun, node.expr.args, offset_provider) + else: # Account for im.make_tuple + if hasattr(node.expr, "args"): + for expr in node.expr.args: + if cpm.is_applied_as_fieldop(expr): + check_expr(expr.fun, expr.args, offset_provider) return node - diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py index d6aa20c377..8a0f48900f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py @@ -22,7 +22,11 @@ offset_provider = {"IOff": IDim} i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) cartesian_domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 5), itir.AxisLiteral(value="JDim"), 0, 7) + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 5), + itir.AxisLiteral(value="JDim"), + 0, + 7, +) def program_factory( @@ -62,7 +66,9 @@ def test_check_inout_zero_offset(): declarations=[], body=[ itir.SetAt( - expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 0)("x"))))(im.ref("inout")), + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 0)("x"))))( + im.ref("inout") + ), domain=cartesian_domain, target=im.ref("inout"), ), @@ -80,7 +86,9 @@ def test_check_inout_e2v_zero_offset(): declarations=[], body=[ itir.SetAt( - expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 0)("x"))))(im.ref("inout")), + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 0)("x"))))( + im.ref("inout") + ), domain=cartesian_domain, target=im.ref("inout"), ), @@ -97,7 +105,9 @@ def test_check_inout_offset(): declarations=[], body=[ itir.SetAt( - expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))(im.ref("inout")), + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + im.ref("inout") + ), domain=cartesian_domain, target=im.ref("inout"), ), @@ -114,7 +124,13 @@ def test_check_inout_shift_different_field(): declarations=[], body=[ itir.SetAt( - expr=im.as_fieldop(im.lambda_("x", "y")(im.plus(im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 1)("y")))))(im.ref("inout"), im.ref("in")), + expr=im.as_fieldop( + im.lambda_("x", "y")( + im.plus( + im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 1)("y")) + ) + ) + )(im.ref("inout"), im.ref("in")), domain=cartesian_domain, target=im.ref("inout"), ), @@ -130,7 +146,9 @@ def test_check_inout_in_arg(): declarations=[], body=[ itir.SetAt( - expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))(im.as_fieldop(im.ref("deref"))(im.ref("inout"))), + expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + im.as_fieldop(im.ref("deref"))(im.ref("inout")) + ), domain=cartesian_domain, target=im.ref("inout"), ), @@ -147,7 +165,13 @@ def test_check_inout_in_arg_two_fields(): declarations=[], body=[ itir.SetAt( - expr=im.as_fieldop(im.lambda_("x", "y")(im.plus(im.deref(im.shift("IOff", 1)("x")), im.deref(im.shift("IOff", 0)("y")))))(im.as_fieldop(im.ref("deref"))(im.ref("inout")), im.ref("in")), + expr=im.as_fieldop( + im.lambda_("x", "y")( + im.plus( + im.deref(im.shift("IOff", 1)("x")), im.deref(im.shift("IOff", 0)("y")) + ) + ) + )(im.as_fieldop(im.ref("deref"))(im.ref("inout")), im.ref("in")), domain=cartesian_domain, target=im.ref("inout"), ), @@ -164,7 +188,13 @@ def test_check_inout_in_arg_shift_different_field(): declarations=[], body=[ itir.SetAt( - expr=im.as_fieldop(im.lambda_("x", "y")(im.plus(im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 1)("y")))))(im.as_fieldop(im.ref("deref"))(im.ref("inout")), im.ref("in")), + expr=im.as_fieldop( + im.lambda_("x", "y")( + im.plus( + im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 1)("y")) + ) + ) + )(im.as_fieldop(im.ref("deref"))(im.ref("inout")), im.ref("in")), domain=cartesian_domain, target=im.ref("inout"), ), @@ -180,7 +210,18 @@ def test_check_inout_in_arg_shifted(): declarations=[], body=[ itir.SetAt( - expr=im.as_fieldop(im.lambda_("x", "y")(im.plus(im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 0)("y")))))(im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))(im.ref("inout")), im.ref("in")), + expr=im.as_fieldop( + im.lambda_("x", "y")( + im.plus( + im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 0)("y")) + ) + ) + )( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + im.ref("inout") + ), + im.ref("in"), + ), domain=cartesian_domain, target=im.ref("inout"), ), @@ -197,7 +238,20 @@ def test_check_inout_in_arg_nested_shifted(): declarations=[], body=[ itir.SetAt( - expr=im.as_fieldop(im.lambda_("x", "y")(im.plus(im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 0)("y")))))(im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 0)("x"))))(im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))(im.ref("inout"))), im.ref("in")), + expr=im.as_fieldop( + im.lambda_("x", "y")( + im.plus( + im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 0)("y")) + ) + ) + )( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 0)("x"))))( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + im.ref("inout") + ) + ), + im.ref("in"), + ), domain=cartesian_domain, target=im.ref("inout"), ), @@ -214,7 +268,20 @@ def test_check_inout_in_arg_nested_shift_different_arg(): declarations=[], body=[ itir.SetAt( - expr=im.as_fieldop(im.lambda_("x", "y")(im.plus(im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 0)("y")))))(im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 0)("x"))))(im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))(im.ref("in"))), im.ref("inout")), + expr=im.as_fieldop( + im.lambda_("x", "y")( + im.plus( + im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 0)("y")) + ) + ) + )( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 0)("x"))))( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + im.ref("in") + ) + ), + im.ref("inout"), + ), domain=cartesian_domain, target=im.ref("inout"), ), @@ -222,3 +289,50 @@ def test_check_inout_in_arg_nested_shift_different_arg(): ) assert ir == CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_tuple(): + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + declarations=[], + body=[ + itir.SetAt( + expr=im.make_tuple( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + im.ref("inout") + ), + im.as_fieldop(im.ref("deref"))(im.ref("in")), + ), + domain=cartesian_domain, + target=im.make_tuple(im.ref("inout"), im.ref("in")), + ), + ], + ) + + with pytest.raises(ValueError, match="The target {inout, in} is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_tuple_get(): + ir = program_factory( + params=[ + im.sym("inout", ts.TupleType(types=[i_field_type] * 2)), + im.sym("in", i_field_type), + ], + declarations=[], + body=[ + itir.SetAt( + expr=im.make_tuple( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + im.tuple_get(0, im.ref("inout")) + ), + im.as_fieldop(im.ref("deref"))(im.ref("in")), + ), + domain=cartesian_domain, + target=im.make_tuple(im.ref("inout")), + ), + ], + ) + + with pytest.raises(ValueError, match="The target {inout} is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) From e1dadff8b4dd9e6661dc04c7d3c1fc4622243a21 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 6 Aug 2025 11:23:45 +0200 Subject: [PATCH 03/10] Fix Doctest --- src/gt4py/next/iterator/transforms/check_inout_field.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gt4py/next/iterator/transforms/check_inout_field.py b/src/gt4py/next/iterator/transforms/check_inout_field.py index 9ce21310f6..21c1800864 100644 --- a/src/gt4py/next/iterator/transforms/check_inout_field.py +++ b/src/gt4py/next/iterator/transforms/check_inout_field.py @@ -23,6 +23,7 @@ class CheckInOutField(PreserveLocationVisitor, NodeTranslator): Example: >>> from gt4py.next.iterator.transforms import infer_domain >>> from gt4py.next.type_system import type_specifications as ts + >>> from gt4py.next.iterator.ir_utils import ir_makers as im >>> float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) >>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) >>> i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) From 865a2a017a3ea4d180153542fd8518c3df566ed2 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 7 Aug 2025 17:20:54 +0200 Subject: [PATCH 04/10] Refactor tests and extend tuple testcases --- .../iterator/transforms/check_inout_field.py | 20 ++- .../test_check_inout_field.py | 150 +++++++++++++++--- 2 files changed, 137 insertions(+), 33 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/check_inout_field.py b/src/gt4py/next/iterator/transforms/check_inout_field.py index 21c1800864..c9e1fe66dd 100644 --- a/src/gt4py/next/iterator/transforms/check_inout_field.py +++ b/src/gt4py/next/iterator/transforms/check_inout_field.py @@ -66,18 +66,26 @@ def visit_SetAt(self, node: itir.SetAt, **kwargs) -> itir.SetAt: def extract_subexprs(expr): """Return a list of all subexpressions in expr.args, including expr itself.""" subexprs = [expr] - if hasattr(expr, "args"): + if isinstance(expr, itir.FunCall): for arg in expr.args: subexprs.extend(extract_subexprs(arg)) return subexprs + def visit_nested_make_tuple_tuple_get(expr): + """Recursively visit make_tuple and tuple_get expr and check all as_fieldop subexpressions.""" + if cpm.is_applied_as_fieldop(expr): + check_expr(expr.fun, expr.args, offset_provider) + elif cpm.is_call_to(expr, ("make_tuple", "tuple_get")): + for arg in expr.args: + visit_nested_make_tuple_tuple_get(arg) + def check_expr(fun, args, offset_provider): shifts = trace_shifts.trace_stencil(fun, num_args=len(args)) for arg, shift in zip(args, shifts): arg_subexprs = extract_subexprs(arg) target_subexprs = extract_subexprs(node.target) for subexpr in arg_subexprs: - if subexpr in target_subexprs: # Account for im.make_tuple + if subexpr in target_subexprs: if shift not in (set(), {()}): # This condition is just to filter out the trivial offsets in the horizontal and vertical. if any( @@ -97,10 +105,6 @@ def check_expr(fun, args, offset_provider): if cpm.is_applied_as_fieldop(node.expr): check_expr(node.expr.fun, node.expr.args, offset_provider) - else: # Account for im.make_tuple - if hasattr(node.expr, "args"): - for expr in node.expr.args: - if cpm.is_applied_as_fieldop(expr): - check_expr(expr.fun, expr.args, offset_provider) - + else: # Account for nested im.make_tuple and im.tuple_get + visit_nested_make_tuple_tuple_get(node.expr) return node diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py index 8a0f48900f..1fc7e869f5 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py @@ -21,12 +21,7 @@ IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) offset_provider = {"IOff": IDim} i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) -cartesian_domain = im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 5), - itir.AxisLiteral(value="JDim"), - 0, - 7, -) +cartesian_domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 5)}) def program_factory( @@ -44,9 +39,9 @@ def program_factory( def test_check_inout_no_offset(): + # inout ← (⇑deref)(inout) ir = program_factory( params=[im.sym("inout", i_field_type)], - declarations=[], body=[ itir.SetAt( expr=im.as_fieldop(im.ref("deref"))(im.ref("inout")), @@ -61,9 +56,9 @@ def test_check_inout_no_offset(): def test_check_inout_zero_offset(): + # inout ← (⇑(λ(x) → ·⟪IOffₒ, 0ₒ⟫(x)))(inout) ir = program_factory( params=[im.sym("inout", i_field_type)], - declarations=[], body=[ itir.SetAt( expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 0)("x"))))( @@ -80,10 +75,10 @@ def test_check_inout_zero_offset(): def test_check_inout_e2v_zero_offset(): + # inout ← (⇑(λ(x) → ·⟪E2Vₒ, 0ₒ⟫(x)))(inout) offset_provider = {"E2V": e2v_conn} # override ir = program_factory( params=[im.sym("inout", i_field_type)], - declarations=[], body=[ itir.SetAt( expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("E2V", 0)("x"))))( @@ -100,9 +95,9 @@ def test_check_inout_e2v_zero_offset(): def test_check_inout_offset(): + # inout ← (⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(inout) ir = program_factory( params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], - declarations=[], body=[ itir.SetAt( expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( @@ -121,7 +116,6 @@ def test_check_inout_offset(): def test_check_inout_shift_different_field(): ir = program_factory( params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], - declarations=[], body=[ itir.SetAt( expr=im.as_fieldop( @@ -141,9 +135,9 @@ def test_check_inout_shift_different_field(): def test_check_inout_in_arg(): + # inout ← (⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))((⇑deref)(inout)) ir = program_factory( params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], - declarations=[], body=[ itir.SetAt( expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( @@ -160,9 +154,9 @@ def test_check_inout_in_arg(): def test_check_inout_in_arg_two_fields(): + # inout ← (⇑(λ(x, y) → ·⟪IOffₒ, 1ₒ⟫(x) + ·⟪IOffₒ, 0ₒ⟫(y)))((⇑deref)(inout), in) ir = program_factory( params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], - declarations=[], body=[ itir.SetAt( expr=im.as_fieldop( @@ -183,9 +177,9 @@ def test_check_inout_in_arg_two_fields(): def test_check_inout_in_arg_shift_different_field(): + # inout ← (⇑(λ(x, y) → ·⟪IOffₒ, 0ₒ⟫(x) + ·⟪IOffₒ, 1ₒ⟫(y)))((⇑deref)(inout), in) ir = program_factory( params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], - declarations=[], body=[ itir.SetAt( expr=im.as_fieldop( @@ -205,9 +199,9 @@ def test_check_inout_in_arg_shift_different_field(): def test_check_inout_in_arg_shifted(): + # inout ← (⇑(λ(x, y) → ·⟪IOffₒ, 0ₒ⟫(x) + ·⟪IOffₒ, 0ₒ⟫(y)))((⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(inout), in) ir = program_factory( params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], - declarations=[], body=[ itir.SetAt( expr=im.as_fieldop( @@ -233,9 +227,11 @@ def test_check_inout_in_arg_shifted(): def test_check_inout_in_arg_nested_shifted(): + # inout ← (⇑(λ(x, y) → ·⟪IOffₒ, 0ₒ⟫(x) + ·⟪IOffₒ, 0ₒ⟫(y)))( + # (⇑(λ(x) → ·⟪IOffₒ, 0ₒ⟫(x)))((⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(inout)), in + # ) ir = program_factory( params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], - declarations=[], body=[ itir.SetAt( expr=im.as_fieldop( @@ -263,9 +259,11 @@ def test_check_inout_in_arg_nested_shifted(): def test_check_inout_in_arg_nested_shift_different_arg(): + # inout ← (⇑(λ(x, y) → ·⟪IOffₒ, 0ₒ⟫(x) + ·⟪IOffₒ, 0ₒ⟫(y)))( + # (⇑(λ(x) → ·⟪IOffₒ, 0ₒ⟫(x)))((⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(in)), inout + # ) ir = program_factory( params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], - declarations=[], body=[ itir.SetAt( expr=im.as_fieldop( @@ -292,34 +290,84 @@ def test_check_inout_in_arg_nested_shift_different_arg(): def test_check_inout_tuple(): + # {inout, inout2} ← {(⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(inout[0]), (⇑deref)(inout2)} ir = program_factory( - params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], - declarations=[], + params=[im.sym("inout", i_field_type), im.sym("inout2", i_field_type)], body=[ itir.SetAt( expr=im.make_tuple( im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( im.ref("inout") ), - im.as_fieldop(im.ref("deref"))(im.ref("in")), + im.as_fieldop(im.ref("deref"))(im.ref("inout2")), ), domain=cartesian_domain, - target=im.make_tuple(im.ref("inout"), im.ref("in")), + target=im.make_tuple(im.ref("inout"), im.ref("inout2")), + ), + ], + ) + + with pytest.raises(ValueError, match="The target {inout, inout2} is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_tuple_as_fieldop(): + # {inout, out} ← (⇑(λ(x, y) → {·⟪IOffₒ, 1ₒ⟫(x), ·y}))(inout, in) + ir = program_factory( + params=[ + im.sym("inout", i_field_type), + im.sym("in", i_field_type), + im.sym("out", i_field_type), + ], + body=[ + itir.SetAt( + expr=im.as_fieldop( + im.lambda_("x", "y")( + im.make_tuple(im.deref(im.shift("IOff", 1)("x")), im.deref("y")) + ) + )(im.ref("inout"), im.ref("in")), + domain=cartesian_domain, + target=im.make_tuple(im.ref("inout"), im.ref("out")), ), ], ) - with pytest.raises(ValueError, match="The target {inout, in} is also read with an offset."): + with pytest.raises(ValueError, match="The target {inout, out} is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_tuple_get_make_tuple(): + # inout ← {(⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(inout[0]), as_fieldop(...)}[0] + ir = program_factory( + params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], + body=[ + itir.SetAt( + expr=im.tuple_get( + 0, + im.make_tuple( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + im.ref("inout") + ), + im.as_fieldop(im.ref("deref"))(im.ref("in")), + ), + ), + domain=cartesian_domain, + target=im.ref("inout"), + ), + ], + ) + + with pytest.raises(ValueError, match="The target inout is also read with an offset."): CheckInOutField.apply(ir, offset_provider=offset_provider) def test_check_inout_tuple_get(): + # inout ← {(⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(inout[0]), (⇑deref)(in)} ir = program_factory( params=[ im.sym("inout", ts.TupleType(types=[i_field_type] * 2)), im.sym("in", i_field_type), ], - declarations=[], body=[ itir.SetAt( expr=im.make_tuple( @@ -329,10 +377,62 @@ def test_check_inout_tuple_get(): im.as_fieldop(im.ref("deref"))(im.ref("in")), ), domain=cartesian_domain, - target=im.make_tuple(im.ref("inout")), + target=im.ref("inout"), + ), + ], + ) + + with pytest.raises(ValueError, match="The target inout is also read with an offset."): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_tuple_tuple_get(): + # {inout[0], inout2} ← {(⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(inout[0]), (⇑deref)(inout2)} + ir = program_factory( + params=[ + im.sym("inout", ts.TupleType(types=[i_field_type] * 2)), + im.sym("inout2", i_field_type), + ], + body=[ + itir.SetAt( + expr=im.make_tuple( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + im.tuple_get(0, im.ref("inout")) + ), + im.as_fieldop(im.ref("deref"))(im.ref("inout2")), + ), + domain=cartesian_domain, + target=im.make_tuple(im.tuple_get(0, im.ref("inout")), im.ref("inout2")), + ), + ], + ) + + with pytest.raises( + ValueError, match="The target {inout\[0\], inout2} is also read with an offset." + ): + CheckInOutField.apply(ir, offset_provider=offset_provider) + + +def test_check_inout_tuple_get_tuple(): + # inout[0] ← {(⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(inout[0][0]), (⇑deref)(in)} + ir = program_factory( + params=[ + im.sym("inout", ts.TupleType(types=[ts.TupleType(types=[i_field_type] * 2)] * 2)), + im.sym("in", i_field_type), + ], + body=[ + itir.SetAt( + expr=im.make_tuple( + im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( + im.tuple_get(0, im.tuple_get(0, im.ref("inout"))) + ), + im.as_fieldop(im.ref("deref"))(im.ref("in")), + ), + domain=cartesian_domain, + target=im.tuple_get(0, im.ref("inout")), ), ], ) - with pytest.raises(ValueError, match="The target {inout} is also read with an offset."): + with pytest.raises(ValueError, match="The target inout\[0\] is also read with an offset."): CheckInOutField.apply(ir, offset_provider=offset_provider) From f94bc07da154c14b15390ad9a5c4a14f7a1fc82b Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 8 Aug 2025 13:16:46 +0200 Subject: [PATCH 05/10] Raise error for as_fielops in as_fielop args, update tests and refactor is_tuple_expr_of --- .../ir_utils/common_pattern_matcher.py | 20 +++- .../iterator/transforms/check_inout_field.py | 10 +- .../iterator/transforms/fuse_as_fieldop.py | 12 +- src/gt4py/next/otf/compiled_program.py | 1 + .../codegens/gtfn/gtfn_ir.py | 32 ++---- .../codegens/gtfn/itir_to_gtfn_ir.py | 4 +- .../test_check_inout_field.py | 106 ++++-------------- 7 files changed, 60 insertions(+), 125 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index da13d20bb6..6433ce9e26 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -7,10 +7,11 @@ # SPDX-License-Identifier: BSD-3-Clause from collections.abc import Iterable -from typing import Any, Generic, List, TypeAlias, TypeGuard, TypeVar +from typing import Any, Callable, Generic, List, TypeAlias, TypeGuard, TypeVar, overload from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr as GTFNIRExpr _Fun = TypeVar("_Fun", bound=itir.Expr) @@ -135,3 +136,20 @@ def is_identity_as_fieldop(node: itir.Expr) -> TypeGuard[_FunCallToFunCallToRef] ): return True return False + + +@overload +def is_tuple_expr_of(pred: Callable[[itir.Expr], bool], expr: itir.Expr) -> bool: ... +@overload +def is_tuple_expr_of(pred: Callable[[GTFNIRExpr], bool], expr: GTFNIRExpr) -> bool: ... + + +def is_tuple_expr_of( + pred: Callable[[Any], bool], + expr: itir.Expr | GTFNIRExpr, +) -> bool: + if is_call_to(expr, "make_tuple"): + return all(is_tuple_expr_of(pred, arg) for arg in expr.args) + if is_call_to(expr, "tuple_get"): + return is_tuple_expr_of(pred, expr.args[1]) + return pred(expr) diff --git a/src/gt4py/next/iterator/transforms/check_inout_field.py b/src/gt4py/next/iterator/transforms/check_inout_field.py index c9e1fe66dd..0f56d58d69 100644 --- a/src/gt4py/next/iterator/transforms/check_inout_field.py +++ b/src/gt4py/next/iterator/transforms/check_inout_field.py @@ -81,13 +81,13 @@ def visit_nested_make_tuple_tuple_get(expr): def check_expr(fun, args, offset_provider): shifts = trace_shifts.trace_stencil(fun, num_args=len(args)) + target_subexprs = extract_subexprs(node.target) for arg, shift in zip(args, shifts): arg_subexprs = extract_subexprs(arg) - target_subexprs = extract_subexprs(node.target) for subexpr in arg_subexprs: if subexpr in target_subexprs: if shift not in (set(), {()}): - # This condition is just to filter out the trivial offsets in the horizontal and vertical. + # This condition is just to filter out the trivial offsets in the horizontal and vertical. # TODO: remove and add preprocessing of IOff(0) instead if any( offset_provider[off.value].kind not in { @@ -100,8 +100,10 @@ def check_expr(fun, args, offset_provider): raise ValueError( f"The target {node.target} is also read with an offset." ) - if cpm.is_applied_as_fieldop(arg): - check_expr(arg.fun, arg.args, offset_provider) + if not cpm.is_tuple_expr_of(lambda e: isinstance(e, itir.SymRef), arg): + raise ValueError( + f"Unexpected as_fieldop argument {arg}. Expected `make_tuple`, `tuple_get` or `SymRef`. Please run temporary extraction first." + ) if cpm.is_applied_as_fieldop(node.expr): check_expr(node.expr.fun, node.expr.args, offset_provider) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 4b3a258396..26679d9c80 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -46,14 +46,6 @@ def _merge_arguments( return new_args -def _is_tuple_expr_of_literals(expr: itir.Expr): - if cpm.is_call_to(expr, "make_tuple"): - return all(_is_tuple_expr_of_literals(arg) for arg in expr.args) - if cpm.is_call_to(expr, "tuple_get"): - return _is_tuple_expr_of_literals(expr.args[1]) - return isinstance(expr, itir.Literal) - - def _inline_as_fieldop_arg( arg: itir.Expr, *, uids: eve_utils.UIDGenerator ) -> tuple[itir.Expr, dict[str, itir.Expr]]: @@ -142,7 +134,7 @@ def fuse_as_fieldop( # transform scalar `if` into per-grid-point `if` # TODO(tehrengruber): revisit if we want to inline if_ arg = im.op_as_fieldop("if_")(*arg.args) - elif _is_tuple_expr_of_literals(arg): + elif cpm.is_tuple_expr_of(lambda e: isinstance(e, itir.Literal), arg): arg = im.op_as_fieldop(im.lambda_()(arg))() else: raise NotImplementedError() @@ -189,7 +181,7 @@ def fuse_as_fieldop( def _arg_inline_predicate(node: itir.Expr, shifts: set[tuple[itir.OffsetLiteral, ...]]) -> bool: - if _is_tuple_expr_of_literals(node): + if cpm.is_tuple_expr_of(lambda e: isinstance(e, itir.Literal), node): return True if ( diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 174ca8edb1..802dcf6a94 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -126,6 +126,7 @@ class CompiledProgramsPool: definition_stage: ffront_stages.ProgramDefinition program_type: ts_ffront.ProgramType static_params: Sequence[str] | None = None # not ordered + static_domain_sizes: bool = False _compiled_programs: eve_utils.CustomMapping = dataclasses.field( default_factory=lambda: eve_utils.CustomMapping(_hash_compiled_program_unsafe), diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index f7445461c0..e2ffecc9b1 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -8,12 +8,13 @@ from __future__ import annotations -from typing import Callable, ClassVar, Optional, Union +from typing import ClassVar, Optional, Union from gt4py.eve import Coerced, SymbolName, datamodels from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.next import common from gt4py.next.iterator import builtins +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.program_processors.codegens.gtfn.gtfn_im_ir import ImperativeFunctionDefinition from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Node, Sym, SymRef @@ -97,25 +98,6 @@ class Backend(Node): domain: Union[SymRef, CartesianDomain, UnstructuredDomain] -def _is_tuple_expr_of(pred: Callable[[Expr], bool], expr: Expr) -> bool: - if ( - isinstance(expr, FunCall) - and isinstance(expr.fun, SymRef) - and expr.fun.id == "tuple_get" - and len(expr.args) == 2 - and _is_tuple_expr_of(pred, expr.args[1]) - ): - return True - if ( - isinstance(expr, FunCall) - and isinstance(expr.fun, SymRef) - and expr.fun.id == "make_tuple" - and all(_is_tuple_expr_of(pred, arg) for arg in expr.args) - ): - return True - return pred(expr) - - class SidComposite(Expr): values: list[Expr] @@ -125,7 +107,7 @@ def _values_validator( ) -> None: if not all( isinstance(el, (SidFromScalar, SidComposite)) - or _is_tuple_expr_of( + or cpm.is_tuple_expr_of( lambda expr: isinstance(expr, (SymRef, Literal)) or (isinstance(expr, FunCall) and expr.fun == SymRef(id="index")), el, @@ -139,9 +121,9 @@ def _values_validator( def _might_be_scalar_expr(expr: Expr) -> bool: if isinstance(expr, BinaryExpr): - return all(_is_tuple_expr_of(_might_be_scalar_expr, arg) for arg in (expr.lhs, expr.rhs)) + return all(cpm.is_tuple_expr_of(_might_be_scalar_expr, arg) for arg in (expr.lhs, expr.rhs)) if isinstance(expr, UnaryExpr): - return _is_tuple_expr_of(_might_be_scalar_expr, expr.expr) + return cpm.is_tuple_expr_of(_might_be_scalar_expr, expr.expr) if ( isinstance(expr, FunCall) and isinstance(expr.fun, SymRef) @@ -150,7 +132,7 @@ def _might_be_scalar_expr(expr: Expr) -> bool: return all(_might_be_scalar_expr(arg) for arg in expr.args) if isinstance(expr, CastExpr): return _might_be_scalar_expr(expr.obj_expr) - if _is_tuple_expr_of(lambda e: isinstance(e, (SymRef, Literal)), expr): + if cpm.is_tuple_expr_of(lambda e: isinstance(e, (SymRef, Literal)), expr): return True return False @@ -183,7 +165,7 @@ def _arg_validator( self: datamodels.DataModelTP, attribute: datamodels.Attribute, inputs: list[Expr] ) -> None: for inp in inputs: - if not _is_tuple_expr_of( + if not cpm.is_tuple_expr_of( lambda expr: isinstance(expr, (SymRef, SidComposite, SidFromScalar)) or ( isinstance(expr, FunCall) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index a445390583..04ce53cccc 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -72,7 +72,7 @@ def _is_tuple_of_ref_or_literal(expr: itir.Expr) -> bool: and all(_is_tuple_of_ref_or_literal(arg) for arg in expr.args) ): return True - if isinstance(expr, (itir.SymRef, itir.Literal)): + if isinstance(expr, (itir.SymRef, itir.Literal)): # move to condition return True return False @@ -587,7 +587,7 @@ def visit_IfStmt(self, node: itir.IfStmt, **kwargs: Any) -> IfStmt: def visit_SetAt( self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any ) -> Union[StencilExecution, ScanExecution]: - if _is_tuple_of_ref_or_literal(node.expr): + if cpm.is_tuple_expr_of(lambda e: isinstance(e, (SymRef, Literal)), node.expr): node.expr = im.as_fieldop("deref", node.domain)(node.expr) itir_projector, extracted_expr = ir_utils_misc.extract_projector(node.expr) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py index 1fc7e869f5..c2bd0e48e3 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_check_inout_field.py @@ -114,6 +114,7 @@ def test_check_inout_offset(): def test_check_inout_shift_different_field(): + # inout ← (⇑(λ(x, y) → ·⟪IOffₒ, 0ₒ⟫(x) + ·⟪IOffₒ, 1ₒ⟫(y)))(inout, in); ir = program_factory( params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], body=[ @@ -134,7 +135,7 @@ def test_check_inout_shift_different_field(): assert ir == CheckInOutField.apply(ir, offset_provider=offset_provider) -def test_check_inout_in_arg(): +def test_check_inout_in_as_fieldop_arg(): # inout ← (⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))((⇑deref)(inout)) ir = program_factory( params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], @@ -149,7 +150,10 @@ def test_check_inout_in_arg(): ], ) - with pytest.raises(ValueError, match="The target inout is also read with an offset."): + with pytest.raises( + ValueError, + match=r"Unexpected as_fieldop argument \(⇑deref\)\(inout\). Expected `make_tuple`, `tuple_get` or `SymRef`. Please run temporary extraction first.", + ): CheckInOutField.apply(ir, offset_provider=offset_provider) @@ -160,12 +164,12 @@ def test_check_inout_in_arg_two_fields(): body=[ itir.SetAt( expr=im.as_fieldop( - im.lambda_("x", "y")( + im.lambda_("x")( im.plus( - im.deref(im.shift("IOff", 1)("x")), im.deref(im.shift("IOff", 0)("y")) + im.deref(im.shift("IOff", 1)("x")), im.deref(im.shift("IOff", 0)("x")) ) ) - )(im.as_fieldop(im.ref("deref"))(im.ref("inout")), im.ref("in")), + )(im.make_tuple(im.ref("inout"), im.ref("in"))), domain=cartesian_domain, target=im.ref("inout"), ), @@ -176,46 +180,19 @@ def test_check_inout_in_arg_two_fields(): CheckInOutField.apply(ir, offset_provider=offset_provider) -def test_check_inout_in_arg_shift_different_field(): - # inout ← (⇑(λ(x, y) → ·⟪IOffₒ, 0ₒ⟫(x) + ·⟪IOffₒ, 1ₒ⟫(y)))((⇑deref)(inout), in) - ir = program_factory( - params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], - body=[ - itir.SetAt( - expr=im.as_fieldop( - im.lambda_("x", "y")( - im.plus( - im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 1)("y")) - ) - ) - )(im.as_fieldop(im.ref("deref"))(im.ref("inout")), im.ref("in")), - domain=cartesian_domain, - target=im.ref("inout"), - ), - ], - ) - - assert ir == CheckInOutField.apply(ir, offset_provider=offset_provider) - - -def test_check_inout_in_arg_shifted(): - # inout ← (⇑(λ(x, y) → ·⟪IOffₒ, 0ₒ⟫(x) + ·⟪IOffₒ, 0ₒ⟫(y)))((⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(inout), in) +def test_check_inout_in_arg_tuple(): + # inout ← (⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x) + ·⟪IOffₒ, 0ₒ⟫(x)))({inout, in}) ir = program_factory( params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], body=[ itir.SetAt( expr=im.as_fieldop( - im.lambda_("x", "y")( + im.lambda_("x")( im.plus( - im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 0)("y")) + im.deref(im.shift("IOff", 1)("x")), im.deref(im.shift("IOff", 0)("x")) ) ) - )( - im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( - im.ref("inout") - ), - im.ref("in"), - ), + )(im.make_tuple(im.ref("inout"), im.ref("in"))), domain=cartesian_domain, target=im.ref("inout"), ), @@ -226,69 +203,32 @@ def test_check_inout_in_arg_shifted(): CheckInOutField.apply(ir, offset_provider=offset_provider) -def test_check_inout_in_arg_nested_shifted(): - # inout ← (⇑(λ(x, y) → ·⟪IOffₒ, 0ₒ⟫(x) + ·⟪IOffₒ, 0ₒ⟫(y)))( - # (⇑(λ(x) → ·⟪IOffₒ, 0ₒ⟫(x)))((⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(inout)), in - # ) +def test_check_inout_in_make_tuple_as_fieldop_in_arg(): + # inout ← (⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x) + ·⟪IOffₒ, 0ₒ⟫(x)))({(⇑deref)(inout), in}) ir = program_factory( params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], body=[ itir.SetAt( expr=im.as_fieldop( - im.lambda_("x", "y")( + im.lambda_("x")( im.plus( - im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 0)("y")) + im.deref(im.shift("IOff", 1)("x")), im.deref(im.shift("IOff", 0)("x")) ) ) - )( - im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 0)("x"))))( - im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( - im.ref("inout") - ) - ), - im.ref("in"), - ), + )(im.make_tuple(im.as_fieldop(im.ref("deref"))(im.ref("inout")), im.ref("in"))), domain=cartesian_domain, target=im.ref("inout"), ), ], ) - with pytest.raises(ValueError, match="The target inout is also read with an offset."): + with pytest.raises( + ValueError, + match=r"Unexpected as_fieldop argument \{\(⇑deref\)\(inout\), in\}. Expected `make_tuple`, `tuple_get` or `SymRef`. Please run temporary extraction first.", + ): CheckInOutField.apply(ir, offset_provider=offset_provider) -def test_check_inout_in_arg_nested_shift_different_arg(): - # inout ← (⇑(λ(x, y) → ·⟪IOffₒ, 0ₒ⟫(x) + ·⟪IOffₒ, 0ₒ⟫(y)))( - # (⇑(λ(x) → ·⟪IOffₒ, 0ₒ⟫(x)))((⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(in)), inout - # ) - ir = program_factory( - params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)], - body=[ - itir.SetAt( - expr=im.as_fieldop( - im.lambda_("x", "y")( - im.plus( - im.deref(im.shift("IOff", 0)("x")), im.deref(im.shift("IOff", 0)("y")) - ) - ) - )( - im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 0)("x"))))( - im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))( - im.ref("in") - ) - ), - im.ref("inout"), - ), - domain=cartesian_domain, - target=im.ref("inout"), - ), - ], - ) - - assert ir == CheckInOutField.apply(ir, offset_provider=offset_provider) - - def test_check_inout_tuple(): # {inout, inout2} ← {(⇑(λ(x) → ·⟪IOffₒ, 1ₒ⟫(x)))(inout[0]), (⇑deref)(inout2)} ir = program_factory( From 2585fce7a4f932aecf8dc7733588486fb707c585 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 8 Aug 2025 13:55:06 +0200 Subject: [PATCH 06/10] Minor --- .../next/iterator/ir_utils/common_pattern_matcher.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 6433ce9e26..f249a10ea0 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from collections.abc import Iterable -from typing import Any, Callable, Generic, List, TypeAlias, TypeGuard, TypeVar, overload +from typing import Any, Callable, Generic, List, TypeAlias, TypeGuard, TypeVar from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im @@ -138,12 +138,6 @@ def is_identity_as_fieldop(node: itir.Expr) -> TypeGuard[_FunCallToFunCallToRef] return False -@overload -def is_tuple_expr_of(pred: Callable[[itir.Expr], bool], expr: itir.Expr) -> bool: ... -@overload -def is_tuple_expr_of(pred: Callable[[GTFNIRExpr], bool], expr: GTFNIRExpr) -> bool: ... - - def is_tuple_expr_of( pred: Callable[[Any], bool], expr: itir.Expr | GTFNIRExpr, From b26f46287e898cd4914b098c32392605fbca4bd8 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 8 Aug 2025 15:32:57 +0200 Subject: [PATCH 07/10] Fix some tests --- src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py | 5 +++-- src/gt4py/next/iterator/transforms/check_inout_field.py | 2 +- src/gt4py/next/iterator/transforms/pass_manager.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index f249a10ea0..aee8464c2e 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -12,6 +12,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr as GTFNIRExpr +from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import FunCall as GTFNIRFunCall, SymRef as GTFNIRSymRef _Fun = TypeVar("_Fun", bound=itir.Expr) @@ -45,8 +46,8 @@ def is_call_to(node: Any, fun: str | Iterable[str]) -> TypeGuard[_FunCallToSymRe assert not isinstance(fun, itir.Node) # to avoid accidentally passing the fun as first argument if isinstance(fun, str): return ( - isinstance(node, itir.FunCall) - and isinstance(node.fun, itir.SymRef) + isinstance(node, itir.FunCall | GTFNIRFunCall) + and isinstance(node.fun, itir.SymRef | GTFNIRSymRef) and node.fun.id == fun ) else: diff --git a/src/gt4py/next/iterator/transforms/check_inout_field.py b/src/gt4py/next/iterator/transforms/check_inout_field.py index 0f56d58d69..417ab60cc5 100644 --- a/src/gt4py/next/iterator/transforms/check_inout_field.py +++ b/src/gt4py/next/iterator/transforms/check_inout_field.py @@ -80,7 +80,7 @@ def visit_nested_make_tuple_tuple_get(expr): visit_nested_make_tuple_tuple_get(arg) def check_expr(fun, args, offset_provider): - shifts = trace_shifts.trace_stencil(fun, num_args=len(args)) + shifts = trace_shifts.trace_stencil(fun.args[0], num_args=len(args)) target_subexprs = extract_subexprs(node.target) for arg, shift in zip(args, shifts): arg_subexprs = extract_subexprs(arg) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index f4eddb6e71..c7fec82a6b 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -84,7 +84,6 @@ def apply_common_transforms( ir = inline_dynamic_shifts.InlineDynamicShifts.apply( ir ) # domain inference does not support dynamic offsets yet - ir = check_inout_field.CheckInOutField.apply(ir, offset_provider=offset_provider) ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) @@ -94,6 +93,8 @@ def apply_common_transforms( offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, ) + ir = check_inout_field.CheckInOutField.apply(ir, offset_provider=offset_provider) + ir = remove_broadcast.RemoveBroadcast.apply(ir) ir = concat_where.transform_to_as_fieldop(ir) From 6ee3f2d0f4595f1b7aa41c98477d8f8162051bcf Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 8 Aug 2025 15:47:15 +0200 Subject: [PATCH 08/10] Fix import --- src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index aee8464c2e..21a99f7e8f 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -11,8 +11,11 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr as GTFNIRExpr -from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import FunCall as GTFNIRFunCall, SymRef as GTFNIRSymRef +from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import FunCall as GTFNIRFunCall +from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import ( + Expr as GTFNIRExpr, + SymRef as GTFNIRSymRef, +) _Fun = TypeVar("_Fun", bound=itir.Expr) From 92da44d0a5ca2d2c0e3dbc971a196da1c713dc8b Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 8 Aug 2025 16:57:54 +0200 Subject: [PATCH 09/10] Fix tests and refactor shift filtering --- .../iterator/transforms/check_inout_field.py | 48 ++++++++++++------- .../codegens/gtfn/itir_to_gtfn_ir.py | 24 +--------- 2 files changed, 33 insertions(+), 39 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/check_inout_field.py b/src/gt4py/next/iterator/transforms/check_inout_field.py index 417ab60cc5..18067ee98d 100644 --- a/src/gt4py/next/iterator/transforms/check_inout_field.py +++ b/src/gt4py/next/iterator/transforms/check_inout_field.py @@ -10,6 +10,7 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next import common +from gt4py.next.common import OffsetProvider from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm from gt4py.next.iterator.transforms import trace_shifts @@ -79,27 +80,42 @@ def visit_nested_make_tuple_tuple_get(expr): for arg in expr.args: visit_nested_make_tuple_tuple_get(arg) - def check_expr(fun, args, offset_provider): + def filter_shifted_args( + shifts: list[set[tuple[itir.OffsetLiteral, ...]]], + args: list[itir.Expr], + offset_provider: OffsetProvider, + ) -> list[itir.Expr]: + """ + Filters out trivial shifts (empty or all horizontal/vertical with zero offset) + and returns filtered shifts and corresponding args. + """ + filtered = [ + arg + for shift, arg in zip(shifts, args) + if shift not in (set(), {()}) + and any( + offset_provider[off.value].kind # type: ignore[index] # mypy not smart enough + not in {common.DimensionKind.HORIZONTAL, common.DimensionKind.VERTICAL} + or val.value != 0 + for off, val in shift + ) + ] + return filtered if filtered else [] + + def check_expr( + fun: itir.FunCall, + args: list[itir.Expr], + offset_provider: OffsetProvider, + ) -> None: shifts = trace_shifts.trace_stencil(fun.args[0], num_args=len(args)) + + shifted_args = filter_shifted_args(shifts, args, offset_provider) target_subexprs = extract_subexprs(node.target) - for arg, shift in zip(args, shifts): + for arg in shifted_args: arg_subexprs = extract_subexprs(arg) for subexpr in arg_subexprs: if subexpr in target_subexprs: - if shift not in (set(), {()}): - # This condition is just to filter out the trivial offsets in the horizontal and vertical. # TODO: remove and add preprocessing of IOff(0) instead - if any( - offset_provider[off.value].kind - not in { - common.DimensionKind.HORIZONTAL, - common.DimensionKind.VERTICAL, - } - or val.value != 0 - for off, val in shift - ): - raise ValueError( - f"The target {node.target} is also read with an offset." - ) + raise ValueError(f"The target {node.target} is also read with an offset.") if not cpm.is_tuple_expr_of(lambda e: isinstance(e, itir.SymRef), arg): raise ValueError( f"Unexpected as_fieldop argument {arg}. Expected `make_tuple`, `tuple_get` or `SymRef`. Please run temporary extraction first." diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 04ce53cccc..c73869e73c 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -56,27 +56,6 @@ _horizontal_dimension = "gtfn::unstructured::dim::horizontal" -def _is_tuple_of_ref_or_literal(expr: itir.Expr) -> bool: - if ( - isinstance(expr, itir.FunCall) - and isinstance(expr.fun, itir.SymRef) - and expr.fun.id == "tuple_get" - and len(expr.args) == 2 - and _is_tuple_of_ref_or_literal(expr.args[1]) - ): - return True - if ( - isinstance(expr, itir.FunCall) - and isinstance(expr.fun, itir.SymRef) - and expr.fun.id == "make_tuple" - and all(_is_tuple_of_ref_or_literal(arg) for arg in expr.args) - ): - return True - if isinstance(expr, (itir.SymRef, itir.Literal)): # move to condition - return True - return False - - def _get_domains(nodes: Iterable[itir.Stmt]) -> Iterable[itir.FunCall]: result = set() for node in nodes: @@ -587,13 +566,12 @@ def visit_IfStmt(self, node: itir.IfStmt, **kwargs: Any) -> IfStmt: def visit_SetAt( self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any ) -> Union[StencilExecution, ScanExecution]: - if cpm.is_tuple_expr_of(lambda e: isinstance(e, (SymRef, Literal)), node.expr): + if cpm.is_tuple_expr_of(lambda e: isinstance(e, (itir.SymRef, itir.Literal)), node.expr): node.expr = im.as_fieldop("deref", node.domain)(node.expr) itir_projector, extracted_expr = ir_utils_misc.extract_projector(node.expr) projector = self.visit(itir_projector, **kwargs) if itir_projector is not None else None node.expr = extracted_expr - assert cpm.is_applied_as_fieldop(node.expr), node.expr stencil = node.expr.fun.args[0] domain = node.domain From 1eed201348f3f89d90d12bfb2d63c7d726637a7c Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 8 Aug 2025 17:25:36 +0200 Subject: [PATCH 10/10] Fix filtering --- src/gt4py/next/iterator/transforms/check_inout_field.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/check_inout_field.py b/src/gt4py/next/iterator/transforms/check_inout_field.py index 18067ee98d..2918404c51 100644 --- a/src/gt4py/next/iterator/transforms/check_inout_field.py +++ b/src/gt4py/next/iterator/transforms/check_inout_field.py @@ -97,7 +97,11 @@ def filter_shifted_args( offset_provider[off.value].kind # type: ignore[index] # mypy not smart enough not in {common.DimensionKind.HORIZONTAL, common.DimensionKind.VERTICAL} or val.value != 0 - for off, val in shift + for off, val in ( + (pair for pair in shift if len(pair) == 2) # set case: skip () + if isinstance(shift, set) + else zip(shift[0::2], shift[1::2]) # tuple/list case + ) ) ] return filtered if filtered else []