Skip to content
This repository was archived by the owner on Sep 9, 2025. It is now read-only.

Commit f1b3aa2

Browse files
SYS-348: add pairwise product reduction HAL op (#803)
This operation produces the partial products for each intermediate claim in the grand product argument.
1 parent 7672509 commit f1b3aa2

6 files changed

Lines changed: 177 additions & 1 deletion

File tree

crates/compute/src/cpu/layer.rs

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::{iter, marker::PhantomData};
55
use binius_field::{BinaryField, ExtensionField, Field, TowerField, util::inner_product_unchecked};
66
use binius_math::{ArithCircuit, TowerTop, extrapolate_line_scalar};
77
use binius_ntt::AdditiveNTT;
8-
use binius_utils::checked_arithmetics::checked_log_2;
8+
use binius_utils::checked_arithmetics::{checked_log_2, strict_log_2};
99
use bytemuck::zeroed_vec;
1010
use itertools::izip;
1111

@@ -433,6 +433,55 @@ impl<F: TowerTop> ComputeLayerExecutor<F> for CpuLayerExecutor<F> {
433433

434434
Ok(())
435435
}
436+
437+
fn pairwise_product_reduce(
438+
&mut self,
439+
input: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
440+
round_outputs: &mut [<Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>],
441+
) -> Result<(), Error> {
442+
let log_num_inputs = match strict_log_2(input.len()) {
443+
None => {
444+
return Err(Error::InputValidation(format!(
445+
"input length must be a power of 2: {}",
446+
input.len()
447+
)));
448+
}
449+
Some(0) => {
450+
return Err(Error::InputValidation(format!(
451+
"input length must be greater than or equal to 2 in order to perform at least one reduction: {}",
452+
input.len()
453+
)));
454+
}
455+
Some(log_num_inputs) => log_num_inputs,
456+
};
457+
let expected_round_outputs_len = log_num_inputs;
458+
if round_outputs.len() != expected_round_outputs_len as usize {
459+
return Err(Error::InputValidation(format!(
460+
"round_outputs.len() does not match the expected length: {} != {expected_round_outputs_len}",
461+
round_outputs.len()
462+
)));
463+
}
464+
for (round_idx, round_output_data) in round_outputs.iter().enumerate() {
465+
let expected_output_size = 1usize << (log_num_inputs as usize - round_idx - 1);
466+
if round_output_data.len() != expected_output_size {
467+
return Err(Error::InputValidation(format!(
468+
"round_outputs[{}].len() = {}, expected {expected_output_size}",
469+
round_idx,
470+
round_output_data.len()
471+
)));
472+
}
473+
}
474+
let mut round_data_source = input;
475+
for round_output_data in round_outputs.iter_mut() {
476+
for idx in 0..round_output_data.len() {
477+
round_output_data[idx] =
478+
round_data_source[idx * 2] * round_data_source[idx * 2 + 1];
479+
}
480+
round_data_source = round_output_data
481+
}
482+
483+
Ok(())
484+
}
436485
}
437486

438487
#[derive(Debug)]

crates/compute/src/layer.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,51 @@ pub trait ComputeLayerExecutor<F: Field> {
462462
output: &mut <Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>,
463463
composition: &Self::ExprEval,
464464
) -> Result<(), Error>;
465+
466+
/// Reduces a slice of elements to a single value by recursively applying pairwise
467+
/// multiplication.
468+
///
469+
/// Given an input slice `x` of length `n = 2^k` for some integer `k`,
470+
/// this function computes the result:
471+
///
472+
/// $$
473+
/// y = \prod_{i=0}^{n-1} x_i
474+
/// $$
475+
///
476+
/// However, instead of a flat left-to-right reduction, the computation proceeds
477+
/// in ⌈log₂(n)⌉ rounds, halving the number of elements each time:
478+
///
479+
/// - Round 0: $$ x_{i,0} = x_{2i} \cdot x_{2i+1} \quad \text{for } i = 0 \ldots \frac{n}{2} - 1
480+
/// $$
481+
///
482+
/// - Round 1: $$ x_{i,1} = x_{2i,0} \cdot x_{2i+1,0} $$
483+
///
484+
/// - ...
485+
///
486+
/// - Final round: $$ y = x_{0,k} = \prod_{i=0}^{n-1} x_i $$
487+
///
488+
/// This binary tree-style reduction is mathematically equivalent to the full product,
489+
/// but structured for efficient parallelization.
490+
///
491+
/// ## Arguments
492+
///
493+
/// * `input`` - A slice of input field elements provided to the first reduction round
494+
/// * `round_outputs` - A mutable slice of preallocated output field elements for each reduction
495+
/// round. `round_outputs.len()` must equal log₂(input.len()) - 1. The length of the FSlice at
496+
/// index i must equal input.len() / 2**(i + 1) for i in 0..round_outputs.len().
497+
///
498+
/// ## Throws
499+
///
500+
/// * Returns an error if the length of `input` is not a power of 2.
501+
/// * Returns an error if the length of `input` is less than 2 (no reductions are possible).
502+
/// * Returns an error if `round_outputs.len()` != log₂(input.len())
503+
/// * Returns an error if any element in `round_outputs` does not satisfy
504+
/// `round_outputs[i].len() == input.len() / 2**(i + 1)` for i in 0..round_outputs.len()
505+
fn pairwise_product_reduce(
506+
&mut self,
507+
input: <Self::DevMem as ComputeMemory<F>>::FSlice<'_>,
508+
round_outputs: &mut [<Self::DevMem as ComputeMemory<F>>::FSliceMut<'_>],
509+
) -> Result<(), Error>;
465510
}
466511

