Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better custom autodiff #407

Merged
merged 4 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ _In other words, why should you care? Because Equinox is really simple to learn,
pip install equinox
```

Requires Python 3.9+ and JAX 0.4.11+.
Requires Python 3.9+ and JAX 0.4.13+.

## Documentation

Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ _In other words, why should you care? Because Equinox is really simple to learn,
pip install equinox
```

Requires Python 3.9+ and JAX 0.4.11+.
Requires Python 3.9+ and JAX 0.4.13+.

## Quick example

Expand Down
215 changes: 184 additions & 31 deletions equinox/_ad.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools as ft
import types
import typing
import warnings
from collections.abc import Callable, Sequence
from typing import (
Any,
Expand Down Expand Up @@ -502,9 +503,18 @@ def f(x, y):
return closure_converted


# Work around JAX issue #16000
def _drop_ints(tangent, primal):
if jnp.issubdtype(jnp.result_type(primal), jnp.inexact):
def _materialise_symbolic_zero(x, grad_x):
if grad_x is None and is_inexact_array(x):
return jnp.zeros_like(x)
else:
return grad_x


def _drop_nondiff(tangent, primal):
if isinstance(tangent, jax.custom_derivatives.SymbolicZero):
return None
elif jnp.issubdtype(jnp.result_type(primal), jnp.inexact):
# Work around JAX issue #16000
return tangent
else:
return None
Expand All @@ -515,7 +525,9 @@ class filter_custom_jvp:

Works in the same way as `jax.custom_jvp`, except that you do not need to specify
`nondiff_argnums`. Instead, arguments are automatically split into differentiable
and nondifferentiable based on whether or not they are a floating-point JAX array.
and nondifferentiable. (Everything that is not a floating-point array is necessarily
nondifferentiable. In addition, some floating-point arrays may happen not to have
been differentiated.)

The tangents of the nondifferentiable arguments will be passed as `None`.

Expand All @@ -530,7 +542,7 @@ class filter_custom_jvp:
def call(x, y, *, fn):
return fn(x, y)

@call.defjvp
@call.def_jvp
def call_jvp(primals, tangents, *, fn):
x, y = primals
tx, ty = tangents
Expand All @@ -548,21 +560,40 @@ def fn_wrapper(static, dynamic):
self.fn = jax.custom_jvp(fn_wrapper, nondiff_argnums=(0,))

def defjvp(self, fn_jvp):
warnings.warn(
"As of Equinox 0.10.7, `equinox.filter_custom_jvp.defjvp` is deprecated in "
"favour of `.def_jvp`. This new API supports symbolic zeros, which allow "
"for more efficient autodifferentiation rules. In particular:, `None` was "
"previously passed to indicate a symbolic zero tangent for all objects "
"that weren't inexact arrays, but all inexact arrays always had an "
"array-valued tangent. Now, `None` may also be passed to indicate that an "
"inexact array has a symbolic zero tangent."
)

def _fn_jvp(args, t_args, **kwargs):
t_args = jtu.tree_map(_materialise_symbolic_zero, args, t_args)
return fn_jvp(args, t_args, **kwargs)

self.def_jvp(_fn_jvp)

def def_jvp(self, fn_jvp):
def fn_jvp_wrapper(static, dynamic, tangents):
(dynamic,) = dynamic
(tangents,) = tangents
d_args, _ = dynamic
t_args, t_kwargs = tangents
if any(x is not None for x in jtu.tree_leaves(t_kwargs)):
if len(jtu.tree_leaves(t_kwargs)) > 0:
raise ValueError("Received keyword tangent")
t_args = jtu.tree_map(_drop_ints, t_args, d_args)
t_args = jtu.tree_map(_drop_nondiff, t_args, d_args)
args, kwargs = combine(dynamic, static)
return fn_jvp(args, t_args, **kwargs)

self.fn.defjvp(fn_jvp_wrapper)
self.fn.defjvp(fn_jvp_wrapper, symbolic_zeros=True)

def defjvps(self, *a, **kw):
raise NotImplementedError("filter_custom_jvp().defjvps is not implemented")
raise NotImplementedError(
"`equinox.filter_custom_jvp.defjvps` is not implemented"
)

def __call__(self, *args, **kwargs):
dynamic, static = partition((args, kwargs), is_array)
Expand Down Expand Up @@ -594,44 +625,86 @@ def nondifferentiable(
return combine(dynamic, static)


def _get_perturbed(x):
assert type(x) is jax.custom_derivatives.CustomVJPPrimal
return x.perturbed


def _get_value(x):
assert type(x) is jax.custom_derivatives.CustomVJPPrimal
return x.value


def _get_value_assert_unperturbed(x):
assert type(x) is jax.custom_derivatives.CustomVJPPrimal
assert x.perturbed is False
return x.value


def _zero_to_none(ct):
if isinstance(ct, jax.custom_derivatives.SymbolicZero):
return None
else:
return ct


def _none_to_zero(ct, x):
if ct is None:
if x is None:
return None
else:
aval = jax.core.get_aval(x).at_least_vspace()
return jax.custom_derivatives.SymbolicZero(aval)
else:
return ct


class filter_custom_vjp:
"""As `jax.custom_vjp`, but with a nicer interface.

Usage is:
```python
@equinox.filter_custom_vjp
def fn(vjp_arg, *args, **kwargs):
# vjp_arg is some PyTree of arbitrary Python objects.
# args, kwargs contain arbitrary Python objects.
# `vjp_arg` is some PyTree of arbitrary Python objects.
# `args`, `kwargs` contain arbitrary Python objects.
...
return obj # some PyTree of arbitrary Python objects.

def fn_fwd(vjp_arg, *args, **kwargs):
return out # some PyTree of arbitrary Python objects.

@fn.def_fwd
def fn_fwd(perturbed, vjp_arg, *args, **kwargs):
# `perturbed` is a pytree with the same structure as `vjp_arg`. Every leaf is
# either `True` or `False`, indicating whether that leaf is being
# differentiated. (All leaves that are not floating-point arrays will
# necessarily have `False`. Some floating-point arrays might happen not to be
# differentiated either.)
...
# Should return `obj` as before. `residuals` can be any collection of JAX
# Should return `out` as before. `residuals` can be any collection of JAX
# arrays you want to keep around for the backward pass.
return obj, residuals
return out, residuals

def fn_bwd(residuals, grad_obj, vjp_arg, *args, **kwargs):
# grad_obj will have `None` as the gradient for any leaves of `obj` that were
# not JAX arrays
@fn.def_bwd
def fn_bwd(residuals, grad_obj, perturbed, vjp_arg, *args, **kwargs):
# `grad_obj` will have `None` as the gradient for any leaves of `out` that were
# not differentiated.
...
# grad_vjp_arg should have `None` as the gradient for any leaves of `vjp_arg`
# that were not JAX arrays.
# `grad_vjp_arg` should be a pytree with the same structure as `vjp_arg`.
# It can have `None` leaves to indicate that that argument has zero gradient.
# (E.g. if the leaf was not a JAX array.)
return grad_vjp_arg

fn.defvjp(fn_fwd, fn_bwd)
```

The key differences to `jax.custom_vjp` are that:

- Only the gradient of the first argument, `vjp_arg`, should be computed on the
backward pass. Everything else will automatically have zero gradient.
- You do not need to distinguish differentiable from nondifferentiable manually.
Instead you should return gradients for all inexact JAX arrays in the first
Instead you should return gradients for all perturbed arrays in the first
argument. (And just put `None` on every other leaf of the PyTree.)
- As a convenience, all of the inputs from the forward pass are additionally made
available to you on the backward pass.
- As a convenience, you can declare forward and backward passes using `def_fwd` and
`def_bwd`, rather than a single `defvjp` as in core JAX.

!!! tip

Expand All @@ -642,9 +715,53 @@ def fn_bwd(residuals, grad_obj, vjp_arg, *args, **kwargs):

def __init__(self, fn):
self.fn = fn
self.fn_fwd: Optional[Callable] = None
self.fn_bwd: Optional[Callable] = None
self.fn_wrapped = None

def def_fwd(self, fn_fwd):
self.fn_fwd = fn_fwd
if self.fn_bwd is not None:
self._defvjp()

def def_bwd(self, fn_bwd):
self.fn_bwd = fn_bwd
if self.fn_fwd is not None:
self._defvjp()

def defvjp(self, fn_fwd, fn_bwd):
warnings.warn(
"As of Equinox 0.10.7, `equinox.filter_custom_vjp.defvjp` is deprecated in "
"favour of `.def_fwd` and `.def_bwd`. This new API supports symbolic "
"zeros, which allow for more efficient autodifferentiation rules. In "
"particular:\n"
"- the fwd and bwd functions take an extra `perturbed` argument, which "
" indicates which primals actually need a gradient. You can use this "
" to skip computing the gradient for any unperturbed value. (You can "
" also safely just ignore this if you wish.)\n"
"- `None` was previously passed to indicate a symbolic zero gradient for "
" all objects that weren't inexact arrays, but all inexact arrays "
" always had an array-valued gradient. Now, `None` may also be passed "
" to indicate that an inexact array has a symbolic zero gradient."
)

def _fn_fwd(perturbed, vjp_arg, *args, **kwargs):
del perturbed
return fn_fwd(vjp_arg, *args, **kwargs)

def _fn_bwd(
residuals, grad_diff_array_out, perturbed, vjp_arg, *args, **kwargs
):
del perturbed
grad_diff_array_out = jtu.tree_map(
_materialise_symbolic_zero, vjp_arg, grad_diff_array_out
)
return fn_bwd(residuals, grad_diff_array_out, vjp_arg, *args, **kwargs)

self.def_fwd(_fn_fwd)
self.def_bwd(_fn_bwd)

def _defvjp(self):
def fn_wrapped(
nonarray_vjp_arg,
nonarray_args_kwargs,
Expand All @@ -668,11 +785,27 @@ def fn_fwd_wrapped(
nondiff_array_vjp_arg,
array_args_kwargs,
):
assert self.fn_fwd is not None
nonarray_perturbed = jtu.tree_map(lambda _: False, nonarray_vjp_arg)
nondiff_array_perturbed = jtu.tree_map(
lambda _: False, nondiff_array_vjp_arg
)
diff_array_perturbed = jtu.tree_map(_get_perturbed, diff_array_vjp_arg)
perturbed = combine(
nonarray_perturbed, nondiff_array_perturbed, diff_array_perturbed
)
diff_array_vjp_arg = jtu.tree_map(_get_value, diff_array_vjp_arg)
nondiff_array_vjp_arg = jtu.tree_map(
_get_value_assert_unperturbed, nondiff_array_vjp_arg
)
array_args_kwargs = jtu.tree_map(
_get_value_assert_unperturbed, array_args_kwargs
)
vjp_arg = combine(
nonarray_vjp_arg, diff_array_vjp_arg, nondiff_array_vjp_arg
)
args, kwargs = combine(nonarray_args_kwargs, array_args_kwargs)
out, residuals = fn_fwd(vjp_arg, *args, **kwargs)
out, residuals = self.fn_fwd(perturbed, vjp_arg, *args, **kwargs)
array_out, nonarray_out = partition(out, is_array)
diff_array_out, nondiff_array_out = partition(array_out, is_inexact_array)
out = diff_array_out, nondiff_array_out, Static(nonarray_out)
Expand All @@ -681,37 +814,48 @@ def fn_fwd_wrapped(
diff_array_vjp_arg,
nondiff_array_vjp_arg,
array_args_kwargs,
perturbed,
)

def fn_bwd_wrapped(nonarray_vjp_arg, nonarray_args_kwargs, residuals, grad_out):
assert self.fn_bwd is not None
(
residuals,
diff_array_vjp_arg,
nondiff_array_vjp_arg,
array_args_kwargs,
perturbed,
) = residuals
vjp_arg = combine(
nonarray_vjp_arg, diff_array_vjp_arg, nondiff_array_vjp_arg
)
args, kwargs = combine(nonarray_args_kwargs, array_args_kwargs)
grad_diff_array_out, _, _ = grad_out
out = fn_bwd(residuals, grad_diff_array_out, vjp_arg, *args, **kwargs)
if jtu.tree_structure(out) != jtu.tree_structure(diff_array_vjp_arg):
grad_diff_array_out = jtu.tree_map(_zero_to_none, grad_diff_array_out)
out = self.fn_bwd(
residuals, grad_diff_array_out, perturbed, vjp_arg, *args, **kwargs
)
if jtu.tree_structure(out, is_leaf=_is_none) != jtu.tree_structure(
diff_array_vjp_arg, is_leaf=_is_none
):
raise RuntimeError(
"custom_vjp gradients must have the same structure as "
"`equinox.filter(vjp_arg, equinox.is_inexact_array)`, where "
"`vjp_arg` is the first argument used in the forward pass."
)
out = jtu.tree_map(_none_to_zero, out, diff_array_vjp_arg, is_leaf=_is_none)
# None is the gradient through nondiff_array_vjp_arg and array_args_kwargs
return out, None, None

fn_wrapped = jax.custom_vjp(fn_wrapped, nondiff_argnums=(0, 1))
fn_wrapped.defvjp(fn_fwd_wrapped, fn_bwd_wrapped)
fn_wrapped.defvjp(fn_fwd_wrapped, fn_bwd_wrapped, symbolic_zeros=True)
self.fn_wrapped = fn_wrapped

def __call__(self, vjp_arg, /, *args, **kwargs):
if self.fn_wrapped is None:
raise RuntimeError(f"defvjp not yet called for {self.fn.__name__}")
raise RuntimeError(
f"`def_fwd` or `def_bwd` not yet called for {self.fn.__name__}"
)
array_vjp_arg, nonarray_vjp_arg = partition(vjp_arg, is_array)
diff_array_vjp_arg, nondiff_array_vjp_arg = partition(
array_vjp_arg, is_inexact_array
Expand All @@ -735,17 +879,26 @@ def __call__(self, vjp_arg, /, *args, **kwargs):
_filter_custom_jvp_doc = filter_custom_jvp.__doc__
_filter_custom_vjp_doc = filter_custom_vjp.__doc__

def def_jvp(fn_jvp):
pass

def defjvp(fn_jvp):
pass

def filter_custom_jvp(fn):
return types.SimpleNamespace(defjvp=defjvp)
return types.SimpleNamespace(def_jvp=def_jvp, defjvp=defjvp)

def def_fwd(fn_fwd):
pass

def def_bwd(fn_bwd):
pass

def defvjp(fn_fwd, fn_bwd):
pass

def filter_custom_vjp(fn):
return types.SimpleNamespace(defvjp=defvjp)
return types.SimpleNamespace(def_fwd=def_fwd, def_bwd=def_bwd, defvjp=defvjp)

filter_custom_jvp.__doc__ = _filter_custom_jvp_doc
filter_custom_vjp.__doc__ = _filter_custom_vjp_doc
7 changes: 6 additions & 1 deletion equinox/internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@
primitive_finalisations as primitive_finalisations,
register_impl_finalisation as register_impl_finalisation,
)
from ._loop import scan as scan, while_loop as while_loop
from ._loop import (
maybe_set_p as maybe_set_p,
scan as scan,
select_if_vmap_p as select_if_vmap_p,
while_loop as while_loop,
)
from ._misc import (
ContainerMeta as ContainerMeta,
eval_empty as eval_empty,
Expand Down
Loading