Skip to content

Conversation

@andrei-stoian-zama
Copy link
Contributor

@andrei-stoian-zama andrei-stoian-zama commented Oct 16, 2025

  • Uses GEMM-based KS in AES when large batches of LWEs need to be keyswitched
  • Implements GEMM KS with non-trivial indexes
  • Improves KS GPU bench and KS GPU test

This change is Reviewable

@cla-bot cla-bot bot added the cla-signed label Oct 16, 2025
@andrei-stoian-zama andrei-stoian-zama force-pushed the as/gemm_ks branch 4 times, most recently from 6dc060b to 60145aa Compare November 10, 2025 10:32
@andrei-stoian-zama andrei-stoian-zama changed the title feat(gpu): use gemm ks for trivial indexes feat(gpu): use gemm ks in HL ops Nov 13, 2025
@andrei-stoian-zama andrei-stoian-zama force-pushed the as/gemm_ks branch 4 times, most recently from 3fa7b3f to 3b28051 Compare November 17, 2025 20:33
@andrei-stoian-zama andrei-stoian-zama marked this pull request as ready for review November 18, 2025 07:44
Copy link
Member

@IceTDrinker IceTDrinker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick comments on the core non cuda part (I did not read that one), thanks!

@IceTDrinker reviewed 7 of 22 files at r1, all commit messages.
Reviewable status: 7 of 22 files reviewed, 13 unresolved discussions (waiting on @agnesLeroy and @soonum)


tfhe/src/core_crypto/gpu/algorithms/lwe_keyswitch.rs line 79 at r1 (raw file):

    let mut ks_tmp_buffer: *mut ffi::c_void = std::ptr::null_mut();

    let num_lwes_to_ks = min(

it's possible to only partially keyswitch an input ?


tfhe/src/core_crypto/gpu/algorithms/lwe_keyswitch.rs line 84 at r1 (raw file):

    );

    assert_eq!(input_indexes.len, output_indexes.len);

error message could be welcome


