From c8a7b85db24d2432a4e6086d2f71c86f26f32caf Mon Sep 17 00:00:00 2001 From: sergerad Date: Wed, 24 Jul 2024 20:18:48 +1200 Subject: [PATCH 1/7] add helper methods for adding nodes to MastForest --- assembly/src/assembler/tests.rs | 23 ++------ core/src/mast/mod.rs | 59 +++++++++++++++++++ core/src/mast/serialization/tests.rs | 38 +++--------- .../integration/operations/io_ops/env_ops.rs | 3 +- processor/src/chiplets/hasher/tests.rs | 5 +- processor/src/chiplets/tests.rs | 3 +- processor/src/decoder/tests.rs | 24 +++----- processor/src/trace/tests/chiplets/hasher.rs | 21 +++---- processor/src/trace/tests/decoder.rs | 45 +++++--------- processor/src/trace/tests/mod.rs | 6 +- 10 files changed, 111 insertions(+), 116 deletions(-) diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index b0c62bb446..9fbb410eb0 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -248,29 +248,18 @@ fn duplicate_nodes() { let mut expected_mast_forest = MastForest::new(); // basic block: mul - let mul_basic_block_id = { - let node = MastNode::new_basic_block(vec![Operation::Mul]); - expected_mast_forest.add_node(node).unwrap() - }; + let mul_basic_block_id = expected_mast_forest.add_block(vec![Operation::Mul], None).unwrap(); // basic block: add - let add_basic_block_id = { - let node = MastNode::new_basic_block(vec![Operation::Add]); - expected_mast_forest.add_node(node).unwrap() - }; + let add_basic_block_id = expected_mast_forest.add_block(vec![Operation::Add], None).unwrap(); // inner split: `if.true add else mul end` - let inner_split_id = { - let node = - MastNode::new_split(add_basic_block_id, mul_basic_block_id, &expected_mast_forest); - expected_mast_forest.add_node(node).unwrap() - }; + let inner_split_id = + expected_mast_forest.add_split(add_basic_block_id, mul_basic_block_id).unwrap(); // root: outer split - let root_id = { - let node = MastNode::new_split(mul_basic_block_id, inner_split_id, &expected_mast_forest); - expected_mast_forest.add_node(node).unwrap() - }; + let root_id = expected_mast_forest.add_split(mul_basic_block_id, inner_split_id).unwrap(); + expected_mast_forest.make_root(root_id); let expected_program = Program::new(expected_mast_forest, root_id); diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index a6d73fe29a..82eb6f8e96 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -10,6 +10,8 @@ pub use node::{ }; use winter_utils::DeserializationError; +use crate::{DecoratorList, Operation}; + mod serialization; #[cfg(test)] @@ -60,6 +62,63 @@ impl MastForest { Ok(new_node_id) } + /// Adds a basic block node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn add_block( + &mut self, + operations: Vec, + decorators: Option, + ) -> Result { + match decorators { + Some(decorators) => { + self.add_node(MastNode::new_basic_block_with_decorators(operations, decorators)) + } + None => self.add_node(MastNode::new_basic_block(operations)), + } + } + + /// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn add_join( + &mut self, + left_child: MastNodeId, + right_child: MastNodeId, + ) -> Result { + self.add_node(MastNode::new_join(left_child, right_child, self)) + } + + /// Adds a split node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn add_split( + &mut self, + if_branch: MastNodeId, + else_branch: MastNodeId, + ) -> Result { + self.add_node(MastNode::new_split(if_branch, else_branch, self)) + } + + /// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn add_loop(&mut self, body: MastNodeId) -> Result { + self.add_node(MastNode::new_loop(body, self)) + } + + /// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn add_call(&mut self, callee: MastNodeId) -> Result { + self.add_node(MastNode::new_call(callee, self)) + } + + /// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn add_syscall(&mut self, callee: MastNodeId) -> Result { + self.add_node(MastNode::new_syscall(callee, self)) + } + + /// Adds a dyn node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn add_dyn(&mut self) -> Result { + self.add_node(MastNode::new_dyn()) + } + + /// Adds an external node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn add_external(&mut self, mast_root: RpoDigest) -> Result { + self.add_node(MastNode::new_external(mast_root)) + } + /// Marks the given [`MastNodeId`] as being the root of a procedure. /// /// # Panics diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs index e07e92296d..1d3e2d90be 100644 --- a/core/src/mast/serialization/tests.rs +++ b/core/src/mast/serialization/tests.rs @@ -295,41 +295,19 @@ fn serialize_deserialize_all_nodes() { (num_operations, Decorator::Trace(55)), ]; - let basic_block_node = MastNode::new_basic_block_with_decorators(operations, decorators); - mast_forest.add_node(basic_block_node).unwrap() + mast_forest.add_block(operations, Some(decorators)).unwrap() }; - let call_node_id = { - let node = MastNode::new_call(basic_block_id, &mast_forest); - mast_forest.add_node(node).unwrap() - }; + let call_node_id = mast_forest.add_call(basic_block_id).unwrap(); - let syscall_node_id = { - let node = MastNode::new_syscall(basic_block_id, &mast_forest); - mast_forest.add_node(node).unwrap() - }; + let syscall_node_id = mast_forest.add_syscall(basic_block_id).unwrap(); - let loop_node_id = { - let node = MastNode::new_loop(basic_block_id, &mast_forest); - mast_forest.add_node(node).unwrap() - }; - let join_node_id = { - let node = MastNode::new_join(basic_block_id, call_node_id, &mast_forest); - mast_forest.add_node(node).unwrap() - }; - let split_node_id = { - let node = MastNode::new_split(basic_block_id, call_node_id, &mast_forest); - mast_forest.add_node(node).unwrap() - }; - let dyn_node_id = { - let node = MastNode::new_dyn(); - mast_forest.add_node(node).unwrap() - }; + let loop_node_id = mast_forest.add_loop(basic_block_id).unwrap(); + let join_node_id = mast_forest.add_join(basic_block_id, call_node_id).unwrap(); + let split_node_id = mast_forest.add_split(basic_block_id, call_node_id).unwrap(); + let dyn_node_id = mast_forest.add_dyn().unwrap(); - let external_node_id = { - let node = MastNode::new_external(RpoDigest::default()); - mast_forest.add_node(node).unwrap() - }; + let external_node_id = mast_forest.add_external(RpoDigest::default()).unwrap(); mast_forest.make_root(join_node_id); mast_forest.make_root(syscall_node_id); diff --git a/miden/tests/integration/operations/io_ops/env_ops.rs b/miden/tests/integration/operations/io_ops/env_ops.rs index 367a121df9..733be8cc2e 100644 --- a/miden/tests/integration/operations/io_ops/env_ops.rs +++ b/miden/tests/integration/operations/io_ops/env_ops.rs @@ -162,8 +162,7 @@ fn caller() { fn build_bar_hash() -> [u64; 4] { let mut mast_forest = MastForest::new(); - let foo_root = MastNode::new_basic_block(vec![Operation::Caller]); - let foo_root_id = mast_forest.add_node(foo_root).unwrap(); + let foo_root_id = mast_forest.add_block(vec![Operation::Caller], None).unwrap(); let bar_root = MastNode::new_syscall(foo_root_id, &mast_forest); let bar_hash: Word = bar_root.digest().into(); diff --git a/processor/src/chiplets/hasher/tests.rs b/processor/src/chiplets/hasher/tests.rs index 69e70d581f..ef47f9dfc9 100644 --- a/processor/src/chiplets/hasher/tests.rs +++ b/processor/src/chiplets/hasher/tests.rs @@ -416,8 +416,9 @@ fn hash_memoization_basic_blocks_check(basic_block: MastNode) { let basic_block_1 = basic_block.clone(); let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()).unwrap(); - let loop_body = MastNode::new_basic_block(vec![Operation::Pad, Operation::Eq, Operation::Not]); - let loop_body_id = mast_forest.add_node(loop_body).unwrap(); + let loop_body_id = mast_forest + .add_block(vec![Operation::Pad, Operation::Eq, Operation::Not], None) + .unwrap(); let loop_block = MastNode::new_loop(loop_body_id, &mast_forest); let loop_block_id = mast_forest.add_node(loop_block.clone()).unwrap(); diff --git a/processor/src/chiplets/tests.rs b/processor/src/chiplets/tests.rs index f45cf25e27..075c2e42e7 100644 --- a/processor/src/chiplets/tests.rs +++ b/processor/src/chiplets/tests.rs @@ -119,8 +119,7 @@ fn build_trace( let program = { let mut mast_forest = MastForest::new(); - let basic_block = MastNode::new_basic_block(operations); - let basic_block_id = mast_forest.add_node(basic_block).unwrap(); + let basic_block_id = mast_forest.add_block(operations, None).unwrap(); Program::new(mast_forest, basic_block_id) }; diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index 8e84ecc5be..281e4d4f0f 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -327,8 +327,7 @@ fn join_node() { let basic_block1_id = mast_forest.add_node(basic_block1.clone()).unwrap(); let basic_block2_id = mast_forest.add_node(basic_block2.clone()).unwrap(); - let join_node = MastNode::new_join(basic_block1_id, basic_block2_id, &mast_forest); - let join_node_id = mast_forest.add_node(join_node).unwrap(); + let join_node_id = mast_forest.add_join(basic_block1_id, basic_block2_id).unwrap(); Program::new(mast_forest, join_node_id) }; @@ -393,8 +392,7 @@ fn split_node_true() { let basic_block1_id = mast_forest.add_node(basic_block1.clone()).unwrap(); let basic_block2_id = mast_forest.add_node(basic_block2.clone()).unwrap(); - let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest); - let split_node_id = mast_forest.add_node(split_node).unwrap(); + let split_node_id = mast_forest.add_split(basic_block1_id, basic_block2_id).unwrap(); Program::new(mast_forest, split_node_id) }; @@ -446,8 +444,7 @@ fn split_node_false() { let basic_block1_id = mast_forest.add_node(basic_block1.clone()).unwrap(); let basic_block2_id = mast_forest.add_node(basic_block2.clone()).unwrap(); - let split_node = MastNode::new_split(basic_block1_id, basic_block2_id, &mast_forest); - let split_node_id = mast_forest.add_node(split_node).unwrap(); + let split_node_id = mast_forest.add_split(basic_block1_id, basic_block2_id).unwrap(); Program::new(mast_forest, split_node_id) }; @@ -500,8 +497,7 @@ fn loop_node() { let loop_body_id = mast_forest.add_node(loop_body.clone()).unwrap(); - let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); - let loop_node_id = mast_forest.add_node(loop_node).unwrap(); + let loop_node_id = mast_forest.add_loop(loop_body_id).unwrap(); Program::new(mast_forest, loop_node_id) }; @@ -553,8 +549,7 @@ fn loop_node_skip() { let loop_body_id = mast_forest.add_node(loop_body.clone()).unwrap(); - let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); - let loop_node_id = mast_forest.add_node(loop_node).unwrap(); + let loop_node_id = mast_forest.add_loop(loop_body_id).unwrap(); Program::new(mast_forest, loop_node_id) }; @@ -596,8 +591,7 @@ fn loop_node_repeat() { let loop_body_id = mast_forest.add_node(loop_body.clone()).unwrap(); - let loop_node = MastNode::new_loop(loop_body_id, &mast_forest); - let loop_node_id = mast_forest.add_node(loop_node).unwrap(); + let loop_node_id = mast_forest.add_loop(loop_body_id).unwrap(); Program::new(mast_forest, loop_node_id) }; @@ -699,8 +693,7 @@ fn call_block() { let join1_node = MastNode::new_join(first_basic_block_id, foo_call_node_id, &mast_forest); let join1_node_id = mast_forest.add_node(join1_node.clone()).unwrap(); - let program_root = MastNode::new_join(join1_node_id, last_basic_block_id, &mast_forest); - let program_root_id = mast_forest.add_node(program_root).unwrap(); + let program_root_id = mast_forest.add_join(join1_node_id, last_basic_block_id).unwrap(); let program = Program::new(mast_forest, program_root_id); @@ -1305,8 +1298,7 @@ fn set_user_op_helpers_many() { let program = { let mut mast_forest = MastForest::new(); - let basic_block = MastNode::new_basic_block(vec![Operation::U32div]); - let basic_block_id = mast_forest.add_node(basic_block).unwrap(); + let basic_block_id = mast_forest.add_block(vec![Operation::U32div], None).unwrap(); Program::new(mast_forest, basic_block_id) }; diff --git a/processor/src/trace/tests/chiplets/hasher.rs b/processor/src/trace/tests/chiplets/hasher.rs index 84e29fe0c2..38be61b933 100644 --- a/processor/src/trace/tests/chiplets/hasher.rs +++ b/processor/src/trace/tests/chiplets/hasher.rs @@ -22,7 +22,7 @@ use miden_air::trace::{ use vm_core::{ chiplets::hasher::apply_permutation, crypto::merkle::{MerkleStore, MerkleTree, NodeIndex}, - mast::{MastForest, MastNode}, + mast::MastForest, utils::range, Program, Word, }; @@ -50,8 +50,8 @@ pub fn b_chip_span() { let program = { let mut mast_forest = MastForest::new(); - let basic_block = MastNode::new_basic_block(vec![Operation::Add, Operation::Mul]); - let basic_block_id = mast_forest.add_node(basic_block).unwrap(); + let basic_block_id = + mast_forest.add_block(vec![Operation::Add, Operation::Mul], None).unwrap(); Program::new(mast_forest, basic_block_id) }; @@ -123,8 +123,7 @@ pub fn b_chip_span_with_respan() { let mut mast_forest = MastForest::new(); let (ops, _) = build_span_with_respan_ops(); - let basic_block = MastNode::new_basic_block(ops); - let basic_block_id = mast_forest.add_node(basic_block).unwrap(); + let basic_block_id = mast_forest.add_block(ops, None).unwrap(); Program::new(mast_forest, basic_block_id) }; @@ -215,14 +214,11 @@ pub fn b_chip_merge() { let program = { let mut mast_forest = MastForest::new(); - let t_branch = MastNode::new_basic_block(vec![Operation::Add]); - let t_branch_id = mast_forest.add_node(t_branch).unwrap(); + let t_branch_id = mast_forest.add_block(vec![Operation::Add], None).unwrap(); - let f_branch = MastNode::new_basic_block(vec![Operation::Mul]); - let f_branch_id = mast_forest.add_node(f_branch).unwrap(); + let f_branch_id = mast_forest.add_block(vec![Operation::Mul], None).unwrap(); - let split = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); - let split_id = mast_forest.add_node(split).unwrap(); + let split_id = mast_forest.add_split(t_branch_id, f_branch_id).unwrap(); Program::new(mast_forest, split_id) }; @@ -334,8 +330,7 @@ pub fn b_chip_permutation() { let program = { let mut mast_forest = MastForest::new(); - let basic_block = MastNode::new_basic_block(vec![Operation::HPerm]); - let basic_block_id = mast_forest.add_node(basic_block).unwrap(); + let basic_block_id = mast_forest.add_block(vec![Operation::HPerm], None).unwrap(); Program::new(mast_forest, basic_block_id) }; diff --git a/processor/src/trace/tests/decoder.rs b/processor/src/trace/tests/decoder.rs index c5acde3241..8bf191471c 100644 --- a/processor/src/trace/tests/decoder.rs +++ b/processor/src/trace/tests/decoder.rs @@ -72,14 +72,11 @@ fn decoder_p1_join() { let program = { let mut mast_forest = MastForest::new(); - let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block_1_id = mast_forest.add_node(basic_block_1).unwrap(); + let basic_block_1_id = mast_forest.add_block(vec![Operation::Mul], None).unwrap(); - let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.add_node(basic_block_2).unwrap(); + let basic_block_2_id = mast_forest.add_block(vec![Operation::Add], None).unwrap(); - let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); - let join_id = mast_forest.add_node(join).unwrap(); + let join_id = mast_forest.add_join(basic_block_1_id, basic_block_2_id).unwrap(); Program::new(mast_forest, join_id) }; @@ -143,14 +140,11 @@ fn decoder_p1_split() { let program = { let mut mast_forest = MastForest::new(); - let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); - let basic_block_1_id = mast_forest.add_node(basic_block_1).unwrap(); + let basic_block_1_id = mast_forest.add_block(vec![Operation::Mul], None).unwrap(); - let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.add_node(basic_block_2).unwrap(); + let basic_block_2_id = mast_forest.add_block(vec![Operation::Add], None).unwrap(); - let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); - let split_id = mast_forest.add_node(split).unwrap(); + let split_id = mast_forest.add_split(basic_block_1_id, basic_block_2_id).unwrap(); Program::new(mast_forest, split_id) }; @@ -201,17 +195,13 @@ fn decoder_p1_loop_with_repeat() { let program = { let mut mast_forest = MastForest::new(); - let basic_block_1 = MastNode::new_basic_block(vec![Operation::Pad]); - let basic_block_1_id = mast_forest.add_node(basic_block_1).unwrap(); + let basic_block_1_id = mast_forest.add_block(vec![Operation::Pad], None).unwrap(); - let basic_block_2 = MastNode::new_basic_block(vec![Operation::Drop]); - let basic_block_2_id = mast_forest.add_node(basic_block_2).unwrap(); + let basic_block_2_id = mast_forest.add_block(vec![Operation::Drop], None).unwrap(); - let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); - let join_id = mast_forest.add_node(join).unwrap(); + let join_id = mast_forest.add_join(basic_block_1_id, basic_block_2_id).unwrap(); - let loop_node = MastNode::new_loop(join_id, &mast_forest); - let loop_node_id = mast_forest.add_node(loop_node).unwrap(); + let loop_node_id = mast_forest.add_loop(join_id).unwrap(); Program::new(mast_forest, loop_node_id) }; @@ -332,8 +322,7 @@ fn decoder_p2_span_with_respan() { let mut mast_forest = MastForest::new(); let (ops, _) = build_span_with_respan_ops(); - let basic_block = MastNode::new_basic_block(ops); - let basic_block_id = mast_forest.add_node(basic_block).unwrap(); + let basic_block_id = mast_forest.add_block(ops, None).unwrap(); Program::new(mast_forest, basic_block_id) }; @@ -436,11 +425,9 @@ fn decoder_p2_split_true() { let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul]); let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()).unwrap(); - let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); - let basic_block_2_id = mast_forest.add_node(basic_block_2).unwrap(); + let basic_block_2_id = mast_forest.add_block(vec![Operation::Add], None).unwrap(); - let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); - let split_id = mast_forest.add_node(split).unwrap(); + let split_id = mast_forest.add_split(basic_block_1_id, basic_block_2_id).unwrap(); let program = Program::new(mast_forest, split_id); @@ -495,8 +482,7 @@ fn decoder_p2_split_false() { let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()).unwrap(); - let split = MastNode::new_split(basic_block_1_id, basic_block_2_id, &mast_forest); - let split_id = mast_forest.add_node(split).unwrap(); + let split_id = mast_forest.add_split(basic_block_1_id, basic_block_2_id).unwrap(); let program = Program::new(mast_forest, split_id); @@ -554,8 +540,7 @@ fn decoder_p2_loop_with_repeat() { let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); let join_id = mast_forest.add_node(join.clone()).unwrap(); - let loop_node = MastNode::new_loop(join_id, &mast_forest); - let loop_node_id = mast_forest.add_node(loop_node).unwrap(); + let loop_node_id = mast_forest.add_loop(join_id).unwrap(); let program = Program::new(mast_forest, loop_node_id); diff --git a/processor/src/trace/tests/mod.rs b/processor/src/trace/tests/mod.rs index 3d9e4c1f85..37dd2759da 100644 --- a/processor/src/trace/tests/mod.rs +++ b/processor/src/trace/tests/mod.rs @@ -34,8 +34,7 @@ pub fn build_trace_from_program(program: &Program, stack_inputs: &[u64]) -> Exec pub fn build_trace_from_ops(operations: Vec, stack: &[u64]) -> ExecutionTrace { let mut mast_forest = MastForest::new(); - let basic_block = MastNode::new_basic_block(operations); - let basic_block_id = mast_forest.add_node(basic_block).unwrap(); + let basic_block_id = mast_forest.add_block(operations, None).unwrap(); let program = Program::new(mast_forest, basic_block_id); @@ -56,8 +55,7 @@ pub fn build_trace_from_ops_with_inputs( Process::new(Kernel::default(), stack_inputs, host, ExecutionOptions::default()); let mut mast_forest = MastForest::new(); - let basic_block = MastNode::new_basic_block(operations); - let basic_block_id = mast_forest.add_node(basic_block).unwrap(); + let basic_block_id = mast_forest.add_block(operations, None).unwrap(); let program = Program::new(mast_forest, basic_block_id); From 674161011d29355984cd9415e32615281ef61db2 Mon Sep 17 00:00:00 2001 From: sergerad Date: Thu, 25 Jul 2024 05:28:18 +1200 Subject: [PATCH 2/7] changelog and clippy --- CHANGELOG.md | 8 ++++---- assembly/src/assembler/tests.rs | 6 +----- processor/src/chiplets/tests.rs | 5 +---- processor/src/trace/tests/mod.rs | 5 +---- 4 files changed, 7 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ecdb69af1..40e0a161c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,10 +19,10 @@ - Added serialization/deserialization for `MastForest` (#1370) - Updated CI to support `CHANGELOG.md` modification checking and `no changelog` label (#1406) - Introduced `MastForestError` to enforce `MastForest` node count invariant (#1394) -- Added functions to `MastForestBuilder` to allow ensuring of nodes with fewer LOC (#1404) -- Make `Assembler` single-use (#1409) -- Remove `ProcedureCache` from the assembler (#1411). -- Add `Assembler::assemble_library()` (#1413) +- Added functions to `MastForest` and `MastForestBuilder` to add and ensure nodes with fewer LOC (#1404, #1412) +- Made `Assembler` single-use (#1409) +- Removed `ProcedureCache` from the assembler (#1411). +- Added `Assembler::assemble_library()` (#1413) #### Changed diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index 9fbb410eb0..18a8d58ff2 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -1,10 +1,6 @@ use alloc::{boxed::Box, vec::Vec}; use pretty_assertions::assert_eq; -use vm_core::{ - assert_matches, - mast::{MastForest, MastNode}, - Program, -}; +use vm_core::{assert_matches, mast::MastForest, Program}; use super::{Assembler, Library, Operation}; use crate::{ diff --git a/processor/src/chiplets/tests.rs b/processor/src/chiplets/tests.rs index 075c2e42e7..65f57efab1 100644 --- a/processor/src/chiplets/tests.rs +++ b/processor/src/chiplets/tests.rs @@ -12,10 +12,7 @@ use miden_air::trace::{ }, CHIPLETS_RANGE, CHIPLETS_WIDTH, }; -use vm_core::{ - mast::{MastForest, MastNode}, - Felt, Program, ONE, ZERO, -}; +use vm_core::{mast::MastForest, Felt, Program, ONE, ZERO}; type ChipletsTrace = [Vec; CHIPLETS_WIDTH]; diff --git a/processor/src/trace/tests/mod.rs b/processor/src/trace/tests/mod.rs index 37dd2759da..d8017ecdd2 100644 --- a/processor/src/trace/tests/mod.rs +++ b/processor/src/trace/tests/mod.rs @@ -5,10 +5,7 @@ use super::{ use crate::{AdviceInputs, DefaultHost, ExecutionOptions, MemAdviceProvider, StackInputs}; use alloc::vec::Vec; use test_utils::rand::rand_array; -use vm_core::{ - mast::{MastForest, MastNode}, - Kernel, Operation, Program, StackOutputs, Word, ONE, ZERO, -}; +use vm_core::{mast::MastForest, Kernel, Operation, Program, StackOutputs, Word, ONE, ZERO}; mod chiplets; mod decoder; From eb551eb6adc310ed611aa004b011c653fce236e9 Mon Sep 17 00:00:00 2001 From: sergerad Date: Fri, 26 Jul 2024 05:35:29 +1200 Subject: [PATCH 3/7] add InvalidNodeId error --- core/src/mast/mod.rs | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 82eb6f8e96..1a723bdd15 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -82,7 +82,12 @@ impl MastForest { left_child: MastNodeId, right_child: MastNodeId, ) -> Result { - self.add_node(MastNode::new_join(left_child, right_child, self)) + match self.add_node(MastNode::new_join(left_child, right_child, self))? { + node_id if node_id <= left_child || node_id <= right_child => { + Err(MastForestError::InvalidNodeId(node_id)) + } + node_id => Ok(node_id), + } } /// Adds a split node to the forest, and returns the [`MastNodeId`] associated with it. @@ -91,22 +96,36 @@ impl MastForest { if_branch: MastNodeId, else_branch: MastNodeId, ) -> Result { - self.add_node(MastNode::new_split(if_branch, else_branch, self)) + match self.add_node(MastNode::new_split(if_branch, else_branch, self))? { + node_id if node_id <= if_branch || node_id <= else_branch => { + Err(MastForestError::InvalidNodeId(node_id)) + } + node_id => Ok(node_id), + } } /// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it. pub fn add_loop(&mut self, body: MastNodeId) -> Result { - self.add_node(MastNode::new_loop(body, self)) + match self.add_node(MastNode::new_loop(body, self))? { + node_id if node_id <= body => Err(MastForestError::InvalidNodeId(node_id)), + node_id => Ok(node_id), + } } /// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it. pub fn add_call(&mut self, callee: MastNodeId) -> Result { - self.add_node(MastNode::new_call(callee, self)) + match self.add_node(MastNode::new_call(callee, self))? { + node_id if node_id <= callee => Err(MastForestError::InvalidNodeId(node_id)), + node_id => Ok(node_id), + } } /// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it. pub fn add_syscall(&mut self, callee: MastNodeId) -> Result { - self.add_node(MastNode::new_syscall(callee, self)) + match self.add_node(MastNode::new_syscall(callee, self))? { + node_id if node_id <= callee => Err(MastForestError::InvalidNodeId(node_id)), + node_id => Ok(node_id), + } } /// Adds a dyn node to the forest, and returns the [`MastNodeId`] associated with it. @@ -245,4 +264,6 @@ pub enum MastForestError { MastForest::MAX_NODES )] TooManyNodes, + #[error("invalid node id: {0}")] + InvalidNodeId(MastNodeId), } From 1d319584e5c7a9a8fb8c00bddbcf14e4fb4aee70 Mon Sep 17 00:00:00 2001 From: sergerad Date: Fri, 26 Jul 2024 07:21:26 +1200 Subject: [PATCH 4/7] fix invalid node id --- core/src/mast/mod.rs | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 1a723bdd15..4a3561172b 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -83,10 +83,9 @@ impl MastForest { right_child: MastNodeId, ) -> Result { match self.add_node(MastNode::new_join(left_child, right_child, self))? { - node_id if node_id <= left_child || node_id <= right_child => { - Err(MastForestError::InvalidNodeId(node_id)) - } - node_id => Ok(node_id), + new if new <= left_child => Err(MastForestError::InvalidNodeId(left_child)), + new if new <= right_child => Err(MastForestError::InvalidNodeId(right_child)), + new => Ok(new), } } @@ -97,34 +96,33 @@ impl MastForest { else_branch: MastNodeId, ) -> Result { match self.add_node(MastNode::new_split(if_branch, else_branch, self))? { - node_id if node_id <= if_branch || node_id <= else_branch => { - Err(MastForestError::InvalidNodeId(node_id)) - } - node_id => Ok(node_id), + new if new <= if_branch => Err(MastForestError::InvalidNodeId(if_branch)), + new if new <= else_branch => Err(MastForestError::InvalidNodeId(else_branch)), + new => Ok(new), } } /// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it. pub fn add_loop(&mut self, body: MastNodeId) -> Result { match self.add_node(MastNode::new_loop(body, self))? { - node_id if node_id <= body => Err(MastForestError::InvalidNodeId(node_id)), - node_id => Ok(node_id), + new if new <= body => Err(MastForestError::InvalidNodeId(body)), + new => Ok(new), } } /// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it. pub fn add_call(&mut self, callee: MastNodeId) -> Result { match self.add_node(MastNode::new_call(callee, self))? { - node_id if node_id <= callee => Err(MastForestError::InvalidNodeId(node_id)), - node_id => Ok(node_id), + new if new <= callee => Err(MastForestError::InvalidNodeId(callee)), + new => Ok(new), } } /// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it. pub fn add_syscall(&mut self, callee: MastNodeId) -> Result { match self.add_node(MastNode::new_syscall(callee, self))? { - node_id if node_id <= callee => Err(MastForestError::InvalidNodeId(node_id)), - node_id => Ok(node_id), + new if new <= callee => Err(MastForestError::InvalidNodeId(callee)), + new => Ok(new), } } From 8addc9fc484478862d66b481f424d3219dbd1c5a Mon Sep 17 00:00:00 2001 From: sergerad Date: Fri, 26 Jul 2024 07:46:37 +1200 Subject: [PATCH 5/7] add tests --- core/src/mast/mod.rs | 2 +- core/src/mast/serialization/tests.rs | 40 +++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 4a3561172b..fffc1956fb 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -255,7 +255,7 @@ impl fmt::Display for MastNodeId { // ================================================================================================ /// Represents the types of errors that can occur when dealing with MAST forest. -#[derive(Debug, thiserror::Error)] +#[derive(Debug, thiserror::Error, PartialEq)] pub enum MastForestError { #[error( "invalid node count: MAST forest exceeds the maximum of {} nodes", diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs index 1d3e2d90be..f2856848f5 100644 --- a/core/src/mast/serialization/tests.rs +++ b/core/src/mast/serialization/tests.rs @@ -3,7 +3,8 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt}; use super::*; use crate::{ - operations::Operation, AdviceInjector, AssemblyOp, DebugOptions, Decorator, SignatureKind, + mast::MastForestError, operations::Operation, AdviceInjector, AssemblyOp, DebugOptions, + Decorator, SignatureKind, }; /// If this test fails to compile, it means that `Operation` or `Decorator` was changed. Make sure @@ -321,3 +322,40 @@ fn serialize_deserialize_all_nodes() { assert_eq!(mast_forest, deserialized_mast_forest); } + +#[test] +fn mast_forest_invalid_node_id() { + // Hydrate a forest smaller than the second + let mut forest = MastForest::new(); + let first = forest.add_block(vec![Operation::U32div], None).unwrap(); + let second = forest.add_block(vec![Operation::U32div], None).unwrap(); + + // Hydrate a forest larger than the first to get an overflow MastNodeId + let mut overflow_forest = MastForest::new(); + let overflow = (0..=3) + .map(|_| overflow_forest.add_block(vec![Operation::U32div], None).unwrap()) + .last() + .unwrap(); + + // Attempt to join with invalid ids + let join = forest.add_join(overflow, second); + assert_eq!(join, Err(MastForestError::InvalidNodeId(overflow))); + let join = forest.add_join(first, overflow); + assert_eq!(join, Err(MastForestError::InvalidNodeId(overflow))); + + // Attempt to split with invalid ids + let split = forest.add_split(overflow, second); + assert_eq!(split, Err(MastForestError::InvalidNodeId(overflow))); + let split = forest.add_split(first, overflow); + assert_eq!(split, Err(MastForestError::InvalidNodeId(overflow))); + + // Attempt to loop with invalid ids + assert_eq!(forest.add_loop(overflow), Err(MastForestError::InvalidNodeId(overflow))); + + // Attempt to call with invalid ids + assert_eq!(forest.add_call(overflow), Err(MastForestError::InvalidNodeId(overflow))); + assert_eq!(forest.add_syscall(overflow), Err(MastForestError::InvalidNodeId(overflow))); + + // Validate normal operations + forest.add_join(first, second).unwrap(); +} From 21e62914f720cfed6ee200ce7c77eb6d0ad8e05a Mon Sep 17 00:00:00 2001 From: sergerad Date: Fri, 26 Jul 2024 21:05:22 +1200 Subject: [PATCH 6/7] return results from node constructors --- assembly/src/assembler/mast_forest_builder.rs | 15 ++++--- core/src/mast/mod.rs | 42 ++++++++----------- core/src/mast/node/call_node.rs | 23 ++++++---- core/src/mast/node/join_node.rs | 15 +++++-- core/src/mast/node/loop_node.rs | 9 ++-- core/src/mast/node/mod.rs | 30 ++++++++----- core/src/mast/node/split_node.rs | 15 +++++-- core/src/mast/serialization/info.rs | 31 +++++++++++--- core/src/mast/serialization/tests.rs | 14 +++---- .../integration/operations/io_ops/env_ops.rs | 2 +- processor/src/chiplets/hasher/tests.rs | 12 +++--- processor/src/decoder/tests.rs | 18 ++++---- processor/src/trace/tests/decoder.rs | 4 +- 13 files changed, 145 insertions(+), 85 deletions(-) diff --git a/assembly/src/assembler/mast_forest_builder.rs b/assembly/src/assembler/mast_forest_builder.rs index 035eccdf7b..7bd7863adf 100644 --- a/assembly/src/assembler/mast_forest_builder.rs +++ b/assembly/src/assembler/mast_forest_builder.rs @@ -168,7 +168,8 @@ impl MastForestBuilder { left_child: MastNodeId, right_child: MastNodeId, ) -> Result { - self.ensure_node(MastNode::new_join(left_child, right_child, &self.mast_forest)) + let join = MastNode::new_join(left_child, right_child, &self.mast_forest)?; + self.ensure_node(join) } /// Adds a split node to the forest, and returns the [`MastNodeId`] associated with it. @@ -177,22 +178,26 @@ impl MastForestBuilder { if_branch: MastNodeId, else_branch: MastNodeId, ) -> Result { - self.ensure_node(MastNode::new_split(if_branch, else_branch, &self.mast_forest)) + let split = MastNode::new_split(if_branch, else_branch, &self.mast_forest)?; + self.ensure_node(split) } /// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it. pub fn ensure_loop(&mut self, body: MastNodeId) -> Result { - self.ensure_node(MastNode::new_loop(body, &self.mast_forest)) + let loop_node = MastNode::new_loop(body, &self.mast_forest)?; + self.ensure_node(loop_node) } /// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it. pub fn ensure_call(&mut self, callee: MastNodeId) -> Result { - self.ensure_node(MastNode::new_call(callee, &self.mast_forest)) + let call = MastNode::new_call(callee, &self.mast_forest)?; + self.ensure_node(call) } /// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it. pub fn ensure_syscall(&mut self, callee: MastNodeId) -> Result { - self.ensure_node(MastNode::new_syscall(callee, &self.mast_forest)) + let syscall = MastNode::new_syscall(callee, &self.mast_forest)?; + self.ensure_node(syscall) } /// Adds a dyn node to the forest, and returns the [`MastNodeId`] associated with it. diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index fffc1956fb..8f3483aad0 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -82,11 +82,8 @@ impl MastForest { left_child: MastNodeId, right_child: MastNodeId, ) -> Result { - match self.add_node(MastNode::new_join(left_child, right_child, self))? { - new if new <= left_child => Err(MastForestError::InvalidNodeId(left_child)), - new if new <= right_child => Err(MastForestError::InvalidNodeId(right_child)), - new => Ok(new), - } + let join = MastNode::new_join(left_child, right_child, self)?; + self.add_node(join) } /// Adds a split node to the forest, and returns the [`MastNodeId`] associated with it. @@ -95,35 +92,26 @@ impl MastForest { if_branch: MastNodeId, else_branch: MastNodeId, ) -> Result { - match self.add_node(MastNode::new_split(if_branch, else_branch, self))? { - new if new <= if_branch => Err(MastForestError::InvalidNodeId(if_branch)), - new if new <= else_branch => Err(MastForestError::InvalidNodeId(else_branch)), - new => Ok(new), - } + let split = MastNode::new_split(if_branch, else_branch, self)?; + self.add_node(split) } /// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it. pub fn add_loop(&mut self, body: MastNodeId) -> Result { - match self.add_node(MastNode::new_loop(body, self))? { - new if new <= body => Err(MastForestError::InvalidNodeId(body)), - new => Ok(new), - } + let loop_node = MastNode::new_loop(body, self)?; + self.add_node(loop_node) } /// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it. pub fn add_call(&mut self, callee: MastNodeId) -> Result { - match self.add_node(MastNode::new_call(callee, self))? { - new if new <= callee => Err(MastForestError::InvalidNodeId(callee)), - new => Ok(new), - } + let call = MastNode::new_call(callee, self)?; + self.add_node(call) } /// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it. pub fn add_syscall(&mut self, callee: MastNodeId) -> Result { - match self.add_node(MastNode::new_syscall(callee, self))? { - new if new <= callee => Err(MastForestError::InvalidNodeId(callee)), - new => Ok(new), - } + let syscall = MastNode::new_syscall(callee, self)?; + self.add_node(syscall) } /// Adds a dyn node to the forest, and returns the [`MastNodeId`] associated with it. @@ -233,6 +221,12 @@ impl MastNodeId { } } +impl From for usize { + fn from(value: MastNodeId) -> Self { + value.0 as usize + } +} + impl From for u32 { fn from(value: MastNodeId) -> Self { value.0 @@ -262,6 +256,6 @@ pub enum MastForestError { MastForest::MAX_NODES )] TooManyNodes, - #[error("invalid node id: {0}")] - InvalidNodeId(MastNodeId), + #[error("node id: {0} is greater than or equal to forest length: {1}")] + NodeIdOverflow(MastNodeId, usize), } diff --git a/core/src/mast/node/call_node.rs b/core/src/mast/node/call_node.rs index ca9c720195..fae63608c1 100644 --- a/core/src/mast/node/call_node.rs +++ b/core/src/mast/node/call_node.rs @@ -5,7 +5,7 @@ use miden_formatting::prettier::PrettyPrint; use crate::{ chiplets::hasher, - mast::{MastForest, MastNodeId}, + mast::{MastForest, MastForestError, MastNodeId}, OPCODE_CALL, OPCODE_SYSCALL, }; @@ -38,34 +38,43 @@ impl CallNode { /// Constructors impl CallNode { /// Returns a new [`CallNode`] instantiated with the specified callee. - pub fn new(callee: MastNodeId, mast_forest: &MastForest) -> Self { + pub fn new(callee: MastNodeId, mast_forest: &MastForest) -> Result { + if usize::from(callee) >= mast_forest.nodes.len() { + return Err(MastForestError::NodeIdOverflow(callee, mast_forest.nodes.len())); + } let digest = { let callee_digest = mast_forest[callee].digest(); hasher::merge_in_domain(&[callee_digest, RpoDigest::default()], Self::CALL_DOMAIN) }; - Self { + Ok(Self { callee, is_syscall: false, digest, - } + }) } /// Returns a new [`CallNode`] instantiated with the specified callee and marked as a kernel /// call. - pub fn new_syscall(callee: MastNodeId, mast_forest: &MastForest) -> Self { + pub fn new_syscall( + callee: MastNodeId, + mast_forest: &MastForest, + ) -> Result { + if usize::from(callee) >= mast_forest.nodes.len() { + return Err(MastForestError::NodeIdOverflow(callee, mast_forest.nodes.len())); + } let digest = { let callee_digest = mast_forest[callee].digest(); hasher::merge_in_domain(&[callee_digest, RpoDigest::default()], Self::SYSCALL_DOMAIN) }; - Self { + Ok(Self { callee, is_syscall: true, digest, - } + }) } } diff --git a/core/src/mast/node/join_node.rs b/core/src/mast/node/join_node.rs index 5f802873dd..145e9c2e5b 100644 --- a/core/src/mast/node/join_node.rs +++ b/core/src/mast/node/join_node.rs @@ -4,7 +4,7 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt}; use crate::{ chiplets::hasher, - mast::{MastForest, MastNodeId}, + mast::{MastForest, MastForestError, MastNodeId}, prettier::PrettyPrint, OPCODE_JOIN, }; @@ -29,7 +29,16 @@ impl JoinNode { /// Constructors impl JoinNode { /// Returns a new [`JoinNode`] instantiated with the specified children nodes. - pub fn new(children: [MastNodeId; 2], mast_forest: &MastForest) -> Self { + pub fn new( + children: [MastNodeId; 2], + mast_forest: &MastForest, + ) -> Result { + let forest_len = mast_forest.nodes.len(); + if usize::from(children[0]) >= forest_len { + return Err(MastForestError::NodeIdOverflow(children[0], forest_len)); + } else if usize::from(children[1]) >= forest_len { + return Err(MastForestError::NodeIdOverflow(children[1], forest_len)); + } let digest = { let left_child_hash = mast_forest[children[0]].digest(); let right_child_hash = mast_forest[children[1]].digest(); @@ -37,7 +46,7 @@ impl JoinNode { hasher::merge_in_domain(&[left_child_hash, right_child_hash], Self::DOMAIN) }; - Self { children, digest } + Ok(Self { children, digest }) } #[cfg(test)] diff --git a/core/src/mast/node/loop_node.rs b/core/src/mast/node/loop_node.rs index aec1b0b451..81309bdfb7 100644 --- a/core/src/mast/node/loop_node.rs +++ b/core/src/mast/node/loop_node.rs @@ -5,7 +5,7 @@ use miden_formatting::prettier::PrettyPrint; use crate::{ chiplets::hasher, - mast::{MastForest, MastNodeId}, + mast::{MastForest, MastForestError, MastNodeId}, OPCODE_LOOP, }; @@ -32,14 +32,17 @@ impl LoopNode { /// Constructors impl LoopNode { - pub fn new(body: MastNodeId, mast_forest: &MastForest) -> Self { + pub fn new(body: MastNodeId, mast_forest: &MastForest) -> Result { + if usize::from(body) >= mast_forest.nodes.len() { + return Err(MastForestError::NodeIdOverflow(body, mast_forest.nodes.len())); + } let digest = { let body_hash = mast_forest[body].digest(); hasher::merge_in_domain(&[body_hash, RpoDigest::default()], Self::DOMAIN) }; - Self { body, digest } + Ok(Self { body, digest }) } } diff --git a/core/src/mast/node/mod.rs b/core/src/mast/node/mod.rs index d84dcf2916..9a27ffc949 100644 --- a/core/src/mast/node/mod.rs +++ b/core/src/mast/node/mod.rs @@ -32,6 +32,8 @@ use crate::{ DecoratorList, Operation, }; +use super::MastForestError; + // MAST NODE // ================================================================================================ @@ -64,28 +66,36 @@ impl MastNode { left_child: MastNodeId, right_child: MastNodeId, mast_forest: &MastForest, - ) -> Self { - Self::Join(JoinNode::new([left_child, right_child], mast_forest)) + ) -> Result { + let join = JoinNode::new([left_child, right_child], mast_forest)?; + Ok(Self::Join(join)) } pub fn new_split( if_branch: MastNodeId, else_branch: MastNodeId, mast_forest: &MastForest, - ) -> Self { - Self::Split(SplitNode::new([if_branch, else_branch], mast_forest)) + ) -> Result { + let split = SplitNode::new([if_branch, else_branch], mast_forest)?; + Ok(Self::Split(split)) } - pub fn new_loop(body: MastNodeId, mast_forest: &MastForest) -> Self { - Self::Loop(LoopNode::new(body, mast_forest)) + pub fn new_loop(body: MastNodeId, mast_forest: &MastForest) -> Result { + let loop_node = LoopNode::new(body, mast_forest)?; + Ok(Self::Loop(loop_node)) } - pub fn new_call(callee: MastNodeId, mast_forest: &MastForest) -> Self { - Self::Call(CallNode::new(callee, mast_forest)) + pub fn new_call(callee: MastNodeId, mast_forest: &MastForest) -> Result { + let call = CallNode::new(callee, mast_forest)?; + Ok(Self::Call(call)) } - pub fn new_syscall(callee: MastNodeId, mast_forest: &MastForest) -> Self { - Self::Call(CallNode::new_syscall(callee, mast_forest)) + pub fn new_syscall( + callee: MastNodeId, + mast_forest: &MastForest, + ) -> Result { + let syscall = CallNode::new_syscall(callee, mast_forest)?; + Ok(Self::Call(syscall)) } pub fn new_dyn() -> Self { diff --git a/core/src/mast/node/split_node.rs b/core/src/mast/node/split_node.rs index f754735e35..962dacae38 100644 --- a/core/src/mast/node/split_node.rs +++ b/core/src/mast/node/split_node.rs @@ -5,7 +5,7 @@ use miden_formatting::prettier::PrettyPrint; use crate::{ chiplets::hasher, - mast::{MastForest, MastNodeId}, + mast::{MastForest, MastForestError, MastNodeId}, OPCODE_SPLIT, }; @@ -32,7 +32,16 @@ impl SplitNode { /// Constructors impl SplitNode { - pub fn new(branches: [MastNodeId; 2], mast_forest: &MastForest) -> Self { + pub fn new( + branches: [MastNodeId; 2], + mast_forest: &MastForest, + ) -> Result { + let forest_len = mast_forest.nodes.len(); + if usize::from(branches[0]) >= forest_len { + return Err(MastForestError::NodeIdOverflow(branches[0], forest_len)); + } else if usize::from(branches[1]) >= forest_len { + return Err(MastForestError::NodeIdOverflow(branches[1], forest_len)); + } let digest = { let if_branch_hash = mast_forest[branches[0]].digest(); let else_branch_hash = mast_forest[branches[1]].digest(); @@ -40,7 +49,7 @@ impl SplitNode { hasher::merge_in_domain(&[if_branch_hash, else_branch_hash], Self::DOMAIN) }; - Self { branches, digest } + Ok(Self { branches, digest }) } #[cfg(test)] diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 4646e9a20b..3729d67b8b 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -51,7 +51,12 @@ impl MastNodeInfo { let left_child = MastNodeId::from_u32_safe(left_child_id, mast_forest)?; let right_child = MastNodeId::from_u32_safe(right_child_id, mast_forest)?; - Ok(MastNode::new_join(left_child, right_child, mast_forest)) + MastNode::new_join(left_child, right_child, mast_forest).map_err(|e| { + DeserializationError::InvalidValue(format!( + "Failed to deserialize Join node: {}", + e + )) + }) } MastNodeType::Split { if_branch_id, @@ -60,22 +65,38 @@ impl MastNodeInfo { let if_branch = MastNodeId::from_u32_safe(if_branch_id, mast_forest)?; let else_branch = MastNodeId::from_u32_safe(else_branch_id, mast_forest)?; - Ok(MastNode::new_split(if_branch, else_branch, mast_forest)) + MastNode::new_split(if_branch, else_branch, mast_forest).map_err(|e| { + DeserializationError::InvalidValue(format!( + "Failed to deserialize Split node: {e}" + )) + }) } MastNodeType::Loop { body_id } => { let body_id = MastNodeId::from_u32_safe(body_id, mast_forest)?; - Ok(MastNode::new_loop(body_id, mast_forest)) + MastNode::new_loop(body_id, mast_forest).map_err(|e| { + DeserializationError::InvalidValue(format!( + "Failed to deserialize Loop node: {e}" + )) + }) } MastNodeType::Call { callee_id } => { let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?; - Ok(MastNode::new_call(callee_id, mast_forest)) + MastNode::new_call(callee_id, mast_forest).map_err(|e| { + DeserializationError::InvalidValue(format!( + "Failed to deserialize Call node: {e}" + )) + }) } MastNodeType::SysCall { callee_id } => { let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?; - Ok(MastNode::new_syscall(callee_id, mast_forest)) + MastNode::new_syscall(callee_id, mast_forest).map_err(|e| { + DeserializationError::InvalidValue(format!( + "Failed to deserialize SysCall node: {e}" + )) + }) } MastNodeType::Dyn => Ok(MastNode::new_dyn()), MastNodeType::External => Ok(MastNode::new_external(self.digest)), diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs index f2856848f5..6b0ad82ca5 100644 --- a/core/src/mast/serialization/tests.rs +++ b/core/src/mast/serialization/tests.rs @@ -339,22 +339,22 @@ fn mast_forest_invalid_node_id() { // Attempt to join with invalid ids let join = forest.add_join(overflow, second); - assert_eq!(join, Err(MastForestError::InvalidNodeId(overflow))); + assert_eq!(join, Err(MastForestError::NodeIdOverflow(overflow, 2))); let join = forest.add_join(first, overflow); - assert_eq!(join, Err(MastForestError::InvalidNodeId(overflow))); + assert_eq!(join, Err(MastForestError::NodeIdOverflow(overflow, 2))); // Attempt to split with invalid ids let split = forest.add_split(overflow, second); - assert_eq!(split, Err(MastForestError::InvalidNodeId(overflow))); + assert_eq!(split, Err(MastForestError::NodeIdOverflow(overflow, 2))); let split = forest.add_split(first, overflow); - assert_eq!(split, Err(MastForestError::InvalidNodeId(overflow))); + assert_eq!(split, Err(MastForestError::NodeIdOverflow(overflow, 2))); // Attempt to loop with invalid ids - assert_eq!(forest.add_loop(overflow), Err(MastForestError::InvalidNodeId(overflow))); + assert_eq!(forest.add_loop(overflow), Err(MastForestError::NodeIdOverflow(overflow, 2))); // Attempt to call with invalid ids - assert_eq!(forest.add_call(overflow), Err(MastForestError::InvalidNodeId(overflow))); - assert_eq!(forest.add_syscall(overflow), Err(MastForestError::InvalidNodeId(overflow))); + assert_eq!(forest.add_call(overflow), Err(MastForestError::NodeIdOverflow(overflow, 2))); + assert_eq!(forest.add_syscall(overflow), Err(MastForestError::NodeIdOverflow(overflow, 2))); // Validate normal operations forest.add_join(first, second).unwrap(); diff --git a/miden/tests/integration/operations/io_ops/env_ops.rs b/miden/tests/integration/operations/io_ops/env_ops.rs index 733be8cc2e..58e2230954 100644 --- a/miden/tests/integration/operations/io_ops/env_ops.rs +++ b/miden/tests/integration/operations/io_ops/env_ops.rs @@ -164,7 +164,7 @@ fn build_bar_hash() -> [u64; 4] { let foo_root_id = mast_forest.add_block(vec![Operation::Caller], None).unwrap(); - let bar_root = MastNode::new_syscall(foo_root_id, &mast_forest); + let bar_root = MastNode::new_syscall(foo_root_id, &mast_forest).unwrap(); let bar_hash: Word = bar_root.digest().into(); [ bar_hash[0].as_int(), diff --git a/processor/src/chiplets/hasher/tests.rs b/processor/src/chiplets/hasher/tests.rs index ef47f9dfc9..713816ebdb 100644 --- a/processor/src/chiplets/hasher/tests.rs +++ b/processor/src/chiplets/hasher/tests.rs @@ -254,13 +254,13 @@ fn hash_memoization_control_blocks() { let f_branch = MastNode::new_basic_block(vec![Operation::Push(ONE)]); let f_branch_id = mast_forest.add_node(f_branch.clone()).unwrap(); - let split1 = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); + let split1 = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest).unwrap(); let split1_id = mast_forest.add_node(split1.clone()).unwrap(); - let split2 = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest); + let split2 = MastNode::new_split(t_branch_id, f_branch_id, &mast_forest).unwrap(); let split2_id = mast_forest.add_node(split2.clone()).unwrap(); - let join_node = MastNode::new_join(split1_id, split2_id, &mast_forest); + let join_node = MastNode::new_join(split1_id, split2_id, &mast_forest).unwrap(); let _join_node_id = mast_forest.add_node(join_node.clone()).unwrap(); let mut hasher = Hasher::default(); @@ -420,16 +420,16 @@ fn hash_memoization_basic_blocks_check(basic_block: MastNode) { .add_block(vec![Operation::Pad, Operation::Eq, Operation::Not], None) .unwrap(); - let loop_block = MastNode::new_loop(loop_body_id, &mast_forest); + let loop_block = MastNode::new_loop(loop_body_id, &mast_forest).unwrap(); let loop_block_id = mast_forest.add_node(loop_block.clone()).unwrap(); - let join2_block = MastNode::new_join(basic_block_1_id, loop_block_id, &mast_forest); + let join2_block = MastNode::new_join(basic_block_1_id, loop_block_id, &mast_forest).unwrap(); let join2_block_id = mast_forest.add_node(join2_block.clone()).unwrap(); let basic_block_2 = basic_block; let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()).unwrap(); - let join1_block = MastNode::new_join(join2_block_id, basic_block_2_id, &mast_forest); + let join1_block = MastNode::new_join(join2_block_id, basic_block_2_id, &mast_forest).unwrap(); let mut hasher = Hasher::default(); let h1: [Felt; DIGEST_LEN] = join2_block diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index 281e4d4f0f..ce16ac354b 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -687,10 +687,10 @@ fn call_block() { let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd]); let last_basic_block_id = mast_forest.add_node(last_basic_block.clone()).unwrap(); - let foo_call_node = MastNode::new_call(foo_root_node_id, &mast_forest); + let foo_call_node = MastNode::new_call(foo_root_node_id, &mast_forest).unwrap(); let foo_call_node_id = mast_forest.add_node(foo_call_node.clone()).unwrap(); - let join1_node = MastNode::new_join(first_basic_block_id, foo_call_node_id, &mast_forest); + let join1_node = MastNode::new_join(first_basic_block_id, foo_call_node_id, &mast_forest).unwrap(); let join1_node_id = mast_forest.add_node(join1_node.clone()).unwrap(); let program_root_id = mast_forest.add_join(join1_node_id, last_basic_block_id).unwrap(); @@ -894,10 +894,10 @@ fn syscall_block() { let bar_basic_block = MastNode::new_basic_block(vec![Operation::Push(TWO), Operation::FmpUpdate]); let bar_basic_block_id = mast_forest.add_node(bar_basic_block.clone()).unwrap(); - let foo_call_node = MastNode::new_syscall(foo_root_id, &mast_forest); + let foo_call_node = MastNode::new_syscall(foo_root_id, &mast_forest).unwrap(); let foo_call_node_id = mast_forest.add_node(foo_call_node.clone()).unwrap(); - let bar_root_node = MastNode::new_join(bar_basic_block_id, foo_call_node_id, &mast_forest); + let bar_root_node = MastNode::new_join(bar_basic_block_id, foo_call_node_id, &mast_forest).unwrap(); let bar_root_node_id = mast_forest.add_node(bar_root_node.clone()).unwrap(); mast_forest.make_root(bar_root_node_id); @@ -912,13 +912,13 @@ fn syscall_block() { let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd]); let last_basic_block_id = mast_forest.add_node(last_basic_block.clone()).unwrap(); - let bar_call_node = MastNode::new_call(bar_root_node_id, &mast_forest); + let bar_call_node = MastNode::new_call(bar_root_node_id, &mast_forest).unwrap(); let bar_call_node_id = mast_forest.add_node(bar_call_node.clone()).unwrap(); - let inner_join_node = MastNode::new_join(first_basic_block_id, bar_call_node_id, &mast_forest); + let inner_join_node = MastNode::new_join(first_basic_block_id, bar_call_node_id, &mast_forest).unwrap(); let inner_join_node_id = mast_forest.add_node(inner_join_node.clone()).unwrap(); - let program_root_node = MastNode::new_join(inner_join_node_id, last_basic_block_id, &mast_forest); + let program_root_node = MastNode::new_join(inner_join_node_id, last_basic_block_id, &mast_forest).unwrap(); let program_root_node_id = mast_forest.add_node(program_root_node.clone()).unwrap(); let program = Program::with_kernel(mast_forest, program_root_node_id, kernel.clone()); @@ -1183,14 +1183,14 @@ fn dyn_block() { let save_bb_node = MastNode::new_basic_block(vec![Operation::MovDn4]); let save_bb_node_id = mast_forest.add_node(save_bb_node.clone()).unwrap(); - let join_node = MastNode::new_join(mul_bb_node_id, save_bb_node_id, &mast_forest); + let join_node = MastNode::new_join(mul_bb_node_id, save_bb_node_id, &mast_forest).unwrap(); let join_node_id = mast_forest.add_node(join_node.clone()).unwrap(); // This dyn will point to foo. let dyn_node = MastNode::new_dyn(); let dyn_node_id = mast_forest.add_node(dyn_node.clone()).unwrap(); - let program_root_node = MastNode::new_join(join_node_id, dyn_node_id, &mast_forest); + let program_root_node = MastNode::new_join(join_node_id, dyn_node_id, &mast_forest).unwrap(); let program_root_node_id = mast_forest.add_node(program_root_node.clone()).unwrap(); let program = Program::new(mast_forest, program_root_node_id); diff --git a/processor/src/trace/tests/decoder.rs b/processor/src/trace/tests/decoder.rs index 8bf191471c..96b117e841 100644 --- a/processor/src/trace/tests/decoder.rs +++ b/processor/src/trace/tests/decoder.rs @@ -363,7 +363,7 @@ fn decoder_p2_join() { let basic_block_2 = MastNode::new_basic_block(vec![Operation::Add]); let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()).unwrap(); - let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); + let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest).unwrap(); let join_id = mast_forest.add_node(join.clone()).unwrap(); let program = Program::new(mast_forest, join_id); @@ -537,7 +537,7 @@ fn decoder_p2_loop_with_repeat() { let basic_block_2 = MastNode::new_basic_block(vec![Operation::Drop]); let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()).unwrap(); - let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest); + let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest).unwrap(); let join_id = mast_forest.add_node(join.clone()).unwrap(); let loop_node_id = mast_forest.add_loop(join_id).unwrap(); From eeb8332fafde7e6f903bff6d27bf38153aefdbe9 Mon Sep 17 00:00:00 2001 From: sergerad Date: Sat, 27 Jul 2024 12:55:17 +1200 Subject: [PATCH 7/7] expect node id --- core/src/mast/serialization/info.rs | 33 ++++++----------------------- 1 file changed, 7 insertions(+), 26 deletions(-) diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 3729d67b8b..675c3fd329 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -51,12 +51,8 @@ impl MastNodeInfo { let left_child = MastNodeId::from_u32_safe(left_child_id, mast_forest)?; let right_child = MastNodeId::from_u32_safe(right_child_id, mast_forest)?; - MastNode::new_join(left_child, right_child, mast_forest).map_err(|e| { - DeserializationError::InvalidValue(format!( - "Failed to deserialize Join node: {}", - e - )) - }) + Ok(MastNode::new_join(left_child, right_child, mast_forest) + .expect("invalid node id")) } MastNodeType::Split { if_branch_id, @@ -65,38 +61,23 @@ impl MastNodeInfo { let if_branch = MastNodeId::from_u32_safe(if_branch_id, mast_forest)?; let else_branch = MastNodeId::from_u32_safe(else_branch_id, mast_forest)?; - MastNode::new_split(if_branch, else_branch, mast_forest).map_err(|e| { - DeserializationError::InvalidValue(format!( - "Failed to deserialize Split node: {e}" - )) - }) + Ok(MastNode::new_split(if_branch, else_branch, mast_forest) + .expect("invalid node id")) } MastNodeType::Loop { body_id } => { let body_id = MastNodeId::from_u32_safe(body_id, mast_forest)?; - MastNode::new_loop(body_id, mast_forest).map_err(|e| { - DeserializationError::InvalidValue(format!( - "Failed to deserialize Loop node: {e}" - )) - }) + Ok(MastNode::new_loop(body_id, mast_forest).expect("invalid node id")) } MastNodeType::Call { callee_id } => { let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?; - MastNode::new_call(callee_id, mast_forest).map_err(|e| { - DeserializationError::InvalidValue(format!( - "Failed to deserialize Call node: {e}" - )) - }) + Ok(MastNode::new_call(callee_id, mast_forest).expect("invalid node id")) } MastNodeType::SysCall { callee_id } => { let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?; - MastNode::new_syscall(callee_id, mast_forest).map_err(|e| { - DeserializationError::InvalidValue(format!( - "Failed to deserialize SysCall node: {e}" - )) - }) + Ok(MastNode::new_syscall(callee_id, mast_forest).expect("invalid node id")) } MastNodeType::Dyn => Ok(MastNode::new_dyn()), MastNodeType::External => Ok(MastNode::new_external(self.digest)),