diff --git a/docs/api-lazy.md b/docs/api-lazy.md new file mode 100644 index 00000000..150fd62e --- /dev/null +++ b/docs/api-lazy.md @@ -0,0 +1,15 @@ +# Tools for lazy backends + +These additional functions are meant to be used to support compatibility with +lazy backends, e.g. Dask or Jax: + +```{eval-rst} +.. currentmodule:: array_api_extra +.. autosummary:: + :nosignatures: + :toctree: generated + + lazy_apply + testing.lazy_xp_function + testing.patch_lazy_xp_functions +``` diff --git a/docs/conf.py b/docs/conf.py index 79000c96..4696e7a6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -53,6 +53,8 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable", None), + "dask": ("https://docs.dask.org/en/stable", None), "jax": ("https://jax.readthedocs.io/en/latest", None), } diff --git a/docs/index.md b/docs/index.md index f7c51574..a5c6d7bf 100644 --- a/docs/index.md +++ b/docs/index.md @@ -5,7 +5,7 @@ :hidden: self api-reference.md -testing-utils.md +api-lazy.md contributing.md contributors.md ``` diff --git a/docs/testing-utils.md b/docs/testing-utils.md deleted file mode 100644 index 49aeb306..00000000 --- a/docs/testing-utils.md +++ /dev/null @@ -1,14 +0,0 @@ -# Testing Utilities - -These additional functions are meant to be used while unit testing Array API -compliant packages: - -```{eval-rst} -.. currentmodule:: array_api_extra.testing -.. autosummary:: - :nosignatures: - :toctree: generated - - lazy_xp_function - patch_lazy_xp_functions -``` diff --git a/pixi.lock b/pixi.lock index 00c8feef..e8da46d4 100644 --- a/pixi.lock +++ b/pixi.lock @@ -1620,16 +1620,22 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/h2-4.1.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/hpack-4.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/hyperframe-6.1.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/icu-75.1-he02047a_0.conda - conda: https://prefix.dev/conda-forge/noarch/idna-3.10-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/importlib-metadata-8.6.1-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/jinja2-3.1.5-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/ld_impl_linux-64-2.43-h712a8e2_2.conda + - conda: https://prefix.dev/conda-forge/linux-64/libblas-3.9.0-26_linux64_mkl.conda + - conda: https://prefix.dev/conda-forge/linux-64/libcblas-3.9.0-26_linux64_mkl.conda - conda: https://prefix.dev/conda-forge/linux-64/libexpat-2.6.4-h5888daf_0.conda - conda: https://prefix.dev/conda-forge/linux-64/libffi-3.4.2-h7f98852_5.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/libgcc-14.2.0-h77fa898_1.conda - conda: https://prefix.dev/conda-forge/linux-64/libgcc-ng-14.2.0-h69a702a_1.conda + - conda: https://prefix.dev/conda-forge/linux-64/libhwloc-2.11.2-default_h0d58e46_1001.conda + - conda: https://prefix.dev/conda-forge/linux-64/libiconv-1.17-hd590300_2.conda + - conda: https://prefix.dev/conda-forge/linux-64/liblapack-3.9.0-26_linux64_mkl.conda - conda: https://prefix.dev/conda-forge/linux-64/liblzma-5.6.3-hb9d3cd8_1.conda - conda: https://prefix.dev/conda-forge/linux-64/libnsl-2.0.1-hd590300_0.conda - conda: https://prefix.dev/conda-forge/linux-64/libsqlite-3.48.0-hee588c1_1.conda @@ -1637,6 +1643,7 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/libstdcxx-ng-14.2.0-h4852527_1.conda - conda: https://prefix.dev/conda-forge/linux-64/libuuid-2.38.1-h0b41bf4_0.conda - conda: https://prefix.dev/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda + - conda: https://prefix.dev/conda-forge/linux-64/libxml2-2.13.5-h8d12d68_1.conda - conda: https://prefix.dev/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda - conda: https://prefix.dev/conda-forge/linux-64/llvm-openmp-19.1.7-h024ca30_0.conda - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 @@ -1644,8 +1651,10 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/markupsafe-3.0.2-py312h178313f_1.conda - conda: https://prefix.dev/conda-forge/noarch/mdit-py-plugins-0.4.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/mdurl-0.1.2-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/linux-64/mkl-2024.2.2-ha957f24_16.conda - conda: https://prefix.dev/conda-forge/noarch/myst-parser-4.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/ncurses-6.5-h2d0b736_2.conda + - conda: https://prefix.dev/conda-forge/linux-64/numpy-2.0.2-py312h58c1407_1.conda - conda: https://prefix.dev/conda-forge/linux-64/openssl-3.4.0-h7b32b05_1.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-24.2-pyhd8ed1ab_2.conda - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda @@ -1672,6 +1681,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-jsmath-1.0.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-qthelp-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/linux-64/tbb-2021.13.0-hceb3a55_1.conda - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/toolz-1.0.0-pyhd8ed1ab_1.conda @@ -1711,12 +1721,19 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/importlib-metadata-8.6.1-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/jinja2-3.1.5-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/libblas-3.9.0-26_osxarm64_openblas.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/libcblas-3.9.0-26_osxarm64_openblas.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libcxx-19.1.7-ha82da77_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libexpat-2.6.4-h286801f_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libffi-3.4.2-h3422bc3_5.tar.bz2 + - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran-5.0.0-13_2_0_hd922786_3.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran5-13.2.0-hf226fd6_3.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/liblapack-3.9.0-26_osxarm64_openblas.conda - conda: https://prefix.dev/conda-forge/osx-arm64/liblzma-5.6.3-h39f12f2_1.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/libopenblas-0.3.28-openmp_hf332438_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libsqlite-3.48.0-h3f77e49_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libzlib-1.3.1-h8359307_2.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/llvm-openmp-19.1.7-hdb05f8b_0.conda - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/markupsafe-3.0.2-py312h998013c_1.conda @@ -1724,6 +1741,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/mdurl-0.1.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/myst-parser-4.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ncurses-6.5-h5e97a16_2.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/numpy-2.0.2-py312h94ee1e1_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/openssl-3.4.0-h81ee809_1.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-24.2-pyhd8ed1ab_2.conda - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda @@ -1788,18 +1806,28 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/importlib-metadata-8.6.1-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/win-64/intel-openmp-2024.2.1-h57928b3_1083.conda - conda: https://prefix.dev/conda-forge/noarch/jinja2-3.1.5-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/win-64/libblas-3.9.0-26_win64_mkl.conda + - conda: https://prefix.dev/conda-forge/win-64/libcblas-3.9.0-26_win64_mkl.conda - conda: https://prefix.dev/conda-forge/win-64/libexpat-2.6.4-he0c23c2_0.conda - conda: https://prefix.dev/conda-forge/win-64/libffi-3.4.2-h8ffe710_5.tar.bz2 + - conda: https://prefix.dev/conda-forge/win-64/libhwloc-2.11.2-default_ha69328c_1001.conda + - conda: https://prefix.dev/conda-forge/win-64/libiconv-1.17-hcfcfb64_2.conda + - conda: https://prefix.dev/conda-forge/win-64/liblapack-3.9.0-26_win64_mkl.conda - conda: https://prefix.dev/conda-forge/win-64/liblzma-5.6.3-h2466b09_1.conda - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.48.0-h67fdade_1.conda + - conda: https://prefix.dev/conda-forge/win-64/libwinpthread-12.0.0.r4.gg4f2fc60ca-h57928b3_9.conda + - conda: https://prefix.dev/conda-forge/win-64/libxml2-2.13.5-he286e8c_1.conda - conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/markupsafe-3.0.2-py312h31fea79_1.conda - conda: https://prefix.dev/conda-forge/noarch/mdit-py-plugins-0.4.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/mdurl-0.1.2-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/win-64/mkl-2024.2.2-h66d3029_15.conda - conda: https://prefix.dev/conda-forge/noarch/myst-parser-4.0.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/win-64/numpy-2.0.2-py312h49bc9c5_1.conda - conda: https://prefix.dev/conda-forge/win-64/openssl-3.4.0-ha4e3fda_1.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-24.2-pyhd8ed1ab_2.conda - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda @@ -1825,6 +1853,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-jsmath-1.0.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-qthelp-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/win-64/tbb-2021.13.0-h62715c5_1.conda - conda: https://prefix.dev/conda-forge/win-64/tk-8.6.13-h5226925_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/toolz-1.0.0-pyhd8ed1ab_1.conda @@ -3767,7 +3796,7 @@ packages: - pypi: . name: array-api-extra version: 0.6.1.dev0 - sha256: bb6cd89a7f100a73d3f853de571b2f4fff0e70de8df0d113f2f5c1559744e6b6 + sha256: 1e032f707df46a29e306ede97d65b2129e0944b361b96317e5653bd74e695ce2 requires_dist: - array-api-compat>=1.10.0,<2 requires_python: '>=3.10' diff --git a/pyproject.toml b/pyproject.toml index d15aba84..26a73fc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,7 @@ sphinx-autodoc-typehints = "*" dask-core = "*" pytest = "*" typing-extensions = "*" +numpy = "*" [tool.pixi.feature.docs.tasks] docs = { cmd = "sphinx-build . build/", cwd = "docs" } @@ -311,10 +312,5 @@ checks = [ "ES01", # most docstrings do not need an extended summary ] exclude = [ # don't report on objects that match any of these regex - '.*test_at.*', - '.*test_funcs.*', - '.*test_testing.*', - '.*test_utils.*', - '.*test_version.*', - '.*test_vendor.*', + '.*test_*', ] diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 840dd8e7..aeedd9da 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -12,6 +12,7 @@ setdiff1d, sinc, ) +from ._lib._lazy import lazy_apply __version__ = "0.6.1.dev0" @@ -25,6 +26,7 @@ "expand_dims", "isclose", "kron", + "lazy_apply", "nunique", "pad", "setdiff1d", diff --git a/src/array_api_extra/_lib/_lazy.py b/src/array_api_extra/_lib/_lazy.py new file mode 100644 index 00000000..47a2cc83 --- /dev/null +++ b/src/array_api_extra/_lib/_lazy.py @@ -0,0 +1,368 @@ +"""Public API Functions.""" + +# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 +from __future__ import annotations + +import math +from collections.abc import Callable, Sequence +from functools import partial, wraps +from types import ModuleType +from typing import TYPE_CHECKING, Any, cast, overload + +from ._utils._compat import ( + array_namespace, + is_array_api_obj, + is_dask_namespace, + is_jax_array, + is_jax_namespace, +) +from ._utils._typing import Array, DType + +if TYPE_CHECKING: + # TODO move outside TYPE_CHECKING + # depends on scikit-learn abandoning Python 3.9 + # https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 + from typing import ParamSpec, TypeAlias + + import numpy as np + from numpy.typing import ArrayLike + + NumPyObject: TypeAlias = np.ndarray[Any, Any] | np.generic # type: ignore[no-any-explicit] + P = ParamSpec("P") +else: + # Sphinx hacks + NumPyObject = Any + + class P: # pylint: disable=missing-class-docstring + args: tuple + kwargs: dict + + +@overload +def lazy_apply( # type: ignore[valid-type] + func: Callable[P, ArrayLike], + *args: Array, + shape: tuple[int | None, ...] | None = None, + dtype: DType | None = None, + as_numpy: bool = False, + xp: ModuleType | None = None, + **kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues] +) -> Array: ... # numpydoc ignore=GL08 + + +@overload +def lazy_apply( # type: ignore[valid-type] + func: Callable[P, Sequence[ArrayLike]], + *args: Array, + shape: Sequence[tuple[int | None, ...]], + dtype: Sequence[DType] | None = None, + as_numpy: bool = False, + xp: ModuleType | None = None, + **kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues] +) -> tuple[Array, ...]: ... # numpydoc ignore=GL08 + + +def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04 + func: Callable[P, Array | Sequence[ArrayLike]], + *args: Array, + shape: tuple[int | None, ...] | Sequence[tuple[int | None, ...]] | None = None, + dtype: DType | Sequence[DType] | None = None, + as_numpy: bool = False, + xp: ModuleType | None = None, + **kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues] +) -> Array | tuple[Array, ...]: + """ + Lazily apply an eager function. + + If the backend of the input arrays is lazy, e.g. Dask or jitted JAX, the execution + of the function is delayed until the graph is materialized; if it's eager, the + function is executed immediately. + + Parameters + ---------- + func : callable + The function to apply. + + It must accept one or more array API compliant arrays as positional arguments. + If `as_numpy=True`, inputs are converted to NumPy before they are passed to + `func`. + It must return either a single array-like or a sequence of array-likes. + + `func` must be a pure function, i.e. without side effects, as depending on the + backend it may be executed more than once. + *args : Array + One or more Array API compliant arrays. + + If `as_numpy=True`, you need to be able to apply :func:`numpy.asarray` to them + to convert them to numpy; read notes below about specific backends. + shape : tuple[int | None, ...] | Sequence[tuple[int, ...]], optional + Output shape or sequence of output shapes, one for each output of `func`. + Default: assume single output and broadcast shapes of the input arrays. + dtype : DType | Sequence[DType], optional + Output dtype or sequence of output dtypes, one for each output of `func`. + dtype(s) must belong to the same array namespace as the input arrays. + Default: infer the result type(s) from the input arrays. + as_numpy : bool, optional + If True, convert the input arrays to NumPy before passing them to `func`. + This is particularly useful to make numpy-only functions, e.g. written in Cython + or Numba, work transparently API arrays. + Default: False. + xp : array_namespace, optional + The standard-compatible namespace for `args`. Default: infer. + **kwargs : Any, optional + Additional keyword arguments to pass verbatim to `func`. + Any array objects in them will be converted to numpy when ``as_numpy=True``. + + Returns + ------- + Array | tuple[Array, ...] + The result(s) of `func` applied to the input arrays, wrapped in the same + array namespace as the inputs. + If shape is omitted or a `tuple[int | None, ...]`, this is a single array. + Otherwise, it's a tuple of arrays. + + Notes + ----- + JAX + This allows applying eager functions to jitted JAX arrays, which are lazy. + The function won't be applied until the JAX array is materialized. + When running inside `jax.jit`, `shape` must be fully known, i.e. it cannot + contain any `None` elements. + + Using this with `as_numpy=False` is particularly useful to apply non-jittable + JAX functions to arrays on GPU devices. + If `as_numpy=True`, the :doc:`jax:transfer_guard` may prevent arrays on a GPU + device from being transferred back to CPU. This is treated as an implicit + transfer. + + PyTorch, CuPy + If `as_numpy=True`, these backends raise by default if you attempt to convert + arrays on a GPU device to NumPy. + + Sparse + If `as_numpy=True`, by default sparse prevents implicit densification through + :func:`numpy.asarray`. `This safety mechanism can be disabled + `_. + + Dask + This allows applying eager functions to dask arrays. + The dask graph won't be computed. + + `lazy_apply` doesn't know if `func` reduces along any axes; also, shape + changes are non-trivial in chunked Dask arrays. For these reasons, all inputs + will be rechunked into a single chunk. + + .. warning:: + + The whole operation needs to fit in memory all at once on a single worker. + + The outputs will also be returned as a single chunk and you should consider + rechunking them into smaller chunks afterwards. + + If you want to distribute the calculation across multiple workers, you + should use :func:`dask.array.map_blocks`, :func:`dask.array.map_overlap`, + :func:`dask.array.blockwise`, or a native Dask wrapper instead of + `lazy_apply`. + + Dask wrapping around other backends + If `as_numpy=False`, `func` will receive in input eager arrays of the meta + namespace, as defined by the `._meta` attribute of the input Dask arrays. + The outputs of `func` will be wrapped by the meta namespace, and then wrapped + again by Dask. + + Raises + ------ + jax.errors.TracerArrayConversionError + When `xp=jax.numpy`, `shape` is unknown (it contains None on one or more axes) + and this function was called inside `jax.jit`. + RuntimeError + When `xp=sparse` and auto-densification is disabled. + Exception (backend-specific) + When the backend disallows implicit device to host transfers and the input + arrays are on a device, e.g. on GPU. + + See Also + -------- + jax.transfer_guard + jax.pure_callback + dask.array.map_blocks + dask.array.map_overlap + dask.array.blockwise + """ + if xp is None: + xp = array_namespace(*args) + + # Normalize and validate shape and dtype + shapes: list[tuple[int | None, ...]] + dtypes: list[DType] + multi_output = False + + if shape is None: + shapes = [xp.broadcast_shapes(*(arg.shape for arg in args))] + elif isinstance(shape, tuple) and all(isinstance(s, int | None) for s in shape): + shapes = [shape] # pyright: ignore[reportAssignmentType] + else: + shapes = list(shape) # type: ignore[arg-type] # pyright: ignore[reportAssignmentType] + multi_output = True + + if dtype is None: + dtypes = [xp.result_type(*args)] * len(shapes) + elif multi_output: + if not isinstance(dtype, Sequence): + msg = "Got sequence of shapes but only one dtype" + raise TypeError(msg) + dtypes = list(dtype) # pyright: ignore[reportUnknownArgumentType] + else: + if isinstance(dtype, Sequence): + msg = "Got single shape but multiple dtypes" + raise TypeError(msg) + dtypes = [dtype] + + if len(shapes) != len(dtypes): + msg = f"Got {len(shapes)} shapes and {len(dtypes)} dtypes" + raise ValueError(msg) + if len(shapes) == 0: + msg = "func must return one or more output arrays" + raise ValueError(msg) + del shape + del dtype + + # Backend-specific branches + if is_dask_namespace(xp): + import dask + + metas = [arg._meta for arg in args if hasattr(arg, "_meta")] # pylint: disable=protected-access + meta_xp = array_namespace(*metas) + + wrapped = dask.delayed( # type: ignore[attr-defined] # pyright: ignore[reportPrivateImportUsage] + _lazy_apply_wrapper(func, as_numpy, multi_output, meta_xp), + pure=True, + ) + # This finalizes each arg, which is the same as arg.rechunk(-1). + # Please read docstring above for why we're not using + # dask.array.map_blocks or dask.array.blockwise! + delayed_out = wrapped(*args, **kwargs) + + out = tuple( + xp.from_delayed( + delayed_out[i], # pyright: ignore[reportIndexIssue] + # Dask's unknown shapes diverge from the Array API specification + shape=tuple(math.nan if s is None else s for s in shape), + dtype=dtype, + meta=metas[0], + ) + for i, (shape, dtype) in enumerate(zip(shapes, dtypes, strict=True)) + ) + + elif is_jax_namespace(xp): + # If we're inside jax.jit, we can't eagerly convert + # the JAX tracer objects to numpy. + # Instead, we delay calling wrapped, which will receive + # as arguments and will return JAX eager arrays. + + import jax + + # Shield eager kwargs from being coerced into JAX arrays. + # jax.pure_callback calls jax.jit under the hood, but without the chance of + # passing static_argnames / static_argnums. + lazy_kwargs = {} + eager_kwargs = {} + for k, v in kwargs.items(): + if _contains_jax_arrays(v): + lazy_kwargs[k] = v + else: + eager_kwargs[k] = v + + wrapped = _lazy_apply_wrapper( + partial(func, **eager_kwargs), as_numpy, multi_output, xp + ) + + if any(s is None for shape in shapes for s in shape): + # Unknown output shape. Won't work with jax.jit, but it + # can work with eager jax. + # Raises jax.errors.TracerArrayConversionError if we're inside jax.jit. + out = wrapped(*args, **lazy_kwargs) + + else: + # suppress unused-ignore to run mypy in -e lint as well as -e dev + out = cast( # type: ignore[bad-cast,unused-ignore] + tuple[Array, ...], + jax.pure_callback( + wrapped, + tuple( + jax.ShapeDtypeStruct(shape, dtype) # pyright: ignore[reportUnknownArgumentType] + for shape, dtype in zip(shapes, dtypes, strict=True) + ), + *args, + **lazy_kwargs, + ), + ) + + else: + # Eager backends + wrapped = _lazy_apply_wrapper(func, as_numpy, multi_output, xp) + out = wrapped(*args, **kwargs) + + return out if multi_output else out[0] + + +def _contains_jax_arrays(x: object) -> bool: # numpydoc ignore=PR01,RT01 + """ + Test if x is a JAX array or a nested collection with any JAX arrays in it. + """ + if is_jax_array(x): + return True + if isinstance(x, list | tuple): + return any(_contains_jax_arrays(i) for i in x) # pyright: ignore[reportUnknownArgumentType] + if isinstance(x, dict): + return any(_contains_jax_arrays(i) for i in x.values()) # pyright: ignore[reportUnknownArgumentType] + return False + + +def _as_numpy(x: object) -> Any: # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01 + """Recursively convert Array API objects in x to NumPy.""" + import numpy as np # pylint: disable=import-outside-toplevel + + if is_array_api_obj(x): + return np.asarray(x) + if isinstance(x, list) or type(x) is tuple: # pylint: disable=unidiomatic-typecheck + return type(x)(_as_numpy(i) for i in x) # pyright: ignore[reportUnknownArgumentType] + if isinstance(x, tuple): # namedtuple + return type(x)(*(_as_numpy(i) for i in x)) # pyright: ignore[reportUnknownArgumentType] + if isinstance(x, dict): + return {k: _as_numpy(v) for k, v in x.items()} # pyright: ignore[reportUnknownArgumentType] + return x + + +def _lazy_apply_wrapper( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01 + func: Callable[..., ArrayLike | Sequence[ArrayLike]], + as_numpy: bool, + multi_output: bool, + xp: ModuleType, +) -> Callable[..., tuple[Array, ...]]: + """ + Helper of `lazy_apply`. + + Given a function that accepts one or more arrays as positional arguments and returns + a single array-like or a sequence of array-likes, return a function that accepts the + same number of Array API arrays and always returns a tuple of Array API array. + + Any keyword arguments are passed through verbatim to the wrapped function. + """ + + # On Dask, @wraps causes the graph key to contain the wrapped function's name + @wraps(func) + def wrapper( # type: ignore[no-any-decorated,no-any-explicit] + *args: Array, **kwargs: Any + ) -> tuple[Array, ...]: # numpydoc ignore=GL08 + if as_numpy: + args = _as_numpy(args) + kwargs = _as_numpy(kwargs) + out = func(*args, **kwargs) + + if multi_output: + assert isinstance(out, Sequence) + return tuple(xp.asarray(o) for o in out) + return (xp.asarray(out),) + + return wrapper diff --git a/src/array_api_extra/_lib/_utils/_typing.py b/src/array_api_extra/_lib/_utils/_typing.py index 83b51d04..95f29f79 100644 --- a/src/array_api_extra/_lib/_utils/_typing.py +++ b/src/array_api_extra/_lib/_utils/_typing.py @@ -5,6 +5,7 @@ # To be changed to a Protocol later (see data-apis/array-api#589) Array = Any # type: ignore[no-any-explicit] Device = Any # type: ignore[no-any-explicit] +DType = Any # type: ignore[no-any-explicit] Index = Any # type: ignore[no-any-explicit] -__all__ = ["Array", "Device", "Index"] +__all__ = ["Array", "DType", "Device", "Index"] diff --git a/src/array_api_extra/testing.py b/src/array_api_extra/testing.py index cc3f01f8..f0e0d7c1 100644 --- a/src/array_api_extra/testing.py +++ b/src/array_api_extra/testing.py @@ -132,12 +132,12 @@ def test_myfunc(xp): a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c = mymodule.myfunc(a) # This is not """ - func.allow_dask_compute = allow_dask_compute # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess] - if jax_jit: - func.lazy_jax_jit_kwargs = { # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess] - "static_argnums": static_argnums, - "static_argnames": static_argnames, - } + func.lazy_xp_function = { # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess] + "allow_dask_compute": allow_dask_compute, + "jax_jit": jax_jit, + "static_argnums": static_argnums, + "static_argnames": static_argnames, + } def patch_lazy_xp_functions( @@ -181,10 +181,13 @@ def xp(request, monkeypatch): if is_dask_namespace(xp): for name, func in globals_.items(): - n = getattr(func, "allow_dask_compute", None) - if n is not None: + kwargs = cast( # type: ignore[no-any-explicit] + "dict[str, Any] | None", getattr(func, "lazy_xp_function", None) + ) + if kwargs is not None: + n = kwargs["allow_dask_compute"] assert isinstance(n, int) - wrapped = _allow_dask_compute(func, n) + wrapped = _dask_wrap(func, n) monkeypatch.setitem(globals_, name, wrapped) elif is_jax_namespace(xp): @@ -192,12 +195,16 @@ def xp(request, monkeypatch): for name, func in globals_.items(): kwargs = cast( # type: ignore[no-any-explicit] - "dict[str, Any] | None", getattr(func, "lazy_jax_jit_kwargs", None) + "dict[str, Any] | None", getattr(func, "lazy_xp_function", None) ) - if kwargs is not None: + if kwargs is not None and kwargs["jax_jit"]: # suppress unused-ignore to run mypy in -e lint as well as -e dev - wrapped = cast(Callable[..., Any], jax.jit(func, **kwargs)) # type: ignore[no-any-explicit,no-untyped-call,unused-ignore] - monkeypatch.setitem(globals_, name, wrapped) + wrapped = jax.jit( # type: ignore[no-untyped-call,unused-ignore] + func, + static_argnums=kwargs["static_argnums"], + static_argnames=kwargs["static_argnames"], + ) + monkeypatch.setitem(globals_, name, wrapped) # pyright: ignore[reportUnknownArgumentType] class CountingDaskScheduler(SchedulerGetCallable): @@ -236,13 +243,15 @@ def __call__(self, dsk: Graph, keys: Sequence[Key] | Key, **kwargs: Any) -> Any: return dask.get(dsk, keys, **kwargs) # type: ignore[attr-defined,no-untyped-call] # pyright: ignore[reportPrivateImportUsage] -def _allow_dask_compute( +def _dask_wrap( func: Callable[P, T], n: int ) -> Callable[P, T]: # numpydoc ignore=PR01,RT01 """ Wrap `func` to raise if it attempts to call `dask.compute` more than `n` times. + + After the function returns, materialize the graph in order to re-raise exceptions. """ - import dask.config + import dask func_name = getattr(func, "__name__", str(func)) n_str = f"only up to {n}" if n else "no" @@ -256,7 +265,12 @@ def _allow_dask_compute( @wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 scheduler = CountingDaskScheduler(n, msg) - with dask.config.set({"scheduler": scheduler}): - return func(*args, **kwargs) + with dask.config.set({"scheduler": scheduler}): # pyright: ignore[reportPrivateImportUsage] + out = func(*args, **kwargs) + + # Block until the graph materializes and reraise exceptions. This allows + # `pytest.raises` and `pytest.warns` to work as expected. Note that this would + # not work on scheduler='distributed', as it would not block. + return dask.persist(out, scheduler="threads")[0] # type: ignore[no-any-return,attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage] return wrapper diff --git a/tests/test_lazy.py b/tests/test_lazy.py new file mode 100644 index 00000000..a01e31c7 --- /dev/null +++ b/tests/test_lazy.py @@ -0,0 +1,82 @@ +from types import ModuleType +from typing import NamedTuple + +import numpy as np +import pytest + +from array_api_extra import lazy_apply +from array_api_extra._lib import Backend +from array_api_extra._lib._testing import xp_assert_equal +from array_api_extra._lib._utils._typing import Array +from array_api_extra.testing import lazy_xp_function + +skip_as_numpy = [ + pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host transfer"), + pytest.mark.skip_xp_backend(Backend.SPARSE, reason="densification"), +] + + +@pytest.mark.parametrize("as_numpy", [False, pytest.param(True, marks=skip_as_numpy)]) +def test_lazy_apply_kwargs(xp: ModuleType, library: Backend, as_numpy: bool) -> None: + expect = np.ndarray if as_numpy or library is Backend.DASK else type(xp.asarray(0)) + + class NT(NamedTuple): + a: Array + + def f( + x: Array, + z: dict[str, list[Array] | tuple[Array, ...] | NT], + msg: str, + msgs: list[str], + ) -> Array: + assert isinstance(x, expect) + assert isinstance(z["foo"], NT) + assert isinstance(z["foo"].a, expect) + assert isinstance(z["bar"][0], expect) + assert isinstance(z["baz"][0], expect) + assert msg == "Hello World" + assert msgs[0] == "Hello World" + return x + + x = xp.asarray(0) + y = lazy_apply( # pyright: ignore[reportCallIssue] + f, + x, + z={"foo": NT(x), "bar": [x], "baz": (x,)}, + msg="Hello World", + msgs=["Hello World"], + shape=x.shape, + dtype=x.dtype, + as_numpy=as_numpy, + ) + xp_assert_equal(x, y) + + +class CustomError(Exception): + pass + + +def raises(x: Array) -> Array: + def eager(_: Array) -> Array: + msg = "Hello World" + raise CustomError(msg) + + return lazy_apply(eager, x, shape=x.shape, dtype=x.dtype) + + +lazy_xp_function(raises) + + +def test_lazy_apply_raises(xp: ModuleType, library: Backend) -> None: + x = xp.asarray(0) + + with pytest.raises( + # FIXME https://github.com/jax-ml/jax/issues/26102 + RuntimeError if library is Backend.JAX else CustomError, + match="Hello World", + ): + # Here we are disregarding the return value, which would + # normally cause the graph not to materialize and the + # exception not to be raised. + # However, lazy_xp_function will do it for us on function exit. + raises(x)