diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index ab9d7a288d..df9af53f83 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -29,7 +29,9 @@ Final, Generic, Literal, + MaybeNestedInTuple, NamedTuple, + NestedTuple, Never, Optional, ParamSpec, @@ -455,7 +457,9 @@ def __getitem__(self, index: slice) -> Self: ... @overload def __getitem__(self, index: Dimension) -> NamedRange: ... - def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: + def __getitem__( + self, index: int | slice | Dimension | Sequence[Dimension] + ) -> NamedRange | Domain: if isinstance(index, Dimension): try: index = self.dims.index(index) @@ -467,6 +471,12 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: dims_slice = self.dims[index] ranges_slice = self.ranges[index] return Domain(dims=dims_slice, ranges=ranges_slice) + if isinstance(index, Sequence) and all(isinstance(d, Dimension) for d in index): + indices = sorted(self.dims.index(d) for d in index) + return Domain( + dims=tuple(self.dims[i] for i in indices), + ranges=tuple(self.ranges[i] for i in indices), + ) raise KeyError("Invalid index type, must be either int, slice, or Dimension.") @@ -582,13 +592,19 @@ def __getstate__(self) -> dict[str, Any]: FiniteDomain: TypeAlias = Domain[FiniteUnitRange] +DomainLikeEl: TypeAlias = Domain | Mapping[Dimension, RangeLike] +DomainLike: TypeAlias = MaybeNestedInTuple[DomainLikeEl] + + +@overload +def domain(domain_like: DomainLikeEl) -> Domain: ... + -DomainLike: TypeAlias = ( - Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike] -) # `Domain` is `Sequence[NamedRange]` and therefore a subset +@overload +def domain(domain_like: NestedTuple[DomainLikeEl]) -> NestedTuple[Domain]: ... -def domain(domain_like: DomainLike) -> Domain: +def domain(domain_like: DomainLike) -> MaybeNestedInTuple[Domain]: """ Construct `Domain` from `DomainLike` object. @@ -610,8 +626,8 @@ def domain(domain_like: DomainLike) -> Domain: """ if isinstance(domain_like, Domain): return domain_like - if isinstance(domain_like, Sequence): - return Domain(*tuple(named_range(d) for d in domain_like)) + if isinstance(domain_like, tuple): + return tuple((domain(el) for el in domain_like)) if isinstance(domain_like, Mapping): if all(isinstance(elem, core_defs.INTEGRAL_TYPES) for elem in domain_like.values()): return Domain( diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 14adb85d0a..fac26a792b 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -247,10 +247,10 @@ def as_field( else: origin = {} actual_domain = common.domain( - [ - (d, (-(start_offset := origin.get(d, 0)), s - start_offset)) + { + d: (-(start_offset := origin.get(d, 0)), s - start_offset) for d, s in zip(domain, data.shape) - ] + } ) else: if origin: @@ -332,7 +332,7 @@ def as_connectivity( raise ValueError( f"Cannot construct 'Field' from array of shape '{data.shape}' and domain '{domain}'." ) - actual_domain = common.domain([(d, (0, s)) for d, s in zip(domain, data.shape)]) + actual_domain = common.domain({d: (0, s) for d, s in zip(domain, data.shape)}) else: actual_domain = common.domain(cast(common.DomainLike, domain)) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 967cf0ab11..b2bb3b5570 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -584,7 +584,7 @@ def _allocate_from_type( case ts.FieldType(dims=dims, dtype=arg_dtype): return strategy.field( allocator=case.allocator, - domain=common.domain(tuple(domain[dim] for dim in dims)), + domain=domain[dims], dtype=dtype or arg_dtype.kind.name.lower(), ) case ts.ScalarType(kind=kind): diff --git a/tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py b/tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py index b69950928d..d85019e24f 100644 --- a/tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py +++ b/tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py @@ -15,7 +15,7 @@ def test_domain_pickle_after_slice(): - domain = common.domain(((I, (2, 4)), (J, (3, 5)))) + domain = common.domain({I: (2, 4), J: (3, 5)}) # use slice_at to populate cached property domain.slice_at[2:5, 5:7] diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 7c29faca92..f71770ad49 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -803,7 +803,7 @@ def test_connectivity_field_inverse_image(): e2v_conn = common._connectivity( np.roll(np.arange(E_START, E_STOP), 1), - domain=common.domain([common.named_range((E, (E_START, E_STOP)))]), + domain=common.domain({E: (E_START, E_STOP)}), codomain=V, ) @@ -835,10 +835,10 @@ def test_connectivity_field_inverse_image_2d_domain(): c2v_conn = common._connectivity( np.asarray([[0, 0, 2], [1, 1, 2], [2, 2, 2]]), domain=common.domain( - [ - common.named_range((C, (C_START, C_STOP))), - common.named_range((C2V, (C2V_START, C2V_STOP))), - ] + { + C: (C_START, C_STOP), + C2V: (C2V_START, C2V_STOP), + } ), codomain=V, ) @@ -890,7 +890,7 @@ def test_connectivity_field_inverse_image_non_contiguous(): e2v_conn = common._connectivity( np.asarray([0, 1, 2, 3, 4, 9, 7, 5, 8, 6]), - domain=common.domain([common.named_range((E, (E_START, E_STOP)))]), + domain=common.domain({E: (E_START, E_STOP)}), codomain=V, ) diff --git a/tests/next_tests/unit_tests/test_allocators.py b/tests/next_tests/unit_tests/test_allocators.py index b37063f4a5..2a3957f849 100644 --- a/tests/next_tests/unit_tests/test_allocators.py +++ b/tests/next_tests/unit_tests/test_allocators.py @@ -155,7 +155,7 @@ def test_allocate(self): ) I = common.Dimension("I") J = common.Dimension("J") - domain = common.domain(((I, (2, 4)), (J, (3, 5)))) + domain = common.domain({I: (2, 4), J: (3, 5)}) dtype = float with pytest.raises(ValueError, match="test error"): allocator.__gt_allocate__(domain, dtype) @@ -166,7 +166,7 @@ def test_allocate(): I = common.Dimension("I") J = common.Dimension("J") - domain = common.domain(((I, (0, 2)), (J, (0, 3)))) + domain = common.domain({I: (0, 2), J: (0, 3)}) dtype = core_defs.dtype(float) # Test with a explicit field allocator diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index deba03651a..ee0d6be79d 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -277,14 +277,13 @@ def test_empty_domain(empty_domain, expected): "domain_like", [ (Domain(dims=(IDim, JDim), ranges=(UnitRange(2, 4), UnitRange(3, 5)))), - ((IDim, (2, 4)), (JDim, (3, 5))), ({IDim: (2, 4), JDim: (3, 5)}), ], ) def test_domain_like(domain_like): - assert domain(domain_like) == Domain( - dims=(IDim, JDim), ranges=(UnitRange(2, 4), UnitRange(3, 5)) - ) + expected = Domain(dims=(IDim, JDim), ranges=(UnitRange(2, 4), UnitRange(3, 5))) + assert domain(domain_like) == expected + assert domain((domain_like, (domain_like, domain_like))) == (expected, (expected, expected)) def test_domain_iteration(a_domain):