Skip to content

Commit d47a275

Browse files
committed
fix four issues: replace assert with Err; prevent overflow; prevent apnics; verifier validates proof mask counts
1 parent e0030fb commit d47a275

4 files changed

Lines changed: 97 additions & 29 deletions

File tree

crates/provers/gkr-logup/examples/range_check.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,16 @@ fn range_check(values: &[u64], n_bits: u32, label: &str) -> bool {
5656

5757
// --- Batch prove ---
5858
let mut prover_ch = DefaultTranscript::<F>::new(&[]);
59-
let (proof, _) = prove_batch(&mut prover_ch, vec![access_layer, table_layer]).unwrap();
59+
let (proof, artifact) = prove_batch(&mut prover_ch, vec![access_layer, table_layer]).unwrap();
6060

6161
// --- Batch verify ---
6262
let mut verifier_ch = DefaultTranscript::<F>::new(&[]);
63-
let gkr_result = verify_batch(&[Gate::LogUp, Gate::LogUp], &proof, &mut verifier_ch);
63+
let gkr_result = verify_batch(
64+
&[Gate::LogUp, Gate::LogUp],
65+
&artifact.n_variables_by_instance,
66+
&proof,
67+
&mut verifier_ch,
68+
);
6469

6570
if let Err(e) = &gkr_result {
6671
println!(" GKR verification: FAILED ({})", e);

crates/provers/gkr-logup/examples/read_only_memory.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ fn main() {
7777
// -------------------------------------------------------
7878
println!("--- Proving ---");
7979
let mut prover_transcript = DefaultTranscript::<F>::new(&[]);
80-
let (proof, _artifact) =
80+
let (proof, artifact) =
8181
prove_batch(&mut prover_transcript, vec![access_layer, table_layer]).unwrap();
8282

8383
println!(
@@ -94,17 +94,18 @@ fn main() {
9494
let mut verifier_transcript = DefaultTranscript::<F>::new(&[]);
9595
let result = verify_batch(
9696
&[Gate::LogUp, Gate::LogUp],
97+
&artifact.n_variables_by_instance,
9798
&proof,
9899
&mut verifier_transcript,
99100
);
100101

101102
match &result {
102-
Ok(artifact) => {
103+
Ok(gkr_result) => {
103104
println!("GKR verification: PASSED");
104-
println!(" OOD point length: {}", artifact.ood_point.len());
105+
println!(" OOD point length: {}", gkr_result.ood_point.len());
105106
println!(
106107
" Variables by instance: {:?}",
107-
artifact.n_variables_by_instance
108+
gkr_result.n_variables_by_instance
108109
);
109110
}
110111
Err(e) => {
@@ -156,12 +157,13 @@ fn main() {
156157
};
157158

158159
let mut prover_transcript = DefaultTranscript::<F>::new(&[]);
159-
let (bad_proof, _) =
160+
let (bad_proof, bad_artifact) =
160161
prove_batch(&mut prover_transcript, vec![bad_access_layer, table_layer2]).unwrap();
161162

162163
let mut verifier_transcript = DefaultTranscript::<F>::new(&[]);
163164
let bad_result = verify_batch(
164165
&[Gate::LogUp, Gate::LogUp],
166+
&bad_artifact.n_variables_by_instance,
165167
&bad_proof,
166168
&mut verifier_transcript,
167169
);

crates/provers/gkr-logup/src/prover.rs

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,17 @@ fn correct_sum_as_poly_in_first_variable<F: IsField>(
271271
y: &[FieldElement<F>],
272272
k: usize,
273273
) -> Result<Polynomial<FieldElement<F>>, ProverError> {
274-
assert!(k > 0);
274+
if k == 0 {
275+
return Err(ProverError::InvalidState(
276+
"correct_sum_as_poly_in_first_variable: k must be > 0".to_string(),
277+
));
278+
}
275279
let n = y.len();
276-
assert!(k <= n);
280+
if k > n {
281+
return Err(ProverError::InvalidState(format!(
282+
"correct_sum_as_poly_in_first_variable: k ({k}) > y.len() ({n})"
283+
)));
284+
}
277285

278286
// a_const = 1 / eq_eval((0^(n-k+1)), y[..n-k+1])
279287
// eq_eval(0, y_i) = 1 - y_i, so eq_eval(0^m, y) = prod(1 - y_i).
@@ -498,7 +506,9 @@ where
498506
for (claim, oracle) in claims.iter_mut().zip(oracles.iter()) {
499507
let n_unused = n_variables - oracle.n_variables();
500508
if n_unused > 0 {
501-
*claim = &*claim * &FieldElement::<F>::from(1u64 << n_unused);
509+
let doubling_factor =
510+
(0..n_unused).fold(FieldElement::<F>::one(), |acc, _| &acc + &acc);
511+
*claim = &*claim * &doubling_factor;
502512
}
503513
}
504514

@@ -1116,11 +1126,16 @@ mod tests {
11161126
input_layers: Vec<Layer<F>>,
11171127
) -> verifier::BatchVerificationResult<F> {
11181128
let mut prover_transcript = DefaultTranscript::<F>::new(&[]);
1119-
let (proof, _artifact) = prove_batch(&mut prover_transcript, input_layers).unwrap();
1129+
let (proof, artifact) = prove_batch(&mut prover_transcript, input_layers).unwrap();
11201130

11211131
let mut verifier_transcript = DefaultTranscript::<F>::new(&[]);
1122-
verifier::verify_batch(&gates, &proof, &mut verifier_transcript)
1123-
.expect("batch verification should succeed")
1132+
verifier::verify_batch(
1133+
&gates,
1134+
&artifact.n_variables_by_instance,
1135+
&proof,
1136+
&mut verifier_transcript,
1137+
)
1138+
.expect("batch verification should succeed")
11241139
}
11251140

11261141
#[test]
@@ -1245,7 +1260,7 @@ mod tests {
12451260
let values1: Vec<FE> = (11u64..=18).map(FE::from).collect();
12461261

12471262
let mut prover_transcript = DefaultTranscript::<F>::new(&[]);
1248-
let (mut proof, _) = prove_batch(
1263+
let (mut proof, artifact) = prove_batch(
12491264
&mut prover_transcript,
12501265
vec![
12511266
Layer::GrandProduct(DenseMultilinearPolynomial::new(values0)),
@@ -1259,6 +1274,7 @@ mod tests {
12591274
let mut verifier_transcript = DefaultTranscript::<F>::new(&[]);
12601275
let result = verifier::verify_batch(
12611276
&[Gate::GrandProduct, Gate::GrandProduct],
1277+
&artifact.n_variables_by_instance,
12621278
&proof,
12631279
&mut verifier_transcript,
12641280
);
@@ -1271,7 +1287,7 @@ mod tests {
12711287
let values1: Vec<FE> = (11u64..=18).map(FE::from).collect();
12721288

12731289
let mut prover_transcript = DefaultTranscript::<F>::new(&[]);
1274-
let (mut proof, _) = prove_batch(
1290+
let (mut proof, artifact) = prove_batch(
12751291
&mut prover_transcript,
12761292
vec![
12771293
Layer::GrandProduct(DenseMultilinearPolynomial::new(values0)),
@@ -1290,6 +1306,7 @@ mod tests {
12901306
let mut verifier_transcript = DefaultTranscript::<F>::new(&[]);
12911307
let result = verifier::verify_batch(
12921308
&[Gate::GrandProduct, Gate::GrandProduct],
1309+
&artifact.n_variables_by_instance,
12931310
&proof,
12941311
&mut verifier_transcript,
12951312
);
@@ -1302,7 +1319,7 @@ mod tests {
13021319
let values1: Vec<FE> = (11u64..=14).map(FE::from).collect();
13031320

13041321
let mut prover_transcript = DefaultTranscript::<F>::new(&[]);
1305-
let (mut proof, _) = prove_batch(
1322+
let (mut proof, artifact) = prove_batch(
13061323
&mut prover_transcript,
13071324
vec![
13081325
Layer::GrandProduct(DenseMultilinearPolynomial::new(values0)),
@@ -1316,6 +1333,7 @@ mod tests {
13161333
let mut verifier_transcript = DefaultTranscript::<F>::new(&[]);
13171334
let result = verifier::verify_batch(
13181335
&[Gate::GrandProduct, Gate::GrandProduct],
1336+
&artifact.n_variables_by_instance,
13191337
&proof,
13201338
&mut verifier_transcript,
13211339
);
@@ -1365,13 +1383,14 @@ mod tests {
13651383

13661384
// Batch prove both instances
13671385
let mut prover_transcript = DefaultTranscript::<F>::new(&[]);
1368-
let (proof, _artifact) =
1386+
let (proof, artifact) =
13691387
prove_batch(&mut prover_transcript, vec![access_layer, table_layer]).unwrap();
13701388

13711389
// Batch verify
13721390
let mut verifier_transcript = DefaultTranscript::<F>::new(&[]);
13731391
let result = verifier::verify_batch(
13741392
&[Gate::LogUp, Gate::LogUp],
1393+
&artifact.n_variables_by_instance,
13751394
&proof,
13761395
&mut verifier_transcript,
13771396
);
@@ -1413,13 +1432,14 @@ mod tests {
14131432
};
14141433

14151434
let mut prover_transcript = DefaultTranscript::<F>::new(&[]);
1416-
let (proof, _artifact) =
1435+
let (proof, artifact) =
14171436
prove_batch(&mut prover_transcript, vec![access_layer, table_layer]).unwrap();
14181437

14191438
// GKR itself verifies fine (each instance is internally consistent)
14201439
let mut verifier_transcript = DefaultTranscript::<F>::new(&[]);
14211440
let result = verifier::verify_batch(
14221441
&[Gate::LogUp, Gate::LogUp],
1442+
&artifact.n_variables_by_instance,
14231443
&proof,
14241444
&mut verifier_transcript,
14251445
);
@@ -1467,8 +1487,12 @@ mod tests {
14671487
assert_eq!(artifact.claims_to_verify_by_instance[0], vec![FE::from(42)]);
14681488

14691489
let mut verifier_transcript = DefaultTranscript::<F>::new(&[]);
1470-
let result =
1471-
verifier::verify_batch(&[Gate::GrandProduct], &proof, &mut verifier_transcript);
1490+
let result = verifier::verify_batch(
1491+
&[Gate::GrandProduct],
1492+
&artifact.n_variables_by_instance,
1493+
&proof,
1494+
&mut verifier_transcript,
1495+
);
14721496
assert!(
14731497
result.is_ok(),
14741498
"batch verification of size-1 instance should succeed"
@@ -1527,10 +1551,15 @@ mod tests {
15271551
let single = Layer::GrandProduct(DenseMultilinearPolynomial::new(vec![FE::from(42)]));
15281552

15291553
let mut prover_transcript = DefaultTranscript::<F>::new(&[]);
1530-
let (proof, _) = prove_batch(&mut prover_transcript, vec![single]).unwrap();
1554+
let (proof, artifact) = prove_batch(&mut prover_transcript, vec![single]).unwrap();
15311555

15321556
let mut verifier_transcript = DefaultTranscript::<F>::new(&[]);
1533-
let result = verifier::verify_batch(&[Gate::LogUp], &proof, &mut verifier_transcript);
1557+
let result = verifier::verify_batch(
1558+
&[Gate::LogUp],
1559+
&artifact.n_variables_by_instance,
1560+
&proof,
1561+
&mut verifier_transcript,
1562+
);
15341563
assert!(
15351564
result.is_err(),
15361565
"wrong gate on 0-layer batch instance should be rejected"
@@ -1552,7 +1581,12 @@ mod tests {
15521581
assert!(artifact.claims_to_verify_by_instance.is_empty());
15531582

15541583
let mut verifier_transcript = DefaultTranscript::<F>::new(&[]);
1555-
let result = verifier::verify_batch(&[], &proof, &mut verifier_transcript);
1584+
let result = verifier::verify_batch(
1585+
&[],
1586+
&artifact.n_variables_by_instance,
1587+
&proof,
1588+
&mut verifier_transcript,
1589+
);
15561590
assert!(result.is_ok(), "empty batch should verify successfully");
15571591

15581592
let vr = result.unwrap();

crates/provers/gkr-logup/src/verifier.rs

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ pub struct BatchVerificationResult<F: IsField> {
274274
/// layer MLE evaluations.
275275
pub fn verify_batch<F, T>(
276276
gates: &[Gate],
277+
n_variables_by_instance: &[usize],
277278
proof: &BatchProof<F>,
278279
transcript: &mut T,
279280
) -> Result<BatchVerificationResult<F>, VerifierError<F>>
@@ -297,11 +298,21 @@ where
297298
return Err(VerifierError::MalformedProof);
298299
}
299300

301+
// HIGH-003: validate trusted statement matches proof structure
302+
if n_variables_by_instance.len() != n_instances {
303+
return Err(VerifierError::MalformedProof);
304+
}
305+
for instance in 0..n_instances {
306+
if layer_masks_by_instance[instance].len() != n_variables_by_instance[instance] {
307+
return Err(VerifierError::MalformedProof);
308+
}
309+
}
310+
300311
// Domain separation: must match prover (see prove_batch).
301312
transcript.append_bytes(b"gkr_batch");
302313
transcript.append_bytes(&(n_instances as u64).to_le_bytes());
303314

304-
let instance_n_layers = |instance: usize| layer_masks_by_instance[instance].len();
315+
let instance_n_layers = |instance: usize| n_variables_by_instance[instance];
305316
let n_layers = (0..n_instances).map(instance_n_layers).max().unwrap_or(0);
306317

307318
if n_layers != sumcheck_proofs.len() {
@@ -362,7 +373,8 @@ where
362373
continue;
363374
}
364375
let n_unused = n_layers - instance_n_layers(instance);
365-
let doubling_factor = FieldElement::<F>::from(1u64 << n_unused);
376+
let doubling_factor =
377+
(0..n_unused).fold(FieldElement::<F>::one(), |acc, _| &acc + &acc);
366378
let claim = &random_linear_combination(claims, &lambda) * &doubling_factor;
367379
sumcheck_claims.push(claim);
368380
sumcheck_instances.push(instance);
@@ -383,7 +395,12 @@ where
383395
let mut layer_evals = Vec::new();
384396
for &instance in &sumcheck_instances {
385397
let n_unused = n_layers - instance_n_layers(instance);
386-
let mask = &layer_masks_by_instance[instance][layer - n_unused];
398+
if layer < n_unused {
399+
return Err(VerifierError::MalformedProof);
400+
}
401+
let mask = layer_masks_by_instance[instance]
402+
.get(layer - n_unused)
403+
.ok_or(VerifierError::MalformedProof)?;
387404
let gate_output = gates[instance].eval(mask)?;
388405

389406
// eq evaluation uses the relevant suffix of the OOD point.
@@ -416,7 +433,12 @@ where
416433
// Seed transcript with masks (same order as prover).
417434
for &instance in &sumcheck_instances {
418435
let n_unused = n_layers - instance_n_layers(instance);
419-
let mask = &layer_masks_by_instance[instance][layer - n_unused];
436+
if layer < n_unused {
437+
return Err(VerifierError::MalformedProof);
438+
}
439+
let mask = layer_masks_by_instance[instance]
440+
.get(layer - n_unused)
441+
.ok_or(VerifierError::MalformedProof)?;
420442
for col in mask.columns() {
421443
transcript.append_field_element(&col[0]);
422444
transcript.append_field_element(&col[1]);
@@ -427,7 +449,12 @@ where
427449
let challenge: FieldElement<F> = transcript.sample_field_element();
428450
for instance in sumcheck_instances {
429451
let n_unused = n_layers - instance_n_layers(instance);
430-
let mask = &layer_masks_by_instance[instance][layer - n_unused];
452+
if layer < n_unused {
453+
return Err(VerifierError::MalformedProof);
454+
}
455+
let mask = layer_masks_by_instance[instance]
456+
.get(layer - n_unused)
457+
.ok_or(VerifierError::MalformedProof)?;
431458
claims_to_verify_by_instance[instance] = Some(mask.reduce_at_point(&challenge));
432459
}
433460

@@ -444,6 +471,6 @@ where
444471
Ok(BatchVerificationResult {
445472
ood_point,
446473
claims_to_verify_by_instance,
447-
n_variables_by_instance: (0..n_instances).map(instance_n_layers).collect(),
474+
n_variables_by_instance: n_variables_by_instance.to_vec(),
448475
})
449476
}

0 commit comments

Comments
 (0)