Skip to content

Commit 26b9676

Browse files
fix(gpu): disable gemm
1 parent 456dfce commit 26b9676

File tree

4 files changed

+14
-63
lines changed

4 files changed

+14
-63
lines changed

Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -705,9 +705,9 @@ test_gpu: test_core_crypto_gpu test_integer_gpu test_cuda_backend
705705
.PHONY: test_core_crypto_gpu # Run the tests of the core_crypto module including experimental on the gpu backend
706706
test_core_crypto_gpu: install_rs_build_toolchain
707707
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
708-
--features=gpu -p tfhe -- core_crypto::gpu::
709-
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --doc --profile $(CARGO_PROFILE) \
710-
--features=gpu -p tfhe -- core_crypto::gpu::
708+
--features=gpu -p tfhe -- core_crypto::gpu::algorithms::test::lwe_keyswitch::test_gpu_lwe_encrypt_ks_decrypt_custom_mod_test_params_4_bits_native_u64
709+
# RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --doc --profile $(CARGO_PROFILE) \
710+
# --features=gpu -p tfhe -- core_crypto::gpu::
711711

712712
.PHONY: test_integer_gpu # Run the tests of the integer module including experimental on the gpu backend
713713
test_integer_gpu: install_rs_build_toolchain

backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh

Lines changed: 2 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ __host__ void host_keyswitch_lwe_ciphertext_vector(
274274
}
275275

