Skip to content

Commit 59b463b

Browse files
committed
docs
1 parent 820043e commit 59b463b

File tree

4 files changed

+58
-15
lines changed

4 files changed

+58
-15
lines changed

docs/conf.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@
5353

5454
intersphinx_mapping = {
5555
"python": ("https://docs.python.org/3", None),
56-
"jax": ("https://jax.readthedocs.io/en/latest", None),
56+
"numpy": ("https://numpy.org/doc/stable", None),
5757
"dask": ("https://docs.dask.org/en/stable", None),
58+
"jax": ("https://jax.readthedocs.io/en/latest", None),
5859
}
5960

6061
nitpick_ignore = [

pixi.lock

+27-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ furo = ">=2023.08.17"
113113
myst-parser = ">=0.13"
114114
sphinx-copybutton = "*"
115115
sphinx-autodoc-typehints = "*"
116+
numpy = "*"
116117

117118
[tool.pixi.feature.docs.tasks]
118119
docs = { cmd = "sphinx-build . build/", cwd = "docs" }

src/array_api_extra/_apply.py

+28-13
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@
2323

2424
NumPyObject: TypeAlias = np.ndarray[Any, Any] | np.generic # type: ignore[no-any-explicit]
2525
P = ParamSpec("P")
26+
else:
27+
# Sphinx hacks
28+
NumPyObject = Any
29+
30+
class P: # pylint: disable=missing-class-docstring
31+
args: tuple
32+
kwargs: dict
2633

2734

2835
@overload
@@ -47,7 +54,7 @@ def apply_numpy_func( # type: ignore[valid-type]
4754
) -> tuple[Array, ...]: ... # numpydoc ignore=GL08
4855

4956

50-
def apply_numpy_func( # type: ignore[valid-type]
57+
def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
5158
func: Callable[P, NumPyObject | Sequence[NumPyObject]],
5259
*args: Array,
5360
shape: tuple[int, ...] | Sequence[tuple[int, ...]] | None = None,
@@ -69,7 +76,7 @@ def apply_numpy_func( # type: ignore[valid-type]
6976
as depending on the backend it may be executed more than once.
7077
*args : Array
7178
One or more Array API compliant arrays. You need to be able to apply
72-
``np.asarray()`` to them to convert them to numpy; read notes below about
79+
:func:`numpy.asarray` to them to convert them to numpy; read notes below about
7380
specific backends.
7481
shape : tuple[int, ...] | Sequence[tuple[int, ...]], optional
7582
Output shape or sequence of output shapes, one for each output of `func`.
@@ -97,25 +104,23 @@ def apply_numpy_func( # type: ignore[valid-type]
97104
This allows applying eager functions to jitted JAX arrays, which are lazy.
98105
The function won't be applied until the JAX array is materialized.
99106
100-
The `JAX transfer guard
101-
<https://jax.readthedocs.io/en/latest/transfer_guard.html>`_
102-
may prevent arrays on a GPU device from being transferred back to CPU.
103-
This is treated as an implicit transfer.
107+
The :doc:`jax:transfer_guard` may prevent arrays on a GPU device from being
108+
transferred back to CPU. This is treated as an implicit transfer.
104109
105110
PyTorch, CuPy
106111
These backends raise by default if you attempt to convert arrays on a GPU device
107112
to NumPy.
108113
109114
Sparse
110-
By default, sparse prevents implicit densification through ``np.asarray`.
111-
`This safety mechanism can be disabled
115+
By default, sparse prevents implicit densification through
116+
:func:`numpy.asarray`. `This safety mechanism can be disabled
112117
<https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_.
113118
114119
Dask
115120
This allows applying eager functions to dask arrays.
116121
The dask graph won't be computed.
117122
118-
`apply_numpy_func` doesn't know if `func` reduces along any axes and shape
123+
`apply_numpy_func` doesn't know if `func` reduces along any axes; also, shape
119124
changes are non-trivial in chunked Dask arrays. For these reasons, all inputs
120125
will be rechunked into a single chunk.
121126
@@ -125,9 +130,19 @@ def apply_numpy_func( # type: ignore[valid-type]
125130
126131
The outputs will also be returned as a single chunk and you should consider
127132
rechunking them into smaller chunks afterwards.
133+
128134
If you want to distribute the calculation across multiple workers, you
129-
should use `dask.array.map_blocks`, `dask.array.blockwise`,
130-
`dask.array.map_overlap`, or a native Dask wrapper instead of this function.
135+
should use :func:`dask.array.map_blocks`, :func:`dask.array.map_overlap`,
136+
:func:`dask.array.blockwise`, or a native Dask wrapper instead of
137+
`apply_numpy_func`.
138+
139+
See Also
140+
--------
141+
jax.transfer_guard
142+
jax.pure_callback
143+
dask.array.map_blocks
144+
dask.array.map_overlap
145+
dask.array.blockwise
131146
"""
132147
if xp is None:
133148
xp = array_namespace(*args)
@@ -239,8 +254,8 @@ def _npfunc_wrapper( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT
239254
240255
Any keyword arguments are passed through verbatim to the wrapped function.
241256
242-
Raise if np.asarray() raises on any input. This typically happens if the input is
243-
lazy and has a guard against being implicitly turned into a NumPy array (e.g.
257+
Raise if np.asarray raises on any input. This typically happens if the input is lazy
258+
and has a guard against being implicitly turned into a NumPy array (e.g.
244259
densification for sparse arrays, device->host transfer for cupy and torch arrays).
245260
"""
246261

0 commit comments

Comments
 (0)