6
6
7
7
import numpy as np
8
8
import pytest
9
+ from numpy .exceptions import AxisError
9
10
10
11
from fast_array_utils import stats , types
11
12
from testing .fast_array_utils import SUPPORTED_TYPES , Flags
26
27
DTypeIn = np .float32 | np .float64 | np .int32 | np .bool
27
28
DTypeOut = np .float32 | np .float64 | np .int64
28
29
29
- NdAndAx : TypeAlias = tuple [Literal [2 ], Literal [0 , 1 , None ]]
30
+ NdAndAx : TypeAlias = tuple [Literal [1 ], Literal [ None ]] | tuple [ Literal [ 2 ], Literal [0 , 1 , None ]]
30
31
31
- class BenchFun (Protocol ): # noqa: D101
32
+ class StatFun (Protocol ): # noqa: D101
32
33
def __call__ ( # noqa: D102
33
34
self ,
34
- arr : CpuArray ,
35
+ arr : Array ,
35
36
* ,
36
37
axis : Literal [0 , 1 , None ] = None ,
37
38
dtype : type [DTypeOut ] | None = None ,
@@ -41,6 +42,8 @@ def __call__( # noqa: D102
41
42
pytestmark = [pytest .mark .skipif (not find_spec ("numba" ), reason = "numba not installed" )]
42
43
43
44
45
+ STAT_FUNCS = [stats .sum , stats .mean , stats .mean_var , stats .is_constant ]
46
+
44
47
# can’t select these using a category filter
45
48
ATS_SPARSE_DS = {at for at in SUPPORTED_TYPES if at .mod == "anndata.abc" }
46
49
ATS_CUPY_SPARSE = {at for at in SUPPORTED_TYPES if "cupyx.scipy" in str (at )}
@@ -49,6 +52,7 @@ def __call__( # noqa: D102
49
52
@pytest .fixture (
50
53
scope = "session" ,
51
54
params = [
55
+ pytest .param ((1 , None ), id = "1d-all" ),
52
56
pytest .param ((2 , None ), id = "2d-all" ),
53
57
pytest .param ((2 , 0 ), id = "2d-ax0" ),
54
58
pytest .param ((2 , 1 ), id = "2d-ax1" ),
@@ -59,18 +63,31 @@ def ndim_and_axis(request: pytest.FixtureRequest) -> NdAndAx:
59
63
60
64
61
65
@pytest .fixture
62
- def ndim (ndim_and_axis : NdAndAx ) -> Literal [2 ]:
63
- return ndim_and_axis [0 ]
66
+ def ndim (ndim_and_axis : NdAndAx , array_type : ArrayType ) -> Literal [1 , 2 ]:
67
+ return check_ndim (array_type , ndim_and_axis [0 ])
68
+
69
+
70
+ def check_ndim (array_type : ArrayType , ndim : Literal [1 , 2 ]) -> Literal [1 , 2 ]:
71
+ inner_cls = array_type .inner .cls if array_type .inner else array_type .cls
72
+ if ndim != 2 and issubclass (inner_cls , types .CSMatrix | types .CupyCSMatrix ):
73
+ pytest .skip ("CSMatrix only supports 2D" )
74
+ if ndim != 2 and inner_cls is types .csc_array :
75
+ pytest .skip ("csc_array only supports 2D" )
76
+ return ndim
64
77
65
78
66
79
@pytest .fixture (scope = "session" )
67
80
def axis (ndim_and_axis : NdAndAx ) -> Literal [0 , 1 , None ]:
68
81
return ndim_and_axis [1 ]
69
82
70
83
71
- @pytest .fixture (scope = "session" , params = [np .float32 , np .float64 , np .int32 , np .bool ])
72
- def dtype_in (request : pytest .FixtureRequest ) -> type [DTypeIn ]:
73
- return cast ("type[DTypeIn]" , request .param )
84
+ @pytest .fixture (params = [np .float32 , np .float64 , np .int32 , np .bool ])
85
+ def dtype_in (request : pytest .FixtureRequest , array_type : ArrayType ) -> type [DTypeIn ]:
86
+ dtype = cast ("type[DTypeIn]" , request .param )
87
+ inner_cls = array_type .inner .cls if array_type .inner else array_type .cls
88
+ if np .dtype (dtype ).kind not in "fdFD" and issubclass (inner_cls , types .CupyCSMatrix ):
89
+ pytest .skip ("Cupy sparse matrices don’t support non-floating dtypes" )
90
+ return dtype
74
91
75
92
76
93
@pytest .fixture (scope = "session" , params = [np .float32 , np .float64 , None ])
@@ -79,12 +96,33 @@ def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None:
79
96
80
97
81
98
@pytest .fixture
82
- def np_arr (dtype_in : type [DTypeIn ]) -> NDArray [DTypeIn ]:
99
+ def np_arr (dtype_in : type [DTypeIn ], ndim : Literal [ 1 , 2 ] ) -> NDArray [DTypeIn ]:
83
100
np_arr = cast ("NDArray[DTypeIn]" , np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = dtype_in ))
84
101
np_arr .flags .writeable = False
102
+ if ndim == 1 :
103
+ np_arr = np_arr .flatten ()
85
104
return np_arr
86
105
87
106
107
+ @pytest .mark .array_type (skip = {* ATS_SPARSE_DS , Flags .Matrix })
108
+ @pytest .mark .parametrize ("func" , STAT_FUNCS )
109
+ @pytest .mark .parametrize (
110
+ ("ndim" , "axis" ), [(1 , 0 ), (2 , 3 ), (2 , - 1 )], ids = ["1d-ax0" , "2d-ax3" , "2d-axneg" ]
111
+ )
112
+ def test_ndim_error (
113
+ array_type : ArrayType [Array ], func : StatFun , ndim : Literal [1 , 2 ], axis : Literal [0 , 1 , None ]
114
+ ) -> None :
115
+ check_ndim (array_type , ndim )
116
+ # not using the fixture because we don’t need to test multiple dtypes
117
+ np_arr = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = np .float32 )
118
+ if ndim == 1 :
119
+ np_arr = np_arr .flatten ()
120
+ arr = array_type (np_arr )
121
+
122
+ with pytest .raises (AxisError ):
123
+ func (arr , axis = axis )
124
+
125
+
88
126
@pytest .mark .array_type (skip = ATS_SPARSE_DS )
89
127
def test_sum (
90
128
array_type : ArrayType [Array ],
@@ -93,8 +131,6 @@ def test_sum(
93
131
axis : Literal [0 , 1 , None ],
94
132
np_arr : NDArray [DTypeIn ],
95
133
) -> None :
96
- if array_type in ATS_CUPY_SPARSE and np_arr .dtype .kind != "f" :
97
- pytest .skip ("CuPy sparse matrices only support floats" )
98
134
arr = array_type (np_arr .copy ())
99
135
assert arr .dtype == dtype_in
100
136
@@ -133,8 +169,6 @@ def test_sum(
133
169
def test_mean (
134
170
array_type : ArrayType [Array ], axis : Literal [0 , 1 , None ], np_arr : NDArray [DTypeIn ]
135
171
) -> None :
136
- if array_type in ATS_CUPY_SPARSE and np_arr .dtype .kind != "f" :
137
- pytest .skip ("CuPy sparse matrices only support floats" )
138
172
arr = array_type (np_arr )
139
173
140
174
result = stats .mean (arr , axis = axis ) # type: ignore[arg-type] # https://github.com/python/mypy/issues/16777
@@ -148,26 +182,21 @@ def test_mean(
148
182
149
183
150
184
@pytest .mark .array_type (skip = Flags .Disk )
151
- @pytest .mark .parametrize (
152
- ("axis" , "mean_expected" , "var_expected" ),
153
- [(None , 3.5 , 3.5 ), (0 , [2.5 , 3.5 , 4.5 ], [4.5 , 4.5 , 4.5 ]), (1 , [2.0 , 5.0 ], [1.0 , 1.0 ])],
154
- )
155
185
def test_mean_var (
156
186
array_type : ArrayType [CpuArray | GpuArray | types .DaskArray ],
157
187
axis : Literal [0 , 1 , None ],
158
- mean_expected : float | list [float ],
159
- var_expected : float | list [float ],
188
+ np_arr : NDArray [DTypeIn ],
160
189
) -> None :
161
- np_arr = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = np .float64 )
162
- np .testing .assert_array_equal (np .mean (np_arr , axis = axis ), mean_expected )
163
- np .testing .assert_array_equal (np .var (np_arr , axis = axis , correction = 1 ), var_expected )
164
-
165
190
arr = array_type (np_arr )
191
+
166
192
mean , var = stats .mean_var (arr , axis = axis , correction = 1 )
167
193
if isinstance (mean , types .DaskArray ) and isinstance (var , types .DaskArray ):
168
194
mean , var = mean .compute (), var .compute () # type: ignore[assignment]
169
195
if isinstance (mean , types .CupyArray ) and isinstance (var , types .CupyArray ):
170
196
mean , var = mean .get (), var .get ()
197
+
198
+ mean_expected = np .mean (np_arr , axis = axis ) # type: ignore[arg-type]
199
+ var_expected = np .var (np_arr , axis = axis , correction = 1 ) # type: ignore[arg-type]
171
200
np .testing .assert_array_equal (mean , mean_expected )
172
201
np .testing .assert_array_almost_equal (var , var_expected ) # type: ignore[arg-type]
173
202
@@ -223,11 +252,11 @@ def test_dask_constant_blocks(
223
252
224
253
@pytest .mark .benchmark
225
254
@pytest .mark .array_type (skip = Flags .Matrix | Flags .Dask | Flags .Disk | Flags .Gpu )
226
- @pytest .mark .parametrize ("func" , [ stats . sum , stats . mean , stats . mean_var , stats . is_constant ] )
255
+ @pytest .mark .parametrize ("func" , STAT_FUNCS )
227
256
@pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 , np .int32 ])
228
257
def test_stats_benchmark (
229
258
benchmark : BenchmarkFixture ,
230
- func : BenchFun ,
259
+ func : StatFun ,
231
260
array_type : ArrayType [CpuArray , None ],
232
261
axis : Literal [0 , 1 , None ],
233
262
dtype : type [np .float32 | np .float64 ],
0 commit comments