From f1c1282b3dcab1e6cd83ff928d46edb24ea49d16 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Sat, 20 Jul 2024 15:43:56 -0700 Subject: [PATCH 1/4] refactor: remove MerkleTreeNode trait --- assembly/src/assembler/mast_forest_builder.rs | 2 +- assembly/src/assembler/mod.rs | 2 +- assembly/src/assembler/procedure.rs | 2 +- core/src/mast/mod.rs | 8 +-- core/src/mast/node/basic_block_node/mod.rs | 50 ++++++++++--------- core/src/mast/node/call_node.rs | 8 +-- core/src/mast/node/dyn_node.rs | 11 ++-- core/src/mast/node/external.rs | 8 +-- core/src/mast/node/join_node.rs | 30 ++++++----- core/src/mast/node/loop_node.rs | 36 +++++++------ core/src/mast/node/mod.rs | 11 ++-- core/src/mast/node/split_node.rs | 28 ++++++----- core/src/mast/serialization/info.rs | 2 +- core/src/mast/tests.rs | 6 +-- core/src/program.rs | 2 +- .../integration/operations/io_ops/env_ops.rs | 2 +- processor/src/chiplets/hasher/tests.rs | 2 +- processor/src/decoder/mod.rs | 2 +- processor/src/decoder/tests.rs | 2 +- processor/src/lib.rs | 2 +- processor/src/trace/tests/decoder.rs | 2 +- 21 files changed, 109 insertions(+), 109 deletions(-) diff --git a/assembly/src/assembler/mast_forest_builder.rs b/assembly/src/assembler/mast_forest_builder.rs index 0a1a44d037..2d0bacbb3e 100644 --- a/assembly/src/assembler/mast_forest_builder.rs +++ b/assembly/src/assembler/mast_forest_builder.rs @@ -3,7 +3,7 @@ use core::ops::Index; use alloc::{collections::BTreeMap, vec::Vec}; use vm_core::{ crypto::hash::RpoDigest, - mast::{MastForest, MastForestError, MastNode, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastForestError, MastNode, MastNodeId}, DecoratorList, Operation, }; diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index cfdfbfe592..2ef43cdb03 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -11,7 +11,7 @@ use crate::{ use alloc::{boxed::Box, sync::Arc, vec::Vec}; use mast_forest_builder::MastForestBuilder; use vm_core::{ - mast::{MastForest, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNodeId}, Decorator, DecoratorList, Kernel, Operation, Program, }; diff --git a/assembly/src/assembler/procedure.rs b/assembly/src/assembler/procedure.rs index 88224396da..15675f5b35 100644 --- a/assembly/src/assembler/procedure.rs +++ b/assembly/src/assembler/procedure.rs @@ -5,7 +5,7 @@ use crate::{ diagnostics::SourceFile, LibraryPath, RpoDigest, SourceSpan, Spanned, }; -use vm_core::mast::{MastForest, MastNodeId, MerkleTreeNode}; +use vm_core::mast::{MastForest, MastNodeId}; pub type CallSet = BTreeSet; diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 31ad5f1f0d..77ebad3249 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -15,12 +15,6 @@ mod serialization; #[cfg(test)] mod tests; -/// Encapsulates the behavior that a [`MastNode`] (and all its variants) is expected to have. -pub trait MerkleTreeNode { - fn digest(&self) -> RpoDigest; - fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a; -} - // MAST FOREST // ================================================================================================ @@ -171,6 +165,8 @@ impl Deserializable for MastNodeId { fn read_from(source: &mut R) -> Result { let inner = source.read_u32()?; + // TODO: fix + Ok(Self(inner)) } } diff --git a/core/src/mast/node/basic_block_node/mod.rs b/core/src/mast/node/basic_block_node/mod.rs index a8d07ab2de..f71f1d0ec1 100644 --- a/core/src/mast/node/basic_block_node/mod.rs +++ b/core/src/mast/node/basic_block_node/mod.rs @@ -6,9 +6,7 @@ use miden_formatting::prettier::PrettyPrint; use winter_utils::flatten_slice_elements; use crate::{ - chiplets::hasher, - mast::{MastForest, MerkleTreeNode}, - Decorator, DecoratorIterator, DecoratorList, Operation, + chiplets::hasher, mast::MastForest, Decorator, DecoratorIterator, DecoratorList, Operation, }; #[cfg(test)] @@ -67,12 +65,14 @@ pub struct BasicBlockNode { decorators: DecoratorList, } +// ------------------------------------------------------------------------------------------------ /// Constants impl BasicBlockNode { /// The domain of the basic block node (used for control block hashing). pub const DOMAIN: Felt = ZERO; } +// ------------------------------------------------------------------------------------------------ /// Constructors impl BasicBlockNode { /// Returns a new [`BasicBlockNode`] instantiated with the specified operations. @@ -108,6 +108,7 @@ impl BasicBlockNode { } } +// ------------------------------------------------------------------------------------------------ /// Public accessors impl BasicBlockNode { pub fn num_operations_and_decorators(&self) -> u32 { @@ -139,35 +140,18 @@ impl BasicBlockNode { pub fn decorators(&self) -> &DecoratorList { &self.decorators } -} -impl MerkleTreeNode for BasicBlockNode { - fn digest(&self) -> RpoDigest { + pub fn digest(&self) -> RpoDigest { self.digest } - fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + pub fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { self } } -/// Checks if a given decorators list is valid (only checked in debug mode) -/// - Assert the decorator list is in ascending order. -/// - Assert the last op index in decorator list is less than or equal to the number of operations. -#[cfg(debug_assertions)] -fn validate_decorators(operations: &[Operation], decorators: &DecoratorList) { - if !decorators.is_empty() { - // check if decorator list is sorted - for i in 0..(decorators.len() - 1) { - debug_assert!(decorators[i + 1].0 >= decorators[i].0, "unsorted decorators list"); - } - // assert the last index in decorator list is less than operations vector length - debug_assert!( - operations.len() >= decorators.last().expect("empty decorators list").0, - "last op index in decorator list should be less than or equal to the number of ops" - ); - } -} +// PRETTY PRINTING +// ================================================================================================ impl PrettyPrint for BasicBlockNode { #[rustfmt::skip] @@ -515,3 +499,21 @@ pub fn get_span_op_group_count(op_batches: &[OpBatch]) -> usize { let last_batch_num_groups = op_batches.last().expect("no last group").num_groups(); (op_batches.len() - 1) * BATCH_SIZE + last_batch_num_groups.next_power_of_two() } + +/// Checks if a given decorators list is valid (only checked in debug mode) +/// - Assert the decorator list is in ascending order. +/// - Assert the last op index in decorator list is less than or equal to the number of operations. +#[cfg(debug_assertions)] +fn validate_decorators(operations: &[Operation], decorators: &DecoratorList) { + if !decorators.is_empty() { + // check if decorator list is sorted + for i in 0..(decorators.len() - 1) { + debug_assert!(decorators[i + 1].0 >= decorators[i].0, "unsorted decorators list"); + } + // assert the last index in decorator list is less than operations vector length + debug_assert!( + operations.len() >= decorators.last().expect("empty decorators list").0, + "last op index in decorator list should be less than or equal to the number of ops" + ); + } +} diff --git a/core/src/mast/node/call_node.rs b/core/src/mast/node/call_node.rs index 1f00d74c00..2b12ec92e6 100644 --- a/core/src/mast/node/call_node.rs +++ b/core/src/mast/node/call_node.rs @@ -5,7 +5,7 @@ use miden_formatting::prettier::PrettyPrint; use crate::{ chiplets::hasher, - mast::{MastForest, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNodeId}, OPCODE_CALL, OPCODE_SYSCALL, }; @@ -87,12 +87,12 @@ impl CallNode { } } -impl MerkleTreeNode for CallNode { - fn digest(&self) -> RpoDigest { +impl CallNode { + pub fn digest(&self) -> RpoDigest { self.digest } - fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { CallNodePrettyPrint { call_node: self, mast_forest, diff --git a/core/src/mast/node/dyn_node.rs b/core/src/mast/node/dyn_node.rs index 83c46f68fe..3593a9c619 100644 --- a/core/src/mast/node/dyn_node.rs +++ b/core/src/mast/node/dyn_node.rs @@ -2,10 +2,7 @@ use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; -use crate::{ - mast::{MastForest, MerkleTreeNode}, - OPCODE_DYN, -}; +use crate::{mast::MastForest, OPCODE_DYN}; #[derive(Debug, Clone, Default, PartialEq, Eq)] pub struct DynNode; @@ -16,8 +13,8 @@ impl DynNode { pub const DOMAIN: Felt = Felt::new(OPCODE_DYN as u64); } -impl MerkleTreeNode for DynNode { - fn digest(&self) -> RpoDigest { +impl DynNode { + pub fn digest(&self) -> RpoDigest { // The Dyn node is represented by a constant, which is set to be the hash of two empty // words ([ZERO, ZERO, ZERO, ZERO]) with a domain value of `DYN_DOMAIN`, i.e. // hasher::merge_in_domain(&[Digest::default(), Digest::default()], DynNode::DOMAIN) @@ -29,7 +26,7 @@ impl MerkleTreeNode for DynNode { ]) } - fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + pub fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { self } } diff --git a/core/src/mast/node/external.rs b/core/src/mast/node/external.rs index c0b8ff10a3..3c2f84ba2a 100644 --- a/core/src/mast/node/external.rs +++ b/core/src/mast/node/external.rs @@ -1,4 +1,4 @@ -use crate::mast::{MastForest, MerkleTreeNode}; +use crate::mast::MastForest; use core::fmt; use miden_crypto::hash::rpo::RpoDigest; @@ -24,11 +24,11 @@ impl ExternalNode { } } -impl MerkleTreeNode for ExternalNode { - fn digest(&self) -> RpoDigest { +impl ExternalNode { + pub fn digest(&self) -> RpoDigest { self.digest } - fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + pub fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { self } } diff --git a/core/src/mast/node/join_node.rs b/core/src/mast/node/join_node.rs index 1cd5322c0f..9cae20fbc3 100644 --- a/core/src/mast/node/join_node.rs +++ b/core/src/mast/node/join_node.rs @@ -4,11 +4,14 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt}; use crate::{ chiplets::hasher, - mast::{MastForest, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNodeId}, prettier::PrettyPrint, OPCODE_JOIN, }; +// JOIN NODE +// ================================================================================================ + #[derive(Debug, Clone, PartialEq, Eq)] pub struct JoinNode { children: [MastNodeId; 2], @@ -41,7 +44,7 @@ impl JoinNode { } } -/// Accessors +/// Public accessors impl JoinNode { pub fn first(&self) -> MastNodeId { self.children[0] @@ -50,13 +53,12 @@ impl JoinNode { pub fn second(&self) -> MastNodeId { self.children[1] } -} -impl JoinNode { - pub(super) fn to_pretty_print<'a>( - &'a self, - mast_forest: &'a MastForest, - ) -> impl PrettyPrint + 'a { + pub fn digest(&self) -> RpoDigest { + self.digest + } + + pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { JoinNodePrettyPrint { join_node: self, mast_forest, @@ -64,12 +66,14 @@ impl JoinNode { } } -impl MerkleTreeNode for JoinNode { - fn digest(&self) -> RpoDigest { - self.digest - } +// PRETTY PRINTING +// ================================================================================================ - fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { +impl JoinNode { + pub(super) fn to_pretty_print<'a>( + &'a self, + mast_forest: &'a MastForest, + ) -> impl PrettyPrint + 'a { JoinNodePrettyPrint { join_node: self, mast_forest, diff --git a/core/src/mast/node/loop_node.rs b/core/src/mast/node/loop_node.rs index 1554181b46..57cbe4f2d5 100644 --- a/core/src/mast/node/loop_node.rs +++ b/core/src/mast/node/loop_node.rs @@ -5,10 +5,13 @@ use miden_formatting::prettier::PrettyPrint; use crate::{ chiplets::hasher, - mast::{MastForest, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNodeId}, OPCODE_LOOP, }; +// LOOP NODE +// ================================================================================================ + #[derive(Debug, Clone, PartialEq, Eq)] pub struct LoopNode { body: MastNodeId, @@ -32,30 +35,33 @@ impl LoopNode { Self { body, digest } } - - pub(super) fn to_pretty_print<'a>( - &'a self, - mast_forest: &'a MastForest, - ) -> impl PrettyPrint + 'a { - LoopNodePrettyPrint { - loop_node: self, - mast_forest, - } - } } impl LoopNode { pub fn body(&self) -> MastNodeId { self.body } -} -impl MerkleTreeNode for LoopNode { - fn digest(&self) -> RpoDigest { + pub fn digest(&self) -> RpoDigest { self.digest } - fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + LoopNodePrettyPrint { + loop_node: self, + mast_forest, + } + } +} + +// PRETTY PRINTING +// ================================================================================================ + +impl LoopNode { + pub(super) fn to_pretty_print<'a>( + &'a self, + mast_forest: &'a MastForest, + ) -> impl PrettyPrint + 'a { LoopNodePrettyPrint { loop_node: self, mast_forest, diff --git a/core/src/mast/node/mod.rs b/core/src/mast/node/mod.rs index a1c52bb211..0a0efabd23 100644 --- a/core/src/mast/node/mod.rs +++ b/core/src/mast/node/mod.rs @@ -28,7 +28,7 @@ mod loop_node; pub use loop_node::LoopNode; use crate::{ - mast::{MastForest, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNodeId}, DecoratorList, Operation, }; @@ -140,13 +140,8 @@ impl MastNode { MastNode::External(_) => panic!("Can't fetch domain for an `External` node."), } } -} - -// ------------------------------------------------------------------------------------------------ -// MerkleTreeNode impl -impl MerkleTreeNode for MastNode { - fn digest(&self) -> RpoDigest { + pub fn digest(&self) -> RpoDigest { match self { MastNode::Block(node) => node.digest(), MastNode::Join(node) => node.digest(), @@ -158,7 +153,7 @@ impl MerkleTreeNode for MastNode { } } - fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { match self { MastNode::Block(node) => MastNodeDisplay::new(node.to_display(mast_forest)), MastNode::Join(node) => MastNodeDisplay::new(node.to_display(mast_forest)), diff --git a/core/src/mast/node/split_node.rs b/core/src/mast/node/split_node.rs index 600186a9e7..8525bfd769 100644 --- a/core/src/mast/node/split_node.rs +++ b/core/src/mast/node/split_node.rs @@ -5,10 +5,13 @@ use miden_formatting::prettier::PrettyPrint; use crate::{ chiplets::hasher, - mast::{MastForest, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNodeId}, OPCODE_SPLIT, }; +// SPLIT NODE +// ================================================================================================ + #[derive(Debug, Clone, PartialEq, Eq)] pub struct SplitNode { branches: [MastNodeId; 2], @@ -49,13 +52,12 @@ impl SplitNode { pub fn on_false(&self) -> MastNodeId { self.branches[1] } -} -impl SplitNode { - pub(super) fn to_pretty_print<'a>( - &'a self, - mast_forest: &'a MastForest, - ) -> impl PrettyPrint + 'a { + pub fn digest(&self) -> RpoDigest { + self.digest + } + + pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl core::fmt::Display + 'a { SplitNodePrettyPrint { split_node: self, mast_forest, @@ -63,12 +65,14 @@ impl SplitNode { } } -impl MerkleTreeNode for SplitNode { - fn digest(&self) -> RpoDigest { - self.digest - } +// PRETTY PRINTING +// ================================================================================================ - fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl core::fmt::Display + 'a { +impl SplitNode { + pub(super) fn to_pretty_print<'a>( + &'a self, + mast_forest: &'a MastForest, + ) -> impl PrettyPrint + 'a { SplitNodePrettyPrint { split_node: self, mast_forest, diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index 5b5a6aabdb..4646e9a20b 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -1,7 +1,7 @@ use miden_crypto::hash::rpo::RpoDigest; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; -use crate::mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}; +use crate::mast::{MastForest, MastNode, MastNodeId}; use super::{basic_block_data_decoder::BasicBlockDataDecoder, DataOffset}; diff --git a/core/src/mast/tests.rs b/core/src/mast/tests.rs index da43d1b5b4..1ca87f0d86 100644 --- a/core/src/mast/tests.rs +++ b/core/src/mast/tests.rs @@ -1,8 +1,4 @@ -use crate::{ - chiplets::hasher, - mast::{DynNode, MerkleTreeNode}, - Kernel, ProgramInfo, Word, -}; +use crate::{chiplets::hasher, mast::DynNode, Kernel, ProgramInfo, Word}; use alloc::vec::Vec; use miden_crypto::{hash::rpo::RpoDigest, Felt}; use proptest::prelude::*; diff --git a/core/src/program.rs b/core/src/program.rs index b055bf3313..66419229b1 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -5,7 +5,7 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt, WORD_SIZE}; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use crate::{ - mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNode, MastNodeId}, utils::ToElements, }; diff --git a/miden/tests/integration/operations/io_ops/env_ops.rs b/miden/tests/integration/operations/io_ops/env_ops.rs index 379a974652..5582aa13ae 100644 --- a/miden/tests/integration/operations/io_ops/env_ops.rs +++ b/miden/tests/integration/operations/io_ops/env_ops.rs @@ -1,7 +1,7 @@ use processor::FMP_MIN; use test_utils::{build_op_test, build_test, StackInputs, Test, Word, STACK_TOP_SIZE}; use vm_core::{ - mast::{MastForest, MastNode, MerkleTreeNode}, + mast::{MastForest, MastNode}, Operation, }; diff --git a/processor/src/chiplets/hasher/tests.rs b/processor/src/chiplets/hasher/tests.rs index 44f814ac58..69e70d581f 100644 --- a/processor/src/chiplets/hasher/tests.rs +++ b/processor/src/chiplets/hasher/tests.rs @@ -12,7 +12,7 @@ use test_utils::rand::rand_array; use vm_core::{ chiplets::hasher, crypto::merkle::{MerkleTree, NodeIndex}, - mast::{MastForest, MastNode, MerkleTreeNode}, + mast::{MastForest, MastNode}, Operation, ONE, ZERO, }; diff --git a/processor/src/decoder/mod.rs b/processor/src/decoder/mod.rs index 5d66138df4..146192225e 100644 --- a/processor/src/decoder/mod.rs +++ b/processor/src/decoder/mod.rs @@ -13,7 +13,7 @@ use miden_air::trace::{ use vm_core::{ mast::{ get_span_op_group_count, BasicBlockNode, CallNode, DynNode, JoinNode, LoopNode, MastForest, - MerkleTreeNode, SplitNode, OP_BATCH_SIZE, + SplitNode, OP_BATCH_SIZE, }, stack::STACK_TOP_SIZE, AssemblyOp, diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index 390a162e9e..8e84ecc5be 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -18,7 +18,7 @@ use miden_air::trace::{ }; use test_utils::rand::rand_value; use vm_core::{ - mast::{BasicBlockNode, MastForest, MastNode, MerkleTreeNode, OP_BATCH_SIZE}, + mast::{BasicBlockNode, MastForest, MastNode, OP_BATCH_SIZE}, Program, EMPTY_WORD, ONE, ZERO, }; diff --git a/processor/src/lib.rs b/processor/src/lib.rs index a254a1b1a6..95ab09e7ed 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -18,7 +18,7 @@ pub use vm_core::{ chiplets::hasher::Digest, crypto::merkle::SMT_DEPTH, errors::InputError, - mast::{MastForest, MastNode, MastNodeId, MerkleTreeNode}, + mast::{MastForest, MastNode, MastNodeId}, utils::DeserializationError, AdviceInjector, AssemblyOp, Felt, Kernel, Operation, Program, ProgramInfo, QuadExtension, StackInputs, StackOutputs, Word, EMPTY_WORD, ONE, ZERO, diff --git a/processor/src/trace/tests/decoder.rs b/processor/src/trace/tests/decoder.rs index 83246bdf5d..c5acde3241 100644 --- a/processor/src/trace/tests/decoder.rs +++ b/processor/src/trace/tests/decoder.rs @@ -16,7 +16,7 @@ use miden_air::trace::{ }; use test_utils::rand::rand_array; use vm_core::{ - mast::{MastForest, MastNode, MerkleTreeNode}, + mast::{MastForest, MastNode}, FieldElement, Operation, Program, Word, ONE, ZERO, }; From ba78af3b361c4904ea011eb72f8b3427fbd8aea5 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Sat, 20 Jul 2024 17:11:56 -0700 Subject: [PATCH 2/4] chore: added comments and section separators --- core/src/mast/mod.rs | 42 +-- core/src/mast/node/basic_block_node/mod.rs | 246 +++--------------- .../mast/node/basic_block_node/op_batch.rs | 175 +++++++++++++ core/src/mast/node/call_node.rs | 45 +++- core/src/mast/node/dyn_node.rs | 26 +- core/src/mast/node/external.rs | 12 +- core/src/mast/node/join_node.rs | 29 ++- core/src/mast/node/loop_node.rs | 31 ++- core/src/mast/node/mod.rs | 8 +- core/src/mast/node/split_node.rs | 32 ++- processor/src/decoder/mod.rs | 5 +- 11 files changed, 373 insertions(+), 278 deletions(-) create mode 100644 core/src/mast/node/basic_block_node/op_batch.rs diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 77ebad3249..feb1b841b3 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -5,8 +5,8 @@ use miden_crypto::hash::rpo::RpoDigest; mod node; pub use node::{ - get_span_op_group_count, BasicBlockNode, CallNode, DynNode, ExternalNode, JoinNode, LoopNode, - MastNode, OpBatch, OperationOrDecorator, SplitNode, OP_BATCH_SIZE, OP_GROUP_SIZE, + BasicBlockNode, CallNode, DynNode, ExternalNode, JoinNode, LoopNode, MastNode, OpBatch, + OperationOrDecorator, SplitNode, OP_BATCH_SIZE, OP_GROUP_SIZE, }; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; @@ -18,14 +18,17 @@ mod tests; // MAST FOREST // ================================================================================================ -/// Represents the types of errors that can occur when dealing with MAST forest. -#[derive(Debug, thiserror::Error)] -pub enum MastForestError { - #[error( - "invalid node count: MAST forest exceeds the maximum of {} nodes", - MastForest::MAX_NODES - )] - TooManyNodes, +/// Represents one or more procedures, represented as a collection of [`MastNode`]s. +/// +/// A [`MastForest`] does not have an entrypoint, and hence is not executable. A [`crate::Program`] +/// can be built from a [`MastForest`] to specify an entrypoint. +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct MastForest { + /// All of the nodes local to the trees comprising the MAST forest. + nodes: Vec, + + /// Roots of procedures defined within this MAST forest. + roots: Vec, } // ------------------------------------------------------------------------------------------------ @@ -174,15 +177,12 @@ impl Deserializable for MastNodeId { // MAST FOREST ERROR // ================================================================================================ -/// Represents one or more procedures, represented as a collection of [`MastNode`]s. -/// -/// A [`MastForest`] does not have an entrypoint, and hence is not executable. A [`crate::Program`] -/// can be built from a [`MastForest`] to specify an entrypoint. -#[derive(Clone, Debug, Default, PartialEq, Eq)] -pub struct MastForest { - /// All of the nodes local to the trees comprising the MAST forest. - nodes: Vec, - - /// Roots of procedures defined within this MAST forest. - roots: Vec, +/// Represents the types of errors that can occur when dealing with MAST forest. +#[derive(Debug, thiserror::Error)] +pub enum MastForestError { + #[error( + "invalid node count: MAST forest exceeds the maximum of {} nodes", + MastForest::MAX_NODES + )] + TooManyNodes, } diff --git a/core/src/mast/node/basic_block_node/mod.rs b/core/src/mast/node/basic_block_node/mod.rs index f71f1d0ec1..a98c5f7785 100644 --- a/core/src/mast/node/basic_block_node/mod.rs +++ b/core/src/mast/node/basic_block_node/mod.rs @@ -5,9 +5,11 @@ use miden_crypto::{hash::rpo::RpoDigest, Felt, ZERO}; use miden_formatting::prettier::PrettyPrint; use winter_utils::flatten_slice_elements; -use crate::{ - chiplets::hasher, mast::MastForest, Decorator, DecoratorIterator, DecoratorList, Operation, -}; +use crate::{chiplets::hasher, Decorator, DecoratorIterator, DecoratorList, Operation}; + +mod op_batch; +pub use op_batch::OpBatch; +use op_batch::OpBatchAccumulator; #[cfg(test)] mod tests; @@ -111,23 +113,36 @@ impl BasicBlockNode { // ------------------------------------------------------------------------------------------------ /// Public accessors impl BasicBlockNode { - pub fn num_operations_and_decorators(&self) -> u32 { - let num_ops: usize = self.op_batches.iter().map(|batch| batch.ops().len()).sum(); - let num_decorators = self.decorators.len(); - - (num_ops + num_decorators) - .try_into() - .expect("basic block contains more than 2^32 operations and decorators") + /// Returns a commitment to this basic block. + pub fn digest(&self) -> RpoDigest { + self.digest } + /// Returns a reference to the operation batches in this basic block. pub fn op_batches(&self) -> &[OpBatch] { &self.op_batches } - /// Returns an iterator over all operations and decorator, in the order in which they appear in - /// the program. - pub fn iter(&self) -> impl Iterator { - OperationOrDecoratorIterator::new(self) + /// Returns the total number of operation groups in this basic block. + /// + /// Then number of operation groups is computed as follows: + /// - For all batches but the last one we set the number of groups to 8, regardless of the + /// actual number of groups in the batch. The reason for this is that when operation batches + /// are concatenated together each batch contributes 8 elements to the hash. + /// - For the last batch, we take the number of actual groups and round it up to the next power + /// of two. The reason for rounding is that the VM always executes a number of operation + /// groups which is a power of two. + pub fn num_op_groups(&self) -> usize { + let last_batch_num_groups = self.op_batches.last().expect("no last group").num_groups(); + (self.op_batches.len() - 1) * BATCH_SIZE + last_batch_num_groups.next_power_of_two() + } + + /// Returns a list of decorators in this basic block node. + /// + /// Each decorator is accompanied by the operation index specifying the operation prior to + /// which the decorator should be executed. + pub fn decorators(&self) -> &DecoratorList { + &self.decorators } /// Returns a [`DecoratorIterator`] which allows us to iterate through the decorator list of @@ -136,17 +151,20 @@ impl BasicBlockNode { DecoratorIterator::new(&self.decorators) } - /// Returns a list of decorators in this basic block node. - pub fn decorators(&self) -> &DecoratorList { - &self.decorators - } + /// Returns the total number of operations and decorators in this basic block. + pub fn num_operations_and_decorators(&self) -> u32 { + let num_ops: usize = self.op_batches.iter().map(|batch| batch.ops().len()).sum(); + let num_decorators = self.decorators.len(); - pub fn digest(&self) -> RpoDigest { - self.digest + (num_ops + num_decorators) + .try_into() + .expect("basic block contains more than 2^32 operations and decorators") } - pub fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { - self + /// Returns an iterator over all operations and decorator, in the order in which they appear in + /// the program. + pub fn iter(&self) -> impl Iterator { + OperationOrDecoratorIterator::new(self) } } @@ -275,175 +293,6 @@ impl<'a> Iterator for OperationOrDecoratorIterator<'a> { } } -// OPERATION BATCH -// ================================================================================================ - -/// A batch of operations in a span block. -/// -/// An operation batch consists of up to 8 operation groups, with each group containing up to 9 -/// operations or a single immediate value. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct OpBatch { - ops: Vec, - groups: [Felt; BATCH_SIZE], - op_counts: [usize; BATCH_SIZE], - num_groups: usize, -} - -impl OpBatch { - /// Returns a list of operations contained in this batch. - pub fn ops(&self) -> &[Operation] { - &self.ops - } - - /// Returns a list of operation groups contained in this batch. - /// - /// Each group is represented by a single field element. - pub fn groups(&self) -> &[Felt; BATCH_SIZE] { - &self.groups - } - - /// Returns the number of non-decorator operations for each operation group. - /// - /// Number of operations for groups containing immediate values is set to 0. - pub fn op_counts(&self) -> &[usize; BATCH_SIZE] { - &self.op_counts - } - - /// Returns the number of groups in this batch. - pub fn num_groups(&self) -> usize { - self.num_groups - } -} - -/// An accumulator used in construction of operation batches. -struct OpBatchAccumulator { - /// A list of operations in this batch, including decorators. - ops: Vec, - /// Values of operation groups, including immediate values. - groups: [Felt; BATCH_SIZE], - /// Number of non-decorator operations in each operation group. Operation count for groups - /// with immediate values is set to 0. - op_counts: [usize; BATCH_SIZE], - /// Value of the currently active op group. - group: u64, - /// Index of the next opcode in the current group. - op_idx: usize, - /// index of the current group in the batch. - group_idx: usize, - // Index of the next free group in the batch. - next_group_idx: usize, -} - -impl OpBatchAccumulator { - /// Returns a blank [OpBatchAccumulator]. - pub fn new() -> Self { - Self { - ops: Vec::new(), - groups: [ZERO; BATCH_SIZE], - op_counts: [0; BATCH_SIZE], - group: 0, - op_idx: 0, - group_idx: 0, - next_group_idx: 1, - } - } - - /// Returns true if this accumulator does not contain any operations. - pub fn is_empty(&self) -> bool { - self.ops.is_empty() - } - - /// Returns true if this accumulator can accept the specified operation. - /// - /// An accumulator may not be able accept an operation for the following reasons: - /// - There is no more space in the underlying batch (e.g., the 8th group of the batch already - /// contains 9 operations). - /// - There is no space for the immediate value carried by the operation (e.g., the 8th group is - /// only partially full, but we are trying to add a PUSH operation). - /// - The alignment rules require that the operation overflows into the next group, and if this - /// happens, there will be no space for the operation or its immediate value. - pub fn can_accept_op(&self, op: Operation) -> bool { - if op.imm_value().is_some() { - // an operation carrying an immediate value cannot be the last one in a group; so, we - // check if we need to move the operation to the next group. in either case, we need - // to make sure there is enough space for the immediate value as well. - if self.op_idx < GROUP_SIZE - 1 { - self.next_group_idx < BATCH_SIZE - } else { - self.next_group_idx + 1 < BATCH_SIZE - } - } else { - // check if there is space for the operation in the current group, or if there isn't, - // whether we can add another group - self.op_idx < GROUP_SIZE || self.next_group_idx < BATCH_SIZE - } - } - - /// Adds the specified operation to this accumulator. It is expected that the specified - /// operation is not a decorator and that (can_accept_op())[OpBatchAccumulator::can_accept_op] - /// is called before this function to make sure that the specified operation can be added to - /// the accumulator. - pub fn add_op(&mut self, op: Operation) { - // if the group is full, finalize it and start a new group - if self.op_idx == GROUP_SIZE { - self.finalize_op_group(); - } - - // for operations with immediate values, we need to do a few more things - if let Some(imm) = op.imm_value() { - // since an operation with an immediate value cannot be the last one in a group, if - // the operation would be the last one in the group, we need to start a new group - if self.op_idx == GROUP_SIZE - 1 { - self.finalize_op_group(); - } - - // save the immediate value at the next group index and advance the next group pointer - self.groups[self.next_group_idx] = imm; - self.next_group_idx += 1; - } - - // add the opcode to the group and increment the op index pointer - let opcode = op.op_code() as u64; - self.group |= opcode << (Operation::OP_BITS * self.op_idx); - self.ops.push(op); - self.op_idx += 1; - } - - /// Convert the accumulator into an [OpBatch]. - pub fn into_batch(mut self) -> OpBatch { - // make sure the last group gets added to the group array; we also check the op_idx to - // handle the case when a group contains a single NOOP operation. - if self.group != 0 || self.op_idx != 0 { - self.groups[self.group_idx] = Felt::new(self.group); - self.op_counts[self.group_idx] = self.op_idx; - } - - OpBatch { - ops: self.ops, - groups: self.groups, - op_counts: self.op_counts, - num_groups: self.next_group_idx, - } - } - - // HELPER METHODS - // -------------------------------------------------------------------------------------------- - - /// Saves the current group into the group array, advances current and next group pointers, - /// and resets group content. - fn finalize_op_group(&mut self) { - self.groups[self.group_idx] = Felt::new(self.group); - self.op_counts[self.group_idx] = self.op_idx; - - self.group_idx = self.next_group_idx; - self.next_group_idx = self.group_idx + 1; - - self.op_idx = 0; - self.group = 0; - } -} - // HELPER FUNCTIONS // ================================================================================================ @@ -485,21 +334,6 @@ fn batch_ops(ops: Vec) -> (Vec, RpoDigest) { (batches, hash) } -/// Returns the total number of operation groups in a span defined by the provides list of -/// operation batches. -/// -/// Then number of operation groups is computed as follows: -/// - For all batches but the last one we set the number of groups to 8, regardless of the actual -/// number of groups in the batch. The reason for this is that when operation batches are -/// concatenated together each batch contributes 8 elements to the hash. -/// - For the last batch, we take the number of actual batches and round it up to the next power of -/// two. The reason for rounding is that the VM always executes a number of operation groups which -/// is a power of two. -pub fn get_span_op_group_count(op_batches: &[OpBatch]) -> usize { - let last_batch_num_groups = op_batches.last().expect("no last group").num_groups(); - (op_batches.len() - 1) * BATCH_SIZE + last_batch_num_groups.next_power_of_two() -} - /// Checks if a given decorators list is valid (only checked in debug mode) /// - Assert the decorator list is in ascending order. /// - Assert the last op index in decorator list is less than or equal to the number of operations. diff --git a/core/src/mast/node/basic_block_node/op_batch.rs b/core/src/mast/node/basic_block_node/op_batch.rs new file mode 100644 index 0000000000..f9e064f15a --- /dev/null +++ b/core/src/mast/node/basic_block_node/op_batch.rs @@ -0,0 +1,175 @@ +use super::{Felt, Operation, BATCH_SIZE, GROUP_SIZE, ZERO}; + +use alloc::vec::Vec; + +// OPERATION BATCH +// ================================================================================================ + +/// A batch of operations in a span block. +/// +/// An operation batch consists of up to 8 operation groups, with each group containing up to 9 +/// operations or a single immediate value. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct OpBatch { + pub(super) ops: Vec, + pub(super) groups: [Felt; BATCH_SIZE], + pub(super) op_counts: [usize; BATCH_SIZE], + pub(super) num_groups: usize, +} + +impl OpBatch { + /// Returns a list of operations contained in this batch. + pub fn ops(&self) -> &[Operation] { + &self.ops + } + + /// Returns a list of operation groups contained in this batch. + /// + /// Each group is represented by a single field element. + pub fn groups(&self) -> &[Felt; BATCH_SIZE] { + &self.groups + } + + /// Returns the number of non-decorator operations for each operation group. + /// + /// Number of operations for groups containing immediate values is set to 0. + pub fn op_counts(&self) -> &[usize; BATCH_SIZE] { + &self.op_counts + } + + /// Returns the number of groups in this batch. + pub fn num_groups(&self) -> usize { + self.num_groups + } +} + +// OPERATION BATCH ACCUMULATOR +// ================================================================================================ + +/// An accumulator used in construction of operation batches. +pub(super) struct OpBatchAccumulator { + /// A list of operations in this batch, including decorators. + ops: Vec, + /// Values of operation groups, including immediate values. + groups: [Felt; BATCH_SIZE], + /// Number of non-decorator operations in each operation group. Operation count for groups + /// with immediate values is set to 0. + op_counts: [usize; BATCH_SIZE], + /// Value of the currently active op group. + group: u64, + /// Index of the next opcode in the current group. + op_idx: usize, + /// index of the current group in the batch. + group_idx: usize, + // Index of the next free group in the batch. + next_group_idx: usize, +} + +impl OpBatchAccumulator { + /// Returns a blank [OpBatchAccumulator]. + pub fn new() -> Self { + Self { + ops: Vec::new(), + groups: [ZERO; BATCH_SIZE], + op_counts: [0; BATCH_SIZE], + group: 0, + op_idx: 0, + group_idx: 0, + next_group_idx: 1, + } + } + + /// Returns true if this accumulator does not contain any operations. + pub fn is_empty(&self) -> bool { + self.ops.is_empty() + } + + /// Returns true if this accumulator can accept the specified operation. + /// + /// An accumulator may not be able accept an operation for the following reasons: + /// - There is no more space in the underlying batch (e.g., the 8th group of the batch already + /// contains 9 operations). + /// - There is no space for the immediate value carried by the operation (e.g., the 8th group is + /// only partially full, but we are trying to add a PUSH operation). + /// - The alignment rules require that the operation overflows into the next group, and if this + /// happens, there will be no space for the operation or its immediate value. + pub fn can_accept_op(&self, op: Operation) -> bool { + if op.imm_value().is_some() { + // an operation carrying an immediate value cannot be the last one in a group; so, we + // check if we need to move the operation to the next group. in either case, we need + // to make sure there is enough space for the immediate value as well. + if self.op_idx < GROUP_SIZE - 1 { + self.next_group_idx < BATCH_SIZE + } else { + self.next_group_idx + 1 < BATCH_SIZE + } + } else { + // check if there is space for the operation in the current group, or if there isn't, + // whether we can add another group + self.op_idx < GROUP_SIZE || self.next_group_idx < BATCH_SIZE + } + } + + /// Adds the specified operation to this accumulator. It is expected that the specified + /// operation is not a decorator and that (can_accept_op())[OpBatchAccumulator::can_accept_op] + /// is called before this function to make sure that the specified operation can be added to + /// the accumulator. + pub fn add_op(&mut self, op: Operation) { + // if the group is full, finalize it and start a new group + if self.op_idx == GROUP_SIZE { + self.finalize_op_group(); + } + + // for operations with immediate values, we need to do a few more things + if let Some(imm) = op.imm_value() { + // since an operation with an immediate value cannot be the last one in a group, if + // the operation would be the last one in the group, we need to start a new group + if self.op_idx == GROUP_SIZE - 1 { + self.finalize_op_group(); + } + + // save the immediate value at the next group index and advance the next group pointer + self.groups[self.next_group_idx] = imm; + self.next_group_idx += 1; + } + + // add the opcode to the group and increment the op index pointer + let opcode = op.op_code() as u64; + self.group |= opcode << (Operation::OP_BITS * self.op_idx); + self.ops.push(op); + self.op_idx += 1; + } + + /// Convert the accumulator into an [OpBatch]. + pub fn into_batch(mut self) -> OpBatch { + // make sure the last group gets added to the group array; we also check the op_idx to + // handle the case when a group contains a single NOOP operation. + if self.group != 0 || self.op_idx != 0 { + self.groups[self.group_idx] = Felt::new(self.group); + self.op_counts[self.group_idx] = self.op_idx; + } + + OpBatch { + ops: self.ops, + groups: self.groups, + op_counts: self.op_counts, + num_groups: self.next_group_idx, + } + } + + // HELPER METHODS + // -------------------------------------------------------------------------------------------- + + /// Saves the current group into the group array, advances current and next group pointers, + /// and resets group content. + pub(super) fn finalize_op_group(&mut self) { + self.groups[self.group_idx] = Felt::new(self.group); + self.op_counts[self.group_idx] = self.op_idx; + + self.group_idx = self.next_group_idx; + self.next_group_idx = self.group_idx + 1; + + self.op_idx = 0; + self.group = 0; + } +} diff --git a/core/src/mast/node/call_node.rs b/core/src/mast/node/call_node.rs index 2b12ec92e6..2cb116aeba 100644 --- a/core/src/mast/node/call_node.rs +++ b/core/src/mast/node/call_node.rs @@ -9,6 +9,15 @@ use crate::{ OPCODE_CALL, OPCODE_SYSCALL, }; +// CALL NODE +// ================================================================================================ + +/// A Call node describes a function call such that the callee is executed in a different execution +/// context from the currently executing code. +/// +/// A call node can be of two types: +/// - A simple call: the callee is executed in the new user context. +/// - A syscall: the callee is executed in the root context. #[derive(Debug, Clone, PartialEq, Eq)] pub struct CallNode { callee: MastNodeId, @@ -16,6 +25,7 @@ pub struct CallNode { digest: RpoDigest, } +//------------------------------------------------------------------------------------------------- /// Constants impl CallNode { /// The domain of the call block (used for control block hashing). @@ -24,6 +34,7 @@ impl CallNode { pub const SYSCALL_DOMAIN: Felt = Felt::new(OPCODE_SYSCALL as u64); } +//------------------------------------------------------------------------------------------------- /// Constructors impl CallNode { /// Returns a new [`CallNode`] instantiated with the specified callee. @@ -58,16 +69,35 @@ impl CallNode { } } +//------------------------------------------------------------------------------------------------- +/// Public accessors impl CallNode { + /// Returns a commitment to this Call node. + /// + /// The commitment is computed as a hash of the callee and an empty word ([ZERO; 4]) in the + /// domain defined by either [Self::CALL_DOMAIN] or [Self::SYSCALL_DOMAIN], depending on + /// whether the node represents a simple call or a syscall - i.e.,: + /// + /// hasher::merge_in_domain(&[callee_digest, Digest::default()], CallNode::CALL_DOMAIN) + /// + /// or + /// + /// hasher::merge_in_domain(&[callee_digest, Digest::default()], CallNode::SYSCALL_DOMAIN) + pub fn digest(&self) -> RpoDigest { + self.digest + } + + /// Returns the ID of the node to be invoked by this call node. pub fn callee(&self) -> MastNodeId { self.callee } + /// Returns true if this call node represents a syscall. pub fn is_syscall(&self) -> bool { self.is_syscall } - /// Returns the domain of the call node. + /// Returns the domain of this call node. pub fn domain(&self) -> Felt { if self.is_syscall() { Self::SYSCALL_DOMAIN @@ -75,7 +105,12 @@ impl CallNode { Self::CALL_DOMAIN } } +} + +// PRETTY PRINTING +// ================================================================================================ +impl CallNode { pub(super) fn to_pretty_print<'a>( &'a self, mast_forest: &'a MastForest, @@ -85,14 +120,8 @@ impl CallNode { mast_forest, } } -} - -impl CallNode { - pub fn digest(&self) -> RpoDigest { - self.digest - } - pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { CallNodePrettyPrint { call_node: self, mast_forest, diff --git a/core/src/mast/node/dyn_node.rs b/core/src/mast/node/dyn_node.rs index 3593a9c619..47e5186b59 100644 --- a/core/src/mast/node/dyn_node.rs +++ b/core/src/mast/node/dyn_node.rs @@ -2,8 +2,12 @@ use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; -use crate::{mast::MastForest, OPCODE_DYN}; +use crate::OPCODE_DYN; +// DYN NODE +// ================================================================================================ + +/// A Dyn node specifies that the node to be executed next is defined dynamically via the stack. #[derive(Debug, Clone, Default, PartialEq, Eq)] pub struct DynNode; @@ -13,11 +17,15 @@ impl DynNode { pub const DOMAIN: Felt = Felt::new(OPCODE_DYN as u64); } +/// Public accessors impl DynNode { + /// Returns a commitment to a Dyn node. + /// + /// The commitment is computed by hashing two empty words ([ZERO; 4]) in the domain defined + /// by [Self::DOMAIN], i.e.: + /// + /// hasher::merge_in_domain(&[Digest::default(), Digest::default()], DynNode::DOMAIN) pub fn digest(&self) -> RpoDigest { - // The Dyn node is represented by a constant, which is set to be the hash of two empty - // words ([ZERO, ZERO, ZERO, ZERO]) with a domain value of `DYN_DOMAIN`, i.e. - // hasher::merge_in_domain(&[Digest::default(), Digest::default()], DynNode::DOMAIN) RpoDigest::new([ Felt::new(8115106948140260551), Felt::new(13491227816952616836), @@ -25,12 +33,11 @@ impl DynNode { Felt::new(16575543461540527115), ]) } - - pub fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { - self - } } +// PRETTY PRINTING +// ================================================================================================ + impl crate::prettier::PrettyPrint for DynNode { fn render(&self) -> crate::prettier::Document { use crate::prettier::*; @@ -45,6 +52,9 @@ impl fmt::Display for DynNode { } } +// TESTS +// ================================================================================================ + #[cfg(test)] mod tests { use miden_crypto::hash::rpo::Rpo256; diff --git a/core/src/mast/node/external.rs b/core/src/mast/node/external.rs index 3c2f84ba2a..e0d15922ff 100644 --- a/core/src/mast/node/external.rs +++ b/core/src/mast/node/external.rs @@ -2,6 +2,9 @@ use crate::mast::MastForest; use core::fmt; use miden_crypto::hash::rpo::RpoDigest; +// EXTERNAL NODE +// ================================================================================================ + /// Node for referencing procedures not present in a given [`MastForest`] (hence "external"). /// /// External nodes can be used to verify the integrity of a program's hash while keeping parts of @@ -25,10 +28,17 @@ impl ExternalNode { } impl ExternalNode { + /// Returns the commitment to the MAST node referenced by this external node. pub fn digest(&self) -> RpoDigest { self.digest } - pub fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { +} + +// PRETTY PRINTING +// ================================================================================================ + +impl ExternalNode { + pub(super) fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { self } } diff --git a/core/src/mast/node/join_node.rs b/core/src/mast/node/join_node.rs index 9cae20fbc3..797b030fa7 100644 --- a/core/src/mast/node/join_node.rs +++ b/core/src/mast/node/join_node.rs @@ -12,6 +12,8 @@ use crate::{ // JOIN NODE // ================================================================================================ +/// A Join node describe sequential execution. When the VM encounters a Join node, it executes the +/// first child first and the second child second. #[derive(Debug, Clone, PartialEq, Eq)] pub struct JoinNode { children: [MastNodeId; 2], @@ -46,30 +48,39 @@ impl JoinNode { /// Public accessors impl JoinNode { + /// Returns a commitment to this Join node. + /// + /// The commitment is computed as a hash of the `first` and `second` child node in the domain + /// defined by [Self::DOMAIN] - i.e.,: + /// + /// hasher::merge_in_domain(&[first_child_digest, second_child_digest], JoinNode::DOMAIN) + pub fn digest(&self) -> RpoDigest { + self.digest + } + + /// Returns the ID of the node that is to be executed first. pub fn first(&self) -> MastNodeId { self.children[0] } + /// Returns the ID of the node that is to be executed after the execution of the program + /// defined by the first node completes. pub fn second(&self) -> MastNodeId { self.children[1] } +} - pub fn digest(&self) -> RpoDigest { - self.digest - } +// PRETTY PRINTING +// ================================================================================================ - pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { +impl JoinNode { + pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { JoinNodePrettyPrint { join_node: self, mast_forest, } } -} -// PRETTY PRINTING -// ================================================================================================ - -impl JoinNode { pub(super) fn to_pretty_print<'a>( &'a self, mast_forest: &'a MastForest, diff --git a/core/src/mast/node/loop_node.rs b/core/src/mast/node/loop_node.rs index 57cbe4f2d5..125a6c2c02 100644 --- a/core/src/mast/node/loop_node.rs +++ b/core/src/mast/node/loop_node.rs @@ -12,6 +12,12 @@ use crate::{ // LOOP NODE // ================================================================================================ +/// A Loop node defines condition-controlled iterative execution. When the VM encounters a Loop +/// node, it will keep executing the body of the loop as long as the top of the stack is `1``. +/// +/// The loop is exited when at the end of executing the loop body the top of the stack is `0``. +/// If the top of the stack is neither `0` nor `1` when the condition is checked, the execution +/// fails. #[derive(Debug, Clone, PartialEq, Eq)] pub struct LoopNode { body: MastNodeId, @@ -38,19 +44,19 @@ impl LoopNode { } impl LoopNode { - pub fn body(&self) -> MastNodeId { - self.body - } - + /// Returns a commitment to this Loop node. + /// + /// The commitment is computed as a hash of the loop body and an empty word ([ZERO; 4]) in + /// the domain defined by [Self::DOMAIN] - i..e,: + /// + /// hasher::merge_in_domain(&[on_true_digest, Digest::default()], LoopNode::DOMAIN) pub fn digest(&self) -> RpoDigest { self.digest } - pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { - LoopNodePrettyPrint { - loop_node: self, - mast_forest, - } + /// Returns the ID of the node presenting the body of the loop. + pub fn body(&self) -> MastNodeId { + self.body } } @@ -58,6 +64,13 @@ impl LoopNode { // ================================================================================================ impl LoopNode { + pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + LoopNodePrettyPrint { + loop_node: self, + mast_forest, + } + } + pub(super) fn to_pretty_print<'a>( &'a self, mast_forest: &'a MastForest, diff --git a/core/src/mast/node/mod.rs b/core/src/mast/node/mod.rs index 0a0efabd23..d84dcf2916 100644 --- a/core/src/mast/node/mod.rs +++ b/core/src/mast/node/mod.rs @@ -3,8 +3,8 @@ use core::fmt; use alloc::{boxed::Box, vec::Vec}; pub use basic_block_node::{ - get_span_op_group_count, BasicBlockNode, OpBatch, OperationOrDecorator, - BATCH_SIZE as OP_BATCH_SIZE, GROUP_SIZE as OP_GROUP_SIZE, + BasicBlockNode, OpBatch, OperationOrDecorator, BATCH_SIZE as OP_BATCH_SIZE, + GROUP_SIZE as OP_GROUP_SIZE, }; mod call_node; @@ -155,12 +155,12 @@ impl MastNode { pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { match self { - MastNode::Block(node) => MastNodeDisplay::new(node.to_display(mast_forest)), + MastNode::Block(node) => MastNodeDisplay::new(node), MastNode::Join(node) => MastNodeDisplay::new(node.to_display(mast_forest)), MastNode::Split(node) => MastNodeDisplay::new(node.to_display(mast_forest)), MastNode::Loop(node) => MastNodeDisplay::new(node.to_display(mast_forest)), MastNode::Call(node) => MastNodeDisplay::new(node.to_display(mast_forest)), - MastNode::Dyn => MastNodeDisplay::new(DynNode.to_display(mast_forest)), + MastNode::Dyn => MastNodeDisplay::new(DynNode), MastNode::External(node) => MastNodeDisplay::new(node.to_display(mast_forest)), } } diff --git a/core/src/mast/node/split_node.rs b/core/src/mast/node/split_node.rs index 8525bfd769..017e12171f 100644 --- a/core/src/mast/node/split_node.rs +++ b/core/src/mast/node/split_node.rs @@ -12,6 +12,12 @@ use crate::{ // SPLIT NODE // ================================================================================================ +/// A Split node defines conditional execution. When the VM encounters a Split node it executes +/// either the `on_true` child or `on_false` child. +/// +/// Which child is executed is determined based on the top of the stack. If the value is `1`, then +/// the `on_true` child is executed. If the value is `0`, then the `on_false` child is executed. If +/// the value is neither `0` nor `1`, the execution fails. #[derive(Debug, Clone, PartialEq, Eq)] pub struct SplitNode { branches: [MastNodeId; 2], @@ -45,30 +51,38 @@ impl SplitNode { /// Public accessors impl SplitNode { + /// Returns a commitment to this Split node. + /// + /// The commitment is computed as a hash of the `on_true` and `on_false` child nodes in the + /// domain defined by [Self::DOMAIN] - i..e,: + /// + /// hasher::merge_in_domain(&[on_true_digest, on_false_digest], SplitNode::DOMAIN) + pub fn digest(&self) -> RpoDigest { + self.digest + } + + /// Returns the ID of the node which is to be executed if the top of the stack is `1`. pub fn on_true(&self) -> MastNodeId { self.branches[0] } + /// Returns the ID of the node which is to be executed if the top of the stack is `0`. pub fn on_false(&self) -> MastNodeId { self.branches[1] } +} - pub fn digest(&self) -> RpoDigest { - self.digest - } +// PRETTY PRINTING +// ================================================================================================ - pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl core::fmt::Display + 'a { +impl SplitNode { + pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { SplitNodePrettyPrint { split_node: self, mast_forest, } } -} -// PRETTY PRINTING -// ================================================================================================ - -impl SplitNode { pub(super) fn to_pretty_print<'a>( &'a self, mast_forest: &'a MastForest, diff --git a/processor/src/decoder/mod.rs b/processor/src/decoder/mod.rs index 146192225e..36e4a6b19b 100644 --- a/processor/src/decoder/mod.rs +++ b/processor/src/decoder/mod.rs @@ -12,8 +12,7 @@ use miden_air::trace::{ }; use vm_core::{ mast::{ - get_span_op_group_count, BasicBlockNode, CallNode, DynNode, JoinNode, LoopNode, MastForest, - SplitNode, OP_BATCH_SIZE, + BasicBlockNode, CallNode, DynNode, JoinNode, LoopNode, MastForest, SplitNode, OP_BATCH_SIZE, }, stack::STACK_TOP_SIZE, AssemblyOp, @@ -341,7 +340,7 @@ where // start decoding the first operation batch; this also appends a row with SPAN operation // to the decoder trace. we also need the total number of operation groups so that we can // set the value of the group_count register at the beginning of the SPAN. - let num_op_groups = get_span_op_group_count(op_batches); + let num_op_groups = basic_block.num_op_groups(); self.decoder .start_basic_block(&op_batches[0], Felt::new(num_op_groups as u64), addr); self.execute_op(Operation::Noop) From 9e59aece8cdfda3528e9f7f2f4ebe5a70e87d589 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Sat, 20 Jul 2024 17:35:32 -0700 Subject: [PATCH 3/4] fix: remove implicit serialization of MastNodeId --- core/src/mast/mod.rs | 24 ++++++++++-------------- core/src/mast/serialization/mod.rs | 10 +++++----- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index feb1b841b3..20d1ed9309 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -8,7 +8,7 @@ pub use node::{ BasicBlockNode, CallNode, DynNode, ExternalNode, JoinNode, LoopNode, MastNode, OpBatch, OperationOrDecorator, SplitNode, OP_BATCH_SIZE, OP_GROUP_SIZE, }; -use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; +use winter_utils::DeserializationError; mod serialization; @@ -152,25 +152,21 @@ impl MastNodeId { } } -impl fmt::Display for MastNodeId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "MastNodeId({})", self.0) +impl From for u32 { + fn from(value: MastNodeId) -> Self { + value.0 } } -impl Serializable for MastNodeId { - fn write_into(&self, target: &mut W) { - self.0.write_into(target) +impl From<&MastNodeId> for u32 { + fn from(value: &MastNodeId) -> Self { + value.0 } } -impl Deserializable for MastNodeId { - fn read_from(source: &mut R) -> Result { - let inner = source.read_u32()?; - - // TODO: fix - - Ok(Self(inner)) +impl fmt::Display for MastNodeId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "MastNodeId({})", self.0) } } diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index e23cdbd2fb..40b32c0923 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -54,7 +54,8 @@ impl Serializable for MastForest { target.write_usize(self.nodes.len()); // roots - self.roots.write_into(target); + let roots: Vec = self.roots.iter().map(u32::from).collect(); + roots.write_into(target); // Prepare MAST node infos, but don't store them yet. We store them at the end to make // deserialization more efficient. @@ -102,11 +103,8 @@ impl Deserializable for MastForest { } let node_count = source.read_usize()?; - - let roots: Vec = Deserializable::read_from(source)?; - + let roots: Vec = Deserializable::read_from(source)?; let strings: Vec = Deserializable::read_from(source)?; - let data: Vec = Deserializable::read_from(source)?; let basic_block_data_decoder = BasicBlockDataDecoder::new(&data, &strings); @@ -128,6 +126,8 @@ impl Deserializable for MastForest { } for root in roots { + // make sure the root is valid in the context of the MAST forest + let root = MastNodeId::from_u32_safe(root, &mast_forest)?; mast_forest.make_root(root); } From ab1effc04bd8eb2c73629ad298f6001581101633 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Mon, 22 Jul 2024 14:22:46 -0700 Subject: [PATCH 4/4] test: add doctests for node hash computations --- core/src/mast/node/call_node.rs | 17 ++++++++++++----- core/src/mast/node/dyn_node.rs | 6 +++++- core/src/mast/node/join_node.rs | 9 +++++++-- core/src/mast/node/loop_node.rs | 8 ++++++-- core/src/mast/node/split_node.rs | 9 +++++++-- 5 files changed, 37 insertions(+), 12 deletions(-) diff --git a/core/src/mast/node/call_node.rs b/core/src/mast/node/call_node.rs index 2cb116aeba..ca9c720195 100644 --- a/core/src/mast/node/call_node.rs +++ b/core/src/mast/node/call_node.rs @@ -77,12 +77,19 @@ impl CallNode { /// The commitment is computed as a hash of the callee and an empty word ([ZERO; 4]) in the /// domain defined by either [Self::CALL_DOMAIN] or [Self::SYSCALL_DOMAIN], depending on /// whether the node represents a simple call or a syscall - i.e.,: - /// - /// hasher::merge_in_domain(&[callee_digest, Digest::default()], CallNode::CALL_DOMAIN) - /// + /// ``` + /// # use miden_core::mast::CallNode; + /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}}; + /// # let callee_digest = Digest::default(); + /// Hasher::merge_in_domain(&[callee_digest, Digest::default()], CallNode::CALL_DOMAIN); + /// ``` /// or - /// - /// hasher::merge_in_domain(&[callee_digest, Digest::default()], CallNode::SYSCALL_DOMAIN) + /// ``` + /// # use miden_core::mast::CallNode; + /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}}; + /// # let callee_digest = Digest::default(); + /// Hasher::merge_in_domain(&[callee_digest, Digest::default()], CallNode::SYSCALL_DOMAIN); + /// ``` pub fn digest(&self) -> RpoDigest { self.digest } diff --git a/core/src/mast/node/dyn_node.rs b/core/src/mast/node/dyn_node.rs index 47e5186b59..934a8fec2d 100644 --- a/core/src/mast/node/dyn_node.rs +++ b/core/src/mast/node/dyn_node.rs @@ -24,7 +24,11 @@ impl DynNode { /// The commitment is computed by hashing two empty words ([ZERO; 4]) in the domain defined /// by [Self::DOMAIN], i.e.: /// - /// hasher::merge_in_domain(&[Digest::default(), Digest::default()], DynNode::DOMAIN) + /// ``` + /// # use miden_core::mast::DynNode; + /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}}; + /// Hasher::merge_in_domain(&[Digest::default(), Digest::default()], DynNode::DOMAIN); + /// ``` pub fn digest(&self) -> RpoDigest { RpoDigest::new([ Felt::new(8115106948140260551), diff --git a/core/src/mast/node/join_node.rs b/core/src/mast/node/join_node.rs index 797b030fa7..5f802873dd 100644 --- a/core/src/mast/node/join_node.rs +++ b/core/src/mast/node/join_node.rs @@ -52,8 +52,13 @@ impl JoinNode { /// /// The commitment is computed as a hash of the `first` and `second` child node in the domain /// defined by [Self::DOMAIN] - i.e.,: - /// - /// hasher::merge_in_domain(&[first_child_digest, second_child_digest], JoinNode::DOMAIN) + /// ``` + /// # use miden_core::mast::JoinNode; + /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}}; + /// # let first_child_digest = Digest::default(); + /// # let second_child_digest = Digest::default(); + /// Hasher::merge_in_domain(&[first_child_digest, second_child_digest], JoinNode::DOMAIN); + /// ``` pub fn digest(&self) -> RpoDigest { self.digest } diff --git a/core/src/mast/node/loop_node.rs b/core/src/mast/node/loop_node.rs index 125a6c2c02..aec1b0b451 100644 --- a/core/src/mast/node/loop_node.rs +++ b/core/src/mast/node/loop_node.rs @@ -48,8 +48,12 @@ impl LoopNode { /// /// The commitment is computed as a hash of the loop body and an empty word ([ZERO; 4]) in /// the domain defined by [Self::DOMAIN] - i..e,: - /// - /// hasher::merge_in_domain(&[on_true_digest, Digest::default()], LoopNode::DOMAIN) + /// ``` + /// # use miden_core::mast::LoopNode; + /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}}; + /// # let body_digest = Digest::default(); + /// Hasher::merge_in_domain(&[body_digest, Digest::default()], LoopNode::DOMAIN); + /// ``` pub fn digest(&self) -> RpoDigest { self.digest } diff --git a/core/src/mast/node/split_node.rs b/core/src/mast/node/split_node.rs index 017e12171f..f754735e35 100644 --- a/core/src/mast/node/split_node.rs +++ b/core/src/mast/node/split_node.rs @@ -55,8 +55,13 @@ impl SplitNode { /// /// The commitment is computed as a hash of the `on_true` and `on_false` child nodes in the /// domain defined by [Self::DOMAIN] - i..e,: - /// - /// hasher::merge_in_domain(&[on_true_digest, on_false_digest], SplitNode::DOMAIN) + /// ``` + /// # use miden_core::mast::SplitNode; + /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}}; + /// # let on_true_digest = Digest::default(); + /// # let on_false_digest = Digest::default(); + /// Hasher::merge_in_domain(&[on_true_digest, on_false_digest], SplitNode::DOMAIN); + /// ``` pub fn digest(&self) -> RpoDigest { self.digest }