Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,41 @@ def __add__(self, offset: int) -> Connectivity:
def __sub__(self, offset: int) -> Connectivity:
return self + (-offset)

def __gt__(self, value: core_defs.IntegralScalar) -> Domain:
return Domain(dims=(self,), ranges=(UnitRange(value + 1, Infinity.POSITIVE),))

def __ge__(self, value: core_defs.IntegralScalar) -> Domain:
return Domain(dims=(self,), ranges=(UnitRange(value, Infinity.POSITIVE),))

def __lt__(self, value: core_defs.IntegralScalar) -> Domain:
return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value),))

def __le__(self, value: core_defs.IntegralScalar) -> Domain:
# TODO add test
return Domain(dims=(self,), ranges=(UnitRange(Infinity.NEGATIVE, value + 1),))

Comment on lines +122 to +123
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overloads

def __eq__(self, value: Dimension | core_defs.IntegralScalar) -> bool | Domain:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __eq__(self, value: Dimension | core_defs.IntegralScalar) -> bool | Domain:
def __eq__(self, value: Dimension | core_defs.IntegralScalar | Any) -> bool | Domain:

if isinstance(value, Dimension):
return self.value == value.value
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return self.value == value.value
return self.value == value.value and self.kind == value.kind

???

elif isinstance(value, core_defs.INTEGRAL_TYPES):
# TODO probably only within valid embedded context?
return Domain(dims=(self,), ranges=(UnitRange(value, value + 1),))
else:
return False

def __ne__(self, value: Dimension | core_defs.IntegralScalar) -> bool | tuple[Domain, Domain]:
# TODO add test
if isinstance(value, Dimension):
return self.value != value.value
elif isinstance(value, core_defs.INTEGRAL_TYPES):
# TODO probably only within valid embedded context?
return (
Domain(self, UnitRange(Infinity.NEGATIVE, value)),
Domain(self, UnitRange(value + 1, Infinity.POSITIVE)),
)
else:
return True


class Infinity(enum.Enum):
"""Describes an unbounded `UnitRange`."""
Expand Down Expand Up @@ -498,6 +533,24 @@ def __and__(self, other: Domain) -> Domain:
)
return Domain(dims=broadcast_dims, ranges=intersected_ranges)

def __or__(self, other: Domain) -> Domain:
# TODO support arbitrary union of domains
# TODO add tests
if self.ndim > 1 or other.ndim > 1:
raise NotImplementedError("Union of multidimensional domains is not supported.")
if self.ndim == 0:
return other
if other.ndim == 0:
return self
sorted_ = sorted((self, other), key=lambda x: x.ranges[0].start)
if sorted_[0].ranges[0].stop >= sorted_[1].ranges[0].start:
return Domain(
dims=(self.dims[0],),
ranges=(UnitRange(sorted_[0].ranges[0].start, sorted_[1].ranges[0].stop),),
)
else:
return (sorted_[0], sorted_[1])

