Skip to content

Commit 60145aa

Browse files
chore(gpu): bench KS latency batches
1 parent 4eb4fa9 commit 60145aa

File tree

20 files changed

+772
-101
lines changed

20 files changed

+772
-101
lines changed

Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -705,9 +705,9 @@ test_gpu: test_core_crypto_gpu test_integer_gpu test_cuda_backend
705705
.PHONY: test_core_crypto_gpu # Run the tests of the core_crypto module including experimental on the gpu backend
706706
test_core_crypto_gpu: install_rs_build_toolchain
707707
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
708-
--features=gpu -p tfhe -- core_crypto::gpu::
709-
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --doc --profile $(CARGO_PROFILE) \
710-
--features=gpu -p tfhe -- core_crypto::gpu::
708+
--features=gpu -p tfhe -- core_crypto::gpu::algorithms::test::lwe_keyswitch::test_gpu_lwe_encrypt_ks_decrypt_custom_mod_test_params_4_bits_native_u64
709+
# RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --doc --profile $(CARGO_PROFILE) \
710+
# --features=gpu -p tfhe -- core_crypto::gpu::
711711

712712
.PHONY: test_integer_gpu # Run the tests of the integer module including experimental on the gpu backend
713713
test_integer_gpu: install_rs_build_toolchain

backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
#include <stdio.h>
1515

