diff --git a/Cargo.toml b/Cargo.toml index a9e38f8..23dbcc4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,3 +13,6 @@ readme = "README.md" repository = "https://github.com/sigp/tree_hash" keywords = ["ethereum"] categories = ["cryptography::cryptocurrencies"] + +[patch.crates-io] +ethereum_ssz = { path = "../ethereum_ssz/ssz" } diff --git a/tree_hash/src/impls.rs b/tree_hash/src/impls.rs index bc9cd59..bdafd55 100644 --- a/tree_hash/src/impls.rs +++ b/tree_hash/src/impls.rs @@ -1,6 +1,6 @@ use super::*; use alloy_primitives::{Address, FixedBytes, U128, U256}; -use ssz::{Bitfield, Fixed, Variable}; +use ssz::{Bitfield, Fixed, Progressive, Variable}; use std::sync::Arc; use typenum::Unsigned; @@ -208,6 +208,38 @@ impl TreeHash for Bitfield> { } } +impl TreeHash for Bitfield { + fn tree_hash_type() -> TreeHashType { + TreeHashType::List + } + + fn tree_hash_packed_encoding(&self) -> PackedEncoding { + unreachable!("ProgressiveBitField should never be packed.") + } + + fn tree_hash_packing_factor() -> usize { + unreachable!("ProgressiveBitField should never be packed.") + } + + fn tree_hash_root(&self) -> Hash256 { + // FIXME(sproul): unclear if this is intended or a bug in the spec tests + // See: https://github.com/ethereum/consensus-specs/issues/4795 + if self.is_empty() { + return mix_in_length(&Hash256::ZERO, 0); + } + + let mut hasher = ProgressiveMerkleHasher::new(); + hasher + .write(self.as_slice()) + .expect("ProgessiveBitList should not exceed tree hash leaf limit"); + + let bitfield_root = hasher + .finish() + .expect("ProgressiveBitList tree hash buffer should not exceed leaf limit"); + mix_in_length(&bitfield_root, self.len()) + } +} + impl TreeHash for Bitfield> { fn tree_hash_type() -> TreeHashType { TreeHashType::Vector diff --git a/tree_hash/src/lib.rs b/tree_hash/src/lib.rs index e5ac66f..f9e2278 100644 --- a/tree_hash/src/lib.rs +++ b/tree_hash/src/lib.rs @@ -2,10 +2,14 @@ pub mod impls; mod merkle_hasher; mod merkleize_padded; mod merkleize_standard; +mod progressive_merkle_hasher; pub use merkle_hasher::{Error, MerkleHasher}; pub use merkleize_padded::merkleize_padded; pub use merkleize_standard::merkleize_standard; +pub use progressive_merkle_hasher::{ + Error as ProgressiveMerkleHasherError, ProgressiveMerkleHasher, +}; use ethereum_hashing::{hash_fixed, ZERO_HASHES, ZERO_HASHES_MAX_INDEX}; use smallvec::SmallVec; @@ -89,6 +93,13 @@ pub fn mix_in_selector(root: &Hash256, selector: u8) -> Option { Some(Hash256::from_slice(&root)) } +pub fn mix_in_active_fields(root: Hash256, active_fields: [u8; BYTES_PER_CHUNK]) -> Hash256 { + Hash256::from(ethereum_hashing::hash32_concat( + root.as_slice(), + &active_fields, + )) +} + /// Returns a cached padding node for a given height. fn get_zero_hash(height: usize) -> &'static [u8] { if height <= ZERO_HASHES_MAX_INDEX { diff --git a/tree_hash/src/progressive_merkle_hasher.rs b/tree_hash/src/progressive_merkle_hasher.rs new file mode 100644 index 0000000..a6b5ed3 --- /dev/null +++ b/tree_hash/src/progressive_merkle_hasher.rs @@ -0,0 +1,415 @@ +use crate::{Hash256, MerkleHasher, BYTES_PER_CHUNK}; +use ethereum_hashing::hash32_concat; + +#[derive(Clone, Debug, PartialEq)] +pub enum Error { + MerkleHasher(crate::merkle_hasher::Error), +} + +/// A progressive Merkle hasher that implements the semantics of `merkleize_progressive` as +/// defined in EIP-7916. +/// +/// The progressive merkle tree has a unique structure where: +/// - At each level, the right child is a binary merkle tree with a specific number of leaves +/// - The left child recursively contains more progressive structure +/// - The number of leaves in each right subtree grows by 4x at each level (1, 4, 16, 64, ...) +/// +/// # Example Tree Structure +/// +/// ```text +/// root +/// /\ +/// / \ +/// /\ 1: chunks[0 ..< 1] +/// / \ +/// /\ 4: chunks[1 ..< 5] +/// / \ +/// /\ 16: chunks[5 ..< 21] +/// / \ +/// 0 64: chunks[21 ..< 85] +/// ``` +/// +/// This structure allows efficient appending and proof generation for growing lists. +/// +/// # Efficiency +/// +/// This implementation hashes chunks as they are streamed in, storing only the minimum +/// necessary state (completed subtree roots). When a level is filled, its binary merkle +/// root is computed and stored, avoiding the need to keep all chunks in memory. +pub struct ProgressiveMerkleHasher { + /// Completed subtree roots at each level, stored in order of completion. + /// Index 0 = first completed level (1 leaf), index 1 = second level (4 leaves), etc. + /// Level i contains 4^i leaves. + completed_roots: Vec, + /// MerkleHasher for computing the current level's binary tree root. + current_hasher: MerkleHasher, + /// The number of leaves expected at the current level (1, 4, 16, 64, ...). + current_level_size: usize, + /// Number of chunks written to the current hasher. + current_level_chunks: usize, + /// Buffer for bytes that haven't been completed into a chunk yet. + buffer: Vec, + /// Total number of chunks written so far. + total_chunks: usize, +} + +impl ProgressiveMerkleHasher { + /// Create a new progressive merkle hasher that can accept any number of chunks. + pub fn new() -> Self { + Self { + completed_roots: Vec::new(), + current_hasher: MerkleHasher::with_leaves(1), + current_level_size: 1, + current_level_chunks: 0, + buffer: Vec::new(), + total_chunks: 0, + } + } + + /// Write bytes to the hasher. + /// + /// The bytes will be split into 32-byte chunks. Bytes are buffered across multiple + /// write calls to ensure proper chunk boundaries. Complete subtrees are hashed + /// immediately as chunks are written. + /// + /// # Errors + /// + /// Returns an error if writing these bytes would exceed the maximum number of leaves. + pub fn write(&mut self, bytes: &[u8]) -> Result<(), Error> { + // Add bytes to buffer + self.buffer.extend_from_slice(bytes); + + // Process complete chunks from buffer + while self.buffer.len() >= BYTES_PER_CHUNK { + let mut chunk = [0u8; BYTES_PER_CHUNK]; + chunk.copy_from_slice(&self.buffer[..BYTES_PER_CHUNK]); + self.buffer.drain(..BYTES_PER_CHUNK); + + self.process_chunk(chunk)?; + } + + Ok(()) + } + + /// Process a single chunk by adding it to the current level and completing the level if full. + fn process_chunk(&mut self, chunk: [u8; BYTES_PER_CHUNK]) -> Result<(), Error> { + // Write the chunk to the current MerkleHasher + self.current_hasher + .write(&chunk) + .map_err(Error::MerkleHasher)?; + + self.current_level_chunks += 1; + self.total_chunks += 1; + + // Check if current level is complete + if self.current_level_chunks == self.current_level_size { + // Move to next level (4x larger) + let next_level_size = self.current_level_size * 4; + + // Replace the current hasher with a new one for the next level + let completed_hasher = std::mem::replace( + &mut self.current_hasher, + MerkleHasher::with_leaves(next_level_size), + ); + + // Finish the completed hasher to get the root + let root = completed_hasher.finish().map_err(Error::MerkleHasher)?; + + // Store this completed root + self.completed_roots.push(root); + + self.current_level_size = next_level_size; + self.current_level_chunks = 0; + } + + Ok(()) + } + + /// Finish the hasher and return the progressive merkle root. + /// + /// This completes any partial level and combines all completed subtree roots + /// according to the progressive merkleization algorithm. + /// + /// Any remaining bytes in the buffer will be padded to form a final chunk. + pub fn finish(mut self) -> Result { + // Process any remaining bytes in the buffer as a final chunk + if !self.buffer.is_empty() { + let mut chunk = [0u8; BYTES_PER_CHUNK]; + chunk[..self.buffer.len()].copy_from_slice(&self.buffer); + self.process_chunk(chunk)?; + } + + // If we have no chunks at all, return zero hash + if self.total_chunks == 0 { + return Ok(Hash256::ZERO); + } + + // If there are chunks in current level (partial level), compute their root + let current_root = if self.current_level_chunks > 0 { + // Create a temporary hasher to replace the current one (since finish() takes ownership) + // FIXME(sproul): get rid of this by making build_progressive_root a static method. + let temp_hasher = std::mem::replace( + &mut self.current_hasher, + MerkleHasher::with_leaves(1), // dummy value, won't be used + ); + Some(temp_hasher.finish().map_err(Error::MerkleHasher)?) + } else { + None + }; + + // Build the progressive tree from completed roots and current root + // completed_roots are in order: [smallest level, ..., largest level] + // We need to build from right to left in the tree + Ok(self.build_progressive_root(current_root)) + } + + /// Build the final progressive merkle root by combining completed subtree roots. + /// + /// The progressive tree structure: at each node, hash(left=deeper_levels, right=this_level). + /// This builds the tree from the largest (leftmost) level backwards to the smallest (rightmost). + fn build_progressive_root(&self, current_root: Option) -> Hash256 { + // Start from the leftmost (largest/deepest) level + // Per EIP-7916 spec, even partial levels follow the progressive structure: + // merkleize_progressive(chunks, n) = hash(merkleize_progressive(chunks[n:], n*4), merkleize(chunks[:n], n)) + // So a partial level with k chunks becomes: hash(ZERO (no further chunks), merkleize(chunks, n)) + let mut result = if let Some(curr) = current_root { + Hash256::from_slice(&hash32_concat(Hash256::ZERO.as_slice(), curr.as_slice())) + } else { + Hash256::ZERO + }; + + // Process completed roots from largest to smallest (reverse order) + // At each step: result = hash(result, completed_root) + // - result accumulates the left subtree (deeper/larger levels) + // - completed_root is the right subtree at this level + for &completed_root in self.completed_roots.iter().rev() { + result = + Hash256::from_slice(&hash32_concat(result.as_slice(), completed_root.as_slice())); + } + + result + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::merkle_root; + + #[test] + fn test_empty_tree() { + let hasher = ProgressiveMerkleHasher::new(); + let root = hasher.finish().unwrap(); + assert_eq!(root, Hash256::ZERO); + } + + #[test] + fn test_single_chunk() { + let mut hasher = ProgressiveMerkleHasher::new(); + let chunk = [1u8; BYTES_PER_CHUNK]; + hasher.write(&chunk).unwrap(); + let root = hasher.finish().unwrap(); + + // For a single chunk, the progressive tree should be: + // hash(merkleize_progressive([], 4), merkleize([chunk], 1)) + // = hash(zero_hash, chunk) + let zero_left = Hash256::ZERO; + let right = Hash256::from_slice(&chunk); + let expected = Hash256::from_slice(&hash32_concat(zero_left.as_slice(), right.as_slice())); + + assert_eq!(root, expected); + } + + #[test] + fn test_two_chunks() { + let mut hasher = ProgressiveMerkleHasher::new(); + let chunk1 = [1u8; BYTES_PER_CHUNK]; + let chunk2 = [2u8; BYTES_PER_CHUNK]; + hasher.write(&chunk1).unwrap(); + hasher.write(&chunk2).unwrap(); + let root = hasher.finish().unwrap(); + + // First chunk goes to right (num_leaves=1) + // Second chunk goes to left recursive call (num_leaves=4) + + // Right: binary tree with 1 leaf = chunk1 + let right = Hash256::from_slice(&chunk1); + + // Left: progressive tree with chunk2 at num_leaves=4 + // At this level: hash(merkleize_progressive([], 16), merkleize([chunk2], 4)) + // = hash(zero_hash, merkle([chunk2], 4)) + let chunk2_padded = merkle_root(&chunk2, 4); + let zero_left_inner = Hash256::ZERO; + let left = Hash256::from_slice(&hash32_concat( + zero_left_inner.as_slice(), + chunk2_padded.as_slice(), + )); + + let expected = Hash256::from_slice(&hash32_concat(left.as_slice(), right.as_slice())); + assert_eq!(root, expected); + } + + #[test] + fn test_partial_chunk() { + let mut hasher = ProgressiveMerkleHasher::new(); + let partial = vec![1u8, 2u8, 3u8]; + hasher.write(&partial).unwrap(); + let root = hasher.finish().unwrap(); + + // Partial chunk should be padded with zeros + let mut chunk = [0u8; BYTES_PER_CHUNK]; + chunk[0] = 1; + chunk[1] = 2; + chunk[2] = 3; + + let zero_left = Hash256::ZERO; + let right = Hash256::from_slice(&chunk); + let expected = Hash256::from_slice(&hash32_concat(zero_left.as_slice(), right.as_slice())); + + assert_eq!(root, expected); + } + + #[test] + fn test_multiple_writes() { + let mut hasher = ProgressiveMerkleHasher::new(); + hasher.write(&[1u8; 16]).unwrap(); + hasher.write(&[2u8; 16]).unwrap(); + hasher.write(&[3u8; 32]).unwrap(); + let root = hasher.finish().unwrap(); + + // Should handle multiple writes correctly + assert_ne!(root, Hash256::ZERO); + } + + #[test] + fn test_five_chunks() { + // Test with 5 chunks as per the problem statement structure: + // chunks[0] goes to right at level 1 (1 leaf) + // chunks[1..5] go to left recursive call (4 leaves at level 2) + let mut hasher = ProgressiveMerkleHasher::new(); + for i in 0..5 { + let mut chunk = [0u8; BYTES_PER_CHUNK]; + chunk[0] = i as u8; + hasher.write(&chunk).unwrap(); + } + let root = hasher.finish().unwrap(); + + // Manually compute expected root: + // Right: chunks[0] + let mut chunk0 = [0u8; BYTES_PER_CHUNK]; + chunk0[0] = 0; + let right = Hash256::from_slice(&chunk0); + + // Left: merkleize_progressive(chunks[1..5], 4) + // Which is: hash(merkleize_progressive([], 16), merkleize(chunks[1..5], 4)) + let chunks_1_to_4: Vec = (1..5) + .flat_map(|i| { + let mut chunk = [0u8; BYTES_PER_CHUNK]; + chunk[0] = i; + chunk + }) + .collect(); + let right_inner = merkle_root(&chunks_1_to_4, 4); + let left_inner = Hash256::ZERO; + let left = Hash256::from_slice(&hash32_concat( + left_inner.as_slice(), + right_inner.as_slice(), + )); + + let expected = Hash256::from_slice(&hash32_concat(left.as_slice(), right.as_slice())); + assert_eq!(root, expected); + } + + #[test] + fn test_21_chunks() { + // Test with 21 chunks as per problem statement: + // chunks[0] goes to right at level 1 (1 leaf) + // chunks[1..5] go to right at level 2 (4 leaves) + // chunks[5..21] go to right at level 3 (16 leaves) + let mut hasher = ProgressiveMerkleHasher::new(); + for i in 0..21 { + let mut chunk = [0u8; BYTES_PER_CHUNK]; + chunk[0] = i as u8; + hasher.write(&chunk).unwrap(); + } + let root = hasher.finish().unwrap(); + + // Root should not be zero + assert_ne!(root, Hash256::ZERO); + } + + #[test] + fn test_85_chunks() { + // Test with 85 chunks as per problem statement structure: + // chunks[0] at level 1 (1 leaf) + // chunks[1..5] at level 2 (4 leaves) + // chunks[5..21] at level 3 (16 leaves) + // chunks[21..85] at level 4 (64 leaves) + let mut hasher = ProgressiveMerkleHasher::new(); + for i in 0..85 { + let mut chunk = [0u8; BYTES_PER_CHUNK]; + chunk[0] = (i % 256) as u8; + hasher.write(&chunk).unwrap(); + } + let root = hasher.finish().unwrap(); + + // Root should not be zero + assert_ne!(root, Hash256::ZERO); + } + + #[test] + fn test_consistency_across_write_patterns() { + // Test that different write patterns produce the same result + let chunks: Vec<[u8; BYTES_PER_CHUNK]> = (0..10) + .map(|i| { + let mut chunk = [0u8; BYTES_PER_CHUNK]; + chunk[0] = i; + chunk + }) + .collect(); + + // Write all chunks individually + let mut hasher1 = ProgressiveMerkleHasher::new(); + for chunk in &chunks { + hasher1.write(chunk).unwrap(); + } + let root1 = hasher1.finish().unwrap(); + + // Write all chunks at once + let mut hasher2 = ProgressiveMerkleHasher::new(); + let all_bytes: Vec = chunks.iter().flat_map(|c| c.iter().copied()).collect(); + hasher2.write(&all_bytes).unwrap(); + let root2 = hasher2.finish().unwrap(); + + // Write in groups + let mut hasher3 = ProgressiveMerkleHasher::new(); + hasher3.write(&all_bytes[..3 * BYTES_PER_CHUNK]).unwrap(); + hasher3 + .write(&all_bytes[3 * BYTES_PER_CHUNK..7 * BYTES_PER_CHUNK]) + .unwrap(); + hasher3.write(&all_bytes[7 * BYTES_PER_CHUNK..]).unwrap(); + let root3 = hasher3.finish().unwrap(); + + assert_eq!(root1, root2); + assert_eq!(root1, root3); + } + + #[test] + fn test_byte_streaming() { + // Test that we can write bytes in various chunk sizes + let data = vec![42u8; BYTES_PER_CHUNK * 3 + 10]; + + // Write all at once + let mut hasher1 = ProgressiveMerkleHasher::new(); + hasher1.write(&data).unwrap(); + let root1 = hasher1.finish().unwrap(); + + // Write in smaller chunks + let mut hasher2 = ProgressiveMerkleHasher::new(); + hasher2.write(&data[0..50]).unwrap(); + hasher2.write(&data[50..]).unwrap(); + let root2 = hasher2.finish().unwrap(); + + assert_eq!(root1, root2); + } +} diff --git a/tree_hash/tests/tests.rs b/tree_hash/tests/tests.rs index 8531548..f9f22e0 100644 --- a/tree_hash/tests/tests.rs +++ b/tree_hash/tests/tests.rs @@ -1,5 +1,6 @@ use alloy_primitives::{Address, U128, U160, U256}; use ssz_derive::Encode; +use std::str::FromStr; use tree_hash::{Hash256, MerkleHasher, PackedEncoding, TreeHash, BYTES_PER_CHUNK}; use tree_hash_derive::TreeHash; @@ -167,3 +168,19 @@ fn packed_encoding_example() { ); } } + +#[derive(TreeHash)] +#[tree_hash(struct_behaviour = "progressive_container", active_fields(1))] +struct ProgressiveContainerOneField { + x: u8, +} + +#[test] +fn progressive_container_one_field() { + let container = ProgressiveContainerOneField { x: 125 }; + assert_eq!( + container.tree_hash_root(), + Hash256::from_str("0xfacc8073916cbe1d3e400f69945fb5b6423d1e8f99be04713bcbe254fad2c94c") + .unwrap() + ); +} diff --git a/tree_hash_derive/src/attrs.rs b/tree_hash_derive/src/attrs.rs new file mode 100644 index 0000000..83426d2 --- /dev/null +++ b/tree_hash_derive/src/attrs.rs @@ -0,0 +1,127 @@ +use darling::{ast::NestedMeta, Error, FromDeriveInput, FromMeta}; +use quote::quote; + +pub const MAX_ACTIVE_FIELDS: usize = 256; +pub const ACTIVE_FIELDS_PACKED_BITS_LEN: usize = MAX_ACTIVE_FIELDS / 8; + +#[derive(Debug, FromDeriveInput)] +#[darling(attributes(tree_hash))] +pub struct StructOpts { + #[darling(default)] + pub enum_behaviour: Option, + #[darling(default)] + pub struct_behaviour: Option, + #[darling(default)] + pub active_fields: Option, +} + +#[derive(Debug, FromMeta)] +pub enum EnumBehaviour { + Transparent, + Union, +} + +#[derive(Debug, Default, FromMeta)] +pub enum StructBehaviour { + #[default] + Container, + ProgressiveContainer, +} + +#[derive(Debug)] +pub struct ActiveFields { + pub active_fields: Vec, +} + +impl FromMeta for ActiveFields { + fn from_list(items: &[NestedMeta]) -> Result { + let active_fields = items + .iter() + .map(|nested_meta| match u8::from_nested_meta(nested_meta) { + Ok(0) => Ok(false), + Ok(1) => Ok(true), + Ok(n) => Err(Error::custom(format!( + "invalid integer in active_fields: {n}" + ))), + Err(e) => Err(Error::custom(format!( + "unable to parse active_fields entry: {e:?}" + ))), + }) + .collect::>()?; + Self::new(active_fields) + } +} + +impl ActiveFields { + fn new(active_fields: Vec) -> Result { + if active_fields.is_empty() { + return Err(Error::custom(format!("active_fields must be non-empty"))); + } + if active_fields.len() > MAX_ACTIVE_FIELDS { + return Err(Error::custom(format!( + "active_fields cannot contain more than {MAX_ACTIVE_FIELDS} entries" + ))); + } + + if let Some(false) = active_fields.last() { + return Err(Error::custom(format!( + "the last entry of active_fields must not be 0" + ))); + } + + Ok(Self { active_fields }) + } + + pub fn packed(&self) -> [u8; ACTIVE_FIELDS_PACKED_BITS_LEN] { + let mut result = [0; ACTIVE_FIELDS_PACKED_BITS_LEN]; + for (i, bit) in self.active_fields.iter().enumerate() { + if *bit { + result[i / 8] |= 1 << (i % 8); + } + } + result + } + + /// Return tokens for the packed representation of these `active_fields`. + /// + /// We compute the packed representation at compile-time, and then inline it via the output + /// of this function. + pub fn packed_tokens(&self) -> proc_macro2::TokenStream { + let packed = self.packed().to_vec(); + quote! { + [ + #(#packed),* + ] + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn active_fields_packed_basic() { + let active_fields = ActiveFields { + active_fields: vec![true], + }; + assert_eq!( + active_fields.packed(), + [ + 0b0000001, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + ] + ); + + let active_fields = ActiveFields { + active_fields: vec![true, false, true, false, false, true], + }; + assert_eq!( + active_fields.packed(), + [ + 0b0100101, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + ] + ); + } +} diff --git a/tree_hash_derive/src/lib.rs b/tree_hash_derive/src/lib.rs index 6e6a727..454d0dc 100644 --- a/tree_hash_derive/src/lib.rs +++ b/tree_hash_derive/src/lib.rs @@ -1,4 +1,7 @@ #![recursion_limit = "256"] +mod attrs; + +use crate::attrs::{EnumBehaviour, StructBehaviour, StructOpts}; use darling::FromDeriveInput; use proc_macro::TokenStream; use quote::quote; @@ -9,37 +12,6 @@ use syn::{parse_macro_input, DataEnum, DataStruct, DeriveInput, Ident}; /// extensions). const MAX_UNION_SELECTOR: u8 = 127; -#[derive(Debug, FromDeriveInput)] -#[darling(attributes(tree_hash))] -struct StructOpts { - #[darling(default)] - enum_behaviour: Option, -} - -const ENUM_TRANSPARENT: &str = "transparent"; -const ENUM_UNION: &str = "union"; -const ENUM_VARIANTS: &[&str] = &[ENUM_TRANSPARENT, ENUM_UNION]; -const NO_ENUM_BEHAVIOUR_ERROR: &str = "enums require an \"enum_behaviour\" attribute, \ - e.g., #[tree_hash(enum_behaviour = \"transparent\")]"; - -enum EnumBehaviour { - Transparent, - Union, -} - -impl EnumBehaviour { - pub fn new(s: Option) -> Option { - s.map(|s| match s.as_ref() { - ENUM_TRANSPARENT => EnumBehaviour::Transparent, - ENUM_UNION => EnumBehaviour::Union, - other => panic!( - "{} is an invalid enum_behaviour, use either {:?}", - other, ENUM_VARIANTS - ), - }) - } -} - /// Return a Vec of `syn::Ident` for each named field in the struct, whilst filtering out fields /// that should not be hashed. /// @@ -82,40 +54,105 @@ fn should_skip_hashing(field: &syn::Field) -> bool { }) } -/// Implements `tree_hash::TreeHash` for some `struct`. +/// Implements `tree_hash::TreeHash` for a type. /// /// Fields are hashed in the order they are defined. #[proc_macro_derive(TreeHash, attributes(tree_hash))] pub fn tree_hash_derive(input: TokenStream) -> TokenStream { let item = parse_macro_input!(input as DeriveInput); let opts = StructOpts::from_derive_input(&item).unwrap(); - let enum_opt = EnumBehaviour::new(opts.enum_behaviour); + let enum_opt = opts.enum_behaviour; + let struct_opt = opts.struct_behaviour; - match &item.data { - syn::Data::Struct(s) => { + match (&item.data, enum_opt, struct_opt) { + (syn::Data::Struct(s), enum_opt, struct_opt) => { if enum_opt.is_some() { panic!("enum_behaviour is invalid for structs"); } - tree_hash_derive_struct(&item, s) + let struct_behaviour = struct_opt.unwrap_or_default(); + tree_hash_derive_struct(&item, s, struct_behaviour, opts.active_fields) + } + (syn::Data::Enum(s), Some(enum_behaviour), struct_opt) => { + if struct_opt.is_some() { + panic!("struct_behaviour is invalid for enums"); + } + match enum_behaviour { + EnumBehaviour::Transparent => tree_hash_derive_enum_transparent(&item, s), + EnumBehaviour::Union => tree_hash_derive_enum_union(&item, s), + } } - syn::Data::Enum(s) => match enum_opt.expect(NO_ENUM_BEHAVIOUR_ERROR) { - EnumBehaviour::Transparent => tree_hash_derive_enum_transparent(&item, s), - EnumBehaviour::Union => tree_hash_derive_enum_union(&item, s), - }, - _ => panic!("tree_hash_derive only supports structs and enums."), + _ => panic!("tree_hash_derive only supports structs and enums"), } } -fn tree_hash_derive_struct(item: &DeriveInput, struct_data: &DataStruct) -> TokenStream { +fn tree_hash_derive_struct( + item: &DeriveInput, + struct_data: &DataStruct, + struct_behaviour: StructBehaviour, + active_fields_opt: Option, +) -> TokenStream { let name = &item.ident; let (impl_generics, ty_generics, where_clause) = &item.generics.split_for_impl(); let idents = get_hashable_fields(struct_data); - let num_leaves = idents.len(); + + let hasher_init = if let StructBehaviour::ProgressiveContainer = struct_behaviour { + quote! { tree_hash::ProgressiveMerkleHasher::new() } + } else { + let num_leaves = idents.len(); + quote! { tree_hash::MerkleHasher::with_leaves(#num_leaves) } + }; + + // Compute the field hashes while accounting for inactive fields which hash as 0x0. + // + // The `mixin_logic` is the expression to mix in the `active_fields` in the case of a + // progressive container. + let (field_hashes, mixin_logic) = + if let StructBehaviour::ProgressiveContainer = struct_behaviour { + let Some(active_fields) = active_fields_opt else { + panic!("active_fields must be provided for progressive_container"); + }; + + let mut active_field_index = 0; + let mut field_hashes: Vec = vec![]; + for active in &active_fields.active_fields { + if *active { + let Some(ident) = idents.get(active_field_index) else { + panic!( + "active_fields is inconsistent with struct fields. \ + index: {active_field_index}, hashable fields: {}", + idents.len() + ) + }; + active_field_index += 1; + field_hashes.push(quote! { self.#ident.tree_hash_root() }); + } else { + field_hashes.push(quote! { tree_hash::Hash256::ZERO }); + } + } + + let packed_active_fields = active_fields.packed_tokens(); + + let mixin_logic = quote! { + const ACTIVE_FIELDS: [u8; 32] = #packed_active_fields; + tree_hash::mix_in_active_fields(container_root, ACTIVE_FIELDS) + }; + + (field_hashes, mixin_logic) + } else { + ( + idents + .into_iter() + .map(|ident| quote! { self.#ident.tree_hash_root() }) + .collect(), + quote! { container_root }, + ) + }; let output = quote! { impl #impl_generics tree_hash::TreeHash for #name #ty_generics #where_clause { fn tree_hash_type() -> tree_hash::TreeHashType { + // FIXME(sproul): consider adjusting this with active_fields tree_hash::TreeHashType::Container } @@ -128,14 +165,16 @@ fn tree_hash_derive_struct(item: &DeriveInput, struct_data: &DataStruct) -> Toke } fn tree_hash_root(&self) -> tree_hash::Hash256 { - let mut hasher = tree_hash::MerkleHasher::with_leaves(#num_leaves); + let mut hasher = #hasher_init; #( - hasher.write(self.#idents.tree_hash_root().as_slice()) + hasher.write(#field_hashes.as_slice()) .expect("tree hash derive should not apply too many leaves"); )* - hasher.finish().expect("tree hash derive should not have a remaining buffer") + let container_root = hasher.finish().expect("tree hash derive should not have a remaining buffer"); + + #mixin_logic } } };