Skip to content

Commit

Permalink
feat: MastForest maximum node length invariant (#1394)
Browse files Browse the repository at this point in the history
  • Loading branch information
sergerad authored Jul 19, 2024
1 parent 323160d commit 14995b4
Show file tree
Hide file tree
Showing 18 changed files with 215 additions and 174 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
- Updated CI and Makefile to standardise it accross Miden repositories (#1342).
- Add serialization/deserialization for `MastForest` (#1370)
- Updated CI to support `CHANGELOG.md` modification checking and `no changelog` label (#1406)
- Introduce `MastForestError` to enforce `MastForest` node count invariant (#1394)

#### Changed

Expand Down
14 changes: 7 additions & 7 deletions assembly/src/assembler/basic_block_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::{
};
use alloc::{borrow::Borrow, string::ToString, vec::Vec};
use vm_core::{
mast::{MastNode, MastNodeId},
mast::{MastForestError, MastNode, MastNodeId},
AdviceInjector, AssemblyOp, Operation,
};

Expand Down Expand Up @@ -129,22 +129,22 @@ impl BasicBlockBuilder {
pub fn make_basic_block(
&mut self,
mast_forest_builder: &mut MastForestBuilder,
) -> Option<MastNodeId> {
) -> Result<Option<MastNodeId>, MastForestError> {
if !self.ops.is_empty() {
let ops = self.ops.drain(..).collect();
let decorators = self.decorators.drain(..).collect();

let basic_block_node = MastNode::new_basic_block_with_decorators(ops, decorators);
let basic_block_node_id = mast_forest_builder.ensure_node(basic_block_node);
let basic_block_node_id = mast_forest_builder.ensure_node(basic_block_node)?;

Some(basic_block_node_id)
Ok(Some(basic_block_node_id))
} else if !self.decorators.is_empty() {
// this is a bug in the assembler. we shouldn't have decorators added without their
// associated operations
// TODO: change this to an error or allow decorators in empty span blocks
unreachable!("decorators in an empty SPAN block")
} else {
None
Ok(None)
}
}

Expand All @@ -155,10 +155,10 @@ impl BasicBlockBuilder {
/// - Operations contained in the epilogue of the builder are appended to the list of ops which
/// go into the new BASIC BLOCK node.
/// - The builder is consumed in the process.
pub fn into_basic_block(
pub fn try_into_basic_block(
mut self,
mast_forest_builder: &mut MastForestBuilder,
) -> Option<MastNodeId> {
) -> Result<Option<MastNodeId>, MastForestError> {
self.ops.append(&mut self.epilogue);
self.make_basic_block(mast_forest_builder)
}
Expand Down
45 changes: 26 additions & 19 deletions assembly/src/assembler/instruction/procedures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,37 +91,44 @@ impl Assembler {
// procedures, such that when we assemble a procedure, all
// procedures that it calls will have been assembled, and
// hence be present in the `MastForest`.
mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)
})
match mast_forest_builder.find_procedure_root(mast_root) {
Some(root) => root,
None => {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)?
}
}
}
InvokeKind::Call => {
let callee_id =
mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| {
let callee_id = match mast_forest_builder.find_procedure_root(mast_root) {
Some(callee_id) => callee_id,
None => {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)
});
mast_forest_builder.ensure_node(external_node)?
}
};

let call_node = MastNode::new_call(callee_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(call_node)
mast_forest_builder.ensure_node(call_node)?
}
InvokeKind::SysCall => {
let callee_id =
mast_forest_builder.find_procedure_root(mast_root).unwrap_or_else(|| {
let callee_id = match mast_forest_builder.find_procedure_root(mast_root) {
Some(callee_id) => callee_id,
None => {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)
});
mast_forest_builder.ensure_node(external_node)?
}
};

let syscall_node =
MastNode::new_syscall(callee_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(syscall_node)
mast_forest_builder.ensure_node(syscall_node)?
}
}
};
Expand All @@ -134,7 +141,7 @@ impl Assembler {
&self,
mast_forest_builder: &mut MastForestBuilder,
) -> Result<Option<MastNodeId>, AssemblyError> {
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn);
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?;

Ok(Some(dyn_node_id))
}
Expand All @@ -145,10 +152,10 @@ impl Assembler {
mast_forest_builder: &mut MastForestBuilder,
) -> Result<Option<MastNodeId>, AssemblyError> {
let dyn_call_node_id = {
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn);
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?;
let dyn_call_node = MastNode::new_call(dyn_node_id, mast_forest_builder.forest());

mast_forest_builder.ensure_node(dyn_call_node)
mast_forest_builder.ensure_node(dyn_call_node)?
};

Ok(Some(dyn_call_node_id))
Expand Down
10 changes: 5 additions & 5 deletions assembly/src/assembler/mast_forest_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use core::ops::Index;
use alloc::collections::BTreeMap;
use vm_core::{
crypto::hash::RpoDigest,
mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode},
mast::{MastForest, MastForestError, MastNode, MastNodeId, MerkleTreeNode},
};

/// Builder for a [`MastForest`].
Expand Down Expand Up @@ -44,17 +44,17 @@ impl MastForestBuilder {
/// If a [`MastNode`] which is equal to the current node was previously added, the previously
/// returned [`MastNodeId`] will be returned. This enforces this invariant that equal
/// [`MastNode`]s have equal [`MastNodeId`]s.
pub fn ensure_node(&mut self, node: MastNode) -> MastNodeId {
pub fn ensure_node(&mut self, node: MastNode) -> Result<MastNodeId, MastForestError> {
let node_digest = node.digest();

if let Some(node_id) = self.node_id_by_hash.get(&node_digest) {
// node already exists in the forest; return previously assigned id
*node_id
Ok(*node_id)
} else {
let new_node_id = self.mast_forest.add_node(node);
let new_node_id = self.mast_forest.add_node(node)?;
self.node_id_by_hash.insert(node_digest, new_node_id);

new_node_id
Ok(new_node_id)
}
}

Expand Down
39 changes: 23 additions & 16 deletions assembly/src/assembler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -766,8 +766,9 @@ impl Assembler {
context,
mast_forest_builder,
)? {
if let Some(basic_block_id) =
basic_block_builder.make_basic_block(mast_forest_builder)
if let Some(basic_block_id) = basic_block_builder
.make_basic_block(mast_forest_builder)
.map_err(AssemblyError::from)?
{
mast_node_ids.push(basic_block_id);
}
Expand All @@ -779,8 +780,9 @@ impl Assembler {
Op::If {
then_blk, else_blk, ..
} => {
if let Some(basic_block_id) =
basic_block_builder.make_basic_block(mast_forest_builder)
if let Some(basic_block_id) = basic_block_builder
.make_basic_block(mast_forest_builder)
.map_err(AssemblyError::from)?
{
mast_node_ids.push(basic_block_id);
}
Expand All @@ -794,14 +796,15 @@ impl Assembler {
let split_node =
MastNode::new_split(then_blk, else_blk, mast_forest_builder.forest());

mast_forest_builder.ensure_node(split_node)
mast_forest_builder.ensure_node(split_node).map_err(AssemblyError::from)?
};
mast_node_ids.push(split_node_id);
}

Op::Repeat { count, body, .. } => {
if let Some(basic_block_id) =
basic_block_builder.make_basic_block(mast_forest_builder)
if let Some(basic_block_id) = basic_block_builder
.make_basic_block(mast_forest_builder)
.map_err(AssemblyError::from)?
{
mast_node_ids.push(basic_block_id);
}
Expand All @@ -815,8 +818,9 @@ impl Assembler {
}

Op::While { body, .. } => {
if let Some(basic_block_id) =
basic_block_builder.make_basic_block(mast_forest_builder)
if let Some(basic_block_id) = basic_block_builder
.make_basic_block(mast_forest_builder)
.map_err(AssemblyError::from)?
{
mast_node_ids.push(basic_block_id);
}
Expand All @@ -827,22 +831,25 @@ impl Assembler {
let loop_node_id = {
let loop_node =
MastNode::new_loop(loop_body_node_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(loop_node)
mast_forest_builder.ensure_node(loop_node).map_err(AssemblyError::from)?
};
mast_node_ids.push(loop_node_id);
}
}
}

if let Some(basic_block_id) = basic_block_builder.into_basic_block(mast_forest_builder) {
if let Some(basic_block_id) = basic_block_builder
.try_into_basic_block(mast_forest_builder)
.map_err(AssemblyError::from)?
{
mast_node_ids.push(basic_block_id);
}

Ok(if mast_node_ids.is_empty() {
let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]);
mast_forest_builder.ensure_node(basic_block_node)
mast_forest_builder.ensure_node(basic_block_node).map_err(AssemblyError::from)?
} else {
combine_mast_node_ids(mast_node_ids, mast_forest_builder)
combine_mast_node_ids(mast_node_ids, mast_forest_builder)?
})
}

Expand Down Expand Up @@ -882,7 +889,7 @@ struct BodyWrapper {
fn combine_mast_node_ids(
mut mast_node_ids: Vec<MastNodeId>,
mast_forest_builder: &mut MastForestBuilder,
) -> MastNodeId {
) -> Result<MastNodeId, AssemblyError> {
debug_assert!(!mast_node_ids.is_empty(), "cannot combine empty MAST node id list");

// build a binary tree of blocks joining them using JOIN blocks
Expand All @@ -901,7 +908,7 @@ fn combine_mast_node_ids(
(source_mast_node_iter.next(), source_mast_node_iter.next())
{
let join_mast_node = MastNode::new_join(left, right, mast_forest_builder.forest());
let join_mast_node_id = mast_forest_builder.ensure_node(join_mast_node);
let join_mast_node_id = mast_forest_builder.ensure_node(join_mast_node)?;

mast_node_ids.push(join_mast_node_id);
}
Expand All @@ -910,5 +917,5 @@ fn combine_mast_node_ids(
}
}

mast_node_ids.remove(0)
Ok(mast_node_ids.remove(0))
}
Loading

0 comments on commit 14995b4

Please sign in to comment.