diff --git a/newton/_src/solvers/mujoco/solver_mujoco.py b/newton/_src/solvers/mujoco/solver_mujoco.py index c72ee186ff..151339bfef 100644 --- a/newton/_src/solvers/mujoco/solver_mujoco.py +++ b/newton/_src/solvers/mujoco/solver_mujoco.py @@ -1685,14 +1685,20 @@ def add_geoms(newton_body_id: int): add_geoms(child) + def get_body_name(body_idx: int) -> str: + """Get body name, handling world body (-1) correctly.""" + if body_idx == -1: + return "world" + return model.body_key[body_idx] + for i in selected_constraints: constraint_type = eq_constraint_type[i] if constraint_type == EqType.CONNECT: eq = spec.add_equality(objtype=mujoco.mjtObj.mjOBJ_BODY) eq.type = mujoco.mjtEq.mjEQ_CONNECT eq.active = eq_constraint_enabled[i] - eq.name1 = model.body_key[eq_constraint_body1[i]] - eq.name2 = model.body_key[eq_constraint_body2[i]] + eq.name1 = get_body_name(eq_constraint_body1[i]) + eq.name2 = get_body_name(eq_constraint_body2[i]) eq.data[0:3] = eq_constraint_anchor[i] if eq_constraint_solref is not None: eq.solref = eq_constraint_solref[i] @@ -1711,8 +1717,8 @@ def add_geoms(newton_body_id: int): eq = spec.add_equality(objtype=mujoco.mjtObj.mjOBJ_BODY) eq.type = mujoco.mjtEq.mjEQ_WELD eq.active = eq_constraint_enabled[i] - eq.name1 = model.body_key[eq_constraint_body1[i]] - eq.name2 = model.body_key[eq_constraint_body2[i]] + eq.name1 = get_body_name(eq_constraint_body1[i]) + eq.name2 = get_body_name(eq_constraint_body2[i]) cns_relpose = wp.transform(*eq_constraint_relpose[i]) eq.data[0:3] = eq_constraint_anchor[i] eq.data[3:6] = wp.transform_get_translation(cns_relpose) diff --git a/newton/_src/utils/import_mjcf.py b/newton/_src/utils/import_mjcf.py index 2adec0ac37..2198283dca 100644 --- a/newton/_src/utils/import_mjcf.py +++ b/newton/_src/utils/import_mjcf.py @@ -26,7 +26,7 @@ from ..core import quat_between_axes, quat_from_euler from ..core.types import Axis, AxisType, Sequence, Transform -from ..geometry import MESH_MAXHULLVERT, Mesh +from ..geometry import MESH_MAXHULLVERT, Mesh, ShapeFlags from ..sim import JointType, ModelBuilder from ..sim.model import ModelAttributeFrequency from ..usd.schemas import solref_to_stiffness_damping @@ -968,6 +968,26 @@ def parse_common_attributes(element): "active": element.attrib.get("active", "true").lower() == "true", } + def get_site_body_and_anchor(site_name: str) -> tuple[int, wp.vec3] | None: + """Look up a site by name and return its body index and position (anchor). + + Returns: + Tuple of (body_idx, anchor_position) or None if site not found or not a site. + """ + if site_name not in builder.shape_key: + if verbose: + print(f"Warning: Site '{site_name}' not found") + return None + site_idx = builder.shape_key.index(site_name) + if not (builder.shape_flags[site_idx] & ShapeFlags.SITE): + if verbose: + print(f"Warning: Shape '{site_name}' is not a site") + return None + body_idx = builder.shape_body[site_idx] + site_xform = builder.shape_transform[site_idx] + anchor = wp.vec3(site_xform[0], site_xform[1], site_xform[2]) + return (body_idx, anchor) + for connect in equality.findall("connect"): common = parse_common_attributes(connect) custom_attrs = parse_custom_attributes(connect.attrib, builder_custom_attr_eq, parsing_mode="mjcf") @@ -976,8 +996,8 @@ def parse_common_attributes(element): connect.attrib.get("body2", "worldbody").replace("-", "_") if connect.attrib.get("body2") else None ) anchor = connect.attrib.get("anchor") - site1 = connect.attrib.get("site1") + site2 = connect.attrib.get("site2") if body1_name and anchor: if verbose: @@ -996,9 +1016,35 @@ def parse_common_attributes(element): enabled=common["active"], custom_attributes=custom_attrs, ) - - if site1: # Implement site-based connect after Newton supports sites - print("Warning: MuJoCo sites are not yet supported in Newton.") + elif site1: + if site2: + # Site-based connect: both site1 and site2 must be specified + site1_info = get_site_body_and_anchor(site1) + site2_info = get_site_body_and_anchor(site2) + if site1_info is None or site2_info is None: + if verbose: + print(f"Warning: Connect constraint '{common['name']}' failed.") + continue + body1_idx, anchor_vec = site1_info + body2_idx, _ = site2_info + if verbose: + print( + f"Connect constraint (site-based): site '{site1}' on body {body1_idx} to body {body2_idx}" + ) + builder.add_equality_constraint_connect( + body1=body1_idx, + body2=body2_idx, + anchor=anchor_vec, + key=common["name"], + enabled=common["active"], + custom_attributes=custom_attrs, + ) + else: + if verbose: + print( + f"Warning: Connect constraint '{common['name']}' has site1 but no site2. " + "When using sites, both site1 and site2 must be specified. Skipping." + ) for weld in equality.findall("weld"): common = parse_common_attributes(weld) @@ -1008,8 +1054,8 @@ def parse_common_attributes(element): anchor = weld.attrib.get("anchor", "0 0 0") relpose = weld.attrib.get("relpose", "0 1 0 0 0 0 0") torquescale = weld.attrib.get("torquescale") - site1 = weld.attrib.get("site1") + site2 = weld.attrib.get("site2") if body1_name: if verbose: @@ -1036,9 +1082,40 @@ def parse_common_attributes(element): enabled=common["active"], custom_attributes=custom_attrs, ) - - if site1: # Implement site-based weld after Newton supports sites - print("Warning: MuJoCo sites are not yet supported in Newton.") + elif site1: + if site2: + # Site-based weld: both site1 and site2 must be specified + site1_info = get_site_body_and_anchor(site1) + site2_info = get_site_body_and_anchor(site2) + if site1_info is None or site2_info is None: + if verbose: + print(f"Warning: Weld constraint '{common['name']}' failed.") + continue + body1_idx, _ = site1_info + body2_idx, anchor_vec = site2_info + relpose_list = [float(x) for x in relpose.split()] + relpose_transform = wp.transform( + wp.vec3(relpose_list[0], relpose_list[1], relpose_list[2]), + wp.quat(relpose_list[4], relpose_list[5], relpose_list[6], relpose_list[3]), + ) + if verbose: + print(f"Weld constraint (site-based): body {body1_idx} to body {body2_idx}") + builder.add_equality_constraint_weld( + body1=body1_idx, + body2=body2_idx, + anchor=anchor_vec, + relpose=relpose_transform, + torquescale=torquescale, + key=common["name"], + enabled=common["active"], + custom_attributes=custom_attrs, + ) + else: + if verbose: + print( + f"Warning: Weld constraint '{common['name']}' has site1 but no site2. " + "When using sites, both site1 and site2 must be specified. Skipping." + ) for joint in equality.findall("joint"): common = parse_common_attributes(joint) diff --git a/newton/tests/test_equality_constraints.py b/newton/tests/test_equality_constraints.py index ca929dcf05..1d77788849 100644 --- a/newton/tests/test_equality_constraints.py +++ b/newton/tests/test_equality_constraints.py @@ -39,6 +39,23 @@ def test_multiple_constraints(self): self.model = builder.finalize() + eq_keys = self.model.equality_constraint_key + eq_body1 = self.model.equality_constraint_body1.numpy() + eq_body2 = self.model.equality_constraint_body2.numpy() + eq_anchors = self.model.equality_constraint_anchor.numpy() + eq_torquescale = self.model.equality_constraint_torquescale.numpy() + + c_site_idx = eq_keys.index("c_site") + self.assertEqual(eq_body1[c_site_idx], -1) + self.assertEqual(eq_body2[c_site_idx], 0) + np.testing.assert_allclose(eq_anchors[c_site_idx], [0.0, 0.0, 1.0], rtol=1e-5) + + w_site_idx = eq_keys.index("w_site") + self.assertEqual(eq_body1[w_site_idx], -1) + self.assertEqual(eq_body2[w_site_idx], 1) + np.testing.assert_allclose(eq_anchors[w_site_idx], [0.0, 0.0, 0.0], rtol=1e-5) + self.assertAlmostEqual(eq_torquescale[w_site_idx], 0.1, places=5) + self.solver = newton.solvers.SolverMuJoCo( self.model, use_mujoco_cpu=True, @@ -62,9 +79,7 @@ def test_multiple_constraints(self): self.sim_time += self.frame_dt - self.assertGreater( - self.solver.mj_model.eq_type.shape[0], 0 - ) # check if number of equality constraints in mjModel > 0 + self.assertEqual(self.solver.mj_model.eq_type.shape[0], 5) # Check constraint violations nefc = self.solver.mj_data.nefc # number of active constraints