diff --git a/crates/fast_compute/src/layer.rs b/crates/fast_compute/src/layer.rs index bac9a444e..94aa99bd3 100644 --- a/crates/fast_compute/src/layer.rs +++ b/crates/fast_compute/src/layer.rs @@ -46,12 +46,12 @@ use binius_utils::{ strided_array::StridedArray2DViewMut, }; use bytemuck::{Pod, zeroed_vec}; -use itertools::izip; +use itertools::{Itertools, izip}; use thread_local::ThreadLocal; use crate::{ arith_circuit::ArithCircuitPoly, - memory::{PackedMemory, PackedMemorySliceMut}, + memory::{PackedMemory, PackedMemorySlice, PackedMemorySliceMut}, }; /// Optimized CPU implementation of the compute layer. @@ -176,14 +176,11 @@ impl> ComputeLayer for FastCpuLayer>::FSliceMut<'_>, value: T::B128, ) -> Result<(), Error> { - match slice { - PackedMemorySliceMut::Slice(items) => { - items.fill(P::broadcast(value)); - } - PackedMemorySliceMut::SingleElement { owned, .. } => { - owned.fill(value); - } - }; + let value = P::broadcast(value); + + for element in slice.as_slice_mut() { + *element = value; + } Ok(()) } } @@ -622,11 +619,68 @@ impl<'a, T: TowerFamily, P: PackedTop> ComputeLayerExecutor fn pairwise_product_reduce( &mut self, - _input: >::FSlice<'_>, - _round_outputs: &mut [>::FSliceMut<'_>], + input: >::FSlice<'_>, + round_outputs: &mut [>::FSliceMut<'_>], ) -> Result<(), Error> { - // TODO(CRY-490) - todo!() + let log_num_inputs = match strict_log_2(input.len()) { + None => { + return Err(Error::InputValidation(format!( + "input length must be a power of 2: {}", + input.len() + ))); + } + Some(0) => { + return Err(Error::InputValidation(format!( + "input length must be greater than or equal to 2 in order to perform at least one reduction: {}", + input.len() + ))); + } + Some(log_num_inputs) => log_num_inputs, + }; + let expected_round_outputs_len = log_num_inputs; + if round_outputs.len() != expected_round_outputs_len as usize { + return Err(Error::InputValidation(format!( + "round_outputs.len() does not match the expected length: {} != {expected_round_outputs_len}", + round_outputs.len() + ))); + } + for (round_idx, round_output_data) in round_outputs.iter().enumerate() { + let expected_output_size = 1usize << (log_num_inputs as usize - round_idx - 1); + if round_output_data.len() != expected_output_size { + return Err(Error::InputValidation(format!( + "round_outputs[{}].len() = {}, expected {expected_output_size}", + round_idx, + round_output_data.len() + ))); + } + } + + let mut round_data_source = input; + for round_output_data in round_outputs.iter_mut() { + match round_data_source { + PackedMemorySlice::Slice(input) => { + input + .par_chunks(2) + .zip(round_output_data.as_slice_mut().par_iter_mut()) + .for_each(|(chunk, output)| { + let scalar_iter = P::iter_slice(chunk) + .tuples() + .map(|(left, right)| left * right); + *output = P::from_scalars(scalar_iter); + }); + } + PackedMemorySlice::Owned(..) => { + let scalar_iter = P::iter_slice(round_data_source.as_slice()) + .tuples() + .map(|(left, right)| left * right); + + round_output_data.as_slice_mut()[0] = P::from_scalars(scalar_iter); + } + } + round_data_source = round_output_data.as_const(); + } + + Ok(()) } } diff --git a/crates/fast_compute/src/memory.rs b/crates/fast_compute/src/memory.rs index 7810daec8..36c536169 100644 --- a/crates/fast_compute/src/memory.rs +++ b/crates/fast_compute/src/memory.rs @@ -286,10 +286,6 @@ impl SmallOwnedChunk

{ fn iter_scalars(&self) -> impl Iterator { self.data.iter().take(self.len) } - - pub fn fill(&mut self, value: P::Scalar) { - self.data = P::broadcast(value) - } } /// Memory slice that can be either a borrowed slice or an owned small chunk (with length < diff --git a/crates/fast_compute/tests/layer.rs b/crates/fast_compute/tests/layer.rs index 9c3e10b31..3ff0f3eae 100644 --- a/crates/fast_compute/tests/layer.rs +++ b/crates/fast_compute/tests/layer.rs @@ -167,3 +167,23 @@ fn test_map_kernels() { log_len, ); } + +#[test] +fn test_pairwise_product_reduce_single_round() { + type P = PackedBinaryField4x128b; + let log_len = 1; + binius_compute_test_utils::layer::test_generic_pairwise_product_reduce( + FastCpuLayerHolder::::new(1 << (log_len + 4), 1 << (log_len + 3)), + log_len, + ); +} + +#[test] +fn test_pairwise_product_reduce() { + type P = PackedBinaryField4x128b; + let log_len = 8; + binius_compute_test_utils::layer::test_generic_pairwise_product_reduce( + FastCpuLayerHolder::::new(1 << (log_len + 4), 1 << (log_len + 3)), + log_len, + ); +}