feat(gpu): implement shuffle#3472
Conversation
0ae7754 to
f6b8be2
Compare
andrei-stoian-zama
left a comment
There was a problem hiding this comment.
Should the API work with GpuFheXYArray ?
Yes but I have to write the HL api before, I'll do it after the merge of the CPU shuffle 👍 |
9c4dd42 to
67dad3e
Compare
|
✅ Backward-compat snapshot: everything looks good! No backward-compatibility issues detected. |
andrei-stoian-zama
left a comment
There was a problem hiding this comment.
Looks good but please remove:
Two dead functions:
- batched_tree_sign_reduction — only called from host_batched_unsigned_comparison
- host_batched_unsigned_comparison — never called (explicitly flagged "Unused" in its own comment)
Nine dead fields on int_bitonic_sort_buffer, each with alloc + release:
- batch_cmp_packed — only in host_batched_unsigned_comparison
- batch_cmp_comparisons — only in host_batched_unsigned_comparison
- batch_identity_lut — only in host_batched_unsigned_comparison
- batch_is_non_zero_lut — only in host_batched_unsigned_comparison
- batch_cmp_tree_x — only in batched_tree_sign_reduction
- batch_cmp_tree_y — only in batched_tree_sign_reduction
- batch_inner_tree_leaf_lut — only in batched_tree_sign_reduction
- batch_last_tree_leaf_lut — only in batched_tree_sign_reduction
- preallocated_h_lut — only in batched_tree_sign_reduction
|
For a next PR:
|
There was a problem hiding this comment.
I'd also like to have a better explanation of the algorithm in the code comments.
I asked Claude to produce pseudo code. Could you confirm this pseudo code is correct and then break it up and copy it into comments next to the code blocks that implement it ?
High-Level Pseudocode
FUNCTION bitonic_shuffle_with_keys(keys[], data[], n):
# Pad to next power of 2
padded_n = next_power_of_2(n)
FOR i IN [n, padded_n):
keys[i] = MAX_VALUE # sentinel: always sorts to end
data[i] = 0
# Bitonic network
FOR k = 2, 4, 8, ... while k <= padded_n:
FOR j = k/2, k/4, ... while j >= 1:
bitonic_substep(keys, data, padded_n, k, j)
RETURN keys[0..n], data[0..n] # drop sentinels
FUNCTION bitonic_substep(keys[], data[], n, k, j):
# Step 1: Compare all pairs in parallel (one PBS per block)
FOR each i where (i XOR j) > i:
l = i XOR j
ascending = ((i AND k) == 0)
sign[i] = FHE_compare(keys[i], keys[l]) # → {INF, EQ, SUP}
# Step 2: Conditional swap keys (batched CMUX)
FOR each pair (i, l):
should_swap = ascending ? (sign == SUP) : (sign == INF)
(keys[i], keys[l]) = CMUX(should_swap,
(keys[l], keys[i]), # swapped
(keys[i], keys[l])) # unchanged
# Step 3: Same CMUX for data (reuse comparison result)
FOR each pair (i, l):
(data[i], data[l]) = CMUX(should_swap,
(data[l], data[i]),
(data[i], data[l]))
---
CMUX (Conditional Multiplexing) — the core primitive
Since we can't branch on encrypted values, swaps are done via:
FUNCTION CMUX(condition, true_val, false_val):
# condition ∈ {INF=0, EQ=1, SUP=2}
# Bivariate PBS: zero out the losing branch
out_true = bivariate_PBS(true_val, condition,
LUT: (b, c) -> if c == SUP then b else 0)
out_false = bivariate_PBS(false_val, condition,
LUT: (b, c) -> if c != SUP then b else 0)
# Add: exactly one branch is nonzero
result = HE_add(out_true, out_false)
result = message_extract(result) # clean up noise
RETURN result
|
Could you please also look into refactoring CMUX : could the FheUintXY Cmux call into the batched version with a batch of 1 ? then a single function could be used in shuffle and the single value cmux operation |
6617bbb to
b096209
Compare
16535f9 to
acfa092
Compare
9179e98 to
d4d3465
Compare
This PR contains...
the new feature FHE bitonic_shuffle:
Benchmarks: