diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py index 224ef34fd5edc..bc57dbc87356c 100644 --- a/python/pyspark/ml/base.py +++ b/python/pyspark/ml/base.py @@ -255,9 +255,23 @@ def transform(self, dataset: DataFrame, params: Optional["ParamMap"] = None) -> params = dict() if isinstance(params, dict): if params: - return self.copy(params)._transform(dataset) + transformed = self.copy(params)._transform(dataset) else: - return self._transform(dataset) + transformed = self._transform(dataset) + + from pyspark.sql.utils import is_remote + + # Keep a reference to the source transformer in the client side, for this case: + # + # def fit_transform(df): + # model = estimator.fit(df) + # return model.transform(df) + # + # output = fit_transform(df) + if is_remote(): + transformed.__source_transformer__ = self + + return transformed else: raise TypeError("Params must be a param map but got %s." % type(params)) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 6b3d6101c249f..19a1e9903a525 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -283,18 +283,8 @@ def wrapped(self: "JavaWrapper") -> Any: if in_remote: # Delete the model if possible - # model_id = self._java_obj - # del_remote_cache(model_id) - # - # Above codes delete the model from the ml cache eagerly, and may cause - # NPE in the server side in the case of 'fit_transform': - # - # def fit_transform(df): - # model = estimator.fit(df) - # return model.transform(df) - # - # output = fit_transform(df) - # output.show() + model_id = self._java_obj + del_remote_cache(model_id) return else: return f(self)