You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/user_guide/compound_types.md
+1-1Lines changed: 1 addition & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -32,7 +32,7 @@ It's of course very subjective, but some guidelines you could consider:
32
32
33
33
- if you are trying to write a python class that runs on the GPU => use a `@qd.data_oriented`
34
34
- if you are trying to write typed dataclasses, for passing data around between the `@data_oriented` classes, and between methods of the same `@data_oriented` class => use `@dataclasses.dataclass`es
35
-
-`@qd.dataclass` is used to create structured element types for field tensors. We also use it to create the Cholesky [tiles](tile16.md).
35
+
-`@qd.dataclass` is used to create structured element types for field tensors. We also use it to create the Cholesky [tiles](tile.md).
Copy file name to clipboardExpand all lines: docs/source/user_guide/subgroup.md
+68-10Lines changed: 68 additions & 10 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -84,6 +84,16 @@ Every op above has a paired `_tiled` form that takes an extra `log2_size` templa
84
84
85
85
The SPV-only no-arg reductions (`subgroup.reduce_mul` / `reduce_and` / `reduce_or` / `reduce_xor`, plus the original `reduce_add_tiled(value)` with no `log2_size`) have been removed in favour of the portable sized API. For reductions other than the ones listed above, build a sized helper on top of `shuffle_down` / `shuffle` following the same pattern as `reduce_add_tiled` / `reduce_all_add_tiled`.
86
86
87
+
### Sorting
88
+
89
+
In-register key/value sort across the subgroup, one `(key, value)` pair per lane. Pure `shuffle` -no shared memory, no barriers -fully unrolled at compile time.
90
+
91
+
| Op | CUDA | AMDGPU | SPIR-V (Vulkan / Metal) | dtypes |
Returns `(key, value)` - assign with `key, value = subgroup.bitonic_sort_kv(key, value)`. Sorts ascending on the `(key, value)` lex tuple; ties on `key` break on ascending `value` (not a textbook-stable sort - equal-keyed lanes come back in ascending-`value` order, not in original-lane order). Tiled variant: `bitonic_sort_kv_tiled(key, value, log2_size)` runs the same sort independently on each `2**log2_size`-aligned tile - see [Tiled variants](#tiled-variants). See [`bitonic_sort_kv`](#bitonic_sort_kvkey-value) for the short-input pattern (sentinel padding), the textbook-stability caveat, and the float NaN behaviour. See [Bitonic key/value sort example](#bitonic-keyvalue-sort-example) for an example.
96
+
87
97
## Semantics
88
98
89
99
All of these ops operate within a single subgroup: they do not move data through memory and do not synchronise across subgroups.
@@ -193,6 +203,7 @@ Why it composes exactly: the underlying `subgroup.shuffle` / `subgroup.shuffle_d
|`subgroup.bitonic_sort_kv_tiled(key, value, log2_size)`| broadcast-to-tile (every lane in the tile holds its sorted-position pair) |
196
207
197
208
-**Broadcast-to-tile forms**: every lane in each tile holds that tile's result. Lanes in different tiles hold different results (their own tile's).
198
209
-**Tile-local lane-0 forms**: only the *tile-local* lane 0 holds the reduction. That's lane 0 alone with `log2_size=5` on wave32, lanes 0 and 32 with `log2_size=5` on wave64, lanes 0 / 16 / 32 / 48 with `log2_size=4` on wave64, etc. Other lanes hold partial reductions and should be treated as undefined. Use the `reduce_all_*_tiled` counterparts if you want every lane to see its tile's result.
@@ -302,6 +313,22 @@ Per-lane exclusive scan across the entire subgroup, under the binary operator na
302
313
- The shared `_exclusive_scan_tiled` helper runs the inclusive scan, shifts the result up by one lane via `shuffle_up`, and substitutes the identity at lane 0. The lane-0 substitution is required because `shuffle_up` with offset 1 is implementation-defined at lane 0 (and `OpGroupNonUniformShuffleUp` calls it undefined outright).
303
314
- AMDGPU performance note (`*` in the table): same `ds_bpermute` cost as `shuffle_up`. Cost is one inclusive scan plus one extra `shuffle_up` and a select.
304
315
316
+
### `bitonic_sort_kv(key, value)`
317
+
318
+
In-register ascending lex sort on `(key, value)` pairs across the subgroup, one `(key, value)` pair per lane. Returns `(key, value)` - the lex-smallest pair on lane 0, the next on lane 1, ..., the lex-largest on lane `group_size() - 1`. Tiled variant: `bitonic_sort_kv_tiled(key, value, log2_size)` runs the same sort independently on each `2**log2_size`-aligned tile - see [Tiled variants](#tiled-variants).
319
+
320
+
- Sorts ascending on `key`; ties on `key` break on ascending `value` (the comparison is the lex compare `(key, value) < (key', value')`).
321
+
-`key` and `value` should be scalar values. Supported dtypes are `i32`, `u32`, `f32`, `f64`, `i64`, `u64`. The dtypes of `key` and `value` do not have to match.
Float NaN handling is implementation-defined: comparisons with NaN return false on most backends, so a NaN-keyed lane drifts to an arbitrary position within the sorted tile and the result loses its "sorted" guarantee. Bit-cast the key to a same-width integer dtype if you need a portable NaN-respecting order.
327
+
328
+
#### Short-input pattern
329
+
330
+
When sorting fewer than `2**log2_size` real elements, load real data into the low `n` lanes, initialise the high lanes with a sentinel `key` that compares greater than every real key (`+inf` for floats, `INT_MAX` / `UINT_MAX` for ints) and any safe `value`, then ignore the high lanes in the result.
331
+
305
332
### `ballot_first_n(predicate, n)`
306
333
307
334
Returns a `u32` bitmask whose bit `i` is set iff `i < n` AND lane `i`'s `predicate` is non-zero. Bits `>= n` are always zero.
@@ -364,7 +391,7 @@ Closed-form `u32` lane-mask constants parametrised by a lane id. Bit `i` of the
Extend the pattern (offsets 16, 8, 4, 2, 1, ...) to reduce a full subgroup; only lane 0's final value is meaningful, because the lanes near the top read past the end of the subgroup.
460
487
461
-
### Sum 32 lanes with `reduce_add_tiled`
488
+
### Sum 32 lanes with `reduce_add_tiled` example
462
489
463
490
The same tree, packaged as a one-liner. Lane 0 of each group of 32 holds the total; other lanes hold partial sums:
`5` is `log2_size`; `2**5 == 32` matches the block dim. The body of `reduce_add_tiled` unrolls at compile time into five `shuffle_down + add` pairs, so the generated IR is identical to a hand-written tree reduction.
477
504
478
-
### Broadcast the sum to all lanes with `reduce_all_add_tiled`
505
+
### Broadcast the sum to all lanes with `reduce_all_add_tiled` example
479
506
480
507
When every lane needs the reduction result - e.g. to normalise by the sum - use the butterfly variant. No follow-up broadcast needed:
Every lane in each group of 32 sees the same `total`.
492
519
493
-
### Partial-subgroup reductions
520
+
### Partial-subgroup reductions example
494
521
495
522
`log2_size` does not have to match the full subgroup. Sum groups of 8 with `reduce_add_tiled(v, 3)` or groups of 16 with `reduce_all_add_tiled(v, 4)`; the caller just ensures `2**log2_size <= group_size()` (so `log2_size <= 5` on CUDA / Metal / Vulkan-wave32, `<= 6` on AMDGPU wave64). Use the bare `reduce_add(v)` / `reduce_all_add(v)` form when you want "the whole subgroup" without hard-coding the limit.
496
523
497
-
### Inclusive scan with `inclusive_add_tiled`
524
+
### Bitonic key/value sort example
525
+
526
+
Sort up to 32 `(key, value)` pairs in registers, one per lane, with `bitonic_sort_kv`. The pattern below is the contact-pruning sort used by Genesis: each lane carries a packed link-pair id (`key`) and a contact index (`value`); after the sort, lane `i` holds the `i`-th smallest pair under the lex order `(key, value)`. Lanes past the real data (`lane >= n_con`) carry a sentinel key (`+inf` here) that the sort moves to the tail:
After the kernel, `keys[0..n_con]` is sorted ascending and `idxs` is the matching permutation. The body unrolls at compile time into 30 shuffles + lex compares (for `log2_size = 5`, the wave32 default); no shared memory, no barriers.
551
+
552
+
Use `bitonic_sort_kv_tiled(k, v, log2_size)` directly to run multiple independent sorts per subgroup - e.g. `bitonic_sort_kv_tiled(k, v, 3)` runs `group_size() / 8` independent 8-element sorts in parallel. The tiles are `2**log2_size`-aligned within the subgroup and do not interact.
553
+
554
+
### Inclusive scan with `inclusive_add_tiled` example
498
555
499
556
```python
500
557
@qd.kernel
@@ -533,6 +590,7 @@ One subtlety worth knowing about (mostly for anyone reading the generated IR): t
533
590
- Pick `reduce_all_add` over `reduce_add + broadcast` when you need the result in every lane - same cost, one fewer shuffle.
534
591
- 64-bit dtypes (`i64`, `u64`, `f64`) are emulated as two 32-bit shuffles on AMDGPU. Prefer 32-bit values when you have a choice.
535
592
- All seven `inclusive_*` ops are `@qd.func` Hillis-Steele scans; cost is exactly `log2_group_size()` shuffle+op pairs, the same as a hand-rolled CUDA warp scan, on every backend. Hardware-accelerated `OpGroupNonUniformInclusiveScan` on SPIR-V is no longer used - the cost difference vs. a portable shuffle tree is small in practice, and the uniform implementation makes performance predictable across CUDA, AMDGPU, and SPIR-V.
593
+
-`bitonic_sort_kv` runs the standard 1-D bitonic schedule: `log2_size * (log2_size + 1) / 2` compare-exchange stages, each two `shuffle` ops + a lex compare + a predicated assignment. Total: `log2_size * (log2_size + 1)` shuffles - 30 for `log2_size = 5` (wave32), 42 for `log2_size = 6` (wave64). All compile-time unrolled into the calling kernel's IR; no shared memory, no barriers within the sort.
0 commit comments