Skip to content

Commit

Permalink
Fix for incorrect gradients through equinox.internal.while_loop(..., …
Browse files Browse the repository at this point in the history
…kind='bounded')
  • Loading branch information
patrick-kidger committed Jul 5, 2023
1 parent b7c6f63 commit aa31212
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 38 deletions.
2 changes: 1 addition & 1 deletion equinox/internal/_loop/bounded.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def bounded_while_loop(
return init_val

cond_fun_, body_fun_, init_val_, _ = common_rewrite(
cond_fun, body_fun, init_val, max_steps, buffers
cond_fun, body_fun, init_val, max_steps, buffers, makes_false_steps=True
)
del cond_fun, body_fun, init_val
rounded_max_steps = base ** int(math.ceil(math.log(max_steps, base)))
Expand Down
2 changes: 1 addition & 1 deletion equinox/internal/_loop/checkpointed.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def checkpointed_while_loop(
if max_steps == 0:
return init_val
cond_fun_, body_fun_, init_val_, buffers_ = common_rewrite(
cond_fun, body_fun, init_val, max_steps, buffers
cond_fun, body_fun, init_val, max_steps, buffers, makes_false_steps=False
)
del cond_fun, body_fun, init_val, buffers
body_fun_ = filter_closure_convert(body_fun_, init_val_)
Expand Down
74 changes: 40 additions & 34 deletions equinox/internal/_loop/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def _select_if_vmap_jvp(primals, tangents):
assert x.dtype == tx.aval.dtype
assert y.shape == ty.aval.shape
assert y.dtype == ty.aval.dtype
out = _select_if_vmap(pred, x, y)
out = _select_if_vmap(pred, x, y, makes_false_steps=False)
if type(tx) is ad.Zero and type(ty) is ad.Zero:
t_out = tx
else:
if type(tx) is ad.Zero:
tx = jnp.zeros(tx.aval.shape, tx.aval.dtype) # pyright: ignore
if type(ty) is ad.Zero:
ty = jnp.zeros(ty.aval.shape, ty.aval.dtype) # pyright: ignore
t_out = _select_if_vmap(pred, tx, ty)
t_out = _select_if_vmap(pred, tx, ty, makes_false_steps=False)
return out, t_out


Expand All @@ -52,11 +52,11 @@ def _select_if_vmap_transpose(ct, pred, x, y):
else:
zero = jnp.zeros(ct.shape, ct.dtype)
if ad.is_undefined_primal(x):
ct_x = _select_if_vmap(pred, ct, zero)
ct_x = _select_if_vmap(pred, ct, zero, makes_false_steps=False)
else:
ct_x = None
if ad.is_undefined_primal(y):
ct_y = _select_if_vmap(pred, zero, ct)
ct_y = _select_if_vmap(pred, zero, ct, makes_false_steps=False)
else:
ct_y = None
return [None, ct_x, ct_y]
Expand All @@ -75,7 +75,7 @@ def _select_if_vmap_batch(axis_size, axis_name, trace, inputs, batch_axes):
y = jnp.broadcast_to(y, (axis_size,) + y.shape)
else:
y = jnp.moveaxis(y, by, 0)
out = _select_if_vmap(pred, x, y)
out = _select_if_vmap(pred, x, y, makes_false_steps=False)
else:
out = jax.vmap(lax.select, in_axes=(bp, bx, by))(pred, x, y)
return out, 0
Expand All @@ -96,45 +96,48 @@ def _select_if_vmap_batch(axis_size, axis_name, trace, inputs, batch_axes):
# have a False predicate. (But the loop is still going whilst other batch elements have
# a True predicate). However, if we have no vmap at all, then we can be slightly more
# efficient: don't introduce a select at all.
def _select_if_vmap(pred, x, y):
"""As `lax.select(pred, x, y)` if `pred` is vmap'd. Unvmap'd `pred` are assumed to
def _select_if_vmap(pred, x, y, makes_false_steps):
"""As `lax.select(pred, x, y)` if `pred` is vmap'd. Not-vmap'd `pred` are assumed to
be `True`, so that in this case `x` is returned unconditionally.
"""
pred = fixed_asarray(pred)
assert pred.shape == ()
assert pred.dtype == jnp.bool_
x = fixed_asarray(x)
y = fixed_asarray(y)
assert x.shape == y.shape
assert x.dtype == y.dtype
return select_if_vmap_p.bind(pred, x, y)


def _maybe_set_impl(pred, xs, i, x, *, kwargs):
x = _select_if_vmap(pred, x, xs.at[i].get(**kwargs))
if makes_false_steps:
return lax.select(pred, x, y)
else:
pred = fixed_asarray(pred)
assert pred.shape == ()
assert pred.dtype == jnp.bool_
x = fixed_asarray(x)
y = fixed_asarray(y)
assert x.shape == y.shape
assert x.dtype == y.dtype
return select_if_vmap_p.bind(pred, x, y)


def _maybe_set_impl(pred, xs, i, x, *, kwargs, makes_false_steps):
x = _select_if_vmap(pred, x, xs.at[i].get(**kwargs), makes_false_steps)
return [xs.at[i].set(x, **kwargs)]


def _maybe_set_abstract(pred, xs, i, x, *, kwargs):
def _maybe_set_abstract(pred, xs, i, x, *, kwargs, makes_false_steps):
return [xs]


def _maybe_set_jvp(primals, tangents, *, kwargs):
def _maybe_set_jvp(primals, tangents, *, kwargs, makes_false_steps):
pred, xs, i, x = primals
_, t_xs, _, t_x = tangents
out = _maybe_set(pred, xs, i, x, kwargs)
out = _maybe_set(pred, xs, i, x, kwargs, makes_false_steps)
if type(t_x) is ad.Zero and type(t_xs) is ad.Zero:
t_out = t_xs
else:
if type(t_x) is ad.Zero:
t_x = jnp.zeros(t_x.aval.shape, t_x.aval.dtype) # pyright: ignore
if type(t_xs) is ad.Zero:
t_xs = jnp.zeros(t_xs.aval.shape, t_xs.aval.dtype) # pyright: ignore
t_out = _maybe_set(pred, t_xs, i, t_x, kwargs)
t_out = _maybe_set(pred, t_xs, i, t_x, kwargs, makes_false_steps)
return [out], [t_out]


def _maybe_set_transpose(ct_out, pred, xs, i, x, *, kwargs):
def _maybe_set_transpose(ct_out, pred, xs, i, x, *, kwargs, makes_false_steps):
assert not ad.is_undefined_primal(pred)
assert not ad.is_undefined_primal(i)
[ct_out] = ct_out
Expand All @@ -154,7 +157,7 @@ def _maybe_set_transpose(ct_out, pred, xs, i, x, *, kwargs):
ct_x = None
else:
ct_x = ct_out.at[i].get(**kwargs)
ct_x = _select_if_vmap(pred, ct_x, jnp.zeros_like(ct_x))
ct_x = _select_if_vmap(pred, ct_x, jnp.zeros_like(ct_x), makes_false_steps)
else:
ct_x = None
return [None, ct_xs, None, ct_x]
Expand All @@ -179,24 +182,27 @@ def _maybe_set_transpose(ct_out, pred, xs, i, x, *, kwargs):
# Second, the fact that unbatched `pred` are necessarily always True (due to being
# used inside a while loop) means that we use `_select_if_vmap` over simply
# `lax.select`.
def _maybe_set(pred, xs, i, x, kwargs):
def _maybe_set(pred, xs, i, x, kwargs, makes_false_steps):
"""As `lax.select(pred, xs.at[i].set(x, **kwargs), xs)`, under the assumption that
`xs.at[i]` is written to at most once, so that we can have a more efficient
transpose rule. Also assumes unvmap'd `pred` is unconditionally `True`.
every location `i` is written to at most once. (So that we can have a more efficient
transpose rule. Also assumes that non-vmap'd `pred` is always `True`.)
"""
assert pred.shape == ()
assert pred.dtype == jnp.bool_
dtype = jnp.result_type(xs, x)
xs = xs.astype(dtype)
x = fixed_asarray(x).astype(dtype)
[out] = maybe_set_p.bind(pred, xs, i, x, kwargs=kwargs)
[out] = maybe_set_p.bind(
pred, xs, i, x, kwargs=kwargs, makes_false_steps=makes_false_steps
)
return out


class _Buffer(Module):
_array: Union[Shaped[Array, "..."], "_Buffer"]
_pred: Bool[Array, ""]
_tag: object = field(static=True)
_makes_false_steps: bool = field(static=True)

def __getitem__(self, item):
return self._array[item]
Expand All @@ -206,8 +212,8 @@ def _op(self, pred, item, x, op, kwargs):
if isinstance(self._array, _Buffer):
array = self._array._op(pred, item, x, op, kwargs)
else:
array = op(pred, self._array, item, x, kwargs)
return _Buffer(array, self._pred, self._tag)
array = op(pred, self._array, item, x, kwargs, self._makes_false_steps)
return _Buffer(array, self._pred, self._tag, self._makes_false_steps)

@property
def at(self):
Expand Down Expand Up @@ -264,7 +270,7 @@ def _fixed_asarray_jvp(x, tx):
return fixed_asarray(x), fixed_asarray(tx)


def common_rewrite(cond_fun, body_fun, init_val, max_steps, buffers):
def common_rewrite(cond_fun, body_fun, init_val, max_steps, buffers, makes_false_steps):
"""Handles:
- Efficient in-place updates;
Expand Down Expand Up @@ -314,7 +320,7 @@ def is_our_buffer(node):
def wrap_buffer(leaf):
if not is_array(leaf):
raise ValueError("Only arrays can be treated as buffers.")
return _Buffer(leaf, pred, tag)
return _Buffer(leaf, pred, tag, makes_false_steps)

def unwrap_and_select(leaf, leaf2):
if is_our_buffer(leaf):
Expand All @@ -325,7 +331,7 @@ def unwrap_and_select(leaf, leaf2):
assert is_array(leaf2._array)
return leaf2._array
else:
return _select_if_vmap(pred, leaf2, leaf)
return _select_if_vmap(pred, leaf2, leaf, makes_false_steps)

step, pred, val = val
_, _, buffer_val = tree_at(
Expand Down
2 changes: 1 addition & 1 deletion equinox/internal/_loop/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def while_loop(
if kind == "lax":
del kind, checkpoints, base
cond_fun_, body_fun_, init_val_, _ = common_rewrite(
cond_fun, body_fun, init_val, max_steps, buffers
cond_fun, body_fun, init_val, max_steps, buffers, makes_false_steps=False
)
del cond_fun, body_fun, init_val
_, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "equinox"
version = "0.10.7"
version = "0.10.8"
description = "Elegant easy-to-use neural networks in JAX."
readme = "README.md"
requires-python ="~=3.9"
Expand Down

0 comments on commit aa31212

Please sign in to comment.