Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Iterable
from typing import Any, Generic, List, TypeAlias, TypeGuard, TypeVar
from typing import Any, Callable, Generic, List, TypeAlias, TypeGuard, TypeVar

from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import FunCall as GTFNIRFunCall
from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import (
Expr as GTFNIRExpr,
SymRef as GTFNIRSymRef,
)


_Fun = TypeVar("_Fun", bound=itir.Expr)
Expand Down Expand Up @@ -44,8 +49,8 @@ def is_call_to(node: Any, fun: str | Iterable[str]) -> TypeGuard[_FunCallToSymRe
assert not isinstance(fun, itir.Node) # to avoid accidentally passing the fun as first argument
if isinstance(fun, str):
return (
isinstance(node, itir.FunCall)
and isinstance(node.fun, itir.SymRef)
isinstance(node, itir.FunCall | GTFNIRFunCall)
and isinstance(node.fun, itir.SymRef | GTFNIRSymRef)
and node.fun.id == fun
)
else:
Expand Down Expand Up @@ -135,3 +140,14 @@ def is_identity_as_fieldop(node: itir.Expr) -> TypeGuard[_FunCallToFunCallToRef]
):
return True
return False


def is_tuple_expr_of(
pred: Callable[[Any], bool],
expr: itir.Expr | GTFNIRExpr,
) -> bool:
if is_call_to(expr, "make_tuple"):
return all(is_tuple_expr_of(pred, arg) for arg in expr.args)
if is_call_to(expr, "tuple_get"):
return is_tuple_expr_of(pred, expr.args[1])
return pred(expr)
132 changes: 132 additions & 0 deletions src/gt4py/next/iterator/transforms/check_inout_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import dataclasses

from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.next import common
from gt4py.next.common import OffsetProvider
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.transforms import trace_shifts


@dataclasses.dataclass(frozen=True)
class CheckInOutField(PreserveLocationVisitor, NodeTranslator):
"""
Checks within a SetAt if any fields which are written to are also read with an offset and raises a ValueError in this case.

Example:
>>> from gt4py.next.iterator.transforms import infer_domain
>>> from gt4py.next.type_system import type_specifications as ts
>>> from gt4py.next.iterator.ir_utils import ir_makers as im
>>> float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
>>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL)
>>> i_field_type = ts.FieldType(dims=[IDim], dtype=float_type)
>>> offset_provider = {"IOff": IDim}
>>> cartesian_domain = im.call("cartesian_domain")(
... im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 5)
... )
>>> ir = itir.Program(
... id="test",
... function_definitions=[],
... params=[im.sym("inout", i_field_type), im.sym("in", i_field_type)],
... declarations=[],
... body=[
... itir.SetAt(
... expr=im.as_fieldop(im.lambda_("x")(im.deref(im.shift("IOff", 1)("x"))))(
... im.ref("inout")
... ),
... domain=cartesian_domain,
... target=im.ref("inout"),
... ),
... ],
... )
>>> CheckInOutField.apply(ir, offset_provider=offset_provider)
Traceback (most recent call last):
...
ValueError: The target inout is also read with an offset.
"""

@classmethod
def apply(
cls,
program: itir.Program,
offset_provider: common.OffsetProvider | common.OffsetProviderType,
):
return cls().visit(program, offset_provider=offset_provider)

def visit_SetAt(self, node: itir.SetAt, **kwargs) -> itir.SetAt:
offset_provider = kwargs["offset_provider"]

def extract_subexprs(expr):
"""Return a list of all subexpressions in expr.args, including expr itself."""
subexprs = [expr]
if isinstance(expr, itir.FunCall):
for arg in expr.args:
subexprs.extend(extract_subexprs(arg))
return subexprs

def visit_nested_make_tuple_tuple_get(expr):
"""Recursively visit make_tuple and tuple_get expr and check all as_fieldop subexpressions."""
if cpm.is_applied_as_fieldop(expr):
check_expr(expr.fun, expr.args, offset_provider)
elif cpm.is_call_to(expr, ("make_tuple", "tuple_get")):
for arg in expr.args:
visit_nested_make_tuple_tuple_get(arg)

def filter_shifted_args(
shifts: list[set[tuple[itir.OffsetLiteral, ...]]],
args: list[itir.Expr],
offset_provider: OffsetProvider,
) -> list[itir.Expr]:
"""
Filters out trivial shifts (empty or all horizontal/vertical with zero offset)
and returns filtered shifts and corresponding args.
"""
filtered = [
arg
for shift, arg in zip(shifts, args)
if shift not in (set(), {()})
and any(
offset_provider[off.value].kind # type: ignore[index] # mypy not smart enough
not in {common.DimensionKind.HORIZONTAL, common.DimensionKind.VERTICAL}
or val.value != 0
for off, val in (
(pair for pair in shift if len(pair) == 2) # set case: skip ()
if isinstance(shift, set)
else zip(shift[0::2], shift[1::2]) # tuple/list case
)
)
]
return filtered if filtered else []