467512
/// An interface for defining execution kernels.

crates/compute/tests/layer.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,21 @@ fn test_map_kernels() {
145145
log_len,
146146
);
147147
}
148+
149+
#[test]
150+
fn test_pairwise_product_reduce_single_round() {
151+
let log_len = 1;
152+
binius_compute_test_utils::layer::test_generic_pairwise_product_reduce(
153+
CpuLayerHolder::<B128>::new(1 << (log_len + 4), 1 << (log_len + 3)),
154+
log_len,
155+
);
156+
}
157+
158+
#[test]
159+
fn test_pairwise_product_reduce() {
160+
let log_len = 8;
161+
binius_compute_test_utils::layer::test_generic_pairwise_product_reduce(
162+
CpuLayerHolder::<B128>::new(1 << (log_len + 4), 1 << (log_len + 3)),
163+
log_len,
164+
);
165+
}

crates/compute_test_utils/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ binius_math = { path = "../math", default-features = false }
1818
binius_ntt = { path = "../ntt", default-features = false }
1919
binius_utils = { path = "../utils", default-features = false }
2020
bytemuck = { workspace = true, features = ["extern_crate_alloc"] }
21+
itertools.workspace = true
2122
rand.workspace = true

crates/compute_test_utils/src/layer.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use binius_math::{
1616
};
1717
use binius_ntt::fri::fold_interleaved;
1818
use binius_utils::checked_arithmetics::checked_log_2;
19+
use itertools::Itertools;
1920
use rand::{Rng, SeedableRng, prelude::StdRng};
2021

2122
pub fn test_generic_single_tensor_expand<
@@ -904,3 +905,56 @@ pub fn test_map_kernels<F, Hal, ComputeHolderType>(
904905
assert_eq!(*output, input_1_host[i] + input_2_host[i]);
905906
}
906907
}
908+
909+
pub fn test_generic_pairwise_product_reduce<F, Hal, ComputeHolderType>(
910+
mut compute_holder: ComputeHolderType,
911+
log_len: usize,
912+
) where
913+
F: Field,
914+
Hal: ComputeLayer<F>,
915+
ComputeHolderType: ComputeHolder<F, Hal>,
916+
{
917+
let mut rng = StdRng::seed_from_u64(0);
918+
919+
let ComputeData {
920+
hal,
921+
host_alloc,
922+
dev_alloc,
923+
..
924+
} = compute_holder.to_data();
925+
926+
let input = host_alloc.alloc(1 << log_len).unwrap();
927+
input.fill_with(|| F::random(&mut rng));
928+
929+
let mut round_outputs = Vec::new();
930+
let mut current_len = input.len() / 2;
931+
while current_len >= 1 {
932+
round_outputs.push(dev_alloc.alloc(current_len).unwrap());
933+
current_len /= 2;
934+
}
935+
936+
let mut dev_input = dev_alloc.alloc(input.len()).unwrap();
937+
hal.copy_h2d(input, &mut dev_input).unwrap();
938+
hal.execute(|exec| {
939+
exec.pairwise_product_reduce(Hal::DevMem::as_const(&dev_input), round_outputs.as_mut())
940+
.unwrap();
941+
Ok(vec![])
942+
})
943+
.unwrap();
944+
let mut working_input = input.to_vec();
945+
let mut round_idx = 0;
946+
while working_input.len() >= 2 {
947+
let mut round_results = (0..working_input.len() / 2).map(|_| F::ZERO).collect_vec();
948+
for idx in 0..round_results.len() {
949+
round_results[idx] = working_input[idx * 2] * working_input[idx * 2 + 1];
950+
}
951+
952+
let actual_round_result = host_alloc.alloc(working_input.len() / 2).unwrap();
953+
hal.copy_d2h(Hal::DevMem::as_const(&round_outputs[round_idx]), actual_round_result)
954+
.unwrap();
955+
assert_eq!(round_results, actual_round_result);
956+
957+
working_input = round_results;
958+
round_idx += 1;
959+
}
960+
}

crates/fast_compute/src/layer.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,15 @@ impl<'a, T: TowerFamily, P: PackedTop<T>> ComputeLayerExecutor<T::B128>
619619

620620
result.into_iter().map(|(_, out)| out).collect()
621621
}
622+
623+
fn pairwise_product_reduce(
624+
&mut self,
625+
_input: <Self::DevMem as ComputeMemory<T::B128>>::FSlice<'_>,
626+
_round_outputs: &mut [<Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>],
627+
) -> Result<(), Error> {
628+
// TODO(CRY-490)
629+
todo!()
630+
}
622631
}
623632

624633
/// In case when `P1` and `P2` are the same type, this function performs the extrapolation

0 commit comments

Comments
 (0)