2424from gt4py import eve
2525from gt4py ._core import definitions as core_defs
2626from gt4py .eve import extended_typing as xtyping
27- from gt4py .eve .extended_typing import Self , override
27+ from gt4py .eve .extended_typing import Self , Unpack , override
2828from gt4py .next import (
2929 allocators as next_allocators ,
3030 backend as next_backend ,
4747)
4848from gt4py .next .ffront .gtcallable import GTCallable
4949from gt4py .next .iterator import ir as itir
50- from gt4py .next .otf import arguments , compiled_program , stages , toolchain
50+ from gt4py .next .otf import arguments , compiled_program , options , stages , toolchain
5151from gt4py .next .type_system import type_info , type_specifications as ts , type_translation
5252
5353
5454DEFAULT_BACKEND : next_backend .Backend | None = None
5555
5656
57+ class CompiledProgramMixin :
58+ @functools .cached_property
59+ def _compiled_programs (self ) -> compiled_program .CompiledProgramsPool :
60+ if self .backend is None or self .backend == eve .NOTHING :
61+ raise RuntimeError ("Cannot compile a program without backend." )
62+
63+ if self .compilation_options .static_params is None :
64+ object .__setattr__ (self .compilation_options , "static_params" , ())
65+
66+ argument_descriptor_mapping = {
67+ arguments .StaticArg : self .compilation_options .static_params ,
68+ }
69+
70+ return compiled_program .CompiledProgramsPool (
71+ backend = self .backend ,
72+ definition_stage = self .definition_stage ,
73+ callable_type = self .__gt_type__ (),
74+ argument_descriptor_mapping = argument_descriptor_mapping ,
75+ # type: ignore[arg-type] # covariant `type[T]` not possible
76+ )
77+
78+ def compile (
79+ self ,
80+ offset_provider : common .OffsetProviderType
81+ | common .OffsetProvider
82+ | list [common .OffsetProviderType | common .OffsetProvider ]
83+ | None = None ,
84+ enable_jit : bool | None = None ,
85+ ** static_args : list [xtyping .MaybeNestedInTuple [core_defs .Scalar ]],
86+ ) -> Self :
87+ """
88+ Compiles the program for the given combination of static arguments and offset provider type.
89+
90+ Note: Unlike `with_...` methods, this method does not return a new instance of the program,
91+ but adds the compiled variants to the current program instance.
92+ """
93+ # TODO(havogt): we should reconsider if we want to return a new program on `compile` (and
94+ # rename to `with_static_args` or similar) once we have a better understanding of the
95+ # use-cases.
96+
97+ if enable_jit is not None :
98+ object .__setattr__ (self .compilation_options , "enable_jit" , enable_jit )
99+ if self .compilation_options .static_params is None :
100+ object .__setattr__ (self .compilation_options , "static_params" , tuple (static_args .keys ()))
101+ if self .compilation_options .connectivities is None and offset_provider is None :
102+ raise ValueError (
103+ "Cannot compile a program without connectivities / OffsetProviderType."
104+ )
105+ if not all (isinstance (v , list ) for v in static_args .values ()):
106+ raise TypeError (
107+ "Please provide the static arguments as lists."
108+ ) # To avoid confusion with tuple args
109+
110+ offset_provider = (
111+ self .compilation_options .connectivities if offset_provider is None else offset_provider
112+ )
113+ if not isinstance (offset_provider , list ):
114+ offset_provider = [offset_provider ] # type: ignore[list-item] # cleanup offset_provider vs offset_provider_type
115+
116+ assert all (
117+ common .is_offset_provider (op ) or common .is_offset_provider_type (op )
118+ for op in offset_provider
119+ )
120+
121+ self ._compiled_programs .compile (offset_providers = offset_provider , ** static_args )
122+ return self
123+
124+
57125# TODO(tehrengruber): Decide if and how programs can call other programs. As a
58126# result Program could become a GTCallable.
59127@dataclasses .dataclass (frozen = True )
60- class Program :
128+ class Program ( CompiledProgramMixin ) :
61129 """
62130 Construct a program object from a PAST node.
63131
@@ -79,35 +147,26 @@ class Program:
79147
80148 definition_stage : ffront_stages .ProgramDefinition
81149 backend : Optional [next_backend .Backend ]
82- connectivities : Optional [
83- common .OffsetProvider
84- ] # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information
85- enable_jit : bool | None
86- static_params : (
87- Sequence [str ] | None
88- ) # if the user requests static params, they will be used later to initialize CompiledPrograms
150+ compilation_options : options .CompilationOptions
89151
90152 @classmethod
91153 def from_function (
92154 cls ,
93155 definition : types .FunctionType ,
94156 backend : next_backend .Backend | None ,
95157 grid_type : common .GridType | None = None ,
96- enable_jit : bool | None = None ,
97- static_params : Sequence [str ] | None = None ,
98- connectivities : Optional [
99- common .OffsetProvider
100- ] = None , # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information
158+ ** compilation_options : Unpack [options .CompilationOptionsArgs ],
101159 ) -> Program :
102160 program_def = ffront_stages .ProgramDefinition (definition = definition , grid_type = grid_type )
103161 return cls (
104162 definition_stage = program_def ,
105163 backend = backend ,
106- connectivities = connectivities ,
107- enable_jit = enable_jit ,
108- static_params = static_params ,
164+ compilation_options = options .CompilationOptions (** compilation_options ),
109165 )
110166
167+ def __gt_type__ (self ):
168+ return self .past_stage .past_node .type
169+
111170 # needed in testing
112171 @property
113172 def definition (self ) -> types .FunctionType :
@@ -229,27 +288,6 @@ def gtir(self) -> itir.Program:
229288 )
230289 return self ._frontend_transforms .past_to_itir (no_args_past ).data
231290
232- @functools .cached_property
233- def _compiled_programs (self ) -> compiled_program .CompiledProgramsPool :
234- if self .backend is None or self .backend == eve .NOTHING :
235- raise RuntimeError ("Cannot compile a program without backend." )
236-
237- if self .static_params is None :
238- object .__setattr__ (self , "static_params" , ())
239-
240- argument_descriptor_mapping = {
241- arguments .StaticArg : self .static_params ,
242- }
243-
244- program_type = self .past_stage .past_node .type
245- assert isinstance (program_type , ts_ffront .ProgramType )
246- return compiled_program .CompiledProgramsPool (
247- backend = self .backend ,
248- definition_stage = self .definition_stage ,
249- program_type = program_type ,
250- argument_descriptor_mapping = argument_descriptor_mapping , # type: ignore[arg-type] # covariant `type[T]` not possible
251- )
252-
253291 def __call__ (
254292 self ,
255293 * args : Any ,
@@ -259,10 +297,7 @@ def __call__(
259297 ) -> None :
260298 if offset_provider is None :
261299 offset_provider = {}
262- if enable_jit is None :
263- enable_jit = (
264- self .enable_jit if self .enable_jit is not None else config .ENABLE_JIT_DEFAULT
265- )
300+ enable_jit = self .compilation_options .enable_jit if enable_jit is None else enable_jit
266301
267302 with metrics .collect () as metrics_source :
268303 if collect_info_metrics := (config .COLLECT_METRICS_LEVEL >= metrics .INFO ):
@@ -302,50 +337,6 @@ def __call__(
302337 assert metrics_source is not None
303338 metrics_source .metrics [metrics .TOTAL_METRIC ].add_sample (time .time () - start )
304339
305- def compile (
306- self ,
307- offset_provider : common .OffsetProviderType
308- | common .OffsetProvider
309- | list [common .OffsetProviderType | common .OffsetProvider ]
310- | None = None ,
311- enable_jit : bool | None = None ,
312- ** static_args : list [xtyping .MaybeNestedInTuple [core_defs .Scalar ]],
313- ) -> Self :
314- """
315- Compiles the program for the given combination of static arguments and offset provider type.
316-
317- Note: Unlike `with_...` methods, this method does not return a new instance of the program,
318- but adds the compiled variants to the current program instance.
319- """
320- # TODO(havogt): we should reconsider if we want to return a new program on `compile` (and
321- # rename to `with_static_args` or similar) once we have a better understanding of the
322- # use-cases.
323-
324- if enable_jit is not None :
325- object .__setattr__ (self , "enable_jit" , enable_jit )
326- if self .static_params is None :
327- object .__setattr__ (self , "static_params" , tuple (static_args .keys ()))
328- if self .connectivities is None and offset_provider is None :
329- raise ValueError (
330- "Cannot compile a program without connectivities / OffsetProviderType."
331- )
332- if not all (isinstance (v , list ) for v in static_args .values ()):
333- raise TypeError (
334- "Please provide the static arguments as lists."
335- ) # To avoid confusion with tuple args
336-
337- offset_provider = self .connectivities if offset_provider is None else offset_provider
338- if not isinstance (offset_provider , list ):
339- offset_provider = [offset_provider ] # type: ignore[list-item] # cleanup offset_provider vs offset_provider_type
340-
341- assert all (
342- common .is_offset_provider (op ) or common .is_offset_provider_type (op )
343- for op in offset_provider
344- )
345-
346- self ._compiled_programs .compile (offset_providers = offset_provider , ** static_args )
347- return self
348-
349340 def freeze (self ) -> FrozenProgram :
350341 if self .backend is None :
351342 raise ValueError ("Can not freeze a program without backend (embedded execution)." )
@@ -540,9 +531,8 @@ def program(
540531 # `NOTHING` -> default backend, `None` -> no backend (embedded execution)
541532 backend : next_backend .Backend | eve .NothingType | None = eve .NOTHING ,
542533 grid_type : common .GridType | None = None ,
543- enable_jit : bool | None = None , # only relevant if static_params are set
544- static_params : Sequence [str ] | None = None ,
545534 frozen : bool = False ,
535+ ** compilation_options : Unpack [options .CompilationOptionsArgs ],
546536) -> Program | FrozenProgram | Callable [[types .FunctionType ], Program | FrozenProgram ]:
547537 """
548538 Generate an implementation of a program from a Python function object.
@@ -569,8 +559,7 @@ def program_inner(definition: types.FunctionType) -> Program:
569559 next_backend .Backend | None , DEFAULT_BACKEND if backend is eve .NOTHING else backend
570560 ),
571561 grid_type = grid_type ,
572- enable_jit = enable_jit ,
573- static_params = static_params ,
562+ ** compilation_options ,
574563 )
575564 if frozen :
576565 return program .freeze () # type: ignore[return-value] # TODO(havogt): Should `FrozenProgram` be a `Program`?
@@ -583,7 +572,7 @@ def program_inner(definition: types.FunctionType) -> Program:
583572
584573
585574@dataclasses .dataclass (frozen = True )
586- class FieldOperator (GTCallable , Generic [OperatorNodeT ]):
575+ class FieldOperator (CompiledProgramMixin , GTCallable , Generic [OperatorNodeT ]):
587576 """
588577 Construct a field operator object from a FOAST node.
589578
@@ -606,6 +595,7 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]):
606595
607596 definition_stage : ffront_stages .FieldOperatorDefinition
608597 backend : Optional [next_backend .Backend ]
598+ compilation_options : options .CompilationOptions
609599 _program_cache : dict = dataclasses .field (
610600 init = False , default_factory = dict
611601 ) # init=False ensure the cache is not copied in calls to replace
@@ -619,6 +609,7 @@ def from_function(
619609 * ,
620610 operator_node_cls : type [OperatorNodeT ] = foast .FieldOperator , # type: ignore[assignment] # TODO(ricoh): understand why mypy complains
621611 operator_attributes : Optional [dict [str , Any ]] = None ,
612+ ** compilation_options : Unpack [options .CompilationOptionsArgs ],
622613 ) -> FieldOperator [OperatorNodeT ]:
623614 return cls (
624615 definition_stage = ffront_stages .FieldOperatorDefinition (
@@ -628,6 +619,7 @@ def from_function(
628619 attributes = operator_attributes or {},
629620 ),
630621 backend = backend ,
622+ compilation_options = options .CompilationOptions (** compilation_options ),
631623 )
632624
633625 # TODO(ricoh): linting should become optional, up to the backend.
@@ -701,12 +693,12 @@ def as_program(self, compiletime_args: arguments.CompileTimeArgs) -> Program:
701693 definition_stage = None , # type: ignore[arg-type] # ProgramFromPast needs to be fixed
702694 past_stage = past_stage ,
703695 backend = self .backend ,
704- connectivities = None ,
705- enable_jit = False , # TODO(havogt): revisit ProgramFromPast
706- static_params = None , # TODO(havogt): revisit ProgramFromPast
696+ compilation_options = self .compilation_options ,
707697 )
708698
709- def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
699+ def __call__ (
700+ self , * args : Any , enable_jit : bool = options .CompilationOptions .enable_jit , ** kwargs : Any
701+ ) -> Any :
710702 if not next_embedded .context .within_valid_context () and self .backend is not None :
711703 # non embedded execution
712704 offset_provider = {** kwargs .pop ("offset_provider" , {})}
@@ -719,15 +711,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
719711 domain = utils .tree_map (lambda _ : domain )(out )
720712 out = utils .tree_map (lambda f , dom : f [dom ])(out , domain )
721713
722- args , kwargs = type_info .canonicalize_arguments (
723- self .foast_stage .foast_node .type , args , kwargs
724- )
725- return self .backend (
726- self .definition_stage ,
727- * args ,
728- out = out ,
729- offset_provider = offset_provider ,
730- ** kwargs ,
714+ return self ._compiled_programs (
715+ * args , ** kwargs , out = out , offset_provider = offset_provider , enable_jit = enable_jit
731716 )
732717 else :
733718 if not next_embedded .context .within_valid_context ():
@@ -793,6 +778,7 @@ def field_operator(
793778 * ,
794779 backend : next_backend .Backend | eve .NothingType | None = eve .NOTHING ,
795780 grid_type : common .GridType | None = None ,
781+ ** compilation_options : Unpack [options .CompilationOptionsArgs ],
796782) -> (
797783 FieldOperator [foast .FieldOperator ]
798784 | Callable [[types .FunctionType ], FieldOperator [foast .FieldOperator ]]
@@ -820,6 +806,7 @@ def field_operator_inner(definition: types.FunctionType) -> FieldOperator[foast.
820806 next_backend .Backend | None , DEFAULT_BACKEND if backend is eve .NOTHING else backend
821807 ),
822808 grid_type ,
809+ ** compilation_options
823810 )
824811
825812 return field_operator_inner if definition is None else field_operator_inner (definition )
0 commit comments