diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index f4aee67332..bdcdf6d991 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -992,7 +992,9 @@ def _concat_where( return cls_.from_array(result_array, domain=result_domain) -NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] # TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR +NdArrayField.register_builtin_func( + experimental.concat_where, _concat_where +) # TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR def _make_reduction( diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index b30b25b309..7de0f8b1cb 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -6,11 +6,11 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Tuple +from typing import Tuple, Union from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset, WhereBuiltinFunction +from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset, WhereBuiltInFunction @BuiltInFunction @@ -18,12 +18,21 @@ def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivi raise NotImplementedError() -@WhereBuiltinFunction -def concat_where( - cond: common.Domain, - true_field: common.Field | core_defs.ScalarT | Tuple, - false_field: common.Field | core_defs.ScalarT | Tuple, - /, +@WhereBuiltInFunction +def concat_where( # TODO: support variable argument numbers + first_arg: Union[common.Domain, tuple[common.Domain, common.Field | core_defs.ScalarT | Tuple]], + true_field: common.Domain + | common.Field + | core_defs.ScalarT + | Tuple + | tuple[common.Domain, common.Field | core_defs.ScalarT | Tuple], + false_field: common.Domain + | common.Field + | core_defs.ScalarT + | Tuple + | tuple[common.Domain, common.Field | core_defs.ScalarT | Tuple], + # further_field: common.Domain | common.Field | core_defs.ScalarT | Tuple | tuple[common.Domain, common.Field | core_defs.ScalarT | Tuple], + # *extra_pairs, # TODO: this doesn't seem to work ) -> common.Field | Tuple: """ Concatenates two field fields based on a 1D mask. diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index b1611209e8..30b790161b 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -141,10 +141,14 @@ def __gt_type__(self) -> ts.FunctionType: FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple]) -class WhereBuiltinFunction( - BuiltInFunction[_R, [CondT, FieldT, FieldT]], Generic[_R, CondT, FieldT] -): - def __call__(self, cond: CondT, true_field: FieldT, false_field: FieldT) -> _R: +class WhereBuiltInFunction(BuiltInFunction): + def __call__(self, *args) -> _R: + if len(args) == 3 and not isinstance(args[0], tuple): + return self._call_orig(*args) + else: + return self._call_extended(*args) + + def _call_orig(self, cond, true_field, false_field): if isinstance(true_field, tuple) or isinstance(false_field, tuple): if not (isinstance(true_field, tuple) and isinstance(false_field, tuple)): raise ValueError( @@ -155,9 +159,18 @@ def __call__(self, cond: CondT, true_field: FieldT, false_field: FieldT) -> _R: raise ValueError( "Tuple of different size not allowed." ) # TODO(havogt) find a strategy to unify parsing and embedded error messages - return tuple(self(cond, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` + return tuple(self(cond, t, f) for t, f in zip(true_field, false_field)) return super().__call__(cond, true_field, false_field) + def _call_extended(self, *args): + pairs = args[:-1] + default = args[-1] + result = default + # TODO: are more checks needed here? + for cond, value in reversed(pairs): + result = self._call_orig(cond, value, result) + return result + @BuiltInFunction def neighbor_sum(field: common.Field, /, axis: common.Dimension) -> common.Field: @@ -185,7 +198,7 @@ def broadcast( return field # type: ignore[return-value] # see comment above -@WhereBuiltinFunction +@WhereBuiltInFunction def where( mask: common.Field, true_field: common.Field | core_defs.ScalarT | Tuple, diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 9761968acb..1f525c1355 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -1033,34 +1033,62 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: ) def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: - cond_type, true_branch_type, false_branch_type = (arg.type for arg in node.args) + if len(node.args) == 3 and not isinstance(node.args[0].type, ts.TupleType): + # classic case: concat_where(condition, true_branch, false_branch) + conditions = [node.args[0].type] + values = [node.args[1].type, node.args[2].type] + else: + # extended case: concat_where((cond1, val1), (cond2, val2), ..., default) + conditions = [] + values = [] + + for pair_arg in node.args[:-1]: + pair_type = pair_arg.type + assert isinstance(pair_type, ts.TupleType) and len(pair_type.types) == 2, ( + f"Each condition-value pair must be a 2-tuple, got {pair_type}" + ) + + cond_type, value_type = pair_type.types + assert isinstance(cond_type, ts.DomainType), ( + f"Condition must be a DomainType, got {cond_type}" + ) + + conditions.append(cond_type) + values.append(value_type) + + values.append(node.args[-1].type) + + assert all(isinstance(c, ts.DomainType) for c in conditions) - assert isinstance(cond_type, ts.DomainType) assert all( isinstance(el, (ts.FieldType, ts.ScalarType)) - for arg in (true_branch_type, false_branch_type) - for el in type_info.primitive_constituents(arg) + for value_type in values + for el in type_info.primitive_constituents(value_type) ) @utils.tree_map( collection_type=ts.TupleType, result_collection_constructor=lambda _, elts: ts.TupleType(types=list(elts)), ) - def deduce_return_type( - tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType - ) -> ts.FieldType: - if (t_dtype := type_info.extract_dtype(tb)) != (f_dtype := type_info.extract_dtype(fb)): + def deduce_return_type(*value_types: ts.FieldType | ts.ScalarType) -> ts.FieldType: + dtypes = [type_info.extract_dtype(vt) for vt in value_types] + if not all(dtype == dtypes[0] for dtype in dtypes): raise errors.DSLError( node.location, - f"Field arguments must be of same dtype, got '{t_dtype}' != '{f_dtype}'.", + f"All field arguments must be of same dtype, got {dtypes}.", ) - return_dims = promote_dims( - cond_type.dims, type_info.extract_dims(type_info.promote(tb, fb)) - ) - return_type = ts.FieldType(dims=return_dims, dtype=t_dtype) + + all_dims = [c.dims for c in conditions] + all_dims += [type_info.extract_dims(vt) for vt in value_types] + + return_dims = all_dims[0] + for dims in all_dims[1:]: + return_dims = promote_dims(return_dims, dims) + + return_type = ts.FieldType(dims=return_dims, dtype=dtypes[0]) return return_type - return_type = deduce_return_type(true_branch_type, false_branch_type) + return_type = deduce_return_type(*values) return foast.Call( func=node.func, diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index ec5c43763b..8af1ff168e 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -27,7 +27,7 @@ from gt4py.next.ffront.foast_passes import utils as foast_utils from gt4py.next.ffront.stages import AOT_FOP, FOP from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import constant_folding from gt4py.next.otf import toolchain, workflow from gt4py.next.type_system import type_info, type_specifications as ts, type_translation as tt @@ -419,8 +419,21 @@ def create_if( return im.let(cond_symref_name, cond_)(result) def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - domain, true_branch, false_branch = self.visit(node.args, **kwargs) - return im.concat_where(domain, true_branch, false_branch) + visited_args = [self.visit(arg, **kwargs) for arg in node.args] + if len(visited_args) == 3 and not isinstance(visited_args[0].type, ts.TupleType): + cond, true_branch, false_branch = visited_args + return im.concat_where(cond, true_branch, false_branch) + else: + *pair_args, default = visited_args + + result = default + for pair in reversed(pair_args): + assert cpm.is_call_to(pair, "make_tuple") + cond, value = pair.args + + result = im.concat_where(cond, value, result) + + return result def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return im.call("broadcast")(*self.visit(node.args, **kwargs)) diff --git a/src/gt4py/next/iterator/transforms/concat_where/__init__.py b/src/gt4py/next/iterator/transforms/concat_where/__init__.py index 31f6872aac..4d3ee5568a 100644 --- a/src/gt4py/next/iterator/transforms/concat_where/__init__.py +++ b/src/gt4py/next/iterator/transforms/concat_where/__init__.py @@ -6,6 +6,9 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from gt4py.next.iterator.transforms.concat_where.canonicalize_concat_where import ( + canonicalize_concat_where, +) from gt4py.next.iterator.transforms.concat_where.canonicalize_domain_argument import ( canonicalize_domain_argument, ) @@ -15,4 +18,9 @@ ) -__all__ = ["canonicalize_domain_argument", "expand_tuple_args", "transform_to_as_fieldop"] +__all__ = [ + "canonicalize_concat_where", + "canonicalize_domain_argument", + "expand_tuple_args", + "transform_to_as_fieldop", +] diff --git a/src/gt4py/next/iterator/transforms/concat_where/canonicalize_concat_where.py b/src/gt4py/next/iterator/transforms/concat_where/canonicalize_concat_where.py new file mode 100644 index 0000000000..582367afdc --- /dev/null +++ b/src/gt4py/next/iterator/transforms/concat_where/canonicalize_concat_where.py @@ -0,0 +1,51 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from typing import TypeVar + +from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im + + +PRG = TypeVar("PRG", bound=itir.Program | itir.Expr) + + +class _CanonicalizeConcatWhere(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + + @classmethod + def apply(cls, node: PRG) -> PRG: + return cls().visit(node) + + def visit_FunCall(self, node: itir.FunCall) -> itir.Expr: + node = self.generic_visit(node) + # `concat_where((c1, v1), (c2, v2), (c3, v3),..., default)` + # -> `{concat_where(c1, v1, concat_where(c2, v2, concat_where(c3, v3, default))}` + + if not cpm.is_call_to(node, "concat_where") or not cpm.is_call_to( + node.args[0], "make_tuple" + ): + return node + + *pairs, default = node.args + + if len(pairs) == 0: + return node + + result = default + for pair in reversed(pairs): + cond, value = pair.args if isinstance(pair, itir.FunCall) else pair + result = im.concat_where(cond, value, result) + + return result + + +canonicalize_concat_where = _CanonicalizeConcatWhere.apply diff --git a/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py index 40d956fca0..d435028d36 100644 --- a/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py +++ b/src/gt4py/next/iterator/transforms/concat_where/expand_tuple_args.py @@ -39,9 +39,11 @@ def apply( def transform(self, node: itir.Node, **kwargs) -> Optional[itir.Node]: # `concat_where(cond, {a, b}, {c, d})` - # -> `{concat_where(cond, a, c), concat_where(cond, a, c)}` - if not cpm.is_call_to(node, "concat_where") or not isinstance( - type_inference.reinfer(node.args[1]).type, ts.TupleType + # -> `{concat_where(cond, a, c), concat_where(cond, b, d)}` + if ( + not cpm.is_call_to(node, "concat_where") + or cpm.is_call_to(node.args[0], "make_tuple") + or not isinstance(type_inference.reinfer(node.args[1]).type, ts.TupleType) ): return None diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index b334ad796d..dda8b97d96 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -77,6 +77,7 @@ def apply_common_transforms( # test_can_deref. We didn't notice previously as FieldOpFusion did this implicitly everywhere. ir = inline_lifts.InlineLifts().visit(ir) + ir = concat_where.canonicalize_concat_where(ir) ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program ir = dead_code_elimination.dead_code_elimination( ir, collapse_tuple_uids=collapse_tuple_uids, offset_provider_type=offset_provider_type @@ -176,6 +177,7 @@ def apply_fieldview_transforms( ir = inline_fundefs.InlineFundefs().visit(ir) ir = inline_fundefs.prune_unreferenced_fundefs(ir) + ir = concat_where.canonicalize_concat_where(ir) # required for dead-code-elimination and `prune_empty_concat_where` pass ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program ir = dead_code_elimination.dead_code_elimination(ir, offset_provider_type=offset_provider_type) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 6b9c4341e4..c92ff57e22 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -264,8 +264,25 @@ def index(arg: ts.DimensionType) -> ts.FieldType: @_register_builtin_type_synthesizer -def concat_where( - domain: ts.DomainType, +def concat_where(first_arg, *args): # TODO: fix annotations + """ + classic form: (domain, true_field, false_field) + extended form: (domain, (cond1, val1), ..., default) + """ + if isinstance(first_arg, (ts.DomainType, ts.DeferredType)) and len(args) == 2: + true_field, false_field = args + return _concat_where_type_classic(first_arg, true_field, false_field) + else: + *pairs, default = (first_arg, *args) + result = default + for pair in reversed(pairs): + cond, value = pair + result = _concat_where_type_classic(cond, value, result) + return result + + +def _concat_where_type_classic( + domain: ts.DomainType | ts.DeferredType, true_field: ts.FieldType | ts.TupleType | ts.DeferredType, false_field: ts.FieldType | ts.TupleType | ts.DeferredType, ) -> ts.FieldType | ts.TupleType | ts.DeferredType: @@ -277,7 +294,7 @@ def concat_where( result_collection_constructor=lambda _, elts: ts.TupleType(types=list(elts)), ) def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.ScalarType): - if any(isinstance(b, ts.DeferredType) for b in [tb, fb]): + if any(isinstance(b, ts.DeferredType) for b in [domain, tb, fb]): return ts.DeferredType(constraint=ts.FieldType) tb_dtype, fb_dtype = (type_info.extract_dtype(b) for b in [tb, fb]) @@ -290,8 +307,7 @@ def deduce_return_type(tb: ts.FieldType | ts.ScalarType, fb: ts.FieldType | ts.S return_dims = common.promote_dims( domain.dims, type_info.extract_dims(type_info.promote(tb, fb)) ) - return_type = ts.FieldType(dims=return_dims, dtype=dtype) - return return_type + return ts.FieldType(dims=return_dims, dtype=dtype) return deduce_return_type(true_field, false_field) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 8ce734ef22..9ff26a6619 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -226,6 +226,29 @@ def testee(interior: cases.IJKField, boundary: cases.IJKField) -> cases.IJKField cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) +def test_dimension_two_conditions_no_nesting(cartesian_case): + @gtx.field_operator + def testee(interior: cases.IJKField, boundary: cases.IJKField) -> cases.IJKField: + return concat_where( + ((KDim < 2), boundary), + ((KDim >= 5), boundary), + # ((KDim >= 6), boundary), + interior, + ) + + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref = np.where( + (k[np.newaxis, np.newaxis, :] < 2) | (k[np.newaxis, np.newaxis, :] >= 5), + boundary.asnumpy(), + interior.asnumpy(), + ) + cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref) + + def test_dimension_two_conditions_and(cartesian_case): @gtx.field_operator def testee(interior: cases.KField, boundary: cases.KField, nlev: np.int32) -> cases.KField: @@ -294,6 +317,31 @@ def testee( cases.verify(cartesian_case, testee, inp, boundary, out.domain.shape, out=out, ref=ref) +def test_lap_like_no_nesting(cartesian_case): + @gtx.field_operator + def testee( + inp: cases.IJField, boundary: np.int32, shape: tuple[np.int32, np.int32] + ) -> cases.IJField: + # TODO add support for multi-dimensional concat_where masks + return concat_where( + ((IDim == 0) | (IDim == shape[0] - 1), boundary), + ((JDim == 0) | (JDim == shape[1] - 1), boundary), + inp, + ) + + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + inp = cases.allocate(cartesian_case, testee, "inp", domain=out.domain.slice_at[1:-1, 1:-1])() + boundary = 2 + + ref = np.full(out.domain.shape, np.nan) + ref[0, :] = boundary + ref[:, 0] = boundary + ref[-1, :] = boundary + ref[:, -1] = boundary + ref[1:-1, 1:-1] = inp.asnumpy() + cases.verify(cartesian_case, testee, inp, boundary, out.domain.shape, out=out, ref=ref) + + @pytest.mark.uses_tuple_returns def test_with_tuples(cartesian_case): @gtx.field_operator @@ -335,6 +383,65 @@ def testee( ) +@pytest.mark.uses_tuple_returns +def test_with_tuples_no_nesting(cartesian_case): + @gtx.field_operator + def testee( + interior0: cases.IJKField, + boundary_l0: cases.IJField, + boundary_u0: cases.IJField, + interior1: cases.IJKField, + boundary_l1: cases.IJField, + boundary_u1: cases.IJField, + ) -> tuple[cases.IJKField, cases.IJKField]: + return concat_where( + (KDim == 0, (boundary_l0, boundary_l1)), + (KDim == 9, (boundary_u0, boundary_u1)), + (interior0, interior1), + ) + + interior0 = cases.allocate(cartesian_case, testee, "interior0")() + boundary_l0 = cases.allocate(cartesian_case, testee, "boundary_l0")() + boundary_u0 = cases.allocate(cartesian_case, testee, "boundary_u0")() + interior1 = cases.allocate(cartesian_case, testee, "interior1")() + boundary_l1 = cases.allocate(cartesian_case, testee, "boundary_l1")() + boundary_u1 = cases.allocate(cartesian_case, testee, "boundary_u1")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + k = np.arange(0, cartesian_case.default_sizes[KDim]) + ref0 = np.where( + k[np.newaxis, np.newaxis, :] == 0, + boundary_l0.asnumpy()[:, :, np.newaxis], + np.where( + k[np.newaxis, np.newaxis, :] == 9, + boundary_u0.asnumpy()[:, :, np.newaxis], + interior0.asnumpy(), + ), + ) + ref1 = np.where( + k[np.newaxis, np.newaxis, :] == 0, + boundary_l1.asnumpy()[:, :, np.newaxis], + np.where( + k[np.newaxis, np.newaxis, :] == 9, + boundary_u1.asnumpy()[:, :, np.newaxis], + interior1.asnumpy(), + ), + ) + + cases.verify( + cartesian_case, + testee, + interior0, + boundary_l0, + boundary_u0, + interior1, + boundary_l1, + boundary_u1, + out=out, + ref=(ref0, ref1), + ) + + def test_nested_conditions_with_empty_branches(cartesian_case): @gtx.field_operator def testee(interior: cases.IField, boundary: cases.IField, N: gtx.int32) -> cases.IField: @@ -357,6 +464,30 @@ def testee(interior: cases.IField, boundary: cases.IField, N: gtx.int32) -> case cases.verify(cartesian_case, testee, interior, boundary, N, out=out, ref=ref) +def test_nested_conditions_with_empty_branches_no_nesting(cartesian_case): + @gtx.field_operator + def testee(interior: cases.IField, boundary: cases.IField, N: gtx.int32) -> cases.IField: + return concat_where( + (IDim == 0, boundary), + ((1 <= IDim) & (IDim < N - 1), interior * 2), + (IDim == N - 1, boundary), + interior, + ) + + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + N = cartesian_case.default_sizes[IDim] + + i = np.arange(0, cartesian_case.default_sizes[IDim]) + ref = np.where( + (i[:] == 0) | (i[:] == N - 1), + boundary.asnumpy(), + interior.asnumpy() * 2, + ) + cases.verify(cartesian_case, testee, interior, boundary, N, out=out, ref=ref) + + @pytest.mark.uses_tuple_returns def test_with_tuples_different_domain(cartesian_case): @gtx.field_operator diff --git a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py index 2f73019d7b..6d2f730520 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py @@ -235,7 +235,10 @@ def domain_comparison(a: Field[[TDim], float], b: Field[[TDim], float]): with pytest.raises( errors.DSLError, - match=re.escape("Field arguments must be of same dtype, got 'float64' != 'int32'."), + match=re.escape( + "All field arguments must be of same dtype, " + "got [ScalarType(kind=, shape=None), ScalarType(kind=, shape=None)]." + ), ): _ = FieldOperatorParser.apply_to_function(domain_comparison)