diff --git a/crates/compute_test_utils/src/piop.rs b/crates/compute_test_utils/src/piop.rs index 32b007635..81fa73878 100644 --- a/crates/compute_test_utils/src/piop.rs +++ b/crates/compute_test_utils/src/piop.rs @@ -159,11 +159,12 @@ pub fn commit_prove_verify>(); - - hal.copy_h2d(&evals[0..1 << mle.n_vars()], &mut buffer) - .unwrap(); + hal.copy_h2d(host_buffer, &mut buffer).unwrap(); HAL::DevMem::to_const(buffer) }) .collect::>(); diff --git a/crates/compute_test_utils/src/ring_switch.rs b/crates/compute_test_utils/src/ring_switch.rs index 44dea8b8d..e2d16f5ce 100644 --- a/crates/compute_test_utils/src/ring_switch.rs +++ b/crates/compute_test_utils/src/ring_switch.rs @@ -207,7 +207,7 @@ pub fn commit_prove_verify_piop( let mut rng = StdRng::seed_from_u64(0); let merkle_scheme = merkle_prover.scheme(); - let mut compute_holder = create_hal_holder(1 << 7, 1 << 15); + let mut compute_holder = create_hal_holder(1 << 12, 1 << 15); let mut compute_data = compute_holder.to_data(); let compute_data_ref = &mut compute_data; diff --git a/crates/core/src/piop/prove.rs b/crates/core/src/piop/prove.rs index ac53d32a2..83c698eba 100644 --- a/crates/core/src/piop/prove.rs +++ b/crates/core/src/piop/prove.rs @@ -288,6 +288,7 @@ where prove_interleaved_fri_sumcheck( hal, + host_alloc, dev_alloc, commit_meta.total_vars(), fri_params, @@ -305,6 +306,7 @@ where #[allow(clippy::too_many_arguments)] fn prove_interleaved_fri_sumcheck( hal: &Hal, + host_alloc: &impl ComputeAllocator, dev_alloc: &impl ComputeAllocator, n_rounds: usize, fri_params: &FRIParams, @@ -381,7 +383,7 @@ where ?dimensions_data, ) .entered(); - match fri_prover.execute_fold_round(dev_alloc, challenge)? { + match fri_prover.execute_fold_round(host_alloc, dev_alloc, challenge)? { FoldRoundOutput::NoCommitment => {} FoldRoundOutput::Commitment(round_commitment) => { transcript.message().write(&round_commitment); @@ -391,7 +393,7 @@ where } sumcheck_batch_prover.finish(&mut transcript.message())?; - fri_prover.finish_proof(transcript)?; + fri_prover.finish_proof(transcript, host_alloc)?; Ok(()) } diff --git a/crates/core/src/protocols/fri/common.rs b/crates/core/src/protocols/fri/common.rs index 7cb6b7870..af6bf30d6 100644 --- a/crates/core/src/protocols/fri/common.rs +++ b/crates/core/src/protocols/fri/common.rs @@ -190,7 +190,7 @@ where } /// The type of the termination round codeword in the FRI protocol. -pub type TerminateCodeword = Vec; +pub type TerminateCodeword<'a, F> = &'a [F]; /// Calculates the number of test queries required to achieve a target security level. /// diff --git a/crates/core/src/protocols/fri/prove.rs b/crates/core/src/protocols/fri/prove.rs index 749df6e75..81b00dfd1 100644 --- a/crates/core/src/protocols/fri/prove.rs +++ b/crates/core/src/protocols/fri/prove.rs @@ -2,6 +2,7 @@ use binius_compute::{ ComputeLayerExecutor, alloc::ComputeAllocator, + cpu::CpuMemory, layer::ComputeLayer, memory::{ComputeMemory, SizedSlice}, }; @@ -209,7 +210,7 @@ where MerkleProver: MerkleTreeProver, { /// The folded codeword on the host - pub host_codeword: Vec, + pub host_codeword: &'b [F], /// The folded codeword on the device pub device_codeword: >::FSliceMut<'b>, /// The Merkle tree commitment @@ -308,7 +309,8 @@ where /// intermediate folded codewords. pub fn execute_fold_round( &mut self, - allocator: &'b impl ComputeAllocator, + host_alloc: &'b impl ComputeAllocator, + dev_alloc: &'b impl ComputeAllocator, challenge: F, ) -> Result, Error> { self.unprocessed_challenges.push(challenge); @@ -343,7 +345,7 @@ where Some(prev_round) => { // Fold a full codeword committed in the previous FRI round into a codeword with // reduced dimension and rate. - let mut folded_codeword = allocator.alloc( + let mut folded_codeword = dev_alloc.alloc( prev_round.device_codeword.len() / (1 << self.unprocessed_challenges.len()), )?; @@ -364,13 +366,19 @@ where } None => { let codeword_len = len_packed_slice(self.codeword); - let mut original_codeword = allocator.alloc(codeword_len)?; + let mut original_codeword = dev_alloc.alloc(codeword_len)?; unpack_if_possible( self.codeword, - |scalars| self.cl.copy_h2d(scalars, &mut original_codeword), + |scalars| { + // TODO: this extra copy possibly can be eliminated at least for the CPU + // layers. + let host_buffer = host_alloc.alloc(codeword_len)?; + host_buffer.copy_from_slice(scalars); + self.cl.copy_h2d(scalars, &mut original_codeword) + }, |_packed| unimplemented!("non-dense packed fields not supported"), )?; - let mut folded_codeword = allocator.alloc( + let mut folded_codeword = dev_alloc.alloc( 1 << (self.params.rs_code().log_len() - (self.unprocessed_challenges.len() - self.params.log_batch_size())), )?; @@ -393,9 +401,9 @@ where drop(fri_fold_span); self.unprocessed_challenges.clear(); - let mut folded_codeword_host = zeroed_vec(folded_codeword.len()); + let folded_codeword_host = host_alloc.alloc(folded_codeword.len())?; self.cl - .copy_d2h(Hal::DevMem::as_const(&folded_codeword), &mut folded_codeword_host)?; + .copy_d2h(Hal::DevMem::as_const(&folded_codeword), folded_codeword_host)?; // take the first arity as coset_log_len, or use inv_rate if arities are empty let coset_size = self @@ -415,7 +423,7 @@ where .entered(); let (commitment, committed) = self .merkle_prover - .commit(&folded_codeword_host, coset_size) + .commit(folded_codeword_host, coset_size) .map_err(|err| Error::VectorCommit(Box::new(err)))?; drop(merkle_tree_span); @@ -443,16 +451,38 @@ where #[allow(clippy::type_complexity)] pub fn finalize( mut self, - ) -> Result<(TerminateCodeword, FRIQueryProver<'a, F, FA, P, MerkleProver, VCS>), Error> { + host_alloc: &'b impl ComputeAllocator, + ) -> Result< + (TerminateCodeword<'b, F>, FRIQueryProver<'a, 'b, F, FA, P, MerkleProver, VCS>), + Error, + > { if self.curr_round != self.n_rounds() { bail!(Error::EarlyProverFinish); } - let terminate_codeword = self + let terminate_codeword: &'b [F] = self .round_committed .last() - .map(|round| round.host_codeword.clone()) - .unwrap_or_else(|| PackedField::iter_slice(self.codeword).collect()); + .map(|round| -> Result<&[F], Error> { Ok(round.host_codeword) }) + .unwrap_or_else(|| { + unpack_if_possible( + self.codeword, + |scalars| { + let buffer = host_alloc.alloc(len_packed_slice(self.codeword))?; + buffer.copy_from_slice(scalars); + + Ok(&buffer[..]) + }, + |packed| { + let buffer = host_alloc.alloc(len_packed_slice(self.codeword))?; + for (src, dst) in PackedField::iter_slice(packed).zip(buffer.iter_mut()) { + *dst = src; + } + + Ok(buffer) + }, + ) + })?; self.unprocessed_challenges.clear(); @@ -483,17 +513,18 @@ where pub fn finish_proof( self, transcript: &mut ProverTranscript, + host_alloc: &'b impl ComputeAllocator, ) -> Result<(), Error> where Challenger_: Challenger, { - let (terminate_codeword, query_prover) = self.finalize()?; + let (terminate_codeword, query_prover) = self.finalize(host_alloc)?; let mut advice = transcript.decommitment(); - advice.write_scalar_slice(&terminate_codeword); + advice.write_scalar_slice(terminate_codeword); let layers = query_prover.vcs_optimal_layers()?; - for layer in layers { - advice.write_slice(&layer); + for layer in &layers { + advice.write_slice(layer); } let params = query_prover.params; @@ -508,7 +539,7 @@ where } /// A prover for the FRI query phase. -pub struct FRIQueryProver<'a, F, FA, P, MerkleProver, VCS> +pub struct FRIQueryProver<'a, 'b, F, FA, P, MerkleProver, VCS> where F: BinaryField, FA: BinaryField, @@ -519,11 +550,11 @@ where params: &'a FRIParams, codeword: &'a [P], codeword_committed: &'a MerkleProver::Committed, - round_committed: Vec<(Vec, MerkleProver::Committed)>, + round_committed: Vec<(&'b [F], MerkleProver::Committed)>, merkle_prover: &'a MerkleProver, } -impl FRIQueryProver<'_, F, FA, P, MerkleProver, VCS> +impl FRIQueryProver<'_, '_, F, FA, P, MerkleProver, VCS> where F: TowerField + ExtensionField, FA: BinaryField, diff --git a/crates/core/src/protocols/fri/tests.rs b/crates/core/src/protocols/fri/tests.rs index be847b9ce..149812a02 100644 --- a/crates/core/src/protocols/fri/tests.rs +++ b/crates/core/src/protocols/fri/tests.rs @@ -136,8 +136,13 @@ fn test_commit_prove_verify_success( codeword, } = fri::commit_interleaved(&committed_rs_code, ¶ms, &ntt, &merkle_prover, &msg).unwrap(); - let mut compute_holder = CpuLayerHolder::::new(1 << 10, 1 << 20); - let ComputeData { hal, dev_alloc, .. } = compute_holder.to_data(); + let mut compute_holder = CpuLayerHolder::::new(1 << 11, 1 << 20); + let ComputeData { + hal, + dev_alloc, + host_alloc, + .. + } = compute_holder.to_data(); // Run the prover to generate the proximity proof let mut round_prover = @@ -149,7 +154,7 @@ fn test_commit_prove_verify_success( for _i in 0..params.n_fold_rounds() { let challenge = prover_challenger.sample(); let fold_round_output = round_prover - .execute_fold_round(&dev_alloc, challenge) + .execute_fold_round(&host_alloc, &dev_alloc, challenge) .unwrap(); match fold_round_output { FoldRoundOutput::NoCommitment => {} @@ -160,7 +165,9 @@ fn test_commit_prove_verify_success( } } - round_prover.finish_proof(&mut prover_challenger).unwrap(); + round_prover + .finish_proof(&mut prover_challenger, &host_alloc) + .unwrap(); // Now run the verifier let mut verifier_challenger = prover_challenger.into_verifier(); codeword_commitment = verifier_challenger.message().read().unwrap(); diff --git a/crates/core/tests/piop.rs b/crates/core/tests/piop.rs index 3860de4a9..be0723e81 100644 --- a/crates/core/tests/piop.rs +++ b/crates/core/tests/piop.rs @@ -64,7 +64,7 @@ fn test_commit_prove_verify_extreme_rate() { let merkle_prover = BinaryMerkleTreeProver::<_, Groestl256, _>::new(Groestl256ByteCompression); let n_transparents = 2; let log_inv_rate = 8; - let compute_holder = CpuLayerHolder::::new(1 << 14, 1 << 22); + let compute_holder = CpuLayerHolder::::new(1 << 16, 1 << 22); commit_prove_verify::( compute_holder,