Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions genesis/engine/solvers/rigid/abd/forward_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,21 +861,22 @@ def func_factor_mass(
# FIXME: Diagonal coeffs of L are ignored in computations, so no need to update them.
rigid_global_info.mass_mat_L[i_d, i_d, i_b] = 1.0

# Cache original pivot row values before modification
# Cache original pivot row values before modification.
# wave64: all threads in lockstep, no explicit sync needed.
if tid < i_d_:
sh_pivot[tid] = mass_mat[i_d_, tid]
# wave64: all threads in lockstep, no explicit sync needed

# Balanced rank-1 update: flatten (row, col) pairs across all threads
_n_updates = i_d_ * (i_d_ + 1) // 2
_upd = tid
while _upd < _n_updates:
_r = qd.cast((qd.sqrt(8.0 * qd.cast(_upd, gs.qd_float) + 1.0) - 1.0) * 0.5, qd.i32)
if _r * (_r + 1) // 2 > _upd:
_r = _r - 1
_c = _upd - _r * (_r + 1) // 2
mass_mat[_r, _c] = mass_mat[_r, _c] - sh_pivot[_r] * D_inv * sh_pivot[_c]
_upd = _upd + BLOCK_DIM

# Row-major rank-1 update: each thread owns rows [tid, tid+BLOCK_DIM, ...].
# Eliminates the per-update sqrt needed by the flat-index decode, and keeps
# sh_pivot[_r] in a VGPR register across the inner column loop.
_r = tid
while _r < i_d_:
piv_r = sh_pivot[_r] * D_inv
_c = 0
while _c <= _r:
mass_mat[_r, _c] = mass_mat[_r, _c] - piv_r * sh_pivot[_c]
_c = _c + 1
_r = _r + BLOCK_DIM

# Write L factors to pivot row
if tid < i_d_:
Expand Down
145 changes: 74 additions & 71 deletions genesis/engine/solvers/rigid/abd/forward_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,75 +448,77 @@ def func_COM_links_entity(
rigid_global_info.EPS[None],
)

for i_l_ in (
range(entities_info.link_start[i_e], entities_info.link_end[i_e])
if qd.static(not BW)
else qd.static(range(static_rigid_sim_config.max_n_links_per_entity))
):
i_l = i_l_ if qd.static(not BW) else (i_l_ + entities_info.link_start[i_e])

if func_check_index_range(i_l, entities_info.link_start[i_e], entities_info.link_end[i_e], BW):
I_l = [i_l, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_l
# j_pos/j_quat are only read by the backward adjoint cache; skip in forward-only mode.
if qd.static(static_rigid_sim_config.requires_grad):
for i_l_ in (
range(entities_info.link_start[i_e], entities_info.link_end[i_e])
if qd.static(not BW)
else qd.static(range(static_rigid_sim_config.max_n_links_per_entity))
):
i_l = i_l_ if qd.static(not BW) else (i_l_ + entities_info.link_start[i_e])

if links_info.n_dofs[I_l] > 0:
i_p = links_info.parent_idx[I_l]
if func_check_index_range(i_l, entities_info.link_start[i_e], entities_info.link_end[i_e], BW):
I_l = [i_l, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_l

_i_j = links_info.joint_start[I_l]
_I_j = [_i_j, i_b] if qd.static(static_rigid_sim_config.batch_joints_info) else _i_j
joint_type = joints_info.type[_I_j]
if links_info.n_dofs[I_l] > 0:
i_p = links_info.parent_idx[I_l]

p_pos = qd.Vector.zero(gs.qd_float, 3)
p_quat = gu.qd_identity_quat()
if i_p != -1:
p_pos = links_state.pos[i_p, i_b]
p_quat = links_state.quat[i_p, i_b]
_i_j = links_info.joint_start[I_l]
_I_j = [_i_j, i_b] if qd.static(static_rigid_sim_config.batch_joints_info) else _i_j
joint_type = joints_info.type[_I_j]

if joint_type == gs.JOINT_TYPE.FREE or (links_info.is_fixed[I_l] and i_p == -1):
links_state.j_pos[i_l, i_b] = links_state.pos[i_l, i_b]
links_state.j_quat[i_l, i_b] = links_state.quat[i_l, i_b]
else:
acc_pos, acc_quat = gu.qd_transform_pos_quat_by_trans_quat(
links_info.pos[I_l], links_info.quat[I_l], p_pos, p_quat,
)
if qd.static(BW):
links_state.j_pos_bw[i_l, 0, i_b] = acc_pos
links_state.j_quat_bw[i_l, 0, i_b] = acc_quat

n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l]

for i_j_ in (
range(n_joints)
if qd.static(not BW)
else qd.static(range(static_rigid_sim_config.max_n_joints_per_link))
):
i_j = i_j_ + links_info.joint_start[I_l]

if func_check_index_range(
i_j,
links_info.joint_start[I_l],
links_info.joint_end[I_l],
BW,
):
I_j = [i_j, i_b] if qd.static(static_rigid_sim_config.batch_joints_info) else i_j

if qd.static(not BW):
acc_pos = acc_pos + gu.qd_transform_by_quat(joints_info.pos[I_j], acc_quat)
else:
curr_i_j = i_j_
next_i_j = i_j_ + 1
prev_quat = links_state.j_quat_bw[i_l, curr_i_j, i_b]
links_state.j_pos_bw[i_l, next_i_j, i_b] = (
links_state.j_pos_bw[i_l, curr_i_j, i_b]
+ gu.qd_transform_by_quat(joints_info.pos[I_j], prev_quat)
)
links_state.j_quat_bw[i_l, next_i_j, i_b] = prev_quat
p_pos = qd.Vector.zero(gs.qd_float, 3)
p_quat = gu.qd_identity_quat()
if i_p != -1:
p_pos = links_state.pos[i_p, i_b]
p_quat = links_state.quat[i_p, i_b]

if qd.static(not BW):
links_state.j_pos[i_l, i_b] = acc_pos
links_state.j_quat[i_l, i_b] = acc_quat
if joint_type == gs.JOINT_TYPE.FREE or (links_info.is_fixed[I_l] and i_p == -1):
links_state.j_pos[i_l, i_b] = links_state.pos[i_l, i_b]
links_state.j_quat[i_l, i_b] = links_state.quat[i_l, i_b]
else:
links_state.j_pos[i_l, i_b] = links_state.j_pos_bw[i_l, n_joints, i_b]
links_state.j_quat[i_l, i_b] = links_state.j_quat_bw[i_l, n_joints, i_b]
acc_pos, acc_quat = gu.qd_transform_pos_quat_by_trans_quat(
links_info.pos[I_l], links_info.quat[I_l], p_pos, p_quat,
)
if qd.static(BW):
links_state.j_pos_bw[i_l, 0, i_b] = acc_pos
links_state.j_quat_bw[i_l, 0, i_b] = acc_quat

n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l]

for i_j_ in (
range(n_joints)
if qd.static(not BW)
else qd.static(range(static_rigid_sim_config.max_n_joints_per_link))
):
i_j = i_j_ + links_info.joint_start[I_l]

if func_check_index_range(
i_j,
links_info.joint_start[I_l],
links_info.joint_end[I_l],
BW,
):
I_j = [i_j, i_b] if qd.static(static_rigid_sim_config.batch_joints_info) else i_j

if qd.static(not BW):
acc_pos = acc_pos + gu.qd_transform_by_quat(joints_info.pos[I_j], acc_quat)
else:
curr_i_j = i_j_
next_i_j = i_j_ + 1
prev_quat = links_state.j_quat_bw[i_l, curr_i_j, i_b]
links_state.j_pos_bw[i_l, next_i_j, i_b] = (
links_state.j_pos_bw[i_l, curr_i_j, i_b]
+ gu.qd_transform_by_quat(joints_info.pos[I_j], prev_quat)
)
links_state.j_quat_bw[i_l, next_i_j, i_b] = prev_quat

if qd.static(not BW):
links_state.j_pos[i_l, i_b] = acc_pos
links_state.j_quat[i_l, i_b] = acc_quat
else:
links_state.j_pos[i_l, i_b] = links_state.j_pos_bw[i_l, n_joints, i_b]
links_state.j_quat[i_l, i_b] = links_state.j_quat_bw[i_l, n_joints, i_b]

for i_l_ in (
range(entities_info.link_start[i_e], entities_info.link_end[i_e])
Expand Down Expand Up @@ -936,14 +938,15 @@ def func_com_links_split(
i_l, i_b, links_state, links_info, rigid_global_info, static_rigid_sim_config,
)

# Pass 6: per-link joint pose (`j_pos`/`j_quat`). Only reads FK outputs, so
# this pass has no data dependency on passes 1-5 and could in principle run
# earlier, but we keep it here to minimize structural churn.
qd.loop_config(serialize=serialize, block_dim=64)
for i_l, i_b in qd.ndrange(n_links, _B):
func_com_pass6_joint_pose_link(
i_l, i_b, links_state, links_info, joints_info, static_rigid_sim_config, is_backward,
)
# Pass 6: per-link joint pose (`j_pos`/`j_quat`). These values are only read
# by the backward adjoint cache (`func_copy_cartesian_space`), so the entire
# pass is dead work in pure forward simulation. Skip it when not training.
if qd.static(static_rigid_sim_config.requires_grad):
qd.loop_config(serialize=serialize, block_dim=64)
for i_l, i_b in qd.ndrange(n_links, _B):
func_com_pass6_joint_pose_link(
i_l, i_b, links_state, links_info, joints_info, static_rigid_sim_config, is_backward,
)

# Pass 7: per-link motion subspace (`cdof_*`/`cdofvel_*`); reads pass-4
# `root_COM` at the joint anchor.
Expand Down
Loading