@@ -190,30 +190,29 @@ def func_broad_phase_lds(
190190 collider_info : array_class .ColliderInfo ,
191191 errno : array_class .V_ANNOTATION ,
192192):
193- """
194- Sweep and Prune (SAP) for broad-phase collision detection.
195-
196- This function sorts the geometry axis-aligned bounding boxes (AABBs) along a specified axis and checks for
197- potential collision pairs based on the AABB overlap.
198-
199- The optimized LDS path primarily targets use_hibernation=False.
200- The hibernation path keeps the original active_buffer_awake/hib logic.
193+ """Sweep and Prune (SAP) for broad-phase collision detection -- LDS path.
194+
195+ Integrates broadphase-stack commits 1+2+3 on top of PR #56's LDS scaffolding:
196+ * Commit 1: replace func_collision_clear() kernel-call with the per-env
197+ helper func_collision_clear_per_env(i_b, ...) inlined here.
198+ * Commit 2: hoist n_eq_static / n_eq_dyn out of the per-pair check.
199+ * Commit 3: pack i_g + is_max into a u32 single-word; use func_pack_event /
200+ func_unpack_i_g / func_unpack_is_max helpers; collapse 2-store sort_buffer
201+ write-back to a single i_g_packed store.
202+
203+ Cooperative parallel warm-start (the original commit 4) is intentionally
204+ omitted from this experiment because adding qd.simt.block.sync() to a loop
205+ that doesn't fill block_dim=64 (e.g. n_envs=1 -> 4 active lanes) triggers
206+ a workgroup-barrier UB / GPU memory access fault. Lane-0-only execution
207+ is preserved here.
201208 """
202209 n_geoms , _B = collider_state .active_buffer .shape
203210 n_links = links_info .geom_start .shape [0 ]
204211
205- # Clear collider state
206- func_collision_clear (links_state , links_info , collider_state , static_rigid_sim_config )
207-
208212 MAX_GEOMS_NUM = qd .static (MAX_GEOMS_IN_LDS )
209213 MAX_SORT_ELEM_NUM = qd .static (MAX_GEOMS_NUM * 2 )
210-
211214 BLOCK_DIM = qd .static (64 )
212215 ENVS_PER_BLOCK = qd .static (16 )
213-
214- # Only one lane out of THREADS_PER_ENV currently processes one env.
215- # THREADS_PER_ENV is used to map 16 envs to one 64-thread workgroup and
216- # reserve one LDS slot per env.
217216 THREADS_PER_ENV = qd .static (BLOCK_DIM // ENVS_PER_BLOCK )
218217
219218 qd .loop_config (serialize = static_rigid_sim_config .para_level < gs .PARA_LEVEL .ALL , block_dim = BLOCK_DIM )
@@ -223,107 +222,97 @@ def func_broad_phase_lds(
223222 continue
224223
225224 lds_sort_value = qd .simt .block .SharedArray ((ENVS_PER_BLOCK , MAX_SORT_ELEM_NUM ), gs .qd_float )
226-
227- # Packed format: lds_sort_i_g_packed = (i_g << 1) | is_max_bit
228- lds_sort_packed = qd .simt .block .SharedArray ((ENVS_PER_BLOCK , MAX_SORT_ELEM_NUM ), gs .qd_int )
229-
230- # Don't need to copy `collider_state.active_buffer` into `lds_active` before using it.
231- # Because the sweep below starts with `n_active = 0` and rebuilds the set from scratch.
225+ # Packed format: bit 0 = is_max, bits 1..31 = i_g (matches StructSortBuffer.i_g_packed).
226+ lds_sort_packed = qd .simt .block .SharedArray ((ENVS_PER_BLOCK , MAX_SORT_ELEM_NUM ), qd .u32 )
227+ # No need to copy collider_state.active_buffer; sweep starts with n_active=0.
232228 lds_active = qd .simt .block .SharedArray ((ENVS_PER_BLOCK , MAX_GEOMS_NUM ), gs .qd_int )
233-
234- i_b_lds = i_b % ENVS_PER_BLOCK
235229
230+ i_b_lds = i_b % ENVS_PER_BLOCK
236231 axis = 0
237232
238- # Calculate the number of active geoms for this environment
239- # (for heterogeneous entities, different envs may have different geoms)
233+ # Commit 1: per-env contact-clear helper, inlined into this kernel.
234+ func_collision_clear_per_env (i_b , links_state , links_info , collider_state , static_rigid_sim_config )
235+
236+ # Commit 2: hoist equality bounds out of the per-pair check.
237+ n_eq_static = rigid_global_info .n_equalities [None ]
238+ n_eq_dyn = constraint_state .qd_n_equalities [i_b ]
239+
240+ # Calculate the number of active geoms for this environment.
240241 env_n_geoms = 0
241242 for i_l in range (n_links ):
242243 I_l = [i_l , i_b ] if qd .static (static_rigid_sim_config .batch_links_info ) else i_l
243244 env_n_geoms = env_n_geoms + links_info .geom_end [I_l ] - links_info .geom_start [I_l ]
244245
245- # copy updated geom aabbs to buffer for sorting
246+ # copy updated geom aabbs to LDS for sorting
246247 if collider_state .first_time [i_b ]:
247248 i_buffer = 0
248249 for i_l in range (n_links ):
249250 I_l = [i_l , i_b ] if qd .static (static_rigid_sim_config .batch_links_info ) else i_l
250251 for i_g in range (links_info .geom_start [I_l ], links_info .geom_end [I_l ]):
251252 lds_sort_value [i_b_lds , 2 * i_buffer ] = geoms_state .aabb_min [i_g , i_b ][axis ]
252- lds_sort_packed [i_b_lds , 2 * i_buffer ] = i_g << 1 # is_max = 0
253-
253+ lds_sort_packed [i_b_lds , 2 * i_buffer ] = func_pack_event (i_g , False )
254254 lds_sort_value [i_b_lds , 2 * i_buffer + 1 ] = geoms_state .aabb_max [i_g , i_b ][axis ]
255- lds_sort_packed [i_b_lds , 2 * i_buffer + 1 ] = (i_g << 1 ) | 1 # is_max = 1
256-
255+ lds_sort_packed [i_b_lds , 2 * i_buffer + 1 ] = func_pack_event (i_g , True )
257256 geoms_state .min_buffer_idx [i_buffer , i_b ] = 2 * i_g
258257 geoms_state .max_buffer_idx [i_buffer , i_b ] = 2 * i_g + 1
259258 i_buffer = i_buffer + 1
260-
261259 collider_state .first_time [i_b ] = False
262-
263260 else :
264- if qd .static (not static_rigid_sim_config .use_hibernation ):
265- for i in range (env_n_geoms * 2 ):
266- is_max = collider_state .sort_buffer .is_max [i , i_b ]
267- i_g = collider_state .sort_buffer .i_g [i , i_b ]
261+ # Warm-start re-fill: read packed event from global sort_buffer, decode via
262+ # helpers, look up new aabb extent in LDS slot.
263+ for i in range (env_n_geoms * 2 ):
264+ packed = collider_state .sort_buffer .i_g_packed [i , i_b ]
265+ is_max = func_unpack_is_max (packed )
266+ i_g = func_unpack_i_g (packed )
267+ lds_sort_packed [i_b_lds , i ] = packed
268+ if qd .static (not static_rigid_sim_config .use_hibernation ):
268269 if is_max :
269270 lds_sort_value [i_b_lds , i ] = geoms_state .aabb_max [i_g , i_b ][axis ]
270271 else :
271272 lds_sort_value [i_b_lds , i ] = geoms_state .aabb_min [i_g , i_b ][axis ]
273+ else :
274+ lds_sort_value [i_b_lds , i ] = collider_state .sort_buffer .value [i , i_b ]
272275
273- lds_sort_packed [i_b_lds , i ] = (i_g << 1 ) | qd .cast (is_max , gs .qd_int )
274- else :
275- for i in range (env_n_geoms * 2 ):
276- is_max = collider_state .sort_buffer .is_max [i , i_b ]
277- i_g = collider_state .sort_buffer .i_g [i , i_b ]
278- value = collider_state .sort_buffer .value [i , i_b ]
279- lds_sort_packed [i_b_lds , i ] = (i_g << 1 ) | qd .cast (is_max , gs .qd_int )
280- lds_sort_value [i_b_lds , i ] = value
281-
282-
283- # insertion sort, which has complexity near O(n) for nearly sorted array
276+ # insertion sort, near O(n) for nearly sorted input
284277 for i in range (1 , 2 * env_n_geoms ):
285278 key_value = lds_sort_value [i_b_lds , i ]
286- key_packed_ig_ismax = lds_sort_packed [i_b_lds , i ]
279+ key_packed = lds_sort_packed [i_b_lds , i ]
287280
288281 j = i - 1
289282 while j >= 0 and key_value < lds_sort_value [i_b_lds , j ]:
290- packed_ig_ismax = lds_sort_packed [i_b_lds , j ]
283+ slid_packed = lds_sort_packed [i_b_lds , j ]
291284 lds_sort_value [i_b_lds , j + 1 ] = lds_sort_value [i_b_lds , j ]
292- lds_sort_packed [i_b_lds , j + 1 ] = packed_ig_ismax
285+ lds_sort_packed [i_b_lds , j + 1 ] = slid_packed
293286
294287 if qd .static (static_rigid_sim_config .use_hibernation ):
295- shifted_i_g = packed_ig_ismax >> 1
296- if packed_ig_ismax & 1 :
297- geoms_state .max_buffer_idx [shifted_i_g , i_b ] = j + 1
288+ slid_i_g = func_unpack_i_g ( slid_packed )
289+ if func_unpack_is_max ( slid_packed ) :
290+ geoms_state .max_buffer_idx [slid_i_g , i_b ] = j + 1
298291 else :
299- geoms_state .min_buffer_idx [shifted_i_g , i_b ] = j + 1
292+ geoms_state .min_buffer_idx [slid_i_g , i_b ] = j + 1
300293
301294 j -= 1
302295 lds_sort_value [i_b_lds , j + 1 ] = key_value
303- lds_sort_packed [i_b_lds , j + 1 ] = key_packed_ig_ismax
296+ lds_sort_packed [i_b_lds , j + 1 ] = key_packed
304297
305298 if qd .static (static_rigid_sim_config .use_hibernation ):
306- key_i_g = key_packed_ig_ismax >> 1
307- if key_packed_ig_ismax & 1 :
299+ key_i_g = func_unpack_i_g ( key_packed )
300+ if func_unpack_is_max ( key_packed ) :
308301 geoms_state .max_buffer_idx [key_i_g , i_b ] = j + 1
309302 else :
310303 geoms_state .min_buffer_idx [key_i_g , i_b ] = j + 1
311304
312-
313305 n_broad = 0
314306 if qd .static (not static_rigid_sim_config .use_hibernation ):
315307 n_active = 0
316-
317308 for i in range (2 * env_n_geoms ):
318- packed_ig_ismax = lds_sort_packed [i_b_lds , i ]
319- is_max = packed_ig_ismax & 1
320- i_g = packed_ig_ismax >> 1
321-
309+ packed = lds_sort_packed [i_b_lds , i ]
310+ is_max = func_unpack_is_max (packed )
311+ i_g = func_unpack_i_g (packed )
322312
323313 if not is_max :
324314 min_b0 , min_b1 , min_b2 = geoms_state .aabb_min [i_g , i_b ]
325315 max_b0 , max_b1 , max_b2 = geoms_state .aabb_max [i_g , i_b ]
326-
327316 for j in range (n_active ):
328317 i_ga = lds_active [i_b_lds , j ]
329318
@@ -338,7 +327,6 @@ def func_broad_phase_lds(
338327
339328 min_a0 , min_a1 , min_a2 = geoms_state .aabb_min [i_ga , i_b ]
340329 max_a0 , max_a1 , max_a2 = geoms_state .aabb_max [i_ga , i_b ]
341-
342330
343331 if (min_a0 > max_b0 or min_a1 > max_b1 or min_a2 > max_b2 or
344332 max_a0 < min_b0 or max_a1 < min_b1 or max_a2 < min_b2 ):
@@ -348,6 +336,8 @@ def func_broad_phase_lds(
348336 i_ga_c ,
349337 i_gb_c ,
350338 i_b ,
339+ n_eq_static ,
340+ n_eq_dyn ,
351341 links_state ,
352342 links_info ,
353343 geoms_info ,
@@ -369,29 +359,24 @@ def func_broad_phase_lds(
369359 lds_active [i_b_lds , n_active ] = i_g
370360 geoms_state .active_buffer_idx [i_g , i_b ] = n_active
371361 n_active += 1
372-
373362 else :
374363 j_remove = geoms_state .active_buffer_idx [i_g , i_b ]
375364 if j_remove < n_active - 1 :
376- # Swap with last element
377365 i_g_last = lds_active [i_b_lds , n_active - 1 ]
378366 lds_active [i_b_lds , j_remove ] = i_g_last
379367 geoms_state .active_buffer_idx [i_g_last , i_b ] = j_remove
380368 n_active -= 1
381-
382- collider_state .n_broad_pairs [i_b ] = n_broad
383369 else :
384370 if rigid_global_info .n_awake_dofs [i_b ] > 0 :
385371 n_active_awake = 0
386372 n_active_hib = 0
387373 for i in range (2 * env_n_geoms ):
388- packed_ig_ismax = lds_sort_packed [i_b_lds , i ]
389- i_gb_origin = packed_ig_ismax >> 1
390- is_max = packed_ig_ismax & 1
374+ packed = lds_sort_packed [i_b_lds , i ]
375+ i_gb_origin = func_unpack_i_g ( packed )
376+ is_max = func_unpack_is_max ( packed )
391377 is_incoming_geom_hibernated = geoms_state .hibernated [i_gb_origin , i_b ]
392378
393379 if not is_max :
394- # both awake and hibernated geom check with active awake geoms
395380 for j in range (n_active_awake ):
396381 i_ga = collider_state .active_buffer_awake [j , i_b ]
397382 i_gb = i_gb_origin
@@ -402,6 +387,8 @@ def func_broad_phase_lds(
402387 i_ga ,
403388 i_gb ,
404389 i_b ,
390+ n_eq_static ,
391+ n_eq_dyn ,
405392 links_state ,
406393 links_info ,
407394 geoms_info ,
@@ -414,7 +401,6 @@ def func_broad_phase_lds(
414401 continue
415402
416403 if not func_is_geom_aabbs_overlap (geoms_state , i_ga , i_gb , i_b ):
417- # Clear collision normal cache if not in contact
418404 if qd .static (not static_rigid_sim_config .enable_mujoco_compatibility ):
419405 i_pair = collider_info .collision_pair_idx [i_ga , i_gb ]
420406 collider_state .contact_cache .normal [i_pair , i_b ] = qd .Vector .zero (gs .qd_float , 3 )
@@ -424,7 +410,6 @@ def func_broad_phase_lds(
424410 collider_state .broad_collision_pairs [n_broad , i_b ][1 ] = i_gb
425411 n_broad = n_broad + 1
426412
427- # if incoming geom is awake, also need to check with hibernated geoms
428413 if not is_incoming_geom_hibernated :
429414 for j in range (n_active_hib ):
430415 i_ga = collider_state .active_buffer_hib [j , i_b ]
@@ -436,6 +421,8 @@ def func_broad_phase_lds(
436421 i_ga ,
437422 i_gb ,
438423 i_b ,
424+ n_eq_static ,
425+ n_eq_dyn ,
439426 links_state ,
440427 links_info ,
441428 geoms_info ,
@@ -448,7 +435,6 @@ def func_broad_phase_lds(
448435 continue
449436
450437 if not func_is_geom_aabbs_overlap (geoms_state , i_ga , i_gb , i_b ):
451- # Clear collision normal cache if not in contact
452438 i_pair = collider_info .collision_pair_idx [i_ga , i_gb ]
453439 collider_state .contact_cache .normal [i_pair , i_b ] = qd .Vector .zero (gs .qd_float , 3 )
454440 continue
@@ -470,42 +456,34 @@ def func_broad_phase_lds(
470456 if collider_state .active_buffer_hib [j , i_b ] == i_g_to_remove :
471457 if j < n_active_hib - 1 :
472458 for k in range (j , n_active_hib - 1 ):
473- collider_state .active_buffer_hib [k , i_b ] = collider_state .active_buffer_hib [
474- k + 1 , i_b
475- ]
459+ collider_state .active_buffer_hib [k , i_b ] = collider_state .active_buffer_hib [k + 1 , i_b ]
476460 n_active_hib = n_active_hib - 1
477461 break
478462 else :
479463 for j in range (n_active_awake ):
480464 if collider_state .active_buffer_awake [j , i_b ] == i_g_to_remove :
481465 if j < n_active_awake - 1 :
482466 for k in range (j , n_active_awake - 1 ):
483- collider_state .active_buffer_awake [k , i_b ] = (
484- collider_state .active_buffer_awake [k + 1 , i_b ]
485- )
467+ collider_state .active_buffer_awake [k , i_b ] = collider_state .active_buffer_awake [k + 1 , i_b ]
486468 n_active_awake = n_active_awake - 1
487469 break
488470
471+ # Write-back to global sort_buffer for next step's warm-start.
472+ # Single i_g_packed store per event (commit 3 dtype change collapsed
473+ # what used to be two stores into one).
489474 for i in range (env_n_geoms ):
490-
491475 if qd .static (static_rigid_sim_config .use_hibernation ):
492476 collider_state .sort_buffer .value [2 * i , i_b ] = lds_sort_value [i_b_lds , 2 * i ]
493477 collider_state .sort_buffer .value [2 * i + 1 , i_b ] = lds_sort_value [i_b_lds , 2 * i + 1 ]
494-
495- packed_ig_ismax = lds_sort_packed [i_b_lds , 2 * i ]
496- collider_state .sort_buffer .i_g [2 * i , i_b ] = packed_ig_ismax >> 1
497- collider_state .sort_buffer .is_max [2 * i , i_b ] = qd .cast (packed_ig_ismax & 1 , gs .qd_bool )
498-
499- packed_ig_ismax = lds_sort_packed [i_b_lds , 2 * i + 1 ]
500- collider_state .sort_buffer .i_g [2 * i + 1 , i_b ] = packed_ig_ismax >> 1
501- collider_state .sort_buffer .is_max [2 * i + 1 , i_b ] = qd .cast (packed_ig_ismax & 1 , gs .qd_bool )
502-
478+ collider_state .sort_buffer .i_g_packed [2 * i , i_b ] = lds_sort_packed [i_b_lds , 2 * i ]
479+ collider_state .sort_buffer .i_g_packed [2 * i + 1 , i_b ] = lds_sort_packed [i_b_lds , 2 * i + 1 ]
503480 if qd .static (not static_rigid_sim_config .use_hibernation ):
504481 collider_state .active_buffer [i , i_b ] = lds_active [i_b_lds , i ]
505482
506483 collider_state .n_broad_pairs [i_b ] = n_broad
507484
508485
486+
509487@qd .func
510488def func_broad_phase_global_mem (
511489 links_state : array_class .LinksState ,
0 commit comments