Skip to content

perf(abd): skip j_pos/j_quat forward pass, row-major rank-1 update, widen constraint grid launches#65

Open
peizhang56 wants to merge 5 commits into
amd-integrationfrom
perf/task-c-twc-envs-per-block-16
Open

perf(abd): skip j_pos/j_quat forward pass, row-major rank-1 update, widen constraint grid launches#65
peizhang56 wants to merge 5 commits into
amd-integrationfrom
perf/task-c-twc-envs-per-block-16

Conversation

@peizhang56
Copy link
Copy Markdown

@peizhang56 peizhang56 commented May 20, 2026

Three performance changes:

  1. Skip j_pos/j_quat in forward-only mode (forward_kinematics.py): gate the j_pos/j_quat write loop in func_COM_links_entity on requires_grad - those fields are only needed by the backward adjoint cache.
  2. Row-major rank-1 update in func_factor_mass (forward_dynamics.py): replace the flat-index decode (which required a sqrt per update) with a row-major loop where each thread owns rows [tid, tid+BLOCK_DIM, ...]. Same factorization, fewer instructions.
  3. Widen constraint kernel launches (constraint/solver.py):
    • Extract the dense qfrc_constraint = J^T @ efc_force gather from the per-env loop into a new func_update_qfrc_constraint_dense 2D (n_dofs, _B) kernel; add defer_dense_qfrc template flag so per-iter callers keep the gather inline.
    • In func_update_gradient_tiled, promote the inner entity loop into a 2D (n_entities_, _B) ndrange (non-hibernation path only), widening dispatch without changing total work.

@peizhang56
Copy link
Copy Markdown
Author

/run-ci

@peizhang56 peizhang56 force-pushed the perf/task-c-twc-envs-per-block-16 branch from 5016e34 to 628a6ce Compare May 20, 2026 19:27
@peizhang56 peizhang56 changed the title Raise _TWC_ENVS_PER_BLOCK and _TWC_BLOCK_DIM on tiled-wc constraint solver widen init kernel grids on tiled-wc solver May 20, 2026
@peizhang56 peizhang56 changed the title widen init kernel grids on tiled-wc solver perf(amdgpu): widen init kernel grids on tiled-wc solver May 20, 2026
@peizhang56
Copy link
Copy Markdown
Author

/run-ci

1 similar comment
@peizhang56
Copy link
Copy Markdown
Author

/run-ci

@peizhang56 peizhang56 force-pushed the perf/task-c-twc-envs-per-block-16 branch from 8569f3b to ae7b8de Compare May 29, 2026 05:38
@npoulad1
Copy link
Copy Markdown
Collaborator

npoulad1 commented Jun 2, 2026

/run-ci

peizhang56 and others added 2 commits June 3, 2026 13:32
Restructures four GRID-STARVED kernels in the AMDGPU constraint solver
init path. Each was launching with too few workgroups to fill the
MI300X CU array because the parallel axis was just the per-env batch.
This commit widens the launch geometry of each via a different lever;
kernel bodies are unchanged.

1. tiled-wc block-shape constants (solver_amdgpu.py)
   _TWC_BLOCK_DIM        64 -> 128
   _TWC_ENVS_PER_BLOCK    8 -> 16
   Doubles threads-per-block and envs-per-block for the tiled
   wave-cooperative variant, so each block does more work and the
   per-env constraint loops draw from a larger lane pool.

2. initialize_Jaref (solver.py)
   Was a 1D loop `for i_b in range(_B)` with an inner serial
   `for i_c in range(n_constraints[i_b])`. Rewritten as a 2D ndrange
   `for i_c, i_b in qd.ndrange(len_constraints, _B)` with an
   `if i_c >= n_constraints[i_b]: continue` guard for the ragged
   tail. Grid width grows from _B to len_constraints * _B.

3. CG mass-solve in func_update_gradient_tiled (kernel_8)
   When hibernation is disabled (compile-time-known via the
   `use_hibernation` template flag), the inner serial loop over
   entities is promoted into the parallel grid: the old
   `for i_b in range(_B)` calling `func_solve_mass_batch` is
   replaced by `for i_e, i_b in qd.ndrange(n_entities_, _B)` calling
   `func_solve_mass_entity` directly. The entity body is already
   guarded by `mass_mat_mask[i_e, i_b]`, so zero-DOF entities (e.g.
   Plane) become near-no-op threads -- same total work, wider
   dispatch. Hibernation path keeps the original 1D form because
   n_awake_entities is dynamic.

