Skip to content

Commit cd614a8

Browse files
committed
chore(gpu): kreyvium -> init + next + drop
1 parent 25c9f46 commit cd614a8

8 files changed

Lines changed: 451 additions & 10 deletions

File tree

backends/tfhe-cuda-backend/cuda/include/kreyvium/kreyvium.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@ void cuda_kreyvium_generate_keystream_64_async(
2020

2121
void cleanup_cuda_kreyvium_generate_keystream_64(CudaStreamsFFI streams,
2222
int8_t **mem_ptr_void);
23+
24+
void cuda_kreyvium_init_state_64_async(CudaStreamsFFI streams,
25+
const CudaRadixCiphertextFFI *key,
26+
const CudaRadixCiphertextFFI *iv,
27+
int8_t *mem_ptr, void *const *bsks,
28+
void *const *ksks);
29+
30+
void cuda_kreyvium_next_step_64_async(CudaStreamsFFI streams,
31+
CudaRadixCiphertextFFI *keystream_output,
32+
uint32_t num_inputs, uint32_t num_steps,
33+
int8_t *mem_ptr, void *const *bsks,
34+
void *const *ksks);
2335
}
2436

2537
#endif

backends/tfhe-cuda-backend/cuda/src/kreyvium/kreyvium.cu

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,28 @@ void cleanup_cuda_kreyvium_generate_keystream_64(CudaStreamsFFI streams,
4444
delete mem_ptr;
4545
*mem_ptr_void = nullptr;
4646
}
47+
48+
void cuda_kreyvium_init_state_64_async(CudaStreamsFFI streams,
49+
const CudaRadixCiphertextFFI *key,
50+
const CudaRadixCiphertextFFI *iv,
51+
int8_t *mem_ptr, void *const *bsks,
52+
void *const *ksks) {
53+
54+
auto buffer = (int_kreyvium_buffer<uint64_t> *)mem_ptr;
55+
56+
host_kreyvium_init_state<uint64_t>(CudaStreams(streams), buffer, key, iv,
57+
bsks, (uint64_t *const *)ksks);
58+
}
59+
60+
void cuda_kreyvium_next_step_64_async(CudaStreamsFFI streams,
61+
CudaRadixCiphertextFFI *keystream_output,
62+
uint32_t num_inputs, uint32_t num_steps,
63+
int8_t *mem_ptr, void *const *bsks,
64+
void *const *ksks) {
65+
66+
auto buffer = (int_kreyvium_buffer<uint64_t> *)mem_ptr;
67+
68+
host_kreyvium_next_step<uint64_t>(CudaStreams(streams), buffer,
69+
keystream_output, num_inputs, num_steps,
70+
bsks, (uint64_t *const *)ksks);
71+
}

backends/tfhe-cuda-backend/cuda/src/kreyvium/kreyvium.cuh

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,4 +355,35 @@ uint64_t scratch_cuda_kreyvium_encrypt(CudaStreams streams,
355355
return size_tracker;
356356
}
357357

358+
template <typename Torus>
359+
__host__ void
360+
host_kreyvium_init_state(CudaStreams streams, int_kreyvium_buffer<Torus> *mem,
361+
CudaRadixCiphertextFFI const *key_bitsliced,
362+
CudaRadixCiphertextFFI const *iv_bitsliced,
363+
void *const *bsks, uint64_t *const *ksks) {
364+
365+
kreyvium_init(streams, mem, key_bitsliced, iv_bitsliced, bsks, ksks);
366+
}
367+
368+
template <typename Torus>
369+
__host__ void
370+
host_kreyvium_next_step(CudaStreams streams, int_kreyvium_buffer<Torus> *mem,
371+
CudaRadixCiphertextFFI *keystream_output,
372+
uint32_t num_inputs, uint32_t num_steps,
373+
void *const *bsks, uint64_t *const *ksks) {
374+
375+
uint32_t compute_batches = num_steps / KREYVIUM_BATCH_SIZE;
376+
for (uint32_t i = 0; i < compute_batches; i++) {
377+
if (keystream_output != nullptr) {
378+
CudaRadixCiphertextFFI batch_out_slice;
379+
slice_reg_batch_kreyvium<Torus>(&batch_out_slice, keystream_output,
380+
i * KREYVIUM_BATCH_SIZE,
381+
KREYVIUM_BATCH_SIZE, num_inputs);
382+
kreyvium_compute_64_steps(streams, mem, &batch_out_slice, bsks, ksks);
383+
} else {
384+
kreyvium_compute_64_steps(streams, mem, nullptr, bsks, ksks);
385+
}
386+
}
387+
}
388+
358389
#endif

