From f79dc23b0cdb57287f84398441149b2354c4883b Mon Sep 17 00:00:00 2001 From: cchudant Date: Fri, 27 Sep 2024 16:48:44 +0000 Subject: [PATCH] feat: merkle multiproof --- src/trie/iterator.rs | 30 +++++++--- src/trie/merkle_node.rs | 20 ++++++- src/trie/proof.rs | 122 ++++++++++++++++++++++++++++++---------- src/trie/tree.rs | 57 +++++++++++++++---- 4 files changed, 177 insertions(+), 52 deletions(-) diff --git a/src/trie/iterator.rs b/src/trie/iterator.rs index 50aec93..2589011 100644 --- a/src/trie/iterator.rs +++ b/src/trie/iterator.rs @@ -9,12 +9,24 @@ use starknet_types_core::{felt::Felt, hash::StarkHash}; /// This trait's function will be called on every node visited during a seek operation. pub trait NodeVisitor { - fn visit_node(&mut self, tree: &mut MerkleTree, node_id: NodeKey, prev_height: usize); + fn visit_node( + &mut self, + tree: &mut MerkleTree, + node_id: NodeKey, + prev_height: usize, + ) -> Result<(), BonsaiStorageError>; } pub struct NoopVisitor(PhantomData); impl NodeVisitor for NoopVisitor { - fn visit_node(&mut self, _tree: &mut MerkleTree, _node_id: NodeKey, _prev_height: usize) {} + fn visit_node( + &mut self, + _tree: &mut MerkleTree, + _node_id: NodeKey, + _prev_height: usize, + ) -> Result<(), BonsaiStorageError> { + Ok(()) + } } pub struct MerkleTreeIterator<'a, H: StarkHash, DB: BonsaiDatabase, ID: Id> { @@ -153,10 +165,9 @@ impl<'a, H: StarkHash + Send + Sync, DB: BonsaiDatabase, ID: Id> MerkleTreeItera // partition point is a binary search under the hood // TODO(perf): measure whether binary search is actually better than reverse iteration - the happy path may be that // only the last few bits are different. - let nodes_new_len = self - .current_nodes_heights - .partition_point(|(_node, height)| *height < shared_prefix_len); - nodes_new_len + + self.current_nodes_heights + .partition_point(|(_node, height)| *height < shared_prefix_len) }; log::trace!( "Truncate pre node id cache shared_prefix_len={:?}, nodes_new_len={:?}, cur_path_nodes_heights={:?}, current_path={:?}", @@ -202,7 +213,7 @@ impl<'a, H: StarkHash + Send + Sync, DB: BonsaiDatabase, ID: Id> MerkleTreeItera return Ok(()); }; - visitor.visit_node(&mut self.tree, node_id, self.current_path.len()); + visitor.visit_node::(self.tree, node_id, self.current_path.len())?; next_to_visit = self.traverse_one(node_id, self.current_path.len(), key)?; log::trace!( @@ -218,7 +229,7 @@ impl<'a, H: StarkHash + Send + Sync, DB: BonsaiDatabase, ID: Id> MerkleTreeItera #[cfg(test)] mod tests { - //! The tree used in this series of cases looks like this: + //! The tree used in this series of tests looks like this: //! ``` //! │ //! ┌▼┐ @@ -239,7 +250,7 @@ mod tests { //! │ │ └┬┘ └┬┘ //! 0x1 0x2 0x3 0x4 //! ``` - + use crate::{ databases::{create_rocks_db, RocksDB, RocksDBConfig}, id::{BasicId, Id}, @@ -304,6 +315,7 @@ mod tests { } } + #[allow(clippy::type_complexity)] fn all_cases( ) -> Vec)> { vec![ diff --git a/src/trie/merkle_node.rs b/src/trie/merkle_node.rs index bd08b39..a98c8c4 100644 --- a/src/trie/merkle_node.rs +++ b/src/trie/merkle_node.rs @@ -5,9 +5,10 @@ //! [`MerkleTree`](super::merkle_tree::MerkleTree). use crate::BitSlice; +use bitvec::view::BitView; use core::fmt; use parity_scale_codec::{Decode, Encode}; -use starknet_types_core::felt::Felt; +use starknet_types_core::{felt::Felt, hash::StarkHash}; use super::{path::Path, tree::NodeKey}; @@ -34,7 +35,6 @@ impl NodeHandle { } } - impl fmt::Debug for NodeHandle { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -252,6 +252,22 @@ impl EdgeNode { } } +pub fn hash_binary_node(left_hash: Felt, right_hash: Felt) -> Felt { + H::hash(&left_hash, &right_hash) +} +pub fn hash_edge_node(path: &Path, child_hash: Felt) -> Felt { + let mut bytes = [0u8; 32]; + bytes.view_bits_mut()[256 - path.len()..].copy_from_bitslice(path); + + let felt_path = Felt::from_bytes_be(&bytes); + let mut length = [0; 32]; + // Safe as len() is guaranteed to be <= 251 + length[31] = path.len() as u8; + + let length = Felt::from_bytes_be(&length); + H::hash(&child_hash, &felt_path) + length +} + #[test] fn test_path_matches_basic() { let path = diff --git a/src/trie/proof.rs b/src/trie/proof.rs index 978f34a..e9bc91f 100644 --- a/src/trie/proof.rs +++ b/src/trie/proof.rs @@ -4,8 +4,12 @@ use super::{path::Path, tree::MerkleTree}; use crate::{ id::Id, key_value_db::KeyValueDB, - trie::{iterator::NodeVisitor, tree::NodeKey}, - BitSlice, BonsaiDatabase, BonsaiStorageError, + trie::{ + iterator::NodeVisitor, + merkle_node::{Node, NodeHandle}, + tree::NodeKey, + }, + BitSlice, BonsaiDatabase, BonsaiStorageError, HashMap, }; use bitvec::view::BitView; use starknet_types_core::{felt::Felt, hash::StarkHash}; @@ -40,37 +44,38 @@ impl ProofNode { } impl MerkleTree { - /// Returns the list of nodes along the path. - /// - /// if it exists, or down to the node which proves that the key does not exist. - /// - /// The nodes are returned in order, root first. - /// - /// Verification is performed by confirming that: - /// 1. the chain follows the path of `key`, and - /// 2. the hashes are correct, and - /// 3. the root hash matches the known root - /// - /// # Arguments - /// - /// * `key` - The key to get the merkle proof of. - /// - /// # Returns - /// - /// The merkle proof and all the child nodes hashes. pub fn get_multi_proof( &mut self, db: &KeyValueDB, keys: impl IntoIterator>, - ) -> Result, BonsaiStorageError> { - struct ProofVisitor(Vec, PhantomData); - impl NodeVisitor for ProofVisitor { - fn visit_node(&mut self, _tree: &mut MerkleTree, node_id: NodeKey, prev_height: usize) { - log::trace!( - "Visiting {:?} prev height: {:?}", - node_id, - prev_height - ); + ) -> Result, BonsaiStorageError> { + struct ProofVisitor(HashMap, PhantomData); + impl NodeVisitor for ProofVisitor { + fn visit_node( + &mut self, + tree: &mut MerkleTree, + node_id: NodeKey, + _prev_height: usize, + ) -> Result<(), BonsaiStorageError> { + let proof_node = match tree.node_storage.get_node_mut::(node_id)? { + Node::Binary(binary_node) => { + let (left, right) = (binary_node.left, binary_node.right); + ProofNode::Binary { + left: tree.get_or_compute_node_hash::(left)?, + right: tree.get_or_compute_node_hash::(right)?, + } + } + Node::Edge(edge_node) => { + let (child, path) = (edge_node.child, edge_node.path.clone()); + ProofNode::Edge { + child: tree.get_or_compute_node_hash::(child)?, + path, + } + } + }; + let hash = tree.get_or_compute_node_hash::(NodeHandle::InMemory(node_id))?; + self.0.insert(hash, proof_node); + Ok(()) } } let mut visitor = ProofVisitor::(Default::default(), PhantomData); @@ -164,3 +169,62 @@ impl MerkleTree { // } } } + +#[cfg(test)] +mod tests { + use crate::{ + databases::{create_rocks_db, RocksDB, RocksDBConfig}, + id::BasicId, + BonsaiStorage, BonsaiStorageConfig, + }; + use bitvec::{bits, order::Msb0}; + use starknet_types_core::{felt::Felt, hash::Pedersen}; + + const ONE: Felt = Felt::ONE; + const TWO: Felt = Felt::TWO; + const THREE: Felt = Felt::THREE; + const FOUR: Felt = Felt::from_hex_unchecked("0x4"); + + #[test] + fn test_multiproof() { + let _ = env_logger::builder().is_test(true).try_init(); + log::set_max_level(log::LevelFilter::Trace); + let tempdir = tempfile::tempdir().unwrap(); + let db = create_rocks_db(tempdir.path()).unwrap(); + let mut bonsai_storage: BonsaiStorage = BonsaiStorage::new( + RocksDB::::new(&db, RocksDBConfig::default()), + BonsaiStorageConfig::default(), + ) + .unwrap(); + + bonsai_storage + .insert(&[], bits![u8, Msb0; 0,0,0,1,0,0,0,0], &ONE) + .unwrap(); + bonsai_storage + .insert(&[], bits![u8, Msb0; 0,0,0,1,0,0,0,1], &TWO) + .unwrap(); + bonsai_storage + .insert(&[], bits![u8, Msb0; 0,0,0,1,0,0,1,0], &THREE) + .unwrap(); + bonsai_storage + .insert(&[], bits![u8, Msb0; 0,1,0,0,0,0,0,0], &FOUR) + .unwrap(); + + bonsai_storage.dump(); + + let tree = bonsai_storage + .tries + .trees + .get_mut(&smallvec::smallvec![]) + .unwrap(); + + let proof = tree.get_multi_proof(&bonsai_storage.tries.db, [ + bits![u8, Msb0; 0,0,0,1,0,0,0,1], + bits![u8, Msb0; 0,1,0,0,0,0,0,0], + ]) + .unwrap(); + + log::trace!("proof: {proof:?}"); + todo!() + } +} diff --git a/src/trie/tree.rs b/src/trie/tree.rs index 23634ce..f2c373a 100644 --- a/src/trie/tree.rs +++ b/src/trie/tree.rs @@ -1,10 +1,10 @@ -use bitvec::view::BitView; use core::{fmt, marker::PhantomData}; use core::{iter, mem}; use parity_scale_codec::Decode; use slotmap::SlotMap; use starknet_types_core::{felt::Felt, hash::StarkHash}; +use crate::trie::merkle_node::{hash_binary_node, hash_edge_node}; use crate::BitVec; use crate::{ error::BonsaiStorageError, format, hash_map, id::Id, vec, BitSlice, BonsaiDatabase, ByteVec, @@ -213,6 +213,46 @@ impl MerkleTree { } } + /// Get or compute the hash of a node. + pub(crate) fn get_or_compute_node_hash( + &mut self, + node: NodeHandle, + ) -> Result> { + match node { + NodeHandle::Hash(felt) => Ok(felt), + NodeHandle::InMemory(node_key) => { + let computed_hash = match self.node_storage.get_node_mut::(node_key)? { + Node::Binary(binary_node) => { + if let Some(hash) = binary_node.hash { + return Ok(hash); + } + let (left, right) = (binary_node.left, binary_node.right); + let left_hash = self.get_or_compute_node_hash::(left)?; + let right_hash = self.get_or_compute_node_hash::(right)?; + hash_binary_node::(left_hash, right_hash) + } + Node::Edge(edge_node) => { + if let Some(hash) = edge_node.hash { + return Ok(hash); + } + let (path, child) = (edge_node.path.clone(), edge_node.child); + // edge_node borrow ends here + let child_hash = self.get_or_compute_node_hash::(child)?; + hash_edge_node::(&path, child_hash) + } + }; + + // reborrow, for lifetime reasons (can't go into children if a borrow is alive) + match self.node_storage.get_node_mut::(node_key)? { + Node::Binary(binary_node) => binary_node.hash = Some(computed_hash), + Node::Edge(edge_node) => edge_node.hash = Some(computed_hash), + } + + Ok(computed_hash) + } + } + } + /// Note: as iterators load nodes from the database, this takes an &mut self. However, /// note that it will not modify anything in the database - hence the &db. pub fn iter<'a, DB: BonsaiDatabase, ID: Id>( @@ -431,7 +471,8 @@ impl MerkleTree { } }; - let hash = H::hash(&left_hash, &right_hash); + let hash = hash_binary_node::(left_hash, right_hash); + hashes.push(hash); Ok(hash) } @@ -446,17 +487,9 @@ impl MerkleTree { } }; - let mut bytes = [0u8; 32]; - bytes.view_bits_mut()[256 - edge.path.0.len()..].copy_from_bitslice(&edge.path.0); - - let felt_path = Felt::from_bytes_be(&bytes); - let mut length = [0; 32]; - // Safe as len() is guaranteed to be <= 251 - length[31] = edge.path.0.len() as u8; - - let length = Felt::from_bytes_be(&length); - let hash = H::hash(&child_hash, &felt_path) + length; + let hash = hash_edge_node::(&edge.path, child_hash); hashes.push(hash); + Ok(hash) } }