|
| 1 | +use benchmark::utilities::{get_bench_type, BenchmarkType}; |
| 2 | +use criterion::{black_box, Criterion, Throughput}; |
| 3 | +use rayon::prelude::*; |
| 4 | +use tfhe::core_crypto::experimental::prelude::*; |
| 5 | +use tfhe::core_crypto::prelude::*; |
| 6 | + |
| 7 | +pub struct ExtendedPBSBenchParameters { |
| 8 | + lwe_dimension: LweDimension, |
| 9 | + glwe_dimension: GlweDimension, |
| 10 | + polynomial_size: PolynomialSize, |
| 11 | + extension_factor: LweBootstrapExtensionFactor, |
| 12 | + lwe_noise_distribution: DynamicDistribution<u64>, |
| 13 | + glwe_noise_distribution: DynamicDistribution<u64>, |
| 14 | + pbs_base_log: DecompositionBaseLog, |
| 15 | + pbs_level: DecompositionLevelCount, |
| 16 | + ks_base_log: DecompositionBaseLog, |
| 17 | + ks_level: DecompositionLevelCount, |
| 18 | + message_modulus: CleartextModulus<MessageSpace>, |
| 19 | + carry_modulus: CleartextModulus<CarrySpace>, |
| 20 | + #[allow(dead_code)] |
| 21 | + max_norm2: MaxNorm2, |
| 22 | + #[allow(dead_code)] |
| 23 | + log2_p_fail: f64, |
| 24 | + ciphertext_modulus: CiphertextModulus<u64>, |
| 25 | + encryption_key_choice: EncryptionKeyChoice, |
| 26 | +} |
| 27 | + |
| 28 | +// p-fail = 2^-128.147, algorithmic cost ~ 67456140, 2-norm = 5, extension factor = 16, |
| 29 | +const BENCH_PARAM_MESSAGE_2_CARRY_2_PARALLEL_PBS_EF_16_2M128: ExtendedPBSBenchParameters = |
| 30 | + ExtendedPBSBenchParameters { |
| 31 | + lwe_dimension: LweDimension(884), |
| 32 | + glwe_dimension: GlweDimension(4), |
| 33 | + polynomial_size: PolynomialSize(512), |
| 34 | + extension_factor: LweBootstrapExtensionFactor(16), |
| 35 | + lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev( |
| 36 | + 1.4999005934396873e-06, |
| 37 | + )), |
| 38 | + glwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev( |
| 39 | + 2.845267479601915e-15, |
| 40 | + )), |
| 41 | + pbs_base_log: DecompositionBaseLog(23), |
| 42 | + pbs_level: DecompositionLevelCount(1), |
| 43 | + ks_base_log: DecompositionBaseLog(5), |
| 44 | + ks_level: DecompositionLevelCount(3), |
| 45 | + message_modulus: CleartextModulus::new(4), |
| 46 | + carry_modulus: CleartextModulus::new(4), |
| 47 | + max_norm2: MaxNorm2(5f64), |
| 48 | + log2_p_fail: -128.0, |
| 49 | + ciphertext_modulus: CiphertextModulus::new_native(), |
| 50 | + encryption_key_choice: EncryptionKeyChoice::Big, |
| 51 | + }; |
| 52 | + |
| 53 | +const KS_EPBS_BENCH_PARAMS: [(&str, &ExtendedPBSBenchParameters); 1] = [( |
| 54 | + "BENCH_PARAM_MESSAGE_2_CARRY_2_PARALLEL_PBS_EF_16_2M128", |
| 55 | + &BENCH_PARAM_MESSAGE_2_CARRY_2_PARALLEL_PBS_EF_16_2M128, |
| 56 | +)]; |
| 57 | + |
| 58 | +fn get_encoding_with_padding<Scalar: UnsignedInteger>( |
| 59 | + ciphertext_modulus: CiphertextModulus<Scalar>, |
| 60 | +) -> Scalar { |
| 61 | + if ciphertext_modulus.is_native_modulus() { |
| 62 | + Scalar::ONE << (Scalar::BITS - 1) |
| 63 | + } else { |
| 64 | + Scalar::cast_from(ciphertext_modulus.get_custom_modulus() / 2) |
| 65 | + } |
| 66 | +} |
| 67 | + |
| 68 | +fn ks_extended_pbs(criterion: &mut Criterion) { |
| 69 | + let bench_name = "core_crypto::ks_extended_pbs"; |
| 70 | + let mut bench_group = criterion.benchmark_group(bench_name); |
| 71 | + |
| 72 | + // Create the PRNG |
| 73 | + let mut seeder = new_seeder(); |
| 74 | + let seeder = seeder.as_mut(); |
| 75 | + let mut encryption_generator = |
| 76 | + EncryptionRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed(), seeder); |
| 77 | + let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed()); |
| 78 | + |
| 79 | + for (name, params) in KS_EPBS_BENCH_PARAMS { |
| 80 | + let ExtendedPBSBenchParameters { |
| 81 | + lwe_dimension, |
| 82 | + glwe_dimension, |
| 83 | + polynomial_size, |
| 84 | + extension_factor, |
| 85 | + lwe_noise_distribution, |
| 86 | + glwe_noise_distribution, |
| 87 | + pbs_base_log, |
| 88 | + pbs_level, |
| 89 | + ks_base_log, |
| 90 | + ks_level, |
| 91 | + message_modulus, |
| 92 | + carry_modulus, |
| 93 | + max_norm2: _, |
| 94 | + log2_p_fail: _, |
| 95 | + ciphertext_modulus, |
| 96 | + encryption_key_choice, |
| 97 | + } = *params; |
| 98 | + |
| 99 | + let plaintext_modulus = message_modulus.0 * carry_modulus.0; |
| 100 | + let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus); |
| 101 | + let delta = encoding_with_padding / plaintext_modulus; |
| 102 | + |
| 103 | + assert!(matches!(encryption_key_choice, EncryptionKeyChoice::Big)); |
| 104 | + |
| 105 | + let lwe_sk = |
| 106 | + allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator); |
| 107 | + |
| 108 | + let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key( |
| 109 | + glwe_dimension, |
| 110 | + polynomial_size, |
| 111 | + &mut secret_generator, |
| 112 | + ); |
| 113 | + let big_lwe_sk = glwe_sk.as_lwe_secret_key(); |
| 114 | + let ksk_big_to_small = allocate_and_generate_new_lwe_keyswitch_key( |
| 115 | + &big_lwe_sk, |
| 116 | + &lwe_sk, |
| 117 | + ks_base_log, |
| 118 | + ks_level, |
| 119 | + lwe_noise_distribution, |
| 120 | + ciphertext_modulus, |
| 121 | + &mut encryption_generator, |
| 122 | + ); |
| 123 | + |
| 124 | + let bsk = allocate_and_generate_new_lwe_bootstrap_key( |
| 125 | + &lwe_sk, |
| 126 | + &glwe_sk, |
| 127 | + pbs_base_log, |
| 128 | + pbs_level, |
| 129 | + glwe_noise_distribution, |
| 130 | + ciphertext_modulus, |
| 131 | + &mut encryption_generator, |
| 132 | + ); |
| 133 | + |
| 134 | + let mut fourier_bsk = FourierLweBootstrapKey::new( |
| 135 | + bsk.input_lwe_dimension(), |
| 136 | + bsk.glwe_size(), |
| 137 | + bsk.polynomial_size(), |
| 138 | + bsk.decomposition_base_log(), |
| 139 | + bsk.decomposition_level_count(), |
| 140 | + ); |
| 141 | + par_convert_standard_lwe_bootstrap_key_to_fourier(&bsk, &mut fourier_bsk); |
| 142 | + |
| 143 | + let f = |x: u64| x; |
| 144 | + |
| 145 | + let accumulator = generate_programmable_bootstrap_glwe_lut( |
| 146 | + PolynomialSize(polynomial_size.0 * extension_factor.0), |
| 147 | + glwe_dimension.to_glwe_size(), |
| 148 | + plaintext_modulus.cast_into(), |
| 149 | + ciphertext_modulus, |
| 150 | + delta, |
| 151 | + f, |
| 152 | + ); |
| 153 | + |
| 154 | + let fft = Fft::new(fourier_bsk.polynomial_size()); |
| 155 | + let fft = fft.as_view(); |
| 156 | + |
| 157 | + let mut buffers = ComputationBuffers::new(); |
| 158 | + |
| 159 | + // TODO: have req for main thread and for workers ? |
| 160 | + use extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized_requirement as rq; |
| 161 | + |
| 162 | + let requirement = rq::<u64>( |
| 163 | + glwe_dimension.to_glwe_size(), |
| 164 | + polynomial_size, |
| 165 | + extension_factor, |
| 166 | + fft, |
| 167 | + ) |
| 168 | + .unaligned_bytes_required(); |
| 169 | + |
| 170 | + buffers.resize(requirement); |
| 171 | + |
| 172 | + let mut thread_buffers = Vec::with_capacity(extension_factor.0); |
| 173 | + for _ in 0..extension_factor.0 { |
| 174 | + let mut buffer = ComputationBuffers::new(); |
| 175 | + buffer.resize(requirement); |
| 176 | + thread_buffers.push(buffer); |
| 177 | + } |
| 178 | + |
| 179 | + let mut thread_stacks: Vec<_> = thread_buffers.iter_mut().map(|x| x.stack()).collect(); |
| 180 | + |
| 181 | + let bench_id; |
| 182 | + |
| 183 | + match get_bench_type() { |
| 184 | + BenchmarkType::Latency => { |
| 185 | + let ct = allocate_and_encrypt_new_lwe_ciphertext( |
| 186 | + &big_lwe_sk, |
| 187 | + Plaintext(0), |
| 188 | + lwe_noise_distribution, |
| 189 | + ciphertext_modulus, |
| 190 | + &mut encryption_generator, |
| 191 | + ); |
| 192 | + |
| 193 | + let mut ks_buffer = |
| 194 | + LweCiphertext::new(0, lwe_sk.lwe_dimension().to_lwe_size(), ciphertext_modulus); |
| 195 | + |
| 196 | + let mut output_ct = ct.clone(); |
| 197 | + output_ct.as_mut().fill(0); |
| 198 | + |
| 199 | + bench_id = format!("{bench_name}::{name}"); |
| 200 | + bench_group.bench_function(&bench_id, |b| { |
| 201 | + b.iter(|| { |
| 202 | + keyswitch_lwe_ciphertext(&ksk_big_to_small, &ct, &mut ks_buffer); |
| 203 | + extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized( |
| 204 | + &fourier_bsk, |
| 205 | + &mut output_ct, |
| 206 | + &ct, |
| 207 | + &accumulator, |
| 208 | + extension_factor, |
| 209 | + fft, |
| 210 | + buffers.stack(), |
| 211 | + &mut thread_stacks, |
| 212 | + ); |
| 213 | + black_box(&mut output_ct); |
| 214 | + }) |
| 215 | + }); |
| 216 | + } |
| 217 | + BenchmarkType::Throughput => { |
| 218 | + bench_id = format!("{bench_name}::throughput::{name}"); |
| 219 | + let mut setup = |batch_size: usize| { |
| 220 | + let inputs = (0..batch_size) |
| 221 | + .map(|_| { |
| 222 | + let ct = allocate_and_encrypt_new_lwe_ciphertext( |
| 223 | + &big_lwe_sk, |
| 224 | + Plaintext(0), |
| 225 | + lwe_noise_distribution, |
| 226 | + ciphertext_modulus, |
| 227 | + &mut encryption_generator, |
| 228 | + ); |
| 229 | + |
| 230 | + let ks_buffer = LweCiphertext::new( |
| 231 | + 0, |
| 232 | + lwe_sk.lwe_dimension().to_lwe_size(), |
| 233 | + ciphertext_modulus, |
| 234 | + ); |
| 235 | + |
| 236 | + let mut output_ct = ct.clone(); |
| 237 | + output_ct.as_mut().fill(0); |
| 238 | + |
| 239 | + let accumulator = generate_programmable_bootstrap_glwe_lut( |
| 240 | + PolynomialSize(polynomial_size.0 * extension_factor.0), |
| 241 | + glwe_dimension.to_glwe_size(), |
| 242 | + plaintext_modulus.cast_into(), |
| 243 | + ciphertext_modulus, |
| 244 | + delta, |
| 245 | + f, |
| 246 | + ); |
| 247 | + |
| 248 | + let fft = Fft::new(fourier_bsk.polynomial_size()); |
| 249 | + let fft = fft.as_view(); |
| 250 | + |
| 251 | + let mut main_thread_buffer = ComputationBuffers::new(); |
| 252 | + |
| 253 | + let requirement = rq::<u64>( |
| 254 | + glwe_dimension.to_glwe_size(), |
| 255 | + polynomial_size, |
| 256 | + extension_factor, |
| 257 | + fft, |
| 258 | + ) |
| 259 | + .unaligned_bytes_required(); |
| 260 | + |
| 261 | + main_thread_buffer.resize(requirement); |
| 262 | + |
| 263 | + let mut thread_buffers = Vec::with_capacity(extension_factor.0); |
| 264 | + for _ in 0..extension_factor.0 { |
| 265 | + let mut buffer = ComputationBuffers::new(); |
| 266 | + buffer.resize(requirement); |
| 267 | + thread_buffers.push(buffer); |
| 268 | + } |
| 269 | + |
| 270 | + ( |
| 271 | + ct, |
| 272 | + ks_buffer, |
| 273 | + output_ct, |
| 274 | + accumulator, |
| 275 | + main_thread_buffer, |
| 276 | + thread_buffers, |
| 277 | + ) |
| 278 | + }) |
| 279 | + .collect::<Vec<_>>(); |
| 280 | + inputs |
| 281 | + }; |
| 282 | + type Res = Vec<( |
| 283 | + LweCiphertext<Vec<u64>>, // Input |
| 284 | + LweCiphertext<Vec<u64>>, // KS result |
| 285 | + LweCiphertext<Vec<u64>>, // PBS result |
| 286 | + GlweCiphertext<Vec<u64>>, // Accumulator |
| 287 | + ComputationBuffers, // Main thread buffer |
| 288 | + Vec<ComputationBuffers>, // Worker thread buffer |
| 289 | + )>; |
| 290 | + let run = |inputs: &mut Res| { |
| 291 | + inputs.par_iter_mut().for_each( |
| 292 | + |( |
| 293 | + ct, |
| 294 | + ks_buffer, |
| 295 | + output_ct, |
| 296 | + accumulator, |
| 297 | + main_thread_buffer, |
| 298 | + thread_buffers, |
| 299 | + )| { |
| 300 | + let mut thread_stacks: Vec<_> = |
| 301 | + thread_buffers.iter_mut().map(|x| x.stack()).collect(); |
| 302 | + keyswitch_lwe_ciphertext(&ksk_big_to_small, ct, ks_buffer); |
| 303 | + extended_programmable_bootstrap_lwe_ciphertext_mem_optimized_parallelized( |
| 304 | + &fourier_bsk, |
| 305 | + output_ct, |
| 306 | + ct, |
| 307 | + accumulator, |
| 308 | + extension_factor, |
| 309 | + fft, |
| 310 | + main_thread_buffer.stack(), |
| 311 | + &mut thread_stacks, |
| 312 | + ); |
| 313 | + black_box(output_ct); |
| 314 | + }, |
| 315 | + ) |
| 316 | + }; |
| 317 | + let elements = { |
| 318 | + use benchmark::find_optimal_batch::find_optimal_batch; |
| 319 | + find_optimal_batch(|inputs, _batch_size| run(inputs), &mut setup) as u64 |
| 320 | + }; |
| 321 | + bench_group.throughput(Throughput::Elements(elements)); |
| 322 | + bench_group.bench_function(&bench_id, |b| { |
| 323 | + b.iter_batched( |
| 324 | + || setup(elements as usize), |
| 325 | + |mut inputs| run(&mut inputs), |
| 326 | + criterion::BatchSize::SmallInput, |
| 327 | + ) |
| 328 | + }); |
| 329 | + } |
| 330 | + }; |
| 331 | + } |
| 332 | +} |
| 333 | + |
| 334 | +pub fn extended_pbs_group() { |
| 335 | + let mut criterion: Criterion<_> = (Criterion::default() |
| 336 | + .sample_size(15) |
| 337 | + .measurement_time(std::time::Duration::from_secs(60))) |
| 338 | + .configure_from_args(); |
| 339 | + ks_extended_pbs(&mut criterion); |
| 340 | +} |
| 341 | + |
| 342 | +fn go_through_cpu_bench_groups() { |
| 343 | + extended_pbs_group(); |
| 344 | +} |
| 345 | + |
| 346 | +fn main() { |
| 347 | + go_through_cpu_bench_groups(); |
| 348 | + |
| 349 | + Criterion::default().configure_from_args().final_summary(); |
| 350 | +} |
0 commit comments