diff --git a/CHANGELOG.md b/CHANGELOG.md index 7741d89f28..081079c22d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,8 +3,9 @@ ## 0.8.0 (TBD) #### Assembly -- Expanded capabilities of the `debug` decorator. Added `debug.mem` and `debug.local` variations. +- Expanded capabilities of the `debug` decorator. Added `debug.mem` and `debug.local` variations (#1103). - Introduced the `emit.` assembly instruction (#1119). +- Introduced the `procref.` assembly instruction (#1113). #### Stdlib - Introduced `std::utils` module with `is_empty_word` procedure. Refactored `std::collections::smt` diff --git a/assembly/src/assembler/context.rs b/assembly/src/assembler/context.rs index 30917f896b..85711a44c1 100644 --- a/assembly/src/assembler/context.rs +++ b/assembly/src/assembler/context.rs @@ -81,13 +81,23 @@ impl AssemblyContext { /// Returns the name of the procedure by its ID from the procedure map. pub fn get_imported_procedure_name(&self, id: &ProcedureId) -> Option { - if let Some(module) = self.module_stack.first() { + if let Some(module) = self.module_stack.last() { module.proc_map.get(id).cloned() } else { None } } + /// Returns the [Procedure] by its index from the vector of local procedures. + pub fn get_local_procedure(&self, idx: u16) -> Result<&Procedure, AssemblyError> { + let module_context = self.module_stack.last().expect("no modules"); + module_context + .compiled_procs + .get(idx as usize) + .map(|named_proc| named_proc.inner()) + .ok_or_else(|| AssemblyError::local_proc_not_found(idx, &module_context.path)) + } + // STATE MUTATORS // -------------------------------------------------------------------------------------------- diff --git a/assembly/src/assembler/instruction/mod.rs b/assembly/src/assembler/instruction/mod.rs index c853e4ea84..dc19c370c3 100644 --- a/assembly/src/assembler/instruction/mod.rs +++ b/assembly/src/assembler/instruction/mod.rs @@ -327,6 +327,8 @@ impl Assembler { Instruction::SysCall(id) => self.syscall(id, ctx), Instruction::DynExec => self.dynexec(), Instruction::DynCall => self.dyncall(), + Instruction::ProcRefLocal(idx) => self.procref_local(*idx, ctx, span), + Instruction::ProcRefImported(id) => self.procref_imported(id, ctx, span), // ----- debug decorators ------------------------------------------------------------- Instruction::Breakpoint => { diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index 43bcee08b2..345d0aed70 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -1,4 +1,7 @@ -use super::{Assembler, AssemblyContext, AssemblyError, CodeBlock, ProcedureId, RpoDigest}; +use super::{ + super::Vec, Assembler, AssemblyContext, AssemblyError, CodeBlock, Operation, ProcedureId, + RpoDigest, SpanBuilder, +}; // PROCEDURE INVOCATIONS // ================================================================================================ @@ -133,4 +136,37 @@ impl Assembler { // create a new CALL block whose target is DYN Ok(Some(CodeBlock::new_dyncall())) } + + pub(super) fn procref_local( + &self, + proc_idx: u16, + context: &mut AssemblyContext, + span: &mut SpanBuilder, + ) -> Result, AssemblyError> { + // get root of the compiled local procedure + let proc_root = context.get_local_procedure(proc_idx)?.mast_root(); + // create an array with `Push` operations containing root elements + let ops: Vec = proc_root.iter().map(|elem| Operation::Push(*elem)).collect(); + span.add_ops(ops) + } + + pub(super) fn procref_imported( + &self, + proc_id: &ProcedureId, + context: &mut AssemblyContext, + span: &mut SpanBuilder, + ) -> Result, AssemblyError> { + // make sure the procedure is in procedure cache + self.ensure_procedure_is_in_cache(proc_id, context)?; + + // get the procedure from the assembler + let proc_cache = self.proc_cache.borrow(); + let proc = proc_cache.get_by_id(proc_id).expect("procedure not in cache"); + + // get root of the cimported procedure + let proc_root = proc.mast_root(); + // create an array with `Push` operations containing root elements + let ops: Vec = proc_root.iter().map(|elem| Operation::Push(*elem)).collect(); + span.add_ops(ops) + } } diff --git a/assembly/src/ast/nodes/format.rs b/assembly/src/ast/nodes/format.rs index 0136820fc6..ebbbb68f0a 100644 --- a/assembly/src/ast/nodes/format.rs +++ b/assembly/src/ast/nodes/format.rs @@ -120,6 +120,14 @@ impl fmt::Display for FormattableInstruction<'_> { write!(f, "call.")?; display_hex_bytes(f, &root.as_bytes())?; } + Instruction::ProcRefLocal(index) => { + let proc_name = self.context.local_proc(*index as usize); + write!(f, "procref.{proc_name}")?; + } + Instruction::ProcRefImported(proc_id) => { + let (_, path) = self.context.imported_proc(proc_id); + write!(f, "procref.{path}")?; + } _ => { // Not a procedure call. Use the normal formatting write!(f, "{}", self.instruction)?; diff --git a/assembly/src/ast/nodes/mod.rs b/assembly/src/ast/nodes/mod.rs index 0973be4c5a..dc10fa25ba 100644 --- a/assembly/src/ast/nodes/mod.rs +++ b/assembly/src/ast/nodes/mod.rs @@ -309,6 +309,8 @@ pub enum Instruction { SysCall(ProcedureId), DynExec, DynCall, + ProcRefLocal(u16), + ProcRefImported(ProcedureId), // ----- debug decorators --------------------------------------------------------------------- Breakpoint, @@ -598,6 +600,8 @@ impl fmt::Display for Instruction { Self::SysCall(proc_id) => write!(f, "syscall.{proc_id}"), Self::DynExec => write!(f, "dynexec"), Self::DynCall => write!(f, "dyncall"), + Self::ProcRefLocal(index) => write!(f, "procref.{index}"), + Self::ProcRefImported(proc_id) => write!(f, "procref.{proc_id}"), // ----- debug decorators ------------------------------------------------------------- Self::Breakpoint => write!(f, "breakpoint"), diff --git a/assembly/src/ast/nodes/serde/deserialization.rs b/assembly/src/ast/nodes/serde/deserialization.rs index 5b8a9d1b57..8632af6721 100644 --- a/assembly/src/ast/nodes/serde/deserialization.rs +++ b/assembly/src/ast/nodes/serde/deserialization.rs @@ -363,6 +363,10 @@ impl Deserializable for Instruction { OpCode::SysCall => Ok(Instruction::SysCall(ProcedureId::read_from(source)?)), OpCode::DynExec => Ok(Instruction::DynExec), OpCode::DynCall => Ok(Instruction::DynCall), + OpCode::ProcRefLocal => Ok(Instruction::ProcRefLocal(source.read_u16()?)), + OpCode::ProcRefImported => { + Ok(Instruction::ProcRefImported(ProcedureId::read_from(source)?)) + } // ----- debugging -------------------------------------------------------------------- OpCode::Debug => { diff --git a/assembly/src/ast/nodes/serde/mod.rs b/assembly/src/ast/nodes/serde/mod.rs index 1594d394e6..22efc46e6c 100644 --- a/assembly/src/ast/nodes/serde/mod.rs +++ b/assembly/src/ast/nodes/serde/mod.rs @@ -282,12 +282,14 @@ pub enum OpCode { SysCall = 246, DynExec = 247, DynCall = 248, + ProcRefLocal = 249, + ProcRefImported = 250, // ----- debugging ---------------------------------------------------------------------------- - Debug = 249, + Debug = 251, // ----- emit -------------------------------------------------------------------------------- - Emit = 250, + Emit = 252, // ----- control flow ------------------------------------------------------------------------- IfElse = 253, diff --git a/assembly/src/ast/nodes/serde/serialization.rs b/assembly/src/ast/nodes/serde/serialization.rs index d5e7e5848d..4050ebb205 100644 --- a/assembly/src/ast/nodes/serde/serialization.rs +++ b/assembly/src/ast/nodes/serde/serialization.rs @@ -521,6 +521,14 @@ impl Serializable for Instruction { } Self::DynExec => OpCode::DynExec.write_into(target), Self::DynCall => OpCode::DynCall.write_into(target), + Self::ProcRefLocal(v) => { + OpCode::ProcRefLocal.write_into(target); + target.write_u16(*v) + } + Self::ProcRefImported(imported) => { + OpCode::ProcRefImported.write_into(target); + imported.write_into(target) + } // ----- debug decorators ------------------------------------------------------------- Self::Breakpoint => { diff --git a/assembly/src/ast/parsers/context.rs b/assembly/src/ast/parsers/context.rs index 792ab70a2b..24e68cf897 100644 --- a/assembly/src/ast/parsers/context.rs +++ b/assembly/src/ast/parsers/context.rs @@ -214,6 +214,26 @@ impl ParserContext<'_> { } } + // PROCREF PARSERS + // -------------------------------------------------------------------------------------------- + + /// Parse a `procref` token into an instruction node. + pub fn parse_procref(&mut self, token: &Token) -> Result { + match token.parse_invocation(token.parts()[0])? { + InvocationTarget::ProcedureName(proc_name) => { + let index = self.get_local_proc_index(proc_name, token)?; + let inner = Instruction::ProcRefLocal(index); + Ok(Node::Instruction(inner)) + } + InvocationTarget::ProcedurePath { name, module } => { + let proc_id = self.import_info.add_invoked_proc(&name, module, token)?; + let inner = Instruction::ProcRefImported(proc_id); + Ok(Node::Instruction(inner)) + } + _ => Err(ParsingError::invalid_param(token, 1)), + } + } + // PROCEDURE PARSERS // -------------------------------------------------------------------------------------------- @@ -622,6 +642,7 @@ impl ParserContext<'_> { "syscall" => self.parse_syscall(op), "dynexec" => simple_instruction(op, DynExec), "dyncall" => simple_instruction(op, DynCall), + "procref" => self.parse_procref(op), // ----- constant statements ---------------------------------------------------------- "const" => Err(ParsingError::const_invalid_scope(op)), diff --git a/assembly/src/tests.rs b/assembly/src/tests.rs index 10af9bbf8a..b06edddd57 100644 --- a/assembly/src/tests.rs +++ b/assembly/src/tests.rs @@ -1,6 +1,7 @@ use crate::{ ast::{ModuleAst, ProgramAst}, - Assembler, AssemblyContext, Library, LibraryNamespace, LibraryPath, Module, Version, + Assembler, AssemblyContext, AssemblyError, Library, LibraryNamespace, LibraryPath, MaslLibrary, + Module, ProcedureName, Version, }; use core::slice::Iter; @@ -216,6 +217,138 @@ fn call_without_path() { .unwrap(); } +// PROGRAM WITH PROCREF +// ================================================================================================ + +#[test] +fn procref_call() { + // instantiate assembler + let assembler = Assembler::default(); + + // compile first module + let module_path1 = LibraryPath::new("module::path::one").unwrap(); + let module_source1 = ModuleAst::parse( + " + export.aaa + push.7.8 + end + + export.foo + push.1.2 + end", + ) + .unwrap(); + + let _roots1 = assembler + .compile_module( + &module_source1, + Some(&module_path1), + &mut AssemblyContext::for_module(false), + ) + .unwrap(); + + // compile second module + let module_path2 = LibraryPath::new("module::path::two").unwrap(); + let module_source2 = ModuleAst::parse( + " + use.module::path::one + export.one::foo + + export.bar + procref.one::aaa + end", + ) + .unwrap(); + + let _roots2 = assembler + .compile_module( + &module_source2, + Some(&module_path2), + &mut AssemblyContext::for_module(false), + ) + .unwrap(); + + // compile program with procref calls + let program_source = ProgramAst::parse( + " + use.module::path::two + + proc.baz.4 + push.3.4 + end + + begin + procref.two::bar + procref.two::foo + procref.baz + end", + ) + .unwrap(); + + let _compiled_program = assembler + .compile_in_context( + &program_source, + &mut AssemblyContext::for_program(Some(&program_source)), + ) + .unwrap(); +} + +#[test] +fn get_proc_name_of_unknown_module() { + // Module `two` is unknown. This program should return + // `AssemblyError::imported_proc_module_not_found` error with `bar` procedure name. + let module_source1 = " + use.module::path::two + + export.foo + procref.two::bar + end"; + let module_ast1 = ModuleAst::parse(module_source1).unwrap(); + let module_path1 = LibraryPath::new("module::path::one").unwrap(); + let module1 = Module::new(module_path1, module_ast1); + + let masl_lib = MaslLibrary::new( + LibraryNamespace::new("module").unwrap(), + Version::default(), + false, + vec![module1], + vec![], + ) + .unwrap(); + + // instantiate assembler + let assembler = Assembler::default().with_library(&masl_lib).unwrap(); + + // compile program with procref calls + let program_source = ProgramAst::parse( + " + use.module::path::one + + begin + procref.one::foo + end", + ) + .unwrap(); + + let compilation_error = assembler + .compile_in_context( + &program_source, + &mut AssemblyContext::for_program(Some(&program_source)), + ) + .err() + .unwrap(); + + let expected_error = AssemblyError::imported_proc_module_not_found( + &crate::ProcedureId([ + 17, 137, 148, 17, 42, 108, 60, 23, 205, 115, 62, 70, 16, 121, 221, 142, 51, 247, 250, + 43, + ]), + ProcedureName::try_from("bar").ok(), + ); + + assert_eq!(compilation_error, expected_error); +} + // CONSTANTS // ================================================================================================ diff --git a/docs/src/user_docs/assembly/io_operations.md b/docs/src/user_docs/assembly/io_operations.md index f6f267e7b2..0f1f8000af 100644 --- a/docs/src/user_docs/assembly/io_operations.md +++ b/docs/src/user_docs/assembly/io_operations.md @@ -25,10 +25,11 @@ In both case the values must still encode valid field elements. | Instruction | Stack_input | Stack_output | Notes | | ------------------------------- | ------------ | ------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| clk
- *(1 cycle)* | [ ... ] | [t, ... ] | $t \leftarrow clock\_value()$
Pushes the current value of the clock cycle counter onto the stack. | -| sdepth
- *(1 cycle)* | [ ... ] | [d, ... ] | $d \leftarrow stack.depth()$
Pushes the current depth of the stack onto the stack. | -| caller
- *(1 cycle)* | [A, b, ... ] | [H, b, ... ] | $H \leftarrow context.fn\_hash()$
Overwrites the top four stack items with the hash of a function which initiated the current SYSCALL.
Executing this instruction outside of SYSCALL context will fail. | -| locaddr.*i*
- *(2 cycles)* | [ ... ] | [a, ... ] | $a \leftarrow address\_of(i)$
Pushes the absolute memory address of local memory at index $i$ onto the stack. | +| clk
- *(1 cycle)* | [ ... ] | [t, ... ] | $t \leftarrow clock\_value()$
Pushes the current value of the clock cycle counter onto the stack. | +| sdepth
- *(1 cycle)* | [ ... ] | [d, ... ] | $d \leftarrow stack.depth()$
Pushes the current depth of the stack onto the stack. | +| caller
- *(1 cycle)* | [A, b, ... ] | [H, b, ... ] | $H \leftarrow context.fn\_hash()$
Overwrites the top four stack items with the hash of a function which initiated the current SYSCALL.
Executing this instruction outside of SYSCALL context will fail. | +| locaddr.*i*
- *(2 cycles)* | [ ... ] | [a, ... ] | $a \leftarrow address\_of(i)$
Pushes the absolute memory address of local memory at index $i$ onto the stack. | +| procref.*name*
- *(4 cycles)* | [ ... ] | [A, ... ] | $A \leftarrow mast\_root()$
Pushes MAST root of the procedure with name $name$ onto the stack. | ### Nondeterministic inputs diff --git a/miden/tests/integration/flow_control/mod.rs b/miden/tests/integration/flow_control/mod.rs index aec1e0c046..002e6a0297 100644 --- a/miden/tests/integration/flow_control/mod.rs +++ b/miden/tests/integration/flow_control/mod.rs @@ -1,4 +1,8 @@ +use assembly::{Assembler, AssemblyContext, LibraryPath}; +use miden::ModuleAst; +use stdlib::StdLibrary; use test_utils::{build_test, AdviceInputs, StackInputs, Test, TestError}; +use vm_core::StarkField; // SIMPLE FLOW CONTROL TESTS // ================================================================================================ @@ -352,3 +356,57 @@ fn simple_dyncall() { false, ); } + +// PROCREF INSTRUCTION +// ================================================================================================ + +#[test] +fn procref() { + let assembler = Assembler::default().with_library(&StdLibrary::default()).unwrap(); + + let module_source = " + use.std::math::u64 + export.u64::overflowing_add + + export.foo.4 + push.3.4 + end + "; + + // obtain procedures' MAST roots by compiling them as module + let module_ast = ModuleAst::parse(module_source).unwrap(); + let module_path = LibraryPath::new("test::foo").unwrap(); + let mast_roots = assembler + .compile_module(&module_ast, Some(&module_path), &mut AssemblyContext::for_module(false)) + .unwrap(); + + let source = " + use.std::math::u64 + + proc.foo.4 + push.3.4 + end + + begin + procref.u64::overflowing_add + push.0 + procref.foo + end"; + + let mut test = build_test!(source, &[]); + test.libraries = vec![StdLibrary::default().into()]; + + test.expect_stack(&[ + mast_roots[1][3].as_int(), + mast_roots[1][2].as_int(), + mast_roots[1][1].as_int(), + mast_roots[1][0].as_int(), + 0, + mast_roots[0][3].as_int(), + mast_roots[0][2].as_int(), + mast_roots[0][1].as_int(), + mast_roots[0][0].as_int(), + ]); + + test.prove_and_verify(vec![], false); +}