diff --git a/src/array_api_extra/_apply.py b/src/array_api_extra/_apply.py index 689ec3f..6d792d5 100644 --- a/src/array_api_extra/_apply.py +++ b/src/array_api_extra/_apply.py @@ -39,7 +39,8 @@ def apply_numpy_func( # type: ignore[no-any-explicit] **kwargs: Any, ) -> tuple[Array, ...]: """ - Apply a function that operates on NumPY arrays to any Array API compliant arrays. + Apply a function that operates on NumPY arrays to any Array API compliant arrays, + as long as you can apply ``np.asarray`` to them. Parameters ---------- @@ -103,7 +104,7 @@ def apply_numpy_func( # type: ignore[no-any-explicit] Notes ----- - JAX: + JAX This allows applying eager functions to jitted JAX arrays, which are lazy. The function won't be applied until the JAX array is materialized. @@ -112,16 +113,21 @@ def apply_numpy_func( # type: ignore[no-any-explicit] may prevent arrays on a GPU device from being transferred back to CPU. This is treated as an implicit transfer. - PyTorch, CuPy: + PyTorch, CuPy These backends raise by default if you attempt to convert arrays on a GPU device to NumPy. - Dask: - This function allows applying func to the chunks of dask arrays. + Sparse + By default, sparse prevents implicit densification through ``np.asarray`. + `This safety mechanism can be disabled + `_. + + Dask + This allows applying eager functions to the individual chunks of dask arrays. The dask graph won't be computed. As a special limitation, `func` must return exactly one output. - In order to allow Dask you need to specify at least + In order to enable running on Dask you need to specify at least `input_indices`, `output_indices`, and `core_indices`, but you may also need `adjust_chunks` and `new_axes` depending on the function. @@ -147,7 +153,13 @@ def apply_numpy_func( # type: ignore[no-any-explicit] ... core_indices='i') This will cause `apply_numpy_func` to raise if the first axis of `x` is broken - along multiple chunks. + along multiple chunks, thus forcing the final user to rechunk ahead of time: + + >>> x = x.chunk({0: -1}) + + This needs to always be a conscious decision on behalf of the final user, as the + new chunks will be larger than the old and may cause memory issues, unless chunk + size is reduced along a different, non-core axis. """ if xp is None: xp = array_namespace(*args)