Skip to content

Commit d5ed163

Browse files
committed
feat[next]: Support for staggered fields GridTools#2339 fea7480
1 parent 0736f52 commit d5ed163

File tree

20 files changed

+322
-92
lines changed

20 files changed

+322
-92
lines changed

src/gt4py/next/common.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1249,11 +1249,21 @@ class GridType(StrEnum):
12491249
UNSTRUCTURED = "unstructured"
12501250

12511251

1252+
def check_staggered(dim: Dimension) -> bool:
1253+
return dim.value.startswith(_STAGGERED_PREFIX)
1254+
1255+
12521256
def order_dimensions(dims: Iterable[Dimension]) -> list[Dimension]:
12531257
"""Find the canonical ordering of the dimensions in `dims`."""
12541258
if sum(1 for dim in dims if dim.kind == DimensionKind.LOCAL) > 1:
12551259
raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.")
1256-
return sorted(dims, key=lambda dim: (_DIM_KIND_ORDER[dim.kind], dim.value))
1260+
return sorted(
1261+
dims,
1262+
key=lambda dim: (
1263+
_DIM_KIND_ORDER[dim.kind],
1264+
flip_staggered(dim).value if check_staggered(dim) else dim.value,
1265+
),
1266+
)
12571267

12581268

12591269
def check_dims(dims: Sequence[Dimension]) -> None:
@@ -1341,3 +1351,22 @@ def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Call
13411351
#: Equivalent to the `_FillValue` attribute in the UGRID Conventions
13421352
#: (see: http://ugrid-conventions.github.io/ugrid-conventions/).
13431353
_DEFAULT_SKIP_VALUE: Final[int] = -1
1354+
_STAGGERED_PREFIX = "_Staggered"
1355+
1356+
1357+
def flip_staggered(dim: Dimension) -> Dimension:
1358+
if dim.value.startswith(_STAGGERED_PREFIX):
1359+
return Dimension(dim.value[len(_STAGGERED_PREFIX) :], dim.kind)
1360+
else:
1361+
return Dimension(f"{_STAGGERED_PREFIX}{dim.value}", dim.kind)
1362+
1363+
1364+
def connectivity_for_cartesian_shift(dim: Dimension, offset: int | float) -> CartesianConnectivity:
1365+
if isinstance(offset, float):
1366+
integral_offset, half = divmod(offset, 1)
1367+
assert half == 0.5
1368+
if dim.value.startswith(_STAGGERED_PREFIX):
1369+
integral_offset += 1
1370+
return CartesianConnectivity(dim, int(integral_offset), codomain=flip_staggered(dim))
1371+
else:
1372+
return CartesianConnectivity(dim, offset, codomain=dim)

src/gt4py/next/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def env_flag_to_int(name: str, default: int) -> int:
6464
#: Master debug flag
6565
#: Changes defaults for all the other options to be as helpful for debugging as possible.
6666
#: Does not override values set in environment variables.
67-
DEBUG: Final[bool] = env_flag_to_bool("GT4PY_DEBUG", default=False)
67+
DEBUG: Final[bool] = env_flag_to_bool("GT4PY_DEBUG", default=True)
6868

6969

7070
#: Verbose flag for DSL compilation errors

