@@ -5,7 +5,7 @@ use std::{iter, marker::PhantomData};
55use binius_field:: { BinaryField , ExtensionField , Field , TowerField , util:: inner_product_unchecked} ;
66use binius_math:: { ArithCircuit , TowerTop , extrapolate_line_scalar} ;
77use binius_ntt:: AdditiveNTT ;
8- use binius_utils:: checked_arithmetics:: checked_log_2;
8+ use binius_utils:: checked_arithmetics:: { checked_log_2, strict_log_2 } ;
99use bytemuck:: zeroed_vec;
1010use itertools:: izip;
1111
@@ -433,6 +433,55 @@ impl<F: TowerTop> ComputeLayerExecutor<F> for CpuLayerExecutor<F> {
433433
434434 Ok ( ( ) )
435435 }
436+
437+ fn pairwise_product_reduce (
438+ & mut self ,
439+ input : <Self :: DevMem as ComputeMemory < F > >:: FSlice < ' _ > ,
440+ round_outputs : & mut [ <Self :: DevMem as ComputeMemory < F > >:: FSliceMut < ' _ > ] ,
441+ ) -> Result < ( ) , Error > {
442+ let log_num_inputs = match strict_log_2 ( input. len ( ) ) {
443+ None => {
444+ return Err ( Error :: InputValidation ( format ! (
445+ "input length must be a power of 2: {}" ,
446+ input. len( )
447+ ) ) ) ;
448+ }
449+ Some ( 0 ) => {
450+ return Err ( Error :: InputValidation ( format ! (
451+ "input length must be greater than or equal to 2 in order to perform at least one reduction: {}" ,
452+ input. len( )
453+ ) ) ) ;
454+ }
455+ Some ( log_num_inputs) => log_num_inputs,
456+ } ;
457+ let expected_round_outputs_len = log_num_inputs;
458+ if round_outputs. len ( ) != expected_round_outputs_len as usize {
459+ return Err ( Error :: InputValidation ( format ! (
460+ "round_outputs.len() does not match the expected length: {} != {expected_round_outputs_len}" ,
461+ round_outputs. len( )
462+ ) ) ) ;
463+ }
464+ for ( round_idx, round_output_data) in round_outputs. iter ( ) . enumerate ( ) {
465+ let expected_output_size = 1usize << ( log_num_inputs as usize - round_idx - 1 ) ;
466+ if round_output_data. len ( ) != expected_output_size {
467+ return Err ( Error :: InputValidation ( format ! (
468+ "round_outputs[{}].len() = {}, expected {expected_output_size}" ,
469+ round_idx,
470+ round_output_data. len( )
471+ ) ) ) ;
472+ }
473+ }
474+ let mut round_data_source = input;
475+ for round_output_data in round_outputs. iter_mut ( ) {
476+ for idx in 0 ..round_output_data. len ( ) {
477+ round_output_data[ idx] =
478+ round_data_source[ idx * 2 ] * round_data_source[ idx * 2 + 1 ] ;
479+ }
480+ round_data_source = round_output_data
481+ }
482+
483+ Ok ( ( ) )
484+ }
436485}
437486
438487#[ derive( Debug ) ]
0 commit comments