Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4a78e24
Start working on getting the AST & code over into the JVM so we can d…
sfc-gh-hkarau Dec 1, 2025
51d53c4
Fix some python UDF parsing stuff
sfc-gh-hkarau Dec 1, 2025
cc5201b
Try and debug some of the JVM side UDF magic.
sfc-gh-hkarau Dec 1, 2025
039e99f
Do a bit of work to try and manipulate the AST scala side
sfc-gh-hkarau Dec 1, 2025
94bebb6
Begging of transpilation but by doing it in the optimizer we make the…
sfc-gh-hkarau Dec 2, 2025
60fd06b
Don't call repr on the nodes, then all of our built in types become s…
sfc-gh-hkarau Dec 4, 2025
c2c84f2
Some hacks, we should explore using resolve in Expression resolver to…
sfc-gh-hkarau Dec 4, 2025
47246b5
Add type upgrading.
sfc-gh-hkarau Dec 4, 2025
d811152
Compiles! I _think_ mapExpressions will recurse down everything which…
sfc-gh-hkarau Dec 4, 2025
82a475c
Try and fix UDF chain detection, and don't use snake case inside of S…
sfc-gh-hkarau Dec 5, 2025
755d1b6
Add UDFTypeCoercesExpressionTypes, this is just a concrete implementa…
sfc-gh-hkarau Dec 5, 2025
efd9bd2
Add a flag to enable and disable transpilation
sfc-gh-hkarau Dec 5, 2025
cff48a0
Fix test compilation errors re-add default None
sfc-gh-hkarau Dec 5, 2025
a25e5fd
ConvertToCatalystSuite generated by Calude Opus 4.5
sfc-gh-hkarau Dec 5, 2025
4a5d7cb
did the children UDF check in reverse(oops)
sfc-gh-hkarau Dec 5, 2025
9e7633c
Fix the tree expression checks in the convert to catalyst suite
sfc-gh-hkarau Dec 5, 2025
f7fca5a
Add failing python side test for now
sfc-gh-hkarau Dec 5, 2025
0310b29
Run reformat python
sfc-gh-hkarau Dec 5, 2025
7ade88d
Merge branch 'master' into explore-transpilation-yet-again-SPARK-1408…
sfc-gh-hkarau Dec 5, 2025
e270433
Use assert in to make sure we grab the core of the lambda.
sfc-gh-hkarau Dec 5, 2025
f9a84bb
Move the AST conversion magic into the base class, start adding a way…
sfc-gh-hkarau Dec 5, 2025
07c498a
Cleanup snake case in Scala
sfc-gh-hkarau Dec 5, 2025
caff1eb
Add pureCatalystExpression as an option for constructing UserDefinedP…
sfc-gh-hkarau Dec 5, 2025
ee7a50f
Fix style issues, add required overrides
sfc-gh-hkarau Dec 5, 2025
facade3
Add py UDF starting to hack on
sfc-gh-hkarau Dec 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __call__(self, col):
pudf = UserDefinedFunction(call, LongType())
res = data.select(pudf(data["number"]).alias("plus_four"))
self.assertEqual(res.agg({"plus_four": "sum"}).collect()[0][0], 85)
self.assertIn("col + 4", pudf._judf.src())

