diff --git a/crates/compute/src/cpu/layer.rs b/crates/compute/src/cpu/layer.rs index 4c297cb6a..52e399227 100644 --- a/crates/compute/src/cpu/layer.rs +++ b/crates/compute/src/cpu/layer.rs @@ -5,7 +5,7 @@ use std::{iter, marker::PhantomData}; use binius_field::{BinaryField, ExtensionField, Field, TowerField, util::inner_product_unchecked}; use binius_math::{ArithCircuit, TowerTop, extrapolate_line_scalar}; use binius_ntt::AdditiveNTT; -use binius_utils::checked_arithmetics::checked_log_2; +use binius_utils::checked_arithmetics::{checked_log_2, strict_log_2}; use bytemuck::zeroed_vec; use itertools::izip; @@ -433,6 +433,55 @@ impl ComputeLayerExecutor for CpuLayerExecutor { Ok(()) } + + fn pairwise_product_reduce( + &mut self, + input: >::FSlice<'_>, + round_outputs: &mut [>::FSliceMut<'_>], + ) -> Result<(), Error> { + let log_num_inputs = match strict_log_2(input.len()) { + None => { + return Err(Error::InputValidation(format!( + "input length must be a power of 2: {}", + input.len() + ))); + } + Some(0) => { + return Err(Error::InputValidation(format!( + "input length must be greater than or equal to 2 in order to perform at least one reduction: {}", + input.len() + ))); + } + Some(log_num_inputs) => log_num_inputs, + }; + let expected_round_outputs_len = log_num_inputs; + if round_outputs.len() != expected_round_outputs_len as usize { + return Err(Error::InputValidation(format!( + "round_outputs.len() does not match the expected length: {} != {expected_round_outputs_len}", + round_outputs.len() + ))); + } + for (round_idx, round_output_data) in round_outputs.iter().enumerate() { + let expected_output_size = 1usize << (log_num_inputs as usize - round_idx - 1); + if round_output_data.len() != expected_output_size { + return Err(Error::InputValidation(format!( + "round_outputs[{}].len() = {}, expected {expected_output_size}", + round_idx, + round_output_data.len() + ))); + } + } + let mut round_data_source = input; + for round_output_data in round_outputs.iter_mut() { + for idx in 0..round_output_data.len() { + round_output_data[idx] = + round_data_source[idx * 2] * round_data_source[idx * 2 + 1]; + } + round_data_source = round_output_data + } + + Ok(()) + } } #[derive(Debug)] diff --git a/crates/compute/src/layer.rs b/crates/compute/src/layer.rs index ff6f44b54..5c7c4f9b8 100644 --- a/crates/compute/src/layer.rs +++ b/crates/compute/src/layer.rs @@ -462,6 +462,51 @@ pub trait ComputeLayerExecutor { output: &mut >::FSliceMut<'_>, composition: &Self::ExprEval, ) -> Result<(), Error>; + + /// Reduces a slice of elements to a single value by recursively applying pairwise + /// multiplication. + /// + /// Given an input slice `x` of length `n = 2^k` for some integer `k`, + /// this function computes the result: + /// + /// $$ + /// y = \prod_{i=0}^{n-1} x_i + /// $$ + /// + /// However, instead of a flat left-to-right reduction, the computation proceeds + /// in ⌈log₂(n)⌉ rounds, halving the number of elements each time: + /// + /// - Round 0: $$ x_{i,0} = x_{2i} \cdot x_{2i+1} \quad \text{for } i = 0 \ldots \frac{n}{2} - 1 + /// $$ + /// + /// - Round 1: $$ x_{i,1} = x_{2i,0} \cdot x_{2i+1,0} $$ + /// + /// - ... + /// + /// - Final round: $$ y = x_{0,k} = \prod_{i=0}^{n-1} x_i $$ + /// + /// This binary tree-style reduction is mathematically equivalent to the full product, + /// but structured for efficient parallelization. + /// + /// ## Arguments + /// + /// * `input`` - A slice of input field elements provided to the first reduction round + /// * `round_outputs` - A mutable slice of preallocated output field elements for each reduction + /// round. `round_outputs.len()` must equal log₂(input.len()) - 1. The length of the FSlice at + /// index i must equal input.len() / 2**(i + 1) for i in 0..round_outputs.len(). + /// + /// ## Throws + /// + /// * Returns an error if the length of `input` is not a power of 2. + /// * Returns an error if the length of `input` is less than 2 (no reductions are possible). + /// * Returns an error if `round_outputs.len()` != log₂(input.len()) + /// * Returns an error if any element in `round_outputs` does not satisfy + /// `round_outputs[i].len() == input.len() / 2**(i + 1)` for i in 0..round_outputs.len() + fn pairwise_product_reduce( + &mut self, + input: >::FSlice<'_>, + round_outputs: &mut [>::FSliceMut<'_>], + ) -> Result<(), Error>; } /// An interface for defining execution kernels. diff --git a/crates/compute/tests/layer.rs b/crates/compute/tests/layer.rs index ad31cedb9..e5de22aa1 100644 --- a/crates/compute/tests/layer.rs +++ b/crates/compute/tests/layer.rs @@ -145,3 +145,21 @@ fn test_map_kernels() { log_len, ); } + +#[test] +fn test_pairwise_product_reduce_single_round() { + let log_len = 1; + binius_compute_test_utils::layer::test_generic_pairwise_product_reduce( + CpuLayerHolder::::new(1 << (log_len + 4), 1 << (log_len + 3)), + log_len, + ); +} + +#[test] +fn test_pairwise_product_reduce() { + let log_len = 8; + binius_compute_test_utils::layer::test_generic_pairwise_product_reduce( + CpuLayerHolder::::new(1 << (log_len + 4), 1 << (log_len + 3)), + log_len, + ); +} diff --git a/crates/compute_test_utils/Cargo.toml b/crates/compute_test_utils/Cargo.toml index 4252a0711..d590223b5 100644 --- a/crates/compute_test_utils/Cargo.toml +++ b/crates/compute_test_utils/Cargo.toml @@ -18,4 +18,5 @@ binius_math = { path = "../math", default-features = false } binius_ntt = { path = "../ntt", default-features = false } binius_utils = { path = "../utils", default-features = false } bytemuck = { workspace = true, features = ["extern_crate_alloc"] } +itertools.workspace = true rand.workspace = true diff --git a/crates/compute_test_utils/src/layer.rs b/crates/compute_test_utils/src/layer.rs index 119a21de4..f4c2ae00b 100644 --- a/crates/compute_test_utils/src/layer.rs +++ b/crates/compute_test_utils/src/layer.rs @@ -16,6 +16,7 @@ use binius_math::{ }; use binius_ntt::fri::fold_interleaved; use binius_utils::checked_arithmetics::checked_log_2; +use itertools::Itertools; use rand::{Rng, SeedableRng, prelude::StdRng}; pub fn test_generic_single_tensor_expand< @@ -904,3 +905,56 @@ pub fn test_map_kernels( assert_eq!(*output, input_1_host[i] + input_2_host[i]); } } + +pub fn test_generic_pairwise_product_reduce( + mut compute_holder: ComputeHolderType, + log_len: usize, +) where + F: Field, + Hal: ComputeLayer, + ComputeHolderType: ComputeHolder, +{ + let mut rng = StdRng::seed_from_u64(0); + + let ComputeData { + hal, + host_alloc, + dev_alloc, + .. + } = compute_holder.to_data(); + + let input = host_alloc.alloc(1 << log_len).unwrap(); + input.fill_with(|| F::random(&mut rng)); + + let mut round_outputs = Vec::new(); + let mut current_len = input.len() / 2; + while current_len >= 1 { + round_outputs.push(dev_alloc.alloc(current_len).unwrap()); + current_len /= 2; + } + + let mut dev_input = dev_alloc.alloc(input.len()).unwrap(); + hal.copy_h2d(input, &mut dev_input).unwrap(); + hal.execute(|exec| { + exec.pairwise_product_reduce(Hal::DevMem::as_const(&dev_input), round_outputs.as_mut()) + .unwrap(); + Ok(vec![]) + }) + .unwrap(); + let mut working_input = input.to_vec(); + let mut round_idx = 0; + while working_input.len() >= 2 { + let mut round_results = (0..working_input.len() / 2).map(|_| F::ZERO).collect_vec(); + for idx in 0..round_results.len() { + round_results[idx] = working_input[idx * 2] * working_input[idx * 2 + 1]; + } + + let actual_round_result = host_alloc.alloc(working_input.len() / 2).unwrap(); + hal.copy_d2h(Hal::DevMem::as_const(&round_outputs[round_idx]), actual_round_result) + .unwrap(); + assert_eq!(round_results, actual_round_result); + + working_input = round_results; + round_idx += 1; + } +} diff --git a/crates/fast_compute/src/layer.rs b/crates/fast_compute/src/layer.rs index 79eda14ae..bac9a444e 100644 --- a/crates/fast_compute/src/layer.rs +++ b/crates/fast_compute/src/layer.rs @@ -619,6 +619,15 @@ impl<'a, T: TowerFamily, P: PackedTop> ComputeLayerExecutor result.into_iter().map(|(_, out)| out).collect() } + + fn pairwise_product_reduce( + &mut self, + _input: >::FSlice<'_>, + _round_outputs: &mut [>::FSliceMut<'_>], + ) -> Result<(), Error> { + // TODO(CRY-490) + todo!() + } } /// In case when `P1` and `P2` are the same type, this function performs the extrapolation