From a94d52d45ee0302a2ee2f83c0d0077f6efd4b491 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 20 Jun 2024 14:53:15 -0400 Subject: [PATCH 001/172] Introduce `ExternalNode` --- core/src/mast/node/external.rs | 44 +++++++++++++++++++++++++++++++++ core/src/mast/node/mod.rs | 12 +++++++++ miden/src/examples/fibonacci.rs | 4 ++- miden/src/examples/mod.rs | 1 - processor/src/lib.rs | 3 +++ prover/src/gpu/metal/mod.rs | 11 ++++++--- 6 files changed, 70 insertions(+), 5 deletions(-) create mode 100644 core/src/mast/node/external.rs diff --git a/core/src/mast/node/external.rs b/core/src/mast/node/external.rs new file mode 100644 index 0000000000..18eb8df17c --- /dev/null +++ b/core/src/mast/node/external.rs @@ -0,0 +1,44 @@ +use crate::mast::{MastForest, MerkleTreeNode}; +use core::fmt; +use miden_crypto::hash::rpo::RpoDigest; +/// Block for a unknown function call. +/// +/// Proxy blocks are used to verify the integrity of a program's hash while keeping parts +/// of the program secret. Fails if executed. +/// +/// Hash of a proxy block is not computed but is rather defined at instantiation time. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ExternalNode { + digest: RpoDigest, +} + +impl ExternalNode { + /// Returns a new [Proxy] block instantiated with the specified code hash. + pub fn new(code_hash: RpoDigest) -> Self { + Self { digest: code_hash } + } +} + +impl MerkleTreeNode for ExternalNode { + fn digest(&self) -> RpoDigest { + self.digest + } + fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + self + } +} + +impl crate::prettier::PrettyPrint for ExternalNode { + fn render(&self) -> crate::prettier::Document { + use crate::prettier::*; + use miden_formatting::hex::ToHex; + const_text("external") + const_text(".") + text(self.digest.as_bytes().to_hex_with_prefix()) + } +} + +impl fmt::Display for ExternalNode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use crate::prettier::PrettyPrint; + self.pretty_print(f) + } +} diff --git a/core/src/mast/node/mod.rs b/core/src/mast/node/mod.rs index 1fc8275194..2bf0836cf3 100644 --- a/core/src/mast/node/mod.rs +++ b/core/src/mast/node/mod.rs @@ -13,6 +13,9 @@ pub use call_node::CallNode; mod dyn_node; pub use dyn_node::DynNode; +mod external; +pub use external::ExternalNode; + mod join_node; pub use join_node::JoinNode; @@ -37,6 +40,7 @@ pub enum MastNode { Loop(LoopNode), Call(CallNode), Dyn, + External(ExternalNode), } /// Constructors @@ -87,6 +91,10 @@ impl MastNode { pub fn new_dyncall(dyn_node_id: MastNodeId, mast_forest: &MastForest) -> Self { Self::Call(CallNode::new(dyn_node_id, mast_forest)) } + + pub fn new_external(mast_root: RpoDigest) -> Self { + Self::External(ExternalNode::new(mast_root)) + } } /// Public accessors @@ -116,6 +124,7 @@ impl MastNode { MastNodePrettyPrint::new(Box::new(call_node.to_pretty_print(mast_forest))) } MastNode::Dyn => MastNodePrettyPrint::new(Box::new(DynNode)), + MastNode::External(external_node) => MastNodePrettyPrint::new(Box::new(external_node)), } } @@ -127,6 +136,7 @@ impl MastNode { MastNode::Loop(_) => LoopNode::DOMAIN, MastNode::Call(call_node) => call_node.domain(), MastNode::Dyn => DynNode::DOMAIN, + MastNode::External(_) => panic!("Can't fetch domain for an `External` node."), } } } @@ -140,6 +150,7 @@ impl MerkleTreeNode for MastNode { MastNode::Loop(node) => node.digest(), MastNode::Call(node) => node.digest(), MastNode::Dyn => DynNode.digest(), + MastNode::External(node) => node.digest(), } } @@ -151,6 +162,7 @@ impl MerkleTreeNode for MastNode { MastNode::Loop(node) => MastNodeDisplay::new(node.to_display(mast_forest)), MastNode::Call(node) => MastNodeDisplay::new(node.to_display(mast_forest)), MastNode::Dyn => MastNodeDisplay::new(DynNode.to_display(mast_forest)), + MastNode::External(node) => MastNodeDisplay::new(node.to_display(mast_forest)), } } } diff --git a/miden/src/examples/fibonacci.rs b/miden/src/examples/fibonacci.rs index 881fb87624..4629f960c2 100644 --- a/miden/src/examples/fibonacci.rs +++ b/miden/src/examples/fibonacci.rs @@ -1,5 +1,7 @@ use super::{Example, ONE, ZERO}; -use miden_vm::{math::Felt, Assembler, DefaultHost, MemAdviceProvider, Program, ProvingOptions, StackInputs}; +use miden_vm::{ + math::Felt, Assembler, DefaultHost, MemAdviceProvider, Program, ProvingOptions, StackInputs, +}; // EXAMPLE BUILDER // ================================================================================================ diff --git a/miden/src/examples/mod.rs b/miden/src/examples/mod.rs index 4b7b2ce37c..7ef8837c4a 100644 --- a/miden/src/examples/mod.rs +++ b/miden/src/examples/mod.rs @@ -194,7 +194,6 @@ where } } - #[cfg(test)] pub fn test_example(example: Example, fail: bool) where diff --git a/processor/src/lib.rs b/processor/src/lib.rs index 244cb31b4a..6ca7974f81 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -255,6 +255,9 @@ where MastNode::Loop(node) => self.execute_loop_node(node, program), MastNode::Call(node) => self.execute_call_node(node, program), MastNode::Dyn => self.execute_dyn_node(program), + MastNode::External(_node) => { + todo!() + } } } diff --git a/prover/src/gpu/metal/mod.rs b/prover/src/gpu/metal/mod.rs index 1b233eae71..e664af320a 100644 --- a/prover/src/gpu/metal/mod.rs +++ b/prover/src/gpu/metal/mod.rs @@ -86,8 +86,9 @@ where // if we will fill the entire segment, we allocate uninitialized memory unsafe { page_aligned_uninit_vector(domain_size) } } else { - // but if some columns in the segment will remain unfilled, we allocate memory initialized - // to zeros to make sure we don't end up with memory with undefined values + // but if some columns in the segment will remain unfilled, we allocate memory + // initialized to zeros to make sure we don't end up with memory with + // undefined values group_vector_elements(Felt::zeroed_vector(N * domain_size)) }; @@ -197,7 +198,11 @@ where let blowup = domain.trace_to_lde_blowup(); let offsets = get_evaluation_offsets::(composition_poly.column_len(), blowup, domain.offset()); - let segments = Self::build_aligned_segements(composition_poly.data(), domain.trace_twiddles(), &offsets); + let segments = Self::build_aligned_segements( + composition_poly.data(), + domain.trace_twiddles(), + &offsets, + ); event!( Level::INFO, "Evaluated {} composition polynomial columns over LDE domain (2^{} elements) in {} ms", From f8ad339d227db3ef6d09c5ebdf9dfa760fe86498 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 20 Jun 2024 15:23:35 -0400 Subject: [PATCH 002/172] Replace `Assembler.node_id_by_digest` map --- assembly/src/assembler/basic_block_builder.rs | 2 +- .../src/assembler/instruction/procedures.rs | 12 +-- assembly/src/assembler/mod.rs | 18 ++-- assembly/src/assembler/tests.rs | 38 ++++----- core/src/mast/mod.rs | 63 +++++++------- core/src/program.rs | 2 +- .../integration/operations/io_ops/env_ops.rs | 2 +- processor/src/chiplets/hasher/tests.rs | 20 ++--- processor/src/chiplets/tests.rs | 2 +- processor/src/decoder/tests.rs | 85 ++++++++++--------- processor/src/trace/tests/chiplets/hasher.rs | 12 +-- processor/src/trace/tests/decoder.rs | 48 +++++------ processor/src/trace/tests/mod.rs | 4 +- 13 files changed, 158 insertions(+), 150 deletions(-) diff --git a/assembly/src/assembler/basic_block_builder.rs b/assembly/src/assembler/basic_block_builder.rs index 24889ef3b7..3b6f0e3fd3 100644 --- a/assembly/src/assembler/basic_block_builder.rs +++ b/assembly/src/assembler/basic_block_builder.rs @@ -129,7 +129,7 @@ impl BasicBlockBuilder { let decorators = self.decorators.drain(..).collect(); let basic_block_node = MastNode::new_basic_block_with_decorators(ops, decorators); - let basic_block_node_id = mast_forest.ensure_node(basic_block_node); + let basic_block_node_id = mast_forest.add_node(basic_block_node); Some(basic_block_node_id) } else if !self.decorators.is_empty() { diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index 9ac2bcfdd3..65fcd12162 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -88,7 +88,7 @@ impl Assembler { // `MastForest` contains all the procedures being called; "external procedures" only // known by digest are not currently supported. let callee_id = mast_forest - .get_node_id_by_digest(mast_root) + .find_root(mast_root) .unwrap_or_else(|| panic!("MAST root {} not in MAST forest", mast_root)); match kind { @@ -98,12 +98,12 @@ impl Assembler { // For `call`, we just use the corresponding CALL block InvokeKind::Call => { let node = MastNode::new_call(callee_id, mast_forest); - mast_forest.ensure_node(node) + mast_forest.add_node(node) } // For `syscall`, we just use the corresponding SYSCALL block InvokeKind::SysCall => { let node = MastNode::new_syscall(callee_id, mast_forest); - mast_forest.ensure_node(node) + mast_forest.add_node(node) } } }; @@ -116,7 +116,7 @@ impl Assembler { &self, mast_forest: &mut MastForest, ) -> Result, AssemblyError> { - let dyn_node_id = mast_forest.ensure_node(MastNode::Dyn); + let dyn_node_id = mast_forest.add_node(MastNode::Dyn); Ok(Some(dyn_node_id)) } @@ -127,10 +127,10 @@ impl Assembler { mast_forest: &mut MastForest, ) -> Result, AssemblyError> { let dyn_call_node_id = { - let dyn_node_id = mast_forest.ensure_node(MastNode::Dyn); + let dyn_node_id = mast_forest.add_node(MastNode::Dyn); let dyn_call_node = MastNode::new_call(dyn_node_id, mast_forest); - mast_forest.ensure_node(dyn_call_node) + mast_forest.add_node(dyn_call_node) }; Ok(Some(dyn_call_node_id)) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 4c7bf1dace..d7c6f2d989 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -778,7 +778,7 @@ impl Assembler { // by noop span let else_blk = if else_blk.is_empty() { let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]); - mast_forest.ensure_node(basic_block_node) + mast_forest.add_node(basic_block_node) } else { self.compile_body(else_blk.iter(), context, None, mast_forest)? }; @@ -786,7 +786,7 @@ impl Assembler { let split_node_id = { let split_node = MastNode::new_split(then_blk, else_blk, mast_forest); - mast_forest.ensure_node(split_node) + mast_forest.add_node(split_node) }; mast_node_ids.push(split_node_id); } @@ -816,7 +816,7 @@ impl Assembler { let loop_node_id = { let loop_node = MastNode::new_loop(loop_body_node_id, mast_forest); - mast_forest.ensure_node(loop_node) + mast_forest.add_node(loop_node) }; mast_node_ids.push(loop_node_id); } @@ -829,7 +829,10 @@ impl Assembler { Ok(if mast_node_ids.is_empty() { let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]); - mast_forest.ensure_node(basic_block_node) + let basic_block_node_id = mast_forest.add_node(basic_block_node); + mast_forest.ensure_root(basic_block_node_id); + + basic_block_node_id } else { combine_mast_node_ids(mast_node_ids, mast_forest) }) @@ -890,7 +893,7 @@ fn combine_mast_node_ids( (source_mast_node_iter.next(), source_mast_node_iter.next()) { let join_mast_node = MastNode::new_join(left, right, mast_forest); - let join_mast_node_id = mast_forest.ensure_node(join_mast_node); + let join_mast_node_id = mast_forest.add_node(join_mast_node); mast_node_ids.push(join_mast_node_id); } @@ -899,5 +902,8 @@ fn combine_mast_node_ids( } } - mast_node_ids.remove(0) + let root_id = mast_node_ids.remove(0); + mast_forest.ensure_root(root_id); + + root_id } diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index 5f99c93aa3..3c6ee31709 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -77,10 +77,10 @@ fn nested_blocks() { // `Assembler::with_kernel_from_module()`. let syscall_foo_node_id = { let kernel_foo_node = MastNode::new_basic_block(vec![Operation::Add]); - let kernel_foo_node_id = expected_mast_forest.ensure_node(kernel_foo_node); + let kernel_foo_node_id = expected_mast_forest.add_node(kernel_foo_node); let syscall_node = MastNode::new_syscall(kernel_foo_node_id, &expected_mast_forest); - expected_mast_forest.ensure_node(syscall_node) + expected_mast_forest.add_node(syscall_node) }; let program = r#" @@ -124,81 +124,81 @@ fn nested_blocks() { let exec_bar_node_id = { // bar procedure let basic_block_1 = MastNode::new_basic_block(vec![Operation::Push(17_u32.into())]); - let basic_block_1_id = expected_mast_forest.ensure_node(basic_block_1); + let basic_block_1_id = expected_mast_forest.add_node(basic_block_1); // Basic block representing the `foo` procedure let basic_block_2 = MastNode::new_basic_block(vec![Operation::Push(19_u32.into())]); - let basic_block_2_id = expected_mast_forest.ensure_node(basic_block_2); + let basic_block_2_id = expected_mast_forest.add_node(basic_block_2); let join_node = MastNode::new_join(basic_block_1_id, basic_block_2_id, &expected_mast_forest); - expected_mast_forest.ensure_node(join_node) + expected_mast_forest.add_node(join_node) }; let exec_foo_bar_baz_node_id = { // basic block representing foo::bar.baz procedure let basic_block = MastNode::new_basic_block(vec![Operation::Push(29_u32.into())]); - expected_mast_forest.ensure_node(basic_block) + expected_mast_forest.add_node(basic_block) }; let before = { let before_node = MastNode::new_basic_block(vec![Operation::Push(2u32.into())]); - expected_mast_forest.ensure_node(before_node) + expected_mast_forest.add_node(before_node) }; let r#true1 = { let r#true_node = MastNode::new_basic_block(vec![Operation::Push(3u32.into())]); - expected_mast_forest.ensure_node(r#true_node) + expected_mast_forest.add_node(r#true_node) }; let r#false1 = { let r#false_node = MastNode::new_basic_block(vec![Operation::Push(5u32.into())]); - expected_mast_forest.ensure_node(r#false_node) + expected_mast_forest.add_node(r#false_node) }; let r#if1 = { let r#if_node = MastNode::new_split(r#true1, r#false1, &expected_mast_forest); - expected_mast_forest.ensure_node(r#if_node) + expected_mast_forest.add_node(r#if_node) }; let r#true3 = { let r#true_node = MastNode::new_basic_block(vec![Operation::Push(7u32.into())]); - expected_mast_forest.ensure_node(r#true_node) + expected_mast_forest.add_node(r#true_node) }; let r#false3 = { let r#false_node = MastNode::new_basic_block(vec![Operation::Push(11u32.into())]); - expected_mast_forest.ensure_node(r#false_node) + expected_mast_forest.add_node(r#false_node) }; let r#true2 = { let r#if_node = MastNode::new_split(r#true3, r#false3, &expected_mast_forest); - expected_mast_forest.ensure_node(r#if_node) + expected_mast_forest.add_node(r#if_node) }; let r#while = { let push_basic_block_id = { let push_basic_block = MastNode::new_basic_block(vec![Operation::Push(23u32.into())]); - expected_mast_forest.ensure_node(push_basic_block) + expected_mast_forest.add_node(push_basic_block) }; let body_node_id = { let body_node = MastNode::new_join(exec_bar_node_id, push_basic_block_id, &expected_mast_forest); - expected_mast_forest.ensure_node(body_node) + expected_mast_forest.add_node(body_node) }; let loop_node = MastNode::new_loop(body_node_id, &expected_mast_forest); - expected_mast_forest.ensure_node(loop_node) + expected_mast_forest.add_node(loop_node) }; let push_13_basic_block_id = { let node = MastNode::new_basic_block(vec![Operation::Push(13u32.into())]); - expected_mast_forest.ensure_node(node) + expected_mast_forest.add_node(node) }; let r#false2 = { let node = MastNode::new_join(push_13_basic_block_id, r#while, &expected_mast_forest); - expected_mast_forest.ensure_node(node) + expected_mast_forest.add_node(node) }; let nested = { let node = MastNode::new_split(r#true2, r#false2, &expected_mast_forest); - expected_mast_forest.ensure_node(node) + expected_mast_forest.add_node(node) }; let combined_node_id = combine_mast_node_ids( diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 070e36ef40..11bedb99ba 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -1,6 +1,6 @@ use core::{fmt, ops::Index}; -use alloc::{collections::BTreeMap, vec::Vec}; +use alloc::vec::Vec; use miden_crypto::hash::rpo::RpoDigest; mod node; @@ -20,6 +20,7 @@ pub trait MerkleTreeNode { fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a; } +// TODOP: Remove `PartialEq/Eq` impls /// An opaque handle to a [`MastNode`] in some [`MastForest`]. It is the responsibility of the user /// to use a given [`MastNodeId`] with the corresponding [`MastForest`]. /// @@ -40,9 +41,11 @@ impl fmt::Display for MastNodeId { #[derive(Clone, Debug, Default)] pub struct MastForest { - /// All of the blocks local to the trees comprising the MAST forest + /// All of the blocks local to the trees comprising the MAST forest. nodes: Vec, - node_id_by_hash: BTreeMap, + + /// Roots of all procedures defined within this MAST forest. + roots: Vec, /// The "entrypoint", when set, is the root of the entire forest, i.e. a path exists from this /// node to all other roots in the forest. This corresponds to the executable entry point. @@ -63,26 +66,23 @@ impl MastForest { /// Mutators impl MastForest { /// Adds a node to the forest, and returns the [`MastNodeId`] associated with it. - /// - /// If a [`MastNode`] which is equal to the current node was previously added, the previously - /// returned [`MastNodeId`] will be returned. This enforces this invariant that equal - /// [`MastNode`]s have equal [`MastNodeId`]s. - pub fn ensure_node(&mut self, node: MastNode) -> MastNodeId { - let node_digest = node.digest(); - - if let Some(node_id) = self.node_id_by_hash.get(&node_digest) { - // node already exists in the forest; return previously assigned id - *node_id - } else { - let new_node_id = - MastNodeId(self.nodes.len().try_into().expect( - "invalid node id: exceeded maximum number of nodes in a single forest", - )); - - self.node_id_by_hash.insert(node_digest, new_node_id); - self.nodes.push(node); - - new_node_id + pub fn add_node(&mut self, node: MastNode) -> MastNodeId { + let new_node_id = MastNodeId( + self.nodes + .len() + .try_into() + .expect("invalid node id: exceeded maximum number of nodes in a single forest"), + ); + + self.nodes.push(node); + + new_node_id + } + + // TODOP: Document + pub fn ensure_root(&mut self, new_root_id: MastNodeId) { + if !self.roots.contains(&new_root_id) { + self.roots.push(new_root_id); } } @@ -92,15 +92,17 @@ impl MastForest { /// must be present in this forest. pub fn set_kernel(&mut self, kernel: Kernel) { #[cfg(debug_assertions)] - for proc_hash in kernel.proc_hashes() { - assert!(self.node_id_by_hash.contains_key(proc_hash)); + for &proc_hash in kernel.proc_hashes() { + assert!(self.find_root(proc_hash).is_some()); } self.kernel = kernel; } - /// Sets the entrypoint for this forest. + /// Sets the entrypoint for this forest. This also ensures that the entrypoint is a root in the + /// forest. pub fn set_entrypoint(&mut self, entrypoint: MastNodeId) { + self.ensure_root(entrypoint); self.entrypoint = Some(entrypoint); } } @@ -133,13 +135,10 @@ impl MastForest { self.nodes.get(idx) } - /// Returns the [`MastNodeId`] associated with a given digest, if any. - /// - /// That is, every [`MastNode`] hashes to some digest. If there exists a [`MastNode`] in the - /// forest that hashes to this digest, then its id is returned. + /// Returns the [`MastNodeId`] of the procedure associated with a given digest, if any. #[inline(always)] - pub fn get_node_id_by_digest(&self, digest: RpoDigest) -> Option { - self.node_id_by_hash.get(&digest).copied() + pub fn find_root(&self, digest: RpoDigest) -> Option { + self.roots.iter().find(|&&root_id| self[root_id].digest() == digest).copied() } } diff --git a/core/src/program.rs b/core/src/program.rs index 73e9fcf4be..6b7c8bdb90 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -68,7 +68,7 @@ impl Program { /// forest that hashes to this digest, then its id is returned. #[inline(always)] pub fn get_node_id_by_digest(&self, digest: RpoDigest) -> Option { - self.mast_forest.get_node_id_by_digest(digest) + self.mast_forest.find_root(digest) } } diff --git a/miden/tests/integration/operations/io_ops/env_ops.rs b/miden/tests/integration/operations/io_ops/env_ops.rs index c8acfbd41e..b0736ea74e 100644 --- a/miden/tests/integration/operations/io_ops/env_ops.rs +++ b/miden/tests/integration/operations/io_ops/env_ops.rs @@ -161,7 +161,7 @@ fn build_bar_hash() -> [u64; 4] { let mut mast_forest = MastForest::new(); let foo_root = MastNode::new_basic_block(vec![Operation::Caller]); - let foo_root_id = mast_forest.ensure_node(foo_root); + let foo_root_id = mast_forest.add_node(foo_root); let bar_root = MastNode::new_syscall(foo_root_id, &mast_forest); let bar_hash: Word = bar_root.digest().into(); diff --git a/processor/src/chiplets/hasher/tests.rs b/processor/src/chiplets/hasher/tests.rs index 4462ad6bef..912aeadc28 100644 --- a/processor/src/chiplets/hasher/tests.rs +++ b/processor/src/chiplets/hasher/tests.rs @@ -249,19 +249,19 @@ fn hash_memoization_control_blocks() { let mut mast_forest = MastForest::new(); let t_branch = MastNode::new_basic_block(vec![Operation::Push(ZERO)]); - let t_branch_id = mast_forest.ensure_node(t_branch.clone()); + let t_branch_id = mast_forest.add_node(t_branch.clone()); let f_branch = MastNode::new_basic_block(vec![Operation::Push(ONE)]); - let f_branch_id = mast_forest.ensure_node(f_branch.clone()); + let f_branch_id = mast_forest.add_node(f_branch.clone()); let split1 = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); - let split1_id = mast_forest.ensure_node(split1.clone()); + let split1_id = mast_forest.add_node(split1.clone()); let split2 = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); - let split2_id = mast_forest.ensure_node(split2.clone()); + let split2_id = mast_forest.add_node(split2.clone()); let join_node = MastNode::new_join(split1_id, split2_id, &mast_forest); - let _join_node_id = mast_forest.ensure_node(join_node.clone()); + let _join_node_id = mast_forest.add_node(join_node.clone()); let mut hasher = Hasher::default(); let h1: [Felt; DIGEST_LEN] = split1 @@ -414,19 +414,19 @@ fn hash_memoization_basic_blocks_check(basic_block: MastNode) { let mut mast_forest = MastForest::new(); let basic_block_1 = basic_block.clone(); - let basic_block_1_id = mast_forest.ensure_node(basic_block_1.clone()); + let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()); let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Eq, Operation::Not]); - let loop_body_id = mast_forest.ensure_node(loop_body); + let loop_body_id = mast_forest.add_node(loop_body); let loop_block = MastNode::new_loop(loop_body_id, &mast_forest); - let loop_block_id = mast_forest.ensure_node(loop_block.clone()); + let loop_block_id = mast_forest.add_node(loop_block.clone()); let join2_block = MastNode::new_join(basic_block_1_id, loop_block_id, &mast_forest); - let join2_block_id = mast_forest.ensure_node(join2_block.clone()); + let join2_block_id = mast_forest.add_node(join2_block.clone()); let basic_block_2 = basic_block; - let basic_block_2_id = mast_forest.ensure_node(basic_block_2.clone()); + let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()); let join1_block = MastNode::new_join(join2_block_id, basic_block_2_id, &mast_forest); diff --git a/processor/src/chiplets/tests.rs b/processor/src/chiplets/tests.rs index a7c6c526e7..4c902c84d0 100644 --- a/processor/src/chiplets/tests.rs +++ b/processor/src/chiplets/tests.rs @@ -120,7 +120,7 @@ fn build_trace( let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(operations); - let basic_block_id = mast_forest.ensure_node(basic_block); + let basic_block_id = mast_forest.add_node(basic_block); mast_forest.set_entrypoint(basic_block_id); mast_forest.try_into().unwrap() diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index a0474cb398..840b959e43 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -50,7 +50,7 @@ fn basic_block_one_group() { let mut mast_forest = MastForest::new(); let basic_block_node = MastNode::Block(basic_block.clone()); - let basic_block_id = mast_forest.ensure_node(basic_block_node); + let basic_block_id = mast_forest.add_node(basic_block_node); mast_forest.set_entrypoint(basic_block_id); Program::new(mast_forest).unwrap() @@ -97,7 +97,7 @@ fn basic_block_small() { let mut mast_forest = MastForest::new(); let basic_block_node = MastNode::Block(basic_block.clone()); - let basic_block_id = mast_forest.ensure_node(basic_block_node); + let basic_block_id = mast_forest.add_node(basic_block_node); mast_forest.set_entrypoint(basic_block_id); Program::new(mast_forest).unwrap() @@ -161,7 +161,7 @@ fn basic_block() { let mut mast_forest = MastForest::new(); let basic_block_node = MastNode::Block(basic_block.clone()); - let basic_block_id = mast_forest.ensure_node(basic_block_node); + let basic_block_id = mast_forest.add_node(basic_block_node); mast_forest.set_entrypoint(basic_block_id); Program::new(mast_forest).unwrap() @@ -254,7 +254,7 @@ fn span_block_with_respan() { let mut mast_forest = MastForest::new(); let basic_block_node = MastNode::Block(basic_block.clone()); - let basic_block_id = mast_forest.ensure_node(basic_block_node); + let basic_block_id = mast_forest.add_node(basic_block_node); mast_forest.set_entrypoint(basic_block_id); Program::new(mast_forest).unwrap() @@ -328,11 +328,11 @@ fn join_node() { let program = { let mut mast_forest = MastForest::new(); - let basic_block1_id = mast_forest.ensure_node(basic_block1.clone()); - let basic_block2_id = mast_forest.ensure_node(basic_block2.clone()); + let basic_block1_id = mast_forest.add_node(basic_block1.clone()); + let basic_block2_id = mast_forest.add_node(basic_block2.clone()); let join_node = MastNode::new_join(basic_block1_id, basic_block2_id, &mast_forest); - let join_node_id = mast_forest.ensure_node(join_node); + let join_node_id = mast_forest.add_node(join_node); mast_forest.set_entrypoint(join_node_id); Program::new(mast_forest).unwrap() @@ -395,11 +395,11 @@ fn split_node_true() { let program = { let mut mast_forest = MastForest::new(); - let basic_block1_id = mast_forest.ensure_node(basic_block1.clone()); - let basic_block2_id = mast_forest.ensure_node(basic_block2.clone()); + let basic_block1_id = mast_forest.add_node(basic_block1.clone()); + let basic_block2_id = mast_forest.add_node(basic_block2.clone()); let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest); - let split_node_id = mast_forest.ensure_node(split_node); + let split_node_id = mast_forest.add_node(split_node); mast_forest.set_entrypoint(split_node_id); Program::new(mast_forest).unwrap() @@ -449,11 +449,11 @@ fn split_node_false() { let program = { let mut mast_forest = MastForest::new(); - let basic_block1_id = mast_forest.ensure_node(basic_block1.clone()); - let basic_block2_id = mast_forest.ensure_node(basic_block2.clone()); + let basic_block1_id = mast_forest.add_node(basic_block1.clone()); + let basic_block2_id = mast_forest.add_node(basic_block2.clone()); let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest); - let split_node_id = mast_forest.ensure_node(split_node); + let split_node_id = mast_forest.add_node(split_node); mast_forest.set_entrypoint(split_node_id); Program::new(mast_forest).unwrap() @@ -505,10 +505,10 @@ fn loop_node() { let program = { let mut mast_forest = MastForest::new(); - let loop_body_id = mast_forest.ensure_node(loop_body.clone()); + let loop_body_id = mast_forest.add_node(loop_body.clone()); let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); - let loop_node_id = mast_forest.ensure_node(loop_node); + let loop_node_id = mast_forest.add_node(loop_node); mast_forest.set_entrypoint(loop_node_id); Program::new(mast_forest).unwrap() @@ -559,10 +559,10 @@ fn loop_node_skip() { let program = { let mut mast_forest = MastForest::new(); - let loop_body_id = mast_forest.ensure_node(loop_body.clone()); + let loop_body_id = mast_forest.add_node(loop_body.clone()); let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); - let loop_node_id = mast_forest.ensure_node(loop_node); + let loop_node_id = mast_forest.add_node(loop_node); mast_forest.set_entrypoint(loop_node_id); Program::new(mast_forest).unwrap() @@ -603,10 +603,10 @@ fn loop_node_repeat() { let program = { let mut mast_forest = MastForest::new(); - let loop_body_id = mast_forest.ensure_node(loop_body.clone()); + let loop_body_id = mast_forest.add_node(loop_body.clone()); let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); - let loop_node_id = mast_forest.ensure_node(loop_node); + let loop_node_id = mast_forest.add_node(loop_node); mast_forest.set_entrypoint(loop_node_id); Program::new(mast_forest).unwrap() @@ -693,24 +693,24 @@ fn call_block() { Operation::FmpUpdate, Operation::Pad, ]); - let first_basic_block_id = mast_forest.ensure_node(first_basic_block.clone()); + let first_basic_block_id = mast_forest.add_node(first_basic_block.clone()); let foo_root_node = MastNode::new_basic_block(vec![ Operation::Push(ONE), Operation::FmpUpdate ]); - let foo_root_node_id = mast_forest.ensure_node(foo_root_node.clone()); + let foo_root_node_id = mast_forest.add_node(foo_root_node.clone()); let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd]); - let last_basic_block_id = mast_forest.ensure_node(last_basic_block.clone()); + let last_basic_block_id = mast_forest.add_node(last_basic_block.clone()); let foo_call_node = MastNode::new_call(foo_root_node_id, &mast_forest); - let foo_call_node_id = mast_forest.ensure_node(foo_call_node.clone()); + let foo_call_node_id = mast_forest.add_node(foo_call_node.clone()); let join1_node = MastNode::new_join(first_basic_block_id, foo_call_node_id, &mast_forest); - let join1_node_id = mast_forest.ensure_node(join1_node.clone()); + let join1_node_id = mast_forest.add_node(join1_node.clone()); let program_root = MastNode::new_join(join1_node_id, last_basic_block_id, &mast_forest); - let program_root_id = mast_forest.ensure_node(program_root); + let program_root_id = mast_forest.add_node(program_root); mast_forest.set_entrypoint(program_root_id); let program = Program::new(mast_forest).unwrap(); @@ -904,19 +904,21 @@ fn syscall_block() { // build foo procedure body let foo_root = MastNode::new_basic_block(vec![Operation::Push(THREE), Operation::FmpUpdate]); - let foo_root_id = mast_forest.ensure_node(foo_root.clone()); + let foo_root_id = mast_forest.add_node(foo_root.clone()); + mast_forest.ensure_root(foo_root_id); let kernel = Kernel::new(&[foo_root.digest()]).unwrap(); mast_forest.set_kernel(kernel.clone()); // build bar procedure body let bar_basic_block = MastNode::new_basic_block(vec![Operation::Push(TWO), Operation::FmpUpdate]); - let bar_basic_block_id = mast_forest.ensure_node(bar_basic_block.clone()); + let bar_basic_block_id = mast_forest.add_node(bar_basic_block.clone()); let foo_call_node = MastNode::new_syscall(foo_root_id, &mast_forest); - let foo_call_node_id = mast_forest.ensure_node(foo_call_node.clone()); + let foo_call_node_id = mast_forest.add_node(foo_call_node.clone()); let bar_root_node = MastNode::new_join(bar_basic_block_id, foo_call_node_id, &mast_forest); - let bar_root_node_id = mast_forest.ensure_node(bar_root_node.clone()); + let bar_root_node_id = mast_forest.add_node(bar_root_node.clone()); + mast_forest.ensure_root(bar_root_node_id); // build the program let first_basic_block = MastNode::new_basic_block(vec![ @@ -924,19 +926,19 @@ fn syscall_block() { Operation::FmpUpdate, Operation::Pad, ]); - let first_basic_block_id = mast_forest.ensure_node(first_basic_block.clone()); + let first_basic_block_id = mast_forest.add_node(first_basic_block.clone()); let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd]); - let last_basic_block_id = mast_forest.ensure_node(last_basic_block.clone()); + let last_basic_block_id = mast_forest.add_node(last_basic_block.clone()); let bar_call_node = MastNode::new_call(bar_root_node_id, &mast_forest); - let bar_call_node_id = mast_forest.ensure_node(bar_call_node.clone()); + let bar_call_node_id = mast_forest.add_node(bar_call_node.clone()); let inner_join_node = MastNode::new_join(first_basic_block_id, bar_call_node_id, &mast_forest); - let inner_join_node_id = mast_forest.ensure_node(inner_join_node.clone()); + let inner_join_node_id = mast_forest.add_node(inner_join_node.clone()); let program_root_node = MastNode::new_join(inner_join_node_id, last_basic_block_id, &mast_forest); - let program_root_node_id = mast_forest.ensure_node(program_root_node.clone()); + let program_root_node_id = mast_forest.add_node(program_root_node.clone()); mast_forest.set_entrypoint(program_root_node_id); let program = Program::new(mast_forest).unwrap(); @@ -1192,23 +1194,24 @@ fn dyn_block() { let mut mast_forest = MastForest::new(); let foo_root_node = MastNode::new_basic_block(vec![Operation::Push(ONE), Operation::Add]); - let _foo_root_node_id = mast_forest.ensure_node(foo_root_node.clone()); + let foo_root_node_id = mast_forest.add_node(foo_root_node.clone()); + mast_forest.ensure_root(foo_root_node_id); let mul_bb_node = MastNode::new_basic_block(vec![Operation::Mul]); - let mul_bb_node_id = mast_forest.ensure_node(mul_bb_node.clone()); + let mul_bb_node_id = mast_forest.add_node(mul_bb_node.clone()); let save_bb_node = MastNode::new_basic_block(vec![Operation::MovDn4]); - let save_bb_node_id = mast_forest.ensure_node(save_bb_node.clone()); + let save_bb_node_id = mast_forest.add_node(save_bb_node.clone()); let join_node = MastNode::new_join(mul_bb_node_id, save_bb_node_id, &mast_forest); - let join_node_id = mast_forest.ensure_node(join_node.clone()); + let join_node_id = mast_forest.add_node(join_node.clone()); // This dyn will point to foo. let dyn_node = MastNode::new_dynexec(); - let dyn_node_id = mast_forest.ensure_node(dyn_node.clone()); + let dyn_node_id = mast_forest.add_node(dyn_node.clone()); let program_root_node = MastNode::new_join(join_node_id, dyn_node_id, &mast_forest); - let program_root_node_id = mast_forest.ensure_node(program_root_node.clone()); + let program_root_node_id = mast_forest.add_node(program_root_node.clone()); mast_forest.set_entrypoint(program_root_node_id); let program = Program::new(mast_forest).unwrap(); @@ -1317,7 +1320,7 @@ fn set_user_op_helpers_many() { let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(vec![Operation::U32div]); - let basic_block_id = mast_forest.ensure_node(basic_block); + let basic_block_id = mast_forest.add_node(basic_block); mast_forest.set_entrypoint(basic_block_id); mast_forest.try_into().unwrap() diff --git a/processor/src/trace/tests/chiplets/hasher.rs b/processor/src/trace/tests/chiplets/hasher.rs index 19bfd34886..620397dbf5 100644 --- a/processor/src/trace/tests/chiplets/hasher.rs +++ b/processor/src/trace/tests/chiplets/hasher.rs @@ -51,7 +51,7 @@ pub fn b_chip_span() { let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(vec![Operation::Add, Operation::Mul]); - let basic_block_id = mast_forest.ensure_node(basic_block); + let basic_block_id = mast_forest.add_node(basic_block); mast_forest.set_entrypoint(basic_block_id); Program::new(mast_forest).unwrap() @@ -125,7 +125,7 @@ pub fn b_chip_span_with_respan() { let (ops, _) = build_span_with_respan_ops(); let basic_block = MastNode::new_basic_block(ops); - let basic_block_id = mast_forest.ensure_node(basic_block); + let basic_block_id = mast_forest.add_node(basic_block); mast_forest.set_entrypoint(basic_block_id); Program::new(mast_forest).unwrap() @@ -218,13 +218,13 @@ pub fn b_chip_merge() { let mut mast_forest = MastForest::new(); let t_branch = MastNode::new_basic_block(vec![Operation::Add]); - let t_branch_id = mast_forest.ensure_node(t_branch); + let t_branch_id = mast_forest.add_node(t_branch); let f_branch = MastNode::new_basic_block(vec![Operation::Mul]); - let f_branch_id = mast_forest.ensure_node(f_branch); + let f_branch_id = mast_forest.add_node(f_branch); let split = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); - let split_id = mast_forest.ensure_node(split); + let split_id = mast_forest.add_node(split); mast_forest.set_entrypoint(split_id); @@ -339,7 +339,7 @@ pub fn b_chip_permutation() { let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(vec![Operation::HPerm]); - let basic_block_id = mast_forest.ensure_node(basic_block); + let basic_block_id = mast_forest.add_node(basic_block); mast_forest.set_entrypoint(basic_block_id); Program::new(mast_forest).unwrap() diff --git a/processor/src/trace/tests/decoder.rs b/processor/src/trace/tests/decoder.rs index cbb4cf7e6c..32a448ddff 100644 --- a/processor/src/trace/tests/decoder.rs +++ b/processor/src/trace/tests/decoder.rs @@ -73,13 +73,13 @@ fn decoder_p1_join() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block_1_id = mast_forest.ensure_node(basic_block_1); + let basic_block_1_id = mast_forest.add_node(basic_block_1); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.ensure_node(basic_block_2); + let basic_block_2_id = mast_forest.add_node(basic_block_2); let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); - let join_id = mast_forest.ensure_node(join); + let join_id = mast_forest.add_node(join); mast_forest.set_entrypoint(join_id); @@ -146,13 +146,13 @@ fn decoder_p1_split() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block_1_id = mast_forest.ensure_node(basic_block_1); + let basic_block_1_id = mast_forest.add_node(basic_block_1); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.ensure_node(basic_block_2); + let basic_block_2_id = mast_forest.add_node(basic_block_2); let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); - let split_id = mast_forest.ensure_node(split); + let split_id = mast_forest.add_node(split); mast_forest.set_entrypoint(split_id); @@ -206,16 +206,16 @@ fn decoder_p1_loop_with_repeat() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Pad]); - let basic_block_1_id = mast_forest.ensure_node(basic_block_1); + let basic_block_1_id = mast_forest.add_node(basic_block_1); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Drop]); - let basic_block_2_id = mast_forest.ensure_node(basic_block_2); + let basic_block_2_id = mast_forest.add_node(basic_block_2); let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); - let join_id = mast_forest.ensure_node(join); + let join_id = mast_forest.add_node(join); let loop_node = MastNode::new_loop(join_id, &mast_forest); - let loop_node_id = mast_forest.ensure_node(loop_node); + let loop_node_id = mast_forest.add_node(loop_node); mast_forest.set_entrypoint(loop_node_id); @@ -339,7 +339,7 @@ fn decoder_p2_span_with_respan() { let (ops, _) = build_span_with_respan_ops(); let basic_block = MastNode::new_basic_block(ops); - let basic_block_id = mast_forest.ensure_node(basic_block); + let basic_block_id = mast_forest.add_node(basic_block); mast_forest.set_entrypoint(basic_block_id); Program::new(mast_forest).unwrap() @@ -376,13 +376,13 @@ fn decoder_p2_join() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block_1_id = mast_forest.ensure_node(basic_block_1.clone()); + let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.ensure_node(basic_block_2.clone()); + let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()); let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); - let join_id = mast_forest.ensure_node(join.clone()); + let join_id = mast_forest.add_node(join.clone()); mast_forest.set_entrypoint(join_id); let program = Program::new(mast_forest).unwrap(); @@ -442,13 +442,13 @@ fn decoder_p2_split_true() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block_1_id = mast_forest.ensure_node(basic_block_1.clone()); + let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.ensure_node(basic_block_2); + let basic_block_2_id = mast_forest.add_node(basic_block_2); let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); - let split_id = mast_forest.ensure_node(split); + let split_id = mast_forest.add_node(split); mast_forest.set_entrypoint(split_id); @@ -500,13 +500,13 @@ fn decoder_p2_split_false() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block_1_id = mast_forest.ensure_node(basic_block_1.clone()); + let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.ensure_node(basic_block_2.clone()); + let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()); let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); - let split_id = mast_forest.ensure_node(split); + let split_id = mast_forest.add_node(split); mast_forest.set_entrypoint(split_id); @@ -558,16 +558,16 @@ fn decoder_p2_loop_with_repeat() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Pad]); - let basic_block_1_id = mast_forest.ensure_node(basic_block_1.clone()); + let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Drop]); - let basic_block_2_id = mast_forest.ensure_node(basic_block_2.clone()); + let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()); let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); - let join_id = mast_forest.ensure_node(join.clone()); + let join_id = mast_forest.add_node(join.clone()); let loop_node = MastNode::new_loop(join_id, &mast_forest); - let loop_node_id = mast_forest.ensure_node(loop_node); + let loop_node_id = mast_forest.add_node(loop_node); mast_forest.set_entrypoint(loop_node_id); diff --git a/processor/src/trace/tests/mod.rs b/processor/src/trace/tests/mod.rs index adc83f5245..6a57960d47 100644 --- a/processor/src/trace/tests/mod.rs +++ b/processor/src/trace/tests/mod.rs @@ -35,7 +35,7 @@ pub fn build_trace_from_ops(operations: Vec, stack: &[u64]) -> Execut let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(operations); - let basic_block_id = mast_forest.ensure_node(basic_block); + let basic_block_id = mast_forest.add_node(basic_block); mast_forest.set_entrypoint(basic_block_id); let program = Program::new(mast_forest).unwrap(); @@ -58,7 +58,7 @@ pub fn build_trace_from_ops_with_inputs( let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(operations); - let basic_block_id = mast_forest.ensure_node(basic_block); + let basic_block_id = mast_forest.add_node(basic_block); mast_forest.set_entrypoint(basic_block_id); let program = Program::new(mast_forest).unwrap(); From 06f421ba6b0498e44e7fd44b956d1bf499bb94c3 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 20 Jun 2024 15:33:17 -0400 Subject: [PATCH 003/172] add TODOP --- core/src/mast/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 11bedb99ba..3f4b7476ff 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -47,6 +47,8 @@ pub struct MastForest { /// Roots of all procedures defined within this MAST forest. roots: Vec, + // TODOP: Move fields to `Program` + /// The "entrypoint", when set, is the root of the entire forest, i.e. a path exists from this /// node to all other roots in the forest. This corresponds to the executable entry point. /// Whether or not the entrypoint is set distinguishes a MAST which is executable, versus a From bcf2a9b7998999852858903a6b8375867a7da212 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 20 Jun 2024 17:07:54 -0400 Subject: [PATCH 004/172] Add `Host::get_mast_forest` --- core/src/mast/mod.rs | 5 ++ processor/src/errors.rs | 9 ++++ processor/src/host/mast_forest_store.rs | 27 +++++++++++ processor/src/host/mod.rs | 64 +++++++++++++++++++------ processor/src/lib.rs | 12 ++++- 5 files changed, 100 insertions(+), 17 deletions(-) create mode 100644 processor/src/host/mast_forest_store.rs diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 3f4b7476ff..713bf52def 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -142,6 +142,11 @@ impl MastForest { pub fn find_root(&self, digest: RpoDigest) -> Option { self.roots.iter().find(|&&root_id| self[root_id].digest() == digest).copied() } + + /// Returns an iterator over the digest of the procedures in this MAST forest. + pub fn roots(&self) -> impl Iterator + '_ { + self.roots.iter().map(|&root_id| self[root_id].digest()) + } } impl Index for MastForest { diff --git a/processor/src/errors.rs b/processor/src/errors.rs index fa5d9251ac..38b3be870f 100644 --- a/processor/src/errors.rs +++ b/processor/src/errors.rs @@ -51,6 +51,9 @@ pub enum ExecutionError { MastNodeNotFoundInForest { node_id: MastNodeId, }, + MastForestNotFound { + root_digest: Digest, + }, MemoryAddressOutOfBounds(u64), MerklePathVerificationFailed { value: Word, @@ -150,6 +153,12 @@ impl Display for ExecutionError { MastNodeNotFoundInForest { node_id } => { write!(f, "Malformed MAST forest, node id {node_id} doesn't exist") } + MastForestNotFound { root_digest } => { + write!( + f, + "No MAST forest contains the following procedure root digest: {root_digest}" + ) + } MemoryAddressOutOfBounds(addr) => { write!(f, "Memory address cannot exceed 2^32 but was {addr}") } diff --git a/processor/src/host/mast_forest_store.rs b/processor/src/host/mast_forest_store.rs new file mode 100644 index 0000000000..b15a1785a2 --- /dev/null +++ b/processor/src/host/mast_forest_store.rs @@ -0,0 +1,27 @@ +use alloc::{collections::BTreeMap, sync::Arc}; +use vm_core::{crypto::hash::RpoDigest, mast::MastForest}; + +pub trait MastForestStore { + fn get(&self, node_digest: &RpoDigest) -> Option>; +} + +#[derive(Debug, Default, Clone)] +pub struct MemMastForestStore { + mast_forests: BTreeMap>, +} + +impl MemMastForestStore { + pub fn insert(&mut self, mast_forest: MastForest) { + let mast_forest = Arc::new(mast_forest); + + for root in mast_forest.roots() { + self.mast_forests.insert(root, mast_forest.clone()); + } + } +} + +impl MastForestStore for MemMastForestStore { + fn get(&self, node_digest: &RpoDigest) -> Option> { + self.mast_forests.get(node_digest).cloned() + } +} diff --git a/processor/src/host/mod.rs b/processor/src/host/mod.rs index 9642b22e28..d61b372af3 100644 --- a/processor/src/host/mod.rs +++ b/processor/src/host/mod.rs @@ -1,6 +1,12 @@ use super::{ExecutionError, Felt, ProcessState}; use crate::MemAdviceProvider; -use vm_core::{crypto::merkle::MerklePath, AdviceInjector, DebugOptions, Word}; +use alloc::sync::Arc; +use mast_forest_store::MemMastForestStore; +use vm_core::{ + crypto::{hash::RpoDigest, merkle::MerklePath}, + mast::MastForest, + AdviceInjector, DebugOptions, Word, +}; pub(super) mod advice; use advice::{AdviceExtractor, AdviceProvider}; @@ -8,6 +14,9 @@ use advice::{AdviceExtractor, AdviceProvider}; #[cfg(feature = "std")] mod debug; +mod mast_forest_store; +pub use mast_forest_store::MastForestStore; + // HOST TRAIT // ================================================================================================ @@ -25,19 +34,23 @@ pub trait Host { // -------------------------------------------------------------------------------------------- /// Returns the requested advice, specified by [AdviceExtractor], from the host to the VM. - fn get_advice( + fn get_advice( &mut self, - process: &S, + process: &P, extractor: AdviceExtractor, ) -> Result; /// Sets the requested advice, specified by [AdviceInjector], on the host. - fn set_advice( + fn set_advice( &mut self, - process: &S, + process: &P, injector: AdviceInjector, ) -> Result; + /// Returns MAST forest corresponding to the specified digest, or None if the MAST forest for + /// this digest could not be found in this [Host]. + fn get_mast_forest(&self, node_digest: &RpoDigest) -> Option>; + // PROVIDED METHODS // -------------------------------------------------------------------------------------------- @@ -182,6 +195,10 @@ where H::set_advice(self, process, injector) } + fn get_mast_forest(&self, node_digest: &RpoDigest) -> Option> { + H::get_mast_forest(self, node_digest) + } + fn on_debug( &mut self, process: &S, @@ -264,21 +281,30 @@ impl From for Felt { // ================================================================================================ /// A default [Host] implementation that provides the essential functionality required by the VM. -pub struct DefaultHost { +pub struct DefaultHost { adv_provider: A, + store: S, } -impl Default for DefaultHost { +impl Default for DefaultHost { fn default() -> Self { Self { adv_provider: MemAdviceProvider::default(), + store: MemMastForestStore::default(), } } } -impl DefaultHost { - pub fn new(adv_provider: A) -> Self { - Self { adv_provider } +impl DefaultHost +where + A: AdviceProvider, + S: MastForestStore, +{ + pub fn new(adv_provider: A, store: S) -> Self { + Self { + adv_provider, + store, + } } #[cfg(any(test, feature = "internals"))] @@ -296,20 +322,28 @@ impl DefaultHost { } } -impl Host for DefaultHost { - fn get_advice( +impl Host for DefaultHost +where + A: AdviceProvider, + S: MastForestStore, +{ + fn get_advice( &mut self, - process: &S, + process: &P, extractor: AdviceExtractor, ) -> Result { self.adv_provider.get_advice(process, &extractor) } - fn set_advice( + fn set_advice( &mut self, - process: &S, + process: &P, injector: AdviceInjector, ) -> Result { self.adv_provider.set_advice(process, &injector) } + + fn get_mast_forest(&self, node_digest: &RpoDigest) -> Option> { + self.store.get(node_digest) + } } diff --git a/processor/src/lib.rs b/processor/src/lib.rs index 6ca7974f81..34e430b396 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -255,8 +255,16 @@ where MastNode::Loop(node) => self.execute_loop_node(node, program), MastNode::Call(node) => self.execute_call_node(node, program), MastNode::Dyn => self.execute_dyn_node(program), - MastNode::External(_node) => { - todo!() + MastNode::External(node) => { + let mast_forest = self + .host + .borrow() + .get_mast_forest(&node.digest()) + .ok_or_else(|| ExecutionError::MastForestNotFound { + root_digest: node.digest(), + })?; + let node_id = mast_forest.find_root(node.digest()).expect(format!("Malformed host: MAST forest indexed by procedure root {} doesn't contain that root.", node.digest()).as_str()); + self.execute_mast_node(node_id, mast_forest) } } } From 980713d53421fd096840cddcb73422638d234a17 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 21 Jun 2024 08:52:16 -0400 Subject: [PATCH 005/172] Move kernel and entrypoint out of `MastForest` --- assembly/src/assembler/mod.rs | 45 ++++++------------ assembly/src/assembler/tests.rs | 11 +++-- assembly/src/testing.rs | 5 +- core/src/errors.rs | 1 + core/src/mast/mod.rs | 46 ------------------ core/src/mast/tests.rs | 4 +- core/src/program.rs | 49 ++++++++++++-------- miden/README.md | 6 +-- miden/src/cli/data.rs | 2 +- miden/src/examples/blake3.rs | 2 +- miden/src/examples/fibonacci.rs | 2 +- miden/src/repl/mod.rs | 2 +- miden/src/tools/mod.rs | 2 +- processor/src/chiplets/tests.rs | 5 +- processor/src/decoder/tests.rs | 43 ++++++----------- processor/src/lib.rs | 2 +- processor/src/trace/tests/chiplets/hasher.rs | 13 ++---- processor/src/trace/tests/decoder.rs | 30 ++++-------- processor/src/trace/tests/mod.rs | 6 +-- test-utils/src/lib.rs | 2 +- 20 files changed, 93 insertions(+), 185 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index d7c6f2d989..44dc696dc3 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -9,7 +9,6 @@ use crate::{ RpoDigest, Spanned, ONE, ZERO, }; use alloc::{boxed::Box, sync::Arc, vec::Vec}; -use miette::miette; use vm_core::{ mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, Decorator, DecoratorList, Kernel, Operation, Program, @@ -145,7 +144,6 @@ impl Assembler { let (kernel_index, kernel) = assembler.assemble_kernel_module(module, &mut mast_forest)?; assembler.module_graph.set_kernel(Some(kernel_index), kernel); - mast_forest.set_kernel(assembler.module_graph.kernel().clone()); assembler.mast_forest = mast_forest; @@ -313,18 +311,6 @@ impl Assembler { /// Compilation/Assembly impl Assembler { - /// Compiles the provided module into a [`MastForest`]. - /// - /// # Errors - /// - /// Returns an error if parsing or compilation of the specified program fails. - pub fn assemble(self, source: impl Compile) -> Result { - let mut context = AssemblyContext::default(); - context.set_warnings_as_errors(self.warnings_as_errors); - - self.assemble_in_context(source, &mut context) - } - /// Compiles the provided module into a [`Program`]. The resulting program can be executed on /// Miden VM. /// @@ -332,10 +318,11 @@ impl Assembler { /// /// Returns an error if parsing or compilation of the specified program fails, or if the source /// doesn't have an entrypoint. - pub fn assemble_program(self, source: impl Compile) -> Result { - let mast_forest = self.assemble(source)?; + pub fn assemble(self, source: impl Compile) -> Result { + let mut context = AssemblyContext::default(); + context.set_warnings_as_errors(self.warnings_as_errors); - mast_forest.try_into().map_err(|program_err| miette!("{program_err}")) + self.assemble_in_context(source, &mut context) } /// Like [Assembler::compile], but also takes an [AssemblyContext] to configure the assembler. @@ -343,7 +330,7 @@ impl Assembler { self, source: impl Compile, context: &mut AssemblyContext, - ) -> Result { + ) -> Result { let opts = CompileOptions { warnings_as_errors: context.warnings_as_errors(), ..CompileOptions::default() @@ -363,7 +350,7 @@ impl Assembler { self, source: impl Compile, options: CompileOptions, - ) -> Result { + ) -> Result { let mut context = AssemblyContext::default(); context.set_warnings_as_errors(options.warnings_as_errors); @@ -378,7 +365,7 @@ impl Assembler { source: impl Compile, options: CompileOptions, context: &mut AssemblyContext, - ) -> Result { + ) -> Result { self.assemble_with_options_in_context_impl(source, options, context) } @@ -391,14 +378,14 @@ impl Assembler { source: impl Compile, options: CompileOptions, context: &mut AssemblyContext, - ) -> Result { + ) -> Result { if options.kind != ModuleKind::Executable { return Err(Report::msg( "invalid compile options: assemble_with_opts_in_context requires that the kind be 'executable'", )); } - let mut mast_forest = core::mem::take(&mut self.mast_forest); + let mast_forest = core::mem::take(&mut self.mast_forest); let program = source.compile_with_options(CompileOptions { // Override the module name so that we always compile the executable @@ -428,9 +415,7 @@ impl Assembler { }) .ok_or(SemanticAnalysisError::MissingEntrypoint)?; - self.compile_program(entrypoint, context, &mut mast_forest)?; - - Ok(mast_forest) + self.compile_program(entrypoint, context, mast_forest) } /// Compile and assembles all procedures in the specified module, adding them to the procedure @@ -585,17 +570,15 @@ impl Assembler { &mut self, entrypoint: GlobalProcedureIndex, context: &mut AssemblyContext, - mast_forest: &mut MastForest, - ) -> Result<(), Report> { + mut mast_forest: MastForest, + ) -> Result { // Raise an error if we are called with an invalid entrypoint assert!(self.module_graph[entrypoint].name().is_main()); // Compile the module graph rooted at the entrypoint - let entry_procedure = self.compile_subgraph(entrypoint, true, context, mast_forest)?; + let entry_procedure = self.compile_subgraph(entrypoint, true, context, &mut mast_forest)?; - mast_forest.set_entrypoint(entry_procedure.body_node_id()); - - Ok(()) + Ok(Program::new_with_kernel(mast_forest, entry_procedure.body_node_id(), self.module_graph.kernel().clone())) } /// Compile all of the uncompiled procedures in the module graph, placing them diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index 3c6ee31709..c39b91113d 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -1,5 +1,8 @@ use alloc::{boxed::Box, vec::Vec}; -use vm_core::mast::{MastForest, MastNode, MerkleTreeNode}; +use vm_core::{ + mast::{MastForest, MastNode}, + Program, +}; use super::{Assembler, Library, Operation}; use crate::{ @@ -205,9 +208,7 @@ fn nested_blocks() { vec![before, r#if1, nested, exec_foo_bar_baz_node_id, syscall_foo_node_id], &mut expected_mast_forest, ); - expected_mast_forest.set_entrypoint(combined_node_id); - - let combined_node = &expected_mast_forest[combined_node_id]; + let expected_program = Program::new(expected_mast_forest, combined_node_id); - assert_eq!(combined_node.digest(), program.entrypoint_digest().unwrap()); + assert_eq!(expected_program.hash(), program.hash()); } diff --git a/assembly/src/testing.rs b/assembly/src/testing.rs index 4bc3e20310..a9b99051a8 100644 --- a/assembly/src/testing.rs +++ b/assembly/src/testing.rs @@ -308,10 +308,7 @@ impl TestContext { /// module represented in `source`. #[track_caller] pub fn assemble(&mut self, source: impl Compile) -> Result { - self.assembler - .clone() - .assemble(source) - .map(|mast_forest| mast_forest.try_into().unwrap()) + self.assembler.clone().assemble(source) } /// Compile a module from `source`, with the fully-qualified name `path`, to MAST, returning diff --git a/core/src/errors.rs b/core/src/errors.rs index a3c01446c6..95294f522f 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -46,6 +46,7 @@ pub enum KernelError { #[derive(Clone, Debug, thiserror::Error)] pub enum ProgramError { + // TODOP: REMOVE #[error("tried to create a program from a MAST forest with no entrypoint")] NoEntrypoint, } diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 3f4b7476ff..63c72df101 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -9,8 +9,6 @@ pub use node::{ OpBatch, SplitNode, OP_BATCH_SIZE, OP_GROUP_SIZE, }; -use crate::Kernel; - #[cfg(test)] mod tests; @@ -46,15 +44,6 @@ pub struct MastForest { /// Roots of all procedures defined within this MAST forest. roots: Vec, - - // TODOP: Move fields to `Program` - - /// The "entrypoint", when set, is the root of the entire forest, i.e. a path exists from this - /// node to all other roots in the forest. This corresponds to the executable entry point. - /// Whether or not the entrypoint is set distinguishes a MAST which is executable, versus a - /// MAST which represents a library. - entrypoint: Option, - kernel: Kernel, } /// Constructors @@ -87,45 +76,10 @@ impl MastForest { self.roots.push(new_root_id); } } - - /// Sets the kernel for this forest. - /// - /// The kernel MUST have been compiled using this [`MastForest`]; that is, all kernel procedures - /// must be present in this forest. - pub fn set_kernel(&mut self, kernel: Kernel) { - #[cfg(debug_assertions)] - for &proc_hash in kernel.proc_hashes() { - assert!(self.find_root(proc_hash).is_some()); - } - - self.kernel = kernel; - } - - /// Sets the entrypoint for this forest. This also ensures that the entrypoint is a root in the - /// forest. - pub fn set_entrypoint(&mut self, entrypoint: MastNodeId) { - self.ensure_root(entrypoint); - self.entrypoint = Some(entrypoint); - } } /// Public accessors impl MastForest { - /// Returns the kernel associated with this forest. - pub fn kernel(&self) -> &Kernel { - &self.kernel - } - - /// Returns the entrypoint associated with this forest, if any. - pub fn entrypoint(&self) -> Option { - self.entrypoint - } - - /// A convenience method that provides the hash of the entrypoint, if any. - pub fn entrypoint_digest(&self) -> Option { - self.entrypoint.map(|entrypoint| self[entrypoint].digest()) - } - /// Returns the [`MastNode`] associated with the provided [`MastNodeId`] if valid, or else /// `None`. /// diff --git a/core/src/mast/tests.rs b/core/src/mast/tests.rs index 5c4e54e738..c515cc98a6 100644 --- a/core/src/mast/tests.rs +++ b/core/src/mast/tests.rs @@ -1,7 +1,5 @@ use crate::{ - chiplets::hasher, - mast::{DynNode, Kernel, MerkleTreeNode}, - ProgramInfo, Word, + chiplets::hasher, mast::{DynNode, MerkleTreeNode}, Kernel, ProgramInfo, Word }; use alloc::vec::Vec; use miden_crypto::{hash::rpo::RpoDigest, Felt}; diff --git a/core/src/program.rs b/core/src/program.rs index 6b7c8bdb90..eee885a8d5 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -5,8 +5,7 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt, WORD_SIZE}; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use crate::{ - errors::ProgramError, - mast::{MastForest, MastNode, MastNodeId}, + mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, utils::ToElements, }; @@ -18,15 +17,34 @@ use super::Kernel; #[derive(Clone, Debug)] pub struct Program { mast_forest: MastForest, + /// The "entrypoint", when set, is the root of the entire forest, i.e. a path exists from this + /// node to all other roots in the forest. This corresponds to the executable entry point. + /// Whether or not the entrypoint is set distinguishes a MAST which is executable, versus a + /// MAST which represents a library. + entrypoint: MastNodeId, + kernel: Kernel, } /// Constructors impl Program { - pub fn new(mast_forest: MastForest) -> Result { - if mast_forest.entrypoint().is_some() { - Ok(Self { mast_forest }) - } else { - Err(ProgramError::NoEntrypoint) + // TODOP: Document + pub fn new(mast_forest: MastForest, entrypoint: MastNodeId) -> Self { + debug_assert!(mast_forest.get_node_by_id(entrypoint).is_some()); + + Self { + mast_forest, + entrypoint, + kernel: Kernel::default(), + } + } + + pub fn new_with_kernel(mast_forest: MastForest, entrypoint: MastNodeId, kernel: Kernel) -> Self { + debug_assert!(mast_forest.get_node_by_id(entrypoint).is_some()); + + Self { + mast_forest, + entrypoint, + kernel, } } } @@ -40,17 +58,17 @@ impl Program { /// Returns the kernel associated with this program. pub fn kernel(&self) -> &Kernel { - self.mast_forest.kernel() + &self.kernel } /// Returns the entrypoint associated with this program. pub fn entrypoint(&self) -> MastNodeId { - self.mast_forest.entrypoint().unwrap() + self.entrypoint } /// A convenience method that provides the hash of the entrypoint. pub fn hash(&self) -> RpoDigest { - self.mast_forest.entrypoint_digest().unwrap() + self.mast_forest[self.entrypoint].digest() } /// Returns the [`MastNode`] associated with the provided [`MastNodeId`] if valid, or else @@ -62,12 +80,13 @@ impl Program { self.mast_forest.get_node_by_id(node_id) } + // TODOP: fix docs /// Returns the [`MastNodeId`] associated with a given digest, if any. /// /// That is, every [`MastNode`] hashes to some digest. If there exists a [`MastNode`] in the /// forest that hashes to this digest, then its id is returned. #[inline(always)] - pub fn get_node_id_by_digest(&self, digest: RpoDigest) -> Option { + pub fn find_root(&self, digest: RpoDigest) -> Option { self.mast_forest.find_root(digest) } } @@ -96,14 +115,6 @@ impl fmt::Display for Program { } } -impl TryFrom for Program { - type Error = ProgramError; - - fn try_from(mast_forest: MastForest) -> Result { - Self::new(mast_forest) - } -} - impl From for MastForest { fn from(program: Program) -> Self { program.mast_forest diff --git a/miden/README.md b/miden/README.md index 80e87ad5e9..7176467b56 100644 --- a/miden/README.md +++ b/miden/README.md @@ -57,7 +57,7 @@ use processor::ExecutionOptions; let mut assembler = Assembler::default(); // compile Miden assembly source code into a program -let program = assembler.assemble_program("begin push.3 push.5 add end").unwrap(); +let program = assembler.assemble("begin push.3 push.5 add end").unwrap(); // use an empty list as initial stack let stack_inputs = StackInputs::default(); @@ -105,7 +105,7 @@ use miden_vm::{Assembler, DefaultHost, ProvingOptions, Program, prove, StackInpu let mut assembler = Assembler::default(); // this is our program, we compile it from assembly code -let program = assembler.assemble_program("begin push.3 push.5 add end").unwrap(); +let program = assembler.assemble("begin push.3 push.5 add end").unwrap(); // let's execute it and generate a STARK proof let (outputs, proof) = prove( @@ -193,7 +193,7 @@ let source = format!( n - 1 ); let mut assembler = Assembler::default(); -let program = assembler.assemble_program(&source).unwrap(); +let program = assembler.assemble(&source).unwrap(); // initialize a default host (with an empty advice provider) let host = DefaultHost::default(); diff --git a/miden/src/cli/data.rs b/miden/src/cli/data.rs index 1d797b0c7b..248e9579a9 100644 --- a/miden/src/cli/data.rs +++ b/miden/src/cli/data.rs @@ -420,7 +420,7 @@ impl ProgramFile { .wrap_err("Failed to load libraries")?; let program: Program = assembler - .assemble_program(self.ast.as_ref()) + .assemble(self.ast.as_ref()) .wrap_err("Failed to compile program")?; Ok(program) diff --git a/miden/src/examples/blake3.rs b/miden/src/examples/blake3.rs index 7bae853923..2e87cadbcc 100644 --- a/miden/src/examples/blake3.rs +++ b/miden/src/examples/blake3.rs @@ -47,7 +47,7 @@ fn generate_blake3_program(n: usize) -> Program { Assembler::default() .with_library(&StdLibrary::default()) .unwrap() - .assemble_program(program) + .assemble(program) .unwrap() } diff --git a/miden/src/examples/fibonacci.rs b/miden/src/examples/fibonacci.rs index 4629f960c2..2dcd778747 100644 --- a/miden/src/examples/fibonacci.rs +++ b/miden/src/examples/fibonacci.rs @@ -41,7 +41,7 @@ fn generate_fibonacci_program(n: usize) -> Program { n - 1 ); - Assembler::default().assemble_program(program).unwrap() + Assembler::default().assemble(program).unwrap() } /// Computes the `n`-th term of Fibonacci sequence diff --git a/miden/src/repl/mod.rs b/miden/src/repl/mod.rs index aea32e62ea..a48bfdcfc7 100644 --- a/miden/src/repl/mod.rs +++ b/miden/src/repl/mod.rs @@ -293,7 +293,7 @@ fn execute( .with_libraries(provided_libraries.iter()) .map_err(|err| format!("{err}"))?; - let program = assembler.assemble_program(program).map_err(|err| format!("{err}"))?; + let program = assembler.assemble(program).map_err(|err| format!("{err}"))?; let stack_inputs = StackInputs::default(); let host = DefaultHost::default(); diff --git a/miden/src/tools/mod.rs b/miden/src/tools/mod.rs index fce3d34dc8..d4bb16c3b8 100644 --- a/miden/src/tools/mod.rs +++ b/miden/src/tools/mod.rs @@ -216,7 +216,7 @@ where let program = Assembler::default() .with_debug_mode(true) .with_library(&StdLibrary::default())? - .assemble_program(program)?; + .assemble(program)?; let mut execution_details = ExecutionDetails::default(); let vm_state_iterator = processor::execute_iter(&program, stack_inputs, host); diff --git a/processor/src/chiplets/tests.rs b/processor/src/chiplets/tests.rs index 4c902c84d0..8353376989 100644 --- a/processor/src/chiplets/tests.rs +++ b/processor/src/chiplets/tests.rs @@ -14,7 +14,7 @@ use miden_air::trace::{ }; use vm_core::{ mast::{MastForest, MastNode}, - Felt, ONE, ZERO, + Felt, Program, ONE, ZERO, }; type ChipletsTrace = [Vec; CHIPLETS_WIDTH]; @@ -121,9 +121,8 @@ fn build_trace( let basic_block = MastNode::new_basic_block(operations); let basic_block_id = mast_forest.add_node(basic_block); - mast_forest.set_entrypoint(basic_block_id); - mast_forest.try_into().unwrap() + Program::new(mast_forest, basic_block_id) }; process.execute(&program).unwrap(); diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index 840b959e43..96fd6e31a8 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -51,9 +51,8 @@ fn basic_block_one_group() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node); - mast_forest.set_entrypoint(basic_block_id); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -98,9 +97,8 @@ fn basic_block_small() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node); - mast_forest.set_entrypoint(basic_block_id); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -162,9 +160,8 @@ fn basic_block() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node); - mast_forest.set_entrypoint(basic_block_id); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -255,9 +252,8 @@ fn span_block_with_respan() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node); - mast_forest.set_entrypoint(basic_block_id); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -333,9 +329,8 @@ fn join_node() { let join_node = MastNode::new_join(basic_block1_id, basic_block2_id, &mast_forest); let join_node_id = mast_forest.add_node(join_node); - mast_forest.set_entrypoint(join_node_id); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, join_node_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -400,9 +395,8 @@ fn split_node_true() { let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest); let split_node_id = mast_forest.add_node(split_node); - mast_forest.set_entrypoint(split_node_id); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, split_node_id) }; let (trace, trace_len) = build_trace(&[1], &program); @@ -454,9 +448,8 @@ fn split_node_false() { let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest); let split_node_id = mast_forest.add_node(split_node); - mast_forest.set_entrypoint(split_node_id); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, split_node_id) }; let (trace, trace_len) = build_trace(&[0], &program); @@ -509,9 +502,8 @@ fn loop_node() { let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); let loop_node_id = mast_forest.add_node(loop_node); - mast_forest.set_entrypoint(loop_node_id); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, loop_node_id) }; let (trace, trace_len) = build_trace(&[0, 1], &program); @@ -563,9 +555,8 @@ fn loop_node_skip() { let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); let loop_node_id = mast_forest.add_node(loop_node); - mast_forest.set_entrypoint(loop_node_id); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, loop_node_id) }; let (trace, trace_len) = build_trace(&[0], &program); @@ -607,9 +598,8 @@ fn loop_node_repeat() { let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); let loop_node_id = mast_forest.add_node(loop_node); - mast_forest.set_entrypoint(loop_node_id); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, loop_node_id) }; let (trace, trace_len) = build_trace(&[0, 1, 1], &program); @@ -711,9 +701,8 @@ fn call_block() { let program_root = MastNode::new_join(join1_node_id, last_basic_block_id, &mast_forest); let program_root_id = mast_forest.add_node(program_root); - mast_forest.set_entrypoint(program_root_id); - let program = Program::new(mast_forest).unwrap(); + let program = Program::new(mast_forest, program_root_id); let (sys_trace, dec_trace, trace_len) = build_call_trace(&program, Kernel::default()); @@ -907,7 +896,6 @@ fn syscall_block() { let foo_root_id = mast_forest.add_node(foo_root.clone()); mast_forest.ensure_root(foo_root_id); let kernel = Kernel::new(&[foo_root.digest()]).unwrap(); - mast_forest.set_kernel(kernel.clone()); // build bar procedure body let bar_basic_block = MastNode::new_basic_block(vec![Operation::Push(TWO), Operation::FmpUpdate]); @@ -939,9 +927,8 @@ fn syscall_block() { let program_root_node = MastNode::new_join(inner_join_node_id, last_basic_block_id, &mast_forest); let program_root_node_id = mast_forest.add_node(program_root_node.clone()); - mast_forest.set_entrypoint(program_root_node_id); - let program = Program::new(mast_forest).unwrap(); + let program = Program::new_with_kernel(mast_forest, program_root_node_id, kernel.clone()); let (sys_trace, dec_trace, trace_len) = build_call_trace(&program, kernel); @@ -1212,9 +1199,8 @@ fn dyn_block() { let program_root_node = MastNode::new_join(join_node_id, dyn_node_id, &mast_forest); let program_root_node_id = mast_forest.add_node(program_root_node.clone()); - mast_forest.set_entrypoint(program_root_node_id); - let program = Program::new(mast_forest).unwrap(); + let program = Program::new(mast_forest, program_root_node_id); let (trace, trace_len) = build_dyn_trace( &[ @@ -1321,9 +1307,8 @@ fn set_user_op_helpers_many() { let basic_block = MastNode::new_basic_block(vec![Operation::U32div]); let basic_block_id = mast_forest.add_node(basic_block); - mast_forest.set_entrypoint(basic_block_id); - mast_forest.try_into().unwrap() + Program::new(mast_forest, basic_block_id) }; let a = rand_value::(); let b = rand_value::(); diff --git a/processor/src/lib.rs b/processor/src/lib.rs index 6ca7974f81..ebbee12773 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -375,7 +375,7 @@ where // get dynamic code from the code block table and execute it let callee_id = program - .get_node_id_by_digest(callee_hash.into()) + .find_root(callee_hash.into()) .ok_or_else(|| ExecutionError::DynamicNodeNotFound(callee_hash.into()))?; self.execute_mast_node(callee_id, program)?; diff --git a/processor/src/trace/tests/chiplets/hasher.rs b/processor/src/trace/tests/chiplets/hasher.rs index 620397dbf5..992cee60d4 100644 --- a/processor/src/trace/tests/chiplets/hasher.rs +++ b/processor/src/trace/tests/chiplets/hasher.rs @@ -52,9 +52,8 @@ pub fn b_chip_span() { let basic_block = MastNode::new_basic_block(vec![Operation::Add, Operation::Mul]); let basic_block_id = mast_forest.add_node(basic_block); - mast_forest.set_entrypoint(basic_block_id); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, basic_block_id) }; let trace = build_trace_from_program(&program, &[]); @@ -126,9 +125,8 @@ pub fn b_chip_span_with_respan() { let (ops, _) = build_span_with_respan_ops(); let basic_block = MastNode::new_basic_block(ops); let basic_block_id = mast_forest.add_node(basic_block); - mast_forest.set_entrypoint(basic_block_id); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, basic_block_id) }; let trace = build_trace_from_program(&program, &[]); @@ -226,9 +224,7 @@ pub fn b_chip_merge() { let split = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); let split_id = mast_forest.add_node(split); - mast_forest.set_entrypoint(split_id); - - Program::new(mast_forest).unwrap() + Program::new(mast_forest, split_id) }; let trace = build_trace_from_program(&program, &[]); @@ -340,9 +336,8 @@ pub fn b_chip_permutation() { let basic_block = MastNode::new_basic_block(vec![Operation::HPerm]); let basic_block_id = mast_forest.add_node(basic_block); - mast_forest.set_entrypoint(basic_block_id); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, basic_block_id) }; let stack = vec![8, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]; let trace = build_trace_from_program(&program, &stack); diff --git a/processor/src/trace/tests/decoder.rs b/processor/src/trace/tests/decoder.rs index 32a448ddff..3b20428865 100644 --- a/processor/src/trace/tests/decoder.rs +++ b/processor/src/trace/tests/decoder.rs @@ -81,9 +81,7 @@ fn decoder_p1_join() { let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); let join_id = mast_forest.add_node(join); - mast_forest.set_entrypoint(join_id); - - Program::new(mast_forest).unwrap() + Program::new(mast_forest, join_id) }; let trace = build_trace_from_program(&program, &[]); @@ -154,9 +152,7 @@ fn decoder_p1_split() { let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); let split_id = mast_forest.add_node(split); - mast_forest.set_entrypoint(split_id); - - Program::new(mast_forest).unwrap() + Program::new(mast_forest, split_id) }; let trace = build_trace_from_program(&program, &[1]); @@ -217,9 +213,7 @@ fn decoder_p1_loop_with_repeat() { let loop_node = MastNode::new_loop(join_id, &mast_forest); let loop_node_id = mast_forest.add_node(loop_node); - mast_forest.set_entrypoint(loop_node_id); - - Program::new(mast_forest).unwrap() + Program::new(mast_forest, loop_node_id) }; let trace = build_trace_from_program(&program, &[0, 1, 1]); @@ -340,9 +334,8 @@ fn decoder_p2_span_with_respan() { let (ops, _) = build_span_with_respan_ops(); let basic_block = MastNode::new_basic_block(ops); let basic_block_id = mast_forest.add_node(basic_block); - mast_forest.set_entrypoint(basic_block_id); - Program::new(mast_forest).unwrap() + Program::new(mast_forest, basic_block_id) }; let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); @@ -383,9 +376,8 @@ fn decoder_p2_join() { let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); let join_id = mast_forest.add_node(join.clone()); - mast_forest.set_entrypoint(join_id); - let program = Program::new(mast_forest).unwrap(); + let program = Program::new(mast_forest, join_id); let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); @@ -450,9 +442,7 @@ fn decoder_p2_split_true() { let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); let split_id = mast_forest.add_node(split); - mast_forest.set_entrypoint(split_id); - - let program = Program::new(mast_forest).unwrap(); + let program = Program::new(mast_forest, split_id); // build trace from program let trace = build_trace_from_program(&program, &[1]); @@ -508,9 +498,7 @@ fn decoder_p2_split_false() { let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); let split_id = mast_forest.add_node(split); - mast_forest.set_entrypoint(split_id); - - let program = Program::new(mast_forest).unwrap(); + let program = Program::new(mast_forest, split_id); // build trace from program let trace = build_trace_from_program(&program, &[0]); @@ -569,9 +557,7 @@ fn decoder_p2_loop_with_repeat() { let loop_node = MastNode::new_loop(join_id, &mast_forest); let loop_node_id = mast_forest.add_node(loop_node); - mast_forest.set_entrypoint(loop_node_id); - - let program = Program::new(mast_forest).unwrap(); + let program = Program::new(mast_forest, loop_node_id); // build trace from program let trace = build_trace_from_program(&program, &[0, 1, 1]); diff --git a/processor/src/trace/tests/mod.rs b/processor/src/trace/tests/mod.rs index 6a57960d47..19c11defbc 100644 --- a/processor/src/trace/tests/mod.rs +++ b/processor/src/trace/tests/mod.rs @@ -36,9 +36,8 @@ pub fn build_trace_from_ops(operations: Vec, stack: &[u64]) -> Execut let basic_block = MastNode::new_basic_block(operations); let basic_block_id = mast_forest.add_node(basic_block); - mast_forest.set_entrypoint(basic_block_id); - let program = Program::new(mast_forest).unwrap(); + let program = Program::new(mast_forest, basic_block_id); build_trace_from_program(&program, stack) } @@ -59,9 +58,8 @@ pub fn build_trace_from_ops_with_inputs( let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(operations); let basic_block_id = mast_forest.add_node(basic_block); - mast_forest.set_entrypoint(basic_block_id); - let program = Program::new(mast_forest).unwrap(); + let program = Program::new(mast_forest, basic_block_id); process.execute(&program).unwrap(); ExecutionTrace::new(process, StackOutputs::default()) diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 95929813a7..ce8b28f172 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -297,7 +297,7 @@ impl Test { .with_libraries(self.libraries.iter()) .expect("failed to load stdlib"); - assembler.assemble_program(self.source.clone()) + assembler.assemble(self.source.clone()) } /// Compiles the test's source to a Program and executes it with the tests inputs. Returns a From 32e757e3a2ea4a04513433f870486ac310c8c75c Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 21 Jun 2024 09:48:59 -0400 Subject: [PATCH 006/172] Remove ProgramError --- core/src/errors.rs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/core/src/errors.rs b/core/src/errors.rs index 95294f522f..5e4d0428e1 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -40,13 +40,3 @@ pub enum KernelError { #[error("kernel can have at most {0} procedures, received {1}")] TooManyProcedures(usize, usize), } - -// PROGRAM ERROR -// ================================================================================================ - -#[derive(Clone, Debug, thiserror::Error)] -pub enum ProgramError { - // TODOP: REMOVE - #[error("tried to create a program from a MAST forest with no entrypoint")] - NoEntrypoint, -} From f87217a8308f2dad7e5dada6b3138c44ea695fdc Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 21 Jun 2024 10:08:20 -0400 Subject: [PATCH 007/172] docs --- assembly/src/assembler/mod.rs | 4 ++-- core/src/mast/mod.rs | 21 ++++++++++++--------- processor/src/decoder/tests.rs | 6 +++--- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 0dea527ce8..69502de589 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -817,7 +817,7 @@ impl Assembler { Ok(if mast_node_ids.is_empty() { let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]); let basic_block_node_id = mast_forest.add_node(basic_block_node); - mast_forest.ensure_root(basic_block_node_id); + mast_forest.make_root(basic_block_node_id); basic_block_node_id } else { @@ -890,7 +890,7 @@ fn combine_mast_node_ids( } let root_id = mast_node_ids.remove(0); - mast_forest.ensure_root(root_id); + mast_forest.make_root(root_id); root_id } diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index ee7c90dad9..f98bb17cbf 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -18,13 +18,12 @@ pub trait MerkleTreeNode { fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a; } -// TODOP: Remove `PartialEq/Eq` impls /// An opaque handle to a [`MastNode`] in some [`MastForest`]. It is the responsibility of the user /// to use a given [`MastNodeId`] with the corresponding [`MastForest`]. /// -/// Note that since a [`MastForest`] enforces the invariant that equal [`MastNode`]s MUST have equal -/// [`MastNodeId`]s, [`MastNodeId`] equality can be used to determine equality of the underlying -/// [`MastNode`]. +/// Note that the [`MastForest`] does *not* ensure that equal [`MastNode`]s have equal +/// [`MastNodeId`] handles. Hence, [`MastNodeId`] equality must not be used to test for equality of +/// the underlying [`MastNode`]. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct MastNodeId(u32); @@ -37,12 +36,16 @@ impl fmt::Display for MastNodeId { // MAST FOREST // =============================================================================================== +/// Represents one or more procedures, represented as a collection of [`MastNode`]s. +/// +/// A [`MastForest`] does not have an entrypoint, and hence is not executable. A [`crate::Program`] +/// can be built from a [`MastForest`] to specify an entrypoint. #[derive(Clone, Debug, Default)] pub struct MastForest { - /// All of the blocks local to the trees comprising the MAST forest. + /// All of the nodes local to the trees comprising the MAST forest. nodes: Vec, - /// Roots of all procedures defined within this MAST forest. + /// Roots of procedures defined within this MAST forest. roots: Vec, } @@ -56,7 +59,7 @@ impl MastForest { /// Mutators impl MastForest { - /// Adds a node to the forest, and returns the [`MastNodeId`] associated with it. + /// Adds a node to the forest, and returns the associated [`MastNodeId`]. pub fn add_node(&mut self, node: MastNode) -> MastNodeId { let new_node_id = MastNodeId( self.nodes @@ -70,8 +73,8 @@ impl MastForest { new_node_id } - // TODOP: Document - pub fn ensure_root(&mut self, new_root_id: MastNodeId) { + /// Marks the given [`MastNodeId`] as being the root of a procedure. + pub fn make_root(&mut self, new_root_id: MastNodeId) { if !self.roots.contains(&new_root_id) { self.roots.push(new_root_id); } diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index 96fd6e31a8..133be1f7a4 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -894,7 +894,7 @@ fn syscall_block() { // build foo procedure body let foo_root = MastNode::new_basic_block(vec![Operation::Push(THREE), Operation::FmpUpdate]); let foo_root_id = mast_forest.add_node(foo_root.clone()); - mast_forest.ensure_root(foo_root_id); + mast_forest.make_root(foo_root_id); let kernel = Kernel::new(&[foo_root.digest()]).unwrap(); // build bar procedure body @@ -906,7 +906,7 @@ fn syscall_block() { let bar_root_node = MastNode::new_join(bar_basic_block_id, foo_call_node_id, &mast_forest); let bar_root_node_id = mast_forest.add_node(bar_root_node.clone()); - mast_forest.ensure_root(bar_root_node_id); + mast_forest.make_root(bar_root_node_id); // build the program let first_basic_block = MastNode::new_basic_block(vec![ @@ -1182,7 +1182,7 @@ fn dyn_block() { let foo_root_node = MastNode::new_basic_block(vec![Operation::Push(ONE), Operation::Add]); let foo_root_node_id = mast_forest.add_node(foo_root_node.clone()); - mast_forest.ensure_root(foo_root_node_id); + mast_forest.make_root(foo_root_node_id); let mul_bb_node = MastNode::new_basic_block(vec![Operation::Mul]); let mul_bb_node_id = mast_forest.add_node(mul_bb_node.clone()); From 71f35f1a667519d2e0bc381052c8d208eb035f42 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 21 Jun 2024 10:13:31 -0400 Subject: [PATCH 008/172] cleanup Program constructors --- assembly/src/assembler/mod.rs | 2 +- assembly/src/assembler/tests.rs | 2 +- core/src/mast/mod.rs | 2 +- core/src/program.rs | 27 +------------------ processor/src/chiplets/tests.rs | 2 +- processor/src/decoder/tests.rs | 28 ++++++++++---------- processor/src/lib.rs | 2 +- processor/src/trace/tests/chiplets/hasher.rs | 8 +++--- processor/src/trace/tests/decoder.rs | 16 +++++------ processor/src/trace/tests/mod.rs | 4 +-- 10 files changed, 34 insertions(+), 59 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 69502de589..8f1a3ec79f 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -579,7 +579,7 @@ impl Assembler { let entry_procedure = self.compile_subgraph(entrypoint, true, context, &mut mast_forest)?; Ok(Program::new_with_kernel( - mast_forest, + mast_forest.into(), entry_procedure.body_node_id(), self.module_graph.kernel().clone(), )) diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index c39b91113d..83853525fc 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -208,7 +208,7 @@ fn nested_blocks() { vec![before, r#if1, nested, exec_foo_bar_baz_node_id, syscall_foo_node_id], &mut expected_mast_forest, ); - let expected_program = Program::new(expected_mast_forest, combined_node_id); + let expected_program = Program::new(expected_mast_forest.into(), combined_node_id); assert_eq!(expected_program.hash(), program.hash()); } diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index f98bb17cbf..c2d68d8f5e 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -37,7 +37,7 @@ impl fmt::Display for MastNodeId { // =============================================================================================== /// Represents one or more procedures, represented as a collection of [`MastNode`]s. -/// +/// /// A [`MastForest`] does not have an entrypoint, and hence is not executable. A [`crate::Program`] /// can be built from a [`MastForest`] to specify an entrypoint. #[derive(Clone, Debug, Default)] diff --git a/core/src/program.rs b/core/src/program.rs index 3820a1b142..9f24da4e64 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -27,18 +27,7 @@ pub struct Program { /// Constructors impl Program { - // TODOP: Document - pub fn new(mast_forest: MastForest, entrypoint: MastNodeId) -> Self { - debug_assert!(mast_forest.get_node_by_id(entrypoint).is_some()); - - Self { - mast_forest: Arc::new(mast_forest), - entrypoint, - kernel: Kernel::default(), - } - } - - pub fn new_shared(mast_forest: Arc, entrypoint: MastNodeId) -> Self { + pub fn new(mast_forest: Arc, entrypoint: MastNodeId) -> Self { debug_assert!(mast_forest.get_node_by_id(entrypoint).is_some()); Self { @@ -49,20 +38,6 @@ impl Program { } pub fn new_with_kernel( - mast_forest: MastForest, - entrypoint: MastNodeId, - kernel: Kernel, - ) -> Self { - debug_assert!(mast_forest.get_node_by_id(entrypoint).is_some()); - - Self { - mast_forest: Arc::new(mast_forest), - entrypoint, - kernel, - } - } - - pub fn new_shared_with_kernel( mast_forest: Arc, entrypoint: MastNodeId, kernel: Kernel, diff --git a/processor/src/chiplets/tests.rs b/processor/src/chiplets/tests.rs index 8353376989..0b137b6304 100644 --- a/processor/src/chiplets/tests.rs +++ b/processor/src/chiplets/tests.rs @@ -122,7 +122,7 @@ fn build_trace( let basic_block = MastNode::new_basic_block(operations); let basic_block_id = mast_forest.add_node(basic_block); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; process.execute(&program).unwrap(); diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index 133be1f7a4..39862e191c 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -52,7 +52,7 @@ fn basic_block_one_group() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -98,7 +98,7 @@ fn basic_block_small() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -161,7 +161,7 @@ fn basic_block() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -253,7 +253,7 @@ fn span_block_with_respan() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -330,7 +330,7 @@ fn join_node() { let join_node = MastNode::new_join(basic_block1_id, basic_block2_id, &mast_forest); let join_node_id = mast_forest.add_node(join_node); - Program::new(mast_forest, join_node_id) + Program::new(mast_forest.into(), join_node_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -396,7 +396,7 @@ fn split_node_true() { let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest); let split_node_id = mast_forest.add_node(split_node); - Program::new(mast_forest, split_node_id) + Program::new(mast_forest.into(), split_node_id) }; let (trace, trace_len) = build_trace(&[1], &program); @@ -449,7 +449,7 @@ fn split_node_false() { let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest); let split_node_id = mast_forest.add_node(split_node); - Program::new(mast_forest, split_node_id) + Program::new(mast_forest.into(), split_node_id) }; let (trace, trace_len) = build_trace(&[0], &program); @@ -503,7 +503,7 @@ fn loop_node() { let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); let loop_node_id = mast_forest.add_node(loop_node); - Program::new(mast_forest, loop_node_id) + Program::new(mast_forest.into(), loop_node_id) }; let (trace, trace_len) = build_trace(&[0, 1], &program); @@ -556,7 +556,7 @@ fn loop_node_skip() { let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); let loop_node_id = mast_forest.add_node(loop_node); - Program::new(mast_forest, loop_node_id) + Program::new(mast_forest.into(), loop_node_id) }; let (trace, trace_len) = build_trace(&[0], &program); @@ -599,7 +599,7 @@ fn loop_node_repeat() { let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); let loop_node_id = mast_forest.add_node(loop_node); - Program::new(mast_forest, loop_node_id) + Program::new(mast_forest.into(), loop_node_id) }; let (trace, trace_len) = build_trace(&[0, 1, 1], &program); @@ -702,7 +702,7 @@ fn call_block() { let program_root = MastNode::new_join(join1_node_id, last_basic_block_id, &mast_forest); let program_root_id = mast_forest.add_node(program_root); - let program = Program::new(mast_forest, program_root_id); + let program = Program::new(mast_forest.into(), program_root_id); let (sys_trace, dec_trace, trace_len) = build_call_trace(&program, Kernel::default()); @@ -928,7 +928,7 @@ fn syscall_block() { let program_root_node = MastNode::new_join(inner_join_node_id, last_basic_block_id, &mast_forest); let program_root_node_id = mast_forest.add_node(program_root_node.clone()); - let program = Program::new_with_kernel(mast_forest, program_root_node_id, kernel.clone()); + let program = Program::new_with_kernel(mast_forest.into(), program_root_node_id, kernel.clone()); let (sys_trace, dec_trace, trace_len) = build_call_trace(&program, kernel); @@ -1200,7 +1200,7 @@ fn dyn_block() { let program_root_node = MastNode::new_join(join_node_id, dyn_node_id, &mast_forest); let program_root_node_id = mast_forest.add_node(program_root_node.clone()); - let program = Program::new(mast_forest, program_root_node_id); + let program = Program::new(mast_forest.into(), program_root_node_id); let (trace, trace_len) = build_dyn_trace( &[ @@ -1308,7 +1308,7 @@ fn set_user_op_helpers_many() { let basic_block = MastNode::new_basic_block(vec![Operation::U32div]); let basic_block_id = mast_forest.add_node(basic_block); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let a = rand_value::(); let b = rand_value::(); diff --git a/processor/src/lib.rs b/processor/src/lib.rs index 6595086bca..35651c163a 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -266,7 +266,7 @@ where // TODOP: Don't clone kernel let program = - Program::new_shared_with_kernel(mast_forest, root_id, program.kernel().clone()); + Program::new_with_kernel(mast_forest, root_id, program.kernel().clone()); self.execute_mast_node(root_id, &program) } } diff --git a/processor/src/trace/tests/chiplets/hasher.rs b/processor/src/trace/tests/chiplets/hasher.rs index 992cee60d4..31944a6460 100644 --- a/processor/src/trace/tests/chiplets/hasher.rs +++ b/processor/src/trace/tests/chiplets/hasher.rs @@ -53,7 +53,7 @@ pub fn b_chip_span() { let basic_block = MastNode::new_basic_block(vec![Operation::Add, Operation::Mul]); let basic_block_id = mast_forest.add_node(basic_block); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let trace = build_trace_from_program(&program, &[]); @@ -126,7 +126,7 @@ pub fn b_chip_span_with_respan() { let basic_block = MastNode::new_basic_block(ops); let basic_block_id = mast_forest.add_node(basic_block); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let trace = build_trace_from_program(&program, &[]); @@ -224,7 +224,7 @@ pub fn b_chip_merge() { let split = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); let split_id = mast_forest.add_node(split); - Program::new(mast_forest, split_id) + Program::new(mast_forest.into(), split_id) }; let trace = build_trace_from_program(&program, &[]); @@ -337,7 +337,7 @@ pub fn b_chip_permutation() { let basic_block = MastNode::new_basic_block(vec![Operation::HPerm]); let basic_block_id = mast_forest.add_node(basic_block); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let stack = vec![8, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]; let trace = build_trace_from_program(&program, &stack); diff --git a/processor/src/trace/tests/decoder.rs b/processor/src/trace/tests/decoder.rs index 3b20428865..e1d6a3ea56 100644 --- a/processor/src/trace/tests/decoder.rs +++ b/processor/src/trace/tests/decoder.rs @@ -81,7 +81,7 @@ fn decoder_p1_join() { let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); let join_id = mast_forest.add_node(join); - Program::new(mast_forest, join_id) + Program::new(mast_forest.into(), join_id) }; let trace = build_trace_from_program(&program, &[]); @@ -152,7 +152,7 @@ fn decoder_p1_split() { let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); let split_id = mast_forest.add_node(split); - Program::new(mast_forest, split_id) + Program::new(mast_forest.into(), split_id) }; let trace = build_trace_from_program(&program, &[1]); @@ -213,7 +213,7 @@ fn decoder_p1_loop_with_repeat() { let loop_node = MastNode::new_loop(join_id, &mast_forest); let loop_node_id = mast_forest.add_node(loop_node); - Program::new(mast_forest, loop_node_id) + Program::new(mast_forest.into(), loop_node_id) }; let trace = build_trace_from_program(&program, &[0, 1, 1]); @@ -335,7 +335,7 @@ fn decoder_p2_span_with_respan() { let basic_block = MastNode::new_basic_block(ops); let basic_block_id = mast_forest.add_node(basic_block); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); @@ -377,7 +377,7 @@ fn decoder_p2_join() { let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); let join_id = mast_forest.add_node(join.clone()); - let program = Program::new(mast_forest, join_id); + let program = Program::new(mast_forest.into(), join_id); let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); @@ -442,7 +442,7 @@ fn decoder_p2_split_true() { let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); let split_id = mast_forest.add_node(split); - let program = Program::new(mast_forest, split_id); + let program = Program::new(mast_forest.into(), split_id); // build trace from program let trace = build_trace_from_program(&program, &[1]); @@ -498,7 +498,7 @@ fn decoder_p2_split_false() { let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); let split_id = mast_forest.add_node(split); - let program = Program::new(mast_forest, split_id); + let program = Program::new(mast_forest.into(), split_id); // build trace from program let trace = build_trace_from_program(&program, &[0]); @@ -557,7 +557,7 @@ fn decoder_p2_loop_with_repeat() { let loop_node = MastNode::new_loop(join_id, &mast_forest); let loop_node_id = mast_forest.add_node(loop_node); - let program = Program::new(mast_forest, loop_node_id); + let program = Program::new(mast_forest.into(), loop_node_id); // build trace from program let trace = build_trace_from_program(&program, &[0, 1, 1]); diff --git a/processor/src/trace/tests/mod.rs b/processor/src/trace/tests/mod.rs index b324589013..5d8c0a69f0 100644 --- a/processor/src/trace/tests/mod.rs +++ b/processor/src/trace/tests/mod.rs @@ -40,7 +40,7 @@ pub fn build_trace_from_ops(operations: Vec, stack: &[u64]) -> Execut let basic_block = MastNode::new_basic_block(operations); let basic_block_id = mast_forest.add_node(basic_block); - let program = Program::new(mast_forest, basic_block_id); + let program = Program::new(mast_forest.into(), basic_block_id); build_trace_from_program(&program, stack) } @@ -63,7 +63,7 @@ pub fn build_trace_from_ops_with_inputs( let basic_block = MastNode::new_basic_block(operations); let basic_block_id = mast_forest.add_node(basic_block); - let program = Program::new(mast_forest, basic_block_id); + let program = Program::new(mast_forest.into(), basic_block_id); process.execute(&program).unwrap(); ExecutionTrace::new(process, StackOutputs::default()) From a39f74fa3b52e32d5205feac63c338290b57ad6e Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 21 Jun 2024 10:23:49 -0400 Subject: [PATCH 009/172] fix docs --- core/src/program.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/core/src/program.rs b/core/src/program.rs index 9f24da4e64..60a1be576b 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -27,6 +27,8 @@ pub struct Program { /// Constructors impl Program { + /// Construct a new [`Program`] from the given MAST forest and entrypoint. The kernel is assumed + /// to be empty. pub fn new(mast_forest: Arc, entrypoint: MastNodeId) -> Self { debug_assert!(mast_forest.get_node_by_id(entrypoint).is_some()); @@ -37,6 +39,7 @@ impl Program { } } + /// Construct a new [`Program`] from the given MAST forest, entrypoint, and kernel. pub fn new_with_kernel( mast_forest: Arc, entrypoint: MastNodeId, @@ -69,7 +72,9 @@ impl Program { self.entrypoint } - /// A convenience method that provides the hash of the entrypoint. + /// Returns the hash of the program's entrypoint. + /// + /// Equivalently, returns the hash of the root of the entrypoint procedure. pub fn hash(&self) -> RpoDigest { self.mast_forest[self.entrypoint].digest() } @@ -83,11 +88,7 @@ impl Program { self.mast_forest.get_node_by_id(node_id) } - // TODOP: fix docs - /// Returns the [`MastNodeId`] associated with a given digest, if any. - /// - /// That is, every [`MastNode`] hashes to some digest. If there exists a [`MastNode`] in the - /// forest that hashes to this digest, then its id is returned. + /// Returns the [`MastNodeId`] of the procedure root associated with a given digest, if any. #[inline(always)] pub fn find_root(&self, digest: RpoDigest) -> Option { self.mast_forest.find_root(digest) From f7e98afc15bfcdf51749b78ff2e52d6c3aae9546 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 21 Jun 2024 10:58:36 -0400 Subject: [PATCH 010/172] Make `Program.kernel` an `Arc` --- assembly/src/assembler/mod.rs | 2 +- core/src/program.rs | 12 ++++++------ processor/src/decoder/tests.rs | 6 +++--- processor/src/lib.rs | 5 ++--- stdlib/tests/mem/mod.rs | 2 +- test-utils/src/lib.rs | 4 ++-- 6 files changed, 15 insertions(+), 16 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 8f1a3ec79f..fccb2c33c8 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -581,7 +581,7 @@ impl Assembler { Ok(Program::new_with_kernel( mast_forest.into(), entry_procedure.body_node_id(), - self.module_graph.kernel().clone(), + self.module_graph.kernel().clone().into(), )) } diff --git a/core/src/program.rs b/core/src/program.rs index 60a1be576b..9d16cc8f84 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -22,7 +22,7 @@ pub struct Program { /// Whether or not the entrypoint is set distinguishes a MAST which is executable, versus a /// MAST which represents a library. entrypoint: MastNodeId, - kernel: Kernel, + kernel: Arc, } /// Constructors @@ -35,7 +35,7 @@ impl Program { Self { mast_forest, entrypoint, - kernel: Kernel::default(), + kernel: Arc::new(Kernel::default()), } } @@ -43,7 +43,7 @@ impl Program { pub fn new_with_kernel( mast_forest: Arc, entrypoint: MastNodeId, - kernel: Kernel, + kernel: Arc, ) -> Self { debug_assert!(mast_forest.get_node_by_id(entrypoint).is_some()); @@ -63,8 +63,8 @@ impl Program { } /// Returns the kernel associated with this program. - pub fn kernel(&self) -> &Kernel { - &self.kernel + pub fn kernel(&self) -> Arc { + self.kernel.clone() } /// Returns the entrypoint associated with this program. @@ -176,7 +176,7 @@ impl ProgramInfo { impl From for ProgramInfo { fn from(program: Program) -> Self { let program_hash = program.hash(); - let kernel = program.kernel().clone(); + let kernel = program.kernel().as_ref().clone(); Self { program_hash, diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index 39862e191c..15d2e2f8e0 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -5,7 +5,7 @@ use super::{ build_op_group, }; use crate::DefaultHost; -use alloc::vec::Vec; +use alloc::{sync::Arc, vec::Vec}; use miden_air::trace::{ decoder::{ ADDR_COL_IDX, GROUP_COUNT_COL_IDX, HASHER_STATE_RANGE, IN_SPAN_COL_IDX, NUM_HASHER_COLUMNS, @@ -895,7 +895,7 @@ fn syscall_block() { let foo_root = MastNode::new_basic_block(vec![Operation::Push(THREE), Operation::FmpUpdate]); let foo_root_id = mast_forest.add_node(foo_root.clone()); mast_forest.make_root(foo_root_id); - let kernel = Kernel::new(&[foo_root.digest()]).unwrap(); + let kernel: Arc = Kernel::new(&[foo_root.digest()]).unwrap().into(); // build bar procedure body let bar_basic_block = MastNode::new_basic_block(vec![Operation::Push(TWO), Operation::FmpUpdate]); @@ -931,7 +931,7 @@ fn syscall_block() { let program = Program::new_with_kernel(mast_forest.into(), program_root_node_id, kernel.clone()); let (sys_trace, dec_trace, trace_len) = - build_call_trace(&program, kernel); + build_call_trace(&program, kernel.as_ref().clone()); // --- check block address, op_bits, group count, op_index, and in_span columns --------------- check_op_decoding(&dec_trace, 0, ZERO, Operation::Join, 0, 0, 0); diff --git a/processor/src/lib.rs b/processor/src/lib.rs index 35651c163a..f7ac6f4ab6 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -134,7 +134,7 @@ pub fn execute( where H: Host, { - let mut process = Process::new(program.kernel().clone(), stack_inputs, host, options); + let mut process = Process::new(program.kernel().as_ref().clone(), stack_inputs, host, options); let stack_outputs = process.execute(program)?; let trace = ExecutionTrace::new(process, stack_outputs); assert_eq!(&program.hash(), trace.program_hash(), "inconsistent program hash"); @@ -147,7 +147,7 @@ pub fn execute_iter(program: &Program, stack_inputs: StackInputs, host: H) -> where H: Host, { - let mut process = Process::new_debug(program.kernel().clone(), stack_inputs, host); + let mut process = Process::new_debug(program.kernel().as_ref().clone(), stack_inputs, host); let result = process.execute(program); if result.is_ok() { assert_eq!( @@ -264,7 +264,6 @@ where )?; let root_id = mast_forest.find_root(external_node.digest()).unwrap_or_else(|| panic!("Malformed host: MAST forest indexed by procedure root {} doesn't contain that root.", external_node.digest())); - // TODOP: Don't clone kernel let program = Program::new_with_kernel(mast_forest, root_id, program.kernel().clone()); self.execute_mast_node(root_id, &program) diff --git a/stdlib/tests/mem/mod.rs b/stdlib/tests/mem/mod.rs index 1dbef6927b..40163149cb 100644 --- a/stdlib/tests/mem/mod.rs +++ b/stdlib/tests/mem/mod.rs @@ -29,7 +29,7 @@ fn test_memcopy() { let program: Program = assembler.assemble(source).expect("Failed to compile test source."); let mut process = Process::new( - program.kernel().clone(), + program.kernel().as_ref().clone(), StackInputs::default(), DefaultHost::default(), ExecutionOptions::default(), diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 4db101e84c..235c4dad32 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -233,7 +233,7 @@ impl Test { // execute the test let mut process = Process::new( - program.kernel().clone(), + program.kernel().as_ref().clone(), self.stack_inputs.clone(), host, ExecutionOptions::default(), @@ -326,7 +326,7 @@ impl Test { MemMastForestStore::default(), ); let mut process = Process::new( - program.kernel().clone(), + program.kernel().as_ref().clone(), self.stack_inputs.clone(), host, ExecutionOptions::default(), From b5538c535c5682881243092b0a6c9274fafadfab Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 21 Jun 2024 11:03:39 -0400 Subject: [PATCH 011/172] fix executable --- miden/src/cli/debug/executor.rs | 8 +++++-- miden/src/cli/prove.rs | 9 +++++-- miden/src/cli/run.rs | 7 ++++-- miden/src/examples/blake3.rs | 3 ++- miden/src/examples/fibonacci.rs | 42 +++++++++++++++++++-------------- miden/src/repl/mod.rs | 2 +- miden/src/tools/mod.rs | 9 ++++--- 7 files changed, 51 insertions(+), 29 deletions(-) diff --git a/miden/src/cli/debug/executor.rs b/miden/src/cli/debug/executor.rs index ec77f492fe..5c2a2bae5b 100644 --- a/miden/src/cli/debug/executor.rs +++ b/miden/src/cli/debug/executor.rs @@ -2,6 +2,7 @@ use super::DebugCommand; use miden_vm::{ math::Felt, DefaultHost, MemAdviceProvider, Program, StackInputs, VmState, VmStateIterator, }; +use processor::MemMastForestStore; /// Holds debugger state and iterator used for debugging. pub struct DebugExecutor { @@ -21,8 +22,11 @@ impl DebugExecutor { stack_inputs: StackInputs, advice_provider: MemAdviceProvider, ) -> Result { - let mut vm_state_iter = - processor::execute_iter(&program, stack_inputs, DefaultHost::new(advice_provider)); + let mut vm_state_iter = processor::execute_iter( + &program, + stack_inputs, + DefaultHost::new(advice_provider, MemMastForestStore::default()), + ); let vm_state = vm_state_iter .next() .ok_or(format!( diff --git a/miden/src/cli/prove.rs b/miden/src/cli/prove.rs index f55913e2a2..687467c28f 100644 --- a/miden/src/cli/prove.rs +++ b/miden/src/cli/prove.rs @@ -2,7 +2,9 @@ use super::data::{instrument, Debug, InputFile, Libraries, OutputFile, ProgramFi use assembly::diagnostics::{IntoDiagnostic, Report, WrapErr}; use clap::Parser; use miden_vm::ProvingOptions; -use processor::{DefaultHost, ExecutionOptions, ExecutionOptionsError, Program}; +use processor::{ + DefaultHost, ExecutionOptions, ExecutionOptionsError, MemMastForestStore, Program, +}; use std::{path::PathBuf, time::Instant}; @@ -95,7 +97,10 @@ impl ProveCmd { // fetch the stack and program inputs from the arguments let stack_inputs = input_data.parse_stack_inputs().map_err(Report::msg)?; - let host = DefaultHost::new(input_data.parse_advice_provider().map_err(Report::msg)?); + let host = DefaultHost::new( + input_data.parse_advice_provider().map_err(Report::msg)?, + MemMastForestStore::default(), + ); let proving_options = self.get_proof_options().map_err(|err| Report::msg(format!("{err}")))?; diff --git a/miden/src/cli/run.rs b/miden/src/cli/run.rs index 59c7bfbe79..484e875435 100644 --- a/miden/src/cli/run.rs +++ b/miden/src/cli/run.rs @@ -1,7 +1,7 @@ use super::data::{instrument, Debug, InputFile, Libraries, OutputFile, ProgramFile}; use assembly::diagnostics::{IntoDiagnostic, Report, WrapErr}; use clap::Parser; -use processor::{DefaultHost, ExecutionOptions, ExecutionTrace}; +use processor::{DefaultHost, ExecutionOptions, ExecutionTrace, MemMastForestStore}; use std::{path::PathBuf, time::Instant}; #[derive(Debug, Clone, Parser)] @@ -117,7 +117,10 @@ fn run_program(params: &RunCmd) -> Result<(ExecutionTrace, [u8; 32]), Report> { // fetch the stack and program inputs from the arguments let stack_inputs = input_data.parse_stack_inputs().map_err(Report::msg)?; - let host = DefaultHost::new(input_data.parse_advice_provider().map_err(Report::msg)?); + let host = DefaultHost::new( + input_data.parse_advice_provider().map_err(Report::msg)?, + MemMastForestStore::default(), + ); let program_hash: [u8; 32] = program.hash().into(); diff --git a/miden/src/examples/blake3.rs b/miden/src/examples/blake3.rs index 2e87cadbcc..0191afabcf 100644 --- a/miden/src/examples/blake3.rs +++ b/miden/src/examples/blake3.rs @@ -1,5 +1,6 @@ use super::Example; use miden_vm::{Assembler, DefaultHost, MemAdviceProvider, Program, StackInputs}; +use processor::MemMastForestStore; use stdlib::StdLibrary; use vm_core::{utils::group_slice_elements, Felt}; @@ -11,7 +12,7 @@ const INITIAL_HASH_VALUE: [u32; 8] = [u32::MAX; 8]; // EXAMPLE BUILDER // ================================================================================================ -pub fn get_example(n: usize) -> Example> { +pub fn get_example(n: usize) -> Example> { // generate the program and expected results let program = generate_blake3_program(n); let expected_result = compute_hash_chain(n); diff --git a/miden/src/examples/fibonacci.rs b/miden/src/examples/fibonacci.rs index 2dcd778747..f2864b4f5f 100644 --- a/miden/src/examples/fibonacci.rs +++ b/miden/src/examples/fibonacci.rs @@ -1,12 +1,11 @@ use super::{Example, ONE, ZERO}; -use miden_vm::{ - math::Felt, Assembler, DefaultHost, MemAdviceProvider, Program, ProvingOptions, StackInputs, -}; +use miden_vm::{math::Felt, Assembler, DefaultHost, MemAdviceProvider, Program, StackInputs}; +use processor::MemMastForestStore; // EXAMPLE BUILDER // ================================================================================================ -pub fn get_example(n: usize) -> Example> { +pub fn get_example(n: usize) -> Example> { // generate the program and expected results let program = generate_fibonacci_program(n); let expected_result = vec![compute_fibonacci(n)]; @@ -59,20 +58,27 @@ fn compute_fibonacci(n: usize) -> Felt { // EXAMPLE TESTER // ================================================================================================ -#[test] -fn test_fib_example() { - let example = get_example(16); - super::test_example(example, false); -} +#[cfg(test)] +mod tests { + use super::*; + use crate::examples::{test_example, test_example_with_options}; + use prover::ProvingOptions; -#[test] -fn test_fib_example_fail() { - let example = get_example(16); - super::test_example(example, true); -} + #[test] + fn test_fib_example() { + let example = get_example(16); + test_example(example, false); + } -#[test] -fn test_fib_example_rpo() { - let example = get_example(16); - super::test_example_with_options(example, false, ProvingOptions::with_96_bit_security(true)); + #[test] + fn test_fib_example_fail() { + let example = get_example(16); + test_example(example, true); + } + + #[test] + fn test_fib_example_rpo() { + let example = get_example(16); + test_example_with_options(example, false, ProvingOptions::with_96_bit_security(true)); + } } diff --git a/miden/src/repl/mod.rs b/miden/src/repl/mod.rs index a48bfdcfc7..aa39436d98 100644 --- a/miden/src/repl/mod.rs +++ b/miden/src/repl/mod.rs @@ -1,5 +1,5 @@ use assembly::{Assembler, Library, MaslLibrary}; -use miden_vm::{math::Felt, DefaultHost, Program, StackInputs, Word}; +use miden_vm::{math::Felt, DefaultHost, StackInputs, Word}; use processor::ContextId; use rustyline::{error::ReadlineError, DefaultEditor}; use std::{collections::BTreeSet, path::PathBuf}; diff --git a/miden/src/tools/mod.rs b/miden/src/tools/mod.rs index d4bb16c3b8..349c9d5f32 100644 --- a/miden/src/tools/mod.rs +++ b/miden/src/tools/mod.rs @@ -2,8 +2,8 @@ use super::cli::InputFile; use assembly::diagnostics::{IntoDiagnostic, Report, WrapErr}; use clap::Parser; use core::fmt; -use miden_vm::{Assembler, DefaultHost, Host, Operation, Program, StackInputs}; -use processor::{AsmOpInfo, TraceLenSummary}; +use miden_vm::{Assembler, DefaultHost, Host, Operation, StackInputs}; +use processor::{AsmOpInfo, MemMastForestStore, TraceLenSummary}; use std::{fs, path::PathBuf}; use stdlib::StdLibrary; @@ -35,7 +35,10 @@ impl Analyze { // fetch the stack and program inputs from the arguments let stack_inputs = input_data.parse_stack_inputs().map_err(Report::msg)?; - let host = DefaultHost::new(input_data.parse_advice_provider().map_err(Report::msg)?); + let host = DefaultHost::new( + input_data.parse_advice_provider().map_err(Report::msg)?, + MemMastForestStore::default(), + ); let execution_details: ExecutionDetails = analyze(program.as_str(), stack_inputs, host) .expect("Could not retrieve execution details"); From 188651bb986db3694740466080ee0f05f5ed0657 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 07:37:51 -0400 Subject: [PATCH 012/172] invoke_mast_root: fix external node creation logic --- .../src/assembler/instruction/procedures.rs | 52 ++++++++++++------- assembly/src/errors.rs | 9 ++++ 2 files changed, 42 insertions(+), 19 deletions(-) diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index 65fcd12162..19c7dc60db 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -74,7 +74,7 @@ impl Assembler { None if matches!(kind, InvokeKind::SysCall) => { return Err(AssemblyError::UnknownSysCallTarget { span, - source_file: current_source_file, + source_file: current_source_file.clone(), callee: mast_root, }); } @@ -82,28 +82,42 @@ impl Assembler { } let mast_root_node_id = { - // Note that here we rely on the fact that we topologically sorted the procedures, such - // that when we assemble a procedure, all procedures that it calls will have been - // assembled, and hence be present in the `MastForest`. We currently assume that the - // `MastForest` contains all the procedures being called; "external procedures" only - // known by digest are not currently supported. - let callee_id = mast_forest - .find_root(mast_root) - .unwrap_or_else(|| panic!("MAST root {} not in MAST forest", mast_root)); - match kind { - // For `exec`, we return the root of the procedure being exec'd, which has the - // effect of inlining it - InvokeKind::Exec => callee_id, - // For `call`, we just use the corresponding CALL block + InvokeKind::Exec => { + // Note that here we rely on the fact that we topologically sorted the procedures, such + // that when we assemble a procedure, all procedures that it calls will have been + // assembled, and hence be present in the `MastForest`. We currently assume that the + // `MastForest` contains all the procedures being called; "external procedures" only + // known by digest are not currently supported. + mast_forest.find_root(mast_root).ok_or_else(|| { + AssemblyError::UnknownExecTarget { + span, + source_file: current_source_file, + callee: mast_root, + } + })? + } InvokeKind::Call => { - let node = MastNode::new_call(callee_id, mast_forest); - mast_forest.add_node(node) + let callee_id = mast_forest.find_root(mast_root).unwrap_or_else(|| { + // If the MAST root called isn't known to us, make it an external + // reference. + let external_node = MastNode::new_external(mast_root); + mast_forest.add_node(external_node) + }); + + let call_node = MastNode::new_call(callee_id, mast_forest); + mast_forest.add_node(call_node) } - // For `syscall`, we just use the corresponding SYSCALL block + // Syscall nodes always use external references, as the kernel should always be + // provided to the VM via the host. InvokeKind::SysCall => { - let node = MastNode::new_syscall(callee_id, mast_forest); - mast_forest.add_node(node) + let callee_id = { + let external_node = MastNode::new_external(mast_root); + mast_forest.add_node(external_node) + }; + + let syscall_node = MastNode::new_syscall(callee_id, mast_forest); + mast_forest.add_node(syscall_node) } } }; diff --git a/assembly/src/errors.rs b/assembly/src/errors.rs index 5b51048292..43d977b715 100644 --- a/assembly/src/errors.rs +++ b/assembly/src/errors.rs @@ -104,6 +104,15 @@ pub enum AssemblyError { source_file: Option>, callee: RpoDigest, }, + #[error("invalid exec: exec'd procedures must be available during compilation, but '{callee}' is not")] + #[diagnostic()] + UnknownExecTarget { + #[label("call occurs here")] + span: SourceSpan, + #[source_code] + source_file: Option>, + callee: RpoDigest + }, #[error("invalid use of 'caller' instruction outside of kernel")] #[diagnostic(help( "the 'caller' instruction is only allowed in procedures defined in a kernel" From b15263c3c508c2c30f89b7d361e5267c72538dc2 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 07:49:50 -0400 Subject: [PATCH 013/172] add failing test --- assembly/src/assembler/instruction/procedures.rs | 4 ++-- assembly/src/assembler/tests.rs | 5 ++++- core/src/mast/mod.rs | 12 ++++++++++-- core/src/program.rs | 9 +++++++-- processor/src/host/mast_forest_store.rs | 2 +- processor/src/lib.rs | 4 ++-- 6 files changed, 26 insertions(+), 10 deletions(-) diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index 19c7dc60db..799032b8b4 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -89,7 +89,7 @@ impl Assembler { // assembled, and hence be present in the `MastForest`. We currently assume that the // `MastForest` contains all the procedures being called; "external procedures" only // known by digest are not currently supported. - mast_forest.find_root(mast_root).ok_or_else(|| { + mast_forest.find_procedure_root(mast_root).ok_or_else(|| { AssemblyError::UnknownExecTarget { span, source_file: current_source_file, @@ -98,7 +98,7 @@ impl Assembler { })? } InvokeKind::Call => { - let callee_id = mast_forest.find_root(mast_root).unwrap_or_else(|| { + let callee_id = mast_forest.find_procedure_root(mast_root).unwrap_or_else(|| { // If the MAST root called isn't known to us, make it an external // reference. let external_node = MastNode::new_external(mast_root); diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index 83853525fc..2f92262cfa 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -208,7 +208,10 @@ fn nested_blocks() { vec![before, r#if1, nested, exec_foo_bar_baz_node_id, syscall_foo_node_id], &mut expected_mast_forest, ); - let expected_program = Program::new(expected_mast_forest.into(), combined_node_id); + let expected_program = Program::new(expected_mast_forest.into(), combined_node_id); assert_eq!(expected_program.hash(), program.hash()); + + // also check that the program has the right number of procedures + assert_eq!(program.num_procedures(), 3); } diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index c2d68d8f5e..c4ca79b6be 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -96,14 +96,22 @@ impl MastForest { /// Returns the [`MastNodeId`] of the procedure associated with a given digest, if any. #[inline(always)] - pub fn find_root(&self, digest: RpoDigest) -> Option { + pub fn find_procedure_root(&self, digest: RpoDigest) -> Option { self.roots.iter().find(|&&root_id| self[root_id].digest() == digest).copied() } /// Returns an iterator over the digest of the procedures in this MAST forest. - pub fn roots(&self) -> impl Iterator + '_ { + pub fn procedure_roots(&self) -> impl Iterator + '_ { self.roots.iter().map(|&root_id| self[root_id].digest()) } + + /// Returns the number of procedures in this MAST forest. + pub fn num_procedures(&self) -> u32 { + self.roots + .len() + .try_into() + .expect("MAST forest contains more than 2^32 procedures.") + } } impl Index for MastForest { diff --git a/core/src/program.rs b/core/src/program.rs index 9d16cc8f84..bb15021d79 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -90,8 +90,13 @@ impl Program { /// Returns the [`MastNodeId`] of the procedure root associated with a given digest, if any. #[inline(always)] - pub fn find_root(&self, digest: RpoDigest) -> Option { - self.mast_forest.find_root(digest) + pub fn find_procedure_root(&self, digest: RpoDigest) -> Option { + self.mast_forest.find_procedure_root(digest) + } + + /// Returns the number of procedures in this program. + pub fn num_procedures(&self) -> u32 { + self.mast_forest.num_procedures() } } diff --git a/processor/src/host/mast_forest_store.rs b/processor/src/host/mast_forest_store.rs index b15a1785a2..3fb8a6c0f0 100644 --- a/processor/src/host/mast_forest_store.rs +++ b/processor/src/host/mast_forest_store.rs @@ -14,7 +14,7 @@ impl MemMastForestStore { pub fn insert(&mut self, mast_forest: MastForest) { let mast_forest = Arc::new(mast_forest); - for root in mast_forest.roots() { + for root in mast_forest.procedure_roots() { self.mast_forests.insert(root, mast_forest.clone()); } } diff --git a/processor/src/lib.rs b/processor/src/lib.rs index f7ac6f4ab6..080b2fd6bc 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -262,7 +262,7 @@ where root_digest: external_node.digest(), }, )?; - let root_id = mast_forest.find_root(external_node.digest()).unwrap_or_else(|| panic!("Malformed host: MAST forest indexed by procedure root {} doesn't contain that root.", external_node.digest())); + let root_id = mast_forest.find_procedure_root(external_node.digest()).unwrap_or_else(|| panic!("Malformed host: MAST forest indexed by procedure root {} doesn't contain that root.", external_node.digest())); let program = Program::new_with_kernel(mast_forest, root_id, program.kernel().clone()); @@ -385,7 +385,7 @@ where // get dynamic code from the code block table and execute it let callee_id = program - .find_root(callee_hash.into()) + .find_procedure_root(callee_hash.into()) .ok_or_else(|| ExecutionError::DynamicNodeNotFound(callee_hash.into()))?; self.execute_mast_node(callee_id, program)?; From a94a095860e093c7b599b000cfdc837f18deb1f1 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 07:52:16 -0400 Subject: [PATCH 014/172] don't make root in `combine_mast_node_ids` and `compile_body` --- assembly/src/assembler/mod.rs | 16 ++++++---------- assembly/src/assembler/tests.rs | 2 +- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index fccb2c33c8..b64dc0919a 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -698,7 +698,7 @@ impl Assembler { context.set_current_procedure(procedure); let proc = self.module_graph[gid].unwrap_procedure(); - let code = if num_locals > 0 { + let proc_body_root = if num_locals > 0 { // for procedures with locals, we need to update fmp register before and after the // procedure body is executed. specifically: // - to allocate procedure locals we need to increment fmp by the number of locals @@ -713,8 +713,10 @@ impl Assembler { self.compile_body(proc.iter(), context, None, mast_forest)? }; + mast_forest.make_root(proc_body_root); + let pctx = context.take_current_procedure().unwrap(); - Ok(pctx.into_procedure(code)) + Ok(pctx.into_procedure(proc_body_root)) } fn compile_body<'a, I>( @@ -816,10 +818,7 @@ impl Assembler { Ok(if mast_node_ids.is_empty() { let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]); - let basic_block_node_id = mast_forest.add_node(basic_block_node); - mast_forest.make_root(basic_block_node_id); - - basic_block_node_id + mast_forest.add_node(basic_block_node) } else { combine_mast_node_ids(mast_node_ids, mast_forest) }) @@ -889,8 +888,5 @@ fn combine_mast_node_ids( } } - let root_id = mast_node_ids.remove(0); - mast_forest.make_root(root_id); - - root_id + mast_node_ids.remove(0) } diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index 2f92262cfa..697ba028b7 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -213,5 +213,5 @@ fn nested_blocks() { assert_eq!(expected_program.hash(), program.hash()); // also check that the program has the right number of procedures - assert_eq!(program.num_procedures(), 3); + assert_eq!(program.num_procedures(), 5); } From 7af0bfaacb565f34dd019a989be41ae997d00e9a Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 08:09:58 -0400 Subject: [PATCH 015/172] fix External docs --- core/src/mast/node/external.rs | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/core/src/mast/node/external.rs b/core/src/mast/node/external.rs index 18eb8df17c..c0b8ff10a3 100644 --- a/core/src/mast/node/external.rs +++ b/core/src/mast/node/external.rs @@ -1,21 +1,26 @@ use crate::mast::{MastForest, MerkleTreeNode}; use core::fmt; use miden_crypto::hash::rpo::RpoDigest; -/// Block for a unknown function call. + +/// Node for referencing procedures not present in a given [`MastForest`] (hence "external"). /// -/// Proxy blocks are used to verify the integrity of a program's hash while keeping parts -/// of the program secret. Fails if executed. +/// External nodes can be used to verify the integrity of a program's hash while keeping parts of +/// the program secret. They also allow a program to refer to a well-known procedure that was not +/// compiled with the program (e.g. a procedure in the standard library). /// -/// Hash of a proxy block is not computed but is rather defined at instantiation time. +/// The hash of an external node is the hash of the procedure it represents, such that an external +/// node can be swapped with the actual subtree that it represents without changing the MAST root. #[derive(Clone, Debug, PartialEq, Eq)] pub struct ExternalNode { digest: RpoDigest, } impl ExternalNode { - /// Returns a new [Proxy] block instantiated with the specified code hash. - pub fn new(code_hash: RpoDigest) -> Self { - Self { digest: code_hash } + /// Returns a new [`ExternalNode`] instantiated with the specified procedure hash. + pub fn new(procedure_hash: RpoDigest) -> Self { + Self { + digest: procedure_hash, + } } } From a6fcf47958616131e8ca10a105723fafdaf3cfc0 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 08:10:02 -0400 Subject: [PATCH 016/172] fmt --- .../src/assembler/instruction/procedures.rs | 24 ++++++++++--------- assembly/src/errors.rs | 2 +- prover/src/gpu/metal/mod.rs | 5 ++-- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index 799032b8b4..49aa6bdb39 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -84,11 +84,12 @@ impl Assembler { let mast_root_node_id = { match kind { InvokeKind::Exec => { - // Note that here we rely on the fact that we topologically sorted the procedures, such - // that when we assemble a procedure, all procedures that it calls will have been - // assembled, and hence be present in the `MastForest`. We currently assume that the - // `MastForest` contains all the procedures being called; "external procedures" only - // known by digest are not currently supported. + // Note that here we rely on the fact that we topologically sorted the + // procedures, such that when we assemble a procedure, all + // procedures that it calls will have been assembled, and + // hence be present in the `MastForest`. We currently assume that the + // `MastForest` contains all the procedures being called; "external procedures" + // only known by digest are not currently supported. mast_forest.find_procedure_root(mast_root).ok_or_else(|| { AssemblyError::UnknownExecTarget { span, @@ -98,12 +99,13 @@ impl Assembler { })? } InvokeKind::Call => { - let callee_id = mast_forest.find_procedure_root(mast_root).unwrap_or_else(|| { - // If the MAST root called isn't known to us, make it an external - // reference. - let external_node = MastNode::new_external(mast_root); - mast_forest.add_node(external_node) - }); + let callee_id = + mast_forest.find_procedure_root(mast_root).unwrap_or_else(|| { + // If the MAST root called isn't known to us, make it an external + // reference. + let external_node = MastNode::new_external(mast_root); + mast_forest.add_node(external_node) + }); let call_node = MastNode::new_call(callee_id, mast_forest); mast_forest.add_node(call_node) diff --git a/assembly/src/errors.rs b/assembly/src/errors.rs index 43d977b715..a85ccb7057 100644 --- a/assembly/src/errors.rs +++ b/assembly/src/errors.rs @@ -111,7 +111,7 @@ pub enum AssemblyError { span: SourceSpan, #[source_code] source_file: Option>, - callee: RpoDigest + callee: RpoDigest, }, #[error("invalid use of 'caller' instruction outside of kernel")] #[diagnostic(help( diff --git a/prover/src/gpu/metal/mod.rs b/prover/src/gpu/metal/mod.rs index 7547c18aa4..602f878378 100644 --- a/prover/src/gpu/metal/mod.rs +++ b/prover/src/gpu/metal/mod.rs @@ -91,8 +91,9 @@ where // if we will fill the entire segment, we allocate uninitialized memory unsafe { page_aligned_uninit_vector(domain_size) } } else { - // but if some columns in the segment will remain unfilled, we allocate memory initialized - // to zeros to make sure we don't end up with memory with undefined values + // but if some columns in the segment will remain unfilled, we allocate memory + // initialized to zeros to make sure we don't end up with memory with + // undefined values vec![[E::BaseField::ZERO; N]; domain_size] }; From c66db6f9f6b65c3d35969f73b1eb9e8a646b7ba2 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 08:12:14 -0400 Subject: [PATCH 017/172] fix `entrypoint` doc --- core/src/program.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/core/src/program.rs b/core/src/program.rs index bb15021d79..4c3c4f62f5 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -17,10 +17,7 @@ use super::Kernel; #[derive(Clone, Debug)] pub struct Program { mast_forest: Arc, - /// The "entrypoint", when set, is the root of the entire forest, i.e. a path exists from this - /// node to all other roots in the forest. This corresponds to the executable entry point. - /// Whether or not the entrypoint is set distinguishes a MAST which is executable, versus a - /// MAST which represents a library. + /// The "entrypoint" is the node where execution of the program begins. entrypoint: MastNodeId, kernel: Arc, } From 572fc7ef012e66afde36c290f3011fd9d8b8831e Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 08:13:17 -0400 Subject: [PATCH 018/172] Rename `Program::new_with_kernel()` --- assembly/src/assembler/mod.rs | 2 +- core/src/program.rs | 2 +- processor/src/decoder/tests.rs | 2 +- processor/src/lib.rs | 3 +-- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index b64dc0919a..2224123c25 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -578,7 +578,7 @@ impl Assembler { // Compile the module graph rooted at the entrypoint let entry_procedure = self.compile_subgraph(entrypoint, true, context, &mut mast_forest)?; - Ok(Program::new_with_kernel( + Ok(Program::with_kernel( mast_forest.into(), entry_procedure.body_node_id(), self.module_graph.kernel().clone().into(), diff --git a/core/src/program.rs b/core/src/program.rs index 4c3c4f62f5..280e0c477c 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -37,7 +37,7 @@ impl Program { } /// Construct a new [`Program`] from the given MAST forest, entrypoint, and kernel. - pub fn new_with_kernel( + pub fn with_kernel( mast_forest: Arc, entrypoint: MastNodeId, kernel: Arc, diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index 15d2e2f8e0..0f964891dc 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -928,7 +928,7 @@ fn syscall_block() { let program_root_node = MastNode::new_join(inner_join_node_id, last_basic_block_id, &mast_forest); let program_root_node_id = mast_forest.add_node(program_root_node.clone()); - let program = Program::new_with_kernel(mast_forest.into(), program_root_node_id, kernel.clone()); + let program = Program::with_kernel(mast_forest.into(), program_root_node_id, kernel.clone()); let (sys_trace, dec_trace, trace_len) = build_call_trace(&program, kernel.as_ref().clone()); diff --git a/processor/src/lib.rs b/processor/src/lib.rs index 080b2fd6bc..f8e4e4e13b 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -264,8 +264,7 @@ where )?; let root_id = mast_forest.find_procedure_root(external_node.digest()).unwrap_or_else(|| panic!("Malformed host: MAST forest indexed by procedure root {} doesn't contain that root.", external_node.digest())); - let program = - Program::new_with_kernel(mast_forest, root_id, program.kernel().clone()); + let program = Program::with_kernel(mast_forest, root_id, program.kernel().clone()); self.execute_mast_node(root_id, &program) } } From 08ce2c7a15f124925db6c4073fa5beb406317644 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 08:28:01 -0400 Subject: [PATCH 019/172] Document `MastForestStore` and `MemMastForestStore` --- processor/src/host/mast_forest_store.rs | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/processor/src/host/mast_forest_store.rs b/processor/src/host/mast_forest_store.rs index 3fb8a6c0f0..eb2ae42055 100644 --- a/processor/src/host/mast_forest_store.rs +++ b/processor/src/host/mast_forest_store.rs @@ -1,16 +1,27 @@ use alloc::{collections::BTreeMap, sync::Arc}; use vm_core::{crypto::hash::RpoDigest, mast::MastForest}; +/// A set of [`MastForest`]s available to the prover that programs may refer to (by means of an +/// [`ExternalNode`]). +/// +/// For example, a program's kernel and standard library would most likely not be compiled directly +/// with the program, and instead be provided separately to the prover. This has the benefit of +/// reducing program binary size. The store could also be much more complex, such as accessing a +/// centralized registry of [`MastForest`]s when it doesn't find one locally. pub trait MastForestStore { - fn get(&self, node_digest: &RpoDigest) -> Option>; + /// Returns a [`MastForest`] which is guaranteed to contain a procedure with the provided + /// procedure hash as one of its procedure, if any. + fn get(&self, procedure_hash: &RpoDigest) -> Option>; } +/// A simple [`MastForestStore`] where all known [`MastForest`]s are held in memory. #[derive(Debug, Default, Clone)] pub struct MemMastForestStore { mast_forests: BTreeMap>, } impl MemMastForestStore { + /// Inserts all the procedures of the provided MAST forest in the store. pub fn insert(&mut self, mast_forest: MastForest) { let mast_forest = Arc::new(mast_forest); @@ -21,7 +32,7 @@ impl MemMastForestStore { } impl MastForestStore for MemMastForestStore { - fn get(&self, node_digest: &RpoDigest) -> Option> { - self.mast_forests.get(node_digest).cloned() + fn get(&self, procedure_hash: &RpoDigest) -> Option> { + self.mast_forests.get(procedure_hash).cloned() } } From 50e01e9064cc24cd7727b753220a73d58fa7fa44 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 08:57:34 -0400 Subject: [PATCH 020/172] fix syscall --- assembly/src/assembler/instruction/procedures.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index 49aa6bdb39..e670fec358 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -110,13 +110,14 @@ impl Assembler { let call_node = MastNode::new_call(callee_id, mast_forest); mast_forest.add_node(call_node) } - // Syscall nodes always use external references, as the kernel should always be - // provided to the VM via the host. InvokeKind::SysCall => { - let callee_id = { - let external_node = MastNode::new_external(mast_root); - mast_forest.add_node(external_node) - }; + let callee_id = + mast_forest.find_procedure_root(mast_root).unwrap_or_else(|| { + // If the MAST root called isn't known to us, make it an external + // reference. + let external_node = MastNode::new_external(mast_root); + mast_forest.add_node(external_node) + }); let syscall_node = MastNode::new_syscall(callee_id, mast_forest); mast_forest.add_node(syscall_node) From 071ab54a9e25f09a771a570e97b9d5f9ad3d39e2 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 09:02:45 -0400 Subject: [PATCH 021/172] execute_* functions: use `MastForest` --- processor/src/decoder/mod.rs | 12 ++++++------ processor/src/lib.rs | 17 ++++++++--------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/processor/src/decoder/mod.rs b/processor/src/decoder/mod.rs index 72845ba549..5d66138df4 100644 --- a/processor/src/decoder/mod.rs +++ b/processor/src/decoder/mod.rs @@ -12,11 +12,11 @@ use miden_air::trace::{ }; use vm_core::{ mast::{ - get_span_op_group_count, BasicBlockNode, CallNode, DynNode, JoinNode, LoopNode, + get_span_op_group_count, BasicBlockNode, CallNode, DynNode, JoinNode, LoopNode, MastForest, MerkleTreeNode, SplitNode, OP_BATCH_SIZE, }, stack::STACK_TOP_SIZE, - AssemblyOp, Program, + AssemblyOp, }; mod trace; @@ -56,7 +56,7 @@ where pub(super) fn start_join_node( &mut self, node: &JoinNode, - program: &Program, + program: &MastForest, ) -> Result<(), ExecutionError> { // use the hasher to compute the hash of the JOIN block; the row address returned by the // hasher is used as the ID of the block; the result of the hash is expected to be in @@ -106,7 +106,7 @@ where pub(super) fn start_split_node( &mut self, node: &SplitNode, - program: &Program, + program: &MastForest, ) -> Result { let condition = self.stack.peek(); @@ -158,7 +158,7 @@ where pub(super) fn start_loop_node( &mut self, node: &LoopNode, - program: &Program, + program: &MastForest, ) -> Result { let condition = self.stack.peek(); @@ -222,7 +222,7 @@ where pub(super) fn start_call_node( &mut self, node: &CallNode, - program: &Program, + program: &MastForest, ) -> Result<(), ExecutionError> { // use the hasher to compute the hash of the CALL or SYSCALL block; the row address // returned by the hasher is used as the ID of the block; the result of the hash is diff --git a/processor/src/lib.rs b/processor/src/lib.rs index f8e4e4e13b..f4485afe79 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -231,7 +231,7 @@ where return Err(ExecutionError::ProgramAlreadyExecuted); } - self.execute_mast_node(program.entrypoint(), program)?; + self.execute_mast_node(program.entrypoint(), program.mast_forest())?; Ok(self.stack.build_stack_outputs()) } @@ -242,7 +242,7 @@ where fn execute_mast_node( &mut self, node_id: MastNodeId, - program: &Program, + program: &MastForest, ) -> Result<(), ExecutionError> { let wrapper_node = &program .get_node_by_id(node_id) @@ -264,8 +264,7 @@ where )?; let root_id = mast_forest.find_procedure_root(external_node.digest()).unwrap_or_else(|| panic!("Malformed host: MAST forest indexed by procedure root {} doesn't contain that root.", external_node.digest())); - let program = Program::with_kernel(mast_forest, root_id, program.kernel().clone()); - self.execute_mast_node(root_id, &program) + self.execute_mast_node(root_id, &mast_forest) } } } @@ -274,7 +273,7 @@ where fn execute_join_node( &mut self, node: &JoinNode, - program: &Program, + program: &MastForest, ) -> Result<(), ExecutionError> { self.start_join_node(node, program)?; @@ -289,7 +288,7 @@ where fn execute_split_node( &mut self, node: &SplitNode, - program: &Program, + program: &MastForest, ) -> Result<(), ExecutionError> { // start the SPLIT block; this also pops the stack and returns the popped element let condition = self.start_split_node(node, program)?; @@ -311,7 +310,7 @@ where fn execute_loop_node( &mut self, node: &LoopNode, - program: &Program, + program: &MastForest, ) -> Result<(), ExecutionError> { // start the LOOP block; this also pops the stack and returns the popped element let condition = self.start_loop_node(node, program)?; @@ -346,7 +345,7 @@ where fn execute_call_node( &mut self, call_node: &CallNode, - program: &Program, + program: &MastForest, ) -> Result<(), ExecutionError> { let callee_digest = { let callee = program.get_node_by_id(call_node.callee()).ok_or_else(|| { @@ -377,7 +376,7 @@ where /// Executes the specified [DynNode] node. #[inline(always)] - fn execute_dyn_node(&mut self, program: &Program) -> Result<(), ExecutionError> { + fn execute_dyn_node(&mut self, program: &MastForest) -> Result<(), ExecutionError> { // get target hash from the stack let callee_hash = self.stack.get_word(0); self.start_dyn_node(callee_hash)?; From 49de40d61cd65854ecb746af15067df7a2df9528 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 09:06:46 -0400 Subject: [PATCH 022/172] `Program`: Remove `Arc` around kernel --- assembly/src/assembler/mod.rs | 2 +- core/src/program.rs | 12 ++++++------ processor/src/decoder/tests.rs | 6 +++--- processor/src/lib.rs | 4 ++-- stdlib/tests/mem/mod.rs | 2 +- test-utils/src/lib.rs | 4 ++-- 6 files changed, 15 insertions(+), 15 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 2224123c25..58f14f38eb 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -581,7 +581,7 @@ impl Assembler { Ok(Program::with_kernel( mast_forest.into(), entry_procedure.body_node_id(), - self.module_graph.kernel().clone().into(), + self.module_graph.kernel().clone(), )) } diff --git a/core/src/program.rs b/core/src/program.rs index 280e0c477c..e38f24a74b 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -19,7 +19,7 @@ pub struct Program { mast_forest: Arc, /// The "entrypoint" is the node where execution of the program begins. entrypoint: MastNodeId, - kernel: Arc, + kernel: Kernel, } /// Constructors @@ -32,7 +32,7 @@ impl Program { Self { mast_forest, entrypoint, - kernel: Arc::new(Kernel::default()), + kernel: Kernel::default(), } } @@ -40,7 +40,7 @@ impl Program { pub fn with_kernel( mast_forest: Arc, entrypoint: MastNodeId, - kernel: Arc, + kernel: Kernel, ) -> Self { debug_assert!(mast_forest.get_node_by_id(entrypoint).is_some()); @@ -60,8 +60,8 @@ impl Program { } /// Returns the kernel associated with this program. - pub fn kernel(&self) -> Arc { - self.kernel.clone() + pub fn kernel(&self) -> &Kernel { + &self.kernel } /// Returns the entrypoint associated with this program. @@ -178,7 +178,7 @@ impl ProgramInfo { impl From for ProgramInfo { fn from(program: Program) -> Self { let program_hash = program.hash(); - let kernel = program.kernel().as_ref().clone(); + let kernel = program.kernel().clone(); Self { program_hash, diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index 0f964891dc..265276119f 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -5,7 +5,7 @@ use super::{ build_op_group, }; use crate::DefaultHost; -use alloc::{sync::Arc, vec::Vec}; +use alloc::vec::Vec; use miden_air::trace::{ decoder::{ ADDR_COL_IDX, GROUP_COUNT_COL_IDX, HASHER_STATE_RANGE, IN_SPAN_COL_IDX, NUM_HASHER_COLUMNS, @@ -895,7 +895,7 @@ fn syscall_block() { let foo_root = MastNode::new_basic_block(vec![Operation::Push(THREE), Operation::FmpUpdate]); let foo_root_id = mast_forest.add_node(foo_root.clone()); mast_forest.make_root(foo_root_id); - let kernel: Arc = Kernel::new(&[foo_root.digest()]).unwrap().into(); + let kernel = Kernel::new(&[foo_root.digest()]).unwrap(); // build bar procedure body let bar_basic_block = MastNode::new_basic_block(vec![Operation::Push(TWO), Operation::FmpUpdate]); @@ -931,7 +931,7 @@ fn syscall_block() { let program = Program::with_kernel(mast_forest.into(), program_root_node_id, kernel.clone()); let (sys_trace, dec_trace, trace_len) = - build_call_trace(&program, kernel.as_ref().clone()); + build_call_trace(&program, kernel); // --- check block address, op_bits, group count, op_index, and in_span columns --------------- check_op_decoding(&dec_trace, 0, ZERO, Operation::Join, 0, 0, 0); diff --git a/processor/src/lib.rs b/processor/src/lib.rs index f4485afe79..83116eab39 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -134,7 +134,7 @@ pub fn execute( where H: Host, { - let mut process = Process::new(program.kernel().as_ref().clone(), stack_inputs, host, options); + let mut process = Process::new(program.kernel().clone(), stack_inputs, host, options); let stack_outputs = process.execute(program)?; let trace = ExecutionTrace::new(process, stack_outputs); assert_eq!(&program.hash(), trace.program_hash(), "inconsistent program hash"); @@ -147,7 +147,7 @@ pub fn execute_iter(program: &Program, stack_inputs: StackInputs, host: H) -> where H: Host, { - let mut process = Process::new_debug(program.kernel().as_ref().clone(), stack_inputs, host); + let mut process = Process::new_debug(program.kernel().clone(), stack_inputs, host); let result = process.execute(program); if result.is_ok() { assert_eq!( diff --git a/stdlib/tests/mem/mod.rs b/stdlib/tests/mem/mod.rs index 40163149cb..1dbef6927b 100644 --- a/stdlib/tests/mem/mod.rs +++ b/stdlib/tests/mem/mod.rs @@ -29,7 +29,7 @@ fn test_memcopy() { let program: Program = assembler.assemble(source).expect("Failed to compile test source."); let mut process = Process::new( - program.kernel().as_ref().clone(), + program.kernel().clone(), StackInputs::default(), DefaultHost::default(), ExecutionOptions::default(), diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 235c4dad32..4db101e84c 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -233,7 +233,7 @@ impl Test { // execute the test let mut process = Process::new( - program.kernel().as_ref().clone(), + program.kernel().clone(), self.stack_inputs.clone(), host, ExecutionOptions::default(), @@ -326,7 +326,7 @@ impl Test { MemMastForestStore::default(), ); let mut process = Process::new( - program.kernel().as_ref().clone(), + program.kernel().clone(), self.stack_inputs.clone(), host, ExecutionOptions::default(), From c28c8768077d3a970464b91b2a7044e971775883 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 09:08:43 -0400 Subject: [PATCH 023/172] remove `Arc` around `MastForest` in `Program` --- assembly/src/assembler/mod.rs | 2 +- assembly/src/assembler/tests.rs | 2 +- core/src/program.rs | 14 ++++------ processor/src/chiplets/tests.rs | 2 +- processor/src/decoder/tests.rs | 28 ++++++++++---------- processor/src/trace/tests/chiplets/hasher.rs | 8 +++--- processor/src/trace/tests/decoder.rs | 16 +++++------ processor/src/trace/tests/mod.rs | 4 +-- 8 files changed, 36 insertions(+), 40 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 58f14f38eb..b97394681e 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -579,7 +579,7 @@ impl Assembler { let entry_procedure = self.compile_subgraph(entrypoint, true, context, &mut mast_forest)?; Ok(Program::with_kernel( - mast_forest.into(), + mast_forest, entry_procedure.body_node_id(), self.module_graph.kernel().clone(), )) diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index 697ba028b7..77a4742767 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -209,7 +209,7 @@ fn nested_blocks() { &mut expected_mast_forest, ); - let expected_program = Program::new(expected_mast_forest.into(), combined_node_id); + let expected_program = Program::new(expected_mast_forest, combined_node_id); assert_eq!(expected_program.hash(), program.hash()); // also check that the program has the right number of procedures diff --git a/core/src/program.rs b/core/src/program.rs index e38f24a74b..67d0ce4c5e 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -1,6 +1,6 @@ use core::{fmt, ops::Index}; -use alloc::{sync::Arc, vec::Vec}; +use alloc::vec::Vec; use miden_crypto::{hash::rpo::RpoDigest, Felt, WORD_SIZE}; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; @@ -16,7 +16,7 @@ use super::Kernel; #[derive(Clone, Debug)] pub struct Program { - mast_forest: Arc, + mast_forest: MastForest, /// The "entrypoint" is the node where execution of the program begins. entrypoint: MastNodeId, kernel: Kernel, @@ -26,7 +26,7 @@ pub struct Program { impl Program { /// Construct a new [`Program`] from the given MAST forest and entrypoint. The kernel is assumed /// to be empty. - pub fn new(mast_forest: Arc, entrypoint: MastNodeId) -> Self { + pub fn new(mast_forest: MastForest, entrypoint: MastNodeId) -> Self { debug_assert!(mast_forest.get_node_by_id(entrypoint).is_some()); Self { @@ -37,11 +37,7 @@ impl Program { } /// Construct a new [`Program`] from the given MAST forest, entrypoint, and kernel. - pub fn with_kernel( - mast_forest: Arc, - entrypoint: MastNodeId, - kernel: Kernel, - ) -> Self { + pub fn with_kernel(mast_forest: MastForest, entrypoint: MastNodeId, kernel: Kernel) -> Self { debug_assert!(mast_forest.get_node_by_id(entrypoint).is_some()); Self { @@ -121,7 +117,7 @@ impl fmt::Display for Program { } } -impl From for Arc { +impl From for MastForest { fn from(program: Program) -> Self { program.mast_forest } diff --git a/processor/src/chiplets/tests.rs b/processor/src/chiplets/tests.rs index 0b137b6304..8353376989 100644 --- a/processor/src/chiplets/tests.rs +++ b/processor/src/chiplets/tests.rs @@ -122,7 +122,7 @@ fn build_trace( let basic_block = MastNode::new_basic_block(operations); let basic_block_id = mast_forest.add_node(basic_block); - Program::new(mast_forest.into(), basic_block_id) + Program::new(mast_forest, basic_block_id) }; process.execute(&program).unwrap(); diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index 265276119f..598db1abdf 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -52,7 +52,7 @@ fn basic_block_one_group() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node); - Program::new(mast_forest.into(), basic_block_id) + Program::new(mast_forest, basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -98,7 +98,7 @@ fn basic_block_small() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node); - Program::new(mast_forest.into(), basic_block_id) + Program::new(mast_forest, basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -161,7 +161,7 @@ fn basic_block() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node); - Program::new(mast_forest.into(), basic_block_id) + Program::new(mast_forest, basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -253,7 +253,7 @@ fn span_block_with_respan() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node); - Program::new(mast_forest.into(), basic_block_id) + Program::new(mast_forest, basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -330,7 +330,7 @@ fn join_node() { let join_node = MastNode::new_join(basic_block1_id, basic_block2_id, &mast_forest); let join_node_id = mast_forest.add_node(join_node); - Program::new(mast_forest.into(), join_node_id) + Program::new(mast_forest, join_node_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -396,7 +396,7 @@ fn split_node_true() { let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest); let split_node_id = mast_forest.add_node(split_node); - Program::new(mast_forest.into(), split_node_id) + Program::new(mast_forest, split_node_id) }; let (trace, trace_len) = build_trace(&[1], &program); @@ -449,7 +449,7 @@ fn split_node_false() { let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest); let split_node_id = mast_forest.add_node(split_node); - Program::new(mast_forest.into(), split_node_id) + Program::new(mast_forest, split_node_id) }; let (trace, trace_len) = build_trace(&[0], &program); @@ -503,7 +503,7 @@ fn loop_node() { let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); let loop_node_id = mast_forest.add_node(loop_node); - Program::new(mast_forest.into(), loop_node_id) + Program::new(mast_forest, loop_node_id) }; let (trace, trace_len) = build_trace(&[0, 1], &program); @@ -556,7 +556,7 @@ fn loop_node_skip() { let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); let loop_node_id = mast_forest.add_node(loop_node); - Program::new(mast_forest.into(), loop_node_id) + Program::new(mast_forest, loop_node_id) }; let (trace, trace_len) = build_trace(&[0], &program); @@ -599,7 +599,7 @@ fn loop_node_repeat() { let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); let loop_node_id = mast_forest.add_node(loop_node); - Program::new(mast_forest.into(), loop_node_id) + Program::new(mast_forest, loop_node_id) }; let (trace, trace_len) = build_trace(&[0, 1, 1], &program); @@ -702,7 +702,7 @@ fn call_block() { let program_root = MastNode::new_join(join1_node_id, last_basic_block_id, &mast_forest); let program_root_id = mast_forest.add_node(program_root); - let program = Program::new(mast_forest.into(), program_root_id); + let program = Program::new(mast_forest, program_root_id); let (sys_trace, dec_trace, trace_len) = build_call_trace(&program, Kernel::default()); @@ -928,7 +928,7 @@ fn syscall_block() { let program_root_node = MastNode::new_join(inner_join_node_id, last_basic_block_id, &mast_forest); let program_root_node_id = mast_forest.add_node(program_root_node.clone()); - let program = Program::with_kernel(mast_forest.into(), program_root_node_id, kernel.clone()); + let program = Program::with_kernel(mast_forest, program_root_node_id, kernel.clone()); let (sys_trace, dec_trace, trace_len) = build_call_trace(&program, kernel); @@ -1200,7 +1200,7 @@ fn dyn_block() { let program_root_node = MastNode::new_join(join_node_id, dyn_node_id, &mast_forest); let program_root_node_id = mast_forest.add_node(program_root_node.clone()); - let program = Program::new(mast_forest.into(), program_root_node_id); + let program = Program::new(mast_forest, program_root_node_id); let (trace, trace_len) = build_dyn_trace( &[ @@ -1308,7 +1308,7 @@ fn set_user_op_helpers_many() { let basic_block = MastNode::new_basic_block(vec![Operation::U32div]); let basic_block_id = mast_forest.add_node(basic_block); - Program::new(mast_forest.into(), basic_block_id) + Program::new(mast_forest, basic_block_id) }; let a = rand_value::(); let b = rand_value::(); diff --git a/processor/src/trace/tests/chiplets/hasher.rs b/processor/src/trace/tests/chiplets/hasher.rs index 31944a6460..992cee60d4 100644 --- a/processor/src/trace/tests/chiplets/hasher.rs +++ b/processor/src/trace/tests/chiplets/hasher.rs @@ -53,7 +53,7 @@ pub fn b_chip_span() { let basic_block = MastNode::new_basic_block(vec![Operation::Add, Operation::Mul]); let basic_block_id = mast_forest.add_node(basic_block); - Program::new(mast_forest.into(), basic_block_id) + Program::new(mast_forest, basic_block_id) }; let trace = build_trace_from_program(&program, &[]); @@ -126,7 +126,7 @@ pub fn b_chip_span_with_respan() { let basic_block = MastNode::new_basic_block(ops); let basic_block_id = mast_forest.add_node(basic_block); - Program::new(mast_forest.into(), basic_block_id) + Program::new(mast_forest, basic_block_id) }; let trace = build_trace_from_program(&program, &[]); @@ -224,7 +224,7 @@ pub fn b_chip_merge() { let split = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); let split_id = mast_forest.add_node(split); - Program::new(mast_forest.into(), split_id) + Program::new(mast_forest, split_id) }; let trace = build_trace_from_program(&program, &[]); @@ -337,7 +337,7 @@ pub fn b_chip_permutation() { let basic_block = MastNode::new_basic_block(vec![Operation::HPerm]); let basic_block_id = mast_forest.add_node(basic_block); - Program::new(mast_forest.into(), basic_block_id) + Program::new(mast_forest, basic_block_id) }; let stack = vec![8, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]; let trace = build_trace_from_program(&program, &stack); diff --git a/processor/src/trace/tests/decoder.rs b/processor/src/trace/tests/decoder.rs index e1d6a3ea56..3b20428865 100644 --- a/processor/src/trace/tests/decoder.rs +++ b/processor/src/trace/tests/decoder.rs @@ -81,7 +81,7 @@ fn decoder_p1_join() { let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); let join_id = mast_forest.add_node(join); - Program::new(mast_forest.into(), join_id) + Program::new(mast_forest, join_id) }; let trace = build_trace_from_program(&program, &[]); @@ -152,7 +152,7 @@ fn decoder_p1_split() { let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); let split_id = mast_forest.add_node(split); - Program::new(mast_forest.into(), split_id) + Program::new(mast_forest, split_id) }; let trace = build_trace_from_program(&program, &[1]); @@ -213,7 +213,7 @@ fn decoder_p1_loop_with_repeat() { let loop_node = MastNode::new_loop(join_id, &mast_forest); let loop_node_id = mast_forest.add_node(loop_node); - Program::new(mast_forest.into(), loop_node_id) + Program::new(mast_forest, loop_node_id) }; let trace = build_trace_from_program(&program, &[0, 1, 1]); @@ -335,7 +335,7 @@ fn decoder_p2_span_with_respan() { let basic_block = MastNode::new_basic_block(ops); let basic_block_id = mast_forest.add_node(basic_block); - Program::new(mast_forest.into(), basic_block_id) + Program::new(mast_forest, basic_block_id) }; let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); @@ -377,7 +377,7 @@ fn decoder_p2_join() { let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); let join_id = mast_forest.add_node(join.clone()); - let program = Program::new(mast_forest.into(), join_id); + let program = Program::new(mast_forest, join_id); let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); @@ -442,7 +442,7 @@ fn decoder_p2_split_true() { let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); let split_id = mast_forest.add_node(split); - let program = Program::new(mast_forest.into(), split_id); + let program = Program::new(mast_forest, split_id); // build trace from program let trace = build_trace_from_program(&program, &[1]); @@ -498,7 +498,7 @@ fn decoder_p2_split_false() { let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); let split_id = mast_forest.add_node(split); - let program = Program::new(mast_forest.into(), split_id); + let program = Program::new(mast_forest, split_id); // build trace from program let trace = build_trace_from_program(&program, &[0]); @@ -557,7 +557,7 @@ fn decoder_p2_loop_with_repeat() { let loop_node = MastNode::new_loop(join_id, &mast_forest); let loop_node_id = mast_forest.add_node(loop_node); - let program = Program::new(mast_forest.into(), loop_node_id); + let program = Program::new(mast_forest, loop_node_id); // build trace from program let trace = build_trace_from_program(&program, &[0, 1, 1]); diff --git a/processor/src/trace/tests/mod.rs b/processor/src/trace/tests/mod.rs index 5d8c0a69f0..b324589013 100644 --- a/processor/src/trace/tests/mod.rs +++ b/processor/src/trace/tests/mod.rs @@ -40,7 +40,7 @@ pub fn build_trace_from_ops(operations: Vec, stack: &[u64]) -> Execut let basic_block = MastNode::new_basic_block(operations); let basic_block_id = mast_forest.add_node(basic_block); - let program = Program::new(mast_forest.into(), basic_block_id); + let program = Program::new(mast_forest, basic_block_id); build_trace_from_program(&program, stack) } @@ -63,7 +63,7 @@ pub fn build_trace_from_ops_with_inputs( let basic_block = MastNode::new_basic_block(operations); let basic_block_id = mast_forest.add_node(basic_block); - let program = Program::new(mast_forest.into(), basic_block_id); + let program = Program::new(mast_forest, basic_block_id); process.execute(&program).unwrap(); ExecutionTrace::new(process, StackOutputs::default()) From 78b2b16684366aabfc704c0a342034dfeebe99e9 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 09:19:02 -0400 Subject: [PATCH 024/172] Return error on malformed host --- processor/src/errors.rs | 6 ++++++ processor/src/lib.rs | 6 +++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/processor/src/errors.rs b/processor/src/errors.rs index 38b3be870f..203384367f 100644 --- a/processor/src/errors.rs +++ b/processor/src/errors.rs @@ -48,6 +48,9 @@ pub enum ExecutionError { }, LogArgumentZero(u32), MalformedSignatureKey(&'static str), + MalformedMastForestInHost { + root_digest: Digest, + }, MastNodeNotFoundInForest { node_id: MastNodeId, }, @@ -150,6 +153,9 @@ impl Display for ExecutionError { ) } MalformedSignatureKey(signature) => write!(f, "Malformed signature key: {signature}"), + MalformedMastForestInHost { root_digest } => { + write!(f, "Malformed host: MAST forest indexed by procedure root {} doesn't contain that root", root_digest) + } MastNodeNotFoundInForest { node_id } => { write!(f, "Malformed MAST forest, node id {node_id} doesn't exist") } diff --git a/processor/src/lib.rs b/processor/src/lib.rs index 83116eab39..d825592967 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -262,7 +262,11 @@ where root_digest: external_node.digest(), }, )?; - let root_id = mast_forest.find_procedure_root(external_node.digest()).unwrap_or_else(|| panic!("Malformed host: MAST forest indexed by procedure root {} doesn't contain that root.", external_node.digest())); + let root_id = mast_forest.find_procedure_root(external_node.digest()).ok_or( + ExecutionError::MalformedMastForestInHost { + root_digest: external_node.digest(), + }, + )?; self.execute_mast_node(root_id, &mast_forest) } From 4883b4440544803c49ee28816e377d391bdad8cd Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 09:27:18 -0400 Subject: [PATCH 025/172] Simplify `DefaultHost` --- miden/src/cli/debug/executor.rs | 8 ++------ miden/src/cli/prove.rs | 9 ++------- miden/src/cli/run.rs | 7 ++----- miden/src/examples/blake3.rs | 3 +-- miden/src/examples/fibonacci.rs | 3 +-- miden/src/tools/mod.rs | 7 ++----- processor/src/host/mod.rs | 20 +++++++++++--------- processor/src/operations/mod.rs | 8 ++++---- processor/src/trace/tests/mod.rs | 8 ++------ stdlib/tests/crypto/falcon.rs | 4 ++-- stdlib/tests/crypto/stark/mod.rs | 4 ++-- test-utils/src/lib.rs | 29 +++++++---------------------- 12 files changed, 38 insertions(+), 72 deletions(-) diff --git a/miden/src/cli/debug/executor.rs b/miden/src/cli/debug/executor.rs index 5c2a2bae5b..ec77f492fe 100644 --- a/miden/src/cli/debug/executor.rs +++ b/miden/src/cli/debug/executor.rs @@ -2,7 +2,6 @@ use super::DebugCommand; use miden_vm::{ math::Felt, DefaultHost, MemAdviceProvider, Program, StackInputs, VmState, VmStateIterator, }; -use processor::MemMastForestStore; /// Holds debugger state and iterator used for debugging. pub struct DebugExecutor { @@ -22,11 +21,8 @@ impl DebugExecutor { stack_inputs: StackInputs, advice_provider: MemAdviceProvider, ) -> Result { - let mut vm_state_iter = processor::execute_iter( - &program, - stack_inputs, - DefaultHost::new(advice_provider, MemMastForestStore::default()), - ); + let mut vm_state_iter = + processor::execute_iter(&program, stack_inputs, DefaultHost::new(advice_provider)); let vm_state = vm_state_iter .next() .ok_or(format!( diff --git a/miden/src/cli/prove.rs b/miden/src/cli/prove.rs index 687467c28f..f55913e2a2 100644 --- a/miden/src/cli/prove.rs +++ b/miden/src/cli/prove.rs @@ -2,9 +2,7 @@ use super::data::{instrument, Debug, InputFile, Libraries, OutputFile, ProgramFi use assembly::diagnostics::{IntoDiagnostic, Report, WrapErr}; use clap::Parser; use miden_vm::ProvingOptions; -use processor::{ - DefaultHost, ExecutionOptions, ExecutionOptionsError, MemMastForestStore, Program, -}; +use processor::{DefaultHost, ExecutionOptions, ExecutionOptionsError, Program}; use std::{path::PathBuf, time::Instant}; @@ -97,10 +95,7 @@ impl ProveCmd { // fetch the stack and program inputs from the arguments let stack_inputs = input_data.parse_stack_inputs().map_err(Report::msg)?; - let host = DefaultHost::new( - input_data.parse_advice_provider().map_err(Report::msg)?, - MemMastForestStore::default(), - ); + let host = DefaultHost::new(input_data.parse_advice_provider().map_err(Report::msg)?); let proving_options = self.get_proof_options().map_err(|err| Report::msg(format!("{err}")))?; diff --git a/miden/src/cli/run.rs b/miden/src/cli/run.rs index 484e875435..59c7bfbe79 100644 --- a/miden/src/cli/run.rs +++ b/miden/src/cli/run.rs @@ -1,7 +1,7 @@ use super::data::{instrument, Debug, InputFile, Libraries, OutputFile, ProgramFile}; use assembly::diagnostics::{IntoDiagnostic, Report, WrapErr}; use clap::Parser; -use processor::{DefaultHost, ExecutionOptions, ExecutionTrace, MemMastForestStore}; +use processor::{DefaultHost, ExecutionOptions, ExecutionTrace}; use std::{path::PathBuf, time::Instant}; #[derive(Debug, Clone, Parser)] @@ -117,10 +117,7 @@ fn run_program(params: &RunCmd) -> Result<(ExecutionTrace, [u8; 32]), Report> { // fetch the stack and program inputs from the arguments let stack_inputs = input_data.parse_stack_inputs().map_err(Report::msg)?; - let host = DefaultHost::new( - input_data.parse_advice_provider().map_err(Report::msg)?, - MemMastForestStore::default(), - ); + let host = DefaultHost::new(input_data.parse_advice_provider().map_err(Report::msg)?); let program_hash: [u8; 32] = program.hash().into(); diff --git a/miden/src/examples/blake3.rs b/miden/src/examples/blake3.rs index 0191afabcf..2e87cadbcc 100644 --- a/miden/src/examples/blake3.rs +++ b/miden/src/examples/blake3.rs @@ -1,6 +1,5 @@ use super::Example; use miden_vm::{Assembler, DefaultHost, MemAdviceProvider, Program, StackInputs}; -use processor::MemMastForestStore; use stdlib::StdLibrary; use vm_core::{utils::group_slice_elements, Felt}; @@ -12,7 +11,7 @@ const INITIAL_HASH_VALUE: [u32; 8] = [u32::MAX; 8]; // EXAMPLE BUILDER // ================================================================================================ -pub fn get_example(n: usize) -> Example> { +pub fn get_example(n: usize) -> Example> { // generate the program and expected results let program = generate_blake3_program(n); let expected_result = compute_hash_chain(n); diff --git a/miden/src/examples/fibonacci.rs b/miden/src/examples/fibonacci.rs index f2864b4f5f..7bd6555c52 100644 --- a/miden/src/examples/fibonacci.rs +++ b/miden/src/examples/fibonacci.rs @@ -1,11 +1,10 @@ use super::{Example, ONE, ZERO}; use miden_vm::{math::Felt, Assembler, DefaultHost, MemAdviceProvider, Program, StackInputs}; -use processor::MemMastForestStore; // EXAMPLE BUILDER // ================================================================================================ -pub fn get_example(n: usize) -> Example> { +pub fn get_example(n: usize) -> Example> { // generate the program and expected results let program = generate_fibonacci_program(n); let expected_result = vec![compute_fibonacci(n)]; diff --git a/miden/src/tools/mod.rs b/miden/src/tools/mod.rs index 349c9d5f32..0028b2c4e1 100644 --- a/miden/src/tools/mod.rs +++ b/miden/src/tools/mod.rs @@ -3,7 +3,7 @@ use assembly::diagnostics::{IntoDiagnostic, Report, WrapErr}; use clap::Parser; use core::fmt; use miden_vm::{Assembler, DefaultHost, Host, Operation, StackInputs}; -use processor::{AsmOpInfo, MemMastForestStore, TraceLenSummary}; +use processor::{AsmOpInfo, TraceLenSummary}; use std::{fs, path::PathBuf}; use stdlib::StdLibrary; @@ -35,10 +35,7 @@ impl Analyze { // fetch the stack and program inputs from the arguments let stack_inputs = input_data.parse_stack_inputs().map_err(Report::msg)?; - let host = DefaultHost::new( - input_data.parse_advice_provider().map_err(Report::msg)?, - MemMastForestStore::default(), - ); + let host = DefaultHost::new(input_data.parse_advice_provider().map_err(Report::msg)?); let execution_details: ExecutionDetails = analyze(program.as_str(), stack_inputs, host) .expect("Could not retrieve execution details"); diff --git a/processor/src/host/mod.rs b/processor/src/host/mod.rs index acdf6e5b7f..d6bfe9a79d 100644 --- a/processor/src/host/mod.rs +++ b/processor/src/host/mod.rs @@ -280,12 +280,12 @@ impl From for Felt { // ================================================================================================ /// A default [Host] implementation that provides the essential functionality required by the VM. -pub struct DefaultHost { +pub struct DefaultHost { adv_provider: A, - store: S, + store: MemMastForestStore, } -impl Default for DefaultHost { +impl Default for DefaultHost { fn default() -> Self { Self { adv_provider: MemAdviceProvider::default(), @@ -294,18 +294,21 @@ impl Default for DefaultHost { } } -impl DefaultHost +impl DefaultHost where A: AdviceProvider, - S: MastForestStore, { - pub fn new(adv_provider: A, store: S) -> Self { + pub fn new(adv_provider: A) -> Self { Self { adv_provider, - store, + store: MemMastForestStore::default(), } } + pub fn load_mast_forest(&mut self, mast_forest: MastForest) { + self.store.insert(mast_forest) + } + #[cfg(any(test, feature = "internals"))] pub fn advice_provider(&self) -> &A { &self.adv_provider @@ -321,10 +324,9 @@ where } } -impl Host for DefaultHost +impl Host for DefaultHost where A: AdviceProvider, - S: MastForestStore, { fn get_advice( &mut self, diff --git a/processor/src/operations/mod.rs b/processor/src/operations/mod.rs index 93256df878..eb9456f154 100644 --- a/processor/src/operations/mod.rs +++ b/processor/src/operations/mod.rs @@ -178,9 +178,9 @@ pub mod tests { use miden_air::ExecutionOptions; use vm_core::StackInputs; - use crate::{AdviceInputs, DefaultHost, MemAdviceProvider, MemMastForestStore}; + use crate::{AdviceInputs, DefaultHost, MemAdviceProvider}; - impl Process> { + impl Process> { /// Instantiates a new blank process for testing purposes. The stack in the process is /// initialized with the provided values. pub fn new_dummy(stack_inputs: StackInputs) -> Self { @@ -203,7 +203,7 @@ pub mod tests { let advice_inputs = AdviceInputs::default().with_stack_values(advice_stack.iter().copied()).unwrap(); let advice_provider = MemAdviceProvider::from(advice_inputs); - let host = DefaultHost::new(advice_provider, MemMastForestStore::default()); + let host = DefaultHost::new(advice_provider); let mut process = Self::new(Kernel::default(), stack_inputs, host, ExecutionOptions::default()); process.execute_op(Operation::Noop).unwrap(); @@ -233,7 +233,7 @@ pub mod tests { advice_inputs: AdviceInputs, ) -> Self { let advice_provider = MemAdviceProvider::from(advice_inputs); - let host = DefaultHost::new(advice_provider, MemMastForestStore::default()); + let host = DefaultHost::new(advice_provider); let mut process = Self::new(Kernel::default(), stack_inputs, host, ExecutionOptions::default()); process.decoder.add_dummy_trace_row(); diff --git a/processor/src/trace/tests/mod.rs b/processor/src/trace/tests/mod.rs index b324589013..19c11defbc 100644 --- a/processor/src/trace/tests/mod.rs +++ b/processor/src/trace/tests/mod.rs @@ -2,10 +2,7 @@ use super::{ super::chiplets::init_state_from_words, ExecutionTrace, Felt, FieldElement, Process, Trace, NUM_RAND_ROWS, }; -use crate::{ - host::MemMastForestStore, AdviceInputs, DefaultHost, ExecutionOptions, MemAdviceProvider, - StackInputs, -}; +use crate::{AdviceInputs, DefaultHost, ExecutionOptions, MemAdviceProvider, StackInputs}; use alloc::vec::Vec; use test_utils::rand::rand_array; use vm_core::{ @@ -54,8 +51,7 @@ pub fn build_trace_from_ops_with_inputs( advice_inputs: AdviceInputs, ) -> ExecutionTrace { let advice_provider = MemAdviceProvider::from(advice_inputs); - let mast_forest_store = MemMastForestStore::default(); - let host = DefaultHost::new(advice_provider, mast_forest_store); + let host = DefaultHost::new(advice_provider); let mut process = Process::new(Kernel::default(), stack_inputs, host, ExecutionOptions::default()); diff --git a/stdlib/tests/crypto/falcon.rs b/stdlib/tests/crypto/falcon.rs index 118813230e..8a89c6ae83 100644 --- a/stdlib/tests/crypto/falcon.rs +++ b/stdlib/tests/crypto/falcon.rs @@ -1,4 +1,4 @@ -use processor::{MemMastForestStore, Program, ProgramInfo}; +use processor::{Program, ProgramInfo}; use rand::{thread_rng, Rng}; use assembly::{utils::Serializable, Assembler}; @@ -208,7 +208,7 @@ fn falcon_prove_verify() { let stack_inputs = StackInputs::try_from_ints(op_stack).expect("failed to create stack inputs"); let advice_inputs = AdviceInputs::default().with_map(advice_map); let advice_provider = MemAdviceProvider::from(advice_inputs); - let host = DefaultHost::new(advice_provider, MemMastForestStore::default()); + let host = DefaultHost::new(advice_provider); let options = ProvingOptions::with_96_bit_security(false); let (stack_outputs, proof) = test_utils::prove(&program, stack_inputs.clone(), host, options) diff --git a/stdlib/tests/crypto/stark/mod.rs b/stdlib/tests/crypto/stark/mod.rs index 7e7c8b76c8..ac29770e3b 100644 --- a/stdlib/tests/crypto/stark/mod.rs +++ b/stdlib/tests/crypto/stark/mod.rs @@ -3,7 +3,7 @@ use verifier_recursive::{generate_advice_inputs, VerifierData}; use assembly::Assembler; use miden_air::{FieldExtension, HashFunction, PublicInputs}; -use processor::{DefaultHost, MemMastForestStore, Program, ProgramInfo}; +use processor::{DefaultHost, Program, ProgramInfo}; use test_utils::{ prove, AdviceInputs, MemAdviceProvider, ProvingOptions, StackInputs, VerifierError, }; @@ -55,7 +55,7 @@ pub fn generate_recursive_verifier_data( let stack_inputs = StackInputs::try_from_ints(stack_inputs).unwrap(); let advice_inputs = AdviceInputs::default(); let advice_provider = MemAdviceProvider::from(advice_inputs); - let host = DefaultHost::new(advice_provider, MemMastForestStore::default()); + let host = DefaultHost::new(advice_provider); let options = ProvingOptions::new(43, 8, 12, FieldExtension::Quadratic, 4, 7, HashFunction::Rpo256); diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 4db101e84c..ce8b28f172 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -8,7 +8,7 @@ extern crate std; // IMPORTS // ================================================================================================ -use processor::{MemMastForestStore, Program}; +use processor::Program; #[cfg(not(target_family = "wasm"))] use proptest::prelude::{Arbitrary, Strategy}; @@ -226,10 +226,7 @@ impl Test { ) { // compile the program let program: Program = self.compile().expect("Failed to compile test source."); - let host = DefaultHost::new( - MemAdviceProvider::from(self.advice_inputs.clone()), - MemMastForestStore::default(), - ); + let host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); // execute the test let mut process = Process::new( @@ -308,10 +305,7 @@ impl Test { #[track_caller] pub fn execute(&self) -> Result { let program: Program = self.compile().expect("Failed to compile test source."); - let host = DefaultHost::new( - MemAdviceProvider::from(self.advice_inputs.clone()), - MemMastForestStore::default(), - ); + let host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); processor::execute(&program, self.stack_inputs.clone(), host, ExecutionOptions::default()) } @@ -319,12 +313,9 @@ impl Test { /// process once execution is finished. pub fn execute_process( &self, - ) -> Result>, ExecutionError> { + ) -> Result>, ExecutionError> { let program: Program = self.compile().expect("Failed to compile test source."); - let host = DefaultHost::new( - MemAdviceProvider::from(self.advice_inputs.clone()), - MemMastForestStore::default(), - ); + let host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); let mut process = Process::new( program.kernel().clone(), self.stack_inputs.clone(), @@ -341,10 +332,7 @@ impl Test { pub fn prove_and_verify(&self, pub_inputs: Vec, test_fail: bool) { let stack_inputs = StackInputs::try_from_ints(pub_inputs).unwrap(); let program: Program = self.compile().expect("Failed to compile test source."); - let host = DefaultHost::new( - MemAdviceProvider::from(self.advice_inputs.clone()), - MemMastForestStore::default(), - ); + let host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); let (mut stack_outputs, proof) = prover::prove(&program, stack_inputs.clone(), host, ProvingOptions::default()).unwrap(); @@ -363,10 +351,7 @@ impl Test { /// state. pub fn execute_iter(&self) -> VmStateIterator { let program: Program = self.compile().expect("Failed to compile test source."); - let host = DefaultHost::new( - MemAdviceProvider::from(self.advice_inputs.clone()), - MemMastForestStore::default(), - ); + let host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); processor::execute_iter(&program, self.stack_inputs.clone(), host) } From 155a798672842c695d976da80f29af894f86f2bb Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 09:31:40 -0400 Subject: [PATCH 026/172] `MastForest::add_node()`: add docs --- core/src/mast/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index c4ca79b6be..527c31d297 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -60,6 +60,8 @@ impl MastForest { /// Mutators impl MastForest { /// Adds a node to the forest, and returns the associated [`MastNodeId`]. + /// + /// Adding two duplicate nodes will result in two distinct returned [`MastNodeId`]s. pub fn add_node(&mut self, node: MastNode) -> MastNodeId { let new_node_id = MastNodeId( self.nodes From bc6d13e138690e73d3568d796a5181bed68927a8 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 09:32:06 -0400 Subject: [PATCH 027/172] fmt --- core/src/mast/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 527c31d297..1d66c66923 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -60,7 +60,7 @@ impl MastForest { /// Mutators impl MastForest { /// Adds a node to the forest, and returns the associated [`MastNodeId`]. - /// + /// /// Adding two duplicate nodes will result in two distinct returned [`MastNodeId`]s. pub fn add_node(&mut self, node: MastNode) -> MastNodeId { let new_node_id = MastNodeId( From be2432004d2c76c5e932579be4306d6c48658639 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 09:39:11 -0400 Subject: [PATCH 028/172] add failing `duplicate_procedure()` test --- assembly/src/assembler/tests.rs | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index 77a4742767..08b5f134c8 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -215,3 +215,30 @@ fn nested_blocks() { // also check that the program has the right number of procedures assert_eq!(program.num_procedures(), 5); } + +/// Ensures that a single copy of procedures with the same MAST root are added only once to the MAST forest. +#[test] +fn duplicate_procedure() { + let assembler = Assembler::new(); + + let program_source = r#" + proc.foo + add + mul + end + + proc.bar + add + mul + end + + begin + # specific impl irrelevant + exec.foo + exec.bar + end + "#; + + let program = assembler.assemble(program_source).unwrap(); + assert_eq!(program.num_procedures(), 2); +} From 32aedd67e28df838be7fd1004710f56b18595407 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Sun, 23 Jun 2024 10:08:50 -0400 Subject: [PATCH 029/172] Introduce `MastForestBuilder` --- assembly/src/assembler/basic_block_builder.rs | 21 ++- assembly/src/assembler/instruction/mod.rs | 28 ++-- .../src/assembler/instruction/procedures.rs | 46 ++++--- assembly/src/assembler/mast_forest_builder.rs | 74 ++++++++++ assembly/src/assembler/mod.rs | 129 ++++++++++-------- assembly/src/assembler/tests.rs | 89 ++++++------ 6 files changed, 249 insertions(+), 138 deletions(-) create mode 100644 assembly/src/assembler/mast_forest_builder.rs diff --git a/assembly/src/assembler/basic_block_builder.rs b/assembly/src/assembler/basic_block_builder.rs index 3b6f0e3fd3..dfbdbd84c2 100644 --- a/assembly/src/assembler/basic_block_builder.rs +++ b/assembly/src/assembler/basic_block_builder.rs @@ -1,7 +1,10 @@ -use super::{AssemblyContext, BodyWrapper, Decorator, DecoratorList, Instruction}; +use super::{ + mast_forest_builder::MastForestBuilder, AssemblyContext, BodyWrapper, Decorator, DecoratorList, + Instruction, +}; use alloc::{borrow::Borrow, string::ToString, vec::Vec}; use vm_core::{ - mast::{MastForest, MastNode, MastNodeId}, + mast::{MastNode, MastNodeId}, AdviceInjector, AssemblyOp, Operation, }; @@ -123,13 +126,16 @@ impl BasicBlockBuilder { /// /// This consumes all operations and decorators in the builder, but does not touch the /// operations in the epilogue of the builder. - pub fn make_basic_block(&mut self, mast_forest: &mut MastForest) -> Option { + pub fn make_basic_block( + &mut self, + mast_forest_builder: &mut MastForestBuilder, + ) -> Option { if !self.ops.is_empty() { let ops = self.ops.drain(..).collect(); let decorators = self.decorators.drain(..).collect(); let basic_block_node = MastNode::new_basic_block_with_decorators(ops, decorators); - let basic_block_node_id = mast_forest.add_node(basic_block_node); + let basic_block_node_id = mast_forest_builder.ensure_node(basic_block_node); Some(basic_block_node_id) } else if !self.decorators.is_empty() { @@ -149,8 +155,11 @@ impl BasicBlockBuilder { /// - Operations contained in the epilogue of the builder are appended to the list of ops which /// go into the new BASIC BLOCK node. /// - The builder is consumed in the process. - pub fn into_basic_block(mut self, mast_forest: &mut MastForest) -> Option { + pub fn into_basic_block( + mut self, + mast_forest_builder: &mut MastForestBuilder, + ) -> Option { self.ops.append(&mut self.epilogue); - self.make_basic_block(mast_forest) + self.make_basic_block(mast_forest_builder) } } diff --git a/assembly/src/assembler/instruction/mod.rs b/assembly/src/assembler/instruction/mod.rs index 461c774d9c..9b57b21c73 100644 --- a/assembly/src/assembler/instruction/mod.rs +++ b/assembly/src/assembler/instruction/mod.rs @@ -1,13 +1,10 @@ use super::{ - ast::InvokeKind, Assembler, AssemblyContext, BasicBlockBuilder, Felt, Instruction, Operation, - ONE, ZERO, + ast::InvokeKind, mast_forest_builder::MastForestBuilder, Assembler, AssemblyContext, + BasicBlockBuilder, Felt, Instruction, Operation, ONE, ZERO, }; use crate::{diagnostics::Report, utils::bound_into_included_u64, AssemblyError}; use core::ops::RangeBounds; -use vm_core::{ - mast::{MastForest, MastNodeId}, - Decorator, -}; +use vm_core::{mast::MastNodeId, Decorator}; mod adv_ops; mod crypto_ops; @@ -27,7 +24,7 @@ impl Assembler { instruction: &Instruction, span_builder: &mut BasicBlockBuilder, ctx: &mut AssemblyContext, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { // if the assembler is in debug mode, start tracking the instruction about to be executed; // this will allow us to map the instruction to the sequence of operations which were @@ -36,7 +33,8 @@ impl Assembler { span_builder.track_instruction(instruction, ctx); } - let result = self.compile_instruction_impl(instruction, span_builder, ctx, mast_forest)?; + let result = + self.compile_instruction_impl(instruction, span_builder, ctx, mast_forest_builder)?; // compute and update the cycle count of the instruction which just finished executing if self.in_debug_mode() { @@ -51,7 +49,7 @@ impl Assembler { instruction: &Instruction, span_builder: &mut BasicBlockBuilder, ctx: &mut AssemblyContext, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { use Operation::*; @@ -369,18 +367,18 @@ impl Assembler { // ----- exec/call instructions ------------------------------------------------------- Instruction::Exec(ref callee) => { - return self.invoke(InvokeKind::Exec, callee, ctx, mast_forest) + return self.invoke(InvokeKind::Exec, callee, ctx, mast_forest_builder) } Instruction::Call(ref callee) => { - return self.invoke(InvokeKind::Call, callee, ctx, mast_forest) + return self.invoke(InvokeKind::Call, callee, ctx, mast_forest_builder) } Instruction::SysCall(ref callee) => { - return self.invoke(InvokeKind::SysCall, callee, ctx, mast_forest) + return self.invoke(InvokeKind::SysCall, callee, ctx, mast_forest_builder) } - Instruction::DynExec => return self.dynexec(mast_forest), - Instruction::DynCall => return self.dyncall(mast_forest), + Instruction::DynExec => return self.dynexec(mast_forest_builder), + Instruction::DynCall => return self.dyncall(mast_forest_builder), Instruction::ProcRef(ref callee) => { - self.procref(callee, ctx, span_builder, mast_forest)? + self.procref(callee, ctx, span_builder, mast_forest_builder.forest())? } // ----- debug decorators ------------------------------------------------------------- diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index e670fec358..3c2e2e7983 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -1,5 +1,6 @@ use super::{Assembler, AssemblyContext, BasicBlockBuilder, Operation}; use crate::{ + assembler::mast_forest_builder::MastForestBuilder, ast::{InvocationTarget, InvokeKind}, AssemblyError, RpoDigest, SourceSpan, Span, Spanned, }; @@ -14,11 +15,11 @@ impl Assembler { kind: InvokeKind, callee: &InvocationTarget, context: &mut AssemblyContext, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { let span = callee.span(); - let digest = self.resolve_target(kind, callee, context, mast_forest)?; - self.invoke_mast_root(kind, span, digest, context, mast_forest) + let digest = self.resolve_target(kind, callee, context, mast_forest_builder.forest())?; + self.invoke_mast_root(kind, span, digest, context, mast_forest_builder) } fn invoke_mast_root( @@ -27,7 +28,7 @@ impl Assembler { span: SourceSpan, mast_root: RpoDigest, context: &mut AssemblyContext, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { // Get the procedure from the assembler let cache = &self.procedure_cache; @@ -68,9 +69,11 @@ impl Assembler { }) } })?; - context.register_external_call(&proc, false, mast_forest)?; + context.register_external_call(&proc, false, mast_forest_builder.forest())?; + } + Some(proc) => { + context.register_external_call(&proc, false, mast_forest_builder.forest())? } - Some(proc) => context.register_external_call(&proc, false, mast_forest)?, None if matches!(kind, InvokeKind::SysCall) => { return Err(AssemblyError::UnknownSysCallTarget { span, @@ -90,7 +93,7 @@ impl Assembler { // hence be present in the `MastForest`. We currently assume that the // `MastForest` contains all the procedures being called; "external procedures" // only known by digest are not currently supported. - mast_forest.find_procedure_root(mast_root).ok_or_else(|| { + mast_forest_builder.find_procedure_root(mast_root).ok_or_else(|| { AssemblyError::UnknownExecTarget { span, source_file: current_source_file, @@ -100,27 +103,28 @@ impl Assembler { } InvokeKind::Call => { let callee_id = - mast_forest.find_procedure_root(mast_root).unwrap_or_else(|| { + mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| { // If the MAST root called isn't known to us, make it an external // reference. let external_node = MastNode::new_external(mast_root); - mast_forest.add_node(external_node) + mast_forest_builder.ensure_node(external_node) }); - let call_node = MastNode::new_call(callee_id, mast_forest); - mast_forest.add_node(call_node) + let call_node = MastNode::new_call(callee_id, mast_forest_builder.forest()); + mast_forest_builder.ensure_node(call_node) } InvokeKind::SysCall => { let callee_id = - mast_forest.find_procedure_root(mast_root).unwrap_or_else(|| { + mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| { // If the MAST root called isn't known to us, make it an external // reference. let external_node = MastNode::new_external(mast_root); - mast_forest.add_node(external_node) + mast_forest_builder.ensure_node(external_node) }); - let syscall_node = MastNode::new_syscall(callee_id, mast_forest); - mast_forest.add_node(syscall_node) + let syscall_node = + MastNode::new_syscall(callee_id, mast_forest_builder.forest()); + mast_forest_builder.ensure_node(syscall_node) } } }; @@ -131,9 +135,9 @@ impl Assembler { /// Creates a new DYN block for the dynamic code execution and return. pub(super) fn dynexec( &self, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { - let dyn_node_id = mast_forest.add_node(MastNode::Dyn); + let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn); Ok(Some(dyn_node_id)) } @@ -141,13 +145,13 @@ impl Assembler { /// Creates a new CALL block whose target is DYN. pub(super) fn dyncall( &self, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { let dyn_call_node_id = { - let dyn_node_id = mast_forest.add_node(MastNode::Dyn); - let dyn_call_node = MastNode::new_call(dyn_node_id, mast_forest); + let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn); + let dyn_call_node = MastNode::new_call(dyn_node_id, mast_forest_builder.forest()); - mast_forest.add_node(dyn_call_node) + mast_forest_builder.ensure_node(dyn_call_node) }; Ok(Some(dyn_call_node_id)) diff --git a/assembly/src/assembler/mast_forest_builder.rs b/assembly/src/assembler/mast_forest_builder.rs new file mode 100644 index 0000000000..39c42c388b --- /dev/null +++ b/assembly/src/assembler/mast_forest_builder.rs @@ -0,0 +1,74 @@ +use core::ops::Index; + +use alloc::collections::BTreeMap; +use vm_core::{ + crypto::hash::RpoDigest, + mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, +}; + +/// Builder for a [`MastForest`]. +#[derive(Clone, Debug, Default)] +pub struct MastForestBuilder { + mast_forest: MastForest, + node_id_by_hash: BTreeMap, +} + +impl MastForestBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn build(self) -> MastForest { + self.mast_forest + } +} + +/// Accessors +impl MastForestBuilder { + /// Returns the underlying [`MastForest`] being built + pub fn forest(&self) -> &MastForest { + &self.mast_forest + } + + /// Returns the [`MastNodeId`] of the procedure associated with a given digest, if any. + #[inline(always)] + pub fn find_procedure_root(&self, digest: RpoDigest) -> Option { + self.mast_forest.find_procedure_root(digest) + } +} + +/// Mutators +impl MastForestBuilder { + /// Adds a node to the forest, and returns the [`MastNodeId`] associated with it. + /// + /// If a [`MastNode`] which is equal to the current node was previously added, the previously + /// returned [`MastNodeId`] will be returned. This enforces this invariant that equal + /// [`MastNode`]s have equal [`MastNodeId`]s. + pub fn ensure_node(&mut self, node: MastNode) -> MastNodeId { + let node_digest = node.digest(); + + if let Some(node_id) = self.node_id_by_hash.get(&node_digest) { + // node already exists in the forest; return previously assigned id + *node_id + } else { + let new_node_id = self.mast_forest.add_node(node); + self.node_id_by_hash.insert(node_digest, new_node_id); + + new_node_id + } + } + + /// Marks the given [`MastNodeId`] as being the root of a procedure. + pub fn make_root(&mut self, new_root_id: MastNodeId) { + self.mast_forest.make_root(new_root_id) + } +} + +impl Index for MastForestBuilder { + type Output = MastNode; + + #[inline(always)] + fn index(&self, node_id: MastNodeId) -> &Self::Output { + &self.mast_forest[node_id] + } +} diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index b97394681e..4cf1a057ae 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -9,6 +9,7 @@ use crate::{ RpoDigest, Spanned, ONE, ZERO, }; use alloc::{boxed::Box, sync::Arc, vec::Vec}; +use mast_forest_builder::MastForestBuilder; use vm_core::{ mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, Decorator, DecoratorList, Kernel, Operation, Program, @@ -18,6 +19,7 @@ mod basic_block_builder; mod context; mod id; mod instruction; +mod mast_forest_builder; mod module_graph; mod procedure; #[cfg(test)] @@ -86,7 +88,7 @@ pub enum ArtifactKind { /// [Assembler::compile] or [Assembler::compile_ast] to get your compiled program. #[derive(Clone)] pub struct Assembler { - mast_forest: MastForest, + mast_forest_builder: MastForestBuilder, /// The global [ModuleGraph] for this assembler. All new [AssemblyContext]s inherit this graph /// as a baseline. module_graph: Box, @@ -103,7 +105,7 @@ pub struct Assembler { impl Default for Assembler { fn default() -> Self { Self { - mast_forest: Default::default(), + mast_forest_builder: Default::default(), module_graph: Default::default(), procedure_cache: Default::default(), warnings_as_errors: false, @@ -120,13 +122,11 @@ impl Assembler { Self::default() } - /// Start building an [`Assembler`] with the given [`Kernel`] and the [`MastForest`] that was - /// used to compile the kernel. - pub fn with_kernel(kernel: Kernel, mast_forest: MastForest) -> Self { + /// Start building an [`Assembler`] with the given [`Kernel`]. + pub fn with_kernel(kernel: Kernel) -> Self { let mut assembler = Self::new(); assembler.module_graph.set_kernel(None, kernel); - assembler.mast_forest = mast_forest; assembler } @@ -140,12 +140,13 @@ impl Assembler { let opts = CompileOptions::for_kernel(); let module = module.compile_with_options(opts)?; - let mut mast_forest = MastForest::new(); + let mut mast_forest_builder = MastForestBuilder::new(); - let (kernel_index, kernel) = assembler.assemble_kernel_module(module, &mut mast_forest)?; + let (kernel_index, kernel) = + assembler.assemble_kernel_module(module, &mut mast_forest_builder)?; assembler.module_graph.set_kernel(Some(kernel_index), kernel); - assembler.mast_forest = mast_forest; + assembler.mast_forest_builder = mast_forest_builder; Ok(assembler) } @@ -385,7 +386,7 @@ impl Assembler { )); } - let mast_forest = core::mem::take(&mut self.mast_forest); + let mast_forest_builder = core::mem::take(&mut self.mast_forest_builder); let program = source.compile_with_options(CompileOptions { // Override the module name so that we always compile the executable @@ -415,7 +416,7 @@ impl Assembler { }) .ok_or(SemanticAnalysisError::MissingEntrypoint)?; - self.compile_program(entrypoint, context, mast_forest) + self.compile_program(entrypoint, context, mast_forest_builder) } /// Compile and assembles all procedures in the specified module, adding them to the procedure @@ -462,13 +463,14 @@ impl Assembler { let module_id = self.module_graph.add_module(module)?; self.module_graph.recompute()?; - let mut mast_forest = core::mem::take(&mut self.mast_forest); + let mut mast_forest_builder = core::mem::take(&mut self.mast_forest_builder); - self.assemble_graph(context, &mut mast_forest)?; - let exported_procedure_digests = self.get_module_exports(module_id, &mast_forest); + self.assemble_graph(context, &mut mast_forest_builder)?; + let exported_procedure_digests = + self.get_module_exports(module_id, mast_forest_builder.forest()); // Reassign the mast_forest to the assembler for use is a future program assembly - self.mast_forest = mast_forest; + self.mast_forest_builder = mast_forest_builder; exported_procedure_digests } @@ -478,7 +480,7 @@ impl Assembler { fn assemble_kernel_module( &mut self, module: Box, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result<(ModuleIndex, Kernel), Report> { if !module.is_kernel() { return Err(Report::msg(format!("expected kernel module, got {}", module.kind()))); @@ -500,8 +502,8 @@ impl Assembler { module: kernel_index, index: ProcedureIndex::new(index), }; - let compiled = self.compile_subgraph(gid, false, &mut context, mast_forest)?; - kernel.push(compiled.mast_root(mast_forest)); + let compiled = self.compile_subgraph(gid, false, &mut context, mast_forest_builder)?; + kernel.push(compiled.mast_root(mast_forest_builder.forest())); } Kernel::new(&kernel) @@ -570,16 +572,17 @@ impl Assembler { &mut self, entrypoint: GlobalProcedureIndex, context: &mut AssemblyContext, - mut mast_forest: MastForest, + mut mast_forest_builder: MastForestBuilder, ) -> Result { // Raise an error if we are called with an invalid entrypoint assert!(self.module_graph[entrypoint].name().is_main()); // Compile the module graph rooted at the entrypoint - let entry_procedure = self.compile_subgraph(entrypoint, true, context, &mut mast_forest)?; + let entry_procedure = + self.compile_subgraph(entrypoint, true, context, &mut mast_forest_builder)?; Ok(Program::with_kernel( - mast_forest, + mast_forest_builder.build(), entry_procedure.body_node_id(), self.module_graph.kernel().clone(), )) @@ -592,11 +595,11 @@ impl Assembler { fn assemble_graph( &mut self, context: &mut AssemblyContext, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result<(), Report> { let mut worklist = self.module_graph.topological_sort().to_vec(); assert!(!worklist.is_empty()); - self.process_graph_worklist(&mut worklist, context, None, mast_forest) + self.process_graph_worklist(&mut worklist, context, None, mast_forest_builder) .map(|_| ()) } @@ -609,7 +612,7 @@ impl Assembler { root: GlobalProcedureIndex, is_entrypoint: bool, context: &mut AssemblyContext, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result, Report> { let mut worklist = self.module_graph.topological_sort_from_root(root).map_err(|cycle| { let iter = cycle.into_node_ids(); @@ -625,9 +628,10 @@ impl Assembler { assert!(!worklist.is_empty()); let compiled = if is_entrypoint { - self.process_graph_worklist(&mut worklist, context, Some(root), mast_forest)? + self.process_graph_worklist(&mut worklist, context, Some(root), mast_forest_builder)? } else { - let _ = self.process_graph_worklist(&mut worklist, context, None, mast_forest)?; + let _ = + self.process_graph_worklist(&mut worklist, context, None, mast_forest_builder)?; self.procedure_cache.get(root) }; @@ -639,7 +643,7 @@ impl Assembler { worklist: &mut Vec, context: &mut AssemblyContext, entrypoint: Option, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result>, Report> { // Process the topological ordering in reverse order (bottom-up), so that // each procedure is compiled with all of its dependencies fully compiled @@ -647,8 +651,10 @@ impl Assembler { while let Some(procedure_gid) = worklist.pop() { // If we have already compiled this procedure, do not recompile if let Some(proc) = self.procedure_cache.get(procedure_gid) { - self.module_graph - .register_mast_root(procedure_gid, proc.mast_root(mast_forest))?; + self.module_graph.register_mast_root( + procedure_gid, + proc.mast_root(mast_forest_builder.forest()), + )?; continue; } let is_entry = entrypoint == Some(procedure_gid); @@ -668,17 +674,21 @@ impl Assembler { .with_source_file(ast.source_file()); // Compile this procedure - let procedure = self.compile_procedure(pctx, context, mast_forest)?; + let procedure = self.compile_procedure(pctx, context, mast_forest_builder)?; // Cache the compiled procedure, unless it's the program entrypoint if is_entry { compiled_entrypoint = Some(Arc::from(procedure)); } else { // Make the MAST root available to all dependents - let digest = procedure.mast_root(mast_forest); + let digest = procedure.mast_root(mast_forest_builder.forest()); self.module_graph.register_mast_root(procedure_gid, digest)?; - self.procedure_cache.insert(procedure_gid, Arc::from(procedure), mast_forest)?; + self.procedure_cache.insert( + procedure_gid, + Arc::from(procedure), + mast_forest_builder.forest(), + )?; } } @@ -690,7 +700,7 @@ impl Assembler { &self, procedure: ProcedureContext, context: &mut AssemblyContext, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result, Report> { // Make sure the current procedure context is available during codegen let gid = procedure.id(); @@ -708,12 +718,12 @@ impl Assembler { prologue: vec![Operation::Push(num_locals), Operation::FmpUpdate], epilogue: vec![Operation::Push(-num_locals), Operation::FmpUpdate], }; - self.compile_body(proc.iter(), context, Some(wrapper), mast_forest)? + self.compile_body(proc.iter(), context, Some(wrapper), mast_forest_builder)? } else { - self.compile_body(proc.iter(), context, None, mast_forest)? + self.compile_body(proc.iter(), context, None, mast_forest_builder)? }; - mast_forest.make_root(proc_body_root); + mast_forest_builder.make_root(proc_body_root); let pctx = context.take_current_procedure().unwrap(); Ok(pctx.into_procedure(proc_body_root)) @@ -724,7 +734,7 @@ impl Assembler { body: I, context: &mut AssemblyContext, wrapper: Option, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> Result where I: Iterator, @@ -741,10 +751,10 @@ impl Assembler { inst, &mut basic_block_builder, context, - mast_forest, + mast_forest_builder, )? { if let Some(basic_block_id) = - basic_block_builder.make_basic_block(mast_forest) + basic_block_builder.make_basic_block(mast_forest_builder) { mast_node_ids.push(basic_block_id); } @@ -756,38 +766,41 @@ impl Assembler { Op::If { then_blk, else_blk, .. } => { - if let Some(basic_block_id) = basic_block_builder.make_basic_block(mast_forest) + if let Some(basic_block_id) = + basic_block_builder.make_basic_block(mast_forest_builder) { mast_node_ids.push(basic_block_id); } let then_blk = - self.compile_body(then_blk.iter(), context, None, mast_forest)?; + self.compile_body(then_blk.iter(), context, None, mast_forest_builder)?; // else is an exception because it is optional; hence, will have to be replaced // by noop span let else_blk = if else_blk.is_empty() { let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]); - mast_forest.add_node(basic_block_node) + mast_forest_builder.ensure_node(basic_block_node) } else { - self.compile_body(else_blk.iter(), context, None, mast_forest)? + self.compile_body(else_blk.iter(), context, None, mast_forest_builder)? }; let split_node_id = { - let split_node = MastNode::new_split(then_blk, else_blk, mast_forest); + let split_node = + MastNode::new_split(then_blk, else_blk, mast_forest_builder.forest()); - mast_forest.add_node(split_node) + mast_forest_builder.ensure_node(split_node) }; mast_node_ids.push(split_node_id); } Op::Repeat { count, body, .. } => { - if let Some(basic_block_id) = basic_block_builder.make_basic_block(mast_forest) + if let Some(basic_block_id) = + basic_block_builder.make_basic_block(mast_forest_builder) { mast_node_ids.push(basic_block_id); } let repeat_node_id = - self.compile_body(body.iter(), context, None, mast_forest)?; + self.compile_body(body.iter(), context, None, mast_forest_builder)?; for _ in 0..*count { mast_node_ids.push(repeat_node_id); @@ -795,32 +808,34 @@ impl Assembler { } Op::While { body, .. } => { - if let Some(basic_block_id) = basic_block_builder.make_basic_block(mast_forest) + if let Some(basic_block_id) = + basic_block_builder.make_basic_block(mast_forest_builder) { mast_node_ids.push(basic_block_id); } let loop_body_node_id = - self.compile_body(body.iter(), context, None, mast_forest)?; + self.compile_body(body.iter(), context, None, mast_forest_builder)?; let loop_node_id = { - let loop_node = MastNode::new_loop(loop_body_node_id, mast_forest); - mast_forest.add_node(loop_node) + let loop_node = + MastNode::new_loop(loop_body_node_id, mast_forest_builder.forest()); + mast_forest_builder.ensure_node(loop_node) }; mast_node_ids.push(loop_node_id); } } } - if let Some(basic_block_id) = basic_block_builder.into_basic_block(mast_forest) { + if let Some(basic_block_id) = basic_block_builder.into_basic_block(mast_forest_builder) { mast_node_ids.push(basic_block_id); } Ok(if mast_node_ids.is_empty() { let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]); - mast_forest.add_node(basic_block_node) + mast_forest_builder.ensure_node(basic_block_node) } else { - combine_mast_node_ids(mast_node_ids, mast_forest) + combine_mast_node_ids(mast_node_ids, mast_forest_builder) }) } @@ -859,7 +874,7 @@ struct BodyWrapper { fn combine_mast_node_ids( mut mast_node_ids: Vec, - mast_forest: &mut MastForest, + mast_forest_builder: &mut MastForestBuilder, ) -> MastNodeId { debug_assert!(!mast_node_ids.is_empty(), "cannot combine empty MAST node id list"); @@ -878,8 +893,8 @@ fn combine_mast_node_ids( while let (Some(left), Some(right)) = (source_mast_node_iter.next(), source_mast_node_iter.next()) { - let join_mast_node = MastNode::new_join(left, right, mast_forest); - let join_mast_node_id = mast_forest.add_node(join_mast_node); + let join_mast_node = MastNode::new_join(left, right, mast_forest_builder.forest()); + let join_mast_node_id = mast_forest_builder.ensure_node(join_mast_node); mast_node_ids.push(join_mast_node_id); } diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index 08b5f134c8..3eeec5d4f2 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -1,12 +1,9 @@ use alloc::{boxed::Box, vec::Vec}; -use vm_core::{ - mast::{MastForest, MastNode}, - Program, -}; +use vm_core::{mast::MastNode, Program}; use super::{Assembler, Library, Operation}; use crate::{ - assembler::combine_mast_node_ids, + assembler::{combine_mast_node_ids, mast_forest_builder::MastForestBuilder}, ast::{Module, ModuleKind}, LibraryNamespace, Version, }; @@ -71,7 +68,7 @@ fn nested_blocks() { .unwrap(); // The expected `MastForest` for the program (that we will build by hand) - let mut expected_mast_forest = MastForest::new(); + let mut expected_mast_forest_builder = MastForestBuilder::new(); // fetch the kernel digest and store into a syscall block // @@ -80,10 +77,11 @@ fn nested_blocks() { // `Assembler::with_kernel_from_module()`. let syscall_foo_node_id = { let kernel_foo_node = MastNode::new_basic_block(vec![Operation::Add]); - let kernel_foo_node_id = expected_mast_forest.add_node(kernel_foo_node); + let kernel_foo_node_id = expected_mast_forest_builder.ensure_node(kernel_foo_node); - let syscall_node = MastNode::new_syscall(kernel_foo_node_id, &expected_mast_forest); - expected_mast_forest.add_node(syscall_node) + let syscall_node = + MastNode::new_syscall(kernel_foo_node_id, expected_mast_forest_builder.forest()); + expected_mast_forest_builder.ensure_node(syscall_node) }; let program = r#" @@ -127,96 +125,109 @@ fn nested_blocks() { let exec_bar_node_id = { // bar procedure let basic_block_1 = MastNode::new_basic_block(vec![Operation::Push(17_u32.into())]); - let basic_block_1_id = expected_mast_forest.add_node(basic_block_1); + let basic_block_1_id = expected_mast_forest_builder.ensure_node(basic_block_1); // Basic block representing the `foo` procedure let basic_block_2 = MastNode::new_basic_block(vec![Operation::Push(19_u32.into())]); - let basic_block_2_id = expected_mast_forest.add_node(basic_block_2); - - let join_node = - MastNode::new_join(basic_block_1_id, basic_block_2_id, &expected_mast_forest); - expected_mast_forest.add_node(join_node) + let basic_block_2_id = expected_mast_forest_builder.ensure_node(basic_block_2); + + let join_node = MastNode::new_join( + basic_block_1_id, + basic_block_2_id, + expected_mast_forest_builder.forest(), + ); + expected_mast_forest_builder.ensure_node(join_node) }; let exec_foo_bar_baz_node_id = { // basic block representing foo::bar.baz procedure let basic_block = MastNode::new_basic_block(vec![Operation::Push(29_u32.into())]); - expected_mast_forest.add_node(basic_block) + expected_mast_forest_builder.ensure_node(basic_block) }; let before = { let before_node = MastNode::new_basic_block(vec![Operation::Push(2u32.into())]); - expected_mast_forest.add_node(before_node) + expected_mast_forest_builder.ensure_node(before_node) }; let r#true1 = { let r#true_node = MastNode::new_basic_block(vec![Operation::Push(3u32.into())]); - expected_mast_forest.add_node(r#true_node) + expected_mast_forest_builder.ensure_node(r#true_node) }; let r#false1 = { let r#false_node = MastNode::new_basic_block(vec![Operation::Push(5u32.into())]); - expected_mast_forest.add_node(r#false_node) + expected_mast_forest_builder.ensure_node(r#false_node) }; let r#if1 = { - let r#if_node = MastNode::new_split(r#true1, r#false1, &expected_mast_forest); - expected_mast_forest.add_node(r#if_node) + let r#if_node = + MastNode::new_split(r#true1, r#false1, expected_mast_forest_builder.forest()); + expected_mast_forest_builder.ensure_node(r#if_node) }; let r#true3 = { let r#true_node = MastNode::new_basic_block(vec![Operation::Push(7u32.into())]); - expected_mast_forest.add_node(r#true_node) + expected_mast_forest_builder.ensure_node(r#true_node) }; let r#false3 = { let r#false_node = MastNode::new_basic_block(vec![Operation::Push(11u32.into())]); - expected_mast_forest.add_node(r#false_node) + expected_mast_forest_builder.ensure_node(r#false_node) }; let r#true2 = { - let r#if_node = MastNode::new_split(r#true3, r#false3, &expected_mast_forest); - expected_mast_forest.add_node(r#if_node) + let r#if_node = + MastNode::new_split(r#true3, r#false3, expected_mast_forest_builder.forest()); + expected_mast_forest_builder.ensure_node(r#if_node) }; let r#while = { let push_basic_block_id = { let push_basic_block = MastNode::new_basic_block(vec![Operation::Push(23u32.into())]); - expected_mast_forest.add_node(push_basic_block) + expected_mast_forest_builder.ensure_node(push_basic_block) }; let body_node_id = { - let body_node = - MastNode::new_join(exec_bar_node_id, push_basic_block_id, &expected_mast_forest); + let body_node = MastNode::new_join( + exec_bar_node_id, + push_basic_block_id, + expected_mast_forest_builder.forest(), + ); - expected_mast_forest.add_node(body_node) + expected_mast_forest_builder.ensure_node(body_node) }; - let loop_node = MastNode::new_loop(body_node_id, &expected_mast_forest); - expected_mast_forest.add_node(loop_node) + let loop_node = MastNode::new_loop(body_node_id, expected_mast_forest_builder.forest()); + expected_mast_forest_builder.ensure_node(loop_node) }; let push_13_basic_block_id = { let node = MastNode::new_basic_block(vec![Operation::Push(13u32.into())]); - expected_mast_forest.add_node(node) + expected_mast_forest_builder.ensure_node(node) }; let r#false2 = { - let node = MastNode::new_join(push_13_basic_block_id, r#while, &expected_mast_forest); - expected_mast_forest.add_node(node) + let node = MastNode::new_join( + push_13_basic_block_id, + r#while, + expected_mast_forest_builder.forest(), + ); + expected_mast_forest_builder.ensure_node(node) }; let nested = { - let node = MastNode::new_split(r#true2, r#false2, &expected_mast_forest); - expected_mast_forest.add_node(node) + let node = MastNode::new_split(r#true2, r#false2, expected_mast_forest_builder.forest()); + expected_mast_forest_builder.ensure_node(node) }; let combined_node_id = combine_mast_node_ids( vec![before, r#if1, nested, exec_foo_bar_baz_node_id, syscall_foo_node_id], - &mut expected_mast_forest, + &mut expected_mast_forest_builder, ); - let expected_program = Program::new(expected_mast_forest, combined_node_id); + let expected_program = Program::new(expected_mast_forest_builder.build(), combined_node_id); assert_eq!(expected_program.hash(), program.hash()); // also check that the program has the right number of procedures assert_eq!(program.num_procedures(), 5); } -/// Ensures that a single copy of procedures with the same MAST root are added only once to the MAST forest. +/// Ensures that a single copy of procedures with the same MAST root are added only once to the MAST +/// forest. #[test] fn duplicate_procedure() { let assembler = Assembler::new(); From 088de8297257a536d2b69ae4b61a1e1556e32ff2 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Tue, 25 Jun 2024 13:45:46 -0400 Subject: [PATCH 030/172] Rename `mod tests` -> `testing` --- processor/src/operations/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/processor/src/operations/mod.rs b/processor/src/operations/mod.rs index eb9456f154..cb677f8ac3 100644 --- a/processor/src/operations/mod.rs +++ b/processor/src/operations/mod.rs @@ -173,7 +173,7 @@ where } #[cfg(test)] -pub mod tests { +pub mod testing { use super::*; use miden_air::ExecutionOptions; use vm_core::StackInputs; From 9d48fdaef52523861e10630f24b6e2ddd5233ab6 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Tue, 25 Jun 2024 16:40:23 -0400 Subject: [PATCH 031/172] add `duplicate_node()` test --- assembly/src/assembler/tests.rs | 56 ++++++++++++++++++++++++++++++++- core/src/mast/mod.rs | 2 +- core/src/program.rs | 2 +- 3 files changed, 57 insertions(+), 3 deletions(-) diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index 3eeec5d4f2..37ed89e5ca 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -1,5 +1,9 @@ use alloc::{boxed::Box, vec::Vec}; -use vm_core::{mast::MastNode, Program}; +use pretty_assertions::assert_eq; +use vm_core::{ + mast::{MastForest, MastNode}, + Program, +}; use super::{Assembler, Library, Operation}; use crate::{ @@ -253,3 +257,53 @@ fn duplicate_procedure() { let program = assembler.assemble(program_source).unwrap(); assert_eq!(program.num_procedures(), 2); } + +/// Ensures that equal MAST nodes don't get added twice to a MAST forest +#[test] +fn duplicate_nodes() { + let assembler = Assembler::new(); + + let program_source = r#" + begin + if.true + mul + else + if.true add else mul end + end + end + "#; + + let program = assembler.assemble(program_source).unwrap(); + + let mut expected_mast_forest = MastForest::new(); + + // basic block: mul + let mul_basic_block_id = { + let node = MastNode::new_basic_block(vec![Operation::Mul]); + expected_mast_forest.add_node(node) + }; + + // basic block: add + let add_basic_block_id = { + let node = MastNode::new_basic_block(vec![Operation::Add]); + expected_mast_forest.add_node(node) + }; + + // inner split: `if.true add else mul end` + let inner_split_id = { + let node = + MastNode::new_split(add_basic_block_id, mul_basic_block_id, &expected_mast_forest); + expected_mast_forest.add_node(node) + }; + + // root: outer split + let root_id = { + let node = MastNode::new_split(mul_basic_block_id, inner_split_id, &expected_mast_forest); + expected_mast_forest.add_node(node) + }; + expected_mast_forest.make_root(root_id); + + let expected_program = Program::new(expected_mast_forest, root_id); + + assert_eq!(program, expected_program); +} diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 1d66c66923..ab03f7139f 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -40,7 +40,7 @@ impl fmt::Display for MastNodeId { /// /// A [`MastForest`] does not have an entrypoint, and hence is not executable. A [`crate::Program`] /// can be built from a [`MastForest`] to specify an entrypoint. -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct MastForest { /// All of the nodes local to the trees comprising the MAST forest. nodes: Vec, diff --git a/core/src/program.rs b/core/src/program.rs index 67d0ce4c5e..0c4b0ac098 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -14,7 +14,7 @@ use super::Kernel; // PROGRAM // =============================================================================================== -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct Program { mast_forest: MastForest, /// The "entrypoint" is the node where execution of the program begins. From 6c62d9bdefec3d5052effd1c1c545fae00c3ec12 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Tue, 25 Jun 2024 16:46:42 -0400 Subject: [PATCH 032/172] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7049192ea8..6e84aa7a5c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Added error codes support for the `mtree_verify` instruction (#1328). - Added support for immediate values for `lt`, `lte`, `gt`, `gte` comparison instructions (#1346). - Change MAST to a table-based representation (#1349) +- Introduce `MastForestStore` (#1359) - Adjusted prover's metal acceleration code to work with 0.9 versions of the crates (#1357) ## 0.9.2 (2024-05-22) - `stdlib` crate only From 039bba02af459b9998bb58920c244ab7fb2743dc Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 26 Jun 2024 09:23:07 -0400 Subject: [PATCH 033/172] Program: use `assert!()` instead of `debug_assert!()` --- core/src/program.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/program.rs b/core/src/program.rs index 0c4b0ac098..536978aaa2 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -26,8 +26,11 @@ pub struct Program { impl Program { /// Construct a new [`Program`] from the given MAST forest and entrypoint. The kernel is assumed /// to be empty. + /// + /// Panics: + /// - if `mast_forest` doesn't have an entrypoint pub fn new(mast_forest: MastForest, entrypoint: MastNodeId) -> Self { - debug_assert!(mast_forest.get_node_by_id(entrypoint).is_some()); + assert!(mast_forest.get_node_by_id(entrypoint).is_some()); Self { mast_forest, @@ -37,8 +40,11 @@ impl Program { } /// Construct a new [`Program`] from the given MAST forest, entrypoint, and kernel. + /// + /// Panics: + /// - if `mast_forest` doesn't have an entrypoint pub fn with_kernel(mast_forest: MastForest, entrypoint: MastNodeId, kernel: Kernel) -> Self { - debug_assert!(mast_forest.get_node_by_id(entrypoint).is_some()); + assert!(mast_forest.get_node_by_id(entrypoint).is_some()); Self { mast_forest, From 9c9e1717feb34254f2cedcfbb3c402059fc123af Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 26 Jun 2024 09:32:43 -0400 Subject: [PATCH 034/172] `MastForest::make_root()`: add assert --- core/src/mast/mod.rs | 6 ++++++ core/src/program.rs | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index ab03f7139f..4efabf7c83 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -76,7 +76,13 @@ impl MastForest { } /// Marks the given [`MastNodeId`] as being the root of a procedure. + /// + /// # Panics + /// - if `new_root_id`'s internal index is larger than the number of nodes in this forest (i.e. + /// clearly doesn't belong to this MAST forest). pub fn make_root(&mut self, new_root_id: MastNodeId) { + assert!((new_root_id.0 as usize) < self.nodes.len()); + if !self.roots.contains(&new_root_id) { self.roots.push(new_root_id); } diff --git a/core/src/program.rs b/core/src/program.rs index 536978aaa2..c8bc5b7893 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -27,7 +27,7 @@ impl Program { /// Construct a new [`Program`] from the given MAST forest and entrypoint. The kernel is assumed /// to be empty. /// - /// Panics: + /// # Panics: /// - if `mast_forest` doesn't have an entrypoint pub fn new(mast_forest: MastForest, entrypoint: MastNodeId) -> Self { assert!(mast_forest.get_node_by_id(entrypoint).is_some()); @@ -41,7 +41,7 @@ impl Program { /// Construct a new [`Program`] from the given MAST forest, entrypoint, and kernel. /// - /// Panics: + /// # Panics: /// - if `mast_forest` doesn't have an entrypoint pub fn with_kernel(mast_forest: MastForest, entrypoint: MastNodeId, kernel: Kernel) -> Self { assert!(mast_forest.get_node_by_id(entrypoint).is_some()); From c34e985b13bd5fd8e5ac7228ed155201db5ad964 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 26 Jun 2024 09:35:02 -0400 Subject: [PATCH 035/172] fmt --- core/src/mast/mod.rs | 2 +- core/src/program.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 4efabf7c83..5371d4c35f 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -76,7 +76,7 @@ impl MastForest { } /// Marks the given [`MastNodeId`] as being the root of a procedure. - /// + /// /// # Panics /// - if `new_root_id`'s internal index is larger than the number of nodes in this forest (i.e. /// clearly doesn't belong to this MAST forest). diff --git a/core/src/program.rs b/core/src/program.rs index c8bc5b7893..b055bf3313 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -26,7 +26,7 @@ pub struct Program { impl Program { /// Construct a new [`Program`] from the given MAST forest and entrypoint. The kernel is assumed /// to be empty. - /// + /// /// # Panics: /// - if `mast_forest` doesn't have an entrypoint pub fn new(mast_forest: MastForest, entrypoint: MastNodeId) -> Self { @@ -40,7 +40,7 @@ impl Program { } /// Construct a new [`Program`] from the given MAST forest, entrypoint, and kernel. - /// + /// /// # Panics: /// - if `mast_forest` doesn't have an entrypoint pub fn with_kernel(mast_forest: MastForest, entrypoint: MastNodeId, kernel: Kernel) -> Self { From c1b269311cf7e1bbd641bffe07a10bb5794f03c1 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 26 Jun 2024 13:47:10 -0400 Subject: [PATCH 036/172] Serialization for `MastNodeId` --- core/src/mast/mod.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 5371d4c35f..3db9d5acb6 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -8,6 +8,7 @@ pub use node::{ get_span_op_group_count, BasicBlockNode, CallNode, DynNode, JoinNode, LoopNode, MastNode, OpBatch, SplitNode, OP_BATCH_SIZE, OP_GROUP_SIZE, }; +use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; #[cfg(test)] mod tests; @@ -33,6 +34,20 @@ impl fmt::Display for MastNodeId { } } +impl Serializable for MastNodeId { + fn write_into(&self, target: &mut W) { + self.0.write_into(target) + } +} + +impl Deserializable for MastNodeId { + fn read_from(source: &mut R) -> Result { + let inner = source.read_u32()?; + + Ok(Self(inner)) + } +} + // MAST FOREST // =============================================================================================== From efc24fdcc792baf25cac80bf0c06c7c2a477ab4a Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 26 Jun 2024 14:01:58 -0400 Subject: [PATCH 037/172] serialization for MastNode variants except basic block --- core/src/mast/node/call_node.rs | 29 +++++++++++++++++++++++++++++ core/src/mast/node/dyn_node.rs | 13 +++++++++++++ core/src/mast/node/external.rs | 17 +++++++++++++++++ core/src/mast/node/join_node.rs | 19 +++++++++++++++++++ core/src/mast/node/loop_node.rs | 19 +++++++++++++++++++ core/src/mast/node/split_node.rs | 19 +++++++++++++++++++ 6 files changed, 116 insertions(+) diff --git a/core/src/mast/node/call_node.rs b/core/src/mast/node/call_node.rs index c2183ea84d..e9df3b45e5 100644 --- a/core/src/mast/node/call_node.rs +++ b/core/src/mast/node/call_node.rs @@ -2,6 +2,7 @@ use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; use miden_formatting::prettier::PrettyPrint; +use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use crate::{chiplets::hasher, Operation}; @@ -98,6 +99,34 @@ impl MerkleTreeNode for CallNode { } } +impl Serializable for CallNode { + fn write_into(&self, target: &mut W) { + let Self { + callee, + is_syscall, + digest, + } = self; + + callee.write_into(target); + target.write_bool(*is_syscall); + digest.write_into(target); + } +} + +impl Deserializable for CallNode { + fn read_from(source: &mut R) -> Result { + let callee = Deserializable::read_from(source)?; + let is_syscall = source.read_bool()?; + let digest = Deserializable::read_from(source)?; + + Ok(Self { + callee, + is_syscall, + digest, + }) + } +} + struct CallNodePrettyPrint<'a> { call_node: &'a CallNode, mast_forest: &'a MastForest, diff --git a/core/src/mast/node/dyn_node.rs b/core/src/mast/node/dyn_node.rs index c298a03ade..716a602446 100644 --- a/core/src/mast/node/dyn_node.rs +++ b/core/src/mast/node/dyn_node.rs @@ -1,6 +1,7 @@ use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; +use winter_utils::{ByteReader, Deserializable, DeserializationError, Serializable}; use crate::{ mast::{MastForest, MerkleTreeNode}, @@ -34,6 +35,18 @@ impl MerkleTreeNode for DynNode { } } +impl Serializable for DynNode { + fn write_into(&self, _target: &mut W) { + // nothing + } +} + +impl Deserializable for DynNode { + fn read_from(_source: &mut R) -> Result { + Ok(Self) + } +} + impl crate::prettier::PrettyPrint for DynNode { fn render(&self) -> crate::prettier::Document { use crate::prettier::*; diff --git a/core/src/mast/node/external.rs b/core/src/mast/node/external.rs index c0b8ff10a3..a778f2b960 100644 --- a/core/src/mast/node/external.rs +++ b/core/src/mast/node/external.rs @@ -1,6 +1,7 @@ use crate::mast::{MastForest, MerkleTreeNode}; use core::fmt; use miden_crypto::hash::rpo::RpoDigest; +use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; /// Node for referencing procedures not present in a given [`MastForest`] (hence "external"). /// @@ -33,6 +34,22 @@ impl MerkleTreeNode for ExternalNode { } } +impl Serializable for ExternalNode { + fn write_into(&self, target: &mut W) { + let Self { digest } = self; + + digest.write_into(target); + } +} + +impl Deserializable for ExternalNode { + fn read_from(source: &mut R) -> Result { + let digest = Deserializable::read_from(source)?; + + Ok(Self { digest }) + } +} + impl crate::prettier::PrettyPrint for ExternalNode { fn render(&self) -> crate::prettier::Document { use crate::prettier::*; diff --git a/core/src/mast/node/join_node.rs b/core/src/mast/node/join_node.rs index 3c4a712655..af8cc70e24 100644 --- a/core/src/mast/node/join_node.rs +++ b/core/src/mast/node/join_node.rs @@ -1,6 +1,7 @@ use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; +use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use crate::{chiplets::hasher, prettier::PrettyPrint, Operation}; @@ -67,6 +68,24 @@ impl MerkleTreeNode for JoinNode { } } +impl Serializable for JoinNode { + fn write_into(&self, target: &mut W) { + let Self { children, digest } = self; + + children.write_into(target); + digest.write_into(target); + } +} + +impl Deserializable for JoinNode { + fn read_from(source: &mut R) -> Result { + let children = Deserializable::read_from(source)?; + let digest = Deserializable::read_from(source)?; + + Ok(Self { children, digest }) + } +} + struct JoinNodePrettyPrint<'a> { join_node: &'a JoinNode, mast_forest: &'a MastForest, diff --git a/core/src/mast/node/loop_node.rs b/core/src/mast/node/loop_node.rs index fc63b11367..d2150cde00 100644 --- a/core/src/mast/node/loop_node.rs +++ b/core/src/mast/node/loop_node.rs @@ -2,6 +2,7 @@ use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; use miden_formatting::prettier::PrettyPrint; +use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use crate::{chiplets::hasher, Operation}; @@ -61,6 +62,24 @@ impl MerkleTreeNode for LoopNode { } } +impl Serializable for LoopNode { + fn write_into(&self, target: &mut W) { + let Self { body, digest } = self; + + body.write_into(target); + digest.write_into(target); + } +} + +impl Deserializable for LoopNode { + fn read_from(source: &mut R) -> Result { + let body = Deserializable::read_from(source)?; + let digest = Deserializable::read_from(source)?; + + Ok(Self { body, digest }) + } +} + struct LoopNodePrettyPrint<'a> { loop_node: &'a LoopNode, mast_forest: &'a MastForest, diff --git a/core/src/mast/node/split_node.rs b/core/src/mast/node/split_node.rs index ca87501fe3..381ca28ef4 100644 --- a/core/src/mast/node/split_node.rs +++ b/core/src/mast/node/split_node.rs @@ -2,6 +2,7 @@ use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; use miden_formatting::prettier::PrettyPrint; +use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use crate::{chiplets::hasher, Operation}; @@ -69,6 +70,24 @@ impl MerkleTreeNode for SplitNode { } } +impl Serializable for SplitNode { + fn write_into(&self, target: &mut W) { + let Self { branches, digest } = self; + + branches.write_into(target); + digest.write_into(target); + } +} + +impl Deserializable for SplitNode { + fn read_from(source: &mut R) -> Result { + let branches = Deserializable::read_from(source)?; + let digest = Deserializable::read_from(source)?; + + Ok(Self { branches, digest }) + } +} + struct SplitNodePrettyPrint<'a> { split_node: &'a SplitNode, mast_forest: &'a MastForest, From 861d0d5f32cb67e4f71feaaad0ac356ac77f0cc8 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 27 Jun 2024 13:08:34 -0400 Subject: [PATCH 038/172] MastForest serialization scaffolding --- core/Cargo.toml | 3 +- core/src/mast/mod.rs | 2 + core/src/mast/serialization/mod.rs | 130 +++++++++++++++++++++++++++++ 3 files changed, 134 insertions(+), 1 deletion(-) create mode 100644 core/src/mast/serialization/mod.rs diff --git a/core/Cargo.toml b/core/Cargo.toml index 80a1beadd0..91ca9d4d75 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -28,9 +28,10 @@ std = [ [dependencies] math = { package = "winter-math", version = "0.9", default-features = false } -#miden-crypto = { version = "0.9", default-features = false } miden-crypto = { git = "https://github.com/0xPolygonMiden/crypto", branch = "next", default-features = false } miden-formatting = { version = "0.1", default-features = false } +num-derive = "0.4.2" +num-traits = "0.2.19" thiserror = { version = "1.0", git = "https://github.com/bitwalker/thiserror", branch = "no-std", default-features = false } winter-utils = { package = "winter-utils", version = "0.9", default-features = false } diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 3db9d5acb6..cea97d093d 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -10,6 +10,8 @@ pub use node::{ }; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; +mod serialization; + #[cfg(test)] mod tests; diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs new file mode 100644 index 0000000000..0e761ec4fb --- /dev/null +++ b/core/src/mast/serialization/mod.rs @@ -0,0 +1,130 @@ +use alloc::{string::String, vec::Vec}; +use miden_crypto::hash::rpo::RpoDigest; +use num_derive::{FromPrimitive, ToPrimitive}; +use num_traits::{FromPrimitive, ToPrimitive}; +use thiserror::Error; +use winter_utils::{ByteWriter, Serializable}; + +use super::{MastForest, MastNode}; + +/// Specifies an offset into the `data` section of an encoded [`MastForest`]. +type DataOffset = u32; + +/// Magic string for detecting that a file is binary-encoded MAST. +const MAGIC: &[u8; 5] = b"MAST\0"; + +/// The format version. +/// +/// If future modifications are made to this format, the version should be incremented by 1. A +/// version of `[255, 255, 255]` is reserved for future extensions that require extending the +/// version field itself, but should be considered invalid for now. +const VERSION: [u8; 3] = [0, 0, 0]; + +#[derive(Debug, Error)] +pub enum Error { + #[error("Invalid discriminant '{discriminant}' for type '{ty}'")] + InvalidDiscriminant { ty: String, discriminant: u8 }, +} + +/// An entry in the `strings` table of an encoded [`MastForest`]. +/// +/// Strings are UTF8-encoded. +pub struct StringRef { + /// Offset into the `data` section. + offset: DataOffset, + + /// Length of the utf-8 string. + len: u32, +} + +impl Serializable for StringRef { + fn write_into(&self, target: &mut W) { + self.offset.write_into(target); + self.len.write_into(target); + } +} + +pub struct MastNodeInfo { + ty: MastNodeType, + offset: DataOffset, + digest: RpoDigest, +} + +impl Serializable for MastNodeInfo { + fn write_into(&self, target: &mut W) { + self.ty.write_into(target); + self.offset.write_into(target); + self.digest.write_into(target); + } +} + +pub struct MastNodeType([u8; 8]); + +impl Serializable for MastNodeType { + fn write_into(&self, target: &mut W) { + self.0.write_into(target); + } +} + +#[derive(Clone, Copy, Debug, FromPrimitive, ToPrimitive)] +#[repr(u8)] +pub enum MastNodeTypeVariant { + Join, + Split, + Loop, + Call, + Dyn, + Block, + External, +} + +impl MastNodeTypeVariant { + pub fn discriminant(&self) -> u8 { + self.to_u8().expect("guaranteed to fit in a `u8` due to #[repr(u8)]") + } + + pub fn try_from_discriminant(discriminant: u8) -> Result { + Self::from_u8(discriminant).ok_or_else(|| Error::InvalidDiscriminant { + ty: "MastNode".into(), + discriminant, + }) + } +} + +impl Serializable for MastForest { + fn write_into(&self, target: &mut W) { + let mut strings: Vec = Vec::new(); + let mut data: Vec = Vec::new(); + + // magic & version + target.write_bytes(MAGIC); + target.write_bytes(&VERSION); + + // node count + target.write_usize(self.nodes.len()); + + // roots + self.roots.write_into(target); + + // MAST node infos + for mast_node in &self.nodes { + let mast_node_info = convert_mast_node(mast_node, &mut data, &mut strings); + + mast_node_info.write_into(target); + } + + // strings table + strings.write_into(target); + + // data blob + data.write_into(target); + } +} + +fn convert_mast_node( + mast_node: &MastNode, + data: &mut Vec, + strings: &mut Vec, +) -> MastNodeInfo { + todo!() +} From 617379024466adb80f7d9a20c7ec3b627148e5ca Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 27 Jun 2024 16:35:51 -0400 Subject: [PATCH 039/172] define `MastNodeType` constructor from `MastNode` --- core/src/mast/node/basic_block_node.rs | 5 ++ core/src/mast/serialization/mod.rs | 109 ++++++++++++++++++++++++- 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/core/src/mast/node/basic_block_node.rs b/core/src/mast/node/basic_block_node.rs index 2d4b4a0313..b8c8795dd3 100644 --- a/core/src/mast/node/basic_block_node.rs +++ b/core/src/mast/node/basic_block_node.rs @@ -107,6 +107,11 @@ impl BasicBlockNode { /// Public accessors impl BasicBlockNode { + + pub fn num_operations_and_decorators(&self) -> u32 { + todo!() + } + pub fn op_batches(&self) -> &[OpBatch] { &self.op_batches } diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 0e761ec4fb..4f312bb5ea 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -1,11 +1,11 @@ use alloc::{string::String, vec::Vec}; use miden_crypto::hash::rpo::RpoDigest; use num_derive::{FromPrimitive, ToPrimitive}; -use num_traits::{FromPrimitive, ToPrimitive}; +use num_traits::{FromPrimitive, ToBytes, ToPrimitive}; use thiserror::Error; use winter_utils::{ByteWriter, Serializable}; -use super::{MastForest, MastNode}; +use super::{MastForest, MastNode, MastNodeId}; /// Specifies an offset into the `data` section of an encoded [`MastForest`]. type DataOffset = u32; @@ -58,8 +58,91 @@ impl Serializable for MastNodeInfo { } } +// TODOP: Describe how first 4 bits (i.e. high order bits of first byte) are the discriminant pub struct MastNodeType([u8; 8]); +impl MastNodeType { + pub fn new(mast_node: &MastNode) -> Self { + use MastNode::*; + + let discriminant = MastNodeTypeVariant::from_mast_node(mast_node).discriminant(); + assert!(discriminant < 2_u8.pow(4_u32)); + + match mast_node { + Block(block_node) => { + let num_ops = block_node.num_operations_and_decorators().to_be_bytes(); + + Self([discriminant << 4, num_ops[0], num_ops[1], num_ops[2], num_ops[3], 0, 0, 0]) + } + Join(join_node) => { + Self::encode_join_or_split(discriminant, join_node.first(), join_node.second()) + } + Split(split_node) => Self::encode_join_or_split( + discriminant, + split_node.on_true(), + split_node.on_false(), + ), + Loop(loop_node) => { + let [body_byte1, body_byte2, body_byte3, body_byte4] = + loop_node.body().0.to_be_bytes(); + + Self([discriminant << 4, body_byte1, body_byte2, body_byte3, body_byte4, 0, 0, 0]) + } + Call(_) | Dyn | External(_) => Self([discriminant << 4, 0, 0, 0, 0, 0, 0, 0]), + } + } + + // TODOP: Make a diagram of how the bits are split + fn encode_join_or_split( + discriminant: u8, + left_child_id: MastNodeId, + right_child_id: MastNodeId, + ) -> Self { + assert!(left_child_id.0 < 2_u32.pow(30)); + assert!(right_child_id.0 < 2_u32.pow(30)); + + let mut result: [u8; 8] = [0_u8; 8]; + + result[0] = discriminant << 4; + + // write left child into result + { + let [lsb, a, b, msb] = left_child_id.0.to_le_bytes(); + result[0] |= lsb >> 4; + result[1] |= lsb << 4; + result[1] |= a >> 4; + result[2] |= a << 4; + result[2] |= b >> 4; + result[3] |= b << 4; + + // msb is different from lsb, a and b since its 2 most significant bits are guaranteed + // to be 0, and hence not encoded. + // + // More specifically, let the bits of msb be `00abcdef`. We encode `abcd` in `result[3]`, + // and `ef` as the most significant bits of `result[4]`. + result[3] |= msb >> 2; + + result[4] |= msb << 6; + }; + + // write right child into result + { + // Recall that `result[4]` contains 2 bits from the left child id in the most + // significant bits. Also, the most significant byte of the right child is guaranteed to + // fit in 6 bits. Hence, we use big endian format for the right child id to simplify + // encoding and decoding. + let [msb, a, b, lsb] = right_child_id.0.to_be_bytes(); + + result[4] |= msb; + result[5] = a; + result[6] = b; + result[7] = lsb; + }; + + Self(result) + } +} + impl Serializable for MastNodeType { fn write_into(&self, target: &mut W) { self.0.write_into(target); @@ -73,6 +156,7 @@ pub enum MastNodeTypeVariant { Split, Loop, Call, + Syscall, Dyn, Block, External, @@ -89,6 +173,24 @@ impl MastNodeTypeVariant { discriminant, }) } + + pub fn from_mast_node(mast_node: &MastNode) -> Self { + match mast_node { + MastNode::Block(_) => Self::Block, + MastNode::Join(_) => Self::Join, + MastNode::Split(_) => Self::Split, + MastNode::Loop(_) => Self::Loop, + MastNode::Call(call_node) => { + if call_node.is_syscall() { + Self::Syscall + } else { + Self::Call + } + } + MastNode::Dyn => Self::Dyn, + MastNode::External(_) => Self::External, + } + } } impl Serializable for MastForest { @@ -126,5 +228,8 @@ fn convert_mast_node( data: &mut Vec, strings: &mut Vec, ) -> MastNodeInfo { + // mast node info + + // fill out encoded operations/decorators in data todo!() } From e018cc8c2431695bfb2b395610a65cd8dcfce80f Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 27 Jun 2024 16:52:38 -0400 Subject: [PATCH 040/172] test join serialization of MastNodeType --- core/src/mast/node/basic_block_node.rs | 1 - core/src/mast/node/join_node.rs | 23 ++++++++++++------- core/src/mast/serialization/mod.rs | 31 +++++++++++++++++++++++--- 3 files changed, 43 insertions(+), 12 deletions(-) diff --git a/core/src/mast/node/basic_block_node.rs b/core/src/mast/node/basic_block_node.rs index b8c8795dd3..56ca756c17 100644 --- a/core/src/mast/node/basic_block_node.rs +++ b/core/src/mast/node/basic_block_node.rs @@ -107,7 +107,6 @@ impl BasicBlockNode { /// Public accessors impl BasicBlockNode { - pub fn num_operations_and_decorators(&self) -> u32 { todo!() } diff --git a/core/src/mast/node/join_node.rs b/core/src/mast/node/join_node.rs index af8cc70e24..0d9acfeb7d 100644 --- a/core/src/mast/node/join_node.rs +++ b/core/src/mast/node/join_node.rs @@ -33,14 +33,9 @@ impl JoinNode { Self { children, digest } } - pub(super) fn to_pretty_print<'a>( - &'a self, - mast_forest: &'a MastForest, - ) -> impl PrettyPrint + 'a { - JoinNodePrettyPrint { - join_node: self, - mast_forest, - } + #[cfg(test)] + pub fn new_test(children: [MastNodeId; 2], digest: RpoDigest) -> Self { + Self { children, digest } } } @@ -55,6 +50,18 @@ impl JoinNode { } } +impl JoinNode { + pub(super) fn to_pretty_print<'a>( + &'a self, + mast_forest: &'a MastForest, + ) -> impl PrettyPrint + 'a { + JoinNodePrettyPrint { + join_node: self, + mast_forest, + } + } +} + impl MerkleTreeNode for JoinNode { fn digest(&self) -> RpoDigest { self.digest diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 4f312bb5ea..1da3f7a59f 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -118,10 +118,9 @@ impl MastNodeType { // msb is different from lsb, a and b since its 2 most significant bits are guaranteed // to be 0, and hence not encoded. // - // More specifically, let the bits of msb be `00abcdef`. We encode `abcd` in `result[3]`, - // and `ef` as the most significant bits of `result[4]`. + // More specifically, let the bits of msb be `00abcdef`. We encode `abcd` in + // `result[3]`, and `ef` as the most significant bits of `result[4]`. result[3] |= msb >> 2; - result[4] |= msb << 6; }; @@ -233,3 +232,29 @@ fn convert_mast_node( // fill out encoded operations/decorators in data todo!() } + +#[cfg(test)] +mod tests { + use super::*; + use crate::mast::JoinNode; + + #[test] + fn mast_node_type_serialization_join() { + let left_child_id = MastNodeId(0b00111001_11101011_01101100_11011000); + let right_child_id = MastNodeId(0b00100111_10101010_11111111_11001110); + let mast_node = MastNode::Join(JoinNode::new_test( + [left_child_id, right_child_id], + RpoDigest::default(), + )); + + let mast_node_type = MastNodeType::new(&mast_node); + + // Note: Join's discriminant is 0 + let expected_mast_node_type = [ + 0b00001101, 0b10000110, 0b11001110, 0b10111110, 0b01100111, 0b10101010, 0b11111111, + 0b11001110, + ]; + + assert_eq!(expected_mast_node_type, mast_node_type.0); + } +} From 6671afedd4d4a10e0463a62d130b682e9e3dd5d8 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 27 Jun 2024 16:55:37 -0400 Subject: [PATCH 041/172] `MastNodeType` serialization of split --- core/src/mast/node/split_node.rs | 5 +++++ core/src/mast/serialization/mod.rs | 20 +++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/core/src/mast/node/split_node.rs b/core/src/mast/node/split_node.rs index 381ca28ef4..fe3a399bab 100644 --- a/core/src/mast/node/split_node.rs +++ b/core/src/mast/node/split_node.rs @@ -32,6 +32,11 @@ impl SplitNode { Self { branches, digest } } + + #[cfg(test)] + pub fn new_test(branches: [MastNodeId; 2], digest: RpoDigest) -> Self { + Self { branches, digest } + } } /// Public accessors diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 1da3f7a59f..7c70951fc2 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -236,7 +236,7 @@ fn convert_mast_node( #[cfg(test)] mod tests { use super::*; - use crate::mast::JoinNode; + use crate::mast::{JoinNode, SplitNode}; #[test] fn mast_node_type_serialization_join() { @@ -257,4 +257,22 @@ mod tests { assert_eq!(expected_mast_node_type, mast_node_type.0); } + + #[test] + fn mast_node_type_serialization_split() { + let on_true_id = MastNodeId(0b00111001_11101011_01101100_11011000); + let on_false_id = MastNodeId(0b00100111_10101010_11111111_11001110); + let mast_node = + MastNode::Split(SplitNode::new_test([on_true_id, on_false_id], RpoDigest::default())); + + let mast_node_type = MastNodeType::new(&mast_node); + + // Note: Split's discriminant is 0 + let expected_mast_node_type = [ + 0b00011101, 0b10000110, 0b11001110, 0b10111110, 0b01100111, 0b10101010, 0b11111111, + 0b11001110, + ]; + + assert_eq!(expected_mast_node_type, mast_node_type.0); + } } From babdd0cc076aad516db8d67f28d08a230d74f426 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 27 Jun 2024 16:56:59 -0400 Subject: [PATCH 042/172] Revert "serialization for MastNode variants except basic block" This reverts commit efc24fdcc792baf25cac80bf0c06c7c2a477ab4a. --- core/src/mast/node/call_node.rs | 29 ----------------------------- core/src/mast/node/dyn_node.rs | 13 ------------- core/src/mast/node/external.rs | 17 ----------------- core/src/mast/node/join_node.rs | 19 ------------------- core/src/mast/node/loop_node.rs | 19 ------------------- core/src/mast/node/split_node.rs | 19 ------------------- 6 files changed, 116 deletions(-) diff --git a/core/src/mast/node/call_node.rs b/core/src/mast/node/call_node.rs index e9df3b45e5..c2183ea84d 100644 --- a/core/src/mast/node/call_node.rs +++ b/core/src/mast/node/call_node.rs @@ -2,7 +2,6 @@ use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; use miden_formatting::prettier::PrettyPrint; -use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use crate::{chiplets::hasher, Operation}; @@ -99,34 +98,6 @@ impl MerkleTreeNode for CallNode { } } -impl Serializable for CallNode { - fn write_into(&self, target: &mut W) { - let Self { - callee, - is_syscall, - digest, - } = self; - - callee.write_into(target); - target.write_bool(*is_syscall); - digest.write_into(target); - } -} - -impl Deserializable for CallNode { - fn read_from(source: &mut R) -> Result { - let callee = Deserializable::read_from(source)?; - let is_syscall = source.read_bool()?; - let digest = Deserializable::read_from(source)?; - - Ok(Self { - callee, - is_syscall, - digest, - }) - } -} - struct CallNodePrettyPrint<'a> { call_node: &'a CallNode, mast_forest: &'a MastForest, diff --git a/core/src/mast/node/dyn_node.rs b/core/src/mast/node/dyn_node.rs index 716a602446..c298a03ade 100644 --- a/core/src/mast/node/dyn_node.rs +++ b/core/src/mast/node/dyn_node.rs @@ -1,7 +1,6 @@ use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; -use winter_utils::{ByteReader, Deserializable, DeserializationError, Serializable}; use crate::{ mast::{MastForest, MerkleTreeNode}, @@ -35,18 +34,6 @@ impl MerkleTreeNode for DynNode { } } -impl Serializable for DynNode { - fn write_into(&self, _target: &mut W) { - // nothing - } -} - -impl Deserializable for DynNode { - fn read_from(_source: &mut R) -> Result { - Ok(Self) - } -} - impl crate::prettier::PrettyPrint for DynNode { fn render(&self) -> crate::prettier::Document { use crate::prettier::*; diff --git a/core/src/mast/node/external.rs b/core/src/mast/node/external.rs index a778f2b960..c0b8ff10a3 100644 --- a/core/src/mast/node/external.rs +++ b/core/src/mast/node/external.rs @@ -1,7 +1,6 @@ use crate::mast::{MastForest, MerkleTreeNode}; use core::fmt; use miden_crypto::hash::rpo::RpoDigest; -use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; /// Node for referencing procedures not present in a given [`MastForest`] (hence "external"). /// @@ -34,22 +33,6 @@ impl MerkleTreeNode for ExternalNode { } } -impl Serializable for ExternalNode { - fn write_into(&self, target: &mut W) { - let Self { digest } = self; - - digest.write_into(target); - } -} - -impl Deserializable for ExternalNode { - fn read_from(source: &mut R) -> Result { - let digest = Deserializable::read_from(source)?; - - Ok(Self { digest }) - } -} - impl crate::prettier::PrettyPrint for ExternalNode { fn render(&self) -> crate::prettier::Document { use crate::prettier::*; diff --git a/core/src/mast/node/join_node.rs b/core/src/mast/node/join_node.rs index 0d9acfeb7d..1bd320c6e1 100644 --- a/core/src/mast/node/join_node.rs +++ b/core/src/mast/node/join_node.rs @@ -1,7 +1,6 @@ use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; -use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use crate::{chiplets::hasher, prettier::PrettyPrint, Operation}; @@ -75,24 +74,6 @@ impl MerkleTreeNode for JoinNode { } } -impl Serializable for JoinNode { - fn write_into(&self, target: &mut W) { - let Self { children, digest } = self; - - children.write_into(target); - digest.write_into(target); - } -} - -impl Deserializable for JoinNode { - fn read_from(source: &mut R) -> Result { - let children = Deserializable::read_from(source)?; - let digest = Deserializable::read_from(source)?; - - Ok(Self { children, digest }) - } -} - struct JoinNodePrettyPrint<'a> { join_node: &'a JoinNode, mast_forest: &'a MastForest, diff --git a/core/src/mast/node/loop_node.rs b/core/src/mast/node/loop_node.rs index d2150cde00..fc63b11367 100644 --- a/core/src/mast/node/loop_node.rs +++ b/core/src/mast/node/loop_node.rs @@ -2,7 +2,6 @@ use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; use miden_formatting::prettier::PrettyPrint; -use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use crate::{chiplets::hasher, Operation}; @@ -62,24 +61,6 @@ impl MerkleTreeNode for LoopNode { } } -impl Serializable for LoopNode { - fn write_into(&self, target: &mut W) { - let Self { body, digest } = self; - - body.write_into(target); - digest.write_into(target); - } -} - -impl Deserializable for LoopNode { - fn read_from(source: &mut R) -> Result { - let body = Deserializable::read_from(source)?; - let digest = Deserializable::read_from(source)?; - - Ok(Self { body, digest }) - } -} - struct LoopNodePrettyPrint<'a> { loop_node: &'a LoopNode, mast_forest: &'a MastForest, diff --git a/core/src/mast/node/split_node.rs b/core/src/mast/node/split_node.rs index fe3a399bab..3418bba88c 100644 --- a/core/src/mast/node/split_node.rs +++ b/core/src/mast/node/split_node.rs @@ -2,7 +2,6 @@ use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; use miden_formatting::prettier::PrettyPrint; -use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use crate::{chiplets::hasher, Operation}; @@ -75,24 +74,6 @@ impl MerkleTreeNode for SplitNode { } } -impl Serializable for SplitNode { - fn write_into(&self, target: &mut W) { - let Self { branches, digest } = self; - - branches.write_into(target); - digest.write_into(target); - } -} - -impl Deserializable for SplitNode { - fn read_from(source: &mut R) -> Result { - let branches = Deserializable::read_from(source)?; - let digest = Deserializable::read_from(source)?; - - Ok(Self { branches, digest }) - } -} - struct SplitNodePrettyPrint<'a> { split_node: &'a SplitNode, mast_forest: &'a MastForest, From cd065276153ce93b0c87e1792b7b06d05cea219b Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 27 Jun 2024 17:00:49 -0400 Subject: [PATCH 043/172] add TODOP --- core/src/mast/serialization/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 7c70951fc2..7b9331d500 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -275,4 +275,6 @@ mod tests { assert_eq!(expected_mast_node_type, mast_node_type.0); } + + // TODOP: Test all other variants } From 2dd482929fd0540126a0bc5af0ad8725a644ec72 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Mon, 1 Jul 2024 11:24:05 -0400 Subject: [PATCH 044/172] impl Deserializable for `MastForest` (scaffold) --- core/src/mast/serialization/mod.rs | 93 ++++++++++++++++++++++++++++-- 1 file changed, 89 insertions(+), 4 deletions(-) diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 7b9331d500..3c207dd213 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -1,9 +1,9 @@ use alloc::{string::String, vec::Vec}; use miden_crypto::hash::rpo::RpoDigest; use num_derive::{FromPrimitive, ToPrimitive}; -use num_traits::{FromPrimitive, ToBytes, ToPrimitive}; +use num_traits::{FromPrimitive, ToPrimitive}; use thiserror::Error; -use winter_utils::{ByteWriter, Serializable}; +use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use super::{MastForest, MastNode, MastNodeId}; @@ -44,6 +44,15 @@ impl Serializable for StringRef { } } +impl Deserializable for StringRef { + fn read_from(source: &mut R) -> Result { + let offset = DataOffset::read_from(source)?; + let len = source.read_u32()?; + + Ok(Self { offset, len }) + } +} + pub struct MastNodeInfo { ty: MastNodeType, offset: DataOffset, @@ -58,6 +67,16 @@ impl Serializable for MastNodeInfo { } } +impl Deserializable for MastNodeInfo { + fn read_from(source: &mut R) -> Result { + let ty = Deserializable::read_from(source)?; + let offset = DataOffset::read_from(source)?; + let digest = RpoDigest::read_from(source)?; + + Ok(Self { ty, offset, digest }) + } +} + // TODOP: Describe how first 4 bits (i.e. high order bits of first byte) are the discriminant pub struct MastNodeType([u8; 8]); @@ -148,6 +167,14 @@ impl Serializable for MastNodeType { } } +impl Deserializable for MastNodeType { + fn read_from(source: &mut R) -> Result { + let bytes = source.read_array()?; + + Ok(Self(bytes)) + } +} + #[derive(Clone, Copy, Debug, FromPrimitive, ToPrimitive)] #[repr(u8)] pub enum MastNodeTypeVariant { @@ -194,6 +221,7 @@ impl MastNodeTypeVariant { impl Serializable for MastForest { fn write_into(&self, target: &mut W) { + // TODOP: make sure padding is in accordance with Paul's docs let mut strings: Vec = Vec::new(); let mut data: Vec = Vec::new(); @@ -209,7 +237,7 @@ impl Serializable for MastForest { // MAST node infos for mast_node in &self.nodes { - let mast_node_info = convert_mast_node(mast_node, &mut data, &mut strings); + let mast_node_info = mast_node_to_info(mast_node, &mut data, &mut strings); mast_node_info.write_into(target); } @@ -222,7 +250,56 @@ impl Serializable for MastForest { } } -fn convert_mast_node( +impl Deserializable for MastForest { + fn read_from(source: &mut R) -> Result { + let magic: [u8; 5] = source.read_array()?; + if magic != *MAGIC { + return Err(DeserializationError::InvalidValue(format!( + "Invalid magic bytes. Expected '{:?}', got '{:?}'", + *MAGIC, magic + ))); + } + + let version: [u8; 3] = source.read_array()?; + if version != VERSION { + return Err(DeserializationError::InvalidValue(format!( + "Unsupported version. Got '{version:?}', but only '{VERSION:?}' is supported", + ))); + } + + let node_count = source.read_usize()?; + + let roots: Vec = Deserializable::read_from(source)?; + + let mast_node_infos = { + let mut mast_node_infos = Vec::with_capacity(node_count); + for _ in 0..node_count { + let mast_node_info = MastNodeInfo::read_from(source)?; + mast_node_infos.push(mast_node_info); + } + + mast_node_infos + }; + + let strings: Vec = Deserializable::read_from(source)?; + + let data: Vec = Deserializable::read_from(source)?; + + let nodes = { + let mut nodes = Vec::with_capacity(node_count); + for mast_node_info in mast_node_infos { + let node = try_info_to_mast_node(mast_node_info, &data, &strings)?; + nodes.push(node); + } + + nodes + }; + + Ok(Self { nodes, roots }) + } +} + +fn mast_node_to_info( mast_node: &MastNode, data: &mut Vec, strings: &mut Vec, @@ -233,6 +310,14 @@ fn convert_mast_node( todo!() } +fn try_info_to_mast_node( + mast_node_info: MastNodeInfo, + data: &[u8], + strings: &[StringRef], +) -> Result { + todo!() +} + #[cfg(test)] mod tests { use super::*; From 4be1401740b8e7a98ac0c1dff37e323bc10651bc Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Mon, 1 Jul 2024 12:44:06 -0400 Subject: [PATCH 045/172] mast_node_to_info() scaffold --- core/src/mast/node/call_node.rs | 1 + core/src/mast/serialization/mod.rs | 32 ++++++++++++++++++++++++++---- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/core/src/mast/node/call_node.rs b/core/src/mast/node/call_node.rs index c2183ea84d..4c66dd0e06 100644 --- a/core/src/mast/node/call_node.rs +++ b/core/src/mast/node/call_node.rs @@ -7,6 +7,7 @@ use crate::{chiplets::hasher, Operation}; use crate::mast::{MastForest, MastNodeId, MerkleTreeNode}; +// TODOP: `callee` must be a digest, #[derive(Debug, Clone, PartialEq, Eq)] pub struct CallNode { callee: MastNodeId, diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 3c207dd213..597e350d9f 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -5,6 +5,8 @@ use num_traits::{FromPrimitive, ToPrimitive}; use thiserror::Error; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; +use crate::mast::MerkleTreeNode; + use super::{MastForest, MastNode, MastNodeId}; /// Specifies an offset into the `data` section of an encoded [`MastForest`]. @@ -107,7 +109,22 @@ impl MastNodeType { Self([discriminant << 4, body_byte1, body_byte2, body_byte3, body_byte4, 0, 0, 0]) } - Call(_) | Dyn | External(_) => Self([discriminant << 4, 0, 0, 0, 0, 0, 0, 0]), + Call(call_node) => { + let [callee_byte1, callee_byte2, callee_byte3, callee_byte4] = + call_node.callee().0.to_be_bytes(); + + Self([ + discriminant << 4, + callee_byte1, + callee_byte2, + callee_byte3, + callee_byte4, + 0, + 0, + 0, + ]) + } + Dyn | External(_) => Self([discriminant << 4, 0, 0, 0, 0, 0, 0, 0]), } } @@ -304,10 +321,17 @@ fn mast_node_to_info( data: &mut Vec, strings: &mut Vec, ) -> MastNodeInfo { - // mast node info + use MastNode::*; - // fill out encoded operations/decorators in data - todo!() + let ty = MastNodeType::new(mast_node); + let digest = mast_node.digest(); + + let offset = match mast_node { + Block(_) => todo!(), + Join(_) | Split(_) | Loop(_) | Call(_) | Dyn | External(_) => 0, + }; + + MastNodeInfo { ty, offset, digest } } fn try_info_to_mast_node( From 2695062ddd0f952b4384e264ab18b1497aabdbfa Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Mon, 1 Jul 2024 15:08:14 -0400 Subject: [PATCH 046/172] try_info_to_mast_node scaffold --- core/src/mast/node/mod.rs | 1 + core/src/mast/serialization/mod.rs | 165 ++++++++++++++++++++++++----- 2 files changed, 138 insertions(+), 28 deletions(-) diff --git a/core/src/mast/node/mod.rs b/core/src/mast/node/mod.rs index 2bf0836cf3..65ef108726 100644 --- a/core/src/mast/node/mod.rs +++ b/core/src/mast/node/mod.rs @@ -88,6 +88,7 @@ impl MastNode { Self::Dyn } + // TODOP: removed, since unused? pub fn new_dyncall(dyn_node_id: MastNodeId, mast_forest: &MastForest) -> Self { Self::Call(CallNode::new(dyn_node_id, mast_forest)) } diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 597e350d9f..ee0d10526d 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -1,4 +1,7 @@ -use alloc::{string::String, vec::Vec}; +use alloc::{ + string::{String, ToString}, + vec::Vec, +}; use miden_crypto::hash::rpo::RpoDigest; use num_derive::{FromPrimitive, ToPrimitive}; use num_traits::{FromPrimitive, ToPrimitive}; @@ -82,6 +85,7 @@ impl Deserializable for MastNodeInfo { // TODOP: Describe how first 4 bits (i.e. high order bits of first byte) are the discriminant pub struct MastNodeType([u8; 8]); +/// Constructors impl MastNodeType { pub fn new(mast_node: &MastNode) -> Self { use MastNode::*; @@ -91,9 +95,9 @@ impl MastNodeType { match mast_node { Block(block_node) => { - let num_ops = block_node.num_operations_and_decorators().to_be_bytes(); + let num_ops = block_node.num_operations_and_decorators(); - Self([discriminant << 4, num_ops[0], num_ops[1], num_ops[2], num_ops[3], 0, 0, 0]) + Self::encode_u32_payload(discriminant, num_ops) } Join(join_node) => { Self::encode_join_or_split(discriminant, join_node.first(), join_node.second()) @@ -104,30 +108,31 @@ impl MastNodeType { split_node.on_false(), ), Loop(loop_node) => { - let [body_byte1, body_byte2, body_byte3, body_byte4] = - loop_node.body().0.to_be_bytes(); + let child_id = loop_node.body().0; - Self([discriminant << 4, body_byte1, body_byte2, body_byte3, body_byte4, 0, 0, 0]) + Self::encode_u32_payload(discriminant, child_id) } Call(call_node) => { - let [callee_byte1, callee_byte2, callee_byte3, callee_byte4] = - call_node.callee().0.to_be_bytes(); - - Self([ - discriminant << 4, - callee_byte1, - callee_byte2, - callee_byte3, - callee_byte4, - 0, - 0, - 0, - ]) + let child_id = call_node.callee().0; + + Self::encode_u32_payload(discriminant, child_id) } Dyn | External(_) => Self([discriminant << 4, 0, 0, 0, 0, 0, 0, 0]), } } +} + +/// Accessors +impl MastNodeType { + pub fn variant(&self) -> Result { + let discriminant = self.0[0] >> 4; + MastNodeTypeVariant::try_from_discriminant(discriminant) + } +} + +/// Helpers +impl MastNodeType { // TODOP: Make a diagram of how the bits are split fn encode_join_or_split( discriminant: u8, @@ -176,6 +181,60 @@ impl MastNodeType { Self(result) } + + fn decode_join_or_split(&self) -> (MastNodeId, MastNodeId) { + let first = { + let mut first_le_bytes = [0_u8; 4]; + + first_le_bytes[0] = self.0[0] << 4; + first_le_bytes[0] |= self.0[1] >> 4; + + first_le_bytes[1] = self.0[1] << 4; + first_le_bytes[1] |= self.0[2] >> 4; + + first_le_bytes[2] = self.0[2] << 4; + first_le_bytes[2] |= self.0[3] >> 4; + + first_le_bytes[3] = (self.0[3] & 0b1111) << 2; + first_le_bytes[3] |= self.0[4] >> 6; + + u32::from_le_bytes(first_le_bytes) + }; + + let second = { + let mut second_be_bytes = [0_u8; 4]; + + second_be_bytes[0] = self.0[4] & 0b0011_1111; + second_be_bytes[1] = self.0[5]; + second_be_bytes[2] = self.0[6]; + second_be_bytes[3] = self.0[7]; + + u32::from_be_bytes(second_be_bytes) + }; + + (MastNodeId(first), MastNodeId(second)) + } + + fn encode_u32_payload(discriminant: u8, payload: u32) -> Self { + let [payload_byte1, payload_byte2, payload_byte3, payload_byte4] = payload.to_be_bytes(); + + Self([ + discriminant << 4, + payload_byte1, + payload_byte2, + payload_byte3, + payload_byte4, + 0, + 0, + 0, + ]) + } + + fn decode_u32_payload(&self) -> u32 { + let payload_be_bytes = [self.0[1], self.0[2], self.0[3], self.0[4]]; + + u32::from_be_bytes(payload_be_bytes) + } } impl Serializable for MastNodeType { @@ -302,17 +361,22 @@ impl Deserializable for MastForest { let data: Vec = Deserializable::read_from(source)?; - let nodes = { - let mut nodes = Vec::with_capacity(node_count); + let mast_forest = { + let mut mast_forest = MastForest::new(); + for mast_node_info in mast_node_infos { - let node = try_info_to_mast_node(mast_node_info, &data, &strings)?; - nodes.push(node); + let node = try_info_to_mast_node(mast_node_info, &mast_forest, &data, &strings)?; + mast_forest.add_node(node); + } + + for root in roots { + mast_forest.make_root(root); } - nodes + mast_forest }; - Ok(Self { nodes, roots }) + Ok(mast_forest) } } @@ -336,10 +400,47 @@ fn mast_node_to_info( fn try_info_to_mast_node( mast_node_info: MastNodeInfo, + mast_forest: &MastForest, data: &[u8], strings: &[StringRef], ) -> Result { - todo!() + let mast_node_variant = mast_node_info + .ty + .variant() + .map_err(|err| DeserializationError::InvalidValue(err.to_string()))?; + + // TODOP: Make a faillible version of `MastNode` ctors + // TODOP: Check digest of resulting `MastNode` matches `MastNodeInfo.digest`? + match mast_node_variant { + MastNodeTypeVariant::Block => todo!(), + MastNodeTypeVariant::Join => { + let (left_child, right_child) = MastNodeType::decode_join_or_split(&mast_node_info.ty); + + Ok(MastNode::new_join(left_child, right_child, mast_forest)) + } + MastNodeTypeVariant::Split => { + let (if_branch, else_branch) = MastNodeType::decode_join_or_split(&mast_node_info.ty); + + Ok(MastNode::new_split(if_branch, else_branch, mast_forest)) + } + MastNodeTypeVariant::Loop => { + let body_id = MastNodeType::decode_u32_payload(&mast_node_info.ty); + + Ok(MastNode::new_loop(MastNodeId(body_id), mast_forest)) + } + MastNodeTypeVariant::Call => { + let callee_id = MastNodeType::decode_u32_payload(&mast_node_info.ty); + + Ok(MastNode::new_call(MastNodeId(callee_id), mast_forest)) + } + MastNodeTypeVariant::Syscall => { + let callee_id = MastNodeType::decode_u32_payload(&mast_node_info.ty); + + Ok(MastNode::new_syscall(MastNodeId(callee_id), mast_forest)) + } + MastNodeTypeVariant::Dyn => Ok(MastNode::new_dynexec()), + MastNodeTypeVariant::External => Ok(MastNode::new_external(mast_node_info.digest)), + } } #[cfg(test)] @@ -348,7 +449,7 @@ mod tests { use crate::mast::{JoinNode, SplitNode}; #[test] - fn mast_node_type_serialization_join() { + fn mast_node_type_serde_join() { let left_child_id = MastNodeId(0b00111001_11101011_01101100_11011000); let right_child_id = MastNodeId(0b00100111_10101010_11111111_11001110); let mast_node = MastNode::Join(JoinNode::new_test( @@ -365,10 +466,14 @@ mod tests { ]; assert_eq!(expected_mast_node_type, mast_node_type.0); + + let (decoded_left, decoded_right) = mast_node_type.decode_join_or_split(); + assert_eq!(left_child_id, decoded_left); + assert_eq!(right_child_id, decoded_right); } #[test] - fn mast_node_type_serialization_split() { + fn mast_node_type_serde_split() { let on_true_id = MastNodeId(0b00111001_11101011_01101100_11011000); let on_false_id = MastNodeId(0b00100111_10101010_11111111_11001110); let mast_node = @@ -383,6 +488,10 @@ mod tests { ]; assert_eq!(expected_mast_node_type, mast_node_type.0); + + let (decoded_on_true, decoded_on_false) = mast_node_type.decode_join_or_split(); + assert_eq!(on_true_id, decoded_on_true); + assert_eq!(on_false_id, decoded_on_false); } // TODOP: Test all other variants From ef0a88185de2f2b866311ad9c40a9c94e3eecd10 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Mon, 1 Jul 2024 15:39:54 -0400 Subject: [PATCH 047/172] Rename `EncodedMastNodeType` --- core/src/mast/serialization/mod.rs | 32 ++++++++++++++++-------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index ee0d10526d..d10604e795 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -59,7 +59,7 @@ impl Deserializable for StringRef { } pub struct MastNodeInfo { - ty: MastNodeType, + ty: EncodedMastNodeType, offset: DataOffset, digest: RpoDigest, } @@ -83,10 +83,10 @@ impl Deserializable for MastNodeInfo { } // TODOP: Describe how first 4 bits (i.e. high order bits of first byte) are the discriminant -pub struct MastNodeType([u8; 8]); +pub struct EncodedMastNodeType([u8; 8]); /// Constructors -impl MastNodeType { +impl EncodedMastNodeType { pub fn new(mast_node: &MastNode) -> Self { use MastNode::*; @@ -123,7 +123,7 @@ impl MastNodeType { } /// Accessors -impl MastNodeType { +impl EncodedMastNodeType { pub fn variant(&self) -> Result { let discriminant = self.0[0] >> 4; @@ -132,7 +132,7 @@ impl MastNodeType { } /// Helpers -impl MastNodeType { +impl EncodedMastNodeType { // TODOP: Make a diagram of how the bits are split fn encode_join_or_split( discriminant: u8, @@ -237,13 +237,13 @@ impl MastNodeType { } } -impl Serializable for MastNodeType { +impl Serializable for EncodedMastNodeType { fn write_into(&self, target: &mut W) { self.0.write_into(target); } } -impl Deserializable for MastNodeType { +impl Deserializable for EncodedMastNodeType { fn read_from(source: &mut R) -> Result { let bytes = source.read_array()?; @@ -387,7 +387,7 @@ fn mast_node_to_info( ) -> MastNodeInfo { use MastNode::*; - let ty = MastNodeType::new(mast_node); + let ty = EncodedMastNodeType::new(mast_node); let digest = mast_node.digest(); let offset = match mast_node { @@ -414,27 +414,29 @@ fn try_info_to_mast_node( match mast_node_variant { MastNodeTypeVariant::Block => todo!(), MastNodeTypeVariant::Join => { - let (left_child, right_child) = MastNodeType::decode_join_or_split(&mast_node_info.ty); + let (left_child, right_child) = + EncodedMastNodeType::decode_join_or_split(&mast_node_info.ty); Ok(MastNode::new_join(left_child, right_child, mast_forest)) } MastNodeTypeVariant::Split => { - let (if_branch, else_branch) = MastNodeType::decode_join_or_split(&mast_node_info.ty); + let (if_branch, else_branch) = + EncodedMastNodeType::decode_join_or_split(&mast_node_info.ty); Ok(MastNode::new_split(if_branch, else_branch, mast_forest)) } MastNodeTypeVariant::Loop => { - let body_id = MastNodeType::decode_u32_payload(&mast_node_info.ty); + let body_id = EncodedMastNodeType::decode_u32_payload(&mast_node_info.ty); Ok(MastNode::new_loop(MastNodeId(body_id), mast_forest)) } MastNodeTypeVariant::Call => { - let callee_id = MastNodeType::decode_u32_payload(&mast_node_info.ty); + let callee_id = EncodedMastNodeType::decode_u32_payload(&mast_node_info.ty); Ok(MastNode::new_call(MastNodeId(callee_id), mast_forest)) } MastNodeTypeVariant::Syscall => { - let callee_id = MastNodeType::decode_u32_payload(&mast_node_info.ty); + let callee_id = EncodedMastNodeType::decode_u32_payload(&mast_node_info.ty); Ok(MastNode::new_syscall(MastNodeId(callee_id), mast_forest)) } @@ -457,7 +459,7 @@ mod tests { RpoDigest::default(), )); - let mast_node_type = MastNodeType::new(&mast_node); + let mast_node_type = EncodedMastNodeType::new(&mast_node); // Note: Join's discriminant is 0 let expected_mast_node_type = [ @@ -479,7 +481,7 @@ mod tests { let mast_node = MastNode::Split(SplitNode::new_test([on_true_id, on_false_id], RpoDigest::default())); - let mast_node_type = MastNodeType::new(&mast_node); + let mast_node_type = EncodedMastNodeType::new(&mast_node); // Note: Split's discriminant is 0 let expected_mast_node_type = [ From dd89461fc07e433787ceaafb0ef978a4b25fe3e7 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Mon, 1 Jul 2024 15:50:41 -0400 Subject: [PATCH 048/172] add info module --- core/src/mast/serialization/info.rs | 299 ++++++++++++++++++++++++++++ core/src/mast/serialization/mod.rs | 298 +-------------------------- 2 files changed, 303 insertions(+), 294 deletions(-) create mode 100644 core/src/mast/serialization/info.rs diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs new file mode 100644 index 0000000000..82c42f8dc7 --- /dev/null +++ b/core/src/mast/serialization/info.rs @@ -0,0 +1,299 @@ +use miden_crypto::hash::rpo::RpoDigest; +use num_derive::{FromPrimitive, ToPrimitive}; +use num_traits::{FromPrimitive, ToPrimitive}; +use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; + +use crate::mast::{MastNode, MastNodeId}; + +use super::DataOffset; + +pub struct MastNodeInfo { + pub(super) ty: EncodedMastNodeType, + pub(super) offset: DataOffset, + pub(super) digest: RpoDigest, +} + +impl Serializable for MastNodeInfo { + fn write_into(&self, target: &mut W) { + self.ty.write_into(target); + self.offset.write_into(target); + self.digest.write_into(target); + } +} + +impl Deserializable for MastNodeInfo { + fn read_from(source: &mut R) -> Result { + let ty = Deserializable::read_from(source)?; + let offset = DataOffset::read_from(source)?; + let digest = RpoDigest::read_from(source)?; + + Ok(Self { ty, offset, digest }) + } +} + +// TODOP: Describe how first 4 bits (i.e. high order bits of first byte) are the discriminant +pub struct EncodedMastNodeType(pub(super) [u8; 8]); + +/// Constructors +impl EncodedMastNodeType { + pub fn new(mast_node: &MastNode) -> Self { + use MastNode::*; + + let discriminant = MastNodeTypeVariant::from_mast_node(mast_node).discriminant(); + assert!(discriminant < 2_u8.pow(4_u32)); + + match mast_node { + Block(block_node) => { + let num_ops = block_node.num_operations_and_decorators(); + + Self::encode_u32_payload(discriminant, num_ops) + } + Join(join_node) => { + Self::encode_join_or_split(discriminant, join_node.first(), join_node.second()) + } + Split(split_node) => Self::encode_join_or_split( + discriminant, + split_node.on_true(), + split_node.on_false(), + ), + Loop(loop_node) => { + let child_id = loop_node.body().0; + + Self::encode_u32_payload(discriminant, child_id) + } + Call(call_node) => { + let child_id = call_node.callee().0; + + Self::encode_u32_payload(discriminant, child_id) + } + Dyn | External(_) => Self([discriminant << 4, 0, 0, 0, 0, 0, 0, 0]), + } + } +} + +/// Accessors +impl EncodedMastNodeType { + pub fn variant(&self) -> Result { + let discriminant = self.0[0] >> 4; + + MastNodeTypeVariant::try_from_discriminant(discriminant) + } +} + +/// Helpers +impl EncodedMastNodeType { + // TODOP: Make a diagram of how the bits are split + pub fn encode_join_or_split( + discriminant: u8, + left_child_id: MastNodeId, + right_child_id: MastNodeId, + ) -> Self { + assert!(left_child_id.0 < 2_u32.pow(30)); + assert!(right_child_id.0 < 2_u32.pow(30)); + + let mut result: [u8; 8] = [0_u8; 8]; + + result[0] = discriminant << 4; + + // write left child into result + { + let [lsb, a, b, msb] = left_child_id.0.to_le_bytes(); + result[0] |= lsb >> 4; + result[1] |= lsb << 4; + result[1] |= a >> 4; + result[2] |= a << 4; + result[2] |= b >> 4; + result[3] |= b << 4; + + // msb is different from lsb, a and b since its 2 most significant bits are guaranteed + // to be 0, and hence not encoded. + // + // More specifically, let the bits of msb be `00abcdef`. We encode `abcd` in + // `result[3]`, and `ef` as the most significant bits of `result[4]`. + result[3] |= msb >> 2; + result[4] |= msb << 6; + }; + + // write right child into result + { + // Recall that `result[4]` contains 2 bits from the left child id in the most + // significant bits. Also, the most significant byte of the right child is guaranteed to + // fit in 6 bits. Hence, we use big endian format for the right child id to simplify + // encoding and decoding. + let [msb, a, b, lsb] = right_child_id.0.to_be_bytes(); + + result[4] |= msb; + result[5] = a; + result[6] = b; + result[7] = lsb; + }; + + Self(result) + } + + pub fn decode_join_or_split(&self) -> (MastNodeId, MastNodeId) { + let first = { + let mut first_le_bytes = [0_u8; 4]; + + first_le_bytes[0] = self.0[0] << 4; + first_le_bytes[0] |= self.0[1] >> 4; + + first_le_bytes[1] = self.0[1] << 4; + first_le_bytes[1] |= self.0[2] >> 4; + + first_le_bytes[2] = self.0[2] << 4; + first_le_bytes[2] |= self.0[3] >> 4; + + first_le_bytes[3] = (self.0[3] & 0b1111) << 2; + first_le_bytes[3] |= self.0[4] >> 6; + + u32::from_le_bytes(first_le_bytes) + }; + + let second = { + let mut second_be_bytes = [0_u8; 4]; + + second_be_bytes[0] = self.0[4] & 0b0011_1111; + second_be_bytes[1] = self.0[5]; + second_be_bytes[2] = self.0[6]; + second_be_bytes[3] = self.0[7]; + + u32::from_be_bytes(second_be_bytes) + }; + + (MastNodeId(first), MastNodeId(second)) + } + + pub fn encode_u32_payload(discriminant: u8, payload: u32) -> Self { + let [payload_byte1, payload_byte2, payload_byte3, payload_byte4] = payload.to_be_bytes(); + + Self([ + discriminant << 4, + payload_byte1, + payload_byte2, + payload_byte3, + payload_byte4, + 0, + 0, + 0, + ]) + } + + pub fn decode_u32_payload(&self) -> u32 { + let payload_be_bytes = [self.0[1], self.0[2], self.0[3], self.0[4]]; + + u32::from_be_bytes(payload_be_bytes) + } +} + +impl Serializable for EncodedMastNodeType { + fn write_into(&self, target: &mut W) { + self.0.write_into(target); + } +} + +impl Deserializable for EncodedMastNodeType { + fn read_from(source: &mut R) -> Result { + let bytes = source.read_array()?; + + Ok(Self(bytes)) + } +} + +#[derive(Clone, Copy, Debug, FromPrimitive, ToPrimitive)] +#[repr(u8)] +pub enum MastNodeTypeVariant { + Join, + Split, + Loop, + Call, + Syscall, + Dyn, + Block, + External, +} + +impl MastNodeTypeVariant { + pub fn discriminant(&self) -> u8 { + self.to_u8().expect("guaranteed to fit in a `u8` due to #[repr(u8)]") + } + + pub fn try_from_discriminant(discriminant: u8) -> Result { + Self::from_u8(discriminant).ok_or_else(|| super::Error::InvalidDiscriminant { + ty: "MastNode".into(), + discriminant, + }) + } + + pub fn from_mast_node(mast_node: &MastNode) -> Self { + match mast_node { + MastNode::Block(_) => Self::Block, + MastNode::Join(_) => Self::Join, + MastNode::Split(_) => Self::Split, + MastNode::Loop(_) => Self::Loop, + MastNode::Call(call_node) => { + if call_node.is_syscall() { + Self::Syscall + } else { + Self::Call + } + } + MastNode::Dyn => Self::Dyn, + MastNode::External(_) => Self::External, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::mast::{JoinNode, SplitNode}; + + #[test] + fn mast_node_type_serde_join() { + let left_child_id = MastNodeId(0b00111001_11101011_01101100_11011000); + let right_child_id = MastNodeId(0b00100111_10101010_11111111_11001110); + let mast_node = MastNode::Join(JoinNode::new_test( + [left_child_id, right_child_id], + RpoDigest::default(), + )); + + let mast_node_type = EncodedMastNodeType::new(&mast_node); + + // Note: Join's discriminant is 0 + let expected_mast_node_type = [ + 0b00001101, 0b10000110, 0b11001110, 0b10111110, 0b01100111, 0b10101010, 0b11111111, + 0b11001110, + ]; + + assert_eq!(expected_mast_node_type, mast_node_type.0); + + let (decoded_left, decoded_right) = mast_node_type.decode_join_or_split(); + assert_eq!(left_child_id, decoded_left); + assert_eq!(right_child_id, decoded_right); + } + + #[test] + fn mast_node_type_serde_split() { + let on_true_id = MastNodeId(0b00111001_11101011_01101100_11011000); + let on_false_id = MastNodeId(0b00100111_10101010_11111111_11001110); + let mast_node = + MastNode::Split(SplitNode::new_test([on_true_id, on_false_id], RpoDigest::default())); + + let mast_node_type = EncodedMastNodeType::new(&mast_node); + + // Note: Split's discriminant is 0 + let expected_mast_node_type = [ + 0b00011101, 0b10000110, 0b11001110, 0b10111110, 0b01100111, 0b10101010, 0b11111111, + 0b11001110, + ]; + + assert_eq!(expected_mast_node_type, mast_node_type.0); + + let (decoded_on_true, decoded_on_false) = mast_node_type.decode_join_or_split(); + assert_eq!(on_true_id, decoded_on_true); + assert_eq!(on_false_id, decoded_on_false); + } + + // TODOP: Test all other variants +} diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index d10604e795..b7bbd78eed 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -2,9 +2,6 @@ use alloc::{ string::{String, ToString}, vec::Vec, }; -use miden_crypto::hash::rpo::RpoDigest; -use num_derive::{FromPrimitive, ToPrimitive}; -use num_traits::{FromPrimitive, ToPrimitive}; use thiserror::Error; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; @@ -12,6 +9,9 @@ use crate::mast::MerkleTreeNode; use super::{MastForest, MastNode, MastNodeId}; +mod info; +use info::{EncodedMastNodeType, MastNodeInfo, MastNodeTypeVariant}; + /// Specifies an offset into the `data` section of an encoded [`MastForest`]. type DataOffset = u32; @@ -25,6 +25,7 @@ const MAGIC: &[u8; 5] = b"MAST\0"; /// version field itself, but should be considered invalid for now. const VERSION: [u8; 3] = [0, 0, 0]; +// TODOP: move into info.rs? Make public? #[derive(Debug, Error)] pub enum Error { #[error("Invalid discriminant '{discriminant}' for type '{ty}'")] @@ -58,243 +59,6 @@ impl Deserializable for StringRef { } } -pub struct MastNodeInfo { - ty: EncodedMastNodeType, - offset: DataOffset, - digest: RpoDigest, -} - -impl Serializable for MastNodeInfo { - fn write_into(&self, target: &mut W) { - self.ty.write_into(target); - self.offset.write_into(target); - self.digest.write_into(target); - } -} - -impl Deserializable for MastNodeInfo { - fn read_from(source: &mut R) -> Result { - let ty = Deserializable::read_from(source)?; - let offset = DataOffset::read_from(source)?; - let digest = RpoDigest::read_from(source)?; - - Ok(Self { ty, offset, digest }) - } -} - -// TODOP: Describe how first 4 bits (i.e. high order bits of first byte) are the discriminant -pub struct EncodedMastNodeType([u8; 8]); - -/// Constructors -impl EncodedMastNodeType { - pub fn new(mast_node: &MastNode) -> Self { - use MastNode::*; - - let discriminant = MastNodeTypeVariant::from_mast_node(mast_node).discriminant(); - assert!(discriminant < 2_u8.pow(4_u32)); - - match mast_node { - Block(block_node) => { - let num_ops = block_node.num_operations_and_decorators(); - - Self::encode_u32_payload(discriminant, num_ops) - } - Join(join_node) => { - Self::encode_join_or_split(discriminant, join_node.first(), join_node.second()) - } - Split(split_node) => Self::encode_join_or_split( - discriminant, - split_node.on_true(), - split_node.on_false(), - ), - Loop(loop_node) => { - let child_id = loop_node.body().0; - - Self::encode_u32_payload(discriminant, child_id) - } - Call(call_node) => { - let child_id = call_node.callee().0; - - Self::encode_u32_payload(discriminant, child_id) - } - Dyn | External(_) => Self([discriminant << 4, 0, 0, 0, 0, 0, 0, 0]), - } - } -} - -/// Accessors -impl EncodedMastNodeType { - pub fn variant(&self) -> Result { - let discriminant = self.0[0] >> 4; - - MastNodeTypeVariant::try_from_discriminant(discriminant) - } -} - -/// Helpers -impl EncodedMastNodeType { - // TODOP: Make a diagram of how the bits are split - fn encode_join_or_split( - discriminant: u8, - left_child_id: MastNodeId, - right_child_id: MastNodeId, - ) -> Self { - assert!(left_child_id.0 < 2_u32.pow(30)); - assert!(right_child_id.0 < 2_u32.pow(30)); - - let mut result: [u8; 8] = [0_u8; 8]; - - result[0] = discriminant << 4; - - // write left child into result - { - let [lsb, a, b, msb] = left_child_id.0.to_le_bytes(); - result[0] |= lsb >> 4; - result[1] |= lsb << 4; - result[1] |= a >> 4; - result[2] |= a << 4; - result[2] |= b >> 4; - result[3] |= b << 4; - - // msb is different from lsb, a and b since its 2 most significant bits are guaranteed - // to be 0, and hence not encoded. - // - // More specifically, let the bits of msb be `00abcdef`. We encode `abcd` in - // `result[3]`, and `ef` as the most significant bits of `result[4]`. - result[3] |= msb >> 2; - result[4] |= msb << 6; - }; - - // write right child into result - { - // Recall that `result[4]` contains 2 bits from the left child id in the most - // significant bits. Also, the most significant byte of the right child is guaranteed to - // fit in 6 bits. Hence, we use big endian format for the right child id to simplify - // encoding and decoding. - let [msb, a, b, lsb] = right_child_id.0.to_be_bytes(); - - result[4] |= msb; - result[5] = a; - result[6] = b; - result[7] = lsb; - }; - - Self(result) - } - - fn decode_join_or_split(&self) -> (MastNodeId, MastNodeId) { - let first = { - let mut first_le_bytes = [0_u8; 4]; - - first_le_bytes[0] = self.0[0] << 4; - first_le_bytes[0] |= self.0[1] >> 4; - - first_le_bytes[1] = self.0[1] << 4; - first_le_bytes[1] |= self.0[2] >> 4; - - first_le_bytes[2] = self.0[2] << 4; - first_le_bytes[2] |= self.0[3] >> 4; - - first_le_bytes[3] = (self.0[3] & 0b1111) << 2; - first_le_bytes[3] |= self.0[4] >> 6; - - u32::from_le_bytes(first_le_bytes) - }; - - let second = { - let mut second_be_bytes = [0_u8; 4]; - - second_be_bytes[0] = self.0[4] & 0b0011_1111; - second_be_bytes[1] = self.0[5]; - second_be_bytes[2] = self.0[6]; - second_be_bytes[3] = self.0[7]; - - u32::from_be_bytes(second_be_bytes) - }; - - (MastNodeId(first), MastNodeId(second)) - } - - fn encode_u32_payload(discriminant: u8, payload: u32) -> Self { - let [payload_byte1, payload_byte2, payload_byte3, payload_byte4] = payload.to_be_bytes(); - - Self([ - discriminant << 4, - payload_byte1, - payload_byte2, - payload_byte3, - payload_byte4, - 0, - 0, - 0, - ]) - } - - fn decode_u32_payload(&self) -> u32 { - let payload_be_bytes = [self.0[1], self.0[2], self.0[3], self.0[4]]; - - u32::from_be_bytes(payload_be_bytes) - } -} - -impl Serializable for EncodedMastNodeType { - fn write_into(&self, target: &mut W) { - self.0.write_into(target); - } -} - -impl Deserializable for EncodedMastNodeType { - fn read_from(source: &mut R) -> Result { - let bytes = source.read_array()?; - - Ok(Self(bytes)) - } -} - -#[derive(Clone, Copy, Debug, FromPrimitive, ToPrimitive)] -#[repr(u8)] -pub enum MastNodeTypeVariant { - Join, - Split, - Loop, - Call, - Syscall, - Dyn, - Block, - External, -} - -impl MastNodeTypeVariant { - pub fn discriminant(&self) -> u8 { - self.to_u8().expect("guaranteed to fit in a `u8` due to #[repr(u8)]") - } - - pub fn try_from_discriminant(discriminant: u8) -> Result { - Self::from_u8(discriminant).ok_or_else(|| Error::InvalidDiscriminant { - ty: "MastNode".into(), - discriminant, - }) - } - - pub fn from_mast_node(mast_node: &MastNode) -> Self { - match mast_node { - MastNode::Block(_) => Self::Block, - MastNode::Join(_) => Self::Join, - MastNode::Split(_) => Self::Split, - MastNode::Loop(_) => Self::Loop, - MastNode::Call(call_node) => { - if call_node.is_syscall() { - Self::Syscall - } else { - Self::Call - } - } - MastNode::Dyn => Self::Dyn, - MastNode::External(_) => Self::External, - } - } -} - impl Serializable for MastForest { fn write_into(&self, target: &mut W) { // TODOP: make sure padding is in accordance with Paul's docs @@ -444,57 +208,3 @@ fn try_info_to_mast_node( MastNodeTypeVariant::External => Ok(MastNode::new_external(mast_node_info.digest)), } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::mast::{JoinNode, SplitNode}; - - #[test] - fn mast_node_type_serde_join() { - let left_child_id = MastNodeId(0b00111001_11101011_01101100_11011000); - let right_child_id = MastNodeId(0b00100111_10101010_11111111_11001110); - let mast_node = MastNode::Join(JoinNode::new_test( - [left_child_id, right_child_id], - RpoDigest::default(), - )); - - let mast_node_type = EncodedMastNodeType::new(&mast_node); - - // Note: Join's discriminant is 0 - let expected_mast_node_type = [ - 0b00001101, 0b10000110, 0b11001110, 0b10111110, 0b01100111, 0b10101010, 0b11111111, - 0b11001110, - ]; - - assert_eq!(expected_mast_node_type, mast_node_type.0); - - let (decoded_left, decoded_right) = mast_node_type.decode_join_or_split(); - assert_eq!(left_child_id, decoded_left); - assert_eq!(right_child_id, decoded_right); - } - - #[test] - fn mast_node_type_serde_split() { - let on_true_id = MastNodeId(0b00111001_11101011_01101100_11011000); - let on_false_id = MastNodeId(0b00100111_10101010_11111111_11001110); - let mast_node = - MastNode::Split(SplitNode::new_test([on_true_id, on_false_id], RpoDigest::default())); - - let mast_node_type = EncodedMastNodeType::new(&mast_node); - - // Note: Split's discriminant is 0 - let expected_mast_node_type = [ - 0b00011101, 0b10000110, 0b11001110, 0b10111110, 0b01100111, 0b10101010, 0b11111111, - 0b11001110, - ]; - - assert_eq!(expected_mast_node_type, mast_node_type.0); - - let (decoded_on_true, decoded_on_false) = mast_node_type.decode_join_or_split(); - assert_eq!(on_true_id, decoded_on_true); - assert_eq!(on_false_id, decoded_on_false); - } - - // TODOP: Test all other variants -} From 91009b0f161fe4f61ba43a51f0e937dbfb46a1be Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Mon, 1 Jul 2024 16:43:45 -0400 Subject: [PATCH 049/172] encode operations into `data` field --- core/src/mast/mod.rs | 2 +- core/src/mast/node/basic_block_node.rs | 15 ++- core/src/mast/node/mod.rs | 4 +- core/src/mast/serialization/mod.rs | 127 ++++++++++++++++++++++++- 4 files changed, 142 insertions(+), 6 deletions(-) diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index cea97d093d..de99e707c7 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -6,7 +6,7 @@ use miden_crypto::hash::rpo::RpoDigest; mod node; pub use node::{ get_span_op_group_count, BasicBlockNode, CallNode, DynNode, JoinNode, LoopNode, MastNode, - OpBatch, SplitNode, OP_BATCH_SIZE, OP_GROUP_SIZE, + OpBatch, OperationOrDecorator, SplitNode, OP_BATCH_SIZE, OP_GROUP_SIZE, }; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; diff --git a/core/src/mast/node/basic_block_node.rs b/core/src/mast/node/basic_block_node.rs index 56ca756c17..49840c6784 100644 --- a/core/src/mast/node/basic_block_node.rs +++ b/core/src/mast/node/basic_block_node.rs @@ -8,7 +8,7 @@ use winter_utils::flatten_slice_elements; use crate::{ chiplets::hasher, mast::{MastForest, MerkleTreeNode}, - DecoratorIterator, DecoratorList, Operation, + Decorator, DecoratorIterator, DecoratorList, Operation, }; // CONSTANTS @@ -23,6 +23,12 @@ pub const BATCH_SIZE: usize = 8; // BASIC BLOCK NODE // ================================================================================================ +// TODOP: Document +pub enum OperationOrDecorator { + Operation(Operation), + Decorator(Decorator), +} + /// Block for a linear sequence of operations (i.e., no branching or loops). /// /// Executes its operations in order. Fails if any of the operations fails. @@ -115,6 +121,13 @@ impl BasicBlockNode { &self.op_batches } + /// Returns an iterator over all operations and decorator, in the order in which they appear in + /// the program. + pub fn iter(&self) -> impl Iterator { + // TODOP: implement + core::iter::empty() + } + /// Returns a [`DecoratorIterator`] which allows us to iterate through the decorator list of /// this basic block node while executing operation batches of this basic block node. pub fn decorator_iter(&self) -> DecoratorIterator { diff --git a/core/src/mast/node/mod.rs b/core/src/mast/node/mod.rs index 65ef108726..16f55452be 100644 --- a/core/src/mast/node/mod.rs +++ b/core/src/mast/node/mod.rs @@ -3,8 +3,8 @@ use core::fmt; use alloc::{boxed::Box, vec::Vec}; pub use basic_block_node::{ - get_span_op_group_count, BasicBlockNode, OpBatch, BATCH_SIZE as OP_BATCH_SIZE, - GROUP_SIZE as OP_GROUP_SIZE, + get_span_op_group_count, BasicBlockNode, OpBatch, OperationOrDecorator, + BATCH_SIZE as OP_BATCH_SIZE, GROUP_SIZE as OP_GROUP_SIZE, }; mod call_node; diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index b7bbd78eed..e2f997d5a1 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -2,10 +2,14 @@ use alloc::{ string::{String, ToString}, vec::Vec, }; +use num_traits::ToBytes; use thiserror::Error; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; -use crate::mast::MerkleTreeNode; +use crate::{ + mast::{MerkleTreeNode, OperationOrDecorator}, + Operation, +}; use super::{MastForest, MastNode, MastNodeId}; @@ -155,13 +159,132 @@ fn mast_node_to_info( let digest = mast_node.digest(); let offset = match mast_node { - Block(_) => todo!(), + Block(basic_block) => { + let offset: u32 = data + .len() + .try_into() + .expect("MastForest serialization: data field larger than 2^32 bytes"); + + for op_or_decorator in basic_block.iter() { + match op_or_decorator { + OperationOrDecorator::Operation(operation) => encode_operation(operation, data), + OperationOrDecorator::Decorator(_) => { + // TODOP: Remember: you need to set the most significant bit to 1 + todo!() + } + } + } + + offset + } Join(_) | Split(_) | Loop(_) | Call(_) | Dyn | External(_) => 0, }; MastNodeInfo { ty, offset, digest } } +fn encode_operation(operation: &Operation, data: &mut Vec) { + data.push(operation.op_code()); + + // For operations that have extra data, encode it in `data`. + match operation { + Operation::Assert(value) | Operation::MpVerify(value) => { + data.extend_from_slice(&value.to_le_bytes()) + } + Operation::U32assert2(value) | Operation::Push(value) => { + data.extend_from_slice(&value.as_int().to_le_bytes()) + } + // Note: we explicitly write out all the operations so that whenever we make a modification + // to the `Operation` enum, we get a compile error here. This should help us remember to + // properly encode/decode each operation variant. + Operation::Noop + | Operation::FmpAdd + | Operation::FmpUpdate + | Operation::SDepth + | Operation::Caller + | Operation::Clk + | Operation::Join + | Operation::Split + | Operation::Loop + | Operation::Call + | Operation::Dyn + | Operation::SysCall + | Operation::Span + | Operation::End + | Operation::Repeat + | Operation::Respan + | Operation::Halt + | Operation::Add + | Operation::Neg + | Operation::Mul + | Operation::Inv + | Operation::Incr + | Operation::And + | Operation::Or + | Operation::Not + | Operation::Eq + | Operation::Eqz + | Operation::Expacc + | Operation::Ext2Mul + | Operation::U32split + | Operation::U32add + | Operation::U32add3 + | Operation::U32sub + | Operation::U32mul + | Operation::U32madd + | Operation::U32div + | Operation::U32and + | Operation::U32xor + | Operation::Pad + | Operation::Drop + | Operation::Dup0 + | Operation::Dup1 + | Operation::Dup2 + | Operation::Dup3 + | Operation::Dup4 + | Operation::Dup5 + | Operation::Dup6 + | Operation::Dup7 + | Operation::Dup9 + | Operation::Dup11 + | Operation::Dup13 + | Operation::Dup15 + | Operation::Swap + | Operation::SwapW + | Operation::SwapW2 + | Operation::SwapW3 + | Operation::SwapDW + | Operation::MovUp2 + | Operation::MovUp3 + | Operation::MovUp4 + | Operation::MovUp5 + | Operation::MovUp6 + | Operation::MovUp7 + | Operation::MovUp8 + | Operation::MovDn2 + | Operation::MovDn3 + | Operation::MovDn4 + | Operation::MovDn5 + | Operation::MovDn6 + | Operation::MovDn7 + | Operation::MovDn8 + | Operation::CSwap + | Operation::CSwapW + | Operation::AdvPop + | Operation::AdvPopW + | Operation::MLoadW + | Operation::MStoreW + | Operation::MLoad + | Operation::MStore + | Operation::MStream + | Operation::Pipe + | Operation::HPerm + | Operation::MrUpdate + | Operation::FriE2F4 + | Operation::RCombBase => (), + } +} + fn try_info_to_mast_node( mast_node_info: MastNodeInfo, mast_forest: &MastForest, From d5ed1085dd5ef176db6f8fcae714fd30cbeda7df Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Mon, 1 Jul 2024 18:00:24 -0400 Subject: [PATCH 050/172] decode operations --- core/src/lib.rs | 2 +- core/src/mast/serialization/mod.rs | 86 ++++++++++++++++++++++++++++-- core/src/operations/mod.rs | 14 +++++ 3 files changed, 96 insertions(+), 6 deletions(-) diff --git a/core/src/lib.rs b/core/src/lib.rs index 6422eb5725..a2d736c60a 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -102,7 +102,7 @@ pub mod prettier { mod operations; pub use operations::{ AdviceInjector, AssemblyOp, DebugOptions, Decorator, DecoratorIterator, DecoratorList, - Operation, SignatureKind, + Operation, OperationData, SignatureKind, }; pub mod stack; diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index e2f997d5a1..16ee3ed86a 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -2,13 +2,16 @@ use alloc::{ string::{String, ToString}, vec::Vec, }; +use miden_crypto::{Felt, ZERO}; use num_traits::ToBytes; use thiserror::Error; -use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; +use winter_utils::{ + ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, SliceReader, +}; use crate::{ mast::{MerkleTreeNode, OperationOrDecorator}, - Operation, + DecoratorList, Operation, OperationData, }; use super::{MastForest, MastNode, MastNodeId}; @@ -128,12 +131,18 @@ impl Deserializable for MastForest { let strings: Vec = Deserializable::read_from(source)?; let data: Vec = Deserializable::read_from(source)?; + let mut data_reader = SliceReader::new(&data); let mast_forest = { let mut mast_forest = MastForest::new(); for mast_node_info in mast_node_infos { - let node = try_info_to_mast_node(mast_node_info, &mast_forest, &data, &strings)?; + let node = try_info_to_mast_node( + mast_node_info, + &mast_forest, + &mut data_reader, + &strings, + )?; mast_forest.add_node(node); } @@ -288,7 +297,7 @@ fn encode_operation(operation: &Operation, data: &mut Vec) { fn try_info_to_mast_node( mast_node_info: MastNodeInfo, mast_forest: &MastForest, - data: &[u8], + data_reader: &mut SliceReader, strings: &[StringRef], ) -> Result { let mast_node_variant = mast_node_info @@ -299,7 +308,18 @@ fn try_info_to_mast_node( // TODOP: Make a faillible version of `MastNode` ctors // TODOP: Check digest of resulting `MastNode` matches `MastNodeInfo.digest`? match mast_node_variant { - MastNodeTypeVariant::Block => todo!(), + MastNodeTypeVariant::Block => { + let num_operations_and_decorators = + EncodedMastNodeType::decode_u32_payload(&mast_node_info.ty); + + let (operations, decorators) = decode_operations_and_decorators( + num_operations_and_decorators, + data_reader, + strings, + )?; + + Ok(MastNode::new_basic_block_with_decorators(operations, decorators)) + } MastNodeTypeVariant::Join => { let (left_child, right_child) = EncodedMastNodeType::decode_join_or_split(&mast_node_info.ty); @@ -331,3 +351,59 @@ fn try_info_to_mast_node( MastNodeTypeVariant::External => Ok(MastNode::new_external(mast_node_info.digest)), } } + +fn decode_operations_and_decorators( + num_to_decode: u32, + data_reader: &mut SliceReader, + strings: &[StringRef], +) -> Result<(Vec, DecoratorList), DeserializationError> { + let mut operations: Vec = Vec::new(); + let mut decorators: DecoratorList = Vec::new(); + + for _ in 0..num_to_decode { + let first_byte = data_reader.read_u8()?; + + if first_byte & 0b1000_0000 > 0 { + // operation. + let op_code = first_byte; + + let maybe_operation = if op_code == Operation::Assert(0_u32).op_code() + || op_code == Operation::MpVerify(0_u32).op_code() + { + let value_le_bytes: [u8; 4] = data_reader.read_array()?; + let value = u32::from_le_bytes(value_le_bytes); + + Operation::with_opcode_and_data(op_code, OperationData::U32(value)) + } else if op_code == Operation::U32assert2(ZERO).op_code() + || op_code == Operation::Push(ZERO).op_code() + { + // Felt operation data + let value_le_bytes: [u8; 8] = data_reader.read_array()?; + let value_u64 = u64::from_le_bytes(value_le_bytes); + let value_felt = Felt::try_from(value_u64).map_err(|_| { + DeserializationError::InvalidValue(format!( + "Operation associated data doesn't fit in a field element: {value_u64}" + )) + })?; + + Operation::with_opcode_and_data(op_code, OperationData::Felt(value_felt)) + } else { + // No operation data + Operation::with_opcode_and_data(op_code, OperationData::None) + }; + + let operation = maybe_operation.ok_or_else(|| { + DeserializationError::InvalidValue(format!("invalid op code: {op_code}")) + })?; + + operations.push(operation); + } else { + // decorator. + let discriminant = first_byte & 0b0111_1111; + + todo!() + } + } + + Ok((operations, decorators)) +} diff --git a/core/src/operations/mod.rs b/core/src/operations/mod.rs index a66bcec4db..d1cc95f273 100644 --- a/core/src/operations/mod.rs +++ b/core/src/operations/mod.rs @@ -441,6 +441,20 @@ pub enum Operation { RCombBase, } +pub enum OperationData { + Felt(Felt), + U32(u32), + None, +} + +/// Constructors +impl Operation { + // TODOP: document, and use `Result` instead? + pub fn with_opcode_and_data(opcode: u8, data: OperationData) -> Option { + todo!() + } +} + impl Operation { pub const OP_BITS: usize = 7; From 0cf49ffab8a780e7ed9ba4f095b90be7d1a39145 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Tue, 2 Jul 2024 11:21:18 -0400 Subject: [PATCH 051/172] implement `BasicBlockNode::num_operations_and_decorators()` --- core/src/mast/node/basic_block_node.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/mast/node/basic_block_node.rs b/core/src/mast/node/basic_block_node.rs index 49840c6784..a033a8379a 100644 --- a/core/src/mast/node/basic_block_node.rs +++ b/core/src/mast/node/basic_block_node.rs @@ -114,7 +114,12 @@ impl BasicBlockNode { /// Public accessors impl BasicBlockNode { pub fn num_operations_and_decorators(&self) -> u32 { - todo!() + let num_ops: usize = self.op_batches.iter().map(|batch| batch.ops().len()).sum(); + let num_decorators = self.decorators.len(); + + (num_ops + num_decorators) + .try_into() + .expect("basic block contains more than 2^32 operations and decorators") } pub fn op_batches(&self) -> &[OpBatch] { From 64c36ecf75d58b04048a1b94ce3a9f62262b44ef Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Tue, 2 Jul 2024 13:37:20 -0400 Subject: [PATCH 052/172] OperationOrDecoratorIterator --- core/src/mast/node/basic_block_node.rs | 82 +++++++++++++++++++++++--- 1 file changed, 73 insertions(+), 9 deletions(-) diff --git a/core/src/mast/node/basic_block_node.rs b/core/src/mast/node/basic_block_node.rs index a033a8379a..74e13b146e 100644 --- a/core/src/mast/node/basic_block_node.rs +++ b/core/src/mast/node/basic_block_node.rs @@ -23,12 +23,6 @@ pub const BATCH_SIZE: usize = 8; // BASIC BLOCK NODE // ================================================================================================ -// TODOP: Document -pub enum OperationOrDecorator { - Operation(Operation), - Decorator(Decorator), -} - /// Block for a linear sequence of operations (i.e., no branching or loops). /// /// Executes its operations in order. Fails if any of the operations fails. @@ -128,9 +122,8 @@ impl BasicBlockNode { /// Returns an iterator over all operations and decorator, in the order in which they appear in /// the program. - pub fn iter(&self) -> impl Iterator { - // TODOP: implement - core::iter::empty() + pub fn iter(&self) -> impl Iterator { + OperationOrDecoratorIterator::new(self) } /// Returns a [`DecoratorIterator`] which allows us to iterate through the decorator list of @@ -224,6 +217,77 @@ impl fmt::Display for BasicBlockNode { } } +// OPERATION OR DECORATOR +// ================================================================================================ + +// TODOP: Document +pub enum OperationOrDecorator<'a> { + Operation(&'a Operation), + Decorator(&'a Decorator), +} + +struct OperationOrDecoratorIterator<'a> { + node: &'a BasicBlockNode, + + /// The index of the current batch + batch_index: usize, + + /// The index of the operation in the current batch + op_index_in_batch: usize, + + /// The index of the current operation across all batches + op_index: usize, + + /// The index of the next element in `node.decorator_list`. This list is assumed to be sorted. + decorator_list_next_index: usize, +} + +impl<'a> OperationOrDecoratorIterator<'a> { + fn new(node: &'a BasicBlockNode) -> Self { + Self { + node, + batch_index: 0, + op_index_in_batch: 0, + op_index: 0, + decorator_list_next_index: 0, + } + } +} + +impl<'a> Iterator for OperationOrDecoratorIterator<'a> { + type Item = OperationOrDecorator<'a>; + + fn next(&mut self) -> Option { + // check if there's a decorator to execute + if let Some((op_index, decorator)) = + self.node.decorators.get(self.decorator_list_next_index) + { + if *op_index == self.op_index { + self.decorator_list_next_index += 1; + return Some(OperationOrDecorator::Decorator(decorator)); + } + } + + // If no decorator needs to be executed, then execute the operation + if let Some(batch) = self.node.op_batches.get(self.batch_index) { + if let Some(operation) = batch.ops.get(self.op_index_in_batch) { + self.op_index_in_batch += 1; + self.op_index += 1; + + Some(OperationOrDecorator::Operation(operation)) + } else { + self.batch_index += 1; + self.op_index_in_batch = 0; + + self.next() + } + } else { + None + } + } +} + + // OPERATION BATCH // ================================================================================================ From 794ebbb2053957f43186d2115d8a3d7a59999330 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Tue, 2 Jul 2024 13:39:49 -0400 Subject: [PATCH 053/172] basic block node: move tests in new file --- .../mod.rs} | 317 +----------------- core/src/mast/node/basic_block_node/tests.rs | 306 +++++++++++++++++ 2 files changed, 309 insertions(+), 314 deletions(-) rename core/src/mast/node/{basic_block_node.rs => basic_block_node/mod.rs} (64%) create mode 100644 core/src/mast/node/basic_block_node/tests.rs diff --git a/core/src/mast/node/basic_block_node.rs b/core/src/mast/node/basic_block_node/mod.rs similarity index 64% rename from core/src/mast/node/basic_block_node.rs rename to core/src/mast/node/basic_block_node/mod.rs index 74e13b146e..b0b8e47fa7 100644 --- a/core/src/mast/node/basic_block_node.rs +++ b/core/src/mast/node/basic_block_node/mod.rs @@ -11,6 +11,9 @@ use crate::{ Decorator, DecoratorIterator, DecoratorList, Operation, }; +#[cfg(test)] +mod tests; + // CONSTANTS // ================================================================================================ @@ -287,7 +290,6 @@ impl<'a> Iterator for OperationOrDecoratorIterator<'a> { } } - // OPERATION BATCH // ================================================================================================ @@ -512,316 +514,3 @@ pub fn get_span_op_group_count(op_batches: &[OpBatch]) -> usize { let last_batch_num_groups = op_batches.last().expect("no last group").num_groups(); (op_batches.len() - 1) * BATCH_SIZE + last_batch_num_groups.next_power_of_two() } - -// TESTS -// ================================================================================================ - -#[cfg(test)] -mod tests { - use super::{hasher, Felt, Operation, BATCH_SIZE, ZERO}; - use crate::ONE; - - #[test] - fn batch_ops() { - // --- one operation ---------------------------------------------------------------------- - let ops = vec![Operation::Add]; - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(1, batches.len()); - - let batch = &batches[0]; - assert_eq!(ops, batch.ops); - assert_eq!(1, batch.num_groups()); - - let mut batch_groups = [ZERO; BATCH_SIZE]; - batch_groups[0] = build_group(&ops); - - assert_eq!(batch_groups, batch.groups); - assert_eq!([1_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts); - assert_eq!(hasher::hash_elements(&batch_groups), hash); - - // --- two operations --------------------------------------------------------------------- - let ops = vec![Operation::Add, Operation::Mul]; - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(1, batches.len()); - - let batch = &batches[0]; - assert_eq!(ops, batch.ops); - assert_eq!(1, batch.num_groups()); - - let mut batch_groups = [ZERO; BATCH_SIZE]; - batch_groups[0] = build_group(&ops); - - assert_eq!(batch_groups, batch.groups); - assert_eq!([2_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts); - assert_eq!(hasher::hash_elements(&batch_groups), hash); - - // --- one group with one immediate value ------------------------------------------------- - let ops = vec![Operation::Add, Operation::Push(Felt::new(12345678))]; - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(1, batches.len()); - - let batch = &batches[0]; - assert_eq!(ops, batch.ops); - assert_eq!(2, batch.num_groups()); - - let mut batch_groups = [ZERO; BATCH_SIZE]; - batch_groups[0] = build_group(&ops); - batch_groups[1] = Felt::new(12345678); - - assert_eq!(batch_groups, batch.groups); - assert_eq!([2_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts); - assert_eq!(hasher::hash_elements(&batch_groups), hash); - - // --- one group with 7 immediate values -------------------------------------------------- - let ops = vec![ - Operation::Push(ONE), - Operation::Push(Felt::new(2)), - Operation::Push(Felt::new(3)), - Operation::Push(Felt::new(4)), - Operation::Push(Felt::new(5)), - Operation::Push(Felt::new(6)), - Operation::Push(Felt::new(7)), - Operation::Add, - ]; - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(1, batches.len()); - - let batch = &batches[0]; - assert_eq!(ops, batch.ops); - assert_eq!(8, batch.num_groups()); - - let batch_groups = [ - build_group(&ops), - ONE, - Felt::new(2), - Felt::new(3), - Felt::new(4), - Felt::new(5), - Felt::new(6), - Felt::new(7), - ]; - - assert_eq!(batch_groups, batch.groups); - assert_eq!([8_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts); - assert_eq!(hasher::hash_elements(&batch_groups), hash); - - // --- two groups with 7 immediate values; the last push overflows to the second batch ---- - let ops = vec![ - Operation::Add, - Operation::Mul, - Operation::Push(ONE), - Operation::Push(Felt::new(2)), - Operation::Push(Felt::new(3)), - Operation::Push(Felt::new(4)), - Operation::Push(Felt::new(5)), - Operation::Push(Felt::new(6)), - Operation::Add, - Operation::Push(Felt::new(7)), - ]; - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(2, batches.len()); - - let batch0 = &batches[0]; - assert_eq!(ops[..9], batch0.ops); - assert_eq!(7, batch0.num_groups()); - - let batch0_groups = [ - build_group(&ops[..9]), - ONE, - Felt::new(2), - Felt::new(3), - Felt::new(4), - Felt::new(5), - Felt::new(6), - ZERO, - ]; - - assert_eq!(batch0_groups, batch0.groups); - assert_eq!([9_usize, 0, 0, 0, 0, 0, 0, 0], batch0.op_counts); - - let batch1 = &batches[1]; - assert_eq!(vec![ops[9]], batch1.ops); - assert_eq!(2, batch1.num_groups()); - - let mut batch1_groups = [ZERO; BATCH_SIZE]; - batch1_groups[0] = build_group(&[ops[9]]); - batch1_groups[1] = Felt::new(7); - - assert_eq!([1_usize, 0, 0, 0, 0, 0, 0, 0], batch1.op_counts); - assert_eq!(batch1_groups, batch1.groups); - - let all_groups = [batch0_groups, batch1_groups].concat(); - assert_eq!(hasher::hash_elements(&all_groups), hash); - - // --- immediate values in-between groups ------------------------------------------------- - let ops = vec![ - Operation::Add, - Operation::Mul, - Operation::Add, - Operation::Push(Felt::new(7)), - Operation::Add, - Operation::Add, - Operation::Push(Felt::new(11)), - Operation::Mul, - Operation::Mul, - Operation::Add, - ]; - - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(1, batches.len()); - - let batch = &batches[0]; - assert_eq!(ops, batch.ops); - assert_eq!(4, batch.num_groups()); - - let batch_groups = [ - build_group(&ops[..9]), - Felt::new(7), - Felt::new(11), - build_group(&ops[9..]), - ZERO, - ZERO, - ZERO, - ZERO, - ]; - - assert_eq!([9_usize, 0, 0, 1, 0, 0, 0, 0], batch.op_counts); - assert_eq!(batch_groups, batch.groups); - assert_eq!(hasher::hash_elements(&batch_groups), hash); - - // --- push at the end of a group is moved into the next group ---------------------------- - let ops = vec![ - Operation::Add, - Operation::Mul, - Operation::Add, - Operation::Add, - Operation::Add, - Operation::Mul, - Operation::Mul, - Operation::Add, - Operation::Push(Felt::new(11)), - ]; - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(1, batches.len()); - - let batch = &batches[0]; - assert_eq!(ops, batch.ops); - assert_eq!(3, batch.num_groups()); - - let batch_groups = [ - build_group(&ops[..8]), - build_group(&[ops[8]]), - Felt::new(11), - ZERO, - ZERO, - ZERO, - ZERO, - ZERO, - ]; - - assert_eq!(batch_groups, batch.groups); - assert_eq!([8_usize, 1, 0, 0, 0, 0, 0, 0], batch.op_counts); - assert_eq!(hasher::hash_elements(&batch_groups), hash); - - // --- push at the end of a group is moved into the next group ---------------------------- - let ops = vec![ - Operation::Add, - Operation::Mul, - Operation::Add, - Operation::Add, - Operation::Add, - Operation::Mul, - Operation::Mul, - Operation::Push(ONE), - Operation::Push(Felt::new(2)), - ]; - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(1, batches.len()); - - let batch = &batches[0]; - assert_eq!(ops, batch.ops); - assert_eq!(4, batch.num_groups()); - - let batch_groups = [ - build_group(&ops[..8]), - ONE, - build_group(&[ops[8]]), - Felt::new(2), - ZERO, - ZERO, - ZERO, - ZERO, - ]; - - assert_eq!(batch_groups, batch.groups); - assert_eq!([8_usize, 0, 1, 0, 0, 0, 0, 0], batch.op_counts); - assert_eq!(hasher::hash_elements(&batch_groups), hash); - - // --- push at the end of the 7th group overflows to the next batch ----------------------- - let ops = vec![ - Operation::Add, - Operation::Mul, - Operation::Push(ONE), - Operation::Push(Felt::new(2)), - Operation::Push(Felt::new(3)), - Operation::Push(Felt::new(4)), - Operation::Push(Felt::new(5)), - Operation::Add, - Operation::Mul, - Operation::Add, - Operation::Mul, - Operation::Add, - Operation::Mul, - Operation::Add, - Operation::Mul, - Operation::Add, - Operation::Mul, - Operation::Push(Felt::new(6)), - Operation::Pad, - ]; - - let (batches, hash) = super::batch_ops(ops.clone()); - assert_eq!(2, batches.len()); - - let batch0 = &batches[0]; - assert_eq!(ops[..17], batch0.ops); - assert_eq!(7, batch0.num_groups()); - - let batch0_groups = [ - build_group(&ops[..9]), - ONE, - Felt::new(2), - Felt::new(3), - Felt::new(4), - Felt::new(5), - build_group(&ops[9..17]), - ZERO, - ]; - - assert_eq!(batch0_groups, batch0.groups); - assert_eq!([9_usize, 0, 0, 0, 0, 0, 8, 0], batch0.op_counts); - - let batch1 = &batches[1]; - assert_eq!(ops[17..], batch1.ops); - assert_eq!(2, batch1.num_groups()); - - let batch1_groups = - [build_group(&ops[17..]), Felt::new(6), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO]; - assert_eq!(batch1_groups, batch1.groups); - assert_eq!([2_usize, 0, 0, 0, 0, 0, 0, 0], batch1.op_counts); - - let all_groups = [batch0_groups, batch1_groups].concat(); - assert_eq!(hasher::hash_elements(&all_groups), hash); - } - - // TEST HELPERS - // -------------------------------------------------------------------------------------------- - - fn build_group(ops: &[Operation]) -> Felt { - let mut group = 0u64; - for (i, op) in ops.iter().enumerate() { - group |= (op.op_code() as u64) << (Operation::OP_BITS * i); - } - Felt::new(group) - } -} diff --git a/core/src/mast/node/basic_block_node/tests.rs b/core/src/mast/node/basic_block_node/tests.rs new file mode 100644 index 0000000000..3285e8f401 --- /dev/null +++ b/core/src/mast/node/basic_block_node/tests.rs @@ -0,0 +1,306 @@ + +use super::{hasher, Felt, Operation, BATCH_SIZE, ZERO}; +use crate::ONE; + +#[test] +fn batch_ops() { + // --- one operation ---------------------------------------------------------------------- + let ops = vec![Operation::Add]; + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(1, batches.len()); + + let batch = &batches[0]; + assert_eq!(ops, batch.ops); + assert_eq!(1, batch.num_groups()); + + let mut batch_groups = [ZERO; BATCH_SIZE]; + batch_groups[0] = build_group(&ops); + + assert_eq!(batch_groups, batch.groups); + assert_eq!([1_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts); + assert_eq!(hasher::hash_elements(&batch_groups), hash); + + // --- two operations --------------------------------------------------------------------- + let ops = vec![Operation::Add, Operation::Mul]; + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(1, batches.len()); + + let batch = &batches[0]; + assert_eq!(ops, batch.ops); + assert_eq!(1, batch.num_groups()); + + let mut batch_groups = [ZERO; BATCH_SIZE]; + batch_groups[0] = build_group(&ops); + + assert_eq!(batch_groups, batch.groups); + assert_eq!([2_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts); + assert_eq!(hasher::hash_elements(&batch_groups), hash); + + // --- one group with one immediate value ------------------------------------------------- + let ops = vec![Operation::Add, Operation::Push(Felt::new(12345678))]; + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(1, batches.len()); + + let batch = &batches[0]; + assert_eq!(ops, batch.ops); + assert_eq!(2, batch.num_groups()); + + let mut batch_groups = [ZERO; BATCH_SIZE]; + batch_groups[0] = build_group(&ops); + batch_groups[1] = Felt::new(12345678); + + assert_eq!(batch_groups, batch.groups); + assert_eq!([2_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts); + assert_eq!(hasher::hash_elements(&batch_groups), hash); + + // --- one group with 7 immediate values -------------------------------------------------- + let ops = vec![ + Operation::Push(ONE), + Operation::Push(Felt::new(2)), + Operation::Push(Felt::new(3)), + Operation::Push(Felt::new(4)), + Operation::Push(Felt::new(5)), + Operation::Push(Felt::new(6)), + Operation::Push(Felt::new(7)), + Operation::Add, + ]; + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(1, batches.len()); + + let batch = &batches[0]; + assert_eq!(ops, batch.ops); + assert_eq!(8, batch.num_groups()); + + let batch_groups = [ + build_group(&ops), + ONE, + Felt::new(2), + Felt::new(3), + Felt::new(4), + Felt::new(5), + Felt::new(6), + Felt::new(7), + ]; + + assert_eq!(batch_groups, batch.groups); + assert_eq!([8_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts); + assert_eq!(hasher::hash_elements(&batch_groups), hash); + + // --- two groups with 7 immediate values; the last push overflows to the second batch ---- + let ops = vec![ + Operation::Add, + Operation::Mul, + Operation::Push(ONE), + Operation::Push(Felt::new(2)), + Operation::Push(Felt::new(3)), + Operation::Push(Felt::new(4)), + Operation::Push(Felt::new(5)), + Operation::Push(Felt::new(6)), + Operation::Add, + Operation::Push(Felt::new(7)), + ]; + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(2, batches.len()); + + let batch0 = &batches[0]; + assert_eq!(ops[..9], batch0.ops); + assert_eq!(7, batch0.num_groups()); + + let batch0_groups = [ + build_group(&ops[..9]), + ONE, + Felt::new(2), + Felt::new(3), + Felt::new(4), + Felt::new(5), + Felt::new(6), + ZERO, + ]; + + assert_eq!(batch0_groups, batch0.groups); + assert_eq!([9_usize, 0, 0, 0, 0, 0, 0, 0], batch0.op_counts); + + let batch1 = &batches[1]; + assert_eq!(vec![ops[9]], batch1.ops); + assert_eq!(2, batch1.num_groups()); + + let mut batch1_groups = [ZERO; BATCH_SIZE]; + batch1_groups[0] = build_group(&[ops[9]]); + batch1_groups[1] = Felt::new(7); + + assert_eq!([1_usize, 0, 0, 0, 0, 0, 0, 0], batch1.op_counts); + assert_eq!(batch1_groups, batch1.groups); + + let all_groups = [batch0_groups, batch1_groups].concat(); + assert_eq!(hasher::hash_elements(&all_groups), hash); + + // --- immediate values in-between groups ------------------------------------------------- + let ops = vec![ + Operation::Add, + Operation::Mul, + Operation::Add, + Operation::Push(Felt::new(7)), + Operation::Add, + Operation::Add, + Operation::Push(Felt::new(11)), + Operation::Mul, + Operation::Mul, + Operation::Add, + ]; + + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(1, batches.len()); + + let batch = &batches[0]; + assert_eq!(ops, batch.ops); + assert_eq!(4, batch.num_groups()); + + let batch_groups = [ + build_group(&ops[..9]), + Felt::new(7), + Felt::new(11), + build_group(&ops[9..]), + ZERO, + ZERO, + ZERO, + ZERO, + ]; + + assert_eq!([9_usize, 0, 0, 1, 0, 0, 0, 0], batch.op_counts); + assert_eq!(batch_groups, batch.groups); + assert_eq!(hasher::hash_elements(&batch_groups), hash); + + // --- push at the end of a group is moved into the next group ---------------------------- + let ops = vec![ + Operation::Add, + Operation::Mul, + Operation::Add, + Operation::Add, + Operation::Add, + Operation::Mul, + Operation::Mul, + Operation::Add, + Operation::Push(Felt::new(11)), + ]; + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(1, batches.len()); + + let batch = &batches[0]; + assert_eq!(ops, batch.ops); + assert_eq!(3, batch.num_groups()); + + let batch_groups = [ + build_group(&ops[..8]), + build_group(&[ops[8]]), + Felt::new(11), + ZERO, + ZERO, + ZERO, + ZERO, + ZERO, + ]; + + assert_eq!(batch_groups, batch.groups); + assert_eq!([8_usize, 1, 0, 0, 0, 0, 0, 0], batch.op_counts); + assert_eq!(hasher::hash_elements(&batch_groups), hash); + + // --- push at the end of a group is moved into the next group ---------------------------- + let ops = vec![ + Operation::Add, + Operation::Mul, + Operation::Add, + Operation::Add, + Operation::Add, + Operation::Mul, + Operation::Mul, + Operation::Push(ONE), + Operation::Push(Felt::new(2)), + ]; + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(1, batches.len()); + + let batch = &batches[0]; + assert_eq!(ops, batch.ops); + assert_eq!(4, batch.num_groups()); + + let batch_groups = [ + build_group(&ops[..8]), + ONE, + build_group(&[ops[8]]), + Felt::new(2), + ZERO, + ZERO, + ZERO, + ZERO, + ]; + + assert_eq!(batch_groups, batch.groups); + assert_eq!([8_usize, 0, 1, 0, 0, 0, 0, 0], batch.op_counts); + assert_eq!(hasher::hash_elements(&batch_groups), hash); + + // --- push at the end of the 7th group overflows to the next batch ----------------------- + let ops = vec![ + Operation::Add, + Operation::Mul, + Operation::Push(ONE), + Operation::Push(Felt::new(2)), + Operation::Push(Felt::new(3)), + Operation::Push(Felt::new(4)), + Operation::Push(Felt::new(5)), + Operation::Add, + Operation::Mul, + Operation::Add, + Operation::Mul, + Operation::Add, + Operation::Mul, + Operation::Add, + Operation::Mul, + Operation::Add, + Operation::Mul, + Operation::Push(Felt::new(6)), + Operation::Pad, + ]; + + let (batches, hash) = super::batch_ops(ops.clone()); + assert_eq!(2, batches.len()); + + let batch0 = &batches[0]; + assert_eq!(ops[..17], batch0.ops); + assert_eq!(7, batch0.num_groups()); + + let batch0_groups = [ + build_group(&ops[..9]), + ONE, + Felt::new(2), + Felt::new(3), + Felt::new(4), + Felt::new(5), + build_group(&ops[9..17]), + ZERO, + ]; + + assert_eq!(batch0_groups, batch0.groups); + assert_eq!([9_usize, 0, 0, 0, 0, 0, 8, 0], batch0.op_counts); + + let batch1 = &batches[1]; + assert_eq!(ops[17..], batch1.ops); + assert_eq!(2, batch1.num_groups()); + + let batch1_groups = [build_group(&ops[17..]), Felt::new(6), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO]; + assert_eq!(batch1_groups, batch1.groups); + assert_eq!([2_usize, 0, 0, 0, 0, 0, 0, 0], batch1.op_counts); + + let all_groups = [batch0_groups, batch1_groups].concat(); + assert_eq!(hasher::hash_elements(&all_groups), hash); +} + +// TEST HELPERS +// -------------------------------------------------------------------------------------------- + +fn build_group(ops: &[Operation]) -> Felt { + let mut group = 0u64; + for (i, op) in ops.iter().enumerate() { + group |= (op.op_code() as u64) << (Operation::OP_BITS * i); + } + Felt::new(group) +} From 49673aee5cb4accdcd59c8137d8c192475924976 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Tue, 2 Jul 2024 13:55:04 -0400 Subject: [PATCH 054/172] operation_or_decorator_iterator test --- core/src/mast/node/basic_block_node/mod.rs | 1 + core/src/mast/node/basic_block_node/tests.rs | 41 ++++++++++++++++++-- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/core/src/mast/node/basic_block_node/mod.rs b/core/src/mast/node/basic_block_node/mod.rs index b0b8e47fa7..66cc659998 100644 --- a/core/src/mast/node/basic_block_node/mod.rs +++ b/core/src/mast/node/basic_block_node/mod.rs @@ -224,6 +224,7 @@ impl fmt::Display for BasicBlockNode { // ================================================================================================ // TODOP: Document +#[derive(Clone, Debug, Eq, PartialEq)] pub enum OperationOrDecorator<'a> { Operation(&'a Operation), Decorator(&'a Decorator), diff --git a/core/src/mast/node/basic_block_node/tests.rs b/core/src/mast/node/basic_block_node/tests.rs index 3285e8f401..49a663ae33 100644 --- a/core/src/mast/node/basic_block_node/tests.rs +++ b/core/src/mast/node/basic_block_node/tests.rs @@ -1,6 +1,5 @@ - -use super::{hasher, Felt, Operation, BATCH_SIZE, ZERO}; -use crate::ONE; +use super::*; +use crate::{Decorator, ONE}; #[test] fn batch_ops() { @@ -294,6 +293,42 @@ fn batch_ops() { assert_eq!(hasher::hash_elements(&all_groups), hash); } +#[test] +fn operation_or_decorator_iterator() { + let operations = vec![Operation::Add, Operation::Mul, Operation::MovDn2, Operation::MovDn3]; + + // Note: there are 2 decorators after the last instruction + let decorators = vec![ + (0, Decorator::Event(0)), + (0, Decorator::Event(1)), + (3, Decorator::Event(2)), + (4, Decorator::Event(3)), + (4, Decorator::Event(4)), + ]; + + let node = BasicBlockNode::with_decorators(operations, decorators); + + let mut iterator = node.iter(); + + // operation index 0 + assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&Decorator::Event(0)))); + assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&Decorator::Event(1)))); + assert_eq!(iterator.next(), Some(OperationOrDecorator::Operation(&Operation::Add))); + + // operations indices 1, 2 + assert_eq!(iterator.next(), Some(OperationOrDecorator::Operation(&Operation::Mul))); + assert_eq!(iterator.next(), Some(OperationOrDecorator::Operation(&Operation::MovDn2))); + + // operation index 3 + assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&Decorator::Event(2)))); + assert_eq!(iterator.next(), Some(OperationOrDecorator::Operation(&Operation::MovDn3))); + + // after last operation + assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&Decorator::Event(3)))); + assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&Decorator::Event(4)))); + assert_eq!(iterator.next(), None); +} + // TEST HELPERS // -------------------------------------------------------------------------------------------- From a5c324b2b15acfda5b983f3d8010e0784e324378 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Tue, 2 Jul 2024 14:18:11 -0400 Subject: [PATCH 055/172] Implement `Operation::with_opcode_and_data()` --- core/src/operations/mod.rs | 119 ++++++++++++++++++++++++++++++++++++- 1 file changed, 118 insertions(+), 1 deletion(-) diff --git a/core/src/operations/mod.rs b/core/src/operations/mod.rs index d1cc95f273..8321db1c43 100644 --- a/core/src/operations/mod.rs +++ b/core/src/operations/mod.rs @@ -451,7 +451,122 @@ pub enum OperationData { impl Operation { // TODOP: document, and use `Result` instead? pub fn with_opcode_and_data(opcode: u8, data: OperationData) -> Option { - todo!() + match opcode { + 0b0000_0000 => Some(Self::Noop), + 0b0000_0001 => Some(Self::Eqz), + 0b0000_0010 => Some(Self::Neg), + 0b0000_0011 => Some(Self::Inv), + 0b0000_0100 => Some(Self::Incr), + 0b0000_0101 => Some(Self::Not), + 0b0000_0110 => Some(Self::FmpAdd), + 0b0000_0111 => Some(Self::MLoad), + 0b0000_1000 => Some(Self::Swap), + 0b0000_1001 => Some(Self::Caller), + 0b0000_1010 => Some(Self::MovUp2), + 0b0000_1011 => Some(Self::MovDn2), + 0b0000_1100 => Some(Self::MovUp3), + 0b0000_1101 => Some(Self::MovDn3), + 0b0000_1110 => Some(Self::AdvPopW), + 0b0000_1111 => Some(Self::Expacc), + + 0b0001_0000 => Some(Self::MovUp4), + 0b0001_0001 => Some(Self::MovDn4), + 0b0001_0010 => Some(Self::MovUp5), + 0b0001_0011 => Some(Self::MovDn5), + 0b0001_0100 => Some(Self::MovUp6), + 0b0001_0101 => Some(Self::MovDn6), + 0b0001_0110 => Some(Self::MovUp7), + 0b0001_0111 => Some(Self::MovDn7), + 0b0001_1000 => Some(Self::SwapW), + 0b0001_1001 => Some(Self::Ext2Mul), + 0b0001_1010 => Some(Self::MovUp8), + 0b0001_1011 => Some(Self::MovDn8), + 0b0001_1100 => Some(Self::SwapW2), + 0b0001_1101 => Some(Self::SwapW3), + 0b0001_1110 => Some(Self::SwapDW), + // 0b0001_1111 => , + 0b0010_0000 => match data { + OperationData::U32(value) => Some(Self::Assert(value)), + _ => None, + }, + 0b0010_0001 => Some(Self::Eq), + 0b0010_0010 => Some(Self::Add), + 0b0010_0011 => Some(Self::Mul), + 0b0010_0100 => Some(Self::And), + 0b0010_0101 => Some(Self::Or), + 0b0010_0110 => Some(Self::U32and), + 0b0010_0111 => Some(Self::U32xor), + 0b0010_1000 => Some(Self::FriE2F4), + 0b0010_1001 => Some(Self::Drop), + 0b0010_1010 => Some(Self::CSwap), + 0b0010_1011 => Some(Self::CSwapW), + 0b0010_1100 => Some(Self::MLoadW), + 0b0010_1101 => Some(Self::MStore), + 0b0010_1110 => Some(Self::MStoreW), + 0b0010_1111 => Some(Self::FmpUpdate), + + 0b0011_0000 => Some(Self::Pad), + 0b0011_0001 => Some(Self::Dup0), + 0b0011_0010 => Some(Self::Dup1), + 0b0011_0011 => Some(Self::Dup2), + 0b0011_0100 => Some(Self::Dup3), + 0b0011_0101 => Some(Self::Dup4), + 0b0011_0110 => Some(Self::Dup5), + 0b0011_0111 => Some(Self::Dup6), + 0b0011_1000 => Some(Self::Dup7), + 0b0011_1001 => Some(Self::Dup9), + 0b0011_1010 => Some(Self::Dup11), + 0b0011_1011 => Some(Self::Dup13), + 0b0011_1100 => Some(Self::Dup15), + 0b0011_1101 => Some(Self::AdvPop), + 0b0011_1110 => Some(Self::SDepth), + 0b0011_1111 => Some(Self::Clk), + + 0b0100_0000 => Some(Self::U32add), + 0b0100_0010 => Some(Self::U32sub), + 0b0100_0100 => Some(Self::U32mul), + 0b0100_0110 => Some(Self::U32div), + 0b0100_1000 => Some(Self::U32split), + 0b0100_1010 => match data { + OperationData::Felt(value) => Some(Self::U32assert2(value)), + _ => None, + }, + 0b0100_1100 => Some(Self::U32add3), + 0b0100_1110 => Some(Self::U32madd), + + 0b0101_0000 => Some(Self::HPerm), + 0b0101_0001 => match data { + OperationData::U32(value) => Some(Self::MpVerify(value)), + _ => None, + }, + 0b0101_0010 => Some(Self::Pipe), + 0b0101_0011 => Some(Self::MStream), + 0b0101_0100 => Some(Self::Split), + 0b0101_0101 => Some(Self::Loop), + 0b0101_0110 => Some(Self::Span), + 0b0101_0111 => Some(Self::Join), + 0b0101_1000 => Some(Self::Dyn), + 0b0101_1001 => Some(Self::RCombBase), + // 0b0101_1010 => , + // 0b0101_1011 => , + // 0b0101_1100 => , + // 0b0101_1101 => , + // 0b0101_1110 => , + // 0b0101_1111 => , + 0b0110_0000 => Some(Self::MrUpdate), + 0b0110_0100 => match data { + OperationData::Felt(value) => Some(Self::Push(value)), + _ => None, + }, + 0b0110_1000 => Some(Self::SysCall), + 0b0110_1100 => Some(Self::Call), + 0b0111_0000 => Some(Self::End), + 0b0111_0100 => Some(Self::Repeat), + 0b0111_1000 => Some(Self::Respan), + 0b0111_1100 => Some(Self::Halt), + + _ => None, + } } } @@ -472,6 +587,8 @@ impl Operation { /// operations and some other operations requiring very high degree constraints. #[rustfmt::skip] pub const fn op_code(&self) -> u8 { + // REMEMBER: If you add/remove/modify an opcode here, you must also make the same change in + // `Operation::with_opcode_and_data()`. match self { Self::Noop => 0b0000_0000, Self::Eqz => 0b0000_0001, From c04bc90c85a999bdd67fe891d7db7fa95b97e4b2 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Tue, 2 Jul 2024 16:18:49 -0400 Subject: [PATCH 056/172] encode decorators --- core/src/mast/serialization/decorator.rs | 84 +++++++++++++++++ core/src/mast/serialization/mod.rs | 109 ++++++++++++++++++++++- 2 files changed, 189 insertions(+), 4 deletions(-) create mode 100644 core/src/mast/serialization/decorator.rs diff --git a/core/src/mast/serialization/decorator.rs b/core/src/mast/serialization/decorator.rs new file mode 100644 index 0000000000..4a7a6af0d1 --- /dev/null +++ b/core/src/mast/serialization/decorator.rs @@ -0,0 +1,84 @@ +use num_derive::{FromPrimitive, ToPrimitive}; +use num_traits::ToPrimitive; + +use crate::{AdviceInjector, DebugOptions, Decorator}; + +/// TODOP: Document +#[derive(FromPrimitive, ToPrimitive)] +#[repr(u8)] +pub enum EncodedDecoratorVariant { + AdviceInjectorMerkleNodeMerge, + AdviceInjectorMerkleNodeToStack, + AdviceInjectorUpdateMerkleNode, + AdviceInjectorMapValueToStack, + AdviceInjectorU64Div, + AdviceInjectorExt2Inv, + AdviceInjectorExt2Intt, + AdviceInjectorSmtGet, + AdviceInjectorSmtSet, + AdviceInjectorSmtPeek, + AdviceInjectorU32Clz, + AdviceInjectorU32Ctz, + AdviceInjectorU32Clo, + AdviceInjectorU32Cto, + AdviceInjectorILog2, + AdviceInjectorMemToMap, + AdviceInjectorHdwordToMap, + AdviceInjectorHpermToMap, + AdviceInjectorSigToStack, + AssemblyOp, + DebugOptionsStackAll, + DebugOptionsStackTop, + DebugOptionsMemAll, + DebugOptionsMemInterval, + DebugOptionsLocalInterval, + Event, + Trace, +} + +impl EncodedDecoratorVariant { + pub fn discriminant(&self) -> u8 { + self.to_u8().expect("guaranteed to fit in a `u8` due to #[repr(u8)]") + } +} + +impl From<&Decorator> for EncodedDecoratorVariant { + fn from(decorator: &Decorator) -> Self { + match decorator { + Decorator::Advice(advice_injector) => match advice_injector { + AdviceInjector::MerkleNodeMerge => Self::AdviceInjectorMerkleNodeMerge, + AdviceInjector::MerkleNodeToStack => Self::AdviceInjectorMerkleNodeToStack, + AdviceInjector::UpdateMerkleNode => todo!(), + AdviceInjector::MapValueToStack { + include_len: _, + key_offset: _, + } => Self::AdviceInjectorMapValueToStack, + AdviceInjector::U64Div => Self::AdviceInjectorU64Div, + AdviceInjector::Ext2Inv => Self::AdviceInjectorExt2Inv, + AdviceInjector::Ext2Intt => Self::AdviceInjectorExt2Intt, + AdviceInjector::SmtGet => Self::AdviceInjectorSmtGet, + AdviceInjector::SmtSet => Self::AdviceInjectorSmtSet, + AdviceInjector::SmtPeek => Self::AdviceInjectorSmtPeek, + AdviceInjector::U32Clz => Self::AdviceInjectorU32Clz, + AdviceInjector::U32Ctz => Self::AdviceInjectorU32Ctz, + AdviceInjector::U32Clo => Self::AdviceInjectorU32Clo, + AdviceInjector::U32Cto => Self::AdviceInjectorU32Cto, + AdviceInjector::ILog2 => Self::AdviceInjectorILog2, + AdviceInjector::MemToMap => Self::AdviceInjectorMemToMap, + AdviceInjector::HdwordToMap { domain: _ } => Self::AdviceInjectorHdwordToMap, + AdviceInjector::HpermToMap => Self::AdviceInjectorHpermToMap, + AdviceInjector::SigToStack { kind: _ } => Self::AdviceInjectorSigToStack, + }, + Decorator::AsmOp(_) => Self::AssemblyOp, + Decorator::Debug(debug_options) => match debug_options { + DebugOptions::StackAll => Self::DebugOptionsStackAll, + DebugOptions::StackTop(_) => Self::DebugOptionsStackTop, + DebugOptions::MemAll => Self::DebugOptionsMemAll, + DebugOptions::MemInterval(_, _) => Self::DebugOptionsMemInterval, + DebugOptions::LocalInterval(_, _, _) => Self::DebugOptionsLocalInterval, + }, + Decorator::Event(_) => Self::Event, + Decorator::Trace(_) => Self::Trace, + } + } +} diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 16ee3ed86a..d3f1fac63a 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -11,17 +11,24 @@ use winter_utils::{ use crate::{ mast::{MerkleTreeNode, OperationOrDecorator}, - DecoratorList, Operation, OperationData, + AdviceInjector, DebugOptions, Decorator, DecoratorList, Operation, OperationData, + SignatureKind, }; use super::{MastForest, MastNode, MastNodeId}; +mod decorator; +use decorator::EncodedDecoratorVariant; + mod info; use info::{EncodedMastNodeType, MastNodeInfo, MastNodeTypeVariant}; /// Specifies an offset into the `data` section of an encoded [`MastForest`]. type DataOffset = u32; +/// Specifies an offset into the `strings` table of an encoded [`MastForest`] +type StringIndex = usize; + /// Magic string for detecting that a file is binary-encoded MAST. const MAGIC: &[u8; 5] = b"MAST\0"; @@ -177,9 +184,8 @@ fn mast_node_to_info( for op_or_decorator in basic_block.iter() { match op_or_decorator { OperationOrDecorator::Operation(operation) => encode_operation(operation, data), - OperationOrDecorator::Decorator(_) => { - // TODOP: Remember: you need to set the most significant bit to 1 - todo!() + OperationOrDecorator::Decorator(decorator) => { + encode_decorator(decorator, data, strings) } } } @@ -294,6 +300,101 @@ fn encode_operation(operation: &Operation, data: &mut Vec) { } } +fn encode_decorator(decorator: &Decorator, data: &mut Vec, strings: &mut Vec) { + // Set the first byte to the decorator discriminant. + // + // Note: the most significant bit is set to 1 (to differentiate decorators from operations). + { + let decorator_variant: EncodedDecoratorVariant = decorator.into(); + data.push(decorator_variant.discriminant() | 0b1000_0000); + } + + // For decorators that have extra data, encode it in `data` and `strings`. + match decorator { + Decorator::Advice(advice_injector) => match advice_injector { + AdviceInjector::MapValueToStack { + include_len, + key_offset, + } => { + data.push((*include_len).into()); + data.write_usize(*key_offset); + } + AdviceInjector::HdwordToMap { domain } => data.extend(domain.as_int().to_le_bytes()), + AdviceInjector::SigToStack { kind } => match kind { + SignatureKind::RpoFalcon512 => data.push(0_u8), + }, + AdviceInjector::MerkleNodeMerge + | AdviceInjector::MerkleNodeToStack + | AdviceInjector::UpdateMerkleNode + | AdviceInjector::U64Div + | AdviceInjector::Ext2Inv + | AdviceInjector::Ext2Intt + | AdviceInjector::SmtGet + | AdviceInjector::SmtSet + | AdviceInjector::SmtPeek + | AdviceInjector::U32Clz + | AdviceInjector::U32Ctz + | AdviceInjector::U32Clo + | AdviceInjector::U32Cto + | AdviceInjector::ILog2 + | AdviceInjector::MemToMap + | AdviceInjector::HpermToMap => (), + }, + Decorator::AsmOp(assembly_op) => { + data.push(assembly_op.num_cycles()); + data.push(assembly_op.should_break() as u8); + + // TODOP: Make a StringTable type + + // context name + { + let str_index_in_table = push_string(data, strings, assembly_op.context_name()); + + data.write_usize(str_index_in_table); + } + + // op + { + let str_index_in_table = push_string(data, strings, assembly_op.op()); + + data.write_usize(str_index_in_table); + } + } + Decorator::Debug(debug_options) => match debug_options { + DebugOptions::StackTop(value) => data.push(*value), + DebugOptions::MemInterval(start, end) => { + data.extend(start.to_le_bytes()); + data.extend(end.to_le_bytes()); + } + DebugOptions::LocalInterval(start, second, end) => { + data.extend(start.to_le_bytes()); + data.extend(second.to_le_bytes()); + data.extend(end.to_le_bytes()); + } + DebugOptions::StackAll | DebugOptions::MemAll => (), + }, + Decorator::Event(value) | Decorator::Trace(value) => data.extend(value.to_le_bytes()), + } +} + +// TODOP: Make this a method of `StringTable` type +fn push_string(data: &mut Vec, strings: &mut Vec, value: &str) -> StringIndex { + let offset = data.len(); + data.extend(value.as_bytes()); + + let str_ref = StringRef { + offset: offset + .try_into() + .expect("MastForest serialization: data field larger than 2^32 bytes"), + len: value.len().try_into().expect("decorator string length exceeds 2^32 bytes"), + }; + + let str_index_in_table = strings.len(); + strings.push(str_ref); + + str_index_in_table +} + fn try_info_to_mast_node( mast_node_info: MastNodeInfo, mast_forest: &MastForest, From 27e27834381f6dd758cb073a4a7d865030fd22a3 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Tue, 2 Jul 2024 17:37:54 -0400 Subject: [PATCH 057/172] implement `decode_decorator()` --- core/src/mast/serialization/decorator.rs | 6 +- core/src/mast/serialization/info.rs | 1 + core/src/mast/serialization/mod.rs | 182 ++++++++++++++++++++++- 3 files changed, 184 insertions(+), 5 deletions(-) diff --git a/core/src/mast/serialization/decorator.rs b/core/src/mast/serialization/decorator.rs index 4a7a6af0d1..dc7d1a0ebf 100644 --- a/core/src/mast/serialization/decorator.rs +++ b/core/src/mast/serialization/decorator.rs @@ -1,5 +1,5 @@ use num_derive::{FromPrimitive, ToPrimitive}; -use num_traits::ToPrimitive; +use num_traits::{FromPrimitive, ToPrimitive}; use crate::{AdviceInjector, DebugOptions, Decorator}; @@ -40,6 +40,10 @@ impl EncodedDecoratorVariant { pub fn discriminant(&self) -> u8 { self.to_u8().expect("guaranteed to fit in a `u8` due to #[repr(u8)]") } + + pub fn from_discriminant(discriminant: u8) -> Option { + Self::from_u8(discriminant) + } } impl From<&Decorator> for EncodedDecoratorVariant { diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 82c42f8dc7..f776708de2 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -218,6 +218,7 @@ impl MastNodeTypeVariant { self.to_u8().expect("guaranteed to fit in a `u8` due to #[repr(u8)]") } + // TODOP: Just do `from_discriminant() -> Option`, and document what `None` means pub fn try_from_discriminant(discriminant: u8) -> Result { Self::from_u8(discriminant).ok_or_else(|| super::Error::InvalidDiscriminant { ty: "MastNode".into(), diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index d3f1fac63a..e6a768d8b9 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -11,7 +11,7 @@ use winter_utils::{ use crate::{ mast::{MerkleTreeNode, OperationOrDecorator}, - AdviceInjector, DebugOptions, Decorator, DecoratorList, Operation, OperationData, + AdviceInjector, AssemblyOp, DebugOptions, Decorator, DecoratorList, Operation, OperationData, SignatureKind, }; @@ -148,6 +148,7 @@ impl Deserializable for MastForest { mast_node_info, &mast_forest, &mut data_reader, + &data, &strings, )?; mast_forest.add_node(node); @@ -316,12 +317,14 @@ fn encode_decorator(decorator: &Decorator, data: &mut Vec, strings: &mut Vec include_len, key_offset, } => { - data.push((*include_len).into()); + data.write_bool(*include_len); data.write_usize(*key_offset); } AdviceInjector::HdwordToMap { domain } => data.extend(domain.as_int().to_le_bytes()), + + // Note: Since there is only 1 variant, we don't need to write any extra bytes. AdviceInjector::SigToStack { kind } => match kind { - SignatureKind::RpoFalcon512 => data.push(0_u8), + SignatureKind::RpoFalcon512 => (), }, AdviceInjector::MerkleNodeMerge | AdviceInjector::MerkleNodeToStack @@ -399,6 +402,7 @@ fn try_info_to_mast_node( mast_node_info: MastNodeInfo, mast_forest: &MastForest, data_reader: &mut SliceReader, + data: &[u8], strings: &[StringRef], ) -> Result { let mast_node_variant = mast_node_info @@ -416,6 +420,7 @@ fn try_info_to_mast_node( let (operations, decorators) = decode_operations_and_decorators( num_operations_and_decorators, data_reader, + data, strings, )?; @@ -456,6 +461,7 @@ fn try_info_to_mast_node( fn decode_operations_and_decorators( num_to_decode: u32, data_reader: &mut SliceReader, + data: &[u8], strings: &[StringRef], ) -> Result<(Vec, DecoratorList), DeserializationError> { let mut operations: Vec = Vec::new(); @@ -501,10 +507,178 @@ fn decode_operations_and_decorators( } else { // decorator. let discriminant = first_byte & 0b0111_1111; + let decorator = decode_decorator(discriminant, data_reader, data, strings)?; - todo!() + decorators.push((operations.len(), decorator)); } } Ok((operations, decorators)) } + +fn decode_decorator( + discriminant: u8, + data_reader: &mut SliceReader, + data: &[u8], + strings: &[StringRef], +) -> Result { + let decorator_variant = + EncodedDecoratorVariant::from_discriminant(discriminant).ok_or_else(|| { + DeserializationError::InvalidValue(format!( + "invalid decorator variant discriminant: {discriminant}" + )) + })?; + + match decorator_variant { + EncodedDecoratorVariant::AdviceInjectorMerkleNodeMerge => { + Ok(Decorator::Advice(AdviceInjector::MerkleNodeMerge)) + } + EncodedDecoratorVariant::AdviceInjectorMerkleNodeToStack => { + Ok(Decorator::Advice(AdviceInjector::MerkleNodeToStack)) + } + EncodedDecoratorVariant::AdviceInjectorUpdateMerkleNode => { + Ok(Decorator::Advice(AdviceInjector::UpdateMerkleNode)) + } + EncodedDecoratorVariant::AdviceInjectorMapValueToStack => { + let include_len = data_reader.read_bool()?; + let key_offset = data_reader.read_usize()?; + + Ok(Decorator::Advice(AdviceInjector::MapValueToStack { + include_len, + key_offset, + })) + } + EncodedDecoratorVariant::AdviceInjectorU64Div => { + Ok(Decorator::Advice(AdviceInjector::U64Div)) + } + EncodedDecoratorVariant::AdviceInjectorExt2Inv => { + Ok(Decorator::Advice(AdviceInjector::Ext2Inv)) + } + EncodedDecoratorVariant::AdviceInjectorExt2Intt => { + Ok(Decorator::Advice(AdviceInjector::Ext2Intt)) + } + EncodedDecoratorVariant::AdviceInjectorSmtGet => { + Ok(Decorator::Advice(AdviceInjector::SmtGet)) + } + EncodedDecoratorVariant::AdviceInjectorSmtSet => { + Ok(Decorator::Advice(AdviceInjector::SmtSet)) + } + EncodedDecoratorVariant::AdviceInjectorSmtPeek => { + Ok(Decorator::Advice(AdviceInjector::SmtPeek)) + } + EncodedDecoratorVariant::AdviceInjectorU32Clz => { + Ok(Decorator::Advice(AdviceInjector::U32Clz)) + } + EncodedDecoratorVariant::AdviceInjectorU32Ctz => { + Ok(Decorator::Advice(AdviceInjector::U32Ctz)) + } + EncodedDecoratorVariant::AdviceInjectorU32Clo => { + Ok(Decorator::Advice(AdviceInjector::U32Clo)) + } + EncodedDecoratorVariant::AdviceInjectorU32Cto => { + Ok(Decorator::Advice(AdviceInjector::U32Cto)) + } + EncodedDecoratorVariant::AdviceInjectorILog2 => { + Ok(Decorator::Advice(AdviceInjector::ILog2)) + } + EncodedDecoratorVariant::AdviceInjectorMemToMap => { + Ok(Decorator::Advice(AdviceInjector::MemToMap)) + } + EncodedDecoratorVariant::AdviceInjectorHdwordToMap => { + let domain = data_reader.read_u64()?; + let domain = Felt::try_from(domain).map_err(|err| { + DeserializationError::InvalidValue(format!( + "Error when deserializing HdwordToMap decorator domain: {err}" + )) + })?; + + Ok(Decorator::Advice(AdviceInjector::HdwordToMap { domain })) + } + EncodedDecoratorVariant::AdviceInjectorHpermToMap => { + Ok(Decorator::Advice(AdviceInjector::HpermToMap)) + } + EncodedDecoratorVariant::AdviceInjectorSigToStack => { + Ok(Decorator::Advice(AdviceInjector::SigToStack { + kind: SignatureKind::RpoFalcon512, + })) + } + EncodedDecoratorVariant::AssemblyOp => { + let num_cycles = data_reader.read_u8()?; + let should_break = data_reader.read_bool()?; + + let context_name = { + let str_index_in_table = data_reader.read_usize()?; + read_string(str_index_in_table, data, strings)? + }; + + let op = { + let str_index_in_table = data_reader.read_usize()?; + read_string(str_index_in_table, data, strings)? + }; + + Ok(Decorator::AsmOp(AssemblyOp::new(context_name, num_cycles, op, should_break))) + } + EncodedDecoratorVariant::DebugOptionsStackAll => { + Ok(Decorator::Debug(DebugOptions::StackAll)) + } + EncodedDecoratorVariant::DebugOptionsStackTop => { + let value = data_reader.read_u8()?; + + Ok(Decorator::Debug(DebugOptions::StackTop(value))) + } + EncodedDecoratorVariant::DebugOptionsMemAll => Ok(Decorator::Debug(DebugOptions::MemAll)), + EncodedDecoratorVariant::DebugOptionsMemInterval => { + let start = u32::from_le_bytes(data_reader.read_array::<4>()?); + let end = u32::from_le_bytes(data_reader.read_array::<4>()?); + + Ok(Decorator::Debug(DebugOptions::MemInterval(start, end))) + } + EncodedDecoratorVariant::DebugOptionsLocalInterval => { + let start = u16::from_le_bytes(data_reader.read_array::<2>()?); + let second = u16::from_le_bytes(data_reader.read_array::<2>()?); + let end = u16::from_le_bytes(data_reader.read_array::<2>()?); + + Ok(Decorator::Debug(DebugOptions::LocalInterval(start, second, end))) + } + EncodedDecoratorVariant::Event => { + let value = u32::from_le_bytes(data_reader.read_array::<4>()?); + + Ok(Decorator::Event(value)) + } + EncodedDecoratorVariant::Trace => { + let value = u32::from_le_bytes(data_reader.read_array::<4>()?); + + Ok(Decorator::Trace(value)) + } + } +} + +fn read_string( + str_index_in_table: usize, + data: &[u8], + strings: &[StringRef], +) -> Result { + let str_ref = strings.get(str_index_in_table).ok_or_else(|| { + DeserializationError::InvalidValue(format!( + "invalid index in strings table: {str_index_in_table}" + )) + })?; + + let str_bytes = { + let start = str_ref.offset as usize; + let end = (str_ref.offset + str_ref.len) as usize; + + data.get(start..end).ok_or_else(|| { + DeserializationError::InvalidValue(format!( + "invalid string ref in strings table. Offset: {}, length: {}", + str_ref.offset, str_ref.len + )) + })? + }; + + String::from_utf8(str_bytes.to_vec()).map_err(|_| { + DeserializationError::InvalidValue(format!( + "Invalid UTF-8 string in strings table: {str_bytes:?}" + )) + }) +} From 25fe82ff8a7e8b80b89c87337f9352e11297f633 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Tue, 2 Jul 2024 18:20:13 -0400 Subject: [PATCH 058/172] fix exec invocation --- .../src/assembler/instruction/procedures.rs | 17 +++++++---------- assembly/src/errors.rs | 9 --------- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index 3c2e2e7983..9894a82092 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -90,16 +90,13 @@ impl Assembler { // Note that here we rely on the fact that we topologically sorted the // procedures, such that when we assemble a procedure, all // procedures that it calls will have been assembled, and - // hence be present in the `MastForest`. We currently assume that the - // `MastForest` contains all the procedures being called; "external procedures" - // only known by digest are not currently supported. - mast_forest_builder.find_procedure_root(mast_root).ok_or_else(|| { - AssemblyError::UnknownExecTarget { - span, - source_file: current_source_file, - callee: mast_root, - } - })? + // hence be present in the `MastForest`. + mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| { + // If the MAST root called isn't known to us, make it an external + // reference. + let external_node = MastNode::new_external(mast_root); + mast_forest_builder.ensure_node(external_node) + }) } InvokeKind::Call => { let callee_id = diff --git a/assembly/src/errors.rs b/assembly/src/errors.rs index a85ccb7057..5b51048292 100644 --- a/assembly/src/errors.rs +++ b/assembly/src/errors.rs @@ -104,15 +104,6 @@ pub enum AssemblyError { source_file: Option>, callee: RpoDigest, }, - #[error("invalid exec: exec'd procedures must be available during compilation, but '{callee}' is not")] - #[diagnostic()] - UnknownExecTarget { - #[label("call occurs here")] - span: SourceSpan, - #[source_code] - source_file: Option>, - callee: RpoDigest, - }, #[error("invalid use of 'caller' instruction outside of kernel")] #[diagnostic(help( "the 'caller' instruction is only allowed in procedures defined in a kernel" From 3c26bd662ca26306608e7519f8875a002764a3bb Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Tue, 2 Jul 2024 18:29:05 -0400 Subject: [PATCH 059/172] no else blk special case --- assembly/src/assembler/mod.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 890e7979c9..a63f7ffb25 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -789,14 +789,8 @@ impl Assembler { let then_blk = self.compile_body(then_blk.iter(), context, None, mast_forest_builder)?; - // else is an exception because it is optional; hence, will have to be replaced - // by noop span - let else_blk = if else_blk.is_empty() { - let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]); - mast_forest_builder.ensure_node(basic_block_node) - } else { - self.compile_body(else_blk.iter(), context, None, mast_forest_builder)? - }; + let else_blk = + self.compile_body(else_blk.iter(), context, None, mast_forest_builder)?; let split_node_id = { let split_node = From 34c2f7f85e379422636836ae133152669d0fe598 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Tue, 2 Jul 2024 18:43:24 -0400 Subject: [PATCH 060/172] add procedure roots comment --- processor/src/lib.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/processor/src/lib.rs b/processor/src/lib.rs index d825592967..fc6b136a2b 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -262,6 +262,9 @@ where root_digest: external_node.digest(), }, )?; + + // We temporarily limit the parts of the program that can be called externally to + // procedure roots, even though MAST doesn't have that restriction. let root_id = mast_forest.find_procedure_root(external_node.digest()).ok_or( ExecutionError::MalformedMastForestInHost { root_digest: external_node.digest(), From ffc7c782de733639c008555a486d011e4eb0da14 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 06:14:37 -0400 Subject: [PATCH 061/172] implement forgotten `todo!()` --- core/src/mast/serialization/decorator.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/mast/serialization/decorator.rs b/core/src/mast/serialization/decorator.rs index dc7d1a0ebf..3cb509a039 100644 --- a/core/src/mast/serialization/decorator.rs +++ b/core/src/mast/serialization/decorator.rs @@ -52,7 +52,7 @@ impl From<&Decorator> for EncodedDecoratorVariant { Decorator::Advice(advice_injector) => match advice_injector { AdviceInjector::MerkleNodeMerge => Self::AdviceInjectorMerkleNodeMerge, AdviceInjector::MerkleNodeToStack => Self::AdviceInjectorMerkleNodeToStack, - AdviceInjector::UpdateMerkleNode => todo!(), + AdviceInjector::UpdateMerkleNode => Self::AdviceInjectorUpdateMerkleNode, AdviceInjector::MapValueToStack { include_len: _, key_offset: _, From 550415752c3ba531cef3925b5614e50b2c05ea67 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 06:15:14 -0400 Subject: [PATCH 062/172] `serialize_deserialize_all_nodes` test --- core/src/mast/serialization/mod.rs | 3 + core/src/mast/serialization/tests.rs | 170 +++++++++++++++++++++++++++ 2 files changed, 173 insertions(+) create mode 100644 core/src/mast/serialization/tests.rs diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index e6a768d8b9..0c82e1b6b6 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -23,6 +23,9 @@ use decorator::EncodedDecoratorVariant; mod info; use info::{EncodedMastNodeType, MastNodeInfo, MastNodeTypeVariant}; +#[cfg(test)] +mod tests; + /// Specifies an offset into the `data` section of an encoded [`MastForest`]. type DataOffset = u32; diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs new file mode 100644 index 0000000000..2f1276d2dd --- /dev/null +++ b/core/src/mast/serialization/tests.rs @@ -0,0 +1,170 @@ +use math::FieldElement; + +use super::*; +use crate::operations::Operation; + +#[test] +fn serialize_deserialize_all_nodes() { + let mut mast_forest = MastForest::new(); + + let basic_block_id = { + let operations = vec![ + Operation::Noop, + Operation::Assert(42), + Operation::FmpAdd, + Operation::FmpUpdate, + Operation::SDepth, + Operation::Caller, + Operation::Clk, + Operation::Join, + Operation::Split, + Operation::Loop, + Operation::Call, + Operation::Dyn, + Operation::SysCall, + Operation::Span, + Operation::End, + Operation::Repeat, + Operation::Respan, + Operation::Halt, + Operation::Add, + Operation::Neg, + Operation::Mul, + Operation::Inv, + Operation::Incr, + Operation::And, + Operation::Or, + Operation::Not, + Operation::Eq, + Operation::Eqz, + Operation::Expacc, + Operation::Ext2Mul, + Operation::U32split, + Operation::U32add, + Operation::U32assert2(Felt::ONE), + Operation::U32add3, + Operation::U32sub, + Operation::U32mul, + Operation::U32madd, + Operation::U32div, + Operation::U32and, + Operation::U32xor, + Operation::Pad, + Operation::Drop, + Operation::Dup0, + Operation::Dup1, + Operation::Dup2, + Operation::Dup3, + Operation::Dup4, + Operation::Dup5, + Operation::Dup6, + Operation::Dup7, + Operation::Dup9, + Operation::Dup11, + Operation::Dup13, + Operation::Dup15, + Operation::Swap, + Operation::SwapW, + Operation::SwapW2, + Operation::SwapW3, + Operation::SwapDW, + Operation::MovUp2, + Operation::MovUp3, + Operation::MovUp4, + Operation::MovUp5, + Operation::MovUp6, + Operation::MovUp7, + Operation::MovUp8, + Operation::MovDn2, + Operation::MovDn3, + Operation::MovDn4, + Operation::MovDn5, + Operation::MovDn6, + Operation::MovDn7, + Operation::MovDn8, + Operation::CSwap, + Operation::CSwapW, + Operation::Push(Felt::new(45)), + Operation::AdvPop, + Operation::AdvPopW, + Operation::MLoadW, + Operation::MStoreW, + Operation::MLoad, + Operation::MStore, + Operation::MStream, + Operation::Pipe, + Operation::HPerm, + Operation::MpVerify(1022), + Operation::MrUpdate, + Operation::FriE2F4, + Operation::RCombBase, + ]; + + let num_operations = operations.len(); + + let decorators = vec![ + (0, Decorator::Advice(AdviceInjector::MerkleNodeMerge)), + (0, Decorator::Advice(AdviceInjector::MerkleNodeToStack)), + (0, Decorator::Advice(AdviceInjector::UpdateMerkleNode)), + ( + 0, + Decorator::Advice(AdviceInjector::MapValueToStack { + include_len: true, + key_offset: 1023, + }), + ), + (1, Decorator::Advice(AdviceInjector::U64Div)), + (3, Decorator::Advice(AdviceInjector::Ext2Inv)), + (5, Decorator::Advice(AdviceInjector::Ext2Intt)), + (5, Decorator::Advice(AdviceInjector::SmtGet)), + (5, Decorator::Advice(AdviceInjector::SmtSet)), + (5, Decorator::Advice(AdviceInjector::SmtPeek)), + (5, Decorator::Advice(AdviceInjector::U32Clz)), + (10, Decorator::Advice(AdviceInjector::U32Ctz)), + (10, Decorator::Advice(AdviceInjector::U32Clo)), + (10, Decorator::Advice(AdviceInjector::U32Cto)), + (10, Decorator::Advice(AdviceInjector::ILog2)), + (10, Decorator::Advice(AdviceInjector::MemToMap)), + ( + 10, + Decorator::Advice(AdviceInjector::HdwordToMap { + domain: Felt::new(423), + }), + ), + (15, Decorator::Advice(AdviceInjector::HpermToMap)), + ( + 15, + Decorator::Advice(AdviceInjector::SigToStack { + kind: SignatureKind::RpoFalcon512, + }), + ), + ( + 15, + Decorator::AsmOp(AssemblyOp::new( + "context".to_string(), + 15, + "op".to_string(), + false, + )), + ), + (15, Decorator::Debug(DebugOptions::StackAll)), + (15, Decorator::Debug(DebugOptions::StackTop(255))), + (15, Decorator::Debug(DebugOptions::MemAll)), + (15, Decorator::Debug(DebugOptions::MemInterval(0, 16))), + (17, Decorator::Debug(DebugOptions::LocalInterval(1, 2, 3))), + (num_operations, Decorator::Event(45)), + (num_operations, Decorator::Trace(55)), + ]; + + let basic_block_node = MastNode::new_basic_block_with_decorators(operations, decorators); + mast_forest.add_node(basic_block_node) + }; + + // TODOP: REMOVE + mast_forest.make_root(basic_block_id); + + let serialized_mast_forest = mast_forest.to_bytes(); + let deserialized_mast_forest = MastForest::read_from_bytes(&serialized_mast_forest).unwrap(); + + assert_eq!(mast_forest, deserialized_mast_forest); +} From 93c4fcae5734c06cba19a9c6defac7dd530f54e5 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 06:24:14 -0400 Subject: [PATCH 063/172] `decode_operations_and_decorators`: fix bit check --- core/src/mast/serialization/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 0c82e1b6b6..82688a360c 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -473,7 +473,7 @@ fn decode_operations_and_decorators( for _ in 0..num_to_decode { let first_byte = data_reader.read_u8()?; - if first_byte & 0b1000_0000 > 0 { + if first_byte & 0b1000_0000 == 0 { // operation. let op_code = first_byte; From da149841007714fccb880be9b3437d537d9383f0 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 06:25:38 -0400 Subject: [PATCH 064/172] confirm_assumptions test scaffold --- core/src/mast/serialization/tests.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs index 2f1276d2dd..1ba99cb9c7 100644 --- a/core/src/mast/serialization/tests.rs +++ b/core/src/mast/serialization/tests.rs @@ -3,6 +3,11 @@ use math::FieldElement; use super::*; use crate::operations::Operation; +#[test] +fn confirm_assumptions() { + // TODOP: match against all `Operation` and `Decorator` +} + #[test] fn serialize_deserialize_all_nodes() { let mut mast_forest = MastForest::new(); From 4870b32d67c65374ab649af4a645a58565bcc650 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 08:53:52 -0400 Subject: [PATCH 065/172] minor adjustments --- core/src/mast/serialization/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 82688a360c..0e7544265c 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -52,6 +52,7 @@ pub enum Error { /// An entry in the `strings` table of an encoded [`MastForest`]. /// /// Strings are UTF8-encoded. +#[derive(Debug)] pub struct StringRef { /// Offset into the `data` section. offset: DataOffset, @@ -348,7 +349,7 @@ fn encode_decorator(decorator: &Decorator, data: &mut Vec, strings: &mut Vec }, Decorator::AsmOp(assembly_op) => { data.push(assembly_op.num_cycles()); - data.push(assembly_op.should_break() as u8); + data.write_bool(assembly_op.should_break()); // TODOP: Make a StringTable type From 072b0b9098f7cb611f94a725343d4e1c656db25a Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 09:04:53 -0400 Subject: [PATCH 066/172] Introduce `StringTableBuilder` --- core/src/mast/serialization/mod.rs | 45 +++++--------- .../serialization/string_table_builder.rs | 59 +++++++++++++++++++ 2 files changed, 75 insertions(+), 29 deletions(-) create mode 100644 core/src/mast/serialization/string_table_builder.rs diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 0e7544265c..8ca9393ceb 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -23,6 +23,9 @@ use decorator::EncodedDecoratorVariant; mod info; use info::{EncodedMastNodeType, MastNodeInfo, MastNodeTypeVariant}; +mod string_table_builder; +use string_table_builder::StringTableBuilder; + #[cfg(test)] mod tests; @@ -80,7 +83,7 @@ impl Deserializable for StringRef { impl Serializable for MastForest { fn write_into(&self, target: &mut W) { // TODOP: make sure padding is in accordance with Paul's docs - let mut strings: Vec = Vec::new(); + let mut string_table_builder = StringTableBuilder::new(); let mut data: Vec = Vec::new(); // magic & version @@ -95,12 +98,13 @@ impl Serializable for MastForest { // MAST node infos for mast_node in &self.nodes { - let mast_node_info = mast_node_to_info(mast_node, &mut data, &mut strings); + let mast_node_info = mast_node_to_info(mast_node, &mut data, &mut string_table_builder); mast_node_info.write_into(target); } // strings table + let strings = string_table_builder.into_table(&mut data); strings.write_into(target); // data blob @@ -172,7 +176,7 @@ impl Deserializable for MastForest { fn mast_node_to_info( mast_node: &MastNode, data: &mut Vec, - strings: &mut Vec, + string_table_builder: &mut StringTableBuilder, ) -> MastNodeInfo { use MastNode::*; @@ -190,7 +194,7 @@ fn mast_node_to_info( match op_or_decorator { OperationOrDecorator::Operation(operation) => encode_operation(operation, data), OperationOrDecorator::Decorator(decorator) => { - encode_decorator(decorator, data, strings) + encode_decorator(decorator, data, string_table_builder) } } } @@ -305,7 +309,11 @@ fn encode_operation(operation: &Operation, data: &mut Vec) { } } -fn encode_decorator(decorator: &Decorator, data: &mut Vec, strings: &mut Vec) { +fn encode_decorator( + decorator: &Decorator, + data: &mut Vec, + string_table_builder: &mut StringTableBuilder, +) { // Set the first byte to the decorator discriminant. // // Note: the most significant bit is set to 1 (to differentiate decorators from operations). @@ -351,19 +359,16 @@ fn encode_decorator(decorator: &Decorator, data: &mut Vec, strings: &mut Vec data.push(assembly_op.num_cycles()); data.write_bool(assembly_op.should_break()); - // TODOP: Make a StringTable type - // context name { - let str_index_in_table = push_string(data, strings, assembly_op.context_name()); - + let str_index_in_table = + string_table_builder.add_string(assembly_op.context_name()); data.write_usize(str_index_in_table); } // op { - let str_index_in_table = push_string(data, strings, assembly_op.op()); - + let str_index_in_table = string_table_builder.add_string(assembly_op.op()); data.write_usize(str_index_in_table); } } @@ -384,24 +389,6 @@ fn encode_decorator(decorator: &Decorator, data: &mut Vec, strings: &mut Vec } } -// TODOP: Make this a method of `StringTable` type -fn push_string(data: &mut Vec, strings: &mut Vec, value: &str) -> StringIndex { - let offset = data.len(); - data.extend(value.as_bytes()); - - let str_ref = StringRef { - offset: offset - .try_into() - .expect("MastForest serialization: data field larger than 2^32 bytes"), - len: value.len().try_into().expect("decorator string length exceeds 2^32 bytes"), - }; - - let str_index_in_table = strings.len(); - strings.push(str_ref); - - str_index_in_table -} - fn try_info_to_mast_node( mast_node_info: MastNodeInfo, mast_forest: &MastForest, diff --git a/core/src/mast/serialization/string_table_builder.rs b/core/src/mast/serialization/string_table_builder.rs new file mode 100644 index 0000000000..ecc03ce150 --- /dev/null +++ b/core/src/mast/serialization/string_table_builder.rs @@ -0,0 +1,59 @@ +use alloc::{collections::BTreeMap, vec::Vec}; +use miden_crypto::hash::rpo::{Rpo256, RpoDigest}; + +use super::{StringIndex, StringRef}; + +#[derive(Debug, Default)] +pub struct StringTableBuilder { + table: Vec, + str_to_index: BTreeMap, + // current length of table + strings: Vec, +} + +impl StringTableBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn add_string(&mut self, string: &str) -> StringIndex { + if let Some(str_idx) = self.str_to_index.get(&Rpo256::hash(string.as_bytes())) { + // return already interned string + *str_idx + } else { + // add new string to table + // NOTE: these string refs' offset will need to be shifted again in `into_buffer()` + let str_ref = StringRef { + offset: self + .strings + .len() + .try_into() + .expect("strings table larger than 2^32 bytes"), + len: string.len().try_into().expect("string larger than 2^32 bytes"), + }; + let str_idx = self.table.len(); + + self.strings.extend(string.as_bytes()); + self.table.push(str_ref); + self.str_to_index.insert(Rpo256::hash(string.as_bytes()), str_idx); + + str_idx + } + } + + pub fn into_table(self, data: &mut Vec) -> Vec { + let table_offset: u32 = data + .len() + .try_into() + .expect("MAST forest serialization: data field longer than 2^32 bytes"); + data.extend(self.strings); + + self.table + .into_iter() + .map(|str_ref| StringRef { + offset: str_ref.offset + table_offset, + len: str_ref.len, + }) + .collect() + } +} From 48aec6a7a32b1e8f5e2a4dec3ddf7eb61600aec7 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 09:16:07 -0400 Subject: [PATCH 067/172] naming --- core/src/mast/serialization/string_table_builder.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/core/src/mast/serialization/string_table_builder.rs b/core/src/mast/serialization/string_table_builder.rs index ecc03ce150..ea01174c8c 100644 --- a/core/src/mast/serialization/string_table_builder.rs +++ b/core/src/mast/serialization/string_table_builder.rs @@ -7,8 +7,7 @@ use super::{StringIndex, StringRef}; pub struct StringTableBuilder { table: Vec, str_to_index: BTreeMap, - // current length of table - strings: Vec, + strings_data: Vec, } impl StringTableBuilder { @@ -25,7 +24,7 @@ impl StringTableBuilder { // NOTE: these string refs' offset will need to be shifted again in `into_buffer()` let str_ref = StringRef { offset: self - .strings + .strings_data .len() .try_into() .expect("strings table larger than 2^32 bytes"), @@ -33,7 +32,7 @@ impl StringTableBuilder { }; let str_idx = self.table.len(); - self.strings.extend(string.as_bytes()); + self.strings_data.extend(string.as_bytes()); self.table.push(str_ref); self.str_to_index.insert(Rpo256::hash(string.as_bytes()), str_idx); @@ -46,7 +45,7 @@ impl StringTableBuilder { .len() .try_into() .expect("MAST forest serialization: data field longer than 2^32 bytes"); - data.extend(self.strings); + data.extend(self.strings_data); self.table .into_iter() From a4ef4b156a23984ee1055c580476fa6aae9e0474 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 09:21:30 -0400 Subject: [PATCH 068/172] test confirm_operation_and_decorator_structure --- core/src/mast/serialization/tests.rs | 135 ++++++++++++++++++++++++++- 1 file changed, 133 insertions(+), 2 deletions(-) diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs index 1ba99cb9c7..6122a75721 100644 --- a/core/src/mast/serialization/tests.rs +++ b/core/src/mast/serialization/tests.rs @@ -3,9 +3,140 @@ use math::FieldElement; use super::*; use crate::operations::Operation; +/// If this test fails to compile, it means that `Operation` or `Decorator` was changed. Make sure +/// that all tests in this file are updated accordingly. For example, if a new `Operation` variant +/// was added, make sure that you add it in the vector of operations in +/// [`serialize_deserialize_all_nodes`]. #[test] -fn confirm_assumptions() { - // TODOP: match against all `Operation` and `Decorator` +fn confirm_operation_and_decorator_structure() { + let _ = match Operation::Noop { + Operation::Noop => (), + Operation::Assert(_) => (), + Operation::FmpAdd => (), + Operation::FmpUpdate => (), + Operation::SDepth => (), + Operation::Caller => (), + Operation::Clk => (), + Operation::Join => (), + Operation::Split => (), + Operation::Loop => (), + Operation::Call => (), + Operation::Dyn => (), + Operation::SysCall => (), + Operation::Span => (), + Operation::End => (), + Operation::Repeat => (), + Operation::Respan => (), + Operation::Halt => (), + Operation::Add => (), + Operation::Neg => (), + Operation::Mul => (), + Operation::Inv => (), + Operation::Incr => (), + Operation::And => (), + Operation::Or => (), + Operation::Not => (), + Operation::Eq => (), + Operation::Eqz => (), + Operation::Expacc => (), + Operation::Ext2Mul => (), + Operation::U32split => (), + Operation::U32add => (), + Operation::U32assert2(_) => (), + Operation::U32add3 => (), + Operation::U32sub => (), + Operation::U32mul => (), + Operation::U32madd => (), + Operation::U32div => (), + Operation::U32and => (), + Operation::U32xor => (), + Operation::Pad => (), + Operation::Drop => (), + Operation::Dup0 => (), + Operation::Dup1 => (), + Operation::Dup2 => (), + Operation::Dup3 => (), + Operation::Dup4 => (), + Operation::Dup5 => (), + Operation::Dup6 => (), + Operation::Dup7 => (), + Operation::Dup9 => (), + Operation::Dup11 => (), + Operation::Dup13 => (), + Operation::Dup15 => (), + Operation::Swap => (), + Operation::SwapW => (), + Operation::SwapW2 => (), + Operation::SwapW3 => (), + Operation::SwapDW => (), + Operation::MovUp2 => (), + Operation::MovUp3 => (), + Operation::MovUp4 => (), + Operation::MovUp5 => (), + Operation::MovUp6 => (), + Operation::MovUp7 => (), + Operation::MovUp8 => (), + Operation::MovDn2 => (), + Operation::MovDn3 => (), + Operation::MovDn4 => (), + Operation::MovDn5 => (), + Operation::MovDn6 => (), + Operation::MovDn7 => (), + Operation::MovDn8 => (), + Operation::CSwap => (), + Operation::CSwapW => (), + Operation::Push(_) => (), + Operation::AdvPop => (), + Operation::AdvPopW => (), + Operation::MLoadW => (), + Operation::MStoreW => (), + Operation::MLoad => (), + Operation::MStore => (), + Operation::MStream => (), + Operation::Pipe => (), + Operation::HPerm => (), + Operation::MpVerify(_) => (), + Operation::MrUpdate => (), + Operation::FriE2F4 => (), + Operation::RCombBase => (), + }; + + let _ = match Decorator::Event(0) { + Decorator::Advice(advice) => match advice { + AdviceInjector::MerkleNodeMerge => (), + AdviceInjector::MerkleNodeToStack => (), + AdviceInjector::UpdateMerkleNode => (), + AdviceInjector::MapValueToStack { + include_len: _, + key_offset: _, + } => (), + AdviceInjector::U64Div => (), + AdviceInjector::Ext2Inv => (), + AdviceInjector::Ext2Intt => (), + AdviceInjector::SmtGet => (), + AdviceInjector::SmtSet => (), + AdviceInjector::SmtPeek => (), + AdviceInjector::U32Clz => (), + AdviceInjector::U32Ctz => (), + AdviceInjector::U32Clo => (), + AdviceInjector::U32Cto => (), + AdviceInjector::ILog2 => (), + AdviceInjector::MemToMap => (), + AdviceInjector::HdwordToMap { domain: _ } => (), + AdviceInjector::HpermToMap => (), + AdviceInjector::SigToStack { kind: _ } => (), + }, + Decorator::AsmOp(_) => (), + Decorator::Debug(debug_options) => match debug_options { + DebugOptions::StackAll => (), + DebugOptions::StackTop(_) => (), + DebugOptions::MemAll => (), + DebugOptions::MemInterval(_, _) => (), + DebugOptions::LocalInterval(_, _, _) => (), + }, + Decorator::Event(_) => (), + Decorator::Trace(_) => (), + }; } #[test] From ca0e7feaf21c6f3d7afd2f7ebc4b8565a2d934f5 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 09:21:56 -0400 Subject: [PATCH 069/172] remove TODOP --- core/src/mast/node/call_node.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/mast/node/call_node.rs b/core/src/mast/node/call_node.rs index 4c66dd0e06..c2183ea84d 100644 --- a/core/src/mast/node/call_node.rs +++ b/core/src/mast/node/call_node.rs @@ -7,7 +7,6 @@ use crate::{chiplets::hasher, Operation}; use crate::mast::{MastForest, MastNodeId, MerkleTreeNode}; -// TODOP: `callee` must be a digest, #[derive(Debug, Clone, PartialEq, Eq)] pub struct CallNode { callee: MastNodeId, From 78e35fcd07c6d95e701969291bd5cb947c10b26f Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 09:22:56 -0400 Subject: [PATCH 070/172] remove unused `MastNode::new_dyncall()` --- core/src/mast/node/mod.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/core/src/mast/node/mod.rs b/core/src/mast/node/mod.rs index 16f55452be..31cb297309 100644 --- a/core/src/mast/node/mod.rs +++ b/core/src/mast/node/mod.rs @@ -88,11 +88,6 @@ impl MastNode { Self::Dyn } - // TODOP: removed, since unused? - pub fn new_dyncall(dyn_node_id: MastNodeId, mast_forest: &MastForest) -> Self { - Self::Call(CallNode::new(dyn_node_id, mast_forest)) - } - pub fn new_external(mast_root: RpoDigest) -> Self { Self::External(ExternalNode::new(mast_root)) } From 411f9f3ebf892afd9c8fcb859bc04a9665dee298 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 09:27:23 -0400 Subject: [PATCH 071/172] Remove `Error` type --- core/src/mast/serialization/info.rs | 15 ++++++++------- core/src/mast/serialization/mod.rs | 18 ++---------------- core/src/mast/serialization/tests.rs | 1 + 3 files changed, 11 insertions(+), 23 deletions(-) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index f776708de2..921f354334 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -73,10 +73,14 @@ impl EncodedMastNodeType { /// Accessors impl EncodedMastNodeType { - pub fn variant(&self) -> Result { + pub fn variant(&self) -> Result { let discriminant = self.0[0] >> 4; - MastNodeTypeVariant::try_from_discriminant(discriminant) + MastNodeTypeVariant::from_discriminant(discriminant).ok_or_else(|| { + DeserializationError::InvalidValue(format!( + "Invalid discriminant {discriminant} for MastNode" + )) + }) } } @@ -219,11 +223,8 @@ impl MastNodeTypeVariant { } // TODOP: Just do `from_discriminant() -> Option`, and document what `None` means - pub fn try_from_discriminant(discriminant: u8) -> Result { - Self::from_u8(discriminant).ok_or_else(|| super::Error::InvalidDiscriminant { - ty: "MastNode".into(), - discriminant, - }) + pub fn from_discriminant(discriminant: u8) -> Option { + Self::from_u8(discriminant) } pub fn from_mast_node(mast_node: &MastNode) -> Self { diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 8ca9393ceb..772a44c356 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -1,10 +1,6 @@ -use alloc::{ - string::{String, ToString}, - vec::Vec, -}; +use alloc::{string::String, vec::Vec}; use miden_crypto::{Felt, ZERO}; use num_traits::ToBytes; -use thiserror::Error; use winter_utils::{ ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, SliceReader, }; @@ -45,13 +41,6 @@ const MAGIC: &[u8; 5] = b"MAST\0"; /// version field itself, but should be considered invalid for now. const VERSION: [u8; 3] = [0, 0, 0]; -// TODOP: move into info.rs? Make public? -#[derive(Debug, Error)] -pub enum Error { - #[error("Invalid discriminant '{discriminant}' for type '{ty}'")] - InvalidDiscriminant { ty: String, discriminant: u8 }, -} - /// An entry in the `strings` table of an encoded [`MastForest`]. /// /// Strings are UTF8-encoded. @@ -396,10 +385,7 @@ fn try_info_to_mast_node( data: &[u8], strings: &[StringRef], ) -> Result { - let mast_node_variant = mast_node_info - .ty - .variant() - .map_err(|err| DeserializationError::InvalidValue(err.to_string()))?; + let mast_node_variant = mast_node_info.ty.variant()?; // TODOP: Make a faillible version of `MastNode` ctors // TODOP: Check digest of resulting `MastNode` matches `MastNodeInfo.digest`? diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs index 6122a75721..6f37920573 100644 --- a/core/src/mast/serialization/tests.rs +++ b/core/src/mast/serialization/tests.rs @@ -1,3 +1,4 @@ +use alloc::string::ToString; use math::FieldElement; use super::*; From 624984c7021f782c32152e42f525a54c87947e23 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 09:31:47 -0400 Subject: [PATCH 072/172] add TODOP --- core/src/mast/serialization/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 772a44c356..4c43965ee7 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -630,6 +630,7 @@ fn decode_decorator( } } +// TODOP: Rename and/or move to some struct? fn read_string( str_index_in_table: usize, data: &[u8], From 858582ad289f4c4db45eb3eab8bcccc50bd3e91f Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 09:38:19 -0400 Subject: [PATCH 073/172] complete test `serialize_deserialize_all_nodes` --- core/src/mast/serialization/tests.rs | 41 ++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs index 6f37920573..2dd3f151ee 100644 --- a/core/src/mast/serialization/tests.rs +++ b/core/src/mast/serialization/tests.rs @@ -1,5 +1,6 @@ use alloc::string::ToString; use math::FieldElement; +use miden_crypto::hash::rpo::RpoDigest; use super::*; use crate::operations::Operation; @@ -297,8 +298,44 @@ fn serialize_deserialize_all_nodes() { mast_forest.add_node(basic_block_node) }; - // TODOP: REMOVE - mast_forest.make_root(basic_block_id); + let call_node_id = { + let node = MastNode::new_call(basic_block_id, &mast_forest); + mast_forest.add_node(node) + }; + + let syscall_node_id = { + let node = MastNode::new_syscall(basic_block_id, &mast_forest); + mast_forest.add_node(node) + }; + + let loop_node_id = { + let node = MastNode::new_loop(basic_block_id, &mast_forest); + mast_forest.add_node(node) + }; + let join_node_id = { + let node = MastNode::new_join(basic_block_id, call_node_id, &mast_forest); + mast_forest.add_node(node) + }; + let split_node_id = { + let node = MastNode::new_split(basic_block_id, call_node_id, &mast_forest); + mast_forest.add_node(node) + }; + let dyn_node_id = { + let node = MastNode::new_dynexec(); + mast_forest.add_node(node) + }; + + let external_node_id = { + let node = MastNode::new_external(RpoDigest::default()); + mast_forest.add_node(node) + }; + + mast_forest.make_root(join_node_id); + mast_forest.make_root(syscall_node_id); + mast_forest.make_root(loop_node_id); + mast_forest.make_root(split_node_id); + mast_forest.make_root(dyn_node_id); + mast_forest.make_root(external_node_id); let serialized_mast_forest = mast_forest.to_bytes(); let deserialized_mast_forest = MastForest::read_from_bytes(&serialized_mast_forest).unwrap(); From 4e5efd3a2fdb3dc65d174e268d9720b3af3fff4f Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 09:42:14 -0400 Subject: [PATCH 074/172] check digest on deserialization --- core/src/mast/serialization/mod.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 4c43965ee7..ed943c140f 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -388,8 +388,7 @@ fn try_info_to_mast_node( let mast_node_variant = mast_node_info.ty.variant()?; // TODOP: Make a faillible version of `MastNode` ctors - // TODOP: Check digest of resulting `MastNode` matches `MastNodeInfo.digest`? - match mast_node_variant { + let mast_node = match mast_node_variant { MastNodeTypeVariant::Block => { let num_operations_and_decorators = EncodedMastNodeType::decode_u32_payload(&mast_node_info.ty); @@ -432,6 +431,16 @@ fn try_info_to_mast_node( } MastNodeTypeVariant::Dyn => Ok(MastNode::new_dynexec()), MastNodeTypeVariant::External => Ok(MastNode::new_external(mast_node_info.digest)), + }?; + + if mast_node.digest() == mast_node_info.digest { + Ok(mast_node) + } else { + Err(DeserializationError::InvalidValue(format!( + "MastNodeInfo's digest '{}' doesn't match deserialized MastNode's digest '{}'", + mast_node_info.digest, + mast_node.digest() + ))) } } From 3f522b8c87631b36e32c34565d6f4cf4b85e3198 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 10:28:40 -0400 Subject: [PATCH 075/172] remove TODOP --- core/src/mast/serialization/info.rs | 2 +- core/src/mast/serialization/mod.rs | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 921f354334..2eed699ad7 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -204,6 +204,7 @@ impl Deserializable for EncodedMastNodeType { } } +// TODOP: Document (and rename `Encoded*`?) #[derive(Clone, Copy, Debug, FromPrimitive, ToPrimitive)] #[repr(u8)] pub enum MastNodeTypeVariant { @@ -222,7 +223,6 @@ impl MastNodeTypeVariant { self.to_u8().expect("guaranteed to fit in a `u8` due to #[repr(u8)]") } - // TODOP: Just do `from_discriminant() -> Option`, and document what `None` means pub fn from_discriminant(discriminant: u8) -> Option { Self::from_u8(discriminant) } diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index ed943c140f..3d2284e9ac 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -71,7 +71,6 @@ impl Deserializable for StringRef { impl Serializable for MastForest { fn write_into(&self, target: &mut W) { - // TODOP: make sure padding is in accordance with Paul's docs let mut string_table_builder = StringTableBuilder::new(); let mut data: Vec = Vec::new(); From 161578d555c56a51ada7adc81e7410e7b1638b7a Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 11:05:56 -0400 Subject: [PATCH 076/172] safely decode mast node ids --- core/src/mast/mod.rs | 21 ++++++++ core/src/mast/serialization/info.rs | 75 +++++++++++++++++------------ core/src/mast/serialization/mod.rs | 26 ++++++---- 3 files changed, 83 insertions(+), 39 deletions(-) diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index de99e707c7..1c0eb12380 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -30,6 +30,27 @@ pub trait MerkleTreeNode { #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct MastNodeId(u32); +impl MastNodeId { + /// Returns a new `MastNodeId` with the provided inner value, or an error if the provided + /// `value` is greater than the number of nodes in the forest. + /// + /// For use in deserialization. + pub fn from_u32_safe( + value: u32, + mast_forest: &MastForest, + ) -> Result { + if (value as usize) < mast_forest.nodes.len() { + Ok(Self(value)) + } else { + Err(DeserializationError::InvalidValue(format!( + "Invalid deserialized MAST node ID '{}', but only {} nodes in the forest", + value, + mast_forest.nodes.len(), + ))) + } + } +} + impl fmt::Display for MastNodeId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "MastNodeId({})", self.0) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 2eed699ad7..34dd68378f 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -3,7 +3,7 @@ use num_derive::{FromPrimitive, ToPrimitive}; use num_traits::{FromPrimitive, ToPrimitive}; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; -use crate::mast::{MastNode, MastNodeId}; +use crate::mast::{MastForest, MastNode, MastNodeId}; use super::DataOffset; @@ -135,7 +135,43 @@ impl EncodedMastNodeType { Self(result) } - pub fn decode_join_or_split(&self) -> (MastNodeId, MastNodeId) { + pub fn decode_join_or_split( + &self, + mast_forest: &MastForest, + ) -> Result<(MastNodeId, MastNodeId), DeserializationError> { + let (first, second) = self.decode_join_or_split_impl(); + + Ok(( + MastNodeId::from_u32_safe(first, mast_forest)?, + MastNodeId::from_u32_safe(second, mast_forest)?, + )) + } + + pub fn encode_u32_payload(discriminant: u8, payload: u32) -> Self { + let [payload_byte1, payload_byte2, payload_byte3, payload_byte4] = payload.to_be_bytes(); + + Self([ + discriminant << 4, + payload_byte1, + payload_byte2, + payload_byte3, + payload_byte4, + 0, + 0, + 0, + ]) + } + + pub fn decode_u32_payload(&self) -> u32 { + let payload_be_bytes = [self.0[1], self.0[2], self.0[3], self.0[4]]; + + u32::from_be_bytes(payload_be_bytes) + } +} + +/// Helpers +impl EncodedMastNodeType { + fn decode_join_or_split_impl(&self) -> (u32, u32) { let first = { let mut first_le_bytes = [0_u8; 4]; @@ -165,28 +201,7 @@ impl EncodedMastNodeType { u32::from_be_bytes(second_be_bytes) }; - (MastNodeId(first), MastNodeId(second)) - } - - pub fn encode_u32_payload(discriminant: u8, payload: u32) -> Self { - let [payload_byte1, payload_byte2, payload_byte3, payload_byte4] = payload.to_be_bytes(); - - Self([ - discriminant << 4, - payload_byte1, - payload_byte2, - payload_byte3, - payload_byte4, - 0, - 0, - 0, - ]) - } - - pub fn decode_u32_payload(&self) -> u32 { - let payload_be_bytes = [self.0[1], self.0[2], self.0[3], self.0[4]]; - - u32::from_be_bytes(payload_be_bytes) + (first, second) } } @@ -270,9 +285,9 @@ mod tests { assert_eq!(expected_mast_node_type, mast_node_type.0); - let (decoded_left, decoded_right) = mast_node_type.decode_join_or_split(); - assert_eq!(left_child_id, decoded_left); - assert_eq!(right_child_id, decoded_right); + let (decoded_left, decoded_right) = mast_node_type.decode_join_or_split_impl(); + assert_eq!(left_child_id.0, decoded_left); + assert_eq!(right_child_id.0, decoded_right); } #[test] @@ -292,9 +307,9 @@ mod tests { assert_eq!(expected_mast_node_type, mast_node_type.0); - let (decoded_on_true, decoded_on_false) = mast_node_type.decode_join_or_split(); - assert_eq!(on_true_id, decoded_on_true); - assert_eq!(on_false_id, decoded_on_false); + let (decoded_on_true, decoded_on_false) = mast_node_type.decode_join_or_split_impl(); + assert_eq!(on_true_id.0, decoded_on_true); + assert_eq!(on_false_id.0, decoded_on_false); } // TODOP: Test all other variants diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 3d2284e9ac..9ff6eda328 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -386,7 +386,6 @@ fn try_info_to_mast_node( ) -> Result { let mast_node_variant = mast_node_info.ty.variant()?; - // TODOP: Make a faillible version of `MastNode` ctors let mast_node = match mast_node_variant { MastNodeTypeVariant::Block => { let num_operations_and_decorators = @@ -403,30 +402,39 @@ fn try_info_to_mast_node( } MastNodeTypeVariant::Join => { let (left_child, right_child) = - EncodedMastNodeType::decode_join_or_split(&mast_node_info.ty); + EncodedMastNodeType::decode_join_or_split(&mast_node_info.ty, mast_forest)?; Ok(MastNode::new_join(left_child, right_child, mast_forest)) } MastNodeTypeVariant::Split => { let (if_branch, else_branch) = - EncodedMastNodeType::decode_join_or_split(&mast_node_info.ty); + EncodedMastNodeType::decode_join_or_split(&mast_node_info.ty, mast_forest)?; Ok(MastNode::new_split(if_branch, else_branch, mast_forest)) } MastNodeTypeVariant::Loop => { - let body_id = EncodedMastNodeType::decode_u32_payload(&mast_node_info.ty); + let body_id = { + let body_id = EncodedMastNodeType::decode_u32_payload(&mast_node_info.ty); + MastNodeId::from_u32_safe(body_id, mast_forest)? + }; - Ok(MastNode::new_loop(MastNodeId(body_id), mast_forest)) + Ok(MastNode::new_loop(body_id, mast_forest)) } MastNodeTypeVariant::Call => { - let callee_id = EncodedMastNodeType::decode_u32_payload(&mast_node_info.ty); + let callee_id = { + let callee_id = EncodedMastNodeType::decode_u32_payload(&mast_node_info.ty); + MastNodeId::from_u32_safe(callee_id, mast_forest)? + }; - Ok(MastNode::new_call(MastNodeId(callee_id), mast_forest)) + Ok(MastNode::new_call(callee_id, mast_forest)) } MastNodeTypeVariant::Syscall => { - let callee_id = EncodedMastNodeType::decode_u32_payload(&mast_node_info.ty); + let callee_id = { + let callee_id = EncodedMastNodeType::decode_u32_payload(&mast_node_info.ty); + MastNodeId::from_u32_safe(callee_id, mast_forest)? + }; - Ok(MastNode::new_syscall(MastNodeId(callee_id), mast_forest)) + Ok(MastNode::new_syscall(callee_id, mast_forest)) } MastNodeTypeVariant::Dyn => Ok(MastNode::new_dynexec()), MastNodeTypeVariant::External => Ok(MastNode::new_external(mast_node_info.digest)), From 2ab1cf9b17ef354753bafeca55ac814f131a27a1 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 11:08:45 -0400 Subject: [PATCH 077/172] use method syntax in `MastNodeType` decoding --- core/src/mast/serialization/mod.rs | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 9ff6eda328..644ec4bf17 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -388,8 +388,7 @@ fn try_info_to_mast_node( let mast_node = match mast_node_variant { MastNodeTypeVariant::Block => { - let num_operations_and_decorators = - EncodedMastNodeType::decode_u32_payload(&mast_node_info.ty); + let num_operations_and_decorators = mast_node_info.ty.decode_u32_payload(); let (operations, decorators) = decode_operations_and_decorators( num_operations_and_decorators, @@ -401,20 +400,18 @@ fn try_info_to_mast_node( Ok(MastNode::new_basic_block_with_decorators(operations, decorators)) } MastNodeTypeVariant::Join => { - let (left_child, right_child) = - EncodedMastNodeType::decode_join_or_split(&mast_node_info.ty, mast_forest)?; + let (left_child, right_child) = mast_node_info.ty.decode_join_or_split(mast_forest)?; Ok(MastNode::new_join(left_child, right_child, mast_forest)) } MastNodeTypeVariant::Split => { - let (if_branch, else_branch) = - EncodedMastNodeType::decode_join_or_split(&mast_node_info.ty, mast_forest)?; + let (if_branch, else_branch) = mast_node_info.ty.decode_join_or_split(mast_forest)?; Ok(MastNode::new_split(if_branch, else_branch, mast_forest)) } MastNodeTypeVariant::Loop => { let body_id = { - let body_id = EncodedMastNodeType::decode_u32_payload(&mast_node_info.ty); + let body_id = mast_node_info.ty.decode_u32_payload(); MastNodeId::from_u32_safe(body_id, mast_forest)? }; @@ -422,7 +419,7 @@ fn try_info_to_mast_node( } MastNodeTypeVariant::Call => { let callee_id = { - let callee_id = EncodedMastNodeType::decode_u32_payload(&mast_node_info.ty); + let callee_id = mast_node_info.ty.decode_u32_payload(); MastNodeId::from_u32_safe(callee_id, mast_forest)? }; @@ -430,7 +427,7 @@ fn try_info_to_mast_node( } MastNodeTypeVariant::Syscall => { let callee_id = { - let callee_id = EncodedMastNodeType::decode_u32_payload(&mast_node_info.ty); + let callee_id = mast_node_info.ty.decode_u32_payload(); MastNodeId::from_u32_safe(callee_id, mast_forest)? }; From c8cfa8e619f73ab60a2e75f066622136772fc09a Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 11:11:08 -0400 Subject: [PATCH 078/172] TODOPs --- core/src/mast/serialization/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 644ec4bf17..af7c58fb71 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -161,6 +161,7 @@ impl Deserializable for MastForest { } } +// TODOP: Make `MastNode` method (impl in this module)? fn mast_node_to_info( mast_node: &MastNode, data: &mut Vec, @@ -377,6 +378,8 @@ fn encode_decorator( } } +// TODOP: Make `MastNodeInfo` method +// TODOP: Can we not have both `data` and `data_reader`? fn try_info_to_mast_node( mast_node_info: MastNodeInfo, mast_forest: &MastForest, From a360959792a5be326d33905580e97a76a0df5550 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 13:24:56 -0400 Subject: [PATCH 079/172] rewrite <= expression --- core/src/mast/serialization/info.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 34dd68378f..768c4d59f1 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -40,7 +40,7 @@ impl EncodedMastNodeType { use MastNode::*; let discriminant = MastNodeTypeVariant::from_mast_node(mast_node).discriminant(); - assert!(discriminant < 2_u8.pow(4_u32)); + assert!(discriminant <= 0b1111); match mast_node { Block(block_node) => { From c8ba463f3ab5380cebd7f58c78c43b72b14dfee9 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 15:39:42 -0400 Subject: [PATCH 080/172] new `MastNodeType` --- core/src/mast/serialization/info.rs | 125 ++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 768c4d59f1..5b8dc40383 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -31,6 +31,131 @@ impl Deserializable for MastNodeInfo { } } +/// TODOP: Document the fact that encoded representation is always 8 bytes +#[derive(Debug)] +#[repr(u8)] +pub enum MastNodeType { + Join { + left_child_id: u32, + right_child_id: u32, + }, + Split { + if_branch_id: u32, + else_branch_id: u32, + }, + Loop { + body: u32, + }, + Block { + /// The number of operations and decorators in the basic block + len: u32, + }, + Call { + callee_id: u32, + }, + SysCall { + callee_id: u32, + }, + Dyn, + External, +} + +impl MastNodeType { + fn tag(&self) -> u8 { + // SAFETY: This is safe because we have given this enum a primitive representation with + // #[repr(u8)], with the first field of the underlying union-of-structs the discriminant. + // + // See the section on "accessing the numeric value of the discriminant" + // here: https://doc.rust-lang.org/std/mem/fn.discriminant.html + unsafe { *<*const _>::from(self).cast::() } + } + + fn inline_data_to_bytes(&self) -> [u8; 8] { + match self { + MastNodeType::Join { + left_child_id: left, + right_child_id: right, + } => Self::encode_join_or_split(*left, *right), + MastNodeType::Split { + if_branch_id: if_branch, + else_branch_id: else_branch, + } => Self::encode_join_or_split(*if_branch, *else_branch), + MastNodeType::Loop { body } => Self::encode_u32_payload(*body), + MastNodeType::Block { len } => Self::encode_u32_payload(*len), + MastNodeType::Call { callee_id } => Self::encode_u32_payload(*callee_id), + MastNodeType::SysCall { callee_id } => Self::encode_u32_payload(*callee_id), + MastNodeType::Dyn => [0; 8], + MastNodeType::External => [0; 8], + } + } + + // TODOP: Make a diagram of how the bits are split + fn encode_join_or_split(left_child_id: u32, right_child_id: u32) -> [u8; 8] { + assert!(left_child_id < 2_u32.pow(30)); + assert!(right_child_id < 2_u32.pow(30)); + + let mut result: [u8; 8] = [0_u8; 8]; + + // write left child into result + { + let [lsb, a, b, msb] = left_child_id.to_le_bytes(); + result[0] |= lsb >> 4; + result[1] |= lsb << 4; + result[1] |= a >> 4; + result[2] |= a << 4; + result[2] |= b >> 4; + result[3] |= b << 4; + + // msb is different from lsb, a and b since its 2 most significant bits are guaranteed + // to be 0, and hence not encoded. + // + // More specifically, let the bits of msb be `00abcdef`. We encode `abcd` in + // `result[3]`, and `ef` as the most significant bits of `result[4]`. + result[3] |= msb >> 2; + result[4] |= msb << 6; + }; + + // write right child into result + { + // Recall that `result[4]` contains 2 bits from the left child id in the most + // significant bits. Also, the most significant byte of the right child is guaranteed to + // fit in 6 bits. Hence, we use big endian format for the right child id to simplify + // encoding and decoding. + let [msb, a, b, lsb] = right_child_id.to_be_bytes(); + + result[4] |= msb; + result[5] = a; + result[6] = b; + result[7] = lsb; + }; + + result + } + + fn encode_u32_payload(payload: u32) -> [u8; 8] { + let [payload_byte1, payload_byte2, payload_byte3, payload_byte4] = payload.to_be_bytes(); + + [0, payload_byte1, payload_byte2, payload_byte3, payload_byte4, 0, 0, 0] + } +} + +impl Serializable for MastNodeType { + fn write_into(&self, target: &mut W) { + let serialized_bytes = { + let mut serialized_bytes = self.inline_data_to_bytes(); + + // Tag is always placed in the first four bytes + let tag = self.tag(); + assert!(tag <= 0b1111); + serialized_bytes[0] |= tag << 4; + + serialized_bytes + }; + + serialized_bytes.write_into(target) + } +} + // TODOP: Describe how first 4 bits (i.e. high order bits of first byte) are the discriminant pub struct EncodedMastNodeType(pub(super) [u8; 8]); From ad1858081da579c4f73f859409d727c6f6b28e47 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 16:04:39 -0400 Subject: [PATCH 081/172] implement `Deserializable` for `MastNodeType` --- core/src/mast/serialization/info.rs | 138 ++++++++++++++++++++++++---- 1 file changed, 118 insertions(+), 20 deletions(-) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 5b8dc40383..9b74d9f2ba 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -31,6 +31,15 @@ impl Deserializable for MastNodeInfo { } } +const JOIN: u8 = 0; +const SPLIT: u8 = 1; +const LOOP: u8 = 2; +const BLOCK: u8 = 3; +const CALL: u8 = 4; +const SYSCALL: u8 = 5; +const DYN: u8 = 6; +const EXTERNAL: u8 = 7; + /// TODOP: Document the fact that encoded representation is always 8 bytes #[derive(Debug)] #[repr(u8)] @@ -38,28 +47,46 @@ pub enum MastNodeType { Join { left_child_id: u32, right_child_id: u32, - }, + } = JOIN, Split { if_branch_id: u32, else_branch_id: u32, - }, + } = SPLIT, Loop { - body: u32, - }, + body_id: u32, + } = LOOP, Block { /// The number of operations and decorators in the basic block len: u32, - }, + } = BLOCK, Call { callee_id: u32, - }, + } = CALL, SysCall { callee_id: u32, - }, - Dyn, - External, + } = SYSCALL, + Dyn = DYN, + External = EXTERNAL, +} + +impl Serializable for MastNodeType { + fn write_into(&self, target: &mut W) { + let serialized_bytes = { + let mut serialized_bytes = self.inline_data_to_bytes(); + + // Tag is always placed in the first four bytes + let tag = self.tag(); + assert!(tag <= 0b1111); + serialized_bytes[0] |= tag << 4; + + serialized_bytes + }; + + serialized_bytes.write_into(target) + } } +/// Serialization helpers impl MastNodeType { fn tag(&self) -> u8 { // SAFETY: This is safe because we have given this enum a primitive representation with @@ -80,7 +107,7 @@ impl MastNodeType { if_branch_id: if_branch, else_branch_id: else_branch, } => Self::encode_join_or_split(*if_branch, *else_branch), - MastNodeType::Loop { body } => Self::encode_u32_payload(*body), + MastNodeType::Loop { body_id: body } => Self::encode_u32_payload(*body), MastNodeType::Block { len } => Self::encode_u32_payload(*len), MastNodeType::Call { callee_id } => Self::encode_u32_payload(*callee_id), MastNodeType::SysCall { callee_id } => Self::encode_u32_payload(*callee_id), @@ -139,20 +166,91 @@ impl MastNodeType { } } -impl Serializable for MastNodeType { - fn write_into(&self, target: &mut W) { - let serialized_bytes = { - let mut serialized_bytes = self.inline_data_to_bytes(); +impl Deserializable for MastNodeType { + fn read_from(source: &mut R) -> Result { + let bytes: [u8; 8] = source.read_array()?; - // Tag is always placed in the first four bytes - let tag = self.tag(); - assert!(tag <= 0b1111); - serialized_bytes[0] |= tag << 4; + let tag = bytes[0] >> 4; - serialized_bytes + match tag { + JOIN => { + let (left_child_id, right_child_id) = Self::decode_join_or_split(bytes); + Ok(Self::Join { + left_child_id, + right_child_id, + }) + } + SPLIT => { + let (if_branch_id, else_branch_id) = Self::decode_join_or_split(bytes); + Ok(Self::Split { + if_branch_id, + else_branch_id, + }) + } + LOOP => { + let body_id = Self::decode_u32_payload(bytes); + Ok(Self::Loop { body_id }) + } + BLOCK => { + let len = Self::decode_u32_payload(bytes); + Ok(Self::Block { len }) + } + CALL => { + let callee_id = Self::decode_u32_payload(bytes); + Ok(Self::Call { callee_id }) + } + SYSCALL => { + let callee_id = Self::decode_u32_payload(bytes); + Ok(Self::SysCall { callee_id }) + } + DYN => Ok(Self::Dyn), + EXTERNAL => Ok(Self::External), + _ => { + Err(DeserializationError::InvalidValue(format!("Invalid tag for MAST node: {tag}"))) + } + } + } +} + +/// Deserialization helpers +impl MastNodeType { + fn decode_join_or_split(buffer: [u8; 8]) -> (u32, u32) { + let first = { + let mut first_le_bytes = [0_u8; 4]; + + first_le_bytes[0] = buffer[0] << 4; + first_le_bytes[0] |= buffer[1] >> 4; + + first_le_bytes[1] = buffer[1] << 4; + first_le_bytes[1] |= buffer[2] >> 4; + + first_le_bytes[2] = buffer[2] << 4; + first_le_bytes[2] |= buffer[3] >> 4; + + first_le_bytes[3] = (buffer[3] & 0b1111) << 2; + first_le_bytes[3] |= buffer[4] >> 6; + + u32::from_le_bytes(first_le_bytes) }; - serialized_bytes.write_into(target) + let second = { + let mut second_be_bytes = [0_u8; 4]; + + second_be_bytes[0] = buffer[4] & 0b0011_1111; + second_be_bytes[1] = buffer[5]; + second_be_bytes[2] = buffer[6]; + second_be_bytes[3] = buffer[7]; + + u32::from_be_bytes(second_be_bytes) + }; + + (first, second) + } + + pub fn decode_u32_payload(payload: [u8; 8]) -> u32 { + let payload_be_bytes = [payload[1], payload[2], payload[3], payload[4]]; + + u32::from_be_bytes(payload_be_bytes) } } From 996498f4fd2a49d600d2aab62defdf649e4458dd Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 16:20:41 -0400 Subject: [PATCH 082/172] migrate tests to new --- core/src/mast/serialization/info.rs | 56 ++++++++++++++++------------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 9b74d9f2ba..c0e545830c 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -487,52 +487,60 @@ impl MastNodeTypeVariant { #[cfg(test)] mod tests { use super::*; - use crate::mast::{JoinNode, SplitNode}; + use alloc::vec::Vec; #[test] fn mast_node_type_serde_join() { - let left_child_id = MastNodeId(0b00111001_11101011_01101100_11011000); - let right_child_id = MastNodeId(0b00100111_10101010_11111111_11001110); - let mast_node = MastNode::Join(JoinNode::new_test( - [left_child_id, right_child_id], - RpoDigest::default(), - )); + let left_child_id = 0b00111001_11101011_01101100_11011000; + let right_child_id = 0b00100111_10101010_11111111_11001110; - let mast_node_type = EncodedMastNodeType::new(&mast_node); + let mast_node_type = MastNodeType::Join { + left_child_id, + right_child_id, + }; + + let mut encoded_mast_node_type: Vec = Vec::new(); + mast_node_type.write_into(&mut encoded_mast_node_type); // Note: Join's discriminant is 0 - let expected_mast_node_type = [ + let expected_encoded_mast_node_type = [ 0b00001101, 0b10000110, 0b11001110, 0b10111110, 0b01100111, 0b10101010, 0b11111111, 0b11001110, ]; - assert_eq!(expected_mast_node_type, mast_node_type.0); + assert_eq!(expected_encoded_mast_node_type.to_vec(), encoded_mast_node_type); - let (decoded_left, decoded_right) = mast_node_type.decode_join_or_split_impl(); - assert_eq!(left_child_id.0, decoded_left); - assert_eq!(right_child_id.0, decoded_right); + let (decoded_left, decoded_right) = + MastNodeType::decode_join_or_split(expected_encoded_mast_node_type); + assert_eq!(left_child_id, decoded_left); + assert_eq!(right_child_id, decoded_right); } #[test] fn mast_node_type_serde_split() { - let on_true_id = MastNodeId(0b00111001_11101011_01101100_11011000); - let on_false_id = MastNodeId(0b00100111_10101010_11111111_11001110); - let mast_node = - MastNode::Split(SplitNode::new_test([on_true_id, on_false_id], RpoDigest::default())); + let if_branch_id = 0b00111001_11101011_01101100_11011000; + let else_branch_id = 0b00100111_10101010_11111111_11001110; + + let mast_node_type = MastNodeType::Split { + if_branch_id, + else_branch_id, + }; - let mast_node_type = EncodedMastNodeType::new(&mast_node); + let mut encoded_mast_node_type: Vec = Vec::new(); + mast_node_type.write_into(&mut encoded_mast_node_type); - // Note: Split's discriminant is 0 - let expected_mast_node_type = [ + // Note: Split's discriminant is 1 + let expected_encoded_mast_node_type = [ 0b00011101, 0b10000110, 0b11001110, 0b10111110, 0b01100111, 0b10101010, 0b11111111, 0b11001110, ]; - assert_eq!(expected_mast_node_type, mast_node_type.0); + assert_eq!(expected_encoded_mast_node_type.to_vec(), encoded_mast_node_type); - let (decoded_on_true, decoded_on_false) = mast_node_type.decode_join_or_split_impl(); - assert_eq!(on_true_id.0, decoded_on_true); - assert_eq!(on_false_id.0, decoded_on_false); + let (decoded_if_branch, decoded_else_branch) = + MastNodeType::decode_join_or_split(expected_encoded_mast_node_type); + assert_eq!(if_branch_id, decoded_if_branch); + assert_eq!(else_branch_id, decoded_else_branch); } // TODOP: Test all other variants From e60bfc2f56ac2ab12863c12ee8f0f343cb0aeabb Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 16:38:56 -0400 Subject: [PATCH 083/172] Use new MastNodeType --- core/src/mast/serialization/info.rs | 276 +++++----------------------- core/src/mast/serialization/mod.rs | 55 +++--- 2 files changed, 68 insertions(+), 263 deletions(-) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index c0e545830c..7d5c5ff440 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -1,14 +1,13 @@ use miden_crypto::hash::rpo::RpoDigest; -use num_derive::{FromPrimitive, ToPrimitive}; -use num_traits::{FromPrimitive, ToPrimitive}; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; -use crate::mast::{MastForest, MastNode, MastNodeId}; +use crate::mast::MastNode; use super::DataOffset; +#[derive(Debug)] pub struct MastNodeInfo { - pub(super) ty: EncodedMastNodeType, + pub(super) ty: MastNodeType, pub(super) offset: DataOffset, pub(super) digest: RpoDigest, } @@ -69,6 +68,45 @@ pub enum MastNodeType { External = EXTERNAL, } +/// Constructors +impl MastNodeType { + pub fn new(mast_node: &MastNode) -> Self { + use MastNode::*; + + match mast_node { + Block(block_node) => { + let len = block_node.num_operations_and_decorators(); + + Self::Block { len } + } + Join(join_node) => Self::Join { + left_child_id: join_node.first().0, + right_child_id: join_node.second().0, + }, + Split(split_node) => Self::Split { + if_branch_id: split_node.on_true().0, + else_branch_id: split_node.on_false().0, + }, + Loop(loop_node) => Self::Loop { + body_id: loop_node.body().0, + }, + Call(call_node) => { + if call_node.is_syscall() { + Self::SysCall { + callee_id: call_node.callee().0, + } + } else { + Self::Call { + callee_id: call_node.callee().0, + } + } + } + Dyn => Self::Dyn, + External(_) => Self::External, + } + } +} + impl Serializable for MastNodeType { fn write_into(&self, target: &mut W) { let serialized_bytes = { @@ -254,236 +292,6 @@ impl MastNodeType { } } -// TODOP: Describe how first 4 bits (i.e. high order bits of first byte) are the discriminant -pub struct EncodedMastNodeType(pub(super) [u8; 8]); - -/// Constructors -impl EncodedMastNodeType { - pub fn new(mast_node: &MastNode) -> Self { - use MastNode::*; - - let discriminant = MastNodeTypeVariant::from_mast_node(mast_node).discriminant(); - assert!(discriminant <= 0b1111); - - match mast_node { - Block(block_node) => { - let num_ops = block_node.num_operations_and_decorators(); - - Self::encode_u32_payload(discriminant, num_ops) - } - Join(join_node) => { - Self::encode_join_or_split(discriminant, join_node.first(), join_node.second()) - } - Split(split_node) => Self::encode_join_or_split( - discriminant, - split_node.on_true(), - split_node.on_false(), - ), - Loop(loop_node) => { - let child_id = loop_node.body().0; - - Self::encode_u32_payload(discriminant, child_id) - } - Call(call_node) => { - let child_id = call_node.callee().0; - - Self::encode_u32_payload(discriminant, child_id) - } - Dyn | External(_) => Self([discriminant << 4, 0, 0, 0, 0, 0, 0, 0]), - } - } -} - -/// Accessors -impl EncodedMastNodeType { - pub fn variant(&self) -> Result { - let discriminant = self.0[0] >> 4; - - MastNodeTypeVariant::from_discriminant(discriminant).ok_or_else(|| { - DeserializationError::InvalidValue(format!( - "Invalid discriminant {discriminant} for MastNode" - )) - }) - } -} - -/// Helpers -impl EncodedMastNodeType { - // TODOP: Make a diagram of how the bits are split - pub fn encode_join_or_split( - discriminant: u8, - left_child_id: MastNodeId, - right_child_id: MastNodeId, - ) -> Self { - assert!(left_child_id.0 < 2_u32.pow(30)); - assert!(right_child_id.0 < 2_u32.pow(30)); - - let mut result: [u8; 8] = [0_u8; 8]; - - result[0] = discriminant << 4; - - // write left child into result - { - let [lsb, a, b, msb] = left_child_id.0.to_le_bytes(); - result[0] |= lsb >> 4; - result[1] |= lsb << 4; - result[1] |= a >> 4; - result[2] |= a << 4; - result[2] |= b >> 4; - result[3] |= b << 4; - - // msb is different from lsb, a and b since its 2 most significant bits are guaranteed - // to be 0, and hence not encoded. - // - // More specifically, let the bits of msb be `00abcdef`. We encode `abcd` in - // `result[3]`, and `ef` as the most significant bits of `result[4]`. - result[3] |= msb >> 2; - result[4] |= msb << 6; - }; - - // write right child into result - { - // Recall that `result[4]` contains 2 bits from the left child id in the most - // significant bits. Also, the most significant byte of the right child is guaranteed to - // fit in 6 bits. Hence, we use big endian format for the right child id to simplify - // encoding and decoding. - let [msb, a, b, lsb] = right_child_id.0.to_be_bytes(); - - result[4] |= msb; - result[5] = a; - result[6] = b; - result[7] = lsb; - }; - - Self(result) - } - - pub fn decode_join_or_split( - &self, - mast_forest: &MastForest, - ) -> Result<(MastNodeId, MastNodeId), DeserializationError> { - let (first, second) = self.decode_join_or_split_impl(); - - Ok(( - MastNodeId::from_u32_safe(first, mast_forest)?, - MastNodeId::from_u32_safe(second, mast_forest)?, - )) - } - - pub fn encode_u32_payload(discriminant: u8, payload: u32) -> Self { - let [payload_byte1, payload_byte2, payload_byte3, payload_byte4] = payload.to_be_bytes(); - - Self([ - discriminant << 4, - payload_byte1, - payload_byte2, - payload_byte3, - payload_byte4, - 0, - 0, - 0, - ]) - } - - pub fn decode_u32_payload(&self) -> u32 { - let payload_be_bytes = [self.0[1], self.0[2], self.0[3], self.0[4]]; - - u32::from_be_bytes(payload_be_bytes) - } -} - -/// Helpers -impl EncodedMastNodeType { - fn decode_join_or_split_impl(&self) -> (u32, u32) { - let first = { - let mut first_le_bytes = [0_u8; 4]; - - first_le_bytes[0] = self.0[0] << 4; - first_le_bytes[0] |= self.0[1] >> 4; - - first_le_bytes[1] = self.0[1] << 4; - first_le_bytes[1] |= self.0[2] >> 4; - - first_le_bytes[2] = self.0[2] << 4; - first_le_bytes[2] |= self.0[3] >> 4; - - first_le_bytes[3] = (self.0[3] & 0b1111) << 2; - first_le_bytes[3] |= self.0[4] >> 6; - - u32::from_le_bytes(first_le_bytes) - }; - - let second = { - let mut second_be_bytes = [0_u8; 4]; - - second_be_bytes[0] = self.0[4] & 0b0011_1111; - second_be_bytes[1] = self.0[5]; - second_be_bytes[2] = self.0[6]; - second_be_bytes[3] = self.0[7]; - - u32::from_be_bytes(second_be_bytes) - }; - - (first, second) - } -} - -impl Serializable for EncodedMastNodeType { - fn write_into(&self, target: &mut W) { - self.0.write_into(target); - } -} - -impl Deserializable for EncodedMastNodeType { - fn read_from(source: &mut R) -> Result { - let bytes = source.read_array()?; - - Ok(Self(bytes)) - } -} - -// TODOP: Document (and rename `Encoded*`?) -#[derive(Clone, Copy, Debug, FromPrimitive, ToPrimitive)] -#[repr(u8)] -pub enum MastNodeTypeVariant { - Join, - Split, - Loop, - Call, - Syscall, - Dyn, - Block, - External, -} - -impl MastNodeTypeVariant { - pub fn discriminant(&self) -> u8 { - self.to_u8().expect("guaranteed to fit in a `u8` due to #[repr(u8)]") - } - - pub fn from_discriminant(discriminant: u8) -> Option { - Self::from_u8(discriminant) - } - - pub fn from_mast_node(mast_node: &MastNode) -> Self { - match mast_node { - MastNode::Block(_) => Self::Block, - MastNode::Join(_) => Self::Join, - MastNode::Split(_) => Self::Split, - MastNode::Loop(_) => Self::Loop, - MastNode::Call(call_node) => { - if call_node.is_syscall() { - Self::Syscall - } else { - Self::Call - } - } - MastNode::Dyn => Self::Dyn, - MastNode::External(_) => Self::External, - } - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index af7c58fb71..7bfa846f99 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -17,7 +17,7 @@ mod decorator; use decorator::EncodedDecoratorVariant; mod info; -use info::{EncodedMastNodeType, MastNodeInfo, MastNodeTypeVariant}; +use info::{MastNodeInfo, MastNodeType}; mod string_table_builder; use string_table_builder::StringTableBuilder; @@ -169,7 +169,7 @@ fn mast_node_to_info( ) -> MastNodeInfo { use MastNode::*; - let ty = EncodedMastNodeType::new(mast_node); + let ty = MastNodeType::new(mast_node); let digest = mast_node.digest(); let offset = match mast_node { @@ -387,12 +387,10 @@ fn try_info_to_mast_node( data: &[u8], strings: &[StringRef], ) -> Result { - let mast_node_variant = mast_node_info.ty.variant()?; - - let mast_node = match mast_node_variant { - MastNodeTypeVariant::Block => { - let num_operations_and_decorators = mast_node_info.ty.decode_u32_payload(); - + let mast_node = match mast_node_info.ty { + MastNodeType::Block { + len: num_operations_and_decorators, + } => { let (operations, decorators) = decode_operations_and_decorators( num_operations_and_decorators, data_reader, @@ -402,42 +400,41 @@ fn try_info_to_mast_node( Ok(MastNode::new_basic_block_with_decorators(operations, decorators)) } - MastNodeTypeVariant::Join => { - let (left_child, right_child) = mast_node_info.ty.decode_join_or_split(mast_forest)?; + MastNodeType::Join { + left_child_id, + right_child_id, + } => { + let left_child = MastNodeId::from_u32_safe(left_child_id, mast_forest)?; + let right_child = MastNodeId::from_u32_safe(right_child_id, mast_forest)?; Ok(MastNode::new_join(left_child, right_child, mast_forest)) } - MastNodeTypeVariant::Split => { - let (if_branch, else_branch) = mast_node_info.ty.decode_join_or_split(mast_forest)?; + MastNodeType::Split { + if_branch_id, + else_branch_id, + } => { + let if_branch = MastNodeId::from_u32_safe(if_branch_id, mast_forest)?; + let else_branch = MastNodeId::from_u32_safe(else_branch_id, mast_forest)?; Ok(MastNode::new_split(if_branch, else_branch, mast_forest)) } - MastNodeTypeVariant::Loop => { - let body_id = { - let body_id = mast_node_info.ty.decode_u32_payload(); - MastNodeId::from_u32_safe(body_id, mast_forest)? - }; + MastNodeType::Loop { body_id } => { + let body_id = MastNodeId::from_u32_safe(body_id, mast_forest)?; Ok(MastNode::new_loop(body_id, mast_forest)) } - MastNodeTypeVariant::Call => { - let callee_id = { - let callee_id = mast_node_info.ty.decode_u32_payload(); - MastNodeId::from_u32_safe(callee_id, mast_forest)? - }; + MastNodeType::Call { callee_id } => { + let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?; Ok(MastNode::new_call(callee_id, mast_forest)) } - MastNodeTypeVariant::Syscall => { - let callee_id = { - let callee_id = mast_node_info.ty.decode_u32_payload(); - MastNodeId::from_u32_safe(callee_id, mast_forest)? - }; + MastNodeType::SysCall { callee_id } => { + let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?; Ok(MastNode::new_syscall(callee_id, mast_forest)) } - MastNodeTypeVariant::Dyn => Ok(MastNode::new_dynexec()), - MastNodeTypeVariant::External => Ok(MastNode::new_external(mast_node_info.digest)), + MastNodeType::Dyn => Ok(MastNode::new_dynexec()), + MastNodeType::External => Ok(MastNode::new_external(mast_node_info.digest)), }?; if mast_node.digest() == mast_node_info.digest { From 578bda9a8e48541f0319d6ee2726edbb8308eb05 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 17:10:12 -0400 Subject: [PATCH 084/172] rename string_table_builder_ module --- .../{string_table_builder.rs => basic_block_data_builder.rs} | 0 core/src/mast/serialization/mod.rs | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) rename core/src/mast/serialization/{string_table_builder.rs => basic_block_data_builder.rs} (100%) diff --git a/core/src/mast/serialization/string_table_builder.rs b/core/src/mast/serialization/basic_block_data_builder.rs similarity index 100% rename from core/src/mast/serialization/string_table_builder.rs rename to core/src/mast/serialization/basic_block_data_builder.rs diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 7bfa846f99..d122cc1fd3 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -19,8 +19,8 @@ use decorator::EncodedDecoratorVariant; mod info; use info::{MastNodeInfo, MastNodeType}; -mod string_table_builder; -use string_table_builder::StringTableBuilder; +mod basic_block_data_builder; +use basic_block_data_builder::StringTableBuilder; #[cfg(test)] mod tests; From d25671ce8017153cb847a8c9a45e2ea2fab434b8 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 17:40:46 -0400 Subject: [PATCH 085/172] implement `BasicBlockDataBuilder` --- .../serialization/basic_block_data_builder.rs | 240 ++++++++++++++++- core/src/mast/serialization/info.rs | 21 +- core/src/mast/serialization/mod.rs | 241 +----------------- 3 files changed, 266 insertions(+), 236 deletions(-) diff --git a/core/src/mast/serialization/basic_block_data_builder.rs b/core/src/mast/serialization/basic_block_data_builder.rs index ea01174c8c..84e163916b 100644 --- a/core/src/mast/serialization/basic_block_data_builder.rs +++ b/core/src/mast/serialization/basic_block_data_builder.rs @@ -1,20 +1,248 @@ use alloc::{collections::BTreeMap, vec::Vec}; use miden_crypto::hash::rpo::{Rpo256, RpoDigest}; +use winter_utils::ByteWriter; -use super::{StringIndex, StringRef}; +use crate::{ + mast::{BasicBlockNode, OperationOrDecorator}, + AdviceInjector, DebugOptions, Decorator, Operation, SignatureKind, +}; + +use super::{decorator::EncodedDecoratorVariant, DataOffset, StringIndex, StringRef}; #[derive(Debug, Default)] -pub struct StringTableBuilder { - table: Vec, - str_to_index: BTreeMap, - strings_data: Vec, +pub struct BasicBlockDataBuilder { + data: Vec, + string_table_builder: StringTableBuilder, } -impl StringTableBuilder { +/// Constructors +impl BasicBlockDataBuilder { pub fn new() -> Self { Self::default() } +} +/// Accessors +impl BasicBlockDataBuilder { + pub fn current_data_offset(&self) -> DataOffset { + self.data + .len() + .try_into() + .expect("MAST forest data segment larger than 2^32 bytes") + } +} + +/// Mutators +impl BasicBlockDataBuilder { + pub fn encode_basic_block(&mut self, basic_block: &BasicBlockNode) { + // 2nd part of `mast_node_to_info()` (inside the match) + for op_or_decorator in basic_block.iter() { + match op_or_decorator { + OperationOrDecorator::Operation(operation) => self.encode_operation(operation), + OperationOrDecorator::Decorator(decorator) => self.encode_decorator(decorator), + } + } + } + + pub fn into_parts(mut self) -> (Vec, Vec) { + let string_table = self.string_table_builder.into_table(&mut self.data); + (self.data, string_table) + } +} + +/// Helpers +impl BasicBlockDataBuilder { + fn encode_operation(&mut self, operation: &Operation) { + self.data.push(operation.op_code()); + + // For operations that have extra data, encode it in `data`. + match operation { + Operation::Assert(value) | Operation::MpVerify(value) => { + self.data.extend_from_slice(&value.to_le_bytes()) + } + Operation::U32assert2(value) | Operation::Push(value) => { + self.data.extend_from_slice(&value.as_int().to_le_bytes()) + } + // Note: we explicitly write out all the operations so that whenever we make a + // modification to the `Operation` enum, we get a compile error here. This + // should help us remember to properly encode/decode each operation variant. + Operation::Noop + | Operation::FmpAdd + | Operation::FmpUpdate + | Operation::SDepth + | Operation::Caller + | Operation::Clk + | Operation::Join + | Operation::Split + | Operation::Loop + | Operation::Call + | Operation::Dyn + | Operation::SysCall + | Operation::Span + | Operation::End + | Operation::Repeat + | Operation::Respan + | Operation::Halt + | Operation::Add + | Operation::Neg + | Operation::Mul + | Operation::Inv + | Operation::Incr + | Operation::And + | Operation::Or + | Operation::Not + | Operation::Eq + | Operation::Eqz + | Operation::Expacc + | Operation::Ext2Mul + | Operation::U32split + | Operation::U32add + | Operation::U32add3 + | Operation::U32sub + | Operation::U32mul + | Operation::U32madd + | Operation::U32div + | Operation::U32and + | Operation::U32xor + | Operation::Pad + | Operation::Drop + | Operation::Dup0 + | Operation::Dup1 + | Operation::Dup2 + | Operation::Dup3 + | Operation::Dup4 + | Operation::Dup5 + | Operation::Dup6 + | Operation::Dup7 + | Operation::Dup9 + | Operation::Dup11 + | Operation::Dup13 + | Operation::Dup15 + | Operation::Swap + | Operation::SwapW + | Operation::SwapW2 + | Operation::SwapW3 + | Operation::SwapDW + | Operation::MovUp2 + | Operation::MovUp3 + | Operation::MovUp4 + | Operation::MovUp5 + | Operation::MovUp6 + | Operation::MovUp7 + | Operation::MovUp8 + | Operation::MovDn2 + | Operation::MovDn3 + | Operation::MovDn4 + | Operation::MovDn5 + | Operation::MovDn6 + | Operation::MovDn7 + | Operation::MovDn8 + | Operation::CSwap + | Operation::CSwapW + | Operation::AdvPop + | Operation::AdvPopW + | Operation::MLoadW + | Operation::MStoreW + | Operation::MLoad + | Operation::MStore + | Operation::MStream + | Operation::Pipe + | Operation::HPerm + | Operation::MrUpdate + | Operation::FriE2F4 + | Operation::RCombBase => (), + } + } + + fn encode_decorator(&mut self, decorator: &Decorator) { + // Set the first byte to the decorator discriminant. + // + // Note: the most significant bit is set to 1 (to differentiate decorators from operations). + { + let decorator_variant: EncodedDecoratorVariant = decorator.into(); + self.data.push(decorator_variant.discriminant() | 0b1000_0000); + } + + // For decorators that have extra data, encode it in `data` and `strings`. + match decorator { + Decorator::Advice(advice_injector) => match advice_injector { + AdviceInjector::MapValueToStack { + include_len, + key_offset, + } => { + self.data.write_bool(*include_len); + self.data.write_usize(*key_offset); + } + AdviceInjector::HdwordToMap { domain } => { + self.data.extend(domain.as_int().to_le_bytes()) + } + + // Note: Since there is only 1 variant, we don't need to write any extra bytes. + AdviceInjector::SigToStack { kind } => match kind { + SignatureKind::RpoFalcon512 => (), + }, + AdviceInjector::MerkleNodeMerge + | AdviceInjector::MerkleNodeToStack + | AdviceInjector::UpdateMerkleNode + | AdviceInjector::U64Div + | AdviceInjector::Ext2Inv + | AdviceInjector::Ext2Intt + | AdviceInjector::SmtGet + | AdviceInjector::SmtSet + | AdviceInjector::SmtPeek + | AdviceInjector::U32Clz + | AdviceInjector::U32Ctz + | AdviceInjector::U32Clo + | AdviceInjector::U32Cto + | AdviceInjector::ILog2 + | AdviceInjector::MemToMap + | AdviceInjector::HpermToMap => (), + }, + Decorator::AsmOp(assembly_op) => { + self.data.push(assembly_op.num_cycles()); + self.data.write_bool(assembly_op.should_break()); + + // context name + { + let str_index_in_table = + self.string_table_builder.add_string(assembly_op.context_name()); + self.data.write_usize(str_index_in_table); + } + + // op + { + let str_index_in_table = self.string_table_builder.add_string(assembly_op.op()); + self.data.write_usize(str_index_in_table); + } + } + Decorator::Debug(debug_options) => match debug_options { + DebugOptions::StackTop(value) => self.data.push(*value), + DebugOptions::MemInterval(start, end) => { + self.data.extend(start.to_le_bytes()); + self.data.extend(end.to_le_bytes()); + } + DebugOptions::LocalInterval(start, second, end) => { + self.data.extend(start.to_le_bytes()); + self.data.extend(second.to_le_bytes()); + self.data.extend(end.to_le_bytes()); + } + DebugOptions::StackAll | DebugOptions::MemAll => (), + }, + Decorator::Event(value) | Decorator::Trace(value) => { + self.data.extend(value.to_le_bytes()) + } + } + } +} + +#[derive(Debug, Default)] +struct StringTableBuilder { + table: Vec, + str_to_index: BTreeMap, + strings_data: Vec, +} + +impl StringTableBuilder { pub fn add_string(&mut self, string: &str) -> StringIndex { if let Some(str_idx) = self.str_to_index.get(&Rpo256::hash(string.as_bytes())) { // return already interned string diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 7d5c5ff440..955e1e9dc7 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -1,17 +1,36 @@ use miden_crypto::hash::rpo::RpoDigest; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; -use crate::mast::MastNode; +use crate::mast::{MastNode, MerkleTreeNode}; use super::DataOffset; #[derive(Debug)] pub struct MastNodeInfo { + // TODOP: Remove pub(super)? pub(super) ty: MastNodeType, pub(super) offset: DataOffset, pub(super) digest: RpoDigest, } +impl MastNodeInfo { + pub fn new(mast_node: &MastNode, basic_block_offset: DataOffset) -> Self { + let ty = MastNodeType::new(mast_node); + + let offset = if let MastNode::Block(_) = mast_node { + basic_block_offset + } else { + 0 + }; + + Self { + ty, + offset, + digest: mast_node.digest(), + } + } +} + impl Serializable for MastNodeInfo { fn write_into(&self, target: &mut W) { self.ty.write_into(target); diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index d122cc1fd3..357476e940 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -1,14 +1,12 @@ use alloc::{string::String, vec::Vec}; use miden_crypto::{Felt, ZERO}; -use num_traits::ToBytes; use winter_utils::{ ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, SliceReader, }; use crate::{ - mast::{MerkleTreeNode, OperationOrDecorator}, - AdviceInjector, AssemblyOp, DebugOptions, Decorator, DecoratorList, Operation, OperationData, - SignatureKind, + mast::MerkleTreeNode, AdviceInjector, AssemblyOp, DebugOptions, Decorator, DecoratorList, + Operation, OperationData, SignatureKind, }; use super::{MastForest, MastNode, MastNodeId}; @@ -20,7 +18,7 @@ mod info; use info::{MastNodeInfo, MastNodeType}; mod basic_block_data_builder; -use basic_block_data_builder::StringTableBuilder; +use basic_block_data_builder::BasicBlockDataBuilder; #[cfg(test)] mod tests; @@ -71,8 +69,7 @@ impl Deserializable for StringRef { impl Serializable for MastForest { fn write_into(&self, target: &mut W) { - let mut string_table_builder = StringTableBuilder::new(); - let mut data: Vec = Vec::new(); + let mut basic_block_data_builder = BasicBlockDataBuilder::new(); // magic & version target.write_bytes(MAGIC); @@ -86,16 +83,19 @@ impl Serializable for MastForest { // MAST node infos for mast_node in &self.nodes { - let mast_node_info = mast_node_to_info(mast_node, &mut data, &mut string_table_builder); + let mast_node_info = + MastNodeInfo::new(mast_node, basic_block_data_builder.current_data_offset()); + + if let MastNode::Block(basic_block) = mast_node { + basic_block_data_builder.encode_basic_block(basic_block); + } mast_node_info.write_into(target); } - // strings table - let strings = string_table_builder.into_table(&mut data); - strings.write_into(target); + let (data, string_table) = basic_block_data_builder.into_parts(); - // data blob + string_table.write_into(target); data.write_into(target); } } @@ -161,223 +161,6 @@ impl Deserializable for MastForest { } } -// TODOP: Make `MastNode` method (impl in this module)? -fn mast_node_to_info( - mast_node: &MastNode, - data: &mut Vec, - string_table_builder: &mut StringTableBuilder, -) -> MastNodeInfo { - use MastNode::*; - - let ty = MastNodeType::new(mast_node); - let digest = mast_node.digest(); - - let offset = match mast_node { - Block(basic_block) => { - let offset: u32 = data - .len() - .try_into() - .expect("MastForest serialization: data field larger than 2^32 bytes"); - - for op_or_decorator in basic_block.iter() { - match op_or_decorator { - OperationOrDecorator::Operation(operation) => encode_operation(operation, data), - OperationOrDecorator::Decorator(decorator) => { - encode_decorator(decorator, data, string_table_builder) - } - } - } - - offset - } - Join(_) | Split(_) | Loop(_) | Call(_) | Dyn | External(_) => 0, - }; - - MastNodeInfo { ty, offset, digest } -} - -fn encode_operation(operation: &Operation, data: &mut Vec) { - data.push(operation.op_code()); - - // For operations that have extra data, encode it in `data`. - match operation { - Operation::Assert(value) | Operation::MpVerify(value) => { - data.extend_from_slice(&value.to_le_bytes()) - } - Operation::U32assert2(value) | Operation::Push(value) => { - data.extend_from_slice(&value.as_int().to_le_bytes()) - } - // Note: we explicitly write out all the operations so that whenever we make a modification - // to the `Operation` enum, we get a compile error here. This should help us remember to - // properly encode/decode each operation variant. - Operation::Noop - | Operation::FmpAdd - | Operation::FmpUpdate - | Operation::SDepth - | Operation::Caller - | Operation::Clk - | Operation::Join - | Operation::Split - | Operation::Loop - | Operation::Call - | Operation::Dyn - | Operation::SysCall - | Operation::Span - | Operation::End - | Operation::Repeat - | Operation::Respan - | Operation::Halt - | Operation::Add - | Operation::Neg - | Operation::Mul - | Operation::Inv - | Operation::Incr - | Operation::And - | Operation::Or - | Operation::Not - | Operation::Eq - | Operation::Eqz - | Operation::Expacc - | Operation::Ext2Mul - | Operation::U32split - | Operation::U32add - | Operation::U32add3 - | Operation::U32sub - | Operation::U32mul - | Operation::U32madd - | Operation::U32div - | Operation::U32and - | Operation::U32xor - | Operation::Pad - | Operation::Drop - | Operation::Dup0 - | Operation::Dup1 - | Operation::Dup2 - | Operation::Dup3 - | Operation::Dup4 - | Operation::Dup5 - | Operation::Dup6 - | Operation::Dup7 - | Operation::Dup9 - | Operation::Dup11 - | Operation::Dup13 - | Operation::Dup15 - | Operation::Swap - | Operation::SwapW - | Operation::SwapW2 - | Operation::SwapW3 - | Operation::SwapDW - | Operation::MovUp2 - | Operation::MovUp3 - | Operation::MovUp4 - | Operation::MovUp5 - | Operation::MovUp6 - | Operation::MovUp7 - | Operation::MovUp8 - | Operation::MovDn2 - | Operation::MovDn3 - | Operation::MovDn4 - | Operation::MovDn5 - | Operation::MovDn6 - | Operation::MovDn7 - | Operation::MovDn8 - | Operation::CSwap - | Operation::CSwapW - | Operation::AdvPop - | Operation::AdvPopW - | Operation::MLoadW - | Operation::MStoreW - | Operation::MLoad - | Operation::MStore - | Operation::MStream - | Operation::Pipe - | Operation::HPerm - | Operation::MrUpdate - | Operation::FriE2F4 - | Operation::RCombBase => (), - } -} - -fn encode_decorator( - decorator: &Decorator, - data: &mut Vec, - string_table_builder: &mut StringTableBuilder, -) { - // Set the first byte to the decorator discriminant. - // - // Note: the most significant bit is set to 1 (to differentiate decorators from operations). - { - let decorator_variant: EncodedDecoratorVariant = decorator.into(); - data.push(decorator_variant.discriminant() | 0b1000_0000); - } - - // For decorators that have extra data, encode it in `data` and `strings`. - match decorator { - Decorator::Advice(advice_injector) => match advice_injector { - AdviceInjector::MapValueToStack { - include_len, - key_offset, - } => { - data.write_bool(*include_len); - data.write_usize(*key_offset); - } - AdviceInjector::HdwordToMap { domain } => data.extend(domain.as_int().to_le_bytes()), - - // Note: Since there is only 1 variant, we don't need to write any extra bytes. - AdviceInjector::SigToStack { kind } => match kind { - SignatureKind::RpoFalcon512 => (), - }, - AdviceInjector::MerkleNodeMerge - | AdviceInjector::MerkleNodeToStack - | AdviceInjector::UpdateMerkleNode - | AdviceInjector::U64Div - | AdviceInjector::Ext2Inv - | AdviceInjector::Ext2Intt - | AdviceInjector::SmtGet - | AdviceInjector::SmtSet - | AdviceInjector::SmtPeek - | AdviceInjector::U32Clz - | AdviceInjector::U32Ctz - | AdviceInjector::U32Clo - | AdviceInjector::U32Cto - | AdviceInjector::ILog2 - | AdviceInjector::MemToMap - | AdviceInjector::HpermToMap => (), - }, - Decorator::AsmOp(assembly_op) => { - data.push(assembly_op.num_cycles()); - data.write_bool(assembly_op.should_break()); - - // context name - { - let str_index_in_table = - string_table_builder.add_string(assembly_op.context_name()); - data.write_usize(str_index_in_table); - } - - // op - { - let str_index_in_table = string_table_builder.add_string(assembly_op.op()); - data.write_usize(str_index_in_table); - } - } - Decorator::Debug(debug_options) => match debug_options { - DebugOptions::StackTop(value) => data.push(*value), - DebugOptions::MemInterval(start, end) => { - data.extend(start.to_le_bytes()); - data.extend(end.to_le_bytes()); - } - DebugOptions::LocalInterval(start, second, end) => { - data.extend(start.to_le_bytes()); - data.extend(second.to_le_bytes()); - data.extend(end.to_le_bytes()); - } - DebugOptions::StackAll | DebugOptions::MemAll => (), - }, - Decorator::Event(value) | Decorator::Trace(value) => data.extend(value.to_le_bytes()), - } -} - // TODOP: Make `MastNodeInfo` method // TODOP: Can we not have both `data` and `data_reader`? fn try_info_to_mast_node( From 99f45442ba9b733161d02cdfe06333d9d99b907c Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 17:42:41 -0400 Subject: [PATCH 086/172] add TODOP --- core/src/mast/serialization/basic_block_data_builder.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/mast/serialization/basic_block_data_builder.rs b/core/src/mast/serialization/basic_block_data_builder.rs index 84e163916b..886f1a09d9 100644 --- a/core/src/mast/serialization/basic_block_data_builder.rs +++ b/core/src/mast/serialization/basic_block_data_builder.rs @@ -9,6 +9,7 @@ use crate::{ use super::{decorator::EncodedDecoratorVariant, DataOffset, StringIndex, StringRef}; +/// TODOP: Document #[derive(Debug, Default)] pub struct BasicBlockDataBuilder { data: Vec, From 0d206291a8d5cda5ef02b6eff1fed341d316bb7e Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 18:04:34 -0400 Subject: [PATCH 087/172] BasicBlockDataDecoder --- .../serialization/basic_block_data_decoder.rs | 248 ++++++++++++++++++ core/src/mast/serialization/mod.rs | 3 + 2 files changed, 251 insertions(+) create mode 100644 core/src/mast/serialization/basic_block_data_decoder.rs diff --git a/core/src/mast/serialization/basic_block_data_decoder.rs b/core/src/mast/serialization/basic_block_data_decoder.rs new file mode 100644 index 0000000000..75f2ab0bf4 --- /dev/null +++ b/core/src/mast/serialization/basic_block_data_decoder.rs @@ -0,0 +1,248 @@ +use crate::{ + AdviceInjector, AssemblyOp, DebugOptions, Decorator, DecoratorList, Operation, OperationData, + SignatureKind, +}; + +use super::{decorator::EncodedDecoratorVariant, StringIndex, StringRef}; +use alloc::{string::String, vec::Vec}; +use miden_crypto::{Felt, ZERO}; +use winter_utils::{ByteReader, DeserializationError, SliceReader}; + +pub struct BasicBlockDataDecoder<'a> { + data: &'a [u8], + data_reader: SliceReader<'a>, + strings: &'a [StringRef], +} + +/// Constructors +impl<'a> BasicBlockDataDecoder<'a> { + pub fn new(data: &'a [u8], strings: &'a [StringRef]) -> Self { + let data_reader = SliceReader::new(data); + + Self { + data, + data_reader, + strings, + } + } +} + +/// Mutators +impl<'a> BasicBlockDataDecoder<'a> { + pub fn decode_operations_and_decorators( + &mut self, + num_to_decode: u32, + ) -> Result<(Vec, DecoratorList), DeserializationError> { + let mut operations: Vec = Vec::new(); + let mut decorators: DecoratorList = Vec::new(); + + for _ in 0..num_to_decode { + let first_byte = self.data_reader.read_u8()?; + + if first_byte & 0b1000_0000 == 0 { + // operation. + let op_code = first_byte; + + let maybe_operation = if op_code == Operation::Assert(0_u32).op_code() + || op_code == Operation::MpVerify(0_u32).op_code() + { + let value_le_bytes: [u8; 4] = self.data_reader.read_array()?; + let value = u32::from_le_bytes(value_le_bytes); + + Operation::with_opcode_and_data(op_code, OperationData::U32(value)) + } else if op_code == Operation::U32assert2(ZERO).op_code() + || op_code == Operation::Push(ZERO).op_code() + { + // Felt operation data + let value_le_bytes: [u8; 8] = self.data_reader.read_array()?; + let value_u64 = u64::from_le_bytes(value_le_bytes); + let value_felt = Felt::try_from(value_u64).map_err(|_| { + DeserializationError::InvalidValue(format!( + "Operation associated data doesn't fit in a field element: {value_u64}" + )) + })?; + + Operation::with_opcode_and_data(op_code, OperationData::Felt(value_felt)) + } else { + // No operation data + Operation::with_opcode_and_data(op_code, OperationData::None) + }; + + let operation = maybe_operation.ok_or_else(|| { + DeserializationError::InvalidValue(format!("invalid op code: {op_code}")) + })?; + + operations.push(operation); + } else { + // decorator. + let discriminant = first_byte & 0b0111_1111; + let decorator = self.decode_decorator(discriminant)?; + + decorators.push((operations.len(), decorator)); + } + } + + Ok((operations, decorators)) + } +} + +/// Helpers +impl<'a> BasicBlockDataDecoder<'a> { + fn decode_decorator(&mut self, discriminant: u8) -> Result { + let decorator_variant = EncodedDecoratorVariant::from_discriminant(discriminant) + .ok_or_else(|| { + DeserializationError::InvalidValue(format!( + "invalid decorator variant discriminant: {discriminant}" + )) + })?; + + match decorator_variant { + EncodedDecoratorVariant::AdviceInjectorMerkleNodeMerge => { + Ok(Decorator::Advice(AdviceInjector::MerkleNodeMerge)) + } + EncodedDecoratorVariant::AdviceInjectorMerkleNodeToStack => { + Ok(Decorator::Advice(AdviceInjector::MerkleNodeToStack)) + } + EncodedDecoratorVariant::AdviceInjectorUpdateMerkleNode => { + Ok(Decorator::Advice(AdviceInjector::UpdateMerkleNode)) + } + EncodedDecoratorVariant::AdviceInjectorMapValueToStack => { + let include_len = self.data_reader.read_bool()?; + let key_offset = self.data_reader.read_usize()?; + + Ok(Decorator::Advice(AdviceInjector::MapValueToStack { + include_len, + key_offset, + })) + } + EncodedDecoratorVariant::AdviceInjectorU64Div => { + Ok(Decorator::Advice(AdviceInjector::U64Div)) + } + EncodedDecoratorVariant::AdviceInjectorExt2Inv => { + Ok(Decorator::Advice(AdviceInjector::Ext2Inv)) + } + EncodedDecoratorVariant::AdviceInjectorExt2Intt => { + Ok(Decorator::Advice(AdviceInjector::Ext2Intt)) + } + EncodedDecoratorVariant::AdviceInjectorSmtGet => { + Ok(Decorator::Advice(AdviceInjector::SmtGet)) + } + EncodedDecoratorVariant::AdviceInjectorSmtSet => { + Ok(Decorator::Advice(AdviceInjector::SmtSet)) + } + EncodedDecoratorVariant::AdviceInjectorSmtPeek => { + Ok(Decorator::Advice(AdviceInjector::SmtPeek)) + } + EncodedDecoratorVariant::AdviceInjectorU32Clz => { + Ok(Decorator::Advice(AdviceInjector::U32Clz)) + } + EncodedDecoratorVariant::AdviceInjectorU32Ctz => { + Ok(Decorator::Advice(AdviceInjector::U32Ctz)) + } + EncodedDecoratorVariant::AdviceInjectorU32Clo => { + Ok(Decorator::Advice(AdviceInjector::U32Clo)) + } + EncodedDecoratorVariant::AdviceInjectorU32Cto => { + Ok(Decorator::Advice(AdviceInjector::U32Cto)) + } + EncodedDecoratorVariant::AdviceInjectorILog2 => { + Ok(Decorator::Advice(AdviceInjector::ILog2)) + } + EncodedDecoratorVariant::AdviceInjectorMemToMap => { + Ok(Decorator::Advice(AdviceInjector::MemToMap)) + } + EncodedDecoratorVariant::AdviceInjectorHdwordToMap => { + let domain = self.data_reader.read_u64()?; + let domain = Felt::try_from(domain).map_err(|err| { + DeserializationError::InvalidValue(format!( + "Error when deserializing HdwordToMap decorator domain: {err}" + )) + })?; + + Ok(Decorator::Advice(AdviceInjector::HdwordToMap { domain })) + } + EncodedDecoratorVariant::AdviceInjectorHpermToMap => { + Ok(Decorator::Advice(AdviceInjector::HpermToMap)) + } + EncodedDecoratorVariant::AdviceInjectorSigToStack => { + Ok(Decorator::Advice(AdviceInjector::SigToStack { + kind: SignatureKind::RpoFalcon512, + })) + } + EncodedDecoratorVariant::AssemblyOp => { + let num_cycles = self.data_reader.read_u8()?; + let should_break = self.data_reader.read_bool()?; + + let context_name = { + let str_index_in_table = self.data_reader.read_usize()?; + self.read_string(str_index_in_table)? + }; + + let op = { + let str_index_in_table = self.data_reader.read_usize()?; + self.read_string(str_index_in_table)? + }; + + Ok(Decorator::AsmOp(AssemblyOp::new(context_name, num_cycles, op, should_break))) + } + EncodedDecoratorVariant::DebugOptionsStackAll => { + Ok(Decorator::Debug(DebugOptions::StackAll)) + } + EncodedDecoratorVariant::DebugOptionsStackTop => { + let value = self.data_reader.read_u8()?; + + Ok(Decorator::Debug(DebugOptions::StackTop(value))) + } + EncodedDecoratorVariant::DebugOptionsMemAll => { + Ok(Decorator::Debug(DebugOptions::MemAll)) + } + EncodedDecoratorVariant::DebugOptionsMemInterval => { + let start = u32::from_le_bytes(self.data_reader.read_array::<4>()?); + let end = u32::from_le_bytes(self.data_reader.read_array::<4>()?); + + Ok(Decorator::Debug(DebugOptions::MemInterval(start, end))) + } + EncodedDecoratorVariant::DebugOptionsLocalInterval => { + let start = u16::from_le_bytes(self.data_reader.read_array::<2>()?); + let second = u16::from_le_bytes(self.data_reader.read_array::<2>()?); + let end = u16::from_le_bytes(self.data_reader.read_array::<2>()?); + + Ok(Decorator::Debug(DebugOptions::LocalInterval(start, second, end))) + } + EncodedDecoratorVariant::Event => { + let value = u32::from_le_bytes(self.data_reader.read_array::<4>()?); + + Ok(Decorator::Event(value)) + } + EncodedDecoratorVariant::Trace => { + let value = u32::from_le_bytes(self.data_reader.read_array::<4>()?); + + Ok(Decorator::Trace(value)) + } + } + } + + fn read_string(&self, str_idx: StringIndex) -> Result { + let str_ref = self.strings.get(str_idx).ok_or_else(|| { + DeserializationError::InvalidValue(format!("invalid index in strings table: {str_idx}")) + })?; + + let str_bytes = { + let start = str_ref.offset as usize; + let end = (str_ref.offset + str_ref.len) as usize; + + self.data.get(start..end).ok_or_else(|| { + DeserializationError::InvalidValue(format!( + "invalid string ref in strings table. Offset: {}, length: {}", + str_ref.offset, str_ref.len + )) + })? + }; + + String::from_utf8(str_bytes.to_vec()).map_err(|_| { + DeserializationError::InvalidValue(format!( + "Invalid UTF-8 string in strings table: {str_bytes:?}" + )) + }) + } +} diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 357476e940..298b751c49 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -20,6 +20,9 @@ use info::{MastNodeInfo, MastNodeType}; mod basic_block_data_builder; use basic_block_data_builder::BasicBlockDataBuilder; +mod basic_block_data_decoder; +use basic_block_data_decoder::BasicBlockDataDecoder; + #[cfg(test)] mod tests; From b66f81bd362225fac0acad63a806c258c472c75a Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 18:08:23 -0400 Subject: [PATCH 088/172] use `BasicBlockDataDecoder` --- core/src/mast/serialization/mod.rs | 259 +-------------------------- core/src/mast/serialization/tests.rs | 6 +- 2 files changed, 13 insertions(+), 252 deletions(-) diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 298b751c49..fc0564a3e9 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -1,18 +1,11 @@ -use alloc::{string::String, vec::Vec}; -use miden_crypto::{Felt, ZERO}; -use winter_utils::{ - ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, SliceReader, -}; +use alloc::vec::Vec; +use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; -use crate::{ - mast::MerkleTreeNode, AdviceInjector, AssemblyOp, DebugOptions, Decorator, DecoratorList, - Operation, OperationData, SignatureKind, -}; +use crate::mast::MerkleTreeNode; use super::{MastForest, MastNode, MastNodeId}; mod decorator; -use decorator::EncodedDecoratorVariant; mod info; use info::{MastNodeInfo, MastNodeType}; @@ -137,7 +130,8 @@ impl Deserializable for MastForest { let strings: Vec = Deserializable::read_from(source)?; let data: Vec = Deserializable::read_from(source)?; - let mut data_reader = SliceReader::new(&data); + + let mut basic_block_data_decoder = BasicBlockDataDecoder::new(&data, &strings); let mast_forest = { let mut mast_forest = MastForest::new(); @@ -146,9 +140,7 @@ impl Deserializable for MastForest { let node = try_info_to_mast_node( mast_node_info, &mast_forest, - &mut data_reader, - &data, - &strings, + &mut basic_block_data_decoder, )?; mast_forest.add_node(node); } @@ -165,24 +157,17 @@ impl Deserializable for MastForest { } // TODOP: Make `MastNodeInfo` method -// TODOP: Can we not have both `data` and `data_reader`? fn try_info_to_mast_node( mast_node_info: MastNodeInfo, mast_forest: &MastForest, - data_reader: &mut SliceReader, - data: &[u8], - strings: &[StringRef], + basic_block_data_decoder: &mut BasicBlockDataDecoder, ) -> Result { let mast_node = match mast_node_info.ty { MastNodeType::Block { len: num_operations_and_decorators, } => { - let (operations, decorators) = decode_operations_and_decorators( - num_operations_and_decorators, - data_reader, - data, - strings, - )?; + let (operations, decorators) = basic_block_data_decoder + .decode_operations_and_decorators(num_operations_and_decorators)?; Ok(MastNode::new_basic_block_with_decorators(operations, decorators)) } @@ -233,229 +218,3 @@ fn try_info_to_mast_node( ))) } } - -fn decode_operations_and_decorators( - num_to_decode: u32, - data_reader: &mut SliceReader, - data: &[u8], - strings: &[StringRef], -) -> Result<(Vec, DecoratorList), DeserializationError> { - let mut operations: Vec = Vec::new(); - let mut decorators: DecoratorList = Vec::new(); - - for _ in 0..num_to_decode { - let first_byte = data_reader.read_u8()?; - - if first_byte & 0b1000_0000 == 0 { - // operation. - let op_code = first_byte; - - let maybe_operation = if op_code == Operation::Assert(0_u32).op_code() - || op_code == Operation::MpVerify(0_u32).op_code() - { - let value_le_bytes: [u8; 4] = data_reader.read_array()?; - let value = u32::from_le_bytes(value_le_bytes); - - Operation::with_opcode_and_data(op_code, OperationData::U32(value)) - } else if op_code == Operation::U32assert2(ZERO).op_code() - || op_code == Operation::Push(ZERO).op_code() - { - // Felt operation data - let value_le_bytes: [u8; 8] = data_reader.read_array()?; - let value_u64 = u64::from_le_bytes(value_le_bytes); - let value_felt = Felt::try_from(value_u64).map_err(|_| { - DeserializationError::InvalidValue(format!( - "Operation associated data doesn't fit in a field element: {value_u64}" - )) - })?; - - Operation::with_opcode_and_data(op_code, OperationData::Felt(value_felt)) - } else { - // No operation data - Operation::with_opcode_and_data(op_code, OperationData::None) - }; - - let operation = maybe_operation.ok_or_else(|| { - DeserializationError::InvalidValue(format!("invalid op code: {op_code}")) - })?; - - operations.push(operation); - } else { - // decorator. - let discriminant = first_byte & 0b0111_1111; - let decorator = decode_decorator(discriminant, data_reader, data, strings)?; - - decorators.push((operations.len(), decorator)); - } - } - - Ok((operations, decorators)) -} - -fn decode_decorator( - discriminant: u8, - data_reader: &mut SliceReader, - data: &[u8], - strings: &[StringRef], -) -> Result { - let decorator_variant = - EncodedDecoratorVariant::from_discriminant(discriminant).ok_or_else(|| { - DeserializationError::InvalidValue(format!( - "invalid decorator variant discriminant: {discriminant}" - )) - })?; - - match decorator_variant { - EncodedDecoratorVariant::AdviceInjectorMerkleNodeMerge => { - Ok(Decorator::Advice(AdviceInjector::MerkleNodeMerge)) - } - EncodedDecoratorVariant::AdviceInjectorMerkleNodeToStack => { - Ok(Decorator::Advice(AdviceInjector::MerkleNodeToStack)) - } - EncodedDecoratorVariant::AdviceInjectorUpdateMerkleNode => { - Ok(Decorator::Advice(AdviceInjector::UpdateMerkleNode)) - } - EncodedDecoratorVariant::AdviceInjectorMapValueToStack => { - let include_len = data_reader.read_bool()?; - let key_offset = data_reader.read_usize()?; - - Ok(Decorator::Advice(AdviceInjector::MapValueToStack { - include_len, - key_offset, - })) - } - EncodedDecoratorVariant::AdviceInjectorU64Div => { - Ok(Decorator::Advice(AdviceInjector::U64Div)) - } - EncodedDecoratorVariant::AdviceInjectorExt2Inv => { - Ok(Decorator::Advice(AdviceInjector::Ext2Inv)) - } - EncodedDecoratorVariant::AdviceInjectorExt2Intt => { - Ok(Decorator::Advice(AdviceInjector::Ext2Intt)) - } - EncodedDecoratorVariant::AdviceInjectorSmtGet => { - Ok(Decorator::Advice(AdviceInjector::SmtGet)) - } - EncodedDecoratorVariant::AdviceInjectorSmtSet => { - Ok(Decorator::Advice(AdviceInjector::SmtSet)) - } - EncodedDecoratorVariant::AdviceInjectorSmtPeek => { - Ok(Decorator::Advice(AdviceInjector::SmtPeek)) - } - EncodedDecoratorVariant::AdviceInjectorU32Clz => { - Ok(Decorator::Advice(AdviceInjector::U32Clz)) - } - EncodedDecoratorVariant::AdviceInjectorU32Ctz => { - Ok(Decorator::Advice(AdviceInjector::U32Ctz)) - } - EncodedDecoratorVariant::AdviceInjectorU32Clo => { - Ok(Decorator::Advice(AdviceInjector::U32Clo)) - } - EncodedDecoratorVariant::AdviceInjectorU32Cto => { - Ok(Decorator::Advice(AdviceInjector::U32Cto)) - } - EncodedDecoratorVariant::AdviceInjectorILog2 => { - Ok(Decorator::Advice(AdviceInjector::ILog2)) - } - EncodedDecoratorVariant::AdviceInjectorMemToMap => { - Ok(Decorator::Advice(AdviceInjector::MemToMap)) - } - EncodedDecoratorVariant::AdviceInjectorHdwordToMap => { - let domain = data_reader.read_u64()?; - let domain = Felt::try_from(domain).map_err(|err| { - DeserializationError::InvalidValue(format!( - "Error when deserializing HdwordToMap decorator domain: {err}" - )) - })?; - - Ok(Decorator::Advice(AdviceInjector::HdwordToMap { domain })) - } - EncodedDecoratorVariant::AdviceInjectorHpermToMap => { - Ok(Decorator::Advice(AdviceInjector::HpermToMap)) - } - EncodedDecoratorVariant::AdviceInjectorSigToStack => { - Ok(Decorator::Advice(AdviceInjector::SigToStack { - kind: SignatureKind::RpoFalcon512, - })) - } - EncodedDecoratorVariant::AssemblyOp => { - let num_cycles = data_reader.read_u8()?; - let should_break = data_reader.read_bool()?; - - let context_name = { - let str_index_in_table = data_reader.read_usize()?; - read_string(str_index_in_table, data, strings)? - }; - - let op = { - let str_index_in_table = data_reader.read_usize()?; - read_string(str_index_in_table, data, strings)? - }; - - Ok(Decorator::AsmOp(AssemblyOp::new(context_name, num_cycles, op, should_break))) - } - EncodedDecoratorVariant::DebugOptionsStackAll => { - Ok(Decorator::Debug(DebugOptions::StackAll)) - } - EncodedDecoratorVariant::DebugOptionsStackTop => { - let value = data_reader.read_u8()?; - - Ok(Decorator::Debug(DebugOptions::StackTop(value))) - } - EncodedDecoratorVariant::DebugOptionsMemAll => Ok(Decorator::Debug(DebugOptions::MemAll)), - EncodedDecoratorVariant::DebugOptionsMemInterval => { - let start = u32::from_le_bytes(data_reader.read_array::<4>()?); - let end = u32::from_le_bytes(data_reader.read_array::<4>()?); - - Ok(Decorator::Debug(DebugOptions::MemInterval(start, end))) - } - EncodedDecoratorVariant::DebugOptionsLocalInterval => { - let start = u16::from_le_bytes(data_reader.read_array::<2>()?); - let second = u16::from_le_bytes(data_reader.read_array::<2>()?); - let end = u16::from_le_bytes(data_reader.read_array::<2>()?); - - Ok(Decorator::Debug(DebugOptions::LocalInterval(start, second, end))) - } - EncodedDecoratorVariant::Event => { - let value = u32::from_le_bytes(data_reader.read_array::<4>()?); - - Ok(Decorator::Event(value)) - } - EncodedDecoratorVariant::Trace => { - let value = u32::from_le_bytes(data_reader.read_array::<4>()?); - - Ok(Decorator::Trace(value)) - } - } -} - -// TODOP: Rename and/or move to some struct? -fn read_string( - str_index_in_table: usize, - data: &[u8], - strings: &[StringRef], -) -> Result { - let str_ref = strings.get(str_index_in_table).ok_or_else(|| { - DeserializationError::InvalidValue(format!( - "invalid index in strings table: {str_index_in_table}" - )) - })?; - - let str_bytes = { - let start = str_ref.offset as usize; - let end = (str_ref.offset + str_ref.len) as usize; - - data.get(start..end).ok_or_else(|| { - DeserializationError::InvalidValue(format!( - "invalid string ref in strings table. Offset: {}, length: {}", - str_ref.offset, str_ref.len - )) - })? - }; - - String::from_utf8(str_bytes.to_vec()).map_err(|_| { - DeserializationError::InvalidValue(format!( - "Invalid UTF-8 string in strings table: {str_bytes:?}" - )) - }) -} diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs index 2dd3f151ee..8903fd5741 100644 --- a/core/src/mast/serialization/tests.rs +++ b/core/src/mast/serialization/tests.rs @@ -1,9 +1,11 @@ use alloc::string::ToString; use math::FieldElement; -use miden_crypto::hash::rpo::RpoDigest; +use miden_crypto::{hash::rpo::RpoDigest, Felt}; use super::*; -use crate::operations::Operation; +use crate::{ + operations::Operation, AdviceInjector, AssemblyOp, DebugOptions, Decorator, SignatureKind, +}; /// If this test fails to compile, it means that `Operation` or `Decorator` was changed. Make sure /// that all tests in this file are updated accordingly. For example, if a new `Operation` variant From f9a3a0bed49c555ec19609b2b4b2933d8d14dcc9 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 18:10:41 -0400 Subject: [PATCH 089/172] add headers --- core/src/mast/serialization/info.rs | 9 +++++++++ core/src/mast/serialization/mod.rs | 12 ++++++++++++ 2 files changed, 21 insertions(+) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 955e1e9dc7..7a5aee9771 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -5,6 +5,9 @@ use crate::mast::{MastNode, MerkleTreeNode}; use super::DataOffset; +// MAST NODE INFO +// =============================================================================================== + #[derive(Debug)] pub struct MastNodeInfo { // TODOP: Remove pub(super)? @@ -49,6 +52,9 @@ impl Deserializable for MastNodeInfo { } } +// MAST NODE TYPE +// =============================================================================================== + const JOIN: u8 = 0; const SPLIT: u8 = 1; const LOOP: u8 = 2; @@ -311,6 +317,9 @@ impl MastNodeType { } } +// TESTS +// =============================================================================================== + #[cfg(test)] mod tests { use super::*; diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index fc0564a3e9..c1a2ed2ee0 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -19,12 +19,18 @@ use basic_block_data_decoder::BasicBlockDataDecoder; #[cfg(test)] mod tests; +// TYPE ALIASES +// =============================================================================================== + /// Specifies an offset into the `data` section of an encoded [`MastForest`]. type DataOffset = u32; /// Specifies an offset into the `strings` table of an encoded [`MastForest`] type StringIndex = usize; +// CONSTANTS +// =============================================================================================== + /// Magic string for detecting that a file is binary-encoded MAST. const MAGIC: &[u8; 5] = b"MAST\0"; @@ -35,6 +41,9 @@ const MAGIC: &[u8; 5] = b"MAST\0"; /// version field itself, but should be considered invalid for now. const VERSION: [u8; 3] = [0, 0, 0]; +// STRING REF +// =============================================================================================== + /// An entry in the `strings` table of an encoded [`MastForest`]. /// /// Strings are UTF8-encoded. @@ -63,6 +72,9 @@ impl Deserializable for StringRef { } } +// MAST FOREST SERIALIZATION/DESERIALIZATION +// =============================================================================================== + impl Serializable for MastForest { fn write_into(&self, target: &mut W) { let mut basic_block_data_builder = BasicBlockDataBuilder::new(); From df621c5b6871e91d0785292667331c69536ad29d Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 18:15:56 -0400 Subject: [PATCH 090/172] add `MastNodeInfo` method --- core/src/mast/serialization/info.rs | 73 +++++++++++++++++++++++++--- core/src/mast/serialization/mod.rs | 75 ++--------------------------- 2 files changed, 71 insertions(+), 77 deletions(-) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 7a5aee9771..bf89539b8e 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -1,19 +1,18 @@ use miden_crypto::hash::rpo::RpoDigest; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; -use crate::mast::{MastNode, MerkleTreeNode}; +use crate::mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}; -use super::DataOffset; +use super::{basic_block_data_decoder::BasicBlockDataDecoder, DataOffset}; // MAST NODE INFO // =============================================================================================== #[derive(Debug)] pub struct MastNodeInfo { - // TODOP: Remove pub(super)? - pub(super) ty: MastNodeType, - pub(super) offset: DataOffset, - pub(super) digest: RpoDigest, + ty: MastNodeType, + offset: DataOffset, + digest: RpoDigest, } impl MastNodeInfo { @@ -32,6 +31,68 @@ impl MastNodeInfo { digest: mast_node.digest(), } } + + pub fn try_into_mast_node( + self, + mast_forest: &MastForest, + basic_block_data_decoder: &mut BasicBlockDataDecoder, + ) -> Result { + let mast_node = match self.ty { + MastNodeType::Block { + len: num_operations_and_decorators, + } => { + let (operations, decorators) = basic_block_data_decoder + .decode_operations_and_decorators(num_operations_and_decorators)?; + + Ok(MastNode::new_basic_block_with_decorators(operations, decorators)) + } + MastNodeType::Join { + left_child_id, + right_child_id, + } => { + let left_child = MastNodeId::from_u32_safe(left_child_id, mast_forest)?; + let right_child = MastNodeId::from_u32_safe(right_child_id, mast_forest)?; + + Ok(MastNode::new_join(left_child, right_child, mast_forest)) + } + MastNodeType::Split { + if_branch_id, + else_branch_id, + } => { + let if_branch = MastNodeId::from_u32_safe(if_branch_id, mast_forest)?; + let else_branch = MastNodeId::from_u32_safe(else_branch_id, mast_forest)?; + + Ok(MastNode::new_split(if_branch, else_branch, mast_forest)) + } + MastNodeType::Loop { body_id } => { + let body_id = MastNodeId::from_u32_safe(body_id, mast_forest)?; + + Ok(MastNode::new_loop(body_id, mast_forest)) + } + MastNodeType::Call { callee_id } => { + let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?; + + Ok(MastNode::new_call(callee_id, mast_forest)) + } + MastNodeType::SysCall { callee_id } => { + let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?; + + Ok(MastNode::new_syscall(callee_id, mast_forest)) + } + MastNodeType::Dyn => Ok(MastNode::new_dynexec()), + MastNodeType::External => Ok(MastNode::new_external(self.digest)), + }?; + + if mast_node.digest() == self.digest { + Ok(mast_node) + } else { + Err(DeserializationError::InvalidValue(format!( + "MastNodeInfo's digest '{}' doesn't match deserialized MastNode's digest '{}'", + self.digest, + mast_node.digest() + ))) + } + } } impl Serializable for MastNodeInfo { diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index c1a2ed2ee0..0caaf2f96e 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -1,14 +1,12 @@ use alloc::vec::Vec; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; -use crate::mast::MerkleTreeNode; - use super::{MastForest, MastNode, MastNodeId}; mod decorator; mod info; -use info::{MastNodeInfo, MastNodeType}; +use info::MastNodeInfo; mod basic_block_data_builder; use basic_block_data_builder::BasicBlockDataBuilder; @@ -149,11 +147,9 @@ impl Deserializable for MastForest { let mut mast_forest = MastForest::new(); for mast_node_info in mast_node_infos { - let node = try_info_to_mast_node( - mast_node_info, - &mast_forest, - &mut basic_block_data_decoder, - )?; + let node = mast_node_info + .try_into_mast_node(&mast_forest, &mut basic_block_data_decoder)?; + mast_forest.add_node(node); } @@ -167,66 +163,3 @@ impl Deserializable for MastForest { Ok(mast_forest) } } - -// TODOP: Make `MastNodeInfo` method -fn try_info_to_mast_node( - mast_node_info: MastNodeInfo, - mast_forest: &MastForest, - basic_block_data_decoder: &mut BasicBlockDataDecoder, -) -> Result { - let mast_node = match mast_node_info.ty { - MastNodeType::Block { - len: num_operations_and_decorators, - } => { - let (operations, decorators) = basic_block_data_decoder - .decode_operations_and_decorators(num_operations_and_decorators)?; - - Ok(MastNode::new_basic_block_with_decorators(operations, decorators)) - } - MastNodeType::Join { - left_child_id, - right_child_id, - } => { - let left_child = MastNodeId::from_u32_safe(left_child_id, mast_forest)?; - let right_child = MastNodeId::from_u32_safe(right_child_id, mast_forest)?; - - Ok(MastNode::new_join(left_child, right_child, mast_forest)) - } - MastNodeType::Split { - if_branch_id, - else_branch_id, - } => { - let if_branch = MastNodeId::from_u32_safe(if_branch_id, mast_forest)?; - let else_branch = MastNodeId::from_u32_safe(else_branch_id, mast_forest)?; - - Ok(MastNode::new_split(if_branch, else_branch, mast_forest)) - } - MastNodeType::Loop { body_id } => { - let body_id = MastNodeId::from_u32_safe(body_id, mast_forest)?; - - Ok(MastNode::new_loop(body_id, mast_forest)) - } - MastNodeType::Call { callee_id } => { - let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?; - - Ok(MastNode::new_call(callee_id, mast_forest)) - } - MastNodeType::SysCall { callee_id } => { - let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?; - - Ok(MastNode::new_syscall(callee_id, mast_forest)) - } - MastNodeType::Dyn => Ok(MastNode::new_dynexec()), - MastNodeType::External => Ok(MastNode::new_external(mast_node_info.digest)), - }?; - - if mast_node.digest() == mast_node_info.digest { - Ok(mast_node) - } else { - Err(DeserializationError::InvalidValue(format!( - "MastNodeInfo's digest '{}' doesn't match deserialized MastNode's digest '{}'", - mast_node_info.digest, - mast_node.digest() - ))) - } -} From 2afd588173f25cb4a879284a8b366c34ce7bd46c Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 18:26:02 -0400 Subject: [PATCH 091/172] return `Result` instead of `Option` --- .../serialization/basic_block_data_decoder.rs | 12 +- core/src/operations/mod.rs | 215 ++++++++++-------- 2 files changed, 119 insertions(+), 108 deletions(-) diff --git a/core/src/mast/serialization/basic_block_data_decoder.rs b/core/src/mast/serialization/basic_block_data_decoder.rs index 75f2ab0bf4..b002b553f8 100644 --- a/core/src/mast/serialization/basic_block_data_decoder.rs +++ b/core/src/mast/serialization/basic_block_data_decoder.rs @@ -43,13 +43,13 @@ impl<'a> BasicBlockDataDecoder<'a> { // operation. let op_code = first_byte; - let maybe_operation = if op_code == Operation::Assert(0_u32).op_code() + let operation = if op_code == Operation::Assert(0_u32).op_code() || op_code == Operation::MpVerify(0_u32).op_code() { let value_le_bytes: [u8; 4] = self.data_reader.read_array()?; let value = u32::from_le_bytes(value_le_bytes); - Operation::with_opcode_and_data(op_code, OperationData::U32(value)) + Operation::with_opcode_and_data(op_code, OperationData::U32(value))? } else if op_code == Operation::U32assert2(ZERO).op_code() || op_code == Operation::Push(ZERO).op_code() { @@ -62,16 +62,12 @@ impl<'a> BasicBlockDataDecoder<'a> { )) })?; - Operation::with_opcode_and_data(op_code, OperationData::Felt(value_felt)) + Operation::with_opcode_and_data(op_code, OperationData::Felt(value_felt))? } else { // No operation data - Operation::with_opcode_and_data(op_code, OperationData::None) + Operation::with_opcode_and_data(op_code, OperationData::None)? }; - let operation = maybe_operation.ok_or_else(|| { - DeserializationError::InvalidValue(format!("invalid op code: {op_code}")) - })?; - operations.push(operation); } else { // decorator. diff --git a/core/src/operations/mod.rs b/core/src/operations/mod.rs index 8321db1c43..ab201933bd 100644 --- a/core/src/operations/mod.rs +++ b/core/src/operations/mod.rs @@ -1,10 +1,12 @@ use super::Felt; use core::fmt; mod decorators; +use alloc::string::ToString; pub use decorators::{ AdviceInjector, AssemblyOp, DebugOptions, Decorator, DecoratorIterator, DecoratorList, SignatureKind, }; +use winter_utils::DeserializationError; // OPERATIONS // ================================================================================================ @@ -449,123 +451,136 @@ pub enum OperationData { /// Constructors impl Operation { - // TODOP: document, and use `Result` instead? - pub fn with_opcode_and_data(opcode: u8, data: OperationData) -> Option { + // TODOP: document + pub fn with_opcode_and_data( + opcode: u8, + data: OperationData, + ) -> Result { match opcode { - 0b0000_0000 => Some(Self::Noop), - 0b0000_0001 => Some(Self::Eqz), - 0b0000_0010 => Some(Self::Neg), - 0b0000_0011 => Some(Self::Inv), - 0b0000_0100 => Some(Self::Incr), - 0b0000_0101 => Some(Self::Not), - 0b0000_0110 => Some(Self::FmpAdd), - 0b0000_0111 => Some(Self::MLoad), - 0b0000_1000 => Some(Self::Swap), - 0b0000_1001 => Some(Self::Caller), - 0b0000_1010 => Some(Self::MovUp2), - 0b0000_1011 => Some(Self::MovDn2), - 0b0000_1100 => Some(Self::MovUp3), - 0b0000_1101 => Some(Self::MovDn3), - 0b0000_1110 => Some(Self::AdvPopW), - 0b0000_1111 => Some(Self::Expacc), - - 0b0001_0000 => Some(Self::MovUp4), - 0b0001_0001 => Some(Self::MovDn4), - 0b0001_0010 => Some(Self::MovUp5), - 0b0001_0011 => Some(Self::MovDn5), - 0b0001_0100 => Some(Self::MovUp6), - 0b0001_0101 => Some(Self::MovDn6), - 0b0001_0110 => Some(Self::MovUp7), - 0b0001_0111 => Some(Self::MovDn7), - 0b0001_1000 => Some(Self::SwapW), - 0b0001_1001 => Some(Self::Ext2Mul), - 0b0001_1010 => Some(Self::MovUp8), - 0b0001_1011 => Some(Self::MovDn8), - 0b0001_1100 => Some(Self::SwapW2), - 0b0001_1101 => Some(Self::SwapW3), - 0b0001_1110 => Some(Self::SwapDW), + 0b0000_0000 => Ok(Self::Noop), + 0b0000_0001 => Ok(Self::Eqz), + 0b0000_0010 => Ok(Self::Neg), + 0b0000_0011 => Ok(Self::Inv), + 0b0000_0100 => Ok(Self::Incr), + 0b0000_0101 => Ok(Self::Not), + 0b0000_0110 => Ok(Self::FmpAdd), + 0b0000_0111 => Ok(Self::MLoad), + 0b0000_1000 => Ok(Self::Swap), + 0b0000_1001 => Ok(Self::Caller), + 0b0000_1010 => Ok(Self::MovUp2), + 0b0000_1011 => Ok(Self::MovDn2), + 0b0000_1100 => Ok(Self::MovUp3), + 0b0000_1101 => Ok(Self::MovDn3), + 0b0000_1110 => Ok(Self::AdvPopW), + 0b0000_1111 => Ok(Self::Expacc), + + 0b0001_0000 => Ok(Self::MovUp4), + 0b0001_0001 => Ok(Self::MovDn4), + 0b0001_0010 => Ok(Self::MovUp5), + 0b0001_0011 => Ok(Self::MovDn5), + 0b0001_0100 => Ok(Self::MovUp6), + 0b0001_0101 => Ok(Self::MovDn6), + 0b0001_0110 => Ok(Self::MovUp7), + 0b0001_0111 => Ok(Self::MovDn7), + 0b0001_1000 => Ok(Self::SwapW), + 0b0001_1001 => Ok(Self::Ext2Mul), + 0b0001_1010 => Ok(Self::MovUp8), + 0b0001_1011 => Ok(Self::MovDn8), + 0b0001_1100 => Ok(Self::SwapW2), + 0b0001_1101 => Ok(Self::SwapW3), + 0b0001_1110 => Ok(Self::SwapDW), // 0b0001_1111 => , 0b0010_0000 => match data { - OperationData::U32(value) => Some(Self::Assert(value)), - _ => None, + OperationData::U32(value) => Ok(Self::Assert(value)), + _ => Err(DeserializationError::InvalidValue( + "Invalid opcode data. 'Assert' opcode provided, hence expected to receive u32 data.".to_string() + )), }, - 0b0010_0001 => Some(Self::Eq), - 0b0010_0010 => Some(Self::Add), - 0b0010_0011 => Some(Self::Mul), - 0b0010_0100 => Some(Self::And), - 0b0010_0101 => Some(Self::Or), - 0b0010_0110 => Some(Self::U32and), - 0b0010_0111 => Some(Self::U32xor), - 0b0010_1000 => Some(Self::FriE2F4), - 0b0010_1001 => Some(Self::Drop), - 0b0010_1010 => Some(Self::CSwap), - 0b0010_1011 => Some(Self::CSwapW), - 0b0010_1100 => Some(Self::MLoadW), - 0b0010_1101 => Some(Self::MStore), - 0b0010_1110 => Some(Self::MStoreW), - 0b0010_1111 => Some(Self::FmpUpdate), - - 0b0011_0000 => Some(Self::Pad), - 0b0011_0001 => Some(Self::Dup0), - 0b0011_0010 => Some(Self::Dup1), - 0b0011_0011 => Some(Self::Dup2), - 0b0011_0100 => Some(Self::Dup3), - 0b0011_0101 => Some(Self::Dup4), - 0b0011_0110 => Some(Self::Dup5), - 0b0011_0111 => Some(Self::Dup6), - 0b0011_1000 => Some(Self::Dup7), - 0b0011_1001 => Some(Self::Dup9), - 0b0011_1010 => Some(Self::Dup11), - 0b0011_1011 => Some(Self::Dup13), - 0b0011_1100 => Some(Self::Dup15), - 0b0011_1101 => Some(Self::AdvPop), - 0b0011_1110 => Some(Self::SDepth), - 0b0011_1111 => Some(Self::Clk), - - 0b0100_0000 => Some(Self::U32add), - 0b0100_0010 => Some(Self::U32sub), - 0b0100_0100 => Some(Self::U32mul), - 0b0100_0110 => Some(Self::U32div), - 0b0100_1000 => Some(Self::U32split), + 0b0010_0001 => Ok(Self::Eq), + 0b0010_0010 => Ok(Self::Add), + 0b0010_0011 => Ok(Self::Mul), + 0b0010_0100 => Ok(Self::And), + 0b0010_0101 => Ok(Self::Or), + 0b0010_0110 => Ok(Self::U32and), + 0b0010_0111 => Ok(Self::U32xor), + 0b0010_1000 => Ok(Self::FriE2F4), + 0b0010_1001 => Ok(Self::Drop), + 0b0010_1010 => Ok(Self::CSwap), + 0b0010_1011 => Ok(Self::CSwapW), + 0b0010_1100 => Ok(Self::MLoadW), + 0b0010_1101 => Ok(Self::MStore), + 0b0010_1110 => Ok(Self::MStoreW), + 0b0010_1111 => Ok(Self::FmpUpdate), + + 0b0011_0000 => Ok(Self::Pad), + 0b0011_0001 => Ok(Self::Dup0), + 0b0011_0010 => Ok(Self::Dup1), + 0b0011_0011 => Ok(Self::Dup2), + 0b0011_0100 => Ok(Self::Dup3), + 0b0011_0101 => Ok(Self::Dup4), + 0b0011_0110 => Ok(Self::Dup5), + 0b0011_0111 => Ok(Self::Dup6), + 0b0011_1000 => Ok(Self::Dup7), + 0b0011_1001 => Ok(Self::Dup9), + 0b0011_1010 => Ok(Self::Dup11), + 0b0011_1011 => Ok(Self::Dup13), + 0b0011_1100 => Ok(Self::Dup15), + 0b0011_1101 => Ok(Self::AdvPop), + 0b0011_1110 => Ok(Self::SDepth), + 0b0011_1111 => Ok(Self::Clk), + + 0b0100_0000 => Ok(Self::U32add), + 0b0100_0010 => Ok(Self::U32sub), + 0b0100_0100 => Ok(Self::U32mul), + 0b0100_0110 => Ok(Self::U32div), + 0b0100_1000 => Ok(Self::U32split), 0b0100_1010 => match data { - OperationData::Felt(value) => Some(Self::U32assert2(value)), - _ => None, + OperationData::Felt(value) => Ok(Self::U32assert2(value)), + _ => Err(DeserializationError::InvalidValue( + "Invalid opcode data. 'U32assert2' opcode provided, hence expected to receive Felt data.".to_string() + )), }, - 0b0100_1100 => Some(Self::U32add3), - 0b0100_1110 => Some(Self::U32madd), + 0b0100_1100 => Ok(Self::U32add3), + 0b0100_1110 => Ok(Self::U32madd), - 0b0101_0000 => Some(Self::HPerm), + 0b0101_0000 => Ok(Self::HPerm), 0b0101_0001 => match data { - OperationData::U32(value) => Some(Self::MpVerify(value)), - _ => None, + OperationData::U32(value) => Ok(Self::MpVerify(value)), + _ => Err(DeserializationError::InvalidValue( + "Invalid opcode data. 'MpVerify' opcode provided, hence expected to receive u32 data.".to_string() + )), }, - 0b0101_0010 => Some(Self::Pipe), - 0b0101_0011 => Some(Self::MStream), - 0b0101_0100 => Some(Self::Split), - 0b0101_0101 => Some(Self::Loop), - 0b0101_0110 => Some(Self::Span), - 0b0101_0111 => Some(Self::Join), - 0b0101_1000 => Some(Self::Dyn), - 0b0101_1001 => Some(Self::RCombBase), + 0b0101_0010 => Ok(Self::Pipe), + 0b0101_0011 => Ok(Self::MStream), + 0b0101_0100 => Ok(Self::Split), + 0b0101_0101 => Ok(Self::Loop), + 0b0101_0110 => Ok(Self::Span), + 0b0101_0111 => Ok(Self::Join), + 0b0101_1000 => Ok(Self::Dyn), + 0b0101_1001 => Ok(Self::RCombBase), // 0b0101_1010 => , // 0b0101_1011 => , // 0b0101_1100 => , // 0b0101_1101 => , // 0b0101_1110 => , // 0b0101_1111 => , - 0b0110_0000 => Some(Self::MrUpdate), + 0b0110_0000 => Ok(Self::MrUpdate), 0b0110_0100 => match data { - OperationData::Felt(value) => Some(Self::Push(value)), - _ => None, + OperationData::Felt(value) => Ok(Self::Push(value)), + _ => Err(DeserializationError::InvalidValue( + "Invalid opcode data. 'Push' opcode provided, hence expected to receive Felt data.".to_string() + )), }, - 0b0110_1000 => Some(Self::SysCall), - 0b0110_1100 => Some(Self::Call), - 0b0111_0000 => Some(Self::End), - 0b0111_0100 => Some(Self::Repeat), - 0b0111_1000 => Some(Self::Respan), - 0b0111_1100 => Some(Self::Halt), - - _ => None, + 0b0110_1000 => Ok(Self::SysCall), + 0b0110_1100 => Ok(Self::Call), + 0b0111_0000 => Ok(Self::End), + 0b0111_0100 => Ok(Self::Repeat), + 0b0111_1000 => Ok(Self::Respan), + 0b0111_1100 => Ok(Self::Halt), + + _ => Err(DeserializationError::InvalidValue(format!( + "Invalid opcode {opcode}" + ))), } } } From 9106a6f580b66347e3515daf081dcfe75d3e2f42 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 18:26:22 -0400 Subject: [PATCH 092/172] Remove TODOP --- core/src/mast/serialization/info.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index bf89539b8e..63ffb216c9 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -439,6 +439,4 @@ mod tests { assert_eq!(if_branch_id, decoded_if_branch); assert_eq!(else_branch_id, decoded_else_branch); } - - // TODOP: Test all other variants } From fe8b7a7914b9222b4325d1245d97bd4b31ca822e Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 3 Jul 2024 18:31:19 -0400 Subject: [PATCH 093/172] docs --- core/src/mast/node/basic_block_node/mod.rs | 2 +- .../serialization/basic_block_data_builder.rs | 8 ++++-- core/src/mast/serialization/decorator.rs | 5 +++- core/src/mast/serialization/info.rs | 26 ++++++++++++++----- core/src/mast/serialization/mod.rs | 2 +- core/src/operations/mod.rs | 2 +- 6 files changed, 33 insertions(+), 12 deletions(-) diff --git a/core/src/mast/node/basic_block_node/mod.rs b/core/src/mast/node/basic_block_node/mod.rs index 66cc659998..e9bec6d0c0 100644 --- a/core/src/mast/node/basic_block_node/mod.rs +++ b/core/src/mast/node/basic_block_node/mod.rs @@ -223,7 +223,7 @@ impl fmt::Display for BasicBlockNode { // OPERATION OR DECORATOR // ================================================================================================ -// TODOP: Document +/// Encodes either an [`Operation`] or a [`Decorator`]. #[derive(Clone, Debug, Eq, PartialEq)] pub enum OperationOrDecorator<'a> { Operation(&'a Operation), diff --git a/core/src/mast/serialization/basic_block_data_builder.rs b/core/src/mast/serialization/basic_block_data_builder.rs index 886f1a09d9..3ca9ece182 100644 --- a/core/src/mast/serialization/basic_block_data_builder.rs +++ b/core/src/mast/serialization/basic_block_data_builder.rs @@ -9,7 +9,7 @@ use crate::{ use super::{decorator::EncodedDecoratorVariant, DataOffset, StringIndex, StringRef}; -/// TODOP: Document +/// Builds the `data` section of a serialized [`crate::mast::MastForest`]. #[derive(Debug, Default)] pub struct BasicBlockDataBuilder { data: Vec, @@ -25,7 +25,9 @@ impl BasicBlockDataBuilder { /// Accessors impl BasicBlockDataBuilder { - pub fn current_data_offset(&self) -> DataOffset { + /// Returns the offset in the serialized [`crate::mast::MastForest`] data field that the next + /// [`super::MastNodeInfo`] representing a [`BasicBlockNode`] will take. + pub fn next_data_offset(&self) -> DataOffset { self.data .len() .try_into() @@ -35,6 +37,7 @@ impl BasicBlockDataBuilder { /// Mutators impl BasicBlockDataBuilder { + /// Encodes a [`BasicBlockNode`] into the serialized [`crate::mast::MastForest`] data field. pub fn encode_basic_block(&mut self, basic_block: &BasicBlockNode) { // 2nd part of `mast_node_to_info()` (inside the match) for op_or_decorator in basic_block.iter() { @@ -45,6 +48,7 @@ impl BasicBlockDataBuilder { } } + /// Returns the serialized [`crate::mast::MastForest`] data field, as well as the string table. pub fn into_parts(mut self) -> (Vec, Vec) { let string_table = self.string_table_builder.into_table(&mut self.data); (self.data, string_table) diff --git a/core/src/mast/serialization/decorator.rs b/core/src/mast/serialization/decorator.rs index 3cb509a039..22024397ca 100644 --- a/core/src/mast/serialization/decorator.rs +++ b/core/src/mast/serialization/decorator.rs @@ -3,7 +3,10 @@ use num_traits::{FromPrimitive, ToPrimitive}; use crate::{AdviceInjector, DebugOptions, Decorator}; -/// TODOP: Document +/// Stores all the possible [`Decorator`] variants, without any associated data. +/// +/// This is effectively equivalent to a set of constants, and designed to convert between variant +/// discriminant and enum variant conveniently. #[derive(FromPrimitive, ToPrimitive)] #[repr(u8)] pub enum EncodedDecoratorVariant { diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 63ffb216c9..1dec494aac 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -8,6 +8,15 @@ use super::{basic_block_data_decoder::BasicBlockDataDecoder, DataOffset}; // MAST NODE INFO // =============================================================================================== +/// Represents a serialized [`MastNode`], with some data inlined in its [`MastNodeType`]. +/// +/// In the case of [`crate::mast::BasicBlockNode`], all its operation- and decorator-related data is +/// stored in the serialized [`MastForest`]'s `data` field at offset represented by the `offset` +/// field. For all other variants of [`MastNode`], the `offset` field is guaranteed to be 0. +/// +/// The serialized representation of [`MastNodeInfo`] is guaranteed to be fixed width, so that the +/// nodes stored in the `nodes` table of the serialzied [`MastForest`] can be accessed quickly by +/// index. #[derive(Debug)] pub struct MastNodeInfo { ty: MastNodeType, @@ -125,7 +134,12 @@ const SYSCALL: u8 = 5; const DYN: u8 = 6; const EXTERNAL: u8 = 7; -/// TODOP: Document the fact that encoded representation is always 8 bytes +/// Represents the variant of a [`MastNode`], as well as any additional data. For example, for more +/// efficient decoding, and because of the frequency with which these node types appear, we directly +/// represent the child indices for `Join`, `Split`, and `Loop`, `Call` and `SysCall` inline. +/// +/// The serialized representation of the MAST node type is guaranteed to be 8 bytes, so that +/// [`MastNodeInfo`] (which contains it) can be of fixed width. #[derive(Debug)] #[repr(u8)] pub enum MastNodeType { @@ -156,6 +170,7 @@ pub enum MastNodeType { /// Constructors impl MastNodeType { + /// Constructs a new [`MastNodeType`] from a [`MastNode`]. pub fn new(mast_node: &MastNode) -> Self { use MastNode::*; @@ -199,9 +214,9 @@ impl Serializable for MastNodeType { let mut serialized_bytes = self.inline_data_to_bytes(); // Tag is always placed in the first four bytes - let tag = self.tag(); - assert!(tag <= 0b1111); - serialized_bytes[0] |= tag << 4; + let discriminant = self.discriminant(); + assert!(discriminant <= 0b1111); + serialized_bytes[0] |= discriminant << 4; serialized_bytes }; @@ -212,7 +227,7 @@ impl Serializable for MastNodeType { /// Serialization helpers impl MastNodeType { - fn tag(&self) -> u8 { + fn discriminant(&self) -> u8 { // SAFETY: This is safe because we have given this enum a primitive representation with // #[repr(u8)], with the first field of the underlying union-of-structs the discriminant. // @@ -240,7 +255,6 @@ impl MastNodeType { } } - // TODOP: Make a diagram of how the bits are split fn encode_join_or_split(left_child_id: u32, right_child_id: u32) -> [u8; 8] { assert!(left_child_id < 2_u32.pow(30)); assert!(right_child_id < 2_u32.pow(30)); diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 0caaf2f96e..b08b58bdda 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -90,7 +90,7 @@ impl Serializable for MastForest { // MAST node infos for mast_node in &self.nodes { let mast_node_info = - MastNodeInfo::new(mast_node, basic_block_data_builder.current_data_offset()); + MastNodeInfo::new(mast_node, basic_block_data_builder.next_data_offset()); if let MastNode::Block(basic_block) = mast_node { basic_block_data_builder.encode_basic_block(basic_block); diff --git a/core/src/operations/mod.rs b/core/src/operations/mod.rs index ab201933bd..078abaf72c 100644 --- a/core/src/operations/mod.rs +++ b/core/src/operations/mod.rs @@ -451,7 +451,7 @@ pub enum OperationData { /// Constructors impl Operation { - // TODOP: document + /// Builds an operation from its opcode and inline data (if any). pub fn with_opcode_and_data( opcode: u8, data: OperationData, From 5869525d05da14ff3d13e4f88d1d8778a57a40af Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Sun, 7 Jul 2024 09:30:39 -0700 Subject: [PATCH 094/172] chore: add section separators and fix typos --- core/src/mast/mod.rs | 7 +++++-- .../src/mast/serialization/basic_block_data_builder.rs | 10 ++++++++-- .../src/mast/serialization/basic_block_data_decoder.rs | 4 ++-- core/src/mast/serialization/info.rs | 8 ++++---- core/src/mast/serialization/mod.rs | 8 ++++---- core/src/operations/decorators/mod.rs | 2 +- 6 files changed, 24 insertions(+), 15 deletions(-) diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 1c0eb12380..c8567f9ae8 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -21,6 +21,9 @@ pub trait MerkleTreeNode { fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a; } +// MAST NODE ID +// ================================================================================================ + /// An opaque handle to a [`MastNode`] in some [`MastForest`]. It is the responsibility of the user /// to use a given [`MastNodeId`] with the corresponding [`MastForest`]. /// @@ -72,7 +75,7 @@ impl Deserializable for MastNodeId { } // MAST FOREST -// =============================================================================================== +// ================================================================================================ /// Represents one or more procedures, represented as a collection of [`MastNode`]s. /// @@ -132,7 +135,7 @@ impl MastForest { /// Returns the [`MastNode`] associated with the provided [`MastNodeId`] if valid, or else /// `None`. /// - /// This is the faillible version of indexing (e.g. `mast_forest[node_id]`). + /// This is the failable version of indexing (e.g. `mast_forest[node_id]`). #[inline(always)] pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> { let idx = node_id.0 as usize; diff --git a/core/src/mast/serialization/basic_block_data_builder.rs b/core/src/mast/serialization/basic_block_data_builder.rs index 3ca9ece182..717ebf1578 100644 --- a/core/src/mast/serialization/basic_block_data_builder.rs +++ b/core/src/mast/serialization/basic_block_data_builder.rs @@ -9,6 +9,9 @@ use crate::{ use super::{decorator::EncodedDecoratorVariant, DataOffset, StringIndex, StringRef}; +// BASIC BLOCK DATA BUILDER +// ================================================================================================ + /// Builds the `data` section of a serialized [`crate::mast::MastForest`]. #[derive(Debug, Default)] pub struct BasicBlockDataBuilder { @@ -62,8 +65,8 @@ impl BasicBlockDataBuilder { // For operations that have extra data, encode it in `data`. match operation { - Operation::Assert(value) | Operation::MpVerify(value) => { - self.data.extend_from_slice(&value.to_le_bytes()) + Operation::Assert(err_code) | Operation::MpVerify(err_code) => { + self.data.extend_from_slice(&err_code.to_le_bytes()) } Operation::U32assert2(value) | Operation::Push(value) => { self.data.extend_from_slice(&value.as_int().to_le_bytes()) @@ -240,6 +243,9 @@ impl BasicBlockDataBuilder { } } +// STRING TABLE BUILDER +// ================================================================================================ + #[derive(Debug, Default)] struct StringTableBuilder { table: Vec, diff --git a/core/src/mast/serialization/basic_block_data_decoder.rs b/core/src/mast/serialization/basic_block_data_decoder.rs index b002b553f8..bb6ebae0e5 100644 --- a/core/src/mast/serialization/basic_block_data_decoder.rs +++ b/core/src/mast/serialization/basic_block_data_decoder.rs @@ -43,8 +43,8 @@ impl<'a> BasicBlockDataDecoder<'a> { // operation. let op_code = first_byte; - let operation = if op_code == Operation::Assert(0_u32).op_code() - || op_code == Operation::MpVerify(0_u32).op_code() + let operation = if op_code == Operation::Assert(0).op_code() + || op_code == Operation::MpVerify(0).op_code() { let value_le_bytes: [u8; 4] = self.data_reader.read_array()?; let value = u32::from_le_bytes(value_le_bytes); diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 1dec494aac..4b056a3ae3 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -6,7 +6,7 @@ use crate::mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}; use super::{basic_block_data_decoder::BasicBlockDataDecoder, DataOffset}; // MAST NODE INFO -// =============================================================================================== +// ================================================================================================ /// Represents a serialized [`MastNode`], with some data inlined in its [`MastNodeType`]. /// @@ -15,7 +15,7 @@ use super::{basic_block_data_decoder::BasicBlockDataDecoder, DataOffset}; /// field. For all other variants of [`MastNode`], the `offset` field is guaranteed to be 0. /// /// The serialized representation of [`MastNodeInfo`] is guaranteed to be fixed width, so that the -/// nodes stored in the `nodes` table of the serialzied [`MastForest`] can be accessed quickly by +/// nodes stored in the `nodes` table of the serialized [`MastForest`] can be accessed quickly by /// index. #[derive(Debug)] pub struct MastNodeInfo { @@ -123,7 +123,7 @@ impl Deserializable for MastNodeInfo { } // MAST NODE TYPE -// =============================================================================================== +// ================================================================================================ const JOIN: u8 = 0; const SPLIT: u8 = 1; @@ -393,7 +393,7 @@ impl MastNodeType { } // TESTS -// =============================================================================================== +// ================================================================================================ #[cfg(test)] mod tests { diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index b08b58bdda..b9f968ae34 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -18,7 +18,7 @@ use basic_block_data_decoder::BasicBlockDataDecoder; mod tests; // TYPE ALIASES -// =============================================================================================== +// ================================================================================================ /// Specifies an offset into the `data` section of an encoded [`MastForest`]. type DataOffset = u32; @@ -27,7 +27,7 @@ type DataOffset = u32; type StringIndex = usize; // CONSTANTS -// =============================================================================================== +// ================================================================================================ /// Magic string for detecting that a file is binary-encoded MAST. const MAGIC: &[u8; 5] = b"MAST\0"; @@ -40,7 +40,7 @@ const MAGIC: &[u8; 5] = b"MAST\0"; const VERSION: [u8; 3] = [0, 0, 0]; // STRING REF -// =============================================================================================== +// ================================================================================================ /// An entry in the `strings` table of an encoded [`MastForest`]. /// @@ -71,7 +71,7 @@ impl Deserializable for StringRef { } // MAST FOREST SERIALIZATION/DESERIALIZATION -// =============================================================================================== +// ================================================================================================ impl Serializable for MastForest { fn write_into(&self, target: &mut W) { diff --git a/core/src/operations/decorators/mod.rs b/core/src/operations/decorators/mod.rs index b049cdba86..d183ada7aa 100644 --- a/core/src/operations/decorators/mod.rs +++ b/core/src/operations/decorators/mod.rs @@ -32,7 +32,7 @@ pub enum Decorator { Debug(DebugOptions), /// Emits an event to the host. Event(u32), - /// Emmits a trace to the host. + /// Emits a trace to the host. Trace(u32), } From 8cb34626348e295562968fcb17802495a7b06173 Mon Sep 17 00:00:00 2001 From: Andrey Khmuro Date: Wed, 10 Jul 2024 15:32:06 +0300 Subject: [PATCH 095/172] refactor: change type of the error code of u32assert2 from Felt to u32 (#1382) --- air/src/constraints/stack/op_flags/mod.rs | 2 +- assembly/src/assembler/instruction/mod.rs | 12 ++++++------ assembly/src/assembler/instruction/u32_ops.rs | 5 ++--- .../mast/serialization/basic_block_data_builder.rs | 8 ++++---- .../mast/serialization/basic_block_data_decoder.rs | 5 ++--- core/src/mast/serialization/tests.rs | 7 +++---- core/src/operations/mod.rs | 6 +++--- processor/src/operations/u32_ops.rs | 8 ++++---- 8 files changed, 25 insertions(+), 28 deletions(-) diff --git a/air/src/constraints/stack/op_flags/mod.rs b/air/src/constraints/stack/op_flags/mod.rs index e8ea0b51d7..0341537447 100644 --- a/air/src/constraints/stack/op_flags/mod.rs +++ b/air/src/constraints/stack/op_flags/mod.rs @@ -840,7 +840,7 @@ impl OpFlags { /// Operation Flag of U32ASSERT2 operation. #[inline(always)] pub fn u32assert2(&self) -> E { - self.degree6_op_flags[get_op_index(Operation::U32assert2(ZERO).op_code())] + self.degree6_op_flags[get_op_index(Operation::U32assert2(0).op_code())] } /// Operation Flag of U32ADD3 operation. diff --git a/assembly/src/assembler/instruction/mod.rs b/assembly/src/assembler/instruction/mod.rs index 48f01ce203..c79f406d6f 100644 --- a/assembly/src/assembler/instruction/mod.rs +++ b/assembly/src/assembler/instruction/mod.rs @@ -123,17 +123,17 @@ impl Assembler { // ----- u32 manipulation ------------------------------------------------------------- Instruction::U32Test => span_builder.push_ops([Dup0, U32split, Swap, Drop, Eqz]), Instruction::U32TestW => u32_ops::u32testw(span_builder), - Instruction::U32Assert => span_builder.push_ops([Pad, U32assert2(ZERO), Drop]), + Instruction::U32Assert => span_builder.push_ops([Pad, U32assert2(0), Drop]), Instruction::U32AssertWithError(err_code) => { - span_builder.push_ops([Pad, U32assert2(Felt::from(err_code.expect_value())), Drop]) + span_builder.push_ops([Pad, U32assert2(err_code.expect_value()), Drop]) } - Instruction::U32Assert2 => span_builder.push_op(U32assert2(ZERO)), + Instruction::U32Assert2 => span_builder.push_op(U32assert2(0)), Instruction::U32Assert2WithError(err_code) => { - span_builder.push_op(U32assert2(Felt::from(err_code.expect_value()))) + span_builder.push_op(U32assert2(err_code.expect_value())) } - Instruction::U32AssertW => u32_ops::u32assertw(span_builder, ZERO), + Instruction::U32AssertW => u32_ops::u32assertw(span_builder, 0), Instruction::U32AssertWWithError(err_code) => { - u32_ops::u32assertw(span_builder, Felt::from(err_code.expect_value())) + u32_ops::u32assertw(span_builder, err_code.expect_value()) } Instruction::U32Cast => span_builder.push_ops([U32split, Drop]), diff --git a/assembly/src/assembler/instruction/u32_ops.rs b/assembly/src/assembler/instruction/u32_ops.rs index e5b5575728..633044aa21 100644 --- a/assembly/src/assembler/instruction/u32_ops.rs +++ b/assembly/src/assembler/instruction/u32_ops.rs @@ -6,7 +6,6 @@ use crate::{ use vm_core::{ AdviceInjector, Felt, Operation::{self, *}, - ZERO, }; /// This enum is intended to determine the mode of operation passed to the parsing function @@ -45,7 +44,7 @@ pub fn u32testw(span_builder: &mut BasicBlockBuilder) { /// /// Implemented by executing `U32ASSERT2` on each pair of elements in the word. /// Total of 6 VM cycles. -pub fn u32assertw(span_builder: &mut BasicBlockBuilder, err_code: Felt) { +pub fn u32assertw(span_builder: &mut BasicBlockBuilder, err_code: u32) { #[rustfmt::skip] let ops = [ // Test the first and the second elements @@ -171,7 +170,7 @@ pub fn u32not(span_builder: &mut BasicBlockBuilder) { let ops = [ // Perform the operation Push(Felt::from(u32::MAX)), - U32assert2(ZERO), + U32assert2(0), Swap, U32sub, diff --git a/core/src/mast/serialization/basic_block_data_builder.rs b/core/src/mast/serialization/basic_block_data_builder.rs index 717ebf1578..0f28587ce1 100644 --- a/core/src/mast/serialization/basic_block_data_builder.rs +++ b/core/src/mast/serialization/basic_block_data_builder.rs @@ -65,12 +65,12 @@ impl BasicBlockDataBuilder { // For operations that have extra data, encode it in `data`. match operation { - Operation::Assert(err_code) | Operation::MpVerify(err_code) => { + Operation::Assert(err_code) + | Operation::MpVerify(err_code) + | Operation::U32assert2(err_code) => { self.data.extend_from_slice(&err_code.to_le_bytes()) } - Operation::U32assert2(value) | Operation::Push(value) => { - self.data.extend_from_slice(&value.as_int().to_le_bytes()) - } + Operation::Push(value) => self.data.extend_from_slice(&value.as_int().to_le_bytes()), // Note: we explicitly write out all the operations so that whenever we make a // modification to the `Operation` enum, we get a compile error here. This // should help us remember to properly encode/decode each operation variant. diff --git a/core/src/mast/serialization/basic_block_data_decoder.rs b/core/src/mast/serialization/basic_block_data_decoder.rs index bb6ebae0e5..1c12befe69 100644 --- a/core/src/mast/serialization/basic_block_data_decoder.rs +++ b/core/src/mast/serialization/basic_block_data_decoder.rs @@ -45,14 +45,13 @@ impl<'a> BasicBlockDataDecoder<'a> { let operation = if op_code == Operation::Assert(0).op_code() || op_code == Operation::MpVerify(0).op_code() + || op_code == Operation::U32assert2(0).op_code() { let value_le_bytes: [u8; 4] = self.data_reader.read_array()?; let value = u32::from_le_bytes(value_le_bytes); Operation::with_opcode_and_data(op_code, OperationData::U32(value))? - } else if op_code == Operation::U32assert2(ZERO).op_code() - || op_code == Operation::Push(ZERO).op_code() - { + } else if op_code == Operation::Push(ZERO).op_code() { // Felt operation data let value_le_bytes: [u8; 8] = self.data_reader.read_array()?; let value_u64 = u64::from_le_bytes(value_le_bytes); diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs index 8903fd5741..4f0c56bcb8 100644 --- a/core/src/mast/serialization/tests.rs +++ b/core/src/mast/serialization/tests.rs @@ -1,5 +1,4 @@ use alloc::string::ToString; -use math::FieldElement; use miden_crypto::{hash::rpo::RpoDigest, Felt}; use super::*; @@ -13,7 +12,7 @@ use crate::{ /// [`serialize_deserialize_all_nodes`]. #[test] fn confirm_operation_and_decorator_structure() { - let _ = match Operation::Noop { + match Operation::Noop { Operation::Noop => (), Operation::Assert(_) => (), Operation::FmpAdd => (), @@ -105,7 +104,7 @@ fn confirm_operation_and_decorator_structure() { Operation::RCombBase => (), }; - let _ = match Decorator::Event(0) { + match Decorator::Event(0) { Decorator::Advice(advice) => match advice { AdviceInjector::MerkleNodeMerge => (), AdviceInjector::MerkleNodeToStack => (), @@ -181,7 +180,7 @@ fn serialize_deserialize_all_nodes() { Operation::Ext2Mul, Operation::U32split, Operation::U32add, - Operation::U32assert2(Felt::ONE), + Operation::U32assert2(222), Operation::U32add3, Operation::U32sub, Operation::U32mul, diff --git a/core/src/operations/mod.rs b/core/src/operations/mod.rs index 078abaf72c..6eff46e2d9 100644 --- a/core/src/operations/mod.rs +++ b/core/src/operations/mod.rs @@ -164,7 +164,7 @@ pub enum Operation { /// /// The internal value specifies an error code associated with the error in case when the /// assertion fails. - U32assert2(Felt), + U32assert2(u32), /// Pops three elements off the stack, adds them together, and splits the result into upper /// and lower 32-bit values. Then pushes the result back onto the stack. @@ -535,9 +535,9 @@ impl Operation { 0b0100_0110 => Ok(Self::U32div), 0b0100_1000 => Ok(Self::U32split), 0b0100_1010 => match data { - OperationData::Felt(value) => Ok(Self::U32assert2(value)), + OperationData::U32(value) => Ok(Self::U32assert2(value)), _ => Err(DeserializationError::InvalidValue( - "Invalid opcode data. 'U32assert2' opcode provided, hence expected to receive Felt data.".to_string() + "Invalid opcode data. 'U32assert2' opcode provided, hence expected to receive u32 data.".to_string() )), }, 0b0100_1100 => Ok(Self::U32add3), diff --git a/processor/src/operations/u32_ops.rs b/processor/src/operations/u32_ops.rs index 0b4f3a0198..d525ab4a9a 100644 --- a/processor/src/operations/u32_ops.rs +++ b/processor/src/operations/u32_ops.rs @@ -28,15 +28,15 @@ where /// Pops top two element off the stack, splits them into low and high 32-bit values, checks if /// the high values are equal to 0; if they are, puts the original elements back onto the /// stack; if they are not, returns an error. - pub(super) fn op_u32assert2(&mut self, err_code: Felt) -> Result<(), ExecutionError> { + pub(super) fn op_u32assert2(&mut self, err_code: u32) -> Result<(), ExecutionError> { let a = self.stack.get(0); let b = self.stack.get(1); if a.as_int() >> 32 != 0 { - return Err(ExecutionError::NotU32Value(a, err_code)); + return Err(ExecutionError::NotU32Value(a, Felt::from(err_code))); } if b.as_int() >> 32 != 0 { - return Err(ExecutionError::NotU32Value(b, err_code)); + return Err(ExecutionError::NotU32Value(b, Felt::from(err_code))); } self.add_range_checks(Operation::U32assert2(err_code), a, b, false); @@ -280,7 +280,7 @@ mod tests { let stack = StackInputs::try_from_ints([d as u64, c as u64, b as u64, a as u64]).unwrap(); let mut process = Process::new_dummy_with_decoder_helpers(stack); - process.execute_op(Operation::U32assert2(ZERO)).unwrap(); + process.execute_op(Operation::U32assert2(0)).unwrap(); let expected = build_expected(&[a, b, c, d]); assert_eq!(expected, process.stack.trace_state()); } From 5ec4826b5ffce218c8f1035a93ab7d41935be4fb Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 10 Jul 2024 08:48:16 -0400 Subject: [PATCH 096/172] impl `Serializable` for `Operation` --- .../serialization/basic_block_data_builder.rs | 108 +----------------- core/src/operations/mod.rs | 107 ++++++++++++++++- 2 files changed, 109 insertions(+), 106 deletions(-) diff --git a/core/src/mast/serialization/basic_block_data_builder.rs b/core/src/mast/serialization/basic_block_data_builder.rs index 0f28587ce1..c86ed1511f 100644 --- a/core/src/mast/serialization/basic_block_data_builder.rs +++ b/core/src/mast/serialization/basic_block_data_builder.rs @@ -1,10 +1,10 @@ use alloc::{collections::BTreeMap, vec::Vec}; use miden_crypto::hash::rpo::{Rpo256, RpoDigest}; -use winter_utils::ByteWriter; +use winter_utils::{ByteWriter, Serializable}; use crate::{ mast::{BasicBlockNode, OperationOrDecorator}, - AdviceInjector, DebugOptions, Decorator, Operation, SignatureKind, + AdviceInjector, DebugOptions, Decorator, SignatureKind, }; use super::{decorator::EncodedDecoratorVariant, DataOffset, StringIndex, StringRef}; @@ -45,7 +45,7 @@ impl BasicBlockDataBuilder { // 2nd part of `mast_node_to_info()` (inside the match) for op_or_decorator in basic_block.iter() { match op_or_decorator { - OperationOrDecorator::Operation(operation) => self.encode_operation(operation), + OperationOrDecorator::Operation(operation) => operation.write_into(&mut self.data), OperationOrDecorator::Decorator(decorator) => self.encode_decorator(decorator), } } @@ -60,108 +60,6 @@ impl BasicBlockDataBuilder { /// Helpers impl BasicBlockDataBuilder { - fn encode_operation(&mut self, operation: &Operation) { - self.data.push(operation.op_code()); - - // For operations that have extra data, encode it in `data`. - match operation { - Operation::Assert(err_code) - | Operation::MpVerify(err_code) - | Operation::U32assert2(err_code) => { - self.data.extend_from_slice(&err_code.to_le_bytes()) - } - Operation::Push(value) => self.data.extend_from_slice(&value.as_int().to_le_bytes()), - // Note: we explicitly write out all the operations so that whenever we make a - // modification to the `Operation` enum, we get a compile error here. This - // should help us remember to properly encode/decode each operation variant. - Operation::Noop - | Operation::FmpAdd - | Operation::FmpUpdate - | Operation::SDepth - | Operation::Caller - | Operation::Clk - | Operation::Join - | Operation::Split - | Operation::Loop - | Operation::Call - | Operation::Dyn - | Operation::SysCall - | Operation::Span - | Operation::End - | Operation::Repeat - | Operation::Respan - | Operation::Halt - | Operation::Add - | Operation::Neg - | Operation::Mul - | Operation::Inv - | Operation::Incr - | Operation::And - | Operation::Or - | Operation::Not - | Operation::Eq - | Operation::Eqz - | Operation::Expacc - | Operation::Ext2Mul - | Operation::U32split - | Operation::U32add - | Operation::U32add3 - | Operation::U32sub - | Operation::U32mul - | Operation::U32madd - | Operation::U32div - | Operation::U32and - | Operation::U32xor - | Operation::Pad - | Operation::Drop - | Operation::Dup0 - | Operation::Dup1 - | Operation::Dup2 - | Operation::Dup3 - | Operation::Dup4 - | Operation::Dup5 - | Operation::Dup6 - | Operation::Dup7 - | Operation::Dup9 - | Operation::Dup11 - | Operation::Dup13 - | Operation::Dup15 - | Operation::Swap - | Operation::SwapW - | Operation::SwapW2 - | Operation::SwapW3 - | Operation::SwapDW - | Operation::MovUp2 - | Operation::MovUp3 - | Operation::MovUp4 - | Operation::MovUp5 - | Operation::MovUp6 - | Operation::MovUp7 - | Operation::MovUp8 - | Operation::MovDn2 - | Operation::MovDn3 - | Operation::MovDn4 - | Operation::MovDn5 - | Operation::MovDn6 - | Operation::MovDn7 - | Operation::MovDn8 - | Operation::CSwap - | Operation::CSwapW - | Operation::AdvPop - | Operation::AdvPopW - | Operation::MLoadW - | Operation::MStoreW - | Operation::MLoad - | Operation::MStore - | Operation::MStream - | Operation::Pipe - | Operation::HPerm - | Operation::MrUpdate - | Operation::FriE2F4 - | Operation::RCombBase => (), - } - } - fn encode_decorator(&mut self, decorator: &Decorator) { // Set the first byte to the decorator discriminant. // diff --git a/core/src/operations/mod.rs b/core/src/operations/mod.rs index 6eff46e2d9..8d6a8696e0 100644 --- a/core/src/operations/mod.rs +++ b/core/src/operations/mod.rs @@ -6,7 +6,7 @@ pub use decorators::{ AdviceInjector, AssemblyOp, DebugOptions, Decorator, DecoratorIterator, DecoratorList, SignatureKind, }; -use winter_utils::DeserializationError; +use winter_utils::{ByteWriter, DeserializationError, Serializable}; // OPERATIONS // ================================================================================================ @@ -869,3 +869,108 @@ impl fmt::Display for Operation { } } } + +impl Serializable for Operation { + fn write_into(&self, target: &mut W) { + target.write_u8(self.op_code()); + + // For operations that have extra data, encode it in `data`. + match self { + Operation::Assert(err_code) + | Operation::MpVerify(err_code) + | Operation::U32assert2(err_code) => { + err_code.to_le_bytes().write_into(target); + } + Operation::Push(value) => value.as_int().write_into(target), + + // Note: we explicitly write out all the operations so that whenever we make a + // modification to the `Operation` enum, we get a compile error here. This + // should help us remember to properly encode/decode each operation variant. + Operation::Noop + | Operation::FmpAdd + | Operation::FmpUpdate + | Operation::SDepth + | Operation::Caller + | Operation::Clk + | Operation::Join + | Operation::Split + | Operation::Loop + | Operation::Call + | Operation::Dyn + | Operation::SysCall + | Operation::Span + | Operation::End + | Operation::Repeat + | Operation::Respan + | Operation::Halt + | Operation::Add + | Operation::Neg + | Operation::Mul + | Operation::Inv + | Operation::Incr + | Operation::And + | Operation::Or + | Operation::Not + | Operation::Eq + | Operation::Eqz + | Operation::Expacc + | Operation::Ext2Mul + | Operation::U32split + | Operation::U32add + | Operation::U32add3 + | Operation::U32sub + | Operation::U32mul + | Operation::U32madd + | Operation::U32div + | Operation::U32and + | Operation::U32xor + | Operation::Pad + | Operation::Drop + | Operation::Dup0 + | Operation::Dup1 + | Operation::Dup2 + | Operation::Dup3 + | Operation::Dup4 + | Operation::Dup5 + | Operation::Dup6 + | Operation::Dup7 + | Operation::Dup9 + | Operation::Dup11 + | Operation::Dup13 + | Operation::Dup15 + | Operation::Swap + | Operation::SwapW + | Operation::SwapW2 + | Operation::SwapW3 + | Operation::SwapDW + | Operation::MovUp2 + | Operation::MovUp3 + | Operation::MovUp4 + | Operation::MovUp5 + | Operation::MovUp6 + | Operation::MovUp7 + | Operation::MovUp8 + | Operation::MovDn2 + | Operation::MovDn3 + | Operation::MovDn4 + | Operation::MovDn5 + | Operation::MovDn6 + | Operation::MovDn7 + | Operation::MovDn8 + | Operation::CSwap + | Operation::CSwapW + | Operation::AdvPop + | Operation::AdvPopW + | Operation::MLoadW + | Operation::MStoreW + | Operation::MLoad + | Operation::MStore + | Operation::MStream + | Operation::Pipe + | Operation::HPerm + | Operation::MrUpdate + | Operation::FriE2F4 + | Operation::RCombBase => (), + } + } +} From d4a50e95d27fdc4f69fa32728b2a6e500fc4cf0c Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 10 Jul 2024 08:57:12 -0400 Subject: [PATCH 097/172] impl Deserializable for `Operation` --- .../serialization/basic_block_data_decoder.rs | 49 +++++-------------- core/src/operations/mod.rs | 35 ++++++++++++- 2 files changed, 47 insertions(+), 37 deletions(-) diff --git a/core/src/mast/serialization/basic_block_data_decoder.rs b/core/src/mast/serialization/basic_block_data_decoder.rs index 1c12befe69..b25682822c 100644 --- a/core/src/mast/serialization/basic_block_data_decoder.rs +++ b/core/src/mast/serialization/basic_block_data_decoder.rs @@ -1,12 +1,11 @@ use crate::{ - AdviceInjector, AssemblyOp, DebugOptions, Decorator, DecoratorList, Operation, OperationData, - SignatureKind, + AdviceInjector, AssemblyOp, DebugOptions, Decorator, DecoratorList, Operation, SignatureKind, }; use super::{decorator::EncodedDecoratorVariant, StringIndex, StringRef}; use alloc::{string::String, vec::Vec}; -use miden_crypto::{Felt, ZERO}; -use winter_utils::{ByteReader, DeserializationError, SliceReader}; +use miden_crypto::Felt; +use winter_utils::{ByteReader, Deserializable, DeserializationError, SliceReader}; pub struct BasicBlockDataDecoder<'a> { data: &'a [u8], @@ -37,42 +36,14 @@ impl<'a> BasicBlockDataDecoder<'a> { let mut decorators: DecoratorList = Vec::new(); for _ in 0..num_to_decode { - let first_byte = self.data_reader.read_u8()?; + let first_byte = self.data_reader.peek_u8()?; if first_byte & 0b1000_0000 == 0 { // operation. - let op_code = first_byte; - - let operation = if op_code == Operation::Assert(0).op_code() - || op_code == Operation::MpVerify(0).op_code() - || op_code == Operation::U32assert2(0).op_code() - { - let value_le_bytes: [u8; 4] = self.data_reader.read_array()?; - let value = u32::from_le_bytes(value_le_bytes); - - Operation::with_opcode_and_data(op_code, OperationData::U32(value))? - } else if op_code == Operation::Push(ZERO).op_code() { - // Felt operation data - let value_le_bytes: [u8; 8] = self.data_reader.read_array()?; - let value_u64 = u64::from_le_bytes(value_le_bytes); - let value_felt = Felt::try_from(value_u64).map_err(|_| { - DeserializationError::InvalidValue(format!( - "Operation associated data doesn't fit in a field element: {value_u64}" - )) - })?; - - Operation::with_opcode_and_data(op_code, OperationData::Felt(value_felt))? - } else { - // No operation data - Operation::with_opcode_and_data(op_code, OperationData::None)? - }; - - operations.push(operation); + operations.push(Operation::read_from(&mut self.data_reader)?); } else { // decorator. - let discriminant = first_byte & 0b0111_1111; - let decorator = self.decode_decorator(discriminant)?; - + let decorator = self.decode_decorator()?; decorators.push((operations.len(), decorator)); } } @@ -83,7 +54,13 @@ impl<'a> BasicBlockDataDecoder<'a> { /// Helpers impl<'a> BasicBlockDataDecoder<'a> { - fn decode_decorator(&mut self, discriminant: u8) -> Result { + fn decode_decorator(&mut self) -> Result { + let discriminant = { + let first_byte = self.data_reader.read_u8()?; + + first_byte & 0b0111_1111 + }; + let decorator_variant = EncodedDecoratorVariant::from_discriminant(discriminant) .ok_or_else(|| { DeserializationError::InvalidValue(format!( diff --git a/core/src/operations/mod.rs b/core/src/operations/mod.rs index 8d6a8696e0..f816f5e12b 100644 --- a/core/src/operations/mod.rs +++ b/core/src/operations/mod.rs @@ -6,7 +6,8 @@ pub use decorators::{ AdviceInjector, AssemblyOp, DebugOptions, Decorator, DecoratorIterator, DecoratorList, SignatureKind, }; -use winter_utils::{ByteWriter, DeserializationError, Serializable}; +use miden_crypto::ZERO; +use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; // OPERATIONS // ================================================================================================ @@ -974,3 +975,35 @@ impl Serializable for Operation { } } } + +impl Deserializable for Operation { + fn read_from(source: &mut R) -> Result { + let op_code = source.read_u8()?; + + let operation = if op_code == Operation::Assert(0).op_code() + || op_code == Operation::MpVerify(0).op_code() + || op_code == Operation::U32assert2(0).op_code() + { + let value_le_bytes: [u8; 4] = source.read_array()?; + let value = u32::from_le_bytes(value_le_bytes); + + Operation::with_opcode_and_data(op_code, OperationData::U32(value))? + } else if op_code == Operation::Push(ZERO).op_code() { + // Felt operation data + let value_le_bytes: [u8; 8] = source.read_array()?; + let value_u64 = u64::from_le_bytes(value_le_bytes); + let value_felt = Felt::try_from(value_u64).map_err(|_| { + DeserializationError::InvalidValue(format!( + "Operation associated data doesn't fit in a field element: {value_u64}" + )) + })?; + + Operation::with_opcode_and_data(op_code, OperationData::Felt(value_felt))? + } else { + // No operation data + Operation::with_opcode_and_data(op_code, OperationData::None)? + }; + + Ok(operation) + } +} From 3ce71b978895734447d4f73839dbce7d3292d5c1 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 10 Jul 2024 09:12:09 -0400 Subject: [PATCH 098/172] `StringTableBuilder`: switch to using blake 3 --- core/src/mast/serialization/basic_block_data_builder.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/mast/serialization/basic_block_data_builder.rs b/core/src/mast/serialization/basic_block_data_builder.rs index c86ed1511f..242afa4b95 100644 --- a/core/src/mast/serialization/basic_block_data_builder.rs +++ b/core/src/mast/serialization/basic_block_data_builder.rs @@ -1,5 +1,5 @@ use alloc::{collections::BTreeMap, vec::Vec}; -use miden_crypto::hash::rpo::{Rpo256, RpoDigest}; +use miden_crypto::hash::blake::{Blake3Digest, Blake3_256}; use winter_utils::{ByteWriter, Serializable}; use crate::{ @@ -147,13 +147,13 @@ impl BasicBlockDataBuilder { #[derive(Debug, Default)] struct StringTableBuilder { table: Vec, - str_to_index: BTreeMap, + str_to_index: BTreeMap, StringIndex>, strings_data: Vec, } impl StringTableBuilder { pub fn add_string(&mut self, string: &str) -> StringIndex { - if let Some(str_idx) = self.str_to_index.get(&Rpo256::hash(string.as_bytes())) { + if let Some(str_idx) = self.str_to_index.get(&Blake3_256::hash(string.as_bytes())) { // return already interned string *str_idx } else { @@ -171,7 +171,7 @@ impl StringTableBuilder { self.strings_data.extend(string.as_bytes()); self.table.push(str_ref); - self.str_to_index.insert(Rpo256::hash(string.as_bytes()), str_idx); + self.str_to_index.insert(Blake3_256::hash(string.as_bytes()), str_idx); str_idx } From db33dc73e5599923024861a13512f6561f631532 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 10 Jul 2024 09:34:01 -0400 Subject: [PATCH 099/172] `EncodedDecoratorVariant`: moved discriminant bit logic to `discriminant()` method --- .../mast/serialization/basic_block_data_builder.rs | 4 +--- .../mast/serialization/basic_block_data_decoder.rs | 6 +----- core/src/mast/serialization/decorator.rs | 11 +++++++++-- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/core/src/mast/serialization/basic_block_data_builder.rs b/core/src/mast/serialization/basic_block_data_builder.rs index 242afa4b95..b280f88065 100644 --- a/core/src/mast/serialization/basic_block_data_builder.rs +++ b/core/src/mast/serialization/basic_block_data_builder.rs @@ -62,11 +62,9 @@ impl BasicBlockDataBuilder { impl BasicBlockDataBuilder { fn encode_decorator(&mut self, decorator: &Decorator) { // Set the first byte to the decorator discriminant. - // - // Note: the most significant bit is set to 1 (to differentiate decorators from operations). { let decorator_variant: EncodedDecoratorVariant = decorator.into(); - self.data.push(decorator_variant.discriminant() | 0b1000_0000); + self.data.push(decorator_variant.discriminant()); } // For decorators that have extra data, encode it in `data` and `strings`. diff --git a/core/src/mast/serialization/basic_block_data_decoder.rs b/core/src/mast/serialization/basic_block_data_decoder.rs index b25682822c..83f85bca81 100644 --- a/core/src/mast/serialization/basic_block_data_decoder.rs +++ b/core/src/mast/serialization/basic_block_data_decoder.rs @@ -55,11 +55,7 @@ impl<'a> BasicBlockDataDecoder<'a> { /// Helpers impl<'a> BasicBlockDataDecoder<'a> { fn decode_decorator(&mut self) -> Result { - let discriminant = { - let first_byte = self.data_reader.read_u8()?; - - first_byte & 0b0111_1111 - }; + let discriminant = self.data_reader.read_u8()?; let decorator_variant = EncodedDecoratorVariant::from_discriminant(discriminant) .ok_or_else(|| { diff --git a/core/src/mast/serialization/decorator.rs b/core/src/mast/serialization/decorator.rs index 22024397ca..c1d9b2f0f1 100644 --- a/core/src/mast/serialization/decorator.rs +++ b/core/src/mast/serialization/decorator.rs @@ -40,12 +40,19 @@ pub enum EncodedDecoratorVariant { } impl EncodedDecoratorVariant { + /// Returns the discriminant of the given decorator variant. + /// + /// To distinguish them from [`crate::Operation`] discriminants, the most significant bit of + /// decorator discriminant is always set to 1. pub fn discriminant(&self) -> u8 { - self.to_u8().expect("guaranteed to fit in a `u8` due to #[repr(u8)]") + let discriminant = self.to_u8().expect("guaranteed to fit in a `u8` due to #[repr(u8)]"); + + discriminant | 0b1000_0000 } + /// The inverse operation of [`Self::discriminant`]. pub fn from_discriminant(discriminant: u8) -> Option { - Self::from_u8(discriminant) + Self::from_u8(discriminant & 0b0111_1111) } } From 10f02a6ecf7143c533594e4593b09e218a39009e Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 10 Jul 2024 10:01:06 -0400 Subject: [PATCH 100/172] Remove basic block offset --- .../serialization/basic_block_data_builder.rs | 14 +----------- core/src/mast/serialization/info.rs | 22 ++++++------------- core/src/mast/serialization/mod.rs | 3 +-- 3 files changed, 9 insertions(+), 30 deletions(-) diff --git a/core/src/mast/serialization/basic_block_data_builder.rs b/core/src/mast/serialization/basic_block_data_builder.rs index b280f88065..1f675a8b5a 100644 --- a/core/src/mast/serialization/basic_block_data_builder.rs +++ b/core/src/mast/serialization/basic_block_data_builder.rs @@ -7,7 +7,7 @@ use crate::{ AdviceInjector, DebugOptions, Decorator, SignatureKind, }; -use super::{decorator::EncodedDecoratorVariant, DataOffset, StringIndex, StringRef}; +use super::{decorator::EncodedDecoratorVariant, StringIndex, StringRef}; // BASIC BLOCK DATA BUILDER // ================================================================================================ @@ -26,18 +26,6 @@ impl BasicBlockDataBuilder { } } -/// Accessors -impl BasicBlockDataBuilder { - /// Returns the offset in the serialized [`crate::mast::MastForest`] data field that the next - /// [`super::MastNodeInfo`] representing a [`BasicBlockNode`] will take. - pub fn next_data_offset(&self) -> DataOffset { - self.data - .len() - .try_into() - .expect("MAST forest data segment larger than 2^32 bytes") - } -} - /// Mutators impl BasicBlockDataBuilder { /// Encodes a [`BasicBlockNode`] into the serialized [`crate::mast::MastForest`] data field. diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 4b056a3ae3..af41c67a45 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -3,7 +3,7 @@ use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, use crate::mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}; -use super::{basic_block_data_decoder::BasicBlockDataDecoder, DataOffset}; +use super::basic_block_data_decoder::BasicBlockDataDecoder; // MAST NODE INFO // ================================================================================================ @@ -20,23 +20,15 @@ use super::{basic_block_data_decoder::BasicBlockDataDecoder, DataOffset}; #[derive(Debug)] pub struct MastNodeInfo { ty: MastNodeType, - offset: DataOffset, digest: RpoDigest, } impl MastNodeInfo { - pub fn new(mast_node: &MastNode, basic_block_offset: DataOffset) -> Self { + pub fn new(mast_node: &MastNode) -> Self { let ty = MastNodeType::new(mast_node); - let offset = if let MastNode::Block(_) = mast_node { - basic_block_offset - } else { - 0 - }; - Self { ty, - offset, digest: mast_node.digest(), } } @@ -106,19 +98,19 @@ impl MastNodeInfo { impl Serializable for MastNodeInfo { fn write_into(&self, target: &mut W) { - self.ty.write_into(target); - self.offset.write_into(target); - self.digest.write_into(target); + let Self { ty, digest } = self; + + ty.write_into(target); + digest.write_into(target); } } impl Deserializable for MastNodeInfo { fn read_from(source: &mut R) -> Result { let ty = Deserializable::read_from(source)?; - let offset = DataOffset::read_from(source)?; let digest = RpoDigest::read_from(source)?; - Ok(Self { ty, offset, digest }) + Ok(Self { ty, digest }) } } diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index b9f968ae34..238cb33dca 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -89,8 +89,7 @@ impl Serializable for MastForest { // MAST node infos for mast_node in &self.nodes { - let mast_node_info = - MastNodeInfo::new(mast_node, basic_block_data_builder.next_data_offset()); + let mast_node_info = MastNodeInfo::new(mast_node); if let MastNode::Block(basic_block) = mast_node { basic_block_data_builder.encode_basic_block(basic_block); From 421518d5cf16e713bc473843430c2a92b040c1bf Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 10 Jul 2024 10:06:11 -0400 Subject: [PATCH 101/172] Cargo: don't specify patch versions --- core/Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/Cargo.toml b/core/Cargo.toml index 91ca9d4d75..f87fa0c6df 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -30,8 +30,8 @@ std = [ math = { package = "winter-math", version = "0.9", default-features = false } miden-crypto = { git = "https://github.com/0xPolygonMiden/crypto", branch = "next", default-features = false } miden-formatting = { version = "0.1", default-features = false } -num-derive = "0.4.2" -num-traits = "0.2.19" +num-derive = "0.4" +num-traits = "0.2" thiserror = { version = "1.0", git = "https://github.com/bitwalker/thiserror", branch = "no-std", default-features = false } winter-utils = { package = "winter-utils", version = "0.9", default-features = false } From cf100c50c92e8a690f477a28c1ce8711eaf14fb7 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 10 Jul 2024 10:26:43 -0400 Subject: [PATCH 102/172] make deserialization more efficient --- core/src/mast/serialization/mod.rs | 43 +++++++++++++++--------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 238cb33dca..946793b275 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -87,21 +87,30 @@ impl Serializable for MastForest { // roots self.roots.write_into(target); - // MAST node infos - for mast_node in &self.nodes { - let mast_node_info = MastNodeInfo::new(mast_node); - - if let MastNode::Block(basic_block) = mast_node { - basic_block_data_builder.encode_basic_block(basic_block); - } - - mast_node_info.write_into(target); - } + // Prepare MAST node infos, but don't store them yet. We store them at the end to make + // deserialization more efficient. + let mast_node_infos: Vec = self + .nodes + .iter() + .map(|mast_node| { + let mast_node_info = MastNodeInfo::new(mast_node); + + if let MastNode::Block(basic_block) = mast_node { + basic_block_data_builder.encode_basic_block(basic_block); + } + + mast_node_info + }) + .collect(); let (data, string_table) = basic_block_data_builder.into_parts(); string_table.write_into(target); data.write_into(target); + + for mast_node_info in mast_node_infos { + mast_node_info.write_into(target); + } } } @@ -126,16 +135,6 @@ impl Deserializable for MastForest { let roots: Vec = Deserializable::read_from(source)?; - let mast_node_infos = { - let mut mast_node_infos = Vec::with_capacity(node_count); - for _ in 0..node_count { - let mast_node_info = MastNodeInfo::read_from(source)?; - mast_node_infos.push(mast_node_info); - } - - mast_node_infos - }; - let strings: Vec = Deserializable::read_from(source)?; let data: Vec = Deserializable::read_from(source)?; @@ -145,7 +144,9 @@ impl Deserializable for MastForest { let mast_forest = { let mut mast_forest = MastForest::new(); - for mast_node_info in mast_node_infos { + for _ in 0..node_count { + let mast_node_info = MastNodeInfo::read_from(source)?; + let node = mast_node_info .try_into_mast_node(&mast_forest, &mut basic_block_data_decoder)?; From f50073d29d1874543b2bd72e14592243fb903295 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 10 Jul 2024 11:40:40 -0400 Subject: [PATCH 103/172] num-traits and num-derive: set default-features false --- core/Cargo.toml | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/core/Cargo.toml b/core/Cargo.toml index f87fa0c6df..f18191053d 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -18,20 +18,14 @@ doctest = false [features] default = ["std"] -std = [ - "miden-crypto/std", - "miden-formatting/std", - "math/std", - "winter-utils/std", - "thiserror/std", -] +std = ["miden-crypto/std", "miden-formatting/std", "math/std", "winter-utils/std", "thiserror/std"] [dependencies] math = { package = "winter-math", version = "0.9", default-features = false } miden-crypto = { git = "https://github.com/0xPolygonMiden/crypto", branch = "next", default-features = false } miden-formatting = { version = "0.1", default-features = false } -num-derive = "0.4" -num-traits = "0.2" +num-derive = { version = "0.4", default-features = false } +num-traits = { version = "0.2", default-features = false } thiserror = { version = "1.0", git = "https://github.com/bitwalker/thiserror", branch = "no-std", default-features = false } winter-utils = { package = "winter-utils", version = "0.9", default-features = false } From 956aac156b2c65e547fe9e7f9a68216e9936087f Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 11 Jul 2024 09:47:51 -0400 Subject: [PATCH 104/172] Remove `OperationData` --- core/src/operations/mod.rs | 143 ------------------------------------- 1 file changed, 143 deletions(-) diff --git a/core/src/operations/mod.rs b/core/src/operations/mod.rs index 22d2b9b4d3..b1829ddeef 100644 --- a/core/src/operations/mod.rs +++ b/core/src/operations/mod.rs @@ -1,7 +1,6 @@ use super::Felt; use core::fmt; mod decorators; -use alloc::string::ToString; pub use decorators::{ AdviceInjector, AssemblyOp, DebugOptions, Decorator, DecoratorIterator, DecoratorList, SignatureKind, @@ -556,148 +555,6 @@ pub enum Operation { RCombBase = OPCODE_RCOMBBASE, } -pub enum OperationData { - Felt(Felt), - U32(u32), - None, -} - -/// Constructors -impl Operation { - /// Builds an operation from its opcode and inline data (if any). - pub fn with_opcode_and_data( - opcode: u8, - data: OperationData, - ) -> Result { - match opcode { - 0b0000_0000 => Ok(Self::Noop), - 0b0000_0001 => Ok(Self::Eqz), - 0b0000_0010 => Ok(Self::Neg), - 0b0000_0011 => Ok(Self::Inv), - 0b0000_0100 => Ok(Self::Incr), - 0b0000_0101 => Ok(Self::Not), - 0b0000_0110 => Ok(Self::FmpAdd), - 0b0000_0111 => Ok(Self::MLoad), - 0b0000_1000 => Ok(Self::Swap), - 0b0000_1001 => Ok(Self::Caller), - 0b0000_1010 => Ok(Self::MovUp2), - 0b0000_1011 => Ok(Self::MovDn2), - 0b0000_1100 => Ok(Self::MovUp3), - 0b0000_1101 => Ok(Self::MovDn3), - 0b0000_1110 => Ok(Self::AdvPopW), - 0b0000_1111 => Ok(Self::Expacc), - - 0b0001_0000 => Ok(Self::MovUp4), - 0b0001_0001 => Ok(Self::MovDn4), - 0b0001_0010 => Ok(Self::MovUp5), - 0b0001_0011 => Ok(Self::MovDn5), - 0b0001_0100 => Ok(Self::MovUp6), - 0b0001_0101 => Ok(Self::MovDn6), - 0b0001_0110 => Ok(Self::MovUp7), - 0b0001_0111 => Ok(Self::MovDn7), - 0b0001_1000 => Ok(Self::SwapW), - 0b0001_1001 => Ok(Self::Ext2Mul), - 0b0001_1010 => Ok(Self::MovUp8), - 0b0001_1011 => Ok(Self::MovDn8), - 0b0001_1100 => Ok(Self::SwapW2), - 0b0001_1101 => Ok(Self::SwapW3), - 0b0001_1110 => Ok(Self::SwapDW), - // 0b0001_1111 => , - 0b0010_0000 => match data { - OperationData::U32(value) => Ok(Self::Assert(value)), - _ => Err(DeserializationError::InvalidValue( - "Invalid opcode data. 'Assert' opcode provided, hence expected to receive u32 data.".to_string() - )), - }, - 0b0010_0001 => Ok(Self::Eq), - 0b0010_0010 => Ok(Self::Add), - 0b0010_0011 => Ok(Self::Mul), - 0b0010_0100 => Ok(Self::And), - 0b0010_0101 => Ok(Self::Or), - 0b0010_0110 => Ok(Self::U32and), - 0b0010_0111 => Ok(Self::U32xor), - 0b0010_1000 => Ok(Self::FriE2F4), - 0b0010_1001 => Ok(Self::Drop), - 0b0010_1010 => Ok(Self::CSwap), - 0b0010_1011 => Ok(Self::CSwapW), - 0b0010_1100 => Ok(Self::MLoadW), - 0b0010_1101 => Ok(Self::MStore), - 0b0010_1110 => Ok(Self::MStoreW), - 0b0010_1111 => Ok(Self::FmpUpdate), - - 0b0011_0000 => Ok(Self::Pad), - 0b0011_0001 => Ok(Self::Dup0), - 0b0011_0010 => Ok(Self::Dup1), - 0b0011_0011 => Ok(Self::Dup2), - 0b0011_0100 => Ok(Self::Dup3), - 0b0011_0101 => Ok(Self::Dup4), - 0b0011_0110 => Ok(Self::Dup5), - 0b0011_0111 => Ok(Self::Dup6), - 0b0011_1000 => Ok(Self::Dup7), - 0b0011_1001 => Ok(Self::Dup9), - 0b0011_1010 => Ok(Self::Dup11), - 0b0011_1011 => Ok(Self::Dup13), - 0b0011_1100 => Ok(Self::Dup15), - 0b0011_1101 => Ok(Self::AdvPop), - 0b0011_1110 => Ok(Self::SDepth), - 0b0011_1111 => Ok(Self::Clk), - - 0b0100_0000 => Ok(Self::U32add), - 0b0100_0010 => Ok(Self::U32sub), - 0b0100_0100 => Ok(Self::U32mul), - 0b0100_0110 => Ok(Self::U32div), - 0b0100_1000 => Ok(Self::U32split), - 0b0100_1010 => match data { - OperationData::U32(value) => Ok(Self::U32assert2(value)), - _ => Err(DeserializationError::InvalidValue( - "Invalid opcode data. 'U32assert2' opcode provided, hence expected to receive u32 data.".to_string() - )), - }, - 0b0100_1100 => Ok(Self::U32add3), - 0b0100_1110 => Ok(Self::U32madd), - - 0b0101_0000 => Ok(Self::HPerm), - 0b0101_0001 => match data { - OperationData::U32(value) => Ok(Self::MpVerify(value)), - _ => Err(DeserializationError::InvalidValue( - "Invalid opcode data. 'MpVerify' opcode provided, hence expected to receive u32 data.".to_string() - )), - }, - 0b0101_0010 => Ok(Self::Pipe), - 0b0101_0011 => Ok(Self::MStream), - 0b0101_0100 => Ok(Self::Split), - 0b0101_0101 => Ok(Self::Loop), - 0b0101_0110 => Ok(Self::Span), - 0b0101_0111 => Ok(Self::Join), - 0b0101_1000 => Ok(Self::Dyn), - 0b0101_1001 => Ok(Self::RCombBase), - // 0b0101_1010 => , - // 0b0101_1011 => , - // 0b0101_1100 => , - // 0b0101_1101 => , - // 0b0101_1110 => , - // 0b0101_1111 => , - 0b0110_0000 => Ok(Self::MrUpdate), - 0b0110_0100 => match data { - OperationData::Felt(value) => Ok(Self::Push(value)), - _ => Err(DeserializationError::InvalidValue( - "Invalid opcode data. 'Push' opcode provided, hence expected to receive Felt data.".to_string() - )), - }, - 0b0110_1000 => Ok(Self::SysCall), - 0b0110_1100 => Ok(Self::Call), - 0b0111_0000 => Ok(Self::End), - 0b0111_0100 => Ok(Self::Repeat), - 0b0111_1000 => Ok(Self::Respan), - 0b0111_1100 => Ok(Self::Halt), - - _ => Err(DeserializationError::InvalidValue(format!( - "Invalid opcode {opcode}" - ))), - } - } -} - impl Operation { pub const OP_BITS: usize = 7; From 9fe4e0e0f4a8196a360fc615ebe99bcaab9030c7 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 11 Jul 2024 10:09:47 -0400 Subject: [PATCH 105/172] `StringRef`: move string length to data buffer --- .../serialization/basic_block_data_builder.rs | 4 +-- .../serialization/basic_block_data_decoder.rs | 25 ++++++------------- core/src/mast/serialization/mod.rs | 7 +----- 3 files changed, 10 insertions(+), 26 deletions(-) diff --git a/core/src/mast/serialization/basic_block_data_builder.rs b/core/src/mast/serialization/basic_block_data_builder.rs index 1f675a8b5a..3a8a2e06f1 100644 --- a/core/src/mast/serialization/basic_block_data_builder.rs +++ b/core/src/mast/serialization/basic_block_data_builder.rs @@ -151,11 +151,10 @@ impl StringTableBuilder { .len() .try_into() .expect("strings table larger than 2^32 bytes"), - len: string.len().try_into().expect("string larger than 2^32 bytes"), }; let str_idx = self.table.len(); - self.strings_data.extend(string.as_bytes()); + string.write_into(&mut self.strings_data); self.table.push(str_ref); self.str_to_index.insert(Blake3_256::hash(string.as_bytes()), str_idx); @@ -174,7 +173,6 @@ impl StringTableBuilder { .into_iter() .map(|str_ref| StringRef { offset: str_ref.offset + table_offset, - len: str_ref.len, }) .collect() } diff --git a/core/src/mast/serialization/basic_block_data_decoder.rs b/core/src/mast/serialization/basic_block_data_decoder.rs index 83f85bca81..7f72870208 100644 --- a/core/src/mast/serialization/basic_block_data_decoder.rs +++ b/core/src/mast/serialization/basic_block_data_decoder.rs @@ -191,26 +191,17 @@ impl<'a> BasicBlockDataDecoder<'a> { } fn read_string(&self, str_idx: StringIndex) -> Result { - let str_ref = self.strings.get(str_idx).ok_or_else(|| { - DeserializationError::InvalidValue(format!("invalid index in strings table: {str_idx}")) - })?; - - let str_bytes = { - let start = str_ref.offset as usize; - let end = (str_ref.offset + str_ref.len) as usize; - - self.data.get(start..end).ok_or_else(|| { + let str_offset = { + let str_ref = self.strings.get(str_idx).ok_or_else(|| { DeserializationError::InvalidValue(format!( - "invalid string ref in strings table. Offset: {}, length: {}", - str_ref.offset, str_ref.len + "invalid index in strings table: {str_idx}" )) - })? + })?; + + str_ref.offset as usize }; - String::from_utf8(str_bytes.to_vec()).map_err(|_| { - DeserializationError::InvalidValue(format!( - "Invalid UTF-8 string in strings table: {str_bytes:?}" - )) - }) + let mut reader = SliceReader::new(&self.data[str_offset..]); + reader.read() } } diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 946793b275..547045a0cf 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -49,24 +49,19 @@ const VERSION: [u8; 3] = [0, 0, 0]; pub struct StringRef { /// Offset into the `data` section. offset: DataOffset, - - /// Length of the utf-8 string. - len: u32, } impl Serializable for StringRef { fn write_into(&self, target: &mut W) { self.offset.write_into(target); - self.len.write_into(target); } } impl Deserializable for StringRef { fn read_from(source: &mut R) -> Result { let offset = DataOffset::read_from(source)?; - let len = source.read_u32()?; - Ok(Self { offset, len }) + Ok(Self { offset }) } } From c5258285c170755e75f36eab674a5668500c4cd8 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 11 Jul 2024 10:43:36 -0400 Subject: [PATCH 106/172] store offset in block --- .../serialization/basic_block_data_builder.rs | 10 ++++- core/src/mast/serialization/info.rs | 38 +++++++++++-------- core/src/mast/serialization/mod.rs | 3 +- 3 files changed, 33 insertions(+), 18 deletions(-) diff --git a/core/src/mast/serialization/basic_block_data_builder.rs b/core/src/mast/serialization/basic_block_data_builder.rs index 3a8a2e06f1..3d6ac31257 100644 --- a/core/src/mast/serialization/basic_block_data_builder.rs +++ b/core/src/mast/serialization/basic_block_data_builder.rs @@ -7,7 +7,7 @@ use crate::{ AdviceInjector, DebugOptions, Decorator, SignatureKind, }; -use super::{decorator::EncodedDecoratorVariant, StringIndex, StringRef}; +use super::{decorator::EncodedDecoratorVariant, DataOffset, StringIndex, StringRef}; // BASIC BLOCK DATA BUILDER // ================================================================================================ @@ -26,6 +26,14 @@ impl BasicBlockDataBuilder { } } +/// Accessors +impl BasicBlockDataBuilder { + /// Returns the current offset into the data buffer. + pub fn get_offset(&self) -> DataOffset { + self.data.len() as DataOffset + } +} + /// Mutators impl BasicBlockDataBuilder { /// Encodes a [`BasicBlockNode`] into the serialized [`crate::mast::MastForest`] data field. diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index af41c67a45..a6922515d1 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -3,7 +3,7 @@ use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, use crate::mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}; -use super::basic_block_data_decoder::BasicBlockDataDecoder; +use super::{basic_block_data_decoder::BasicBlockDataDecoder, DataOffset}; // MAST NODE INFO // ================================================================================================ @@ -24,8 +24,8 @@ pub struct MastNodeInfo { } impl MastNodeInfo { - pub fn new(mast_node: &MastNode) -> Self { - let ty = MastNodeType::new(mast_node); + pub fn new(mast_node: &MastNode, basic_block_offset: DataOffset) -> Self { + let ty = MastNodeType::new(mast_node, basic_block_offset); Self { ty, @@ -40,6 +40,7 @@ impl MastNodeInfo { ) -> Result { let mast_node = match self.ty { MastNodeType::Block { + offset: _, len: num_operations_and_decorators, } => { let (operations, decorators) = basic_block_data_decoder @@ -147,6 +148,8 @@ pub enum MastNodeType { body_id: u32, } = LOOP, Block { + /// Offset of the basic block in the data segment + offset: u32, /// The number of operations and decorators in the basic block len: u32, } = BLOCK, @@ -163,14 +166,17 @@ pub enum MastNodeType { /// Constructors impl MastNodeType { /// Constructs a new [`MastNodeType`] from a [`MastNode`]. - pub fn new(mast_node: &MastNode) -> Self { + pub fn new(mast_node: &MastNode, basic_block_offset: u32) -> Self { use MastNode::*; match mast_node { Block(block_node) => { let len = block_node.num_operations_and_decorators(); - Self::Block { len } + Self::Block { + len, + offset: basic_block_offset, + } } Join(join_node) => Self::Join { left_child_id: join_node.first().0, @@ -233,13 +239,13 @@ impl MastNodeType { MastNodeType::Join { left_child_id: left, right_child_id: right, - } => Self::encode_join_or_split(*left, *right), + } => Self::encode_u32_pair(*left, *right), MastNodeType::Split { if_branch_id: if_branch, else_branch_id: else_branch, - } => Self::encode_join_or_split(*if_branch, *else_branch), + } => Self::encode_u32_pair(*if_branch, *else_branch), MastNodeType::Loop { body_id: body } => Self::encode_u32_payload(*body), - MastNodeType::Block { len } => Self::encode_u32_payload(*len), + MastNodeType::Block { offset, len } => Self::encode_u32_pair(*offset, *len), MastNodeType::Call { callee_id } => Self::encode_u32_payload(*callee_id), MastNodeType::SysCall { callee_id } => Self::encode_u32_payload(*callee_id), MastNodeType::Dyn => [0; 8], @@ -247,7 +253,7 @@ impl MastNodeType { } } - fn encode_join_or_split(left_child_id: u32, right_child_id: u32) -> [u8; 8] { + fn encode_u32_pair(left_child_id: u32, right_child_id: u32) -> [u8; 8] { assert!(left_child_id < 2_u32.pow(30)); assert!(right_child_id < 2_u32.pow(30)); @@ -304,14 +310,14 @@ impl Deserializable for MastNodeType { match tag { JOIN => { - let (left_child_id, right_child_id) = Self::decode_join_or_split(bytes); + let (left_child_id, right_child_id) = Self::decode_u32_pair(bytes); Ok(Self::Join { left_child_id, right_child_id, }) } SPLIT => { - let (if_branch_id, else_branch_id) = Self::decode_join_or_split(bytes); + let (if_branch_id, else_branch_id) = Self::decode_u32_pair(bytes); Ok(Self::Split { if_branch_id, else_branch_id, @@ -322,8 +328,8 @@ impl Deserializable for MastNodeType { Ok(Self::Loop { body_id }) } BLOCK => { - let len = Self::decode_u32_payload(bytes); - Ok(Self::Block { len }) + let (offset, len) = Self::decode_u32_pair(bytes); + Ok(Self::Block { offset, len }) } CALL => { let callee_id = Self::decode_u32_payload(bytes); @@ -344,7 +350,7 @@ impl Deserializable for MastNodeType { /// Deserialization helpers impl MastNodeType { - fn decode_join_or_split(buffer: [u8; 8]) -> (u32, u32) { + fn decode_u32_pair(buffer: [u8; 8]) -> (u32, u32) { let first = { let mut first_le_bytes = [0_u8; 4]; @@ -414,7 +420,7 @@ mod tests { assert_eq!(expected_encoded_mast_node_type.to_vec(), encoded_mast_node_type); let (decoded_left, decoded_right) = - MastNodeType::decode_join_or_split(expected_encoded_mast_node_type); + MastNodeType::decode_u32_pair(expected_encoded_mast_node_type); assert_eq!(left_child_id, decoded_left); assert_eq!(right_child_id, decoded_right); } @@ -441,7 +447,7 @@ mod tests { assert_eq!(expected_encoded_mast_node_type.to_vec(), encoded_mast_node_type); let (decoded_if_branch, decoded_else_branch) = - MastNodeType::decode_join_or_split(expected_encoded_mast_node_type); + MastNodeType::decode_u32_pair(expected_encoded_mast_node_type); assert_eq!(if_branch_id, decoded_if_branch); assert_eq!(else_branch_id, decoded_else_branch); } diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 547045a0cf..d2484d1704 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -88,7 +88,8 @@ impl Serializable for MastForest { .nodes .iter() .map(|mast_node| { - let mast_node_info = MastNodeInfo::new(mast_node); + let mast_node_info = + MastNodeInfo::new(mast_node, basic_block_data_builder.get_offset()); if let MastNode::Block(basic_block) = mast_node { basic_block_data_builder.encode_basic_block(basic_block); From 2b332e1349f4a16344347e628c58102744652245 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 12 Jul 2024 08:36:03 -0400 Subject: [PATCH 107/172] Use `source.read_u32/u64()` --- core/src/operations/mod.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/core/src/operations/mod.rs b/core/src/operations/mod.rs index b1829ddeef..72fed3bae4 100644 --- a/core/src/operations/mod.rs +++ b/core/src/operations/mod.rs @@ -873,8 +873,7 @@ impl Deserializable for Operation { OPCODE_SWAPDW => Self::SwapDW, OPCODE_ASSERT => { - let err_code_le_bytes: [u8; 4] = source.read_array()?; - let err_code = u32::from_le_bytes(err_code_le_bytes); + let err_code = source.read_u32()?; Self::Assert(err_code) } OPCODE_EQ => Self::Eq, @@ -916,8 +915,7 @@ impl Deserializable for Operation { OPCODE_U32DIV => Self::U32div, OPCODE_U32SPLIT => Self::U32split, OPCODE_U32ASSERT2 => { - let err_code_le_bytes: [u8; 4] = source.read_array()?; - let err_code = u32::from_le_bytes(err_code_le_bytes); + let err_code = source.read_u32()?; Self::U32assert2(err_code) } @@ -926,8 +924,7 @@ impl Deserializable for Operation { OPCODE_HPERM => Self::HPerm, OPCODE_MPVERIFY => { - let err_code_le_bytes: [u8; 4] = source.read_array()?; - let err_code = u32::from_le_bytes(err_code_le_bytes); + let err_code = source.read_u32()?; Self::MpVerify(err_code) } @@ -942,8 +939,7 @@ impl Deserializable for Operation { OPCODE_MRUPDATE => Self::MrUpdate, OPCODE_PUSH => { - let value_le_bytes: [u8; 8] = source.read_array()?; - let value_u64 = u64::from_le_bytes(value_le_bytes); + let value_u64 = source.read_u64()?; let value_felt = Felt::try_from(value_u64).map_err(|_| { DeserializationError::InvalidValue(format!( "Operation associated data doesn't fit in a field element: {value_u64}" From 9ca910a1fe2840de12ff3e0f6ab6452ac254e67f Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 12 Jul 2024 08:46:36 -0400 Subject: [PATCH 108/172] Update `MastNodeInfo` docstring --- core/src/mast/serialization/info.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index a6922515d1..95b8ef357a 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -10,10 +10,6 @@ use super::{basic_block_data_decoder::BasicBlockDataDecoder, DataOffset}; /// Represents a serialized [`MastNode`], with some data inlined in its [`MastNodeType`]. /// -/// In the case of [`crate::mast::BasicBlockNode`], all its operation- and decorator-related data is -/// stored in the serialized [`MastForest`]'s `data` field at offset represented by the `offset` -/// field. For all other variants of [`MastNode`], the `offset` field is guaranteed to be 0. -/// /// The serialized representation of [`MastNodeInfo`] is guaranteed to be fixed width, so that the /// nodes stored in the `nodes` table of the serialized [`MastForest`] can be accessed quickly by /// index. From 5c6f28779907fc151b6deec7382f2f30b022a36d Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 12 Jul 2024 08:54:37 -0400 Subject: [PATCH 109/172] rename arguments in `encode_u32_pair` --- core/src/mast/serialization/info.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 95b8ef357a..3fb2816bba 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -249,15 +249,15 @@ impl MastNodeType { } } - fn encode_u32_pair(left_child_id: u32, right_child_id: u32) -> [u8; 8] { - assert!(left_child_id < 2_u32.pow(30)); - assert!(right_child_id < 2_u32.pow(30)); + fn encode_u32_pair(left_value: u32, right_value: u32) -> [u8; 8] { + assert!(left_value < 2_u32.pow(30)); + assert!(right_value < 2_u32.pow(30)); let mut result: [u8; 8] = [0_u8; 8]; // write left child into result { - let [lsb, a, b, msb] = left_child_id.to_le_bytes(); + let [lsb, a, b, msb] = left_value.to_le_bytes(); result[0] |= lsb >> 4; result[1] |= lsb << 4; result[1] |= a >> 4; @@ -280,7 +280,7 @@ impl MastNodeType { // significant bits. Also, the most significant byte of the right child is guaranteed to // fit in 6 bits. Hence, we use big endian format for the right child id to simplify // encoding and decoding. - let [msb, a, b, lsb] = right_child_id.to_be_bytes(); + let [msb, a, b, lsb] = right_value.to_be_bytes(); result[4] |= msb; result[5] = a; From 58df61e085f76fb9eb9b0228fb7e53e99f8b0414 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 12 Jul 2024 09:13:43 -0400 Subject: [PATCH 110/172] Use basic block offset in deserialization --- .../serialization/basic_block_data_decoder.rs | 58 +++++++++---------- core/src/mast/serialization/info.rs | 6 +- core/src/mast/serialization/mod.rs | 6 +- 3 files changed, 34 insertions(+), 36 deletions(-) diff --git a/core/src/mast/serialization/basic_block_data_decoder.rs b/core/src/mast/serialization/basic_block_data_decoder.rs index 7f72870208..f603ce3f79 100644 --- a/core/src/mast/serialization/basic_block_data_decoder.rs +++ b/core/src/mast/serialization/basic_block_data_decoder.rs @@ -2,48 +2,43 @@ use crate::{ AdviceInjector, AssemblyOp, DebugOptions, Decorator, DecoratorList, Operation, SignatureKind, }; -use super::{decorator::EncodedDecoratorVariant, StringIndex, StringRef}; +use super::{decorator::EncodedDecoratorVariant, DataOffset, StringIndex, StringRef}; use alloc::{string::String, vec::Vec}; use miden_crypto::Felt; use winter_utils::{ByteReader, Deserializable, DeserializationError, SliceReader}; pub struct BasicBlockDataDecoder<'a> { data: &'a [u8], - data_reader: SliceReader<'a>, strings: &'a [StringRef], } /// Constructors impl<'a> BasicBlockDataDecoder<'a> { pub fn new(data: &'a [u8], strings: &'a [StringRef]) -> Self { - let data_reader = SliceReader::new(data); - - Self { - data, - data_reader, - strings, - } + Self { data, strings } } } /// Mutators impl<'a> BasicBlockDataDecoder<'a> { pub fn decode_operations_and_decorators( - &mut self, + &self, + offset: DataOffset, num_to_decode: u32, ) -> Result<(Vec, DecoratorList), DeserializationError> { let mut operations: Vec = Vec::new(); let mut decorators: DecoratorList = Vec::new(); + let mut data_reader = SliceReader::new(&self.data[offset as usize..]); for _ in 0..num_to_decode { - let first_byte = self.data_reader.peek_u8()?; + let first_byte = data_reader.peek_u8()?; if first_byte & 0b1000_0000 == 0 { // operation. - operations.push(Operation::read_from(&mut self.data_reader)?); + operations.push(Operation::read_from(&mut data_reader)?); } else { // decorator. - let decorator = self.decode_decorator()?; + let decorator = self.decode_decorator(&mut data_reader)?; decorators.push((operations.len(), decorator)); } } @@ -54,8 +49,11 @@ impl<'a> BasicBlockDataDecoder<'a> { /// Helpers impl<'a> BasicBlockDataDecoder<'a> { - fn decode_decorator(&mut self) -> Result { - let discriminant = self.data_reader.read_u8()?; + fn decode_decorator( + &self, + data_reader: &mut SliceReader, + ) -> Result { + let discriminant = data_reader.read_u8()?; let decorator_variant = EncodedDecoratorVariant::from_discriminant(discriminant) .ok_or_else(|| { @@ -75,8 +73,8 @@ impl<'a> BasicBlockDataDecoder<'a> { Ok(Decorator::Advice(AdviceInjector::UpdateMerkleNode)) } EncodedDecoratorVariant::AdviceInjectorMapValueToStack => { - let include_len = self.data_reader.read_bool()?; - let key_offset = self.data_reader.read_usize()?; + let include_len = data_reader.read_bool()?; + let key_offset = data_reader.read_usize()?; Ok(Decorator::Advice(AdviceInjector::MapValueToStack { include_len, @@ -120,7 +118,7 @@ impl<'a> BasicBlockDataDecoder<'a> { Ok(Decorator::Advice(AdviceInjector::MemToMap)) } EncodedDecoratorVariant::AdviceInjectorHdwordToMap => { - let domain = self.data_reader.read_u64()?; + let domain = data_reader.read_u64()?; let domain = Felt::try_from(domain).map_err(|err| { DeserializationError::InvalidValue(format!( "Error when deserializing HdwordToMap decorator domain: {err}" @@ -138,16 +136,16 @@ impl<'a> BasicBlockDataDecoder<'a> { })) } EncodedDecoratorVariant::AssemblyOp => { - let num_cycles = self.data_reader.read_u8()?; - let should_break = self.data_reader.read_bool()?; + let num_cycles = data_reader.read_u8()?; + let should_break = data_reader.read_bool()?; let context_name = { - let str_index_in_table = self.data_reader.read_usize()?; + let str_index_in_table = data_reader.read_usize()?; self.read_string(str_index_in_table)? }; let op = { - let str_index_in_table = self.data_reader.read_usize()?; + let str_index_in_table = data_reader.read_usize()?; self.read_string(str_index_in_table)? }; @@ -157,7 +155,7 @@ impl<'a> BasicBlockDataDecoder<'a> { Ok(Decorator::Debug(DebugOptions::StackAll)) } EncodedDecoratorVariant::DebugOptionsStackTop => { - let value = self.data_reader.read_u8()?; + let value = data_reader.read_u8()?; Ok(Decorator::Debug(DebugOptions::StackTop(value))) } @@ -165,25 +163,25 @@ impl<'a> BasicBlockDataDecoder<'a> { Ok(Decorator::Debug(DebugOptions::MemAll)) } EncodedDecoratorVariant::DebugOptionsMemInterval => { - let start = u32::from_le_bytes(self.data_reader.read_array::<4>()?); - let end = u32::from_le_bytes(self.data_reader.read_array::<4>()?); + let start = u32::from_le_bytes(data_reader.read_array::<4>()?); + let end = u32::from_le_bytes(data_reader.read_array::<4>()?); Ok(Decorator::Debug(DebugOptions::MemInterval(start, end))) } EncodedDecoratorVariant::DebugOptionsLocalInterval => { - let start = u16::from_le_bytes(self.data_reader.read_array::<2>()?); - let second = u16::from_le_bytes(self.data_reader.read_array::<2>()?); - let end = u16::from_le_bytes(self.data_reader.read_array::<2>()?); + let start = u16::from_le_bytes(data_reader.read_array::<2>()?); + let second = u16::from_le_bytes(data_reader.read_array::<2>()?); + let end = u16::from_le_bytes(data_reader.read_array::<2>()?); Ok(Decorator::Debug(DebugOptions::LocalInterval(start, second, end))) } EncodedDecoratorVariant::Event => { - let value = u32::from_le_bytes(self.data_reader.read_array::<4>()?); + let value = u32::from_le_bytes(data_reader.read_array::<4>()?); Ok(Decorator::Event(value)) } EncodedDecoratorVariant::Trace => { - let value = u32::from_le_bytes(self.data_reader.read_array::<4>()?); + let value = u32::from_le_bytes(data_reader.read_array::<4>()?); Ok(Decorator::Trace(value)) } diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 3fb2816bba..6205a202c2 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -32,15 +32,15 @@ impl MastNodeInfo { pub fn try_into_mast_node( self, mast_forest: &MastForest, - basic_block_data_decoder: &mut BasicBlockDataDecoder, + basic_block_data_decoder: &BasicBlockDataDecoder, ) -> Result { let mast_node = match self.ty { MastNodeType::Block { - offset: _, + offset, len: num_operations_and_decorators, } => { let (operations, decorators) = basic_block_data_decoder - .decode_operations_and_decorators(num_operations_and_decorators)?; + .decode_operations_and_decorators(offset, num_operations_and_decorators)?; Ok(MastNode::new_basic_block_with_decorators(operations, decorators)) } diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index d2484d1704..d8a6b0f455 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -135,7 +135,7 @@ impl Deserializable for MastForest { let data: Vec = Deserializable::read_from(source)?; - let mut basic_block_data_decoder = BasicBlockDataDecoder::new(&data, &strings); + let basic_block_data_decoder = BasicBlockDataDecoder::new(&data, &strings); let mast_forest = { let mut mast_forest = MastForest::new(); @@ -143,8 +143,8 @@ impl Deserializable for MastForest { for _ in 0..node_count { let mast_node_info = MastNodeInfo::read_from(source)?; - let node = mast_node_info - .try_into_mast_node(&mast_forest, &mut basic_block_data_decoder)?; + let node = + mast_node_info.try_into_mast_node(&mast_forest, &basic_block_data_decoder)?; mast_forest.add_node(node); } From 1cdc419679c2938a8fb3f185e2c4cb294ecbb764 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 12 Jul 2024 09:18:58 -0400 Subject: [PATCH 111/172] `BasicBlockDataDecoder`: use `ByteReader::read_u16/32()` methods --- .../mast/serialization/basic_block_data_decoder.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/core/src/mast/serialization/basic_block_data_decoder.rs b/core/src/mast/serialization/basic_block_data_decoder.rs index f603ce3f79..8a9d50e902 100644 --- a/core/src/mast/serialization/basic_block_data_decoder.rs +++ b/core/src/mast/serialization/basic_block_data_decoder.rs @@ -163,25 +163,25 @@ impl<'a> BasicBlockDataDecoder<'a> { Ok(Decorator::Debug(DebugOptions::MemAll)) } EncodedDecoratorVariant::DebugOptionsMemInterval => { - let start = u32::from_le_bytes(data_reader.read_array::<4>()?); - let end = u32::from_le_bytes(data_reader.read_array::<4>()?); + let start = data_reader.read_u32()?; + let end = data_reader.read_u32()?; Ok(Decorator::Debug(DebugOptions::MemInterval(start, end))) } EncodedDecoratorVariant::DebugOptionsLocalInterval => { - let start = u16::from_le_bytes(data_reader.read_array::<2>()?); - let second = u16::from_le_bytes(data_reader.read_array::<2>()?); - let end = u16::from_le_bytes(data_reader.read_array::<2>()?); + let start = data_reader.read_u16()?; + let second = data_reader.read_u16()?; + let end = data_reader.read_u16()?; Ok(Decorator::Debug(DebugOptions::LocalInterval(start, second, end))) } EncodedDecoratorVariant::Event => { - let value = u32::from_le_bytes(data_reader.read_array::<4>()?); + let value = data_reader.read_u32()?; Ok(Decorator::Event(value)) } EncodedDecoratorVariant::Trace => { - let value = u32::from_le_bytes(data_reader.read_array::<4>()?); + let value = data_reader.read_u32()?; Ok(Decorator::Trace(value)) } From f45d0c54b78a5cd4bf7f89499efe93a351823b83 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 12 Jul 2024 09:21:12 -0400 Subject: [PATCH 112/172] `StringTableBuilder`: fix comment --- core/src/mast/serialization/basic_block_data_builder.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/mast/serialization/basic_block_data_builder.rs b/core/src/mast/serialization/basic_block_data_builder.rs index 3d6ac31257..94f50603ff 100644 --- a/core/src/mast/serialization/basic_block_data_builder.rs +++ b/core/src/mast/serialization/basic_block_data_builder.rs @@ -152,7 +152,7 @@ impl StringTableBuilder { *str_idx } else { // add new string to table - // NOTE: these string refs' offset will need to be shifted again in `into_buffer()` + // NOTE: these string refs' offset will need to be shifted again in `into_table()` let str_ref = StringRef { offset: self .strings_data From 43cadeb733d9d40056bdc47dbb7d2bc4751206a0 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 12 Jul 2024 09:27:16 -0400 Subject: [PATCH 113/172] Remove `StringRef` in favor of `DataOffset` --- .../serialization/basic_block_data_builder.rs | 30 ++++++++----------- .../serialization/basic_block_data_decoder.rs | 18 ++++------- core/src/mast/serialization/mod.rs | 28 +---------------- 3 files changed, 19 insertions(+), 57 deletions(-) diff --git a/core/src/mast/serialization/basic_block_data_builder.rs b/core/src/mast/serialization/basic_block_data_builder.rs index 94f50603ff..78a045d6c4 100644 --- a/core/src/mast/serialization/basic_block_data_builder.rs +++ b/core/src/mast/serialization/basic_block_data_builder.rs @@ -7,7 +7,7 @@ use crate::{ AdviceInjector, DebugOptions, Decorator, SignatureKind, }; -use super::{decorator::EncodedDecoratorVariant, DataOffset, StringIndex, StringRef}; +use super::{decorator::EncodedDecoratorVariant, DataOffset, StringIndex}; // BASIC BLOCK DATA BUILDER // ================================================================================================ @@ -48,7 +48,7 @@ impl BasicBlockDataBuilder { } /// Returns the serialized [`crate::mast::MastForest`] data field, as well as the string table. - pub fn into_parts(mut self) -> (Vec, Vec) { + pub fn into_parts(mut self) -> (Vec, Vec) { let string_table = self.string_table_builder.into_table(&mut self.data); (self.data, string_table) } @@ -140,7 +140,7 @@ impl BasicBlockDataBuilder { #[derive(Debug, Default)] struct StringTableBuilder { - table: Vec, + table: Vec, str_to_index: BTreeMap, StringIndex>, strings_data: Vec, } @@ -153,35 +153,29 @@ impl StringTableBuilder { } else { // add new string to table // NOTE: these string refs' offset will need to be shifted again in `into_table()` - let str_ref = StringRef { - offset: self - .strings_data - .len() - .try_into() - .expect("strings table larger than 2^32 bytes"), - }; + let str_offset = self + .strings_data + .len() + .try_into() + .expect("strings table larger than 2^32 bytes"); + let str_idx = self.table.len(); string.write_into(&mut self.strings_data); - self.table.push(str_ref); + self.table.push(str_offset); self.str_to_index.insert(Blake3_256::hash(string.as_bytes()), str_idx); str_idx } } - pub fn into_table(self, data: &mut Vec) -> Vec { + pub fn into_table(self, data: &mut Vec) -> Vec { let table_offset: u32 = data .len() .try_into() .expect("MAST forest serialization: data field longer than 2^32 bytes"); data.extend(self.strings_data); - self.table - .into_iter() - .map(|str_ref| StringRef { - offset: str_ref.offset + table_offset, - }) - .collect() + self.table.into_iter().map(|str_offset| str_offset + table_offset).collect() } } diff --git a/core/src/mast/serialization/basic_block_data_decoder.rs b/core/src/mast/serialization/basic_block_data_decoder.rs index 8a9d50e902..78dd77215f 100644 --- a/core/src/mast/serialization/basic_block_data_decoder.rs +++ b/core/src/mast/serialization/basic_block_data_decoder.rs @@ -2,19 +2,19 @@ use crate::{ AdviceInjector, AssemblyOp, DebugOptions, Decorator, DecoratorList, Operation, SignatureKind, }; -use super::{decorator::EncodedDecoratorVariant, DataOffset, StringIndex, StringRef}; +use super::{decorator::EncodedDecoratorVariant, DataOffset, StringIndex}; use alloc::{string::String, vec::Vec}; use miden_crypto::Felt; use winter_utils::{ByteReader, Deserializable, DeserializationError, SliceReader}; pub struct BasicBlockDataDecoder<'a> { data: &'a [u8], - strings: &'a [StringRef], + strings: &'a [DataOffset], } /// Constructors impl<'a> BasicBlockDataDecoder<'a> { - pub fn new(data: &'a [u8], strings: &'a [StringRef]) -> Self { + pub fn new(data: &'a [u8], strings: &'a [DataOffset]) -> Self { Self { data, strings } } } @@ -189,15 +189,9 @@ impl<'a> BasicBlockDataDecoder<'a> { } fn read_string(&self, str_idx: StringIndex) -> Result { - let str_offset = { - let str_ref = self.strings.get(str_idx).ok_or_else(|| { - DeserializationError::InvalidValue(format!( - "invalid index in strings table: {str_idx}" - )) - })?; - - str_ref.offset as usize - }; + let str_offset = self.strings.get(str_idx).copied().ok_or_else(|| { + DeserializationError::InvalidValue(format!("invalid index in strings table: {str_idx}")) + })? as usize; let mut reader = SliceReader::new(&self.data[str_offset..]); reader.read() diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index d8a6b0f455..a71c79c89a 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -39,32 +39,6 @@ const MAGIC: &[u8; 5] = b"MAST\0"; /// version field itself, but should be considered invalid for now. const VERSION: [u8; 3] = [0, 0, 0]; -// STRING REF -// ================================================================================================ - -/// An entry in the `strings` table of an encoded [`MastForest`]. -/// -/// Strings are UTF8-encoded. -#[derive(Debug)] -pub struct StringRef { - /// Offset into the `data` section. - offset: DataOffset, -} - -impl Serializable for StringRef { - fn write_into(&self, target: &mut W) { - self.offset.write_into(target); - } -} - -impl Deserializable for StringRef { - fn read_from(source: &mut R) -> Result { - let offset = DataOffset::read_from(source)?; - - Ok(Self { offset }) - } -} - // MAST FOREST SERIALIZATION/DESERIALIZATION // ================================================================================================ @@ -131,7 +105,7 @@ impl Deserializable for MastForest { let roots: Vec = Deserializable::read_from(source)?; - let strings: Vec = Deserializable::read_from(source)?; + let strings: Vec = Deserializable::read_from(source)?; let data: Vec = Deserializable::read_from(source)?; From 7dec4281baea3e86e8f21fde62fd4571e719b07a Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 12 Jul 2024 10:13:25 -0400 Subject: [PATCH 114/172] cleanup `MastNodeType` serialization --- core/src/mast/serialization/info.rs | 249 ++++++++-------------------- 1 file changed, 71 insertions(+), 178 deletions(-) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 6205a202c2..c2eab7d4cd 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -204,18 +204,28 @@ impl MastNodeType { impl Serializable for MastNodeType { fn write_into(&self, target: &mut W) { - let serialized_bytes = { - let mut serialized_bytes = self.inline_data_to_bytes(); + let discriminant = self.discriminant() as u64; + assert!(discriminant <= 0b1111); - // Tag is always placed in the first four bytes - let discriminant = self.discriminant(); - assert!(discriminant <= 0b1111); - serialized_bytes[0] |= discriminant << 4; - - serialized_bytes + let payload = match self { + MastNodeType::Join { + left_child_id: left, + right_child_id: right, + } => Self::encode_u32_pair(*left, *right), + MastNodeType::Split { + if_branch_id: if_branch, + else_branch_id: else_branch, + } => Self::encode_u32_pair(*if_branch, *else_branch), + MastNodeType::Loop { body_id: body } => Self::encode_u32_payload(*body), + MastNodeType::Block { offset, len } => Self::encode_u32_pair(*offset, *len), + MastNodeType::Call { callee_id } => Self::encode_u32_payload(*callee_id), + MastNodeType::SysCall { callee_id } => Self::encode_u32_payload(*callee_id), + MastNodeType::Dyn => 0, + MastNodeType::External => 0, }; - serialized_bytes.write_into(target) + let value = (discriminant << 60) | payload; + target.write_u64(value); } } @@ -230,221 +240,104 @@ impl MastNodeType { unsafe { *<*const _>::from(self).cast::() } } - fn inline_data_to_bytes(&self) -> [u8; 8] { - match self { - MastNodeType::Join { - left_child_id: left, - right_child_id: right, - } => Self::encode_u32_pair(*left, *right), - MastNodeType::Split { - if_branch_id: if_branch, - else_branch_id: else_branch, - } => Self::encode_u32_pair(*if_branch, *else_branch), - MastNodeType::Loop { body_id: body } => Self::encode_u32_payload(*body), - MastNodeType::Block { offset, len } => Self::encode_u32_pair(*offset, *len), - MastNodeType::Call { callee_id } => Self::encode_u32_payload(*callee_id), - MastNodeType::SysCall { callee_id } => Self::encode_u32_payload(*callee_id), - MastNodeType::Dyn => [0; 8], - MastNodeType::External => [0; 8], + /// Encodes two u32 numbers in the first 60 bits of a `u64`. + /// + /// # Panics + /// - Panics if either `left_value` or `right_value` doesn't fit in 30 bits. + fn encode_u32_pair(left_value: u32, right_value: u32) -> u64 { + if left_value.leading_zeros() < 2 { + panic!( + "MastNodeType::encode_u32_pair: left value doesn't fit in 30 bits: {}", + left_value + ); } - } - - fn encode_u32_pair(left_value: u32, right_value: u32) -> [u8; 8] { - assert!(left_value < 2_u32.pow(30)); - assert!(right_value < 2_u32.pow(30)); - - let mut result: [u8; 8] = [0_u8; 8]; - - // write left child into result - { - let [lsb, a, b, msb] = left_value.to_le_bytes(); - result[0] |= lsb >> 4; - result[1] |= lsb << 4; - result[1] |= a >> 4; - result[2] |= a << 4; - result[2] |= b >> 4; - result[3] |= b << 4; - - // msb is different from lsb, a and b since its 2 most significant bits are guaranteed - // to be 0, and hence not encoded. - // - // More specifically, let the bits of msb be `00abcdef`. We encode `abcd` in - // `result[3]`, and `ef` as the most significant bits of `result[4]`. - result[3] |= msb >> 2; - result[4] |= msb << 6; - }; - // write right child into result - { - // Recall that `result[4]` contains 2 bits from the left child id in the most - // significant bits. Also, the most significant byte of the right child is guaranteed to - // fit in 6 bits. Hence, we use big endian format for the right child id to simplify - // encoding and decoding. - let [msb, a, b, lsb] = right_value.to_be_bytes(); - - result[4] |= msb; - result[5] = a; - result[6] = b; - result[7] = lsb; - }; + if right_value.leading_zeros() < 2 { + panic!( + "MastNodeType::encode_u32_pair: right value doesn't fit in 30 bits: {}", + left_value + ); + } - result + ((left_value as u64) << 30) | (right_value as u64) } - fn encode_u32_payload(payload: u32) -> [u8; 8] { - let [payload_byte1, payload_byte2, payload_byte3, payload_byte4] = payload.to_be_bytes(); - - [0, payload_byte1, payload_byte2, payload_byte3, payload_byte4, 0, 0, 0] + fn encode_u32_payload(payload: u32) -> u64 { + payload as u64 } } impl Deserializable for MastNodeType { fn read_from(source: &mut R) -> Result { - let bytes: [u8; 8] = source.read_array()?; + let (discriminant, payload) = { + let value = source.read_u64()?; - let tag = bytes[0] >> 4; + // 4 bits + let discriminant = (value >> 60) as u8; + // 60 bits + let payload = value & 0x0F_FF_FF_FF_FF_FF_FF_FF; + + (discriminant, payload) + }; - match tag { + match discriminant { JOIN => { - let (left_child_id, right_child_id) = Self::decode_u32_pair(bytes); + let (left_child_id, right_child_id) = Self::decode_u32_pair(payload); Ok(Self::Join { left_child_id, right_child_id, }) } SPLIT => { - let (if_branch_id, else_branch_id) = Self::decode_u32_pair(bytes); + let (if_branch_id, else_branch_id) = Self::decode_u32_pair(payload); Ok(Self::Split { if_branch_id, else_branch_id, }) } LOOP => { - let body_id = Self::decode_u32_payload(bytes); + let body_id = Self::decode_u32_payload(payload)?; Ok(Self::Loop { body_id }) } BLOCK => { - let (offset, len) = Self::decode_u32_pair(bytes); + let (offset, len) = Self::decode_u32_pair(payload); Ok(Self::Block { offset, len }) } CALL => { - let callee_id = Self::decode_u32_payload(bytes); + let callee_id = Self::decode_u32_payload(payload)?; Ok(Self::Call { callee_id }) } SYSCALL => { - let callee_id = Self::decode_u32_payload(bytes); + let callee_id = Self::decode_u32_payload(payload)?; Ok(Self::SysCall { callee_id }) } DYN => Ok(Self::Dyn), EXTERNAL => Ok(Self::External), - _ => { - Err(DeserializationError::InvalidValue(format!("Invalid tag for MAST node: {tag}"))) - } + _ => Err(DeserializationError::InvalidValue(format!( + "Invalid tag for MAST node: {discriminant}" + ))), } } } /// Deserialization helpers impl MastNodeType { - fn decode_u32_pair(buffer: [u8; 8]) -> (u32, u32) { - let first = { - let mut first_le_bytes = [0_u8; 4]; - - first_le_bytes[0] = buffer[0] << 4; - first_le_bytes[0] |= buffer[1] >> 4; - - first_le_bytes[1] = buffer[1] << 4; - first_le_bytes[1] |= buffer[2] >> 4; - - first_le_bytes[2] = buffer[2] << 4; - first_le_bytes[2] |= buffer[3] >> 4; - - first_le_bytes[3] = (buffer[3] & 0b1111) << 2; - first_le_bytes[3] |= buffer[4] >> 6; - - u32::from_le_bytes(first_le_bytes) - }; - - let second = { - let mut second_be_bytes = [0_u8; 4]; - - second_be_bytes[0] = buffer[4] & 0b0011_1111; - second_be_bytes[1] = buffer[5]; - second_be_bytes[2] = buffer[6]; - second_be_bytes[3] = buffer[7]; - - u32::from_be_bytes(second_be_bytes) - }; - - (first, second) - } - - pub fn decode_u32_payload(payload: [u8; 8]) -> u32 { - let payload_be_bytes = [payload[1], payload[2], payload[3], payload[4]]; - - u32::from_be_bytes(payload_be_bytes) - } -} - -// TESTS -// ================================================================================================ - -#[cfg(test)] -mod tests { - use super::*; - use alloc::vec::Vec; - - #[test] - fn mast_node_type_serde_join() { - let left_child_id = 0b00111001_11101011_01101100_11011000; - let right_child_id = 0b00100111_10101010_11111111_11001110; - - let mast_node_type = MastNodeType::Join { - left_child_id, - right_child_id, - }; - - let mut encoded_mast_node_type: Vec = Vec::new(); - mast_node_type.write_into(&mut encoded_mast_node_type); + /// Decodes two `u32` numbers from a 60-bit payload. + fn decode_u32_pair(payload: u64) -> (u32, u32) { + let left_value = (payload >> 30) as u32; + let right_value = (payload & 0x3F_FF_FF_FF) as u32; - // Note: Join's discriminant is 0 - let expected_encoded_mast_node_type = [ - 0b00001101, 0b10000110, 0b11001110, 0b10111110, 0b01100111, 0b10101010, 0b11111111, - 0b11001110, - ]; - - assert_eq!(expected_encoded_mast_node_type.to_vec(), encoded_mast_node_type); - - let (decoded_left, decoded_right) = - MastNodeType::decode_u32_pair(expected_encoded_mast_node_type); - assert_eq!(left_child_id, decoded_left); - assert_eq!(right_child_id, decoded_right); + (left_value, right_value) } - #[test] - fn mast_node_type_serde_split() { - let if_branch_id = 0b00111001_11101011_01101100_11011000; - let else_branch_id = 0b00100111_10101010_11111111_11001110; - - let mast_node_type = MastNodeType::Split { - if_branch_id, - else_branch_id, - }; - - let mut encoded_mast_node_type: Vec = Vec::new(); - mast_node_type.write_into(&mut encoded_mast_node_type); - - // Note: Split's discriminant is 1 - let expected_encoded_mast_node_type = [ - 0b00011101, 0b10000110, 0b11001110, 0b10111110, 0b01100111, 0b10101010, 0b11111111, - 0b11001110, - ]; - - assert_eq!(expected_encoded_mast_node_type.to_vec(), encoded_mast_node_type); - - let (decoded_if_branch, decoded_else_branch) = - MastNodeType::decode_u32_pair(expected_encoded_mast_node_type); - assert_eq!(if_branch_id, decoded_if_branch); - assert_eq!(else_branch_id, decoded_else_branch); + /// Decodes one `u32` number from a 60-bit payload. + /// + /// Returns an error if the payload doesn't fit in a `u32`. + pub fn decode_u32_payload(payload: u64) -> Result { + payload.try_into().map_err(|_| { + DeserializationError::InvalidValue(format!( + "Invalid payload: expected to fit in u32, but was {payload}" + )) + }) } } From c081b00daf22b64eb16d0a4c802f3b86d6e9ca17 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 12 Jul 2024 10:25:39 -0400 Subject: [PATCH 115/172] derive `Copy` for `MastNodeType` --- core/src/mast/serialization/info.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index c2eab7d4cd..2f4c035ef4 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -129,7 +129,7 @@ const EXTERNAL: u8 = 7; /// /// The serialized representation of the MAST node type is guaranteed to be 8 bytes, so that /// [`MastNodeInfo`] (which contains it) can be of fixed width. -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] #[repr(u8)] pub enum MastNodeType { Join { @@ -207,19 +207,19 @@ impl Serializable for MastNodeType { let discriminant = self.discriminant() as u64; assert!(discriminant <= 0b1111); - let payload = match self { + let payload = match *self { MastNodeType::Join { left_child_id: left, right_child_id: right, - } => Self::encode_u32_pair(*left, *right), + } => Self::encode_u32_pair(left, right), MastNodeType::Split { if_branch_id: if_branch, else_branch_id: else_branch, - } => Self::encode_u32_pair(*if_branch, *else_branch), - MastNodeType::Loop { body_id: body } => Self::encode_u32_payload(*body), - MastNodeType::Block { offset, len } => Self::encode_u32_pair(*offset, *len), - MastNodeType::Call { callee_id } => Self::encode_u32_payload(*callee_id), - MastNodeType::SysCall { callee_id } => Self::encode_u32_payload(*callee_id), + } => Self::encode_u32_pair(if_branch, else_branch), + MastNodeType::Loop { body_id: body } => Self::encode_u32_payload(body), + MastNodeType::Block { offset, len } => Self::encode_u32_pair(offset, len), + MastNodeType::Call { callee_id } => Self::encode_u32_payload(callee_id), + MastNodeType::SysCall { callee_id } => Self::encode_u32_payload(callee_id), MastNodeType::Dyn => 0, MastNodeType::External => 0, }; From c32ef2284b5a99a3f308d76976b221f07a94ee03 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 12 Jul 2024 10:41:01 -0400 Subject: [PATCH 116/172] `MastNodeType` tests --- core/src/mast/serialization/info.rs | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 2f4c035ef4..f856687d89 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -129,7 +129,7 @@ const EXTERNAL: u8 = 7; /// /// The serialized representation of the MAST node type is guaranteed to be 8 bytes, so that /// [`MastNodeInfo`] (which contains it) can be of fixed width. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub enum MastNodeType { Join { @@ -341,3 +341,22 @@ impl MastNodeType { }) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialize_deserialize_60_bit_payload() { + // each child needs 30 bits + let mast_node_type = MastNodeType::Join { + left_child_id: 0x3F_FF_FF_FF, + right_child_id: 0x3F_FF_FF_FF, + }; + + let serialized = mast_node_type.to_bytes(); + let deserialized = MastNodeType::read_from_bytes(&serialized).unwrap(); + + assert_eq!(mast_node_type, deserialized); + } +} From c7ee9c8e80252d8799ff45c69c6449a38fc08296 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 12 Jul 2024 10:52:40 -0400 Subject: [PATCH 117/172] add `MastNodeType` tests --- core/src/mast/serialization/info.rs | 45 +++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index f856687d89..eb24bf7088 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -344,6 +344,8 @@ impl MastNodeType { #[cfg(test)] mod tests { + use alloc::vec::Vec; + use super::*; #[test] @@ -359,4 +361,47 @@ mod tests { assert_eq!(mast_node_type, deserialized); } + + #[test] + #[should_panic] + fn serialize_large_payloads_fails_1() { + // left child needs 31 bits + let mast_node_type = MastNodeType::Join { + left_child_id: 0x4F_FF_FF_FF, + right_child_id: 0x0, + }; + + // must panic + let _serialized = mast_node_type.to_bytes(); + } + + #[test] + #[should_panic] + fn serialize_large_payloads_fails_2() { + // right child needs 31 bits + let mast_node_type = MastNodeType::Join { + left_child_id: 0x0, + right_child_id: 0x4F_FF_FF_FF, + }; + + // must panic + let _serialized = mast_node_type.to_bytes(); + } + + #[test] + fn deserialize_large_payloads_fails() { + // Serialized `CALL` with a 33-bit payload + let serialized = { + let serialized_value = ((CALL as u64) << 60) | (u32::MAX as u64 + 1_u64); + + let mut serialized_buffer: Vec = Vec::new(); + serialized_value.write_into(&mut serialized_buffer); + + serialized_buffer + }; + + let deserialized_result = MastNodeType::read_from_bytes(&serialized); + + assert_matches!(deserialized_result, Err(DeserializationError::InvalidValue(_))); + } } From f9d2e59d1e749d6032633a9b988fa7fc4d145578 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 12 Jul 2024 13:39:12 -0400 Subject: [PATCH 118/172] use assert --- core/src/mast/serialization/info.rs | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index eb24bf7088..3f628a4dbb 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -245,19 +245,16 @@ impl MastNodeType { /// # Panics /// - Panics if either `left_value` or `right_value` doesn't fit in 30 bits. fn encode_u32_pair(left_value: u32, right_value: u32) -> u64 { - if left_value.leading_zeros() < 2 { - panic!( - "MastNodeType::encode_u32_pair: left value doesn't fit in 30 bits: {}", - left_value - ); - } - - if right_value.leading_zeros() < 2 { - panic!( - "MastNodeType::encode_u32_pair: right value doesn't fit in 30 bits: {}", - left_value - ); - } + assert!( + left_value.leading_zeros() < 2, + "MastNodeType::encode_u32_pair: left value doesn't fit in 30 bits: {}", + left_value + ); + assert!( + right_value.leading_zeros() < 2, + "MastNodeType::encode_u32_pair: right value doesn't fit in 30 bits: {}", + right_value + ); ((left_value as u64) << 30) | (right_value as u64) } From 781fc7376b8f557b2123d1233ca14a9f914cfb4a Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 12 Jul 2024 15:46:10 -0400 Subject: [PATCH 119/172] fix asserts --- core/src/mast/serialization/info.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 3f628a4dbb..b72e1701e0 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -246,12 +246,12 @@ impl MastNodeType { /// - Panics if either `left_value` or `right_value` doesn't fit in 30 bits. fn encode_u32_pair(left_value: u32, right_value: u32) -> u64 { assert!( - left_value.leading_zeros() < 2, + left_value.leading_zeros() >= 2, "MastNodeType::encode_u32_pair: left value doesn't fit in 30 bits: {}", left_value ); assert!( - right_value.leading_zeros() < 2, + right_value.leading_zeros() >= 2, "MastNodeType::encode_u32_pair: right value doesn't fit in 30 bits: {}", right_value ); From fa91716b30abb1366cd89953404db46f1db3ed7e Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Tue, 16 Jul 2024 10:39:01 -0400 Subject: [PATCH 120/172] `ModuleGraph::recompute()` reverse edge caller/callee --- assembly/src/assembler/module_graph/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 9ebb4c2157..0d215eb3d3 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -340,7 +340,7 @@ impl ModuleGraph { self.modules.append(&mut finished); edges .into_iter() - .for_each(|(callee, caller)| self.callgraph.add_edge(callee, caller)); + .for_each(|(caller, callee)| self.callgraph.add_edge(caller, callee)); // Visit all of the modules in the base module graph, and modify them if any of the // pending modules allow additional information to be inferred (such as the absolute path From 1bd84ec759c4ce1bf55eab1a2c9f7e3f10bcd3a1 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Tue, 16 Jul 2024 16:11:31 -0400 Subject: [PATCH 121/172] Implement `Assembler::assemble_library()` --- assembly/src/assembler/mod.rs | 60 +++++++++++++++++++++++++++----- assembly/src/compiled_library.rs | 46 ++++++++++++++++++++++++ assembly/src/lib.rs | 1 + 3 files changed, 98 insertions(+), 9 deletions(-) create mode 100644 assembly/src/compiled_library.rs diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index a63f7ffb25..1657f978ac 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -3,6 +3,7 @@ use crate::{ self, AliasTarget, Export, FullyQualifiedProcedureName, Instruction, InvocationTarget, InvokeKind, ModuleKind, ProcedureIndex, }, + compiled_library::{CompiledLibrary, CompiledLibraryMetadata}, diagnostics::{tracing::instrument, Report}, sema::SemanticAnalysisError, AssemblyError, Compile, CompileOptions, Felt, Library, LibraryNamespace, LibraryPath, @@ -312,6 +313,42 @@ impl Assembler { /// Compilation/Assembly impl Assembler { + // TODOP: Document + pub fn assemble_library>( + mut self, + modules: T, + metadata: CompiledLibraryMetadata, // name, version etc. + ) -> Result { + let module_ids: Vec = modules + .map(|module| { + let module = module.compile()?; + Ok(self.module_graph.add_module(module)?) + }) + .collect::>()?; + self.module_graph.recompute()?; + + let mut mast_forest_builder = core::mem::take(&mut self.mast_forest_builder); + let mut context = AssemblyContext::default(); + + self.assemble_graph(&mut context, &mut mast_forest_builder)?; + + let exports = { + let mut exports = Vec::new(); + for module_id in module_ids { + let exports_in_module: Vec = + self.get_module_exports(module_id).map(|procedures| { + procedures.into_iter().map(|proc| proc.path().clone()).collect() + })?; + + exports.extend(exports_in_module); + } + + exports + }; + + Ok(CompiledLibrary::new(mast_forest_builder.build(), exports, metadata)) + } + /// Compiles the provided module into a [`Program`]. The resulting program can be executed on /// Miden VM. /// @@ -466,8 +503,15 @@ impl Assembler { let mut mast_forest_builder = core::mem::take(&mut self.mast_forest_builder); self.assemble_graph(context, &mut mast_forest_builder)?; - let exported_procedure_digests = - self.get_module_exports(module_id, mast_forest_builder.forest()); + let exported_procedure_digests = self.get_module_exports(module_id).map(|procedures| { + procedures + .into_iter() + .map(|proc| { + let proc_code_node = &mast_forest_builder.forest()[proc.body_node_id()]; + proc_code_node.digest() + }) + .collect() + }); // Reassign the mast_forest to the assembler for use is a future program assembly self.mast_forest_builder = mast_forest_builder; @@ -511,20 +555,19 @@ impl Assembler { .map_err(|err| Report::new(AssemblyError::Kernel(err))) } - /// Get the set of procedure roots for all exports of the given module + /// Get the set of exported procedures of the given module. /// /// Returns an error if the provided Miden Assembly is invalid. fn get_module_exports( &mut self, module: ModuleIndex, - mast_forest: &MastForest, - ) -> Result, Report> { + // TODOP: Return iterator instead? + ) -> Result>, Report> { assert!(self.module_graph.contains_module(module), "invalid module index"); let mut exports = Vec::new(); for (index, procedure) in self.module_graph[module].procedures().enumerate() { - // Only add exports to the code block table, locals will - // be added if they are in the call graph rooted at those + // Only add exports; locals will be added if they are in the call graph rooted at those // procedures if !procedure.visibility().is_exported() { continue; @@ -571,8 +614,7 @@ impl Assembler { } }); - let proc_code_node = &mast_forest[proc.body_node_id()]; - exports.push(proc_code_node.digest()); + exports.push(proc); } Ok(exports) diff --git a/assembly/src/compiled_library.rs b/assembly/src/compiled_library.rs new file mode 100644 index 0000000000..bd6085cf97 --- /dev/null +++ b/assembly/src/compiled_library.rs @@ -0,0 +1,46 @@ +use alloc::{string::String, vec::Vec}; +use vm_core::mast::MastForest; + +use crate::{LibraryPath, Version}; + +// TODOP: Move into `miden-core` along with `LibraryPath` +pub struct CompiledLibrary { + mast_forest: MastForest, + // a path for every `root` in the associated [MastForest] + exports: Vec, + metadata: CompiledLibraryMetadata, +} + +/// Constructors +impl CompiledLibrary { + pub fn new( + mast_forest: MastForest, + exports: Vec, + metadata: CompiledLibraryMetadata, + ) -> Self { + Self { + mast_forest, + exports, + metadata, + } + } +} + +impl CompiledLibrary { + pub fn mast_forest(&self) -> &MastForest { + &self.mast_forest + } + + pub fn exports(&self) -> &[LibraryPath] { + &self.exports + } + + pub fn metadata(&self) -> &CompiledLibraryMetadata { + &self.metadata + } +} + +pub struct CompiledLibraryMetadata { + pub name: String, + pub version: Version, +} diff --git a/assembly/src/lib.rs b/assembly/src/lib.rs index 000cf114de..fa52681016 100644 --- a/assembly/src/lib.rs +++ b/assembly/src/lib.rs @@ -21,6 +21,7 @@ use vm_core::{prettier, utils::DisplayHex}; mod assembler; pub mod ast; mod compile; +pub mod compiled_library; pub mod diagnostics; mod errors; pub mod library; From 0a9cfadad657d39914d8a6f0114a381c26d6b1c9 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 17 Jul 2024 09:45:49 -0400 Subject: [PATCH 122/172] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6504a8e6f5..89c5cd9821 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ - Relaxed the parser to allow one branch of an `if.(true|false)` to be empty - Added support for immediate values for `u32and`, `u32or`, `u32xor` and `u32not` bitwise instructions (#1362). - Optimized `std::sys::truncate_stuck` procedure (#1384). +- Add serialization/deserialization for `MastForest` (#1370) #### Changed From 40cbbb97e56c56b40ab06a97d3c36c7c32062731 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 17 Jul 2024 09:52:23 -0400 Subject: [PATCH 123/172] fix docs --- core/src/mast/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index c8567f9ae8..70a87c58c1 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -5,8 +5,8 @@ use miden_crypto::hash::rpo::RpoDigest; mod node; pub use node::{ - get_span_op_group_count, BasicBlockNode, CallNode, DynNode, JoinNode, LoopNode, MastNode, - OpBatch, OperationOrDecorator, SplitNode, OP_BATCH_SIZE, OP_GROUP_SIZE, + get_span_op_group_count, BasicBlockNode, CallNode, DynNode, ExternalNode, JoinNode, LoopNode, + MastNode, OpBatch, OperationOrDecorator, SplitNode, OP_BATCH_SIZE, OP_GROUP_SIZE, }; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; From f405a2b5d3aa954e06925433516d33ac57ee76b7 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 17 Jul 2024 12:13:17 -0400 Subject: [PATCH 124/172] Introduce `CompiledFQDN` --- assembly/src/assembler/mod.rs | 11 ++++++++--- assembly/src/compiled_library.rs | 25 +++++++++++++++++++++---- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 6a05780a49..a29d8250ed 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -3,7 +3,9 @@ use crate::{ self, AliasTarget, Export, FullyQualifiedProcedureName, Instruction, InvocationTarget, InvokeKind, ModuleKind, ProcedureIndex, }, - compiled_library::{CompiledLibrary, CompiledLibraryMetadata}, + compiled_library::{ + CompiledFullyQualifiedProcedureName, CompiledLibrary, CompiledLibraryMetadata, + }, diagnostics::{tracing::instrument, Report}, sema::SemanticAnalysisError, AssemblyError, Compile, CompileOptions, Felt, Library, LibraryNamespace, LibraryPath, @@ -333,9 +335,12 @@ impl Assembler { let exports = { let mut exports = Vec::new(); for module_id in module_ids { - let exports_in_module: Vec = + let exports_in_module: Vec = self.get_module_exports(module_id).map(|procedures| { - procedures.into_iter().map(|proc| proc.path().clone()).collect() + procedures + .into_iter() + .map(|proc| proc.fully_qualified_name().clone().into()) + .collect() })?; exports.extend(exports_in_module); diff --git a/assembly/src/compiled_library.rs b/assembly/src/compiled_library.rs index bd6085cf97..66aa48e68f 100644 --- a/assembly/src/compiled_library.rs +++ b/assembly/src/compiled_library.rs @@ -1,13 +1,30 @@ use alloc::{string::String, vec::Vec}; use vm_core::mast::MastForest; -use crate::{LibraryPath, Version}; +use crate::{ast::{FullyQualifiedProcedureName, ProcedureName}, LibraryPath, Version}; + +// TODOP: Refactor `FullyQualifiedProcedureName` instead, and use `Span` where needed? +pub struct CompiledFullyQualifiedProcedureName { + /// The module path for this procedure. + pub module: LibraryPath, + /// The name of the procedure. + pub name: ProcedureName, +} + +impl From for CompiledFullyQualifiedProcedureName { + fn from(fqdn: FullyQualifiedProcedureName) -> Self { + Self { + module: fqdn.module, + name: fqdn.name, + } + } +} // TODOP: Move into `miden-core` along with `LibraryPath` pub struct CompiledLibrary { mast_forest: MastForest, // a path for every `root` in the associated [MastForest] - exports: Vec, + exports: Vec, metadata: CompiledLibraryMetadata, } @@ -15,7 +32,7 @@ pub struct CompiledLibrary { impl CompiledLibrary { pub fn new( mast_forest: MastForest, - exports: Vec, + exports: Vec, metadata: CompiledLibraryMetadata, ) -> Self { Self { @@ -31,7 +48,7 @@ impl CompiledLibrary { &self.mast_forest } - pub fn exports(&self) -> &[LibraryPath] { + pub fn exports(&self) -> &[CompiledFullyQualifiedProcedureName] { &self.exports } From 4236a847d7d169d91a8798a666da2cccb4badeb9 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 17 Jul 2024 13:34:13 -0400 Subject: [PATCH 125/172] Introduce `WrapperModule` to module graph --- assembly/src/assembler/module_graph/mod.rs | 450 +++++++++++++-------- assembly/src/compiled_library.rs | 23 +- 2 files changed, 295 insertions(+), 178 deletions(-) diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 0d215eb3d3..0619b53041 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -27,21 +27,88 @@ use self::{ rewrites::ModuleRewriter, }; use super::{GlobalProcedureIndex, ModuleIndex}; +use crate::compiled_library::CompiledProcedure; use crate::{ ast::{ - Export, FullyQualifiedProcedureName, InvocationTarget, Module, Procedure, ProcedureIndex, + Export, FullyQualifiedProcedureName, InvocationTarget, Module, ProcedureIndex, ProcedureName, ResolvedProcedure, }, diagnostics::{RelatedLabel, SourceFile}, AssemblyError, LibraryPath, RpoDigest, Spanned, }; +// TODOP: Better doc +pub enum WrapperProcedure<'a> { + Ast(&'a Export), + Compiled(&'a CompiledProcedure), +} + +impl<'a> WrapperProcedure<'a> { + pub fn name(&self) -> &ProcedureName { + match self { + WrapperProcedure::Ast(p) => p.name(), + WrapperProcedure::Compiled(p) => p.name(), + } + } +} + +// TODOP: Rename (?) +#[derive(Clone)] +pub struct ModuleExports { + path: LibraryPath, + procedures: Vec<(ProcedureIndex, CompiledProcedure)>, +} + +impl ModuleExports { + pub fn new(path: LibraryPath, procedures: impl Iterator) -> Self { + Self { + path, + procedures: procedures + .enumerate() + .map(|(idx, proc)| (ProcedureIndex::new(idx), proc)) + .collect(), + } + } +} + +// TODOP: Rename +#[derive(Clone)] +pub enum WrapperModule { + Ast(Arc), + Exports(ModuleExports), +} + +impl WrapperModule { + pub fn path(&self) -> &LibraryPath { + match self { + WrapperModule::Ast(m) => m.path(), + WrapperModule::Exports(m) => &m.path, + } + } +} + +// TODOP: Try to do without this `Pending*` version +#[derive(Clone)] +pub enum PendingWrapperModule { + Ast(Box), + Exports(ModuleExports), +} + +impl PendingWrapperModule { + pub fn path(&self) -> &LibraryPath { + match self { + PendingWrapperModule::Ast(m) => m.path(), + PendingWrapperModule::Exports(m) => &m.path, + } + } +} + // MODULE GRAPH // ================================================================================================ #[derive(Default, Clone)] pub struct ModuleGraph { - modules: Vec>, + modules: Vec, /// The set of modules pending additional processing before adding them to the graph. /// /// When adding a set of inter-dependent modules to the graph, we process them as a group, so @@ -50,8 +117,7 @@ pub struct ModuleGraph { /// /// Once added to the graph, modules become immutable, and any additional modules added after /// that must by definition only depend on modules in the graph, and not be depended upon. - #[allow(clippy::vec_box)] - pending: Vec>, + pending: Vec, /// The global call graph of calls, not counting those that are performed directly via MAST /// root. callgraph: CallGraph, @@ -92,7 +158,10 @@ impl ModuleGraph { /// /// This function will panic if the number of modules exceeds the maximum representable /// [ModuleIndex] value, `u16::MAX`. - pub fn add_module(&mut self, module: Box) -> Result { + pub fn add_module( + &mut self, + module: PendingWrapperModule, + ) -> Result { let is_duplicate = self.is_pending(module.path()) || self.find_module_index(module.path()).is_some(); if is_duplicate { @@ -191,21 +260,6 @@ impl ModuleGraph { pub fn is_kernel_procedure_root(&self, digest: &RpoDigest) -> bool { self.kernel.contains_proc(*digest) } - - #[allow(unused)] - pub fn is_kernel_procedure(&self, name: &ProcedureName) -> bool { - self.kernel_index - .map(|index| self[index].resolve(name).is_some()) - .unwrap_or(false) - } - - #[allow(unused)] - pub fn is_kernel_procedure_fully_qualified(&self, name: &FullyQualifiedProcedureName) -> bool { - self.find_module_index(&name.module) - .filter(|module_index| self.kernel_index == Some(*module_index)) - .map(|module_index| self[module_index].resolve(&name.name).is_some()) - .unwrap_or(false) - } } /// Analysis @@ -272,17 +326,31 @@ impl ModuleGraph { let module_id = ModuleIndex::new(high_water_mark + pending_index); // Apply module to call graph - for (index, procedure) in pending_module.procedures().enumerate() { - let procedure_id = ProcedureIndex::new(index); - let global_id = GlobalProcedureIndex { - module: module_id, - index: procedure_id, - }; - - // Ensure all entrypoints and exported symbols are represented in the call graph, - // even if they have no edges, we need them in the graph for the topological sort - if matches!(procedure, Export::Procedure(_)) { - self.callgraph.get_or_insert_node(global_id); + match pending_module { + PendingWrapperModule::Ast(pending_module) => { + for (index, procedure) in pending_module.procedures().enumerate() { + let procedure_id = ProcedureIndex::new(index); + let global_id = GlobalProcedureIndex { + module: module_id, + index: procedure_id, + }; + + // Ensure all entrypoints and exported symbols are represented in the call + // graph, even if they have no edges, we need them + // in the graph for the topological sort + if matches!(procedure, Export::Procedure(_)) { + self.callgraph.get_or_insert_node(global_id); + } + } + } + PendingWrapperModule::Exports(pending_module) => { + for (procedure_id, _procedure) in pending_module.procedures.iter() { + let global_id = GlobalProcedureIndex { + module: module_id, + index: *procedure_id, + }; + self.callgraph.get_or_insert_node(global_id); + } } } } @@ -291,45 +359,54 @@ impl ModuleGraph { // before they are added to the graph let mut resolver = NameResolver::new(self); for module in pending.iter() { - resolver.push_pending(module); + if let PendingWrapperModule::Ast(module) = module { + resolver.push_pending(module); + } } let mut phantoms = BTreeSet::default(); let mut edges = Vec::new(); - let mut finished = Vec::>::new(); + let mut finished: Vec = Vec::new(); - // Visit all of the newly-added modules and perform any rewrites - for (pending_index, mut module) in pending.into_iter().enumerate() { - let module_id = ModuleIndex::new(high_water_mark + pending_index); + // Visit all of the newly-added modules and perform any rewrites to AST modules. + for (pending_index, module) in pending.into_iter().enumerate() { + match module { + PendingWrapperModule::Ast(mut ast_module) => { + let module_id = ModuleIndex::new(high_water_mark + pending_index); - let mut rewriter = ModuleRewriter::new(&resolver); - rewriter.apply(module_id, &mut module)?; + let mut rewriter = ModuleRewriter::new(&resolver); + rewriter.apply(module_id, &mut ast_module)?; - // Gather the phantom calls found while rewriting the module - phantoms.extend(rewriter.phantoms()); - - for (index, procedure) in module.procedures().enumerate() { - let procedure_id = ProcedureIndex::new(index); - let gid = GlobalProcedureIndex { - module: module_id, - index: procedure_id, - }; - - for invoke in procedure.invoked() { - let caller = CallerInfo { - span: invoke.span(), - source_file: module.source_file(), - module: module_id, - kind: invoke.kind, - }; - if let Some(callee) = - resolver.resolve_target(&caller, &invoke.target)?.into_global_id() - { - edges.push((gid, callee)); + // Gather the phantom calls found while rewriting the module + phantoms.extend(rewriter.phantoms()); + + for (index, procedure) in ast_module.procedures().enumerate() { + let procedure_id = ProcedureIndex::new(index); + let gid = GlobalProcedureIndex { + module: module_id, + index: procedure_id, + }; + + for invoke in procedure.invoked() { + let caller = CallerInfo { + span: invoke.span(), + source_file: ast_module.source_file(), + module: module_id, + kind: invoke.kind, + }; + if let Some(callee) = + resolver.resolve_target(&caller, &invoke.target)?.into_global_id() + { + edges.push((gid, callee)); + } + } } + + finished.push(WrapperModule::Ast(Arc::new(*ast_module))) + } + PendingWrapperModule::Exports(module) => { + finished.push(WrapperModule::Exports(module)); } } - - finished.push(Arc::from(module)); } // Release the graph again @@ -342,18 +419,20 @@ impl ModuleGraph { .into_iter() .for_each(|(caller, callee)| self.callgraph.add_edge(caller, callee)); - // Visit all of the modules in the base module graph, and modify them if any of the - // pending modules allow additional information to be inferred (such as the absolute path - // of imports, etc) + // Visit all of the (AST) modules in the base module graph, and modify them if any of the + // pending modules allow additional information to be inferred (such as the absolute path of + // imports, etc) for module_index in 0..high_water_mark { let module_id = ModuleIndex::new(module_index); let module = self.modules[module_id.as_usize()].clone(); - // Re-analyze the module, and if we needed to clone-on-write, the new module will be - // returned. Otherwise, `Ok(None)` indicates that the module is unchanged, and `Err` - // indicates that re-analysis has found an issue with this module. - if let Some(new_module) = self.reanalyze_module(module_id, module)? { - self.modules[module_id.as_usize()] = new_module; + if let WrapperModule::Ast(module) = module { + // Re-analyze the module, and if we needed to clone-on-write, the new module will be + // returned. Otherwise, `Ok(None)` indicates that the module is unchanged, and `Err` + // indicates that re-analysis has found an issue with this module. + if let Some(new_module) = self.reanalyze_module(module_id, module)? { + self.modules[module_id.as_usize()] = WrapperModule::Ast(new_module); + } } } @@ -363,8 +442,8 @@ impl ModuleGraph { let mut nodes = Vec::with_capacity(iter.len()); for node in iter { let module = self[node.module].path(); - let proc = self[node].name(); - nodes.push(format!("{}::{}", module, proc)); + let proc = self.get_procedure_unsafe(node); + nodes.push(format!("{}::{}", module, proc.name())); } AssemblyError::Cycle { nodes } })?; @@ -418,7 +497,7 @@ impl ModuleGraph { /// Fetch a [Module] by [ModuleIndex] #[allow(unused)] - pub fn get_module(&self, id: ModuleIndex) -> Option> { + pub fn get_module(&self, id: ModuleIndex) -> Option { self.modules.get(id.as_usize()).cloned() } @@ -427,24 +506,28 @@ impl ModuleGraph { self.modules.get(id.as_usize()).is_some() } - /// Fetch a [Export] by [GlobalProcedureIndex] - #[allow(unused)] - pub fn get_procedure(&self, id: GlobalProcedureIndex) -> Option<&Export> { - self.modules.get(id.module.as_usize()).and_then(|m| m.get(id.index)) + /// Fetch a [WrapperProcedure] by [GlobalProcedureIndex], or `None` if index is invalid. + pub fn get_procedure(&self, id: GlobalProcedureIndex) -> Option { + match &self.modules[id.module.as_usize()] { + WrapperModule::Ast(m) => m.get(id.index).map(|export| WrapperProcedure::Ast(export)), + WrapperModule::Exports(m) => m + .procedures + .get(id.index.as_usize()) + .map(|(_idx, proc)| WrapperProcedure::Compiled(proc)), + } } - /// Fetches a [Procedure] by [RpoDigest]. + /// Fetch a [WrapperProcedure] by [GlobalProcedureIndex]. /// - /// NOTE: This implicitly chooses the first definition for a procedure if the same digest is - /// shared for multiple definitions. - #[allow(unused)] - pub fn get_procedure_by_digest(&self, digest: &RpoDigest) -> Option<&Procedure> { - self.roots - .get(digest) - .and_then(|indices| match self.get_procedure(indices[0])? { - Export::Procedure(ref proc) => Some(proc), - Export::Alias(_) => None, - }) + /// # Panics + /// - Panics if index is invalid. + pub fn get_procedure_unsafe(&self, id: GlobalProcedureIndex) -> WrapperProcedure { + match &self.modules[id.module.as_usize()] { + WrapperModule::Ast(m) => WrapperProcedure::Ast(&m[id.index]), + WrapperModule::Exports(m) => { + WrapperProcedure::Compiled(&m.procedures[id.index.as_usize()].1) + } + } } pub fn get_procedure_index_by_digest( @@ -493,26 +576,41 @@ impl ModuleGraph { Entry::Occupied(ref mut entry) => { let prev_id = entry.get()[0]; if prev_id != id { - // Multiple procedures with the same root, but incompatible - let prev = &self.modules[prev_id.module.as_usize()][prev_id.index]; - let current = &self.modules[id.module.as_usize()][id.index]; - if prev.num_locals() != current.num_locals() { - let prev_module = self.modules[prev_id.module.as_usize()].path(); - let prev_name = FullyQualifiedProcedureName { - span: prev.span(), - module: prev_module.clone(), - name: prev.name().clone(), - }; - let current_module = self.modules[id.module.as_usize()].path(); - let current_name = FullyQualifiedProcedureName { - span: current.span(), - module: current_module.clone(), - name: current.name().clone(), - }; - return Err(AssemblyError::ConflictingDefinitions { - first: prev_name, - second: current_name, - }); + let prev_proc = { + match &self.modules[prev_id.module.as_usize()] { + WrapperModule::Ast(module) => Some(&module[prev_id.index]), + WrapperModule::Exports(_) => None, + } + }; + let current_proc = { + match &self.modules[id.module.as_usize()] { + WrapperModule::Ast(module) => Some(&module[id.index]), + WrapperModule::Exports(_) => None, + } + }; + + // Note: For compiled procedures, we can't check further if they're compatible, + // so we assume they are. + if let (Some(prev_proc), Some(current_proc)) = (prev_proc, current_proc) { + if prev_proc.num_locals() != current_proc.num_locals() { + // Multiple procedures with the same root, but incompatible + let prev_module = self.modules[prev_id.module.as_usize()].path(); + let prev_name = FullyQualifiedProcedureName { + span: prev_proc.span(), + module: prev_module.clone(), + name: prev_proc.name().clone(), + }; + let current_module = self.modules[id.module.as_usize()].path(); + let current_name = FullyQualifiedProcedureName { + span: current_proc.span(), + module: current_module.clone(), + name: current_proc.name().clone(), + }; + return Err(AssemblyError::ConflictingDefinitions { + first: prev_name, + second: current_name, + }); + } } // Multiple procedures with the same root, but compatible @@ -557,43 +655,72 @@ impl ModuleGraph { } })?; let module = &self.modules[module_index.as_usize()]; - match module.resolve(&next.name) { - Some(ResolvedProcedure::Local(index)) => { - let id = GlobalProcedureIndex { - module: module_index, - index: index.into_inner(), - }; - break Ok(id); - } - Some(ResolvedProcedure::External(fqn)) => { - // If we see that we're about to enter an infinite resolver loop because of a - // recursive alias, return an error - if name == &fqn { - break Err(AssemblyError::RecursiveAlias { - source_file: caller.clone(), - name: name.clone(), - }); - } - next = Cow::Owned(fqn); - caller = module.source_file(); - } - Some(ResolvedProcedure::MastRoot(ref digest)) => { - if let Some(id) = self.get_procedure_index_by_digest(digest) { - break Ok(id); + + match module { + WrapperModule::Ast(module) => { + match module.resolve(&next.name) { + Some(ResolvedProcedure::Local(index)) => { + let id = GlobalProcedureIndex { + module: module_index, + index: index.into_inner(), + }; + break Ok(id); + } + Some(ResolvedProcedure::External(fqn)) => { + // If we see that we're about to enter an infinite resolver loop because + // of a recursive alias, return an error + if name == &fqn { + break Err(AssemblyError::RecursiveAlias { + source_file: caller.clone(), + name: name.clone(), + }); + } + next = Cow::Owned(fqn); + caller = module.source_file(); + } + Some(ResolvedProcedure::MastRoot(ref digest)) => { + if let Some(id) = self.get_procedure_index_by_digest(digest) { + break Ok(id); + } + break Err(AssemblyError::Failed { + labels: vec![RelatedLabel::error("undefined procedure") + .with_source_file(source_file) + .with_labeled_span( + next.span(), + "unable to resolve this reference", + )], + }); + } + None => { + // No such procedure known to `module` + break Err(AssemblyError::Failed { + labels: vec![RelatedLabel::error("undefined procedure") + .with_source_file(source_file) + .with_labeled_span( + next.span(), + "unable to resolve this reference", + )], + }); + } } - break Err(AssemblyError::Failed { - labels: vec![RelatedLabel::error("undefined procedure") - .with_source_file(source_file) - .with_labeled_span(next.span(), "unable to resolve this reference")], - }); } - None => { - // No such procedure known to `module` - break Err(AssemblyError::Failed { - labels: vec![RelatedLabel::error("undefined procedure") - .with_source_file(source_file) - .with_labeled_span(next.span(), "unable to resolve this reference")], - }); + WrapperModule::Exports(module) => { + break module + .procedures + .iter() + .find(|(_index, procedure)| procedure.name() == &name.name) + .map(|(index, _)| GlobalProcedureIndex { + module: module_index, + index: *index, + }) + .ok_or(AssemblyError::Failed { + labels: vec![RelatedLabel::error("undefined procedure") + .with_source_file(source_file) + .with_labeled_span( + next.span(), + "unable to resolve this reference", + )], + }) } } } @@ -605,13 +732,13 @@ impl ModuleGraph { } /// Resolve a [LibraryPath] to a [Module] in this graph - pub fn find_module(&self, name: &LibraryPath) -> Option> { + pub fn find_module(&self, name: &LibraryPath) -> Option { self.modules.iter().find(|m| m.path() == name).cloned() } /// Returns an iterator over the set of [Module]s in this graph, and their indices #[allow(unused)] - pub fn modules(&self) -> impl Iterator)> + '_ { + pub fn modules(&self) -> impl Iterator + '_ { self.modules .iter() .enumerate() @@ -620,44 +747,15 @@ impl ModuleGraph { /// Like [modules], but returns a reference to the module, rather than an owned pointer #[allow(unused)] - pub fn modules_by_ref(&self) -> impl Iterator + '_ { - self.modules - .iter() - .enumerate() - .map(|(idx, m)| (ModuleIndex::new(idx), m.as_ref())) - } - - /// Returns an iterator over the set of [Procedure]s in this graph, and their indices - #[allow(unused)] - pub fn procedures(&self) -> impl Iterator + '_ { - self.modules_by_ref().flat_map(|(module_index, module)| { - module.procedures().enumerate().filter_map(move |(index, p)| { - let index = ProcedureIndex::new(index); - let id = GlobalProcedureIndex { - module: module_index, - index, - }; - match p { - Export::Procedure(ref p) => Some((id, p)), - Export::Alias(_) => None, - } - }) - }) + pub fn modules_by_ref(&self) -> impl Iterator + '_ { + self.modules.iter().enumerate().map(|(idx, m)| (ModuleIndex::new(idx), m)) } } impl Index for ModuleGraph { - type Output = Arc; + type Output = WrapperModule; fn index(&self, index: ModuleIndex) -> &Self::Output { self.modules.index(index.as_usize()) } } - -impl Index for ModuleGraph { - type Output = Export; - - fn index(&self, index: GlobalProcedureIndex) -> &Self::Output { - self.modules[index.module.as_usize()].index(index.index) - } -} diff --git a/assembly/src/compiled_library.rs b/assembly/src/compiled_library.rs index 66aa48e68f..f2dcde68ff 100644 --- a/assembly/src/compiled_library.rs +++ b/assembly/src/compiled_library.rs @@ -1,7 +1,10 @@ use alloc::{string::String, vec::Vec}; -use vm_core::mast::MastForest; +use vm_core::{crypto::hash::RpoDigest, mast::MastForest}; -use crate::{ast::{FullyQualifiedProcedureName, ProcedureName}, LibraryPath, Version}; +use crate::{ + ast::{FullyQualifiedProcedureName, ProcedureName}, + LibraryPath, Version, +}; // TODOP: Refactor `FullyQualifiedProcedureName` instead, and use `Span` where needed? pub struct CompiledFullyQualifiedProcedureName { @@ -20,6 +23,22 @@ impl From for CompiledFullyQualifiedProcedureName { } } +#[derive(Clone)] +pub struct CompiledProcedure { + name: ProcedureName, + digest: RpoDigest, +} + +impl CompiledProcedure { + pub fn name(&self) -> &ProcedureName { + &self.name + } + + pub fn digest(&self) -> &RpoDigest { + &self.digest + } +} + // TODOP: Move into `miden-core` along with `LibraryPath` pub struct CompiledLibrary { mast_forest: MastForest, From c1947c2ff0d43d997712bfe717e76f5a22dbeb79 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 17 Jul 2024 15:19:26 -0400 Subject: [PATCH 126/172] split `ModuleGraph::add_module()` --- assembly/src/assembler/mod.rs | 11 ++++--- assembly/src/assembler/module_graph/mod.rs | 37 ++++++++++++++++++---- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index a29d8250ed..4cca1fd9f0 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -13,6 +13,7 @@ use crate::{ }; use alloc::{boxed::Box, sync::Arc, vec::Vec}; use mast_forest_builder::MastForestBuilder; +use module_graph::PendingWrapperModule; use vm_core::{ mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, Decorator, DecoratorList, Kernel, Operation, Program, @@ -222,7 +223,7 @@ impl Assembler { let module = module.compile_with_options(options)?; assert_eq!(module.kind(), kind, "expected module kind to match compilation options"); - self.module_graph.add_module(module)?; + self.module_graph.add_ast_module(module)?; Ok(()) } @@ -322,7 +323,7 @@ impl Assembler { let module_ids: Vec = modules .map(|module| { let module = module.compile()?; - Ok(self.module_graph.add_module(module)?) + Ok(self.module_graph.add_ast_module(module)?) }) .collect::>()?; self.module_graph.recompute()?; @@ -444,7 +445,7 @@ impl Assembler { } // Recompute graph with executable module, and start compiling - let module_index = self.module_graph.add_module(program)?; + let module_index = self.module_graph.add_ast_module(program)?; self.module_graph.recompute()?; // Find the executable entrypoint @@ -500,7 +501,7 @@ impl Assembler { let module = module.compile_with_options(options)?; // Recompute graph with the provided module, and start assembly - let module_id = self.module_graph.add_module(module)?; + let module_id = self.module_graph.add_ast_module(module)?; self.module_graph.recompute()?; let mut mast_forest_builder = core::mem::take(&mut self.mast_forest_builder); @@ -536,7 +537,7 @@ impl Assembler { let mut context = AssemblyContext::for_kernel(module.path()); context.set_warnings_as_errors(self.warnings_as_errors); - let kernel_index = self.module_graph.add_module(module)?; + let kernel_index = self.module_graph.add_ast_module(module)?; self.module_graph.recompute()?; let kernel_module = self.module_graph[kernel_index].clone(); let mut kernel = Vec::new(); diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 0619b53041..d67aa36b57 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -54,12 +54,12 @@ impl<'a> WrapperProcedure<'a> { // TODOP: Rename (?) #[derive(Clone)] -pub struct ModuleExports { +pub struct CompiledModule { path: LibraryPath, procedures: Vec<(ProcedureIndex, CompiledProcedure)>, } -impl ModuleExports { +impl CompiledModule { pub fn new(path: LibraryPath, procedures: impl Iterator) -> Self { Self { path, @@ -75,7 +75,7 @@ impl ModuleExports { #[derive(Clone)] pub enum WrapperModule { Ast(Arc), - Exports(ModuleExports), + Exports(CompiledModule), } impl WrapperModule { @@ -91,7 +91,7 @@ impl WrapperModule { #[derive(Clone)] pub enum PendingWrapperModule { Ast(Box), - Exports(ModuleExports), + Exports(CompiledModule), } impl PendingWrapperModule { @@ -158,10 +158,35 @@ impl ModuleGraph { /// /// This function will panic if the number of modules exceeds the maximum representable /// [ModuleIndex] value, `u16::MAX`. - pub fn add_module( + pub fn add_ast_module(&mut self, module: Box) -> Result { + self.add_module(PendingWrapperModule::Ast(module)) + } + + /// Add compiled `module` to the graph. + /// + /// NOTE: This operation only adds a module to the graph, but does not perform the + /// important analysis needed for compilation, you must call [recompute] once all modules + /// are added to ensure the analysis results reflect the current version of the graph. + /// + /// # Errors + /// + /// This operation can fail for the following reasons: + /// + /// * Module with same [LibraryPath] is in the graph already + /// * Too many modules in the graph + /// + /// # Panics + /// + /// This function will panic if the number of modules exceeds the maximum representable + /// [ModuleIndex] value, `u16::MAX`. + pub fn add_compiled_module( &mut self, - module: PendingWrapperModule, + module: CompiledModule, ) -> Result { + self.add_module(PendingWrapperModule::Exports(module)) + } + + fn add_module(&mut self, module: PendingWrapperModule) -> Result { let is_duplicate = self.is_pending(module.path()) || self.find_module_index(module.path()).is_some(); if is_duplicate { From 78780325cce2b00aa5549cd103c170ea27bf94d9 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 17 Jul 2024 16:48:48 -0400 Subject: [PATCH 127/172] fix compile errors from API changes --- assembly/src/assembler/mod.rs | 11 +++++++---- assembly/src/assembler/module_graph/mod.rs | 16 ++++++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 4cca1fd9f0..3975895404 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -450,6 +450,7 @@ impl Assembler { // Find the executable entrypoint let entrypoint = self.module_graph[module_index] + .unwrap_ast() .index_of(|p| p.is_main()) .map(|index| GlobalProcedureIndex { module: module_index, @@ -542,6 +543,7 @@ impl Assembler { let kernel_module = self.module_graph[kernel_index].clone(); let mut kernel = Vec::new(); for (index, _syscall) in kernel_module + .unwrap_ast() .procedures() .enumerate() .filter(|(_, p)| p.visibility().is_syscall()) @@ -636,7 +638,7 @@ impl Assembler { mut mast_forest_builder: MastForestBuilder, ) -> Result { // Raise an error if we are called with an invalid entrypoint - assert!(self.module_graph[entrypoint].name().is_main()); + assert!(self.module_graph.get_procedure_unsafe(entrypoint).name().is_main()); // Compile the module graph rooted at the entrypoint let entry_procedure = @@ -680,8 +682,8 @@ impl Assembler { let mut nodes = Vec::with_capacity(iter.len()); for node in iter { let module = self.module_graph[node.module].path(); - let proc = self.module_graph[node].name(); - nodes.push(format!("{}::{}", module, proc)); + let proc = self.module_graph.get_procedure_unsafe(node); + nodes.push(format!("{}::{}", module, proc.name())); } AssemblyError::Cycle { nodes } })?; @@ -768,7 +770,8 @@ impl Assembler { let num_locals = procedure.num_locals(); context.set_current_procedure(procedure); - let proc = self.module_graph[gid].unwrap_procedure(); + let wrapper_proc = self.module_graph.get_procedure_unsafe(gid); + let proc = wrapper_proc.unwrap_ast().unwrap_procedure(); let proc_body_root = if num_locals > 0 { // for procedures with locals, we need to update fmp register before and after the // procedure body is executed. specifically: diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index d67aa36b57..2021e8963b 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -50,6 +50,13 @@ impl<'a> WrapperProcedure<'a> { WrapperProcedure::Compiled(p) => p.name(), } } + + pub fn unwrap_ast(&self) -> &Export { + match self { + WrapperProcedure::Ast(proc) => proc, + WrapperProcedure::Compiled(_) => panic!("expected AST procedure, but was compiled"), + } + } } // TODOP: Rename (?) @@ -85,6 +92,15 @@ impl WrapperModule { WrapperModule::Exports(m) => &m.path, } } + + pub fn unwrap_ast(&self) -> Arc { + match self { + WrapperModule::Ast(module) => module.clone(), + WrapperModule::Exports(_) => { + panic!("expected module to be in AST representation, but was compiled") + } + } + } } // TODOP: Try to do without this `Pending*` version From b3e337a0087d759cdae78588939e6e24c652ce88 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Wed, 17 Jul 2024 17:01:29 -0400 Subject: [PATCH 128/172] fix debug structs --- assembly/src/assembler/mod.rs | 1 - assembly/src/assembler/module_graph/debug.rs | 101 ++++++++++++++----- assembly/src/assembler/module_graph/mod.rs | 4 + 3 files changed, 78 insertions(+), 28 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 3975895404..b526691547 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -13,7 +13,6 @@ use crate::{ }; use alloc::{boxed::Box, sync::Arc, vec::Vec}; use mast_forest_builder::MastForestBuilder; -use module_graph::PendingWrapperModule; use vm_core::{ mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, Decorator, DecoratorList, Kernel, Operation, Program, diff --git a/assembly/src/assembler/module_graph/debug.rs b/assembly/src/assembler/module_graph/debug.rs index 9858e8ebff..2cf640602b 100644 --- a/assembly/src/assembler/module_graph/debug.rs +++ b/assembly/src/assembler/module_graph/debug.rs @@ -17,54 +17,100 @@ struct DisplayModuleGraph<'a>(&'a ModuleGraph); impl<'a> fmt::Debug for DisplayModuleGraph<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_set() - .entries(self.0.modules.iter().enumerate().flat_map(|(index, m)| { - m.procedures().enumerate().filter_map(move |(i, export)| { - if matches!(export, Export::Alias(_)) { - None - } else { - let gid = GlobalProcedureIndex { - module: ModuleIndex::new(index), - index: ProcedureIndex::new(i), - }; - let out_edges = self.0.callgraph.out_edges(gid); - Some(DisplayModuleGraphNodeWithEdges { gid, out_edges }) - } - }) + .entries(self.0.modules.iter().enumerate().flat_map(|(module_index, m)| { + match m { + WrapperModule::Ast(m) => m + .procedures() + .enumerate() + .filter_map(move |(i, export)| { + if matches!(export, Export::Alias(_)) { + None + } else { + let gid = GlobalProcedureIndex { + module: ModuleIndex::new(module_index), + index: ProcedureIndex::new(i), + }; + let out_edges = self.0.callgraph.out_edges(gid); + Some(DisplayModuleGraphNodeWithEdges { gid, out_edges }) + } + }) + .collect::>(), + WrapperModule::Exports(m) => m + .procedures + .iter() + .map(|(proc_index, _proc)| { + let gid = GlobalProcedureIndex { + module: ModuleIndex::new(module_index), + index: *proc_index, + }; + + let out_edges = self.0.callgraph.out_edges(gid); + DisplayModuleGraphNodeWithEdges { gid, out_edges } + }) + .collect::>(), + } })) .finish() } } #[doc(hidden)] -struct DisplayModuleGraphNodes<'a>(&'a Vec>); +struct DisplayModuleGraphNodes<'a>(&'a Vec); impl<'a> fmt::Debug for DisplayModuleGraphNodes<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_list() - .entries(self.0.iter().enumerate().flat_map(|(index, m)| { - m.procedures().enumerate().filter_map(move |(i, export)| { - if matches!(export, Export::Alias(_)) { - None - } else { - Some(DisplayModuleGraphNode { - module: ModuleIndex::new(index), - index: ProcedureIndex::new(i), + .entries(self.0.iter().enumerate().flat_map(|(module_index, m)| { + let module_index = ModuleIndex::new(module_index); + + match m { + WrapperModule::Ast(m) => m + .procedures() + .enumerate() + .filter_map(move |(proc_index, export)| { + if matches!(export, Export::Alias(_)) { + None + } else { + Some(DisplayModuleGraphNode { + module: module_index, + index: ProcedureIndex::new(proc_index), + path: m.path(), + proc_name: export.name(), + ty: GraphNodeType::Ast, + }) + } + }) + .collect::>(), + WrapperModule::Exports(m) => m + .procedures + .iter() + .map(|(proc_index, proc)| DisplayModuleGraphNode { + module: module_index, + index: *proc_index, path: m.path(), - proc: export, + proc_name: proc.name(), + ty: GraphNodeType::Compiled, }) - } - }) + .collect::>(), + } })) .finish() } } +#[derive(Debug)] +enum GraphNodeType { + Ast, + Compiled, +} + #[doc(hidden)] struct DisplayModuleGraphNode<'a> { module: ModuleIndex, index: ProcedureIndex, path: &'a LibraryPath, - proc: &'a Export, + proc_name: &'a ProcedureName, + ty: GraphNodeType, } impl<'a> fmt::Debug for DisplayModuleGraphNode<'a> { @@ -72,7 +118,8 @@ impl<'a> fmt::Debug for DisplayModuleGraphNode<'a> { f.debug_struct("Node") .field("id", &format_args!("{}:{}", &self.module.as_usize(), &self.index.as_usize())) .field("module", &self.path) - .field("name", &self.proc.name()) + .field("name", &self.proc_name) + .field("type", &self.ty) .finish() } } diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 2021e8963b..5e5efbe5ef 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -76,6 +76,10 @@ impl CompiledModule { .collect(), } } + + pub fn path(&self) -> &LibraryPath { + &self.path + } } // TODOP: Rename From a4a5e1751ef71f302c2aa3c0c3aabc81be3a2bd2 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 18 Jul 2024 11:16:25 -0400 Subject: [PATCH 129/172] fix `Assembler::get_module_exports()` --- assembly/src/assembler/mod.rs | 144 ++++++++++++--------- assembly/src/assembler/module_graph/mod.rs | 4 + assembly/src/compiled_library.rs | 15 ++- 3 files changed, 97 insertions(+), 66 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index b526691547..fbc0638abc 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -5,6 +5,7 @@ use crate::{ }, compiled_library::{ CompiledFullyQualifiedProcedureName, CompiledLibrary, CompiledLibraryMetadata, + CompiledProcedure, }, diagnostics::{tracing::instrument, Report}, sema::SemanticAnalysisError, @@ -335,11 +336,20 @@ impl Assembler { let exports = { let mut exports = Vec::new(); for module_id in module_ids { - let exports_in_module: Vec = - self.get_module_exports(module_id).map(|procedures| { + let module = self.module_graph.get_module(module_id).unwrap(); + let module_path = module.path(); + + let exports_in_module: Vec = self + .get_module_exports(module_id, mast_forest_builder.forest()) + .map(|procedures| { procedures .into_iter() - .map(|proc| proc.fully_qualified_name().clone().into()) + .map(|proc| { + CompiledFullyQualifiedProcedureName::new( + module_path.clone(), + proc.name, + ) + }) .collect() })?; @@ -507,15 +517,9 @@ impl Assembler { let mut mast_forest_builder = core::mem::take(&mut self.mast_forest_builder); self.assemble_graph(context, &mut mast_forest_builder)?; - let exported_procedure_digests = self.get_module_exports(module_id).map(|procedures| { - procedures - .into_iter() - .map(|proc| { - let proc_code_node = &mast_forest_builder.forest()[proc.body_node_id()]; - proc_code_node.digest() - }) - .collect() - }); + let exported_procedure_digests = self + .get_module_exports(module_id, mast_forest_builder.forest()) + .map(|procedures| procedures.into_iter().map(|proc| proc.digest).collect()); // Reassign the mast_forest to the assembler for use is a future program assembly self.mast_forest_builder = mast_forest_builder; @@ -560,67 +564,83 @@ impl Assembler { .map_err(|err| Report::new(AssemblyError::Kernel(err))) } + // TODOP: Fix docs /// Get the set of exported procedures of the given module. /// /// Returns an error if the provided Miden Assembly is invalid. fn get_module_exports( &mut self, - module: ModuleIndex, + module_index: ModuleIndex, + mast_forest: &MastForest, // TODOP: Return iterator instead? - ) -> Result>, Report> { - assert!(self.module_graph.contains_module(module), "invalid module index"); - - let mut exports = Vec::new(); - for (index, procedure) in self.module_graph[module].procedures().enumerate() { - // Only add exports; locals will be added if they are in the call graph rooted at those - // procedures - if !procedure.visibility().is_exported() { - continue; - } - let gid = match procedure { - Export::Procedure(_) => GlobalProcedureIndex { - module, - index: ProcedureIndex::new(index), - }, - Export::Alias(ref alias) => { - match alias.target() { - AliasTarget::MastRoot(digest) => { - self.procedure_cache.contains_mast_root(digest) - .unwrap_or_else(|| { - panic!( - "compilation apparently succeeded, but did not find a \ - entry in the procedure cache for alias '{}', i.e. '{}'", - alias.name(), - digest - ); - }) - } - AliasTarget::Path(ref name)=> { - self.module_graph.find(alias.source_file(), name)? - } + ) -> Result, Report> { + assert!(self.module_graph.contains_module(module_index), "invalid module index"); + + let exports: Vec = match &self.module_graph[module_index] { + module_graph::WrapperModule::Ast(module) => { + let mut exports = Vec::new(); + for (index, procedure) in module.procedures().enumerate() { + // Only add exports; locals will be added if they are in the call graph rooted + // at those procedures + if !procedure.visibility().is_exported() { + continue; } - } - }; - let proc = self.procedure_cache.get(gid).unwrap_or_else(|| match procedure { - Export::Procedure(ref proc) => { - panic!( - "compilation apparently succeeded, but did not find a \ + let gid = match procedure { + Export::Procedure(_) => GlobalProcedureIndex { + module: module_index, + index: ProcedureIndex::new(index), + }, + Export::Alias(ref alias) => { + match alias.target() { + AliasTarget::MastRoot(digest) => { + self.procedure_cache.contains_mast_root(digest) + .unwrap_or_else(|| { + panic!( + "compilation apparently succeeded, but did not find a \ + entry in the procedure cache for alias '{}', i.e. '{}'", + alias.name(), + digest + ); + }) + } + AliasTarget::Path(ref name)=> { + self.module_graph.find(alias.source_file(), name)? + } + } + } + }; + let proc = self.procedure_cache.get(gid).unwrap_or_else(|| match procedure { + Export::Procedure(ref proc) => { + panic!( + "compilation apparently succeeded, but did not find a \ entry in the procedure cache for '{}'", - proc.name() - ) - } - Export::Alias(ref alias) => { - panic!( - "compilation apparently succeeded, but did not find a \ + proc.name() + ) + } + Export::Alias(ref alias) => { + panic!( + "compilation apparently succeeded, but did not find a \ entry in the procedure cache for alias '{}', i.e. '{}'", - alias.name(), - alias.target() - ); + alias.name(), + alias.target() + ); + } + }); + + let compiled_proc = CompiledProcedure { + name: proc.name().clone(), + digest: mast_forest[proc.body_node_id()].digest(), + }; + + exports.push(compiled_proc); } - }); - exports.push(proc); - } + exports + } + module_graph::WrapperModule::Exports(module) => { + module.procedures().iter().map(|(_idx, proc)| proc).cloned().collect() + } + }; Ok(exports) } diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 5e5efbe5ef..fa0b83636e 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -80,6 +80,10 @@ impl CompiledModule { pub fn path(&self) -> &LibraryPath { &self.path } + + pub fn procedures(&self) -> &[(ProcedureIndex, CompiledProcedure)] { + &self.procedures + } } // TODOP: Rename diff --git a/assembly/src/compiled_library.rs b/assembly/src/compiled_library.rs index f2dcde68ff..2b49ec3a2f 100644 --- a/assembly/src/compiled_library.rs +++ b/assembly/src/compiled_library.rs @@ -9,15 +9,21 @@ use crate::{ // TODOP: Refactor `FullyQualifiedProcedureName` instead, and use `Span` where needed? pub struct CompiledFullyQualifiedProcedureName { /// The module path for this procedure. - pub module: LibraryPath, + pub module_path: LibraryPath, /// The name of the procedure. pub name: ProcedureName, } +impl CompiledFullyQualifiedProcedureName { + pub fn new(module_path: LibraryPath, name: ProcedureName) -> Self { + Self { module_path, name } + } +} + impl From for CompiledFullyQualifiedProcedureName { fn from(fqdn: FullyQualifiedProcedureName) -> Self { Self { - module: fqdn.module, + module_path: fqdn.module, name: fqdn.name, } } @@ -25,10 +31,11 @@ impl From for CompiledFullyQualifiedProcedureName { #[derive(Clone)] pub struct CompiledProcedure { - name: ProcedureName, - digest: RpoDigest, + pub name: ProcedureName, + pub digest: RpoDigest, } +// TODOP: Remove methods in favor of pub fields? impl CompiledProcedure { pub fn name(&self) -> &ProcedureName { &self.name From b13016782a95c11ad22426154d2a4798f2b3458c Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 18 Jul 2024 11:23:14 -0400 Subject: [PATCH 130/172] fix `process_graph_worklist` --- assembly/src/assembler/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index fbc0638abc..d6c52e2b52 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -742,7 +742,7 @@ impl Assembler { let is_entry = entrypoint == Some(procedure_gid); // Fetch procedure metadata from the graph - let module = &self.module_graph[procedure_gid.module]; + let module = &self.module_graph[procedure_gid.module].unwrap_ast(); let ast = &module[procedure_gid.index]; let num_locals = ast.num_locals(); let name = FullyQualifiedProcedureName { From 1b5a1b5ba06a5b4e0f7a946dec188ed405f94a14 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 18 Jul 2024 11:25:45 -0400 Subject: [PATCH 131/172] fix procedure --- assembly/src/assembler/instruction/procedures.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index 9894a82092..8412094984 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -59,7 +59,7 @@ impl Assembler { callee: proc.fully_qualified_name().clone(), }) .and_then(|module| { - if module.is_kernel() { + if module.unwrap_ast().is_kernel() { Ok(()) } else { Err(AssemblyError::InvalidSysCallTarget { From 9bb9ed5c43f895d49b9ce41d9c467e28c75032b6 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 18 Jul 2024 11:32:25 -0400 Subject: [PATCH 132/172] fix `NameResolver` --- assembly/src/assembler/module_graph/mod.rs | 6 +++--- assembly/src/assembler/module_graph/name_resolver.rs | 9 +++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index fa0b83636e..1963ca8864 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -101,9 +101,9 @@ impl WrapperModule { } } - pub fn unwrap_ast(&self) -> Arc { + pub fn unwrap_ast(&self) -> &Arc { match self { - WrapperModule::Ast(module) => module.clone(), + WrapperModule::Ast(module) => module, WrapperModule::Exports(_) => { panic!("expected module to be in AST representation, but was compiled") } @@ -558,7 +558,7 @@ impl ModuleGraph { /// Fetch a [WrapperProcedure] by [GlobalProcedureIndex], or `None` if index is invalid. pub fn get_procedure(&self, id: GlobalProcedureIndex) -> Option { match &self.modules[id.module.as_usize()] { - WrapperModule::Ast(m) => m.get(id.index).map(|export| WrapperProcedure::Ast(export)), + WrapperModule::Ast(m) => m.get(id.index).map(WrapperProcedure::Ast), WrapperModule::Exports(m) => m .procedures .get(id.index.as_usize()) diff --git a/assembly/src/assembler/module_graph/name_resolver.rs b/assembly/src/assembler/module_graph/name_resolver.rs index 6070a52e18..8f4c93ee68 100644 --- a/assembly/src/assembler/module_graph/name_resolver.rs +++ b/assembly/src/assembler/module_graph/name_resolver.rs @@ -177,7 +177,7 @@ impl<'a> NameResolver<'a> { .get_name(gid.index) .clone() } else { - self.graph[gid].name().clone() + self.graph.get_procedure_unsafe(gid).name().clone() }; Ok(ResolvedTarget::Resolved { gid, @@ -210,6 +210,7 @@ impl<'a> NameResolver<'a> { .resolve_import(name) } else { self.graph[caller.module] + .unwrap_ast() .resolve_import(name) .map(|import| Span::new(import.span(), import.path())) } @@ -255,7 +256,7 @@ impl<'a> NameResolver<'a> { .get_name(gid.index) .clone() } else { - self.graph[gid].name().clone() + self.graph.get_procedure_unsafe(gid).name().clone() }; Ok(ResolvedTarget::Resolved { gid, @@ -315,7 +316,7 @@ impl<'a> NameResolver<'a> { if module_index >= pending_offset { self.pending[module_index - pending_offset].resolver.resolve(callee) } else { - self.graph[module].resolve(callee) + self.graph[module].unwrap_ast().resolve(callee) } } @@ -487,7 +488,7 @@ impl<'a> NameResolver<'a> { if module_index >= pending_offset { self.pending[module_index - pending_offset].source_file.clone() } else { - self.graph[module].source_file() + self.graph[module].unwrap_ast().source_file() } } From f5699f1386e9dd7acb7ae55691961b739dbb62f0 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 18 Jul 2024 12:56:30 -0400 Subject: [PATCH 133/172] move `CompiledModule` --- assembly/src/assembler/module_graph/debug.rs | 4 +- assembly/src/assembler/module_graph/mod.rs | 41 ++++---------------- assembly/src/compiled_library.rs | 34 +++++++++++++++- 3 files changed, 42 insertions(+), 37 deletions(-) diff --git a/assembly/src/assembler/module_graph/debug.rs b/assembly/src/assembler/module_graph/debug.rs index 2cf640602b..4596cf0e3b 100644 --- a/assembly/src/assembler/module_graph/debug.rs +++ b/assembly/src/assembler/module_graph/debug.rs @@ -36,7 +36,7 @@ impl<'a> fmt::Debug for DisplayModuleGraph<'a> { }) .collect::>(), WrapperModule::Exports(m) => m - .procedures + .procedures() .iter() .map(|(proc_index, _proc)| { let gid = GlobalProcedureIndex { @@ -82,7 +82,7 @@ impl<'a> fmt::Debug for DisplayModuleGraphNodes<'a> { }) .collect::>(), WrapperModule::Exports(m) => m - .procedures + .procedures() .iter() .map(|(proc_index, proc)| DisplayModuleGraphNode { module: module_index, diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 1963ca8864..661e98fd5a 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -27,7 +27,7 @@ use self::{ rewrites::ModuleRewriter, }; use super::{GlobalProcedureIndex, ModuleIndex}; -use crate::compiled_library::CompiledProcedure; +use crate::compiled_library::{CompiledModule, CompiledProcedure}; use crate::{ ast::{ Export, FullyQualifiedProcedureName, InvocationTarget, Module, ProcedureIndex, @@ -59,33 +59,6 @@ impl<'a> WrapperProcedure<'a> { } } -// TODOP: Rename (?) -#[derive(Clone)] -pub struct CompiledModule { - path: LibraryPath, - procedures: Vec<(ProcedureIndex, CompiledProcedure)>, -} - -impl CompiledModule { - pub fn new(path: LibraryPath, procedures: impl Iterator) -> Self { - Self { - path, - procedures: procedures - .enumerate() - .map(|(idx, proc)| (ProcedureIndex::new(idx), proc)) - .collect(), - } - } - - pub fn path(&self) -> &LibraryPath { - &self.path - } - - pub fn procedures(&self) -> &[(ProcedureIndex, CompiledProcedure)] { - &self.procedures - } -} - // TODOP: Rename #[derive(Clone)] pub enum WrapperModule { @@ -97,7 +70,7 @@ impl WrapperModule { pub fn path(&self) -> &LibraryPath { match self { WrapperModule::Ast(m) => m.path(), - WrapperModule::Exports(m) => &m.path, + WrapperModule::Exports(m) => m.path(), } } @@ -122,7 +95,7 @@ impl PendingWrapperModule { pub fn path(&self) -> &LibraryPath { match self { PendingWrapperModule::Ast(m) => m.path(), - PendingWrapperModule::Exports(m) => &m.path, + PendingWrapperModule::Exports(m) => m.path(), } } } @@ -393,7 +366,7 @@ impl ModuleGraph { } } PendingWrapperModule::Exports(pending_module) => { - for (procedure_id, _procedure) in pending_module.procedures.iter() { + for (procedure_id, _procedure) in pending_module.procedures().iter() { let global_id = GlobalProcedureIndex { module: module_id, index: *procedure_id, @@ -560,7 +533,7 @@ impl ModuleGraph { match &self.modules[id.module.as_usize()] { WrapperModule::Ast(m) => m.get(id.index).map(WrapperProcedure::Ast), WrapperModule::Exports(m) => m - .procedures + .procedures() .get(id.index.as_usize()) .map(|(_idx, proc)| WrapperProcedure::Compiled(proc)), } @@ -574,7 +547,7 @@ impl ModuleGraph { match &self.modules[id.module.as_usize()] { WrapperModule::Ast(m) => WrapperProcedure::Ast(&m[id.index]), WrapperModule::Exports(m) => { - WrapperProcedure::Compiled(&m.procedures[id.index.as_usize()].1) + WrapperProcedure::Compiled(&m.procedures()[id.index.as_usize()].1) } } } @@ -755,7 +728,7 @@ impl ModuleGraph { } WrapperModule::Exports(module) => { break module - .procedures + .procedures() .iter() .find(|(_index, procedure)| procedure.name() == &name.name) .map(|(index, _)| GlobalProcedureIndex { diff --git a/assembly/src/compiled_library.rs b/assembly/src/compiled_library.rs index 2b49ec3a2f..6eba91cdd0 100644 --- a/assembly/src/compiled_library.rs +++ b/assembly/src/compiled_library.rs @@ -2,7 +2,7 @@ use alloc::{string::String, vec::Vec}; use vm_core::{crypto::hash::RpoDigest, mast::MastForest}; use crate::{ - ast::{FullyQualifiedProcedureName, ProcedureName}, + ast::{FullyQualifiedProcedureName, ProcedureIndex, ProcedureName}, LibraryPath, Version, }; @@ -81,9 +81,41 @@ impl CompiledLibrary { pub fn metadata(&self) -> &CompiledLibraryMetadata { &self.metadata } + + pub fn into_compiled_modules(self) -> Vec { + todo!() + } } pub struct CompiledLibraryMetadata { pub name: String, pub version: Version, } + +// TODOP: Rename (?) +#[derive(Clone)] +pub struct CompiledModule { + path: LibraryPath, + procedures: Vec<(ProcedureIndex, CompiledProcedure)>, +} + +impl CompiledModule { + pub fn new(path: LibraryPath, procedures: impl Iterator) -> Self { + Self { + path, + procedures: procedures + .enumerate() + .map(|(idx, proc)| (ProcedureIndex::new(idx), proc)) + .collect(), + } + } + + pub fn path(&self) -> &LibraryPath { + &self.path + } + + // TODOP: Store as `CompiledProcedure`, and add a method `iter()` that iterates with `ProcedureIndex` + pub fn procedures(&self) -> &[(ProcedureIndex, CompiledProcedure)] { + &self.procedures + } +} From 8d148ab8753e309b11b52b9f2dc3c8c88bd615b6 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 18 Jul 2024 13:37:36 -0400 Subject: [PATCH 134/172] `CompiledLibrary::into_compiled_modules` --- assembly/src/compiled_library.rs | 44 ++++++++++++++++++++++--- core/src/mast/mod.rs | 6 ++-- processor/src/host/mast_forest_store.rs | 7 ++-- 3 files changed, 48 insertions(+), 9 deletions(-) diff --git a/assembly/src/compiled_library.rs b/assembly/src/compiled_library.rs index 6eba91cdd0..29372431dd 100644 --- a/assembly/src/compiled_library.rs +++ b/assembly/src/compiled_library.rs @@ -1,5 +1,8 @@ -use alloc::{string::String, vec::Vec}; -use vm_core::{crypto::hash::RpoDigest, mast::MastForest}; +use alloc::{collections::BTreeMap, string::String, vec::Vec}; +use vm_core::{ + crypto::hash::RpoDigest, + mast::{MastForest, MerkleTreeNode}, +}; use crate::{ ast::{FullyQualifiedProcedureName, ProcedureIndex, ProcedureName}, @@ -56,6 +59,7 @@ pub struct CompiledLibrary { /// Constructors impl CompiledLibrary { + // TODOP: Add validation that num roots = num exports pub fn new( mast_forest: MastForest, exports: Vec, @@ -83,7 +87,33 @@ impl CompiledLibrary { } pub fn into_compiled_modules(self) -> Vec { - todo!() + let mut modules_by_path: BTreeMap = BTreeMap::new(); + + for (proc_index, proc_name) in self.exports.into_iter().enumerate() { + modules_by_path + .entry(proc_name.module_path.clone()) + .and_modify(|compiled_module| { + let proc_node_id = self.mast_forest.procedure_roots()[proc_index]; + let proc_digest = self.mast_forest[proc_node_id].digest(); + + compiled_module.add_procedure(CompiledProcedure { + name: proc_name.name.clone(), + digest: proc_digest, + }) + }) + .or_insert_with(|| { + let proc_node_id = self.mast_forest.procedure_roots()[proc_index]; + let proc_digest = self.mast_forest[proc_node_id].digest(); + let proc = CompiledProcedure { + name: proc_name.name, + digest: proc_digest, + }; + + CompiledModule::new(proc_name.module_path, core::iter::once(proc)) + }); + } + + modules_by_path.into_values().collect() } } @@ -110,11 +140,17 @@ impl CompiledModule { } } + pub fn add_procedure(&mut self, procedure: CompiledProcedure) { + let index = ProcedureIndex::new(self.procedures.len()); + self.procedures.push((index, procedure)); + } + pub fn path(&self) -> &LibraryPath { &self.path } - // TODOP: Store as `CompiledProcedure`, and add a method `iter()` that iterates with `ProcedureIndex` + // TODOP: Store as `CompiledProcedure`, and add a method `iter()` that iterates with + // `ProcedureIndex` pub fn procedures(&self) -> &[(ProcedureIndex, CompiledProcedure)] { &self.procedures } diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 70a87c58c1..65b0a6be02 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -149,9 +149,9 @@ impl MastForest { self.roots.iter().find(|&&root_id| self[root_id].digest() == digest).copied() } - /// Returns an iterator over the digest of the procedures in this MAST forest. - pub fn procedure_roots(&self) -> impl Iterator + '_ { - self.roots.iter().map(|&root_id| self[root_id].digest()) + /// Returns an iterator over the IDs of the procedures in this MAST forest. + pub fn procedure_roots(&self) -> &[MastNodeId] { + &self.roots } /// Returns the number of procedures in this MAST forest. diff --git a/processor/src/host/mast_forest_store.rs b/processor/src/host/mast_forest_store.rs index 06c3634250..f8e47eb01c 100644 --- a/processor/src/host/mast_forest_store.rs +++ b/processor/src/host/mast_forest_store.rs @@ -1,5 +1,8 @@ use alloc::{collections::BTreeMap, sync::Arc}; -use vm_core::{crypto::hash::RpoDigest, mast::MastForest}; +use vm_core::{ + crypto::hash::RpoDigest, + mast::{MastForest, MerkleTreeNode}, +}; /// A set of [`MastForest`]s available to the prover that programs may refer to (by means of an /// [`vm_core::mast::ExternalNode`]). @@ -25,7 +28,7 @@ impl MemMastForestStore { pub fn insert(&mut self, mast_forest: MastForest) { let mast_forest = Arc::new(mast_forest); - for root in mast_forest.procedure_roots() { + for root in mast_forest.procedure_roots().iter().map(|&id| mast_forest[id].digest()) { self.mast_forests.insert(root, mast_forest.clone()); } } From 2904e1ed4e3790c9972b8501a500db61880d6ff7 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 18 Jul 2024 13:41:29 -0400 Subject: [PATCH 135/172] `Assembler::add_compiled_library()` --- assembly/src/assembler/mod.rs | 9 +++++++++ assembly/src/assembler/module_graph/mod.rs | 1 + 2 files changed, 10 insertions(+) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index d6c52e2b52..871868dca5 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -228,6 +228,15 @@ impl Assembler { Ok(()) } + /// TODOP: Add `vendored` flag and document + pub fn add_compiled_library(&mut self, library: CompiledLibrary) -> Result<(), Report> { + for module in library.into_compiled_modules() { + self.module_graph.add_compiled_module(module)?; + } + + Ok(()) + } + /// Adds the library to provide modules for the compilation. pub fn with_library(mut self, library: &L) -> Result where diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 661e98fd5a..4196bb67df 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -529,6 +529,7 @@ impl ModuleGraph { } /// Fetch a [WrapperProcedure] by [GlobalProcedureIndex], or `None` if index is invalid. + #[allow(unused)] pub fn get_procedure(&self, id: GlobalProcedureIndex) -> Option { match &self.modules[id.module.as_usize()] { WrapperModule::Ast(m) => m.get(id.index).map(WrapperProcedure::Ast), From 2adf48dc9ca9ce8dc78f3f7f4755b1a548111c52 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 18 Jul 2024 13:43:37 -0400 Subject: [PATCH 136/172] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 89c5cd9821..8fd073d099 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ - Added support for immediate values for `u32and`, `u32or`, `u32xor` and `u32not` bitwise instructions (#1362). - Optimized `std::sys::truncate_stuck` procedure (#1384). - Add serialization/deserialization for `MastForest` (#1370) +- Assembler: add the ability to compile MAST libraries, and to assemble a program using compiled libraries (#1401) #### Changed From f9d6c150b42da7839ae9ebf517f9d6e869b29b00 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 18 Jul 2024 13:53:59 -0400 Subject: [PATCH 137/172] fix `assemble_library()` signature --- assembly/src/assembler/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 871868dca5..ab154b8289 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -324,9 +324,9 @@ impl Assembler { /// Compilation/Assembly impl Assembler { // TODOP: Document - pub fn assemble_library>( + pub fn assemble_library( mut self, - modules: T, + modules: impl Iterator, metadata: CompiledLibraryMetadata, // name, version etc. ) -> Result { let module_ids: Vec = modules From 506194146b83c4582e53ae7b959bacef6b5775c8 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 18 Jul 2024 15:20:20 -0400 Subject: [PATCH 138/172] test `compiled_library()` --- assembly/src/assembler/mod.rs | 2 +- assembly/src/compiled_library.rs | 4 +- assembly/src/tests.rs | 80 +++++++++++++++++++++++++++++++- 3 files changed, 82 insertions(+), 4 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index ab154b8289..7be132cc69 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -331,7 +331,7 @@ impl Assembler { ) -> Result { let module_ids: Vec = modules .map(|module| { - let module = module.compile()?; + let module = module.compile_with_options(CompileOptions::for_library())?; Ok(self.module_graph.add_ast_module(module)?) }) .collect::>()?; diff --git a/assembly/src/compiled_library.rs b/assembly/src/compiled_library.rs index 29372431dd..6a4931eae8 100644 --- a/assembly/src/compiled_library.rs +++ b/assembly/src/compiled_library.rs @@ -1,4 +1,4 @@ -use alloc::{collections::BTreeMap, string::String, vec::Vec}; +use alloc::{collections::BTreeMap, vec::Vec}; use vm_core::{ crypto::hash::RpoDigest, mast::{MastForest, MerkleTreeNode}, @@ -118,7 +118,7 @@ impl CompiledLibrary { } pub struct CompiledLibraryMetadata { - pub name: String, + pub path: LibraryPath, pub version: Version, } diff --git a/assembly/src/tests.rs b/assembly/src/tests.rs index 07cdc00917..7ae11f8760 100644 --- a/assembly/src/tests.rs +++ b/assembly/src/tests.rs @@ -3,10 +3,12 @@ use alloc::{rc::Rc, string::ToString, vec::Vec}; use crate::{ assert_diagnostic_lines, ast::{Module, ModuleKind}, + compiled_library::CompiledLibraryMetadata, diagnostics::Report, regex, source_file, testing::{Pattern, TestContext}, - Assembler, AssemblyContext, Library, LibraryNamespace, LibraryPath, MaslLibrary, Version, + Assembler, AssemblyContext, Library, LibraryNamespace, LibraryPath, MaslLibrary, ModuleParser, + Version, }; type TestResult = Result<(), Report>; @@ -2383,6 +2385,82 @@ fn invalid_while() -> TestResult { Ok(()) } +// COMPILED LIBRARIES +// ================================================================================================ +#[test] +fn test_compiled_library() { + let mut mod_parser = ModuleParser::new(ModuleKind::Library); + let mod1 = { + let source = source_file!( + " + proc.internal + push.5 + end + + export.foo + push.1 + drop + end + + export.bar + exec.internal + drop + end + " + ); + mod_parser.parse(LibraryPath::new("mod1").unwrap(), source).unwrap() + }; + + let mod2 = { + let source = source_file!( + " + export.foo + push.7 + add.5 + end + + # Same definition as mod1::foo + export.bar + push.1 + drop + end + " + ); + mod_parser.parse(LibraryPath::new("mod2").unwrap(), source).unwrap() + }; + + let metadata = CompiledLibraryMetadata { + path: LibraryPath::new("mylib").unwrap(), + version: Version::min(), + }; + + let compiled_library = { + let assembler = Assembler::new(); + assembler.assemble_library(vec![mod1, mod2].into_iter(), metadata).unwrap() + }; + + assert_eq!(compiled_library.exports().len(), 4); + + // Compile program that uses compiled library + let mut assembler = Assembler::new(); + + assembler.add_compiled_library(compiled_library).unwrap(); + + let program_source = " + use.mylib::mod1 + use.mylib::mod2 + + begin + exec.mod1::foo + exec.mod1::bar + exec.mod2::foo + exec.mod2::bar + end + "; + + let _program = assembler.assemble(program_source).unwrap(); +} + // DUMMY LIBRARY // ================================================================================================ From 4fba058df0279e02879cdeae7e20a8eb849613fb Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 18 Jul 2024 15:55:30 -0400 Subject: [PATCH 139/172] nits --- assembly/src/assembler/mod.rs | 1 + assembly/src/compiled_library.rs | 4 ++-- assembly/src/tests.rs | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 7be132cc69..97e456035b 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -324,6 +324,7 @@ impl Assembler { /// Compilation/Assembly impl Assembler { // TODOP: Document + // TODOP: Check that `CompiledLibraryMetadata` is consistent with `modules` (e.g. modules path indeed start with library name) pub fn assemble_library( mut self, modules: impl Iterator, diff --git a/assembly/src/compiled_library.rs b/assembly/src/compiled_library.rs index 6a4931eae8..fed148083f 100644 --- a/assembly/src/compiled_library.rs +++ b/assembly/src/compiled_library.rs @@ -32,7 +32,7 @@ impl From for CompiledFullyQualifiedProcedureName { } } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct CompiledProcedure { pub name: ProcedureName, pub digest: RpoDigest, @@ -123,7 +123,7 @@ pub struct CompiledLibraryMetadata { } // TODOP: Rename (?) -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct CompiledModule { path: LibraryPath, procedures: Vec<(ProcedureIndex, CompiledProcedure)>, diff --git a/assembly/src/tests.rs b/assembly/src/tests.rs index 7ae11f8760..61fda8c32e 100644 --- a/assembly/src/tests.rs +++ b/assembly/src/tests.rs @@ -2408,7 +2408,7 @@ fn test_compiled_library() { end " ); - mod_parser.parse(LibraryPath::new("mod1").unwrap(), source).unwrap() + mod_parser.parse(LibraryPath::new("mylib::mod1").unwrap(), source).unwrap() }; let mod2 = { @@ -2426,7 +2426,7 @@ fn test_compiled_library() { end " ); - mod_parser.parse(LibraryPath::new("mod2").unwrap(), source).unwrap() + mod_parser.parse(LibraryPath::new("mylib::mod2").unwrap(), source).unwrap() }; let metadata = CompiledLibraryMetadata { From c287f7382709c11748835be09b15a67bd57bff0c Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 18 Jul 2024 16:09:43 -0400 Subject: [PATCH 140/172] register mast roots in `Assembler::add_compiled_library()` --- assembly/src/assembler/mod.rs | 23 ++++++++++++++++++++-- assembly/src/assembler/module_graph/mod.rs | 9 +++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 97e456035b..f79ff823aa 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -230,8 +230,27 @@ impl Assembler { /// TODOP: Add `vendored` flag and document pub fn add_compiled_library(&mut self, library: CompiledLibrary) -> Result<(), Report> { - for module in library.into_compiled_modules() { - self.module_graph.add_compiled_module(module)?; + let module_indexes: Vec = library + .into_compiled_modules() + .into_iter() + .map(|module| self.module_graph.add_compiled_module(module)) + .collect::>()?; + + // TODOP: Try to remove this recompute() + self.module_graph.recompute()?; + + // Register all procedures as roots + for module_index in module_indexes { + for (proc_index, proc) in + self.module_graph[module_index].unwrap_compiled().clone().procedures() + { + let gid = GlobalProcedureIndex { + module: module_index, + index: *proc_index, + }; + + self.module_graph.register_mast_root(gid, proc.digest)?; + } } Ok(()) diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 4196bb67df..3f1aee34cf 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -82,6 +82,15 @@ impl WrapperModule { } } } + + pub fn unwrap_compiled(&self) -> &CompiledModule { + match self { + WrapperModule::Ast(_) => { + panic!("expected module to be compiled, but was in AST representation") + } + WrapperModule::Exports(module) => module, + } + } } // TODOP: Try to do without this `Pending*` version From f5880e798aa4a395c6063455805952e70a94e224 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 18 Jul 2024 16:25:29 -0400 Subject: [PATCH 141/172] fix resolve --- assembly/src/assembler/module_graph/mod.rs | 8 ++++++++ .../assembler/module_graph/name_resolver.rs | 2 +- assembly/src/compiled_library.rs | 18 ++++++++++++++++-- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 3f1aee34cf..ed7522b888 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -91,6 +91,14 @@ impl WrapperModule { WrapperModule::Exports(module) => module, } } + + /// Resolves `name` to a procedure within the local scope of this module + pub fn resolve(&self, name: &ProcedureName) -> Option { + match self { + WrapperModule::Ast(module) => module.resolve(name), + WrapperModule::Exports(module) => module.resolve(name), + } + } } // TODOP: Try to do without this `Pending*` version diff --git a/assembly/src/assembler/module_graph/name_resolver.rs b/assembly/src/assembler/module_graph/name_resolver.rs index 8f4c93ee68..c5f1ee1a9e 100644 --- a/assembly/src/assembler/module_graph/name_resolver.rs +++ b/assembly/src/assembler/module_graph/name_resolver.rs @@ -316,7 +316,7 @@ impl<'a> NameResolver<'a> { if module_index >= pending_offset { self.pending[module_index - pending_offset].resolver.resolve(callee) } else { - self.graph[module].unwrap_ast().resolve(callee) + self.graph[module].resolve(callee) } } diff --git a/assembly/src/compiled_library.rs b/assembly/src/compiled_library.rs index fed148083f..648184d5d0 100644 --- a/assembly/src/compiled_library.rs +++ b/assembly/src/compiled_library.rs @@ -5,8 +5,8 @@ use vm_core::{ }; use crate::{ - ast::{FullyQualifiedProcedureName, ProcedureIndex, ProcedureName}, - LibraryPath, Version, + ast::{FullyQualifiedProcedureName, ProcedureIndex, ProcedureName, ResolvedProcedure}, + LibraryPath, SourceSpan, Span, Version, }; // TODOP: Refactor `FullyQualifiedProcedureName` instead, and use `Span` where needed? @@ -154,4 +154,18 @@ impl CompiledModule { pub fn procedures(&self) -> &[(ProcedureIndex, CompiledProcedure)] { &self.procedures } + + pub fn resolve(&self, name: &ProcedureName) -> Option { + self.procedures.iter().find_map(|(_, proc)| { + if proc.name() == name { + // TODOP: FIX SPAN + Some(ResolvedProcedure::MastRoot(Span::new( + SourceSpan::at(0), + proc.digest().clone(), + ))) + } else { + None + } + }) + } } From b65703143f4726408877b66eb036f3fd102fce35 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 18 Jul 2024 17:35:14 -0400 Subject: [PATCH 142/172] `ModuleGraph::topological_sort_from_root`: only include AST procedures --- assembly/src/assembler/module_graph/mod.rs | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index ed7522b888..c17131643c 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -57,6 +57,14 @@ impl<'a> WrapperProcedure<'a> { WrapperProcedure::Compiled(_) => panic!("expected AST procedure, but was compiled"), } } + + pub fn is_compiled(&self) -> bool { + matches!(self, Self::Compiled(_)) + } + + pub fn is_ast(&self) -> bool { + matches!(self, Self::Ast(_)) + } } // TODOP: Rename @@ -531,7 +539,14 @@ impl ModuleGraph { &self, caller: GlobalProcedureIndex, ) -> Result, CycleError> { - self.callgraph.toposort_caller(caller) + // TODOP: Fix Vec -> into_iter() -> collect + // TODOP: Should we change name/args? + Ok(self + .callgraph + .toposort_caller(caller)? + .into_iter() + .filter(|&gid| self.get_procedure_unsafe(gid).is_ast()) + .collect()) } /// Fetch a [Module] by [ModuleIndex] From c0fb87ae837f6d501ca7da4fe22a535eff40ed03 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 19 Jul 2024 10:07:33 -0400 Subject: [PATCH 143/172] `Assembler::resolve_target()`: look for digest in module graph first --- assembly/src/assembler/mod.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index f79ff823aa..8d29f28355 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -962,11 +962,15 @@ impl Assembler { let resolved = self.module_graph.resolve_target(&caller, target)?; match resolved { ResolvedTarget::Phantom(digest) | ResolvedTarget::Cached { digest, .. } => Ok(digest), - ResolvedTarget::Exact { gid } | ResolvedTarget::Resolved { gid, .. } => Ok(self - .procedure_cache - .get(gid) - .map(|p| p.mast_root(mast_forest)) - .expect("expected callee to have been compiled already")), + ResolvedTarget::Exact { gid } | ResolvedTarget::Resolved { gid, .. } => Ok( + // first look in the module graph, and fallback to the procedure cache + self.module_graph.get_mast_root(gid).copied().unwrap_or_else(|| { + self.procedure_cache + .get(gid) + .map(|p| p.mast_root(mast_forest)) + .expect("expected callee to have been compiled already") + }), + ), } } } From d3b856cd6fd8844cf3ad4ef970086da5fff3ef0f Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 19 Jul 2024 10:27:10 -0400 Subject: [PATCH 144/172] remove `AssemblyContext::allow_phantom_calls` flag --- assembly/src/assembler/context.rs | 47 +------------------ .../src/assembler/instruction/procedures.rs | 16 +++---- assembly/src/assembler/mod.rs | 3 +- assembly/src/assembler/module_graph/mod.rs | 4 -- assembly/src/compiled_library.rs | 2 +- assembly/src/errors.rs | 9 ---- assembly/src/tests.rs | 23 +-------- 7 files changed, 11 insertions(+), 93 deletions(-) diff --git a/assembly/src/assembler/context.rs b/assembly/src/assembler/context.rs index 1cc60d225d..8947fa8487 100644 --- a/assembly/src/assembler/context.rs +++ b/assembly/src/assembler/context.rs @@ -4,7 +4,7 @@ use super::{procedure::CallSet, ArtifactKind, GlobalProcedureIndex, Procedure}; use crate::{ ast::{FullyQualifiedProcedureName, Visibility}, diagnostics::SourceFile, - AssemblyError, LibraryPath, RpoDigest, SourceSpan, Span, Spanned, + AssemblyError, LibraryPath, RpoDigest, SourceSpan, Spanned, }; use vm_core::mast::{MastForest, MastNodeId}; @@ -26,11 +26,6 @@ pub struct AssemblyContext { kind: ArtifactKind, /// When true, promote warning diagnostics to errors warnings_as_errors: bool, - /// When true, this permits calls to refer to procedures which are not locally available, - /// as long as they are referenced by MAST root, and not by name. As long as the MAST for those - /// roots is present when the code is executed, this works fine. However, if the VM tries to - /// execute a program with such calls, and the MAST is not available, the program will trap. - allow_phantom_calls: bool, /// The current procedure being compiled current_procedure: Option, /// The fully-qualified module path which should be compiled. @@ -98,21 +93,6 @@ impl AssemblyContext { self.current_procedure.as_mut().expect("missing current procedure context") } - /// Enables phantom calls when compiling with this context. - /// - /// # Panics - /// - /// This function will panic if you attempt to enable phantom calls for a kernel-mode context, - /// as non-local procedure calls are not allowed in kernel modules. - pub fn with_phantom_calls(mut self, allow_phantom_calls: bool) -> Self { - assert!( - !self.is_kernel() || !allow_phantom_calls, - "kernel modules cannot have phantom calls enabled" - ); - self.allow_phantom_calls = allow_phantom_calls; - self - } - /// Returns true if this context is used for compiling a kernel. pub fn is_kernel(&self) -> bool { matches!(self.kind, ArtifactKind::Kernel) @@ -134,31 +114,6 @@ impl AssemblyContext { self.warnings_as_errors } - /// Registers a "phantom" call to the procedure with the specified MAST root. - /// - /// A phantom call indicates that code for the procedure is not available. Executing a phantom - /// call will result in a runtime error. However, the VM may be able to execute a program with - /// phantom calls as long as the branches containing them are not taken. - /// - /// # Errors - /// Returns an error if phantom calls are not allowed in this assembly context. - pub fn register_phantom_call( - &mut self, - mast_root: Span, - ) -> Result<(), AssemblyError> { - if !self.allow_phantom_calls { - let source_file = self.unwrap_current_procedure().source_file().clone(); - let (span, digest) = mast_root.into_parts(); - Err(AssemblyError::PhantomCallsNotAllowed { - span, - source_file, - digest, - }) - } else { - Ok(()) - } - } - /// Registers a call to an externally-defined procedure which we have previously compiled. /// /// The call set of the callee is added to the call set of the procedure we are currently diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index 8412094984..1fd575f04e 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -2,7 +2,7 @@ use super::{Assembler, AssemblyContext, BasicBlockBuilder, Operation}; use crate::{ assembler::mast_forest_builder::MastForestBuilder, ast::{InvocationTarget, InvokeKind}, - AssemblyError, RpoDigest, SourceSpan, Span, Spanned, + AssemblyError, RpoDigest, SourceSpan, Spanned, }; use smallvec::SmallVec; @@ -35,7 +35,7 @@ impl Assembler { let current_source_file = context.unwrap_current_procedure().source_file(); // If the procedure is cached, register the call to ensure the callset - // is updated correctly. Otherwise, register a phantom call. + // is updated correctly. match cache.get_by_mast_root(&mast_root) { Some(proc) if matches!(kind, InvokeKind::SysCall) => { // Verify if this is a syscall, that the callee is a kernel procedure @@ -81,7 +81,7 @@ impl Assembler { callee: mast_root, }); } - None => context.register_phantom_call(Span::new(span, mast_root))?, + None => (), } let mast_root_node_id = { @@ -161,14 +161,12 @@ impl Assembler { span_builder: &mut BasicBlockBuilder, mast_forest: &MastForest, ) -> Result<(), AssemblyError> { - let span = callee.span(); let digest = self.resolve_target(InvokeKind::Exec, callee, context, mast_forest)?; - self.procref_mast_root(span, digest, context, span_builder, mast_forest) + self.procref_mast_root(digest, context, span_builder, mast_forest) } fn procref_mast_root( &self, - span: SourceSpan, mast_root: RpoDigest, context: &mut AssemblyContext, span_builder: &mut BasicBlockBuilder, @@ -176,10 +174,8 @@ impl Assembler { ) -> Result<(), AssemblyError> { // Add the root to the callset to be able to use dynamic instructions // with the referenced procedure later - let cache = &self.procedure_cache; - match cache.get_by_mast_root(&mast_root) { - Some(proc) => context.register_external_call(&proc, false, mast_forest)?, - None => context.register_phantom_call(Span::new(span, mast_root))?, + if let Some(proc) = self.procedure_cache.get_by_mast_root(&mast_root) { + context.register_external_call(&proc, false, mast_forest)?; } // Create an array with `Push` operations containing root elements diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 8d29f28355..e2a143be36 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -343,7 +343,8 @@ impl Assembler { /// Compilation/Assembly impl Assembler { // TODOP: Document - // TODOP: Check that `CompiledLibraryMetadata` is consistent with `modules` (e.g. modules path indeed start with library name) + // TODOP: Check that `CompiledLibraryMetadata` is consistent with `modules` (e.g. modules path + // indeed start with library name) pub fn assemble_library( mut self, modules: impl Iterator, diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index c17131643c..e76e6256a1 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -58,10 +58,6 @@ impl<'a> WrapperProcedure<'a> { } } - pub fn is_compiled(&self) -> bool { - matches!(self, Self::Compiled(_)) - } - pub fn is_ast(&self) -> bool { matches!(self, Self::Ast(_)) } diff --git a/assembly/src/compiled_library.rs b/assembly/src/compiled_library.rs index 648184d5d0..4ac6407c62 100644 --- a/assembly/src/compiled_library.rs +++ b/assembly/src/compiled_library.rs @@ -161,7 +161,7 @@ impl CompiledModule { // TODOP: FIX SPAN Some(ResolvedProcedure::MastRoot(Span::new( SourceSpan::at(0), - proc.digest().clone(), + *proc.digest(), ))) } else { None diff --git a/assembly/src/errors.rs b/assembly/src/errors.rs index 5b51048292..b4a13271a8 100644 --- a/assembly/src/errors.rs +++ b/assembly/src/errors.rs @@ -77,15 +77,6 @@ pub enum AssemblyError { #[source_code] source_file: Option>, }, - #[error("cannot call phantom procedure: phantom calls are disabled")] - #[diagnostic(help("mast root is {digest}"))] - PhantomCallsNotAllowed { - #[label("the procedure referenced here is not available")] - span: SourceSpan, - #[source_code] - source_file: Option>, - digest: RpoDigest, - }, #[error("invalid syscall: '{callee}' is not an exported kernel procedure")] #[diagnostic()] InvalidSysCallTarget { diff --git a/assembly/src/tests.rs b/assembly/src/tests.rs index 61fda8c32e..f88dcaf6b7 100644 --- a/assembly/src/tests.rs +++ b/assembly/src/tests.rs @@ -1486,8 +1486,6 @@ fn program_with_invalid_rpo_digest_call() { ); } -/// Phantom calls are currently not implemented. Re-enable this test once they are implemented. -#[ignore] #[test] fn program_with_phantom_mast_call() -> TestResult { let mut context = TestContext::default(); @@ -1496,27 +1494,8 @@ fn program_with_phantom_mast_call() -> TestResult { ); let ast = context.parse_program(source)?; - // phantom calls not allowed - let assembler = Assembler::default().with_debug_mode(true); - - let mut context = AssemblyContext::for_program(ast.path()).with_phantom_calls(false); - let err = assembler - .assemble_in_context(ast.clone(), &mut context) - .expect_err("expected compilation to fail with phantom calls"); - assert_diagnostic_lines!( - err, - "cannot call phantom procedure: phantom calls are disabled", - regex!(r#",-\[test[\d]+:1:12\]"#), - "1 | begin call.0xc2545da99d3a1f3f38d957c7893c44d78998d8ea8b11aba7e22c8c2b2a213dae end", - " : ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^|^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^", - " : `-- the procedure referenced here is not available", - " `----", - " help: mast root is 0xc2545da99d3a1f3f38d957c7893c44d78998d8ea8b11aba7e22c8c2b2a213dae" - ); - - // phantom calls allowed let assembler = Assembler::default().with_debug_mode(true); - let mut context = AssemblyContext::for_program(ast.path()).with_phantom_calls(true); + let mut context = AssemblyContext::for_program(ast.path()); assembler.assemble_in_context(ast, &mut context)?; Ok(()) } From 096f2897e9aabce6a8a4ffd52104f53a82ed3bb9 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 19 Jul 2024 11:17:17 -0400 Subject: [PATCH 145/172] remove TODOP --- assembly/src/assembler/mod.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index e2a143be36..db078ec629 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -228,7 +228,7 @@ impl Assembler { Ok(()) } - /// TODOP: Add `vendored` flag and document + /// TODOP: documentation pub fn add_compiled_library(&mut self, library: CompiledLibrary) -> Result<(), Report> { let module_indexes: Vec = library .into_compiled_modules() @@ -236,7 +236,6 @@ impl Assembler { .map(|module| self.module_graph.add_compiled_module(module)) .collect::>()?; - // TODOP: Try to remove this recompute() self.module_graph.recompute()?; // Register all procedures as roots From f0efb06ccddcc9f76ff84b2b618d6acfd2e8edc7 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 19 Jul 2024 11:38:00 -0400 Subject: [PATCH 146/172] `ResolvedProcedure` is no longer `Spanned` --- .../src/assembler/module_graph/name_resolver.rs | 2 +- assembly/src/ast/module.rs | 12 +++++------- assembly/src/ast/procedure/resolver.rs | 14 ++------------ assembly/src/compiled_library.rs | 8 ++------ assembly/src/sema/errors.rs | 5 +++++ assembly/src/tests.rs | 8 +------- 6 files changed, 16 insertions(+), 33 deletions(-) diff --git a/assembly/src/assembler/module_graph/name_resolver.rs b/assembly/src/assembler/module_graph/name_resolver.rs index c5f1ee1a9e..f247ba47a3 100644 --- a/assembly/src/assembler/module_graph/name_resolver.rs +++ b/assembly/src/assembler/module_graph/name_resolver.rs @@ -189,7 +189,7 @@ impl<'a> NameResolver<'a> { Some(ResolvedProcedure::MastRoot(ref digest)) => { match self.graph.get_procedure_index_by_digest(digest) { Some(gid) => Ok(ResolvedTarget::Exact { gid }), - None => Ok(ResolvedTarget::Phantom(**digest)), + None => Ok(ResolvedTarget::Phantom(*digest)), } } None => Err(AssemblyError::Failed { diff --git a/assembly/src/ast/module.rs b/assembly/src/ast/module.rs index 31dc0886c3..4e6cb0826b 100644 --- a/assembly/src/ast/module.rs +++ b/assembly/src/ast/module.rs @@ -208,11 +208,9 @@ impl Module { span: export.span(), }); } - if let Some(prev) = self.resolve(export.name()) { - let prev_span = prev.span(); - Err(SemanticAnalysisError::SymbolConflict { - span: export.span(), - prev_span, + if let Some(_prev) = self.resolve(export.name()) { + Err(SemanticAnalysisError::ProcedureNameConflict { + name: export.name().clone(), }) } else { self.procedures.push(export); @@ -420,7 +418,7 @@ impl Module { Some(ResolvedProcedure::Local(Span::new(proc.name().span(), index))) } Export::Alias(ref alias) => match alias.target() { - AliasTarget::MastRoot(digest) => Some(ResolvedProcedure::MastRoot(*digest)), + AliasTarget::MastRoot(digest) => Some(ResolvedProcedure::MastRoot(**digest)), AliasTarget::Path(path) => Some(ResolvedProcedure::External(path.clone())), }, } @@ -435,7 +433,7 @@ impl Module { ), Export::Alias(ref p) => { let target = match p.target { - AliasTarget::MastRoot(ref digest) => ResolvedProcedure::MastRoot(*digest), + AliasTarget::MastRoot(ref digest) => ResolvedProcedure::MastRoot(**digest), AliasTarget::Path(ref path) => ResolvedProcedure::External(path.clone()), }; (p.name().clone(), target) diff --git a/assembly/src/ast/procedure/resolver.rs b/assembly/src/ast/procedure/resolver.rs index 05c71eff07..89c1978eac 100644 --- a/assembly/src/ast/procedure/resolver.rs +++ b/assembly/src/ast/procedure/resolver.rs @@ -1,5 +1,5 @@ use super::{FullyQualifiedProcedureName, ProcedureIndex, ProcedureName}; -use crate::{ast::Ident, LibraryPath, RpoDigest, SourceSpan, Span, Spanned}; +use crate::{ast::Ident, LibraryPath, RpoDigest, Span}; use alloc::{collections::BTreeMap, vec::Vec}; // RESOLVED PROCEDURE @@ -13,17 +13,7 @@ pub enum ResolvedProcedure { /// The name was resolved to a procedure exported from another module External(FullyQualifiedProcedureName), /// The name was resolved to a procedure with a known MAST root - MastRoot(Span), -} - -impl Spanned for ResolvedProcedure { - fn span(&self) -> SourceSpan { - match self { - Self::Local(ref spanned) => spanned.span(), - Self::External(ref spanned) => spanned.span(), - Self::MastRoot(ref spanned) => spanned.span(), - } - } + MastRoot(RpoDigest), } // LOCAL NAME RESOLVER diff --git a/assembly/src/compiled_library.rs b/assembly/src/compiled_library.rs index 4ac6407c62..08449c37d8 100644 --- a/assembly/src/compiled_library.rs +++ b/assembly/src/compiled_library.rs @@ -6,7 +6,7 @@ use vm_core::{ use crate::{ ast::{FullyQualifiedProcedureName, ProcedureIndex, ProcedureName, ResolvedProcedure}, - LibraryPath, SourceSpan, Span, Version, + LibraryPath, Version, }; // TODOP: Refactor `FullyQualifiedProcedureName` instead, and use `Span` where needed? @@ -158,11 +158,7 @@ impl CompiledModule { pub fn resolve(&self, name: &ProcedureName) -> Option { self.procedures.iter().find_map(|(_, proc)| { if proc.name() == name { - // TODOP: FIX SPAN - Some(ResolvedProcedure::MastRoot(Span::new( - SourceSpan::at(0), - *proc.digest(), - ))) + Some(ResolvedProcedure::MastRoot(*proc.digest())) } else { None } diff --git a/assembly/src/sema/errors.rs b/assembly/src/sema/errors.rs index 59b1852f4d..de6447abbf 100644 --- a/assembly/src/sema/errors.rs +++ b/assembly/src/sema/errors.rs @@ -5,6 +5,8 @@ use crate::{ use alloc::{sync::Arc, vec::Vec}; use core::fmt; +use super::ProcedureName; + /// The high-level error type for all semantic analysis errors. /// /// This rolls up multiple errors into a single one, and as such, can emit many @@ -76,6 +78,9 @@ pub enum SemanticAnalysisError { #[label("previously defined here")] prev_span: SourceSpan, }, + #[error("procedure name conflict: found duplicate definitions of '{name}'")] + #[diagnostic()] + ProcedureNameConflict { name: ProcedureName }, #[error("symbol undefined: no such name found in scope")] #[diagnostic(help("are you missing an import?"))] SymbolUndefined { diff --git a/assembly/src/tests.rs b/assembly/src/tests.rs index f88dcaf6b7..4b53bf8149 100644 --- a/assembly/src/tests.rs +++ b/assembly/src/tests.rs @@ -2199,13 +2199,7 @@ fn invalid_proc_duplicate_procedure_name() { source, "syntax error", "help: see emitted diagnostics for details", - "symbol conflict: found duplicate definitions of the same name", - regex!(r#",-\[test[\d]+:1:6\]"#), - "1 | proc.foo add mul end proc.foo push.3 end begin push.1 end", - " : ^|^ ^^^^^^^^^|^^^^^^^^^", - " : | `-- conflict occurs here", - " : `-- previously defined here", - " `----" + "procedure name conflict: found duplicate definitions of 'foo'" ); } From cca6fe2a2955d6b1ba135dc2c9f93f67d4cb2d7b Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 19 Jul 2024 11:39:50 -0400 Subject: [PATCH 147/172] improve test --- assembly/src/tests.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/assembly/src/tests.rs b/assembly/src/tests.rs index 4b53bf8149..10f56dd97a 100644 --- a/assembly/src/tests.rs +++ b/assembly/src/tests.rs @@ -2423,11 +2423,17 @@ fn test_compiled_library() { use.mylib::mod1 use.mylib::mod2 + proc.foo + push.1 + drop + end + begin - exec.mod1::foo - exec.mod1::bar - exec.mod2::foo - exec.mod2::bar + exec.mod1::foo + exec.mod1::bar + exec.mod2::foo + exec.mod2::bar + exec.foo end "; From c83ba0bfc16d3714a83eccdab84f40dd9d38a71b Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 19 Jul 2024 12:19:52 -0400 Subject: [PATCH 148/172] remove TODOP --- assembly/src/compiled_library.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/assembly/src/compiled_library.rs b/assembly/src/compiled_library.rs index 08449c37d8..36a847ba33 100644 --- a/assembly/src/compiled_library.rs +++ b/assembly/src/compiled_library.rs @@ -9,7 +9,10 @@ use crate::{ LibraryPath, Version, }; -// TODOP: Refactor `FullyQualifiedProcedureName` instead, and use `Span` where needed? +/// A procedure's name, along with its module path. +/// +/// The only difference between this type and [`FullyQualifiedProcedureName`] is that +/// [`CompiledFullyQualifiedProcedureName`] doesn't have a [`crate::SourceSpan`]. pub struct CompiledFullyQualifiedProcedureName { /// The module path for this procedure. pub module_path: LibraryPath, From db5a41597f3654efe7fb9ee8e72999c3c6d17cc7 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 19 Jul 2024 13:05:21 -0400 Subject: [PATCH 149/172] `CompiledProcedure` -> `ProcedureInfo` --- assembly/src/assembler/mod.rs | 10 +++--- assembly/src/assembler/module_graph/debug.rs | 2 +- assembly/src/assembler/module_graph/mod.rs | 8 ++--- assembly/src/compiled_library.rs | 32 +++++++------------- 4 files changed, 21 insertions(+), 31 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index db078ec629..382b6f6b2b 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -5,7 +5,7 @@ use crate::{ }, compiled_library::{ CompiledFullyQualifiedProcedureName, CompiledLibrary, CompiledLibraryMetadata, - CompiledProcedure, + ProcedureInfo, }, diagnostics::{tracing::instrument, Report}, sema::SemanticAnalysisError, @@ -228,7 +228,7 @@ impl Assembler { Ok(()) } - /// TODOP: documentation + /// Adds the compiled library to provide modules for the compilation. pub fn add_compiled_library(&mut self, library: CompiledLibrary) -> Result<(), Report> { let module_indexes: Vec = library .into_compiled_modules() @@ -602,10 +602,10 @@ impl Assembler { module_index: ModuleIndex, mast_forest: &MastForest, // TODOP: Return iterator instead? - ) -> Result, Report> { + ) -> Result, Report> { assert!(self.module_graph.contains_module(module_index), "invalid module index"); - let exports: Vec = match &self.module_graph[module_index] { + let exports: Vec = match &self.module_graph[module_index] { module_graph::WrapperModule::Ast(module) => { let mut exports = Vec::new(); for (index, procedure) in module.procedures().enumerate() { @@ -656,7 +656,7 @@ impl Assembler { } }); - let compiled_proc = CompiledProcedure { + let compiled_proc = ProcedureInfo { name: proc.name().clone(), digest: mast_forest[proc.body_node_id()].digest(), }; diff --git a/assembly/src/assembler/module_graph/debug.rs b/assembly/src/assembler/module_graph/debug.rs index 4596cf0e3b..455b619ce0 100644 --- a/assembly/src/assembler/module_graph/debug.rs +++ b/assembly/src/assembler/module_graph/debug.rs @@ -88,7 +88,7 @@ impl<'a> fmt::Debug for DisplayModuleGraphNodes<'a> { module: module_index, index: *proc_index, path: m.path(), - proc_name: proc.name(), + proc_name: &proc.name, ty: GraphNodeType::Compiled, }) .collect::>(), diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index e76e6256a1..01d64b24af 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -27,7 +27,7 @@ use self::{ rewrites::ModuleRewriter, }; use super::{GlobalProcedureIndex, ModuleIndex}; -use crate::compiled_library::{CompiledModule, CompiledProcedure}; +use crate::compiled_library::{CompiledModule, ProcedureInfo}; use crate::{ ast::{ Export, FullyQualifiedProcedureName, InvocationTarget, Module, ProcedureIndex, @@ -40,14 +40,14 @@ use crate::{ // TODOP: Better doc pub enum WrapperProcedure<'a> { Ast(&'a Export), - Compiled(&'a CompiledProcedure), + Compiled(&'a ProcedureInfo), } impl<'a> WrapperProcedure<'a> { pub fn name(&self) -> &ProcedureName { match self { WrapperProcedure::Ast(p) => p.name(), - WrapperProcedure::Compiled(p) => p.name(), + WrapperProcedure::Compiled(p) => &p.name, } } @@ -759,7 +759,7 @@ impl ModuleGraph { break module .procedures() .iter() - .find(|(_index, procedure)| procedure.name() == &name.name) + .find(|(_index, procedure)| procedure.name == name.name) .map(|(index, _)| GlobalProcedureIndex { module: module_index, index: *index, diff --git a/assembly/src/compiled_library.rs b/assembly/src/compiled_library.rs index 36a847ba33..44139603a1 100644 --- a/assembly/src/compiled_library.rs +++ b/assembly/src/compiled_library.rs @@ -10,7 +10,7 @@ use crate::{ }; /// A procedure's name, along with its module path. -/// +/// /// The only difference between this type and [`FullyQualifiedProcedureName`] is that /// [`CompiledFullyQualifiedProcedureName`] doesn't have a [`crate::SourceSpan`]. pub struct CompiledFullyQualifiedProcedureName { @@ -35,23 +35,13 @@ impl From for CompiledFullyQualifiedProcedureName { } } +/// Stores the name and digest of a procedure. #[derive(Debug, Clone)] -pub struct CompiledProcedure { +pub struct ProcedureInfo { pub name: ProcedureName, pub digest: RpoDigest, } -// TODOP: Remove methods in favor of pub fields? -impl CompiledProcedure { - pub fn name(&self) -> &ProcedureName { - &self.name - } - - pub fn digest(&self) -> &RpoDigest { - &self.digest - } -} - // TODOP: Move into `miden-core` along with `LibraryPath` pub struct CompiledLibrary { mast_forest: MastForest, @@ -99,7 +89,7 @@ impl CompiledLibrary { let proc_node_id = self.mast_forest.procedure_roots()[proc_index]; let proc_digest = self.mast_forest[proc_node_id].digest(); - compiled_module.add_procedure(CompiledProcedure { + compiled_module.add_procedure(ProcedureInfo { name: proc_name.name.clone(), digest: proc_digest, }) @@ -107,7 +97,7 @@ impl CompiledLibrary { .or_insert_with(|| { let proc_node_id = self.mast_forest.procedure_roots()[proc_index]; let proc_digest = self.mast_forest[proc_node_id].digest(); - let proc = CompiledProcedure { + let proc = ProcedureInfo { name: proc_name.name, digest: proc_digest, }; @@ -129,11 +119,11 @@ pub struct CompiledLibraryMetadata { #[derive(Debug, Clone)] pub struct CompiledModule { path: LibraryPath, - procedures: Vec<(ProcedureIndex, CompiledProcedure)>, + procedures: Vec<(ProcedureIndex, ProcedureInfo)>, } impl CompiledModule { - pub fn new(path: LibraryPath, procedures: impl Iterator) -> Self { + pub fn new(path: LibraryPath, procedures: impl Iterator) -> Self { Self { path, procedures: procedures @@ -143,7 +133,7 @@ impl CompiledModule { } } - pub fn add_procedure(&mut self, procedure: CompiledProcedure) { + pub fn add_procedure(&mut self, procedure: ProcedureInfo) { let index = ProcedureIndex::new(self.procedures.len()); self.procedures.push((index, procedure)); } @@ -154,14 +144,14 @@ impl CompiledModule { // TODOP: Store as `CompiledProcedure`, and add a method `iter()` that iterates with // `ProcedureIndex` - pub fn procedures(&self) -> &[(ProcedureIndex, CompiledProcedure)] { + pub fn procedures(&self) -> &[(ProcedureIndex, ProcedureInfo)] { &self.procedures } pub fn resolve(&self, name: &ProcedureName) -> Option { self.procedures.iter().find_map(|(_, proc)| { - if proc.name() == name { - Some(ResolvedProcedure::MastRoot(*proc.digest())) + if &proc.name == name { + Some(ResolvedProcedure::MastRoot(proc.digest)) } else { None } From d0c17264572fdb0a47757509ee0b8d50bffef1d1 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 19 Jul 2024 13:27:16 -0400 Subject: [PATCH 150/172] Document `CompiledLibrary` --- assembly/src/assembler/mod.rs | 3 +-- assembly/src/compiled_library.rs | 29 ++++++++++++++++++++--------- assembly/src/errors.rs | 10 ++++++++++ assembly/src/lib.rs | 2 +- 4 files changed, 32 insertions(+), 12 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 382b6f6b2b..9d56c6f8ed 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -232,7 +232,6 @@ impl Assembler { pub fn add_compiled_library(&mut self, library: CompiledLibrary) -> Result<(), Report> { let module_indexes: Vec = library .into_compiled_modules() - .into_iter() .map(|module| self.module_graph.add_compiled_module(module)) .collect::>()?; @@ -388,7 +387,7 @@ impl Assembler { exports }; - Ok(CompiledLibrary::new(mast_forest_builder.build(), exports, metadata)) + Ok(CompiledLibrary::new(mast_forest_builder.build(), exports, metadata)?) } /// Compiles the provided module into a [`Program`]. The resulting program can be executed on diff --git a/assembly/src/compiled_library.rs b/assembly/src/compiled_library.rs index 44139603a1..8f4ad40853 100644 --- a/assembly/src/compiled_library.rs +++ b/assembly/src/compiled_library.rs @@ -6,7 +6,7 @@ use vm_core::{ use crate::{ ast::{FullyQualifiedProcedureName, ProcedureIndex, ProcedureName, ResolvedProcedure}, - LibraryPath, Version, + CompiledLibraryError, LibraryPath, Version, }; /// A procedure's name, along with its module path. @@ -42,44 +42,55 @@ pub struct ProcedureInfo { pub digest: RpoDigest, } -// TODOP: Move into `miden-core` along with `LibraryPath` +/// Represents a library where all modules modules were compiled into a [`MastForest`]. pub struct CompiledLibrary { mast_forest: MastForest, - // a path for every `root` in the associated [MastForest] + // a path for every `root` in the associated MAST forest exports: Vec, metadata: CompiledLibraryMetadata, } /// Constructors impl CompiledLibrary { - // TODOP: Add validation that num roots = num exports + /// Constructs a new [`CompiledLibrary`]. pub fn new( mast_forest: MastForest, exports: Vec, metadata: CompiledLibraryMetadata, - ) -> Self { - Self { + ) -> Result { + if mast_forest.procedure_roots().len() != exports.len() { + return Err(CompiledLibraryError::InvalidExports { + exports_len: exports.len(), + roots_len: mast_forest.procedure_roots().len(), + }); + } + + Ok(Self { mast_forest, exports, metadata, - } + }) } } impl CompiledLibrary { + /// Returns the inner [`MastForest`]. pub fn mast_forest(&self) -> &MastForest { &self.mast_forest } + /// Returns the fully qualified name of all procedures exported by the library. pub fn exports(&self) -> &[CompiledFullyQualifiedProcedureName] { &self.exports } + /// Returns the library metadata. pub fn metadata(&self) -> &CompiledLibraryMetadata { &self.metadata } - pub fn into_compiled_modules(self) -> Vec { + /// Returns an iterator over the compiled modules of the library. + pub fn into_compiled_modules(self) -> impl Iterator { let mut modules_by_path: BTreeMap = BTreeMap::new(); for (proc_index, proc_name) in self.exports.into_iter().enumerate() { @@ -106,7 +117,7 @@ impl CompiledLibrary { }); } - modules_by_path.into_values().collect() + modules_by_path.into_values() } } diff --git a/assembly/src/errors.rs b/assembly/src/errors.rs index b4a13271a8..dc8134d4d6 100644 --- a/assembly/src/errors.rs +++ b/assembly/src/errors.rs @@ -132,3 +132,13 @@ impl From for AssemblyError { Self::Other(RelatedError::new(report)) } } + +#[derive(Debug, thiserror::Error, Diagnostic)] +pub enum CompiledLibraryError { + #[error("Invalid exports: MAST forest has {roots_len} procedure roots, but exports have {exports_len}")] + #[diagnostic()] + InvalidExports { + exports_len: usize, + roots_len: usize, + }, +} diff --git a/assembly/src/lib.rs b/assembly/src/lib.rs index f7ddc13751..7743c9be91 100644 --- a/assembly/src/lib.rs +++ b/assembly/src/lib.rs @@ -34,7 +34,7 @@ mod tests; pub use self::assembler::{ArtifactKind, Assembler, AssemblyContext}; pub use self::compile::{Compile, Options as CompileOptions}; -pub use self::errors::AssemblyError; +pub use self::errors::{AssemblyError, CompiledLibraryError}; pub use self::library::{ Library, LibraryError, LibraryNamespace, LibraryPath, MaslLibrary, PathError, Version, }; From f5c81541c42ef9690246a93e6f0835a373abcc11 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 19 Jul 2024 13:35:18 -0400 Subject: [PATCH 151/172] Rename `CompiledModule` -> `ModuleInfo` --- assembly/src/assembler/mod.rs | 8 ++-- assembly/src/assembler/module_graph/debug.rs | 4 +- assembly/src/assembler/module_graph/mod.rs | 49 ++++++++++---------- assembly/src/compiled_library.rs | 14 +++--- 4 files changed, 38 insertions(+), 37 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 9d56c6f8ed..27dda82467 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -231,8 +231,8 @@ impl Assembler { /// Adds the compiled library to provide modules for the compilation. pub fn add_compiled_library(&mut self, library: CompiledLibrary) -> Result<(), Report> { let module_indexes: Vec = library - .into_compiled_modules() - .map(|module| self.module_graph.add_compiled_module(module)) + .into_module_infos() + .map(|module| self.module_graph.add_module_info(module)) .collect::>()?; self.module_graph.recompute()?; @@ -240,7 +240,7 @@ impl Assembler { // Register all procedures as roots for module_index in module_indexes { for (proc_index, proc) in - self.module_graph[module_index].unwrap_compiled().clone().procedures() + self.module_graph[module_index].unwrap_info().clone().procedures() { let gid = GlobalProcedureIndex { module: module_index, @@ -665,7 +665,7 @@ impl Assembler { exports } - module_graph::WrapperModule::Exports(module) => { + module_graph::WrapperModule::Info(module) => { module.procedures().iter().map(|(_idx, proc)| proc).cloned().collect() } }; diff --git a/assembly/src/assembler/module_graph/debug.rs b/assembly/src/assembler/module_graph/debug.rs index 455b619ce0..c69e450869 100644 --- a/assembly/src/assembler/module_graph/debug.rs +++ b/assembly/src/assembler/module_graph/debug.rs @@ -35,7 +35,7 @@ impl<'a> fmt::Debug for DisplayModuleGraph<'a> { } }) .collect::>(), - WrapperModule::Exports(m) => m + WrapperModule::Info(m) => m .procedures() .iter() .map(|(proc_index, _proc)| { @@ -81,7 +81,7 @@ impl<'a> fmt::Debug for DisplayModuleGraphNodes<'a> { } }) .collect::>(), - WrapperModule::Exports(m) => m + WrapperModule::Info(m) => m .procedures() .iter() .map(|(proc_index, proc)| DisplayModuleGraphNode { diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 01d64b24af..5dc096fd31 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -27,7 +27,7 @@ use self::{ rewrites::ModuleRewriter, }; use super::{GlobalProcedureIndex, ModuleIndex}; -use crate::compiled_library::{CompiledModule, ProcedureInfo}; +use crate::compiled_library::{ModuleInfo, ProcedureInfo}; use crate::{ ast::{ Export, FullyQualifiedProcedureName, InvocationTarget, Module, ProcedureIndex, @@ -67,32 +67,32 @@ impl<'a> WrapperProcedure<'a> { #[derive(Clone)] pub enum WrapperModule { Ast(Arc), - Exports(CompiledModule), + Info(ModuleInfo), } impl WrapperModule { pub fn path(&self) -> &LibraryPath { match self { WrapperModule::Ast(m) => m.path(), - WrapperModule::Exports(m) => m.path(), + WrapperModule::Info(m) => m.path(), } } pub fn unwrap_ast(&self) -> &Arc { match self { WrapperModule::Ast(module) => module, - WrapperModule::Exports(_) => { + WrapperModule::Info(_) => { panic!("expected module to be in AST representation, but was compiled") } } } - pub fn unwrap_compiled(&self) -> &CompiledModule { + pub fn unwrap_info(&self) -> &ModuleInfo { match self { WrapperModule::Ast(_) => { panic!("expected module to be compiled, but was in AST representation") } - WrapperModule::Exports(module) => module, + WrapperModule::Info(module) => module, } } @@ -100,7 +100,7 @@ impl WrapperModule { pub fn resolve(&self, name: &ProcedureName) -> Option { match self { WrapperModule::Ast(module) => module.resolve(name), - WrapperModule::Exports(module) => module.resolve(name), + WrapperModule::Info(module) => module.resolve(name), } } } @@ -109,14 +109,14 @@ impl WrapperModule { #[derive(Clone)] pub enum PendingWrapperModule { Ast(Box), - Exports(CompiledModule), + Info(ModuleInfo), } impl PendingWrapperModule { pub fn path(&self) -> &LibraryPath { match self { PendingWrapperModule::Ast(m) => m.path(), - PendingWrapperModule::Exports(m) => m.path(), + PendingWrapperModule::Info(m) => m.path(), } } } @@ -180,11 +180,11 @@ impl ModuleGraph { self.add_module(PendingWrapperModule::Ast(module)) } - /// Add compiled `module` to the graph. + /// Add the [`ModuleInfo`] to the graph. /// - /// NOTE: This operation only adds a module to the graph, but does not perform the - /// important analysis needed for compilation, you must call [recompute] once all modules - /// are added to ensure the analysis results reflect the current version of the graph. + /// NOTE: This operation only adds a module to the graph, but does not perform the important + /// analysis needed for compilation, you must call [`Self::recompute`] once all modules are + /// added to ensure the analysis results reflect the current version of the graph. /// /// # Errors /// @@ -197,11 +197,11 @@ impl ModuleGraph { /// /// This function will panic if the number of modules exceeds the maximum representable /// [ModuleIndex] value, `u16::MAX`. - pub fn add_compiled_module( + pub fn add_module_info( &mut self, - module: CompiledModule, + module_info: ModuleInfo, ) -> Result { - self.add_module(PendingWrapperModule::Exports(module)) + self.add_module(PendingWrapperModule::Info(module_info)) } fn add_module(&mut self, module: PendingWrapperModule) -> Result { @@ -368,6 +368,7 @@ impl ModuleGraph { for (pending_index, pending_module) in pending.iter().enumerate() { let module_id = ModuleIndex::new(high_water_mark + pending_index); + // TODOP: Refactor everywhere that we added big `match` statements // Apply module to call graph match pending_module { PendingWrapperModule::Ast(pending_module) => { @@ -386,7 +387,7 @@ impl ModuleGraph { } } } - PendingWrapperModule::Exports(pending_module) => { + PendingWrapperModule::Info(pending_module) => { for (procedure_id, _procedure) in pending_module.procedures().iter() { let global_id = GlobalProcedureIndex { module: module_id, @@ -446,8 +447,8 @@ impl ModuleGraph { finished.push(WrapperModule::Ast(Arc::new(*ast_module))) } - PendingWrapperModule::Exports(module) => { - finished.push(WrapperModule::Exports(module)); + PendingWrapperModule::Info(module) => { + finished.push(WrapperModule::Info(module)); } } } @@ -561,7 +562,7 @@ impl ModuleGraph { pub fn get_procedure(&self, id: GlobalProcedureIndex) -> Option { match &self.modules[id.module.as_usize()] { WrapperModule::Ast(m) => m.get(id.index).map(WrapperProcedure::Ast), - WrapperModule::Exports(m) => m + WrapperModule::Info(m) => m .procedures() .get(id.index.as_usize()) .map(|(_idx, proc)| WrapperProcedure::Compiled(proc)), @@ -575,7 +576,7 @@ impl ModuleGraph { pub fn get_procedure_unsafe(&self, id: GlobalProcedureIndex) -> WrapperProcedure { match &self.modules[id.module.as_usize()] { WrapperModule::Ast(m) => WrapperProcedure::Ast(&m[id.index]), - WrapperModule::Exports(m) => { + WrapperModule::Info(m) => { WrapperProcedure::Compiled(&m.procedures()[id.index.as_usize()].1) } } @@ -630,13 +631,13 @@ impl ModuleGraph { let prev_proc = { match &self.modules[prev_id.module.as_usize()] { WrapperModule::Ast(module) => Some(&module[prev_id.index]), - WrapperModule::Exports(_) => None, + WrapperModule::Info(_) => None, } }; let current_proc = { match &self.modules[id.module.as_usize()] { WrapperModule::Ast(module) => Some(&module[id.index]), - WrapperModule::Exports(_) => None, + WrapperModule::Info(_) => None, } }; @@ -755,7 +756,7 @@ impl ModuleGraph { } } } - WrapperModule::Exports(module) => { + WrapperModule::Info(module) => { break module .procedures() .iter() diff --git a/assembly/src/compiled_library.rs b/assembly/src/compiled_library.rs index 8f4ad40853..349f3850ae 100644 --- a/assembly/src/compiled_library.rs +++ b/assembly/src/compiled_library.rs @@ -89,9 +89,9 @@ impl CompiledLibrary { &self.metadata } - /// Returns an iterator over the compiled modules of the library. - pub fn into_compiled_modules(self) -> impl Iterator { - let mut modules_by_path: BTreeMap = BTreeMap::new(); + /// Returns an iterator over the module infos of the library. + pub fn into_module_infos(self) -> impl Iterator { + let mut modules_by_path: BTreeMap = BTreeMap::new(); for (proc_index, proc_name) in self.exports.into_iter().enumerate() { modules_by_path @@ -113,7 +113,7 @@ impl CompiledLibrary { digest: proc_digest, }; - CompiledModule::new(proc_name.module_path, core::iter::once(proc)) + ModuleInfo::new(proc_name.module_path, core::iter::once(proc)) }); } @@ -126,14 +126,14 @@ pub struct CompiledLibraryMetadata { pub version: Version, } -// TODOP: Rename (?) +/// Stores a module's path, as well as information about all exported procedures. #[derive(Debug, Clone)] -pub struct CompiledModule { +pub struct ModuleInfo { path: LibraryPath, procedures: Vec<(ProcedureIndex, ProcedureInfo)>, } -impl CompiledModule { +impl ModuleInfo { pub fn new(path: LibraryPath, procedures: impl Iterator) -> Self { Self { path, From 3456d9c034a5f00f08f054cada3250f811d73419 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 19 Jul 2024 13:49:58 -0400 Subject: [PATCH 152/172] Refactor `ModuleInfo` --- assembly/src/assembler/mod.rs | 6 +-- assembly/src/assembler/module_graph/debug.rs | 10 ++-- assembly/src/assembler/module_graph/mod.rs | 33 ++++++------ assembly/src/compiled_library.rs | 53 ++++++++++++-------- 4 files changed, 59 insertions(+), 43 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 27dda82467..3877326164 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -240,11 +240,11 @@ impl Assembler { // Register all procedures as roots for module_index in module_indexes { for (proc_index, proc) in - self.module_graph[module_index].unwrap_info().clone().procedures() + self.module_graph[module_index].unwrap_info().clone().procedure_infos() { let gid = GlobalProcedureIndex { module: module_index, - index: *proc_index, + index: proc_index, }; self.module_graph.register_mast_root(gid, proc.digest)?; @@ -666,7 +666,7 @@ impl Assembler { exports } module_graph::WrapperModule::Info(module) => { - module.procedures().iter().map(|(_idx, proc)| proc).cloned().collect() + module.procedure_infos().map(|(_idx, proc)| proc).cloned().collect() } }; diff --git a/assembly/src/assembler/module_graph/debug.rs b/assembly/src/assembler/module_graph/debug.rs index c69e450869..1eab5de877 100644 --- a/assembly/src/assembler/module_graph/debug.rs +++ b/assembly/src/assembler/module_graph/debug.rs @@ -36,12 +36,11 @@ impl<'a> fmt::Debug for DisplayModuleGraph<'a> { }) .collect::>(), WrapperModule::Info(m) => m - .procedures() - .iter() + .procedure_infos() .map(|(proc_index, _proc)| { let gid = GlobalProcedureIndex { module: ModuleIndex::new(module_index), - index: *proc_index, + index: proc_index, }; let out_edges = self.0.callgraph.out_edges(gid); @@ -82,11 +81,10 @@ impl<'a> fmt::Debug for DisplayModuleGraphNodes<'a> { }) .collect::>(), WrapperModule::Info(m) => m - .procedures() - .iter() + .procedure_infos() .map(|(proc_index, proc)| DisplayModuleGraphNode { module: module_index, - index: *proc_index, + index: proc_index, path: m.path(), proc_name: &proc.name, ty: GraphNodeType::Compiled, diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 5dc096fd31..b603a22b23 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -100,7 +100,9 @@ impl WrapperModule { pub fn resolve(&self, name: &ProcedureName) -> Option { match self { WrapperModule::Ast(module) => module.resolve(name), - WrapperModule::Info(module) => module.resolve(name), + WrapperModule::Info(module) => { + module.get_proc_digest_by_name(name).map(ResolvedProcedure::MastRoot) + } } } } @@ -388,10 +390,10 @@ impl ModuleGraph { } } PendingWrapperModule::Info(pending_module) => { - for (procedure_id, _procedure) in pending_module.procedures().iter() { + for (proc_index, _procedure) in pending_module.procedure_infos() { let global_id = GlobalProcedureIndex { module: module_id, - index: *procedure_id, + index: proc_index, }; self.callgraph.get_or_insert_node(global_id); } @@ -562,10 +564,9 @@ impl ModuleGraph { pub fn get_procedure(&self, id: GlobalProcedureIndex) -> Option { match &self.modules[id.module.as_usize()] { WrapperModule::Ast(m) => m.get(id.index).map(WrapperProcedure::Ast), - WrapperModule::Info(m) => m - .procedures() - .get(id.index.as_usize()) - .map(|(_idx, proc)| WrapperProcedure::Compiled(proc)), + WrapperModule::Info(m) => { + m.get_proc_info_by_index(id.index).map(WrapperProcedure::Compiled) + } } } @@ -577,7 +578,7 @@ impl ModuleGraph { match &self.modules[id.module.as_usize()] { WrapperModule::Ast(m) => WrapperProcedure::Ast(&m[id.index]), WrapperModule::Info(m) => { - WrapperProcedure::Compiled(&m.procedures()[id.index.as_usize()].1) + WrapperProcedure::Compiled(m.get_proc_info_by_index(id.index).unwrap()) } } } @@ -758,12 +759,16 @@ impl ModuleGraph { } WrapperModule::Info(module) => { break module - .procedures() - .iter() - .find(|(_index, procedure)| procedure.name == name.name) - .map(|(index, _)| GlobalProcedureIndex { - module: module_index, - index: *index, + .procedure_infos() + .find_map(|(index, procedure)| { + if procedure.name == name.name { + Some(GlobalProcedureIndex { + module: module_index, + index, + }) + } else { + None + } }) .ok_or(AssemblyError::Failed { labels: vec![RelatedLabel::error("undefined procedure") diff --git a/assembly/src/compiled_library.rs b/assembly/src/compiled_library.rs index 349f3850ae..6022e4f925 100644 --- a/assembly/src/compiled_library.rs +++ b/assembly/src/compiled_library.rs @@ -5,7 +5,7 @@ use vm_core::{ }; use crate::{ - ast::{FullyQualifiedProcedureName, ProcedureIndex, ProcedureName, ResolvedProcedure}, + ast::{FullyQualifiedProcedureName, ProcedureIndex, ProcedureName}, CompiledLibraryError, LibraryPath, Version, }; @@ -100,7 +100,7 @@ impl CompiledLibrary { let proc_node_id = self.mast_forest.procedure_roots()[proc_index]; let proc_digest = self.mast_forest[proc_node_id].digest(); - compiled_module.add_procedure(ProcedureInfo { + compiled_module.add_procedure_info(ProcedureInfo { name: proc_name.name.clone(), digest: proc_digest, }) @@ -113,7 +113,7 @@ impl CompiledLibrary { digest: proc_digest, }; - ModuleInfo::new(proc_name.module_path, core::iter::once(proc)) + ModuleInfo::new(proc_name.module_path, vec![proc]) }); } @@ -130,39 +130,52 @@ pub struct CompiledLibraryMetadata { #[derive(Debug, Clone)] pub struct ModuleInfo { path: LibraryPath, - procedures: Vec<(ProcedureIndex, ProcedureInfo)>, + procedure_infos: Vec, } impl ModuleInfo { - pub fn new(path: LibraryPath, procedures: impl Iterator) -> Self { + /// Constructs a new [`ModuleInfo`]. + pub fn new(path: LibraryPath, procedures: Vec) -> Self { Self { path, - procedures: procedures - .enumerate() - .map(|(idx, proc)| (ProcedureIndex::new(idx), proc)) - .collect(), + procedure_infos: procedures, } } - pub fn add_procedure(&mut self, procedure: ProcedureInfo) { - let index = ProcedureIndex::new(self.procedures.len()); - self.procedures.push((index, procedure)); + /// Adds a [`ProcedureInfo`] to the module. + pub fn add_procedure_info(&mut self, procedure: ProcedureInfo) { + self.procedure_infos.push(procedure); } + /// Returns the module's library path. pub fn path(&self) -> &LibraryPath { &self.path } - // TODOP: Store as `CompiledProcedure`, and add a method `iter()` that iterates with - // `ProcedureIndex` - pub fn procedures(&self) -> &[(ProcedureIndex, ProcedureInfo)] { - &self.procedures + /// Returns the number of procedures in the module. + pub fn num_procedures(&self) -> usize { + self.procedure_infos.len() } - pub fn resolve(&self, name: &ProcedureName) -> Option { - self.procedures.iter().find_map(|(_, proc)| { - if &proc.name == name { - Some(ResolvedProcedure::MastRoot(proc.digest)) + /// Returns an iterator over the procedure infos in the module with their corresponding + /// procedure index in the module. + pub fn procedure_infos(&self) -> impl Iterator { + self.procedure_infos + .iter() + .enumerate() + .map(|(idx, proc)| (ProcedureIndex::new(idx), proc)) + } + + /// Returns the [`ProcedureInfo`] of the procedure at the provided index, if any. + pub fn get_proc_info_by_index(&self, index: ProcedureIndex) -> Option<&ProcedureInfo> { + self.procedure_infos.get(index.as_usize()) + } + + /// Returns the digest of the procedure with the provided name, if any. + pub fn get_proc_digest_by_name(&self, name: &ProcedureName) -> Option { + self.procedure_infos.iter().find_map(|proc_info| { + if &proc_info.name == name { + Some(proc_info.digest) } else { None } From 7d7af1772437e8c52cfffc62c9357594443f0673 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 19 Jul 2024 14:05:35 -0400 Subject: [PATCH 153/172] `ModuleWrapper` -> `WrappedModule` --- assembly/src/assembler/mod.rs | 4 +- assembly/src/assembler/module_graph/debug.rs | 10 +- assembly/src/assembler/module_graph/mod.rs | 141 +++++++++++-------- 3 files changed, 90 insertions(+), 65 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 3877326164..af2dd807cf 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -605,7 +605,7 @@ impl Assembler { assert!(self.module_graph.contains_module(module_index), "invalid module index"); let exports: Vec = match &self.module_graph[module_index] { - module_graph::WrapperModule::Ast(module) => { + module_graph::WrappedModule::Ast(module) => { let mut exports = Vec::new(); for (index, procedure) in module.procedures().enumerate() { // Only add exports; locals will be added if they are in the call graph rooted @@ -665,7 +665,7 @@ impl Assembler { exports } - module_graph::WrapperModule::Info(module) => { + module_graph::WrappedModule::Info(module) => { module.procedure_infos().map(|(_idx, proc)| proc).cloned().collect() } }; diff --git a/assembly/src/assembler/module_graph/debug.rs b/assembly/src/assembler/module_graph/debug.rs index 1eab5de877..4bb5a691dc 100644 --- a/assembly/src/assembler/module_graph/debug.rs +++ b/assembly/src/assembler/module_graph/debug.rs @@ -19,7 +19,7 @@ impl<'a> fmt::Debug for DisplayModuleGraph<'a> { f.debug_set() .entries(self.0.modules.iter().enumerate().flat_map(|(module_index, m)| { match m { - WrapperModule::Ast(m) => m + WrappedModule::Ast(m) => m .procedures() .enumerate() .filter_map(move |(i, export)| { @@ -35,7 +35,7 @@ impl<'a> fmt::Debug for DisplayModuleGraph<'a> { } }) .collect::>(), - WrapperModule::Info(m) => m + WrappedModule::Info(m) => m .procedure_infos() .map(|(proc_index, _proc)| { let gid = GlobalProcedureIndex { @@ -54,7 +54,7 @@ impl<'a> fmt::Debug for DisplayModuleGraph<'a> { } #[doc(hidden)] -struct DisplayModuleGraphNodes<'a>(&'a Vec); +struct DisplayModuleGraphNodes<'a>(&'a Vec); impl<'a> fmt::Debug for DisplayModuleGraphNodes<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -63,7 +63,7 @@ impl<'a> fmt::Debug for DisplayModuleGraphNodes<'a> { let module_index = ModuleIndex::new(module_index); match m { - WrapperModule::Ast(m) => m + WrappedModule::Ast(m) => m .procedures() .enumerate() .filter_map(move |(proc_index, export)| { @@ -80,7 +80,7 @@ impl<'a> fmt::Debug for DisplayModuleGraphNodes<'a> { } }) .collect::>(), - WrapperModule::Info(m) => m + WrappedModule::Info(m) => m .procedure_infos() .map(|(proc_index, proc)| DisplayModuleGraphNode { module: module_index, diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index b603a22b23..e656b7817f 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -37,70 +37,95 @@ use crate::{ AssemblyError, LibraryPath, RpoDigest, Spanned, }; -// TODOP: Better doc -pub enum WrapperProcedure<'a> { +/// Wraps all supported representations of a procedure in the module graph. +/// +/// Currently, there are two supported representations: +/// - `Ast`: wraps a procedure for which we have access to the entire AST, +/// - `Info`: stores the procedure's name and digest (resulting from previously compiled +/// procedures). +pub enum ProcedureWrapper<'a> { Ast(&'a Export), - Compiled(&'a ProcedureInfo), + Info(&'a ProcedureInfo), } -impl<'a> WrapperProcedure<'a> { +impl<'a> ProcedureWrapper<'a> { + /// Returns the name of the procedure. pub fn name(&self) -> &ProcedureName { match self { - WrapperProcedure::Ast(p) => p.name(), - WrapperProcedure::Compiled(p) => &p.name, + Self::Ast(p) => p.name(), + Self::Info(p) => &p.name, } } + /// Returns the wrapped procedure if in the `Ast` representation, or panics otherwise. + /// + /// # Panics + /// - Panics if the wrapped procedure is not in the `Ast` representation. pub fn unwrap_ast(&self) -> &Export { match self { - WrapperProcedure::Ast(proc) => proc, - WrapperProcedure::Compiled(_) => panic!("expected AST procedure, but was compiled"), + Self::Ast(proc) => proc, + Self::Info(_) => panic!("expected AST procedure, but was compiled"), } } + /// Returns true if the wrapped procedure is in the `Ast` representation. pub fn is_ast(&self) -> bool { matches!(self, Self::Ast(_)) } } -// TODOP: Rename +/// Wraps all supported representations of a module in the module graph. +/// +/// Currently, there are two supported representations: +/// - `Ast`: wraps a module for which we have access to the entire AST, +/// - `Info`: stores only the necessary information about a module (resulting from previously +/// compiled modules). #[derive(Clone)] -pub enum WrapperModule { +pub enum WrappedModule { Ast(Arc), Info(ModuleInfo), } -impl WrapperModule { +impl WrappedModule { + /// Returns the library path of the wrapped module. pub fn path(&self) -> &LibraryPath { match self { - WrapperModule::Ast(m) => m.path(), - WrapperModule::Info(m) => m.path(), + Self::Ast(m) => m.path(), + Self::Info(m) => m.path(), } } + /// Returns the wrapped module if in the `Ast` representation, or panics otherwise. + /// + /// # Panics + /// - Panics if the wrapped module is not in the `Ast` representation. pub fn unwrap_ast(&self) -> &Arc { match self { - WrapperModule::Ast(module) => module, - WrapperModule::Info(_) => { + Self::Ast(module) => module, + Self::Info(_) => { panic!("expected module to be in AST representation, but was compiled") } } } + /// Returns the wrapped module if in the `Info` representation, or panics otherwise. + /// + /// # Panics + /// - Panics if the wrapped module is not in the `Info` representation. pub fn unwrap_info(&self) -> &ModuleInfo { match self { - WrapperModule::Ast(_) => { + Self::Ast(_) => { panic!("expected module to be compiled, but was in AST representation") } - WrapperModule::Info(module) => module, + Self::Info(module) => module, } } - /// Resolves `name` to a procedure within the local scope of this module + /// Resolves `name` to a procedure within the local scope of this module. pub fn resolve(&self, name: &ProcedureName) -> Option { match self { - WrapperModule::Ast(module) => module.resolve(name), - WrapperModule::Info(module) => { + WrappedModule::Ast(module) => module.resolve(name), + WrappedModule::Info(module) => { module.get_proc_digest_by_name(name).map(ResolvedProcedure::MastRoot) } } @@ -109,16 +134,16 @@ impl WrapperModule { // TODOP: Try to do without this `Pending*` version #[derive(Clone)] -pub enum PendingWrapperModule { +pub enum PendingModuleWrapper { Ast(Box), Info(ModuleInfo), } -impl PendingWrapperModule { +impl PendingModuleWrapper { pub fn path(&self) -> &LibraryPath { match self { - PendingWrapperModule::Ast(m) => m.path(), - PendingWrapperModule::Info(m) => m.path(), + Self::Ast(m) => m.path(), + Self::Info(m) => m.path(), } } } @@ -128,7 +153,7 @@ impl PendingWrapperModule { #[derive(Default, Clone)] pub struct ModuleGraph { - modules: Vec, + modules: Vec, /// The set of modules pending additional processing before adding them to the graph. /// /// When adding a set of inter-dependent modules to the graph, we process them as a group, so @@ -137,7 +162,7 @@ pub struct ModuleGraph { /// /// Once added to the graph, modules become immutable, and any additional modules added after /// that must by definition only depend on modules in the graph, and not be depended upon. - pending: Vec, + pending: Vec, /// The global call graph of calls, not counting those that are performed directly via MAST /// root. callgraph: CallGraph, @@ -179,7 +204,7 @@ impl ModuleGraph { /// This function will panic if the number of modules exceeds the maximum representable /// [ModuleIndex] value, `u16::MAX`. pub fn add_ast_module(&mut self, module: Box) -> Result { - self.add_module(PendingWrapperModule::Ast(module)) + self.add_module(PendingModuleWrapper::Ast(module)) } /// Add the [`ModuleInfo`] to the graph. @@ -203,10 +228,10 @@ impl ModuleGraph { &mut self, module_info: ModuleInfo, ) -> Result { - self.add_module(PendingWrapperModule::Info(module_info)) + self.add_module(PendingModuleWrapper::Info(module_info)) } - fn add_module(&mut self, module: PendingWrapperModule) -> Result { + fn add_module(&mut self, module: PendingModuleWrapper) -> Result { let is_duplicate = self.is_pending(module.path()) || self.find_module_index(module.path()).is_some(); if is_duplicate { @@ -373,7 +398,7 @@ impl ModuleGraph { // TODOP: Refactor everywhere that we added big `match` statements // Apply module to call graph match pending_module { - PendingWrapperModule::Ast(pending_module) => { + PendingModuleWrapper::Ast(pending_module) => { for (index, procedure) in pending_module.procedures().enumerate() { let procedure_id = ProcedureIndex::new(index); let global_id = GlobalProcedureIndex { @@ -389,7 +414,7 @@ impl ModuleGraph { } } } - PendingWrapperModule::Info(pending_module) => { + PendingModuleWrapper::Info(pending_module) => { for (proc_index, _procedure) in pending_module.procedure_infos() { let global_id = GlobalProcedureIndex { module: module_id, @@ -405,18 +430,18 @@ impl ModuleGraph { // before they are added to the graph let mut resolver = NameResolver::new(self); for module in pending.iter() { - if let PendingWrapperModule::Ast(module) = module { + if let PendingModuleWrapper::Ast(module) = module { resolver.push_pending(module); } } let mut phantoms = BTreeSet::default(); let mut edges = Vec::new(); - let mut finished: Vec = Vec::new(); + let mut finished: Vec = Vec::new(); // Visit all of the newly-added modules and perform any rewrites to AST modules. for (pending_index, module) in pending.into_iter().enumerate() { match module { - PendingWrapperModule::Ast(mut ast_module) => { + PendingModuleWrapper::Ast(mut ast_module) => { let module_id = ModuleIndex::new(high_water_mark + pending_index); let mut rewriter = ModuleRewriter::new(&resolver); @@ -447,10 +472,10 @@ impl ModuleGraph { } } - finished.push(WrapperModule::Ast(Arc::new(*ast_module))) + finished.push(WrappedModule::Ast(Arc::new(*ast_module))) } - PendingWrapperModule::Info(module) => { - finished.push(WrapperModule::Info(module)); + PendingModuleWrapper::Info(module) => { + finished.push(WrappedModule::Info(module)); } } } @@ -472,12 +497,12 @@ impl ModuleGraph { let module_id = ModuleIndex::new(module_index); let module = self.modules[module_id.as_usize()].clone(); - if let WrapperModule::Ast(module) = module { + if let WrappedModule::Ast(module) = module { // Re-analyze the module, and if we needed to clone-on-write, the new module will be // returned. Otherwise, `Ok(None)` indicates that the module is unchanged, and `Err` // indicates that re-analysis has found an issue with this module. if let Some(new_module) = self.reanalyze_module(module_id, module)? { - self.modules[module_id.as_usize()] = WrapperModule::Ast(new_module); + self.modules[module_id.as_usize()] = WrappedModule::Ast(new_module); } } } @@ -550,7 +575,7 @@ impl ModuleGraph { /// Fetch a [Module] by [ModuleIndex] #[allow(unused)] - pub fn get_module(&self, id: ModuleIndex) -> Option { + pub fn get_module(&self, id: ModuleIndex) -> Option { self.modules.get(id.as_usize()).cloned() } @@ -561,11 +586,11 @@ impl ModuleGraph { /// Fetch a [WrapperProcedure] by [GlobalProcedureIndex], or `None` if index is invalid. #[allow(unused)] - pub fn get_procedure(&self, id: GlobalProcedureIndex) -> Option { + pub fn get_procedure(&self, id: GlobalProcedureIndex) -> Option { match &self.modules[id.module.as_usize()] { - WrapperModule::Ast(m) => m.get(id.index).map(WrapperProcedure::Ast), - WrapperModule::Info(m) => { - m.get_proc_info_by_index(id.index).map(WrapperProcedure::Compiled) + WrappedModule::Ast(m) => m.get(id.index).map(ProcedureWrapper::Ast), + WrappedModule::Info(m) => { + m.get_proc_info_by_index(id.index).map(ProcedureWrapper::Info) } } } @@ -574,11 +599,11 @@ impl ModuleGraph { /// /// # Panics /// - Panics if index is invalid. - pub fn get_procedure_unsafe(&self, id: GlobalProcedureIndex) -> WrapperProcedure { + pub fn get_procedure_unsafe(&self, id: GlobalProcedureIndex) -> ProcedureWrapper { match &self.modules[id.module.as_usize()] { - WrapperModule::Ast(m) => WrapperProcedure::Ast(&m[id.index]), - WrapperModule::Info(m) => { - WrapperProcedure::Compiled(m.get_proc_info_by_index(id.index).unwrap()) + WrappedModule::Ast(m) => ProcedureWrapper::Ast(&m[id.index]), + WrappedModule::Info(m) => { + ProcedureWrapper::Info(m.get_proc_info_by_index(id.index).unwrap()) } } } @@ -631,14 +656,14 @@ impl ModuleGraph { if prev_id != id { let prev_proc = { match &self.modules[prev_id.module.as_usize()] { - WrapperModule::Ast(module) => Some(&module[prev_id.index]), - WrapperModule::Info(_) => None, + WrappedModule::Ast(module) => Some(&module[prev_id.index]), + WrappedModule::Info(_) => None, } }; let current_proc = { match &self.modules[id.module.as_usize()] { - WrapperModule::Ast(module) => Some(&module[id.index]), - WrapperModule::Info(_) => None, + WrappedModule::Ast(module) => Some(&module[id.index]), + WrappedModule::Info(_) => None, } }; @@ -710,7 +735,7 @@ impl ModuleGraph { let module = &self.modules[module_index.as_usize()]; match module { - WrapperModule::Ast(module) => { + WrappedModule::Ast(module) => { match module.resolve(&next.name) { Some(ResolvedProcedure::Local(index)) => { let id = GlobalProcedureIndex { @@ -757,7 +782,7 @@ impl ModuleGraph { } } } - WrapperModule::Info(module) => { + WrappedModule::Info(module) => { break module .procedure_infos() .find_map(|(index, procedure)| { @@ -789,13 +814,13 @@ impl ModuleGraph { } /// Resolve a [LibraryPath] to a [Module] in this graph - pub fn find_module(&self, name: &LibraryPath) -> Option { + pub fn find_module(&self, name: &LibraryPath) -> Option { self.modules.iter().find(|m| m.path() == name).cloned() } /// Returns an iterator over the set of [Module]s in this graph, and their indices #[allow(unused)] - pub fn modules(&self) -> impl Iterator + '_ { + pub fn modules(&self) -> impl Iterator + '_ { self.modules .iter() .enumerate() @@ -804,13 +829,13 @@ impl ModuleGraph { /// Like [modules], but returns a reference to the module, rather than an owned pointer #[allow(unused)] - pub fn modules_by_ref(&self) -> impl Iterator + '_ { + pub fn modules_by_ref(&self) -> impl Iterator + '_ { self.modules.iter().enumerate().map(|(idx, m)| (ModuleIndex::new(idx), m)) } } impl Index for ModuleGraph { - type Output = WrapperModule; + type Output = WrappedModule; fn index(&self, index: ModuleIndex) -> &Self::Output { self.modules.index(index.as_usize()) From 033b741eab6b73cc340a1f609a76b8d6e7672164 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 19 Jul 2024 14:08:08 -0400 Subject: [PATCH 154/172] Document `PendingModuleWrapper` --- assembly/src/assembler/module_graph/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index e656b7817f..4770ffbb30 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -132,7 +132,7 @@ impl WrappedModule { } } -// TODOP: Try to do without this `Pending*` version +/// Wraps modules that are pending in the [`ModuleGraph`]. #[derive(Clone)] pub enum PendingModuleWrapper { Ast(Box), @@ -140,6 +140,7 @@ pub enum PendingModuleWrapper { } impl PendingModuleWrapper { + /// Returns the library path of the wrapped module. pub fn path(&self) -> &LibraryPath { match self { Self::Ast(m) => m.path(), From 7190323aded38710696cfb7c510bc405b4f87203 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 19 Jul 2024 14:22:14 -0400 Subject: [PATCH 155/172] document `Assembler::assemble_library()` --- assembly/src/assembler/mod.rs | 18 +++++++++++++++--- assembly/src/compiled_library.rs | 4 ++-- assembly/src/tests.rs | 2 +- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index af2dd807cf..ba6888d23d 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -14,6 +14,7 @@ use crate::{ }; use alloc::{boxed::Box, sync::Arc, vec::Vec}; use mast_forest_builder::MastForestBuilder; +use miette::miette; use vm_core::{ mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, Decorator, DecoratorList, Kernel, Operation, Program, @@ -340,9 +341,10 @@ impl Assembler { /// Compilation/Assembly impl Assembler { - // TODOP: Document - // TODOP: Check that `CompiledLibraryMetadata` is consistent with `modules` (e.g. modules path - // indeed start with library name) + /// Assembles a set of modules into a library. + /// + /// The returned library can be added to the assembler assembling a program that depends on the + /// library using [`Self::add_compiled_library`]. pub fn assemble_library( mut self, modules: impl Iterator, @@ -351,6 +353,16 @@ impl Assembler { let module_ids: Vec = modules .map(|module| { let module = module.compile_with_options(CompileOptions::for_library())?; + + if module.path().namespace() != &metadata.name { + return Err(miette!( + "library namespace is {}, but module {} has namespace {}", + metadata.name, + module.name(), + module.path().namespace() + )); + } + Ok(self.module_graph.add_ast_module(module)?) }) .collect::>()?; diff --git a/assembly/src/compiled_library.rs b/assembly/src/compiled_library.rs index 6022e4f925..4dd0bc496e 100644 --- a/assembly/src/compiled_library.rs +++ b/assembly/src/compiled_library.rs @@ -6,7 +6,7 @@ use vm_core::{ use crate::{ ast::{FullyQualifiedProcedureName, ProcedureIndex, ProcedureName}, - CompiledLibraryError, LibraryPath, Version, + CompiledLibraryError, LibraryNamespace, LibraryPath, Version, }; /// A procedure's name, along with its module path. @@ -122,7 +122,7 @@ impl CompiledLibrary { } pub struct CompiledLibraryMetadata { - pub path: LibraryPath, + pub name: LibraryNamespace, pub version: Version, } diff --git a/assembly/src/tests.rs b/assembly/src/tests.rs index 10f56dd97a..58ea9146d7 100644 --- a/assembly/src/tests.rs +++ b/assembly/src/tests.rs @@ -2403,7 +2403,7 @@ fn test_compiled_library() { }; let metadata = CompiledLibraryMetadata { - path: LibraryPath::new("mylib").unwrap(), + name: LibraryNamespace::new("mylib").unwrap(), version: Version::min(), }; From 1ee70da8e958b724b30a63aa765c56f5f4c18bdc Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 19 Jul 2024 14:39:35 -0400 Subject: [PATCH 156/172] fix TODOP --- assembly/src/assembler/mod.rs | 89 ++++++++++++---------- assembly/src/assembler/module_graph/mod.rs | 1 - 2 files changed, 49 insertions(+), 41 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index ba6888d23d..dc6b7d4f40 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -1,7 +1,7 @@ use crate::{ ast::{ self, AliasTarget, Export, FullyQualifiedProcedureName, Instruction, InvocationTarget, - InvokeKind, ModuleKind, ProcedureIndex, + InvokeKind, Module, ModuleKind, ProcedureIndex, }, compiled_library::{ CompiledFullyQualifiedProcedureName, CompiledLibrary, CompiledLibraryMetadata, @@ -604,28 +604,44 @@ impl Assembler { .map_err(|err| Report::new(AssemblyError::Kernel(err))) } - // TODOP: Fix docs - /// Get the set of exported procedures of the given module. + /// Get the set of exported procedure infos of the given module. /// /// Returns an error if the provided Miden Assembly is invalid. fn get_module_exports( &mut self, module_index: ModuleIndex, mast_forest: &MastForest, - // TODOP: Return iterator instead? ) -> Result, Report> { assert!(self.module_graph.contains_module(module_index), "invalid module index"); let exports: Vec = match &self.module_graph[module_index] { module_graph::WrappedModule::Ast(module) => { - let mut exports = Vec::new(); - for (index, procedure) in module.procedures().enumerate() { - // Only add exports; locals will be added if they are in the call graph rooted - // at those procedures - if !procedure.visibility().is_exported() { - continue; - } - let gid = match procedure { + self.get_module_exports_ast(module_index, module, mast_forest)? + } + module_graph::WrappedModule::Info(module) => { + module.procedure_infos().map(|(_idx, proc)| proc).cloned().collect() + } + }; + + Ok(exports) + } + + /// Helper function for [`Self::get_module_exports`], specifically for when the inner + /// [`module_graph::WrappedModule`] is in `Ast` representation. + fn get_module_exports_ast( + &self, + module_index: ModuleIndex, + module: &Arc, + mast_forest: &MastForest, + ) -> Result, Report> { + let mut exports = Vec::new(); + for (index, procedure) in module.procedures().enumerate() { + // Only add exports; locals will be added if they are in the call graph rooted + // at those procedures + if !procedure.visibility().is_exported() { + continue; + } + let gid = match procedure { Export::Procedure(_) => GlobalProcedureIndex { module: module_index, index: ProcedureIndex::new(index), @@ -649,38 +665,31 @@ impl Assembler { } } }; - let proc = self.procedure_cache.get(gid).unwrap_or_else(|| match procedure { - Export::Procedure(ref proc) => { - panic!( - "compilation apparently succeeded, but did not find a \ + let proc = self.procedure_cache.get(gid).unwrap_or_else(|| match procedure { + Export::Procedure(ref proc) => { + panic!( + "compilation apparently succeeded, but did not find a \ entry in the procedure cache for '{}'", - proc.name() - ) - } - Export::Alias(ref alias) => { - panic!( - "compilation apparently succeeded, but did not find a \ + proc.name() + ) + } + Export::Alias(ref alias) => { + panic!( + "compilation apparently succeeded, but did not find a \ entry in the procedure cache for alias '{}', i.e. '{}'", - alias.name(), - alias.target() - ); - } - }); - - let compiled_proc = ProcedureInfo { - name: proc.name().clone(), - digest: mast_forest[proc.body_node_id()].digest(), - }; - - exports.push(compiled_proc); + alias.name(), + alias.target() + ); } + }); - exports - } - module_graph::WrappedModule::Info(module) => { - module.procedure_infos().map(|(_idx, proc)| proc).cloned().collect() - } - }; + let compiled_proc = ProcedureInfo { + name: proc.name().clone(), + digest: mast_forest[proc.body_node_id()].digest(), + }; + + exports.push(compiled_proc); + } Ok(exports) } diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 4770ffbb30..813b11b4bf 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -396,7 +396,6 @@ impl ModuleGraph { for (pending_index, pending_module) in pending.iter().enumerate() { let module_id = ModuleIndex::new(high_water_mark + pending_index); - // TODOP: Refactor everywhere that we added big `match` statements // Apply module to call graph match pending_module { PendingModuleWrapper::Ast(pending_module) => { From 9263fbbdd8501744ce5d6bf96f7797d84b23ddc6 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 19 Jul 2024 14:40:59 -0400 Subject: [PATCH 157/172] rename --- assembly/src/assembler/mod.rs | 21 +++++++++++---------- assembly/src/assembler/module_graph/mod.rs | 7 +++---- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index dc6b7d4f40..f5a7b6c205 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -745,16 +745,17 @@ impl Assembler { context: &mut AssemblyContext, mast_forest_builder: &mut MastForestBuilder, ) -> Result, Report> { - let mut worklist = self.module_graph.topological_sort_from_root(root).map_err(|cycle| { - let iter = cycle.into_node_ids(); - let mut nodes = Vec::with_capacity(iter.len()); - for node in iter { - let module = self.module_graph[node.module].path(); - let proc = self.module_graph.get_procedure_unsafe(node); - nodes.push(format!("{}::{}", module, proc.name())); - } - AssemblyError::Cycle { nodes } - })?; + let mut worklist = + self.module_graph.topological_sort_ast_procs_from_root(root).map_err(|cycle| { + let iter = cycle.into_node_ids(); + let mut nodes = Vec::with_capacity(iter.len()); + for node in iter { + let module = self.module_graph[node.module].path(); + let proc = self.module_graph.get_procedure_unsafe(node); + nodes.push(format!("{}::{}", module, proc.name())); + } + AssemblyError::Cycle { nodes } + })?; assert!(!worklist.is_empty()); diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 813b11b4bf..e0d85e072e 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -558,13 +558,12 @@ impl ModuleGraph { self.topo.as_slice() } - /// Compute the topological sort of the callgraph rooted at `caller` - pub fn topological_sort_from_root( + /// Compute the topological sort of the callgraph rooted at `caller`, only for procedures that + /// need to be assembled (i.e. in `Ast` representation). + pub fn topological_sort_ast_procs_from_root( &self, caller: GlobalProcedureIndex, ) -> Result, CycleError> { - // TODOP: Fix Vec -> into_iter() -> collect - // TODOP: Should we change name/args? Ok(self .callgraph .toposort_caller(caller)? From a784b00f825a413c0999deb7c70f54cf7a797c6e Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 25 Jul 2024 14:13:18 -0400 Subject: [PATCH 158/172] fix test --- assembly/src/assembler/mod.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 12548a5d00..c56ff30e73 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -10,6 +10,7 @@ use crate::{ }; use alloc::{sync::Arc, vec::Vec}; use mast_forest_builder::MastForestBuilder; +use module_graph::ProcedureWrapper; use vm_core::{mast::MastNodeId, Decorator, DecoratorList, Felt, Kernel, Operation, Program}; mod basic_block_builder; @@ -614,10 +615,13 @@ impl Assembler { match resolved { ResolvedTarget::Phantom(digest) => Ok(digest), ResolvedTarget::Exact { gid } | ResolvedTarget::Resolved { gid, .. } => { - Ok(mast_forest_builder - .get_procedure(gid) - .map(|p| p.mast_root()) - .expect("expected callee to have been compiled already")) + match mast_forest_builder.get_procedure(gid) { + Some(p) => Ok(p.mast_root()), + None => match self.module_graph.get_procedure_unsafe(gid) { + ProcedureWrapper::Info(p) => Ok(p.digest), + ProcedureWrapper::Ast(_) => panic!("Did not find procedure {gid:?} neither in module graph nor procedure cache"), + }, + } } } } From 98ae7bc40a4aa29c63973426717488a96505acd5 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 25 Jul 2024 15:12:27 -0400 Subject: [PATCH 159/172] cleanup `ModuleGraph::topological_sort_from_root` --- assembly/src/assembler/mod.rs | 26 +++++++++++++--------- assembly/src/assembler/module_graph/mod.rs | 8 +------ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index c56ff30e73..2e516d8116 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -398,16 +398,22 @@ impl Assembler { is_entrypoint: bool, mast_forest_builder: &mut MastForestBuilder, ) -> Result, Report> { - let mut worklist = self.module_graph.topological_sort_from_root(root).map_err(|cycle| { - let iter = cycle.into_node_ids(); - let mut nodes = Vec::with_capacity(iter.len()); - for node in iter { - let module = self.module_graph[node.module].path(); - let proc = self.module_graph.get_procedure_unsafe(node); - nodes.push(format!("{}::{}", module, proc.name())); - } - AssemblyError::Cycle { nodes } - })?; + let mut worklist: Vec = self + .module_graph + .topological_sort_from_root(root) + .map_err(|cycle| { + let iter = cycle.into_node_ids(); + let mut nodes = Vec::with_capacity(iter.len()); + for node in iter { + let module = self.module_graph[node.module].path(); + let proc = self.module_graph.get_procedure_unsafe(node); + nodes.push(format!("{}::{}", module, proc.name())); + } + AssemblyError::Cycle { nodes } + })? + .into_iter() + .filter(|&gid| self.module_graph.get_procedure_unsafe(gid).is_ast()) + .collect(); assert!(!worklist.is_empty()); diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 1c74152cf2..21a4721cbf 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -471,13 +471,7 @@ impl ModuleGraph { &self, caller: GlobalProcedureIndex, ) -> Result, CycleError> { - Ok(self - .callgraph - .toposort_caller(caller)? - .into_iter() - // TODOP: do this outside the function - .filter(|&gid| self.get_procedure_unsafe(gid).is_ast()) - .collect()) + Ok(self.callgraph.toposort_caller(caller)?) } /// Fetch a [Module] by [ModuleIndex] From 250e86d186f3d60e21724cede5579a1380b9c91f Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 25 Jul 2024 15:17:56 -0400 Subject: [PATCH 160/172] fix CI --- assembly/src/assembler/module_graph/mod.rs | 2 +- assembly/src/library/mod.rs | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 21a4721cbf..61bb5c04c2 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -471,7 +471,7 @@ impl ModuleGraph { &self, caller: GlobalProcedureIndex, ) -> Result, CycleError> { - Ok(self.callgraph.toposort_caller(caller)?) + self.callgraph.toposort_caller(caller) } /// Fetch a [Module] by [ModuleIndex] diff --git a/assembly/src/library/mod.rs b/assembly/src/library/mod.rs index 2ec1b72890..3bc7f5ae8e 100644 --- a/assembly/src/library/mod.rs +++ b/assembly/src/library/mod.rs @@ -1,6 +1,4 @@ -use std::collections::BTreeMap; - -use alloc::vec::Vec; +use alloc::{collections::BTreeMap, vec::Vec}; use vm_core::crypto::hash::RpoDigest; use vm_core::mast::MastForest; From 42a991a27ae5f6e4b6a54cf2c0da29673eda4c6c Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 25 Jul 2024 15:21:12 -0400 Subject: [PATCH 161/172] re-implement `Spanned` for `ResolvedProcedure` --- assembly/src/ast/procedure/resolver.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/assembly/src/ast/procedure/resolver.rs b/assembly/src/ast/procedure/resolver.rs index 89c1978eac..8199b32d68 100644 --- a/assembly/src/ast/procedure/resolver.rs +++ b/assembly/src/ast/procedure/resolver.rs @@ -1,5 +1,5 @@ use super::{FullyQualifiedProcedureName, ProcedureIndex, ProcedureName}; -use crate::{ast::Ident, LibraryPath, RpoDigest, Span}; +use crate::{ast::Ident, LibraryPath, RpoDigest, SourceSpan, Span, Spanned}; use alloc::{collections::BTreeMap, vec::Vec}; // RESOLVED PROCEDURE @@ -16,6 +16,16 @@ pub enum ResolvedProcedure { MastRoot(RpoDigest), } +impl Spanned for ResolvedProcedure { + fn span(&self) -> SourceSpan { + match self { + ResolvedProcedure::Local(p) => p.span(), + ResolvedProcedure::External(p) => p.span(), + ResolvedProcedure::MastRoot(_) => SourceSpan::default(), + } + } +} + // LOCAL NAME RESOLVER // ================================================================================================ From 8669dfcb2e24144b85b905295eafae9cd8ea7dc0 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 25 Jul 2024 16:23:32 -0400 Subject: [PATCH 162/172] reintroduce proper error message --- assembly/src/ast/module.rs | 8 +++++--- assembly/src/sema/errors.rs | 5 ----- assembly/src/tests.rs | 8 +++++++- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/assembly/src/ast/module.rs b/assembly/src/ast/module.rs index 4e6cb0826b..1c608fa04a 100644 --- a/assembly/src/ast/module.rs +++ b/assembly/src/ast/module.rs @@ -208,9 +208,11 @@ impl Module { span: export.span(), }); } - if let Some(_prev) = self.resolve(export.name()) { - Err(SemanticAnalysisError::ProcedureNameConflict { - name: export.name().clone(), + if let Some(prev) = self.resolve(export.name()) { + let prev_span = prev.span(); + Err(SemanticAnalysisError::SymbolConflict { + span: export.span(), + prev_span, }) } else { self.procedures.push(export); diff --git a/assembly/src/sema/errors.rs b/assembly/src/sema/errors.rs index de6447abbf..59b1852f4d 100644 --- a/assembly/src/sema/errors.rs +++ b/assembly/src/sema/errors.rs @@ -5,8 +5,6 @@ use crate::{ use alloc::{sync::Arc, vec::Vec}; use core::fmt; -use super::ProcedureName; - /// The high-level error type for all semantic analysis errors. /// /// This rolls up multiple errors into a single one, and as such, can emit many @@ -78,9 +76,6 @@ pub enum SemanticAnalysisError { #[label("previously defined here")] prev_span: SourceSpan, }, - #[error("procedure name conflict: found duplicate definitions of '{name}'")] - #[diagnostic()] - ProcedureNameConflict { name: ProcedureName }, #[error("symbol undefined: no such name found in scope")] #[diagnostic(help("are you missing an import?"))] SymbolUndefined { diff --git a/assembly/src/tests.rs b/assembly/src/tests.rs index 7dbeb1ab6a..d50a6484e0 100644 --- a/assembly/src/tests.rs +++ b/assembly/src/tests.rs @@ -2198,7 +2198,13 @@ fn invalid_proc_duplicate_procedure_name() { source, "syntax error", "help: see emitted diagnostics for details", - "procedure name conflict: found duplicate definitions of 'foo'" + "symbol conflict: found duplicate definitions of the same name", + regex!(r#",-\[test[\d]+:1:6\]"#), + "1 | proc.foo add mul end proc.foo push.3 end begin push.1 end", + " : ^|^ ^^^^^^^^^|^^^^^^^^^", + " : | `-- conflict occurs here", + " : `-- previously defined here", + " `----" ); } From d444d671936733ce610ef631cf6c1af973d27568 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 25 Jul 2024 16:26:02 -0400 Subject: [PATCH 163/172] remove unused methods --- assembly/src/assembler/module_graph/mod.rs | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 61bb5c04c2..a9c84b4df7 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -474,23 +474,6 @@ impl ModuleGraph { self.callgraph.toposort_caller(caller) } - /// Fetch a [Module] by [ModuleIndex] - #[allow(unused)] - pub fn get_module(&self, id: ModuleIndex) -> Option { - self.modules.get(id.as_usize()).cloned() - } - - /// Fetch a [WrapperProcedure] by [GlobalProcedureIndex], or `None` if index is invalid. - #[allow(unused)] - pub fn get_procedure(&self, id: GlobalProcedureIndex) -> Option { - match &self.modules[id.module.as_usize()] { - WrappedModule::Ast(m) => m.get(id.index).map(ProcedureWrapper::Ast), - WrappedModule::Info(m) => { - m.get_proc_info_by_index(id.index).map(ProcedureWrapper::Info) - } - } - } - /// Fetch a [WrapperProcedure] by [GlobalProcedureIndex]. /// /// # Panics From 48d59a9659d6241360471b84271ca072ed46dc12 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 25 Jul 2024 16:30:09 -0400 Subject: [PATCH 164/172] Remove all `allow(unused)` methods --- assembly/src/assembler/module_graph/mod.rs | 30 ---------------------- 1 file changed, 30 deletions(-) diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index a9c84b4df7..d48464257e 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -245,19 +245,9 @@ impl ModuleGraph { &self.kernel } - #[allow(unused)] - pub fn kernel_index(&self) -> Option { - self.kernel_index - } - pub fn has_nonempty_kernel(&self) -> bool { self.kernel_index.is_some() || !self.kernel.is_empty() } - - #[allow(unused)] - pub fn is_kernel_procedure_root(&self, digest: &RpoDigest) -> bool { - self.kernel.contains_proc(*digest) - } } // ------------------------------------------------------------------------------------------------ @@ -494,11 +484,6 @@ impl ModuleGraph { self.roots.get(digest).map(|indices| indices[0]) } - #[allow(unused)] - pub fn callees(&self, gid: GlobalProcedureIndex) -> &[GlobalProcedureIndex] { - self.callgraph.out_edges(gid) - } - /// Resolves `target` from the perspective of `caller`. pub fn resolve_target( &self, @@ -585,21 +570,6 @@ impl ModuleGraph { pub fn find_module(&self, name: &LibraryPath) -> Option { self.modules.iter().find(|m| m.path() == name).cloned() } - - /// Returns an iterator over the set of [Module]s in this graph, and their indices - #[allow(unused)] - pub fn modules(&self) -> impl Iterator + '_ { - self.modules - .iter() - .enumerate() - .map(|(idx, m)| (ModuleIndex::new(idx), m.clone())) - } - - /// Like [modules], but returns a reference to the module, rather than an owned pointer - #[allow(unused)] - pub fn modules_by_ref(&self) -> impl Iterator + '_ { - self.modules.iter().enumerate().map(|(idx, m)| (ModuleIndex::new(idx), m)) - } } impl Index for ModuleGraph { From 770b25da5198581c808ef4c3c837f6537737ab08 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 26 Jul 2024 12:37:17 -0400 Subject: [PATCH 165/172] Document `unwrap_ast()` call --- assembly/src/assembler/instruction/procedures.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index a3b6e74b59..4863fb49f7 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -58,6 +58,9 @@ impl Assembler { callee: proc.fully_qualified_name().clone(), }) .and_then(|module| { + // Note: this module is guaranteed to be of AST variant, since we have the + // AST of a procedure contained in it (i.e. `proc`). Hence, it must be that + // the entire module is in AST representation as well. if module.unwrap_ast().is_kernel() { Ok(()) } else { From 53e65551a2bc3f4a5665e1affd453c17dbcbf2d4 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 26 Jul 2024 12:40:31 -0400 Subject: [PATCH 166/172] `NameResolver`: remove use of `unwrap_ast()` --- assembly/src/assembler/module_graph/name_resolver.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/assembly/src/assembler/module_graph/name_resolver.rs b/assembly/src/assembler/module_graph/name_resolver.rs index fb75245946..2db16b000c 100644 --- a/assembly/src/assembler/module_graph/name_resolver.rs +++ b/assembly/src/assembler/module_graph/name_resolver.rs @@ -449,7 +449,10 @@ impl<'a> NameResolver<'a> { if module_index >= pending_offset { self.pending[module_index - pending_offset].source_file.clone() } else { - self.graph[module].unwrap_ast().source_file() + match &self.graph[module] { + WrappedModule::Ast(module) => module.source_file(), + WrappedModule::Info(_) => None, + } } } From 30ef4c5356ff937c46c50065135766075a5b1d4b Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 26 Jul 2024 12:50:31 -0400 Subject: [PATCH 167/172] Document or remove all calls to `WrappedModule.unwrap_ast()` --- assembly/src/assembler/mod.rs | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 2e516d8116..0991385830 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -10,7 +10,7 @@ use crate::{ }; use alloc::{sync::Arc, vec::Vec}; use mast_forest_builder::MastForestBuilder; -use module_graph::ProcedureWrapper; +use module_graph::{ProcedureWrapper, WrappedModule}; use vm_core::{mast::MastNodeId, Decorator, DecoratorList, Felt, Kernel, Operation, Program}; mod basic_block_builder; @@ -259,7 +259,7 @@ impl Assembler { mut self, modules: impl Iterator, ) -> Result { - let module_indices: Vec = modules + let ast_module_indices: Vec = modules .map(|module| { let module = module.compile_with_options(CompileOptions::for_library())?; @@ -273,10 +273,12 @@ impl Assembler { let exports = { let mut exports = Vec::new(); - for module_idx in module_indices { - let module = self.module_graph[module_idx].unwrap_ast().clone(); + for ast_module_idx in ast_module_indices { + // Note: it is safe to use `unwrap_ast()` here, since all modules looped over are + // AST (we just added them to the module graph) + let ast_module = self.module_graph[ast_module_idx].unwrap_ast().clone(); - for (proc_idx, procedure) in module.procedures().enumerate() { + for (proc_idx, procedure) in ast_module.procedures().enumerate() { // Only add exports; locals will be added if they are in the call graph rooted // at those procedures if !procedure.visibility().is_exported() { @@ -284,14 +286,14 @@ impl Assembler { } let gid = GlobalProcedureIndex { - module: module_idx, + module: ast_module_idx, index: ProcedureIndex::new(proc_idx), }; self.compile_subgraph(gid, false, &mut mast_forest_builder)?; exports.push(FullyQualifiedProcedureName::new( - module.path().clone(), + ast_module.path().clone(), procedure.name().clone(), )); } @@ -349,15 +351,16 @@ impl Assembler { assert!(program.is_executable()); // Recompute graph with executable module, and start compiling - let module_index = self.module_graph.add_ast_module(program)?; + let ast_module_index = self.module_graph.add_ast_module(program)?; self.module_graph.recompute()?; - // Find the executable entrypoint - let entrypoint = self.module_graph[module_index] + // Find the executable entrypoint Note: it is safe to use `unwrap_ast()` here, since this is + // the module we just added, which is in AST representation. + let entrypoint = self.module_graph[ast_module_index] .unwrap_ast() .index_of(|p| p.is_main()) .map(|index| GlobalProcedureIndex { - module: module_index, + module: ast_module_index, index, }) .ok_or(SemanticAnalysisError::MissingEntrypoint)?; @@ -427,6 +430,7 @@ impl Assembler { Ok(compiled.expect("compilation succeeded but root not found in cache")) } + /// Compiles all procedures in the `worklist`. fn process_graph_worklist( &mut self, worklist: &mut Vec, @@ -445,7 +449,12 @@ impl Assembler { let is_entry = entrypoint == Some(procedure_gid); // Fetch procedure metadata from the graph - let module = &self.module_graph[procedure_gid.module].unwrap_ast(); + let module = match &self.module_graph[procedure_gid.module] { + WrappedModule::Ast(ast_module) => ast_module, + // Note: if the containing module is in `Info` representation, there is nothing to + // compile. + WrappedModule::Info(_) => continue, + }; let ast = &module[procedure_gid.index]; let num_locals = ast.num_locals(); let name = FullyQualifiedProcedureName { From 2c43ed447631a2ced04d1604a2e3f4679a778307 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 26 Jul 2024 12:52:12 -0400 Subject: [PATCH 168/172] rename `PendingWrappedModule` --- assembly/src/assembler/module_graph/mod.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index d48464257e..1235f82c47 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -121,12 +121,12 @@ impl WrappedModule { /// Wraps modules that are pending in the [`ModuleGraph`]. #[derive(Clone)] -pub enum PendingModuleWrapper { +pub enum PendingWrappedModule { Ast(Box), Info(ModuleInfo), } -impl PendingModuleWrapper { +impl PendingWrappedModule { /// Returns the library path of the wrapped module. pub fn path(&self) -> &LibraryPath { match self { @@ -150,7 +150,7 @@ pub struct ModuleGraph { /// /// Once added to the graph, modules become immutable, and any additional modules added after /// that must by definition only depend on modules in the graph, and not be depended upon. - pending: Vec, + pending: Vec, /// The global call graph of calls, not counting those that are performed directly via MAST /// root. callgraph: CallGraph, @@ -182,7 +182,7 @@ impl ModuleGraph { /// This function will panic if the number of modules exceeds the maximum representable /// [ModuleIndex] value, `u16::MAX`. pub fn add_ast_module(&mut self, module: Box) -> Result { - self.add_module(PendingModuleWrapper::Ast(module)) + self.add_module(PendingWrappedModule::Ast(module)) } /// Add the [`ModuleInfo`] to the graph. @@ -206,10 +206,10 @@ impl ModuleGraph { &mut self, module_info: ModuleInfo, ) -> Result { - self.add_module(PendingModuleWrapper::Info(module_info)) + self.add_module(PendingWrappedModule::Info(module_info)) } - fn add_module(&mut self, module: PendingModuleWrapper) -> Result { + fn add_module(&mut self, module: PendingWrappedModule) -> Result { let is_duplicate = self.is_pending(module.path()) || self.find_module_index(module.path()).is_some(); if is_duplicate { @@ -313,7 +313,7 @@ impl ModuleGraph { // Apply module to call graph match pending_module { - PendingModuleWrapper::Ast(pending_module) => { + PendingWrappedModule::Ast(pending_module) => { for (index, procedure) in pending_module.procedures().enumerate() { let procedure_id = ProcedureIndex::new(index); let global_id = GlobalProcedureIndex { @@ -329,7 +329,7 @@ impl ModuleGraph { } } } - PendingModuleWrapper::Info(pending_module) => { + PendingWrappedModule::Info(pending_module) => { for (proc_index, _procedure) in pending_module.procedure_infos() { let global_id = GlobalProcedureIndex { module: module_id, @@ -345,7 +345,7 @@ impl ModuleGraph { // before they are added to the graph let mut resolver = NameResolver::new(self); for module in pending.iter() { - if let PendingModuleWrapper::Ast(module) = module { + if let PendingWrappedModule::Ast(module) = module { resolver.push_pending(module); } } @@ -355,7 +355,7 @@ impl ModuleGraph { // Visit all of the newly-added modules and perform any rewrites to AST modules. for (pending_index, module) in pending.into_iter().enumerate() { match module { - PendingModuleWrapper::Ast(mut ast_module) => { + PendingWrappedModule::Ast(mut ast_module) => { let module_id = ModuleIndex::new(high_water_mark + pending_index); let mut rewriter = ModuleRewriter::new(&resolver); @@ -385,7 +385,7 @@ impl ModuleGraph { finished.push(WrappedModule::Ast(Arc::new(*ast_module))) } - PendingModuleWrapper::Info(module) => { + PendingWrappedModule::Info(module) => { finished.push(WrappedModule::Info(module)); } } From 258aa999eedde5fd2bffae5be9dd8e4a2db63ab2 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 26 Jul 2024 13:04:07 -0400 Subject: [PATCH 169/172] Add `ModuleGraph::add_compiled_modules()` --- assembly/src/assembler/mod.rs | 25 +++------------------ assembly/src/assembler/module_graph/mod.rs | 26 ++++++++++++++++++++++ 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 0991385830..b1ecdf27cb 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -147,28 +147,9 @@ impl Assembler { /// Adds the compiled library to provide modules for the compilation. pub fn add_compiled_library(&mut self, library: CompiledLibrary) -> Result<(), Report> { - let module_indexes: Vec = library - .into_module_infos() - .map(|module| self.module_graph.add_module_info(module)) - .collect::>()?; - - self.module_graph.recompute()?; - - // Register all procedures as roots - for module_index in module_indexes { - for (proc_index, proc) in - self.module_graph[module_index].unwrap_info().clone().procedure_infos() - { - let gid = GlobalProcedureIndex { - module: module_index, - index: proc_index, - }; - - self.module_graph.register_mast_root(gid, proc.digest)?; - } - } - - Ok(()) + self.module_graph + .add_compiled_modules(library.into_module_infos()) + .map_err(Report::from) } /// Adds the library to provide modules for the compilation. diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 1235f82c47..96a6ed5fc4 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -164,6 +164,32 @@ pub struct ModuleGraph { // ------------------------------------------------------------------------------------------------ /// Constructors impl ModuleGraph { + /// Adds all module infos to the graph. + pub fn add_compiled_modules( + &mut self, + module_infos: impl Iterator, + ) -> Result<(), AssemblyError> { + let module_indices: Vec = module_infos + .map(|module| self.add_module_info(module)) + .collect::>()?; + + self.recompute()?; + + // Register all procedures as roots + for module_index in module_indices { + for (proc_index, proc) in self[module_index].unwrap_info().clone().procedure_infos() { + let gid = GlobalProcedureIndex { + module: module_index, + index: proc_index, + }; + + self.register_mast_root(gid, proc.digest)?; + } + } + + Ok(()) + } + /// Add `module` to the graph. /// /// NOTE: This operation only adds a module to the graph, but does not perform the From 49879de60c69c1962a4983df6ee3387ac9e31754 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Fri, 26 Jul 2024 14:18:17 -0400 Subject: [PATCH 170/172] Remove `ModuleGraph::add_module_info()` --- assembly/src/assembler/module_graph/mod.rs | 26 +--------------------- 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index 96a6ed5fc4..a5ea5a3e05 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -170,7 +170,7 @@ impl ModuleGraph { module_infos: impl Iterator, ) -> Result<(), AssemblyError> { let module_indices: Vec = module_infos - .map(|module| self.add_module_info(module)) + .map(|module| self.add_module(PendingWrappedModule::Info(module))) .collect::>()?; self.recompute()?; @@ -211,30 +211,6 @@ impl ModuleGraph { self.add_module(PendingWrappedModule::Ast(module)) } - /// Add the [`ModuleInfo`] to the graph. - /// - /// NOTE: This operation only adds a module to the graph, but does not perform the important - /// analysis needed for compilation, you must call [`Self::recompute`] once all modules are - /// added to ensure the analysis results reflect the current version of the graph. - /// - /// # Errors - /// - /// This operation can fail for the following reasons: - /// - /// * Module with same [LibraryPath] is in the graph already - /// * Too many modules in the graph - /// - /// # Panics - /// - /// This function will panic if the number of modules exceeds the maximum representable - /// [ModuleIndex] value, `u16::MAX`. - pub fn add_module_info( - &mut self, - module_info: ModuleInfo, - ) -> Result { - self.add_module(PendingWrappedModule::Info(module_info)) - } - fn add_module(&mut self, module: PendingWrappedModule) -> Result { let is_duplicate = self.is_pending(module.path()) || self.find_module_index(module.path()).is_some(); From 302c4cc09b139f95c324782dca5497ceb952fcec Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Sun, 28 Jul 2024 14:50:23 -0700 Subject: [PATCH 171/172] refactor: remove Assembler::compile_program() internal method --- assembly/src/assembler/mod.rs | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index b1ecdf27cb..4b9c1ffbb6 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -321,8 +321,6 @@ impl Assembler { )); } - let mast_forest_builder = MastForestBuilder::default(); - let program = source.compile_with_options(CompileOptions { // Override the module name so that we always compile the executable // module as #exe @@ -346,23 +344,8 @@ impl Assembler { }) .ok_or(SemanticAnalysisError::MissingEntrypoint)?; - self.compile_program(entrypoint, mast_forest_builder) - } - - /// Compile the provided [Module] into a [Program]. - /// - /// Ensures that the [`MastForest`] entrypoint is set to the entrypoint of the program. - /// - /// Returns an error if the provided Miden Assembly is invalid. - fn compile_program( - mut self, - entrypoint: GlobalProcedureIndex, - mut mast_forest_builder: MastForestBuilder, - ) -> Result { - // Raise an error if we are called with an invalid entrypoint - assert!(self.module_graph.get_procedure_unsafe(entrypoint).name().is_main()); - // Compile the module graph rooted at the entrypoint + let mut mast_forest_builder = MastForestBuilder::default(); let entry_procedure = self.compile_subgraph(entrypoint, true, &mut mast_forest_builder)?; Ok(Program::with_kernel( From fa4b79e1305e883658e741db400e495f3e140edf Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Sun, 28 Jul 2024 15:05:56 -0700 Subject: [PATCH 172/172] refactor: remove Assembler::assemble_with_options() internal method --- assembly/src/assembler/mod.rs | 36 +++++------------------------------ 1 file changed, 5 insertions(+), 31 deletions(-) diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 4b9c1ffbb6..88dc369577 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -293,40 +293,14 @@ impl Assembler { /// /// Returns an error if parsing or compilation of the specified program fails, or if the source /// doesn't have an entrypoint. - pub fn assemble_program(self, source: impl Compile) -> Result { - let opts = CompileOptions { + pub fn assemble_program(mut self, source: impl Compile) -> Result { + let options = CompileOptions { + kind: ModuleKind::Executable, warnings_as_errors: self.warnings_as_errors, - ..CompileOptions::default() + path: Some(LibraryPath::from(LibraryNamespace::Exec)), }; - self.assemble_with_options(source, opts) - } - - /// Compiles the provided module into a [Program] using the provided options. - /// - /// The resulting program can be executed on Miden VM. - /// - /// # Errors - /// - /// Returns an error if parsing or compilation of the specified program fails, or the options - /// are invalid. - fn assemble_with_options( - mut self, - source: impl Compile, - options: CompileOptions, - ) -> Result { - if options.kind != ModuleKind::Executable { - return Err(Report::msg( - "invalid compile options: assemble_with_opts_in_context requires that the kind be 'executable'", - )); - } - - let program = source.compile_with_options(CompileOptions { - // Override the module name so that we always compile the executable - // module as #exe - path: Some(LibraryPath::from(LibraryNamespace::Exec)), - ..options - })?; + let program = source.compile_with_options(options)?; assert!(program.is_executable()); // Recompute graph with executable module, and start compiling