Skip to content

Commit

Permalink
Hopefully compatibility with JAX 0.4.32
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Sep 18, 2024
1 parent fab726b commit 8f327d9
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
37 changes: 34 additions & 3 deletions equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,18 +538,27 @@ def _unflatten(flat_pytree):
_FlatPyTree = tuple[list[_T], PyTreeDef]


def _strip_weak_dtype(
tree: PyTree[jax.ShapeDtypeStruct],
) -> PyTree[jax.ShapeDtypeStruct]:
return jtu.tree_map(
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=x.sharding), tree
)


def _check_closure_convert_input(self, args, kwargs):
self_in_dynamic_struct = _unflatten(self.in_dynamic_struct)
self_in_static = _unflatten(self.in_static)
in_dynamic, in_static = partition((args, kwargs), is_array)
in_dynamic_struct = jax.eval_shape(lambda: in_dynamic)
in_dynamic_struct = _strip_weak_dtype(jax.eval_shape(lambda: in_dynamic))
# `is` because `tree_equal` may return a tracer
if tree_equal(in_dynamic_struct, self_in_dynamic_struct) is not True:
raise ValueError(
"Closure-converted function called with different dynamic arguments to "
"the example arguments provided:\n\n"
f"Called with: {tree_pformat(in_dynamic)}\n\n"
f"Closure-converted with: {tree_pformat(self_in_dynamic_struct)}"
"Closure-converted with: "
f"{tree_pformat(self_in_dynamic_struct, struct_as_array=True)}"
)
if tree_equal(in_static, self_in_static) is not True:
raise ValueError(
Expand Down Expand Up @@ -663,7 +672,29 @@ def f(x, y):
```
"""
in_dynamic, in_static = partition((args, kwargs), _is_struct)
in_dynamic_struct = jax.eval_shape(lambda: in_dynamic)
# Strip `weak_dtype`. This didn't used to exist on `jax.ShapeDtypeStruct`, and then
# got added: https://github.com/patrick-kidger/equinox/issues/854
#
# If we were writing from scratch then we'd keep this in, but for backward
# compatibility we instead strip it and treat every dtype as non-weak.
#
# Note that there are *two* kinds of backward compatibility we're thinking about
# here. The first more important kind of backward compatibility is when doing
# something like
# ```python
# g = filter_closure_convert(f, some_array)
# g(some_int)
# ```
# (which indeed is the case that's exploding in the linked issue above). This worked
# before! We'd like it to keep working.
#
# The second, less important, is how we trace the current function into a jaxpr.
# Whether we trace with weak dtypes or not can give different results.
# In this case, we all survived for a long time without even noticing we were doing
# this... so probably we're actually happy with either choice.
# Regardless, stripping weak dtypes here again means that we obtain the same
# behaviour as before.
in_dynamic_struct = _strip_weak_dtype(jax.eval_shape(lambda: in_dynamic))
in_dynamic_struct = jtu.tree_flatten(in_dynamic_struct)
in_static = jtu.tree_flatten(in_static)
if isinstance(fn, types.FunctionType) and fn.__closure__ is None:
Expand Down
3 changes: 3 additions & 0 deletions equinox/_pretty_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def _pformat_array(
else:
dtype = obj.dtype.name
if isinstance(obj, (jax.Array, jax.ShapeDtypeStruct)):
# Added in JAX 0.4.32 to `ShapeDtypeStruct`
if getattr(obj, "weak_type", False):
dtype = f"weak_{dtype}"
kind = None
elif isinstance(obj, np.ndarray):
kind = "numpy"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class M(typing.NamedTuple):


def test_jax_array():
assert eqx.tree_pformat(jnp.array(1)) == "i32[]"
assert eqx.tree_pformat(jnp.array(1)) == "weak_i32[]"
assert eqx.tree_pformat(jnp.arange(12).reshape(3, 4)) == "i32[3,4]"
array = "Array(1, dtype=int32, weak_type=True)"
device_array = "DeviceArray(1, dtype=int32, weak_type=True)"
Expand Down

0 comments on commit 8f327d9

Please sign in to comment.