Skip to content
This repository was archived by the owner on Sep 9, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions crates/compute_test_utils/src/piop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,12 @@ pub fn commit_prove_verify<FDomain, FEncode, F, P, MTScheme, HAL, ComputeHolderT
.iter()
.map(|mle| {
let mut buffer = dev_alloc.alloc(1 << mle.n_vars()).unwrap();
let host_buffer = host_alloc.alloc(1 << mle.n_vars()).unwrap();
for (input, out) in P::iter_slice(mle.evals()).zip(host_buffer.iter_mut()) {
*out = input;
}

let evals = P::iter_slice(mle.evals()).collect::<Vec<_>>();

hal.copy_h2d(&evals[0..1 << mle.n_vars()], &mut buffer)
.unwrap();
hal.copy_h2d(host_buffer, &mut buffer).unwrap();
HAL::DevMem::to_const(buffer)
})
.collect::<Vec<_>>();
Expand Down
2 changes: 1 addition & 1 deletion crates/compute_test_utils/src/ring_switch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ pub fn commit_prove_verify_piop<U, F, MTScheme, MTProver, Hal, HalHolder>(
let mut rng = StdRng::seed_from_u64(0);
let merkle_scheme = merkle_prover.scheme();

let mut compute_holder = create_hal_holder(1 << 7, 1 << 15);
let mut compute_holder = create_hal_holder(1 << 12, 1 << 15);
let mut compute_data = compute_holder.to_data();

let compute_data_ref = &mut compute_data;
Expand Down
6 changes: 4 additions & 2 deletions crates/core/src/piop/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ where

prove_interleaved_fri_sumcheck(
hal,
host_alloc,
dev_alloc,
commit_meta.total_vars(),
fri_params,
Expand All @@ -305,6 +306,7 @@ where
#[allow(clippy::too_many_arguments)]
fn prove_interleaved_fri_sumcheck<Hal, F, FEncode, P, NTT, MTScheme, MTProver, Challenger_>(
hal: &Hal,
host_alloc: &impl ComputeAllocator<F, CpuMemory>,
dev_alloc: &impl ComputeAllocator<F, Hal::DevMem>,
n_rounds: usize,
fri_params: &FRIParams<F, FEncode>,
Expand Down Expand Up @@ -381,7 +383,7 @@ where
?dimensions_data,
)
.entered();
match fri_prover.execute_fold_round(dev_alloc, challenge)? {
match fri_prover.execute_fold_round(host_alloc, dev_alloc, challenge)? {
FoldRoundOutput::NoCommitment => {}
FoldRoundOutput::Commitment(round_commitment) => {
transcript.message().write(&round_commitment);
Expand All @@ -391,7 +393,7 @@ where
}

sumcheck_batch_prover.finish(&mut transcript.message())?;
fri_prover.finish_proof(transcript)?;
fri_prover.finish_proof(transcript, host_alloc)?;
Ok(())
}

Expand Down
2 changes: 1 addition & 1 deletion crates/core/src/protocols/fri/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ where
}

/// The type of the termination round codeword in the FRI protocol.
pub type TerminateCodeword<F> = Vec<F>;
pub type TerminateCodeword<'a, F> = &'a [F];

/// Calculates the number of test queries required to achieve a target security level.
///
Expand Down
71 changes: 51 additions & 20 deletions crates/core/src/protocols/fri/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use binius_compute::{
ComputeLayerExecutor,
alloc::ComputeAllocator,
cpu::CpuMemory,
layer::ComputeLayer,
memory::{ComputeMemory, SizedSlice},
};
Expand Down Expand Up @@ -209,7 +210,7 @@ where
MerkleProver: MerkleTreeProver<F>,
{
/// The folded codeword on the host
pub host_codeword: Vec<F>,
pub host_codeword: &'b [F],
/// The folded codeword on the device
pub device_codeword: <Hal::DevMem as ComputeMemory<F>>::FSliceMut<'b>,
/// The Merkle tree commitment
Expand Down Expand Up @@ -308,7 +309,8 @@ where
/// intermediate folded codewords.
pub fn execute_fold_round(
&mut self,
allocator: &'b impl ComputeAllocator<F, Hal::DevMem>,
host_alloc: &'b impl ComputeAllocator<F, CpuMemory>,
dev_alloc: &'b impl ComputeAllocator<F, Hal::DevMem>,
challenge: F,
) -> Result<FoldRoundOutput<VCS::Digest>, Error> {
self.unprocessed_challenges.push(challenge);
Expand Down Expand Up @@ -343,7 +345,7 @@ where
Some(prev_round) => {
// Fold a full codeword committed in the previous FRI round into a codeword with
// reduced dimension and rate.
let mut folded_codeword = allocator.alloc(
let mut folded_codeword = dev_alloc.alloc(
prev_round.device_codeword.len() / (1 << self.unprocessed_challenges.len()),
)?;

Expand All @@ -364,13 +366,19 @@ where
}
None => {
let codeword_len = len_packed_slice(self.codeword);
let mut original_codeword = allocator.alloc(codeword_len)?;
let mut original_codeword = dev_alloc.alloc(codeword_len)?;
unpack_if_possible(
self.codeword,
|scalars| self.cl.copy_h2d(scalars, &mut original_codeword),
|scalars| {
// TODO: this extra copy possibly can be eliminated at least for the CPU
// layers.
let host_buffer = host_alloc.alloc(codeword_len)?;
host_buffer.copy_from_slice(scalars);
self.cl.copy_h2d(scalars, &mut original_codeword)
},
|_packed| unimplemented!("non-dense packed fields not supported"),
)?;
let mut folded_codeword = allocator.alloc(
let mut folded_codeword = dev_alloc.alloc(
1 << (self.params.rs_code().log_len()
- (self.unprocessed_challenges.len() - self.params.log_batch_size())),
)?;
Expand All @@ -393,9 +401,9 @@ where
drop(fri_fold_span);
self.unprocessed_challenges.clear();

let mut folded_codeword_host = zeroed_vec(folded_codeword.len());
let folded_codeword_host = host_alloc.alloc(folded_codeword.len())?;
self.cl
.copy_d2h(Hal::DevMem::as_const(&folded_codeword), &mut folded_codeword_host)?;
.copy_d2h(Hal::DevMem::as_const(&folded_codeword), folded_codeword_host)?;

// take the first arity as coset_log_len, or use inv_rate if arities are empty
let coset_size = self
Expand All @@ -415,7 +423,7 @@ where
.entered();
let (commitment, committed) = self
.merkle_prover
.commit(&folded_codeword_host, coset_size)
.commit(folded_codeword_host, coset_size)
.map_err(|err| Error::VectorCommit(Box::new(err)))?;
drop(merkle_tree_span);

Expand Down Expand Up @@ -443,16 +451,38 @@ where
#[allow(clippy::type_complexity)]
pub fn finalize(
mut self,
) -> Result<(TerminateCodeword<F>, FRIQueryProver<'a, F, FA, P, MerkleProver, VCS>), Error> {
host_alloc: &'b impl ComputeAllocator<F, CpuMemory>,
) -> Result<
(TerminateCodeword<'b, F>, FRIQueryProver<'a, 'b, F, FA, P, MerkleProver, VCS>),
Error,
> {
if self.curr_round != self.n_rounds() {
bail!(Error::EarlyProverFinish);
}

let terminate_codeword = self
let terminate_codeword: &'b [F] = self
.round_committed
.last()
.map(|round| round.host_codeword.clone())
.unwrap_or_else(|| PackedField::iter_slice(self.codeword).collect());
.map(|round| -> Result<&[F], Error> { Ok(round.host_codeword) })
.unwrap_or_else(|| {
unpack_if_possible(
self.codeword,
|scalars| {
let buffer = host_alloc.alloc(len_packed_slice(self.codeword))?;
buffer.copy_from_slice(scalars);

Ok(&buffer[..])
},
|packed| {
let buffer = host_alloc.alloc(len_packed_slice(self.codeword))?;
for (src, dst) in PackedField::iter_slice(packed).zip(buffer.iter_mut()) {
*dst = src;
}

Ok(buffer)
},
)
})?;

self.unprocessed_challenges.clear();

Expand Down Expand Up @@ -483,17 +513,18 @@ where
pub fn finish_proof<Challenger_>(
self,
transcript: &mut ProverTranscript<Challenger_>,
host_alloc: &'b impl ComputeAllocator<F, CpuMemory>,
) -> Result<(), Error>
where
Challenger_: Challenger,
{
let (terminate_codeword, query_prover) = self.finalize()?;
let (terminate_codeword, query_prover) = self.finalize(host_alloc)?;
let mut advice = transcript.decommitment();
advice.write_scalar_slice(&terminate_codeword);
advice.write_scalar_slice(terminate_codeword);

let layers = query_prover.vcs_optimal_layers()?;
for layer in layers {
advice.write_slice(&layer);
for layer in &layers {
advice.write_slice(layer);
}

let params = query_prover.params;
Expand All @@ -508,7 +539,7 @@ where
}

/// A prover for the FRI query phase.
pub struct FRIQueryProver<'a, F, FA, P, MerkleProver, VCS>
pub struct FRIQueryProver<'a, 'b, F, FA, P, MerkleProver, VCS>
where
F: BinaryField,
FA: BinaryField,
Expand All @@ -519,11 +550,11 @@ where
params: &'a FRIParams<F, FA>,
codeword: &'a [P],
codeword_committed: &'a MerkleProver::Committed,
round_committed: Vec<(Vec<F>, MerkleProver::Committed)>,
round_committed: Vec<(&'b [F], MerkleProver::Committed)>,
merkle_prover: &'a MerkleProver,
}

impl<F, FA, P, MerkleProver, VCS> FRIQueryProver<'_, F, FA, P, MerkleProver, VCS>
impl<F, FA, P, MerkleProver, VCS> FRIQueryProver<'_, '_, F, FA, P, MerkleProver, VCS>
where
F: TowerField + ExtensionField<FA>,
FA: BinaryField,
Expand Down
15 changes: 11 additions & 4 deletions crates/core/src/protocols/fri/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,13 @@ fn test_commit_prove_verify_success<U, F, FA>(
codeword,
} = fri::commit_interleaved(&committed_rs_code, &params, &ntt, &merkle_prover, &msg).unwrap();

let mut compute_holder = CpuLayerHolder::<F>::new(1 << 10, 1 << 20);
let ComputeData { hal, dev_alloc, .. } = compute_holder.to_data();
let mut compute_holder = CpuLayerHolder::<F>::new(1 << 11, 1 << 20);
let ComputeData {
hal,
dev_alloc,
host_alloc,
..
} = compute_holder.to_data();

// Run the prover to generate the proximity proof
let mut round_prover =
Expand All @@ -149,7 +154,7 @@ fn test_commit_prove_verify_success<U, F, FA>(
for _i in 0..params.n_fold_rounds() {
let challenge = prover_challenger.sample();
let fold_round_output = round_prover
.execute_fold_round(&dev_alloc, challenge)
.execute_fold_round(&host_alloc, &dev_alloc, challenge)
.unwrap();
match fold_round_output {
FoldRoundOutput::NoCommitment => {}
Expand All @@ -160,7 +165,9 @@ fn test_commit_prove_verify_success<U, F, FA>(
}
}

round_prover.finish_proof(&mut prover_challenger).unwrap();
round_prover
.finish_proof(&mut prover_challenger, &host_alloc)
.unwrap();
// Now run the verifier
let mut verifier_challenger = prover_challenger.into_verifier();
codeword_commitment = verifier_challenger.message().read().unwrap();
Expand Down
2 changes: 1 addition & 1 deletion crates/core/tests/piop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ fn test_commit_prove_verify_extreme_rate() {
let merkle_prover = BinaryMerkleTreeProver::<_, Groestl256, _>::new(Groestl256ByteCompression);
let n_transparents = 2;
let log_inv_rate = 8;
let compute_holder = CpuLayerHolder::<B128>::new(1 << 14, 1 << 22);
let compute_holder = CpuLayerHolder::<B128>::new(1 << 16, 1 << 22);

commit_prove_verify::<B8, B16, B128, PackedBinaryField2x128b, _, _, _>(
compute_holder,
Expand Down
Loading