Skip to content

Commit 54d3915

Browse files
committed
Add py UDF starting to hack on
1 parent 6f05ad5 commit 54d3915

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

python/pyspark/sql/udf.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def __init__(
237237
ast_info = None
238238
ast_dumped = None
239239
src = None
240+
transpiled = None
240241
from pyspark.sql import SparkSession
241242

242243
session = SparkSession._instantiatedSession
@@ -256,12 +257,77 @@ def __init__(
256257
except Exception:
257258
src = inspect.getsource(func.__call__)
258259
ast_info = ast.parse(src)
260+
transpiled = _transpile(src, ast_info)
259261
ast_dumped = _dump_to_tree(ast_info)
260262
except Exception as e:
261263
warnings.warn(f"Error building AST for UDF: {e} -- will not transpile")
262264
self.src = src
263265
self.ast_dumped = ast_dumped
266+
self.transpiled = transpiled
264267

268+
# Transpiling tools
269+
@staticmethod
270+
def _transpile(src: str, ast_info: ast.AST) -> Optional[Column]:
271+
# Short circuit on nothing to transpile.
272+
if src == "" or ast_info is None:
273+
return None
274+
lambda_ast = _get_lambda_from_ast(ast_info)
275+
if lambda_ast is None:
276+
return None
277+
lambda_body = lambda_ast.body
278+
params = _get_parameter_list(lambda_ast)
279+
return _convert_function(params, lambda_body)
280+
281+
def _convert_function(params: List[str], body: ast.AST) -> Optional[Column]:
282+
match body:
283+
case ast.BinOp(left=left, op=op, right=right):
284+
match op:
285+
case ast.Add():
286+
left_col = _convert_function(params, left)
287+
if left_col is None:
288+
return
289+
right_col = _convert_function(params, right)
290+
if right_col is None:
291+
return
292+
return left_col.add(right_col)
293+
case _:
294+
return
295+
case ast.Constant(value=value):
296+
return Column._literal(value)
297+
case ast.Name(id=name, ctx=ast.Load()):
298+
# Note: the Python UDF parameter name might not match the column
299+
# And at this point we don't know who are children are going to be.
300+
if name in params:
301+
param_index = params.index(name)
302+
# TODO: Add a special node here that indicates we want child number param_index
303+
return ParamIndexNode(param_index)
304+
case _:
305+
return
306+
307+
308+
309+
@staticmethod
310+
def _get_parameter_list(lambdaAst: ast.Lambda) -> List[str]:
311+
params = []
312+
for arg in lambdaAst.args.args:
313+
params.append(arg.arg)
314+
return params
315+
316+
@staticmethod
317+
def _get_lambda_from_ast(ast: ast.AST) -> Optional[ast.Lambda]:
318+
module = ast.Module
319+
module_body = module.body
320+
assigned = module_body.Assign.value
321+
if isinstance(assigned, ast.Lambda):
322+
return assigned
323+
else:
324+
return assigned.Call.func.args.Lambda
325+
326+
@staticmethod
327+
def _convert_function(params: List[str], body: ast.AST) -> Optional[Column]
328+
329+
330+
# Everything else
265331
@staticmethod
266332
def _check_return_type(returnType: DataType, evalType: int) -> None:
267333
if evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF:

0 commit comments

Comments
 (0)