backends/tfhe-cuda-backend/src/bindings.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2706,6 +2706,27 @@ unsafe extern "C" {
27062706
mem_ptr_void: *mut *mut i8,
27072707
);
27082708
}
2709+
unsafe extern "C" {
2710+
pub fn cuda_kreyvium_init_state_64_async(
2711+
streams: CudaStreamsFFI,
2712+
key: *const CudaRadixCiphertextFFI,
2713+
iv: *const CudaRadixCiphertextFFI,
2714+
mem_ptr: *mut i8,
2715+
bsks: *const *mut ffi::c_void,
2716+
ksks: *const *mut ffi::c_void,
2717+
);
2718+
}
2719+
unsafe extern "C" {
2720+
pub fn cuda_kreyvium_next_step_64_async(
2721+
streams: CudaStreamsFFI,
2722+
keystream_output: *mut CudaRadixCiphertextFFI,
2723+
num_inputs: u32,
2724+
num_steps: u32,
2725+
mem_ptr: *mut i8,
2726+
bsks: *const *mut ffi::c_void,
2727+
ksks: *const *mut ffi::c_void,
2728+
);
2729+
}
27092730
pub const KS_TYPE_BIG_TO_SMALL: KS_TYPE = 0;
27102731
pub const KS_TYPE_SMALL_TO_BIG: KS_TYPE = 1;
27112732
pub type KS_TYPE = ffi::c_uint;

tfhe/src/integer/gpu/ffi.rs

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10756,3 +10756,127 @@ pub(crate) unsafe fn cuda_backend_kreyvium_generate_keystream<T: UnsignedInteger
1075610756

