Skip to content

Commit d2fcd4c

Browse files
authored
[Perf] Add bitonic sort to subgroup ops (#713)
1 parent f618f05 commit d2fcd4c

5 files changed

Lines changed: 429 additions & 19 deletions

File tree

docs/source/user_guide/compound_types.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ It's of course very subjective, but some guidelines you could consider:
3232

3333
- if you are trying to write a python class that runs on the GPU => use a `@qd.data_oriented`
3434
- if you are trying to write typed dataclasses, for passing data around between the `@data_oriented` classes, and between methods of the same `@data_oriented` class => use `@dataclasses.dataclass`es
35-
- `@qd.dataclass` is used to create structured element types for field tensors. We also use it to create the Cholesky [tiles](tile16.md).
35+
- `@qd.dataclass` is used to create structured element types for field tensors. We also use it to create the Cholesky [tiles](tile.md).
3636

3737
## dataclasses.dataclass
3838

docs/source/user_guide/subgroup.md

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,16 @@ Every op above has a paired `_tiled` form that takes an extra `log2_size` templa
8484

8585
The SPV-only no-arg reductions (`subgroup.reduce_mul` / `reduce_and` / `reduce_or` / `reduce_xor`, plus the original `reduce_add_tiled(value)` with no `log2_size`) have been removed in favour of the portable sized API. For reductions other than the ones listed above, build a sized helper on top of `shuffle_down` / `shuffle` following the same pattern as `reduce_add_tiled` / `reduce_all_add_tiled`.
8686

87+
### Sorting
88+
89+
In-register key/value sort across the subgroup, one `(key, value)` pair per lane. Pure `shuffle` -no shared memory, no barriers -fully unrolled at compile time.
90+
91+
| Op | CUDA | AMDGPU | SPIR-V (Vulkan / Metal) | dtypes |
92+
|----------------------------------------|------|--------|-------------------------|-----------------------------------------------------------------|
93+
| `subgroup.bitonic_sort_kv(key, value)` | yes | yes | yes | key & value: i32, u32, f32, f64, i64, u64 (independently typed) |
94+
95+
Returns `(key, value)` - assign with `key, value = subgroup.bitonic_sort_kv(key, value)`. Sorts ascending on the `(key, value)` lex tuple; ties on `key` break on ascending `value` (not a textbook-stable sort - equal-keyed lanes come back in ascending-`value` order, not in original-lane order). Tiled variant: `bitonic_sort_kv_tiled(key, value, log2_size)` runs the same sort independently on each `2**log2_size`-aligned tile - see [Tiled variants](#tiled-variants). See [`bitonic_sort_kv`](#bitonic_sort_kvkey-value) for the short-input pattern (sentinel padding), the textbook-stability caveat, and the float NaN behaviour. See [Bitonic key/value sort example](#bitonic-keyvalue-sort-example) for an example.
96+
8797
## Semantics
8898

8999
All of these ops operate within a single subgroup: they do not move data through memory and do not synchronise across subgroups.
@@ -193,6 +203,7 @@ Why it composes exactly: the underlying `subgroup.shuffle` / `subgroup.shuffle_d
193203
| `subgroup.segmented_reduce_{min,max}_tiled(v, head_flag, log2_size)` | broadcast-to-tile |
194204
| `subgroup.inclusive_{add,mul,min,max,and,or,xor}_tiled(v, log2_size)` | broadcast-to-tile |
195205
| `subgroup.exclusive_{add,mul,min,max,and,or,xor}_tiled(v, log2_size)` | broadcast-to-tile |
206+
| `subgroup.bitonic_sort_kv_tiled(key, value, log2_size)` | broadcast-to-tile (every lane in the tile holds its sorted-position pair) |
196207

197208
- **Broadcast-to-tile forms**: every lane in each tile holds that tile's result. Lanes in different tiles hold different results (their own tile's).
198209
- **Tile-local lane-0 forms**: only the *tile-local* lane 0 holds the reduction. That's lane 0 alone with `log2_size=5` on wave32, lanes 0 and 32 with `log2_size=5` on wave64, lanes 0 / 16 / 32 / 48 with `log2_size=4` on wave64, etc. Other lanes hold partial reductions and should be treated as undefined. Use the `reduce_all_*_tiled` counterparts if you want every lane to see its tile's result.
@@ -302,6 +313,22 @@ Per-lane exclusive scan across the entire subgroup, under the binary operator na
302313
- The shared `_exclusive_scan_tiled` helper runs the inclusive scan, shifts the result up by one lane via `shuffle_up`, and substitutes the identity at lane 0. The lane-0 substitution is required because `shuffle_up` with offset 1 is implementation-defined at lane 0 (and `OpGroupNonUniformShuffleUp` calls it undefined outright).
303314
- AMDGPU performance note (`*` in the table): same `ds_bpermute` cost as `shuffle_up`. Cost is one inclusive scan plus one extra `shuffle_up` and a select.
304315

316+
### `bitonic_sort_kv(key, value)`
317+
318+
In-register ascending lex sort on `(key, value)` pairs across the subgroup, one `(key, value)` pair per lane. Returns `(key, value)` - the lex-smallest pair on lane 0, the next on lane 1, ..., the lex-largest on lane `group_size() - 1`. Tiled variant: `bitonic_sort_kv_tiled(key, value, log2_size)` runs the same sort independently on each `2**log2_size`-aligned tile - see [Tiled variants](#tiled-variants).
319+
320+
- Sorts ascending on `key`; ties on `key` break on ascending `value` (the comparison is the lex compare `(key, value) < (key', value')`).
321+
- `key` and `value` should be scalar values. Supported dtypes are `i32`, `u32`, `f32`, `f64`, `i64`, `u64`. The dtypes of `key` and `value` do not have to match.
322+
- Implementation: classic 1-D bitonic sorting network.
323+
324+
#### Float NaN handling
325+
326+
Float NaN handling is implementation-defined: comparisons with NaN return false on most backends, so a NaN-keyed lane drifts to an arbitrary position within the sorted tile and the result loses its "sorted" guarantee. Bit-cast the key to a same-width integer dtype if you need a portable NaN-respecting order.
327+
328+
#### Short-input pattern
329+
330+
When sorting fewer than `2**log2_size` real elements, load real data into the low `n` lanes, initialise the high lanes with a sentinel `key` that compares greater than every real key (`+inf` for floats, `INT_MAX` / `UINT_MAX` for ints) and any safe `value`, then ignore the high lanes in the result.
331+
305332
### `ballot_first_n(predicate, n)`
306333

307334
Returns a `u32` bitmask whose bit `i` is set iff `i < n` AND lane `i`'s `predicate` is non-zero. Bits `>= n` are always zero.
@@ -364,7 +391,7 @@ Closed-form `u32` lane-mask constants parametrised by a lane id. Bit `i` of the
364391

365392
## Examples
366393

367-
### Broadcast lane 0 to all lanes
394+
### Broadcast lane 0 to all lanes example
368395

369396
```python
370397
import quadrants as qd
@@ -379,7 +406,7 @@ def broadcast(a: qd.types.ndarray(dtype=qd.f32, ndim=1)):
379406

380407
After the kernel, every lane in a subgroup holds the original value of its lane 0. `subgroup.broadcast(a[i], qd.u32(0))` is interchangeable here.
381408

382-
### Ballot: count how many lanes satisfy a condition
409+
### Ballot: count how many lanes satisfy a condition example
383410

384411
```python
385412
@qd.kernel
@@ -394,7 +421,7 @@ def count_positive(a: qd.types.ndarray(dtype=qd.f32, ndim=1),
394421

395422
After the kernel, `counts[g]` contains a bitmask of which lanes in group `g` had positive values. Use `popcount(mask)` on the host to get the count.
396423

397-
### Identity shuffle (each lane reads its own id)
424+
### Identity shuffle (each lane reads its own id) example
398425

399426
Useful as a sanity check:
400427

@@ -410,7 +437,7 @@ def identity(src: qd.types.ndarray(dtype=qd.f32, ndim=1),
410437

411438
`dst[i]` equals `src[i]` on every lane.
412439

413-
### Swap neighbours (xor pattern via explicit lane)
440+
### Swap neighbours (xor pattern via explicit lane) example
414441

415442
```python
416443
@qd.kernel
@@ -424,7 +451,7 @@ def swap_pairs(src: qd.types.ndarray(dtype=qd.f32, ndim=1),
424451

425452
Pairs `(0,1)`, `(2,3)`, ... swap their values.
426453

427-
### Arbitrary per-lane gather
454+
### Arbitrary per-lane gather example
428455

429456
```python
430457
@qd.kernel
@@ -440,7 +467,7 @@ def reverse4(src: qd.types.ndarray(dtype=qd.f32, ndim=1),
440467

441468
Within each group of 4 contiguous lanes the values are reversed.
442469

443-
### Tree reduction with `shuffle_down`
470+
### Tree reduction with `shuffle_down` example
444471

445472
Classic warp-level sum of 4 values - after the second step, lane 0 of each group of 4 holds the total:
446473

@@ -458,7 +485,7 @@ def reduce4(src: qd.types.ndarray(dtype=qd.f32, ndim=1),
458485

459486
Extend the pattern (offsets 16, 8, 4, 2, 1, ...) to reduce a full subgroup; only lane 0's final value is meaningful, because the lanes near the top read past the end of the subgroup.
460487

461-
### Sum 32 lanes with `reduce_add_tiled`
488+
### Sum 32 lanes with `reduce_add_tiled` example
462489

463490
The same tree, packaged as a one-liner. Lane 0 of each group of 32 holds the total; other lanes hold partial sums:
464491

@@ -475,7 +502,7 @@ def sum32(src: qd.types.ndarray(dtype=qd.f32, ndim=1),
475502

476503
`5` is `log2_size`; `2**5 == 32` matches the block dim. The body of `reduce_add_tiled` unrolls at compile time into five `shuffle_down + add` pairs, so the generated IR is identical to a hand-written tree reduction.
477504

478-
### Broadcast the sum to all lanes with `reduce_all_add_tiled`
505+
### Broadcast the sum to all lanes with `reduce_all_add_tiled` example
479506

480507
When every lane needs the reduction result - e.g. to normalise by the sum - use the butterfly variant. No follow-up broadcast needed:
481508

@@ -490,11 +517,41 @@ def normalize32(a: qd.types.ndarray(dtype=qd.f32, ndim=1)):
490517

491518
Every lane in each group of 32 sees the same `total`.
492519

493-
### Partial-subgroup reductions
520+
### Partial-subgroup reductions example
494521

495522
`log2_size` does not have to match the full subgroup. Sum groups of 8 with `reduce_add_tiled(v, 3)` or groups of 16 with `reduce_all_add_tiled(v, 4)`; the caller just ensures `2**log2_size <= group_size()` (so `log2_size <= 5` on CUDA / Metal / Vulkan-wave32, `<= 6` on AMDGPU wave64). Use the bare `reduce_add(v)` / `reduce_all_add(v)` form when you want "the whole subgroup" without hard-coding the limit.
496523

497-
### Inclusive scan with `inclusive_add_tiled`
524+
### Bitonic key/value sort example
525+
526+
Sort up to 32 `(key, value)` pairs in registers, one per lane, with `bitonic_sort_kv`. The pattern below is the contact-pruning sort used by Genesis: each lane carries a packed link-pair id (`key`) and a contact index (`value`); after the sort, lane `i` holds the `i`-th smallest pair under the lex order `(key, value)`. Lanes past the real data (`lane >= n_con`) carry a sentinel key (`+inf` here) that the sort moves to the tail:
527+
528+
```python
529+
@qd.kernel
530+
def sort_contacts(keys: qd.types.ndarray(dtype=qd.f32, ndim=1),
531+
idxs: qd.types.ndarray(dtype=qd.i32, ndim=1),
532+
n_con: qd.i32):
533+
qd.loop_config(block_dim=32)
534+
for tid in range(32):
535+
# Load real data into the low n_con lanes; sentinel-pad the rest. +inf compares greater than every real
536+
# key, so the sentinels drift to the high end of the sort.
537+
my_key = qd.f32(1.0e30)
538+
my_idx = qd.i32(-1)
539+
if tid < n_con:
540+
my_key = keys[tid]
541+
my_idx = idxs[tid]
542+
543+
my_key, my_idx = subgroup.bitonic_sort_kv(my_key, my_idx)
544+
545+
if tid < n_con:
546+
keys[tid] = my_key
547+
idxs[tid] = my_idx
548+
```
549+
550+
After the kernel, `keys[0..n_con]` is sorted ascending and `idxs` is the matching permutation. The body unrolls at compile time into 30 shuffles + lex compares (for `log2_size = 5`, the wave32 default); no shared memory, no barriers.
551+
552+
Use `bitonic_sort_kv_tiled(k, v, log2_size)` directly to run multiple independent sorts per subgroup - e.g. `bitonic_sort_kv_tiled(k, v, 3)` runs `group_size() / 8` independent 8-element sorts in parallel. The tiles are `2**log2_size`-aligned within the subgroup and do not interact.
553+
554+
### Inclusive scan with `inclusive_add_tiled` example
498555

499556
```python
500557
@qd.kernel
@@ -533,6 +590,7 @@ One subtlety worth knowing about (mostly for anyone reading the generated IR): t
533590
- Pick `reduce_all_add` over `reduce_add + broadcast` when you need the result in every lane - same cost, one fewer shuffle.
534591
- 64-bit dtypes (`i64`, `u64`, `f64`) are emulated as two 32-bit shuffles on AMDGPU. Prefer 32-bit values when you have a choice.
535592
- All seven `inclusive_*` ops are `@qd.func` Hillis-Steele scans; cost is exactly `log2_group_size()` shuffle+op pairs, the same as a hand-rolled CUDA warp scan, on every backend. Hardware-accelerated `OpGroupNonUniformInclusiveScan` on SPIR-V is no longer used - the cost difference vs. a portable shuffle tree is small in practice, and the uniform implementation makes performance predictable across CUDA, AMDGPU, and SPIR-V.
593+
- `bitonic_sort_kv` runs the standard 1-D bitonic schedule: `log2_size * (log2_size + 1) / 2` compare-exchange stages, each two `shuffle` ops + a lex compare + a predicated assignment. Total: `log2_size * (log2_size + 1)` shuffles - 30 for `log2_size = 5` (wave32), 42 for `log2_size = 6` (wave64). All compile-time unrolled into the calling kernel's IR; no shared memory, no barriers within the sort.
536594

537595
## Related
538596

0 commit comments

Comments
 (0)