Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions python/pyspark/ml/tests/tuning/test_cv_io_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

import tempfile
from concurrent.futures import ThreadPoolExecutor

from pyspark.ml.feature import HashingTF, Tokenizer
from pyspark.ml import Pipeline
Expand Down Expand Up @@ -54,7 +55,7 @@ def _run_test_save_load_pipeline_estimator(self, LogisticRegressionCls):
tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")

ova = OneVsRest(classifier=LogisticRegressionCls())
ova = OneVsRest(classifier=LogisticRegressionCls(), parallelism=2)
lr1 = LogisticRegressionCls().setMaxIter(5)
lr2 = LogisticRegressionCls().setMaxIter(10)

Expand All @@ -72,6 +73,7 @@ def _run_test_save_load_pipeline_estimator(self, LogisticRegressionCls):
estimatorParamMaps=paramGrid,
evaluator=MulticlassClassificationEvaluator(),
numFolds=2,
parallelism=4,
) # use 3+ folds in practice
cvPath = temp_path + "/cv"
crossval.save(cvPath)
Expand Down Expand Up @@ -100,6 +102,7 @@ def _run_test_save_load_pipeline_estimator(self, LogisticRegressionCls):
estimatorParamMaps=paramGrid,
evaluator=MulticlassClassificationEvaluator(),
numFolds=2,
parallelism=4,
) # use 3+ folds in practice
cv2Path = temp_path + "/cv2"
crossval2.save(cv2Path)
Expand All @@ -126,8 +129,13 @@ def _run_test_save_load_pipeline_estimator(self, LogisticRegressionCls):
self.assertEqual(loadedStage.uid, originalStage.uid)

def test_save_load_pipeline_estimator(self):
self._run_test_save_load_pipeline_estimator(LogisticRegression)
self._run_test_save_load_pipeline_estimator(DummyLogisticRegression)
with ThreadPoolExecutor(max_workers=2) as executor:
list(
executor.map(
self._run_test_save_load_pipeline_estimator,
[LogisticRegression, DummyLogisticRegression],
)
)


if __name__ == "__main__":
Expand Down