Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions newton/_src/sim/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def __init__(self, up_axis: AxisType = Axis.Z, gravity: float = -9.81):
self.joint_X_c = [] # frame of child com (in child coordinates) (constant)
self.joint_q = []
self.joint_qd = []
self.joint_act = []
self.joint_f = []

self.joint_type = []
Expand Down Expand Up @@ -1101,6 +1102,7 @@ def transform_mul(a, b):
"joint_dof_mode",
"joint_key",
"joint_qd",
"joint_act",
"joint_f",
"joint_target",
"joint_limit_lower",
Expand Down Expand Up @@ -1336,6 +1338,7 @@ def add_axis_dim(dim: ModelBuilder.JointDofConfig):
self.joint_q.append(0.0)
for _ in range(dof_count):
self.joint_qd.append(0.0)
self.joint_act.append(0.0)
self.joint_f.append(0.0)

if joint_type == JointType.FREE or joint_type == JointType.DISTANCE or joint_type == JointType.BALL:
Expand Down Expand Up @@ -2084,6 +2087,7 @@ def collapse_fixed_joints(self, verbose=wp.config.verbose):
"type": self.joint_type[i],
"q": self.joint_q[q_start : q_start + q_dim],
"qd": self.joint_qd[qd_start : qd_start + qd_dim],
"act": self.joint_act[qd_start : qd_start + qd_dim],
"armature": self.joint_armature[qd_start : qd_start + qd_dim],
"q_start": q_start,
"qd_start": qd_start,
Expand Down Expand Up @@ -2277,6 +2281,7 @@ def dfs(parent_body: int, child_body: int, incoming_xform: wp.transform, last_dy
self.joint_child.clear()
self.joint_q.clear()
self.joint_qd.clear()
self.joint_act.clear()
self.joint_q_start.clear()
self.joint_qd_start.clear()
self.joint_enabled.clear()
Expand Down Expand Up @@ -4331,6 +4336,7 @@ def finalize(self, device: Devicelike | None = None, requires_grad: bool = False
m.joint_target_kd = wp.array(self.joint_target_kd, dtype=wp.float32, requires_grad=requires_grad)
m.joint_dof_mode = wp.array(self.joint_dof_mode, dtype=wp.int32)
m.joint_target = wp.array(self.joint_target, dtype=wp.float32, requires_grad=requires_grad)
m.joint_act = wp.array(self.joint_act, dtype=wp.float32, requires_grad=requires_grad)
m.joint_f = wp.array(self.joint_f, dtype=wp.float32, requires_grad=requires_grad)
m.joint_effort_limit = wp.array(self.joint_effort_limit, dtype=wp.float32, requires_grad=requires_grad)
m.joint_velocity_limit = wp.array(self.joint_velocity_limit, dtype=wp.float32, requires_grad=requires_grad)
Expand Down
4 changes: 4 additions & 0 deletions newton/_src/sim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def __init__(self, device: Devicelike | None = None):
"""Generalized joint positions for state initialization, shape [joint_coord_count], float."""
self.joint_qd = None
"""Generalized joint velocities for state initialization, shape [joint_dof_count], float."""
self.joint_act = None
"""Generalized joint actuation for state initialization, shape [joint_dof_count], float."""
self.joint_f = None
"""Generalized joint forces for state initialization, shape [joint_dof_count], float."""
self.joint_target = None
Expand Down Expand Up @@ -405,6 +407,7 @@ def __init__(self, device: Devicelike | None = None):

# attributes per joint dof
self.attribute_frequency["joint_qd"] = "joint_dof"
self.attribute_frequency["joint_act"] = "joint_dof"
self.attribute_frequency["joint_f"] = "joint_dof"
self.attribute_frequency["joint_armature"] = "joint_dof"
self.attribute_frequency["joint_target"] = "joint_dof"
Expand Down Expand Up @@ -470,6 +473,7 @@ def state(self, requires_grad: bool | None = None) -> State:
if self.joint_count:
s.joint_q = wp.clone(self.joint_q, requires_grad=requires_grad)
s.joint_qd = wp.clone(self.joint_qd, requires_grad=requires_grad)
s.joint_act = wp.clone(self.joint_act, requires_grad=requires_grad)

return s

Expand Down
3 changes: 3 additions & 0 deletions newton/_src/sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def __init__(self) -> None:
self.joint_qd: wp.array | None = None
"""Generalized joint velocity coordinates, shape (joint_dof_count,), dtype float."""

self.joint_act: wp.array | None = None
"""Generalized joint actuation, shape (joint_dof_count,), dtype float."""

def clear_forces(self) -> None:
"""
Clear all force arrays (for particles and bodies) in the state object.
Expand Down
40 changes: 38 additions & 2 deletions newton/_src/solvers/mujoco/solver_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def convert_newton_contacts_to_mjwarp_kernel(
def convert_mj_coords_to_warp_kernel(
qpos: wp.array2d(dtype=wp.float32),
qvel: wp.array2d(dtype=wp.float32),
qact: wp.array2d(dtype=wp.float32),
joints_per_env: int,
up_axis: int,
joint_type: wp.array(dtype=wp.int32),
Expand All @@ -345,6 +346,7 @@ def convert_mj_coords_to_warp_kernel(
# outputs
joint_q: wp.array(dtype=wp.float32),
joint_qd: wp.array(dtype=wp.float32),
joint_act: wp.array(dtype=wp.float32),
):
worldid, jntid = wp.tid()

Expand Down Expand Up @@ -383,6 +385,14 @@ def convert_mj_coords_to_warp_kernel(
joint_qd[wqd_i + 3] = w[0]
joint_qd[wqd_i + 4] = w[1]
joint_qd[wqd_i + 5] = w[2]

joint_act[wqd_i + 0] = qact[worldid, q_i + 0]
joint_act[wqd_i + 1] = qact[worldid, q_i + 1]
joint_act[wqd_i + 2] = qact[worldid, q_i + 2]
joint_act[wqd_i + 3] = qact[worldid, q_i + 3]
joint_act[wqd_i + 4] = qact[worldid, q_i + 4]
joint_act[wqd_i + 5] = qact[worldid, q_i + 5]

Comment on lines +389 to +395
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix qact indexing for free joints.

qact/qfrc_actuator follow the velocity layout (stride = nv), but here we index them with q_i (qpos stride). On the second and later free joints this walks past the end of the actuator array, corrupting data. Use qd_i just like we already do for velocities.

Apply this diff:

-        joint_act[wqd_i + 0] = qact[worldid, q_i + 0]
-        joint_act[wqd_i + 1] = qact[worldid, q_i + 1]
-        joint_act[wqd_i + 2] = qact[worldid, q_i + 2]
-        joint_act[wqd_i + 3] = qact[worldid, q_i + 3]
-        joint_act[wqd_i + 4] = qact[worldid, q_i + 4]
-        joint_act[wqd_i + 5] = qact[worldid, q_i + 5]
+        joint_act[wqd_i + 0] = qact[worldid, qd_i + 0]
+        joint_act[wqd_i + 1] = qact[worldid, qd_i + 1]
+        joint_act[wqd_i + 2] = qact[worldid, qd_i + 2]
+        joint_act[wqd_i + 3] = qact[worldid, qd_i + 3]
+        joint_act[wqd_i + 4] = qact[worldid, qd_i + 4]
+        joint_act[wqd_i + 5] = qact[worldid, qd_i + 5]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
joint_act[wqd_i + 0] = qact[worldid, q_i + 0]
joint_act[wqd_i + 1] = qact[worldid, q_i + 1]
joint_act[wqd_i + 2] = qact[worldid, q_i + 2]
joint_act[wqd_i + 3] = qact[worldid, q_i + 3]
joint_act[wqd_i + 4] = qact[worldid, q_i + 4]
joint_act[wqd_i + 5] = qact[worldid, q_i + 5]
joint_act[wqd_i + 0] = qact[worldid, qd_i + 0]
joint_act[wqd_i + 1] = qact[worldid, qd_i + 1]
joint_act[wqd_i + 2] = qact[worldid, qd_i + 2]
joint_act[wqd_i + 3] = qact[worldid, qd_i + 3]
joint_act[wqd_i + 4] = qact[worldid, qd_i + 4]
joint_act[wqd_i + 5] = qact[worldid, qd_i + 5]

elif type == JointType.BALL:
# change quaternion order from wxyz to xyzw
rot = wp.quat(
Expand All @@ -398,6 +408,8 @@ def convert_mj_coords_to_warp_kernel(
for i in range(3):
# convert velocity components
joint_qd[wqd_i + i] = qvel[worldid, qd_i + i]
# convert act components
joint_act[wqd_i + i] = qact[worldid, qd_i + i]
else:
axis_count = joint_dof_dim[jntid, 0] + joint_dof_dim[jntid, 1]
for i in range(axis_count):
Expand All @@ -406,12 +418,15 @@ def convert_mj_coords_to_warp_kernel(
for i in range(axis_count):
# convert velocity components
joint_qd[wqd_i + i] = qvel[worldid, qd_i + i]
# convert act components
joint_act[wqd_i + i] = qact[worldid, qd_i + i]


@wp.kernel
def convert_warp_coords_to_mj_kernel(
joint_q: wp.array(dtype=wp.float32),
joint_qd: wp.array(dtype=wp.float32),
joint_act: wp.array(dtype=wp.float32),
joints_per_env: int,
up_axis: int,
joint_type: wp.array(dtype=wp.int32),
Expand All @@ -421,6 +436,7 @@ def convert_warp_coords_to_mj_kernel(
# outputs
qpos: wp.array2d(dtype=wp.float32),
qvel: wp.array2d(dtype=wp.float32),
qact: wp.array2d(dtype=wp.float32),
):
worldid, jntid = wp.tid()

Expand Down Expand Up @@ -460,6 +476,13 @@ def convert_warp_coords_to_mj_kernel(
qvel[worldid, qd_i + 4] = w[1]
qvel[worldid, qd_i + 5] = w[2]

qact[worldid, qd_i + 0] = joint_act[wqd_i + 0]
qact[worldid, qd_i + 1] = joint_act[wqd_i + 1]
qact[worldid, qd_i + 2] = joint_act[wqd_i + 2]
qact[worldid, qd_i + 3] = joint_act[wqd_i + 3]
qact[worldid, qd_i + 4] = joint_act[wqd_i + 4]
qact[worldid, qd_i + 5] = joint_act[wqd_i + 5]

elif type == JointType.BALL:
# change quaternion order from xyzw to wxyz
qpos[worldid, q_i + 0] = joint_q[wq_i + 1]
Expand All @@ -469,6 +492,8 @@ def convert_warp_coords_to_mj_kernel(
for i in range(3):
# convert velocity components
qvel[worldid, qd_i + i] = joint_qd[wqd_i + i]
# convert act components
qact[worldid, qd_i + i] = joint_act[wqd_i + i]
else:
axis_count = joint_dof_dim[jntid, 0] + joint_dof_dim[jntid, 1]
for i in range(axis_count):
Expand All @@ -477,6 +502,8 @@ def convert_warp_coords_to_mj_kernel(
for i in range(axis_count):
# convert velocity components
qvel[worldid, qd_i + i] = joint_qd[wqd_i + i]
# convert act components
qact[worldid, qd_i + i] = joint_act[wqd_i + i]


@wp.kernel
Expand Down Expand Up @@ -1499,38 +1526,44 @@ def update_mjc_data(self, mj_data: MjWarpData | MjData, model: Model, state: Sta
# we have an MjWarp Data object
qpos = mj_data.qpos
qvel = mj_data.qvel
qact = mj_data.qact
nworld = mj_data.nworld
else:
# we have an MjData object from Mujoco
qpos = wp.empty((1, model.joint_coord_count), dtype=wp.float32, device=model.device)
qvel = wp.empty((1, model.joint_dof_count), dtype=wp.float32, device=model.device)
qact = wp.empty((1, model.joint_dof_count), dtype=wp.float32, device=model.device)
nworld = 1
if state is None:
joint_q = model.joint_q
joint_qd = model.joint_qd
joint_act = model.joint_act
else:
joint_q = state.joint_q
joint_qd = state.joint_qd
joint_act = state.joint_act
joints_per_env = model.joint_count // nworld
wp.launch(
convert_warp_coords_to_mj_kernel,
dim=(nworld, joints_per_env),
inputs=[
joint_q,
joint_qd,
joint_act,
joints_per_env,
model.up_axis,
model.joint_type,
model.joint_q_start,
model.joint_qd_start,
model.joint_dof_dim,
],
outputs=[qpos, qvel],
outputs=[qpos, qvel, qact],
device=model.device,
)
if not is_mjwarp:
mj_data.qpos[:] = qpos.numpy().flatten()[: len(mj_data.qpos)]
mj_data.qvel[:] = qvel.numpy().flatten()[: len(mj_data.qvel)]
mj_data.qfrc_actuator[:] = qact.numpy().flatten()[: len(mj_data.qfrc_actuator)]

def update_newton_state(
self,
Expand All @@ -1544,6 +1577,7 @@ def update_newton_state(
# we have an MjWarp Data object
qpos = mj_data.qpos
qvel = mj_data.qvel
qact = mj_data.qfrc_actuator
nworld = mj_data.nworld

xpos = mj_data.xpos
Expand All @@ -1552,6 +1586,7 @@ def update_newton_state(
# we have an MjData object from Mujoco
qpos = wp.array([mj_data.qpos], dtype=wp.float32, device=model.device)
qvel = wp.array([mj_data.qvel], dtype=wp.float32, device=model.device)
qact = wp.array([mj_data.qfrc_actuator], dtype=wp.float32, device=model.device)
nworld = 1

xpos = wp.array([mj_data.xpos], dtype=wp.vec3, device=model.device)
Expand All @@ -1563,14 +1598,15 @@ def update_newton_state(
inputs=[
qpos,
qvel,
qact,
joints_per_env,
int(model.up_axis),
model.joint_type,
model.joint_q_start,
model.joint_qd_start,
model.joint_dof_dim,
],
outputs=[state.joint_q, state.joint_qd],
outputs=[state.joint_q, state.joint_qd, state.joint_act],
device=model.device,
)

Expand Down
Loading