1075710757
update_noise_degree(keystream_output, &cuda_ffi_keystream);
1075810758
}
10759+
10760+
#[allow(clippy::too_many_arguments)]
10761+
pub(crate) unsafe fn cuda_backend_kreyvium_init_state<B: Numeric, KST: UnsignedInteger>(
10762+
streams: &CudaStreams,
10763+
key: &CudaRadixCiphertext,
10764+
iv: &CudaRadixCiphertext,
10765+
bootstrapping_key: &CudaVec<B>,
10766+
keyswitch_key: &CudaVec<KST>,
10767+
message_modulus: MessageModulus,
10768+
carry_modulus: CarryModulus,
10769+
glwe_dimension: GlweDimension,
10770+
polynomial_size: PolynomialSize,
10771+
lwe_dimension: LweDimension,
10772+
ks_level: DecompositionLevelCount,
10773+
ks_base_log: DecompositionBaseLog,
10774+
pbs_level: DecompositionLevelCount,
10775+
pbs_base_log: DecompositionBaseLog,
10776+
grouping_factor: LweBskGroupingFactor,
10777+
pbs_type: PBSType,
10778+
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
10779+
) -> *mut i8 {
10780+
let mut key_degrees = key.info.blocks.iter().map(|b| b.degree.0).collect();
10781+
let mut key_noise_levels = key.info.blocks.iter().map(|b| b.noise_level.0).collect();
10782+
let cuda_ffi_key = prepare_cuda_radix_ffi(key, &mut key_degrees, &mut key_noise_levels);
10783+
10784+
let mut iv_degrees = iv.info.blocks.iter().map(|b| b.degree.0).collect();
10785+
let mut iv_noise_levels = iv.info.blocks.iter().map(|b| b.noise_level.0).collect();
10786+
let cuda_ffi_iv = prepare_cuda_radix_ffi(iv, &mut iv_degrees, &mut iv_noise_levels);
10787+
10788+
let num_inputs = u32::try_from(key.info.blocks.len() / 128).unwrap();
10789+
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
10790+
10791+
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
10792+
10793+
scratch_cuda_kreyvium_generate_keystream_64_async(
10794+
streams.ffi(),
10795+
std::ptr::addr_of_mut!(mem_ptr),
10796+
u32::try_from(glwe_dimension.0).unwrap(),
10797+
u32::try_from(polynomial_size.0).unwrap(),
10798+
u32::try_from(lwe_dimension.0).unwrap(),
10799+
u32::try_from(ks_level.0).unwrap(),
10800+
u32::try_from(ks_base_log.0).unwrap(),
10801+
u32::try_from(pbs_level.0).unwrap(),
10802+
u32::try_from(pbs_base_log.0).unwrap(),
10803+
u32::try_from(grouping_factor.0).unwrap(),
10804+
u32::try_from(message_modulus.0).unwrap(),
10805+
u32::try_from(carry_modulus.0).unwrap(),
10806+
pbs_type as u32,
10807+
true,
10808+
noise_reduction_type as u32,
10809+
num_inputs,
10810+
);
10811+
10812+
cuda_kreyvium_init_state_64_async(
10813+
streams.ffi(),
10814+
&raw const cuda_ffi_key,
10815+
&raw const cuda_ffi_iv,
10816+
mem_ptr,
10817+
bootstrapping_key.ptr.as_ptr(),
10818+
keyswitch_key.ptr.as_ptr(),
10819+
);
10820+
10821+
mem_ptr
10822+
}
10823+
10824+
#[allow(clippy::too_many_arguments)]
10825+
pub(crate) unsafe fn cuda_backend_kreyvium_next_step<B: Numeric, KST: UnsignedInteger>(
10826+
streams: &CudaStreams,
10827+
keystream_output: Option<&mut CudaRadixCiphertext>,
10828+
mem_ptr: *mut i8,
10829+
num_inputs: u32,
10830+
num_steps: u32,
10831+
bootstrapping_key: &CudaVec<B>,
10832+
keyswitch_key: &CudaVec<KST>,
10833+
) {
10834+
if let Some(keystream_output) = keystream_output {
10835+
let mut keystream_degrees = keystream_output
10836+
.info
10837+
.blocks
10838+
.iter()
10839+
.map(|b| b.degree.0)
10840+
.collect();
10841+
let mut keystream_noise_levels = keystream_output
10842+
.info
10843+
.blocks
10844+
.iter()
10845+
.map(|b| b.noise_level.0)
10846+
.collect();
10847+
let mut cuda_ffi_keystream = prepare_cuda_radix_ffi(
10848+
keystream_output,
10849+
&mut keystream_degrees,
10850+
&mut keystream_noise_levels,
10851+
);
10852+
10853+
cuda_kreyvium_next_step_64_async(
10854+
streams.ffi(),
10855+
&raw mut cuda_ffi_keystream,
10856+
num_inputs,
10857+
num_steps,
10858+
mem_ptr,
10859+
bootstrapping_key.ptr.as_ptr(),
10860+
keyswitch_key.ptr.as_ptr(),
10861+
);
10862+
10863+
update_noise_degree(keystream_output, &cuda_ffi_keystream);
10864+
} else {
10865+
cuda_kreyvium_next_step_64_async(
10866+
streams.ffi(),
10867+
std::ptr::null_mut(),
10868+
num_inputs,
10869+
num_steps,
10870+
mem_ptr,
10871+
bootstrapping_key.ptr.as_ptr(),
10872+
keyswitch_key.ptr.as_ptr(),
10873+
);
10874+
}
10875+
}
10876+
10877+
pub(crate) unsafe fn cuda_backend_kreyvium_cleanup_state(
10878+
streams: &CudaStreams,
10879+
mut mem_ptr: *mut i8,
10880+
) {
10881+
cleanup_cuda_kreyvium_generate_keystream_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
10882+
}

0 commit comments

Comments
 (0)