Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jan 9, 2025
1 parent 72360ed commit a599544
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions src/array_api_extra/_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
**kwargs: Any,
) -> tuple[Array, ...]:
"""
Apply a function that operates on NumPY arrays to any Array API compliant arrays.
Apply a function that operates on NumPY arrays to any Array API compliant arrays,
as long as you can apply ``np.asarray`` to them.
Parameters
----------
Expand Down Expand Up @@ -103,7 +104,7 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
Notes
-----
JAX:
JAX
This allows applying eager functions to jitted JAX arrays, which are lazy.
The function won't be applied until the JAX array is materialized.
Expand All @@ -112,16 +113,21 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
may prevent arrays on a GPU device from being transferred back to CPU.
This is treated as an implicit transfer.
PyTorch, CuPy:
PyTorch, CuPy
These backends raise by default if you attempt to convert arrays on a GPU device
to NumPy.
Dask:
This function allows applying func to the chunks of dask arrays.
Sparse
By default, sparse prevents implicit densification through ``np.asarray`.
`This safety mechanism can be disabled
<https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_.
Dask
This allows applying eager functions to the individual chunks of dask arrays.
The dask graph won't be computed. As a special limitation, `func` must return
exactly one output.
In order to allow Dask you need to specify at least
In order to enable running on Dask you need to specify at least
`input_indices`, `output_indices`, and `core_indices`, but you may also need
`adjust_chunks` and `new_axes` depending on the function.
Expand All @@ -147,7 +153,13 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
... core_indices='i')
This will cause `apply_numpy_func` to raise if the first axis of `x` is broken
along multiple chunks.
along multiple chunks, thus forcing the final user to rechunk ahead of time:
>>> x = x.chunk({0: -1})
This needs to always be a conscious decision on behalf of the final user, as the
new chunks will be larger than the old and may cause memory issues, unless chunk
size is reduced along a different, non-core axis.
"""
if xp is None:
xp = array_namespace(*args)
Expand Down

0 comments on commit a599544

Please sign in to comment.