You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm using brax 0.11.0 and am confused about parameter gradients. In this issue from last year, pipelines appeared to be consistent, however I'm looking at parameter gradients w.r.t. link mass and they no longer appear consistent across pipelines
If you run the following script,
import jax
from jax import numpy as jp
from brax.positional import pipeline as positional_pipeline
from brax.generalized import pipeline as generalized_pipeline
from brax.mjx import pipeline as mjx_pipeline
from brax.io import mjcf
ball = """
<mujoco>
<option gravity="0 0 -9.81" timestep="0.002"/>
<worldbody>
<geom name="floor" pos="0 0 0" size="40 40 40" type="plane" friction="1.0"/>
<body pos="0 0 0.2">
<joint type="free"/>
<geom size=".2" mass="1.0" friction="1.0"/>
</body>
</worldbody>
</mujoco>
"""
for name, pipeline in {
"pos": positional_pipeline,
"gen": generalized_pipeline,
"mjx": mjx_pipeline,
}.items():
print(name)
try:
def simulation(pipeline, sys, params):
init_qd = jp.zeros(6).at[0].set(1.0) # 1m/s in +x
mass = params
sys = sys.replace(
link=sys.link.replace(inertia=sys.link.inertia.replace(mass=mass))
)
state = jax.jit(pipeline.init)(sys, sys.init_q, init_qd)
for i in range(1):
state = jax.jit(pipeline.step)(sys, state, None)
return state.qd[0]
sys = mjcf.loads(ball)
mass = sys.link.inertia.mass
params = jp.ones_like(mass)
x = simulation(pipeline, sys, params)
grad_x = jax.grad(simulation, argnums=2)(pipeline, sys, params)
print(f"{x=}, {grad_x=}")
except Exception as exc:
print(exc)
the output is
pos
x=Array(0.99439585, dtype=float32), grad_x=Array([-0.00160096], dtype=float32)
gen
reshape total size must be unchanged, got new_sizes (6, 3) (of total size 18) for shape (6,) (of total size 6).
mjx
x=Array(1., dtype=float32), grad_x=Array([0.], dtype=float32)
Should we expect to be able to compute parameter gradients for generalized and MJX pipelines?
The text was updated successfully, but these errors were encountered:
Hi,
I'm using brax 0.11.0 and am confused about parameter gradients. In this issue from last year, pipelines appeared to be consistent, however I'm looking at parameter gradients w.r.t. link mass and they no longer appear consistent across pipelines
If you run the following script,
the output is
Should we expect to be able to compute parameter gradients for generalized and MJX pipelines?
The text was updated successfully, but these errors were encountered: