Skip to content

feat(gpu): implement shuffle#3472

Draft
enzodimaria wants to merge 1 commit into
mainfrom
edm/shuffle
Draft

feat(gpu): implement shuffle#3472
enzodimaria wants to merge 1 commit into
mainfrom
edm/shuffle

Conversation

@enzodimaria
Copy link
Copy Markdown
Contributor

@enzodimaria enzodimaria commented Apr 14, 2026

This PR contains...

the new feature FHE bitonic_shuffle:

Benchmarks:

  +----------+-----------+-------------------------------------+----------+---------+
  | Size (n) | Parameter | Operation                           | GPU      | CPU     |
  +----------+-----------+-------------------------------------+----------+---------+
  | 8        | MULTIBIT  | unchecked_bitonic_shuffle_with_keys | 266 ms   | -       |
  |          |           | bitonic_shuffle                     | 271 ms   | -       |
  |          |           | OPRF (estimated)                    | 5 ms     | -       |
  |          |-----------+-------------------------------------+----------+---------+
  |          | CLASSICAL | unchecked_bitonic_shuffle_with_keys | 301 ms   | 2000 ms |
  |          |           | bitonic_shuffle                     | 309 ms   | 2100 ms |
  |          |           | OPRF (estimated)                    | 8 ms     | 100 ms  |
  +----------+-----------+-------------------------------------+----------+---------+
  | 16       | MULTIBIT  | unchecked_bitonic_shuffle_with_keys | 826 ms   | -       |
  |          |           | bitonic_shuffle                     | 836 ms   | -       |
  |          |           | OPRF (estimated)                    | 10 ms    | -       |
  |          |-----------+-------------------------------------+----------+---------+
  |          | CLASSICAL | unchecked_bitonic_shuffle_with_keys | 824 ms   | 4800 ms |
  |          |           | bitonic_shuffle                     | 839 ms   | 5100 ms |
  |          |           | OPRF (estimated)                    | 15 ms    | 300 ms  |
  +----------+-----------+-------------------------------------+----------+---------+
  | 32       | MULTIBIT  | unchecked_bitonic_shuffle_with_keys | 2420 ms  | -       |
  |          |           | bitonic_shuffle                     | 2442 ms  | -       |
  |          |           | OPRF (estimated)                    | 22 ms    | -       |
  |          |-----------+-------------------------------------+----------+---------+
  |          | CLASSICAL | unchecked_bitonic_shuffle_with_keys | 2246 ms  | 12500 ms|
  |          |           | bitonic_shuffle                     | 2279 ms  | 13000 ms|
  |          |           | OPRF (estimated)                    | 33 ms    | 500 ms  |
  +----------+-----------+-------------------------------------+----------+---------+
  | 64       | MULTIBIT  | unchecked_bitonic_shuffle_with_keys | 6704 ms  | -       |
  |          |           | bitonic_shuffle                     | 6741 ms  | -       |
  |          |           | OPRF (estimated)                    | 37 ms    | -       |
  |          |-----------+-------------------------------------+----------+---------+
  |          | CLASSICAL | unchecked_bitonic_shuffle_with_keys | 5988 ms  | -       |
  |          |           | bitonic_shuffle                     | 6069 ms  | -       |
  |          |           | OPRF (estimated)                    | 81 ms    | -       |
  +----------+-----------+-------------------------------------+----------+---------+

1 x H100-SXM

@cla-bot cla-bot Bot added the cla-signed label Apr 14, 2026
@enzodimaria enzodimaria force-pushed the edm/shuffle branch 6 times, most recently from 0ae7754 to f6b8be2 Compare April 15, 2026 13:31
Copy link
Copy Markdown
Contributor

@andrei-stoian-zama andrei-stoian-zama left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the API work with GpuFheXYArray ?

@enzodimaria
Copy link
Copy Markdown
Contributor Author

Should the API work with GpuFheXYArray ?

Yes but I have to write the HL api before, I'll do it after the merge of the CPU shuffle 👍

@enzodimaria enzodimaria marked this pull request as ready for review April 24, 2026 07:44
@enzodimaria enzodimaria marked this pull request as draft April 24, 2026 07:44
@enzodimaria enzodimaria force-pushed the edm/shuffle branch 5 times, most recently from 9c4dd42 to 67dad3e Compare April 29, 2026 14:38
@github-actions
Copy link
Copy Markdown

Backward-compat snapshot: everything looks good! No backward-compatibility issues detected.

Copy link
Copy Markdown
Contributor

@andrei-stoian-zama andrei-stoian-zama left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good but please remove:

Two dead functions:

  • batched_tree_sign_reduction — only called from host_batched_unsigned_comparison
  • host_batched_unsigned_comparison — never called (explicitly flagged "Unused" in its own comment)