4. Dense qfrc gather kernel_5 (solver.py)
   The dense `qfrc_constraint = J^T @ efc_force` gather was nested
   inside `func_update_constraint_batch`'s per-env loop. Split into
   a new `func_update_qfrc_constraint_dense` 2D kernel over
   (n_dofs, _B). Routing controlled by a new `defer_dense_qfrc`
   template flag on `func_update_constraint_batch`:
     - True  (set by func_update_constraint init caller): skip the
             inline dense gather, follow up with the 2D kernel.
     - False (set by func_solve_iter and
             func_solve_iter_post_linesearch per-iter callers):
             keep the gather inline. Those callers run inside a
             per-env loop with no follow-up dispatch site, so
             deferring there would leave qfrc_constraint stale and
             NaN out the gradient.
   Sparse path is unchanged -- its scatter would race on
   `qfrc_constraint[i_d, i_b]` without atomic-add, so it stays
   inside the 1D per-env loop.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Bump _TWC_BLOCK_DIM 128->256 and _TWC_ENVS_PER_BLOCK 16->32 in
_kernel_solve_body_tiled_wc_amdgpu. Doubles L2 cache reuse per workgroup
with unchanged total launched waves.

BLOCK_DIM > 256 is gated by the Quadrants 100-iter unroll cap on the
qd.static(range(BLOCK_DIM)) act_red reduction.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@npoulad1
Copy link
Copy Markdown
Collaborator

npoulad1 commented Jun 3, 2026

/run-ci

@peizhang56 peizhang56 force-pushed the perf/task-c-twc-envs-per-block-16 branch from ae7b8de to 26932d0 Compare June 3, 2026 21:33
@peizhang56 peizhang56 force-pushed the perf/task-c-twc-envs-per-block-16 branch from 26932d0 to 234f8b2 Compare June 3, 2026 21:52
Skip the per-link joint pose pass (Pass 6 of func_com_links_split and the
equivalent block in func_COM_links_entity) when requires_grad=False. The
j_pos/j_quat values in links_state are only consumed by the backward adjoint
cache (func_copy_cartesian_space in diff.py) and are dead work in pure forward
simulation. The guard is a qd.static branch keyed on requires_grad, so the
skip is resolved at compile time with no runtime overhead.

Replace the flat-index rank-1 update in the tiled Cholesky (func_factor_mass)
with a row-major nested loop. The previous implementation decoded each flat
triangle index to (row, col) via a sqrt + integer correction on every iteration
of the update loop. The new loop strides threads across rows and iterates
columns sequentially within each row, eliminating all sqrt calls. The pivot
row value sh_pivot[_r] is loaded once into a VGPR and reused across the inner
column loop, reducing LDS reads for the pivot row.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
@peizhang56
Copy link
Copy Markdown
Author

/run-ci

@lohiaj
Copy link
Copy Markdown

lohiaj commented Jun 4, 2026

A couple of things before this can merge:

  1. The comments bake in 8192 (the env count) in a few spots, which we can't have in the repo. e.g. solver.py:3213-3214:
only 8192 threads (wgs=128, below the MI300X 304-CU floor); this
2D form widens to n_dofs * _B threads (wgs=4480+ on g1_29dof),

same at :3262 ("the 8192-thread (128 wg) launch geometry") and :3322/:3326 ("8192/32=256 wgs", "2*8192 = 16384 threads"). Can you rephrase these in terms of _B / per-env so the count is not hardcoded?

  1. The PR body is out of sync with the diff. The title and item 1 describe solver_amdgpu.py tiled-wc constant changes that are not in here, while the two biggest changes (the forward-only j_pos/j_quat skip in func_COM_links, and the factor_mass rank-1 rewrite) are not mentioned at all. Since we squash merge, that body becomes the commit message, so please update it to match what actually landed.

Approach itself looks reasonable.

… GPU names)

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
@peizhang56 peizhang56 changed the title perf(amdgpu): widen init kernel grids on tiled-wc solver perf(abd): skip j_pos/j_quat forward pass, row-major rank-1 update, widen constraint grid launches Jun 4, 2026
@peizhang56
Copy link
Copy Markdown
Author

/run-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants