Skip to content

Commit 0fbb809

Browse files
committed
feat(core): bring extended PBS to core_crypto::experimental module
- for now no dedicated types have been created for the the extended bootstrap, meaning an extended BSK is merely seen as a BSK/ExtensionFactor couple
1 parent f8fbc30 commit 0fbb809

11 files changed

Lines changed: 1173 additions & 3 deletions

File tree

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,6 +1545,10 @@ clippy_bench: install_rs_check_toolchain
15451545
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \
15461546
--features=shortint,internal-keycache \
15471547
-p tfhe-benchmark -- --no-deps -D warnings
1548+
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \
1549+
--features=experimental \
1550+
-p tfhe-benchmark -- --no-deps -D warnings
1551+
15481552

15491553
.PHONY: clippy_bench_gpu # Run clippy lints on tfhe-benchmark
15501554
clippy_bench_gpu: install_rs_check_toolchain

tfhe-benchmark/Cargo.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ avx512 = ["tfhe/avx512"]
4747
pbs-stats = ["tfhe/pbs-stats"]
4848
zk-pok = ["tfhe/zk-pok"]
4949

50+
# experimental section
51+
experimental = ["tfhe/experimental"]
52+
5053
[[bench]]
5154
name = "boolean"
5255
path = "benches/boolean/bench.rs"
@@ -197,6 +200,12 @@ path = "benches/core_crypto/pbs128_bench.rs"
197200
harness = false
198201
required-features = ["shortint", "internal-keycache"]
199202

