Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent parameter gradient behaviour across pipelines #540

Open
JoeMWatson opened this issue Oct 22, 2024 · 0 comments
Open

Inconsistent parameter gradient behaviour across pipelines #540

JoeMWatson opened this issue Oct 22, 2024 · 0 comments

Comments

@JoeMWatson
Copy link

Hi,

I'm using brax 0.11.0 and am confused about parameter gradients. In this issue from last year, pipelines appeared to be consistent, however I'm looking at parameter gradients w.r.t. link mass and they no longer appear consistent across pipelines

If you run the following script,

import jax
from jax import numpy as jp
from brax.positional import pipeline as positional_pipeline
from brax.generalized import pipeline as generalized_pipeline
from brax.mjx import pipeline as mjx_pipeline
from brax.io import mjcf
ball = """
<mujoco>
  <option gravity="0 0 -9.81" timestep="0.002"/>
  <worldbody>
    <geom name="floor" pos="0 0 0" size="40 40 40" type="plane" friction="1.0"/>
    <body pos="0 0 0.2">
      <joint type="free"/>
      <geom size=".2" mass="1.0" friction="1.0"/>
    </body>
  </worldbody>
</mujoco>
"""
for name, pipeline in {
    "pos": positional_pipeline,
    "gen": generalized_pipeline,
    "mjx": mjx_pipeline,
}.items():
    print(name)
    try:
        def simulation(pipeline, sys, params):
            init_qd = jp.zeros(6).at[0].set(1.0)  # 1m/s in +x
            mass = params
            sys = sys.replace(
                link=sys.link.replace(inertia=sys.link.inertia.replace(mass=mass))
            )
            state = jax.jit(pipeline.init)(sys, sys.init_q, init_qd)
            for i in range(1):
                state = jax.jit(pipeline.step)(sys, state, None)
            return state.qd[0]

        sys = mjcf.loads(ball)
        mass = sys.link.inertia.mass
        params = jp.ones_like(mass)
        x = simulation(pipeline, sys, params)
        grad_x = jax.grad(simulation, argnums=2)(pipeline, sys, params)
        print(f"{x=}, {grad_x=}")
    except Exception as exc:
        print(exc)

the output is

pos
x=Array(0.99439585, dtype=float32), grad_x=Array([-0.00160096], dtype=float32)
gen
reshape total size must be unchanged, got new_sizes (6, 3) (of total size 18) for shape (6,) (of total size 6).
mjx
x=Array(1., dtype=float32), grad_x=Array([0.], dtype=float32)

Should we expect to be able to compute parameter gradients for generalized and MJX pipelines?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant