Skip to content

Commit 8ebf5c6

Browse files
gpinkertyaoliu13
authored andcommitted
perf(broadphase): integrate commits 1+2+3 into LDS path (no cooperative phase)
Adapts perf/broadphase-stack commits 1+2+3 to live on top of PR #56's LDS infrastructure. Cooperative parallel re-fill (the original commit 4) is intentionally excluded for this experiment because adding qd.simt.block.sync() to a loop that doesn't fill block_dim=64 (e.g. n_envs=1 -> 4 active lanes) triggers a workgroup-barrier UB and a GPU memory-access fault. Changes inside func_broad_phase_lds: * Replace the now-removed func_collision_clear() kernel call with the per-env helper func_collision_clear_per_env(i_b, ...). * Hoist n_eq_static / n_eq_dyn at the top of the per-env block and pass them to all 3 func_check_collision_valid call sites in this path. * Switch lds_sort_packed dtype from gs.qd_int to qd.u32 to match the StructSortBuffer.i_g_packed encoding contract. * Replace open-coded (i_g << 1) / >> 1 / & 1 patterns with func_pack_event / func_unpack_i_g / func_unpack_is_max throughout. * Fix the post-cherry-pick broken sort_buffer.{is_max, i_g} refs (8 sites) to use the renamed sort_buffer.i_g_packed field via helpers. * Collapse the write-back loop from 2 stores per event (separate i_g + is_max) to 1 store (i_g_packed).
1 parent ee419a4 commit 8ebf5c6

1 file changed

Lines changed: 71 additions & 93 deletions

File tree

genesis/engine/solvers/rigid/collider/broadphase.py

