Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: lazy_apply #86

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open

ENH: lazy_apply #86

wants to merge 15 commits into from

Conversation

crusaderky
Copy link
Contributor

@crusaderky crusaderky commented Jan 9, 2025

First draft of a wrapper around jax.apply_pure_callback.
Untested.

The enormous amount of extra complication needed to make it work with Dask makes me uncomfortable.
@lucascolley you've been working on Dask support for Scipy; what's your vision for it?

CC @rgommers

.pre-commit-config.yaml Outdated Show resolved Hide resolved
Comment on lines 134 to 135
``core_indices`` is a safety measure to prevent incorrect results on
Dask along chunked axes. Consider this::
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 121 to 122
The dask graph won't be computed. As a special limitation, `func` must return
exactly one output.
Copy link
Contributor Author

@crusaderky crusaderky Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This limitation is straightforward to fix in Dask (at the cost of API duplication).
Until then, however, I suspect it will be a major roadblock for Dask support in scipy.

It can be also hacked outside of dask but I'm a hesitant to do that for the sake of robustness, as it would rely on deliberately triggering key collisions between diverging graph branches.

`input_indices`, `output_indices`, and `core_indices`, but you may also need
`adjust_chunks` and `new_axes` depending on the function.

Read `dask.array.blockwise`:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 130 to 132
- ``output_indices[0]`` maps to the ``out_ind`` parameter
- ``adjust_chunks[0]`` maps to the ``adjust_chunks`` parameter
- ``new_axes[0]`` maps to the ``new_axes`` parameter
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are all lists for forward-compatibility to a yet-to-be-written da.blockwise variant that supports multiple outputs.

Comment on lines 57 to 58
If `func` returns a single (non-sequence) output, this must be a sequence
with a single element.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried overloading but it was a big headache when validating inputs. I found this approach much simpler.

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I am missing some context here. Why are we wrapping arbitrary NumPy functions? Instead of, e.g., considering individual functions which we need one-by-one.

.pre-commit-config.yaml Outdated Show resolved Hide resolved
src/array_api_extra/_apply.py Outdated Show resolved Hide resolved
@crusaderky
Copy link
Contributor Author

crusaderky commented Jan 9, 2025

I think I am missing some context here. Why are we wrapping arbitrary NumPy functions? Instead of, e.g., considering individual functions which we need one-by-one.

There are many points in scipy that look like this:

x = np.asarray(x)
y = np.asarray(y)
z = some_cython_kernel(x, y)
z = xp.asarray(z)

None of them will ever work with arrays on GPU devices, of course, and they'll either need a pure-array-api alternative or a dispatch to cupy.scipy etc.

None of them work with jitted, cpu-based JAX either, because jax.jit doesn't support abrupt materialization of the graph on np.asarray in the middle of jitting. Notably, this differs from torch.compile, which does exactly that.

With Dask, they technically work because there is no materialization guard but most times you would prefer it if there was one.
In the best case, they will be exceptionally slow to run once you reach production-sized data: if there are multiple inputs, chances are that large parts of the graph will be computed multiple times, and they trigger massive transfers between client and workers - which are very likely to kill off the client and the scheduler too (at least in the default configuration direct_to_workers=False where all client<->worker traffic transits through the scheduler).

However, it is possible to make these pieces of Cython code work thanks to this PR.
For JAX, this is straightforward - all you need to know is the output shape(s) and dtype(s).
For Dask, it is, well, the very opposite of straightforward. You can run arbitrary cython kernel in dask, on the workers, but with the very big caveat that any axis they reduce upon can't be chunked or you'll get incorrect behaviour, as explained in the docstring. Additionally, Dask needs to know, in order to function, how each axis of the input maps to each axis of the output and, if the size along each axis changes in any way, how that translates to chunk sizes.

There are two competing functions in Dask to achieve this, with different API:

map_blocks is a variant of blockwise with simplified API, which can only work on broadcastable inputs.

This problem has already been dealt with by xarray, with https://docs.xarray.dev/en/latest/generated/xarray.apply_ufunc.html. Note that xarray API is more user-friendly thanks to each dimension being labelled at all times, so apply_ufunc can do fancy tricks like auto-transposing the inputs and pushing the dimensions that func doesn't know about to the left.

What I tried to implement here is equivalent to xarray.apply_ufunc(..., dask="parallelized"), which under the hood calls da.blockwise.
dask="allowed" makes no sense without an ulterior wrapper around the numpy-like API of dask.

@lucascolley
Copy link
Member

Okay, thanks. When you say

any Array API compliant arrays

in the docstring, that isn't strictly true, right? It relies on np.asarray working on xp arrays and xp.asarray working on np arrays, if I have read the code correctly.

I understand the utility of this PR now.

what's your vision for it?

I hadn't envisioned tackling this yet. In my mind getting Cython kernels working with dask/jax jit has been in the same "for later" basket as handling device transfers or writing new implementations to delegate to. But if the implementation works, makes sense to tackle it.

@crusaderky
Copy link
Contributor Author

Okay, thanks. When you say

any Array API compliant arrays

in the docstring, that isn't strictly true, right? It relies on np.asarray working on xp arrays and xp.asarray working on np arrays, if I have read the code correctly.

Correct. Nominally it will fail when densifying sparse arrays and moving data from GPU to CPU. A final user can however force their way through, if they want to, by deliberately suppressing transfer/densification guards for the time necessary to run the scipy function. Either that, or do an explicit device to cpu transfer / to_dense() call ahead of the function.

There is nothing however a jax or dask user can do today, short of completely getting out of the graph generation phase.

Copy link
Member

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @crusaderky!

The enormous amount of extra complication needed to make it work with Dask makes me uncomfortable.

Yes indeed, it does. That looks like it's way too much. I don't think we would like to use all those extra keywords and Dask-specific functions in SciPy. If you'd drop those, does it make Dask completely non-working or is there a subset of functionality that would still work. I'd say that JAX shows that it can be straightforward, and a similar callback mechanism could be used for PyTorch/MLX/ndonnx as well - if that were to exist in those libraries.

src/array_api_extra/_apply.py Outdated Show resolved Hide resolved
Sparse
By default, sparse prevents implicit densification through ``np.asarray`.
`This safety mechanism can be disabled
<https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine to leave as is for now I'd say. Once sparse adds a better API for this (an env var doesn't work), it seems reasonable to add a force=False option to this function. There are various reasons why one may want to force an expensive conversion; that kind of thing should always be opt-in on a case-by-case basis.

Copy link
Contributor Author

@crusaderky crusaderky Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be tackled by a more general design pattern?

with disable_guards():
    y = scipy.somefunc(x)

where disable_guards is backend-specific.
Applies to torch/cupy/jax arrays on GPU, sparse arrays, etc.

@rgommers
Copy link
Member

Do you have a SciPy branch with this function being used @crusaderky? I'd be interested in playing with it.

@crusaderky
Copy link
Contributor Author

If you'd drop those, does it make Dask completely non-working or is there a subset of functionality that would still work

I could make it work by rechunking all the inputs to a single chunk. In other words the whole calculation would need to fit in memory at once on a single worker.

@crusaderky
Copy link
Contributor Author

Do you have a SciPy branch with this function being used @crusaderky? I'd be interested in playing with it.

Not yet

@crusaderky
Copy link
Contributor Author

I could make it work by rechunking all the inputs to a single chunk. In other words the whole calculation would need to fit in memory at once on a single worker.

@rgommers I rewrote it to do exactly this and now it's a lot cleaner. I'll keep my eyes open if I can see patterns in scipy we can leverage to improve Dask support (e.g. if there is there are frequent elementwise functions that could be trivially served by map_blocks)

@lucascolley lucascolley marked this pull request as draft January 13, 2025 15:50
@lucascolley lucascolley added enhancement New feature or request new function labels Jan 13, 2025
@lucascolley
Copy link
Member

@allcontributors, please add @crusaderky for bug

let me just try this once from this PR...

Copy link
Member

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't tried to test this, but did go through it in a bit more detail now - overall looks good, a few comments. Looking forward to trying it out!

src/array_api_extra/_apply.py Outdated Show resolved Hide resolved
src/array_api_extra/_apply.py Outdated Show resolved Hide resolved
src/array_api_extra/_apply.py Outdated Show resolved Hide resolved
src/array_api_extra/_apply.py Outdated Show resolved Hide resolved
src/array_api_extra/_lib/_compat.pyi Outdated Show resolved Hide resolved
@crusaderky
Copy link
Contributor Author

