-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: lazy_apply
#86
base: main
Are you sure you want to change the base?
ENH: lazy_apply
#86
Conversation
src/array_api_extra/_apply.py
Outdated
``core_indices`` is a safety measure to prevent incorrect results on | ||
Dask along chunked axes. Consider this:: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This design was informed from https://docs.xarray.dev/en/latest/generated/xarray.apply_ufunc.html
src/array_api_extra/_apply.py
Outdated
The dask graph won't be computed. As a special limitation, `func` must return | ||
exactly one output. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This limitation is straightforward to fix in Dask (at the cost of API duplication).
Until then, however, I suspect it will be a major roadblock for Dask support in scipy.
It can be also hacked outside of dask but I'm a hesitant to do that for the sake of robustness, as it would rely on deliberately triggering key collisions between diverging graph branches.
src/array_api_extra/_apply.py
Outdated
`input_indices`, `output_indices`, and `core_indices`, but you may also need | ||
`adjust_chunks` and `new_axes` depending on the function. | ||
|
||
Read `dask.array.blockwise`: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/array_api_extra/_apply.py
Outdated
- ``output_indices[0]`` maps to the ``out_ind`` parameter | ||
- ``adjust_chunks[0]`` maps to the ``adjust_chunks`` parameter | ||
- ``new_axes[0]`` maps to the ``new_axes`` parameter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are all lists for forward-compatibility to a yet-to-be-written da.blockwise variant that supports multiple outputs.
src/array_api_extra/_apply.py
Outdated
If `func` returns a single (non-sequence) output, this must be a sequence | ||
with a single element. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried overloading but it was a big headache when validating inputs. I found this approach much simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I am missing some context here. Why are we wrapping arbitrary NumPy functions? Instead of, e.g., considering individual functions which we need one-by-one.
There are many points in scipy that look like this: x = np.asarray(x)
y = np.asarray(y)
z = some_cython_kernel(x, y)
z = xp.asarray(z) None of them will ever work with arrays on GPU devices, of course, and they'll either need a pure-array-api alternative or a dispatch to cupy.scipy etc. None of them work with jitted, cpu-based JAX either, because With Dask, they technically work because there is no materialization guard but most times you would prefer it if there was one. However, it is possible to make these pieces of Cython code work thanks to this PR. There are two competing functions in Dask to achieve this, with different API:
map_blocks is a variant of blockwise with simplified API, which can only work on broadcastable inputs. This problem has already been dealt with by xarray, with https://docs.xarray.dev/en/latest/generated/xarray.apply_ufunc.html. Note that xarray API is more user-friendly thanks to each dimension being labelled at all times, so apply_ufunc can do fancy tricks like auto-transposing the inputs and pushing the dimensions that func doesn't know about to the left. What I tried to implement here is equivalent to |
Okay, thanks. When you say
in the docstring, that isn't strictly true, right? It relies on I understand the utility of this PR now.
I hadn't envisioned tackling this yet. In my mind getting Cython kernels working with dask/jax jit has been in the same "for later" basket as handling device transfers or writing new implementations to delegate to. But if the implementation works, makes sense to tackle it. |
Correct. Nominally it will fail when densifying sparse arrays and moving data from GPU to CPU. A final user can however force their way through, if they want to, by deliberately suppressing transfer/densification guards for the time necessary to run the scipy function. Either that, or do an explicit device to cpu transfer / There is nothing however a jax or dask user can do today, short of completely getting out of the graph generation phase. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @crusaderky!
The enormous amount of extra complication needed to make it work with Dask makes me uncomfortable.
Yes indeed, it does. That looks like it's way too much. I don't think we would like to use all those extra keywords and Dask-specific functions in SciPy. If you'd drop those, does it make Dask completely non-working or is there a subset of functionality that would still work. I'd say that JAX shows that it can be straightforward, and a similar callback mechanism could be used for PyTorch/MLX/ndonnx as well - if that were to exist in those libraries.
src/array_api_extra/_apply.py
Outdated
Sparse | ||
By default, sparse prevents implicit densification through ``np.asarray`. | ||
`This safety mechanism can be disabled | ||
<https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine to leave as is for now I'd say. Once sparse
adds a better API for this (an env var doesn't work), it seems reasonable to add a force=False
option to this function. There are various reasons why one may want to force an expensive conversion; that kind of thing should always be opt-in on a case-by-case basis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be tackled by a more general design pattern?
with disable_guards():
y = scipy.somefunc(x)
where disable_guards
is backend-specific.
Applies to torch/cupy/jax arrays on GPU, sparse arrays, etc.
Do you have a SciPy branch with this function being used @crusaderky? I'd be interested in playing with it. |
I could make it work by rechunking all the inputs to a single chunk. In other words the whole calculation would need to fit in memory at once on a single worker. |
Not yet |
@rgommers I rewrote it to do exactly this and now it's a lot cleaner. I'll keep my eyes open if I can see patterns in scipy we can leverage to improve Dask support (e.g. if there is there are frequent elementwise functions that could be trivially served by map_blocks) |
@allcontributors, please add @crusaderky for bug let me just try this once from this PR... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't tried to test this, but did go through it in a bit more detail now - overall looks good, a few comments. Looking forward to trying it out!
I've reworked the design a bit.
FYI, when I move |
src/array_api_extra/_lib/_apply.py
Outdated
if any(s is None for shape in shapes for s in shape): | ||
# Unknown output shape. Won't work with jax.jit, but it | ||
# can work with eager jax. | ||
# Raises jax.errors.TracerArrayConversionError if we're inside jax.jit. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Offline conversation:
how do you see scipy functions with unknown output size work with jax.jit? e.g. scipy.cluster.leaders? Should we do like in jnp.unique_all and add a size=None optional parameter, which becomes mandatory when the jit is on?
I'm not sure that we should add support for those functions. My assumption is that it's only a few functions, and that those are inherently pretty clunky with JAX. I don't really want to think about extending the signatures (yet at least), because the current jax.jit support is experimental and behind a flag, and adding keywords is public and not reversible.
Perhaps making a note on the tracking issue about this being an option, but not done because of the reason above (could be done in the future, if JAX usage takes off)?
If in the future we want to support these functions, we'll have to modify this point to catch jax.errors.TracerArrayConversionError
and reraise a backend-agnostic exception, so that scipy.cluster.leaders
and similar can then catch it and reraise an informative error message about size=
being mandatory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scipy.cluster.leaders
is a function which in a nearby future will work in eager JAX but can't work in jax.jit
short of a public API change, because its output arrays' shape is xp.unique_values(input).shape
.
@pearu @vfdev-5 a while ago you asked offline how can we run inside jax.jit a function such as this.
It will be possible for an end user to call do so, at the condition that they consume its output and revert to a known shape, for example:
import array_api_extra as xpx
from scipy.cluster import leaders
def _eager(x):
a, b = leaders(x) # shapes = (None, ), (None, )
xp = array_namespace(a, b)
# silly example; probably won't make sense functionally
return xp.max(a), xp.max(b)
# This is just an example that makes little sense;
# in practice @jax.jit will be much higher in the call stack
@jax.jit
def f(x):
return xpx.lazy_apply(
_eager, x, shape=((), ()), dtype=(x.dtype, x.dtype))
)
src/array_api_extra/_lib/_apply.py
Outdated
jax.errors.TracerArrayConversionError | ||
When `xp=jax.numpy`, `shape` is unknown (it contains None on one or more axes) | ||
and this function was called inside `jax.jit`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment below
I think I want to have some evidence that the whole thing works in practice before I finalize this PR. |
794a35a
to
2d007e0
Compare
report.exclude_also = [ | ||
'\.\.\.', | ||
'if TYPE_CHECKING:', | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was causing coverage to skip the whole _lazy.py
for some reason
# jax.pure_callback calls jax.jit under the hood, but without the chance of | ||
# passing static_argnames / static_argnums. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a very unpleasant discovery.
lazy_kwargs[k] = v | ||
else: | ||
eager_kwargs[k] = v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original design was much more simply stating that there cannot be arrays in the kwargs.
However I later found out that there is severely obfuscated code in scipy that breaks this assumption.
metas = [arg._meta for arg in args if hasattr(arg, "_meta")] # pylint: disable=protected-access | ||
meta_xp = array_namespace(*metas) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI not being able to infer the meta for argument-less generator functions was the reason that made me ditch support for them. Alternatively I could have added a meta_xp
parameter but I really disliked it as it would only be meaningful for dask.
# Block until the graph materializes and reraise exceptions. This allows | ||
# `pytest.raises` and `pytest.warns` to work as expected. Note that this would | ||
# not work on scheduler='distributed', as it would not block. | ||
return dask.persist(out, scheduler="threads")[0] # type: ignore[no-any-return,attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the important change in this module; the rest is cosmetic.
This is ready for final review and merge. |
First draft of a wrapper around jax.apply_pure_callback.
Untested.
The enormous amount of extra complication needed to make it work with Dask makes me uncomfortable.@lucascolley you've been working on Dask support for Scipy; what's your vision for it?
CC @rgommers