Skip to content

Commit 4674372

Browse files
authored
Allow use on 1d arrays (#84)
1 parent a1643f9 commit 4674372

File tree

7 files changed

+83
-40
lines changed

7 files changed

+83
-40
lines changed

src/fast_array_utils/_validation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,19 @@
22
from __future__ import annotations
33

44
import numpy as np
5+
from numpy.exceptions import AxisError
56

67

7-
def validate_axis(axis: int | None) -> None:
8+
def validate_axis(ndim: int, axis: int | None) -> None:
89
if axis is None:
910
return
1011
if not isinstance(axis, int | np.integer): # pragma: no cover
1112
msg = "axis must be integer or None."
1213
raise TypeError(msg)
14+
if axis == 0 and ndim == 1:
15+
raise AxisError(axis, ndim, "use axis=None for 1D arrays")
16+
if axis not in range(ndim):
17+
raise AxisError(axis, ndim)
1318
if axis not in (0, 1): # pragma: no cover
1419
msg = "We only support axis 0 and 1 at the moment"
1520
raise NotImplementedError(msg)

src/fast_array_utils/stats/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def is_constant(
7575
"""
7676
from ._is_constant import is_constant_
7777

78-
validate_axis(axis)
78+
validate_axis(x.ndim, axis)
7979
return is_constant_(x, axis=axis)
8080

8181

@@ -144,7 +144,7 @@ def mean(
144144
"""
145145
from ._mean import mean_
146146

147-
validate_axis(axis)
147+
validate_axis(x.ndim, axis)
148148
return mean_(x, axis=axis, dtype=dtype) # type: ignore[no-any-return] # literally the same type, wtf mypy
149149

150150

@@ -219,6 +219,7 @@ def mean_var(
219219
"""
220220
from ._mean_var import mean_var_
221221

222+
validate_axis(x.ndim, axis)
222223
return mean_var_(x, axis=axis, correction=correction) # type: ignore[no-any-return]
223224

224225

@@ -284,5 +285,5 @@ def sum(
284285
"""
285286
from ._sum import sum_
286287

287-
validate_axis(axis)
288+
validate_axis(x.ndim, axis)
288289
return sum_(x, axis=axis, dtype=dtype)

src/fast_array_utils/stats/_sum.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def _sum_cs(
5252
if isinstance(x, types.CSMatrix):
5353
x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x)
5454

55-
return cast("NDArray[Any] | np.number[Any]", np.sum(x, axis=axis, dtype=dtype)) # type: ignore[call-overload]
55+
if axis is None:
56+
return cast("np.number[Any]", x.data.sum(dtype=dtype))
57+
return cast("NDArray[Any] | np.number[Any]", x.sum(axis=axis, dtype=dtype))
5658

5759

5860
@sum_.register(types.DaskArray)
@@ -76,17 +78,20 @@ def sum_drop_keepdims(
7678
keepdims: bool = False,
7779
) -> NDArray[Any] | types.CupyArray:
7880
del keepdims
79-
match axis:
80-
case (0 | 1 as n,):
81-
axis = n
82-
case (0, 1) | (1, 0):
83-
axis = None
84-
case tuple(): # pragma: no cover
85-
msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead"
86-
raise ValueError(msg)
81+
if a.ndim == 1:
82+
axis = None
83+
else:
84+
match axis:
85+
case (0, 1) | (1, 0):
86+
axis = None
87+
case (0 | 1 as n,):
88+
axis = n
89+
case tuple(): # pragma: no cover
90+
msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead"
91+
raise ValueError(msg)
8792
rv = sum(a, axis=axis, dtype=dtype)
88-
# make sure rv is 2D
89-
return np.reshape(rv, (1, 1 if rv.shape == () else len(rv))) # type: ignore[arg-type]
93+
shape = (1,) if a.ndim == 1 else (1, 1 if rv.shape == () else len(rv)) # type: ignore[arg-type]
94+
return np.reshape(rv, shape)
9095

9196
if dtype is None:
9297
# Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`)

tests/test_stats.py

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88
import pytest
9+
from numpy.exceptions import AxisError
910

1011
from fast_array_utils import stats, types
1112
from testing.fast_array_utils import SUPPORTED_TYPES, Flags
@@ -26,12 +27,12 @@
2627
DTypeIn = np.float32 | np.float64 | np.int32 | np.bool
2728
DTypeOut = np.float32 | np.float64 | np.int64
2829

29-
NdAndAx: TypeAlias = tuple[Literal[2], Literal[0, 1, None]]
30+
NdAndAx: TypeAlias = tuple[Literal[1], Literal[None]] | tuple[Literal[2], Literal[0, 1, None]]
3031

31-
class BenchFun(Protocol): # noqa: D101
32+
class StatFun(Protocol): # noqa: D101
3233
def __call__( # noqa: D102
3334
self,
34-
arr: CpuArray,
35+
arr: Array,
3536
*,
3637
axis: Literal[0, 1, None] = None,
3738
dtype: type[DTypeOut] | None = None,
@@ -41,6 +42,8 @@ def __call__( # noqa: D102
4142
pytestmark = [pytest.mark.skipif(not find_spec("numba"), reason="numba not installed")]
4243

4344

45+
STAT_FUNCS = [stats.sum, stats.mean, stats.mean_var, stats.is_constant]
46+
4447
# can’t select these using a category filter
4548
ATS_SPARSE_DS = {at for at in SUPPORTED_TYPES if at.mod == "anndata.abc"}
4649
ATS_CUPY_SPARSE = {at for at in SUPPORTED_TYPES if "cupyx.scipy" in str(at)}
@@ -49,6 +52,7 @@ def __call__( # noqa: D102
4952
@pytest.fixture(
5053
scope="session",
5154
params=[
55+
pytest.param((1, None), id="1d-all"),
5256
pytest.param((2, None), id="2d-all"),
5357
pytest.param((2, 0), id="2d-ax0"),
5458
pytest.param((2, 1), id="2d-ax1"),
@@ -59,18 +63,31 @@ def ndim_and_axis(request: pytest.FixtureRequest) -> NdAndAx:
5963

6064

6165
@pytest.fixture
62-
def ndim(ndim_and_axis: NdAndAx) -> Literal[2]:
63-
return ndim_and_axis[0]
66+
def ndim(ndim_and_axis: NdAndAx, array_type: ArrayType) -> Literal[1, 2]:
67+
return check_ndim(array_type, ndim_and_axis[0])
68+
69+
70+
def check_ndim(array_type: ArrayType, ndim: Literal[1, 2]) -> Literal[1, 2]:
71+
inner_cls = array_type.inner.cls if array_type.inner else array_type.cls
72+
if ndim != 2 and issubclass(inner_cls, types.CSMatrix | types.CupyCSMatrix):
73+
pytest.skip("CSMatrix only supports 2D")
74+
if ndim != 2 and inner_cls is types.csc_array:
75+
pytest.skip("csc_array only supports 2D")
76+
return ndim
6477

6578

6679
@pytest.fixture(scope="session")
6780
def axis(ndim_and_axis: NdAndAx) -> Literal[0, 1, None]:
6881
return ndim_and_axis[1]
6982

7083

71-
@pytest.fixture(scope="session", params=[np.float32, np.float64, np.int32, np.bool])
72-
def dtype_in(request: pytest.FixtureRequest) -> type[DTypeIn]:
73-
return cast("type[DTypeIn]", request.param)
84+
@pytest.fixture(params=[np.float32, np.float64, np.int32, np.bool])
85+
def dtype_in(request: pytest.FixtureRequest, array_type: ArrayType) -> type[DTypeIn]:
86+
dtype = cast("type[DTypeIn]", request.param)
87+
inner_cls = array_type.inner.cls if array_type.inner else array_type.cls
88+
if np.dtype(dtype).kind not in "fdFD" and issubclass(inner_cls, types.CupyCSMatrix):
89+
pytest.skip("Cupy sparse matrices don’t support non-floating dtypes")
90+
return dtype
7491

7592

7693
@pytest.fixture(scope="session", params=[np.float32, np.float64, None])
@@ -79,12 +96,33 @@ def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None:
7996

8097

8198
@pytest.fixture
82-
def np_arr(dtype_in: type[DTypeIn]) -> NDArray[DTypeIn]:
99+
def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]:
83100
np_arr = cast("NDArray[DTypeIn]", np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in))
84101
np_arr.flags.writeable = False
102+
if ndim == 1:
103+
np_arr = np_arr.flatten()
85104
return np_arr
86105

87106

107+
@pytest.mark.array_type(skip={*ATS_SPARSE_DS, Flags.Matrix})
108+
@pytest.mark.parametrize("func", STAT_FUNCS)
109+
@pytest.mark.parametrize(
110+
("ndim", "axis"), [(1, 0), (2, 3), (2, -1)], ids=["1d-ax0", "2d-ax3", "2d-axneg"]
111+
)
112+
def test_ndim_error(
113+
array_type: ArrayType[Array], func: StatFun, ndim: Literal[1, 2], axis: Literal[0, 1, None]
114+
) -> None:
115+
check_ndim(array_type, ndim)
116+
# not using the fixture because we don’t need to test multiple dtypes
117+
np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
118+
if ndim == 1:
119+
np_arr = np_arr.flatten()
120+
arr = array_type(np_arr)
121+
122+
with pytest.raises(AxisError):
123+
func(arr, axis=axis)
124+
125+
88126
@pytest.mark.array_type(skip=ATS_SPARSE_DS)
89127
def test_sum(
90128
array_type: ArrayType[Array],
@@ -93,8 +131,6 @@ def test_sum(
93131
axis: Literal[0, 1, None],
94132
np_arr: NDArray[DTypeIn],
95133
) -> None:
96-
if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f":
97-
pytest.skip("CuPy sparse matrices only support floats")
98134
arr = array_type(np_arr.copy())
99135
assert arr.dtype == dtype_in
100136

@@ -133,8 +169,6 @@ def test_sum(
133169
def test_mean(
134170
array_type: ArrayType[Array], axis: Literal[0, 1, None], np_arr: NDArray[DTypeIn]
135171
) -> None:
136-
if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f":
137-
pytest.skip("CuPy sparse matrices only support floats")
138172
arr = array_type(np_arr)
139173

140174
result = stats.mean(arr, axis=axis) # type: ignore[arg-type] # https://github.com/python/mypy/issues/16777
@@ -148,26 +182,21 @@ def test_mean(
148182

149183

150184
@pytest.mark.array_type(skip=Flags.Disk)
151-
@pytest.mark.parametrize(
152-
("axis", "mean_expected", "var_expected"),
153-
[(None, 3.5, 3.5), (0, [2.5, 3.5, 4.5], [4.5, 4.5, 4.5]), (1, [2.0, 5.0], [1.0, 1.0])],
154-
)
155185
def test_mean_var(
156186
array_type: ArrayType[CpuArray | GpuArray | types.DaskArray],
157187
axis: Literal[0, 1, None],
158-
mean_expected: float | list[float],
159-
var_expected: float | list[float],
188+
np_arr: NDArray[DTypeIn],
160189
) -> None:
161-
np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
162-
np.testing.assert_array_equal(np.mean(np_arr, axis=axis), mean_expected)
163-
np.testing.assert_array_equal(np.var(np_arr, axis=axis, correction=1), var_expected)
164-
165190
arr = array_type(np_arr)
191+
166192
mean, var = stats.mean_var(arr, axis=axis, correction=1)
167193
if isinstance(mean, types.DaskArray) and isinstance(var, types.DaskArray):
168194
mean, var = mean.compute(), var.compute() # type: ignore[assignment]
169195
if isinstance(mean, types.CupyArray) and isinstance(var, types.CupyArray):
170196
mean, var = mean.get(), var.get()
197+
198+
mean_expected = np.mean(np_arr, axis=axis) # type: ignore[arg-type]
199+
var_expected = np.var(np_arr, axis=axis, correction=1) # type: ignore[arg-type]
171200
np.testing.assert_array_equal(mean, mean_expected)
172201
np.testing.assert_array_almost_equal(var, var_expected) # type: ignore[arg-type]
173202

@@ -223,11 +252,11 @@ def test_dask_constant_blocks(
223252

224253
@pytest.mark.benchmark
225254
@pytest.mark.array_type(skip=Flags.Matrix | Flags.Dask | Flags.Disk | Flags.Gpu)
226-
@pytest.mark.parametrize("func", [stats.sum, stats.mean, stats.mean_var, stats.is_constant])
255+
@pytest.mark.parametrize("func", STAT_FUNCS)
227256
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32])
228257
def test_stats_benchmark(
229258
benchmark: BenchmarkFixture,
230-
func: BenchFun,
259+
func: StatFun,
231260
array_type: ArrayType[CpuArray, None],
232261
axis: Literal[0, 1, None],
233262
dtype: type[np.float32 | np.float64],

typings/cupy/_core/core.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ from numpy.typing import NDArray
88
class ndarray:
99
dtype: np.dtype[Any]
1010
shape: tuple[int, ...]
11+
ndim: int
1112

1213
# cupy-specific
1314
def get(self) -> NDArray[Any]: ...

typings/cupyx/scipy/sparse/_base.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ from numpy.typing import NDArray
99
class spmatrix:
1010
dtype: np.dtype[Any]
1111
shape: tuple[int, int]
12+
ndim: int
1213
def toarray(self, order: Literal["C", "F", None] = None, out: None = None) -> cupy.ndarray: ...
1314
def __power__(self, other: int) -> Self: ...
1415
def __array__(self) -> NDArray[Any]: ...

typings/h5py.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class HLObject: ...
1212
class Dataset(HLObject):
1313
dtype: np.dtype[Any]
1414
shape: tuple[int, ...]
15+
ndim: int
1516

1617
class Group(HLObject): ...
1718

0 commit comments

Comments
 (0)