Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,9 @@ def _concat_where(
return cls_.from_array(result_array, domain=result_domain)


NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] # TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR
NdArrayField.register_builtin_func(
experimental.concat_where, _concat_where
) # TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR


def _make_reduction(
Expand Down
25 changes: 17 additions & 8 deletions src/gt4py/next/ffront/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,33 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Tuple
from typing import Tuple, Union

from gt4py._core import definitions as core_defs
from gt4py.next import common
from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset, WhereBuiltinFunction
from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset, WhereBuiltInFunction


@BuiltInFunction
def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivity:
raise NotImplementedError()


@WhereBuiltinFunction
def concat_where(
cond: common.Domain,
true_field: common.Field | core_defs.ScalarT | Tuple,
false_field: common.Field | core_defs.ScalarT | Tuple,
/,
@WhereBuiltInFunction
def concat_where( # TODO: support variable argument numbers
first_arg: Union[common.Domain, tuple[common.Domain, common.Field | core_defs.ScalarT | Tuple]],
true_field: common.Domain
| common.Field
| core_defs.ScalarT
| Tuple
| tuple[common.Domain, common.Field | core_defs.ScalarT | Tuple],
false_field: common.Domain
| common.Field
| core_defs.ScalarT
| Tuple
| tuple[common.Domain, common.Field | core_defs.ScalarT | Tuple],
# further_field: common.Domain | common.Field | core_defs.ScalarT | Tuple | tuple[common.Domain, common.Field | core_defs.ScalarT | Tuple],
# *extra_pairs, # TODO: this doesn't seem to work
) -> common.Field | Tuple:
"""
Concatenates two field fields based on a 1D mask.
Expand Down
25 changes: 19 additions & 6 deletions src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,14 @@ def __gt_type__(self) -> ts.FunctionType:
FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple])


class WhereBuiltinFunction(
BuiltInFunction[_R, [CondT, FieldT, FieldT]], Generic[_R, CondT, FieldT]
):
def __call__(self, cond: CondT, true_field: FieldT, false_field: FieldT) -> _R:
class WhereBuiltInFunction(BuiltInFunction):
def __call__(self, *args) -> _R:
if len(args) == 3 and not isinstance(args[0], tuple):
return self._call_orig(*args)
else:
return self._call_extended(*args)

def _call_orig(self, cond, true_field, false_field):
if isinstance(true_field, tuple) or isinstance(false_field, tuple):
if not (isinstance(true_field, tuple) and isinstance(false_field, tuple)):
raise ValueError(
Expand All @@ -155,9 +159,18 @@ def __call__(self, cond: CondT, true_field: FieldT, false_field: FieldT) -> _R:
raise ValueError(
"Tuple of different size not allowed."
) # TODO(havogt) find a strategy to unify parsing and embedded error messages
return tuple(self(cond, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R`
return tuple(self(cond, t, f) for t, f in zip(true_field, false_field))
return super().__call__(cond, true_field, false_field)

def _call_extended(self, *args):
pairs = args[:-1]
default = args[-1]
result = default
# TODO: are more checks needed here?
for cond, value in reversed(pairs):
result = self._call_orig(cond, value, result)
return result


@BuiltInFunction
def neighbor_sum(field: common.Field, /, axis: common.Dimension) -> common.Field:
Expand Down Expand Up @@ -185,7 +198,7 @@ def broadcast(
return field # type: ignore[return-value] # see comment above


@WhereBuiltinFunction
@WhereBuiltInFunction
def where(
mask: common.Field,
true_field: common.Field | core_defs.ScalarT | Tuple,
Expand Down
56 changes: 42 additions & 14 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,34 +1033,62 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
)

def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
cond_type, true_branch_type, false_branch_type = (arg.type for arg in node.args)
if len(node.args) == 3 and not isinstance(node.args[0].type, ts.TupleType):
# classic case: concat_where(condition, true_branch, false_branch)
conditions = [node.args[0].type]
values = [node.args[1].type, node.args[2].type]
else:
# extended case: concat_where((cond1, val1), (cond2, val2), ..., default)
conditions = []
values = []

for pair_arg in node.args[:-1]:
pair_type = pair_arg.type
assert isinstance(pair_type, ts.TupleType) and len(pair_type.types) == 2, (
f"Each condition-value pair must be a 2-tuple, got {pair_type}"
)

cond_type, value_type = pair_type.types
assert isinstance(cond_type, ts.DomainType), (
f"Condition must be a DomainType, got {cond_type}"
)

conditions.append(cond_type)
values.append(value_type)

values.append(node.args[-1].type)

assert all(isinstance(c, ts.DomainType) for c in conditions)

assert isinstance(cond_type, ts.DomainType)
assert all(
isinstance(el, (ts.FieldType, ts.ScalarType))
for arg in (true_branch_type, false_branch_type)
for el in type_info.primitive_constituents(arg)
for value_type in values
for el in type_info.primitive_constituents(value_type)
)

@utils.tree_map(
collection_type=ts.TupleType,
result_collection_constructor=lambda _, elts: ts.TupleType(types=list(elts)),
)
def deduce_return_type(
tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType
) -> ts.FieldType:
if (t_dtype := type_info.extract_dtype(tb)) != (f_dtype := type_info.extract_dtype(fb)):
def deduce_return_type(*value_types: ts.FieldType | ts.ScalarType) -> ts.FieldType:
dtypes = [type_info.extract_dtype(vt) for vt in value_types]
if not all(dtype == dtypes[0] for dtype in dtypes):
raise errors.DSLError(
node.location,
f"Field arguments must be of same dtype, got '{t_dtype}' != '{f_dtype}'.",
f"All field arguments must be of same dtype, got {dtypes}.",
)
return_dims = promote_dims(
cond_type.dims, type_info.extract_dims(type_info.promote(tb, fb))
)
return_type = ts.FieldType(dims=return_dims, dtype=t_dtype)

all_dims = [c.dims for c in conditions]
all_dims += [type_info.extract_dims(vt) for vt in value_types]

return_dims = all_dims[0]
for dims in all_dims[1:]:
return_dims = promote_dims(return_dims, dims)

return_type = ts.FieldType(dims=return_dims, dtype=dtypes[0])
return return_type

return_type = deduce_return_type(true_branch_type, false_branch_type)
return_type = deduce_return_type(*values)

return foast.Call(
func=node.func,
Expand Down
19 changes: 16 additions & 3 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from gt4py.next.ffront.foast_passes import utils as foast_utils
from gt4py.next.ffront.stages import AOT_FOP, FOP
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.transforms import constant_folding
from gt4py.next.otf import toolchain, workflow
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation as tt
Expand Down Expand Up @@ -419,8 +419,21 @@ def create_if(
return im.let(cond_symref_name, cond_)(result)

def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
domain, true_branch, false_branch = self.visit(node.args, **kwargs)
return im.concat_where(domain, true_branch, false_branch)
visited_args = [self.visit(arg, **kwargs) for arg in node.args]
if len(visited_args) == 3 and not isinstance(visited_args[0].type, ts.TupleType):
cond, true_branch, false_branch = visited_args
return im.concat_where(cond, true_branch, false_branch)
else:
*pair_args, default = visited_args

result = default
for pair in reversed(pair_args):
assert cpm.is_call_to(pair, "make_tuple")
cond, value = pair.args

result = im.concat_where(cond, value, result)

return result

def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return im.call("broadcast")(*self.visit(node.args, **kwargs))
Expand Down
10 changes: 9 additions & 1 deletion src/gt4py/next/iterator/transforms/concat_where/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py.next.iterator.transforms.concat_where.canonicalize_concat_where import (
canonicalize_concat_where,
)
from gt4py.next.iterator.transforms.concat_where.canonicalize_domain_argument import (
canonicalize_domain_argument,
)
Expand All @@ -15,4 +18,9 @@
)


__all__ = ["canonicalize_domain_argument", "expand_tuple_args", "transform_to_as_fieldop"]
__all__ = [
"canonicalize_concat_where",
"canonicalize_domain_argument",
"expand_tuple_args",
"transform_to_as_fieldop",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
from typing import TypeVar

from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im


PRG = TypeVar("PRG", bound=itir.Program | itir.Expr)


class _CanonicalizeConcatWhere(PreserveLocationVisitor, NodeTranslator):
PRESERVED_ANNEX_ATTRS = (
"type",
"domain",
)

@classmethod
def apply(cls, node: PRG) -> PRG:
return cls().visit(node)

def visit_FunCall(self, node: itir.FunCall) -> itir.Expr:
node = self.generic_visit(node)
# `concat_where((c1, v1), (c2, v2), (c3, v3),..., default)`
# -> `{concat_where(c1, v1, concat_where(c2, v2, concat_where(c3, v3, default))}`

if not cpm.is_call_to(node, "concat_where") or not cpm.is_call_to(
node.args[0], "make_tuple"
):
return node

*pairs, default = node.args

if len(pairs) == 0:
return node

result = default
for pair in reversed(pairs):
cond, value = pair.args if isinstance(pair, itir.FunCall) else pair
result = im.concat_where(cond, value, result)

return result


canonicalize_concat_where = _CanonicalizeConcatWhere.apply
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ def apply(

def transform(self, node: itir.Node, **kwargs) -> Optional[itir.Node]:
# `concat_where(cond, {a, b}, {c, d})`
# -> `{concat_where(cond, a, c), concat_where(cond, a, c)}`
if not cpm.is_call_to(node, "concat_where") or not isinstance(
type_inference.reinfer(node.args[1]).type, ts.TupleType
# -> `{concat_where(cond, a, c), concat_where(cond, b, d)}`
if (
not cpm.is_call_to(node, "concat_where")
or cpm.is_call_to(node.args[0], "make_tuple")
or not isinstance(type_inference.reinfer(node.args[1]).type, ts.TupleType)
):
return None

Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def apply_common_transforms(
# test_can_deref. We didn't notice previously as FieldOpFusion did this implicitly everywhere.
ir = inline_lifts.InlineLifts().visit(ir)

ir = concat_where.canonicalize_concat_where(ir)
ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program
ir = dead_code_elimination.dead_code_elimination(
ir, collapse_tuple_uids=collapse_tuple_uids, offset_provider_type=offset_provider_type
Expand Down Expand Up @@ -176,6 +177,7 @@ def apply_fieldview_transforms(

ir = inline_fundefs.InlineFundefs().visit(ir)
ir = inline_fundefs.prune_unreferenced_fundefs(ir)
ir = concat_where.canonicalize_concat_where(ir)
# required for dead-code-elimination and `prune_empty_concat_where` pass
ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program
ir = dead_code_elimination.dead_code_elimination(ir, offset_provider_type=offset_provider_type)
Expand Down
26 changes: 21 additions & 5 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,25 @@ def index(arg: ts.DimensionType) -> ts.FieldType:


@_register_builtin_type_synthesizer
def concat_where(
domain: ts.DomainType,
def concat_where(first_arg, *args): # TODO: fix annotations
"""
classic form: (domain, true_field, false_field)
extended form: (domain, (cond1, val1), ..., default)
"""
if isinstance(first_arg, (ts.DomainType, ts.DeferredType)) and len(args) == 2:
true_field, false_field = args
return _concat_where_type_classic(first_arg, true_field, false_field)
else:
*pairs, default = (first_arg, *args)
result = default
for pair in reversed(pairs):
cond, value = pair
result = _concat_where_type_classic(cond, value, result)
return result


def _concat_where_type_classic(
domain: ts.DomainType | ts.DeferredType,
true_field: ts.FieldType | ts.TupleType | ts.DeferredType,
false_field: ts.FieldType | ts.TupleType | ts.DeferredType,
) -> ts.FieldType | ts.TupleType | ts.DeferredType:
Expand All @@ -277,7 +294,7 @@ def concat_where(
result_collection_constructor=lambda _, elts: ts.TupleType(types=list(elts)),
)
def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType):
if any(isinstance(b, ts.DeferredType) for b in [tb, fb]):
if any(isinstance(b, ts.DeferredType) for b in [domain, tb, fb]):
return ts.DeferredType(constraint=ts.FieldType)

tb_dtype, fb_dtype = (type_info.extract_dtype(b) for b in [tb, fb])
Expand All @@ -290,8 +307,7 @@ def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.S
return_dims = common.promote_dims(
domain.dims, type_info.extract_dims(type_info.promote(tb, fb))
)
return_type = ts.FieldType(dims=return_dims, dtype=dtype)
return return_type
return ts.FieldType(dims=return_dims, dtype=dtype)

return deduce_return_type(true_field, false_field)

Expand Down
Loading
Loading