17
17
18
18
if TYPE_CHECKING :
19
19
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
20
- from typing import TypeAlias
20
+ from typing import ParamSpec , TypeAlias
21
21
22
22
import numpy as np
23
23
24
24
NumPyObject : TypeAlias = np .ndarray [Any , Any ] | np .generic # type: ignore[no-any-explicit]
25
- KwArg : TypeAlias = Any # type: ignore[no-any-explicit]
25
+ P = ParamSpec ( "P" )
26
26
27
27
28
28
@overload
29
- def apply_numpy_func (
30
- func : Callable [... , NumPyObject ],
29
+ def apply_numpy_func ( # type: ignore[valid-type]
30
+ func : Callable [P , NumPyObject ],
31
31
* args : Array ,
32
32
shape : tuple [int , ...] | None = None ,
33
33
dtype : DType | None = None ,
34
34
xp : ModuleType | None = None ,
35
- ** kwargs : KwArg ,
35
+ ** kwargs : P . kwargs , # pyright: ignore[reportGeneralTypeIssues]
36
36
) -> Array : ... # numpydoc ignore=GL08
37
37
38
38
39
39
@overload
40
- def apply_numpy_func ( # type: ignore[no-any-decorated ]
41
- func : Callable [... , Sequence [NumPyObject ]],
40
+ def apply_numpy_func ( # type: ignore[valid-type ]
41
+ func : Callable [P , Sequence [NumPyObject ]],
42
42
* args : Array ,
43
43
shape : Sequence [tuple [int , ...]],
44
44
dtype : Sequence [DType ] | None = None ,
45
45
xp : ModuleType | None = None ,
46
- ** kwargs : Any ,
46
+ ** kwargs : P . kwargs , # pyright: ignore[reportGeneralTypeIssues]
47
47
) -> tuple [Array , ...]: ... # numpydoc ignore=GL08
48
48
49
49
50
- def apply_numpy_func ( # type: ignore[no-any-explicit ]
51
- func : Callable [... , NumPyObject | Sequence [NumPyObject ]],
50
+ def apply_numpy_func ( # type: ignore[valid-type ]
51
+ func : Callable [P , NumPyObject | Sequence [NumPyObject ]],
52
52
* args : Array ,
53
53
shape : tuple [int , ...] | Sequence [tuple [int , ...]] | None = None ,
54
54
dtype : DType | Sequence [DType ] | None = None ,
55
55
xp : ModuleType | None = None ,
56
- ** kwargs : Any ,
56
+ ** kwargs : P . kwargs , # pyright: ignore[reportGeneralTypeIssues]
57
57
) -> Array | tuple [Array , ...]:
58
58
"""
59
59
Apply a function that operates on NumPy arrays to Array API compliant arrays.
@@ -139,7 +139,7 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
139
139
elif isinstance (shape , tuple ) and all (isinstance (s , int ) for s in shape ):
140
140
shapes = [shape ]
141
141
else :
142
- shapes = shape
142
+ shapes = list ( shape )
143
143
multi_output = True
144
144
145
145
if dtype is None :
@@ -148,7 +148,7 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
148
148
if not isinstance (dtype , Sequence ):
149
149
msg = "Got sequence of shapes but only one dtype"
150
150
raise TypeError (msg )
151
- dtypes = dtype
151
+ dtypes = list ( dtype ) # pyright: ignore[reportUnknownArgumentType]
152
152
else :
153
153
if isinstance (dtype , Sequence ):
154
154
msg = "Got single shape but multiple dtypes"
@@ -254,13 +254,16 @@ def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
254
254
args = tuple (np .asarray (arg ) for arg in args )
255
255
out = func (* args , ** kwargs )
256
256
257
+ # Stay relaxed on output validation, e.g. in case func returns a
258
+ # Python scalar instead of a np.generic
257
259
if multi_output :
258
260
if not isinstance (out , Sequence ) or isinstance (out , np .ndarray ):
259
261
msg = "Expected multiple outputs, got a single one"
260
262
raise ValueError (msg )
263
+ outs = out
261
264
else :
262
- out = ( out ,)
265
+ outs = [ cast ( "NumPyObject" , out )]
263
266
264
- return tuple (xp .asarray (o ) for o in out )
267
+ return tuple (xp .asarray (o ) for o in outs )
265
268
266
269
return wrapper
0 commit comments