Skip to content

Commit b532f37

Browse files
committed
chore(gpu): add vectorized function for bitand and plug it in the Array API
1 parent e5b39a6 commit b532f37

File tree

12 files changed

+455
-25
lines changed

12 files changed

+455
-25
lines changed

backends/tfhe-cuda-backend/cuda/include/integer/integer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ typedef struct {
5656
uint32_t num_radix_blocks;
5757
uint32_t max_num_radix_blocks;
5858
uint32_t lwe_dimension;
59+
uint32_t num_radix_ciphertexts;
5960
} CudaRadixCiphertextFFI;
6061

6162
typedef struct {

backends/tfhe-cuda-backend/cuda/src/integer/bitwise_ops.cuh

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,33 +24,43 @@ __host__ void host_integer_radix_bitop_kb(
2424
lwe_array_out->num_radix_blocks == lwe_array_2->num_radix_blocks,
2525
"Cuda error: input and output num radix blocks must be equal");
2626

27+
PANIC_IF_FALSE(
28+
lwe_array_out->num_radix_ciphertexts ==
29+
lwe_array_1->num_radix_ciphertexts &&
30+
lwe_array_out->num_radix_ciphertexts ==
31+
lwe_array_2->num_radix_ciphertexts,
32+
"Cuda error: input and output num radix ciphertexts must be equal");
33+
2734
PANIC_IF_FALSE(lwe_array_out->lwe_dimension == lwe_array_1->lwe_dimension &&
2835
lwe_array_out->lwe_dimension == lwe_array_2->lwe_dimension,
2936
"Cuda error: input and output lwe dimension must be equal");
3037

3138
auto lut = mem_ptr->lut;
32-
uint64_t degrees[lwe_array_1->num_radix_blocks];
39+
uint64_t degrees[lwe_array_1->num_radix_blocks *
40+
lwe_array_1->num_radix_ciphertexts];
3341
if (mem_ptr->op == BITOP_TYPE::BITAND) {
34-
update_degrees_after_bitand(degrees, lwe_array_1->degrees,
35-
lwe_array_2->degrees,
36-
lwe_array_1->num_radix_blocks);
42+
update_degrees_after_bitand(
43+
degrees, lwe_array_1->degrees, lwe_array_2->degrees,
44+
lwe_array_1->num_radix_blocks * lwe_array_1->num_radix_ciphertexts);
3745
} else if (mem_ptr->op == BITOP_TYPE::BITOR) {
38-
update_degrees_after_bitor(degrees, lwe_array_1->degrees,
39-
lwe_array_2->degrees,
40-
lwe_array_1->num_radix_blocks);
46+
update_degrees_after_bitor(
47+
degrees, lwe_array_1->degrees, lwe_array_2->degrees,
48+
lwe_array_1->num_radix_blocks * lwe_array_1->num_radix_ciphertexts);
4149
} else if (mem_ptr->op == BITOP_TYPE::BITXOR) {
42-
update_degrees_after_bitxor(degrees, lwe_array_1->degrees,
43-
lwe_array_2->degrees,
44-
lwe_array_1->num_radix_blocks);
50+
update_degrees_after_bitxor(
51+
degrees, lwe_array_1->degrees, lwe_array_2->degrees,
52+
lwe_array_1->num_radix_blocks * lwe_array_1->num_radix_ciphertexts);
4553
}
4654

4755
integer_radix_apply_bivariate_lookup_table_kb<Torus>(
48-
streams, lwe_array_out, lwe_array_1, lwe_array_2, bsks, ksks,
49-
ms_noise_reduction_key, lut, lwe_array_out->num_radix_blocks,
56+
streams, lwe_array_out, lwe_array_1, lwe_array_2,
57+
bsks, ksks, ms_noise_reduction_key, lut,
58+
lwe_array_out->num_radix_blocks * lwe_array_out->num_radix_ciphertexts,
5059
lut->params.message_modulus);
5160

5261
memcpy(lwe_array_out->degrees, degrees,
53-
lwe_array_out->num_radix_blocks * sizeof(uint64_t));
62+
lwe_array_out->num_radix_blocks *
63+
lwe_array_out->num_radix_ciphertexts * sizeof(uint64_t));
5464
}
5565

