From 8f327d9c98ad1f3120d355bc74d80056a8115edd Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Wed, 18 Sep 2024 17:38:55 +0200 Subject: [PATCH] Hopefully compatibility with JAX 0.4.32 --- equinox/_ad.py | 37 ++++++++++++++++++++++++++++++++++--- equinox/_pretty_print.py | 3 +++ tests/test_pformat.py | 2 +- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/equinox/_ad.py b/equinox/_ad.py index 43f2f17d..2aaac6af 100644 --- a/equinox/_ad.py +++ b/equinox/_ad.py @@ -538,18 +538,27 @@ def _unflatten(flat_pytree): _FlatPyTree = tuple[list[_T], PyTreeDef] +def _strip_weak_dtype( + tree: PyTree[jax.ShapeDtypeStruct], +) -> PyTree[jax.ShapeDtypeStruct]: + return jtu.tree_map( + lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=x.sharding), tree + ) + + def _check_closure_convert_input(self, args, kwargs): self_in_dynamic_struct = _unflatten(self.in_dynamic_struct) self_in_static = _unflatten(self.in_static) in_dynamic, in_static = partition((args, kwargs), is_array) - in_dynamic_struct = jax.eval_shape(lambda: in_dynamic) + in_dynamic_struct = _strip_weak_dtype(jax.eval_shape(lambda: in_dynamic)) # `is` because `tree_equal` may return a tracer if tree_equal(in_dynamic_struct, self_in_dynamic_struct) is not True: raise ValueError( "Closure-converted function called with different dynamic arguments to " "the example arguments provided:\n\n" f"Called with: {tree_pformat(in_dynamic)}\n\n" - f"Closure-converted with: {tree_pformat(self_in_dynamic_struct)}" + "Closure-converted with: " + f"{tree_pformat(self_in_dynamic_struct, struct_as_array=True)}" ) if tree_equal(in_static, self_in_static) is not True: raise ValueError( @@ -663,7 +672,29 @@ def f(x, y): ``` """ in_dynamic, in_static = partition((args, kwargs), _is_struct) - in_dynamic_struct = jax.eval_shape(lambda: in_dynamic) + # Strip `weak_dtype`. This didn't used to exist on `jax.ShapeDtypeStruct`, and then + # got added: https://github.com/patrick-kidger/equinox/issues/854 + # + # If we were writing from scratch then we'd keep this in, but for backward + # compatibility we instead strip it and treat every dtype as non-weak. + # + # Note that there are *two* kinds of backward compatibility we're thinking about + # here. The first more important kind of backward compatibility is when doing + # something like + # ```python + # g = filter_closure_convert(f, some_array) + # g(some_int) + # ``` + # (which indeed is the case that's exploding in the linked issue above). This worked + # before! We'd like it to keep working. + # + # The second, less important, is how we trace the current function into a jaxpr. + # Whether we trace with weak dtypes or not can give different results. + # In this case, we all survived for a long time without even noticing we were doing + # this... so probably we're actually happy with either choice. + # Regardless, stripping weak dtypes here again means that we obtain the same + # behaviour as before. + in_dynamic_struct = _strip_weak_dtype(jax.eval_shape(lambda: in_dynamic)) in_dynamic_struct = jtu.tree_flatten(in_dynamic_struct) in_static = jtu.tree_flatten(in_static) if isinstance(fn, types.FunctionType) and fn.__closure__ is None: diff --git a/equinox/_pretty_print.py b/equinox/_pretty_print.py index bb3c2d9a..3fcefa57 100644 --- a/equinox/_pretty_print.py +++ b/equinox/_pretty_print.py @@ -138,6 +138,9 @@ def _pformat_array( else: dtype = obj.dtype.name if isinstance(obj, (jax.Array, jax.ShapeDtypeStruct)): + # Added in JAX 0.4.32 to `ShapeDtypeStruct` + if getattr(obj, "weak_type", False): + dtype = f"weak_{dtype}" kind = None elif isinstance(obj, np.ndarray): kind = "numpy" diff --git a/tests/test_pformat.py b/tests/test_pformat.py index d8590251..f78dfe12 100644 --- a/tests/test_pformat.py +++ b/tests/test_pformat.py @@ -40,7 +40,7 @@ class M(typing.NamedTuple): def test_jax_array(): - assert eqx.tree_pformat(jnp.array(1)) == "i32[]" + assert eqx.tree_pformat(jnp.array(1)) == "weak_i32[]" assert eqx.tree_pformat(jnp.arange(12).reshape(3, 4)) == "i32[3,4]" array = "Array(1, dtype=int32, weak_type=True)" device_array = "DeviceArray(1, dtype=int32, weak_type=True)"