Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
Final,
Generic,
Literal,
MaybeNestedInTuple,
NamedTuple,
NestedTuple,
Never,
Optional,
ParamSpec,
Expand Down Expand Up @@ -455,7 +457,9 @@ def __getitem__(self, index: slice) -> Self: ...
@overload
def __getitem__(self, index: Dimension) -> NamedRange: ...

def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain:
def __getitem__(
self, index: int | slice | Dimension | Sequence[Dimension]
) -> NamedRange | Domain:
if isinstance(index, Dimension):
try:
index = self.dims.index(index)
Expand All @@ -467,6 +471,12 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain:
dims_slice = self.dims[index]
ranges_slice = self.ranges[index]
return Domain(dims=dims_slice, ranges=ranges_slice)
if isinstance(index, Sequence) and all(isinstance(d, Dimension) for d in index):
indices = sorted(self.dims.index(d) for d in index)
return Domain(
dims=tuple(self.dims[i] for i in indices),
ranges=tuple(self.ranges[i] for i in indices),
)

raise KeyError("Invalid index type, must be either int, slice, or Dimension.")

Expand Down Expand Up @@ -582,13 +592,19 @@ def __getstate__(self) -> dict[str, Any]:

FiniteDomain: TypeAlias = Domain[FiniteUnitRange]

DomainLikeEl: TypeAlias = Domain | Mapping[Dimension, RangeLike]
DomainLike: TypeAlias = MaybeNestedInTuple[DomainLikeEl]


@overload
def domain(domain_like: DomainLikeEl) -> Domain: ...


DomainLike: TypeAlias = (
Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike]
) # `Domain` is `Sequence[NamedRange]` and therefore a subset
@overload
def domain(domain_like: NestedTuple[DomainLikeEl]) -> NestedTuple[Domain]: ...


def domain(domain_like: DomainLike) -> Domain:
def domain(domain_like: DomainLike) -> MaybeNestedInTuple[Domain]:
"""
Construct `Domain` from `DomainLike` object.

Expand All @@ -610,8 +626,8 @@ def domain(domain_like: DomainLike) -> Domain:
"""
if isinstance(domain_like, Domain):
return domain_like
if isinstance(domain_like, Sequence):
return Domain(*tuple(named_range(d) for d in domain_like))
if isinstance(domain_like, tuple):
return tuple((domain(el) for el in domain_like))
if isinstance(domain_like, Mapping):
if all(isinstance(elem, core_defs.INTEGRAL_TYPES) for elem in domain_like.values()):
return Domain(
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,10 @@ def as_field(
else:
origin = {}
actual_domain = common.domain(
[
(d, (-(start_offset := origin.get(d, 0)), s - start_offset))
{
d: (-(start_offset := origin.get(d, 0)), s - start_offset)
for d, s in zip(domain, data.shape)
]
}
)
else:
if origin:
Expand Down Expand Up @@ -332,7 +332,7 @@ def as_connectivity(
raise ValueError(
f"Cannot construct 'Field' from array of shape '{data.shape}' and domain '{domain}'."
)
actual_domain = common.domain([(d, (0, s)) for d, s in zip(domain, data.shape)])
actual_domain = common.domain({d: (0, s) for d, s in zip(domain, data.shape)})
else:
actual_domain = common.domain(cast(common.DomainLike, domain))

Expand Down
2 changes: 1 addition & 1 deletion tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def _allocate_from_type(
case ts.FieldType(dims=dims, dtype=arg_dtype):
return strategy.field(
allocator=case.allocator,
domain=common.domain(tuple(domain[dim] for dim in dims)),
domain=domain[dims],
dtype=dtype or arg_dtype.kind.name.lower(),
)
case ts.ScalarType(kind=kind):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def test_domain_pickle_after_slice():
domain = common.domain(((I, (2, 4)), (J, (3, 5))))
domain = common.domain({I: (2, 4), J: (3, 5)})
# use slice_at to populate cached property
domain.slice_at[2:5, 5:7]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ def test_connectivity_field_inverse_image():

e2v_conn = common._connectivity(
np.roll(np.arange(E_START, E_STOP), 1),
domain=common.domain([common.named_range((E, (E_START, E_STOP)))]),
domain=common.domain({E: (E_START, E_STOP)}),
codomain=V,
)

Expand Down Expand Up @@ -835,10 +835,10 @@ def test_connectivity_field_inverse_image_2d_domain():
c2v_conn = common._connectivity(
np.asarray([[0, 0, 2], [1, 1, 2], [2, 2, 2]]),
domain=common.domain(
[
common.named_range((C, (C_START, C_STOP))),
common.named_range((C2V, (C2V_START, C2V_STOP))),
]
{
C: (C_START, C_STOP),
C2V: (C2V_START, C2V_STOP),
}
),
codomain=V,
)
Expand Down Expand Up @@ -890,7 +890,7 @@ def test_connectivity_field_inverse_image_non_contiguous():

e2v_conn = common._connectivity(
np.asarray([0, 1, 2, 3, 4, 9, 7, 5, 8, 6]),
domain=common.domain([common.named_range((E, (E_START, E_STOP)))]),
domain=common.domain({E: (E_START, E_STOP)}),
codomain=V,
)

Expand Down
4 changes: 2 additions & 2 deletions tests/next_tests/unit_tests/test_allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_allocate(self):
)
I = common.Dimension("I")
J = common.Dimension("J")
domain = common.domain(((I, (2, 4)), (J, (3, 5))))
domain = common.domain({I: (2, 4), J: (3, 5)})
dtype = float
with pytest.raises(ValueError, match="test error"):
allocator.__gt_allocate__(domain, dtype)
Expand All @@ -166,7 +166,7 @@ def test_allocate():

I = common.Dimension("I")
J = common.Dimension("J")
domain = common.domain(((I, (0, 2)), (J, (0, 3))))
domain = common.domain({I: (0, 2), J: (0, 3)})
dtype = core_defs.dtype(float)

# Test with a explicit field allocator
Expand Down
7 changes: 3 additions & 4 deletions tests/next_tests/unit_tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,13 @@ def test_empty_domain(empty_domain, expected):
"domain_like",
[
(Domain(dims=(IDim, JDim), ranges=(UnitRange(2, 4), UnitRange(3, 5)))),
((IDim, (2, 4)), (JDim, (3, 5))),
({IDim: (2, 4), JDim: (3, 5)}),
],
)
def test_domain_like(domain_like):
assert domain(domain_like) == Domain(
dims=(IDim, JDim), ranges=(UnitRange(2, 4), UnitRange(3, 5))
)
expected = Domain(dims=(IDim, JDim), ranges=(UnitRange(2, 4), UnitRange(3, 5)))
assert domain(domain_like) == expected
assert domain((domain_like, (domain_like, domain_like))) == (expected, (expected, expected))


def test_domain_iteration(a_domain):
Expand Down
Loading