From 2c237096e55d40445d8d97cd58f8033dc6cda89a Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Sun, 12 Oct 2025 07:43:29 +0200 Subject: [PATCH 1/2] Name agnostic common subexpression elimination --- src/gt4py/next/iterator/transforms/cse.py | 140 +++++++++++++++++- .../next/iterator/transforms/remap_symbols.py | 4 +- .../transforms_tests/test_cse.py | 67 +++++++++ 3 files changed, 207 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 2fcbd5df0d..4cafc798bb 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -25,6 +25,7 @@ from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.transforms import remap_symbols from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.type_system import type_info, type_specifications as ts @@ -46,9 +47,125 @@ def _is_trivial_tuple_expr(node: itir.Expr): @dataclasses.dataclass -class _NodeReplacer(PreserveLocationVisitor, NodeTranslator): +class _CanonicalizeNamesPostProcessing(remap_symbols.RenameSymbols): + PRESERVED_ANNEX_ATTRS = ("type", "domain", "pre_canonicalization_name") + + @classmethod + def apply(cls, node, max_count: int) -> itir.Node: + name_map = {f"_{i}": f"_{max_count - i}" for i in range(max_count + 1)} + obj = cls() + new_node = obj.visit(node, name_map=name_map) + return new_node + + +# TODO(tehrengruber): Canonicalize all let vars, then order by their hash +@dataclasses.dataclass +class _CanonicalizeNames(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("type", "domain") + + allow_external_symbols: bool + + counter: int = 0 + max_count: int = 0 + + @classmethod + def apply(cls, node: itir.Node, allow_external_symbols=False) -> itir.Node: + obj = cls(allow_external_symbols=allow_external_symbols) + new_node = obj.visit(node, name_map={}) + return _CanonicalizeNamesPostProcessing.apply(new_node, obj.max_count) + + # TODO: extend to program + + def visit_Program( + self, node: itir.Program, *, name_map: collections.ChainMap | dict + ) -> itir.Program: + assert not name_map + # ignore all program params and builtins + return self.generic_visit(node, name_map={k: None for k in node.annex.symtable}) + + def visit_Lambda( + self, node: itir.Lambda, *, name_map: collections.ChainMap | dict + ) -> itir.Lambda: + initial_count = self.counter + + local_name_map: dict[str, itir.Sym] = {} + # go in reverse order so that after postprocessing we have forward order + for param in reversed(node.params): + new_sym = im.sym(f"_{self.counter}", param.type) + self.counter += 1 + new_sym.annex.pre_canonicalization_name = param.id + local_name_map[param.id] = new_sym + + new_node = im.lambda_(*reversed(local_name_map.values()))( + # TODO: check what happens if there is a collision between local_name_map and name_map + self.visit(node.expr, name_map=collections.ChainMap(local_name_map, name_map)) + ) + + self.max_count = max(0, self.counter - 1, self.max_count) + self.counter = initial_count + + return new_node + + def visit_SymRef( + self, node: itir.SymRef, *, name_map: collections.ChainMap | dict + ) -> itir.SymRef: + if self.allow_external_symbols and node.id not in name_map: + return node + + if name_map[node.id] is None: + return node + + return im.ref(name_map[node.id].id, node.type) + + +@dataclasses.dataclass +class _RestoreSymbolNames(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type", "domain") + allow_external_symbols: bool + + @classmethod + def apply(cls, node: itir.Node, allow_external_symbols=False) -> itir.Node: + return cls(allow_external_symbols=allow_external_symbols).visit(node, name_map={}) + + def visit_Program( + self, node: itir.Program, *, name_map: collections.ChainMap | dict + ) -> itir.Program: + # ignore all program params and builtins + assert not name_map + return self.generic_visit(node, name_map={k: None for k in node.annex.symtable}) + + def visit_Lambda( + self, node: itir.Lambda, *, name_map: collections.ChainMap | dict + ) -> itir.Lambda: + local_name_map: dict[str, itir.Sym] = {} + for param in node.params: + # assert hasattr(param.annex, "pre_canonicalization_name") + if not hasattr(param.annex, "pre_canonicalization_name"): + local_name_map[param.id] = param # TODO: check this is _cs... + else: + local_name_map[param.id] = im.sym(param.annex.pre_canonicalization_name, param.type) + + return im.lambda_(*local_name_map.values())( + self.visit(node.expr, name_map=collections.ChainMap(local_name_map, name_map)) + ) + + def visit_SymRef( + self, node: itir.SymRef, *, name_map: collections.ChainMap | dict + ) -> itir.SymRef: + if self.allow_external_symbols and node.id not in name_map: + return node + + if name_map[node.id] is None: + return node + + return im.ref(name_map[node.id].id, node.type) + + +@dataclasses.dataclass +class _NodeReplacer(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("type", "domain", "pre_canonicalization_name") + expr_map: dict[int, itir.SymRef] def visit_Expr(self, node: itir.Node) -> itir.Node: @@ -416,6 +533,8 @@ def extract_subexpression( @dataclasses.dataclass(frozen=True) class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("pre_canonicalization_name",) + """ Perform common subexpression elimination. @@ -447,6 +566,7 @@ def apply( node: ProgramOrExpr, within_stencil: bool | None = None, offset_provider_type: common.OffsetProviderType | None = None, + canonicalize: bool = True, ) -> ProgramOrExpr: is_program = isinstance(node, itir.Program) if is_program: @@ -461,7 +581,21 @@ def apply( node = itir_type_inference.infer( node, offset_provider_type=offset_provider_type, allow_undeclared_symbols=not is_program ) - return cls().visit(node, within_stencil=within_stencil) + + if canonicalize: + # TODO: double check if allow_external_symbols might be dangerous when not in testing and an expr, probably yes so warn + node = _CanonicalizeNames.apply( + node, allow_external_symbols=not isinstance(node, itir.Program) + ) + + new_node = cls().visit(node, within_stencil=within_stencil) + + if canonicalize: + new_node = _RestoreSymbolNames.apply( + new_node, allow_external_symbols=not isinstance(node, itir.Program) + ) + + return new_node def generic_visit(self, node, **kwargs): if cpm.is_call_to(node, "as_fieldop"): @@ -479,7 +613,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): def predicate(subexpr: itir.Expr, num_occurences: int): # note: be careful here with the syntatic context: the expression might be in local # view, even though the syntactic context of `node` is in field view. - # note: what is extracted is sketched in the docstring above. keep it updated. + # note: what is extracted is sketched in the docstring above. keep it up-to-date. if num_occurences > 1: if within_stencil: # TODO(tehrengruber): Lists must not be extracted to avoid errors in partial diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index 5495f63ae1..fa689b781b 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -40,7 +40,9 @@ def visit_Sym( self, node: ir.Sym, *, name_map: Dict[str, str], active: Optional[Set[str]] = None ): if active and node.id in active: - return ir.Sym(id=name_map.get(node.id, node.id)) + sym = ir.Sym(id=name_map.get(node.id, node.id)) + type_inference.copy_type(from_=node, to=sym, allow_untyped=True) + return sym return node def visit_SymRef( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index f618ba409d..90ecd23c94 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -5,6 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from typing import Sequence import pytest import textwrap @@ -16,6 +17,8 @@ from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.cse import ( CommonSubexpressionElimination as CSE, + _CanonicalizeNames, + _RestoreSymbolNames, extract_subexpression, ) @@ -360,3 +363,67 @@ def test_sym_ref_collection_from_lambda(opaque_fun): actual = CSE.apply(testee, within_stencil=True) assert actual == expected # no extraction should happen + + +@pytest.mark.parametrize( + "testees", + ( + [im.lambda_("a")("a"), im.lambda_("b")("b")], + [ + im.lambda_("a1")(im.lambda_("b1")(im.plus("a1", "b1"))), + im.lambda_("a1")(im.lambda_("b2")(im.plus("a1", "b2"))), + ], + [ + im.lambda_("a1")(im.lambda_("b1")(im.plus("a1", "b1"))), + im.lambda_("a2")(im.lambda_("b1")(im.plus("a2", "b1"))), + ], + [ + im.lambda_("a1")(im.lambda_("b1")(im.plus("a1", "b1"))), + im.lambda_("a2")(im.lambda_("b2")(im.plus("a2", "b2"))), + ], + [ + im.lambda_("a1")(im.make_tuple("a1", im.lambda_("a1")("a1"))), + im.lambda_("a2")(im.make_tuple("a2", im.lambda_("a2")("a2"))), + ], + ), +) +def test_equivalent_are_equal_after_canonicalization(testees: Sequence[ir.Expr]): + transformed = { + _CanonicalizeNames.apply(testee, allow_external_symbols=True) for testee in testees + } + + assert len(transformed) == 1 + + +def test_canonicalize_symbol_names(): + testee = im.let(("a", 1), ("b", 2))( + im.make_tuple(im.plus("a", "b"), im.let("c", 3)(im.plus("c", im.plus("a", "b")))) + ) + expected = im.let(("_1", 1), ("_2", 2))( + im.make_tuple(im.plus("_1", "_2"), im.let("_0", 3)(im.plus("_0", im.plus("_1", "_2")))) + ) + + actual = _CanonicalizeNames.apply(testee, allow_external_symbols=True) + assert actual == expected + + restored = _RestoreSymbolNames.apply(actual, allow_external_symbols=True) + assert restored == testee + + +# TODO: collisions are likely not occuring in cse. describe in the pass why +# def test_restore_symbol_names(): +# testee = im.let(("a", 1), ("b", 2))( +# im.make_tuple( +# im.plus("external")(im.plus("a", "b")), +# im.let("c", 3)(im.plus("c", im.plus("a", "b"))) +# ) +# ) +# +# _CanonicalizeNames.apply(testee, allow_external_symbols=True) + + +def test_cse_can_required(): + testee = im.plus(im.let("x1", 1)(im.plus("x1", "y")), im.let("x2", 1)(im.plus("x2", "y"))) + expected = im.let("_cs_1", im.let("x1", 1)(im.plus("x1", "y")))(im.plus("_cs_1", "_cs_1")) + actual = CSE.apply(testee, within_stencil=True) + assert actual == expected From 3ece7ca6ec089f235fd9c2a2f55464397b4296f5 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Mon, 13 Oct 2025 10:37:53 +0200 Subject: [PATCH 2/2] Cleanup --- src/gt4py/next/iterator/transforms/cse.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 4cafc798bb..8da300b1bf 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -31,6 +31,9 @@ from gt4py.next.type_system import type_info, type_specifications as ts +T = TypeVar("T", bound=itir.Expr | itir.Program) + + def _is_trivial_tuple_expr(node: itir.Expr): """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" if cpm.is_call_to(node, "make_tuple") and all( @@ -51,7 +54,7 @@ class _CanonicalizeNamesPostProcessing(remap_symbols.RenameSymbols): PRESERVED_ANNEX_ATTRS = ("type", "domain", "pre_canonicalization_name") @classmethod - def apply(cls, node, max_count: int) -> itir.Node: + def apply(cls, node: T, max_count: int) -> T: name_map = {f"_{i}": f"_{max_count - i}" for i in range(max_count + 1)} obj = cls() new_node = obj.visit(node, name_map=name_map) @@ -69,7 +72,7 @@ class _CanonicalizeNames(PreserveLocationVisitor, NodeTranslator): max_count: int = 0 @classmethod - def apply(cls, node: itir.Node, allow_external_symbols=False) -> itir.Node: + def apply(cls, node: T, allow_external_symbols=False) -> T: obj = cls(allow_external_symbols=allow_external_symbols) new_node = obj.visit(node, name_map={}) return _CanonicalizeNamesPostProcessing.apply(new_node, obj.max_count)