Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,11 +1263,21 @@ class GridType(StrEnum):
UNSTRUCTURED = "unstructured"


def check_staggered(dim: Dimension) -> bool:
return dim.value.startswith(_STAGGERED_PREFIX)


def order_dimensions(dims: Iterable[Dimension]) -> list[Dimension]:
"""Find the canonical ordering of the dimensions in `dims`."""
if sum(1 for dim in dims if dim.kind == DimensionKind.LOCAL) > 1:
raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.")
return sorted(dims, key=lambda dim: (_DIM_KIND_ORDER[dim.kind], dim.value))
return sorted(
dims,
key=lambda dim: (
_DIM_KIND_ORDER[dim.kind],
flip_staggered(dim).value if check_staggered(dim) else dim.value,
),
)


def check_dims(dims: Sequence[Dimension]) -> None:
Expand Down Expand Up @@ -1355,3 +1365,22 @@ def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Call
#: Equivalent to the `_FillValue` attribute in the UGRID Conventions
#: (see: http://ugrid-conventions.github.io/ugrid-conventions/).
_DEFAULT_SKIP_VALUE: Final[int] = -1
_STAGGERED_PREFIX = "_Staggered"


def flip_staggered(dim: Dimension) -> Dimension:
if dim.value.startswith(_STAGGERED_PREFIX):
return Dimension(dim.value[len(_STAGGERED_PREFIX) :], dim.kind)
else:
return Dimension(f"{_STAGGERED_PREFIX}{dim.value}", dim.kind)


def connectivity_for_cartesian_shift(dim: Dimension, offset: int | float) -> CartesianConnectivity:
if isinstance(offset, float):
integral_offset, half = divmod(offset, 1)
assert half == 0.5
if dim.value.startswith(_STAGGERED_PREFIX):
integral_offset += 1
return CartesianConnectivity(dim, int(integral_offset), codomain=flip_staggered(dim))
else:
return CartesianConnectivity(dim, offset, codomain=dim)
2 changes: 1 addition & 1 deletion src/gt4py/next/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def env_flag_to_int(name: str, default: int) -> int:
#: Master debug flag
#: Changes defaults for all the other options to be as helpful for debugging as possible.
#: Does not override values set in environment variables.
DEBUG: Final[bool] = env_flag_to_bool("GT4PY_DEBUG", default=False)
DEBUG: Final[bool] = env_flag_to_bool("GT4PY_DEBUG", default=True)


