Skip to content
This repository was archived by the owner on Sep 9, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 68 additions & 14 deletions crates/fast_compute/src/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ use binius_utils::{
strided_array::StridedArray2DViewMut,
};
use bytemuck::{Pod, zeroed_vec};
use itertools::izip;
use itertools::{Itertools, izip};
use thread_local::ThreadLocal;

use crate::{
arith_circuit::ArithCircuitPoly,
memory::{PackedMemory, PackedMemorySliceMut},
memory::{PackedMemory, PackedMemorySlice, PackedMemorySliceMut},
};

/// Optimized CPU implementation of the compute layer.
Expand Down Expand Up @@ -176,14 +176,11 @@ impl<T: TowerFamily, P: PackedTop<T>> ComputeLayer<T::B128> for FastCpuLayer<T,
slice: &mut <Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>,
value: T::B128,
) -> Result<(), Error> {
match slice {
PackedMemorySliceMut::Slice(items) => {
items.fill(P::broadcast(value));
}
PackedMemorySliceMut::SingleElement { owned, .. } => {
owned.fill(value);
}
};
let value = P::broadcast(value);

for element in slice.as_slice_mut() {
*element = value;
}
Comment on lines +179 to +183
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!
Can be simplified even further:

slice.as_slice_mut().fill(P::broadcast(value));

Ok(())
}
}
Expand Down Expand Up @@ -622,11 +619,68 @@ impl<'a, T: TowerFamily, P: PackedTop<T>> ComputeLayerExecutor<T::B128>

fn pairwise_product_reduce(
&mut self,
_input: <Self::DevMem as ComputeMemory<T::B128>>::FSlice<'_>,
_round_outputs: &mut [<Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>],
input: <Self::DevMem as ComputeMemory<T::B128>>::FSlice<'_>,
round_outputs: &mut [<Self::DevMem as ComputeMemory<T::B128>>::FSliceMut<'_>],
) -> Result<(), Error> {
// TODO(CRY-490)
todo!()
let log_num_inputs = match strict_log_2(input.len()) {
None => {
return Err(Error::InputValidation(format!(
"input length must be a power of 2: {}",
input.len()
)));
}
Some(0) => {
return Err(Error::InputValidation(format!(
"input length must be greater than or equal to 2 in order to perform at least one reduction: {}",
input.len()
)));
}
Some(log_num_inputs) => log_num_inputs,
};
let expected_round_outputs_len = log_num_inputs;
if round_outputs.len() != expected_round_outputs_len as usize {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe move the logic verifying the input and output dimensions to a separate helper function to share between the reference and fast implementations?

return Err(Error::InputValidation(format!(
"round_outputs.len() does not match the expected length: {} != {expected_round_outputs_len}",
round_outputs.len()
)));
}
for (round_idx, round_output_data) in round_outputs.iter().enumerate() {
let expected_output_size = 1usize << (log_num_inputs as usize - round_idx - 1);
if round_output_data.len() != expected_output_size {
return Err(Error::InputValidation(format!(
"round_outputs[{}].len() = {}, expected {expected_output_size}",
round_idx,
round_output_data.len()
)));
}
}

let mut round_data_source = input;
for round_output_data in round_outputs.iter_mut() {
match round_data_source {
PackedMemorySlice::Slice(input) => {
input
.par_chunks(2)
.zip(round_output_data.as_slice_mut().par_iter_mut())
.for_each(|(chunk, output)| {
let scalar_iter = P::iter_slice(chunk)
.tuples()
.map(|(left, right)| left * right);
*output = P::from_scalars(scalar_iter);
});
}
Comment on lines +665 to +671
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potentially it would be faster to re-use the packed multiplication by interleaving values:

let (lhs, rhs) = PackedField::interleave(chunk[0], chunk[1]);
let mults = lhs*rhs;

PackedField::from_scalars(mults.iter().step_by(2).copied(), mults.iter().skip(1).step_by(2).copied())

PackedMemorySlice::Owned(..) => {
let scalar_iter = P::iter_slice(round_data_source.as_slice())
.tuples()
.map(|(left, right)| left * right);

round_output_data.as_slice_mut()[0] = P::from_scalars(scalar_iter);
}
}
round_data_source = round_output_data.as_const();
}

Ok(())
}
}

Expand Down
4 changes: 0 additions & 4 deletions crates/fast_compute/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,6 @@ impl<P: PackedField> SmallOwnedChunk<P> {
fn iter_scalars(&self) -> impl Iterator<Item = P::Scalar> {
self.data.iter().take(self.len)
}

pub fn fill(&mut self, value: P::Scalar) {
self.data = P::broadcast(value)
}
}

/// Memory slice that can be either a borrowed slice or an owned small chunk (with length <
Expand Down
20 changes: 20 additions & 0 deletions crates/fast_compute/tests/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,23 @@ fn test_map_kernels() {
log_len,
);
}

#[test]
fn test_pairwise_product_reduce_single_round() {
type P = PackedBinaryField4x128b;
let log_len = 1;
binius_compute_test_utils::layer::test_generic_pairwise_product_reduce(
FastCpuLayerHolder::<CanonicalTowerFamily, P>::new(1 << (log_len + 4), 1 << (log_len + 3)),
log_len,
);
}

#[test]
fn test_pairwise_product_reduce() {
type P = PackedBinaryField4x128b;
let log_len = 8;
binius_compute_test_utils::layer::test_generic_pairwise_product_reduce(
FastCpuLayerHolder::<CanonicalTowerFamily, P>::new(1 << (log_len + 4), 1 << (log_len + 3)),
log_len,
);
}
Loading