Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions src/neutronnova_zk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ where
// Squeeze tau and rhos fresh inside this function (like ZK sum-check APIs)
let (ell_cons, left, right) = compute_tensor_decomp(S.num_cons);
let tau = transcript.squeeze(b"tau")?;
let E_eq = PowPolynomial::new(&tau, ell_cons).split_evals(left, right);

let E_eq = PowPolynomial::split_evals(tau, ell_cons, left, right);

let mut rhos = Vec::with_capacity(ell_b);
for _ in 0..ell_b {
Expand Down Expand Up @@ -738,19 +739,13 @@ where
info!(elapsed_ms = %nifs_t.elapsed().as_millis(), "NIFS");

let (_tensor_span, tensor_t) = start_span!("compute_tensor_and_poly_tau");
let (_ell, left, right) = compute_tensor_decomp(pk.S_step.num_cons);
let (E1, E2) = E_eq.split_at(left);
let mut full_E = vec![E::Scalar::ONE; left * right];
full_E
.par_chunks_mut(left)
.enumerate()
.for_each(|(i, row)| {
let e2 = E2[i];
row.iter_mut().zip(E1.iter()).for_each(|(val, e1)| {
*val = e2 * *e1;
});
});
let mut poly_tau = MultilinearPolynomial::new(full_E);
let (_ell, left, _right) = compute_tensor_decomp(pk.S_step.num_cons);
let mut E1 = E_eq;
let E2 = E1.split_off(left);

let mut poly_tau_left = MultilinearPolynomial::new(E1);
let poly_tau_right = MultilinearPolynomial::new(E2);

info!(elapsed_ms = %tensor_t.elapsed().as_millis(), "compute_tensor_and_poly_tau");

// outer sum-check preparation
Expand Down Expand Up @@ -786,7 +781,8 @@ where
let (_sc_span, sc_t) = start_span!("outer_sumcheck_batched");
let r_x = SumcheckProof::<E>::prove_cubic_with_additive_term_batched_zk(
num_rounds_x,
&mut poly_tau,
&mut poly_tau_left,
&poly_tau_right,
&mut poly_Az_step,
&mut poly_Az_core,
&mut poly_Bz_step,
Expand All @@ -808,7 +804,7 @@ where
vc.claim_Az_core = poly_Az_core[0];
vc.claim_Bz_core = poly_Bz_core[0];
vc.claim_Cz_core = poly_Cz_core[0];
vc.tau_at_rx = poly_tau[0];
vc.tau_at_rx = poly_tau_left[0];

let chals = SatisfyingAssignment::<E>::process_round(
&mut vc_state,
Expand Down
11 changes: 5 additions & 6 deletions src/polys/power.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,12 @@ impl<Scalar: PrimeField> PowPolynomial<Scalar> {
}

/// Computes two vectors such that their outer product equals the output of the `evals` function.
/// This code ensures
pub fn split_evals(&self, len_left: usize, len_right: usize) -> Vec<Scalar> {
/// The left vector is 1, t, t^2, ..., t^{2^{ell/2}-1}
/// and the right vector is 1, t^{2^{ell/2}}, ..., t^{(2^{ell/2}-1) * 2^{ell/2}}.
pub fn split_evals(t: Scalar, ell: usize, len_left: usize, len_right: usize) -> Vec<Scalar> {
// Compute the number of elements in the left and right halves
let ell = self.t_pow.len();
assert_eq!(len_left * len_right, 1 << ell);

let t = self.t_pow[0];

// Compute the left and right halves of the evaluations
// left = [1, t, t^2, ..., t^{2^{ell/2} - 1}]
let left = successors(Some(Scalar::ONE), |p| Some(*p * t))
Expand Down Expand Up @@ -128,7 +126,8 @@ mod tests {
assert_eq!(evals.len(), 1 << ell);

// now compute split evals
let split_evals = pow.split_evals(1 << (ell / 2), 1 << (ell - ell / 2));
let split_evals =
PowPolynomial::split_evals(t, pow.t_pow.len(), 1 << (ell / 2), 1 << (ell - ell / 2));
let (left, right) = split_evals.split_at(1 << (ell / 2));

// check that the outer product of left and right equals evals
Expand Down
204 changes: 167 additions & 37 deletions src/sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ impl<E: Engine> SumcheckProof<E> {
where
F: Fn(&E::Scalar, &E::Scalar, &E::Scalar, &E::Scalar) -> E::Scalar + Sync,
{
let len = poly_A.Z.len() / 2;
let len = poly_B.Z.len() / 2;
par_for(
len,
|i| {
Expand Down Expand Up @@ -410,6 +410,117 @@ impl<E: Engine> SumcheckProof<E> {
)
}

#[inline]
/// Computes evaluation points for a cubic polynomial with additive term.
/// The outer polynomial is the power of tau, which is an outer product of two polynomials left and right.
///
/// This function computes three evaluation points (at 0, 2, and 3) for a univariate
/// polynomial that represents the sum over a hypercube edge in the sum-check protocol
/// for a cubic combination of three multilinear polynomials.
///
/// # Arguments
/// * `pow_tau_left` - The left part of the power of tau
/// * `pow_tau_right` - The right part of the power of tau
/// * `poly_A` - First multilinear polynomial
/// * `poly_B` - Second multilinear polynomial
/// * `poly_C` - Third multilinear polynomial
/// * `comb_func` - Function that combines evaluations of the four polynomials
///
/// # Returns
/// A tuple containing the evaluations at points 0, 2, and 3.
fn compute_eval_points_cubic_with_additive_term_with_outer_pow<F>(
pow_tau_left: &MultilinearPolynomial<E::Scalar>,
pow_tau_right: &MultilinearPolynomial<E::Scalar>,
poly_A: &MultilinearPolynomial<E::Scalar>,
poly_B: &MultilinearPolynomial<E::Scalar>,
poly_C: &MultilinearPolynomial<E::Scalar>,
comb_func: &F,
) -> (E::Scalar, E::Scalar, E::Scalar)
where
F: Fn(&E::Scalar, &E::Scalar, &E::Scalar, &E::Scalar) -> E::Scalar + Sync,
{
let len = poly_A.Z.len() / 2;
let left = pow_tau_left.Z.len();

if len < left {
return Self::compute_eval_points_cubic_with_additive_term(
pow_tau_left,
poly_A,
poly_B,
poly_C,
comb_func,
);
}

let right = len / left;

par_for(
left,
|i| {
let pow_left = pow_tau_left[i];

let mut acc_0 = E::Scalar::ZERO;
let mut acc_2 = E::Scalar::ZERO;
let mut acc_3 = E::Scalar::ZERO;

for j in 0..right {
let low = i + j * left;
let high = low + len;

let tau_low_right = pow_tau_right[j];
let tau_high_right = pow_tau_right[j + right];

let a_low = poly_A[low];
let a_high = poly_A[high];
let b_low = poly_B[low];
let b_high = poly_B[high];
let c_low = poly_C[low];
let c_high = poly_C[high];

// eval 0: bound_func is A(low)
let eval_point_0 = comb_func(&tau_low_right, &a_low, &b_low, &c_low);

// eval 2: bound_func is -A(low) + 2*A(high)
let poly_tau_bound_point = tau_high_right + tau_high_right - tau_low_right;
let poly_A_bound_point = a_high + a_high - a_low;
let poly_B_bound_point = b_high + b_high - b_low;
let poly_C_bound_point = c_high + c_high - c_low;
let eval_point_2 = comb_func(
&poly_tau_bound_point,
&poly_A_bound_point,
&poly_B_bound_point,
&poly_C_bound_point,
);

// eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
let poly_tau_bound_point = poly_tau_bound_point + tau_high_right - tau_low_right;
let poly_A_bound_point = poly_A_bound_point + a_high - a_low;
let poly_B_bound_point = poly_B_bound_point + b_high - b_low;
let poly_C_bound_point = poly_C_bound_point + c_high - c_low;
let eval_point_3 = comb_func(
&poly_tau_bound_point,
&poly_A_bound_point,
&poly_B_bound_point,
&poly_C_bound_point,
);

acc_0 += eval_point_0;
acc_2 += eval_point_2;
acc_3 += eval_point_3;
}

(acc_0 * pow_left, acc_2 * pow_left, acc_3 * pow_left)
},
|mut acc, val| {
acc.0 += val.0;
acc.1 += val.1;
acc.2 += val.2;
acc
},
|| (E::Scalar::ZERO, E::Scalar::ZERO, E::Scalar::ZERO),
)
}

/// Generates a sum-check proof for a cubic combination with additive term of four multilinear polynomials.
///
/// # Arguments
Expand Down Expand Up @@ -823,20 +934,24 @@ impl<E: Engine> SumcheckProof<E> {
/// and returns the sequence of verifier challenges.
pub fn prove_cubic_with_additive_term_batched_zk(
num_rounds: usize,
poly_A: &mut MultilinearPolynomial<E::Scalar>,
pow_tau_left: &mut MultilinearPolynomial<E::Scalar>,
pow_tau_right: &MultilinearPolynomial<E::Scalar>,
poly_A_step: &mut MultilinearPolynomial<E::Scalar>,
poly_A_core: &mut MultilinearPolynomial<E::Scalar>,
poly_B_step: &mut MultilinearPolynomial<E::Scalar>,
poly_B_core: &mut MultilinearPolynomial<E::Scalar>,
poly_C_step: &mut MultilinearPolynomial<E::Scalar>,
poly_C_core: &mut MultilinearPolynomial<E::Scalar>,
poly_D_step: &mut MultilinearPolynomial<E::Scalar>,
poly_D_core: &mut MultilinearPolynomial<E::Scalar>,
verifier_circuit: &mut NeutronNovaVerifierCircuit<E>,
state: &mut MultiRoundState<E>,
vc_shape: &SplitMultiRoundR1CSShape<E>,
vc_ck: &CommitmentKey<E>,
transcript: &mut E::TE,
start_round: usize,
) -> Result<Vec<E::Scalar>, SpartanError> {
let mut base_tau = E::Scalar::ONE;
let mut len_pow_tau = pow_tau_left.Z.len() * pow_tau_right.Z.len();

let mut r_x: Vec<E::Scalar> = Vec::with_capacity(num_rounds);

let mut claim_step = verifier_circuit.t_out_step;
Expand All @@ -848,26 +963,36 @@ impl<E: Engine> SumcheckProof<E> {
};

// step branch
let ((eval0_s, eval2_s, eval3_s), (eval0_c, eval2_c, eval3_c)) = rayon::join(
|| {
Self::compute_eval_points_cubic_with_additive_term(
poly_A,
poly_B_step,
poly_C_step,
poly_D_step,
&comb,
)
},
|| {
Self::compute_eval_points_cubic_with_additive_term(
poly_A,
poly_B_core,
poly_C_core,
poly_D_core,
&comb,
)
},
);
let ((mut eval0_s, mut eval2_s, mut eval3_s), (mut eval0_c, mut eval2_c, mut eval3_c)) =
rayon::join(
|| {
Self::compute_eval_points_cubic_with_additive_term_with_outer_pow(
pow_tau_left,
pow_tau_right,
poly_A_step,
poly_B_step,
poly_C_step,
&comb,
)
},
|| {
Self::compute_eval_points_cubic_with_additive_term_with_outer_pow(
pow_tau_left,
pow_tau_right,
poly_A_core,
poly_B_core,
poly_C_core,
&comb,
)
},
);

eval0_s *= base_tau;
eval2_s *= base_tau;
eval3_s *= base_tau;
eval0_c *= base_tau;
eval2_c *= base_tau;
eval3_c *= base_tau;

let evals_s = vec![eval0_s, claim_step - eval0_s, eval2_s, eval3_s];
let poly_s = UniPoly::from_evals(&evals_s)?;
Expand Down Expand Up @@ -908,7 +1033,12 @@ impl<E: Engine> SumcheckProof<E> {

// bind polynomials to the verifier's challenge
rayon::join(
|| poly_A.bind_poly_var_top(&r_i),
|| {
rayon::join(
|| poly_A_step.bind_poly_var_top(&r_i),
|| poly_A_core.bind_poly_var_top(&r_i),
);
},
|| {
rayon::join(
|| {
Expand All @@ -919,25 +1049,25 @@ impl<E: Engine> SumcheckProof<E> {
},
|| {
rayon::join(
|| {
rayon::join(
|| poly_C_step.bind_poly_var_top(&r_i),
|| poly_C_core.bind_poly_var_top(&r_i),
);
},
|| {
rayon::join(
|| poly_D_step.bind_poly_var_top(&r_i),
|| poly_D_core.bind_poly_var_top(&r_i),
);
},
|| poly_C_step.bind_poly_var_top(&r_i),
|| poly_C_core.bind_poly_var_top(&r_i),
);
},
);
},
);

// bind polynomial power of tau
// list power of tau (pow_tau) halves effectively
len_pow_tau >>= 1;
let one = E::Scalar::ONE;
let left = pow_tau_left.Z.len();
let pow = pow_tau_left.Z[len_pow_tau % left] * pow_tau_right.Z[len_pow_tau / left];
base_tau *= (pow - one) * r_i + one;
}

pow_tau_left.Z[0] = base_tau;

Ok(r_x)
}
}