def check_expr(
fun: itir.FunCall,
args: list[itir.Expr],
offset_provider: OffsetProvider,
) -> None:
shifts = trace_shifts.trace_stencil(fun.args[0], num_args=len(args))

shifted_args = filter_shifted_args(shifts, args, offset_provider)
target_subexprs = extract_subexprs(node.target)
for arg in shifted_args:
arg_subexprs = extract_subexprs(arg)
for subexpr in arg_subexprs:
if subexpr in target_subexprs:
raise ValueError(f"The target {node.target} is also read with an offset.")
if not cpm.is_tuple_expr_of(lambda e: isinstance(e, itir.SymRef), arg):
raise ValueError(
f"Unexpected as_fieldop argument {arg}. Expected `make_tuple`, `tuple_get` or `SymRef`. Please run temporary extraction first."
)

if cpm.is_applied_as_fieldop(node.expr):
check_expr(node.expr.fun, node.expr.args, offset_provider)
else: # Account for nested im.make_tuple and im.tuple_get
visit_nested_make_tuple_tuple_get(node.expr)
return node
12 changes: 2 additions & 10 deletions src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,6 @@ def _merge_arguments(
return new_args


def _is_tuple_expr_of_literals(expr: itir.Expr):
if cpm.is_call_to(expr, "make_tuple"):
return all(_is_tuple_expr_of_literals(arg) for arg in expr.args)
if cpm.is_call_to(expr, "tuple_get"):
return _is_tuple_expr_of_literals(expr.args[1])
return isinstance(expr, itir.Literal)


def _inline_as_fieldop_arg(
arg: itir.Expr, *, uids: eve_utils.UIDGenerator
) -> tuple[itir.Expr, dict[str, itir.Expr]]:
Expand Down Expand Up @@ -142,7 +134,7 @@ def fuse_as_fieldop(
# transform scalar `if` into per-grid-point `if`
# TODO(tehrengruber): revisit if we want to inline if_
arg = im.op_as_fieldop("if_")(*arg.args)
elif _is_tuple_expr_of_literals(arg):
elif cpm.is_tuple_expr_of(lambda e: isinstance(e, itir.Literal), arg):
arg = im.op_as_fieldop(im.lambda_()(arg))()
else:
raise NotImplementedError()
Expand Down Expand Up @@ -189,7 +181,7 @@ def fuse_as_fieldop(


def _arg_inline_predicate(node: itir.Expr, shifts: set[tuple[itir.OffsetLiteral, ...]]) -> bool:
if _is_tuple_expr_of_literals(node):
if cpm.is_tuple_expr_of(lambda e: isinstance(e, itir.Literal), node):
return True

if (
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from gt4py.next import common
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.transforms import (
check_inout_field,
concat_where,
dead_code_elimination,
fuse_as_fieldop,
Expand Down Expand Up @@ -92,6 +93,8 @@ def apply_common_transforms(
offset_provider=offset_provider,
symbolic_domain_sizes=symbolic_domain_sizes,
)
ir = check_inout_field.CheckInOutField.apply(ir, offset_provider=offset_provider)

ir = remove_broadcast.RemoveBroadcast.apply(ir)

ir = concat_where.transform_to_as_fieldop(ir)
Expand Down
1 change: 1 addition & 0 deletions src/gt4py/next/otf/compiled_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class CompiledProgramsPool:
definition_stage: ffront_stages.ProgramDefinition
program_type: ts_ffront.ProgramType
static_params: Sequence[str] | None = None # not ordered
static_domain_sizes: bool = False

_compiled_programs: eve_utils.CustomMapping = dataclasses.field(
default_factory=lambda: eve_utils.CustomMapping(_hash_compiled_program_unsafe),
Expand Down
32 changes: 7 additions & 25 deletions src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@

from __future__ import annotations

from typing import Callable, ClassVar, Optional, Union
from typing import ClassVar, Optional, Union

from gt4py.eve import Coerced, SymbolName, datamodels
from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait
from gt4py.next import common
from gt4py.next.iterator import builtins
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.program_processors.codegens.gtfn.gtfn_im_ir import ImperativeFunctionDefinition
from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Node, Sym, SymRef

Expand Down Expand Up @@ -97,25 +98,6 @@ class Backend(Node):
domain: Union[SymRef, CartesianDomain, UnstructuredDomain]


def _is_tuple_expr_of(pred: Callable[[Expr], bool], expr: Expr) -> bool:
if (
isinstance(expr, FunCall)
and isinstance(expr.fun, SymRef)
and expr.fun.id == "tuple_get"
and len(expr.args) == 2
and _is_tuple_expr_of(pred, expr.args[1])
):
return True
if (
isinstance(expr, FunCall)
and isinstance(expr.fun, SymRef)
and expr.fun.id == "make_tuple"
and all(_is_tuple_expr_of(pred, arg) for arg in expr.args)
):
return True
return pred(expr)


class SidComposite(Expr):
values: list[Expr]

Expand All @@ -125,7 +107,7 @@ def _values_validator(
) -> None:
if not all(
isinstance(el, (SidFromScalar, SidComposite))
or _is_tuple_expr_of(
or cpm.is_tuple_expr_of(
lambda expr: isinstance(expr, (SymRef, Literal))
or (isinstance(expr, FunCall) and expr.fun == SymRef(id="index")),
el,
Expand All @@ -139,9 +121,9 @@ def _values_validator(

def _might_be_scalar_expr(expr: Expr) -> bool:
if isinstance(expr, BinaryExpr):
return all(_is_tuple_expr_of(_might_be_scalar_expr, arg) for arg in (expr.lhs, expr.rhs))
return all(cpm.is_tuple_expr_of(_might_be_scalar_expr, arg) for arg in (expr.lhs, expr.rhs))
if isinstance(expr, UnaryExpr):
return _is_tuple_expr_of(_might_be_scalar_expr, expr.expr)
return cpm.is_tuple_expr_of(_might_be_scalar_expr, expr.expr)
if (
isinstance(expr, FunCall)
and isinstance(expr.fun, SymRef)
Expand All @@ -150,7 +132,7 @@ def _might_be_scalar_expr(expr: Expr) -> bool:
return all(_might_be_scalar_expr(arg) for arg in expr.args)
if isinstance(expr, CastExpr):
return _might_be_scalar_expr(expr.obj_expr)
if _is_tuple_expr_of(lambda e: isinstance(e, (SymRef, Literal)), expr):
if cpm.is_tuple_expr_of(lambda e: isinstance(e, (SymRef, Literal)), expr):
return True
return False

Expand Down Expand Up @@ -183,7 +165,7 @@ def _arg_validator(
self: datamodels.DataModelTP, attribute: datamodels.Attribute, inputs: list[Expr]
) -> None:
for inp in inputs:
if not _is_tuple_expr_of(
if not cpm.is_tuple_expr_of(
lambda expr: isinstance(expr, (SymRef, SidComposite, SidFromScalar))
or (
isinstance(expr, FunCall)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,6 @@
_horizontal_dimension = "gtfn::unstructured::dim::horizontal"


def _is_tuple_of_ref_or_literal(expr: itir.Expr) -> bool:
if (
isinstance(expr, itir.FunCall)
and isinstance(expr.fun, itir.SymRef)
and expr.fun.id == "tuple_get"
and len(expr.args) == 2
and _is_tuple_of_ref_or_literal(expr.args[1])
):
return True
if (
isinstance(expr, itir.FunCall)
and isinstance(expr.fun, itir.SymRef)
and expr.fun.id == "make_tuple"
and all(_is_tuple_of_ref_or_literal(arg) for arg in expr.args)
):
return True
if isinstance(expr, (itir.SymRef, itir.Literal)):
return True
return False


def _get_domains(nodes: Iterable[itir.Stmt]) -> Iterable[itir.FunCall]:
result = set()
for node in nodes:
Expand Down Expand Up @@ -587,13 +566,12 @@ def visit_IfStmt(self, node: itir.IfStmt, **kwargs: Any) -> IfStmt:
def visit_SetAt(
self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any
) -> Union[StencilExecution, ScanExecution]:
if _is_tuple_of_ref_or_literal(node.expr):
if cpm.is_tuple_expr_of(lambda e: isinstance(e, (itir.SymRef, itir.Literal)), node.expr):
node.expr = im.as_fieldop("deref", node.domain)(node.expr)

itir_projector, extracted_expr = ir_utils_misc.extract_projector(node.expr)
projector = self.visit(itir_projector, **kwargs) if itir_projector is not None else None
node.expr = extracted_expr

assert cpm.is_applied_as_fieldop(node.expr), node.expr
stencil = node.expr.fun.args[0]
domain = node.domain
Expand Down
Loading
Loading