diff --git a/src/neutronnova_zk.rs b/src/neutronnova_zk.rs index 02f8e9f0..f39a782b 100644 --- a/src/neutronnova_zk.rs +++ b/src/neutronnova_zk.rs @@ -234,7 +234,8 @@ where // Squeeze tau and rhos fresh inside this function (like ZK sum-check APIs) let (ell_cons, left, right) = compute_tensor_decomp(S.num_cons); let tau = transcript.squeeze(b"tau")?; - let E_eq = PowPolynomial::new(&tau, ell_cons).split_evals(left, right); + + let E_eq = PowPolynomial::split_evals(tau, ell_cons, left, right); let mut rhos = Vec::with_capacity(ell_b); for _ in 0..ell_b { @@ -738,19 +739,13 @@ where info!(elapsed_ms = %nifs_t.elapsed().as_millis(), "NIFS"); let (_tensor_span, tensor_t) = start_span!("compute_tensor_and_poly_tau"); - let (_ell, left, right) = compute_tensor_decomp(pk.S_step.num_cons); - let (E1, E2) = E_eq.split_at(left); - let mut full_E = vec![E::Scalar::ONE; left * right]; - full_E - .par_chunks_mut(left) - .enumerate() - .for_each(|(i, row)| { - let e2 = E2[i]; - row.iter_mut().zip(E1.iter()).for_each(|(val, e1)| { - *val = e2 * *e1; - }); - }); - let mut poly_tau = MultilinearPolynomial::new(full_E); + let (_ell, left, _right) = compute_tensor_decomp(pk.S_step.num_cons); + let mut E1 = E_eq; + let E2 = E1.split_off(left); + + let mut poly_tau_left = MultilinearPolynomial::new(E1); + let poly_tau_right = MultilinearPolynomial::new(E2); + info!(elapsed_ms = %tensor_t.elapsed().as_millis(), "compute_tensor_and_poly_tau"); // outer sum-check preparation @@ -786,7 +781,8 @@ where let (_sc_span, sc_t) = start_span!("outer_sumcheck_batched"); let r_x = SumcheckProof::::prove_cubic_with_additive_term_batched_zk( num_rounds_x, - &mut poly_tau, + &mut poly_tau_left, + &poly_tau_right, &mut poly_Az_step, &mut poly_Az_core, &mut poly_Bz_step, @@ -808,7 +804,7 @@ where vc.claim_Az_core = poly_Az_core[0]; vc.claim_Bz_core = poly_Bz_core[0]; vc.claim_Cz_core = poly_Cz_core[0]; - vc.tau_at_rx = poly_tau[0]; + vc.tau_at_rx = poly_tau_left[0]; let chals = SatisfyingAssignment::::process_round( &mut vc_state, diff --git a/src/polys/power.rs b/src/polys/power.rs index 40bd1035..f56d55c1 100644 --- a/src/polys/power.rs +++ b/src/polys/power.rs @@ -54,14 +54,12 @@ impl PowPolynomial { } /// Computes two vectors such that their outer product equals the output of the `evals` function. - /// This code ensures - pub fn split_evals(&self, len_left: usize, len_right: usize) -> Vec { + /// The left vector is 1, t, t^2, ..., t^{2^{ell/2}-1} + /// and the right vector is 1, t^{2^{ell/2}}, ..., t^{(2^{ell/2}-1) * 2^{ell/2}}. + pub fn split_evals(t: Scalar, ell: usize, len_left: usize, len_right: usize) -> Vec { // Compute the number of elements in the left and right halves - let ell = self.t_pow.len(); assert_eq!(len_left * len_right, 1 << ell); - let t = self.t_pow[0]; - // Compute the left and right halves of the evaluations // left = [1, t, t^2, ..., t^{2^{ell/2} - 1}] let left = successors(Some(Scalar::ONE), |p| Some(*p * t)) @@ -128,7 +126,8 @@ mod tests { assert_eq!(evals.len(), 1 << ell); // now compute split evals - let split_evals = pow.split_evals(1 << (ell / 2), 1 << (ell - ell / 2)); + let split_evals = + PowPolynomial::split_evals(t, pow.t_pow.len(), 1 << (ell / 2), 1 << (ell - ell / 2)); let (left, right) = split_evals.split_at(1 << (ell / 2)); // check that the outer product of left and right equals evals diff --git a/src/sumcheck.rs b/src/sumcheck.rs index ec337b51..ca446c13 100644 --- a/src/sumcheck.rs +++ b/src/sumcheck.rs @@ -359,7 +359,7 @@ impl SumcheckProof { where F: Fn(&E::Scalar, &E::Scalar, &E::Scalar, &E::Scalar) -> E::Scalar + Sync, { - let len = poly_A.Z.len() / 2; + let len = poly_B.Z.len() / 2; par_for( len, |i| { @@ -410,6 +410,117 @@ impl SumcheckProof { ) } + #[inline] + /// Computes evaluation points for a cubic polynomial with additive term. + /// The outer polynomial is the power of tau, which is an outer product of two polynomials left and right. + /// + /// This function computes three evaluation points (at 0, 2, and 3) for a univariate + /// polynomial that represents the sum over a hypercube edge in the sum-check protocol + /// for a cubic combination of three multilinear polynomials. + /// + /// # Arguments + /// * `pow_tau_left` - The left part of the power of tau + /// * `pow_tau_right` - The right part of the power of tau + /// * `poly_A` - First multilinear polynomial + /// * `poly_B` - Second multilinear polynomial + /// * `poly_C` - Third multilinear polynomial + /// * `comb_func` - Function that combines evaluations of the four polynomials + /// + /// # Returns + /// A tuple containing the evaluations at points 0, 2, and 3. + fn compute_eval_points_cubic_with_additive_term_with_outer_pow( + pow_tau_left: &MultilinearPolynomial, + pow_tau_right: &MultilinearPolynomial, + poly_A: &MultilinearPolynomial, + poly_B: &MultilinearPolynomial, + poly_C: &MultilinearPolynomial, + comb_func: &F, + ) -> (E::Scalar, E::Scalar, E::Scalar) + where + F: Fn(&E::Scalar, &E::Scalar, &E::Scalar, &E::Scalar) -> E::Scalar + Sync, + { + let len = poly_A.Z.len() / 2; + let left = pow_tau_left.Z.len(); + + if len < left { + return Self::compute_eval_points_cubic_with_additive_term( + pow_tau_left, + poly_A, + poly_B, + poly_C, + comb_func, + ); + } + + let right = len / left; + + par_for( + left, + |i| { + let pow_left = pow_tau_left[i]; + + let mut acc_0 = E::Scalar::ZERO; + let mut acc_2 = E::Scalar::ZERO; + let mut acc_3 = E::Scalar::ZERO; + + for j in 0..right { + let low = i + j * left; + let high = low + len; + + let tau_low_right = pow_tau_right[j]; + let tau_high_right = pow_tau_right[j + right]; + + let a_low = poly_A[low]; + let a_high = poly_A[high]; + let b_low = poly_B[low]; + let b_high = poly_B[high]; + let c_low = poly_C[low]; + let c_high = poly_C[high]; + + // eval 0: bound_func is A(low) + let eval_point_0 = comb_func(&tau_low_right, &a_low, &b_low, &c_low); + + // eval 2: bound_func is -A(low) + 2*A(high) + let poly_tau_bound_point = tau_high_right + tau_high_right - tau_low_right; + let poly_A_bound_point = a_high + a_high - a_low; + let poly_B_bound_point = b_high + b_high - b_low; + let poly_C_bound_point = c_high + c_high - c_low; + let eval_point_2 = comb_func( + &poly_tau_bound_point, + &poly_A_bound_point, + &poly_B_bound_point, + &poly_C_bound_point, + ); + + // eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2) + let poly_tau_bound_point = poly_tau_bound_point + tau_high_right - tau_low_right; + let poly_A_bound_point = poly_A_bound_point + a_high - a_low; + let poly_B_bound_point = poly_B_bound_point + b_high - b_low; + let poly_C_bound_point = poly_C_bound_point + c_high - c_low; + let eval_point_3 = comb_func( + &poly_tau_bound_point, + &poly_A_bound_point, + &poly_B_bound_point, + &poly_C_bound_point, + ); + + acc_0 += eval_point_0; + acc_2 += eval_point_2; + acc_3 += eval_point_3; + } + + (acc_0 * pow_left, acc_2 * pow_left, acc_3 * pow_left) + }, + |mut acc, val| { + acc.0 += val.0; + acc.1 += val.1; + acc.2 += val.2; + acc + }, + || (E::Scalar::ZERO, E::Scalar::ZERO, E::Scalar::ZERO), + ) + } + /// Generates a sum-check proof for a cubic combination with additive term of four multilinear polynomials. /// /// # Arguments @@ -823,13 +934,14 @@ impl SumcheckProof { /// and returns the sequence of verifier challenges. pub fn prove_cubic_with_additive_term_batched_zk( num_rounds: usize, - poly_A: &mut MultilinearPolynomial, + pow_tau_left: &mut MultilinearPolynomial, + pow_tau_right: &MultilinearPolynomial, + poly_A_step: &mut MultilinearPolynomial, + poly_A_core: &mut MultilinearPolynomial, poly_B_step: &mut MultilinearPolynomial, poly_B_core: &mut MultilinearPolynomial, poly_C_step: &mut MultilinearPolynomial, poly_C_core: &mut MultilinearPolynomial, - poly_D_step: &mut MultilinearPolynomial, - poly_D_core: &mut MultilinearPolynomial, verifier_circuit: &mut NeutronNovaVerifierCircuit, state: &mut MultiRoundState, vc_shape: &SplitMultiRoundR1CSShape, @@ -837,6 +949,9 @@ impl SumcheckProof { transcript: &mut E::TE, start_round: usize, ) -> Result, SpartanError> { + let mut base_tau = E::Scalar::ONE; + let mut len_pow_tau = pow_tau_left.Z.len() * pow_tau_right.Z.len(); + let mut r_x: Vec = Vec::with_capacity(num_rounds); let mut claim_step = verifier_circuit.t_out_step; @@ -848,26 +963,36 @@ impl SumcheckProof { }; // step branch - let ((eval0_s, eval2_s, eval3_s), (eval0_c, eval2_c, eval3_c)) = rayon::join( - || { - Self::compute_eval_points_cubic_with_additive_term( - poly_A, - poly_B_step, - poly_C_step, - poly_D_step, - &comb, - ) - }, - || { - Self::compute_eval_points_cubic_with_additive_term( - poly_A, - poly_B_core, - poly_C_core, - poly_D_core, - &comb, - ) - }, - ); + let ((mut eval0_s, mut eval2_s, mut eval3_s), (mut eval0_c, mut eval2_c, mut eval3_c)) = + rayon::join( + || { + Self::compute_eval_points_cubic_with_additive_term_with_outer_pow( + pow_tau_left, + pow_tau_right, + poly_A_step, + poly_B_step, + poly_C_step, + &comb, + ) + }, + || { + Self::compute_eval_points_cubic_with_additive_term_with_outer_pow( + pow_tau_left, + pow_tau_right, + poly_A_core, + poly_B_core, + poly_C_core, + &comb, + ) + }, + ); + + eval0_s *= base_tau; + eval2_s *= base_tau; + eval3_s *= base_tau; + eval0_c *= base_tau; + eval2_c *= base_tau; + eval3_c *= base_tau; let evals_s = vec![eval0_s, claim_step - eval0_s, eval2_s, eval3_s]; let poly_s = UniPoly::from_evals(&evals_s)?; @@ -908,7 +1033,12 @@ impl SumcheckProof { // bind polynomials to the verifier's challenge rayon::join( - || poly_A.bind_poly_var_top(&r_i), + || { + rayon::join( + || poly_A_step.bind_poly_var_top(&r_i), + || poly_A_core.bind_poly_var_top(&r_i), + ); + }, || { rayon::join( || { @@ -919,25 +1049,25 @@ impl SumcheckProof { }, || { rayon::join( - || { - rayon::join( - || poly_C_step.bind_poly_var_top(&r_i), - || poly_C_core.bind_poly_var_top(&r_i), - ); - }, - || { - rayon::join( - || poly_D_step.bind_poly_var_top(&r_i), - || poly_D_core.bind_poly_var_top(&r_i), - ); - }, + || poly_C_step.bind_poly_var_top(&r_i), + || poly_C_core.bind_poly_var_top(&r_i), ); }, ); }, ); + + // bind polynomial power of tau + // list power of tau (pow_tau) halves effectively + len_pow_tau >>= 1; + let one = E::Scalar::ONE; + let left = pow_tau_left.Z.len(); + let pow = pow_tau_left.Z[len_pow_tau % left] * pow_tau_right.Z[len_pow_tau / left]; + base_tau *= (pow - one) * r_i + one; } + pow_tau_left.Z[0] = base_tau; + Ok(r_x) } }