Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Value error when used in gradient optimization with equinox latest version #502

Closed
ParticularlyPythonicBS opened this issue Sep 18, 2024 · 4 comments

Comments

@ParticularlyPythonicBS
Copy link
Contributor

Using diffrax ode integration within a equinox nn training loop throws the error:

ValueError: Closure-converted function called with different dynamic arguments to the example arguments provided: ...

traceback.txt

stderr attached since its too long to paste into the issue.

on equinox version 0.11.6, jax version 0.4.33, and diffrax 0.6.0 while it works perfectly fine on equinox version 0.10.6, jax version 0.4.13, and diffrax 0.4.0.

Here is an MVE that replicates the traceback provided:

import diffrax
import jax.numpy as jnp
import jax
import equinox as eqx
import optax

def odeint(dynamics):
    def integrator(args, ts, y0):
        terms = diffrax.ODETerm(dynamics)
        t0 = ts[0]
        t1 = ts[-1]
        dt0 = ts[1] - ts[0]
        saveat = diffrax.SaveAt(ts=ts)
        sol = diffrax.diffeqsolve(
            terms,
            diffrax.Tsit5(),
            t0, t1, dt0,
            y0,
            args=args,
            saveat=saveat,
        )
        return sol.ys
    return integrator

def dynamics(t, y, args):
    dy = args(t) * y
    return dy

class NN(eqx.Module):
    layer: eqx.nn.Linear

    def __init__(self, key):
        self.layer = eqx.nn.Linear(1, 1, key=key)

    def __call__(self, t):
        t = jnp.array(t).flatten()
        u = self.layer(t)
        return u

def compute_loss(args, integrator, x, ts):
    loss = integrator(args, ts, x)[-1].mean()
    return loss

def make_step(controller, integrator, x, ts, optim, opt_state):
    grads = eqx.filter_grad(compute_loss)(controller, integrator, x, ts)
    updates, opt_state = optim.update(grads, opt_state)
    controller = eqx.apply_updates(controller, updates)
    return controller, opt_state

ts = jnp.arange(0.0, 1, 0.1)
y0 = jnp.array([1])

model_key = jax.random.PRNGKey(0)
neural_net = NN(model_key)
integrator = odeint(dynamics)

optimizer = optax.sgd(learning_rate=3e-3)

opt_state = optimizer.init(eqx.filter(neural_net, eqx.is_inexact_array))

neural_net, opt_state = make_step(
    neural_net,
    integrator,
    y0,
    ts,
    optimizer,
    opt_state,
)

As far as I'm aware there have been no deprecation warnings for any of this code.

Is there a better way to perform this task where an equinox neural network gives the argument for the function to be integrated using diffrax?

@lockwo
Copy link
Contributor

lockwo commented Sep 18, 2024

This looks like the same issue that in equinox that came from the new weak_type struct in jax 0.4.33 (see patrick-kidger/equinox#854, jax-ml/jax#23690).

With diffrax 0.6.0, equinox 0.11.6, and jax 0.4.31, it works.

@ParticularlyPythonicBS
Copy link
Contributor Author

Thank you so much, freezing the jax version fixes it for now. Hope the upstream issue is fixed soon.

@patrick-kidger
Copy link
Owner

Closing as fixed in Equinox v0.11.7 / patrick-kidger/equinox#856 ! Thanks for the report :)

@ParticularlyPythonicBS
Copy link
Contributor Author

Closing since fixed and last comment intended to close it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants