Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 140 additions & 3 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,15 @@
from gt4py.next import common
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.transforms import remap_symbols
from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda
from gt4py.next.iterator.type_system import inference as itir_type_inference
from gt4py.next.type_system import type_info, type_specifications as ts


T = TypeVar("T", bound=itir.Expr | itir.Program)


def _is_trivial_tuple_expr(node: itir.Expr):
"""Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof."""
if cpm.is_call_to(node, "make_tuple") and all(
Expand All @@ -46,9 +50,125 @@ def _is_trivial_tuple_expr(node: itir.Expr):


@dataclasses.dataclass
class _NodeReplacer(PreserveLocationVisitor, NodeTranslator):
class _CanonicalizeNamesPostProcessing(remap_symbols.RenameSymbols):
PRESERVED_ANNEX_ATTRS = ("type", "domain", "pre_canonicalization_name")

@classmethod
def apply(cls, node: T, max_count: int) -> T:
name_map = {f"_{i}": f"_{max_count - i}" for i in range(max_count + 1)}
obj = cls()
new_node = obj.visit(node, name_map=name_map)
return new_node


# TODO(tehrengruber): Canonicalize all let vars, then order by their hash
@dataclasses.dataclass
class _CanonicalizeNames(PreserveLocationVisitor, NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("type", "domain")

allow_external_symbols: bool

counter: int = 0
max_count: int = 0

@classmethod
def apply(cls, node: T, allow_external_symbols=False) -> T:
obj = cls(allow_external_symbols=allow_external_symbols)
new_node = obj.visit(node, name_map={})
return _CanonicalizeNamesPostProcessing.apply(new_node, obj.max_count)

# TODO: extend to program

def visit_Program(
self, node: itir.Program, *, name_map: collections.ChainMap | dict
) -> itir.Program:
assert not name_map
# ignore all program params and builtins
return self.generic_visit(node, name_map={k: None for k in node.annex.symtable})

def visit_Lambda(
self, node: itir.Lambda, *, name_map: collections.ChainMap | dict
) -> itir.Lambda:
initial_count = self.counter

local_name_map: dict[str, itir.Sym] = {}
# go in reverse order so that after postprocessing we have forward order
for param in reversed(node.params):
new_sym = im.sym(f"_{self.counter}", param.type)
self.counter += 1
new_sym.annex.pre_canonicalization_name = param.id
local_name_map[param.id] = new_sym

new_node = im.lambda_(*reversed(local_name_map.values()))(
# TODO: check what happens if there is a collision between local_name_map and name_map
self.visit(node.expr, name_map=collections.ChainMap(local_name_map, name_map))
)

self.max_count = max(0, self.counter - 1, self.max_count)
self.counter = initial_count

return new_node

def visit_SymRef(
self, node: itir.SymRef, *, name_map: collections.ChainMap | dict
) -> itir.SymRef:
if self.allow_external_symbols and node.id not in name_map:
return node

if name_map[node.id] is None:
return node

return im.ref(name_map[node.id].id, node.type)


@dataclasses.dataclass
class _RestoreSymbolNames(PreserveLocationVisitor, NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("type", "domain")

allow_external_symbols: bool

@classmethod
def apply(cls, node: itir.Node, allow_external_symbols=False) -> itir.Node:
return cls(allow_external_symbols=allow_external_symbols).visit(node, name_map={})

def visit_Program(
self, node: itir.Program, *, name_map: collections.ChainMap | dict
) -> itir.Program:
# ignore all program params and builtins
assert not name_map
return self.generic_visit(node, name_map={k: None for k in node.annex.symtable})

def visit_Lambda(
self, node: itir.Lambda, *, name_map: collections.ChainMap | dict
) -> itir.Lambda:
local_name_map: dict[str, itir.Sym] = {}
for param in node.params:
# assert hasattr(param.annex, "pre_canonicalization_name")
if not hasattr(param.annex, "pre_canonicalization_name"):
local_name_map[param.id] = param # TODO: check this is _cs...
else:
local_name_map[param.id] = im.sym(param.annex.pre_canonicalization_name, param.type)

return im.lambda_(*local_name_map.values())(
self.visit(node.expr, name_map=collections.ChainMap(local_name_map, name_map))
)

def visit_SymRef(
self, node: itir.SymRef, *, name_map: collections.ChainMap | dict
) -> itir.SymRef:
if self.allow_external_symbols and node.id not in name_map:
return node

if name_map[node.id] is None:
return node

return im.ref(name_map[node.id].id, node.type)


@dataclasses.dataclass
class _NodeReplacer(PreserveLocationVisitor, NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("type", "domain", "pre_canonicalization_name")

expr_map: dict[int, itir.SymRef]

def visit_Expr(self, node: itir.Node) -> itir.Node:
Expand Down Expand Up @@ -416,6 +536,8 @@ def extract_subexpression(

@dataclasses.dataclass(frozen=True)
class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("pre_canonicalization_name",)

"""
Perform common subexpression elimination.

Expand Down Expand Up @@ -447,6 +569,7 @@ def apply(
node: ProgramOrExpr,
within_stencil: bool | None = None,
offset_provider_type: common.OffsetProviderType | None = None,
canonicalize: bool = True,
) -> ProgramOrExpr:
is_program = isinstance(node, itir.Program)
if is_program:
Expand All @@ -461,7 +584,21 @@ def apply(
node = itir_type_inference.infer(
node, offset_provider_type=offset_provider_type, allow_undeclared_symbols=not is_program
)
return cls().visit(node, within_stencil=within_stencil)

if canonicalize:
# TODO: double check if allow_external_symbols might be dangerous when not in testing and an expr, probably yes so warn
node = _CanonicalizeNames.apply(
node, allow_external_symbols=not isinstance(node, itir.Program)
)

new_node = cls().visit(node, within_stencil=within_stencil)

if canonicalize:
new_node = _RestoreSymbolNames.apply(
new_node, allow_external_symbols=not isinstance(node, itir.Program)
)

return new_node

def generic_visit(self, node, **kwargs):
if cpm.is_call_to(node, "as_fieldop"):
Expand All @@ -479,7 +616,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs):
def predicate(subexpr: itir.Expr, num_occurences: int):
# note: be careful here with the syntatic context: the expression might be in local
# view, even though the syntactic context of `node` is in field view.
# note: what is extracted is sketched in the docstring above. keep it updated.
# note: what is extracted is sketched in the docstring above. keep it up-to-date.
if num_occurences > 1:
if within_stencil:
# TODO(tehrengruber): Lists must not be extracted to avoid errors in partial
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/next/iterator/transforms/remap_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def visit_Sym(
self, node: ir.Sym, *, name_map: Dict[str, str], active: Optional[Set[str]] = None
):
if active and node.id in active:
return ir.Sym(id=name_map.get(node.id, node.id))
sym = ir.Sym(id=name_map.get(node.id, node.id))
type_inference.copy_type(from_=node, to=sym, allow_untyped=True)
return sym
return node

def visit_SymRef(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Sequence

import pytest
import textwrap
Expand All @@ -16,6 +17,8 @@
from gt4py.next.type_system import type_specifications as ts
from gt4py.next.iterator.transforms.cse import (
CommonSubexpressionElimination as CSE,
_CanonicalizeNames,
_RestoreSymbolNames,
extract_subexpression,
)

Expand Down Expand Up @@ -360,3 +363,67 @@ def test_sym_ref_collection_from_lambda(opaque_fun):

actual = CSE.apply(testee, within_stencil=True)
assert actual == expected # no extraction should happen


@pytest.mark.parametrize(
"testees",
(
[im.lambda_("a")("a"), im.lambda_("b")("b")],
[
im.lambda_("a1")(im.lambda_("b1")(im.plus("a1", "b1"))),
im.lambda_("a1")(im.lambda_("b2")(im.plus("a1", "b2"))),
],
[
im.lambda_("a1")(im.lambda_("b1")(im.plus("a1", "b1"))),
im.lambda_("a2")(im.lambda_("b1")(im.plus("a2", "b1"))),
],
[
im.lambda_("a1")(im.lambda_("b1")(im.plus("a1", "b1"))),
im.lambda_("a2")(im.lambda_("b2")(im.plus("a2", "b2"))),
],
[
im.lambda_("a1")(im.make_tuple("a1", im.lambda_("a1")("a1"))),
im.lambda_("a2")(im.make_tuple("a2", im.lambda_("a2")("a2"))),
],
),
)
def test_equivalent_are_equal_after_canonicalization(testees: Sequence[ir.Expr]):
transformed = {
_CanonicalizeNames.apply(testee, allow_external_symbols=True) for testee in testees
}

assert len(transformed) == 1


def test_canonicalize_symbol_names():
testee = im.let(("a", 1), ("b", 2))(
im.make_tuple(im.plus("a", "b"), im.let("c", 3)(im.plus("c", im.plus("a", "b"))))
)
expected = im.let(("_1", 1), ("_2", 2))(
im.make_tuple(im.plus("_1", "_2"), im.let("_0", 3)(im.plus("_0", im.plus("_1", "_2"))))
)

actual = _CanonicalizeNames.apply(testee, allow_external_symbols=True)
assert actual == expected

restored = _RestoreSymbolNames.apply(actual, allow_external_symbols=True)
assert restored == testee


# TODO: collisions are likely not occuring in cse. describe in the pass why
# def test_restore_symbol_names():
# testee = im.let(("a", 1), ("b", 2))(
# im.make_tuple(
# im.plus("external")(im.plus("a", "b")),
# im.let("c", 3)(im.plus("c", im.plus("a", "b")))
# )
# )
#
# _CanonicalizeNames.apply(testee, allow_external_symbols=True)


def test_cse_can_required():
testee = im.plus(im.let("x1", 1)(im.plus("x1", "y")), im.let("x2", 1)(im.plus("x2", "y")))
expected = im.let("_cs_1", im.let("x1", 1)(im.plus("x1", "y")))(im.plus("_cs_1", "_cs_1"))
actual = CSE.apply(testee, within_stencil=True)
assert actual == expected
Loading