Skip to content

Commit eb19848

Browse files
SF-Nhavogttehrengruberedopao
authored
feat[next]: Integration of concat_where (#1713)
In previous commits the following pieces where already added: - #1998 - #2065 In this PR, `concat_where` tests are added and transformations are integrated. Currently, the feature is disabled for embedded (#2127) and DaCe (PR will be opened once this is merged). Details: - constant folding for `InfinityLiteral`s is added - various tuning to passes for `index`, `concat_where` builtins and `InfinityLiteral` - cases: field constructors now accept `Domain`s (instead of `dict[Dimension, int]) --------- Co-authored-by: Hannes Vogt <[email protected]> Co-authored-by: Till Ehrengruber <[email protected]> Co-authored-by: Edoardo Paone <[email protected]>
1 parent 416e713 commit eb19848

File tree

20 files changed

+680
-96
lines changed

20 files changed

+680
-96
lines changed

src/gt4py/next/ffront/foast_to_gtir.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,9 @@ def create_if(
406406

407407
return im.let(cond_symref_name, cond_)(result)
408408

409-
_visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where
409+
def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
410+
domain, true_branch, false_branch = self.visit(node.args, **kwargs)
411+
return im.concat_where(domain, true_branch, false_branch)
410412

411413
def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
412414
return im.call("broadcast")(*self.visit(node.args, **kwargs))
@@ -488,7 +490,7 @@ def _map(
488490
Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists.
489491
"""
490492
if all(
491-
isinstance(t, ts.ScalarType)
493+
isinstance(t, (ts.ScalarType, ts.DimensionType, ts.DomainType))
492494
for arg_type in original_arg_types
493495
for t in type_info.primitive_constituents(arg_type)
494496
):

src/gt4py/next/iterator/builtins.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,11 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
407407
raise BackendNotSelectedError()
408408

409409

410+
@builtin_dispatch
411+
def concat_where(*args):
412+
raise BackendNotSelectedError()
413+
414+
410415
UNARY_MATH_NUMBER_BUILTINS = {"abs", "neg"}
411416
UNARY_LOGICAL_BUILTINS = {"not_"}
412417
UNARY_MATH_FP_BUILTINS = {
@@ -494,6 +499,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
494499
"scan",
495500
"tuple_get",
496501
"unstructured_domain",
502+
"concat_where",
497503
*ARITHMETIC_BUILTINS,
498504
*TYPE_BUILTINS,
499505
}

src/gt4py/next/iterator/embedded.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,6 +1802,11 @@ def index(axis: common.Dimension) -> common.Field:
18021802
return IndexField(axis)
18031803

18041804

1805+
@builtins.concat_where.register(EMBEDDED)
1806+
def concat_where(*args):
1807+
raise NotImplementedError("To be implemented in frontend embedded.")
1808+
1809+
18051810
def closure(
18061811
domain_: runtime.CartesianDomain | runtime.UnstructuredDomain,
18071812
sten: Callable[..., Any],

src/gt4py/next/iterator/transforms/collapse_tuple.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _with_altered_iterator_position_dims(
5050
)
5151

5252

53-
def _is_trivial_make_tuple_call(node: ir.Expr):
53+
def _is_trivial_make_tuple_call(node: itir.Expr):
5454
"""Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof."""
5555
if not cpm.is_call_to(node, "make_tuple"):
5656
return False
@@ -307,9 +307,10 @@ def transform_propagate_tuple_get(self, node: itir.FunCall, **kwargs) -> Optiona
307307
self.fp_transform(im.tuple_get(idx.value, expr.fun.expr), **kwargs)
308308
)
309309
)(*expr.args)
310-
elif cpm.is_call_to(expr, "if_"):
310+
elif cpm.is_call_to(expr, ("if_", "concat_where")):
311+
fun = expr.fun
311312
cond, true_branch, false_branch = expr.args
312-
return im.if_(
313+
return im.call(fun)(
313314
cond,
314315
self.fp_transform(im.tuple_get(idx.value, true_branch), **kwargs),
315316
self.fp_transform(im.tuple_get(idx.value, false_branch), **kwargs),

src/gt4py/next/iterator/transforms/concat_where/transform_to_as_fieldop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def _in(pos: itir.Expr, domain: itir.Expr) -> itir.Expr:
2525
"""
2626
Given a position and a domain return an expression that evaluates to `True` if the position is inside the domain.
2727
28-
`in_({i, j, k}, u⟨ Iₕ: [i0, i1[, Iₕ: [j0, j1[, Iₕ: [k0, k1[ ⟩`
29-
-> `i0 <= i < i1 & j0 <= j < j1 & k0 <= k < k1`
28+
pos = `{i, j, k}`, domain = `u⟨ Iₕ: [i0, i1[, Iₕ: [j0, j1[, Iₕ: [k0, k1[ ⟩`
29+
-> `((i0 <= i) & (i < i1)) & ((j0 <= j) & (j < j1)) & ((k0 <= k)l & (k < k1))`
3030
"""
3131
ret = [
3232
im.and_(

src/gt4py/next/iterator/transforms/constant_folding.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class Transformation(enum.Flag):
9898
# `if_(True, true_branch, false_branch)` -> `true_branch`
9999
FOLD_IF = enum.auto()
100100

101+
FOLD_INFINITY_ARITHMETIC = enum.auto()
102+
101103
@classmethod
102104
def all(self) -> ConstantFolding.Transformation:
103105
return functools.reduce(operator.or_, self.__members__.values())
@@ -239,3 +241,60 @@ def transform_fold_if(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
239241
assert node.args[0].value == "False"
240242
return node.args[2]
241243
return None
244+
245+
def transform_fold_infinity_arithmetic(self, node: ir.FunCall) -> Optional[ir.Node]:
246+
if cpm.is_call_to(node, "plus"):
247+
# `a + +/-inf` -> `+/-inf`
248+
a, b = node.args
249+
assert not (isinstance(a, ir.InfinityLiteral) and isinstance(b, ir.InfinityLiteral))
250+
for arg in a, b:
251+
if isinstance(arg, ir.InfinityLiteral):
252+
return arg
253+
254+
if cpm.is_call_to(node, "minimum"):
255+
if ir.InfinityLiteral.NEGATIVE in node.args:
256+
# `minimum(-inf, a)` -> `-inf`
257+
return ir.InfinityLiteral.NEGATIVE
258+
if ir.InfinityLiteral.POSITIVE in node.args:
259+
# `minimum(inf, a)` -> `a`
260+
a, b = node.args
261+
return b if a == ir.InfinityLiteral.POSITIVE else a
262+
263+
if cpm.is_call_to(node, "maximum"):
264+
if ir.InfinityLiteral.POSITIVE in node.args:
265+
# `maximum(inf, a)` -> `inf`
266+
return ir.InfinityLiteral.POSITIVE
267+
if ir.InfinityLiteral.NEGATIVE in node.args:
268+
# `maximum(-inf, a)` -> `a`
269+
a, b = node.args
270+
return b if a == ir.InfinityLiteral.NEGATIVE else a
271+
272+
if cpm.is_call_to(node, ("less", "less_equal")):
273+
a, b = node.args
274+
# we don't handle `inf < inf` or `-inf < -inf`.args
275+
assert a != b or not isinstance(a, ir.InfinityLiteral)
276+
277+
# `-inf < v` -> `True`
278+
# `v < inf` -> `True`
279+
if a == ir.InfinityLiteral.NEGATIVE or b == ir.InfinityLiteral.POSITIVE:
280+
return im.literal_from_value(True)
281+
# `inf < v` -> `False`
282+
# `v < -inf ` -> `False`
283+
if a == ir.InfinityLiteral.POSITIVE or b == ir.InfinityLiteral.NEGATIVE:
284+
return im.literal_from_value(False)
285+
286+
if cpm.is_call_to(node, ("greater", "greater_equal")):
287+
a, b = node.args
288+
# we don't handle `inf > inf` or `-inf > -inf`.args
289+
assert a != b or not isinstance(a, ir.InfinityLiteral)
290+
291+
# `inf > v` -> `True`
292+
# `v > -inf ` -> `True`
293+
if a == ir.InfinityLiteral.POSITIVE or b == ir.InfinityLiteral.NEGATIVE:
294+
return im.literal_from_value(True)
295+
# `-inf > v` -> `False`
296+
# `v > inf` -> `False`
297+
if a == ir.InfinityLiteral.NEGATIVE or b == ir.InfinityLiteral.POSITIVE:
298+
return im.literal_from_value(False)
299+
300+
return None

src/gt4py/next/iterator/transforms/cse.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,19 @@ def _is_collectable_expr(node: itir.Node) -> bool:
8787
# backend (single pass eager depth first visit approach)
8888
# do also not collect lifts or applied lifts as they become invisible to the lift inliner
8989
# otherwise
90-
if cpm.is_call_to(node, ("lift", "shift", "reduce", "map_")) or cpm.is_applied_lift(node):
90+
# do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement
91+
# instead of an as_fieldop
92+
if cpm.is_call_to(
93+
node, ("lift", "shift", "reduce", "map_", "index")
94+
) or cpm.is_applied_lift(node):
9195
return False
9296
return True
97+
# do also not collect make_tuple(index) nodes because otherwise the right hand side of SetAts becomes a let statement
98+
# instead of an as_fieldop
99+
if cpm.is_call_to(node, "make_tuple") and all(
100+
cpm.is_call_to(arg, "index") for arg in node.args
101+
):
102+
return False
93103
elif isinstance(node, itir.Lambda):
94104
return True
95105

src/gt4py/next/iterator/transforms/fuse_as_fieldop.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,4 @@ def visit(self, node, **kwargs):
450450

451451
node = super().visit(node, **kwargs)
452452

453-
if isinstance(node, itir.Expr) and hasattr(node.annex, "domain"):
454-
node.annex.domain = node.annex.domain
455-
456453
return node

src/gt4py/next/iterator/transforms/global_tmps.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,11 +329,16 @@ def create_global_tmps(
329329
This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its
330330
arguments into temporaries.
331331
"""
332-
offset_provider_type = common.offset_provider_to_type(offset_provider)
332+
# TODO(tehrengruber): document why to keep existing domains and add test
333333
program = infer_domain.infer_program(
334-
program, offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes
334+
program,
335+
offset_provider=offset_provider,
336+
symbolic_domain_sizes=symbolic_domain_sizes,
337+
keep_existing_domains=True,
338+
)
339+
program = type_inference.infer(
340+
program, offset_provider_type=common.offset_provider_to_type(offset_provider)
335341
)
336-
program = type_inference.infer(program, offset_provider_type=offset_provider_type)
337342

338343
if not uids:
339344
uids = eve_utils.UIDGenerator(prefix="__tmp")

src/gt4py/next/iterator/transforms/pass_manager.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
from gt4py.next import common
1313
from gt4py.next.iterator import ir as itir
1414
from gt4py.next.iterator.transforms import (
15+
concat_where,
1516
dead_code_elimination,
1617
fuse_as_fieldop,
1718
global_tmps,
1819
infer_domain,
20+
infer_domain_ops,
1921
inline_dynamic_shifts,
2022
inline_fundefs,
2123
inline_lifts,
@@ -81,13 +83,19 @@ def apply_common_transforms(
8183
ir = inline_dynamic_shifts.InlineDynamicShifts.apply(
8284
ir
8385
) # domain inference does not support dynamic offsets yet
86+
ir = infer_domain_ops.InferDomainOps.apply(ir)
87+
ir = concat_where.canonicalize_domain_argument(ir)
88+
89+
ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program
8490
ir = infer_domain.infer_program(
8591
ir,
8692
offset_provider=offset_provider,
8793
symbolic_domain_sizes=symbolic_domain_sizes,
8894
)
8995
ir = remove_broadcast.RemoveBroadcast.apply(ir)
9096

97+
ir = concat_where.transform_to_as_fieldop(ir)
98+
9199
for _ in range(10):
92100
inlined = ir
93101

@@ -183,6 +191,11 @@ def apply_fieldview_transforms(
183191
ir = inline_dynamic_shifts.InlineDynamicShifts.apply(
184192
ir
185193
) # domain inference does not support dynamic offsets yet
194+
195+
ir = infer_domain_ops.InferDomainOps.apply(ir)
196+
ir = concat_where.canonicalize_domain_argument(ir)
197+
ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program
198+
186199
ir = infer_domain.infer_program(ir, offset_provider=offset_provider)
187200
ir = remove_broadcast.RemoveBroadcast.apply(ir)
188201
return ir

0 commit comments

Comments
 (0)