33import zipfile
44from dataclasses import dataclass
55from enum import Enum
6- from typing import AbstractSet , Any , ClassVar , Dict , List , Mapping , Optional , Tuple , TypeVar , Union
6+ from typing import AbstractSet , Any , ClassVar , Dict , List , Mapping , Optional , Set , Tuple , TypeVar , Union
77
88import orjson
99
@@ -68,15 +68,45 @@ def local_jar_information() -> LocalJarInformation:
6868 raise ValueError (f'Hail requires either { hail_jar } or { hail_all_spark_jar } .' )
6969
7070
71+ class IRFunction :
72+ def __init__ (
73+ self ,
74+ name : str ,
75+ type_parameters : Union [Tuple [HailType , ...], List [HailType ]],
76+ value_parameter_names : Union [Tuple [str , ...], List [str ]],
77+ value_parameter_types : Union [Tuple [HailType , ...], List [HailType ]],
78+ return_type : HailType ,
79+ body : Expression ,
80+ ):
81+ assert len (value_parameter_names ) == len (value_parameter_types )
82+ render = CSERenderer ()
83+ self ._name = name
84+ self ._type_parameters = type_parameters
85+ self ._value_parameter_names = value_parameter_names
86+ self ._value_parameter_types = value_parameter_types
87+ self ._return_type = return_type
88+ self ._rendered_body = render (finalize_randomness (body ._ir ))
89+
90+ def to_dataclass (self ):
91+ return SerializedIRFunction (
92+ name = self ._name ,
93+ type_parameters = [tp ._parsable_string () for tp in self ._type_parameters ],
94+ value_parameter_names = list (self ._value_parameter_names ),
95+ value_parameter_types = [vpt ._parsable_string () for vpt in self ._value_parameter_types ],
96+ return_type = self ._return_type ._parsable_string (),
97+ rendered_body = self ._rendered_body ,
98+ )
99+
100+
71101class ActionTag (Enum ):
72- LOAD_REFERENCES_FROM_DATASET = 1
73- VALUE_TYPE = 2
74- TABLE_TYPE = 3
75- MATRIX_TABLE_TYPE = 4
76- BLOCK_MATRIX_TYPE = 5
77- EXECUTE = 6
78- PARSE_VCF_METADATA = 7
79- IMPORT_FAM = 8
102+ VALUE_TYPE = 1
103+ TABLE_TYPE = 2
104+ MATRIX_TABLE_TYPE = 3
105+ BLOCK_MATRIX_TYPE = 4
106+ EXECUTE = 5
107+ PARSE_VCF_METADATA = 6
108+ IMPORT_FAM = 7
109+ LOAD_REFERENCES_FROM_DATASET = 8
80110 FROM_FASTA_FILE = 9
81111
82112
@@ -90,11 +120,21 @@ class IRTypePayload(ActionPayload):
90120 ir : str
91121
92122
123+ @dataclass
124+ class SerializedIRFunction :
125+ name : str
126+ type_parameters : List [str ]
127+ value_parameter_names : List [str ]
128+ value_parameter_types : List [str ]
129+ return_type : str
130+ rendered_body : str
131+
132+
93133@dataclass
94134class ExecutePayload (ActionPayload ):
95135 ir : str
136+ fns : List [SerializedIRFunction ]
96137 stream_codec : str
97- timed : bool
98138
99139
100140@dataclass
@@ -164,17 +204,24 @@ def _valid_flags(self) -> AbstractSet[str]:
164204 def __init__ (self ):
165205 self ._persisted_locations = dict ()
166206 self ._references = {}
207+ self .functions : List [IRFunction ] = []
208+ self ._registered_ir_function_names : Set [str ] = set ()
167209
168210 @abc .abstractmethod
169211 def validate_file (self , uri : str ):
170212 raise NotImplementedError
171213
172214 @abc .abstractmethod
173215 def stop (self ):
174- pass
216+ self .functions = []
217+ self ._registered_ir_function_names = set ()
175218
176219 def execute (self , ir : BaseIR , timed : bool = False ) -> Any :
177- payload = ExecutePayload (self ._render_ir (ir ), '{"name":"StreamBufferSpec"}' , timed )
220+ payload = ExecutePayload (
221+ self ._render_ir (ir ),
222+ fns = [fn .to_dataclass () for fn in self .functions ],
223+ stream_codec = '{"name":"StreamBufferSpec"}' ,
224+ )
178225 try :
179226 result , timings = self ._rpc (ActionTag .EXECUTE , payload )
180227 except FatalError as e :
@@ -300,7 +347,6 @@ def unpersist(self, dataset: Dataset) -> Dataset:
300347 tempfile .__exit__ (None , None , None )
301348 return unpersisted
302349
303- @abc .abstractmethod
304350 def register_ir_function (
305351 self ,
306352 name : str ,
@@ -310,11 +356,13 @@ def register_ir_function(
310356 return_type : HailType ,
311357 body : Expression ,
312358 ):
313- pass
359+ self ._registered_ir_function_names .add (name )
360+ self .functions .append (
361+ IRFunction (name , type_parameters , value_parameter_names , value_parameter_types , return_type , body )
362+ )
314363
315- @abc .abstractmethod
316364 def _is_registered_ir_function_name (self , name : str ) -> bool :
317- pass
365+ return name in self . _registered_ir_function_names
318366
319367 @abc .abstractmethod
320368 def persist_expression (self , expr : Expression ) -> Expression :
0 commit comments