From 6b3ad5123525aa10551a7d41f407551bc56f8555 Mon Sep 17 00:00:00 2001 From: "Perciaccante, Giovambattista" Date: Wed, 28 Oct 2020 19:14:34 +0100 Subject: [PATCH 1/8] fixed default test_size which didn't work with pandas dataframe and changed the async def _fit function to allow Hyperband to work with non dask arrays --- dask_ml/model_selection/_incremental.py | 57 ++++++++++++++----------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/dask_ml/model_selection/_incremental.py b/dask_ml/model_selection/_incremental.py index 4697cbec6..d5a0f9252 100644 --- a/dask_ml/model_selection/_incremental.py +++ b/dask_ml/model_selection/_incremental.py @@ -198,34 +198,40 @@ async def _fit( else: y_test = await client.scatter(y_test) - # Convert to batches of delayed objects of numpy arrays - X_train = sorted(futures_of(X_train), key=lambda f: f.key) - y_train = sorted(futures_of(y_train), key=lambda f: f.key) - assert len(X_train) == len(y_train) - - train_eg = await client.gather(client.map(len, y_train)) - msg = "[CV%s] For training there are between %d and %d examples in each chunk" - logger.info(msg, prefix, min(train_eg), max(train_eg)) + if hasattr(X_train, 'npartitions'): + # Convert to batches of delayed objects of numpy arrays + X_train = sorted(futures_of(X_train), key=lambda f: f.key) + y_train = sorted(futures_of(y_train), key=lambda f: f.key) + assert len(X_train) == len(y_train) + + train_eg = await client.gather(client.map(len, y_train)) + msg = "[CV%s] For training there are between %d and %d examples in each chunk" + logger.info(msg, prefix, min(train_eg), max(train_eg)) + + def get_futures(partial_fit_calls): + """Policy to get training data futures + + Currently we compute once, and then keep in memory. + Presumably in the future we'll want to let data drop and recompute. + This function handles that policy internally, and also controls random + access to training data. + """ + # Shuffle blocks going forward to get uniform-but-random access + while partial_fit_calls >= len(order): + L = list(range(len(X_train))) + rng.shuffle(L) + order.extend(L) + j = order[partial_fit_calls] + return X_train[j], y_train[j] + # __addition__ start + else: + def get_futures(partial_fit_calls): + return X_train, y_train + # __addition__ end # Order by which we process training data futures order = [] - def get_futures(partial_fit_calls): - """Policy to get training data futures - - Currently we compute once, and then keep in memory. - Presumably in the future we'll want to let data drop and recompute. - This function handles that policy internally, and also controls random - access to training data. - """ - # Shuffle blocks going forward to get uniform-but-random access - while partial_fit_calls >= len(order): - L = list(range(len(X_train))) - rng.shuffle(L) - order.extend(L) - j = order[partial_fit_calls] - return X_train[j], y_train[j] - # Submit initial partial_fit and score computations on first batch of data X_future, y_future = get_futures(0) X_future_2, y_future_2 = get_futures(1) @@ -566,7 +572,8 @@ def _get_train_test_split(self, X, y, **kwargs): X, y : dask.array.Array """ if self.test_size is None: - test_size = min(0.2, 1 / X.npartitions) + npartitions = getattr(X, 'npartitions', 1) + test_size = min(0.2, 1 / npartitions) else: test_size = self.test_size X_train, X_test, y_train, y_test = train_test_split( From 101e6a9272602d42ec63f1e792f23ec51ac204db Mon Sep 17 00:00:00 2001 From: gioxc88 Date: Fri, 20 Nov 2020 14:12:47 +0000 Subject: [PATCH 2/8] - change in get future function - change in checking array using dask.is_dask_collection - added test in tests.model_selection.test_hyperband_non_daskarray.py --- dask_ml/model_selection/_incremental.py | 61 +++++++++++++------------ 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/dask_ml/model_selection/_incremental.py b/dask_ml/model_selection/_incremental.py index d5a0f9252..4dbbe0dd2 100644 --- a/dask_ml/model_selection/_incremental.py +++ b/dask_ml/model_selection/_incremental.py @@ -198,36 +198,39 @@ async def _fit( else: y_test = await client.scatter(y_test) - if hasattr(X_train, 'npartitions'): - # Convert to batches of delayed objects of numpy arrays - X_train = sorted(futures_of(X_train), key=lambda f: f.key) - y_train = sorted(futures_of(y_train), key=lambda f: f.key) - assert len(X_train) == len(y_train) - - train_eg = await client.gather(client.map(len, y_train)) - msg = "[CV%s] For training there are between %d and %d examples in each chunk" - logger.info(msg, prefix, min(train_eg), max(train_eg)) - - def get_futures(partial_fit_calls): - """Policy to get training data futures - - Currently we compute once, and then keep in memory. - Presumably in the future we'll want to let data drop and recompute. - This function handles that policy internally, and also controls random - access to training data. - """ - # Shuffle blocks going forward to get uniform-but-random access - while partial_fit_calls >= len(order): - L = list(range(len(X_train))) - rng.shuffle(L) - order.extend(L) - j = order[partial_fit_calls] - return X_train[j], y_train[j] - # __addition__ start - else: - def get_futures(partial_fit_calls): + # Convert to batches of delayed objects of numpy arrays + X_train = sorted(futures_of(X_train), key=lambda f: f.key) + y_train = sorted(futures_of(y_train), key=lambda f: f.key) + assert len(X_train) == len(y_train) + + train_eg = await client.gather(client.map(len, y_train)) + + ### start addition ### + min_samples = min(train_eg) if len(train_eg) else len(y_train) + max_samples = max(train_eg) if len(train_eg) else len(y_train) + + msg = "[CV%s] For training there are between %d and %d examples in each chunk" + logger.info(msg, prefix, min_samples, max_samples) + + def get_futures(partial_fit_calls): + """Policy to get training data futures + + Currently we compute once, and then keep in memory. + Presumably in the future we'll want to let data drop and recompute. + This function handles that policy internally, and also controls random + access to training data. + """ + if dask.is_dask_collection(y_train): return X_train, y_train - # __addition__ end + + # Shuffle blocks going forward to get uniform-but-random access + while partial_fit_calls >= len(order): + L = list(range(len(X_train))) + rng.shuffle(L) + order.extend(L) + j = order[partial_fit_calls] + return X_train[j], y_train[j] + ### end addition ### # Order by which we process training data futures order = [] From fe154c63d54550c570ae03760dc7cecaafcc2e81 Mon Sep 17 00:00:00 2001 From: gioxc88 Date: Fri, 20 Nov 2020 14:45:43 +0000 Subject: [PATCH 3/8] fix in test_hyperband_non_daskarray --- .../test_hyperband_non_daskarray.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 tests/model_selection/test_hyperband_non_daskarray.py diff --git a/tests/model_selection/test_hyperband_non_daskarray.py b/tests/model_selection/test_hyperband_non_daskarray.py new file mode 100644 index 000000000..63e50289e --- /dev/null +++ b/tests/model_selection/test_hyperband_non_daskarray.py @@ -0,0 +1,22 @@ +import numpy as np +import pandas as pd + +from dask_ml.model_selection import HyperbandSearchCV +from dask_ml.datasets import make_classification +from distributed.utils_test import gen_cluster +from sklearn.linear_model import SGDClassifier + + +@gen_cluster(client=True) +def test_pandas(): + X, y = make_classification(chunks=100) + X, y = pd.DataFrame(X.compute()), pd.Series(y.compute()) + + est = SGDClassifier(tol=1e-3) + param_dist = {'alpha': np.logspace(-4, 0, num=1000), + 'loss': ['hinge', 'log', 'modified_huber', 'squared_hinge'], + 'average': [True, False]} + + search = HyperbandSearchCV(est, param_dist) + search.fit(X, y, classes=y.unique()) + assert search.best_params_ From a73e6ec467eec1539626df81d01c2b143d3934e3 Mon Sep 17 00:00:00 2001 From: gioxc88 Date: Fri, 20 Nov 2020 15:09:47 +0000 Subject: [PATCH 4/8] fix in test_hyperband_non_daskarray --- tests/model_selection/test_hyperband_non_daskarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/model_selection/test_hyperband_non_daskarray.py b/tests/model_selection/test_hyperband_non_daskarray.py index 63e50289e..7c1a5b260 100644 --- a/tests/model_selection/test_hyperband_non_daskarray.py +++ b/tests/model_selection/test_hyperband_non_daskarray.py @@ -8,7 +8,7 @@ @gen_cluster(client=True) -def test_pandas(): +def test_pandas(c, s, a, b): X, y = make_classification(chunks=100) X, y = pd.DataFrame(X.compute()), pd.Series(y.compute()) From cd5645fbbf2b8030184c39c01b663e7c13595e49 Mon Sep 17 00:00:00 2001 From: gioxc88 Date: Sun, 22 Nov 2020 19:24:01 +0000 Subject: [PATCH 5/8] comment removed and length cheked on X_train --- dask_ml/model_selection/_incremental.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/dask_ml/model_selection/_incremental.py b/dask_ml/model_selection/_incremental.py index 4dbbe0dd2..4056f35ec 100644 --- a/dask_ml/model_selection/_incremental.py +++ b/dask_ml/model_selection/_incremental.py @@ -203,11 +203,10 @@ async def _fit( y_train = sorted(futures_of(y_train), key=lambda f: f.key) assert len(X_train) == len(y_train) - train_eg = await client.gather(client.map(len, y_train)) + train_eg = await client.gather(client.map(len, X_train)) - ### start addition ### - min_samples = min(train_eg) if len(train_eg) else len(y_train) - max_samples = max(train_eg) if len(train_eg) else len(y_train) + min_samples = min(train_eg) if len(train_eg) else len(X_train) + max_samples = max(train_eg) if len(train_eg) else len(X_train) msg = "[CV%s] For training there are between %d and %d examples in each chunk" logger.info(msg, prefix, min_samples, max_samples) @@ -220,7 +219,7 @@ def get_futures(partial_fit_calls): This function handles that policy internally, and also controls random access to training data. """ - if dask.is_dask_collection(y_train): + if dask.is_dask_collection(X_train): return X_train, y_train # Shuffle blocks going forward to get uniform-but-random access @@ -230,7 +229,6 @@ def get_futures(partial_fit_calls): order.extend(L) j = order[partial_fit_calls] return X_train[j], y_train[j] - ### end addition ### # Order by which we process training data futures order = [] From 0017e09cd3cac3f5dcb9962781a13e5977b18eed Mon Sep 17 00:00:00 2001 From: gioxc88 Date: Mon, 23 Nov 2020 00:45:20 +0000 Subject: [PATCH 6/8] moved test_hyperband_non_daskarray.py to test_hyperband.py and changed dask make_classification to sklearn make_calssification --- tests/model_selection/test_hyperband.py | 17 ++++++++++++++ .../test_hyperband_non_daskarray.py | 22 ------------------- 2 files changed, 17 insertions(+), 22 deletions(-) delete mode 100644 tests/model_selection/test_hyperband_non_daskarray.py diff --git a/tests/model_selection/test_hyperband.py b/tests/model_selection/test_hyperband.py index 9bcf131aa..e8602064c 100644 --- a/tests/model_selection/test_hyperband.py +++ b/tests/model_selection/test_hyperband.py @@ -14,6 +14,7 @@ loop, ) from sklearn.linear_model import SGDClassifier +from sklearn.datasets import make_classification as sk_make_classification from dask_ml._compat import DISTRIBUTED_2_5_0 from dask_ml.datasets import make_classification @@ -478,3 +479,19 @@ async def test_dataframe_inputs(c, s, a, b): params = {"value": scipy.stats.uniform(0, 1)} alg = HyperbandSearchCV(model, params, max_iter=9, random_state=42) await alg.fit(X, y) + + +@gen_cluster(client=True) +def test_pandas(c, s, a, b): + + X, y = sk_make_classification(chunks=100) + X, y = pd.DataFrame(X), pd.Series(y) + + est = SGDClassifier(tol=1e-3) + param_dist = {'alpha': np.logspace(-4, 0, num=1000), + 'loss': ['hinge', 'log', 'modified_huber', 'squared_hinge'], + 'average': [True, False]} + + search = HyperbandSearchCV(est, param_dist) + search.fit(X, y, classes=y.unique()) + assert search.best_params_ \ No newline at end of file diff --git a/tests/model_selection/test_hyperband_non_daskarray.py b/tests/model_selection/test_hyperband_non_daskarray.py deleted file mode 100644 index 7c1a5b260..000000000 --- a/tests/model_selection/test_hyperband_non_daskarray.py +++ /dev/null @@ -1,22 +0,0 @@ -import numpy as np -import pandas as pd - -from dask_ml.model_selection import HyperbandSearchCV -from dask_ml.datasets import make_classification -from distributed.utils_test import gen_cluster -from sklearn.linear_model import SGDClassifier - - -@gen_cluster(client=True) -def test_pandas(c, s, a, b): - X, y = make_classification(chunks=100) - X, y = pd.DataFrame(X.compute()), pd.Series(y.compute()) - - est = SGDClassifier(tol=1e-3) - param_dist = {'alpha': np.logspace(-4, 0, num=1000), - 'loss': ['hinge', 'log', 'modified_huber', 'squared_hinge'], - 'average': [True, False]} - - search = HyperbandSearchCV(est, param_dist) - search.fit(X, y, classes=y.unique()) - assert search.best_params_ From f7c01c7488dc0281a31b7cb49b403b108b1929b8 Mon Sep 17 00:00:00 2001 From: gioxc88 Date: Mon, 23 Nov 2020 00:57:26 +0000 Subject: [PATCH 7/8] fix make_classification --- tests/model_selection/test_hyperband.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/model_selection/test_hyperband.py b/tests/model_selection/test_hyperband.py index e8602064c..d1c257a85 100644 --- a/tests/model_selection/test_hyperband.py +++ b/tests/model_selection/test_hyperband.py @@ -484,7 +484,7 @@ async def test_dataframe_inputs(c, s, a, b): @gen_cluster(client=True) def test_pandas(c, s, a, b): - X, y = sk_make_classification(chunks=100) + X, y = sk_make_classification() X, y = pd.DataFrame(X), pd.Series(y) est = SGDClassifier(tol=1e-3) From ef2956a20e2286d64cbed6c6e5b1115d95d0927c Mon Sep 17 00:00:00 2001 From: gioxc88 Date: Mon, 23 Nov 2020 01:16:00 +0000 Subject: [PATCH 8/8] put yield in test --- tests/model_selection/test_hyperband.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/model_selection/test_hyperband.py b/tests/model_selection/test_hyperband.py index d1c257a85..6e31fddc4 100644 --- a/tests/model_selection/test_hyperband.py +++ b/tests/model_selection/test_hyperband.py @@ -493,5 +493,5 @@ def test_pandas(c, s, a, b): 'average': [True, False]} search = HyperbandSearchCV(est, param_dist) - search.fit(X, y, classes=y.unique()) - assert search.best_params_ \ No newline at end of file + yield search.fit(X, y, classes=y.unique()) + assert search.best_params_