I've reworked the design a bit.

  • Renamed the function to apply_lazy and added a as_numpy=False optional parameter. This allows for
    • array API compliant eager functions to be wrapped by Dask and applied to Dask arrays with non-numpy _.meta , and
    • eager-only JAX operations (e.g. with output size that's not predictable) to be executed on lazy arrays, which is particularly beneficial for GPU
  • Added support for unknown output size in Dask and eager JAX. This allows Dask support for functions such as scipy.cluster.leaders.

FYI, when I move _lazywhere from scipy, I intend to call it apply_where, which I think makes a good pair with apply_lazy as it conveys that they're both about applying a callback.

if any(s is None for shape in shapes for s in shape):
# Unknown output shape. Won't work with jax.jit, but it
# can work with eager jax.
# Raises jax.errors.TracerArrayConversionError if we're inside jax.jit.
Copy link
Contributor Author

@crusaderky crusaderky Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Offline conversation:

@crusaderky:

how do you see scipy functions with unknown output size work with jax.jit? e.g. scipy.cluster.leaders? Should we do like in jnp.unique_all and add a size=None optional parameter, which becomes mandatory when the jit is on?

@rgommers:

I'm not sure that we should add support for those functions. My assumption is that it's only a few functions, and that those are inherently pretty clunky with JAX. I don't really want to think about extending the signatures (yet at least), because the current jax.jit support is experimental and behind a flag, and adding keywords is public and not reversible.
Perhaps making a note on the tracking issue about this being an option, but not done because of the reason above (could be done in the future, if JAX usage takes off)?

If in the future we want to support these functions, we'll have to modify this point to catch jax.errors.TracerArrayConversionError and reraise a backend-agnostic exception, so that scipy.cluster.leaders and similar can then catch it and reraise an informative error message about size= being mandatory.

Copy link
Contributor Author

@crusaderky crusaderky Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scipy.cluster.leaders is a function which in a nearby future will work in eager JAX but can't work in jax.jit short of a public API change, because its output arrays' shape is xp.unique_values(input).shape.

@pearu @vfdev-5 a while ago you asked offline how can we run inside jax.jit a function such as this.
It will be possible for an end user to call do so, at the condition that they consume its output and revert to a known shape, for example:

import array_api_extra as xpx
from scipy.cluster import leaders

def _eager(x):
    a, b = leaders(x)  # shapes = (None, ), (None, )
    xp = array_namespace(a, b)
    # silly example; probably won't make sense functionally
    return xp.max(a), xp.max(b)

# This is just an example that makes little sense;
# in practice @jax.jit will be much higher in the call stack
@jax.jit
def f(x):
    return xpx.lazy_apply(
        _eager, x, shape=((), ()), dtype=(x.dtype, x.dtype))
    )

Comment on lines 169 to 171
jax.errors.TracerArrayConversionError
When `xp=jax.numpy`, `shape` is unknown (it contains None on one or more axes)
and this function was called inside `jax.jit`.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment below

@lucascolley lucascolley changed the title WIP apply_numpy_func WIP apply_lazy Jan 15, 2025
@crusaderky
Copy link
Contributor Author

crusaderky commented Jan 16, 2025

I think I want to have some evidence that the whole thing works in practice before I finalize this PR.
See scipy/scipy#22342

@crusaderky crusaderky force-pushed the apply branch 4 times, most recently from 794a35a to 2d007e0 Compare January 28, 2025 16:09
report.exclude_also = [
'\.\.\.',
'if TYPE_CHECKING:',
]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was causing coverage to skip the whole _lazy.py for some reason

Comment on lines +277 to +278
# jax.pure_callback calls jax.jit under the hood, but without the chance of
# passing static_argnames / static_argnums.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a very unpleasant discovery.

Comment on lines +283 to +285
lazy_kwargs[k] = v
else:
eager_kwargs[k] = v
Copy link
Contributor Author

@crusaderky crusaderky Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original design was much more simply stating that there cannot be arrays in the kwargs.
However I later found out that there is severely obfuscated code in scipy that breaks this assumption.

Comment on lines +246 to +247
metas = [arg._meta for arg in args if hasattr(arg, "_meta")] # pylint: disable=protected-access
meta_xp = array_namespace(*metas)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI not being able to infer the meta for argument-less generator functions was the reason that made me ditch support for them. Alternatively I could have added a meta_xp parameter but I really disliked it as it would only be meaningful for dask.

# Block until the graph materializes and reraise exceptions. This allows
# `pytest.raises` and `pytest.warns` to work as expected. Note that this would
# not work on scheduler='distributed', as it would not block.
return dask.persist(out, scheduler="threads")[0] # type: ignore[no-any-return,attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the important change in this module; the rest is cosmetic.

@crusaderky crusaderky changed the title WIP lazy_apply ENH: lazy_apply Feb 6, 2025
@crusaderky crusaderky marked this pull request as ready for review February 6, 2025 13:47
@crusaderky
Copy link
Contributor Author

This is ready for final review and merge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants