Skip to content

Commit 05b1c9a

Browse files
committed
feat(hlapi): bind CudaServerKey::contains
1 parent 8d2caa1 commit 05b1c9a

2 files changed

Lines changed: 30 additions & 2 deletions

File tree

tfhe/src/high_level_api/array/mod.rs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ pub use traits::{IOwnedArray, Slicing, SlicingMut};
2424
use crate::array::stride::DynDimensions;
2525
use crate::core_crypto::prelude::{Numeric, OverflowingAdd, SignedNumeric, UnsignedNumeric};
2626
use crate::integer::block_decomposition::DecomposableInto;
27+
#[cfg(feature = "gpu")]
28+
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
2729
use crate::integer::RadixCiphertext;
2830
use crate::prelude::{CastFrom, CastInto};
2931
pub use cpu::{
@@ -472,8 +474,25 @@ where
472474
)
473475
}
474476
#[cfg(feature = "gpu")]
475-
InternalServerKey::Cuda(_) => {
476-
panic!("GPU does not support contains() on FheIntegerType yet")
477+
InternalServerKey::Cuda(gpu_key) => {
478+
use crate::high_level_api::details::MaybeCloned;
479+
480+
let streams = &gpu_key.streams;
481+
let tmp_data = data
482+
.iter()
483+
.map(|element| match element.on_gpu(streams) {
484+
MaybeCloned::Borrowed(ct) => ct.duplicate(streams),
485+
MaybeCloned::Cloned(ct) => ct,
486+
})
487+
.collect::<Vec<_>>();
488+
let tmp_value = value.on_gpu(streams);
489+
490+
let result = gpu_key.pbs_key().contains(&tmp_data, &*tmp_value, streams);
491+
FheBool::new(
492+
result,
493+
gpu_key.tag.clone(),
494+
ReRandomizationMetadata::default(),
495+
)
477496
}
478497
#[cfg(feature = "hpu")]
479498
InternalServerKey::Hpu(_) => {

tfhe/src/high_level_api/array/tests/unsigned.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,15 @@ fn test_contains() {
107107
super::test_case_contains::<crate::FheUint8, u8>(&ck);
108108
}
109109

110+
#[test]
111+
#[cfg(feature = "gpu")]
112+
fn test_contains_gpu() {
113+
for setup_fn in crate::high_level_api::integers::unsigned::tests::gpu::GPU_SETUP_FN {
114+
let ck = setup_fn();
115+
super::test_case_contains::<crate::FheUint8, u8>(&ck);
116+
}
117+
}
118+
110119
#[test]
111120
fn test_single_dimension() {
112121
let config = ConfigBuilder::default().build();

0 commit comments

Comments
 (0)