Skip to content

Commit

Permalink
feat: merkle multiproof
Browse files Browse the repository at this point in the history
  • Loading branch information
cchudant committed Sep 27, 2024
1 parent ca72fc6 commit f79dc23
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 52 deletions.
30 changes: 21 additions & 9 deletions src/trie/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,24 @@ use starknet_types_core::{felt::Felt, hash::StarkHash};

/// This trait's function will be called on every node visited during a seek operation.
pub trait NodeVisitor<H: StarkHash> {
fn visit_node(&mut self, tree: &mut MerkleTree<H>, node_id: NodeKey, prev_height: usize);
fn visit_node<DB: BonsaiDatabase>(
&mut self,
tree: &mut MerkleTree<H>,
node_id: NodeKey,
prev_height: usize,
) -> Result<(), BonsaiStorageError<DB::DatabaseError>>;
}

pub struct NoopVisitor<H>(PhantomData<H>);
impl<H: StarkHash> NodeVisitor<H> for NoopVisitor<H> {
fn visit_node(&mut self, _tree: &mut MerkleTree<H>, _node_id: NodeKey, _prev_height: usize) {}
fn visit_node<DB: BonsaiDatabase>(
&mut self,
_tree: &mut MerkleTree<H>,
_node_id: NodeKey,
_prev_height: usize,
) -> Result<(), BonsaiStorageError<DB::DatabaseError>> {
Ok(())
}
}

pub struct MerkleTreeIterator<'a, H: StarkHash, DB: BonsaiDatabase, ID: Id> {
Expand Down Expand Up @@ -153,10 +165,9 @@ impl<'a, H: StarkHash + Send + Sync, DB: BonsaiDatabase, ID: Id> MerkleTreeItera
// partition point is a binary search under the hood
// TODO(perf): measure whether binary search is actually better than reverse iteration - the happy path may be that
// only the last few bits are different.
let nodes_new_len = self
.current_nodes_heights
.partition_point(|(_node, height)| *height < shared_prefix_len);
nodes_new_len

self.current_nodes_heights
.partition_point(|(_node, height)| *height < shared_prefix_len)
};
log::trace!(
"Truncate pre node id cache shared_prefix_len={:?}, nodes_new_len={:?}, cur_path_nodes_heights={:?}, current_path={:?}",
Expand Down Expand Up @@ -202,7 +213,7 @@ impl<'a, H: StarkHash + Send + Sync, DB: BonsaiDatabase, ID: Id> MerkleTreeItera
return Ok(());
};

visitor.visit_node(&mut self.tree, node_id, self.current_path.len());
visitor.visit_node::<DB>(self.tree, node_id, self.current_path.len())?;
next_to_visit = self.traverse_one(node_id, self.current_path.len(), key)?;

log::trace!(
Expand All @@ -218,7 +229,7 @@ impl<'a, H: StarkHash + Send + Sync, DB: BonsaiDatabase, ID: Id> MerkleTreeItera

#[cfg(test)]
mod tests {
//! The tree used in this series of cases looks like this:
//! The tree used in this series of tests looks like this:
//! ```
//! │
//! ┌▼┐
Expand All @@ -239,7 +250,7 @@ mod tests {
//! │ │ └┬┘ └┬┘
//! 0x1 0x2 0x3 0x4
//! ```

use crate::{
databases::{create_rocks_db, RocksDB, RocksDBConfig},
id::{BasicId, Id},
Expand Down Expand Up @@ -304,6 +315,7 @@ mod tests {
}
}

#[allow(clippy::type_complexity)]
fn all_cases<H: StarkHash + Send + Sync, DB: BonsaiDatabase, ID: Id>(
) -> Vec<fn(&mut MerkleTreeIterator<H, DB, ID>)> {
vec![
Expand Down
20 changes: 18 additions & 2 deletions src/trie/merkle_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
//! [`MerkleTree`](super::merkle_tree::MerkleTree).

use crate::BitSlice;
use bitvec::view::BitView;
use core::fmt;
use parity_scale_codec::{Decode, Encode};
use starknet_types_core::felt::Felt;
use starknet_types_core::{felt::Felt, hash::StarkHash};

use super::{path::Path, tree::NodeKey};

Expand All @@ -34,7 +35,6 @@ impl NodeHandle {
}
}


impl fmt::Debug for NodeHandle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Expand Down Expand Up @@ -252,6 +252,22 @@ impl EdgeNode {
}
}

pub fn hash_binary_node<H: StarkHash>(left_hash: Felt, right_hash: Felt) -> Felt {
H::hash(&left_hash, &right_hash)
}
pub fn hash_edge_node<H: StarkHash>(path: &Path, child_hash: Felt) -> Felt {
let mut bytes = [0u8; 32];
bytes.view_bits_mut()[256 - path.len()..].copy_from_bitslice(path);

let felt_path = Felt::from_bytes_be(&bytes);
let mut length = [0; 32];
// Safe as len() is guaranteed to be <= 251
length[31] = path.len() as u8;

let length = Felt::from_bytes_be(&length);
H::hash(&child_hash, &felt_path) + length
}

#[test]
fn test_path_matches_basic() {
let path =
Expand Down
122 changes: 93 additions & 29 deletions src/trie/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@ use super::{path::Path, tree::MerkleTree};
use crate::{
id::Id,
key_value_db::KeyValueDB,
trie::{iterator::NodeVisitor, tree::NodeKey},
BitSlice, BonsaiDatabase, BonsaiStorageError,
trie::{
iterator::NodeVisitor,
merkle_node::{Node, NodeHandle},
tree::NodeKey,
},
BitSlice, BonsaiDatabase, BonsaiStorageError, HashMap,
};
use bitvec::view::BitView;
use starknet_types_core::{felt::Felt, hash::StarkHash};
Expand Down Expand Up @@ -40,37 +44,38 @@ impl ProofNode {
}

impl<H: StarkHash + Send + Sync> MerkleTree<H> {
/// Returns the list of nodes along the path.
///
/// if it exists, or down to the node which proves that the key does not exist.
///
/// The nodes are returned in order, root first.
///
/// Verification is performed by confirming that:
/// 1. the chain follows the path of `key`, and
/// 2. the hashes are correct, and
/// 3. the root hash matches the known root
///
/// # Arguments
///
/// * `key` - The key to get the merkle proof of.
///
/// # Returns
///
/// The merkle proof and all the child nodes hashes.
pub fn get_multi_proof<DB: BonsaiDatabase, ID: Id>(
&mut self,
db: &KeyValueDB<DB, ID>,
keys: impl IntoIterator<Item = impl AsRef<BitSlice>>,
) -> Result<Vec<ProofNode>, BonsaiStorageError<DB::DatabaseError>> {
struct ProofVisitor<H>(Vec<ProofNode>, PhantomData<H>);
impl<H: StarkHash> NodeVisitor<H> for ProofVisitor<H> {
fn visit_node(&mut self, _tree: &mut MerkleTree<H>, node_id: NodeKey, prev_height: usize) {
log::trace!(
"Visiting {:?} prev height: {:?}",
node_id,
prev_height
);
) -> Result<HashMap<Felt, ProofNode>, BonsaiStorageError<DB::DatabaseError>> {
struct ProofVisitor<H>(HashMap<Felt, ProofNode>, PhantomData<H>);
impl<H: StarkHash + Send + Sync> NodeVisitor<H> for ProofVisitor<H> {
fn visit_node<DB: BonsaiDatabase>(
&mut self,
tree: &mut MerkleTree<H>,
node_id: NodeKey,
_prev_height: usize,
) -> Result<(), BonsaiStorageError<DB::DatabaseError>> {
let proof_node = match tree.node_storage.get_node_mut::<DB>(node_id)? {
Node::Binary(binary_node) => {
let (left, right) = (binary_node.left, binary_node.right);
ProofNode::Binary {
left: tree.get_or_compute_node_hash::<DB>(left)?,
right: tree.get_or_compute_node_hash::<DB>(right)?,
}
}
Node::Edge(edge_node) => {
let (child, path) = (edge_node.child, edge_node.path.clone());
ProofNode::Edge {
child: tree.get_or_compute_node_hash::<DB>(child)?,
path,
}
}
};
let hash = tree.get_or_compute_node_hash::<DB>(NodeHandle::InMemory(node_id))?;
self.0.insert(hash, proof_node);
Ok(())
}
}
let mut visitor = ProofVisitor::<H>(Default::default(), PhantomData);
Expand Down Expand Up @@ -164,3 +169,62 @@ impl<H: StarkHash + Send + Sync> MerkleTree<H> {
// }
}
}

#[cfg(test)]
mod tests {
use crate::{
databases::{create_rocks_db, RocksDB, RocksDBConfig},
id::BasicId,
BonsaiStorage, BonsaiStorageConfig,
};
use bitvec::{bits, order::Msb0};
use starknet_types_core::{felt::Felt, hash::Pedersen};

const ONE: Felt = Felt::ONE;
const TWO: Felt = Felt::TWO;
const THREE: Felt = Felt::THREE;
const FOUR: Felt = Felt::from_hex_unchecked("0x4");

#[test]
fn test_multiproof() {
let _ = env_logger::builder().is_test(true).try_init();
log::set_max_level(log::LevelFilter::Trace);
let tempdir = tempfile::tempdir().unwrap();
let db = create_rocks_db(tempdir.path()).unwrap();
let mut bonsai_storage: BonsaiStorage<BasicId, _, Pedersen> = BonsaiStorage::new(
RocksDB::<BasicId>::new(&db, RocksDBConfig::default()),
BonsaiStorageConfig::default(),
)
.unwrap();

bonsai_storage
.insert(&[], bits![u8, Msb0; 0,0,0,1,0,0,0,0], &ONE)
.unwrap();
bonsai_storage
.insert(&[], bits![u8, Msb0; 0,0,0,1,0,0,0,1], &TWO)
.unwrap();
bonsai_storage
.insert(&[], bits![u8, Msb0; 0,0,0,1,0,0,1,0], &THREE)
.unwrap();
bonsai_storage
.insert(&[], bits![u8, Msb0; 0,1,0,0,0,0,0,0], &FOUR)
.unwrap();

bonsai_storage.dump();

let tree = bonsai_storage
.tries
.trees
.get_mut(&smallvec::smallvec![])
.unwrap();

let proof = tree.get_multi_proof(&bonsai_storage.tries.db, [
bits![u8, Msb0; 0,0,0,1,0,0,0,1],
bits![u8, Msb0; 0,1,0,0,0,0,0,0],
])
.unwrap();

log::trace!("proof: {proof:?}");
todo!()
}
}
57 changes: 45 additions & 12 deletions src/trie/tree.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use bitvec::view::BitView;
use core::{fmt, marker::PhantomData};
use core::{iter, mem};
use parity_scale_codec::Decode;
use slotmap::SlotMap;
use starknet_types_core::{felt::Felt, hash::StarkHash};

use crate::trie::merkle_node::{hash_binary_node, hash_edge_node};
use crate::BitVec;
use crate::{
error::BonsaiStorageError, format, hash_map, id::Id, vec, BitSlice, BonsaiDatabase, ByteVec,
Expand Down Expand Up @@ -213,6 +213,46 @@ impl<H: StarkHash + Send + Sync> MerkleTree<H> {
}
}

/// Get or compute the hash of a node.
pub(crate) fn get_or_compute_node_hash<DB: BonsaiDatabase>(
&mut self,
node: NodeHandle,
) -> Result<Felt, BonsaiStorageError<DB::DatabaseError>> {
match node {
NodeHandle::Hash(felt) => Ok(felt),
NodeHandle::InMemory(node_key) => {
let computed_hash = match self.node_storage.get_node_mut::<DB>(node_key)? {
Node::Binary(binary_node) => {
if let Some(hash) = binary_node.hash {
return Ok(hash);
}
let (left, right) = (binary_node.left, binary_node.right);
let left_hash = self.get_or_compute_node_hash::<DB>(left)?;
let right_hash = self.get_or_compute_node_hash::<DB>(right)?;
hash_binary_node::<H>(left_hash, right_hash)
}
Node::Edge(edge_node) => {
if let Some(hash) = edge_node.hash {
return Ok(hash);
}
let (path, child) = (edge_node.path.clone(), edge_node.child);
// edge_node borrow ends here
let child_hash = self.get_or_compute_node_hash::<DB>(child)?;
hash_edge_node::<H>(&path, child_hash)
}
};

// reborrow, for lifetime reasons (can't go into children if a borrow is alive)
match self.node_storage.get_node_mut::<DB>(node_key)? {
Node::Binary(binary_node) => binary_node.hash = Some(computed_hash),
Node::Edge(edge_node) => edge_node.hash = Some(computed_hash),
}

Ok(computed_hash)
}
}
}

/// Note: as iterators load nodes from the database, this takes an &mut self. However,
/// note that it will not modify anything in the database - hence the &db.
pub fn iter<'a, DB: BonsaiDatabase, ID: Id>(
Expand Down Expand Up @@ -431,7 +471,8 @@ impl<H: StarkHash + Send + Sync> MerkleTree<H> {
}
};

let hash = H::hash(&left_hash, &right_hash);
let hash = hash_binary_node::<H>(left_hash, right_hash);

hashes.push(hash);
Ok(hash)
}
Expand All @@ -446,17 +487,9 @@ impl<H: StarkHash + Send + Sync> MerkleTree<H> {
}
};

let mut bytes = [0u8; 32];
bytes.view_bits_mut()[256 - edge.path.0.len()..].copy_from_bitslice(&edge.path.0);

let felt_path = Felt::from_bytes_be(&bytes);
let mut length = [0; 32];
// Safe as len() is guaranteed to be <= 251
length[31] = edge.path.0.len() as u8;

let length = Felt::from_bytes_be(&length);
let hash = H::hash(&child_hash, &felt_path) + length;
let hash = hash_edge_node::<H>(&edge.path, child_hash);
hashes.push(hash);

Ok(hash)
}
}
Expand Down

0 comments on commit f79dc23

Please sign in to comment.