src/gt4py/next/ffront/foast_passes/type_deduction.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import gt4py.next.ffront.field_operator_ast as foast
1212
from gt4py.eve import NodeTranslator, NodeVisitor, traits
13-
from gt4py.next import errors, utils
13+
from gt4py.next import common, errors, utils
1414
from gt4py.next.common import DimensionKind, promote_dims
1515
from gt4py.next.ffront import ( # noqa
1616
dialect_ast_enums,
@@ -655,13 +655,20 @@ def _deduce_compare_type(
655655
def _deduce_binop_type(
656656
self, node: foast.BinOp, *, left: foast.Expr, right: foast.Expr, **kwargs: Any
657657
) -> Optional[ts.TypeSpec]:
658-
# e.g. `IDim+1`
658+
# e.g. `IDim+1` or `IDim+0.5`
659659
if (
660660
isinstance(left.type, ts.DimensionType)
661661
and isinstance(right.type, ts.ScalarType)
662-
and type_info.is_integral(right.type)
662+
and type_info.is_arithmetic(right.type)
663663
):
664-
return ts.OffsetType(source=left.type.dim, target=(left.type.dim,))
664+
if not isinstance(right, foast.Constant):
665+
raise NotImplementedError()
666+
offset_index = right.value
667+
if node.op == dialect_ast_enums.BinaryOperator.SUB:
668+
offset_index *= -1
669+
conn = common.connectivity_for_cartesian_shift(left.type.dim, offset_index)
670+
return ts.OffsetType(source=conn.codomain, target=(conn.domain_dim,))
671+
665672
if isinstance(left.type, ts.OffsetType):
666673
raise errors.DSLError(
667674
node.location, f"Type '{left.type}' can not be used in operator '{node.op}'."

src/gt4py/next/ffront/foast_to_gtir.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,17 +306,19 @@ def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
306306
# `field(Dim + idx)`
307307
case foast.BinOp(
308308
op=dialect_ast_enums.BinaryOperator.ADD | dialect_ast_enums.BinaryOperator.SUB,
309-
left=foast.Name(id=dimension), # TODO(tehrengruber): use type of lhs
309+
left=foast.Name(), # TODO(tehrengruber): use type of lhs
310310
right=foast.Constant(value=offset_index),
311311
):
312312
if arg.op == dialect_ast_enums.BinaryOperator.SUB:
313313
offset_index *= -1
314-
# TODO(havogt): we rely on the naming-convention for implicit offsets, see `dimension_to_implicit_offset`
314+
conn = common.connectivity_for_cartesian_shift(
315+
node.args[0].left.type.dim, offset_index
316+
)
315317
current_expr = im.as_fieldop(
316318
im.lambda_("__it")(
317319
im.deref(
318320
im.shift(
319-
common.dimension_to_implicit_offset(dimension), offset_index
321+
im.cartesian_offset(conn.domain_dim, conn.codomain), conn.offset
320322
)("__it")
321323
)
322324
)

src/gt4py/next/iterator/ir.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,23 @@ def __str__(self):
8181
InfinityLiteral.POSITIVE = InfinityLiteral(name="POSITIVE")
8282

8383

84-
class OffsetLiteral(Expr):
85-
value: Union[int, str]
86-
87-
8884
class AxisLiteral(Expr):
8985
# TODO(havogt): Refactor to use declare Axis/Dimension at the Program level.
9086
# Now every use of the literal has to provide the kind, where usually we only care of the name.
9187
value: str
9288
kind: common.DimensionKind = common.DimensionKind.HORIZONTAL
9389

9490

91+
class CartesianOffset(Expr):
92+
domain: AxisLiteral
93+
codomain: AxisLiteral
94+
95+
96+
# TODO(tehrengruber): allow int only and create OffsetRef for str instead
97+
class OffsetLiteral(Expr):
98+
value: Union[int, str]
99+
100+
95101
class SymRef(Expr):
96102
id: Coerced[SymbolRef]
97103

@@ -150,8 +156,9 @@ class Program(Node, ValidatedSymbolTableTrait):
150156
Expr.__hash__ = Node.__hash__ # type: ignore[method-assign]
151157
Literal.__hash__ = Node.__hash__ # type: ignore[method-assign]
152158
NoneLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign]
153-
OffsetLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign]
154159
AxisLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign]
160+
OffsetLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign]
161+
CartesianOffset.__hash__ = Node.__hash__ # type: ignore[method-assign]
155162
SymRef.__hash__ = Node.__hash__ # type: ignore[method-assign]
156163
Lambda.__hash__ = Node.__hash__ # type: ignore[method-assign]
157164
FunCall.__hash__ = Node.__hash__ # type: ignore[method-assign]