@functools.cached_property
def slice_at(self) -> utils.IndexerCallable[slice, Domain]:
"""
Expand Down
147 changes: 77 additions & 70 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,25 +828,6 @@ def _hyperslice(
NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where"))


def _compute_mask_slices(
mask: core_defs.NDArrayObject,
) -> list[tuple[bool, slice]]:
"""Take a 1-dimensional mask and return a sequence of mappings from boolean values to slices."""
# TODO: does it make sense to upgrade this naive algorithm to numpy?
assert mask.ndim == 1
cur = bool(mask[0].item())
ind = 0
res = []
for i in range(1, mask.shape[0]):
# Use `.item()` to extract the scalar from a 0-d array in case of e.g. cupy
if (mask_i := bool(mask[i].item())) != cur:
res.append((cur, slice(ind, i)))
cur = mask_i
ind = i
res.append((cur, slice(ind, mask.shape[0])))
return res


def _trim_empty_domains(
lst: Iterable[tuple[bool, common.Domain]],
) -> list[tuple[bool, common.Domain]]:
Expand Down Expand Up @@ -914,82 +895,108 @@ def _stack_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[c

def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field:
# TODO(havogt): this function could be extended to a general concat
# currently only concatenate along the given dimension and requires the fields to be ordered
# currently only concatenate along the given dimension
sorted_fields = sorted(fields, key=lambda f: f.domain[dim].unit_range.start)

if (
len(fields) > 1
and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty()
len(sorted_fields) > 1
and not embedded_common.domain_intersection(*[f.domain for f in sorted_fields]).is_empty()
):
raise ValueError("Fields to concatenate must not overlap.")
new_domain = _stack_domains(*[f.domain for f in fields], dim=dim)
new_domain = _stack_domains(*[f.domain for f in sorted_fields], dim=dim)
if new_domain is None:
raise embedded_exceptions.NonContiguousDomain(f"Cannot concatenate fields along {dim}.")
nd_array_class = _get_nd_array_class(*fields)
nd_array_class = _get_nd_array_class(*sorted_fields)
return nd_array_class.from_array(
nd_array_class.array_ns.concatenate(
[nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) for f in fields],
[
nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape)
for f in sorted_fields
],
axis=new_domain.dim_index(dim, allow_missing=False),
),
domain=new_domain,
)


def _invert_domain(
domains: common.Domain | tuple[common.Domain],
) -> common.Domain | tuple[common.Domain, ...]:
if not isinstance(domains, tuple):
domains = (domains,)

assert all(d.ndim == 1 for d in domains)
dim = domains[0].dims[0]
assert all(d.dims[0] == dim for d in domains)
sorted_domains = sorted(domains, key=lambda d: d.ranges[0].start)

result = []
if domains[0].ranges[0].start is not common.Infinity.NEGATIVE:
result.append(
common.Domain(
dims=(dim,),
ranges=(common.UnitRange(common.Infinity.NEGATIVE, domains[0].ranges[0].start),),
)
)
for i in range(len(sorted_domains) - 1):
if sorted_domains[i].ranges[0].stop != sorted_domains[i + 1].ranges[0].start:
result.append(
common.Domain(
dims=(dim,),
ranges=(
common.UnitRange(
sorted_domains[i].ranges[0].stop, sorted_domains[i + 1].ranges[0].start
),
),
)
)
if domains[-1].ranges[0].stop is not common.Infinity.POSITIVE:
result.append(
common.Domain(
dims=(dim,),
ranges=(common.UnitRange(domains[-1].ranges[0].stop, common.Infinity.POSITIVE),),
)
)
return tuple(result)


def _intersect_multiple(
domain: common.Domain, domains: common.Domain | tuple[common.Domain]
) -> tuple[common.Domain, ...]:
if not isinstance(domains, tuple):
domains = (domains,)

return tuple(
intersection
for d in domains
if not (intersection := embedded_common.domain_intersection(domain, d)).is_empty()
)


def _concat_where(
mask_field: common.Field, true_field: common.Field, false_field: common.Field
masks: common.Domain | tuple[common.Domain, ...],
true_field: common.Field,
false_field: common.Field,
) -> common.Field:
cls_ = _get_nd_array_class(mask_field, true_field, false_field)
xp = cls_.array_ns
if mask_field.domain.ndim != 1:
if not isinstance(masks, tuple):
masks = (masks,)
if any(m.ndim for m in masks) != 1:
raise NotImplementedError(
"'concat_where': Can only concatenate fields with a 1-dimensional mask."
)
mask_dim = mask_field.domain.dims[0]
mask_dim = masks[0].dims[0]

# intersect the field in dimensions orthogonal to the mask, then all slices in the mask field have same domain
t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim)

# TODO(havogt): for clarity, most of it could be implemented on named_range in the masked dimension, but we currently lack the utils
# compute the consecutive ranges (first relative, then domain) of true and false values
mask_values_to_slices_mapping: Iterable[tuple[bool, slice]] = _compute_mask_slices(
mask_field.ndarray
)
mask_values_to_domain_mapping: Iterable[tuple[bool, common.Domain]] = (
(mask, mask_field.domain.slice_at[domain_slice])
for mask, domain_slice in mask_values_to_slices_mapping
)
# mask domains intersected with the respective fields
mask_values_to_intersected_domains_mapping: Iterable[tuple[bool, common.Domain]] = (
(
mask_value,
embedded_common.domain_intersection(
t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain
),
)
for mask_value, mask_domain in mask_values_to_domain_mapping
)

# remove the empty domains from the beginning and end
mask_values_to_intersected_domains_mapping = _trim_empty_domains(
mask_values_to_intersected_domains_mapping
)
if any(d.is_empty() for _, d in mask_values_to_intersected_domains_mapping):
raise embedded_exceptions.NonContiguousDomain(
f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in mask_values_to_intersected_domains_mapping]}."
)
true_domains = _intersect_multiple(t_broadcasted.domain, masks)
t_slices = tuple(t_broadcasted[d] for d in true_domains)

# slice the fields with the domain ranges
transformed = [
t_broadcasted[d] if v else f_broadcasted[d]
for v, d in mask_values_to_intersected_domains_mapping
]
inverted_masks = _invert_domain(masks)
false_domains = _intersect_multiple(f_broadcasted.domain, inverted_masks)
f_slices = tuple(f_broadcasted[d] for d in false_domains)

# stack the fields together
if transformed:
return _concat(*transformed, dim=mask_dim)
else:
result_domain = common.Domain(common.NamedRange(mask_dim, common.UnitRange(0, 0)))
result_array = xp.empty(result_domain.shape)
return cls_.from_array(result_array, domain=result_domain)
return _concat(*f_slices, *t_slices, dim=mask_dim)


NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] # TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR
Expand Down
Loading