23
23
SupportsIndex ,
24
24
TypeAlias ,
25
25
TypeGuard ,
26
- TypeVar ,
27
26
cast ,
28
27
overload ,
29
28
)
30
29
31
30
from ._typing import Array , Device , HasShape , Namespace , SupportsArrayNamespace
32
31
33
32
if TYPE_CHECKING :
34
-
33
+ import cupy as cp
35
34
import dask .array as da
36
35
import jax
37
36
import ndonnx as ndx
38
37
import numpy as np
39
38
import numpy .typing as npt
40
- import sparse # pyright: ignore[reportMissingTypeStubs]
39
+ import sparse
41
40
import torch
42
41
43
42
# TODO: import from typing (requires Python >=3.13)
44
- from typing_extensions import TypeIs , TypeVar
45
-
46
- _SizeT = TypeVar ("_SizeT" , bound = int | None )
43
+ from typing_extensions import TypeIs
47
44
48
45
_ZeroGradientArray : TypeAlias = npt .NDArray [np .void ]
49
- _CupyArray : TypeAlias = Any # cupy has no py.typed
50
46
51
47
_ArrayApiObj : TypeAlias = (
52
48
npt .NDArray [Any ]
49
+ | cp .ndarray
53
50
| da .Array
54
51
| jax .Array
55
52
| ndx .Array
56
53
| sparse .SparseArray
57
54
| torch .Tensor
58
55
| SupportsArrayNamespace [Any ]
59
- | _CupyArray
60
56
)
61
57
62
58
_API_VERSIONS_OLD : Final = frozenset ({"2021.12" , "2022.12" , "2023.12" })
@@ -96,7 +92,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
96
92
return dtype == jax .float0
97
93
98
94
99
- def is_numpy_array (x : object ) -> TypeGuard [npt .NDArray [Any ]]:
95
+ def is_numpy_array (x : object ) -> TypeIs [npt .NDArray [Any ]]:
100
96
"""
101
97
Return True if `x` is a NumPy array.
102
98
@@ -267,7 +263,7 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
267
263
return _issubclass_fast (cls , "sparse" , "SparseArray" )
268
264
269
265
270
- def is_array_api_obj (x : object ) -> TypeIs [_ArrayApiObj ]: # pyright: ignore[reportUnknownParameterType]
266
+ def is_array_api_obj (x : object ) -> TypeGuard [_ArrayApiObj ]:
271
267
"""
272
268
Return True if `x` is an array API compatible array object.
273
269
@@ -748,7 +744,7 @@ def device(x: _ArrayApiObj, /) -> Device:
748
744
return "cpu"
749
745
elif is_dask_array (x ):
750
746
# Peek at the metadata of the Dask array to determine type
751
- if is_numpy_array (x ._meta ): # pyright: ignore
747
+ if is_numpy_array (x ._meta ):
752
748
# Must be on CPU since backed by numpy
753
749
return "cpu"
754
750
return _DASK_DEVICE
@@ -777,7 +773,7 @@ def device(x: _ArrayApiObj, /) -> Device:
777
773
return "cpu"
778
774
# Return the device of the constituent array
779
775
return device (inner ) # pyright: ignore
780
- return x .device # pyright: ignore
776
+ return x .device # type: ignore # pyright: ignore
781
777
782
778
783
779
# Prevent shadowing, used below
@@ -786,11 +782,11 @@ def device(x: _ArrayApiObj, /) -> Device:
786
782
787
783
# Based on cupy.array_api.Array.to_device
788
784
def _cupy_to_device (
789
- x : _CupyArray ,
785
+ x : cp . ndarray ,
790
786
device : Device ,
791
787
/ ,
792
788
stream : int | Any | None = None ,
793
- ) -> _CupyArray :
789
+ ) -> cp . ndarray :
794
790
import cupy as cp
795
791
796
792
if device == "cpu" :
@@ -819,7 +815,7 @@ def _torch_to_device(
819
815
x : torch .Tensor ,
820
816
device : torch .device | str | int ,
821
817
/ ,
822
- stream : None = None ,
818
+ stream : int | Any | None = None ,
823
819
) -> torch .Tensor :
824
820
if stream is not None :
825
821
raise NotImplementedError
@@ -885,7 +881,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
885
881
# cupy does not yet have to_device
886
882
return _cupy_to_device (x , device , stream = stream )
887
883
elif is_torch_array (x ):
888
- return _torch_to_device (x , device , stream = stream ) # pyright: ignore[reportArgumentType]
884
+ return _torch_to_device (x , device , stream = stream )
889
885
elif is_dask_array (x ):
890
886
if stream is not None :
891
887
raise ValueError ("The stream argument to to_device() is not supported" )
@@ -912,8 +908,6 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
912
908
@overload
913
909
def size (x : HasShape [Collection [SupportsIndex ]]) -> int : ...
914
910
@overload
915
- def size (x : HasShape [Collection [None ]]) -> None : ...
916
- @overload
917
911
def size (x : HasShape [Collection [SupportsIndex | None ]]) -> int | None : ...
918
912
def size (x : HasShape [Collection [SupportsIndex | None ]]) -> int | None :
919
913
"""
@@ -948,7 +942,7 @@ def _is_writeable_cls(cls: type) -> bool | None:
948
942
return None
949
943
950
944
951
- def is_writeable_array (x : object ) -> bool :
945
+ def is_writeable_array (x : object ) -> TypeGuard [ _ArrayApiObj ] :
952
946
"""
953
947
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
954
948
Return False if `x` is not an array API compatible object.
@@ -986,7 +980,7 @@ def _is_lazy_cls(cls: type) -> bool | None:
986
980
return None
987
981
988
982
989
- def is_lazy_array (x : object ) -> bool :
983
+ def is_lazy_array (x : object ) -> TypeGuard [ _ArrayApiObj ] :
990
984
"""Return True if x is potentially a future or it may be otherwise impossible or
991
985
expensive to eagerly read its contents, regardless of their size, e.g. by
992
986
calling ``bool(x)`` or ``float(x)``.
0 commit comments