#: Verbose flag for DSL compilation errors
Expand Down
15 changes: 11 additions & 4 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import gt4py.next.ffront.field_operator_ast as foast
from gt4py.eve import NodeTranslator, NodeVisitor, traits
from gt4py.next import errors, utils
from gt4py.next import common, errors, utils
from gt4py.next.common import DimensionKind, promote_dims
from gt4py.next.ffront import ( # noqa
dialect_ast_enums,
Expand Down Expand Up @@ -655,13 +655,20 @@ def _deduce_compare_type(
def _deduce_binop_type(
self, node: foast.BinOp, *, left: foast.Expr, right: foast.Expr, **kwargs: Any
) -> Optional[ts.TypeSpec]:
# e.g. `IDim+1`
# e.g. `IDim+1` or `IDim+0.5`
if (
isinstance(left.type, ts.DimensionType)
and isinstance(right.type, ts.ScalarType)
and type_info.is_integral(right.type)
and type_info.is_arithmetic(right.type)
):
return ts.OffsetType(source=left.type.dim, target=(left.type.dim,))
if not isinstance(right, foast.Constant):
raise NotImplementedError()
offset_index = right.value
if node.op == dialect_ast_enums.BinaryOperator.SUB:
offset_index *= -1
conn = common.connectivity_for_cartesian_shift(left.type.dim, offset_index)
return ts.OffsetType(source=conn.codomain, target=(conn.domain_dim,))

if isinstance(left.type, ts.OffsetType):
raise errors.DSLError(
node.location, f"Type '{left.type}' can not be used in operator '{node.op}'."
Expand Down
8 changes: 5 additions & 3 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,17 +305,19 @@ def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
# `field(Dim + idx)`
case foast.BinOp(
op=dialect_ast_enums.BinaryOperator.ADD | dialect_ast_enums.BinaryOperator.SUB,
left=foast.Name(id=dimension), # TODO(tehrengruber): use type of lhs
left=foast.Name(), # TODO(tehrengruber): use type of lhs
right=foast.Constant(value=offset_index),
):
if arg.op == dialect_ast_enums.BinaryOperator.SUB:
offset_index *= -1
# TODO(havogt): we rely on the naming-convention for implicit offsets, see `dimension_to_implicit_offset`
conn = common.connectivity_for_cartesian_shift(
node.args[0].left.type.dim, offset_index
)
current_expr = im.as_fieldop(
im.lambda_("__it")(
im.deref(
im.shift(
common.dimension_to_implicit_offset(dimension), offset_index
im.cartesian_offset(conn.domain_dim, conn.codomain), conn.offset
)("__it")
)
)
Expand Down
17 changes: 12 additions & 5 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,23 @@ def __str__(self):
InfinityLiteral.POSITIVE = InfinityLiteral(name="POSITIVE")


class OffsetLiteral(Expr):
value: Union[int, str]


class AxisLiteral(Expr):
# TODO(havogt): Refactor to use declare Axis/Dimension at the Program level.
# Now every use of the literal has to provide the kind, where usually we only care of the name.
value: str
kind: common.DimensionKind = common.DimensionKind.HORIZONTAL


class CartesianOffset(Expr):
domain: AxisLiteral
codomain: AxisLiteral


# TODO(tehrengruber): allow int only and create OffsetRef for str instead
class OffsetLiteral(Expr):
value: Union[int, str]


class SymRef(Expr):
id: Coerced[SymbolRef]

Expand Down Expand Up @@ -150,8 +156,9 @@ class Program(Node, ValidatedSymbolTableTrait):
Expr.__hash__ = Node.__hash__ # type: ignore[method-assign]
Literal.__hash__ = Node.__hash__ # type: ignore[method-assign]
NoneLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign]
OffsetLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign]
AxisLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign]
OffsetLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign]
CartesianOffset.__hash__ = Node.__hash__ # type: ignore[method-assign]
SymRef.__hash__ = Node.__hash__ # type: ignore[method-assign]
Lambda.__hash__ = Node.__hash__ # type: ignore[method-assign]
FunCall.__hash__ = Node.__hash__ # type: ignore[method-assign]
Expand Down
94 changes: 51 additions & 43 deletions src/gt4py/next/iterator/ir_utils/domain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,54 +124,62 @@ def translate(
return self
if len(shift) == 2:
off, val = shift
assert isinstance(off, itir.OffsetLiteral) and isinstance(off.value, str)
connectivity_type = common.get_offset_type(offset_provider_type, off.value)

if isinstance(connectivity_type, common.Dimension):
if val is trace_shifts.Sentinel.VALUE:
raise NotImplementedError("Dynamic offsets not supported.")
assert isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int)
current_dim = connectivity_type
# cartesian offset
new_ranges[current_dim] = SymbolicRange.translate(
self.ranges[current_dim], val.value
)
elif isinstance(connectivity_type, common.NeighborConnectivityType):
# unstructured shift
assert (
isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int)
) or val in [
trace_shifts.Sentinel.ALL_NEIGHBORS,
trace_shifts.Sentinel.VALUE,
]
horizontal_sizes: dict[str, itir.Expr]
if symbolic_domain_sizes is not None:
horizontal_sizes = {
k: im.ensure_expr(v) for k, v in symbolic_domain_sizes.items()
}
else:
# note: ugly but cheap re-computation, but should disappear
assert common.is_offset_provider(offset_provider)
horizontal_sizes = {
k: im.literal(str(v), builtins.INTEGER_INDEX_BUILTIN)
for k, v in _max_domain_sizes_by_location_type(offset_provider).items()
}

old_dim = connectivity_type.source_dim
new_dim = connectivity_type.codomain

assert new_dim not in new_ranges or old_dim == new_dim

new_range = SymbolicRange(
im.literal("0", builtins.INTEGER_INDEX_BUILTIN),
horizontal_sizes[new_dim.value],
)
if isinstance(off, itir.CartesianOffset):
old_dim = common.Dimension(value=off.domain.value, kind=off.domain.kind)
new_dim = common.Dimension(value=off.codomain.value, kind=off.codomain.kind)
new_range = SymbolicRange.translate(self.ranges[old_dim], val.value)
new_ranges = dict(
(dim, range_) if dim != old_dim else (new_dim, new_range)
for dim, range_ in new_ranges.items()
)
else:
raise AssertionError()
assert isinstance(off, itir.OffsetLiteral) and isinstance(off.value, str)
connectivity_type = common.get_offset_type(offset_provider_type, off.value)
if isinstance(connectivity_type, common.Dimension):
if val is trace_shifts.Sentinel.VALUE:
raise NotImplementedError("Dynamic offsets not supported.")
assert isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int)
current_dim = connectivity_type
# cartesian offset
new_ranges[current_dim] = SymbolicRange.translate(
self.ranges[current_dim], val.value
)
elif isinstance(connectivity_type, common.NeighborConnectivityType):
# unstructured shift
assert (
isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int)
) or val in [
trace_shifts.Sentinel.ALL_NEIGHBORS,
trace_shifts.Sentinel.VALUE,
]
horizontal_sizes: dict[str, itir.Expr]
if symbolic_domain_sizes is not None:
horizontal_sizes = {
k: im.ensure_expr(v) for k, v in symbolic_domain_sizes.items()
}
else:
# note: ugly but cheap re-computation, but should disappear
assert common.is_offset_provider(offset_provider)
horizontal_sizes = {
k: im.literal(str(v), builtins.INTEGER_INDEX_BUILTIN)
for k, v in _max_domain_sizes_by_location_type(offset_provider).items()
}

