Skip to content
This repository was archived by the owner on Sep 9, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion crates/compute/src/cpu/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{iter, marker::PhantomData};
use binius_field::{BinaryField, ExtensionField, Field, TowerField, util::inner_product_unchecked};
use binius_math::{ArithCircuit, TowerTop, extrapolate_line_scalar};
use binius_ntt::AdditiveNTT;
use binius_utils::checked_arithmetics::checked_log_2;
use binius_utils::checked_arithmetics::{checked_log_2, strict_log_2};
use bytemuck::zeroed_vec;
use itertools::izip;

Expand Down Expand Up @@ -433,6 +433,55 @@ impl<F: TowerTop> ComputeLayerExecutor<F> for CpuLayerExecutor<F> {

Ok(())
}

fn pairwise_product_reduce(
&mut self,
input: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
round_outputs: &mut [<Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>],
) -> Result<(), Error> {
let log_num_inputs = match strict_log_2(input.len()) {
None => {
return Err(Error::InputValidation(format!(
"input length must be a power of 2: {}",
input.len()
)));
}
Some(0) => {
return Err(Error::InputValidation(format!(
"input length must be greater than or equal to 2 in order to perform at least one reduction: {}",
input.len()
)));
}
Some(log_num_inputs) => log_num_inputs,
};
let expected_round_outputs_len = log_num_inputs;
if round_outputs.len() != expected_round_outputs_len as usize {
return Err(Error::InputValidation(format!(
"round_outputs.len() does not match the expected length: {} != {expected_round_outputs_len}",
round_outputs.len()
)));
}
for (round_idx, round_output_data) in round_outputs.iter().enumerate() {
let expected_output_size = 1usize << (log_num_inputs as usize - round_idx - 1);
if round_output_data.len() != expected_output_size {
return Err(Error::InputValidation(format!(
"round_outputs[{}].len() = {}, expected {expected_output_size}",
round_idx,
round_output_data.len()
)));
}
}
let mut round_data_source = input;
for round_output_data in round_outputs.iter_mut() {
for idx in 0..round_output_data.len() {
round_output_data[idx] =
round_data_source[idx * 2] * round_data_source[idx * 2 + 1];
}
round_data_source = round_output_data
}

Ok(())
}
}

#[derive(Debug)]
Expand Down
45 changes: 45 additions & 0 deletions crates/compute/src/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,51 @@ pub trait ComputeLayerExecutor<F: Field> {
output: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
composition: &Self::ExprEval,
) -> Result<(), Error>;

/// Reduces a slice of elements to a single value by recursively applying pairwise
/// multiplication.
///
/// Given an input slice `x` of length `n = 2^k` for some integer `k`,
/// this function computes the result:
///
/// $$
/// y = \prod_{i=0}^{n-1} x_i
/// $$
///
/// However, instead of a flat left-to-right reduction, the computation proceeds
/// in ⌈log₂(n)⌉ rounds, halving the number of elements each time:
///
/// - Round 0: $$ x_{i,0} = x_{2i} \cdot x_{2i+1} \quad \text{for } i = 0 \ldots \frac{n}{2} - 1
/// $$
///
/// - Round 1: $$ x_{i,1} = x_{2i,0} \cdot x_{2i+1,0} $$
///
/// - ...
///
/// - Final round: $$ y = x_{0,k} = \prod_{i=0}^{n-1} x_i $$
///
/// This binary tree-style reduction is mathematically equivalent to the full product,
/// but structured for efficient parallelization.
///
/// ## Arguments
///
/// * `input`` - A slice of input field elements provided to the first reduction round
/// * `round_outputs` - A mutable slice of preallocated output field elements for each reduction
/// round. `round_outputs.len()` must equal log₂(input.len()) - 1. The length of the FSlice at
/// index i must equal input.len() / 2**(i + 1) for i in 0..round_outputs.len().
///
/// ## Throws
///
/// * Returns an error if the length of `input` is not a power of 2.
/// * Returns an error if the length of `input` is less than 2 (no reductions are possible).
/// * Returns an error if `round_outputs.len()` != log₂(input.len())
/// * Returns an error if any element in `round_outputs` does not satisfy
/// `round_outputs[i].len() == input.len() / 2**(i + 1)` for i in 0..round_outputs.len()
fn pairwise_product_reduce(
&mut self,
input: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
round_outputs: &mut [<Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>],
) -> Result<(), Error>;
}

