Skip to content

Commit a8b7d55

Browse files
havogttehrengruber
authored andcommitted
feat[next]: GTIR concat_where frontend (remove chained comparison) (GridTools#1998)
GTIR concat_where frontend parts extracted from GridTools#1713. Credits to @SF-N and @tehrengruber . This PR removes the support for chained comparisons like `field_a < field_b < field_c` which cannot be supported in embedded (same as `0 < Dim < 42`) because it translates to a boolean scalar comparison as it is evaluated as `(field_a < field_b) and (field_b < field_c)` with `and` instead of `&`. --------- Co-authored-by: tehrengruber <[email protected]>
1 parent ef51631 commit a8b7d55

File tree

19 files changed

+344
-77
lines changed

19 files changed

+344
-77
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ markers = [
302302
'uses_unstructured_shift: tests that use a unstructured connectivity',
303303
'uses_max_over: tests that use the max_over builtin',
304304
'uses_mesh_with_skip_values: tests that use a mesh with skip values',
305+
'uses_concat_where: tests that use the concat_where builtin',
305306
'uses_program_metrics: tests that require backend support for program metrics',
306307
'checks_specific_error: tests that rely on the backend to produce a specific error message'
307308
]

src/gt4py/next/embedded/nd_array_field.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,7 @@ def _concat_where(
974974
return cls_.from_array(result_array, domain=result_domain)
975975

976976

977-
NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type]
977+
NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] # TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR
978978

979979

980980
def _make_reduction(

src/gt4py/next/ffront/ast_passes/unchain_compares.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def visit_Compare(self, node: ast.Compare) -> ast.Compare | ast.BinOp:
4646

4747
# the remainder of the chain -> right branch of the new tree
4848
# example: ``b > c > d``
49-
remaining_chain = copy.copy(node)
49+
remaining_chain = copy.deepcopy(node)
5050
remaining_chain.left = remaining_chain.comparators.pop(0)
5151
remaining_chain.ops.pop(0)
5252

src/gt4py/next/ffront/experimental.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivi
2020

2121
@WhereBuiltinFunction
2222
def concat_where(
23-
mask: common.Field,
23+
cond: common.Domain,
2424
true_field: common.Field | core_defs.ScalarT | Tuple,
2525
false_field: common.Field | core_defs.ScalarT | Tuple,
2626
/,

src/gt4py/next/ffront/fbuiltins.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp
6666
return ts.OffsetType
6767
elif t is core_defs.ScalarT:
6868
return ts.ScalarType
69+
elif t is common.Domain:
70+
return ts.DomainType
6971
elif t is type:
7072
return (
7173
ts.FunctionType
@@ -135,14 +137,14 @@ def __gt_type__(self) -> ts.FunctionType:
135137
)
136138

137139

138-
MaskT = TypeVar("MaskT", bound=common.Field)
140+
CondT = TypeVar("CondT", bound=Union[common.Field, common.Domain])
139141
FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple])
140142

141143