old_dim = connectivity_type.source_dim
new_dim = connectivity_type.codomain

assert new_dim not in new_ranges or old_dim == new_dim

new_range = SymbolicRange(
im.literal("0", builtins.INTEGER_INDEX_BUILTIN),
horizontal_sizes[new_dim.value],
)
new_ranges = dict(
(dim, range_) if dim != old_dim else (new_dim, new_range)
for dim, range_ in new_ranges.items()
)
else:
raise AssertionError()
return SymbolicDomain(self.grid_type, new_ranges)
elif len(shift) > 2:
return self.translate(shift[0:2], offset_provider, symbolic_domain_sizes).translate(
Expand Down
4 changes: 4 additions & 0 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,10 @@ def axis_literal(dim: common.Dimension) -> itir.AxisLiteral:
return itir.AxisLiteral(value=dim.value, kind=dim.kind)


def cartesian_offset(domain: common.Dimension, codomain: common.Dimension):
return itir.CartesianOffset(domain=axis_literal(domain), codomain=axis_literal(codomain))


def cast_as_fieldop(type_: str, domain: Optional[itir.FunCall] = None):
"""
Promotes the function `cast_` to a field_operator.
Expand Down
5 changes: 5 additions & 0 deletions src/gt4py/next/iterator/pretty_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ def visit_InfinityLiteral(self, node: ir.InfinityLiteral, *, prec: int) -> list[
def visit_OffsetLiteral(self, node: ir.OffsetLiteral, *, prec: int) -> list[str]:
return [str(node.value) + "ₒ"]

def visit_CartesianOffset(self, node: ir.CartesianOffset, *, prec: int) -> list[str]:
(domain,) = self.visit(node.domain, prec=0)
(codomain,) = self.visit(node.codomain, prec=0)
return [f"{domain}₂{codomain}"]

def visit_AxisLiteral(self, node: ir.AxisLiteral, *, prec: int) -> list[str]:
kind = ""
if node.kind == ir.DimensionKind.HORIZONTAL:
Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/next/iterator/transforms/trace_shifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def _can_deref(x):

def _shift(*offsets):
assert all(
isinstance(offset, ir.OffsetLiteral) or offset in [Sentinel.ALL_NEIGHBORS, Sentinel.VALUE]
isinstance(offset, (ir.OffsetLiteral, ir.CartesianOffset))
or offset in [Sentinel.ALL_NEIGHBORS, Sentinel.VALUE]
for offset in offsets
)

Expand Down
9 changes: 9 additions & 0 deletions src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,15 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs) -> it_ts.Offse
assert isinstance(node.value, str)
return it_ts.OffsetLiteralType(value=node.value)

def visit_CartesianOffset(
self, node: itir.CartesianOffset, *, ctx
) -> it_ts.CartesianOffsetType:
self.visit(node.domain, ctx=ctx)
self.visit(node.codomain, ctx=ctx)
return it_ts.CartesianOffsetType(
domain=node.domain.type.dim, codomain=node.codomain.type.dim
)

def visit_Literal(self, node: itir.Literal, **kwargs) -> ts.ScalarType:
assert isinstance(node.type, ts.ScalarType)
return node.type
Expand Down
5 changes: 5 additions & 0 deletions src/gt4py/next/iterator/type_system/type_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ class OffsetLiteralType(ts.TypeSpec):
value: ts.ScalarType | str


class CartesianOffsetType(ts.TypeSpec):
domain: common.Dimension
codomain: common.Dimension


class IteratorType(ts.DataType, ts.CallableType):
position_dims: list[common.Dimension] | Literal["unknown"]
defined_dims: list[common.Dimension]
Expand Down
Loading
Loading