Skip to content

[SPARK-51473][ML][CONNECT] ML transformed dataframe keep a reference to the model #50199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 53 additions & 20 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,10 @@ def summary(self) -> "LinearSVCTrainingSummary": # type: ignore[override]
trained on the training set. An exception is thrown if `trainingSummary is None`.
"""
if self.hasSummary:
return LinearSVCTrainingSummary(super(LinearSVCModel, self).summary)
s = LinearSVCTrainingSummary(super(LinearSVCModel, self).summary)
if is_remote():
s.__source_transformer__ = self # type: ignore[attr-defined]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

training summary hold a reference to the model

return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
Expand All @@ -909,7 +912,10 @@ def evaluate(self, dataset: DataFrame) -> "LinearSVCSummary":
if not isinstance(dataset, DataFrame):
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
java_lsvc_summary = self._call_java("evaluate", dataset)
return LinearSVCSummary(java_lsvc_summary)
s = LinearSVCSummary(java_lsvc_summary)
if is_remote():
s.__source_transformer__ = self # type: ignore[attr-defined]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the testing summary model.evaluate(df) hold a reference to the model

return s


class LinearSVCSummary(_BinaryClassificationSummary):
Expand Down Expand Up @@ -1578,14 +1584,16 @@ def summary(self) -> "LogisticRegressionTrainingSummary":
trained on the training set. An exception is thrown if `trainingSummary is None`.
"""
if self.hasSummary:
s: LogisticRegressionTrainingSummary
if self.numClasses <= 2:
return BinaryLogisticRegressionTrainingSummary(
s = BinaryLogisticRegressionTrainingSummary(
super(LogisticRegressionModel, self).summary
)
else:
return LogisticRegressionTrainingSummary(
super(LogisticRegressionModel, self).summary
)
s = LogisticRegressionTrainingSummary(super(LogisticRegressionModel, self).summary)
if is_remote():
s.__source_transformer__ = self # type: ignore[attr-defined]
return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
Expand All @@ -1605,10 +1613,14 @@ def evaluate(self, dataset: DataFrame) -> "LogisticRegressionSummary":
if not isinstance(dataset, DataFrame):
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
java_blr_summary = self._call_java("evaluate", dataset)
s: LogisticRegressionSummary
if self.numClasses <= 2:
return BinaryLogisticRegressionSummary(java_blr_summary)
s = BinaryLogisticRegressionSummary(java_blr_summary)
else:
return LogisticRegressionSummary(java_blr_summary)
s = LogisticRegressionSummary(java_blr_summary)
if is_remote():
s.__source_transformer__ = self # type: ignore[attr-defined]
return s


class LogisticRegressionSummary(_ClassificationSummary):
Expand Down Expand Up @@ -2304,22 +2316,24 @@ def summary(self) -> "RandomForestClassificationTrainingSummary":
trained on the training set. An exception is thrown if `trainingSummary is None`.
"""
if self.hasSummary:
s: RandomForestClassificationTrainingSummary
if self.numClasses <= 2:
return BinaryRandomForestClassificationTrainingSummary(
s = BinaryRandomForestClassificationTrainingSummary(
super(RandomForestClassificationModel, self).summary
)
else:
return RandomForestClassificationTrainingSummary(
s = RandomForestClassificationTrainingSummary(
super(RandomForestClassificationModel, self).summary
)
if is_remote():
s.__source_transformer__ = self # type: ignore[attr-defined]
return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
)

def evaluate(
self, dataset: DataFrame
) -> Union["BinaryRandomForestClassificationSummary", "RandomForestClassificationSummary"]:
def evaluate(self, dataset: DataFrame) -> "RandomForestClassificationSummary":
"""
Evaluates the model on a test dataset.

Expand All @@ -2333,10 +2347,14 @@ def evaluate(
if not isinstance(dataset, DataFrame):
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
java_rf_summary = self._call_java("evaluate", dataset)
s: RandomForestClassificationSummary
if self.numClasses <= 2:
return BinaryRandomForestClassificationSummary(java_rf_summary)
s = BinaryRandomForestClassificationSummary(java_rf_summary)
else:
return RandomForestClassificationSummary(java_rf_summary)
s = RandomForestClassificationSummary(java_rf_summary)
if is_remote():
s.__source_transformer__ = self # type: ignore[attr-defined]
return s


class RandomForestClassificationSummary(_ClassificationSummary):
Expand All @@ -2363,7 +2381,10 @@ class RandomForestClassificationTrainingSummary(


@inherit_doc
class BinaryRandomForestClassificationSummary(_BinaryClassificationSummary):
class BinaryRandomForestClassificationSummary(
_BinaryClassificationSummary,
RandomForestClassificationSummary,
):
"""
BinaryRandomForestClassification results for a given model.

Expand Down Expand Up @@ -3341,9 +3362,12 @@ def summary( # type: ignore[override]
trained on the training set. An exception is thrown if `trainingSummary is None`.
"""
if self.hasSummary:
return MultilayerPerceptronClassificationTrainingSummary(
s = MultilayerPerceptronClassificationTrainingSummary(
super(MultilayerPerceptronClassificationModel, self).summary
)
if is_remote():
s.__source_transformer__ = self # type: ignore[attr-defined]
return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
Expand All @@ -3363,7 +3387,10 @@ def evaluate(self, dataset: DataFrame) -> "MultilayerPerceptronClassificationSum
if not isinstance(dataset, DataFrame):
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
java_mlp_summary = self._call_java("evaluate", dataset)
return MultilayerPerceptronClassificationSummary(java_mlp_summary)
s = MultilayerPerceptronClassificationSummary(java_mlp_summary)
if is_remote():
s.__source_transformer__ = self # type: ignore[attr-defined]
return s


class MultilayerPerceptronClassificationSummary(_ClassificationSummary):
Expand Down Expand Up @@ -4290,7 +4317,10 @@ def summary(self) -> "FMClassificationTrainingSummary":
trained on the training set. An exception is thrown if `trainingSummary is None`.
"""
if self.hasSummary:
return FMClassificationTrainingSummary(super(FMClassificationModel, self).summary)
s = FMClassificationTrainingSummary(super(FMClassificationModel, self).summary)
if is_remote():
s.__source_transformer__ = self # type: ignore[attr-defined]
return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
Expand All @@ -4310,7 +4340,10 @@ def evaluate(self, dataset: DataFrame) -> "FMClassificationSummary":
if not isinstance(dataset, DataFrame):
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
java_fm_summary = self._call_java("evaluate", dataset)
return FMClassificationSummary(java_fm_summary)
s = FMClassificationSummary(java_fm_summary)
if is_remote():
s.__source_transformer__ = self # type: ignore[attr-defined]
return s


class FMClassificationSummary(_BinaryClassificationSummary):
Expand Down
15 changes: 12 additions & 3 deletions python/pyspark/ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,10 @@ def summary(self) -> "GaussianMixtureSummary":
training set. An exception is thrown if no summary exists.
"""
if self.hasSummary:
return GaussianMixtureSummary(super(GaussianMixtureModel, self).summary)
s = GaussianMixtureSummary(super(GaussianMixtureModel, self).summary)
if is_remote():
s.__source_transformer__ = self # type: ignore[attr-defined]
return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
Expand Down Expand Up @@ -710,7 +713,10 @@ def summary(self) -> KMeansSummary:
training set. An exception is thrown if no summary exists.
"""
if self.hasSummary:
return KMeansSummary(super(KMeansModel, self).summary)
s = KMeansSummary(super(KMeansModel, self).summary)
if is_remote():
s.__source_transformer__ = self # type: ignore[attr-defined]
return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
Expand Down Expand Up @@ -1057,7 +1063,10 @@ def summary(self) -> "BisectingKMeansSummary":
training set. An exception is thrown if no summary exists.
"""
if self.hasSummary:
return BisectingKMeansSummary(super(BisectingKMeansModel, self).summary)
s = BisectingKMeansSummary(super(BisectingKMeansModel, self).summary)
if is_remote():
s.__source_transformer__ = self # type: ignore[attr-defined]
return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
Expand Down
20 changes: 16 additions & 4 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,10 @@ def summary(self) -> "LinearRegressionTrainingSummary":
`trainingSummary is None`.
"""
if self.hasSummary:
return LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary)
s = LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary)
if is_remote():
s.__source_transformer__ = self # type: ignore[attr-defined]
return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
Expand All @@ -508,7 +511,10 @@ def evaluate(self, dataset: DataFrame) -> "LinearRegressionSummary":
if not isinstance(dataset, DataFrame):
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
java_lr_summary = self._call_java("evaluate", dataset)
return LinearRegressionSummary(java_lr_summary)
s = LinearRegressionSummary(java_lr_summary)
if is_remote():
s.__source_transformer__ = self # type: ignore[attr-defined]
return s


class LinearRegressionSummary(JavaWrapper):
Expand Down Expand Up @@ -2766,9 +2772,12 @@ def summary(self) -> "GeneralizedLinearRegressionTrainingSummary":
`trainingSummary is None`.
"""
if self.hasSummary:
return GeneralizedLinearRegressionTrainingSummary(
s = GeneralizedLinearRegressionTrainingSummary(
super(GeneralizedLinearRegressionModel, self).summary
)
if is_remote():
s.__source_transformer__ = self # type: ignore[attr-defined]
return s
else:
raise RuntimeError(
"No training summary available for this %s" % self.__class__.__name__
Expand All @@ -2789,7 +2798,10 @@ def evaluate(self, dataset: DataFrame) -> "GeneralizedLinearRegressionSummary":
if not isinstance(dataset, DataFrame):
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
java_glr_summary = self._call_java("evaluate", dataset)
return GeneralizedLinearRegressionSummary(java_glr_summary)
s = GeneralizedLinearRegressionSummary(java_glr_summary)
if is_remote():
s.__source_transformer__ = self # type: ignore[attr-defined]
return s


class GeneralizedLinearRegressionSummary(JavaWrapper):
Expand Down
107 changes: 103 additions & 4 deletions python/pyspark/ml/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from pyspark.ml.linalg import Vectors
from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel
from pyspark.ml.clustering import KMeans, KMeansModel
from pyspark.ml.clustering import KMeans, KMeansModel, GaussianMixture
from pyspark.testing.mlutils import MockDataset, MockEstimator, MockTransformer
from pyspark.testing.sqlutils import ReusedSQLTestCase

Expand Down Expand Up @@ -176,7 +176,7 @@ def test_clustering_pipeline(self):

def test_model_gc(self):
spark = self.spark
df = spark.createDataFrame(
df1 = spark.createDataFrame(
[
Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
Expand All @@ -189,8 +189,107 @@ def fit_transform(df):
model = lr.fit(df)
return model.transform(df)

output = fit_transform(df)
self.assertEqual(output.count(), 3)
output1 = fit_transform(df1)
self.assertEqual(output1.count(), 3)

df2 = spark.range(10)

def fit_transform_and_union(df1, df2):
output1 = fit_transform(df1)
return output1.unionByName(df2, True)

output2 = fit_transform_and_union(df1, df2)
self.assertEqual(output2.count(), 13)

def test_model_training_summary_gc(self):
spark = self.spark
df1 = spark.createDataFrame(
[
Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0])),
]
)

def fit_predictions(df):
lr = LogisticRegression(maxIter=1, regParam=0.01, weightCol="weight")
model = lr.fit(df)
return model.summary.predictions

output1 = fit_predictions(df1)
self.assertEqual(output1.count(), 3)

df2 = spark.range(10)

def fit_predictions_and_union(df1, df2):
output1 = fit_predictions(df1)
return output1.unionByName(df2, True)

output2 = fit_predictions_and_union(df1, df2)
self.assertEqual(output2.count(), 13)

def test_model_testing_summary_gc(self):
spark = self.spark
df1 = spark.createDataFrame(
[
Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0])),
]
)

def fit_predictions(df):
lr = LogisticRegression(maxIter=1, regParam=0.01, weightCol="weight")
model = lr.fit(df)
return model.evaluate(df).predictions

output1 = fit_predictions(df1)
self.assertEqual(output1.count(), 3)

df2 = spark.range(10)

def fit_predictions_and_union(df1, df2):
output1 = fit_predictions(df1)
return output1.unionByName(df2, True)

output2 = fit_predictions_and_union(df1, df2)
self.assertEqual(output2.count(), 13)

def test_model_attr_df_gc(self):
spark = self.spark
df1 = (
spark.createDataFrame(
[
(1, 1.0, Vectors.dense([-0.1, -0.05])),
(2, 2.0, Vectors.dense([-0.01, -0.1])),
(3, 3.0, Vectors.dense([0.9, 0.8])),
(4, 1.0, Vectors.dense([0.75, 0.935])),
(5, 1.0, Vectors.dense([-0.83, -0.68])),
(6, 1.0, Vectors.dense([-0.91, -0.76])),
],
["index", "weight", "features"],
)
.coalesce(1)
.sortWithinPartitions("index")
.select("weight", "features")
)

def fit_attr_df(df):
gmm = GaussianMixture(k=2, maxIter=2, weightCol="weight", seed=1)
model = gmm.fit(df)
return model.gaussiansDF

output1 = fit_attr_df(df1)
self.assertEqual(output1.count(), 2)

df2 = spark.range(10)

def fit_attr_df_and_union(df1, df2):
output1 = fit_attr_df(df1)
return output1.unionByName(df2, True)

output2 = fit_attr_df_and_union(df1, df2)
self.assertEqual(output2.count(), 12)


class PipelineTests(PipelineTestsMixin, ReusedSQLTestCase):
Expand Down
Loading