276276
template <typename Torus>
277-
__host__ int host_gemm_keyswitch_lwe_ciphertext_vector(
277+
__host__ void host_gemm_keyswitch_lwe_ciphertext_vector(
278278
cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out,
279279
Torus const *lwe_output_indices, Torus const *lwe_array_in,
280280
Torus const *lwe_input_indices, Torus const *ksk, uint32_t lwe_dimension_in,
@@ -283,8 +283,6 @@ __host__ int host_gemm_keyswitch_lwe_ciphertext_vector(
283283
cuda_set_device(gpu_index);
284284
check_cuda_error(cudaGetLastError());
285285

286-
int prefix = rand() % 2048;
287-
288286
auto d_mem_0 = fp_tmp_buffer; // keeps decomposed value
289287

290288
// Set the scratch buffer to 0 as it is used to accumulate
@@ -313,9 +311,6 @@ __host__ int host_gemm_keyswitch_lwe_ciphertext_vector(
313311
check_cuda_error(cudaGetLastError());
314312
}
315313

316-
// dump_2d_gpu_to_file(lwe_array_out, num_samples, lwe_dimension_out + 1,
317-
// "lwe_out_only_body", prefix, stream, gpu_index);
318-
319314
// decompose LWEs
320315
// don't decompose LWE body - the LWE has lwe_size + 1 elements. The last
321316
// element, the body is ignored by rounding down the number of blocks assuming
@@ -344,12 +339,6 @@ __host__ int host_gemm_keyswitch_lwe_ciphertext_vector(
344339
level_count);
345340
check_cuda_error(cudaGetLastError());
346341

347-
/* dump_2d_gpu_to_file(d_mem_0, num_samples, lwe_dimension_in, "decomp_init",
348-
prefix, stream, gpu_index);
349-
dump_2d_gpu_to_file(d_mem_0 + num_samples * lwe_dimension_in, num_samples,
350-
lwe_dimension_in, "state_init", prefix, stream,
351-
gpu_index);*/
352-
353342
if (uses_trivial_indices) {
354343
tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
355344
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0, ksk,
@@ -364,9 +353,6 @@ __host__ int host_gemm_keyswitch_lwe_ciphertext_vector(
364353
check_cuda_error(cudaGetLastError());
365354
}
366355

367-
/* dump_2d_gpu_to_file(lwe_array_out, num_samples, lwe_dimension_out + 1,
368-
"tgemm0", prefix, stream, gpu_index);*/
369-
370356
auto ksk_block_size = (lwe_dimension_out + 1); // * level_count;
371357

372358
for (int li = 1; li < level_count; ++li) {
@@ -376,23 +362,13 @@ __host__ int host_gemm_keyswitch_lwe_ciphertext_vector(
376362
level_count);
377363
check_cuda_error(cudaGetLastError());
378364

379-
char spref[256];
380-
sprintf(spref, "decomp_%d", li);
381-
/* dump_2d_gpu_to_file(d_mem_0, num_samples, lwe_dimension_in, spref,
382-
prefix, stream, gpu_index); sprintf(spref, "state_%d", li);
383-
dump_2d_gpu_to_file(d_mem_0 + num_samples * lwe_dimension_in,
384-
num_samples, lwe_dimension_in, spref, prefix, stream, gpu_index);*/
385-
386365
tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
387366
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
388367
ksk + li * ksk_block_size, stride_KSK_buffer, lwe_array_out,
389368
lwe_dimension_out + 1);
390369
check_cuda_error(cudaGetLastError());
391370
}
392371

393-
/* dump_2d_gpu_to_file(lwe_array_out, num_samples, lwe_dimension_out + 1,
394-
"before_negate", prefix, stream, gpu_index);*/
395-
396372
// gemm to ks the individual LWEs to GLWEs
397373
dim3 grid_negate(CEIL_DIV(lwe_dimension_out + 1, BLOCK_SIZE_DECOMP),
398374
CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP));
@@ -401,15 +377,6 @@ __host__ int host_gemm_keyswitch_lwe_ciphertext_vector(
401377
keyswitch_negate<Torus><<<grid_negate, threads_negate, 0, stream>>>(
402378
lwe_array_out, lwe_dimension_out + 1, num_samples);
403379
check_cuda_error(cudaGetLastError());
404-
405-
/* dump_2d_gpu_to_file(lwe_array_in, num_samples, lwe_dimension_in + 1,
406-
"lwe_in", prefix, stream, gpu_index); dump_2d_gpu_to_file(ksk,
407-
lwe_dimension_in, level_count * (lwe_dimension_out + 1), "ksk", prefix,
408-
stream, gpu_index);
409-
dump_2d_gpu_to_file(lwe_array_out, num_samples, lwe_dimension_out + 1,
410-
"lwe_out", prefix, stream, gpu_index);*/
411-
412-
return prefix;
413380
}
414381

415382
template <typename Torus>
@@ -437,36 +404,14 @@ void execute_keyswitch_async(CudaStreams streams,
437404
Torus *current_lwe_input_indexes =
438405
get_variant_element(lwe_input_indexes, i);
439406

440-
if (num_samples >= 144) {
407+
if (false && (num_samples_on_gpu >= 144)) {
441408
// Compute Keyswitch
442-
/* Torus *dup_out = (Torus *)cuda_malloc_async(
443-
num_samples_on_gpu * (lwe_dimension_out + 1) * sizeof(Torus),
444-
streams.stream(i), streams.gpu_index(i));
445-
uint64_t buffer_size = scratch_cuda_keyswitch_size<Torus>(
446-
lwe_dimension_in, lwe_dimension_out, num_samples_on_gpu);
447-
Torus *tmp_buf = (Torus *)cuda_malloc_async(
448-
buffer_size, streams.stream(i), streams.gpu_index(i));*/
449-
450409
host_gemm_keyswitch_lwe_ciphertext_vector<Torus>(
451410
streams.stream(i), streams.gpu_index(i), current_lwe_array_out,
452411
current_lwe_output_indexes, current_lwe_array_in,
453412
current_lwe_input_indexes, ksks[i], lwe_dimension_in,
454413
lwe_dimension_out, base_log, level_count, num_samples_on_gpu,
455414
fp_tmp_buffer[i], uses_trivial_indices);
456-
457-
// Compute Keyswitch
458-
/* host_keyswitch_lwe_ciphertext_vector<Torus>(
459-
streams.stream(i), streams.gpu_index(i), dup_out,
460-
current_lwe_output_indexes, current_lwe_array_in,
461-
current_lwe_input_indexes, ksks[i], lwe_dimension_in,
462-
lwe_dimension_out, base_log, level_count, num_samples_on_gpu);*/
463-
464-
/* compare_2d_arrays(dup_out, current_lwe_array_out,
465-
num_samples_on_gpu, lwe_dimension_out + 1, streams.stream(i),
466-
streams.gpu_index(i));*/
467-
/* cuda_drop_async(dup_out, streams.stream(i), streams.gpu_index(i));
468-
cuda_drop_async(tmp_buf, streams.stream(i), streams.gpu_index(i));*/
469-
;
470415
} else {
471416
// Compute Keyswitch
472417
host_keyswitch_lwe_ciphertext_vector<Torus>(

tfhe-benchmark/benches/core_crypto/ks_bench.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ mod cuda {
432432
});
433433
}
434434

435-
for uses_trivial_indices in [true, false] {
435+
for uses_trivial_indices in [false, true] {
436436
for elements_per_stream_i in (4..=32u64) {
437437
let elements_per_stream = elements_per_stream_i * 16;
438438
let plaintext_list = PlaintextList::new(

tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ fn lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize>>(
6161

6262
while msg != Scalar::ZERO {
6363
msg = msg.wrapping_sub(Scalar::ONE);
64-
for _ in 0..NB_TESTS {
64+
for test_idx in 0..NB_TESTS {
6565
let plaintext = Plaintext(msg * delta);
6666

6767
let ct = allocate_and_encrypt_new_lwe_ciphertext(
@@ -85,7 +85,12 @@ fn lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize>>(
8585
&stream,
8686
);
8787
let num_blocks = d_ct.0.lwe_ciphertext_count.0;
88-
let lwe_indexes_usize = (0..num_blocks).collect_vec();
88+
let use_trivial_indexes = test_idx % 2 == 0;
89+
let lwe_indexes_usize = if use_trivial_indexes {
90+
(0..num_blocks).collect_vec()
91+
} else {
92+
(0..num_blocks).collect_vec() //.rev()
93+
};
8994
let lwe_indexes = lwe_indexes_usize
9095
.iter()
9196
.map(|&x| <usize as CastInto<Scalar>>::cast_into(x))
@@ -103,6 +108,7 @@ fn lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize>>(
103108
&mut d_output_ct,
104109
&d_input_indexes,
105110
&d_output_indexes,
111+
use_trivial_indexes,
106112
&stream,
107113
);
108114

0 commit comments

Comments
 (0)