diff --git a/Cargo.toml b/Cargo.toml index 182f98a..22454f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ once_cell = "1.18.0" tracing = "0.1" tracing-subscriber = { version = "0.3.18", features = ["env-filter", "time"] } tikv-jemallocator = { version = "0.6.0", optional = true } +clap = { version = "4", features = ["derive"] } [target.wasm32-unknown-unknown.dependencies] # see https://github.com/rust-random/rand/pull/948 @@ -56,6 +57,6 @@ jem = ["tikv-jemallocator"] [profile.release] debug = 1 lto = "fat" - + [target.'cfg(target_arch = "x86_64")'.dependencies] halo2curves = { version = "0.9.0", features = ["derive_serde", "std", "asm"] } diff --git a/examples/accum_bench.rs b/examples/accum_bench.rs new file mode 100644 index 0000000..ff83ceb --- /dev/null +++ b/examples/accum_bench.rs @@ -0,0 +1,520 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! examples/accum_bench.rs +//! +//! Benchmark for Lagrange accumulator building and l0 sumcheck rounds. +//! +//! Measures: +//! - Accumulator generation time (`build_accumulators_spartan`) +//! - L0 sumcheck rounds time (using precomputed accumulators) +//! - Total time (accum + l0) +//! +//! Compares DelayedModularReductionEnabled vs DelayedModularReductionDisabled. +//! +//! Run with: +//! cargo run --release --example accum_bench -- single 22 +//! cargo run --release --example accum_bench -- --delayed-modular-reduction both --l0 5 single 22 +//! cargo run --release --example accum_bench -- --show accum-only range-sweep --min 16 --max 24 +//! cargo run --release --example accum_bench -- --show l0-only range-sweep --min 16 --max 24 +//! cargo run --release --example accum_bench -- --show total-only range-sweep --min 16 --max 24 + +#[cfg(feature = "jem")] +#[global_allocator] +static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + +use clap::{Parser, Subcommand, ValueEnum}; +use ff::Field; +use spartan2::{ + lagrange_accumulator::{ + DelayedModularReductionDisabled, DelayedModularReductionEnabled, DelayedModularReductionMode, + MatVecMLE, SPARTAN_T_DEGREE, build_accumulators_spartan, derive_t1, + }, + polys::{multilinear::MultilinearPolynomial, univariate::UniPoly}, + provider::{Bn254Engine, PallasHyraxEngine, VestaHyraxEngine}, + small_field::{DelayedReduction, SmallValueField}, + sumcheck::lagrange_sumcheck::SmallValueSumCheck, + traits::Engine, +}; +use std::{io::Write, time::Instant}; +use tracing::info; +use tracing_subscriber::EnvFilter; + +/// Field choice for benchmarks +#[derive(ValueEnum, Clone, Default, Debug)] +enum FieldChoice { + /// Pallas curve scalar field (Fq) + #[default] + PallasFq, + /// Vesta curve scalar field (Fp) + VestaFp, + /// BN254 curve scalar field (Fr) + Bn254Fr, +} + +/// Witness type for benchmarks +#[derive(ValueEnum, Clone, Default, Debug)] +enum WitnessType { + /// i32 witness coefficients + I32, + /// i64 witness coefficients + #[default] + I64, + /// Full field element coefficients + Field, +} + +/// Delayed modular reduction mode selection for benchmarks +#[derive(ValueEnum, Clone, Default, Debug)] +enum DelayedModularReductionChoice { + /// Delayed modular reduction enabled + Enabled, + /// Delayed modular reduction disabled (immediate reduction) + Disabled, + /// Run both and compare + #[default] + Both, +} + +/// Which timings to show in output +#[derive(ValueEnum, Clone, Default, Debug)] +enum ShowTimings { + /// Show only accumulator build time + AccumOnly, + /// Show only l0 sumcheck rounds time + L0Only, + /// Show only total time (accum + l0) + TotalOnly, + /// Show all timings (accum, l0, and total) + #[default] + All, +} + +#[derive(Parser)] +#[command(about = "Accumulator benchmark with delayed modular reduction toggle")] +struct Args { + /// Field to use for benchmarks + #[arg(long, value_enum, default_value = "pallas-fq")] + field: FieldChoice, + + /// Witness type (i32, i64, or field) + #[arg(long, value_enum, default_value = "i64")] + witness: WitnessType, + + /// Number of small-value rounds (default: num_vars / 4) + #[arg(long)] + l0: Option, + + /// Delayed modular reduction mode (enabled, disabled, or both) + #[arg(long, value_enum, default_value = "both")] + delayed_modular_reduction: DelayedModularReductionChoice, + + /// Which timings to show (accum-only, l0-only, total-only, or all) + #[arg(long, value_enum, default_value = "all")] + show: ShowTimings, + + /// Number of trials per num_vars + #[arg(long, default_value = "1")] + trials: usize, + + #[command(subcommand)] + command: Option, +} + +#[derive(Subcommand)] +enum Command { + /// Run a single level (for profiling) + Single { vars: usize }, + /// Run a range sweep + RangeSweep { + #[arg(long, default_value = "16")] + min: usize, + #[arg(long, default_value = "24")] + max: usize, + }, +} + +/// Benchmark result for one configuration +#[derive(Clone, Copy, Default)] +struct BenchResult { + accum_us: u128, + l0_rounds_us: u128, +} + +impl BenchResult { + fn total_us(&self) -> u128 { + self.accum_us + self.l0_rounds_us + } +} + +/// Result of a single trial +struct TrialResult { + num_vars: usize, + n: usize, + l0: usize, + trial: usize, + delayed_modular_reduction_enabled: Option, + delayed_modular_reduction_disabled: Option, +} + +fn build_csv_header( + delayed_modular_reduction: &DelayedModularReductionChoice, + show: &ShowTimings, +) -> String { + let mut cols = vec!["num_vars", "n", "l0", "trial"]; + + match delayed_modular_reduction { + DelayedModularReductionChoice::Enabled => match show { + ShowTimings::AccumOnly => cols.push("dmr_accum_us"), + ShowTimings::L0Only => cols.push("dmr_l0_us"), + ShowTimings::TotalOnly => cols.push("dmr_total_us"), + ShowTimings::All => cols.extend(["dmr_accum_us", "dmr_l0_us", "dmr_total_us"]), + }, + DelayedModularReductionChoice::Disabled => match show { + ShowTimings::AccumOnly => cols.push("no_dmr_accum_us"), + ShowTimings::L0Only => cols.push("no_dmr_l0_us"), + ShowTimings::TotalOnly => cols.push("no_dmr_total_us"), + ShowTimings::All => cols.extend(["no_dmr_accum_us", "no_dmr_l0_us", "no_dmr_total_us"]), + }, + DelayedModularReductionChoice::Both => match show { + ShowTimings::AccumOnly => { + cols.extend(["dmr_accum_us", "no_dmr_accum_us", "accum_speedup"]); + } + ShowTimings::L0Only => { + cols.extend(["dmr_l0_us", "no_dmr_l0_us", "l0_speedup"]); + } + ShowTimings::TotalOnly => { + cols.extend(["dmr_total_us", "no_dmr_total_us", "total_speedup"]); + } + ShowTimings::All => { + cols.extend([ + "dmr_accum_us", + "dmr_l0_us", + "dmr_total_us", + "no_dmr_accum_us", + "no_dmr_l0_us", + "no_dmr_total_us", + "accum_speedup", + "l0_speedup", + "total_speedup", + ]); + } + }, + } + cols.join(",") +} + +fn format_csv_row( + result: &TrialResult, + delayed_modular_reduction: &DelayedModularReductionChoice, + show: &ShowTimings, +) -> String { + let mut row = format!( + "{},{},{},{}", + result.num_vars, result.n, result.l0, result.trial + ); + + match delayed_modular_reduction { + DelayedModularReductionChoice::Enabled => { + if let Some(r) = result.delayed_modular_reduction_enabled { + match show { + ShowTimings::AccumOnly => row.push_str(&format!(",{}", r.accum_us)), + ShowTimings::L0Only => row.push_str(&format!(",{}", r.l0_rounds_us)), + ShowTimings::TotalOnly => row.push_str(&format!(",{}", r.total_us())), + ShowTimings::All => row.push_str(&format!( + ",{},{},{}", + r.accum_us, + r.l0_rounds_us, + r.total_us() + )), + } + } + } + DelayedModularReductionChoice::Disabled => { + if let Some(r) = result.delayed_modular_reduction_disabled { + match show { + ShowTimings::AccumOnly => row.push_str(&format!(",{}", r.accum_us)), + ShowTimings::L0Only => row.push_str(&format!(",{}", r.l0_rounds_us)), + ShowTimings::TotalOnly => row.push_str(&format!(",{}", r.total_us())), + ShowTimings::All => row.push_str(&format!( + ",{},{},{}", + r.accum_us, + r.l0_rounds_us, + r.total_us() + )), + } + } + } + DelayedModularReductionChoice::Both => { + if let (Some(enabled_r), Some(disabled_r)) = ( + result.delayed_modular_reduction_enabled, + result.delayed_modular_reduction_disabled, + ) { + match show { + ShowTimings::AccumOnly => { + let speedup = disabled_r.accum_us as f64 / enabled_r.accum_us as f64; + row.push_str(&format!( + ",{},{},{:.2}", + enabled_r.accum_us, disabled_r.accum_us, speedup + )); + } + ShowTimings::L0Only => { + let speedup = disabled_r.l0_rounds_us as f64 / enabled_r.l0_rounds_us as f64; + row.push_str(&format!( + ",{},{},{:.2}", + enabled_r.l0_rounds_us, disabled_r.l0_rounds_us, speedup + )); + } + ShowTimings::TotalOnly => { + let speedup = disabled_r.total_us() as f64 / enabled_r.total_us() as f64; + row.push_str(&format!( + ",{},{},{:.2}", + enabled_r.total_us(), + disabled_r.total_us(), + speedup + )); + } + ShowTimings::All => { + let accum_speedup = disabled_r.accum_us as f64 / enabled_r.accum_us as f64; + let l0_speedup = disabled_r.l0_rounds_us as f64 / enabled_r.l0_rounds_us as f64; + let total_speedup = disabled_r.total_us() as f64 / enabled_r.total_us() as f64; + row.push_str(&format!( + ",{},{},{},{},{},{},{:.2},{:.2},{:.2}", + enabled_r.accum_us, + enabled_r.l0_rounds_us, + enabled_r.total_us(), + disabled_r.accum_us, + disabled_r.l0_rounds_us, + disabled_r.total_us(), + accum_speedup, + l0_speedup, + total_speedup + )); + } + } + } + } + } + row +} + +/// Generate satisfying witness polynomials (Az * Bz = Cz on boolean hypercube) +fn generate_witness_i64>( + num_vars: usize, +) -> ( + MultilinearPolynomial, + MultilinearPolynomial, + Vec, +) { + let n = 1usize << num_vars; + let az_vals: Vec = (0..n).map(|i| ((i % 1000) + 1) as i64).collect(); + let bz_vals: Vec = (0..n).map(|i| (((i * 7) % 1000) + 1) as i64).collect(); + let taus: Vec = (0..num_vars).map(|i| F::from((i * 7 + 3) as u64)).collect(); + + ( + MultilinearPolynomial::new(az_vals), + MultilinearPolynomial::new(bz_vals), + taus, + ) +} + +/// Run accumulator benchmark with specified delayed modular reduction mode +fn run_accumulator_bench(az: &P, bz: &P, taus: &[E::Scalar], l0: usize) -> BenchResult +where + E: Engine, + E::Scalar: SmallValueField + DelayedReduction, + P: MatVecMLE, + Mode: DelayedModularReductionMode, +{ + // Phase 1: Measure accumulator build time + let start = Instant::now(); + let accumulators = build_accumulators_spartan::<_, _, Mode>(az, bz, taus, l0); + let accum_us = start.elapsed().as_micros(); + + // Phase 2: Measure l0 sumcheck rounds time + let mut small_value = + SmallValueSumCheck::::from_accumulators(accumulators); + + // Initial claim = 0 for satisfying witness (Az * Bz - Cz = 0) + let mut claim_per_round = E::Scalar::ZERO; + + let start = Instant::now(); + + for round in 0..l0 { + // 1. Get t evaluations from accumulators + let t_all = small_value.eval_t_all_u(round); + let t_inf = t_all.at_infinity(); + let t0 = t_all.at_zero(); + + // 2. Get eq round values + let li = small_value.eq_round_values(taus[round]); + + // 3. Derive t1 from sumcheck constraint + let t1 = derive_t1(li.at_zero(), li.at_one(), claim_per_round, t0) + .expect("l1 non-zero for valid witness"); + + // 4. Build univariate polynomial + // s_i(X) = ℓ_i(X) * t_i(X) where ℓ_i(X) = ℓ_∞·X + ℓ_0 and t_i(X) is degree-2 + // We compute evaluations at 0, 1, 2, 3 and interpolate + let l0_val = li.at_zero(); + let linf = li.at_infinity(); + + // t_i(X) = t_inf * X^2 + b*X + t0 where b = t1 - t_inf - t0 + let b = t1 - t_inf - t0; + + // Evaluate s(X) = ℓ(X) * t(X) at X = 0, 1, 2, 3 + // ℓ(X) = linf * X + l0_val + // t(X) = t_inf * X^2 + b * X + t0 + let eval_s = |x: u64| -> E::Scalar { + let x_f = E::Scalar::from(x); + let l_x = linf * x_f + l0_val; + let t_x = t_inf * x_f * x_f + b * x_f + t0; + l_x * t_x + }; + + let evals = vec![eval_s(0), eval_s(1), eval_s(2), eval_s(3)]; + let poly = UniPoly::from_evals(&evals).expect("valid polynomial"); + + // 5. Simulate verifier challenge (deterministic for benchmark) + let r_i = E::Scalar::from((round + 7) as u64); + + // 6. Advance to next round + small_value.advance(&li, r_i); + claim_per_round = poly.evaluate(&r_i); + } + + let l0_rounds_us = start.elapsed().as_micros(); + + BenchResult { + accum_us, + l0_rounds_us, + } +} + +/// Run benchmark for a single num_vars configuration +fn run_single( + num_vars: usize, + l0: usize, + delayed_modular_reduction: &DelayedModularReductionChoice, +) -> TrialResult +where + E: Engine, + E::Scalar: SmallValueField + DelayedReduction, +{ + let (az, bz, taus) = generate_witness_i64::(num_vars); + let n = 1usize << num_vars; + + let delayed_modular_reduction_enabled = match delayed_modular_reduction { + DelayedModularReductionChoice::Enabled | DelayedModularReductionChoice::Both => { + Some(run_accumulator_bench::< + E, + _, + DelayedModularReductionEnabled, + >(&az, &bz, &taus, l0)) + } + DelayedModularReductionChoice::Disabled => None, + }; + + let delayed_modular_reduction_disabled = match delayed_modular_reduction { + DelayedModularReductionChoice::Disabled | DelayedModularReductionChoice::Both => { + Some(run_accumulator_bench::(&az, &bz, &taus, l0)) + } + DelayedModularReductionChoice::Enabled => None, + }; + + TrialResult { + num_vars, + n, + l0, + trial: 0, + delayed_modular_reduction_enabled, + delayed_modular_reduction_disabled, + } +} + +fn run_benchmark(args: &Args) +where + E: Engine, + E::Scalar: SmallValueField + DelayedReduction, +{ + let (min_vars, max_vars) = match &args.command { + Some(Command::Single { vars }) => (*vars, *vars), + Some(Command::RangeSweep { min, max }) => (*min, *max), + None => (16, 24), + }; + + // Print CSV header + println!( + "{}", + build_csv_header(&args.delayed_modular_reduction, &args.show) + ); + + for num_vars in min_vars..=max_vars { + let l0 = args.l0.unwrap_or(num_vars / 4).max(1).min(num_vars - 1); + + // Warmup + if num_vars == min_vars { + info!("Warmup run for num_vars={}", num_vars); + let _ = run_single::(num_vars, l0, &args.delayed_modular_reduction); + } + + for trial in 0..args.trials { + let mut result = run_single::(num_vars, l0, &args.delayed_modular_reduction); + result.trial = trial; + println!( + "{}", + format_csv_row(&result, &args.delayed_modular_reduction, &args.show) + ); + std::io::stdout().flush().unwrap(); + + // Log progress + if let (Some(enabled), Some(disabled)) = ( + result.delayed_modular_reduction_enabled, + result.delayed_modular_reduction_disabled, + ) { + let accum_speedup = disabled.accum_us as f64 / enabled.accum_us as f64; + let l0_speedup = disabled.l0_rounds_us as f64 / enabled.l0_rounds_us as f64; + let total_speedup = disabled.total_us() as f64 / enabled.total_us() as f64; + info!( + "num_vars={} l0={}: accum({:.0}μs vs {:.0}μs, {:.2}x), l0({:.0}μs vs {:.0}μs, {:.2}x), total({:.0}μs vs {:.0}μs, {:.2}x)", + num_vars, + l0, + enabled.accum_us, + disabled.accum_us, + accum_speedup, + enabled.l0_rounds_us, + disabled.l0_rounds_us, + l0_speedup, + enabled.total_us(), + disabled.total_us(), + total_speedup + ); + } + } + } +} + +fn main() { + // Initialize tracing + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .with_target(false) + .init(); + + let args = Args::parse(); + + info!("Accumulator benchmark"); + info!( + "Field: {:?}, Witness: {:?}, delayed_modular_reduction: {:?}, show: {:?}", + args.field, args.witness, args.delayed_modular_reduction, args.show + ); + + match args.field { + FieldChoice::PallasFq => run_benchmark::(&args), + FieldChoice::VestaFp => run_benchmark::(&args), + FieldChoice::Bn254Fr => run_benchmark::(&args), + } +} diff --git a/examples/circuits/mod.rs b/examples/circuits/mod.rs new file mode 100644 index 0000000..ab27afd --- /dev/null +++ b/examples/circuits/mod.rs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Reusable circuits for examples and benchmarks. + +pub mod sha256; +pub use sha256::*; diff --git a/examples/circuits/sha256/bellpepper.rs b/examples/circuits/sha256/bellpepper.rs new file mode 100644 index 0000000..310920a --- /dev/null +++ b/examples/circuits/sha256/bellpepper.rs @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! SHA-256 circuit using bellpepper's implementation (baseline, not small-value compatible). + +use super::{ + alloc_preimage_bits, assert_hash_matches, expose_hash_bits_as_public, hash_to_public_scalars, +}; +use bellpepper::gadgets::sha256::sha256 as bellpepper_sha256; +use bellpepper_core::{Circuit, ConstraintSystem, SynthesisError, num::AllocatedNum}; +use ff::{PrimeField, PrimeFieldBits}; +use spartan2::traits::{Engine, circuit::SpartanCircuit}; +use std::marker::PhantomData; + +/// SHA-256 circuit using bellpepper's implementation. +/// +/// This is the baseline circuit for comparison. It produces coefficients ~2^237 +/// which breaks small-value optimization. +#[derive(Clone, Debug)] +pub struct Sha256Circuit { + pub preimage: Vec, + _p: PhantomData, +} + +impl Sha256Circuit { + pub fn new(preimage: Vec) -> Self { + Self { + preimage, + _p: PhantomData, + } + } +} + +impl SpartanCircuit for Sha256Circuit +where + E::Scalar: PrimeFieldBits, +{ + fn public_values(&self) -> Result, SynthesisError> { + Ok(hash_to_public_scalars(&self.preimage)) + } + + fn shared>( + &self, + _: &mut CS, + ) -> Result>, SynthesisError> { + Ok(vec![]) + } + + fn precommitted>( + &self, + cs: &mut CS, + _: &[AllocatedNum], + ) -> Result>, SynthesisError> { + // Allocate preimage bits (little-endian for bellpepper) + let preimage_bits = alloc_preimage_bits::(cs, &self.preimage, false)?; + + // SHA-256 gadget + let hash_bits = bellpepper_sha256(cs.namespace(|| "sha256"), &preimage_bits)?; + + // Verify against native SHA-256 + assert_hash_matches(&hash_bits, &self.preimage); + + // Expose as public inputs + expose_hash_bits_as_public::(cs, &hash_bits)?; + + Ok(vec![]) + } + + fn num_challenges(&self) -> usize { + 0 + } + + fn synthesize>( + &self, + _: &mut CS, + _: &[AllocatedNum], + _: &[AllocatedNum], + _: Option<&[E::Scalar]>, + ) -> Result<(), SynthesisError> { + Ok(()) + } +} + +impl Circuit for Sha256Circuit { + fn synthesize>(self, cs: &mut CS) -> Result<(), SynthesisError> { + let preimage_bits = alloc_preimage_bits(cs, &self.preimage, false)?; + let _ = bellpepper_sha256(cs.namespace(|| "sha256"), &preimage_bits)?; + Ok(()) + } +} diff --git a/examples/circuits/sha256/chain.rs b/examples/circuits/sha256/chain.rs new file mode 100644 index 0000000..3b89837 --- /dev/null +++ b/examples/circuits/sha256/chain.rs @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! SHA-256 chain circuit using small_sha256 (small-value compatible). + +use super::{ + alloc_preimage_bits, assert_bits_match_bytes, expose_hash_bits_as_public, hash_to_public_scalars, +}; +use bellpepper_core::{Circuit, ConstraintSystem, SynthesisError, num::AllocatedNum}; +use ff::{PrimeField, PrimeFieldBits}; +use sha2::{Digest, Sha256}; +use spartan2::{ + gadgets::small_sha256_with_prefix, + traits::{Engine, circuit::SpartanCircuit}, +}; +use std::marker::PhantomData; + +/// SHA-256 chain circuit using small_sha256 (small-value compatible). +/// +/// Chains `chain_length` SHA-256 hashes starting from a 256-bit input. +/// Hash[0] = SHA-256(input), Hash[i] = SHA-256(Hash[i-1]) +#[derive(Debug, Clone)] +pub struct SmallSha256ChainCircuit { + /// 32-byte (256-bit) input to start the chain + pub input: [u8; 32], + /// Number of SHA-256 hashes in the chain + pub chain_length: usize, + _p: PhantomData, +} + +impl SmallSha256ChainCircuit { + pub fn new(input: [u8; 32], chain_length: usize) -> Self { + Self { + input, + chain_length, + _p: PhantomData, + } + } + + /// Compute the expected final hash by applying SHA-256 chain_length times + pub fn expected_output(&self) -> [u8; 32] { + let mut current = self.input; + for _ in 0..self.chain_length { + let mut hasher = Sha256::new(); + hasher.update(current); + current = hasher.finalize().into(); + } + current + } +} + +impl SpartanCircuit for SmallSha256ChainCircuit +where + E::Scalar: PrimeFieldBits, +{ + fn public_values(&self) -> Result, SynthesisError> { + Ok(hash_to_public_scalars(&self.expected_output())) + } + + fn shared>( + &self, + _: &mut CS, + ) -> Result>, SynthesisError> { + Ok(vec![]) + } + + fn precommitted>( + &self, + cs: &mut CS, + _: &[AllocatedNum], + ) -> Result>, SynthesisError> { + // Allocate input bits (big-endian for small_sha256) + let mut current_bits = alloc_preimage_bits::(cs, &self.input, true)?; + + // Chain SHA-256 hashes + for chain_idx in 0..self.chain_length { + let prefix = format!("c{}_", chain_idx); + let hash_bits = small_sha256_with_prefix(cs, ¤t_bits, &prefix)?; + current_bits = hash_bits; + } + + // Verify against expected output (already a hash, don't re-hash) + assert_bits_match_bytes(¤t_bits, &self.expected_output()); + + // Expose as public inputs + expose_hash_bits_as_public::(cs, ¤t_bits)?; + + Ok(vec![]) + } + + fn num_challenges(&self) -> usize { + 0 + } + + fn synthesize>( + &self, + _: &mut CS, + _: &[AllocatedNum], + _: &[AllocatedNum], + _: Option<&[E::Scalar]>, + ) -> Result<(), SynthesisError> { + Ok(()) + } +} + +impl Circuit for SmallSha256ChainCircuit { + fn synthesize>(self, cs: &mut CS) -> Result<(), SynthesisError> { + // Allocate input bits (big-endian for small_sha256) + let mut current_bits = alloc_preimage_bits(cs, &self.input, true)?; + + // Chain SHA-256 hashes + for chain_idx in 0..self.chain_length { + let prefix = format!("c{}_", chain_idx); + let hash_bits = small_sha256_with_prefix(cs, ¤t_bits, &prefix)?; + current_bits = hash_bits; + } + + Ok(()) + } +} diff --git a/examples/circuits/sha256/mod.rs b/examples/circuits/sha256/mod.rs new file mode 100644 index 0000000..4fa2155 --- /dev/null +++ b/examples/circuits/sha256/mod.rs @@ -0,0 +1,157 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! SHA-256 circuits for examples and benchmarks. +//! +//! This module provides reusable SHA-256 circuit implementations: +//! - [`Sha256Circuit`]: Uses bellpepper's SHA-256 (baseline, not small-value compatible) +//! - [`SmallSha256Circuit`]: Uses small_sha256 gadget (small-value compatible) +//! - [`SmallSha256ChainCircuit`]: Chains multiple small_sha256 calls + +mod bellpepper; +mod chain; +mod small; + +pub use bellpepper::Sha256Circuit; +pub use chain::SmallSha256ChainCircuit; +pub use small::SmallSha256Circuit; + +use bellpepper_core::{ + ConstraintSystem, SynthesisError, + boolean::{AllocatedBit, Boolean}, + num::AllocatedNum, +}; +use ff::{Field, PrimeField}; +use sha2::{Digest, Sha256}; +use spartan2::traits::Engine; + +// ============================================================================ +// Helper Functions +// ============================================================================ + +/// Compute SHA-256 hash and convert to public value scalars. +/// +/// Each bit of the 256-bit hash becomes a field element (0 or 1). +pub fn hash_to_public_scalars(data: &[u8]) -> Vec { + let mut hasher = Sha256::new(); + hasher.update(data); + let hash = hasher.finalize(); + + hash + .iter() + .flat_map(|&byte| { + (0..8).rev().map(move |i| { + if (byte >> i) & 1 == 1 { + F::ONE + } else { + F::ZERO + } + }) + }) + .collect() +} + +/// Allocate preimage bytes as witness bits. +/// +/// If `big_endian` is true, bits are allocated in big-endian order per byte +/// (required for small_sha256). If false, little-endian (bellpepper's convention). +pub fn alloc_preimage_bits( + cs: &mut CS, + preimage: &[u8], + big_endian: bool, +) -> Result, SynthesisError> +where + Scalar: PrimeField, + CS: ConstraintSystem, +{ + preimage + .iter() + .enumerate() + .flat_map(|(byte_idx, &byte)| { + let bit_iter: Box> = if big_endian { + Box::new( + (0..8) + .rev() + .enumerate() + .map(move |(bit_idx, i)| (bit_idx, (byte >> i) & 1 == 1)), + ) + } else { + Box::new((0..8).map(move |i| (i, (byte >> i) & 1 == 1))) + }; + bit_iter.map(move |(bit_idx, bit_val)| (byte_idx, bit_idx, bit_val)) + }) + .map(|(byte_idx, bit_idx, bit_val)| { + AllocatedBit::alloc( + cs.namespace(|| format!("preimage_{}_{}", byte_idx, bit_idx)), + Some(bit_val), + ) + .map(Boolean::from) + }) + .collect() +} + +/// Assert that computed hash bits match native SHA-256 output. +/// +/// Panics if there's a mismatch (debug assertion). +pub fn assert_hash_matches(hash_bits: &[Boolean], preimage: &[u8]) { + let mut hasher = Sha256::new(); + hasher.update(preimage); + let expected = hasher.finalize(); + assert_bits_match_bytes(hash_bits, &expected); +} + +/// Assert that computed bits match expected bytes (big-endian per byte). +/// +/// Panics if there's a mismatch (debug assertion). +pub fn assert_bits_match_bytes(bits: &[Boolean], expected_bytes: &[u8]) { + let expected_bits: Vec = expected_bytes + .iter() + .flat_map(|&byte| (0..8).rev().map(move |i| (byte >> i) & 1 == 1)) + .collect(); + + for (i, (computed, expected_bit)) in bits.iter().zip(expected_bits.iter()).enumerate() { + let computed_val = match computed { + Boolean::Is(bit) => bit.get_value().unwrap(), + Boolean::Not(bit) => !bit.get_value().unwrap(), + Boolean::Constant(b) => *b, + }; + assert_eq!( + computed_val, *expected_bit, + "Hash bit {} mismatch: computed={}, expected={}", + i, computed_val, expected_bit + ); + } +} + +/// Expose hash bits as public inputs with equality constraints. +pub fn expose_hash_bits_as_public( + cs: &mut CS, + hash_bits: &[Boolean], +) -> Result<(), SynthesisError> +where + E: Engine, + CS: ConstraintSystem, +{ + for (i, bit) in hash_bits.iter().enumerate() { + let n = AllocatedNum::alloc_input(cs.namespace(|| format!("public num {i}")), || { + Ok( + if bit.get_value().ok_or(SynthesisError::AssignmentMissing)? { + E::Scalar::ONE + } else { + E::Scalar::ZERO + }, + ) + })?; + + cs.enforce( + || format!("bit == num {i}"), + |_| bit.lc(CS::one(), E::Scalar::ONE), + |lc| lc + CS::one(), + |lc| lc + n.get_variable(), + ); + } + Ok(()) +} diff --git a/examples/circuits/sha256/small.rs b/examples/circuits/sha256/small.rs new file mode 100644 index 0000000..f49f598 --- /dev/null +++ b/examples/circuits/sha256/small.rs @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! SHA-256 circuit using small_sha256 gadget (small-value compatible). + +use super::{ + alloc_preimage_bits, assert_hash_matches, expose_hash_bits_as_public, hash_to_public_scalars, +}; +use bellpepper_core::{Circuit, ConstraintSystem, SynthesisError, num::AllocatedNum}; +use ff::{PrimeField, PrimeFieldBits}; +use spartan2::{ + gadgets::{NoBatchEq, small_sha256, small_sha256_with_small_multi_eq}, + traits::{Engine, circuit::SpartanCircuit}, +}; +use std::marker::PhantomData; + +/// SHA-256 circuit using small_sha256 gadget (small-value compatible). +/// +/// Uses `SmallMultiEq` to keep coefficients bounded for small-value sumcheck. +#[derive(Debug, Clone)] +pub struct SmallSha256Circuit { + pub preimage: Vec, + /// If true, use BatchingEq<21> (i64 path); if false, use NoBatchEq (i32 path) + pub use_batching: bool, + _p: PhantomData, +} + +impl SmallSha256Circuit { + pub fn new(preimage: Vec, use_batching: bool) -> Self { + Self { + preimage, + use_batching, + _p: PhantomData, + } + } +} + +impl SpartanCircuit for SmallSha256Circuit +where + E::Scalar: PrimeFieldBits, +{ + fn public_values(&self) -> Result, SynthesisError> { + Ok(hash_to_public_scalars(&self.preimage)) + } + + fn shared>( + &self, + _: &mut CS, + ) -> Result>, SynthesisError> { + Ok(vec![]) + } + + fn precommitted>( + &self, + cs: &mut CS, + _: &[AllocatedNum], + ) -> Result>, SynthesisError> { + // Allocate preimage bits (big-endian for small_sha256) + let preimage_bits = alloc_preimage_bits::(cs, &self.preimage, true)?; + + // SmallSHA-256 gadget + let hash_bits = if self.use_batching { + small_sha256(cs, &preimage_bits)? + } else { + let mut eq = NoBatchEq::::new(cs); + let bits = small_sha256_with_small_multi_eq(&mut eq, &preimage_bits, "")?; + drop(eq); + bits + }; + + // Verify against native SHA-256 + assert_hash_matches(&hash_bits, &self.preimage); + + // Expose as public inputs + expose_hash_bits_as_public::(cs, &hash_bits)?; + + Ok(vec![]) + } + + fn num_challenges(&self) -> usize { + 0 + } + + fn synthesize>( + &self, + _: &mut CS, + _: &[AllocatedNum], + _: &[AllocatedNum], + _: Option<&[E::Scalar]>, + ) -> Result<(), SynthesisError> { + Ok(()) + } +} + +impl Circuit for SmallSha256Circuit { + fn synthesize>(self, cs: &mut CS) -> Result<(), SynthesisError> { + let preimage_bits = alloc_preimage_bits(cs, &self.preimage, true)?; + let _ = if self.use_batching { + small_sha256(cs, &preimage_bits)? + } else { + let mut eq = NoBatchEq::::new(cs); + let bits = small_sha256_with_small_multi_eq(&mut eq, &preimage_bits, "")?; + drop(eq); + bits + }; + Ok(()) + } +} diff --git a/examples/sha256.rs b/examples/sha256.rs index 96cc775..301cfea 100644 --- a/examples/sha256.rs +++ b/examples/sha256.rs @@ -9,159 +9,30 @@ //! circuit with varying message lengths //! //! Run with: `RUST_LOG=info cargo run --release --example sha256` + #[cfg(feature = "jem")] #[global_allocator] -static GLOBAL: Jemalloc = tikv_jemallocator::Jemalloc; -use bellpepper::gadgets::sha256::sha256; -use bellpepper_core::{ - ConstraintSystem, SynthesisError, - boolean::{AllocatedBit, Boolean}, - num::AllocatedNum, -}; -use ff::{Field, PrimeField, PrimeFieldBits}; -use sha2::{Digest, Sha256}; +static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + +#[path = "circuits/mod.rs"] +mod circuits; + +use circuits::Sha256Circuit; use spartan2::{ provider::T256HyraxEngine, spartan::SpartanSNARK, - traits::{Engine, circuit::SpartanCircuit, snark::R1CSSNARKTrait}, + traits::{Engine, snark::R1CSSNARKTrait}, }; -use std::{marker::PhantomData, time::Instant}; +use std::time::Instant; use tracing::{info, info_span}; use tracing_subscriber::EnvFilter; type E = T256HyraxEngine; -#[derive(Clone, Debug)] -struct Sha256Circuit { - preimage: Vec, - _p: PhantomData, -} - -impl Sha256Circuit { - fn new(preimage: Vec) -> Self { - Self { - preimage, - _p: PhantomData, - } - } -} - -impl SpartanCircuit for Sha256Circuit { - fn public_values(&self) -> Result::Scalar>, SynthesisError> { - // compute the SHA-256 hash of the preimage - let mut hasher = Sha256::new(); - hasher.update(&self.preimage); - let hash = hasher.finalize(); - // convert the hash to a vector of scalars - let hash_scalars: Vec<::Scalar> = hash - .iter() - .flat_map(|&byte| { - (0..8).rev().map(move |i| { - if (byte >> i) & 1 == 1 { - E::Scalar::ONE - } else { - E::Scalar::ZERO - } - }) - }) - .collect(); - Ok(hash_scalars) - } - - fn shared>( - &self, - _: &mut CS, - ) -> Result>, SynthesisError> { - // No shared variables in this circuit - Ok(vec![]) - } - - fn precommitted>( - &self, - cs: &mut CS, - _: &[AllocatedNum], // shared variables, if any - ) -> Result>, SynthesisError> { - // 1. Preimage bits - let bit_values: Vec<_> = self - .preimage - .clone() - .into_iter() - .flat_map(|byte| (0..8).map(move |i| (byte >> i) & 1 == 1)) - .map(Some) - .collect(); - assert_eq!(bit_values.len(), self.preimage.len() * 8); - - let preimage_bits = bit_values - .into_iter() - .enumerate() - .map(|(i, b)| AllocatedBit::alloc(cs.namespace(|| format!("preimage bit {i}")), b)) - .map(|b| b.map(Boolean::from)) - .collect::, _>>()?; - - // 2. SHA-256 gadget - let hash_bits = sha256(cs.namespace(|| "sha256"), &preimage_bits)?; - - // 3. Sanity-check against Rust SHA-256 - let mut hasher = Sha256::new(); - hasher.update(&self.preimage); - let expected = hasher.finalize(); - - let mut expected_bits = expected - .iter() - .flat_map(|&byte| (0..8).rev().map(move |i| (byte >> i) & 1 == 1)); - - for b in &hash_bits { - match b { - Boolean::Is(bit) => assert_eq!(expected_bits.next().unwrap(), bit.get_value().unwrap()), - Boolean::Not(bit) => assert_ne!(expected_bits.next().unwrap(), bit.get_value().unwrap()), - Boolean::Constant(_) => unreachable!(), - } - } - - for (i, bit) in hash_bits.iter().enumerate() { - // Allocate public input - let n = AllocatedNum::alloc_input(cs.namespace(|| format!("public num {i}")), || { - Ok( - if bit.get_value().ok_or(SynthesisError::AssignmentMissing)? { - E::Scalar::ONE - } else { - E::Scalar::ZERO - }, - ) - })?; - - // Single equality constraint is enough - cs.enforce( - || format!("bit == num {i}"), - |_| bit.lc(CS::one(), E::Scalar::ONE), - |lc| lc + CS::one(), - |lc| lc + n.get_variable(), - ); - } - - Ok(vec![]) - } - - fn num_challenges(&self) -> usize { - // SHA-256 circuit does not expect any challenges - 0 - } - - fn synthesize>( - &self, - _: &mut CS, - _: &[AllocatedNum], - _: &[AllocatedNum], - _: Option<&[E::Scalar]>, - ) -> Result<(), SynthesisError> { - Ok(()) - } -} - fn main() { tracing_subscriber::fmt() .with_target(false) - .with_ansi(true) // no bold colour codes + .with_ansi(true) .with_env_filter(EnvFilter::from_default_env()) .init(); diff --git a/examples/sha256_chain_benchmark.rs b/examples/sha256_chain_benchmark.rs new file mode 100644 index 0000000..1793a3c --- /dev/null +++ b/examples/sha256_chain_benchmark.rs @@ -0,0 +1,263 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! examples/sha256_chain_benchmark.rs +//! Benchmark SHA-256 hash chains comparing: +//! - Original sumcheck (baseline) +//! - Small-value sumcheck with BatchingEq<21> +//! +//! Run with: `RUST_LOG=info cargo run --release --no-default-features --example sha256_chain_benchmark` +//! Or for CSV only: `cargo run --release --no-default-features --example sha256_chain_benchmark 2>/dev/null` +//! +//! CLI modes: +//! single 26 - Run only 2^26 (for profiling) +//! range-sweep - Sweep 16-26 (default) +//! range-sweep --min 16 --max 20 - Custom range + +#[cfg(feature = "jem")] +#[global_allocator] +static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + +#[path = "circuits/mod.rs"] +mod circuits; + +use circuits::SmallSha256ChainCircuit; +use clap::{Parser, Subcommand}; +use ff::Field; +use spartan2::{ + polys::multilinear::MultilinearPolynomial, + provider::PallasHyraxEngine, + small_field::SmallValueField, + spartan::SpartanSNARK, + sumcheck::SumcheckProof, + traits::{Engine, snark::R1CSSNARKTrait, transcript::TranscriptEngineTrait}, +}; +use std::{io::Write, time::Instant}; +use tracing::info; +use tracing_subscriber::EnvFilter; + +// Use PallasHyraxEngine which has Barrett-optimized SmallLargeMul for Fq +type E = PallasHyraxEngine; +type F = ::Scalar; + +#[derive(Parser)] +#[command(about = "SHA-256 chain benchmark: original vs small-value sumcheck")] +struct Args { + #[command(subcommand)] + command: Option, +} + +#[derive(Subcommand)] +enum Command { + /// Run a single num_vars value (for profiling) + Single { num_vars: usize }, + /// Run a range sweep + RangeSweep { + #[arg(long, default_value = "16")] + min: usize, + #[arg(long, default_value = "26")] + max: usize, + }, +} + +/// Convert num_vars to chain_length. +/// num_vars=16 → chain=2, num_vars=18 → chain=8, etc. +/// Formula: chain_length = 2^(num_vars - 15) +fn num_vars_to_chain_length(num_vars: usize) -> usize { + 1 << (num_vars - 15) +} + +/// Benchmark result for a single chain length +struct BenchmarkResult { + chain_length: usize, + num_vars: usize, + num_constraints: usize, + witness_ms: u128, + extract_ms: u128, + orig_sumcheck_ms: u128, + small_sumcheck_ms: u128, +} + +fn run_chain_benchmark( + input: [u8; 32], + chain_length: usize, + expected_num_vars: usize, +) -> BenchmarkResult +where + F: SmallValueField, +{ + // Create the circuit + let small_circuit = SmallSha256ChainCircuit::::new(input, chain_length); + + // === SETUP (one-time, not included in proving time) === + let t0 = Instant::now(); + let (pk, _vk) = SpartanSNARK::::setup(small_circuit.clone()).expect("setup failed"); + let setup_ms = t0.elapsed().as_millis(); + let num_constraints = pk.sizes()[4]; + info!(setup_ms, num_constraints, "setup"); + + // === WITNESS SYNTHESIS === + let t0 = Instant::now(); + let prep_snark = + SpartanSNARK::::prep_prove(&pk, small_circuit.clone(), true).expect("prep_prove failed"); + let witness_ms = t0.elapsed().as_millis(); + info!(witness_ms, "witness synthesis"); + + // === EXTRACT SUMCHECK INPUTS === + let t0 = Instant::now(); + let (az, bz, cz, tau) = + SpartanSNARK::::extract_outer_sumcheck_inputs(&pk, small_circuit, &prep_snark) + .expect("extract_outer_sumcheck_inputs failed"); + let extract_ms = t0.elapsed().as_millis(); + info!(extract_ms, "extract inputs"); + + let num_vars = tau.len(); + + // Verify we got the expected num_vars + assert_eq!( + num_vars, expected_num_vars, + "Expected num_vars={} but got {}. Adjust chain_length.", + expected_num_vars, num_vars + ); + + let claim = F::ZERO; + + // ===== ORIGINAL SUMCHECK ===== + // Run in scope so memory is freed before next benchmark + let (proof1, r1, evals1, orig_sumcheck_ms) = { + let mut az1 = MultilinearPolynomial::new(az.clone()); + let mut bz1 = MultilinearPolynomial::new(bz.clone()); + let mut cz1 = MultilinearPolynomial::new(cz.clone()); + let mut transcript1 = ::TE::new(b"sha256_chain_bench"); + + let t0 = Instant::now(); + let (proof1, r1, evals1) = SumcheckProof::::prove_cubic_with_three_inputs( + &claim, + tau.clone(), + &mut az1, + &mut bz1, + &mut cz1, + &mut transcript1, + ) + .expect("prove_cubic_with_three_inputs failed"); + let elapsed = t0.elapsed().as_millis(); + info!(orig_sumcheck_ms = elapsed, "original sumcheck"); + (proof1, r1, evals1, elapsed) + }; // az1, bz1, cz1 dropped here + + // ===== SMALL-VALUE SUMCHECK ===== + // Run in scope so memory is freed after benchmark + let (proof2, r2, evals2, small_sumcheck_ms) = { + // Create small-value polynomials for the optimized method (using i64) + let az_poly = MultilinearPolynomial::new(az.clone()); + let bz_poly = MultilinearPolynomial::new(bz.clone()); + let az_small = + MultilinearPolynomial::::try_from_field(&az_poly).expect("Az values too large for i64"); + let bz_small = + MultilinearPolynomial::::try_from_field(&bz_poly).expect("Bz values too large for i64"); + + let mut az2 = MultilinearPolynomial::new(az); + let mut bz2 = MultilinearPolynomial::new(bz); + let mut cz2 = MultilinearPolynomial::new(cz); + let mut transcript2 = ::TE::new(b"sha256_chain_bench"); + + let t0 = Instant::now(); + let (proof2, r2, evals2) = SumcheckProof::::prove_cubic_with_three_inputs_small_value( + &claim, + tau, + &az_small, + &bz_small, + &mut az2, + &mut bz2, + &mut cz2, + &mut transcript2, + ) + .expect("prove_cubic_with_three_inputs_small_value failed"); + let elapsed = t0.elapsed().as_millis(); + info!(small_sumcheck_ms = elapsed, "small-value sumcheck"); + (proof2, r2, evals2, elapsed) + }; + + // Verify equivalence + assert_eq!(r1, r2, "Challenges must match!"); + assert_eq!(proof1, proof2, "Round polynomials must match!"); + assert_eq!(evals1, evals2, "Final evaluations must match!"); + + BenchmarkResult { + chain_length, + num_vars, + num_constraints, + witness_ms, + extract_ms, + orig_sumcheck_ms, + small_sumcheck_ms, + } +} + +fn print_csv_header() { + println!( + "chain_length,num_vars,log2_constraints,num_constraints,witness_ms,orig_sumcheck_ms,small_sumcheck_ms,total_proving_ms,speedup,witness_pct" + ); +} + +fn print_csv_row(result: &BenchmarkResult) { + let speedup = result.orig_sumcheck_ms as f64 / result.small_sumcheck_ms as f64; + + // Total proving time (excluding one-time setup) + let total_ms = result.witness_ms + result.extract_ms + result.small_sumcheck_ms; + let witness_pct = (result.witness_ms as f64 / total_ms as f64) * 100.0; + + // log2(num_constraints) = number of sumcheck rounds + let log2_constraints = (result.num_constraints as f64).log2(); + + println!( + "{},{},{:.3},{},{},{},{},{},{:.2},{:.1}", + result.chain_length, + result.num_vars, + log2_constraints, + result.num_constraints, + result.witness_ms, + result.orig_sumcheck_ms, + result.small_sumcheck_ms, + total_ms, + speedup, + witness_pct + ); + std::io::stdout().flush().ok(); +} + +fn main() { + // Initialize tracing to stderr so CSV output goes to stdout cleanly + tracing_subscriber::fmt() + .with_target(false) + .with_ansi(true) + .with_env_filter(EnvFilter::from_default_env()) + .with_writer(std::io::stderr) + .init(); + + let args = Args::parse(); + + // Determine which num_vars values to run + let num_vars_list: Vec = match args.command { + Some(Command::Single { num_vars }) => vec![num_vars], + Some(Command::RangeSweep { min, max }) => (min..=max).collect(), + None => vec![16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26], + }; + + // Use a deterministic input + let input: [u8; 32] = [ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + ]; + + print_csv_header(); + + for num_vars in num_vars_list { + let chain_length = num_vars_to_chain_length(num_vars); + let result = run_chain_benchmark(input, chain_length, num_vars); + print_csv_row(&result); + } +} diff --git a/examples/sumcheck_sha256_equivalence.rs b/examples/sumcheck_sha256_equivalence.rs new file mode 100644 index 0000000..fb41a51 --- /dev/null +++ b/examples/sumcheck_sha256_equivalence.rs @@ -0,0 +1,313 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! examples/sumcheck_sha256_equivalence.rs +//! Verify that prove_cubic_with_three_inputs and prove_cubic_with_three_inputs_small_value +//! produce identical proofs when used with a SmallSha256 circuit. +//! +//! This tests Algorithm 6 (small-value sumcheck optimization) against the standard method. +//! Unlike bellpepper's SHA-256 which has coefficients ~2^237, our SmallSha256Circuit uses +//! SmallMultiEq to keep coefficients within bounded ranges. +//! +//! Run with: `RUST_LOG=info cargo run --release --example sumcheck_sha256_equivalence` + +#[cfg(feature = "jem")] +#[global_allocator] +static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + +#[path = "circuits/mod.rs"] +mod circuits; + +use bellpepper_core::{Circuit, test_cs::TestConstraintSystem}; +use circuits::{Sha256Circuit, SmallSha256Circuit}; +use ff::Field; +use spartan2::{ + polys::multilinear::MultilinearPolynomial, + provider::PallasHyraxEngine, + small_field::SmallValueField, + spartan::SpartanSNARK, + sumcheck::SumcheckProof, + traits::{Engine, snark::R1CSSNARKTrait, transcript::TranscriptEngineTrait}, +}; +use std::time::Instant; +use tracing::{info, info_span, instrument}; +use tracing_subscriber::{EnvFilter, fmt::time::uptime}; + +// Use PallasHyraxEngine which has Barrett-optimized SmallLargeMul for Fq +type E = PallasHyraxEngine; +type F = ::Scalar; + +#[instrument(skip_all)] +fn run_setup( + circuit: SmallSha256Circuit, +) -> ( + as R1CSSNARKTrait>::ProverKey, + as R1CSSNARKTrait>::VerifierKey, +) { + let t0 = Instant::now(); + let result = SpartanSNARK::::setup(circuit).expect("setup failed"); + info!(elapsed_ms = t0.elapsed().as_millis(), "completed"); + result +} + +#[instrument(skip_all)] +fn run_prep_prove( + pk: & as R1CSSNARKTrait>::ProverKey, + circuit: SmallSha256Circuit, +) -> as R1CSSNARKTrait>::PrepSNARK { + let t0 = Instant::now(); + let result = SpartanSNARK::::prep_prove(pk, circuit, true).expect("prep_prove failed"); + info!(elapsed_ms = t0.elapsed().as_millis(), "completed"); + result +} + +#[instrument(skip_all)] +fn extract_sumcheck_inputs( + pk: & as R1CSSNARKTrait>::ProverKey, + circuit: SmallSha256Circuit, + prep_snark: & as R1CSSNARKTrait>::PrepSNARK, +) -> (Vec, Vec, Vec, Vec) { + let t0 = Instant::now(); + let (az, bz, cz, tau) = SpartanSNARK::::extract_outer_sumcheck_inputs(pk, circuit, prep_snark) + .expect("extract_outer_sumcheck_inputs failed"); + info!(elapsed_ms = t0.elapsed().as_millis(), "completed"); + + // Check small-value compatibility: Az and Bz must fit in i64 + let mut az_failures = 0; + let mut bz_failures = 0; + for (i, val) in az.iter().enumerate() { + if >::try_field_to_small(val).is_none() { + if az_failures < 5 { + info!("Az[{}] doesn't fit in i64: {:?}", i, val); + } + az_failures += 1; + } + } + for (i, val) in bz.iter().enumerate() { + if >::try_field_to_small(val).is_none() { + if bz_failures < 5 { + info!("Bz[{}] doesn't fit in i64: {:?}", i, val); + } + bz_failures += 1; + } + } + if az_failures > 0 || bz_failures > 0 { + info!( + az_failures, + bz_failures, "Small-value compatibility check FAILED" + ); + } + + (az, bz, cz, tau) +} + +fn test_sumcheck_equivalence_for_message_len(preimage_len: usize, use_batching: bool) +where + F: SmallValueField, +{ + let config_name = if use_batching { + "BatchingEq<21>" + } else { + "NoBatchEq" + }; + let _span = info_span!("test", msg_len = preimage_len, config = config_name).entered(); + + // Create both circuits + let preimage = vec![0u8; preimage_len]; + let small_circuit = SmallSha256Circuit::::new(preimage.clone(), use_batching); + let bellpepper_circuit = Sha256Circuit::::new(preimage); + + // Synthesize and count constraints for SmallSha256 + let mut cs1 = TestConstraintSystem::::new(); + small_circuit + .clone() + .synthesize(&mut cs1) + .expect("small_sha256 synthesis failed"); + let small_sha256_constraints = cs1.num_constraints(); + + // Synthesize and count constraints for bellpepper SHA256 + let mut cs2 = TestConstraintSystem::::new(); + bellpepper_circuit + .synthesize(&mut cs2) + .expect("bellpepper synthesis failed"); + let bellpepper_sha256_constraints = cs2.num_constraints(); + + info!( + msg_len = preimage_len, + small_sha256_constraints, bellpepper_sha256_constraints, "Constraint comparison" + ); + + let (pk, _vk) = run_setup(small_circuit.clone()); + let prep_snark = run_prep_prove(&pk, small_circuit.clone()); + let (az, bz, cz, tau) = extract_sumcheck_inputs(&pk, small_circuit, &prep_snark); + + let num_vars = tau.len(); + info!( + num_vars, + poly_len = az.len(), + "Extracted sumcheck polynomials" + ); + + // Claim is zero for satisfying R1CS (Az * Bz = Cz) + let claim = F::ZERO; + + // ===== ORIGINAL METHOD ===== + // Run in scope so memory is freed before next benchmark + let (proof1, r1, evals1, original_us) = { + let mut az1 = MultilinearPolynomial::new(az.clone()); + let mut bz1 = MultilinearPolynomial::new(bz.clone()); + let mut cz1 = MultilinearPolynomial::new(cz.clone()); + + info!("Running prove_cubic_with_three_inputs (original method)..."); + let mut transcript1 = ::TE::new(b"test_equivalence"); + let t0 = Instant::now(); + let (proof1, r1, evals1) = SumcheckProof::::prove_cubic_with_three_inputs( + &claim, + tau.clone(), + &mut az1, + &mut bz1, + &mut cz1, + &mut transcript1, + ) + .expect("prove_cubic_with_three_inputs failed"); + let elapsed = t0.elapsed().as_micros(); + info!(elapsed_us = elapsed, "prove_cubic_with_three_inputs"); + (proof1, r1, evals1, elapsed) + }; // az1, bz1, cz1 dropped here + + // Try to create small-value polynomials for the optimized method + let az_poly = MultilinearPolynomial::new(az.clone()); + let bz_poly = MultilinearPolynomial::new(bz.clone()); + let az_small_opt = MultilinearPolynomial::::try_from_field(&az_poly); + let bz_small_opt = MultilinearPolynomial::::try_from_field(&bz_poly); + + // ===== SMALL-VALUE METHOD (Algorithm 6) ===== + match (az_small_opt, bz_small_opt) { + (Some(az_small), Some(bz_small)) => { + info!("Witness values fit in i64, running small-value optimization with i64/i128 config..."); + + // Verify i64 polynomials match field polynomials (debug check) + let az_back: MultilinearPolynomial = az_small.to_field(); + let bz_back: MultilinearPolynomial = bz_small.to_field(); + let mut mismatch_count = 0; + for i in 0..az.len() { + if az[i] != az_back[i] { + mismatch_count += 1; + if mismatch_count <= 5 { + info!( + "Az mismatch at index {}: field={:?}, i64_to_field={:?}", + i, az[i], az_back[i] + ); + } + } + if bz[i] != bz_back[i] { + mismatch_count += 1; + if mismatch_count <= 5 { + info!( + "Bz mismatch at index {}: field={:?}, i64_to_field={:?}", + i, bz[i], bz_back[i] + ); + } + } + } + if mismatch_count > 0 { + panic!( + "Polynomial mismatch: {} values differ. i64 conversion is lossy!", + mismatch_count + ); + } + info!("Verified: i64 polynomials match field polynomials"); + + // Run in scope so memory is freed after benchmark + let (proof2, r2, evals2, smallvalue_us) = { + let mut az2 = MultilinearPolynomial::new(az); + let mut bz2 = MultilinearPolynomial::new(bz); + let mut cz2 = MultilinearPolynomial::new(cz); + let mut transcript2 = ::TE::new(b"test_equivalence"); + + info!("Running prove_cubic_with_three_inputs_small_value (Algorithm 6)..."); + let t0 = Instant::now(); + let (proof2, r2, evals2) = SumcheckProof::::prove_cubic_with_three_inputs_small_value( + &claim, + tau, + &az_small, + &bz_small, + &mut az2, + &mut bz2, + &mut cz2, + &mut transcript2, + ) + .expect("prove_cubic_with_three_inputs_small_value failed"); + let elapsed = t0.elapsed().as_micros(); + info!( + elapsed_us = elapsed, + "prove_cubic_with_three_inputs_small_value" + ); + (proof2, r2, evals2, elapsed) + }; + + // ===== VERIFY EQUIVALENCE ===== + info!("Verifying equivalence..."); + + assert_eq!(r1, r2, "Challenges must match!"); + info!("Challenges match (len={})", r1.len()); + + assert_eq!(proof1, proof2, "Round polynomials must match!"); + info!("Round polynomials match"); + + assert_eq!(evals1, evals2, "Final evaluations must match!"); + info!("Final evaluations match"); + + let speedup = if smallvalue_us > 0 { + original_us as f64 / smallvalue_us as f64 + } else { + f64::INFINITY + }; + + info!( + msg_len = preimage_len, + small_sha256_constraints, + bellpepper_sha256_constraints, + original_sumcheck_us = original_us, + small_value_sumcheck_us = smallvalue_us, + speedup = format!("{:.2}x", speedup), + "PASSED: proofs are equivalent" + ); + } + _ => { + panic!( + "Az/Bz values too large for small-value optimization (don't fit in i64). \ + Config: {}. The small-value optimization requires all Az and Bz values to fit in i64.", + config_name + ); + } + } +} + +fn main() { + tracing_subscriber::fmt() + .with_target(false) + .with_ansi(true) + .with_timer(uptime()) + .with_env_filter(EnvFilter::from_default_env()) + .init(); + + info!("Testing sumcheck method equivalence with SmallSha256Circuit"); + info!("SmallSha256Circuit uses SmallMultiEq to keep coefficients bounded"); + + // Message lengths: 1024 bytes produces num_vars=20 + let preimage_len = 1024; + + // Test with NoBatchEq (i32 path) - direct constraints + info!("===== Testing with NoBatchEq (i32 path, no batching) ====="); + test_sumcheck_equivalence_for_message_len(preimage_len, false); + + // Test with BatchingEq<21> (i64 path) - batched constraints + info!("===== Testing with BatchingEq<21> (i64 path, batch 21) ====="); + test_sumcheck_equivalence_for_message_len(preimage_len, true); + + info!("All tests passed for both NoBatchEq and BatchingEq<21>!"); +} diff --git a/examples/sumcheck_sweep.rs b/examples/sumcheck_sweep.rs new file mode 100644 index 0000000..8b012d7 --- /dev/null +++ b/examples/sumcheck_sweep.rs @@ -0,0 +1,597 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! examples/sumcheck_sweep.rs +//! Sweep benchmark for sumcheck methods across polynomial sizes 2^10 to 2^30 +//! Produces CSV output for chart generation +//! +//! Run with: RUST_LOG=info cargo run --release --example sumcheck_sweep +//! Or for CSV only: cargo run --release --example sumcheck_sweep 2>/dev/null > results.csv +//! +//! CLI modes: +//! single 26 - Run only 26 variables (for profiling) +//! range-sweep - Sweep 10-24 (default) +//! range-sweep --min 20 --max 26 - Custom range + +#[cfg(feature = "jem")] +#[global_allocator] +static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + +use clap::{Parser, Subcommand, ValueEnum}; +use spartan2::{ + polys::multilinear::MultilinearPolynomial, + provider::{Bn254Engine, PallasHyraxEngine, VestaHyraxEngine}, + small_field::{DelayedReduction, SmallValueField}, + sumcheck::SumcheckProof, + traits::{Engine, transcript::TranscriptEngineTrait}, +}; +use std::{io::Write, time::Instant}; +use tracing::{info, info_span}; +use tracing_subscriber::EnvFilter; + +/// Field choice for benchmarks +#[derive(ValueEnum, Clone, Default, Debug)] +enum FieldChoice { + /// Pallas curve scalar field (Fq) + #[default] + PallasFq, + /// Vesta curve scalar field (Fp) + VestaFp, + /// BN254 curve scalar field (Fr) + Bn254Fr, +} + +/// Sumcheck method to benchmark +#[derive(ValueEnum, Clone, Debug, PartialEq, Eq)] +enum SumcheckMethod { + /// Baseline cubic sumcheck + Base, + /// Split-eq with delayed modular reduction + SplitEqDmr, + /// Small-value i32 optimization + I32, + /// Small-value i64 optimization + I64, +} + +#[derive(Parser)] +#[command(about = "Sumcheck benchmark sweep")] +struct Args { + /// Field to use for benchmarks + #[arg(long, value_enum, default_value = "pallas-fq")] + field: FieldChoice, + + /// Methods to benchmark (comma-separated). Default: all + #[arg( + long, + value_enum, + default_value = "base,i32,i64", + value_delimiter = ',' + )] + methods: Vec, + + /// Number of trials per num_vars (each recorded separately) + #[arg(long, default_value = "1")] + trials: usize, + + #[command(subcommand)] + command: Option, +} + +/// Tracks which benchmark methods to run +struct BenchMethods { + base: bool, + split_eq_delayed_modular_reduction: bool, + i32_small: bool, + i64_small: bool, +} + +impl BenchMethods { + fn from_args(methods: &[SumcheckMethod]) -> Self { + Self { + base: methods.contains(&SumcheckMethod::Base), + split_eq_delayed_modular_reduction: methods.contains(&SumcheckMethod::SplitEqDmr), + i32_small: methods.contains(&SumcheckMethod::I32), + i64_small: methods.contains(&SumcheckMethod::I64), + } + } + + /// Returns a copy with i32 disabled (for fields that don't support it) + fn without_i32(&self) -> Self { + Self { + base: self.base, + split_eq_delayed_modular_reduction: self.split_eq_delayed_modular_reduction, + i32_small: false, + i64_small: self.i64_small, + } + } +} + +/// Benchmark result with separate setup and prove times +#[derive(Clone, Copy, Default)] +struct BenchResult { + setup_us: u128, + prove_us: u128, +} + +/// Result of a single trial for a given num_vars +struct TrialResult { + num_vars: usize, + n: usize, + trial: usize, + base: Option, + split_eq_delayed_modular_reduction: Option, + i32_small: Option, + i64_small: Option, +} + +#[derive(Subcommand)] +enum Command { + /// Run a single level (for profiling) + Single { vars: usize }, + /// Run a range sweep + RangeSweep { + #[arg(long, default_value = "10")] + min: usize, + #[arg(long, default_value = "24")] + max: usize, + }, +} + +/// Helper to create field-element polynomials from i32 values +#[inline] +fn make_field_polys(az_i32: &[i32], bz_i32: &[i32]) -> (Vec, Vec, Vec) { + let az_vals: Vec = az_i32.iter().map(|&v| F::from(v as u64)).collect(); + let bz_vals: Vec = bz_i32.iter().map(|&v| F::from(v as u64)).collect(); + let cz_vals: Vec = az_vals.iter().zip(&bz_vals).map(|(a, b)| *a * *b).collect(); + (az_vals, bz_vals, cz_vals) +} + +/// Returns (base_result, split_eq_delayed_modular_reduction_result, i64_result) as Options based on selected methods. +/// Each BenchResult contains separate setup_us and prove_us timings. +/// +/// Memory-optimized: each benchmark runs in its own scope so memory is freed +/// before the next benchmark starts. This reduces peak memory from ~78GB to ~26GB +/// for num_vars=28. +fn run_single_benchmark( + num_vars: usize, + methods: &BenchMethods, +) -> ( + Option, + Option, + Option, +) +where + E: Engine, + E::Scalar: SmallValueField + DelayedReduction, +{ + type F = ::Scalar; + + let _span = info_span!("benchmark", num_vars).entered(); + let n = 1usize << num_vars; + + // Setup: deterministic polynomials with satisfying witness (Cz = Az * Bz) + // build_accumulators_spartan assumes Az·Bz = Cz on binary points + // + // Create small i32 values - these are small and kept for all benchmarks + let az_i32: Vec = (0..n).map(|i| (i + 1) as i32).collect(); + let bz_i32: Vec = (0..n).map(|i| (i + 3) as i32).collect(); + let taus: Vec> = (0..num_vars) + .map(|i| F::::from((i + 2) as u64)) + .collect(); + + // Claim = 0 for satisfying witness + let claim: F = F::::from(0u64); + + // ===== ORIGINAL METHOD ===== + // Run in scope so memory is freed before next benchmark + let (result_base, r1, evals1) = if methods.base { + let t_setup = Instant::now(); + let (az_vals, bz_vals, cz_vals) = make_field_polys::>(&az_i32, &bz_i32); + let mut az1 = MultilinearPolynomial::new(az_vals); + let mut bz1 = MultilinearPolynomial::new(bz_vals); + let mut cz1 = MultilinearPolynomial::new(cz_vals); + let mut transcript1 = E::TE::new(b"bench"); + let setup_us = t_setup.elapsed().as_micros(); + + let t_prove = Instant::now(); + let (_proof1, r1, evals1) = SumcheckProof::::prove_cubic_with_three_inputs( + &claim, + taus.clone(), + &mut az1, + &mut bz1, + &mut cz1, + &mut transcript1, + ) + .unwrap(); + let prove_us = t_prove.elapsed().as_micros(); + info!(setup_us, prove_us, "prove_cubic_with_three_inputs"); + ( + Some(BenchResult { setup_us, prove_us }), + Some(r1), + Some(evals1), + ) + } else { + (None, None, None) + }; // az1, bz1, cz1, az_vals, bz_vals, cz_vals dropped here + + // ===== SPLIT-EQ DELAYED MODULAR REDUCTION METHOD ===== + let (result_split_eq_delayed_modular_reduction, r2, evals2) = + if methods.split_eq_delayed_modular_reduction { + let t_setup = Instant::now(); + let (az_vals, bz_vals, cz_vals) = make_field_polys::>(&az_i32, &bz_i32); + let mut az2 = MultilinearPolynomial::new(az_vals); + let mut bz2 = MultilinearPolynomial::new(bz_vals); + let mut cz2 = MultilinearPolynomial::new(cz_vals); + let mut transcript2 = E::TE::new(b"bench"); + let setup_us = t_setup.elapsed().as_micros(); + + let t_prove = Instant::now(); + let (_proof2, r2, evals2) = + SumcheckProof::::prove_cubic_with_three_inputs_split_eq_delayed::( + &claim, + taus.clone(), + &mut az2, + &mut bz2, + &mut cz2, + &mut transcript2, + ) + .unwrap(); + let prove_us = t_prove.elapsed().as_micros(); + info!( + setup_us, + prove_us, "prove_cubic_with_three_inputs_split_eq_delayed" + ); + ( + Some(BenchResult { setup_us, prove_us }), + Some(r2), + Some(evals2), + ) + } else { + (None, None, None) + }; + + // Note: i32 benchmark is handled separately in run_i32_benchmark() for fields that support it + + // ===== SMALL-VALUE METHOD (i64/i128) ===== + let (result_i64, r3, evals3) = if methods.i64_small { + let t_setup = Instant::now(); + // Create i64 small-value polynomials + let az_i64: Vec = az_i32.iter().map(|&v| v as i64).collect(); + let bz_i64: Vec = bz_i32.iter().map(|&v| v as i64).collect(); + let az_small_i64 = MultilinearPolynomial::new(az_i64); + let bz_small_i64 = MultilinearPolynomial::new(bz_i64); + + // Need fresh field-element polynomials for binding in later rounds + let (az_vals, bz_vals, cz_vals) = make_field_polys::>(&az_i32, &bz_i32); + let mut az3 = MultilinearPolynomial::new(az_vals); + let mut bz3 = MultilinearPolynomial::new(bz_vals); + let mut cz3 = MultilinearPolynomial::new(cz_vals); + let mut transcript3 = E::TE::new(b"bench"); + let setup_us = t_setup.elapsed().as_micros(); + + let t_prove = Instant::now(); + let (_proof3, r3, evals3) = SumcheckProof::::prove_cubic_with_three_inputs_small_value( + &claim, + taus, + &az_small_i64, + &bz_small_i64, + &mut az3, + &mut bz3, + &mut cz3, + &mut transcript3, + ) + .unwrap(); + let prove_us = t_prove.elapsed().as_micros(); + info!( + setup_us, + prove_us, "prove_cubic_with_three_inputs_small_value (i64/i128)" + ); + ( + Some(BenchResult { setup_us, prove_us }), + Some(r3), + Some(evals3), + ) + } else { + (None, None, None) + }; + + // Verify split_eq_delayed_modular_reduction matches base (only when both methods were run) + if let (Some(r1), Some(r2)) = (&r1, &r2) { + assert_eq!( + r1, r2, + "split_eq_delayed_modular_reduction challenges must match base" + ); + } + if let (Some(e1), Some(e2)) = (&evals1, &evals2) { + assert_eq!( + e1, e2, + "split_eq_delayed_modular_reduction final evaluations must match base" + ); + } + + // Verify i64 matches base (only when both methods were run) + if let (Some(r1), Some(r3)) = (&r1, &r3) { + assert_eq!(r1, r3, "i64 challenges must match"); + } + if let (Some(e1), Some(e3)) = (&evals1, &evals3) { + assert_eq!(e1, e3, "i64 final evaluations must match"); + } + + ( + result_base, + result_split_eq_delayed_modular_reduction, + result_i64, + ) +} + +/// Run i32 benchmark separately (only for fields that support SmallValueField) +fn run_i32_benchmark(num_vars: usize, az_i32: &[i32], bz_i32: &[i32]) -> Option +where + E: Engine, + E::Scalar: SmallValueField + DelayedReduction, +{ + type F = ::Scalar; + + let taus: Vec> = (0..num_vars) + .map(|i| F::::from((i + 2) as u64)) + .collect(); + let claim: F = F::::from(0u64); + + let t_setup = Instant::now(); + let az_small = MultilinearPolynomial::new(az_i32.to_vec()); + let bz_small = MultilinearPolynomial::new(bz_i32.to_vec()); + + let (az_vals, bz_vals, cz_vals) = make_field_polys::>(az_i32, bz_i32); + let mut az2 = MultilinearPolynomial::new(az_vals); + let mut bz2 = MultilinearPolynomial::new(bz_vals); + let mut cz2 = MultilinearPolynomial::new(cz_vals); + let mut transcript2 = E::TE::new(b"bench"); + let setup_us = t_setup.elapsed().as_micros(); + + let t_prove = Instant::now(); + let (_proof2, _r2, _evals2) = SumcheckProof::::prove_cubic_with_three_inputs_small_value( + &claim, + taus, + &az_small, + &bz_small, + &mut az2, + &mut bz2, + &mut cz2, + &mut transcript2, + ) + .unwrap(); + let prove_us = t_prove.elapsed().as_micros(); + info!( + setup_us, + prove_us, "prove_cubic_with_three_inputs_small_value (i32/i64)" + ); + Some(BenchResult { setup_us, prove_us }) +} + +/// Build dynamic CSV header based on selected methods +fn build_csv_header(methods: &BenchMethods) -> String { + let mut cols = vec!["num_vars", "n", "trial"]; + if methods.base { + cols.push("base_setup_us"); + cols.push("base_prove_us"); + } + if methods.split_eq_delayed_modular_reduction { + cols.push("split_eq_delayed_modular_reduction_setup_us"); + cols.push("split_eq_delayed_modular_reduction_prove_us"); + } + if methods.i32_small { + cols.push("i32_setup_us"); + cols.push("i32_prove_us"); + } + if methods.i64_small { + cols.push("i64_setup_us"); + cols.push("i64_prove_us"); + } + // Add speedup columns only when comparing base with others (based on prove time) + if methods.base && methods.split_eq_delayed_modular_reduction { + cols.push("prove_speedup_split_eq_delayed_modular_reduction"); + } + if methods.base && methods.i32_small { + cols.push("prove_speedup_i32"); + } + if methods.base && methods.i64_small { + cols.push("prove_speedup_i64"); + } + cols.join(",") +} + +/// Format a trial result as a CSV row +fn format_csv_row(result: &TrialResult, methods: &BenchMethods) -> String { + let mut vals: Vec = vec![ + result.num_vars.to_string(), + result.n.to_string(), + result.trial.to_string(), + ]; + + if methods.base { + let r = result.base.unwrap(); + vals.push(r.setup_us.to_string()); + vals.push(r.prove_us.to_string()); + } + if methods.split_eq_delayed_modular_reduction { + let r = result.split_eq_delayed_modular_reduction.unwrap(); + vals.push(r.setup_us.to_string()); + vals.push(r.prove_us.to_string()); + } + if methods.i32_small { + let r = result.i32_small.unwrap(); + vals.push(r.setup_us.to_string()); + vals.push(r.prove_us.to_string()); + } + if methods.i64_small { + let r = result.i64_small.unwrap(); + vals.push(r.setup_us.to_string()); + vals.push(r.prove_us.to_string()); + } + + // Add speedup columns only when comparing base with others (based on prove time) + if methods.base && methods.split_eq_delayed_modular_reduction { + let base_prove = result.base.unwrap().prove_us as f64; + let delayed_modular_reduction_prove = + result.split_eq_delayed_modular_reduction.unwrap().prove_us as f64; + let speedup = if delayed_modular_reduction_prove > 0.0 { + base_prove / delayed_modular_reduction_prove + } else { + f64::INFINITY + }; + vals.push(format!("{:.3}", speedup)); + } + if methods.base && methods.i32_small { + let base_prove = result.base.unwrap().prove_us as f64; + let i32_prove = result.i32_small.unwrap().prove_us as f64; + let speedup = if i32_prove > 0.0 { + base_prove / i32_prove + } else { + f64::INFINITY + }; + vals.push(format!("{:.3}", speedup)); + } + if methods.base && methods.i64_small { + let base_prove = result.base.unwrap().prove_us as f64; + let i64_prove = result.i64_small.unwrap().prove_us as f64; + let speedup = if i64_prove > 0.0 { + base_prove / i64_prove + } else { + f64::INFINITY + }; + vals.push(format!("{:.3}", speedup)); + } + + vals.join(",") +} + +/// Run sweep with full i32+i64 support (for Pallas/Vesta) +/// Prints each trial result immediately after completion +fn run_sumcheck_sweep_with_i32( + min_vars: usize, + max_vars: usize, + num_trials: usize, + methods: &BenchMethods, +) where + E: Engine, + E::Scalar: SmallValueField + + SmallValueField + + DelayedReduction + + DelayedReduction, +{ + println!("{}", build_csv_header(methods)); + std::io::stdout().flush().ok(); + + for num_vars in min_vars..=max_vars { + let n = 1usize << num_vars; + + // Create test data once per num_vars + let az_i32: Vec = (0..n).map(|i| (i + 1) as i32).collect(); + let bz_i32: Vec = (0..n).map(|i| (i + 3) as i32).collect(); + + for trial in 1..=num_trials { + let (base_r, split_eq_delayed_modular_reduction_r, i64_r) = + run_single_benchmark::(num_vars, methods); + let i32_r = if methods.i32_small { + run_i32_benchmark::(num_vars, &az_i32, &bz_i32) + } else { + None + }; + + let result = TrialResult { + num_vars, + n, + trial, + base: base_r, + split_eq_delayed_modular_reduction: split_eq_delayed_modular_reduction_r, + i32_small: i32_r, + i64_small: i64_r, + }; + println!("{}", format_csv_row(&result, methods)); + std::io::stdout().flush().ok(); + } + } +} + +/// Run sweep with i64 only (for BN254 which doesn't support i32) +/// Prints each trial result immediately after completion +fn run_sumcheck_sweep_i64_only( + min_vars: usize, + max_vars: usize, + num_trials: usize, + methods: &BenchMethods, +) where + E: Engine, + E::Scalar: SmallValueField + DelayedReduction, +{ + // Use methods without i32 (will be skipped) + let methods = methods.without_i32(); + + println!("{}", build_csv_header(&methods)); + std::io::stdout().flush().ok(); + + for num_vars in min_vars..=max_vars { + let n = 1usize << num_vars; + + for trial in 1..=num_trials { + let (base_r, split_eq_delayed_modular_reduction_r, i64_r) = + run_single_benchmark::(num_vars, &methods); + + let result = TrialResult { + num_vars, + n, + trial, + base: base_r, + split_eq_delayed_modular_reduction: split_eq_delayed_modular_reduction_r, + i32_small: None, + i64_small: i64_r, + }; + println!("{}", format_csv_row(&result, &methods)); + std::io::stdout().flush().ok(); + } + } +} + +fn main() { + // Initialize tracing (logs to stderr so CSV can go to stdout) + tracing_subscriber::fmt() + .with_target(false) + .with_ansi(true) + .with_env_filter(EnvFilter::from_default_env()) + .with_writer(std::io::stderr) + .init(); + + let args = Args::parse(); + let methods = BenchMethods::from_args(&args.methods); + + let (min_vars, max_vars) = match args.command { + Some(Command::Single { vars }) => (vars, vars), + Some(Command::RangeSweep { min, max }) => (min, max), + None => (10, 24), // Default: sweep 10-24 + }; + + eprintln!( + "Running sumcheck benchmark (field={:?}, min={}, max={}, trials={}, methods={:?})...", + args.field, min_vars, max_vars, args.trials, args.methods + ); + + // Run benchmarks, printing each trial as it completes + match args.field { + FieldChoice::PallasFq => { + run_sumcheck_sweep_with_i32::(min_vars, max_vars, args.trials, &methods); + } + FieldChoice::VestaFp => { + run_sumcheck_sweep_with_i32::(min_vars, max_vars, args.trials, &methods); + } + FieldChoice::Bn254Fr => { + if methods.i32_small { + eprintln!("Note: i32 benchmarks not supported for BN254, skipping i32 method"); + } + run_sumcheck_sweep_i64_only::(min_vars, max_vars, args.trials, &methods); + } + } +} diff --git a/src/bellpepper/shape_cs.rs b/src/bellpepper/shape_cs.rs index 8892043..9ba0de1 100644 --- a/src/bellpepper/shape_cs.rs +++ b/src/bellpepper/shape_cs.rs @@ -6,19 +6,12 @@ //! Support for generating R1CS shape using bellperson. -use std::{ - cmp::Ordering, - collections::{BTreeMap, HashMap}, -}; +use std::collections::HashMap; use crate::traits::Engine; use bellpepper_core::{ConstraintSystem, Index, LinearCombination, SynthesisError, Variable}; -use core::fmt::Write; use ff::{Field, PrimeField}; -#[derive(Clone, Copy)] -struct OrderedVariable(Variable); - #[derive(Debug)] enum NamedObject { Constraint, @@ -26,32 +19,6 @@ enum NamedObject { Namespace, } -impl Eq for OrderedVariable {} -impl PartialEq for OrderedVariable { - fn eq(&self, other: &OrderedVariable) -> bool { - match (self.0.get_unchecked(), other.0.get_unchecked()) { - (Index::Input(ref a), Index::Input(ref b)) => a == b, - (Index::Aux(ref a), Index::Aux(ref b)) => a == b, - _ => false, - } - } -} -impl PartialOrd for OrderedVariable { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} -impl Ord for OrderedVariable { - fn cmp(&self, other: &Self) -> Ordering { - match (self.0.get_unchecked(), other.0.get_unchecked()) { - (Index::Input(ref a), Index::Input(ref b)) => a.cmp(b), - (Index::Aux(ref a), Index::Aux(ref b)) => a.cmp(b), - (Index::Input(_), Index::Aux(_)) => Ordering::Less, - (Index::Aux(_), Index::Input(_)) => Ordering::Greater, - } - } -} - /// `ShapeCS` is a `ConstraintSystem` for creating `R1CSShape`s for a circuit. pub struct ShapeCS where @@ -70,32 +37,6 @@ where aux: Vec, } -fn proc_lc( - terms: &LinearCombination, -) -> BTreeMap { - let mut map = BTreeMap::new(); - for (var, &coeff) in terms.iter() { - map - .entry(OrderedVariable(var)) - .or_insert_with(|| Scalar::ZERO) - .add_assign(&coeff); - } - - // Remove terms that have a zero coefficient to normalize - let mut to_remove = vec![]; - for (var, coeff) in map.iter() { - if coeff.is_zero().into() { - to_remove.push(*var) - } - } - - for var in to_remove { - map.remove(&var); - } - - map -} - impl ShapeCS where E::Scalar: PrimeField, @@ -120,94 +61,6 @@ where self.aux.len() } - /// Print all public inputs, aux inputs, and constraint names. - #[allow(dead_code)] - pub fn pretty_print_list(&self) -> Vec { - let mut result = Vec::new(); - - for input in &self.inputs { - result.push(format!("INPUT {input}")); - } - for aux in &self.aux { - result.push(format!("AUX {aux}")); - } - - for (_a, _b, _c, name) in &self.constraints { - result.push(name.to_string()); - } - - result - } - - /// Print all iputs and a detailed representation of each constraint. - #[allow(dead_code)] - pub fn pretty_print(&self) -> String { - let mut s = String::new(); - - for input in &self.inputs { - writeln!(s, "INPUT {}", &input).unwrap() - } - - let negone = -::ONE; - - let powers_of_two = (0..E::Scalar::NUM_BITS) - .map(|i| E::Scalar::from(2u64).pow_vartime([u64::from(i)])) - .collect::>(); - - let pp = |s: &mut String, lc: &LinearCombination| { - s.push('('); - let mut is_first = true; - for (var, coeff) in proc_lc::(lc) { - if coeff == negone { - s.push_str(" - ") - } else if !is_first { - s.push_str(" + ") - } - is_first = false; - - if coeff != ::ONE && coeff != negone { - for (i, x) in powers_of_two.iter().enumerate() { - if x == &coeff { - write!(s, "2^{i} . ").unwrap(); - break; - } - } - - write!(s, "{coeff:?} . ").unwrap() - } - - match var.0.get_unchecked() { - Index::Input(i) => { - write!(s, "`I{}`", &self.inputs[i]).unwrap(); - } - Index::Aux(i) => { - write!(s, "`A{}`", &self.aux[i]).unwrap(); - } - } - } - if is_first { - // Nothing was visited, print 0. - s.push('0'); - } - s.push(')'); - }; - - for (a, b, c, name) in &self.constraints { - s.push('\n'); - - write!(s, "{name}: ").unwrap(); - pp(&mut s, a); - write!(s, " * ").unwrap(); - pp(&mut s, b); - s.push_str(" = "); - pp(&mut s, c); - } - - s.push('\n'); - - s - } - /// Associate `NamedObject` with `path`. /// `path` must not already have an associated object. fn set_named_obj(&mut self, path: String, to: NamedObject) { diff --git a/src/bellpepper/test_shape_cs.rs b/src/bellpepper/test_shape_cs.rs index ce5731a..329b920 100644 --- a/src/bellpepper/test_shape_cs.rs +++ b/src/bellpepper/test_shape_cs.rs @@ -9,15 +9,8 @@ use crate::traits::Engine; use bellpepper_core::{ConstraintSystem, Index, LinearCombination, SynthesisError, Variable}; -use core::fmt::Write; use ff::{Field, PrimeField}; -use std::{ - cmp::Ordering, - collections::{BTreeMap, HashMap}, -}; - -#[derive(Clone, Copy)] -struct OrderedVariable(Variable); +use std::collections::HashMap; #[derive(Debug)] enum NamedObject { @@ -26,32 +19,6 @@ enum NamedObject { Namespace, } -impl Eq for OrderedVariable {} -impl PartialEq for OrderedVariable { - fn eq(&self, other: &OrderedVariable) -> bool { - match (self.0.get_unchecked(), other.0.get_unchecked()) { - (Index::Input(ref a), Index::Input(ref b)) | (Index::Aux(ref a), Index::Aux(ref b)) => a == b, - _ => false, - } - } -} -impl PartialOrd for OrderedVariable { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} -impl Ord for OrderedVariable { - fn cmp(&self, other: &Self) -> Ordering { - match (self.0.get_unchecked(), other.0.get_unchecked()) { - (Index::Input(ref a), Index::Input(ref b)) | (Index::Aux(ref a), Index::Aux(ref b)) => { - a.cmp(b) - } - (Index::Input(_), Index::Aux(_)) => Ordering::Less, - (Index::Aux(_), Index::Input(_)) => Ordering::Greater, - } - } -} - /// `TestShapeCS` is a `ConstraintSystem` for creating `R1CSShape`s for a circuit. pub struct TestShapeCS where @@ -70,32 +37,6 @@ where aux: Vec, } -fn proc_lc( - terms: &LinearCombination, -) -> BTreeMap { - let mut map = BTreeMap::new(); - for (var, &coeff) in terms.iter() { - map - .entry(OrderedVariable(var)) - .or_insert_with(|| Scalar::ZERO) - .add_assign(&coeff); - } - - // Remove terms that have a zero coefficient to normalize - let mut to_remove = vec![]; - for (var, coeff) in map.iter() { - if coeff.is_zero().into() { - to_remove.push(*var) - } - } - - for var in to_remove { - map.remove(&var); - } - - map -} - impl TestShapeCS where E::Scalar: PrimeField, @@ -120,94 +61,6 @@ where self.aux.len() } - /// Print all public inputs, aux inputs, and constraint names. - #[allow(dead_code)] - pub fn pretty_print_list(&self) -> Vec { - let mut result = Vec::new(); - - for input in &self.inputs { - result.push(format!("INPUT {input}")); - } - for aux in &self.aux { - result.push(format!("AUX {aux}")); - } - - for (_a, _b, _c, name) in &self.constraints { - result.push(name.to_string()); - } - - result - } - - /// Print all iputs and a detailed representation of each constraint. - #[allow(dead_code)] - pub fn pretty_print(&self) -> String { - let mut s = String::new(); - - for input in &self.inputs { - writeln!(s, "INPUT {}", &input).unwrap() - } - - let negone = -::ONE; - - let powers_of_two = (0..E::Scalar::NUM_BITS) - .map(|i| E::Scalar::from(2u64).pow_vartime([u64::from(i)])) - .collect::>(); - - let pp = |s: &mut String, lc: &LinearCombination| { - s.push('('); - let mut is_first = true; - for (var, coeff) in proc_lc::(lc) { - if coeff == negone { - s.push_str(" - ") - } else if !is_first { - s.push_str(" + ") - } - is_first = false; - - if coeff != ::ONE && coeff != negone { - for (i, x) in powers_of_two.iter().enumerate() { - if x == &coeff { - write!(s, "2^{i} . ").unwrap(); - break; - } - } - - write!(s, "{coeff:?} . ").unwrap() - } - - match var.0.get_unchecked() { - Index::Input(i) => { - write!(s, "`I{}`", &self.inputs[i]).unwrap(); - } - Index::Aux(i) => { - write!(s, "`A{}`", &self.aux[i]).unwrap(); - } - } - } - if is_first { - // Nothing was visited, print 0. - s.push('0'); - } - s.push(')'); - }; - - for (a, b, c, name) in &self.constraints { - s.push('\n'); - - write!(s, "{name}: ").unwrap(); - pp(&mut s, a); - write!(s, " * ").unwrap(); - pp(&mut s, b); - s.push_str(" = "); - pp(&mut s, c); - } - - s.push('\n'); - - s - } - /// Associate `NamedObject` with `path`. /// `path` must not already have an associated object. fn set_named_obj(&mut self, path: String, to: NamedObject) { diff --git a/src/csr.rs b/src/csr.rs new file mode 100644 index 0000000..60ab747 --- /dev/null +++ b/src/csr.rs @@ -0,0 +1,103 @@ +//! Compressed Sparse Row (CSR) storage for variable-length lists. +//! +//! This module provides a memory-efficient data structure for storing +//! N variable-length lists with only 2 allocations total, instead of N+1 +//! allocations required by `Vec>`. + +use std::ops::Index; + +/// Compressed Sparse Row storage for variable-length lists. +/// +/// Stores N variable-length lists in two contiguous arrays: +/// - `offsets[i]..offsets[i+1]` defines the slice for row i +/// - `data` contains all elements back-to-back +/// +/// # Benefits over `Vec>` +/// - 2 allocations total (vs N+1) +/// - Contiguous memory for cache-friendly iteration +/// - No pointer chasing +pub struct Csr { + offsets: Vec, + data: Vec, +} + +impl Csr { + /// Create an empty CSR with pre-allocated capacity. + /// + /// # Arguments + /// * `num_rows` - Expected number of rows + /// * `total_elements` - Expected total elements across all rows + pub fn with_capacity(num_rows: usize, total_elements: usize) -> Self { + let mut offsets = Vec::with_capacity(num_rows + 1); + offsets.push(0); + Self { + offsets, + data: Vec::with_capacity(total_elements), + } + } + + /// Append a new row with the given elements. + pub fn push(&mut self, elements: &[T]) + where + T: Clone, + { + self.data.extend_from_slice(elements); + self.offsets.push(self.data.len() as u32); + } +} + +/// Test-only helpers. +#[cfg(test)] +impl Csr { + /// Number of rows. + #[inline] + pub fn num_rows(&self) -> usize { + self.offsets.len() - 1 + } + + /// Iterate over all rows as (index, slice) pairs. + pub fn iter_rows(&self) -> impl Iterator { + (0..self.num_rows()).map(move |i| (i, &self[i])) + } +} + +impl Index for Csr { + type Output = [T]; + + #[inline] + fn index(&self, i: usize) -> &Self::Output { + let start = self.offsets[i] as usize; + let end = self.offsets[i + 1] as usize; + &self.data[start..end] + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_with_capacity() { + let mut csr = Csr::with_capacity(3, 7); + csr.push(&[1, 2, 3]); + csr.push(&[4, 5]); + csr.push(&[6, 7]); + + assert_eq!(csr.num_rows(), 3); + assert_eq!(&csr[0], &[1, 2, 3]); + assert_eq!(&csr[1], &[4, 5]); + assert_eq!(&csr[2], &[6, 7]); + } + + #[test] + fn test_iter_rows() { + let mut csr = Csr::with_capacity(2, 3); + csr.push(&[10, 20]); + csr.push(&[30]); + + let rows: Vec<_> = csr.iter_rows().collect(); + assert_eq!(rows.len(), 2); + assert_eq!(rows[0], (0, &[10, 20][..])); + assert_eq!(rows[1], (1, &[30][..])); + } +} diff --git a/src/gadgets/addmany.rs b/src/gadgets/addmany.rs new file mode 100644 index 0000000..b91f26b --- /dev/null +++ b/src/gadgets/addmany.rs @@ -0,0 +1,218 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Addition algorithms for SmallUInt32 values. +//! +//! This module provides two addition algorithms optimized for different +//! coefficient bounds in the small-value sumcheck optimization: +//! +//! - [`limbed`]: 16-bit limbed addition, max coefficient 2^18 (fits i32) +//! - [`full`]: Full 35-bit addition, max coefficient 2^34 (fits i64) + +use super::{small_multi_eq::SmallMultiEq, small_uint32::SmallUInt32}; +use bellpepper_core::{ + LinearCombination, SynthesisError, + boolean::{AllocatedBit, Boolean}, +}; +use ff::PrimeField; + +/// Full 35-bit addition for i64 path. +/// +/// Computes the sum of multiple SmallUInt32 operands, allocating enough bits +/// to represent the full sum before truncating to 32 bits. +/// +/// Max coefficient: 2^34 (for 5 operands producing 35-bit result). +pub(crate) fn full( + cs: &mut M, + operands: &[SmallUInt32], +) -> Result +where + Scalar: PrimeField, + M: SmallMultiEq, +{ + // Compute the maximum value of the sum + let max_value = (operands.len() as u64) * (u32::MAX as u64); + + // How many bits do we need to represent the result? + let result_bits = 64 - max_value.leading_zeros() as usize; + + // Compute the value of the result + let result_value = operands + .iter() + .try_fold(0u64, |acc, op| op.get_value().map(|v| acc + (v as u64))); + + // Allocate each bit of the result + let mut result_bits_vec: Vec = Vec::with_capacity(result_bits); + let mut coeff = Scalar::ONE; + let mut lc = LinearCombination::zero(); + let mut all_operands_lc = LinearCombination::zero(); + + for i in 0..result_bits { + // Allocate the bit + let bit = AllocatedBit::alloc( + cs.namespace(|| format!("result bit {}", i)), + result_value.map(|v| (v >> i) & 1 == 1), + )?; + + // Add to linear combination + lc = lc + (coeff, bit.get_variable()); + + result_bits_vec.push(Boolean::from(bit)); + coeff = coeff.double(); + } + + // Compute linear combination of all operand bits + for op in operands.iter() { + let mut coeff = Scalar::ONE; + for bit in op.bits_le() { + all_operands_lc = all_operands_lc + &bit.lc(M::one(), coeff); + coeff = coeff.double(); + } + } + + // Enforce that the result equals the sum of operands + cs.enforce_equal(&lc, &all_operands_lc); + + // Truncate to 32 bits + let bits: [Boolean; 32] = result_bits_vec + .into_iter() + .take(32) + .collect::>() + .try_into() + .unwrap(); + + Ok(SmallUInt32::from_bits_le(&bits)) +} + +/// 16-bit limbed addition for i32 path. +/// +/// Splits each 32-bit value into two 16-bit limbs and adds them separately. +/// This keeps the maximum coefficient at 2^18, which fits in i32. +/// +/// Constraint 1 (low limb): +/// Σ(operand_lo) = result_lo + carry × 2^16 +/// +/// Constraint 2 (high limb): +/// Σ(operand_hi) + carry = result_hi + overflow × 2^16 +pub(crate) fn limbed( + cs: &mut M, + operands: &[SmallUInt32], +) -> Result +where + Scalar: PrimeField, + M: SmallMultiEq, +{ + // For N operands, each 16-bit limb sum can be up to N * (2^16 - 1) + // For 10 operands: 10 * 65535 = 655350, needs 20 bits + // We allocate 16 result bits + up to 4 carry/overflow bits + let num_carry_bits = + 64 - ((operands.len() as u64) * (u16::MAX as u64)).leading_zeros() as usize - 16; + let num_carry_bits = num_carry_bits.max(1); // At least 1 carry bit + + // Compute low limb sum (for witness generation) + // IMPORTANT: Carry must be computed from the LOW LIMB SUM, not from the full result + let lo_sum: Option = operands.iter().try_fold(0u64, |acc, op| { + op.get_value().map(|v| acc + ((v as u64) & 0xFFFF)) + }); + + // Compute high limb sum (for witness generation) + // hi_sum = Σ(operand_hi) + carry + let hi_sum: Option = operands.iter().try_fold(0u64, |acc, op| { + op.get_value().map(|v| acc + (((v as u64) >> 16) & 0xFFFF)) + }); + let hi_sum_with_carry: Option = hi_sum.and_then(|h| lo_sum.map(|l| h + (l >> 16))); + + // === LOW LIMB CONSTRAINT === + // Sum of low 16 bits of each operand = low 16 bits of result + carry × 2^16 + + // Build LHS: sum of all operand low limbs + let mut lo_operands_lc = LinearCombination::zero(); + for op in operands.iter() { + let mut coeff = Scalar::ONE; + for bit in &op.bits_le()[0..16] { + lo_operands_lc = lo_operands_lc + &bit.lc(M::one(), coeff); + coeff = coeff.double(); + } + } + + // Allocate result low bits (0..15) from lo_sum + let mut lo_result_lc = LinearCombination::zero(); + let mut bits = [const { Boolean::Constant(false) }; 32]; + let mut carry_bits: Vec = Vec::with_capacity(num_carry_bits); + + let mut coeff = Scalar::ONE; + for (i, slot) in bits.iter_mut().enumerate().take(16) { + let bit = AllocatedBit::alloc( + cs.namespace(|| format!("lo{i}")), + lo_sum.map(|v| (v >> i) & 1 == 1), + )?; + lo_result_lc = lo_result_lc + (coeff, bit.get_variable()); + *slot = Boolean::from(bit); + coeff = coeff.double(); + } + + // Allocate carry bits from lo_sum (bits 16+) + for i in 0..num_carry_bits { + let bit = AllocatedBit::alloc( + cs.namespace(|| format!("c{i}")), + lo_sum.map(|v| (v >> (16 + i)) & 1 == 1), + )?; + lo_result_lc = lo_result_lc + (coeff, bit.get_variable()); + carry_bits.push(bit); + coeff = coeff.double(); + } + + // Enforce: lo_operands_lc = lo_result_lc + cs.enforce_equal(&lo_operands_lc, &lo_result_lc); + + // === HIGH LIMB CONSTRAINT === + // Sum of high 16 bits of each operand + carry = high 16 bits of result + overflow × 2^16 + + // Build LHS: sum of all operand high limbs + carry + let mut hi_operands_lc = LinearCombination::zero(); + for op in operands.iter() { + let mut coeff = Scalar::ONE; + for bit in &op.bits_le()[16..32] { + hi_operands_lc = hi_operands_lc + &bit.lc(M::one(), coeff); + coeff = coeff.double(); + } + } + + // Add carry from low limb + let mut coeff = Scalar::ONE; + for carry_bit in &carry_bits { + hi_operands_lc = hi_operands_lc + (coeff, carry_bit.get_variable()); + coeff = coeff.double(); + } + + // Allocate result high bits (16..31) from hi_sum_with_carry + let mut hi_result_lc = LinearCombination::zero(); + let mut coeff = Scalar::ONE; + for i in 0..16 { + let bit = AllocatedBit::alloc( + cs.namespace(|| format!("hi{i}")), + hi_sum_with_carry.map(|v| (v >> i) & 1 == 1), + )?; + hi_result_lc = hi_result_lc + (coeff, bit.get_variable()); + bits[16 + i] = Boolean::from(bit); + coeff = coeff.double(); + } + + // Allocate overflow bits from hi_sum_with_carry (bits 16+, discarded) + for i in 0..num_carry_bits { + let bit = AllocatedBit::alloc( + cs.namespace(|| format!("o{i}")), + hi_sum_with_carry.map(|v| (v >> (16 + i)) & 1 == 1), + )?; + hi_result_lc = hi_result_lc + (coeff, bit.get_variable()); + coeff = coeff.double(); + } + + // Enforce: hi_operands_lc = hi_result_lc + cs.enforce_equal(&hi_operands_lc, &hi_result_lc); + + Ok(SmallUInt32::from_bits_le(&bits)) +} diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs new file mode 100644 index 0000000..969e656 --- /dev/null +++ b/src/gadgets/mod.rs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Circuit gadgets optimized for small-value sumcheck. +//! +//! This module provides circuit gadgets that are designed to work with the +//! small-value sumcheck optimization. The key difference from bellpepper's +//! gadgets is that constraint coefficients are bounded to fit in native integers. +//! +//! # Available Gadgets +//! +//! - [`SmallMultiEq`]: Trait for batched equality constraints with bounded coefficients +//! - [`NoBatchEq`]: Direct enforcement (for i32 path) +//! - [`BatchingEq`]: Batched enforcement (for i64 path) +//! - [`SmallUInt32`]: 32-bit unsigned integer with bit operations +//! - [`small_sha256`]: SHA-256 function using small-value compatible gadgets + +mod addmany; +mod small_multi_eq; +mod small_sha256; +mod small_uint32; + +pub use small_multi_eq::{BatchingEq, NoBatchEq, SmallMultiEq}; +pub use small_sha256::{small_sha256, small_sha256_with_prefix, small_sha256_with_small_multi_eq}; +pub use small_uint32::SmallUInt32; diff --git a/src/gadgets/small_multi_eq.rs b/src/gadgets/small_multi_eq.rs new file mode 100644 index 0000000..1973553 --- /dev/null +++ b/src/gadgets/small_multi_eq.rs @@ -0,0 +1,720 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! SmallMultiEq: Batched equality constraints with bounded coefficients. +//! +//! This module provides the `SmallMultiEq` trait and two implementations: +//! - [`NoBatchEq`]: Each equality constraint is enforced directly (for i32 path) +//! - [`BatchingEq`]: Batches up to K constraints before flushing (for i64 path) +//! +//! # Why SmallMultiEq? +//! +//! Bellpepper's `MultiEq` batches equality constraints using powers of 2 as +//! coefficients, accumulating up to 2^237 after ~237 batched equalities. +//! This breaks small-value optimization because: +//! +//! ```text +//! Az(x) = sum_y A(x,y) * z(y) +//! ``` +//! +//! Even if z(y) is small (bits), if A(x,y) = 2^237, then Az(x) is huge. +//! +//! # Design +//! +//! The `SmallMultiEq` trait extends `ConstraintSystem` with: +//! - `enforce_equal`: Batch-aware equality constraints +//! - `flush`: Force pending constraints to be emitted +//! - `addmany`: Add multiple SmallUInt32 values (algorithm determined by impl) +//! +//! Each implementation pairs a batching strategy with its compatible addmany algorithm: +//! - `NoBatchEq`: No batching + limbed addition (max coeff 2^18, fits i32) +//! - `BatchingEq`: Batch K constraints + full addition (max coeff 2^34, fits i64) + +use super::{addmany, small_uint32::SmallUInt32}; +use bellpepper_core::{ConstraintSystem, LinearCombination, SynthesisError, Variable}; +use ff::PrimeField; + +// ============================================================================ +// SmallMultiEq Trait +// ============================================================================ + +/// Constraint system extension for batched equality constraints with bounded coefficients. +/// +/// This trait extends `ConstraintSystem` with methods for enforcing equality +/// constraints in a way that keeps coefficients within small-value bounds. +pub trait SmallMultiEq: ConstraintSystem { + /// Enforce that `lhs` equals `rhs`. + /// + /// The implementation determines whether this is enforced directly or batched. + fn enforce_equal(&mut self, lhs: &LinearCombination, rhs: &LinearCombination); + + /// Flush any pending batched constraints to the underlying constraint system. + fn flush(&mut self); + + /// Add multiple SmallUInt32 values together. + /// + /// The implementation determines which addition algorithm is used: + /// - `NoBatchEq`: Uses limbed addition (max coeff 2^18) + /// - `BatchingEq`: Uses full addition (max coeff 2^34) + fn addmany(&mut self, operands: &[SmallUInt32]) -> Result; +} + +// ============================================================================ +// NoBatchEq - Direct enforcement, limbed addition +// ============================================================================ + +/// Constraint system wrapper that enforces equality constraints directly. +/// +/// Each call to `enforce_equal` immediately creates a constraint. This is used +/// with the limbed addition algorithm which keeps coefficients within i32 bounds. +/// +/// # Example +/// +/// ```ignore +/// let mut eq = NoBatchEq::::new(&mut cs); +/// eq.enforce_equal(&lhs, &rhs); // Immediately enforced +/// let sum = eq.addmany(&[a, b, c])?; // Uses limbed addition +/// ``` +pub struct NoBatchEq<'a, Scalar: PrimeField, CS: ConstraintSystem> { + cs: &'a mut CS, + ops: usize, + addmany_count: usize, + _marker: std::marker::PhantomData, +} + +impl<'a, Scalar: PrimeField, CS: ConstraintSystem> NoBatchEq<'a, Scalar, CS> { + /// Create a new NoBatchEq wrapper around a constraint system. + pub fn new(cs: &'a mut CS) -> Self { + NoBatchEq { + cs, + ops: 0, + addmany_count: 0, + _marker: std::marker::PhantomData, + } + } +} + +impl> SmallMultiEq + for NoBatchEq<'_, Scalar, CS> +{ + fn enforce_equal(&mut self, lhs: &LinearCombination, rhs: &LinearCombination) { + let ops = self.ops; + self.cs.enforce( + || format!("eq {ops}"), + |_| lhs.clone(), + |lc| lc + CS::one(), + |_| rhs.clone(), + ); + self.ops += 1; + } + + fn flush(&mut self) { + // No-op: NoBatchEq enforces constraints directly + } + + fn addmany(&mut self, operands: &[SmallUInt32]) -> Result { + assert!(Scalar::NUM_BITS >= 64); + assert!(operands.len() >= 2); + assert!(operands.len() <= 10); + + // Check for all-constant case + if let Some(sum) = try_constant_sum(operands) { + return Ok(SmallUInt32::constant(sum)); + } + + // Create a unique namespace for this addmany call + let count = self.addmany_count; + self.addmany_count += 1; + self.cs.push_namespace(|| format!("add{count}")); + + // Use limbed addition (max coeff 2^18, fits i32) + let result = addmany::limbed(self, operands); + + self.cs.pop_namespace(); + result + } +} + +// Delegate ConstraintSystem to inner cs +impl> ConstraintSystem + for NoBatchEq<'_, Scalar, CS> +{ + type Root = Self; + + fn one() -> Variable { + CS::one() + } + + fn alloc(&mut self, annotation: A, f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + self.cs.alloc(annotation, f) + } + + fn alloc_input(&mut self, annotation: A, f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + self.cs.alloc_input(annotation, f) + } + + fn enforce(&mut self, annotation: A, a: LA, b: LB, c: LC) + where + A: FnOnce() -> AR, + AR: Into, + LA: FnOnce(LinearCombination) -> LinearCombination, + LB: FnOnce(LinearCombination) -> LinearCombination, + LC: FnOnce(LinearCombination) -> LinearCombination, + { + self.cs.enforce(annotation, a, b, c); + } + + fn push_namespace(&mut self, name_fn: N) + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs.push_namespace(name_fn); + } + + fn pop_namespace(&mut self) { + self.cs.pop_namespace(); + } + + fn get_root(&mut self) -> &mut Self::Root { + self + } + + fn is_witness_generator(&self) -> bool { + self.cs.is_witness_generator() + } + + fn extend_inputs(&mut self, inputs: &[Scalar]) { + self.cs.extend_inputs(inputs); + } + + fn extend_aux(&mut self, aux: &[Scalar]) { + self.cs.extend_aux(aux); + } + + fn allocate_empty(&mut self, aux_n: usize, inputs_n: usize) -> (&mut [Scalar], &mut [Scalar]) { + self.cs.allocate_empty(aux_n, inputs_n) + } + + fn inputs_slice(&self) -> &[Scalar] { + self.cs.inputs_slice() + } + + fn aux_slice(&self) -> &[Scalar] { + self.cs.aux_slice() + } +} + +// ============================================================================ +// BatchingEq - Batched enforcement, full addition +// ============================================================================ + +/// Constraint system wrapper that batches equality constraints. +/// +/// Accumulates up to K equality constraints with coefficients 2^0, 2^1, ..., 2^(K-1) +/// before flushing as a single batched constraint. This is used with the full +/// addition algorithm which keeps coefficients within i64 bounds. +/// +/// # Type Parameter +/// +/// - `K`: Maximum number of constraints to batch before flushing +/// +/// # Why K=21? +/// +/// Batching packs multiple equality constraints into fewer constraints by forming: +/// ```text +/// B(z) = Σⱼ 2ʲ · Lⱼ(z) for j = 0..K-1 +/// ``` +/// and enforcing B(z) = 0. This is safe because powers of 2 give unique representation. +/// +/// The limit K=21 comes from not overflowing i64 before converting to field: +/// ```text +/// Az ≤ 200 terms × 2^34 (positional) × 2^20 (batching) × 1 (witness) +/// = 2^8 × 2^34 × 2^20 × 2^0 +/// = 2^62 < 2^63 (i64 signed max) +/// ``` +/// +/// # Example +/// +/// ```ignore +/// let mut eq = BatchingEq::::new(&mut cs); +/// eq.enforce_equal(&lhs1, &rhs1); // Batched with coeff 2^0 +/// eq.enforce_equal(&lhs2, &rhs2); // Batched with coeff 2^1 +/// // ... up to 21 constraints batched together +/// let sum = eq.addmany(&[a, b, c])?; // Uses full addition +/// drop(eq); // Flushes remaining batched constraints +/// ``` +pub struct BatchingEq<'a, Scalar: PrimeField, CS: ConstraintSystem, const K: usize> { + cs: &'a mut CS, + ops: usize, + addmany_count: usize, + bits_used: usize, + lhs: LinearCombination, + rhs: LinearCombination, +} + +impl<'a, Scalar: PrimeField, CS: ConstraintSystem, const K: usize> + BatchingEq<'a, Scalar, CS, K> +{ + /// Create a new BatchingEq wrapper around a constraint system. + pub fn new(cs: &'a mut CS) -> Self { + BatchingEq { + cs, + ops: 0, + addmany_count: 0, + bits_used: 0, + lhs: LinearCombination::zero(), + rhs: LinearCombination::zero(), + } + } + + /// Flush the pending batched constraint to the underlying constraint system. + fn do_flush(&mut self) { + let ops = self.ops; + let lhs = std::mem::replace(&mut self.lhs, LinearCombination::zero()); + let rhs = std::mem::replace(&mut self.rhs, LinearCombination::zero()); + + // Only enforce if we have accumulated something + if self.bits_used > 0 { + self.cs.enforce( + || format!("multieq {ops}"), + |_| lhs, + |lc| lc + CS::one(), + |_| rhs, + ); + self.ops += 1; + } + + self.bits_used = 0; + } +} + +impl, const K: usize> SmallMultiEq + for BatchingEq<'_, Scalar, CS, K> +{ + fn enforce_equal(&mut self, lhs: &LinearCombination, rhs: &LinearCombination) { + if self.bits_used >= K { + self.do_flush(); + } + + // Compute the coefficient: 2^bits_used + let coeff = Scalar::from(1u64 << self.bits_used); + + // Scale and accumulate + self.lhs = self.lhs.clone() + (coeff, lhs); + self.rhs = self.rhs.clone() + (coeff, rhs); + + self.bits_used += 1; + } + + fn flush(&mut self) { + self.do_flush(); + } + + fn addmany(&mut self, operands: &[SmallUInt32]) -> Result { + assert!(Scalar::NUM_BITS >= 64); + assert!(operands.len() >= 2); + assert!(operands.len() <= 10); + + // Check for all-constant case + if let Some(sum) = try_constant_sum(operands) { + return Ok(SmallUInt32::constant(sum)); + } + + // Create a unique namespace for this addmany call + let count = self.addmany_count; + self.addmany_count += 1; + self.cs.push_namespace(|| format!("add{count}")); + + // Use full addition (max coeff 2^34, fits i64) + let result = addmany::full(self, operands); + + self.cs.pop_namespace(); + result + } +} + +impl, const K: usize> Drop + for BatchingEq<'_, Scalar, CS, K> +{ + fn drop(&mut self) { + self.do_flush(); + } +} + +// Delegate ConstraintSystem to inner cs +impl, const K: usize> ConstraintSystem + for BatchingEq<'_, Scalar, CS, K> +{ + type Root = Self; + + fn one() -> Variable { + CS::one() + } + + fn alloc(&mut self, annotation: A, f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + self.cs.alloc(annotation, f) + } + + fn alloc_input(&mut self, annotation: A, f: F) -> Result + where + F: FnOnce() -> Result, + A: FnOnce() -> AR, + AR: Into, + { + self.cs.alloc_input(annotation, f) + } + + fn enforce(&mut self, annotation: A, a: LA, b: LB, c: LC) + where + A: FnOnce() -> AR, + AR: Into, + LA: FnOnce(LinearCombination) -> LinearCombination, + LB: FnOnce(LinearCombination) -> LinearCombination, + LC: FnOnce(LinearCombination) -> LinearCombination, + { + self.cs.enforce(annotation, a, b, c); + } + + fn push_namespace(&mut self, name_fn: N) + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs.push_namespace(name_fn); + } + + fn pop_namespace(&mut self) { + self.cs.pop_namespace(); + } + + fn get_root(&mut self) -> &mut Self::Root { + self + } + + fn is_witness_generator(&self) -> bool { + self.cs.is_witness_generator() + } + + fn extend_inputs(&mut self, inputs: &[Scalar]) { + self.cs.extend_inputs(inputs); + } + + fn extend_aux(&mut self, aux: &[Scalar]) { + self.cs.extend_aux(aux); + } + + fn allocate_empty(&mut self, aux_n: usize, inputs_n: usize) -> (&mut [Scalar], &mut [Scalar]) { + self.cs.allocate_empty(aux_n, inputs_n) + } + + fn inputs_slice(&self) -> &[Scalar] { + self.cs.inputs_slice() + } + + fn aux_slice(&self) -> &[Scalar] { + self.cs.aux_slice() + } +} + +// ============================================================================ +// Helper Functions +// ============================================================================ + +/// Try to compute sum as a constant if all operands are constant. +fn try_constant_sum(operands: &[SmallUInt32]) -> Option { + // Check if all operands have known values + if !operands.iter().all(|op| op.get_value().is_some()) { + return None; + } + + // Check if all bits are constant + let all_constant = operands.iter().all(|op| { + op.bits_le() + .iter() + .all(|b| matches!(b, bellpepper_core::boolean::Boolean::Constant(_))) + }); + + if all_constant { + let sum: u32 = operands + .iter() + .map(|op| op.get_value().unwrap()) + .fold(0u32, |a, b| a.wrapping_add(b)); + Some(sum) + } else { + None + } +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + use bellpepper_core::test_cs::TestConstraintSystem; + use halo2curves::pasta::Fq; + + #[test] + fn test_no_batch_eq_basic() { + let mut cs = TestConstraintSystem::::new(); + + let a = cs.alloc(|| "a", || Ok(Fq::from(5u64))).unwrap(); + let b = cs.alloc(|| "b", || Ok(Fq::from(5u64))).unwrap(); + + { + let mut eq = NoBatchEq::::new(&mut cs); + let lhs = LinearCombination::zero() + a; + let rhs = LinearCombination::zero() + b; + eq.enforce_equal(&lhs, &rhs); + } + + assert!(cs.is_satisfied()); + assert_eq!(cs.num_constraints(), 1); + } + + #[test] + fn test_batching_eq_basic() { + let mut cs = TestConstraintSystem::::new(); + + let a = cs.alloc(|| "a", || Ok(Fq::from(5u64))).unwrap(); + let b = cs.alloc(|| "b", || Ok(Fq::from(5u64))).unwrap(); + + { + let mut eq = BatchingEq::::new(&mut cs); + let lhs = LinearCombination::zero() + a; + let rhs = LinearCombination::zero() + b; + eq.enforce_equal(&lhs, &rhs); + } // Drop flushes + + assert!(cs.is_satisfied()); + assert_eq!(cs.num_constraints(), 1); + } + + #[test] + fn test_no_batch_eq_multiple() { + let mut cs = TestConstraintSystem::::new(); + + let a = cs.alloc(|| "a", || Ok(Fq::from(10u64))).unwrap(); + let b = cs.alloc(|| "b", || Ok(Fq::from(10u64))).unwrap(); + let c = cs.alloc(|| "c", || Ok(Fq::from(20u64))).unwrap(); + let d = cs.alloc(|| "d", || Ok(Fq::from(20u64))).unwrap(); + + { + let mut eq = NoBatchEq::::new(&mut cs); + eq.enforce_equal( + &(LinearCombination::zero() + a), + &(LinearCombination::zero() + b), + ); + eq.enforce_equal( + &(LinearCombination::zero() + c), + &(LinearCombination::zero() + d), + ); + } + + assert!(cs.is_satisfied()); + // NoBatchEq: 2 direct constraints + assert_eq!(cs.num_constraints(), 2); + } + + #[test] + fn test_batching_eq_multiple() { + let mut cs = TestConstraintSystem::::new(); + + let a = cs.alloc(|| "a", || Ok(Fq::from(10u64))).unwrap(); + let b = cs.alloc(|| "b", || Ok(Fq::from(10u64))).unwrap(); + let c = cs.alloc(|| "c", || Ok(Fq::from(20u64))).unwrap(); + let d = cs.alloc(|| "d", || Ok(Fq::from(20u64))).unwrap(); + + { + let mut eq = BatchingEq::::new(&mut cs); + eq.enforce_equal( + &(LinearCombination::zero() + a), + &(LinearCombination::zero() + b), + ); + eq.enforce_equal( + &(LinearCombination::zero() + c), + &(LinearCombination::zero() + d), + ); + } + + assert!(cs.is_satisfied()); + // BatchingEq<21>: 2 constraints batched into 1 + assert_eq!(cs.num_constraints(), 1); + } + + #[test] + fn test_batching_eq_flush_at_capacity() { + let mut cs = TestConstraintSystem::::new(); + + // Batch size 5 for testing + let n = 12; // 12 constraints = 2 full batches (5+5) + 2 remaining + + let vars: Vec<_> = (0..n) + .map(|i| { + cs.alloc(|| format!("v{i}"), || Ok(Fq::from(42u64))) + .unwrap() + }) + .collect(); + + let expected = cs.alloc(|| "expected", || Ok(Fq::from(42u64))).unwrap(); + + { + let mut eq = BatchingEq::::new(&mut cs); + let expected_lc = LinearCombination::zero() + expected; + + for v in &vars { + let lhs = LinearCombination::zero() + *v; + eq.enforce_equal(&lhs, &expected_lc); + } + } + + assert!(cs.is_satisfied()); + // 12 constraints with batch size 5: 5+5+2 = 3 batched constraints + assert_eq!(cs.num_constraints(), 3); + } + + #[test] + fn test_no_batch_eq_many() { + let mut cs = TestConstraintSystem::::new(); + + let n = 10; + + let vars: Vec<_> = (0..n) + .map(|i| { + cs.alloc(|| format!("v{i}"), || Ok(Fq::from(42u64))) + .unwrap() + }) + .collect(); + + let expected = cs.alloc(|| "expected", || Ok(Fq::from(42u64))).unwrap(); + + { + let mut eq = NoBatchEq::::new(&mut cs); + let expected_lc = LinearCombination::zero() + expected; + + for v in &vars { + let lhs = LinearCombination::zero() + *v; + eq.enforce_equal(&lhs, &expected_lc); + } + } + + assert!(cs.is_satisfied()); + // NoBatchEq: 10 direct constraints + assert_eq!(cs.num_constraints(), 10); + } + + #[test] + fn test_batching_eq_unsatisfied() { + let mut cs = TestConstraintSystem::::new(); + + let a = cs.alloc(|| "a", || Ok(Fq::from(5u64))).unwrap(); + let b = cs.alloc(|| "b", || Ok(Fq::from(10u64))).unwrap(); + + { + let mut eq = BatchingEq::::new(&mut cs); + eq.enforce_equal( + &(LinearCombination::zero() + a), + &(LinearCombination::zero() + b), + ); + } + + assert!(!cs.is_satisfied()); + } + + #[test] + fn test_no_batch_eq_addmany() { + let mut cs = TestConstraintSystem::::new(); + + let a = SmallUInt32::alloc(cs.namespace(|| "a"), Some(100)).unwrap(); + let b = SmallUInt32::alloc(cs.namespace(|| "b"), Some(200)).unwrap(); + let c = SmallUInt32::alloc(cs.namespace(|| "c"), Some(300)).unwrap(); + + { + let mut eq = NoBatchEq::::new(&mut cs); + let result = eq.addmany(&[a, b, c]).unwrap(); + assert_eq!(result.get_value(), Some(600)); + } + + assert!(cs.is_satisfied()); + } + + #[test] + fn test_batching_eq_addmany() { + let mut cs = TestConstraintSystem::::new(); + + let a = SmallUInt32::alloc(cs.namespace(|| "a"), Some(100)).unwrap(); + let b = SmallUInt32::alloc(cs.namespace(|| "b"), Some(200)).unwrap(); + let c = SmallUInt32::alloc(cs.namespace(|| "c"), Some(300)).unwrap(); + + { + let mut eq = BatchingEq::::new(&mut cs); + let result = eq.addmany(&[a, b, c]).unwrap(); + assert_eq!(result.get_value(), Some(600)); + } + + assert!(cs.is_satisfied()); + } + + #[test] + fn test_addmany_overflow() { + let mut cs = TestConstraintSystem::::new(); + + let a = SmallUInt32::alloc(cs.namespace(|| "a"), Some(0xFFFFFFFF)).unwrap(); + let b = SmallUInt32::alloc(cs.namespace(|| "b"), Some(1)).unwrap(); + + { + let mut eq = NoBatchEq::::new(&mut cs); + let result = eq.addmany(&[a, b]).unwrap(); + // Should wrap to 0 + assert_eq!(result.get_value(), Some(0)); + } + + assert!(cs.is_satisfied()); + } + + #[test] + fn test_addmany_5_operands() { + // SHA-256 uses 5-operand addition + let mut cs = TestConstraintSystem::::new(); + + let a = SmallUInt32::alloc(cs.namespace(|| "a"), Some(0x12345678)).unwrap(); + let b = SmallUInt32::alloc(cs.namespace(|| "b"), Some(0x87654321)).unwrap(); + let c = SmallUInt32::alloc(cs.namespace(|| "c"), Some(0xDEADBEEF)).unwrap(); + let d = SmallUInt32::alloc(cs.namespace(|| "d"), Some(0xCAFEBABE)).unwrap(); + let e = SmallUInt32::alloc(cs.namespace(|| "e"), Some(0x01020304)).unwrap(); + + let expected = 0x12345678u32 + .wrapping_add(0x87654321) + .wrapping_add(0xDEADBEEF) + .wrapping_add(0xCAFEBABE) + .wrapping_add(0x01020304); + + { + let mut eq = NoBatchEq::::new(&mut cs); + let result = eq.addmany(&[a, b, c, d, e]).unwrap(); + assert_eq!(result.get_value(), Some(expected)); + } + + assert!(cs.is_satisfied()); + } +} diff --git a/src/gadgets/small_sha256.rs b/src/gadgets/small_sha256.rs new file mode 100644 index 0000000..7bec5ac --- /dev/null +++ b/src/gadgets/small_sha256.rs @@ -0,0 +1,504 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! SHA-256 circuit using small-value compatible gadgets. +//! +//! This module provides a SHA-256 implementation that is compatible with the +//! small-value sumcheck optimization. Unlike bellpepper's SHA-256 which uses +//! `MultiEq` and can create coefficients up to 2^237, this implementation uses +//! `SmallMultiEq` which either enforces directly or batches with bounded coefficients. +//! +//! # Usage +//! +//! ```ignore +//! use spartan2::gadgets::small_sha256; +//! +//! // Simple API (recommended) - uses BatchingEq<21> for optimal performance +//! let hash_bits = small_sha256(&mut cs, &input_bits)?; +//! +//! // With prefix for hash chains +//! let hash_bits = small_sha256_with_prefix(&mut cs, &input_bits, "block0_")?; +//! +//! // Advanced API - bring your own SmallMultiEq +//! let mut eq = NoBatchEq::::new(&mut cs); +//! let hash_bits = small_sha256_with_small_multi_eq(&mut eq, &input_bits, "")?; +//! ``` + +use super::{BatchingEq, SmallMultiEq, SmallUInt32}; +use bellpepper_core::{ConstraintSystem, SynthesisError, boolean::Boolean}; +use ff::PrimeField; + +/// SHA-256 round constants K[0..63]. +const ROUND_CONSTANTS: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; + +/// SHA-256 initial hash values H[0..7]. +const IV: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +/// Σ0(x) = ROTR^2(x) ⊕ ROTR^13(x) ⊕ ROTR^22(x) +fn big_sigma_0>( + mut cs: CS, + x: &SmallUInt32, +) -> Result { + let r2 = x.rotr(2); + let r13 = x.rotr(13); + let r22 = x.rotr(22); + let tmp = r2.xor(cs.namespace(|| "sigma0_r2_xor_r13"), &r13)?; + tmp.xor(cs.namespace(|| "sigma0_xor_r22"), &r22) +} + +/// Σ1(x) = ROTR^6(x) ⊕ ROTR^11(x) ⊕ ROTR^25(x) +fn big_sigma_1>( + mut cs: CS, + x: &SmallUInt32, +) -> Result { + let r6 = x.rotr(6); + let r11 = x.rotr(11); + let r25 = x.rotr(25); + let tmp = r6.xor(cs.namespace(|| "sigma1_r6_xor_r11"), &r11)?; + tmp.xor(cs.namespace(|| "sigma1_xor_r25"), &r25) +} + +/// σ0(x) = ROTR^7(x) ⊕ ROTR^18(x) ⊕ SHR^3(x) +fn small_sigma_0>( + mut cs: CS, + x: &SmallUInt32, +) -> Result { + let r7 = x.rotr(7); + let r18 = x.rotr(18); + let s3 = x.shr(3); + let tmp = r7.xor(cs.namespace(|| "s0_r7_xor_r18"), &r18)?; + tmp.xor(cs.namespace(|| "s0_xor_s3"), &s3) +} + +/// σ1(x) = ROTR^17(x) ⊕ ROTR^19(x) ⊕ SHR^10(x) +fn small_sigma_1>( + mut cs: CS, + x: &SmallUInt32, +) -> Result { + let r17 = x.rotr(17); + let r19 = x.rotr(19); + let s10 = x.shr(10); + let tmp = r17.xor(cs.namespace(|| "s1_r17_xor_r19"), &r19)?; + tmp.xor(cs.namespace(|| "s1_xor_s10"), &s10) +} + +/// SHA-256 compression function. +/// +/// Takes the current hash state H and a 512-bit message block W, +/// returns the updated hash state. +/// +/// The `prefix` is prepended to all variable names to allow multiple SHA-256 +/// calls in the same constraint system (e.g., for hash chains). +fn sha256_compression( + cs: &mut M, + h: &mut [SmallUInt32; 8], + w: &[SmallUInt32; 16], + block_idx: usize, + prefix: &str, +) -> Result<(), SynthesisError> +where + Scalar: PrimeField, + M: SmallMultiEq, +{ + // Message schedule: expand 16 words to 64 words + let mut w_expanded: Vec = w.to_vec(); + w_expanded.reserve(48); + + for i in 16..64 { + // W[i] = σ1(W[i-2]) + W[i-7] + σ0(W[i-15]) + W[i-16] + let s1 = small_sigma_1( + cs.namespace(|| format!("{}b{}_w{}_s1", prefix, block_idx, i)), + &w_expanded[i - 2], + )?; + let s0 = small_sigma_0( + cs.namespace(|| format!("{}b{}_w{}_s0", prefix, block_idx, i)), + &w_expanded[i - 15], + )?; + + let wi = cs.addmany(&[ + s1, + w_expanded[i - 7].clone(), + s0, + w_expanded[i - 16].clone(), + ])?; + w_expanded.push(wi); + } + + // Initialize working variables + let (mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h_var) = ( + h[0].clone(), + h[1].clone(), + h[2].clone(), + h[3].clone(), + h[4].clone(), + h[5].clone(), + h[6].clone(), + h[7].clone(), + ); + + // 64 rounds + for i in 0..64 { + // T1 = h + Σ1(e) + Ch(e,f,g) + K[i] + W[i] + let sigma1 = big_sigma_1( + cs.namespace(|| format!("{}b{}_r{}_sigma1", prefix, block_idx, i)), + &e, + )?; + let ch = SmallUInt32::sha256_ch( + cs.namespace(|| format!("{}b{}_r{}_ch", prefix, block_idx, i)), + &e, + &f, + &g, + )?; + let k = SmallUInt32::constant(ROUND_CONSTANTS[i]); + + let t1 = cs.addmany(&[h_var.clone(), sigma1, ch, k, w_expanded[i].clone()])?; + + // T2 components: Σ0(a) and Maj(a,b,c) + // Instead of computing T2 = sigma0 + maj separately, we fuse it into 'a' below. + let sigma0 = big_sigma_0( + cs.namespace(|| format!("{}b{}_r{}_sigma0", prefix, block_idx, i)), + &a, + )?; + let maj = SmallUInt32::sha256_maj( + cs.namespace(|| format!("{}b{}_r{}_maj", prefix, block_idx, i)), + &a, + &b, + &c, + )?; + + // Update working variables + h_var = g; + g = f; + f = e; + e = cs.addmany(&[d, t1.clone()])?; + d = c; + c = b; + b = a; + // Fused: a = T1 + T2 = T1 + sigma0 + maj (saves one addmany call per round) + a = cs.addmany(&[t1, sigma0, maj])?; + } + + // Compute final hash values + h[0] = cs.addmany(&[h[0].clone(), a])?; + h[1] = cs.addmany(&[h[1].clone(), b])?; + h[2] = cs.addmany(&[h[2].clone(), c])?; + h[3] = cs.addmany(&[h[3].clone(), d])?; + h[4] = cs.addmany(&[h[4].clone(), e])?; + h[5] = cs.addmany(&[h[5].clone(), f])?; + h[6] = cs.addmany(&[h[6].clone(), g])?; + h[7] = cs.addmany(&[h[7].clone(), h_var])?; + + Ok(()) +} + +/// Compute SHA-256 hash of input bits. +/// +/// Uses `BatchingEq<21>` internally for optimal performance with i64 small values. +/// This batches up to 21 equality constraints and uses full 35-bit addition. +/// +/// Returns 256 bits of the hash in big-endian order. +/// +/// # Example +/// +/// ```ignore +/// let mut cs = TestConstraintSystem::::new(); +/// let hash_bits = small_sha256::(&mut cs, &input_bits)?; +/// ``` +pub fn small_sha256( + cs: &mut CS, + input: &[Boolean], +) -> Result, SynthesisError> +where + Scalar: PrimeField, + CS: ConstraintSystem, +{ + small_sha256_with_prefix::(cs, input, "") +} + +/// Compute SHA-256 hash with a prefix for variable names. +/// +/// Uses `BatchingEq<21>` internally for optimal performance. +/// This variant allows multiple SHA-256 computations in the same constraint +/// system (e.g., for hash chains) by prefixing all internal variable names. +/// +/// # Example +/// +/// ```ignore +/// // Hash chain: H(H(H(x))) +/// let h1 = small_sha256_with_prefix(&mut cs, &input, "hash1_")?; +/// let h2 = small_sha256_with_prefix(&mut cs, &h1, "hash2_")?; +/// let h3 = small_sha256_with_prefix(&mut cs, &h2, "hash3_")?; +/// ``` +pub fn small_sha256_with_prefix( + cs: &mut CS, + input: &[Boolean], + prefix: &str, +) -> Result, SynthesisError> +where + Scalar: PrimeField, + CS: ConstraintSystem, +{ + // Push namespace to scope constraints under the prefix + cs.push_namespace(|| format!("{}sha256", prefix)); + + // Create BatchingEq<21> for optimal performance + let mut eq = BatchingEq::::new(cs); + let result = small_sha256_with_small_multi_eq(&mut eq, input, prefix); + drop(eq); // Flush any pending constraints + + // Pop the namespace before returning + cs.pop_namespace(); + + result +} + +/// Compute SHA-256 hash using a custom `SmallMultiEq` implementation. +/// +/// This is the advanced API that gives full control over the batching strategy. +/// Use this when you need: +/// - `NoBatchEq` for i32 small value compatibility +/// - Custom `BatchingEq` with a different batch size +/// +/// # Example +/// +/// ```ignore +/// // Using NoBatchEq for i32 compatibility +/// let mut eq = NoBatchEq::::new(&mut cs); +/// let hash = small_sha256_with_small_multi_eq(&mut eq, &input, "")?; +/// +/// // Using custom BatchingEq<10> +/// let mut eq = BatchingEq::::new(&mut cs); +/// let hash = small_sha256_with_small_multi_eq(&mut eq, &input, "")?; +/// ``` +pub fn small_sha256_with_small_multi_eq( + eq: &mut M, + input: &[Boolean], + prefix: &str, +) -> Result, SynthesisError> +where + Scalar: PrimeField, + M: SmallMultiEq, +{ + // Pad the input according to SHA-256 spec + let padded = sha256_padding(input); + + // Process in 512-bit blocks + assert!(padded.len().is_multiple_of(512)); + let num_blocks = padded.len() / 512; + + // Initialize hash state + let mut h: [SmallUInt32; 8] = IV.map(SmallUInt32::constant); + + for block_idx in 0..num_blocks { + let block_start = block_idx * 512; + let block_bits = &padded[block_start..block_start + 512]; + + // Convert 512 bits to 16 32-bit words (big-endian) + let mut w: [SmallUInt32; 16] = std::array::from_fn(|_| SmallUInt32::constant(0)); + for (i, w_item) in w.iter_mut().enumerate() { + let word_bits: [Boolean; 32] = block_bits[i * 32..(i + 1) * 32] + .to_vec() + .try_into() + .unwrap(); + *w_item = SmallUInt32::from_bits_be(&word_bits); + } + + // Run compression + sha256_compression(eq, &mut h, &w, block_idx, prefix)?; + } + + // Collect output bits in big-endian order + let mut output = Vec::with_capacity(256); + for h_i in h { + output.extend(h_i.into_bits_be()); + } + + Ok(output) +} + +/// SHA-256 padding: append 1 bit, zeros, and 64-bit length. +fn sha256_padding(input: &[Boolean]) -> Vec { + let msg_len = input.len(); + + // Calculate padded length: message + 1 + zeros + 64-bit length + // Must be multiple of 512 + let mut padded_len = msg_len + 1 + 64; // message + '1' bit + length + if !padded_len.is_multiple_of(512) { + padded_len += 512 - (padded_len % 512); + } + + let mut padded = Vec::with_capacity(padded_len); + + // Copy message bits + padded.extend_from_slice(input); + + // Append '1' bit + padded.push(Boolean::constant(true)); + + // Append zeros + let zero_count = padded_len - msg_len - 1 - 64; + for _ in 0..zero_count { + padded.push(Boolean::constant(false)); + } + + // Append 64-bit length (big-endian) + let len_bits: u64 = msg_len as u64; + for i in (0..64).rev() { + padded.push(Boolean::constant((len_bits >> i) & 1 == 1)); + } + + assert_eq!(padded.len(), padded_len); + assert!(padded.len() % 512 == 0); + + padded +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::gadgets::NoBatchEq; + use bellpepper_core::test_cs::TestConstraintSystem; + use halo2curves::pasta::Fq; + use rand::{Rng, SeedableRng, rngs::StdRng}; + use sha2::{Digest, Sha256}; + + /// Convert bytes to Boolean bits (big-endian per byte). + fn bytes_to_bits(bytes: &[u8]) -> Vec { + bytes + .iter() + .flat_map(|byte| { + (0..8) + .rev() + .map(move |i| Boolean::constant((byte >> i) & 1 == 1)) + }) + .collect() + } + + /// Convert Boolean bits to bytes (big-endian per byte). + fn bits_to_bytes(bits: &[Boolean]) -> Vec { + assert!(bits.len().is_multiple_of(8)); + bits + .chunks(8) + .map(|chunk| { + chunk.iter().fold(0u8, |acc, bit| { + let b = match bit { + Boolean::Constant(b) => *b, + Boolean::Is(ab) => ab.get_value().unwrap(), + Boolean::Not(ab) => !ab.get_value().unwrap(), + }; + (acc << 1) | (b as u8) + }) + }) + .collect() + } + + #[test] + fn test_small_sha256_empty() { + let mut cs = TestConstraintSystem::::new(); + + // SHA256("") = e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 + let input: Vec = vec![]; + let hash_bits = small_sha256::(&mut cs, &input).unwrap(); + + let hash_bytes = bits_to_bytes(&hash_bits); + let expected = Sha256::digest(b""); + + assert_eq!(&hash_bytes[..], &expected[..]); + assert!(cs.is_satisfied()); + } + + #[test] + fn test_small_sha256_abc() { + let mut cs = TestConstraintSystem::::new(); + + // SHA256("abc") = ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad + let input = bytes_to_bits(b"abc"); + let hash_bits = small_sha256::(&mut cs, &input).unwrap(); + + let hash_bytes = bits_to_bytes(&hash_bits); + let expected = Sha256::digest(b"abc"); + + assert_eq!(&hash_bytes[..], &expected[..]); + assert!(cs.is_satisfied()); + } + + #[test] + fn test_small_sha256_matches_native_32_times() { + // Use seeded RNG for reproducibility + let mut rng = StdRng::seed_from_u64(12345); + + for i in 0..32 { + let mut cs = TestConstraintSystem::::new(); + + // Random preimage length: 1 to 128 bytes + let len = rng.gen_range(1..=128); + let preimage: Vec = (0..len).map(|_| rng.r#gen()).collect(); + + // Native SHA-256 + let expected = Sha256::digest(&preimage); + + // Circuit SHA-256 (now uses BatchingEq<21> by default) + let input_bits = bytes_to_bits(&preimage); + let hash_bits = small_sha256::(&mut cs, &input_bits).unwrap(); + let hash_bytes = bits_to_bytes(&hash_bits); + + assert_eq!( + &hash_bytes[..], + &expected[..], + "Mismatch at iteration {}, preimage len {}", + i, + len + ); + assert!(cs.is_satisfied(), "CS not satisfied at iteration {}", i); + } + } + + #[test] + fn test_small_sha256_with_small_multi_eq_nobatch() { + // Test the advanced API with NoBatchEq + let mut rng = StdRng::seed_from_u64(54321); + + for i in 0..8 { + let mut cs = TestConstraintSystem::::new(); + + let len = rng.gen_range(1..=64); + let preimage: Vec = (0..len).map(|_| rng.r#gen()).collect(); + + let expected = Sha256::digest(&preimage); + + let input_bits = bytes_to_bits(&preimage); + + // Use NoBatchEq via the advanced API + let mut eq = NoBatchEq::::new(&mut cs); + let hash_bits = small_sha256_with_small_multi_eq(&mut eq, &input_bits, "").unwrap(); + #[allow(clippy::drop_non_drop)] + // Intentional: signals "done with eq" for consistency with BatchingEq + drop(eq); + + let hash_bytes = bits_to_bytes(&hash_bits); + + assert_eq!( + &hash_bytes[..], + &expected[..], + "Mismatch at iteration {}, preimage len {}", + i, + len + ); + assert!(cs.is_satisfied(), "CS not satisfied at iteration {}", i); + } + } +} diff --git a/src/gadgets/small_uint32.rs b/src/gadgets/small_uint32.rs new file mode 100644 index 0000000..f4fd5ce --- /dev/null +++ b/src/gadgets/small_uint32.rs @@ -0,0 +1,371 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! SmallUInt32: 32-bit unsigned integer gadget for small-value sumcheck. +//! +//! This is a port of bellpepper's `UInt32` with bit operations for SHA-256. +//! Addition is handled externally via the `SmallMultiEq` trait's `addmany` method. +//! +//! # SHA256 Operations +//! +//! | Operation | Constraints | +//! |-----------|-------------| +//! | `xor()` | Delegates to Boolean | +//! | `rotr()` | No constraints - just reorders bits | +//! | `shr()` | No constraints - inserts zero bits | +//! | `sha256_ch/maj()` | Uses AND/XOR | + +use bellpepper_core::{ + ConstraintSystem, SynthesisError, + boolean::{AllocatedBit, Boolean}, +}; +use ff::PrimeField; + +/// A 32-bit unsigned integer for circuits with small-value optimization. +#[derive(Clone, Debug)] +pub struct SmallUInt32 { + /// Little-endian bit representation + bits: [Boolean; 32], + /// Cached value (if known) + value: Option, +} + +impl SmallUInt32 { + /// Construct a `SmallUInt32` from a `Boolean` array. + /// Bits are in little-endian order. + pub fn from_bits_le(bits: &[Boolean; 32]) -> Self { + let value = bits.iter().rev().try_fold(0u32, |acc, bit| { + bit + .get_value() + .map(|b| if b { (acc << 1) | 1 } else { acc << 1 }) + }); + + SmallUInt32 { + bits: bits.clone(), + value, + } + } + + /// Construct a `SmallUInt32` from a `Boolean` array in big-endian order. + pub fn from_bits_be(bits: &[Boolean; 32]) -> Self { + let mut bits_le = bits.clone(); + bits_le.reverse(); + Self::from_bits_le(&bits_le) + } + + /// Get the bits in little-endian order. + pub fn bits_le(&self) -> &[Boolean; 32] { + &self.bits + } + + /// Get the bits in big-endian order. + pub fn into_bits_be(self) -> [Boolean; 32] { + let mut bits = self.bits; + bits.reverse(); + bits + } + + /// Get the value if known. + pub fn get_value(&self) -> Option { + self.value + } + + /// Create a constant `SmallUInt32`. + pub fn constant(value: u32) -> Self { + let bits: [Boolean; 32] = std::array::from_fn(|i| Boolean::constant((value >> i) & 1 == 1)); + + SmallUInt32 { + bits, + value: Some(value), + } + } + + /// Allocate a `SmallUInt32` in the constraint system. + pub fn alloc(mut cs: CS, value: Option) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + let mut bits = [const { Boolean::Constant(false) }; 32]; + for (i, slot) in bits.iter_mut().enumerate() { + *slot = Boolean::from(AllocatedBit::alloc( + cs.namespace(|| format!("b{i}")), + value.map(|v| (v >> i) & 1 == 1), + )?); + } + Ok(SmallUInt32 { bits, value }) + } + + /// Right rotation. + pub fn rotr(&self, by: usize) -> Self { + let by = by % 32; + let bits: [Boolean; 32] = std::array::from_fn(|i| self.bits[(i + by) % 32].clone()); + + SmallUInt32 { + bits, + value: self.value.map(|v| v.rotate_right(by as u32)), + } + } + + /// Right shift. + pub fn shr(&self, by: usize) -> Self { + let bits: [Boolean; 32] = std::array::from_fn(|i| { + if i + by < 32 { + self.bits[i + by].clone() + } else { + Boolean::constant(false) + } + }); + + SmallUInt32 { + bits, + value: self.value.map(|v| v >> by), + } + } + + /// XOR with another `SmallUInt32`. + pub fn xor(&self, mut cs: CS, other: &Self) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + let mut bits = [const { Boolean::Constant(false) }; 32]; + for (i, (slot, (a, b))) in bits + .iter_mut() + .zip(self.bits.iter().zip(other.bits.iter())) + .enumerate() + { + *slot = Boolean::xor(cs.namespace(|| format!("b{i}")), a, b)?; + } + + Ok(SmallUInt32 { + bits, + value: self.value.and_then(|a| other.value.map(|b| a ^ b)), + }) + } + + /// SHA-256 CH function: (a AND b) XOR ((NOT a) AND c) + pub fn sha256_ch( + mut cs: CS, + a: &Self, + b: &Self, + c: &Self, + ) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + let mut bits = [const { Boolean::Constant(false) }; 32]; + for (i, (slot, ((a_bit, b_bit), c_bit))) in bits + .iter_mut() + .zip(a.bits.iter().zip(b.bits.iter()).zip(c.bits.iter())) + .enumerate() + { + *slot = Boolean::sha256_ch(cs.namespace(|| format!("b{i}")), a_bit, b_bit, c_bit)?; + } + + Ok(SmallUInt32 { + bits, + value: a + .value + .and_then(|a| b.value.and_then(|b| c.value.map(|c| (a & b) ^ ((!a) & c)))), + }) + } + + /// SHA-256 MAJ function: (a AND b) XOR (a AND c) XOR (b AND c) + /// + /// Optimized identity: Maj(a,b,c) = (a & b) ^ (c & (a ^ b)) + /// This uses 2 AND + 2 XOR per bit instead of 3 AND + 2 XOR. + pub fn sha256_maj( + mut cs: CS, + a: &Self, + b: &Self, + c: &Self, + ) -> Result + where + Scalar: PrimeField, + CS: ConstraintSystem, + { + let mut bits = [const { Boolean::Constant(false) }; 32]; + for (i, (slot, ((a_bit, b_bit), c_bit))) in bits + .iter_mut() + .zip(a.bits.iter().zip(b.bits.iter()).zip(c.bits.iter())) + .enumerate() + { + let mut bit_cs = cs.namespace(|| format!("b{i}")); + // Optimized: Maj(a,b,c) = (a & b) ^ (c & (a ^ b)) + let t = Boolean::xor(bit_cs.namespace(|| "xor_ab"), a_bit, b_bit)?; + let u = Boolean::and(bit_cs.namespace(|| "and_c_t"), c_bit, &t)?; + let v = Boolean::and(bit_cs.namespace(|| "and_ab"), a_bit, b_bit)?; + *slot = Boolean::xor(bit_cs.namespace(|| "xor_vu"), &v, &u)?; + } + + Ok(SmallUInt32 { + bits, + value: a.value.and_then(|a| { + b.value + .and_then(|b| c.value.map(|c| (a & b) ^ (a & c) ^ (b & c))) + }), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::gadgets::{BatchingEq, NoBatchEq, SmallMultiEq}; + use bellpepper_core::test_cs::TestConstraintSystem; + use halo2curves::pasta::Fq; + + #[test] + fn test_small_uint32_constant() { + let u = SmallUInt32::constant(0x12345678); + assert_eq!(u.get_value(), Some(0x12345678)); + } + + #[test] + fn test_small_uint32_rotr() { + let u = SmallUInt32::constant(0x80000001); + let rotated = u.rotr(1); + assert_eq!(rotated.get_value(), Some(0xC0000000)); + } + + #[test] + fn test_small_uint32_shr() { + let u = SmallUInt32::constant(0x80000000); + let shifted = u.shr(1); + assert_eq!(shifted.get_value(), Some(0x40000000)); + } + + #[test] + fn test_small_uint32_xor() { + let mut cs = TestConstraintSystem::::new(); + + let a = SmallUInt32::constant(0xAAAAAAAA); + let b = SmallUInt32::constant(0x55555555); + + let result = a.xor(&mut cs, &b).unwrap(); + assert_eq!(result.get_value(), Some(0xFFFFFFFF)); + assert!(cs.is_satisfied()); + } + + #[test] + fn test_small_uint32_addmany_via_trait() { + let mut cs = TestConstraintSystem::::new(); + + let a = SmallUInt32::alloc(cs.namespace(|| "a"), Some(100)).unwrap(); + let b = SmallUInt32::alloc(cs.namespace(|| "b"), Some(200)).unwrap(); + let c = SmallUInt32::alloc(cs.namespace(|| "c"), Some(300)).unwrap(); + + { + let mut eq = NoBatchEq::::new(&mut cs); + let result = eq.addmany(&[a, b, c]).unwrap(); + assert_eq!(result.get_value(), Some(600)); + } + + assert!(cs.is_satisfied()); + } + + #[test] + fn test_small_uint32_addmany_overflow() { + let mut cs = TestConstraintSystem::::new(); + + let a = SmallUInt32::alloc(cs.namespace(|| "a"), Some(0xFFFFFFFF)).unwrap(); + let b = SmallUInt32::alloc(cs.namespace(|| "b"), Some(1)).unwrap(); + + { + let mut eq = NoBatchEq::::new(&mut cs); + let result = eq.addmany(&[a, b]).unwrap(); + // Should wrap to 0 + assert_eq!(result.get_value(), Some(0)); + } + + assert!(cs.is_satisfied()); + } + + #[test] + fn test_small_uint32_addmany_batching() { + let mut cs = TestConstraintSystem::::new(); + + let a = SmallUInt32::alloc(cs.namespace(|| "a"), Some(100)).unwrap(); + let b = SmallUInt32::alloc(cs.namespace(|| "b"), Some(200)).unwrap(); + let c = SmallUInt32::alloc(cs.namespace(|| "c"), Some(300)).unwrap(); + + { + // Use BatchingEq<21> (full 35-bit addition path) + let mut eq = BatchingEq::::new(&mut cs); + let result = eq.addmany(&[a, b, c]).unwrap(); + assert_eq!(result.get_value(), Some(600)); + } + + assert!(cs.is_satisfied()); + } + + #[test] + fn test_small_uint32_addmany_5_operands() { + // SHA-256 uses 5-operand addition + let mut cs = TestConstraintSystem::::new(); + + let a = SmallUInt32::alloc(cs.namespace(|| "a"), Some(0x12345678)).unwrap(); + let b = SmallUInt32::alloc(cs.namespace(|| "b"), Some(0x87654321)).unwrap(); + let c = SmallUInt32::alloc(cs.namespace(|| "c"), Some(0xDEADBEEF)).unwrap(); + let d = SmallUInt32::alloc(cs.namespace(|| "d"), Some(0xCAFEBABE)).unwrap(); + let e = SmallUInt32::alloc(cs.namespace(|| "e"), Some(0x01020304)).unwrap(); + + let expected = 0x12345678u32 + .wrapping_add(0x87654321) + .wrapping_add(0xDEADBEEF) + .wrapping_add(0xCAFEBABE) + .wrapping_add(0x01020304); + + { + let mut eq = NoBatchEq::::new(&mut cs); + let result = eq.addmany(&[a, b, c, d, e]).unwrap(); + assert_eq!(result.get_value(), Some(expected)); + } + + assert!(cs.is_satisfied()); + } + + #[test] + fn test_small_uint32_sha256_ch() { + let mut cs = TestConstraintSystem::::new(); + + // ch(a, b, c) = (a & b) ^ ((!a) & c) + let a = SmallUInt32::constant(0xFF00FF00); + let b = SmallUInt32::constant(0xF0F0F0F0); + let c = SmallUInt32::constant(0x0F0F0F0F); + + let result = SmallUInt32::sha256_ch(&mut cs, &a, &b, &c).unwrap(); + + // Expected: (0xFF00FF00 & 0xF0F0F0F0) ^ ((~0xFF00FF00) & 0x0F0F0F0F) + // = 0xF000F000 ^ (0x00FF00FF & 0x0F0F0F0F) + // = 0xF000F000 ^ 0x000F000F + // = 0xF00FF00F + assert_eq!(result.get_value(), Some(0xF00FF00F)); + assert!(cs.is_satisfied()); + } + + #[test] + fn test_small_uint32_sha256_maj() { + let mut cs = TestConstraintSystem::::new(); + + // maj(a, b, c) = (a & b) ^ (a & c) ^ (b & c) + let a = SmallUInt32::constant(0xFF00FF00); + let b = SmallUInt32::constant(0xF0F0F0F0); + let c = SmallUInt32::constant(0x0F0F0F0F); + + let result = SmallUInt32::sha256_maj(&mut cs, &a, &b, &c).unwrap(); + + // Expected: (a & b) ^ (a & c) ^ (b & c) + // = (0xFF00FF00 & 0xF0F0F0F0) ^ (0xFF00FF00 & 0x0F0F0F0F) ^ (0xF0F0F0F0 & 0x0F0F0F0F) + // = 0xF000F000 ^ 0x0F000F00 ^ 0x00000000 + // = 0xFF00FF00 + let expected = 0xFF00FF00u32; + assert_eq!(result.get_value(), Some(expected)); + assert!(cs.is_satisfied()); + } +} diff --git a/src/lagrange_accumulator/accumulator.rs b/src/lagrange_accumulator/accumulator.rs new file mode 100644 index 0000000..969e66b --- /dev/null +++ b/src/lagrange_accumulator/accumulator.rs @@ -0,0 +1,423 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Accumulator data structures for Algorithm 6 small-value sumcheck optimization. +//! +//! This module defines: +//! - [`RoundAccumulator`]: Single round accumulator A_i(v, u) with flat storage +//! - [`LagrangeAccumulator`]: Collection of accumulators for all ℓ₀ rounds + +use super::{basis::LagrangeCoeff, evals::LagrangeHatEvals}; +use ff::PrimeField; +use std::ops::AddAssign; + +#[cfg(test)] +use super::domain::{LagrangeHatPoint, LagrangeIndex}; + +/// A single round's accumulator A_i(v, u) with flat contiguous storage. +/// +/// For round i (0-indexed), this stores: +/// - (D+1)^i prefixes (one per v ∈ U_D^i) +/// - Each prefix has D values (one per u ∈ Û_D = {∞, 0, 2, ..., D-1}) +/// +/// Storage: `Vec<[T; D]>` — one allocation, contiguous +/// Access: data[v_idx][u_idx] +/// +/// This flat storage design provides: +/// - Cache-friendly memory access patterns +/// - Vectorizable merge operations +/// - No runtime bounds checks on inner dimension (compile-time D) +pub struct RoundAccumulator { + /// Flat storage: data[v_idx] = [A_i(v, ∞), A_i(v, 0), A_i(v, 2), ...] + data: Vec<[T; D]>, +} + +impl RoundAccumulator { + /// Base of the Lagrange domain U_D (compile-time constant) + const BASE: usize = D + 1; + + /// Create a new accumulator for the given round (0-indexed). + /// + /// Allocates (D+1)^round prefix entries, each with D values. + /// Uses `Default::default()` which is zero for field elements. + pub fn new(round: usize) -> Self { + let num_prefixes = Self::BASE.pow(round as u32); + Self { + data: vec![[T::default(); D]; num_prefixes], + } + } + + /// O(1) indexed accumulation into bucket (v_idx, u_idx). + #[inline] + pub fn accumulate(&mut self, v_idx: usize, u_idx: usize, value: T) + where + T: AddAssign, + { + self.data[v_idx][u_idx] += value; + } + + /// Element-wise merge (tight loop, compiler can vectorize). + /// + /// Used in the reduce phase of parallel fold-reduce. + pub fn merge(&mut self, other: &Self) + where + T: AddAssign, + { + for (a, b) in self.data.iter_mut().zip(&other.data) { + for i in 0..D { + a[i] += b[i]; + } + } + } + + /// Get direct access to data slice. + #[inline] + pub fn data(&self) -> &[[T; D]] { + &self.data + } + + /// Get mutable direct access to data slice. + #[inline] + pub fn data_mut(&mut self) -> &mut [[T; D]] { + &mut self.data + } +} + +/// Sumcheck-specific methods for RoundAccumulator with field elements. +impl RoundAccumulator { + /// Evaluate t_i(u) for all u ∈ Û_D in a single pass. + pub fn eval_t_all_u(&self, coeff: &LagrangeCoeff) -> LagrangeHatEvals { + debug_assert_eq!(self.data.len(), coeff.len()); + let mut acc = [Scalar::ZERO; D]; + for (c, row) in coeff.as_slice().iter().zip(self.data.iter()) { + let scaled = *c; + for i in 0..D { + acc[i] += scaled * row[i]; + } + } + LagrangeHatEvals::from_array(acc) + } +} + +/// Test-only helper methods for RoundAccumulator. +#[cfg(test)] +impl RoundAccumulator { + /// O(1) indexed read from bucket (v_idx, u_idx). + #[inline] + pub fn get(&self, v_idx: usize, u_idx: usize) -> Scalar { + self.data[v_idx][u_idx] + } + + /// Accumulate by domain types (type-safe path). + #[inline] + pub fn accumulate_by_domain( + &mut self, + v: &LagrangeIndex, + u: LagrangeHatPoint, + value: Scalar, + ) { + let v_idx = v.to_flat_index(); + let u_idx = u.to_index(); + self.data[v_idx][u_idx] += value; + } + + /// Read by domain types (type-safe path). + #[inline] + pub fn get_by_domain(&self, v: &LagrangeIndex, u: LagrangeHatPoint) -> Scalar { + let v_idx = v.to_flat_index(); + let u_idx = u.to_index(); + self.data[v_idx][u_idx] + } + + /// Number of prefix entries. + #[inline] + pub fn num_prefixes(&self) -> usize { + self.data.len() + } +} + +/// Generic collection of accumulators for all ℓ₀ rounds. +/// +/// This struct can hold any element type T, enabling both: +/// - Reduced field elements (for final sumcheck usage) +/// - Unreduced wide-limb elements (during DMR-enabled accumulation) +/// +/// Type parameter D is the degree bound for t_i(X) (D=2 for Spartan). +pub struct LagrangeAccumulators { + /// rounds[i] contains A_{i+1} (the accumulator for 1-indexed round i+1) + pub rounds: Vec>, +} + +impl LagrangeAccumulators { + /// Create a fresh accumulator (used per-thread in fold). + /// + /// # Arguments + /// * `l0` - Number of rounds using small-value optimization + pub fn new(l0: usize) -> Self { + let rounds = (0..l0).map(RoundAccumulator::new).collect(); + Self { rounds } + } + + /// O(1) accumulation into bucket (round, v_idx, u_idx). + #[inline] + pub fn accumulate(&mut self, round: usize, v_idx: usize, u_idx: usize, value: T) + where + T: AddAssign, + { + self.rounds[round].accumulate(v_idx, u_idx, value); + } + + /// Merge another accumulator into this one (for reduce phase). + pub fn merge(&mut self, other: &Self) + where + T: AddAssign, + { + for (self_round, other_round) in self.rounds.iter_mut().zip(&other.rounds) { + self_round.merge(other_round); + } + } + + /// Number of rounds. + pub fn num_rounds(&self) -> usize { + self.rounds.len() + } + + /// Check if all elements are zero (equal to default). + #[allow(dead_code)] + pub fn is_all_zero(&self) -> bool + where + T: PartialEq, + { + let zero = T::default(); + self.rounds.iter().all(|round| { + round + .data() + .iter() + .all(|row| row.iter().all(|elem| *elem == zero)) + }) + } +} + +/// Sumcheck-specific methods for LagrangeAccumulator with field elements. +impl LagrangeAccumulators { + /// Get read-only access to a specific round's accumulator. + pub fn round(&self, i: usize) -> &RoundAccumulator { + &self.rounds[i] + } +} + +/// Test-only helper methods for LagrangeAccumulator. +#[cfg(test)] +impl LagrangeAccumulators { + /// Read A_i(v, u). + #[inline] + pub fn get(&self, round: usize, v_idx: usize, u_idx: usize) -> Scalar { + self.rounds[round].get(v_idx, u_idx) + } + + /// Accumulate by domain types (type-safe path). + #[inline] + pub fn accumulate_by_domain( + &mut self, + round: usize, + v: &LagrangeIndex, + u: LagrangeHatPoint, + value: Scalar, + ) { + self.rounds[round].accumulate_by_domain(v, u, value); + } + + /// Read A_i(v, u) by domain types (type-safe path). + #[inline] + pub fn get_by_domain( + &self, + round: usize, + v: &LagrangeIndex, + u: LagrangeHatPoint, + ) -> Scalar { + self.rounds[round].get_by_domain(v, u) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{lagrange_accumulator::domain::LagrangePoint, provider::pasta::pallas}; + use ff::Field; + + type Scalar = pallas::Scalar; + + // Use D=2 for Spartan + const D: usize = 2; + + // === RoundAccumulator tests === + + #[test] + fn test_round_accumulator_new() { + // D=3: round 0 → 4^0=1 prefix, round 1 → 4^1=4 prefixes, round 2 → 4^2=16 prefixes + let acc0: RoundAccumulator = RoundAccumulator::new(0); + assert_eq!(acc0.num_prefixes(), 1); + + let acc1: RoundAccumulator = RoundAccumulator::new(1); + assert_eq!(acc1.num_prefixes(), 4); + + let acc2: RoundAccumulator = RoundAccumulator::new(2); + assert_eq!(acc2.num_prefixes(), 16); + } + + #[test] + fn test_round_accumulator_accumulate_get() { + let mut acc: RoundAccumulator = RoundAccumulator::new(1); // 4 prefixes + + // Initially all zeros + for v_idx in 0..4 { + for u_idx in 0..3 { + assert_eq!(acc.get(v_idx, u_idx), Scalar::ZERO); + } + } + + // Accumulate some values + let val1 = Scalar::from(7u64); + let val2 = Scalar::from(13u64); + + acc.accumulate(0, 0, val1); + acc.accumulate(0, 0, val2); + assert_eq!(acc.get(0, 0), val1 + val2); + + acc.accumulate(2, 1, val1); + assert_eq!(acc.get(2, 1), val1); + + // Other entries unchanged + assert_eq!(acc.get(0, 1), Scalar::ZERO); + assert_eq!(acc.get(1, 0), Scalar::ZERO); + } + + #[test] + fn test_round_accumulator_merge() { + let mut acc1: RoundAccumulator = RoundAccumulator::new(1); // 4 prefixes + let mut acc2: RoundAccumulator = RoundAccumulator::new(1); + + let val1 = Scalar::from(5u64); + let val2 = Scalar::from(11u64); + let val3 = Scalar::from(17u64); + + acc1.accumulate(0, 0, val1); + acc1.accumulate(1, 2, val2); + + acc2.accumulate(0, 0, val3); + acc2.accumulate(2, 1, val1); + + acc1.merge(&acc2); + + // Check merged values + assert_eq!(acc1.get(0, 0), val1 + val3); + assert_eq!(acc1.get(1, 2), val2); + assert_eq!(acc1.get(2, 1), val1); + assert_eq!(acc1.get(3, 0), Scalar::ZERO); + } + + #[test] + fn test_round_accumulator_domain_methods() { + let mut acc: RoundAccumulator = RoundAccumulator::new(1); // 4 prefixes + + // v = (Finite(1),) -> flat index = 2 (base 4: ∞=0, 0=1, 1=2, 2=3) + let v = LagrangeIndex::<3>(vec![LagrangePoint::Finite(1)]); + let u = LagrangeHatPoint::<3>::Infinity; // index 0 + + let val = Scalar::from(42u64); + acc.accumulate_by_domain(&v, u, val); + + assert_eq!(acc.get_by_domain(&v, u), val); + // Verify same via raw indices + assert_eq!(acc.get(2, 0), val); + } + + // === LagrangeAccumulator tests === + + #[test] + fn test_lagrange_accumulators_new() { + // For D=2 (base=3), ℓ₀=3 + // Round 0: 3^0 = 1 prefix + // Round 1: 3^1 = 3 prefixes + // Round 2: 3^2 = 9 prefixes + let acc: LagrangeAccumulators = LagrangeAccumulators::new(3); + + assert_eq!(acc.num_rounds(), 3); + assert_eq!(acc.round(0).num_prefixes(), 1); + assert_eq!(acc.round(1).num_prefixes(), 3); + assert_eq!(acc.round(2).num_prefixes(), 9); + } + + #[test] + fn test_lagrange_accumulators_accumulate_get() { + let mut acc: LagrangeAccumulators = LagrangeAccumulators::new(3); + + let val1 = Scalar::from(19u64); + let val2 = Scalar::from(23u64); + + // Accumulate into different rounds + acc.accumulate(0, 0, 0, val1); + acc.accumulate(1, 2, 1, val2); + acc.accumulate(2, 6, 1, val1); + + assert_eq!(acc.get(0, 0, 0), val1); + assert_eq!(acc.get(1, 2, 1), val2); + assert_eq!(acc.get(2, 6, 1), val1); + assert_eq!(acc.get(2, 0, 0), Scalar::ZERO); + } + + #[test] + fn test_lagrange_accumulators_merge() { + let mut acc1: LagrangeAccumulators = LagrangeAccumulators::new(3); + let mut acc2: LagrangeAccumulators = LagrangeAccumulators::new(3); + + let val1 = Scalar::from(7u64); + let val2 = Scalar::from(11u64); + let val3 = Scalar::from(13u64); + + acc1.accumulate(0, 0, 0, val1); + acc1.accumulate(1, 1, 0, val2); + + acc2.accumulate(0, 0, 0, val3); + acc2.accumulate(2, 4, 1, val1); + + acc1.merge(&acc2); + + assert_eq!(acc1.get(0, 0, 0), val1 + val3); + assert_eq!(acc1.get(1, 1, 0), val2); + assert_eq!(acc1.get(2, 4, 1), val1); + } + + #[test] + fn test_lagrange_accumulators_domain_methods() { + let mut acc: LagrangeAccumulators = LagrangeAccumulators::new(2); + + // Round 1 has 3 prefixes (base^1) + // v = (Finite(0),) -> flat index = 1 (∞=0, 0=1, 1=2) + let v = LagrangeIndex::<2>(vec![LagrangePoint::Finite(0)]); + let u = LagrangeHatPoint::::Infinity; // index 0 + + let val = Scalar::from(99u64); + acc.accumulate_by_domain(1, &v, u, val); + + assert_eq!(acc.get_by_domain(1, &v, u), val); + // Verify same via raw indices + assert_eq!(acc.get(1, 1, 0), val); + } + + #[test] + fn test_accumulator_sizes_match_spec() { + // For D=2, ℓ₀=3 should have total 26 elements + // Round 0: 1 * 2 = 2 + // Round 1: 3 * 2 = 6 + // Round 2: 9 * 2 = 18 + // Total: 26 + let acc: LagrangeAccumulators = LagrangeAccumulators::new(3); + + let total_elements: usize = (0..3).map(|i| acc.round(i).num_prefixes() * 2).sum(); + assert_eq!(total_elements, 26); + } +} diff --git a/src/lagrange_accumulator/accumulator_builder.rs b/src/lagrange_accumulator/accumulator_builder.rs new file mode 100644 index 0000000..8ef0fa5 --- /dev/null +++ b/src/lagrange_accumulator/accumulator_builder.rs @@ -0,0 +1,1036 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Builder functions for constructing Lagrange accumulators (Procedure 9). +//! +//! This module provides: +//! - [`build_accumulators_spartan`]: Optimized builder for Spartan's cubic relation +//! - [`build_accumulators`]: Generic builder for arbitrary polynomial products + +use super::{ + accumulator::LagrangeAccumulators, + delay_modular_reduction_mode::DelayedModularReductionMode, + domain::LagrangeIndex, + extension::LagrangeEvaluatedMultilinearPolynomial, + index::CachedPrefixIndex, + mat_vec_mle::MatVecMLE, + thread_state::{GenericThreadState, SpartanThreadState}, +}; +use crate::{ + csr::Csr, + polys::{ + eq::{EqPolynomial, compute_suffix_eq_pyramid}, + multilinear::MultilinearPolynomial, + }, +}; +use ff::PrimeField; +use rayon::prelude::*; + +use super::index::compute_idx4; + +/// Polynomial degree D for Spartan's small-value sumcheck. +/// For Spartan's cubic relation (A·B - C), D=2 yields quadratic t_i. +pub const SPARTAN_T_DEGREE: usize = 2; + +struct EqSplitTables { + e_in: Vec, + e_xout: Vec, + e_y: Vec>, + e_y_sizes: Vec, +} + +struct BetaPrefixCache { + cache: Csr, + num_betas: usize, +} + +/// Procedure 9: Build accumulators A_i(v, u) for Spartan's first sum-check (Algorithm 6). +/// +/// Computes accumulators for: g(X) = eq(τ, X) · (Az(X) · Bz(X) - Cz(X)) +/// +/// D is the degree bound of t_i(X) (not s_i); for Spartan, D = 2. +/// +/// # Type Parameters +/// +/// - `S`: Field type +/// - `P`: Polynomial type implementing [`MatVecMLE`] +/// - `Mode`: Delayed modular reduction mode selection ([`super::DelayedModularReductionEnabled`] or [`super::DelayedModularReductionDisabled`]) +/// +/// # Parallelism strategy +/// +/// - Outer parallel loop over x_out values (using Rayon fold-reduce) +/// - Each thread maintains thread-local accumulators +/// - Final reduction merges all thread-local results via element-wise addition +/// +/// # Spartan-specific optimizations (D=2) +/// +/// - Skip cz_ext entirely: for binary betas, use cz_pref directly +/// - Skip binary betas: for satisfying witnesses, Az·Bz = Cz on {0,1}^n, so they contribute 0 +/// - Only process betas with ∞ (where Cz doesn't contribute anyway) +pub fn build_accumulators_spartan( + az: &P, + bz: &P, + taus: &[S], + l0: usize, +) -> LagrangeAccumulators +where + S: PrimeField + Send + Sync, + P: MatVecMLE, + Mode: DelayedModularReductionMode, +{ + let base: usize = 3; // D + 1 = 2 + 1 = 3 + let l = az.len().trailing_zeros() as usize; + debug_assert_eq!(az.len(), 1usize << l, "poly size must be power of 2"); + debug_assert_eq!(az.len(), bz.len()); + debug_assert_eq!(taus.len(), l, "taus must have length ℓ"); + debug_assert!(l0 < l, "l0 must be < ℓ"); + + let suffix_vars = l - l0; + let prefix_size = 1usize << l0; + + let (eq_tables, _in_vars, xout_vars) = precompute_eq_tables(taus, l0); + let num_x_out = 1usize << xout_vars; + let BetaPrefixCache { + cache: beta_prefix_cache, + num_betas, + } = build_beta_cache::<2>(l0); + + let beta_has_infinity: Vec = (0..num_betas) + .map(|mut t| { + for _ in 0..l0 { + if t % base == 0 { + return true; + } + t /= base; + } + false + }) + .collect(); + + // Precompute indices of betas with infinity to avoid filter in hot loop + let betas_with_infty: Vec = (0..num_betas).filter(|&i| beta_has_infinity[i]).collect(); + + let ext_size = base.pow(l0 as u32); // (D+1)^l0 + + // Parallel over x_out with thread-local state (zero per-iteration allocations) + // State type determined by Mode: DelayedModularReductionEnabled uses unreduced accumulators, DelayedModularReductionDisabled uses reduced + type State = SpartanThreadState>::Value, P, Mode, 2>; + + let fold_results: Vec> = (0..num_x_out) + .into_par_iter() + .fold( + || State::::new(l0, num_betas, prefix_size, ext_size), + |mut state: State, x_out_bits| { + // Reset partial sums for this x_out iteration + state.reset_partial_sums(); + + let ex = eq_tables.e_xout[x_out_bits]; + + // Inner loop over x_in - accumulate into UNREDUCED form + // Each beta_partial_sums[beta_idx] accumulates 2^(l/2) terms per x_out. + // Safety bound for UnreducedFieldInt (N limbs, 64 bits per limb): + // field_bits + product_bits + (l/2) < 64*N + // i32 path: N=6, product_bits<=62; i64 path: N=8, product_bits<=126. + for (x_in_bits, &e_in_eval) in eq_tables.e_in.iter().enumerate() { + let suffix = (x_in_bits << xout_vars) | x_out_bits; + + // Fill prefix buffers by index assignment (no allocation) + #[allow(clippy::needless_range_loop)] + for prefix in 0..prefix_size { + let idx = (prefix << suffix_vars) | suffix; + state.az_pref[prefix] = az.get(idx); + state.bz_pref[prefix] = bz.get(idx); + } + + // Extend Az and Bz to Lagrange domain in-place (zero allocation) + let az_size = LagrangeEvaluatedMultilinearPolynomial::::extend_in_place( + &state.az_pref, + &mut state.az_buf_curr, + &mut state.az_buf_scratch, + ); + let az_ext = &state.az_buf_curr[..az_size]; + + let bz_size = LagrangeEvaluatedMultilinearPolynomial::::extend_in_place( + &state.bz_pref, + &mut state.bz_buf_curr, + &mut state.bz_buf_scratch, + ); + let bz_ext = &state.bz_buf_curr[..bz_size]; + + // Only process betas with ∞ - binary betas contribute 0 for satisfying witnesses + // Accumulation strategy determined by Mode (unreduced for DelayedModularReductionEnabled, reduced for DelayedModularReductionDisabled) + // Use precomputed indices to avoid filter overhead in inner loop + for &beta_idx in &betas_with_infty { + let prod = P::multiply_witnesses(az_ext[beta_idx], bz_ext[beta_idx]); + Mode::accumulate_eq_product(&mut state.partial_sums[beta_idx], prod, &e_in_eval); + } + } + + // Pre-compute and filter: reduce all non-zero betas upfront + // This eliminates closure call overhead in the scatter loop + // Reuse pre-allocated buffer to avoid per-iteration allocations + for &beta_idx in &betas_with_infty { + if Mode::partial_sum_is_zero(&state.partial_sums[beta_idx]) { + continue; + } + let val = Mode::modular_reduction_partial_sum(&state.partial_sums[beta_idx]); + if ff::Field::is_zero(&val).into() { + continue; + } + state.beta_values.push((beta_idx, val)); + } + + // Distribute beta values → A_i(v,u) via idx4 + // Accumulation strategy determined by Mode: + // - DelayedModularReductionEnabled: Unreduced F×F accumulation, final reduction once after merge + // - DelayedModularReductionDisabled: Immediate F×F reduction on each accumulation + for &(beta_idx, val) in &state.beta_values { + let z_beta = ex * val; + for pref in &beta_prefix_cache[beta_idx] { + let ey = &eq_tables.e_y[pref.round_0][pref.y_idx]; + Mode::accumulate_scatter( + &mut state.scatter_acc.rounds[pref.round_0].data_mut()[pref.v_idx][pref.u_idx], + ey, + &z_beta, + ); + } + } + + state + }, + ) + .collect(); + + // Sequential merge: avoids parallel reduce tree overhead and identity allocations. + // Each fold task's State is merged one by one, spreading deallocation cost. + // Using std::iter::Iterator::reduce (not rayon's) - no extra state allocations. + let merged = fold_results + .into_iter() + .reduce(|mut a, b| { + a.scatter_acc.merge(&b.scatter_acc); + a + }) + .expect("num_x_out > 0 guarantees non-empty fold results"); + + // Finalize: convert scatter accumulator to LagrangeAccumulator + // For DelayedModularReductionEnabled: performs final Montgomery reductions on each element + // For DelayedModularReductionDisabled: elements are already reduced (identity operation) + let mut result: LagrangeAccumulators = LagrangeAccumulators::new(l0); + for (round_idx, round) in merged.scatter_acc.rounds.iter().enumerate() { + for (v_idx, row) in round.data().iter().enumerate() { + for (u_idx, elem) in row.iter().enumerate() { + if !Mode::scatter_element_is_zero(elem) { + result.rounds[round_idx].data_mut()[v_idx][u_idx] = Mode::modular_reduction_scatter(elem); + } + } + } + } + result +} + +/// Generic Procedure 9: Build accumulators A_i(v, u) for Algorithm 6. +/// +/// Computes accumulators for: g(X) = eq(τ, X) · ∏_{k=1}^d p_k(X) +/// +/// This is the general algorithm that works for any number of polynomials +/// and any degree bound D. +/// +/// # Arguments +/// * `polys` - Slice of multilinear polynomials to multiply +/// * `taus` - Random challenge points (length ℓ) +/// * `l0` - Number of small-value rounds +#[allow(dead_code)] +pub fn build_accumulators( + polys: &[&MultilinearPolynomial], + taus: &[S], + l0: usize, +) -> LagrangeAccumulators { + assert!(!polys.is_empty(), "must have at least one polynomial"); + let base: usize = D + 1; + let l = polys[0].Z.len().trailing_zeros() as usize; + debug_assert_eq!( + polys[0].Z.len(), + 1usize << l, + "poly size must be power of 2" + ); + for poly in polys.iter().skip(1) { + debug_assert_eq!( + poly.Z.len(), + polys[0].Z.len(), + "all polys must have same size" + ); + } + debug_assert_eq!(taus.len(), l, "taus must have length ℓ"); + debug_assert!(l0 < l, "l0 must be < ℓ"); + + let suffix_vars = l - l0; + let prefix_size = 1usize << l0; + let d = polys.len(); + + let (eq_tables, in_vars, xout_vars) = precompute_eq_tables(taus, l0); + let num_x_out = 1usize << xout_vars; + let _num_x_in = 1usize << in_vars; + let BetaPrefixCache { + cache: beta_prefix_cache, + num_betas, + } = build_beta_cache::(l0); + + let ext_size = base.pow(l0 as u32); + + // Parallel over x_out with thread-local state (zero per-iteration allocations) + (0..num_x_out) + .into_par_iter() + .fold( + || { + GenericThreadState::::new( + l0, + num_betas, + prefix_size, + ext_size, + d, + &eq_tables.e_y_sizes, + ) + }, + |mut state, x_out_bits| { + // Reset partial sums for this x_out iteration + state.reset_partial_sums(); + + // Compute eyx = ey * ex on-the-fly for this x_out_bits (tiny, stays hot in L1) + let ex = eq_tables.e_xout[x_out_bits]; + fill_eyx(ex, &eq_tables.e_y, &mut state.eyx); + + // Inner loop over x_in + for (x_in_bits, &e_in_eval) in eq_tables.e_in.iter().enumerate() { + let suffix = (x_in_bits << xout_vars) | x_out_bits; + + // Fill all d prefix buffers by index assignment + #[allow(clippy::needless_range_loop)] + for prefix in 0..prefix_size { + let idx = (prefix << suffix_vars) | suffix; + for (k, poly) in polys.iter().enumerate() { + state.poly_prefs[k][prefix] = poly.Z[idx]; + } + } + + // Extend all d polynomials in-place (zero allocation) + // Result is always in buf_curr (first element of each pair) + for (pref, (buf_curr, buf_scratch)) in + state.poly_prefs.iter().zip(state.buf_pairs.iter_mut()) + { + LagrangeEvaluatedMultilinearPolynomial::::extend_in_place( + pref, + buf_curr, + buf_scratch, + ); + } + + // Compute ∏ p_k(β) for each beta + for (beta_idx, sum) in state.beta_partial_sums.iter_mut().enumerate() { + let prod: S = state + .buf_pairs + .iter() + .map(|(buf_curr, _)| buf_curr[beta_idx]) + .product(); + *sum += e_in_eval * prod; + } + } + + // Distribute beta_partial_sums → A_i(v,u) via idx4 + scatter_beta_contributions( + 0..num_betas, + &beta_prefix_cache, + &state.eyx, + &mut state.acc, + |beta_idx| { + let val = state.beta_partial_sums[beta_idx]; + if ff::Field::is_zero(&val).into() { + None + } else { + Some(val) + } + }, + ); + + state + }, + ) + .map(|state| state.acc) + .reduce( + || LagrangeAccumulators::::new(l0), + |mut a, b| { + a.merge(&b); + a + }, + ) +} + +// ============================================================================= +// Helper functions +// ============================================================================= + +/// Precompute eq polynomial tables with balanced split for e_in and e_xout. +/// +/// Returns (tables, in_vars, xout_vars) where: +/// - in_vars = ceil((l - l0) / 2) - variables for inner loop (e_in) +/// - xout_vars = floor((l - l0) / 2) - variables for outer loop (e_xout) +/// +/// The balanced split reduces precomputation cost by ~33% compared to the +/// asymmetric l/2 split, and enables odd number of rounds. +fn precompute_eq_tables(taus: &[S], l0: usize) -> (EqSplitTables, usize, usize) { + let l = taus.len(); + let suffix_vars = l - l0; + let in_vars = suffix_vars.div_ceil(2); // ceiling: e_in larger (inner loop, sequential access) + let xout_vars = suffix_vars - in_vars; // floor: e_xout smaller (outer loop, reused) + + let e_in = EqPolynomial::evals_from_points(&taus[l0..l0 + in_vars]); // 2^in_vars entries + let e_xout = EqPolynomial::evals_from_points(&taus[l0 + in_vars..]); // 2^xout_vars entries + let e_y = compute_suffix_eq_pyramid(&taus[..l0], l0); // Vec per round, total 2^l0 - 1 + let e_y_sizes: Vec = e_y.iter().map(|v| v.len()).collect(); + + ( + EqSplitTables { + e_in, + e_xout, + e_y, + e_y_sizes, + }, + in_vars, + xout_vars, + ) +} + +fn build_beta_cache(l0: usize) -> BetaPrefixCache { + let base: usize = D + 1; + let num_betas = base.pow(l0 as u32); + let mut cache: Csr = Csr::with_capacity(num_betas, num_betas * l0); + for b in 0..num_betas { + let beta = LagrangeIndex::::from_flat_index(b, l0); + let entries: Vec<_> = compute_idx4(&beta) + .into_iter() + .map(|entry| CachedPrefixIndex { + round_0: entry.round_0idx(), + v_idx: entry.v_idx, + u_idx: entry.u.to_index(), + y_idx: entry.y_idx, + }) + .collect(); + cache.push(&entries); + } + + BetaPrefixCache { cache, num_betas } +} + +#[inline] +fn fill_eyx(ex: S, e_y: &[Vec], eyx: &mut [Vec]) { + debug_assert_eq!(e_y.len(), eyx.len()); + for (round, ey_round) in e_y.iter().enumerate() { + let dst = &mut eyx[round]; + debug_assert_eq!(dst.len(), ey_round.len()); + for (dst_i, &ey) in dst.iter_mut().zip(ey_round.iter()) { + *dst_i = ey * ex; + } + } +} + +/// Legacy scatter function using eyx precomputation with immediate F×F reduction. +/// Each contribution does F×F multiply with internal Montgomery reduction. +#[allow(dead_code)] // Kept for reference; new code uses scatter_beta_contributions_unreduced +#[inline] +fn scatter_beta_contributions( + beta_indices: I, + beta_prefix_cache: &Csr, + eyx: &[Vec], + acc: &mut LagrangeAccumulators, + mut value_for_beta: F, +) where + I: IntoIterator, + F: FnMut(usize) -> Option, +{ + for beta_idx in beta_indices { + let Some(val) = value_for_beta(beta_idx) else { + continue; + }; + for pref in &beta_prefix_cache[beta_idx] { + let eyx_val = eyx[pref.round_0][pref.y_idx]; + acc.accumulate(pref.round_0, pref.v_idx, pref.u_idx, eyx_val * val); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + lagrange_accumulator::{ + DelayedModularReductionDisabled, DelayedModularReductionEnabled, domain::LagrangeHatPoint, + }, + polys::eq::EqPolynomial, + provider::pasta::pallas, + }; + use ff::Field; + + type Scalar = pallas::Scalar; + + // Use the shared constant for polynomial degree in tests + const D: usize = SPARTAN_T_DEGREE; + + /// End-to-end correctness for build_accumulators_spartan on a tiny instance. + /// + /// ℓ = 4, ℓ0 = 2, D = 2. + /// Uses a satisfying witness (Az * Bz = Cz) to test the optimized Spartan path. + /// Verifies against a straightforward (non-parallel) implementation of Procedure 9. + #[test] + fn test_build_accumulators_spartan_matches_naive() { + let l0 = 2; + let l = 4; + + // Balanced split for eq tables (matches precompute_eq_tables) + let suffix_vars = l - l0; // 2 + let in_vars = (suffix_vars + 1) / 2; // 1 + let xout_vars = suffix_vars - in_vars; // 1 + + // Define deterministic Az, Bz, Cz over {0,1}^4 + // Use a SATISFYING witness: Cz = Az * Bz + let eval = |bits: usize| -> Scalar { + // Simple affine: a0 x0 + a1 x1 + a2 x2 + a3 x3 + const + let x0 = (bits >> 3) & 1; + let x1 = (bits >> 2) & 1; + let x2 = (bits >> 1) & 1; + let x3 = bits & 1; + Scalar::from((x0 + 2 * x1 + 3 * x2 + 4 * x3 + 5) as u64) + }; + let az_vals: Vec = (0..16).map(eval).collect(); + let bz_vals: Vec = (0..16).map(|b| eval(b) + Scalar::from(7u64)).collect(); + // Satisfying witness: cz = az * bz + let cz_vals: Vec = az_vals + .iter() + .zip(bz_vals.iter()) + .map(|(a, b)| *a * *b) + .collect(); + + let az = MultilinearPolynomial::new(az_vals.clone()); + let bz = MultilinearPolynomial::new(bz_vals.clone()); + let cz = MultilinearPolynomial::new(cz_vals.clone()); + + // Taus (length ℓ) + let taus: Vec = vec![ + Scalar::from(5u64), + Scalar::from(7u64), + Scalar::from(11u64), + Scalar::from(13u64), + ]; + + // Implementation under test + let acc_impl = + build_accumulators_spartan::<_, _, DelayedModularReductionDisabled>(&az, &bz, &taus, l0); + + // Precompute eq tables for naive computation (balanced split) + let e_in = EqPolynomial::evals_from_points(&taus[l0..l0 + in_vars]); // τ[2..3] + let e_xout = EqPolynomial::evals_from_points(&taus[l0 + in_vars..]); // τ[3..4] + let e_y = compute_suffix_eq_pyramid(&taus[..l0], l0); + + let num_betas = (D + 1).pow(l0 as u32); + let idx4_cache: Vec> = (0..num_betas) + .map(|b| compute_idx4(&LagrangeIndex::::from_flat_index(b, l0))) + .collect(); + + // Naive accumulators + let mut acc_naive: LagrangeAccumulators = LagrangeAccumulators::new(l0); + + // Iterate over x_out and x_in with balanced split + #[allow(clippy::needless_range_loop)] + for x_out_bits in 0..(1 << xout_vars) { + let ex = e_xout[x_out_bits]; + + for x_in_bits in 0..(1 << in_vars) { + let suffix = (x_in_bits << xout_vars) | x_out_bits; + + let az_pref = az.gather_prefix_evals(l0, suffix); + let bz_pref = bz.gather_prefix_evals(l0, suffix); + let cz_pref = cz.gather_prefix_evals(l0, suffix); + + let az_ext = + LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&az_pref); + let bz_ext = + LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&bz_pref); + let cz_ext = + LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&cz_pref); + + let e_in_eval = e_in[x_in_bits]; + + #[allow(clippy::needless_range_loop)] + for beta_idx in 0..num_betas { + let beta_tuple = az_ext.to_domain_tuple(beta_idx); + let ab = az_ext.get(beta_idx) * bz_ext.get(beta_idx); + let prod = if beta_tuple.has_infinity() { + ab + } else { + ab - cz_ext.get(beta_idx) + }; + let val = e_in_eval * prod; + + for pref in &idx4_cache[beta_idx] { + let ey = e_y[pref.round_0idx()][pref.y_idx]; + acc_naive.accumulate( + pref.round_0idx(), + pref.v_idx, + pref.u.to_index(), + ey * ex * val, + ); + } + } + } + } + + // Compare all buckets + for round in 0..l0 { + let num_v = (D + 1).pow(round as u32); + for v_idx in 0..num_v { + for u_idx in 0..D { + let got = acc_impl.get(round, v_idx, u_idx); + let expect = acc_naive.get(round, v_idx, u_idx); + assert_eq!( + got, expect, + "Mismatch at round {}, v_idx {}, u_idx {}", + round, v_idx, u_idx + ); + } + } + } + } + + /// Check the ∞ rule: constant polynomials have zero leading coefficient at ∞. + /// + /// ℓ = 2, ℓ0 = 1, D = 2. + /// Tests product of two constant polynomials (1 and -1) = -1 everywhere. + /// At ∞: leading coefficient of constant poly is 0. + /// At finite points: evaluates to the constant value. + /// + /// Uses generic `build_accumulators` (Procedure 9). + #[test] + fn test_infinity_drops_cz() { + let l0 = 1; + let l = 2; + + // Balanced split for eq tables + let suffix_vars = l - l0; // 1 + let in_vars = (suffix_vars + 1) / 2; // 1 + + // Two constant polynomials: 1 and -1, product = -1 + let ones = MultilinearPolynomial::new(vec![Scalar::ONE; 1 << l]); + let neg_ones = MultilinearPolynomial::new(vec![-Scalar::ONE; 1 << l]); + + let taus: Vec = vec![Scalar::from(5u64), Scalar::from(7u64)]; + + // Use generic build_accumulators with D=2 for product of two polynomials + let acc = build_accumulators::(&[&ones, &neg_ones], &taus, l0); + + // Compute e_in sum = Σ eq(τ[l0..l0+in_vars], xin), here in_vars=1, l0=1 -> slice τ[1..2] + let e_in = EqPolynomial::evals_from_points(&taus[l0..l0 + in_vars]); + let e_in_eval_sum: Scalar = e_in.iter().copied().sum(); + + // Round 0, v_idx=0 + let u_infinity_idx = LagrangeHatPoint::<2>::Infinity.to_index(); // 0 + let u_zero_idx = LagrangeHatPoint::<2>::Finite(0).to_index(); // 1 + + let acc_inf = acc.get(0, 0, u_infinity_idx); + let acc_zero = acc.get(0, 0, u_zero_idx); + + // At ∞: leading coefficient of constant poly is 0 + assert_eq!( + acc_inf, + Scalar::ZERO, + "Constant poly has zero leading coeff at ∞" + ); + // At 0: product = 1 * (-1) = -1 + assert_eq!( + acc_zero, + e_in_eval_sum * (-Scalar::ONE), + "Should equal sum * (-1)" + ); + } + + /// Binary-β zero shortcut: Az=Bz=Cz=first variable (x0), so Az·Bz−Cz=0 on binary β. + /// Non-binary β (∞) should yield non-zero in some bucket. + #[test] + fn test_binary_beta_zero_shortcut_behavior() { + // Use l0=1 so round 0 buckets are fed only by β of length 1 (easy to reason about). + let l0 = 1; + let l = 2; + + // Az = Bz = top bit x0 (most significant of 2 bits) + // For satisfying witness, Cz = Az * Bz = Az (since Az ∈ {0,1} and Az = Bz) + let az_vals: Vec = (0..(1 << l)) + .map(|bits| { + let x0 = (bits >> (l - 1)) & 1; + Scalar::from(x0 as u64) + }) + .collect(); + let bz_vals = az_vals.clone(); + + let az = MultilinearPolynomial::new(az_vals); + let bz = MultilinearPolynomial::new(bz_vals); + + let taus: Vec = vec![Scalar::from(3u64), Scalar::from(5u64)]; + + let acc = + build_accumulators_spartan::<_, _, DelayedModularReductionDisabled>(&az, &bz, &taus, l0); + + // Only round 0 exists (v is empty). β ranges over U_d with binary {0,1} and non-binary {∞}. + // Buckets for u = 0 should be zero (binary β), bucket for u = ∞ should be non-zero. + let u_inf = LagrangeHatPoint::::Infinity.to_index(); // 0 + let u_zero = LagrangeHatPoint::::Finite(0).to_index(); // 1 + + assert!( + bool::from(acc.get(0, 0, u_zero).is_zero()), + "binary β should give zero for u=0" + ); + assert!( + !bool::from(acc.get(0, 0, u_inf).is_zero()), + "non-binary β (∞) should give non-zero" + ); + } + + /// Test generic build_accumulators (Procedure 9) with a product of 3 polynomials. + /// + /// ℓ = 10, ℓ0 = 3, D = 3 (degree bound for product of 3 polynomials). + /// Verifies that accumulators are computed correctly by comparing against naive computation. + #[test] + fn test_build_accumulators_product_of_three() { + use rand::{SeedableRng, rngs::StdRng}; + + const L: usize = 10; + const L0: usize = 3; + const D: usize = 3; // Degree bound for product of 3 linear polynomials + + let n = 1usize << L; + + // Balanced split for eq tables (matches precompute_eq_tables) + let suffix_vars = L - L0; // 7 + let in_vars = (suffix_vars + 1) / 2; // 4 + let xout_vars = suffix_vars - in_vars; // 3 + + let num_betas = (D + 1).pow(L0 as u32); + + let mut rng = StdRng::seed_from_u64(42); + + // Create 3 random multilinear polynomials + let p1_vals: Vec = (0..n).map(|_| Scalar::random(&mut rng)).collect(); + let p2_vals: Vec = (0..n).map(|_| Scalar::random(&mut rng)).collect(); + let p3_vals: Vec = (0..n).map(|_| Scalar::random(&mut rng)).collect(); + + let p1 = MultilinearPolynomial::new(p1_vals); + let p2 = MultilinearPolynomial::new(p2_vals); + let p3 = MultilinearPolynomial::new(p3_vals); + + // Random taus + let taus: Vec = (0..L).map(|_| Scalar::random(&mut rng)).collect(); + + // Build accumulators using generic Procedure 9 + let acc_impl = build_accumulators::(&[&p1, &p2, &p3], &taus, L0); + + // ===== Naive computation for comparison (balanced split) ===== + let e_in = EqPolynomial::evals_from_points(&taus[L0..L0 + in_vars]); + let e_xout = EqPolynomial::evals_from_points(&taus[L0 + in_vars..]); + let e_y = compute_suffix_eq_pyramid(&taus[..L0], L0); + + let idx4_cache: Vec> = (0..num_betas) + .map(|b| compute_idx4(&LagrangeIndex::::from_flat_index(b, L0))) + .collect(); + + let mut acc_naive: LagrangeAccumulators = LagrangeAccumulators::new(L0); + + #[allow(clippy::needless_range_loop)] + for x_out_bits in 0..(1 << xout_vars) { + let ex = e_xout[x_out_bits]; + + #[allow(clippy::needless_range_loop)] + for x_in_bits in 0..(1 << in_vars) { + let suffix = (x_in_bits << xout_vars) | x_out_bits; + + // Gather prefix evaluations and extend to Lagrange domain + let p1_pref = p1.gather_prefix_evals(L0, suffix); + let p2_pref = p2.gather_prefix_evals(L0, suffix); + let p3_pref = p3.gather_prefix_evals(L0, suffix); + + let p1_ext = + LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&p1_pref); + let p2_ext = + LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&p2_pref); + let p3_ext = + LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&p3_pref); + + let e_in_eval = e_in[x_in_bits]; + + #[allow(clippy::needless_range_loop)] + for beta_idx in 0..num_betas { + // Compute product p1(β) * p2(β) * p3(β) + let prod = p1_ext.get(beta_idx) * p2_ext.get(beta_idx) * p3_ext.get(beta_idx); + let val = e_in_eval * prod; + + // Distribute to accumulators via idx4 + for pref in &idx4_cache[beta_idx] { + let ey = e_y[pref.round_0idx()][pref.y_idx]; + acc_naive.accumulate( + pref.round_0idx(), + pref.v_idx, + pref.u.to_index(), + ey * ex * val, + ); + } + } + } + } + + // ===== Compare all accumulator buckets ===== + for round in 0..L0 { + let num_v = (D + 1).pow(round as u32); + for v_idx in 0..num_v { + for u_idx in 0..D { + let got = acc_impl.get(round, v_idx, u_idx); + let expect = acc_naive.get(round, v_idx, u_idx); + assert_eq!( + got, expect, + "Mismatch at round {}, v_idx {}, u_idx {}", + round, v_idx, u_idx + ); + } + } + } + } + + /// Test that build_accumulators_spartan with i32 witnesses matches field witnesses. + /// + /// Uses the same setup as test_build_accumulators_spartan_matches_naive but + /// compares the small-value version against the original field-element version. + #[test] + fn test_build_accumulators_spartan_small_matches_field() { + let l0 = 2; + + // Define deterministic Az, Bz over {0,1}^4 using small values + // Values must fit in i32 for the small-value optimization + let eval = |bits: usize| -> i32 { + let x0 = (bits >> 3) & 1; + let x1 = (bits >> 2) & 1; + let x2 = (bits >> 1) & 1; + let x3 = bits & 1; + (x0 + 2 * x1 + 3 * x2 + 4 * x3 + 5) as i32 + }; + + // Create small-value polynomials + let az_small_vals: Vec = (0..16).map(&eval).collect(); + let bz_small_vals: Vec = (0..16).map(|b| eval(b) + 7).collect(); + + let az_small = MultilinearPolynomial::new(az_small_vals); + let bz_small = MultilinearPolynomial::new(bz_small_vals); + + // Create field-element polynomials (same values) + let az_field_vals: Vec = (0..16).map(|b| Scalar::from(eval(b) as u64)).collect(); + let bz_field_vals: Vec = (0..16) + .map(|b| Scalar::from((eval(b) + 7) as u64)) + .collect(); + + let az_field = MultilinearPolynomial::new(az_field_vals); + let bz_field = MultilinearPolynomial::new(bz_field_vals); + + // Taus (length ℓ) + let taus: Vec = vec![ + Scalar::from(5u64), + Scalar::from(7u64), + Scalar::from(11u64), + Scalar::from(13u64), + ]; + + // Build accumulators using both versions (unified function, different input types) + let acc_small = build_accumulators_spartan::<_, _, DelayedModularReductionEnabled>( + &az_small, &bz_small, &taus, l0, + ); + let acc_field = build_accumulators_spartan::<_, _, DelayedModularReductionDisabled>( + &az_field, &bz_field, &taus, l0, + ); + + // Compare all buckets + for round in 0..l0 { + let num_v = (D + 1).pow(round as u32); + for v_idx in 0..num_v { + for u_idx in 0..D { + let got = acc_small.get(round, v_idx, u_idx); + let expect = acc_field.get(round, v_idx, u_idx); + assert_eq!( + got, expect, + "Mismatch at round {}, v_idx {}, u_idx {}", + round, v_idx, u_idx + ); + } + } + } + } + + /// Test build_accumulators_spartan with i32 witnesses using larger inputs to stress test. + #[test] + fn test_build_accumulators_spartan_small_larger() { + use crate::small_field::SmallValueField; + + let l0 = 3; + let l = 10; + let n = 1 << l; + + // Create polynomials with varying small values + let az_vals: Vec = (0..n).map(|i| (i % 1000) + 1).collect(); + let bz_vals: Vec = (0..n).map(|i| ((i * 7) % 1000) + 1).collect(); + + let az_small = MultilinearPolynomial::new(az_vals.clone()); + let bz_small = MultilinearPolynomial::new(bz_vals.clone()); + + // Create field-element versions + let az_field = + MultilinearPolynomial::new(az_vals.iter().map(|&s| Scalar::small_to_field(s)).collect()); + let bz_field = + MultilinearPolynomial::new(bz_vals.iter().map(|&s| Scalar::small_to_field(s)).collect()); + + // Random-looking taus + let taus: Vec = (0..l).map(|i| Scalar::from((i * 7 + 3) as u64)).collect(); + + // Build and compare (unified function, different input types) + let acc_small = build_accumulators_spartan::<_, _, DelayedModularReductionEnabled>( + &az_small, &bz_small, &taus, l0, + ); + let acc_field = build_accumulators_spartan::<_, _, DelayedModularReductionDisabled>( + &az_field, &bz_field, &taus, l0, + ); + + for round in 0..l0 { + let num_v = (D + 1).pow(round as u32); + for v_idx in 0..num_v { + for u_idx in 0..D { + let got = acc_small.get(round, v_idx, u_idx); + let expect = acc_field.get(round, v_idx, u_idx); + assert_eq!( + got, expect, + "Mismatch at round {}, v_idx {}, u_idx {}", + round, v_idx, u_idx + ); + } + } + } + } + + /// Test that odd number of rounds works correctly with balanced split. + /// + /// ℓ = 11 (odd), ℓ0 = 3, D = 2. + /// This tests the new balanced split which enables odd rounds. + /// With l=11, l0=3: suffix_vars=8, in_vars=4, xout_vars=4 (perfectly balanced!) + #[test] + fn test_build_accumulators_odd_rounds() { + use rand::{SeedableRng, rngs::StdRng}; + + const L: usize = 11; // Odd number of rounds + const L0: usize = 3; + + let n = 1usize << L; + + // Balanced split for eq tables + let suffix_vars = L - L0; // 8 + let in_vars = (suffix_vars + 1) / 2; // 4 + let xout_vars = suffix_vars - in_vars; // 4 + + let num_betas = (D + 1).pow(L0 as u32); + + let mut rng = StdRng::seed_from_u64(123); + + // Create satisfying witness: Cz = Az * Bz + let az_vals: Vec = (0..n).map(|_| Scalar::random(&mut rng)).collect(); + let bz_vals: Vec = (0..n).map(|_| Scalar::random(&mut rng)).collect(); + + let az = MultilinearPolynomial::new(az_vals.clone()); + let bz = MultilinearPolynomial::new(bz_vals.clone()); + let cz = MultilinearPolynomial::new( + az_vals + .iter() + .zip(bz_vals.iter()) + .map(|(a, b)| *a * *b) + .collect(), + ); + + // Random taus of length 11 (odd) + let taus: Vec = (0..L).map(|_| Scalar::random(&mut rng)).collect(); + + // Build accumulators using optimized Spartan path + let acc_impl = + build_accumulators_spartan::<_, _, DelayedModularReductionDisabled>(&az, &bz, &taus, L0); + + // ===== Naive computation for comparison (balanced split) ===== + let e_in = EqPolynomial::evals_from_points(&taus[L0..L0 + in_vars]); + let e_xout = EqPolynomial::evals_from_points(&taus[L0 + in_vars..]); + let e_y = compute_suffix_eq_pyramid(&taus[..L0], L0); + + let idx4_cache: Vec> = (0..num_betas) + .map(|b| compute_idx4(&LagrangeIndex::::from_flat_index(b, L0))) + .collect(); + + let mut acc_naive: LagrangeAccumulators = LagrangeAccumulators::new(L0); + + #[allow(clippy::needless_range_loop)] + for x_out_bits in 0..(1 << xout_vars) { + let ex = e_xout[x_out_bits]; + + for x_in_bits in 0..(1 << in_vars) { + let suffix = (x_in_bits << xout_vars) | x_out_bits; + + let az_pref = az.gather_prefix_evals(L0, suffix); + let bz_pref = bz.gather_prefix_evals(L0, suffix); + let cz_pref = cz.gather_prefix_evals(L0, suffix); + + let az_ext = + LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&az_pref); + let bz_ext = + LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&bz_pref); + let cz_ext = + LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&cz_pref); + + let e_in_eval = e_in[x_in_bits]; + + #[allow(clippy::needless_range_loop)] + for beta_idx in 0..num_betas { + let beta_tuple = az_ext.to_domain_tuple(beta_idx); + let ab = az_ext.get(beta_idx) * bz_ext.get(beta_idx); + let prod = if beta_tuple.has_infinity() { + ab + } else { + ab - cz_ext.get(beta_idx) + }; + let val = e_in_eval * prod; + + for pref in &idx4_cache[beta_idx] { + let ey = e_y[pref.round_0idx()][pref.y_idx]; + acc_naive.accumulate( + pref.round_0idx(), + pref.v_idx, + pref.u.to_index(), + ey * ex * val, + ); + } + } + } + } + + // ===== Compare all accumulator buckets ===== + for round in 0..L0 { + let num_v = (D + 1).pow(round as u32); + for v_idx in 0..num_v { + for u_idx in 0..D { + let got = acc_impl.get(round, v_idx, u_idx); + let expect = acc_naive.get(round, v_idx, u_idx); + assert_eq!( + got, expect, + "Mismatch at round {}, v_idx {}, u_idx {} for odd ℓ={}", + round, v_idx, u_idx, L + ); + } + } + } + } +} diff --git a/src/lagrange_accumulator/basis.rs b/src/lagrange_accumulator/basis.rs new file mode 100644 index 0000000..ff087d0 --- /dev/null +++ b/src/lagrange_accumulator/basis.rs @@ -0,0 +1,508 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Lagrange basis computation for U_d = {∞, 0, 1, ..., d-1}. + +use super::{domain::LagrangePoint, evals::LagrangeEvals}; +use ff::PrimeField; + +#[cfg(test)] +use super::extension::LagrangeEvaluatedMultilinearPolynomial; + +/// Evaluated Lagrange basis at a single r, stored in LagrangePoint order. +pub type LagrangeBasisEval = LagrangeEvals; + +/// Precomputes data for the 1-D Lagrange basis on U_d = {∞, 0, 1, ..., d-1}. +/// +/// `finite_points[k]` is the field element for LagrangePoint::Finite(k). +/// `weights[k]` is the barycentric weight: +/// w_k = 1 / ∏_{m≠k} (x_k - x_m) +/// +/// With these weights, basis evaluation at any r costs O(D) multiplies +/// and uses no per-round inversions. +pub struct LagrangeBasisFactory { + finite_points: [F; D], + weights: [F; D], +} + +/// R_i tensor coefficients used in Algorithm 6. +/// +/// Indexing matches LagrangeIndex::to_flat_index() over U_d^{i-1}. +pub struct LagrangeCoeff { + coeffs: Vec, +} + +impl Default for LagrangeCoeff { + fn default() -> Self { + Self::new() + } +} + +impl LagrangeCoeff { + /// Initialize R_1 = [1]. + pub fn new() -> Self { + Self { + coeffs: vec![F::ONE], + } + } + + /// Returns the number of coefficients. + pub fn len(&self) -> usize { + self.coeffs.len() + } + + /// Returns true if there are no coefficients (always false by construction). + pub fn is_empty(&self) -> bool { + self.coeffs.is_empty() + } + + /// Returns a slice of the coefficients. + pub fn as_slice(&self) -> &[F] { + &self.coeffs + } + + /// Extend: R_{i+1} = R_i ⊗ L(r_i). + pub fn extend(&mut self, basis: &LagrangeBasisEval) { + let base = D + 1; + let mut next = vec![F::ZERO; self.coeffs.len() * base]; + for (i, &c) in self.coeffs.iter().enumerate() { + for (k, b) in basis.iter_ud_order().enumerate() { + next[i * base + k] = c * b; + } + } + self.coeffs = next; + } +} + +/// Test-only helper methods for LagrangeCoeff. +#[cfg(test)] +#[allow(missing_docs)] +impl LagrangeCoeff { + pub fn get(&self, idx: usize) -> F { + self.coeffs[idx] + } +} + +impl LagrangeBasisFactory { + /// Construct the domain using an embedding from indices to field elements. + pub fn new(embed: impl Fn(usize) -> F) -> Self { + let finite_points = std::array::from_fn(embed); + let weights = Self::weights_general(&finite_points); + + Self { + finite_points, + weights, + } + } + + /// Evaluate the Lagrange basis at r. + /// + /// Returns values in LagrangePoint order: [L∞(r), L0(r), L1(r), ..., L_{d-1}(r)]. + pub fn basis_at(&self, r: F) -> LagrangeBasisEval { + // One-hot if r equals a finite domain point. + for (k, &xk) in self.finite_points.iter().enumerate() { + if r == xk { + let mut finite = [F::ZERO; D]; + finite[k] = F::ONE; + return LagrangeEvals::new(F::ZERO, finite); + } + } + + let diffs: [F; D] = std::array::from_fn(|i| r - self.finite_points[i]); + + // prefix[k] = ∏_{j < k} diffs[j] + let base = LagrangePoint::::BASE; + let mut prefix = vec![F::ONE; base]; + for i in 0..D { + prefix[i + 1] = prefix[i] * diffs[i]; + } + + // suffix[k] = ∏_{j > k} diffs[j] + let mut suffix = vec![F::ONE; base]; + for i in (0..D).rev() { + suffix[i] = suffix[i + 1] * diffs[i]; + } + + let prod = prefix[D]; // P(r) = ∏(r - x_k) + + let mut finite = [F::ZERO; D]; + for k in 0..D { + let numer = prefix[k] * suffix[k + 1]; // P(r)/(r - x_k) + finite[k] = numer * self.weights[k]; + } + + LagrangeEvals::new(prod, finite) + } + + fn weights_general(points: &[F; D]) -> [F; D] { + let denoms = std::array::from_fn(|k| { + let xk = points[k]; + let mut denom = F::ONE; + for (m, &xm) in points.iter().enumerate() { + if m == k { + continue; + } + denom *= xk - xm; + } + denom + }); + Self::batch_invert_array(denoms) + } + + fn batch_invert_array(values: [F; D]) -> [F; D] { + let mut prefix: Vec = vec![F::ONE; D + 1]; + for i in 0..D { + prefix[i + 1] = prefix[i] * values[i]; + } + + let inv_prod = prefix[D].invert().unwrap(); + + let mut out = [F::ZERO; D]; + let mut suffix = F::ONE; + for i in (0..D).rev() { + out[i] = prefix[i] * suffix * inv_prod; + suffix *= values[i]; + } + out + } +} + +/// Test-only helper methods for LagrangeBasisFactory. +#[cfg(test)] +impl LagrangeBasisFactory { + /// Evaluate an extended polynomial at r using the tensor-product Lagrange basis. + pub fn eval_extended( + &self, + extended: &LagrangeEvaluatedMultilinearPolynomial, + r: &[F], + ) -> F { + assert_eq!(extended.num_vars(), r.len()); + + let mut coeff = LagrangeCoeff::::new(); + for &ri in r { + let basis = self.basis_at(ri); + coeff.extend(&basis); + } + + let mut acc = F::ZERO; + for idx in 0..extended.len() { + acc += coeff.get(idx) * extended.get(idx); + } + acc + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{polys::multilinear::MultilinearPolynomial, provider::pasta::pallas}; + use ff::Field; + + type Scalar = pallas::Scalar; + + fn evaluate_multilinear(evals: &[Scalar], point: &[Scalar]) -> Scalar { + let chis = crate::polys::eq::EqPolynomial::evals_from_points(point); + evals + .iter() + .zip(chis.iter()) + .fold(Scalar::ZERO, |acc, (z, chi)| acc + *z * *chi) + } + + // === Lagrange basis tests === + + // Property: basis is one-hot at a finite domain point and L∞(x_k)=0. + #[test] + fn test_basis_at_domain_points_one_hot() { + const D: usize = 3; + + let factory = LagrangeBasisFactory::::new(|i| Scalar::from(i as u64)); + + for k in 0..D { + let r = Scalar::from(k as u64); + let basis = factory.basis_at(r); + + assert_eq!(basis.infinity, Scalar::ZERO); + for j in 0..D { + let expected = if j == k { Scalar::ONE } else { Scalar::ZERO }; + assert_eq!(basis.finite[j], expected); + } + } + } + + // Property: L∞(r) equals ∏(r - x_k). + #[test] + fn test_basis_at_l_inf_product() { + const D: usize = 3; + + let factory = LagrangeBasisFactory::::new(|i| Scalar::from(i as u64)); + let r = Scalar::from(5u64); + let basis = factory.basis_at(r); + + let expected = (0..D).fold(Scalar::ONE, |acc, k| acc * (r - Scalar::from(k as u64))); + assert_eq!(basis.infinity, expected); + } + + // Property: Σ_k L_k(r) = 1 for any r (constant polynomial). + #[test] + fn test_basis_at_finite_sum_is_one() { + const D: usize = 3; + + let factory = LagrangeBasisFactory::::new(|i| Scalar::from(i as u64)); + let r = Scalar::from(7u64); + let basis = factory.basis_at(r); + + let sum = basis.finite.iter().fold(Scalar::ZERO, |acc, v| acc + v); + assert_eq!(sum, Scalar::ONE); + } + + // Property: LagrangePoint ordering and getters are consistent. + #[test] + fn test_basis_eval_order_and_get() { + const D: usize = 3; + + let eval = LagrangeEvals::::new( + Scalar::from(2u64), + [Scalar::from(3u64), Scalar::from(5u64), Scalar::from(7u64)], + ); + + let vals: Vec<_> = eval.iter_ud_order().collect(); + assert_eq!( + vals, + vec![ + Scalar::from(2u64), + Scalar::from(3u64), + Scalar::from(5u64), + Scalar::from(7u64), + ] + ); + + assert_eq!(eval.get(LagrangePoint::Infinity), Scalar::from(2u64)); + assert_eq!(eval.get(LagrangePoint::Finite(2)), Scalar::from(7u64)); + } + + // Property: degree-2 polynomial is reconstructed from {∞,0,1}. + #[test] + fn test_basis_reconstructs_deg2_poly() { + const D: usize = 2; + + let factory = LagrangeBasisFactory::::new(|i| Scalar::from(i as u64)); + + let eval = |x: Scalar| x * x + Scalar::from(2u64) * x + Scalar::ONE; + let s_inf = Scalar::ONE; // leading coeff of x^2 + 2x + 1 + let s0 = eval(Scalar::ZERO); + let s1 = eval(Scalar::ONE); + + let mut rng = rand_core::OsRng; + let mut rs = vec![Scalar::ZERO, Scalar::ONE]; + for _ in 0..3 { + rs.push(Scalar::random(&mut rng)); + } + for r in rs { + let basis = factory.basis_at(r); + let reconstructed = s_inf * basis.infinity + s0 * basis.finite[0] + s1 * basis.finite[1]; + assert_eq!(reconstructed, eval(r)); + } + } + + // Property: degree-3 polynomial is reconstructed from {∞,0,1,2}. + #[test] + fn test_basis_reconstructs_deg3_poly() { + const D: usize = 3; + + let factory = LagrangeBasisFactory::::new(|i| Scalar::from(i as u64)); + + let eval = |x: Scalar| { + let x2 = x * x; + let x3 = x2 * x; + Scalar::from(2u64) * x3 - Scalar::from(4u64) * x2 + Scalar::from(5u64) * x - Scalar::ONE + }; + + let s_inf = Scalar::from(2u64); // leading coeff of 2x^3 - 4x^2 + 5x - 1 + let s0 = eval(Scalar::ZERO); + let s1 = eval(Scalar::ONE); + let s2 = eval(Scalar::from(2u64)); + + let mut rng = rand_core::OsRng; + let mut rs = vec![Scalar::ZERO, Scalar::ONE, Scalar::from(2u64)]; + for _ in 0..4 { + rs.push(Scalar::random(&mut rng)); + } + for r in rs { + let basis = factory.basis_at(r); + let reconstructed = + s_inf * basis.infinity + s0 * basis.finite[0] + s1 * basis.finite[1] + s2 * basis.finite[2]; + assert_eq!(reconstructed, eval(r)); + } + } + + // === LagrangeCoeff tests === + + // Property: R_1 starts at 1 and extend copies basis in Ud order. + #[test] + fn test_lagrange_coeff_new_and_extend() { + const D: usize = 3; + + let mut coeff = LagrangeCoeff::::new(); + assert_eq!(coeff.len(), 1); + assert_eq!(coeff.get(0), Scalar::ONE); + + let basis = LagrangeEvals::::new( + Scalar::from(2u64), + [Scalar::from(3u64), Scalar::from(5u64), Scalar::from(7u64)], + ); + + coeff.extend(&basis); + assert_eq!(coeff.len(), D + 1); + assert_eq!( + coeff.as_slice(), + &[ + Scalar::from(2u64), + Scalar::from(3u64), + Scalar::from(5u64), + Scalar::from(7u64), + ] + ); + } + + // Property: R_2 equals outer product of two basis vectors. + #[test] + #[allow(clippy::needless_range_loop)] + fn test_lagrange_coeff_tensor_product() { + const D: usize = 3; + let base = D + 1; + + let basis1 = LagrangeEvals::::new( + Scalar::from(2u64), + [Scalar::from(3u64), Scalar::from(5u64), Scalar::from(7u64)], + ); + let basis2 = LagrangeEvals::::new( + Scalar::from(11u64), + [ + Scalar::from(13u64), + Scalar::from(17u64), + Scalar::from(19u64), + ], + ); + + let mut coeff = LagrangeCoeff::::new(); + coeff.extend(&basis1); + coeff.extend(&basis2); + + let b1: Vec<_> = basis1.iter_ud_order().collect(); + let b2: Vec<_> = basis2.iter_ud_order().collect(); + for i in 0..base { + for j in 0..base { + assert_eq!(coeff.get(i * base + j), b1[i] * b2[j]); + } + } + } + + // Property: LagrangeCoeff + extended evals matches direct multilinear evaluation. + #[test] + fn test_lagrange_coeff_matches_direct_eval_multilinear() { + const D: usize = 1; + let num_vars = 4; + let mut rng = rand_core::OsRng; + + let evals: Vec = (0..(1 << num_vars)) + .map(|_| Scalar::random(&mut rng)) + .collect(); + let poly = MultilinearPolynomial::new(evals.clone()); + let extended = LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&poly); + let factory = LagrangeBasisFactory::::new(|i| Scalar::from(i as u64)); + + // Check the finite domain point in U_d^4 (only 0 when D=1). + let r = [Scalar::ZERO, Scalar::ZERO, Scalar::ZERO, Scalar::ZERO]; + let direct = evaluate_multilinear(&evals, &r); + let lagrange = factory.eval_extended(&extended, &r); + assert_eq!(direct, lagrange); + + // Check a few random points in F^4. + for _ in 0..3 { + let r = [ + Scalar::random(&mut rng), + Scalar::random(&mut rng), + Scalar::random(&mut rng), + Scalar::random(&mut rng), + ]; + let direct = evaluate_multilinear(&evals, &r); + let lagrange = factory.eval_extended(&extended, &r); + assert_eq!(direct, lagrange); + } + } + + // Property: LagrangeCoeff matches direct eval for product of three multilinear polynomials. + #[test] + fn test_lagrange_coeff_matches_direct_eval_product_of_three() { + const D: usize = 3; + let num_vars = 4; + let mut rng = rand_core::OsRng; + + let evals1: Vec = (0..(1 << num_vars)) + .map(|_| Scalar::random(&mut rng)) + .collect(); + let evals2: Vec = (0..(1 << num_vars)) + .map(|_| Scalar::random(&mut rng)) + .collect(); + let evals3: Vec = (0..(1 << num_vars)) + .map(|_| Scalar::random(&mut rng)) + .collect(); + + let p1 = MultilinearPolynomial::new(evals1.clone()); + let p2 = MultilinearPolynomial::new(evals2.clone()); + let p3 = MultilinearPolynomial::new(evals3.clone()); + + let ext1 = LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&p1); + let ext2 = LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&p2); + let ext3 = LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&p3); + + let mut prod_evals = vec![Scalar::ZERO; ext1.len()]; + #[allow(clippy::needless_range_loop)] + for i in 0..ext1.len() { + prod_evals[i] = ext1.get(i) * ext2.get(i) * ext3.get(i); + } + + let prod_extended = + LagrangeEvaluatedMultilinearPolynomial::::from_evals(prod_evals, num_vars); + let factory = LagrangeBasisFactory::::new(|i| Scalar::from(i as u64)); + + // Check all finite domain points in U_d^4. + for a in 0..D { + for b in 0..D { + for c in 0..D { + for d in 0..D { + let r = [ + Scalar::from(a as u64), + Scalar::from(b as u64), + Scalar::from(c as u64), + Scalar::from(d as u64), + ]; + let direct = evaluate_multilinear(&evals1, &r) + * evaluate_multilinear(&evals2, &r) + * evaluate_multilinear(&evals3, &r); + let lagrange = factory.eval_extended(&prod_extended, &r); + assert_eq!(direct, lagrange); + } + } + } + } + + // Check a few random points in F^4. + for _ in 0..3 { + let r = [ + Scalar::random(&mut rng), + Scalar::random(&mut rng), + Scalar::random(&mut rng), + Scalar::random(&mut rng), + ]; + let direct = evaluate_multilinear(&evals1, &r) + * evaluate_multilinear(&evals2, &r) + * evaluate_multilinear(&evals3, &r); + let lagrange = factory.eval_extended(&prod_extended, &r); + assert_eq!(direct, lagrange); + } + } +} diff --git a/src/lagrange_accumulator/delay_modular_reduction_mode.rs b/src/lagrange_accumulator/delay_modular_reduction_mode.rs new file mode 100644 index 0000000..607057a --- /dev/null +++ b/src/lagrange_accumulator/delay_modular_reduction_mode.rs @@ -0,0 +1,294 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Delayed Modular Reduction (DMR) mode selection for accumulator building. +//! +//! This module provides compile-time selection between DMR-enabled and DMR-disabled +//! accumulation strategies in [`super::build_accumulators_spartan`]. +//! +//! # Design +//! +//! Uses marker types + trait with associated types to select behavior at compile time: +//! - [`DelayedModularReductionEnabled`]: Accumulates into unreduced wide-limb form, reduces once at end +//! - [`DelayedModularReductionDisabled`]: Immediate Montgomery reduction on each operation (baseline) +//! +//! This gives zero wasted allocations and zero runtime branching overhead. +//! +//! # Traits +//! +//! - [`AccumulateProduct`]: Bridges product types (i64, i128, F) to their unreduced accumulators +//! - [`DelayedModularReductionMode`]: Defines element-level accumulation operations for both F×int and F×F + +use super::mat_vec_mle::MatVecMLE; +use crate::small_field::DelayedReduction; +use ff::PrimeField; +use num_traits::Zero; +use std::{ + fmt::Debug, + marker::PhantomData, + ops::{Add, AddAssign, Neg, Sub, SubAssign}, +}; + +// ============================================================================= +// AccumulateProduct trait +// ============================================================================= + +/// Bridges product types to their unreduced accumulators. +/// +/// This trait connects `MLE::Product` (i64, i128, or F) to the appropriate +/// unreduced accumulator type. It enables `DelayedModularReductionEnabled` +/// to work generically with any product type. +/// +/// # Type Parameters +/// +/// - `F`: The field type to reduce to +/// +/// # Implementations +/// +/// - `i64: AccumulateProduct` where `Unreduced = F::UnreducedFieldInt` (for i32 witnesses) +/// - `i128: AccumulateProduct` where `Unreduced = F::UnreducedFieldInt` (for i64 witnesses) +pub trait AccumulateProduct: Copy { + /// The unreduced accumulator type for this product. + /// Has Zero bound so callers can use `.is_zero()` directly. + type Unreduced: Copy + Clone + Default + AddAssign + Send + Sync + Zero; + + /// Accumulate `e × prod` into `acc` (no modular reduction). + fn accumulate(acc: &mut Self::Unreduced, e: &F, prod: Self); + + /// Reduce the accumulator to a field element. + fn reduce(acc: &Self::Unreduced) -> F; +} + +// i64 products (from i32 witnesses) +impl AccumulateProduct for i64 +where + F: DelayedReduction, +{ + type Unreduced = F::UnreducedFieldInt; + + #[inline] + fn accumulate(acc: &mut Self::Unreduced, e: &F, prod: i64) { + F::unreduced_field_int_mul_add(acc, e, prod); + } + + #[inline] + fn reduce(acc: &Self::Unreduced) -> F { + F::reduce_field_int(acc) + } +} + +// i128 products (from i64 witnesses) +impl AccumulateProduct for i128 +where + F: DelayedReduction, +{ + type Unreduced = F::UnreducedFieldInt; + + #[inline] + fn accumulate(acc: &mut Self::Unreduced, e: &F, prod: i128) { + F::unreduced_field_int_mul_add(acc, e, prod); + } + + #[inline] + fn reduce(acc: &Self::Unreduced) -> F { + F::reduce_field_int(acc) + } +} + +// Note: We don't implement AccumulateProduct for F (field products) +// because field witnesses should use DelayedModularReductionDisabled, not DelayedModularReductionEnabled. +// DelayedModularReductionEnabled requires DelayedReduction which isn't available for field witnesses. + +// ============================================================================= +// DelayedModularReductionMode trait +// ============================================================================= + +/// Compile-time selection of delayed modular reduction strategy. +/// +/// This trait defines element-level accumulation operations for both +/// F×int (inner loop) and F×F (scatter phase) accumulation. +/// +/// # Type Parameters +/// +/// - `F`: Field type (e.g., `pallas::Scalar`) +/// - `MLE`: Polynomial type implementing [`MatVecMLE`] +/// - `D`: Polynomial degree bound (2 for Spartan) +/// +/// # Associated Types +/// +/// - `PartialSum`: Element type for F×int accumulation (per-beta) +/// - `ScatterElement`: Element type for F×F accumulation (per-bucket) +pub trait DelayedModularReductionMode: Sized +where + F: PrimeField + Send + Sync, + MLE: MatVecMLE, +{ + /// Element type for F×int partial sums (per-beta accumulation). + /// + /// - `Enabled`: `>::Unreduced` (delayed reduction) + /// - `Disabled`: `F` (immediate reduction) + type PartialSum: Copy + Default + Send; + + /// Element type for F×F scatter buckets. + /// + /// - `Enabled`: `F::UnreducedFieldField` (delayed F×F) + /// - `Disabled`: `F` (immediate F×F) + type ScatterElement: Copy + Default + Send + AddAssign; + + // =========================================================================== + // F×int operations (inner loop) + // =========================================================================== + + /// Accumulate eq-weighted product into partial sum. + /// + /// Called in hot inner loop over x_in for each beta with infinity. + fn accumulate_eq_product(sum: &mut Self::PartialSum, prod: MLE::Product, e: &F); + + /// Reduce partial sum to field element. + fn modular_reduction_partial_sum(sum: &Self::PartialSum) -> F; + + /// Check if partial sum is zero. + fn partial_sum_is_zero(sum: &Self::PartialSum) -> bool; + + // =========================================================================== + // F×F operations (scatter phase) + // =========================================================================== + + /// Accumulate `ey × z_beta` into scatter element. + fn accumulate_scatter(elem: &mut Self::ScatterElement, ey: &F, z_beta: &F); + + /// Reduce scatter element to field element. + fn modular_reduction_scatter(elem: &Self::ScatterElement) -> F; + + /// Check if scatter element is zero. + fn scatter_element_is_zero(elem: &Self::ScatterElement) -> bool; +} + +// ============================================================================= +// DelayedModularReductionEnabled +// ============================================================================= + +/// Marker type for delayed modular reduction (optimized path). +/// +/// Accumulates F×int products into unreduced form during the inner loop, +/// and F×F products into unreduced form during scatter. Reduces once at the end. +/// +/// # Type Parameters +/// +/// - `SmallValue`: The small integer witness type (i32 or i64) +pub struct DelayedModularReductionEnabled(PhantomData); + +impl Default for DelayedModularReductionEnabled { + fn default() -> Self { + Self(PhantomData) + } +} + +impl DelayedModularReductionMode + for DelayedModularReductionEnabled +where + F: PrimeField + DelayedReduction + Send + Sync, + SmallValue: Copy + + Clone + + Default + + Debug + + PartialEq + + Eq + + Add + + Sub + + Neg + + AddAssign + + SubAssign + + Send + + Sync, + MLE: MatVecMLE, + MLE::Product: AccumulateProduct, +{ + type PartialSum = >::Unreduced; + type ScatterElement = F::UnreducedFieldField; + + #[inline] + fn accumulate_eq_product(sum: &mut Self::PartialSum, prod: MLE::Product, e: &F) { + MLE::Product::accumulate(sum, e, prod); + } + + #[inline] + fn modular_reduction_partial_sum(sum: &Self::PartialSum) -> F { + MLE::Product::reduce(sum) + } + + #[inline] + fn partial_sum_is_zero(sum: &Self::PartialSum) -> bool { + sum.is_zero() + } + + #[inline] + fn accumulate_scatter(elem: &mut Self::ScatterElement, ey: &F, z_beta: &F) { + F::unreduced_field_field_mul_add(elem, ey, z_beta); + } + + #[inline] + fn modular_reduction_scatter(elem: &Self::ScatterElement) -> F { + F::reduce_field_field(elem) + } + + #[inline] + fn scatter_element_is_zero(elem: &Self::ScatterElement) -> bool { + elem.is_zero() + } +} + +// ============================================================================= +// DelayedModularReductionDisabled +// ============================================================================= + +/// Marker type for immediate reduction (baseline path). +/// +/// Performs Montgomery reduction on each accumulation operation. +/// Used for benchmarking to measure DMR speedup. +pub struct DelayedModularReductionDisabled; + +impl DelayedModularReductionMode + for DelayedModularReductionDisabled +where + F: PrimeField + Send + Sync, + MLE: MatVecMLE, +{ + type PartialSum = F; + type ScatterElement = F; + + #[inline] + fn accumulate_eq_product(sum: &mut F, prod: MLE::Product, e: &F) { + // Immediate reduction: convert product to field element directly + let field_prod = MLE::product_to_field(prod); + *sum += *e * field_prod; + } + + #[inline] + fn modular_reduction_partial_sum(sum: &F) -> F { + *sum // Already reduced + } + + #[inline] + fn partial_sum_is_zero(sum: &F) -> bool { + ff::Field::is_zero(sum).into() + } + + #[inline] + fn accumulate_scatter(elem: &mut F, ey: &F, z_beta: &F) { + *elem += *ey * *z_beta; // Immediate F×F + } + + #[inline] + fn modular_reduction_scatter(elem: &F) -> F { + *elem // Already reduced + } + + #[inline] + fn scatter_element_is_zero(elem: &F) -> bool { + ff::Field::is_zero(elem).into() + } +} diff --git a/src/lagrange_accumulator/domain.rs b/src/lagrange_accumulator/domain.rs new file mode 100644 index 0000000..cd54a78 --- /dev/null +++ b/src/lagrange_accumulator/domain.rs @@ -0,0 +1,513 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Domain types for the Algorithm 6 small-value sum-check optimization. +//! +//! This module defines: +//! - [`LagrangePoint`]: Points in U_d = {∞, 0, 1, ..., d-1} +//! - [`LagrangeHatPoint`]: Points in Û_d = U_d \ {1} (reduced domain) +//! - [`LagrangeIndex`]: Tuples β ∈ U_d^k +//! - [`ValueOneExcluded`]: Error for invalid conversions +//! +//! All types are parameterized by `const D: usize` representing the degree bound. +//! This enables compile-time type safety and debug assertions for bounds checking. + +/// A point in the domain U_d = {∞, 0, 1, ..., d-1} +/// +/// The domain has d+1 points. The ∞ point represents evaluation of the +/// leading coefficient (see Lemma 2.2 in the paper). +/// +/// Type parameter `D` is the degree bound, so valid finite values are 0..D-1. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum LagrangePoint { + /// The point at infinity — represents leading coefficient + Infinity, + /// A finite field value 0, 1, ..., D-1 + Finite(usize), +} + +impl LagrangePoint { + /// Base of the domain U_D (= D + 1 points) + pub const BASE: usize = D + 1; + + /// Convert to flat index for array access. + /// Infinity → 0, Finite(v) → v + 1 + #[inline] + pub fn to_index(self) -> usize { + match self { + LagrangePoint::Infinity => 0, + LagrangePoint::Finite(v) => v + 1, + } + } + + /// Convert from flat index. + /// 0 → Infinity, k → Finite(k - 1) + /// + /// # Panics (debug builds only) + /// Panics if idx > D + #[inline] + pub fn from_index(idx: usize) -> Self { + debug_assert!( + idx <= D, + "LagrangePoint::from_index({idx}) out of bounds for D={D}" + ); + if idx == 0 { + LagrangePoint::Infinity + } else { + LagrangePoint::Finite(idx - 1) + } + } + + /// Is this a binary point (0 or 1)? + #[inline] + pub fn is_binary(self) -> bool { + matches!(self, LagrangePoint::Finite(0) | LagrangePoint::Finite(1)) + } + + /// Convert to Û_d point (the reduced domain excluding value 1). + /// + /// Returns `None` for Finite(1) since 1 ∉ Û_d. + #[inline] + pub fn to_ud_hat(self) -> Option> { + LagrangeHatPoint::try_from(self).ok() + } +} + +/// Test-only helper methods for LagrangePoint. +#[cfg(test)] +impl LagrangePoint { + /// Convert to field element. Returns `None` for Infinity. + #[inline] + pub fn to_field(self) -> Option { + match self { + LagrangePoint::Infinity => None, + LagrangePoint::Finite(v) => Some(F::from(v as u64)), + } + } +} + +/// Error returned when trying to convert `Finite(1)` to `LagrangeHatPoint`. +/// +/// The value 1 is excluded from Û_d because s(1) can be recovered +/// from the sum-check constraint s(0) + s(1) = claim. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct ValueOneExcluded; + +impl std::fmt::Display for ValueOneExcluded { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "value 1 is not in Û_d (excluded from reduced domain)") + } +} + +impl std::error::Error for ValueOneExcluded {} + +/// A point in the reduced domain Û_d = U_d \ {1} = {∞, 0, 2, 3, ..., d-1} +/// +/// This domain has d elements (one less than U_d). +/// Value 1 is excluded because s(1) can be recovered from s(0) + s(1) = claim. +/// +/// Type parameter `D` is the degree bound (size of Û_d). +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum LagrangeHatPoint { + /// The point at infinity — represents leading coefficient + Infinity, + /// A finite field value: 0, 2, 3, ... (never 1) + Finite(usize), +} + +impl LagrangeHatPoint { + /// Convert to array index. + /// Mapping: ∞ → 0, 0 → 1, 2 → 2, 3 → 3, ... + #[inline] + pub fn to_index(self) -> usize { + match self { + LagrangeHatPoint::Infinity => 0, + LagrangeHatPoint::Finite(0) => 1, + LagrangeHatPoint::Finite(k) => k, // 2→2, 3→3, etc. + } + } + + /// Convert to LagrangePoint (U_d point) + #[inline] + pub fn to_ud_point(self) -> LagrangePoint { + match self { + LagrangeHatPoint::Infinity => LagrangePoint::Infinity, + LagrangeHatPoint::Finite(v) => LagrangePoint::Finite(v), + } + } +} + +/// Test-only helper methods for LagrangeHatPoint. +#[cfg(test)] +impl LagrangeHatPoint { + /// Create a finite point. Returns None for v=1 (not in Û_d). + pub fn finite(v: usize) -> Option { + if v == 1 { + None + } else { + debug_assert!( + v < D, + "LagrangeHatPoint::finite({v}) out of bounds for D={D}" + ); + Some(LagrangeHatPoint::Finite(v)) + } + } + + /// Create from array index. + /// Mapping: 0 → ∞, 1 → 0, 2 → 2, 3 → 3, ... + #[inline] + pub fn from_index(idx: usize) -> Self { + debug_assert!( + idx < D, + "LagrangeHatPoint::from_index({idx}) out of bounds for D={D}" + ); + match idx { + 0 => LagrangeHatPoint::Infinity, + 1 => LagrangeHatPoint::Finite(0), + k => LagrangeHatPoint::Finite(k), + } + } + + /// Iterate over all points in Û_d. + /// Yields: ∞, 0, 2, 3, ..., D-1 (total of D elements) + pub fn iter() -> impl Iterator> { + (0..D).map(LagrangeHatPoint::from_index) + } +} + +// === Trait Implementations === + +impl From> for LagrangePoint { + fn from(p: LagrangeHatPoint) -> Self { + p.to_ud_point() + } +} + +impl TryFrom> for LagrangeHatPoint { + type Error = ValueOneExcluded; + + fn try_from(p: LagrangePoint) -> Result { + match p { + LagrangePoint::Infinity => Ok(LagrangeHatPoint::Infinity), + LagrangePoint::Finite(1) => Err(ValueOneExcluded), + LagrangePoint::Finite(v) => Ok(LagrangeHatPoint::Finite(v)), + } + } +} + +/// A tuple β ∈ U_d^k — an index into the extended domain. +/// +/// Used to index into LagrangeEvaluatedMultilinearPolynomial which stores evaluations over U_d^ℓ₀. +/// +/// Type parameter `D` is the degree bound (U_D has D+1 points). +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct LagrangeIndex(pub Vec>); + +impl LagrangeIndex { + /// Base of the domain U_D (= D + 1) + pub const BASE: usize = D + 1; + + /// Number of coordinates + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns true if there are no coordinates. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Convert from flat index (mixed-radix decoding) + /// + /// Uses compile-time BASE = D + 1 + pub fn from_flat_index(mut idx: usize, len: usize) -> Self { + let mut points = vec![LagrangePoint::Infinity; len]; + for i in (0..len).rev() { + points[i] = LagrangePoint::from_index(idx % Self::BASE); + idx /= Self::BASE; + } + LagrangeIndex(points) + } +} + +/// Test-only helper methods for LagrangeIndex. +#[cfg(test)] +impl LagrangeIndex { + /// Convert to flat index for array access (mixed-radix encoding) + /// + /// Uses compile-time BASE = D + 1 + pub fn to_flat_index(&self) -> usize { + self + .0 + .iter() + .fold(0, |acc, p| acc * Self::BASE + p.to_index()) + } + + /// Check if all coordinates are binary (0 or 1, no ∞) + pub fn is_all_binary(&self) -> bool { + self.0.iter().all(|p| p.is_binary()) + } + + /// Check if any coordinate is ∞ + pub fn has_infinity(&self) -> bool { + self.0.iter().any(|p| matches!(p, LagrangePoint::Infinity)) + } + + /// Create a LagrangeIndex from a binary index in {0,1}^num_bits. + #[inline] + pub fn from_binary(bits: usize, num_bits: usize) -> Self { + let mut points = Vec::with_capacity(num_bits); + for j in 0..num_bits { + let bit = (bits >> (num_bits - 1 - j)) & 1; + points.push(LagrangePoint::Finite(bit)); + } + LagrangeIndex(points) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::provider::pasta::pallas; + use ff::Field; + + type Scalar = pallas::Scalar; + + // === LagrangePoint tests === + + #[test] + fn test_ud_point_index_roundtrip() { + // Test all points for D=4 (indices 0..5) + for idx in 0..5 { + let p = LagrangePoint::<4>::from_index(idx); + assert_eq!(p.to_index(), idx); + } + + // Test specific values + assert_eq!(LagrangePoint::<3>::Infinity.to_index(), 0); + assert_eq!(LagrangePoint::<3>::Finite(0).to_index(), 1); + assert_eq!(LagrangePoint::<3>::Finite(1).to_index(), 2); + assert_eq!(LagrangePoint::<3>::Finite(2).to_index(), 3); + } + + #[test] + fn test_ud_point_is_binary() { + assert!(!LagrangePoint::<3>::Infinity.is_binary()); + assert!(LagrangePoint::<3>::Finite(0).is_binary()); + assert!(LagrangePoint::<3>::Finite(1).is_binary()); + assert!(!LagrangePoint::<3>::Finite(2).is_binary()); + } + + #[test] + fn test_ud_point_to_field() { + assert_eq!(LagrangePoint::<3>::Infinity.to_field::(), None); + assert_eq!( + LagrangePoint::<3>::Finite(0).to_field::(), + Some(Scalar::ZERO) + ); + assert_eq!( + LagrangePoint::<3>::Finite(1).to_field::(), + Some(Scalar::ONE) + ); + assert_eq!( + LagrangePoint::<3>::Finite(2).to_field::(), + Some(Scalar::from(2u64)) + ); + } + + #[test] + fn test_ud_point_base_const() { + assert_eq!(LagrangePoint::<3>::BASE, 4); + assert_eq!(LagrangePoint::<4>::BASE, 5); + } + + // === LagrangeHatPoint tests === + + #[test] + fn test_ud_hat_index_roundtrip() { + // Test all points for D=4 (indices 0..4) + for idx in 0..4 { + let p = LagrangeHatPoint::<4>::from_index(idx); + assert_eq!(p.to_index(), idx); + } + } + + #[test] + fn test_ud_hat_index_mapping() { + // Verify exact mapping: ∞→0, 0→1, 2→2, 3→3 + assert_eq!(LagrangeHatPoint::<4>::Infinity.to_index(), 0); + assert_eq!(LagrangeHatPoint::<4>::Finite(0).to_index(), 1); + assert_eq!(LagrangeHatPoint::<4>::Finite(2).to_index(), 2); + assert_eq!(LagrangeHatPoint::<4>::Finite(3).to_index(), 3); + + // Reverse mapping + assert_eq!( + LagrangeHatPoint::<4>::from_index(0), + LagrangeHatPoint::Infinity + ); + assert_eq!( + LagrangeHatPoint::<4>::from_index(1), + LagrangeHatPoint::Finite(0) + ); + assert_eq!( + LagrangeHatPoint::<4>::from_index(2), + LagrangeHatPoint::Finite(2) + ); + assert_eq!( + LagrangeHatPoint::<4>::from_index(3), + LagrangeHatPoint::Finite(3) + ); + } + + #[test] + fn test_ud_hat_iter() { + // For D=3, Û_d = {∞, 0, 2} + let points: Vec<_> = LagrangeHatPoint::<3>::iter().collect(); + assert_eq!(points.len(), 3); + assert_eq!(points[0], LagrangeHatPoint::Infinity); + assert_eq!(points[1], LagrangeHatPoint::Finite(0)); + assert_eq!(points[2], LagrangeHatPoint::Finite(2)); + } + + #[test] + fn test_ud_hat_finite_one_rejected() { + assert!(LagrangeHatPoint::<3>::finite(0).is_some()); + assert!(LagrangeHatPoint::<3>::finite(1).is_none()); // 1 not in Û_d + assert!(LagrangeHatPoint::<3>::finite(2).is_some()); + } + + // === Conversion tests === + + #[test] + fn test_ud_to_ud_hat() { + assert_eq!( + LagrangeHatPoint::<3>::try_from(LagrangePoint::<3>::Infinity), + Ok(LagrangeHatPoint::Infinity) + ); + assert_eq!( + LagrangeHatPoint::<3>::try_from(LagrangePoint::<3>::Finite(0)), + Ok(LagrangeHatPoint::Finite(0)) + ); + assert_eq!( + LagrangeHatPoint::<3>::try_from(LagrangePoint::<3>::Finite(1)), + Err(ValueOneExcluded) + ); + assert_eq!( + LagrangeHatPoint::<3>::try_from(LagrangePoint::<3>::Finite(2)), + Ok(LagrangeHatPoint::Finite(2)) + ); + } + + #[test] + fn test_ud_hat_to_ud() { + // Via From trait + assert_eq!( + LagrangePoint::<3>::from(LagrangeHatPoint::<3>::Infinity), + LagrangePoint::Infinity + ); + assert_eq!( + LagrangePoint::<3>::from(LagrangeHatPoint::<3>::Finite(0)), + LagrangePoint::Finite(0) + ); + assert_eq!( + LagrangePoint::<3>::from(LagrangeHatPoint::<3>::Finite(2)), + LagrangePoint::Finite(2) + ); + + // Roundtrip for valid points + let valid_points = [ + LagrangePoint::<3>::Infinity, + LagrangePoint::<3>::Finite(0), + LagrangePoint::<3>::Finite(2), + ]; + for p in valid_points { + let hat = LagrangeHatPoint::try_from(p).unwrap(); + assert_eq!(LagrangePoint::from(hat), p); + } + } + + // === LagrangeIndex tests === + + #[test] + fn test_tuple_flat_index_roundtrip() { + let len: usize = 3; + + // Test all tuples in U_4^3 (D=3, BASE=4) + for idx in 0..LagrangeIndex::<3>::BASE.pow(len as u32) { + let tuple = LagrangeIndex::<3>::from_flat_index(idx, len); + assert_eq!(tuple.to_flat_index(), idx); + assert_eq!(tuple.len(), len); + } + } + + #[test] + fn test_tuple_base_const() { + assert_eq!(LagrangeIndex::<3>::BASE, 4); + assert_eq!(LagrangeIndex::<4>::BASE, 5); + } + + #[test] + fn test_tuple_is_all_binary() { + // [0, 1, 0] - all binary + let binary = LagrangeIndex::<3>(vec![ + LagrangePoint::Finite(0), + LagrangePoint::Finite(1), + LagrangePoint::Finite(0), + ]); + assert!(binary.is_all_binary()); + + // [0, ∞, 1] - has infinity + let has_inf = LagrangeIndex::<3>(vec![ + LagrangePoint::Finite(0), + LagrangePoint::Infinity, + LagrangePoint::Finite(1), + ]); + assert!(!has_inf.is_all_binary()); + + // [0, 2, 1] - has non-binary finite + let has_two = LagrangeIndex::<3>(vec![ + LagrangePoint::Finite(0), + LagrangePoint::Finite(2), + LagrangePoint::Finite(1), + ]); + assert!(!has_two.is_all_binary()); + } + + #[test] + fn test_tuple_has_infinity() { + // [0, ∞, 1] - has infinity + let has_inf = LagrangeIndex::<3>(vec![ + LagrangePoint::Finite(0), + LagrangePoint::Infinity, + LagrangePoint::Finite(1), + ]); + assert!(has_inf.has_infinity()); + + // [0, 1, 2] - no infinity + let no_inf = LagrangeIndex::<3>(vec![ + LagrangePoint::Finite(0), + LagrangePoint::Finite(1), + LagrangePoint::Finite(2), + ]); + assert!(!no_inf.has_infinity()); + } + + #[test] + fn test_tuple_specific_encoding() { + // For D=3 (BASE=4), test specific encodings + // Tuple (∞, 0, 1) = (idx 0, idx 1, idx 2) -> 0*16 + 1*4 + 2 = 6 + let tuple = LagrangeIndex::<3>(vec![ + LagrangePoint::Infinity, + LagrangePoint::Finite(0), + LagrangePoint::Finite(1), + ]); + assert_eq!(tuple.to_flat_index(), 6); + + // Reverse: 6 -> (0, 1, 2) -> (∞, 0, 1) + let decoded = LagrangeIndex::<3>::from_flat_index(6, 3); + assert_eq!(decoded, tuple); + } +} diff --git a/src/lagrange_accumulator/eq_round.rs b/src/lagrange_accumulator/eq_round.rs new file mode 100644 index 0000000..c1c91c1 --- /dev/null +++ b/src/lagrange_accumulator/eq_round.rs @@ -0,0 +1,203 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Utilities for computing the per-round linear equality factor in sum-check. + +use super::evals::LagrangeEvals; +use ff::PrimeField; + +/// Derives t_i(1) using the sumcheck relation: claim = ℓ_i(0)·t(0) + ℓ_i(1)·t(1). +/// +/// Returns `None` if `l1` is zero (non-invertible). +pub fn derive_t1(l0: F, l1: F, claim_prev: F, t0: F) -> Option { + let s0 = l0 * t0; + let s1 = claim_prev - s0; + l1.invert().into_option().map(|inv| s1 * inv) +} + +/// Tracks α_i = eqe(w_{ { + alpha: F, +} + +impl Default for EqRoundFactor { + fn default() -> Self { + Self::new() + } +} + +impl EqRoundFactor { + /// Creates a new tracker with α_0 = 1. + pub fn new() -> Self { + Self { alpha: F::ONE } + } + + /// Returns ℓ_i evaluated at U_2 = {∞, 0, 1} for the provided w_i. + /// + /// - `infinity` = ℓ_i(∞) = α_i · (2w_i − 1) + /// - `finite[0]` = ℓ_i(0) = α_i · (1 − w_i) + /// - `finite[1]` = ℓ_i(1) = α_i · w_i + pub fn values(&self, w_i: F) -> LagrangeEvals { + let l0 = self.alpha * (F::ONE - w_i); + let l1 = self.alpha * w_i; + let linf = self.alpha * (w_i.double() - F::ONE); + LagrangeEvals::new(linf, [l0, l1]) + } + + /// Advances α using ℓ_i(r_i) = linf * r_i + l0. + pub fn advance(&mut self, li: &LagrangeEvals, r_i: F) { + self.alpha = li.eval_linear_at(r_i); + } +} + +/// Test-only helper methods for EqRoundFactor. +#[cfg(test)] +impl EqRoundFactor { + /// Returns the current prefix product α_i. + pub(crate) fn alpha(&self) -> F { + self.alpha + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::provider::pasta::pallas; + use ff::Field; + + type F = pallas::Scalar; + + fn eqe_bit(w: F, x: F) -> F { + F::ONE - w - x + (w * x).double() + } + + // Round 0 has alpha = 1; check the base formulas for l0, l1, linf. + #[test] + fn test_values_round0_basic() { + let w = F::from(3u64); + let tracker = EqRoundFactor::new(); + let v = tracker.values(w); + + assert_eq!(v.at_zero(), F::ONE - w); + assert_eq!(v.at_one(), w); + assert_eq!(v.at_infinity(), w.double() - F::ONE); + } + + // Invariants: l0 + l1 = alpha and linf = l1 - l0 after one advance. + #[test] + fn test_values_relations_hold() { + let w0 = F::from(7u64); + let w1 = F::from(5u64); + let r0 = F::from(9u64); + let mut tracker = EqRoundFactor::new(); + + let v0 = tracker.values(w0); + tracker.advance(&v0, r0); + + let v1 = tracker.values(w1); + assert_eq!(v1.at_zero() + v1.at_one(), tracker.alpha()); + assert_eq!(v1.at_infinity(), v1.at_one() - v1.at_zero()); + } + + // eval_linear_at(u) should agree with stored points and derived l(2). + #[test] + fn test_li_at_matches_values() { + let w = F::from(11u64); + let tracker = EqRoundFactor::new(); + let v = tracker.values(w); + + assert_eq!(v.eval_linear_at(F::ZERO), v.at_zero()); + assert_eq!(v.eval_linear_at(F::ONE), v.at_one()); + assert_eq!( + v.eval_linear_at(F::from(2u64)), + v.at_infinity().double() + v.at_zero() + ); + } + + // advance should update alpha by eqe(w, r). + #[test] + fn test_advance_updates_alpha() { + let w0 = F::from(4u64); + let r0 = F::from(6u64); + let mut tracker = EqRoundFactor::new(); + let alpha0 = tracker.alpha(); + + let v0 = tracker.values(w0); + tracker.advance(&v0, r0); + + let expected = alpha0 * eqe_bit(w0, r0); + assert_eq!(tracker.alpha(), expected); + } + + // Repeated updates should match the product of eqe(w_i, r_i). + #[test] + fn test_alpha_matches_product() { + let taus = vec![F::from(2u64), F::from(5u64), F::from(8u64)]; + let rs = vec![F::from(3u64), F::from(4u64), F::from(7u64)]; + let mut tracker = EqRoundFactor::new(); + + let mut expected = F::ONE; + for (tau, r) in taus.into_iter().zip(rs.into_iter()) { + let v = tracker.values(tau); + tracker.advance(&v, r); + expected *= eqe_bit(tau, r); + assert_eq!(tracker.alpha(), expected); + } + } + + // Degenerate endpoints: w=0 and w=1 should yield expected l0/l1/linf. + #[test] + fn test_values_degenerate_case_for_w_zero_and_one() { + let tracker = EqRoundFactor::::new(); + + let v0 = tracker.values(F::ZERO); + assert_eq!(v0.at_zero(), F::ONE); + assert_eq!(v0.at_one(), F::ZERO); + assert_eq!(v0.at_infinity(), -F::ONE); + + let v1 = tracker.values(F::ONE); + assert_eq!(v1.at_zero(), F::ZERO); + assert_eq!(v1.at_one(), F::ONE); + assert_eq!(v1.at_infinity(), F::ONE); + } + + // For w = 1/2, slope should be zero (linf = 0). + #[test] + fn test_slope_zero_at_half() { + let half = F::from(2u64).invert().unwrap(); + let tracker = EqRoundFactor::new(); + let v = tracker.values(half); + + assert_eq!(v.at_infinity(), F::ZERO); + assert_eq!(v.at_zero() + v.at_one(), F::ONE); + } + + // derive_t1 should return s1 / l1 for non-zero l1. + #[test] + fn test_derive_t1_returns_value() { + let l0 = F::from(2u64); + let l1 = F::from(5u64); + let t0 = F::from(11u64); + let claim = F::from(97u64); + + let s0 = l0 * t0; + let s1 = claim - s0; + let expected = s1 * l1.invert().unwrap(); + + assert_eq!(derive_t1(l0, l1, claim, t0), Some(expected)); + } + + // derive_t1 should return None when l1 == 0. + #[test] + fn test_derive_t1_returns_none_on_zero_l1() { + let l0 = F::from(3u64); + let l1 = F::ZERO; + let t0 = F::from(4u64); + let claim = F::from(10u64); + + assert_eq!(derive_t1(l0, l1, claim, t0), None); + } +} diff --git a/src/lagrange_accumulator/evals.rs b/src/lagrange_accumulator/evals.rs new file mode 100644 index 0000000..e1f3d5c --- /dev/null +++ b/src/lagrange_accumulator/evals.rs @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Evaluation containers for Lagrange domains U_d and Û_d. + +use ff::PrimeField; + +#[cfg(test)] +use super::domain::LagrangePoint; + +/// Evaluations at all D+1 points of U_d = {∞, 0, 1, ..., D-1}. +/// +/// This type stores values indexed by [`LagrangePoint`], with the infinity +/// point stored separately from the D finite points. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct LagrangeEvals { + /// Value at the infinity point + pub infinity: T, + /// Values at finite points 0, 1, ..., D-1 + pub finite: [T; D], +} + +impl LagrangeEvals { + /// Create new evaluations from infinity and finite values. + #[inline] + pub fn new(infinity: T, finite: [T; D]) -> Self { + Self { infinity, finite } + } + + /// Get value at infinity. + #[inline] + pub fn at_infinity(&self) -> T { + self.infinity + } + + /// Get value at zero (finite point 0). + #[inline] + pub fn at_zero(&self) -> T { + self.finite[0] + } + + /// Get value at one (finite point 1). + /// + /// # Panics (debug builds only) + /// Panics if D < 2. + #[inline] + pub fn at_one(&self) -> T { + debug_assert!(D >= 2, "at_one() requires D >= 2"); + self.finite[1] + } + + /// Iterate values in U_d order: [∞, 0, 1, ..., D-1]. + pub fn iter_ud_order(&self) -> impl Iterator + '_ { + std::iter::once(self.infinity).chain(self.finite.iter().copied()) + } +} + +/// Test-only helper methods for LagrangeEvals. +#[cfg(test)] +impl LagrangeEvals { + /// Get value at a domain point. + #[inline] + pub fn get(&self, p: LagrangePoint) -> T { + match p { + LagrangePoint::Infinity => self.infinity, + LagrangePoint::Finite(k) => self.finite[k], + } + } +} + +impl LagrangeEvals { + /// Evaluate linear polynomial at u: L(u) = infinity * u + finite[0]. + /// + /// For evaluations of a degree-1 polynomial over U_2 = {∞, 0, 1}, + /// this computes L(u) = l_∞ · u + l_0. + #[inline] + pub fn eval_linear_at(&self, u: F) -> F { + self.infinity * u + self.finite[0] + } +} + +/// Evaluations at all D points of Û_d = U_d \ {1} = {∞, 0, 2, ..., D-1}. +/// +/// This reduced domain excludes point 1 because s(1) can be recovered +/// from the sum-check constraint s(0) + s(1) = claim. +/// +/// Indexing follows [`LagrangeHatPoint::to_index()`]: ∞→0, 0→1, 2→2, 3→3, ... +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct LagrangeHatEvals { + data: [T; D], +} + +impl LagrangeHatEvals { + /// Create from array indexed by `LagrangeHatPoint::to_index()`. + #[inline] + pub fn from_array(data: [T; D]) -> Self { + Self { data } + } + + /// Get value at infinity (index 0). + #[inline] + pub fn at_infinity(&self) -> T { + self.data[0] + } + + /// Get value at zero (index 1). + #[inline] + pub fn at_zero(&self) -> T { + self.data[1] + } +} diff --git a/src/lagrange_accumulator/extension.rs b/src/lagrange_accumulator/extension.rs new file mode 100644 index 0000000..24fe88a --- /dev/null +++ b/src/lagrange_accumulator/extension.rs @@ -0,0 +1,702 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Procedure 6: Extension of multilinear polynomial evaluations from +//! boolean hypercube {0,1}^ℓ to Lagrange domain U_D^ℓ. + +#[cfg(test)] +use super::domain::LagrangeIndex; +use std::ops::{Add, Sub}; + +// ============================================================================ +// Helper functions for Lagrange extension +// ============================================================================ + +/// Extend a single suffix element from boolean to Lagrange domain. +/// Original implementation, kept for clarity and remainder handling. +#[inline(always)] +fn extend_single( + src: &[T], + dst: &mut [T], + base_src: usize, + base_dst: usize, + suffix_count: usize, + suffix_idx: usize, +) where + T: Copy + Default + Add + Sub, +{ + let p0 = src[base_src + suffix_idx]; + let p1 = src[base_src + suffix_count + suffix_idx]; + let diff = p1 - p0; + + // γ = ∞ (index 0) + dst[base_dst + suffix_idx] = diff; + // γ = 0 (index 1) + dst[base_dst + suffix_count + suffix_idx] = p0; + + if D >= 2 { + // γ = 1 (index 2) + dst[base_dst + 2 * suffix_count + suffix_idx] = p1; + // γ = 2..D-1: extrapolate + let mut val = p1; + for k in 2..D { + val = val + diff; + dst[base_dst + (k + 1) * suffix_count + suffix_idx] = val; + } + } +} + +/// Extend 4 consecutive suffix elements with ILP optimization. +/// Interleaves operations across 4 independent elements for better +/// instruction-level parallelism on modern CPUs (especially AArch64). +#[inline(always)] +fn extend_batch4( + src: &[T], + dst: &mut [T], + base_src: usize, + base_dst: usize, + suffix_count: usize, + s: usize, +) where + T: Copy + Default + Add + Sub, +{ + // Load 4 pairs (contiguous reads) + let p0_0 = src[base_src + s]; + let p0_1 = src[base_src + s + 1]; + let p0_2 = src[base_src + s + 2]; + let p0_3 = src[base_src + s + 3]; + + let p1_0 = src[base_src + suffix_count + s]; + let p1_1 = src[base_src + suffix_count + s + 1]; + let p1_2 = src[base_src + suffix_count + s + 2]; + let p1_3 = src[base_src + suffix_count + s + 3]; + + // 4 independent diffs (ILP) + let d0 = p1_0 - p0_0; + let d1 = p1_1 - p0_1; + let d2 = p1_2 - p0_2; + let d3 = p1_3 - p0_3; + + // γ = ∞ (4 contiguous writes) + dst[base_dst + s] = d0; + dst[base_dst + s + 1] = d1; + dst[base_dst + s + 2] = d2; + dst[base_dst + s + 3] = d3; + + // γ = 0 (4 contiguous writes) + dst[base_dst + suffix_count + s] = p0_0; + dst[base_dst + suffix_count + s + 1] = p0_1; + dst[base_dst + suffix_count + s + 2] = p0_2; + dst[base_dst + suffix_count + s + 3] = p0_3; + + if D >= 2 { + // γ = 1 (4 contiguous writes) + dst[base_dst + 2 * suffix_count + s] = p1_0; + dst[base_dst + 2 * suffix_count + s + 1] = p1_1; + dst[base_dst + 2 * suffix_count + s + 2] = p1_2; + dst[base_dst + 2 * suffix_count + s + 3] = p1_3; + + // γ = 2..D-1: extrapolate (4 at a time) + if D > 2 { + let (mut v0, mut v1, mut v2, mut v3) = (p1_0, p1_1, p1_2, p1_3); + for k in 2..D { + v0 = v0 + d0; + v1 = v1 + d1; + v2 = v2 + d2; + v3 = v3 + d3; + let offset = (k + 1) * suffix_count + s; + dst[base_dst + offset] = v0; + dst[base_dst + offset + 1] = v1; + dst[base_dst + offset + 2] = v2; + dst[base_dst + offset + 3] = v3; + } + } + } +} + +#[cfg(test)] +use crate::polys::multilinear::MultilinearPolynomial; +#[cfg(test)] +use crate::small_field::SmallValueField; +#[cfg(test)] +use ff::PrimeField; + +/// Multilinear polynomial evaluations extended to the Lagrange domain U_D^ℓ. +/// +/// Stores evaluations at all (D+1)^num_vars points of the extended domain, +/// indexed by `LagrangeIndex`. +pub struct LagrangeEvaluatedMultilinearPolynomial +where + T: Copy + Default + Add + Sub, +{ + #[allow(dead_code)] // Used by test-only methods (get, get_by_domain, len) + evals: Vec, // size (D+1)^num_vars + #[allow(dead_code)] // Used by test-only num_vars() method + num_vars: usize, +} + +impl LagrangeEvaluatedMultilinearPolynomial +where + T: Copy + Default + Add + Sub, +{ + /// Base of the extended domain U_D (= D + 1) + const BASE: usize = D + 1; + + /// Extend boolean hypercube evaluations to Lagrange domain in-place. + /// + /// Returns the number of valid elements in `buf_curr` (= (D+1)^num_vars). + /// After this call, `buf_curr[..result]` contains the extended evaluations. + /// + /// This is the zero-allocation version - caller reads results directly from `buf_curr`. + /// + /// # Arguments + /// * `input` - Boolean hypercube evaluations (read-only slice, length must be power of 2) + /// * `buf_curr` - Result buffer, will contain extended evaluations after call + /// * `buf_scratch` - Scratch buffer used during iterative extension + /// + /// Both buffers will be resized if needed to (D+1)^num_vars. + pub fn extend_in_place(input: &[T], buf_curr: &mut Vec, buf_scratch: &mut Vec) -> usize { + let num_vars = input.len().trailing_zeros() as usize; + debug_assert_eq!(input.len(), 1 << num_vars, "Input size must be power of 2"); + + if num_vars == 0 { + // Single element: copy to buf_curr and return + if buf_curr.is_empty() { + buf_curr.push(T::default()); + } + buf_curr[0] = input[0]; + return 1; + } + + let final_size = Self::BASE.pow(num_vars as u32); + + // Ensure buffers are large enough + if buf_curr.len() < final_size { + buf_curr.resize(final_size, T::default()); + } + if buf_scratch.len() < final_size { + buf_scratch.resize(final_size, T::default()); + } + + // Copy input into buf_curr to start + buf_curr[..input.len()].copy_from_slice(input); + + for j in 1..=num_vars { + // At step j: + // - prefix_count = (D+1)^{j-1} extended prefix combinations + // - suffix_count = 2^{num_vars-j} remaining boolean suffix combinations + let prefix_count = Self::BASE.pow((j - 1) as u32); + let suffix_count = 1usize << (num_vars - j); + // Current layout: prefix_count rows × 2 boolean values × suffix_count elements + let current_stride = 2 * suffix_count; + // Next layout: prefix_count rows × (D+1) domain values × suffix_count elements + let next_stride = Self::BASE * suffix_count; + + // Alternate between buffers each iteration + let (src, dst) = if j % 2 == 1 { + (&buf_curr[..], &mut buf_scratch[..]) + } else { + (&buf_scratch[..], &mut buf_curr[..]) + }; + + for prefix_idx in 0..prefix_count { + let base_src = prefix_idx * current_stride; + let base_dst = prefix_idx * next_stride; + let mut s = 0; + + // Process 4 suffix elements at a time for ILP + while s + 4 <= suffix_count { + extend_batch4::(src, dst, base_src, base_dst, suffix_count, s); + s += 4; + } + + // Handle remainder (0-3 elements) + while s < suffix_count { + extend_single::(src, dst, base_src, base_dst, suffix_count, s); + s += 1; + } + } + } + + // Ensure result ends up in buf_curr (swap if result is currently in buf_scratch) + if num_vars % 2 == 1 { + std::mem::swap(buf_curr, buf_scratch); + } + final_size + } +} + +/// Test-only helper methods for LagrangeEvaluatedMultilinearPolynomial. +#[cfg(test)] +impl LagrangeEvaluatedMultilinearPolynomial +where + T: Copy + Default + Add + Sub, +{ + /// Procedure 6: Extend polynomial evaluations from {0,1}^ℓ₀ to U_D^ℓ₀. + pub fn from_boolean_evals(input: &[T]) -> Self { + let num_vars = input.len().trailing_zeros() as usize; + debug_assert_eq!(input.len(), 1 << num_vars, "Input size must be power of 2"); + + let mut current = input.to_vec(); + + for j in 1..=num_vars { + // At step j: + // - prefix_count = (D+1)^{j-1} (number of extended prefix combinations) + // - suffix_count = 2^{num_vars-j} (number of remaining boolean suffix combinations) + // - current has size = prefix_count × 2 × suffix_count + // - next will have size = prefix_count × (D+1) × suffix_count + + let prefix_count = Self::BASE.pow((j - 1) as u32); + let suffix_count = 1usize << (num_vars - j); + let current_stride = 2 * suffix_count; // stride between prefixes in current + let next_stride = Self::BASE * suffix_count; // stride between prefixes in next + + let next_size = prefix_count * next_stride; + let mut next = vec![T::default(); next_size]; + + for prefix_idx in 0..prefix_count { + for suffix_idx in 0..suffix_count { + // Read p(prefix, 0, suffix) and p(prefix, 1, suffix) + let base_current = prefix_idx * current_stride; + let p0 = current[base_current + suffix_idx]; + let p1 = current[base_current + suffix_count + suffix_idx]; + + // Extend using Procedure 5: compute p(prefix, γ, suffix) for γ ∈ U_D + let diff = p1 - p0; + let base_next = prefix_idx * next_stride; + + // γ = ∞ (index 0): leading coefficient + next[base_next + suffix_idx] = diff; + + // γ = 0 (index 1): p(prefix, 0, suffix) + next[base_next + suffix_count + suffix_idx] = p0; + + if D >= 2 { + // γ = 1 (index 2): p(prefix, 1, suffix) + next[base_next + 2 * suffix_count + suffix_idx] = p1; + + // γ = 2, 3, ..., D-1: extrapolate using accumulation (faster than multiplication) + // val starts at p1 = p0 + 1*diff, then we add diff each iteration + let mut val = p1; + for k in 2..D { + val = val + diff; // val = p0 + k*diff + next[base_next + (k + 1) * suffix_count + suffix_idx] = val; + } + } + } + } + + current = next; + } + + Self { + evals: current, + num_vars, + } + } + + /// Get evaluation by flat index (performance path) + #[inline] + pub fn get(&self, idx: usize) -> T { + self.evals[idx] + } + + /// Number of evaluations + #[inline] + pub fn len(&self) -> usize { + self.evals.len() + } + + /// Get evaluation by domain tuple (type-safe path) + #[inline] + pub fn get_by_domain(&self, tuple: &LagrangeIndex) -> T { + self.evals[tuple.to_flat_index()] + } + + /// Number of variables + pub fn num_vars(&self) -> usize { + self.num_vars + } + + /// Convert flat index to domain tuple + pub fn to_domain_tuple(&self, flat_idx: usize) -> LagrangeIndex { + LagrangeIndex::from_flat_index(flat_idx, self.num_vars) + } +} + +/// Test-only: Create from a MultilinearPolynomial. +#[cfg(test)] +#[allow(missing_docs)] +impl LagrangeEvaluatedMultilinearPolynomial { + pub fn from_multilinear(poly: &MultilinearPolynomial) -> Self { + Self::from_boolean_evals(&poly.Z) + } + + pub fn from_evals(evals: Vec, num_vars: usize) -> Self { + debug_assert_eq!(evals.len(), (D + 1).pow(num_vars as u32)); + Self { evals, num_vars } + } +} + +/// Test-only: Convert i32 evaluations to field elements. +#[cfg(test)] +#[allow(missing_docs)] +impl LagrangeEvaluatedMultilinearPolynomial { + pub fn to_field>(&self) -> LagrangeEvaluatedMultilinearPolynomial { + LagrangeEvaluatedMultilinearPolynomial { + evals: self.evals.iter().map(|&v| F::small_to_field(v)).collect(), + num_vars: self.num_vars, + } + } +} + +/// Test-only: Convert i64 evaluations to field elements. +#[cfg(test)] +#[allow(dead_code, missing_docs)] +impl LagrangeEvaluatedMultilinearPolynomial { + pub fn to_field>(&self) -> LagrangeEvaluatedMultilinearPolynomial { + LagrangeEvaluatedMultilinearPolynomial { + evals: self + .evals + .iter() + .map(|&v| crate::small_field::i64_to_field(v)) + .collect(), + num_vars: self.num_vars, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + polys::multilinear::MultilinearPolynomial, provider::pasta::pallas, + small_field::SmallValueField, + }; + use ff::Field; + + use super::super::domain::LagrangePoint; + + type Scalar = pallas::Scalar; + + #[test] + fn test_extend_output_size() { + for num_vars in 1..=4 { + let input_size = 1 << num_vars; + let input: Vec = (0..input_size).map(|i| Scalar::from(i as u64)).collect(); + let poly = MultilinearPolynomial::new(input); + + let extended = LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&poly); + + let expected_size = 4usize.pow(num_vars as u32); // (D+1)^num_vars = 4^num_vars + assert_eq!(extended.len(), expected_size); + assert_eq!(extended.num_vars(), num_vars); + } + } + + #[test] + fn test_extend_preserves_boolean() { + let num_vars = 3; + const D: usize = 3; + let base = D + 1; + + let input: Vec = (0..(1 << num_vars)) + .map(|_| Scalar::random(&mut rand_core::OsRng)) + .collect(); + let poly = MultilinearPolynomial::new(input.clone()); + + let extended = LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&poly); + + // In U_d indexing: 0 → index 1, 1 → index 2 + #[allow(clippy::needless_range_loop)] + for b in 0..(1 << num_vars) { + let mut ud_idx = 0; + for j in 0..num_vars { + let bit = (b >> (num_vars - 1 - j)) & 1; + let ud_val = bit + 1; // 0→1, 1→2 + ud_idx = ud_idx * base + ud_val; + } + + assert_eq!(extended.get(ud_idx), input[b]); + } + } + + #[test] + fn test_extend_single_var() { + let p0 = Scalar::from(7u64); + let p1 = Scalar::from(19u64); + + let poly = MultilinearPolynomial::new(vec![p0, p1]); + let extended = LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&poly); + + // U_d = {∞, 0, 1, 2} with indices 0, 1, 2, 3 + assert_eq!(extended.get(0), p1 - p0, "p(∞) = leading coeff"); + assert_eq!(extended.get(1), p0, "p(0)"); + assert_eq!(extended.get(2), p1, "p(1)"); + assert_eq!(extended.get(3), p1.double() - p0, "p(2) = 2*p1 - p0"); + } + + #[test] + fn test_extend_matches_direct() { + let num_vars = 3; + const D: usize = 3; + let base = D + 1; + + let input: Vec = (0..(1 << num_vars)) + .map(|_| Scalar::random(&mut rand_core::OsRng)) + .collect(); + let poly = MultilinearPolynomial::new(input.clone()); + let extended = LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&poly); + + // Check all finite points via direct multilinear evaluation + for idx in 0..extended.len() { + let tuple = index_to_tuple(idx, base, num_vars); + + // Skip infinity points (index 0 in any coordinate) + if tuple.contains(&0) { + continue; + } + + // Convert U_d indices to field values: index k → value k-1 + let point: Vec = tuple + .iter() + .map(|&t| Scalar::from((t - 1) as u64)) + .collect(); + + let direct = evaluate_multilinear(&input, &point); + assert_eq!(extended.get(idx), direct); + } + } + + #[test] + #[allow(clippy::identity_op, clippy::erasing_op)] + fn test_extend_infinity_leading_coeff() { + let num_vars = 3; + const D: usize = 3; + let base = D + 1; + + let input: Vec = (0..(1 << num_vars)) + .map(|_| Scalar::random(&mut rand_core::OsRng)) + .collect(); + let poly = MultilinearPolynomial::new(input.clone()); + let extended = LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&poly); + + // p(∞, y₂, y₃) = p(1, y₂, y₃) - p(0, y₂, y₃) + for y2 in 0..2usize { + for y3 in 0..2usize { + let idx_0 = (0 << 2) | (y2 << 1) | y3; // p(0, y2, y3) + let idx_1 = (1 << 2) | (y2 << 1) | y3; // p(1, y2, y3) + + let expected = input[idx_1] - input[idx_0]; + let ext_idx = 0 * base * base + (y2 + 1) * base + (y3 + 1); + + assert_eq!(extended.get(ext_idx), expected); + } + } + } + + #[test] + #[allow(clippy::identity_op, clippy::erasing_op)] + fn test_extend_known_polynomial() { + // p(X, Y, Z) = X + 2Y + 4Z + const D: usize = 3; + let base = D + 1; + + let mut input = Vec::with_capacity(8); + for x in 0..2u64 { + for y in 0..2u64 { + for z in 0..2u64 { + input.push(Scalar::from(x + 2 * y + 4 * z)); + } + } + } + let poly = MultilinearPolynomial::new(input); + + let extended = LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&poly); + + // Finite points: p(a,b,c) = a + 2b + 4c + for a in 0..D { + for b in 0..D { + for c in 0..D { + let idx = (a + 1) * base * base + (b + 1) * base + (c + 1); + let expected = Scalar::from(a as u64 + 2 * b as u64 + 4 * c as u64); + assert_eq!(extended.get(idx), expected); + } + } + } + + // Infinity points = variable coefficients + assert_eq!( + extended.get(0 * base * base + 1 * base + 1), + Scalar::ONE, + "p(∞,0,0) = coeff of X" + ); + assert_eq!( + extended.get(1 * base * base + 0 * base + 1), + Scalar::from(2u64), + "p(0,∞,0) = coeff of Y" + ); + assert_eq!( + extended.get(1 * base * base + 1 * base + 0), + Scalar::from(4u64), + "p(0,0,∞) = coeff of Z" + ); + assert_eq!(extended.get(0), Scalar::ZERO, "p(∞,∞,∞) = 0 (no XYZ term)"); + } + + #[test] + fn test_get_by_domain() { + let p0 = Scalar::from(7u64); + let p1 = Scalar::from(19u64); + + let poly = MultilinearPolynomial::new(vec![p0, p1]); + let extended = LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&poly); + + // Test type-safe access + let tuple_inf = LagrangeIndex::<3>(vec![LagrangePoint::Infinity]); + let tuple_zero = LagrangeIndex::<3>(vec![LagrangePoint::Finite(0)]); + let tuple_one = LagrangeIndex::<3>(vec![LagrangePoint::Finite(1)]); + + assert_eq!(extended.get_by_domain(&tuple_inf), p1 - p0); + assert_eq!(extended.get_by_domain(&tuple_zero), p0); + assert_eq!(extended.get_by_domain(&tuple_one), p1); + } + + // === Test helpers === + + fn index_to_tuple(mut idx: usize, base: usize, len: usize) -> Vec { + let mut tuple = vec![0; len]; + for i in (0..len).rev() { + tuple[i] = idx % base; + idx /= base; + } + tuple + } + + /// Direct multilinear evaluation: p(r) = Σ_x p(x) · eq(x, r). + /// + /// This mirrors EqPolynomial::evals_from_points() so the bit ordering + /// matches the codebase's {0,1}^ℓ indexing. + fn evaluate_multilinear(evals: &[Scalar], point: &[Scalar]) -> Scalar { + let chis = crate::polys::eq::EqPolynomial::evals_from_points(point); + evals + .iter() + .zip(chis.iter()) + .fold(Scalar::ZERO, |acc, (z, chi)| acc + *z * *chi) + } + + // === SmallLagrangePolynomial tests === + + #[test] + fn test_small_lagrange_matches_field_version() { + const D: usize = 3; + let num_vars = 3; + + // Create input as small values (i32 is identity) + let input_small: Vec = (0..(1 << num_vars)).map(|i| (i + 1) as i32).collect(); + + // Create same input as field elements + let input_field: Vec = (0..(1 << num_vars)) + .map(|i| Scalar::from((i + 1) as u64)) + .collect(); + let poly = MultilinearPolynomial::new(input_field); + + // Extend using both methods + let small_ext = + LagrangeEvaluatedMultilinearPolynomial::::from_boolean_evals(&input_small); + let field_ext = LagrangeEvaluatedMultilinearPolynomial::::from_multilinear(&poly); + + // Verify they match + assert_eq!(small_ext.len(), field_ext.len()); + for i in 0..small_ext.len() { + let small_as_field: Scalar = Scalar::small_to_field(small_ext.get(i)); + assert_eq!(small_as_field, field_ext.get(i), "mismatch at index {i}"); + } + } + + #[test] + fn test_small_lagrange_single_var() { + let p0: i32 = 7; + let p1: i32 = 19; + + let input = vec![p0, p1]; + let extended = LagrangeEvaluatedMultilinearPolynomial::::from_boolean_evals(&input); + + // U_d = {∞, 0, 1, 2} with indices 0, 1, 2, 3 + assert_eq!(extended.get(0), p1 - p0, "p(∞) = leading coeff"); + assert_eq!(extended.get(1), p0, "p(0)"); + assert_eq!(extended.get(2), p1, "p(1)"); + // p(2) = p0 + 2 * (p1 - p0) = 2*p1 - p0 = 2*19 - 7 = 31 + assert_eq!(extended.get(3), 31i32, "p(2) = 2*p1 - p0"); + } + + #[test] + fn test_small_lagrange_extend_in_place() { + const D: usize = 2; + let num_vars = 3; + + let input: Vec = (0..(1 << num_vars)).map(|i| (i * 2 + 1) as i32).collect(); + + // Extend using allocating version + let ext1 = LagrangeEvaluatedMultilinearPolynomial::::from_boolean_evals(&input); + + // Extend in-place (zero allocation after initial buffer setup) + let mut buf_curr = Vec::new(); + let mut buf_scratch = Vec::new(); + let final_size = LagrangeEvaluatedMultilinearPolynomial::::extend_in_place( + &input, + &mut buf_curr, + &mut buf_scratch, + ); + + // Result is always in buf_curr after extend_in_place + let ext2 = &buf_curr[..final_size]; + + // Verify they match + assert_eq!(ext1.len(), final_size); + for (i, &ext2_val) in ext2.iter().enumerate() { + assert_eq!(ext1.get(i), ext2_val, "mismatch at index {i}"); + } + } + + #[test] + fn test_small_lagrange_to_field() { + const D: usize = 2; + let num_vars = 2; + + let input: Vec = (0..(1 << num_vars)).map(|i| (i + 1) as i32).collect(); + + let small_ext = LagrangeEvaluatedMultilinearPolynomial::::from_boolean_evals(&input); + let field_ext: LagrangeEvaluatedMultilinearPolynomial = + small_ext.to_field::(); + + // Verify conversion + for i in 0..small_ext.len() { + let expected: Scalar = Scalar::small_to_field(small_ext.get(i)); + assert_eq!(field_ext.get(i), expected); + } + } + + #[test] + fn test_small_lagrange_negative_values() { + // Test with negative differences (p0 > p1) + let p0: i32 = 100; + let p1: i32 = 50; + + let input = vec![p0, p1]; + let extended = LagrangeEvaluatedMultilinearPolynomial::::from_boolean_evals(&input); + + // p(∞) = p1 - p0 = -50 + assert_eq!(extended.get(0), -50i32); + assert_eq!(extended.get(1), p0); + assert_eq!(extended.get(2), p1); + + // Verify field conversion handles negatives correctly + let field_ext: LagrangeEvaluatedMultilinearPolynomial = + extended.to_field::(); + assert_eq!(field_ext.get(0), -Scalar::from(50u64)); + } +} diff --git a/src/lagrange_accumulator/index.rs b/src/lagrange_accumulator/index.rs new file mode 100644 index 0000000..88014ea --- /dev/null +++ b/src/lagrange_accumulator/index.rs @@ -0,0 +1,573 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Index mapping for Algorithm 6 small-value sumcheck optimization (Definition A.5). +//! +//! This module defines: +//! - [`AccumulatorPrefixIndex`]: Describes how an evaluation prefix β contributes to accumulators +//! - [`compute_idx4`]: Maps evaluation prefixes β ∈ U_d^ℓ₀ to accumulator contributions + +use super::domain::{LagrangeHatPoint, LagrangeIndex, LagrangePoint}; + +/// A single contribution from β to an accumulator A_i(v, u). +/// +/// Represents the decomposition of β ∈ U_d^ℓ₀ into: +/// - Round i (which accumulator) +/// - Prefix v = (β₁, ..., β_{i-1}) ∈ U_d^{i-1} +/// - Coordinate u = βᵢ ∈ Û_d +/// - Binary suffix y = (β_{i+1}, ..., β_{ℓ₀}) ∈ {0,1}^{ℓ₀-i} +/// +/// This set identifies all accumulators A_i(v, u) to which the product term +/// computed using the prefix β contributes. The y component will be summed +/// over when computing the accumulators. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct AccumulatorPrefixIndex { + /// Total number of small-value rounds (ℓ₀) + pub l0: usize, + + /// Round index i ∈ [1, ℓ₀] (1-indexed as in paper) + pub round: usize, + + /// Prefix v = (β₁, ..., β_{i-1}) as flat index in U_d^{i-1} + pub v_idx: usize, + + /// Coordinate u = βᵢ ∈ Û_d (type-safe, excludes value 1) + pub u: LagrangeHatPoint, + + /// Binary suffix y = (β_{i+1}, ..., β_{ℓ₀}) as flat index in {0,1}^{ℓ₀-i} + pub y_idx: usize, +} + +impl AccumulatorPrefixIndex { + /// Round as 0-indexed (for array access) + #[inline] + pub fn round_0idx(&self) -> usize { + self.round - 1 + } +} + +/// Test-only helper methods for verifying index computations. +#[cfg(test)] +impl AccumulatorPrefixIndex { + /// Length of prefix v: i - 1 + #[inline] + pub fn prefix_len(&self) -> usize { + self.round - 1 + } + + /// Length of binary suffix y: ℓ₀ - i + #[inline] + pub fn suffix_len(&self) -> usize { + self.l0 - self.round + } +} + +/// Pre-computed flat indices for O(1) accumulator access in inner loops. +/// +/// Derived from [`AccumulatorPrefixIndex`] with all type-safe conversions pre-applied. +/// This avoids repeated method calls and enum matching in hot loops. +#[derive(Clone, Copy)] +pub struct CachedPrefixIndex { + /// Round as 0-indexed (for array access) + pub round_0: usize, + /// Prefix v as flat index + pub v_idx: usize, + /// Coordinate u as flat index in Û_d + pub u_idx: usize, + /// Binary suffix y as flat index + pub y_idx: usize, +} + +impl From<&AccumulatorPrefixIndex> for CachedPrefixIndex { + fn from(idx: &AccumulatorPrefixIndex) -> Self { + Self { + round_0: idx.round_0idx(), + v_idx: idx.v_idx, + u_idx: idx.u.to_index(), + y_idx: idx.y_idx, + } + } +} + +/// Computes accumulator indices for β ∈ U_d^ℓ₀ (Definition A.5). +/// +/// For each round i where: +/// 1. The suffix β[i..] is binary (values in {0,1}) +/// 2. The coordinate u = β[i-1] is in Û_d (i.e., u ≠ 1) +/// +/// Returns contributions as `AccumulatorPrefixIndex`. +pub fn compute_idx4(beta: &LagrangeIndex) -> Vec> { + let l0 = beta.len(); + + let mut result = Vec::new(); + let base = LagrangePoint::::BASE; + + // Phase 1: Compute prefix indices (forward pass) + // prefix_idx[i] = flat index of β[0..i] in U_d^i (mixed-radix encoding) + let mut prefix_idx = vec![0usize; l0 + 1]; + for i in 0..l0 { + prefix_idx[i + 1] = prefix_idx[i] * base + beta.0[i].to_index(); + } + + // Phase 2: Compute suffix properties (backward pass) + // suffix_is_binary[i] = true iff β[i..] consists only of binary values (0 or 1) + // suffix_idx[i] = binary encoding of β[i..] if it's binary, undefined otherwise + let mut suffix_is_binary = vec![true; l0 + 1]; + let mut suffix_idx = vec![0usize; l0 + 1]; + for i in (0..l0).rev() { + let point = beta.0[i]; + if !point.is_binary() { + // Non-binary value breaks the suffix property for this position and all earlier + suffix_is_binary[i] = false; + continue; + } + // Propagate binary status from suffix + suffix_is_binary[i] = suffix_is_binary[i + 1]; + let bit = match point { + LagrangePoint::Finite(0) => 0, + LagrangePoint::Finite(1) => 1, + _ => unreachable!("binary points must be 0 or 1"), + }; + let shift = l0 - 1 - i; + suffix_idx[i] = suffix_idx[i + 1] | (bit << shift); + } + + // Phase 3: Generate contributions for valid rounds + for i in 1..=l0 { + if !suffix_is_binary[i] { + continue; + } + + let u = beta.0[i - 1]; + let Some(u_hat) = u.to_ud_hat() else { + continue; // u = Finite(1), not in Û_d + }; + + let v_idx = prefix_idx[i - 1]; + let y_idx = suffix_idx[i]; + + result.push(AccumulatorPrefixIndex { + l0, + round: i, // 1-indexed + v_idx, + u: u_hat, + y_idx, + }); + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Helper: construct LagrangeIndex from flat indices + /// 0 → ∞, 1 → 0, 2 → 1, 3 → 2, ... + fn tuple_from_indices(indices: &[usize]) -> LagrangeIndex { + LagrangeIndex( + indices + .iter() + .map(|&i| LagrangePoint::from_index(i)) + .collect(), + ) + } + + #[test] + fn test_accumulator_prefix_index_helpers() { + let idx: AccumulatorPrefixIndex<3> = AccumulatorPrefixIndex { + l0: 5, + round: 3, + v_idx: 10, + u: LagrangeHatPoint::Finite(2), + y_idx: 3, + }; + + assert_eq!(idx.round_0idx(), 2); // 3 - 1 + assert_eq!(idx.prefix_len(), 2); // 3 - 1 + assert_eq!(idx.suffix_len(), 2); // 5 - 3 + } + + #[test] + fn test_compute_idx4_example_mixed_binary() { + // β = (Finite(0), Finite(1), Finite(0)) → point (0, 1, 0) + // Round 2 filtered because u = value 1 ∉ Û_d + let beta = tuple_from_indices::<3>(&[1, 2, 1]); // indices → (0, 1, 0) + let contributions = compute_idx4(&beta); + + assert_eq!(contributions.len(), 2); // Round 2 filtered + + // Round 1: v=(), u=value 0, y=(Finite(1),Finite(0)) → y_idx = 2 + let round1 = contributions.iter().find(|c| c.round == 1).unwrap(); + assert_eq!(round1.v_idx, 0); + assert_eq!(round1.u, LagrangeHatPoint::Finite(0)); + assert_eq!(round1.y_idx, 2); // bits (1,0) = 2 + assert_eq!(round1.prefix_len(), 0); + assert_eq!(round1.suffix_len(), 2); + + // Round 2: FILTERED (u = Finite(1) ∉ Û_d) + assert!(contributions.iter().all(|c| c.round != 2)); + + // Round 3: v=(Finite(0),Finite(1)), u=value 0, y=() → y_idx = 0 + let round3 = contributions.iter().find(|c| c.round == 3).unwrap(); + assert_eq!(round3.v_idx, 6); // 1*4 + 2 = 6 + assert_eq!(round3.u, LagrangeHatPoint::Finite(0)); + assert_eq!(round3.y_idx, 0); + assert_eq!(round3.prefix_len(), 2); + assert_eq!(round3.suffix_len(), 0); + } + + #[test] + fn test_compute_idx4_example_with_infinity() { + // β = (∞, 0, 1) → point (∞, 0, 1) + // Round 3 filtered because u = value 1 ∉ Û_d + let beta = tuple_from_indices::<3>(&[0, 1, 2]); // indices → (∞, 0, 1) + let contributions = compute_idx4(&beta); + + assert_eq!(contributions.len(), 2); // Round 3 filtered + + // Round 1: v=(), u=∞, y=(Finite(0),Finite(1)) → y_idx = 1 + let round1 = contributions.iter().find(|c| c.round == 1).unwrap(); + assert_eq!(round1.v_idx, 0); + assert_eq!(round1.u, LagrangeHatPoint::Infinity); + assert_eq!(round1.y_idx, 1); // bits (0,1) = 1 + + // Round 2: v=(∞,), u=value 0, y=(Finite(1),) → y_idx = 1 + let round2 = contributions.iter().find(|c| c.round == 2).unwrap(); + assert_eq!(round2.v_idx, 0); + assert_eq!(round2.u, LagrangeHatPoint::Finite(0)); + assert_eq!(round2.y_idx, 1); // bits (1,) = 1 + + // Round 3: FILTERED (u = value 1 ∉ Û_d) + assert!(contributions.iter().all(|c| c.round != 3)); + } + + #[test] + fn test_compute_idx4_example_double_infinity() { + // β = (∞, ∞, 0) → point (∞, ∞, 0) + // Round 1 skipped because suffix contains ∞ + let beta = tuple_from_indices::<3>(&[0, 0, 1]); // indices → (∞, ∞, 0) + let contributions = compute_idx4(&beta); + + assert_eq!(contributions.len(), 2); + + // Round 1 should be missing (suffix has ∞) + assert!(contributions.iter().all(|c| c.round != 1)); + + // Round 2: v=(∞,), u=∞, y=(Finite(0),) → y_idx = 0 + let round2 = contributions.iter().find(|c| c.round == 2).unwrap(); + assert_eq!(round2.v_idx, 0); + assert_eq!(round2.u, LagrangeHatPoint::Infinity); + assert_eq!(round2.y_idx, 0); + + // Round 3: v=(∞,∞), u=Finite(0), y=() → y_idx = 0 + let round3 = contributions.iter().find(|c| c.round == 3).unwrap(); + assert_eq!(round3.v_idx, 0); // 0*4 + 0 = 0 + assert_eq!(round3.u, LagrangeHatPoint::Finite(0)); + assert_eq!(round3.y_idx, 0); + } + + #[test] + fn test_compute_idx4_example_all_extrapolated() { + // β = (Finite(2), Finite(2), Finite(2)) → point (2, 2, 2) + // Only last round (empty suffix is vacuously binary) + let beta = tuple_from_indices::<3>(&[3, 3, 3]); // indices → (2, 2, 2) + let contributions = compute_idx4(&beta); + + assert_eq!(contributions.len(), 1); + + let only = &contributions[0]; + assert_eq!(only.round, 3); + assert_eq!(only.v_idx, 15); // 3*4 + 3 = 15 + assert_eq!(only.u, LagrangeHatPoint::Finite(2)); + assert_eq!(only.y_idx, 0); + assert_eq!(only.suffix_len(), 0); + } + + #[test] + fn test_compute_idx4_all_ones_has_no_contributions() { + // β = (Finite(1), Finite(1), Finite(1)) → point (1, 1, 1) + // ALL rounds filtered because u = value 1 ∉ Û_d at every position + let beta = tuple_from_indices::<3>(&[2, 2, 2]); // indices → (1, 1, 1) + let contributions = compute_idx4(&beta); + + assert!( + contributions.is_empty(), + "β=(1,1,1) should have no contributions" + ); + } + + #[test] + fn test_compute_idx4_contribution_counts() { + // Not every β contributes! β where all elements are Finite(1) + // has NO contributions because u=1 ∉ Û_d for all rounds. + let l0 = 3; + const D: usize = 3; + let base = D + 1; + let prefix_ud_size = base.pow(l0 as u32); + + let mut zero_contribution_count = 0; + + for beta_idx in 0..prefix_ud_size { + let beta = LagrangeIndex::::from_flat_index(beta_idx, l0); + let contributions = compute_idx4(&beta); + + // All contributions should have correct l0 + for c in &contributions { + assert_eq!(c.l0, l0); + } + + if contributions.is_empty() { + zero_contribution_count += 1; + } + } + + // β = (Finite(1), Finite(1), Finite(1)) is the only β with zero contributions + assert_eq!( + zero_contribution_count, 1, + "Only β=(1,1,1) should have zero contributions" + ); + } + + #[test] + fn test_compute_idx4_binary_beta_filtered_by_u() { + // Binary β (Finite(0) or Finite(1)) DON'T all contribute to all rounds! + // Round i is filtered when β[i-1] = Finite(1) (i.e., u = value 1 ∉ Û_d) + let l0 = 3; + const D: usize = 3; + + // Iterate over all binary β + for b in 0..(1 << l0) { + // Construct beta with Finite(0) or Finite(1) values + let indices: Vec = (0..l0) + .map(|j| ((b >> (l0 - 1 - j)) & 1) + 1) // 0→index 1 (Finite(0)), 1→index 2 (Finite(1)) + .collect(); + let beta = tuple_from_indices::(&indices); + + let contributions = compute_idx4(&beta); + + // Count how many positions have Finite(1) (value 1) + let num_ones = beta + .0 + .iter() + .filter(|&&p| p == LagrangePoint::Finite(1)) + .count(); + + // Expected contributions = l0 - num_ones (each position with value 1 filters that round) + let expected_len = l0 - num_ones; + assert_eq!( + contributions.len(), + expected_len, + "β={:?} should have {} contributions (filtering {} rounds with u=1)", + beta, + expected_len, + num_ones + ); + + // Verify correct rounds are present/missing + for round in 1..=l0 { + let has_round = contributions.iter().any(|c| c.round == round); + let u = beta.0[round - 1]; + + if u == LagrangePoint::Finite(1) { + // u = 1 ∉ Û_d → round should be filtered + assert!( + !has_round, + "β={:?} should NOT have round {} (u=1)", + beta, round + ); + } else { + // u ∈ Û_d → round should be present (suffix is always binary for binary β) + assert!(has_round, "β={:?} should have round {} (u≠1)", beta, round); + } + } + } + } + + #[test] + fn test_compute_idx4_index_bounds() { + let l0 = 4; + const D: usize = 3; + let base = D + 1; + let prefix_ud_size = base.pow(l0 as u32); + + for beta_idx in 0..prefix_ud_size { + let beta = LagrangeIndex::::from_flat_index(beta_idx, l0); + let contributions = compute_idx4(&beta); + + for c in contributions { + // l0 should match + assert_eq!(c.l0, l0); + + // Round bound: i ∈ [1, ℓ₀] + assert!( + c.round >= 1 && c.round <= l0, + "round {} out of [1, {}] for β={:?}", + c.round, + l0, + beta + ); + + // v_idx bound: v ∈ U_d^{i-1}, so v_idx < (d+1)^{i-1} + let max_v_idx = base.pow(c.prefix_len() as u32); + assert!( + c.v_idx < max_v_idx, + "v_idx {} >= {} for round {} β={:?}", + c.v_idx, + max_v_idx, + c.round, + beta + ); + + // u should be valid LagrangeHatPoint (this is enforced by type system) + let u_idx = c.u.to_index(); + assert!(u_idx < D, "u_idx {} >= {} for β={:?}", u_idx, D, beta); + + // y_idx bound: y ∈ {0,1}^{ℓ₀-i}, so y_idx < 2^{ℓ₀-i} + let max_y_idx = 1usize << c.suffix_len(); + assert!( + c.y_idx < max_y_idx, + "y_idx {} >= {} for round {} β={:?}", + c.y_idx, + max_y_idx, + c.round, + beta + ); + } + } + } + + #[test] + fn test_compute_idx4_v_u_decode() { + // Test cases that have at least some contributions + // Note: all Finite(1) has NO contributions so we exclude it + let test_cases: Vec> = vec![ + tuple_from_indices(&[1, 1, 1]), // all u=Finite(0), 3 contributions + tuple_from_indices(&[0, 1, 1]), // u=∞,Finite(0),Finite(0), 3 contributions + tuple_from_indices(&[3, 0, 1]), // u=Finite(2),∞,Finite(0), 3 contributions + tuple_from_indices(&[3, 3, 3]), // u=Finite(2)×3, only round 3 due to non-binary suffix + ]; + + for beta in test_cases { + let contributions = compute_idx4(&beta); + + for c in contributions { + // v should be β[0..round-1] + let expected_v = LagrangeIndex(beta.0[0..c.prefix_len()].to_vec()); + let expected_v_idx = expected_v.to_flat_index(); + assert_eq!( + c.v_idx, expected_v_idx, + "v_idx mismatch for round {} β={:?}", + c.round, beta + ); + + // u should be the Û_d point from β[round-1] + let u = beta.0[c.round - 1]; + let expected_u = u.to_ud_hat().unwrap(); + assert_eq!( + c.u, expected_u, + "u mismatch for round {} β={:?}: expected {:?} (from {:?})", + c.round, beta, expected_u, u + ); + } + } + } + + #[test] + fn test_compute_idx4_y_idx_encoding() { + // β = (Finite(0), Finite(1), Finite(0), Finite(1)) → point (0, 1, 0, 1) + // Rounds 2 and 4 filtered because u = value 1 ∉ Û_d + let beta = tuple_from_indices::<3>(&[1, 2, 1, 2]); + let contributions = compute_idx4(&beta); + + assert_eq!(contributions.len(), 2); // Only rounds 1 and 3 + + // Round 1: u=Finite(0) ∈ Û_d, y = suffix → bits (1,0,1) → y_idx = 5 + let c1 = contributions.iter().find(|c| c.round == 1).unwrap(); + assert_eq!(c1.y_idx, 0b101); + assert_eq!(c1.suffix_len(), 3); + assert_eq!(c1.u, LagrangeHatPoint::Finite(0)); + + // Round 2: FILTERED (u = Finite(1) ∉ Û_d) + assert!(contributions.iter().all(|c| c.round != 2)); + + // Round 3: u=Finite(0) ∈ Û_d, y = suffix → bits (1,) → y_idx = 1 + let c3 = contributions.iter().find(|c| c.round == 3).unwrap(); + assert_eq!(c3.y_idx, 0b1); + assert_eq!(c3.suffix_len(), 1); + assert_eq!(c3.u, LagrangeHatPoint::Finite(0)); + + // Round 4: FILTERED (u = Finite(1) ∉ Û_d) + assert!(contributions.iter().all(|c| c.round != 4)); + } + + #[test] + fn test_compute_idx4_y_idx_encoding_no_filtering() { + // Test y_idx encoding with a β that has no filtering + let l0 = 4; + const D: usize = 3; + + // β = (Finite(0), Finite(0), Finite(0), Finite(0)) → all zeros + // All u = Finite(0) ∈ Û_d, so all rounds present + let beta = tuple_from_indices::(&[1, 1, 1, 1]); + let contributions = compute_idx4(&beta); + + assert_eq!(contributions.len(), l0); + + // Round 1: y = suffix → bits (0,0,0) → y_idx = 0 + let c1 = contributions.iter().find(|c| c.round == 1).unwrap(); + assert_eq!(c1.y_idx, 0b000); + assert_eq!(c1.suffix_len(), 3); + + // Round 2: y = suffix → bits (0,0) → y_idx = 0 + let c2 = contributions.iter().find(|c| c.round == 2).unwrap(); + assert_eq!(c2.y_idx, 0b00); + assert_eq!(c2.suffix_len(), 2); + + // Round 3: y = suffix → bits (0,) → y_idx = 0 + let c3 = contributions.iter().find(|c| c.round == 3).unwrap(); + assert_eq!(c3.y_idx, 0b0); + assert_eq!(c3.suffix_len(), 1); + + // Round 4: y = suffix → () → y_idx = 0 + let c4 = contributions.iter().find(|c| c.round == 4).unwrap(); + assert_eq!(c4.y_idx, 0); + assert_eq!(c4.suffix_len(), 0); + } + + #[test] + fn test_compute_idx4_suffix_must_be_binary_values() { + // Two conditions filter rounds: + // 1. The suffix y must consist only of binary values (Finite(0) or Finite(1)) + // 2. The coordinate u must be in Û_d (i.e., u ≠ Finite(1)) + const D: usize = 3; + + // β with non-binary value (Finite(2)) in suffix position + // indices (1, 3, 2) → values (Finite(0), Finite(2), Finite(1)) + let beta = tuple_from_indices::(&[1, 3, 2]); + let contributions = compute_idx4(&beta); + + // Round 1: u=Finite(0) ∈ Û_d ✓, suffix has Finite(2) → SKIP (non-binary suffix) + // Round 2: u=Finite(2) ∈ Û_d ✓, suffix has Finite(1) binary ✓ → OK + // Round 3: u=Finite(1) ∉ Û_d ✗ → SKIP (u not in Û_d) + assert_eq!(contributions.len(), 1); // Only round 2 + assert!(contributions.iter().any(|c| c.round == 2)); + assert!(contributions.iter().all(|c| c.round != 1 && c.round != 3)); + + // β with ∞ in suffix position + // indices (2, 0, 1) → values (Finite(1), ∞, Finite(0)) + let beta = tuple_from_indices::(&[2, 0, 1]); + let contributions = compute_idx4(&beta); + + // Round 1: u=Finite(1) ∉ Û_d ✗ → SKIP (u not in Û_d) + // Round 2: u=∞ ∈ Û_d ✓, suffix has Finite(0) binary ✓ → OK + // Round 3: u=Finite(0) ∈ Û_d ✓, suffix=() binary ✓ → OK + assert_eq!(contributions.len(), 2); // Rounds 2 and 3 + assert!(contributions.iter().all(|c| c.round != 1)); + assert!(contributions.iter().any(|c| c.round == 2)); + assert!(contributions.iter().any(|c| c.round == 3)); + } +} diff --git a/src/lagrange_accumulator/mat_vec_mle.rs b/src/lagrange_accumulator/mat_vec_mle.rs new file mode 100644 index 0000000..34160a5 --- /dev/null +++ b/src/lagrange_accumulator/mat_vec_mle.rs @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Trait for computing matrix-vector MLE evaluations. +//! +//! This module provides [`MatVecMLE`], which computes multilinear extensions of +//! matrix-vector products Az, Bz, Cz. It abstracts over: +//! - Field-element witnesses (`MultilinearPolynomial`) +//! - Small-value witnesses (`MultilinearPolynomial`, `MultilinearPolynomial`) + +use crate::{polys::multilinear::MultilinearPolynomial, small_field::SmallValueField}; +use ff::PrimeField; +use std::ops::{Add, Sub}; + +/// Trait for multilinear extensions of matrix-vector products (Az, Bz, Cz). +/// +/// Abstracts over field-element witnesses vs small-value (i32/i64) witnesses, +/// enabling the same accumulator-building code to work with both representations. +/// +/// # Type Parameters +/// +/// - `S`: The field type used for accumulation (always a `PrimeField`) +/// +/// # Associated Types +/// +/// - `Value`: The witness coefficient type (`S` for field polynomials, `i32`/`i64` for small-value) +/// - `Product`: The product type (`S` for field, `i64`/`i128` for small-value to avoid overflow) +pub trait MatVecMLE: Sync { + /// The witness value type (S for field, i32/i64 for small) + type Value: Copy + Default + Add + Sub + Send + Sync; + + /// The product type (S for field, i64/i128 for small) + type Product: Copy; + + /// Get witness value at index + fn get(&self, idx: usize) -> Self::Value; + + /// Get polynomial length + fn len(&self) -> usize; + + /// Returns true if the polynomial is empty. + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Multiply two witness values: a × b + fn multiply_witnesses(a: Self::Value, b: Self::Value) -> Self::Product; + + /// Convert a product to a field element (for immediate reduction path). + /// This is used by `DelayedModularReductionDisabled` to avoid going through unreduced form. + fn product_to_field(prod: Self::Product) -> S; +} + +/// Macro to implement MatVecMLE for field-element polynomials. +/// This avoids conflicting with the i32/i64 impls due to Rust's coherence rules. +macro_rules! impl_mat_vec_mle_for_field { + ($($field:ty),* $(,)?) => { + $( + impl MatVecMLE<$field> for MultilinearPolynomial<$field> { + type Value = $field; + type Product = $field; + + #[inline] + fn get(&self, idx: usize) -> $field { + self.Z[idx] + } + + #[inline] + fn len(&self) -> usize { + self.Z.len() + } + + #[inline] + fn multiply_witnesses(a: $field, b: $field) -> $field { + a * b + } + + #[inline] + fn product_to_field(prod: $field) -> $field { + prod // Already a field element + } + } + )* + }; +} + +// Implement for supported field types +use crate::provider::{ + bn254::bn254, + pasta::{pallas, vesta}, +}; + +impl_mat_vec_mle_for_field!(pallas::Scalar, vesta::Scalar, bn254::Scalar,); + +/// Implementation for i32-valued polynomials (i32 coefficients, i64 products). +impl + Sync> MatVecMLE + for MultilinearPolynomial +{ + type Value = i32; + type Product = i64; + + #[inline] + fn get(&self, idx: usize) -> i32 { + self.Z[idx] + } + + #[inline] + fn len(&self) -> usize { + self.Z.len() + } + + #[inline] + fn multiply_witnesses(a: i32, b: i32) -> i64 { + (a as i64) * (b as i64) + } + + #[inline] + fn product_to_field(prod: i64) -> S { + S::intermediate_to_field(prod) + } +} + +/// Implementation for i64-valued polynomials (i64 coefficients, i128 products). +impl + Sync> MatVecMLE + for MultilinearPolynomial +{ + type Value = i64; + type Product = i128; + + #[inline] + fn get(&self, idx: usize) -> i64 { + self.Z[idx] + } + + #[inline] + fn len(&self) -> usize { + self.Z.len() + } + + #[inline] + fn multiply_witnesses(a: i64, b: i64) -> i128 { + (a as i128) * (b as i128) + } + + #[inline] + fn product_to_field(prod: i128) -> S { + S::intermediate_to_field(prod) + } +} diff --git a/src/lagrange_accumulator/mod.rs b/src/lagrange_accumulator/mod.rs new file mode 100644 index 0000000..919459e --- /dev/null +++ b/src/lagrange_accumulator/mod.rs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Lagrange accumulator algorithm for small-value sumcheck optimization (Algorithm 6). +//! +//! This module implements the Lagrange domain extension technique from IACR 2025/1117, +//! which accelerates sumcheck provers when witness coefficients are small integers. +//! +//! # Module Structure +//! +//! - [`domain`]: Domain types (LagrangePoint, LagrangeHatPoint, LagrangeIndex) +//! - [`evals`]: Evaluation containers (LagrangeEvals, LagrangeHatEvals) +//! - [`basis`]: Lagrange basis computation (LagrangeBasisFactory, LagrangeCoeff) +//! - [`extension`]: Multilinear polynomial extension (Procedure 6) +//! - [`accumulator`]: Accumulator data structures +//! - [`accumulator_builder`]: Accumulator construction (Procedure 9) +//! - [`index`]: Index mapping (Definition A.5) +//! - [`thread_state`]: Thread-local buffers for parallel execution +//! - [`eq_round`]: Per-round equality factor tracking +//! - [`witness`]: Witness polynomial abstraction (MatVecMLE trait) + +mod accumulator; +mod accumulator_builder; +mod basis; +mod delay_modular_reduction_mode; +mod domain; +mod eq_round; +mod evals; +mod extension; +mod index; +mod mat_vec_mle; +mod thread_state; + +// Domain types +pub use domain::{LagrangeHatPoint, LagrangeIndex, LagrangePoint, ValueOneExcluded}; + +// Evaluation containers +pub use evals::{LagrangeEvals, LagrangeHatEvals}; + +// Basis computation +pub use basis::{LagrangeBasisFactory, LagrangeCoeff}; + +// Extension +pub use extension::LagrangeEvaluatedMultilinearPolynomial; + +// Accumulators +pub use accumulator::{LagrangeAccumulators, RoundAccumulator}; + +// Builder functions +pub use accumulator_builder::{SPARTAN_T_DEGREE, build_accumulators, build_accumulators_spartan}; + +// Delayed modular reduction mode selection +pub use delay_modular_reduction_mode::{ + AccumulateProduct, DelayedModularReductionDisabled, DelayedModularReductionEnabled, + DelayedModularReductionMode, +}; + +// Index computation +pub use index::{AccumulatorPrefixIndex, CachedPrefixIndex, compute_idx4}; + +// Eq round factor and derivation +pub use eq_round::{EqRoundFactor, derive_t1}; +pub use mat_vec_mle::MatVecMLE; diff --git a/src/lagrange_accumulator/thread_state.rs b/src/lagrange_accumulator/thread_state.rs new file mode 100644 index 0000000..2ddda08 --- /dev/null +++ b/src/lagrange_accumulator/thread_state.rs @@ -0,0 +1,176 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Thread-local scratch buffers for accumulator building. +//! +//! These structs eliminate per-iteration heap allocations in the parallel fold loops +//! of `build_accumulators_spartan` and `build_accumulators`. By hoisting buffer +//! allocations to the fold identity closure (called once per Rayon thread subdivision), +//! we reduce allocations from O(num_x_out) to O(num_threads). + +use super::{ + accumulator::LagrangeAccumulators, delay_modular_reduction_mode::DelayedModularReductionMode, + mat_vec_mle::MatVecMLE, +}; +use ff::PrimeField; + +/// Thread-local scratch buffers for `build_accumulators_spartan`. +/// +/// # Motivation +/// +/// Without this optimization, the fold closure allocates 5 vectors on every x_out iteration: +/// ```ignore +/// |mut acc, x_out_bits| { +/// let mut beta_partial_sums = vec![S::ZERO; num_betas]; // ALLOC +/// let mut az_pref = vec![...]; // ALLOC +/// let mut bz_pref = vec![...]; // ALLOC +/// let mut buf_a = vec![...]; // ALLOC +/// let mut buf_b = vec![...]; // ALLOC +/// ... +/// } +/// ``` +/// +/// For typical workloads (l=20, l0=4), num_x_out = 2^6 = 64, causing 320 allocations +/// per parallel task. With Rayon's work-stealing, this leads to significant allocator +/// contention and cache pollution. +/// +/// # Solution +/// +/// By hoisting these buffers into a struct created once per Rayon thread subdivision +/// (in the fold identity closure), we reduce allocations from O(num_x_out) to O(num_threads). +/// The `reset_partial_sums()` method zeros the sums between iterations (cheap memset). +/// +/// # Buffer Layout +/// +/// - `az_buf_curr/scratch`, `bz_buf_curr/scratch`: Separate buffer pairs for Az and Bz +/// Lagrange extensions. After `extend_in_place`, the result is always in `*_buf_curr`. +/// We need 4 buffers (not 2) because both extension results must be available +/// simultaneously to compute Az(β) × Bz(β) for each β. +/// - `scatter_acc`: Bucket accumulators for scatter phase (type determined by Mode). +/// +/// Note: `eyx` precomputation has been removed. We now use `e_y` directly in the +/// scatter phase, computing `z_beta = ex * tA_red` once per beta instead of +/// precomputing `eyx = ey * ex` for all y indices. +/// +/// # Type Parameters +/// +/// - `S`: Field type for final accumulator values +/// - `V`: Witness value type (i32 for small-value, S for field) +/// - `P`: Polynomial type implementing [`MatVecMLE`] +/// - `Mode`: Delayed modular reduction mode selection ([`super::DelayedModularReductionEnabled`] or [`super::DelayedModularReductionDisabled`]) +/// - `D`: Polynomial degree bound +pub(crate) struct SpartanThreadState +where + S: PrimeField + Send + Sync, + V: Copy + Default, + P: MatVecMLE, + Mode: DelayedModularReductionMode, +{ + /// Partial sums indexed by β, accumulated over the x_in loop. + /// Type determined by Mode: unreduced for DelayedModularReductionEnabled, reduced for DelayedModularReductionDisabled. + /// Reset each x_out iteration. + pub partial_sums: Vec, + /// Bucket accumulators for scatter phase. + /// Type determined by Mode: unreduced F×F for DelayedModularReductionEnabled, LagrangeAccumulator for DelayedModularReductionDisabled. + pub scatter_acc: LagrangeAccumulators, + /// Prefix evaluations of Az for current suffix. Size: 2^l0 + pub az_pref: Vec, + /// Prefix evaluations of Bz for current suffix. Size: 2^l0 + pub bz_pref: Vec, + /// Result buffer for Az Lagrange extension. After `extend_in_place`, contains the extended values. + pub az_buf_curr: Vec, + /// Scratch buffer for Az Lagrange extension. Used during iterative extension. + pub az_buf_scratch: Vec, + /// Result buffer for Bz Lagrange extension. After `extend_in_place`, contains the extended values. + pub bz_buf_curr: Vec, + /// Scratch buffer for Bz Lagrange extension. Used during iterative extension. + pub bz_buf_scratch: Vec, + /// Reusable buffer for filtered (beta_idx, reduced_value) pairs in scatter phase. + /// Eliminates per-x_out allocation overhead. + pub beta_values: Vec<(usize, S)>, +} + +impl SpartanThreadState +where + S: PrimeField + Send + Sync, + V: Copy + Default, + P: MatVecMLE, + Mode: DelayedModularReductionMode, +{ + pub fn new(l0: usize, num_betas: usize, prefix_size: usize, ext_size: usize) -> Self { + Self { + partial_sums: vec![Mode::PartialSum::default(); num_betas], + scatter_acc: LagrangeAccumulators::new(l0), + az_pref: vec![V::default(); prefix_size], + bz_pref: vec![V::default(); prefix_size], + az_buf_curr: vec![V::default(); ext_size], + az_buf_scratch: vec![V::default(); ext_size], + bz_buf_curr: vec![V::default(); ext_size], + bz_buf_scratch: vec![V::default(); ext_size], + beta_values: Vec::with_capacity(num_betas), + } + } + + /// Zero out partial sums for the next x_out iteration. + /// This is O(num_betas) but much cheaper than reallocating. + #[inline] + pub fn reset_partial_sums(&mut self) { + for sum in &mut self.partial_sums { + *sum = Mode::PartialSum::default(); + } + self.beta_values.clear(); + } +} + +/// Thread-local scratch buffers for the generic `build_accumulators`. +/// +/// Similar to `SpartanThreadState`, but handles a variable number of polynomials (d). +/// Each polynomial needs its own buffer pair for Lagrange extension since all d +/// extension results must be available simultaneously to compute ∏ p_k(β). +/// +/// See `SpartanThreadState` documentation for the full motivation. +pub(crate) struct GenericThreadState { + /// Accumulator being built (the actual output) + pub acc: LagrangeAccumulators, + /// Partial sums indexed by β. Reset each x_out iteration. + pub beta_partial_sums: Vec, + /// Prefix evaluations for each of the d polynomials. Size: d × 2^l0 + pub poly_prefs: Vec>, + /// Buffer pairs for each polynomial's Lagrange extension: (result, scratch). + /// After `extend_in_place`, the result is always in the first element of each pair. + /// Size: d × 2 × (D+1)^l0 + pub buf_pairs: Vec<(Vec, Vec)>, + /// On-the-fly computed ey*ex scratch buffer. Size per round: 2^{l0-1-round}. + /// Total size: 2^l0 - 1 (e.g., 7 for l0=3). Stays hot in L1 cache. + pub eyx: Vec>, +} + +impl GenericThreadState { + pub fn new( + l0: usize, + num_betas: usize, + prefix_size: usize, + ext_size: usize, + num_polys: usize, + e_y_sizes: &[usize], + ) -> Self { + Self { + acc: LagrangeAccumulators::new(l0), + beta_partial_sums: vec![S::ZERO; num_betas], + poly_prefs: (0..num_polys).map(|_| vec![S::ZERO; prefix_size]).collect(), + buf_pairs: (0..num_polys) + .map(|_| (vec![S::ZERO; ext_size], vec![S::ZERO; ext_size])) + .collect(), + eyx: e_y_sizes.iter().map(|&sz| vec![S::ZERO; sz]).collect(), + } + } + + /// Zero out partial sums for the next x_out iteration. + #[inline] + pub fn reset_partial_sums(&mut self) { + self.beta_partial_sums.fill(S::ZERO); + } +} diff --git a/src/lib.rs b/src/lib.rs index 5da0473..07f79d1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,7 +9,6 @@ //! that is generic over the polynomial commitment and evaluation argument (i.e., a PCS). #![deny( warnings, - unused, future_incompatible, nonstandard_style, rust_2018_idioms, @@ -22,27 +21,35 @@ #![forbid(unsafe_code)] // private modules +mod csr; mod digest; mod math; mod nifs; -mod polys; mod r1cs; -mod sumcheck; mod zk; +/// Lagrange accumulator algorithm for small-value sumcheck optimization (Algorithm 6). +pub mod lagrange_accumulator; + +// modules with some public items for benchmarking +pub mod polys; +pub mod small_field; +pub mod sumcheck; + #[macro_use] mod macros; // public modules pub mod bellpepper; pub mod errors; +pub mod gadgets; pub mod provider; pub mod traits; // public modules for different proof systems pub mod neutronnova_zk; // NeutronNova with zero-knowledge pub mod spartan; // Spartan without zero-knowledge -pub mod spartan_zk; // Spartan with zero-knowledge +pub mod spartan_zk; // Spartan with zero-knowledge /// Start a span + timer, return `(Span, Instant)`. macro_rules! start_span { diff --git a/src/polys/eq.rs b/src/polys/eq.rs index 9f990bf..62e4f51 100644 --- a/src/polys/eq.rs +++ b/src/polys/eq.rs @@ -86,10 +86,74 @@ impl FromIterator for EqPolynomial { } } +/// Computes suffix eq-polynomials: E_y[i] = eq(τ[i+1:ℓ₀], y) for all i ∈ [0, ℓ₀). +/// +/// This is used for Algorithm 6 (small-value sumcheck) in +/// . +/// +/// Uses a pyramid approach: build from the end (τ[ℓ₀-1]) backwards to τ[1]. +/// Each step extends the previous suffix by prepending one more τ value. +/// +/// # Arguments +/// * `taus` - τ values for the first ℓ₀ variables (τ[0:ℓ₀]) +/// * `l0` - number of small-value rounds +/// +/// # Returns +/// Vec of length ℓ₀, where `result[i]` has 2^{ℓ₀-i-1} elements. +/// `result[i][y]` = eq(τ[i+1:ℓ₀], y) for y ∈ {0,1}^{ℓ₀-i-1} +/// +/// # Complexity +/// O(2^ℓ₀) total field multiplications (vs O(ℓ₀ · 2^ℓ₀) naive) +// Allow dead code until later chunks use this function +pub fn compute_suffix_eq_pyramid(taus: &[S], l0: usize) -> Vec> { + // Handle l0 == 0: no suffix tables needed (small-value optimization disabled) + if l0 == 0 { + return Vec::new(); + } + + assert!(taus.len() >= l0, "taus must have at least l0 elements"); + + let mut result: Vec> = vec![vec![]; l0]; + + // Base case: E_y[l0-1] = eq([], ·) = [1] (empty suffix) + result[l0 - 1] = vec![S::ONE]; + + // Build backwards: each step prepends one τ value + // E_y[i] = eq(τ[i+1:l0], ·) is built from E_y[i+1] = eq(τ[i+2:l0], ·) + // by prepending τ[i+1] + for i in (0..l0 - 1).rev() { + let tau = taus[i + 1]; + let prev = &result[i + 1]; + let prev_len = prev.len(); + + // New table has 2× the entries (prepending a new variable) + // For multilinear indexing: first variable is high bit + // new_idx = new_bit * prev_len + old_idx + // + // new_bit = 0: eq factor is (1 - τ) + // new_bit = 1: eq factor is τ + // + // Optimized: use 1 multiplication per element instead of 2 + // hi = v * τ, lo = v - hi = v * (1 - τ) + let mut next = vec![S::ZERO; prev_len * 2]; + let (lo_half, hi_half) = next.split_at_mut(prev_len); + + for ((lo, hi), v) in lo_half.iter_mut().zip(hi_half.iter_mut()).zip(prev.iter()) { + *hi = *v * tau; + *lo = *v - *hi; + } + + result[i] = next; + } + + result +} + #[cfg(test)] mod tests { use super::*; use crate::provider::pasta::pallas; + use ff::Field; fn test_eq_polynomial_with() { let eq_poly = EqPolynomial::::new(vec![F::ONE, F::ZERO, F::ONE]); @@ -113,4 +177,173 @@ mod tests { fn test_eq_polynomial() { test_eq_polynomial_with::(); } + + // === Suffix Eq Pyramid tests === + + #[test] + fn test_suffix_pyramid_l0_zero() { + let taus: Vec = vec![pallas::Scalar::from(1), pallas::Scalar::from(2)]; + let pyramid = compute_suffix_eq_pyramid(&taus, 0); + assert!(pyramid.is_empty()); + } + + #[test] + fn test_suffix_pyramid_sizes() { + let l0 = 4; + let taus: Vec = (0..l0) + .map(|i| pallas::Scalar::from(i as u64 + 2)) + .collect(); + let pyramid = compute_suffix_eq_pyramid(&taus, l0); + + assert_eq!(pyramid.len(), l0); + #[allow(clippy::needless_range_loop)] + for i in 0..l0 { + let expected_size = 1 << (l0 - i - 1); + assert_eq!( + pyramid[i].len(), + expected_size, + "E_y[{}] should have size {}", + i, + expected_size + ); + } + } + + #[test] + fn test_suffix_pyramid_base_case() { + use ff::Field; + let l0 = 3; + let taus: Vec = (0..l0) + .map(|_| pallas::Scalar::random(&mut rand_core::OsRng)) + .collect(); + let pyramid = compute_suffix_eq_pyramid(&taus, l0); + + // E_y[l0-1] = [1] (empty suffix) + assert_eq!(pyramid[l0 - 1].len(), 1); + assert_eq!(pyramid[l0 - 1][0], pallas::Scalar::ONE); + } + + #[test] + fn test_suffix_pyramid_single_tau() { + use ff::Field; + let l0 = 3; + let taus: Vec = (0..l0) + .map(|_| pallas::Scalar::random(&mut rand_core::OsRng)) + .collect(); + let pyramid = compute_suffix_eq_pyramid(&taus, l0); + + // E_y[l0-2] = eq([τ_{l0-1}], ·) = [1-τ, τ] + let tau_last = taus[l0 - 1]; + assert_eq!(pyramid[l0 - 2].len(), 2); + assert_eq!( + pyramid[l0 - 2][0], + pallas::Scalar::ONE - tau_last, + "eq(τ, 0) = 1-τ" + ); + assert_eq!(pyramid[l0 - 2][1], tau_last, "eq(τ, 1) = τ"); + } + + #[test] + fn test_suffix_pyramid_matches_naive() { + use ff::Field; + // Verify pyramid matches independent computation + let l0 = 4; + let taus: Vec = (0..l0) + .map(|_| pallas::Scalar::random(&mut rand_core::OsRng)) + .collect(); + + let pyramid = compute_suffix_eq_pyramid(&taus, l0); + + for i in 0..l0 { + let naive = if i + 1 >= l0 { + vec![pallas::Scalar::ONE] + } else { + EqPolynomial::evals_from_points(&taus[i + 1..l0]) + }; + + assert_eq!(pyramid[i].len(), naive.len(), "Size mismatch at i={}", i); + + for (j, (&p, &n)) in pyramid[i].iter().zip(naive.iter()).enumerate() { + assert_eq!(p, n, "Value mismatch at E_y[{}][{}]", i, j); + } + } + } + + #[test] + fn test_suffix_pyramid_indexing() { + use ff::Field; + // Verify index semantics: pyramid[i][y] = eq(τ[i+1:l0], y) + let l0 = 3; + let tau1 = pallas::Scalar::from(5); + let tau2 = pallas::Scalar::from(7); + let tau0 = pallas::Scalar::from(3); // Not used in any E_y suffix + let taus = vec![tau0, tau1, tau2]; + + let pyramid = compute_suffix_eq_pyramid(&taus, l0); + + // E_y[0] = eq([τ₁, τ₂], y) for y ∈ {0,1}² + // Index 0 = (0,0): eq = (1-τ₁)(1-τ₂) + // Index 1 = (0,1): eq = (1-τ₁)(τ₂) + // Index 2 = (1,0): eq = (τ₁)(1-τ₂) + // Index 3 = (1,1): eq = (τ₁)(τ₂) + assert_eq!( + pyramid[0][0], + (pallas::Scalar::ONE - tau1) * (pallas::Scalar::ONE - tau2) + ); + assert_eq!(pyramid[0][1], (pallas::Scalar::ONE - tau1) * tau2); + assert_eq!(pyramid[0][2], tau1 * (pallas::Scalar::ONE - tau2)); + assert_eq!(pyramid[0][3], tau1 * tau2); + + // E_y[1] = eq([τ₂], y) for y ∈ {0,1} + assert_eq!(pyramid[1][0], pallas::Scalar::ONE - tau2); + assert_eq!(pyramid[1][1], tau2); + + // E_y[2] = eq([], ·) = [1] + assert_eq!(pyramid[2][0], pallas::Scalar::ONE); + } + + /// Ensure evals_from_points uses MSB-first indexing and matches direct product formula. + #[test] + #[allow(clippy::needless_range_loop)] + fn test_eq_table_index_convention() { + let r = vec![ + pallas::Scalar::from(2u64), + pallas::Scalar::from(3u64), + pallas::Scalar::from(5u64), + ]; + let m = r.len(); + let evals = EqPolynomial::evals_from_points(&r); + assert_eq!(evals.len(), 1 << m); + + for idx in 0..(1usize << m) { + let mut expected = pallas::Scalar::ONE; + for j in 0..m { + // MSB-first: bit j of idx (from left) corresponds to variable j + let bit = (idx >> (m - 1 - j)) & 1; + expected *= if bit == 1 { + r[j] + } else { + pallas::Scalar::ONE - r[j] + }; + } + assert_eq!( + evals[idx], expected, + "Mismatch at idx {}: got {:?}, expected {:?}", + idx, evals[idx], expected + ); + } + } + + /// Spot-check specific values to catch bit-order flips. + #[test] + fn test_eq_table_specific_values() { + // m=2, r = [2,3]; MSB-first convention + let r = vec![pallas::Scalar::from(2u64), pallas::Scalar::from(3u64)]; + let evals = EqPolynomial::evals_from_points(&r); + + assert_eq!(evals[0], pallas::Scalar::from(2u64)); // (1-2)(1-3) = 2 + assert_eq!(evals[1], -pallas::Scalar::from(3u64)); // (1-2)*3 = -3 + assert_eq!(evals[2], -pallas::Scalar::from(4u64)); // 2*(1-3) = -4 + assert_eq!(evals[3], pallas::Scalar::from(6u64)); // 2*3 = 6 + } } diff --git a/src/polys/multilinear.rs b/src/polys/multilinear.rs index bfc6a37..fde45d6 100644 --- a/src/polys/multilinear.rs +++ b/src/polys/multilinear.rs @@ -8,9 +8,9 @@ //! - `MultilinearPolynomial`: Dense representation of multilinear polynomials, represented by evaluations over all possible binary inputs. //! - `SparsePolynomial`: Efficient representation of sparse multilinear polynomials, storing only non-zero evaluations. -use crate::{math::Math, polys::eq::EqPolynomial, zip_with_for_each}; +use crate::{math::Math, polys::eq::EqPolynomial, small_field::SmallValueField, zip_with_for_each}; use core::ops::Index; -use ff::PrimeField; +use ff::{Field, PrimeField}; use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -30,24 +30,30 @@ use serde::{Deserialize, Serialize}; /// $$ /// /// Vector $Z$ indicates $Z(e)$ where $e$ ranges from $0$ to $2^m-1$. +/// +/// The type parameter `T` is the coefficient type. Typically this is a field element, +/// but can also be any type with ring operations (add, sub, mul, zero). #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct MultilinearPolynomial { - pub(crate) Z: Vec, // evaluations of the polynomial in all the 2^num_vars Boolean inputs +pub struct MultilinearPolynomial { + pub(crate) Z: Vec, // evaluations of the polynomial in all the 2^num_vars Boolean inputs } -impl MultilinearPolynomial { +impl MultilinearPolynomial { /// Creates a new `MultilinearPolynomial` from the given evaluations. /// /// # Panics /// The number of evaluations must be a power of two. - pub fn new(Z: Vec) -> Self { + pub fn new(Z: Vec) -> Self { MultilinearPolynomial { Z } } +} +impl MultilinearPolynomial { /// Binds the polynomial's top variable using the given scalar. /// /// This operation modifies the polynomial in-place. - pub fn bind_poly_var_top(&mut self, r: &Scalar) { + /// Formula: new[i] = old[i] + r * (old[i + n] - old[i]) + pub fn bind_poly_var_top(&mut self, r: &T) { assert!( self.Z.len() >= 2, "Vector Z must have at least two elements to bind the top variable." @@ -58,14 +64,15 @@ impl MultilinearPolynomial { let (left, right) = self.Z.split_at_mut(n); zip_with_for_each!((left.par_iter_mut(), right.par_iter()), |a, b| { + // Field types implement Copy, so no cloning needed *a += *r * (*b - *a); }); self.Z.truncate(n); } - /// binds the polynomial's top variables using the given scalars. - pub fn bind_with(poly: &[Scalar], L: &[Scalar], r_len: usize) -> Vec { + /// Binds the polynomial's top variables using the given scalars. + pub fn bind_with(poly: &[T], L: &[T], r_len: usize) -> Vec { assert_eq!( poly.len(), L.len() * r_len, @@ -79,7 +86,7 @@ impl MultilinearPolynomial { (0..r_len) .into_par_iter() .map(|i| { - let mut acc = Scalar::ZERO; + let mut acc = T::ZERO; for j in 0..L.len() { // row-major: index = j * r_len + i acc += L[j] * poly[j * r_len + i]; @@ -90,11 +97,108 @@ impl MultilinearPolynomial { } } -impl Index for MultilinearPolynomial { - type Output = Scalar; +impl MultilinearPolynomial { + /// Gathers prefix evaluations p(b, suffix) for all binary prefixes b ∈ {0,1}^ℓ₀. + /// + /// For a polynomial with ℓ variables, this extracts a strided slice where: + /// - First ℓ₀ variables form the "prefix" (high bits) + /// - Remaining ℓ-ℓ₀ variables form the "suffix" (low bits) + /// + /// Index layout: `index = (prefix << suffix_vars) | suffix` + /// + /// # Arguments + /// * `l0` - Number of prefix variables + /// * `suffix` - Fixed suffix value in range [0, 2^{ℓ-ℓ₀}) + /// + /// # Returns + /// MultilinearPolynomial of size 2^ℓ₀ where result[prefix] = self[(prefix << suffix_vars) | suffix] + /// + /// # Example + /// For ℓ=4, ℓ₀=2, suffix=1: + /// - Returns polynomial with evals [self[1], self[5], self[9], self[13]] + /// - Indices: 0b0001, 0b0101, 0b1001, 0b1101 (prefix varies, suffix=01 fixed) + // Allow dead code until Chunk 7 (build_accumulators) uses this method + #[allow(dead_code)] + pub fn gather_prefix_evals(&self, l0: usize, suffix: usize) -> Self { + let l = self.Z.len().trailing_zeros() as usize; + debug_assert_eq!(self.Z.len(), 1 << l, "poly size must be power of 2"); + + let suffix_vars = l - l0; + let prefix_size = 1 << l0; + + debug_assert!(suffix < (1 << suffix_vars), "suffix out of range"); + + let mut Z = Vec::with_capacity(prefix_size); + for prefix in 0..prefix_size { + let idx = (prefix << suffix_vars) | suffix; + Z.push(self.Z[idx]); // Copy, no clone needed + } + + MultilinearPolynomial::new(Z) + } +} + +// ============================================================================ +// Small-value polynomial operations (MultilinearPolynomial) +// ============================================================================ + +impl MultilinearPolynomial { + /// Try to create from a field-element polynomial. + /// Returns None if any value doesn't fit in i32. + pub fn try_from_field>(poly: &MultilinearPolynomial) -> Option { + let evals: Option> = poly.Z.iter().map(|f| F::try_field_to_small(f)).collect(); + evals.map(Self::new) + } + + /// Get the number of variables. + pub fn num_vars(&self) -> usize { + self.Z.len().trailing_zeros() as usize + } + + /// Convert to field-element polynomial. + pub fn to_field>(&self) -> MultilinearPolynomial { + MultilinearPolynomial::new(self.Z.iter().map(|&s| F::small_to_field(s)).collect()) + } +} + +// ============================================================================ +// Small-value polynomial operations (MultilinearPolynomial) +// ============================================================================ + +impl MultilinearPolynomial { + /// Try to create from a field-element polynomial. + /// Returns None if any value doesn't fit in i64. + pub fn try_from_field>(poly: &MultilinearPolynomial) -> Option { + let evals: Option> = poly + .Z + .iter() + .map(|f| crate::small_field::try_field_to_i64(f)) + .collect(); + evals.map(Self::new) + } + + /// Get the number of variables. + pub fn num_vars(&self) -> usize { + self.Z.len().trailing_zeros() as usize + } + + /// Convert to field-element polynomial. + pub fn to_field>(&self) -> MultilinearPolynomial { + MultilinearPolynomial::new( + self + .Z + .iter() + .map(|&s| crate::small_field::i64_to_field(s)) + .collect(), + ) + } +} + +impl Index for MultilinearPolynomial { + type Output = T; #[inline(always)] - fn index(&self, _index: usize) -> &Scalar { + fn index(&self, _index: usize) -> &T { &(self.Z[_index]) } } @@ -137,6 +241,8 @@ impl SparsePolynomial { mod tests { use super::*; use crate::{provider::pasta::pallas, zip_with}; + use ff::Field; + use pallas::Scalar; use rand_core::{CryptoRng, OsRng, RngCore}; /// Evaluates the polynomial at the given point. @@ -302,4 +408,158 @@ mod tests { fn test_bind_and_evaluate() { bind_and_evaluate_with::(); } + + /// Explicit check that bind_poly_var_top matches manual linear interpolation on the MSB. + #[test] + fn test_bind_matches_direct_evaluation_explicit() { + // ℓ=3, poly[i] = i^2 + 1 + let l = 3; + let size = 1 << l; + let vals: Vec = (0..size) + .map(|i| Scalar::from((i * i + 1) as u64)) + .collect(); + let mut poly = MultilinearPolynomial::new(vals.clone()); + + let r = Scalar::from(7u64); + poly.bind_poly_var_top(&r); + assert_eq!(poly.Z.len(), size / 2); + + // Bound variable is the MSB: new[j] = (1-r)*vals[j] + r*vals[j+4] + for j in 0..(size / 2) { + let expected = (Scalar::ONE - r) * vals[j] + r * vals[j + size / 2]; + assert_eq!(poly.Z[j], expected, "Mismatch at j={}", j); + } + } + + /// Ensure "top" refers to the MSB (high-order variable), not the LSB. + #[test] + fn test_bind_top_is_msb_not_lsb() { + // ℓ=2, values encode (x0,x1) with x0 as MSB: [p(0,0), p(0,1), p(1,0), p(1,1)] + let vals = vec![ + Scalar::from(1u64), // (0,0) + Scalar::from(2u64), // (0,1) + Scalar::from(3u64), // (1,0) + Scalar::from(4u64), // (1,1) + ]; + let mut poly = MultilinearPolynomial::new(vals.clone()); + let r = Scalar::from(5u64); + + poly.bind_poly_var_top(&r); + assert_eq!(poly.Z.len(), 2); + + // Expected with MSB binding: + // new[0] = (1-r)*p(0,0) + r*p(1,0) = (1-5)*1 + 5*3 = 11 + // new[1] = (1-r)*p(0,1) + r*p(1,1) = (1-5)*2 + 5*4 = 12 + assert_eq!(poly.Z[0], Scalar::from(11u64)); + assert_eq!(poly.Z[1], Scalar::from(12u64)); + + // If LSB were bound, results would differ (6 and 8 respectively). + assert_ne!(poly.Z[0], Scalar::from(6u64)); + assert_ne!(poly.Z[1], Scalar::from(8u64)); + } + + // === gather_prefix_evals tests === + + #[test] + fn test_gather_prefix_evals_all_suffixes() { + // ℓ=4, ℓ₀=2: 4 variables, first 2 are prefix + // poly[i] = i, so value equals index for easy verification + let l = 4; + let l0 = 2; + let size = 1 << l; // 16 + + let evals: Vec = (0..size).map(|i| pallas::Scalar::from(i as u64)).collect(); + let poly = MultilinearPolynomial::new(evals); + + let num_prefix = 1 << l0; // 4 prefix combinations + let suffix_vars = l - l0; // 2 suffix variables + let num_suffix = 1 << suffix_vars; // 4 suffix combinations + + for suffix in 0..num_suffix { + let gathered = poly.gather_prefix_evals(l0, suffix); + assert_eq!(gathered.Z.len(), num_prefix); + + for prefix in 0..num_prefix { + let expected_idx = (prefix << suffix_vars) | suffix; + let expected = pallas::Scalar::from(expected_idx as u64); + assert_eq!( + gathered[prefix], expected, + "Mismatch at suffix={}, prefix={}", + suffix, prefix + ); + } + } + } + + #[test] + fn test_gather_prefix_l0_equals_l() { + // ℓ₀ = ℓ: no suffix variables, suffix must be 0 + let l0 = 3; + let evals: Vec = (0..8).map(|i| pallas::Scalar::from(i as u64)).collect(); + let poly = MultilinearPolynomial::new(evals); + + let gathered = poly.gather_prefix_evals(l0, 0); + + // Should return entire polynomial + assert_eq!(gathered.Z.len(), 8); + for i in 0..8 { + assert_eq!(gathered[i], pallas::Scalar::from(i as u64)); + } + } + + #[test] + fn test_gather_prefix_l0_equals_1() { + // ℓ₀ = 1: single prefix bit + let l0 = 1; + let evals: Vec = (0..16).map(|i| pallas::Scalar::from(i as u64)).collect(); + let poly = MultilinearPolynomial::new(evals); + + // suffix = 5 (binary 101): should get indices 5, 13 + let gathered = poly.gather_prefix_evals(l0, 5); + assert_eq!(gathered.Z.len(), 2); + assert_eq!(gathered[0], pallas::Scalar::from(5u64)); // prefix=0: 0*8 + 5 = 5 + assert_eq!(gathered[1], pallas::Scalar::from(13u64)); // prefix=1: 1*8 + 5 = 13 + } + + #[test] + fn test_gather_then_extend_preserves_binary_points() { + use crate::lagrange_accumulator::{LagrangeEvaluatedMultilinearPolynomial, LagrangeIndex}; + use ff::Field; + + let l = 4; + let l0 = 2; + + // Random polynomial + let evals: Vec = (0..(1 << l)) + .map(|_| pallas::Scalar::random(&mut OsRng)) + .collect(); + let poly = MultilinearPolynomial::new(evals); + + let suffix_vars = l - l0; + let num_suffix = 1 << suffix_vars; + + for suffix in 0..num_suffix { + let gathered = poly.gather_prefix_evals(l0, suffix); + + // Extend to Lagrange domain + let extended = LagrangeEvaluatedMultilinearPolynomial::<_, 3>::from_multilinear(&gathered); + + // Verify: at binary points, extended values match original poly values + for prefix_bits in 0..(1 << l0) { + let original_idx = (prefix_bits << suffix_vars) | suffix; + let original_val = poly[original_idx]; + + // Convert binary prefix to U_D^ℓ₀ tuple + let tuple = LagrangeIndex::<3>::from_binary(prefix_bits, l0); + + assert_eq!( + extended.get_by_domain(&tuple), + original_val, + "Mismatch at suffix={}, prefix_bits={}", + suffix, + prefix_bits + ); + } + } + } } diff --git a/src/polys/univariate.rs b/src/polys/univariate.rs index 49bfcee..c64cdf6 100644 --- a/src/polys/univariate.rs +++ b/src/polys/univariate.rs @@ -32,7 +32,7 @@ pub struct UniPoly { /// /// The linear term coefficient is omitted to save space. For a polynomial $ax^2 + bx + c$, /// coefficients are stored as `vec![c, a]`. For $ax^3 + bx^2 + cx + d$, stored as `vec![d, c, a]`. -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct CompressedUniPoly { coeffs_except_linear_term: Vec, } diff --git a/src/provider/bn254.rs b/src/provider/bn254.rs new file mode 100644 index 0000000..82818aa --- /dev/null +++ b/src/provider/bn254.rs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! This module implements the Spartan traits for BN254 (also known as BN256 or alt_bn128). +use crate::{ + impl_traits, + provider::{ + msm::{msm, msm_small}, + traits::{DlogGroup, DlogGroupExt}, + }, + traits::{Group, PrimeFieldExt, transcript::TranscriptReprTrait}, +}; +use digest::{ExtendableOutput, Update}; +use ff::FromUniformBytes; +use halo2curves::{ + CurveAffine, CurveExt, + bn256::{G1 as Bn256G1, G1Affine as Bn256G1Affine}, + group::{Curve, Group as AnotherGroup, cofactor::CofactorCurveAffine}, +}; +use num_bigint::BigInt; +use num_integer::Integer; +use num_traits::{Num, ToPrimitive}; +use rayon::prelude::*; +use sha3::Shake256; +use std::io::Read; + +/// Re-exports that give access to the standard aliases used in the code base, for bn254 +#[allow(clippy::module_inception)] +pub mod bn254 { + pub use halo2curves::bn256::{Fq as Base, Fr as Scalar, G1 as Point, G1Affine as Affine}; +} + +impl_traits!( + bn254, + Bn256G1, + Bn256G1Affine, + // Fr (scalar field) modulus + "30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + // Fq (base field) modulus + "30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47" +); + +// BN254 is not a cycle pair, so we need to manually implement TranscriptReprTrait for the Base field +impl TranscriptReprTrait for bn254::Base { + fn to_transcript_bytes(&self) -> Vec { + self.to_bytes().into_iter().rev().collect() + } +} diff --git a/src/provider/mod.rs b/src/provider/mod.rs index 7166bb4..2416036 100644 --- a/src/provider/mod.rs +++ b/src/provider/mod.rs @@ -7,6 +7,7 @@ //! This module implements Spartan's traits using the following several different combinations // public modules to be used as an commitment engine with Spartan +pub mod bn254; pub mod keccak; pub mod pasta; pub mod pcs; @@ -17,6 +18,7 @@ mod msm; use crate::{ provider::{ + bn254::bn254 as bn254_types, keccak::Keccak256Transcript, pasta::{pallas, vesta}, pcs::hyrax_pc::HyraxPCS, @@ -43,6 +45,10 @@ pub struct P256HyraxEngine; #[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)] pub struct T256HyraxEngine; +/// An implementation of the Spartan Engine trait with BN254 curve and Hyrax commitment scheme +#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)] +pub struct Bn254Engine; + impl Engine for PallasHyraxEngine { type Base = pallas::Base; type Scalar = pallas::Scalar; @@ -74,3 +80,11 @@ impl Engine for T256HyraxEngine { type TE = Keccak256Transcript; type PCS = HyraxPCS; } + +impl Engine for Bn254Engine { + type Base = bn254_types::Base; + type Scalar = bn254_types::Scalar; + type GE = bn254_types::Point; + type TE = Keccak256Transcript; + type PCS = HyraxPCS; +} diff --git a/src/provider/msm.rs b/src/provider/msm.rs index 1a9aabb..0d741cc 100644 --- a/src/provider/msm.rs +++ b/src/provider/msm.rs @@ -481,8 +481,15 @@ mod tests { for bit_width in [1, 4, 8, 10, 16, 20, 32, 40, 64] { println!("bit_width: {bit_width}"); assert!(bit_width <= 64); // Ensure we don't overflow F::from + let _bound = 1u128 << bit_width; let coeffs: Vec = (0..n) - .map(|_| rand::random::() % (1 << bit_width)) + .map(|_| { + if bit_width == 64 { + rand::random::() + } else { + rand::random::() % (1 << bit_width) + } + }) .collect::>(); let coeffs_scalar: Vec = coeffs.iter().map(|b| F::from(*b)).collect::>(); let general = msm(&coeffs_scalar, &bases, true); diff --git a/src/small_field/barrett.rs b/src/small_field/barrett.rs new file mode 100644 index 0000000..3ac2ca3 --- /dev/null +++ b/src/small_field/barrett.rs @@ -0,0 +1,1076 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Barrett reduction for small × field multiplication. + +use super::limbs::{gte_4_4, gte_5_4, mul_4_by_1, mul_5_by_1, sub_4_4, sub_5_4, sub_5_5}; +use halo2curves::{ + bn256::Fr as Bn254Fr, + pasta::{Fp, Fq}, +}; +use std::ops::Neg; + +// ========================================================================== +// FieldReductionConstants - Trait for field-specific reduction constants +// ========================================================================== + +/// Trait providing precomputed constants for efficient modular reduction. +/// +/// # Overview +/// +/// When reducing a wide integer (more than 4 limbs = 256 bits) modulo a prime p, +/// we need to handle "overflow limbs" that represent values ≥ 2^256. Each overflow +/// limb at position i represents the value `limb[i] × 2^(64×i)`. +/// +/// # The R Constants +/// +/// For each bit position beyond 256 bits, we precompute `2^k mod p`: +/// +/// | Constant | Value | Used When | +/// |----------|-------|-----------| +/// | `R256_MOD` | 2^256 mod p | Reducing 5th limb (bits 256-319) | +/// | `R320_MOD` | 2^320 mod p | Reducing 6th limb (bits 320-383) | +/// | `R384_MOD` | 2^384 mod p | Reducing 7th limb (bits 384-447) | +/// | `R448_MOD` | 2^448 mod p | Reducing 8th limb (bits 448-511) | +/// | `R512_MOD` | 2^512 mod p | Reducing 9th limb (bits 512-575) | +/// +/// # Example: 6-limb Reduction +/// +/// For a 6-limb value `c = [c0, c1, c2, c3, c4, c5]` representing: +/// ```text +/// c = c0 + c1·2^64 + c2·2^128 + c3·2^192 + c4·2^256 + c5·2^320 +/// ``` +/// +/// We reduce by computing: +/// ```text +/// c mod p = (c0 + c1·2^64 + c2·2^128 + c3·2^192) +/// + c4·(2^256 mod p) +/// + c5·(2^320 mod p) +/// ``` +/// +/// Since `R256_MOD` and `R320_MOD` are 4-limb values (< 2^256), multiplying +/// by a single limb produces at most a 5-limb result, which can then be +/// reduced further if needed. +/// +/// # Why This Works +/// +/// By the properties of modular arithmetic: +/// `a ≡ b (mod p) ⟹ c·a ≡ c·b (mod p)` +/// +/// So `c5·2^320 ≡ c5·R320_MOD (mod p)`, and the right side is much smaller. +/// +/// # Performance +/// +/// Much faster than naive division: avoids division entirely, uses only +/// 4-5 64-bit multiplications with precomputed constants. +pub trait FieldReductionConstants { + /// The 4-limb prime modulus p (little-endian, 256 bits) + const MODULUS: [u64; 4]; + + /// 2×p as a 5-limb value (for Barrett reduction comparisons) + const MODULUS_2P: [u64; 5]; + + /// Barrett approximation constant μ = floor(2^128 / (p >> 191)) + /// Used to estimate the quotient in Barrett reduction + const MU: u64; + + /// 2^256 mod p - reduces the 5th limb (index 4) of a wide integer + const R256_MOD: [u64; 4]; + + /// 2^320 mod p - reduces the 6th limb (index 5) of a wide integer + const R320_MOD: [u64; 4]; + + /// 2^384 mod p - reduces the 7th limb (index 6) of a wide integer + const R384_MOD: [u64; 4]; + + /// 2^448 mod p - reduces the 8th limb (index 7) of a wide integer + const R448_MOD: [u64; 4]; + + /// 2^512 mod p - reduces the 9th limb (index 8) of a wide integer + const R512_MOD: [u64; 4]; + + /// Montgomery inverse: -p^(-1) mod 2^64 + /// Used in Montgomery REDC to eliminate low limbs + const MONT_INV: u64; +} + +// ========================================================================== +// FieldReductionConstants implementation for Fp (Pallas base field) +// ========================================================================== + +impl FieldReductionConstants for Fp { + // p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001 + const MODULUS: [u64; 4] = [ + 0x992d30ed00000001, + 0x224698fc094cf91b, + 0x0000000000000000, + 0x4000000000000000, + ]; + + const MODULUS_2P: [u64; 5] = double_limbs(Self::MODULUS); + + const MU: u64 = 0xffffffffffffffff; + + // 2^256 mod p = 0x3fffffffffffffff992c350be41914ad34786d38fffffffd + const R256_MOD: [u64; 4] = [ + 0x34786d38fffffffd, + 0x992c350be41914ad, + 0xffffffffffffffff, + 0x3fffffffffffffff, + ]; + + // 2^320 mod p = 0x3fffffffffffffff76e59c0fdacc1b91bd91d548094cf917992d30ed00000001 + const R320_MOD: [u64; 4] = [ + 0x992d30ed00000001, + 0xbd91d548094cf917, + 0x76e59c0fdacc1b91, + 0x3fffffffffffffff, + ]; + + // 2^384 mod p + const R384_MOD: [u64; 4] = [ + 0xcb8792c700000003, + 0x66d3caf41be6eb52, + 0x9b4b3c4bfffffffc, + 0x36e59c0fdacc1b91, + ]; + + // 2^448 mod p + const R448_MOD: [u64; 4] = [ + 0x9b9858f294cf91ba, + 0x8635bd2c4252b065, + 0x496d41af7b9cb714, + 0x1b4b3c4bfffffffc, + ]; + + // 2^512 mod p + const R512_MOD: [u64; 4] = [ + 0x8c78ecb30000000f, + 0xd7d30dbd8b0de0e7, + 0x7797a99bc3c95d18, + 0x096d41af7b9cb714, + ]; + + // -p^(-1) mod 2^64 + const MONT_INV: u64 = 0x992d30ecffffffff; +} + +// ========================================================================== +// FieldReductionConstants implementation for Fq (Pallas scalar field) +// ========================================================================== + +impl FieldReductionConstants for Fq { + // q = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001 + const MODULUS: [u64; 4] = [ + 0x8c46eb2100000001, + 0x224698fc0994a8dd, + 0x0000000000000000, + 0x4000000000000000, + ]; + + const MODULUS_2P: [u64; 5] = double_limbs(Self::MODULUS); + + const MU: u64 = 0xffffffffffffffff; + + // 2^256 mod q + const R256_MOD: [u64; 4] = [ + 0x5b2b3e9cfffffffd, + 0x992c350be3420567, + 0xffffffffffffffff, + 0x3fffffffffffffff, + ]; + + // 2^320 mod q + const R320_MOD: [u64; 4] = [ + 0x8c46eb2100000001, + 0xf12aec780994a8d9, + 0x76e59c0fd9ad5c89, + 0x3fffffffffffffff, + ]; + + // 2^384 mod q + const R384_MOD: [u64; 4] = [ + 0xa4d4c16300000003, + 0x66d3caf41cbdfa98, + 0xcee4537bfffffffc, + 0x36e59c0fd9ad5c89, + ]; + + // 2^448 mod q + const R448_MOD: [u64; 4] = [ + 0xcc920bb9994a8dd9, + 0x87a7dcbe1ff6e0d7, + 0x496d41af7ccfdaa9, + 0x0ee4537bfffffffc, + ]; + + // 2^512 mod q + const R512_MOD: [u64; 4] = [ + 0xfc9678ff0000000f, + 0x67bb433d891a16e3, + 0x7fae231004ccf590, + 0x096d41af7ccfdaa9, + ]; + + // -q^(-1) mod 2^64 + const MONT_INV: u64 = 0x8c46eb20ffffffff; +} + +// ========================================================================== +// FieldReductionConstants implementation for Bn254Fr (BN254 scalar field) +// ========================================================================== + +impl FieldReductionConstants for Bn254Fr { + // r = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 + const MODULUS: [u64; 4] = [ + 0x43e1f593f0000001, + 0x2833e84879b97091, + 0xb85045b68181585d, + 0x30644e72e131a029, + ]; + + const MODULUS_2P: [u64; 5] = double_limbs(Self::MODULUS); + + const MU: u64 = 0xffffffffffffffff; + + // 2^256 mod r + const R256_MOD: [u64; 4] = [ + 0xac96341c4ffffffb, + 0x36fc76959f60cd29, + 0x666ea36f7879462e, + 0x0e0a77c19a07df2f, + ]; + + // 2^320 mod r + const R320_MOD: [u64; 4] = [ + 0xb4c6edf97c5fb586, + 0x708c8d50bfeb93be, + 0x9ffd1de404f7e0ef, + 0x215b02ac9a392866, + ]; + + // 2^384 mod r + const R384_MOD: [u64; 4] = [ + 0xb075da81ef8cfeb9, + 0xa7f12acca5b6cd8c, + 0x32c475047957bf7b, + 0x03d581d748ffa25e, + ]; + + // 2^448 mod r + const R448_MOD: [u64; 4] = [ + 0x5665c3b5c177f51a, + 0x00e7f02ade75c713, + 0xb09192e52f747168, + 0x0621c0bbcccdc65d, + ]; + + // 2^512 mod r + const R512_MOD: [u64; 4] = [ + 0x1bb8e645ae216da7, + 0x53fe3ab1e35c59e3, + 0x8c49833d53bb8085, + 0x0216d0b17f4e44a5, + ]; + + // -r^(-1) mod 2^64 + const MONT_INV: u64 = 0xc2e1f593efffffff; +} + +// ========================================================================== +// BarrettField - Trait for generic small × field multiplication +// ========================================================================== + +/// Trait for field types supporting optimized small × field multiplication. +/// +/// Extends `FieldReductionConstants` with methods for accessing the internal +/// limb representation and constructing field elements from limbs. +pub(crate) trait BarrettField: FieldReductionConstants + Neg + Copy { + /// 2^64 as a field element, used for u128 multiplication decomposition. + const TWO_POW_64: Self; + + /// Construct a field element from 4 Montgomery-form limbs. + fn from_limbs(limbs: [u64; 4]) -> Self; + + /// Access the internal Montgomery-form limbs. + fn to_limbs(&self) -> &[u64; 4]; +} + +impl BarrettField for Fp { + const TWO_POW_64: Self = Fp::from_raw([0, 1, 0, 0]); + + #[inline] + fn from_limbs(limbs: [u64; 4]) -> Self { + Fp(limbs) + } + + #[inline] + fn to_limbs(&self) -> &[u64; 4] { + &self.0 + } +} + +impl BarrettField for Fq { + const TWO_POW_64: Self = Fq::from_raw([0, 1, 0, 0]); + + #[inline] + fn from_limbs(limbs: [u64; 4]) -> Self { + Fq(limbs) + } + + #[inline] + fn to_limbs(&self) -> &[u64; 4] { + &self.0 + } +} + +impl BarrettField for Bn254Fr { + const TWO_POW_64: Self = Bn254Fr::from_raw([0, 1, 0, 0]); + + #[inline] + fn from_limbs(limbs: [u64; 4]) -> Self { + Bn254Fr(limbs) + } + + #[inline] + fn to_limbs(&self) -> &[u64; 4] { + &self.0 + } +} + +/// Computes 2*p, returning a 5-limb result. +const fn double_limbs(p: [u64; 4]) -> [u64; 5] { + let (r0, c0) = p[0].overflowing_add(p[0]); + let (r1, c1) = { + let (sum, c1a) = p[1].overflowing_add(p[1]); + let (sum, c1b) = sum.overflowing_add(c0 as u64); + (sum, c1a || c1b) + }; + let (r2, c2) = { + let (sum, c2a) = p[2].overflowing_add(p[2]); + let (sum, c2b) = sum.overflowing_add(c1 as u64); + (sum, c2a || c2b) + }; + let (r3, c3) = { + let (sum, c3a) = p[3].overflowing_add(p[3]); + let (sum, c3b) = sum.overflowing_add(c2 as u64); + (sum, c3a || c3b) + }; + let r4 = c3 as u64; + [r0, r1, r2, r3, r4] +} + +// ========================================================================== +// Generic small × field multiplication +// ========================================================================== + +/// Multiply field element by i64 (signed). +#[inline] +pub(crate) fn mul_by_i64(large: &F, small: i64) -> F { + if small >= 0 { + mul_by_u64(large, small as u64) + } else { + mul_by_u64(large, small.wrapping_neg() as u64).neg() + } +} + +/// Multiply field element by i128 (signed). +#[inline] +#[allow(dead_code)] +pub(crate) fn mul_by_i128< + F: BarrettField + std::ops::Add + std::ops::Mul, +>( + large: &F, + small: i128, +) -> F { + if small >= 0 { + mul_by_u128(large, small as u128) + } else { + mul_by_u128(large, small.wrapping_neg() as u128).neg() + } +} + +/// Multiply field element by u128 (unsigned). +#[inline] +#[allow(dead_code)] +fn mul_by_u128 + std::ops::Mul>( + large: &F, + small: u128, +) -> F { + let low = small as u64; + let high = (small >> 64) as u64; + + if high == 0 { + mul_by_u64(large, low) + } else { + // result = large * low + large * high * 2^64 + let low_part = mul_by_u64(large, low); + let high_part = mul_by_u64(large, high); + low_part + high_part * F::TWO_POW_64 + } +} + +/// Multiply field element by u64 (unsigned). +#[inline] +fn mul_by_u64(large: &F, small: u64) -> F { + if small == 0 { + return F::from_limbs([0, 0, 0, 0]); + } + if small == 1 { + return *large; + } + let c = mul_4_by_1(large.to_limbs(), small); + F::from_limbs(barrett_reduce_5::(&c)) +} + +/// Generic 5-limb Barrett reduction using trait constants. +/// Reduces a 5-limb value (up to 320 bits) modulo p. +#[inline(always)] +fn barrett_reduce_5(c: &[u64; 5]) -> [u64; 4] { + let c_tilde = (c[3] >> 63) | (c[4] << 1); + let m = { + let product = (c_tilde as u128) * (F::MU as u128); + (product >> 64) as u64 + }; + let m_times_2p = mul_5_by_1(&F::MODULUS_2P, m); + let mut r = sub_5_5(c, &m_times_2p); + // At most 2 iterations needed: after Barrett approximation, 0 ≤ r < 2p + while gte_5_4(&r, &F::MODULUS) { + r = sub_5_4(&r, &F::MODULUS); + } + [r[0], r[1], r[2], r[3]] +} + +// ========================================================================== +// 6-limb Barrett reduction (for UnreducedFieldInt accumulator) +// ========================================================================== + +/// Generic 6-limb Barrett reduction using trait constants. +/// +/// Reduces a 6-limb value (up to 384 bits) modulo p using precomputed +/// R256 = 2^256 mod p and R320 = 2^320 mod p constants. +/// +/// Input is already in Montgomery form (R-scaled). This function reduces +/// the 6-limb value mod p while preserving the Montgomery scaling. +#[inline] +fn barrett_reduce_6(c: &[u64; 6]) -> [u64; 4] { + // Reduce c[5] * 2^320 ≡ c[5] * R320 (mod p), then reduce c[4] * 2^256, etc. + // R320 = 2^320 mod p, R256 = 2^256 mod p + // + // We do: result = c[0..4] + c[4] * R256 + c[5] * R320 (mod p) + + // c[4] * R256_MOD (4x1 -> 5 limbs) + let c4_contrib = mul_4_by_1(&F::R256_MOD, c[4]); + // c[5] * R320_MOD (4x1 -> 5 limbs) + let c5_contrib = mul_4_by_1(&F::R320_MOD, c[5]); + + // Sum: c[0..4] + c4_contrib + c5_contrib (could be up to 6 limbs) + let mut sum = [0u64; 6]; + let mut carry = 0u128; + for i in 0..4 { + let s = (c[i] as u128) + (c4_contrib[i] as u128) + (c5_contrib[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + // Limb 4: c4_contrib[4] + c5_contrib[4] + carry + let s = (c4_contrib[4] as u128) + (c5_contrib[4] as u128) + carry; + sum[4] = s as u64; + sum[5] = (s >> 64) as u64; + + // Now reduce the 6-limb sum. If sum[5] or sum[4] is non-zero, recurse (limited depth) + if sum[5] == 0 && sum[4] == 0 { + // Result fits in 4 limbs, just do final reduction + let mut r = [sum[0], sum[1], sum[2], sum[3]]; + while gte_4_4(&r, &F::MODULUS) { + r = sub_4_4(&r, &F::MODULUS); + } + return r; + } + + // Recurse (this will terminate because sum < c in most cases) + barrett_reduce_6::(&sum) +} + +#[inline] +#[allow(dead_code)] +pub(crate) fn barrett_reduce_6_fp(c: &[u64; 6]) -> [u64; 4] { + barrett_reduce_6::(c) +} + +#[inline] +#[allow(dead_code)] +pub(crate) fn barrett_reduce_6_fq(c: &[u64; 6]) -> [u64; 4] { + barrett_reduce_6::(c) +} + +// ========================================================================== +// 8-limb Barrett reduction (for i64/i128 UnreducedFieldInt accumulator) +// ========================================================================== + +/// Generic 8-limb Barrett reduction using trait constants. +/// +/// Reduces an 8-limb value (up to 512 bits) modulo p using precomputed +/// R384 = 2^384 mod p and R448 = 2^448 mod p constants, then delegates +/// to 6-limb reduction. +#[inline] +fn barrett_reduce_8(c: &[u64; 8]) -> [u64; 4] { + // Reduce high limbs: c[6] * 2^384 + c[7] * 2^448 + let c6_contrib = mul_4_by_1(&F::R384_MOD, c[6]); + let c7_contrib = mul_4_by_1(&F::R448_MOD, c[7]); + + // Sum: c[0..6] + c6_contrib + c7_contrib + let mut sum = [0u64; 6]; + let mut carry = 0u128; + for i in 0..4 { + let s = (c[i] as u128) + (c6_contrib[i] as u128) + (c7_contrib[i] as u128) + carry; + sum[i] = s as u64; + carry = s >> 64; + } + // Limbs 4-5: add c[4], c[5], and carry from high contributions + let s = (c[4] as u128) + (c6_contrib[4] as u128) + (c7_contrib[4] as u128) + carry; + sum[4] = s as u64; + carry = s >> 64; + let s = (c[5] as u128) + carry; + sum[5] = s as u64; + + // Now reduce the 6-limb result + barrett_reduce_6::(&sum) +} + +#[inline] +pub(crate) fn barrett_reduce_8_fp(c: &[u64; 8]) -> [u64; 4] { + barrett_reduce_8::(c) +} + +#[inline] +pub(crate) fn barrett_reduce_8_fq(c: &[u64; 8]) -> [u64; 4] { + barrett_reduce_8::(c) +} + +// ========================================================================== +// 9-limb Montgomery REDC (for UnreducedFieldField accumulator) +// ========================================================================== + +/// Generic Montgomery REDC for 9-limb input using trait constants. +/// +/// Reduces a 2R-scaled value (sum of field×field products) to 1R-scaled. +/// Input: T representing x·R² (up to 9 limbs) +/// Output: x·R mod p (4 limbs, standard Montgomery form) +#[inline] +fn montgomery_reduce_9(c: &[u64; 9]) -> [u64; 4] { + // Step 1: Reduce 9 limbs to 8 limbs using precomputed 2^512 mod p + let mut t = [0u64; 9]; + if c[8] == 0 { + t[..8].copy_from_slice(&c[..8]); + } else { + // t = c[0..8] + c[8] * R512_MOD + let high_contribution = mul_4_by_1(&F::R512_MOD, c[8]); + let mut carry = 0u128; + for i in 0..5 { + let sum = (c[i] as u128) + (high_contribution[i] as u128) + carry; + t[i] = sum as u64; + carry = sum >> 64; + } + for i in 5..8 { + let sum = (c[i] as u128) + carry; + t[i] = sum as u64; + carry = sum >> 64; + } + t[8] = carry as u64; + + // Recurse if still > 8 limbs + if t[8] > 0 { + return montgomery_reduce_9::(&t); + } + } + + // Step 2: Montgomery REDC on 8-limb value + montgomery_reduce_8::(&[t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7]]) +} + +/// Generic Montgomery REDC for 8-limb input. +/// Standard Montgomery reduction: T × R⁻¹ mod p +#[inline] +fn montgomery_reduce_8(t: &[u64; 8]) -> [u64; 4] { + // Use 9 limbs to track overflow + let mut r = [t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7], 0u64]; + + // Montgomery reduction: for each of the low 4 limbs, eliminate it + // by adding appropriate multiples of p + for i in 0..4 { + // q = r[i] * p' mod 2^64 + let q = r[i].wrapping_mul(F::MONT_INV); + + // r += q * p * 2^(64*i) + // qp = q * p, which is 5 limbs (since p is 4 limbs and q is 1 limb) + let qp = mul_4_by_1(&F::MODULUS, q); + + let mut carry = 0u128; + for j in 0..5 { + let sum = (r[i + j] as u128) + (qp[j] as u128) + carry; + r[i + j] = sum as u64; + carry = sum >> 64; + } + // Propagate remaining carry through the rest of the array + for item in r[(i + 5)..9].iter_mut() { + let sum = (*item as u128) + carry; + *item = sum as u64; + carry = sum >> 64; + if carry == 0 { + break; + } + } + } + + // Now r[0..4] should be zero (by construction), result is in r[4..9] + // We need to reduce this to [0, p) + let mut result = [r[4], r[5], r[6], r[7], r[8]]; + + // Reduce until result < p + while result[4] > 0 || gte_4_4(&[result[0], result[1], result[2], result[3]], &F::MODULUS) { + let sub = sub_5_4(&result, &F::MODULUS); + result = sub; + } + + [result[0], result[1], result[2], result[3]] +} + +#[inline] +#[allow(dead_code)] +pub(crate) fn montgomery_reduce_9_fp(c: &[u64; 9]) -> [u64; 4] { + montgomery_reduce_9::(c) +} + +#[inline] +#[allow(dead_code)] +pub(crate) fn montgomery_reduce_9_fq(c: &[u64; 9]) -> [u64; 4] { + montgomery_reduce_9::(c) +} + +// ========================================================================== +// BN254 Fr reduction functions +// ========================================================================== + +#[inline] +#[allow(dead_code)] +pub(crate) fn barrett_reduce_6_bn254_fr(c: &[u64; 6]) -> [u64; 4] { + barrett_reduce_6::(c) +} + +#[inline] +#[allow(dead_code)] +pub(crate) fn barrett_reduce_8_bn254_fr(c: &[u64; 8]) -> [u64; 4] { + barrett_reduce_8::(c) +} + +#[inline] +#[allow(dead_code)] +pub(crate) fn montgomery_reduce_9_bn254_fr(c: &[u64; 9]) -> [u64; 4] { + montgomery_reduce_9::(c) +} + +// ========================================================================== +// Tests +// ========================================================================== + +#[cfg(test)] +mod tests { + use super::*; + use ff::Field; + use rand_core::{OsRng, RngCore}; + + #[test] + fn test_barrett_fp_matches_naive() { + for small in [0u64, 1, 2, 42, 1000, u32::MAX as u64, u64::MAX] { + let large = Fp::random(&mut OsRng); + let naive = Fp::from(small) * large; + let barrett = mul_by_u64(&large, small); + assert_eq!(naive, barrett, "Fp mismatch for small = {}", small); + } + } + + #[test] + fn test_barrett_fp_random() { + let mut rng = OsRng; + for _ in 0..1000 { + let large = Fp::random(&mut rng); + let small: u64 = rng.next_u64(); + assert_eq!(Fp::from(small) * large, mul_by_u64(&large, small)); + } + } + + #[test] + fn test_barrett_fp_i64() { + let large = Fp::from(42u64); + assert_eq!(mul_by_i64(&large, 100i64), Fp::from(100u64) * large); + assert_eq!(mul_by_i64(&large, -100i64), -Fp::from(100u64) * large); + } + + #[test] + fn test_barrett_fq_matches_naive() { + for small in [0u64, 1, 2, 42, 1000, u32::MAX as u64, u64::MAX] { + let large = Fq::random(&mut OsRng); + let naive = Fq::from(small) * large; + let barrett = mul_by_u64(&large, small); + assert_eq!(naive, barrett, "Fq mismatch for small = {}", small); + } + } + + #[test] + fn test_barrett_fq_random() { + let mut rng = OsRng; + for _ in 0..1000 { + let large = Fq::random(&mut rng); + let small: u64 = rng.next_u64(); + assert_eq!(Fq::from(small) * large, mul_by_u64(&large, small)); + } + } + + #[test] + fn test_barrett_fq_i64() { + let large = Fq::from(42u64); + assert_eq!(mul_by_i64(&large, 100i64), Fq::from(100u64) * large); + assert_eq!(mul_by_i64(&large, -100i64), -Fq::from(100u64) * large); + } + + #[test] + fn test_barrett_fp_i128() { + let large = Fp::random(&mut OsRng); + + // Test small i128 values (fits in i64) + assert_eq!(mul_by_i128(&large, 100), Fp::from(100u64) * large); + assert_eq!(mul_by_i128(&large, -100), -Fp::from(100u64) * large); + + // Test large i128 values (requires 2^64 decomposition) + let big: i128 = (1i128 << 70) + 12345; + let expected = crate::small_field::i128_to_field::(big) * large; + assert_eq!(mul_by_i128(&large, big), expected); + + let neg_big: i128 = -((1i128 << 70) + 12345); + let expected_neg = crate::small_field::i128_to_field::(neg_big) * large; + assert_eq!(mul_by_i128(&large, neg_big), expected_neg); + } + + #[test] + fn test_barrett_fq_i128() { + let large = Fq::random(&mut OsRng); + + // Test small i128 values (fits in i64) + assert_eq!(mul_by_i128(&large, 100), Fq::from(100u64) * large); + assert_eq!(mul_by_i128(&large, -100), -Fq::from(100u64) * large); + + // Test large i128 values (requires 2^64 decomposition) + let big: i128 = (1i128 << 70) + 12345; + let expected = crate::small_field::i128_to_field::(big) * large; + assert_eq!(mul_by_i128(&large, big), expected); + + let neg_big: i128 = -((1i128 << 70) + 12345); + let expected_neg = crate::small_field::i128_to_field::(neg_big) * large; + assert_eq!(mul_by_i128(&large, neg_big), expected_neg); + } + + #[test] + fn test_barrett_i128_random() { + let mut rng = OsRng; + for _ in 0..100 { + let large = Fq::random(&mut rng); + // Generate random i128 in reasonable range + let small: i128 = ((rng.next_u64() as i128) << 32) | (rng.next_u64() as i128); + let small = if rng.next_u32().is_multiple_of(2) { + small + } else { + -small + }; + + let result = mul_by_i128(&large, small); + let expected = crate::small_field::i128_to_field::(small) * large; + assert_eq!(result, expected); + } + } + + #[test] + fn test_two_pow_64_constants() { + // Verify Fp::TWO_POW_64 is correct + let computed = Fp::from(1u64 << 32) * Fp::from(1u64 << 32); + assert_eq!(Fp::TWO_POW_64, computed); + + // Verify Fq::TWO_POW_64 is correct + let computed = Fq::from(1u64 << 32) * Fq::from(1u64 << 32); + assert_eq!(Fq::TWO_POW_64, computed); + } + + #[test] + fn test_constants_match_halo2curves() { + let p_minus_one = -Fp::ONE; + let expected = Fp::from_raw([ + Fp::MODULUS[0].wrapping_sub(1), + Fp::MODULUS[1], + Fp::MODULUS[2], + Fp::MODULUS[3], + ]); + assert_eq!(p_minus_one, expected); + + let q_minus_one = -Fq::ONE; + let expected = Fq::from_raw([ + Fq::MODULUS[0].wrapping_sub(1), + Fq::MODULUS[1], + Fq::MODULUS[2], + Fq::MODULUS[3], + ]); + assert_eq!(q_minus_one, expected); + } + + // ======================================================================== + // 6-limb Barrett reduction tests + // ======================================================================== + + #[test] + fn test_barrett_6_fp_zero() { + let c = [0u64; 6]; + let result = Fp(barrett_reduce_6_fp(&c)); + assert_eq!(result, Fp::ZERO); + } + + #[test] + fn test_barrett_6_fp_from_product() { + // Test with an actual field × integer product (the real use case) + let field_elem = Fp::from(12345u64); // Creates Montgomery form + let small = 9999u64; + + // Compute field × small as 5 limbs (Montgomery form) + let product = mul_4_by_1(&field_elem.0, small); + + // Extend to 6 limbs + let c = [ + product[0], product[1], product[2], product[3], product[4], 0, + ]; + + // Reduce + let result = Fp(barrett_reduce_6_fp(&c)); + + // Expected: field_elem * small + let expected = field_elem * Fp::from(small); + assert_eq!(result, expected); + } + + #[test] + fn test_barrett_6_fp_sum_of_products() { + // Test summing multiple products (the accumulator use case) + let mut rng = OsRng; + let mut acc = [0u64; 6]; + let mut expected_sum = Fp::ZERO; + + // Sum 100 products + for _ in 0..100 { + let field_elem = Fp::random(&mut rng); + let small = rng.next_u64() >> 32; // Keep small to avoid overflow in test + + // Accumulate expected result + expected_sum += field_elem * Fp::from(small); + + // Compute field × small as 5 limbs + let product = mul_4_by_1(&field_elem.0, small); + + // Add to accumulator with carry propagation + let mut carry = 0u128; + for i in 0..5 { + let sum = (acc[i] as u128) + (product[i] as u128) + carry; + acc[i] = sum as u64; + carry = sum >> 64; + } + acc[5] = acc[5].wrapping_add(carry as u64); + } + + // Reduce and compare + let result = Fp(barrett_reduce_6_fp(&acc)); + assert_eq!(result, expected_sum); + } + + #[test] + fn test_barrett_6_fp_many_products() { + // Stress test: sum many products to exercise 6-limb reduction + // Note: In the real use case, we sum at most 2^(ℓ/2) products where ℓ ≤ 130. + // For ℓ = 20 (a typical size), that's 2^10 = 1024 products. + let mut rng = OsRng; + let mut acc = [0u64; 6]; + let mut expected_sum = Fp::ZERO; + + // Sum 2000 products (realistic bound for medium-sized polynomials) + for _ in 0..2000 { + let field_elem = Fp::random(&mut rng); + let small = rng.next_u64(); + + expected_sum += field_elem * Fp::from(small); + + let product = mul_4_by_1(&field_elem.0, small); + let mut carry = 0u128; + for i in 0..5 { + let sum = (acc[i] as u128) + (product[i] as u128) + carry; + acc[i] = sum as u64; + carry = sum >> 64; + } + acc[5] = acc[5].wrapping_add(carry as u64); + } + + let result = Fp(barrett_reduce_6_fp(&acc)); + assert_eq!(result, expected_sum); + } + + #[test] + fn test_barrett_6_fq_sum_of_products() { + let mut rng = OsRng; + let mut acc = [0u64; 6]; + let mut expected_sum = Fq::ZERO; + + for _ in 0..100 { + let field_elem = Fq::random(&mut rng); + let small = rng.next_u64() >> 32; + + expected_sum += field_elem * Fq::from(small); + + let product = mul_4_by_1(&field_elem.0, small); + let mut carry = 0u128; + for i in 0..5 { + let sum = (acc[i] as u128) + (product[i] as u128) + carry; + acc[i] = sum as u64; + carry = sum >> 64; + } + acc[5] = acc[5].wrapping_add(carry as u64); + } + + let result = Fq(barrett_reduce_6_fq(&acc)); + assert_eq!(result, expected_sum); + } + + // ======================================================================== + // 9-limb Montgomery REDC tests + // ======================================================================== + + /// Helper to multiply two 4-limb values, producing an 8-limb result + fn mul_4_by_4(a: &[u64; 4], b: &[u64; 4]) -> [u64; 8] { + let mut result = [0u64; 8]; + for i in 0..4 { + let mut carry = 0u128; + for j in 0..4 { + let prod = (a[i] as u128) * (b[j] as u128) + (result[i + j] as u128) + carry; + result[i + j] = prod as u64; + carry = prod >> 64; + } + result[i + 4] = carry as u64; + } + result + } + + #[test] + fn test_montgomery_9_fp_single_product() { + // Test with a single field × field product + let a = Fp::from(12345u64); + let b = Fp::from(67890u64); + + // Compute a_mont × b_mont (8 limbs, representing (a*b)*R² in unreduced form) + let product = mul_4_by_4(&a.0, &b.0); + + // Extend to 9 limbs + let c = [ + product[0], product[1], product[2], product[3], product[4], product[5], product[6], + product[7], 0, + ]; + + // Montgomery reduce: should give (a*b)*R mod p = (a*b) in Montgomery form + let result = Fp(montgomery_reduce_9_fp(&c)); + + // Expected: a * b in field + let expected = a * b; + assert_eq!(result, expected); + } + + #[test] + fn test_montgomery_9_fp_sum_of_products() { + // Test summing multiple field × field products + let mut rng = OsRng; + let mut acc = [0u64; 9]; + let mut expected_sum = Fp::ZERO; + + // Sum 100 products + for _ in 0..100 { + let a = Fp::random(&mut rng); + let b = Fp::random(&mut rng); + + // Accumulate expected result + expected_sum += a * b; + + // Compute a_mont × b_mont as 8 limbs + let product = mul_4_by_4(&a.0, &b.0); + + // Add to accumulator with carry propagation + let mut carry = 0u128; + for i in 0..8 { + let sum = (acc[i] as u128) + (product[i] as u128) + carry; + acc[i] = sum as u64; + carry = sum >> 64; + } + acc[8] = acc[8].wrapping_add(carry as u64); + } + + // Montgomery reduce and compare + let result = Fp(montgomery_reduce_9_fp(&acc)); + assert_eq!(result, expected_sum); + } + + #[test] + fn test_montgomery_9_fp_many_products() { + // Stress test with many products + let mut rng = OsRng; + let mut acc = [0u64; 9]; + let mut expected_sum = Fp::ZERO; + + // Sum 1000 products + for _ in 0..1000 { + let a = Fp::random(&mut rng); + let b = Fp::random(&mut rng); + + expected_sum += a * b; + + let product = mul_4_by_4(&a.0, &b.0); + let mut carry = 0u128; + for i in 0..8 { + let sum = (acc[i] as u128) + (product[i] as u128) + carry; + acc[i] = sum as u64; + carry = sum >> 64; + } + acc[8] = acc[8].wrapping_add(carry as u64); + } + + let result = Fp(montgomery_reduce_9_fp(&acc)); + assert_eq!(result, expected_sum); + } + + #[test] + fn test_montgomery_9_fq_sum_of_products() { + let mut rng = OsRng; + let mut acc = [0u64; 9]; + let mut expected_sum = Fq::ZERO; + + for _ in 0..100 { + let a = Fq::random(&mut rng); + let b = Fq::random(&mut rng); + + expected_sum += a * b; + + let product = mul_4_by_4(&a.0, &b.0); + let mut carry = 0u128; + for i in 0..8 { + let sum = (acc[i] as u128) + (product[i] as u128) + carry; + acc[i] = sum as u64; + carry = sum >> 64; + } + acc[8] = acc[8].wrapping_add(carry as u64); + } + + let result = Fq(montgomery_reduce_9_fq(&acc)); + assert_eq!(result, expected_sum); + } +} diff --git a/src/small_field/delayed_reduction.rs b/src/small_field/delayed_reduction.rs new file mode 100644 index 0000000..6f26bad --- /dev/null +++ b/src/small_field/delayed_reduction.rs @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! DelayedReduction trait for accumulating unreduced products. + +use super::SmallValueField; +use std::{ + fmt::Debug, + ops::{Add, AddAssign, Neg, Sub, SubAssign}, +}; + +/// Extension trait for delayed modular reduction operations. +/// +/// This trait extends `SmallValueField` with operations that accumulate +/// unreduced products in wide integers, reducing only at the end. +/// Used in hot paths like matrix-vector multiplication where many products +/// are summed together. +/// +/// # Performance +/// +/// Delaying reduction saves ~1 field multiplication per accumulation: +/// - Without delayed reduction: N additions + N reductions +/// - With delayed reduction: N additions + 1 reduction +pub trait DelayedReduction: SmallValueField +where + SmallValue: Copy + + Clone + + Default + + Debug + + PartialEq + + Eq + + Add + + Sub + + Neg + + AddAssign + + SubAssign + + Send + + Sync, +{ + /// Unreduced accumulator for field × integer products. + /// - For i32/i64: SignedWideLimbs<6> (384 bits) + /// - For i64/i128: SignedWideLimbs<8> (512 bits) + /// + /// Sized to safely sum 2^(l/2) terms without overflow, assuming: + /// `field_bits + product_bits + (l/2) < 64*N` + /// (N = limb count for this accumulator, 64 bits per limb). + type UnreducedFieldInt: Copy + + Clone + + Default + + Debug + + AddAssign + + Send + + Sync + + num_traits::Zero; + + /// Unreduced accumulator for field × field products (9 limbs, 576 bits). + /// Used to delay modular reduction when summing many F × F products. + /// The value is in 2R-scaled Montgomery form, reduced via Montgomery REDC. + type UnreducedFieldField: Copy + + Clone + + Default + + Debug + + AddAssign + + Send + + Sync + + num_traits::Zero; + + /// Multiply field element by signed integer and add to unreduced accumulator. + /// acc += field × intermediate (keeps result in unreduced form, handles sign internally) + fn unreduced_field_int_mul_add( + acc: &mut Self::UnreducedFieldInt, + field: &Self, + small: Self::IntermediateSmallValue, + ); + + /// Multiply two field elements and add to unreduced accumulator. + /// acc += field_a × field_b (keeps result in 2R-scaled unreduced form) + fn unreduced_field_field_mul_add( + acc: &mut Self::UnreducedFieldField, + field_a: &Self, + field_b: &Self, + ); + + /// Batch 4 independent field×int multiply-accumulates for ILP optimization. + /// Default implementation calls single version 4 times. + #[inline(always)] + fn unreduced_field_int_mul_add_batch4( + accs: [&mut Self::UnreducedFieldInt; 4], + field: &Self, + smalls: [Self::IntermediateSmallValue; 4], + ) { + let [acc0, acc1, acc2, acc3] = accs; + Self::unreduced_field_int_mul_add(acc0, field, smalls[0]); + Self::unreduced_field_int_mul_add(acc1, field, smalls[1]); + Self::unreduced_field_int_mul_add(acc2, field, smalls[2]); + Self::unreduced_field_int_mul_add(acc3, field, smalls[3]); + } + + /// Batch 4 independent field×field multiply-accumulates for ILP optimization. + /// Default implementation calls single version 4 times. + #[inline(always)] + fn unreduced_field_field_mul_add_batch4( + accs: [&mut Self::UnreducedFieldField; 4], + a: [&Self; 4], + b: [&Self; 4], + ) { + let [acc0, acc1, acc2, acc3] = accs; + Self::unreduced_field_field_mul_add(acc0, a[0], b[0]); + Self::unreduced_field_field_mul_add(acc1, a[1], b[1]); + Self::unreduced_field_field_mul_add(acc2, a[2], b[2]); + Self::unreduced_field_field_mul_add(acc3, a[3], b[3]); + } + + /// Reduce an unreduced field×integer accumulator to a field element. + fn reduce_field_int(acc: &Self::UnreducedFieldInt) -> Self; + + /// Reduce an unreduced field×field accumulator to a field element. + fn reduce_field_field(acc: &Self::UnreducedFieldField) -> Self; +} diff --git a/src/small_field/impls.rs b/src/small_field/impls.rs new file mode 100644 index 0000000..bca5aeb --- /dev/null +++ b/src/small_field/impls.rs @@ -0,0 +1,1436 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! SmallValueField and DelayedReduction implementations for Fp, Fq, and BN254 Fr. + +use super::{ + DelayedReduction, SmallValueField, barrett, i64_to_field, i128_to_field, + limbs::{SignedWideLimbs, SubMagResult, WideLimbs, mac, mul_4_by_2_ext, mul_4_by_4_ext, sub_mag}, +}; +use ff::PrimeField; +use halo2curves::bn256::Fr as Bn254Fr; + +// ============================================================================ +// Helper function for try_field_to_small +// ============================================================================ + +/// Helper for try_field_to_small: attempts to convert a field element to i32. +/// +/// Returns Some(v) if the field element represents a small integer in [-2^31, 2^31-1]. +fn try_field_to_small_impl(val: &F) -> Option { + let repr = val.to_repr(); + let bytes = repr.as_ref(); + + // Check if value fits in positive i32 + let high_zero = bytes[4..].iter().all(|&b| b == 0); + if high_zero { + let val_u32 = u32::from_le_bytes(bytes[..4].try_into().unwrap()); + if val_u32 <= i32::MAX as u32 { + return Some(val_u32 as i32); + } + } + + // Check if negation fits in i32 (value is negative) + let neg_val = val.neg(); + let neg_repr = neg_val.to_repr(); + let neg_bytes = neg_repr.as_ref(); + let neg_high_zero = neg_bytes[4..].iter().all(|&b| b == 0); + if neg_high_zero { + let neg_u32 = u32::from_le_bytes(neg_bytes[..4].try_into().unwrap()); + if neg_u32 > 0 && neg_u32 <= (i32::MAX as u32) + 1 { + return Some(-(neg_u32 as i64) as i32); + } + } + + None +} + +// ============================================================================ +// SmallValueField for Fp +// ============================================================================ + +impl SmallValueField for halo2curves::pasta::Fp { + type IntermediateSmallValue = i64; + + #[inline] + fn ss_mul(a: i32, b: i32) -> i64 { + (a as i64) * (b as i64) + } + + #[inline] + fn sl_mul(small: i32, large: &Self) -> Self { + barrett::mul_by_i64(large, small as i64) + } + + #[inline] + fn isl_mul(small: i64, large: &Self) -> Self { + barrett::mul_by_i64(large, small) + } + + #[inline] + fn small_to_field(val: i32) -> Self { + if val >= 0 { + Self::from(val as u64) + } else { + -Self::from((-val) as u64) + } + } + + #[inline] + fn intermediate_to_field(val: i64) -> Self { + i64_to_field(val) + } + + fn try_field_to_small(val: &Self) -> Option { + try_field_to_small_impl(val) + } +} + +// ============================================================================ +// DelayedReduction for Fp +// ============================================================================ + +impl DelayedReduction for halo2curves::pasta::Fp { + type UnreducedFieldInt = SignedWideLimbs<6>; + type UnreducedFieldField = WideLimbs<9>; + + #[inline(always)] + fn unreduced_field_int_mul_add(acc: &mut Self::UnreducedFieldInt, field: &Self, small: i64) { + // Handle sign: accumulate into pos or neg based on sign of small + let (target, mag) = if small >= 0 { + (&mut acc.pos, small as u64) + } else { + (&mut acc.neg, (-small) as u64) + }; + // Fused multiply-accumulate: no intermediate array + let a = &field.0; + let (r0, c) = mac(target.0[0], a[0], mag, 0); + let (r1, c) = mac(target.0[1], a[1], mag, c); + let (r2, c) = mac(target.0[2], a[2], mag, c); + let (r3, c) = mac(target.0[3], a[3], mag, c); + // Propagate carry without multiply (just add) + let (r4, of) = target.0[4].overflowing_add(c); + target.0[0] = r0; + target.0[1] = r1; + target.0[2] = r2; + target.0[3] = r3; + target.0[4] = r4; + target.0[5] = target.0[5].wrapping_add(of as u64); + } + + #[inline(always)] + fn unreduced_field_field_mul_add( + acc: &mut Self::UnreducedFieldField, + field_a: &Self, + field_b: &Self, + ) { + // Compute field_a × field_b as 8 limbs and add to accumulator + let product = mul_4_by_4_ext(&field_a.0, &field_b.0); + let mut carry = 0u128; + for (acc_limb, &prod_limb) in acc.0.iter_mut().take(8).zip(product.iter()) { + let sum = (*acc_limb as u128) + (prod_limb as u128) + carry; + *acc_limb = sum as u64; + carry = sum >> 64; + } + acc.0[8] = acc.0[8].wrapping_add(carry as u64); + } + + #[inline(always)] + fn unreduced_field_int_mul_add_batch4( + accs: [&mut Self::UnreducedFieldInt; 4], + field: &Self, + smalls: [i64; 4], + ) { + // Batched ILP version: interleave 4 independent carry chains + batch_unreduced_field_int_mul_add_x4_fp(accs, field, smalls); + } + + #[inline(always)] + fn unreduced_field_field_mul_add_batch4( + accs: [&mut Self::UnreducedFieldField; 4], + a: [&Self; 4], + b: [&Self; 4], + ) { + batch_unreduced_field_field_mul_add_x4( + accs, + [&a[0].0, &a[1].0, &a[2].0, &a[3].0], + [&b[0].0, &b[1].0, &b[2].0, &b[3].0], + ); + } + + #[inline(always)] + fn reduce_field_int(acc: &Self::UnreducedFieldInt) -> Self { + // Subtract in limb space first, then reduce once (saves one Barrett reduction) + match sub_mag::<6>(&acc.pos.0, &acc.neg.0) { + SubMagResult::Positive(mag) => Self(barrett::barrett_reduce_6_fp(&mag)), + SubMagResult::Negative(mag) => -Self(barrett::barrett_reduce_6_fp(&mag)), + } + } + + #[inline(always)] + fn reduce_field_field(acc: &Self::UnreducedFieldField) -> Self { + Self(barrett::montgomery_reduce_9_fp(&acc.0)) + } +} + +// ============================================================================ +// SmallValueField for Fp +// ============================================================================ + +impl SmallValueField for halo2curves::pasta::Fp { + type IntermediateSmallValue = i128; + + #[inline] + fn ss_mul(a: i64, b: i64) -> i128 { + (a as i128) * (b as i128) + } + + #[inline] + fn sl_mul(small: i64, large: &Self) -> Self { + barrett::mul_by_i64(large, small) + } + + #[inline] + fn isl_mul(small: i128, large: &Self) -> Self { + if small == 0 { + return Self::zero(); + } + let (is_neg, mag) = if small >= 0 { + (false, small as u128) + } else { + (true, (-small) as u128) + }; + // mul_4_by_2_ext produces 6 limbs, use barrett_reduce_6 directly (no padding) + let product = mul_4_by_2_ext(&large.0, mag); + let result = Self(barrett::barrett_reduce_6_fp(&product)); + if is_neg { -result } else { result } + } + + #[inline] + fn small_to_field(val: i64) -> Self { + i64_to_field(val) + } + + #[inline] + fn intermediate_to_field(val: i128) -> Self { + i128_to_field(val) + } + + fn try_field_to_small(val: &Self) -> Option { + let repr = val.to_repr(); + let bytes = repr.as_ref(); + + // Check if value fits in positive i64 + let high_zero = bytes[8..].iter().all(|&b| b == 0); + if high_zero { + let val_u64 = u64::from_le_bytes(bytes[..8].try_into().unwrap()); + if val_u64 <= i64::MAX as u64 { + return Some(val_u64 as i64); + } + } + + // Check if negation fits in i64 + let neg_val = val.neg(); + let neg_repr = neg_val.to_repr(); + let neg_bytes = neg_repr.as_ref(); + let neg_high_zero = neg_bytes[8..].iter().all(|&b| b == 0); + if neg_high_zero { + let neg_u64 = u64::from_le_bytes(neg_bytes[..8].try_into().unwrap()); + if neg_u64 > 0 && neg_u64 <= (i64::MAX as u64) + 1 { + return Some(-(neg_u64 as i128) as i64); + } + } + + None + } +} + +// ============================================================================ +// DelayedReduction for Fp +// ============================================================================ + +impl DelayedReduction for halo2curves::pasta::Fp { + type UnreducedFieldInt = SignedWideLimbs<8>; + type UnreducedFieldField = WideLimbs<9>; + + #[inline(always)] + fn unreduced_field_int_mul_add(acc: &mut Self::UnreducedFieldInt, field: &Self, small: i128) { + let (target, mag) = if small >= 0 { + (&mut acc.pos, small as u128) + } else { + (&mut acc.neg, (-small) as u128) + }; + // Fused 4×2 multiply-accumulate: two passes at different offsets + let a = &field.0; + let b_lo = mag as u64; + let b_hi = (mag >> 64) as u64; + + // Pass 1: multiply by b_lo at offset 0 + let (r0, c) = mac(target.0[0], a[0], b_lo, 0); + let (r1, c) = mac(target.0[1], a[1], b_lo, c); + let (r2, c) = mac(target.0[2], a[2], b_lo, c); + let (r3, c) = mac(target.0[3], a[3], b_lo, c); + // Propagate carry without multiply (just add) + let (r4, of1) = target.0[4].overflowing_add(c); + let c1 = of1 as u64; + target.0[0] = r0; + + // Pass 2: multiply by b_hi at offset 1 (add to r1..r5) + let (r1, c) = mac(r1, a[0], b_hi, 0); + let (r2, c) = mac(r2, a[1], b_hi, c); + let (r3, c) = mac(r3, a[2], b_hi, c); + let (r4, c) = mac(r4, a[3], b_hi, c); + // Add both carries (c from pass 2, c1 from pass 1) into position 5 + let (r5, c) = mac(target.0[5], c1, 1, c); + target.0[1] = r1; + target.0[2] = r2; + target.0[3] = r3; + target.0[4] = r4; + target.0[5] = r5; + // Propagate final carry through remaining limbs (just add) + let (r6, of) = target.0[6].overflowing_add(c); + target.0[6] = r6; + target.0[7] = target.0[7].wrapping_add(of as u64); + } + + #[inline(always)] + fn unreduced_field_field_mul_add( + acc: &mut Self::UnreducedFieldField, + field_a: &Self, + field_b: &Self, + ) { + let product = mul_4_by_4_ext(&field_a.0, &field_b.0); + let mut carry = 0u128; + for (acc_limb, &prod_limb) in acc.0.iter_mut().take(8).zip(product.iter()) { + let sum = (*acc_limb as u128) + (prod_limb as u128) + carry; + *acc_limb = sum as u64; + carry = sum >> 64; + } + acc.0[8] = acc.0[8].wrapping_add(carry as u64); + } + + #[inline(always)] + fn unreduced_field_field_mul_add_batch4( + accs: [&mut Self::UnreducedFieldField; 4], + a: [&Self; 4], + b: [&Self; 4], + ) { + batch_unreduced_field_field_mul_add_x4( + accs, + [&a[0].0, &a[1].0, &a[2].0, &a[3].0], + [&b[0].0, &b[1].0, &b[2].0, &b[3].0], + ); + } + + #[inline(always)] + fn reduce_field_int(acc: &Self::UnreducedFieldInt) -> Self { + // Subtract in limb space first, then reduce once (saves one Barrett reduction) + match sub_mag::<8>(&acc.pos.0, &acc.neg.0) { + SubMagResult::Positive(mag) => Self(barrett::barrett_reduce_8_fp(&mag)), + SubMagResult::Negative(mag) => -Self(barrett::barrett_reduce_8_fp(&mag)), + } + } + + #[inline(always)] + fn reduce_field_field(acc: &Self::UnreducedFieldField) -> Self { + Self(barrett::montgomery_reduce_9_fp(&acc.0)) + } +} + +// ============================================================================ +// SmallValueField for Fq +// ============================================================================ + +impl SmallValueField for halo2curves::pasta::Fq { + type IntermediateSmallValue = i64; + + #[inline] + fn ss_mul(a: i32, b: i32) -> i64 { + (a as i64) * (b as i64) + } + + #[inline] + fn sl_mul(small: i32, large: &Self) -> Self { + barrett::mul_by_i64(large, small as i64) + } + + #[inline] + fn isl_mul(small: i64, large: &Self) -> Self { + barrett::mul_by_i64(large, small) + } + + #[inline] + fn small_to_field(val: i32) -> Self { + if val >= 0 { + Self::from(val as u64) + } else { + -Self::from((-val) as u64) + } + } + + #[inline] + fn intermediate_to_field(val: i64) -> Self { + i64_to_field(val) + } + + fn try_field_to_small(val: &Self) -> Option { + try_field_to_small_impl(val) + } +} + +// ============================================================================ +// DelayedReduction for Fq +// ============================================================================ + +impl DelayedReduction for halo2curves::pasta::Fq { + type UnreducedFieldInt = SignedWideLimbs<6>; + type UnreducedFieldField = WideLimbs<9>; + + #[inline(always)] + fn unreduced_field_int_mul_add(acc: &mut Self::UnreducedFieldInt, field: &Self, small: i64) { + // Handle sign: accumulate into pos or neg based on sign of small + let (target, mag) = if small >= 0 { + (&mut acc.pos, small as u64) + } else { + (&mut acc.neg, (-small) as u64) + }; + // Fused multiply-accumulate: no intermediate array + let a = &field.0; + let (r0, c) = mac(target.0[0], a[0], mag, 0); + let (r1, c) = mac(target.0[1], a[1], mag, c); + let (r2, c) = mac(target.0[2], a[2], mag, c); + let (r3, c) = mac(target.0[3], a[3], mag, c); + // Propagate carry without multiply (just add) + let (r4, of) = target.0[4].overflowing_add(c); + target.0[0] = r0; + target.0[1] = r1; + target.0[2] = r2; + target.0[3] = r3; + target.0[4] = r4; + target.0[5] = target.0[5].wrapping_add(of as u64); + } + + #[inline(always)] + fn unreduced_field_field_mul_add( + acc: &mut Self::UnreducedFieldField, + field_a: &Self, + field_b: &Self, + ) { + let product = mul_4_by_4_ext(&field_a.0, &field_b.0); + let mut carry = 0u128; + for (acc_limb, &prod_limb) in acc.0.iter_mut().take(8).zip(product.iter()) { + let sum = (*acc_limb as u128) + (prod_limb as u128) + carry; + *acc_limb = sum as u64; + carry = sum >> 64; + } + acc.0[8] = acc.0[8].wrapping_add(carry as u64); + } + + #[inline(always)] + fn unreduced_field_int_mul_add_batch4( + accs: [&mut Self::UnreducedFieldInt; 4], + field: &Self, + smalls: [i64; 4], + ) { + // Batched ILP version: interleave 4 independent carry chains + batch_unreduced_field_int_mul_add_x4_fq(accs, field, smalls); + } + + #[inline(always)] + fn unreduced_field_field_mul_add_batch4( + accs: [&mut Self::UnreducedFieldField; 4], + a: [&Self; 4], + b: [&Self; 4], + ) { + batch_unreduced_field_field_mul_add_x4( + accs, + [&a[0].0, &a[1].0, &a[2].0, &a[3].0], + [&b[0].0, &b[1].0, &b[2].0, &b[3].0], + ); + } + + #[inline(always)] + fn reduce_field_int(acc: &Self::UnreducedFieldInt) -> Self { + // Subtract in limb space first, then reduce once (saves one Barrett reduction) + match sub_mag::<6>(&acc.pos.0, &acc.neg.0) { + SubMagResult::Positive(mag) => Self(barrett::barrett_reduce_6_fq(&mag)), + SubMagResult::Negative(mag) => -Self(barrett::barrett_reduce_6_fq(&mag)), + } + } + + #[inline(always)] + fn reduce_field_field(acc: &Self::UnreducedFieldField) -> Self { + Self(barrett::montgomery_reduce_9_fq(&acc.0)) + } +} + +// ============================================================================ +// SmallValueField for Fq +// ============================================================================ + +impl SmallValueField for halo2curves::pasta::Fq { + type IntermediateSmallValue = i128; + + #[inline] + fn ss_mul(a: i64, b: i64) -> i128 { + (a as i128) * (b as i128) + } + + #[inline] + fn sl_mul(small: i64, large: &Self) -> Self { + barrett::mul_by_i64(large, small) + } + + #[inline] + fn isl_mul(small: i128, large: &Self) -> Self { + if small == 0 { + return Self::zero(); + } + let (is_neg, mag) = if small >= 0 { + (false, small as u128) + } else { + (true, (-small) as u128) + }; + // mul_4_by_2_ext produces 6 limbs, use barrett_reduce_6 directly (no padding) + let product = mul_4_by_2_ext(&large.0, mag); + let result = Self(barrett::barrett_reduce_6_fq(&product)); + if is_neg { -result } else { result } + } + + #[inline] + fn small_to_field(val: i64) -> Self { + i64_to_field(val) + } + + #[inline] + fn intermediate_to_field(val: i128) -> Self { + i128_to_field(val) + } + + fn try_field_to_small(val: &Self) -> Option { + let repr = val.to_repr(); + let bytes = repr.as_ref(); + + let high_zero = bytes[8..].iter().all(|&b| b == 0); + if high_zero { + let val_u64 = u64::from_le_bytes(bytes[..8].try_into().unwrap()); + if val_u64 <= i64::MAX as u64 { + return Some(val_u64 as i64); + } + } + + let neg_val = val.neg(); + let neg_repr = neg_val.to_repr(); + let neg_bytes = neg_repr.as_ref(); + let neg_high_zero = neg_bytes[8..].iter().all(|&b| b == 0); + if neg_high_zero { + let neg_u64 = u64::from_le_bytes(neg_bytes[..8].try_into().unwrap()); + if neg_u64 > 0 && neg_u64 <= (i64::MAX as u64) + 1 { + return Some(-(neg_u64 as i128) as i64); + } + } + + None + } +} + +// ============================================================================ +// DelayedReduction for Fq +// ============================================================================ + +impl DelayedReduction for halo2curves::pasta::Fq { + type UnreducedFieldInt = SignedWideLimbs<8>; + type UnreducedFieldField = WideLimbs<9>; + + #[inline(always)] + fn unreduced_field_int_mul_add(acc: &mut Self::UnreducedFieldInt, field: &Self, small: i128) { + let (target, mag) = if small >= 0 { + (&mut acc.pos, small as u128) + } else { + (&mut acc.neg, (-small) as u128) + }; + // Fused 4×2 multiply-accumulate: two passes at different offsets + let a = &field.0; + let b_lo = mag as u64; + let b_hi = (mag >> 64) as u64; + + // Pass 1: multiply by b_lo at offset 0 + let (r0, c) = mac(target.0[0], a[0], b_lo, 0); + let (r1, c) = mac(target.0[1], a[1], b_lo, c); + let (r2, c) = mac(target.0[2], a[2], b_lo, c); + let (r3, c) = mac(target.0[3], a[3], b_lo, c); + // Propagate carry without multiply (just add) + let (r4, of1) = target.0[4].overflowing_add(c); + let c1 = of1 as u64; + target.0[0] = r0; + + // Pass 2: multiply by b_hi at offset 1 (add to r1..r5) + let (r1, c) = mac(r1, a[0], b_hi, 0); + let (r2, c) = mac(r2, a[1], b_hi, c); + let (r3, c) = mac(r3, a[2], b_hi, c); + let (r4, c) = mac(r4, a[3], b_hi, c); + // Add both carries (c1 from pass 1, c from pass 2) into position 5 + let (r5, c) = mac(target.0[5], c1, 1, c); + target.0[1] = r1; + target.0[2] = r2; + target.0[3] = r3; + target.0[4] = r4; + target.0[5] = r5; + // Propagate final carry through remaining limbs (just add) + let (r6, of) = target.0[6].overflowing_add(c); + target.0[6] = r6; + target.0[7] = target.0[7].wrapping_add(of as u64); + } + + #[inline(always)] + fn unreduced_field_field_mul_add( + acc: &mut Self::UnreducedFieldField, + field_a: &Self, + field_b: &Self, + ) { + let product = mul_4_by_4_ext(&field_a.0, &field_b.0); + let mut carry = 0u128; + for (acc_limb, &prod_limb) in acc.0.iter_mut().take(8).zip(product.iter()) { + let sum = (*acc_limb as u128) + (prod_limb as u128) + carry; + *acc_limb = sum as u64; + carry = sum >> 64; + } + acc.0[8] = acc.0[8].wrapping_add(carry as u64); + } + + #[inline(always)] + fn unreduced_field_field_mul_add_batch4( + accs: [&mut Self::UnreducedFieldField; 4], + a: [&Self; 4], + b: [&Self; 4], + ) { + batch_unreduced_field_field_mul_add_x4( + accs, + [&a[0].0, &a[1].0, &a[2].0, &a[3].0], + [&b[0].0, &b[1].0, &b[2].0, &b[3].0], + ); + } + + #[inline(always)] + fn reduce_field_int(acc: &Self::UnreducedFieldInt) -> Self { + // Subtract in limb space first, then reduce once (saves one Barrett reduction) + match sub_mag::<8>(&acc.pos.0, &acc.neg.0) { + SubMagResult::Positive(mag) => Self(barrett::barrett_reduce_8_fq(&mag)), + SubMagResult::Negative(mag) => -Self(barrett::barrett_reduce_8_fq(&mag)), + } + } + + #[inline(always)] + fn reduce_field_field(acc: &Self::UnreducedFieldField) -> Self { + Self(barrett::montgomery_reduce_9_fq(&acc.0)) + } +} + +// ============================================================================ +// SmallValueField for BN254 Fr +// ============================================================================ + +impl SmallValueField for Bn254Fr { + type IntermediateSmallValue = i128; + + #[inline] + fn ss_mul(a: i64, b: i64) -> i128 { + (a as i128) * (b as i128) + } + + #[inline] + fn sl_mul(small: i64, large: &Self) -> Self { + barrett::mul_by_i64(large, small) + } + + #[inline] + fn isl_mul(small: i128, large: &Self) -> Self { + if small == 0 { + return Self::zero(); + } + let (is_neg, mag) = if small >= 0 { + (false, small as u128) + } else { + (true, (-small) as u128) + }; + let product = mul_4_by_2_ext(&large.0, mag); + let result = Self(barrett::barrett_reduce_6_bn254_fr(&product)); + if is_neg { -result } else { result } + } + + #[inline] + fn small_to_field(val: i64) -> Self { + i64_to_field(val) + } + + #[inline] + fn intermediate_to_field(val: i128) -> Self { + i128_to_field(val) + } + + fn try_field_to_small(val: &Self) -> Option { + let repr = val.to_repr(); + let bytes = repr.as_ref(); + + let high_zero = bytes[8..].iter().all(|&b| b == 0); + if high_zero { + let val_u64 = u64::from_le_bytes(bytes[..8].try_into().unwrap()); + if val_u64 <= i64::MAX as u64 { + return Some(val_u64 as i64); + } + } + + let neg_val = val.neg(); + let neg_repr = neg_val.to_repr(); + let neg_bytes = neg_repr.as_ref(); + let neg_high_zero = neg_bytes[8..].iter().all(|&b| b == 0); + if neg_high_zero { + let neg_u64 = u64::from_le_bytes(neg_bytes[..8].try_into().unwrap()); + if neg_u64 > 0 && neg_u64 <= (i64::MAX as u64) + 1 { + return Some(-(neg_u64 as i128) as i64); + } + } + + None + } +} + +// ============================================================================ +// DelayedReduction for BN254 Fr +// ============================================================================ + +impl DelayedReduction for Bn254Fr { + type UnreducedFieldInt = SignedWideLimbs<8>; + type UnreducedFieldField = WideLimbs<9>; + + #[inline(always)] + fn unreduced_field_int_mul_add(acc: &mut Self::UnreducedFieldInt, field: &Self, small: i128) { + let (target, mag) = if small >= 0 { + (&mut acc.pos, small as u128) + } else { + (&mut acc.neg, (-small) as u128) + }; + let a = &field.0; + let b_lo = mag as u64; + let b_hi = (mag >> 64) as u64; + + // Pass 1: multiply by b_lo at offset 0 + let (r0, c) = mac(target.0[0], a[0], b_lo, 0); + let (r1, c) = mac(target.0[1], a[1], b_lo, c); + let (r2, c) = mac(target.0[2], a[2], b_lo, c); + let (r3, c) = mac(target.0[3], a[3], b_lo, c); + let (r4, of1) = target.0[4].overflowing_add(c); + let c1 = of1 as u64; + target.0[0] = r0; + + // Pass 2: multiply by b_hi at offset 1 + let (r1, c) = mac(r1, a[0], b_hi, 0); + let (r2, c) = mac(r2, a[1], b_hi, c); + let (r3, c) = mac(r3, a[2], b_hi, c); + let (r4, c) = mac(r4, a[3], b_hi, c); + let (r5, c) = mac(target.0[5], c1, 1, c); + target.0[1] = r1; + target.0[2] = r2; + target.0[3] = r3; + target.0[4] = r4; + target.0[5] = r5; + let (r6, of) = target.0[6].overflowing_add(c); + target.0[6] = r6; + target.0[7] = target.0[7].wrapping_add(of as u64); + } + + #[inline(always)] + fn unreduced_field_field_mul_add( + acc: &mut Self::UnreducedFieldField, + field_a: &Self, + field_b: &Self, + ) { + let product = mul_4_by_4_ext(&field_a.0, &field_b.0); + let mut carry = 0u128; + for (acc_limb, &prod_limb) in acc.0.iter_mut().take(8).zip(product.iter()) { + let sum = (*acc_limb as u128) + (prod_limb as u128) + carry; + *acc_limb = sum as u64; + carry = sum >> 64; + } + acc.0[8] = acc.0[8].wrapping_add(carry as u64); + } + + #[inline(always)] + fn unreduced_field_field_mul_add_batch4( + accs: [&mut Self::UnreducedFieldField; 4], + a: [&Self; 4], + b: [&Self; 4], + ) { + batch_unreduced_field_field_mul_add_x4( + accs, + [&a[0].0, &a[1].0, &a[2].0, &a[3].0], + [&b[0].0, &b[1].0, &b[2].0, &b[3].0], + ); + } + + #[inline(always)] + fn reduce_field_int(acc: &Self::UnreducedFieldInt) -> Self { + match sub_mag::<8>(&acc.pos.0, &acc.neg.0) { + SubMagResult::Positive(mag) => Self(barrett::barrett_reduce_8_bn254_fr(&mag)), + SubMagResult::Negative(mag) => -Self(barrett::barrett_reduce_8_bn254_fr(&mag)), + } + } + + #[inline(always)] + fn reduce_field_field(acc: &Self::UnreducedFieldField) -> Self { + Self(barrett::montgomery_reduce_9_bn254_fr(&acc.0)) + } +} + +// ============================================================================ +// Batched ILP Operations +// ============================================================================ + +/// Internal helper for batched MAC operations (works for any 4-limb field). +/// +/// Interleaves 4 independent carry chains for better ILP on AArch64. +#[inline(always)] +fn batch_mac_4limb_x4(targets: [&mut WideLimbs<6>; 4], a: &[u64; 4], mags: [u64; 4]) { + // Limb 0: 4 independent macs (ILP - CPU can overlap these) + let (r0_0, c0) = mac(targets[0].0[0], a[0], mags[0], 0); + let (r1_0, c1) = mac(targets[1].0[0], a[0], mags[1], 0); + let (r2_0, c2) = mac(targets[2].0[0], a[0], mags[2], 0); + let (r3_0, c3) = mac(targets[3].0[0], a[0], mags[3], 0); + + // Limb 1: 4 independent macs (each uses its own carry) + let (r0_1, c0) = mac(targets[0].0[1], a[1], mags[0], c0); + let (r1_1, c1) = mac(targets[1].0[1], a[1], mags[1], c1); + let (r2_1, c2) = mac(targets[2].0[1], a[1], mags[2], c2); + let (r3_1, c3) = mac(targets[3].0[1], a[1], mags[3], c3); + + // Limb 2: 4 independent macs + let (r0_2, c0) = mac(targets[0].0[2], a[2], mags[0], c0); + let (r1_2, c1) = mac(targets[1].0[2], a[2], mags[1], c1); + let (r2_2, c2) = mac(targets[2].0[2], a[2], mags[2], c2); + let (r3_2, c3) = mac(targets[3].0[2], a[2], mags[3], c3); + + // Limb 3: 4 independent macs + let (r0_3, c0) = mac(targets[0].0[3], a[3], mags[0], c0); + let (r1_3, c1) = mac(targets[1].0[3], a[3], mags[1], c1); + let (r2_3, c2) = mac(targets[2].0[3], a[3], mags[2], c2); + let (r3_3, c3) = mac(targets[3].0[3], a[3], mags[3], c3); + + // Final carry propagation for all 4 (still independent) + let (r0_4, of0) = targets[0].0[4].overflowing_add(c0); + let (r1_4, of1) = targets[1].0[4].overflowing_add(c1); + let (r2_4, of2) = targets[2].0[4].overflowing_add(c2); + let (r3_4, of3) = targets[3].0[4].overflowing_add(c3); + + // Store results for accumulator 0 + targets[0].0[0] = r0_0; + targets[0].0[1] = r0_1; + targets[0].0[2] = r0_2; + targets[0].0[3] = r0_3; + targets[0].0[4] = r0_4; + targets[0].0[5] = targets[0].0[5].wrapping_add(of0 as u64); + + // Store results for accumulator 1 + targets[1].0[0] = r1_0; + targets[1].0[1] = r1_1; + targets[1].0[2] = r1_2; + targets[1].0[3] = r1_3; + targets[1].0[4] = r1_4; + targets[1].0[5] = targets[1].0[5].wrapping_add(of1 as u64); + + // Store results for accumulator 2 + targets[2].0[0] = r2_0; + targets[2].0[1] = r2_1; + targets[2].0[2] = r2_2; + targets[2].0[3] = r2_3; + targets[2].0[4] = r2_4; + targets[2].0[5] = targets[2].0[5].wrapping_add(of2 as u64); + + // Store results for accumulator 3 + targets[3].0[0] = r3_0; + targets[3].0[1] = r3_1; + targets[3].0[2] = r3_2; + targets[3].0[3] = r3_3; + targets[3].0[4] = r3_4; + targets[3].0[5] = targets[3].0[5].wrapping_add(of3 as u64); +} + +/// Batch 4 independent field×int multiply-accumulates for instruction-level parallelism. +/// +/// On AArch64 (M1/M2), this allows the CPU to overlap mul/umulh latencies across +/// 4 independent carry chains, significantly improving throughput compared to +/// processing one accumulation at a time. +/// +/// # Safety +/// - All 4 accumulators must be valid and non-overlapping +/// - This is for Fq (4 limbs, 256-bit field) +#[inline(always)] +pub fn batch_unreduced_field_int_mul_add_x4_fq( + accs: [&mut SignedWideLimbs<6>; 4], + field: &halo2curves::pasta::Fq, + smalls: [i64; 4], +) { + // Destructure to get 4 independent mutable references (satisfies borrow checker) + let [acc0, acc1, acc2, acc3] = accs; + + // Prepare targets and magnitudes for each of the 4 operations + let (target0, mag0) = if smalls[0] >= 0 { + (&mut acc0.pos, smalls[0] as u64) + } else { + (&mut acc0.neg, (-smalls[0]) as u64) + }; + let (target1, mag1) = if smalls[1] >= 0 { + (&mut acc1.pos, smalls[1] as u64) + } else { + (&mut acc1.neg, (-smalls[1]) as u64) + }; + let (target2, mag2) = if smalls[2] >= 0 { + (&mut acc2.pos, smalls[2] as u64) + } else { + (&mut acc2.neg, (-smalls[2]) as u64) + }; + let (target3, mag3) = if smalls[3] >= 0 { + (&mut acc3.pos, smalls[3] as u64) + } else { + (&mut acc3.neg, (-smalls[3]) as u64) + }; + + let a = &field.0; + + // Limb 0: 4 independent macs (ILP - CPU can overlap these) + let (r0_0, c0) = mac(target0.0[0], a[0], mag0, 0); + let (r1_0, c1) = mac(target1.0[0], a[0], mag1, 0); + let (r2_0, c2) = mac(target2.0[0], a[0], mag2, 0); + let (r3_0, c3) = mac(target3.0[0], a[0], mag3, 0); + + // Limb 1: 4 independent macs (each uses its own carry) + let (r0_1, c0) = mac(target0.0[1], a[1], mag0, c0); + let (r1_1, c1) = mac(target1.0[1], a[1], mag1, c1); + let (r2_1, c2) = mac(target2.0[1], a[1], mag2, c2); + let (r3_1, c3) = mac(target3.0[1], a[1], mag3, c3); + + // Limb 2: 4 independent macs + let (r0_2, c0) = mac(target0.0[2], a[2], mag0, c0); + let (r1_2, c1) = mac(target1.0[2], a[2], mag1, c1); + let (r2_2, c2) = mac(target2.0[2], a[2], mag2, c2); + let (r3_2, c3) = mac(target3.0[2], a[2], mag3, c3); + + // Limb 3: 4 independent macs + let (r0_3, c0) = mac(target0.0[3], a[3], mag0, c0); + let (r1_3, c1) = mac(target1.0[3], a[3], mag1, c1); + let (r2_3, c2) = mac(target2.0[3], a[3], mag2, c2); + let (r3_3, c3) = mac(target3.0[3], a[3], mag3, c3); + + // Final carry propagation for all 4 (still independent) + let (r0_4, of0) = target0.0[4].overflowing_add(c0); + let (r1_4, of1) = target1.0[4].overflowing_add(c1); + let (r2_4, of2) = target2.0[4].overflowing_add(c2); + let (r3_4, of3) = target3.0[4].overflowing_add(c3); + + // Store results for accumulator 0 + target0.0[0] = r0_0; + target0.0[1] = r0_1; + target0.0[2] = r0_2; + target0.0[3] = r0_3; + target0.0[4] = r0_4; + target0.0[5] = target0.0[5].wrapping_add(of0 as u64); + + // Store results for accumulator 1 + target1.0[0] = r1_0; + target1.0[1] = r1_1; + target1.0[2] = r1_2; + target1.0[3] = r1_3; + target1.0[4] = r1_4; + target1.0[5] = target1.0[5].wrapping_add(of1 as u64); + + // Store results for accumulator 2 + target2.0[0] = r2_0; + target2.0[1] = r2_1; + target2.0[2] = r2_2; + target2.0[3] = r2_3; + target2.0[4] = r2_4; + target2.0[5] = target2.0[5].wrapping_add(of2 as u64); + + // Store results for accumulator 3 + target3.0[0] = r3_0; + target3.0[1] = r3_1; + target3.0[2] = r3_2; + target3.0[3] = r3_3; + target3.0[4] = r3_4; + target3.0[5] = target3.0[5].wrapping_add(of3 as u64); +} + +/// Batch 4 independent field×int multiply-accumulates for Fp. +#[inline(always)] +pub fn batch_unreduced_field_int_mul_add_x4_fp( + accs: [&mut SignedWideLimbs<6>; 4], + field: &halo2curves::pasta::Fp, + smalls: [i64; 4], +) { + // Destructure to get 4 independent mutable references (satisfies borrow checker) + let [acc0, acc1, acc2, acc3] = accs; + + // Prepare targets and magnitudes for each of the 4 operations + let (target0, mag0) = if smalls[0] >= 0 { + (&mut acc0.pos, smalls[0] as u64) + } else { + (&mut acc0.neg, (-smalls[0]) as u64) + }; + let (target1, mag1) = if smalls[1] >= 0 { + (&mut acc1.pos, smalls[1] as u64) + } else { + (&mut acc1.neg, (-smalls[1]) as u64) + }; + let (target2, mag2) = if smalls[2] >= 0 { + (&mut acc2.pos, smalls[2] as u64) + } else { + (&mut acc2.neg, (-smalls[2]) as u64) + }; + let (target3, mag3) = if smalls[3] >= 0 { + (&mut acc3.pos, smalls[3] as u64) + } else { + (&mut acc3.neg, (-smalls[3]) as u64) + }; + + batch_mac_4limb_x4( + [target0, target1, target2, target3], + &field.0, + [mag0, mag1, mag2, mag3], + ); +} + +/// Batch 4 independent field×field multiply-accumulates for ILP optimization. +/// +/// Computes 4 products in parallel (mul_4_by_4_ext) and adds them to 4 separate +/// 9-limb accumulators with interleaved carry propagation for better ILP on AArch64. +#[inline(always)] +pub fn batch_unreduced_field_field_mul_add_x4( + accs: [&mut WideLimbs<9>; 4], + a: [&[u64; 4]; 4], + b: [&[u64; 4]; 4], +) { + let [acc0, acc1, acc2, acc3] = accs; + + // Compute 4 products (ILP: these can be computed independently) + let prod0 = mul_4_by_4_ext(a[0], b[0]); + let prod1 = mul_4_by_4_ext(a[1], b[1]); + let prod2 = mul_4_by_4_ext(a[2], b[2]); + let prod3 = mul_4_by_4_ext(a[3], b[3]); + + // Add products to accumulators with interleaved carry propagation + // Limb 0 + let sum0 = (acc0.0[0] as u128) + (prod0[0] as u128); + let sum1 = (acc1.0[0] as u128) + (prod1[0] as u128); + let sum2 = (acc2.0[0] as u128) + (prod2[0] as u128); + let sum3 = (acc3.0[0] as u128) + (prod3[0] as u128); + acc0.0[0] = sum0 as u64; + acc1.0[0] = sum1 as u64; + acc2.0[0] = sum2 as u64; + acc3.0[0] = sum3 as u64; + let (mut c0, mut c1, mut c2, mut c3) = (sum0 >> 64, sum1 >> 64, sum2 >> 64, sum3 >> 64); + + // Limb 1 + let sum0 = (acc0.0[1] as u128) + (prod0[1] as u128) + c0; + let sum1 = (acc1.0[1] as u128) + (prod1[1] as u128) + c1; + let sum2 = (acc2.0[1] as u128) + (prod2[1] as u128) + c2; + let sum3 = (acc3.0[1] as u128) + (prod3[1] as u128) + c3; + acc0.0[1] = sum0 as u64; + acc1.0[1] = sum1 as u64; + acc2.0[1] = sum2 as u64; + acc3.0[1] = sum3 as u64; + (c0, c1, c2, c3) = (sum0 >> 64, sum1 >> 64, sum2 >> 64, sum3 >> 64); + + // Limb 2 + let sum0 = (acc0.0[2] as u128) + (prod0[2] as u128) + c0; + let sum1 = (acc1.0[2] as u128) + (prod1[2] as u128) + c1; + let sum2 = (acc2.0[2] as u128) + (prod2[2] as u128) + c2; + let sum3 = (acc3.0[2] as u128) + (prod3[2] as u128) + c3; + acc0.0[2] = sum0 as u64; + acc1.0[2] = sum1 as u64; + acc2.0[2] = sum2 as u64; + acc3.0[2] = sum3 as u64; + (c0, c1, c2, c3) = (sum0 >> 64, sum1 >> 64, sum2 >> 64, sum3 >> 64); + + // Limb 3 + let sum0 = (acc0.0[3] as u128) + (prod0[3] as u128) + c0; + let sum1 = (acc1.0[3] as u128) + (prod1[3] as u128) + c1; + let sum2 = (acc2.0[3] as u128) + (prod2[3] as u128) + c2; + let sum3 = (acc3.0[3] as u128) + (prod3[3] as u128) + c3; + acc0.0[3] = sum0 as u64; + acc1.0[3] = sum1 as u64; + acc2.0[3] = sum2 as u64; + acc3.0[3] = sum3 as u64; + (c0, c1, c2, c3) = (sum0 >> 64, sum1 >> 64, sum2 >> 64, sum3 >> 64); + + // Limb 4 + let sum0 = (acc0.0[4] as u128) + (prod0[4] as u128) + c0; + let sum1 = (acc1.0[4] as u128) + (prod1[4] as u128) + c1; + let sum2 = (acc2.0[4] as u128) + (prod2[4] as u128) + c2; + let sum3 = (acc3.0[4] as u128) + (prod3[4] as u128) + c3; + acc0.0[4] = sum0 as u64; + acc1.0[4] = sum1 as u64; + acc2.0[4] = sum2 as u64; + acc3.0[4] = sum3 as u64; + (c0, c1, c2, c3) = (sum0 >> 64, sum1 >> 64, sum2 >> 64, sum3 >> 64); + + // Limb 5 + let sum0 = (acc0.0[5] as u128) + (prod0[5] as u128) + c0; + let sum1 = (acc1.0[5] as u128) + (prod1[5] as u128) + c1; + let sum2 = (acc2.0[5] as u128) + (prod2[5] as u128) + c2; + let sum3 = (acc3.0[5] as u128) + (prod3[5] as u128) + c3; + acc0.0[5] = sum0 as u64; + acc1.0[5] = sum1 as u64; + acc2.0[5] = sum2 as u64; + acc3.0[5] = sum3 as u64; + (c0, c1, c2, c3) = (sum0 >> 64, sum1 >> 64, sum2 >> 64, sum3 >> 64); + + // Limb 6 + let sum0 = (acc0.0[6] as u128) + (prod0[6] as u128) + c0; + let sum1 = (acc1.0[6] as u128) + (prod1[6] as u128) + c1; + let sum2 = (acc2.0[6] as u128) + (prod2[6] as u128) + c2; + let sum3 = (acc3.0[6] as u128) + (prod3[6] as u128) + c3; + acc0.0[6] = sum0 as u64; + acc1.0[6] = sum1 as u64; + acc2.0[6] = sum2 as u64; + acc3.0[6] = sum3 as u64; + (c0, c1, c2, c3) = (sum0 >> 64, sum1 >> 64, sum2 >> 64, sum3 >> 64); + + // Limb 7 + let sum0 = (acc0.0[7] as u128) + (prod0[7] as u128) + c0; + let sum1 = (acc1.0[7] as u128) + (prod1[7] as u128) + c1; + let sum2 = (acc2.0[7] as u128) + (prod2[7] as u128) + c2; + let sum3 = (acc3.0[7] as u128) + (prod3[7] as u128) + c3; + acc0.0[7] = sum0 as u64; + acc1.0[7] = sum1 as u64; + acc2.0[7] = sum2 as u64; + acc3.0[7] = sum3 as u64; + (c0, c1, c2, c3) = (sum0 >> 64, sum1 >> 64, sum2 >> 64, sum3 >> 64); + + // Limb 8 (final carry) + acc0.0[8] = acc0.0[8].wrapping_add(c0 as u64); + acc1.0[8] = acc1.0[8].wrapping_add(c1 as u64); + acc2.0[8] = acc2.0[8].wrapping_add(c2 as u64); + acc3.0[8] = acc3.0[8].wrapping_add(c3 as u64); +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + use crate::{polys::multilinear::MultilinearPolynomial, provider::pasta::pallas}; + + type Scalar = pallas::Scalar; + + #[test] + fn test_small_value_field_arithmetic() { + let a: i32 = 10; + let b: i32 = 3; + + assert_eq!(a + b, 13); + assert_eq!(a - b, 7); + assert_eq!(-a, -10); + assert_eq!(>::ss_mul(a, b), 30i64); + assert_eq!(a * 5, 50); // ss_mul_const is just native multiplication + } + + #[test] + fn test_small_value_field_negative() { + let a: i32 = -5; + let b: i32 = 3; + + assert_eq!(a + b, -2); + assert_eq!(a - b, -8); + assert_eq!(>::ss_mul(a, b), -15i64); + + let field_a = >::small_to_field(a); + assert_eq!(field_a, -Scalar::from(5u64)); + } + + #[test] + fn test_i64_to_field() { + let pos: Scalar = i64_to_field(100); + assert_eq!(pos, Scalar::from(100u64)); + + let neg: Scalar = i64_to_field(-50); + assert_eq!(neg, -Scalar::from(50u64)); + } + + #[test] + fn test_try_field_to_i64_roundtrip() { + use super::super::try_field_to_i64; + + // Test positive values + for val in [0i64, 1, 100, 1_000_000, i64::MAX / 2, i64::MAX] { + let field: Scalar = i64_to_field(val); + let back = try_field_to_i64(&field).expect("should fit"); + assert_eq!(back, val, "roundtrip failed for {}", val); + } + + // Test negative values + for val in [-1i64, -100, -1_000_000, i64::MIN / 2, i64::MIN + 1] { + let field: Scalar = i64_to_field(val); + let back = try_field_to_i64(&field).expect("should fit"); + assert_eq!(back, val, "roundtrip failed for {}", val); + } + + // Test i64::MIN separately (edge case) + let field: Scalar = i64_to_field(i64::MIN); + let back = try_field_to_i64(&field).expect("should fit"); + assert_eq!(back, i64::MIN, "roundtrip failed for i64::MIN"); + + // Test values that don't fit in i64 + let too_large = Scalar::from(u64::MAX) + Scalar::from(1u64); + assert!(try_field_to_i64(&too_large).is_none()); + } + + #[test] + fn test_small_multilinear_polynomial() { + let poly = MultilinearPolynomial::new(vec![1i32, 2, 3, 4]); + + assert_eq!(poly.num_vars(), 2); + assert_eq!(poly.Z.len(), 4); + assert_eq!(poly[0], 1); + assert_eq!(poly[3], 4); + } + + #[test] + fn test_to_field_conversion() { + let evals: Vec = vec![1, -2, 3, -4]; + let small_poly = MultilinearPolynomial::new(evals); + let field_poly: MultilinearPolynomial = small_poly.to_field(); + + assert_eq!(field_poly.Z[0], Scalar::from(1u64)); + assert_eq!(field_poly.Z[1], -Scalar::from(2u64)); + assert_eq!(field_poly.Z[2], Scalar::from(3u64)); + assert_eq!(field_poly.Z[3], -Scalar::from(4u64)); + } + + #[test] + fn test_try_field_to_small_roundtrip() { + assert_eq!( + >::try_field_to_small(&Scalar::from(42u64)), + Some(42) + ); + assert_eq!( + >::try_field_to_small(&-Scalar::from(100u64)), + Some(-100) + ); + assert_eq!( + >::try_field_to_small(&Scalar::from(u64::MAX)), + None + ); + } + + #[test] + fn test_isl_mul() { + use ff::Field; + use rand_core::OsRng; + + let large = Scalar::random(&mut OsRng); + let small: i64 = 12345; + + let result = >::isl_mul(small, &large); + let expected = i64_to_field::(small) * large; + + assert_eq!(result, expected); + } + + #[test] + fn test_sl_mul() { + use ff::Field; + use rand_core::OsRng; + + let large = Scalar::random(&mut OsRng); + let small: i32 = -999; + + let result = >::sl_mul(small, &large); + let expected = >::small_to_field(small) * large; + + assert_eq!(result, expected); + } + + #[test] + fn test_overflow_bounds() { + let typical_witness = 1i32 << 20; + let extension_factor = 27i32; + let after_extension = typical_witness * extension_factor; + + let prod = >::ss_mul(after_extension, after_extension); + assert!(prod > 0); + assert!(prod < (1i64 << 55)); + } + + #[test] + fn test_ss_sign_combinations() { + assert_eq!(>::ss_mul(100, 200), 20000i64); + assert_eq!( + >::ss_mul(-100, -200), + 20000i64 + ); + assert_eq!( + >::ss_mul(100, -200), + -20000i64 + ); + assert_eq!( + >::ss_mul(-100, 200), + -20000i64 + ); + } + + #[test] + fn test_ss_zero_edge_cases() { + let zero = 0i32; + let val = 12345i32; + + assert_eq!(>::ss_mul(zero, val), 0i64); + assert_eq!(>::ss_mul(val, zero), 0i64); + } + + #[test] + fn test_isl_with_random() { + use ff::Field; + use rand_core::{OsRng, RngCore}; + + let mut rng = OsRng; + for _ in 0..100 { + let large = Scalar::random(&mut rng); + let small = (rng.next_u64() % (i64::MAX as u64)) as i64; + let small = if rng.next_u32().is_multiple_of(2) { + small + } else { + -small + }; + + let result = >::isl_mul(small, &large); + let expected = i64_to_field::(small) * large; + + assert_eq!(result, expected); + } + } + + #[test] + fn test_fp_small_value_field() { + use halo2curves::pasta::Fp; + + let a: i32 = 42; + let b: i32 = -10; + + assert_eq!(a + b, 32); + assert_eq!(>::ss_mul(a, b), -420i64); + assert_eq!( + >::small_to_field(a), + Fp::from(42u64) + ); + } + + #[test] + fn test_unreduced_field_int_mul_add() { + use crate::small_field::limbs::SignedWideLimbs; + use ff::Field; + use rand_core::{OsRng, RngCore}; + + let mut rng = OsRng; + let mut acc: SignedWideLimbs<6> = Default::default(); + let mut expected = Scalar::ZERO; + + // Sum 100 field × i64 products (mix of positive and negative) + for i in 0..100 { + let field = Scalar::random(&mut rng); + let small_u = rng.next_u64() >> 32; // Keep smaller to avoid extreme overflow + // Alternate signs for variety + let small: i64 = if i % 2 == 0 { + small_u as i64 + } else { + -(small_u as i64) + }; + + >::unreduced_field_int_mul_add(&mut acc, &field, small); + expected += field * i64_to_field::(small); + } + + let result = >::reduce_field_int(&acc); + assert_eq!(result, expected); + } + + #[test] + fn test_unreduced_field_field_mul_add() { + use crate::small_field::limbs::WideLimbs; + use ff::Field; + use rand_core::OsRng; + + let mut rng = OsRng; + let mut acc: WideLimbs<9> = Default::default(); + let mut expected = Scalar::ZERO; + + // Sum 100 field × field products + for _ in 0..100 { + let a = Scalar::random(&mut rng); + let b = Scalar::random(&mut rng); + + >::unreduced_field_field_mul_add(&mut acc, &a, &b); + expected += a * b; + } + + let result = >::reduce_field_field(&acc); + assert_eq!(result, expected); + } + + #[test] + fn test_unreduced_field_int_many_products() { + use crate::small_field::limbs::SignedWideLimbs; + use ff::Field; + use rand_core::{OsRng, RngCore}; + + let mut rng = OsRng; + let mut acc: SignedWideLimbs<6> = Default::default(); + let mut expected = Scalar::ZERO; + + // Stress test: sum 2000 products (mix of positive and negative) + for i in 0..2000 { + let field = Scalar::random(&mut rng); + let small_u = rng.next_u64(); + // Alternate signs for variety + let small: i64 = if i % 3 == 0 { + (small_u >> 1) as i64 // positive, shifted to fit in i64 + } else if i % 3 == 1 { + -((small_u >> 1) as i64) // negative + } else { + 0 // occasionally zero + }; + + >::unreduced_field_int_mul_add(&mut acc, &field, small); + expected += field * i64_to_field::(small); + } + + let result = >::reduce_field_int(&acc); + assert_eq!(result, expected); + } + + #[test] + fn test_unreduced_field_field_many_products() { + use crate::small_field::limbs::WideLimbs; + use ff::Field; + use rand_core::OsRng; + + let mut rng = OsRng; + let mut acc: WideLimbs<9> = Default::default(); + let mut expected = Scalar::ZERO; + + // Stress test: sum 1000 products + for _ in 0..1000 { + let a = Scalar::random(&mut rng); + let b = Scalar::random(&mut rng); + + >::unreduced_field_field_mul_add(&mut acc, &a, &b); + expected += a * b; + } + + let result = >::reduce_field_field(&acc); + assert_eq!(result, expected); + } +} diff --git a/src/small_field/limbs.rs b/src/small_field/limbs.rs new file mode 100644 index 0000000..789a31f --- /dev/null +++ b/src/small_field/limbs.rs @@ -0,0 +1,506 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Wide integers and limb operations for delayed modular reduction. +//! +//! This module provides: +//! - [`WideLimbs`]: Stack-allocated wide integers for accumulating unreduced products +//! - [`SignedWideLimbs`]: Signed variant for accumulating signed products +//! - Limb arithmetic operations (multiply, subtract, compare) +//! +//! # Why not `num-bigint`? +//! +//! We define our own types rather than using `num-bigint::BigInt` because: +//! - `BigInt` is heap-allocated (uses `Vec`) +//! - We need stack-allocated fixed-size arrays for hot-path performance +//! - We want the `Copy` trait for cheap pass-by-value in tight loops + +use num_traits::Zero; +use std::ops::{Add, AddAssign}; + +// ============================================================================ +// WideLimbs - Stack-allocated wide integer +// ============================================================================ + +/// Stack-allocated wide integer with N 64-bit limbs. +/// +/// Limbs are stored in little-endian order: `limbs[0]` is the least significant. +/// +/// # Type Parameters +/// +/// - `N`: Number of 64-bit limbs. Common values: +/// - `N=6` (384 bits): For `UnreducedFieldInt` (field × integer products) +/// - `N=9` (576 bits): For `UnreducedFieldField` (field × field products) +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct WideLimbs(pub [u64; N]); + +impl Default for WideLimbs { + fn default() -> Self { + Self([0u64; N]) + } +} + +impl AddAssign for WideLimbs { + /// Wide addition with carry propagation. + #[inline] + fn add_assign(&mut self, other: Self) { + let mut carry = 0u64; + for i in 0..N { + let (sum, c1) = self.0[i].overflowing_add(other.0[i]); + let (sum, c2) = sum.overflowing_add(carry); + self.0[i] = sum; + carry = (c1 as u64) + (c2 as u64); + } + // Note: We intentionally don't check for overflow here. + // The caller is responsible for ensuring the sum doesn't exceed N limbs. + } +} + +impl AddAssign<&Self> for WideLimbs { + /// Wide addition with carry propagation (reference variant). + #[inline] + fn add_assign(&mut self, other: &Self) { + let mut carry = 0u64; + for i in 0..N { + let (sum, c1) = self.0[i].overflowing_add(other.0[i]); + let (sum, c2) = sum.overflowing_add(carry); + self.0[i] = sum; + carry = (c1 as u64) + (c2 as u64); + } + } +} + +impl Add for WideLimbs { + type Output = Self; + + #[inline] + fn add(mut self, other: Self) -> Self { + self += other; + self + } +} + +impl Add<&Self> for WideLimbs { + type Output = Self; + + #[inline] + fn add(mut self, other: &Self) -> Self { + self += other; + self + } +} + +impl Zero for WideLimbs { + #[inline] + fn zero() -> Self { + Self([0u64; N]) + } + + #[inline] + fn is_zero(&self) -> bool { + self.0.iter().all(|&x| x == 0) + } +} + +// ============================================================================ +// SignedWideLimbs - for accumulating signed products +// ============================================================================ + +/// Pair of wide integers for accumulating signed products. +/// +/// Since `WideLimbs` only supports unsigned addition, we track positive and +/// negative contributions separately, then subtract at the end. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct SignedWideLimbs { + /// Accumulator for positive contributions + pub pos: WideLimbs, + /// Accumulator for negative contributions (stored as positive magnitude) + pub neg: WideLimbs, +} + +impl Default for SignedWideLimbs { + fn default() -> Self { + Self { + pos: WideLimbs::zero(), + neg: WideLimbs::zero(), + } + } +} + +impl AddAssign for SignedWideLimbs { + /// Merge two signed accumulators by adding their respective parts. + #[inline] + fn add_assign(&mut self, other: Self) { + self.pos += other.pos; + self.neg += other.neg; + } +} + +impl AddAssign<&Self> for SignedWideLimbs { + #[inline] + fn add_assign(&mut self, other: &Self) { + self.pos += &other.pos; + self.neg += &other.neg; + } +} + +impl Add for SignedWideLimbs { + type Output = Self; + + #[inline] + fn add(mut self, other: Self) -> Self { + self.pos += other.pos; + self.neg += other.neg; + self + } +} + +impl Zero for SignedWideLimbs { + #[inline] + fn zero() -> Self { + Self { + pos: WideLimbs::zero(), + neg: WideLimbs::zero(), + } + } + + #[inline] + fn is_zero(&self) -> bool { + self.pos.is_zero() && self.neg.is_zero() + } +} + +// ============================================================================ +// SubMagResult - magnitude subtraction result +// ============================================================================ + +/// Result of magnitude subtraction |a - b|. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum SubMagResult { + /// a >= b, contains a - b + Positive([u64; N]), + /// a < b, contains b - a + Negative([u64; N]), +} + +/// Compute |a - b| and return the magnitude with sign information. +/// +/// Used to reduce two wide integers to one before Barrett reduction, +/// saving one expensive reduction operation. +#[inline(always)] +pub fn sub_mag(a: &[u64; N], b: &[u64; N]) -> SubMagResult { + let mut out = [0u64; N]; + let mut borrow = 0u64; + for i in 0..N { + let (d1, b1) = a[i].overflowing_sub(b[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + out[i] = d2; + borrow = (b1 as u64) + (b2 as u64); + } + if borrow == 0 { + SubMagResult::Positive(out) + } else { + // a < b, compute b - a instead + let mut out2 = [0u64; N]; + borrow = 0; + for i in 0..N { + let (d1, b1) = b[i].overflowing_sub(a[i]); + let (d2, b2) = d1.overflowing_sub(borrow); + out2[i] = d2; + borrow = (b1 as u64) + (b2 as u64); + } + SubMagResult::Negative(out2) + } +} + +// ============================================================================ +// Limb multiplication operations +// ============================================================================ + +/// Multiply-accumulate: acc + a * b + carry → (low, high) +/// +/// Fused operation that computes one limb of a multiply-accumulate in a single step, +/// avoiding materialization of intermediate arrays. +#[inline(always)] +pub fn mac(acc: u64, a: u64, b: u64, carry: u64) -> (u64, u64) { + let prod = (a as u128) * (b as u128) + (acc as u128) + (carry as u128); + (prod as u64, (prod >> 64) as u64) +} + +/// Multiply two 4-limb values, producing an 8-limb result. +#[inline(always)] +pub fn mul_4_by_4_ext(a: &[u64; 4], b: &[u64; 4]) -> [u64; 8] { + let mut result = [0u64; 8]; + for i in 0..4 { + let mut carry = 0u128; + for j in 0..4 { + let prod = (a[i] as u128) * (b[j] as u128) + (result[i + j] as u128) + carry; + result[i + j] = prod as u64; + carry = prod >> 64; + } + result[i + 4] = carry as u64; + } + result +} + +/// Multiply 4-limb field by 2-limb integer (u128), producing a 6-limb result. +/// Used for i64/i128 small-value optimization where IntermediateSmallValue is i128. +#[inline(always)] +pub fn mul_4_by_2_ext(a: &[u64; 4], b: u128) -> [u64; 6] { + let b_lo = b as u64; + let b_hi = (b >> 64) as u64; + + // Multiply a by b_lo (4x1 -> 5 limbs) + let mut result = [0u64; 6]; + let mut carry = 0u128; + for i in 0..4 { + let prod = (a[i] as u128) * (b_lo as u128) + carry; + result[i] = prod as u64; + carry = prod >> 64; + } + result[4] = carry as u64; + + // Multiply a by b_hi and add at offset 1 (4x1 -> 5 limbs, shifted) + carry = 0u128; + for i in 0..4 { + let prod = (a[i] as u128) * (b_hi as u128) + (result[i + 1] as u128) + carry; + result[i + 1] = prod as u64; + carry = prod >> 64; + } + result[5] = carry as u64; + + result +} + +/// Multiply 4-limb by 1-limb, producing a 5-limb result. +#[inline(always)] +pub(super) fn mul_4_by_1(a: &[u64; 4], b: u64) -> [u64; 5] { + let mut result = [0u64; 5]; + let mut carry = 0u128; + for i in 0..4 { + let prod = (a[i] as u128) * (b as u128) + carry; + result[i] = prod as u64; + carry = prod >> 64; + } + result[4] = carry as u64; + result +} + +/// Multiply 5-limb by 1-limb, producing a 5-limb result (overflow ignored). +#[inline(always)] +pub(super) fn mul_5_by_1(a: &[u64; 5], b: u64) -> [u64; 5] { + let mut result = [0u64; 5]; + let mut carry = 0u128; + for i in 0..5 { + let prod = (a[i] as u128) * (b as u128) + carry; + result[i] = prod as u64; + carry = prod >> 64; + } + result +} + +// ============================================================================ +// Limb subtraction operations +// ============================================================================ + +/// Subtract two 5-limb values: a - b. +#[inline(always)] +pub(super) fn sub_5_5(a: &[u64; 5], b: &[u64; 5]) -> [u64; 5] { + let mut result = [0u64; 5]; + let mut borrow = 0u64; + for i in 0..5 { + let (diff, b1) = a[i].overflowing_sub(b[i]); + let (diff2, b2) = diff.overflowing_sub(borrow); + result[i] = diff2; + borrow = (b1 as u64) + (b2 as u64); + } + result +} + +/// Subtract 4-limb from 5-limb: a - b. +#[inline(always)] +pub(super) fn sub_5_4(a: &[u64; 5], b: &[u64; 4]) -> [u64; 5] { + let mut result = [0u64; 5]; + let mut borrow = 0u64; + for i in 0..4 { + let (diff, b1) = a[i].overflowing_sub(b[i]); + let (diff2, b2) = diff.overflowing_sub(borrow); + result[i] = diff2; + borrow = (b1 as u64) + (b2 as u64); + } + let (diff, _) = a[4].overflowing_sub(borrow); + result[4] = diff; + result +} + +/// Subtract two 4-limb values: a - b. +#[inline(always)] +pub(super) fn sub_4_4(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] { + let mut result = [0u64; 4]; + let mut borrow = 0u64; + for i in 0..4 { + let (diff, b1) = a[i].overflowing_sub(b[i]); + let (diff2, b2) = diff.overflowing_sub(borrow); + result[i] = diff2; + borrow = (b1 as u64) + (b2 as u64); + } + result +} + +// ============================================================================ +// Limb comparison operations +// ============================================================================ + +/// Check if 5-limb value >= 4-limb value. +#[inline(always)] +pub(super) fn gte_5_4(a: &[u64; 5], b: &[u64; 4]) -> bool { + if a[4] > 0 { + return true; + } + for i in (0..4).rev() { + if a[i] > b[i] { + return true; + } + if a[i] < b[i] { + return false; + } + } + true +} + +/// Check if 4-limb value a >= 4-limb value b. +#[inline(always)] +pub(super) fn gte_4_4(a: &[u64; 4], b: &[u64; 4]) -> bool { + for i in (0..4).rev() { + if a[i] > b[i] { + return true; + } + if a[i] < b[i] { + return false; + } + } + true // equal +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_zero() { + let z: WideLimbs<6> = WideLimbs::zero(); + assert!(z.is_zero()); + assert_eq!(z.0, [0u64; 6]); + } + + #[test] + fn test_default_is_zero() { + let z: WideLimbs<6> = WideLimbs::default(); + assert!(z.is_zero()); + } + + #[test] + fn test_add_no_carry() { + let a: WideLimbs<4> = WideLimbs([1, 2, 3, 4]); + let b: WideLimbs<4> = WideLimbs([10, 20, 30, 40]); + let c = a + b; + assert_eq!(c.0, [11, 22, 33, 44]); + } + + #[test] + fn test_add_with_carry() { + // Test carry propagation + let a: WideLimbs<4> = WideLimbs([u64::MAX, 0, 0, 0]); + let b: WideLimbs<4> = WideLimbs([1, 0, 0, 0]); + let c = a + b; + assert_eq!(c.0, [0, 1, 0, 0]); + } + + #[test] + fn test_add_with_multi_carry() { + // Test carry propagation across multiple limbs + let a: WideLimbs<4> = WideLimbs([u64::MAX, u64::MAX, u64::MAX, 0]); + let b: WideLimbs<4> = WideLimbs([1, 0, 0, 0]); + let c = a + b; + assert_eq!(c.0, [0, 0, 0, 1]); + } + + #[test] + fn test_add_assign() { + let mut a: WideLimbs<4> = WideLimbs([1, 2, 3, 4]); + let b: WideLimbs<4> = WideLimbs([10, 20, 30, 40]); + a += b; + assert_eq!(a.0, [11, 22, 33, 44]); + } + + #[test] + fn test_add_assign_ref() { + let mut a: WideLimbs<4> = WideLimbs([1, 2, 3, 4]); + let b: WideLimbs<4> = WideLimbs([10, 20, 30, 40]); + a += &b; + assert_eq!(a.0, [11, 22, 33, 44]); + } + + #[test] + fn test_is_zero() { + let z: WideLimbs<4> = WideLimbs::zero(); + assert!(z.is_zero()); + + let nz: WideLimbs<4> = WideLimbs([0, 0, 1, 0]); + assert!(!nz.is_zero()); + } + + #[test] + fn test_different_sizes() { + // Test that WideLimbs works with different N values + let _a: WideLimbs<6> = WideLimbs::zero(); + let _b: WideLimbs<9> = WideLimbs::zero(); + + let x: WideLimbs<6> = WideLimbs([1, 2, 3, 4, 5, 6]); + let y: WideLimbs<6> = WideLimbs([6, 5, 4, 3, 2, 1]); + let z = x + y; + assert_eq!(z.0, [7, 7, 7, 7, 7, 7]); + } + + #[test] + fn test_mul_4_by_1() { + let a = [1u64, 2, 3, 4]; + let b = 10u64; + let result = mul_4_by_1(&a, b); + assert_eq!(result, [10, 20, 30, 40, 0]); + } + + #[test] + fn test_mul_4_by_1_with_carry() { + let a = [u64::MAX, 0, 0, 0]; + let b = 2u64; + let result = mul_4_by_1(&a, b); + assert_eq!(result, [u64::MAX - 1, 1, 0, 0, 0]); + } + + #[test] + fn test_sub_4_4() { + let a = [10u64, 20, 30, 40]; + let b = [1u64, 2, 3, 4]; + let result = sub_4_4(&a, &b); + assert_eq!(result, [9, 18, 27, 36]); + } + + #[test] + fn test_gte_4_4() { + let a = [10u64, 20, 30, 40]; + let b = [1u64, 2, 3, 4]; + assert!(gte_4_4(&a, &b)); + assert!(!gte_4_4(&b, &a)); + + // Equal case + assert!(gte_4_4(&a, &a)); + } +} diff --git a/src/small_field/mod.rs b/src/small_field/mod.rs new file mode 100644 index 0000000..e228e03 --- /dev/null +++ b/src/small_field/mod.rs @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! Small-value field operations for optimized sumcheck. +//! +//! This module provides types and operations for the small-value optimization +//! described in "Speeding Up Sum-Check Proving". The key insight is distinguishing +//! between three multiplication types: +//! +//! - **ss** (small × small): Native i32/i64 multiplication +//! - **sl** (small × large): Barrett-optimized multiplication (~3× faster) +//! - **ll** (large × large): Standard field multiplication +//! +//! For polynomial evaluations on the boolean hypercube (typically i32 values), +//! we can perform many operations in native integers before converting to field. +//! +//! # Architecture +//! +//! ```text +//! ┌─────────────────────────────────────────────────────────────┐ +//! │ SmallValueField Trait │ +//! │ (ss_mul, sl_mul, isl_mul, small_to_field, etc.) │ +//! └─────────────────────────────────────────────────────────────┘ +//! │ +//! ▼ +//! ┌─────────────────────────────────────────────────────────────┐ +//! │ Barrett Reduction (internal) │ +//! │ mul_by_i64 - ~9 base muls vs ~32 naive │ +//! └─────────────────────────────────────────────────────────────┘ +//! ``` + +pub(crate) mod barrett; +mod delayed_reduction; +mod impls; +pub(crate) mod limbs; +mod small_value_field; + +pub use delayed_reduction::DelayedReduction; +pub use limbs::{SignedWideLimbs, SubMagResult, WideLimbs, sub_mag}; +pub use small_value_field::SmallValueField; + +use ff::PrimeField; + +// ============================================================================ +// Helper Functions +// ============================================================================ + +/// Convert i64 to field element (handles negative values correctly). +#[inline] +pub fn i64_to_field(val: i64) -> F { + if val >= 0 { + F::from(val as u64) + } else { + // Use wrapping_neg to handle i64::MIN correctly + -F::from(val.wrapping_neg() as u64) + } +} + +/// Convert i128 to field element (handles negative values correctly). +#[inline] +pub fn i128_to_field(val: i128) -> F { + if val >= 0 { + // Split into high and low u64 parts + let low = val as u64; + let high = (val >> 64) as u64; + if high == 0 { + F::from(low) + } else { + // result = low + high * 2^64 + F::from(low) + F::from(high) * two_pow_64::() + } + } else { + // Use wrapping_neg to handle i128::MIN correctly + let pos = val.wrapping_neg() as u128; + let low = pos as u64; + let high = (pos >> 64) as u64; + if high == 0 { + -F::from(low) + } else { + -(F::from(low) + F::from(high) * two_pow_64::()) + } + } +} + +/// Try to convert a field element to i64. +/// Returns None if the value doesn't fit in the i64 range. +#[inline] +pub fn try_field_to_i64(val: &F) -> Option { + let repr = val.to_repr(); + let bytes = repr.as_ref(); + + // Check if value fits in positive i64 (high bytes all zero) + let high_zero = bytes[8..].iter().all(|&b| b == 0); + if high_zero { + let val_u64 = u64::from_le_bytes(bytes[..8].try_into().unwrap()); + if val_u64 <= i64::MAX as u64 { + return Some(val_u64 as i64); + } + } + + // Check if negation fits in i64 (value is negative) + let neg_val = val.neg(); + let neg_repr = neg_val.to_repr(); + let neg_bytes = neg_repr.as_ref(); + let neg_high_zero = neg_bytes[8..].iter().all(|&b| b == 0); + if neg_high_zero { + let neg_u64 = u64::from_le_bytes(neg_bytes[..8].try_into().unwrap()); + if neg_u64 > 0 && neg_u64 <= (i64::MAX as u64) + 1 { + return Some(-(neg_u64 as i128) as i64); + } + } + + None +} + +/// Returns 2^64 as a field element (cached via lazy computation). +#[inline] +fn two_pow_64() -> F { + // 2^64 = (2^32)^2 + let two_32 = F::from(1u64 << 32); + two_32 * two_32 +} diff --git a/src/small_field/small_value_field.rs b/src/small_field/small_value_field.rs new file mode 100644 index 0000000..445249b --- /dev/null +++ b/src/small_field/small_value_field.rs @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: MIT +// This file is part of the Spartan2 project. +// See the LICENSE file in the project root for full license information. +// Source repository: https://github.com/Microsoft/Spartan2 + +//! SmallValueField trait for small-value optimization. + +use ff::PrimeField; +use std::{ + fmt::Debug, + ops::{Add, AddAssign, Neg, Sub, SubAssign}, +}; + +/// Trait for fields that support small-value optimization. +/// +/// This trait defines operations for efficient arithmetic when polynomial +/// evaluations fit in native integers. The key optimization is avoiding +/// expensive field operations until absolutely necessary. +/// +/// # Type Parameters +/// - `SmallValue`: Native type for witness values (i32 or i64) +/// +/// # Implementations +/// - `Fp: SmallValueField` with `IntermediateSmallValue = i64` +/// - `Fp: SmallValueField` with `IntermediateSmallValue = i128` +/// +/// # Overflow Bounds (for D=2 Spartan with typical witness values) +/// +/// | Step | Bound | Bits | Container | +/// |------|-------|------|-----------| +/// | Original witness values | 2²⁰ | 20 | i32 | +/// | After 3 extensions (D=2) | 2²³ | 23 | i32 ✓ | +/// | Product of two extended | 2⁴⁶ | 46 | i64 ✓ | +pub trait SmallValueField: PrimeField +where + SmallValue: Copy + + Clone + + Default + + Debug + + PartialEq + + Eq + + Add + + Sub + + Neg + + AddAssign + + SubAssign + + Send + + Sync, +{ + /// Intermediate type for products (i64 for i32 inputs, i128 for i64 inputs). + /// Used when multiplying two SmallValues together. + type IntermediateSmallValue: Copy + Clone + Default + Debug + PartialEq + Eq + Send + Sync; + + // ===== Core Multiplications ===== + + /// ss: small × small → intermediate (i32 × i32 → i64, or i64 × i64 → i128) + fn ss_mul(a: SmallValue, b: SmallValue) -> Self::IntermediateSmallValue; + + /// sl: small × large → large (small × field → field) + fn sl_mul(small: SmallValue, large: &Self) -> Self; + + /// isl: intermediate × large → large (intermediate × field → field) + /// This is the key operation for accumulator building. + fn isl_mul(small: Self::IntermediateSmallValue, large: &Self) -> Self; + + // ===== Conversions ===== + + /// Convert SmallValue to field element. + fn small_to_field(val: SmallValue) -> Self; + + /// Convert IntermediateSmallValue to field element. + fn intermediate_to_field(val: Self::IntermediateSmallValue) -> Self; + + /// Try to convert a field element to SmallValue. + /// Returns None if the value doesn't fit. + fn try_field_to_small(val: &Self) -> Option; +} diff --git a/src/spartan.rs b/src/spartan.rs index 7bbeb93..670229b 100644 --- a/src/spartan.rs +++ b/src/spartan.rs @@ -443,6 +443,70 @@ impl R1CSSNARKTrait for SpartanSNARK { } } +impl SpartanSNARK { + /// Extract the Az, Bz, Cz polynomials and tau challenges from a circuit. + /// + /// This is useful for testing sumcheck methods with real circuit-derived data. + /// Returns `(Az, Bz, Cz, tau)` where Az, Bz, Cz are the matrix-vector products + /// and tau are the random challenges for the outer sum-check. + pub fn extract_outer_sumcheck_inputs>( + pk: &SpartanProverKey, + circuit: C, + prep_snark: &SpartanPrepSNARK, + ) -> Result< + ( + Vec, + Vec, + Vec, + Vec, + ), + SpartanError, + > { + let mut prep_snark = prep_snark.clone(); + + let mut transcript = E::TE::new(b"SpartanSNARK"); + transcript.absorb(b"vk", &pk.vk_digest); + + let public_values = circuit + .public_values() + .map_err(|e| SpartanError::SynthesisError { + reason: format!("Circuit does not provide public IO: {e}"), + })?; + + transcript.absorb(b"public_values", &public_values.as_slice()); + + let (U, W) = SatisfyingAssignment::r1cs_instance_and_witness( + &mut prep_snark.ps, + &pk.S, + &pk.ck, + &circuit, + true, // is_small + &mut transcript, + )?; + + // compute the full satisfying assignment by concatenating W.W, 1, and U.X + let z = [ + W.W.clone(), + vec![E::Scalar::ONE], + U.public_values.clone(), + U.challenges.clone(), + ] + .concat(); + + let num_rounds_x = usize::try_from(pk.S.num_cons.ilog2()).unwrap(); + + // Generate tau challenges + let tau = (0..num_rounds_x) + .map(|_i| transcript.squeeze(b"t")) + .collect::, SpartanError>>()?; + + // Compute Az, Bz, Cz + let (Az, Bz, Cz) = pk.S.multiply_vec(&z)?; + + Ok((Az, Bz, Cz, tau)) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/sumcheck.rs b/src/sumcheck.rs index 3f3819b..ec7102b 100644 --- a/src/sumcheck.rs +++ b/src/sumcheck.rs @@ -17,11 +17,16 @@ use crate::{ solver::SatisfyingAssignment, }, errors::SpartanError, + lagrange_accumulator::{ + AccumulateProduct, DelayedModularReductionEnabled, MatVecMLE, SPARTAN_T_DEGREE, + build_accumulators_spartan, derive_t1, + }, polys::{ multilinear::MultilinearPolynomial, univariate::{CompressedUniPoly, UniPoly}, }, r1cs::SplitMultiRoundR1CSShape, + small_field::{DelayedReduction, SmallValueField}, start_span, traits::{Engine, transcript::TranscriptEngineTrait}, zk::{NeutronNovaVerifierCircuit, SpartanVerifierCircuit}, @@ -71,11 +76,56 @@ where } } +/// Bind three polynomials to the same challenge in one pass. +/// More efficient than three separate bind calls - reduces Rayon dispatches +/// and uses serial fallback for small polynomials. +fn bind_three_polys_top( + poly_a: &mut MultilinearPolynomial, + poly_b: &mut MultilinearPolynomial, + poly_c: &mut MultilinearPolynomial, + r: &F, +) { + let n = poly_a.Z.len() / 2; + debug_assert_eq!(poly_b.Z.len() / 2, n); + debug_assert_eq!(poly_c.Z.len() / 2, n); + + let (a_lo, a_hi) = poly_a.Z.split_at_mut(n); + let (b_lo, b_hi) = poly_b.Z.split_at_mut(n); + let (c_lo, c_hi) = poly_c.Z.split_at_mut(n); + + if n >= PAR_THRESHOLD { + // Parallel path - one Rayon dispatch for all three + a_lo + .par_iter_mut() + .zip(a_hi.par_iter()) + .zip(b_lo.par_iter_mut()) + .zip(b_hi.par_iter()) + .zip(c_lo.par_iter_mut()) + .zip(c_hi.par_iter()) + .for_each(|(((((a_l, a_h), b_l), b_h), c_l), c_h)| { + *a_l += *r * (*a_h - *a_l); + *b_l += *r * (*b_h - *b_l); + *c_l += *r * (*c_h - *c_l); + }); + } else { + // Serial path - no Rayon overhead + for i in 0..n { + a_lo[i] += *r * (a_hi[i] - a_lo[i]); + b_lo[i] += *r * (b_hi[i] - b_lo[i]); + c_lo[i] += *r * (c_hi[i] - c_lo[i]); + } + } + + poly_a.Z.truncate(n); + poly_b.Z.truncate(n); + poly_c.Z.truncate(n); +} + /// A proof generated by the sum-check protocol. /// /// This struct contains the compressed univariate polynomials that constitute /// the prover's messages in each round of the sum-check protocol. -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] #[serde(bound = "")] pub struct SumcheckProof { compressed_polys: Vec>, @@ -203,7 +253,7 @@ impl SumcheckProof { /// * `claim` - The claimed sum over the hypercube /// * `num_rounds` - The number of variables/rounds in the sum-check /// * `poly_A` - First multilinear polynomial (mutable, will be bound during protocol) - /// * `poly_B` - Second multilinear polynomial (mutable, will be bound during protocol) + /// * `poly_B` - Second multilinear polynomial (mutable, will be bound during protocol) /// * `comb_func` - Function that combines evaluations of the two polynomials /// * `transcript` - The transcript for generating randomness /// @@ -278,7 +328,7 @@ impl SumcheckProof { /// /// # Arguments /// * `poly_A` - First multilinear polynomial - /// * `poly_B` - Second multilinear polynomial + /// * `poly_B` - Second multilinear polynomial /// * `poly_C` - Third multilinear polynomial /// * `poly_D` - Fourth multilinear polynomial /// * `comb_func` - Function that combines evaluations of the four polynomials @@ -358,7 +408,7 @@ impl SumcheckProof { /// * `pow_tau_left` - The left part of the power of tau /// * `pow_tau_right` - The right part of the power of tau /// * `poly_A` - First multilinear polynomial - /// * `poly_B` - Second multilinear polynomial + /// * `poly_B` - Second multilinear polynomial /// * `poly_C` - Third multilinear polynomial /// * `comb_func` - Function that combines evaluations of the four polynomials /// @@ -529,6 +579,292 @@ impl SumcheckProof { )) } + /// Prove poly_A * poly_B - poly_C using split-eq with delayed modular reduction. + /// + /// Same as `prove_cubic_with_three_inputs` but uses delayed modular reduction + /// in the eq polynomial evaluation to reduce Montgomery reductions from O(2^k) + /// to O(2^{k/2}) per round. This isolates the delayed reduction optimization + /// without small-value accumulator precomputation. + #[allow(dead_code)] + pub fn prove_cubic_with_three_inputs_split_eq_delayed( + claim: &E::Scalar, + taus: Vec, + poly_A: &mut MultilinearPolynomial, + poly_B: &mut MultilinearPolynomial, + poly_C: &mut MultilinearPolynomial, + transcript: &mut E::TE, + ) -> Result<(Self, Vec, Vec), SpartanError> + where + SmallValue: Copy + + Clone + + Default + + std::fmt::Debug + + PartialEq + + Eq + + std::ops::Add + + std::ops::Sub + + std::ops::Neg + + std::ops::AddAssign + + std::ops::SubAssign + + Send + + Sync, + E::Scalar: DelayedReduction, + { + let mut r: Vec = Vec::new(); + let mut polys: Vec> = Vec::new(); + let mut claim_per_round = *claim; + + let num_rounds = taus.len(); + + let mut eq_instance = eq_sumcheck::EqSumCheckInstance::::new(taus); + + for round in 0..num_rounds { + let (_round_span, round_t) = start_span!("sumcheck_round", round = round); + + let poly = { + let (_eval_span, eval_t) = start_span!("compute_eval_points"); + // Use delayed modular reduction version + let (eval_point_0, eval_point_2, eval_point_3) = eq_instance + .evaluation_points_cubic_with_three_inputs_delayed::( + round, poly_A, poly_B, poly_C, + ); + if eval_t.elapsed().as_millis() > 0 { + info!(elapsed_ms = %eval_t.elapsed().as_millis(), "compute_eval_points"); + } + + let evals = vec![ + eval_point_0, + claim_per_round - eval_point_0, + eval_point_2, + eval_point_3, + ]; + UniPoly::from_evals(&evals)? + }; + + // append the prover's message to the transcript + transcript.absorb(b"p", &poly); + + //derive the verifier's challenge for the next round + let r_i = transcript.squeeze(b"c")?; + r.push(r_i); + polys.push(poly.compress()); + + // Set up next round + claim_per_round = poly.evaluate(&r_i); + + // bound all tables to the verifier's challenge + let (_bind_span, bind_t) = start_span!("bind_poly_vars"); + bind_three_polys_top(poly_A, poly_B, poly_C, &r_i); + eq_instance.bound(&r_i); + info!(elapsed_ms = %bind_t.elapsed().as_millis(), "bind_poly_vars"); + info!(elapsed_ms = %round_t.elapsed().as_millis(), round = round, "sumcheck_round"); + } + + Ok(( + SumcheckProof { + compressed_polys: polys, + }, + r, + vec![poly_A[0], poly_B[0], poly_C[0]], + )) + } + + /// Default number of small-value rounds (ℓ₀) for Algorithm 6. + /// Optimal value from paper analysis for Spartan with D=2. + const DEFAULT_SMALL_VALUE_ROUNDS: usize = 3; + + /// Prove poly_A * poly_B - poly_C using Algorithm 6 (EqPoly-SmallValueSC). + /// + /// This method combines small-value optimization (Algorithm 4) for the first ℓ₀ rounds + /// with eq-poly optimization (Algorithm 5) for the remaining rounds. + /// + /// # Arguments + /// * `claim` - The claimed sum + /// * `taus` - Random challenges for the eq polynomial + /// * `poly_A_small` - Small-value polynomial (evaluations must fit in small type) + /// * `poly_B_small` - Small-value polynomial (evaluations must fit in small type) + /// * `poly_A` - Field-element polynomial (same values as poly_A_small, for binding) + /// * `poly_B` - Field-element polynomial (same values as poly_B_small, for binding) + /// * `poly_C` - Field-element polynomial for the subtractive term + /// * `transcript` - The transcript for Fiat-Shamir + /// + /// # Type Parameters + /// * `P` - The small-value polynomial type, must implement `SpartanAccumulatorInputPolynomial` + /// - `MultilinearPolynomial` for i32/i64 config (simple circuits) + /// - `MultilinearPolynomial` for i64/i128 config (SHA-256, large coefficients) + /// + /// The small-value polynomials are used for efficient ss/sl multiplications in + /// the accumulator building phase. The field-element polynomials are used for + /// binding after the small-value rounds. + /// + /// Generic over `SmallValue` to support both i32/i64 and i64/i128 configurations. + #[allow(dead_code)] + pub fn prove_cubic_with_three_inputs_small_value( + claim: &E::Scalar, + taus: Vec, + poly_A_small: &MultilinearPolynomial, + poly_B_small: &MultilinearPolynomial, + poly_A: &mut MultilinearPolynomial, + poly_B: &mut MultilinearPolynomial, + poly_C: &mut MultilinearPolynomial, + transcript: &mut E::TE, + ) -> Result<(Self, Vec, Vec), SpartanError> + where + SmallValue: Copy + + Clone + + Default + + std::fmt::Debug + + PartialEq + + Eq + + std::ops::Add + + std::ops::Sub + + std::ops::Neg + + std::ops::AddAssign + + std::ops::SubAssign + + Send + + Sync, + E::Scalar: SmallValueField + DelayedReduction, + MultilinearPolynomial: MatVecMLE, + as MatVecMLE>::Product: + AccumulateProduct, + { + let num_rounds = taus.len(); + let mut r: Vec = Vec::with_capacity(num_rounds); + let mut polys: Vec> = Vec::with_capacity(num_rounds); + let mut claim_per_round = *claim; + + // Determine ℓ₀: must satisfy l0 <= num_rounds / 2 + let l0 = std::cmp::min(Self::DEFAULT_SMALL_VALUE_ROUNDS, num_rounds / 2); + + // If l0 is 0, fall back to standard algorithm + if l0 == 0 { + return Self::prove_cubic_with_three_inputs(claim, taus, poly_A, poly_B, poly_C, transcript); + } + + // ===== Pre-computation Phase ===== + // Build accumulators A_i(v, u) for all i ∈ [ℓ₀] using small-value arithmetic + // ss: i32 × i32 → i64 for polynomial products + // isl: i64 × field for eq weighting + let accumulators = build_accumulators_spartan::<_, _, DelayedModularReductionEnabled>( + poly_A_small, + poly_B_small, + &taus, + l0, + ); + let mut small_value = + lagrange_sumcheck::SmallValueSumCheck::::from_accumulators( + accumulators, + ); + + // ===== Small-Value Rounds (0 to ℓ₀-1) ===== + // During these rounds, we use the precomputed accumulators. Polynomials are NOT bound + // during these rounds - that will happen in the transition phase. + #[allow(clippy::needless_range_loop)] + for round in 0..l0 { + let (_round_span, round_t) = start_span!("sumcheck_smallvalue_round", round = round); + + // 1. Get t_i evaluations from accumulators + let t_all = small_value.eval_t_all_u(round); + let t_inf = t_all.at_infinity(); + let t0 = t_all.at_zero(); + + // 2. Get eq factor values ℓ_i(0), ℓ_i(1), ℓ_i(∞) + let li = small_value.eq_round_values(taus[round]); + + // 3. Derive t(1) from sumcheck constraint: s(0) + s(1) = claim + let t1 = derive_t1(li.at_zero(), li.at_one(), claim_per_round, t0) + .ok_or(SpartanError::InvalidSumcheckProof)?; + + // 4. Build round polynomial s_i(X) = ℓ_i(X) · t_i(X) + let poly = lagrange_sumcheck::build_univariate_round_polynomial(&li, t0, t1, t_inf); + + // 5. Transcript interaction + transcript.absorb(b"p", &poly); + let r_i = transcript.squeeze(b"c")?; + r.push(r_i); + polys.push(poly.compress()); + + // 6. Update claim + claim_per_round = poly.evaluate(&r_i); + + // 7. Advance small-value state (updates R_{i+1} and eq_factor) + small_value.advance(&li, r_i); + + info!( + elapsed_ms = %round_t.elapsed().as_millis(), + round = round, + "sumcheck_smallvalue_round" + ); + } + + // ===== Transition Phase ===== + // Create EqSumCheckInstance with ALL taus and advance it by l0 rounds. + // This ensures the internal partition and state matches the standard method exactly. + let mut eq_instance = eq_sumcheck::EqSumCheckInstance::::new(taus.clone()); + + // Bind all three polynomials to challenges r[0..l0] and advance eq_instance + // This brings everything to the same state as if we had run the standard method for l0 rounds + let (_bind_span, bind_t) = start_span!("bind_poly_vars_transition"); + #[allow(clippy::needless_range_loop)] + for i in 0..l0 { + bind_three_polys_top(poly_A, poly_B, poly_C, &r[i]); + eq_instance.bound(&r[i]); + } + info!( + elapsed_ms = %bind_t.elapsed().as_millis(), + "bind_poly_vars_transition" + ); + + // ===== Remaining Rounds (ℓ₀ to ℓ-1) ===== + // Continue using the same eq_instance which is now in the correct state + for round in l0..num_rounds { + let (_round_span, round_t) = start_span!("sumcheck_round", round = round); + + let poly = { + let (_eval_span, eval_t) = start_span!("compute_eval_points"); + let (eval_point_0, eval_point_2, eval_point_3) = eq_instance + .evaluation_points_cubic_with_three_inputs_delayed::( + round, poly_A, poly_B, poly_C, + ); + if eval_t.elapsed().as_millis() > 0 { + info!(elapsed_ms = %eval_t.elapsed().as_millis(), "compute_eval_points"); + } + + let evals = [ + eval_point_0, + claim_per_round - eval_point_0, + eval_point_2, + eval_point_3, + ]; + UniPoly::from_evals(&evals)? + }; + + // Transcript interaction + transcript.absorb(b"p", &poly); + let r_i = transcript.squeeze(b"c")?; + r.push(r_i); + polys.push(poly.compress()); + + // Update claim + claim_per_round = poly.evaluate(&r_i); + + // Bind polynomials and advance eq instance + let (_bind_span, bind_t) = start_span!("bind_poly_vars"); + bind_three_polys_top(poly_A, poly_B, poly_C, &r_i); + eq_instance.bound(&r_i); + info!(elapsed_ms = %bind_t.elapsed().as_millis(), "bind_poly_vars"); + info!(elapsed_ms = %round_t.elapsed().as_millis(), round = round, "sumcheck_round"); + } + + Ok(( + SumcheckProof { + compressed_polys: polys, + }, + r, + vec![poly_A[0], poly_B[0], poly_C[0]], + )) + } + /// Executes the **outer** cubic-with-additive-term sum-check in /// Zero-knowledge outer sum-check for the cubic-with-additive-term case. pub fn prove_cubic_with_additive_term_zk( @@ -867,7 +1203,9 @@ impl SumcheckProof { pub(crate) mod eq_sumcheck { //! This module implements the sumcheck optimization for equality polynomials. //! The optimization is described in Section 5 of algorithm 5. - use crate::{polys::multilinear::MultilinearPolynomial, traits::Engine}; + use crate::{ + polys::multilinear::MultilinearPolynomial, small_field::DelayedReduction, traits::Engine, + }; use ff::{Field, PrimeField}; use rayon::{iter::ZipEq, prelude::*, slice::Iter}; @@ -1026,6 +1364,196 @@ pub(crate) mod eq_sumcheck { (eval_0, eval_2, eval_3) } + /// Evaluate poly_A * poly_B - poly_C using delayed reduction. + /// + /// Same as `evaluation_points_cubic_with_three_inputs` but uses delayed modular reduction + /// to reduce Montgomery reductions from O(2^k) to O(2^{k/2}) per round. + /// + /// # Algorithm + /// + /// Uses two-phase accumulation with the split-eq factorization: + /// - E[id] = E_out[x_out] * E_in[x_in] where id = (x_out, x_in) + /// + /// Phase 1 (inner): For each x_out, accumulate ∑_{x_in} E_in[x_in] ⊗ q_k(g) in wide limbs + /// Phase 2 (outer): Reduce once, then accumulate E_out[x_out] ⊗ inner_reduced + /// Final: Reduce once at the end + #[inline] + pub fn evaluation_points_cubic_with_three_inputs_delayed( + &self, + round_idx: usize, + poly_A: &MultilinearPolynomial, + poly_B: &MultilinearPolynomial, + poly_C: &MultilinearPolynomial, + ) -> (E::Scalar, E::Scalar, E::Scalar) + where + Value: Copy + + Clone + + Default + + std::fmt::Debug + + PartialEq + + Eq + + std::ops::Add + + std::ops::Sub + + std::ops::Neg + + std::ops::AddAssign + + std::ops::SubAssign + + Send + + Sync, + E::Scalar: DelayedReduction, + { + debug_assert_eq!(poly_A.Z.len() % 2, 0); + + type UF = >::UnreducedFieldField; + + let in_first_half = self.round < self.first_half; + let half_p = poly_A.Z.len() / 2; + + let (mut eval_0, mut eval_2, mut eval_3) = if in_first_half { + // Two-phase accumulation: E[id] = E_out[x_out] * E_in[x_in] + let (poly_eq_left, poly_eq_right, second_half, _low_mask) = self.poly_eqs_first_half(); + let eq_out_len = poly_eq_left.len(); + + // Outer loop: iterate over E_out indices + // Dynamic chunk size: enough chunks for work-stealing, but not too many to cause overhead + let min_chunk = (eq_out_len / (rayon::current_num_threads() * 4)).max(1); + let (acc_0, acc_2, acc_3) = (0..eq_out_len) + .into_par_iter() + .with_min_len(min_chunk) + .fold( + || { + ( + UF::::default(), + UF::::default(), + UF::::default(), + ) + }, + |mut outer_acc, x_out| { + let e_out = &poly_eq_left[x_out]; + + // Phase 1: Inner loop - accumulate E_in[x_in] ⊗ q_k(g) in wide limbs + let mut inner_0 = UF::::default(); + let mut inner_2 = UF::::default(); + let mut inner_3 = UF::::default(); + + for (x_in, e_in) in poly_eq_right.iter().enumerate() { + let id = (x_out << second_half) | x_in; + + // Get polynomial values at index id + let (zero_a, one_a) = (&poly_A.Z[id], &poly_A.Z[id + half_p]); + let (zero_b, one_b) = (&poly_B.Z[id], &poly_B.Z[id + half_p]); + let (zero_c, one_c) = (&poly_C.Z[id], &poly_C.Z[id + half_p]); + + let (q0, q2, q3) = eval_one_case_cubic_three_inputs( + round_idx, zero_a, one_a, zero_b, one_b, zero_c, one_c, + ); + + // Accumulate E_in * q_k in wide limbs (NO REDUCTION) + E::Scalar::unreduced_field_field_mul_add(&mut inner_0, e_in, &q0); + E::Scalar::unreduced_field_field_mul_add(&mut inner_2, e_in, &q2); + E::Scalar::unreduced_field_field_mul_add(&mut inner_3, e_in, &q3); + } + + // Phase 2: Reduce inner sums ONCE, then multiply by E_out + let inner_0_red = E::Scalar::reduce_field_field(&inner_0); + let inner_2_red = E::Scalar::reduce_field_field(&inner_2); + let inner_3_red = E::Scalar::reduce_field_field(&inner_3); + + // Accumulate E_out * inner_reduced in wide limbs (NO REDUCTION) + E::Scalar::unreduced_field_field_mul_add(&mut outer_acc.0, e_out, &inner_0_red); + E::Scalar::unreduced_field_field_mul_add(&mut outer_acc.1, e_out, &inner_2_red); + E::Scalar::unreduced_field_field_mul_add(&mut outer_acc.2, e_out, &inner_3_red); + + outer_acc + }, + ) + .reduce( + || { + ( + UF::::default(), + UF::::default(), + UF::::default(), + ) + }, + |mut a, b| { + a.0 += b.0; + a.1 += b.1; + a.2 += b.2; + a + }, + ); + + // Final reduction + ( + E::Scalar::reduce_field_field(&acc_0), + E::Scalar::reduce_field_field(&acc_2), + E::Scalar::reduce_field_field(&acc_3), + ) + } else { + // Second half: only E_in (poly_eq_right), no E_out + // Still use delayed reduction but simpler (single phase) + let poly_eq_right = self.poly_eq_right_last_half(); + + // Dynamic chunk size for work-stealing balance + let min_chunk = (half_p / (rayon::current_num_threads() * 4)).max(1); + let (acc_0, acc_2, acc_3) = (0..half_p) + .into_par_iter() + .with_min_len(min_chunk) + .fold( + || { + ( + UF::::default(), + UF::::default(), + UF::::default(), + ) + }, + |mut acc, id| { + let e = &poly_eq_right[id]; + + let (zero_a, one_a) = (&poly_A.Z[id], &poly_A.Z[id + half_p]); + let (zero_b, one_b) = (&poly_B.Z[id], &poly_B.Z[id + half_p]); + let (zero_c, one_c) = (&poly_C.Z[id], &poly_C.Z[id + half_p]); + + let (q0, q2, q3) = eval_one_case_cubic_three_inputs( + round_idx, zero_a, one_a, zero_b, one_b, zero_c, one_c, + ); + + // Accumulate E * q_k in wide limbs (NO REDUCTION) + E::Scalar::unreduced_field_field_mul_add(&mut acc.0, e, &q0); + E::Scalar::unreduced_field_field_mul_add(&mut acc.1, e, &q2); + E::Scalar::unreduced_field_field_mul_add(&mut acc.2, e, &q3); + + acc + }, + ) + .reduce( + || { + ( + UF::::default(), + UF::::default(), + UF::::default(), + ) + }, + |mut a, b| { + a.0 += b.0; + a.1 += b.1; + a.2 += b.2; + a + }, + ); + + // Final reduction + ( + E::Scalar::reduce_field_field(&acc_0), + E::Scalar::reduce_field_field(&acc_2), + E::Scalar::reduce_field_field(&acc_3), + ) + }; + + self.update_evals(&mut eval_0, &mut eval_2, &mut eval_3); + + (eval_0, eval_2, eval_3) + } + #[inline] pub fn bound(&mut self, r: &E::Scalar) { // Invariant: self.round is always >= 1 when bound is called @@ -1119,7 +1647,7 @@ pub(crate) mod eq_sumcheck { /// A tuple `(eval_0, eval_2, eval_3)` containing the evaluation points. #[inline] fn eval_one_case_cubic_three_inputs( - round_idx: usize, + _round_idx: usize, zero_a: &Scalar, one_a: &Scalar, zero_b: &Scalar, @@ -1127,20 +1655,10 @@ pub(crate) mod eq_sumcheck { zero_c: &Scalar, one_c: &Scalar, ) -> (Scalar, Scalar, Scalar) { - // Optimization: In the first round (round == 0), eval_0 is always ZERO. - // This is mathematically correct because in round 0 of the sumcheck protocol with equality - // polynomials, when evaluating at point 0, the equality polynomial factor eq(tau, 0, ...) - // evaluates to (1 - tau_0) for the first variable. The sumcheck instance's update_evals - // method multiplies eval_0 by eq_tau_0_p, which for round 0 equals (1 - tau_0) * eval_eq_left. - // The contribution from eval_0 to the final sum is zero in this case due to the structure of - // the equality polynomial and how it combines with the cubic terms in the first round. - // This optimization avoids unnecessary computation of zero_a * zero_b - zero_c when the result - // will be zeroed out anyway by the equality polynomial evaluation at point 0 in round 0. - let eval_0 = if round_idx == 0 { - Scalar::ZERO - } else { - *zero_a * *zero_b - *zero_c - }; + // Compute the evaluation at point 0 + // Note: The small-value optimization (Algorithm 6) requires the full eval_0 computation + // to produce transcript-equivalent proofs. + let eval_0 = *zero_a * *zero_b - *zero_c; let double_one_a = one_a.double(); let double_one_b = one_b.double(); @@ -1164,3 +1682,520 @@ pub(crate) mod eq_sumcheck { (eval_0, eval_2, eval_3) } } + +/// Lagrange sumcheck implementation for small-value rounds. +pub mod lagrange_sumcheck { + + use crate::{ + lagrange_accumulator::{ + EqRoundFactor, LagrangeAccumulators, LagrangeBasisFactory, LagrangeCoeff, LagrangeEvals, + LagrangeHatEvals, + }, + polys::univariate::UniPoly, + }; + use ff::PrimeField; + + // Re-export for tests + #[cfg(test)] + pub(crate) use crate::lagrange_accumulator::derive_t1; + + /// Tracks the small-value sum-check state for the first ℓ₀ rounds. + pub struct SmallValueSumCheck { + accumulators: LagrangeAccumulators, + coeff: LagrangeCoeff, + eq_factor: EqRoundFactor, + basis_factory: LagrangeBasisFactory, + } + + impl SmallValueSumCheck { + /// Create a new small-value round tracker with precomputed accumulators. + pub fn new( + accumulators: LagrangeAccumulators, + basis_factory: LagrangeBasisFactory, + ) -> Self { + Self { + accumulators, + coeff: LagrangeCoeff::new(), + eq_factor: EqRoundFactor::new(), + basis_factory, + } + } + + /// Create from accumulators with the standard Lagrange basis (0, 1, 2, ...). + pub fn from_accumulators(accumulators: LagrangeAccumulators) -> Self { + let basis_factory = LagrangeBasisFactory::::new(|i| Scalar::from(i as u64)); + Self::new(accumulators, basis_factory) + } + + /// Evaluate t_i(u) for all u ∈ Û_D in a single pass for round i. + pub fn eval_t_all_u(&self, round: usize) -> LagrangeHatEvals { + self.accumulators.round(round).eval_t_all_u(&self.coeff) + } + + /// Compute ℓ_i values for the provided w_i. + pub fn eq_round_values(&self, w_i: Scalar) -> LagrangeEvals { + self.eq_factor.values(w_i) + } + + /// Advance the round state with the verifier challenge r_i. + pub fn advance(&mut self, li: &LagrangeEvals, r_i: Scalar) { + self.eq_factor.advance(li, r_i); + self.coeff.extend(&self.basis_factory.basis_at(r_i)); + } + } + + /// Build the cubic round polynomial s_i(X) in coefficient form for Spartan. + pub(crate) fn build_univariate_round_polynomial( + li: &LagrangeEvals, + t0: F, + t1: F, + t_inf: F, + ) -> UniPoly { + // Reconstruct t_i(X) = aX^2 + bX + c using: + // - a = t_i(∞) (leading coefficient for degree-2 polynomials) + // - c = t_i(0) + // - t_i(1) = a + b + c ⇒ b = t_i(1) − a − c + let a = t_inf; + let c = t0; + let b = t1 - a - c; + + let linf = li.at_infinity(); + let l0 = li.at_zero(); + + // Multiply s_i(X) = ℓ_i(X)·t_i(X) with ℓ_i(X)=ℓ_∞X+ℓ_0 and collect coefficients. + let s3 = linf * a; + let s2 = linf * b + l0 * a; + let s1 = linf * c + l0 * b; + let s0 = l0 * c; + + UniPoly { + coeffs: vec![s0, s1, s2, s3], + } + } + + #[cfg(test)] + mod tests { + use super::*; + use crate::{ + lagrange_accumulator::{ + DelayedModularReductionDisabled, SPARTAN_T_DEGREE, build_accumulators_spartan, + }, + polys::{eq::EqPolynomial, multilinear::MultilinearPolynomial}, + provider::PallasHyraxEngine, + sumcheck::eq_sumcheck::EqSumCheckInstance, + traits::Engine, + }; + use ff::Field; + + type E = PallasHyraxEngine; + type F = ::Scalar; + + #[test] + #[allow(clippy::needless_range_loop)] + fn test_smallvalue_round_matches_eq_instance_evals() { + const NUM_VARS: usize = 6; + const SMALL_VALUE_ROUNDS: usize = 3; + + let n = 1usize << NUM_VARS; + let taus = (0..NUM_VARS) + .map(|i| F::from((i + 2) as u64)) + .collect::>(); + + let az_vals = (0..n).map(|i| F::from((i + 1) as u64)).collect::>(); + let bz_vals = (0..n).map(|i| F::from((i + 3) as u64)).collect::>(); + let cz_vals = az_vals + .iter() + .zip(bz_vals.iter()) + .map(|(a, b)| *a * *b) + .collect::>(); + + let az = MultilinearPolynomial::new(az_vals); + let bz = MultilinearPolynomial::new(bz_vals); + let cz = MultilinearPolynomial::new(cz_vals); + + let eq_evals = EqPolynomial::evals_from_points(&taus); + let mut claim = F::ZERO; + for i in 0..n { + claim += eq_evals[i] * (az.Z[i] * bz.Z[i] - cz.Z[i]); + } + + let accs = build_accumulators_spartan::<_, _, DelayedModularReductionDisabled>( + &az, + &bz, + &taus, + SMALL_VALUE_ROUNDS, + ); + let mut small_value = SmallValueSumCheck::from_accumulators(accs); + + let mut eq_instance = EqSumCheckInstance::::new(taus.clone()); + let mut poly_A = az.clone(); + let mut poly_B = bz.clone(); + let mut poly_C = cz.clone(); + + for round in 0..SMALL_VALUE_ROUNDS { + let (expected_eval_0, expected_eval_2, expected_eval_3) = + eq_instance.evaluation_points_cubic_with_three_inputs(round, &poly_A, &poly_B, &poly_C); + + let li = small_value.eq_round_values(taus[round]); + let t_all = small_value.eval_t_all_u(round); + let t_inf = t_all.at_infinity(); + let t0 = t_all.at_zero(); + let t1 = derive_t1(li.at_zero(), li.at_one(), claim, t0) + .expect("l1 should be non-zero for chosen taus"); + + let r_i = F::from((round + 7) as u64); + let poly = build_univariate_round_polynomial(&li, t0, t1, t_inf); + assert_eq!(poly.evaluate(&F::ZERO), expected_eval_0); + assert_eq!(poly.evaluate(&F::from(2u64)), expected_eval_2); + assert_eq!(poly.evaluate(&F::from(3u64)), expected_eval_3); + claim = poly.evaluate(&r_i); + + poly_A.bind_poly_var_top(&r_i); + poly_B.bind_poly_var_top(&r_i); + poly_C.bind_poly_var_top(&r_i); + eq_instance.bound(&r_i); + small_value.advance(&li, r_i); + } + } + + /// Test that prove_cubic_with_three_inputs_small_value produces identical + /// output to prove_cubic_with_three_inputs for various sizes. + fn run_equivalence_test(num_vars: usize) { + use crate::{sumcheck::SumcheckProof, traits::transcript::TranscriptEngineTrait}; + + let n = 1usize << num_vars; + + // Deterministic polynomials for reproducibility + // Use satisfying witness: Cz = Az * Bz (so Az·Bz - Cz = 0 on boolean hypercube) + // This is required for build_accumulators_spartan which assumes satisfaction. + // + // Create small i32 values + let az_i32: Vec = (0..n).map(|i| (i + 1) as i32).collect(); + let bz_i32: Vec = (0..n).map(|i| (i + 3) as i32).collect(); + + // Create field-element polynomials + let az_vals: Vec = az_i32.iter().map(|&v| F::from(v as u64)).collect(); + let bz_vals: Vec = bz_i32.iter().map(|&v| F::from(v as u64)).collect(); + let cz_vals: Vec = az_vals.iter().zip(&bz_vals).map(|(a, b)| *a * *b).collect(); + + let taus: Vec = (0..num_vars).map(|i| F::from((i + 2) as u64)).collect(); + + // Claim = 0 for satisfying witness (Az·Bz = Cz on {0,1}^n) + let claim: F = F::ZERO; + + // Polynomials for standard method + let mut az1 = MultilinearPolynomial::new(az_vals.clone()); + let mut bz1 = MultilinearPolynomial::new(bz_vals.clone()); + let mut cz1 = MultilinearPolynomial::new(cz_vals.clone()); + + // Small-value polynomials (use native integer arithmetic) + let az_small = MultilinearPolynomial::new(az_i32); + let bz_small = MultilinearPolynomial::new(bz_i32); + + // Field-element polynomials for binding + let mut az2 = MultilinearPolynomial::new(az_vals); + let mut bz2 = MultilinearPolynomial::new(bz_vals); + let mut cz2 = MultilinearPolynomial::new(cz_vals); + + // Fresh transcripts with same seed + let mut transcript1 = ::TE::new(b"test"); + let mut transcript2 = ::TE::new(b"test"); + + // Run standard method + let (proof1, r1, evals1) = SumcheckProof::::prove_cubic_with_three_inputs( + &claim, + taus.clone(), + &mut az1, + &mut bz1, + &mut cz1, + &mut transcript1, + ) + .unwrap(); + + // Run small-value method + let (proof2, r2, evals2) = SumcheckProof::::prove_cubic_with_three_inputs_small_value( + &claim, + taus, + &az_small, + &bz_small, + &mut az2, + &mut bz2, + &mut cz2, + &mut transcript2, + ) + .unwrap(); + + // Verify all outputs match + assert_eq!(r1, r2, "challenges must match for num_vars={}", num_vars); + assert_eq!( + proof1.compressed_polys, proof2.compressed_polys, + "compressed_polys must match for num_vars={}", + num_vars + ); + assert_eq!( + evals1, evals2, + "final evals must match for num_vars={}", + num_vars + ); + } + + /// Debug test to compare first polynomial only + #[test] + fn test_small_value_polynomial_and_eq_instance_equivalence() { + const NUM_VARS: usize = 10; // Same as equivalence test + let n = 1usize << NUM_VARS; + + // Use satisfying witness: Cz = Az * Bz (required for build_accumulators_spartan) + let az_vals: Vec = (0..n).map(|i| F::from((i + 1) as u64)).collect(); + let bz_vals: Vec = (0..n).map(|i| F::from((i + 3) as u64)).collect(); + let cz_vals: Vec = az_vals.iter().zip(&bz_vals).map(|(a, b)| *a * *b).collect(); + + let taus: Vec = (0..NUM_VARS).map(|i| F::from((i + 2) as u64)).collect(); + + // Claim = 0 for satisfying witness + let claim: F = F::ZERO; + + // Create polynomials for standard method + let az1 = MultilinearPolynomial::new(az_vals.clone()); + let bz1 = MultilinearPolynomial::new(bz_vals.clone()); + let cz1 = MultilinearPolynomial::new(cz_vals); + + // Create polynomials for small-value method + let az2 = MultilinearPolynomial::new(az_vals); + let bz2 = MultilinearPolynomial::new(bz_vals); + + // Run standard method - just get first polynomial's evaluations + let eq_instance = EqSumCheckInstance::::new(taus.clone()); + let (eval_0_std, eval_2_std, eval_3_std) = + eq_instance.evaluation_points_cubic_with_three_inputs(0, &az1, &bz1, &cz1); + let evals_std = [eval_0_std, claim - eval_0_std, eval_2_std, eval_3_std]; + + // Run small-value method - get first polynomial's evaluations + let l0 = 3usize; + let accumulators = + build_accumulators_spartan::<_, _, DelayedModularReductionDisabled>(&az2, &bz2, &taus, l0); + let small_value = SmallValueSumCheck::::from_accumulators(accumulators); + + let t_all = small_value.eval_t_all_u(0); + let t_inf = t_all.at_infinity(); + let t0 = t_all.at_zero(); + let li = small_value.eq_round_values(taus[0]); + let t1 = derive_t1(li.at_zero(), li.at_one(), claim, t0).expect("l1 non-zero"); + + let poly_sv = build_univariate_round_polynomial(&li, t0, t1, t_inf); + + // Compare evaluations at 0, 1, 2, 3 + let sv0 = poly_sv.evaluate(&F::ZERO); + let sv1 = poly_sv.evaluate(&F::ONE); + let sv2 = poly_sv.evaluate(&F::from(2u64)); + let sv3 = poly_sv.evaluate(&F::from(3u64)); + + assert_eq!( + sv0, evals_std[0], + "s(0) must match: sv={:?}, std={:?}", + sv0, evals_std[0] + ); + assert_eq!( + sv1, evals_std[1], + "s(1) must match: sv={:?}, std={:?}", + sv1, evals_std[1] + ); + assert_eq!( + sv2, evals_std[2], + "s(2) must match: sv={:?}, std={:?}", + sv2, evals_std[2] + ); + assert_eq!( + sv3, evals_std[3], + "s(3) must match: sv={:?}, std={:?}", + sv3, evals_std[3] + ); + } + + #[test] + fn test_small_value_equivalence_l10() { + run_equivalence_test(10); // 2^10 = 1024 elements, l0 = 3 + } + + #[test] + fn test_small_value_equivalence_l16() { + run_equivalence_test(16); // 2^16 = 65536 elements, l0 = 3 (must be even ℓ) + } + + /// Test that i64 small-value polynomials produce identical results to standard. + /// This tests the i64/i128 path which uses isl_mul_128 Barrett reduction. + #[test] + fn test_small_value_equivalence_i64() { + use crate::{sumcheck::SumcheckProof, traits::transcript::TranscriptEngineTrait}; + + let num_vars = 10; + let n = 1usize << num_vars; + + // Create i64 values (larger than i32 range to test the i64 path) + let az_i64: Vec = (0..n).map(|i| (i as i64 + 1) * 100_000).collect(); + let bz_i64: Vec = (0..n).map(|i| (i as i64 + 3) * 100_000).collect(); + + // Create field-element polynomials + let az_vals: Vec = az_i64 + .iter() + .map(|&v| crate::small_field::i64_to_field(v)) + .collect(); + let bz_vals: Vec = bz_i64 + .iter() + .map(|&v| crate::small_field::i64_to_field(v)) + .collect(); + let cz_vals: Vec = az_vals.iter().zip(&bz_vals).map(|(a, b)| *a * *b).collect(); + + let taus: Vec = (0..num_vars).map(|i| F::from((i + 2) as u64)).collect(); + + // Claim = 0 for satisfying witness (Az·Bz = Cz on {0,1}^n) + let claim: F = F::ZERO; + + // Polynomials for standard method + let mut az1 = MultilinearPolynomial::new(az_vals.clone()); + let mut bz1 = MultilinearPolynomial::new(bz_vals.clone()); + let mut cz1 = MultilinearPolynomial::new(cz_vals.clone()); + + // i64 small-value polynomials + let az_small = MultilinearPolynomial::new(az_i64); + let bz_small = MultilinearPolynomial::new(bz_i64); + + // Field-element polynomials for binding + let mut az2 = MultilinearPolynomial::new(az_vals); + let mut bz2 = MultilinearPolynomial::new(bz_vals); + let mut cz2 = MultilinearPolynomial::new(cz_vals); + + // Fresh transcripts with same seed + let mut transcript1 = ::TE::new(b"test"); + let mut transcript2 = ::TE::new(b"test"); + + // Run standard method + let (proof1, r1, evals1) = SumcheckProof::::prove_cubic_with_three_inputs( + &claim, + taus.clone(), + &mut az1, + &mut bz1, + &mut cz1, + &mut transcript1, + ) + .unwrap(); + + // Run small-value method with i64 polynomials + let (proof2, r2, evals2) = SumcheckProof::::prove_cubic_with_three_inputs_small_value( + &claim, + taus, + &az_small, + &bz_small, + &mut az2, + &mut bz2, + &mut cz2, + &mut transcript2, + ) + .unwrap(); + + // Verify all outputs match + assert_eq!(r1, r2, "challenges must match for i64 test"); + assert_eq!( + proof1.compressed_polys, proof2.compressed_polys, + "compressed_polys must match for i64 test" + ); + assert_eq!(evals1, evals2, "final evals must match for i64 test"); + } + + /// Test small-value sumcheck with very large i64 values (~2^50). + /// This simulates the SHA-256 case where coefficients can be large. + #[test] + fn test_small_value_equivalence_i64_large() { + use crate::{sumcheck::SumcheckProof, traits::transcript::TranscriptEngineTrait}; + + let num_vars = 10; + let n = 1usize << num_vars; + + // Create very large i64 values (similar to SHA-256 coefficients ~2^50) + let az_i64: Vec = (0..n) + .map(|i| { + let base = (i as i64 + 1) * (1i64 << 40); // ~2^50 + if i % 3 == 2 { -base } else { base } + }) + .collect(); + let bz_i64: Vec = (0..n) + .map(|i| { + let base = (i as i64 + 3) * (1i64 << 35); // ~2^45 + if i % 5 == 4 { -base } else { base } + }) + .collect(); + + // Create field-element polynomials + let az_vals: Vec = az_i64 + .iter() + .map(|&v| crate::small_field::i64_to_field(v)) + .collect(); + let bz_vals: Vec = bz_i64 + .iter() + .map(|&v| crate::small_field::i64_to_field(v)) + .collect(); + let cz_vals: Vec = az_vals.iter().zip(&bz_vals).map(|(a, b)| *a * *b).collect(); + + // Verify roundtrip + for i in 0..n { + let back = crate::small_field::try_field_to_i64(&az_vals[i]).expect("should fit"); + assert_eq!(back, az_i64[i], "roundtrip failed at {}", i); + } + + let taus: Vec = (0..num_vars).map(|i| F::from((i + 2) as u64)).collect(); + + // Claim = 0 for satisfying witness (Az·Bz = Cz on {0,1}^n) + let claim: F = F::ZERO; + + // Polynomials for standard method + let mut az1 = MultilinearPolynomial::new(az_vals.clone()); + let mut bz1 = MultilinearPolynomial::new(bz_vals.clone()); + let mut cz1 = MultilinearPolynomial::new(cz_vals.clone()); + + // i64 small-value polynomials + let az_small = MultilinearPolynomial::new(az_i64); + let bz_small = MultilinearPolynomial::new(bz_i64); + + // Field-element polynomials for binding + let mut az2 = MultilinearPolynomial::new(az_vals); + let mut bz2 = MultilinearPolynomial::new(bz_vals); + let mut cz2 = MultilinearPolynomial::new(cz_vals); + + // Fresh transcripts with same seed + let mut transcript1 = ::TE::new(b"test"); + let mut transcript2 = ::TE::new(b"test"); + + // Run standard method + let (proof1, r1, evals1) = SumcheckProof::::prove_cubic_with_three_inputs( + &claim, + taus.clone(), + &mut az1, + &mut bz1, + &mut cz1, + &mut transcript1, + ) + .unwrap(); + + // Run small-value method with i64 polynomials + let (proof2, r2, evals2) = SumcheckProof::::prove_cubic_with_three_inputs_small_value( + &claim, + taus, + &az_small, + &bz_small, + &mut az2, + &mut bz2, + &mut cz2, + &mut transcript2, + ) + .unwrap(); + + // Verify all outputs match + assert_eq!(r1, r2, "challenges must match for large i64 test"); + assert_eq!( + proof1.compressed_polys, proof2.compressed_polys, + "compressed_polys must match for large i64 test" + ); + assert_eq!(evals1, evals2, "final evals must match for large i64 test"); + } + } +}