@@ -113,6 +113,26 @@ __global__ void keyswitch_gemm_copy_message(const Torus *lwe_in, Torus *lwe_out,
113113 -lwe_in[lwe_id * (lwe_dimension_in + 1 ) + lwe_dimension_in];
114114}
115115
116+ template <typename Torus>
117+ __global__ void keyswitch_gemm_copy_message_with_indices (
118+ const Torus *__restrict__ lwe_in,
119+ const Torus *__restrict__ lwe_input_indices, Torus *__restrict__ lwe_out,
120+ const Torus *__restrict__ lwe_output_indices,
121+
122+ uint32_t lwe_dimension_in, uint32_t num_lwes, uint32_t lwe_dimension_out) {
123+
124+ uint32_t lwe_id = blockIdx .x * blockDim .x + threadIdx .x ;
125+
126+ if (lwe_id >= num_lwes)
127+ return ;
128+
129+ uint32_t lwe_in_idx = lwe_input_indices[lwe_id];
130+ uint32_t lwe_out_idx = lwe_output_indices[lwe_id];
131+
132+ lwe_out[lwe_in_idx * (lwe_dimension_out + 1 ) + lwe_dimension_out] =
133+ -lwe_in[lwe_out_idx * (lwe_dimension_in + 1 ) + lwe_dimension_in];
134+ }
135+
116136// Continue decomposition of an array of Torus elements in place. Supposes
117137// that the array contains already decomposed elements and
118138// computes the new decomposed level in place.
@@ -256,10 +276,10 @@ __host__ void host_keyswitch_lwe_ciphertext_vector(
256276template <typename Torus>
257277__host__ int host_gemm_keyswitch_lwe_ciphertext_vector (
258278 cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out,
259- Torus const *lwe_output_indexes , Torus const *lwe_array_in,
260- Torus const *lwe_input_indexes , Torus const *ksk, uint32_t lwe_dimension_in,
279+ Torus const *lwe_output_indices , Torus const *lwe_array_in,
280+ Torus const *lwe_input_indices , Torus const *ksk, uint32_t lwe_dimension_in,
261281 uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
262- uint32_t num_samples, Torus *fp_tmp_buffer) {
282+ uint32_t num_samples, Torus *fp_tmp_buffer, bool uses_trivial_indices ) {
263283 cuda_set_device (gpu_index);
264284 check_cuda_error (cudaGetLastError ());
265285
@@ -280,10 +300,18 @@ __host__ int host_gemm_keyswitch_lwe_ciphertext_vector(
280300 // lwe_array_out is num_samples x (lwe_dimension_out + 1). copy the bodies
281301 // lwe_array_in[:,lwe_dimension_in] to lwe_array_out[:,lwe_dimension_out]
282302 // and negate
283- keyswitch_gemm_copy_message<Torus><<<grid_copy, threads_copy, 0 , stream>>> (
284- lwe_array_in, lwe_array_out, lwe_dimension_in, num_samples,
285- lwe_dimension_out);
286- check_cuda_error (cudaGetLastError ());
303+ if (uses_trivial_indices) {
304+ keyswitch_gemm_copy_message<Torus><<<grid_copy, threads_copy, 0 , stream>>> (
305+ lwe_array_in, lwe_array_out, lwe_dimension_in, num_samples,
306+ lwe_dimension_out);
307+ check_cuda_error (cudaGetLastError ());
308+ } else {
309+ keyswitch_gemm_copy_message_with_indices<Torus>
310+ <<<grid_copy, threads_copy, 0 , stream>>> (
311+ lwe_array_in, lwe_input_indices, lwe_array_out, lwe_output_indices,
312+ lwe_dimension_in, num_samples, lwe_dimension_out);
313+ check_cuda_error (cudaGetLastError ());
314+ }
287315
288316 // dump_2d_gpu_to_file(lwe_array_out, num_samples, lwe_dimension_out + 1,
289317 // "lwe_out_only_body", prefix, stream, gpu_index);
@@ -322,10 +350,19 @@ __host__ int host_gemm_keyswitch_lwe_ciphertext_vector(
322350 lwe_dimension_in, "state_init", prefix, stream,
323351 gpu_index);*/
324352
325- tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>> (
326- num_samples, (lwe_dimension_out + 1 ), lwe_dimension_in, d_mem_0, ksk,
327- stride_KSK_buffer, lwe_array_out, lwe_dimension_out + 1 );
328- check_cuda_error (cudaGetLastError ());
353+ if (uses_trivial_indices) {
354+ tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>> (
355+ num_samples, (lwe_dimension_out + 1 ), lwe_dimension_in, d_mem_0, ksk,
356+ stride_KSK_buffer, lwe_array_out, lwe_dimension_out + 1 );
357+ check_cuda_error (cudaGetLastError ());
358+ } else {
359+ tgemm_with_indices<Torus>
360+ <<<grid_gemm, threads_gemm, shared_mem_size, stream>>> (
361+ num_samples, (lwe_dimension_out + 1 ), lwe_dimension_in, d_mem_0,
362+ lwe_input_indices, ksk, stride_KSK_buffer, lwe_array_out,
363+ lwe_dimension_out + 1 , lwe_output_indices);
364+ check_cuda_error (cudaGetLastError ());
365+ }
329366
330367 /* dump_2d_gpu_to_file(lwe_array_out, num_samples, lwe_dimension_out + 1,
331368 "tgemm0", prefix, stream, gpu_index);*/
@@ -400,35 +437,35 @@ void execute_keyswitch_async(CudaStreams streams,
400437 Torus *current_lwe_input_indexes =
401438 get_variant_element (lwe_input_indexes, i);
402439
403- if (uses_trivial_indices && num_samples >= 19 ) {
440+ if (num_samples >= 144 ) {
404441 // Compute Keyswitch
405- /* Torus *dup_out = (Torus *)cuda_malloc_async(
406- num_samples_on_gpu * (lwe_dimension_out + 1) * sizeof(Torus),
407- streams.stream(i), streams.gpu_index(i));
408- uint64_t buffer_size = scratch_cuda_keyswitch_size<Torus>(
409- lwe_dimension_in, lwe_dimension_out, num_samples_on_gpu);
410- Torus *tmp_buf = (Torus *)cuda_malloc_async(
411- buffer_size, streams.stream(i), streams.gpu_index(i));*/
442+ /* Torus *dup_out = (Torus *)cuda_malloc_async(
443+ num_samples_on_gpu * (lwe_dimension_out + 1) * sizeof(Torus),
444+ streams.stream(i), streams.gpu_index(i));
445+ uint64_t buffer_size = scratch_cuda_keyswitch_size<Torus>(
446+ lwe_dimension_in, lwe_dimension_out, num_samples_on_gpu);
447+ Torus *tmp_buf = (Torus *)cuda_malloc_async(
448+ buffer_size, streams.stream(i), streams.gpu_index(i));*/
412449
413450 host_gemm_keyswitch_lwe_ciphertext_vector<Torus>(
414451 streams.stream (i), streams.gpu_index (i), current_lwe_array_out,
415452 current_lwe_output_indexes, current_lwe_array_in,
416453 current_lwe_input_indexes, ksks[i], lwe_dimension_in,
417454 lwe_dimension_out, base_log, level_count, num_samples_on_gpu,
418- fp_tmp_buffer[i]);
455+ fp_tmp_buffer[i], uses_trivial_indices );
419456
420457 // Compute Keyswitch
421- /* host_keyswitch_lwe_ciphertext_vector<Torus>(
422- streams.stream(i), streams.gpu_index(i), dup_out,
423- current_lwe_output_indexes, current_lwe_array_in,
424- current_lwe_input_indexes, ksks[i], lwe_dimension_in,
425- lwe_dimension_out, base_log, level_count, num_samples_on_gpu);*/
426-
427- /* compare_2d_arrays(dup_out, current_lwe_array_out, num_samples_on_gpu ,
428- lwe_dimension_out + 1, streams.stream(i),
429- streams.gpu_index(i));*/
430- /* cuda_drop_async(dup_out, streams.stream(i), streams.gpu_index(i));
431- cuda_drop_async(tmp_buf, streams.stream(i), streams.gpu_index(i));*/
458+ /* host_keyswitch_lwe_ciphertext_vector<Torus>(
459+ streams.stream(i), streams.gpu_index(i), dup_out,
460+ current_lwe_output_indexes, current_lwe_array_in,
461+ current_lwe_input_indexes, ksks[i], lwe_dimension_in,
462+ lwe_dimension_out, base_log, level_count, num_samples_on_gpu);*/
463+
464+ /* compare_2d_arrays(dup_out, current_lwe_array_out,
465+ num_samples_on_gpu, lwe_dimension_out + 1, streams.stream(i),
466+ streams.gpu_index(i));*/
467+ /* cuda_drop_async(dup_out, streams.stream(i), streams.gpu_index(i));
468+ cuda_drop_async(tmp_buf, streams.stream(i), streams.gpu_index(i));*/
432469 ;
433470 } else {
434471 // Compute Keyswitch
0 commit comments