Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
143 commits
Select commit Hold shift + click to select a range
768e458
First draft
tehrengruber Mar 2, 2025
ac7db53
Remove debugging leftovers
tehrengruber Mar 2, 2025
0b2ba1d
Merge origin_tehrengruber/get_domain_builtin
SF-N Jul 16, 2025
61b4a09
Add transformation for get_domain to named_range
SF-N Jul 16, 2025
ff6e23b
Merge branch 'main' into get_domain_builtin
tehrengruber Jul 17, 2025
4a65e7a
Add tuple suppoert
SF-N Jul 17, 2025
28804a3
Move tests to new file
SF-N Jul 17, 2025
5455f34
Update TransformGetDomain to return a tuple, introduce named_range in…
SF-N Jul 23, 2025
0bfbfde
Merge branch 'main' into transform_get_domain
SF-N Jul 23, 2025
1b002f7
Compute actual tempory sizes in unstructured case
SF-N Jul 24, 2025
0a84e7e
Fix tests
SF-N Jul 24, 2025
81f9de6
Only compute where values are used
SF-N Jul 24, 2025
c291c26
Fix some tests
SF-N Jul 25, 2025
56dd9f3
Merge branch 'main' into infer_unstructured_domains_of_temporaries
SF-N Jul 25, 2025
afc9c77
Merge branch 'main' into transform_get_domain
SF-N Jul 25, 2025
649363d
Update tests
SF-N Jul 25, 2025
06b2d0c
Merge branch 'main' into infer_unstructured_domains_of_temporaries
SF-N Jul 25, 2025
453bc0c
Minor
SF-N Jul 25, 2025
3731810
Get domain from tuple element
tehrengruber Jul 30, 2025
4bd8a51
Merge origin/main
tehrengruber Jul 30, 2025
8a6b636
Merge branch 'main' into transform_get_domain
SF-N Aug 6, 2025
001fab2
Merge branch 'main' into infer_unstructured_domains_of_temporaries
SF-N Aug 6, 2025
2c486ed
Refactor tests
SF-N Aug 6, 2025
abb93ef
Cleanup and restrict domains in tests
SF-N Aug 7, 2025
7b5b9b4
Refactor and clean up tests
SF-N Aug 7, 2025
aa861a9
Merge branch 'transform_get_domain' of github.com:SF-N/gt4py into tra…
SF-N Aug 7, 2025
1a93a4b
Merge branch 'main' into transform_get_domain
SF-N Aug 7, 2025
e12d9ec
Merge branch 'main' into infer_unstructured_domains_of_temporaries
SF-N Aug 7, 2025
891a97a
Add check on mesh order
SF-N Aug 7, 2025
958183b
Reformat
SF-N Aug 7, 2025
466ec0f
Reformat
SF-N Aug 7, 2025
0e4eb57
Rename get_domain to get_domain_range
tehrengruber Aug 18, 2025
9a878c7
Merge branch 'main' into transform_get_domain
SF-N Aug 18, 2025
08efc12
Merge branch 'main' into infer_unstructured_domains_of_temporaries
SF-N Aug 18, 2025
d916337
Remove compile time args
tehrengruber Aug 18, 2025
25e24e9
Fix format
tehrengruber Aug 18, 2025
e856a18
Fix failing tests
tehrengruber Aug 18, 2025
00a11a6
Fix failing tests
tehrengruber Aug 18, 2025
575c7bf
Merge origin_tehrengruber/transform_get_domain
tehrengruber Aug 20, 2025
fd509c8
Merge remote-tracking branch 'origin/main' into get_domain_builtin
tehrengruber Aug 20, 2025
a3e722d
Merge branch 'get_domain_builtin' into transform_get_domain
tehrengruber Aug 20, 2025
97e47a2
Cleanup
tehrengruber Aug 20, 2025
564ca16
Cleanup tests
tehrengruber Aug 23, 2025
f784171
Merge transform_get_domain (test_trivial_shift_warning fails)
tehrengruber Aug 23, 2025
0aae1fb
Fix failing test
tehrengruber Aug 23, 2025
a35bfb5
Cleanup
tehrengruber Aug 23, 2025
7c9e270
Merge branch 'transform_get_domain' into infer_unstructured_domains_o…
tehrengruber Aug 23, 2025
325db75
Cleanup
tehrengruber Aug 23, 2025
f246582
Address review comments
tehrengruber Sep 2, 2025
53a5ac8
Remove implicit domain
tehrengruber Sep 2, 2025
9c7d9e2
Fix format
tehrengruber Sep 2, 2025
43e6af6
Merge remote-tracking branch 'origin/main' into get_domain_builtin
tehrengruber Sep 5, 2025
30f4b7b
Add ArgumentDescriptor mechanism, reworking static args and enabling …
tehrengruber Sep 11, 2025
c274a58
Merge origin_tehrengruber/get_domain_builtin
tehrengruber Sep 11, 2025
30cc046
Merge origin_tehrengruber/transform_get_domain
tehrengruber Sep 11, 2025
dcca182
Merge origin/main
tehrengruber Sep 11, 2025
8949042
Merge origin_sf_n/transform_get_domain
tehrengruber Sep 11, 2025
19f8f7a
Add argument descriptors and rework static args mechanism
tehrengruber Sep 11, 2025
1702aec
Cleanup
tehrengruber Sep 11, 2025
4fb4e81
Merge branch origin_tehrengruber/argument_descriptors
tehrengruber Sep 11, 2025
e5aac50
Cleanup
tehrengruber Sep 11, 2025
f83925d
Cleanup
tehrengruber Sep 11, 2025
2d00cb6
Cleanup
tehrengruber Sep 11, 2025
44773ea
Cleanup
tehrengruber Sep 11, 2025
4f45573
Cleanup
tehrengruber Sep 11, 2025
b01d6ee
Cleanup
tehrengruber Sep 11, 2025
fa9fa57
Cleanup
tehrengruber Sep 11, 2025
60e3b79
Cleanup
tehrengruber Sep 11, 2025
b694623
Cleanup
tehrengruber Sep 11, 2025
f7e131c
Cleanup
tehrengruber Sep 12, 2025
d7a09d8
Cleanup
tehrengruber Sep 12, 2025
2277bc1
Cleanup
tehrengruber Sep 12, 2025
5465a55
Merge branch 'transform_get_domain' into infer_unstructured_domains_o…
tehrengruber Sep 12, 2025
878bc82
Merge branch 'argument_descriptors' into infer_unstructured_domains_o…
tehrengruber Sep 12, 2025
86d19a7
Cleanup
tehrengruber Sep 12, 2025
ebca236
Improve docs add ADR
tehrengruber Sep 12, 2025
0c9bb89
Improve docs
tehrengruber Sep 12, 2025
a284ea7
Cleanup
tehrengruber Sep 12, 2025
c7c7361
Cleanup
tehrengruber Sep 12, 2025
efd331b
Fix failing test
tehrengruber Sep 12, 2025
dde023c
Merge remote-tracking branch 'origin/main' into argument_descriptors
tehrengruber Sep 12, 2025
07b968c
Fix failing test, abstract classmethod
tehrengruber Sep 12, 2025
1692f5c
Cleanup
tehrengruber Sep 12, 2025
adf89f2
Merge branch 'argument_descriptors' into infer_unstructured_domains_o…
tehrengruber Sep 12, 2025
7e3ec65
Cleanup
tehrengruber Sep 12, 2025
8b8b69f
Merge branch 'transform_get_domain' into infer_unstructured_domains_o…
tehrengruber Sep 12, 2025
6dfb037
Address review comments
tehrengruber Sep 21, 2025
de7a8bd
Address review comments
tehrengruber Sep 21, 2025
bfbab63
Merge branch 'main' into transform_get_domain
tehrengruber Sep 22, 2025
493daa1
Merge remote-tracking branch 'origin_sf_n/transform_get_domain' into …
tehrengruber Sep 22, 2025
907a7e1
Merge branch 'argument_descriptors' into infer_unstructured_domains_o…
tehrengruber Sep 22, 2025
1e90442
Cleanup
tehrengruber Sep 22, 2025
a99b06a
Cleanup
tehrengruber Sep 22, 2025
356f4c2
Merge remote-tracking branch 'origin/main' into transform_get_domain
tehrengruber Sep 22, 2025
e95ddc9
Merge branch 'transform_get_domain' into infer_unstructured_domains_o…
tehrengruber Sep 22, 2025
fe90bff
Cleanup
tehrengruber Sep 23, 2025
6395086
Fix fieldview transforms
tehrengruber Sep 23, 2025
d71ae22
Fix fieldview transforms
tehrengruber Sep 23, 2025
b38d3b1
Cleanup
tehrengruber Sep 23, 2025
ccdf312
Cleanup
tehrengruber Sep 23, 2025
bba5db2
Cleanup
tehrengruber Sep 24, 2025
d0297a6
Cleanup
tehrengruber Sep 24, 2025
67199ae
Merge branch 'argument_descriptors' into infer_unstructured_domains_o…
tehrengruber Sep 24, 2025
03bb958
Cleanup
tehrengruber Sep 25, 2025
ff9a48b
Merge branch 'argument_descriptors' into infer_unstructured_domains_o…
tehrengruber Sep 25, 2025
886df4d
Static domains for args of inhomogenous types
tehrengruber Sep 25, 2025
2fd1c2b
Fix non scan projector
tehrengruber Sep 25, 2025
1e91b65
Merge branch 'fix_non_scan_projector' into infer_unstructured_domains…
tehrengruber Sep 25, 2025
bc29ac3
Small fix
tehrengruber Sep 25, 2025
b8619a7
Fix tests
tehrengruber Sep 25, 2025
807a1fb
Fix format
tehrengruber Sep 25, 2025
c04d84d
Merge remote-tracking branch 'origin/main' into argument_descriptors
tehrengruber Sep 25, 2025
f5cf8a7
Merge branch 'argument_descriptors' into infer_unstructured_domains_o…
tehrengruber Sep 25, 2025
f9648b3
Make static domain translation configurable
tehrengruber Sep 26, 2025
cac49f6
Fix
tehrengruber Sep 26, 2025
1dc4198
Merge argument_descriptors
tehrengruber Sep 29, 2025
e3e012d
Merge remote-tracking branch 'origin/main' into argument_descriptors
tehrengruber Sep 29, 2025
1375c53
Merge branch 'argument_descriptors' into infer_unstructured_domains_o…
tehrengruber Sep 29, 2025
3cb901f
Improve oob error msg
tehrengruber Sep 30, 2025
44e02c7
Fix dace tests
tehrengruber Oct 7, 2025
d75a092
Fix failing tests
tehrengruber Oct 7, 2025
34691e4
Merge remote-tracking branch 'origin/main' into infer_unstructured_do…
tehrengruber Oct 8, 2025
7db7d24
Fix dace tests
tehrengruber Oct 8, 2025
0af5e95
Merge remote-tracking branch 'origin/main' into infer_unstructured_do…
tehrengruber Oct 8, 2025
7098ccf
Small fixes
tehrengruber Oct 8, 2025
1cf559c
Small fixes
tehrengruber Oct 8, 2025
8090168
Merge branch 'main' into infer_unstructured_domains_of_temporaries
tehrengruber Oct 8, 2025
38068a0
Merge remote-tracking branch 'origin/main' into infer_unstructured_do…
tehrengruber Oct 14, 2025
9a44302
Merge remote-tracking branch 'origin_sf_n/infer_unstructured_domains_…
tehrengruber Oct 14, 2025
1a65fb6
Merge remote-tracking branch 'origin/main' into infer_unstructured_do…
tehrengruber Oct 17, 2025
0e05fa9
Merge branch 'main' into infer_unstructured_domains_of_temporaries
SF-N Oct 20, 2025
0348c1a
Merge branch 'main' into infer_unstructured_domains_of_temporaries
SF-N Oct 21, 2025
381435f
Merge remote-tracking branch 'origin/main' into infer_unstructured_do…
tehrengruber Nov 12, 2025
e9171d3
Merge remote-tracking branch 'origin/main' into infer_unstructured_do…
tehrengruber Nov 13, 2025
6b26323
Fix format
tehrengruber Nov 13, 2025
39312c2
Fix broken merge
tehrengruber Nov 13, 2025
f1efb3f
Fix format
tehrengruber Nov 13, 2025
888a2a9
Fix failing test in dace
tehrengruber Nov 19, 2025
86556af
Fix format
tehrengruber Nov 19, 2025
52e67d9
Address review comments
tehrengruber Nov 26, 2025
d7bf7a7
Address review comments
tehrengruber Nov 26, 2025
cdeff43
Address review comments
tehrengruber Nov 26, 2025
83b0e88
Address review comments
tehrengruber Nov 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/gt4py/eve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ def reset_sequence(self, start: int = 1, *, warn_unsafe: Optional[bool] = None)
if warn_unsafe is None:
warn_unsafe = self.warn_unsafe
if warn_unsafe and start < next(self._counter):
warnings.warn("Unsafe reset of UIDGenerator ({self})", stacklevel=2)
warnings.warn(f"Unsafe reset of UIDGenerator ({self})", stacklevel=2)
self._counter = itertools.count(start)

return self
Expand Down
36 changes: 30 additions & 6 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@
DEFAULT_BACKEND: next_backend.Backend | None = None


def _field_domain_descriptor_mapping_from_func_type(func_type: ts.FunctionType) -> list[str]:
static_domain_args = []
param_types = func_type.pos_or_kw_args | func_type.kw_only_args
for name, type_ in param_types.items():
for el_type_, path in type_info.primitive_constituents(type_, with_path_arg=True):
if isinstance(el_type_, ts.FieldType):
path_as_expr = "".join(map(lambda idx: f"[{idx}]", path))
static_domain_args.append(f"{name}{path_as_expr}")
return static_domain_args


# TODO(tehrengruber): Decide if and how programs can call other programs. As a
# result Program could become a GTCallable.
@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -86,6 +97,7 @@ class Program:
static_params: (
Sequence[str] | None
) # if the user requests static params, they will be used later to initialize CompiledPrograms
static_domains: bool

@classmethod
def from_function(
Expand All @@ -95,6 +107,7 @@ def from_function(
grid_type: common.GridType | None = None,
enable_jit: bool | None = None,
static_params: Sequence[str] | None = None,
static_domains: bool = False,
connectivities: Optional[
common.OffsetProvider
] = None, # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information
Expand All @@ -106,6 +119,7 @@ def from_function(
connectivities=connectivities,
enable_jit=enable_jit,
static_params=static_params,
static_domains=static_domains,
)

# TODO(ricoh): linting should become optional, up to the backend.
Expand Down Expand Up @@ -170,20 +184,26 @@ def _compiled_programs(self) -> compiled_program.CompiledProgramsPool:
if self.backend is None or self.backend == eve.NOTHING:
raise RuntimeError("Cannot compile a program without backend.")

if self.static_params is None:
object.__setattr__(self, "static_params", ())
argument_descriptor_mapping: dict[type[arguments.ArgStaticDescriptor], Sequence[str]] = {}

argument_descriptor_mapping = {
arguments.StaticArg: self.static_params,
}
if self.static_params:
argument_descriptor_mapping[arguments.StaticArg] = self.static_params

if self.static_domains:
assert isinstance(self.past_stage.past_node.type, ts_ffront.ProgramType)
argument_descriptor_mapping[arguments.FieldDomainDescriptor] = (
_field_domain_descriptor_mapping_from_func_type(
self.past_stage.past_node.type.definition
)
)

program_type = self.past_stage.past_node.type
assert isinstance(program_type, ts_ffront.ProgramType)
return compiled_program.CompiledProgramsPool(
backend=self.backend,
definition_stage=self.definition_stage,
program_type=program_type,
argument_descriptor_mapping=argument_descriptor_mapping, # type: ignore[arg-type] # covariant `type[T]` not possible
argument_descriptor_mapping=argument_descriptor_mapping,
)

def with_backend(self, backend: next_backend.Backend) -> Program:
Expand Down Expand Up @@ -529,6 +549,7 @@ def program(
grid_type: common.GridType | None,
enable_jit: bool | None,
static_params: Sequence[str] | None,
static_domains: bool,
frozen: bool,
) -> Callable[[types.FunctionType], Program]: ...

Expand All @@ -541,6 +562,7 @@ def program(
grid_type: common.GridType | None = None,
enable_jit: bool | None = None, # only relevant if static_params are set
static_params: Sequence[str] | None = None,
static_domains: bool = False,
frozen: bool = False,
) -> Program | FrozenProgram | Callable[[types.FunctionType], Program | FrozenProgram]:
"""
Expand Down Expand Up @@ -569,6 +591,7 @@ def program_inner(definition: types.FunctionType) -> Program:
),
grid_type=grid_type,
enable_jit=enable_jit,
static_domains=static_domains,
static_params=static_params,
)
if frozen:
Expand Down Expand Up @@ -703,6 +726,7 @@ def as_program(self, compiletime_args: arguments.CompileTimeArgs) -> Program:
connectivities=None,
enable_jit=False, # TODO(havogt): revisit ProgramFromPast
static_params=None, # TODO(havogt): revisit ProgramFromPast
static_domains=False, # TODO(havogt): revisit ProgramFromPast
)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
Expand Down
22 changes: 17 additions & 5 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import devtools

from gt4py.eve import NodeTranslator, traits
from gt4py.next import common, config, errors, utils as gtx_utils
from gt4py.next import common, config, errors, utils
from gt4py.next.ffront import (
fbuiltins,
gtcallable,
Expand All @@ -28,7 +28,7 @@
from gt4py.next.ffront.stages import AOT_PRG
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms import remap_symbols
from gt4py.next.iterator.transforms import remap_symbols, replace_get_domain_range_with_constants
from gt4py.next.otf import arguments, stages, workflow
from gt4py.next.type_system import type_info, type_specifications as ts

Expand Down Expand Up @@ -102,14 +102,14 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram:
static_arg_descriptors = inp.args.argument_descriptor_contexts[arguments.StaticArg]
if not all(
isinstance(arg_descriptor, arguments.StaticArg)
or all(el is None for el in gtx_utils.flatten_nested_tuple(arg_descriptor)) # type: ignore[arg-type]
or all(el is None for el in utils.flatten_nested_tuple(arg_descriptor)) # type: ignore[arg-type]
for arg_descriptor in static_arg_descriptors.values()
):
raise NotImplementedError("Only top-level arguments can be static.")
static_args = {
name: im.literal_from_tuple_value(descr.value) # type: ignore[union-attr] # type checked above
for name, descr in static_arg_descriptors.items()
if not any(el is None for el in gtx_utils.flatten_nested_tuple(descr)) # type: ignore[arg-type]
if not any(el is None for el in utils.flatten_nested_tuple(descr)) # type: ignore[arg-type]
}
body = remap_symbols.RemapSymbolRefs().visit(itir_program.body, symbol_map=static_args)
itir_program = itir.Program(
Expand All @@ -120,6 +120,18 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram:
body=body,
)

# TODO(tehrengruber): Put this in a dedicated transformation step.
if context := inp.args.argument_descriptor_contexts.get(arguments.FieldDomainDescriptor, None):
field_domains = {
param: utils.tree_map(lambda x: x.domain if x is not None else x)(v)
for param, v in context.items()
}
itir_program = (
replace_get_domain_range_with_constants.ReplaceGetDomainRangeWithConstants.apply(
itir_program, sizes=field_domains
)
)

# Translate NamedCollectionTypes to TupleTypes in compile-time args
args = tuple(ffront_ti.named_collections_to_tuple_types(arg) for arg in inp.args.args)
kwargs: dict[str, ts.TypeSpec] = {
Expand Down Expand Up @@ -443,7 +455,7 @@ def _visit_stencil_call_out_arg(
"Unexpected 'out' argument. Must be a 'past.Subscript', 'past.Name' or 'past.TupleExpr' node."
)

@gtx_utils.tree_map(
@utils.tree_map(
collection_type=ts.COLLECTION_TYPE_SPECS,
with_path_arg=True,
unpack=True,
Expand Down
132 changes: 93 additions & 39 deletions src/gt4py/next/iterator/ir_utils/domain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,24 @@

import dataclasses
import functools
from typing import Any, Callable, Iterable, Literal, Mapping, Optional
import warnings
from typing import Callable, Iterable, Literal, Optional

import numpy as np

from gt4py.next import common
from gt4py.next.iterator import builtins, ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.transforms import trace_shifts
from gt4py.next.iterator.transforms import collapse_tuple, trace_shifts
from gt4py.next.iterator.transforms.constant_folding import ConstantFolding


def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> dict[str, int]:
"""
Extract horizontal domain sizes from an `offset_provider`.
#: Threshold fraction of domain points which may be added to a domain on translation in order
#: to have a contiguous domain before a warning is raised.
_NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD: float = 1 / 4

Considers the shape of the neighbor table to get the size of each `source_dim` and the maximum
value inside the neighbor table to get the size of each `codomain`.
"""
sizes = dict[str, int]()
for provider in offset_provider.values():
if common.is_neighbor_connectivity(provider):
conn_type = provider.__gt_type__()
sizes[conn_type.source_dim.value] = max(
sizes.get(conn_type.source_dim.value, 0), provider.ndarray.shape[0]
)
sizes[conn_type.codomain.value] = max(
sizes.get(conn_type.codomain.value, 0),
provider.ndarray.max() + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject
)
return sizes
#: Offset tags for which a non-contiguous domain warning has already been printed
_NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS: set[str] = set()


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -68,6 +58,78 @@ def empty(self) -> bool | None:
}


def _unstructured_translate_range_statically(
range_: SymbolicRange,
tag: str,
val: itir.OffsetLiteral
| Literal[trace_shifts.Sentinel.VALUE, trace_shifts.Sentinel.ALL_NEIGHBORS],
offset_provider: common.OffsetProvider,
expr: itir.Expr | None = None,
) -> SymbolicRange:
"""
Translate `range_` using static connectivity information from `offset_provider`.
"""
assert common.is_offset_provider(offset_provider)
connectivity = offset_provider[tag]
assert isinstance(connectivity, common.Connectivity)
skip_value = connectivity.skip_value

# fold & convert expr into actual integers
start_expr, stop_expr = range_.start, range_.stop
start_expr, stop_expr = ( # type: ignore[assignment] # mypy not smart enough
collapse_tuple.CollapseTuple.apply(
expr,
within_stencil=False,
allow_undeclared_symbols=True,
)
for expr in (start_expr, stop_expr)
)
assert isinstance(start_expr, itir.Literal) and isinstance(stop_expr, itir.Literal)
start, stop = int(start_expr.value), int(stop_expr.value)

nb_index: slice | int
if val in [trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE]:
nb_index = slice(None)
else:
nb_index = val.value # type: ignore[assignment] # assert above

accessed = connectivity.ndarray[start:stop, nb_index]

if isinstance(val, itir.OffsetLiteral) and np.any(accessed == skip_value):
# TODO(tehrengruber): Turn this into a configurable error. This is currently
# not possible since some test cases starting from ITIR containing
# `can_deref` might lead here. The frontend never emits such IR and domain
# inference runs after we transform reductions into stmts containing
# `can_deref`.
warnings.warn(
UserWarning(f"Translating '{expr}' using '{tag}' has an out-of-bounds access."),
stacklevel=2,
)

new_start, new_stop = accessed.min(), accessed.max() + 1 # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject

fraction_accessed = np.unique(accessed).size / (new_stop - new_start) # type: ignore[call-overload] # TODO(havogt): improve typing for NDArrayObject

if fraction_accessed < _NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD and (
tag not in _NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS
):
_NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS.add(tag)
warnings.warn(
UserWarning(
f"Translating '{expr}' using '{tag}' requires "
f"computations on many additional points "
f"({round((1 - fraction_accessed) * 100)}%) in order to get a contiguous "
f"domain. Please consider reordering your mesh."
),
stacklevel=2,
)

return SymbolicRange(
im.literal(str(new_start), builtins.INTEGER_INDEX_BUILTIN),
im.literal(str(new_stop), builtins.INTEGER_INDEX_BUILTIN),
)


@dataclasses.dataclass(frozen=True)
class SymbolicDomain:
grid_type: common.GridType
Expand Down Expand Up @@ -114,7 +176,7 @@ def translate(
offset_provider: common.OffsetProvider | common.OffsetProviderType,
#: A dictionary mapping axes names to their length. See
#: func:`gt4py.next.iterator.transforms.infer_domain.infer_expr` for more details.
symbolic_domain_sizes: Optional[dict[str, str]] = None,
symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None,
) -> SymbolicDomain:
offset_provider_type = common.offset_provider_to_type(offset_provider)

Expand Down Expand Up @@ -144,28 +206,20 @@ def translate(
trace_shifts.Sentinel.ALL_NEIGHBORS,
trace_shifts.Sentinel.VALUE,
]
horizontal_sizes: dict[str, itir.Expr]
if symbolic_domain_sizes is not None:
horizontal_sizes = {
k: im.ensure_expr(v) for k, v in symbolic_domain_sizes.items()
}
else:
# note: ugly but cheap re-computation, but should disappear
assert common.is_offset_provider(offset_provider)
horizontal_sizes = {
k: im.literal(str(v), builtins.INTEGER_INDEX_BUILTIN)
for k, v in _max_domain_sizes_by_location_type(offset_provider).items()
}

old_dim = connectivity_type.source_dim
new_dim = connectivity_type.codomain

assert new_dim not in new_ranges or old_dim == new_dim
if symbolic_domain_sizes is not None and new_dim.value in symbolic_domain_sizes:
new_range = SymbolicRange(
im.literal(str(0), builtins.INTEGER_INDEX_BUILTIN),
im.ensure_expr(symbolic_domain_sizes[new_dim.value]),
)
else:
assert common.is_offset_provider(offset_provider)
new_range = _unstructured_translate_range_statically(
new_ranges[old_dim], off.value, val, offset_provider, self.as_expr()
)

new_range = SymbolicRange(
im.literal("0", builtins.INTEGER_INDEX_BUILTIN),
horizontal_sizes[new_dim.value],
)
new_ranges = dict(
(dim, range_) if dim != old_dim else (new_dim, new_range)
for dim, range_ in new_ranges.items()
Expand Down
9 changes: 8 additions & 1 deletion src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def _impl(*its: itir.Expr) -> itir.FunCall:

def domain(
grid_type: Union[common.GridType, str],
ranges: dict[common.Dimension, tuple[itir.Expr, itir.Expr]],
ranges_or_domain: dict[common.Dimension, tuple[itir.Expr, itir.Expr]] | common.Domain,
) -> itir.FunCall:
"""
>>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL)
Expand All @@ -463,6 +463,13 @@ def domain(
>>> str(domain(common.GridType.UNSTRUCTURED, {IDim: (0, 10), JDim: (0, 20)}))
'u⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩'
"""
if isinstance(ranges_or_domain, common.Domain):
domain = ranges_or_domain
ranges = {d: (r.start, r.stop) for d, r in zip(domain.dims, domain.ranges)}
else:
assert isinstance(ranges_or_domain, dict)
ranges = ranges_or_domain

if isinstance(grid_type, common.GridType):
grid_type = f"{grid_type!s}_domain"
expr = call(grid_type)(
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class Transformation(enum.Flag):
# `maximum(maximum(a, 1), 1)` -> `maximum(a, 1)`
FOLD_MIN_MAX = enum.auto()

# `maximum(a + 1), a)` -> `a + 1`
# `maximum(a + 1, a)` -> `a + 1`
# `maximum(a + 1, a + (-1))` -> `a + maximum(1, -1)`
FOLD_MIN_MAX_PLUS = enum.auto()

Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def create_global_tmps(
offset_provider: common.OffsetProvider | common.OffsetProviderType,
#: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for
#: more details.
symbolic_domain_sizes: Optional[dict[str, str]] = None,
symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None,
*,
uids: Optional[eve_utils.UIDGenerator] = None,
) -> itir.Program:
Expand Down
Loading