Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 81 additions & 97 deletions src/neutronnova_zk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,26 @@ fn suffix_weight_full<F: Field>(t: usize, ell_b: usize, pair_idx: usize, rhos: &
w
}

#[inline]
fn mul_opt<F: Field>(a: &F, b: &F) -> F {
if a == &F::ZERO || b == &F::ZERO {
F::ZERO
} else if a == &F::ONE {
*b
} else if b == &F::ONE {
*a
} else {
*a * *b
}
}

impl<E: Engine> NeutronNovaNIFS<E>
where
E::PCS: FoldingEngineTrait<E>,
{
/// Computes the evaluations of the sum-check polynomial at 0, 2, and 3
#[inline]
fn prove_helper(
rho: &E::Scalar,
(left, right): (usize, usize),
e: &[E::Scalar],
Az1: &[E::Scalar],
Expand All @@ -91,7 +103,7 @@ where
Az2: &[E::Scalar],
Bz2: &[E::Scalar],
Cz2: &[E::Scalar],
) -> (E::Scalar, E::Scalar, E::Scalar) {
) -> (E::Scalar, E::Scalar) {
// sanity check sizes
assert_eq!(e.len(), left + right);
assert_eq!(Az1.len(), left * right);
Expand All @@ -104,10 +116,10 @@ where
let comb_func = |c1: &E::Scalar, c2: &E::Scalar, c3: &E::Scalar, c4: &E::Scalar| -> E::Scalar {
*c1 * (*c2 * *c3 - *c4)
};
let (eval_at_0, eval_at_2, eval_at_3) = (0..right)
let (eval_at_0, quad_coeff) = (0..right)
.into_par_iter()
.map(|i| {
let (i_eval_at_0, i_eval_at_2, i_eval_at_3) = (0..left)
let (mut i_eval_at_0, mut i_quad_coeff) = (0..left)
.into_par_iter()
.map(|j| {
// Turn the two dimensional (i, j) into a single dimension index
Expand All @@ -117,65 +129,39 @@ where
// eval 0: bound_func is A(low)
let eval_point_0 = comb_func(&poly_e_bound_point, &Az1[k], &Bz1[k], &Cz1[k]);

// eval 2: bound_func is -A(low) + 2*A(high)
let poly_Az_bound_point = Az2[k] + Az2[k] - Az1[k];
let poly_Bz_bound_point = Bz2[k] + Bz2[k] - Bz1[k];
let poly_Cz_bound_point = Cz2[k] + Cz2[k] - Cz1[k];
let eval_point_2 = comb_func(
// quad coeff
let poly_Az_bound_point = Az2[k] - Az1[k];
let poly_Bz_bound_point = Bz2[k] - Bz1[k];
let quad_coeff = mul_opt(
&mul_opt(&poly_Az_bound_point, &poly_Bz_bound_point),
&poly_e_bound_point,
&poly_Az_bound_point,
&poly_Bz_bound_point,
&poly_Cz_bound_point,
);

// eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
let poly_Az_bound_point = poly_Az_bound_point + Az2[k] - Az1[k];
let poly_Bz_bound_point = poly_Bz_bound_point + Bz2[k] - Bz1[k];
let poly_Cz_bound_point = poly_Cz_bound_point + Cz2[k] - Cz1[k];
let eval_point_3 = comb_func(
&poly_e_bound_point,
&poly_Az_bound_point,
&poly_Bz_bound_point,
&poly_Cz_bound_point,
);

(eval_point_0, eval_point_2, eval_point_3)
(eval_point_0, quad_coeff)
})
.reduce(
|| (E::Scalar::ZERO, E::Scalar::ZERO, E::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2),
|| (E::Scalar::ZERO, E::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1),
);

let f = &e[left..];

let poly_f_bound_point = f[i];

// eval 0: bound_func is A(low)
let eval_at_0 = poly_f_bound_point * i_eval_at_0;

// eval 2: bound_func is -A(low) + 2*A(high)
let eval_at_2 = poly_f_bound_point * i_eval_at_2;
i_eval_at_0 *= poly_f_bound_point;

// eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
let eval_at_3 = poly_f_bound_point * i_eval_at_3;
// quad coeff
i_quad_coeff *= poly_f_bound_point;

(eval_at_0, eval_at_2, eval_at_3)
(i_eval_at_0, i_quad_coeff)
})
.reduce(
|| (E::Scalar::ZERO, E::Scalar::ZERO, E::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2),
|| (E::Scalar::ZERO, E::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1),
);

// multiply by the common factors
let one_minus_rho = E::Scalar::ONE - rho;
let three_rho_minus_one = E::Scalar::from(3) * rho - E::Scalar::ONE;
let five_rho_minus_two = E::Scalar::from(5) * rho - E::Scalar::from(2);

(
eval_at_0 * one_minus_rho,
eval_at_2 * three_rho_minus_one,
eval_at_3 * five_rho_minus_two,
)
(eval_at_0, quad_coeff)
}

/// ZK version of NeutronNova NIFS prove. This function performs the NIFS folding
Expand Down Expand Up @@ -244,7 +230,6 @@ where
// Build Az, Bz, Cz tables for each (possibly padded) instance
let (_matrix_span, matrix_t) =
start_span!("matrix_vector_multiply_instances", instances = n_padded);
let chunk_len = left * right;
let triples = (0..n_padded)
.into_par_iter()
.map(|i| {
Expand Down Expand Up @@ -282,34 +267,45 @@ where
// Round polynomial: use rho_t inside prove_helper (this multiplies by eq(b_t; rho_t))
let pairs = m / 2;

let (e0, e2, e3) = (0..pairs)
.into_par_iter()
.map(|pair_idx| {
let lo = 2 * pair_idx;
let hi = lo + 1;
let (a0, a2, a3) = Self::prove_helper(
&rho_t,
let (e0, quad_coeff) = A_layers
.par_chunks(2)
.zip(B_layers.par_chunks(2))
.zip(C_layers.par_chunks(2))
.enumerate()
.map(|(pair_idx, ((pair_a, pair_b), pair_c))| {
let (e0, quad_coeff) = Self::prove_helper(
(left, right),
&E_eq,
&A_layers[lo],
&B_layers[lo],
&C_layers[lo],
&A_layers[hi],
&B_layers[hi],
&C_layers[hi],
&pair_a[0],
&pair_b[0],
&pair_c[0],
&pair_a[1],
&pair_b[1],
&pair_c[1],
);
let w = suffix_weight_full::<E::Scalar>(t, ell_b, pair_idx, &rhos);
(a0 * w, a2 * w, a3 * w)
(e0 * w, quad_coeff * w)
})
.reduce(
|| (E::Scalar::ZERO, E::Scalar::ZERO, E::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2),
|| (E::Scalar::ZERO, E::Scalar::ZERO),
|a, b| (a.0 + b.0, a.1 + b.1),
);

let se0 = acc_eq * e0;
let se2 = acc_eq * e2;
let se3 = acc_eq * e3;
let poly_t = UniPoly::<E::Scalar>::from_evals(&[se0, T_cur - se0, se2, se3])?;
// recover cubic polynomial coefficients from eval_at_zero and cubic_term_coeff
let one_minus_rho = E::Scalar::ONE - rho_t;
let two_rho_minus_one = rho_t - one_minus_rho;
let c = e0 * acc_eq;
let a = quad_coeff * acc_eq;
let a_b_c = (T_cur - c * one_minus_rho) * rho_t.invert().unwrap();
let b = a_b_c - a - c;
let new_a = a * two_rho_minus_one;
let new_b = b * two_rho_minus_one + a * one_minus_rho;
let new_c = c * two_rho_minus_one + b * one_minus_rho;
let new_d = c * one_minus_rho;

let poly_t = UniPoly {
coeffs: vec![new_d, new_c, new_b, new_a],
};
polys.push(poly_t.clone());

// Expose polynomial coefficients to the verifier circuit and feed into the transcript/state
Expand All @@ -326,41 +322,29 @@ where
T_cur = poly_t.evaluate(&r_b);

// Fold A/B/C layers for next round (weights 1-r_b, r_b)
let mut next_A: Vec<Vec<E::Scalar>> = Vec::with_capacity(pairs);
let mut next_B: Vec<Vec<E::Scalar>> = Vec::with_capacity(pairs);
let mut next_C: Vec<Vec<E::Scalar>> = Vec::with_capacity(pairs);
next_A.par_extend((0..pairs).into_par_iter().map(|i| {
let lo = 2 * i;
let hi = lo + 1;
let mut v = vec![E::Scalar::ZERO; chunk_len];
v.iter_mut().enumerate().for_each(|(k, val)| {
*val = A_layers[lo][k] + (A_layers[hi][k] - A_layers[lo][k]) * r_b;
});
v
}));
next_B.par_extend((0..pairs).into_par_iter().map(|i| {
let lo = 2 * i;
let hi = lo + 1;
let mut v = vec![E::Scalar::ZERO; chunk_len];
v.iter_mut().enumerate().for_each(|(k, val)| {
*val = B_layers[lo][k] + (B_layers[hi][k] - B_layers[lo][k]) * r_b;
});
v
}));
next_C.par_extend((0..pairs).into_par_iter().map(|i| {
let lo = 2 * i;
let hi = lo + 1;
let mut v = vec![E::Scalar::ZERO; chunk_len];
v.iter_mut().enumerate().for_each(|(k, val)| {
*val = C_layers[lo][k] + (C_layers[hi][k] - C_layers[lo][k]) * r_b;
});
v
}));

let mut next_A = vec![vec![]; m];
let mut next_B = vec![vec![]; m];
let mut next_C = vec![vec![]; m];
for i in 0..m {
let t = if i & 1 == 0 { i >> 1 } else { (i >> 1) + pairs };
next_A[t] = std::mem::take(&mut A_layers[i]);
next_B[t] = std::mem::take(&mut B_layers[i]);
next_C[t] = std::mem::take(&mut C_layers[i]);
}
A_layers = next_A;
B_layers = next_B;
C_layers = next_C;

for matrix_layer in [&mut A_layers, &mut B_layers, &mut C_layers] {
let (low, high) = matrix_layer.split_at_mut(pairs);
low.iter_mut().zip(high.iter()).for_each(|(lo, hi)| {
lo.iter_mut().zip(hi.iter()).for_each(|(l, h)| {
*l += mul_opt(&(*h - *l), &r_b);
});
});
matrix_layer.truncate(pairs);
}

// m becomes ceil(m/2)
m = pairs;
}
Expand Down
Loading