Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement check_consistent_length #900

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions dask_ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ def check_random_state(random_state):
raise TypeError("Unexpected type '{}'".format(type(random_state)))


def check_matching_blocks(*arrays):
"""Check that the partitioning structure for many arrays matches.
def _check_matching_blocks(*arrays, check_first_dim_only=False):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this was the best way to re-use code between check_consistent_length and check_matching_blocks without changing the api of either function.

"""Helper function to check blocks match across *arrays.

Parameters
----------
Expand All @@ -252,18 +252,22 @@ def check_matching_blocks(*arrays):
* Dask Array
* Dask DataFrame
* Dask Series
check_first_dim_only: bool, default false
Whether to only checks the chunks along the first dimension. Only applies
if all the arrays are dask arrays.
"""
if len(arrays) <= 1:
return
slice_to_check = slice(0, 1, 1) if check_first_dim_only else slice(None, None)
if all(isinstance(x, da.Array) for x in arrays):
# TODO: unknown chunks, ensure blocks match, or just raise (configurable)
chunks = arrays[0].chunks
chunks = arrays[0].chunks[slice_to_check]
for array in arrays[1:]:
if array.chunks != chunks:
if array.chunks[slice_to_check] != chunks:
raise ValueError(
"Mismatched chunks. {} != {}".format(chunks, array.chunks)
)

# Divisions correspond to the index (first_dim) so no need to use slice_to_check
elif all(isinstance(x, (dd.Series, dd.DataFrame)) for x in arrays):
divisions = arrays[0].divisions
for array in arrays[1:]:
Expand All @@ -275,6 +279,21 @@ def check_matching_blocks(*arrays):
raise ValueError("Unexpected types {}.".format({type(x) for x in arrays}))


def check_matching_blocks(*arrays):
"""Check that the partitioning structure for many arrays matches.

Parameters
----------
*arrays : Sequence of array-likes
This includes

* Dask Array
* Dask DataFrame
* Dask Series
"""
_check_matching_blocks(*arrays, check_first_dim_only=False)


def check_X_y(
X,
y,
Expand Down Expand Up @@ -433,8 +452,8 @@ def _check_y(y, multi_output=False, y_numeric=False):


def check_consistent_length(*arrays):
# TODO: check divisions, chunks, etc.
pass
"""Check that blocks match for arrays and divisions match for dataframes."""
_check_matching_blocks(*arrays, check_first_dim_only=True)


def check_chunks(n_samples, n_features, chunks=None):
Expand Down
60 changes: 60 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
assert_estimator_equal,
check_array,
check_chunks,
check_consistent_length,
check_matching_blocks,
check_random_state,
handle_zeros_in_scale,
Expand Down Expand Up @@ -234,3 +235,62 @@ def test_matching_blocks_ok(arrays):
def test_matching_blocks_raises(arrays):
with pytest.raises(ValueError):
check_matching_blocks(*arrays)


@pytest.mark.parametrize(
"arrays",
[
(
da.random.uniform(size=(10, 10), chunks=(10, 10)),
da.random.uniform(size=10, chunks=10),
),
(
da.random.uniform(size=(50, 10), chunks=(50, 10)),
da.random.uniform(size=50, chunks=50),
),
(
dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]}), 2)
.reset_index()
.to_dask_array(),
dd.from_pandas(pd.Series([1, 2, 3]), 2).reset_index().to_dask_array(),
),
(
dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]}), 2),
dd.from_pandas(pd.Series([1, 2, 3]), 2),
),
# Allow known and unknown?
pytest.param(
(
dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]}), 2)
.reset_index()
.to_dask_array(),
dd.from_pandas(pd.Series([1, 2, 3]), 2).reset_index(),
),
marks=pytest.mark.xfail(reason="Known and unknown blocks."),
),
],
)
def test_check_consistent_length_ok(arrays):
check_consistent_length(*arrays)


@pytest.mark.parametrize(
"arrays",
[
(
da.random.uniform(size=(10, 10), chunks=(10, 10)),
da.random.uniform(size=8, chunks=8),
),
(
da.random.uniform(size=(100, 10), chunks=(100, 10)),
da.random.uniform(size=50, chunks=50),
),
(
dd.from_pandas(pd.DataFrame({"a": [1, 2, 3, 4]}), 4),
dd.from_pandas(pd.Series([1, 2, 3]), 2),
),
],
)
def test_check_consistent_length_raises(arrays):
with pytest.raises(ValueError):
check_consistent_length(*arrays)