16+
#include "crypto/keyswitch.cuh"
17+
1618
class NoiseLevel {
1719
public:
1820
// Constants equivalent to the Rust code
@@ -336,6 +338,8 @@ struct int_radix_lut_custom_input_output {
336338
std::vector<InputTorus *> lwe_after_ks_vec;
337339
std::vector<OutputTorus *> lwe_after_pbs_vec;
338340
std::vector<InputTorus *> lwe_trivial_indexes_vec;
341+
std::vector<InputTorus *>
342+
ks_tmp_buf_vec; // buffers on each GPU to store keyswitch temporary data
339343
std::vector<InputTorus *> lwe_aligned_vec;
340344

341345
bool gpu_memory_allocated;
@@ -439,6 +443,21 @@ struct int_radix_lut_custom_input_output {
439443
multi_gpu_copy_array_async(active_streams, lwe_trivial_indexes_vec,
440444
lwe_trivial_indexes, num_radix_blocks,
441445
allocate_gpu_memory);
446+
447+
for (auto i = 0; i < active_streams.count(); ++i) {
448+
uint64_t sub_size_tracker = 0;
449+
uint64_t buffer_size = scratch_cuda_keyswitch_size<InputTorus>(
450+
params.small_lwe_dimension, params.big_lwe_dimension,
451+
num_radix_blocks);
452+
auto *gpu_ks_buffer = (InputTorus *)cuda_malloc_with_size_tracking_async(
453+
buffer_size, active_streams.stream(i), active_streams.gpu_index(i),
454+
sub_size_tracker, allocate_gpu_memory);
455+
456+
if (i == 0) {
457+
size_tracker += sub_size_tracker;
458+
}
459+
ks_tmp_buf_vec.push_back(gpu_ks_buffer);
460+
}
442461
}
443462

444463
void setup_mem_reuse(uint32_t num_radix_blocks,
@@ -459,6 +478,8 @@ struct int_radix_lut_custom_input_output {
459478
lwe_after_pbs_vec = base_lut_object->lwe_after_pbs_vec;
460479
lwe_trivial_indexes_vec = base_lut_object->lwe_trivial_indexes_vec;
461480

481+
ks_tmp_buf_vec = base_lut_object->ks_tmp_buf_vec;
482+
462483
mem_reuse = true;
463484
}
464485

@@ -861,6 +882,12 @@ struct int_radix_lut_custom_input_output {
861882
}
862883
lwe_aligned_vec.clear();
863884
}
885+
886+
for (auto i = 0; i < ks_tmp_buf_vec.size(); i++) {
887+
cuda_drop_with_size_tracking_async(
888+
ks_tmp_buf_vec[i], active_streams.stream(i),
889+
active_streams.gpu_index(i), gpu_memory_allocated);
890+
}
864891
}
865892
free(h_lut_indexes);
866893
free(degrees);

backends/tfhe-cuda-backend/cuda/include/integer/rerand_utilities.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ template <typename Torus> struct int_rerand_mem {
1515

1616
bool gpu_memory_allocated;
1717

18+
std::vector<Torus *>
19+
ks_tmp_buf_vec; // buffers on each GPU to store keyswitch temporary data
20+
1821
expand_job<Torus> *d_expand_jobs;
1922
expand_job<Torus> *h_expand_jobs;
2023

@@ -54,6 +57,20 @@ template <typename Torus> struct int_rerand_mem {
5457
num_lwes * sizeof(Torus), streams.stream(0),
5558
streams.gpu_index(0));
5659

60+
for (auto i = 0; i < streams.count(); ++i) {
61+
uint64_t sub_size_tracker = 0;
62+
uint64_t buffer_size = scratch_cuda_keyswitch_size<Torus>(
63+
params.small_lwe_dimension, params.big_lwe_dimension, num_lwes);
64+
auto *gpu_ks_buffer = (Torus *)cuda_malloc_with_size_tracking_async(
65+
buffer_size, streams.stream(i), streams.gpu_index(i),
66+
sub_size_tracker, allocate_gpu_memory);
67+
68+
if (i == 0) {
69+
size_tracker += sub_size_tracker;
70+
}
71+
ks_tmp_buf_vec.push_back(gpu_ks_buffer);
72+
}
73+
5774
streams.synchronize();
5875

5976
free(h_lwe_trivial_indexes);
@@ -69,6 +86,14 @@ template <typename Torus> struct int_rerand_mem {
6986
cuda_drop_with_size_tracking_async(d_expand_jobs, streams.stream(0),
7087
streams.gpu_index(0),
7188
gpu_memory_allocated);
89+
90+
for (auto i = 0; i < ks_tmp_buf_vec.size(); i++) {
91+
cuda_drop_with_size_tracking_async(ks_tmp_buf_vec[i], streams.stream(i),
92+
streams.gpu_index(i),
93+
gpu_memory_allocated);
94+
}
95+
ks_tmp_buf_vec.clear();
96+
7297
cuda_synchronize_stream(streams.stream(0), streams.gpu_index(0));
7398
free(h_expand_jobs);
7499
}

backends/tfhe-cuda-backend/cuda/include/keyswitch/keyswitch.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,22 @@ void cuda_keyswitch_lwe_ciphertext_vector_64(
1717
void const *lwe_output_indexes, void const *lwe_array_in,
1818
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
1919
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
20-
uint32_t num_samples);
20+
uint32_t num_samples, int8_t *ksk_tmp_buffer, bool uses_trivial_indexes);
2121

2222
uint64_t scratch_packing_keyswitch_lwe_list_to_glwe_64(
2323
void *stream, uint32_t gpu_index, int8_t **fp_ks_buffer,
2424
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
2525
uint32_t num_lwes, bool allocate_gpu_memory);
2626

27+
uint64_t scratch_cuda_keyswitch_64(void *stream, uint32_t gpu_index,
28+
int8_t **fp_ks_buffer,
29+
uint32_t lwe_dimension_in,
30+
uint32_t lwe_dimension_out,
31+
uint32_t num_lwes, bool allocate_gpu_memory);
32+
33+
void cleanup_cuda_keyswitch_64(void *stream, uint32_t gpu_index,
34+
int8_t **fp_ks_buffer, bool allocate_gpu_memory);
35+
2736
void cuda_packing_keyswitch_lwe_list_to_glwe_64(
2837
void *stream, uint32_t gpu_index, void *glwe_array_out,
2938
void const *lwe_array_in, void const *fp_ksk_array, int8_t *fp_ks_buffer,

backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cu

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@ void cuda_keyswitch_lwe_ciphertext_vector_32(
99
void *stream, uint32_t gpu_index, void *lwe_array_out,
1010
void *lwe_output_indexes, void *lwe_array_in, void *lwe_input_indexes,
1111
void *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
12-
uint32_t base_log, uint32_t level_count, uint32_t num_samples) {
13-
host_keyswitch_lwe_ciphertext_vector<uint32_t>(
12+
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
13+
void *ksk_tmp_buffer, bool uses_trivial_indices) {
14+
host_gemm_keyswitch_lwe_ciphertext_vector<uint32_t>(
1415
static_cast<cudaStream_t>(stream), gpu_index,
1516
static_cast<uint32_t *>(lwe_array_out),
1617
static_cast<uint32_t *>(lwe_output_indexes),
1718
static_cast<uint32_t *>(lwe_array_in),
1819
static_cast<uint32_t *>(lwe_input_indexes), static_cast<uint32_t *>(ksk),
19-
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples);
20+
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples,
21+
static_cast<uint32_t *>(ksk_tmp_buffer), uses_trivial_indices);
2022
}
2123

2224
/* Perform keyswitch on a batch of 64 bits input LWE ciphertexts.
@@ -40,15 +42,16 @@ void cuda_keyswitch_lwe_ciphertext_vector_64(
4042
void const *lwe_output_indexes, void const *lwe_array_in,
4143
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
4244
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
43-
uint32_t num_samples) {
44-
host_keyswitch_lwe_ciphertext_vector<uint64_t>(
45+
uint32_t num_samples, int8_t *ksk_tmp_buffer, bool uses_trivial_indices) {
46+
host_gemm_keyswitch_lwe_ciphertext_vector<uint64_t>(
4547
static_cast<cudaStream_t>(stream), gpu_index,
4648
static_cast<uint64_t *>(lwe_array_out),
4749
static_cast<const uint64_t *>(lwe_output_indexes),
4850
static_cast<const uint64_t *>(lwe_array_in),
4951
static_cast<const uint64_t *>(lwe_input_indexes),
5052
static_cast<const uint64_t *>(ksk), lwe_dimension_in, lwe_dimension_out,
51-
base_log, level_count, num_samples);
53+
base_log, level_count, num_samples,
54+
(uint64_t *)((ks_mem *)ksk_tmp_buffer)->buffer, uses_trivial_indices);
5255
}
5356

5457
uint64_t scratch_packing_keyswitch_lwe_list_to_glwe_64(
@@ -60,6 +63,27 @@ uint64_t scratch_packing_keyswitch_lwe_list_to_glwe_64(
6063
glwe_dimension, polynomial_size, num_lwes, allocate_gpu_memory);
6164
}
6265

66+
uint64_t scratch_cuda_keyswitch_64(void *stream, uint32_t gpu_index,
67+
int8_t **fp_ks_buffer,
68+
uint32_t lwe_dimension_in,
69+
uint32_t lwe_dimension_out,
70+
uint32_t num_lwes,
71+
bool allocate_gpu_memory) {
72+
return scratch_cuda_keyswitch<uint64_t>(
73+
static_cast<cudaStream_t>(stream), gpu_index, (ks_mem **)fp_ks_buffer,
74+
lwe_dimension_in, lwe_dimension_out, num_lwes, allocate_gpu_memory);
75+
}
76+
77+
void cleanup_cuda_keyswitch_64(void *stream, uint32_t gpu_index,
78+
int8_t **fp_ks_buffer,
79+
bool allocate_gpu_memory) {
80+
cleanup_cuda_keyswitch<uint64_t>(static_cast<cudaStream_t>(stream), gpu_index,
81+
(ks_mem *)*fp_ks_buffer,
82+
allocate_gpu_memory);
83+
delete (ks_mem *)*fp_ks_buffer;
84+
*fp_ks_buffer = nullptr;
85+
}
86+
6387
/* Perform functional packing keyswitch on a batch of 64 bits input LWE
6488
* ciphertexts.
6589
*/

0 commit comments

Comments
 (0)