Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengruifeng committed Mar 7, 2025
1 parent 33d3e5b commit 189017e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
18 changes: 16 additions & 2 deletions python/pyspark/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,23 @@ def transform(self, dataset: DataFrame, params: Optional["ParamMap"] = None) ->
params = dict()
if isinstance(params, dict):
if params:
return self.copy(params)._transform(dataset)
transformed = self.copy(params)._transform(dataset)
else:
return self._transform(dataset)
transformed = self._transform(dataset)

from pyspark.sql.utils import is_remote

# Keep a reference to the source transformer in the client side, for this case:
#
# def fit_transform(df):
# model = estimator.fit(df)
# return model.transform(df)
#
# output = fit_transform(df)
if is_remote():
transformed.__source_transformer__ = self

return transformed
else:
raise TypeError("Params must be a param map but got %s." % type(params))

Expand Down
14 changes: 2 additions & 12 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,18 +283,8 @@ def wrapped(self: "JavaWrapper") -> Any:

if in_remote:
# Delete the model if possible
# model_id = self._java_obj
# del_remote_cache(model_id)
#
# Above codes delete the model from the ml cache eagerly, and may cause
# NPE in the server side in the case of 'fit_transform':
#
# def fit_transform(df):
# model = estimator.fit(df)
# return model.transform(df)
#
# output = fit_transform(df)
# output.show()
model_id = self._java_obj
del_remote_cache(model_id)
return
else:
return f(self)
Expand Down

0 comments on commit 189017e

Please sign in to comment.