Skip to content

Commit 820043e

Browse files
committed
lint
1 parent fa47560 commit 820043e

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

src/array_api_extra/_apply.py

+18-15
Original file line numberDiff line numberDiff line change
@@ -17,43 +17,43 @@
1717

1818
if TYPE_CHECKING:
1919
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
20-
from typing import TypeAlias
20+
from typing import ParamSpec, TypeAlias
2121

2222
import numpy as np
2323

2424
NumPyObject: TypeAlias = np.ndarray[Any, Any] | np.generic # type: ignore[no-any-explicit]
25-
KwArg: TypeAlias = Any # type: ignore[no-any-explicit]
25+
P = ParamSpec("P")
2626

2727

2828
@overload
29-
def apply_numpy_func(
30-
func: Callable[..., NumPyObject],
29+
def apply_numpy_func( # type: ignore[valid-type]
30+
func: Callable[P, NumPyObject],
3131
*args: Array,
3232
shape: tuple[int, ...] | None = None,
3333
dtype: DType | None = None,
3434
xp: ModuleType | None = None,
35-
**kwargs: KwArg,
35+
**kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues]
3636
) -> Array: ... # numpydoc ignore=GL08
3737

3838

3939
@overload
40-
def apply_numpy_func( # type: ignore[no-any-decorated]
41-
func: Callable[..., Sequence[NumPyObject]],
40+
def apply_numpy_func( # type: ignore[valid-type]
41+
func: Callable[P, Sequence[NumPyObject]],
4242
*args: Array,
4343
shape: Sequence[tuple[int, ...]],
4444
dtype: Sequence[DType] | None = None,
4545
xp: ModuleType | None = None,
46-
**kwargs: Any,
46+
**kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues]
4747
) -> tuple[Array, ...]: ... # numpydoc ignore=GL08
4848

4949

50-
def apply_numpy_func( # type: ignore[no-any-explicit]
51-
func: Callable[..., NumPyObject | Sequence[NumPyObject]],
50+
def apply_numpy_func( # type: ignore[valid-type]
51+
func: Callable[P, NumPyObject | Sequence[NumPyObject]],
5252
*args: Array,
5353
shape: tuple[int, ...] | Sequence[tuple[int, ...]] | None = None,
5454
dtype: DType | Sequence[DType] | None = None,
5555
xp: ModuleType | None = None,
56-
**kwargs: Any,
56+
**kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues]
5757
) -> Array | tuple[Array, ...]:
5858
"""
5959
Apply a function that operates on NumPy arrays to Array API compliant arrays.
@@ -139,7 +139,7 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
139139
elif isinstance(shape, tuple) and all(isinstance(s, int) for s in shape):
140140
shapes = [shape]
141141
else:
142-
shapes = shape
142+
shapes = list(shape)
143143
multi_output = True
144144

145145
if dtype is None:
@@ -148,7 +148,7 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
148148
if not isinstance(dtype, Sequence):
149149
msg = "Got sequence of shapes but only one dtype"
150150
raise TypeError(msg)
151-
dtypes = dtype
151+
dtypes = list(dtype) # pyright: ignore[reportUnknownArgumentType]
152152
else:
153153
if isinstance(dtype, Sequence):
154154
msg = "Got single shape but multiple dtypes"
@@ -254,13 +254,16 @@ def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
254254
args = tuple(np.asarray(arg) for arg in args)
255255
out = func(*args, **kwargs)
256256

257+
# Stay relaxed on output validation, e.g. in case func returns a
258+
# Python scalar instead of a np.generic
257259
if multi_output:
258260
if not isinstance(out, Sequence) or isinstance(out, np.ndarray):
259261
msg = "Expected multiple outputs, got a single one"
260262
raise ValueError(msg)
263+
outs = out
261264
else:
262-
out = (out,)
265+
outs = [cast("NumPyObject", out)]
263266

264-
return tuple(xp.asarray(o) for o in out)
267+
return tuple(xp.asarray(o) for o in outs)
265268

266269
return wrapper

0 commit comments

Comments
 (0)