Skip to content

Commit ab82975

Browse files
authored
Parse springref and ref (#1288)
1 parent 58cbaf8 commit ab82975

File tree

5 files changed

+340
-6
lines changed

5 files changed

+340
-6
lines changed

newton/_src/solvers/mujoco/solver_mujoco.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,30 @@ def register_custom_attributes(cls, builder: ModelBuilder) -> None:
305305
mjcf_attribute_name="damping",
306306
)
307307
)
308+
builder.add_custom_attribute(
309+
ModelBuilder.CustomAttribute(
310+
name="dof_springref",
311+
frequency=ModelAttributeFrequency.JOINT_DOF,
312+
assignment=ModelAttributeAssignment.MODEL,
313+
dtype=wp.float32,
314+
default=0.0,
315+
namespace="mujoco",
316+
usd_attribute_name="mjc:springref",
317+
mjcf_attribute_name="springref",
318+
)
319+
)
320+
builder.add_custom_attribute(
321+
ModelBuilder.CustomAttribute(
322+
name="dof_ref",
323+
frequency=ModelAttributeFrequency.JOINT_DOF,
324+
assignment=ModelAttributeAssignment.MODEL,
325+
dtype=wp.float32,
326+
default=0.0,
327+
namespace="mujoco",
328+
usd_attribute_name="mjc:ref",
329+
mjcf_attribute_name="ref",
330+
)
331+
)
308332
builder.add_custom_attribute(
309333
ModelBuilder.CustomAttribute(
310334
name="jnt_actgravcomp",
@@ -672,6 +696,12 @@ def __init__(
672696
self.update_data_interval = update_data_interval
673697
self._step = 0
674698

699+
# Check if dof_ref is used - if so, we use MuJoCo's FK (eval_fk=False)
700+
# because ref is only handled by MuJoCo via qpos0
701+
mujoco_attrs = getattr(model, "mujoco", None)
702+
dof_ref = getattr(mujoco_attrs, "dof_ref", None) if mujoco_attrs is not None else None
703+
self._has_ref = dof_ref is not None and np.any(dof_ref.numpy() != 0.0)
704+
675705
if self.mjw_model is not None:
676706
self.mjw_model.opt.run_collision_detection = use_mujoco_contacts
677707

@@ -682,14 +712,17 @@ def mujoco_warp_step(self):
682712
@event_scope
683713
@override
684714
def step(self, state_in: State, state_out: State, control: Control, contacts: Contacts, dt: float):
715+
# When ref is used, we rely on MuJoCo's FK (eval_fk=False) because ref is handled by MuJoCo via qpos0
716+
eval_fk = not self._has_ref
717+
685718
if self.use_mujoco_cpu:
686719
self.apply_mjc_control(self.model, state_in, control, self.mj_data)
687720
if self.update_data_interval > 0 and self._step % self.update_data_interval == 0:
688721
# XXX updating the mujoco state at every step may introduce numerical instability
689722
self.update_mjc_data(self.mj_data, self.model, state_in)
690723
self.mj_model.opt.timestep = dt
691724
self._mujoco.mj_step(self.mj_model, self.mj_data)
692-
self.update_newton_state(self.model, state_out, self.mj_data)
725+
self.update_newton_state(self.model, state_out, self.mj_data, eval_fk=eval_fk)
693726
else:
694727
self.enable_rne_postconstraint(state_out)
695728
self.apply_mjc_control(self.model, state_in, control, self.mjw_data)
@@ -703,7 +736,7 @@ def step(self, state_in: State, state_out: State, control: Control, contacts: Co
703736
self.convert_contacts_to_mjwarp(self.model, state_in, contacts)
704737
self.mujoco_warp_step()
705738

706-
self.update_newton_state(self.model, state_out, self.mjw_data)
739+
self.update_newton_state(self.model, state_out, self.mjw_data, eval_fk=eval_fk)
707740
self._step += 1
708741
return state_out
709742

@@ -1370,6 +1403,8 @@ def get_custom_attribute(name: str) -> nparray | None:
13701403
joint_stiffness = get_custom_attribute("dof_passive_stiffness")
13711404
joint_damping = get_custom_attribute("dof_passive_damping")
13721405
joint_actgravcomp = get_custom_attribute("jnt_actgravcomp")
1406+
joint_springref = get_custom_attribute("dof_springref")
1407+
joint_ref = get_custom_attribute("dof_ref")
13731408

13741409
eq_constraint_type = model.equality_constraint_type.numpy()
13751410
eq_constraint_body1 = model.equality_constraint_body1.numpy()
@@ -1792,6 +1827,12 @@ def add_geoms(newton_body_id: int):
17921827
effort_limit = joint_effort_limit[ai]
17931828
joint_params["actfrclimited"] = True
17941829
joint_params["actfrcrange"] = (-effort_limit, effort_limit)
1830+
1831+
if joint_springref is not None:
1832+
joint_params["springref"] = joint_springref[ai]
1833+
if joint_ref is not None:
1834+
joint_params["ref"] = joint_ref[ai]
1835+
17951836
axname = name
17961837
if lin_axis_count > 1 or ang_axis_count > 1:
17971838
axname += "_lin"
@@ -1873,6 +1914,11 @@ def add_geoms(newton_body_id: int):
18731914
joint_params["actfrclimited"] = True
18741915
joint_params["actfrcrange"] = (-effort_limit, effort_limit)
18751916

1917+
if joint_springref is not None:
1918+
joint_params["springref"] = joint_springref[ai]
1919+
if joint_ref is not None:
1920+
joint_params["ref"] = joint_ref[ai]
1921+
18761922
axname = name
18771923
if lin_axis_count > 1 or ang_axis_count > 1:
18781924
axname += "_ang"
@@ -2500,6 +2546,7 @@ def update_joint_properties(self):
25002546
if self.mjc_jnt_to_newton_jnt is not None and self.mjc_jnt_to_newton_jnt.shape[1] > 0:
25012547
nworld = self.mjc_jnt_to_newton_jnt.shape[0]
25022548
njnt = self.mjc_jnt_to_newton_jnt.shape[1]
2549+
25032550
wp.launch(
25042551
update_joint_transforms_kernel,
25052552
dim=(nworld, njnt),

newton/_src/utils/import_mjcf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,6 @@ def parse_body(
774774
)
775775
)
776776
else:
777-
# TODO parse ref, springref values from joint_attrib
778777
# When parent is world (-1), use world_xform to respect the xform argument
779778
if parent == -1:
780779
parent_xform_for_joint = world_xform * wp.transform(joint_pos, wp.quat_identity())

newton/tests/test_import_mjcf.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,6 +1665,71 @@ def test_eq_solref_parsing(self):
16651665
for i, (a, e) in enumerate(zip(actual, expected, strict=False)):
16661666
self.assertAlmostEqual(a, e, places=4, msg=f"eq_solref[{eq_idx}][{i}] should be {e}, got {a}")
16671667

1668+
def test_ref_attribute_parsing(self):
1669+
"""Test that 'ref' attribute is parsed"""
1670+
mjcf_content = """<?xml version="1.0" encoding="utf-8"?>
1671+
<mujoco model="test">
1672+
<worldbody>
1673+
<body name="base">
1674+
<geom type="box" size="0.1 0.1 0.1"/>
1675+
<body name="child1" pos="0 0 1">
1676+
<joint name="hinge" type="hinge" axis="0 1 0" ref="90"/>
1677+
<geom type="box" size="0.1 0.1 0.1"/>
1678+
<body name="child2" pos="0 0 1">
1679+
<joint name="slide" type="slide" axis="0 0 1" ref="0.5"/>
1680+
<geom type="box" size="0.1 0.1 0.1"/>
1681+
</body>
1682+
</body>
1683+
</body>
1684+
</worldbody>
1685+
</mujoco>"""
1686+
1687+
builder = newton.ModelBuilder()
1688+
SolverMuJoCo.register_custom_attributes(builder)
1689+
builder.add_mjcf(mjcf_content)
1690+
model = builder.finalize()
1691+
1692+
# Verify custom attribute parsing
1693+
qd_start = model.joint_qd_start.numpy()
1694+
dof_ref = model.mujoco.dof_ref.numpy()
1695+
1696+
hinge_idx = model.joint_key.index("hinge")
1697+
self.assertAlmostEqual(dof_ref[qd_start[hinge_idx]], 90.0, places=4)
1698+
1699+
slide_idx = model.joint_key.index("slide")
1700+
self.assertAlmostEqual(dof_ref[qd_start[slide_idx]], 0.5, places=4)
1701+
1702+
def test_springref_attribute_parsing(self):
1703+
"""Test that 'springref' attribute is parsed for hinge and slide joints."""
1704+
mjcf_content = """<?xml version="1.0" encoding="utf-8"?>
1705+
<mujoco model="test">
1706+
<worldbody>
1707+
<body name="base">
1708+
<geom type="box" size="0.1 0.1 0.1"/>
1709+
<body name="child1" pos="0 0 1">
1710+
<joint name="hinge" type="hinge" axis="0 0 1" stiffness="100" springref="30"/>
1711+
<geom type="box" size="0.1 0.1 0.1"/>
1712+
<body name="child2" pos="0 0 1">
1713+
<joint name="slide" type="slide" axis="0 0 1" stiffness="50" springref="0.25"/>
1714+
<geom type="box" size="0.1 0.1 0.1"/>
1715+
</body>
1716+
</body>
1717+
</body>
1718+
</worldbody>
1719+
</mujoco>"""
1720+
1721+
builder = newton.ModelBuilder()
1722+
SolverMuJoCo.register_custom_attributes(builder)
1723+
builder.add_mjcf(mjcf_content)
1724+
model = builder.finalize()
1725+
springref = model.mujoco.dof_springref.numpy()
1726+
qd_start = model.joint_qd_start.numpy()
1727+
1728+
hinge_idx = model.joint_key.index("hinge")
1729+
self.assertAlmostEqual(springref[qd_start[hinge_idx]], 30.0, places=4)
1730+
slide_idx = model.joint_key.index("slide")
1731+
self.assertAlmostEqual(springref[qd_start[slide_idx]], 0.25, places=4)
1732+
16681733

16691734
if __name__ == "__main__":
16701735
unittest.main(verbosity=2)

newton/tests/test_import_usd.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2523,6 +2523,160 @@ def PhysicsRevoluteJoint "Joint2"
25232523
self.assertTrue(np.any(jnt_actgravcomp))
25242524
self.assertTrue(np.any(~jnt_actgravcomp))
25252525

2526+
@unittest.skipUnless(USD_AVAILABLE, "Requires usd-core")
2527+
def test_ref_attribute_parsing(self):
2528+
"""Test that 'mjc:ref' attribute is parsed."""
2529+
from pxr import Usd # noqa: PLC0415
2530+
2531+
usd_content = """#usda 1.0
2532+
(
2533+
metersPerUnit = 1.0
2534+
upAxis = "Z"
2535+
)
2536+
2537+
def Xform "Articulation" (
2538+
prepend apiSchemas = ["PhysicsArticulationRootAPI"]
2539+
)
2540+
{
2541+
def Cube "base" (
2542+
prepend apiSchemas = ["PhysicsRigidBodyAPI", "PhysicsCollisionAPI"]
2543+
)
2544+
{
2545+
double3 xformOp:translate = (0, 0, 0)
2546+
uniform token[] xformOpOrder = ["xformOp:translate"]
2547+
}
2548+
2549+
def Cube "child1" (
2550+
prepend apiSchemas = ["PhysicsRigidBodyAPI", "PhysicsCollisionAPI"]
2551+
)
2552+
{
2553+
double3 xformOp:translate = (0, 0, 1)
2554+
uniform token[] xformOpOrder = ["xformOp:translate"]
2555+
}
2556+
2557+
def PhysicsRevoluteJoint "revolute_joint"
2558+
{
2559+
token physics:axis = "Y"
2560+
rel physics:body0 = </Articulation/base>
2561+
rel physics:body1 = </Articulation/child1>
2562+
float mjc:ref = 90.0
2563+
}
2564+
}
2565+
"""
2566+
stage = Usd.Stage.CreateInMemory()
2567+
stage.GetRootLayer().ImportFromString(usd_content)
2568+
2569+
builder = newton.ModelBuilder()
2570+
SolverMuJoCo.register_custom_attributes(builder)
2571+
builder.add_usd(stage)
2572+
model = builder.finalize()
2573+
2574+
# Verify custom attribute parsing
2575+
self.assertTrue(hasattr(model, "mujoco"))
2576+
self.assertTrue(hasattr(model.mujoco, "dof_ref"))
2577+
dof_ref = model.mujoco.dof_ref.numpy()
2578+
qd_start = model.joint_qd_start.numpy()
2579+
2580+
revolute_joint_idx = model.joint_key.index("/Articulation/revolute_joint")
2581+
self.assertAlmostEqual(dof_ref[qd_start[revolute_joint_idx]], 90.0, places=4)
2582+
2583+
@unittest.skipUnless(USD_AVAILABLE, "Requires usd-core")
2584+
def test_springref_attribute_parsing(self):
2585+
"""Test that 'mjc:springref' attribute is parsed for revolute and prismatic joints."""
2586+
from pxr import Usd # noqa: PLC0415
2587+
2588+
usd_content = """#usda 1.0
2589+
(
2590+
upAxis = "Z"
2591+
)
2592+
2593+
def PhysicsScene "physicsScene"
2594+
{
2595+
}
2596+
2597+
def Xform "Articulation" (
2598+
prepend apiSchemas = ["PhysicsArticulationRootAPI"]
2599+
)
2600+
{
2601+
def Xform "Body0" (
2602+
prepend apiSchemas = ["PhysicsRigidBodyAPI"]
2603+
)
2604+
{
2605+
double3 xformOp:translate = (0, 0, 0)
2606+
uniform token[] xformOpOrder = ["xformOp:translate"]
2607+
def Cube "Collision0" (
2608+
prepend apiSchemas = ["PhysicsCollisionAPI"]
2609+
)
2610+
{
2611+
double size = 0.2
2612+
}
2613+
}
2614+
2615+
def Xform "Body1" (
2616+
prepend apiSchemas = ["PhysicsRigidBodyAPI"]
2617+
)
2618+
{
2619+
double3 xformOp:translate = (1, 0, 0)
2620+
uniform token[] xformOpOrder = ["xformOp:translate"]
2621+
def Cube "Collision1" (
2622+
prepend apiSchemas = ["PhysicsCollisionAPI"]
2623+
)
2624+
{
2625+
double size = 0.2
2626+
}
2627+
}
2628+
2629+
def Xform "Body2" (
2630+
prepend apiSchemas = ["PhysicsRigidBodyAPI"]
2631+
)
2632+
{
2633+
double3 xformOp:translate = (2, 0, 0)
2634+
uniform token[] xformOpOrder = ["xformOp:translate"]
2635+
def Cube "Collision2" (
2636+
prepend apiSchemas = ["PhysicsCollisionAPI"]
2637+
)
2638+
{
2639+
double size = 0.2
2640+
}
2641+
}
2642+
2643+
def PhysicsRevoluteJoint "revolute_joint" (
2644+
prepend apiSchemas = ["PhysicsDriveAPI:angular"]
2645+
)
2646+
{
2647+
rel physics:body0 = </Articulation/Body0>
2648+
rel physics:body1 = </Articulation/Body1>
2649+
float mjc:springref = 30.0
2650+
}
2651+
2652+
def PhysicsPrismaticJoint "prismatic_joint"
2653+
{
2654+
token physics:axis = "Z"
2655+
rel physics:body0 = </Articulation/Body1>
2656+
rel physics:body1 = </Articulation/Body2>
2657+
float mjc:springref = 0.25
2658+
}
2659+
}
2660+
"""
2661+
stage = Usd.Stage.CreateInMemory()
2662+
stage.GetRootLayer().ImportFromString(usd_content)
2663+
2664+
builder = newton.ModelBuilder()
2665+
SolverMuJoCo.register_custom_attributes(builder)
2666+
builder.add_usd(stage)
2667+
model = builder.finalize()
2668+
2669+
self.assertTrue(hasattr(model, "mujoco"))
2670+
self.assertTrue(hasattr(model.mujoco, "dof_springref"))
2671+
springref = model.mujoco.dof_springref.numpy()
2672+
qd_start = model.joint_qd_start.numpy()
2673+
2674+
revolute_joint_idx = model.joint_key.index("/Articulation/revolute_joint")
2675+
self.assertAlmostEqual(springref[qd_start[revolute_joint_idx]], 30.0, places=4)
2676+
2677+
prismatic_joint_idx = model.joint_key.index("/Articulation/prismatic_joint")
2678+
self.assertAlmostEqual(springref[qd_start[prismatic_joint_idx]], 0.25, places=4)
2679+
25262680

25272681
if __name__ == "__main__":
25282682
unittest.main(verbosity=2, failfast=True)

0 commit comments

Comments
 (0)