diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index a8be1ba881..f48039d47c 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -914,7 +914,7 @@ def reset_sequence(self, start: int = 1, *, warn_unsafe: Optional[bool] = None) if warn_unsafe is None: warn_unsafe = self.warn_unsafe if warn_unsafe and start < next(self._counter): - warnings.warn("Unsafe reset of UIDGenerator ({self})", stacklevel=2) + warnings.warn(f"Unsafe reset of UIDGenerator ({self})", stacklevel=2) self._counter = itertools.count(start) return self diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index c23fff9a9a..31dc3aa5f7 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -54,6 +54,17 @@ DEFAULT_BACKEND: next_backend.Backend | None = None +def _field_domain_descriptor_mapping_from_func_type(func_type: ts.FunctionType) -> list[str]: + static_domain_args = [] + param_types = func_type.pos_or_kw_args | func_type.kw_only_args + for name, type_ in param_types.items(): + for el_type_, path in type_info.primitive_constituents(type_, with_path_arg=True): + if isinstance(el_type_, ts.FieldType): + path_as_expr = "".join(map(lambda idx: f"[{idx}]", path)) + static_domain_args.append(f"{name}{path_as_expr}") + return static_domain_args + + # TODO(tehrengruber): Decide if and how programs can call other programs. As a # result Program could become a GTCallable. @dataclasses.dataclass(frozen=True) @@ -86,6 +97,7 @@ class Program: static_params: ( Sequence[str] | None ) # if the user requests static params, they will be used later to initialize CompiledPrograms + static_domains: bool @classmethod def from_function( @@ -95,6 +107,7 @@ def from_function( grid_type: common.GridType | None = None, enable_jit: bool | None = None, static_params: Sequence[str] | None = None, + static_domains: bool = False, connectivities: Optional[ common.OffsetProvider ] = None, # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information @@ -106,6 +119,7 @@ def from_function( connectivities=connectivities, enable_jit=enable_jit, static_params=static_params, + static_domains=static_domains, ) # TODO(ricoh): linting should become optional, up to the backend. @@ -170,12 +184,18 @@ def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: if self.backend is None or self.backend == eve.NOTHING: raise RuntimeError("Cannot compile a program without backend.") - if self.static_params is None: - object.__setattr__(self, "static_params", ()) + argument_descriptor_mapping: dict[type[arguments.ArgStaticDescriptor], Sequence[str]] = {} - argument_descriptor_mapping = { - arguments.StaticArg: self.static_params, - } + if self.static_params: + argument_descriptor_mapping[arguments.StaticArg] = self.static_params + + if self.static_domains: + assert isinstance(self.past_stage.past_node.type, ts_ffront.ProgramType) + argument_descriptor_mapping[arguments.FieldDomainDescriptor] = ( + _field_domain_descriptor_mapping_from_func_type( + self.past_stage.past_node.type.definition + ) + ) program_type = self.past_stage.past_node.type assert isinstance(program_type, ts_ffront.ProgramType) @@ -183,7 +203,7 @@ def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: backend=self.backend, definition_stage=self.definition_stage, program_type=program_type, - argument_descriptor_mapping=argument_descriptor_mapping, # type: ignore[arg-type] # covariant `type[T]` not possible + argument_descriptor_mapping=argument_descriptor_mapping, ) def with_backend(self, backend: next_backend.Backend) -> Program: @@ -529,6 +549,7 @@ def program( grid_type: common.GridType | None, enable_jit: bool | None, static_params: Sequence[str] | None, + static_domains: bool, frozen: bool, ) -> Callable[[types.FunctionType], Program]: ... @@ -541,6 +562,7 @@ def program( grid_type: common.GridType | None = None, enable_jit: bool | None = None, # only relevant if static_params are set static_params: Sequence[str] | None = None, + static_domains: bool = False, frozen: bool = False, ) -> Program | FrozenProgram | Callable[[types.FunctionType], Program | FrozenProgram]: """ @@ -569,6 +591,7 @@ def program_inner(definition: types.FunctionType) -> Program: ), grid_type=grid_type, enable_jit=enable_jit, + static_domains=static_domains, static_params=static_params, ) if frozen: @@ -703,6 +726,7 @@ def as_program(self, compiletime_args: arguments.CompileTimeArgs) -> Program: connectivities=None, enable_jit=False, # TODO(havogt): revisit ProgramFromPast static_params=None, # TODO(havogt): revisit ProgramFromPast + static_domains=False, # TODO(havogt): revisit ProgramFromPast ) def __call__(self, *args: Any, **kwargs: Any) -> Any: diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 4b76589ae3..bd3ee6c2b4 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -15,7 +15,7 @@ import devtools from gt4py.eve import NodeTranslator, traits -from gt4py.next import common, config, errors, utils as gtx_utils +from gt4py.next import common, config, errors, utils from gt4py.next.ffront import ( fbuiltins, gtcallable, @@ -28,7 +28,7 @@ from gt4py.next.ffront.stages import AOT_PRG from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms import remap_symbols +from gt4py.next.iterator.transforms import remap_symbols, replace_get_domain_range_with_constants from gt4py.next.otf import arguments, stages, workflow from gt4py.next.type_system import type_info, type_specifications as ts @@ -102,14 +102,14 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: static_arg_descriptors = inp.args.argument_descriptor_contexts[arguments.StaticArg] if not all( isinstance(arg_descriptor, arguments.StaticArg) - or all(el is None for el in gtx_utils.flatten_nested_tuple(arg_descriptor)) # type: ignore[arg-type] + or all(el is None for el in utils.flatten_nested_tuple(arg_descriptor)) # type: ignore[arg-type] for arg_descriptor in static_arg_descriptors.values() ): raise NotImplementedError("Only top-level arguments can be static.") static_args = { name: im.literal_from_tuple_value(descr.value) # type: ignore[union-attr] # type checked above for name, descr in static_arg_descriptors.items() - if not any(el is None for el in gtx_utils.flatten_nested_tuple(descr)) # type: ignore[arg-type] + if not any(el is None for el in utils.flatten_nested_tuple(descr)) # type: ignore[arg-type] } body = remap_symbols.RemapSymbolRefs().visit(itir_program.body, symbol_map=static_args) itir_program = itir.Program( @@ -120,6 +120,18 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: body=body, ) + # TODO(tehrengruber): Put this in a dedicated transformation step. + if context := inp.args.argument_descriptor_contexts.get(arguments.FieldDomainDescriptor, None): + field_domains = { + param: utils.tree_map(lambda x: x.domain if x is not None else x)(v) + for param, v in context.items() + } + itir_program = ( + replace_get_domain_range_with_constants.ReplaceGetDomainRangeWithConstants.apply( + itir_program, sizes=field_domains + ) + ) + # Translate NamedCollectionTypes to TupleTypes in compile-time args args = tuple(ffront_ti.named_collections_to_tuple_types(arg) for arg in inp.args.args) kwargs: dict[str, ts.TypeSpec] = { @@ -443,7 +455,7 @@ def _visit_stencil_call_out_arg( "Unexpected 'out' argument. Must be a 'past.Subscript', 'past.Name' or 'past.TupleExpr' node." ) - @gtx_utils.tree_map( + @utils.tree_map( collection_type=ts.COLLECTION_TYPE_SPECS, with_path_arg=True, unpack=True, diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 3fa088d785..835b6c2f2f 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -10,34 +10,24 @@ import dataclasses import functools -from typing import Any, Callable, Iterable, Literal, Mapping, Optional +import warnings +from typing import Callable, Iterable, Literal, Optional + +import numpy as np from gt4py.next import common from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.transforms import trace_shifts +from gt4py.next.iterator.transforms import collapse_tuple, trace_shifts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding -def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> dict[str, int]: - """ - Extract horizontal domain sizes from an `offset_provider`. +#: Threshold fraction of domain points which may be added to a domain on translation in order +#: to have a contiguous domain before a warning is raised. +_NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD: float = 1 / 4 - Considers the shape of the neighbor table to get the size of each `source_dim` and the maximum - value inside the neighbor table to get the size of each `codomain`. - """ - sizes = dict[str, int]() - for provider in offset_provider.values(): - if common.is_neighbor_connectivity(provider): - conn_type = provider.__gt_type__() - sizes[conn_type.source_dim.value] = max( - sizes.get(conn_type.source_dim.value, 0), provider.ndarray.shape[0] - ) - sizes[conn_type.codomain.value] = max( - sizes.get(conn_type.codomain.value, 0), - provider.ndarray.max() + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject - ) - return sizes +#: Offset tags for which a non-contiguous domain warning has already been printed +_NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS: set[str] = set() @dataclasses.dataclass(frozen=True) @@ -68,6 +58,78 @@ def empty(self) -> bool | None: } +def _unstructured_translate_range_statically( + range_: SymbolicRange, + tag: str, + val: itir.OffsetLiteral + | Literal[trace_shifts.Sentinel.VALUE, trace_shifts.Sentinel.ALL_NEIGHBORS], + offset_provider: common.OffsetProvider, + expr: itir.Expr | None = None, +) -> SymbolicRange: + """ + Translate `range_` using static connectivity information from `offset_provider`. + """ + assert common.is_offset_provider(offset_provider) + connectivity = offset_provider[tag] + assert isinstance(connectivity, common.Connectivity) + skip_value = connectivity.skip_value + + # fold & convert expr into actual integers + start_expr, stop_expr = range_.start, range_.stop + start_expr, stop_expr = ( # type: ignore[assignment] # mypy not smart enough + collapse_tuple.CollapseTuple.apply( + expr, + within_stencil=False, + allow_undeclared_symbols=True, + ) + for expr in (start_expr, stop_expr) + ) + assert isinstance(start_expr, itir.Literal) and isinstance(stop_expr, itir.Literal) + start, stop = int(start_expr.value), int(stop_expr.value) + + nb_index: slice | int + if val in [trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE]: + nb_index = slice(None) + else: + nb_index = val.value # type: ignore[assignment] # assert above + + accessed = connectivity.ndarray[start:stop, nb_index] + + if isinstance(val, itir.OffsetLiteral) and np.any(accessed == skip_value): + # TODO(tehrengruber): Turn this into a configurable error. This is currently + # not possible since some test cases starting from ITIR containing + # `can_deref` might lead here. The frontend never emits such IR and domain + # inference runs after we transform reductions into stmts containing + # `can_deref`. + warnings.warn( + UserWarning(f"Translating '{expr}' using '{tag}' has an out-of-bounds access."), + stacklevel=2, + ) + + new_start, new_stop = accessed.min(), accessed.max() + 1 # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject + + fraction_accessed = np.unique(accessed).size / (new_stop - new_start) # type: ignore[call-overload] # TODO(havogt): improve typing for NDArrayObject + + if fraction_accessed < _NON_CONTIGUOUS_DOMAIN_WARNING_THRESHOLD and ( + tag not in _NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS + ): + _NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS.add(tag) + warnings.warn( + UserWarning( + f"Translating '{expr}' using '{tag}' requires " + f"computations on many additional points " + f"({round((1 - fraction_accessed) * 100)}%) in order to get a contiguous " + f"domain. Please consider reordering your mesh." + ), + stacklevel=2, + ) + + return SymbolicRange( + im.literal(str(new_start), builtins.INTEGER_INDEX_BUILTIN), + im.literal(str(new_stop), builtins.INTEGER_INDEX_BUILTIN), + ) + + @dataclasses.dataclass(frozen=True) class SymbolicDomain: grid_type: common.GridType @@ -114,7 +176,7 @@ def translate( offset_provider: common.OffsetProvider | common.OffsetProviderType, #: A dictionary mapping axes names to their length. See #: func:`gt4py.next.iterator.transforms.infer_domain.infer_expr` for more details. - symbolic_domain_sizes: Optional[dict[str, str]] = None, + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None, ) -> SymbolicDomain: offset_provider_type = common.offset_provider_to_type(offset_provider) @@ -144,28 +206,20 @@ def translate( trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE, ] - horizontal_sizes: dict[str, itir.Expr] - if symbolic_domain_sizes is not None: - horizontal_sizes = { - k: im.ensure_expr(v) for k, v in symbolic_domain_sizes.items() - } - else: - # note: ugly but cheap re-computation, but should disappear - assert common.is_offset_provider(offset_provider) - horizontal_sizes = { - k: im.literal(str(v), builtins.INTEGER_INDEX_BUILTIN) - for k, v in _max_domain_sizes_by_location_type(offset_provider).items() - } - old_dim = connectivity_type.source_dim new_dim = connectivity_type.codomain - assert new_dim not in new_ranges or old_dim == new_dim + if symbolic_domain_sizes is not None and new_dim.value in symbolic_domain_sizes: + new_range = SymbolicRange( + im.literal(str(0), builtins.INTEGER_INDEX_BUILTIN), + im.ensure_expr(symbolic_domain_sizes[new_dim.value]), + ) + else: + assert common.is_offset_provider(offset_provider) + new_range = _unstructured_translate_range_statically( + new_ranges[old_dim], off.value, val, offset_provider, self.as_expr() + ) - new_range = SymbolicRange( - im.literal("0", builtins.INTEGER_INDEX_BUILTIN), - horizontal_sizes[new_dim.value], - ) new_ranges = dict( (dim, range_) if dim != old_dim else (new_dim, new_range) for dim, range_ in new_ranges.items() diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index fefca65a62..53474bbcff 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -453,7 +453,7 @@ def _impl(*its: itir.Expr) -> itir.FunCall: def domain( grid_type: Union[common.GridType, str], - ranges: dict[common.Dimension, tuple[itir.Expr, itir.Expr]], + ranges_or_domain: dict[common.Dimension, tuple[itir.Expr, itir.Expr]] | common.Domain, ) -> itir.FunCall: """ >>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) @@ -463,6 +463,13 @@ def domain( >>> str(domain(common.GridType.UNSTRUCTURED, {IDim: (0, 10), JDim: (0, 20)})) 'u⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩' """ + if isinstance(ranges_or_domain, common.Domain): + domain = ranges_or_domain + ranges = {d: (r.start, r.stop) for d, r in zip(domain.dims, domain.ranges)} + else: + assert isinstance(ranges_or_domain, dict) + ranges = ranges_or_domain + if isinstance(grid_type, common.GridType): grid_type = f"{grid_type!s}_domain" expr = call(grid_type)( diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index f9269314fb..0b9321ef33 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -77,7 +77,7 @@ class Transformation(enum.Flag): # `maximum(maximum(a, 1), 1)` -> `maximum(a, 1)` FOLD_MIN_MAX = enum.auto() - # `maximum(a + 1), a)` -> `a + 1` + # `maximum(a + 1, a)` -> `a + 1` # `maximum(a + 1, a + (-1))` -> `a + maximum(1, -1)` FOLD_MIN_MAX_PLUS = enum.auto() diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index d4a6543aa3..3bd734fa96 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -312,7 +312,7 @@ def create_global_tmps( offset_provider: common.OffsetProvider | common.OffsetProviderType, #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for #: more details. - symbolic_domain_sizes: Optional[dict[str, str]] = None, + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None, *, uids: Optional[eve_utils.UIDGenerator] = None, ) -> itir.Program: diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index d77ef9f096..f4dfd4096e 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -55,7 +55,7 @@ class DomainAccessDescriptor(eve.StrEnum): class InferenceOptions(typing.TypedDict): offset_provider: common.OffsetProvider | common.OffsetProviderType - symbolic_domain_sizes: Optional[dict[str, str]] + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] allow_uninferred: bool keep_existing_domains: bool @@ -126,7 +126,7 @@ def _extract_accessed_domains( input_ids: list[str], target_domain: NonTupleDomainAccess, offset_provider: common.OffsetProvider | common.OffsetProviderType, - symbolic_domain_sizes: Optional[dict[str, str]], + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]], ) -> dict[str, NonTupleDomainAccess]: accessed_domains: dict[str, NonTupleDomainAccess] = {} @@ -182,7 +182,7 @@ def _infer_as_fieldop( target_domain: DomainAccess, *, offset_provider: common.OffsetProvider | common.OffsetProviderType, - symbolic_domain_sizes: Optional[dict[str, str]], + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]], allow_uninferred: bool, keep_existing_domains: bool, ) -> tuple[itir.FunCall, AccessedDomains]: @@ -441,7 +441,7 @@ def infer_expr( domain: DomainAccess, *, offset_provider: common.OffsetProvider | common.OffsetProviderType, - symbolic_domain_sizes: Optional[dict[str, str]] = None, + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None, allow_uninferred: bool = False, keep_existing_domains: bool = False, ) -> tuple[itir.Expr, AccessedDomains]: @@ -457,7 +457,7 @@ def infer_expr( Keyword Arguments: - symbolic_domain_sizes: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol - name that evaluates to the length of that axis. + name or expression that evaluates to the length of that axis. - allow_uninferred: Allow `as_fieldop` expressions whose domain is either unknown (e.g. because of a dynamic shift) or never accessed. - keep_existing_domains: If `True`, keep existing domains in `as_fieldop` expressions and @@ -557,7 +557,7 @@ def infer_program( program: itir.Program, *, offset_provider: common.OffsetProvider | common.OffsetProviderType, - symbolic_domain_sizes: Optional[dict[str, str]] = None, + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None, allow_uninferred: bool = False, keep_existing_domains: bool = False, ) -> itir.Program: diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index b334ad796d..22bf93d7f2 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -5,12 +5,13 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - +import warnings from typing import Optional, Protocol from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import ( concat_where, dead_code_elimination, @@ -23,6 +24,7 @@ inline_lifts, prune_empty_concat_where, remove_broadcast, + symbol_ref_utils, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -43,13 +45,97 @@ def __call__( ) -> itir.Program: ... +def _max_domain_range_sizes(offset_provider: common.OffsetProvider) -> dict[str, itir.Literal]: + """ + Extract horizontal domain sizes from an `offset_provider`. + + Considers the shape of the neighbor table to get the size of each `source_dim` and the maximum + value inside the neighbor table to get the size of each `codomain`. + """ + sizes: dict[str, int] = {} + for provider in offset_provider.values(): + if common.is_neighbor_connectivity(provider): + conn_type = provider.__gt_type__() + sizes[conn_type.source_dim.value] = max( + sizes.get(conn_type.source_dim.value, 0), provider.ndarray.shape[0] + ) + sizes[conn_type.codomain.value] = max( + sizes.get(conn_type.codomain.value, 0), + int(provider.ndarray.max()) + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject + ) + sizes_exprs = {k: im.literal_from_value(v) for k, v in sizes.items()} + return sizes_exprs + + +def _has_dynamic_domains(ir: itir.Program) -> bool: + # note: this function does not respect symbol collisions with builtins. As it is a temporary + # workaround we don't care about this corner case. + domains = set() + domains |= ir.walk_values().if_isinstance(itir.SetAt).getattr("domain").to_set() + for as_fop in ( + ir.walk_values() + .if_isinstance(itir.FunCall) + .filter(lambda node: cpm.is_call_to(node, "as_fieldop") and len(node.args) == 2) + ): + domains.add(as_fop.args[1]) + return len(symbol_ref_utils.collect_symbol_refs(domains)) > 0 + + +def _process_symbolic_domains_option( + ir: itir.Program, + offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]], + use_max_domain_range_on_unstructured_shift: Optional[bool], +) -> Optional[dict[str, str | itir.Expr]]: + """ + Given a program, offset_provider and some configuration options determine how domains are + inferred. + + The output of this function is used as `symbolic_domain_sizes` argument of domain inference, i.e. + :func:`infer_domain.infer_program`. + + Right now domains of `as_fieldop` expressions can be inferred either a) using static information + from the offset provider, or b) they are set to an expression controlled by + the user and configured in the backend, or c) they are set to the maximum possible domain / + everywhere (see :func:`_max_domain_range_sizes`) + + Option a) applies when the program is decorated with `static_domains = True` (unless option c) + is explicitly requested). Then all dynamic domains were replaced with static ones + which we recognize here. The domain inference then uses this static information which we + communicate by returning `None`, i.e. no symbolic domain sizes. + Option b) applies when the user explicitly configured `symbolic_domain_sizes` in the backend. + In that case we just forward the value. + Option c) applies when `static_domains = False` or when explicitly configured in the backend + with `use_max_domain_range_on_unstructured_shift = True`. In that case we determine the + maximum sizes using :func:`_max_domain_range_sizes` and return them. + """ + if symbolic_domain_sizes: + assert not use_max_domain_range_on_unstructured_shift, "Options are mutually exclusive." + return symbolic_domain_sizes + + has_dynamic_domains = _has_dynamic_domains(ir) + if has_dynamic_domains and use_max_domain_range_on_unstructured_shift is None: + use_max_domain_range_on_unstructured_shift = True + else: + use_max_domain_range_on_unstructured_shift = False + if use_max_domain_range_on_unstructured_shift: + if not has_dynamic_domains: + warnings.warn( + "You are using static domains together with " + "'use_max_domain_range_on_unstructured_shift'. This is" + "likely not what you wanted.", + stacklevel=2, + ) + assert not symbolic_domain_sizes, "Options are mutually exclusive." + symbolic_domain_sizes = _max_domain_range_sizes(offset_provider) # type: ignore[assignment] + return symbolic_domain_sizes + + # TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward # `extract_temporaries` and `temporary_extraction_heuristics` which is inconvenient. def apply_common_transforms( ir: itir.Program, *, - # TODO(havogt): should be replaced by `common.OffsetProviderType`, but global_tmps currently - # relies on runtime info or `symbolic_domain_sizes`. offset_provider: common.OffsetProvider | common.OffsetProviderType, extract_temporaries=False, unroll_reduce=False, @@ -57,12 +143,22 @@ def apply_common_transforms( force_inline_lambda_args=False, #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for #: more details. - symbolic_domain_sizes: Optional[dict[str, str]] = None, + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None, + # TODO(tehrengruber): Remove this option again as soon as we have the necessary builtins + # to work with / translate domains. + use_max_domain_range_on_unstructured_shift: Optional[bool] = None, ) -> itir.Program: assert isinstance(ir, itir.Program) + # TODO(tehrengruber): Allow `common.OffsetProviderType`, but domain inference currently + # relies on static information or `symbolic_domain_sizes`. + assert common.is_offset_provider(offset_provider) offset_provider_type = common.offset_provider_to_type(offset_provider) + symbolic_domain_sizes = _process_symbolic_domains_option( + ir, offset_provider, symbolic_domain_sizes, use_max_domain_range_on_unstructured_shift + ) + tmp_uids = eve_utils.UIDGenerator(prefix="__tmp") mergeasfop_uids = eve_utils.UIDGenerator() collapse_tuple_uids = eve_utils.UIDGenerator() @@ -170,10 +266,19 @@ def apply_common_transforms( def apply_fieldview_transforms( - ir: itir.Program, *, offset_provider: common.OffsetProvider + ir: itir.Program, + *, + offset_provider: common.OffsetProvider, + # TODO(tehrengruber): Remove this option again as soon as we have the necessary builtins + # to work with / translate domains. + use_max_domain_range_on_unstructured_shift: Optional[bool] = None, ) -> itir.Program: offset_provider_type = common.offset_provider_to_type(offset_provider) + symbolic_domain_sizes = _process_symbolic_domains_option( + ir, offset_provider, None, use_max_domain_range_on_unstructured_shift + ) + ir = inline_fundefs.InlineFundefs().visit(ir) ir = inline_fundefs.prune_unreferenced_fundefs(ir) # required for dead-code-elimination and `prune_empty_concat_where` pass @@ -187,7 +292,11 @@ def apply_fieldview_transforms( ir = concat_where.canonicalize_domain_argument(ir) ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program - ir = infer_domain.infer_program(ir, offset_provider=offset_provider) + ir = infer_domain.infer_program( + ir, + symbolic_domain_sizes=symbolic_domain_sizes, + offset_provider=offset_provider, + ) ir = prune_empty_concat_where.prune_empty_concat_where(ir) ir = remove_broadcast.RemoveBroadcast.apply(ir) return ir diff --git a/src/gt4py/next/iterator/transforms/transform_get_domain_range.py b/src/gt4py/next/iterator/transforms/replace_get_domain_range_with_constants.py similarity index 86% rename from src/gt4py/next/iterator/transforms/transform_get_domain_range.py rename to src/gt4py/next/iterator/transforms/replace_get_domain_range_with_constants.py index c34ba61a28..f228c24998 100644 --- a/src/gt4py/next/iterator/transforms/transform_get_domain_range.py +++ b/src/gt4py/next/iterator/transforms/replace_get_domain_range_with_constants.py @@ -7,10 +7,10 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses -from typing import Dict from gt4py._core import definitions as core_defs from gt4py.eve import NodeTranslator, PreserveLocationVisitor +from gt4py.eve.extended_typing import MaybeNestedInTuple from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ( @@ -28,7 +28,7 @@ def visit_Node(self, node: itir.Node, **kwargs): return None # means we could not deduce the domain def visit_SymRef( - self, node: itir.SymRef, *, sizes: Dict[str, common.Domain], **kwargs + self, node: itir.SymRef, *, sizes: dict[str, MaybeNestedInTuple[common.Domain]], **kwargs ) -> DomainOrTupleThereof | None: return sizes.get(node.id, None) @@ -48,9 +48,9 @@ def visit_FunCall(self, node, **kwargs): @dataclasses.dataclass(frozen=True) -class TransformGetDomainRange(PreserveLocationVisitor, NodeTranslator): +class ReplaceGetDomainRangeWithConstants(PreserveLocationVisitor, NodeTranslator): """ - Transforms `get_domain` calls into a tuple containing start and stop. + Replace `get_domain` calls into a tuple containing start and stop. Example: >>> from gt4py import next as gtx @@ -86,7 +86,7 @@ class TransformGetDomainRange(PreserveLocationVisitor, NodeTranslator): ... ), ... ], ... ) - >>> result = TransformGetDomainRange.apply(ir, sizes=sizes) + >>> result = ReplaceGetDomainRangeWithConstants.apply(ir, sizes=sizes) >>> print(result) test(inp, out) { out @ u⟨ Vertexₕ: [{0, 10}[0], {0, 10}[1][, KDimᵥ: [{0, 20}[0], {0, 20}[1][ ⟩ ← (⇑deref)(inp); @@ -94,7 +94,9 @@ class TransformGetDomainRange(PreserveLocationVisitor, NodeTranslator): """ @classmethod - def apply(cls, program: itir.Program, sizes: Dict[str, common.Domain]): + def apply( + cls, program: itir.Program, sizes: dict[str, MaybeNestedInTuple[common.Domain | None]] + ): return cls().visit(program, sizes=sizes) def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.FunCall: @@ -115,7 +117,4 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.FunCall: index = next((i for i, d in enumerate(domain.dims) if d.value == dim.value), None) assert index is not None, f"Dimension {dim.value} not found in {domain.dims}" - start = domain.ranges[index].start - stop = domain.ranges[index].stop - node = im.make_tuple(start, stop) - return node + return im.make_tuple(domain.ranges[index].start, domain.ranges[index].stop) diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index d7ba35eed0..e87009c564 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -10,7 +10,7 @@ from collections import Counter import gt4py.eve as eve -from gt4py.eve.extended_typing import Iterable, Literal, Optional, Sequence, cast, overload +from gt4py.eve.extended_typing import Iterable, Literal, Optional, cast, overload from gt4py.next.iterator import ir as itir @@ -22,7 +22,7 @@ class CountSymbolRefs(eve.PreserveLocationVisitor, eve.NodeVisitor): @classmethod def apply( cls, - node: itir.Node | Sequence[itir.Node], + node: itir.Node | Iterable[itir.Node], symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, @@ -33,7 +33,7 @@ def apply( @classmethod def apply( cls, - node: itir.Node | Sequence[itir.Node], + node: itir.Node | Iterable[itir.Node], symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, @@ -43,7 +43,7 @@ def apply( @classmethod def apply( cls, - node: itir.Node | Sequence[itir.Node], + node: itir.Node | Iterable[itir.Node], symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, @@ -101,7 +101,7 @@ def visit_Lambda(self, node: itir.Lambda, *, inactive_refs: set[str]): @overload def collect_symbol_refs( - node: itir.Node | Sequence[itir.Node], + node: itir.Node | Iterable[itir.Node], symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, @@ -111,7 +111,7 @@ def collect_symbol_refs( @overload def collect_symbol_refs( - node: itir.Node | Sequence[itir.Node], + node: itir.Node | Iterable[itir.Node], symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, @@ -120,7 +120,7 @@ def collect_symbol_refs( def collect_symbol_refs( - node: itir.Node | Sequence[itir.Node], + node: itir.Node | Iterable[itir.Node], symbol_names: Optional[Iterable[str]] = None, *, ignore_builtins: bool = True, diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index ca54d12c88..dafba7ef03 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -111,6 +111,15 @@ def attribute_extractor_exprs(cls, arg_expr: str) -> dict[str, str]: return {"value": arg_expr} +@dataclasses.dataclass(frozen=True) +class FieldDomainDescriptor(ArgStaticDescriptor): + domain: common.Domain + + @classmethod + def attribute_extractor_exprs(cls, arg_expr: str) -> dict[str, str]: + return {"domain": f"({arg_expr}).domain"} + + @dataclasses.dataclass(frozen=True) class JITArgs: """Concrete (runtime) arguments to a GTX program in a format that can be passed into the toolchain.""" diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index d873e7da17..7c3307b486 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -362,7 +362,7 @@ def _initialize_argument_descriptor_mapping( self._validate_argument_descriptor_mapping() else: for descr_cls, descriptor_expr_mapping in argument_descriptors.items(): - if (expected := set(self.argument_descriptor_mapping[descr_cls])) != ( + if (expected := set(self.argument_descriptor_mapping.get(descr_cls, {}))) != ( got := set(descriptor_expr_mapping.keys()) ): raise ValueError( diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 0c76757d70..63d9416bc5 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -52,7 +52,8 @@ class GTFNTranslationStep( enable_itir_transforms: bool = True use_imperative_backend: bool = False device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - symbolic_domain_sizes: Optional[dict[str, str]] = None + symbolic_domain_sizes: Optional[dict[str, str | itir.Expr]] = None + use_max_domain_range_on_unstructured_shift: Optional[bool] = None def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: match self.device_type: @@ -163,6 +164,7 @@ def _preprocess_program( extract_temporaries=True, offset_provider=offset_provider, symbolic_domain_sizes=self.symbolic_domain_sizes, + use_max_domain_range_on_unstructured_shift=self.use_max_domain_range_on_unstructured_shift, ) new_program = apply_common_transforms( diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index e5a556eb4f..18d15e21f4 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -354,6 +354,7 @@ class DaCeTranslator( disable_itir_transforms: bool = False disable_field_origin_on_program_arguments: bool = False + use_max_domain_range_on_unstructured_shift: Optional[bool] = None def generate_sdfg( self, @@ -370,7 +371,11 @@ def _generate_sdfg_without_configuring_dace( column_axis: Optional[common.Dimension], ) -> dace.SDFG: if not self.disable_itir_transforms: - ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) + ir = itir_transforms.apply_fieldview_transforms( + ir, + use_max_domain_range_on_unstructured_shift=self.use_max_domain_range_on_unstructured_shift, + offset_provider=offset_provider, + ) offset_provider_type = common.offset_provider_to_type(offset_provider) on_gpu = self.device_type != core_defs.DeviceType.CPU diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index a9374c6a24..2e1e5ee93a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -5,7 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - +from typing import Optional from unittest import mock import numpy as np @@ -33,6 +33,7 @@ skip_value_mesh, ) +from gt4py.next.otf import arguments _raise_on_compile = mock.Mock() _raise_on_compile.compile.side_effect = AssertionError("This function should never be called.") @@ -849,3 +850,54 @@ def test_wait_for_compilation(cartesian_case, compile_testee, compile_testee_dom gtx.wait_for_compilation() # ... and afterwards compilation still works compile_testee_domain.compile(offset_provider=cartesian_case.offset_provider) + + +def test_compile_variants_decorator_static_domains(compile_variants_field_operator, cartesian_case): + if cartesian_case.backend is None: + pytest.skip("Embedded compiled program doesn't make sense.") + + captured_cargs: Optional[arguments.CompileTimeArgs] = None + + class CaptureCompileTimeArgsBackend: + def __getattr__(self, name): + return getattr(cartesian_case.backend, name) + + def compile(self, program, compile_time_args): + nonlocal captured_cargs + captured_cargs = compile_time_args + + return cartesian_case.backend.compile(program, compile_time_args) + + @gtx.field_operator + def identity_like(inp: tuple[cases.IField, cases.IField, float]): + return inp[0], inp[1] + + # the float argument here is merely to test that static domains work for tuple arguments + # of inhomogeneous types + @gtx.program(backend=CaptureCompileTimeArgsBackend(), static_domains=True) + def testee( + inp: tuple[cases.IField, cases.IField, float], out: tuple[cases.IField, cases.IField] + ): + identity_like(inp, out=out) + + inp = cases.allocate(cartesian_case, testee, "inp")() + out = cases.allocate(cartesian_case, testee, "out")() + + testee(inp, out, offset_provider={}) + assert np.allclose(inp[0].ndarray, out[0].ndarray) + assert np.allclose(inp[1].ndarray, out[1].ndarray) + + assert testee._compiled_programs.argument_descriptor_mapping[ + arguments.FieldDomainDescriptor + ] == ["inp[0]", "inp[1]", "out[0]", "out[1]"] + assert captured_cargs.argument_descriptor_contexts[arguments.FieldDomainDescriptor] == { + "inp": ( + arguments.FieldDomainDescriptor(inp[0].domain), + arguments.FieldDomainDescriptor(inp[1].domain), + None, + ), + "out": ( + arguments.FieldDomainDescriptor(out[0].domain), + arguments.FieldDomainDescriptor(out[1].domain), + ), + } diff --git a/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py index 69a2ed772b..04e820cd27 100644 --- a/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py +++ b/tests/next_tests/unit_tests/iterator_tests/ir_utils_test.py/test_domain_utils.py @@ -7,13 +7,21 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest - +import numpy as np from gt4py.next import common +from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import domain_utils, ir_makers as im +from gt4py.next import backend as next_backend, common, allocators as next_allocators, constructors I = common.Dimension("I") J = common.Dimension("J") +K = common.Dimension("J", kind=common.DimensionKind.VERTICAL) +Vertex = common.Dimension("Vertex") +Edge = common.Dimension("Edge") +V2EDim = common.Dimension("V2E", kind=common.DimensionKind.LOCAL) +E2VDim = common.Dimension("E2V", kind=common.DimensionKind.LOCAL) +V2VDim = common.Dimension("V2V", kind=common.DimensionKind.LOCAL) a_range = domain_utils.SymbolicRange(0, 10) another_range = domain_utils.SymbolicRange(5, 15) @@ -180,3 +188,99 @@ def test_is_finite_symbolic_domain(ranges, expected): ) == expected ) + + +@pytest.mark.parametrize( + "shift_chain, expected_end_domain", + [ + (("V2V", 0), {Vertex: (0, 4)}), + (("V2V", 1), {Vertex: (0, 4)}), + (("V2V", 2), {Vertex: (0, 1)}), + (("V2V", 3), {Vertex: (1, 4)}), + (("V2V", 0, "V2V", 3, "V2V", 0), {Vertex: (1, 4)}), + (("V2E", 0), {Edge: (0, 4)}), + (("V2E", 0, "E2V", 0), {Vertex: (0, 4)}), + (("V2V", 3, "V2E", 0), {Edge: (1, 4)}), + ], +) +def test_unstructured_translate(shift_chain, expected_end_domain): + offset_provider = { + "V2V": constructors.as_connectivity( + domain={Vertex: (0, 4), V2VDim: 5}, + codomain=Vertex, + data=np.asarray( + [[0, 3, 0, 1, -1], [1, 2, 0, 1, 1], [2, 1, 0, 3, 2], [3, 0, 0, 3, -1]], + dtype=fbuiltins.IndexType, + ), + ), + "V2E": constructors.as_connectivity( + domain={Vertex: (0, 4), V2EDim: 1}, + codomain=Edge, + data=np.asarray( + [ + [0, 1, 2, 3], + ], + dtype=fbuiltins.IndexType, + ).reshape((4, 1)), + ), + "E2V": constructors.as_connectivity( + domain={Edge: (0, 4), E2VDim: 1}, + codomain=Vertex, + data=np.asarray( + [ + [0, 1, 2, 3], + ], + dtype=fbuiltins.IndexType, + ).reshape((4, 1)), + ), + } + shift_chain = [im.ensure_offset(o) for o in shift_chain] + expected_end_domain = im.domain(common.GridType.UNSTRUCTURED, expected_end_domain) + + init_domain = domain_utils.SymbolicDomain.from_expr( + im.domain(common.GridType.UNSTRUCTURED, {Vertex: (0, 4)}) + ) + end_domain = init_domain.translate(shift_chain, offset_provider).as_expr() + assert end_domain == expected_end_domain + + +def test_non_contiguous_domain_warning(monkeypatch): + monkeypatch.setattr(domain_utils, "_NON_CONTIGUOUS_DOMAIN_WARNING_SKIPPED_OFFSET_TAGS", set()) + + offset_provider = { + "V2V": constructors.as_connectivity( + domain={Vertex: (0, 100), V2VDim: 1}, + codomain=Vertex, + data=np.asarray([0] + [99] * 99, dtype=fbuiltins.IndexType).reshape((100, 1)), + ) + } + shift_chain = ("V2V", 0) + shift_chain = [im.ensure_offset(o) for o in shift_chain] + domain = domain_utils.SymbolicDomain.from_expr( + im.domain(common.GridType.UNSTRUCTURED, {Vertex: (0, 2)}) + ) + with pytest.warns( + UserWarning, + match=r"98%.*Please consider reordering your mesh.", + ): + domain.translate(shift_chain, offset_provider).as_expr() + + +def test_oob_error(): + offset_provider = { + "V2V": constructors.as_connectivity( + domain={Vertex: (0, 3), V2VDim: 1}, + codomain=Vertex, + data=np.asarray([0, -1, 1], dtype=fbuiltins.IndexType).reshape((3, 1)), + ) + } + shift_chain = ("V2V", 0) + shift_chain = [im.ensure_offset(o) for o in shift_chain] + domain = domain_utils.SymbolicDomain.from_expr( + im.domain(common.GridType.UNSTRUCTURED, {Vertex: (0, 3)}) + ) + with pytest.warns( + UserWarning, + match=r"out-of-bounds", + ): + domain.translate(shift_chain, offset_provider).as_expr() diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 0d9a55ceef..dd405ca7b0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -229,7 +229,8 @@ def test_multi_length_shift(offset_provider): def test_unstructured_shift(unstructured_offset_provider): stencil = im.lambda_("arg0")(im.deref(im.shift("E2V", 1)("arg0"))) domain = im.domain(common.GridType.UNSTRUCTURED, {Edge: (0, 1)}) - expected_domains = {"in_field1": {Vertex: (0, 2)}} + accessed_vertex = unstructured_offset_provider["E2V"].ndarray[0, 1] + expected_domains = {"in_field1": {Vertex: (accessed_vertex, accessed_vertex + np.int32(1))}} testee, expected = setup_test_as_fieldop(stencil, domain, expected_domains=expected_domains) run_test_expr(testee, expected, domain, expected_domains, unstructured_offset_provider) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_replace_get_domain_range_with_constants.py similarity index 78% rename from tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py rename to tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_replace_get_domain_range_with_constants.py index a4864ff00e..322c04f5e5 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_transform_get_domain_range.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_replace_get_domain_range_with_constants.py @@ -21,7 +21,9 @@ from gt4py.next import Domain, common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms.transform_get_domain_range import TransformGetDomainRange +from gt4py.next.iterator.transforms.replace_get_domain_range_with_constants import ( + ReplaceGetDomainRangeWithConstants, +) from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -64,25 +66,24 @@ def run_test_program( params=params, body=[setat_factory(domain=domain, target=im.ref(target))], ) - actual = TransformGetDomainRange.apply(testee, sizes=sizes) + actual = ReplaceGetDomainRangeWithConstants.apply(testee, sizes=sizes) actual = CollapseTuple.apply( actual, enabled_transformations=CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE ) assert actual == expected -def domain_as_expr(domain: gtx.Domain) -> itir.Expr: - return im.domain( - common.GridType.UNSTRUCTURED, - {d: (r.start, r.stop) for d, r in zip(domain.dims, domain.ranges)}, - ) - - def test_get_domain(): sizes = {"out": gtx.domain({Vertex: (0, 10), KDim: (0, 20)})} get_domain_expr = im.get_field_domain(common.GridType.UNSTRUCTURED, "out", sizes["out"].dims) - run_test_program(["inp", "out"], sizes, "out", domain_as_expr(sizes["out"]), get_domain_expr) + run_test_program( + ["inp", "out"], + sizes, + "out", + im.domain(common.GridType.UNSTRUCTURED, sizes["out"]), + get_domain_expr, + ) def test_get_domain_tuples(): @@ -92,7 +93,13 @@ def test_get_domain_tuples(): common.GridType.UNSTRUCTURED, im.tuple_get(1, "out"), sizes["out"][1].dims ) - run_test_program(["inp", "out"], sizes, "out", domain_as_expr(sizes["out"][1]), get_domain_expr) + run_test_program( + ["inp", "out"], + sizes, + "out", + im.domain(common.GridType.UNSTRUCTURED, sizes["out"][1]), + get_domain_expr, + ) def test_get_domain_nested_tuples(): @@ -110,5 +117,9 @@ def test_get_domain_nested_tuples(): ) run_test_program( - ["inp", "a", "b", "c", "d"], sizes, "a", domain_as_expr(sizes["a"]), get_domain_expr + ["inp", "a", "b", "c", "d"], + sizes, + "a", + im.domain(common.GridType.UNSTRUCTURED, sizes["a"]), + get_domain_expr, ) diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index 59b12b3f0a..f1f375af4a 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -9,9 +9,10 @@ import pytest from gt4py import eve, next as gtx -from gt4py.next import errors, backend -from gt4py.next.ffront import type_specifications as ts_ffront -from gt4py.next.otf import compiled_program, toolchain, arguments +from gt4py.next import errors, backend, broadcast, common +from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.otf import toolchain, arguments from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator import ir as itir from gt4py.next.program_processors.runners import gtfn @@ -49,18 +50,16 @@ def test_sanitize_static_args_wrong_type(): @gtx.field_operator -def fop(cond: bool, a: gtx.Field[gtx.Dims[TDim], float], b: gtx.Field[gtx.Dims[TDim], float]): - return a if cond else b +def fop(cond: bool): + return broadcast(cond, (TDim,)) @gtx.program def prog( cond: bool, - a: gtx.Field[gtx.Dims[TDim], gtx.float64], - b: gtx.Field[gtx.Dims[TDim], gtx.float64], - out: gtx.Field[gtx.Dims[TDim], gtx.float64], + out: gtx.Field[gtx.Dims[TDim], bool], ): - fop(cond, a, b, out=out) + fop(cond, out=out) def _verify_program_has_expected_true_value(program: itir.Program): @@ -110,10 +109,35 @@ def pirate(program: toolchain.CompilableProgram): testee = prog.with_backend(hacked_gtfn_backend).compile(cond=[True], offset_provider={}) testee( cond=True, - a=gtx.zeros(domain={TDim: 1}, dtype=gtx.float64), - b=gtx.zeros(domain={TDim: 1}, dtype=gtx.float64), - out=gtx.zeros(domain={TDim: 1}, dtype=gtx.float64), + out=gtx.zeros(domain={TDim: 1}, dtype=bool), offset_provider={}, ) _verify_program_has_expected_true_value(hijacked_program.data) + + +def _verify_program_has_expected_domain(program: itir.Program, expected_domain: gtx.Domain): + assert isinstance(program.body[0], itir.SetAt) + assert isinstance(program.body[0].expr, itir.FunCall) + assert program.body[0].expr.fun == itir.SymRef(id="fop") + domain = CollapseTuple.apply(program.body[0].domain, within_stencil=False) + assert domain == im.domain(common.GridType.CARTESIAN, expected_domain) + + +def test_inlining_of_static_domain_works(): + domain = gtx.Domain(dims=(TDim,), ranges=(gtx.UnitRange(0, 1),)) + input_pair = toolchain.CompilableProgram( + data=prog.definition_stage, + args=arguments.CompileTimeArgs( + args=list(prog.past_stage.past_node.type.definition.pos_or_kw_args.values()), + kwargs={}, + offset_provider={}, + column_axis=None, + argument_descriptor_contexts={ + arguments.FieldDomainDescriptor: {"out": arguments.FieldDomainDescriptor(domain)} + }, + ), + ) + + transformed = backend.DEFAULT_TRANSFORMS(input_pair).data + _verify_program_has_expected_domain(transformed, domain) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py index 9c3301645c..a31370baa1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_domain.py @@ -20,6 +20,7 @@ from gt4py.next.iterator.ir_utils import domain_utils, ir_makers as im from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + Cell, KDim, Vertex, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index bd848a2154..a6f5e8486b 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -22,6 +22,7 @@ from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import domain_utils, ir_makers as im from gt4py.next.iterator.transforms import infer_domain +from gt4py.next.iterator.transforms import pass_manager from gt4py.next.type_system import type_specifications as ts from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -128,7 +129,11 @@ def build_dace_sdfg( """ if not skip_domain_inference: # run domain inference in order to add the domain annex information to the IR nodes - ir = infer_domain.infer_program(ir, offset_provider=offset_provider) + ir = infer_domain.infer_program( + ir, + offset_provider=offset_provider, + symbolic_domain_sizes=pass_manager._max_domain_range_sizes(offset_provider), + ) offset_provider_type = gtx_common.offset_provider_to_type(offset_provider) return dace_backend.build_sdfg_from_gtir(ir, offset_provider_type, column_axis=KDim) @@ -220,7 +225,7 @@ def test_gtir_copy_self(): body=[ gtir.SetAt( expr=gtir.SymRef(id="x"), - domain=im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, 2)}), + domain=im.domain(gtx_common.GridType.CARTESIAN, {IDim: (1, 2)}), target=gtir.SymRef(id="x"), ) ], @@ -414,7 +419,7 @@ def test_gtir_tuple_broadcast_scalar(): def test_gtir_zero_dim_fields(): domain = im.get_field_domain(gtx_common.GridType.CARTESIAN, "y", [IDim]) - empty_domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={}) + empty_domain = im.domain(gtx_common.GridType.CARTESIAN, {}) testee = gtir.Program( id="gtir_zero_dim_fields", function_definitions=[], @@ -1642,7 +1647,7 @@ def test_gtir_let_lambda(): def test_gtir_let_lambda_scalar_expression(): - domain_inner = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, "size_inner")}) + domain_inner = im.domain(gtx_common.GridType.CARTESIAN, {IDim: (1, "size_inner")}) domain_outer = im.get_field_domain( gtx_common.GridType.CARTESIAN, "y", @@ -2144,8 +2149,6 @@ def test_gtir_concat_where(): ], ) - # run domain inference in order to add the domain annex information to the concat_where node. - testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) c = np.empty_like(a) @@ -2225,8 +2228,6 @@ def test_gtir_concat_where_two_dimensions(): "__z_JDim_stride": d.strides[1] // d.itemsize, } - # run domain inference in order to add the domain annex information to the concat_where node. - testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, c, d, **field_symbols)