5666
template <typename Torus>

backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -417,9 +417,12 @@ __host__ void host_pack_bivariate_blocks(
417417
lwe_array_out->lwe_dimension != lwe_array_2->lwe_dimension)
418418
PANIC("Cuda error: input and output radix ciphertexts should have the same "
419419
"lwe dimension")
420-
if (num_radix_blocks > lwe_array_out->num_radix_blocks ||
421-
num_radix_blocks > lwe_array_1->num_radix_blocks ||
422-
num_radix_blocks > lwe_array_2->num_radix_blocks)
420+
if (num_radix_blocks > lwe_array_out->num_radix_blocks *
421+
lwe_array_out->num_radix_ciphertexts ||
422+
num_radix_blocks >
423+
lwe_array_1->num_radix_blocks * lwe_array_1->num_radix_ciphertexts ||
424+
num_radix_blocks >
425+
lwe_array_2->num_radix_blocks * lwe_array_2->num_radix_ciphertexts)
423426
PANIC("Cuda error: num radix blocks on which packing is applied should be "
424427
"smaller or equal to the number of input & output radix blocks")
425428

@@ -530,7 +533,8 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(
530533
if (num_radix_blocks > lut->num_blocks)
531534
PANIC("Cuda error: num radix blocks on which lut is applied should be "
532535
"smaller or equal to the number of lut radix blocks")
533-
if (num_radix_blocks > lwe_array_out->num_radix_blocks)
536+
if (num_radix_blocks >
537+
lwe_array_out->num_radix_blocks * lwe_array_out->num_radix_ciphertexts)
534538
PANIC("Cuda error: num radix blocks on which lut is applied should be "
535539
"smaller or equal to the number of input & output radix blocks")
536540

@@ -756,11 +760,14 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
756760
if (num_radix_blocks > lut->num_blocks)
757761
PANIC("Cuda error: num radix blocks on which lut is applied should be "
758762
"smaller or equal to the number of lut radix blocks")
759-
if (num_radix_blocks > lwe_array_out->num_radix_blocks ||
760-
num_radix_blocks > lwe_array_1->num_radix_blocks ||
761-
num_radix_blocks > lwe_array_2->num_radix_blocks)
763+
if (num_radix_blocks > lwe_array_out->num_radix_blocks *
764+
lwe_array_out->num_radix_ciphertexts ||
765+
num_radix_blocks >
766+
lwe_array_1->num_radix_blocks * lwe_array_1->num_radix_ciphertexts ||
767+
num_radix_blocks >
768+
lwe_array_2->num_radix_blocks * lwe_array_2->num_radix_ciphertexts)
762769
PANIC("Cuda error: num radix blocks on which lut is applied should be "
763-
"smaller or equal to the number of input & output radix blocks")
770+
"smaller or equal to the number of total input & output radix blocks")
764771

765772
auto params = lut->params;
766773
auto pbs_type = params.pbs_type;