tfhe-benchmark/benches/core_crypto/ks_bench.rs line 458 at r1 (raw file):

                            let input_ks_list = LweCiphertextList::from_container(
                                input_ct_list.into_container(),
                                big_lwe_sk.lwe_dimension().to_lwe_size(),

same nits as tests will apply here about key dimensions and ciphertext counts and places where things get collected/transformed into vecs


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 118 at r1 (raw file):

        msg = msg.wrapping_sub(Scalar::ONE);
        for test_idx in 0..NB_TESTS {
            let num_blocks = test_idx * test_idx * 3 + 1;

are those magic numbers ? or could those be randomly chosen ?


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 139 at r1 (raw file):

                &mut rsc.encryption_random_generator,
            );
            let input_ks_list = LweCiphertextList::from_container(

why is this required ? it looks like it's just recreating the input_ct_list ?


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 149 at r1 (raw file):

            let output_ct_list = LweCiphertextList::new(
                Scalar::ZERO,
                lwe_sk.lwe_dimension().to_lwe_size(),

nit: prefer using the compute key (here ksk) dimension that will be used, tends to help with local reasoning


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 150 at r1 (raw file):

                Scalar::ZERO,
                lwe_sk.lwe_dimension().to_lwe_size(),
                LweCiphertextCount(num_blocks),

nit: use the input.lwe_ciphertext_count(), same for local reasoning


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 151 at r1 (raw file):

                lwe_sk.lwe_dimension().to_lwe_size(),
                LweCiphertextCount(num_blocks),
                ciphertext_modulus,

nit: again use the output modulus of the compute key (ksk)


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 173 at r1 (raw file):

            };
            let lwe_indexes_usize = (0..num_blocks).collect_vec();
            let mut lwe_indexes = lwe_indexes_usize.iter().collect_vec();

let mut lwe_indexes = lwe_indexes_usize.clone();

?


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 182 at r1 (raw file):

            }

            if num_blocks_to_ks < num_blocks {

no need to do the check you can always do

lwe_indexes = lwe_indexes[..num_blocks_to_ks];

I believe since the num to ks should always be <= num_blocks ?


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 184 at r1 (raw file):

            if num_blocks_to_ks < num_blocks {
                lwe_indexes = lwe_indexes[0..num_blocks_to_ks].to_vec();
                lwe_indexes_out = lwe_indexes_out[0..num_blocks_to_ks].to_vec();

I don't think you need the to_vec

can take a slice like

lwe_indexes = &lwe_indexes[..num_blocks_to_ks];

the whole thing above can be put in the iter below will give an example


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 187 at r1 (raw file):

            }

            let h_lwe_indexes: Vec<Scalar> = lwe_indexes

lwe_indexes.iter().take(num_blocks_to_ks).map(...)


tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs line 216 at r1 (raw file):

            for i in 0..num_blocks_to_ks {
                ref_vec[*lwe_indexes_out[i]] =
                    round_decode(*plaintext_list.get(*lwe_indexes[i]).0, delta); // % msg_modulus;

comment can be removed I think ?

@zama-bot zama-bot removed the approved label Nov 18, 2025
@zama-bot zama-bot removed the approved label Nov 19, 2025
@andrei-stoian-zama andrei-stoian-zama changed the title feat(gpu): use gemm ks in HL ops feat(gpu): use gemm ks in AES Nov 21, 2025
@andrei-stoian-zama andrei-stoian-zama force-pushed the as/gemm_ks branch 3 times, most recently from 242b718 to 22e52e6 Compare November 21, 2025 16:31
@IceTDrinker
Copy link
Member

@andrei-stoian-zama should we review, i.e. the PR is ready ? otherwise I'll wait for you to say it's ok to review

@IceTDrinker
Copy link
Member

@andrei-stoian-zama is it ready ? I see conflicts with main

Copy link
Contributor

@agnesLeroy agnesLeroy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR @andrei-stoian-zama. Here comes my review 🙂


cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));

streams.synchronize();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had a sync on stream 0 at the line just above, maybe it's enough?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right, this was something I forgot to remove.

uint64_t scratch_cuda_keyswitch_size(uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out,
uint32_t num_lwes) {
return (uint64_t)num_lwes * std::max(lwe_dimension_in, lwe_dimension_out) *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we could keep lwe_dimension_in which is the large one, instead of the max

@@ -1,4 +1,5 @@
#include "device.h"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for new line

keyswitch_negate<Torus><<<grid_negate, threads_negate, 0, stream>>>(
lwe_array_out, lwe_dimension_out + 1, num_samples);
} else {
keyswitch_negate_with_output_indices<Torus>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here as well most probably

});
}

for uses_simple_indices in [false, true] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this additional bench should be just the one with non-trivial indices, since the case with trivial indexes is handled just above?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I see you only added the non trivial indexes case for latency but probably what makes more sense is to add it for thorughput for gemm keyswitch?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For latency you still run twice the trivial indexes case it seems

@andrei-stoian-zama
Copy link
Contributor Author

@andrei-stoian-zama is it ready ? I see conflicts with main

Yeah, sorry, it was green on Friday night, now it's not anymore :)

@IceTDrinker
Copy link
Member

tfhe/src/core_crypto/gpu/algorithms/lwe_keyswitch.rs line 79 at r1 (raw file):

Previously, IceTDrinker wrote…

it's possible to only partially keyswitch an input ?

about this ?

@andrei-stoian-zama
Copy link
Contributor Author

tfhe/src/core_crypto/gpu/algorithms/lwe_keyswitch.rs line 79 at r1 (raw file):

Previously, IceTDrinker wrote…

it's possible to only partially keyswitch an input ?

about this ?

Sorry, yes; it's possible to call KS with a N-sized LWE batch and only M<=N indices. The keyswitched LWEs can be written at output indices in an output LWE array of size K>=M

@IceTDrinker
Copy link
Member

tfhe/src/core_crypto/gpu/algorithms/lwe_keyswitch.rs line 79 at r1 (raw file):

Previously, andrei-stoian-zama (Andrei Stoian) wrote…

Sorry, yes; it's possible to call KS with a N-sized LWE batch and only M<=N indices. The keyswitched LWEs can be written at output indices in an output LWE array of size K>=M

understood 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants