Skip to content

Commit c67b6fb

Browse files
authored
Site based constraints (#1291)
1 parent ab82975 commit c67b6fb

File tree

3 files changed

+114
-16
lines changed

3 files changed

+114
-16
lines changed

newton/_src/solvers/mujoco/solver_mujoco.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,14 +1961,20 @@ def add_geoms(newton_body_id: int):
19611961

19621962
add_geoms(child)
19631963

1964+
def get_body_name(body_idx: int) -> str:
1965+
"""Get body name, handling world body (-1) correctly."""
1966+
if body_idx == -1:
1967+
return "world"
1968+
return model.body_key[body_idx]
1969+
19641970
for i in selected_constraints:
19651971
constraint_type = eq_constraint_type[i]
19661972
if constraint_type == EqType.CONNECT:
19671973
eq = spec.add_equality(objtype=mujoco.mjtObj.mjOBJ_BODY)
19681974
eq.type = mujoco.mjtEq.mjEQ_CONNECT
19691975
eq.active = eq_constraint_enabled[i]
1970-
eq.name1 = model.body_key[eq_constraint_body1[i]]
1971-
eq.name2 = model.body_key[eq_constraint_body2[i]]
1976+
eq.name1 = get_body_name(eq_constraint_body1[i])
1977+
eq.name2 = get_body_name(eq_constraint_body2[i])
19721978
eq.data[0:3] = eq_constraint_anchor[i]
19731979
if eq_constraint_solref is not None:
19741980
eq.solref = eq_constraint_solref[i]
@@ -1987,8 +1993,8 @@ def add_geoms(newton_body_id: int):
19871993
eq = spec.add_equality(objtype=mujoco.mjtObj.mjOBJ_BODY)
19881994
eq.type = mujoco.mjtEq.mjEQ_WELD
19891995
eq.active = eq_constraint_enabled[i]
1990-
eq.name1 = model.body_key[eq_constraint_body1[i]]
1991-
eq.name2 = model.body_key[eq_constraint_body2[i]]
1996+
eq.name1 = get_body_name(eq_constraint_body1[i])
1997+
eq.name2 = get_body_name(eq_constraint_body2[i])
19921998
cns_relpose = wp.transform(*eq_constraint_relpose[i])
19931999
eq.data[0:3] = eq_constraint_anchor[i]
19942000
eq.data[3:6] = wp.transform_get_translation(cns_relpose)

newton/_src/utils/import_mjcf.py

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from ..core import quat_between_axes, quat_from_euler
2828
from ..core.types import Axis, AxisType, Sequence, Transform
29-
from ..geometry import MESH_MAXHULLVERT, Mesh
29+
from ..geometry import MESH_MAXHULLVERT, Mesh, ShapeFlags
3030
from ..sim import JointType, ModelBuilder
3131
from ..sim.model import ModelAttributeFrequency
3232
from ..usd.schemas import solref_to_stiffness_damping
@@ -968,6 +968,26 @@ def parse_common_attributes(element):
968968
"active": element.attrib.get("active", "true").lower() == "true",
969969
}
970970

971+
def get_site_body_and_anchor(site_name: str) -> tuple[int, wp.vec3] | None:
972+
"""Look up a site by name and return its body index and position (anchor).
973+
974+
Returns:
975+
Tuple of (body_idx, anchor_position) or None if site not found or not a site.
976+
"""
977+
if site_name not in builder.shape_key:
978+
if verbose:
979+
print(f"Warning: Site '{site_name}' not found")
980+
return None
981+
site_idx = builder.shape_key.index(site_name)
982+
if not (builder.shape_flags[site_idx] & ShapeFlags.SITE):
983+
if verbose:
984+
print(f"Warning: Shape '{site_name}' is not a site")
985+
return None
986+
body_idx = builder.shape_body[site_idx]
987+
site_xform = builder.shape_transform[site_idx]
988+
anchor = wp.vec3(site_xform[0], site_xform[1], site_xform[2])
989+
return (body_idx, anchor)
990+
971991
for connect in equality.findall("connect"):
972992
common = parse_common_attributes(connect)
973993
custom_attrs = parse_custom_attributes(connect.attrib, builder_custom_attr_eq, parsing_mode="mjcf")
@@ -976,8 +996,8 @@ def parse_common_attributes(element):
976996
connect.attrib.get("body2", "worldbody").replace("-", "_") if connect.attrib.get("body2") else None
977997
)
978998
anchor = connect.attrib.get("anchor")
979-
980999
site1 = connect.attrib.get("site1")
1000+
site2 = connect.attrib.get("site2")
9811001

9821002
if body1_name and anchor:
9831003
if verbose:
@@ -996,9 +1016,35 @@ def parse_common_attributes(element):
9961016
enabled=common["active"],
9971017
custom_attributes=custom_attrs,
9981018
)
999-
1000-
if site1: # Implement site-based connect after Newton supports sites
1001-
print("Warning: MuJoCo sites are not yet supported in Newton.")
1019+
elif site1:
1020+
if site2:
1021+
# Site-based connect: both site1 and site2 must be specified
1022+
site1_info = get_site_body_and_anchor(site1)
1023+
site2_info = get_site_body_and_anchor(site2)
1024+
if site1_info is None or site2_info is None:
1025+
if verbose:
1026+
print(f"Warning: Connect constraint '{common['name']}' failed.")
1027+
continue
1028+
body1_idx, anchor_vec = site1_info
1029+
body2_idx, _ = site2_info
1030+
if verbose:
1031+
print(
1032+
f"Connect constraint (site-based): site '{site1}' on body {body1_idx} to body {body2_idx}"
1033+
)
1034+
builder.add_equality_constraint_connect(
1035+
body1=body1_idx,
1036+
body2=body2_idx,
1037+
anchor=anchor_vec,
1038+
key=common["name"],
1039+
enabled=common["active"],
1040+
custom_attributes=custom_attrs,
1041+
)
1042+
else:
1043+
if verbose:
1044+
print(
1045+
f"Warning: Connect constraint '{common['name']}' has site1 but no site2. "
1046+
"When using sites, both site1 and site2 must be specified. Skipping."
1047+
)
10021048

10031049
for weld in equality.findall("weld"):
10041050
common = parse_common_attributes(weld)
@@ -1008,8 +1054,8 @@ def parse_common_attributes(element):
10081054
anchor = weld.attrib.get("anchor", "0 0 0")
10091055
relpose = weld.attrib.get("relpose", "0 1 0 0 0 0 0")
10101056
torquescale = weld.attrib.get("torquescale")
1011-
10121057
site1 = weld.attrib.get("site1")
1058+
site2 = weld.attrib.get("site2")
10131059

10141060
if body1_name:
10151061
if verbose:
@@ -1036,9 +1082,40 @@ def parse_common_attributes(element):
10361082
enabled=common["active"],
10371083
custom_attributes=custom_attrs,
10381084
)
1039-
1040-
if site1: # Implement site-based weld after Newton supports sites
1041-
print("Warning: MuJoCo sites are not yet supported in Newton.")
1085+
elif site1:
1086+
if site2:
1087+
# Site-based weld: both site1 and site2 must be specified
1088+
site1_info = get_site_body_and_anchor(site1)
1089+
site2_info = get_site_body_and_anchor(site2)
1090+
if site1_info is None or site2_info is None:
1091+
if verbose:
1092+
print(f"Warning: Weld constraint '{common['name']}' failed.")
1093+
continue
1094+
body1_idx, _ = site1_info
1095+
body2_idx, anchor_vec = site2_info
1096+
relpose_list = [float(x) for x in relpose.split()]
1097+
relpose_transform = wp.transform(
1098+
wp.vec3(relpose_list[0], relpose_list[1], relpose_list[2]),
1099+
wp.quat(relpose_list[4], relpose_list[5], relpose_list[6], relpose_list[3]),
1100+
)
1101+
if verbose:
1102+
print(f"Weld constraint (site-based): body {body1_idx} to body {body2_idx}")
1103+
builder.add_equality_constraint_weld(
1104+
body1=body1_idx,
1105+
body2=body2_idx,
1106+
anchor=anchor_vec,
1107+
relpose=relpose_transform,
1108+
torquescale=torquescale,
1109+
key=common["name"],
1110+
enabled=common["active"],
1111+
custom_attributes=custom_attrs,
1112+
)
1113+
else:
1114+
if verbose:
1115+
print(
1116+
f"Warning: Weld constraint '{common['name']}' has site1 but no site2. "
1117+
"When using sites, both site1 and site2 must be specified. Skipping."
1118+
)
10421119

10431120
for joint in equality.findall("joint"):
10441121
common = parse_common_attributes(joint)

newton/tests/test_equality_constraints.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,23 @@ def test_multiple_constraints(self):
3939

4040
self.model = builder.finalize()
4141

42+
eq_keys = self.model.equality_constraint_key
43+
eq_body1 = self.model.equality_constraint_body1.numpy()
44+
eq_body2 = self.model.equality_constraint_body2.numpy()
45+
eq_anchors = self.model.equality_constraint_anchor.numpy()
46+
eq_torquescale = self.model.equality_constraint_torquescale.numpy()
47+
48+
c_site_idx = eq_keys.index("c_site")
49+
self.assertEqual(eq_body1[c_site_idx], -1)
50+
self.assertEqual(eq_body2[c_site_idx], 0)
51+
np.testing.assert_allclose(eq_anchors[c_site_idx], [0.0, 0.0, 1.0], rtol=1e-5)
52+
53+
w_site_idx = eq_keys.index("w_site")
54+
self.assertEqual(eq_body1[w_site_idx], -1)
55+
self.assertEqual(eq_body2[w_site_idx], 1)
56+
np.testing.assert_allclose(eq_anchors[w_site_idx], [0.0, 0.0, 0.0], rtol=1e-5)
57+
self.assertAlmostEqual(eq_torquescale[w_site_idx], 0.1, places=5)
58+
4259
self.solver = newton.solvers.SolverMuJoCo(
4360
self.model,
4461
use_mujoco_cpu=True,
@@ -62,9 +79,7 @@ def test_multiple_constraints(self):
6279

6380
self.sim_time += self.frame_dt
6481

65-
self.assertGreater(
66-
self.solver.mj_model.eq_type.shape[0], 0
67-
) # check if number of equality constraints in mjModel > 0
82+
self.assertEqual(self.solver.mj_model.eq_type.shape[0], 5)
6883

6984
# Check constraint violations
7085
nefc = self.solver.mj_data.nefc # number of active constraints

0 commit comments

Comments
 (0)