/// An interface for defining execution kernels.
Expand Down
18 changes: 18 additions & 0 deletions crates/compute/tests/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,21 @@ fn test_map_kernels() {
log_len,
);
}

#[test]
fn test_pairwise_product_reduce_single_round() {
let log_len = 1;
binius_compute_test_utils::layer::test_generic_pairwise_product_reduce(
CpuLayerHolder::<B128>::new(1 << (log_len + 4), 1 << (log_len + 3)),
log_len,
);
}

#[test]
fn test_pairwise_product_reduce() {
let log_len = 8;
binius_compute_test_utils::layer::test_generic_pairwise_product_reduce(
CpuLayerHolder::<B128>::new(1 << (log_len + 4), 1 << (log_len + 3)),
log_len,
);
}
1 change: 1 addition & 0 deletions crates/compute_test_utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ binius_math = { path = "../math", default-features = false }
binius_ntt = { path = "../ntt", default-features = false }
binius_utils = { path = "../utils", default-features = false }
bytemuck = { workspace = true, features = ["extern_crate_alloc"] }
itertools.workspace = true
rand.workspace = true
54 changes: 54 additions & 0 deletions crates/compute_test_utils/src/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use binius_math::{
};
use binius_ntt::fri::fold_interleaved;
use binius_utils::checked_arithmetics::checked_log_2;
use itertools::Itertools;
use rand::{Rng, SeedableRng, prelude::StdRng};

pub fn test_generic_single_tensor_expand<
Expand Down Expand Up @@ -904,3 +905,56 @@ pub fn test_map_kernels<F, Hal, ComputeHolderType>(
assert_eq!(*output, input_1_host[i] + input_2_host[i]);
}
}

pub fn test_generic_pairwise_product_reduce<F, Hal, ComputeHolderType>(
mut compute_holder: ComputeHolderType,
log_len: usize,
) where
F: Field,
Hal: ComputeLayer<F>,
ComputeHolderType: ComputeHolder<F, Hal>,
{
let mut rng = StdRng::seed_from_u64(0);

let ComputeData {
hal,
host_alloc,
dev_alloc,
..
} = compute_holder.to_data();

let input = host_alloc.alloc(1 << log_len).unwrap();
input.fill_with(|| F::random(&mut rng));

let mut round_outputs = Vec::new();
let mut current_len = input.len() / 2;
while current_len >= 1 {
round_outputs.push(dev_alloc.alloc(current_len).unwrap());
current_len /= 2;
}

let mut dev_input = dev_alloc.alloc(input.len()).unwrap();
hal.copy_h2d(input, &mut dev_input).unwrap();
hal.execute(|exec| {
exec.pairwise_product_reduce(Hal::DevMem::as_const(&dev_input), round_outputs.as_mut())
.unwrap();
Ok(vec![])
})
.unwrap();
let mut working_input = input.to_vec();
let mut round_idx = 0;
while working_input.len() >= 2 {
let mut round_results = (0..working_input.len() / 2).map(|_| F::ZERO).collect_vec();
for idx in 0..round_results.len() {
round_results[idx] = working_input[idx * 2] * working_input[idx * 2 + 1];
}

let actual_round_result = host_alloc.alloc(working_input.len() / 2).unwrap();
hal.copy_d2h(Hal::DevMem::as_const(&round_outputs[round_idx]), actual_round_result)
.unwrap();
assert_eq!(round_results, actual_round_result);

working_input = round_results;
round_idx += 1;
}
}
9 changes: 9 additions & 0 deletions crates/fast_compute/src/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,15 @@ impl<'a, T: TowerFamily, P: PackedTop<T>> ComputeLayerExecutor<T::B128>

result.into_iter().map(|(_, out)| out).collect()
}

fn pairwise_product_reduce(
&mut self,
_input: <Self::DevMem as ComputeMemory<T::B128>>::FSlice<'_>,
_round_outputs: &mut [<Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>],
) -> Result<(), Error> {
// TODO(CRY-490)
todo!()
}
}

/// In case when `P1` and `P2` are the same type, this function performs the extrapolation
Expand Down
Loading