From 491044075155794f20b45376a99b6d863362be55 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 9 Jul 2025 13:25:54 +0200 Subject: [PATCH 1/2] fext[next]: concat_where fieldview embedded --- src/gt4py/next/common.py | 53 ++++++++ src/gt4py/next/embedded/nd_array_field.py | 147 +++++++++++----------- 2 files changed, 130 insertions(+), 70 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index dc6f24e9dd..30726ee862 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -109,6 +109,41 @@ def __add__(self, offset: int) -> Connectivity: def __sub__(self, offset: int) -> Connectivity: return self + (-offset) + def __gt__(self, value: core_defs.IntegralScalar) -> Domain: + return Domain(dims=(self,), ranges=(UnitRange(value + 1, Infinity.POSITIVE),)) + + def __ge__(self, value: core_defs.IntegralScalar) -> Domain: + return Domain(dims=(self,), ranges=(UnitRange(value, Infinity.POSITIVE),)) + + def __lt__(self, value: core_defs.IntegralScalar) -> Domain: + return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value),)) + + def __le__(self, value: core_defs.IntegralScalar) -> Domain: + # TODO add test + return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value + 1),)) + + def __eq__(self, value: Dimension | core_defs.IntegralScalar) -> bool | Domain: + if isinstance(value, Dimension): + return self.value == value.value + elif isinstance(value, core_defs.INTEGRAL_TYPES): + # TODO probably only within valid embedded context? + return Domain(dims=(self,), ranges=(UnitRange(value, value + 1),)) + else: + return False + + def __ne__(self, value: Dimension | core_defs.IntegralScalar) -> bool | tuple[Domain, Domain]: + # TODO add test + if isinstance(value, Dimension): + return self.value != value.value + elif isinstance(value, core_defs.INTEGRAL_TYPES): + # TODO probably only within valid embedded context? + return ( + Domain(self, UnitRange(Infinity.NEGATIVE, value)), + Domain(self, UnitRange(value + 1, Infinity.POSITIVE)), + ) + else: + return True + class Infinity(enum.Enum): """Describes an unbounded `UnitRange`.""" @@ -500,6 +535,24 @@ def __and__(self, other: Domain) -> Domain: ) return Domain(dims=broadcast_dims, ranges=intersected_ranges) + def __or__(self, other: Domain) -> Domain: + # TODO support arbitrary union of domains + # TODO add tests + if self.ndim > 1 or other.ndim > 1: + raise NotImplementedError("Union of multidimensional domains is not supported.") + if self.ndim == 0: + return other + if other.ndim == 0: + return self + sorted_ = sorted((self, other), key=lambda x: x.ranges[0].start) + if sorted_[0].ranges[0].stop >= sorted_[1].ranges[0].start: + return Domain( + dims=(self.dims[0],), + ranges=(UnitRange(sorted_[0].ranges[0].start, sorted_[1].ranges[0].stop),), + ) + else: + return (sorted_[0], sorted_[1]) + @functools.cached_property def slice_at(self) -> utils.IndexerCallable[slice, Domain]: """ diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 25ce060c7c..a2101e6c99 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -810,25 +810,6 @@ def _hyperslice( NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) -def _compute_mask_slices( - mask: core_defs.NDArrayObject, -) -> list[tuple[bool, slice]]: - """Take a 1-dimensional mask and return a sequence of mappings from boolean values to slices.""" - # TODO: does it make sense to upgrade this naive algorithm to numpy? - assert mask.ndim == 1 - cur = bool(mask[0].item()) - ind = 0 - res = [] - for i in range(1, mask.shape[0]): - # Use `.item()` to extract the scalar from a 0-d array in case of e.g. cupy - if (mask_i := bool(mask[i].item())) != cur: - res.append((cur, slice(ind, i))) - cur = mask_i - ind = i - res.append((cur, slice(ind, mask.shape[0]))) - return res - - def _trim_empty_domains( lst: Iterable[tuple[bool, common.Domain]], ) -> list[tuple[bool, common.Domain]]: @@ -896,82 +877,108 @@ def _stack_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[c def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field: # TODO(havogt): this function could be extended to a general concat - # currently only concatenate along the given dimension and requires the fields to be ordered + # currently only concatenate along the given dimension + sorted_fields = sorted(fields, key=lambda f: f.domain[dim].unit_range.start) if ( - len(fields) > 1 - and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty() + len(sorted_fields) > 1 + and not embedded_common.domain_intersection(*[f.domain for f in sorted_fields]).is_empty() ): raise ValueError("Fields to concatenate must not overlap.") - new_domain = _stack_domains(*[f.domain for f in fields], dim=dim) + new_domain = _stack_domains(*[f.domain for f in sorted_fields], dim=dim) if new_domain is None: raise embedded_exceptions.NonContiguousDomain(f"Cannot concatenate fields along {dim}.") - nd_array_class = _get_nd_array_class(*fields) + nd_array_class = _get_nd_array_class(*sorted_fields) return nd_array_class.from_array( nd_array_class.array_ns.concatenate( - [nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) for f in fields], + [ + nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) + for f in sorted_fields + ], axis=new_domain.dim_index(dim, allow_missing=False), ), domain=new_domain, ) +def _invert_domain( + domains: common.Domain | tuple[common.Domain], +) -> common.Domain | tuple[common.Domain, ...]: + if not isinstance(domains, tuple): + domains = (domains,) + + assert all(d.ndim == 1 for d in domains) + dim = domains[0].dims[0] + assert all(d.dims[0] == dim for d in domains) + sorted_domains = sorted(domains, key=lambda d: d.ranges[0].start) + + result = [] + if domains[0].ranges[0].start is not common.Infinity.NEGATIVE: + result.append( + common.Domain( + dims=(dim,), + ranges=(common.UnitRange(common.Infinity.NEGATIVE, domains[0].ranges[0].start),), + ) + ) + for i in range(len(sorted_domains) - 1): + if sorted_domains[i].ranges[0].stop != sorted_domains[i + 1].ranges[0].start: + result.append( + common.Domain( + dims=(dim,), + ranges=( + common.UnitRange( + sorted_domains[i].ranges[0].stop, sorted_domains[i + 1].ranges[0].start + ), + ), + ) + ) + if domains[-1].ranges[0].stop is not common.Infinity.POSITIVE: + result.append( + common.Domain( + dims=(dim,), + ranges=(common.UnitRange(domains[-1].ranges[0].stop, common.Infinity.POSITIVE),), + ) + ) + return tuple(result) + + +def _intersect_multiple( + domain: common.Domain, domains: common.Domain | tuple[common.Domain] +) -> tuple[common.Domain, ...]: + if not isinstance(domains, tuple): + domains = (domains,) + + return tuple( + intersection + for d in domains + if not (intersection := embedded_common.domain_intersection(domain, d)).is_empty() + ) + + def _concat_where( - mask_field: common.Field, true_field: common.Field, false_field: common.Field + masks: common.Domain | tuple[common.Domain, ...], + true_field: common.Field, + false_field: common.Field, ) -> common.Field: - cls_ = _get_nd_array_class(mask_field, true_field, false_field) - xp = cls_.array_ns - if mask_field.domain.ndim != 1: + if not isinstance(masks, tuple): + masks = (masks,) + if any(m.ndim for m in masks) != 1: raise NotImplementedError( "'concat_where': Can only concatenate fields with a 1-dimensional mask." ) - mask_dim = mask_field.domain.dims[0] + mask_dim = masks[0].dims[0] # intersect the field in dimensions orthogonal to the mask, then all slices in the mask field have same domain t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim) - # TODO(havogt): for clarity, most of it could be implemented on named_range in the masked dimension, but we currently lack the utils - # compute the consecutive ranges (first relative, then domain) of true and false values - mask_values_to_slices_mapping: Iterable[tuple[bool, slice]] = _compute_mask_slices( - mask_field.ndarray - ) - mask_values_to_domain_mapping: Iterable[tuple[bool, common.Domain]] = ( - (mask, mask_field.domain.slice_at[domain_slice]) - for mask, domain_slice in mask_values_to_slices_mapping - ) - # mask domains intersected with the respective fields - mask_values_to_intersected_domains_mapping: Iterable[tuple[bool, common.Domain]] = ( - ( - mask_value, - embedded_common.domain_intersection( - t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain - ), - ) - for mask_value, mask_domain in mask_values_to_domain_mapping - ) - - # remove the empty domains from the beginning and end - mask_values_to_intersected_domains_mapping = _trim_empty_domains( - mask_values_to_intersected_domains_mapping - ) - if any(d.is_empty() for _, d in mask_values_to_intersected_domains_mapping): - raise embedded_exceptions.NonContiguousDomain( - f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in mask_values_to_intersected_domains_mapping]}." - ) + true_domains = _intersect_multiple(t_broadcasted.domain, masks) + t_slices = tuple(t_broadcasted[d] for d in true_domains) - # slice the fields with the domain ranges - transformed = [ - t_broadcasted[d] if v else f_broadcasted[d] - for v, d in mask_values_to_intersected_domains_mapping - ] + inverted_masks = _invert_domain(masks) + false_domains = _intersect_multiple(f_broadcasted.domain, inverted_masks) + f_slices = tuple(f_broadcasted[d] for d in false_domains) - # stack the fields together - if transformed: - return _concat(*transformed, dim=mask_dim) - else: - result_domain = common.Domain(common.NamedRange(mask_dim, common.UnitRange(0, 0))) - result_array = xp.empty(result_domain.shape) - return cls_.from_array(result_array, domain=result_domain) + return _concat(*f_slices, *t_slices, dim=mask_dim) NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] # TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR From 13707c30920069a80643f3daa509de7e53055993 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 8 Dec 2025 09:13:08 +0100 Subject: [PATCH 2/2] a few tests, still incomplete --- pyproject.toml | 1 + src/gt4py/next/embedded/nd_array_field.py | 13 ++ tests/next_tests/definitions.py | 5 +- .../embedded_tests/test_nd_array_field.py | 168 ++++++++++-------- 4 files changed, 115 insertions(+), 72 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 84ef79eb95..f6cb2bd71f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -299,6 +299,7 @@ markers = [ 'uses_max_over: tests that use the max_over builtin', 'uses_mesh_with_skip_values: tests that use a mesh with skip values', 'uses_concat_where: tests that use the concat_where builtin', + 'embedded_concat_where_infinite_domain: tests with concat_where resulting in an infinite domain', 'uses_program_metrics: tests that require backend support for program metrics', 'checks_specific_error: tests that rely on the backend to produce a specific error message' ] diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 90b6f273b1..e04575df34 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -973,6 +973,15 @@ def _intersect_multiple( ) +def _size0_field( + nd_array_class: type[NdArrayField], dims: tuple[common.Dimension, ...], dtype: core_defs.DType +) -> NdArrayField: + return nd_array_class.from_array( + nd_array_class.array_ns.empty((0,) * len(dims), dtype=dtype.scalar_type), + domain=common.Domain(dims=dims, ranges=(common.UnitRange(0, 0),) * len(dims)), + ) + + def _concat_where( masks: common.Domain | tuple[common.Domain, ...], true_field: common.Field, @@ -996,6 +1005,10 @@ def _concat_where( false_domains = _intersect_multiple(f_broadcasted.domain, inverted_masks) f_slices = tuple(f_broadcasted[d] for d in false_domains) + if len(t_slices) + len(f_slices) == 0: + # no data to concatenate, return an empty field + nd_array_class = _get_nd_array_class(true_field, false_field) + return _size0_field(nd_array_class, dims=t_broadcasted.domain.dims, dtype=true_field.dtype) return _concat(*f_slices, *t_slices, dim=mask_dim) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 0bff0b0aa7..a6dfd3f101 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -128,6 +128,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_PROGRAM_METRICS = "uses_program_metrics" USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo" USES_CONCAT_WHERE = "uses_concat_where" +EMBEDDED_CONCAT_WHERE_INFINITE_DOMAIN = "embedded_concat_where_infinite_domain" CHECKS_SPECIFIC_ERROR = "checks_specific_error" # Skip messages (available format keys: 'marker', 'backend') @@ -170,7 +171,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): XFAIL, UNSUPPORTED_MESSAGE, ), # we can't extract the field type from scan args - (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), + (EMBEDDED_CONCAT_WHERE_INFINITE_DOMAIN, XFAIL, UNSUPPORTED_MESSAGE), ] ROUNDTRIP_SKIP_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_PROGRAM_METRICS, XFAIL, UNSUPPORTED_MESSAGE), @@ -179,7 +180,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), ] GTIR_EMBEDDED_SKIP_LIST = ROUNDTRIP_SKIP_LIST + [ - (USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE), + (EMBEDDED_CONCAT_WHERE_INFINITE_DOMAIN, XFAIL, UNSUPPORTED_MESSAGE) ] GTFN_SKIP_TEST_LIST = ( COMMON_SKIP_TEST_LIST diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index cddea76ba1..ccca5c6779 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -1015,91 +1015,119 @@ def test_hyperslice(index_array, expected): @pytest.mark.uses_concat_where @pytest.mark.parametrize( - "mask_data, true_data, false_data, expected", + "cond, true_data, false_data, expected", [ + (D0 == 0, ([0, 0], None), ([1, 1], None), ([0, 1], None)), + (D0 == -1, ([0, 0], {D0: (-1, 1)}), ([1, 1], {D0: (0, 2)}), ([0, 1, 1], {D0: (-1, 2)})), + (D0 < 0, ([0, 0], {D0: (-2, 0)}), ([1, 1], {D0: (0, 2)}), ([0, 0, 1, 1], {D0: (-2, 2)})), + (D0 == 1, ([0, 0, 0], None), ([1, 1, 1], None), ([1, 0, 1], None)), + # non-contiguous domain + (D0 <= 0, ([0, 0], {D0: (-2, 0)}), ([1, 1], {D0: (0, 2)}), None), + # empty result domain ( - ([True, False, True, False, True], None), - ([1, 2, 3, 4, 5], None), - ([6, 7, 8, 9, 10], None), - ([1, 7, 3, 9, 5], None), - ), - ( - ([True, False, True, False], None), - ([1, 2, 3, 4, 5], {D0: (-2, 3)}), - ([6, 7, 8, 9], {D0: (1, 5)}), - ([3, 6, 5, 8], {D0: (0, 4)}), - ), - ( - ([True, False, True, False, True], None), - ([1, 2, 3, 4, 5], {D0: (-2, 3)}), - ([6, 7, 8, 9, 10], {D0: (1, 6)}), - ([3, 6, 5, 8], {D0: (0, 4)}), - ), - ( - ([True, False, True, False, True], None), - ([1, 2, 3, 4, 5], {D0: (-2, 3)}), - ([6, 7, 8, 9, 10], {D0: (2, 7)}), - None, - ), - ( - # empty result domain - ([True, False, True, False, True], None), - ([1, 2, 3, 4, 5], {D0: (-5, 0)}), - ([6, 7, 8, 9, 10], {D0: (5, 10)}), + D0 < 0, + ([0, 0], {D0: (0, 2)}), + ([1, 1], {D0: (-2, 0)}), ([], {D0: (0, 0)}), ), + # broadcasting from scalar + # pytest.param( + # D0 == 0, + # ([0, 0], None), + # (1, None), + # ([0, 1], None), + # marks=pytest.mark.embedded_concat_where_infinite_domain, + # ), + # different dimensions + # ( + # D0 == 0, + # ([0, 0], {D0: (0, 2)}), + # ([1, 1], {D1: (0, 2)}), + # # TODO + # ), + # 2D ( - ([True, False, True, False, True], None), - ([1, 2, 3, 4, 5], {D0: (-4, 1)}), - ([6, 7, 8, 9, 10], {D0: (5, 10)}), - ([5], {D0: (0, 1)}), - ), - ( - # broadcasting true_field - ([True, False, True, False, True], {D0: 5}), - ([1, 2, 3, 4, 5], {D0: 5}), - ([[6, 11], [7, 12], [8, 13], [9, 14], [10, 15]], {D0: 5, D1: 2}), - ([[1, 1], [7, 12], [3, 3], [9, 14], [5, 5]], {D0: 5, D1: 2}), - ), - ( - ([True, False, True, False, True], None), - (42, None), - ([6, 7, 8, 9, 10], None), - ([42, 7, 42, 9, 42], None), - ), - ( - # parts of mask_ranges are concatenated - ([True, True, False, False], None), - ([1, 2], {D0: (1, 3)}), - ([3, 4], {D0: (1, 3)}), - ([1, 4], {D0: (1, 3)}), - ), - ( - # parts of mask_ranges are concatenated and yield non-contiguous domain - ([True, False, True, False], None), - ([1, 2], {D0: (0, 2)}), - ([3, 4], {D0: (2, 4)}), - None, + ((D0 == 0) & (D1 == 0)), + ([[0, 0], [0, 0]], None), + ([[1, 1], [1, 1]], None), + ([[0, 1], [1, 1]], None), ), + # ( + # ([True, False, True, False, True], None), + # ([1, 2, 3, 4, 5], None), + # ([6, 7, 8, 9, 10], None), + # ([1, 7, 3, 9, 5], None), + # ), + # ( + # ([True, False, True, False], None), + # ([1, 2, 3, 4, 5], {D0: (-2, 3)}), + # ([6, 7, 8, 9], {D0: (1, 5)}), + # ([3, 6, 5, 8], {D0: (0, 4)}), + # ), + # ( + # ([True, False, True, False, True], None), + # ([1, 2, 3, 4, 5], {D0: (-2, 3)}), + # ([6, 7, 8, 9, 10], {D0: (1, 6)}), + # ([3, 6, 5, 8], {D0: (0, 4)}), + # ), + # ( + # ([True, False, True, False, True], None), + # ([1, 2, 3, 4, 5], {D0: (-2, 3)}), + # ([6, 7, 8, 9, 10], {D0: (2, 7)}), + # None, + # ), + # ( + # # empty result domain + # ([True, False, True, False, True], None), + # ([1, 2, 3, 4, 5], {D0: (-5, 0)}), + # ([6, 7, 8, 9, 10], {D0: (5, 10)}), + # ([], {D0: (0, 0)}), + # ), + # ( + # ([True, False, True, False, True], None), + # ([1, 2, 3, 4, 5], {D0: (-4, 1)}), + # ([6, 7, 8, 9, 10], {D0: (5, 10)}), + # ([5], {D0: (0, 1)}), + # ), + # ( + # # broadcasting true_field + # ([True, False, True, False, True], {D0: 5}), + # ([1, 2, 3, 4, 5], {D0: 5}), + # ([[6, 11], [7, 12], [8, 13], [9, 14], [10, 15]], {D0: 5, D1: 2}), + # ([[1, 1], [7, 12], [3, 3], [9, 14], [5, 5]], {D0: 5, D1: 2}), + # ), + # ( + # ([True, False, True, False, True], None), + # (42, None), + # ([6, 7, 8, 9, 10], None), + # ([42, 7, 42, 9, 42], None), + # ), + # ( + # # parts of mask_ranges are concatenated + # ([True, True, False, False], None), + # ([1, 2], {D0: (1, 3)}), + # ([3, 4], {D0: (1, 3)}), + # ([1, 4], {D0: (1, 3)}), + # ), + # ( + # # parts of mask_ranges are concatenated and yield non-contiguous domain + # ([True, False, True, False], None), + # ([1, 2], {D0: (0, 2)}), + # ([3, 4], {D0: (2, 4)}), + # None, + # ), ], ) def test_concat_where( nd_array_implementation, - mask_data: tuple[list[bool], Optional[common.DomainLike]], + cond: common.Domain, true_data: tuple[list[int], Optional[common.DomainLike]], false_data: tuple[list[int], Optional[common.DomainLike]], expected: Optional[tuple[list[int], Optional[common.DomainLike]]], ): - mask_lst, mask_domain = mask_data true_lst, true_domain = true_data false_lst, false_domain = false_data - mask_field = _make_field_or_scalar( - mask_lst, - nd_array_implementation=nd_array_implementation, - domain=common.domain(mask_domain) if mask_domain is not None else None, - dtype=bool, - ) true_field = _make_field_or_scalar( true_lst, nd_array_implementation=nd_array_implementation, @@ -1115,7 +1143,7 @@ def test_concat_where( if expected is None: with pytest.raises(embedded_exceptions.NonContiguousDomain): - nd_array_field._concat_where(mask_field, true_field, false_field) + nd_array_field._concat_where(cond, true_field, false_field) else: expected_lst, expected_domain_like = expected expected_array = np.asarray(expected_lst) @@ -1125,7 +1153,7 @@ def test_concat_where( else _make_default_domain(expected_array.shape) ) - result = nd_array_field._concat_where(mask_field, true_field, false_field) + result = nd_array_field._concat_where(cond, true_field, false_field) assert expected_domain == result.domain np.testing.assert_allclose(result.asnumpy(), expected_array)