diff --git a/CHANGELOG.md b/CHANGELOG.md index a33b178741..72392ea618 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ - Updated CI and Makefile to standardise it accross Miden repositories (#1342). - Add serialization/deserialization for `MastForest` (#1370) - Updated CI to support `CHANGELOG.md` modification checking and `no changelog` label (#1406) +- Introduce `MastForestError` to enforce `MastForest` node count invariant (#1394) #### Changed diff --git a/assembly/src/assembler/basic_block_builder.rs b/assembly/src/assembler/basic_block_builder.rs index dfbdbd84c2..2545d4eea8 100644 --- a/assembly/src/assembler/basic_block_builder.rs +++ b/assembly/src/assembler/basic_block_builder.rs @@ -4,7 +4,7 @@ use super::{ }; use alloc::{borrow::Borrow, string::ToString, vec::Vec}; use vm_core::{ - mast::{MastNode, MastNodeId}, + mast::{MastForestError, MastNode, MastNodeId}, AdviceInjector, AssemblyOp, Operation, }; @@ -129,22 +129,22 @@ impl BasicBlockBuilder { pub fn make_basic_block( &mut self, mast_forest_builder: &mut MastForestBuilder, - ) -> Option { + ) -> Result, MastForestError> { if !self.ops.is_empty() { let ops = self.ops.drain(..).collect(); let decorators = self.decorators.drain(..).collect(); let basic_block_node = MastNode::new_basic_block_with_decorators(ops, decorators); - let basic_block_node_id = mast_forest_builder.ensure_node(basic_block_node); + let basic_block_node_id = mast_forest_builder.ensure_node(basic_block_node)?; - Some(basic_block_node_id) + Ok(Some(basic_block_node_id)) } else if !self.decorators.is_empty() { // this is a bug in the assembler. we shouldn't have decorators added without their // associated operations // TODO: change this to an error or allow decorators in empty span blocks unreachable!("decorators in an empty SPAN block") } else { - None + Ok(None) } } @@ -155,10 +155,10 @@ impl BasicBlockBuilder { /// - Operations contained in the epilogue of the builder are appended to the list of ops which /// go into the new BASIC BLOCK node. /// - The builder is consumed in the process. - pub fn into_basic_block( + pub fn try_into_basic_block( mut self, mast_forest_builder: &mut MastForestBuilder, - ) -> Option { + ) -> Result, MastForestError> { self.ops.append(&mut self.epilogue); self.make_basic_block(mast_forest_builder) } diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index 9894a82092..d6a96f0f1b 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -91,37 +91,44 @@ impl Assembler { // procedures, such that when we assemble a procedure, all // procedures that it calls will have been assembled, and // hence be present in the `MastForest`. - mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| { - // If the MAST root called isn't known to us, make it an external - // reference. - let external_node = MastNode::new_external(mast_root); - mast_forest_builder.ensure_node(external_node) - }) + match mast_forest_builder.find_procedure_root(mast_root) { + Some(root) => root, + None => { + // If the MAST root called isn't known to us, make it an external + // reference. + let external_node = MastNode::new_external(mast_root); + mast_forest_builder.ensure_node(external_node)? + } + } } InvokeKind::Call => { - let callee_id = - mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| { + let callee_id = match mast_forest_builder.find_procedure_root(mast_root) { + Some(callee_id) => callee_id, + None => { // If the MAST root called isn't known to us, make it an external // reference. let external_node = MastNode::new_external(mast_root); - mast_forest_builder.ensure_node(external_node) - }); + mast_forest_builder.ensure_node(external_node)? + } + }; let call_node = MastNode::new_call(callee_id, mast_forest_builder.forest()); - mast_forest_builder.ensure_node(call_node) + mast_forest_builder.ensure_node(call_node)? } InvokeKind::SysCall => { - let callee_id = - mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| { + let callee_id = match mast_forest_builder.find_procedure_root(mast_root) { + Some(callee_id) => callee_id, + None => { // If the MAST root called isn't known to us, make it an external // reference. let external_node = MastNode::new_external(mast_root); - mast_forest_builder.ensure_node(external_node) - }); + mast_forest_builder.ensure_node(external_node)? + } + }; let syscall_node = MastNode::new_syscall(callee_id, mast_forest_builder.forest()); - mast_forest_builder.ensure_node(syscall_node) + mast_forest_builder.ensure_node(syscall_node)? } } }; @@ -134,7 +141,7 @@ impl Assembler { &self, mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { - let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn); + let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?; Ok(Some(dyn_node_id)) } @@ -145,10 +152,10 @@ impl Assembler { mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { let dyn_call_node_id = { - let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn); + let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?; let dyn_call_node = MastNode::new_call(dyn_node_id, mast_forest_builder.forest()); - mast_forest_builder.ensure_node(dyn_call_node) + mast_forest_builder.ensure_node(dyn_call_node)? }; Ok(Some(dyn_call_node_id)) diff --git a/assembly/src/assembler/mast_forest_builder.rs b/assembly/src/assembler/mast_forest_builder.rs index 39c42c388b..a5854c8b01 100644 --- a/assembly/src/assembler/mast_forest_builder.rs +++ b/assembly/src/assembler/mast_forest_builder.rs @@ -3,7 +3,7 @@ use core::ops::Index; use alloc::collections::BTreeMap; use vm_core::{ crypto::hash::RpoDigest, - mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastForestError, MastNode, MastNodeId, MerkleTreeNode}, }; /// Builder for a [`MastForest`]. @@ -44,17 +44,17 @@ impl MastForestBuilder { /// If a [`MastNode`] which is equal to the current node was previously added, the previously /// returned [`MastNodeId`] will be returned. This enforces this invariant that equal /// [`MastNode`]s have equal [`MastNodeId`]s. - pub fn ensure_node(&mut self, node: MastNode) -> MastNodeId { + pub fn ensure_node(&mut self, node: MastNode) -> Result { let node_digest = node.digest(); if let Some(node_id) = self.node_id_by_hash.get(&node_digest) { // node already exists in the forest; return previously assigned id - *node_id + Ok(*node_id) } else { - let new_node_id = self.mast_forest.add_node(node); + let new_node_id = self.mast_forest.add_node(node)?; self.node_id_by_hash.insert(node_digest, new_node_id); - new_node_id + Ok(new_node_id) } } diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index 26cdaa9dfe..b6bb051d30 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -766,8 +766,9 @@ impl Assembler { context, mast_forest_builder, )? { - if let Some(basic_block_id) = - basic_block_builder.make_basic_block(mast_forest_builder) + if let Some(basic_block_id) = basic_block_builder + .make_basic_block(mast_forest_builder) + .map_err(AssemblyError::from)? { mast_node_ids.push(basic_block_id); } @@ -779,8 +780,9 @@ impl Assembler { Op::If { then_blk, else_blk, .. } => { - if let Some(basic_block_id) = - basic_block_builder.make_basic_block(mast_forest_builder) + if let Some(basic_block_id) = basic_block_builder + .make_basic_block(mast_forest_builder) + .map_err(AssemblyError::from)? { mast_node_ids.push(basic_block_id); } @@ -794,14 +796,15 @@ impl Assembler { let split_node = MastNode::new_split(then_blk, else_blk, mast_forest_builder.forest()); - mast_forest_builder.ensure_node(split_node) + mast_forest_builder.ensure_node(split_node).map_err(AssemblyError::from)? }; mast_node_ids.push(split_node_id); } Op::Repeat { count, body, .. } => { - if let Some(basic_block_id) = - basic_block_builder.make_basic_block(mast_forest_builder) + if let Some(basic_block_id) = basic_block_builder + .make_basic_block(mast_forest_builder) + .map_err(AssemblyError::from)? { mast_node_ids.push(basic_block_id); } @@ -815,8 +818,9 @@ impl Assembler { } Op::While { body, .. } => { - if let Some(basic_block_id) = - basic_block_builder.make_basic_block(mast_forest_builder) + if let Some(basic_block_id) = basic_block_builder + .make_basic_block(mast_forest_builder) + .map_err(AssemblyError::from)? { mast_node_ids.push(basic_block_id); } @@ -827,22 +831,25 @@ impl Assembler { let loop_node_id = { let loop_node = MastNode::new_loop(loop_body_node_id, mast_forest_builder.forest()); - mast_forest_builder.ensure_node(loop_node) + mast_forest_builder.ensure_node(loop_node).map_err(AssemblyError::from)? }; mast_node_ids.push(loop_node_id); } } } - if let Some(basic_block_id) = basic_block_builder.into_basic_block(mast_forest_builder) { + if let Some(basic_block_id) = basic_block_builder + .try_into_basic_block(mast_forest_builder) + .map_err(AssemblyError::from)? + { mast_node_ids.push(basic_block_id); } Ok(if mast_node_ids.is_empty() { let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]); - mast_forest_builder.ensure_node(basic_block_node) + mast_forest_builder.ensure_node(basic_block_node).map_err(AssemblyError::from)? } else { - combine_mast_node_ids(mast_node_ids, mast_forest_builder) + combine_mast_node_ids(mast_node_ids, mast_forest_builder)? }) } @@ -882,7 +889,7 @@ struct BodyWrapper { fn combine_mast_node_ids( mut mast_node_ids: Vec, mast_forest_builder: &mut MastForestBuilder, -) -> MastNodeId { +) -> Result { debug_assert!(!mast_node_ids.is_empty(), "cannot combine empty MAST node id list"); // build a binary tree of blocks joining them using JOIN blocks @@ -901,7 +908,7 @@ fn combine_mast_node_ids( (source_mast_node_iter.next(), source_mast_node_iter.next()) { let join_mast_node = MastNode::new_join(left, right, mast_forest_builder.forest()); - let join_mast_node_id = mast_forest_builder.ensure_node(join_mast_node); + let join_mast_node_id = mast_forest_builder.ensure_node(join_mast_node)?; mast_node_ids.push(join_mast_node_id); } @@ -910,5 +917,5 @@ fn combine_mast_node_ids( } } - mast_node_ids.remove(0) + Ok(mast_node_ids.remove(0)) } diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index ee68d98b27..e627d985ab 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -82,11 +82,11 @@ fn nested_blocks() { // `Assembler::with_kernel_from_module()`. let syscall_foo_node_id = { let kernel_foo_node = MastNode::new_basic_block(vec![Operation::Add]); - let kernel_foo_node_id = expected_mast_forest_builder.ensure_node(kernel_foo_node); + let kernel_foo_node_id = expected_mast_forest_builder.ensure_node(kernel_foo_node).unwrap(); let syscall_node = MastNode::new_syscall(kernel_foo_node_id, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(syscall_node) + expected_mast_forest_builder.ensure_node(syscall_node).unwrap() }; let program = r#" @@ -130,63 +130,63 @@ fn nested_blocks() { let exec_bar_node_id = { // bar procedure let basic_block_1 = MastNode::new_basic_block(vec![Operation::Push(17_u32.into())]); - let basic_block_1_id = expected_mast_forest_builder.ensure_node(basic_block_1); + let basic_block_1_id = expected_mast_forest_builder.ensure_node(basic_block_1).unwrap(); // Basic block representing the `foo` procedure let basic_block_2 = MastNode::new_basic_block(vec![Operation::Push(19_u32.into())]); - let basic_block_2_id = expected_mast_forest_builder.ensure_node(basic_block_2); + let basic_block_2_id = expected_mast_forest_builder.ensure_node(basic_block_2).unwrap(); let join_node = MastNode::new_join( basic_block_1_id, basic_block_2_id, expected_mast_forest_builder.forest(), ); - expected_mast_forest_builder.ensure_node(join_node) + expected_mast_forest_builder.ensure_node(join_node).unwrap() }; let exec_foo_bar_baz_node_id = { // basic block representing foo::bar.baz procedure let basic_block = MastNode::new_basic_block(vec![Operation::Push(29_u32.into())]); - expected_mast_forest_builder.ensure_node(basic_block) + expected_mast_forest_builder.ensure_node(basic_block).unwrap() }; let before = { let before_node = MastNode::new_basic_block(vec![Operation::Push(2u32.into())]); - expected_mast_forest_builder.ensure_node(before_node) + expected_mast_forest_builder.ensure_node(before_node).unwrap() }; let r#true1 = { let r#true_node = MastNode::new_basic_block(vec![Operation::Push(3u32.into())]); - expected_mast_forest_builder.ensure_node(r#true_node) + expected_mast_forest_builder.ensure_node(r#true_node).unwrap() }; let r#false1 = { let r#false_node = MastNode::new_basic_block(vec![Operation::Push(5u32.into())]); - expected_mast_forest_builder.ensure_node(r#false_node) + expected_mast_forest_builder.ensure_node(r#false_node).unwrap() }; let r#if1 = { let r#if_node = MastNode::new_split(r#true1, r#false1, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(r#if_node) + expected_mast_forest_builder.ensure_node(r#if_node).unwrap() }; let r#true3 = { let r#true_node = MastNode::new_basic_block(vec![Operation::Push(7u32.into())]); - expected_mast_forest_builder.ensure_node(r#true_node) + expected_mast_forest_builder.ensure_node(r#true_node).unwrap() }; let r#false3 = { let r#false_node = MastNode::new_basic_block(vec![Operation::Push(11u32.into())]); - expected_mast_forest_builder.ensure_node(r#false_node) + expected_mast_forest_builder.ensure_node(r#false_node).unwrap() }; let r#true2 = { let r#if_node = MastNode::new_split(r#true3, r#false3, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(r#if_node) + expected_mast_forest_builder.ensure_node(r#if_node).unwrap() }; let r#while = { let push_basic_block_id = { let push_basic_block = MastNode::new_basic_block(vec![Operation::Push(23u32.into())]); - expected_mast_forest_builder.ensure_node(push_basic_block) + expected_mast_forest_builder.ensure_node(push_basic_block).unwrap() }; let body_node_id = { let body_node = MastNode::new_join( @@ -195,15 +195,15 @@ fn nested_blocks() { expected_mast_forest_builder.forest(), ); - expected_mast_forest_builder.ensure_node(body_node) + expected_mast_forest_builder.ensure_node(body_node).unwrap() }; let loop_node = MastNode::new_loop(body_node_id, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(loop_node) + expected_mast_forest_builder.ensure_node(loop_node).unwrap() }; let push_13_basic_block_id = { let node = MastNode::new_basic_block(vec![Operation::Push(13u32.into())]); - expected_mast_forest_builder.ensure_node(node) + expected_mast_forest_builder.ensure_node(node).unwrap() }; let r#false2 = { @@ -212,17 +212,18 @@ fn nested_blocks() { r#while, expected_mast_forest_builder.forest(), ); - expected_mast_forest_builder.ensure_node(node) + expected_mast_forest_builder.ensure_node(node).unwrap() }; let nested = { let node = MastNode::new_split(r#true2, r#false2, expected_mast_forest_builder.forest()); - expected_mast_forest_builder.ensure_node(node) + expected_mast_forest_builder.ensure_node(node).unwrap() }; let combined_node_id = combine_mast_node_ids( vec![before, r#if1, nested, exec_foo_bar_baz_node_id, syscall_foo_node_id], &mut expected_mast_forest_builder, - ); + ) + .unwrap(); let expected_program = Program::new(expected_mast_forest_builder.build(), combined_node_id); assert_eq!(expected_program.hash(), program.hash()); @@ -281,26 +282,26 @@ fn duplicate_nodes() { // basic block: mul let mul_basic_block_id = { let node = MastNode::new_basic_block(vec![Operation::Mul]); - expected_mast_forest.add_node(node) + expected_mast_forest.add_node(node).unwrap() }; // basic block: add let add_basic_block_id = { let node = MastNode::new_basic_block(vec![Operation::Add]); - expected_mast_forest.add_node(node) + expected_mast_forest.add_node(node).unwrap() }; // inner split: `if.true add else mul end` let inner_split_id = { let node = MastNode::new_split(add_basic_block_id, mul_basic_block_id, &expected_mast_forest); - expected_mast_forest.add_node(node) + expected_mast_forest.add_node(node).unwrap() }; // root: outer split let root_id = { let node = MastNode::new_split(mul_basic_block_id, inner_split_id, &expected_mast_forest); - expected_mast_forest.add_node(node) + expected_mast_forest.add_node(node).unwrap() }; expected_mast_forest.make_root(root_id); diff --git a/assembly/src/errors.rs b/assembly/src/errors.rs index 5b51048292..648fd851ef 100644 --- a/assembly/src/errors.rs +++ b/assembly/src/errors.rs @@ -1,4 +1,5 @@ use alloc::{string::String, sync::Arc, vec::Vec}; +use vm_core::mast::MastForestError; use crate::{ ast::{FullyQualifiedProcedureName, ProcedureName}, @@ -135,7 +136,10 @@ pub enum AssemblyError { #[error(transparent)] #[diagnostic(transparent)] Other(#[from] RelatedError), + #[error(transparent)] + Forest(#[from] MastForestError), } + impl From for AssemblyError { fn from(report: Report) -> Self { Self::Other(RelatedError::new(report)) diff --git a/core/Cargo.toml b/core/Cargo.toml index d7b97938a4..d35bcf0ca7 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -19,7 +19,13 @@ doctest = false [features] default = ["std"] -std = ["miden-crypto/std", "miden-formatting/std", "math/std", "winter-utils/std", "thiserror/std"] +std = [ + "miden-crypto/std", + "miden-formatting/std", + "math/std", + "winter-utils/std", + "thiserror/std", +] [dependencies] math = { package = "winter-math", version = "0.9", default-features = false } diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 70a87c58c1..b997c9f051 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -77,6 +77,16 @@ impl Deserializable for MastNodeId { // MAST FOREST // ================================================================================================ +/// Represents the types of errors that can occur when dealing with MAST forest. +#[derive(Debug, thiserror::Error)] +pub enum MastForestError { + #[error( + "invalid node count: MAST forest exceeds the maximum of {} nodes", + MastForest::MAX_NODES + )] + TooManyNodes, +} + /// Represents one or more procedures, represented as a collection of [`MastNode`]s. /// /// A [`MastForest`] does not have an entrypoint, and hence is not executable. A [`crate::Program`] @@ -100,20 +110,21 @@ impl MastForest { /// Mutators impl MastForest { + /// The maximum number of nodes that can be stored in a single MAST forest. + const MAX_NODES: usize = (1 << 30) - 1; + /// Adds a node to the forest, and returns the associated [`MastNodeId`]. /// /// Adding two duplicate nodes will result in two distinct returned [`MastNodeId`]s. - pub fn add_node(&mut self, node: MastNode) -> MastNodeId { - let new_node_id = MastNodeId( - self.nodes - .len() - .try_into() - .expect("invalid node id: exceeded maximum number of nodes in a single forest"), - ); + pub fn add_node(&mut self, node: MastNode) -> Result { + if self.nodes.len() == Self::MAX_NODES { + return Err(MastForestError::TooManyNodes); + } + let new_node_id = MastNodeId(self.nodes.len() as u32); self.nodes.push(node); - new_node_id + Ok(new_node_id) } /// Marks the given [`MastNodeId`] as being the root of a procedure. diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index a71c79c89a..e23cdbd2fb 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -120,7 +120,11 @@ impl Deserializable for MastForest { let node = mast_node_info.try_into_mast_node(&mast_forest, &basic_block_data_decoder)?; - mast_forest.add_node(node); + mast_forest.add_node(node).map_err(|e| { + DeserializationError::InvalidValue(format!( + "failed to add node to MAST forest while deserializing: {e}", + )) + })?; } for root in roots { diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs index 4f0c56bcb8..69d746ec38 100644 --- a/core/src/mast/serialization/tests.rs +++ b/core/src/mast/serialization/tests.rs @@ -296,39 +296,39 @@ fn serialize_deserialize_all_nodes() { ]; let basic_block_node = MastNode::new_basic_block_with_decorators(operations, decorators); - mast_forest.add_node(basic_block_node) + mast_forest.add_node(basic_block_node).unwrap() }; let call_node_id = { let node = MastNode::new_call(basic_block_id, &mast_forest); - mast_forest.add_node(node) + mast_forest.add_node(node).unwrap() }; let syscall_node_id = { let node = MastNode::new_syscall(basic_block_id, &mast_forest); - mast_forest.add_node(node) + mast_forest.add_node(node).unwrap() }; let loop_node_id = { let node = MastNode::new_loop(basic_block_id, &mast_forest); - mast_forest.add_node(node) + mast_forest.add_node(node).unwrap() }; let join_node_id = { let node = MastNode::new_join(basic_block_id, call_node_id, &mast_forest); - mast_forest.add_node(node) + mast_forest.add_node(node).unwrap() }; let split_node_id = { let node = MastNode::new_split(basic_block_id, call_node_id, &mast_forest); - mast_forest.add_node(node) + mast_forest.add_node(node).unwrap() }; let dyn_node_id = { let node = MastNode::new_dynexec(); - mast_forest.add_node(node) + mast_forest.add_node(node).unwrap() }; let external_node_id = { let node = MastNode::new_external(RpoDigest::default()); - mast_forest.add_node(node) + mast_forest.add_node(node).unwrap() }; mast_forest.make_root(join_node_id); diff --git a/miden/tests/integration/operations/io_ops/env_ops.rs b/miden/tests/integration/operations/io_ops/env_ops.rs index b0736ea74e..379a974652 100644 --- a/miden/tests/integration/operations/io_ops/env_ops.rs +++ b/miden/tests/integration/operations/io_ops/env_ops.rs @@ -161,7 +161,7 @@ fn build_bar_hash() -> [u64; 4] { let mut mast_forest = MastForest::new(); let foo_root = MastNode::new_basic_block(vec![Operation::Caller]); - let foo_root_id = mast_forest.add_node(foo_root); + let foo_root_id = mast_forest.add_node(foo_root).unwrap(); let bar_root = MastNode::new_syscall(foo_root_id, &mast_forest); let bar_hash: Word = bar_root.digest().into(); diff --git a/processor/src/chiplets/hasher/tests.rs b/processor/src/chiplets/hasher/tests.rs index 912aeadc28..44f814ac58 100644 --- a/processor/src/chiplets/hasher/tests.rs +++ b/processor/src/chiplets/hasher/tests.rs @@ -249,19 +249,19 @@ fn hash_memoization_control_blocks() { let mut mast_forest = MastForest::new(); let t_branch = MastNode::new_basic_block(vec![Operation::Push(ZERO)]); - let t_branch_id = mast_forest.add_node(t_branch.clone()); + let t_branch_id = mast_forest.add_node(t_branch.clone()).unwrap(); let f_branch = MastNode::new_basic_block(vec![Operation::Push(ONE)]); - let f_branch_id = mast_forest.add_node(f_branch.clone()); + let f_branch_id = mast_forest.add_node(f_branch.clone()).unwrap(); let split1 = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); - let split1_id = mast_forest.add_node(split1.clone()); + let split1_id = mast_forest.add_node(split1.clone()).unwrap(); let split2 = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); - let split2_id = mast_forest.add_node(split2.clone()); + let split2_id = mast_forest.add_node(split2.clone()).unwrap(); let join_node = MastNode::new_join(split1_id, split2_id, &mast_forest); - let _join_node_id = mast_forest.add_node(join_node.clone()); + let _join_node_id = mast_forest.add_node(join_node.clone()).unwrap(); let mut hasher = Hasher::default(); let h1: [Felt; DIGEST_LEN] = split1 @@ -414,19 +414,19 @@ fn hash_memoization_basic_blocks_check(basic_block: MastNode) { let mut mast_forest = MastForest::new(); let basic_block_1 = basic_block.clone(); - let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()); + let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()).unwrap(); let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Eq, Operation::Not]); - let loop_body_id = mast_forest.add_node(loop_body); + let loop_body_id = mast_forest.add_node(loop_body).unwrap(); let loop_block = MastNode::new_loop(loop_body_id, &mast_forest); - let loop_block_id = mast_forest.add_node(loop_block.clone()); + let loop_block_id = mast_forest.add_node(loop_block.clone()).unwrap(); let join2_block = MastNode::new_join(basic_block_1_id, loop_block_id, &mast_forest); - let join2_block_id = mast_forest.add_node(join2_block.clone()); + let join2_block_id = mast_forest.add_node(join2_block.clone()).unwrap(); let basic_block_2 = basic_block; - let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()); + let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()).unwrap(); let join1_block = MastNode::new_join(join2_block_id, basic_block_2_id, &mast_forest); diff --git a/processor/src/chiplets/tests.rs b/processor/src/chiplets/tests.rs index 8353376989..f45cf25e27 100644 --- a/processor/src/chiplets/tests.rs +++ b/processor/src/chiplets/tests.rs @@ -120,7 +120,7 @@ fn build_trace( let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(operations); - let basic_block_id = mast_forest.add_node(basic_block); + let basic_block_id = mast_forest.add_node(basic_block).unwrap(); Program::new(mast_forest, basic_block_id) }; diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index 598db1abdf..0cbdbd8f82 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -50,7 +50,7 @@ fn basic_block_one_group() { let mut mast_forest = MastForest::new(); let basic_block_node = MastNode::Block(basic_block.clone()); - let basic_block_id = mast_forest.add_node(basic_block_node); + let basic_block_id = mast_forest.add_node(basic_block_node).unwrap(); Program::new(mast_forest, basic_block_id) }; @@ -96,7 +96,7 @@ fn basic_block_small() { let mut mast_forest = MastForest::new(); let basic_block_node = MastNode::Block(basic_block.clone()); - let basic_block_id = mast_forest.add_node(basic_block_node); + let basic_block_id = mast_forest.add_node(basic_block_node).unwrap(); Program::new(mast_forest, basic_block_id) }; @@ -159,7 +159,7 @@ fn basic_block() { let mut mast_forest = MastForest::new(); let basic_block_node = MastNode::Block(basic_block.clone()); - let basic_block_id = mast_forest.add_node(basic_block_node); + let basic_block_id = mast_forest.add_node(basic_block_node).unwrap(); Program::new(mast_forest, basic_block_id) }; @@ -251,7 +251,7 @@ fn span_block_with_respan() { let mut mast_forest = MastForest::new(); let basic_block_node = MastNode::Block(basic_block.clone()); - let basic_block_id = mast_forest.add_node(basic_block_node); + let basic_block_id = mast_forest.add_node(basic_block_node).unwrap(); Program::new(mast_forest, basic_block_id) }; @@ -324,11 +324,11 @@ fn join_node() { let program = { let mut mast_forest = MastForest::new(); - let basic_block1_id = mast_forest.add_node(basic_block1.clone()); - let basic_block2_id = mast_forest.add_node(basic_block2.clone()); + let basic_block1_id = mast_forest.add_node(basic_block1.clone()).unwrap(); + let basic_block2_id = mast_forest.add_node(basic_block2.clone()).unwrap(); let join_node = MastNode::new_join(basic_block1_id, basic_block2_id, &mast_forest); - let join_node_id = mast_forest.add_node(join_node); + let join_node_id = mast_forest.add_node(join_node).unwrap(); Program::new(mast_forest, join_node_id) }; @@ -390,11 +390,11 @@ fn split_node_true() { let program = { let mut mast_forest = MastForest::new(); - let basic_block1_id = mast_forest.add_node(basic_block1.clone()); - let basic_block2_id = mast_forest.add_node(basic_block2.clone()); + let basic_block1_id = mast_forest.add_node(basic_block1.clone()).unwrap(); + let basic_block2_id = mast_forest.add_node(basic_block2.clone()).unwrap(); let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest); - let split_node_id = mast_forest.add_node(split_node); + let split_node_id = mast_forest.add_node(split_node).unwrap(); Program::new(mast_forest, split_node_id) }; @@ -443,11 +443,11 @@ fn split_node_false() { let program = { let mut mast_forest = MastForest::new(); - let basic_block1_id = mast_forest.add_node(basic_block1.clone()); - let basic_block2_id = mast_forest.add_node(basic_block2.clone()); + let basic_block1_id = mast_forest.add_node(basic_block1.clone()).unwrap(); + let basic_block2_id = mast_forest.add_node(basic_block2.clone()).unwrap(); let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest); - let split_node_id = mast_forest.add_node(split_node); + let split_node_id = mast_forest.add_node(split_node).unwrap(); Program::new(mast_forest, split_node_id) }; @@ -498,10 +498,10 @@ fn loop_node() { let program = { let mut mast_forest = MastForest::new(); - let loop_body_id = mast_forest.add_node(loop_body.clone()); + let loop_body_id = mast_forest.add_node(loop_body.clone()).unwrap(); let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); - let loop_node_id = mast_forest.add_node(loop_node); + let loop_node_id = mast_forest.add_node(loop_node).unwrap(); Program::new(mast_forest, loop_node_id) }; @@ -551,10 +551,10 @@ fn loop_node_skip() { let program = { let mut mast_forest = MastForest::new(); - let loop_body_id = mast_forest.add_node(loop_body.clone()); + let loop_body_id = mast_forest.add_node(loop_body.clone()).unwrap(); let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); - let loop_node_id = mast_forest.add_node(loop_node); + let loop_node_id = mast_forest.add_node(loop_node).unwrap(); Program::new(mast_forest, loop_node_id) }; @@ -594,10 +594,10 @@ fn loop_node_repeat() { let program = { let mut mast_forest = MastForest::new(); - let loop_body_id = mast_forest.add_node(loop_body.clone()); + let loop_body_id = mast_forest.add_node(loop_body.clone()).unwrap(); let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); - let loop_node_id = mast_forest.add_node(loop_node); + let loop_node_id = mast_forest.add_node(loop_node).unwrap(); Program::new(mast_forest, loop_node_id) }; @@ -683,24 +683,24 @@ fn call_block() { Operation::FmpUpdate, Operation::Pad, ]); - let first_basic_block_id = mast_forest.add_node(first_basic_block.clone()); + let first_basic_block_id = mast_forest.add_node(first_basic_block.clone()).unwrap(); let foo_root_node = MastNode::new_basic_block(vec![ Operation::Push(ONE), Operation::FmpUpdate ]); - let foo_root_node_id = mast_forest.add_node(foo_root_node.clone()); + let foo_root_node_id = mast_forest.add_node(foo_root_node.clone()).unwrap(); let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd]); - let last_basic_block_id = mast_forest.add_node(last_basic_block.clone()); + let last_basic_block_id = mast_forest.add_node(last_basic_block.clone()).unwrap(); let foo_call_node = MastNode::new_call(foo_root_node_id, &mast_forest); - let foo_call_node_id = mast_forest.add_node(foo_call_node.clone()); + let foo_call_node_id = mast_forest.add_node(foo_call_node.clone()).unwrap(); let join1_node = MastNode::new_join(first_basic_block_id, foo_call_node_id, &mast_forest); - let join1_node_id = mast_forest.add_node(join1_node.clone()); + let join1_node_id = mast_forest.add_node(join1_node.clone()).unwrap(); let program_root = MastNode::new_join(join1_node_id, last_basic_block_id, &mast_forest); - let program_root_id = mast_forest.add_node(program_root); + let program_root_id = mast_forest.add_node(program_root).unwrap(); let program = Program::new(mast_forest, program_root_id); @@ -893,19 +893,19 @@ fn syscall_block() { // build foo procedure body let foo_root = MastNode::new_basic_block(vec![Operation::Push(THREE), Operation::FmpUpdate]); - let foo_root_id = mast_forest.add_node(foo_root.clone()); + let foo_root_id = mast_forest.add_node(foo_root.clone()).unwrap(); mast_forest.make_root(foo_root_id); let kernel = Kernel::new(&[foo_root.digest()]).unwrap(); // build bar procedure body let bar_basic_block = MastNode::new_basic_block(vec![Operation::Push(TWO), Operation::FmpUpdate]); - let bar_basic_block_id = mast_forest.add_node(bar_basic_block.clone()); + let bar_basic_block_id = mast_forest.add_node(bar_basic_block.clone()).unwrap(); let foo_call_node = MastNode::new_syscall(foo_root_id, &mast_forest); - let foo_call_node_id = mast_forest.add_node(foo_call_node.clone()); + let foo_call_node_id = mast_forest.add_node(foo_call_node.clone()).unwrap(); let bar_root_node = MastNode::new_join(bar_basic_block_id, foo_call_node_id, &mast_forest); - let bar_root_node_id = mast_forest.add_node(bar_root_node.clone()); + let bar_root_node_id = mast_forest.add_node(bar_root_node.clone()).unwrap(); mast_forest.make_root(bar_root_node_id); // build the program @@ -914,19 +914,19 @@ fn syscall_block() { Operation::FmpUpdate, Operation::Pad, ]); - let first_basic_block_id = mast_forest.add_node(first_basic_block.clone()); + let first_basic_block_id = mast_forest.add_node(first_basic_block.clone()).unwrap(); let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd]); - let last_basic_block_id = mast_forest.add_node(last_basic_block.clone()); + let last_basic_block_id = mast_forest.add_node(last_basic_block.clone()).unwrap(); let bar_call_node = MastNode::new_call(bar_root_node_id, &mast_forest); - let bar_call_node_id = mast_forest.add_node(bar_call_node.clone()); + let bar_call_node_id = mast_forest.add_node(bar_call_node.clone()).unwrap(); let inner_join_node = MastNode::new_join(first_basic_block_id, bar_call_node_id, &mast_forest); - let inner_join_node_id = mast_forest.add_node(inner_join_node.clone()); + let inner_join_node_id = mast_forest.add_node(inner_join_node.clone()).unwrap(); let program_root_node = MastNode::new_join(inner_join_node_id, last_basic_block_id, &mast_forest); - let program_root_node_id = mast_forest.add_node(program_root_node.clone()); + let program_root_node_id = mast_forest.add_node(program_root_node.clone()).unwrap(); let program = Program::with_kernel(mast_forest, program_root_node_id, kernel.clone()); @@ -1181,24 +1181,24 @@ fn dyn_block() { let mut mast_forest = MastForest::new(); let foo_root_node = MastNode::new_basic_block(vec![Operation::Push(ONE), Operation::Add]); - let foo_root_node_id = mast_forest.add_node(foo_root_node.clone()); + let foo_root_node_id = mast_forest.add_node(foo_root_node.clone()).unwrap(); mast_forest.make_root(foo_root_node_id); let mul_bb_node = MastNode::new_basic_block(vec![Operation::Mul]); - let mul_bb_node_id = mast_forest.add_node(mul_bb_node.clone()); + let mul_bb_node_id = mast_forest.add_node(mul_bb_node.clone()).unwrap(); let save_bb_node = MastNode::new_basic_block(vec![Operation::MovDn4]); - let save_bb_node_id = mast_forest.add_node(save_bb_node.clone()); + let save_bb_node_id = mast_forest.add_node(save_bb_node.clone()).unwrap(); let join_node = MastNode::new_join(mul_bb_node_id, save_bb_node_id, &mast_forest); - let join_node_id = mast_forest.add_node(join_node.clone()); + let join_node_id = mast_forest.add_node(join_node.clone()).unwrap(); // This dyn will point to foo. let dyn_node = MastNode::new_dynexec(); - let dyn_node_id = mast_forest.add_node(dyn_node.clone()); + let dyn_node_id = mast_forest.add_node(dyn_node.clone()).unwrap(); let program_root_node = MastNode::new_join(join_node_id, dyn_node_id, &mast_forest); - let program_root_node_id = mast_forest.add_node(program_root_node.clone()); + let program_root_node_id = mast_forest.add_node(program_root_node.clone()).unwrap(); let program = Program::new(mast_forest, program_root_node_id); @@ -1306,7 +1306,7 @@ fn set_user_op_helpers_many() { let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(vec![Operation::U32div]); - let basic_block_id = mast_forest.add_node(basic_block); + let basic_block_id = mast_forest.add_node(basic_block).unwrap(); Program::new(mast_forest, basic_block_id) }; diff --git a/processor/src/trace/tests/chiplets/hasher.rs b/processor/src/trace/tests/chiplets/hasher.rs index 992cee60d4..84e29fe0c2 100644 --- a/processor/src/trace/tests/chiplets/hasher.rs +++ b/processor/src/trace/tests/chiplets/hasher.rs @@ -51,7 +51,7 @@ pub fn b_chip_span() { let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(vec![Operation::Add, Operation::Mul]); - let basic_block_id = mast_forest.add_node(basic_block); + let basic_block_id = mast_forest.add_node(basic_block).unwrap(); Program::new(mast_forest, basic_block_id) }; @@ -124,7 +124,7 @@ pub fn b_chip_span_with_respan() { let (ops, _) = build_span_with_respan_ops(); let basic_block = MastNode::new_basic_block(ops); - let basic_block_id = mast_forest.add_node(basic_block); + let basic_block_id = mast_forest.add_node(basic_block).unwrap(); Program::new(mast_forest, basic_block_id) }; @@ -216,13 +216,13 @@ pub fn b_chip_merge() { let mut mast_forest = MastForest::new(); let t_branch = MastNode::new_basic_block(vec![Operation::Add]); - let t_branch_id = mast_forest.add_node(t_branch); + let t_branch_id = mast_forest.add_node(t_branch).unwrap(); let f_branch = MastNode::new_basic_block(vec![Operation::Mul]); - let f_branch_id = mast_forest.add_node(f_branch); + let f_branch_id = mast_forest.add_node(f_branch).unwrap(); let split = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); - let split_id = mast_forest.add_node(split); + let split_id = mast_forest.add_node(split).unwrap(); Program::new(mast_forest, split_id) }; @@ -335,7 +335,7 @@ pub fn b_chip_permutation() { let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(vec![Operation::HPerm]); - let basic_block_id = mast_forest.add_node(basic_block); + let basic_block_id = mast_forest.add_node(basic_block).unwrap(); Program::new(mast_forest, basic_block_id) }; diff --git a/processor/src/trace/tests/decoder.rs b/processor/src/trace/tests/decoder.rs index 3b20428865..83246bdf5d 100644 --- a/processor/src/trace/tests/decoder.rs +++ b/processor/src/trace/tests/decoder.rs @@ -73,13 +73,13 @@ fn decoder_p1_join() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block_1_id = mast_forest.add_node(basic_block_1); + let basic_block_1_id = mast_forest.add_node(basic_block_1).unwrap(); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.add_node(basic_block_2); + let basic_block_2_id = mast_forest.add_node(basic_block_2).unwrap(); let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); - let join_id = mast_forest.add_node(join); + let join_id = mast_forest.add_node(join).unwrap(); Program::new(mast_forest, join_id) }; @@ -144,13 +144,13 @@ fn decoder_p1_split() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block_1_id = mast_forest.add_node(basic_block_1); + let basic_block_1_id = mast_forest.add_node(basic_block_1).unwrap(); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.add_node(basic_block_2); + let basic_block_2_id = mast_forest.add_node(basic_block_2).unwrap(); let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); - let split_id = mast_forest.add_node(split); + let split_id = mast_forest.add_node(split).unwrap(); Program::new(mast_forest, split_id) }; @@ -202,16 +202,16 @@ fn decoder_p1_loop_with_repeat() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Pad]); - let basic_block_1_id = mast_forest.add_node(basic_block_1); + let basic_block_1_id = mast_forest.add_node(basic_block_1).unwrap(); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Drop]); - let basic_block_2_id = mast_forest.add_node(basic_block_2); + let basic_block_2_id = mast_forest.add_node(basic_block_2).unwrap(); let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); - let join_id = mast_forest.add_node(join); + let join_id = mast_forest.add_node(join).unwrap(); let loop_node = MastNode::new_loop(join_id, &mast_forest); - let loop_node_id = mast_forest.add_node(loop_node); + let loop_node_id = mast_forest.add_node(loop_node).unwrap(); Program::new(mast_forest, loop_node_id) }; @@ -333,7 +333,7 @@ fn decoder_p2_span_with_respan() { let (ops, _) = build_span_with_respan_ops(); let basic_block = MastNode::new_basic_block(ops); - let basic_block_id = mast_forest.add_node(basic_block); + let basic_block_id = mast_forest.add_node(basic_block).unwrap(); Program::new(mast_forest, basic_block_id) }; @@ -369,13 +369,13 @@ fn decoder_p2_join() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()); + let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()).unwrap(); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()); + let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()).unwrap(); let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); - let join_id = mast_forest.add_node(join.clone()); + let join_id = mast_forest.add_node(join.clone()).unwrap(); let program = Program::new(mast_forest, join_id); @@ -434,13 +434,13 @@ fn decoder_p2_split_true() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()); + let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()).unwrap(); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.add_node(basic_block_2); + let basic_block_2_id = mast_forest.add_node(basic_block_2).unwrap(); let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); - let split_id = mast_forest.add_node(split); + let split_id = mast_forest.add_node(split).unwrap(); let program = Program::new(mast_forest, split_id); @@ -490,13 +490,13 @@ fn decoder_p2_split_false() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()); + let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()).unwrap(); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()); + let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()).unwrap(); let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); - let split_id = mast_forest.add_node(split); + let split_id = mast_forest.add_node(split).unwrap(); let program = Program::new(mast_forest, split_id); @@ -546,16 +546,16 @@ fn decoder_p2_loop_with_repeat() { let mut mast_forest = MastForest::new(); let basic_block_1 = MastNode::new_basic_block(vec![Operation::Pad]); - let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()); + let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()).unwrap(); let basic_block_2 = MastNode::new_basic_block(vec![Operation::Drop]); - let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()); + let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()).unwrap(); let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); - let join_id = mast_forest.add_node(join.clone()); + let join_id = mast_forest.add_node(join.clone()).unwrap(); let loop_node = MastNode::new_loop(join_id, &mast_forest); - let loop_node_id = mast_forest.add_node(loop_node); + let loop_node_id = mast_forest.add_node(loop_node).unwrap(); let program = Program::new(mast_forest, loop_node_id); diff --git a/processor/src/trace/tests/mod.rs b/processor/src/trace/tests/mod.rs index 19c11defbc..3d9e4c1f85 100644 --- a/processor/src/trace/tests/mod.rs +++ b/processor/src/trace/tests/mod.rs @@ -35,7 +35,7 @@ pub fn build_trace_from_ops(operations: Vec, stack: &[u64]) -> Execut let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(operations); - let basic_block_id = mast_forest.add_node(basic_block); + let basic_block_id = mast_forest.add_node(basic_block).unwrap(); let program = Program::new(mast_forest, basic_block_id); @@ -57,7 +57,7 @@ pub fn build_trace_from_ops_with_inputs( let mut mast_forest = MastForest::new(); let basic_block = MastNode::new_basic_block(operations); - let basic_block_id = mast_forest.add_node(basic_block); + let basic_block_id = mast_forest.add_node(basic_block).unwrap(); let program = Program::new(mast_forest, basic_block_id);