@@ -35,7 +35,8 @@ device_accumulate_all_blocks(Torus *output, Torus const *input_block,
3535
3636template <typename Torus>
3737__host__ void accumulate_all_blocks (cudaStream_t stream, uint32_t gpu_index,
38- Torus *output, Torus const *input,
38+ CudaRadixCiphertextFFI *output,
39+ CudaRadixCiphertextFFI const *input,
3940 uint32_t lwe_dimension,
4041 uint32_t num_radix_blocks) {
4142
@@ -45,7 +46,8 @@ __host__ void accumulate_all_blocks(cudaStream_t stream, uint32_t gpu_index,
4546 getNumBlocksAndThreads (num_entries, 512 , num_blocks, num_threads);
4647 // Add all blocks and store in sum
4748 device_accumulate_all_blocks<Torus><<<num_blocks, num_threads, 0 , stream>>> (
48- output, input, lwe_dimension, num_radix_blocks);
49+ (Torus *)output->ptr , (Torus const *)input->ptr , lwe_dimension,
50+ num_radix_blocks);
4951 check_cuda_error (cudaGetLastError ());
5052}
5153
@@ -102,23 +104,33 @@ __host__ void are_all_comparisons_block_true(
102104
103105 // Since all blocks encrypt either 0 or 1, we can sum max_value of them
104106 // as in the worst case we will be adding `max_value` ones
105- auto input_blocks = (Torus *)tmp_out->ptr ;
106- auto accumulator_ptr =
107- (Torus *)are_all_block_true_buffer->tmp_block_accumulated ->ptr ;
108107 auto is_max_value_lut = are_all_block_true_buffer->is_max_value ;
108+ GPU_ASSERT (are_all_block_true_buffer->tmp_block_accumulated ->lwe_dimension ==
109+ big_lwe_dimension,
110+ " lwe_dimension mismatch between tmp_block_accumulated and "
111+ " big_lwe_dimension" );
112+ GPU_ASSERT (tmp_out->lwe_dimension == big_lwe_dimension,
113+ " lwe_dimension mismatch between tmp_out and big_lwe_dimension" );
109114 uint32_t chunk_lengths[num_chunks];
110115 auto begin_remaining_blocks = remaining_blocks;
116+ uint32_t acc_offset = 0 , inp_offset = 0 ;
111117 for (int i = 0 ; i < num_chunks; i++) {
112118 uint32_t chunk_length =
113119 std::min (max_value, begin_remaining_blocks - i * max_value);
114120 chunk_lengths[i] = chunk_length;
121+ CudaRadixCiphertextFFI acc_slice, inp_slice;
122+ as_radix_ciphertext_slice<Torus>(
123+ &acc_slice, are_all_block_true_buffer->tmp_block_accumulated ,
124+ acc_offset, acc_offset + 1 );
125+ as_radix_ciphertext_slice<Torus>(&inp_slice, tmp_out, inp_offset,
126+ inp_offset + chunk_length);
115127 accumulate_all_blocks<Torus>(streams.stream (0 ), streams.gpu_index (0 ),
116- accumulator_ptr, input_blocks ,
117- big_lwe_dimension, chunk_length);
128+ &acc_slice, &inp_slice, big_lwe_dimension ,
129+ chunk_length);
118130
119- accumulator_ptr += (big_lwe_dimension + 1 ) ;
131+ acc_offset += 1 ;
120132 remaining_blocks -= (chunk_length - 1 );
121- input_blocks += (big_lwe_dimension + 1 ) * chunk_length;
133+ inp_offset += chunk_length;
122134 }
123135 auto accumulator = are_all_block_true_buffer->tmp_block_accumulated ;
124136
@@ -219,21 +231,31 @@ __host__ void is_at_least_one_comparisons_block_true(
219231
220232 // Since all blocks encrypt either 0 or 1, we can sum max_value of them
221233 // as in the worst case we will be adding `max_value` ones
222- auto input_blocks = (Torus *)mem_ptr->tmp_lwe_array_out ->ptr ;
223- auto accumulator = (Torus *)buffer->tmp_block_accumulated ->ptr ;
234+ GPU_ASSERT (buffer->tmp_block_accumulated ->lwe_dimension == big_lwe_dimension,
235+ " lwe_dimension mismatch between tmp_block_accumulated and "
236+ " big_lwe_dimension" );
237+ GPU_ASSERT (mem_ptr->tmp_lwe_array_out ->lwe_dimension == big_lwe_dimension,
238+ " lwe_dimension mismatch between tmp_lwe_array_out and "
239+ " big_lwe_dimension" );
224240 uint32_t chunk_lengths[num_chunks];
225241 auto begin_remaining_blocks = remaining_blocks;
242+ uint32_t acc_offset = 0 , inp_offset = 0 ;
226243 for (int i = 0 ; i < num_chunks; i++) {
227244 uint32_t chunk_length =
228245 std::min (max_value, begin_remaining_blocks - i * max_value);
229246 chunk_lengths[i] = chunk_length;
247+ CudaRadixCiphertextFFI acc_slice, inp_slice;
248+ as_radix_ciphertext_slice<Torus>(&acc_slice, buffer->tmp_block_accumulated ,
249+ acc_offset, acc_offset + 1 );
250+ as_radix_ciphertext_slice<Torus>(&inp_slice, mem_ptr->tmp_lwe_array_out ,
251+ inp_offset, inp_offset + chunk_length);
230252 accumulate_all_blocks<Torus>(streams.stream (0 ), streams.gpu_index (0 ),
231- accumulator, input_blocks , big_lwe_dimension,
253+ &acc_slice, &inp_slice , big_lwe_dimension,
232254 chunk_length);
233255
234- accumulator += (big_lwe_dimension + 1 ) ;
256+ acc_offset += 1 ;
235257 remaining_blocks -= (chunk_length - 1 );
236- input_blocks += (big_lwe_dimension + 1 ) * chunk_length;
258+ inp_offset += chunk_length;
237259 }
238260
239261 // Selects a LUT
@@ -296,22 +318,31 @@ __host__ void host_compare_blocks_with_zero(
296318 streams.stream (0 ), streams.gpu_index (0 ), sum, 0 , 1 , lwe_array_in, 0 , 1 );
297319 num_sum_blocks = 1 ;
298320 } else {
321+ GPU_ASSERT (sum->lwe_dimension == big_lwe_dimension,
322+ " lwe_dimension mismatch between sum and big_lwe_dimension" );
323+ GPU_ASSERT (lwe_array_in->lwe_dimension == big_lwe_dimension,
324+ " lwe_dimension mismatch between lwe_array_in and "
325+ " big_lwe_dimension" );
299326 uint32_t remainder_blocks = num_radix_blocks;
300- auto sum_i = (Torus *)sum->ptr ;
301- auto chunk = (Torus *)lwe_array_in->ptr ;
327+ uint32_t sum_offset = 0 , inp_offset = 0 ;
302328 while (remainder_blocks > 1 ) {
303329 uint32_t chunk_size =
304330 std::min (remainder_blocks, num_elements_to_fill_carry);
305-
331+ CudaRadixCiphertextFFI sum_slice, inp_slice;
332+ as_radix_ciphertext_slice<Torus>(&sum_slice, sum, sum_offset,
333+ sum_offset + 1 );
334+ as_radix_ciphertext_slice<Torus>(&inp_slice, lwe_array_in, inp_offset,
335+ inp_offset + chunk_size);
306336 accumulate_all_blocks<Torus>(streams.stream (0 ), streams.gpu_index (0 ),
307- sum_i, chunk, big_lwe_dimension, chunk_size);
337+ &sum_slice, &inp_slice, big_lwe_dimension,
338+ chunk_size);
308339
309340 num_sum_blocks++;
310341 remainder_blocks -= (chunk_size - 1 );
311342
312343 // Update operands
313- chunk += ( chunk_size - 1 ) * big_lwe_size ;
314- sum_i += big_lwe_size ;
344+ inp_offset += chunk_size - 1 ;
345+ sum_offset += 1 ;
315346 }
316347 }
317348
@@ -381,9 +412,8 @@ compare_radix_blocks(CudaStreams streams, CudaRadixCiphertextFFI *lwe_array_out,
381412
382413 // Subtract
383414 host_subtraction<Torus>(
384- streams.stream (0 ), streams.gpu_index (0 ), (Torus *)lwe_array_out->ptr ,
385- (Torus *)lwe_array_left->ptr , (Torus *)lwe_array_right->ptr ,
386- big_lwe_dimension, num_radix_blocks);
415+ streams.stream (0 ), streams.gpu_index (0 ), lwe_array_out, lwe_array_left,
416+ lwe_array_right, big_lwe_dimension, num_radix_blocks);
387417
388418 // Apply LUT to compare to 0
389419 auto is_non_zero_lut = mem_ptr->eq_buffer ->is_non_zero_lut ;
0 commit comments