Skip to content

Commit

Permalink
Use sparse (uncompressed) actuator_moment in mj_transmission.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 692179704
Change-Id: Ic30ac5a98dc13de2028e378df65dc88ba3912bf5
  • Loading branch information
thowell authored and copybara-github committed Nov 1, 2024
1 parent 1d58576 commit a51f346
Show file tree
Hide file tree
Showing 13 changed files with 230 additions and 81 deletions.
47 changes: 41 additions & 6 deletions mjx/mujoco/mjx/_src/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ def make_data(
'wrap_obj': (m.nwrap, 2, jp.int32),
'wrap_xpos': (m.nwrap, 6, float),
'actuator_length': (m.nu, float),
'moment_rownnz': (m.nu, jp.int32),
'moment_rowadr': (m.nu, jp.int32),
'moment_colind': (m.nu, m.nv, jp.int32),
'actuator_moment': (m.nu, m.nv, float),
'crb': (m.nbody, 10, float),
'qM': (m.nM, float) if support.is_sparse(m) else (m.nv, m.nv, float),
Expand Down Expand Up @@ -427,6 +430,25 @@ def get_data_into(
result_i.contact.efc_address[:] = efc_map[result_i.contact.efc_address]
continue

# MuJoCo actuator_moment is sparse, MJX uses a dense representation.
if field.name == 'actuator_moment' and m.nu:
moment_rownnz = np.zeros(m.nu, dtype=int)
moment_rowadr = np.zeros(m.nu, dtype=int)
moment_colind = np.zeros(m.nu * m.nv, dtype=int)
actuator_moment = np.zeros(m.nu * m.nv)
mujoco.mju_dense2sparse(
actuator_moment,
d.actuator_moment,
moment_rownnz,
moment_rowadr,
moment_colind,
)
result_i.moment_rownnz[:] = moment_rownnz
result_i.moment_rowadr[:] = moment_rowadr
result_i.moment_colind[:] = moment_colind.reshape((m.nu, m.nv))
result_i.actuator_moment[:] = actuator_moment.reshape((m.nu, m.nv))
continue

value = getattr(d_i, field.name)

if field.name in ('nefc', 'ncon'):
Expand Down Expand Up @@ -532,19 +554,32 @@ def put_data(
# MJX does not support islanding, so only transfer the first solver_niter
fields['solver_niter'] = fields['solver_niter'][0]

# convert sparse representation of actuator_moment to dense matrix
moment = np.zeros((m.nu, m.nv))
mujoco.mju_sparse2dense(
moment,
d.actuator_moment.reshape(-1),
d.moment_rownnz,
d.moment_rowadr,
d.moment_colind.reshape(-1),
)
fields['actuator_moment'] = moment

contact, contact_map = _make_contact(d.contact, dim, efc_address)

# pad efc fields: MuJoCo efc arrays are sparse for inactive constraints.
# efc_J is also optionally column-sparse (typically for large nv). MJX is
# neither: it contains zeros for inactive constraints, and efc_J is always
# (nefc, nv). this may change in the future.
if mujoco.mj_isSparse(m):
nr = d.efc_J_rownnz.shape[0]
efc_j = np.zeros((nr, m.nv))
for i in range(nr):
rowadr = d.efc_J_rowadr[i]
for j in range(d.efc_J_rownnz[i]):
efc_j[i, d.efc_J_colind[rowadr + j]] = fields['efc_J'][rowadr + j]
efc_j = np.zeros((d.efc_J_rownnz.shape[0], m.nv))
mujoco.mju_sparse2dense(
efc_j,
fields['efc_J'],
d.efc_J_rownnz,
d.efc_J_rowadr,
d.efc_J_colind,
)
fields['efc_J'] = efc_j
else:
fields['efc_J'] = fields['efc_J'].reshape((-1 if m.nv else 0, m.nv))
Expand Down
24 changes: 22 additions & 2 deletions mjx/mujoco/mjx/_src/smooth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,17 @@ def test_smooth(self):
# transmission
dx = jax.jit(mjx.transmission)(mx, dx)
_assert_attr_eq(d, dx, 'actuator_length')
_assert_attr_eq(d, dx, 'actuator_moment')

# convert sparse actuator_moment to dense representation
moment = np.zeros((m.nu, m.nv))
mujoco.mju_sparse2dense(
moment,
d.actuator_moment.reshape(-1),
d.moment_rownnz,
d.moment_rowadr,
d.moment_colind.reshape(-1),
)
_assert_eq(moment, dx.actuator_moment, 'actuator_moment')

def test_disable_gravity(self):
m = mujoco.MjModel.from_xml_string("""
Expand Down Expand Up @@ -178,7 +188,17 @@ def test_site_transmission(self):
mujoco.mj_transmission(m, d)
dx = jax.jit(mjx.transmission)(mx, dx)
_assert_attr_eq(d, dx, 'actuator_length')
_assert_attr_eq(d, dx, 'actuator_moment')

# convert sparse actuator_moment to dense representation
moment = np.zeros((m.nu, m.nv))
mujoco.mju_sparse2dense(
moment,
d.actuator_moment.reshape(-1),
d.moment_rownnz,
d.moment_rowadr,
d.moment_colind.reshape(-1),
)
_assert_eq(moment, dx.actuator_moment, 'actuator_moment')

def test_subtree_vel(self):
"""Tests MJX subtree_vel function matches MuJoCo mj_subtreeVel."""
Expand Down
6 changes: 6 additions & 0 deletions mjx/mujoco/mjx/_src/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,9 @@ class Data(PyTreeNode):
wrap_obj: geom id; -1: site; -2: pulley (nwrap*2,)
wrap_xpos: Cartesian 3D points in all path (nwrap*2, 3)
actuator_length: actuator lengths (nu,)
moment_rownnz: number of non-zeros in actuator_moment row (nu,)
moment_rowadr: row start address in colind array (nu,)
moment_colind: column indices in sparse Jacobian (nu, nv)
actuator_moment: actuator moments (nu, nv)
crb: com-based composite inertia and mass (nbody, 10)
qM: total inertia if sparse: (nM,)
Expand Down Expand Up @@ -1350,6 +1353,9 @@ class Data(PyTreeNode):
wrap_obj: jax.Array
wrap_xpos: jax.Array
actuator_length: jax.Array
moment_rownnz: jax.Array = _restricted_to('mujoco') # pylint:disable=invalid-name
moment_rowadr: jax.Array = _restricted_to('mujoco') # pylint:disable=invalid-name
moment_colind: jax.Array = _restricted_to('mujoco') # pylint:disable=invalid-name
actuator_moment: jax.Array
crb: jax.Array
qM: jax.Array # pylint:disable=invalid-name
Expand Down
25 changes: 21 additions & 4 deletions mjx/mujoco/mjx/integration_test/smooth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,27 @@ def test_transmission(self, seed):
mujoco.mj_transmission(m, d)
dx = transmission_jit_fn(mx, dx)

for field in ['actuator_length', 'actuator_moment']:
_assert_attr_eq(
d, dx, field, seed, f'transmission{seed}', atol=1e-4
)
_assert_attr_eq(
d, dx, 'actuator_length', seed, f'transmission{seed}', atol=1e-4
)

# convert sparse actuator_moment to dense representation
moment = np.zeros((m.nu, m.nv))
mujoco.mju_sparse2dense(
moment,
d.actuator_moment.reshape(-1),
d.moment_rownnz,
d.moment_rowadr,
d.moment_colind.reshape(-1),
)
_assert_eq(
moment,
dx.actuator_moment,
'actuator_moment',
seed,
f'transmission{seed}',
atol=1e-4,
)


if __name__ == '__main__':
Expand Down
10 changes: 9 additions & 1 deletion python/LQR.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,15 @@
},
"outputs": [],
"source": [
"ctrl0 = np.atleast_2d(qfrc0) @ np.linalg.pinv(data.actuator_moment)\n",
"actuator_moment = np.zeros((model.nu, model.nv))\n",
"mujoco.mju_sparse2dense(\n",
" actuator_moment,\n",
" data.actuator_moment,\n",
" data.moment_rownnz,\n",
" data.moment_rowadr,\n",
" data.moment_colind,\n",
")\n",
"ctrl0 = np.atleast_2d(qfrc0) @ np.linalg.pinv(actuator_moment)\n",
"ctrl0 = ctrl0.flatten() # Save the ctrl setpoint.\n",
"print('control setpoint:', ctrl0)"
]
Expand Down
Loading

0 comments on commit a51f346

Please sign in to comment.