backends/tfhe-cuda-backend/cuda/src/integer/radix_ciphertext.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ void into_radix_ciphertext(CudaRadixCiphertextFFI *radix, void *lwe_array,
2525
radix->num_radix_blocks = num_radix_blocks;
2626
radix->max_num_radix_blocks = num_radix_blocks;
2727
radix->ptr = lwe_array;
28+
radix->num_radix_ciphertexts = 1;
2829

2930
radix->degrees = (uint64_t *)(calloc(num_radix_blocks, sizeof(uint64_t)));
3031
radix->noise_levels =

backends/tfhe-cuda-backend/cuda/src/integer/radix_ciphertext.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ void create_zero_radix_ciphertext_async(cudaStream_t const stream,
1919
radix->lwe_dimension = lwe_dimension;
2020
radix->num_radix_blocks = num_radix_blocks;
2121
radix->max_num_radix_blocks = num_radix_blocks;
22+
radix->num_radix_ciphertexts = 1;
2223
uint64_t size = (lwe_dimension + 1) * num_radix_blocks * sizeof(Torus);
2324
radix->ptr = (void *)cuda_malloc_with_size_tracking_async(
2425
size, stream, gpu_index, size_tracker, allocate_gpu_memory);
@@ -63,6 +64,7 @@ void as_radix_ciphertext_slice(CudaRadixCiphertextFFI *output_radix,
6364

6465
auto lwe_size = input_radix->lwe_dimension + 1;
6566
output_radix->num_radix_blocks = end_input_lwe_index - start_input_lwe_index;
67+
output_radix->num_radix_ciphertexts = input_radix->num_radix_ciphertexts;
6668
output_radix->max_num_radix_blocks = input_radix->max_num_radix_blocks;
6769
output_radix->lwe_dimension = input_radix->lwe_dimension;
6870
Torus *in_ptr = (Torus *)input_radix->ptr;

backends/tfhe-cuda-backend/src/bindings.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ pub struct CudaRadixCiphertextFFI {
183183
pub num_radix_blocks: u32,
184184
pub max_num_radix_blocks: u32,
185185
pub lwe_dimension: u32,
186+
pub num_radix_ciphertexts: u32,
186187
}
187188
#[allow(clippy::unnecessary_operation, clippy::identity_op)]
188189
const _: () = {
@@ -201,6 +202,8 @@ const _: () = {
201202
[::std::mem::offset_of!(CudaRadixCiphertextFFI, max_num_radix_blocks) - 28usize];
202203
["Offset of field: CudaRadixCiphertextFFI::lwe_dimension"]
203204
[::std::mem::offset_of!(CudaRadixCiphertextFFI, lwe_dimension) - 32usize];
205+
["Offset of field: CudaRadixCiphertextFFI::num_radix_ciphertexts"]
206+
[::std::mem::offset_of!(CudaRadixCiphertextFFI, num_radix_ciphertexts) - 36usize];
204207
};
205208
#[repr(C)]
206209
#[derive(Debug, Copy, Clone)]

tfhe/src/core_crypto/gpu/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,7 @@ pub unsafe fn add_lwe_ciphertext_vector_async<T: UnsignedInteger>(
840840
num_radix_blocks: num_samples,
841841
max_num_radix_blocks: num_samples,
842842
lwe_dimension: lwe_dimension.0 as u32,
843+
num_radix_ciphertexts: 1u32,
843844
};
844845
let lwe_array_in_1_data = CudaRadixCiphertextFFI {
845846
ptr: lwe_array_in_1.get_mut_c_ptr(0),
@@ -848,6 +849,7 @@ pub unsafe fn add_lwe_ciphertext_vector_async<T: UnsignedInteger>(
848849
num_radix_blocks: num_samples,
849850
max_num_radix_blocks: num_samples,
850851
lwe_dimension: lwe_dimension.0 as u32,
852+
num_radix_ciphertexts: 1u32,
851853
};
852854
let lwe_array_in_2_data = CudaRadixCiphertextFFI {
853855
ptr: lwe_array_in_2.get_mut_c_ptr(0),
@@ -856,6 +858,7 @@ pub unsafe fn add_lwe_ciphertext_vector_async<T: UnsignedInteger>(
856858
num_radix_blocks: num_samples,
857859
max_num_radix_blocks: num_samples,
858860
lwe_dimension: lwe_dimension.0 as u32,
861+
num_radix_ciphertexts: 1u32,
859862
};
860863
cuda_add_lwe_ciphertext_vector_64(
861864
streams.ptr[0],
@@ -890,6 +893,7 @@ pub unsafe fn add_lwe_ciphertext_vector_assign_async<T: UnsignedInteger>(
890893
num_radix_blocks: num_samples,
891894
max_num_radix_blocks: num_samples,
892895
lwe_dimension: lwe_dimension.0 as u32,
896+
num_radix_ciphertexts: 1u32,
893897
};
894898
let lwe_array_in_data = CudaRadixCiphertextFFI {
895899
ptr: lwe_array_in.get_mut_c_ptr(0),
@@ -898,6 +902,7 @@ pub unsafe fn add_lwe_ciphertext_vector_assign_async<T: UnsignedInteger>(
898902
num_radix_blocks: num_samples,
899903
max_num_radix_blocks: num_samples,
900904
lwe_dimension: lwe_dimension.0 as u32,
905+
num_radix_ciphertexts: 1u32,
901906
};
902907
cuda_add_lwe_ciphertext_vector_64(
903908
streams.ptr[0],

tfhe/src/high_level_api/array/gpu/integers.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ use crate::integer::block_decomposition::{
1919
DecomposableInto, RecomposableFrom, RecomposableSignedInteger,
2020
};
2121
use crate::integer::gpu::ciphertext::{
22-
CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext,
22+
CudaIntegerRadixCiphertext, CudaRadixCiphertext, CudaSignedRadixCiphertext,
23+
CudaUnsignedRadixCiphertext,
2324
};
2425
use crate::integer::server_key::radix_parallel::scalar_div_mod::SignedReciprocable;
2526
use crate::integer::server_key::{Reciprocable, ScalarMultiplier};
@@ -83,6 +84,12 @@ impl<'a, T> TensorSlice<'a, GpuSlice<'a, T>> {
8384
pub fn par_iter(self) -> ParStridedIter<'a, T> {
8485
ParStridedIter::new(self.slice.0, self.dims.clone())
8586
}
87+
pub fn len(&self) -> usize {
88+
self.dims.flattened_len()
89+
}
90+
pub fn as_slice(&self) -> &'a [T] {
91+
self.slice.0
92+
}
8693
}
8794

8895
impl<'a, T> TensorSlice<'a, GpuSliceMut<'a, T>> {
@@ -316,7 +323,25 @@ where
316323
lhs: TensorSlice<'_, Self::Slice<'a>>,
317324
rhs: TensorSlice<'_, Self::Slice<'a>>,
318325
) -> Self::Owned {
319-
par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::gpu::CudaServerKey::bitand)
326+
GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| {
327+
let streams = &cuda_key.streams;
328+
let num_ciphertexts = lhs.len() as u32;
329+
let lhs_slice: &[T] = lhs.as_slice();
330+
let rhs_slice: &[T] = rhs.as_slice();
331+
let mut lhs_aligned = T::from(CudaRadixCiphertext::from_radix_ciphertext_vec(
332+
lhs_slice, streams,
333+
));
334+
let rhs_aligned = T::from(CudaRadixCiphertext::from_radix_ciphertext_vec(
335+
rhs_slice, streams,
336+
));
337+
crate::integer::gpu::CudaServerKey::bitand_vec(
338+
cuda_key.pbs_key(),
339+
&mut lhs_aligned,
340+
&rhs_aligned,
341+
num_ciphertexts,
342+
streams,
343+
)
344+
}))
320345
}
321346

322347
fn bitor<'a>(

tfhe/src/high_level_api/array/traits.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ impl<'a, T> TensorSlice<'a, &'a [T]> {
2828
pub fn par_iter(self) -> ParStridedIter<'a, T> {
2929
ParStridedIter::new(self.slice, self.dims.clone())
3030
}
31+
pub fn len(&self) -> usize {
32+
self.dims.flattened_len()
33+
}
34+
pub fn as_slice(&self) -> &'a [T] {
35+
self.slice
36+
}
3137
}
3238

