Skip to content

Commit

Permalink
Fix for incorrect gradients through equinox.internal.while_loop(..., …
Browse files Browse the repository at this point in the history
…kind='bounded')
  • Loading branch information
patrick-kidger committed Jul 5, 2023
1 parent b7c6f63 commit 07351ec
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 50 deletions.
4 changes: 2 additions & 2 deletions equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,8 @@ def __call__(self, *args, **kwargs):
)
assert len(out_dynamic_flat) == len(out_dynamic_struct_flat)
for o1, o2 in zip(out_dynamic_flat, out_dynamic_struct_flat):
assert o1.shape == o2.shape
assert o1.dtype == o2.dtype
assert jnp.shape(o1) == jnp.shape(o2)
assert jnp.result_type(o1) == jnp.result_type(o2)
out = jtu.tree_unflatten(out_dynamic_treedef, out_dynamic_flat)
out = combine(out, self_out_static)
return out
Expand Down
2 changes: 1 addition & 1 deletion equinox/internal/_loop/bounded.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def bounded_while_loop(
return init_val

cond_fun_, body_fun_, init_val_, _ = common_rewrite(
cond_fun, body_fun, init_val, max_steps, buffers
cond_fun, body_fun, init_val, max_steps, buffers, makes_false_steps=True
)
del cond_fun, body_fun, init_val
rounded_max_steps = base ** int(math.ceil(math.log(max_steps, base)))
Expand Down
2 changes: 1 addition & 1 deletion equinox/internal/_loop/checkpointed.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def checkpointed_while_loop(
if max_steps == 0:
return init_val
cond_fun_, body_fun_, init_val_, buffers_ = common_rewrite(
cond_fun, body_fun, init_val, max_steps, buffers
cond_fun, body_fun, init_val, max_steps, buffers, makes_false_steps=False
)
del cond_fun, body_fun, init_val, buffers
body_fun_ = filter_closure_convert(body_fun_, init_val_)
Expand Down
121 changes: 77 additions & 44 deletions equinox/internal/_loop/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import jax.tree_util as jtu
from jaxtyping import Array, Bool, Shaped

from ..._errors import error_if
from ..._filters import is_array
from ..._module import field, Module
from ..._tree import tree_at, tree_equal
Expand All @@ -18,6 +19,15 @@


def _select_if_vmap_impl(pred, x, y):
msg = (
"Internal error in Equinox. Please report a bug at "
"https://github.com/patrick-kidger/equinox."
)
x = error_if(x, jnp.invert(pred), msg)
return x


def _select_if_vmap_abstract(pred, x, y):
return x


Expand All @@ -28,15 +38,15 @@ def _select_if_vmap_jvp(primals, tangents):
assert x.dtype == tx.aval.dtype
assert y.shape == ty.aval.shape
assert y.dtype == ty.aval.dtype
out = _select_if_vmap(pred, x, y)
out = _select_if_vmap(pred, x, y, makes_false_steps=False)
if type(tx) is ad.Zero and type(ty) is ad.Zero:
t_out = tx
else:
if type(tx) is ad.Zero:
tx = jnp.zeros(tx.aval.shape, tx.aval.dtype) # pyright: ignore
if type(ty) is ad.Zero:
ty = jnp.zeros(ty.aval.shape, ty.aval.dtype) # pyright: ignore
t_out = _select_if_vmap(pred, tx, ty)
t_out = _select_if_vmap(pred, tx, ty, makes_false_steps=False)
return out, t_out


Expand All @@ -52,11 +62,11 @@ def _select_if_vmap_transpose(ct, pred, x, y):
else:
zero = jnp.zeros(ct.shape, ct.dtype)
if ad.is_undefined_primal(x):
ct_x = _select_if_vmap(pred, ct, zero)
ct_x = _select_if_vmap(pred, ct, zero, makes_false_steps=False)
else:
ct_x = None
if ad.is_undefined_primal(y):
ct_y = _select_if_vmap(pred, zero, ct)
ct_y = _select_if_vmap(pred, zero, ct, makes_false_steps=False)
else:
ct_y = None
return [None, ct_x, ct_y]
Expand All @@ -75,15 +85,15 @@ def _select_if_vmap_batch(axis_size, axis_name, trace, inputs, batch_axes):
y = jnp.broadcast_to(y, (axis_size,) + y.shape)
else:
y = jnp.moveaxis(y, by, 0)
out = _select_if_vmap(pred, x, y)
out = _select_if_vmap(pred, x, y, makes_false_steps=False)
else:
out = jax.vmap(lax.select, in_axes=(bp, bx, by))(pred, x, y)
return out, 0


select_if_vmap_p = jax.core.Primitive("select_if_vmap")
select_if_vmap_p.def_impl(_select_if_vmap_impl)
select_if_vmap_p.def_abstract_eval(_select_if_vmap_impl)
select_if_vmap_p.def_abstract_eval(_select_if_vmap_abstract)
ad.primitive_jvps[select_if_vmap_p] = _select_if_vmap_jvp
ad.primitive_transposes[select_if_vmap_p] = _select_if_vmap_transpose
batching.axis_primitive_batchers[select_if_vmap_p] = _select_if_vmap_batch
Expand All @@ -96,45 +106,48 @@ def _select_if_vmap_batch(axis_size, axis_name, trace, inputs, batch_axes):
# have a False predicate. (But the loop is still going whilst other batch elements have
# a True predicate). However, if we have no vmap at all, then we can be slightly more
# efficient: don't introduce a select at all.
def _select_if_vmap(pred, x, y):
"""As `lax.select(pred, x, y)` if `pred` is vmap'd. Unvmap'd `pred` are assumed to
def _select_if_vmap(pred, x, y, makes_false_steps):
"""As `lax.select(pred, x, y)` if `pred` is vmap'd. Not-vmap'd `pred` are assumed to
be `True`, so that in this case `x` is returned unconditionally.
"""
pred = fixed_asarray(pred)
assert pred.shape == ()
assert pred.dtype == jnp.bool_
x = fixed_asarray(x)
y = fixed_asarray(y)
assert x.shape == y.shape
assert x.dtype == y.dtype
return select_if_vmap_p.bind(pred, x, y)


def _maybe_set_impl(pred, xs, i, x, *, kwargs):
x = _select_if_vmap(pred, x, xs.at[i].get(**kwargs))
if makes_false_steps:
return lax.select(pred, x, y)
else:
pred = fixed_asarray(pred)
assert pred.shape == ()
assert pred.dtype == jnp.bool_
x = fixed_asarray(x)
y = fixed_asarray(y)
assert x.shape == y.shape
assert x.dtype == y.dtype
return select_if_vmap_p.bind(pred, x, y)


def _maybe_set_impl(pred, xs, i, x, *, kwargs, makes_false_steps):
x = _select_if_vmap(pred, x, xs.at[i].get(**kwargs), makes_false_steps)
return [xs.at[i].set(x, **kwargs)]


def _maybe_set_abstract(pred, xs, i, x, *, kwargs):
def _maybe_set_abstract(pred, xs, i, x, *, kwargs, makes_false_steps):
return [xs]


def _maybe_set_jvp(primals, tangents, *, kwargs):
def _maybe_set_jvp(primals, tangents, *, kwargs, makes_false_steps):
pred, xs, i, x = primals
_, t_xs, _, t_x = tangents
out = _maybe_set(pred, xs, i, x, kwargs)
out = _maybe_set(pred, xs, i, x, kwargs, makes_false_steps)
if type(t_x) is ad.Zero and type(t_xs) is ad.Zero:
t_out = t_xs
else:
if type(t_x) is ad.Zero:
t_x = jnp.zeros(t_x.aval.shape, t_x.aval.dtype) # pyright: ignore
if type(t_xs) is ad.Zero:
t_xs = jnp.zeros(t_xs.aval.shape, t_xs.aval.dtype) # pyright: ignore
t_out = _maybe_set(pred, t_xs, i, t_x, kwargs)
t_out = _maybe_set(pred, t_xs, i, t_x, kwargs, makes_false_steps)
return [out], [t_out]


def _maybe_set_transpose(ct_out, pred, xs, i, x, *, kwargs):
def _maybe_set_transpose(ct_out, pred, xs, i, x, *, kwargs, makes_false_steps):
assert not ad.is_undefined_primal(pred)
assert not ad.is_undefined_primal(i)
[ct_out] = ct_out
Expand All @@ -154,7 +167,7 @@ def _maybe_set_transpose(ct_out, pred, xs, i, x, *, kwargs):
ct_x = None
else:
ct_x = ct_out.at[i].get(**kwargs)
ct_x = _select_if_vmap(pred, ct_x, jnp.zeros_like(ct_x))
ct_x = _select_if_vmap(pred, ct_x, jnp.zeros_like(ct_x), makes_false_steps)
else:
ct_x = None
return [None, ct_xs, None, ct_x]
Expand All @@ -179,39 +192,51 @@ def _maybe_set_transpose(ct_out, pred, xs, i, x, *, kwargs):
# Second, the fact that unbatched `pred` are necessarily always True (due to being
# used inside a while loop) means that we use `_select_if_vmap` over simply
# `lax.select`.
def _maybe_set(pred, xs, i, x, kwargs):
def _maybe_set(pred, xs, i, x, kwargs, makes_false_steps):
"""As `lax.select(pred, xs.at[i].set(x, **kwargs), xs)`, under the assumption that
`xs.at[i]` is written to at most once, so that we can have a more efficient
transpose rule. Also assumes unvmap'd `pred` is unconditionally `True`.
every location `i` is written to at most once. (So that we can have a more efficient
transpose rule. Also assumes that non-vmap'd `pred` is always `True`.)
"""
assert pred.shape == ()
assert pred.dtype == jnp.bool_
dtype = jnp.result_type(xs, x)
xs = xs.astype(dtype)
if jnp.shape(pred) != () or jnp.result_type(pred) != jnp.bool_:
raise ValueError("predicate must be a boolean scalar.")
dtype = jnp.result_type(x, xs)
if dtype != jnp.result_type(xs):
raise ValueError(
"When doing `buffer.at[i].set(value)`, then `value` must have a dtype that "
"can be promoted to the same dtype as `buffer`."
)
x = fixed_asarray(x).astype(dtype)
[out] = maybe_set_p.bind(pred, xs, i, x, kwargs=kwargs)
if jax.eval_shape(lambda: xs[i]) != jax.eval_shape(lambda: x):
raise ValueError(
"When doing `buffer.at[i].set(value)`, then `value` must have the same "
"shape as `buffer[i]`."
)
[out] = maybe_set_p.bind(
pred, xs, i, x, kwargs=kwargs, makes_false_steps=makes_false_steps
)
return out


class _Buffer(Module):
_array: Union[Shaped[Array, "..."], "_Buffer"]
_pred: Bool[Array, ""]
_tag: object = field(static=True)
_makes_false_steps: bool = field(static=True)

def __getitem__(self, item):
return self._array[item]

def _op(self, pred, item, x, op, kwargs):
def _op(self, pred, item, x, op, kwargs, makes_false_steps):
pred = pred & self._pred
if isinstance(self._array, _Buffer):
array = self._array._op(pred, item, x, op, kwargs)
array = self._array._op(pred, item, x, op, kwargs, makes_false_steps)
else:
array = op(pred, self._array, item, x, kwargs)
return _Buffer(array, self._pred, self._tag)
array = op(pred, self._array, item, x, kwargs, makes_false_steps)
return _Buffer(array, self._pred, self._tag, self._makes_false_steps)

@property
def at(self):
return _BufferAt(self)
return _BufferAt(self, self._makes_false_steps)

@property
def shape(self):
Expand All @@ -228,17 +253,25 @@ def size(self):

class _BufferAt(Module):
_buffer: _Buffer
_makes_false_steps: bool = field(static=True)

def __getitem__(self, item):
return _BufferItem(self._buffer, item)
return _BufferItem(self._buffer, item, self._makes_false_steps)


class _BufferItem(Module):
_buffer: _Buffer
_item: Any
_makes_false_steps: bool = field(static=True)

def set(self, x, *, pred=True, **kwargs):
return self._buffer._op(pred, self._item, x, _maybe_set, kwargs)
if pred is True:
makes_false_steps = self._makes_false_steps
else:
makes_false_steps = True
return self._buffer._op(
pred, self._item, x, _maybe_set, kwargs, makes_false_steps
)


def _is_buffer(x):
Expand All @@ -264,7 +297,7 @@ def _fixed_asarray_jvp(x, tx):
return fixed_asarray(x), fixed_asarray(tx)


def common_rewrite(cond_fun, body_fun, init_val, max_steps, buffers):
def common_rewrite(cond_fun, body_fun, init_val, max_steps, buffers, makes_false_steps):
"""Handles:
- Efficient in-place updates;
Expand Down Expand Up @@ -314,7 +347,7 @@ def is_our_buffer(node):
def wrap_buffer(leaf):
if not is_array(leaf):
raise ValueError("Only arrays can be treated as buffers.")
return _Buffer(leaf, pred, tag)
return _Buffer(leaf, pred, tag, makes_false_steps)

def unwrap_and_select(leaf, leaf2):
if is_our_buffer(leaf):
Expand All @@ -325,7 +358,7 @@ def unwrap_and_select(leaf, leaf2):
assert is_array(leaf2._array)
return leaf2._array
else:
return _select_if_vmap(pred, leaf2, leaf)
return _select_if_vmap(pred, leaf2, leaf, makes_false_steps)

step, pred, val = val
_, _, buffer_val = tree_at(
Expand Down
2 changes: 1 addition & 1 deletion equinox/internal/_loop/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def while_loop(
if kind == "lax":
del kind, checkpoints, base
cond_fun_, body_fun_, init_val_, _ = common_rewrite(
cond_fun, body_fun, init_val, max_steps, buffers
cond_fun, body_fun, init_val, max_steps, buffers, makes_false_steps=False
)
del cond_fun, body_fun, init_val
_, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "equinox"
version = "0.10.7"
version = "0.10.8"
description = "Elegant easy-to-use neural networks in JAX."
readme = "README.md"
requires-python ="~=3.9"
Expand Down

0 comments on commit 07351ec

Please sign in to comment.