Skip to content

Commit 78d5229

Browse files
gaogaotiantianzhengruifeng
authored andcommitted
[SPARK-54958][PYTHON][ML][TEST] Accelerate test_cv_io_pipeline.py
### What changes were proposed in this pull request? Parallelized some operations in the test. ### Why are the changes needed? The test takes ~80s which is the second longest test case we have. Having some parallelism will reduce about 50% of the time it takes. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Test passed locally. Time consumed 64s -> 28s. ### Was this patch authored or co-authored using generative AI tooling? No Closes #53725 from gaogaotiantian/optimize-cv-io-pipeline. Authored-by: Tian Gao <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent fe9ffd5 commit 78d5229

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

python/pyspark/ml/tests/tuning/test_cv_io_pipeline.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717

1818
import tempfile
19+
from concurrent.futures import ThreadPoolExecutor
1920

2021
from pyspark.ml.feature import HashingTF, Tokenizer
2122
from pyspark.ml import Pipeline
@@ -54,7 +55,7 @@ def _run_test_save_load_pipeline_estimator(self, LogisticRegressionCls):
5455
tokenizer = Tokenizer(inputCol="text", outputCol="words")
5556
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
5657

57-
ova = OneVsRest(classifier=LogisticRegressionCls())
58+
ova = OneVsRest(classifier=LogisticRegressionCls(), parallelism=2)
5859
lr1 = LogisticRegressionCls().setMaxIter(5)
5960
lr2 = LogisticRegressionCls().setMaxIter(10)
6061

@@ -72,6 +73,7 @@ def _run_test_save_load_pipeline_estimator(self, LogisticRegressionCls):
7273
estimatorParamMaps=paramGrid,
7374
evaluator=MulticlassClassificationEvaluator(),
7475
numFolds=2,
76+
parallelism=4,
7577
) # use 3+ folds in practice
7678
cvPath = temp_path + "/cv"
7779
crossval.save(cvPath)
@@ -100,6 +102,7 @@ def _run_test_save_load_pipeline_estimator(self, LogisticRegressionCls):
100102
estimatorParamMaps=paramGrid,
101103
evaluator=MulticlassClassificationEvaluator(),
102104
numFolds=2,
105+
parallelism=4,
103106
) # use 3+ folds in practice
104107
cv2Path = temp_path + "/cv2"
105108
crossval2.save(cv2Path)
@@ -126,8 +129,13 @@ def _run_test_save_load_pipeline_estimator(self, LogisticRegressionCls):
126129
self.assertEqual(loadedStage.uid, originalStage.uid)
127130

128131
def test_save_load_pipeline_estimator(self):
129-
self._run_test_save_load_pipeline_estimator(LogisticRegression)
130-
self._run_test_save_load_pipeline_estimator(DummyLogisticRegression)
132+
with ThreadPoolExecutor(max_workers=2) as executor:
133+
list(
134+
executor.map(
135+
self._run_test_save_load_pipeline_estimator,
136+
[LogisticRegression, DummyLogisticRegression],
137+
)
138+
)
131139

132140

133141
if __name__ == "__main__":

0 commit comments

Comments
 (0)