Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more functions for ensuring nodes via MastForestBuilder #1404

Merged
merged 3 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

- Added error codes support for the `mtree_verify` instruction (#1328).
- Added support for immediate values for `lt`, `lte`, `gt`, `gte` comparison instructions (#1346).
- Change MAST to a table-based representation (#1349)
- Introduce `MastForestStore` (#1359)
- Changed MAST to a table-based representation (#1349)
- Introduced `MastForestStore` (#1359)
- Adjusted prover's metal acceleration code to work with 0.9 versions of the crates (#1357)
- Added support for immediate values for `u32lt`, `u32lte`, `u32gt`, `u32gte`, `u32min` and `u32max` comparison instructions (#1358).
- Added support for the `nop` instruction, which corresponds to the VM opcode of the same name, and has the same semantics. This is implemented for use by compilers primarily.
Expand All @@ -16,9 +16,10 @@
- Added support for immediate values for `u32and`, `u32or`, `u32xor` and `u32not` bitwise instructions (#1362).
- Optimized `std::sys::truncate_stuck` procedure (#1384).
- Updated CI and Makefile to standardise it accross Miden repositories (#1342).
- Add serialization/deserialization for `MastForest` (#1370)
- Added serialization/deserialization for `MastForest` (#1370)
- Updated CI to support `CHANGELOG.md` modification checking and `no changelog` label (#1406)
- Introduce `MastForestError` to enforce `MastForest` node count invariant (#1394)
- Introduced `MastForestError` to enforce `MastForest` node count invariant (#1394)
- Added functions to `MastForestBuilder` to allow ensuring of nodes with fewer LOC (#1404)

#### Changed

Expand Down
5 changes: 2 additions & 3 deletions assembly/src/assembler/basic_block_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::{
};
use alloc::{borrow::Borrow, string::ToString, vec::Vec};
use vm_core::{
mast::{MastForestError, MastNode, MastNodeId},
mast::{MastForestError, MastNodeId},
AdviceInjector, AssemblyOp, Operation,
};

Expand Down Expand Up @@ -134,8 +134,7 @@ impl BasicBlockBuilder {
let ops = self.ops.drain(..).collect();
let decorators = self.decorators.drain(..).collect();

let basic_block_node = MastNode::new_basic_block_with_decorators(ops, decorators);
let basic_block_node_id = mast_forest_builder.ensure_node(basic_block_node)?;
let basic_block_node_id = mast_forest_builder.ensure_block(ops, Some(decorators))?;

Ok(Some(basic_block_node_id))
} else if !self.decorators.is_empty() {
Expand Down
26 changes: 9 additions & 17 deletions assembly/src/assembler/instruction/procedures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
};

use smallvec::SmallVec;
use vm_core::mast::{MastForest, MastNode, MastNodeId};
use vm_core::mast::{MastForest, MastNodeId};

/// Procedure Invocation
impl Assembler {
Expand Down Expand Up @@ -96,8 +96,7 @@ impl Assembler {
None => {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)?
mast_forest_builder.ensure_external(mast_root)?
}
}
}
Expand All @@ -107,28 +106,23 @@ impl Assembler {
None => {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)?
mast_forest_builder.ensure_external(mast_root)?
}
};

let call_node = MastNode::new_call(callee_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(call_node)?
mast_forest_builder.ensure_call(callee_id)?
}
InvokeKind::SysCall => {
let callee_id = match mast_forest_builder.find_procedure_root(mast_root) {
Some(callee_id) => callee_id,
None => {
// If the MAST root called isn't known to us, make it an external
// reference.
let external_node = MastNode::new_external(mast_root);
mast_forest_builder.ensure_node(external_node)?
mast_forest_builder.ensure_external(mast_root)?
}
};

let syscall_node =
MastNode::new_syscall(callee_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(syscall_node)?
mast_forest_builder.ensure_syscall(callee_id)?
}
}
};
Expand All @@ -141,7 +135,7 @@ impl Assembler {
&self,
mast_forest_builder: &mut MastForestBuilder,
) -> Result<Option<MastNodeId>, AssemblyError> {
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?;
let dyn_node_id = mast_forest_builder.ensure_dyn()?;

Ok(Some(dyn_node_id))
}
Expand All @@ -152,10 +146,8 @@ impl Assembler {
mast_forest_builder: &mut MastForestBuilder,
) -> Result<Option<MastNodeId>, AssemblyError> {
let dyn_call_node_id = {
let dyn_node_id = mast_forest_builder.ensure_node(MastNode::Dyn)?;
let dyn_call_node = MastNode::new_call(dyn_node_id, mast_forest_builder.forest());

mast_forest_builder.ensure_node(dyn_call_node)?
let dyn_node_id = mast_forest_builder.ensure_dyn()?;
mast_forest_builder.ensure_call(dyn_node_id)?
};

Ok(Some(dyn_call_node_id))
Expand Down
62 changes: 60 additions & 2 deletions assembly/src/assembler/mast_forest_builder.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use core::ops::Index;

use alloc::collections::BTreeMap;
use alloc::{collections::BTreeMap, vec::Vec};
use vm_core::{
crypto::hash::RpoDigest,
mast::{MastForest, MastForestError, MastNode, MastNodeId, MerkleTreeNode},
DecoratorList, Operation,
};

/// Builder for a [`MastForest`].
Expand Down Expand Up @@ -44,7 +45,7 @@ impl MastForestBuilder {
/// If a [`MastNode`] which is equal to the current node was previously added, the previously
/// returned [`MastNodeId`] will be returned. This enforces this invariant that equal
/// [`MastNode`]s have equal [`MastNodeId`]s.
pub fn ensure_node(&mut self, node: MastNode) -> Result<MastNodeId, MastForestError> {
fn ensure_node(&mut self, node: MastNode) -> Result<MastNodeId, MastForestError> {
let node_digest = node.digest();

if let Some(node_id) = self.node_id_by_hash.get(&node_digest) {
Expand All @@ -58,6 +59,63 @@ impl MastForestBuilder {
}
}

/// Adds a basic block node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_block(
&mut self,
operations: Vec<Operation>,
decorators: Option<DecoratorList>,
) -> Result<MastNodeId, MastForestError> {
match decorators {
Some(decorators) => {
self.ensure_node(MastNode::new_basic_block_with_decorators(operations, decorators))
}
None => self.ensure_node(MastNode::new_basic_block(operations)),
}
}

/// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_join(
&mut self,
left_child: MastNodeId,
right_child: MastNodeId,
) -> Result<MastNodeId, MastForestError> {
self.ensure_node(MastNode::new_join(left_child, right_child, &self.mast_forest))
}

/// Adds a split node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_split(
&mut self,
if_branch: MastNodeId,
else_branch: MastNodeId,
) -> Result<MastNodeId, MastForestError> {
self.ensure_node(MastNode::new_split(if_branch, else_branch, &self.mast_forest))
}

/// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_loop(&mut self, body: MastNodeId) -> Result<MastNodeId, MastForestError> {
self.ensure_node(MastNode::new_loop(body, &self.mast_forest))
}

/// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_call(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
self.ensure_node(MastNode::new_call(callee, &self.mast_forest))
}

/// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_syscall(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
self.ensure_node(MastNode::new_syscall(callee, &self.mast_forest))
}

/// Adds a dynexec node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_dyn(&mut self) -> Result<MastNodeId, MastForestError> {
self.ensure_node(MastNode::new_dyn())
}

/// Adds an external node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_external(&mut self, mast_root: RpoDigest) -> Result<MastNodeId, MastForestError> {
self.ensure_node(MastNode::new_external(mast_root))
}

/// Marks the given [`MastNodeId`] as being the root of a procedure.
pub fn make_root(&mut self, new_root_id: MastNodeId) {
self.mast_forest.make_root(new_root_id)
Expand Down
27 changes: 11 additions & 16 deletions assembly/src/assembler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use mast_forest_builder::MastForestBuilder;
use vm_core::{
mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode},
mast::{MastForest, MastNodeId, MerkleTreeNode},
Decorator, DecoratorList, Kernel, Operation, Program,
};

Expand Down Expand Up @@ -792,12 +792,9 @@ impl Assembler {
let else_blk =
self.compile_body(else_blk.iter(), context, None, mast_forest_builder)?;

let split_node_id = {
let split_node =
MastNode::new_split(then_blk, else_blk, mast_forest_builder.forest());

mast_forest_builder.ensure_node(split_node).map_err(AssemblyError::from)?
};
let split_node_id = mast_forest_builder
.ensure_split(then_blk, else_blk)
.map_err(AssemblyError::from)?;
mast_node_ids.push(split_node_id);
}

Expand Down Expand Up @@ -828,11 +825,9 @@ impl Assembler {
let loop_body_node_id =
self.compile_body(body.iter(), context, None, mast_forest_builder)?;

let loop_node_id = {
let loop_node =
MastNode::new_loop(loop_body_node_id, mast_forest_builder.forest());
mast_forest_builder.ensure_node(loop_node).map_err(AssemblyError::from)?
};
let loop_node_id = mast_forest_builder
.ensure_loop(loop_body_node_id)
.map_err(AssemblyError::from)?;
mast_node_ids.push(loop_node_id);
}
}
Expand All @@ -846,8 +841,9 @@ impl Assembler {
}

Ok(if mast_node_ids.is_empty() {
let basic_block_node = MastNode::new_basic_block(vec![Operation::Noop]);
mast_forest_builder.ensure_node(basic_block_node).map_err(AssemblyError::from)?
mast_forest_builder
.ensure_block(vec![Operation::Noop], None)
.map_err(AssemblyError::from)?
} else {
combine_mast_node_ids(mast_node_ids, mast_forest_builder)?
})
Expand Down Expand Up @@ -907,8 +903,7 @@ fn combine_mast_node_ids(
while let (Some(left), Some(right)) =
(source_mast_node_iter.next(), source_mast_node_iter.next())
{
let join_mast_node = MastNode::new_join(left, right, mast_forest_builder.forest());
let join_mast_node_id = mast_forest_builder.ensure_node(join_mast_node)?;
let join_mast_node_id = mast_forest_builder.ensure_join(left, right)?;

mast_node_ids.push(join_mast_node_id);
}
Expand Down
Loading
Loading