diff --git a/.gitignore b/.gitignore index 7e93290..107d879 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,9 @@ Cargo.lock bin obj + +# Simple Task Master - User tasks are git-ignored +.simple-task-master +.simple-task-master/lock +.claude +specs/ diff --git a/banderwagon/src/element.rs b/banderwagon/src/element.rs index 665f4d8..e43bb93 100644 --- a/banderwagon/src/element.rs +++ b/banderwagon/src/element.rs @@ -74,6 +74,14 @@ impl Element { Self(point) } + /// Try to deserialize from uncompressed bytes, returning an error on failure + pub fn try_from_bytes_uncompressed( + bytes: [u8; 64], + ) -> Result { + let point = EdwardsProjective::deserialize_uncompressed_unchecked(&bytes[..])?; + Ok(Self(point)) + } + pub fn from_bytes(bytes: &[u8]) -> Option { // Switch from big endian to little endian, as arkworks library uses little endian let mut bytes = bytes.to_vec(); diff --git a/banderwagon/src/trait_impls/serialize.rs b/banderwagon/src/trait_impls/serialize.rs index 8938dfe..432764a 100644 --- a/banderwagon/src/trait_impls/serialize.rs +++ b/banderwagon/src/trait_impls/serialize.rs @@ -1,5 +1,4 @@ use crate::Element; -use ark_ec::CurveGroup; use ark_ed_on_bls12_381_bandersnatch::EdwardsProjective; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError, Valid}; impl CanonicalSerialize for Element { @@ -13,7 +12,7 @@ impl CanonicalSerialize for Element { writer.write_all(&self.to_bytes())?; Ok(()) } - ark_serialize::Compress::No => self.0.into_affine().serialize_uncompressed(writer), + ark_serialize::Compress::No => self.0.serialize_uncompressed(writer), } } @@ -56,7 +55,7 @@ impl CanonicalDeserialize for Element { } } ark_serialize::Compress::No => { - let point = EdwardsProjective::deserialize_uncompressed(reader)?; + let point = EdwardsProjective::deserialize_uncompressed_unchecked(reader)?; Ok(Element(point)) } } diff --git a/bindings/csharp/csharp_code/Verkle.Bindings/native_methods.g.cs b/bindings/csharp/csharp_code/Verkle.Bindings/native_methods.g.cs index 5dd05e0..539151d 100644 --- a/bindings/csharp/csharp_code/Verkle.Bindings/native_methods.g.cs +++ b/bindings/csharp/csharp_code/Verkle.Bindings/native_methods.g.cs @@ -16,6 +16,8 @@ internal static unsafe partial class NativeMethods + + [DllImport(__DllName, EntryPoint = "context_new", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] internal static extern Context* context_new(); diff --git a/ffi_interface/Cargo.toml b/ffi_interface/Cargo.toml index 06696b1..448ad45 100644 --- a/ffi_interface/Cargo.toml +++ b/ffi_interface/Cargo.toml @@ -10,4 +10,5 @@ banderwagon = { path = "../banderwagon" } ipa-multipoint = { path = "../ipa-multipoint" } verkle-spec = { path = "../verkle-spec" } hex = "*" +rayon = "1.8.0" verkle-trie = { path = "../verkle-trie" } diff --git a/ffi_interface/src/lib.rs b/ffi_interface/src/lib.rs index fbecabf..8753b1b 100644 --- a/ffi_interface/src/lib.rs +++ b/ffi_interface/src/lib.rs @@ -14,6 +14,7 @@ use ipa_multipoint::crs::CRS; use ipa_multipoint::lagrange_basis::PrecomputedWeights; use ipa_multipoint::multiproof::{MultiPoint, MultiPointProof, ProverQuery, VerifierQuery}; use ipa_multipoint::transcript::Transcript; +use rayon::prelude::*; pub use serialization::{fr_from_le_bytes, fr_to_le_bytes}; use verkle_trie::proof::golang_proof_format::{bytes32_to_element, hex_to_bytes32, VerkleProofGo}; @@ -258,12 +259,50 @@ pub fn hash_commitment(commitment: CommitmentBytes) -> ScalarBytes { // TODO: this is actually a bottleneck for the average workflow before doing this. fr_to_le_bytes(Element::from_bytes_unchecked_uncompressed(commitment).map_to_scalar_field()) } + +/// Minimum number of commitments before using parallel processing. +/// Below this threshold, sequential processing is faster due to thread pool overhead. +const PARALLEL_HASH_THRESHOLD: usize = 100; + /// Hashes a vector of commitments. /// -/// This is more efficient than repeatedly calling `hash_commitment` +/// This is more efficient than repeatedly calling `hash_commitment`. +/// For batches of 100 or more commitments, parallel processing is automatically used. /// /// Returns a vector of `Scalar`s representing the hash of each commitment pub fn hash_commitments(commitments: &[CommitmentBytes]) -> Vec { + if commitments.len() < PARALLEL_HASH_THRESHOLD { + // Sequential for small batches + hash_commitments_sequential(commitments) + } else { + // Parallel for large batches + hash_commitments_impl_parallel(commitments) + } +} + +/// Hashes commitments with explicit parallelism control. +/// +/// Use this when you need to control whether parallel processing is used, +/// regardless of the batch size. +/// +/// # Arguments +/// * `commitments` - The commitments to hash +/// * `use_parallel` - If true, use parallel processing; if false, use sequential +/// +/// Returns a vector of `Scalar`s representing the hash of each commitment +pub fn hash_commitments_parallel( + commitments: &[CommitmentBytes], + use_parallel: bool, +) -> Vec { + if use_parallel { + hash_commitments_impl_parallel(commitments) + } else { + hash_commitments_sequential(commitments) + } +} + +/// Sequential implementation of commitment hashing using batch_map_to_scalar_field +fn hash_commitments_sequential(commitments: &[CommitmentBytes]) -> Vec { let elements = commitments .iter() .map(|commitment| Element::from_bytes_unchecked_uncompressed(*commitment)) @@ -275,6 +314,14 @@ pub fn hash_commitments(commitments: &[CommitmentBytes]) -> Vec { .collect() } +/// Parallel implementation of commitment hashing using rayon +fn hash_commitments_impl_parallel(commitments: &[CommitmentBytes]) -> Vec { + commitments + .par_iter() + .map(|commitment| hash_commitment(*commitment)) + .collect() +} + /// Receives a tuple (C_i, f_i(X), z_i, y_i) /// /// Where C_i is a commitment to f_i(X) serialized as 32 bytes @@ -746,3 +793,124 @@ mod prover_verifier_test { assert!(verified); } } + +#[cfg(test)] +mod parallel_hash_tests { + use super::*; + use ipa_multipoint::committer::Committer; + + fn create_test_commitments(count: usize) -> Vec { + let context = Context::new(); + (0..count) + .map(|i| { + // Create a unique commitment for each index + let scalar = banderwagon::Fr::from(i as u128 + 1); + let element = context.committer.scalar_mul(scalar, 0); + element.to_bytes_uncompressed() + }) + .collect() + } + + #[test] + fn test_parallel_hash_determinism() { + // Test that parallel hashing produces the same results as sequential + let commitments = create_test_commitments(200); + + let sequential = hash_commitments_parallel(&commitments, false); + let parallel = hash_commitments_parallel(&commitments, true); + + assert_eq!( + sequential, parallel, + "Parallel and sequential hashing should produce identical results" + ); + } + + #[test] + fn test_parallel_hash_threshold_behavior() { + // Below threshold (< 100): should use sequential + let small_batch = create_test_commitments(50); + let result1 = hash_commitments(&small_batch); + let result2 = hash_commitments_sequential(&small_batch); + assert_eq!( + result1, result2, + "Small batch should use sequential implementation" + ); + + // At or above threshold (>= 100): should use parallel + let large_batch = create_test_commitments(150); + let result3 = hash_commitments(&large_batch); + let result4 = hash_commitments_impl_parallel(&large_batch); + assert_eq!( + result3, result4, + "Large batch should use parallel implementation" + ); + } + + #[test] + fn test_parallel_hash_explicit_control() { + let commitments = create_test_commitments(50); + + // Force parallel even with small batch + let parallel_result = hash_commitments_parallel(&commitments, true); + let sequential_result = hash_commitments_parallel(&commitments, false); + + assert_eq!( + parallel_result, sequential_result, + "Explicit parallel control should still produce same results" + ); + } + + #[test] + fn test_parallel_hash_empty_input() { + let empty: Vec = vec![]; + + let sequential = hash_commitments_parallel(&empty, false); + let parallel = hash_commitments_parallel(&empty, true); + + assert!(sequential.is_empty()); + assert!(parallel.is_empty()); + } + + #[test] + fn test_parallel_hash_single_item() { + let single = create_test_commitments(1); + + let sequential = hash_commitments_parallel(&single, false); + let parallel = hash_commitments_parallel(&single, true); + + assert_eq!(sequential.len(), 1); + assert_eq!(parallel.len(), 1); + assert_eq!(sequential[0], parallel[0]); + } + + #[test] + fn test_parallel_hash_large_batch() { + // Test with a large batch to verify parallel processing works correctly + let large_batch = create_test_commitments(500); + + let sequential = hash_commitments_parallel(&large_batch, false); + let parallel = hash_commitments_parallel(&large_batch, true); + + assert_eq!(sequential.len(), 500); + assert_eq!(parallel.len(), 500); + assert_eq!( + sequential, parallel, + "Large batch parallel hashing should match sequential" + ); + } + + #[test] + fn test_hash_commitments_matches_individual() { + // Verify that batch hashing produces same results as individual hashing + let commitments = create_test_commitments(10); + + let batch_results = hash_commitments(&commitments); + let individual_results: Vec = + commitments.iter().map(|c| hash_commitment(*c)).collect(); + + assert_eq!( + batch_results, individual_results, + "Batch hashing should produce same results as individual hashing" + ); + } +} diff --git a/ipa-multipoint/Cargo.toml b/ipa-multipoint/Cargo.toml index c84a7cf..bfa686c 100644 --- a/ipa-multipoint/Cargo.toml +++ b/ipa-multipoint/Cargo.toml @@ -12,6 +12,7 @@ itertools = "0.10.1" sha2 = "0.9.8" rayon = "1.8.0" hex = "0.4.3" +thiserror = "1.0" [dev-dependencies] criterion = "0.5.1" diff --git a/ipa-multipoint/Readme.md b/ipa-multipoint/Readme.md index bb8a7c8..0cc2c41 100644 --- a/ipa-multipoint/Readme.md +++ b/ipa-multipoint/Readme.md @@ -15,7 +15,9 @@ This library uses the banderwagon prime group (https://hackmd.io/@6iQDuIePQjyYBq ## Efficiency -- Parallelism is not being used +- The `MultiPoint::open()` function uses Rayon for parallel processing of large batches (>100 grouped queries) +- Parallel paths are used for query aggregation, division mapping, and polynomial scaling +- For small batches, sequential processing is used to avoid thread pool overhead - We have not modified pippenger to take benefit of the GLV endomorphism ## API @@ -78,4 +80,8 @@ New benchmark on banderwagon subgroup: Apple M1 Pro 16GB RAM -*These benchmarks are tentative because on one hand, the machine being used may not be the what the average user uses, while on the other hand, we have not optimised the verifier algorithm to remove `bH` , the pippenger algorithm does not take into consideration GLV and we are not using rayon to parallelise.* +*These benchmarks are tentative because on one hand, the machine being used may not be the what the average user uses, while on the other hand, we have not optimised the verifier algorithm to remove `bH` and the pippenger algorithm does not take into consideration GLV.* + +### Parallel Performance + +The multiproof prover now uses Rayon parallelization for large batches (>100 grouped queries). For batches of 16K+ queries, expect 4-8x speedup on multi-core systems. The parallel threshold of 100 queries balances parallelization overhead against performance gains. You can control the number of threads using the `RAYON_NUM_THREADS` environment variable. diff --git a/ipa-multipoint/src/crs.rs b/ipa-multipoint/src/crs.rs index d11938c..5f9333e 100644 --- a/ipa-multipoint/src/crs.rs +++ b/ipa-multipoint/src/crs.rs @@ -1,5 +1,21 @@ use crate::{default_crs, ipa::slow_vartime_multiscalar_mul, lagrange_basis::LagrangeBasis}; use banderwagon::{try_reduce_to_element, Element}; +use thiserror::Error; + +/// Size of a single uncompressed point in bytes +pub const UNCOMPRESSED_POINT_SIZE: usize = 64; + +#[derive(Debug, Error)] +pub enum CRSError { + #[error("Invalid CRS byte length: expected {expected}, got {actual}")] + InvalidLength { expected: usize, actual: usize }, + + #[error("Duplicate points detected in CRS at indices {0} and {1}")] + DuplicatePoints(usize, usize), + + #[error("Failed to deserialize point at index {index}: {reason}")] + PointDeserializationError { index: usize, reason: String }, +} #[allow(non_snake_case)] #[derive(Debug, Clone)] @@ -36,18 +52,44 @@ impl CRS { #[allow(non_snake_case)] // The last element is implied to be `Q` - pub fn from_bytes(bytes: &[[u8; 64]]) -> CRS { - let (q_bytes, g_vec_bytes) = bytes - .split_last() - .expect("bytes vector should not be empty"); + pub fn from_bytes(bytes: &[[u8; 64]]) -> Result { + let (q_bytes, g_vec_bytes) = bytes.split_last().ok_or_else(|| CRSError::InvalidLength { + expected: 1, + actual: 0, + })?; + + let Q = Element::try_from_bytes_uncompressed(*q_bytes).map_err(|e| { + CRSError::PointDeserializationError { + index: g_vec_bytes.len(), + reason: e.to_string(), + } + })?; + + let mut G = Vec::with_capacity(g_vec_bytes.len()); + for (index, bytes) in g_vec_bytes.iter().enumerate() { + let point = Element::try_from_bytes_uncompressed(*bytes).map_err(|e| { + CRSError::PointDeserializationError { + index, + reason: e.to_string(), + } + })?; + G.push(point); + } + + // Check for duplicates + Self::validate_no_duplicates(&G)?; - let Q = Element::from_bytes_unchecked_uncompressed(*q_bytes); - let G: Vec<_> = g_vec_bytes - .iter() - .map(|bytes| Element::from_bytes_unchecked_uncompressed(*bytes)) - .collect(); let n = G.len(); - CRS { G, Q, n } + Ok(CRS { G, Q, n }) + } + + /// Deserialize CRS from bytes, panicking on error + /// + /// # Panics + /// Panics if bytes are invalid. Prefer `from_bytes()` which returns Result. + #[deprecated(since = "1.0.0", note = "Use from_bytes() which returns Result")] + pub fn from_bytes_unchecked(bytes: &[[u8; 64]]) -> Self { + Self::from_bytes(bytes).expect("Failed to deserialize CRS") } pub fn from_hex(hex_encoded_crs: &[&str]) -> CRS { let bytes: Vec<[u8; 64]> = hex_encoded_crs @@ -55,7 +97,7 @@ impl CRS { .map(|hex| hex::decode(hex).unwrap()) .map(|byte_vector| byte_vector.try_into().unwrap()) .collect(); - CRS::from_bytes(&bytes) + CRS::from_bytes(&bytes).expect("Failed to deserialize CRS from hex") } pub fn to_bytes(&self) -> Vec<[u8; 64]> { @@ -71,6 +113,18 @@ impl CRS { self.to_bytes().iter().map(hex::encode).collect() } + /// Check that no two points in the CRS are identical + fn validate_no_duplicates(points: &[Element]) -> Result<(), CRSError> { + for i in 0..points.len() { + for j in (i + 1)..points.len() { + if points[i] == points[j] { + return Err(CRSError::DuplicatePoints(i, j)); + } + } + } + Ok(()) + } + // Asserts that not of the points generated are the same fn assert_dedup(points: &[Element]) { use std::collections::HashSet; @@ -154,11 +208,64 @@ fn crs_consistency() { fn load_from_bytes_to_bytes() { let crs = CRS::new(256, b"eth_verkle_oct_2021"); let bytes = crs.to_bytes(); - let crs2 = CRS::from_bytes(&bytes); + let crs2 = CRS::from_bytes(&bytes).expect("should deserialize"); let bytes2 = crs2.to_bytes(); - let hex: Vec<_> = bytes.iter().map(hex::encode).collect(); - dbg!(hex); - assert_eq!(bytes, bytes2, "bytes should be the same"); } + +#[cfg(test)] +mod error_handling_tests { + use super::*; + + /// Test: Valid CRS bytes deserialize successfully + #[test] + fn test_from_bytes_valid() { + let crs = CRS::new(256, b"test_seed"); + let bytes = crs.to_bytes(); + + let restored = CRS::from_bytes(&bytes); + assert!(restored.is_ok()); + assert_eq!(restored.unwrap().G.len(), 256); + } + + /// Test: Empty bytes returns InvalidLength error + #[test] + fn test_from_bytes_empty() { + let result = CRS::from_bytes(&[]); + assert!(matches!( + result, + Err(CRSError::InvalidLength { + expected: 1, + actual: 0 + }) + )); + } + + /// Test: Corrupted point returns PointDeserializationError + #[test] + fn test_from_bytes_corrupted_point() { + // Create array of invalid point data (all 0xFF bytes) + let corrupted_bytes: Vec<[u8; 64]> = vec![[0xFF; 64]; 257]; // 256 G points + 1 Q point + + let result = CRS::from_bytes(&corrupted_bytes); + // The last point is Q, so it should fail at index 256 + assert!(matches!( + result, + Err(CRSError::PointDeserializationError { index: 256, .. }) + )); + } + + /// Test: Duplicate points returns DuplicatePoints error + #[test] + fn test_from_bytes_duplicate_points() { + let crs = CRS::new(256, b"test_seed"); + let mut bytes = crs.to_bytes(); + + // Copy first point to second position + bytes[1] = bytes[0]; + + let result = CRS::from_bytes(&bytes); + assert!(matches!(result, Err(CRSError::DuplicatePoints(0, 1)))); + } +} diff --git a/ipa-multipoint/src/multiproof.rs b/ipa-multipoint/src/multiproof.rs index 03d3720..05551ce 100644 --- a/ipa-multipoint/src/multiproof.rs +++ b/ipa-multipoint/src/multiproof.rs @@ -9,9 +9,42 @@ use crate::math_utils::powers_of; use crate::transcript::Transcript; use crate::transcript::TranscriptProtocol; +use rayon::prelude::*; use std::collections::HashMap; use banderwagon::{trait_defs::*, Element, Fr}; + +/// Minimum number of grouped queries before using parallel processing. +/// Below this threshold, sequential processing is faster due to thread pool overhead. +const PARALLEL_QUERY_THRESHOLD: usize = 100; + +/// Aggregate a group of queries evaluated at the same point into a single polynomial. +/// This is the core computation extracted to avoid duplication between parallel/sequential paths. +fn aggregate_query_group( + point: usize, + queries_challenges: Vec<(&ProverQuery, &Fr)>, + poly_size: usize, +) -> (usize, LagrangeBasis) { + let mut aggregated_polynomial = vec![Fr::zero(); poly_size]; + + for (query, challenge) in queries_challenges { + for (result, value) in aggregated_polynomial + .iter_mut() + .zip(query.poly.values().iter()) + { + *result += *value * challenge; + } + } + + (point, LagrangeBasis::new(aggregated_polynomial)) +} + +/// Scale polynomial coefficients by a scalar and wrap in LagrangeBasis. +fn scale_polynomial(poly: LagrangeBasis, scalar: Fr) -> LagrangeBasis { + let term: Vec<_> = poly.values().iter().map(|coeff| scalar * coeff).collect(); + LagrangeBasis::new(term) +} + pub struct MultiPoint; #[derive(Clone, Debug)] @@ -75,50 +108,40 @@ impl MultiPoint { let powers_of_r = powers_of(r, queries.len()); let grouped_queries = group_prover_queries(&queries, &powers_of_r); - - // aggregate all of the queries evaluated at the same point - let aggregated_queries: Vec<_> = grouped_queries - .into_iter() - .map(|(point, queries_challenges)| { - let mut aggregated_polynomial = vec![Fr::zero(); crs.n]; - - let scaled_lagrange_polynomials = - queries_challenges.into_iter().map(|(query, challenge)| { - // scale the polynomial by the challenge - query.poly.values().iter().map(move |x| *x * challenge) - }); - - for poly_mul_challenge in scaled_lagrange_polynomials { - for (result, scaled_poly) in - aggregated_polynomial.iter_mut().zip(poly_mul_challenge) - { - *result += scaled_poly; - } - } - - (point, LagrangeBasis::new(aggregated_polynomial)) - }) - .collect(); - - // Compute g(X) - // - let g_x: LagrangeBasis = aggregated_queries - .iter() - .map(|(point, agg_f_x)| (agg_f_x).divide_by_linear_vanishing(precomp, *point)) - .fold(LagrangeBasis::zero(), |mut res, val| { - res = res + val; - res - }); + let poly_size = crs.n; + let use_parallel = grouped_queries.len() >= PARALLEL_QUERY_THRESHOLD; + + // Aggregate all of the queries evaluated at the same point + let aggregated_queries: Vec<_> = if use_parallel { + grouped_queries + .into_par_iter() + .map(|(point, qc)| aggregate_query_group(point, qc, poly_size)) + .collect() + } else { + grouped_queries + .into_iter() + .map(|(point, qc)| aggregate_query_group(point, qc, poly_size)) + .collect() + }; + + // Compute g(X) = sum of (agg_f_x / (X - point)) for each aggregated query + let g_x: LagrangeBasis = if use_parallel { + aggregated_queries + .par_iter() + .map(|(point, agg_f_x)| agg_f_x.divide_by_linear_vanishing(precomp, *point)) + .reduce(LagrangeBasis::zero, |a, b| a + b) + } else { + aggregated_queries + .iter() + .map(|(point, agg_f_x)| agg_f_x.divide_by_linear_vanishing(precomp, *point)) + .fold(LagrangeBasis::zero(), |res, val| res + val) + }; let g_x_comm = crs.commit_lagrange_poly(&g_x); transcript.append_point(b"D", &g_x_comm); // 2. Compute g_1(t) - // - // let t = transcript.challenge_scalar(b"t"); - // - // let mut g1_den: Vec<_> = aggregated_queries .iter() @@ -126,22 +149,20 @@ impl MultiPoint { .collect(); batch_inversion(&mut g1_den); - let g1_x = aggregated_queries - .into_iter() - .zip(g1_den) - .map(|((_, agg_f_x), den_inv)| { - let term: Vec<_> = agg_f_x - .values() - .iter() - .map(|coeff| den_inv * coeff) - .collect(); - - LagrangeBasis::new(term) - }) - .fold(LagrangeBasis::zero(), |mut res, val| { - res = res + val; - res - }); + // Compute g1_x = sum of (agg_f_x * den_inv) for each aggregated query + let g1_x = if use_parallel { + aggregated_queries + .into_par_iter() + .zip(g1_den.into_par_iter()) + .map(|((_, agg_f_x), den_inv)| scale_polynomial(agg_f_x, den_inv)) + .reduce(LagrangeBasis::zero, |a, b| a + b) + } else { + aggregated_queries + .into_iter() + .zip(g1_den) + .map(|((_, agg_f_x), den_inv)| scale_polynomial(agg_f_x, den_inv)) + .fold(LagrangeBasis::zero(), |res, val| res + val) + }; let g1_comm = crs.commit_lagrange_poly(&g1_x); transcript.append_point(b"E", &g1_comm); @@ -532,3 +553,166 @@ fn multiproof_consistency() { let expected = "4f53588244efaf07a370ee3f9c467f933eed360d4fbf7a19dfc8bc49b67df4711bf1d0a720717cd6a8c75f1a668cb7cbdd63b48c676b89a7aee4298e71bd7f4013d7657146aa9736817da47051ed6a45fc7b5a61d00eb23e5df82a7f285cc10e67d444e91618465ca68d8ae4f2c916d1942201b7e2aae491ef0f809867d00e83468fb7f9af9b42ede76c1e90d89dd789ff22eb09e8b1d062d8a58b6f88b3cbe80136fc68331178cd45a1df9496ded092d976911b5244b85bc3de41e844ec194256b39aeee4ea55538a36139211e9910ad6b7a74e75d45b869d0a67aa4bf600930a5f760dfb8e4df9938d1f47b743d71c78ba8585e3b80aba26d24b1f50b36fa1458e79d54c05f58049245392bc3e2b5c5f9a1b99d43ed112ca82b201fb143d401741713188e47f1d6682b0bf496a5d4182836121efff0fd3b030fc6bfb5e21d6314a200963fe75cb856d444a813426b2084dfdc49dca2e649cb9da8bcb47859a4c629e97898e3547c591e39764110a224150d579c33fb74fa5eb96427036899c04154feab5344873d36a53a5baefd78c132be419f3f3a8dd8f60f72eb78dd5f43c53226f5ceb68947da3e19a750d760fb31fa8d4c7f53bfef11c4b89158aa56b1f4395430e16a3128f88e234ce1df7ef865f2d2c4975e8c82225f578310c31fd41d265fd530cbfa2b8895b228a510b806c31dff3b1fa5c08bffad443d567ed0e628febdd22775776e0cc9cebcaea9c6df9279a5d91dd0ee5e7a0434e989a160005321c97026cb559f71db23360105460d959bcdf74bee22c4ad8805a1d497507"; assert_eq!(got, expected) } + +#[test] +fn parallel_multiproof_determinism() { + // Test that parallel proof generation produces deterministic results + use ark_std::One; + + let n = 256; + let crs = CRS::new(n, b"eth_verkle_oct_2021"); + let precomp = PrecomputedWeights::new(n); + + // Create a batch of polynomials + let poly_a: Vec = (0..n).map(|i| Fr::from(((i % 32) + 1) as u128)).collect(); + let polynomial_a = LagrangeBasis::new(poly_a); + let poly_b: Vec = (0..n) + .rev() + .map(|i| Fr::from(((i % 32) + 1) as u128)) + .collect(); + let polynomial_b = LagrangeBasis::new(poly_b); + let poly_c: Vec = (0..n).map(|i| Fr::from((i * 2 + 1) as u128)).collect(); + let polynomial_c = LagrangeBasis::new(poly_c); + + let poly_comm_a = crs.commit_lagrange_poly(&polynomial_a); + let poly_comm_b = crs.commit_lagrange_poly(&polynomial_b); + let poly_comm_c = crs.commit_lagrange_poly(&polynomial_c); + + let queries = vec![ + ProverQuery { + commitment: poly_comm_a, + poly: polynomial_a.clone(), + point: 0, + result: Fr::one(), + }, + ProverQuery { + commitment: poly_comm_b, + poly: polynomial_b.clone(), + point: 1, + result: polynomial_b.evaluate_in_domain(1), + }, + ProverQuery { + commitment: poly_comm_c, + poly: polynomial_c.clone(), + point: 2, + result: polynomial_c.evaluate_in_domain(2), + }, + ]; + + // Generate proof multiple times and verify identical results + let mut first_proof_bytes: Option> = None; + for _ in 0..3 { + let mut transcript = Transcript::new(b"determinism_test"); + let proof = MultiPoint::open(crs.clone(), &precomp, &mut transcript, queries.clone()); + let proof_bytes = proof.to_bytes().unwrap(); + + match &first_proof_bytes { + Some(expected) => assert_eq!( + &proof_bytes, expected, + "Parallel proof generation should be deterministic" + ), + None => first_proof_bytes = Some(proof_bytes), + } + } +} + +#[test] +fn parallel_large_batch_verification() { + // Test parallel processing with a larger batch that would trigger parallel paths + + let n = 256; + let crs = CRS::new(n, b"eth_verkle_oct_2021"); + let precomp = PrecomputedWeights::new(n); + + // Create multiple polynomials with different evaluation points + let num_queries = 10; + let mut queries = Vec::with_capacity(num_queries); + let mut verifier_queries = Vec::with_capacity(num_queries); + + for i in 0..num_queries { + let poly: Vec = (0..n) + .map(|j| Fr::from(((j + i) % 256 + 1) as u128)) + .collect(); + let polynomial = LagrangeBasis::new(poly); + let point = i % n; + let result = polynomial.evaluate_in_domain(point); + let commitment = crs.commit_lagrange_poly(&polynomial); + + queries.push(ProverQuery { + commitment, + poly: polynomial, + point, + result, + }); + + verifier_queries.push(VerifierQuery { + commitment, + point: Fr::from(point as u128), + result, + }); + } + + let mut prover_transcript = Transcript::new(b"large_batch"); + let proof = MultiPoint::open(crs.clone(), &precomp, &mut prover_transcript, queries); + + let mut verifier_transcript = Transcript::new(b"large_batch"); + assert!( + proof.check(&crs, &precomp, &verifier_queries, &mut verifier_transcript), + "Large batch multiproof should verify correctly" + ); +} + +#[test] +fn single_thread_pool_safety() { + // Test that proof generation works correctly with a single-threaded pool + // This simulates RAYON_NUM_THREADS=1 environment + use rayon::ThreadPoolBuilder; + + let n = 256; + let crs = CRS::new(n, b"eth_verkle_oct_2021"); + let precomp = PrecomputedWeights::new(n); + + // Create test queries + let poly: Vec = (0..n).map(|i| Fr::from((i + 1) as u128)).collect(); + let polynomial = LagrangeBasis::new(poly); + let commitment = crs.commit_lagrange_poly(&polynomial); + + let queries: Vec = (0..5) + .map(|i| { + let point = i % n; + ProverQuery { + commitment, + poly: polynomial.clone(), + point, + result: polynomial.evaluate_in_domain(point), + } + }) + .collect(); + + let verifier_queries: Vec = queries + .iter() + .map(|q| VerifierQuery { + commitment: q.commitment, + point: Fr::from(q.point as u128), + result: q.result, + }) + .collect(); + + // Build a single-thread pool and run the proof generation inside it + let pool = ThreadPoolBuilder::new() + .num_threads(1) + .build() + .expect("Failed to build single-thread pool"); + + let proof = pool.install(|| { + let mut transcript = Transcript::new(b"single_thread_test"); + MultiPoint::open(crs.clone(), &precomp, &mut transcript, queries) + }); + + // Verify the proof works correctly + let mut verifier_transcript = Transcript::new(b"single_thread_test"); + assert!( + proof.check(&crs, &precomp, &verifier_queries, &mut verifier_transcript), + "Single-threaded proof should verify correctly" + ); +} diff --git a/verkle-trie/Cargo.toml b/verkle-trie/Cargo.toml index dacb9e4..b8a5cd2 100644 --- a/verkle-trie/Cargo.toml +++ b/verkle-trie/Cargo.toml @@ -30,6 +30,10 @@ serde_json = "1.0" criterion = "0.5.1" tempfile = "3.2.0" +[features] +default = [] +rocks_db = ["verkle-db/rocks_db"] + [[bench]] name = "benchmark_main" harness = false diff --git a/verkle-trie/src/database.rs b/verkle-trie/src/database.rs index fb927ba..7cf8747 100644 --- a/verkle-trie/src/database.rs +++ b/verkle-trie/src/database.rs @@ -44,6 +44,18 @@ pub trait WriteOnlyHigherDb { // TODO maybe we can return BranchChild, as the previous data could have been a stem or branch_meta // TODO then we can leave it upto the caller on how to deal with it fn insert_branch(&mut self, key: Vec, meta: BranchMeta, _depth: u8) -> Option; + + /// Remove a leaf, returning the old value if it existed + fn delete_leaf(&mut self, key: [u8; 32]) -> Option<[u8; 32]>; + + /// Remove a stem, returning the old metadata if it existed + fn delete_stem(&mut self, stem: [u8; 31]) -> Option; + + /// Remove a branch node, returning the old metadata if it existed + fn delete_branch(&mut self, path: &[u8]) -> Option; + + /// Remove a child reference from a branch node + fn remove_branch_child(&mut self, branch_path: &[u8], child_index: u8); } // Notice that these take self, which effectively forces the implementer diff --git a/verkle-trie/src/database/default.rs b/verkle-trie/src/database/default.rs index c25cf22..667e720 100644 --- a/verkle-trie/src/database/default.rs +++ b/verkle-trie/src/database/default.rs @@ -2,6 +2,7 @@ use super::{ generic::GenericBatchDB, memory_db::MemoryDb, BranchChild, BranchMeta, Flush, ReadOnlyHigherDb, StemMeta, WriteOnlyHigherDb, }; +use crate::constants::VERKLE_NODE_WIDTH; use crate::database::generic::GenericBatchWriter; use std::collections::HashMap; use verkle_db::{BareMetalDiskDb, BareMetalKVDb, BatchDB, BatchWriter}; @@ -26,14 +27,80 @@ pub struct VerkleDb { pub cache: MemoryDb, } -impl BareMetalDiskDb for VerkleDb { +impl VerkleDb { + /// Populate the in-memory cache from persistent storage + /// + /// This method performs a BFS traversal of the trie, loading nodes + /// from persistent storage into the in-memory cache. Only nodes at + /// depth <= CACHE_DEPTH are loaded. + pub fn populate_cache_from_storage(&mut self) { + // Check if root exists in storage + let root_path: Vec = vec![]; + let root_meta = match self.storage.get_branch_meta(&root_path) { + Some(meta) => meta, + None => return, // Empty trie, nothing to load + }; + + // Insert root into cache + self.cache.insert_branch(root_path.clone(), root_meta, 0); + + // BFS queue: (path, depth) + let mut queue: Vec<(Vec, u8)> = vec![(root_path, 0)]; + + while let Some((current_path, current_depth)) = queue.pop() { + // Stop if we've reached cache depth limit + if current_depth >= CACHE_DEPTH { + continue; + } + + // Load all children of this node from storage + let child_depth = current_depth + 1; + + for child_index in 0..VERKLE_NODE_WIDTH { + let mut child_path = current_path.clone(); + child_path.push(child_index as u8); + + // Try to load child from storage + if let Some(child) = self + .storage + .get_branch_child(¤t_path, child_index as u8) + { + match child { + BranchChild::Stem(stem_id) => { + // Load and cache stem metadata + if let Some(stem_meta) = self.storage.get_stem_meta(stem_id) { + self.cache.insert_stem(stem_id, stem_meta, child_depth); + } + self.cache + .add_stem_as_branch_child(child_path, stem_id, child_depth); + } + BranchChild::Branch(branch_meta) => { + // Cache branch and add to queue for further traversal + self.cache + .insert_branch(child_path.clone(), branch_meta, child_depth); + // Continue BFS if not at depth limit + if child_depth < CACHE_DEPTH { + queue.push((child_path, child_depth)); + } + } + } + } + } + } + } +} + +impl BareMetalDiskDb for VerkleDb { fn from_path>(path: P) -> Self { - VerkleDb { + let mut db = VerkleDb { storage: GenericBatchDB::from_path(path), - batch: MemoryDb::new(), cache: MemoryDb::new(), - } + }; + + db.populate_cache_from_storage(); + + db } const DEFAULT_PATH: &'static str = S::DEFAULT_PATH; @@ -187,7 +254,7 @@ impl ReadOnlyHigherDb for VerkleDb { } // Always save in the permanent storage and only save in the memorydb if the depth is <= cache depth -impl WriteOnlyHigherDb for VerkleDb { +impl WriteOnlyHigherDb for VerkleDb { fn insert_leaf(&mut self, key: [u8; 32], value: [u8; 32], depth: u8) -> Option> { if depth <= CACHE_DEPTH { self.cache.insert_leaf(key, value, depth); @@ -222,4 +289,442 @@ impl WriteOnlyHigherDb for VerkleDb { } self.batch.insert_branch(key, meta, depth) } + + fn delete_leaf(&mut self, key: [u8; 32]) -> Option<[u8; 32]> { + // Try cache first, then batch, then storage + // We need to check all layers and return the most recent value + let cache_val = self.cache.delete_leaf(key); + let batch_val = self.batch.delete_leaf(key); + + // Return the value from the most recent layer that had it + // Priority: cache > batch > storage + if cache_val.is_some() { + return cache_val; + } + if batch_val.is_some() { + return batch_val; + } + // For storage, we just get the value (actual deletion happens on flush) + // We mark for deletion by not having it in batch after this + self.storage.get_leaf(key) + } + + fn delete_stem(&mut self, stem: [u8; 31]) -> Option { + let cache_val = self.cache.delete_stem(stem); + let batch_val = self.batch.delete_stem(stem); + + if cache_val.is_some() { + return cache_val; + } + if batch_val.is_some() { + return batch_val; + } + self.storage.get_stem_meta(stem) + } + + fn delete_branch(&mut self, path: &[u8]) -> Option { + let cache_val = self.cache.delete_branch(path); + let batch_val = self.batch.delete_branch(path); + + if cache_val.is_some() { + return cache_val; + } + if batch_val.is_some() { + return batch_val; + } + self.storage.get_branch_meta(path) + } + + fn remove_branch_child(&mut self, branch_path: &[u8], child_index: u8) { + self.cache.remove_branch_child(branch_path, child_index); + self.batch.remove_branch_child(branch_path, child_index); + // Storage deletion happens implicitly when the data is not in batch on flush + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Test: Cache depth constant is correctly defined + #[test] + fn test_cache_depth_constant() { + assert_eq!(CACHE_DEPTH, 4); + } + + /// Test: VERKLE_NODE_WIDTH is 256 (needed for BFS traversal) + #[test] + fn test_verkle_node_width() { + assert_eq!(VERKLE_NODE_WIDTH, 256); + } + + /// Test: Empty storage results in empty cache + #[test] + fn test_empty_memory_db() { + let cache = MemoryDb::new(); + assert!(cache.branch_table.is_empty()); + assert!(cache.stem_table.is_empty()); + assert!(cache.leaf_table.is_empty()); + } +} + +#[cfg(all(test, feature = "rocks_db"))] +mod rocksdb_loading_tests { + use super::*; + use crate::{DefaultConfig, TrieTrait}; + use tempfile::TempDir; + use verkle_db::RocksDb; + + /// Test: Loading empty database results in empty trie + #[test] + fn test_load_empty_database() { + let temp_dir = TempDir::new().unwrap(); + let db = VerkleDb::::from_path(temp_dir.path()); + let config = DefaultConfig::new(db); + + let trie = crate::trie::Trie::new(config); + + // Empty database should have zero root + assert_eq!(trie.root_commitment(), banderwagon::Element::zero()); + } + + /// Test: Loading populated database restores correct root + #[test] + fn test_load_populated_database_root() { + let temp_dir = TempDir::new().unwrap(); + let path = temp_dir.path(); + + // Phase 1: Create and populate trie + let expected_root = { + let db = VerkleDb::::from_path(path); + let config = DefaultConfig::new(db); + let mut trie = crate::trie::Trie::new(config); + + let key = [1u8; 32]; + let value = [2u8; 32]; + trie.insert_single(key, value); + trie.root_commitment() + }; + + // Flush changes before reloading + // We need to create a new scope to drop the trie and flush + { + let db = VerkleDb::::from_path(path); + let config = DefaultConfig::new(db); + let mut trie = crate::trie::Trie::new(config); + trie.insert_single([1u8; 32], [2u8; 32]); + trie.storage.flush(); + } + + // Phase 2: Reload from same database + let db = VerkleDb::::from_path(path); + let config = DefaultConfig::new(db); + let trie = crate::trie::Trie::new(config); + + assert_eq!(trie.root_commitment(), expected_root); + } + + /// Test: Cache contains root after load + #[test] + fn test_cache_populated_to_depth() { + let temp_dir = TempDir::new().unwrap(); + let path = temp_dir.path().to_path_buf(); + + { + let db = VerkleDb::::from_path(&path); + let config = DefaultConfig::new(db); + let mut trie = crate::trie::Trie::new(config); + + for i in 0..10 { + let mut key = [0u8; 32]; + key[0] = i; + key[1] = i; + trie.insert_single(key, [i; 32]); + } + trie.storage.flush(); + drop(trie); + } + + std::thread::sleep(std::time::Duration::from_millis(100)); + + let db = VerkleDb::::from_path(&path); + + // Root should be in cache after loading from disk + assert!(db.cache.get_branch_meta(&[]).is_some()); + } + + /// Test: Get operation works after reload + #[test] + fn test_get_after_reload() { + let temp_dir = TempDir::new().unwrap(); + let path = temp_dir.path().to_path_buf(); + let key = [42u8; 32]; + let value = [99u8; 32]; + + { + let db = VerkleDb::::from_path(&path); + let config = DefaultConfig::new(db); + let mut trie = crate::trie::Trie::new(config); + trie.insert_single(key, value); + trie.storage.flush(); + drop(trie); + } + + std::thread::sleep(std::time::Duration::from_millis(100)); + + let db = VerkleDb::::from_path(&path); + let config = DefaultConfig::new(db); + let trie = crate::trie::Trie::new(config); + + let retrieved = trie.get(key); + assert_eq!(retrieved, Some(value)); + } + + /// Test: Proof generation works after reload + #[test] + fn test_proof_after_reload() { + let temp_dir = TempDir::new().unwrap(); + let path = temp_dir.path().to_path_buf(); + let key = [42u8; 32]; + let value = [99u8; 32]; + + { + let db = VerkleDb::::from_path(&path); + let config = DefaultConfig::new(db); + let mut trie = crate::trie::Trie::new(config); + trie.insert_single(key, value); + trie.storage.flush(); + drop(trie); + } + + std::thread::sleep(std::time::Duration::from_millis(100)); + + let db = VerkleDb::::from_path(&path); + let config = DefaultConfig::new(db); + let trie = crate::trie::Trie::new(config); + + let proof = trie.create_verkle_proof(vec![key].into_iter()); + assert!(proof.is_ok()); + } + + /// Test: Multiple inserts, flush, and reopen - the original bug scenario + /// + /// This test reproduces the bug where BranchChild deserialization failed + /// after reopening a persisted database. The issue was: + /// 1. BranchMeta::from_bytes() had wrong validation: `!len == 96` instead of `len != 96` + /// 2. BranchChild serialization didn't include a type discriminator tag + #[test] + fn test_database_persistence_roundtrip() { + let temp_dir = TempDir::new().unwrap(); + let path = temp_dir.path().to_path_buf(); + + // Store the root hash for comparison + let expected_root_hash: [u8; 32]; + + // Phase 1: Create trie, insert multiple keys, flush to disk + { + let db = VerkleDb::::from_path(&path); + let config = DefaultConfig::new(db); + let mut trie = crate::trie::Trie::new(config); + + // Insert multiple keys to create a non-trivial trie structure + // with both stems and branches + for i in 0..20u8 { + let mut key = [0u8; 32]; + key[0] = i; + key[31] = i; + let value = [i; 32]; + trie.insert_single(key, value); + } + + // Get the root hash before flush + let root = trie.root_hash(); + use banderwagon::trait_defs::*; + let mut root_bytes = [0u8; 32]; + root.serialize_compressed(&mut root_bytes[..]).unwrap(); + expected_root_hash = root_bytes; + + // Flush to disk + trie.storage.flush(); + } + + // Small delay to ensure RocksDB has flushed to disk + std::thread::sleep(std::time::Duration::from_millis(100)); + + // Phase 2: Reopen database - this is where the bug used to trigger + // The populate_cache_from_storage() call would panic on BranchChild::from_bytes() + let db = VerkleDb::::from_path(&path); + let config = DefaultConfig::new(db); + let trie = crate::trie::Trie::new(config); + + // Verify root hash matches + let root = trie.root_hash(); + use banderwagon::trait_defs::*; + let mut root_bytes = [0u8; 32]; + root.serialize_compressed(&mut root_bytes[..]).unwrap(); + assert_eq!(root_bytes, expected_root_hash, "Root hash should match after reopen"); + + // Verify we can retrieve all inserted values + for i in 0..20u8 { + let mut key = [0u8; 32]; + key[0] = i; + key[31] = i; + let expected_value = [i; 32]; + + let retrieved = trie.get(key); + assert_eq!( + retrieved, + Some(expected_value), + "Value for key {} should be retrievable after reopen", + i + ); + } + } + + /// Test: Insert keys with same stem prefix to create stem nodes + #[test] + fn test_persistence_with_stem_nodes() { + let temp_dir = TempDir::new().unwrap(); + let path = temp_dir.path().to_path_buf(); + + // Phase 1: Create trie with keys that share stem prefix + { + let db = VerkleDb::::from_path(&path); + let config = DefaultConfig::new(db); + let mut trie = crate::trie::Trie::new(config); + + // Keys with same first 31 bytes (same stem) but different last byte + let stem_prefix = [42u8; 31]; + for suffix in 0..5u8 { + let mut key = [0u8; 32]; + key[..31].copy_from_slice(&stem_prefix); + key[31] = suffix; + let value = [suffix; 32]; + trie.insert_single(key, value); + } + + trie.storage.flush(); + } + + std::thread::sleep(std::time::Duration::from_millis(100)); + + // Phase 2: Reopen and verify stem children + let db = VerkleDb::::from_path(&path); + let config = DefaultConfig::new(db); + let trie = crate::trie::Trie::new(config); + + let stem_prefix = [42u8; 31]; + for suffix in 0..5u8 { + let mut key = [0u8; 32]; + key[..31].copy_from_slice(&stem_prefix); + key[31] = suffix; + let expected_value = [suffix; 32]; + + let retrieved = trie.get(key); + assert_eq!( + retrieved, + Some(expected_value), + "Stem child {} should be retrievable after reopen", + suffix + ); + } + } + + /// Test: Multiple flush-reopen cycles + #[test] + fn test_multiple_reopen_cycles() { + let temp_dir = TempDir::new().unwrap(); + let path = temp_dir.path().to_path_buf(); + + // Cycle 1: Initial insert + { + let db = VerkleDb::::from_path(&path); + let config = DefaultConfig::new(db); + let mut trie = crate::trie::Trie::new(config); + trie.insert_single([1u8; 32], [10u8; 32]); + trie.storage.flush(); + } + + std::thread::sleep(std::time::Duration::from_millis(50)); + + // Cycle 2: Reopen, add more data, flush + { + let db = VerkleDb::::from_path(&path); + let config = DefaultConfig::new(db); + let mut trie = crate::trie::Trie::new(config); + + // Verify first insert is still there + assert_eq!(trie.get([1u8; 32]), Some([10u8; 32])); + + // Add more data + trie.insert_single([2u8; 32], [20u8; 32]); + trie.storage.flush(); + } + + std::thread::sleep(std::time::Duration::from_millis(50)); + + // Cycle 3: Final verification + let db = VerkleDb::::from_path(&path); + let config = DefaultConfig::new(db); + let trie = crate::trie::Trie::new(config); + + assert_eq!(trie.get([1u8; 32]), Some([10u8; 32]), "First insert should persist"); + assert_eq!(trie.get([2u8; 32]), Some([20u8; 32]), "Second insert should persist"); + } + + /// Test that storage contents can be retrieved after reopening + #[test] + fn test_debug_storage_contents() { + use crate::database::generic::{BRANCH_TABLE_MARKER, LEAF_TABLE_MARKER, STEM_TABLE_MARKER}; + + let temp_dir = TempDir::new().unwrap(); + let path = temp_dir.path().to_path_buf(); + let key = [42u8; 32]; + let value = [99u8; 32]; + + // Phase 1: Insert and flush + { + let db = VerkleDb::::from_path(&path); + let config = DefaultConfig::new(db); + let mut trie = crate::trie::Trie::new(config); + trie.insert_single(key, value); + trie.storage.flush(); + } + + std::thread::sleep(std::time::Duration::from_millis(100)); + + // Phase 2: Open raw RocksDB and verify contents exist + let raw_db = RocksDb::open_default(&path).unwrap(); + + // Check leaf entry exists + let mut leaf_key = vec![LEAF_TABLE_MARKER]; + leaf_key.extend_from_slice(&key); + let leaf_val = raw_db.get(&leaf_key).unwrap(); + assert!(leaf_val.is_some(), "Leaf entry should exist"); + assert_eq!(leaf_val.unwrap().len(), 32, "Leaf value should be 32 bytes"); + + // Check stem entry exists + let stem_key: [u8; 31] = key[..31].try_into().unwrap(); + let mut stem_db_key = vec![STEM_TABLE_MARKER]; + stem_db_key.extend_from_slice(&stem_key); + let stem_val = raw_db.get(&stem_db_key).unwrap(); + assert!(stem_val.is_some(), "Stem entry should exist"); + // 3 points * 64 bytes + 3 scalars * 32 bytes = 288 bytes + assert_eq!(stem_val.unwrap().len(), 288, "Stem meta should be 288 bytes"); + + // Check branch root entry exists + let root_key = vec![BRANCH_TABLE_MARKER]; + let root_val = raw_db.get(&root_key).unwrap(); + assert!(root_val.is_some(), "Branch root entry should exist"); + // 1 tag byte + 64 bytes point + 32 bytes scalar = 97 bytes + assert_eq!(root_val.unwrap().len(), 97, "Branch meta should be 97 bytes"); + + // Reopen and verify data can be retrieved + drop(raw_db); + let db = VerkleDb::::from_path(&path); + let config = DefaultConfig::new(db); + let trie = crate::trie::Trie::new(config); + let retrieved = trie.get(key); + assert_eq!(retrieved, Some(value), "Value should be retrievable after reopen"); + } } diff --git a/verkle-trie/src/database/generic.rs b/verkle-trie/src/database/generic.rs index 4940a34..393dc43 100644 --- a/verkle-trie/src/database/generic.rs +++ b/verkle-trie/src/database/generic.rs @@ -46,7 +46,9 @@ impl WriteOnlyHigherDb for GenericBatchWriter { labelled_key.push(BRANCH_TABLE_MARKER); labelled_key.extend(branch_child_id); - self.inner.batch_put(&labelled_key, &stem_id); + let branch_child = BranchChild::Stem(stem_id); + self.inner + .batch_put(&labelled_key, &branch_child.to_bytes().unwrap()); None } @@ -55,10 +57,31 @@ impl WriteOnlyHigherDb for GenericBatchWriter { labelled_key.push(BRANCH_TABLE_MARKER); labelled_key.extend_from_slice(&key); + let branch_child = BranchChild::Branch(meta); self.inner - .batch_put(&labelled_key, &meta.to_bytes().unwrap()); + .batch_put(&labelled_key, &branch_child.to_bytes().unwrap()); + None + } + + fn delete_leaf(&mut self, _key: [u8; 32]) -> Option<[u8; 32]> { + // BatchWriter only writes, cannot read previous values + // Deletion is handled by not writing the key in the batch + None + } + + fn delete_stem(&mut self, _stem: [u8; 31]) -> Option { + // BatchWriter only writes, cannot read previous values None } + + fn delete_branch(&mut self, _path: &[u8]) -> Option { + // BatchWriter only writes, cannot read previous values + None + } + + fn remove_branch_child(&mut self, _branch_path: &[u8], _child_index: u8) { + // BatchWriter only writes, deletion is handled by not writing + } } // This struct allows us to provide a default implementation of ReadOnlyHigherDB to @@ -140,9 +163,11 @@ impl ReadOnlyHigherDb for GenericBatchDB { labelled_key.push(BRANCH_TABLE_MARKER); labelled_key.extend_from_slice(key); - self.inner - .fetch(&labelled_key) - .map(|old_val_bytes| BranchMeta::from_bytes(old_val_bytes).unwrap()) + self.inner.fetch(&labelled_key).and_then(|old_val_bytes| { + BranchChild::from_bytes(old_val_bytes) + .ok() + .and_then(|bc| bc.branch()) + }) } fn get_branch_child(&self, branch_id: &[u8], index: u8) -> Option { diff --git a/verkle-trie/src/database/memory_db.rs b/verkle-trie/src/database/memory_db.rs index 3f10f7d..666d22a 100644 --- a/verkle-trie/src/database/memory_db.rs +++ b/verkle-trie/src/database/memory_db.rs @@ -139,6 +139,28 @@ impl WriteOnlyHigherDb for MemoryDb { self.branch_table .insert(branch_child_id, BranchChild::Stem(stem_id)) } + + fn delete_leaf(&mut self, key: [u8; 32]) -> Option<[u8; 32]> { + self.leaf_table.remove(&key) + } + + fn delete_stem(&mut self, stem: [u8; 31]) -> Option { + self.stem_table.remove(&stem) + } + + fn delete_branch(&mut self, path: &[u8]) -> Option { + match self.branch_table.remove(&path.to_vec()) { + Some(BranchChild::Branch(meta)) => Some(meta), + Some(BranchChild::Stem(_)) => None, + None => None, + } + } + + fn remove_branch_child(&mut self, branch_path: &[u8], child_index: u8) { + let mut child_key = branch_path.to_vec(); + child_key.push(child_index); + self.branch_table.remove(&child_key); + } } impl Flush for MemoryDb { diff --git a/verkle-trie/src/database/meta.rs b/verkle-trie/src/database/meta.rs index bb69a0f..f8512d7 100644 --- a/verkle-trie/src/database/meta.rs +++ b/verkle-trie/src/database/meta.rs @@ -69,9 +69,9 @@ impl FromBytes> for StemMeta { // not structured properly. We can guarantee this in verkle trie. fn from_bytes(bytes: Vec) -> Result { let len = bytes.len(); - // TODO: Explain where this number comes from + // 3 points * 64 bytes + 3 scalars * 32 bytes = 288 bytes total if len != 64 * 3 + 32 * 3 { - return Err(SerializationError::InvalidData); // TODO not the most accurate error msg for now + return Err(SerializationError::InvalidData); } let point_bytes = &bytes[0..64 * 3]; @@ -147,7 +147,7 @@ use crate::from_to_bytes::{FromBytes, ToBytes}; impl FromBytes> for BranchMeta { fn from_bytes(bytes: Vec) -> Result { let len = bytes.len(); - if !len == 32 + 64 { + if len != 32 + 64 { return Err(SerializationError::InvalidData); } @@ -222,22 +222,51 @@ pub enum BranchChild { Branch(BranchMeta), } +// Type discriminator bytes for BranchChild serialization +const BRANCH_CHILD_STEM_TAG: u8 = 0x00; +const BRANCH_CHILD_BRANCH_TAG: u8 = 0x01; + impl ToBytes> for BranchChild { fn to_bytes(&self) -> Result, SerializationError> { match self { - BranchChild::Stem(stem_id) => Ok(stem_id.to_vec()), - BranchChild::Branch(bm) => Ok(bm.to_bytes().unwrap().to_vec()), + BranchChild::Stem(stem_id) => { + let mut bytes = Vec::with_capacity(1 + 31); + bytes.push(BRANCH_CHILD_STEM_TAG); + bytes.extend_from_slice(stem_id); + Ok(bytes) + } + BranchChild::Branch(bm) => { + let mut bytes = Vec::with_capacity(1 + 96); + bytes.push(BRANCH_CHILD_BRANCH_TAG); + bytes.extend(bm.to_bytes()?); + Ok(bytes) + } } } } impl FromBytes> for BranchChild { fn from_bytes(bytes: Vec) -> Result { - if bytes.len() == 31 { - return Ok(BranchChild::Stem(bytes.try_into().unwrap())); + if bytes.is_empty() { + return Err(SerializationError::InvalidData); + } + + let tag = bytes[0]; + let data = &bytes[1..]; + + match tag { + BRANCH_CHILD_STEM_TAG => { + if data.len() != 31 { + return Err(SerializationError::InvalidData); + } + Ok(BranchChild::Stem(data.try_into().unwrap())) + } + BRANCH_CHILD_BRANCH_TAG => { + let branch_meta = BranchMeta::from_bytes(data.to_vec())?; + Ok(BranchChild::Branch(branch_meta)) + } + _ => Err(SerializationError::InvalidData), } - let branch_as_bytes = BranchMeta::from_bytes(bytes)?; - Ok(BranchChild::Branch(branch_as_bytes)) } } @@ -261,3 +290,209 @@ impl BranchChild { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_branch_meta_from_bytes_validates_length() { + // Empty bytes should fail + let result = BranchMeta::from_bytes(vec![]); + assert!(result.is_err()); + + // Wrong length (too short) should fail + let result = BranchMeta::from_bytes(vec![0u8; 50]); + assert!(result.is_err()); + + // Wrong length (too long) should fail + let result = BranchMeta::from_bytes(vec![0u8; 100]); + assert!(result.is_err()); + + // Exactly 31 bytes (stem size) should fail + let result = BranchMeta::from_bytes(vec![0u8; 31]); + assert!(result.is_err()); + } + + #[test] + fn test_branch_meta_serialization_roundtrip() { + let meta = BranchMeta::zero(); + let bytes = meta.to_bytes().unwrap(); + + // Verify correct size: 64 (point) + 32 (scalar) = 96 bytes + assert_eq!(bytes.len(), 96); + + let recovered = BranchMeta::from_bytes(bytes).unwrap(); + assert_eq!(meta, recovered); + } + + #[test] + fn test_branch_child_stem_serialization_roundtrip() { + let stem_id: [u8; 31] = [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, + ]; + let branch_child = BranchChild::Stem(stem_id); + + let bytes = branch_child.to_bytes().unwrap(); + + // Should be 1 byte tag + 31 bytes stem_id = 32 bytes + assert_eq!(bytes.len(), 32); + assert_eq!(bytes[0], BRANCH_CHILD_STEM_TAG); + + let recovered = BranchChild::from_bytes(bytes).unwrap(); + match recovered { + BranchChild::Stem(recovered_stem) => assert_eq!(recovered_stem, stem_id), + BranchChild::Branch(_) => panic!("Expected Stem variant"), + } + } + + #[test] + fn test_branch_child_branch_serialization_roundtrip() { + let meta = BranchMeta::zero(); + let branch_child = BranchChild::Branch(meta); + + let bytes = branch_child.to_bytes().unwrap(); + + // Should be 1 byte tag + 96 bytes BranchMeta = 97 bytes + assert_eq!(bytes.len(), 97); + assert_eq!(bytes[0], BRANCH_CHILD_BRANCH_TAG); + + let recovered = BranchChild::from_bytes(bytes).unwrap(); + match recovered { + BranchChild::Branch(recovered_meta) => assert_eq!(recovered_meta, meta), + BranchChild::Stem(_) => panic!("Expected Branch variant"), + } + } + + #[test] + fn test_branch_child_from_bytes_rejects_invalid_data() { + // Empty bytes should fail + let result = BranchChild::from_bytes(vec![]); + assert!(result.is_err()); + + // Invalid tag should fail + let result = BranchChild::from_bytes(vec![0xFF, 1, 2, 3]); + assert!(result.is_err()); + + // Stem tag with wrong data length should fail + let mut bad_stem = vec![BRANCH_CHILD_STEM_TAG]; + bad_stem.extend_from_slice(&[0u8; 30]); // Only 30 bytes, not 31 + let result = BranchChild::from_bytes(bad_stem); + assert!(result.is_err()); + + // Branch tag with wrong data length should fail + let mut bad_branch = vec![BRANCH_CHILD_BRANCH_TAG]; + bad_branch.extend_from_slice(&[0u8; 50]); // Only 50 bytes, not 96 + let result = BranchChild::from_bytes(bad_branch); + assert!(result.is_err()); + } + + #[test] + fn test_branch_child_type_discriminator_prevents_confusion() { + // Create a stem with 31 bytes + let stem_id: [u8; 31] = [0xAB; 31]; + let stem_child = BranchChild::Stem(stem_id); + let stem_bytes = stem_child.to_bytes().unwrap(); + + // Create a branch + let branch_child = BranchChild::Branch(BranchMeta::zero()); + let branch_bytes = branch_child.to_bytes().unwrap(); + + // Verify they have different tags + assert_ne!(stem_bytes[0], branch_bytes[0]); + + // Verify deserialization produces correct types + let recovered_stem = BranchChild::from_bytes(stem_bytes).unwrap(); + assert!(recovered_stem.stem().is_some()); + assert!(recovered_stem.branch().is_none()); + + let recovered_branch = BranchChild::from_bytes(branch_bytes).unwrap(); + assert!(recovered_branch.branch().is_some()); + assert!(recovered_branch.stem().is_none()); + } + + #[test] + fn test_stem_meta_serialization_roundtrip() { + use banderwagon::trait_defs::*; + + let meta = StemMeta { + c_1: Element::zero(), + hash_c1: Fr::zero(), + c_2: Element::zero(), + hash_c2: Fr::zero(), + stem_commitment: Element::zero(), + hash_stem_commitment: Fr::zero(), + }; + + let bytes = meta.to_bytes().unwrap(); + + // Verify correct size: 3 * 64 (points) + 3 * 32 (scalars) = 288 bytes + assert_eq!(bytes.len(), 288); + + let recovered = StemMeta::from_bytes(bytes).unwrap(); + assert_eq!(meta, recovered); + } + + #[test] + fn test_stem_meta_with_real_point() { + use banderwagon::trait_defs::*; + use banderwagon::Fr; + + // Create a non-trivial point by multiplying generator by a scalar + let generator = Element::prime_subgroup_generator(); + let scalar = Fr::from(12345u64); + let point = generator * scalar; + + let meta = StemMeta { + c_1: point, + hash_c1: scalar, + c_2: Element::zero(), + hash_c2: Fr::zero(), + stem_commitment: Element::zero(), + hash_stem_commitment: Fr::zero(), + }; + + let bytes = meta.to_bytes().unwrap(); + println!("First 64 bytes (c_1): {:?}", &bytes[0..64]); + + let recovered = StemMeta::from_bytes(bytes).unwrap(); + assert_eq!(meta.c_1, recovered.c_1); + assert_eq!(meta.hash_c1, recovered.hash_c1); + } + + #[test] + fn test_element_uncompressed_roundtrip() { + use banderwagon::trait_defs::*; + use banderwagon::Fr; + + // Create a non-trivial point + let generator = Element::prime_subgroup_generator(); + let scalar = Fr::from(12345u64); + let point = generator * scalar; + + // Serialize uncompressed + let mut bytes = [0u8; 64]; + point.serialize_uncompressed(&mut bytes[..]).unwrap(); + println!("Serialized bytes: {:?}", &bytes[..]); + + // Deserialize uncompressed + let recovered = Element::deserialize_uncompressed(&bytes[..]).unwrap(); + assert_eq!(point, recovered); + } + + #[test] + fn test_stem_meta_from_bytes_validates_length() { + // Empty bytes should fail + let result = StemMeta::from_bytes(vec![]); + assert!(result.is_err()); + + // Wrong length should fail + let result = StemMeta::from_bytes(vec![0u8; 100]); + assert!(result.is_err()); + + // Wrong length (close but not exact) should fail + let result = StemMeta::from_bytes(vec![0u8; 287]); + assert!(result.is_err()); + } +} diff --git a/verkle-trie/src/errors.rs b/verkle-trie/src/errors.rs index 8a57b9d..99649ae 100644 --- a/verkle-trie/src/errors.rs +++ b/verkle-trie/src/errors.rs @@ -45,3 +45,15 @@ pub enum ProofCreationError { #[error("Expected to have atleast one query, which will be against the root")] ExpectedOneQueryAgainstRoot, } + +#[derive(Debug, Error, PartialEq, Eq)] +pub enum DeleteError { + #[error("Stem not found for key: {0:?}")] + StemNotFound([u8; 31]), + + #[error("Branch not found at path: {0:?}")] + BranchNotFound(Vec), + + #[error("Failed to update commitment: {0}")] + CommitmentUpdateFailed(String), +} diff --git a/verkle-trie/src/lib.rs b/verkle-trie/src/lib.rs index aeff98a..bd9ae54 100644 --- a/verkle-trie/src/lib.rs +++ b/verkle-trie/src/lib.rs @@ -10,7 +10,7 @@ pub mod trie; pub use config::*; use errors::ProofCreationError; -pub use trie::Trie; +pub use trie::{ParallelInsertConfig, Trie}; pub use banderwagon::{Element, Fr}; diff --git a/verkle-trie/src/proof/stateless_updater.rs b/verkle-trie/src/proof/stateless_updater.rs index 6234eb7..b12c69f 100644 --- a/verkle-trie/src/proof/stateless_updater.rs +++ b/verkle-trie/src/proof/stateless_updater.rs @@ -721,7 +721,6 @@ mod test { group_to_field(&new_root_comm.unwrap()) .serialize_uncompressed(&mut got_bytes[..]) .unwrap(); - dbg!(&got_bytes); for key in keys.into_iter().skip(2) { // skip two keys that are already in the trie @@ -733,7 +732,6 @@ mod test { expected_root .serialize_uncompressed(&mut expected_bytes[..]) .unwrap(); - dbg!(&expected_bytes); assert_eq!(got_bytes, expected_bytes) } } diff --git a/verkle-trie/src/trie.rs b/verkle-trie/src/trie.rs index b03e2d7..e87f2a4 100644 --- a/verkle-trie/src/trie.rs +++ b/verkle-trie/src/trie.rs @@ -725,6 +725,421 @@ impl Trie Trie { + /// Delete a key from the trie + /// + /// Returns `Ok(Some(old_value))` if the key existed and was deleted, + /// `Ok(None)` if the key did not exist, or an error if deletion failed. + /// + /// # Example + /// ```ignore + /// let old_value = trie.delete(key)?; + /// match old_value { + /// Some(v) => println!("Deleted value: {:?}", v), + /// None => println!("Key did not exist"), + /// } + /// ``` + pub fn delete(&mut self, key: [u8; 32]) -> Result, DeleteError> { + // 1. Check if key exists + let old_value = match self.storage.get_leaf(key) { + Some(v) => v, + None => return Ok(None), + }; + + let stem: [u8; 31] = key[..31].try_into().unwrap(); + let suffix = key[31]; + + // 2. Delete the leaf + self.storage.delete_leaf(key); + + // 3. Check if stem is now empty + let stem_children = self.storage.get_stem_children(stem); + let stem_is_empty = stem_children.is_empty(); + + if stem_is_empty { + // Delete stem and prune empty branches + self.delete_stem_and_prune(stem)?; + } else { + // Update stem commitment + self.update_stem_after_delete(stem, suffix, old_value)?; + } + + Ok(Some(old_value)) + } + + /// Update stem commitment after deleting a leaf + /// + /// Algorithm: + /// 1. Determine which commitment to update (C1 for suffix < 128, C2 otherwise) + /// 2. Compute delta = old_scalar * G[2*suffix] + old_scalar_high * G[2*suffix + 1] + /// 3. Subtract delta from C1 or C2 + /// 4. Recompute stem commitment + /// 5. Update parent branch + fn update_stem_after_delete( + &mut self, + stem: [u8; 31], + suffix: u8, + old_value: [u8; 32], + ) -> Result<(), DeleteError> { + let stem_meta = self + .storage + .get_stem_meta(stem) + .ok_or(DeleteError::StemNotFound(stem))?; + + // Split old value into low and high 16-byte chunks + let old_value_low = Fr::from_le_bytes_mod_order(&old_value[..16]) + TWO_POW_128; + let old_value_high = Fr::from_le_bytes_mod_order(&old_value[16..]); + + // Compute which generators to use + let pos_mod_128 = (suffix % 128) as usize; + let low_index = 2 * pos_mod_128; + let high_index = low_index + 1; + + // Compute the commitment delta to subtract + let delta = self.committer.scalar_mul(old_value_low, low_index) + + self.committer.scalar_mul(old_value_high, high_index); + + // Update C1 or C2 + let (new_c1, new_hash_c1, new_c2, new_hash_c2) = if suffix < 128 { + let new_c1 = stem_meta.c_1 - delta; + let new_hash_c1 = group_to_field(&new_c1); + (new_c1, new_hash_c1, stem_meta.c_2, stem_meta.hash_c2) + } else { + let new_c2 = stem_meta.c_2 - delta; + let new_hash_c2 = group_to_field(&new_c2); + (stem_meta.c_1, stem_meta.hash_c1, new_c2, new_hash_c2) + }; + + // Recompute stem commitment: update based on C1/C2 change + let new_stem_commitment = if suffix < 128 { + let c1_delta = new_hash_c1 - stem_meta.hash_c1; + stem_meta.stem_commitment + self.committer.scalar_mul(c1_delta, 2) + } else { + let c2_delta = new_hash_c2 - stem_meta.hash_c2; + stem_meta.stem_commitment + self.committer.scalar_mul(c2_delta, 3) + }; + + let new_hash_stem_commitment = group_to_field(&new_stem_commitment); + + // Get depth for this stem + let depth = self.get_stem_depth(&stem); + + // Update database with new stem metadata + let new_meta = StemMeta { + c_1: new_c1, + hash_c1: new_hash_c1, + c_2: new_c2, + hash_c2: new_hash_c2, + stem_commitment: new_stem_commitment, + hash_stem_commitment: new_hash_stem_commitment, + }; + + self.storage.insert_stem(stem, new_meta, depth); + + // Update parent branch commitment + self.update_parent_after_stem_change( + &stem, + stem_meta.hash_stem_commitment, + new_hash_stem_commitment, + depth, + )?; + + Ok(()) + } + + /// Delete a stem node and prune any resulting empty branches + /// + /// Algorithm: + /// 1. Find path from root to stem + /// 2. Delete stem + /// 3. Remove stem reference from parent branch + /// 4. Walk up the tree, pruning empty branches + fn delete_stem_and_prune(&mut self, stem: [u8; 31]) -> Result<(), DeleteError> { + let stem_meta = self + .storage + .get_stem_meta(stem) + .ok_or(DeleteError::StemNotFound(stem))?; + + // Delete stem from database + self.storage.delete_stem(stem); + + // Get path to stem and remove from parent branch + let depth = self.get_stem_depth(&stem); + let path = stem[..depth as usize].to_vec(); + let child_index = stem[depth as usize]; + + // Remove stem from parent branch and update commitment + self.remove_child_from_branch(&path, child_index, stem_meta.hash_stem_commitment)?; + + // Prune empty branches walking up to root + self.prune_empty_branches(&path)?; + + Ok(()) + } + + /// Remove a child from a branch and update its commitment + fn remove_child_from_branch( + &mut self, + branch_path: &[u8], + child_index: u8, + old_child_hash: Fr, + ) -> Result<(), DeleteError> { + let branch_meta = self + .storage + .get_branch_meta(branch_path) + .ok_or_else(|| DeleteError::BranchNotFound(branch_path.to_vec()))?; + + // Remove child reference + self.storage.remove_branch_child(branch_path, child_index); + + // Compute new commitment: subtract old_child_hash * G[child_index] + let delta = self + .committer + .scalar_mul(old_child_hash, child_index as usize); + let new_commitment = branch_meta.commitment - delta; + let new_hash_commitment = group_to_field(&new_commitment); + + // Update branch + let new_meta = BranchMeta { + commitment: new_commitment, + hash_commitment: new_hash_commitment, + }; + + let depth = branch_path.len() as u8; + self.storage + .insert_branch(branch_path.to_vec(), new_meta, depth); + + // Update parent if not root + if !branch_path.is_empty() { + let parent_path = &branch_path[..branch_path.len() - 1]; + let parent_child_index = branch_path[branch_path.len() - 1]; + self.update_parent_branch( + parent_path, + parent_child_index, + branch_meta.hash_commitment, + new_hash_commitment, + )?; + } + + Ok(()) + } + + /// Prune branches that become empty after deletion + fn prune_empty_branches(&mut self, starting_path: &[u8]) -> Result<(), DeleteError> { + let mut current_path = starting_path.to_vec(); + + while !current_path.is_empty() { + let children = self.storage.get_branch_children(¤t_path); + + if children.is_empty() { + // Branch is empty, delete it + let branch_meta = self + .storage + .get_branch_meta(¤t_path) + .ok_or_else(|| DeleteError::BranchNotFound(current_path.clone()))?; + + self.storage.delete_branch(¤t_path); + + // Remove from parent + let child_index = current_path.pop().unwrap(); + self.remove_child_from_branch( + ¤t_path, + child_index, + branch_meta.hash_commitment, + )?; + } else { + // Branch still has children, stop pruning + break; + } + } + + Ok(()) + } + + /// Get the depth at which a stem exists in the tree + fn get_stem_depth(&self, stem: &[u8; 31]) -> u8 { + // Walk down from root finding the stem + for depth in 0..31u8 { + let path = &stem[..depth as usize]; + let child_index = stem[depth as usize]; + + match self.storage.get_branch_child(path, child_index) { + Some(BranchChild::Stem(s)) if s == *stem => return depth, + Some(BranchChild::Branch(_)) => continue, + _ => return depth, + } + } + 31 + } + + /// Update parent branch after stem commitment changes + fn update_parent_after_stem_change( + &mut self, + stem: &[u8; 31], + old_hash: Fr, + new_hash: Fr, + depth: u8, + ) -> Result<(), DeleteError> { + let path = stem[..depth as usize].to_vec(); + let child_index = stem[depth as usize]; + + self.update_parent_branch(&path, child_index, old_hash, new_hash) + } + + /// Update a parent branch when a child's commitment changes + fn update_parent_branch( + &mut self, + parent_path: &[u8], + child_index: u8, + old_child_hash: Fr, + new_child_hash: Fr, + ) -> Result<(), DeleteError> { + let parent_meta = self + .storage + .get_branch_meta(parent_path) + .ok_or_else(|| DeleteError::BranchNotFound(parent_path.to_vec()))?; + + // Delta update: new = old + (new_hash - old_hash) * G[child_index] + let delta = new_child_hash - old_child_hash; + let delta_commitment = self.committer.scalar_mul(delta, child_index as usize); + let new_commitment = parent_meta.commitment + delta_commitment; + let new_hash_commitment = group_to_field(&new_commitment); + + let new_meta = BranchMeta { + commitment: new_commitment, + hash_commitment: new_hash_commitment, + }; + + let depth = parent_path.len() as u8; + self.storage + .insert_branch(parent_path.to_vec(), new_meta, depth); + + // Recursively update ancestors + if !parent_path.is_empty() { + let grandparent_path = &parent_path[..parent_path.len() - 1]; + let parent_index = parent_path[parent_path.len() - 1]; + self.update_parent_branch( + grandparent_path, + parent_index, + parent_meta.hash_commitment, + new_hash_commitment, + )?; + } + + Ok(()) + } +} +/// Threshold for using parallel processing in insert operations +const PARALLEL_INSERT_THRESHOLD: usize = 100; + +/// Configuration for parallel batch insert +#[derive(Debug, Clone, Default)] +pub struct ParallelInsertConfig { + /// Use parallel processing regardless of batch size + pub force_parallel: bool, +} + +/// Grouped entries for a stem that can be processed together +#[derive(Debug)] +struct StemGroupedEntries { + /// All key-value pairs belonging to this stem + leaves: Vec<([u8; 32], [u8; 32])>, +} + +impl Trie { + /// Insert multiple key-value pairs with optimized batch processing. + /// + /// This method groups entries by stem and processes them more efficiently + /// than inserting individually. For batches of 100+ entries, commitment + /// computations are parallelized using rayon. + /// + /// # Example + /// ```ignore + /// let entries = vec![ + /// ([1u8; 32], [2u8; 32]), + /// ([1u8; 32], [3u8; 32]), + /// ]; + /// trie.insert_parallel(entries); + /// ``` + pub fn insert_parallel(&mut self, entries: I) + where + I: IntoIterator, + { + self.insert_parallel_with_config(entries, ParallelInsertConfig::default()) + } + + /// Insert with explicit parallelism configuration + pub fn insert_parallel_with_config(&mut self, entries: I, config: ParallelInsertConfig) + where + I: IntoIterator, + { + let entries: Vec<_> = entries.into_iter().collect(); + + if entries.is_empty() { + return; + } + + // For small batches or when not forced, use standard insert + if entries.len() < PARALLEL_INSERT_THRESHOLD && !config.force_parallel { + self.insert(entries.into_iter()); + return; + } + + // Group entries by stem + let mut by_stem: HashMap<[u8; 31], Vec<(u8, [u8; 32])>> = HashMap::new(); + for (key, value) in entries { + let stem: [u8; 31] = key[..31].try_into().unwrap(); + let suffix = key[31]; + by_stem.entry(stem).or_default().push((suffix, value)); + } + + // Group entries in parallel (reconstructing full keys) + let stem_groups: Vec = by_stem + .into_par_iter() + .map(|(stem, leaf_entries)| self.group_stem_entries(stem, leaf_entries)) + .collect(); + + // Apply updates sequentially (database writes must be sequential) + for group in stem_groups { + self.apply_stem_group(group); + } + } + + /// Group entries for a stem into full key-value pairs (can run in parallel) + fn group_stem_entries( + &self, + stem: [u8; 31], + entries: Vec<(u8, [u8; 32])>, + ) -> StemGroupedEntries { + let leaves = entries + .into_iter() + .map(|(suffix, value)| { + let mut key = [0u8; 32]; + key[..31].copy_from_slice(&stem); + key[31] = suffix; + (key, value) + }) + .collect(); + + StemGroupedEntries { leaves } + } + + /// Apply a grouped set of entries (must be sequential for database safety) + fn apply_stem_group(&mut self, group: StemGroupedEntries) { + // Insert all leaves for this stem using the standard insert + for (key, value) in group.leaves { + let ins = self.create_insert_instructions(key, value); + self.process_instructions(ins); + } + } +} + // Returns a list of all of the path indices where the two stems // are the same and the next path index where they both differ for each // stem. @@ -1205,4 +1620,349 @@ mod tests { let _val = trie.get(tree_key_code_keccak).unwrap(); let _val = trie.get(tree_key_code_size).unwrap(); } + + // Delete operation tests + + #[test] + fn test_delete_existing_key() { + let db = MemoryDb::new(); + let mut trie = Trie::new(DefaultConfig::new(db)); + + let key = [1u8; 32]; + let value = [2u8; 32]; + + trie.insert_single(key, value); + let result = trie.delete(key); + + assert_eq!(result, Ok(Some(value))); + assert_eq!(trie.get(key), None); + } + + #[test] + fn test_delete_nonexistent_key() { + let db = MemoryDb::new(); + let mut trie = Trie::new(DefaultConfig::new(db)); + + let key = [1u8; 32]; + let result = trie.delete(key); + + assert_eq!(result, Ok(None)); + } + + #[test] + fn test_delete_updates_root() { + let db = MemoryDb::new(); + let mut trie = Trie::new(DefaultConfig::new(db)); + + let key = [1u8; 32]; + let value = [2u8; 32]; + + trie.insert_single(key, value); + let root_before = trie.root_commitment(); + + trie.delete(key).unwrap(); + let root_after = trie.root_commitment(); + + assert_ne!(root_before, root_after); + } + + #[test] + fn test_delete_last_key_removes_stem() { + let db = MemoryDb::new(); + let mut trie = Trie::new(DefaultConfig::new(db)); + + let key = [1u8; 32]; + let stem: [u8; 31] = key[..31].try_into().unwrap(); + + trie.insert_single(key, [2u8; 32]); + assert!(trie.storage.get_stem_meta(stem).is_some()); + + trie.delete(key).unwrap(); + assert!(trie.storage.get_stem_meta(stem).is_none()); + } + + #[test] + fn test_delete_partial_stem() { + let db = MemoryDb::new(); + let mut trie = Trie::new(DefaultConfig::new(db)); + + // Two keys with same stem, different suffix + let mut key1 = [1u8; 32]; + let mut key2 = [1u8; 32]; + key1[31] = 0; + key2[31] = 1; + + let stem: [u8; 31] = key1[..31].try_into().unwrap(); + + trie.insert_single(key1, [2u8; 32]); + trie.insert_single(key2, [3u8; 32]); + + trie.delete(key1).unwrap(); + + // Stem should still exist with key2 + assert!(trie.storage.get_stem_meta(stem).is_some()); + assert_eq!(trie.get(key2), Some([3u8; 32])); + } + + #[test] + fn test_delete_reinsert_same_root() { + let db = MemoryDb::new(); + let mut trie = Trie::new(DefaultConfig::new(db)); + + let key = [1u8; 32]; + let value = [2u8; 32]; + + trie.insert_single(key, value); + let root_with_key = trie.root_commitment(); + + trie.delete(key).unwrap(); + trie.insert_single(key, value); + let root_after_reinsert = trie.root_commitment(); + + assert_eq!(root_with_key, root_after_reinsert); + } + + #[test] + fn test_delete_to_empty_trie() { + let db = MemoryDb::new(); + let mut trie = Trie::new(DefaultConfig::new(db)); + + let empty_root = trie.root_hash(); + + let key = [0u8; 32]; + let value = [1u8; 32]; + + trie.insert_single(key, value); + trie.delete(key).unwrap(); + + // After deleting the only key, root should be back to empty state + assert_eq!(trie.root_hash(), empty_root); + } + + #[test] + fn test_delete_c2_suffix() { + // Test deleting a key with suffix >= 128 (uses C2 commitment) + let db = MemoryDb::new(); + let mut trie = Trie::new(DefaultConfig::new(db)); + + let mut key = [1u8; 32]; + key[31] = 200; // suffix >= 128, uses C2 + let value = [2u8; 32]; + + trie.insert_single(key, value); + let result = trie.delete(key); + + assert_eq!(result, Ok(Some(value))); + assert_eq!(trie.get(key), None); + } + + #[test] + fn test_delete_multiple_stems() { + let db = MemoryDb::new(); + let mut trie = Trie::new(DefaultConfig::new(db)); + + let key1 = [0u8; 32]; + let key2 = [1u8; 32]; + + trie.insert_single(key1, [1u8; 32]); + trie.insert_single(key2, [2u8; 32]); + + // Delete first key + trie.delete(key1).unwrap(); + + // Second key should still exist + assert_eq!(trie.get(key2), Some([2u8; 32])); + assert_eq!(trie.get(key1), None); + + // Delete second key + trie.delete(key2).unwrap(); + assert_eq!(trie.get(key2), None); + } + + // Parallel insert tests + + #[test] + fn test_parallel_insert_produces_same_root_as_sequential() { + // Create two tries with the same entries using different insert methods + let entries: Vec<_> = (0..150) + .map(|i| { + let mut key = [0u8; 32]; + key[0..4].copy_from_slice(&(i as u32).to_le_bytes()); + (key, [i as u8; 32]) + }) + .collect(); + + // Sequential insert + let db1 = MemoryDb::new(); + let mut trie1 = Trie::new(DefaultConfig::new(db1)); + for (key, value) in entries.clone() { + trie1.insert_single(key, value); + } + let root1 = trie1.root_hash(); + + // Parallel insert + let db2 = MemoryDb::new(); + let mut trie2 = Trie::new(DefaultConfig::new(db2)); + trie2.insert_parallel(entries); + let root2 = trie2.root_hash(); + + assert_eq!( + root1, root2, + "Parallel and sequential insert should produce same root" + ); + } + + #[test] + fn test_parallel_insert_small_batch_uses_sequential() { + // Small batch should still work correctly + let entries: Vec<_> = (0..10) + .map(|i| { + let mut key = [0u8; 32]; + key[0] = i; + (key, [i; 32]) + }) + .collect(); + + let db = MemoryDb::new(); + let mut trie = Trie::new(DefaultConfig::new(db)); + trie.insert_parallel(entries.clone()); + + // Verify all entries exist + for (key, value) in entries { + assert_eq!(trie.get(key), Some(value)); + } + } + + #[test] + fn test_parallel_insert_same_stem_multiple_leaves() { + // Multiple leaves under the same stem + let stem = [1u8; 31]; + let entries: Vec<_> = (0..100) + .map(|i| { + let mut key = [0u8; 32]; + key[..31].copy_from_slice(&stem); + key[31] = i as u8; + (key, [i as u8; 32]) + }) + .collect(); + + // Sequential + let db1 = MemoryDb::new(); + let mut trie1 = Trie::new(DefaultConfig::new(db1)); + for (key, value) in entries.clone() { + trie1.insert_single(key, value); + } + + // Parallel (forced even though below threshold) + let db2 = MemoryDb::new(); + let mut trie2 = Trie::new(DefaultConfig::new(db2)); + trie2.insert_parallel_with_config( + entries.clone(), + super::ParallelInsertConfig { + force_parallel: true, + }, + ); + + assert_eq!( + trie1.root_hash(), + trie2.root_hash(), + "Same stem multiple leaves should produce same root" + ); + + // Verify all entries + for (key, value) in entries { + assert_eq!(trie2.get(key), Some(value)); + } + } + + #[test] + fn test_parallel_insert_different_stems() { + // Many different stems + let entries: Vec<_> = (0..200) + .map(|i| { + let mut key = [0u8; 32]; + // Different stem for each entry + key[0] = (i / 256) as u8; + key[1] = (i % 256) as u8; + key[31] = 0; // same suffix + (key, [(i % 256) as u8; 32]) + }) + .collect(); + + // Sequential + let db1 = MemoryDb::new(); + let mut trie1 = Trie::new(DefaultConfig::new(db1)); + for (key, value) in entries.clone() { + trie1.insert_single(key, value); + } + + // Parallel + let db2 = MemoryDb::new(); + let mut trie2 = Trie::new(DefaultConfig::new(db2)); + trie2.insert_parallel(entries.clone()); + + assert_eq!( + trie1.root_hash(), + trie2.root_hash(), + "Different stems should produce same root with parallel insert" + ); + } + + #[test] + fn test_parallel_insert_empty() { + let db = MemoryDb::new(); + let mut trie = Trie::new(DefaultConfig::new(db)); + + let empty_root = trie.root_hash(); + trie.insert_parallel(vec![]); + + assert_eq!( + trie.root_hash(), + empty_root, + "Empty insert should not change root" + ); + } + + #[test] + fn test_parallel_insert_c1_and_c2() { + // Test entries with suffix < 128 (C1) and suffix >= 128 (C2) + let stem = [5u8; 31]; + let entries: Vec<_> = vec![ + { + let mut key = [0u8; 32]; + key[..31].copy_from_slice(&stem); + key[31] = 50; // C1 + (key, [1u8; 32]) + }, + { + let mut key = [0u8; 32]; + key[..31].copy_from_slice(&stem); + key[31] = 200; // C2 + (key, [2u8; 32]) + }, + ]; + + // Sequential + let db1 = MemoryDb::new(); + let mut trie1 = Trie::new(DefaultConfig::new(db1)); + for (key, value) in entries.clone() { + trie1.insert_single(key, value); + } + + // Parallel (forced) + let db2 = MemoryDb::new(); + let mut trie2 = Trie::new(DefaultConfig::new(db2)); + trie2.insert_parallel_with_config( + entries.clone(), + super::ParallelInsertConfig { + force_parallel: true, + }, + ); + + assert_eq!( + trie1.root_hash(), + trie2.root_hash(), + "C1 and C2 entries should produce same root" + ); + } }