Lines changed: 71 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -190,30 +190,29 @@ def func_broad_phase_lds(
190190
collider_info: array_class.ColliderInfo,
191191
errno: array_class.V_ANNOTATION,
192192
):
193-
"""
194-
Sweep and Prune (SAP) for broad-phase collision detection.
195-
196-
This function sorts the geometry axis-aligned bounding boxes (AABBs) along a specified axis and checks for
197-
potential collision pairs based on the AABB overlap.
198-
199-
The optimized LDS path primarily targets use_hibernation=False.
200-
The hibernation path keeps the original active_buffer_awake/hib logic.
193+
"""Sweep and Prune (SAP) for broad-phase collision detection -- LDS path.
194+
195+
Integrates broadphase-stack commits 1+2+3 on top of PR #56's LDS scaffolding:
196+
* Commit 1: replace func_collision_clear() kernel-call with the per-env
197+
helper func_collision_clear_per_env(i_b, ...) inlined here.
198+
* Commit 2: hoist n_eq_static / n_eq_dyn out of the per-pair check.
199+
* Commit 3: pack i_g + is_max into a u32 single-word; use func_pack_event /
200+
func_unpack_i_g / func_unpack_is_max helpers; collapse 2-store sort_buffer
201+
write-back to a single i_g_packed store.
202+
203+
Cooperative parallel warm-start (the original commit 4) is intentionally
204+
omitted from this experiment because adding qd.simt.block.sync() to a loop
205+
that doesn't fill block_dim=64 (e.g. n_envs=1 -> 4 active lanes) triggers
206+
a workgroup-barrier UB / GPU memory access fault. Lane-0-only execution
207+
is preserved here.
201208
"""
202209
n_geoms, _B = collider_state.active_buffer.shape
203210
n_links = links_info.geom_start.shape[0]
204211

205-
# Clear collider state
206-
func_collision_clear(links_state, links_info, collider_state, static_rigid_sim_config)
207-
208212
MAX_GEOMS_NUM = qd.static(MAX_GEOMS_IN_LDS)
209213
MAX_SORT_ELEM_NUM = qd.static(MAX_GEOMS_NUM * 2)
210-
211214
BLOCK_DIM = qd.static(64)
212215
ENVS_PER_BLOCK = qd.static(16)
213-
214-
# Only one lane out of THREADS_PER_ENV currently processes one env.
215-
# THREADS_PER_ENV is used to map 16 envs to one 64-thread workgroup and
216-
# reserve one LDS slot per env.
217216
THREADS_PER_ENV = qd.static(BLOCK_DIM // ENVS_PER_BLOCK)
218217

219218
qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=BLOCK_DIM)
@@ -223,107 +222,97 @@ def func_broad_phase_lds(
223222
continue
224223

225224
lds_sort_value = qd.simt.block.SharedArray((ENVS_PER_BLOCK, MAX_SORT_ELEM_NUM), gs.qd_float)
226-
227-
# Packed format: lds_sort_i_g_packed = (i_g << 1) | is_max_bit
228-
lds_sort_packed = qd.simt.block.SharedArray((ENVS_PER_BLOCK, MAX_SORT_ELEM_NUM), gs.qd_int)
229-
230-
# Don't need to copy `collider_state.active_buffer` into `lds_active` before using it.
231-
# Because the sweep below starts with `n_active = 0` and rebuilds the set from scratch.
225+
# Packed format: bit 0 = is_max, bits 1..31 = i_g (matches StructSortBuffer.i_g_packed).
226+
lds_sort_packed = qd.simt.block.SharedArray((ENVS_PER_BLOCK, MAX_SORT_ELEM_NUM), qd.u32)
227+
# No need to copy collider_state.active_buffer; sweep starts with n_active=0.
232228
lds_active = qd.simt.block.SharedArray((ENVS_PER_BLOCK, MAX_GEOMS_NUM), gs.qd_int)
233-
234-
i_b_lds = i_b % ENVS_PER_BLOCK
235229

230+
i_b_lds = i_b % ENVS_PER_BLOCK
236231
axis = 0
237232

238-
# Calculate the number of active geoms for this environment
239-
# (for heterogeneous entities, different envs may have different geoms)
233+
# Commit 1: per-env contact-clear helper, inlined into this kernel.
234+
func_collision_clear_per_env(i_b, links_state, links_info, collider_state, static_rigid_sim_config)
235+
236+
# Commit 2: hoist equality bounds out of the per-pair check.
237+
n_eq_static = rigid_global_info.n_equalities[None]
238+
n_eq_dyn = constraint_state.qd_n_equalities[i_b]
239+
240+
# Calculate the number of active geoms for this environment.
240241
env_n_geoms = 0
241242
for i_l in range(n_links):
242243
I_l = [i_l, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_l
243244
env_n_geoms = env_n_geoms + links_info.geom_end[I_l] - links_info.geom_start[I_l]
244245

245-
# copy updated geom aabbs to buffer for sorting
246+
# copy updated geom aabbs to LDS for sorting
246247
if collider_state.first_time[i_b]:
247248
i_buffer = 0
248249
for i_l in range(n_links):
249250
I_l = [i_l, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_l
250251
for i_g in range(links_info.geom_start[I_l], links_info.geom_end[I_l]):
251252
lds_sort_value[i_b_lds, 2 * i_buffer] = geoms_state.aabb_min[i_g, i_b][axis]
252-
lds_sort_packed[i_b_lds, 2 * i_buffer] = i_g << 1 # is_max = 0
253-
253+
lds_sort_packed[i_b_lds, 2 * i_buffer] = func_pack_event(i_g, False)
254254
lds_sort_value[i_b_lds, 2 * i_buffer + 1] = geoms_state.aabb_max[i_g, i_b][axis]
255-
lds_sort_packed[i_b_lds, 2 * i_buffer + 1] = (i_g << 1) | 1 # is_max = 1
256-
255+
lds_sort_packed[i_b_lds, 2 * i_buffer + 1] = func_pack_event(i_g, True)
257256
geoms_state.min_buffer_idx[i_buffer, i_b] = 2 * i_g
258257
geoms_state.max_buffer_idx[i_buffer, i_b] = 2 * i_g + 1
259258
i_buffer = i_buffer + 1
260-
261259
collider_state.first_time[i_b] = False
262-
263260
else:
264-
if qd.static(not static_rigid_sim_config.use_hibernation):
265-
for i in range(env_n_geoms * 2):
266-
is_max = collider_state.sort_buffer.is_max[i, i_b]
267-
i_g = collider_state.sort_buffer.i_g[i, i_b]
261+
# Warm-start re-fill: read packed event from global sort_buffer, decode via
262+
# helpers, look up new aabb extent in LDS slot.
263+
for i in range(env_n_geoms * 2):
264+
packed = collider_state.sort_buffer.i_g_packed[i, i_b]
265+
is_max = func_unpack_is_max(packed)
266+
i_g = func_unpack_i_g(packed)
267+
lds_sort_packed[i_b_lds, i] = packed
268+
if qd.static(not static_rigid_sim_config.use_hibernation):
268269
if is_max:
269270
lds_sort_value[i_b_lds, i] = geoms_state.aabb_max[i_g, i_b][axis]
270271
else:
271272
lds_sort_value[i_b_lds, i] = geoms_state.aabb_min[i_g, i_b][axis]
273+
else:
274+
lds_sort_value[i_b_lds, i] = collider_state.sort_buffer.value[i, i_b]
272275

273-
lds_sort_packed[i_b_lds, i] = (i_g << 1) | qd.cast(is_max, gs.qd_int)
274-
else:
275-
for i in range(env_n_geoms * 2):
276-
is_max = collider_state.sort_buffer.is_max[i, i_b]
277-
i_g = collider_state.sort_buffer.i_g[i, i_b]
278-
value = collider_state.sort_buffer.value[i, i_b]
279-
lds_sort_packed[i_b_lds, i] = (i_g << 1) | qd.cast(is_max, gs.qd_int)
280-
lds_sort_value[i_b_lds, i] = value
281-
282-
283-
# insertion sort, which has complexity near O(n) for nearly sorted array
276+
# insertion sort, near O(n) for nearly sorted input
284277
for i in range(1, 2 * env_n_geoms):
285278
key_value = lds_sort_value[i_b_lds, i]
286-
key_packed_ig_ismax = lds_sort_packed[i_b_lds, i]
279+
key_packed = lds_sort_packed[i_b_lds, i]
287280

288281
j = i - 1
289282
while j >= 0 and key_value < lds_sort_value[i_b_lds, j]:
290-
packed_ig_ismax = lds_sort_packed[i_b_lds, j]
283+
slid_packed = lds_sort_packed[i_b_lds, j]
291284
lds_sort_value[i_b_lds, j + 1] = lds_sort_value[i_b_lds, j]
292-
lds_sort_packed[i_b_lds, j + 1] = packed_ig_ismax
285+
lds_sort_packed[i_b_lds, j + 1] = slid_packed
293286

294287
if qd.static(static_rigid_sim_config.use_hibernation):
295-
shifted_i_g = packed_ig_ismax >> 1
296-
if packed_ig_ismax & 1:
297-
geoms_state.max_buffer_idx[shifted_i_g, i_b] = j + 1
288+
slid_i_g = func_unpack_i_g(slid_packed)
289+
if func_unpack_is_max(slid_packed):
290+
geoms_state.max_buffer_idx[slid_i_g, i_b] = j + 1
298291
else:
299-
geoms_state.min_buffer_idx[shifted_i_g, i_b] = j + 1
292+
geoms_state.min_buffer_idx[slid_i_g, i_b] = j + 1
300293

301294
j -= 1
302295
lds_sort_value[i_b_lds, j + 1] = key_value
303-
lds_sort_packed[i_b_lds, j + 1] = key_packed_ig_ismax
296+
lds_sort_packed[i_b_lds, j + 1] = key_packed
304297

305298
if qd.static(static_rigid_sim_config.use_hibernation):
306-
key_i_g = key_packed_ig_ismax >> 1
307-
if key_packed_ig_ismax & 1:
299+
key_i_g = func_unpack_i_g(key_packed)
300+
if func_unpack_is_max(key_packed):
308301
geoms_state.max_buffer_idx[key_i_g, i_b] = j + 1
309302
else:
310303
geoms_state.min_buffer_idx[key_i_g, i_b] = j + 1
311304

312-
313305
n_broad = 0
314306
if qd.static(not static_rigid_sim_config.use_hibernation):
315307
n_active = 0
316-
317308
for i in range(2 * env_n_geoms):
318-
packed_ig_ismax = lds_sort_packed[i_b_lds, i]
319-
is_max = packed_ig_ismax & 1
320-
i_g = packed_ig_ismax >> 1
321-
309+
packed = lds_sort_packed[i_b_lds, i]
310+
is_max = func_unpack_is_max(packed)
311+
i_g = func_unpack_i_g(packed)
322312

323313
if not is_max:
324314
min_b0, min_b1, min_b2 = geoms_state.aabb_min[i_g, i_b]
325315
max_b0, max_b1, max_b2 = geoms_state.aabb_max[i_g, i_b]
326-
327316
for j in range(n_active):
328317
i_ga = lds_active[i_b_lds, j]
329318

@@ -338,7 +327,6 @@ def func_broad_phase_lds(
338327

339328
min_a0, min_a1, min_a2 = geoms_state.aabb_min[i_ga, i_b]
340329
max_a0, max_a1, max_a2 = geoms_state.aabb_max[i_ga, i_b]
341-
342330

343331
if (min_a0 > max_b0 or min_a1 > max_b1 or min_a2 > max_b2 or
344332
max_a0 < min_b0 or max_a1 < min_b1 or max_a2 < min_b2):
@@ -348,6 +336,8 @@ def func_broad_phase_lds(
348336
i_ga_c,
349337
i_gb_c,
350338
i_b,
339+
n_eq_static,
340+
n_eq_dyn,
351341
links_state,
352342
links_info,
353343
geoms_info,
@@ -369,29 +359,24 @@ def func_broad_phase_lds(
369359
lds_active[i_b_lds, n_active] = i_g
370360
geoms_state.active_buffer_idx[i_g, i_b] = n_active
371361
n_active += 1
372-
373362
else:
374363
j_remove = geoms_state.active_buffer_idx[i_g, i_b]
375364
if j_remove < n_active - 1:
376-
# Swap with last element
377365
i_g_last = lds_active[i_b_lds, n_active - 1]
378366
lds_active[i_b_lds, j_remove] = i_g_last
379367
geoms_state.active_buffer_idx[i_g_last, i_b] = j_remove
380368
n_active -= 1
381-
382-
collider_state.n_broad_pairs[i_b] = n_broad
383369
else:
384370
if rigid_global_info.n_awake_dofs[i_b] > 0:
385371
n_active_awake = 0
386372
n_active_hib = 0
387373
for i in range(2 * env_n_geoms):
388-
packed_ig_ismax = lds_sort_packed[i_b_lds, i]
389-
i_gb_origin = packed_ig_ismax >> 1
390-
is_max = packed_ig_ismax & 1
374+
packed = lds_sort_packed[i_b_lds, i]
375+
i_gb_origin = func_unpack_i_g(packed)
376+
is_max = func_unpack_is_max(packed)
391377
is_incoming_geom_hibernated = geoms_state.hibernated[i_gb_origin, i_b]
392378

393379
if not is_max:
394-
# both awake and hibernated geom check with active awake geoms
395380
for j in range(n_active_awake):
396381
i_ga = collider_state.active_buffer_awake[j, i_b]
397382
i_gb = i_gb_origin
@@ -402,6 +387,8 @@ def func_broad_phase_lds(
402387
i_ga,
403388
i_gb,
404389
i_b,
390+
n_eq_static,
391+
n_eq_dyn,
405392
links_state,
406393
links_info,
407394
geoms_info,
@@ -414,7 +401,6 @@ def func_broad_phase_lds(
414401
continue
415402

416403
if not func_is_geom_aabbs_overlap(geoms_state, i_ga, i_gb, i_b):
417-
# Clear collision normal cache if not in contact
418404
if qd.static(not static_rigid_sim_config.enable_mujoco_compatibility):
419405
i_pair = collider_info.collision_pair_idx[i_ga, i_gb]
420406
collider_state.contact_cache.normal[i_pair, i_b] = qd.Vector.zero(gs.qd_float, 3)
@@ -424,7 +410,6 @@ def func_broad_phase_lds(
424410
collider_state.broad_collision_pairs[n_broad, i_b][1] = i_gb
425411
n_broad = n_broad + 1
426412

427-
# if incoming geom is awake, also need to check with hibernated geoms
428413
if not is_incoming_geom_hibernated:
429414
for j in range(n_active_hib):
430415
i_ga = collider_state.active_buffer_hib[j, i_b]
@@ -436,6 +421,8 @@ def func_broad_phase_lds(
436421
i_ga,
437422
i_gb,
438423
i_b,
424+
n_eq_static,
425+
n_eq_dyn,
439426
links_state,
440427
links_info,
441428
geoms_info,
@@ -448,7 +435,6 @@ def func_broad_phase_lds(
448435
continue
449436

450437
if not func_is_geom_aabbs_overlap(geoms_state, i_ga, i_gb, i_b):
451-
# Clear collision normal cache if not in contact
452438
i_pair = collider_info.collision_pair_idx[i_ga, i_gb]
453439
collider_state.contact_cache.normal[i_pair, i_b] = qd.Vector.zero(gs.qd_float, 3)
454440
continue
@@ -470,42 +456,34 @@ def func_broad_phase_lds(
470456
if collider_state.active_buffer_hib[j, i_b] == i_g_to_remove:
471457
if j < n_active_hib - 1:
472458
for k in range(j, n_active_hib - 1):
473-
collider_state.active_buffer_hib[k, i_b] = collider_state.active_buffer_hib[
474-
k + 1, i_b
475-
]
459+
collider_state.active_buffer_hib[k, i_b] = collider_state.active_buffer_hib[k + 1, i_b]
476460
n_active_hib = n_active_hib - 1
477461
break
478462
else:
479463
for j in range(n_active_awake):
480464
if collider_state.active_buffer_awake[j, i_b] == i_g_to_remove:
481465
if j < n_active_awake - 1:
482466
for k in range(j, n_active_awake - 1):
483-
collider_state.active_buffer_awake[k, i_b] = (
484-
collider_state.active_buffer_awake[k + 1, i_b]
485-
)
467+
collider_state.active_buffer_awake[k, i_b] = collider_state.active_buffer_awake[k + 1, i_b]
486468
n_active_awake = n_active_awake - 1
487469
break
488470

471+
# Write-back to global sort_buffer for next step's warm-start.
472+
# Single i_g_packed store per event (commit 3 dtype change collapsed
473+
# what used to be two stores into one).
489474
for i in range(env_n_geoms):
490-
491475
if qd.static(static_rigid_sim_config.use_hibernation):
492476
collider_state.sort_buffer.value[2 * i, i_b] = lds_sort_value[i_b_lds, 2 * i]
493477
collider_state.sort_buffer.value[2 * i + 1, i_b] = lds_sort_value[i_b_lds, 2 * i + 1]
494-
495-
packed_ig_ismax = lds_sort_packed[i_b_lds, 2 * i]
496-
collider_state.sort_buffer.i_g[2 * i, i_b] = packed_ig_ismax >> 1
497-
collider_state.sort_buffer.is_max[2 * i, i_b] = qd.cast(packed_ig_ismax & 1, gs.qd_bool)
498-
499-
packed_ig_ismax = lds_sort_packed[i_b_lds, 2 * i + 1]
500-
collider_state.sort_buffer.i_g[2 * i + 1, i_b] = packed_ig_ismax >> 1
501-
collider_state.sort_buffer.is_max[2 * i + 1, i_b] = qd.cast(packed_ig_ismax & 1, gs.qd_bool)
502-
478+
collider_state.sort_buffer.i_g_packed[2 * i, i_b] = lds_sort_packed[i_b_lds, 2 * i]
479+
collider_state.sort_buffer.i_g_packed[2 * i + 1, i_b] = lds_sort_packed[i_b_lds, 2 * i + 1]
503480
if qd.static(not static_rigid_sim_config.use_hibernation):
504481
collider_state.active_buffer[i, i_b] = lds_active[i_b_lds, i]
505482

506483
collider_state.n_broad_pairs[i_b] = n_broad
507484

508485

486+
509487
@qd.func
510488
def func_broad_phase_global_mem(
511489
links_state: array_class.LinksState,

0 commit comments

Comments
 (0)