Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more functions for adding nodes to MastForest #1412

Merged
merged 7 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
- Added serialization/deserialization for `MastForest` (#1370)
- Updated CI to support `CHANGELOG.md` modification checking and `no changelog` label (#1406)
- Introduced `MastForestError` to enforce `MastForest` node count invariant (#1394)
- Added functions to `MastForestBuilder` to allow ensuring of nodes with fewer LOC (#1404)
- Make `Assembler` single-use (#1409)
- Remove `ProcedureCache` from the assembler (#1411).
- Add `Assembler::assemble_library()` (#1413)
- Added functions to `MastForest` and `MastForestBuilder` to add and ensure nodes with fewer LOC (#1404, #1412)
- Made `Assembler` single-use (#1409)
- Removed `ProcedureCache` from the assembler (#1411).
- Added `Assembler::assemble_library()` (#1413)

#### Changed

Expand Down
15 changes: 10 additions & 5 deletions assembly/src/assembler/mast_forest_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ impl MastForestBuilder {
left_child: MastNodeId,
right_child: MastNodeId,
) -> Result<MastNodeId, AssemblyError> {
self.ensure_node(MastNode::new_join(left_child, right_child, &self.mast_forest))
let join = MastNode::new_join(left_child, right_child, &self.mast_forest)?;
self.ensure_node(join)
}

/// Adds a split node to the forest, and returns the [`MastNodeId`] associated with it.
Expand All @@ -177,22 +178,26 @@ impl MastForestBuilder {
if_branch: MastNodeId,
else_branch: MastNodeId,
) -> Result<MastNodeId, AssemblyError> {
self.ensure_node(MastNode::new_split(if_branch, else_branch, &self.mast_forest))
let split = MastNode::new_split(if_branch, else_branch, &self.mast_forest)?;
self.ensure_node(split)
}

/// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_loop(&mut self, body: MastNodeId) -> Result<MastNodeId, AssemblyError> {
self.ensure_node(MastNode::new_loop(body, &self.mast_forest))
let loop_node = MastNode::new_loop(body, &self.mast_forest)?;
self.ensure_node(loop_node)
}

/// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_call(&mut self, callee: MastNodeId) -> Result<MastNodeId, AssemblyError> {
self.ensure_node(MastNode::new_call(callee, &self.mast_forest))
let call = MastNode::new_call(callee, &self.mast_forest)?;
self.ensure_node(call)
}

/// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_syscall(&mut self, callee: MastNodeId) -> Result<MastNodeId, AssemblyError> {
self.ensure_node(MastNode::new_syscall(callee, &self.mast_forest))
let syscall = MastNode::new_syscall(callee, &self.mast_forest)?;
self.ensure_node(syscall)
}

/// Adds a dyn node to the forest, and returns the [`MastNodeId`] associated with it.
Expand Down
29 changes: 7 additions & 22 deletions assembly/src/assembler/tests.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
use alloc::{boxed::Box, vec::Vec};
use pretty_assertions::assert_eq;
use vm_core::{
assert_matches,
mast::{MastForest, MastNode},
Program,
};
use vm_core::{assert_matches, mast::MastForest, Program};

use super::{Assembler, Library, Operation};
use crate::{
Expand Down Expand Up @@ -248,29 +244,18 @@ fn duplicate_nodes() {
let mut expected_mast_forest = MastForest::new();

// basic block: mul
let mul_basic_block_id = {
let node = MastNode::new_basic_block(vec![Operation::Mul]);
expected_mast_forest.add_node(node).unwrap()
};
let mul_basic_block_id = expected_mast_forest.add_block(vec![Operation::Mul], None).unwrap();

// basic block: add
let add_basic_block_id = {
let node = MastNode::new_basic_block(vec![Operation::Add]);
expected_mast_forest.add_node(node).unwrap()
};
let add_basic_block_id = expected_mast_forest.add_block(vec![Operation::Add], None).unwrap();

// inner split: `if.true add else mul end`
let inner_split_id = {
let node =
MastNode::new_split(add_basic_block_id, mul_basic_block_id, &expected_mast_forest);
expected_mast_forest.add_node(node).unwrap()
};
let inner_split_id =
expected_mast_forest.add_split(add_basic_block_id, mul_basic_block_id).unwrap();

// root: outer split
let root_id = {
let node = MastNode::new_split(mul_basic_block_id, inner_split_id, &expected_mast_forest);
expected_mast_forest.add_node(node).unwrap()
};
let root_id = expected_mast_forest.add_split(mul_basic_block_id, inner_split_id).unwrap();

expected_mast_forest.make_root(root_id);

let expected_program = Program::new(expected_mast_forest, root_id);
Expand Down
74 changes: 73 additions & 1 deletion core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pub use node::{
};
use winter_utils::DeserializationError;

use crate::{DecoratorList, Operation};

mod serialization;

#[cfg(test)]
Expand Down Expand Up @@ -60,6 +62,68 @@ impl MastForest {
Ok(new_node_id)
}

/// Adds a basic block node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_block(
&mut self,
operations: Vec<Operation>,
decorators: Option<DecoratorList>,
) -> Result<MastNodeId, MastForestError> {
match decorators {
Some(decorators) => {
self.add_node(MastNode::new_basic_block_with_decorators(operations, decorators))
}
None => self.add_node(MastNode::new_basic_block(operations)),
}
}

/// Adds a join node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_join(
&mut self,
left_child: MastNodeId,
right_child: MastNodeId,
) -> Result<MastNodeId, MastForestError> {
let join = MastNode::new_join(left_child, right_child, self)?;
self.add_node(join)
}

/// Adds a split node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_split(
&mut self,
if_branch: MastNodeId,
else_branch: MastNodeId,
) -> Result<MastNodeId, MastForestError> {
let split = MastNode::new_split(if_branch, else_branch, self)?;
self.add_node(split)
}

/// Adds a loop node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_loop(&mut self, body: MastNodeId) -> Result<MastNodeId, MastForestError> {
let loop_node = MastNode::new_loop(body, self)?;
self.add_node(loop_node)
}

/// Adds a call node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_call(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
let call = MastNode::new_call(callee, self)?;
self.add_node(call)
}

/// Adds a syscall node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_syscall(&mut self, callee: MastNodeId) -> Result<MastNodeId, MastForestError> {
let syscall = MastNode::new_syscall(callee, self)?;
self.add_node(syscall)
}

/// Adds a dyn node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_dyn(&mut self) -> Result<MastNodeId, MastForestError> {
self.add_node(MastNode::new_dyn())
}

/// Adds an external node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn add_external(&mut self, mast_root: RpoDigest) -> Result<MastNodeId, MastForestError> {
self.add_node(MastNode::new_external(mast_root))
}

/// Marks the given [`MastNodeId`] as being the root of a procedure.
///
/// # Panics
Expand Down Expand Up @@ -157,6 +221,12 @@ impl MastNodeId {
}
}

impl From<MastNodeId> for usize {
fn from(value: MastNodeId) -> Self {
value.0 as usize
}
}

impl From<MastNodeId> for u32 {
fn from(value: MastNodeId) -> Self {
value.0
Expand All @@ -179,11 +249,13 @@ impl fmt::Display for MastNodeId {
// ================================================================================================

/// Represents the types of errors that can occur when dealing with MAST forest.
#[derive(Debug, thiserror::Error)]
#[derive(Debug, thiserror::Error, PartialEq)]
pub enum MastForestError {
#[error(
"invalid node count: MAST forest exceeds the maximum of {} nodes",
MastForest::MAX_NODES
)]
TooManyNodes,
#[error("node id: {0} is greater than or equal to forest length: {1}")]
NodeIdOverflow(MastNodeId, usize),
}
23 changes: 16 additions & 7 deletions core/src/mast/node/call_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use miden_formatting::prettier::PrettyPrint;

use crate::{
chiplets::hasher,
mast::{MastForest, MastNodeId},
mast::{MastForest, MastForestError, MastNodeId},
OPCODE_CALL, OPCODE_SYSCALL,
};

Expand Down Expand Up @@ -38,34 +38,43 @@ impl CallNode {
/// Constructors
impl CallNode {
/// Returns a new [`CallNode`] instantiated with the specified callee.
pub fn new(callee: MastNodeId, mast_forest: &MastForest) -> Self {
pub fn new(callee: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
if usize::from(callee) >= mast_forest.nodes.len() {
return Err(MastForestError::NodeIdOverflow(callee, mast_forest.nodes.len()));
}
let digest = {
let callee_digest = mast_forest[callee].digest();

hasher::merge_in_domain(&[callee_digest, RpoDigest::default()], Self::CALL_DOMAIN)
};

Self {
Ok(Self {
callee,
is_syscall: false,
digest,
}
})
}

/// Returns a new [`CallNode`] instantiated with the specified callee and marked as a kernel
/// call.
pub fn new_syscall(callee: MastNodeId, mast_forest: &MastForest) -> Self {
pub fn new_syscall(
callee: MastNodeId,
mast_forest: &MastForest,
) -> Result<Self, MastForestError> {
if usize::from(callee) >= mast_forest.nodes.len() {
return Err(MastForestError::NodeIdOverflow(callee, mast_forest.nodes.len()));
}
let digest = {
let callee_digest = mast_forest[callee].digest();

hasher::merge_in_domain(&[callee_digest, RpoDigest::default()], Self::SYSCALL_DOMAIN)
};

Self {
Ok(Self {
callee,
is_syscall: true,
digest,
}
})
}
}

Expand Down
15 changes: 12 additions & 3 deletions core/src/mast/node/join_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt};

use crate::{
chiplets::hasher,
mast::{MastForest, MastNodeId},
mast::{MastForest, MastForestError, MastNodeId},
prettier::PrettyPrint,
OPCODE_JOIN,
};
Expand All @@ -29,15 +29,24 @@ impl JoinNode {
/// Constructors
impl JoinNode {
/// Returns a new [`JoinNode`] instantiated with the specified children nodes.
pub fn new(children: [MastNodeId; 2], mast_forest: &MastForest) -> Self {
pub fn new(
children: [MastNodeId; 2],
mast_forest: &MastForest,
) -> Result<Self, MastForestError> {
let forest_len = mast_forest.nodes.len();
if usize::from(children[0]) >= forest_len {
return Err(MastForestError::NodeIdOverflow(children[0], forest_len));
} else if usize::from(children[1]) >= forest_len {
return Err(MastForestError::NodeIdOverflow(children[1], forest_len));
}
let digest = {
let left_child_hash = mast_forest[children[0]].digest();
let right_child_hash = mast_forest[children[1]].digest();

hasher::merge_in_domain(&[left_child_hash, right_child_hash], Self::DOMAIN)
};

Self { children, digest }
Ok(Self { children, digest })
}

#[cfg(test)]
Expand Down
9 changes: 6 additions & 3 deletions core/src/mast/node/loop_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use miden_formatting::prettier::PrettyPrint;

use crate::{
chiplets::hasher,
mast::{MastForest, MastNodeId},
mast::{MastForest, MastForestError, MastNodeId},
OPCODE_LOOP,
};

Expand All @@ -32,14 +32,17 @@ impl LoopNode {

/// Constructors
impl LoopNode {
pub fn new(body: MastNodeId, mast_forest: &MastForest) -> Self {
pub fn new(body: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
if usize::from(body) >= mast_forest.nodes.len() {
return Err(MastForestError::NodeIdOverflow(body, mast_forest.nodes.len()));
}
let digest = {
let body_hash = mast_forest[body].digest();

hasher::merge_in_domain(&[body_hash, RpoDigest::default()], Self::DOMAIN)
};

Self { body, digest }
Ok(Self { body, digest })
}
}

Expand Down
30 changes: 20 additions & 10 deletions core/src/mast/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ use crate::{
DecoratorList, Operation,
};

use super::MastForestError;

// MAST NODE
// ================================================================================================

Expand Down Expand Up @@ -64,28 +66,36 @@ impl MastNode {
left_child: MastNodeId,
right_child: MastNodeId,
mast_forest: &MastForest,
) -> Self {
Self::Join(JoinNode::new([left_child, right_child], mast_forest))
) -> Result<Self, MastForestError> {
let join = JoinNode::new([left_child, right_child], mast_forest)?;
Ok(Self::Join(join))
}

pub fn new_split(
if_branch: MastNodeId,
else_branch: MastNodeId,
mast_forest: &MastForest,
) -> Self {
Self::Split(SplitNode::new([if_branch, else_branch], mast_forest))
) -> Result<Self, MastForestError> {
let split = SplitNode::new([if_branch, else_branch], mast_forest)?;
Ok(Self::Split(split))
}

pub fn new_loop(body: MastNodeId, mast_forest: &MastForest) -> Self {
Self::Loop(LoopNode::new(body, mast_forest))
pub fn new_loop(body: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
let loop_node = LoopNode::new(body, mast_forest)?;
Ok(Self::Loop(loop_node))
}

pub fn new_call(callee: MastNodeId, mast_forest: &MastForest) -> Self {
Self::Call(CallNode::new(callee, mast_forest))
pub fn new_call(callee: MastNodeId, mast_forest: &MastForest) -> Result<Self, MastForestError> {
let call = CallNode::new(callee, mast_forest)?;
Ok(Self::Call(call))
}

pub fn new_syscall(callee: MastNodeId, mast_forest: &MastForest) -> Self {
Self::Call(CallNode::new_syscall(callee, mast_forest))
pub fn new_syscall(
callee: MastNodeId,
mast_forest: &MastForest,
) -> Result<Self, MastForestError> {
let syscall = CallNode::new_syscall(callee, mast_forest)?;
Ok(Self::Call(syscall))
}

pub fn new_dyn() -> Self {
Expand Down
Loading
Loading