23
23
24
24
NumPyObject : TypeAlias = np .ndarray [Any , Any ] | np .generic # type: ignore[no-any-explicit]
25
25
P = ParamSpec ("P" )
26
+ else :
27
+ # Sphinx hacks
28
+ NumPyObject = Any
29
+
30
+ class P : # pylint: disable=missing-class-docstring
31
+ args : tuple
32
+ kwargs : dict
26
33
27
34
28
35
@overload
@@ -47,7 +54,7 @@ def apply_numpy_func( # type: ignore[valid-type]
47
54
) -> tuple [Array , ...]: ... # numpydoc ignore=GL08
48
55
49
56
50
- def apply_numpy_func ( # type: ignore[valid-type]
57
+ def apply_numpy_func ( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
51
58
func : Callable [P , NumPyObject | Sequence [NumPyObject ]],
52
59
* args : Array ,
53
60
shape : tuple [int , ...] | Sequence [tuple [int , ...]] | None = None ,
@@ -69,7 +76,7 @@ def apply_numpy_func( # type: ignore[valid-type]
69
76
as depending on the backend it may be executed more than once.
70
77
*args : Array
71
78
One or more Array API compliant arrays. You need to be able to apply
72
- ``np .asarray()` ` to them to convert them to numpy; read notes below about
79
+ :func:`numpy .asarray` to them to convert them to numpy; read notes below about
73
80
specific backends.
74
81
shape : tuple[int, ...] | Sequence[tuple[int, ...]], optional
75
82
Output shape or sequence of output shapes, one for each output of `func`.
@@ -97,25 +104,23 @@ def apply_numpy_func( # type: ignore[valid-type]
97
104
This allows applying eager functions to jitted JAX arrays, which are lazy.
98
105
The function won't be applied until the JAX array is materialized.
99
106
100
- The `JAX transfer guard
101
- <https://jax.readthedocs.io/en/latest/transfer_guard.html>`_
102
- may prevent arrays on a GPU device from being transferred back to CPU.
103
- This is treated as an implicit transfer.
107
+ The :doc:`jax:transfer_guard` may prevent arrays on a GPU device from being
108
+ transferred back to CPU. This is treated as an implicit transfer.
104
109
105
110
PyTorch, CuPy
106
111
These backends raise by default if you attempt to convert arrays on a GPU device
107
112
to NumPy.
108
113
109
114
Sparse
110
- By default, sparse prevents implicit densification through ``np.asarray`.
111
- `This safety mechanism can be disabled
115
+ By default, sparse prevents implicit densification through
116
+ :func:`numpy.asarray`. `This safety mechanism can be disabled
112
117
<https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_.
113
118
114
119
Dask
115
120
This allows applying eager functions to dask arrays.
116
121
The dask graph won't be computed.
117
122
118
- `apply_numpy_func` doesn't know if `func` reduces along any axes and shape
123
+ `apply_numpy_func` doesn't know if `func` reduces along any axes; also, shape
119
124
changes are non-trivial in chunked Dask arrays. For these reasons, all inputs
120
125
will be rechunked into a single chunk.
121
126
@@ -125,9 +130,19 @@ def apply_numpy_func( # type: ignore[valid-type]
125
130
126
131
The outputs will also be returned as a single chunk and you should consider
127
132
rechunking them into smaller chunks afterwards.
133
+
128
134
If you want to distribute the calculation across multiple workers, you
129
- should use `dask.array.map_blocks`, `dask.array.blockwise`,
130
- `dask.array.map_overlap`, or a native Dask wrapper instead of this function.
135
+ should use :func:`dask.array.map_blocks`, :func:`dask.array.map_overlap`,
136
+ :func:`dask.array.blockwise`, or a native Dask wrapper instead of
137
+ `apply_numpy_func`.
138
+
139
+ See Also
140
+ --------
141
+ jax.transfer_guard
142
+ jax.pure_callback
143
+ dask.array.map_blocks
144
+ dask.array.map_overlap
145
+ dask.array.blockwise
131
146
"""
132
147
if xp is None :
133
148
xp = array_namespace (* args )
@@ -239,8 +254,8 @@ def _npfunc_wrapper( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT
239
254
240
255
Any keyword arguments are passed through verbatim to the wrapped function.
241
256
242
- Raise if np.asarray() raises on any input. This typically happens if the input is
243
- lazy and has a guard against being implicitly turned into a NumPy array (e.g.
257
+ Raise if np.asarray raises on any input. This typically happens if the input is lazy
258
+ and has a guard against being implicitly turned into a NumPy array (e.g.
244
259
densification for sparse arrays, device->host transfer for cupy and torch arrays).
245
260
"""
246
261
0 commit comments