203+
[[bench]]
204+
name = "core_crypto-experimental_extended_pbs"
205+
path = "benches/core_crypto/experimental_extended_pbs.rs"
206+
harness = false
207+
required-features = ["experimental"]
208+
200209
[[bin]]
201210
name = "boolean_key_sizes"
202211
path = "src/bin/boolean_key_sizes.rs"
Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
use benchmark::utilities::{get_bench_type, BenchmarkType};
2+
use criterion::{black_box, Criterion, Throughput};
3+
use rayon::prelude::*;
4+
use tfhe::core_crypto::experimental::prelude::*;
5+
use tfhe::core_crypto::prelude::*;
6+
7+
pub struct ExtendedPBSBenchParameters {
8+
lwe_dimension: LweDimension,
9+
glwe_dimension: GlweDimension,
10+
polynomial_size: PolynomialSize,
11+
extension_factor: LweBootstrapExtensionFactor,
12+
lwe_noise_distribution: DynamicDistribution<u64>,
13+
glwe_noise_distribution: DynamicDistribution<u64>,
14+
pbs_base_log: DecompositionBaseLog,
15+
pbs_level: DecompositionLevelCount,
16+
ks_base_log: DecompositionBaseLog,
17+
ks_level: DecompositionLevelCount,
18+
message_modulus: CleartextModulus<MessageSpace>,
19+
carry_modulus: CleartextModulus<CarrySpace>,
20+
#[allow(dead_code)]
21+
max_norm2: MaxNorm2,
22+
#[allow(dead_code)]
23+
log2_p_fail: f64,
24+
ciphertext_modulus: CiphertextModulus<u64>,
25+
encryption_key_choice: EncryptionKeyChoice,
26+
}
27+
28+
// p-fail = 2^-128.147, algorithmic cost ~ 67456140, 2-norm = 5, extension factor = 16,
29+
const BENCH_PARAM_MESSAGE_2_CARRY_2_PARALLEL_PBS_EF_16_2M128: ExtendedPBSBenchParameters =
30+
ExtendedPBSBenchParameters {
31+
lwe_dimension: LweDimension(884),
32+
glwe_dimension: GlweDimension(4),
33+
polynomial_size: PolynomialSize(512),
34+
extension_factor: LweBootstrapExtensionFactor(16),
35+
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
36+
1.4999005934396873e-06,
37+
)),
38+
glwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
39+
2.845267479601915e-15,
40+
)),
41+
pbs_base_log: DecompositionBaseLog(23),
42+
pbs_level: DecompositionLevelCount(1),
43+
ks_base_log: DecompositionBaseLog(5),
44+
ks_level: DecompositionLevelCount(3),
45+
message_modulus: CleartextModulus::new(4),
46+
carry_modulus: CleartextModulus::new(4),
47+
max_norm2: MaxNorm2(5f64),
48+
log2_p_fail: -128.0,
49+
ciphertext_modulus: CiphertextModulus::new_native(),
50+
encryption_key_choice: EncryptionKeyChoice::Big,
51+
};
52+
53+
const KS_EPBS_BENCH_PARAMS: [(&str, &ExtendedPBSBenchParameters); 1] = [(
54+
"BENCH_PARAM_MESSAGE_2_CARRY_2_PARALLEL_PBS_EF_16_2M128",
55+
&BENCH_PARAM_MESSAGE_2_CARRY_2_PARALLEL_PBS_EF_16_2M128,
56+
)];
57+
58+
fn get_encoding_with_padding<Scalar: UnsignedInteger>(
59+
ciphertext_modulus: CiphertextModulus<Scalar>,
60+
) -> Scalar {
61+
if ciphertext_modulus.is_native_modulus() {
62+
Scalar::ONE << (Scalar::BITS - 1)
63+
} else {
64+
Scalar::cast_from(ciphertext_modulus.get_custom_modulus() / 2)
65+
}
66+
}
67+
68+
fn ks_extended_pbs(criterion: &mut Criterion) {
69+
let bench_name = "core_crypto::ks_extended_pbs";
70+
let mut bench_group = criterion.benchmark_group(bench_name);
71+
72+
// Create the PRNG
73+
let mut seeder = new_seeder();
74+
let seeder = seeder.as_mut();
75+
let mut encryption_generator =
76+
EncryptionRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed(), seeder);
77+
let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
78+
79+
for (name, params) in KS_EPBS_BENCH_PARAMS {
80+
let ExtendedPBSBenchParameters {
81+
lwe_dimension,
82+
glwe_dimension,
83+
polynomial_size,
84+
extension_factor,
85+
lwe_noise_distribution,
86+
glwe_noise_distribution,
87+
pbs_base_log,
88+
pbs_level,
89+
ks_base_log,
90+
ks_level,
91+
message_modulus,
92+
carry_modulus,
93+
max_norm2: _,
94+
log2_p_fail: _,
95+
ciphertext_modulus,
96+
encryption_key_choice,
97+
} = *params;
98+
99+
let plaintext_modulus = message_modulus.0 * carry_modulus.0;
100+
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
101+
let delta = encoding_with_padding / plaintext_modulus;
102+
103+
assert!(matches!(encryption_key_choice, EncryptionKeyChoice::Big));
104+
105+
let lwe_sk =
106+
allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator);
107+
108+
let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key(
109+
glwe_dimension,
110+
polynomial_size,
111+
&mut secret_generator,
112+
);
113+
let big_lwe_sk = glwe_sk.as_lwe_secret_key();
114+
let ksk_big_to_small = allocate_and_generate_new_lwe_keyswitch_key(
115+
&big_lwe_sk,
116+
&lwe_sk,
117+
ks_base_log,
118+
ks_level,
119+
lwe_noise_distribution,
120+
ciphertext_modulus,
121+
&mut encryption_generator,
122+
);
123+
124+
let bsk = allocate_and_generate_new_lwe_bootstrap_key(
125+
&lwe_sk,
126+
&glwe_sk,
127+
pbs_base_log,
128+
pbs_level,
129+
glwe_noise_distribution,
130+
ciphertext_modulus,
131+
&mut encryption_generator,
132+
);
133+
134+
let mut fourier_bsk = FourierLweBootstrapKey::new(
135+
bsk.input_lwe_dimension(),
136+
bsk.glwe_size(),
137+
bsk.polynomial_size(),
138+
bsk.decomposition_base_log(),
139+
bsk.decomposition_level_count(),
140+
);
141+
par_convert_standard_lwe_bootstrap_key_to_fourier(&bsk, &mut fourier_bsk);
142+
143+
let f = |x: u64| x;
144+
145+
let accumulator = generate_programmable_bootstrap_glwe_lut(
146+
PolynomialSize(polynomial_size.0 * extension_factor.0),
147+
glwe_dimension.to_glwe_size(),
148+
plaintext_modulus.cast_into(),
149+
ciphertext_modulus,
150+
delta,
151+
f,
152+
);
153+
154+
let fft = Fft::new(fourier_bsk.polynomial_size());
155+
let fft = fft.as_view();
156+
157+
let mut buffers = ComputationBuffers::new();
158+
159+
// TODO: have req for main thread and for workers ?
160+
use extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized_requirement as rq;
161+
162+
let requirement = rq::<u64>(
163+
glwe_dimension.to_glwe_size(),
164+
polynomial_size,
165+
extension_factor,
166+
fft,
167+
)
168+
.unaligned_bytes_required();
169+
170+
buffers.resize(requirement);
171+
172+
let mut thread_buffers = Vec::with_capacity(extension_factor.0);
173+
for _ in 0..extension_factor.0 {
174+
let mut buffer = ComputationBuffers::new();
175+
buffer.resize(requirement);
176+
thread_buffers.push(buffer);
177+
}
178+
179+
let mut thread_stacks: Vec<_> = thread_buffers.iter_mut().map(|x| x.stack()).collect();
180+
181+
let bench_id;
182+
183+
match get_bench_type() {
184+
BenchmarkType::Latency => {
185+
let ct = allocate_and_encrypt_new_lwe_ciphertext(
186+
&big_lwe_sk,
187+
Plaintext(0),
188+
lwe_noise_distribution,
189+
ciphertext_modulus,
190+
&mut encryption_generator,
191+
);
192+
193+
let mut ks_buffer =
194+
LweCiphertext::new(0, lwe_sk.lwe_dimension().to_lwe_size(), ciphertext_modulus);
195+
196+
let mut output_ct = ct.clone();
197+
output_ct.as_mut().fill(0);
198+
199+
bench_id = format!("{bench_name}::{name}");
200+
bench_group.bench_function(&bench_id, |b| {
201+
b.iter(|| {
202+
keyswitch_lwe_ciphertext(&ksk_big_to_small, &ct, &mut ks_buffer);
203+
extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized(
204+
&fourier_bsk,
205+
&mut output_ct,
206+
&ct,
207+
&accumulator,
208+
extension_factor,
209+
fft,
210+
buffers.stack(),
211+
&mut thread_stacks,
212+
);
213+
black_box(&mut output_ct);
214+
})
215+
});
216+
}
217+
BenchmarkType::Throughput => {
218+
bench_id = format!("{bench_name}::throughput::{name}");
219+
let mut setup = |batch_size: usize| {
220+
let inputs = (0..batch_size)
221+
.map(|_| {
222+
let ct = allocate_and_encrypt_new_lwe_ciphertext(
223+
&big_lwe_sk,
224+
Plaintext(0),
225+
lwe_noise_distribution,
226+
ciphertext_modulus,
227+
&mut encryption_generator,
228+
);
229+
230+
let ks_buffer = LweCiphertext::new(
231+
0,
232+
lwe_sk.lwe_dimension().to_lwe_size(),
233+
ciphertext_modulus,
234+
);
235+
236+
let mut output_ct = ct.clone();
237+
output_ct.as_mut().fill(0);
238+
239+
let accumulator = generate_programmable_bootstrap_glwe_lut(
240+
PolynomialSize(polynomial_size.0 * extension_factor.0),
241+
glwe_dimension.to_glwe_size(),
242+
plaintext_modulus.cast_into(),
243+
ciphertext_modulus,
244+
delta,
245+
f,
246+
);
247+
248+
let fft = Fft::new(fourier_bsk.polynomial_size());
249+
let fft = fft.as_view();
250+
251+
let mut main_thread_buffer = ComputationBuffers::new();
252+
253+
let requirement = rq::<u64>(
254+
glwe_dimension.to_glwe_size(),
255+
polynomial_size,
256+
extension_factor,
257+
fft,
258+
)
259+
.unaligned_bytes_required();
260+
261+
main_thread_buffer.resize(requirement);
262+
263+
let mut thread_buffers = Vec::with_capacity(extension_factor.0);
264+
for _ in 0..extension_factor.0 {
265+
let mut buffer = ComputationBuffers::new();
266+
buffer.resize(requirement);
267+
thread_buffers.push(buffer);
268+
}
269+
270+
(
271+
ct,
272+
ks_buffer,
273+
output_ct,
274+
accumulator,
275+
main_thread_buffer,
276+
thread_buffers,
277+
)
278+
})
279+
.collect::<Vec<_>>();
280+
inputs
281+
};
282+
type Res = Vec<(
283+
LweCiphertext<Vec<u64>>, // Input
284+
LweCiphertext<Vec<u64>>, // KS result
285+
LweCiphertext<Vec<u64>>, // PBS result
286+
GlweCiphertext<Vec<u64>>, // Accumulator
287+
ComputationBuffers, // Main thread buffer
288+
Vec<ComputationBuffers>, // Worker thread buffer
289+
)>;
290+
let run = |inputs: &mut Res| {
291+
inputs.par_iter_mut().for_each(
292+
|(
293+
ct,
294+
ks_buffer,
295+
output_ct,
296+
accumulator,
297+
main_thread_buffer,
298+
thread_buffers,
299+
)| {
300+
let mut thread_stacks: Vec<_> =
301+
thread_buffers.iter_mut().map(|x| x.stack()).collect();
302+
keyswitch_lwe_ciphertext(&ksk_big_to_small, ct, ks_buffer);
303+
extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized(
304+
&fourier_bsk,
305+
output_ct,
306+
ct,
307+
accumulator,
308+
extension_factor,
309+
fft,
310+
main_thread_buffer.stack(),
311+
&mut thread_stacks,
312+
);
313+
black_box(output_ct);
314+
},
315+
)
316+
};
317+
let elements = {
318+
use benchmark::find_optimal_batch::find_optimal_batch;
319+
find_optimal_batch(|inputs, _batch_size| run(inputs), &mut setup) as u64
320+
};
321+
bench_group.throughput(Throughput::Elements(elements));
322+
bench_group.bench_function(&bench_id, |b| {
323+
b.iter_batched(
324+
|| setup(elements as usize),
325+
|mut inputs| run(&mut inputs),
326+
criterion::BatchSize::SmallInput,
327+
)
328+
});
329+
}
330+
};
331+
}
332+
}
333+
334+
pub fn extended_pbs_group() {
335+
let mut criterion: Criterion<_> = (Criterion::default()
336+
.sample_size(15)
337+
.measurement_time(std::time::Duration::from_secs(60)))
338+
.configure_from_args();
339+
ks_extended_pbs(&mut criterion);
340+
}
341+
342+
fn go_through_cpu_bench_groups() {
343+
extended_pbs_group();
344+
}
345+
346+
fn main() {
347+
go_through_cpu_bench_groups();
348+
349+
Criterion::default().configure_from_args().final_summary();
350+
}

tfhe-benchmark/src/utilities.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
pub use benchmark_spec::{get_bench_type, BenchmarkType};
12
use benchmark_spec::{Backend, BenchmarkSpec, OperandType};
23
use criterion::Criterion;
34
use serde::Serialize;

0 commit comments

Comments
 (0)