Skip to content

Commit a599544

Browse files
committed
docs
1 parent 72360ed commit a599544

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

src/array_api_extra/_apply.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
3939
**kwargs: Any,
4040
) -> tuple[Array, ...]:
4141
"""
42-
Apply a function that operates on NumPY arrays to any Array API compliant arrays.
42+
Apply a function that operates on NumPY arrays to any Array API compliant arrays,
43+
as long as you can apply ``np.asarray`` to them.
4344
4445
Parameters
4546
----------
@@ -103,7 +104,7 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
103104
104105
Notes
105106
-----
106-
JAX:
107+
JAX
107108
This allows applying eager functions to jitted JAX arrays, which are lazy.
108109
The function won't be applied until the JAX array is materialized.
109110
@@ -112,16 +113,21 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
112113
may prevent arrays on a GPU device from being transferred back to CPU.
113114
This is treated as an implicit transfer.
114115
115-
PyTorch, CuPy:
116+
PyTorch, CuPy
116117
These backends raise by default if you attempt to convert arrays on a GPU device
117118
to NumPy.
118119
119-
Dask:
120-
This function allows applying func to the chunks of dask arrays.
120+
Sparse
121+
By default, sparse prevents implicit densification through ``np.asarray`.
122+
`This safety mechanism can be disabled
123+
<https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_.
124+
125+
Dask
126+
This allows applying eager functions to the individual chunks of dask arrays.
121127
The dask graph won't be computed. As a special limitation, `func` must return
122128
exactly one output.
123129
124-
In order to allow Dask you need to specify at least
130+
In order to enable running on Dask you need to specify at least
125131
`input_indices`, `output_indices`, and `core_indices`, but you may also need
126132
`adjust_chunks` and `new_axes` depending on the function.
127133
@@ -147,7 +153,13 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
147153
... core_indices='i')
148154
149155
This will cause `apply_numpy_func` to raise if the first axis of `x` is broken
150-
along multiple chunks.
156+
along multiple chunks, thus forcing the final user to rechunk ahead of time:
157+
158+
>>> x = x.chunk({0: -1})
159+
160+
This needs to always be a conscious decision on behalf of the final user, as the
161+
new chunks will be larger than the old and may cause memory issues, unless chunk
162+
size is reduced along a different, non-core axis.
151163
"""
152164
if xp is None:
153165
xp = array_namespace(*args)

0 commit comments

Comments
 (0)