diff --git a/circuits/src/generation/poseidon2_output_bytes.rs b/circuits/src/generation/poseidon2_output_bytes.rs index 6487a3b80..7a1b4970b 100644 --- a/circuits/src/generation/poseidon2_output_bytes.rs +++ b/circuits/src/generation/poseidon2_output_bytes.rs @@ -4,7 +4,8 @@ use super::MIN_TRACE_LENGTH; use crate::poseidon2_output_bytes::columns::Poseidon2OutputBytes; use crate::poseidon2_sponge::columns::Poseidon2Sponge; -fn pad_trace( +#[must_use] +pub fn pad_trace( mut trace: Vec>, ) -> Vec> { trace.resize( diff --git a/circuits/src/poseidon2_output_bytes/columns.rs b/circuits/src/poseidon2_output_bytes/columns.rs index f9db46a10..91f381524 100644 --- a/circuits/src/poseidon2_output_bytes/columns.rs +++ b/circuits/src/poseidon2_output_bytes/columns.rs @@ -17,6 +17,7 @@ pub struct Poseidon2OutputBytes { pub output_addr: F, pub output_fields: [F; FIELDS_COUNT], pub output_bytes: [F; BYTES_COUNT], + pub gap_invs: [F; FIELDS_COUNT], } columns_view_impl!(Poseidon2OutputBytes); @@ -31,6 +32,13 @@ impl From<&Poseidon2Sponge> for Vec> { .try_into() .expect("Must have at least 4 Fields"); let hash_bytes = HashOut::from(output_fields).to_bytes(); + let hash_high_limbs = output_fields.map(|limb| (limb.to_canonical_u64() >> 32) as u32); + let gap_invs = hash_high_limbs.map(|limb| { + F::from_canonical_u32(u32::MAX - limb) + .try_inverse() + .unwrap_or_default() + }); + let output_bytes = hash_bytes .iter() .map(|x| F::from_canonical_u8(*x)) @@ -43,6 +51,7 @@ impl From<&Poseidon2Sponge> for Vec> { output_addr: value.output_addr, output_fields, output_bytes, + gap_invs, }]; } vec![] diff --git a/circuits/src/poseidon2_output_bytes/stark.rs b/circuits/src/poseidon2_output_bytes/stark.rs index 3b6c6033d..2b5499189 100644 --- a/circuits/src/poseidon2_output_bytes/stark.rs +++ b/circuits/src/poseidon2_output_bytes/stark.rs @@ -44,18 +44,28 @@ impl, const D: usize> Stark for Poseidon2Outp ) where FE: FieldExtension, P: PackedField, { + let two_to_eight = P::Scalar::from_canonical_u16(256); let lv: &Poseidon2OutputBytes<_> = vars.get_local_values().into(); is_binary(yield_constr, lv.is_executed); for i in 0..FIELDS_COUNT { let start_index = i * 8; let end_index = i * 8 + 8; yield_constr.constraint( - reduce_with_powers( - &lv.output_bytes[start_index..end_index], - P::Scalar::from_canonical_u16(256), - ) - lv.output_fields[i], + reduce_with_powers(&lv.output_bytes[start_index..end_index], two_to_eight) + - lv.output_fields[i], ); } + + let u32_max: P = P::Scalar::from_canonical_u32(u32::MAX).into(); + let one = P::ONES; + + (0..4).for_each(|i| { + let low_limb = reduce_with_powers(&lv.output_bytes[8 * i..8 * i + 4], two_to_eight); + let high_limb = + reduce_with_powers(&lv.output_bytes[8 * i + 4..8 * i + 8], two_to_eight); + let gap_inv = lv.gap_invs[i]; + yield_constr.constraint(((u32_max - high_limb) * gap_inv - one) * low_limb); + }); } fn eval_ext_circuit( @@ -65,8 +75,8 @@ impl, const D: usize> Stark for Poseidon2Outp yield_constr: &mut RecursiveConstraintConsumer, ) { let lv: &Poseidon2OutputBytes> = vars.get_local_values().into(); - is_binary_ext_circuit(builder, lv.is_executed, yield_constr); let two_to_eight = builder.constant(F::from_canonical_u16(256)); + is_binary_ext_circuit(builder, lv.is_executed, yield_constr); for i in 0..FIELDS_COUNT { let start_index = i * 8; let end_index = i * 8 + 8; @@ -78,6 +88,29 @@ impl, const D: usize> Stark for Poseidon2Outp let x_sub_of = builder.sub_extension(x, lv.output_fields[i]); yield_constr.constraint(builder, x_sub_of); } + + let u32_max = builder.constant_extension(F::from_canonical_u32(u32::MAX).into()); + let one = builder.constant_extension(F::ONE.into()); + + (0..4).for_each(|i| { + let low_limb = reduce_with_powers_ext_circuit( + builder, + &lv.output_bytes[8 * i..8 * i + 4], + two_to_eight, + ); + let high_limb = reduce_with_powers_ext_circuit( + builder, + &lv.output_bytes[8 * i + 4..8 * i + 8], + two_to_eight, + ); + let gap_inv = lv.gap_invs[i]; + let u32_max_sub_high_limb = builder.sub_extension(u32_max, high_limb); + let u32_max_sub_high_limb_times_gap_inv_minus_one = + builder.mul_sub_extension(u32_max_sub_high_limb, gap_inv, one); + let zero = + builder.mul_extension(u32_max_sub_high_limb_times_gap_inv_minus_one, low_limb); + yield_constr.constraint(builder, zero); + }); } fn constraint_degree(&self) -> usize { 3 } @@ -95,13 +128,17 @@ mod tests { use proptest::prelude::ProptestConfig; use proptest::{prop_assert_eq, proptest}; use starky::config::StarkConfig; - use starky::prover::prove; + use starky::prover::{prove, prove as prove_table}; use starky::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; use starky::verifier::verify_stark_proof; use super::Poseidon2OutputBytesStark; - use crate::generation::poseidon2_output_bytes::generate_poseidon2_output_bytes_trace; + use crate::generation::poseidon2_output_bytes::{ + generate_poseidon2_output_bytes_trace, pad_trace, + }; use crate::generation::poseidon2_sponge::generate_poseidon2_sponge_trace; + use crate::poseidon2_output_bytes::columns::Poseidon2OutputBytes; + use crate::poseidon2_sponge::columns::Poseidon2Sponge; use crate::stark::utils::trace_rows_to_poly_values; use crate::test_utils::{create_poseidon2_test, Poseidon2Test}; @@ -186,4 +223,50 @@ mod tests { Ok(()) } + + proptest! { + /// Poseidon2OutputBytes stark with output bytes corresponding to + /// non canonical form of hash (with a limb >= goldilocks prime) + /// should fail + #[test] + #[cfg_attr(debug_assertions, should_panic = "Constraint failed in")] + fn non_canonical_hash(value in 0..u32::MAX) { + fn malicious_trace(value: u32) -> Vec> { + let output = [F::from_canonical_u32(value); 12]; + let sponge = Poseidon2Sponge:: { + output, + gen_output: F::ONE, + ..Default::default() + }; + let mut malicious_trace: Vec> = (&sponge).into(); + // add goldilocks prime to first limb + let u8_max = F::from_canonical_u8(u8::MAX); + (4..8).for_each(|i| malicious_trace[0].output_bytes[i] += u8_max); + malicious_trace[0].output_bytes[0] += F::ONE; + + // test that field elements still correspond to malicious bytes + let two_to_eight = F::from_canonical_u16(256); + let output_fields = [0, 1, 2, 3].map(|i| { + reduce_with_powers( + &malicious_trace[0].output_bytes[8 * i..8 * i + 8], + two_to_eight, + ) + }); + assert_eq!(output_fields, malicious_trace[0].output_fields); + pad_trace(malicious_trace) + } + + let trace = malicious_trace(value); + let config = StarkConfig::standard_fast_config(); + let stark = S::default(); + let trace_poly_values = trace_rows_to_poly_values(trace); + + let _proof = prove_table::( + stark, + &config, + trace_poly_values, + &[], + &mut TimingTree::default(), + ); + }} }