Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with jax and jaxlib versions 0.4.32 and 0.4.33 #854

Closed
dkweiss31 opened this issue Sep 17, 2024 · 4 comments
Closed

Issue with jax and jaxlib versions 0.4.32 and 0.4.33 #854

dkweiss31 opened this issue Sep 17, 2024 · 4 comments

Comments

@dkweiss31
Copy link

It might make more sense to post this as a jax issue, but I just thought I'd call your attention to this if you weren't aware: equinox and diffrax (and probably other libraries) seem to be broken by these newest versions 0.4.32 and 0.4.33 of jax and jaxlib, with e.g. some of the tests/test_scan.py tests failing

============================================================================ short test summary info =============================================================================
FAILED tests/test_scan.py::test_scan[True-4-checkpointed] - ValueError: Closure-converted function called with different dynamic arguments to the example arguments provided.
FAILED tests/test_scan.py::test_scan[True-None-checkpointed] - ValueError: Closure-converted function called with different dynamic arguments to the example arguments provided.

These tests pass with jax version 0.4.31. I get a similar error for multiple tests in diffrax. Happy to redirect this to the jax folks if you think thats more appropriate!

@patrick-kidger
Copy link
Owner

patrick-kidger commented Sep 17, 2024

Ach. When I get a bit of time I can maybe try reducing that to a MWE without Equinox, but I'd suggest raising it with the JAX folks ASAP. They might already have a hunch what this is, and this looks like a major break.

@lockwo
Copy link
Contributor

lockwo commented Sep 17, 2024

As @/hawkinsp pointed to in the other thread (jax-ml/jax#23690), I think it does have to do with the new weak type attribute to the shapestruct. In the specific example you linked in the thread, self_in_dynamic_struct in _check_closure_convert_input is identical tree struct/dtype to the other, but fails a tree at because it has weak_type set to true for some of the ints (I checked this by removing the weak types with a self_in_dynamic_struct = jtu.tree_map(lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.shape), self_in_dynamic_struct) and it worked). Maybe that information is helpful in finding the problem?

I am curious to see what the solution is tho, since I don't know much about the philosophy of weak types in equinox (e.g. presumably its correct to have eqx.tree_equal(jax.ShapeDtypeStruct((1,), jnp.int32, weak_type=True), jax.ShapeDtypeStruct((1,), jnp.int32)) be false, since they have different attributes, so is the solution to just find where the weak type is getting introduced and make sure its not a weak type?).

@hawkinsp
Copy link
Contributor

I suspect you can simply strip off the weak_type attribute for comparisons to achieve the previous behavior, although it's possible a difference of weak_type does point to a real problem.

@patrick-kidger
Copy link
Owner

Closing -- fixed in #856! I've just done a hotfix release for this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants