Skip to content

Commit 43a2a4b

Browse files
authored
Allow multiple dims to be passed with min_count (pydata#4356)
* Allow multiple dims to be passed with min_count * Add whatsnew
1 parent efabe74 commit 43a2a4b

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

doc/whats-new.rst

+4-1
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,11 @@ New Features
3030
- Support multiple outputs in :py:func:`xarray.apply_ufunc` when using ``dask='parallelized'``. (:issue:`1815`, :pull:`4060`)
3131
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
3232
- :py:meth:`~xarray.DataArray.rolling` and :py:meth:`~xarray.Dataset.rolling`
33-
now accept more than 1 dimension.(:pull:`4219`)
33+
now accept more than 1 dimension. (:pull:`4219`)
3434
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
35+
- ``min_count`` can be supplied to reductions such as ``.sum`` when specifying
36+
multiple dimension to reduce over. (:pull:`4356`)
37+
By `Maximilian Roos <https://github.com/max-sixty>`_.
3538
- Build ``CFTimeIndex.__repr__`` explicitly as :py:class:`pandas.Index`. Add ``calendar`` as a new
3639
property for :py:class:`CFTimeIndex` and show ``calendar`` and ``length`` in
3740
``CFTimeIndex.__repr__`` (:issue:`2416`, :pull:`4092`)

xarray/core/nanops.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,9 @@ def _maybe_null_out(result, axis, mask, min_count=1):
2626
"""
2727
xarray version of pandas.core.nanops._maybe_null_out
2828
"""
29-
if hasattr(axis, "__len__"): # if tuple or list
30-
raise ValueError(
31-
"min_count is not available for reduction with more than one dimensions."
32-
)
3329

3430
if axis is not None and getattr(result, "ndim", False):
35-
null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0
31+
null_mask = (np.take(mask.shape, axis).prod() - mask.sum(axis) - min_count) < 0
3632
if null_mask.any():
3733
dtype, fill_value = dtypes.maybe_promote(result.dtype)
3834
result = result.astype(dtype)

xarray/tests/test_duck_array_ops.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,24 @@ def test_min_count(dim_num, dtype, dask, func, aggdim):
595595
assert_dask_array(actual, dask)
596596

597597

598+
@pytest.mark.parametrize("dtype", [float, int, np.float32, np.bool_])
599+
@pytest.mark.parametrize("dask", [False, True])
600+
@pytest.mark.parametrize("func", ["sum", "prod"])
601+
def test_min_count_nd(dtype, dask, func):
602+
if dask and not has_dask:
603+
pytest.skip("requires dask")
604+
605+
min_count = 3
606+
dim_num = 3
607+
da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask)
608+
actual = getattr(da, func)(dim=["x", "y", "z"], skipna=True, min_count=min_count)
609+
# Supplying all dims is equivalent to supplying `...` or `None`
610+
expected = getattr(da, func)(dim=..., skipna=True, min_count=min_count)
611+
612+
assert_allclose(actual, expected)
613+
assert_dask_array(actual, dask)
614+
615+
598616
@pytest.mark.parametrize("func", ["sum", "prod"])
599617
def test_min_count_dataset(func):
600618
da = construct_dataarray(2, dtype=float, contains_nan=True, dask=False)
@@ -606,14 +624,15 @@ def test_min_count_dataset(func):
606624

607625
@pytest.mark.parametrize("dtype", [float, int, np.float32, np.bool_])
608626
@pytest.mark.parametrize("dask", [False, True])
627+
@pytest.mark.parametrize("skipna", [False, True])
609628
@pytest.mark.parametrize("func", ["sum", "prod"])
610-
def test_multiple_dims(dtype, dask, func):
629+
def test_multiple_dims(dtype, dask, skipna, func):
611630
if dask and not has_dask:
612631
pytest.skip("requires dask")
613632
da = construct_dataarray(3, dtype, contains_nan=True, dask=dask)
614633

615-
actual = getattr(da, func)(("x", "y"))
616-
expected = getattr(getattr(da, func)("x"), func)("y")
634+
actual = getattr(da, func)(("x", "y"), skipna=skipna)
635+
expected = getattr(getattr(da, func)("x", skipna=skipna), func)("y", skipna=skipna)
617636
assert_allclose(actual, expected)
618637

619638

0 commit comments

Comments
 (0)