src/gt4py/next/iterator/ir_utils/domain_utils.py

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -124,54 +124,62 @@ def translate(
124124
return self
125125
if len(shift) == 2:
126126
off, val = shift
127-
assert isinstance(off, itir.OffsetLiteral) and isinstance(off.value, str)
128-
connectivity_type = common.get_offset_type(offset_provider_type, off.value)
129-
130-
if isinstance(connectivity_type, common.Dimension):
131-
if val is trace_shifts.Sentinel.VALUE:
132-
raise NotImplementedError("Dynamic offsets not supported.")
133-
assert isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int)
134-
current_dim = connectivity_type
135-
# cartesian offset
136-
new_ranges[current_dim] = SymbolicRange.translate(
137-
self.ranges[current_dim], val.value
138-
)
139-
elif isinstance(connectivity_type, common.NeighborConnectivityType):
140-
# unstructured shift
141-
assert (
142-
isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int)
143-
) or val in [
144-
trace_shifts.Sentinel.ALL_NEIGHBORS,
145-
trace_shifts.Sentinel.VALUE,
146-
]
147-
horizontal_sizes: dict[str, itir.Expr]
148-
if symbolic_domain_sizes is not None:
149-
horizontal_sizes = {
150-
k: im.ensure_expr(v) for k, v in symbolic_domain_sizes.items()
151-
}
152-
else:
153-
# note: ugly but cheap re-computation, but should disappear
154-
assert common.is_offset_provider(offset_provider)
155-
horizontal_sizes = {
156-
k: im.literal(str(v), builtins.INTEGER_INDEX_BUILTIN)
157-
for k, v in _max_domain_sizes_by_location_type(offset_provider).items()
158-
}
159-
160-
old_dim = connectivity_type.source_dim
161-
new_dim = connectivity_type.codomain
162-
163-
assert new_dim not in new_ranges or old_dim == new_dim
164-
165-
new_range = SymbolicRange(
166-
im.literal("0", builtins.INTEGER_INDEX_BUILTIN),
167-
horizontal_sizes[new_dim.value],
168-
)
127+
if isinstance(off, itir.CartesianOffset):
128+
old_dim = common.Dimension(value=off.domain.value, kind=off.domain.kind)
129+
new_dim = common.Dimension(value=off.codomain.value, kind=off.codomain.kind)
130+
new_range = SymbolicRange.translate(self.ranges[old_dim], val.value)
169131
new_ranges = dict(
170132
(dim, range_) if dim != old_dim else (new_dim, new_range)
171133
for dim, range_ in new_ranges.items()
172134
)
173135
else:
174-
raise AssertionError()
136+
assert isinstance(off, itir.OffsetLiteral) and isinstance(off.value, str)
137+
connectivity_type = common.get_offset_type(offset_provider_type, off.value)
138+
if isinstance(connectivity_type, common.Dimension):
139+
if val is trace_shifts.Sentinel.VALUE:
140+
raise NotImplementedError("Dynamic offsets not supported.")
141+
assert isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int)
142+
current_dim = connectivity_type
143+
# cartesian offset
144+
new_ranges[current_dim] = SymbolicRange.translate(
145+
self.ranges[current_dim], val.value
146+
)
147+
elif isinstance(connectivity_type, common.NeighborConnectivityType):
148+
# unstructured shift
149+
assert (
150+
isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int)
151+
) or val in [
152+
trace_shifts.Sentinel.ALL_NEIGHBORS,
153+
trace_shifts.Sentinel.VALUE,
154+
]
155+
horizontal_sizes: dict[str, itir.Expr]
156+
if symbolic_domain_sizes is not None:
157+
horizontal_sizes = {
158+
k: im.ensure_expr(v) for k, v in symbolic_domain_sizes.items()
159+
}
160+
else:
161+
# note: ugly but cheap re-computation, but should disappear
162+
assert common.is_offset_provider(offset_provider)
163+
horizontal_sizes = {
164+
k: im.literal(str(v), builtins.INTEGER_INDEX_BUILTIN)
165+
for k, v in _max_domain_sizes_by_location_type(offset_provider).items()
166+
}
167+
168+
old_dim = connectivity_type.source_dim
169+
new_dim = connectivity_type.codomain
170+
171+
assert new_dim not in new_ranges or old_dim == new_dim
172+
173+
new_range = SymbolicRange(
174+
im.literal("0", builtins.INTEGER_INDEX_BUILTIN),
175+
horizontal_sizes[new_dim.value],
176+
)
177+
new_ranges = dict(
178+
(dim, range_) if dim != old_dim else (new_dim, new_range)
179+
for dim, range_ in new_ranges.items()
180+
)
181+
else:
182+
raise AssertionError()
175183
return SymbolicDomain(self.grid_type, new_ranges)
176184
elif len(shift) > 2:
177185
return self.translate(shift[0:2], offset_provider, symbolic_domain_sizes).translate(

src/gt4py/next/iterator/ir_utils/ir_makers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,10 @@ def axis_literal(dim: common.Dimension) -> itir.AxisLiteral:
578578
return itir.AxisLiteral(value=dim.value, kind=dim.kind)
579579

580580

581+
def cartesian_offset(domain: common.Dimension, codomain: common.Dimension):
582+
return itir.CartesianOffset(domain=axis_literal(domain), codomain=axis_literal(codomain))
583+
584+
581585
def cast_as_fieldop(type_: str, domain: Optional[itir.FunCall] = None):
582586
"""
583587
Promotes the function `cast_` to a field_operator.

src/gt4py/next/iterator/pretty_printer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ def visit_InfinityLiteral(self, node: ir.InfinityLiteral, *, prec: int) -> list[
143143
def visit_OffsetLiteral(self, node: ir.OffsetLiteral, *, prec: int) -> list[str]:
144144
return [str(node.value) + "ₒ"]
145145

146+
def visit_CartesianOffset(self, node: ir.CartesianOffset, *, prec: int) -> list[str]:
147+
(domain,) = self.visit(node.domain, prec=0)
148+
(codomain,) = self.visit(node.codomain, prec=0)
149+
return [f"{domain}{codomain}"]
150+
146151
def visit_AxisLiteral(self, node: ir.AxisLiteral, *, prec: int) -> list[str]:
147152
kind = ""
148153
if node.kind == ir.DimensionKind.HORIZONTAL:

src/gt4py/next/iterator/transforms/trace_shifts.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ def _can_deref(x):
137137

138138
def _shift(*offsets):
139139
assert all(
140-
isinstance(offset, ir.OffsetLiteral) or offset in [Sentinel.ALL_NEIGHBORS, Sentinel.VALUE]
140+
isinstance(offset, (ir.OffsetLiteral, ir.CartesianOffset))
141+
or offset in [Sentinel.ALL_NEIGHBORS, Sentinel.VALUE]
141142
for offset in offsets
142143
)
143144

src/gt4py/next/iterator/type_system/inference.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,15 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs) -> it_ts.Offse
473473
assert isinstance(node.value, str)
474474
return it_ts.OffsetLiteralType(value=node.value)
475475

476+
def visit_CartesianOffset(
477+
self, node: itir.CartesianOffset, *, ctx
478+
) -> it_ts.CartesianOffsetType:
479+
self.visit(node.domain, ctx=ctx)
480+
self.visit(node.codomain, ctx=ctx)
481+
return it_ts.CartesianOffsetType(
482+
domain=node.domain.type.dim, codomain=node.codomain.type.dim
483+
)
484+
476485
def visit_Literal(self, node: itir.Literal, **kwargs) -> ts.ScalarType:
477486
assert isinstance(node.type, ts.ScalarType)
478487
return node.type

0 commit comments

Comments
 (0)