142144
class WhereBuiltinFunction(
143-
BuiltInFunction[_R, [MaskT, FieldT, FieldT]], Generic[_R, MaskT, FieldT]
145+
BuiltInFunction[_R, [CondT, FieldT, FieldT]], Generic[_R, CondT, FieldT]
144146
):
145-
def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R:
147+
def __call__(self, cond: CondT, true_field: FieldT, false_field: FieldT) -> _R:
146148
if isinstance(true_field, tuple) or isinstance(false_field, tuple):
147149
if not (isinstance(true_field, tuple) and isinstance(false_field, tuple)):
148150
raise ValueError(
@@ -153,8 +155,8 @@ def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R:
153155
raise ValueError(
154156
"Tuple of different size not allowed."
155157
) # TODO(havogt) find a strategy to unify parsing and embedded error messages
156-
return tuple(self(mask, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R`
157-
return super().__call__(mask, true_field, false_field)
158+
return tuple(self(cond, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R`
159+
return super().__call__(cond, true_field, false_field)
158160

159161

160162
@BuiltInFunction

src/gt4py/next/ffront/foast_passes/type_deduction.py

Lines changed: 126 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
import gt4py.next.ffront.field_operator_ast as foast
1212
from gt4py.eve import NodeTranslator, NodeVisitor, traits
13-
from gt4py.next import errors
14-
from gt4py.next.common import DimensionKind
13+
from gt4py.next import errors, utils
14+
from gt4py.next.common import DimensionKind, promote_dims
1515
from gt4py.next.ffront import ( # noqa
1616
dialect_ast_enums,
1717
experimental,
@@ -20,6 +20,7 @@
2020
type_specifications as ts_ffront,
2121
)
2222
from gt4py.next.ffront.foast_passes.utils import compute_assign_indices
23+
from gt4py.next.iterator import builtins
2324
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation
2425

2526

@@ -428,7 +429,7 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs: Any) -> foast.IfStmt:
428429
if not isinstance(new_node.condition.type, ts.ScalarType):
429430
raise errors.DSLError(
430431
node.location,
431-
"Condition for 'if' must be scalar, " f"got '{new_node.condition.type}' instead.",
432+
f"Condition for 'if' must be scalar, got '{new_node.condition.type}' instead.",
432433
)
433434

434435
if new_node.condition.type.kind != ts.ScalarKind.BOOL:
@@ -566,16 +567,10 @@ def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> foast.Compare:
566567
op=node.op, left=new_left, right=new_right, location=node.location, type=new_type
567568
)
568569

569-
def _deduce_compare_type(
570+
def _deduce_arithmetic_compare_type(
570571
self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any
571572
) -> Optional[ts.TypeSpec]:
572-
# check both types compatible
573-
for arg in (left, right):
574-
if not type_info.is_arithmetic(arg.type):
575-
raise errors.DSLError(
576-
arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'."
577-
)
578-
573+
# e.g. `1 < 2`
579574
self._check_operand_dtypes_match(node, left=left, right=right)
580575

581576
try:
@@ -592,6 +587,51 @@ def _deduce_compare_type(
592587
f" in call to '{node.op}'.",
593588
) from ex
594589

590+
def _deduce_dimension_compare_type(
591+
self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any
592+
) -> Optional[ts.TypeSpec]:
593+
# e.g. `IDim > 1`
594+
index_type = ts.ScalarType(
595+
kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())
596+
)
597+
598+
def error_msg(left: ts.TypeSpec, right: ts.TypeSpec) -> str:
599+
return f"Dimension comparison needs to be between a 'Dimension' and index of type '{index_type}', got '{left}' and '{right}'."
600+
601+
if isinstance(left.type, ts.DimensionType):
602+
if not right.type == index_type:
603+
raise errors.DSLError(
604+
right.location,
605+
error_msg(left.type, right.type),
606+
)
607+
return ts.DomainType(dims=[left.type.dim])
608+
elif isinstance(right.type, ts.DimensionType):
609+
if not left.type == index_type:
610+
raise errors.DSLError(
611+
left.location,
612+
error_msg(left.type, right.type),
613+
)
614+
return ts.DomainType(dims=[right.type.dim])
615+
else:
616+
raise AssertionError()
617+
618+
def _deduce_compare_type(
619+
self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any
620+
) -> Optional[ts.TypeSpec]:
621+
# e.g. `1 < 1`
622+
if all(type_info.is_arithmetic(arg) for arg in (left.type, right.type)):
623+
return self._deduce_arithmetic_compare_type(node, left=left, right=right)
624+
# e.g. `IDim > 1`
625+
if any(isinstance(arg, ts.DimensionType) for arg in (left.type, right.type)):
626+
return self._deduce_dimension_compare_type(node, left=left, right=right)
627+
628+
raise errors.DSLError(
629+
left.location,
630+
"Comparison operators can only be used between arithmetic types "
631+
"(scalars, fields) or between a dimension and an index type "
632+
"({builtins.INTEGER_INDEX_BUILTIN}).",
633+
)
634+
595635
def _deduce_binop_type(
596636
self, node: foast.BinOp, *, left: foast.Expr, right: foast.Expr, **kwargs: Any
597637
) -> Optional[ts.TypeSpec]:
@@ -612,37 +652,48 @@ def _deduce_binop_type(
612652
dialect_ast_enums.BinaryOperator.BIT_OR,
613653
dialect_ast_enums.BinaryOperator.BIT_XOR,
614654
}
615-
is_compatible = type_info.is_logical if node.op in logical_ops else type_info.is_arithmetic
616-
617-
# check both types compatible
618-
for arg in (left, right):
619-
if not is_compatible(arg.type):
620-
raise errors.DSLError(
621-
arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'."
622-
)
623-
624-
left_type = cast(ts.FieldType | ts.ScalarType, left.type)
625-
right_type = cast(ts.FieldType | ts.ScalarType, right.type)
626655

627-
if node.op == dialect_ast_enums.BinaryOperator.POW:
628-
return left_type
656+
err_msg = f"Unsupported operand type(s) for {node.op}: '{left.type}' and '{right.type}'."
629657

630-
if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral(
631-
right_type
658+
if isinstance(left.type, (ts.ScalarType, ts.FieldType)) and isinstance(
659+
right.type, (ts.ScalarType, ts.FieldType)
632660
):
633-
raise errors.DSLError(
634-
arg.location,
635-
f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.",
661+
is_compatible = (
662+
type_info.is_logical if node.op in logical_ops else type_info.is_arithmetic
636663
)
664+
for arg in (left, right):
665+
if not is_compatible(arg.type):
666+
raise errors.DSLError(arg.location, err_msg)
637667

638-
try:
639-
return type_info.promote(left_type, right_type)
640-
except ValueError as ex:
641-
raise errors.DSLError(
642-
node.location,
643-
f"Could not promote '{left_type}' and '{right_type}' to common type"
644-
f" in call to '{node.op}'.",
645-
) from ex
668+
if node.op == dialect_ast_enums.BinaryOperator.POW:
669+
return left.type
670+
671+
if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral(
672+
right.type
673+
):
674+
raise errors.DSLError(
675+
arg.location,
676+
f"Type '{right.type}' can not be used in operator '{node.op}', it only accepts 'int'.",
677+
)
678+
679+
try:
680+
return type_info.promote(left.type, right.type)
681+
except ValueError as ex:
682+
raise errors.DSLError(
683+
node.location,
684+
f"Could not promote '{left.type}' and '{right.type}' to common type"
685+
f" in call to '{node.op}'.",
686+
) from ex
687+
elif isinstance(left.type, ts.DomainType) and isinstance(right.type, ts.DomainType):
688+
if node.op not in logical_ops:
689+
raise errors.DSLError(
690+
node.location,
691+
f"{err_msg} Operator "
692+
f"must be one of {', '.join((str(op) for op in logical_ops))}.",
693+
)
694+
return ts.DomainType(dims=promote_dims(left.type.dims, right.type.dims))
695+
else:
696+
raise errors.DSLError(node.location, err_msg)
646697

647698
def _check_operand_dtypes_match(
648699
self, node: foast.BinOp | foast.Compare, left: foast.Expr, right: foast.Expr
@@ -908,6 +959,7 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
908959
)
909960

910961
try:
962+
# TODO(tehrengruber): the construct_tuple_type function doesn't look correct
911963
if isinstance(true_branch_type, ts.TupleType) and isinstance(
912964
false_branch_type, ts.TupleType
913965
):
@@ -943,7 +995,43 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
943995
location=node.location,
944996
)
945997

946-
_visit_concat_where = _visit_where
998+
def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
999+
cond_type, true_branch_type, false_branch_type = (arg.type for arg in node.args)
1000+
1001+
assert isinstance(cond_type, ts.DomainType)
1002+
assert all(
1003+
isinstance(el, (ts.FieldType, ts.ScalarType))
1004+
for arg in (true_branch_type, false_branch_type)
1005+
for el in type_info.primitive_constituents(arg)
1006+
)
1007+
1008+
@utils.tree_map(
1009+
collection_type=ts.TupleType,
1010+
result_collection_constructor=lambda el: ts.TupleType(types=list(el)),
1011+
)
1012+
def deduce_return_type(
1013+
tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType
1014+
) -> ts.FieldType:
1015+
if (t_dtype := type_info.extract_dtype(tb)) != (f_dtype := type_info.extract_dtype(fb)):
1016+
raise errors.DSLError(
1017+
node.location,
1018+
f"Field arguments must be of same dtype, got '{t_dtype}' != '{f_dtype}'.",
1019+
)
1020+
return_dims = promote_dims(
1021+
cond_type.dims, type_info.extract_dims(type_info.promote(tb, fb))
1022+
)
1023+
return_type = ts.FieldType(dims=return_dims, dtype=t_dtype)
1024+
return return_type
1025+
1026+
return_type = deduce_return_type(true_branch_type, false_branch_type)
1027+
1028+
return foast.Call(
1029+
func=node.func,
1030+
args=node.args,
1031+
kwargs=node.kwargs,
1032+
type=return_type,
1033+
location=node.location,
1034+
)
9471035

9481036
def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> foast.Call:
9491037
arg_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type)

src/gt4py/next/ffront/func_to_foast.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import ast
1212
import builtins
13+
import textwrap
1314
import typing
1415
from typing import Any, Callable, Iterable, Mapping, Type
1516

@@ -144,12 +145,11 @@ class FieldOperatorParser(DialectParser[foast.FunctionDefinition]):
144145
"""
145146

146147
@classmethod
147-
def _preprocess_definition_ast(cls, definition_ast: ast.AST) -> ast.AST:
148-
sta = StringifyAnnotationsPass.apply(definition_ast)
149-
ssa = SingleStaticAssignPass.apply(sta)
150-
sat = SingleAssignTargetPass.apply(ssa)
151-
ucc = UnchainComparesPass.apply(sat)
152-
return ucc
148+
def _preprocess_definition_ast(cls, ast: ast.AST) -> ast.AST:
149+
ast = StringifyAnnotationsPass.apply(ast)
150+
ast = SingleStaticAssignPass.apply(ast)
151+
ast = SingleAssignTargetPass.apply(ast)
152+
return ast
153153

154154
@classmethod
155155
def _postprocess_dialect_ast(
@@ -474,10 +474,20 @@ def _visit_stmts(
474474

475475
def visit_Compare(self, node: ast.Compare, **kwargs: Any) -> foast.Compare:
476476
loc = self.get_location(node)
477+
477478
if len(node.ops) != 1 or len(node.comparators) != 1:
478-
# Remove comparison chains in a preprocessing pass
479-
# TODO: maybe add a note to the error about preprocessing passes?
480-
raise errors.UnsupportedPythonFeatureError(loc, "comparison chains")
479+
refactored = UnchainComparesPass.apply(node)
480+
raise errors.DSLError(
481+
loc,
482+
textwrap.dedent(
483+
f"""
484+
Comparison chains are not allowed. Please replace
485+
{ast.unparse(node)}
486+
by
487+
{ast.unparse(refactored)}
488+
""",
489+
),
490+
)
481491
return foast.Compare(
482492
op=self.visit(node.ops[0]),
483493
left=self.visit(node.left),

src/gt4py/next/iterator/type_system/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType:
455455

456456
def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.TupleType:
457457
domain = self.visit(node.domain, ctx=ctx)
458-
assert isinstance(domain, it_ts.DomainType)
458+
assert isinstance(domain, ts.DomainType)
459459
assert domain.dims != "unknown"
460460
assert node.dtype
461461
return type_info.apply_to_primitive_constituents(

src/gt4py/next/iterator/type_system/type_specifications.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@ class NamedRangeType(ts.TypeSpec):
1616
dim: common.Dimension
1717

1818

19-
class DomainType(ts.DataType):
20-
dims: list[common.Dimension]
21-
22-
2319
class OffsetLiteralType(ts.TypeSpec):
2420
value: ts.ScalarType | str
2521

0 commit comments

Comments
 (0)