Skip to content

Commit 1ba5adb

Browse files
committed
feat[next]: Compiled variant for field operators GridTools#2368 4c490e4
1 parent 968a8e8 commit 1ba5adb

File tree

6 files changed

+210
-117
lines changed

6 files changed

+210
-117
lines changed

src/gt4py/next/ffront/decorator.py

Lines changed: 92 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from gt4py import eve
2525
from gt4py._core import definitions as core_defs
2626
from gt4py.eve import extended_typing as xtyping
27-
from gt4py.eve.extended_typing import Self, override
27+
from gt4py.eve.extended_typing import Self, Unpack, override
2828
from gt4py.next import (
2929
allocators as next_allocators,
3030
backend as next_backend,
@@ -47,17 +47,85 @@
4747
)
4848
from gt4py.next.ffront.gtcallable import GTCallable
4949
from gt4py.next.iterator import ir as itir
50-
from gt4py.next.otf import arguments, compiled_program, stages, toolchain
50+
from gt4py.next.otf import arguments, compiled_program, options, stages, toolchain
5151
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation
5252

5353

5454
DEFAULT_BACKEND: next_backend.Backend | None = None
5555

5656

57+
class CompiledProgramMixin:
58+
@functools.cached_property
59+
def _compiled_programs(self) -> compiled_program.CompiledProgramsPool:
60+
if self.backend is None or self.backend == eve.NOTHING:
61+
raise RuntimeError("Cannot compile a program without backend.")
62+
63+
if self.compilation_options.static_params is None:
64+
object.__setattr__(self.compilation_options, "static_params", ())
65+
66+
argument_descriptor_mapping = {
67+
arguments.StaticArg: self.compilation_options.static_params,
68+
}
69+
70+
return compiled_program.CompiledProgramsPool(
71+
backend=self.backend,
72+
definition_stage=self.definition_stage,
73+
callable_type=self.__gt_type__(),
74+
argument_descriptor_mapping=argument_descriptor_mapping,
75+
# type: ignore[arg-type] # covariant `type[T]` not possible
76+
)
77+
78+
def compile(
79+
self,
80+
offset_provider: common.OffsetProviderType
81+
| common.OffsetProvider
82+
| list[common.OffsetProviderType | common.OffsetProvider]
83+
| None = None,
84+
enable_jit: bool | None = None,
85+
**static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]],
86+
) -> Self:
87+
"""
88+
Compiles the program for the given combination of static arguments and offset provider type.
89+
90+
Note: Unlike `with_...` methods, this method does not return a new instance of the program,
91+
but adds the compiled variants to the current program instance.
92+
"""
93+
# TODO(havogt): we should reconsider if we want to return a new program on `compile` (and
94+
# rename to `with_static_args` or similar) once we have a better understanding of the
95+
# use-cases.
96+
97+
if enable_jit is not None:
98+
object.__setattr__(self.compilation_options, "enable_jit", enable_jit)
99+
if self.compilation_options.static_params is None:
100+
object.__setattr__(self.compilation_options, "static_params", tuple(static_args.keys()))
101+
if self.compilation_options.connectivities is None and offset_provider is None:
102+
raise ValueError(
103+
"Cannot compile a program without connectivities / OffsetProviderType."
104+
)
105+
if not all(isinstance(v, list) for v in static_args.values()):
106+
raise TypeError(
107+
"Please provide the static arguments as lists."
108+
) # To avoid confusion with tuple args
109+
110+
offset_provider = (
111+
self.compilation_options.connectivities if offset_provider is None else offset_provider
112+
)
113+
if not isinstance(offset_provider, list):
114+
offset_provider = [offset_provider] # type: ignore[list-item] # cleanup offset_provider vs offset_provider_type
115+
116+
assert all(
117+
common.is_offset_provider(op) or common.is_offset_provider_type(op)
118+
for op in offset_provider
119+
)
120+
121+
self._compiled_programs.compile(offset_providers=offset_provider, **static_args)
122+
return self
123+
124+
57125
# TODO(tehrengruber): Decide if and how programs can call other programs. As a
58126
# result Program could become a GTCallable.
59127
@dataclasses.dataclass(frozen=True)
60-
class Program:
128+
class Program(CompiledProgramMixin):
61129
"""
62130
Construct a program object from a PAST node.
63131
@@ -79,35 +147,26 @@ class Program:
79147

80148
definition_stage: ffront_stages.ProgramDefinition
81149
backend: Optional[next_backend.Backend]
82-
connectivities: Optional[
83-
common.OffsetProvider
84-
] # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information
85-
enable_jit: bool | None
86-
static_params: (
87-
Sequence[str] | None
88-
) # if the user requests static params, they will be used later to initialize CompiledPrograms
150+
compilation_options: options.CompilationOptions
89151

90152
@classmethod
91153
def from_function(
92154
cls,
93155
definition: types.FunctionType,
94156
backend: next_backend.Backend | None,
95157
grid_type: common.GridType | None = None,
96-
enable_jit: bool | None = None,
97-
static_params: Sequence[str] | None = None,
98-
connectivities: Optional[
99-
common.OffsetProvider
100-
] = None, # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information
158+
**compilation_options: Unpack[options.CompilationOptionsArgs],
101159
) -> Program:
102160
program_def = ffront_stages.ProgramDefinition(definition=definition, grid_type=grid_type)
103161
return cls(
104162
definition_stage=program_def,
105163
backend=backend,
106-
connectivities=connectivities,
107-
enable_jit=enable_jit,
108-
static_params=static_params,
164+
compilation_options=options.CompilationOptions(**compilation_options),
109165
)
110166

167+
def __gt_type__(self):
168+
return self.past_stage.past_node.type
169+
111170
# needed in testing
112171
@property
113172
def definition(self) -> types.FunctionType:
@@ -229,27 +288,6 @@ def gtir(self) -> itir.Program:
229288
)
230289
return self._frontend_transforms.past_to_itir(no_args_past).data
231290

232-
@functools.cached_property
233-
def _compiled_programs(self) -> compiled_program.CompiledProgramsPool:
234-
if self.backend is None or self.backend == eve.NOTHING:
235-
raise RuntimeError("Cannot compile a program without backend.")
236-
237-
if self.static_params is None:
238-
object.__setattr__(self, "static_params", ())
239-
240-
argument_descriptor_mapping = {
241-
arguments.StaticArg: self.static_params,
242-
}
243-
244-
program_type = self.past_stage.past_node.type
245-
assert isinstance(program_type, ts_ffront.ProgramType)
246-
return compiled_program.CompiledProgramsPool(
247-
backend=self.backend,
248-
definition_stage=self.definition_stage,
249-
program_type=program_type,
250-
argument_descriptor_mapping=argument_descriptor_mapping, # type: ignore[arg-type] # covariant `type[T]` not possible
251-
)
252-
253291
def __call__(
254292
self,
255293
*args: Any,
@@ -259,10 +297,7 @@ def __call__(
259297
) -> None:
260298
if offset_provider is None:
261299
offset_provider = {}
262-
if enable_jit is None:
263-
enable_jit = (
264-
self.enable_jit if self.enable_jit is not None else config.ENABLE_JIT_DEFAULT
265-
)
300+
enable_jit = self.compilation_options.enable_jit if enable_jit is None else enable_jit
266301

267302
with metrics.collect() as metrics_source:
268303
if collect_info_metrics := (config.COLLECT_METRICS_LEVEL >= metrics.INFO):
@@ -302,50 +337,6 @@ def __call__(
302337
assert metrics_source is not None
303338
metrics_source.metrics[metrics.TOTAL_METRIC].add_sample(time.time() - start)
304339

305-
def compile(
306-
self,
307-
offset_provider: common.OffsetProviderType
308-
| common.OffsetProvider
309-
| list[common.OffsetProviderType | common.OffsetProvider]
310-
| None = None,
311-
enable_jit: bool | None = None,
312-
**static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]],
313-
) -> Self:
314-
"""
315-
Compiles the program for the given combination of static arguments and offset provider type.
316-
317-
Note: Unlike `with_...` methods, this method does not return a new instance of the program,
318-
but adds the compiled variants to the current program instance.
319-
"""
320-
# TODO(havogt): we should reconsider if we want to return a new program on `compile` (and
321-
# rename to `with_static_args` or similar) once we have a better understanding of the
322-
# use-cases.
323-
324-
if enable_jit is not None:
325-
object.__setattr__(self, "enable_jit", enable_jit)
326-
if self.static_params is None:
327-
object.__setattr__(self, "static_params", tuple(static_args.keys()))
328-
if self.connectivities is None and offset_provider is None:
329-
raise ValueError(
330-
"Cannot compile a program without connectivities / OffsetProviderType."
331-
)
332-
if not all(isinstance(v, list) for v in static_args.values()):
333-
raise TypeError(
334-
"Please provide the static arguments as lists."
335-
) # To avoid confusion with tuple args
336-
337-
offset_provider = self.connectivities if offset_provider is None else offset_provider
338-
if not isinstance(offset_provider, list):
339-
offset_provider = [offset_provider] # type: ignore[list-item] # cleanup offset_provider vs offset_provider_type
340-
341-
assert all(
342-
common.is_offset_provider(op) or common.is_offset_provider_type(op)
343-
for op in offset_provider
344-
)
345-
346-
self._compiled_programs.compile(offset_providers=offset_provider, **static_args)
347-
return self
348-
349340
def freeze(self) -> FrozenProgram:
350341
if self.backend is None:
351342
raise ValueError("Can not freeze a program without backend (embedded execution).")
@@ -540,9 +531,8 @@ def program(
540531
# `NOTHING` -> default backend, `None` -> no backend (embedded execution)
541532
backend: next_backend.Backend | eve.NothingType | None = eve.NOTHING,
542533
grid_type: common.GridType | None = None,
543-
enable_jit: bool | None = None, # only relevant if static_params are set
544-
static_params: Sequence[str] | None = None,
545534
frozen: bool = False,
535+
**compilation_options: Unpack[options.CompilationOptionsArgs],
546536
) -> Program | FrozenProgram | Callable[[types.FunctionType], Program | FrozenProgram]:
547537
"""
548538
Generate an implementation of a program from a Python function object.
@@ -569,8 +559,7 @@ def program_inner(definition: types.FunctionType) -> Program:
569559
next_backend.Backend | None, DEFAULT_BACKEND if backend is eve.NOTHING else backend
570560
),
571561
grid_type=grid_type,
572-
enable_jit=enable_jit,
573-
static_params=static_params,
562+
**compilation_options,
574563
)
575564
if frozen:
576565
return program.freeze() # type: ignore[return-value] # TODO(havogt): Should `FrozenProgram` be a `Program`?
@@ -583,7 +572,7 @@ def program_inner(definition: types.FunctionType) -> Program:
583572

