@@ -39,7 +39,8 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
39
39
** kwargs : Any ,
40
40
) -> tuple [Array , ...]:
41
41
"""
42
- Apply a function that operates on NumPY arrays to any Array API compliant arrays.
42
+ Apply a function that operates on NumPY arrays to any Array API compliant arrays,
43
+ as long as you can apply ``np.asarray`` to them.
43
44
44
45
Parameters
45
46
----------
@@ -103,7 +104,7 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
103
104
104
105
Notes
105
106
-----
106
- JAX:
107
+ JAX
107
108
This allows applying eager functions to jitted JAX arrays, which are lazy.
108
109
The function won't be applied until the JAX array is materialized.
109
110
@@ -112,16 +113,21 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
112
113
may prevent arrays on a GPU device from being transferred back to CPU.
113
114
This is treated as an implicit transfer.
114
115
115
- PyTorch, CuPy:
116
+ PyTorch, CuPy
116
117
These backends raise by default if you attempt to convert arrays on a GPU device
117
118
to NumPy.
118
119
119
- Dask:
120
- This function allows applying func to the chunks of dask arrays.
120
+ Sparse
121
+ By default, sparse prevents implicit densification through ``np.asarray`.
122
+ `This safety mechanism can be disabled
123
+ <https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_.
124
+
125
+ Dask
126
+ This allows applying eager functions to the individual chunks of dask arrays.
121
127
The dask graph won't be computed. As a special limitation, `func` must return
122
128
exactly one output.
123
129
124
- In order to allow Dask you need to specify at least
130
+ In order to enable running on Dask you need to specify at least
125
131
`input_indices`, `output_indices`, and `core_indices`, but you may also need
126
132
`adjust_chunks` and `new_axes` depending on the function.
127
133
@@ -147,7 +153,13 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
147
153
... core_indices='i')
148
154
149
155
This will cause `apply_numpy_func` to raise if the first axis of `x` is broken
150
- along multiple chunks.
156
+ along multiple chunks, thus forcing the final user to rechunk ahead of time:
157
+
158
+ >>> x = x.chunk({0: -1})
159
+
160
+ This needs to always be a conscious decision on behalf of the final user, as the
161
+ new chunks will be larger than the old and may cause memory issues, unless chunk
162
+ size is reduced along a different, non-core axis.
151
163
"""
152
164
if xp is None :
153
165
xp = array_namespace (* args )
0 commit comments