3339
impl<'a, T> TensorSlice<'a, &'a mut [T]> {

tfhe/src/integer/gpu/ciphertext/mod.rs

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ pub mod squashed_noise;
77

88
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
99
use crate::core_crypto::gpu::vec::CudaVec;
10-
use crate::core_crypto::gpu::CudaStreams;
10+
use crate::core_crypto::gpu::{CudaLweList, CudaStreams};
1111
use crate::core_crypto::prelude::{LweCiphertextList, LweCiphertextOwned};
1212
use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo};
1313
use crate::integer::parameters::LweDimension;
1414
use crate::integer::{IntegerCiphertext, RadixCiphertext, SignedRadixCiphertext};
1515
use crate::shortint::{Ciphertext, EncryptionKeyChoice};
1616
use crate::GpuIndex;
1717

18+
use crate::shortint::parameters::LweCiphertextCount;
1819
pub use compressed_noise_squashed_ciphertext_list::*;
1920

2021
pub trait CudaIntegerRadixCiphertext: Sized {
@@ -70,8 +71,68 @@ pub trait CudaIntegerRadixCiphertext: Sized {
7071
fn gpu_indexes(&self) -> &[GpuIndex] {
7172
&self.as_ref().d_blocks.0.d_vec.gpu_indexes
7273
}
74+
75+
// Converts a CudaIntegerRadixCiphertext with num_blocks * num_ciphertexts LWEs into a
76+
// Vec<CudaIntegerRadixCiphertext> of length num_radix_ciphertexts, where each ciphertext has
77+
// num_blocks LWEs
78+
fn to_integer_radix_ciphertext_vec(
79+
&self,
80+
num_radix_ciphertexts: u32,
81+
streams: &CudaStreams,
82+
) -> Vec<Self> {
83+
let total_blocks = self.as_ref().d_blocks.0.lwe_ciphertext_count.0;
84+
assert_eq!(total_blocks % num_radix_ciphertexts as usize, 0, "Total number of blocks ({total_blocks}) is not divisible by number of radix ciphertexts ({num_radix_ciphertexts})");
85+
86+
let num_blocks = total_blocks / num_radix_ciphertexts as usize;
87+
88+
let mut result = Vec::with_capacity(num_radix_ciphertexts as usize);
89+
let lwe_dimension = self.as_ref().d_blocks.lwe_dimension();
90+
91+
for i in 0..num_radix_ciphertexts as usize {
92+
let block_start = i * num_blocks;
93+
let block_end = block_start + num_blocks;
94+
95+
let d_vec = unsafe {
96+
let mut d_vec =
97+
CudaVec::new_async(lwe_dimension.to_lwe_size().0 * num_blocks, streams, 0);
98+
99+
let copy_start = block_start * lwe_dimension.to_lwe_size().0;
100+
let copy_end = block_end * lwe_dimension.to_lwe_size().0;
101+
d_vec.copy_src_range_gpu_to_gpu_async(
102+
copy_start..copy_end,
103+
&self.as_ref().d_blocks.0.d_vec,
104+
streams,
105+
0,
106+
);
107+
108+
streams.synchronize();
109+
d_vec
110+
};
111+
let lwe_list = CudaLweList::<u64> {
112+
d_vec,
113+
lwe_ciphertext_count: LweCiphertextCount(num_blocks),
114+
lwe_dimension,
115+
ciphertext_modulus: self.as_ref().d_blocks.ciphertext_modulus(),
116+
};
117+
118+
// Copy the associated block metadata
119+
let block_info = self.as_ref().info.blocks[block_start..block_end].to_vec();
120+
121+
let info = CudaRadixCiphertextInfo { blocks: block_info };
122+
123+
result.push(Self::from(CudaRadixCiphertext::new(
124+
CudaLweCiphertextList(lwe_list),
125+
info,
126+
)));
127+
}
128+
129+
result
130+
}
73131
}
74132

133+
/// This struct corresponds to the pointers on GPU and
134+
/// metadata representing an array of LWEs corresponding
135+
/// to one or more RadixCiphertexts
75136
pub struct CudaRadixCiphertext {
76137
pub d_blocks: CudaLweCiphertextList<u64>,
77138
pub info: CudaRadixCiphertextInfo,

0 commit comments

Comments
 (0)