584573

585574
@dataclasses.dataclass(frozen=True)
586-
class FieldOperator(GTCallable, Generic[OperatorNodeT]):
575+
class FieldOperator(CompiledProgramMixin, GTCallable, Generic[OperatorNodeT]):
587576
"""
588577
Construct a field operator object from a FOAST node.
589578
@@ -606,6 +595,7 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]):
606595

607596
definition_stage: ffront_stages.FieldOperatorDefinition
608597
backend: Optional[next_backend.Backend]
598+
compilation_options: options.CompilationOptions
609599
_program_cache: dict = dataclasses.field(
610600
init=False, default_factory=dict
611601
) # init=False ensure the cache is not copied in calls to replace
@@ -619,6 +609,7 @@ def from_function(
619609
*,
620610
operator_node_cls: type[OperatorNodeT] = foast.FieldOperator, # type: ignore[assignment] # TODO(ricoh): understand why mypy complains
621611
operator_attributes: Optional[dict[str, Any]] = None,
612+
**compilation_options: Unpack[options.CompilationOptionsArgs],
622613
) -> FieldOperator[OperatorNodeT]:
623614
return cls(
624615
definition_stage=ffront_stages.FieldOperatorDefinition(
@@ -628,6 +619,7 @@ def from_function(
628619
attributes=operator_attributes or {},
629620
),
630621
backend=backend,
622+
compilation_options=options.CompilationOptions(**compilation_options),
631623
)
632624

633625
# TODO(ricoh): linting should become optional, up to the backend.
@@ -701,12 +693,12 @@ def as_program(self, compiletime_args: arguments.CompileTimeArgs) -> Program:
701693
definition_stage=None, # type: ignore[arg-type] # ProgramFromPast needs to be fixed
702694
past_stage=past_stage,
703695
backend=self.backend,
704-
connectivities=None,
705-
enable_jit=False, # TODO(havogt): revisit ProgramFromPast
706-
static_params=None, # TODO(havogt): revisit ProgramFromPast
696+
compilation_options=self.compilation_options,
707697
)
708698

709-
def __call__(self, *args: Any, **kwargs: Any) -> Any:
699+
def __call__(
700+
self, *args: Any, enable_jit: bool = options.CompilationOptions.enable_jit, **kwargs: Any
701+
) -> Any:
710702
if not next_embedded.context.within_valid_context() and self.backend is not None:
711703
# non embedded execution
712704
offset_provider = {**kwargs.pop("offset_provider", {})}
@@ -719,15 +711,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
719711
domain = utils.tree_map(lambda _: domain)(out)
720712
out = utils.tree_map(lambda f, dom: f[dom])(out, domain)
721713

722-
args, kwargs = type_info.canonicalize_arguments(
723-
self.foast_stage.foast_node.type, args, kwargs
724-
)
725-
return self.backend(
726-
self.definition_stage,
727-
*args,
728-
out=out,
729-
offset_provider=offset_provider,
730-
**kwargs,
714+
return self._compiled_programs(
715+
*args, **kwargs, out=out, offset_provider=offset_provider, enable_jit=enable_jit
731716
)
732717
else:
733718
if not next_embedded.context.within_valid_context():
@@ -793,6 +778,7 @@ def field_operator(
793778
*,
794779
backend: next_backend.Backend | eve.NothingType | None = eve.NOTHING,
795780
grid_type: common.GridType | None = None,
781+
**compilation_options: Unpack[options.CompilationOptionsArgs],
796782
) -> (
797783
FieldOperator[foast.FieldOperator]
798784
| Callable[[types.FunctionType], FieldOperator[foast.FieldOperator]]
@@ -820,6 +806,7 @@ def field_operator_inner(definition: types.FunctionType) -> FieldOperator[foast.
820806
next_backend.Backend | None, DEFAULT_BACKEND if backend is eve.NOTHING else backend
821807
),
822808
grid_type,
809+
**compilation_options
823810
)
824811

825812
return field_operator_inner if definition is None else field_operator_inner(definition)

src/gt4py/next/ffront/foast_to_past.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,19 @@ def __call__(self, inp: AOT_FOP) -> AOT_PRG:
9696
# TODO(tehrengruber): check foast operator has no out argument that clashes
9797
# with the out argument of the program we generate here.
9898

99-
arg_types, kwarg_types = inp.args.args, inp.args.kwargs
99+
# TODO(tehrengruber): This function used to be wrong. The kwarg_types here are
100+
# just ignored silently and the out argument of the field operator used to be passed here.
101+
# With the CompiledProgramPool the out argument is just the third argument (consistent
102+
# with the program definition below). So we just drop it now. In general this function
103+
# should be reworked or replaced. Decide in review.
104+
arg_types, kwarg_types = inp.args.args[:-1], inp.args.kwargs
105+
assert not kwarg_types
100106

107+
type_ = inp.data.foast_node.type
101108
loc = inp.data.foast_node.location
102109
# use a new UID generator to allow caching
103110
param_sym_uids = eve_utils.UIDGenerator()
104111

105-
type_ = inp.data.foast_node.type
106112
params_decl: list[past.Symbol] = [
107113
past.DataSymbol(
108114
id=param_sym_uids.sequential_id(prefix="__sym"),

src/gt4py/next/ffront/signature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66
# Please, refer to the LICENSE file in the root directory.
77
# SPDX-License-Identifier: BSD-3-Clause
8-
8+
# TODO: delete this file
99
# TODO(ricoh): This overlaps with `canonicalize_arguments`, solutions:
1010
# - merge the two
1111
# - extract the signature gathering functionality from canonicalize_arguments

0 commit comments

Comments
 (0)