diff --git a/src/error.rs b/src/error.rs index 499712a..1dc0288 100644 --- a/src/error.rs +++ b/src/error.rs @@ -34,6 +34,7 @@ pub enum Error { IntraRebaseZeroHash, IntraRebaseZeroDepth, IntraRebaseRepeatVisit, + PackedLeavesNoArc, } impl Display for Error { diff --git a/src/interface.rs b/src/interface.rs index f056c2e..ad15ff1 100644 --- a/src/interface.rs +++ b/src/interface.rs @@ -1,3 +1,4 @@ +use crate::ArcIter; use crate::level_iter::LevelIter; use crate::update_map::UpdateMap; use crate::utils::{Length, updated_length}; @@ -22,6 +23,8 @@ pub trait ImmList { fn iter_from(&self, index: usize) -> Iter<'_, T>; fn level_iter_from(&self, index: usize) -> LevelIter<'_, T>; + + fn iter_arc(&self, index: usize) -> Result, Error>; } pub trait MutList: ImmList { @@ -100,6 +103,10 @@ where self.iter_from(0) } + pub fn iter_arc(&self) -> Result, Error> { + self.backing.iter_arc(0) + } + pub fn iter_from(&self, index: usize) -> InterfaceIter<'_, T, U> { InterfaceIter { tree_iter: self.backing.iter_from(index), diff --git a/src/iter_arc.rs b/src/iter_arc.rs new file mode 100644 index 0000000..f055505 --- /dev/null +++ b/src/iter_arc.rs @@ -0,0 +1,157 @@ +use tree_hash::{TreeHash, TreeHashType}; +use triomphe::Arc; + +use crate::{ + Error, Leaf, Tree, UpdateMap, Value, + utils::{Length, opt_packing_depth}, +}; + +#[derive(Debug)] +pub struct ArcIter<'a, T: Value> { + /// Stack of tree nodes corresponding to the current position. + stack: Vec<&'a Tree>, + /// The list index corresponding to the current position (next element to be yielded). + index: usize, + /// The `depth` of the root tree. + full_depth: usize, + /// Cached packing depth to avoid re-calculating `opt_packing_depth`. + packing_depth: usize, + /// Number of items that will be yielded by the iterator. + length: Length, +} + +impl<'a, T: Value> ArcIter<'a, T> { + pub fn from_index( + index: usize, + root: &'a Tree, + depth: usize, + length: Length, + ) -> Result { + if ::tree_hash_type() == TreeHashType::Basic { + return Err(Error::PackedLeavesNoArc); + } + let mut stack = Vec::with_capacity(depth); + stack.push(root); + + Ok(ArcIter { + stack, + index, + full_depth: depth, + packing_depth: opt_packing_depth::().unwrap_or(0), + length, + }) + } +} + +impl<'a, T: Value> ArcIter<'a, T> { + pub fn new(root: &'a Tree, depth: usize, length: Length) -> Self { + let mut stack = Vec::with_capacity(depth); + stack.push(root); + + ArcIter { + stack, + index: 0, + full_depth: depth, + packing_depth: opt_packing_depth::().unwrap_or(0), + length, + } + } +} + +impl<'a, T: Value> Iterator for ArcIter<'a, T> { + type Item = &'a Arc; + + fn next(&mut self) -> Option { + if self.index >= self.length.as_usize() { + return None; + } + + match self.stack.last() { + None | Some(Tree::Zero(_)) => None, + Some(Tree::Leaf(Leaf { value, .. })) => { + let result = Some(value); + + self.index += 1; + + // Backtrack to the parent node of the next subtree + for _ in 0..=self.index.trailing_zeros() { + self.stack.pop(); + } + + result + } + Some(Tree::PackedLeaf(_)) => { + // Return None case of PackedLeaf + None + } + Some(Tree::Node { left, right, .. }) => { + let depth = self.full_depth - self.stack.len(); + + // Go left + if (self.index >> (depth + self.packing_depth)) & 1 == 0 { + self.stack.push(left); + self.next() + } + // Go right + else { + self.stack.push(right); + self.next() + } + } + } + } + + fn size_hint(&self) -> (usize, Option) { + let remaining = self.length.as_usize().saturating_sub(self.index); + (remaining, Some(remaining)) + } +} + +impl ExactSizeIterator for ArcIter<'_, T> {} +#[derive(Debug)] +pub struct ArcInterfaceIter<'a, T: Value, U: UpdateMap> { + tree_iter: ArcIter<'a, T>, + updates: &'a U, + index: usize, + length: usize, +} + +impl<'a, T: Value, U: UpdateMap> ArcInterfaceIter<'a, T, U> { + pub fn new(root: &'a Tree, depth: usize, length: Length, updates: &'a U) -> Self { + ArcInterfaceIter { + tree_iter: ArcIter::new(root, depth, length), + updates, + index: 0, + length: length.as_usize(), + } + } +} + +impl<'a, T: Value, U: UpdateMap> Iterator for ArcInterfaceIter<'a, T, U> { + type Item = Arc; + + fn next(&mut self) -> Option { + if self.index >= self.length { + return None; + } + let idx = self.index; + self.index += 1; + + let backing = self.tree_iter.next(); + if let Some(new_val) = self.updates.get(idx) { + Some( + self.updates + .get_arc(idx) + .unwrap_or_else(|| Arc::new(new_val.clone())), + ) + } else { + backing.cloned() + } + } + + fn size_hint(&self) -> (usize, Option) { + let rem = self.length.saturating_sub(self.index); + (rem, Some(rem)) + } +} +impl> ExactSizeIterator for ArcInterfaceIter<'_, T, U> {} diff --git a/src/lib.rs b/src/lib.rs index cec7288..4576e29 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ pub mod error; pub mod interface; pub mod interface_iter; pub mod iter; +pub mod iter_arc; pub mod leaf; pub mod level_iter; pub mod list; @@ -23,6 +24,7 @@ pub mod vector; pub use cow::Cow; pub use error::Error; pub use interface::ImmList; +pub use iter_arc::ArcIter; pub use leaf::Leaf; pub use list::List; pub use packed_leaf::PackedLeaf; diff --git a/src/list.rs b/src/list.rs index ad96b01..9f92f0a 100644 --- a/src/list.rs +++ b/src/list.rs @@ -2,6 +2,7 @@ use crate::builder::Builder; use crate::interface::{ImmList, Interface, MutList}; use crate::interface_iter::{InterfaceIter, InterfaceIterCow}; use crate::iter::Iter; +use crate::iter_arc::{ArcInterfaceIter, ArcIter}; use crate::level_iter::{LevelIter, LevelNode}; use crate::serde::ListVisitor; use crate::tree::{IntraRebaseAction, RebaseAction}; @@ -16,7 +17,7 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer, ser::SerializeSeq} use ssz::{BYTES_PER_LENGTH_OFFSET, Decode, Encode, SszEncoder, TryFromIter}; use std::collections::{BTreeMap, HashMap}; use std::marker::PhantomData; -use tree_hash::{Hash256, PackedEncoding, TreeHash}; +use tree_hash::{Hash256, PackedEncoding, TreeHash, TreeHashType}; use typenum::Unsigned; use vec_map::VecMap; #[derive(Debug, Clone, Educe)] @@ -136,6 +137,15 @@ impl> List { Ok(self.interface.iter_from(index)) } + pub fn iter_arc(&self) -> Result>, Error> { + Ok(ArcInterfaceIter::new( + &self.interface.backing.tree, + self.interface.backing.depth, + Length(self.len()), + &self.interface.updates, + )) + } + /// Iterate all internal nodes on the same level as `index`. pub fn level_iter_from(&self, index: usize) -> Result, Error> { // Return an empty iterator at index == length, just like slicing. @@ -267,6 +277,10 @@ impl ImmList for ListInner { fn level_iter_from(&self, index: usize) -> LevelIter<'_, T> { LevelIter::from_index(index, &self.tree, self.depth, self.length) } + + fn iter_arc(&self, index: usize) -> Result, Error> { + ArcIter::from_index(index, &self.tree, self.depth, self.length) + } } impl MutList for ListInner @@ -370,8 +384,8 @@ impl> Default for List { } impl> TreeHash for List { - fn tree_hash_type() -> tree_hash::TreeHashType { - tree_hash::TreeHashType::List + fn tree_hash_type() -> TreeHashType { + TreeHashType::List } fn tree_hash_packed_encoding(&self) -> PackedEncoding { diff --git a/src/tests/proptest/operations.rs b/src/tests/proptest/operations.rs index bb3aaa6..186360d 100644 --- a/src/tests/proptest/operations.rs +++ b/src/tests/proptest/operations.rs @@ -4,7 +4,7 @@ use proptest::prelude::*; use ssz::{Decode, Encode}; use std::fmt::Debug; use std::marker::PhantomData; -use tree_hash::{Hash256, TreeHash}; +use tree_hash::{Hash256, TreeHash, TreeHashType}; use typenum::{U1, U2, U3, U4, U7, U8, U9, U32, U33, U1024, Unsigned}; const OP_LIMIT: usize = 128; @@ -107,6 +107,8 @@ pub enum Op { Push(T), /// Check the `iter` method. Iter, + /// Cheeck the `iter_arc` method for non basic types. + IterArc, /// Check the `iter_from` method. IterFrom(usize), /// Check the `pop_front` method. @@ -154,10 +156,11 @@ where Just(Op::Debase), Just(Op::FromIntoRoundtrip), Just(Op::IntraRebase), + Just(Op::IterArc), ]; prop_oneof![ 10 => a_block, - 6 => b_block + 7 => b_block ] } @@ -210,6 +213,16 @@ where Op::Iter => { assert!(list.iter().eq(spec.iter())); } + Op::IterArc => { + if ::tree_hash_type() != TreeHashType::Basic { + let actual: Vec = + list.iter_arc().unwrap().map(|arc| (*arc).clone()).collect(); + let expected: Vec = spec.iter().cloned().collect(); + + assert_eq!(actual, expected); + } + } + Op::IterFrom(index) => match (list.iter_from(index), spec.iter_from(index)) { (Ok(iter1), Ok(iter2)) => assert!(iter1.eq(iter2)), (Err(e1), Err(e2)) => assert_eq!(e1, e2), @@ -306,6 +319,16 @@ where Op::Iter => { assert!(vect.iter().eq(spec.iter())); } + Op::IterArc => { + if ::tree_hash_type() != TreeHashType::Basic { + let actual: Vec = + vect.iter_arc().unwrap().map(|arc| (*arc).clone()).collect(); + let expected: Vec = spec.iter().cloned().collect(); + + assert!(actual.eq(&expected)); + } + } + Op::IterFrom(index) => match (vect.iter_from(index), spec.iter_from(index)) { (Ok(iter1), Ok(iter2)) => assert!(iter1.eq(iter2)), (Err(e1), Err(e2)) => assert_eq!(e1, e2), diff --git a/src/update_map.rs b/src/update_map.rs index a58056f..6661fc5 100644 --- a/src/update_map.rs +++ b/src/update_map.rs @@ -2,6 +2,7 @@ use crate::cow::{BTreeCow, Cow, VecCow}; use crate::utils::max_btree_index; use std::collections::{BTreeMap, btree_map::Entry}; use std::ops::ControlFlow; +use triomphe::Arc; use vec_map::VecMap; /// Trait for map types which can be used to store intermediate updates before application @@ -18,6 +19,8 @@ pub trait UpdateMap: Default + Clone { F: FnOnce(usize) -> Option<&'a T>, T: Clone + 'a; + fn get_arc(&self, k: usize) -> Option>; + fn insert(&mut self, k: usize, value: T) -> Option; fn for_each_range(&self, start: usize, end: usize, f: F) -> Result<(), E> @@ -72,6 +75,10 @@ impl UpdateMap for BTreeMap { Some(Cow::BTree(cow)) } + fn get_arc(&self, k: usize) -> Option> { + self.get(&k).cloned().map(Arc::new) + } + fn insert(&mut self, idx: usize, value: T) -> Option { BTreeMap::insert(self, idx, value) } @@ -136,6 +143,10 @@ impl UpdateMap for VecMap { Some(Cow::Vec(cow)) } + fn get_arc(&self, k: usize) -> Option> { + self.get(k).cloned().map(Arc::new) + } + fn insert(&mut self, idx: usize, value: T) -> Option { VecMap::insert(self, idx, value) } @@ -202,6 +213,10 @@ where self.inner.get_cow_with(k, f) } + fn get_arc(&self, k: usize) -> Option> { + self.inner.get_arc(k) + } + fn insert(&mut self, k: usize, value: T) -> Option { if k > self.max_key { self.max_key = k; @@ -224,3 +239,62 @@ where Some(self.max_key).filter(|_| !self.inner.is_empty()) } } + +#[derive(Debug, Default, Clone, PartialEq)] +pub struct ArcMap(pub M); + +impl UpdateMap for ArcMap +where + M: UpdateMap>, + T: Clone + 'static, +{ + fn get(&self, k: usize) -> Option<&T> { + self.0.get(k).map(|arc| &**arc) + } + + fn get_mut_with(&mut self, k: usize, f: F) -> Option<&mut T> + where + F: FnOnce(usize) -> Option, + { + let value = self.0.get_mut_with(k, |idx| f(idx).map(Arc::new))?; + Arc::get_mut(value) + } + + fn get_cow_with<'a, F>(&'a mut self, idx: usize, f: F) -> Option> + where + F: FnOnce(usize) -> Option<&'a T>, + T: Clone + 'a, + { + let arc = self + .0 + .get_mut_with(idx, |_| Some(Arc::new(f(idx)?.clone())))?; + let value_mut = Arc::get_mut(arc)?; + + Some(Cow::BTree(BTreeCow::Mutable { value: value_mut })) + } + + fn get_arc(&self, k: usize) -> Option> { + self.0.get(k).cloned() + } + + fn insert(&mut self, k: usize, value: T) -> Option { + self.0 + .insert(k, Arc::new(value)) + .and_then(|arc| Arc::try_unwrap(arc).ok()) + } + + fn for_each_range(&self, start: usize, end: usize, mut f: F) -> Result<(), E> + where + F: FnMut(usize, &T) -> ControlFlow<(), Result<(), E>>, + { + self.0.for_each_range(start, end, |k, v| f(k, &**v)) + } + + fn max_index(&self) -> Option { + self.0.max_index() + } + + fn len(&self) -> usize { + self.0.len() + } +} diff --git a/src/vector.rs b/src/vector.rs index 2cb91e5..597c9cd 100644 --- a/src/vector.rs +++ b/src/vector.rs @@ -1,11 +1,12 @@ use crate::interface::{ImmList, Interface, MutList}; use crate::interface_iter::InterfaceIter; use crate::iter::Iter; +use crate::iter_arc::ArcInterfaceIter; use crate::level_iter::LevelIter; use crate::tree::{IntraRebaseAction, RebaseAction}; use crate::update_map::MaxMap; use crate::utils::Length; -use crate::{Arc, Cow, Error, List, Tree, UpdateMap, Value}; +use crate::{Arc, ArcIter, Cow, Error, List, Tree, UpdateMap, Value}; #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; use educe::Educe; @@ -78,6 +79,15 @@ impl> Vector { self.interface.iter() } + pub fn iter_arc(&self) -> Result>, Error> { + Ok(ArcInterfaceIter::new( + &self.interface.backing.tree, + self.interface.backing.depth, + Length(self.len()), + &self.interface.updates, + )) + } + pub fn iter_from(&self, index: usize) -> Result, Error> { if index > self.len() { return Err(Error::OutOfBoundsIterFrom { @@ -224,6 +234,10 @@ impl ImmList for VectorInner { fn level_iter_from(&self, index: usize) -> LevelIter<'_, T> { LevelIter::from_index(index, &self.tree, self.depth, Length(N::to_usize())) } + + fn iter_arc(&self, index: usize) -> Result, Error> { + ArcIter::from_index(index, &self.tree, self.depth, Length(N::to_usize())) + } } impl MutList for VectorInner