diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py index bc57dbc87356c..87b61a8a7d52e 100644 --- a/python/pyspark/ml/base.py +++ b/python/pyspark/ml/base.py @@ -269,7 +269,9 @@ def transform(self, dataset: DataFrame, params: Optional["ParamMap"] = None) -> # # output = fit_transform(df) if is_remote(): - transformed.__source_transformer__ = self + # attach the source transformer to the internal plan, + # so that all descendant plans also keep it. + transformed._plan.__source_transformer__ = self # type: ignore[attr-defined] return transformed else: diff --git a/python/pyspark/ml/tests/test_pipeline.py b/python/pyspark/ml/tests/test_pipeline.py index ced1cda1948a6..1df37bdfa7fa4 100644 --- a/python/pyspark/ml/tests/test_pipeline.py +++ b/python/pyspark/ml/tests/test_pipeline.py @@ -192,6 +192,30 @@ def fit_transform(df): output = fit_transform(df) self.assertEqual(output.count(), 3) + def test_model_gc_II(self): + spark = self.spark + df1 = spark.createDataFrame( + [ + Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])), + Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])), + Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0])), + ] + ) + + df2 = spark.range(10) + + def fit_transform(df): + lr = LogisticRegression(maxIter=1, regParam=0.01, weightCol="weight") + model = lr.fit(df) + return model.transform(df) + + def fit_transform_and_union(df1, df2): + output1 = fit_transform(df1) + return output1.unionByName(df2, True) + + output = fit_transform_and_union(df1, df2) + self.assertEqual(output.count(), 13) + class PipelineTests(PipelineTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 19a1e9903a525..0d49a96320f4e 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -284,7 +284,7 @@ def wrapped(self: "JavaWrapper") -> Any: if in_remote: # Delete the model if possible model_id = self._java_obj - del_remote_cache(model_id) + del_remote_cache(cast(str, model_id)) return else: return f(self)