Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions newton/_src/solvers/mujoco/solver_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -1685,14 +1685,20 @@ def add_geoms(newton_body_id: int):

add_geoms(child)

def get_body_name(body_idx: int) -> str:
"""Get body name, handling world body (-1) correctly."""
if body_idx == -1:
return "world"
return model.body_key[body_idx]

for i in selected_constraints:
constraint_type = eq_constraint_type[i]
if constraint_type == EqType.CONNECT:
eq = spec.add_equality(objtype=mujoco.mjtObj.mjOBJ_BODY)
eq.type = mujoco.mjtEq.mjEQ_CONNECT
eq.active = eq_constraint_enabled[i]
eq.name1 = model.body_key[eq_constraint_body1[i]]
eq.name2 = model.body_key[eq_constraint_body2[i]]
eq.name1 = get_body_name(eq_constraint_body1[i])
eq.name2 = get_body_name(eq_constraint_body2[i])
eq.data[0:3] = eq_constraint_anchor[i]
if eq_constraint_solref is not None:
eq.solref = eq_constraint_solref[i]
Expand All @@ -1711,8 +1717,8 @@ def add_geoms(newton_body_id: int):
eq = spec.add_equality(objtype=mujoco.mjtObj.mjOBJ_BODY)
eq.type = mujoco.mjtEq.mjEQ_WELD
eq.active = eq_constraint_enabled[i]
eq.name1 = model.body_key[eq_constraint_body1[i]]
eq.name2 = model.body_key[eq_constraint_body2[i]]
eq.name1 = get_body_name(eq_constraint_body1[i])
eq.name2 = get_body_name(eq_constraint_body2[i])
cns_relpose = wp.transform(*eq_constraint_relpose[i])
eq.data[0:3] = eq_constraint_anchor[i]
eq.data[3:6] = wp.transform_get_translation(cns_relpose)
Expand Down
95 changes: 86 additions & 9 deletions newton/_src/utils/import_mjcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from ..core import quat_between_axes, quat_from_euler
from ..core.types import Axis, AxisType, Sequence, Transform
from ..geometry import MESH_MAXHULLVERT, Mesh
from ..geometry import MESH_MAXHULLVERT, Mesh, ShapeFlags
from ..sim import JointType, ModelBuilder
from ..sim.model import ModelAttributeFrequency
from ..usd.schemas import solref_to_stiffness_damping
Expand Down Expand Up @@ -968,6 +968,26 @@ def parse_common_attributes(element):
"active": element.attrib.get("active", "true").lower() == "true",
}

def get_site_body_and_anchor(site_name: str) -> tuple[int, wp.vec3] | None:
"""Look up a site by name and return its body index and position (anchor).

Returns:
Tuple of (body_idx, anchor_position) or None if site not found or not a site.
"""
if site_name not in builder.shape_key:
if verbose:
print(f"Warning: Site '{site_name}' not found")
return None
site_idx = builder.shape_key.index(site_name)
if not (builder.shape_flags[site_idx] & ShapeFlags.SITE):
if verbose:
print(f"Warning: Shape '{site_name}' is not a site")
return None
body_idx = builder.shape_body[site_idx]
site_xform = builder.shape_transform[site_idx]
anchor = wp.vec3(site_xform[0], site_xform[1], site_xform[2])
return (body_idx, anchor)

for connect in equality.findall("connect"):
common = parse_common_attributes(connect)
custom_attrs = parse_custom_attributes(connect.attrib, builder_custom_attr_eq, parsing_mode="mjcf")
Expand All @@ -976,8 +996,8 @@ def parse_common_attributes(element):
connect.attrib.get("body2", "worldbody").replace("-", "_") if connect.attrib.get("body2") else None
)
anchor = connect.attrib.get("anchor")

site1 = connect.attrib.get("site1")
site2 = connect.attrib.get("site2")

if body1_name and anchor:
if verbose:
Expand All @@ -996,9 +1016,35 @@ def parse_common_attributes(element):
enabled=common["active"],
custom_attributes=custom_attrs,
)

if site1: # Implement site-based connect after Newton supports sites
print("Warning: MuJoCo sites are not yet supported in Newton.")
elif site1:
if site2:
# Site-based connect: both site1 and site2 must be specified
site1_info = get_site_body_and_anchor(site1)
site2_info = get_site_body_and_anchor(site2)
if site1_info is None or site2_info is None:
if verbose:
print(f"Warning: Connect constraint '{common['name']}' failed.")
continue
body1_idx, anchor_vec = site1_info
body2_idx, _ = site2_info
if verbose:
print(
f"Connect constraint (site-based): site '{site1}' on body {body1_idx} to body {body2_idx}"
)
builder.add_equality_constraint_connect(
body1=body1_idx,
body2=body2_idx,
anchor=anchor_vec,
key=common["name"],
enabled=common["active"],
custom_attributes=custom_attrs,
)
else:
if verbose:
print(
f"Warning: Connect constraint '{common['name']}' has site1 but no site2. "
"When using sites, both site1 and site2 must be specified. Skipping."
)

for weld in equality.findall("weld"):
common = parse_common_attributes(weld)
Expand All @@ -1008,8 +1054,8 @@ def parse_common_attributes(element):
anchor = weld.attrib.get("anchor", "0 0 0")
relpose = weld.attrib.get("relpose", "0 1 0 0 0 0 0")
torquescale = weld.attrib.get("torquescale")

site1 = weld.attrib.get("site1")
site2 = weld.attrib.get("site2")

if body1_name:
if verbose:
Expand All @@ -1036,9 +1082,40 @@ def parse_common_attributes(element):
enabled=common["active"],
custom_attributes=custom_attrs,
)

if site1: # Implement site-based weld after Newton supports sites
print("Warning: MuJoCo sites are not yet supported in Newton.")
elif site1:
if site2:
# Site-based weld: both site1 and site2 must be specified
site1_info = get_site_body_and_anchor(site1)
site2_info = get_site_body_and_anchor(site2)
if site1_info is None or site2_info is None:
if verbose:
print(f"Warning: Weld constraint '{common['name']}' failed.")
continue
body1_idx, _ = site1_info
body2_idx, anchor_vec = site2_info
relpose_list = [float(x) for x in relpose.split()]
relpose_transform = wp.transform(
wp.vec3(relpose_list[0], relpose_list[1], relpose_list[2]),
wp.quat(relpose_list[4], relpose_list[5], relpose_list[6], relpose_list[3]),
)
if verbose:
print(f"Weld constraint (site-based): body {body1_idx} to body {body2_idx}")
builder.add_equality_constraint_weld(
body1=body1_idx,
body2=body2_idx,
anchor=anchor_vec,
relpose=relpose_transform,
torquescale=torquescale,
key=common["name"],
enabled=common["active"],
custom_attributes=custom_attrs,
)
else:
if verbose:
print(
f"Warning: Weld constraint '{common['name']}' has site1 but no site2. "
"When using sites, both site1 and site2 must be specified. Skipping."
)

for joint in equality.findall("joint"):
common = parse_common_attributes(joint)
Expand Down
21 changes: 18 additions & 3 deletions newton/tests/test_equality_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,23 @@ def test_multiple_constraints(self):

self.model = builder.finalize()

eq_keys = self.model.equality_constraint_key
eq_body1 = self.model.equality_constraint_body1.numpy()
eq_body2 = self.model.equality_constraint_body2.numpy()
eq_anchors = self.model.equality_constraint_anchor.numpy()
eq_torquescale = self.model.equality_constraint_torquescale.numpy()

c_site_idx = eq_keys.index("c_site")
self.assertEqual(eq_body1[c_site_idx], -1)
self.assertEqual(eq_body2[c_site_idx], 0)
np.testing.assert_allclose(eq_anchors[c_site_idx], [0.0, 0.0, 1.0], rtol=1e-5)

w_site_idx = eq_keys.index("w_site")
self.assertEqual(eq_body1[w_site_idx], -1)
self.assertEqual(eq_body2[w_site_idx], 1)
np.testing.assert_allclose(eq_anchors[w_site_idx], [0.0, 0.0, 0.0], rtol=1e-5)
self.assertAlmostEqual(eq_torquescale[w_site_idx], 0.1, places=5)

self.solver = newton.solvers.SolverMuJoCo(
self.model,
use_mujoco_cpu=True,
Expand All @@ -62,9 +79,7 @@ def test_multiple_constraints(self):

self.sim_time += self.frame_dt

self.assertGreater(
self.solver.mj_model.eq_type.shape[0], 0
) # check if number of equality constraints in mjModel > 0
self.assertEqual(self.solver.mj_model.eq_type.shape[0], 5)

# Check constraint violations
nefc = self.solver.mj_data.nefc # number of active constraints
Expand Down
Loading