Skip to content

Commit

Permalink
first pass at fixing empty partition failures
Browse files Browse the repository at this point in the history
  • Loading branch information
VibhuJawa committed Mar 25, 2022
1 parent 1e811ce commit 28b97e0
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 31 deletions.
117 changes: 88 additions & 29 deletions dask_ml/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,23 +231,22 @@ def transform(self, X):
"""
self._check_method("transform")
X = self._check_array(X)
meta = self.transform_meta
output_meta = self.transform_meta

if isinstance(X, da.Array):
if meta is None:
meta = _get_output_dask_ar_meta_for_estimator(
if output_meta is None:
output_meta = _get_output_dask_ar_meta_for_estimator(
_transform, self._postfit_estimator, X
)
return X.map_blocks(
_transform, estimator=self._postfit_estimator, meta=meta
_transform, estimator=self._postfit_estimator, meta=output_meta
)
elif isinstance(X, dd._Frame):
if meta is None:
# dask-dataframe relies on dd.core.no_default
# for infering meta
meta = dd.core.no_default
return X.map_partitions(
_transform, estimator=self._postfit_estimator, meta=meta
return _get_output_df_for_estimator(
model_fn=_transform,
X=X,
output_meta=output_meta,
estimator=self._postfit_estimator,
)
else:
return _transform(X, estimator=self._postfit_estimator)
Expand Down Expand Up @@ -311,25 +310,30 @@ def predict(self, X):
"""
self._check_method("predict")
X = self._check_array(X)
meta = self.predict_meta
output_meta = self.predict_meta

if isinstance(X, da.Array):
if meta is None:
meta = _get_output_dask_ar_meta_for_estimator(
if output_meta is None:
output_meta = _get_output_dask_ar_meta_for_estimator(
_predict, self._postfit_estimator, X
)

result = X.map_blocks(
_predict, estimator=self._postfit_estimator, drop_axis=1, meta=meta
_predict,
estimator=self._postfit_estimator,
drop_axis=1,
meta=output_meta,
)
return result

elif isinstance(X, dd._Frame):
if meta is None:
meta = dd.core.no_default
return X.map_partitions(
_predict, estimator=self._postfit_estimator, meta=meta
return _get_output_df_for_estimator(
model_fn=_predict,
X=X,
output_meta=output_meta,
estimator=self._postfit_estimator,
)

else:
return _predict(X, estimator=self._postfit_estimator)

Expand All @@ -355,25 +359,26 @@ def predict_proba(self, X):

self._check_method("predict_proba")

meta = self.predict_proba_meta
output_meta = self.predict_proba_meta

if isinstance(X, da.Array):
if meta is None:
meta = _get_output_dask_ar_meta_for_estimator(
if output_meta is None:
output_meta = _get_output_dask_ar_meta_for_estimator(
_predict_proba, self._postfit_estimator, X
)
# XXX: multiclass
return X.map_blocks(
_predict_proba,
estimator=self._postfit_estimator,
meta=meta,
meta=output_meta,
chunks=(X.chunks[0], len(self._postfit_estimator.classes_)),
)
elif isinstance(X, dd._Frame):
if meta is None:
meta = dd.core.no_default
return X.map_partitions(
_predict_proba, estimator=self._postfit_estimator, meta=meta
return _get_output_df_for_estimator(
model_fn=_predict_proba,
X=X,
output_meta=output_meta,
estimator=self._postfit_estimator,
)
else:
return _predict_proba(X, estimator=self._postfit_estimator)
Expand Down Expand Up @@ -626,18 +631,63 @@ def _first_block(dask_object):
return dask_object


def _predict(part, estimator):
def _predict(part, estimator, output_meta=None):
if part.shape[0] == 0 and output_meta is not None:
empty_output = handle_empty_partitions(output_meta)
if empty_output is not None:
return empty_output
return estimator.predict(part)


def _predict_proba(part, estimator):
def _predict_proba(part, estimator, output_meta=None):
if part.shape[0] == 0 and output_meta is not None:
empty_output = handle_empty_partitions(output_meta)
if empty_output is not None:
return empty_output

return estimator.predict_proba(part)


def _transform(part, estimator):
def _transform(part, estimator, output_meta=None):
if part.shape[0] == 0 and output_meta is not None:
empty_output = handle_empty_partitions(output_meta)
if empty_output is not None:
return empty_output

return estimator.transform(part)


def handle_empty_partitions(output_meta):
if hasattr(output_meta, "__array_function__"):
if len(output_meta.shape) == 1:
shape = 0
else:
shape = list(output_meta.shape)
shape[0] = 0
ar = np.zeros(
shape=shape,
dtype=output_meta.dtype,
like=output_meta,
)
return ar
elif "scipy.sparse" in type(output_meta).__module__:
# sparse matrices dont support
# `like` due to non implimented __array_function__
# Refer https://github.com/scipy/scipy/issues/10362
# Note below works for both cupy and scipy sparse matrices
# TODO: REMOVE code duplication
if len(ar.shape) == 1:
shape = 0
else:
shape = list(ar.shape)
shape[0] = 0

ar = type(output_meta)(shape, dtype=output_meta.dtype)
return ar
elif hasattr(output_meta, "iloc"):
return output_meta.iloc[:0, :]


def _get_output_dask_ar_meta_for_estimator(model_fn, estimator, input_dask_ar):
"""
Returns the output metadata array
Expand Down Expand Up @@ -692,3 +742,12 @@ def _get_output_dask_ar_meta_for_estimator(model_fn, estimator, input_dask_ar):
warnings.warn(msg)
ar = np.zeros(shape=(1, input_dask_ar.shape[1]), dtype=input_dask_ar.dtype)
return model_fn(ar, estimator)


def _get_output_df_for_estimator(model_fn, X, output_meta, estimator):
if output_meta is None:
# dask-dataframe relies on dd.core.no_default
# for infering meta
output_meta = model_fn(X._meta_nonempty, estimator)

return X.map_partitions(model_fn, estimator, output_meta, meta=output_meta)
21 changes: 19 additions & 2 deletions tests/test_parallel_post_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def test_predict_meta_override():
# Failure when not proving predict_meta
# because of value dependent model
wrap = ParallelPostFit(base)
with pytest.raises(ValueError):
# TODO: Fix
with pytest.raises(IndexError):
wrap.predict(dd_X)

# Success when providing meta over-ride
Expand All @@ -89,7 +90,8 @@ def test_predict_proba_meta_override():
# Failure when not proving predict_proba_meta
# because of value dependent model
wrap = ParallelPostFit(base)
with pytest.raises(ValueError):
# TODO: Fix below
with pytest.raises(IndexError):
wrap.predict_proba(dd_X)

# Success when providing meta over-ride
Expand Down Expand Up @@ -289,3 +291,18 @@ def shape(self):
match="provide explicit `predict_proba_meta` to the dask_ml.wrapper",
):
clf.predict_proba(fake_dask_ar)


def test_predict_empty_partitions():
df = pd.DataFrame({"x": [1, 2, 3, 4, 5, 6, 7, 8], "y": [True, False] * 4})
ddf = dd.from_pandas(df, npartitions=4)

clf = ParallelPostFit(LogisticRegression())
clf = clf.fit(df[["x"]], df["y"])

ddf_with_empty_part = ddf[ddf.x < 5][["x"]]
result = clf.predict(ddf_with_empty_part).compute()

expected = clf.estimator.predict(ddf_with_empty_part.compute())

assert_eq_ar(result, expected)

0 comments on commit 28b97e0

Please sign in to comment.