@@ -237,6 +237,7 @@ def __init__(
237237 ast_info = None
238238 ast_dumped = None
239239 src = None
240+ transpiled = None
240241 from pyspark .sql import SparkSession
241242
242243 session = SparkSession ._instantiatedSession
@@ -256,12 +257,77 @@ def __init__(
256257 except Exception :
257258 src = inspect .getsource (func .__call__ )
258259 ast_info = ast .parse (src )
260+ transpiled = _transpile (src , ast_info )
259261 ast_dumped = _dump_to_tree (ast_info )
260262 except Exception as e :
261263 warnings .warn (f"Error building AST for UDF: { e } -- will not transpile" )
262264 self .src = src
263265 self .ast_dumped = ast_dumped
266+ self .transpiled = transpiled
264267
268+ # Transpiling tools
269+ @staticmethod
270+ def _transpile (src : str , ast_info : ast .AST ) -> Optional [Column ]:
271+ # Short circuit on nothing to transpile.
272+ if src == "" or ast_info is None :
273+ return None
274+ lambda_ast = _get_lambda_from_ast (ast_info )
275+ if lambda_ast is None :
276+ return None
277+ lambda_body = lambda_ast .body
278+ params = _get_parameter_list (lambda_ast )
279+ return _convert_function (params , lambda_body )
280+
281+ def _convert_function (params : List [str ], body : ast .AST ) -> Optional [Column ]:
282+ match body :
283+ case ast .BinOp (left = left , op = op , right = right ):
284+ match op :
285+ case ast .Add ():
286+ left_col = _convert_function (params , left )
287+ if left_col is None :
288+ return
289+ right_col = _convert_function (params , right )
290+ if right_col is None :
291+ return
292+ return left_col .add (right_col )
293+ case _:
294+ return
295+ case ast .Constant (value = value ):
296+ return Column ._literal (value )
297+ case ast .Name (id = name , ctx = ast .Load ()):
298+ # Note: the Python UDF parameter name might not match the column
299+ # And at this point we don't know who are children are going to be.
300+ if name in params :
301+ param_index = params .index (name )
302+ # TODO: Add a special node here that indicates we want child number param_index
303+ return ParamIndexNode (param_index )
304+ case _:
305+ return
306+
307+
308+
309+ @staticmethod
310+ def _get_parameter_list (lambdaAst : ast .Lambda ) -> List [str ]:
311+ params = []
312+ for arg in lambdaAst .args .args :
313+ params .append (arg .arg )
314+ return params
315+
316+ @staticmethod
317+ def _get_lambda_from_ast (ast : ast .AST ) -> Optional [ast .Lambda ]:
318+ module = ast .Module
319+ module_body = module .body
320+ assigned = module_body .Assign .value
321+ if isinstance (assigned , ast .Lambda ):
322+ return assigned
323+ else :
324+ return assigned .Call .func .args .Lambda
325+
326+ @staticmethod
327+ def _convert_function (params : List [str ], body : ast .AST ) -> Optional [Column ]
328+
329+
330+ # Everything else
265331 @staticmethod
266332 def _check_return_type (returnType : DataType , evalType : int ) -> None :
267333 if evalType == PythonEvalType .SQL_ARROW_BATCHED_UDF :
0 commit comments