Nine dead fields on int_bitonic_sort_buffer, each with alloc + release:

  • batch_cmp_packed — only in host_batched_unsigned_comparison
  • batch_cmp_comparisons — only in host_batched_unsigned_comparison
  • batch_identity_lut — only in host_batched_unsigned_comparison
  • batch_is_non_zero_lut — only in host_batched_unsigned_comparison
  • batch_cmp_tree_x — only in batched_tree_sign_reduction
  • batch_cmp_tree_y — only in batched_tree_sign_reduction
  • batch_inner_tree_leaf_lut — only in batched_tree_sign_reduction
  • batch_last_tree_leaf_lut — only in batched_tree_sign_reduction
  • preallocated_h_lut — only in batched_tree_sign_reduction

Comment thread backends/tfhe-cuda-backend/cuda/include/integer/shuffle_utilities.h Outdated
Comment thread backends/tfhe-cuda-backend/cuda/include/integer/shuffle_utilities.h Outdated
Comment thread tfhe/src/integer/gpu/server_key/radix/shuffle.rs
Comment thread backends/tfhe-cuda-backend/cuda/src/integer/shuffle.cuh Outdated
@andrei-stoian-zama
Copy link
Copy Markdown
Contributor

andrei-stoian-zama commented May 12, 2026

For a next PR:

  • batch the comparisons (based on initial work in host_batched_unsigned_comparison)
  • the cmux implemented here (shuffle.cuh:283) performs a "message extract" step to clean noise. but we might not need to clean the noise since the subsequent comparison first performs a subtraction, then cleans the noise (so it can pack) - only works for message_modulus==carry_modulus
  • explore if a single kernel can perform the various copy operations that are done on loops atm

Copy link
Copy Markdown
Contributor

@andrei-stoian-zama andrei-stoian-zama left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd also like to have a better explanation of the algorithm in the code comments.

I asked Claude to produce pseudo code. Could you confirm this pseudo code is correct and then break it up and copy it into comments next to the code blocks that implement it ?


High-Level Pseudocode

  FUNCTION bitonic_shuffle_with_keys(keys[], data[], n):
      # Pad to next power of 2
      padded_n = next_power_of_2(n)
      FOR i IN [n, padded_n):
          keys[i] = MAX_VALUE   # sentinel: always sorts to end
          data[i] = 0

      # Bitonic network
      FOR k = 2, 4, 8, ... while k <= padded_n:
          FOR j = k/2, k/4, ... while j >= 1:
              bitonic_substep(keys, data, padded_n, k, j)

      RETURN keys[0..n], data[0..n]   # drop sentinels


  FUNCTION bitonic_substep(keys[], data[], n, k, j):
      # Step 1: Compare all pairs in parallel (one PBS per block)
      FOR each i where (i XOR j) > i:
          l = i XOR j
          ascending = ((i AND k) == 0)
          sign[i] = FHE_compare(keys[i], keys[l])   # → {INF, EQ, SUP}
     # Step 2: Conditional swap keys (batched CMUX)
      FOR each pair (i, l):
          should_swap = ascending ? (sign == SUP) : (sign == INF)
          (keys[i], keys[l]) = CMUX(should_swap,
                                     (keys[l], keys[i]),   # swapped
                                     (keys[i], keys[l]))   # unchanged

      # Step 3: Same CMUX for data (reuse comparison result)
      FOR each pair (i, l):
          (data[i], data[l]) = CMUX(should_swap,
                                     (data[l], data[i]),
                                     (data[i], data[l]))

  ---
  CMUX (Conditional Multiplexing) — the core primitive

  Since we can't branch on encrypted values, swaps are done via:

  FUNCTION CMUX(condition, true_val, false_val):
     # condition ∈ {INF=0, EQ=1, SUP=2}

      # Bivariate PBS: zero out the losing branch
      out_true  = bivariate_PBS(true_val,  condition,
                      LUT: (b, c) -> if c == SUP then b else 0)
      out_false = bivariate_PBS(false_val, condition,
                      LUT: (b, c) -> if c != SUP then b else 0)

      # Add: exactly one branch is nonzero
      result = HE_add(out_true, out_false)
      result = message_extract(result)    # clean up noise
      RETURN result

@andrei-stoian-zama
Copy link
Copy Markdown
Contributor

Could you please also look into refactoring CMUX : could the FheUintXY Cmux call into the batched version with a batch of 1 ? then a single function could be used in shuffle and the single value cmux operation

@enzodimaria enzodimaria force-pushed the edm/shuffle branch 2 times, most recently from 6617bbb to b096209 Compare May 19, 2026 08:31
@enzodimaria enzodimaria force-pushed the edm/shuffle branch 3 times, most recently from 16535f9 to acfa092 Compare May 19, 2026 14:18
@enzodimaria enzodimaria marked this pull request as ready for review May 20, 2026 08:04
@enzodimaria enzodimaria marked this pull request as draft May 20, 2026 08:19
@enzodimaria enzodimaria force-pushed the edm/shuffle branch 3 times, most recently from 9179e98 to d4d3465 Compare May 20, 2026 13:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants