Skip to content

Commit

Permalink
Merge pull request #1574 from 0xPolygonMiden/greenhat/i1547-mastfores…
Browse files Browse the repository at this point in the history
…t-add-advicemap

Add `MastForest::advice_map` for the data required in the advice provider before execution
  • Loading branch information
plafer authored Nov 21, 2024
2 parents 8b2545e + 93869ef commit 9f56203
Show file tree
Hide file tree
Showing 25 changed files with 321 additions and 44 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
- [BREAKING] `Process` no longer takes ownership of the `Host` (#1571)
- [BREAKING] `ProcessState` was converted from a trait to a struct (#1571)

#### Enhancements
- Added `miden_core::mast::MastForest::advice_map` to load it into the advice provider before the `MastForest` execution (#1574).

## 0.11.0 (2024-11-04)

#### Enhancements
Expand Down
53 changes: 48 additions & 5 deletions processor/src/host/advice/map.rs → core/src/advice/map.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use alloc::{
boxed::Box,
collections::{btree_map::IntoIter, BTreeMap},
vec::Vec,
};

use vm_core::{
use miden_crypto::{utils::collections::KvMap, Felt};

use crate::{
crypto::hash::RpoDigest,
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
};

use super::Felt;

// ADVICE MAP
// ================================================================================================

Expand Down Expand Up @@ -38,8 +39,18 @@ impl AdviceMap {
}

/// Removes the value associated with the key and returns the removed element.
pub fn remove(&mut self, key: RpoDigest) -> Option<Vec<Felt>> {
self.0.remove(&key)
pub fn remove(&mut self, key: &RpoDigest) -> Option<Vec<Felt>> {
self.0.remove(key)
}

/// Returns the number of key value pairs in the advice map.
pub fn len(&self) -> usize {
self.0.len()
}

/// Returns true if the advice map is empty.
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}

Expand All @@ -58,6 +69,38 @@ impl IntoIterator for AdviceMap {
}
}

impl FromIterator<(RpoDigest, Vec<Felt>)> for AdviceMap {
fn from_iter<T: IntoIterator<Item = (RpoDigest, Vec<Felt>)>>(iter: T) -> Self {
iter.into_iter().collect::<BTreeMap<RpoDigest, Vec<Felt>>>().into()
}
}

impl KvMap<RpoDigest, Vec<Felt>> for AdviceMap {
fn get(&self, key: &RpoDigest) -> Option<&Vec<Felt>> {
self.0.get(key)
}

fn contains_key(&self, key: &RpoDigest) -> bool {
self.0.contains_key(key)
}

fn len(&self) -> usize {
self.len()
}

fn insert(&mut self, key: RpoDigest, value: Vec<Felt>) -> Option<Vec<Felt>> {
self.insert(key, value)
}

fn remove(&mut self, key: &RpoDigest) -> Option<Vec<Felt>> {
self.remove(key)
}

fn iter(&self) -> Box<dyn Iterator<Item = (&RpoDigest, &Vec<Felt>)> + '_> {
Box::new(self.0.iter())
}
}

impl Extend<(RpoDigest, Vec<Felt>)> for AdviceMap {
fn extend<T: IntoIterator<Item = (RpoDigest, Vec<Felt>)>>(&mut self, iter: T) {
self.0.extend(iter)
Expand Down
1 change: 1 addition & 0 deletions core/src/advice/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub(super) mod map;
3 changes: 3 additions & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,7 @@ pub use operations::{
pub mod stack;
pub use stack::{StackInputs, StackOutputs};

mod advice;
pub use advice::map::AdviceMap;

pub mod utils;
25 changes: 21 additions & 4 deletions core/src/mast/merger/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use alloc::{collections::BTreeMap, vec::Vec};

use miden_crypto::hash::blake::Blake3Digest;
use miden_crypto::{hash::blake::Blake3Digest, utils::collections::KvMap};

use crate::mast::{
DecoratorId, MastForest, MastForestError, MastNode, MastNodeFingerprint, MastNodeId,
Expand Down Expand Up @@ -65,10 +65,11 @@ impl MastForestMerger {
///
/// It does this in three steps:
///
/// 1. Merge all decorators, which is a case of deduplication and creating a decorator id
/// 1. Merge all advice maps, checking for key collisions.
/// 2. Merge all decorators, which is a case of deduplication and creating a decorator id
/// mapping which contains how existing [`DecoratorId`]s map to [`DecoratorId`]s in the
/// merged forest.
/// 2. Merge all nodes of forests.
/// 3. Merge all nodes of forests.
/// - Similar to decorators, node indices might move during merging, so the merger keeps a
/// node id mapping as it merges nodes.
/// - This is a depth-first traversal over all forests to ensure all children are processed
Expand All @@ -90,10 +91,13 @@ impl MastForestMerger {
/// `replacement` node. Now we can simply add a mapping from the external node to the
/// `replacement` node in our node id mapping which means all nodes that referenced the
/// external node will point to the `replacement` instead.
/// 3. Finally, we merge all roots of all forests. Here we map the existing root indices to
/// 4. Finally, we merge all roots of all forests. Here we map the existing root indices to
/// their potentially new indices in the merged forest and add them to the forest,
/// deduplicating in the process, too.
fn merge_inner(&mut self, forests: Vec<&MastForest>) -> Result<(), MastForestError> {
for other_forest in forests.iter() {
self.merge_advice_map(other_forest)?;
}
for other_forest in forests.iter() {
self.merge_decorators(other_forest)?;
}
Expand Down Expand Up @@ -163,6 +167,19 @@ impl MastForestMerger {
Ok(())
}

fn merge_advice_map(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
for (digest, values) in other_forest.advice_map.iter() {
if let Some(stored_values) = self.mast_forest.advice_map().get(digest) {
if stored_values != values {
return Err(MastForestError::AdviceMapKeyCollisionOnMerge(*digest));
}
} else {
self.mast_forest.advice_map_mut().insert(*digest, values.clone());
}
}
Ok(())
}

fn merge_node(
&mut self,
forest_idx: usize,
Expand Down
53 changes: 52 additions & 1 deletion core/src/mast/merger/tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use miden_crypto::{hash::rpo::RpoDigest, ONE};
use miden_crypto::{hash::rpo::RpoDigest, Felt, ONE};

use super::*;
use crate::{Decorator, Operation};
Expand Down Expand Up @@ -794,3 +794,54 @@ fn mast_forest_merge_invalid_decorator_index() {
let err = MastForest::merge([&forest_a, &forest_b]).unwrap_err();
assert_matches!(err, MastForestError::DecoratorIdOverflow(_, _));
}

/// Tests that forest's advice maps are merged correctly.
#[test]
fn mast_forest_merge_advice_maps_merged() {
let mut forest_a = MastForest::new();
let id_foo = forest_a.add_node(block_foo()).unwrap();
let id_call_a = forest_a.add_call(id_foo).unwrap();
forest_a.make_root(id_call_a);
let key_a = RpoDigest::new([Felt::new(1), Felt::new(2), Felt::new(3), Felt::new(4)]);
let value_a = vec![ONE, ONE];
forest_a.advice_map_mut().insert(key_a, value_a.clone());

let mut forest_b = MastForest::new();
let id_bar = forest_b.add_node(block_bar()).unwrap();
let id_call_b = forest_b.add_call(id_bar).unwrap();
forest_b.make_root(id_call_b);
let key_b = RpoDigest::new([Felt::new(1), Felt::new(3), Felt::new(2), Felt::new(1)]);
let value_b = vec![Felt::new(2), Felt::new(2)];
forest_b.advice_map_mut().insert(key_b, value_b.clone());

let (merged, _root_maps) = MastForest::merge([&forest_a, &forest_b]).unwrap();

let merged_advice_map = merged.advice_map();
assert_eq!(merged_advice_map.len(), 2);
assert_eq!(merged_advice_map.get(&key_a).unwrap(), &value_a);
assert_eq!(merged_advice_map.get(&key_b).unwrap(), &value_b);
}

/// Tests that an error is returned when advice maps have a key collision.
#[test]
fn mast_forest_merge_advice_maps_collision() {
let mut forest_a = MastForest::new();
let id_foo = forest_a.add_node(block_foo()).unwrap();
let id_call_a = forest_a.add_call(id_foo).unwrap();
forest_a.make_root(id_call_a);
let key_a = RpoDigest::new([Felt::new(1), Felt::new(2), Felt::new(3), Felt::new(4)]);
let value_a = vec![ONE, ONE];
forest_a.advice_map_mut().insert(key_a, value_a.clone());

let mut forest_b = MastForest::new();
let id_bar = forest_b.add_node(block_bar()).unwrap();
let id_call_b = forest_b.add_call(id_bar).unwrap();
forest_b.make_root(id_call_b);
// The key collides with key_a in the forest_a.
let key_b = key_a;
let value_b = vec![Felt::new(2), Felt::new(2)];
forest_b.advice_map_mut().insert(key_b, value_b.clone());

let err = MastForest::merge([&forest_a, &forest_b]).unwrap_err();
assert_matches!(err, MastForestError::AdviceMapKeyCollisionOnMerge(_));
}
15 changes: 14 additions & 1 deletion core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub use node::{
};
use winter_utils::{ByteWriter, DeserializationError, Serializable};

use crate::{Decorator, DecoratorList, Operation};
use crate::{AdviceMap, Decorator, DecoratorList, Operation};

mod serialization;

Expand Down Expand Up @@ -50,6 +50,9 @@ pub struct MastForest {

/// All the decorators included in the MAST forest.
decorators: Vec<Decorator>,

/// Advice map to be loaded into the VM prior to executing procedures from this MAST forest.
advice_map: AdviceMap,
}

// ------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -463,6 +466,14 @@ impl MastForest {
pub fn nodes(&self) -> &[MastNode] {
&self.nodes
}

pub fn advice_map(&self) -> &AdviceMap {
&self.advice_map
}

pub fn advice_map_mut(&mut self) -> &mut AdviceMap {
&mut self.advice_map
}
}

impl Index<MastNodeId> for MastForest {
Expand Down Expand Up @@ -689,4 +700,6 @@ pub enum MastForestError {
EmptyBasicBlock,
#[error("decorator root of child with node id {0} is missing but required for fingerprint computation")]
ChildFingerprintMissing(MastNodeId),
#[error("advice map key already exists when merging forests: {0}")]
AdviceMapKeyCollisionOnMerge(RpoDigest),
}
6 changes: 6 additions & 0 deletions core/src/mast/serialization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use string_table::{StringTable, StringTableBuilder};
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};

use super::{DecoratorId, MastForest, MastNode, MastNodeId};
use crate::AdviceMap;

mod decorator;

Expand Down Expand Up @@ -149,6 +150,8 @@ impl Serializable for MastForest {
node_data.write_into(target);
string_table.write_into(target);

self.advice_map.write_into(target);

// Write decorator and node infos
for decorator_info in decorator_infos {
decorator_info.write_into(target);
Expand Down Expand Up @@ -187,6 +190,7 @@ impl Deserializable for MastForest {
let decorator_data: Vec<u8> = Deserializable::read_from(source)?;
let node_data: Vec<u8> = Deserializable::read_from(source)?;
let string_table: StringTable = Deserializable::read_from(source)?;
let advice_map = AdviceMap::read_from(source)?;

let mut mast_forest = {
let mut mast_forest = MastForest::new();
Expand Down Expand Up @@ -229,6 +233,8 @@ impl Deserializable for MastForest {
mast_forest.make_root(root);
}

mast_forest.advice_map = advice_map;

mast_forest
};

Expand Down
21 changes: 20 additions & 1 deletion core/src/mast/serialization/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use alloc::{string::ToString, sync::Arc};

use miden_crypto::{hash::rpo::RpoDigest, Felt};
use miden_crypto::{hash::rpo::RpoDigest, Felt, ONE};

use super::*;
use crate::{
Expand Down Expand Up @@ -435,3 +435,22 @@ fn mast_forest_invalid_node_id() {
// Validate normal operations
forest.add_join(first, second).unwrap();
}

/// Test `MastForest::advice_map` serialization and deserialization.
#[test]
fn mast_forest_serialize_deserialize_advice_map() {
let mut forest = MastForest::new();
let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap();
let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap();
let first = forest.add_block(vec![Operation::U32add], Some(vec![(0, deco0)])).unwrap();
let second = forest.add_block(vec![Operation::U32and], Some(vec![(1, deco1)])).unwrap();
forest.add_join(first, second).unwrap();

let key = RpoDigest::new([ONE, ONE, ONE, ONE]);
let value = vec![ONE, ONE];

forest.advice_map_mut().insert(key, value);

let parsed = MastForest::read_from_bytes(&forest.to_bytes()).unwrap();
assert_eq!(forest.advice_map, parsed.advice_map);
}
2 changes: 1 addition & 1 deletion miden/benches/program_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn program_execution(c: &mut Criterion) {

let stdlib = StdLibrary::default();
let mut host = DefaultHost::default();
host.load_mast_forest(stdlib.as_ref().mast_forest().clone());
host.load_mast_forest(stdlib.as_ref().mast_forest().clone()).unwrap();

group.bench_function("sha256", |bench| {
let source = "
Expand Down
2 changes: 1 addition & 1 deletion miden/src/examples/blake3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub fn get_example(n: usize) -> Example<DefaultHost<MemAdviceProvider>> {
);

let mut host = DefaultHost::default();
host.load_mast_forest(StdLibrary::default().mast_forest().clone());
host.load_mast_forest(StdLibrary::default().mast_forest().clone()).unwrap();

let stack_inputs =
StackInputs::try_from_ints(INITIAL_HASH_VALUE.iter().map(|&v| v as u64)).unwrap();
Expand Down
3 changes: 2 additions & 1 deletion miden/src/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ fn execute(
let stack_inputs = StackInputs::default();
let mut host = DefaultHost::default();
for library in provided_libraries {
host.load_mast_forest(library.mast_forest().clone());
host.load_mast_forest(library.mast_forest().clone())
.map_err(|err| format!("{err}"))?;
}

let state_iter = processor::execute_iter(&program, stack_inputs, &mut host);
Expand Down
3 changes: 2 additions & 1 deletion miden/src/tools/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ impl Analyze {
// fetch the stack and program inputs from the arguments
let stack_inputs = input_data.parse_stack_inputs().map_err(Report::msg)?;
let mut host = DefaultHost::new(input_data.parse_advice_provider().map_err(Report::msg)?);
host.load_mast_forest(StdLibrary::default().mast_forest().clone());
host.load_mast_forest(StdLibrary::default().mast_forest().clone())
.into_diagnostic()?;

let execution_details: ExecutionDetails = analyze(program.as_str(), stack_inputs, host)
.expect("Could not retrieve execution details");
Expand Down
Loading

0 comments on commit 9f56203

Please sign in to comment.