Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][ML][CONNECT] ML transformed dataframe keep a reference to the model #50199

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions python/pyspark/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,25 @@ def transform(self, dataset: DataFrame, params: Optional["ParamMap"] = None) ->
params = dict()
if isinstance(params, dict):
if params:
return self.copy(params)._transform(dataset)
transformed = self.copy(params)._transform(dataset)
else:
return self._transform(dataset)
transformed = self._transform(dataset)

from pyspark.sql.utils import is_remote

# Keep a reference to the source transformer in the client side, for this case:
#
# def fit_transform(df):
# model = estimator.fit(df)
# return model.transform(df)
#
# output = fit_transform(df)
if is_remote():
# attach the source transformer to the internal plan,
# so that all descendant plans also keep it.
transformed._plan.__source_transformer__ = self # type: ignore[attr-defined]

return transformed
else:
raise TypeError("Params must be a param map but got %s." % type(params))

Expand Down
24 changes: 24 additions & 0 deletions python/pyspark/ml/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,30 @@ def fit_transform(df):
output = fit_transform(df)
self.assertEqual(output.count(), 3)

def test_model_gc_II(self):
spark = self.spark
df1 = spark.createDataFrame(
[
Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0])),
]
)

df2 = spark.range(10)

def fit_transform(df):
lr = LogisticRegression(maxIter=1, regParam=0.01, weightCol="weight")
model = lr.fit(df)
return model.transform(df)

def fit_transform_and_union(df1, df2):
output1 = fit_transform(df1)
return output1.unionByName(df2, True)

output = fit_transform_and_union(df1, df2)
self.assertEqual(output.count(), 13)


class PipelineTests(PipelineTestsMixin, ReusedSQLTestCase):
pass
Expand Down
14 changes: 2 additions & 12 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,18 +283,8 @@ def wrapped(self: "JavaWrapper") -> Any:

if in_remote:
# Delete the model if possible
# model_id = self._java_obj
# del_remote_cache(model_id)
#
# Above codes delete the model from the ml cache eagerly, and may cause
# NPE in the server side in the case of 'fit_transform':
#
# def fit_transform(df):
# model = estimator.fit(df)
# return model.transform(df)
#
# output = fit_transform(df)
# output.show()
model_id = self._java_obj
del_remote_cache(cast(str, model_id))
return
else:
return f(self)
Expand Down