diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 3bc042fbe0..619425ef8e 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -1263,11 +1263,21 @@ class GridType(StrEnum): UNSTRUCTURED = "unstructured" +def check_staggered(dim: Dimension) -> bool: + return dim.value.startswith(_STAGGERED_PREFIX) + + def order_dimensions(dims: Iterable[Dimension]) -> list[Dimension]: """Find the canonical ordering of the dimensions in `dims`.""" if sum(1 for dim in dims if dim.kind == DimensionKind.LOCAL) > 1: raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.") - return sorted(dims, key=lambda dim: (_DIM_KIND_ORDER[dim.kind], dim.value)) + return sorted( + dims, + key=lambda dim: ( + _DIM_KIND_ORDER[dim.kind], + flip_staggered(dim).value if check_staggered(dim) else dim.value, + ), + ) def check_dims(dims: Sequence[Dimension]) -> None: @@ -1355,3 +1365,22 @@ def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Call #: Equivalent to the `_FillValue` attribute in the UGRID Conventions #: (see: http://ugrid-conventions.github.io/ugrid-conventions/). _DEFAULT_SKIP_VALUE: Final[int] = -1 +_STAGGERED_PREFIX = "_Staggered" + + +def flip_staggered(dim: Dimension) -> Dimension: + if dim.value.startswith(_STAGGERED_PREFIX): + return Dimension(dim.value[len(_STAGGERED_PREFIX) :], dim.kind) + else: + return Dimension(f"{_STAGGERED_PREFIX}{dim.value}", dim.kind) + + +def connectivity_for_cartesian_shift(dim: Dimension, offset: int | float) -> CartesianConnectivity: + if isinstance(offset, float): + integral_offset, half = divmod(offset, 1) + assert half == 0.5 + if dim.value.startswith(_STAGGERED_PREFIX): + integral_offset += 1 + return CartesianConnectivity(dim, int(integral_offset), codomain=flip_staggered(dim)) + else: + return CartesianConnectivity(dim, offset, codomain=dim) diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py index cf1d5faec1..fedffb2620 100644 --- a/src/gt4py/next/config.py +++ b/src/gt4py/next/config.py @@ -64,7 +64,7 @@ def env_flag_to_int(name: str, default: int) -> int: #: Master debug flag #: Changes defaults for all the other options to be as helpful for debugging as possible. #: Does not override values set in environment variables. -DEBUG: Final[bool] = env_flag_to_bool("GT4PY_DEBUG", default=False) +DEBUG: Final[bool] = env_flag_to_bool("GT4PY_DEBUG", default=True) #: Verbose flag for DSL compilation errors diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 68bf108a0a..5a72d60131 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -10,7 +10,7 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits -from gt4py.next import errors, utils +from gt4py.next import common, errors, utils from gt4py.next.common import DimensionKind, promote_dims from gt4py.next.ffront import ( # noqa dialect_ast_enums, @@ -655,13 +655,20 @@ def _deduce_compare_type( def _deduce_binop_type( self, node: foast.BinOp, *, left: foast.Expr, right: foast.Expr, **kwargs: Any ) -> Optional[ts.TypeSpec]: - # e.g. `IDim+1` + # e.g. `IDim+1` or `IDim+0.5` if ( isinstance(left.type, ts.DimensionType) and isinstance(right.type, ts.ScalarType) - and type_info.is_integral(right.type) + and type_info.is_arithmetic(right.type) ): - return ts.OffsetType(source=left.type.dim, target=(left.type.dim,)) + if not isinstance(right, foast.Constant): + raise NotImplementedError() + offset_index = right.value + if node.op == dialect_ast_enums.BinaryOperator.SUB: + offset_index *= -1 + conn = common.connectivity_for_cartesian_shift(left.type.dim, offset_index) + return ts.OffsetType(source=conn.codomain, target=(conn.domain_dim,)) + if isinstance(left.type, ts.OffsetType): raise errors.DSLError( node.location, f"Type '{left.type}' can not be used in operator '{node.op}'." diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 4b29f02b41..36ec3892dc 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -305,17 +305,19 @@ def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: # `field(Dim + idx)` case foast.BinOp( op=dialect_ast_enums.BinaryOperator.ADD | dialect_ast_enums.BinaryOperator.SUB, - left=foast.Name(id=dimension), # TODO(tehrengruber): use type of lhs + left=foast.Name(), # TODO(tehrengruber): use type of lhs right=foast.Constant(value=offset_index), ): if arg.op == dialect_ast_enums.BinaryOperator.SUB: offset_index *= -1 - # TODO(havogt): we rely on the naming-convention for implicit offsets, see `dimension_to_implicit_offset` + conn = common.connectivity_for_cartesian_shift( + node.args[0].left.type.dim, offset_index + ) current_expr = im.as_fieldop( im.lambda_("__it")( im.deref( im.shift( - common.dimension_to_implicit_offset(dimension), offset_index + im.cartesian_offset(conn.domain_dim, conn.codomain), conn.offset )("__it") ) ) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 79ccc83cd2..d29b5370e7 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -81,10 +81,6 @@ def __str__(self): InfinityLiteral.POSITIVE = InfinityLiteral(name="POSITIVE") -class OffsetLiteral(Expr): - value: Union[int, str] - - class AxisLiteral(Expr): # TODO(havogt): Refactor to use declare Axis/Dimension at the Program level. # Now every use of the literal has to provide the kind, where usually we only care of the name. @@ -92,6 +88,16 @@ class AxisLiteral(Expr): kind: common.DimensionKind = common.DimensionKind.HORIZONTAL +class CartesianOffset(Expr): + domain: AxisLiteral + codomain: AxisLiteral + + +# TODO(tehrengruber): allow int only and create OffsetRef for str instead +class OffsetLiteral(Expr): + value: Union[int, str] + + class SymRef(Expr): id: Coerced[SymbolRef] @@ -150,8 +156,9 @@ class Program(Node, ValidatedSymbolTableTrait): Expr.__hash__ = Node.__hash__ # type: ignore[method-assign] Literal.__hash__ = Node.__hash__ # type: ignore[method-assign] NoneLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign] -OffsetLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign] AxisLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign] +OffsetLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign] +CartesianOffset.__hash__ = Node.__hash__ # type: ignore[method-assign] SymRef.__hash__ = Node.__hash__ # type: ignore[method-assign] Lambda.__hash__ = Node.__hash__ # type: ignore[method-assign] FunCall.__hash__ = Node.__hash__ # type: ignore[method-assign] diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 3fa088d785..ea41e03550 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -124,54 +124,62 @@ def translate( return self if len(shift) == 2: off, val = shift - assert isinstance(off, itir.OffsetLiteral) and isinstance(off.value, str) - connectivity_type = common.get_offset_type(offset_provider_type, off.value) - - if isinstance(connectivity_type, common.Dimension): - if val is trace_shifts.Sentinel.VALUE: - raise NotImplementedError("Dynamic offsets not supported.") - assert isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int) - current_dim = connectivity_type - # cartesian offset - new_ranges[current_dim] = SymbolicRange.translate( - self.ranges[current_dim], val.value - ) - elif isinstance(connectivity_type, common.NeighborConnectivityType): - # unstructured shift - assert ( - isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int) - ) or val in [ - trace_shifts.Sentinel.ALL_NEIGHBORS, - trace_shifts.Sentinel.VALUE, - ] - horizontal_sizes: dict[str, itir.Expr] - if symbolic_domain_sizes is not None: - horizontal_sizes = { - k: im.ensure_expr(v) for k, v in symbolic_domain_sizes.items() - } - else: - # note: ugly but cheap re-computation, but should disappear - assert common.is_offset_provider(offset_provider) - horizontal_sizes = { - k: im.literal(str(v), builtins.INTEGER_INDEX_BUILTIN) - for k, v in _max_domain_sizes_by_location_type(offset_provider).items() - } - - old_dim = connectivity_type.source_dim - new_dim = connectivity_type.codomain - - assert new_dim not in new_ranges or old_dim == new_dim - - new_range = SymbolicRange( - im.literal("0", builtins.INTEGER_INDEX_BUILTIN), - horizontal_sizes[new_dim.value], - ) + if isinstance(off, itir.CartesianOffset): + old_dim = common.Dimension(value=off.domain.value, kind=off.domain.kind) + new_dim = common.Dimension(value=off.codomain.value, kind=off.codomain.kind) + new_range = SymbolicRange.translate(self.ranges[old_dim], val.value) new_ranges = dict( (dim, range_) if dim != old_dim else (new_dim, new_range) for dim, range_ in new_ranges.items() ) else: - raise AssertionError() + assert isinstance(off, itir.OffsetLiteral) and isinstance(off.value, str) + connectivity_type = common.get_offset_type(offset_provider_type, off.value) + if isinstance(connectivity_type, common.Dimension): + if val is trace_shifts.Sentinel.VALUE: + raise NotImplementedError("Dynamic offsets not supported.") + assert isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int) + current_dim = connectivity_type + # cartesian offset + new_ranges[current_dim] = SymbolicRange.translate( + self.ranges[current_dim], val.value + ) + elif isinstance(connectivity_type, common.NeighborConnectivityType): + # unstructured shift + assert ( + isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int) + ) or val in [ + trace_shifts.Sentinel.ALL_NEIGHBORS, + trace_shifts.Sentinel.VALUE, + ] + horizontal_sizes: dict[str, itir.Expr] + if symbolic_domain_sizes is not None: + horizontal_sizes = { + k: im.ensure_expr(v) for k, v in symbolic_domain_sizes.items() + } + else: + # note: ugly but cheap re-computation, but should disappear + assert common.is_offset_provider(offset_provider) + horizontal_sizes = { + k: im.literal(str(v), builtins.INTEGER_INDEX_BUILTIN) + for k, v in _max_domain_sizes_by_location_type(offset_provider).items() + } + + old_dim = connectivity_type.source_dim + new_dim = connectivity_type.codomain + + assert new_dim not in new_ranges or old_dim == new_dim + + new_range = SymbolicRange( + im.literal("0", builtins.INTEGER_INDEX_BUILTIN), + horizontal_sizes[new_dim.value], + ) + new_ranges = dict( + (dim, range_) if dim != old_dim else (new_dim, new_range) + for dim, range_ in new_ranges.items() + ) + else: + raise AssertionError() return SymbolicDomain(self.grid_type, new_ranges) elif len(shift) > 2: return self.translate(shift[0:2], offset_provider, symbolic_domain_sizes).translate( diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index fefca65a62..33f11c4d54 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -578,6 +578,10 @@ def axis_literal(dim: common.Dimension) -> itir.AxisLiteral: return itir.AxisLiteral(value=dim.value, kind=dim.kind) +def cartesian_offset(domain: common.Dimension, codomain: common.Dimension): + return itir.CartesianOffset(domain=axis_literal(domain), codomain=axis_literal(codomain)) + + def cast_as_fieldop(type_: str, domain: Optional[itir.FunCall] = None): """ Promotes the function `cast_` to a field_operator. diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 5063e26392..c827641560 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -143,6 +143,11 @@ def visit_InfinityLiteral(self, node: ir.InfinityLiteral, *, prec: int) -> list[ def visit_OffsetLiteral(self, node: ir.OffsetLiteral, *, prec: int) -> list[str]: return [str(node.value) + "ₒ"] + def visit_CartesianOffset(self, node: ir.CartesianOffset, *, prec: int) -> list[str]: + (domain,) = self.visit(node.domain, prec=0) + (codomain,) = self.visit(node.codomain, prec=0) + return [f"{domain}₂{codomain}"] + def visit_AxisLiteral(self, node: ir.AxisLiteral, *, prec: int) -> list[str]: kind = "" if node.kind == ir.DimensionKind.HORIZONTAL: diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 8173ceebbb..b198bd36a4 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -137,7 +137,8 @@ def _can_deref(x): def _shift(*offsets): assert all( - isinstance(offset, ir.OffsetLiteral) or offset in [Sentinel.ALL_NEIGHBORS, Sentinel.VALUE] + isinstance(offset, (ir.OffsetLiteral, ir.CartesianOffset)) + or offset in [Sentinel.ALL_NEIGHBORS, Sentinel.VALUE] for offset in offsets ) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 2bcb991849..5c7ef84e75 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -473,6 +473,15 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs) -> it_ts.Offse assert isinstance(node.value, str) return it_ts.OffsetLiteralType(value=node.value) + def visit_CartesianOffset( + self, node: itir.CartesianOffset, *, ctx + ) -> it_ts.CartesianOffsetType: + self.visit(node.domain, ctx=ctx) + self.visit(node.codomain, ctx=ctx) + return it_ts.CartesianOffsetType( + domain=node.domain.type.dim, codomain=node.codomain.type.dim + ) + def visit_Literal(self, node: itir.Literal, **kwargs) -> ts.ScalarType: assert isinstance(node.type, ts.ScalarType) return node.type diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index 39e9e607ce..40c6dde6d7 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -20,6 +20,11 @@ class OffsetLiteralType(ts.TypeSpec): value: ts.ScalarType | str +class CartesianOffsetType(ts.TypeSpec): + domain: common.Dimension + codomain: common.Dimension + + class IteratorType(ts.DataType, ts.CallableType): position_dims: list[common.Dimension] | Literal["unknown"] defined_dims: list[common.Dimension] diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 6b9c4341e4..2de0d838fe 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -16,7 +16,6 @@ from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Callable, Iterable, Optional, Union from gt4py.next import common, utils -from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_info, type_specifications as ts @@ -458,7 +457,7 @@ def _canonicalize_nb_fields( def _resolve_dimensions( input_dims: list[common.Dimension], - shift_tuple: tuple[itir.OffsetLiteral, ...], + shift_tuple: tuple[itir.OffsetLiteral | itir.CartesianOffset, ...], offset_provider_type: common.OffsetProviderType, ) -> list[common.Dimension]: """ @@ -486,14 +485,25 @@ def _resolve_dimensions( >>> Edge = common.Dimension(value="Edge") >>> Vertex = common.Dimension(value="Vertex") + >>> Cell = common.Dimension(value="Cell") >>> K = common.Dimension(value="K", kind=common.DimensionKind.VERTICAL) >>> V2E = common.Dimension(value="V2E") + >>> C2V = common.Dimension(value="C2V") >>> input_dims = [Edge, K] >>> shift_tuple = ( + ... itir.OffsetLiteral(value="C2V"), + ... itir.OffsetLiteral(value=0), ... itir.OffsetLiteral(value="V2E"), ... itir.OffsetLiteral(value=0), ... ) >>> offset_provider_type = { + ... "C2V": common.NeighborConnectivityType( + ... domain=(Cell, C2V), + ... codomain=Vertex, + ... skip_value=None, + ... dtype=None, + ... max_neighbors=3, + ... ), ... "V2E": common.NeighborConnectivityType( ... domain=(Vertex, V2E), ... codomain=Edge, @@ -504,21 +514,53 @@ def _resolve_dimensions( ... "KOff": K, ... } >>> _resolve_dimensions(input_dims, shift_tuple, offset_provider_type) - [Dimension(value='Vertex', kind=), Dimension(value='K', kind=)] + [Dimension(value='Cell', kind=), Dimension(value='K', kind=)] + >>> from gt4py.next.iterator.ir_utils import ir_makers as im + >>> IDim = common.Dimension(value="IDim") + >>> IHalfDim = common.flip_staggered(IDim) + >>> JDim = common.Dimension(value="JDim") + >>> JHalfDim = common.flip_staggered(JDim) + >>> input_dims = [IDim, JDim] + >>> shift_tuple = ( + ... itir.CartesianOffset( + ... domain=im.axis_literal(IDim), codomain=im.axis_literal(IHalfDim) + ... ), + ... itir.OffsetLiteral(value=0), + ... itir.CartesianOffset(domain=im.axis_literal(JDim), codomain=im.axis_literal(IDim)), + ... itir.OffsetLiteral(value=0), + ... itir.CartesianOffset( + ... domain=im.axis_literal(IHalfDim), codomain=im.axis_literal(JDim) + ... ), + ... itir.OffsetLiteral(value=0), + ... ) + >>> _resolve_dimensions(input_dims, shift_tuple, offset_provider_type) + [Dimension(value='JDim', kind=), Dimension(value='IDim', kind=)] + """ resolved_dims = [] for input_dim in input_dims: + resolved_dim = input_dim for off_literal in reversed( shift_tuple[::2] - ): # Only OffsetLiterals are processed, located at even indices in shift_tuple. Shifts are applied in reverse order: the last shift in the tuple is applied first. - assert isinstance(off_literal.value, str) - offset_type = common.get_offset_type(offset_provider_type, off_literal.value) - if isinstance(offset_type, common.Dimension) and input_dim == offset_type: - continue # No shift applied - if isinstance(offset_type, (fbuiltins.FieldOffset, common.NeighborConnectivityType)): - if input_dim == offset_type.codomain: # Check if input fits to offset - input_dim = offset_type.domain[0] # Update input_dim for next iteration - resolved_dims.append(input_dim) + ): # Only OffsetLiterals/CartesianOffsets are processed, located at even indices in shift_tuple. Shifts are applied in reverse order: the last shift in the tuple is applied first. + if isinstance(off_literal, itir.CartesianOffset): + if resolved_dim == common.Dimension( + value=off_literal.codomain.value, kind=off_literal.codomain.kind + ): + resolved_dim = common.Dimension( + value=off_literal.domain.value, kind=off_literal.domain.kind + ) + else: + assert isinstance(off_literal, itir.OffsetLiteral) and isinstance( + off_literal.value, str + ) + offset_type = common.get_offset_type(offset_provider_type, off_literal.value) + if isinstance(offset_type, common.Dimension) and resolved_dim == offset_type: + continue # No shift applied + if isinstance(offset_type, common.NeighborConnectivityType): + if resolved_dim == offset_type.codomain: # Check if input fits to offset + resolved_dim = offset_type.domain[0] # Update input_dim for next iteration + resolved_dims.append(resolved_dim) return resolved_dims @@ -661,20 +703,27 @@ def apply_shift( new_position_dims = [*it.position_dims] assert len(offset_literals) % 2 == 0 for offset_axis, _ in zip(offset_literals[:-1:2], offset_literals[1::2], strict=True): - assert isinstance(offset_axis, it_ts.OffsetLiteralType) and isinstance( - offset_axis.value, str - ) - type_ = common.get_offset_type(offset_provider_type, offset_axis.value) - if isinstance(type_, common.Dimension): - pass - elif isinstance(type_, common.NeighborConnectivityType): + if isinstance(offset_axis, it_ts.CartesianOffsetType): found = False for i, dim in enumerate(new_position_dims): - if dim.value == type_.source_dim.value: + if dim == offset_axis.domain: assert not found - new_position_dims[i] = type_.codomain + new_position_dims[i] = offset_axis.codomain found = True assert found + elif isinstance(offset_axis, it_ts.OffsetLiteralType): + assert isinstance(offset_axis.value, str) + type_ = common.get_offset_type(offset_provider_type, offset_axis.value) + if isinstance(type_, common.Dimension): + pass + elif isinstance(type_, common.NeighborConnectivityType): + found = False + for i, dim in enumerate(new_position_dims): + if dim.value == type_.source_dim.value: + assert not found + new_position_dims[i] = type_.codomain + found = True + assert found else: raise NotImplementedError(f"{type_} is not a supported Connectivity type.") else: diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 041868b00e..f08ef9067d 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -206,7 +206,9 @@ def make_argument(name: str, type_: ts.TypeSpec) -> str | BufferSID | Tuple: source_buffer=name, dimensions=[ DimensionSpec( - name=dim.value, + name=dim.value + if not common.check_staggered(dim) + else common.flip_staggered(dim).value, static_stride=1 if ( config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index a7062f2e1c..9ad77d357f 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -92,10 +92,31 @@ def _name_from_named_range(named_range_call: itir.FunCall) -> str: return named_range_call.args[0].value +class FlipStaggeredDims(eve.NodeTranslator): + def flip_to_nonstaggered(self, axis_literal: itir.AxisLiteral) -> itir.AxisLiteral: + dim = common.Dimension(value=axis_literal.value, kind=axis_literal.kind) + return im.axis_literal( + common.flip_staggered(dim) if dim.value.startswith(common._STAGGERED_PREFIX) else dim + ) + + def visit_CartesianOffset(self, node: itir.CartesianOffset) -> itir.CartesianOffset: + return itir.CartesianOffset( + domain=self.flip_to_nonstaggered(node.domain), + codomain=self.flip_to_nonstaggered(node.codomain), + ) + + def visit_AxisLiteral(self, node: itir.AxisLiteral) -> itir.AxisLiteral: + return self.flip_to_nonstaggered(node) + + +def flip_staggered_offsets(node: itir.Node) -> itir.Node: + return FlipStaggeredDims().visit(node) + + def _collect_dimensions_from_domain( body: Iterable[itir.Stmt], ) -> dict[str, TagDefinition]: - domains = _get_domains(body) + domains = flip_staggered_offsets(_get_domains(body)) offset_definitions = {} for domain in domains: if domain.fun == itir.SymRef(id="cartesian_domain"): @@ -132,6 +153,17 @@ def _collect_offset_definitions( grid_type: common.GridType, offset_provider_type: common.OffsetProviderType, ) -> dict[str, TagDefinition]: + offset_definitions = {} + + expr = flip_staggered_offsets(node.body[0].expr) + cartesian_offset_tags: set[str] = set() + for v in expr.walk_values(): + if isinstance(v, itir.CartesianOffset): + cartesian_offset_tags.add(v.domain.value) + cartesian_offset_tags.add(v.codomain.value) + for offset_name in cartesian_offset_tags: + offset_definitions[offset_name] = TagDefinition(name=Sym(id=offset_name)) + used_offset_tags: set[str] = ( node.walk_values() .if_isinstance(itir.OffsetLiteral) @@ -143,7 +175,6 @@ def _collect_offset_definitions( offset_name: common.get_offset_type(offset_provider_type, offset_name) for offset_name in used_offset_tags } | {**offset_provider_type} - offset_definitions = {} for offset_name, dim_or_connectivity_type in offset_provider_type.items(): if isinstance(dim_or_connectivity_type, common.Dimension): @@ -373,8 +404,14 @@ def visit_Literal(self, node: itir.Literal, **kwargs: Any) -> Literal: def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs: Any) -> OffsetLiteral: return OffsetLiteral(value=node.value) + def visit_CartesianOffset(self, node: itir.CartesianOffset, **kwargs: Any) -> Literal: + return self.visit(node.codomain, **kwargs) + def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs: Any) -> Literal: - return Literal(value=node.value, type="axis_literal") + dim = node.type.dim + if common.check_staggered(dim): + dim = common.flip_staggered(dim) + return Literal(value=dim.value, type="axis_literal") def _make_domain(self, node: itir.FunCall) -> tuple[TaggedValues, TaggedValues]: tags = [] diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 5d9ef9b397..b65d81e8bf 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -48,8 +48,10 @@ E2VDim, Edge, IDim, + IHalfDim, Ioff, JDim, + JHalfDim, Joff, KDim, KHalfDim, @@ -67,6 +69,7 @@ # mypy does not accept [IDim, ...] as a type IField: TypeAlias = gtx.Field[[IDim], np.int32] # type: ignore [valid-type] +IHalfField: TypeAlias = gtx.Field[[IHalfDim], np.int32] # type: ignore [valid-type] JField: TypeAlias = gtx.Field[[JDim], np.int32] # type: ignore [valid-type] IFloatField: TypeAlias = gtx.Field[[IDim], np.float64] # type: ignore [valid-type] IBoolField: TypeAlias = gtx.Field[[IDim], bool] # type: ignore [valid-type] @@ -494,6 +497,7 @@ def verify_with_default_data( case: Case, fieldop: decorator.FieldOperator, ref: Callable, + offset_provider: Optional[OffsetProvider] = None, comparison: Callable[[Any, Any], bool] = tree_mapped_np_allclose, ) -> None: """ @@ -508,6 +512,8 @@ def verify_with_default_data( fieldview_prog: The field operator or program to be verified. ref: A callable which will be called with all the input arguments of the fieldview code, after applying ``.ndarray`` on the fields. + offset_provider: An override for the test case's offset_provider. + Use with care! comparison: A comparison function, which will be called as ``comparison(ref, )`` and should return a boolean. """ @@ -521,7 +527,7 @@ def verify_with_default_data( *inps, **kwfields, ref=ref(*ref_args), - offset_provider=case.offset_provider, + offset_provider=offset_provider, comparison=comparison, ) @@ -724,7 +730,9 @@ def from_cartesian_grid_descriptor( IDim: grid_descriptor.sizes[0], JDim: grid_descriptor.sizes[1], KDim: grid_descriptor.sizes[2], - KHalfDim: grid_descriptor.sizes[3], + IHalfDim: grid_descriptor.sizes[0]-1, + JHalfDim: grid_descriptor.sizes[1]-1, + KHalfDim: grid_descriptor.sizes[2]-1, }, grid_type=common.GridType.CARTESIAN, allocator=allocator, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 7640553e6a..02181594c5 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -133,9 +133,11 @@ def debug_itir(tree): DType = TypeVar("DType") IDim = gtx.Dimension("IDim") +IHalfDim = common.flip_staggered(IDim) JDim = gtx.Dimension("JDim") +JHalfDim = common.flip_staggered(JDim) KDim = gtx.Dimension("KDim", kind=gtx.DimensionKind.VERTICAL) -KHalfDim = gtx.Dimension("KHalf", kind=gtx.DimensionKind.VERTICAL) +KHalfDim = common.flip_staggered(KDim) Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) Joff = gtx.FieldOffset("Joff", source=JDim, target=(JDim,)) Koff = gtx.FieldOffset("Koff", source=KDim, target=(KDim,)) @@ -172,17 +174,16 @@ def offset_provider_type(self) -> common.OffsetProviderType: ... def simple_cartesian_grid( - sizes: int | tuple[int, int, int, int] = (5, 7, 9, 11), + sizes: int | tuple[int, int, int, int] = (5, 7, 9), ) -> CartesianGridDescriptor: if isinstance(sizes, int): - sizes = (sizes,) * 4 - assert len(sizes) == 4, "sizes must be a tuple of four integers" + sizes = (sizes,) * 3 + assert len(sizes) == 3, "sizes must be a tuple of three integers" offset_provider = { "Ioff": IDim, "Joff": JDim, "Koff": KDim, - "KHalfoff": KHalfDim, } return types.SimpleNamespace( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_staggered.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_staggered.py new file mode 100644 index 0000000000..759779f037 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_staggered.py @@ -0,0 +1,144 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import functools +import math +from functools import reduce +from typing import TypeAlias + +import numpy as np +import pytest + +import gt4py.next as gtx +from gt4py.next import ( + astype, + broadcast, + common, + errors, + float32, + float64, + int32, + int64, + minimum, + neighbor_sum, + utils as gt_utils, +) +from gt4py.next.ffront.experimental import as_offset + +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import ( + C2E, + E2V, + V2E, + E2VDim, + Edge, + IDim, + IHalfDim, + Ioff, + JDim, + KDim, + Koff, + V2EDim, + Vertex, + cartesian_case, + unstructured_case, + unstructured_case_3d, +) +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + exec_alloc_descriptor, + mesh_descriptor, +) + + +@pytest.mark.uses_cartesian_shift +def test_copy_half_field(cartesian_case): + @gtx.field_operator + def testee(a: cases.IHalfField) -> cases.IHalfField: + field_tuple = (a, a) + field_0 = field_tuple[0] + field_1 = field_tuple[1] + return field_0 + + cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a, offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_shift_plus(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IField) -> cases.IField: + return a(IDim + 1) # always pass an I-index to an IField + + a = cases.allocate(cartesian_case, testee, "a").extend({IDim: (0, 1)})() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + cases.verify(cartesian_case, testee, a, out=out, ref=a[1:], offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_half_shift_plus(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IField) -> cases.IHalfField: + return a(IHalfDim + 0.5) # always pass an I-index to an IField + + a = cases.allocate(cartesian_case, testee, "a").extend({IDim: (-1, 0)})() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + cases.verify(cartesian_case, testee, a, out=out, ref=a, offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_half_shift_back(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IHalfField) -> cases.IHalfField: + return a(IDim + 0.5)(IHalfDim - 0.5) # always pass an I-index to an IField + + a = cases.allocate(cartesian_case, testee, "a")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + cases.verify(cartesian_case, testee, a, out=out, ref=a, offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_half_shift_plus1(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IHalfField) -> cases.IHalfField: + return a(IHalfDim + 1) # always pass an IHalf-index to an IHalfField + + a = cases.allocate(cartesian_case, testee, "a")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + cases.verify(cartesian_case, testee, a, out=out[:-1], ref=a[1:], offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_half_shift_minus(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IField) -> cases.IHalfField: + return a(IHalfDim - 0.5) # always pass an I-index to an IField + + a = cases.allocate(cartesian_case, testee, "a").extend({IDim: (0, -1)})() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + cases.verify(cartesian_case, testee, a, out=out, ref=a[:], offset_provider={}) + + +@pytest.mark.uses_cartesian_shift +def test_cartesian_half_shift_half2center(cartesian_case): + # TODO: center inlining probably doesn't work + @gtx.field_operator + def testee(a: cases.IHalfField) -> cases.IField: + return 2 * a(IDim + 0.5) # always pass an IHalf-index to an IHalfField + + a = cases.allocate(cartesian_case, testee, "a").extend({IHalfDim: (0, 1)})() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + cases.verify(cartesian_case, testee, a, out=out, ref=2 * a[:], offset_provider={}) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py index dd30caa726..e5a4fb3be7 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_multiple_output_domains.py @@ -17,6 +17,7 @@ IDim, JDim, KDim, + KHalfDim, C2E, E2V, V2E, @@ -37,7 +38,6 @@ mesh_descriptor, ) -KHalfDim = gtx.Dimension("KHalf", kind=gtx.DimensionKind.VERTICAL) pytestmark = pytest.mark.uses_cartesian_shift diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 21177d0aea..97ef301e38 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -131,9 +131,7 @@ def foo(inp: gtx.Field[[TDim], float64]): lowered = FieldOperatorLowering.apply(parsed) reference = im.as_fieldop( - im.lambda_("__it")( - im.deref(im.shift(common.dimension_to_implicit_offset(TDim.value), 1)("__it")) - ) + im.lambda_("__it")(im.deref(im.shift(im.cartesian_offset(TDim, TDim), 1)("__it"))) )("inp") assert lowered.expr == reference diff --git a/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py index ff7a761c5a..c0bea0d592 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py @@ -14,7 +14,6 @@ from gt4py.next.type_system import type_specifications as ts IDim = gtx.Dimension("IDim") -field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) def test_inline_dynamic_shift_as_fieldop_arg():