Skip to content

Commit

Permalink
propogate
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengruifeng committed Mar 7, 2025
1 parent c197a44 commit dd3b334
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
4 changes: 3 additions & 1 deletion python/pyspark/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,9 @@ def transform(self, dataset: DataFrame, params: Optional["ParamMap"] = None) ->
#
# output = fit_transform(df)
if is_remote():
transformed.__source_transformer__ = self
# attach the source transformer to the internal plan,
# so that all descendant plans also keep it.
transformed._plan.__source_transformer__ = self # type: ignore[attr-defined]

return transformed
else:
Expand Down
24 changes: 24 additions & 0 deletions python/pyspark/ml/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,30 @@ def fit_transform(df):
output = fit_transform(df)
self.assertEqual(output.count(), 3)

def test_model_gc_II(self):
spark = self.spark
df1 = spark.createDataFrame(
[
Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0])),
]
)

df2 = spark.range(10)

def fit_transform(df):
lr = LogisticRegression(maxIter=1, regParam=0.01, weightCol="weight")
model = lr.fit(df)
return model.transform(df)

def fit_transform_and_union(df1, df2):
output1 = fit_transform(df1)
return output1.unionByName(df2, True)

output = fit_transform_and_union(df1, df2)
self.assertEqual(output.count(), 13)


class PipelineTests(PipelineTestsMixin, ReusedSQLTestCase):
pass
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def wrapped(self: "JavaWrapper") -> Any:
if in_remote:
# Delete the model if possible
model_id = self._java_obj
del_remote_cache(model_id)
del_remote_cache(cast(str, model_id))
return
else:
return f(self)
Expand Down

0 comments on commit dd3b334

Please sign in to comment.