def test_udf_with_partial_function(self):
data = self.spark.createDataFrame([(i, i**2) for i in range(10)], ["number", "squared"])
Expand Down
151 changes: 144 additions & 7 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""
User-defined function related classes and functions
"""
import ast
import functools
import inspect
import sys
Expand Down Expand Up @@ -69,6 +70,40 @@ def _wrap_function(
)


def _dump_to_tree(node):
"""
Return a formatted dump of the tree in node. This is based on
Lib/ast.py from the standard library, but modified to return
basic types for sending over to the JVM side.
"""

def _format(node, level=0):
if isinstance(node, ast.AST):
cls = type(node)
args = []
args_buffer = []
allsimple = True
for name in node._fields:
try:
value = getattr(node, name)
except AttributeError:
continue
if value is None and getattr(cls, name, ...) is None:
continue
value, simple = _format(value, level)
args.append((name, value))
return (node.__class__.__name__, args), False
elif isinstance(node, list):
if not node:
return [], True
return (list(_format(x, level)[0] for x in node)), False
return node, True

if not isinstance(node, ast.AST):
raise TypeError("expected AST, got %r" % node.__class__.__name__)
return _format(node)[0]


def _create_udf(
f: Callable[..., Any],
returnType: "DataTypeOrString",
Expand All @@ -78,6 +113,7 @@ def _create_udf(
) -> "UserDefinedFunctionLike":
"""Create a regular(non-Arrow-optimized) Python UDF."""
# Set the name of the UserDefinedFunction object to be the name of function f
# Possible todo: heuristic for if we even bother.
udf_obj = UserDefinedFunction(
f, returnType=returnType, name=name, evalType=evalType, deterministic=deterministic
)
Expand Down Expand Up @@ -165,12 +201,6 @@ def __init__(
evalType: int = PythonEvalType.SQL_BATCHED_UDF,
deterministic: bool = True,
):
if not callable(func):
raise PySparkTypeError(
errorClass="NOT_CALLABLE",
messageParameters={"arg_name": "func", "arg_type": type(func).__name__},
)

if not isinstance(returnType, (DataType, str)):
raise PySparkTypeError(
errorClass="NOT_DATATYPE_OR_STR",
Expand All @@ -196,7 +226,108 @@ def __init__(
)
self.evalType = evalType
self.deterministic = deterministic
# Make sure the function is callable first.
if not callable(func):
raise PySparkTypeError(
errorClass="NOT_CALLABLE",
messageParameters={"arg_name": "func", "arg_type": type(func).__name__},
)

# Extract Python UDF details if transpilation is enabled.
ast_info = None
ast_dumped = None
src = None
transpiled = None
from pyspark.sql import SparkSession

session = SparkSession._instantiatedSession
transpile_enabled = (
False
if session is None
else session.conf.get("spark.sql.optimizer.transpilePyUDFS") == "true"
)
if transpile_enabled:
try:
# Note: consider maybe dill? (see the JYTHON PR)
# inspect getsource does not work for functions defined in vanilla
# repl, but does for those in files or in ipython.
# It also fails when we give it an instance of a callable class.
try:
src = inspect.getsource(func)
except Exception:
src = inspect.getsource(func.__call__)
ast_info = ast.parse(src)
transpiled = _transpile(src, ast_info)
ast_dumped = _dump_to_tree(ast_info)
except Exception as e:
warnings.warn(f"Error building AST for UDF: {e} -- will not transpile")
self.src = src
self.ast_dumped = ast_dumped
self.transpiled = transpiled

# Transpiling tools
@staticmethod
def _transpile(src: str, ast_info: ast.AST) -> Optional[Column]:
# Short circuit on nothing to transpile.
if src == "" or ast_info is None:
return None
lambda_ast = _get_lambda_from_ast(ast_info)
if lambda_ast is None:
return None
lambda_body = lambda_ast.body
params = _get_parameter_list(lambda_ast)
return _convert_function(params, lambda_body)

def _convert_function(params: List[str], body: ast.AST) -> Optional[Column]:
match body:
case ast.BinOp(left=left, op=op, right=right):
match op:
case ast.Add():
left_col = _convert_function(params, left)
if left_col is None:
return
right_col = _convert_function(params, right)
if right_col is None:
return
return left_col.add(right_col)
case _:
return
case ast.Constant(value=value):
return Column._literal(value)
case ast.Name(id=name, ctx=ast.Load()):
# Note: the Python UDF parameter name might not match the column
# And at this point we don't know who are children are going to be.
if name in params:
param_index = params.index(name)
# TODO: Add a special node here that indicates we want child number param_index
return ParamIndexNode(param_index)
case _:
return



@staticmethod
def _get_parameter_list(lambdaAst: ast.Lambda) -> List[str]:
params = []
for arg in lambdaAst.args.args:
params.append(arg.arg)
return params

@staticmethod
def _get_lambda_from_ast(ast: ast.AST) -> Optional[ast.Lambda]:
module = ast.Module
module_body = module.body
assigned = module_body.Assign.value
if isinstance(assigned, ast.Lambda):
return assigned
else:
return assigned.Call.func.args.Lambda

@staticmethod
def _convert_function(params: List[str], body: ast.AST) -> Optional[Column]


# Everything else
@staticmethod
def _check_return_type(returnType: DataType, evalType: int) -> None:
if evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF:
Expand Down Expand Up @@ -413,7 +544,13 @@ def _create_judf(self, func: Callable[..., Any]) -> "JavaObject":
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
assert sc._jvm is not None
judf = getattr(sc._jvm, "org.apache.spark.sql.execution.python.UserDefinedPythonFunction")(
self._name, wrapped_func, jdt, self.evalType, self.deterministic
self._name,
wrapped_func,
jdt,
self.evalType,
self.deterministic,
self.src,
self.ast_dumped,
)
return judf

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ trait CoercesExpressionTypes extends SQLConfHelper {
withTypeCoercion.withNewChildren(newChildren)
}

private def runCoercionTransformations(expression: Expression, ansiMode: Boolean): Expression = {
protected[spark] def runCoercionTransformations(expression: Expression,
ansiMode: Boolean): Expression = {
val transformations = if (ansiMode) {
ansiTransformations
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis.resolver

class UDFTypeCoercesExpressionTypes extends CoercesExpressionTypes {
}
Loading