diff --git a/.github/ISSUE_TEMPLATE/1-bugreport.yml b/.github/ISSUE_TEMPLATE/1-bugreport.yml new file mode 100644 index 0000000000..f55d1abd00 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/1-bugreport.yml @@ -0,0 +1,38 @@ +name: "Bug Report" +description: "File a bug report" +labels: ["bug"] +body: + - type: markdown + attributes: + value: | + Thank you for taking the time to fill out this bug report! + - type: textarea + id: version + attributes: + label: "Packages versions" + description: "Let us know the versions of any other packages used. For example, which version of the VM are you using?" + placeholder: "miden-vm: 0.1.0" + validations: + required: true + - type: textarea + id: bug-description + attributes: + label: "Bug description" + description: "Describe the behavior you are experiencing." + placeholder: "Tell us what happened and what should have happened." + validations: + required: true + - type: textarea + id: reproduce-steps + attributes: + label: "How can this be reproduced?" + description: "If possible, describe how to replicate the unexpected behavior that you see." + placeholder: "Steps!" + validations: + required: false + - type: textarea + id: logs + attributes: + label: Relevant log output + description: Please copy and paste any relevant log output. This is automatically formatted as code, no need for backticks. + render: shell diff --git a/.github/ISSUE_TEMPLATE/2-feature-request.yml b/.github/ISSUE_TEMPLATE/2-feature-request.yml new file mode 100644 index 0000000000..24edfc9406 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/2-feature-request.yml @@ -0,0 +1,20 @@ +name: "Feature request" +description: "Request new goodies" +labels: ["enhancement"] +body: + - type: markdown + attributes: + value: | + Thank you for taking the time to fill a feature request! + - type: textarea + id: scenario-why + attributes: + label: "Feature description" + validations: + required: true + - type: textarea + id: scenario-how + attributes: + label: "Why is this feature needed?" + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/3-task.yml b/.github/ISSUE_TEMPLATE/3-task.yml new file mode 100644 index 0000000000..3bc108578f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/3-task.yml @@ -0,0 +1,35 @@ +name: "Task" +description: "Work item" +body: + - type: markdown + attributes: + value: | + A task should be less than a week worth of work! + - type: textarea + id: task-what + attributes: + label: "What should be done?" + placeholder: "Impose restrictions on DYN and DYNCALL operation" + validations: + required: true + - type: textarea + id: task-how + attributes: + label: "How should it be done?" + placeholder: "Users should be able to specify whether DYN/DYNCALL operations are allowed in a given program" + validations: + required: true + - type: textarea + id: task-done + attributes: + label: "When is this task done?" + placeholder: "The task is done when users are able to specify whether DYN/DYNCALL operations are allowed in a given program" + validations: + required: true + - type: textarea + id: task-related + attributes: + label: "Additional context" + description: "Add context to the tasks. E.g. other related tasks or relevant discussions on PRs/chats." + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000..0086358db1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1 @@ +blank_issues_enabled: true diff --git a/.github/dependabot.yml b/.github/dependabot.yml deleted file mode 100644 index 53f8242a9f..0000000000 --- a/.github/dependabot.yml +++ /dev/null @@ -1,6 +0,0 @@ -version: 2 -updates: - - package-ecosystem: "cargo" - directory: "/" - schedule: - interval: "weekly" diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md deleted file mode 100644 index ffd0f236fb..0000000000 --- a/.github/pull_request_template.md +++ /dev/null @@ -1,9 +0,0 @@ -## Describe your changes - - -## Checklist before requesting a review -- Repo forked and branch created from `next` according to naming convention. -- Commit messages and codestyle follow [conventions](./CONTRIBUTING.md). -- Relevant issues are linked in the PR description. -- Tests added for new functionality. -- Documentation/comments updated according to changes. diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d0385dbaa..d37a550c67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,43 @@ # Changelog -## 0.10.6 (2024-09-12) - `miden-processor` crate only. +## 0.11.0 (2024-11-04) + +#### Enhancements + +- Added `miden_core::utils::sync::racy_lock` module (#1463). +- Updated `miden_core::utils` to re-export `std::sync::LazyLock` and `racy_lock::RacyLock as LazyLock` for std and no_std environments, respectively (#1463). +- Debug instructions can be enabled in the cli `run` command using `--debug` flag (#1502). +- Added support for procedure annotation (attribute) syntax to Miden Assembly (#1510). +- Make `miden-prover::prove()` method conditionally asynchronous (#1563). + +#### Changes + +- [BREAKING] Wrapped `MastForest`s in `Program` and `Library` structs in `Arc` (#1465). +- `MastForestBuilder`: use `MastNodeId` instead of MAST root to uniquely identify procedures (#1473). +- Made the undocumented behavior of the VM with regard to undefined behavior of u32 operations, stricter (#1480). +- Introduced the `Emit` instruction (#1496). +- [BREAKING] ExecutionOptions::new constructor requires a boolean to explicitly set debug mode (#1502). +- [BREAKING] The `run` and the `prove` commands in the cli will accept `--trace` flag instead of `--tracing` (#1502). +- Migrated to new padding rule for RPO (#1343). +- Migrated to `miden-crypto` v0.11.0 (#1343). +- Implemented `MastForest` merging (#1534). +- Rename `EqHash` to `MastNodeFingerprint` and make it `pub` (#1539). +- Updated Winterfell dependency to v0.10 (#1533). +- [BREAKING] `DYN` operation now expects a memory address pointing to the procedure hash (#1535). +- [BREAKING] `DYNCALL` operation fixed, and now expects a memory address pointing to the procedure hash (#1535). +- Permit child `MastNodeId`s to exceed the `MastNodeId`s of their parents (#1542). +- Don't validate export names on `Library` deserialization (#1554) + +#### Fixes + +- Fixed an issue with formatting of blocks in Miden Assembly syntax +- Fixed the construction of the block hash table (#1506) +- Fixed a bug in the block stack table (#1511) (#1512) (#1557) +- Fixed the construction of the chiplets virtual table (#1514) (#1556) +- Fixed the construction of the chiplets bus (#1516) (#1525) +- Decorators are now allowed in empty basic blocks (#1466) + +## 0.10.6 (2024-09-12) - `miden-processor` crate only #### Enhancements @@ -42,6 +79,7 @@ - [BREAKING] Replaced `SourceManager` parameter with `Assembler` in `Library::from_dir` (#1445). - [BREAKING] Moved `Library` and `KernelLibrary` exports to the root of the `miden-assembly` crate. (#1445). +- [BREAKING] Depth of the input and output stack was restricted to 16 (#1456). ## 0.10.2 (2024-08-10) @@ -126,6 +164,8 @@ #### Stdlib - Added `init_no_padding` procedure to `std::crypto::hashes::native` (#1313). +- [BREAKING] `native` module was renamed to the `rpo`, `hash_memory` procedure was renamed to the `hash_memory_words` (#1368). +- Added `hash_memory` procedure to `std::crypto::hashes::rpo` (#1368). #### VM Internals diff --git a/Cargo.lock b/Cargo.lock index 9fd88568cf..0ecc546d00 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,18 +4,18 @@ version = 3 [[package]] name = "addr2line" -version = "0.22.0" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" dependencies = [ "gimli", ] [[package]] -name = "adler" -version = "1.0.2" +name = "adler2" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" [[package]] name = "aho-corasick" @@ -43,9 +43,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.15" +version = "0.6.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" dependencies = [ "anstyle", "anstyle-parse", @@ -58,43 +58,43 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.8" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anstyle-parse" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.1" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "anstyle-wincon" -version = "3.0.4" +version = "3.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" +checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" dependencies = [ "anstyle", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "arrayref" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d151e35f61089500b617991b791fc8bfd237ae50cd5950803758a179b41e67a" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" [[package]] name = "arrayvec" @@ -129,23 +129,23 @@ dependencies = [ [[package]] name = "autocfg" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "backtrace" -version = "0.3.73" +version = "0.3.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" dependencies = [ "addr2line", - "cc", "cfg-if", "libc", "miniz_oxide", "object", "rustc-demangle", + "windows-targets 0.52.6", ] [[package]] @@ -219,7 +219,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" dependencies = [ "memchr", - "regex-automata 0.4.7", + "regex-automata 0.4.8", "serde", ] @@ -243,9 +243,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.1.13" +version = "1.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72db2f7947ecee9b03b510377e8bb9077afa27176fdbff55c51027e976fdcc48" +checksum = "67b9470d453346108f93a59222a9a1a5724db32d0a4727b7ab7ace4b4d822dc9" dependencies = [ "jobserver", "libc", @@ -287,9 +287,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.16" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed6719fffa43d0d87e5fd8caeab59be1554fb028cd30edc88fc4369b17971019" +checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" dependencies = [ "clap_builder", "clap_derive", @@ -297,9 +297,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.15" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "216aec2b177652e3846684cbfe25c9964d18ec45234f0f5da5157b207ed1aab6" +checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" dependencies = [ "anstream", "anstyle", @@ -309,9 +309,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.13" +version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0" +checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" dependencies = [ "heck", "proc-macro2", @@ -336,15 +336,15 @@ dependencies = [ [[package]] name = "colorchoice" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "constant_time_eq" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" [[package]] name = "core-foundation" @@ -375,9 +375,9 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.13" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51e852e6dc9a5bed1fae92dd2375037bf2b768725bf3be87811edee3249d09ad" +checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" dependencies = [ "libc", ] @@ -556,15 +556,15 @@ dependencies = [ [[package]] name = "error-code" -version = "3.2.0" +version = "3.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0474425d51df81997e2f90a21591180b38eccf27292d755f3e30750225c175b" +checksum = "a5d9305ccc6942a704f4335694ecd3de2ea531b114ac2d51f5f843750787a92f" [[package]] name = "escargot" -version = "0.5.12" +version = "0.5.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c000f23e9d459aef148b7267e02b03b94a0aaacf4ec64c65612f67e02f525fb6" +checksum = "05a3ac187a16b5382fef8c69fd1bad123c67b7cf3932240a2d43dcdd32cded88" dependencies = [ "log", "once_cell", @@ -574,9 +574,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" [[package]] name = "fixedbitset" @@ -628,9 +628,9 @@ checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" [[package]] name = "futures" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -642,9 +642,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -652,33 +652,33 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-io" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-util" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-core", "futures-sink", @@ -689,9 +689,9 @@ dependencies = [ [[package]] name = "generator" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "979f00864edc7516466d6b3157706e06c032f22715700ddd878228a91d02bc56" +checksum = "dbb949699c3e4df3a183b1d2142cb24277057055ed23c68ed58894f76c517223" dependencies = [ "cfg-if", "libc", @@ -723,9 +723,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.29.0" +version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "glob" @@ -745,9 +745,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.5" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" [[package]] name = "heck" @@ -775,9 +775,9 @@ checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" [[package]] name = "indexmap" -version = "2.4.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93ead53efc7ea8ed3cfb0c79fc8023fbb782a5432b52830b6518941cebe6505c" +checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", "hashbrown", @@ -841,9 +841,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ "wasm-bindgen", ] @@ -870,7 +870,7 @@ dependencies = [ "lalrpop-util", "petgraph", "regex", - "regex-syntax 0.8.4", + "regex-syntax 0.8.5", "string_cache", "term", "tiny-keccak", @@ -892,15 +892,15 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.158" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "libm" -version = "0.2.8" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" [[package]] name = "libredox" @@ -986,9 +986,9 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "metal" -version = "0.27.0" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c43f73953f8cbe511f021b58f18c3ce1c3d1ae13fe953293e13345bf83217f25" +checksum = "9c3572083504c43e14aec05447f8a3d57cce0f66d7a3c1b9058572eca4d70ab9" dependencies = [ "bitflags 2.6.0", "block", @@ -1001,7 +1001,7 @@ dependencies = [ [[package]] name = "miden-air" -version = "0.10.5" +version = "0.11.0" dependencies = [ "criterion", "miden-core", @@ -1014,7 +1014,7 @@ dependencies = [ [[package]] name = "miden-assembly" -version = "0.10.5" +version = "0.11.0" dependencies = [ "aho-corasick", "lalrpop", @@ -1024,15 +1024,15 @@ dependencies = [ "miden-thiserror", "pretty_assertions", "regex", - "rustc_version 0.4.0", + "rustc_version 0.4.1", "smallvec", "tracing", - "unicode-width", + "unicode-width 0.2.0", ] [[package]] name = "miden-core" -version = "0.10.5" +version = "0.11.0" dependencies = [ "lock_api", "loom", @@ -1052,9 +1052,9 @@ dependencies = [ [[package]] name = "miden-crypto" -version = "0.10.0" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6fad06fc3af260ed3c4235821daa2132813d993f96d446856036ae97e9606dd" +checksum = "f50a68deed96cde1f51eb623f75828e320f699e0d798f11592f8958ba8b512c3" dependencies = [ "blake3", "cc", @@ -1075,14 +1075,14 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e392e0a8c34b32671012b439de35fa8987bf14f0f8aac279b97f8b8cc6e263b" dependencies = [ - "unicode-width", + "unicode-width 0.1.14", ] [[package]] name = "miden-gpu" -version = "0.2.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ade33603aa2eaf78c6f06fd60f4dfe22b7ae1f5606698e386baf71eb9d246d50" +checksum = "271d375ea8bfdb0995f30d27d5eccc1468326e502e6ad6bdb0f9ee2d91ade50c" dependencies = [ "metal", "once_cell", @@ -1117,7 +1117,7 @@ dependencies = [ "terminal_size", "textwrap", "trybuild", - "unicode-width", + "unicode-width 0.1.14", ] [[package]] @@ -1133,7 +1133,7 @@ dependencies = [ [[package]] name = "miden-processor" -version = "0.10.6" +version = "0.11.0" dependencies = [ "logtest", "miden-air", @@ -1148,7 +1148,7 @@ dependencies = [ [[package]] name = "miden-prover" -version = "0.10.5" +version = "0.11.0" dependencies = [ "elsa", "miden-air", @@ -1156,12 +1156,13 @@ dependencies = [ "miden-processor", "pollster", "tracing", + "winter-maybe-async", "winter-prover", ] [[package]] name = "miden-stdlib" -version = "0.10.5" +version = "0.11.0" dependencies = [ "blake3", "criterion", @@ -1219,7 +1220,7 @@ dependencies = [ [[package]] name = "miden-verifier" -version = "0.10.5" +version = "0.11.0" dependencies = [ "miden-air", "miden-core", @@ -1229,7 +1230,7 @@ dependencies = [ [[package]] name = "miden-vm" -version = "0.10.5" +version = "0.11.0" dependencies = [ "assert_cmd", "blake3", @@ -1259,11 +1260,11 @@ dependencies = [ [[package]] name = "miniz_oxide" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" dependencies = [ - "adler", + "adler2", ] [[package]] @@ -1391,32 +1392,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" dependencies = [ "malloc_buf", - "objc_exception", -] - -[[package]] -name = "objc_exception" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad970fb455818ad6cba4c122ad012fae53ae8b4795f86378bce65e4f6bab2ca4" -dependencies = [ - "cc", ] [[package]] name = "object" -version = "0.36.3" +version = "0.36.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27b64972346851a39438c60b341ebc01bba47464ae329e55cf343eb93964efd9" +checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" dependencies = [ "memchr", ] [[package]] name = "once_cell" -version = "1.19.0" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "oorandom" @@ -1432,9 +1423,9 @@ checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "owo-colors" -version = "4.0.0" +version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "caff54706df99d2a78a5a4e3455ff45448d81ef1bb63c22cd14052ca0e993a3f" +checksum = "fb37767f6569cd834a413442455e0f066d0d522de8630436e2a1761d9726ba56" [[package]] name = "parking_lot" @@ -1486,9 +1477,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" [[package]] name = "pin-utils" @@ -1498,9 +1489,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "plotters" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" dependencies = [ "num-traits", "plotters-backend", @@ -1511,24 +1502,24 @@ dependencies = [ [[package]] name = "plotters-backend" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" [[package]] name = "plotters-svg" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" dependencies = [ "plotters-backend", ] [[package]] name = "pollster" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22686f4785f02a4fcc856d3b3bb19bf6c8160d103f7a99cc258bddd0251dc7f2" +checksum = "2f3a9f18d041e6d0e102a0a46750538147e5e8992d3b4873aaafee2520b00ce3" [[package]] name = "ppv-lite86" @@ -1577,9 +1568,9 @@ dependencies = [ [[package]] name = "pretty_assertions" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af7cee1a6c8a5b9208b3cb1061f10c0cb689087b3d8ce85fb9d2dd7a29b6ba66" +checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d" dependencies = [ "diff", "yansi", @@ -1587,9 +1578,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] @@ -1608,7 +1599,7 @@ dependencies = [ "rand", "rand_chacha", "rand_xorshift", - "regex-syntax 0.8.4", + "regex-syntax 0.8.5", "rusty-fork", "tempfile", "unarray", @@ -1622,9 +1613,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -1690,9 +1681,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.3" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" dependencies = [ "bitflags 2.6.0", ] @@ -1710,14 +1701,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.6" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.7", - "regex-syntax 0.8.4", + "regex-automata 0.4.8", + "regex-syntax 0.8.5", ] [[package]] @@ -1731,13 +1722,13 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.4", + "regex-syntax 0.8.5", ] [[package]] @@ -1748,9 +1739,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "rustc-demangle" @@ -1769,18 +1760,18 @@ dependencies = [ [[package]] name = "rustc_version" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" dependencies = [ "semver 1.0.23", ] [[package]] name = "rustix" -version = "0.38.34" +version = "0.38.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +checksum = "375116bee2be9ed569afe2154ea6a99dfdffd257f533f187498c2a8f5feaf4ee" dependencies = [ "bitflags 2.6.0", "errno", @@ -1791,9 +1782,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" +checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" [[package]] name = "rusty-fork" @@ -1821,7 +1812,7 @@ dependencies = [ "memchr", "nix", "unicode-segmentation", - "unicode-width", + "unicode-width 0.1.14", "utf8parse", "winapi", ] @@ -1876,18 +1867,18 @@ checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" [[package]] name = "serde" -version = "1.0.208" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cff085d2cb684faa248efb494c39b68e522822ac0de72ccf08109abde717cfb2" +checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.208" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24008e81ff7613ed8e5ba0cfaf24e2c2f1e5b8a0495711e44fcd4882fca62bcf" +checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ "proc-macro2", "quote", @@ -1896,9 +1887,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.125" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83c8e735a073ccf5be70aa8066aa984eaf2fa000db6c8d0100ae605b366d31ed" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", "memchr", @@ -1908,9 +1899,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "0.6.7" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d" +checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" dependencies = [ "serde", ] @@ -2011,9 +2002,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" [[package]] name = "supports-color" -version = "3.0.0" +version = "3.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9829b314621dfc575df4e409e79f9d6a66a3bd707ab73f23cb4aa3a854ac854f" +checksum = "8775305acf21c96926c900ad056abeef436701108518cf890020387236ac5a77" dependencies = [ "is_ci", ] @@ -2032,20 +2023,26 @@ checksum = "b7401a30af6cb5818bb64852270bb722533397edcfc7344954a38f420819ece2" [[package]] name = "syn" -version = "2.0.75" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6af063034fc1935ede7be0122941bafa9bacb949334d090b77ca98b5817c7d9" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] +[[package]] +name = "target-triple" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42a4d50cdb458045afc8131fd91b64904da29548bcb63c7236e0844936c13078" + [[package]] name = "tempfile" -version = "3.12.0" +version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" +checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" dependencies = [ "cfg-if", "fastrand", @@ -2131,23 +2128,23 @@ checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" dependencies = [ "smawk", "unicode-linebreak", - "unicode-width", + "unicode-width 0.1.14", ] [[package]] name = "thiserror" -version = "1.0.63" +version = "1.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +checksum = "02dd99dc800bbb97186339685293e1cc5d9df1f8fae2d0aecd9ff1c77efea892" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.63" +version = "1.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +checksum = "a7c61ec9a6f64d2793d8a45faba21efbe3ced62a886d44c36a009b2b519b4c7e" dependencies = [ "proc-macro2", "quote", @@ -2206,9 +2203,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.22.20" +version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "583c44c02ad26b0c3f3066fe629275e50627026c51ac2e595cca4c230ce1ce1d" +checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ "indexmap", "serde", @@ -2293,15 +2290,16 @@ dependencies = [ [[package]] name = "trybuild" -version = "1.0.99" +version = "1.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "207aa50d36c4be8d8c6ea829478be44a372c6a77669937bb39c698e52f1491e8" +checksum = "8dcd332a5496c026f1e14b7f3d2b7bd98e509660c04239c58b0ba38a12daded4" dependencies = [ "dissimilar", "glob", "serde", "serde_derive", "serde_json", + "target-triple", "termcolor", "toml", ] @@ -2320,9 +2318,9 @@ checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" [[package]] name = "unicode-ident" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" [[package]] name = "unicode-linebreak" @@ -2332,21 +2330,27 @@ checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f" [[package]] name = "unicode-segmentation" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] name = "unicode-width" -version = "0.1.13" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + +[[package]] +name = "unicode-width" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" [[package]] name = "unicode-xid" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "229730647fbc343e3a80e463c1db7f78f3855d3f3739bee0dda773c9a037c90a" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" [[package]] name = "utf8parse" @@ -2362,9 +2366,9 @@ checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" [[package]] name = "value-bag" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a84c137d37ab0142f0f2ddfe332651fdbf252e7b7dbb4e67b6c1f1b2e925101" +checksum = "3ef4c4aa54d5d05a279399bfa921ec387b7aba77caf7a682ae8d86785b8fdad2" [[package]] name = "version_check" @@ -2419,9 +2423,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", "once_cell", @@ -2430,9 +2434,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", @@ -2445,9 +2449,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2455,9 +2459,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", @@ -2468,15 +2472,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "web-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" dependencies = [ "js-sys", "wasm-bindgen", @@ -2727,18 +2731,18 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.18" +version = "0.6.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68a9bda4691f099d435ad181000724da8e5899daa10713c2d432552b9ccd3a6f" +checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" dependencies = [ "memchr", ] [[package]] name = "winter-air" -version = "0.9.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b72f12b88ebb060b52c0e9aece9bb64a9fc38daf7ba689dd5ce63271b456c883" +checksum = "29bec0b06b741543f43e3a6677b95b200d4cad2daab76e6721e14345345bfd0e" dependencies = [ "libm", "winter-crypto", @@ -2749,9 +2753,9 @@ dependencies = [ [[package]] name = "winter-crypto" -version = "0.9.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00fbb724d2d9fbfd3aa16ea27f5e461d4fe1d74b0c9e0ed1bf79e9e2a955f4d5" +checksum = "163da45f1d4d65cac361b8df4835a6daa95b3399154e16eb0305c178c6f6c1f4" dependencies = [ "blake3", "sha3", @@ -2761,9 +2765,9 @@ dependencies = [ [[package]] name = "winter-fri" -version = "0.9.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ab6077cf4c23c0411f591f4ba29378e27f26acb8cef3c51cadd93daaf6080b3" +checksum = "3b7b394670d68979a4cc21a37a95ef8ef350cf84be9256c53effe3052df50d26" dependencies = [ "winter-crypto", "winter-math", @@ -2772,29 +2776,28 @@ dependencies = [ [[package]] name = "winter-math" -version = "0.9.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "004f85bb051ce986ec0b9a2bd90aaf81b83e3c67464becfdf7db31f14c1019ba" +checksum = "5a8ba832121679e79b004b0003018c85873956d742a39c348c247f680fe15e00" dependencies = [ "winter-utils", ] [[package]] name = "winter-maybe-async" -version = "0.9.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ce0f4161cdde50de809b3869c1cb083a09e92e949428ea28f04c0d64045875c" +checksum = "be43529f43f70306437d2c2c9f9e2b3a4d39b42e86702d8d7577f2357ea32fa6" dependencies = [ - "proc-macro2", "quote", "syn", ] [[package]] name = "winter-prover" -version = "0.9.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f17e3dbae97050f58e01ed4f12906e247841575a0518632e052941a1c37468df" +checksum = "2f55f0153d26691caaf969066a13a824bcf3c98719d71b0f569bf8dc40a06fb9" dependencies = [ "tracing", "winter-air", @@ -2807,9 +2810,9 @@ dependencies = [ [[package]] name = "winter-rand-utils" -version = "0.9.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2b827c901ab0c316d89812858ff451d60855c0a5c7ae734b098c62a28624181" +checksum = "4a7616d11fcc26552dada45c803a884ac97c253218835b83a2c63e1c2a988639" dependencies = [ "rand", "winter-utils", @@ -2817,18 +2820,18 @@ dependencies = [ [[package]] name = "winter-utils" -version = "0.9.1" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0568612a95bcae3c94fb14da2686f8279ca77723dbdf1e97cf3673798faf6485" +checksum = "76b116c8ade0172506f8bda32dc674cf6b230adc8516e5138a0173ae69158a4f" dependencies = [ "rayon", ] [[package]] name = "winter-verifier" -version = "0.9.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "324002ade90f21e85599d51a232a80781efc8cb46f511f8bc89f9c5a4eb9cb65" +checksum = "2ae1648768f96f5e6321a48a5bff5cc3101d2e51b23a6a095c6c9c9e133ecb61" dependencies = [ "winter-air", "winter-crypto", @@ -2839,9 +2842,9 @@ dependencies = [ [[package]] name = "yansi" -version = "0.5.1" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" [[package]] name = "zerocopy" diff --git a/Cargo.toml b/Cargo.toml index ac69f241ed..30d5cdbdff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ resolver = "2" [workspace.package] edition = "2021" -rust-version = "1.80" +rust-version = "1.82" license = "MIT" readme = "README.md" authors = ["Miden contributors"] diff --git a/Makefile b/Makefile index e46a8f8c74..7a60dc155a 100644 --- a/Makefile +++ b/Makefile @@ -11,17 +11,18 @@ DEBUG_ASSERTIONS=RUSTFLAGS="-C debug-assertions" FEATURES_CONCURRENT_EXEC=--features concurrent,executable FEATURES_LOG_TREE=--features concurrent,executable,tracing-forest FEATURES_METAL_EXEC=--features concurrent,executable,metal +ALL_FEATURES_BUT_ASYNC=--features concurrent,executable,metal,testing,with-debug-info # -- linting -------------------------------------------------------------------------------------- .PHONY: clippy clippy: ## Runs Clippy with configs - cargo +nightly clippy --workspace --all-targets --all-features -- -D warnings + cargo +nightly clippy --workspace --all-targets ${ALL_FEATURES_BUT_ASYNC} -- -D warnings .PHONY: fix fix: ## Runs Fix with configs - cargo +nightly fix --allow-staged --allow-dirty --all-targets --all-features + cargo +nightly fix --allow-staged --allow-dirty --all-targets ${ALL_FEATURES_BUT_ASYNC} .PHONY: format @@ -41,7 +42,7 @@ lint: format fix clippy ## Runs all linting tasks at once (Clippy, fixing, forma .PHONY: doc doc: ## Generates & checks documentation - $(WARNINGS) cargo doc --all-features --keep-going --release + $(WARNINGS) cargo doc ${ALL_FEATURES_BUT_ASYNC} --keep-going --release .PHONY: mdbook mdbook: ## Generates mdbook documentation @@ -65,11 +66,15 @@ test-skip-proptests: ## Runs all tests, except property-based tests test-loom: ## Runs all loom-based tests RUSTFLAGS="--cfg loom" cargo nextest run --cargo-profile test-release --features testing -E 'test(#*loom)' +.PHONY: test-package +test-package: ## Tests specific package: make test-package package=miden-vm + $(DEBUG_ASSERTIONS) cargo nextest run --cargo-profile test-release --features testing -p $(package) + # --- checking ------------------------------------------------------------------------------------ .PHONY: check check: ## Checks all targets and features for errors without code generation - cargo check --all-targets --all-features + cargo check --all-targets ${ALL_FEATURES_BUT_ASYNC} # --- building ------------------------------------------------------------------------------------ diff --git a/README.md b/README.md index 42a6dbbcd5..18a9782b2d 100644 --- a/README.md +++ b/README.md @@ -3,13 +3,15 @@ [![LICENSE](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/0xPolygonMiden/miden-vm/blob/main/LICENSE) [![Test](https://github.com/0xPolygonMiden/miden-vm/actions/workflows/test.yml/badge.svg)](https://github.com/0xPolygonMiden/miden-vm/actions/workflows/test.yml) [![Build](https://github.com/0xPolygonMiden/miden-vm/actions/workflows/build.yml/badge.svg)](https://github.com/0xPolygonMiden/miden-vm/actions/workflows/build.yml) -[![RUST_VERSION](https://img.shields.io/badge/rustc-1.80+-lightgray.svg)](https://www.rust-lang.org/tools/install) +[![RUST_VERSION](https://img.shields.io/badge/rustc-1.82+-lightgray.svg)](https://www.rust-lang.org/tools/install) [![Crates.io](https://img.shields.io/crates/v/miden-vm)](https://crates.io/crates/miden-vm) A STARK-based virtual machine. **WARNING:** This project is in an alpha stage. It has not been audited and may contain bugs and security flaws. This implementation is NOT ready for production use. +**WARNING:** For `no_std`, only the `wasm32-unknown-unknown` and `wasm32-wasip1` targets are officially supported. + ## Overview Miden VM is a zero-knowledge virtual machine written in Rust. For any program executed on Miden VM, a STARK-based proof of execution is automatically generated. This proof can then be used by anyone to verify that the program was executed correctly without the need for re-executing the program or even knowing the contents of the program. @@ -20,7 +22,7 @@ Miden VM is a zero-knowledge virtual machine written in Rust. For any program ex ### Status and features -Miden VM is currently on release v0.10. In this release, most of the core features of the VM have been stabilized, and most of the STARK proof generation has been implemented. While we expect to keep making changes to the VM internals, the external interfaces should remain relatively stable, and we will do our best to minimize the amount of breaking changes going forward. +Miden VM is currently on release v0.11. In this release, most of the core features of the VM have been stabilized, and most of the STARK proof generation has been implemented. While we expect to keep making changes to the VM internals, the external interfaces should remain relatively stable, and we will do our best to minimize the amount of breaking changes going forward. The next version of the VM is being developed in the [next](https://github.com/0xPolygonMiden/miden-vm/tree/next) branch. There is also a documentation for the latest features and changes in the next branch [documentation next branch](https://0xpolygonmiden.github.io/miden-vm/intro/main.html). diff --git a/air/Cargo.toml b/air/Cargo.toml index 23c1937131..ad70f6b999 100644 --- a/air/Cargo.toml +++ b/air/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "miden-air" -version = "0.10.5" +version = "0.11.0" description = "Algebraic intermediate representation of Miden VM processor" -documentation = "https://docs.rs/miden-air/0.10.5" +documentation = "https://docs.rs/miden-air/0.11.0" readme = "README.md" categories = ["cryptography", "no-std"] keywords = ["air", "arithmetization", "crypto", "miden"] @@ -32,11 +32,11 @@ testing = [] [dependencies] thiserror = { package = "miden-thiserror", version = "1.0", default-features = false } -vm-core = { package = "miden-core", path = "../core", version = "0.10", default-features = false } -winter-air = { package = "winter-air", version = "0.9", default-features = false } -winter-prover = { package = "winter-prover", version = "0.9", default-features = false } +vm-core = { package = "miden-core", path = "../core", version = "0.11", default-features = false } +winter-air = { package = "winter-air", version = "0.10", default-features = false } +winter-prover = { package = "winter-prover", version = "0.10", default-features = false } [dev-dependencies] criterion = "0.5" -proptest = "1.3" -rand-utils = { package = "winter-rand-utils", version = "0.9" } +proptest = "1.5" +rand-utils = { package = "winter-rand-utils", version = "0.10" } diff --git a/air/src/constraints/stack/mod.rs b/air/src/constraints/stack/mod.rs index 7cc021824c..367b334acb 100644 --- a/air/src/constraints/stack/mod.rs +++ b/air/src/constraints/stack/mod.rs @@ -1,6 +1,6 @@ use alloc::vec::Vec; -use vm_core::{stack::STACK_TOP_SIZE, StackOutputs}; +use vm_core::{stack::MIN_STACK_DEPTH, StackOutputs}; use super::super::{ Assertion, EvaluationFrame, Felt, FieldElement, TransitionConstraintDegree, CLK_COL_IDX, @@ -22,7 +22,7 @@ pub mod u32_ops; // CONSTANTS // ================================================================================================ -const B0_COL_IDX: usize = STACK_TRACE_OFFSET + STACK_TOP_SIZE; +const B0_COL_IDX: usize = STACK_TRACE_OFFSET + MIN_STACK_DEPTH; const B1_COL_IDX: usize = B0_COL_IDX + 1; const H0_COL_IDX: usize = B1_COL_IDX + 1; @@ -30,7 +30,7 @@ const H0_COL_IDX: usize = B1_COL_IDX + 1; /// The number of boundary constraints required by the Stack, which is all stack positions for /// inputs and outputs as well as the initial values of the bookkeeping columns. -pub const NUM_ASSERTIONS: usize = 2 * STACK_TOP_SIZE + 2; +pub const NUM_ASSERTIONS: usize = 2 * MIN_STACK_DEPTH + 2; /// The number of general constraints in the stack operations. pub const NUM_GENERAL_CONSTRAINTS: usize = 17; @@ -193,19 +193,19 @@ pub fn enforce_general_constraints( /// Returns the stack's boundary assertions for the main trace at the first step. pub fn get_assertions_first_step(result: &mut Vec>, stack_inputs: &[Felt]) { // stack columns at the first step should be set to stack inputs, excluding overflow inputs. - for (i, &value) in stack_inputs.iter().take(STACK_TOP_SIZE).enumerate() { + for (i, &value) in stack_inputs.iter().take(MIN_STACK_DEPTH).enumerate() { result.push(Assertion::single(STACK_TRACE_OFFSET + i, 0, value)); } // if there are remaining slots on top of the stack without specified values, set them to ZERO. - for i in stack_inputs.len()..STACK_TOP_SIZE { + for i in stack_inputs.len()..MIN_STACK_DEPTH { result.push(Assertion::single(STACK_TRACE_OFFSET + i, 0, ZERO)); } // get the initial values for the bookkeeping columns. - let mut depth = STACK_TOP_SIZE; + let mut depth = MIN_STACK_DEPTH; let mut overflow_addr = ZERO; - if stack_inputs.len() > STACK_TOP_SIZE { + if stack_inputs.len() > MIN_STACK_DEPTH { depth = stack_inputs.len(); overflow_addr = -ONE; } @@ -225,7 +225,7 @@ pub fn get_assertions_last_step( stack_outputs: &StackOutputs, ) { // stack columns at the last step should be set to stack outputs, excluding overflow outputs - for (i, value) in stack_outputs.stack_top().iter().enumerate() { + for (i, value) in stack_outputs.iter().enumerate() { result.push(Assertion::single(STACK_TRACE_OFFSET + i, step, *value)); } } @@ -233,92 +233,19 @@ pub fn get_assertions_last_step( // --- AUXILIARY COLUMNS -------------------------------------------------------------------------- /// Returns the stack's boundary assertions for auxiliary columns at the first step. -pub fn get_aux_assertions_first_step( - result: &mut Vec>, - alphas: &[E], - stack_inputs: &[Felt], -) where - E: FieldElement, -{ - let step = 0; - let value = if stack_inputs.len() > STACK_TOP_SIZE { - get_overflow_table_init(alphas, &stack_inputs[STACK_TOP_SIZE..]) - } else { - E::ONE - }; - - result.push(Assertion::single(STACK_AUX_TRACE_OFFSET, step, value)); -} - -/// Returns the stack's boundary assertions for auxiliary columns at the last step. -pub fn get_aux_assertions_last_step( - result: &mut Vec>, - alphas: &[E], - stack_outputs: &StackOutputs, - step: usize, -) where - E: FieldElement, -{ - let value = if stack_outputs.has_overflow() { - get_overflow_table_final(alphas, stack_outputs) - } else { - E::ONE - }; - - result.push(Assertion::single(STACK_AUX_TRACE_OFFSET, step, value)); -} - -// BOUNDARY CONSTRAINT HELPERS -// ================================================================================================ - -// --- AUX TRACE ---------------------------------------------------------------------------------- - -/// Gets the initial value of the overflow table auxiliary column from the provided sets of initial -/// values and random elements. -fn get_overflow_table_init(alphas: &[E], init_values: &[Felt]) -> E +pub fn get_aux_assertions_first_step(result: &mut Vec>) where E: FieldElement, { - let mut value = E::ONE; - let mut prev_clk = ZERO; - let mut clk = -Felt::from(init_values.len() as u32); - - // the values are in the overflow table in reverse order, since the deepest stack - // value is added to the overflow table first. - for &input in init_values.iter().rev() { - value *= alphas[0] - + alphas[1].mul_base(clk) - + alphas[2].mul_base(input) - + alphas[3].mul_base(prev_clk); - prev_clk = clk; - clk += ONE; - } - - value + result.push(Assertion::single(STACK_AUX_TRACE_OFFSET, 0, E::ONE)); } -/// Gets the final value of the overflow table auxiliary column from the provided program outputs -/// and random elements. -fn get_overflow_table_final(alphas: &[E], stack_outputs: &StackOutputs) -> E +/// Returns the stack's boundary assertions for auxiliary columns at the last step. +pub fn get_aux_assertions_last_step(result: &mut Vec>, step: usize) where E: FieldElement, { - let mut value = E::ONE; - - // When the overflow table is non-empty, we expect at least 2 addresses (the `prev` value of - // the first row and the address value(s) of the row(s)) and more than STACK_TOP_SIZE - // elements in the stack. - let mut prev = stack_outputs.overflow_prev(); - for (clk, val) in stack_outputs.stack_overflow() { - value *= alphas[0] - + alphas[1].mul_base(clk) - + alphas[2].mul_base(val) - + alphas[3].mul_base(prev); - - prev = clk; - } - - value + result.push(Assertion::single(STACK_AUX_TRACE_OFFSET, step, E::ONE)); } // STACK OPERATION EXTENSION TRAIT @@ -362,9 +289,10 @@ trait EvaluationFrameExt { /// Gets the current value of user op helper register located at the specified index. fn user_op_helper(&self, index: usize) -> E; - /// Returns the value if the `h6` helper register in the decoder which is set to ONE if the - /// ending block is a `CALL` block. - fn is_call_end(&self) -> E; + /// Returns ONE if the block being `END`ed is a `CALL` or `DYNCALL`, or ZERO otherwise. + /// + /// This must only be used when an `END` operation is being executed. + fn is_call_or_dyncall_end(&self) -> E; /// Returns the value if the `h7` helper register in the decoder which is set to ONE if the /// ending block is a `SYSCALL` block. @@ -432,7 +360,7 @@ impl EvaluationFrameExt for &EvaluationFrame { } #[inline] - fn is_call_end(&self) -> E { + fn is_call_or_dyncall_end(&self) -> E { self.current()[DECODER_TRACE_OFFSET + IS_CALL_FLAG_COL_IDX] } diff --git a/air/src/constraints/stack/op_flags/mod.rs b/air/src/constraints/stack/op_flags/mod.rs index 7913ed6738..caa8302c15 100644 --- a/air/src/constraints/stack/op_flags/mod.rs +++ b/air/src/constraints/stack/op_flags/mod.rs @@ -267,7 +267,7 @@ impl OpFlags { // degree 6 flags do not use the first two bits (op_bits[0], op_bits[1]) degree4_op_flags[0] = not_2_not_3; // MRUPDATE - degree4_op_flags[1] = yes_2_not_3; // PUSH + degree4_op_flags[1] = yes_2_not_3; // (unused) degree4_op_flags[2] = not_2_yes_3; // SYSCALL degree4_op_flags[3] = yes_2_yes_3; // CALL @@ -292,6 +292,7 @@ impl OpFlags { + degree5_op_flags[1] // MPVERIFY + degree5_op_flags[6] // SPAN + degree5_op_flags[7] // JOIN + + degree5_op_flags[10] // EMIT + degree4_op_flags[6] // RESPAN + degree4_op_flags[7] // HALT + degree4_op_flags[3] // CALL @@ -347,7 +348,9 @@ impl OpFlags { + degree7_op_flags[47] + degree7_op_flags[46] + split_loop_flag - + shift_left_on_end; + + shift_left_on_end + + degree5_op_flags[8] // DYN + + degree5_op_flags[12]; // DYNCALL left_shift_flags[2] = left_shift_flags[1] + left_change_1_flag; left_shift_flags[3] = @@ -375,7 +378,7 @@ impl OpFlags { + degree7_op_flags[22] + degree7_op_flags[26]; - right_shift_flags[0] = f011 + degree4_op_flags[1] + movupn_flag; + right_shift_flags[0] = f011 + degree5_op_flags[11] + movupn_flag; // degree 5: PUSH right_shift_flags[1] = right_shift_flags[0] + degree6_op_flags[4]; // degree 6: U32SPLIT @@ -395,11 +398,17 @@ impl OpFlags { right_shift_flags[15] = right_shift_flags[8]; // Flag if the stack has been shifted to the right. - let right_shift = f011 + degree4_op_flags[1] + degree6_op_flags[4]; // PUSH; U32SPLIT + let right_shift = f011 + degree5_op_flags[11] + degree6_op_flags[4]; // PUSH; U32SPLIT - // Flag if the stack has been shifted to the left. - let left_shift = - f010 + add3_madd_flag + split_loop_flag + degree4_op_flags[5] + shift_left_on_end; + // Flag if the stack has been shifted to the left. Note that `DYNCALL` is not included in + // this flag even if it shifts the stack to the left. See `Opflags::left_shift()` for more + // information. + let left_shift = f010 + + add3_madd_flag + + split_loop_flag + + degree4_op_flags[5] + + shift_left_on_end + + degree5_op_flags[8]; // DYN // Flag if the current operation being executed is a control flow operation. // first row: SPAN, JOIN, SPLIT, LOOP @@ -907,7 +916,7 @@ impl OpFlags { /// Operation Flag of PUSH operation. #[inline(always)] pub fn push(&self) -> E { - self.degree4_op_flags[get_op_index(Operation::Push(ONE).op_code())] + self.degree5_op_flags[get_op_index(Operation::Push(ONE).op_code())] } /// Operation Flag of CALL operation. @@ -922,6 +931,12 @@ impl OpFlags { self.degree4_op_flags[get_op_index(Operation::SysCall.op_code())] } + /// Operation Flag of DYNCALL operation. + #[inline(always)] + pub fn dyncall(&self) -> E { + self.degree5_op_flags[get_op_index(Operation::Dyncall.op_code())] + } + /// Operation Flag of END operation. #[inline(always)] pub fn end(&self) -> E { @@ -981,6 +996,11 @@ impl OpFlags { } /// Returns the flag when the stack operation shifts the flag to the left. + /// + /// Note that although `DYNCALL` shifts the entire stack, it is not included in this flag. This + /// is because this "aggregate left shift" flag is used in constraints related to the stack + /// helper columns, and `DYNCALL` uses them unconventionally. + /// /// Degree: 5 #[inline(always)] pub fn left_shift(&self) -> E { diff --git a/air/src/constraints/stack/op_flags/tests.rs b/air/src/constraints/stack/op_flags/tests.rs index bfbc594f0f..d4d9a469d0 100644 --- a/air/src/constraints/stack/op_flags/tests.rs +++ b/air/src/constraints/stack/op_flags/tests.rs @@ -144,7 +144,8 @@ fn degree_4_op_flags() { fn composite_flags() { // ------ no change 0 --------------------------------------------------------------------- - let op_no_change_0 = [Operation::MpVerify(0), Operation::Span, Operation::Halt]; + let op_no_change_0 = + [Operation::MpVerify(0), Operation::Span, Operation::Halt, Operation::Emit(42)]; for op in op_no_change_0 { // frame initialised with an op operation. let frame = generate_evaluation_frame(op.op_code().into()); @@ -168,7 +169,7 @@ fn composite_flags() { assert_eq!(op_flags.left_shift(), ZERO); assert_eq!(op_flags.top_binary(), ZERO); - if op == Operation::MpVerify(0) { + if op == Operation::MpVerify(0) || op == Operation::Emit(42) { assert_eq!(op_flags.control_flow(), ZERO); } else if op == Operation::Span || op == Operation::Halt { assert_eq!(op_flags.control_flow(), ONE); diff --git a/air/src/constraints/stack/overflow/mod.rs b/air/src/constraints/stack/overflow/mod.rs index 5e055eedd5..6cac64a6c9 100644 --- a/air/src/constraints/stack/overflow/mod.rs +++ b/air/src/constraints/stack/overflow/mod.rs @@ -65,7 +65,8 @@ pub fn enforce_constraints( /// - If the operation is a left shift op, then, depth should be decreased by 1 provided the /// existing depth of the stack is not 16. In the case of depth being 16, depth will not be /// updated. -/// - If the current op being executed is `CALL` or `SYSCALL`, then the depth should be reset to 16. +/// - If the current op being executed is `CALL`, `SYSCALL` or `DYNCALL`, then the depth should be +/// reset to 16. /// /// TODO: This skips the operation when `END` is exiting for a `CALL` or a `SYSCALL` block. It /// should be handled later in multiset constraints. @@ -77,13 +78,15 @@ pub fn enforce_stack_depth_constraints( let depth = frame.stack_depth(); let depth_next = frame.stack_depth_next(); - let call_or_syscall = op_flag.call() + op_flag.syscall(); - let call_or_syscall_end = op_flag.end() * (frame.is_call_end() + frame.is_syscall_end()); + let call_or_dyncall_or_syscall = op_flag.call() + op_flag.dyncall() + op_flag.syscall(); + let call_or_dyncall_or_syscall_end = + op_flag.end() * (frame.is_call_or_dyncall_end() + frame.is_syscall_end()); - let no_shift_part = (depth_next - depth) * (E::ONE - call_or_syscall - call_or_syscall_end); + let no_shift_part = (depth_next - depth) + * (E::ONE - call_or_dyncall_or_syscall - call_or_dyncall_or_syscall_end); let left_shift_part = op_flag.left_shift() * op_flag.overflow(); let right_shift_part = op_flag.right_shift(); - let call_part = call_or_syscall * (depth_next - E::from(16u32)); + let call_part = call_or_dyncall_or_syscall * (depth_next - E::from(16u32)); // Enforces constraints of the transition of depth of the stack. result[0] = no_shift_part + left_shift_part - right_shift_part + call_part; diff --git a/air/src/constraints/stack/u32_ops/mod.rs b/air/src/constraints/stack/u32_ops/mod.rs index e4190c3c8e..e4d0c3fe15 100644 --- a/air/src/constraints/stack/u32_ops/mod.rs +++ b/air/src/constraints/stack/u32_ops/mod.rs @@ -232,8 +232,10 @@ pub fn enforce_u32madd_constraints>( 1 } -/// Enforces constraints of the U32DIV operation. The U32DIV operation divides the second element -/// with the first element in the current trace. Therefore, the following constraints are enforced: +/// Enforces constraints of the U32DIV operation. +/// +/// The U32DIV operation divides the second element with the first element in the current trace. +/// Therefore, the following constraints are enforced: /// - The second element in the current trace should be equal to the sum of the first element in the /// next trace with the product of the first element in the current trace and second element in /// the next trace. diff --git a/air/src/lib.rs b/air/src/lib.rs index b126c07a3a..da934e4d28 100644 --- a/air/src/lib.rs +++ b/air/src/lib.rs @@ -41,7 +41,9 @@ pub use vm_core::{ utils::{DeserializationError, ToElements}, Felt, FieldElement, StarkField, }; -pub use winter_air::{AuxRandElements, FieldExtension, LagrangeKernelEvaluationFrame}; +pub use winter_air::{ + AuxRandElements, FieldExtension, LagrangeKernelEvaluationFrame, PartitionOptions, +}; // PROCESSOR AIR // ================================================================================================ @@ -146,7 +148,7 @@ impl Air for ProcessorAir { result.push(Assertion::single(FMP_COL_IDX, 0, Felt::new(2u64.pow(30)))); // add initial assertions for the stack. - stack::get_assertions_first_step(&mut result, self.stack_inputs.values()); + stack::get_assertions_first_step(&mut result, &*self.stack_inputs); // Add initial assertions for the range checker. range::get_assertions_first_step(&mut result); @@ -165,18 +167,14 @@ impl Air for ProcessorAir { fn get_aux_assertions>( &self, - aux_rand_elements: &[E], + _aux_rand_elements: &AuxRandElements, ) -> Vec> { let mut result: Vec> = Vec::new(); // --- set assertions for the first step -------------------------------------------------- // add initial assertions for the stack's auxiliary columns. - stack::get_aux_assertions_first_step( - &mut result, - aux_rand_elements, - self.stack_inputs.values(), - ); + stack::get_aux_assertions_first_step(&mut result); // Add initial assertions for the range checker's auxiliary columns. range::get_aux_assertions_first_step::(&mut result); @@ -185,12 +183,7 @@ impl Air for ProcessorAir { let last_step = self.last_step(); // add the stack's auxiliary column assertions for the last step. - stack::get_aux_assertions_last_step( - &mut result, - aux_rand_elements, - &self.stack_outputs, - last_step, - ); + stack::get_aux_assertions_last_step(&mut result, last_step); // Add the range checker's auxiliary column assertions for the last step. range::get_aux_assertions_last_step::(&mut result, last_step); @@ -239,14 +232,19 @@ impl Air for ProcessorAir { main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], - aux_rand_elements: &[E], + aux_rand_elements: &AuxRandElements, result: &mut [E], ) where F: FieldElement, E: FieldElement + ExtensionOf, { // --- range checker ---------------------------------------------------------------------- - range::enforce_aux_constraints::(main_frame, aux_frame, aux_rand_elements, result); + range::enforce_aux_constraints::( + main_frame, + aux_frame, + aux_rand_elements.rand_elements(), + result, + ); } fn context(&self) -> &AirContext { @@ -281,8 +279,8 @@ impl PublicInputs { impl vm_core::ToElements for PublicInputs { fn to_elements(&self) -> Vec { let mut result = self.program_info.to_elements(); - result.append(&mut self.stack_inputs.to_elements()); - result.append(&mut self.stack_outputs.to_elements()); + result.append(&mut self.stack_inputs.to_vec()); + result.append(&mut self.stack_outputs.to_vec()); result } } diff --git a/air/src/options.rs b/air/src/options.rs index f880563ccd..fbbbbb31ce 100644 --- a/air/src/options.rs +++ b/air/src/options.rs @@ -192,6 +192,7 @@ impl ExecutionOptions { max_cycles: Option, expected_cycles: u32, enable_tracing: bool, + enable_debugging: bool, ) -> Result { let max_cycles = max_cycles.unwrap_or(u32::MAX); if max_cycles < MIN_TRACE_LEN as u32 { @@ -209,7 +210,7 @@ impl ExecutionOptions { max_cycles, expected_cycles, enable_tracing, - enable_debugging: false, + enable_debugging, }) } diff --git a/air/src/trace/chiplets/memory.rs b/air/src/trace/chiplets/memory.rs index 8fa9648dae..6e531d30d1 100644 --- a/air/src/trace/chiplets/memory.rs +++ b/air/src/trace/chiplets/memory.rs @@ -11,8 +11,8 @@ pub const NUM_SELECTORS: usize = 2; /// Type for Memory trace selectors. /// -/// These selectors are used to define which operation and memory state update (init & read / copy -/// & read / write) is to be applied at a specific row of the memory execution trace. +/// These selectors are used to define which operation and memory state update (init & read / copy & +/// read / write) is to be applied at a specific row of the memory execution trace. pub type Selectors = [Felt; NUM_SELECTORS]; // --- OPERATION SELECTORS ------------------------------------------------------------------------ diff --git a/air/src/trace/decoder/mod.rs b/air/src/trace/decoder/mod.rs index 77a7531d7f..aaad96cc42 100644 --- a/air/src/trace/decoder/mod.rs +++ b/air/src/trace/decoder/mod.rs @@ -83,7 +83,7 @@ pub const IS_LOOP_BODY_FLAG_COL_IDX: usize = HASHER_STATE_RANGE.start + 4; /// Index of a flag column which indicates whether an ending block is a LOOP block. pub const IS_LOOP_FLAG_COL_IDX: usize = HASHER_STATE_RANGE.start + 5; -/// Index of a flag column which indicates whether an ending block is a CALL block. +/// Index of a flag column which indicates whether an ending block is a CALL or DYNCALL block. pub const IS_CALL_FLAG_COL_IDX: usize = HASHER_STATE_RANGE.start + 6; /// Index of a flag column which indicates whether an ending block is a SYSCALL block. diff --git a/air/src/trace/main_trace.rs b/air/src/trace/main_trace.rs index 79092a1b06..eebc8e2779 100644 --- a/air/src/trace/main_trace.rs +++ b/air/src/trace/main_trace.rs @@ -137,7 +137,7 @@ impl MainTrace { /// Returns a specific element from the hasher state at row i. pub fn decoder_hasher_state_element(&self, element: usize, i: RowIndex) -> Felt { - self.columns.get_column(DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + element)[i + 1] + self.columns.get_column(DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET + element)[i] } /// Returns the current function hash (i.e., root) at row i. @@ -240,6 +240,8 @@ impl MainTrace { ([e0, b3, b2, b1] == [ONE, ZERO, ONE, ZERO]) || // REPEAT ([b6, b5, b4, b3, b2, b1, b0] == [ONE, ONE, ONE, ZERO, ONE, ZERO, ZERO]) || + // DYN + ([b6, b5, b4, b3, b2, b1, b0] == [ONE, ZERO, ONE, ONE, ZERO, ZERO, ZERO]) || // END of a loop ([b6, b5, b4, b3, b2, b1, b0] == [ONE, ONE, ONE, ZERO, ZERO, ZERO, ZERO] && h5 == ONE) } @@ -259,8 +261,8 @@ impl MainTrace { [b6, b5, b4] == [ZERO, ONE, ONE]|| // u32SPLIT 100_1000 ([b6, b5, b4, b3, b2, b1, b0] == [ONE, ZERO, ZERO, ONE, ZERO, ZERO, ZERO]) || - // PUSH i.e., 110_0100 - ([b6, b5, b4, b3, b2, b1, b0] == [ONE, ONE, ZERO, ZERO, ONE, ZERO, ZERO]) + // PUSH i.e., 101_1011 + ([b6, b5, b4, b3, b2, b1, b0] == [ONE, ZERO, ONE, ONE, ZERO, ONE, ONE]) } // STACK COLUMNS diff --git a/air/src/trace/stack/mod.rs b/air/src/trace/stack/mod.rs index f4dbf4035d..1960933cc4 100644 --- a/air/src/trace/stack/mod.rs +++ b/air/src/trace/stack/mod.rs @@ -1,6 +1,6 @@ use core::ops::Range; -use vm_core::utils::range; +use vm_core::{stack::MIN_STACK_DEPTH, utils::range}; // CONSTANTS // ================================================================================================ @@ -8,12 +8,8 @@ use vm_core::utils::range; /// Index at which stack item columns start in the stack trace. pub const STACK_TOP_OFFSET: usize = 0; -/// The number of stack registers which can be accessed by the VM directly. This is also the -/// minimum stack depth enforced by the VM. -pub const STACK_TOP_SIZE: usize = 16; - /// Location of stack top items in the stack trace. -pub const STACK_TOP_RANGE: Range = range(STACK_TOP_OFFSET, STACK_TOP_SIZE); +pub const STACK_TOP_RANGE: Range = range(STACK_TOP_OFFSET, MIN_STACK_DEPTH); /// Number of bookkeeping and helper columns in the stack trace. pub const NUM_STACK_HELPER_COLS: usize = 3; diff --git a/assembly/Cargo.toml b/assembly/Cargo.toml index da461bc398..e72f780fb6 100644 --- a/assembly/Cargo.toml +++ b/assembly/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "miden-assembly" -version = "0.10.5" +version = "0.11.0" description = "Miden VM assembly language" -documentation = "https://docs.rs/miden-assembly/0.10.5" +documentation = "https://docs.rs/miden-assembly/0.11.0" readme = "README.md" categories = ["compilers", "no-std"] keywords = ["assembler", "assembly", "language", "miden"] @@ -33,8 +33,8 @@ regex = { version = "1.10", optional = true, default-features = false, features smallvec = { version = "1.13", features = ["union", "const_generics", "const_new"] } thiserror = { package = "miden-thiserror", version = "1.0", default-features = false } tracing = { version = "0.1", default-features = false, features = ["attributes"] } -unicode-width = { version = "0.1", features = ["no_std"] } -vm-core = { package = "miden-core", path = "../core", version = "0.10", default-features = false, features = [ +unicode-width = { version = "0.2", features = ["no_std"] } +vm-core = { package = "miden-core", path = "../core", version = "0.11", default-features = false, features = [ "diagnostics", ] } diff --git a/assembly/src/assembler/basic_block_builder.rs b/assembly/src/assembler/basic_block_builder.rs index 5a6a05cabb..3246484d21 100644 --- a/assembly/src/assembler/basic_block_builder.rs +++ b/assembly/src/assembler/basic_block_builder.rs @@ -1,6 +1,9 @@ use alloc::{borrow::Borrow, string::ToString, vec::Vec}; -use vm_core::{mast::MastNodeId, AdviceInjector, AssemblyOp, Decorator, Operation}; +use vm_core::{ + mast::{DecoratorId, MastNodeId}, + AdviceInjector, AssemblyOp, Decorator, Operation, +}; use super::{mast_forest_builder::MastForestBuilder, BodyWrapper, DecoratorList, ProcedureContext}; use crate::{ast::Instruction, AssemblyError, Span}; @@ -17,36 +20,60 @@ use crate::{ast::Instruction, AssemblyError, Span}; /// The same basic block builder can be used to construct many blocks. It is expected that when the /// last basic block in a procedure's body is constructed [`Self::try_into_basic_block`] will be /// used. -#[derive(Default)] -pub struct BasicBlockBuilder { +#[derive(Debug)] +pub struct BasicBlockBuilder<'a> { ops: Vec, decorators: DecoratorList, epilogue: Vec, last_asmop_pos: usize, + mast_forest_builder: &'a mut MastForestBuilder, } /// Constructors -impl BasicBlockBuilder { +impl<'a> BasicBlockBuilder<'a> { /// Returns a new [`BasicBlockBuilder`] instantiated with the specified optional wrapper. /// /// If the wrapper is provided, the prologue of the wrapper is immediately appended to the /// vector of span operations. The epilogue of the wrapper is appended to the list of operations /// upon consumption of the builder via the [`Self::try_into_basic_block`] method. - pub(super) fn new(wrapper: Option) -> Self { + pub(super) fn new( + wrapper: Option, + mast_forest_builder: &'a mut MastForestBuilder, + ) -> Self { match wrapper { Some(wrapper) => Self { ops: wrapper.prologue, decorators: Vec::new(), epilogue: wrapper.epilogue, last_asmop_pos: 0, + mast_forest_builder, + }, + None => Self { + ops: Default::default(), + decorators: Default::default(), + epilogue: Default::default(), + last_asmop_pos: 0, + mast_forest_builder, }, - None => Self::default(), } } } +/// Accessors +impl BasicBlockBuilder<'_> { + /// Returns a reference to the internal [`MastForestBuilder`]. + pub fn mast_forest_builder(&self) -> &MastForestBuilder { + self.mast_forest_builder + } + + /// Returns a mutable reference to the internal [`MastForestBuilder`]. + pub fn mast_forest_builder_mut(&mut self) -> &mut MastForestBuilder { + self.mast_forest_builder + } +} + /// Operations -impl BasicBlockBuilder { +impl BasicBlockBuilder<'_> { /// Adds the specified operation to the list of basic block operations. pub fn push_op(&mut self, op: Operation) { self.ops.push(op); @@ -69,15 +96,18 @@ impl BasicBlockBuilder { } /// Decorators -impl BasicBlockBuilder { +impl BasicBlockBuilder<'_> { /// Add the specified decorator to the list of basic block decorators. - pub fn push_decorator(&mut self, decorator: Decorator) { - self.decorators.push((self.ops.len(), decorator)); + pub fn push_decorator(&mut self, decorator: Decorator) -> Result<(), AssemblyError> { + let decorator_id = self.mast_forest_builder.ensure_decorator(decorator)?; + self.decorators.push((self.ops.len(), decorator_id)); + + Ok(()) } /// Adds the specified advice injector to the list of basic block decorators. - pub fn push_advice_injector(&mut self, injector: AdviceInjector) { - self.push_decorator(Decorator::Advice(injector)); + pub fn push_advice_injector(&mut self, injector: AdviceInjector) -> Result<(), AssemblyError> { + self.push_decorator(Decorator::Advice(injector)) } /// Adds an AsmOp decorator to the list of basic block decorators. @@ -88,7 +118,7 @@ impl BasicBlockBuilder { &mut self, instruction: &Span, proc_ctx: &ProcedureContext, - ) { + ) -> Result<(), AssemblyError> { let span = instruction.span(); let location = proc_ctx.source_manager().location(span).ok(); let context_name = proc_ctx.name().to_string(); @@ -96,8 +126,10 @@ impl BasicBlockBuilder { let op = instruction.to_string(); let should_break = instruction.should_break(); let op = AssemblyOp::new(location, context_name, num_cycles, op, should_break); - self.push_decorator(Decorator::AsmOp(op)); + self.push_decorator(Decorator::AsmOp(op))?; self.last_asmop_pos = self.decorators.len() - 1; + + Ok(()) } /// Computes the number of cycles elapsed since the last invocation of track_instruction() @@ -108,8 +140,10 @@ impl BasicBlockBuilder { /// call, and syscall. pub fn set_instruction_cycle_count(&mut self) { // get the last asmop decorator and the cycle at which it was added - let (op_start, assembly_op) = + let (op_start, assembly_op_id) = self.decorators.get_mut(self.last_asmop_pos).expect("no asmop decorator"); + + let assembly_op = &mut self.mast_forest_builder[*assembly_op_id]; assert!(matches!(assembly_op, Decorator::AsmOp(_))); // compute the cycle count for the instruction @@ -125,45 +159,77 @@ impl BasicBlockBuilder { } /// Span Constructors -impl BasicBlockBuilder { - /// Creates and returns a new BASIC BLOCK node from the operations and decorators currently in - /// this builder. If the builder is empty, then no node is created and `None` is returned. +impl BasicBlockBuilder<'_> { + /// Creates and returns a new basic block node from the operations and decorators currently in + /// this builder. /// - /// This consumes all operations and decorators in the builder, but does not touch the - /// operations in the epilogue of the builder. - pub fn make_basic_block( - &mut self, - mast_forest_builder: &mut MastForestBuilder, - ) -> Result, AssemblyError> { + /// If there are no operations however, then no node is created, the decorators are left + /// untouched and `None` is returned. Use [`Self::drain_decorators`] to retrieve the decorators + /// in this case. + /// + /// This consumes all operations in the builder, but does not touch the operations in the + /// epilogue of the builder. + pub fn make_basic_block(&mut self) -> Result, AssemblyError> { if !self.ops.is_empty() { let ops = self.ops.drain(..).collect(); - let decorators = self.decorators.drain(..).collect(); + let decorators = if !self.decorators.is_empty() { + Some(self.decorators.drain(..).collect()) + } else { + None + }; - let basic_block_node_id = mast_forest_builder.ensure_block(ops, Some(decorators))?; + let basic_block_node_id = self.mast_forest_builder.ensure_block(ops, decorators)?; Ok(Some(basic_block_node_id)) - } else if !self.decorators.is_empty() { - // this is a bug in the assembler. we shouldn't have decorators added without their - // associated operations - // TODO: change this to an error or allow decorators in empty span blocks - unreachable!("decorators in an empty SPAN block") } else { Ok(None) } } - /// Creates and returns a new BASIC BLOCK node from the operations and decorators currently in - /// this builder. If the builder is empty, then no node is created and `None` is returned. + /// Creates and returns a new basic block node from the operations and decorators currently in + /// this builder. If there are no operations however, we return the decorators that were + /// accumulated up until this point. If the builder is empty, then no node is created and + /// `Nothing` is returned. /// - /// The main differences with [`Self::to_basic_block`] are: + /// The main differences with [`Self::make_basic_block`] are: /// - Operations contained in the epilogue of the builder are appended to the list of ops which /// go into the new BASIC BLOCK node. /// - The builder is consumed in the process. - pub fn try_into_basic_block( - mut self, - mast_forest_builder: &mut MastForestBuilder, - ) -> Result, AssemblyError> { + /// - Hence, any remaining decorators if no basic block was created are drained and returned. + pub fn try_into_basic_block(mut self) -> Result { self.ops.append(&mut self.epilogue); - self.make_basic_block(mast_forest_builder) + + if let Some(basic_block_node_id) = self.make_basic_block()? { + Ok(BasicBlockOrDecorators::BasicBlock(basic_block_node_id)) + } else if let Some(decorator_ids) = self.drain_decorators() { + Ok(BasicBlockOrDecorators::Decorators(decorator_ids)) + } else { + Ok(BasicBlockOrDecorators::Nothing) + } + } + + /// Drains and returns the decorators in the builder, if any. + /// + /// This should only be called after [`Self::make_basic_block`], when no blocks were created. + /// In other words, there MUST NOT be any operations left in the builder when this is called. + /// + /// # Panics + /// + /// Panics if there are still operations left in the builder. + pub fn drain_decorators(&mut self) -> Option> { + assert!(self.ops.is_empty()); + if !self.decorators.is_empty() { + Some(self.decorators.drain(..).map(|(_, decorator_id)| decorator_id).collect()) + } else { + None + } } } + +/// Holds either the node id of a basic block, or a list of decorators that are currently not +/// attached to any node. +pub enum BasicBlockOrDecorators { + BasicBlock(MastNodeId), + Decorators(Vec), + Nothing, +} diff --git a/assembly/src/assembler/instruction/adv_ops.rs b/assembly/src/assembler/instruction/adv_ops.rs index 2ec27fc8a8..86a38c1fa8 100644 --- a/assembly/src/assembler/instruction/adv_ops.rs +++ b/assembly/src/assembler/instruction/adv_ops.rs @@ -13,9 +13,9 @@ use crate::{ast::AdviceInjectorNode, AssemblyError, ADVICE_READ_LIMIT}; /// # Errors /// Returns an error if the specified number of values to pushed is smaller than 1 or greater /// than 16. -pub fn adv_push(span: &mut BasicBlockBuilder, n: u8) -> Result<(), AssemblyError> { +pub fn adv_push(block_builder: &mut BasicBlockBuilder, n: u8) -> Result<(), AssemblyError> { validate_param(n, 1..=ADVICE_READ_LIMIT)?; - span.push_op_many(Operation::AdvPop, n as usize); + block_builder.push_op_many(Operation::AdvPop, n as usize); Ok(()) } @@ -23,6 +23,9 @@ pub fn adv_push(span: &mut BasicBlockBuilder, n: u8) -> Result<(), AssemblyError // ================================================================================================ /// Appends advice injector decorator to the span. -pub fn adv_inject(span: &mut BasicBlockBuilder, injector: &AdviceInjectorNode) { - span.push_advice_injector(injector.into()); +pub fn adv_inject( + block_builder: &mut BasicBlockBuilder, + injector: &AdviceInjectorNode, +) -> Result<(), AssemblyError> { + block_builder.push_advice_injector(injector.into()) } diff --git a/assembly/src/assembler/instruction/crypto_ops.rs b/assembly/src/assembler/instruction/crypto_ops.rs index 0b72975ead..54e6a2acff 100644 --- a/assembly/src/assembler/instruction/crypto_ops.rs +++ b/assembly/src/assembler/instruction/crypto_ops.rs @@ -1,6 +1,7 @@ -use vm_core::{AdviceInjector, Operation::*}; +use vm_core::{AdviceInjector, Felt, Operation::*}; use super::BasicBlockBuilder; +use crate::AssemblyError; // HASHING // ================================================================================================ @@ -15,26 +16,27 @@ use super::BasicBlockBuilder; /// To perform the operation we do the following: /// 1. Prepare the stack with 12 elements for HPERM by pushing 4 more elements for the capacity, /// then reordering the stack and pushing an additional 4 elements so that the stack looks like: -/// [0, 0, 0, 1, a3, a2, a1, a0, 0, 0, 0, 1, ...]. The first capacity element is set to ONE as -/// we are hashing a number of elements which is not a multiple of the rate width. We also set -/// the next element in the rate after `A` to ONE. All other capacity and rate elements are set -/// to ZERO, in accordance with the RPO rules. +/// [0, 0, 0, 0, a3, a2, a1, a0, 0, 0, 0, 4, ...]. The first capacity element is set to Felt(4) +/// as we are hashing a number of elements which is equal to 4 modulo the rate width, while the +/// other capacity elements are set to ZERO. A sequence of 4 ZERO elements is used as padding. +/// The padding rule used follows the one described in this [work](https://eprint.iacr.org/2023/1045). /// 2. Append the HPERM operation, which performs a permutation of RPO on the top 12 elements and /// leaves the an output of [D, C, B, ...] on the stack. C is our 1-to-1 has result. /// 3. Drop D and B to achieve our result [C, ...] /// /// This operation takes 20 VM cycles. -pub(super) fn hash(span: &mut BasicBlockBuilder) { +pub(super) fn hash(block_builder: &mut BasicBlockBuilder) { #[rustfmt::skip] let ops = [ - // add 4 elements to the stack to be used as the capacity elements for the RPO permutation - Pad, Incr, Pad, Pad, Pad, + // add 4 elements to the stack to be used as the capacity elements for the RPO permutation. + // Since we are hashing 4 field elements, the first capacity element is set to 4. + Push(Felt::from(4_u32)), Pad, Pad, Pad, // swap capacity elements such that they are below the elements to be hashed SwapW, - // Duplicate capacity elements in the rate portion of the stack - Dup7, Dup7, Dup7, Dup7, + // add 4 ZERO elements for the second half of the rate portion + Pad, Dup7, Dup7, Dup7, // Apply a hashing permutation on the top 12 elements in the stack HPerm, @@ -48,7 +50,7 @@ pub(super) fn hash(span: &mut BasicBlockBuilder) { // Drop 4 elements (the capacity portion) Drop, Drop, Drop, Drop, ]; - span.push_ops(ops); + block_builder.push_ops(ops); } /// Appends HPERM and stack manipulation operations to the span block as required to compute a @@ -70,7 +72,7 @@ pub(super) fn hash(span: &mut BasicBlockBuilder) { /// 4. Drop F and D to return our result [E, ...]. /// /// This operation takes 16 VM cycles. -pub(super) fn hmerge(span: &mut BasicBlockBuilder) { +pub(super) fn hmerge(block_builder: &mut BasicBlockBuilder) { #[rustfmt::skip] let ops = [ // Add 4 elements to the stack to prepare the capacity portion for the RPO permutation @@ -93,7 +95,7 @@ pub(super) fn hmerge(span: &mut BasicBlockBuilder) { // Drop 4 elements (the capacity portion) Drop, Drop, Drop, Drop, ]; - span.push_ops(ops); + block_builder.push_ops(ops); } // MERKLE TREES @@ -111,10 +113,10 @@ pub(super) fn hmerge(span: &mut BasicBlockBuilder) { /// - root of the tree, 4 elements. /// /// This operation takes 9 VM cycles. -pub(super) fn mtree_get(span: &mut BasicBlockBuilder) { +pub(super) fn mtree_get(block_builder: &mut BasicBlockBuilder) -> Result<(), AssemblyError> { // stack: [d, i, R, ...] // pops the value of the node we are looking for from the advice stack - read_mtree_node(span); + read_mtree_node(block_builder)?; #[rustfmt::skip] let ops = [ // verify the node V for root R with depth d and index i @@ -125,7 +127,9 @@ pub(super) fn mtree_get(span: &mut BasicBlockBuilder) { // no longer needed => [V, R, ...] MovUp4, Drop, MovUp4, Drop, ]; - span.push_ops(ops); + block_builder.push_ops(ops); + + Ok(()) } /// Appends the MRUPDATE op with a parameter of "false" and stack manipulations to the span block @@ -141,11 +145,11 @@ pub(super) fn mtree_get(span: &mut BasicBlockBuilder) { /// - new root of the tree after the update, 4 elements /// /// This operation takes 29 VM cycles. -pub(super) fn mtree_set(span: &mut BasicBlockBuilder) { +pub(super) fn mtree_set(block_builder: &mut BasicBlockBuilder) -> Result<(), AssemblyError> { // stack: [d, i, R_old, V_new, ...] // stack: [V_old, R_new, ...] (29 cycles) - update_mtree(span); + update_mtree(block_builder) } /// Creates a new Merkle tree in the advice provider by combining trees with the specified roots. @@ -161,16 +165,18 @@ pub(super) fn mtree_set(span: &mut BasicBlockBuilder) { /// It is not checked whether the provided roots exist as Merkle trees in the advide providers. /// /// This operation takes 16 VM cycles. -pub(super) fn mtree_merge(span: &mut BasicBlockBuilder) { +pub(super) fn mtree_merge(block_builder: &mut BasicBlockBuilder) -> Result<(), AssemblyError> { // stack input: [R_rhs, R_lhs, ...] // stack output: [R_merged, ...] // invoke the advice provider function to merge 2 Merkle trees defined by the roots on the top // of the operand stack - span.push_advice_injector(AdviceInjector::MerkleNodeMerge); + block_builder.push_advice_injector(AdviceInjector::MerkleNodeMerge)?; // perform the `hmerge`, updating the operand stack - hmerge(span); + hmerge(block_builder); + + Ok(()) } // MERKLE TREES - HELPERS @@ -193,31 +199,33 @@ pub(super) fn mtree_merge(span: &mut BasicBlockBuilder) { /// - new value of the node, 4 elements (only in the case of mtree_set) /// /// This operation takes 4 VM cycles. -fn read_mtree_node(span: &mut BasicBlockBuilder) { +fn read_mtree_node(block_builder: &mut BasicBlockBuilder) -> Result<(), AssemblyError> { // The stack should be arranged in the following way: [d, i, R, ...] so that the decorator // can fetch the node value from the root. In the `mtree.get` operation we have the stack in // the following format: [d, i, R], whereas in the case of `mtree.set` we would also have the // new node value post the tree root: [d, i, R, V_new] // // pops the value of the node we are looking for from the advice stack - span.push_advice_injector(AdviceInjector::MerkleNodeToStack); + block_builder.push_advice_injector(AdviceInjector::MerkleNodeToStack)?; // pops the old node value from advice the stack => MPVERIFY: [V_old, d, i, R, ...] // MRUPDATE: [V_old, d, i, R, V_new, ...] - span.push_op_many(AdvPop, 4); + block_builder.push_op_many(AdvPop, 4); + + Ok(()) } /// Update a node in the merkle tree. This operation will always copy the tree into a new instance, /// and perform the mutation on the copied tree. /// /// This operation takes 29 VM cycles. -fn update_mtree(span: &mut BasicBlockBuilder) { +fn update_mtree(block_builder: &mut BasicBlockBuilder) -> Result<(), AssemblyError> { // stack: [d, i, R_old, V_new, ...] // output: [R_new, R_old, V_new, V_old, ...] // Inject the old node value onto the stack for the call to MRUPDATE. // stack: [V_old, d, i, R_old, V_new, ...] (4 cycles) - read_mtree_node(span); + read_mtree_node(block_builder)?; #[rustfmt::skip] let ops = [ @@ -279,5 +287,7 @@ fn update_mtree(span: &mut BasicBlockBuilder) { ]; // stack: [V_old, R_new, ...] (25 cycles) - span.push_ops(ops); + block_builder.push_ops(ops); + + Ok(()) } diff --git a/assembly/src/assembler/instruction/env_ops.rs b/assembly/src/assembler/instruction/env_ops.rs index ef30001adc..2026af71f5 100644 --- a/assembly/src/assembler/instruction/env_ops.rs +++ b/assembly/src/assembler/instruction/env_ops.rs @@ -11,11 +11,11 @@ use crate::{assembler::ProcedureContext, AssemblyError, Felt, SourceSpan}; /// In cases when the immediate value is 0, `PUSH` operation is replaced with `PAD`. Also, in cases /// when immediate value is 1, `PUSH` operation is replaced with `PAD INCR` because in most cases /// this will be more efficient than doing a `PUSH`. -pub fn push_one(imm: T, span: &mut BasicBlockBuilder) +pub fn push_one(imm: T, block_builder: &mut BasicBlockBuilder) where T: Into, { - push_felt(span, imm.into()); + push_felt(block_builder, imm.into()); } /// Appends `PUSH` operations to the span block to push two or more provided constant values onto @@ -24,11 +24,11 @@ where /// In cases when the immediate value is 0, `PUSH` operation is replaced with `PAD`. Also, in cases /// when immediate value is 1, `PUSH` operation is replaced with `PAD INCR` because in most cases /// this will be more efficient than doing a `PUSH`. -pub fn push_many(imms: &[T], span: &mut BasicBlockBuilder) +pub fn push_many(imms: &[T], block_builder: &mut BasicBlockBuilder) where T: Into + Copy, { - imms.iter().for_each(|imm| push_felt(span, (*imm).into())); + imms.iter().for_each(|imm| push_felt(block_builder, (*imm).into())); } // ENVIRONMENT INPUTS @@ -40,11 +40,11 @@ where /// # Errors /// Returns an error if index is greater than the number of procedure locals. pub fn locaddr( - span: &mut BasicBlockBuilder, + block_builder: &mut BasicBlockBuilder, index: u16, proc_ctx: &ProcedureContext, ) -> Result<(), AssemblyError> { - local_to_absolute_addr(span, index, proc_ctx.num_locals()) + local_to_absolute_addr(block_builder, index, proc_ctx.num_locals()) } /// Appends CALLER operation to the span which puts the hash of the function which initiated the @@ -53,7 +53,7 @@ pub fn locaddr( /// # Errors /// Returns an error if the instruction is being executed outside of kernel context. pub fn caller( - span: &mut BasicBlockBuilder, + block_builder: &mut BasicBlockBuilder, proc_ctx: &ProcedureContext, source_span: SourceSpan, ) -> Result<(), AssemblyError> { @@ -63,6 +63,6 @@ pub fn caller( source_file: proc_ctx.source_manager().get(source_span.source_id()).ok(), }); } - span.push_op(Caller); + block_builder.push_op(Caller); Ok(()) } diff --git a/assembly/src/assembler/instruction/ext2_ops.rs b/assembly/src/assembler/instruction/ext2_ops.rs index be3749e70f..6cc3ae88a4 100644 --- a/assembly/src/assembler/instruction/ext2_ops.rs +++ b/assembly/src/assembler/instruction/ext2_ops.rs @@ -1,13 +1,14 @@ use vm_core::{AdviceInjector::Ext2Inv, Operation::*}; use super::BasicBlockBuilder; +use crate::AssemblyError; /// Given a stack in the following initial configuration [b1, b0, a1, a0, ...] where a = (a0, a1) /// and b = (b0, b1) represent elements in the extension field of degree 2, this series of /// operations outputs the result c = (c1, c0) where c1 = a1 + b1 and c0 = a0 + b0. /// /// This operation takes 5 VM cycles. -pub fn ext2_add(span: &mut BasicBlockBuilder) { +pub fn ext2_add(block_builder: &mut BasicBlockBuilder) { #[rustfmt::skip] let ops = [ Swap, // [b0, b1, a1, a0, ...] @@ -16,7 +17,7 @@ pub fn ext2_add(span: &mut BasicBlockBuilder) { MovDn2, // [b1, a1, a0+b0, ...] Add // [b1+a1, a0+b0, ...] ]; - span.push_ops(ops); + block_builder.push_ops(ops); } /// Given a stack in the following initial configuration [b1, b0, a1, a0, ...] where a = (a0, a1) @@ -24,7 +25,7 @@ pub fn ext2_add(span: &mut BasicBlockBuilder) { /// operations outputs the result c = (c1, c0) where c1 = a1 - b1 and c0 = a0 - b0. /// /// This operation takes 7 VM cycles. -pub fn ext2_sub(span: &mut BasicBlockBuilder) { +pub fn ext2_sub(block_builder: &mut BasicBlockBuilder) { #[rustfmt::skip] let ops = [ Neg, // [-b1, b0, a1, a0, ...] @@ -35,7 +36,7 @@ pub fn ext2_sub(span: &mut BasicBlockBuilder) { MovDn2, // [-b1, a1, a0-b0, ...] Add // [a1-b1, a0-b0, ...] ]; - span.push_ops(ops); + block_builder.push_ops(ops); } /// Given a stack with initial configuration given by [b1, b0, a1, a0, ...] where a = (a0, a1) and @@ -43,8 +44,8 @@ pub fn ext2_sub(span: &mut BasicBlockBuilder) { /// outputs the product c = (c1, c0) where c0 = a0b0 - 2(a1b1) and c1 = (a0 + a1)(b0 + b1) - a0b0 /// /// This operation takes 3 VM cycles. -pub fn ext2_mul(span: &mut BasicBlockBuilder) { - span.push_ops([Ext2Mul, Drop, Drop]); +pub fn ext2_mul(block_builder: &mut BasicBlockBuilder) { + block_builder.push_ops([Ext2Mul, Drop, Drop]); } /// Given a stack in the following initial configuration [b1, b0, a1, a0, ...] where a = (a0, a1) @@ -52,8 +53,8 @@ pub fn ext2_mul(span: &mut BasicBlockBuilder) { /// operations outputs the result c = (c1, c0) where c = a * b^-1. /// /// This operation takes 11 VM cycles. -pub fn ext2_div(span: &mut BasicBlockBuilder) { - span.push_advice_injector(Ext2Inv); +pub fn ext2_div(block_builder: &mut BasicBlockBuilder) -> Result<(), AssemblyError> { + block_builder.push_advice_injector(Ext2Inv)?; #[rustfmt::skip] let ops = [ AdvPop, // [b0', b1, b0, a1, a0, ...] @@ -68,7 +69,9 @@ pub fn ext2_div(span: &mut BasicBlockBuilder) { Drop, // [b0', a1*b1', a0*b0'...] Drop // [a1*b1', a0*b0'...] ]; - span.push_ops(ops); + block_builder.push_ops(ops); + + Ok(()) } /// Given a stack with initial configuration given by [a1, a0, ...] where a = (a0, a1) represents @@ -76,7 +79,7 @@ pub fn ext2_div(span: &mut BasicBlockBuilder) { /// [-a1, -a0, ...] /// /// This operation takes 4 VM cycles. -pub fn ext2_neg(span: &mut BasicBlockBuilder) { +pub fn ext2_neg(block_builder: &mut BasicBlockBuilder) { #[rustfmt::skip] let ops = [ Neg, // [a1, a0, ...] @@ -84,7 +87,7 @@ pub fn ext2_neg(span: &mut BasicBlockBuilder) { Neg, // [-a0, -a1, ...] Swap // [-a1, -a0, ...] ]; - span.push_ops(ops); + block_builder.push_ops(ops); } /// Given an invertible quadratic extension field element on the stack, this routine computes @@ -112,8 +115,8 @@ pub fn ext2_neg(span: &mut BasicBlockBuilder) { /// assert b = (1, 0) | (1, 0) is the multiplicative identity of extension field. /// /// This operation takes 8 VM cycles. -pub fn ext2_inv(span: &mut BasicBlockBuilder) { - span.push_advice_injector(Ext2Inv); +pub fn ext2_inv(block_builder: &mut BasicBlockBuilder) -> Result<(), AssemblyError> { + block_builder.push_advice_injector(Ext2Inv)?; #[rustfmt::skip] let ops = [ AdvPop, // [a0', a1, a0, ...] @@ -125,5 +128,7 @@ pub fn ext2_inv(span: &mut BasicBlockBuilder) { MovUp2, // [1, a1', a0', ...] Assert(0), // [a1', a0', ...] ]; - span.push_ops(ops); + block_builder.push_ops(ops); + + Ok(()) } diff --git a/assembly/src/assembler/instruction/field_ops.rs b/assembly/src/assembler/instruction/field_ops.rs index cb65ec63aa..74063bb489 100644 --- a/assembly/src/assembler/instruction/field_ops.rs +++ b/assembly/src/assembler/instruction/field_ops.rs @@ -259,13 +259,13 @@ fn perform_exp_for_small_power(span_builder: &mut BasicBlockBuilder, pow: u64) { /// /// # Errors /// Returns an error if the logarithm argument (top stack element) equals ZERO. -pub fn ilog2(span: &mut BasicBlockBuilder) { - span.push_advice_injector(AdviceInjector::ILog2); - span.push_op(AdvPop); // [ilog2, n, ...] +pub fn ilog2(block_builder: &mut BasicBlockBuilder) -> Result<(), AssemblyError> { + block_builder.push_advice_injector(AdviceInjector::ILog2)?; + block_builder.push_op(AdvPop); // [ilog2, n, ...] // compute the power-of-two for the value given in the advice tape (17 cycles) - span.push_op(Dup0); - append_pow2_op(span); + block_builder.push_op(Dup0); + append_pow2_op(block_builder); // => [pow2, ilog2, n, ...] #[rustfmt::skip] @@ -289,7 +289,9 @@ pub fn ilog2(span: &mut BasicBlockBuilder) { // => [ilog2, ...] ]; - span.push_ops(ops); + block_builder.push_ops(ops); + + Ok(()) } // COMPARISON OPERATIONS diff --git a/assembly/src/assembler/instruction/mem_ops.rs b/assembly/src/assembler/instruction/mem_ops.rs index b9615ef319..cdfbdc79c5 100644 --- a/assembly/src/assembler/instruction/mem_ops.rs +++ b/assembly/src/assembler/instruction/mem_ops.rs @@ -23,7 +23,7 @@ use crate::{assembler::ProcedureContext, diagnostics::Report, AssemblyError}; /// Returns an error if we are reading from local memory and local memory index is greater than /// the number of procedure locals. pub fn mem_read( - span: &mut BasicBlockBuilder, + block_builder: &mut BasicBlockBuilder, proc_ctx: &ProcedureContext, addr: Option, is_local: bool, @@ -33,9 +33,9 @@ pub fn mem_read( if let Some(addr) = addr { if is_local { let num_locals = proc_ctx.num_locals(); - local_to_absolute_addr(span, addr as u16, num_locals)?; + local_to_absolute_addr(block_builder, addr as u16, num_locals)?; } else { - push_u32_value(span, addr); + push_u32_value(block_builder, addr); } } else { assert!(!is_local, "local always contains addr value"); @@ -43,9 +43,9 @@ pub fn mem_read( // load from the memory address on top of the stack if is_single { - span.push_op(MLoad); + block_builder.push_op(MLoad); } else { - span.push_op(MLoadW); + block_builder.push_op(MLoadW); } Ok(()) @@ -74,23 +74,23 @@ pub fn mem_read( /// Returns an error if we are writing to local memory and local memory index is greater than /// the number of procedure locals. pub fn mem_write_imm( - span: &mut BasicBlockBuilder, + block_builder: &mut BasicBlockBuilder, proc_ctx: &ProcedureContext, addr: u32, is_local: bool, is_single: bool, ) -> Result<(), AssemblyError> { if is_local { - local_to_absolute_addr(span, addr as u16, proc_ctx.num_locals())?; + local_to_absolute_addr(block_builder, addr as u16, proc_ctx.num_locals())?; } else { - push_u32_value(span, addr); + push_u32_value(block_builder, addr); } if is_single { - span.push_op(MStore); - span.push_op(Drop); + block_builder.push_op(MStore); + block_builder.push_op(Drop); } else { - span.push_op(MStoreW); + block_builder.push_op(MStoreW); } Ok(()) @@ -110,7 +110,7 @@ pub fn mem_write_imm( /// # Errors /// Returns an error if index is greater than the number of procedure locals. pub fn local_to_absolute_addr( - span: &mut BasicBlockBuilder, + block_builder: &mut BasicBlockBuilder, index: u16, num_proc_locals: u16, ) -> Result<(), AssemblyError> { @@ -127,8 +127,8 @@ pub fn local_to_absolute_addr( let max = num_proc_locals - 1; validate_param(index, 0..=max)?; - push_felt(span, -Felt::from(max - index)); - span.push_op(FmpAdd); + push_felt(block_builder, -Felt::from(max - index)); + block_builder.push_op(FmpAdd); Ok(()) } diff --git a/assembly/src/assembler/instruction/mod.rs b/assembly/src/assembler/instruction/mod.rs index 9465f380c3..49e8bec164 100644 --- a/assembly/src/assembler/instruction/mod.rs +++ b/assembly/src/assembler/instruction/mod.rs @@ -3,10 +3,7 @@ use core::ops::RangeBounds; use miette::miette; use vm_core::{mast::MastNodeId, Decorator, ONE, ZERO}; -use super::{ - ast::InvokeKind, mast_forest_builder::MastForestBuilder, Assembler, BasicBlockBuilder, Felt, - Operation, ProcedureContext, -}; +use super::{ast::InvokeKind, Assembler, BasicBlockBuilder, Felt, Operation, ProcedureContext}; use crate::{ast::Instruction, utils::bound_into_included_u64, AssemblyError, Span}; mod adv_ops; @@ -25,27 +22,21 @@ impl Assembler { pub(super) fn compile_instruction( &self, instruction: &Span, - span_builder: &mut BasicBlockBuilder, + block_builder: &mut BasicBlockBuilder, proc_ctx: &mut ProcedureContext, - mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { // if the assembler is in debug mode, start tracking the instruction about to be executed; // this will allow us to map the instruction to the sequence of operations which were // executed as a part of this instruction. if self.in_debug_mode() { - span_builder.track_instruction(instruction, proc_ctx); + block_builder.track_instruction(instruction, proc_ctx)?; } - let result = self.compile_instruction_impl( - instruction, - span_builder, - proc_ctx, - mast_forest_builder, - )?; + let result = self.compile_instruction_impl(instruction, block_builder, proc_ctx)?; // compute and update the cycle count of the instruction which just finished executing if self.in_debug_mode() { - span_builder.set_instruction_cycle_count(); + block_builder.set_instruction_cycle_count(); } Ok(result) @@ -54,423 +45,399 @@ impl Assembler { fn compile_instruction_impl( &self, instruction: &Span, - basic_block_builder: &mut BasicBlockBuilder, + block_builder: &mut BasicBlockBuilder, proc_ctx: &mut ProcedureContext, - mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { use Operation::*; match &**instruction { - Instruction::Nop => basic_block_builder.push_op(Noop), - Instruction::Assert => basic_block_builder.push_op(Assert(0)), + Instruction::Nop => block_builder.push_op(Noop), + Instruction::Assert => block_builder.push_op(Assert(0)), Instruction::AssertWithError(err_code) => { - basic_block_builder.push_op(Assert(err_code.expect_value())) + block_builder.push_op(Assert(err_code.expect_value())) }, - Instruction::AssertEq => basic_block_builder.push_ops([Eq, Assert(0)]), + Instruction::AssertEq => block_builder.push_ops([Eq, Assert(0)]), Instruction::AssertEqWithError(err_code) => { - basic_block_builder.push_ops([Eq, Assert(err_code.expect_value())]) + block_builder.push_ops([Eq, Assert(err_code.expect_value())]) }, - Instruction::AssertEqw => field_ops::assertw(basic_block_builder, 0), + Instruction::AssertEqw => field_ops::assertw(block_builder, 0), Instruction::AssertEqwWithError(err_code) => { - field_ops::assertw(basic_block_builder, err_code.expect_value()) + field_ops::assertw(block_builder, err_code.expect_value()) }, - Instruction::Assertz => basic_block_builder.push_ops([Eqz, Assert(0)]), + Instruction::Assertz => block_builder.push_ops([Eqz, Assert(0)]), Instruction::AssertzWithError(err_code) => { - basic_block_builder.push_ops([Eqz, Assert(err_code.expect_value())]) + block_builder.push_ops([Eqz, Assert(err_code.expect_value())]) }, - Instruction::Add => basic_block_builder.push_op(Add), - Instruction::AddImm(imm) => field_ops::add_imm(basic_block_builder, imm.expect_value()), - Instruction::Sub => basic_block_builder.push_ops([Neg, Add]), - Instruction::SubImm(imm) => field_ops::sub_imm(basic_block_builder, imm.expect_value()), - Instruction::Mul => basic_block_builder.push_op(Mul), - Instruction::MulImm(imm) => field_ops::mul_imm(basic_block_builder, imm.expect_value()), - Instruction::Div => basic_block_builder.push_ops([Inv, Mul]), + Instruction::Add => block_builder.push_op(Add), + Instruction::AddImm(imm) => field_ops::add_imm(block_builder, imm.expect_value()), + Instruction::Sub => block_builder.push_ops([Neg, Add]), + Instruction::SubImm(imm) => field_ops::sub_imm(block_builder, imm.expect_value()), + Instruction::Mul => block_builder.push_op(Mul), + Instruction::MulImm(imm) => field_ops::mul_imm(block_builder, imm.expect_value()), + Instruction::Div => block_builder.push_ops([Inv, Mul]), Instruction::DivImm(imm) => { - field_ops::div_imm(basic_block_builder, proc_ctx, imm.expect_spanned_value())?; + field_ops::div_imm(block_builder, proc_ctx, imm.expect_spanned_value())?; }, - Instruction::Neg => basic_block_builder.push_op(Neg), - Instruction::Inv => basic_block_builder.push_op(Inv), - Instruction::Incr => basic_block_builder.push_op(Incr), + Instruction::Neg => block_builder.push_op(Neg), + Instruction::Inv => block_builder.push_op(Inv), + Instruction::Incr => block_builder.push_op(Incr), - Instruction::Pow2 => field_ops::pow2(basic_block_builder), - Instruction::Exp => field_ops::exp(basic_block_builder, 64)?, - Instruction::ExpImm(pow) => { - field_ops::exp_imm(basic_block_builder, pow.expect_value())? - }, + Instruction::Pow2 => field_ops::pow2(block_builder), + Instruction::Exp => field_ops::exp(block_builder, 64)?, + Instruction::ExpImm(pow) => field_ops::exp_imm(block_builder, pow.expect_value())?, Instruction::ExpBitLength(num_pow_bits) => { - field_ops::exp(basic_block_builder, *num_pow_bits)? - }, - Instruction::ILog2 => field_ops::ilog2(basic_block_builder), - - Instruction::Not => basic_block_builder.push_op(Not), - Instruction::And => basic_block_builder.push_op(And), - Instruction::Or => basic_block_builder.push_op(Or), - Instruction::Xor => { - basic_block_builder.push_ops([Dup0, Dup2, Or, MovDn2, And, Not, And]) - }, - - Instruction::Eq => basic_block_builder.push_op(Eq), - Instruction::EqImm(imm) => field_ops::eq_imm(basic_block_builder, imm.expect_value()), - Instruction::Eqw => field_ops::eqw(basic_block_builder), - Instruction::Neq => basic_block_builder.push_ops([Eq, Not]), - Instruction::NeqImm(imm) => field_ops::neq_imm(basic_block_builder, imm.expect_value()), - Instruction::Lt => field_ops::lt(basic_block_builder), - Instruction::Lte => field_ops::lte(basic_block_builder), - Instruction::Gt => field_ops::gt(basic_block_builder), - Instruction::Gte => field_ops::gte(basic_block_builder), - Instruction::IsOdd => field_ops::is_odd(basic_block_builder), + field_ops::exp(block_builder, *num_pow_bits)? + }, + Instruction::ILog2 => field_ops::ilog2(block_builder)?, + + Instruction::Not => block_builder.push_op(Not), + Instruction::And => block_builder.push_op(And), + Instruction::Or => block_builder.push_op(Or), + Instruction::Xor => block_builder.push_ops([Dup0, Dup2, Or, MovDn2, And, Not, And]), + + Instruction::Eq => block_builder.push_op(Eq), + Instruction::EqImm(imm) => field_ops::eq_imm(block_builder, imm.expect_value()), + Instruction::Eqw => field_ops::eqw(block_builder), + Instruction::Neq => block_builder.push_ops([Eq, Not]), + Instruction::NeqImm(imm) => field_ops::neq_imm(block_builder, imm.expect_value()), + Instruction::Lt => field_ops::lt(block_builder), + Instruction::Lte => field_ops::lte(block_builder), + Instruction::Gt => field_ops::gt(block_builder), + Instruction::Gte => field_ops::gte(block_builder), + Instruction::IsOdd => field_ops::is_odd(block_builder), // ----- ext2 instructions ------------------------------------------------------------ - Instruction::Ext2Add => ext2_ops::ext2_add(basic_block_builder), - Instruction::Ext2Sub => ext2_ops::ext2_sub(basic_block_builder), - Instruction::Ext2Mul => ext2_ops::ext2_mul(basic_block_builder), - Instruction::Ext2Div => ext2_ops::ext2_div(basic_block_builder), - Instruction::Ext2Neg => ext2_ops::ext2_neg(basic_block_builder), - Instruction::Ext2Inv => ext2_ops::ext2_inv(basic_block_builder), + Instruction::Ext2Add => ext2_ops::ext2_add(block_builder), + Instruction::Ext2Sub => ext2_ops::ext2_sub(block_builder), + Instruction::Ext2Mul => ext2_ops::ext2_mul(block_builder), + Instruction::Ext2Div => ext2_ops::ext2_div(block_builder)?, + Instruction::Ext2Neg => ext2_ops::ext2_neg(block_builder), + Instruction::Ext2Inv => ext2_ops::ext2_inv(block_builder)?, // ----- u32 manipulation ------------------------------------------------------------- - Instruction::U32Test => basic_block_builder.push_ops([Dup0, U32split, Swap, Drop, Eqz]), - Instruction::U32TestW => u32_ops::u32testw(basic_block_builder), - Instruction::U32Assert => basic_block_builder.push_ops([Pad, U32assert2(0), Drop]), + Instruction::U32Test => block_builder.push_ops([Dup0, U32split, Swap, Drop, Eqz]), + Instruction::U32TestW => u32_ops::u32testw(block_builder), + Instruction::U32Assert => block_builder.push_ops([Pad, U32assert2(0), Drop]), Instruction::U32AssertWithError(err_code) => { - basic_block_builder.push_ops([Pad, U32assert2(err_code.expect_value()), Drop]) + block_builder.push_ops([Pad, U32assert2(err_code.expect_value()), Drop]) }, - Instruction::U32Assert2 => basic_block_builder.push_op(U32assert2(0)), + Instruction::U32Assert2 => block_builder.push_op(U32assert2(0)), Instruction::U32Assert2WithError(err_code) => { - basic_block_builder.push_op(U32assert2(err_code.expect_value())) + block_builder.push_op(U32assert2(err_code.expect_value())) }, - Instruction::U32AssertW => u32_ops::u32assertw(basic_block_builder, 0), + Instruction::U32AssertW => u32_ops::u32assertw(block_builder, 0), Instruction::U32AssertWWithError(err_code) => { - u32_ops::u32assertw(basic_block_builder, err_code.expect_value()) + u32_ops::u32assertw(block_builder, err_code.expect_value()) }, - Instruction::U32Cast => basic_block_builder.push_ops([U32split, Drop]), - Instruction::U32Split => basic_block_builder.push_op(U32split), + Instruction::U32Cast => block_builder.push_ops([U32split, Drop]), + Instruction::U32Split => block_builder.push_op(U32split), - Instruction::U32OverflowingAdd => { - u32_ops::u32add(basic_block_builder, Overflowing, None) - }, + Instruction::U32OverflowingAdd => u32_ops::u32add(block_builder, Overflowing, None), Instruction::U32OverflowingAddImm(v) => { - u32_ops::u32add(basic_block_builder, Overflowing, Some(v.expect_value())) + u32_ops::u32add(block_builder, Overflowing, Some(v.expect_value())) }, - Instruction::U32WrappingAdd => u32_ops::u32add(basic_block_builder, Wrapping, None), + Instruction::U32WrappingAdd => u32_ops::u32add(block_builder, Wrapping, None), Instruction::U32WrappingAddImm(v) => { - u32_ops::u32add(basic_block_builder, Wrapping, Some(v.expect_value())) + u32_ops::u32add(block_builder, Wrapping, Some(v.expect_value())) }, - Instruction::U32OverflowingAdd3 => basic_block_builder.push_op(U32add3), - Instruction::U32WrappingAdd3 => basic_block_builder.push_ops([U32add3, Drop]), + Instruction::U32OverflowingAdd3 => block_builder.push_op(U32add3), + Instruction::U32WrappingAdd3 => block_builder.push_ops([U32add3, Drop]), - Instruction::U32OverflowingSub => { - u32_ops::u32sub(basic_block_builder, Overflowing, None) - }, + Instruction::U32OverflowingSub => u32_ops::u32sub(block_builder, Overflowing, None), Instruction::U32OverflowingSubImm(v) => { - u32_ops::u32sub(basic_block_builder, Overflowing, Some(v.expect_value())) + u32_ops::u32sub(block_builder, Overflowing, Some(v.expect_value())) }, - Instruction::U32WrappingSub => u32_ops::u32sub(basic_block_builder, Wrapping, None), + Instruction::U32WrappingSub => u32_ops::u32sub(block_builder, Wrapping, None), Instruction::U32WrappingSubImm(v) => { - u32_ops::u32sub(basic_block_builder, Wrapping, Some(v.expect_value())) + u32_ops::u32sub(block_builder, Wrapping, Some(v.expect_value())) }, - Instruction::U32OverflowingMul => { - u32_ops::u32mul(basic_block_builder, Overflowing, None) - }, + Instruction::U32OverflowingMul => u32_ops::u32mul(block_builder, Overflowing, None), Instruction::U32OverflowingMulImm(v) => { - u32_ops::u32mul(basic_block_builder, Overflowing, Some(v.expect_value())) + u32_ops::u32mul(block_builder, Overflowing, Some(v.expect_value())) }, - Instruction::U32WrappingMul => u32_ops::u32mul(basic_block_builder, Wrapping, None), + Instruction::U32WrappingMul => u32_ops::u32mul(block_builder, Wrapping, None), Instruction::U32WrappingMulImm(v) => { - u32_ops::u32mul(basic_block_builder, Wrapping, Some(v.expect_value())) + u32_ops::u32mul(block_builder, Wrapping, Some(v.expect_value())) }, - Instruction::U32OverflowingMadd => basic_block_builder.push_op(U32madd), - Instruction::U32WrappingMadd => basic_block_builder.push_ops([U32madd, Drop]), + Instruction::U32OverflowingMadd => block_builder.push_op(U32madd), + Instruction::U32WrappingMadd => block_builder.push_ops([U32madd, Drop]), - Instruction::U32Div => u32_ops::u32div(basic_block_builder, proc_ctx, None)?, + Instruction::U32Div => u32_ops::u32div(block_builder, proc_ctx, None)?, Instruction::U32DivImm(v) => { - u32_ops::u32div(basic_block_builder, proc_ctx, Some(v.expect_spanned_value()))? + u32_ops::u32div(block_builder, proc_ctx, Some(v.expect_spanned_value()))? }, - Instruction::U32Mod => u32_ops::u32mod(basic_block_builder, proc_ctx, None)?, + Instruction::U32Mod => u32_ops::u32mod(block_builder, proc_ctx, None)?, Instruction::U32ModImm(v) => { - u32_ops::u32mod(basic_block_builder, proc_ctx, Some(v.expect_spanned_value()))? + u32_ops::u32mod(block_builder, proc_ctx, Some(v.expect_spanned_value()))? }, - Instruction::U32DivMod => u32_ops::u32divmod(basic_block_builder, proc_ctx, None)?, + Instruction::U32DivMod => u32_ops::u32divmod(block_builder, proc_ctx, None)?, Instruction::U32DivModImm(v) => { - u32_ops::u32divmod(basic_block_builder, proc_ctx, Some(v.expect_spanned_value()))? - }, - Instruction::U32And => basic_block_builder.push_op(U32and), - Instruction::U32Or => basic_block_builder.push_ops([Dup1, Dup1, U32and, Neg, Add, Add]), - Instruction::U32Xor => basic_block_builder.push_op(U32xor), - Instruction::U32Not => u32_ops::u32not(basic_block_builder), - Instruction::U32Shl => u32_ops::u32shl(basic_block_builder, None)?, - Instruction::U32ShlImm(v) => { - u32_ops::u32shl(basic_block_builder, Some(v.expect_value()))? - }, - Instruction::U32Shr => u32_ops::u32shr(basic_block_builder, None)?, - Instruction::U32ShrImm(v) => { - u32_ops::u32shr(basic_block_builder, Some(v.expect_value()))? - }, - Instruction::U32Rotl => u32_ops::u32rotl(basic_block_builder, None)?, - Instruction::U32RotlImm(v) => { - u32_ops::u32rotl(basic_block_builder, Some(v.expect_value()))? - }, - Instruction::U32Rotr => u32_ops::u32rotr(basic_block_builder, None)?, - Instruction::U32RotrImm(v) => { - u32_ops::u32rotr(basic_block_builder, Some(v.expect_value()))? - }, - Instruction::U32Popcnt => u32_ops::u32popcnt(basic_block_builder), - Instruction::U32Clz => u32_ops::u32clz(basic_block_builder), - Instruction::U32Ctz => u32_ops::u32ctz(basic_block_builder), - Instruction::U32Clo => u32_ops::u32clo(basic_block_builder), - Instruction::U32Cto => u32_ops::u32cto(basic_block_builder), - Instruction::U32Lt => u32_ops::u32lt(basic_block_builder), - Instruction::U32Lte => u32_ops::u32lte(basic_block_builder), - Instruction::U32Gt => u32_ops::u32gt(basic_block_builder), - Instruction::U32Gte => u32_ops::u32gte(basic_block_builder), - Instruction::U32Min => u32_ops::u32min(basic_block_builder), - Instruction::U32Max => u32_ops::u32max(basic_block_builder), + u32_ops::u32divmod(block_builder, proc_ctx, Some(v.expect_spanned_value()))? + }, + Instruction::U32And => block_builder.push_op(U32and), + Instruction::U32Or => block_builder.push_ops([Dup1, Dup1, U32and, Neg, Add, Add]), + Instruction::U32Xor => block_builder.push_op(U32xor), + Instruction::U32Not => u32_ops::u32not(block_builder), + Instruction::U32Shl => u32_ops::u32shl(block_builder, None)?, + Instruction::U32ShlImm(v) => u32_ops::u32shl(block_builder, Some(v.expect_value()))?, + Instruction::U32Shr => u32_ops::u32shr(block_builder, None)?, + Instruction::U32ShrImm(v) => u32_ops::u32shr(block_builder, Some(v.expect_value()))?, + Instruction::U32Rotl => u32_ops::u32rotl(block_builder, None)?, + Instruction::U32RotlImm(v) => u32_ops::u32rotl(block_builder, Some(v.expect_value()))?, + Instruction::U32Rotr => u32_ops::u32rotr(block_builder, None)?, + Instruction::U32RotrImm(v) => u32_ops::u32rotr(block_builder, Some(v.expect_value()))?, + Instruction::U32Popcnt => u32_ops::u32popcnt(block_builder), + Instruction::U32Clz => u32_ops::u32clz(block_builder)?, + Instruction::U32Ctz => u32_ops::u32ctz(block_builder)?, + Instruction::U32Clo => u32_ops::u32clo(block_builder)?, + Instruction::U32Cto => u32_ops::u32cto(block_builder)?, + Instruction::U32Lt => u32_ops::u32lt(block_builder), + Instruction::U32Lte => u32_ops::u32lte(block_builder), + Instruction::U32Gt => u32_ops::u32gt(block_builder), + Instruction::U32Gte => u32_ops::u32gte(block_builder), + Instruction::U32Min => u32_ops::u32min(block_builder), + Instruction::U32Max => u32_ops::u32max(block_builder), // ----- stack manipulation ----------------------------------------------------------- - Instruction::Drop => basic_block_builder.push_op(Drop), - Instruction::DropW => basic_block_builder.push_ops([Drop; 4]), - Instruction::PadW => basic_block_builder.push_ops([Pad; 4]), - Instruction::Dup0 => basic_block_builder.push_op(Dup0), - Instruction::Dup1 => basic_block_builder.push_op(Dup1), - Instruction::Dup2 => basic_block_builder.push_op(Dup2), - Instruction::Dup3 => basic_block_builder.push_op(Dup3), - Instruction::Dup4 => basic_block_builder.push_op(Dup4), - Instruction::Dup5 => basic_block_builder.push_op(Dup5), - Instruction::Dup6 => basic_block_builder.push_op(Dup6), - Instruction::Dup7 => basic_block_builder.push_op(Dup7), - Instruction::Dup8 => basic_block_builder.push_ops([Pad, Dup9, Add]), - Instruction::Dup9 => basic_block_builder.push_op(Dup9), - Instruction::Dup10 => basic_block_builder.push_ops([Pad, Dup11, Add]), - Instruction::Dup11 => basic_block_builder.push_op(Dup11), - Instruction::Dup12 => basic_block_builder.push_ops([Pad, Dup13, Add]), - Instruction::Dup13 => basic_block_builder.push_op(Dup13), - Instruction::Dup14 => basic_block_builder.push_ops([Pad, Dup15, Add]), - Instruction::Dup15 => basic_block_builder.push_op(Dup15), - Instruction::DupW0 => basic_block_builder.push_ops([Dup3; 4]), - Instruction::DupW1 => basic_block_builder.push_ops([Dup7; 4]), - Instruction::DupW2 => basic_block_builder.push_ops([Dup11; 4]), - Instruction::DupW3 => basic_block_builder.push_ops([Dup15; 4]), - Instruction::Swap1 => basic_block_builder.push_op(Swap), - Instruction::Swap2 => basic_block_builder.push_ops([Swap, MovUp2]), - Instruction::Swap3 => basic_block_builder.push_ops([MovDn2, MovUp3]), - Instruction::Swap4 => basic_block_builder.push_ops([MovDn3, MovUp4]), - Instruction::Swap5 => basic_block_builder.push_ops([MovDn4, MovUp5]), - Instruction::Swap6 => basic_block_builder.push_ops([MovDn5, MovUp6]), - Instruction::Swap7 => basic_block_builder.push_ops([MovDn6, MovUp7]), - Instruction::Swap8 => basic_block_builder.push_ops([MovDn7, MovUp8]), - Instruction::Swap9 => { - basic_block_builder.push_ops([MovDn8, SwapDW, Swap, SwapDW, MovUp8]) - }, + Instruction::Drop => block_builder.push_op(Drop), + Instruction::DropW => block_builder.push_ops([Drop; 4]), + Instruction::PadW => block_builder.push_ops([Pad; 4]), + Instruction::Dup0 => block_builder.push_op(Dup0), + Instruction::Dup1 => block_builder.push_op(Dup1), + Instruction::Dup2 => block_builder.push_op(Dup2), + Instruction::Dup3 => block_builder.push_op(Dup3), + Instruction::Dup4 => block_builder.push_op(Dup4), + Instruction::Dup5 => block_builder.push_op(Dup5), + Instruction::Dup6 => block_builder.push_op(Dup6), + Instruction::Dup7 => block_builder.push_op(Dup7), + Instruction::Dup8 => block_builder.push_ops([Pad, Dup9, Add]), + Instruction::Dup9 => block_builder.push_op(Dup9), + Instruction::Dup10 => block_builder.push_ops([Pad, Dup11, Add]), + Instruction::Dup11 => block_builder.push_op(Dup11), + Instruction::Dup12 => block_builder.push_ops([Pad, Dup13, Add]), + Instruction::Dup13 => block_builder.push_op(Dup13), + Instruction::Dup14 => block_builder.push_ops([Pad, Dup15, Add]), + Instruction::Dup15 => block_builder.push_op(Dup15), + Instruction::DupW0 => block_builder.push_ops([Dup3; 4]), + Instruction::DupW1 => block_builder.push_ops([Dup7; 4]), + Instruction::DupW2 => block_builder.push_ops([Dup11; 4]), + Instruction::DupW3 => block_builder.push_ops([Dup15; 4]), + Instruction::Swap1 => block_builder.push_op(Swap), + Instruction::Swap2 => block_builder.push_ops([Swap, MovUp2]), + Instruction::Swap3 => block_builder.push_ops([MovDn2, MovUp3]), + Instruction::Swap4 => block_builder.push_ops([MovDn3, MovUp4]), + Instruction::Swap5 => block_builder.push_ops([MovDn4, MovUp5]), + Instruction::Swap6 => block_builder.push_ops([MovDn5, MovUp6]), + Instruction::Swap7 => block_builder.push_ops([MovDn6, MovUp7]), + Instruction::Swap8 => block_builder.push_ops([MovDn7, MovUp8]), + Instruction::Swap9 => block_builder.push_ops([MovDn8, SwapDW, Swap, SwapDW, MovUp8]), Instruction::Swap10 => { - basic_block_builder.push_ops([MovDn8, SwapDW, Swap, MovUp2, SwapDW, MovUp8]) + block_builder.push_ops([MovDn8, SwapDW, Swap, MovUp2, SwapDW, MovUp8]) }, Instruction::Swap11 => { - basic_block_builder.push_ops([MovDn8, SwapDW, MovDn2, MovUp3, SwapDW, MovUp8]) + block_builder.push_ops([MovDn8, SwapDW, MovDn2, MovUp3, SwapDW, MovUp8]) }, Instruction::Swap12 => { - basic_block_builder.push_ops([MovDn8, SwapDW, MovDn3, MovUp4, SwapDW, MovUp8]) + block_builder.push_ops([MovDn8, SwapDW, MovDn3, MovUp4, SwapDW, MovUp8]) }, Instruction::Swap13 => { - basic_block_builder.push_ops([MovDn8, SwapDW, MovDn4, MovUp5, SwapDW, MovUp8]) + block_builder.push_ops([MovDn8, SwapDW, MovDn4, MovUp5, SwapDW, MovUp8]) }, Instruction::Swap14 => { - basic_block_builder.push_ops([MovDn8, SwapDW, MovDn5, MovUp6, SwapDW, MovUp8]) + block_builder.push_ops([MovDn8, SwapDW, MovDn5, MovUp6, SwapDW, MovUp8]) }, Instruction::Swap15 => { - basic_block_builder.push_ops([MovDn8, SwapDW, MovDn6, MovUp7, SwapDW, MovUp8]) - }, - Instruction::SwapW1 => basic_block_builder.push_op(SwapW), - Instruction::SwapW2 => basic_block_builder.push_op(SwapW2), - Instruction::SwapW3 => basic_block_builder.push_op(SwapW3), - Instruction::SwapDw => basic_block_builder.push_op(SwapDW), - Instruction::MovUp2 => basic_block_builder.push_op(MovUp2), - Instruction::MovUp3 => basic_block_builder.push_op(MovUp3), - Instruction::MovUp4 => basic_block_builder.push_op(MovUp4), - Instruction::MovUp5 => basic_block_builder.push_op(MovUp5), - Instruction::MovUp6 => basic_block_builder.push_op(MovUp6), - Instruction::MovUp7 => basic_block_builder.push_op(MovUp7), - Instruction::MovUp8 => basic_block_builder.push_op(MovUp8), - Instruction::MovUp9 => basic_block_builder.push_ops([SwapDW, Swap, SwapDW, MovUp8]), - Instruction::MovUp10 => basic_block_builder.push_ops([SwapDW, MovUp2, SwapDW, MovUp8]), - Instruction::MovUp11 => basic_block_builder.push_ops([SwapDW, MovUp3, SwapDW, MovUp8]), - Instruction::MovUp12 => basic_block_builder.push_ops([SwapDW, MovUp4, SwapDW, MovUp8]), - Instruction::MovUp13 => basic_block_builder.push_ops([SwapDW, MovUp5, SwapDW, MovUp8]), - Instruction::MovUp14 => basic_block_builder.push_ops([SwapDW, MovUp6, SwapDW, MovUp8]), - Instruction::MovUp15 => basic_block_builder.push_ops([SwapDW, MovUp7, SwapDW, MovUp8]), - Instruction::MovUpW2 => basic_block_builder.push_ops([SwapW, SwapW2]), - Instruction::MovUpW3 => basic_block_builder.push_ops([SwapW, SwapW2, SwapW3]), - Instruction::MovDn2 => basic_block_builder.push_op(MovDn2), - Instruction::MovDn3 => basic_block_builder.push_op(MovDn3), - Instruction::MovDn4 => basic_block_builder.push_op(MovDn4), - Instruction::MovDn5 => basic_block_builder.push_op(MovDn5), - Instruction::MovDn6 => basic_block_builder.push_op(MovDn6), - Instruction::MovDn7 => basic_block_builder.push_op(MovDn7), - Instruction::MovDn8 => basic_block_builder.push_op(MovDn8), - Instruction::MovDn9 => basic_block_builder.push_ops([MovDn8, SwapDW, Swap, SwapDW]), - Instruction::MovDn10 => basic_block_builder.push_ops([MovDn8, SwapDW, MovDn2, SwapDW]), - Instruction::MovDn11 => basic_block_builder.push_ops([MovDn8, SwapDW, MovDn3, SwapDW]), - Instruction::MovDn12 => basic_block_builder.push_ops([MovDn8, SwapDW, MovDn4, SwapDW]), - Instruction::MovDn13 => basic_block_builder.push_ops([MovDn8, SwapDW, MovDn5, SwapDW]), - Instruction::MovDn14 => basic_block_builder.push_ops([MovDn8, SwapDW, MovDn6, SwapDW]), - Instruction::MovDn15 => basic_block_builder.push_ops([MovDn8, SwapDW, MovDn7, SwapDW]), - Instruction::MovDnW2 => basic_block_builder.push_ops([SwapW2, SwapW]), - Instruction::MovDnW3 => basic_block_builder.push_ops([SwapW3, SwapW2, SwapW]), - - Instruction::CSwap => basic_block_builder.push_op(CSwap), - Instruction::CSwapW => basic_block_builder.push_op(CSwapW), - Instruction::CDrop => basic_block_builder.push_ops([CSwap, Drop]), - Instruction::CDropW => basic_block_builder.push_ops([CSwapW, Drop, Drop, Drop, Drop]), + block_builder.push_ops([MovDn8, SwapDW, MovDn6, MovUp7, SwapDW, MovUp8]) + }, + Instruction::SwapW1 => block_builder.push_op(SwapW), + Instruction::SwapW2 => block_builder.push_op(SwapW2), + Instruction::SwapW3 => block_builder.push_op(SwapW3), + Instruction::SwapDw => block_builder.push_op(SwapDW), + Instruction::MovUp2 => block_builder.push_op(MovUp2), + Instruction::MovUp3 => block_builder.push_op(MovUp3), + Instruction::MovUp4 => block_builder.push_op(MovUp4), + Instruction::MovUp5 => block_builder.push_op(MovUp5), + Instruction::MovUp6 => block_builder.push_op(MovUp6), + Instruction::MovUp7 => block_builder.push_op(MovUp7), + Instruction::MovUp8 => block_builder.push_op(MovUp8), + Instruction::MovUp9 => block_builder.push_ops([SwapDW, Swap, SwapDW, MovUp8]), + Instruction::MovUp10 => block_builder.push_ops([SwapDW, MovUp2, SwapDW, MovUp8]), + Instruction::MovUp11 => block_builder.push_ops([SwapDW, MovUp3, SwapDW, MovUp8]), + Instruction::MovUp12 => block_builder.push_ops([SwapDW, MovUp4, SwapDW, MovUp8]), + Instruction::MovUp13 => block_builder.push_ops([SwapDW, MovUp5, SwapDW, MovUp8]), + Instruction::MovUp14 => block_builder.push_ops([SwapDW, MovUp6, SwapDW, MovUp8]), + Instruction::MovUp15 => block_builder.push_ops([SwapDW, MovUp7, SwapDW, MovUp8]), + Instruction::MovUpW2 => block_builder.push_ops([SwapW, SwapW2]), + Instruction::MovUpW3 => block_builder.push_ops([SwapW, SwapW2, SwapW3]), + Instruction::MovDn2 => block_builder.push_op(MovDn2), + Instruction::MovDn3 => block_builder.push_op(MovDn3), + Instruction::MovDn4 => block_builder.push_op(MovDn4), + Instruction::MovDn5 => block_builder.push_op(MovDn5), + Instruction::MovDn6 => block_builder.push_op(MovDn6), + Instruction::MovDn7 => block_builder.push_op(MovDn7), + Instruction::MovDn8 => block_builder.push_op(MovDn8), + Instruction::MovDn9 => block_builder.push_ops([MovDn8, SwapDW, Swap, SwapDW]), + Instruction::MovDn10 => block_builder.push_ops([MovDn8, SwapDW, MovDn2, SwapDW]), + Instruction::MovDn11 => block_builder.push_ops([MovDn8, SwapDW, MovDn3, SwapDW]), + Instruction::MovDn12 => block_builder.push_ops([MovDn8, SwapDW, MovDn4, SwapDW]), + Instruction::MovDn13 => block_builder.push_ops([MovDn8, SwapDW, MovDn5, SwapDW]), + Instruction::MovDn14 => block_builder.push_ops([MovDn8, SwapDW, MovDn6, SwapDW]), + Instruction::MovDn15 => block_builder.push_ops([MovDn8, SwapDW, MovDn7, SwapDW]), + Instruction::MovDnW2 => block_builder.push_ops([SwapW2, SwapW]), + Instruction::MovDnW3 => block_builder.push_ops([SwapW3, SwapW2, SwapW]), + + Instruction::CSwap => block_builder.push_op(CSwap), + Instruction::CSwapW => block_builder.push_op(CSwapW), + Instruction::CDrop => block_builder.push_ops([CSwap, Drop]), + Instruction::CDropW => block_builder.push_ops([CSwapW, Drop, Drop, Drop, Drop]), // ----- input / output instructions -------------------------------------------------- - Instruction::Push(imm) => env_ops::push_one(imm.expect_value(), basic_block_builder), - Instruction::PushU8(imm) => env_ops::push_one(*imm, basic_block_builder), - Instruction::PushU16(imm) => env_ops::push_one(*imm, basic_block_builder), - Instruction::PushU32(imm) => env_ops::push_one(*imm, basic_block_builder), - Instruction::PushFelt(imm) => env_ops::push_one(*imm, basic_block_builder), - Instruction::PushWord(imms) => env_ops::push_many(imms, basic_block_builder), - Instruction::PushU8List(imms) => env_ops::push_many(imms, basic_block_builder), - Instruction::PushU16List(imms) => env_ops::push_many(imms, basic_block_builder), - Instruction::PushU32List(imms) => env_ops::push_many(imms, basic_block_builder), - Instruction::PushFeltList(imms) => env_ops::push_many(imms, basic_block_builder), - Instruction::Sdepth => basic_block_builder.push_op(SDepth), - Instruction::Caller => { - env_ops::caller(basic_block_builder, proc_ctx, instruction.span())? - }, - Instruction::Clk => basic_block_builder.push_op(Clk), - Instruction::AdvPipe => basic_block_builder.push_op(Pipe), - Instruction::AdvPush(n) => adv_ops::adv_push(basic_block_builder, n.expect_value())?, - Instruction::AdvLoadW => basic_block_builder.push_op(AdvPopW), - - Instruction::MemStream => basic_block_builder.push_op(MStream), - Instruction::Locaddr(v) => { - env_ops::locaddr(basic_block_builder, v.expect_value(), proc_ctx)? - }, - Instruction::MemLoad => { - mem_ops::mem_read(basic_block_builder, proc_ctx, None, false, true)? - }, - Instruction::MemLoadImm(v) => mem_ops::mem_read( - basic_block_builder, - proc_ctx, - Some(v.expect_value()), - false, - true, - )?, + Instruction::Push(imm) => env_ops::push_one(imm.expect_value(), block_builder), + Instruction::PushU8(imm) => env_ops::push_one(*imm, block_builder), + Instruction::PushU16(imm) => env_ops::push_one(*imm, block_builder), + Instruction::PushU32(imm) => env_ops::push_one(*imm, block_builder), + Instruction::PushFelt(imm) => env_ops::push_one(*imm, block_builder), + Instruction::PushWord(imms) => env_ops::push_many(imms, block_builder), + Instruction::PushU8List(imms) => env_ops::push_many(imms, block_builder), + Instruction::PushU16List(imms) => env_ops::push_many(imms, block_builder), + Instruction::PushU32List(imms) => env_ops::push_many(imms, block_builder), + Instruction::PushFeltList(imms) => env_ops::push_many(imms, block_builder), + Instruction::Sdepth => block_builder.push_op(SDepth), + Instruction::Caller => env_ops::caller(block_builder, proc_ctx, instruction.span())?, + Instruction::Clk => block_builder.push_op(Clk), + Instruction::AdvPipe => block_builder.push_op(Pipe), + Instruction::AdvPush(n) => adv_ops::adv_push(block_builder, n.expect_value())?, + Instruction::AdvLoadW => block_builder.push_op(AdvPopW), + + Instruction::MemStream => block_builder.push_op(MStream), + Instruction::Locaddr(v) => env_ops::locaddr(block_builder, v.expect_value(), proc_ctx)?, + Instruction::MemLoad => mem_ops::mem_read(block_builder, proc_ctx, None, false, true)?, + Instruction::MemLoadImm(v) => { + mem_ops::mem_read(block_builder, proc_ctx, Some(v.expect_value()), false, true)? + }, Instruction::MemLoadW => { - mem_ops::mem_read(basic_block_builder, proc_ctx, None, false, false)? + mem_ops::mem_read(block_builder, proc_ctx, None, false, false)? + }, + Instruction::MemLoadWImm(v) => { + mem_ops::mem_read(block_builder, proc_ctx, Some(v.expect_value()), false, false)? }, - Instruction::MemLoadWImm(v) => mem_ops::mem_read( - basic_block_builder, - proc_ctx, - Some(v.expect_value()), - false, - false, - )?, Instruction::LocLoad(v) => mem_ops::mem_read( - basic_block_builder, + block_builder, proc_ctx, Some(v.expect_value() as u32), true, true, )?, Instruction::LocLoadW(v) => mem_ops::mem_read( - basic_block_builder, + block_builder, proc_ctx, Some(v.expect_value() as u32), true, false, )?, - Instruction::MemStore => basic_block_builder.push_ops([MStore, Drop]), - Instruction::MemStoreW => basic_block_builder.push_ops([MStoreW]), - Instruction::MemStoreImm(v) => mem_ops::mem_write_imm( - basic_block_builder, - proc_ctx, - v.expect_value(), - false, - true, - )?, - Instruction::MemStoreWImm(v) => mem_ops::mem_write_imm( - basic_block_builder, - proc_ctx, - v.expect_value(), - false, - false, - )?, + Instruction::MemStore => block_builder.push_ops([MStore, Drop]), + Instruction::MemStoreW => block_builder.push_ops([MStoreW]), + Instruction::MemStoreImm(v) => { + mem_ops::mem_write_imm(block_builder, proc_ctx, v.expect_value(), false, true)? + }, + Instruction::MemStoreWImm(v) => { + mem_ops::mem_write_imm(block_builder, proc_ctx, v.expect_value(), false, false)? + }, Instruction::LocStore(v) => mem_ops::mem_write_imm( - basic_block_builder, + block_builder, proc_ctx, v.expect_value() as u32, true, true, )?, Instruction::LocStoreW(v) => mem_ops::mem_write_imm( - basic_block_builder, + block_builder, proc_ctx, v.expect_value() as u32, true, false, )?, - Instruction::AdvInject(injector) => adv_ops::adv_inject(basic_block_builder, injector), + Instruction::AdvInject(injector) => adv_ops::adv_inject(block_builder, injector)?, // ----- cryptographic instructions --------------------------------------------------- - Instruction::Hash => crypto_ops::hash(basic_block_builder), - Instruction::HPerm => basic_block_builder.push_op(HPerm), - Instruction::HMerge => crypto_ops::hmerge(basic_block_builder), - Instruction::MTreeGet => crypto_ops::mtree_get(basic_block_builder), - Instruction::MTreeSet => crypto_ops::mtree_set(basic_block_builder), - Instruction::MTreeMerge => crypto_ops::mtree_merge(basic_block_builder), - Instruction::MTreeVerify => basic_block_builder.push_op(MpVerify(0)), + Instruction::Hash => crypto_ops::hash(block_builder), + Instruction::HPerm => block_builder.push_op(HPerm), + Instruction::HMerge => crypto_ops::hmerge(block_builder), + Instruction::MTreeGet => crypto_ops::mtree_get(block_builder)?, + Instruction::MTreeSet => crypto_ops::mtree_set(block_builder)?, + Instruction::MTreeMerge => crypto_ops::mtree_merge(block_builder)?, + Instruction::MTreeVerify => block_builder.push_op(MpVerify(0)), Instruction::MTreeVerifyWithError(err_code) => { - basic_block_builder.push_op(MpVerify(err_code.expect_value())) + block_builder.push_op(MpVerify(err_code.expect_value())) }, // ----- STARK proof verification ----------------------------------------------------- - Instruction::FriExt2Fold4 => basic_block_builder.push_op(FriE2F4), - Instruction::RCombBase => basic_block_builder.push_op(RCombBase), + Instruction::FriExt2Fold4 => block_builder.push_op(FriE2F4), + Instruction::RCombBase => block_builder.push_op(RCombBase), // ----- exec/call instructions ------------------------------------------------------- Instruction::Exec(ref callee) => { - return self.invoke(InvokeKind::Exec, callee, proc_ctx, mast_forest_builder) + return self + .invoke( + InvokeKind::Exec, + callee, + proc_ctx, + block_builder.mast_forest_builder_mut(), + ) + .map(Into::into); }, Instruction::Call(ref callee) => { - return self.invoke(InvokeKind::Call, callee, proc_ctx, mast_forest_builder) + return self + .invoke( + InvokeKind::Call, + callee, + proc_ctx, + block_builder.mast_forest_builder_mut(), + ) + .map(Into::into); }, Instruction::SysCall(ref callee) => { - return self.invoke(InvokeKind::SysCall, callee, proc_ctx, mast_forest_builder) - }, - Instruction::DynExec => return self.dynexec(mast_forest_builder), - Instruction::DynCall => return self.dyncall(mast_forest_builder), - Instruction::ProcRef(ref callee) => { - self.procref(callee, proc_ctx, basic_block_builder, mast_forest_builder)? - }, + return self + .invoke( + InvokeKind::SysCall, + callee, + proc_ctx, + block_builder.mast_forest_builder_mut(), + ) + .map(Into::into); + }, + Instruction::DynExec => return self.dynexec(block_builder.mast_forest_builder_mut()), + Instruction::DynCall => return self.dyncall(block_builder.mast_forest_builder_mut()), + Instruction::ProcRef(ref callee) => self.procref(callee, proc_ctx, block_builder)?, // ----- debug decorators ------------------------------------------------------------- Instruction::Breakpoint => { if self.in_debug_mode() { - basic_block_builder.push_op(Noop); - basic_block_builder.track_instruction(instruction, proc_ctx); + block_builder.push_op(Noop); + block_builder.track_instruction(instruction, proc_ctx)?; } }, Instruction::Debug(options) => { if self.in_debug_mode() { - basic_block_builder.push_decorator(Decorator::Debug( + block_builder.push_decorator(Decorator::Debug( options.clone().try_into().expect("unresolved constant"), - )) + ))?; } }, // ----- emit instruction ------------------------------------------------------------- Instruction::Emit(event_id) => { - basic_block_builder.push_decorator(Decorator::Event(event_id.expect_value())); + block_builder.push_op(Operation::Emit(event_id.expect_value())); }, // ----- trace instruction ------------------------------------------------------------ Instruction::Trace(trace_id) => { - basic_block_builder.push_decorator(Decorator::Trace(trace_id.expect_value())); + block_builder.push_decorator(Decorator::Trace(trace_id.expect_value()))?; }, } diff --git a/assembly/src/assembler/instruction/procedures.rs b/assembly/src/assembler/instruction/procedures.rs index 1dfbfabc2c..0c7b2600c0 100644 --- a/assembly/src/assembler/instruction/procedures.rs +++ b/assembly/src/assembler/instruction/procedures.rs @@ -5,122 +5,31 @@ use super::{Assembler, BasicBlockBuilder, Operation}; use crate::{ assembler::{mast_forest_builder::MastForestBuilder, ProcedureContext}, ast::{InvocationTarget, InvokeKind}, - AssemblyError, RpoDigest, SourceSpan, Spanned, + AssemblyError, RpoDigest, }; /// Procedure Invocation impl Assembler { + /// Returns the [`MastNodeId`] of the invoked procedure specified by `callee`. + /// + /// For example, given `exec.f`, this method would return the procedure body id of `f`. If the + /// only representation of `f` that we have is its MAST root, then this method will also insert + /// a [`core::mast::ExternalNode`] that wraps `f`'s MAST root and return the corresponding id. pub(super) fn invoke( &self, kind: InvokeKind, callee: &InvocationTarget, - proc_ctx: &mut ProcedureContext, - mast_forest_builder: &mut MastForestBuilder, - ) -> Result, AssemblyError> { - let span = callee.span(); - let digest = self.resolve_target(kind, callee, proc_ctx, mast_forest_builder)?; - self.invoke_mast_root(kind, span, digest, proc_ctx, mast_forest_builder) - } - - fn invoke_mast_root( - &self, - kind: InvokeKind, - span: SourceSpan, - mast_root: RpoDigest, - proc_ctx: &mut ProcedureContext, + proc_ctx: &ProcedureContext, mast_forest_builder: &mut MastForestBuilder, - ) -> Result, AssemblyError> { - // Get the procedure from the assembler - let current_source_file = self.source_manager.get(span.source_id()).ok(); - - // If the procedure is cached, register the call to ensure the callset - // is updated correctly. - match mast_forest_builder.find_procedure(&mast_root) { - Some(proc) if matches!(kind, InvokeKind::SysCall) => { - // Verify if this is a syscall, that the callee is a kernel procedure - // - // NOTE: The assembler is expected to know the full set of all kernel - // procedures at this point, so if we can't identify the callee as a - // kernel procedure, it is a definite error. - if !proc.visibility().is_syscall() { - return Err(AssemblyError::InvalidSysCallTarget { - span, - source_file: current_source_file, - callee: proc.fully_qualified_name().clone(), - }); - } - let maybe_kernel_path = proc.path(); - self.module_graph - .find_module(maybe_kernel_path) - .ok_or_else(|| AssemblyError::InvalidSysCallTarget { - span, - source_file: current_source_file.clone(), - callee: proc.fully_qualified_name().clone(), - }) - .and_then(|module| { - // Note: this module is guaranteed to be of AST variant, since we have the - // AST of a procedure contained in it (i.e. `proc`). Hence, it must be that - // the entire module is in AST representation as well. - if module.unwrap_ast().is_kernel() { - Ok(()) - } else { - Err(AssemblyError::InvalidSysCallTarget { - span, - source_file: current_source_file.clone(), - callee: proc.fully_qualified_name().clone(), - }) - } - })?; - proc_ctx.register_external_call(&proc, false)?; - }, - Some(proc) => proc_ctx.register_external_call(&proc, false)?, - None => (), + ) -> Result { + let invoked_proc_node_id = + self.resolve_target(kind, callee, proc_ctx, mast_forest_builder)?; + + match kind { + InvokeKind::ProcRef | InvokeKind::Exec => Ok(invoked_proc_node_id), + InvokeKind::Call => mast_forest_builder.ensure_call(invoked_proc_node_id), + InvokeKind::SysCall => mast_forest_builder.ensure_syscall(invoked_proc_node_id), } - - let mast_root_node_id = { - match kind { - InvokeKind::Exec | InvokeKind::ProcRef => { - // Note that here we rely on the fact that we topologically sorted the - // procedures, such that when we assemble a procedure, all - // procedures that it calls will have been assembled, and - // hence be present in the `MastForest`. - match mast_forest_builder.find_procedure_node_id(mast_root) { - Some(root) => root, - None => { - // If the MAST root called isn't known to us, make it an external - // reference. - mast_forest_builder.ensure_external(mast_root)? - }, - } - }, - InvokeKind::Call => { - let callee_id = match mast_forest_builder.find_procedure_node_id(mast_root) { - Some(callee_id) => callee_id, - None => { - // If the MAST root called isn't known to us, make it an external - // reference. - mast_forest_builder.ensure_external(mast_root)? - }, - }; - - mast_forest_builder.ensure_call(callee_id)? - }, - InvokeKind::SysCall => { - let callee_id = match mast_forest_builder.find_procedure_node_id(mast_root) { - Some(callee_id) => callee_id, - None => { - // If the MAST root called isn't known to us, make it an external - // reference. - mast_forest_builder.ensure_external(mast_root)? - }, - }; - - mast_forest_builder.ensure_syscall(callee_id)? - }, - } - }; - - Ok(Some(mast_root_node_id)) } /// Creates a new DYN block for the dynamic code execution and return. @@ -133,15 +42,12 @@ impl Assembler { Ok(Some(dyn_node_id)) } - /// Creates a new CALL block whose target is DYN. + /// Creates a new DYNCALL block for the dynamic function call and return. pub(super) fn dyncall( &self, mast_forest_builder: &mut MastForestBuilder, ) -> Result, AssemblyError> { - let dyn_call_node_id = { - let dyn_node_id = mast_forest_builder.ensure_dyn()?; - mast_forest_builder.ensure_call(dyn_node_id)? - }; + let dyn_call_node_id = mast_forest_builder.ensure_dyncall()?; Ok(Some(dyn_call_node_id)) } @@ -150,34 +56,38 @@ impl Assembler { &self, callee: &InvocationTarget, proc_ctx: &mut ProcedureContext, - span_builder: &mut BasicBlockBuilder, - mast_forest_builder: &MastForestBuilder, + block_builder: &mut BasicBlockBuilder, ) -> Result<(), AssemblyError> { - let digest = - self.resolve_target(InvokeKind::ProcRef, callee, proc_ctx, mast_forest_builder)?; - self.procref_mast_root(digest, proc_ctx, span_builder, mast_forest_builder) + let mast_root = { + let proc_body_id = self.resolve_target( + InvokeKind::ProcRef, + callee, + proc_ctx, + block_builder.mast_forest_builder_mut(), + )?; + // Note: it's ok to `unwrap()` here since `proc_body_id` was returned from + // `mast_forest_builder` + block_builder + .mast_forest_builder() + .get_mast_node(proc_body_id) + .unwrap() + .digest() + }; + + self.procref_mast_root(mast_root, block_builder) } fn procref_mast_root( &self, mast_root: RpoDigest, - proc_ctx: &mut ProcedureContext, - span_builder: &mut BasicBlockBuilder, - mast_forest_builder: &MastForestBuilder, + block_builder: &mut BasicBlockBuilder, ) -> Result<(), AssemblyError> { - // Add the root to the callset to be able to use dynamic instructions - // with the referenced procedure later - - if let Some(proc) = mast_forest_builder.find_procedure(&mast_root) { - proc_ctx.register_external_call(&proc, false)?; - } - // Create an array with `Push` operations containing root elements let ops = mast_root .iter() .map(|elem| Operation::Push(*elem)) .collect::>(); - span_builder.push_ops(ops); + block_builder.push_ops(ops); Ok(()) } } diff --git a/assembly/src/assembler/instruction/u32_ops.rs b/assembly/src/assembler/instruction/u32_ops.rs index 54e56b4d11..8826d05939 100644 --- a/assembly/src/assembler/instruction/u32_ops.rs +++ b/assembly/src/assembler/instruction/u32_ops.rs @@ -236,25 +236,25 @@ pub fn u32rotl(span_builder: &mut BasicBlockBuilder, imm: Option) -> Result< /// b is the shift amount, then adding the overflow limb to the shifted limb. /// /// VM cycles per mode: -/// - u32rotr: 22 cycles +/// - u32rotr: 23 cycles /// - u32rotr.b: 3 cycles pub fn u32rotr(span_builder: &mut BasicBlockBuilder, imm: Option) -> Result<(), AssemblyError> { match imm { Some(0) => { // if rotation is performed by 0, do nothing (Noop) span_builder.push_op(Noop); - return Ok(()); }, Some(imm) => { validate_param(imm, 1..=MAX_U32_ROTATE_VALUE)?; span_builder.push_op(Push(Felt::new(1 << (32 - imm)))); + span_builder.push_ops([U32mul, Add]); }, None => { span_builder.push_ops([Push(Felt::new(32)), Swap, U32sub, Drop]); append_pow2_op(span_builder); + span_builder.push_ops([Mul, U32split, Add]); }, } - span_builder.push_ops([U32mul, Add]); Ok(()) } @@ -297,12 +297,13 @@ pub fn u32popcnt(span_builder: &mut BasicBlockBuilder) { /// leading zeros of the value using non-deterministic technique (i.e. it takes help of advice /// provider). /// -/// This operation takes 37 VM cycles. -pub fn u32clz(span: &mut BasicBlockBuilder) { - span.push_advice_injector(AdviceInjector::U32Clz); - span.push_op(AdvPop); // [clz, n, ...] +/// This operation takes 42 VM cycles. +pub fn u32clz(block_builder: &mut BasicBlockBuilder) -> Result<(), AssemblyError> { + block_builder.push_advice_injector(AdviceInjector::U32Clz)?; + block_builder.push_op(AdvPop); // [clz, n, ...] - calculate_clz(span); + verify_clz(block_builder); + Ok(()) } /// Translates `u32ctz` assembly instruction to VM operations. `u32ctz` counts the number of @@ -310,23 +311,25 @@ pub fn u32clz(span: &mut BasicBlockBuilder) { /// provider). /// /// This operation takes 34 VM cycles. -pub fn u32ctz(span: &mut BasicBlockBuilder) { - span.push_advice_injector(AdviceInjector::U32Ctz); - span.push_op(AdvPop); // [ctz, n, ...] +pub fn u32ctz(block_builder: &mut BasicBlockBuilder) -> Result<(), AssemblyError> { + block_builder.push_advice_injector(AdviceInjector::U32Ctz)?; + block_builder.push_op(AdvPop); // [ctz, n, ...] - calculate_ctz(span); + verify_ctz(block_builder); + Ok(()) } /// Translates `u32clo` assembly instruction to VM operations. `u32clo` counts the number of /// leading ones of the value using non-deterministic technique (i.e. it takes help of advice /// provider). /// -/// This operation takes 36 VM cycles. -pub fn u32clo(span: &mut BasicBlockBuilder) { - span.push_advice_injector(AdviceInjector::U32Clo); - span.push_op(AdvPop); // [clo, n, ...] +/// This operation takes 41 VM cycles. +pub fn u32clo(block_builder: &mut BasicBlockBuilder) -> Result<(), AssemblyError> { + block_builder.push_advice_injector(AdviceInjector::U32Clo)?; + block_builder.push_op(AdvPop); // [clo, n, ...] - calculate_clo(span); + verify_clo(block_builder); + Ok(()) } /// Translates `u32cto` assembly instruction to VM operations. `u32cto` counts the number of @@ -334,11 +337,12 @@ pub fn u32clo(span: &mut BasicBlockBuilder) { /// provider). /// /// This operation takes 33 VM cycles. -pub fn u32cto(span: &mut BasicBlockBuilder) { - span.push_advice_injector(AdviceInjector::U32Cto); - span.push_op(AdvPop); // [cto, n, ...] +pub fn u32cto(block_builder: &mut BasicBlockBuilder) -> Result<(), AssemblyError> { + block_builder.push_advice_injector(AdviceInjector::U32Cto)?; + block_builder.push_op(AdvPop); // [cto, n, ...] - calculate_cto(span); + verify_cto(block_builder); + Ok(()) } /// Specifically handles these specific inputs per the spec. @@ -347,27 +351,27 @@ pub fn u32cto(span: &mut BasicBlockBuilder) { /// - Overflowing: does not check if the inputs are u32 values; overflow or underflow bits are /// pushed onto the stack. fn handle_arithmetic_operation( - span_builder: &mut BasicBlockBuilder, + block_builder: &mut BasicBlockBuilder, op: Operation, op_mode: U32OpMode, imm: Option, ) { if let Some(imm) = imm { - push_u32_value(span_builder, imm); + push_u32_value(block_builder, imm); } - span_builder.push_op(op); + block_builder.push_op(op); // in the wrapping mode, drop high 32 bits if matches!(op_mode, U32OpMode::Wrapping) { - span_builder.push_op(Drop); + block_builder.push_op(Drop); } } /// Handles common parts of u32div, u32mod, and u32divmod operations, including handling of /// immediate parameters. fn handle_division( - span_builder: &mut BasicBlockBuilder, + block_builder: &mut BasicBlockBuilder, proc_ctx: &ProcedureContext, imm: Option>, ) -> Result<(), AssemblyError> { @@ -382,10 +386,10 @@ fn handle_division( Err(AssemblyError::Other(RelatedError::new(error))) }; } - push_u32_value(span_builder, imm.into_inner()); + push_u32_value(block_builder, imm.into_inner()); } - span_builder.push_op(U32div); + block_builder.push_op(U32div); Ok(()) } @@ -395,20 +399,20 @@ fn handle_division( /// Mutate the first two elements of the stack from `[b, a, ..]` into `[2^b, a, ..]`, with `b` /// either as a provided immediate value, or as an element that already exists in the stack. fn prepare_bitwise( - span_builder: &mut BasicBlockBuilder, + block_builder: &mut BasicBlockBuilder, imm: Option, ) -> Result<(), AssemblyError> { match imm { Some(0) => { // if shift/rotation is performed by 0, do nothing (Noop) - span_builder.push_op(Noop); + block_builder.push_op(Noop); }, Some(imm) => { validate_param(imm, 1..=MAX_VALUE)?; - span_builder.push_op(Push(Felt::new(1 << imm))); + block_builder.push_op(Push(Felt::new(1 << imm))); }, None => { - append_pow2_op(span_builder); + append_pow2_op(block_builder); }, } Ok(()) @@ -449,44 +453,50 @@ fn prepare_bitwise( /// /// `[clz, n, ... ] -> [clz, ... ]` /// -/// VM cycles: 36 -fn calculate_clz(span: &mut BasicBlockBuilder) { +/// VM cycles: 42 +fn verify_clz(block_builder: &mut BasicBlockBuilder) { // [clz, n, ...] #[rustfmt::skip] let ops_group_1 = [ - Swap, Push(32u8.into()), Dup2, Neg, Add // [32 - clz, n, clz, ...] + Push(32u8.into()), Dup1, Neg, Add // [32 - clz, clz, n, ...] ]; - span.push_ops(ops_group_1); + block_builder.push_ops(ops_group_1); - append_pow2_op(span); // [pow2(32 - clz), n, clz, ...] + append_pow2_op(block_builder); // [pow2(32 - clz), clz, n, ...] #[rustfmt::skip] let ops_group_2 = [ - Push(Felt::new(u32::MAX as u64 + 1)), // [2^32, pow2(32 - clz), n, clz, ...] - - Dup1, Neg, Add, // [2^32 - pow2(32 - clz), pow2(32 - clz), n, clz, ...] - // `2^32 - pow2(32 - clz)` is equal to `clz` leading ones and `32 - clz` - // zeros: - // 1111111111...1110000...0 - // └─ `clz` ones ─┘ - - Swap, Push(2u8.into()), U32div, Drop, // [pow2(32 - clz) / 2, 2^32 - pow2(32 - clz), n, clz, ...] - // pow2(32 - clz) / 2 is equal to `clz` leading - // zeros, `1` one and all other zeros. - - Swap, Dup1, Add, // [bit_mask, pow2(32 - clz) / 2, n, clz, ...] - // 1111111111...111000...0 <-- bitmask - // └─ clz ones ─┘│ - // └─ additional one - - MovUp2, U32and, // [m, pow2(32 - clz) / 2, clz] - // If calcualtion of `clz` is correct, m should be equal to - // pow2(32 - clz) / 2 - - Eq, Assert(0) // [clz, ...] + // 1. Obtain a mask for all `32 - clz` trailing bits + // + // #=> [2^(32 - clz) - 1, clz, n] + Push(1u8.into()), Neg, Add, + // 2. Compute a value that represents setting the first non-zero bit to 1, i.e. if there + // are 2 leading zeros, this would set the 3rd most significant bit to 1, with all other + // bits set to zero. + // + // NOTE: This first step is an intermediate computation. + // + // #=> [(2^(32 - clz) - 1) / 2, clz, n, ...] + Push(2u8.into()), U32div, Drop, + // Save the intermediate result of dividing by 2 for reuse in the next step + // + // #=> [((2^(32 - clz) - 1) / 2) + 1, (2^(32 - clz) - 1) / 2, clz, n, ...] + Dup0, Incr, + // 3. Obtain a mask for `clz + 1` leading bits + // + // #=> [u32::MAX - (2^(32 - clz) - 1 / 2), ((2^(32 - clz) - 1) / 2) + 1, clz, n, ...] + Push(u32::MAX.into()), MovUp2, Neg, Add, + // 4. Set zero flag if input was zero, and apply the mask to the input value + // + // #=> [n & mask, (2^(32 - clz) - 1 / 2) + 1, clz, is_zero] + Dup3, Eqz, MovDn3, MovUp4, U32and, + // 6. Assert that the masked input, and the mask representing `clz` leading zeros, followed + // by at least one trailing one, if `clz < 32`, are equal; OR that the input was zero if `clz` + // is 32. + Eq, MovUp2, Or, Assert(0), ]; - span.push_ops(ops_group_2); + block_builder.push_ops(ops_group_2); } /// Appends relevant operations to the span block for the correctness check of the `U32Clo` @@ -524,44 +534,44 @@ fn calculate_clz(span: &mut BasicBlockBuilder) { /// /// `[clo, n, ... ] -> [clo, ... ]` /// -/// VM cycles: 35 -fn calculate_clo(span: &mut BasicBlockBuilder) { +/// VM cycle: 40 +fn verify_clo(block_builder: &mut BasicBlockBuilder) { // [clo, n, ...] #[rustfmt::skip] let ops_group_1 = [ - Swap, Push(32u8.into()), Dup2, Neg, Add // [32 - clo, n, clo, ...] + Push(32u8.into()), Dup1, Neg, Add // [32 - clo, clo, n, ...] ]; - span.push_ops(ops_group_1); + block_builder.push_ops(ops_group_1); - append_pow2_op(span); // [pow2(32 - clo), n, clo, ...] + append_pow2_op(block_builder); // [pow2(32 - clo), clo, n, ...] #[rustfmt::skip] let ops_group_2 = [ - Push(Felt::new(u32::MAX as u64 + 1)), // [2^32, pow2(32 - clo), n, clo, ...] - - Dup1, Neg, Add, // [2^32 - pow2(32 - clo), pow2(32 - clo), n, clo, ...] - // `2^32 - pow2(32 - clo)` is equal to `clo` leading ones and `32 - clo` - // zeros: - // 11111111...1110000...0 - // └─ clo ones ─┘ - - Swap, Push(2u8.into()), U32div, Drop, // [pow2(32 - clo) / 2, 2^32 - pow2(32 - clo), n, clo, ...] - // pow2(32 - clo) / 2 is equal to `clo` leading - // zeros, `1` one and all other zeros. - - Dup1, Add, // [bit_mask, 2^32 - pow2(32 - clo), n, clo, ...] - // 111111111...111000...0 <-- bitmask - // └─ clo ones ─┘│ - // └─ additional one - - MovUp2, U32and, // [m, 2^32 - pow2(32 - clo), clo] - // If calcualtion of `clo` is correct, m should be equal to - // 2^32 - pow2(32 - clo) - - Eq, Assert(0) // [clo, ...] + // 1. Obtain a mask for all `32 - clo` trailing bits + // + // #=> [2^(32 - clo) - 1, clo, n] + Push(1u8.into()), Neg, Add, + // 2. Obtain a mask for `32 - clo - 1` trailing bits + // + // #=> [(2^(32 - clo) - 1) / 2, 2^(32 - clo) - 1, clo, n] + Dup0, Push(2u8.into()), U32div, Drop, + // 3. Invert the mask from Step 2, to get one that covers `clo + 1` leading bits + // + // #=> [u32::MAX - ((2^(32 - clo) - 1) / 2), 2^(32 - clo) - 1, clo, n] + Push(u32::MAX.into()), Swap, Neg, Add, + // 4. Apply the mask to the input value + // + // #=> [n & mask, 2^(32 - clo) - 1, clo] + MovUp3, U32and, + // 5. Invert the mask from Step 1, to get one that covers `clo` leading bits + // + // #=> [u32::MAX - 2^(32 - clo) - 1, n & mask, clo] + Push(u32::MAX.into()), MovUp2, Neg, Add, + // 6. Assert that the masked input, and the mask representing `clo` leading ones, are equal + Eq, Assert(0), ]; - span.push_ops(ops_group_2); + block_builder.push_ops(ops_group_2); } /// Appends relevant operations to the span block for the correctness check of the `U32Ctz` @@ -600,15 +610,15 @@ fn calculate_clo(span: &mut BasicBlockBuilder) { /// `[ctz, n, ... ] -> [ctz, ... ]` /// /// VM cycles: 33 -fn calculate_ctz(span: &mut BasicBlockBuilder) { +fn verify_ctz(block_builder: &mut BasicBlockBuilder) { // [ctz, n, ...] #[rustfmt::skip] let ops_group_1 = [ Swap, Dup1, // [ctz, n, ctz, ...] ]; - span.push_ops(ops_group_1); + block_builder.push_ops(ops_group_1); - append_pow2_op(span); // [pow2(ctz), n, ctz, ...] + append_pow2_op(block_builder); // [pow2(ctz), n, ctz, ...] #[rustfmt::skip] let ops_group_2 = [ @@ -635,7 +645,7 @@ fn calculate_ctz(span: &mut BasicBlockBuilder) { Eq, Assert(0), // [ctz, ...] ]; - span.push_ops(ops_group_2); + block_builder.push_ops(ops_group_2); } /// Appends relevant operations to the span block for the correctness check of the `U32Cto` @@ -674,15 +684,15 @@ fn calculate_ctz(span: &mut BasicBlockBuilder) { /// `[cto, n, ... ] -> [cto, ... ]` /// /// VM cycles: 32 -fn calculate_cto(span: &mut BasicBlockBuilder) { +fn verify_cto(block_builder: &mut BasicBlockBuilder) { // [cto, n, ...] #[rustfmt::skip] let ops_group_1 = [ Swap, Dup1, // [cto, n, cto, ...] ]; - span.push_ops(ops_group_1); + block_builder.push_ops(ops_group_1); - append_pow2_op(span); // [pow2(cto), n, cto, ...] + append_pow2_op(block_builder); // [pow2(cto), n, cto, ...] #[rustfmt::skip] let ops_group_2 = [ @@ -709,7 +719,7 @@ fn calculate_cto(span: &mut BasicBlockBuilder) { Eq, Assert(0), // [cto, ...] ]; - span.push_ops(ops_group_2); + block_builder.push_ops(ops_group_2); } // COMPARISON OPERATIONS @@ -720,8 +730,8 @@ fn calculate_cto(span: &mut BasicBlockBuilder) { /// This operation takes: /// - 3 cycles without immediate value. /// - 4 cycles with immediate value. -pub fn u32lt(span_builder: &mut BasicBlockBuilder) { - compute_lt(span_builder); +pub fn u32lt(block_builder: &mut BasicBlockBuilder) { + compute_lt(block_builder); } /// Translates u32lte assembly instruction to VM operations. @@ -729,13 +739,13 @@ pub fn u32lt(span_builder: &mut BasicBlockBuilder) { /// This operation takes: /// - 5 cycles without immediate value. /// - 6 cycles with immediate value. -pub fn u32lte(span_builder: &mut BasicBlockBuilder) { +pub fn u32lte(block_builder: &mut BasicBlockBuilder) { // Compute the lt with reversed number to get a gt check - span_builder.push_op(Swap); - compute_lt(span_builder); + block_builder.push_op(Swap); + compute_lt(block_builder); // Flip the final results to get the lte results. - span_builder.push_op(Not); + block_builder.push_op(Not); } /// Translates u32gt assembly instruction to VM operations. @@ -743,11 +753,11 @@ pub fn u32lte(span_builder: &mut BasicBlockBuilder) { /// This operation takes: /// - 4 cycles without immediate value. /// - 5 cycles with immediate value. -pub fn u32gt(span_builder: &mut BasicBlockBuilder) { +pub fn u32gt(block_builder: &mut BasicBlockBuilder) { // Reverse the numbers so we can get a gt check. - span_builder.push_op(Swap); + block_builder.push_op(Swap); - compute_lt(span_builder); + compute_lt(block_builder); } /// Translates u32gte assembly instruction to VM operations. @@ -755,11 +765,11 @@ pub fn u32gt(span_builder: &mut BasicBlockBuilder) { /// This operation takes: /// - 4 cycles without immediate value. /// - 5 cycles with immediate value. -pub fn u32gte(span_builder: &mut BasicBlockBuilder) { - compute_lt(span_builder); +pub fn u32gte(block_builder: &mut BasicBlockBuilder) { + compute_lt(block_builder); // Flip the final results to get the gte results. - span_builder.push_op(Not); + block_builder.push_op(Not); } /// Translates u32min assembly instruction to VM operations. @@ -771,11 +781,11 @@ pub fn u32gte(span_builder: &mut BasicBlockBuilder) { /// This operation takes: /// - 8 cycles without immediate value. /// - 9 cycles with immediate value. -pub fn u32min(span_builder: &mut BasicBlockBuilder) { - compute_max_and_min(span_builder); +pub fn u32min(block_builder: &mut BasicBlockBuilder) { + compute_max_and_min(block_builder); // Drop the max and keep the min - span_builder.push_op(Drop); + block_builder.push_op(Drop); } /// Translates u32max assembly instruction to VM operations. @@ -787,11 +797,11 @@ pub fn u32min(span_builder: &mut BasicBlockBuilder) { /// This operation takes: /// - 9 cycles without immediate value. /// - 10 cycles with immediate value. -pub fn u32max(span_builder: &mut BasicBlockBuilder) { - compute_max_and_min(span_builder); +pub fn u32max(block_builder: &mut BasicBlockBuilder) { + compute_max_and_min(block_builder); // Drop the min and keep the max - span_builder.push_ops([Swap, Drop]); + block_builder.push_ops([Swap, Drop]); } // COMPARISON OPERATIONS - HELPERS @@ -799,8 +809,8 @@ pub fn u32max(span_builder: &mut BasicBlockBuilder) { /// Inserts the VM operations to check if the second element is less than /// the top element. This takes 3 cycles. -fn compute_lt(span_builder: &mut BasicBlockBuilder) { - span_builder.push_ops([ +fn compute_lt(block_builder: &mut BasicBlockBuilder) { + block_builder.push_ops([ U32sub, Swap, Drop, // Perform the operations ]) } @@ -808,12 +818,12 @@ fn compute_lt(span_builder: &mut BasicBlockBuilder) { /// Duplicate the top two elements in the stack and determine the min and max between them. /// /// The maximum number will be at the top of the stack and minimum will be at the 2nd index. -fn compute_max_and_min(span_builder: &mut BasicBlockBuilder) { +fn compute_max_and_min(block_builder: &mut BasicBlockBuilder) { // Copy top two elements of the stack. - span_builder.push_ops([Dup1, Dup1]); + block_builder.push_ops([Dup1, Dup1]); #[rustfmt::skip] - span_builder.push_ops([ + block_builder.push_ops([ U32sub, Swap, Drop, // Check the underflow flag, if it's zero diff --git a/assembly/src/assembler/mast_forest_builder.rs b/assembly/src/assembler/mast_forest_builder.rs index e072bfff27..96e7d51b52 100644 --- a/assembly/src/assembler/mast_forest_builder.rs +++ b/assembly/src/assembler/mast_forest_builder.rs @@ -1,14 +1,15 @@ use alloc::{ collections::{BTreeMap, BTreeSet}, - sync::Arc, vec::Vec, }; -use core::ops::Index; +use core::ops::{Index, IndexMut}; use vm_core::{ crypto::hash::RpoDigest, - mast::{MastForest, MastNode, MastNodeId}, - DecoratorList, Operation, + mast::{ + DecoratorFingerprint, DecoratorId, MastForest, MastNode, MastNodeFingerprint, MastNodeId, + }, + Decorator, DecoratorList, Operation, }; use super::{GlobalProcedureIndex, Procedure}; @@ -24,14 +25,40 @@ const PROCEDURE_INLINING_THRESHOLD: usize = 32; // ================================================================================================ /// Builder for a [`MastForest`]. +/// +/// The purpose of the builder is to ensure that the underlying MAST forest contains as little +/// information as possible needed to adequately describe the logical MAST forest. Specifically: +/// - The builder ensures that only one copy of nodes that have the same MAST root and decorators is +/// added to the MAST forest (i.e., two nodes that have the same MAST root and decorators will +/// have the same [`MastNodeId`]). +/// - The builder tries to merge adjacent basic blocks and eliminate the source block whenever this +/// does not have an impact on other nodes in the forest. #[derive(Clone, Debug, Default)] pub struct MastForestBuilder { + /// The MAST forest being built by this builder; this MAST forest is up-to-date - i.e., all + /// nodes added to the MAST forest builder are also immediately added to the underlying MAST + /// forest. mast_forest: MastForest, - node_id_by_hash: BTreeMap, - procedures: BTreeMap>, - procedure_hashes: BTreeMap, - proc_gid_by_hash: BTreeMap, - merged_node_ids: BTreeSet, + /// A map of all procedures added to the MAST forest indexed by their global procedure ID. + /// This includes all local, exported, and re-exported procedures. In case multiple procedures + /// with the same digest are added to the MAST forest builder, only the first procedure is + /// added to the map, and all subsequent insertions are ignored. + procedures: BTreeMap, + /// A map from procedure MAST root to its global procedure index. Similar to the `procedures` + /// map, this map contains only the first inserted procedure for procedures with the same MAST + /// root. + proc_gid_by_mast_root: BTreeMap, + /// A map of MAST node fingerprints to their corresponding positions in the MAST forest. + node_id_by_fingerprint: BTreeMap, + /// The reverse mapping of `node_id_by_fingerprint`. This map caches the fingerprints of all + /// nodes (for performance reasons). + hash_by_node_id: BTreeMap, + /// A map of decorator fingerprints to their corresponding positions in the MAST forest. + decorator_id_by_fingerprint: BTreeMap, + /// A set of IDs for basic blocks which have been merged into a bigger basic blocks. This is + /// used as a candidate set of nodes that may be eliminated if the are not referenced by any + /// other node in the forest and are not a root of any procedure. + merged_basic_block_ids: BTreeSet, } impl MastForestBuilder { @@ -42,7 +69,7 @@ impl MastForestBuilder { /// unchanged. Any [`MastNodeId`] used in reference to the old [`MastForest`] should be remapped /// using this map. pub fn build(mut self) -> (MastForest, Option>) { - let nodes_to_remove = get_nodes_to_remove(self.merged_node_ids, &self.mast_forest); + let nodes_to_remove = get_nodes_to_remove(self.merged_basic_block_ids, &self.mast_forest); let id_remappings = self.mast_forest.remove_nodes(&nodes_to_remove); (self.mast_forest, id_remappings) @@ -96,7 +123,7 @@ fn get_nodes_to_remove( nodes_to_remove.remove(&node.callee()); } }, - MastNode::Block(_) | MastNode::Dyn | MastNode::External(_) => (), + MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => (), } } @@ -109,29 +136,17 @@ impl MastForestBuilder { /// Returns a reference to the procedure with the specified [`GlobalProcedureIndex`], or None /// if such a procedure is not present in this MAST forest builder. #[inline(always)] - pub fn get_procedure(&self, gid: GlobalProcedureIndex) -> Option> { - self.procedures.get(&gid).cloned() - } - - /// Returns the hash of the procedure with the specified [`GlobalProcedureIndex`], or None if - /// such a procedure is not present in this MAST forest builder. - #[inline(always)] - pub fn get_procedure_hash(&self, gid: GlobalProcedureIndex) -> Option { - self.procedure_hashes.get(&gid).cloned() + pub fn get_procedure(&self, gid: GlobalProcedureIndex) -> Option<&Procedure> { + self.procedures.get(&gid) } /// Returns a reference to the procedure with the specified MAST root, or None /// if such a procedure is not present in this MAST forest builder. #[inline(always)] - pub fn find_procedure(&self, mast_root: &RpoDigest) -> Option> { - self.proc_gid_by_hash.get(mast_root).and_then(|gid| self.get_procedure(*gid)) - } - - /// Returns the [`MastNodeId`] of the procedure associated with a given MAST root, or None - /// if such a procedure is not present in this MAST forest builder. - #[inline(always)] - pub fn find_procedure_node_id(&self, digest: RpoDigest) -> Option { - self.mast_forest.find_procedure_root(digest) + pub fn find_procedure_by_mast_root(&self, mast_root: &RpoDigest) -> Option<&Procedure> { + self.proc_gid_by_mast_root + .get(mast_root) + .and_then(|gid| self.get_procedure(*gid)) } /// Returns the [`MastNode`] for the provided MAST node ID, or None if a node with this ID is @@ -141,18 +156,9 @@ impl MastForestBuilder { } } +// ------------------------------------------------------------------------------------------------ +/// Procedure insertion impl MastForestBuilder { - pub fn insert_procedure_hash( - &mut self, - gid: GlobalProcedureIndex, - proc_hash: RpoDigest, - ) -> Result<(), AssemblyError> { - // TODO(plafer): Check if exists - self.procedure_hashes.insert(gid, proc_hash); - - Ok(()) - } - /// Inserts a procedure into this MAST forest builder. /// /// If the procedure with the same ID already exists in this forest builder, this will have @@ -162,8 +168,6 @@ impl MastForestBuilder { gid: GlobalProcedureIndex, procedure: Procedure, ) -> Result<(), AssemblyError> { - let proc_root = self.mast_forest[procedure.body_node_id()].digest(); - // Check if an entry is already in this cache slot. // // If there is already a cache entry, but it conflicts with what we're trying to cache, @@ -181,7 +185,7 @@ impl MastForestBuilder { // We don't have a cache entry yet, but we do want to make sure we don't have a conflicting // cache entry with the same MAST root: - if let Some(cached) = self.find_procedure(&proc_root) { + if let Some(cached) = self.find_procedure_by_mast_root(&procedure.mast_root()) { // Handle the case where a procedure with no locals is lowered to a MastForest // consisting only of an `External` node to another procedure which has one or more // locals. This will result in the calling procedure having the same digest as the @@ -202,19 +206,17 @@ impl MastForestBuilder { } } - self.make_root(procedure.body_node_id()); - self.proc_gid_by_hash.insert(proc_root, gid); - self.insert_procedure_hash(gid, procedure.mast_root())?; - self.procedures.insert(gid, Arc::new(procedure)); + self.mast_forest.make_root(procedure.body_node_id()); + self.proc_gid_by_mast_root.insert(procedure.mast_root(), gid); + self.procedures.insert(gid, procedure); Ok(()) } +} - /// Marks the given [`MastNodeId`] as being the root of a procedure. - pub fn make_root(&mut self, new_root_id: MastNodeId) { - self.mast_forest.make_root(new_root_id) - } - +// ------------------------------------------------------------------------------------------------ +/// Joining nodes +impl MastForestBuilder { /// Builds a tree of `JOIN` operations to combine the provided MAST node IDs. pub fn join_nodes(&mut self, node_ids: Vec) -> Result { debug_assert!(!node_ids.is_empty(), "cannot combine empty MAST node id list"); @@ -254,7 +256,7 @@ impl MastForestBuilder { let mut contiguous_basic_block_ids: Vec = Vec::new(); for mast_node_id in node_ids { - if self[mast_node_id].is_basic_block() { + if self.mast_forest[mast_node_id].is_basic_block() { contiguous_basic_block_ids.push(mast_node_id); } else { merged_node_ids.extend(self.merge_basic_blocks(&contiguous_basic_block_ids)?); @@ -293,15 +295,16 @@ impl MastForestBuilder { for &basic_block_id in contiguous_basic_block_ids { // It is safe to unwrap here, since we already checked that all IDs in // `contiguous_basic_block_ids` are `BasicBlockNode`s - let basic_block_node = self[basic_block_id].get_basic_block().unwrap().clone(); + let basic_block_node = + self.mast_forest[basic_block_id].get_basic_block().unwrap().clone(); // check if the block should be merged with other blocks if should_merge( self.mast_forest.is_procedure_root(basic_block_id), basic_block_node.num_op_batches(), ) { - for (op_idx, decorator) in basic_block_node.decorators() { - decorators.push((*op_idx + operations.len(), decorator.clone())); + for &(op_idx, decorator) in basic_block_node.decorators() { + decorators.push((op_idx + operations.len(), decorator)); } for batch in basic_block_node.op_batches() { operations.extend_from_slice(batch.ops()); @@ -322,7 +325,7 @@ impl MastForestBuilder { } // Mark the removed basic blocks as merged - self.merged_node_ids.extend(contiguous_basic_block_ids.iter()); + self.merged_basic_block_ids.extend(contiguous_basic_block_ids.iter()); if !operations.is_empty() || !decorators.is_empty() { let merged_basic_block = self.ensure_block(operations, Some(decorators))?; @@ -336,20 +339,36 @@ impl MastForestBuilder { // ------------------------------------------------------------------------------------------------ /// Node inserters impl MastForestBuilder { + /// Adds a decorator to the forest, and returns the [`Decorator`] associated with it. + pub fn ensure_decorator(&mut self, decorator: Decorator) -> Result { + let decorator_hash = decorator.fingerprint(); + + if let Some(decorator_id) = self.decorator_id_by_fingerprint.get(&decorator_hash) { + // decorator already exists in the forest; return previously assigned id + Ok(*decorator_id) + } else { + let new_decorator_id = self.mast_forest.add_decorator(decorator)?; + self.decorator_id_by_fingerprint.insert(decorator_hash, new_decorator_id); + + Ok(new_decorator_id) + } + } + /// Adds a node to the forest, and returns the [`MastNodeId`] associated with it. /// - /// If a [`MastNode`] which is equal to the current node was previously added, the previously - /// returned [`MastNodeId`] will be returned. This enforces this invariant that equal - /// [`MastNode`]s have equal [`MastNodeId`]s. - fn ensure_node(&mut self, node: MastNode) -> Result { - let node_digest = node.digest(); + /// Note that only one copy of nodes that have the same MAST root and decorators is added to the + /// MAST forest; two nodes that have the same MAST root and decorators will have the same + /// [`MastNodeId`]. + pub fn ensure_node(&mut self, node: MastNode) -> Result { + let node_fingerprint = self.fingerprint_for_node(&node); - if let Some(node_id) = self.node_id_by_hash.get(&node_digest) { + if let Some(node_id) = self.node_id_by_fingerprint.get(&node_fingerprint) { // node already exists in the forest; return previously assigned id Ok(*node_id) } else { let new_node_id = self.mast_forest.add_node(node)?; - self.node_id_by_hash.insert(node_digest, new_node_id); + self.node_id_by_fingerprint.insert(node_fingerprint, new_node_id); + self.hash_by_node_id.insert(new_node_id, node_fingerprint); Ok(new_node_id) } @@ -408,10 +427,36 @@ impl MastForestBuilder { self.ensure_node(MastNode::new_dyn()) } + /// Adds a dyncall node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn ensure_dyncall(&mut self) -> Result { + self.ensure_node(MastNode::new_dyncall()) + } + /// Adds an external node to the forest, and returns the [`MastNodeId`] associated with it. pub fn ensure_external(&mut self, mast_root: RpoDigest) -> Result { self.ensure_node(MastNode::new_external(mast_root)) } + + pub fn set_before_enter(&mut self, node_id: MastNodeId, decorator_ids: Vec) { + self.mast_forest[node_id].set_before_enter(decorator_ids); + + let new_node_fingerprint = self.fingerprint_for_node(&self[node_id]); + self.hash_by_node_id.insert(node_id, new_node_fingerprint); + } + + pub fn set_after_exit(&mut self, node_id: MastNodeId, decorator_ids: Vec) { + self.mast_forest[node_id].set_after_exit(decorator_ids); + + let new_node_fingerprint = self.fingerprint_for_node(&self[node_id]); + self.hash_by_node_id.insert(node_id, new_node_fingerprint); + } +} + +impl MastForestBuilder { + fn fingerprint_for_node(&self, node: &MastNode) -> MastNodeFingerprint { + MastNodeFingerprint::from_mast_node(&self.mast_forest, &self.hash_by_node_id, node) + .expect("hash_by_node_id should contain the fingerprints of all children of `node`") + } } impl Index for MastForestBuilder { @@ -423,6 +468,22 @@ impl Index for MastForestBuilder { } } +impl Index for MastForestBuilder { + type Output = Decorator; + + #[inline(always)] + fn index(&self, decorator_id: DecoratorId) -> &Self::Output { + &self.mast_forest[decorator_id] + } +} + +impl IndexMut for MastForestBuilder { + #[inline(always)] + fn index_mut(&mut self, decorator_id: DecoratorId) -> &mut Self::Output { + &mut self.mast_forest[decorator_id] + } +} + // HELPER FUNCTIONS // ================================================================================================ diff --git a/assembly/src/assembler/mast_forest_merger_tests.rs b/assembly/src/assembler/mast_forest_merger_tests.rs new file mode 100644 index 0000000000..96e533992c --- /dev/null +++ b/assembly/src/assembler/mast_forest_merger_tests.rs @@ -0,0 +1,73 @@ +use miette::{IntoDiagnostic, Report}; +use vm_core::mast::{MastForest, MastForestRootMap}; + +use crate::{testing::TestContext, Assembler}; + +#[allow(clippy::type_complexity)] +fn merge_programs( + program_a: &str, + program_b: &str, +) -> Result<(MastForest, MastForest, MastForest, MastForestRootMap), Report> { + let context = TestContext::new(); + let module = context.parse_module_with_path("lib::mod".parse().unwrap(), program_a)?; + + let lib_a = Assembler::new(context.source_manager()).assemble_library([module])?; + + let mut assembler = Assembler::new(context.source_manager()); + assembler.add_library(lib_a.clone())?; + let lib_b = assembler.assemble_library([program_b])?.mast_forest().as_ref().clone(); + let lib_a = lib_a.mast_forest().as_ref().clone(); + + let (merged, root_maps) = MastForest::merge([&lib_a, &lib_b]).into_diagnostic()?; + + Ok((lib_a, lib_b, merged, root_maps)) +} + +/// Tests that an assembler-produced library's forests can be merged and that external nodes are +/// replaced by their referenced procedures. +#[test] +fn mast_forest_merge_assembler() { + let lib_a = r#" + export.foo + push.19 + end + + export.qux + swap drop + end +"#; + + let lib_b = r#" + use.lib::mod + + export.qux_duplicate + swap drop + end + + export.bar + push.2 + if.true + push.3 + else + while.true + add + push.23 + end + end + exec.mod::foo + end"#; + + let (forest_a, forest_b, merged, root_maps) = merge_programs(lib_a, lib_b).unwrap(); + + for (forest_idx, forest) in [forest_a, forest_b].into_iter().enumerate() { + for root in forest.procedure_roots() { + let original_digest = forest.nodes()[root.as_usize()].digest(); + let new_root = root_maps.map_root(forest_idx, root).unwrap(); + let new_digest = merged.nodes()[new_root.as_usize()].digest(); + assert_eq!(original_digest, new_digest); + } + } + + // Assert that the external node for the import was removed during merging. + merged.nodes().iter().for_each(|node| assert!(!node.is_external())); +} diff --git a/assembly/src/assembler/mod.rs b/assembly/src/assembler/mod.rs index ed67ed1546..2ddc94e31a 100644 --- a/assembly/src/assembler/mod.rs +++ b/assembly/src/assembler/mod.rs @@ -1,16 +1,21 @@ use alloc::{collections::BTreeMap, sync::Arc, vec::Vec}; +use basic_block_builder::BasicBlockOrDecorators; use mast_forest_builder::MastForestBuilder; use module_graph::{ProcedureWrapper, WrappedModule}; -use vm_core::{mast::MastNodeId, DecoratorList, Felt, Kernel, Operation, Program}; +use vm_core::{ + crypto::hash::RpoDigest, + debuginfo::SourceSpan, + mast::{DecoratorId, MastNodeId}, + DecoratorList, Felt, Kernel, Operation, Program, +}; use crate::{ ast::{self, Export, InvocationTarget, InvokeKind, ModuleKind, QualifiedProcedureName}, diagnostics::Report, library::{KernelLibrary, Library}, sema::SemanticAnalysisError, - AssemblyError, Compile, CompileOptions, LibraryNamespace, LibraryPath, RpoDigest, - SourceManager, Spanned, + AssemblyError, Compile, CompileOptions, LibraryNamespace, LibraryPath, SourceManager, Spanned, }; mod basic_block_builder; @@ -23,6 +28,9 @@ mod procedure; #[cfg(test)] mod tests; +#[cfg(test)] +mod mast_forest_merger_tests; + use self::{ basic_block_builder::BasicBlockBuilder, module_graph::{CallerInfo, ModuleGraph, ResolvedTarget}, @@ -202,6 +210,21 @@ impl Assembler { } /// Adds the compiled library to provide modules for the compilation. + /// + /// We only current support adding non-vendored libraries - that is, the source code of exported + /// procedures is not included in the program that compiles against the library. The library's + /// source code is instead expected to be loaded in the processor at execution time. Hence, all + /// calls to library procedures will be compiled down to a [`vm_core::mast::ExternalNode`] (i.e. + /// a reference to the procedure's MAST root). This means that when executing a program compiled + /// against a library, the processor will not be able to differentiate procedures with the same + /// MAST root but different decorators. + /// + /// Hence, it is not recommended to export two procedures that have the same MAST root (i.e. are + /// identical except for their decorators). Note however that we don't expect this scenario to + /// be frequent in practice. For example, this could occur when APIs are being renamed and/or + /// moved between modules, and for some deprecation period, the same is exported under both its + /// old and new paths. Or possibly with common small functions that are implemented by the main + /// program and one of its dependencies. pub fn add_library(&mut self, library: impl AsRef) -> Result<(), Report> { self.module_graph .add_compiled_modules(library.as_ref().module_infos()) @@ -210,6 +233,8 @@ impl Assembler { } /// Adds the compiled library to provide modules for the compilation. + /// + /// See [`Self::add_library`] for more detailed information. pub fn with_library(mut self, library: impl AsRef) -> Result { self.add_library(library)?; Ok(self) @@ -277,7 +302,7 @@ impl Assembler { let mut mast_forest_builder = MastForestBuilder::default(); - let exports = { + let mut exports = { let mut exports = BTreeMap::new(); for module_idx in ast_module_indices { @@ -289,10 +314,11 @@ impl Assembler { let gid = module_idx + proc_idx; self.compile_subgraph(gid, &mut mast_forest_builder)?; - let proc_hash = mast_forest_builder - .get_procedure_hash(gid) - .expect("compilation succeeded but root not found in cache"); - exports.insert(fqn, proc_hash); + let proc_root_node_id = mast_forest_builder + .get_procedure(gid) + .expect("compilation succeeded but root not found in cache") + .body_node_id(); + exports.insert(fqn, proc_root_node_id); } } @@ -300,8 +326,16 @@ impl Assembler { }; // TODO: show a warning if library exports are empty? - let (mast_forest, _) = mast_forest_builder.build(); - Ok(Library::new(mast_forest, exports)) + let (mast_forest, id_remappings) = mast_forest_builder.build(); + if let Some(id_remappings) = id_remappings { + for (_proc_name, node_id) in exports.iter_mut() { + if let Some(&new_node_id) = id_remappings.get(node_id) { + *node_id = new_node_id; + } + } + } + + Ok(Library::new(mast_forest.into(), exports)?) } /// Assembles the provided module into a [KernelLibrary] intended to be used as a Kernel. @@ -327,23 +361,30 @@ impl Assembler { // AST (we just added them to the module graph) let ast_module = self.module_graph[module_idx].unwrap_ast().clone(); - let exports = ast_module + let mut exports = ast_module .exported_procedures() .map(|(proc_idx, fqn)| { let gid = module_idx + proc_idx; self.compile_subgraph(gid, &mut mast_forest_builder)?; - let proc_hash = mast_forest_builder - .get_procedure_hash(gid) - .expect("compilation succeeded but root not found in cache"); - Ok((fqn, proc_hash)) + let proc_root_node_id = mast_forest_builder + .get_procedure(gid) + .expect("compilation succeeded but root not found in cache") + .body_node_id(); + Ok((fqn, proc_root_node_id)) }) - .collect::, Report>>()?; + .collect::, Report>>()?; // TODO: show a warning if library exports are empty? - - let (mast_forest, _) = mast_forest_builder.build(); - let library = Library::new(mast_forest, exports); + let (mast_forest, id_remappings) = mast_forest_builder.build(); + if let Some(id_remappings) = id_remappings { + for (_proc_name, node_id) in exports.iter_mut() { + if let Some(&new_node_id) = id_remappings.get(node_id) { + *node_id = new_node_id; + } + } + } + let library = Library::new(mast_forest.into(), exports)?; Ok(library.try_into()?) } @@ -379,21 +420,19 @@ impl Assembler { // Compile the module graph rooted at the entrypoint let mut mast_forest_builder = MastForestBuilder::default(); self.compile_subgraph(entrypoint, &mut mast_forest_builder)?; - let entry_procedure = mast_forest_builder + let entry_node_id = mast_forest_builder .get_procedure(entrypoint) - .expect("compilation succeeded but root not found in cache"); + .expect("compilation succeeded but root not found in cache") + .body_node_id(); + // in case the node IDs changed, update the entrypoint ID to the new value let (mast_forest, id_remappings) = mast_forest_builder.build(); - let entry_node_id = { - let old_entry_node_id = entry_procedure.body_node_id(); - - id_remappings - .map(|id_remappings| id_remappings[&old_entry_node_id]) - .unwrap_or(old_entry_node_id) - }; + let entry_node_id = id_remappings + .map(|id_remappings| id_remappings[&entry_node_id]) + .unwrap_or(entry_node_id); Ok(Program::with_kernel( - mast_forest, + mast_forest.into(), entry_node_id, self.module_graph.kernel().clone(), )) @@ -441,7 +480,7 @@ impl Assembler { while let Some(procedure_gid) = worklist.pop() { // If we have already compiled this procedure, do not recompile if let Some(proc) = mast_forest_builder.get_procedure(procedure_gid) { - self.module_graph.register_mast_root(procedure_gid, proc.mast_root())?; + self.module_graph.register_procedure_root(procedure_gid, proc.mast_root())?; continue; } // Fetch procedure metadata from the graph @@ -473,9 +512,15 @@ impl Assembler { // Compile this procedure let procedure = self.compile_procedure(pctx, mast_forest_builder)?; - - // Cache the compiled procedure. - self.module_graph.register_mast_root(procedure_gid, procedure.mast_root())?; + // TODO: if a re-exported procedure with the same MAST root had been previously + // added to the builder, this will result in unreachable nodes added to the + // MAST forest. This is because while we won't insert a duplicate node for the + // procedure body node itself, all nodes that make up the procedure body would + // be added to the forest. + + // Cache the compiled procedure + self.module_graph + .register_procedure_root(procedure_gid, procedure.mast_root())?; mast_forest_builder.insert_procedure(procedure_gid, procedure)?; }, Export::Alias(proc_alias) => { @@ -493,15 +538,20 @@ impl Assembler { ) .with_span(proc_alias.span()); - let proc_alias_root = self.resolve_target( + let proc_node_id = self.resolve_target( InvokeKind::ProcRef, &proc_alias.target().into(), &pctx, mast_forest_builder, )?; + let proc_mast_root = + mast_forest_builder.get_mast_node(proc_node_id).unwrap().digest(); + + let procedure = pctx.into_procedure(proc_mast_root, proc_node_id); + // Make the MAST root available to all dependents - self.module_graph.register_mast_root(procedure_gid, proc_alias_root)?; - mast_forest_builder.insert_procedure_hash(procedure_gid, proc_alias_root)?; + self.module_graph.register_procedure_root(procedure_gid, proc_mast_root)?; + mast_forest_builder.insert_procedure(procedure_gid, procedure)?; }, } } @@ -554,95 +604,161 @@ impl Assembler { { use ast::Op; - let mut node_ids: Vec = Vec::new(); - let mut basic_block_builder = BasicBlockBuilder::new(wrapper); + let mut body_node_ids: Vec = Vec::new(); + let mut block_builder = BasicBlockBuilder::new(wrapper, mast_forest_builder); for op in body { match op { Op::Inst(inst) => { - if let Some(mast_node_id) = self.compile_instruction( - inst, - &mut basic_block_builder, - proc_ctx, - mast_forest_builder, - )? { - if let Some(basic_block_id) = - basic_block_builder.make_basic_block(mast_forest_builder)? - { - node_ids.push(basic_block_id); + if let Some(node_id) = + self.compile_instruction(inst, &mut block_builder, proc_ctx)? + { + if let Some(basic_block_id) = block_builder.make_basic_block()? { + body_node_ids.push(basic_block_id); + } else if let Some(decorator_ids) = block_builder.drain_decorators() { + block_builder + .mast_forest_builder_mut() + .set_before_enter(node_id, decorator_ids); } - node_ids.push(mast_node_id); + body_node_ids.push(node_id); } }, Op::If { then_blk, else_blk, .. } => { - if let Some(basic_block_id) = - basic_block_builder.make_basic_block(mast_forest_builder)? - { - node_ids.push(basic_block_id); + if let Some(basic_block_id) = block_builder.make_basic_block()? { + body_node_ids.push(basic_block_id); } - let then_blk = - self.compile_body(then_blk.iter(), proc_ctx, None, mast_forest_builder)?; - let else_blk = - self.compile_body(else_blk.iter(), proc_ctx, None, mast_forest_builder)?; + let then_blk = self.compile_body( + then_blk.iter(), + proc_ctx, + None, + block_builder.mast_forest_builder_mut(), + )?; + let else_blk = self.compile_body( + else_blk.iter(), + proc_ctx, + None, + block_builder.mast_forest_builder_mut(), + )?; + + let split_node_id = + block_builder.mast_forest_builder_mut().ensure_split(then_blk, else_blk)?; + if let Some(decorator_ids) = block_builder.drain_decorators() { + block_builder + .mast_forest_builder_mut() + .set_before_enter(split_node_id, decorator_ids) + } - let split_node_id = mast_forest_builder.ensure_split(then_blk, else_blk)?; - node_ids.push(split_node_id); + body_node_ids.push(split_node_id); }, Op::Repeat { count, body, .. } => { - if let Some(basic_block_id) = - basic_block_builder.make_basic_block(mast_forest_builder)? - { - node_ids.push(basic_block_id); + if let Some(basic_block_id) = block_builder.make_basic_block()? { + body_node_ids.push(basic_block_id); } - let repeat_node_id = - self.compile_body(body.iter(), proc_ctx, None, mast_forest_builder)?; + let repeat_node_id = self.compile_body( + body.iter(), + proc_ctx, + None, + block_builder.mast_forest_builder_mut(), + )?; - for _ in 0..*count { - node_ids.push(repeat_node_id); + if let Some(decorator_ids) = block_builder.drain_decorators() { + // Attach the decorators before the first instance of the repeated node + let mut first_repeat_node = + block_builder.mast_forest_builder_mut()[repeat_node_id].clone(); + first_repeat_node.set_before_enter(decorator_ids); + let first_repeat_node_id = block_builder + .mast_forest_builder_mut() + .ensure_node(first_repeat_node)?; + + body_node_ids.push(first_repeat_node_id); + for _ in 0..(*count - 1) { + body_node_ids.push(repeat_node_id); + } + } else { + for _ in 0..*count { + body_node_ids.push(repeat_node_id); + } } }, Op::While { body, .. } => { - if let Some(basic_block_id) = - basic_block_builder.make_basic_block(mast_forest_builder)? - { - node_ids.push(basic_block_id); + if let Some(basic_block_id) = block_builder.make_basic_block()? { + body_node_ids.push(basic_block_id); } - let loop_body_node_id = - self.compile_body(body.iter(), proc_ctx, None, mast_forest_builder)?; + let loop_node_id = { + let loop_body_node_id = self.compile_body( + body.iter(), + proc_ctx, + None, + block_builder.mast_forest_builder_mut(), + )?; + block_builder.mast_forest_builder_mut().ensure_loop(loop_body_node_id)? + }; + if let Some(decorator_ids) = block_builder.drain_decorators() { + block_builder + .mast_forest_builder_mut() + .set_before_enter(loop_node_id, decorator_ids) + } - let loop_node_id = mast_forest_builder.ensure_loop(loop_body_node_id)?; - node_ids.push(loop_node_id); + body_node_ids.push(loop_node_id); }, } } - if let Some(basic_block_id) = - basic_block_builder.try_into_basic_block(mast_forest_builder)? - { - node_ids.push(basic_block_id); - } + let maybe_post_decorators: Option> = + match block_builder.try_into_basic_block()? { + BasicBlockOrDecorators::BasicBlock(basic_block_id) => { + body_node_ids.push(basic_block_id); + None + }, + BasicBlockOrDecorators::Decorators(decorator_ids) => { + // the procedure body ends with a list of decorators + Some(decorator_ids) + }, + BasicBlockOrDecorators::Nothing => None, + }; + + let procedure_body_id = if body_node_ids.is_empty() { + // We cannot allow only decorators in a procedure body, since decorators don't change + // the MAST digest of a node. Hence, two empty procedures with different decorators + // would look the same to the `MastForestBuilder`. + if maybe_post_decorators.is_some() { + return Err(AssemblyError::EmptyProcedureBodyWithDecorators { + span: proc_ctx.span(), + source_file: proc_ctx.source_manager().get(proc_ctx.span().source_id()).ok(), + })?; + } - Ok(if node_ids.is_empty() { mast_forest_builder.ensure_block(vec![Operation::Noop], None)? } else { - mast_forest_builder.join_nodes(node_ids)? - }) + mast_forest_builder.join_nodes(body_node_ids)? + }; + + // Make sure that any post decorators are added at the end of the procedure body + if let Some(post_decorator_ids) = maybe_post_decorators { + mast_forest_builder.set_after_exit(procedure_body_id, post_decorator_ids); + } + + Ok(procedure_body_id) } + /// Resolves the specified target to the corresponding procedure root [`MastNodeId`]. + /// + /// If no [`MastNodeId`] exists for that procedure root, we wrap the root in an + /// [`crate::mast::ExternalNode`], and return the resulting [`MastNodeId`]. pub(super) fn resolve_target( &self, kind: InvokeKind, target: &InvocationTarget, proc_ctx: &ProcedureContext, - mast_forest_builder: &MastForestBuilder, - ) -> Result { + mast_forest_builder: &mut MastForestBuilder, + ) -> Result { let caller = CallerInfo { span: target.span(), module: proc_ctx.id().module, @@ -650,17 +766,86 @@ impl Assembler { }; let resolved = self.module_graph.resolve_target(&caller, target)?; match resolved { - ResolvedTarget::Phantom(digest) => Ok(digest), + ResolvedTarget::Phantom(mast_root) => self.ensure_valid_procedure_mast_root( + kind, + target.span(), + mast_root, + mast_forest_builder, + ), ResolvedTarget::Exact { gid } | ResolvedTarget::Resolved { gid, .. } => { - match mast_forest_builder.get_procedure_hash(gid) { - Some(proc_hash) => Ok(proc_hash), + match mast_forest_builder.get_procedure(gid) { + Some(proc) => Ok(proc.body_node_id()), + // We didn't find the procedure in our current MAST forest. We still need to + // check if it exists in one of a library dependency. None => match self.module_graph.get_procedure_unsafe(gid) { - ProcedureWrapper::Info(p) => Ok(p.digest), - ProcedureWrapper::Ast(_) => panic!("Did not find procedure {gid:?} neither in module graph nor procedure cache"), + ProcedureWrapper::Info(p) => self.ensure_valid_procedure_mast_root( + kind, + target.span(), + p.digest, + mast_forest_builder, + ) + , + ProcedureWrapper::Ast(_) => panic!("AST procedure {gid:?} exits in the module graph but not in the MastForestBuilder"), }, } - } + }, + } + } + + /// Verifies the validity of the MAST root as a procedure root hash, and returns the ID of the + /// [`core::mast::ExternalNode`] that wraps it. + fn ensure_valid_procedure_mast_root( + &self, + kind: InvokeKind, + span: SourceSpan, + mast_root: RpoDigest, + mast_forest_builder: &mut MastForestBuilder, + ) -> Result { + // Get the procedure from the assembler + let current_source_file = self.source_manager.get(span.source_id()).ok(); + + // If the procedure is cached and is a system call, ensure that the call is valid. + match mast_forest_builder.find_procedure_by_mast_root(&mast_root) { + Some(proc) if matches!(kind, InvokeKind::SysCall) => { + // Verify if this is a syscall, that the callee is a kernel procedure + // + // NOTE: The assembler is expected to know the full set of all kernel + // procedures at this point, so if we can't identify the callee as a + // kernel procedure, it is a definite error. + if !proc.visibility().is_syscall() { + return Err(AssemblyError::InvalidSysCallTarget { + span, + source_file: current_source_file, + callee: proc.fully_qualified_name().clone(), + }); + } + let maybe_kernel_path = proc.path(); + self.module_graph + .find_module(maybe_kernel_path) + .ok_or_else(|| AssemblyError::InvalidSysCallTarget { + span, + source_file: current_source_file.clone(), + callee: proc.fully_qualified_name().clone(), + }) + .and_then(|module| { + // Note: this module is guaranteed to be of AST variant, since we have the + // AST of a procedure contained in it (i.e. `proc`). Hence, it must be that + // the entire module is in AST representation as well. + if module.unwrap_ast().is_kernel() { + Ok(()) + } else { + Err(AssemblyError::InvalidSysCallTarget { + span, + source_file: current_source_file.clone(), + callee: proc.fully_qualified_name().clone(), + }) + } + })?; + }, + Some(_) | None => (), } + + mast_forest_builder.ensure_external(mast_root) } } diff --git a/assembly/src/assembler/module_graph/debug.rs b/assembly/src/assembler/module_graph/debug.rs index bf584fad1b..eea58adf48 100644 --- a/assembly/src/assembler/module_graph/debug.rs +++ b/assembly/src/assembler/module_graph/debug.rs @@ -14,7 +14,7 @@ impl fmt::Debug for ModuleGraph { #[doc(hidden)] struct DisplayModuleGraph<'a>(&'a ModuleGraph); -impl<'a> fmt::Debug for DisplayModuleGraph<'a> { +impl fmt::Debug for DisplayModuleGraph<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_set() .entries(self.0.modules.iter().enumerate().flat_map(|(module_index, m)| { @@ -56,7 +56,7 @@ impl<'a> fmt::Debug for DisplayModuleGraph<'a> { #[doc(hidden)] struct DisplayModuleGraphNodes<'a>(&'a Vec); -impl<'a> fmt::Debug for DisplayModuleGraphNodes<'a> { +impl fmt::Debug for DisplayModuleGraphNodes<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_list() .entries(self.0.iter().enumerate().flat_map(|(module_index, m)| { @@ -111,7 +111,7 @@ struct DisplayModuleGraphNode<'a> { ty: GraphNodeType, } -impl<'a> fmt::Debug for DisplayModuleGraphNode<'a> { +impl fmt::Debug for DisplayModuleGraphNode<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("Node") .field("id", &format_args!("{}:{}", &self.module.as_usize(), &self.index.as_usize())) @@ -128,7 +128,7 @@ struct DisplayModuleGraphNodeWithEdges<'a> { out_edges: &'a [GlobalProcedureIndex], } -impl<'a> fmt::Debug for DisplayModuleGraphNodeWithEdges<'a> { +impl fmt::Debug for DisplayModuleGraphNodeWithEdges<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("Edge") .field( diff --git a/assembly/src/assembler/module_graph/mod.rs b/assembly/src/assembler/module_graph/mod.rs index d4a2fff6ac..f8116f9985 100644 --- a/assembly/src/assembler/module_graph/mod.rs +++ b/assembly/src/assembler/module_graph/mod.rs @@ -8,7 +8,7 @@ use alloc::{boxed::Box, collections::BTreeMap, sync::Arc, vec::Vec}; use core::ops::Index; use smallvec::{smallvec, SmallVec}; -use vm_core::Kernel; +use vm_core::{crypto::hash::RpoDigest, Kernel}; use self::{analysis::MaybeRewriteCheck, name_resolver::NameResolver, rewrites::ModuleRewriter}; pub use self::{ @@ -22,7 +22,7 @@ use crate::{ ResolvedProcedure, }, library::{ModuleInfo, ProcedureInfo}, - AssemblyError, LibraryNamespace, LibraryPath, RpoDigest, SourceManager, Spanned, + AssemblyError, LibraryNamespace, LibraryPath, SourceManager, Spanned, }; // WRAPPER STRUCTS @@ -39,7 +39,7 @@ pub enum ProcedureWrapper<'a> { Info(&'a ProcedureInfo), } -impl<'a> ProcedureWrapper<'a> { +impl ProcedureWrapper<'_> { /// Returns the name of the procedure. pub fn name(&self) -> &ProcedureName { match self { @@ -160,7 +160,7 @@ pub struct ModuleGraph { callgraph: CallGraph, /// The set of MAST roots which have procedure definitions in this graph. There can be /// multiple procedures bound to the same root due to having identical code. - roots: BTreeMap>, + procedures_by_mast_root: BTreeMap>, kernel_index: Option, kernel: Kernel, source_manager: Arc, @@ -175,7 +175,7 @@ impl ModuleGraph { modules: Default::default(), pending: Default::default(), callgraph: Default::default(), - roots: Default::default(), + procedures_by_mast_root: Default::default(), kernel_index: None, kernel: Default::default(), source_manager, @@ -198,7 +198,7 @@ impl ModuleGraph { for &module_index in module_indices.iter() { for (proc_index, proc) in self[module_index].unwrap_info().clone().procedures() { let gid = module_index + proc_index; - self.register_mast_root(gid, proc.digest)?; + self.register_procedure_root(gid, proc.digest)?; } } @@ -520,11 +520,16 @@ impl ModuleGraph { } } + /// Returns a procedure index which corresponds to the provided procedure digest. + /// + /// Note that there can be many procedures with the same digest - due to having the same code, + /// and/or using different decorators which don't affect the MAST root. This method returns an + /// arbitrary one. pub fn get_procedure_index_by_digest( &self, - digest: &RpoDigest, + procedure_digest: &RpoDigest, ) -> Option { - self.roots.get(digest).map(|indices| indices[0]) + self.procedures_by_mast_root.get(procedure_digest).map(|indices| indices[0]) } /// Resolves `target` from the perspective of `caller`. @@ -537,7 +542,7 @@ impl ModuleGraph { resolver.resolve_target(caller, target) } - /// Registers a [RpoDigest] as corresponding to a given [GlobalProcedureIndex]. + /// Registers a [MastNodeId] as corresponding to a given [GlobalProcedureIndex]. /// /// # SAFETY /// @@ -545,13 +550,13 @@ impl ModuleGraph { /// procedure. It is fine if there are multiple procedures with the same digest, but it _must_ /// be the case that if a given digest is specified, it can be used as if it was the definition /// of the referenced procedure, i.e. they are referentially transparent. - pub(crate) fn register_mast_root( + pub(crate) fn register_procedure_root( &mut self, id: GlobalProcedureIndex, - digest: RpoDigest, + procedure_mast_root: RpoDigest, ) -> Result<(), AssemblyError> { use alloc::collections::btree_map::Entry; - match self.roots.entry(digest) { + match self.procedures_by_mast_root.entry(procedure_mast_root) { Entry::Occupied(ref mut entry) => { let prev_id = entry.get()[0]; if prev_id != id { diff --git a/assembly/src/assembler/procedure.rs b/assembly/src/assembler/procedure.rs index c1a59063b7..167a25806e 100644 --- a/assembly/src/assembler/procedure.rs +++ b/assembly/src/assembler/procedure.rs @@ -1,4 +1,4 @@ -use alloc::{collections::BTreeSet, sync::Arc}; +use alloc::sync::Arc; use vm_core::mast::MastNodeId; @@ -6,11 +6,9 @@ use super::GlobalProcedureIndex; use crate::{ ast::{ProcedureName, QualifiedProcedureName, Visibility}, diagnostics::{SourceManager, SourceSpan, Spanned}, - AssemblyError, LibraryPath, RpoDigest, + LibraryPath, RpoDigest, }; -pub type CallSet = BTreeSet; - // PROCEDURE CONTEXT // ================================================================================================ @@ -23,7 +21,6 @@ pub struct ProcedureContext { visibility: Visibility, is_kernel: bool, num_locals: u16, - callset: CallSet, } // ------------------------------------------------------------------------------------------------ @@ -44,7 +41,6 @@ impl ProcedureContext { visibility, is_kernel, num_locals: 0, - callset: Default::default(), } } @@ -93,43 +89,13 @@ impl ProcedureContext { // ------------------------------------------------------------------------------------------------ /// State mutators impl ProcedureContext { - pub fn insert_callee(&mut self, callee: RpoDigest) { - self.callset.insert(callee); - } - - pub fn extend_callset(&mut self, callees: I) - where - I: IntoIterator, - { - self.callset.extend(callees); - } - - /// Registers a call to an externally-defined procedure which we have previously compiled. - /// - /// The call set of the callee is added to the call set of the procedure we are currently - /// compiling, to reflect that all of the code reachable from the callee is by extension - /// reachable by the caller. - pub fn register_external_call( - &mut self, - callee: &Procedure, - inlined: bool, - ) -> Result<(), AssemblyError> { - // If we call the callee, it's callset is by extension part of our callset - self.extend_callset(callee.callset().iter().cloned()); - - // If the callee is not being inlined, add it to our callset - if !inlined { - self.insert_callee(callee.mast_root()); - } - - Ok(()) - } - /// Transforms this procedure context into a [Procedure]. /// /// The passed-in `mast_root` defines the MAST root of the procedure's body while /// `mast_node_id` specifies the ID of the procedure's body node in the MAST forest in - /// which the procedure is defined. + /// which the procedure is defined. Note that if the procedure is re-exported (i.e., the body + /// of the procedure is defined in some other MAST forest) `mast_node_id` will point to a + /// single `External` node. /// ///
/// `mast_root` and `mast_node_id` must be consistent. That is, the node located in the MAST @@ -138,7 +104,6 @@ impl ProcedureContext { pub fn into_procedure(self, mast_root: RpoDigest, mast_node_id: MastNodeId) -> Procedure { Procedure::new(self.name, self.visibility, self.num_locals as u32, mast_root, mast_node_id) .with_span(self.span) - .with_callset(self.callset) } } @@ -170,8 +135,6 @@ pub struct Procedure { mast_root: RpoDigest, /// The MAST node id which resolves to the above MAST root. body_node_id: MastNodeId, - /// The set of MAST roots called by this procedure - callset: CallSet, } // ------------------------------------------------------------------------------------------------ @@ -191,7 +154,6 @@ impl Procedure { num_locals, mast_root, body_node_id, - callset: Default::default(), } } @@ -199,11 +161,6 @@ impl Procedure { self.span = span; self } - - pub(crate) fn with_callset(mut self, callset: CallSet) -> Self { - self.callset = callset; - self - } } // ------------------------------------------------------------------------------------------------ @@ -244,12 +201,6 @@ impl Procedure { pub fn body_node_id(&self) -> MastNodeId { self.body_node_id } - - /// Returns a reference to a set of all procedures (identified by their MAST roots) which may - /// be called during the execution of this procedure. - pub fn callset(&self) -> &CallSet { - &self.callset - } } impl Spanned for Procedure { diff --git a/assembly/src/assembler/tests.rs b/assembly/src/assembler/tests.rs index 0cd220b0bc..cb45e88ca8 100644 --- a/assembly/src/assembler/tests.rs +++ b/assembly/src/assembler/tests.rs @@ -1,5 +1,12 @@ +use alloc::vec::Vec; + use pretty_assertions::assert_eq; -use vm_core::{assert_matches, mast::MastForest, Program}; +use vm_core::{ + assert_matches, + crypto::hash::RpoDigest, + mast::{MastForest, MastNode}, + Program, +}; use super::{Assembler, Operation}; use crate::{ @@ -142,7 +149,9 @@ fn nested_blocks() -> Result<(), Report> { .join_nodes(vec![before, r#if1, nested, exec_foo_bar_baz_node_id, syscall_foo_node_id]) .unwrap(); - let expected_program = Program::new(expected_mast_forest_builder.build().0, combined_node_id); + let mut expected_mast_forest = expected_mast_forest_builder.build().0; + expected_mast_forest.make_root(combined_node_id); + let expected_program = Program::new(expected_mast_forest.into(), combined_node_id); assert_eq!(expected_program.hash(), program.hash()); // also check that the program has the right number of procedures (which excludes the dummy @@ -152,8 +161,41 @@ fn nested_blocks() -> Result<(), Report> { Ok(()) } -/// Ensures that a single copy of procedures with the same MAST root are added only once to the MAST -/// forest. +/// Ensures that the arguments of `emit` do indeed modify the digest of a basic block +#[test] +fn emit_instruction_digest() { + let context = TestContext::new(); + + let program_source = r#" + proc.foo + emit.1 + end + + proc.bar + emit.2 + end + + begin + # specific impl irrelevant + exec.foo + exec.bar + end + "#; + + let program = context.assemble(program_source).unwrap(); + + let procedure_digests: Vec = program.mast_forest().procedure_digests().collect(); + + // foo, bar and entrypoint + assert_eq!(3, procedure_digests.len()); + + // Ensure that foo, bar and entrypoint all have different digests + assert_ne!(procedure_digests[0], procedure_digests[1]); + assert_ne!(procedure_digests[0], procedure_digests[2]); + assert_ne!(procedure_digests[1], procedure_digests[2]); +} + +/// Since `foo` and `bar` have the same body, we only expect them to be added once to the program. #[test] fn duplicate_procedure() { let context = TestContext::new(); @@ -180,6 +222,38 @@ fn duplicate_procedure() { assert_eq!(program.num_procedures(), 2); } +#[test] +fn distinguish_grandchildren_correctly() { + let context = TestContext::new(); + + let program_source = r#" + begin + if.true + while.true + trace.1234 + push.1 + end + end + + if.true + while.true + push.1 + end + end + end + "#; + + let program = context.assemble(program_source).unwrap(); + + let join_node = match &program.mast_forest()[program.entrypoint()] { + MastNode::Join(node) => node, + _ => panic!("expected join node"), + }; + + // Make sure that both `if.true` blocks compile down to a different MAST node. + assert_ne!(join_node.first(), join_node.second()); +} + /// Ensures that equal MAST nodes don't get added twice to a MAST forest #[test] fn duplicate_nodes() { @@ -199,10 +273,8 @@ fn duplicate_nodes() { let mut expected_mast_forest = MastForest::new(); - // basic block: mul let mul_basic_block_id = expected_mast_forest.add_block(vec![Operation::Mul], None).unwrap(); - // basic block: add let add_basic_block_id = expected_mast_forest.add_block(vec![Operation::Add], None).unwrap(); // inner split: `if.true add else mul end` @@ -214,9 +286,9 @@ fn duplicate_nodes() { expected_mast_forest.make_root(root_id); - let expected_program = Program::new(expected_mast_forest, root_id); + let expected_program = Program::new(expected_mast_forest.into(), root_id); - assert_eq!(program, expected_program); + assert_eq!(expected_program, program); } #[test] diff --git a/assembly/src/ast/attribute/meta.rs b/assembly/src/ast/attribute/meta.rs new file mode 100644 index 0000000000..29973cf752 --- /dev/null +++ b/assembly/src/ast/attribute/meta.rs @@ -0,0 +1,253 @@ +mod expr; +mod kv; +mod list; + +use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec}; +use core::fmt; + +pub use self::{expr::MetaExpr, kv::MetaKeyValue, list::MetaList}; +use crate::{ast::Ident, parser::HexEncodedValue, Felt, SourceSpan, Span}; + +/// Represents the metadata provided as arguments to an attribute. +#[derive(Clone, PartialEq, Eq)] +pub enum Meta { + /// Represents empty metadata, e.g. `@foo` + Unit, + /// A list of metadata expressions, e.g. `@foo(a, "some text", 0x01)` + /// + /// The list should always have at least one element, and this is guaranteed by the parser. + List(Vec), + /// A set of uniquely-named metadata expressions, e.g. `@foo(letter = a, text = "some text")` + /// + /// The set should always have at least one key-value pair, and this is guaranteed by the + /// parser. + KeyValue(BTreeMap), +} +impl Meta { + /// Borrow the metadata without unwrapping the specific type + /// + /// Returns `None` if there is no meaningful metadata + #[inline] + pub fn borrow(&self) -> Option> { + match self { + Self::Unit => None, + Self::List(ref list) => Some(BorrowedMeta::List(list)), + Self::KeyValue(ref kv) => Some(BorrowedMeta::KeyValue(kv)), + } + } +} +impl FromIterator for Meta { + #[inline] + fn from_iter>(iter: T) -> Self { + let mut iter = iter.into_iter(); + match iter.next() { + None => Self::Unit, + Some(MetaItem::Expr(expr)) => Self::List( + core::iter::once(expr) + .chain(iter.map(|item| match item { + MetaItem::Expr(expr) => expr, + MetaItem::KeyValue(..) => unsafe { core::hint::unreachable_unchecked() }, + })) + .collect(), + ), + Some(MetaItem::KeyValue(k, v)) => Self::KeyValue( + core::iter::once((k, v)) + .chain(iter.map(|item| match item { + MetaItem::KeyValue(k, v) => (k, v), + MetaItem::Expr(_) => unsafe { core::hint::unreachable_unchecked() }, + })) + .collect(), + ), + } + } +} + +impl FromIterator for Meta { + #[inline] + fn from_iter>(iter: T) -> Self { + Self::List(iter.into_iter().collect()) + } +} + +impl FromIterator<(Ident, MetaExpr)> for Meta { + #[inline] + fn from_iter>(iter: T) -> Self { + Self::KeyValue(iter.into_iter().collect()) + } +} + +impl<'a> FromIterator<(&'a str, MetaExpr)> for Meta { + #[inline] + fn from_iter(iter: T) -> Self + where + T: IntoIterator, + { + Self::KeyValue( + iter.into_iter() + .map(|(k, v)| { + let k = Ident::new_unchecked(Span::new(SourceSpan::UNKNOWN, Arc::from(k))); + (k, v) + }) + .collect(), + ) + } +} + +impl From for Meta +where + Meta: FromIterator, + I: IntoIterator, +{ + #[inline] + fn from(iter: I) -> Self { + Self::from_iter(iter) + } +} + +/// Represents a reference to the metadata for an [super::Attribute] +/// +/// See [Meta] for what metadata is represented, and its syntax. +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum BorrowedMeta<'a> { + /// A list of metadata expressions + List(&'a [MetaExpr]), + /// A list of uniquely-named metadata expressions + KeyValue(&'a BTreeMap), +} +impl fmt::Debug for BorrowedMeta<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::List(items) => write!(f, "{items:#?}"), + Self::KeyValue(items) => write!(f, "{items:#?}"), + } + } +} + +/// Represents a single metadata item provided as an argument to an attribute. +/// +/// For example, the `foo` attribute in `@foo(bar, baz)` has two metadata items, both of `Expr` +/// type, which compose a ` +#[derive(Clone, PartialEq, Eq)] +pub enum MetaItem { + /// A metadata expression, e.g. `"some text"` in `@foo("some text")` + /// + /// This represents the element type for `Meta::List`-based attributes. + Expr(MetaExpr), + /// A named metadata expression, e.g. `letter = a` in `@foo(letter = a)` + /// + /// This represents the element type for `Meta::KeyValue`-based attributes. + KeyValue(Ident, MetaExpr), +} + +impl MetaItem { + /// Unwrap this item to extract the contained [MetaExpr]. + /// + /// Panics if this item is not the `Expr` variant. + #[inline] + #[track_caller] + pub fn unwrap_expr(self) -> MetaExpr { + match self { + Self::Expr(expr) => expr, + Self::KeyValue(..) => unreachable!("tried to unwrap key-value as expression"), + } + } + + /// Unwrap this item to extract the contained key-value pair. + /// + /// Panics if this item is not the `KeyValue` variant. + #[inline] + #[track_caller] + pub fn unwrap_key_value(self) -> (Ident, MetaExpr) { + match self { + Self::KeyValue(k, v) => (k, v), + Self::Expr(_) => unreachable!("tried to unwrap expression as key-value"), + } + } +} + +impl From for MetaItem { + fn from(value: Ident) -> Self { + Self::Expr(MetaExpr::Ident(value)) + } +} + +impl From<&str> for MetaItem { + fn from(value: &str) -> Self { + Self::Expr(MetaExpr::String(Ident::new_unchecked(Span::new( + SourceSpan::UNKNOWN, + Arc::from(value), + )))) + } +} + +impl From for MetaItem { + fn from(value: String) -> Self { + Self::Expr(MetaExpr::String(Ident::new_unchecked(Span::new( + SourceSpan::UNKNOWN, + Arc::from(value.into_boxed_str()), + )))) + } +} + +impl From for MetaItem { + fn from(value: u8) -> Self { + Self::Expr(MetaExpr::Int(Span::new(SourceSpan::UNKNOWN, HexEncodedValue::U8(value)))) + } +} + +impl From for MetaItem { + fn from(value: u16) -> Self { + Self::Expr(MetaExpr::Int(Span::new(SourceSpan::UNKNOWN, HexEncodedValue::U16(value)))) + } +} + +impl From for MetaItem { + fn from(value: u32) -> Self { + Self::Expr(MetaExpr::Int(Span::new(SourceSpan::UNKNOWN, HexEncodedValue::U32(value)))) + } +} + +impl From for MetaItem { + fn from(value: Felt) -> Self { + Self::Expr(MetaExpr::Int(Span::new(SourceSpan::UNKNOWN, HexEncodedValue::Felt(value)))) + } +} + +impl From<[Felt; 4]> for MetaItem { + fn from(value: [Felt; 4]) -> Self { + Self::Expr(MetaExpr::Int(Span::new(SourceSpan::UNKNOWN, HexEncodedValue::Word(value)))) + } +} + +impl From<(Ident, V)> for MetaItem +where + V: Into, +{ + fn from(entry: (Ident, V)) -> Self { + let (key, value) = entry; + Self::KeyValue(key, value.into()) + } +} + +impl From<(&str, V)> for MetaItem +where + V: Into, +{ + fn from(entry: (&str, V)) -> Self { + let (key, value) = entry; + let key = Ident::new_unchecked(Span::new(SourceSpan::UNKNOWN, Arc::from(key))); + Self::KeyValue(key, value.into()) + } +} + +impl From<(String, V)> for MetaItem +where + V: Into, +{ + fn from(entry: (String, V)) -> Self { + let (key, value) = entry; + let key = + Ident::new_unchecked(Span::new(SourceSpan::UNKNOWN, Arc::from(key.into_boxed_str()))); + Self::KeyValue(key, value.into()) + } +} diff --git a/assembly/src/ast/attribute/meta/expr.rs b/assembly/src/ast/attribute/meta/expr.rs new file mode 100644 index 0000000000..8c92899b25 --- /dev/null +++ b/assembly/src/ast/attribute/meta/expr.rs @@ -0,0 +1,86 @@ +use alloc::{string::String, sync::Arc}; + +use crate::{ast::Ident, parser::HexEncodedValue, prettier, Felt, SourceSpan, Span, Spanned}; + +/// Represents a metadata expression of an [crate::ast::Attribute] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum MetaExpr { + /// An identifier/keyword, e.g. `inline` + Ident(Ident), + /// A decimal or hexadecimal integer value + Int(Span), + /// A quoted string or identifier + String(Ident), +} + +impl prettier::PrettyPrint for MetaExpr { + fn render(&self) -> prettier::Document { + use prettier::*; + + match self { + Self::Ident(id) => text(id), + Self::Int(value) => text(value), + Self::String(id) => text(format!("\"{}\"", id.as_str().escape_default())), + } + } +} + +impl From for MetaExpr { + fn from(value: Ident) -> Self { + Self::Ident(value) + } +} + +impl From<&str> for MetaExpr { + fn from(value: &str) -> Self { + Self::String(Ident::new_unchecked(Span::new(SourceSpan::UNKNOWN, Arc::from(value)))) + } +} + +impl From for MetaExpr { + fn from(value: String) -> Self { + Self::String(Ident::new_unchecked(Span::new( + SourceSpan::UNKNOWN, + Arc::from(value.into_boxed_str()), + ))) + } +} + +impl From for MetaExpr { + fn from(value: u8) -> Self { + Self::Int(Span::new(SourceSpan::UNKNOWN, HexEncodedValue::U8(value))) + } +} + +impl From for MetaExpr { + fn from(value: u16) -> Self { + Self::Int(Span::new(SourceSpan::UNKNOWN, HexEncodedValue::U16(value))) + } +} + +impl From for MetaExpr { + fn from(value: u32) -> Self { + Self::Int(Span::new(SourceSpan::UNKNOWN, HexEncodedValue::U32(value))) + } +} + +impl From for MetaExpr { + fn from(value: Felt) -> Self { + Self::Int(Span::new(SourceSpan::UNKNOWN, HexEncodedValue::Felt(value))) + } +} + +impl From<[Felt; 4]> for MetaExpr { + fn from(value: [Felt; 4]) -> Self { + Self::Int(Span::new(SourceSpan::UNKNOWN, HexEncodedValue::Word(value))) + } +} + +impl Spanned for MetaExpr { + fn span(&self) -> SourceSpan { + match self { + Self::Ident(spanned) | Self::String(spanned) => spanned.span(), + Self::Int(spanned) => spanned.span(), + } + } +} diff --git a/assembly/src/ast/attribute/meta/kv.rs b/assembly/src/ast/attribute/meta/kv.rs new file mode 100644 index 0000000000..69ed38936b --- /dev/null +++ b/assembly/src/ast/attribute/meta/kv.rs @@ -0,0 +1,135 @@ +use alloc::collections::BTreeMap; +use core::borrow::Borrow; + +use super::MetaExpr; +use crate::{ast::Ident, SourceSpan, Spanned}; + +/// Represents the metadata of a key-value [crate::ast::Attribute], i.e. `@props(key = value)` +#[derive(Clone)] +pub struct MetaKeyValue { + pub span: SourceSpan, + /// The name of the key-value dictionary + pub name: Ident, + /// The set of key-value pairs provided as arguments to this attribute + pub items: BTreeMap, +} + +impl Spanned for MetaKeyValue { + #[inline(always)] + fn span(&self) -> SourceSpan { + self.span + } +} + +impl MetaKeyValue { + pub fn new(name: Ident, items: I) -> Self + where + I: IntoIterator, + K: Into, + V: Into, + { + let items = items.into_iter().map(|(k, v)| (k.into(), v.into())).collect(); + Self { span: SourceSpan::default(), name, items } + } + + pub fn with_span(mut self, span: SourceSpan) -> Self { + self.span = span; + self + } + + /// Get the name of this metadata as a string + #[inline] + pub fn name(&self) -> &str { + self.name.as_str() + } + + /// Get the name of this metadata as an [Ident] + #[inline] + pub fn id(&self) -> Ident { + self.name.clone() + } + + /// Returns true if this metadata contains an entry for `key` + pub fn contains_key(&self, key: &Q) -> bool + where + Ident: Borrow + Ord, + Q: ?Sized + Ord, + { + self.items.contains_key(key) + } + + /// Returns the value associated with `key`, if present in this metadata + pub fn get(&self, key: &Q) -> Option<&MetaExpr> + where + Ident: Borrow + Ord, + Q: ?Sized + Ord, + { + self.items.get(key) + } + + /// Inserts a new key-value entry in this metadata + pub fn insert(&mut self, key: impl Into, value: impl Into) { + self.items.insert(key.into(), value.into()); + } + + /// Removes the entry associated with `key`, if present in this metadata, and returns it + pub fn remove(&mut self, key: &Q) -> Option + where + Ident: Borrow + Ord, + Q: ?Sized + Ord, + { + self.items.remove(key) + } + + /// Get an entry in the key-value map of this metadata for `key` + pub fn entry( + &mut self, + key: Ident, + ) -> alloc::collections::btree_map::Entry<'_, Ident, MetaExpr> { + self.items.entry(key) + } + + /// Get an iterator over the the key-value items of this metadata + #[inline] + pub fn iter(&self) -> impl Iterator { + self.items.iter() + } +} + +impl IntoIterator for MetaKeyValue { + type Item = (Ident, MetaExpr); + type IntoIter = alloc::collections::btree_map::IntoIter; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.items.into_iter() + } +} + +impl Eq for MetaKeyValue {} + +impl PartialEq for MetaKeyValue { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.items == other.items + } +} + +impl PartialOrd for MetaKeyValue { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for MetaKeyValue { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.name.cmp(&other.name).then_with(|| self.items.cmp(&other.items)) + } +} + +impl core::hash::Hash for MetaKeyValue { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.items.hash(state); + } +} diff --git a/assembly/src/ast/attribute/meta/list.rs b/assembly/src/ast/attribute/meta/list.rs new file mode 100644 index 0000000000..99cb6e61ed --- /dev/null +++ b/assembly/src/ast/attribute/meta/list.rs @@ -0,0 +1,102 @@ +use alloc::vec::Vec; + +use super::MetaExpr; +use crate::{ast::Ident, SourceSpan, Spanned}; + +/// Represents the metadata of a named list [crate::ast::Attribute], i.e. `@name(item0, .., itemN)` +#[derive(Clone)] +pub struct MetaList { + pub span: SourceSpan, + /// The identifier used as the name of this attribute + pub name: Ident, + /// The list of items representing the value of this attribute - will always contain at least + /// one element when parsed. + pub items: Vec, +} + +impl Spanned for MetaList { + #[inline(always)] + fn span(&self) -> SourceSpan { + self.span + } +} + +impl MetaList { + pub fn new(name: Ident, items: I) -> Self + where + I: IntoIterator, + { + Self { + span: SourceSpan::default(), + name, + items: items.into_iter().collect(), + } + } + + pub fn with_span(mut self, span: SourceSpan) -> Self { + self.span = span; + self + } + + /// Get the name of this attribute as a string + pub fn name(&self) -> &str { + self.name.as_str() + } + + /// Get the name of this attribute as an [Ident] + pub fn id(&self) -> Ident { + self.name.clone() + } + + /// Returns true if the metadata list is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.items.is_empty() + } + + /// Returns the number of items in the metadata list + #[inline] + pub fn len(&self) -> usize { + self.items.len() + } + + /// Get the metadata list as a slice + #[inline] + pub fn as_slice(&self) -> &[MetaExpr] { + self.items.as_slice() + } + + /// Get the metadata list as a mutable slice + #[inline] + pub fn as_mut_slice(&mut self) -> &mut [MetaExpr] { + self.items.as_mut_slice() + } +} + +impl Eq for MetaList {} + +impl PartialEq for MetaList { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.items == other.items + } +} + +impl PartialOrd for MetaList { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for MetaList { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.name.cmp(&other.name).then_with(|| self.items.cmp(&other.items)) + } +} + +impl core::hash::Hash for MetaList { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.items.hash(state); + } +} diff --git a/assembly/src/ast/attribute/mod.rs b/assembly/src/ast/attribute/mod.rs new file mode 100644 index 0000000000..3c7937ba5c --- /dev/null +++ b/assembly/src/ast/attribute/mod.rs @@ -0,0 +1,246 @@ +mod meta; +mod set; + +use core::fmt; + +pub use self::{ + meta::{BorrowedMeta, Meta, MetaExpr, MetaItem, MetaKeyValue, MetaList}, + set::{AttributeSet, AttributeSetEntry}, +}; +use crate::{ast::Ident, prettier, SourceSpan, Spanned}; + +/// An [Attribute] represents some named metadata attached to a Miden Assembly procedure. +/// +/// An attribute has no predefined structure per se, but syntactically there are three types: +/// +/// * Marker attributes, i.e. just a name and no associated data. Attributes of this type are used +/// to "mark" the item they are attached to with some unique trait or behavior implied by the +/// name. For example, `@inline`. NOTE: `@inline()` is not valid syntax. +/// +/// * List attributes, i.e. a name and one or more comma-delimited expressions. Attributes of this +/// type are used for cases where you want to parameterize a marker-like trait. To use a Rust +/// example, `#[derive(Trait)]` is a list attribute, where `derive` is the marker, but we want to +/// instruct whatever processes derives, what traits it needs to derive. The equivalent syntax in +/// Miden Assembly would be `@derive(Trait)`. Lists must always have at least one item. +/// +/// * Key-value attributes, i.e. a name and a value. Attributes of this type are used to attach +/// named properties to an item. For example, `@storage(offset = 1)`. Possible value types are: +/// bare identifiers, decimal or hexadecimal integers, and quoted strings. +/// +/// There are no restrictions on what attributes can exist or be used. However, there are a set of +/// attributes that the assembler knows about, and acts on, which will be stripped during assembly. +/// Any remaining attributes we don't explicitly handle in the assembler, will be passed along as +/// metadata attached to the procedures in the MAST output by the assembler. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Attribute { + /// A named behavior, trait or action; e.g. `@inline` + Marker(Ident), + /// A parameterized behavior, trait or action; e.g. `@inline(always)` or `@derive(Foo, Bar)` + List(MetaList), + /// A named property; e.g. `@props(key = "value")`, `@props(a = 1, b = 0x1)` + KeyValue(MetaKeyValue), +} + +impl Attribute { + /// Create a new [Attribute] with the given metadata. + /// + /// The metadata value must be convertible to [Meta]. + /// + /// For marker attributes, you can either construct the `Marker` variant directly, or pass + /// either `Meta::Unit` or `None` as the metadata argument. + /// + /// If the metadata is empty, a `Marker` attribute will be produced, otherwise the type depends + /// on the metadata. If the metadata is _not_ key-value shaped, a `List` is produced, otherwise + /// a `KeyValue`. + pub fn new(name: Ident, metadata: impl Into) -> Self { + let metadata = metadata.into(); + match metadata { + Meta::Unit => Self::Marker(name), + Meta::List(items) => Self::List(MetaList { span: Default::default(), name, items }), + Meta::KeyValue(items) => { + Self::KeyValue(MetaKeyValue { span: Default::default(), name, items }) + }, + } + } + + /// Create a new [Attribute] from an metadata-producing iterator. + /// + /// If the iterator is empty, a `Marker` attribute will be produced, otherwise the type depends + /// on the metadata. If the metadata is _not_ key-value shaped, a `List` is produced, otherwise + /// a `KeyValue`. + pub fn from_iter(name: Ident, metadata: I) -> Self + where + Meta: FromIterator, + I: IntoIterator, + { + Self::new(name, Meta::from_iter(metadata)) + } + + /// Set the source location for this attribute + pub fn with_span(self, span: SourceSpan) -> Self { + match self { + Self::Marker(id) => Self::Marker(id.with_span(span)), + Self::List(list) => Self::List(list.with_span(span)), + Self::KeyValue(kv) => Self::KeyValue(kv.with_span(span)), + } + } + + /// Get the name of this attribute as a string + pub fn name(&self) -> &str { + match self { + Self::Marker(id) => id.as_str(), + Self::List(list) => list.name(), + Self::KeyValue(kv) => kv.name(), + } + } + + /// Get the name of this attribute as an [Ident] + pub fn id(&self) -> Ident { + match self { + Self::Marker(id) => id.clone(), + Self::List(list) => list.id(), + Self::KeyValue(kv) => kv.id(), + } + } + + /// Returns true if this is a marker attribute + pub fn is_marker(&self) -> bool { + matches!(self, Self::Marker(_)) + } + + /// Returns true if this is a list attribute + pub fn is_list(&self) -> bool { + matches!(self, Self::List(_)) + } + + /// Returns true if this is a key-value attribute + pub fn is_key_value(&self) -> bool { + matches!(self, Self::KeyValue(_)) + } + + /// Get the metadata for this attribute + /// + /// Returns `None` if this is a marker attribute, and thus has no metadata + pub fn metadata(&self) -> Option> { + match self { + Self::Marker(_) => None, + Self::List(ref list) => Some(BorrowedMeta::List(&list.items)), + Self::KeyValue(ref kv) => Some(BorrowedMeta::KeyValue(&kv.items)), + } + } +} + +impl fmt::Debug for Attribute { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Marker(id) => f.debug_tuple("Marker").field(&id).finish(), + Self::List(meta) => f + .debug_struct("List") + .field("name", &meta.name) + .field("items", &meta.items) + .finish(), + Self::KeyValue(meta) => f + .debug_struct("KeyValue") + .field("name", &meta.name) + .field("items", &meta.items) + .finish(), + } + } +} + +impl fmt::Display for Attribute { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use prettier::PrettyPrint; + self.pretty_print(f) + } +} + +impl prettier::PrettyPrint for Attribute { + fn render(&self) -> prettier::Document { + use prettier::*; + let doc = text(format!("@{}", &self.name())); + match self { + Self::Marker(_) => doc, + Self::List(meta) => { + let singleline_items = meta + .items + .iter() + .map(|item| item.render()) + .reduce(|acc, item| acc + const_text(", ") + item) + .unwrap_or(Document::Empty); + let multiline_items = indent( + 4, + nl() + meta + .items + .iter() + .map(|item| item.render()) + .reduce(|acc, item| acc + nl() + item) + .unwrap_or(Document::Empty), + ) + nl(); + doc + const_text("(") + (singleline_items | multiline_items) + const_text(")") + }, + Self::KeyValue(meta) => { + let singleline_items = meta + .items + .iter() + .map(|(k, v)| text(k) + const_text(" = ") + v.render()) + .reduce(|acc, item| acc + const_text(", ") + item) + .unwrap_or(Document::Empty); + let multiline_items = indent( + 4, + nl() + meta + .items + .iter() + .map(|(k, v)| text(k) + const_text(" = ") + v.render()) + .reduce(|acc, item| acc + nl() + item) + .unwrap_or(Document::Empty), + ) + nl(); + doc + const_text("(") + (singleline_items | multiline_items) + const_text(")") + }, + } + } +} + +impl Spanned for Attribute { + fn span(&self) -> SourceSpan { + match self { + Self::Marker(id) => id.span(), + Self::List(list) => list.span(), + Self::KeyValue(kv) => kv.span(), + } + } +} + +impl From for Attribute { + fn from(value: Ident) -> Self { + Self::Marker(value) + } +} + +impl From<(K, V)> for Attribute +where + K: Into, + V: Into, +{ + fn from(kv: (K, V)) -> Self { + let (key, value) = kv; + Self::List(MetaList { + span: SourceSpan::default(), + name: key.into(), + items: vec![value.into()], + }) + } +} + +impl From for Attribute { + fn from(value: MetaList) -> Self { + Self::List(value) + } +} + +impl From for Attribute { + fn from(value: MetaKeyValue) -> Self { + Self::KeyValue(value) + } +} diff --git a/assembly/src/ast/attribute/set.rs b/assembly/src/ast/attribute/set.rs new file mode 100644 index 0000000000..f215f96db0 --- /dev/null +++ b/assembly/src/ast/attribute/set.rs @@ -0,0 +1,239 @@ +use alloc::vec::Vec; +use core::fmt; + +use super::*; +use crate::ast::Ident; + +/// An [AttributeSet] provides storage and access to all of the attributes attached to a Miden +/// Assembly item, e.g. procedure definition. +/// +/// Attributes are uniqued by name, so if you attempt to add multiple attributes with the same name, +/// the last write wins. In Miden Assembly syntax, multiple key-value attributes are merged +/// automatically, and a syntax error is only generated when keys conflict. All other attribute +/// types produce an error if they are declared multiple times on the same item. +#[derive(Default, Clone, PartialEq, Eq)] +pub struct AttributeSet { + /// The attributes in this set. + /// + /// The [AttributeSet] structure has map-like semantics, so why are we using a vector here? + /// + /// * We expect attributes to be relatively rare, with no more than a handful on the same item + /// at any given time. + /// * A vector is much more space and time efficient to search for small numbers of items + /// * We can acheive map-like semantics without O(N) complexity by keeping the vector sorted by + /// the attribute name, and using binary search to search it. This gives us O(1) best-case + /// performance, and O(log N) in the worst case. + attrs: Vec, +} + +impl AttributeSet { + /// Create a new [AttributeSet] from `attrs` + /// + /// If the input attributes have duplicate entries for the same name, only one will be selected, + /// but it is unspecified which. + pub fn new(attrs: I) -> Self + where + I: IntoIterator, + { + let mut this = Self { attrs: attrs.into_iter().collect() }; + this.attrs.sort_by_key(|attr| attr.id()); + this.attrs.dedup_by_key(|attr| attr.id()); + this + } + + /// Returns true if there are no attributes in this set + #[inline] + pub fn is_empty(&self) -> bool { + self.attrs.is_empty() + } + + /// Returns the number of attributes in this set + #[inline] + pub fn len(&self) -> usize { + self.attrs.len() + } + + /// Check if this set has an attributed named `name` + pub fn has(&self, name: impl AsRef) -> bool { + self.get(name).is_some() + } + + /// Get the attribute named `name`, if one is present. + pub fn get(&self, name: impl AsRef) -> Option<&Attribute> { + let name = name.as_ref(); + match self.attrs.binary_search_by_key(&name, |attr| attr.name()) { + Ok(index) => self.attrs.get(index), + Err(_) => None, + } + } + + /// Get a mutable reference to the attribute named `name`, if one is present. + pub fn get_mut(&mut self, name: impl AsRef) -> Option<&mut Attribute> { + let name = name.as_ref(); + match self.attrs.binary_search_by_key(&name, |attr| attr.name()) { + Ok(index) => self.attrs.get_mut(index), + Err(_) => None, + } + } + + /// Get an iterator over the attributes in this set + #[inline] + pub fn iter(&self) -> core::slice::Iter<'_, Attribute> { + self.attrs.iter() + } + + /// Get a mutable iterator over the attributes in this set + #[inline] + pub fn iter_mut(&mut self) -> core::slice::IterMut<'_, Attribute> { + self.attrs.iter_mut() + } + + /// Insert `attr` in the attribute set, replacing any existing attribute with the same name + /// + /// Returns true if the insertion was new, or false if the insertion replaced an existing entry. + pub fn insert(&mut self, attr: Attribute) -> bool { + let name = attr.name(); + match self.attrs.binary_search_by_key(&name, |attr| attr.name()) { + Ok(index) => { + // Replace existing attribute + self.attrs[index] = attr; + false + }, + Err(index) => { + self.attrs.insert(index, attr); + true + }, + } + } + + /// Insert `attr` in the attribute set, but only if there is no existing attribute with the same + /// name. + /// + /// Returns `Err` with `attr` if there is already an existing attribute with the same name. + pub fn insert_new(&mut self, attr: Attribute) -> Result<(), Attribute> { + if self.has(attr.name()) { + Err(attr) + } else { + self.insert(attr); + Ok(()) + } + } + + /// Removes the attribute named `name`, if present. + pub fn remove(&mut self, name: impl AsRef) -> Option { + let name = name.as_ref(); + match self.attrs.binary_search_by_key(&name, |attr| attr.name()) { + Ok(index) => Some(self.attrs.remove(index)), + Err(_) => None, + } + } + + /// Gets the given key's corresponding entry in the set for in-place modfication + pub fn entry(&mut self, key: Ident) -> AttributeSetEntry<'_> { + match self.attrs.binary_search_by_key(&key.as_str(), |attr| attr.name()) { + Ok(index) => AttributeSetEntry::occupied(self, index), + Err(index) => AttributeSetEntry::vacant(self, key, index), + } + } + + /// Clear all attributes from the set + #[inline] + pub fn clear(&mut self) { + self.attrs.clear(); + } +} + +impl fmt::Debug for AttributeSet { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = f.debug_map(); + for attr in self.iter() { + match attr.metadata() { + None => { + builder.entry(&attr.name(), &"None"); + }, + Some(meta) => { + builder.entry(&attr.name(), &meta); + }, + } + } + builder.finish() + } +} + +impl FromIterator for AttributeSet { + #[inline] + fn from_iter>(iter: T) -> Self { + Self::new(iter) + } +} + +impl Extend for AttributeSet { + fn extend>(&mut self, iter: T) { + for attr in iter { + self.insert(attr); + } + } +} + +/// Represents an entry under a specific key in a [AttributeSet] +pub enum AttributeSetEntry<'a> { + /// The entry is currently occupied with a value + Occupied(AttributeSetOccupiedEntry<'a>), + /// The entry is currently vacant + Vacant(AttributeSetVacantEntry<'a>), +} +impl<'a> AttributeSetEntry<'a> { + fn occupied(set: &'a mut AttributeSet, index: usize) -> Self { + Self::Occupied(AttributeSetOccupiedEntry { set, index }) + } + + fn vacant(set: &'a mut AttributeSet, key: Ident, index: usize) -> Self { + Self::Vacant(AttributeSetVacantEntry { set, key, index }) + } +} + +#[doc(hidden)] +pub struct AttributeSetOccupiedEntry<'a> { + set: &'a mut AttributeSet, + index: usize, +} +impl AttributeSetOccupiedEntry<'_> { + #[inline] + pub fn get(&self) -> &Attribute { + &self.set.attrs[self.index] + } + + #[inline] + pub fn get_mut(&mut self) -> &mut Attribute { + &mut self.set.attrs[self.index] + } + + pub fn insert(self, attr: Attribute) { + if attr.name() != self.get().name() { + self.set.insert(attr); + } else { + self.set.attrs[self.index] = attr; + } + } + + #[inline] + pub fn remove(self) -> Attribute { + self.set.attrs.remove(self.index) + } +} + +#[doc(hidden)] +pub struct AttributeSetVacantEntry<'a> { + set: &'a mut AttributeSet, + key: Ident, + index: usize, +} +impl AttributeSetVacantEntry<'_> { + pub fn insert(self, attr: Attribute) { + if self.key != attr.id() { + self.set.insert(attr); + } else { + self.set.attrs.insert(self.index, attr); + } + } +} diff --git a/assembly/src/ast/block.rs b/assembly/src/ast/block.rs index d263a97e97..6f880ef91b 100644 --- a/assembly/src/ast/block.rs +++ b/assembly/src/ast/block.rs @@ -70,7 +70,7 @@ impl crate::prettier::PrettyPrint for Block { .map(PrettyPrint::render) .reduce(|acc, doc| acc + nl() + doc); - body.map(|body| indent(4, body)).unwrap_or(Document::Empty) + body.map(|body| indent(4, nl() + body)).unwrap_or(Document::Empty) } } diff --git a/assembly/src/ast/ident.rs b/assembly/src/ast/ident.rs index 0f7e594c8a..ea2d13a238 100644 --- a/assembly/src/ast/ident.rs +++ b/assembly/src/ast/ident.rs @@ -12,8 +12,8 @@ use crate::{SourceSpan, Span, Spanned}; pub enum IdentError { #[error("invalid identifier: cannot be empty")] Empty, - #[error("invalid identifier: must contain only lowercase, ascii alphanumeric characters, or underscores")] - InvalidChars, + #[error("invalid identifier '{ident}': must contain only lowercase, ascii alphanumeric characters, or underscores")] + InvalidChars { ident: Arc }, #[error("invalid identifier: must start with lowercase ascii alphabetic character")] InvalidStart, #[error("invalid identifier: length exceeds the maximum of {max} bytes")] @@ -109,7 +109,7 @@ impl Ident { return Err(IdentError::InvalidStart); } if !source.chars().all(|c| c.is_ascii_alphabetic() || matches!(c, '_' | '0'..='9')) { - return Err(IdentError::InvalidChars); + return Err(IdentError::InvalidChars { ident: source.into() }); } Ok(()) } diff --git a/assembly/src/ast/instruction/print.rs b/assembly/src/ast/instruction/print.rs index a76239f70e..b7ca11c37c 100644 --- a/assembly/src/ast/instruction/print.rs +++ b/assembly/src/ast/instruction/print.rs @@ -427,7 +427,7 @@ mod tests { let target = InvocationTarget::MastRoot(Span::unknown(digest)); let instruction = format!("{}", Instruction::Exec(target)); assert_eq!( - "exec.0x03b49d98981575360dd1f8c8b5a7feefcadadd56ec2a33e3e43edae3577de150", + "exec.0x90b3926941061b28638b6cc0bbdb3bcb335e834dc9ab8044250875055202d2fe", instruction ); } diff --git a/assembly/src/ast/mod.rs b/assembly/src/ast/mod.rs index 256173191d..a606d9e906 100644 --- a/assembly/src/ast/mod.rs +++ b/assembly/src/ast/mod.rs @@ -1,5 +1,6 @@ //! Abstract syntax tree (AST) components of Miden programs, modules, and procedures. +mod attribute; mod block; mod constants; mod form; @@ -16,6 +17,10 @@ mod tests; pub mod visit; pub use self::{ + attribute::{ + Attribute, AttributeSet, AttributeSetEntry, BorrowedMeta, Meta, MetaExpr, MetaItem, + MetaKeyValue, MetaList, + }, block::Block, constants::{Constant, ConstantExpr, ConstantOp}, form::Form, diff --git a/assembly/src/ast/procedure/mod.rs b/assembly/src/ast/procedure/mod.rs index e74bd1fc26..8878018cf4 100644 --- a/assembly/src/ast/procedure/mod.rs +++ b/assembly/src/ast/procedure/mod.rs @@ -14,7 +14,10 @@ pub use self::{ procedure::{Procedure, Visibility}, resolver::{LocalNameResolver, ResolvedProcedure}, }; -use crate::{ast::Invoke, SourceSpan, Span, Spanned}; +use crate::{ + ast::{AttributeSet, Invoke}, + SourceSpan, Span, Spanned, +}; // EXPORT // ================================================================================================ @@ -56,6 +59,14 @@ impl Export { } } + /// Returns the attributes for this procedure. + pub fn attributes(&self) -> Option<&AttributeSet> { + match self { + Self::Procedure(ref proc) => Some(proc.attributes()), + Self::Alias(_) => None, + } + } + /// Returns the visibility of this procedure (e.g. public or private). /// /// See [Visibility] for more details on what visibilities are supported. diff --git a/assembly/src/ast/procedure/name.rs b/assembly/src/ast/procedure/name.rs index f7979bd429..731610aaf9 100644 --- a/assembly/src/ast/procedure/name.rs +++ b/assembly/src/ast/procedure/name.rs @@ -298,17 +298,17 @@ impl FromStr for ProcedureName { match c { '"' => { if chars.next().is_some() { - break Err(IdentError::InvalidChars); + break Err(IdentError::InvalidChars { ident: s.into() }); } let tok = &s[1..pos]; break Ok(Arc::from(tok.to_string().into_boxed_str())); }, c if c.is_alphanumeric() => continue, '_' | '$' | '-' | '!' | '?' | '<' | '>' | ':' | '.' => continue, - _ => break Err(IdentError::InvalidChars), + _ => break Err(IdentError::InvalidChars { ident: s.into() }), } } else { - break Err(IdentError::InvalidChars); + break Err(IdentError::InvalidChars { ident: s.into() }); } }, Some((_, c)) if c.is_ascii_lowercase() || c == '_' || c == '$' => { @@ -317,13 +317,13 @@ impl FromStr for ProcedureName { '_' | '$' => false, _ => true, }) { - Err(IdentError::InvalidChars) + Err(IdentError::InvalidChars { ident: s.into() }) } else { Ok(Arc::from(s.to_string().into_boxed_str())) } }, Some((_, c)) if c.is_ascii_uppercase() => Err(IdentError::Casing(CaseKindError::Snake)), - Some(_) => Err(IdentError::InvalidChars), + Some(_) => Err(IdentError::InvalidChars { ident: s.into() }), }?; Ok(Self(Ident::new_unchecked(Span::unknown(raw)))) } diff --git a/assembly/src/ast/procedure/procedure.rs b/assembly/src/ast/procedure/procedure.rs index 2555af2967..70a0ac6a9f 100644 --- a/assembly/src/ast/procedure/procedure.rs +++ b/assembly/src/ast/procedure/procedure.rs @@ -3,7 +3,7 @@ use core::fmt; use super::ProcedureName; use crate::{ - ast::{Block, Invoke}, + ast::{Attribute, AttributeSet, Block, Invoke}, SourceSpan, Span, Spanned, }; @@ -55,6 +55,8 @@ pub struct Procedure { span: SourceSpan, /// The documentation attached to this procedure docs: Option>, + /// The attributes attached to this procedure + attrs: AttributeSet, /// The local name of this procedure name: ProcedureName, /// The visibility of this procedure (i.e. whether it is exported or not) @@ -81,6 +83,7 @@ impl Procedure { Self { span, docs: None, + attrs: Default::default(), name, visibility, num_locals, @@ -95,6 +98,15 @@ impl Procedure { self } + /// Adds attributes to this procedure definition + pub fn with_attributes(mut self, attrs: I) -> Self + where + I: IntoIterator, + { + self.attrs.extend(attrs); + self + } + /// Modifies the visibility of this procedure. /// /// This is made crate-local as the visibility of a procedure is virtually always determined @@ -134,6 +146,30 @@ impl Procedure { self.docs.as_ref() } + /// Get the attributes attached to this procedure + #[inline] + pub fn attributes(&self) -> &AttributeSet { + &self.attrs + } + + /// Get the attributes attached to this procedure, mutably + #[inline] + pub fn attributes_mut(&mut self) -> &mut AttributeSet { + &mut self.attrs + } + + /// Returns true if this procedure has an attribute named `name` + #[inline] + pub fn has_attribute(&self, name: impl AsRef) -> bool { + self.attrs.has(name) + } + + /// Returns the attribute named `name`, if present + #[inline] + pub fn get_attribute(&self, name: impl AsRef) -> Option<&Attribute> { + self.attrs.get(name) + } + /// Returns a reference to the [Block] containing the body of this procedure. pub fn body(&self) -> &Block { &self.body @@ -216,6 +252,15 @@ impl crate::prettier::PrettyPrint for Procedure { .unwrap_or(Document::Empty); } + if !self.attrs.is_empty() { + doc = self + .attrs + .iter() + .map(|attr| attr.render()) + .reduce(|acc, attr| acc + nl() + attr) + .unwrap_or(Document::Empty); + } + doc += display(self.visibility) + const_text(".") + display(&self.name); if self.num_locals > 0 { doc += const_text(".") + display(self.num_locals); @@ -231,6 +276,7 @@ impl fmt::Debug for Procedure { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("Procedure") .field("docs", &self.docs) + .field("attrs", &self.attrs) .field("name", &self.name) .field("visibility", &self.visibility) .field("num_locals", &self.num_locals) @@ -248,6 +294,7 @@ impl PartialEq for Procedure { && self.visibility == other.visibility && self.num_locals == other.num_locals && self.body == other.body + && self.attrs == other.attrs && self.docs == other.docs } } diff --git a/assembly/src/ast/tests.rs b/assembly/src/ast/tests.rs index 50ef338a87..7fff35c587 100644 --- a/assembly/src/ast/tests.rs +++ b/assembly/src/ast/tests.rs @@ -11,6 +11,12 @@ use crate::{ Felt, Span, }; +macro_rules! id { + ($name:ident) => { + Ident::new(stringify!($name)).unwrap() + }; +} + macro_rules! inst { ($inst:ident($value:expr)) => { Op::Inst(Span::unknown(Instruction::$inst($value))) @@ -145,7 +151,20 @@ macro_rules! proc { ))) }; - ($docs:expr, $name:ident, $num_locals:literal, $body:expr) => { + ([$($attr:expr),*], $name:ident, $num_locals:literal, $body:expr) => { + Form::Procedure(Export::Procedure( + Procedure::new( + Default::default(), + Visibility::Private, + stringify!($name).parse().expect("invalid procedure name"), + $num_locals, + $body, + ) + .with_attributes([$($attr),*]), + )) + }; + + ($docs:literal, $name:ident, $num_locals:literal, $body:expr) => { Form::Procedure(Export::Procedure( Procedure::new( Default::default(), @@ -157,6 +176,20 @@ macro_rules! proc { .with_docs(Some(Span::unknown($docs.to_string()))), )) }; + + ($docs:literal, [$($attr:expr),*], $name:ident, $num_locals:literal, $body:expr) => { + Form::Procedure(Export::Procedure( + Procedure::new( + Default::default(), + Visibility::Private, + stringify!($name).parse().expect("invalid procedure name"), + $num_locals, + $body, + ) + .with_docs($docs) + .with_attributes([$($attr),*]), + )) + }; } macro_rules! export { @@ -569,7 +602,7 @@ fn test_ast_parsing_module_sequential_if() -> Result<(), Report> { } #[test] -fn parsed_while_if_body() { +fn test_ast_parsing_while_if_body() { let context = TestContext::new(); let source = source_file!( &context, @@ -599,6 +632,65 @@ fn parsed_while_if_body() { assert_forms!(context, source, forms); } +#[test] +fn test_ast_parsing_attributes() -> Result<(), Report> { + let context = TestContext::new(); + + let source = source_file!( + &context, + r#" + # Simple marker attribute + @inline + proc.foo.1 + loc_load.0 + end + + # List attribute + @inline(always) + proc.bar.2 + padw + end + + # Key value attributes of various kinds + @numbers(decimal = 1, hex = 0xdeadbeef) + @props(name = baz) + @props(string = "not a valid quoted identifier") + proc.baz.2 + padw + end + + begin + exec.foo + exec.bar + exec.baz + end"# + ); + + let inline = Attribute::Marker(id!(inline)); + let inline_always = Attribute::List(MetaList::new(id!(inline), [MetaExpr::Ident(id!(always))])); + let numbers = Attribute::new( + id!(numbers), + [(id!(decimal), MetaExpr::from(1u8)), (id!(hex), MetaExpr::from(0xdeadbeefu32))], + ); + let props = Attribute::new( + id!(props), + [ + (id!(name), MetaExpr::from(id!(baz))), + (id!(string), MetaExpr::from("not a valid quoted identifier")), + ], + ); + + let forms = module!( + proc!([inline], foo, 1, block!(inst!(LocLoad(0u16.into())))), + proc!([inline_always], bar, 2, block!(inst!(PadW))), + proc!([numbers, props], baz, 2, block!(inst!(PadW))), + begin!(exec!(foo), exec!(bar), exec!(baz)) + ); + assert_eq!(context.parse_forms(source)?, forms); + + Ok(()) +} + // PROCEDURE IMPORTS // ================================================================================================ @@ -1078,6 +1170,6 @@ fn assert_parsing_line_unexpected_token() { " : ^|^", " : `-- found a mul here", " `----", - r#" help: expected "begin", or "const", or "export", or "proc", or "use", or end of file, or doc comment"# + r#" help: expected "@", or "begin", or "const", or "export", or "proc", or "use", or end of file, or doc comment"# ); } diff --git a/assembly/src/ast/visit.rs b/assembly/src/ast/visit.rs index b9d60cbe65..ecd11fb86e 100644 --- a/assembly/src/ast/visit.rs +++ b/assembly/src/ast/visit.rs @@ -139,7 +139,7 @@ pub trait Visit { } } -impl<'a, V, T> Visit for &'a mut V +impl Visit for &mut V where V: ?Sized + Visit, { @@ -575,7 +575,7 @@ pub trait VisitMut { } } -impl<'a, V, T> VisitMut for &'a mut V +impl VisitMut for &mut V where V: ?Sized + VisitMut, { diff --git a/assembly/src/compile.rs b/assembly/src/compile.rs index 1d2fda3a45..2e36858cdd 100644 --- a/assembly/src/compile.rs +++ b/assembly/src/compile.rs @@ -130,7 +130,7 @@ impl Compile for Module { } } -impl<'a> Compile for &'a Module { +impl Compile for &Module { #[inline(always)] fn compile_with_options( self, @@ -197,7 +197,7 @@ impl Compile for Arc { } } -impl<'a> Compile for &'a str { +impl Compile for &str { #[inline(always)] fn compile_with_options( self, @@ -208,7 +208,7 @@ impl<'a> Compile for &'a str { } } -impl<'a> Compile for &'a String { +impl Compile for &String { #[inline(always)] fn compile_with_options( self, @@ -251,7 +251,7 @@ impl Compile for Box { } } -impl<'a> Compile for Cow<'a, str> { +impl Compile for Cow<'_, str> { #[inline(always)] fn compile_with_options( self, @@ -265,7 +265,7 @@ impl<'a> Compile for Cow<'a, str> { // COMPILE IMPLEMENTATIONS FOR BYTES // ------------------------------------------------------------------------------------------------ -impl<'a> Compile for &'a [u8] { +impl Compile for &[u8] { #[inline] fn compile_with_options( self, @@ -350,7 +350,7 @@ where // ------------------------------------------------------------------------------------------------ #[cfg(feature = "std")] -impl<'a> Compile for &'a std::path::Path { +impl Compile for &std::path::Path { fn compile_with_options( self, source_manager: &dyn SourceManager, diff --git a/assembly/src/errors.rs b/assembly/src/errors.rs index 8c4a3d22cd..4fd06609ef 100644 --- a/assembly/src/errors.rs +++ b/assembly/src/errors.rs @@ -71,6 +71,14 @@ pub enum AssemblyError { #[source_code] source_file: Option>, }, + + #[error("invalid procedure: body must contain at least one instruction if it has decorators")] + #[diagnostic()] + EmptyProcedureBodyWithDecorators { + span: SourceSpan, + #[source_code] + source_file: Option>, + }, #[error(transparent)] #[diagnostic(transparent)] Other(#[from] RelatedError), diff --git a/assembly/src/library/error.rs b/assembly/src/library/error.rs index c8e31f79db..3df795ca3a 100644 --- a/assembly/src/library/error.rs +++ b/assembly/src/library/error.rs @@ -11,4 +11,6 @@ pub enum LibraryError { InvalidKernelExport { procedure_path: QualifiedProcedureName }, #[error(transparent)] Kernel(#[from] KernelError), + #[error("invalid export: no procedure root for {procedure_path} procedure")] + NoProcedureRootForExport { procedure_path: QualifiedProcedureName }, } diff --git a/assembly/src/library/mod.rs b/assembly/src/library/mod.rs index 8ae3415bd8..7e2d35789a 100644 --- a/assembly/src/library/mod.rs +++ b/assembly/src/library/mod.rs @@ -1,17 +1,19 @@ use alloc::{ collections::{BTreeMap, BTreeSet}, - string::{String, ToString}, + string::String, + sync::Arc, vec::Vec, }; use vm_core::{ crypto::hash::RpoDigest, + debuginfo::Span, mast::{MastForest, MastNodeId}, utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}, Kernel, }; -use crate::ast::{ProcedureName, QualifiedProcedureName}; +use crate::ast::{Ident, ProcedureName, QualifiedProcedureName}; mod error; mod module; @@ -43,12 +45,18 @@ pub struct Library { /// The content hash of this library, formed by hashing the roots of all exports in /// lexicographical order (by digest, not procedure name) digest: RpoDigest, - /// A map between procedure paths and the corresponding procedure toots in the MAST forest. + /// A map between procedure paths and the corresponding procedure roots in the MAST forest. /// Multiple paths can map to the same root, and also, some roots may not be associated with /// any paths. - exports: BTreeMap, + /// + /// Note that we use `MastNodeId` as an identifier for procedures instead of MAST root, since 2 + /// different procedures with the same MAST root can be different due to the decorators they + /// contain. However, note that `MastNodeId` is also not a unique identifier for procedures; if + /// the procedures have the same MAST root and decorators, they will have the same + /// `MastNodeId`. + exports: BTreeMap, /// The MAST forest underlying this library. - mast_forest: MastForest, + mast_forest: Arc, } impl AsRef for Library { @@ -58,47 +66,32 @@ impl AsRef for Library { } } -#[derive(Debug, Clone, PartialEq, Eq)] -#[repr(u8)] -enum Export { - /// The export is contained in the [MastForest] of this library - Local(MastNodeId), - /// The export is a re-export of an externally-defined procedure from another library - External(RpoDigest), -} - +// ------------------------------------------------------------------------------------------------ /// Constructors impl Library { /// Constructs a new [`Library`] from the provided MAST forest and a set of exports. + /// + /// # Errors + /// Returns an error if any of the specified exports do not have a corresponding procedure root + /// in the provided MAST forest. pub fn new( - mast_forest: MastForest, - exports: BTreeMap, - ) -> Self { - let mut fqn_to_export = BTreeMap::new(); - - // convert fqn |-> mast_root map into fqn |-> mast_node_id map - for (fqn, mast_root) in exports.into_iter() { - match mast_forest.find_procedure_root(mast_root) { - Some(node_id) => { - fqn_to_export.insert(fqn, Export::Local(node_id)); - }, - None => { - fqn_to_export.insert(fqn, Export::External(mast_root)); - }, + mast_forest: Arc, + exports: BTreeMap, + ) -> Result { + for (fqn, &proc_body_id) in exports.iter() { + if !mast_forest.is_procedure_root(proc_body_id) { + return Err(LibraryError::NoProcedureRootForExport { procedure_path: fqn.clone() }); } } - let digest = content_hash(&fqn_to_export, &mast_forest); + let digest = compute_content_hash(&exports, &mast_forest); - Self { - digest, - exports: fqn_to_export, - mast_forest, - } + Ok(Self { digest, exports, mast_forest }) } } -/// Accessors +// ------------------------------------------------------------------------------------------------ +/// Public accessors impl Library { /// Returns the [RpoDigest] representing the content hash of this library pub fn digest(&self) -> &RpoDigest { @@ -110,8 +103,29 @@ impl Library { self.exports.keys() } - /// Returns the inner [`MastForest`]. - pub fn mast_forest(&self) -> &MastForest { + /// Returns the number of exports in this library. + pub fn num_exports(&self) -> usize { + self.exports.len() + } + + /// Returns a MAST node ID associated with the specified exported procedure. + /// + /// # Panics + /// Panics if the specified procedure is not exported from this library. + pub fn get_export_node_id(&self, proc_name: &QualifiedProcedureName) -> MastNodeId { + *self.exports.get(proc_name).expect("procedure not exported from the library") + } + + /// Returns true if the specified exported procedure is re-exported from a dependency. + pub fn is_reexport(&self, proc_name: &QualifiedProcedureName) -> bool { + self.exports + .get(proc_name) + .map(|&node_id| self.mast_forest[node_id].is_external()) + .unwrap_or(false) + } + + /// Returns a reference to the inner [`MastForest`]. + pub fn mast_forest(&self) -> &Arc { &self.mast_forest } } @@ -122,17 +136,17 @@ impl Library { pub fn module_infos(&self) -> impl Iterator { let mut modules_by_path: BTreeMap = BTreeMap::new(); - for (proc_name, export) in self.exports.iter() { + for (proc_name, &proc_root_node_id) in self.exports.iter() { modules_by_path .entry(proc_name.module.clone()) .and_modify(|compiled_module| { - let proc_digest = export.digest(&self.mast_forest); + let proc_digest = self.mast_forest[proc_root_node_id].digest(); compiled_module.add_procedure(proc_name.name.clone(), proc_digest); }) .or_insert_with(|| { let mut module_info = ModuleInfo::new(proc_name.module.clone()); - let proc_digest = export.digest(&self.mast_forest); + let proc_digest = self.mast_forest[proc_root_node_id].digest(); module_info.add_procedure(proc_name.name.clone(), proc_digest); module_info @@ -143,12 +157,6 @@ impl Library { } } -impl From for MastForest { - fn from(value: Library) -> Self { - value.mast_forest - } -} - impl Serializable for Library { fn write_into(&self, target: &mut W) { let Self { digest: _, exports, mast_forest } = self; @@ -156,42 +164,43 @@ impl Serializable for Library { mast_forest.write_into(target); target.write_usize(exports.len()); - for (proc_name, export) in exports { + for (proc_name, proc_node_id) in exports { proc_name.module.write_into(target); proc_name.name.as_str().write_into(target); - export.write_into(target); + target.write_u32(proc_node_id.as_u32()); } } } impl Deserializable for Library { fn read_from(source: &mut R) -> Result { - let mast_forest = MastForest::read_from(source)?; + let mast_forest = Arc::new(MastForest::read_from(source)?); let num_exports = source.read_usize()?; let mut exports = BTreeMap::new(); for _ in 0..num_exports { let proc_module = source.read()?; let proc_name: String = source.read()?; - let proc_name = ProcedureName::new(proc_name) - .map_err(|err| DeserializationError::InvalidValue(err.to_string()))?; + let proc_name = ProcedureName::new_unchecked(Ident::new_unchecked(Span::unknown( + Arc::from(proc_name), + ))); let proc_name = QualifiedProcedureName::new(proc_module, proc_name); - let export = Export::read_with_forest(source, &mast_forest)?; + let proc_node_id = MastNodeId::from_u32_safe(source.read_u32()?, &mast_forest)?; - exports.insert(proc_name, export); + exports.insert(proc_name, proc_node_id); } - let digest = content_hash(&exports, &mast_forest); + let digest = compute_content_hash(&exports, &mast_forest); Ok(Self { digest, exports, mast_forest }) } } -fn content_hash( - exports: &BTreeMap, +fn compute_content_hash( + exports: &BTreeMap, mast_forest: &MastForest, ) -> RpoDigest { - let digests = BTreeSet::from_iter(exports.values().map(|export| export.digest(mast_forest))); + let digests = BTreeSet::from_iter(exports.values().map(|&id| mast_forest[id].digest())); digests .into_iter() .reduce(|a, b| vm_core::crypto::hash::Rpo256::merge(&[a, b])) @@ -254,10 +263,15 @@ mod use_std_library { /// For example, let's say I call this function like so: /// /// ```rust + /// use std::sync::Arc; + /// + /// use miden_assembly::{Assembler, Library, LibraryNamespace}; + /// use vm_core::debuginfo::DefaultSourceManager; + /// /// Library::from_dir( /// "~/masm/std", - /// LibraryNamespace::new("std").unwrap() - /// Arc::new(crate::DefaultSourceManager::default()), + /// LibraryNamespace::new("std").unwrap(), + /// Assembler::new(Arc::new(DefaultSourceManager::default())), /// ); /// ``` /// @@ -295,58 +309,6 @@ mod use_std_library { } } -impl Export { - pub fn digest(&self, mast_forest: &MastForest) -> RpoDigest { - match self { - Self::Local(node_id) => mast_forest[*node_id].digest(), - Self::External(digest) => *digest, - } - } - - fn tag(&self) -> u8 { - // SAFETY: This is safe because we have given this enum a primitive representation with - // #[repr(u8)], with the first field of the underlying union-of-structs the discriminant. - // - // See the section on "accessing the numeric value of the discriminant" - // here: https://doc.rust-lang.org/std/mem/fn.discriminant.html - unsafe { *<*const _>::from(self).cast::() } - } -} - -impl Serializable for Export { - fn write_into(&self, target: &mut W) { - target.write_u8(self.tag()); - match self { - Self::Local(node_id) => target.write_u32(node_id.into()), - Self::External(digest) => digest.write_into(target), - } - } -} - -impl Export { - pub fn read_with_forest( - source: &mut R, - mast_forest: &MastForest, - ) -> Result { - match source.read_u8()? { - 0 => { - let node_id = MastNodeId::from_u32_safe(source.read_u32()?, mast_forest)?; - if !mast_forest.is_procedure_root(node_id) { - return Err(DeserializationError::InvalidValue(format!( - "node with id {node_id} is not a procedure root" - ))); - } - Ok(Self::Local(node_id)) - }, - 1 => RpoDigest::read_from(source).map(Self::External), - n => Err(DeserializationError::InvalidValue(format!( - "{} is not a valid compiled library export entry", - n - ))), - } - } -} - // KERNEL LIBRARY // ================================================================================================ @@ -356,7 +318,7 @@ impl Export { /// - All exported procedures must be exported directly from the kernel namespace (i.e., `#sys`). /// - There must be at least one exported procedure. /// - The number of exported procedures cannot exceed [Kernel::MAX_NUM_PROCEDURES] (i.e., 256). -#[derive(Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct KernelLibrary { kernel: Kernel, kernel_info: ModuleInfo, @@ -376,13 +338,13 @@ impl KernelLibrary { &self.kernel } - /// Returns the inner [`MastForest`]. - pub fn mast_forest(&self) -> &MastForest { + /// Returns a reference to the inner [`MastForest`]. + pub fn mast_forest(&self) -> &Arc { self.library.mast_forest() } /// Destructures this kernel library into individual parts. - pub fn into_parts(self) -> (Kernel, ModuleInfo, MastForest) { + pub fn into_parts(self) -> (Kernel, ModuleInfo, Arc) { (self.kernel, self.kernel_info, self.library.mast_forest) } } @@ -400,7 +362,7 @@ impl TryFrom for KernelLibrary { let mut kernel_module = ModuleInfo::new(kernel_path.clone()); - for (proc_path, export) in library.exports.iter() { + for (proc_path, &proc_node_id) in library.exports.iter() { // make sure all procedures are exported only from the kernel root if proc_path.module != kernel_path { return Err(LibraryError::InvalidKernelExport { @@ -408,7 +370,7 @@ impl TryFrom for KernelLibrary { }); } - let proc_digest = export.digest(&library.mast_forest); + let proc_digest = library.mast_forest[proc_node_id].digest(); proc_digests.push(proc_digest); kernel_module.add_procedure(proc_path.name.clone(), proc_digest); } @@ -423,12 +385,6 @@ impl TryFrom for KernelLibrary { } } -impl From for MastForest { - fn from(value: KernelLibrary) -> Self { - value.library.mast_forest - } -} - impl Serializable for KernelLibrary { fn write_into(&self, target: &mut W) { let Self { kernel: _, kernel_info: _, library } = self; diff --git a/assembly/src/library/module.rs b/assembly/src/library/module.rs index eb7287769b..ed010d0e2b 100644 --- a/assembly/src/library/module.rs +++ b/assembly/src/library/module.rs @@ -9,7 +9,7 @@ use crate::{ // MODULE INFO // ================================================================================================ -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct ModuleInfo { path: LibraryPath, procedures: Vec, @@ -68,7 +68,7 @@ impl ModuleInfo { } /// Stores the name and digest of a procedure. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct ProcedureInfo { pub name: ProcedureName, pub digest: RpoDigest, diff --git a/assembly/src/library/path.rs b/assembly/src/library/path.rs index a77ea1b287..8175339dd2 100644 --- a/assembly/src/library/path.rs +++ b/assembly/src/library/path.rs @@ -65,9 +65,9 @@ impl<'a> LibraryPathComponent<'a> { } } -impl<'a> Eq for LibraryPathComponent<'a> {} +impl Eq for LibraryPathComponent<'_> {} -impl<'a> PartialEq for LibraryPathComponent<'a> { +impl PartialEq for LibraryPathComponent<'_> { fn eq(&self, other: &Self) -> bool { match (self, other) { (Self::Namespace(a), Self::Namespace(b)) => a == b, @@ -77,13 +77,13 @@ impl<'a> PartialEq for LibraryPathComponent<'a> { } } -impl<'a> PartialEq for LibraryPathComponent<'a> { +impl PartialEq for LibraryPathComponent<'_> { fn eq(&self, other: &str) -> bool { self.as_ref().eq(other) } } -impl<'a> AsRef for LibraryPathComponent<'a> { +impl AsRef for LibraryPathComponent<'_> { fn as_ref(&self) -> &str { match self { Self::Namespace(ns) => ns.as_str(), @@ -92,7 +92,7 @@ impl<'a> AsRef for LibraryPathComponent<'a> { } } -impl<'a> fmt::Display for LibraryPathComponent<'a> { +impl fmt::Display for LibraryPathComponent<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str(self.as_ref()) } @@ -575,7 +575,10 @@ mod tests { assert_matches!(path, Err(PathError::InvalidComponent(IdentError::InvalidStart))); let path = LibraryPath::new("foo::b@r"); - assert_matches!(path, Err(PathError::InvalidComponent(IdentError::InvalidChars))); + assert_matches!( + path, + Err(PathError::InvalidComponent(IdentError::InvalidChars { ident: _ })) + ); let path = LibraryPath::new("#foo::bar"); assert_matches!( diff --git a/assembly/src/library/tests.rs b/assembly/src/library/tests.rs index ef13c8a197..adadc53029 100644 --- a/assembly/src/library/tests.rs +++ b/assembly/src/library/tests.rs @@ -1,6 +1,5 @@ -use alloc::{string::ToString, vec::Vec}; - -use vm_core::utils::SliceReader; +use alloc::string::ToString; +use core::str::FromStr; use super::*; use crate::{ @@ -19,6 +18,159 @@ macro_rules! parse_module { }}; } +// TESTS +// ================================================================================================ + +#[test] +fn library_exports() -> Result<(), Report> { + let context = TestContext::new(); + + // build the first library + let baz = r#" + export.baz1 + push.7 push.8 sub + end + "#; + let baz = parse_module!(&context, "lib1::baz", baz); + + let lib1 = Assembler::new(context.source_manager()).assemble_library([baz])?; + + // build the second library + let foo = r#" + proc.foo1 + push.1 add + end + + export.foo2 + push.2 add + exec.foo1 + end + + export.foo3 + push.3 mul + exec.foo1 + exec.foo2 + end + "#; + let foo = parse_module!(&context, "lib2::foo", foo); + + // declare bar module + let bar = r#" + use.lib1::baz + use.lib2::foo + + export.baz::baz1->bar1 + + export.foo::foo2->bar2 + + export.bar3 + exec.foo::foo2 + end + + proc.bar4 + push.1 push.2 mul + end + + export.bar5 + push.3 sub + exec.foo::foo2 + exec.bar1 + exec.bar2 + exec.bar4 + end + "#; + let bar = parse_module!(&context, "lib2::bar", bar); + let lib2_modules = [foo, bar]; + + let lib2 = Assembler::new(context.source_manager()) + .with_library(lib1)? + .assemble_library(lib2_modules.iter().cloned())?; + + let foo2 = QualifiedProcedureName::from_str("lib2::foo::foo2").unwrap(); + let foo3 = QualifiedProcedureName::from_str("lib2::foo::foo3").unwrap(); + let bar1 = QualifiedProcedureName::from_str("lib2::bar::bar1").unwrap(); + let bar2 = QualifiedProcedureName::from_str("lib2::bar::bar2").unwrap(); + let bar3 = QualifiedProcedureName::from_str("lib2::bar::bar3").unwrap(); + let bar5 = QualifiedProcedureName::from_str("lib2::bar::bar5").unwrap(); + + // make sure the library exports all exported procedures + let expected_exports: BTreeSet<_> = [&foo2, &foo3, &bar1, &bar2, &bar3, &bar5].into(); + let actual_exports: BTreeSet<_> = lib2.exports().collect(); + assert_eq!(expected_exports, actual_exports); + + // make sure foo2, bar2, and bar3 map to the same MastNode + assert_eq!(lib2.get_export_node_id(&foo2), lib2.get_export_node_id(&bar2)); + assert_eq!(lib2.get_export_node_id(&foo2), lib2.get_export_node_id(&bar3)); + + // make sure there are 6 roots in the MAST (foo1, foo2, foo3, bar1, bar4, and bar5) + assert_eq!(lib2.mast_forest.num_procedures(), 6); + + // bar1 should be the only re-export (i.e. the only procedure re-exported from a dependency) + assert!(!lib2.is_reexport(&foo2)); + assert!(!lib2.is_reexport(&foo3)); + assert!(lib2.is_reexport(&bar1)); + assert!(!lib2.is_reexport(&bar2)); + assert!(!lib2.is_reexport(&bar3)); + assert!(!lib2.is_reexport(&bar5)); + + Ok(()) +} + +#[test] +fn library_procedure_collision() -> Result<(), Report> { + let context = TestContext::new(); + + // build the first library + let foo = r#" + export.foo1 + push.1 + if.true + push.1 push.2 add + else + push.1 push.2 mul + end + end + "#; + let foo = parse_module!(&context, "lib1::foo", foo); + let lib1 = Assembler::new(context.source_manager()).assemble_library([foo])?; + + // build the second library which defines the same procedure as the first one + let bar = r#" + use.lib1::foo + + export.foo::foo1->bar1 + + export.bar2 + push.1 + if.true + push.1 push.2 add + else + push.1 push.2 mul + end + end + "#; + let bar = parse_module!(&context, "lib2::bar", bar); + let lib2 = Assembler::new(context.source_manager()) + .with_library(lib1)? + .assemble_library([bar])?; + + // make sure lib2 has the expected exports (i.e., bar1 and bar2) + assert_eq!(lib2.num_exports(), 2); + + // make sure that bar1 and bar2 are equal nodes in the MAST forest + let lib2_bar_bar1 = QualifiedProcedureName::from_str("lib2::bar::bar1").unwrap(); + let lib2_bar_bar2 = QualifiedProcedureName::from_str("lib2::bar::bar2").unwrap(); + assert_eq!(lib2.get_export_node_id(&lib2_bar_bar1), lib2.get_export_node_id(&lib2_bar_bar2)); + + // make sure only one node was added to the forest + // NOTE: the MAST forest should actually have only 1 node (external node for the re-exported + // procedure), because nodes for the local procedure nodes should be pruned from the forest, + // but this is not implemented yet + assert_eq!(lib2.mast_forest().num_nodes(), 5); + + Ok(()) +} + #[test] fn library_serialization() -> Result<(), Report> { let context = TestContext::new(); @@ -46,13 +198,11 @@ fn library_serialization() -> Result<(), Report> { let modules = [foo, bar]; // serialize/deserialize the bundle with locations - let bundle = Assembler::new(context.source_manager()) - .assemble_library(modules.iter().cloned()) - .unwrap(); + let bundle = + Assembler::new(context.source_manager()).assemble_library(modules.iter().cloned())?; - let mut bytes = Vec::new(); - bundle.write_into(&mut bytes); - let deserialized = Library::read_from(&mut SliceReader::new(&bytes)).unwrap(); + let bytes = bundle.to_bytes(); + let deserialized = Library::read_from_bytes(&bytes).unwrap(); assert_eq!(bundle, deserialized); Ok(()) diff --git a/assembly/src/parser/error.rs b/assembly/src/parser/error.rs index 5d88994473..4b0fd1bf2b 100644 --- a/assembly/src/parser/error.rs +++ b/assembly/src/parser/error.rs @@ -234,6 +234,26 @@ pub enum ParsingError { #[label] span: SourceSpan, }, + #[error("conflicting attributes for procedure definition")] + #[diagnostic()] + AttributeConflict { + #[label( + "conflict occurs because an attribute with the same name has already been defined" + )] + span: SourceSpan, + #[label("previously defined here")] + prev: SourceSpan, + }, + #[error("conflicting key-value attributes for procedure definition")] + #[diagnostic()] + AttributeKeyValueConflict { + #[label( + "conflict occurs because a key with the same name has already been set in a previous declaration" + )] + span: SourceSpan, + #[label("previously defined here")] + prev: SourceSpan, + }, } impl ParsingError { diff --git a/assembly/src/parser/grammar.lalrpop b/assembly/src/parser/grammar.lalrpop index a8be8ce46d..65b78ae427 100644 --- a/assembly/src/parser/grammar.lalrpop +++ b/assembly/src/parser/grammar.lalrpop @@ -1,6 +1,6 @@ use alloc::{ boxed::Box, - collections::{VecDeque, BTreeSet}, + collections::{VecDeque, BTreeSet, BTreeMap}, string::ToString, sync::Arc, vec::Vec, @@ -34,6 +34,7 @@ extern { bare_ident => Token::Ident(<&'input str>), const_ident => Token::ConstantIdent(<&'input str>), quoted_ident => Token::QuotedIdent(<&'input str>), + quoted_string => Token::QuotedString(<&'input str>), hex_value => Token::HexValue(), bin_value => Token::BinValue(), doc_comment => Token::DocComment(), @@ -192,22 +193,36 @@ extern { "u32xor" => Token::U32Xor, "while" => Token::While, "xor" => Token::Xor, + "@" => Token::At, "!" => Token::Bang, "::" => Token::ColonColon, "." => Token::Dot, + "," => Token::Comma, "=" => Token::Equal, "(" => Token::Lparen, + "[" => Token::Lbracket, "-" => Token::Minus, "+" => Token::Plus, "//" => Token::SlashSlash, "/" => Token::Slash, "*" => Token::Star, ")" => Token::Rparen, + "]" => Token::Rbracket, "->" => Token::Rstab, EOF => Token::Eof, } } + +// comma-delimited with at least one element +#[inline] +CommaDelimited: Vec = { + ",")*> => { + v.push(e); + v + } +}; + // dot-delimited with at least one element #[inline] DotDelimited: Vec = { @@ -290,6 +305,66 @@ Begin: Form = { } Proc: Form = { + =>? { + use alloc::collections::btree_map::Entry; + let attributes = proc.attributes_mut(); + for attr in annotations { + match attr { + Attribute::KeyValue(kv) => { + match attributes.entry(kv.id()) { + AttributeSetEntry::Vacant(entry) => { + entry.insert(Attribute::KeyValue(kv)); + } + AttributeSetEntry::Occupied(mut entry) => { + let value = entry.get_mut(); + match value { + Attribute::KeyValue(ref mut existing_kvs) => { + for (k, v) in kv.into_iter() { + let span = k.span(); + match existing_kvs.entry(k) { + Entry::Vacant(entry) => { + entry.insert(v); + } + Entry::Occupied(ref entry) => { + let prev = entry.get(); + return Err(ParseError::User { + error: ParsingError::AttributeKeyValueConflict { span, prev: prev.span() }, + }); + } + } + } + } + other => { + return Err(ParseError::User { + error: ParsingError::AttributeConflict { span: kv.span(), prev: other.span() }, + }); + } + } + } + } + } + attr => { + match attributes.entry(attr.id()) { + AttributeSetEntry::Vacant(entry) => { + entry.insert(attr); + } + AttributeSetEntry::Occupied(ref entry) => { + let prev_attr = entry.get(); + return Err(ParseError::User { + error: ParsingError::AttributeConflict { span: attr.span(), prev: prev_attr.span() }, + }); + } + } + } + } + } + Ok(Form::Procedure(Export::Procedure(proc))) + }, + AliasDef => Form::Procedure(Export::Alias(<>)), +} + +#[inline] +ProcedureDef: Procedure = { "." > "end" =>? { let num_locals = num_locals.unwrap_or(0); let procedure = Procedure::new( @@ -299,9 +374,12 @@ Proc: Form = { num_locals, body ); - Ok(Form::Procedure(Export::Procedure(procedure))) + Ok(procedure) }, +} +#[inline] +AliasDef: ProcedureAlias = { "export" "." " )?> =>? { let span = span!(source_file.id(), l, r); let alias = match name { @@ -341,7 +419,7 @@ Proc: Form = { ProcedureAlias::new(export_name, AliasTarget::AbsoluteProcedurePath(target)) } }; - Ok(Form::Procedure(Export::Alias(alias))) + Ok(alias) } } @@ -351,6 +429,68 @@ Visibility: Visibility = { "export" => Visibility::Public, } +// ANNOTATIONS +// ================================================================================================ + +Annotation: Attribute = { + "@" => attr.with_span(span!(source_file.id(), l, r)), +} + +#[inline] +Attribute: Attribute = { + "(" > ")" => { + Attribute::List(MetaList { span: span!(source_file.id(), l, r), name, items }) + }, + + "(" > ")" =>? { + use alloc::collections::btree_map::Entry; + + let mut map = BTreeMap::::default(); + for meta_kv in items { + let (span, (k, v)) = meta_kv.into_parts(); + match map.entry(k) { + Entry::Occupied(ref entry) => { + let prev = entry.key().span(); + return Err(ParseError::User { + error: ParsingError::AttributeKeyValueConflict { span, prev }, + }); + } + Entry::Vacant(entry) => { + entry.insert(v); + } + } + } + Ok(Attribute::KeyValue(MetaKeyValue { span: span!(source_file.id(), l, r), name, items: map })) + }, + + => Attribute::Marker(<>), +} + +MetaKeyValue: Span<(Ident, MetaExpr)> = { + "=" => { + let span = span!(source_file.id(), l, r); + Span::new(span, (key, value)) + } +} + +MetaExpr: MetaExpr = { + BareIdent => MetaExpr::Ident(<>), + QuotedString => MetaExpr::String(<>), + => MetaExpr::Int(Span::new(span!(source_file.id(), l, r), value)), +} + +#[inline] +QuotedString: Ident = { + => { + let value = interned.get(value).cloned().unwrap_or_else(|| { + let value = Arc::::from(value.to_string().into_boxed_str()); + interned.insert(value.clone()); + value + }); + Ident::new_unchecked(Span::new(span!(source_file.id(), l, r), value)) + } +} + // CODE BLOCKS // ================================================================================================ diff --git a/assembly/src/parser/lexer.rs b/assembly/src/parser/lexer.rs index b256addd75..0664ab0e9d 100644 --- a/assembly/src/parser/lexer.rs +++ b/assembly/src/parser/lexer.rs @@ -274,15 +274,19 @@ impl<'input> Lexer<'input> { } match self.read() { + '@' => pop!(self, Token::At), '!' => pop!(self, Token::Bang), ':' => match self.peek() { ':' => pop2!(self, Token::ColonColon), _ => Err(ParsingError::InvalidToken { span: self.span() }), }, '.' => pop!(self, Token::Dot), + ',' => pop!(self, Token::Comma), '=' => pop!(self, Token::Equal), '(' => pop!(self, Token::Lparen), + '[' => pop!(self, Token::Lbracket), ')' => pop!(self, Token::Rparen), + ']' => pop!(self, Token::Rbracket), '-' => match self.peek() { '>' => pop2!(self, Token::Rstab), _ => pop!(self, Token::Minus), @@ -293,7 +297,7 @@ impl<'input> Lexer<'input> { _ => pop!(self, Token::Slash), }, '*' => pop!(self, Token::Star), - '"' => self.lex_quoted_identifier(), + '"' => self.lex_quoted_identifier_or_string(), '0' => match self.peek() { 'x' => { self.skip(); @@ -414,10 +418,11 @@ impl<'input> Lexer<'input> { } } - fn lex_quoted_identifier(&mut self) -> Result, ParsingError> { + fn lex_quoted_identifier_or_string(&mut self) -> Result, ParsingError> { // Skip quotation mark self.skip(); + let mut is_identifier = true; let quote_size = ByteOffset::from_char_len('"'); loop { match self.read() { @@ -426,27 +431,37 @@ impl<'input> Lexer<'input> { start: SourceSpan::at(self.source_id, self.span().start()), }); }, + '\\' => { + is_identifier = false; + self.skip(); + match self.read() { + '"' | '\n' => { + self.skip(); + }, + _ => (), + } + }, '"' => { let span = self.span(); let start = span.start() + quote_size; let span = SourceSpan::new(self.source_id, start..span.end()); self.skip(); - break Ok(Token::QuotedIdent(self.slice_span(span))); + break Ok(if is_identifier { + Token::QuotedIdent(self.slice_span(span)) + } else { + Token::QuotedString(self.slice_span(span)) + }); }, c if c.is_ascii_alphanumeric() => { self.skip(); - continue; }, '_' | '$' | '-' | '!' | '?' | '<' | '>' | ':' | '.' => { self.skip(); - continue; }, - c => { - let loc = self.span().end() - ByteOffset::from_char_len(c); - break Err(ParsingError::InvalidIdentCharacter { - span: SourceSpan::at(self.source_id, loc), - }); + _ => { + is_identifier = false; + self.skip(); }, } } diff --git a/assembly/src/parser/mod.rs b/assembly/src/parser/mod.rs index 440f0cb5ae..d61e5619a1 100644 --- a/assembly/src/parser/mod.rs +++ b/assembly/src/parser/mod.rs @@ -21,8 +21,6 @@ mod token; use alloc::{boxed::Box, collections::BTreeSet, string::ToString, sync::Arc, vec::Vec}; -use miette::miette; - pub use self::{ error::{BinErrorKind, HexErrorKind, LiteralErrorKind, ParsingError}, lexer::Lexer, @@ -182,6 +180,7 @@ pub fn read_modules_from_dir( ) -> Result>, Report> { use std::collections::{btree_map::Entry, BTreeMap}; + use miette::miette; use module_walker::{ModuleEntry, WalkModules}; use crate::diagnostics::{IntoDiagnostic, WrapErr}; @@ -232,7 +231,8 @@ mod module_walker { path::{Path, PathBuf}, }; - use super::miette; + use miette::miette; + use crate::{ ast::Module, diagnostics::{IntoDiagnostic, Report}, @@ -304,7 +304,7 @@ mod module_walker { } } - impl<'a> Iterator for WalkModules<'a> { + impl Iterator for WalkModules<'_> { type Item = Result; fn next(&mut self) -> Option { diff --git a/assembly/src/parser/token.rs b/assembly/src/parser/token.rs index 7ff7976ddd..602e274912 100644 --- a/assembly/src/parser/token.rs +++ b/assembly/src/parser/token.rs @@ -50,6 +50,69 @@ pub enum HexEncodedValue { /// A set of 4 field elements, 32 bytes, encoded as a contiguous string of 64 hex digits Word([Felt; 4]), } +impl fmt::Display for HexEncodedValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::U8(value) => write!(f, "{value}"), + Self::U16(value) => write!(f, "{value}"), + Self::U32(value) => write!(f, "{value:#04x}"), + Self::Felt(value) => write!(f, "{:#08x}", &value.as_int().to_be()), + Self::Word(value) => write!( + f, + "{:#08x}{:08x}{:08x}{:08x}", + &value[0].as_int(), + &value[1].as_int(), + &value[2].as_int(), + &value[3].as_int(), + ), + } + } +} +impl PartialOrd for HexEncodedValue { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for HexEncodedValue { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + use core::cmp::Ordering; + match (self, other) { + (Self::U8(l), Self::U8(r)) => l.cmp(r), + (Self::U8(_), _) => Ordering::Less, + (Self::U16(_), Self::U8(_)) => Ordering::Greater, + (Self::U16(l), Self::U16(r)) => l.cmp(r), + (Self::U16(_), _) => Ordering::Less, + (Self::U32(_), Self::U8(_) | Self::U16(_)) => Ordering::Greater, + (Self::U32(l), Self::U32(r)) => l.cmp(r), + (Self::U32(_), _) => Ordering::Less, + (Self::Felt(_), Self::U8(_) | Self::U16(_) | Self::U32(_)) => Ordering::Greater, + (Self::Felt(l), Self::Felt(r)) => l.as_int().cmp(&r.as_int()), + (Self::Felt(_), _) => Ordering::Less, + (Self::Word([l0, l1, l2, l3]), Self::Word([r0, r1, r2, r3])) => l0 + .as_int() + .cmp(&r0.as_int()) + .then_with(|| l1.as_int().cmp(&r1.as_int())) + .then_with(|| l2.as_int().cmp(&r2.as_int())) + .then_with(|| l3.as_int().cmp(&r3.as_int())), + (Self::Word(_), _) => Ordering::Greater, + } + } +} + +impl core::hash::Hash for HexEncodedValue { + fn hash(&self, state: &mut H) { + core::mem::discriminant(self).hash(state); + match self { + Self::U8(value) => value.hash(state), + Self::U16(value) => value.hash(state), + Self::U32(value) => value.hash(state), + Self::Felt(value) => value.as_int().hash(state), + Self::Word([a, b, c, d]) => { + [a.as_int(), b.as_int(), c.as_int(), d.as_int()].hash(state) + }, + } + } +} // BINARY ENCODED VALUE // ================================================================================================ @@ -225,17 +288,21 @@ pub enum Token<'input> { U32Xor, While, Xor, + At, Bang, ColonColon, Dot, + Comma, Equal, Lparen, + Lbracket, Minus, Plus, SlashSlash, Slash, Star, Rparen, + Rbracket, Rstab, DocComment(DocumentationType), HexValue(HexEncodedValue), @@ -244,11 +311,12 @@ pub enum Token<'input> { Ident(&'input str), ConstantIdent(&'input str), QuotedIdent(&'input str), + QuotedString(&'input str), Comment, Eof, } -impl<'input> fmt::Display for Token<'input> { +impl fmt::Display for Token<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Token::Add => write!(f, "add"), @@ -404,17 +472,21 @@ impl<'input> fmt::Display for Token<'input> { Token::U32Xor => write!(f, "u32xor"), Token::While => write!(f, "while"), Token::Xor => write!(f, "xor"), + Token::At => write!(f, "@"), Token::Bang => write!(f, "!"), Token::ColonColon => write!(f, "::"), Token::Dot => write!(f, "."), + Token::Comma => write!(f, ","), Token::Equal => write!(f, "="), Token::Lparen => write!(f, "("), + Token::Lbracket => write!(f, "["), Token::Minus => write!(f, "-"), Token::Plus => write!(f, "+"), Token::SlashSlash => write!(f, "//"), Token::Slash => write!(f, "/"), Token::Star => write!(f, "*"), Token::Rparen => write!(f, ")"), + Token::Rbracket => write!(f, "]"), Token::Rstab => write!(f, "->"), Token::DocComment(DocumentationType::Module(_)) => f.write_str("module doc"), Token::DocComment(DocumentationType::Form(_)) => f.write_str("doc comment"), @@ -424,6 +496,7 @@ impl<'input> fmt::Display for Token<'input> { Token::Ident(_) => f.write_str("identifier"), Token::ConstantIdent(_) => f.write_str("constant identifier"), Token::QuotedIdent(_) => f.write_str("quoted identifier"), + Token::QuotedString(_) => f.write_str("quoted string"), Token::Comment => f.write_str("comment"), Token::Eof => write!(f, "end of file"), } @@ -804,17 +877,21 @@ impl<'input> Token<'input> { Token::Ident(_) => { // Nope, try again match s { + "@" => Ok(Token::At), "!" => Ok(Token::Bang), "::" => Ok(Token::ColonColon), "." => Ok(Token::Dot), + "," => Ok(Token::Comma), "=" => Ok(Token::Equal), "(" => Ok(Token::Lparen), + "[" => Ok(Token::Lbracket), "-" => Ok(Token::Minus), "+" => Ok(Token::Plus), "//" => Ok(Token::SlashSlash), "/" => Ok(Token::Slash), "*" => Ok(Token::Star), ")" => Ok(Token::Rparen), + "]" => Ok(Token::Rbracket), "->" => Ok(Token::Rstab), "end of file" => Ok(Token::Eof), "module doc" => Ok(Token::DocComment(DocumentationType::Module(String::new()))), @@ -826,6 +903,7 @@ impl<'input> Token<'input> { "identifier" => Ok(Token::Ident("")), "constant identifier" => Ok(Token::ConstantIdent("")), "quoted identifier" => Ok(Token::QuotedIdent("")), + "quoted string" => Ok(Token::QuotedString("")), _ => Err(()), } }, diff --git a/assembly/src/sema/passes/const_eval.rs b/assembly/src/sema/passes/const_eval.rs index f76de03ad2..8394bc5402 100644 --- a/assembly/src/sema/passes/const_eval.rs +++ b/assembly/src/sema/passes/const_eval.rs @@ -17,7 +17,7 @@ impl<'analyzer> ConstEvalVisitor<'analyzer> { } } -impl<'analyzer> ConstEvalVisitor<'analyzer> { +impl ConstEvalVisitor<'_> { fn eval_const(&mut self, imm: &mut Immediate) -> ControlFlow<()> where T: TryFrom, @@ -45,7 +45,7 @@ impl<'analyzer> ConstEvalVisitor<'analyzer> { } } -impl<'analyzer> VisitMut for ConstEvalVisitor<'analyzer> { +impl VisitMut for ConstEvalVisitor<'_> { fn visit_mut_immediate_u8(&mut self, imm: &mut Immediate) -> ControlFlow<()> { self.eval_const(imm) } diff --git a/assembly/src/sema/passes/verify_invoke.rs b/assembly/src/sema/passes/verify_invoke.rs index a1c38ba8ca..16c7ac5ba5 100644 --- a/assembly/src/sema/passes/verify_invoke.rs +++ b/assembly/src/sema/passes/verify_invoke.rs @@ -43,7 +43,7 @@ impl<'a> VerifyInvokeTargets<'a> { } } -impl<'a> VerifyInvokeTargets<'a> { +impl VerifyInvokeTargets<'_> { fn resolve_local(&mut self, name: &ProcedureName) -> ControlFlow<()> { if !self.procedures.contains(name) { self.analyzer @@ -72,7 +72,7 @@ impl<'a> VerifyInvokeTargets<'a> { } } -impl<'a> VisitMut for VerifyInvokeTargets<'a> { +impl VisitMut for VerifyInvokeTargets<'_> { fn visit_mut_inst(&mut self, inst: &mut Span) -> ControlFlow<()> { let span = inst.span(); match &**inst { diff --git a/assembly/src/testing.rs b/assembly/src/testing.rs index e2630fbf0f..2ba14bcfe8 100644 --- a/assembly/src/testing.rs +++ b/assembly/src/testing.rs @@ -199,9 +199,9 @@ impl TestContext { let _ = set_hook(Box::new(|_| Box::new(ReportHandlerOpts::new().build()))); } let source_manager = Arc::new(crate::DefaultSourceManager::default()); - let assembler = Assembler::new(source_manager.clone()) - .with_debug_mode(true) - .with_warnings_as_errors(true); + // Note: we do not set debug mode by default because we do not want AsmOp decorators to be + // inserted in our programs + let assembler = Assembler::new(source_manager.clone()).with_warnings_as_errors(true); Self { source_manager, assembler } } diff --git a/assembly/src/tests.rs b/assembly/src/tests.rs index 8b005f351a..0cc537e3a2 100644 --- a/assembly/src/tests.rs +++ b/assembly/src/tests.rs @@ -1,12 +1,17 @@ -use alloc::string::ToString; +use alloc::{string::ToString, vec::Vec}; + +use vm_core::{ + mast::{MastNode, MastNodeId}, + Program, +}; use crate::{ assert_diagnostic_lines, ast::{Module, ModuleKind}, - diagnostics::Report, + diagnostics::{IntoDiagnostic, Report}, regex, source_file, testing::{Pattern, TestContext}, - Assembler, LibraryPath, ModuleParser, + Assembler, Deserializable, LibraryPath, ModuleParser, Serializable, }; type TestResult = Result<(), Report>; @@ -29,6 +34,15 @@ macro_rules! assert_assembler_diagnostic { }}; } +macro_rules! parse_module { + ($context:expr, $path:literal, $source:expr) => {{ + let path = LibraryPath::new($path).into_diagnostic()?; + let source_file = + $context.source_manager().load(concat!("test", line!()), $source.to_string()); + Module::parse(path, ModuleKind::Library, source_file)? + }}; +} + // SIMPLE PROGRAMS // ================================================================================================ @@ -718,7 +732,7 @@ fn constant_must_be_valid_felt() -> TestResult { " : ^^^|^^^", " : `-- found a constant identifier here", " `----", - " help: expected \"*\", or \"+\", or \"-\", or \"/\", or \"//\", or \"begin\", or \"const\", \ + " help: expected \"*\", or \"+\", or \"-\", or \"/\", or \"//\", or \"@\", or \"begin\", or \"const\", \ or \"export\", or \"proc\", or \"use\", or end of file, or doc comment" ); Ok(()) @@ -986,6 +1000,245 @@ fn const_conversion_failed_to_u32() -> TestResult { Ok(()) } +// DECORATORS +// ================================================================================================ + +#[test] +fn decorators_basic_block() -> TestResult { + let context = TestContext::default(); + let source = source_file!( + &context, + "\ + begin + trace.0 + add + trace.1 + mul + trace.2 + end" + ); + let expected = "\ +begin + basic_block trace(0) add trace(1) mul trace(2) end +end"; + let program = context.assemble(source)?; + assert_str_eq!(expected, format!("{program}")); + Ok(()) +} + +#[test] +fn decorators_repeat_one_basic_block() -> TestResult { + let context = TestContext::default(); + let source = source_file!( + &context, + "\ + begin + trace.0 + repeat.2 add end + trace.1 + repeat.2 mul end + trace.2 + end" + ); + let expected = "\ +begin + basic_block trace(0) add add trace(1) mul mul trace(2) end +end"; + let program = context.assemble(source)?; + assert_str_eq!(expected, format!("{program}")); + Ok(()) +} + +#[test] +fn decorators_repeat_split() -> TestResult { + let context = TestContext::default(); + let source = source_file!( + &context, + "\ + begin + trace.0 + repeat.2 + if.true + trace.1 push.42 trace.2 + else + trace.3 push.22 trace.3 + end + trace.4 + end + trace.5 + end" + ); + let expected = "\ +begin + join + trace(0) + if.true + basic_block trace(1) push(42) trace(2) end + else + basic_block trace(3) push(22) trace(3) end + end + trace(4) + if.true + basic_block trace(1) push(42) trace(2) end + else + basic_block trace(3) push(22) trace(3) end + end + trace(4) + end + trace(5) +end"; + let program = context.assemble(source)?; + assert_str_eq!(expected, format!("{program}")); + Ok(()) +} + +#[test] +fn decorators_call() -> TestResult { + let context = TestContext::default(); + let source = source_file!( + &context, + "\ + begin + trace.0 trace.1 + call.0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef + trace.2 + end" + ); + let expected = "\ +begin + trace(0) trace(1) + call.0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef + trace(2) +end"; + let program = context.assemble(source)?; + assert_str_eq!(expected, format!("{program}")); + Ok(()) +} + +#[test] +fn decorators_dyn() -> TestResult { + // single line + let context = TestContext::default(); + let source = source_file!( + &context, + "\ + begin + trace.0 + dynexec + trace.1 + end" + ); + let expected = "\ +begin + trace(0) dyn trace(1) +end"; + let program = context.assemble(source)?; + assert_str_eq!(expected, format!("{program}")); + + // multi line + let context = TestContext::default(); + let source = source_file!( + &context, + "\ + begin + trace.0 trace.1 trace.2 trace.3 trace.4 + dynexec + trace.5 trace.6 trace.7 trace.8 trace.9 + end" + ); + let expected = "\ +begin + trace(0) trace(1) trace(2) trace(3) trace(4) + dyn + trace(5) trace(6) trace(7) trace(8) trace(9) +end"; + let program = context.assemble(source)?; + assert_str_eq!(expected, format!("{program}")); + Ok(()) +} + +#[test] +fn decorators_external() -> TestResult { + let context = TestContext::default(); + let baz = r#" + export.f + push.7 push.8 sub + end + "#; + let baz = parse_module!(&context, "lib::baz", baz); + + let lib = Assembler::new(context.source_manager()).assemble_library([baz])?; + + let program_source = source_file!( + &context, + "\ + use.lib::baz + begin + trace.0 + exec.baz::f + trace.1 + end" + ); + + let expected = "\ +begin + trace(0) + external.0xe776df8dc02329acc43a09fe8e510b44a87dfd876e375ad383891470ece4f6de + trace(1) +end"; + let program = Assembler::new(context.source_manager()) + .with_library(lib)? + .assemble_program(program_source)?; + assert_str_eq!(expected, format!("{program}")); + + Ok(()) +} + +#[test] +fn decorators_join_and_split() -> TestResult { + let context = TestContext::default(); + let source = source_file!( + &context, + "\ + begin + trace.0 trace.1 + if.true + trace.2 add trace.3 + else + trace.4 mul trace.5 + end + trace.6 + if.true + trace.7 push.42 trace.8 + else + trace.9 push.22 trace.10 + end + trace.11 + end" + ); + let expected = "\ +begin + join + trace(0) trace(1) + if.true + basic_block trace(2) add trace(3) end + else + basic_block trace(4) mul trace(5) end + end + trace(6) + if.true + basic_block trace(7) push(42) trace(8) end + else + basic_block trace(9) push(22) trace(10) end + end + end + trace(11) +end"; + let program = context.assemble(source)?; + assert_str_eq!(expected, format!("{program}")); + Ok(()) +} + // ASSERTIONS // ================================================================================================ @@ -1227,6 +1480,131 @@ end"; Ok(()) } +/// Ensure that there is no collision between `Assert`, `U32assert2`, and `MpVerify` instructions +/// with different inner values (which all don't contribute to the MAST root). +#[test] +fn asserts_and_mpverify_with_code_in_duplicate_procedure() -> TestResult { + let context = TestContext::default(); + let source = source_file!( + &context, + "\ + proc.f1 + u32assert.err=1 + end + proc.f2 + u32assert.err=2 + end + proc.f12 + u32assert.err=1 + u32assert.err=2 + end + proc.f21 + u32assert.err=2 + u32assert.err=1 + end + proc.g1 + assert.err=1 + end + proc.g2 + assert.err=2 + end + proc.g12 + assert.err=1 + assert.err=2 + end + proc.g21 + assert.err=2 + assert.err=1 + end + proc.fg + assert.err=1 + u32assert.err=1 + assert.err=2 + u32assert.err=2 + + u32assert.err=1 + assert.err=1 + u32assert.err=2 + assert.err=2 + end + + proc.mpverify + mtree_verify.err=1 + mtree_verify.err=2 + mtree_verify.err=2 + mtree_verify.err=1 + end + + begin + exec.f1 + exec.f2 + exec.f12 + exec.f21 + exec.g1 + exec.g2 + exec.g12 + exec.g21 + exec.fg + exec.mpverify + end + " + ); + let program = context.assemble(source)?; + + let expected = "\ +begin + basic_block + pad + u32assert2(1) + drop + pad + u32assert2(2) + drop + pad + u32assert2(1) + drop + pad + u32assert2(2) + drop + pad + u32assert2(2) + drop + pad + u32assert2(1) + drop + assert(1) + assert(2) + assert(1) + assert(2) + assert(2) + assert(1) + assert(1) + pad + u32assert2(1) + drop + assert(2) + pad + u32assert2(2) + drop + pad + u32assert2(1) + drop + assert(1) + pad + u32assert2(2) + drop + assert(2) + mpverify(1) + mpverify(2) + mpverify(2) + mpverify(1) + end +end"; + + assert_str_eq!(expected, format!("{program}")); + Ok(()) +} + #[test] fn mtree_verify_with_code() -> TestResult { let context = TestContext::default(); @@ -1307,6 +1685,58 @@ end"; // PROGRAMS WITH PROCEDURES // ================================================================================================ +/// If the program has 2 procedures with the same MAST root (but possibly different decorators), the +/// correct procedure is chosen on exec +#[test] +fn ensure_correct_procedure_selection_on_collision() -> TestResult { + let context = TestContext::default(); + + // if with else + let source = source_file!( + &context, + " + proc.f + add + end + + proc.g + trace.2 + add + end + + begin + if.true + exec.f + else + exec.g + end + end" + ); + let program = context.assemble(source)?; + + // Note: those values were taken from adding prints to the assembler at the time of writing. It + // is possible that this test starts failing if we end up ordering procedures differently. + let expected_f_node_id = + MastNodeId::from_u32_safe(1_u32, program.mast_forest().as_ref()).unwrap(); + let expected_g_node_id = + MastNodeId::from_u32_safe(0_u32, program.mast_forest().as_ref()).unwrap(); + + let (exec_f_node_id, exec_g_node_id) = { + let split_node_id = program.entrypoint(); + let split_node = match &program.mast_forest()[split_node_id] { + MastNode::Split(split_node) => split_node, + _ => panic!("expected split node"), + }; + + (split_node.on_true(), split_node.on_false()) + }; + + assert_eq!(program.mast_forest()[expected_f_node_id], program.mast_forest()[exec_f_node_id]); + assert_eq!(program.mast_forest()[expected_g_node_id], program.mast_forest()[exec_g_node_id]); + + Ok(()) +} + #[test] fn program_with_one_procedure() -> TestResult { let context = TestContext::default(); @@ -1471,7 +1901,7 @@ fn program_with_dynamic_code_execution_in_new_context() -> TestResult { let program = context.assemble(source)?; let expected = "\ begin - call.0xc75c340ec6a69e708457544d38783abbb604d881b7dc62d00bfc2b10f52808e6 + dyncall end"; assert_str_eq!(format!("{program}"), expected); Ok(()) @@ -1898,7 +2328,7 @@ end"; " : `-- found a -> here", "3 |", " `----", - r#" help: expected "begin", or "const", or "export", or "proc", or "use", or end of file, or doc comment"# + r#" help: expected "@", or "begin", or "const", or "export", or "proc", or "use", or end of file, or doc comment"# ); // --- duplicate module import -------------------------------------------- @@ -2105,7 +2535,7 @@ fn invalid_empty_program() { "unexpected end of file", regex!(r#",-\[test[\d]+:1:1\]"#), "`----", - r#" help: expected "begin", or "const", or "export", or "proc", or "use", or doc comment"# + r#" help: expected "@", or "begin", or "const", or "export", or "proc", or "use", or doc comment"# ); assert_assembler_diagnostic!( @@ -2114,7 +2544,7 @@ fn invalid_empty_program() { "unexpected end of file", regex!(r#",-\[test[\d]+:1:1\]"#), " `----", - r#" help: expected "begin", or "const", or "export", or "proc", or "use", or doc comment"# + r#" help: expected "@", or "begin", or "const", or "export", or "proc", or "use", or doc comment"# ); } @@ -2130,7 +2560,7 @@ fn invalid_program_unrecognized_token() { " : ^^|^", " : `-- found a identifier here", " `----", - r#" help: expected "begin", or "const", or "export", or "proc", or "use", or doc comment"# + r#" help: expected "@", or "begin", or "const", or "export", or "proc", or "use", or doc comment"# ); } @@ -2160,7 +2590,7 @@ fn invalid_program_invalid_top_level_token() { " : ^|^", " : `-- found a mul here", " `----", - r#" help: expected "begin", or "const", or "export", or "proc", or "use", or end of file, or doc comment"# + r#" help: expected "@", or "begin", or "const", or "export", or "proc", or "use", or end of file, or doc comment"# ); } @@ -2495,10 +2925,10 @@ fn test_reexported_proc_with_same_name_as_local_proc_diff_locals() { let source = source_file!( &context, "export.foo.2 - push.1 - drop -end -" + push.1 + drop + end + " ); mod_parser.parse(LibraryPath::new("test::mod1").unwrap(), source).unwrap() }; @@ -2507,10 +2937,10 @@ end let source = source_file!( &context, "use.test::mod1 -export.foo - exec.mod1::foo -end -" + export.foo + exec.mod1::foo + end + " ); mod_parser.parse(LibraryPath::new("test::mod2").unwrap(), source).unwrap() }; @@ -2543,3 +2973,58 @@ end let _program = assembler.assemble_program(program_source).unwrap(); } + +// PROGRAM SERIALIZATION AND DESERIALIZATION +// ================================================================================================ +#[test] +fn test_program_serde_simple() { + let source = " + begin + push.1.2 + add + drop + end + "; + + let assembler = Assembler::default(); + let original_program = assembler.assemble_program(source).unwrap(); + + let mut target = Vec::new(); + original_program.write_into(&mut target); + let deserialized_program = Program::read_from_bytes(&target).unwrap(); + + assert_eq!(original_program, deserialized_program); +} + +#[test] +fn test_program_serde_with_decorators() { + let source = " + const.DEFAULT_CONST=100 + + proc.foo + push.1.2 add + debug.stack.8 + end + + begin + emit.DEFAULT_CONST + + exec.foo + + debug.stack.4 + + drop + + trace.DEFAULT_CONST + end + "; + + let assembler = Assembler::default().with_debug_mode(true); + let original_program = assembler.assemble_program(source).unwrap(); + + let mut target = Vec::new(); + original_program.write_into(&mut target); + let deserialized_program = Program::read_from_bytes(&target).unwrap(); + + assert_eq!(original_program, deserialized_program); +} diff --git a/core/Cargo.toml b/core/Cargo.toml index ae7da5c505..106f228858 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "miden-core" -version = "0.10.5" +version = "0.11.0" description = "Miden VM core components" -documentation = "https://docs.rs/miden-core/0.10.5" +documentation = "https://docs.rs/miden-core/0.11.0" readme = "README.md" categories = ["emulators", "no-std"] keywords = ["instruction-set", "miden", "program"] @@ -32,24 +32,24 @@ std = [ [dependencies] lock_api = { version = "0.4", features = ["arc_lock"] } -math = { package = "winter-math", version = "0.9", default-features = false } +math = { package = "winter-math", version = "0.10", default-features = false } memchr = { version = "2.7", default-features = false } -miden-crypto = { version = "0.10", default-features = false } +miden-crypto = { version = "0.12", default-features = false } miden-formatting = { version = "0.1", default-features = false } -miette = { package = "miden-miette", version = "7.1", default-features = false, features = [ +miette = { package = "miden-miette", version = "7.1", default-features = false, optional = true, features = [ "fancy-no-syscall", "derive" -], optional = true } +] } num-derive = { version = "0.4", default-features = false } num-traits = { version = "0.2", default-features = false } parking_lot = { version = "0.12", optional = true } thiserror = { package = "miden-thiserror", version = "1.0", default-features = false } -winter-utils = { package = "winter-utils", version = "0.9", default-features = false } +winter-utils = { package = "winter-utils", version = "0.10", default-features = false } [dev-dependencies] loom = "0.7" proptest = "1.5" -rand_utils = { version = "0.9", package = "winter-rand-utils" } +rand-utils = { package = "winter-rand-utils", version = "0.10" } [target.'cfg(loom)'.dependencies] loom = "0.7" diff --git a/core/README.md b/core/README.md index 3d2175b6b1..9d8ff7387a 100644 --- a/core/README.md +++ b/core/README.md @@ -11,3 +11,7 @@ This crate contains core components used by Miden VM. These components include: ## License This project is [MIT licensed](../LICENSE). + +## Acknowledgements + +The `racy_lock` module found under `core/src/utils/sync` is based on the [once_cell](https://crates.io/crates/once_cell) crate's implementation of `race::OnceBox`. diff --git a/core/src/lib.rs b/core/src/lib.rs index f95ec04b8f..15f14969d7 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -73,7 +73,7 @@ pub mod crypto { blake::{Blake3Digest, Blake3_160, Blake3_192, Blake3_256}, rpo::{Rpo256, RpoDigest}, rpx::{Rpx256, RpxDigest}, - ElementHasher, Hasher, + Digest, ElementHasher, Hasher, }; } @@ -124,8 +124,3 @@ pub mod stack; pub use stack::{StackInputs, StackOutputs}; pub mod utils; - -// TYPE ALIASES -// ================================================================================================ - -pub type StackTopState = [Felt; stack::STACK_TOP_SIZE]; diff --git a/core/src/mast/merger/mod.rs b/core/src/mast/merger/mod.rs new file mode 100644 index 0000000000..ae0273371a --- /dev/null +++ b/core/src/mast/merger/mod.rs @@ -0,0 +1,418 @@ +use alloc::{collections::BTreeMap, vec::Vec}; + +use miden_crypto::hash::blake::Blake3Digest; + +use crate::mast::{ + DecoratorId, MastForest, MastForestError, MastNode, MastNodeFingerprint, MastNodeId, + MultiMastForestIteratorItem, MultiMastForestNodeIter, +}; + +#[cfg(test)] +mod tests; + +/// A type that allows merging [`MastForest`]s. +/// +/// This functionality is exposed via [`MastForest::merge`]. See its documentation for more details. +pub(crate) struct MastForestMerger { + mast_forest: MastForest, + // Internal indices needed for efficient duplicate checking and MastNodeFingerprint + // computation. + // + // These are always in-sync with the nodes in `mast_forest`, i.e. all nodes added to the + // `mast_forest` are also added to the indices. + node_id_by_hash: BTreeMap, + hash_by_node_id: BTreeMap, + decorators_by_hash: BTreeMap, DecoratorId>, + /// Mappings from old decorator and node ids to their new ids. + /// + /// Any decorator in `mast_forest` is present as the target of some mapping in this map. + decorator_id_mappings: Vec, + /// Mappings from previous `MastNodeId`s to their new ids. + /// + /// Any `MastNodeId` in `mast_forest` is present as the target of some mapping in this map. + node_id_mappings: Vec, +} + +impl MastForestMerger { + /// Creates a new merger with an initially empty forest and merges all provided [`MastForest`]s + /// into it. + pub(crate) fn merge<'forest>( + forests: impl IntoIterator, + ) -> Result<(MastForest, MastForestRootMap), MastForestError> { + let forests = forests.into_iter().collect::>(); + let decorator_id_mappings = Vec::with_capacity(forests.len()); + let node_id_mappings = vec![MastForestNodeIdMap::new(); forests.len()]; + + let mut merger = Self { + node_id_by_hash: BTreeMap::new(), + hash_by_node_id: BTreeMap::new(), + decorators_by_hash: BTreeMap::new(), + mast_forest: MastForest::new(), + decorator_id_mappings, + node_id_mappings, + }; + + merger.merge_inner(forests.clone())?; + + let Self { mast_forest, node_id_mappings, .. } = merger; + + let root_maps = MastForestRootMap::from_node_id_map(node_id_mappings, forests); + + Ok((mast_forest, root_maps)) + } + + /// Merges all `forests` into self. + /// + /// It does this in three steps: + /// + /// 1. Merge all decorators, which is a case of deduplication and creating a decorator id + /// mapping which contains how existing [`DecoratorId`]s map to [`DecoratorId`]s in the + /// merged forest. + /// 2. Merge all nodes of forests. + /// - Similar to decorators, node indices might move during merging, so the merger keeps a + /// node id mapping as it merges nodes. + /// - This is a depth-first traversal over all forests to ensure all children are processed + /// before their parents. See the documentation of [`MultiMastForestNodeIter`] for details + /// on this traversal. + /// - Because all parents are processed after their children, we can use the node id mapping + /// to remap all [`MastNodeId`]s of the children to their potentially new id in the merged + /// forest. + /// - If any external node is encountered during this traversal with a digest `foo` for which + /// a `replacement` node exists in another forest with digest `foo`, then the external node + /// will be replaced by that node. In particular, it means we do not want to add the + /// external node to the merged forest, so it is never yielded from the iterator. + /// - Assuming the simple case, where the `replacement` was not visited yet and is just a + /// single node (not a tree), the iterator would first yield the `replacement` node which + /// means it is going to be merged into the forest. + /// - Next the iterator yields [`MultiMastForestIteratorItem::ExternalNodeReplacement`] + /// which signals that an external node was replaced by another node. In this example, + /// the `replacement_*` indices contained in that variant would point to the + /// `replacement` node. Now we can simply add a mapping from the external node to the + /// `replacement` node in our node id mapping which means all nodes that referenced the + /// external node will point to the `replacement` instead. + /// 3. Finally, we merge all roots of all forests. Here we map the existing root indices to + /// their potentially new indices in the merged forest and add them to the forest, + /// deduplicating in the process, too. + fn merge_inner(&mut self, forests: Vec<&MastForest>) -> Result<(), MastForestError> { + for other_forest in forests.iter() { + self.merge_decorators(other_forest)?; + } + + let iterator = MultiMastForestNodeIter::new(forests.clone()); + for item in iterator { + match item { + MultiMastForestIteratorItem::Node { forest_idx, node_id } => { + let node = &forests[forest_idx][node_id]; + self.merge_node(forest_idx, node_id, node)?; + }, + MultiMastForestIteratorItem::ExternalNodeReplacement { + // forest index of the node which replaces the external node + replacement_forest_idx, + // ID of the node that replaces the external node + replacement_mast_node_id, + // forest index of the external node + replaced_forest_idx, + // ID of the external node + replaced_mast_node_id, + } => { + // The iterator is not aware of the merged forest, so the node indices it yields + // are for the existing forests. That means we have to map the ID of the + // replacement to its new location, since it was previously merged and its IDs + // have very likely changed. + let mapped_replacement = self.node_id_mappings[replacement_forest_idx] + .get(&replacement_mast_node_id) + .copied() + .expect("every merged node id should be mapped"); + + // SAFETY: The iterator only yields valid forest indices, so it is safe to index + // directly. + self.node_id_mappings[replaced_forest_idx] + .insert(replaced_mast_node_id, mapped_replacement); + }, + } + } + + for (forest_idx, forest) in forests.iter().enumerate() { + self.merge_roots(forest_idx, forest)?; + } + + Ok(()) + } + + fn merge_decorators(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> { + let mut decorator_id_remapping = DecoratorIdMap::new(other_forest.decorators.len()); + + for (merging_id, merging_decorator) in other_forest.decorators.iter().enumerate() { + let merging_decorator_hash = merging_decorator.fingerprint(); + let new_decorator_id = if let Some(existing_decorator) = + self.decorators_by_hash.get(&merging_decorator_hash) + { + *existing_decorator + } else { + let new_decorator_id = self.mast_forest.add_decorator(merging_decorator.clone())?; + self.decorators_by_hash.insert(merging_decorator_hash, new_decorator_id); + new_decorator_id + }; + + decorator_id_remapping + .insert(DecoratorId::new_unchecked(merging_id as u32), new_decorator_id); + } + + self.decorator_id_mappings.push(decorator_id_remapping); + + Ok(()) + } + + fn merge_node( + &mut self, + forest_idx: usize, + merging_id: MastNodeId, + node: &MastNode, + ) -> Result<(), MastForestError> { + // We need to remap the node prior to computing the MastNodeFingerprint. + // + // This is because the MastNodeFingerprint computation looks up its descendants and + // decorators in the internal index, and if we were to pass the original node to + // that computation, it would look up the incorrect descendants and decorators + // (since the descendant's indices may have changed). + // + // Remapping at this point is guaranteed to be "complete", meaning all ids of children + // will be present in the node id mapping since the DFS iteration guarantees + // that all children of this `node` have been processed before this node and + // their indices have been added to the mappings. + let remapped_node = self.remap_node(forest_idx, node)?; + + let node_fingerprint = MastNodeFingerprint::from_mast_node( + &self.mast_forest, + &self.hash_by_node_id, + &remapped_node, + ) + .expect( + "hash_by_node_id should contain the fingerprints of all children of `remapped_node`", + ); + + match self.lookup_node_by_fingerprint(&node_fingerprint) { + Some(matching_node_id) => { + // If a node with a matching fingerprint exists, then the merging node is a + // duplicate and we remap it to the existing node. + self.node_id_mappings[forest_idx].insert(merging_id, matching_node_id); + }, + None => { + // If no node with a matching fingerprint exists, then the merging node is + // unique and we can add it to the merged forest. + let new_node_id = self.mast_forest.add_node(remapped_node)?; + self.node_id_mappings[forest_idx].insert(merging_id, new_node_id); + + // We need to update the indices with the newly inserted nodes + // since the MastNodeFingerprint computation requires all descendants of a node + // to be in this index. Hence when we encounter a node in the merging forest + // which has descendants (Call, Loop, Split, ...), then their descendants need to be + // in the indices. + self.node_id_by_hash.insert(node_fingerprint, new_node_id); + self.hash_by_node_id.insert(new_node_id, node_fingerprint); + }, + } + + Ok(()) + } + + fn merge_roots( + &mut self, + forest_idx: usize, + other_forest: &MastForest, + ) -> Result<(), MastForestError> { + for root_id in other_forest.roots.iter() { + // Map the previous root to its possibly new id. + let new_root = self.node_id_mappings[forest_idx] + .get(root_id) + .expect("all node ids should have an entry"); + // This takes O(n) where n is the number of roots in the merged forest every time to + // check if the root already exists. As the number of roots is relatively low generally, + // this should be okay. + self.mast_forest.make_root(*new_root); + } + + Ok(()) + } + + /// Remaps a nodes' potentially contained children and decorators to their new IDs according to + /// the given maps. + fn remap_node(&self, forest_idx: usize, node: &MastNode) -> Result { + let map_decorator_id = |decorator_id: &DecoratorId| { + self.decorator_id_mappings[forest_idx].get(decorator_id).ok_or_else(|| { + MastForestError::DecoratorIdOverflow( + *decorator_id, + self.decorator_id_mappings[forest_idx].len(), + ) + }) + }; + let map_decorators = |decorators: &[DecoratorId]| -> Result, MastForestError> { + decorators.iter().map(map_decorator_id).collect() + }; + + let map_node_id = |node_id: MastNodeId| { + self.node_id_mappings[forest_idx] + .get(&node_id) + .copied() + .expect("every node id should have an entry") + }; + + // Due to DFS postorder iteration all children of node's should have been inserted before + // their parents which is why we can `expect` the constructor calls here. + let mut mapped_node = match node { + MastNode::Join(join_node) => { + let first = map_node_id(join_node.first()); + let second = map_node_id(join_node.second()); + + MastNode::new_join(first, second, &self.mast_forest) + .expect("JoinNode children should have been mapped to a lower index") + }, + MastNode::Split(split_node) => { + let if_branch = map_node_id(split_node.on_true()); + let else_branch = map_node_id(split_node.on_false()); + + MastNode::new_split(if_branch, else_branch, &self.mast_forest) + .expect("SplitNode children should have been mapped to a lower index") + }, + MastNode::Loop(loop_node) => { + let body = map_node_id(loop_node.body()); + MastNode::new_loop(body, &self.mast_forest) + .expect("LoopNode children should have been mapped to a lower index") + }, + MastNode::Call(call_node) => { + let callee = map_node_id(call_node.callee()); + MastNode::new_call(callee, &self.mast_forest) + .expect("CallNode children should have been mapped to a lower index") + }, + // Other nodes are simply copied. + MastNode::Block(basic_block_node) => { + MastNode::new_basic_block( + basic_block_node.operations().copied().collect(), + // Operation Indices of decorators stay the same while decorator IDs need to be + // mapped. + Some( + basic_block_node + .decorators() + .iter() + .map(|(idx, decorator_id)| match map_decorator_id(decorator_id) { + Ok(mapped_decorator) => Ok((*idx, mapped_decorator)), + Err(err) => Err(err), + }) + .collect::, _>>()?, + ), + ) + .expect("previously valid BasicBlockNode should still be valid") + }, + MastNode::Dyn(_) => MastNode::new_dyn(), + MastNode::External(external_node) => MastNode::new_external(external_node.digest()), + }; + + // Decorators must be handled specially for basic block nodes. + // For other node types we can handle it centrally. + if !mapped_node.is_basic_block() { + mapped_node.set_before_enter(map_decorators(node.before_enter())?); + mapped_node.set_after_exit(map_decorators(node.after_exit())?); + } + + Ok(mapped_node) + } + + // HELPERS + // ================================================================================================ + + /// Returns a slice of nodes in the merged forest which have the given `mast_root`. + fn lookup_node_by_fingerprint(&self, fingerprint: &MastNodeFingerprint) -> Option { + self.node_id_by_hash.get(fingerprint).copied() + } +} + +// MAST FOREST ROOT MAP +// ================================================================================================ + +/// A mapping for the new location of the roots of a [`MastForest`] after a merge. +/// +/// It maps the roots ([`MastNodeId`]s) of a forest to their new [`MastNodeId`] in the merged +/// forest. See [`MastForest::merge`] for more details. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MastForestRootMap { + root_maps: Vec>, +} + +impl MastForestRootMap { + fn from_node_id_map(id_map: Vec, forests: Vec<&MastForest>) -> Self { + let mut root_maps = vec![BTreeMap::new(); forests.len()]; + + for (forest_idx, forest) in forests.into_iter().enumerate() { + for root in forest.procedure_roots() { + let new_id = id_map[forest_idx] + .get(root) + .copied() + .expect("every node id should be mapped to its new id"); + root_maps[forest_idx].insert(*root, new_id); + } + } + + Self { root_maps } + } + + /// Maps the given root to its new location in the merged forest, if such a mapping exists. + /// + /// It is guaranteed that every root of the map's corresponding forest is contained in the map. + pub fn map_root(&self, forest_index: usize, root: &MastNodeId) -> Option { + self.root_maps.get(forest_index).and_then(|map| map.get(root)).copied() + } +} + +// DECORATOR ID MAP +// ================================================================================================ + +/// A specialized map from [`DecoratorId`] -> [`DecoratorId`]. +/// +/// When mapping Decorator IDs during merging, we always map all IDs of the merging +/// forest to new ids. Hence it is more efficient to use a `Vec` instead of, say, a `BTreeMap`. +/// +/// In other words, this type is similar to `BTreeMap` but takes advantage of the fact that +/// the keys are contiguous. +/// +/// This type is meant to encapsulates some guarantees: +/// +/// - Indexing into the vector for any ID is safe if that ID is valid for the corresponding forest. +/// Despite that, we still cannot index unconditionally in case a node with invalid +/// [`DecoratorId`]s is passed to `merge`. +/// - The entry itself can be either None or Some. However: +/// - For `DecoratorId`s we iterate and insert all decorators into this map before retrieving any +/// entry, so all entries contain `Some`. Because of this, we can use `expect` in `get` for the +/// `Option` value. +/// - Similarly, inserting any ID from the corresponding forest is safe as the map contains a +/// pre-allocated `Vec` of the appropriate size. +struct DecoratorIdMap { + inner: Vec>, +} + +impl DecoratorIdMap { + fn new(num_ids: usize) -> Self { + Self { inner: vec![None; num_ids] } + } + + /// Maps the given key to the given value. + /// + /// It is the caller's responsibility to only pass keys that belong to the forest for which this + /// map was originally created. + fn insert(&mut self, key: DecoratorId, value: DecoratorId) { + self.inner[key.as_usize()] = Some(value); + } + + /// Retrieves the value for the given key. + fn get(&self, key: &DecoratorId) -> Option { + self.inner + .get(key.as_usize()) + .map(|id| id.expect("every id should have a Some entry in the map when calling get")) + } + + fn len(&self) -> usize { + self.inner.len() + } +} + +/// A type definition for increased readability in function signatures. +type MastForestNodeIdMap = BTreeMap; diff --git a/core/src/mast/merger/tests.rs b/core/src/mast/merger/tests.rs new file mode 100644 index 0000000000..b33ae97296 --- /dev/null +++ b/core/src/mast/merger/tests.rs @@ -0,0 +1,796 @@ +use miden_crypto::{hash::rpo::RpoDigest, ONE}; + +use super::*; +use crate::{Decorator, Operation}; + +fn block_foo() -> MastNode { + MastNode::new_basic_block(vec![Operation::Mul, Operation::Add], None).unwrap() +} + +fn block_bar() -> MastNode { + MastNode::new_basic_block(vec![Operation::And, Operation::Eq], None).unwrap() +} + +fn block_qux() -> MastNode { + MastNode::new_basic_block(vec![Operation::Swap, Operation::Push(ONE), Operation::Eq], None) + .unwrap() +} + +/// Asserts that the given forest contains exactly one node with the given digest. +/// +/// Returns a Result which can be unwrapped in the calling test function to assert. This way, if +/// this assertion fails it'll be clear which exact call failed. +fn assert_contains_node_once(forest: &MastForest, digest: RpoDigest) -> Result<(), &str> { + if forest.nodes.iter().filter(|node| node.digest() == digest).count() != 1 { + return Err("node digest contained more than once in the forest"); + } + + Ok(()) +} + +/// Asserts that every root of an original forest has an id to which it is mapped and that this +/// mapped root is in the set of roots in the merged forest. +/// +/// Returns a Result which can be unwrapped in the calling test function to assert. This way, if +/// this assertion fails it'll be clear which exact call failed. +fn assert_root_mapping( + root_map: &MastForestRootMap, + original_roots: Vec<&[MastNodeId]>, + merged_roots: &[MastNodeId], +) -> Result<(), &'static str> { + for (forest_idx, original_root) in original_roots.into_iter().enumerate() { + for root in original_root { + let mapped_root = root_map.map_root(forest_idx, root).unwrap(); + if !merged_roots.contains(&mapped_root) { + return Err("merged root does not contain mapped root"); + } + } + } + + Ok(()) +} + +/// Asserts that all children of nodes in the given forest have an id that is less than the parent's +/// ID. +/// +/// Returns a Result which can be unwrapped in the calling test function to assert. This way, if +/// this assertion fails it'll be clear which exact call failed. +fn assert_child_id_lt_parent_id(forest: &MastForest) -> Result<(), &str> { + for (mast_node_id, node) in forest.nodes().iter().enumerate() { + match node { + MastNode::Join(join_node) => { + if !join_node.first().as_usize() < mast_node_id { + return Err("join node first child id is not < parent id"); + }; + if !join_node.second().as_usize() < mast_node_id { + return Err("join node second child id is not < parent id"); + } + }, + MastNode::Split(split_node) => { + if !split_node.on_true().as_usize() < mast_node_id { + return Err("split node on true id is not < parent id"); + } + if !split_node.on_false().as_usize() < mast_node_id { + return Err("split node on false id is not < parent id"); + } + }, + MastNode::Loop(loop_node) => { + if !loop_node.body().as_usize() < mast_node_id { + return Err("loop node body id is not < parent id"); + } + }, + MastNode::Call(call_node) => { + if !call_node.callee().as_usize() < mast_node_id { + return Err("call node callee id is not < parent id"); + } + }, + MastNode::Block(_) => (), + MastNode::Dyn(_) => (), + MastNode::External(_) => (), + } + } + + Ok(()) +} + +/// Tests that Call(bar) still correctly calls the remapped bar block. +/// +/// [Block(foo), Call(foo)] +/// + +/// [Block(bar), Call(bar)] +/// = +/// [Block(foo), Call(foo), Block(bar), Call(bar)] +#[test] +fn mast_forest_merge_remap() { + let mut forest_a = MastForest::new(); + let id_foo = forest_a.add_node(block_foo()).unwrap(); + let id_call_a = forest_a.add_call(id_foo).unwrap(); + forest_a.make_root(id_call_a); + + let mut forest_b = MastForest::new(); + let id_bar = forest_b.add_node(block_bar()).unwrap(); + let id_call_b = forest_b.add_call(id_bar).unwrap(); + forest_b.make_root(id_call_b); + + let (merged, root_maps) = MastForest::merge([&forest_a, &forest_b]).unwrap(); + + assert_eq!(merged.nodes().len(), 4); + assert_eq!(merged.nodes()[0], block_foo()); + assert_matches!(&merged.nodes()[1], MastNode::Call(call_node) if call_node.callee().as_u32() == 0); + assert_eq!(merged.nodes()[2], block_bar()); + assert_matches!(&merged.nodes()[3], MastNode::Call(call_node) if call_node.callee().as_u32() == 2); + + assert_eq!(root_maps.map_root(0, &id_call_a).unwrap().as_u32(), 1); + assert_eq!(root_maps.map_root(1, &id_call_b).unwrap().as_u32(), 3); + + assert_child_id_lt_parent_id(&merged).unwrap(); +} + +/// Tests that Forest_A + Forest_A = Forest_A (i.e. duplicates are removed). +#[test] +fn mast_forest_merge_duplicate() { + let mut forest_a = MastForest::new(); + forest_a.add_decorator(Decorator::Debug(crate::DebugOptions::MemAll)).unwrap(); + forest_a.add_decorator(Decorator::Trace(25)).unwrap(); + + let id_external = forest_a.add_external(block_bar().digest()).unwrap(); + let id_foo = forest_a.add_node(block_foo()).unwrap(); + let id_call = forest_a.add_call(id_foo).unwrap(); + let id_loop = forest_a.add_loop(id_external).unwrap(); + forest_a.make_root(id_call); + forest_a.make_root(id_loop); + + let (merged, root_maps) = MastForest::merge([&forest_a, &forest_a]).unwrap(); + + for merged_root in merged.procedure_digests() { + forest_a.procedure_digests().find(|root| root == &merged_root).unwrap(); + } + + // Both maps should map the roots to the same target id. + for original_root in forest_a.procedure_roots() { + assert_eq!(&root_maps.map_root(0, original_root), &root_maps.map_root(1, original_root)); + } + + for merged_node in merged.nodes().iter().map(MastNode::digest) { + forest_a.nodes.iter().find(|node| node.digest() == merged_node).unwrap(); + } + + for merged_decorator in merged.decorators.iter() { + assert!(forest_a.decorators.contains(merged_decorator)); + } + + assert_child_id_lt_parent_id(&merged).unwrap(); +} + +/// Tests that External(foo) is replaced by Block(foo) whether it is in forest A or B, and the +/// duplicate Call is removed. +/// +/// [External(foo), Call(foo)] +/// + +/// [Block(foo), Call(foo)] +/// = +/// [Block(foo), Call(foo)] +/// + +/// [External(foo), Call(foo)] +/// = +/// [Block(foo), Call(foo)] +#[test] +fn mast_forest_merge_replace_external() { + let mut forest_a = MastForest::new(); + let id_foo_a = forest_a.add_external(block_foo().digest()).unwrap(); + let id_call_a = forest_a.add_call(id_foo_a).unwrap(); + forest_a.make_root(id_call_a); + + let mut forest_b = MastForest::new(); + let id_foo_b = forest_b.add_node(block_foo()).unwrap(); + let id_call_b = forest_b.add_call(id_foo_b).unwrap(); + forest_b.make_root(id_call_b); + + let (merged_ab, root_maps_ab) = MastForest::merge([&forest_a, &forest_b]).unwrap(); + let (merged_ba, root_maps_ba) = MastForest::merge([&forest_b, &forest_a]).unwrap(); + + for (merged, root_map) in [(merged_ab, root_maps_ab), (merged_ba, root_maps_ba)] { + assert_eq!(merged.nodes().len(), 2); + assert_eq!(merged.nodes()[0], block_foo()); + assert_matches!(&merged.nodes()[1], MastNode::Call(call_node) if call_node.callee().as_u32() == 0); + // The only root node should be the call node. + assert_eq!(merged.roots.len(), 1); + assert_eq!(root_map.map_root(0, &id_call_a).unwrap().as_usize(), 1); + assert_eq!(root_map.map_root(1, &id_call_b).unwrap().as_usize(), 1); + assert_child_id_lt_parent_id(&merged).unwrap(); + } +} + +/// Test that roots are preserved and deduplicated if appropriate. +/// +/// Nodes: [Block(foo), Call(foo)] +/// Roots: [Call(foo)] +/// + +/// Nodes: [Block(foo), Block(bar), Call(foo)] +/// Roots: [Block(bar), Call(foo)] +/// = +/// Nodes: [Block(foo), Block(bar), Call(foo)] +/// Roots: [Block(bar), Call(foo)] +#[test] +fn mast_forest_merge_roots() { + let mut forest_a = MastForest::new(); + let id_foo_a = forest_a.add_node(block_foo()).unwrap(); + let call_a = forest_a.add_call(id_foo_a).unwrap(); + forest_a.make_root(call_a); + + let mut forest_b = MastForest::new(); + let id_foo_b = forest_b.add_node(block_foo()).unwrap(); + let id_bar_b = forest_b.add_node(block_bar()).unwrap(); + let call_b = forest_b.add_call(id_foo_b).unwrap(); + forest_b.make_root(id_bar_b); + forest_b.make_root(call_b); + + let root_digest_call_a = forest_a.get_node_by_id(call_a).unwrap().digest(); + let root_digest_bar_b = forest_b.get_node_by_id(id_bar_b).unwrap().digest(); + let root_digest_call_b = forest_b.get_node_by_id(call_b).unwrap().digest(); + + let (merged, root_maps) = MastForest::merge([&forest_a, &forest_b]).unwrap(); + + // Asserts (together with the other assertions) that the duplicate Call(foo) roots have been + // deduplicated. + assert_eq!(merged.procedure_roots().len(), 2); + + // Assert that all root digests from A an B are still roots in the merged forest. + let root_digests = merged.procedure_digests().collect::>(); + assert!(root_digests.contains(&root_digest_call_a)); + assert!(root_digests.contains(&root_digest_bar_b)); + assert!(root_digests.contains(&root_digest_call_b)); + + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots).unwrap(); + + assert_child_id_lt_parent_id(&merged).unwrap(); +} + +/// Test that multiple trees can be merged when the same merger is reused. +/// +/// Nodes: [Block(foo), Call(foo)] +/// Roots: [Call(foo)] +/// + +/// Nodes: [Block(foo), Block(bar), Call(foo)] +/// Roots: [Block(bar), Call(foo)] +/// + +/// Nodes: [Block(foo), Block(qux), Call(foo)] +/// Roots: [Block(qux), Call(foo)] +/// = +/// Nodes: [Block(foo), Block(bar), Block(qux), Call(foo)] +/// Roots: [Block(bar), Block(qux), Call(foo)] +#[test] +fn mast_forest_merge_multiple() { + let mut forest_a = MastForest::new(); + let id_foo_a = forest_a.add_node(block_foo()).unwrap(); + let call_a = forest_a.add_call(id_foo_a).unwrap(); + forest_a.make_root(call_a); + + let mut forest_b = MastForest::new(); + let id_foo_b = forest_b.add_node(block_foo()).unwrap(); + let id_bar_b = forest_b.add_node(block_bar()).unwrap(); + let call_b = forest_b.add_call(id_foo_b).unwrap(); + forest_b.make_root(id_bar_b); + forest_b.make_root(call_b); + + let mut forest_c = MastForest::new(); + let id_foo_c = forest_c.add_node(block_foo()).unwrap(); + let id_qux_c = forest_c.add_node(block_qux()).unwrap(); + let call_c = forest_c.add_call(id_foo_c).unwrap(); + forest_c.make_root(id_qux_c); + forest_c.make_root(call_c); + + let (merged, root_maps) = MastForest::merge([&forest_a, &forest_b, &forest_c]).unwrap(); + + let block_foo_digest = forest_b.get_node_by_id(id_foo_b).unwrap().digest(); + let block_bar_digest = forest_b.get_node_by_id(id_bar_b).unwrap().digest(); + let call_foo_digest = forest_b.get_node_by_id(call_b).unwrap().digest(); + let block_qux_digest = forest_c.get_node_by_id(id_qux_c).unwrap().digest(); + + assert_eq!(merged.procedure_roots().len(), 3); + + let root_digests = merged.procedure_digests().collect::>(); + assert!(root_digests.contains(&call_foo_digest)); + assert!(root_digests.contains(&block_bar_digest)); + assert!(root_digests.contains(&block_qux_digest)); + + assert_contains_node_once(&merged, block_foo_digest).unwrap(); + assert_contains_node_once(&merged, block_bar_digest).unwrap(); + assert_contains_node_once(&merged, block_qux_digest).unwrap(); + assert_contains_node_once(&merged, call_foo_digest).unwrap(); + + assert_root_mapping( + &root_maps, + vec![&forest_a.roots, &forest_b.roots, &forest_c.roots], + &merged.roots, + ) + .unwrap(); + + assert_child_id_lt_parent_id(&merged).unwrap(); +} + +/// Tests that decorators are merged and that nodes who are identical except for their +/// decorators are not deduplicated. +/// +/// Note in particular that the `Loop` nodes only differ in their decorator which ensures that +/// the merging takes decorators into account. +/// +/// Nodes: [Block(foo, [Trace(1), Trace(2)]), Loop(foo, [Trace(0), Trace(2)])] +/// Decorators: [Trace(0), Trace(1), Trace(2)] +/// + +/// Nodes: [Block(foo, [Trace(1), Trace(2)]), Loop(foo, [Trace(1), Trace(3)])] +/// Decorators: [Trace(1), Trace(2), Trace(3)] +/// = +/// Nodes: [ +/// Block(foo, [Trace(1), Trace(2)]), +/// Loop(foo, [Trace(0), Trace(2)]), +/// Loop(foo, [Trace(1), Trace(3)]), +/// ] +/// Decorators: [Trace(0), Trace(1), Trace(2), Trace(3)] +#[test] +fn mast_forest_merge_decorators() { + let mut forest_a = MastForest::new(); + let trace0 = Decorator::Trace(0); + let trace1 = Decorator::Trace(1); + let trace2 = Decorator::Trace(2); + let trace3 = Decorator::Trace(3); + + // Build Forest A + let deco0_a = forest_a.add_decorator(trace0.clone()).unwrap(); + let deco1_a = forest_a.add_decorator(trace1.clone()).unwrap(); + let deco2_a = forest_a.add_decorator(trace2.clone()).unwrap(); + + let mut foo_node_a = block_foo(); + foo_node_a.set_before_enter(vec![deco1_a, deco2_a]); + let id_foo_a = forest_a.add_node(foo_node_a).unwrap(); + + let mut loop_node_a = MastNode::new_loop(id_foo_a, &forest_a).unwrap(); + loop_node_a.set_after_exit(vec![deco0_a, deco2_a]); + let id_loop_a = forest_a.add_node(loop_node_a).unwrap(); + + forest_a.make_root(id_loop_a); + + // Build Forest B + let mut forest_b = MastForest::new(); + let deco1_b = forest_b.add_decorator(trace1.clone()).unwrap(); + let deco2_b = forest_b.add_decorator(trace2.clone()).unwrap(); + let deco3_b = forest_b.add_decorator(trace3.clone()).unwrap(); + + // This foo node is identical to the one in A, including its decorators. + let mut foo_node_b = block_foo(); + foo_node_b.set_before_enter(vec![deco1_b, deco2_b]); + let id_foo_b = forest_b.add_node(foo_node_b).unwrap(); + + // This loop node's decorators are different from the loop node in a. + let mut loop_node_b = MastNode::new_loop(id_foo_b, &forest_b).unwrap(); + loop_node_b.set_after_exit(vec![deco1_b, deco3_b]); + let id_loop_b = forest_b.add_node(loop_node_b).unwrap(); + + forest_b.make_root(id_loop_b); + + let (merged, root_maps) = MastForest::merge([&forest_a, &forest_b]).unwrap(); + + // There are 4 unique decorators across both forests. + assert_eq!(merged.decorators.len(), 4); + assert!(merged.decorators.contains(&trace0)); + assert!(merged.decorators.contains(&trace1)); + assert!(merged.decorators.contains(&trace2)); + assert!(merged.decorators.contains(&trace3)); + + let find_decorator_id = |deco: &Decorator| { + let idx = merged + .decorators + .iter() + .enumerate() + .find_map( + |(deco_id, forest_deco)| if forest_deco == deco { Some(deco_id) } else { None }, + ) + .unwrap(); + DecoratorId::from_u32_safe(idx as u32, &merged).unwrap() + }; + + let merged_deco0 = find_decorator_id(&trace0); + let merged_deco1 = find_decorator_id(&trace1); + let merged_deco2 = find_decorator_id(&trace2); + let merged_deco3 = find_decorator_id(&trace3); + + assert_eq!(merged.nodes.len(), 3); + + let merged_foo_block = merged.nodes.iter().find(|node| node.is_basic_block()).unwrap(); + let MastNode::Block(merged_foo_block) = merged_foo_block else { + panic!("expected basic block node"); + }; + + assert_eq!( + merged_foo_block.decorators().as_slice(), + &[(0, merged_deco1), (0, merged_deco2)] + ); + + // Asserts that there exists exactly one Loop Node with the given decorators. + assert_eq!( + merged + .nodes + .iter() + .filter(|node| { + if let MastNode::Loop(loop_node) = node { + loop_node.after_exit() == [merged_deco0, merged_deco2] + } else { + false + } + }) + .count(), + 1 + ); + + // Asserts that there exists exactly one Loop Node with the given decorators. + assert_eq!( + merged + .nodes + .iter() + .filter(|node| { + if let MastNode::Loop(loop_node) = node { + loop_node.after_exit() == [merged_deco1, merged_deco3] + } else { + false + } + }) + .count(), + 1 + ); + + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots).unwrap(); + + assert_child_id_lt_parent_id(&merged).unwrap(); +} + +/// Tests that an external node without decorators is replaced by its referenced node which has +/// decorators. +/// +/// [External(foo)] +/// + +/// [Block(foo, Trace(1))] +/// = +/// [Block(foo, Trace(1))] +/// + +/// [External(foo)] +/// = +/// [Block(foo, Trace(1))] +#[test] +fn mast_forest_merge_external_node_reference_with_decorator() { + let mut forest_a = MastForest::new(); + let trace = Decorator::Trace(1); + + // Build Forest A + let deco = forest_a.add_decorator(trace.clone()).unwrap(); + + let mut foo_node_a = block_foo(); + foo_node_a.set_before_enter(vec![deco]); + let foo_node_digest = foo_node_a.digest(); + let id_foo_a = forest_a.add_node(foo_node_a).unwrap(); + + forest_a.make_root(id_foo_a); + + // Build Forest B + let mut forest_b = MastForest::new(); + let id_external_b = forest_b.add_external(foo_node_digest).unwrap(); + + forest_b.make_root(id_external_b); + + for (idx, (merged, root_maps)) in [ + MastForest::merge([&forest_a, &forest_b]).unwrap(), + MastForest::merge([&forest_b, &forest_a]).unwrap(), + ] + .into_iter() + .enumerate() + { + let id_foo_a_fingerprint = + MastNodeFingerprint::from_mast_node(&forest_a, &BTreeMap::new(), &forest_a[id_foo_a]); + + let fingerprints: Vec<_> = merged + .nodes() + .iter() + .map(|node| MastNodeFingerprint::from_mast_node(&merged, &BTreeMap::new(), node)) + .collect(); + + assert_eq!(merged.nodes.len(), 1); + assert!(fingerprints.contains(&id_foo_a_fingerprint)); + + if idx == 0 { + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots) + .unwrap(); + } else { + assert_root_mapping(&root_maps, vec![&forest_b.roots, &forest_a.roots], &merged.roots) + .unwrap(); + } + + assert_child_id_lt_parent_id(&merged).unwrap(); + } +} + +/// Tests that an external node with decorators is replaced by its referenced node which does not +/// have decorators. +/// +/// [External(foo, Trace(1), Trace(2))] +/// + +/// [Block(foo)] +/// = +/// [Block(foo)] +/// + +/// [External(foo, Trace(1), Trace(2))] +/// = +/// [Block(foo)] +#[test] +fn mast_forest_merge_external_node_with_decorator() { + let mut forest_a = MastForest::new(); + let trace1 = Decorator::Trace(1); + let trace2 = Decorator::Trace(2); + + // Build Forest A + let deco1 = forest_a.add_decorator(trace1.clone()).unwrap(); + let deco2 = forest_a.add_decorator(trace2.clone()).unwrap(); + + let mut external_node_a = MastNode::new_external(block_foo().digest()); + external_node_a.set_before_enter(vec![deco1]); + external_node_a.set_after_exit(vec![deco2]); + let id_external_a = forest_a.add_node(external_node_a).unwrap(); + + forest_a.make_root(id_external_a); + + // Build Forest B + let mut forest_b = MastForest::new(); + let id_foo_b = forest_b.add_node(block_foo()).unwrap(); + + forest_b.make_root(id_foo_b); + + for (idx, (merged, root_maps)) in [ + MastForest::merge([&forest_a, &forest_b]).unwrap(), + MastForest::merge([&forest_b, &forest_a]).unwrap(), + ] + .into_iter() + .enumerate() + { + assert_eq!(merged.nodes.len(), 1); + + let id_foo_b_fingerprint = + MastNodeFingerprint::from_mast_node(&forest_a, &BTreeMap::new(), &forest_b[id_foo_b]); + + let fingerprints: Vec<_> = merged + .nodes() + .iter() + .map(|node| MastNodeFingerprint::from_mast_node(&merged, &BTreeMap::new(), node)) + .collect(); + + // Block foo should be unmodified. + assert!(fingerprints.contains(&id_foo_b_fingerprint)); + + if idx == 0 { + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots) + .unwrap(); + } else { + assert_root_mapping(&root_maps, vec![&forest_b.roots, &forest_a.roots], &merged.roots) + .unwrap(); + } + + assert_child_id_lt_parent_id(&merged).unwrap(); + } +} + +/// Tests that an external node with decorators is replaced by its referenced node which also has +/// decorators. +/// +/// [External(foo, Trace(1))] +/// + +/// [Block(foo, Trace(2))] +/// = +/// [Block(foo, Trace(2))] +/// + +/// [External(foo, Trace(1))] +/// = +/// [Block(foo, Trace(2))] +#[test] +fn mast_forest_merge_external_node_and_referenced_node_have_decorators() { + let mut forest_a = MastForest::new(); + let trace1 = Decorator::Trace(1); + let trace2 = Decorator::Trace(2); + + // Build Forest A + let deco1_a = forest_a.add_decorator(trace1.clone()).unwrap(); + + let mut external_node_a = MastNode::new_external(block_foo().digest()); + external_node_a.set_before_enter(vec![deco1_a]); + let id_external_a = forest_a.add_node(external_node_a).unwrap(); + + forest_a.make_root(id_external_a); + + // Build Forest B + let mut forest_b = MastForest::new(); + let deco2_b = forest_b.add_decorator(trace2.clone()).unwrap(); + + let mut foo_node_b = block_foo(); + foo_node_b.set_before_enter(vec![deco2_b]); + let id_foo_b = forest_b.add_node(foo_node_b).unwrap(); + + forest_b.make_root(id_foo_b); + + for (idx, (merged, root_maps)) in [ + MastForest::merge([&forest_a, &forest_b]).unwrap(), + MastForest::merge([&forest_b, &forest_a]).unwrap(), + ] + .into_iter() + .enumerate() + { + assert_eq!(merged.nodes.len(), 1); + + let id_foo_b_fingerprint = + MastNodeFingerprint::from_mast_node(&forest_b, &BTreeMap::new(), &forest_b[id_foo_b]); + + let fingerprints: Vec<_> = merged + .nodes() + .iter() + .map(|node| MastNodeFingerprint::from_mast_node(&merged, &BTreeMap::new(), node)) + .collect(); + + // Block foo should be unmodified. + assert!(fingerprints.contains(&id_foo_b_fingerprint)); + + if idx == 0 { + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots) + .unwrap(); + } else { + assert_root_mapping(&root_maps, vec![&forest_b.roots, &forest_a.roots], &merged.roots) + .unwrap(); + } + + assert_child_id_lt_parent_id(&merged).unwrap(); + } +} + +/// Tests that two external nodes with the same MAST root are deduplicated during merging and then +/// replaced by a block with the matching digest. +/// +/// [External(foo, Trace(1), Trace(2)), +/// External(foo, Trace(1))] +/// + +/// [Block(foo, Trace(1))] +/// = +/// [Block(foo, Trace(1))] +/// + +/// [External(foo, Trace(1), Trace(2)), +/// External(foo, Trace(1))] +/// = +/// [Block(foo, Trace(1))] +#[test] +fn mast_forest_merge_multiple_external_nodes_with_decorator() { + let mut forest_a = MastForest::new(); + let trace1 = Decorator::Trace(1); + let trace2 = Decorator::Trace(2); + + // Build Forest A + let deco1_a = forest_a.add_decorator(trace1.clone()).unwrap(); + let deco2_a = forest_a.add_decorator(trace2.clone()).unwrap(); + + let mut external_node_a = MastNode::new_external(block_foo().digest()); + external_node_a.set_before_enter(vec![deco1_a]); + external_node_a.set_after_exit(vec![deco2_a]); + let id_external_a = forest_a.add_node(external_node_a).unwrap(); + + let mut external_node_b = MastNode::new_external(block_foo().digest()); + external_node_b.set_before_enter(vec![deco1_a]); + let id_external_b = forest_a.add_node(external_node_b).unwrap(); + + forest_a.make_root(id_external_a); + forest_a.make_root(id_external_b); + + // Build Forest B + let mut forest_b = MastForest::new(); + let deco1_b = forest_b.add_decorator(trace1).unwrap(); + let mut block_foo_b = block_foo(); + block_foo_b.set_before_enter(vec![deco1_b]); + let id_foo_b = forest_b.add_node(block_foo_b).unwrap(); + + forest_b.make_root(id_foo_b); + + for (idx, (merged, root_maps)) in [ + MastForest::merge([&forest_a, &forest_b]).unwrap(), + MastForest::merge([&forest_b, &forest_a]).unwrap(), + ] + .into_iter() + .enumerate() + { + assert_eq!(merged.nodes.len(), 1); + + let id_foo_b_fingerprint = + MastNodeFingerprint::from_mast_node(&forest_a, &BTreeMap::new(), &forest_b[id_foo_b]); + + let fingerprints: Vec<_> = merged + .nodes() + .iter() + .map(|node| MastNodeFingerprint::from_mast_node(&merged, &BTreeMap::new(), node)) + .collect(); + + // Block foo should be unmodified. + assert!(fingerprints.contains(&id_foo_b_fingerprint)); + + if idx == 0 { + assert_root_mapping(&root_maps, vec![&forest_a.roots, &forest_b.roots], &merged.roots) + .unwrap(); + } else { + assert_root_mapping(&root_maps, vec![&forest_b.roots, &forest_a.roots], &merged.roots) + .unwrap(); + } + + assert_child_id_lt_parent_id(&merged).unwrap(); + } +} + +/// Tests that dependencies between External nodes are correctly resolved. +/// +/// [External(foo), Call(0) = qux] +/// + +/// [External(qux), Call(0), Block(foo)] +/// = +/// [External(qux), Call(0), Block(foo)] +/// + +/// [External(foo), Call(0) = qux] +/// = +/// [Block(foo), Call(0), Call(1)] +#[test] +fn mast_forest_merge_external_dependencies() { + let mut forest_a = MastForest::new(); + let id_foo_a = forest_a.add_external(block_qux().digest()).unwrap(); + let id_call_a = forest_a.add_call(id_foo_a).unwrap(); + forest_a.make_root(id_call_a); + + let mut forest_b = MastForest::new(); + let id_ext_b = forest_b.add_external(forest_a[id_call_a].digest()).unwrap(); + let id_call_b = forest_b.add_call(id_ext_b).unwrap(); + let id_qux_b = forest_b.add_node(block_qux()).unwrap(); + forest_b.make_root(id_call_b); + forest_b.make_root(id_qux_b); + + for (merged, _) in [ + MastForest::merge([&forest_a, &forest_b]).unwrap(), + MastForest::merge([&forest_b, &forest_a]).unwrap(), + ] + .into_iter() + { + let digests = merged.nodes().iter().map(|node| node.digest()).collect::>(); + assert_eq!(merged.nodes().len(), 3); + assert!(digests.contains(&forest_b[id_ext_b].digest())); + assert!(digests.contains(&forest_b[id_call_b].digest())); + assert!(digests.contains(&forest_a[id_foo_a].digest())); + assert!(digests.contains(&forest_a[id_call_a].digest())); + assert!(digests.contains(&forest_b[id_qux_b].digest())); + assert_eq!(merged.nodes().iter().filter(|node| node.is_external()).count(), 0); + + assert_child_id_lt_parent_id(&merged).unwrap(); + } +} + +/// Tests that a forest with nodes who reference non-existent decorators return an error during +/// merging and does not panic. +#[test] +fn mast_forest_merge_invalid_decorator_index() { + let trace1 = Decorator::Trace(1); + let trace2 = Decorator::Trace(2); + + // Build Forest A + let mut forest_a = MastForest::new(); + let deco1_a = forest_a.add_decorator(trace1.clone()).unwrap(); + let deco2_a = forest_a.add_decorator(trace2.clone()).unwrap(); + let id_bar_a = forest_a.add_node(block_bar()).unwrap(); + + forest_a.make_root(id_bar_a); + + // Build Forest B + let mut forest_b = MastForest::new(); + let mut block_b = block_foo(); + // We're using a DecoratorId from forest A which is invalid. + block_b.set_before_enter(vec![deco1_a, deco2_a]); + let id_foo_b = forest_b.add_node(block_b).unwrap(); + + forest_b.make_root(id_foo_b); + + let err = MastForest::merge([&forest_a, &forest_b]).unwrap_err(); + assert_matches!(err, MastForestError::DecoratorIdOverflow(_, _)); +} diff --git a/core/src/mast/mod.rs b/core/src/mast/mod.rs index 87a6daa354..444fedc5c6 100644 --- a/core/src/mast/mod.rs +++ b/core/src/mast/mod.rs @@ -2,7 +2,10 @@ use alloc::{ collections::{BTreeMap, BTreeSet}, vec::Vec, }; -use core::{fmt, mem, ops::Index}; +use core::{ + fmt, mem, + ops::{Index, IndexMut}, +}; use miden_crypto::hash::rpo::RpoDigest; @@ -11,12 +14,22 @@ pub use node::{ BasicBlockNode, CallNode, DynNode, ExternalNode, JoinNode, LoopNode, MastNode, OpBatch, OperationOrDecorator, SplitNode, OP_BATCH_SIZE, OP_GROUP_SIZE, }; -use winter_utils::DeserializationError; +use winter_utils::{ByteWriter, DeserializationError, Serializable}; -use crate::{DecoratorList, Operation}; +use crate::{Decorator, DecoratorList, Operation}; mod serialization; +mod merger; +pub(crate) use merger::MastForestMerger; +pub use merger::MastForestRootMap; + +mod multi_forest_node_iterator; +pub(crate) use multi_forest_node_iterator::*; + +mod node_fingerprint; +pub use node_fingerprint::{DecoratorFingerprint, MastNodeFingerprint}; + #[cfg(test)] mod tests; @@ -34,6 +47,9 @@ pub struct MastForest { /// Roots of procedures defined within this MAST forest. roots: Vec, + + /// All the decorators included in the MAST forest. + decorators: Vec, } // ------------------------------------------------------------------------------------------------ @@ -50,6 +66,20 @@ impl MastForest { impl MastForest { /// The maximum number of nodes that can be stored in a single MAST forest. const MAX_NODES: usize = (1 << 30) - 1; + /// The maximum number of decorators that can be stored in a single MAST forest. + const MAX_DECORATORS: usize = Self::MAX_NODES; + + /// Adds a decorator to the forest, and returns the associated [`DecoratorId`]. + pub fn add_decorator(&mut self, decorator: Decorator) -> Result { + if self.decorators.len() >= u32::MAX as usize { + return Err(MastForestError::TooManyDecorators); + } + + let new_decorator_id = DecoratorId(self.decorators.len() as u32); + self.decorators.push(decorator); + + Ok(new_decorator_id) + } /// Adds a node to the forest, and returns the associated [`MastNodeId`]. /// @@ -118,6 +148,11 @@ impl MastForest { self.add_node(MastNode::new_dyn()) } + /// Adds a dyncall node to the forest, and returns the [`MastNodeId`] associated with it. + pub fn add_dyncall(&mut self) -> Result { + self.add_node(MastNode::new_dyncall()) + } + /// Adds an external node to the forest, and returns the [`MastNodeId`] associated with it. pub fn add_external(&mut self, mast_root: RpoDigest) -> Result { self.add_node(MastNode::new_external(mast_root)) @@ -125,6 +160,8 @@ impl MastForest { /// Marks the given [`MastNodeId`] as being the root of a procedure. /// + /// If the specified node is already marked as a root, this will have no effect. + /// /// # Panics /// - if `new_root_id`'s internal index is larger than the number of nodes in this forest (i.e. /// clearly doesn't belong to this MAST forest). @@ -160,6 +197,83 @@ impl MastForest { self.remap_and_add_roots(old_root_ids, &id_remappings); Some(id_remappings) } + + pub fn set_before_enter(&mut self, node_id: MastNodeId, decorator_ids: Vec) { + self[node_id].set_before_enter(decorator_ids) + } + + pub fn set_after_exit(&mut self, node_id: MastNodeId, decorator_ids: Vec) { + self[node_id].set_after_exit(decorator_ids) + } + + /// Merges all `forests` into a new [`MastForest`]. + /// + /// Merging two forests means combining all their constituent parts, i.e. [`MastNode`]s, + /// [`Decorator`]s and roots. During this process, any duplicate or + /// unreachable nodes are removed. Additionally, [`MastNodeId`]s of nodes as well as + /// [`DecoratorId`]s of decorators may change and references to them are remapped to their new + /// location. + /// + /// For example, consider this representation of a forest's nodes with all of these nodes being + /// roots: + /// + /// ```text + /// [Block(foo), Block(bar)] + /// ``` + /// + /// If we merge another forest into it: + /// + /// ```text + /// [Block(bar), Call(0)] + /// ``` + /// + /// then we would expect this forest: + /// + /// ```text + /// [Block(foo), Block(bar), Call(1)] + /// ``` + /// + /// - The `Call` to the `bar` block was remapped to its new index (now 1, previously 0). + /// - The `Block(bar)` was deduplicated any only exists once in the merged forest. + /// + /// The function also returns a vector of [`MastForestRootMap`]s, whose length equals the number + /// of passed `forests`. The indices in the vector correspond to the ones in `forests`. The map + /// of a given forest contains the new locations of its roots in the merged forest. To + /// illustrate, the above example would return a vector of two maps: + /// + /// ```text + /// vec![{0 -> 0, 1 -> 1} + /// {0 -> 1, 1 -> 2}] + /// ``` + /// + /// - The root locations of the original forest are unchanged. + /// - For the second forest, the `bar` block has moved from index 0 to index 1 in the merged + /// forest, and the `Call` has moved from index 1 to 2. + /// + /// If any forest being merged contains an `External(qux)` node and another forest contains a + /// node whose digest is `qux`, then the external node will be replaced with the `qux` node, + /// which is effectively deduplication. Decorators are ignored when it comes to merging + /// External nodes. This means that an External node with decorators may be replaced by a node + /// without decorators or vice versa. + pub fn merge<'forest>( + forests: impl IntoIterator, + ) -> Result<(MastForest, MastForestRootMap), MastForestError> { + MastForestMerger::merge(forests) + } + + /// Adds a basic block node to the forest, and returns the [`MastNodeId`] associated with it. + /// + /// It is assumed that the decorators have not already been added to the MAST forest. If they + /// were, they will be added again (and result in a different set of [`DecoratorId`]s). + #[cfg(test)] + pub fn add_block_with_raw_decorators( + &mut self, + operations: Vec, + decorators: Vec<(usize, Decorator)>, + ) -> Result { + let block = MastNode::new_basic_block_with_raw_decorators(operations, decorators, self)?; + self.add_node(block) + } } /// Helpers @@ -220,7 +334,7 @@ impl MastForest { self.add_call(callee_id).unwrap(); } }, - MastNode::Block(_) | MastNode::Dyn | MastNode::External(_) => { + MastNode::Block(_) | MastNode::Dyn(_) | MastNode::External(_) => { self.add_node(live_node).unwrap(); }, } @@ -275,10 +389,21 @@ fn remove_nodes( /// Public accessors impl MastForest { + /// Returns the [`Decorator`] associated with the provided [`DecoratorId`] if valid, or else + /// `None`. + /// + /// This is the fallible version of indexing (e.g. `mast_forest[decorator_id]`). + #[inline(always)] + pub fn get_decorator_by_id(&self, decorator_id: DecoratorId) -> Option<&Decorator> { + let idx = decorator_id.0 as usize; + + self.decorators.get(idx) + } + /// Returns the [`MastNode`] associated with the provided [`MastNodeId`] if valid, or else /// `None`. /// - /// This is the failable version of indexing (e.g. `mast_forest[node_id]`). + /// This is the fallible version of indexing (e.g. `mast_forest[node_id]`). #[inline(always)] pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> { let idx = node_id.0 as usize; @@ -351,6 +476,34 @@ impl Index for MastForest { } } +impl IndexMut for MastForest { + #[inline(always)] + fn index_mut(&mut self, node_id: MastNodeId) -> &mut Self::Output { + let idx = node_id.0 as usize; + + &mut self.nodes[idx] + } +} + +impl Index for MastForest { + type Output = Decorator; + + #[inline(always)] + fn index(&self, decorator_id: DecoratorId) -> &Self::Output { + let idx = decorator_id.0 as usize; + + &self.decorators[idx] + } +} + +impl IndexMut for MastForest { + #[inline(always)] + fn index_mut(&mut self, decorator_id: DecoratorId) -> &mut Self::Output { + let idx = decorator_id.0 as usize; + &mut self.decorators[idx] + } +} + // MAST NODE ID // ================================================================================================ @@ -372,13 +525,37 @@ impl MastNodeId { value: u32, mast_forest: &MastForest, ) -> Result { - if (value as usize) < mast_forest.nodes.len() { - Ok(Self(value)) + Self::from_u32_with_node_count(value, mast_forest.nodes.len()) + } + + /// Returns a new [`MastNodeId`] from the given `value` without checking its validity. + pub(crate) fn new_unchecked(value: u32) -> Self { + Self(value) + } + + /// Returns a new [`MastNodeId`] with the provided `id`, or an error if `id` is greater or equal + /// to `node_count`. The `node_count` is the total number of nodes in the [`MastForest`] for + /// which this ID is being constructed. + /// + /// This function can be used when deserializing an id whose corresponding node is not yet in + /// the forest and [`Self::from_u32_safe`] would fail. For instance, when deserializing the ids + /// referenced by the Join node in this forest: + /// + /// ```text + /// [Join(1, 2), Block(foo), Block(bar)] + /// ``` + /// + /// Since it is less safe than [`Self::from_u32_safe`] and usually not needed it is not public. + pub(super) fn from_u32_with_node_count( + id: u32, + node_count: usize, + ) -> Result { + if (id as usize) < node_count { + Ok(Self(id)) } else { Err(DeserializationError::InvalidValue(format!( - "Invalid deserialized MAST node ID '{}', but only {} nodes in the forest", - value, - mast_forest.nodes.len(), + "Invalid deserialized MAST node ID '{}', but {} is the number of nodes in the forest", + id, node_count, ))) } } @@ -416,12 +593,89 @@ impl fmt::Display for MastNodeId { } } +// DECORATOR ID +// ================================================================================================ + +/// An opaque handle to a [`Decorator`] in some [`MastForest`]. It is the responsibility of the user +/// to use a given [`DecoratorId`] with the corresponding [`MastForest`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct DecoratorId(u32); + +impl DecoratorId { + /// Returns a new `DecoratorId` with the provided inner value, or an error if the provided + /// `value` is greater than the number of nodes in the forest. + /// + /// For use in deserialization. + pub fn from_u32_safe( + value: u32, + mast_forest: &MastForest, + ) -> Result { + if (value as usize) < mast_forest.decorators.len() { + Ok(Self(value)) + } else { + Err(DeserializationError::InvalidValue(format!( + "Invalid deserialized MAST decorator id '{}', but only {} decorators in the forest", + value, + mast_forest.nodes.len(), + ))) + } + } + + /// Creates a new [`DecoratorId`] without checking its validity. + pub(crate) fn new_unchecked(value: u32) -> Self { + Self(value) + } + + pub fn as_usize(&self) -> usize { + self.0 as usize + } + + pub fn as_u32(&self) -> u32 { + self.0 + } +} + +impl From for usize { + fn from(value: DecoratorId) -> Self { + value.0 as usize + } +} + +impl From for u32 { + fn from(value: DecoratorId) -> Self { + value.0 + } +} + +impl From<&DecoratorId> for u32 { + fn from(value: &DecoratorId) -> Self { + value.0 + } +} + +impl fmt::Display for DecoratorId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "DecoratorId({})", self.0) + } +} + +impl Serializable for DecoratorId { + fn write_into(&self, target: &mut W) { + self.0.write_into(target) + } +} + // MAST FOREST ERROR // ================================================================================================ /// Represents the types of errors that can occur when dealing with MAST forest. #[derive(Debug, thiserror::Error, PartialEq)] pub enum MastForestError { + #[error( + "invalid decorator count: MAST forest exceeds the maximum of {} decorators", + u32::MAX + )] + TooManyDecorators, #[error( "invalid node count: MAST forest exceeds the maximum of {} nodes", MastForest::MAX_NODES @@ -429,6 +683,10 @@ pub enum MastForestError { TooManyNodes, #[error("node id: {0} is greater than or equal to forest length: {1}")] NodeIdOverflow(MastNodeId, usize), + #[error("decorator id: {0} is greater than or equal to decorator count: {1}")] + DecoratorIdOverflow(DecoratorId, usize), #[error("basic block cannot be created from an empty list of operations")] EmptyBasicBlock, + #[error("decorator root of child with node id {0} is missing but required for fingerprint computation")] + ChildFingerprintMissing(MastNodeId), } diff --git a/core/src/mast/multi_forest_node_iterator.rs b/core/src/mast/multi_forest_node_iterator.rs new file mode 100644 index 0000000000..be5ce1b5d2 --- /dev/null +++ b/core/src/mast/multi_forest_node_iterator.rs @@ -0,0 +1,490 @@ +use alloc::{ + collections::{BTreeMap, VecDeque}, + vec::Vec, +}; + +use miden_crypto::hash::rpo::RpoDigest; + +use crate::mast::{MastForest, MastForestError, MastNode, MastNodeId}; + +type ForestIndex = usize; + +/// Depth First Search Iterator in Post Order for [`MastForest`]s. +/// +/// This iterator iterates through all **reachable** nodes of all given forests exactly once. +/// +/// Since a `MastForest` has multiple possible entrypoints in the form of its roots, a depth-first +/// search must visit all of those roots and the trees they form. This iterator's `Item` is +/// [`MultiMastForestIteratorItem`]. It contains either a [`MultiMastForestIteratorItem::Node`] of a +/// forest, or the replacement of an external node. This is returned if one forest contains an +/// External node with digest `foo` and another forest contains a non-external node with digest +/// `foo`. In such a case the `foo` node is yielded first (unless it was already visited) and +/// subsequently a "replacement signal" ([`MultiMastForestIteratorItem::ExternalNodeReplacement`]) +/// for the external node is yielded to make the caller aware that this replacement has happened. +/// +/// All of this is useful to ensure that children are always processed before their parents, even if +/// a child is an External node which is replaced by a node in another forest. This guarantees that +/// **all [`MastNodeId`]s of child nodes are strictly less than the [`MastNodeId`] of their +/// parents**. +/// +/// For instance, consider these `MastForest`s being passed to this iterator with the `Call(0)`'s +/// digest being `qux`: +/// +/// ```text +/// Forest A Nodes: [Block(foo), External(qux), Join(0, 1)] +/// Forest A Roots: [2] +/// Forest B Nodes: [Block(bar), Call(0)] +/// Forest B Roots: [0] +/// ``` +/// +/// The only root of A is the `Join` node at index 2. The first three nodes of the forest form a +/// tree, since the `Join` node references index 0 and 1. This tree is discovered by +/// starting at the root at index 2 and following all children until we reach terminal nodes (like +/// `Block`s) and building up a deque of the discovered nodes. The special case here is the +/// `External` node whose digest matches that of a node in forest B. Instead of the External +/// node being added to the deque, the tree of the Call node is added instead. The deque is built +/// such that popping elements off the deque (from the front) yields a postorder. +/// +/// After the first tree is discovered, the deque looks like this: +/// +/// ```text +/// [Node(forest_idx: 0, node_id: 0), +/// Node(forest_idx: 1, node_id: 0), +/// Node(forest_idx: 1, node_id: 1), +/// ExternalNodeReplacement( +/// replacement_forest_idx: 1, replacement_node_id: 1 +/// replaced_forest_idx: 0, replaced_node_id: 1 +/// ), +/// Node(forest_idx: 0, node_id: 2)] +/// ``` +/// +/// If the deque is exhausted we start another discovery if more undiscovered roots exist. In this +/// example, the root of forest B was already discovered and visited due to the External node +/// reference, so the iteration is complete. +/// +/// The iteration on a higher level thus consists of a back and forth between discovering trees and +/// returning nodes from the deque. +pub(crate) struct MultiMastForestNodeIter<'forest> { + /// The forests that we're iterating. + mast_forests: Vec<&'forest MastForest>, + /// The index of the forest we're currently processing and discovering trees in. + /// + /// This value iterates through 0..mast_forests.len() which guarantees that we visit all + /// forests once. + current_forest_idx: ForestIndex, + /// The procedure root index at which we last started a tree discovery in the + /// current_forest_idx. + /// + /// This value iterates through 0..mast_forests[current_forest_idx].num_procedures() which + /// guarantees that we visit all nodes reachable from all roots. + current_procedure_root_idx: u32, + /// A map of MAST roots of all non-external nodes in mast_forests to their forest and node + /// indices. + non_external_nodes: BTreeMap, + /// Describes whether the node identified by [forest_index][node_index] has already been + /// discovered. Note that this is `true` for all nodes that are in the unvisited node deque. + discovered_nodes: Vec>, + /// This deque always contains the discovered, but unvisited nodes. + /// It holds that discovered_nodes[forest_idx][node_id] = true for all elements in this deque. + unvisited_nodes: VecDeque, +} + +impl<'forest> MultiMastForestNodeIter<'forest> { + /// Builds a map of MAST roots to non-external nodes in any of the given forests to initialize + /// the iterator. This enables an efficient check whether for any encountered External node + /// referencing digest `foo` a node with digest `foo` already exists in any forest. + pub(crate) fn new(mast_forests: Vec<&'forest MastForest>) -> Self { + let discovered_nodes = mast_forests + .iter() + .map(|forest| vec![false; forest.num_nodes() as usize]) + .collect(); + + let mut non_external_nodes = BTreeMap::new(); + + for (forest_idx, forest) in mast_forests.iter().enumerate() { + for (node_idx, node) in forest.nodes().iter().enumerate() { + // SAFETY: The passed id comes from the iterator over the nodes, so we never exceed + // the forest's number of nodes. + let node_id = MastNodeId::new_unchecked(node_idx as u32); + if !node.is_external() { + non_external_nodes.insert(node.digest(), (forest_idx, node_id)); + } + } + } + + Self { + mast_forests, + current_forest_idx: 0, + current_procedure_root_idx: 0, + non_external_nodes, + discovered_nodes, + unvisited_nodes: VecDeque::new(), + } + } + + /// Pushes the given node, uniquely identified by the forest and node index onto the deque + /// even if the node was already discovered before. + /// + /// It's the callers responsibility to only pass valid indices. + fn push_node(&mut self, forest_idx: usize, node_id: MastNodeId) { + self.unvisited_nodes + .push_back(MultiMastForestIteratorItem::Node { forest_idx, node_id }); + self.discovered_nodes[forest_idx][node_id.as_usize()] = true; + } + + /// Discovers a tree starting at the given forest index and node id. + /// + /// SAFETY: We only pass valid forest and node indices so we can index directly in this + /// function. + fn discover_tree( + &mut self, + forest_idx: ForestIndex, + node_id: MastNodeId, + ) -> Result<(), MastForestError> { + if self.discovered_nodes[forest_idx][node_id.as_usize()] { + return Ok(()); + } + + let current_node = + &self.mast_forests[forest_idx].nodes.get(node_id.as_usize()).ok_or_else(|| { + MastForestError::NodeIdOverflow( + node_id, + self.mast_forests[forest_idx].num_nodes() as usize, + ) + })?; + + // Note that we can process nodes in postorder, since we push them onto the back of the + // deque but pop them off the front. + match current_node { + MastNode::Block(_) => { + self.push_node(forest_idx, node_id); + }, + MastNode::Join(join_node) => { + self.discover_tree(forest_idx, join_node.first())?; + self.discover_tree(forest_idx, join_node.second())?; + self.push_node(forest_idx, node_id); + }, + MastNode::Split(split_node) => { + self.discover_tree(forest_idx, split_node.on_true())?; + self.discover_tree(forest_idx, split_node.on_false())?; + self.push_node(forest_idx, node_id); + }, + MastNode::Loop(loop_node) => { + self.discover_tree(forest_idx, loop_node.body())?; + self.push_node(forest_idx, node_id); + }, + MastNode::Call(call_node) => { + self.discover_tree(forest_idx, call_node.callee())?; + self.push_node(forest_idx, node_id); + }, + MastNode::Dyn(_) => { + self.push_node(forest_idx, node_id); + }, + MastNode::External(external_node) => { + // When we encounter an external node referencing digest `foo` there are two cases: + // - If there exists a node `replacement` in any forest with digest `foo`, we want + // to replace the external node with that node, which we do in two steps. + // - Discover the `replacement`'s tree and add it to the deque. + // - If `replacement` was already discovered before, it won't actually be + // returned. + // - In any case this means: The `replacement` node is processed before the + // replacement signal we're adding next. + // - Add a replacement signal to the deque, signaling that the `replacement` + // replaced the external node. + // - If no replacement exists, yield the External Node as a regular `Node`. + if let Some((other_forest_idx, other_node_id)) = + self.non_external_nodes.get(&external_node.digest()).copied() + { + self.discover_tree(other_forest_idx, other_node_id)?; + + self.unvisited_nodes.push_back( + MultiMastForestIteratorItem::ExternalNodeReplacement { + replacement_forest_idx: other_forest_idx, + replacement_mast_node_id: other_node_id, + replaced_forest_idx: forest_idx, + replaced_mast_node_id: node_id, + }, + ); + + self.discovered_nodes[forest_idx][node_id.as_usize()] = true; + } else { + self.push_node(forest_idx, node_id); + } + }, + } + + Ok(()) + } + + /// Finds the next undiscovered procedure root and discovers a tree from it. + /// + /// If the undiscovered node deque is empty after calling this function, the iteration is + /// complete. + /// + /// This function basically consists of two loops: + /// - The outer loop iterates over all forest indices. + /// - The inner loop iterates over all procedure root indices for the current forest. + fn discover_nodes(&mut self) { + 'forest_loop: while self.current_forest_idx < self.mast_forests.len() + && self.unvisited_nodes.is_empty() + { + // If we don't have any forests, there is nothing to do. + if self.mast_forests.is_empty() { + return; + } + + // If the current forest doesn't have roots, advance to the next one. + if self.mast_forests[self.current_forest_idx].num_procedures() == 0 { + self.current_forest_idx += 1; + continue; + } + + let procedure_roots = self.mast_forests[self.current_forest_idx].procedure_roots(); + let discovered_nodes = &self.discovered_nodes[self.current_forest_idx]; + + // Find the next undiscovered procedure root for the current forest by incrementing the + // current procedure root until we find one that was not yet discovered. + while discovered_nodes + [procedure_roots[self.current_procedure_root_idx as usize].as_usize()] + { + // If we have reached the end of the procedure roots for the current forest, + // continue searching in the next forest. + if self.current_procedure_root_idx + 1 + >= self.mast_forests[self.current_forest_idx].num_procedures() + { + // Reset current procedure root. + self.current_procedure_root_idx = 0; + // Increment forest index. + self.current_forest_idx += 1; + + continue 'forest_loop; + } + + // Since the current procedure root was already discovered, check the next one. + self.current_procedure_root_idx += 1; + } + + // We exited the loop, so the current procedure root is undiscovered and so we can start + // a discovery from that root. Since that root is undiscovered, it is guaranteed that + // after this discovery the deque will be non-empty. + let procedure_root_id = procedure_roots[self.current_procedure_root_idx as usize]; + self.discover_tree(self.current_forest_idx, procedure_root_id) + .expect("we should only pass root indices that are valid for the forest"); + } + } +} + +impl Iterator for MultiMastForestNodeIter<'_> { + type Item = MultiMastForestIteratorItem; + + fn next(&mut self) -> Option { + if let Some(deque_item) = self.unvisited_nodes.pop_front() { + return Some(deque_item); + } + + self.discover_nodes(); + + if !self.unvisited_nodes.is_empty() { + self.next() + } else { + // If the deque is empty after tree discovery, all (reachable) nodes have been + // discovered and visited. + None + } + } +} + +/// The iterator item for [`MultiMastForestNodeIter`]. See its documentation for details. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum MultiMastForestIteratorItem { + /// A regular node discovered by the iterator. + Node { + forest_idx: ForestIndex, + node_id: MastNodeId, + }, + /// Signals a replacement of an external node by some other node. + ExternalNodeReplacement { + replacement_forest_idx: usize, + replacement_mast_node_id: MastNodeId, + replaced_forest_idx: usize, + replaced_mast_node_id: MastNodeId, + }, +} + +// TESTS +// ================================================================================================ + +#[cfg(test)] +mod tests { + use miden_crypto::hash::rpo::RpoDigest; + + use super::*; + use crate::Operation; + + fn random_digest() -> RpoDigest { + RpoDigest::new([rand_utils::rand_value(); 4]) + } + + #[test] + fn multi_mast_forest_dfs_empty() { + let forest = MastForest::new(); + let mut iterator = MultiMastForestNodeIter::new(vec![&forest]); + assert!(iterator.next().is_none()); + } + + #[test] + fn multi_mast_forest_multiple_forests_dfs() { + let nodea0_digest = random_digest(); + let nodea1_digest = random_digest(); + let nodea2_digest = random_digest(); + let nodea3_digest = random_digest(); + + let nodeb0_digest = random_digest(); + + let mut forest_a = MastForest::new(); + forest_a.add_external(nodea0_digest).unwrap(); + let id1 = forest_a.add_external(nodea1_digest).unwrap(); + let id2 = forest_a.add_external(nodea2_digest).unwrap(); + let id3 = forest_a.add_external(nodea3_digest).unwrap(); + let id_split = forest_a.add_split(id2, id3).unwrap(); + let id_join = forest_a.add_join(id2, id_split).unwrap(); + + forest_a.make_root(id_join); + forest_a.make_root(id1); + + let mut forest_b = MastForest::new(); + let id_ext_b = forest_b.add_external(nodeb0_digest).unwrap(); + let id_block_b = forest_b.add_block(vec![Operation::Eqz], None).unwrap(); + let id_split_b = forest_b.add_split(id_ext_b, id_block_b).unwrap(); + + forest_b.make_root(id_split_b); + + // Note that the node at index 0 is not visited because it is not reachable from any root + // and is not a root itself. + let nodes = MultiMastForestNodeIter::new(vec![&forest_a, &forest_b]).collect::>(); + + assert_eq!(nodes.len(), 8); + assert_eq!(nodes[0], MultiMastForestIteratorItem::Node { forest_idx: 0, node_id: id2 }); + assert_eq!(nodes[1], MultiMastForestIteratorItem::Node { forest_idx: 0, node_id: id3 }); + assert_eq!( + nodes[2], + MultiMastForestIteratorItem::Node { forest_idx: 0, node_id: id_split } + ); + assert_eq!(nodes[3], MultiMastForestIteratorItem::Node { forest_idx: 0, node_id: id_join }); + assert_eq!(nodes[4], MultiMastForestIteratorItem::Node { forest_idx: 0, node_id: id1 }); + assert_eq!( + nodes[5], + MultiMastForestIteratorItem::Node { forest_idx: 1, node_id: id_ext_b } + ); + assert_eq!( + nodes[6], + MultiMastForestIteratorItem::Node { forest_idx: 1, node_id: id_block_b } + ); + assert_eq!( + nodes[7], + MultiMastForestIteratorItem::Node { forest_idx: 1, node_id: id_split_b } + ); + } + + #[test] + fn multi_mast_forest_external_dependencies() { + let block_foo = MastNode::new_basic_block(vec![Operation::Drop], None).unwrap(); + let mut forest_a = MastForest::new(); + let id_foo_a = forest_a.add_external(block_foo.digest()).unwrap(); + let id_call_a = forest_a.add_call(id_foo_a).unwrap(); + forest_a.make_root(id_call_a); + + let mut forest_b = MastForest::new(); + let id_ext_b = forest_b.add_external(forest_a[id_call_a].digest()).unwrap(); + let id_call_b = forest_b.add_call(id_ext_b).unwrap(); + forest_b.add_node(block_foo).unwrap(); + forest_b.make_root(id_call_b); + + let nodes = MultiMastForestNodeIter::new(vec![&forest_a, &forest_b]).collect::>(); + + assert_eq!(nodes.len(), 5); + + // The replacement for the external node from forest A. + assert_eq!( + nodes[0], + MultiMastForestIteratorItem::Node { + forest_idx: 1, + node_id: MastNodeId::new_unchecked(2) + } + ); + // The external node replaced by the block foo from forest B. + assert_eq!( + nodes[1], + MultiMastForestIteratorItem::ExternalNodeReplacement { + replacement_forest_idx: 1, + replacement_mast_node_id: MastNodeId::new_unchecked(2), + replaced_forest_idx: 0, + replaced_mast_node_id: MastNodeId::new_unchecked(0) + } + ); + // The call from forest A. + assert_eq!( + nodes[2], + MultiMastForestIteratorItem::Node { + forest_idx: 0, + node_id: MastNodeId::new_unchecked(1) + } + ); + // The replacement for the external node that is replaced by the Call in forest A. + assert_eq!( + nodes[3], + MultiMastForestIteratorItem::ExternalNodeReplacement { + replacement_forest_idx: 0, + replacement_mast_node_id: MastNodeId::new_unchecked(1), + replaced_forest_idx: 1, + replaced_mast_node_id: MastNodeId::new_unchecked(0) + } + ); + // The call from forest B. + assert_eq!( + nodes[4], + MultiMastForestIteratorItem::Node { + forest_idx: 1, + node_id: MastNodeId::new_unchecked(1) + } + ); + } + + /// Tests that a node which is referenced twice in a Mast Forest is returned in the required + /// order. + /// + /// In this test we have a MastForest with this graph: + /// + /// 3 <- Split Node + /// / \ + /// 1 2 + /// \ / + /// 0 + /// + /// We need to ensure that 0 is processed before 1 and that it is not processed again when + /// processing the children of node 2. + /// + /// This test and example is essentially a copy from a part of the MastForest of the Miden + /// Stdlib where this failed on a previous implementation. + #[test] + fn multi_mast_forest_child_duplicate() { + let block_foo = MastNode::new_basic_block(vec![Operation::Drop], None).unwrap(); + let mut forest = MastForest::new(); + let id_foo = forest.add_external(block_foo.digest()).unwrap(); + let id_call1 = forest.add_call(id_foo).unwrap(); + let id_call2 = forest.add_call(id_foo).unwrap(); + let id_split = forest.add_split(id_call1, id_call2).unwrap(); + forest.make_root(id_split); + + let nodes = MultiMastForestNodeIter::new(vec![&forest]).collect::>(); + + // The foo node should be yielded first and it should not be yielded twice. + for (i, expected_node_id) in [id_foo, id_call1, id_call2, id_split].into_iter().enumerate() + { + assert_eq!( + nodes[i], + MultiMastForestIteratorItem::Node { forest_idx: 0, node_id: expected_node_id } + ); + } + } +} diff --git a/core/src/mast/node/basic_block_node/mod.rs b/core/src/mast/node/basic_block_node/mod.rs index 44e0a1835e..726decbc72 100644 --- a/core/src/mast/node/basic_block_node/mod.rs +++ b/core/src/mast/node/basic_block_node/mod.rs @@ -1,12 +1,13 @@ use alloc::vec::Vec; -use core::fmt; +use core::{fmt, mem}; use miden_crypto::{hash::rpo::RpoDigest, Felt, ZERO}; use miden_formatting::prettier::PrettyPrint; -use winter_utils::flatten_slice_elements; use crate::{ - chiplets::hasher, mast::MastForestError, Decorator, DecoratorIterator, DecoratorList, Operation, + chiplets::hasher, + mast::{DecoratorId, MastForest, MastForestError}, + DecoratorIterator, DecoratorList, Operation, }; mod op_batch; @@ -110,9 +111,24 @@ impl BasicBlockNode { digest: RpoDigest, ) -> Self { assert!(!operations.is_empty()); - let (op_batches, _) = batch_ops(operations); + let op_batches = batch_ops(operations); Self { op_batches, digest, decorators } } + + /// Returns a new [`BasicBlockNode`] instantiated with the specified operations and decorators. + #[cfg(test)] + pub fn new_with_raw_decorators( + operations: Vec, + decorators: Vec<(usize, crate::Decorator)>, + mast_forest: &mut crate::mast::MastForest, + ) -> Result { + let mut decorator_list = Vec::new(); + for (idx, decorator) in decorators { + decorator_list.push((idx, mast_forest.add_decorator(decorator)?)); + } + + Self::new(operations, Some(decorator_list)) + } } // ------------------------------------------------------------------------------------------------ @@ -167,6 +183,11 @@ impl BasicBlockNode { DecoratorIterator::new(&self.decorators) } + /// Returns an iterator over the operations in the order in which they appear in the program. + pub fn operations(&self) -> impl Iterator { + self.op_batches.iter().flat_map(|batch| batch.ops()) + } + /// Returns the total number of operations and decorators in this basic block. pub fn num_operations_and_decorators(&self) -> u32 { let num_ops: usize = self.num_operations() as usize; @@ -184,10 +205,49 @@ impl BasicBlockNode { } } +/// Mutators +impl BasicBlockNode { + /// Sets the provided list of decorators to be executed before all existing decorators. + pub fn prepend_decorators(&mut self, decorator_ids: Vec) { + let mut new_decorators: DecoratorList = + decorator_ids.into_iter().map(|decorator_id| (0, decorator_id)).collect(); + new_decorators.extend(mem::take(&mut self.decorators)); + + self.decorators = new_decorators; + } + + /// Sets the provided list of decorators to be executed after all existing decorators. + pub fn append_decorators(&mut self, decorator_ids: Vec) { + let after_last_op_idx = self.num_operations() as usize; + + self.decorators.extend( + decorator_ids.into_iter().map(|decorator_id| (after_last_op_idx, decorator_id)), + ); + } +} + // PRETTY PRINTING // ================================================================================================ -impl PrettyPrint for BasicBlockNode { +impl BasicBlockNode { + pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + BasicBlockNodePrettyPrint { block_node: self, mast_forest } + } + + pub(super) fn to_pretty_print<'a>( + &'a self, + mast_forest: &'a MastForest, + ) -> impl PrettyPrint + 'a { + BasicBlockNodePrettyPrint { block_node: self, mast_forest } + } +} + +struct BasicBlockNodePrettyPrint<'a> { + block_node: &'a BasicBlockNode, + mast_forest: &'a MastForest, +} + +impl PrettyPrint for BasicBlockNodePrettyPrint<'_> { #[rustfmt::skip] fn render(&self) -> crate::prettier::Document { use crate::prettier::*; @@ -195,11 +255,13 @@ impl PrettyPrint for BasicBlockNode { // e.g. `basic_block a b c end` let single_line = const_text("basic_block") + const_text(" ") - + self - .op_batches + + self. + block_node .iter() - .flat_map(|batch| batch.ops().iter()) - .map(|p| p.render()) + .map(|op_or_dec| match op_or_dec { + OperationOrDecorator::Operation(op) => op.render(), + OperationOrDecorator::Decorator(&decorator_id) => self.mast_forest[decorator_id].render(), + }) .reduce(|acc, doc| acc + const_text(" ") + doc) .unwrap_or_default() + const_text(" ") @@ -218,10 +280,12 @@ impl PrettyPrint for BasicBlockNode { const_text("basic_block") + nl() + self - .op_batches + .block_node .iter() - .flat_map(|batch| batch.ops().iter()) - .map(|p| p.render()) + .map(|op_or_dec| match op_or_dec { + OperationOrDecorator::Operation(op) => op.render(), + OperationOrDecorator::Decorator(&decorator_id) => self.mast_forest[decorator_id].render(), + }) .reduce(|acc, doc| acc + nl() + doc) .unwrap_or_default(), ) + nl() @@ -231,7 +295,7 @@ impl PrettyPrint for BasicBlockNode { } } -impl fmt::Display for BasicBlockNode { +impl fmt::Display for BasicBlockNodePrettyPrint<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use crate::prettier::PrettyPrint; self.pretty_print(f) @@ -241,11 +305,11 @@ impl fmt::Display for BasicBlockNode { // OPERATION OR DECORATOR // ================================================================================================ -/// Encodes either an [`Operation`] or a [`Decorator`]. +/// Encodes either an [`Operation`] or a [`crate::Decorator`]. #[derive(Clone, Debug, Eq, PartialEq)] pub enum OperationOrDecorator<'a> { Operation(&'a Operation), - Decorator(&'a Decorator), + Decorator(&'a DecoratorId), } struct OperationOrDecoratorIterator<'a> { @@ -315,22 +379,20 @@ impl<'a> Iterator for OperationOrDecoratorIterator<'a> { /// Groups the provided operations into batches and computes the hash of the block. fn batch_and_hash_ops(ops: Vec) -> (Vec, RpoDigest) { // Group the operations into batches. - let (batches, batch_groups) = batch_ops(ops); + let batches = batch_ops(ops); // Compute the hash of all operation groups. - let op_groups = &flatten_slice_elements(&batch_groups); - let hash = hasher::hash_elements(op_groups); + let op_groups: Vec = batches.iter().flat_map(|batch| batch.groups).collect(); + let hash = hasher::hash_elements(&op_groups); (batches, hash) } -/// Groups the provided operations into batches as described in the docs for this module (i.e., -/// up to 9 operations per group, and 8 groups per batch). -/// Returns a list of operation batches and a list of operation groups. -fn batch_ops(ops: Vec) -> (Vec, Vec<[Felt; BATCH_SIZE]>) { +/// Groups the provided operations into batches as described in the docs for this module (i.e., up +/// to 9 operations per group, and 8 groups per batch). +fn batch_ops(ops: Vec) -> Vec { let mut batches = Vec::::new(); let mut batch_acc = OpBatchAccumulator::new(); - let mut batch_groups = Vec::<[Felt; BATCH_SIZE]>::new(); for op in ops { // If the operation cannot be accepted into the current accumulator, add the contents of @@ -339,7 +401,6 @@ fn batch_ops(ops: Vec) -> (Vec, Vec<[Felt; BATCH_SIZE]>) { let batch = batch_acc.into_batch(); batch_acc = OpBatchAccumulator::new(); - batch_groups.push(*batch.groups()); batches.push(batch); } @@ -350,10 +411,10 @@ fn batch_ops(ops: Vec) -> (Vec, Vec<[Felt; BATCH_SIZE]>) { // Make sure we finished processing the last batch. if !batch_acc.is_empty() { let batch = batch_acc.into_batch(); - batch_groups.push(*batch.groups()); batches.push(batch); } - (batches, batch_groups) + + batches } /// Checks if a given decorators list is valid (only checked in debug mode) diff --git a/core/src/mast/node/basic_block_node/tests.rs b/core/src/mast/node/basic_block_node/tests.rs index b44b3f1fd8..607793642c 100644 --- a/core/src/mast/node/basic_block_node/tests.rs +++ b/core/src/mast/node/basic_block_node/tests.rs @@ -1,5 +1,5 @@ use super::*; -use crate::{Decorator, ONE}; +use crate::{mast::MastForest, Decorator, ONE}; #[test] fn batch_ops() { @@ -295,24 +295,26 @@ fn batch_ops() { #[test] fn operation_or_decorator_iterator() { + let mut mast_forest = MastForest::new(); let operations = vec![Operation::Add, Operation::Mul, Operation::MovDn2, Operation::MovDn3]; // Note: there are 2 decorators after the last instruction let decorators = vec![ - (0, Decorator::Event(0)), - (0, Decorator::Event(1)), - (3, Decorator::Event(2)), - (4, Decorator::Event(3)), - (4, Decorator::Event(4)), + (0, Decorator::Trace(0)), // ID: 0 + (0, Decorator::Trace(1)), // ID: 1 + (3, Decorator::Trace(2)), // ID: 2 + (4, Decorator::Trace(3)), // ID: 3 + (4, Decorator::Trace(4)), // ID: 4 ]; - let node = BasicBlockNode::new(operations, Some(decorators)).unwrap(); + let node = + BasicBlockNode::new_with_raw_decorators(operations, decorators, &mut mast_forest).unwrap(); let mut iterator = node.iter(); // operation index 0 - assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&Decorator::Event(0)))); - assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&Decorator::Event(1)))); + assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&DecoratorId(0)))); + assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&DecoratorId(1)))); assert_eq!(iterator.next(), Some(OperationOrDecorator::Operation(&Operation::Add))); // operations indices 1, 2 @@ -320,12 +322,12 @@ fn operation_or_decorator_iterator() { assert_eq!(iterator.next(), Some(OperationOrDecorator::Operation(&Operation::MovDn2))); // operation index 3 - assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&Decorator::Event(2)))); + assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&DecoratorId(2)))); assert_eq!(iterator.next(), Some(OperationOrDecorator::Operation(&Operation::MovDn3))); // after last operation - assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&Decorator::Event(3)))); - assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&Decorator::Event(4)))); + assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&DecoratorId(3)))); + assert_eq!(iterator.next(), Some(OperationOrDecorator::Decorator(&DecoratorId(4)))); assert_eq!(iterator.next(), None); } diff --git a/core/src/mast/node/call_node.rs b/core/src/mast/node/call_node.rs index f5038dd57c..7f207d386f 100644 --- a/core/src/mast/node/call_node.rs +++ b/core/src/mast/node/call_node.rs @@ -1,11 +1,15 @@ +use alloc::vec::Vec; use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; -use miden_formatting::prettier::PrettyPrint; +use miden_formatting::{ + hex::ToHex, + prettier::{const_text, nl, text, Document, PrettyPrint}, +}; use crate::{ chiplets::hasher, - mast::{MastForest, MastForestError, MastNodeId}, + mast::{DecoratorId, MastForest, MastForestError, MastNodeId}, OPCODE_CALL, OPCODE_SYSCALL, }; @@ -23,6 +27,8 @@ pub struct CallNode { callee: MastNodeId, is_syscall: bool, digest: RpoDigest, + before_enter: Vec, + after_exit: Vec, } //------------------------------------------------------------------------------------------------- @@ -48,13 +54,25 @@ impl CallNode { hasher::merge_in_domain(&[callee_digest, RpoDigest::default()], Self::CALL_DOMAIN) }; - Ok(Self { callee, is_syscall: false, digest }) + Ok(Self { + callee, + is_syscall: false, + digest, + before_enter: Vec::new(), + after_exit: Vec::new(), + }) } /// Returns a new [`CallNode`] from values that are assumed to be correct. /// Should only be used when the source of the inputs is trusted (e.g. deserialization). pub fn new_unsafe(callee: MastNodeId, digest: RpoDigest) -> Self { - Self { callee, is_syscall: false, digest } + Self { + callee, + is_syscall: false, + digest, + before_enter: Vec::new(), + after_exit: Vec::new(), + } } /// Returns a new [`CallNode`] instantiated with the specified callee and marked as a kernel @@ -72,13 +90,25 @@ impl CallNode { hasher::merge_in_domain(&[callee_digest, RpoDigest::default()], Self::SYSCALL_DOMAIN) }; - Ok(Self { callee, is_syscall: true, digest }) + Ok(Self { + callee, + is_syscall: true, + digest, + before_enter: Vec::new(), + after_exit: Vec::new(), + }) } /// Returns a new syscall [`CallNode`] from values that are assumed to be correct. /// Should only be used when the source of the inputs is trusted (e.g. deserialization). pub fn new_syscall_unsafe(callee: MastNodeId, digest: RpoDigest) -> Self { - Self { callee, is_syscall: true, digest } + Self { + callee, + is_syscall: true, + digest, + before_enter: Vec::new(), + after_exit: Vec::new(), + } } } @@ -125,6 +155,29 @@ impl CallNode { Self::CALL_DOMAIN } } + + /// Returns the decorators to be executed before this node is executed. + pub fn before_enter(&self) -> &[DecoratorId] { + &self.before_enter + } + + /// Returns the decorators to be executed after this node is executed. + pub fn after_exit(&self) -> &[DecoratorId] { + &self.after_exit + } +} + +/// Mutators +impl CallNode { + /// Sets the list of decorators to be executed before this node. + pub fn set_before_enter(&mut self, decorator_ids: Vec) { + self.before_enter = decorator_ids; + } + + /// Sets the list of decorators to be executed after this node. + pub fn set_after_exit(&mut self, decorator_ids: Vec) { + self.after_exit = decorator_ids; + } } // PRETTY PRINTING @@ -135,37 +188,84 @@ impl CallNode { &'a self, mast_forest: &'a MastForest, ) -> impl PrettyPrint + 'a { - CallNodePrettyPrint { call_node: self, mast_forest } + CallNodePrettyPrint { node: self, mast_forest } } pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { - CallNodePrettyPrint { call_node: self, mast_forest } + CallNodePrettyPrint { node: self, mast_forest } } } struct CallNodePrettyPrint<'a> { - call_node: &'a CallNode, + node: &'a CallNode, mast_forest: &'a MastForest, } -impl<'a> PrettyPrint for CallNodePrettyPrint<'a> { - #[rustfmt::skip] - fn render(&self) -> crate::prettier::Document { - use crate::prettier::*; - use miden_formatting::hex::ToHex; - - let callee_digest = self.mast_forest[self.call_node.callee].digest(); +impl CallNodePrettyPrint<'_> { + /// Concatenates the provided decorators in a single line. If the list of decorators is not + /// empty, prepends `prepend` and appends `append` to the decorator document. + fn concatenate_decorators( + &self, + decorator_ids: &[DecoratorId], + prepend: Document, + append: Document, + ) -> Document { + let decorators = decorator_ids + .iter() + .map(|&decorator_id| self.mast_forest[decorator_id].render()) + .reduce(|acc, doc| acc + const_text(" ") + doc) + .unwrap_or_default(); - let doc = if self.call_node.is_syscall { - const_text("syscall") + if decorators.is_empty() { + decorators } else { - const_text("call") + prepend + decorators + append + } + } + + fn single_line_pre_decorators(&self) -> Document { + self.concatenate_decorators(self.node.before_enter(), Document::Empty, const_text(" ")) + } + + fn single_line_post_decorators(&self) -> Document { + self.concatenate_decorators(self.node.after_exit(), const_text(" "), Document::Empty) + } + + fn multi_line_pre_decorators(&self) -> Document { + self.concatenate_decorators(self.node.before_enter(), Document::Empty, nl()) + } + + fn multi_line_post_decorators(&self) -> Document { + self.concatenate_decorators(self.node.after_exit(), nl(), Document::Empty) + } +} + +impl PrettyPrint for CallNodePrettyPrint<'_> { + fn render(&self) -> Document { + let call_or_syscall = { + let callee_digest = self.mast_forest[self.node.callee].digest(); + if self.node.is_syscall { + const_text("syscall") + + const_text(".") + + text(callee_digest.as_bytes().to_hex_with_prefix()) + } else { + const_text("call") + + const_text(".") + + text(callee_digest.as_bytes().to_hex_with_prefix()) + } }; - doc + const_text(".") + text(callee_digest.as_bytes().to_hex_with_prefix()) + + let single_line = self.single_line_pre_decorators() + + call_or_syscall.clone() + + self.single_line_post_decorators(); + let multi_line = + self.multi_line_pre_decorators() + call_or_syscall + self.multi_line_post_decorators(); + + single_line | multi_line } } -impl<'a> fmt::Display for CallNodePrettyPrint<'a> { +impl fmt::Display for CallNodePrettyPrint<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use crate::prettier::PrettyPrint; self.pretty_print(f) diff --git a/core/src/mast/node/dyn_node.rs b/core/src/mast/node/dyn_node.rs index 934a8fec2d..8bdaf516a7 100644 --- a/core/src/mast/node/dyn_node.rs +++ b/core/src/mast/node/dyn_node.rs @@ -1,57 +1,201 @@ +use alloc::vec::Vec; use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; +use miden_formatting::prettier::{const_text, nl, Document, PrettyPrint}; -use crate::OPCODE_DYN; +use crate::{ + mast::{DecoratorId, MastForest}, + OPCODE_DYN, OPCODE_DYNCALL, +}; // DYN NODE // ================================================================================================ /// A Dyn node specifies that the node to be executed next is defined dynamically via the stack. -#[derive(Debug, Clone, Default, PartialEq, Eq)] -pub struct DynNode; +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DynNode { + is_dyncall: bool, + before_enter: Vec, + after_exit: Vec, +} /// Constants impl DynNode { /// The domain of the Dyn block (used for control block hashing). - pub const DOMAIN: Felt = Felt::new(OPCODE_DYN as u64); + pub const DYN_DOMAIN: Felt = Felt::new(OPCODE_DYN as u64); + + /// The domain of the Dyncall block (used for control block hashing). + pub const DYNCALL_DOMAIN: Felt = Felt::new(OPCODE_DYNCALL as u64); } /// Public accessors impl DynNode { + /// Creates a new [`DynNode`] representing a dynexec operation. + pub fn new_dyn() -> Self { + Self { + is_dyncall: false, + before_enter: Vec::new(), + after_exit: Vec::new(), + } + } + + /// Creates a new [`DynNode`] representing a dyncall operation. + pub fn new_dyncall() -> Self { + Self { + is_dyncall: true, + before_enter: Vec::new(), + after_exit: Vec::new(), + } + } + + /// Returns true if the [`DynNode`] represents a dyncall operation, and false for dynexec. + pub fn is_dyncall(&self) -> bool { + self.is_dyncall + } + + /// Returns the domain of this dyn node. + pub fn domain(&self) -> Felt { + if self.is_dyncall() { + Self::DYNCALL_DOMAIN + } else { + Self::DYN_DOMAIN + } + } + /// Returns a commitment to a Dyn node. /// /// The commitment is computed by hashing two empty words ([ZERO; 4]) in the domain defined - /// by [Self::DOMAIN], i.e.: + /// by [Self::DYN_DOMAIN] or [Self::DYNCALL_DOMAIN], i.e.: /// /// ``` /// # use miden_core::mast::DynNode; /// # use miden_crypto::{hash::rpo::{RpoDigest as Digest, Rpo256 as Hasher}}; - /// Hasher::merge_in_domain(&[Digest::default(), Digest::default()], DynNode::DOMAIN); + /// Hasher::merge_in_domain(&[Digest::default(), Digest::default()], DynNode::DYN_DOMAIN); + /// Hasher::merge_in_domain(&[Digest::default(), Digest::default()], DynNode::DYNCALL_DOMAIN); /// ``` pub fn digest(&self) -> RpoDigest { - RpoDigest::new([ - Felt::new(8115106948140260551), - Felt::new(13491227816952616836), - Felt::new(15015806788322198710), - Felt::new(16575543461540527115), - ]) + if self.is_dyncall { + RpoDigest::new([ + Felt::new(8751004906421739448), + Felt::new(13469709002495534233), + Felt::new(12584249374630430826), + Felt::new(7624899870831503004), + ]) + } else { + RpoDigest::new([ + Felt::new(8115106948140260551), + Felt::new(13491227816952616836), + Felt::new(15015806788322198710), + Felt::new(16575543461540527115), + ]) + } + } + + /// Returns the decorators to be executed before this node is executed. + pub fn before_enter(&self) -> &[DecoratorId] { + &self.before_enter + } + + /// Returns the decorators to be executed after this node is executed. + pub fn after_exit(&self) -> &[DecoratorId] { + &self.after_exit + } +} + +/// Mutators +impl DynNode { + /// Sets the list of decorators to be executed before this node. + pub fn set_before_enter(&mut self, decorator_ids: Vec) { + self.before_enter = decorator_ids; + } + + /// Sets the list of decorators to be executed after this node. + pub fn set_after_exit(&mut self, decorator_ids: Vec) { + self.after_exit = decorator_ids; } } // PRETTY PRINTING // ================================================================================================ -impl crate::prettier::PrettyPrint for DynNode { +impl DynNode { + pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + DynNodePrettyPrint { node: self, mast_forest } + } + + pub(super) fn to_pretty_print<'a>( + &'a self, + mast_forest: &'a MastForest, + ) -> impl PrettyPrint + 'a { + DynNodePrettyPrint { node: self, mast_forest } + } +} + +struct DynNodePrettyPrint<'a> { + node: &'a DynNode, + mast_forest: &'a MastForest, +} + +impl DynNodePrettyPrint<'_> { + /// Concatenates the provided decorators in a single line. If the list of decorators is not + /// empty, prepends `prepend` and appends `append` to the decorator document. + fn concatenate_decorators( + &self, + decorator_ids: &[DecoratorId], + prepend: Document, + append: Document, + ) -> Document { + let decorators = decorator_ids + .iter() + .map(|&decorator_id| self.mast_forest[decorator_id].render()) + .reduce(|acc, doc| acc + const_text(" ") + doc) + .unwrap_or_default(); + + if decorators.is_empty() { + decorators + } else { + prepend + decorators + append + } + } + + fn single_line_pre_decorators(&self) -> Document { + self.concatenate_decorators(self.node.before_enter(), Document::Empty, const_text(" ")) + } + + fn single_line_post_decorators(&self) -> Document { + self.concatenate_decorators(self.node.after_exit(), const_text(" "), Document::Empty) + } + + fn multi_line_pre_decorators(&self) -> Document { + self.concatenate_decorators(self.node.before_enter(), Document::Empty, nl()) + } + + fn multi_line_post_decorators(&self) -> Document { + self.concatenate_decorators(self.node.after_exit(), nl(), Document::Empty) + } +} + +impl crate::prettier::PrettyPrint for DynNodePrettyPrint<'_> { fn render(&self) -> crate::prettier::Document { - use crate::prettier::*; - const_text("dyn") + let dyn_text = if self.node.is_dyncall() { + const_text("dyncall") + } else { + const_text("dyn") + }; + + let single_line = self.single_line_pre_decorators() + + dyn_text.clone() + + self.single_line_post_decorators(); + let multi_line = + self.multi_line_pre_decorators() + dyn_text + self.multi_line_post_decorators(); + + single_line | multi_line } } -impl fmt::Display for DynNode { +impl fmt::Display for DynNodePrettyPrint<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use miden_formatting::prettier::PrettyPrint; self.pretty_print(f) } } @@ -70,8 +214,19 @@ mod tests { #[test] pub fn test_dyn_node_digest() { assert_eq!( - DynNode.digest(), - Rpo256::merge_in_domain(&[RpoDigest::default(), RpoDigest::default()], DynNode::DOMAIN) + DynNode::new_dyn().digest(), + Rpo256::merge_in_domain( + &[RpoDigest::default(), RpoDigest::default()], + DynNode::DYN_DOMAIN + ) + ); + + assert_eq!( + DynNode::new_dyncall().digest(), + Rpo256::merge_in_domain( + &[RpoDigest::default(), RpoDigest::default()], + DynNode::DYNCALL_DOMAIN + ) ); } } diff --git a/core/src/mast/node/external.rs b/core/src/mast/node/external.rs index e38e5f2b18..d966008009 100644 --- a/core/src/mast/node/external.rs +++ b/core/src/mast/node/external.rs @@ -1,8 +1,13 @@ +use alloc::vec::Vec; use core::fmt; use miden_crypto::hash::rpo::RpoDigest; +use miden_formatting::{ + hex::ToHex, + prettier::{const_text, nl, text, Document, PrettyPrint}, +}; -use crate::mast::MastForest; +use crate::mast::{DecoratorId, MastForest}; // EXTERNAL NODE // ================================================================================================ @@ -18,12 +23,18 @@ use crate::mast::MastForest; #[derive(Clone, Debug, PartialEq, Eq)] pub struct ExternalNode { digest: RpoDigest, + before_enter: Vec, + after_exit: Vec, } impl ExternalNode { /// Returns a new [`ExternalNode`] instantiated with the specified procedure hash. pub fn new(procedure_hash: RpoDigest) -> Self { - Self { digest: procedure_hash } + Self { + digest: procedure_hash, + before_enter: Vec::new(), + after_exit: Vec::new(), + } } } @@ -32,27 +43,108 @@ impl ExternalNode { pub fn digest(&self) -> RpoDigest { self.digest } + + /// Returns the decorators to be executed before this node is executed. + pub fn before_enter(&self) -> &[DecoratorId] { + &self.before_enter + } + + /// Returns the decorators to be executed after this node is executed. + pub fn after_exit(&self) -> &[DecoratorId] { + &self.after_exit + } +} + +/// Mutators +impl ExternalNode { + /// Sets the list of decorators to be executed before this node. + pub fn set_before_enter(&mut self, decorator_ids: Vec) { + self.before_enter = decorator_ids; + } + + /// Sets the list of decorators to be executed after this node. + pub fn set_after_exit(&mut self, decorator_ids: Vec) { + self.after_exit = decorator_ids; + } } // PRETTY PRINTING // ================================================================================================ impl ExternalNode { - pub(super) fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a { - self + pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { + ExternalNodePrettyPrint { node: self, mast_forest } } + + pub(super) fn to_pretty_print<'a>( + &'a self, + mast_forest: &'a MastForest, + ) -> impl PrettyPrint + 'a { + ExternalNodePrettyPrint { node: self, mast_forest } + } +} + +struct ExternalNodePrettyPrint<'a> { + node: &'a ExternalNode, + mast_forest: &'a MastForest, } -impl crate::prettier::PrettyPrint for ExternalNode { +impl ExternalNodePrettyPrint<'_> { + /// Concatenates the provided decorators in a single line. If the list of decorators is not + /// empty, prepends `prepend` and appends `append` to the decorator document. + fn concatenate_decorators( + &self, + decorator_ids: &[DecoratorId], + prepend: Document, + append: Document, + ) -> Document { + let decorators = decorator_ids + .iter() + .map(|&decorator_id| self.mast_forest[decorator_id].render()) + .reduce(|acc, doc| acc + const_text(" ") + doc) + .unwrap_or_default(); + + if decorators.is_empty() { + decorators + } else { + prepend + decorators + append + } + } + + fn single_line_pre_decorators(&self) -> Document { + self.concatenate_decorators(self.node.before_enter(), Document::Empty, const_text(" ")) + } + + fn single_line_post_decorators(&self) -> Document { + self.concatenate_decorators(self.node.after_exit(), const_text(" "), Document::Empty) + } + + fn multi_line_pre_decorators(&self) -> Document { + self.concatenate_decorators(self.node.before_enter(), Document::Empty, nl()) + } + + fn multi_line_post_decorators(&self) -> Document { + self.concatenate_decorators(self.node.after_exit(), nl(), Document::Empty) + } +} + +impl crate::prettier::PrettyPrint for ExternalNodePrettyPrint<'_> { fn render(&self) -> crate::prettier::Document { - use miden_formatting::hex::ToHex; + let external = const_text("external") + + const_text(".") + + text(self.node.digest.as_bytes().to_hex_with_prefix()); + + let single_line = self.single_line_pre_decorators() + + external.clone() + + self.single_line_post_decorators(); + let multi_line = + self.multi_line_pre_decorators() + external + self.multi_line_post_decorators(); - use crate::prettier::*; - const_text("external") + const_text(".") + text(self.digest.as_bytes().to_hex_with_prefix()) + single_line | multi_line } } -impl fmt::Display for ExternalNode { +impl fmt::Display for ExternalNodePrettyPrint<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use crate::prettier::PrettyPrint; self.pretty_print(f) diff --git a/core/src/mast/node/join_node.rs b/core/src/mast/node/join_node.rs index cb58008cb1..d3d04b4510 100644 --- a/core/src/mast/node/join_node.rs +++ b/core/src/mast/node/join_node.rs @@ -1,10 +1,11 @@ +use alloc::vec::Vec; use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; use crate::{ chiplets::hasher, - mast::{MastForest, MastForestError, MastNodeId}, + mast::{DecoratorId, MastForest, MastForestError, MastNodeId}, prettier::PrettyPrint, OPCODE_JOIN, }; @@ -18,6 +19,8 @@ use crate::{ pub struct JoinNode { children: [MastNodeId; 2], digest: RpoDigest, + before_enter: Vec, + after_exit: Vec, } /// Constants @@ -46,13 +49,23 @@ impl JoinNode { hasher::merge_in_domain(&[left_child_hash, right_child_hash], Self::DOMAIN) }; - Ok(Self { children, digest }) + Ok(Self { + children, + digest, + before_enter: Vec::new(), + after_exit: Vec::new(), + }) } /// Returns a new [`JoinNode`] from values that are assumed to be correct. /// Should only be used when the source of the inputs is trusted (e.g. deserialization). pub fn new_unsafe(children: [MastNodeId; 2], digest: RpoDigest) -> Self { - Self { children, digest } + Self { + children, + digest, + before_enter: Vec::new(), + after_exit: Vec::new(), + } } } @@ -83,6 +96,29 @@ impl JoinNode { pub fn second(&self) -> MastNodeId { self.children[1] } + + /// Returns the decorators to be executed before this node is executed. + pub fn before_enter(&self) -> &[DecoratorId] { + &self.before_enter + } + + /// Returns the decorators to be executed after this node is executed. + pub fn after_exit(&self) -> &[DecoratorId] { + &self.after_exit + } +} + +/// Mutators +impl JoinNode { + /// Sets the list of decorators to be executed before this node. + pub fn set_before_enter(&mut self, decorator_ids: Vec) { + self.before_enter = decorator_ids; + } + + /// Sets the list of decorators to be executed after this node. + pub fn set_after_exit(&mut self, decorator_ids: Vec) { + self.after_exit = decorator_ids; + } } // PRETTY PRINTING @@ -106,15 +142,48 @@ struct JoinNodePrettyPrint<'a> { mast_forest: &'a MastForest, } -impl<'a> PrettyPrint for JoinNodePrettyPrint<'a> { +impl PrettyPrint for JoinNodePrettyPrint<'_> { #[rustfmt::skip] fn render(&self) -> crate::prettier::Document { use crate::prettier::*; - let first_child = self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest); - let second_child = self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest); + let pre_decorators = { + let mut pre_decorators = self + .join_node + .before_enter() + .iter() + .map(|&decorator_id| self.mast_forest[decorator_id].render()) + .reduce(|acc, doc| acc + const_text(" ") + doc) + .unwrap_or_default(); + if !pre_decorators.is_empty() { + pre_decorators += nl(); + } + + pre_decorators + }; + + let post_decorators = { + let mut post_decorators = self + .join_node + .after_exit() + .iter() + .map(|&decorator_id| self.mast_forest[decorator_id].render()) + .reduce(|acc, doc| acc + const_text(" ") + doc) + .unwrap_or_default(); + if !post_decorators.is_empty() { + post_decorators = nl() + post_decorators; + } + + post_decorators + }; + + let first_child = + self.mast_forest[self.join_node.first()].to_pretty_print(self.mast_forest); + let second_child = + self.mast_forest[self.join_node.second()].to_pretty_print(self.mast_forest); - indent( + pre_decorators + + indent( 4, const_text("join") + nl() @@ -122,10 +191,11 @@ impl<'a> PrettyPrint for JoinNodePrettyPrint<'a> { + nl() + second_child.render(), ) + nl() + const_text("end") + + post_decorators } } -impl<'a> fmt::Display for JoinNodePrettyPrint<'a> { +impl fmt::Display for JoinNodePrettyPrint<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use crate::prettier::PrettyPrint; self.pretty_print(f) diff --git a/core/src/mast/node/loop_node.rs b/core/src/mast/node/loop_node.rs index 08e30ca69d..6091ce034d 100644 --- a/core/src/mast/node/loop_node.rs +++ b/core/src/mast/node/loop_node.rs @@ -1,3 +1,4 @@ +use alloc::vec::Vec; use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; @@ -5,7 +6,7 @@ use miden_formatting::prettier::PrettyPrint; use crate::{ chiplets::hasher, - mast::{MastForest, MastForestError, MastNodeId}, + mast::{DecoratorId, MastForest, MastForestError, MastNodeId}, OPCODE_LOOP, }; @@ -22,6 +23,8 @@ use crate::{ pub struct LoopNode { body: MastNodeId, digest: RpoDigest, + before_enter: Vec, + after_exit: Vec, } /// Constants @@ -43,13 +46,23 @@ impl LoopNode { hasher::merge_in_domain(&[body_hash, RpoDigest::default()], Self::DOMAIN) }; - Ok(Self { body, digest }) + Ok(Self { + body, + digest, + before_enter: Vec::new(), + after_exit: Vec::new(), + }) } /// Returns a new [`LoopNode`] from values that are assumed to be correct. /// Should only be used when the source of the inputs is trusted (e.g. deserialization). pub fn new_unsafe(body: MastNodeId, digest: RpoDigest) -> Self { - Self { body, digest } + Self { + body, + digest, + before_enter: Vec::new(), + after_exit: Vec::new(), + } } } @@ -72,6 +85,29 @@ impl LoopNode { pub fn body(&self) -> MastNodeId { self.body } + + /// Returns the decorators to be executed before this node is executed. + pub fn before_enter(&self) -> &[DecoratorId] { + &self.before_enter + } + + /// Returns the decorators to be executed after this node is executed. + pub fn after_exit(&self) -> &[DecoratorId] { + &self.after_exit + } +} + +/// Mutators +impl LoopNode { + /// Sets the list of decorators to be executed before this node. + pub fn set_before_enter(&mut self, decorator_ids: Vec) { + self.before_enter = decorator_ids; + } + + /// Sets the list of decorators to be executed after this node. + pub fn set_after_exit(&mut self, decorator_ids: Vec) { + self.after_exit = decorator_ids; + } } // PRETTY PRINTING @@ -95,17 +131,51 @@ struct LoopNodePrettyPrint<'a> { mast_forest: &'a MastForest, } -impl<'a> crate::prettier::PrettyPrint for LoopNodePrettyPrint<'a> { +impl crate::prettier::PrettyPrint for LoopNodePrettyPrint<'_> { fn render(&self) -> crate::prettier::Document { use crate::prettier::*; + let pre_decorators = { + let mut pre_decorators = self + .loop_node + .before_enter() + .iter() + .map(|&decorator_id| self.mast_forest[decorator_id].render()) + .reduce(|acc, doc| acc + const_text(" ") + doc) + .unwrap_or_default(); + if !pre_decorators.is_empty() { + pre_decorators += nl(); + } + + pre_decorators + }; + + let post_decorators = { + let mut post_decorators = self + .loop_node + .after_exit() + .iter() + .map(|&decorator_id| self.mast_forest[decorator_id].render()) + .reduce(|acc, doc| acc + const_text(" ") + doc) + .unwrap_or_default(); + if !post_decorators.is_empty() { + post_decorators = nl() + post_decorators; + } + + post_decorators + }; + let loop_body = self.mast_forest[self.loop_node.body].to_pretty_print(self.mast_forest); - indent(4, const_text("while.true") + nl() + loop_body.render()) + nl() + const_text("end") + pre_decorators + + indent(4, const_text("while.true") + nl() + loop_body.render()) + + nl() + + const_text("end") + + post_decorators } } -impl<'a> fmt::Display for LoopNodePrettyPrint<'a> { +impl fmt::Display for LoopNodePrettyPrint<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use crate::prettier::PrettyPrint; self.pretty_print(f) diff --git a/core/src/mast/node/mod.rs b/core/src/mast/node/mod.rs index 63293a3333..2dfa9dd037 100644 --- a/core/src/mast/node/mod.rs +++ b/core/src/mast/node/mod.rs @@ -27,7 +27,7 @@ pub use split_node::SplitNode; mod loop_node; pub use loop_node::LoopNode; -use super::MastForestError; +use super::{DecoratorId, MastForestError}; use crate::{ mast::{MastForest, MastNodeId}, DecoratorList, Operation, @@ -43,7 +43,7 @@ pub enum MastNode { Split(SplitNode), Loop(LoopNode), Call(CallNode), - Dyn, + Dyn(DynNode), External(ExternalNode), } @@ -95,12 +95,25 @@ impl MastNode { } pub fn new_dyn() -> Self { - Self::Dyn + Self::Dyn(DynNode::new_dyn()) + } + pub fn new_dyncall() -> Self { + Self::Dyn(DynNode::new_dyncall()) } pub fn new_external(mast_root: RpoDigest) -> Self { Self::External(ExternalNode::new(mast_root)) } + + #[cfg(test)] + pub fn new_basic_block_with_raw_decorators( + operations: Vec, + decorators: Vec<(usize, crate::Decorator)>, + mast_forest: &mut MastForest, + ) -> Result { + let block = BasicBlockNode::new_with_raw_decorators(operations, decorators, mast_forest)?; + Ok(Self::Block(block)) + } } // ------------------------------------------------------------------------------------------------ @@ -113,7 +126,7 @@ impl MastNode { /// Returns true if this node is a Dyn node. pub fn is_dyn(&self) -> bool { - matches!(self, MastNode::Dyn) + matches!(self, MastNode::Dyn(_)) } /// Returns true if this node is a basic block. @@ -133,7 +146,7 @@ impl MastNode { pub fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> impl PrettyPrint + 'a { match self { MastNode::Block(basic_block_node) => { - MastNodePrettyPrint::new(Box::new(basic_block_node)) + MastNodePrettyPrint::new(Box::new(basic_block_node.to_pretty_print(mast_forest))) }, MastNode::Join(join_node) => { MastNodePrettyPrint::new(Box::new(join_node.to_pretty_print(mast_forest))) @@ -147,8 +160,12 @@ impl MastNode { MastNode::Call(call_node) => { MastNodePrettyPrint::new(Box::new(call_node.to_pretty_print(mast_forest))) }, - MastNode::Dyn => MastNodePrettyPrint::new(Box::new(DynNode)), - MastNode::External(external_node) => MastNodePrettyPrint::new(Box::new(external_node)), + MastNode::Dyn(dyn_node) => { + MastNodePrettyPrint::new(Box::new(dyn_node.to_pretty_print(mast_forest))) + }, + MastNode::External(external_node) => { + MastNodePrettyPrint::new(Box::new(external_node.to_pretty_print(mast_forest))) + }, } } @@ -159,7 +176,7 @@ impl MastNode { MastNode::Split(_) => SplitNode::DOMAIN, MastNode::Loop(_) => LoopNode::DOMAIN, MastNode::Call(call_node) => call_node.domain(), - MastNode::Dyn => DynNode::DOMAIN, + MastNode::Dyn(dyn_node) => dyn_node.domain(), MastNode::External(_) => panic!("Can't fetch domain for an `External` node."), } } @@ -171,22 +188,79 @@ impl MastNode { MastNode::Split(node) => node.digest(), MastNode::Loop(node) => node.digest(), MastNode::Call(node) => node.digest(), - MastNode::Dyn => DynNode.digest(), + MastNode::Dyn(node) => node.digest(), MastNode::External(node) => node.digest(), } } pub fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a { match self { - MastNode::Block(node) => MastNodeDisplay::new(node), + MastNode::Block(node) => MastNodeDisplay::new(node.to_display(mast_forest)), MastNode::Join(node) => MastNodeDisplay::new(node.to_display(mast_forest)), MastNode::Split(node) => MastNodeDisplay::new(node.to_display(mast_forest)), MastNode::Loop(node) => MastNodeDisplay::new(node.to_display(mast_forest)), MastNode::Call(node) => MastNodeDisplay::new(node.to_display(mast_forest)), - MastNode::Dyn => MastNodeDisplay::new(DynNode), + MastNode::Dyn(node) => MastNodeDisplay::new(node.to_display(mast_forest)), MastNode::External(node) => MastNodeDisplay::new(node.to_display(mast_forest)), } } + + /// Returns the decorators to be executed before this node is executed. + pub fn before_enter(&self) -> &[DecoratorId] { + use MastNode::*; + match self { + Block(_) => &[], + Join(node) => node.before_enter(), + Split(node) => node.before_enter(), + Loop(node) => node.before_enter(), + Call(node) => node.before_enter(), + Dyn(node) => node.before_enter(), + External(node) => node.before_enter(), + } + } + + /// Returns the decorators to be executed after this node is executed. + pub fn after_exit(&self) -> &[DecoratorId] { + use MastNode::*; + match self { + Block(_) => &[], + Join(node) => node.after_exit(), + Split(node) => node.after_exit(), + Loop(node) => node.after_exit(), + Call(node) => node.after_exit(), + Dyn(node) => node.after_exit(), + External(node) => node.after_exit(), + } + } +} + +/// Mutators +impl MastNode { + /// Sets the list of decorators to be executed before this node. + pub fn set_before_enter(&mut self, decorator_ids: Vec) { + match self { + MastNode::Block(node) => node.prepend_decorators(decorator_ids), + MastNode::Join(node) => node.set_before_enter(decorator_ids), + MastNode::Split(node) => node.set_before_enter(decorator_ids), + MastNode::Loop(node) => node.set_before_enter(decorator_ids), + MastNode::Call(node) => node.set_before_enter(decorator_ids), + MastNode::Dyn(node) => node.set_before_enter(decorator_ids), + MastNode::External(node) => node.set_before_enter(decorator_ids), + } + } + + /// Sets the list of decorators to be executed after this node. + pub fn set_after_exit(&mut self, decorator_ids: Vec) { + match self { + MastNode::Block(node) => node.append_decorators(decorator_ids), + MastNode::Join(node) => node.set_after_exit(decorator_ids), + MastNode::Split(node) => node.set_after_exit(decorator_ids), + MastNode::Loop(node) => node.set_after_exit(decorator_ids), + MastNode::Call(node) => node.set_after_exit(decorator_ids), + MastNode::Dyn(node) => node.set_after_exit(decorator_ids), + MastNode::External(node) => node.set_after_exit(decorator_ids), + } + } } // PRETTY PRINTING @@ -202,7 +276,7 @@ impl<'a> MastNodePrettyPrint<'a> { } } -impl<'a> PrettyPrint for MastNodePrettyPrint<'a> { +impl PrettyPrint for MastNodePrettyPrint<'_> { fn render(&self) -> Document { self.node_pretty_print.render() } @@ -218,7 +292,7 @@ impl<'a> MastNodeDisplay<'a> { } } -impl<'a> fmt::Display for MastNodeDisplay<'a> { +impl fmt::Display for MastNodeDisplay<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.node_display.fmt(f) } diff --git a/core/src/mast/node/split_node.rs b/core/src/mast/node/split_node.rs index 635049157f..8a46fcdc70 100644 --- a/core/src/mast/node/split_node.rs +++ b/core/src/mast/node/split_node.rs @@ -1,3 +1,4 @@ +use alloc::vec::Vec; use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt}; @@ -5,7 +6,7 @@ use miden_formatting::prettier::PrettyPrint; use crate::{ chiplets::hasher, - mast::{MastForest, MastForestError, MastNodeId}, + mast::{DecoratorId, MastForest, MastForestError, MastNodeId}, OPCODE_SPLIT, }; @@ -22,6 +23,8 @@ use crate::{ pub struct SplitNode { branches: [MastNodeId; 2], digest: RpoDigest, + before_enter: Vec, + after_exit: Vec, } /// Constants @@ -49,13 +52,23 @@ impl SplitNode { hasher::merge_in_domain(&[if_branch_hash, else_branch_hash], Self::DOMAIN) }; - Ok(Self { branches, digest }) + Ok(Self { + branches, + digest, + before_enter: Vec::new(), + after_exit: Vec::new(), + }) } /// Returns a new [`SplitNode`] from values that are assumed to be correct. /// Should only be used when the source of the inputs is trusted (e.g. deserialization). pub fn new_unsafe(branches: [MastNodeId; 2], digest: RpoDigest) -> Self { - Self { branches, digest } + Self { + branches, + digest, + before_enter: Vec::new(), + after_exit: Vec::new(), + } } } @@ -85,6 +98,29 @@ impl SplitNode { pub fn on_false(&self) -> MastNodeId { self.branches[1] } + + /// Returns the decorators to be executed before this node is executed. + pub fn before_enter(&self) -> &[DecoratorId] { + &self.before_enter + } + + /// Returns the decorators to be executed after this node is executed. + pub fn after_exit(&self) -> &[DecoratorId] { + &self.after_exit + } +} + +/// Mutators +impl SplitNode { + /// Sets the list of decorators to be executed before this node. + pub fn set_before_enter(&mut self, decorator_ids: Vec) { + self.before_enter = decorator_ids; + } + + /// Sets the list of decorators to be executed after this node. + pub fn set_after_exit(&mut self, decorator_ids: Vec) { + self.after_exit = decorator_ids; + } } // PRETTY PRINTING @@ -108,21 +144,53 @@ struct SplitNodePrettyPrint<'a> { mast_forest: &'a MastForest, } -impl<'a> PrettyPrint for SplitNodePrettyPrint<'a> { +impl PrettyPrint for SplitNodePrettyPrint<'_> { #[rustfmt::skip] fn render(&self) -> crate::prettier::Document { use crate::prettier::*; + let pre_decorators = { + let mut pre_decorators = self + .split_node + .before_enter() + .iter() + .map(|&decorator_id| self.mast_forest[decorator_id].render()) + .reduce(|acc, doc| acc + const_text(" ") + doc) + .unwrap_or_default(); + if !pre_decorators.is_empty() { + pre_decorators += nl(); + } + + pre_decorators + }; + + let post_decorators = { + let mut post_decorators = self + .split_node + .after_exit() + .iter() + .map(|&decorator_id| self.mast_forest[decorator_id].render()) + .reduce(|acc, doc| acc + const_text(" ") + doc) + .unwrap_or_default(); + if !post_decorators.is_empty() { + post_decorators = nl() + post_decorators; + } + + post_decorators + }; + let true_branch = self.mast_forest[self.split_node.on_true()].to_pretty_print(self.mast_forest); let false_branch = self.mast_forest[self.split_node.on_false()].to_pretty_print(self.mast_forest); - let mut doc = indent(4, const_text("if.true") + nl() + true_branch.render()) + nl(); + let mut doc = pre_decorators; + doc += indent(4, const_text("if.true") + nl() + true_branch.render()) + nl(); doc += indent(4, const_text("else") + nl() + false_branch.render()); - doc + nl() + const_text("end") + doc += nl() + const_text("end"); + doc + post_decorators } } -impl<'a> fmt::Display for SplitNodePrettyPrint<'a> { +impl fmt::Display for SplitNodePrettyPrint<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use crate::prettier::PrettyPrint; self.pretty_print(f) diff --git a/core/src/mast/node_fingerprint.rs b/core/src/mast/node_fingerprint.rs new file mode 100644 index 0000000000..6bf41a2ce1 --- /dev/null +++ b/core/src/mast/node_fingerprint.rs @@ -0,0 +1,198 @@ +use alloc::{collections::BTreeMap, vec::Vec}; + +use miden_crypto::hash::{ + blake::{Blake3Digest, Blake3_256}, + rpo::RpoDigest, + Digest, +}; + +use crate::{ + mast::{DecoratorId, MastForest, MastForestError, MastNode, MastNodeId}, + Operation, +}; + +// MAST NODE EQUALITY +// ================================================================================================ + +pub type DecoratorFingerprint = Blake3Digest<32>; + +/// Represents the hash used to test for equality between [`MastNode`]s. +/// +/// The decorator root will be `None` if and only if there are no decorators attached to the node, +/// and all children have no decorator roots (meaning that there are no decorators in all the +/// descendants). +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct MastNodeFingerprint { + mast_root: RpoDigest, + decorator_root: Option, +} + +// ------------------------------------------------------------------------------------------------ +/// Constructors +impl MastNodeFingerprint { + /// Creates a new [`MastNodeFingerprint`] from the given MAST root with an empty decorator root. + pub fn new(mast_root: RpoDigest) -> Self { + Self { mast_root, decorator_root: None } + } + + /// Creates a new [`MastNodeFingerprint`] from the given MAST root and the given + /// [`DecoratorFingerprint`]. + pub fn with_decorator_root(mast_root: RpoDigest, decorator_root: DecoratorFingerprint) -> Self { + Self { + mast_root, + decorator_root: Some(decorator_root), + } + } + + /// Creates a [`MastNodeFingerprint`] from a [`MastNode`]. + /// + /// The `hash_by_node_id` map must contain all children of the node for efficient lookup of + /// their fingerprints. This function returns an error if a child of the given `node` is not in + /// this map. + pub fn from_mast_node( + forest: &MastForest, + hash_by_node_id: &BTreeMap, + node: &MastNode, + ) -> Result { + match node { + MastNode::Block(node) => { + let mut bytes_to_hash = Vec::new(); + + for &(idx, decorator_id) in node.decorators() { + bytes_to_hash.extend(idx.to_le_bytes()); + bytes_to_hash.extend(forest[decorator_id].fingerprint().as_bytes()); + } + + // Add any `Assert`, `U32assert2` and `MpVerify` opcodes present, since these are + // not included in the MAST root. + for (op_idx, op) in node.operations().enumerate() { + if let Operation::U32assert2(inner_value) + | Operation::Assert(inner_value) + | Operation::MpVerify(inner_value) = op + { + let op_idx: u32 = op_idx + .try_into() + .expect("there are more than 2^{32}-1 operations in basic block"); + + // we include the opcode to differentiate between `Assert` and `U32assert2` + bytes_to_hash.push(op.op_code()); + // we include the operation index to distinguish between basic blocks that + // would have the same assert instructions, but in a different order + bytes_to_hash.extend(op_idx.to_le_bytes()); + bytes_to_hash.extend(inner_value.to_le_bytes()); + } + } + + if bytes_to_hash.is_empty() { + Ok(MastNodeFingerprint::new(node.digest())) + } else { + let decorator_root = Blake3_256::hash(&bytes_to_hash); + Ok(MastNodeFingerprint::with_decorator_root(node.digest(), decorator_root)) + } + }, + MastNode::Join(node) => fingerprint_from_parts( + forest, + hash_by_node_id, + node.before_enter(), + node.after_exit(), + &[node.first(), node.second()], + node.digest(), + ), + MastNode::Split(node) => fingerprint_from_parts( + forest, + hash_by_node_id, + node.before_enter(), + node.after_exit(), + &[node.on_true(), node.on_false()], + node.digest(), + ), + MastNode::Loop(node) => fingerprint_from_parts( + forest, + hash_by_node_id, + node.before_enter(), + node.after_exit(), + &[node.body()], + node.digest(), + ), + MastNode::Call(node) => fingerprint_from_parts( + forest, + hash_by_node_id, + node.before_enter(), + node.after_exit(), + &[node.callee()], + node.digest(), + ), + MastNode::Dyn(node) => fingerprint_from_parts( + forest, + hash_by_node_id, + node.before_enter(), + node.after_exit(), + &[], + node.digest(), + ), + MastNode::External(node) => fingerprint_from_parts( + forest, + hash_by_node_id, + node.before_enter(), + node.after_exit(), + &[], + node.digest(), + ), + } + } +} + +// ------------------------------------------------------------------------------------------------ +/// Accessors +impl MastNodeFingerprint { + pub fn mast_root(&self) -> &RpoDigest { + &self.mast_root + } +} + +fn fingerprint_from_parts( + forest: &MastForest, + hash_by_node_id: &BTreeMap, + before_enter_ids: &[DecoratorId], + after_exit_ids: &[DecoratorId], + children_ids: &[MastNodeId], + node_digest: RpoDigest, +) -> Result { + let pre_decorator_hash_bytes = + before_enter_ids.iter().flat_map(|&id| forest[id].fingerprint().as_bytes()); + let post_decorator_hash_bytes = + after_exit_ids.iter().flat_map(|&id| forest[id].fingerprint().as_bytes()); + + let children_decorator_roots = children_ids + .iter() + .filter_map(|child_id| { + hash_by_node_id + .get(child_id) + .ok_or(MastForestError::ChildFingerprintMissing(*child_id)) + .map(|child_fingerprint| child_fingerprint.decorator_root) + .transpose() + }) + .collect::, MastForestError>>()?; + + // Reminder: the `MastNodeFingerprint`'s decorator root will be `None` if and only if there are + // no decorators attached to the node, and all children have no decorator roots (meaning + // that there are no decorators in all the descendants). + if pre_decorator_hash_bytes.clone().next().is_none() + && post_decorator_hash_bytes.clone().next().is_none() + && children_decorator_roots.is_empty() + { + Ok(MastNodeFingerprint::new(node_digest)) + } else { + let decorator_bytes_to_hash: Vec = pre_decorator_hash_bytes + .chain(post_decorator_hash_bytes) + .chain( + children_decorator_roots + .into_iter() + .flat_map(|decorator_root| decorator_root.as_bytes()), + ) + .collect(); + + let decorator_root = Blake3_256::hash(&decorator_bytes_to_hash); + Ok(MastNodeFingerprint::with_decorator_root(node_digest, decorator_root)) + } +} diff --git a/core/src/mast/serialization/basic_block_data_builder.rs b/core/src/mast/serialization/basic_block_data_builder.rs deleted file mode 100644 index add9776a17..0000000000 --- a/core/src/mast/serialization/basic_block_data_builder.rs +++ /dev/null @@ -1,189 +0,0 @@ -use alloc::{collections::BTreeMap, vec::Vec}; - -use miden_crypto::hash::blake::{Blake3Digest, Blake3_256}; -use winter_utils::{ByteWriter, Serializable}; - -use super::{decorator::EncodedDecoratorVariant, DataOffset, StringIndex}; -use crate::{ - mast::{BasicBlockNode, OperationOrDecorator}, - AdviceInjector, DebugOptions, Decorator, SignatureKind, -}; - -// BASIC BLOCK DATA BUILDER -// ================================================================================================ - -/// Builds the `data` section of a serialized [`crate::mast::MastForest`]. -#[derive(Debug, Default)] -pub struct BasicBlockDataBuilder { - data: Vec, - string_table_builder: StringTableBuilder, -} - -/// Constructors -impl BasicBlockDataBuilder { - pub fn new() -> Self { - Self::default() - } -} - -/// Accessors -impl BasicBlockDataBuilder { - /// Returns the current offset into the data buffer. - pub fn get_offset(&self) -> DataOffset { - self.data.len() as DataOffset - } -} - -/// Mutators -impl BasicBlockDataBuilder { - /// Encodes a [`BasicBlockNode`] into the serialized [`crate::mast::MastForest`] data field. - pub fn encode_basic_block(&mut self, basic_block: &BasicBlockNode) { - // 2nd part of `mast_node_to_info()` (inside the match) - for op_or_decorator in basic_block.iter() { - match op_or_decorator { - OperationOrDecorator::Operation(operation) => operation.write_into(&mut self.data), - OperationOrDecorator::Decorator(decorator) => self.encode_decorator(decorator), - } - } - } - - /// Returns the serialized [`crate::mast::MastForest`] data field, as well as the string table. - pub fn into_parts(mut self) -> (Vec, Vec) { - let string_table = self.string_table_builder.into_table(&mut self.data); - (self.data, string_table) - } -} - -/// Helpers -impl BasicBlockDataBuilder { - fn encode_decorator(&mut self, decorator: &Decorator) { - // Set the first byte to the decorator discriminant. - { - let decorator_variant: EncodedDecoratorVariant = decorator.into(); - self.data.push(decorator_variant.discriminant()); - } - - // For decorators that have extra data, encode it in `data` and `strings`. - match decorator { - Decorator::Advice(advice_injector) => match advice_injector { - AdviceInjector::MapValueToStack { include_len, key_offset } => { - self.data.write_bool(*include_len); - self.data.write_usize(*key_offset); - }, - AdviceInjector::HdwordToMap { domain } => { - self.data.extend(domain.as_int().to_le_bytes()) - }, - - // Note: Since there is only 1 variant, we don't need to write any extra bytes. - AdviceInjector::SigToStack { kind } => match kind { - SignatureKind::RpoFalcon512 => (), - }, - AdviceInjector::MerkleNodeMerge - | AdviceInjector::MerkleNodeToStack - | AdviceInjector::UpdateMerkleNode - | AdviceInjector::U64Div - | AdviceInjector::Ext2Inv - | AdviceInjector::Ext2Intt - | AdviceInjector::SmtGet - | AdviceInjector::SmtSet - | AdviceInjector::SmtPeek - | AdviceInjector::U32Clz - | AdviceInjector::U32Ctz - | AdviceInjector::U32Clo - | AdviceInjector::U32Cto - | AdviceInjector::ILog2 - | AdviceInjector::MemToMap - | AdviceInjector::HpermToMap => (), - }, - Decorator::AsmOp(assembly_op) => { - self.data.push(assembly_op.num_cycles()); - self.data.write_bool(assembly_op.should_break()); - - // source location - let loc = assembly_op.location(); - self.data.write_bool(loc.is_some()); - if let Some(loc) = loc { - let str_index_in_table = - self.string_table_builder.add_string(loc.path.as_ref()); - self.data.write_usize(str_index_in_table); - self.data.write_u32(loc.start.to_u32()); - self.data.write_u32(loc.end.to_u32()); - } - - // context name - { - let str_index_in_table = - self.string_table_builder.add_string(assembly_op.context_name()); - self.data.write_usize(str_index_in_table); - } - - // op - { - let str_index_in_table = self.string_table_builder.add_string(assembly_op.op()); - self.data.write_usize(str_index_in_table); - } - }, - Decorator::Debug(debug_options) => match debug_options { - DebugOptions::StackTop(value) => self.data.push(*value), - DebugOptions::MemInterval(start, end) => { - self.data.extend(start.to_le_bytes()); - self.data.extend(end.to_le_bytes()); - }, - DebugOptions::LocalInterval(start, second, end) => { - self.data.extend(start.to_le_bytes()); - self.data.extend(second.to_le_bytes()); - self.data.extend(end.to_le_bytes()); - }, - DebugOptions::StackAll | DebugOptions::MemAll => (), - }, - Decorator::Event(value) | Decorator::Trace(value) => { - self.data.extend(value.to_le_bytes()) - }, - } - } -} - -// STRING TABLE BUILDER -// ================================================================================================ - -#[derive(Debug, Default)] -struct StringTableBuilder { - table: Vec, - str_to_index: BTreeMap, StringIndex>, - strings_data: Vec, -} - -impl StringTableBuilder { - pub fn add_string(&mut self, string: &str) -> StringIndex { - if let Some(str_idx) = self.str_to_index.get(&Blake3_256::hash(string.as_bytes())) { - // return already interned string - *str_idx - } else { - // add new string to table - // NOTE: these string refs' offset will need to be shifted again in `into_table()` - let str_offset = self - .strings_data - .len() - .try_into() - .expect("strings table larger than 2^32 bytes"); - - let str_idx = self.table.len(); - - string.write_into(&mut self.strings_data); - self.table.push(str_offset); - self.str_to_index.insert(Blake3_256::hash(string.as_bytes()), str_idx); - - str_idx - } - } - - pub fn into_table(self, data: &mut Vec) -> Vec { - let table_offset: u32 = data - .len() - .try_into() - .expect("MAST forest serialization: data field longer than 2^32 bytes"); - data.extend(self.strings_data); - - self.table.into_iter().map(|str_offset| str_offset + table_offset).collect() - } -} diff --git a/core/src/mast/serialization/basic_block_data_decoder.rs b/core/src/mast/serialization/basic_block_data_decoder.rs deleted file mode 100644 index 35993951e9..0000000000 --- a/core/src/mast/serialization/basic_block_data_decoder.rs +++ /dev/null @@ -1,244 +0,0 @@ -use alloc::{string::String, sync::Arc, vec::Vec}; -use core::cell::RefCell; - -use miden_crypto::Felt; -use winter_utils::{ByteReader, Deserializable, DeserializationError, SliceReader}; - -use super::{decorator::EncodedDecoratorVariant, DataOffset, StringIndex}; -use crate::{ - AdviceInjector, AssemblyOp, DebugOptions, Decorator, DecoratorList, Operation, SignatureKind, -}; - -pub struct BasicBlockDataDecoder<'a> { - data: &'a [u8], - strings: &'a [DataOffset], - /// This field is used to allocate an `Arc` for any string in `strings` where the decoder - /// requests a reference-counted string rather than a fresh allocation as a `String`. - /// - /// Currently, this is only used for debug information (source file names), but most cases - /// where strings are stored in MAST are stored as `Arc` in practice, we just haven't yet - /// updated all of the decoders. - /// - /// We lazily allocate an `Arc` when strings are decoded as an `Arc`, but the underlying - /// string data corresponds to the same index in `strings`. All future requests for a - /// ref-counted string we've allocated an `Arc` for, will clone the `Arc` rather than - /// allocate a fresh string. - refc_strings: Vec>>>, -} - -/// Constructors -impl<'a> BasicBlockDataDecoder<'a> { - pub fn new(data: &'a [u8], strings: &'a [DataOffset]) -> Self { - let mut refc_strings = Vec::with_capacity(strings.len()); - refc_strings.resize(strings.len(), RefCell::new(None)); - Self { data, strings, refc_strings } - } -} - -/// Mutators -impl<'a> BasicBlockDataDecoder<'a> { - pub fn decode_operations_and_decorators( - &self, - offset: DataOffset, - num_to_decode: u32, - ) -> Result<(Vec, DecoratorList), DeserializationError> { - let mut operations: Vec = Vec::new(); - let mut decorators: DecoratorList = Vec::new(); - - let mut data_reader = SliceReader::new(&self.data[offset as usize..]); - for _ in 0..num_to_decode { - let first_byte = data_reader.peek_u8()?; - - if first_byte & 0b1000_0000 == 0 { - // operation. - operations.push(Operation::read_from(&mut data_reader)?); - } else { - // decorator. - let decorator = self.decode_decorator(&mut data_reader)?; - decorators.push((operations.len(), decorator)); - } - } - - Ok((operations, decorators)) - } -} - -/// Helpers -impl<'a> BasicBlockDataDecoder<'a> { - fn decode_decorator( - &self, - data_reader: &mut SliceReader, - ) -> Result { - let discriminant = data_reader.read_u8()?; - - let decorator_variant = EncodedDecoratorVariant::from_discriminant(discriminant) - .ok_or_else(|| { - DeserializationError::InvalidValue(format!( - "invalid decorator variant discriminant: {discriminant}" - )) - })?; - - match decorator_variant { - EncodedDecoratorVariant::AdviceInjectorMerkleNodeMerge => { - Ok(Decorator::Advice(AdviceInjector::MerkleNodeMerge)) - }, - EncodedDecoratorVariant::AdviceInjectorMerkleNodeToStack => { - Ok(Decorator::Advice(AdviceInjector::MerkleNodeToStack)) - }, - EncodedDecoratorVariant::AdviceInjectorUpdateMerkleNode => { - Ok(Decorator::Advice(AdviceInjector::UpdateMerkleNode)) - }, - EncodedDecoratorVariant::AdviceInjectorMapValueToStack => { - let include_len = data_reader.read_bool()?; - let key_offset = data_reader.read_usize()?; - - Ok(Decorator::Advice(AdviceInjector::MapValueToStack { include_len, key_offset })) - }, - EncodedDecoratorVariant::AdviceInjectorU64Div => { - Ok(Decorator::Advice(AdviceInjector::U64Div)) - }, - EncodedDecoratorVariant::AdviceInjectorExt2Inv => { - Ok(Decorator::Advice(AdviceInjector::Ext2Inv)) - }, - EncodedDecoratorVariant::AdviceInjectorExt2Intt => { - Ok(Decorator::Advice(AdviceInjector::Ext2Intt)) - }, - EncodedDecoratorVariant::AdviceInjectorSmtGet => { - Ok(Decorator::Advice(AdviceInjector::SmtGet)) - }, - EncodedDecoratorVariant::AdviceInjectorSmtSet => { - Ok(Decorator::Advice(AdviceInjector::SmtSet)) - }, - EncodedDecoratorVariant::AdviceInjectorSmtPeek => { - Ok(Decorator::Advice(AdviceInjector::SmtPeek)) - }, - EncodedDecoratorVariant::AdviceInjectorU32Clz => { - Ok(Decorator::Advice(AdviceInjector::U32Clz)) - }, - EncodedDecoratorVariant::AdviceInjectorU32Ctz => { - Ok(Decorator::Advice(AdviceInjector::U32Ctz)) - }, - EncodedDecoratorVariant::AdviceInjectorU32Clo => { - Ok(Decorator::Advice(AdviceInjector::U32Clo)) - }, - EncodedDecoratorVariant::AdviceInjectorU32Cto => { - Ok(Decorator::Advice(AdviceInjector::U32Cto)) - }, - EncodedDecoratorVariant::AdviceInjectorILog2 => { - Ok(Decorator::Advice(AdviceInjector::ILog2)) - }, - EncodedDecoratorVariant::AdviceInjectorMemToMap => { - Ok(Decorator::Advice(AdviceInjector::MemToMap)) - }, - EncodedDecoratorVariant::AdviceInjectorHdwordToMap => { - let domain = data_reader.read_u64()?; - let domain = Felt::try_from(domain).map_err(|err| { - DeserializationError::InvalidValue(format!( - "Error when deserializing HdwordToMap decorator domain: {err}" - )) - })?; - - Ok(Decorator::Advice(AdviceInjector::HdwordToMap { domain })) - }, - EncodedDecoratorVariant::AdviceInjectorHpermToMap => { - Ok(Decorator::Advice(AdviceInjector::HpermToMap)) - }, - EncodedDecoratorVariant::AdviceInjectorSigToStack => { - Ok(Decorator::Advice(AdviceInjector::SigToStack { - kind: SignatureKind::RpoFalcon512, - })) - }, - EncodedDecoratorVariant::AssemblyOp => { - let num_cycles = data_reader.read_u8()?; - let should_break = data_reader.read_bool()?; - - // source location - let location = if data_reader.read_bool()? { - let str_index_in_table = data_reader.read_usize()?; - let path = self.read_arc_str(str_index_in_table)?; - let start = data_reader.read_u32()?; - let end = data_reader.read_u32()?; - Some(crate::debuginfo::Location { - path, - start: start.into(), - end: end.into(), - }) - } else { - None - }; - - let context_name = { - let str_index_in_table = data_reader.read_usize()?; - self.read_string(str_index_in_table)? - }; - - let op = { - let str_index_in_table = data_reader.read_usize()?; - self.read_string(str_index_in_table)? - }; - - Ok(Decorator::AsmOp(AssemblyOp::new( - location, - context_name, - num_cycles, - op, - should_break, - ))) - }, - EncodedDecoratorVariant::DebugOptionsStackAll => { - Ok(Decorator::Debug(DebugOptions::StackAll)) - }, - EncodedDecoratorVariant::DebugOptionsStackTop => { - let value = data_reader.read_u8()?; - - Ok(Decorator::Debug(DebugOptions::StackTop(value))) - }, - EncodedDecoratorVariant::DebugOptionsMemAll => { - Ok(Decorator::Debug(DebugOptions::MemAll)) - }, - EncodedDecoratorVariant::DebugOptionsMemInterval => { - let start = data_reader.read_u32()?; - let end = data_reader.read_u32()?; - - Ok(Decorator::Debug(DebugOptions::MemInterval(start, end))) - }, - EncodedDecoratorVariant::DebugOptionsLocalInterval => { - let start = data_reader.read_u16()?; - let second = data_reader.read_u16()?; - let end = data_reader.read_u16()?; - - Ok(Decorator::Debug(DebugOptions::LocalInterval(start, second, end))) - }, - EncodedDecoratorVariant::Event => { - let value = data_reader.read_u32()?; - - Ok(Decorator::Event(value)) - }, - EncodedDecoratorVariant::Trace => { - let value = data_reader.read_u32()?; - - Ok(Decorator::Trace(value)) - }, - } - } - - fn read_arc_str(&self, str_idx: StringIndex) -> Result, DeserializationError> { - if let Some(cached) = self.refc_strings.get(str_idx).and_then(|cell| cell.borrow().clone()) - { - return Ok(cached); - } - - let string = Arc::from(self.read_string(str_idx)?.into_boxed_str()); - *self.refc_strings[str_idx].borrow_mut() = Some(Arc::clone(&string)); - Ok(string) - } - - fn read_string(&self, str_idx: StringIndex) -> Result { - let str_offset = self.strings.get(str_idx).copied().ok_or_else(|| { - DeserializationError::InvalidValue(format!("invalid index in strings table: {str_idx}")) - })? as usize; - - let mut reader = SliceReader::new(&self.data[str_offset..]); - reader.read() - } -} diff --git a/core/src/mast/serialization/basic_blocks.rs b/core/src/mast/serialization/basic_blocks.rs new file mode 100644 index 0000000000..86cf884717 --- /dev/null +++ b/core/src/mast/serialization/basic_blocks.rs @@ -0,0 +1,102 @@ +use alloc::vec::Vec; + +use winter_utils::{ByteReader, DeserializationError, Serializable, SliceReader}; + +use super::{DecoratorDataOffset, NodeDataOffset}; +use crate::{ + mast::{BasicBlockNode, DecoratorId, MastForest}, + DecoratorList, Operation, +}; + +// BASIC BLOCK DATA BUILDER +// ================================================================================================ + +/// Builds the node `data` section of a serialized [`crate::mast::MastForest`]. +#[derive(Debug, Default)] +pub struct BasicBlockDataBuilder { + node_data: Vec, +} + +/// Constructors +impl BasicBlockDataBuilder { + pub fn new() -> Self { + Self::default() + } +} + +/// Mutators +impl BasicBlockDataBuilder { + /// Encodes a [`BasicBlockNode`] into the serialized [`crate::mast::MastForest`] data field. + pub fn encode_basic_block( + &mut self, + basic_block: &BasicBlockNode, + ) -> (NodeDataOffset, Option) { + let ops_offset = self.node_data.len() as NodeDataOffset; + + let operations: Vec = basic_block.operations().copied().collect(); + operations.write_into(&mut self.node_data); + + if basic_block.decorators().is_empty() { + (ops_offset, None) + } else { + let decorator_data_offset = self.node_data.len() as DecoratorDataOffset; + basic_block.decorators().write_into(&mut self.node_data); + + (ops_offset, Some(decorator_data_offset)) + } + } + + /// Returns the serialized [`crate::mast::MastForest`] node data field. + pub fn finalize(self) -> Vec { + self.node_data + } +} + +// BASIC BLOCK DATA DECODER +// ================================================================================================ + +pub struct BasicBlockDataDecoder<'a> { + node_data: &'a [u8], +} + +/// Constructors +impl<'a> BasicBlockDataDecoder<'a> { + pub fn new(node_data: &'a [u8]) -> Self { + Self { node_data } + } +} + +/// Decoding methods +impl BasicBlockDataDecoder<'_> { + pub fn decode_operations_and_decorators( + &self, + ops_offset: NodeDataOffset, + decorator_list_offset: NodeDataOffset, + mast_forest: &MastForest, + ) -> Result<(Vec, DecoratorList), DeserializationError> { + // Read ops + let mut ops_data_reader = SliceReader::new(&self.node_data[ops_offset as usize..]); + let operations: Vec = ops_data_reader.read()?; + + // read decorators only if there are some + let decorators = if decorator_list_offset == MastForest::MAX_DECORATORS as u32 { + Vec::new() + } else { + let mut decorators_data_reader = + SliceReader::new(&self.node_data[decorator_list_offset as usize..]); + + let num_decorators: usize = decorators_data_reader.read()?; + (0..num_decorators) + .map(|_| { + let decorator_loc: usize = decorators_data_reader.read()?; + let decorator_id = + DecoratorId::from_u32_safe(decorators_data_reader.read()?, mast_forest)?; + + Ok((decorator_loc, decorator_id)) + }) + .collect::>()? + }; + + Ok((operations, decorators)) + } +} diff --git a/core/src/mast/serialization/decorator.rs b/core/src/mast/serialization/decorator.rs index a2f1e84aaa..a8b2041e3c 100644 --- a/core/src/mast/serialization/decorator.rs +++ b/core/src/mast/serialization/decorator.rs @@ -1,13 +1,218 @@ +use alloc::vec::Vec; + +use miden_crypto::Felt; use num_derive::{FromPrimitive, ToPrimitive}; use num_traits::{FromPrimitive, ToPrimitive}; +use winter_utils::{ + ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, SliceReader, +}; + +use super::{ + string_table::{StringTable, StringTableBuilder}, + DecoratorDataOffset, +}; +use crate::{AdviceInjector, AssemblyOp, DebugOptions, Decorator, SignatureKind}; + +/// Represents a serialized [`Decorator`]. +/// +/// The serialized representation of [`DecoratorInfo`] is guaranteed to be fixed width, so that the +/// decorators stored in the `decorators` table of the serialized [`MastForest`] can be accessed +/// quickly by index. +#[derive(Debug)] +pub struct DecoratorInfo { + variant: EncodedDecoratorVariant, + decorator_data_offset: DecoratorDataOffset, +} + +impl DecoratorInfo { + pub fn from_decorator( + decorator: &Decorator, + data_builder: &mut DecoratorDataBuilder, + string_table_builder: &mut StringTableBuilder, + ) -> Self { + let variant = EncodedDecoratorVariant::from(decorator); + let decorator_data_offset = + data_builder.encode_decorator_data(decorator, string_table_builder).unwrap_or(0); + + Self { variant, decorator_data_offset } + } + + pub fn try_into_decorator( + &self, + string_table: &StringTable, + decorator_data: &[u8], + ) -> Result { + // This is safe because for decorators that don't use the offset, `0` is used (and hence + // will never access an element outside). Note that in this implementation, we trust the + // encoder. + let mut data_reader = + SliceReader::new(&decorator_data[self.decorator_data_offset as usize..]); + match self.variant { + EncodedDecoratorVariant::AdviceInjectorMerkleNodeMerge => { + Ok(Decorator::Advice(AdviceInjector::MerkleNodeMerge)) + }, + EncodedDecoratorVariant::AdviceInjectorMerkleNodeToStack => { + Ok(Decorator::Advice(AdviceInjector::MerkleNodeToStack)) + }, + EncodedDecoratorVariant::AdviceInjectorUpdateMerkleNode => { + Ok(Decorator::Advice(AdviceInjector::UpdateMerkleNode)) + }, + EncodedDecoratorVariant::AdviceInjectorMapValueToStack => { + let include_len = data_reader.read_bool()?; + let key_offset = data_reader.read_usize()?; + + Ok(Decorator::Advice(AdviceInjector::MapValueToStack { include_len, key_offset })) + }, + EncodedDecoratorVariant::AdviceInjectorU64Div => { + Ok(Decorator::Advice(AdviceInjector::U64Div)) + }, + EncodedDecoratorVariant::AdviceInjectorExt2Inv => { + Ok(Decorator::Advice(AdviceInjector::Ext2Inv)) + }, + EncodedDecoratorVariant::AdviceInjectorExt2Intt => { + Ok(Decorator::Advice(AdviceInjector::Ext2Intt)) + }, + EncodedDecoratorVariant::AdviceInjectorSmtGet => { + Ok(Decorator::Advice(AdviceInjector::SmtGet)) + }, + EncodedDecoratorVariant::AdviceInjectorSmtSet => { + Ok(Decorator::Advice(AdviceInjector::SmtSet)) + }, + EncodedDecoratorVariant::AdviceInjectorSmtPeek => { + Ok(Decorator::Advice(AdviceInjector::SmtPeek)) + }, + EncodedDecoratorVariant::AdviceInjectorU32Clz => { + Ok(Decorator::Advice(AdviceInjector::U32Clz)) + }, + EncodedDecoratorVariant::AdviceInjectorU32Ctz => { + Ok(Decorator::Advice(AdviceInjector::U32Ctz)) + }, + EncodedDecoratorVariant::AdviceInjectorU32Clo => { + Ok(Decorator::Advice(AdviceInjector::U32Clo)) + }, + EncodedDecoratorVariant::AdviceInjectorU32Cto => { + Ok(Decorator::Advice(AdviceInjector::U32Cto)) + }, + EncodedDecoratorVariant::AdviceInjectorILog2 => { + Ok(Decorator::Advice(AdviceInjector::ILog2)) + }, + EncodedDecoratorVariant::AdviceInjectorMemToMap => { + Ok(Decorator::Advice(AdviceInjector::MemToMap)) + }, + EncodedDecoratorVariant::AdviceInjectorHdwordToMap => { + let domain = data_reader.read_u64()?; + let domain = Felt::try_from(domain).map_err(|err| { + DeserializationError::InvalidValue(format!( + "Error when deserializing HdwordToMap decorator domain: {err}" + )) + })?; + + Ok(Decorator::Advice(AdviceInjector::HdwordToMap { domain })) + }, + EncodedDecoratorVariant::AdviceInjectorHpermToMap => { + Ok(Decorator::Advice(AdviceInjector::HpermToMap)) + }, + EncodedDecoratorVariant::AdviceInjectorSigToStack => { + Ok(Decorator::Advice(AdviceInjector::SigToStack { + kind: SignatureKind::RpoFalcon512, + })) + }, + EncodedDecoratorVariant::AssemblyOp => { + let num_cycles = data_reader.read_u8()?; + let should_break = data_reader.read_bool()?; + + // source location + let location = if data_reader.read_bool()? { + let str_index_in_table = data_reader.read_usize()?; + let path = string_table.read_arc_str(str_index_in_table)?; + let start = data_reader.read_u32()?; + let end = data_reader.read_u32()?; + Some(crate::debuginfo::Location { + path, + start: start.into(), + end: end.into(), + }) + } else { + None + }; + + let context_name = { + let str_index_in_table = data_reader.read_usize()?; + string_table.read_string(str_index_in_table)? + }; + + let op = { + let str_index_in_table = data_reader.read_usize()?; + string_table.read_string(str_index_in_table)? + }; + + Ok(Decorator::AsmOp(AssemblyOp::new( + location, + context_name, + num_cycles, + op, + should_break, + ))) + }, + EncodedDecoratorVariant::DebugOptionsStackAll => { + Ok(Decorator::Debug(DebugOptions::StackAll)) + }, + EncodedDecoratorVariant::DebugOptionsStackTop => { + let value = data_reader.read_u8()?; + + Ok(Decorator::Debug(DebugOptions::StackTop(value))) + }, + EncodedDecoratorVariant::DebugOptionsMemAll => { + Ok(Decorator::Debug(DebugOptions::MemAll)) + }, + EncodedDecoratorVariant::DebugOptionsMemInterval => { + let start = data_reader.read_u32()?; + let end = data_reader.read_u32()?; + + Ok(Decorator::Debug(DebugOptions::MemInterval(start, end))) + }, + EncodedDecoratorVariant::DebugOptionsLocalInterval => { + let start = data_reader.read_u16()?; + let second = data_reader.read_u16()?; + let end = data_reader.read_u16()?; + + Ok(Decorator::Debug(DebugOptions::LocalInterval(start, second, end))) + }, + EncodedDecoratorVariant::Trace => { + let value = data_reader.read_u32()?; + + Ok(Decorator::Trace(value)) + }, + } + } +} + +impl Serializable for DecoratorInfo { + fn write_into(&self, target: &mut W) { + let Self { variant, decorator_data_offset } = self; + + variant.write_into(target); + decorator_data_offset.write_into(target); + } +} -use crate::{AdviceInjector, DebugOptions, Decorator}; +impl Deserializable for DecoratorInfo { + fn read_from(source: &mut R) -> Result { + let variant = source.read()?; + let decorator_data_offset = source.read()?; + + Ok(Self { variant, decorator_data_offset }) + } +} + +// ENCODED DATA VARIANT +// =============================================================================================== /// Stores all the possible [`Decorator`] variants, without any associated data. /// /// This is effectively equivalent to a set of constants, and designed to convert between variant /// discriminant and enum variant conveniently. -#[derive(FromPrimitive, ToPrimitive)] +#[derive(Debug, FromPrimitive, ToPrimitive)] #[repr(u8)] pub enum EncodedDecoratorVariant { AdviceInjectorMerkleNodeMerge, @@ -35,7 +240,6 @@ pub enum EncodedDecoratorVariant { DebugOptionsMemAll, DebugOptionsMemInterval, DebugOptionsLocalInterval, - Event, Trace, } @@ -45,14 +249,12 @@ impl EncodedDecoratorVariant { /// To distinguish them from [`crate::Operation`] discriminants, the most significant bit of /// decorator discriminant is always set to 1. pub fn discriminant(&self) -> u8 { - let discriminant = self.to_u8().expect("guaranteed to fit in a `u8` due to #[repr(u8)]"); - - discriminant | 0b1000_0000 + self.to_u8().expect("guaranteed to fit in a `u8` due to #[repr(u8)]") } /// The inverse operation of [`Self::discriminant`]. pub fn from_discriminant(discriminant: u8) -> Option { - Self::from_u8(discriminant & 0b0111_1111) + Self::from_u8(discriminant) } } @@ -90,8 +292,149 @@ impl From<&Decorator> for EncodedDecoratorVariant { DebugOptions::MemInterval(..) => Self::DebugOptionsMemInterval, DebugOptions::LocalInterval(..) => Self::DebugOptionsLocalInterval, }, - Decorator::Event(_) => Self::Event, Decorator::Trace(_) => Self::Trace, } } } + +impl Serializable for EncodedDecoratorVariant { + fn write_into(&self, target: &mut W) { + self.discriminant().write_into(target); + } +} + +impl Deserializable for EncodedDecoratorVariant { + fn read_from(source: &mut R) -> Result { + let discriminant: u8 = source.read_u8()?; + + Self::from_discriminant(discriminant).ok_or_else(|| { + DeserializationError::InvalidValue(format!( + "invalid decorator discriminant: {discriminant}" + )) + }) + } +} + +// DECORATOR DATA BUILDER +// =============================================================================================== + +/// Builds the decorator `data` section of a serialized [`crate::mast::MastForest`]. +#[derive(Debug, Default)] +pub struct DecoratorDataBuilder { + decorator_data: Vec, +} + +/// Constructors +impl DecoratorDataBuilder { + pub fn new() -> Self { + Self::default() + } +} + +/// Mutators +impl DecoratorDataBuilder { + /// If a decorator has extra data to store, encode it in internal data buffer, and return the + /// offset of the newly added data. If not, return `None`. + pub fn encode_decorator_data( + &mut self, + decorator: &Decorator, + string_table_builder: &mut StringTableBuilder, + ) -> Option { + let data_offset = self.decorator_data.len() as DecoratorDataOffset; + + match decorator { + Decorator::Advice(advice_injector) => match advice_injector { + AdviceInjector::MapValueToStack { include_len, key_offset } => { + self.decorator_data.write_bool(*include_len); + self.decorator_data.write_usize(*key_offset); + + Some(data_offset) + }, + AdviceInjector::HdwordToMap { domain } => { + self.decorator_data.extend(domain.as_int().to_le_bytes()); + + Some(data_offset) + }, + + // Note: Since there is only 1 variant, we don't need to write any extra bytes. + AdviceInjector::SigToStack { kind } => match kind { + SignatureKind::RpoFalcon512 => None, + }, + AdviceInjector::MerkleNodeMerge + | AdviceInjector::MerkleNodeToStack + | AdviceInjector::UpdateMerkleNode + | AdviceInjector::U64Div + | AdviceInjector::Ext2Inv + | AdviceInjector::Ext2Intt + | AdviceInjector::SmtGet + | AdviceInjector::SmtSet + | AdviceInjector::SmtPeek + | AdviceInjector::U32Clz + | AdviceInjector::U32Ctz + | AdviceInjector::U32Clo + | AdviceInjector::U32Cto + | AdviceInjector::ILog2 + | AdviceInjector::MemToMap + | AdviceInjector::HpermToMap => None, + }, + Decorator::AsmOp(assembly_op) => { + self.decorator_data.push(assembly_op.num_cycles()); + self.decorator_data.write_bool(assembly_op.should_break()); + + // source location + let loc = assembly_op.location(); + self.decorator_data.write_bool(loc.is_some()); + if let Some(loc) = loc { + let str_offset = string_table_builder.add_string(loc.path.as_ref()); + self.decorator_data.write_usize(str_offset); + self.decorator_data.write_u32(loc.start.to_u32()); + self.decorator_data.write_u32(loc.end.to_u32()); + } + + // context name + { + let str_offset = string_table_builder.add_string(assembly_op.context_name()); + self.decorator_data.write_usize(str_offset); + } + + // op + { + let str_index_in_table = string_table_builder.add_string(assembly_op.op()); + self.decorator_data.write_usize(str_index_in_table); + } + + Some(data_offset) + }, + Decorator::Debug(debug_options) => match debug_options { + DebugOptions::StackTop(value) => { + self.decorator_data.push(*value); + Some(data_offset) + }, + DebugOptions::MemInterval(start, end) => { + self.decorator_data.extend(start.to_le_bytes()); + self.decorator_data.extend(end.to_le_bytes()); + + Some(data_offset) + }, + DebugOptions::LocalInterval(start, second, end) => { + self.decorator_data.extend(start.to_le_bytes()); + self.decorator_data.extend(second.to_le_bytes()); + self.decorator_data.extend(end.to_le_bytes()); + + Some(data_offset) + }, + DebugOptions::StackAll | DebugOptions::MemAll => None, + }, + Decorator::Trace(value) => { + self.decorator_data.extend(value.to_le_bytes()); + + Some(data_offset) + }, + } + } + + /// Returns the serialized [`crate::mast::MastForest`] decorator data field. + pub fn finalize(self) -> Vec { + self.decorator_data + } +} diff --git a/core/src/mast/serialization/info.rs b/core/src/mast/serialization/info.rs index d26543720e..fa9efe3371 100644 --- a/core/src/mast/serialization/info.rs +++ b/core/src/mast/serialization/info.rs @@ -1,7 +1,7 @@ use miden_crypto::hash::rpo::RpoDigest; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; -use super::{basic_block_data_decoder::BasicBlockDataDecoder, DataOffset}; +use super::{basic_blocks::BasicBlockDataDecoder, NodeDataOffset}; use crate::mast::{ BasicBlockNode, CallNode, JoinNode, LoopNode, MastForest, MastNode, MastNodeId, SplitNode, }; @@ -21,55 +21,77 @@ pub struct MastNodeInfo { } impl MastNodeInfo { - pub fn new(mast_node: &MastNode, basic_block_offset: DataOffset) -> Self { - let ty = MastNodeType::new(mast_node, basic_block_offset); + /// Constructs a new [`MastNodeInfo`] from a [`MastNode`], along with an `ops_offset` and + /// `decorator_list_offset` in the case of [`BasicBlockNode`]. + /// + /// If the represented [`MastNode`] is a [`BasicBlockNode`] that has an empty decorator list, + /// use `MastForest::MAX_DECORATORS` for the value of `decorator_list_offset`. For non-basic + /// block nodes, `ops_offset` and `decorator_list_offset` are ignored, and should be set to 0. + pub fn new( + mast_node: &MastNode, + ops_offset: NodeDataOffset, + decorator_list_offset: NodeDataOffset, + ) -> Self { + if !matches!(mast_node, &MastNode::Block(_)) { + debug_assert_eq!(ops_offset, 0); + debug_assert_eq!(decorator_list_offset, 0); + } + + let ty = MastNodeType::new(mast_node, ops_offset, decorator_list_offset); Self { ty, digest: mast_node.digest() } } + /// Attempts to convert this [`MastNodeInfo`] into a [`MastNode`] for the given `mast_forest`. + /// + /// The `node_count` is the total expected number of nodes in the [`MastForest`] **after + /// deserialization**. pub fn try_into_mast_node( self, mast_forest: &MastForest, + node_count: usize, basic_block_data_decoder: &BasicBlockDataDecoder, ) -> Result { match self.ty { - MastNodeType::Block { - offset, - len: num_operations_and_decorators, - } => { + MastNodeType::Block { ops_offset, decorator_list_offset } => { let (operations, decorators) = basic_block_data_decoder - .decode_operations_and_decorators(offset, num_operations_and_decorators)?; + .decode_operations_and_decorators( + ops_offset, + decorator_list_offset, + mast_forest, + )?; let block = BasicBlockNode::new_unsafe(operations, decorators, self.digest); Ok(MastNode::Block(block)) }, MastNodeType::Join { left_child_id, right_child_id } => { - let left_child = MastNodeId::from_u32_safe(left_child_id, mast_forest)?; - let right_child = MastNodeId::from_u32_safe(right_child_id, mast_forest)?; + let left_child = MastNodeId::from_u32_with_node_count(left_child_id, node_count)?; + let right_child = MastNodeId::from_u32_with_node_count(right_child_id, node_count)?; let join = JoinNode::new_unsafe([left_child, right_child], self.digest); Ok(MastNode::Join(join)) }, MastNodeType::Split { if_branch_id, else_branch_id } => { - let if_branch = MastNodeId::from_u32_safe(if_branch_id, mast_forest)?; - let else_branch = MastNodeId::from_u32_safe(else_branch_id, mast_forest)?; + let if_branch = MastNodeId::from_u32_with_node_count(if_branch_id, node_count)?; + let else_branch = MastNodeId::from_u32_with_node_count(else_branch_id, node_count)?; let split = SplitNode::new_unsafe([if_branch, else_branch], self.digest); Ok(MastNode::Split(split)) }, MastNodeType::Loop { body_id } => { - let body_id = MastNodeId::from_u32_safe(body_id, mast_forest)?; + let body_id = MastNodeId::from_u32_with_node_count(body_id, node_count)?; let loop_node = LoopNode::new_unsafe(body_id, self.digest); Ok(MastNode::Loop(loop_node)) }, MastNodeType::Call { callee_id } => { - let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?; + let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?; let call = CallNode::new_unsafe(callee_id, self.digest); Ok(MastNode::Call(call)) }, MastNodeType::SysCall { callee_id } => { - let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?; + let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?; let syscall = CallNode::new_syscall_unsafe(callee_id, self.digest); Ok(MastNode::Call(syscall)) }, MastNodeType::Dyn => Ok(MastNode::new_dyn()), + MastNodeType::Dyncall => Ok(MastNode::new_dyncall()), MastNodeType::External => Ok(MastNode::new_external(self.digest)), } } @@ -103,7 +125,8 @@ const BLOCK: u8 = 3; const CALL: u8 = 4; const SYSCALL: u8 = 5; const DYN: u8 = 6; -const EXTERNAL: u8 = 7; +const DYNCALL: u8 = 7; +const EXTERNAL: u8 = 8; /// Represents the variant of a [`MastNode`], as well as any additional data. For example, for more /// efficient decoding, and because of the frequency with which these node types appear, we directly @@ -126,10 +149,10 @@ pub enum MastNodeType { body_id: u32, } = LOOP, Block { - /// Offset of the basic block in the data segment - offset: u32, - /// The number of operations and decorators in the basic block - len: u32, + // offset of operations in node data + ops_offset: u32, + // offset of DecoratorList in node data + decorator_list_offset: u32, } = BLOCK, Call { callee_id: u32, @@ -138,21 +161,25 @@ pub enum MastNodeType { callee_id: u32, } = SYSCALL, Dyn = DYN, + Dyncall = DYNCALL, External = EXTERNAL, } /// Constructors impl MastNodeType { /// Constructs a new [`MastNodeType`] from a [`MastNode`]. - pub fn new(mast_node: &MastNode, basic_block_offset: u32) -> Self { + /// + /// If the represented [`MastNode`] is a [`BasicBlockNode`] that has an empty decorator list, + /// use `MastForest::MAX_DECORATORS` for the value of `decorator_list_offset`. + pub fn new( + mast_node: &MastNode, + ops_offset: NodeDataOffset, + decorator_list_offset: NodeDataOffset, + ) -> Self { use MastNode::*; match mast_node { - Block(block_node) => { - let len = block_node.num_operations_and_decorators(); - - Self::Block { len, offset: basic_block_offset } - }, + Block(_block_node) => Self::Block { decorator_list_offset, ops_offset }, Join(join_node) => Self::Join { left_child_id: join_node.first().0, right_child_id: join_node.second().0, @@ -169,7 +196,13 @@ impl MastNodeType { Self::Call { callee_id: call_node.callee().0 } } }, - Dyn => Self::Dyn, + Dyn(dyn_node) => { + if dyn_node.is_dyncall() { + Self::Dyncall + } else { + Self::Dyn + } + }, External(_) => Self::External, } } @@ -190,10 +223,13 @@ impl Serializable for MastNodeType { else_branch_id: else_branch, } => Self::encode_u32_pair(if_branch, else_branch), MastNodeType::Loop { body_id: body } => Self::encode_u32_payload(body), - MastNodeType::Block { offset, len } => Self::encode_u32_pair(offset, len), + MastNodeType::Block { ops_offset, decorator_list_offset } => { + Self::encode_u32_pair(ops_offset, decorator_list_offset) + }, MastNodeType::Call { callee_id } => Self::encode_u32_payload(callee_id), MastNodeType::SysCall { callee_id } => Self::encode_u32_payload(callee_id), MastNodeType::Dyn => 0, + MastNodeType::Dyncall => 0, MastNodeType::External => 0, }; @@ -264,8 +300,8 @@ impl Deserializable for MastNodeType { Ok(Self::Loop { body_id }) }, BLOCK => { - let (offset, len) = Self::decode_u32_pair(payload); - Ok(Self::Block { offset, len }) + let (ops_offset, decorator_list_offset) = Self::decode_u32_pair(payload); + Ok(Self::Block { ops_offset, decorator_list_offset }) }, CALL => { let callee_id = Self::decode_u32_payload(payload)?; @@ -276,6 +312,7 @@ impl Deserializable for MastNodeType { Ok(Self::SysCall { callee_id }) }, DYN => Ok(Self::Dyn), + DYNCALL => Ok(Self::Dyncall), EXTERNAL => Ok(Self::External), _ => Err(DeserializationError::InvalidValue(format!( "Invalid tag for MAST node: {discriminant}" diff --git a/core/src/mast/serialization/mod.rs b/core/src/mast/serialization/mod.rs index 334275698f..e76f775c5e 100644 --- a/core/src/mast/serialization/mod.rs +++ b/core/src/mast/serialization/mod.rs @@ -1,19 +1,46 @@ +//! The serialization format of MastForest is as follows: +//! +//! (Metadata) +//! - MAGIC +//! - VERSION +//! +//! (lengths) +//! - decorators length (`usize`) +//! - nodes length (`usize`) +//! +//! (procedure roots) +//! - procedure roots (`Vec`) +//! +//! (raw data) +//! - Decorator data +//! - Node data +//! - String table +//! +//! (info structs) +//! - decorator infos (`Vec`) +//! - MAST node infos (`Vec`) +//! +//! (before enter and after exit decorators) +//! - before enter decorators (`Vec<(MastNodeId, Vec)>`) +//! - after exit decorators (`Vec<(MastNodeId, Vec)>`) + use alloc::vec::Vec; +use decorator::{DecoratorDataBuilder, DecoratorInfo}; +use string_table::{StringTable, StringTableBuilder}; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; -use super::{MastForest, MastNode, MastNodeId}; +use super::{DecoratorId, MastForest, MastNode, MastNodeId}; mod decorator; mod info; use info::MastNodeInfo; -mod basic_block_data_builder; -use basic_block_data_builder::BasicBlockDataBuilder; +mod basic_blocks; +use basic_blocks::{BasicBlockDataBuilder, BasicBlockDataDecoder}; -mod basic_block_data_decoder; -use basic_block_data_decoder::BasicBlockDataDecoder; +mod string_table; #[cfg(test)] mod tests; @@ -21,10 +48,16 @@ mod tests; // TYPE ALIASES // ================================================================================================ -/// Specifies an offset into the `data` section of an encoded [`MastForest`]. -type DataOffset = u32; +/// Specifies an offset into the `node_data` section of an encoded [`MastForest`]. +type NodeDataOffset = u32; + +/// Specifies an offset into the `decorator_data` section of an encoded [`MastForest`]. +type DecoratorDataOffset = u32; + +/// Specifies an offset into the `strings_data` section of an encoded [`MastForest`]. +type StringDataOffset = usize; -/// Specifies an offset into the `strings` table of an encoded [`MastForest`] +/// Specifies an offset into the strings table of an encoded [`MastForest`]. type StringIndex = usize; // CONSTANTS @@ -46,43 +79,88 @@ const VERSION: [u8; 3] = [0, 0, 0]; impl Serializable for MastForest { fn write_into(&self, target: &mut W) { let mut basic_block_data_builder = BasicBlockDataBuilder::new(); + let mut decorator_data_builder = DecoratorDataBuilder::new(); + let mut string_table_builder = StringTableBuilder::default(); + + // Set up "before enter" and "after exit" decorators by `MastNodeId` + let mut before_enter_decorators: Vec<(usize, Vec)> = Vec::new(); + let mut after_exit_decorators: Vec<(usize, Vec)> = Vec::new(); // magic & version target.write_bytes(MAGIC); target.write_bytes(&VERSION); - // node count + // decorator & node counts + target.write_usize(self.decorators.len()); target.write_usize(self.nodes.len()); // roots let roots: Vec = self.roots.iter().map(u32::from).collect(); roots.write_into(target); + // decorators + let decorator_infos: Vec = self + .decorators + .iter() + .map(|decorator| { + DecoratorInfo::from_decorator( + decorator, + &mut decorator_data_builder, + &mut string_table_builder, + ) + }) + .collect(); + // Prepare MAST node infos, but don't store them yet. We store them at the end to make // deserialization more efficient. let mast_node_infos: Vec = self .nodes .iter() - .map(|mast_node| { - let mast_node_info = - MastNodeInfo::new(mast_node, basic_block_data_builder.get_offset()); - - if let MastNode::Block(basic_block) = mast_node { - basic_block_data_builder.encode_basic_block(basic_block); + .enumerate() + .map(|(mast_node_id, mast_node)| { + if !mast_node.before_enter().is_empty() { + before_enter_decorators.push((mast_node_id, mast_node.before_enter().to_vec())); } + if !mast_node.after_exit().is_empty() { + after_exit_decorators.push((mast_node_id, mast_node.after_exit().to_vec())); + } + + let (ops_offset, decorator_data_offset) = if let MastNode::Block(basic_block) = + mast_node + { + let (ops_offset, decorator_data_offset) = + basic_block_data_builder.encode_basic_block(basic_block); - mast_node_info + (ops_offset, decorator_data_offset.unwrap_or(MastForest::MAX_DECORATORS as u32)) + } else { + (0, 0) + }; + + MastNodeInfo::new(mast_node, ops_offset, decorator_data_offset) }) .collect(); - let (data, string_table) = basic_block_data_builder.into_parts(); + let decorator_data = decorator_data_builder.finalize(); + let node_data = basic_block_data_builder.finalize(); + let string_table = string_table_builder.into_table(); + // Write 3 data buffers + decorator_data.write_into(target); + node_data.write_into(target); string_table.write_into(target); - data.write_into(target); + + // Write decorator and node infos + for decorator_info in decorator_infos { + decorator_info.write_into(target); + } for mast_node_info in mast_node_infos { mast_node_info.write_into(target); } + + // Write "before enter" and "after exit" decorators + before_enter_decorators.write_into(target); + after_exit_decorators.write_into(target); } } @@ -103,21 +181,39 @@ impl Deserializable for MastForest { ))); } + let decorator_count = source.read_usize()?; let node_count = source.read_usize()?; let roots: Vec = Deserializable::read_from(source)?; - let strings: Vec = Deserializable::read_from(source)?; - let data: Vec = Deserializable::read_from(source)?; + let decorator_data: Vec = Deserializable::read_from(source)?; + let node_data: Vec = Deserializable::read_from(source)?; + let string_table: StringTable = Deserializable::read_from(source)?; - let basic_block_data_decoder = BasicBlockDataDecoder::new(&data, &strings); - - let mast_forest = { + let mut mast_forest = { let mut mast_forest = MastForest::new(); + // decorators + for _ in 0..decorator_count { + let decorator_info = DecoratorInfo::read_from(source)?; + let decorator = + decorator_info.try_into_decorator(&string_table, &decorator_data)?; + + mast_forest.add_decorator(decorator).map_err(|e| { + DeserializationError::InvalidValue(format!( + "failed to add decorator to MAST forest while deserializing: {e}", + )) + })?; + } + + // nodes + let basic_block_data_decoder = BasicBlockDataDecoder::new(&node_data); for _ in 0..node_count { let mast_node_info = MastNodeInfo::read_from(source)?; - let node = - mast_node_info.try_into_mast_node(&mast_forest, &basic_block_data_decoder)?; + let node = mast_node_info.try_into_mast_node( + &mast_forest, + node_count, + &basic_block_data_decoder, + )?; mast_forest.add_node(node).map_err(|e| { DeserializationError::InvalidValue(format!( @@ -126,6 +222,7 @@ impl Deserializable for MastForest { })?; } + // roots for root in roots { // make sure the root is valid in the context of the MAST forest let root = MastNodeId::from_u32_safe(root, &mast_forest)?; @@ -135,6 +232,59 @@ impl Deserializable for MastForest { mast_forest }; + // read "before enter" and "after exit" decorators, and update the corresponding nodes + let before_enter_decorators: Vec<(usize, Vec)> = + read_before_after_decorators(source, &mast_forest)?; + for (node_id, decorator_ids) in before_enter_decorators { + let node_id: u32 = node_id.try_into().map_err(|_| { + DeserializationError::InvalidValue(format!( + "Invalid node id '{node_id}' while deserializing" + )) + })?; + let node_id = MastNodeId::from_u32_safe(node_id, &mast_forest)?; + mast_forest.set_before_enter(node_id, decorator_ids); + } + + let after_exit_decorators: Vec<(usize, Vec)> = + read_before_after_decorators(source, &mast_forest)?; + for (node_id, decorator_ids) in after_exit_decorators { + let node_id: u32 = node_id.try_into().map_err(|_| { + DeserializationError::InvalidValue(format!( + "Invalid node id '{node_id}' while deserializing" + )) + })?; + let node_id = MastNodeId::from_u32_safe(node_id, &mast_forest)?; + mast_forest.set_after_exit(node_id, decorator_ids); + } + Ok(mast_forest) } } + +/// Reads the `before_enter_decorators` and `after_exit_decorators` of the serialized `MastForest` +/// format. +/// +/// Note that we need this custom format because we cannot implement `Deserializable` for +/// `DecoratorId` (in favor of using [`DecoratorId::from_u32_safe`]). +fn read_before_after_decorators( + source: &mut R, + mast_forest: &MastForest, +) -> Result)>, DeserializationError> { + let vec_len: usize = source.read()?; + let mut out_vec: Vec<_> = Vec::with_capacity(vec_len); + + for _ in 0..vec_len { + let node_id: usize = source.read()?; + + let inner_vec_len: usize = source.read()?; + let mut inner_vec: Vec = Vec::with_capacity(inner_vec_len); + for _ in 0..inner_vec_len { + let decorator_id = DecoratorId::from_u32_safe(source.read()?, mast_forest)?; + inner_vec.push(decorator_id); + } + + out_vec.push((node_id, inner_vec)); + } + + Ok(out_vec) +} diff --git a/core/src/mast/serialization/string_table.rs b/core/src/mast/serialization/string_table.rs new file mode 100644 index 0000000000..9377aaa856 --- /dev/null +++ b/core/src/mast/serialization/string_table.rs @@ -0,0 +1,114 @@ +use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec}; +use core::cell::RefCell; + +use miden_crypto::hash::blake::{Blake3Digest, Blake3_256}; +use winter_utils::{ + ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, SliceReader, +}; + +use super::{StringDataOffset, StringIndex}; + +pub struct StringTable { + data: Vec, + + table: Vec, + + /// This field is used to allocate an `Arc` for any string in `strings` where the decoder + /// requests a reference-counted string rather than a fresh allocation as a `String`. + /// + /// Currently, this is only used for debug information (source file names), but most cases + /// where strings are stored in MAST are stored as `Arc` in practice, we just haven't yet + /// updated all of the decoders. + /// + /// We lazily allocate an `Arc` when strings are decoded as an `Arc`, but the underlying + /// string data corresponds to the same index in `strings`. All future requests for a + /// ref-counted string we've allocated an `Arc` for, will clone the `Arc` rather than + /// allocate a fresh string. + refc_strings: Vec>>>, +} + +impl StringTable { + pub fn new(table: Vec, data: Vec) -> Self { + let mut refc_strings = Vec::with_capacity(table.len()); + refc_strings.resize(table.len(), RefCell::new(None)); + + Self { table, data, refc_strings } + } + + pub fn read_arc_str(&self, str_idx: StringIndex) -> Result, DeserializationError> { + if let Some(cached) = self.refc_strings.get(str_idx).and_then(|cell| cell.borrow().clone()) + { + return Ok(cached); + } + + let string = Arc::from(self.read_string(str_idx)?.into_boxed_str()); + *self.refc_strings[str_idx].borrow_mut() = Some(Arc::clone(&string)); + Ok(string) + } + + pub fn read_string(&self, str_idx: StringIndex) -> Result { + let str_offset = self.table.get(str_idx).copied().ok_or_else(|| { + DeserializationError::InvalidValue(format!("invalid index in strings table: {str_idx}")) + })?; + + let mut reader = SliceReader::new(&self.data[str_offset..]); + reader.read() + } +} + +impl Serializable for StringTable { + fn write_into(&self, target: &mut W) { + let Self { table, data, refc_strings: _ } = self; + + table.write_into(target); + data.write_into(target); + } +} + +impl Deserializable for StringTable { + fn read_from(source: &mut R) -> Result { + let table = source.read()?; + let data = source.read()?; + + Ok(Self::new(table, data)) + } +} + +// STRING TABLE BUILDER +// ================================================================================================ + +#[derive(Debug, Default)] +pub struct StringTableBuilder { + table: Vec, + str_to_index: BTreeMap, StringIndex>, + strings_data: Vec, +} + +impl StringTableBuilder { + pub fn add_string(&mut self, string: &str) -> StringIndex { + if let Some(str_idx) = self.str_to_index.get(&Blake3_256::hash(string.as_bytes())) { + // return already interned string + *str_idx + } else { + // add new string to table + let str_offset = self.strings_data.len(); + + assert!( + str_offset + string.len() < u32::MAX as usize, + "strings table larger than 2^32 bytes" + ); + + let str_idx = self.table.len(); + + string.write_into(&mut self.strings_data); + self.table.push(str_offset); + self.str_to_index.insert(Blake3_256::hash(string.as_bytes()), str_idx); + + str_idx + } + } + + pub fn into_table(self) -> StringTable { + StringTable::new(self.table, self.strings_data) + } +} diff --git a/core/src/mast/serialization/tests.rs b/core/src/mast/serialization/tests.rs index f99b29168d..cb6e9e2c09 100644 --- a/core/src/mast/serialization/tests.rs +++ b/core/src/mast/serialization/tests.rs @@ -27,6 +27,7 @@ fn confirm_operation_and_decorator_structure() { Operation::Loop => (), Operation::Call => (), Operation::Dyn => (), + Operation::Dyncall => (), Operation::SysCall => (), Operation::Span => (), Operation::End => (), @@ -104,9 +105,10 @@ fn confirm_operation_and_decorator_structure() { Operation::MrUpdate => (), Operation::FriE2F4 => (), Operation::RCombBase => (), + Operation::Emit(_) => (), }; - match Decorator::Event(0) { + match Decorator::Trace(0) { Decorator::Advice(advice) => match advice { AdviceInjector::MerkleNodeMerge => (), AdviceInjector::MerkleNodeToStack => (), @@ -136,7 +138,6 @@ fn confirm_operation_and_decorator_structure() { DebugOptions::MemInterval(..) => (), DebugOptions::LocalInterval(..) => (), }, - Decorator::Event(_) => (), Decorator::Trace(_) => (), }; } @@ -236,6 +237,7 @@ fn serialize_deserialize_all_nodes() { Operation::MrUpdate, Operation::FriE2F4, Operation::RCombBase, + Operation::Emit(42), ]; let num_operations = operations.len(); @@ -288,29 +290,62 @@ fn serialize_deserialize_all_nodes() { (15, Decorator::Debug(DebugOptions::MemAll)), (15, Decorator::Debug(DebugOptions::MemInterval(0, 16))), (17, Decorator::Debug(DebugOptions::LocalInterval(1, 2, 3))), - (num_operations, Decorator::Event(45)), (num_operations, Decorator::Trace(55)), ]; - mast_forest.add_block(operations, Some(decorators)).unwrap() + mast_forest.add_block_with_raw_decorators(operations, decorators).unwrap() }; + // Decorators to add to following nodes + let decorator_id1 = mast_forest.add_decorator(Decorator::Trace(1)).unwrap(); + let decorator_id2 = mast_forest.add_decorator(Decorator::Trace(2)).unwrap(); + + // Call node let call_node_id = mast_forest.add_call(basic_block_id).unwrap(); + mast_forest[call_node_id].set_before_enter(vec![decorator_id1]); + mast_forest[call_node_id].set_after_exit(vec![decorator_id2]); + // Syscall node let syscall_node_id = mast_forest.add_syscall(basic_block_id).unwrap(); + mast_forest[syscall_node_id].set_before_enter(vec![decorator_id1]); + mast_forest[syscall_node_id].set_after_exit(vec![decorator_id2]); + // Loop node let loop_node_id = mast_forest.add_loop(basic_block_id).unwrap(); + mast_forest[loop_node_id].set_before_enter(vec![decorator_id1]); + mast_forest[loop_node_id].set_after_exit(vec![decorator_id2]); + + // Join node let join_node_id = mast_forest.add_join(basic_block_id, call_node_id).unwrap(); + mast_forest[join_node_id].set_before_enter(vec![decorator_id1]); + mast_forest[join_node_id].set_after_exit(vec![decorator_id2]); + + // Split node let split_node_id = mast_forest.add_split(basic_block_id, call_node_id).unwrap(); + mast_forest[split_node_id].set_before_enter(vec![decorator_id1]); + mast_forest[split_node_id].set_after_exit(vec![decorator_id2]); + + // Dyn node let dyn_node_id = mast_forest.add_dyn().unwrap(); + mast_forest[dyn_node_id].set_before_enter(vec![decorator_id1]); + mast_forest[dyn_node_id].set_after_exit(vec![decorator_id2]); + + // Dyncall node + let dyncall_node_id = mast_forest.add_dyncall().unwrap(); + mast_forest[dyncall_node_id].set_before_enter(vec![decorator_id1]); + mast_forest[dyncall_node_id].set_after_exit(vec![decorator_id2]); + // External node let external_node_id = mast_forest.add_external(RpoDigest::default()).unwrap(); + mast_forest[external_node_id].set_before_enter(vec![decorator_id1]); + mast_forest[external_node_id].set_after_exit(vec![decorator_id2]); mast_forest.make_root(join_node_id); mast_forest.make_root(syscall_node_id); mast_forest.make_root(loop_node_id); mast_forest.make_root(split_node_id); mast_forest.make_root(dyn_node_id); + mast_forest.make_root(dyncall_node_id); mast_forest.make_root(external_node_id); let serialized_mast_forest = mast_forest.to_bytes(); @@ -319,6 +354,51 @@ fn serialize_deserialize_all_nodes() { assert_eq!(mast_forest, deserialized_mast_forest); } +/// Test that a forest with a node whose child ids are larger than its own id serializes and +/// deserializes successfully. +#[test] +fn mast_forest_serialize_deserialize_with_child_ids_exceeding_parent_id() { + let mut forest = MastForest::new(); + let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap(); + let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap(); + let zero = forest.add_block(vec![Operation::U32div], None).unwrap(); + let first = forest.add_block(vec![Operation::U32add], Some(vec![(0, deco0)])).unwrap(); + let second = forest.add_block(vec![Operation::U32and], Some(vec![(1, deco1)])).unwrap(); + forest.add_join(first, second).unwrap(); + + // Move the Join node before its child nodes and remove the temporary zero node. + forest.nodes.swap_remove(zero.as_usize()); + + MastForest::read_from_bytes(&forest.to_bytes()).unwrap(); +} + +/// Test that a forest with a node whose referenced index is >= the max number of nodes in +/// the forest returns an error during deserialization. +#[test] +fn mast_forest_serialize_deserialize_with_overflowing_ids_fails() { + let mut overflow_forest = MastForest::new(); + let id0 = overflow_forest.add_block(vec![Operation::Eqz], None).unwrap(); + overflow_forest.add_block(vec![Operation::Eqz], None).unwrap(); + let id2 = overflow_forest.add_block(vec![Operation::Eqz], None).unwrap(); + let id_join = overflow_forest.add_join(id0, id2).unwrap(); + + let join_node = overflow_forest[id_join].clone(); + + // Add the Join(0, 2) to this forest which does not have a node with index 2. + let mut forest = MastForest::new(); + let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap(); + let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap(); + forest + .add_block(vec![Operation::U32add], Some(vec![(0, deco0), (1, deco1)])) + .unwrap(); + forest.add_node(join_node).unwrap(); + + assert_matches!( + MastForest::read_from_bytes(&forest.to_bytes()), + Err(DeserializationError::InvalidValue(msg)) if msg.contains("number of nodes") + ); +} + #[test] fn mast_forest_invalid_node_id() { // Hydrate a forest smaller than the second diff --git a/core/src/mast/tests.rs b/core/src/mast/tests.rs index 086bf3bb11..31da93ba7b 100644 --- a/core/src/mast/tests.rs +++ b/core/src/mast/tests.rs @@ -10,8 +10,8 @@ use crate::{chiplets::hasher, mast::DynNode, Kernel, ProgramInfo, Word}; #[test] fn dyn_hash_is_correct() { let expected_constant = - hasher::merge_in_domain(&[RpoDigest::default(), RpoDigest::default()], DynNode::DOMAIN); - assert_eq!(expected_constant, DynNode.digest()); + hasher::merge_in_domain(&[RpoDigest::default(), RpoDigest::default()], DynNode::DYN_DOMAIN); + assert_eq!(expected_constant, DynNode::new_dyn().digest()); } proptest! { diff --git a/core/src/operations/decorators/mod.rs b/core/src/operations/decorators/mod.rs index c695d9ffe6..02e054eb0a 100644 --- a/core/src/operations/decorators/mod.rs +++ b/core/src/operations/decorators/mod.rs @@ -1,6 +1,9 @@ -use alloc::vec::Vec; +use alloc::{string::ToString, vec::Vec}; use core::fmt; +use miden_crypto::hash::blake::Blake3_256; +use num_traits::ToBytes; + mod advice; pub use advice::AdviceInjector; @@ -10,6 +13,8 @@ pub use assembly_op::AssemblyOp; mod debug; pub use debug::DebugOptions; +use crate::mast::{DecoratorFingerprint, DecoratorId}; + // DECORATORS // ================================================================================================ @@ -30,12 +35,34 @@ pub enum Decorator { /// Prints out information about the state of the VM based on the specified options. This /// decorator is executed only in debug mode. Debug(DebugOptions), - /// Emits an event to the host. - Event(u32), /// Emits a trace to the host. Trace(u32), } +impl Decorator { + pub fn fingerprint(&self) -> DecoratorFingerprint { + match self { + Self::Advice(advice) => Blake3_256::hash(advice.to_string().as_bytes()), + Self::AsmOp(asm_op) => { + let mut bytes_to_hash = Vec::new(); + if let Some(location) = asm_op.location() { + bytes_to_hash.extend(location.path.as_bytes()); + bytes_to_hash.extend(location.start.to_u32().to_le_bytes()); + bytes_to_hash.extend(location.end.to_u32().to_le_bytes()); + } + bytes_to_hash.extend(asm_op.context_name().as_bytes()); + bytes_to_hash.extend(asm_op.op().as_bytes()); + bytes_to_hash.push(asm_op.num_cycles()); + bytes_to_hash.push(asm_op.should_break() as u8); + + Blake3_256::hash(&bytes_to_hash) + }, + Self::Debug(debug) => Blake3_256::hash(debug.to_string().as_bytes()), + Self::Trace(trace) => Blake3_256::hash(&trace.to_le_bytes()), + } + } +} + impl crate::prettier::PrettyPrint for Decorator { fn render(&self) -> crate::prettier::Document { crate::prettier::display(self) @@ -50,7 +77,6 @@ impl fmt::Display for Decorator { write!(f, "asmOp({}, {})", assembly_op.op(), assembly_op.num_cycles()) }, Self::Debug(options) => write!(f, "debug({options})"), - Self::Event(event_id) => write!(f, "event({})", event_id), Self::Trace(trace_id) => write!(f, "trace({})", trace_id), } } @@ -58,7 +84,7 @@ impl fmt::Display for Decorator { /// Vector consisting of a tuple of operation index (within a span block) and decorator at that /// index -pub type DecoratorList = Vec<(usize, Decorator)>; +pub type DecoratorList = Vec<(usize, DecoratorId)>; /// Iterator used to iterate through the decorator list of a span block /// while executing operation batches of a span block. @@ -76,7 +102,7 @@ impl<'a> DecoratorIterator<'a> { /// Returns the next decorator but only if its position matches the specified position, /// otherwise, None is returned. #[inline(always)] - pub fn next_filtered(&mut self, pos: usize) -> Option<&Decorator> { + pub fn next_filtered(&mut self, pos: usize) -> Option<&DecoratorId> { if self.idx < self.decorators.len() && self.decorators[self.idx].0 == pos { self.idx += 1; Some(&self.decorators[self.idx - 1].1) @@ -87,7 +113,7 @@ impl<'a> DecoratorIterator<'a> { } impl<'a> Iterator for DecoratorIterator<'a> { - type Item = &'a Decorator; + type Item = &'a DecoratorId; fn next(&mut self) -> Option { if self.idx < self.decorators.len() { diff --git a/core/src/operations/mod.rs b/core/src/operations/mod.rs index 8a8770139e..7a2eb6e3c0 100644 --- a/core/src/operations/mod.rs +++ b/core/src/operations/mod.rs @@ -109,9 +109,12 @@ pub(super) mod opcode_constants { pub const OPCODE_JOIN: u8 = 0b0101_0111; pub const OPCODE_DYN: u8 = 0b0101_1000; pub const OPCODE_RCOMBBASE: u8 = 0b0101_1001; + pub const OPCODE_EMIT: u8 = 0b0101_1010; + pub const OPCODE_PUSH: u8 = 0b0101_1011; + pub const OPCODE_DYNCALL: u8 = 0b0101_1100; pub const OPCODE_MRUPDATE: u8 = 0b0110_0000; - pub const OPCODE_PUSH: u8 = 0b0110_0100; + /* unused: 0b0110_0100 */ pub const OPCODE_SYSCALL: u8 = 0b0110_1000; pub const OPCODE_CALL: u8 = 0b0110_1100; pub const OPCODE_END: u8 = 0b0111_0000; @@ -156,6 +159,16 @@ pub enum Operation { /// instruction. Clk = OPCODE_CLK, + /// Emits an event id (`u32` value) to the host. + /// + /// We interpret the event id as follows: + /// - 16 most significant bits identify the event source, + /// - 16 least significant bits identify the actual event. + /// + /// Similar to Noop, this operation does not change the state of user stack. The immediate + /// value affects the program MAST root computation. + Emit(u32) = OPCODE_EMIT, + // ----- flow control operations ------------------------------------------------------------- /// Marks the beginning of a join block. Join = OPCODE_JOIN, @@ -172,6 +185,9 @@ pub enum Operation { /// Marks the beginning of a dynamic code block, where the target is specified by the stack. Dyn = OPCODE_DYN, + /// Marks the beginning of a dynamic function call, where the target is specified by the stack. + Dyncall = OPCODE_DYNCALL, + /// Marks the beginning of a kernel call. SysCall = OPCODE_SYSCALL, @@ -570,14 +586,17 @@ impl Operation { /// Returns an immediate value carried by this operation. pub fn imm_value(&self) -> Option { - match self { - Self::Push(imm) => Some(*imm), + match *self { + Self::Push(imm) => Some(imm), + Self::Emit(imm) => Some(imm.into()), _ => None, } } - /// Returns true if this operation is a control operation. - pub fn is_control_op(&self) -> bool { + /// Returns true if this operation writes any data to the decoder hasher registers. + /// + /// In other words, if so, then the user op helper registers are not available. + pub fn populates_decoder_hasher_registers(&self) -> bool { matches!( self, Self::End @@ -590,7 +609,6 @@ impl Operation { | Self::Halt | Self::Call | Self::SysCall - | Self::Dyn ) } } @@ -621,6 +639,7 @@ impl fmt::Display for Operation { Self::Split => write!(f, "split"), Self::Loop => write!(f, "loop"), Self::Call => writeln!(f, "call"), + Self::Dyncall => writeln!(f, "dyncall"), Self::SysCall => writeln!(f, "syscall"), Self::Dyn => writeln!(f, "dyn"), Self::Span => write!(f, "span"), @@ -718,6 +737,8 @@ impl fmt::Display for Operation { Self::MStream => write!(f, "mstream"), Self::Pipe => write!(f, "pipe"), + Self::Emit(value) => write!(f, "emit({value})"), + // ----- cryptographic operations ----------------------------------------------------- Self::HPerm => write!(f, "hperm"), Self::MpVerify(err_code) => write!(f, "mpverify({err_code})"), @@ -737,9 +758,10 @@ impl Serializable for Operation { Operation::Assert(err_code) | Operation::MpVerify(err_code) | Operation::U32assert2(err_code) => { - err_code.to_le_bytes().write_into(target); + err_code.write_into(target); }, Operation::Push(value) => value.as_int().write_into(target), + Operation::Emit(value) => value.write_into(target), // Note: we explicitly write out all the operations so that whenever we make a // modification to the `Operation` enum, we get a compile error here. This @@ -755,6 +777,7 @@ impl Serializable for Operation { | Operation::Loop | Operation::Call | Operation::Dyn + | Operation::Dyncall | Operation::SysCall | Operation::Span | Operation::End @@ -934,6 +957,7 @@ impl Deserializable for Operation { OPCODE_SPAN => Self::Span, OPCODE_JOIN => Self::Join, OPCODE_DYN => Self::Dyn, + OPCODE_DYNCALL => Self::Dyncall, OPCODE_RCOMBBASE => Self::RCombBase, OPCODE_MRUPDATE => Self::MrUpdate, @@ -947,6 +971,11 @@ impl Deserializable for Operation { Self::Push(value_felt) }, + OPCODE_EMIT => { + let value = source.read_u32()?; + + Self::Emit(value) + }, OPCODE_SYSCALL => Self::SysCall, OPCODE_CALL => Self::Call, OPCODE_END => Self::End, diff --git a/core/src/program.rs b/core/src/program.rs index 093e0902a6..ebff4c7b01 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -1,5 +1,5 @@ -use alloc::vec::Vec; -use core::{fmt, ops::Index}; +use alloc::{sync::Arc, vec::Vec}; +use core::fmt; use miden_crypto::{hash::rpo::RpoDigest, Felt, WORD_SIZE}; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; @@ -13,9 +13,14 @@ use crate::{ // PROGRAM // =============================================================================================== +/// An executable program for Miden VM. +/// +/// A program consists of a MAST forest, an entrypoint defining the MAST node at which the program +/// execution begins, and a definition of the kernel against which the program must be executed +/// (the kernel can be an empty kernel). #[derive(Clone, Debug, PartialEq, Eq)] pub struct Program { - mast_forest: MastForest, + mast_forest: Arc, /// The "entrypoint" is the node where execution of the program begins. entrypoint: MastNodeId, kernel: Kernel, @@ -27,38 +32,37 @@ impl Program { /// to be empty. /// /// # Panics: - /// - if `mast_forest` doesn't have an entrypoint - pub fn new(mast_forest: MastForest, entrypoint: MastNodeId) -> Self { - assert!(mast_forest.get_node_by_id(entrypoint).is_some()); - - Self { - mast_forest, - entrypoint, - kernel: Kernel::default(), - } + /// - if `mast_forest` doesn't contain the specified entrypoint. + /// - if the specified entrypoint is not a procedure root in the `mast_forest`. + pub fn new(mast_forest: Arc, entrypoint: MastNodeId) -> Self { + Self::with_kernel(mast_forest, entrypoint, Kernel::default()) } /// Construct a new [`Program`] from the given MAST forest, entrypoint, and kernel. /// /// # Panics: - /// - if `mast_forest` doesn't have an entrypoint - pub fn with_kernel(mast_forest: MastForest, entrypoint: MastNodeId, kernel: Kernel) -> Self { - assert!(mast_forest.get_node_by_id(entrypoint).is_some()); + /// - if `mast_forest` doesn't contain the specified entrypoint. + /// - if the specified entrypoint is not a procedure root in the `mast_forest`. + pub fn with_kernel( + mast_forest: Arc, + entrypoint: MastNodeId, + kernel: Kernel, + ) -> Self { + assert!(mast_forest.get_node_by_id(entrypoint).is_some(), "invalid entrypoint"); + assert!(mast_forest.is_procedure_root(entrypoint), "entrypoint not a procedure"); Self { mast_forest, entrypoint, kernel } } } +// ------------------------------------------------------------------------------------------------ /// Public accessors impl Program { - /// Returns the underlying [`MastForest`]. - pub fn mast_forest(&self) -> &MastForest { - &self.mast_forest - } - - /// Returns the kernel associated with this program. - pub fn kernel(&self) -> &Kernel { - &self.kernel + /// Returns the hash of the program's entrypoint. + /// + /// Equivalently, returns the hash of the root of the entrypoint procedure. + pub fn hash(&self) -> RpoDigest { + self.mast_forest[self.entrypoint].digest() } /// Returns the entrypoint associated with this program. @@ -66,17 +70,20 @@ impl Program { self.entrypoint } - /// Returns the hash of the program's entrypoint. - /// - /// Equivalently, returns the hash of the root of the entrypoint procedure. - pub fn hash(&self) -> RpoDigest { - self.mast_forest[self.entrypoint].digest() + /// Returns a reference to the underlying [`MastForest`]. + pub fn mast_forest(&self) -> &Arc { + &self.mast_forest + } + + /// Returns the kernel associated with this program. + pub fn kernel(&self) -> &Kernel { + &self.kernel } /// Returns the [`MastNode`] associated with the provided [`MastNodeId`] if valid, or else /// `None`. /// - /// This is the faillible version of indexing (e.g. `program[node_id]`). + /// This is the fallible version of indexing (e.g. `program[node_id]`). #[inline(always)] pub fn get_node_by_id(&self, node_id: MastNodeId) -> Option<&MastNode> { self.mast_forest.get_node_by_id(node_id) @@ -94,10 +101,11 @@ impl Program { } } +// ------------------------------------------------------------------------------------------------ /// Serialization +#[cfg(feature = "std")] impl Program { /// Writes this [Program] to the provided file path. - #[cfg(feature = "std")] pub fn write_to_file

(&self, path: P) -> std::io::Result<()> where P: AsRef, @@ -139,26 +147,27 @@ impl Serializable for Program { impl Deserializable for Program { fn read_from(source: &mut R) -> Result { - let mast_forest = source.read()?; + let mast_forest = Arc::new(source.read()?); let kernel = source.read()?; let entrypoint = MastNodeId::from_u32_safe(source.read_u32()?, &mast_forest)?; - Ok(Self { mast_forest, kernel, entrypoint }) - } -} - -impl Index for Program { - type Output = MastNode; + if !mast_forest.is_procedure_root(entrypoint) { + return Err(DeserializationError::InvalidValue(format!( + "entrypoint {entrypoint} is not a procedure" + ))); + } - fn index(&self, node_id: MastNodeId) -> &Self::Output { - &self.mast_forest[node_id] + Ok(Self::with_kernel(mast_forest, entrypoint, kernel)) } } +// ------------------------------------------------------------------------------------------------ +// Pretty-printing + impl crate::prettier::PrettyPrint for Program { fn render(&self) -> crate::prettier::Document { use crate::prettier::*; - let entrypoint = self[self.entrypoint()].to_pretty_print(&self.mast_forest); + let entrypoint = self.mast_forest[self.entrypoint()].to_pretty_print(&self.mast_forest); indent(4, const_text("begin") + nl() + entrypoint.render()) + nl() + const_text("end") } @@ -171,12 +180,6 @@ impl fmt::Display for Program { } } -impl From for MastForest { - fn from(program: Program) -> Self { - program.mast_forest - } -} - // PROGRAM INFO // =============================================================================================== @@ -195,17 +198,11 @@ pub struct ProgramInfo { } impl ProgramInfo { - // CONSTRUCTORS - // -------------------------------------------------------------------------------------------- - /// Creates a new instance of a program info. pub const fn new(program_hash: RpoDigest, kernel: Kernel) -> Self { Self { program_hash, kernel } } - // PUBLIC ACCESSORS - // -------------------------------------------------------------------------------------------- - /// Returns the program hash computed from its code block root. pub const fn program_hash(&self) -> &RpoDigest { &self.program_hash @@ -231,8 +228,8 @@ impl From for ProgramInfo { } } -// SERIALIZATION // ------------------------------------------------------------------------------------------------ +// Serialization impl Serializable for ProgramInfo { fn write_into(&self, target: &mut W) { @@ -249,8 +246,8 @@ impl Deserializable for ProgramInfo { } } -// TO ELEMENTS // ------------------------------------------------------------------------------------------------ +// ToElements implementation impl ToElements for ProgramInfo { fn to_elements(&self) -> Vec { diff --git a/core/src/stack/inputs.rs b/core/src/stack/inputs.rs index db33721c05..b96f4b64fb 100644 --- a/core/src/stack/inputs.rs +++ b/core/src/stack/inputs.rs @@ -1,27 +1,24 @@ use alloc::vec::Vec; -use core::slice; +use core::{ops::Deref, slice}; -use super::{ByteWriter, Felt, InputError, Serializable, ToElements}; +use super::{ + super::ZERO, get_num_stack_values, ByteWriter, Felt, InputError, Serializable, MIN_STACK_DEPTH, +}; use crate::utils::{ByteReader, Deserializable, DeserializationError}; // STACK INPUTS // ================================================================================================ -/// Initial state of the stack to support program execution. +/// Defines the initial state of the VM's operand stack. /// -/// The program execution expects the inputs to be a stack on the VM, and it will be stored in -/// reversed order on this struct. +/// The values in the struct are stored in the "stack order" - i.e., the last input is at the top +/// of the stack (in position 0). #[derive(Clone, Debug, Default)] pub struct StackInputs { - values: Vec, + elements: [Felt; MIN_STACK_DEPTH], } impl StackInputs { - // CONSTANTS - // -------------------------------------------------------------------------------------------- - - pub const MAX_LEN: usize = u16::MAX as usize; - // CONSTRUCTORS // -------------------------------------------------------------------------------------------- @@ -30,12 +27,13 @@ impl StackInputs { /// # Errors /// Returns an error if the number of input values exceeds the allowed maximum. pub fn new(mut values: Vec) -> Result { - if values.len() > Self::MAX_LEN { - return Err(InputError::InputLengthExceeded(Self::MAX_LEN, values.len())); + if values.len() > MIN_STACK_DEPTH { + return Err(InputError::InputLengthExceeded(MIN_STACK_DEPTH, values.len())); } values.reverse(); + values.resize(MIN_STACK_DEPTH, ZERO); - Ok(Self { values }) + Ok(Self { elements: values.try_into().unwrap() }) } /// Attempts to create stack inputs from an iterator of integers. @@ -55,13 +53,19 @@ impl StackInputs { Self::new(values) } +} - // PUBLIC ACCESSORS - // -------------------------------------------------------------------------------------------- +impl Deref for StackInputs { + type Target = [Felt; MIN_STACK_DEPTH]; + + fn deref(&self) -> &Self::Target { + &self.elements + } +} - /// Returns the initial stack values in stack/reversed order. - pub fn values(&self) -> &[Felt] { - &self.values +impl From<[Felt; MIN_STACK_DEPTH]> for StackInputs { + fn from(value: [Felt; MIN_STACK_DEPTH]) -> Self { + Self { elements: value } } } @@ -70,22 +74,16 @@ impl<'a> IntoIterator for &'a StackInputs { type IntoIter = slice::Iter<'a, Felt>; fn into_iter(self) -> Self::IntoIter { - self.values.iter() + self.elements.iter() } } impl IntoIterator for StackInputs { type Item = Felt; - type IntoIter = alloc::vec::IntoIter; + type IntoIter = core::array::IntoIter; fn into_iter(self) -> Self::IntoIter { - self.values.into_iter() - } -} - -impl ToElements for StackInputs { - fn to_elements(&self) -> Vec { - self.values.to_vec() + self.elements.into_iter() } } @@ -94,27 +92,24 @@ impl ToElements for StackInputs { impl Serializable for StackInputs { fn write_into(&self, target: &mut W) { - // TODO the length of the stack, by design, will not be greater than `u32::MAX`. however, - // we must define a common serialization format as we might diverge from the implementation - // here and the one provided by default from winterfell. - - debug_assert!(self.values.len() <= Self::MAX_LEN); - target.write_usize(self.values.len()); - target.write_many(&self.values); + let num_stack_values = get_num_stack_values(self); + target.write_u8(num_stack_values); + target.write_many(&self.elements[..num_stack_values as usize]); } } impl Deserializable for StackInputs { fn read_from(source: &mut R) -> Result { - let count = source.read_usize()?; - if count > Self::MAX_LEN { - return Err(DeserializationError::InvalidValue(format!( - "Number of values on the input stack can not be more than {}, but {} was found", - Self::MAX_LEN, - count - ))); - } - let values = source.read_many::(count)?; - Ok(StackInputs { values }) + let num_elements = source.read_u8()?; + + let mut elements = source.read_many::(num_elements.into())?; + elements.reverse(); + + StackInputs::new(elements).map_err(|_| { + DeserializationError::InvalidValue(format!( + "number of stack elements should not be greater than {}, but {} was found", + MIN_STACK_DEPTH, num_elements + )) + }) } } diff --git a/core/src/stack/mod.rs b/core/src/stack/mod.rs index 3b8748f310..9def652702 100644 --- a/core/src/stack/mod.rs +++ b/core/src/stack/mod.rs @@ -1,6 +1,6 @@ use super::{ errors::{InputError, OutputError}, - Felt, StackTopState, ToElements, + Felt, }; use crate::utils::{ByteWriter, Serializable}; @@ -10,9 +10,33 @@ pub use inputs::StackInputs; mod outputs; pub use outputs::StackOutputs; +#[cfg(test)] +mod tests; + // CONSTANTS // ================================================================================================ -/// The number of stack registers which can be accessed by the VM directly. This is also the -/// minimum stack depth enforced by the VM. -pub const STACK_TOP_SIZE: usize = 16; +/// Represents: +/// - Number of elements that can be initialized at the start of execution and remain populated at +/// the end of execution. +/// - Number of elements that can be accessed directly via instructions. +/// - Number of elements that remain visible to the callee when the context is switched via `call` +/// or `syscall` instructions. +/// - Number of elements below which the depth of the stack never drops. +pub const MIN_STACK_DEPTH: usize = 16; + +// HELPER FUNCTIONS +// ================================================================================================ + +/// Get the number of non-zero stack elements. +fn get_num_stack_values(values: &[Felt; MIN_STACK_DEPTH]) -> u8 { + let mut num_trailing_zeros = 0; + for v in values.iter().rev() { + if v.as_int() == 0 { + num_trailing_zeros += 1; + } else { + break; + } + } + (MIN_STACK_DEPTH - num_trailing_zeros) as u8 +} diff --git a/core/src/stack/outputs.rs b/core/src/stack/outputs.rs index 855afa0e10..d4b4799c41 100644 --- a/core/src/stack/outputs.rs +++ b/core/src/stack/outputs.rs @@ -1,10 +1,9 @@ use alloc::vec::Vec; +use core::ops::Deref; use miden_crypto::{Word, ZERO}; -use super::{ - ByteWriter, Felt, OutputError, Serializable, StackTopState, ToElements, STACK_TOP_SIZE, -}; +use super::{get_num_stack_values, ByteWriter, Felt, OutputError, Serializable, MIN_STACK_DEPTH}; use crate::utils::{range, ByteReader, Deserializable, DeserializationError}; // STACK OUTPUTS @@ -12,89 +11,52 @@ use crate::utils::{range, ByteReader, Deserializable, DeserializationError}; /// Output container for Miden VM programs. /// -/// Miden program outputs contain the full state of the stack at the end of execution as well as the -/// addresses in the overflow table which are required to reconstruct the table (when combined with -/// the overflow values from the stack state). +/// Miden program outputs contain the full state of the stack at the end of execution. /// /// `stack` is expected to be ordered as if the elements were popped off the stack one by one. /// Thus, the value at the top of the stack is expected to be in the first position, and the order /// of the rest of the output elements will also match the order on the stack. -/// -/// `overflow_addrs` is expected to start with the `prev` address value from the first row in the -/// overflow table (the row representing the deepest element in the stack) and then be followed by -/// the address (`clk` value) of each row in the table starting from the deepest element in the -/// stack and finishing with the row which was added to the table last. #[derive(Debug, Clone, Default, PartialEq, Eq)] pub struct StackOutputs { - /// The elements on the stack at the end of execution. - stack: Vec, - /// The overflow table row addresses required to reconstruct the final state of the table. - overflow_addrs: Vec, + elements: [Felt; MIN_STACK_DEPTH], } impl StackOutputs { - // CONSTANTS - // -------------------------------------------------------------------------------------------- - - pub const MAX_LEN: usize = u16::MAX as usize; - // CONSTRUCTORS // -------------------------------------------------------------------------------------------- - /// Constructs a new [StackOutputs] struct from the provided stack elements and overflow - /// addresses. + /// Constructs a new [StackOutputs] struct from the provided stack elements. /// /// # Errors - /// Returns an error if the number of stack elements is greater than `STACK_TOP_SIZE` (16) and - /// `overflow_addrs` does not contain exactly `stack.len() + 1 - STACK_TOP_SIZE` elements. - pub fn new(mut stack: Vec, overflow_addrs: Vec) -> Result { + /// Returns an error if the number of stack elements is greater than `MIN_STACK_DEPTH` (16). + pub fn new(mut stack: Vec) -> Result { // validate stack length - if stack.len() > Self::MAX_LEN { + if stack.len() > MIN_STACK_DEPTH { return Err(OutputError::OutputSizeTooBig(stack.len())); } + stack.resize(MIN_STACK_DEPTH, ZERO); - // get overflow_addrs length - let expected_overflow_addrs_len = get_overflow_addrs_len(stack.len()); - - // validate overflow_addrs length - if overflow_addrs.len() != expected_overflow_addrs_len { - return Err(OutputError::InvalidOverflowAddressLength( - overflow_addrs.len(), - expected_overflow_addrs_len, - )); - } - - // pad stack to the `STACK_TOP_SIZE` - if stack.len() < STACK_TOP_SIZE { - stack.resize(STACK_TOP_SIZE, ZERO); - } - - Ok(Self { stack, overflow_addrs }) + Ok(Self { elements: stack.try_into().unwrap() }) } - /// Attempts to create [StackOutputs] struct from the provided stack elements and overflow - /// addresses represented as vectors of `u64` values. + /// Attempts to create [StackOutputs] struct from the provided stack elements represented as + /// vector of `u64` values. /// /// # Errors /// Returns an error if: /// - Any of the provided stack elements are invalid field elements. - /// - Any of the provided overflow addresses are invalid field elements. - pub fn try_from_ints(stack: Vec, overflow_addrs: Vec) -> Result { + pub fn try_from_ints(iter: I) -> Result + where + I: IntoIterator, + { // Validate stack elements - let stack = stack - .iter() - .map(|v| Felt::try_from(*v)) + let stack = iter + .into_iter() + .map(Felt::try_from) .collect::, _>>() .map_err(OutputError::InvalidStackElement)?; - // Validate overflow address elements - let overflow_addrs = overflow_addrs - .iter() - .map(|v| Felt::try_from(*v)) - .collect::, _>>() - .map_err(OutputError::InvalidOverflowAddress)?; - - Self::new(stack, overflow_addrs) + Self::new(stack) } // PUBLIC ACCESSORS @@ -103,7 +65,7 @@ impl StackOutputs { /// Returns the element located at the specified position on the stack or `None` if out of /// bounds. pub fn get_stack_item(&self, idx: usize) -> Option { - self.stack.get(idx).cloned() + self.elements.get(idx).cloned() } /// Returns the word located starting at the specified Felt position on the stack or `None` if @@ -124,88 +86,38 @@ impl StackOutputs { Some(word_elements) } - /// Returns the stack outputs, which is state of the stack at the end of execution converted to - /// integers. - pub fn stack(&self) -> &[Felt] { - &self.stack - } - /// Returns the number of requested stack outputs or returns the full stack if fewer than the /// requested number of stack values exist. pub fn stack_truncated(&self, num_outputs: usize) -> &[Felt] { - let len = self.stack.len().min(num_outputs); - &self.stack[..len] - } - - /// Returns the state of the top of the stack at the end of execution. - pub fn stack_top(&self) -> StackTopState { - self.stack - .iter() - .take(STACK_TOP_SIZE) - .cloned() - .collect::>() - .try_into() - .expect("failed to convert vector to array") - } - - /// Returns the overflow address outputs, which are the addresses required to reconstruct the - /// overflow table (when combined with the stack overflow values) converted to integers. - pub fn overflow_addrs(&self) -> &[Felt] { - &self.overflow_addrs - } - - /// Returns true if the overflow table outputs are non-empty. - pub fn has_overflow(&self) -> bool { - !self.overflow_addrs.is_empty() - } - - /// Returns the previous address `prev` for the first row in the stack overflow table - pub fn overflow_prev(&self) -> Felt { - self.overflow_addrs[0] - } - - /// Returns (address, value) for all rows which were on the overflow table at the end of - /// execution in the order in which they were added to the table (deepest stack item first). - pub fn stack_overflow(&self) -> Vec<(Felt, Felt)> { - let mut overflow = Vec::with_capacity(self.overflow_addrs.len() - 1); - for (addr, val) in self - .overflow_addrs - .iter() - .skip(1) - .zip(self.stack.iter().skip(STACK_TOP_SIZE).rev()) - { - overflow.push((*addr, *val)); - } - - overflow + let len = self.elements.len().min(num_outputs); + &self.elements[..len] } // PUBLIC MUTATORS // -------------------------------------------------------------------------------------------- /// Returns mutable access to the stack outputs, to be used for testing or running examples. - /// TODO: this should be marked with #[cfg(test)] attribute, but that currently won't work with - /// the integration test handler util. pub fn stack_mut(&mut self) -> &mut [Felt] { - &mut self.stack + &mut self.elements + } + + /// Converts the [`StackOutputs`] into the vector of `u64` values. + pub fn as_int_vec(&self) -> Vec { + self.elements.iter().map(|e| (*e).as_int()).collect() } } -// HELPER FUNCTIONS -// ================================================================================================ +impl Deref for StackOutputs { + type Target = [Felt; 16]; -impl ToElements for StackOutputs { - fn to_elements(&self) -> Vec { - self.stack.iter().chain(self.overflow_addrs.iter()).cloned().collect() + fn deref(&self) -> &Self::Target { + &self.elements } } -/// Returs the number of overflow addresses based on the lenght of the stack. -fn get_overflow_addrs_len(stack_len: usize) -> usize { - if stack_len > STACK_TOP_SIZE { - stack_len + 1 - STACK_TOP_SIZE - } else { - 0 +impl From<[Felt; MIN_STACK_DEPTH]> for StackOutputs { + fn from(value: [Felt; MIN_STACK_DEPTH]) -> Self { + Self { elements: value } } } @@ -214,29 +126,23 @@ fn get_overflow_addrs_len(stack_len: usize) -> usize { impl Serializable for StackOutputs { fn write_into(&self, target: &mut W) { - debug_assert!(self.stack.len() <= Self::MAX_LEN); - target.write_usize(self.stack.len()); - target.write_many(&self.stack); - - target.write_many(&self.overflow_addrs); + let num_stack_values = get_num_stack_values(self); + target.write_u8(num_stack_values); + target.write_many(&self.elements[..num_stack_values as usize]); } } impl Deserializable for StackOutputs { fn read_from(source: &mut R) -> Result { - let count = source.read_usize()?; - if count > Self::MAX_LEN { - return Err(DeserializationError::InvalidValue(format!( - "Number of values on the output stack can not be more than {}, but {} was found", - Self::MAX_LEN, - count - ))); - } - let stack = source.read_many::(count)?; + let num_elements = source.read_u8()?; - let count = get_overflow_addrs_len(stack.len()); - let overflow_addrs = source.read_many::(count)?; + let elements = source.read_many::(num_elements.into())?; - Ok(Self { stack, overflow_addrs }) + StackOutputs::new(elements).map_err(|_| { + DeserializationError::InvalidValue(format!( + "number of stack elements should not be greater than {}, but {} was found", + MIN_STACK_DEPTH, num_elements + )) + }) } } diff --git a/core/src/stack/tests.rs b/core/src/stack/tests.rs new file mode 100644 index 0000000000..4a87dfc94a --- /dev/null +++ b/core/src/stack/tests.rs @@ -0,0 +1,130 @@ +use alloc::vec::Vec; + +use crate::{ + utils::{Deserializable, Serializable}, + StackInputs, StackOutputs, +}; + +// SERDE INPUTS TESTS +// ================================================================================================ + +#[test] +fn test_inputs_simple() { + let source = Vec::::from([5, 4, 3, 2, 1]); + let mut serialized = Vec::new(); + let inputs = StackInputs::try_from_ints(source.clone()).unwrap(); + + inputs.write_into(&mut serialized); + + let mut expected_serialized = Vec::new(); + expected_serialized.push(source.len() as u8); + source + .iter() + .rev() + .for_each(|v| expected_serialized.append(&mut v.to_le_bytes().to_vec())); + + assert_eq!(serialized, expected_serialized); + + let result = StackInputs::read_from_bytes(&serialized).unwrap(); + + assert_eq!(*inputs, *result); +} + +#[test] +fn test_inputs_full() { + let source = Vec::::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]); + let mut serialized = Vec::new(); + let inputs = StackInputs::try_from_ints(source.clone()).unwrap(); + + inputs.write_into(&mut serialized); + + let mut expected_serialized = Vec::new(); + expected_serialized.push(source.len() as u8); + source + .iter() + .rev() + .for_each(|v| expected_serialized.append(&mut v.to_le_bytes().to_vec())); + + assert_eq!(serialized, expected_serialized); + + let result = StackInputs::read_from_bytes(&serialized).unwrap(); + + assert_eq!(*inputs, *result); +} + +#[test] +fn test_inputs_empty() { + let mut serialized = Vec::new(); + let inputs = StackInputs::try_from_ints([]).unwrap(); + + inputs.write_into(&mut serialized); + + let expected_serialized = vec![0]; + + assert_eq!(serialized, expected_serialized); + + let result = StackInputs::read_from_bytes(&serialized).unwrap(); + + assert_eq!(*inputs, *result); +} + +// SERDE OUTPUTS TESTS +// ================================================================================================ + +#[test] +fn test_outputs_simple() { + let source = Vec::::from([1, 2, 3, 4, 5]); + let mut serialized = Vec::new(); + let inputs = StackOutputs::try_from_ints(source.clone()).unwrap(); + + inputs.write_into(&mut serialized); + + let mut expected_serialized = Vec::new(); + expected_serialized.push(source.len() as u8); + source + .iter() + .for_each(|v| expected_serialized.append(&mut v.to_le_bytes().to_vec())); + + assert_eq!(serialized, expected_serialized); + + let result = StackOutputs::read_from_bytes(&serialized).unwrap(); + + assert_eq!(*inputs, *result); +} + +#[test] +fn test_outputs_full() { + let source = Vec::::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); + let mut serialized = Vec::new(); + let inputs = StackOutputs::try_from_ints(source.clone()).unwrap(); + + inputs.write_into(&mut serialized); + + let mut expected_serialized = Vec::new(); + expected_serialized.push(source.len() as u8); + source + .iter() + .for_each(|v| expected_serialized.append(&mut v.to_le_bytes().to_vec())); + + assert_eq!(serialized, expected_serialized); + + let result = StackOutputs::read_from_bytes(&serialized).unwrap(); + + assert_eq!(*inputs, *result); +} + +#[test] +fn test_outputs_empty() { + let mut serialized = Vec::new(); + let inputs = StackOutputs::try_from_ints([]).unwrap(); + + inputs.write_into(&mut serialized); + + let expected_serialized = vec![0]; + + assert_eq!(serialized, expected_serialized); + + let result = StackOutputs::read_from_bytes(&serialized).unwrap(); + + assert_eq!(*inputs, *result); +} diff --git a/core/src/utils/sync.rs b/core/src/utils/sync.rs deleted file mode 100644 index a3739fdc36..0000000000 --- a/core/src/utils/sync.rs +++ /dev/null @@ -1,313 +0,0 @@ -#[cfg(feature = "std")] -pub use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard}; - -#[cfg(not(feature = "std"))] -pub use self::rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard}; - -pub mod rwlock { - #[cfg(not(loom))] - use core::{ - hint, - sync::atomic::{AtomicUsize, Ordering}, - }; - - use lock_api::RawRwLock; - #[cfg(loom)] - use loom::{ - hint, - sync::atomic::{AtomicUsize, Ordering}, - }; - - /// An implementation of a reader-writer lock, based on a spinlock primitive, no-std compatible - /// - /// See [lock_api::RwLock] for usage. - pub type RwLock = lock_api::RwLock; - - /// See [lock_api::RwLockReadGuard] for usage. - pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, Spinlock, T>; - - /// See [lock_api::RwLockWriteGuard] for usage. - pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, Spinlock, T>; - - /// The underlying raw reader-writer primitive that implements [lock_api::RawRwLock] - /// - /// This is fundamentally a spinlock, in that blocking operations on the lock will spin until - /// they succeed in acquiring/releasing the lock. - /// - /// To acheive the ability to share the underlying data with multiple readers, or hold - /// exclusive access for one writer, the lock state is based on a "locked" count, where shared - /// access increments the count by an even number, and acquiring exclusive access relies on the - /// use of the lowest order bit to stop further shared acquisition, and indicate that the lock - /// is exclusively held (the difference between the two is irrelevant from the perspective of - /// a thread attempting to acquire the lock, but internally the state uses `usize::MAX` as the - /// "exlusively locked" sentinel). - /// - /// This mechanism gets us the following: - /// - /// * Whether the lock has been acquired (shared or exclusive) - /// * Whether the lock is being exclusively acquired - /// * How many times the lock has been acquired - /// * Whether the acquisition(s) are exclusive or shared - /// - /// Further implementation details, such as how we manage draining readers once an attempt to - /// exclusively acquire the lock occurs, are described below. - /// - /// NOTE: This is a simple implementation, meant for use in no-std environments; there are much - /// more robust/performant implementations available when OS primitives can be used. - pub struct Spinlock { - /// The state of the lock, primarily representing the acquisition count, but relying on - /// the distinction between even and odd values to indicate whether or not exclusive access - /// is being acquired. - state: AtomicUsize, - /// A counter used to wake a parked writer once the last shared lock is released during - /// acquisition of an exclusive lock. The actual count is not acutally important, and - /// simply wraps around on overflow, but what is important is that when the value changes, - /// the writer will wake and resume attempting to acquire the exclusive lock. - writer_wake_counter: AtomicUsize, - } - - impl Default for Spinlock { - #[inline(always)] - fn default() -> Self { - Self::new() - } - } - - impl Spinlock { - #[cfg(not(loom))] - pub const fn new() -> Self { - Self { - state: AtomicUsize::new(0), - writer_wake_counter: AtomicUsize::new(0), - } - } - - #[cfg(loom)] - pub fn new() -> Self { - Self { - state: AtomicUsize::new(0), - writer_wake_counter: AtomicUsize::new(0), - } - } - } - - unsafe impl RawRwLock for Spinlock { - #[cfg(loom)] - const INIT: Spinlock = unimplemented!(); - - #[cfg(not(loom))] - // This is intentional on the part of the [RawRwLock] API, basically a hack to provide - // initial values as static items. - #[allow(clippy::declare_interior_mutable_const)] - const INIT: Spinlock = Spinlock::new(); - - type GuardMarker = lock_api::GuardSend; - - /// The operation invoked when calling `RwLock::read`, blocks the caller until acquired - fn lock_shared(&self) { - let mut s = self.state.load(Ordering::Relaxed); - loop { - // If the exclusive bit is unset, attempt to acquire a read lock - if s & 1 == 0 { - match self.state.compare_exchange_weak( - s, - s + 2, - Ordering::Acquire, - Ordering::Relaxed, - ) { - Ok(_) => return, - // Someone else beat us to the punch and acquired a lock - Err(e) => s = e, - } - } - // If an exclusive lock is held/being acquired, loop until the lock state changes - // at which point, try to acquire the lock again - if s & 1 == 1 { - loop { - let next = self.state.load(Ordering::Relaxed); - if s == next { - hint::spin_loop(); - continue; - } else { - s = next; - break; - } - } - } - } - } - - /// The operation invoked when calling `RwLock::try_read`, returns whether or not the - /// lock was acquired - fn try_lock_shared(&self) -> bool { - let s = self.state.load(Ordering::Relaxed); - if s & 1 == 0 { - self.state - .compare_exchange_weak(s, s + 2, Ordering::Acquire, Ordering::Relaxed) - .is_ok() - } else { - false - } - } - - /// The operation invoked when dropping a `RwLockReadGuard` - unsafe fn unlock_shared(&self) { - if self.state.fetch_sub(2, Ordering::Release) == 3 { - // The lock is being exclusively acquired, and we're the last shared acquisition - // to be released, so wake the writer by incrementing the wake counter - self.writer_wake_counter.fetch_add(1, Ordering::Release); - } - } - - /// The operation invoked when calling `RwLock::write`, blocks the caller until acquired - fn lock_exclusive(&self) { - let mut s = self.state.load(Ordering::Relaxed); - loop { - // Attempt to acquire the lock immediately, or complete acquistion of the lock - // if we're continuing the loop after acquiring the exclusive bit. If another - // thread acquired it first, we race to be the first thread to acquire it once - // released, by busy looping here. - if s <= 1 { - match self.state.compare_exchange( - s, - usize::MAX, - Ordering::Acquire, - Ordering::Relaxed, - ) { - Ok(_) => return, - Err(e) => { - s = e; - hint::spin_loop(); - continue; - }, - } - } - - // Only shared locks have been acquired, attempt to acquire the exclusive bit, - // which will prevent further shared locks from being acquired. It does not - // in and of itself grant us exclusive access however. - if s & 1 == 0 { - if let Err(e) = - self.state.compare_exchange(s, s + 1, Ordering::Relaxed, Ordering::Relaxed) - { - // The lock state has changed before we could acquire the exclusive bit, - // update our view of the lock state and try again - s = e; - continue; - } - } - - // We've acquired the exclusive bit, now we need to busy wait until all shared - // acquisitions are released. - let w = self.writer_wake_counter.load(Ordering::Acquire); - s = self.state.load(Ordering::Relaxed); - - // "Park" the thread here (by busy looping), until the release of the last shared - // lock, which is communicated to us by it incrementing the wake counter. - if s >= 2 { - while self.writer_wake_counter.load(Ordering::Acquire) == w { - hint::spin_loop(); - } - s = self.state.load(Ordering::Relaxed); - } - - // All shared locks have been released, go back to the top and try to complete - // acquisition of exclusive access. - } - } - - /// The operation invoked when calling `RwLock::try_write`, returns whether or not the - /// lock was acquired - fn try_lock_exclusive(&self) -> bool { - let s = self.state.load(Ordering::Relaxed); - if s <= 1 { - self.state - .compare_exchange(s, usize::MAX, Ordering::Acquire, Ordering::Relaxed) - .is_ok() - } else { - false - } - } - - /// The operation invoked when dropping a `RwLockWriteGuard` - unsafe fn unlock_exclusive(&self) { - // Infallible, as we hold an exclusive lock - // - // Note the use of `Release` ordering here, which ensures any loads of the lock state - // by other threads, are ordered after this store. - self.state.store(0, Ordering::Release); - // This fetch_add isn't important for signaling purposes, however it serves a key - // purpose, in that it imposes a memory ordering on any loads of this field that - // have an `Acquire` ordering, i.e. they will read the value stored here. Without - // a `Release` store, loads/stores of this field could be reordered relative to - // each other. - self.writer_wake_counter.fetch_add(1, Ordering::Release); - } - } -} - -#[cfg(all(loom, test))] -mod test { - use alloc::vec::Vec; - - use loom::{model::Builder, sync::Arc}; - - use super::rwlock::{RwLock, Spinlock}; - - #[test] - fn test_rwlock_loom() { - let mut builder = Builder::default(); - builder.max_duration = Some(std::time::Duration::from_secs(60)); - builder.log = true; - builder.check(|| { - let raw_rwlock = Spinlock::new(); - let n = Arc::new(RwLock::from_raw(raw_rwlock, 0usize)); - let mut readers = Vec::new(); - let mut writers = Vec::new(); - - let num_readers = 2; - let num_writers = 2; - let num_iterations = 2; - - // Readers should never observe a non-zero value - for _ in 0..num_readers { - let n0 = n.clone(); - let t = loom::thread::spawn(move || { - for _ in 0..num_iterations { - let guard = n0.read(); - assert_eq!(*guard, 0); - } - }); - - readers.push(t); - } - - // Writers should never observe a non-zero value once they've - // acquired the lock, and should never observe a value > 1 - // while holding the lock - for _ in 0..num_writers { - let n0 = n.clone(); - let t = loom::thread::spawn(move || { - for _ in 0..num_iterations { - let mut guard = n0.write(); - assert_eq!(*guard, 0); - *guard += 1; - assert_eq!(*guard, 1); - *guard -= 1; - assert_eq!(*guard, 0); - } - }); - - writers.push(t); - } - - for t in readers { - t.join().unwrap(); - } - - for t in writers { - t.join().unwrap(); - } - }) - } -} diff --git a/core/src/utils/sync/mod.rs b/core/src/utils/sync/mod.rs new file mode 100644 index 0000000000..ef68428f86 --- /dev/null +++ b/core/src/utils/sync/mod.rs @@ -0,0 +1,12 @@ +pub mod racy_lock; +pub mod rw_lock; + +#[cfg(feature = "std")] +pub use std::sync::LazyLock; + +#[cfg(feature = "std")] +pub use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard}; +#[cfg(not(feature = "std"))] +pub use racy_lock::RacyLock as LazyLock; +#[cfg(not(feature = "std"))] +pub use rw_lock::{RwLock, RwLockReadGuard, RwLockWriteGuard}; diff --git a/core/src/utils/sync/racy_lock.rs b/core/src/utils/sync/racy_lock.rs new file mode 100644 index 0000000000..157cd41ed0 --- /dev/null +++ b/core/src/utils/sync/racy_lock.rs @@ -0,0 +1,187 @@ +use alloc::boxed::Box; +use core::{ + fmt, + ops::Deref, + ptr, + sync::atomic::{AtomicPtr, Ordering}, +}; + +/// Thread-safe, non-blocking, lazily evaluated lock with the same interface +/// as [`std::sync::LazyLock`]. +/// +/// Concurrent threads will race to set the value atomically, and memory allocated by losing threads +/// will be dropped immediately after they fail to set the pointer. +/// +/// The underlying implementation is based on `once_cell::race::OnceBox` which relies on +/// [`core::sync::atomic::AtomicPtr`] to ensure that the data race results in a single successful +/// write to the relevant pointer, namely the first write. +/// See . +/// +/// Performs lazy evaluation and can be used for statics. +pub struct RacyLock T> +where + F: Fn() -> T, +{ + inner: AtomicPtr, + f: F, +} + +impl RacyLock +where + F: Fn() -> T, +{ + /// Creates a new lazy, racy value with the given initializing function. + pub const fn new(f: F) -> Self { + Self { + inner: AtomicPtr::new(ptr::null_mut()), + f, + } + } + + /// Forces the evaluation of the locked value and returns a reference to + /// the result. This is equivalent to the [`Self::deref`]. + /// + /// There is no blocking involved in this operation. Instead, concurrent + /// threads will race to set the underlying pointer. Memory allocated by + /// losing threads will be dropped immediately after they fail to set the pointer. + /// + /// This function's interface is designed around [`std::sync::LazyLock::force`] but + /// the implementation is derived from `once_cell::race::OnceBox::get_or_try_init`. + pub fn force(this: &RacyLock) -> &T { + let mut ptr = this.inner.load(Ordering::Acquire); + + // Pointer is not yet set, attempt to set it ourselves. + if ptr.is_null() { + // Execute the initialization function and allocate. + let val = (this.f)(); + ptr = Box::into_raw(Box::new(val)); + + // Attempt atomic store. + let exchange = this.inner.compare_exchange( + ptr::null_mut(), + ptr, + Ordering::AcqRel, + Ordering::Acquire, + ); + + // Pointer already set, load. + if let Err(old) = exchange { + drop(unsafe { Box::from_raw(ptr) }); + ptr = old; + } + } + + unsafe { &*ptr } + } +} + +impl Default for RacyLock { + /// Creates a new lock that will evaluate the underlying value based on `T::default`. + #[inline] + fn default() -> RacyLock { + RacyLock::new(T::default) + } +} + +impl fmt::Debug for RacyLock +where + T: fmt::Debug, + F: Fn() -> T, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "RacyLock({:?})", self.inner.load(Ordering::Relaxed)) + } +} + +impl Deref for RacyLock +where + F: Fn() -> T, +{ + type Target = T; + + /// Either sets or retrieves the value, and dereferences it. + /// + /// See [`Self::force`] for more details. + #[inline] + fn deref(&self) -> &T { + RacyLock::force(self) + } +} + +impl Drop for RacyLock +where + F: Fn() -> T, +{ + /// Drops the underlying pointer. + fn drop(&mut self) { + let ptr = *self.inner.get_mut(); + if !ptr.is_null() { + // SAFETY: for any given value of `ptr`, we are guaranteed to have at most a single + // instance of `RacyLock` holding that value. Hence, synchronizing threads + // in `drop()` is not necessary, and we are guaranteed never to double-free. + // In short, since `RacyLock` doesn't implement `Clone`, the only scenario + // where there can be multiple instances of `RacyLock` across multiple threads + // referring to the same `ptr` value is when `RacyLock` is used in a static variable. + drop(unsafe { Box::from_raw(ptr) }); + } + } +} + +#[cfg(test)] +mod tests { + use alloc::vec::Vec; + + use super::*; + + #[test] + fn deref_default() { + // Lock a copy type and validate default value. + let lock: RacyLock = RacyLock::default(); + assert_eq!(*lock, 0); + } + + #[test] + fn deref_copy() { + // Lock a copy type and validate value. + let lock = RacyLock::new(|| 42); + assert_eq!(*lock, 42); + } + + #[test] + fn deref_clone() { + // Lock a no copy type. + let lock = RacyLock::new(|| Vec::from([1, 2, 3])); + + // Use the value so that the compiler forces us to clone. + let mut v = lock.clone(); + v.push(4); + + // Validate the value. + assert_eq!(v, Vec::from([1, 2, 3, 4])); + } + + #[test] + fn deref_static() { + // Create a static lock. + static VEC: RacyLock> = RacyLock::new(|| Vec::from([1, 2, 3])); + + // Validate that the address of the value does not change. + let addr = &*VEC as *const Vec; + for _ in 0..5 { + assert_eq!(*VEC, [1, 2, 3]); + assert_eq!(addr, &(*VEC) as *const Vec) + } + } + + #[test] + fn type_inference() { + // Check that we can infer `T` from closure's type. + let _ = RacyLock::new(|| ()); + } + + #[test] + fn is_sync_send() { + fn assert_traits() {} + assert_traits::>>(); + } +} diff --git a/core/src/utils/sync/rw_lock.rs b/core/src/utils/sync/rw_lock.rs new file mode 100644 index 0000000000..5a5cc5a722 --- /dev/null +++ b/core/src/utils/sync/rw_lock.rs @@ -0,0 +1,305 @@ +#[cfg(not(loom))] +use core::{ + hint, + sync::atomic::{AtomicUsize, Ordering}, +}; + +use lock_api::RawRwLock; +#[cfg(loom)] +use loom::{ + hint, + sync::atomic::{AtomicUsize, Ordering}, +}; + +/// An implementation of a reader-writer lock, based on a spinlock primitive, no-std compatible +/// +/// See [lock_api::RwLock] for usage. +pub type RwLock = lock_api::RwLock; + +/// See [lock_api::RwLockReadGuard] for usage. +pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, Spinlock, T>; + +/// See [lock_api::RwLockWriteGuard] for usage. +pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, Spinlock, T>; + +/// The underlying raw reader-writer primitive that implements [lock_api::RawRwLock] +/// +/// This is fundamentally a spinlock, in that blocking operations on the lock will spin until +/// they succeed in acquiring/releasing the lock. +/// +/// To acheive the ability to share the underlying data with multiple readers, or hold +/// exclusive access for one writer, the lock state is based on a "locked" count, where shared +/// access increments the count by an even number, and acquiring exclusive access relies on the +/// use of the lowest order bit to stop further shared acquisition, and indicate that the lock +/// is exclusively held (the difference between the two is irrelevant from the perspective of +/// a thread attempting to acquire the lock, but internally the state uses `usize::MAX` as the +/// "exlusively locked" sentinel). +/// +/// This mechanism gets us the following: +/// +/// * Whether the lock has been acquired (shared or exclusive) +/// * Whether the lock is being exclusively acquired +/// * How many times the lock has been acquired +/// * Whether the acquisition(s) are exclusive or shared +/// +/// Further implementation details, such as how we manage draining readers once an attempt to +/// exclusively acquire the lock occurs, are described below. +/// +/// NOTE: This is a simple implementation, meant for use in no-std environments; there are much +/// more robust/performant implementations available when OS primitives can be used. +pub struct Spinlock { + /// The state of the lock, primarily representing the acquisition count, but relying on + /// the distinction between even and odd values to indicate whether or not exclusive access + /// is being acquired. + state: AtomicUsize, + /// A counter used to wake a parked writer once the last shared lock is released during + /// acquisition of an exclusive lock. The actual count is not acutally important, and + /// simply wraps around on overflow, but what is important is that when the value changes, + /// the writer will wake and resume attempting to acquire the exclusive lock. + writer_wake_counter: AtomicUsize, +} + +impl Default for Spinlock { + #[inline(always)] + fn default() -> Self { + Self::new() + } +} + +impl Spinlock { + #[cfg(not(loom))] + pub const fn new() -> Self { + Self { + state: AtomicUsize::new(0), + writer_wake_counter: AtomicUsize::new(0), + } + } + + #[cfg(loom)] + pub fn new() -> Self { + Self { + state: AtomicUsize::new(0), + writer_wake_counter: AtomicUsize::new(0), + } + } +} + +unsafe impl RawRwLock for Spinlock { + #[cfg(loom)] + const INIT: Spinlock = unimplemented!(); + + #[cfg(not(loom))] + // This is intentional on the part of the [RawRwLock] API, basically a hack to provide + // initial values as static items. + #[allow(clippy::declare_interior_mutable_const)] + const INIT: Spinlock = Spinlock::new(); + + type GuardMarker = lock_api::GuardSend; + + /// The operation invoked when calling `RwLock::read`, blocks the caller until acquired + fn lock_shared(&self) { + let mut s = self.state.load(Ordering::Relaxed); + loop { + // If the exclusive bit is unset, attempt to acquire a read lock + if s & 1 == 0 { + match self.state.compare_exchange_weak( + s, + s + 2, + Ordering::Acquire, + Ordering::Relaxed, + ) { + Ok(_) => return, + // Someone else beat us to the punch and acquired a lock + Err(e) => s = e, + } + } + // If an exclusive lock is held/being acquired, loop until the lock state changes + // at which point, try to acquire the lock again + if s & 1 == 1 { + loop { + let next = self.state.load(Ordering::Relaxed); + if s == next { + hint::spin_loop(); + continue; + } else { + s = next; + break; + } + } + } + } + } + + /// The operation invoked when calling `RwLock::try_read`, returns whether or not the + /// lock was acquired + fn try_lock_shared(&self) -> bool { + let s = self.state.load(Ordering::Relaxed); + if s & 1 == 0 { + self.state + .compare_exchange_weak(s, s + 2, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + } else { + false + } + } + + /// The operation invoked when dropping a `RwLockReadGuard` + unsafe fn unlock_shared(&self) { + if self.state.fetch_sub(2, Ordering::Release) == 3 { + // The lock is being exclusively acquired, and we're the last shared acquisition + // to be released, so wake the writer by incrementing the wake counter + self.writer_wake_counter.fetch_add(1, Ordering::Release); + } + } + + /// The operation invoked when calling `RwLock::write`, blocks the caller until acquired + fn lock_exclusive(&self) { + let mut s = self.state.load(Ordering::Relaxed); + loop { + // Attempt to acquire the lock immediately, or complete acquistion of the lock + // if we're continuing the loop after acquiring the exclusive bit. If another + // thread acquired it first, we race to be the first thread to acquire it once + // released, by busy looping here. + if s <= 1 { + match self.state.compare_exchange( + s, + usize::MAX, + Ordering::Acquire, + Ordering::Relaxed, + ) { + Ok(_) => return, + Err(e) => { + s = e; + hint::spin_loop(); + continue; + }, + } + } + + // Only shared locks have been acquired, attempt to acquire the exclusive bit, + // which will prevent further shared locks from being acquired. It does not + // in and of itself grant us exclusive access however. + if s & 1 == 0 { + if let Err(e) = + self.state.compare_exchange(s, s + 1, Ordering::Relaxed, Ordering::Relaxed) + { + // The lock state has changed before we could acquire the exclusive bit, + // update our view of the lock state and try again + s = e; + continue; + } + } + + // We've acquired the exclusive bit, now we need to busy wait until all shared + // acquisitions are released. + let w = self.writer_wake_counter.load(Ordering::Acquire); + s = self.state.load(Ordering::Relaxed); + + // "Park" the thread here (by busy looping), until the release of the last shared + // lock, which is communicated to us by it incrementing the wake counter. + if s >= 2 { + while self.writer_wake_counter.load(Ordering::Acquire) == w { + hint::spin_loop(); + } + s = self.state.load(Ordering::Relaxed); + } + + // All shared locks have been released, go back to the top and try to complete + // acquisition of exclusive access. + } + } + + /// The operation invoked when calling `RwLock::try_write`, returns whether or not the + /// lock was acquired + fn try_lock_exclusive(&self) -> bool { + let s = self.state.load(Ordering::Relaxed); + if s <= 1 { + self.state + .compare_exchange(s, usize::MAX, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + } else { + false + } + } + + /// The operation invoked when dropping a `RwLockWriteGuard` + unsafe fn unlock_exclusive(&self) { + // Infallible, as we hold an exclusive lock + // + // Note the use of `Release` ordering here, which ensures any loads of the lock state + // by other threads, are ordered after this store. + self.state.store(0, Ordering::Release); + // This fetch_add isn't important for signaling purposes, however it serves a key + // purpose, in that it imposes a memory ordering on any loads of this field that + // have an `Acquire` ordering, i.e. they will read the value stored here. Without + // a `Release` store, loads/stores of this field could be reordered relative to + // each other. + self.writer_wake_counter.fetch_add(1, Ordering::Release); + } +} + +#[cfg(all(loom, test))] +mod test { + use alloc::vec::Vec; + + use loom::{model::Builder, sync::Arc}; + + use super::rwlock::{RwLock, Spinlock}; + + #[test] + fn test_rwlock_loom() { + let mut builder = Builder::default(); + builder.max_duration = Some(std::time::Duration::from_secs(60)); + builder.log = true; + builder.check(|| { + let raw_rwlock = Spinlock::new(); + let n = Arc::new(RwLock::from_raw(raw_rwlock, 0usize)); + let mut readers = Vec::new(); + let mut writers = Vec::new(); + + let num_readers = 2; + let num_writers = 2; + let num_iterations = 2; + + // Readers should never observe a non-zero value + for _ in 0..num_readers { + let n0 = n.clone(); + let t = loom::thread::spawn(move || { + for _ in 0..num_iterations { + let guard = n0.read(); + assert_eq!(*guard, 0); + } + }); + + readers.push(t); + } + + // Writers should never observe a non-zero value once they've + // acquired the lock, and should never observe a value > 1 + // while holding the lock + for _ in 0..num_writers { + let n0 = n.clone(); + let t = loom::thread::spawn(move || { + for _ in 0..num_iterations { + let mut guard = n0.write(); + assert_eq!(*guard, 0); + *guard += 1; + assert_eq!(*guard, 1); + *guard -= 1; + assert_eq!(*guard, 0); + } + }); + + writers.push(t); + } + + for t in readers { + t.join().unwrap(); + } + + for t in writers { + t.join().unwrap(); + } + }) + } +} diff --git a/docs/src/assets/design/decoder/decoder_block_stack_table.png b/docs/src/assets/design/decoder/decoder_block_stack_table.png index d56dae6c48..cfa1ac1eeb 100644 Binary files a/docs/src/assets/design/decoder/decoder_block_stack_table.png and b/docs/src/assets/design/decoder/decoder_block_stack_table.png differ diff --git a/docs/src/assets/design/decoder/decoder_call_operation.png b/docs/src/assets/design/decoder/decoder_call_operation.png new file mode 100644 index 0000000000..6572260d14 Binary files /dev/null and b/docs/src/assets/design/decoder/decoder_call_operation.png differ diff --git a/docs/src/assets/design/decoder/decoder_dyn_operation.png b/docs/src/assets/design/decoder/decoder_dyn_operation.png index 93426e0d31..71da6f13a6 100644 Binary files a/docs/src/assets/design/decoder/decoder_dyn_operation.png and b/docs/src/assets/design/decoder/decoder_dyn_operation.png differ diff --git a/docs/src/assets/design/decoder/decoder_dyncall_operation.png b/docs/src/assets/design/decoder/decoder_dyncall_operation.png new file mode 100644 index 0000000000..2c0effc4d0 Binary files /dev/null and b/docs/src/assets/design/decoder/decoder_dyncall_operation.png differ diff --git a/docs/src/assets/design/decoder/decoder_end_operation.png b/docs/src/assets/design/decoder/decoder_end_operation.png index 7b36383573..2dd02bc748 100644 Binary files a/docs/src/assets/design/decoder/decoder_end_operation.png and b/docs/src/assets/design/decoder/decoder_end_operation.png differ diff --git a/docs/src/assets/design/decoder/decoder_syscall_operation.png b/docs/src/assets/design/decoder/decoder_syscall_operation.png new file mode 100644 index 0000000000..adbbf7cc6c Binary files /dev/null and b/docs/src/assets/design/decoder/decoder_syscall_operation.png differ diff --git a/docs/src/design/decoder/constraints.md b/docs/src/design/decoder/constraints.md index bf12f217ca..30f8aced83 100644 --- a/docs/src/design/decoder/constraints.md +++ b/docs/src/design/decoder/constraints.md @@ -1,6 +1,6 @@ # Miden VM decoder AIR constraints -In this section we describe AIR constraint for Miden VM program decoder. These constraints enforce that the execution trace generated by the prover when executing a particular program complies with the rules described in the [previous section](./main.md). +In this section we describe AIR constraints for Miden VM program decoder. These constraints enforce that the execution trace generated by the prover when executing a particular program complies with the rules described in the [previous section](./main.md). To refer to decoder execution trace columns, we use the names shown on the diagram below (these are the same names as in the previous section). Additionally, we denote the register containing the value at the top of the stack as $s_0$. @@ -23,11 +23,12 @@ AIR constraints for the decoder involve operations listed in the table below. Fo | `SYSCALL` | $f_{syscall}$ | 4 | Stack remains unchanged. | | `END` | $f_{end}$ | 4 | When exiting a loop block, top stack element is dropped; otherwise, the stack remains unchanged. | | `HALT` | $f_{halt}$ | 4 | Stack remains unchanged. | -| `PUSH` | $f_{push}$ | 4 | An immediate value is pushed onto the stack. | +| `PUSH` | $f_{push}$ | 5 | An immediate value is pushed onto the stack. | +| `EMIT` | $f_{emit}$ | 5 | Stack remains unchanged. | We also use the [control flow flag](../stack/op_constraints.md#control-flow-flag) $f_{ctrl}$ exposed by the VM, which is set when any one of the above control flow operations is being executed. It has degree $5$. -As described [previously](./main.md#program-decoding), the general idea of the decoder is that the prover provides the program to the VM by populating some of cells in the trace non-deterministically. Values in these are then used to update virtual tables (represented via multiset checks) such as block hash table, block stack table etc. Transition constraints are used to enforce that the tables are updates correctly, and we also apply boundary constraints to enforce the correct initial and final states of these tables. One of these boundary constraints binds the execution trace to the hash of the program being executed. Thus, if the virtual tables were updated correctly and boundary constraints hold, we can be convinced that the prover executed the claimed program on the VM. +As described [previously](./main.md#program-decoding), the general idea of the decoder is that the prover provides the program to the VM by populating some of cells in the trace non-deterministically. Values in these are then used to update virtual tables (represented via multiset checks) such as block hash table, block stack table etc. Transition constraints are used to ensure that the tables are updates correctly, and we also apply boundary constraints to enforce the correct initial and final states of these tables. One of these boundary constraints binds the execution trace to the hash of the program being executed. Thus, if the virtual tables were updated correctly and boundary constraints hold, we can be convinced that the prover executed the claimed program on the VM. In the sections below, we describe constraints according to their logical grouping. However, we start out with a set of general constraints which are applicable to multiple parts of the decoder. @@ -137,10 +138,10 @@ $$ In the above, $a$ represents the address value in the decoder which corresponds to the hasher chiplet address at which the hasher was initialized (or the last absorption took place). As such, $a + 7$ corresponds to the hasher chiplet address at which the result is returned. $$ -f_{ctrli} = f_{join} + f_{split} + f_{loop} + f_{dyn} + f_{call} \text{ | degree} = 5 +f_{ctrli} = f_{join} + f_{split} + f_{loop} + f_{call} \text{ | degree} = 5 $$ -In the above, $f_{ctrli}$ is set to $1$ when a control flow operation that signifies the initialization of a control block is being executed on the VM. Otherwise, it is set to $0$. An exception is made for the `SYSCALL` operation. Although it also signifies the initialization of a control block, it must additionally send a procedure access request to the [kernel ROM chiplet](../chiplets/kernel_rom.md) via the chiplets bus. Therefore, it is excluded from this flag and its communication with the chiplets bus is handled separately. +In the above, $f_{ctrli}$ is set to $1$ when a control flow operation that signifies the initialization of a control block is being executed on the VM (only those control blocks that don't do any concurrent requests to the chiplets but). Otherwise, it is set to $0$. An exception is made for the `DYN`, `DYNCALL`, and `SYSCALL` operations, since although they initialize a control block, they also run another concurrent bus request, and so are handled separately. $$ d = \sum_{b=0}^6(b_i \cdot 2^i) @@ -150,7 +151,7 @@ In the above, $d$ represents the opcode value of the opcode being executed on th Using the above variables, we define operation values as described below. -When a control block initializer operation (`JOIN`, `SPLIT`, `LOOP`, `DYN`, `CALL`, `SYSCALL`) is executed, a new hasher is initialized and the contents of $h_0, ..., h_7$ are absorbed into the hasher. As mentioned above, the opcode value $d$ is populated in the second capacity resister via the $\alpha_5$ term. +When a control block initializer operation (`JOIN`, `SPLIT`, `LOOP`, `CALL`) is executed, a new hasher is initialized and the contents of $h_0, ..., h_7$ are absorbed into the hasher. As mentioned above, the opcode value $d$ is populated in the second capacity register via the $\alpha_5$ term. $$ u_{ctrli} = f_{ctrli} \cdot (h_{init} + \alpha_5 \cdot d) \text{ | degree} = 6 @@ -170,6 +171,22 @@ $$ The above value sends both the hash initialization request and the kernel procedure access request to the chiplets bus when the `SYSCALL` operation is executed. +Similar to `SYSCALL`, `DYN` and `DYNCALL` are handled separately, since in addition to communicating with the hash chiplet they must also issue a memory read operation for the hash of the procedure being called. + +$$ +h_{dynordyncall} = \alpha_0 + \alpha_1 \cdot m_{bp} + \alpha_2 \cdot a' +$$ + +$$ +m_{dynordyncall} = \alpha_0 + \alpha_1 \cdot m_{read} + \alpha_2 \cdot ctx + \alpha_3 \cdot s_0 + \alpha_4 \cdot clk + <[\alpha_5 \dots \alpha_8], h[0 \dots 4]> +$$ + +$$ +u_{dynordyncall} = (f_{dyn} + f_{dyncall}) (h_{dynordyncall} \cdot m_{dynordyncall}) +$$ + +In the above, $h_{dynordyncall}$ can be thought of as $h_{init}$, but where the values used for the hasher decoder trace registers is all 0's. $m_{dynordyncall}$ represents a memory read request from memory address $s_0$ (the top stack element), where the result is placed in the first half of the decoder hasher trace, and where $m_{read}$ is a label that represents a memory read request. + When `SPAN` operation is executed, a new hasher is initialized and contents of $h_0, ..., h_7$ are absorbed into the hasher. The number of operation groups to be hashed is padded to a multiple of the rate width ($8$) and so the $\alpha_4$ is set to 0: $$ @@ -191,8 +208,8 @@ $$ Using the above definitions, we can describe the constraint for computing block hashes as follows: > $$ -b_{chip}' \cdot (u_{ctrli} + u_{syscall} + u_{span} + u_{respan} + u_{end} + \\ -1 - (f_{ctrli} + f_{syscall} + f_{span} + f_{respan} + f_{end})) = b_{chip} +b_{chip}' \cdot (u_{ctrli} + u_{syscall} + u_{dynordyncall} + u_{span} + u_{respan} + u_{end} + \\ +1 - (f_{ctrli} + f_{syscall} + f_{dyn} + f_{dyncall} + f_{span} + f_{respan} + f_{end})) = b_{chip} $$ We need to add $1$ and subtract the sum of the relevant operation flags to ensure that when none of the flags is set to $1$, the above constraint reduces to $b_{chip}' = b_{chip}$. @@ -203,8 +220,11 @@ The degree of this constraint is $8$. As described [previously](./main.md#block-stack-table), block stack table keeps track of program blocks currently executing on the VM. Thus, whenever the VM starts executing a new block, an entry for this block is added to the block stack table. And when execution of a block completes, it is removed from the block stack table. Adding and removing entries to/from the block stack table is accomplished as follows: -* To add an entry, we multiply the value in column $p_1$ by a value representing a tuple `(blk_id, prnt_id, is_loop)`. A constraint to enforce this would look as $p_1' = p_1 \cdot v$, where $v$ is the value representing the row to be added. -* To remove an entry, we divide the value in column $p_1$ by a value representing a tuple `(blk_id, prnt_id, is_loop)`. A constraint to enforce this would look as $p_1' \cdot u = p_1$, where $u$ is the value representing the row to be removed. +* To add an entry, we multiply the value in column $p_1$ by a value representing a tuple `(blk, prnt, is_loop, ctx_next, fmp_next, b0_next, b1_next, fn_hash_next)` +. A constraint to enforce this would look as $p_1' = p_1 \cdot v$, where $v$ is the value representing the row to be added. +* To remove an entry, we divide the value in column $p_1$ by a value representing a tuple `(blk, prnt, is_loop, ctx_next, fmp_next, b0_next, b1_next, fn_hash_next)`. A constraint to enforce this would look as $p_1' \cdot u = p_1$, where $u$ is the value representing the row to be removed. + +> Recall that the columns `ctx_next, fmp_next, b0_next, b1_next, fn_hash_next` are only set on `CALL`, `SYSCALL`, and their corresponding `END` block. Therefore, for simplicity, we will ignore them when documenting all other block types (such that their values are set to `0`). Before describing the constraints for the block stack table, we first describe how we compute the values to be added and removed from the table for each operation. In the below, for block start operations (`JOIN`, `SPLIT`, `LOOP`, `SPAN`) $a$ refers to the ID of the parent block, and $a'$ refers to the ID of the starting block. For `END` operation, the situation is reversed: $a$ is the ID of the ending block, and $a'$ is the ID of the parent block. For `RESPAN` operation, $a$ refers to the ID of the current operation batch, $a'$ refers to the ID of the next batch, and the parent ID for both batches is set by the prover non-deterministically in register $h_1$. @@ -245,25 +265,47 @@ $$ v_{dyn} = f_{dyn} \cdot (\alpha_0 + \alpha_1 \cdot a' + \alpha_2 \cdot a) \text{ | degree} = 6 $$ -When `END` operation is executed, row $(a, a', h_5)$ is removed from the block span table. Register $h_5$ contains the `is_loop` flag: +When a `DYNCALL` operation is executed, row $(a', a, 0, ctx, fmp, b_0, b_1, \mathrm{fnhash}[0..3])$ is added to the block stack table: $$ -u_{end} = f_{end} \cdot (\alpha_0 + \alpha_1 \cdot a + \alpha_2 \cdot a' + \alpha_3 \cdot h_5) \text{ | degree} = 5 +\begin{align*} +v_{dyncall} &= f_{dyncall} \cdot (\alpha_0 + \alpha_1 \cdot a + \alpha_2 \cdot a' + \alpha_4 \cdot ctx \\ +&+ \alpha_5 \cdot fmp + \alpha_6 \cdot b_0 + \alpha_7 \cdot b_1 + <[\alpha_8, \alpha_{11}], \mathrm{fnhash}[0..3]>) \text{ | degree} = 6 +\end{align*} +$$ + +When a `CALL` or `SYSCALL` operation is executed, row $(a', a, 0, ctx, fmp, b_0, b_1, \mathrm{fnhash}[0..3])$ is added to the block stack table: + +$$ +\begin{align*} +v_{callorsyscall} &= (f_{call} + f_{syscall}) \cdot (\alpha_0 + \alpha_1 \cdot a + \alpha_2 \cdot a' + \alpha_4 \cdot ctx \\ +&+ \alpha_5 \cdot fmp + \alpha_6 \cdot b_0 + \alpha_7 \cdot b_1 + <[\alpha_8, \alpha_{11}], \mathrm{fnhash}[0..3]>) \text{ | degree} = 5 +\end{align*} +$$ + +When `END` operation is executed, how we construct the row will depend on whether the `IS_CALL` or `IS_SYSCALL` values are set (stored in registers $h_6$ and $h_7$ respectively). If they are not set, then row $(a, a', h_5)$ is removed from the block span table (where $h_5$ contains the `is_loop` flag); otherwise, row $(a ,a', 0, ctx', fmp', b_0', b_1', \mathrm{fnhash}'[0..3])$. + +$$ +\begin{align*} +u_{endnocall} &=\alpha_0 + \alpha_1 \cdot a + \alpha_2 \cdot a' \\ +u_{endcall} &= u_{endnocall} + \alpha_4 \cdot ctx' + \alpha_5 \cdot fmp' + \alpha_6 \cdot b_0' + \alpha_7 \cdot b_1' + <[\alpha_8, \alpha_{11}], \mathrm{fnhash}'[0..3]>\\ +u_{end} &= f_{end} \cdot ((1 - h_6 - h_7) \cdot u_{endnocall} + (h_6 + h_7) \cdot u_{endcall} ) \text{ | degree} = 6 +\end{align*} $$ Using the above definitions, we can describe the constraint for updating the block stack table as follows: > $$ p_1' \cdot (u_{end} + u_{respan} + 1 - (f_{end} + f_{respan})) = p_1 \cdot \\ -(v_{join} + v_{split} + v_{loop} + v_{span} + v_{respan} + v_{dyn} + 1 - \\ -(f_{join} + f_{split} + f_{loop} + f_{span} + f_{respan} + f_{dyn})) +(v_{join} + v_{split} + v_{loop} + v_{span} + v_{respan} + v_{dyn} + v_{dyncall} + v_{callorsyscall} + 1 - \\ +(f_{join} + f_{split} + f_{loop} + f_{span} + f_{respan} + f_{dyn} + f_{dyncall} + f_{call} + f_{syscall})) $$ We need to add $1$ and subtract the sum of the relevant operation flags from each side to ensure that when none of the flags is set to $1$, the above constraint reduces to $p_1' = p_1$. The degree of this constraint is $7$. -In addition to the above transition constraint, we also need to impose boundary constraints against the $p_1$ column to make sure the first and the last value in the column is set to $1$. This enforces that the block stack table starts and ends in an empty state. +In addition to the above transition constraint, we also need to impose boundary constraints against the $p_1$ column to make sure the first and the last values in the column are set to $1$. This enforces that the block stack table starts and ends in an empty state. ## Block hash table constraints As described [previously](./main.md#block-hash-table), when the VM starts executing a new program block, it adds hashes of the block's children to the block hash table. And when the VM finishes executing a block, it removes the block's hash from the block hash table. This means that the block hash table gets updated when we execute the `JOIN`, `SPLIT`, `LOOP`, `REPEAT`, `DYN`, and `END` operations (executing `SPAN` operation does not affect the block hash table because a *span* block has no children). @@ -286,10 +328,10 @@ Graphically, this looks like so: In a similar manner, we define a value representing the result of hash computation as follows: $$ -bh = \alpha_0 + \alpha_1 \cdot a + \sum_{i=0}^3(\alpha_{i+2} \cdot h_i) + \alpha_7 \cdot h_4 \text{ | degree} = 1 +bh = \alpha_0 + \alpha_1 \cdot a' + \sum_{i=0}^3(\alpha_{i+2} \cdot h_i) + \alpha_7 \cdot f_{is\_loop\_body} \text{ | degree} = 1 $$ -Note that in the above we use $a$ (block address from the current row) rather than $a'$ (block address from the next row) as we did for for values of $ch_1$ and $ch_2$. Also, note that we are not adding a flag indicating whether the block is the first child of a join block (i.e., $\alpha_6$ term is missing). It will be added later on. +Above, $f_{is\_loop\_body}$ refers to the value in the `IS_LOOP_BODY` column (already constrained to be 0 or 1), located in $h_4$. Also, note that we are not adding a flag indicating whether the block is the first child of a join block (i.e., $\alpha_6$ term is missing). It will be added later on. Using the above variables, we define row values to be added to and removed from the block hash table as follows. @@ -315,27 +357,23 @@ When `REPEAT` operation is executed, hash of loop body is added to the block has $$v_{repeat} = f_{repeat} \cdot (ch_1 + \alpha_7) \text{ | } \text{degree} = 5$$ -When the `DYN` operation is executed, the hash of the dynamic child is added to the block hash table. Since the child is dynamically specified by the top four elements of the stack, the value representing the *dyn* block's child must be computed based on the stack rather than from the decoder's hasher registers: +When `DYN`, `DYNCALL`, `CALL` or `SYSCALL` operation is executed, the hash of the child is added to the block hash table. In all cases, this child is found in the first half of the decoder hasher state. $$ -ch_{dyn} = \alpha_0 + \alpha_1 \cdot a' + \sum_{i=0}^3(\alpha_{i+2} \cdot s_{3-i}) \text{ | degree} = 1 +v_{allcalls} = (f_{dyn} + f_{dyncall} + f_{call} + f_{syscall}) \cdot ch_1 \text{ | degree} = 6 $$ -$$ -v_{dyn} = f_{dyn} \cdot ch_{dyn} \text{ | degree} = 6 -$$ - -When `END` operation is executed, hash of the completed block is removed from the block hash table. However, we also need to differentiate between removing the first and the second child of a *join* block. We do this by looking at the next operation. Specifically, if the next operation is neither `END` nor `REPEAT` we know that another block is about to be executed, and thus, we have just finished executing the first child of a *join* block. Thus, if the next operation is neither `END` nor `REPEAT` we need to set the term for $\alpha_6$ coefficient to $1$ as shown below: +When `END` operation is executed, hash of the completed block is removed from the block hash table. However, we also need to differentiate between removing the first and the second child of a *join* block. We do this by looking at the next operation. Specifically, if the next operation is neither `END` nor `REPEAT` nor `HALT`, we know that another block is about to be executed, and thus, we have just finished executing the first child of a *join* block. Thus, if the next operation is neither `END` nor `REPEAT` nor `HALT` we need to set the term for $\alpha_6$ coefficient to $1$ as shown below: $$ -u_{end} = f_{end} \cdot (bh + \alpha_6 \cdot (1 - (f_{end}' + f_{repeat}'))) \text{ | } \text{degree} = 8 +u_{end} = f_{end} \cdot (bh + \alpha_6 \cdot (1 - (f_{end}' + f_{repeat}' + f_{halt}'))) \text{ | } \text{degree} = 8 $$ Using the above definitions, we can describe the constraint for updating the block hash table as follows: > $$ p_2' \cdot (u_{end} + 1 - f_{end}) = \\ -p_2 \cdot (v_{join} + v_{split} + v_{loop} + v_{repeat} + v_{dyn} + 1 - (f_{join} + f_{split} + f_{loop} + f_{repeat} + f_{dyn})) +p_2 \cdot (v_{join} + v_{split} + v_{loop} + v_{repeat} + v_{allcalls} + 1 - (f_{join} + f_{split} + f_{loop} + f_{repeat} + f_{dyn} + f_{dyncall} + f_{call} + f_{syscall})) $$ We need to add $1$ and subtract the sum of the relevant operation flags from each side to ensure that when none of the flags is set to $1$, the above constraint reduces to $p_2' = p_2$. @@ -404,9 +442,9 @@ In the beginning of a span block (i.e., when `SPAN` operation is executed), the The rules for decrementing values in the $gc$ column are as follows: * The count cannot be decremented by more than $1$ in a single row. * When an operation group is fully executed (which happens when $h_0 = 0$ inside a span block), the count is decremented by $1$. -* When `SPAN`, `RESPAN`, or `PUSH` operations are executed, the count is decremented by $1$. +* When `SPAN`, `RESPAN`, `EMIT` or `PUSH` operations are executed, the count is decremented by $1$. -Note that these rules imply that `PUSH` operation cannot be the last operation in an operation group (otherwise the count would have to be decremented by $2$). +Note that these rules imply that the `EMIT` and `PUSH` operations cannot be the last operation in an operation group (otherwise the count would have to be decremented by $2$). To simplify the description of the constraints, we will define the following variable: @@ -422,18 +460,18 @@ Inside a *span* block, group count can either stay the same or decrease by one: sp \cdot \Delta gc \cdot (\Delta gc - 1) = 0 \text{ | degree} = 3 $$ -When group count is decremented inside a *span* block, either $h_0$ must be $0$ (we consumed all operations in a group) or we must be executing `PUSH` operation: +When group count is decremented inside a *span* block, either $h_0$ must be $0$ (we consumed all operations in a group) or we must be executing an operation with an immediate value: > $$ -sp \cdot \Delta gc \cdot (1 - f_{push})\cdot h_0 = 0 \text{ | degree} = 7 +sp \cdot \Delta gc \cdot (1 - f_{imm})\cdot h_0 = 0 \text{ | degree} = 7 $$ -Notice that the above constraint does not preclude $f_{push} = 1$ and $h_0 = 0$ from being true at the same time. If this happens, op group decoding constraints (described [here](#op-group-decoding-constraints)) will force that the operation following the `PUSH` operation is a `NOOP`. +Notice that the above constraint does not preclude $f_{imm} = 1$ and $h_0 = 0$ from being true at the same time. If this happens, op group decoding constraints (described [here](#op-group-decoding-constraints)) will force that the operation following the operation with an immediate value is a `NOOP`. -When executing a `SPAN`, a `RESPAN`, or a `PUSH` operation, group count must be decremented by $1$: +When executing a `SPAN`, a `RESPAN`, or an operation with an immediate value, group count must be decremented by $1$: > $$ -(f_{span} + f_{respan} + f_{push}) \cdot (\Delta gc - 1) = 0 \text{ | degree} = 6 +(f_{span} + f_{respan} + f_{imm}) \cdot (\Delta gc - 1) = 0 \text{ | degree} = 6 $$ If the next operation is either an `END` or a `RESPAN`, group count must remain the same: @@ -467,17 +505,17 @@ op = \sum_{i=0}^6 (b_i \cdot 2^i) \\ f_{sgc} = sp \cdot sp' \cdot (1 - \Delta gc) $$ -$op$ is just an opcode value implied by the values in `op_bits` registers. $f_{sgc}$ is a flag which is set to $1$ when the group count within a *span* block does not change. We multiply it by $sp'$ to make sure the flag is $0$ when we are about to end decoding of an operation batch. Note that $f_{sgc}$ flag is mutually exclusive with $f_{span}$, $f_{respan}$, and $f_{push}$ flags as these three operations decrement the group count. +$op$ is just an opcode value implied by the values in `op_bits` registers. $f_{sgc}$ is a flag which is set to $1$ when the group count within a *span* block does not change. We multiply it by $sp'$ to make sure the flag is $0$ when we are about to end decoding of an operation batch. Note that $f_{sgc}$ flag is mutually exclusive with $f_{span}$, $f_{respan}$, and $f_{imm}$ flags as these three operations decrement the group count. Using these variables, we can describe operation group decoding constraints as follows: -When a `SPAN`, a `RESPAN`, or a `PUSH` operation is executed or when the group count does not change, the value in $h_0$ should be decremented by the value of the opcode in the next row. +When a `SPAN`, a `RESPAN`, or an operation with an immediate value is executed or when the group count does not change, the value in $h_0$ should be decremented by the value of the opcode in the next row. > $$ -(f_{span} + f_{respan} + f_{push} + f_{sgc}) \cdot (h_0 - h_0' \cdot 2^7 - op') = 0 \text{ | degree} = 6 +(f_{span} + f_{respan} + f_{imm} + f_{sgc}) \cdot (h_0 - h_0' \cdot 2^7 - op') = 0 \text{ | degree} = 6 $$ -Notice that when the group count does change, and we are not executing $f_{span}$, $f_{respan}$, or $f_{push}$ operations, no constraints are placed against $h_0$, and thus, the prover can populate this register non-deterministically. +Notice that when the group count does change, and we are not executing $f_{span}$, $f_{respan}$, or $f_{imm}$ operations, no constraints are placed against $h_0$, and thus, the prover can populate this register non-deterministically. When we are in a *span* block and the next operation is `END` or `RESPAN`, the current value in $h_0$ column must be $0$. @@ -491,11 +529,11 @@ The `op_index` column (denoted as $ox$) tracks index of an operation within its To simplify the description of the constraints, we will define the following variables: $$ -ng = \Delta gc - f_{push} \\ +ng = \Delta gc - f_{imm} \\ \Delta ox = ox' - ox $$ -The value of $ng$ is set to $1$ when we are about to start executing a new operation group (i.e., group count is decremented but we did not execute a `PUSH` operation). Using these variables, we can describe the constraints against the $ox$ column as follows. +The value of $ng$ is set to $1$ when we are about to start executing a new operation group (i.e., group count is decremented but we did not execute an operation with an immediate value). Using these variables, we can describe the constraints against the $ox$ column as follows. When executing `SPAN` or `RESPAN` operations the next value of `op_index` must be set to $0$: @@ -547,6 +585,12 @@ When `SPAN` or `RESPAN` operations is executed, one of the batch flags must be s (f_{span} + f_{respan}) - (f_{g1} + f_{g2} + f_{g4} + f_{g8}) = 0 \text{ | degree} = 5 $$ +When neither `SPAN` nor `RESPAN` is executed, all batch flags must be set to $0$. + +$$ +(1 - (f_{span} + f_{respan})) \cdot (bc_0 + bc_1 + bc_2) = 0 \text{ | degree} = 6 +$$ + When we have at most 4 groups in a batch, registers $h_4, ..., h_7$ should be set to $0$'s. > $$ @@ -584,15 +628,15 @@ Where $i \in [1, 8)$. Thus, $v_1$ defines row value for group in $h_1$, $v_2$ de We compute the value of the row to be removed from the op group table as follows: $$ -u = \alpha_0 + \alpha_1 \cdot a + \alpha_2 \cdot gc + \alpha_3 \cdot ((h_0' \cdot 2^7 + op') \cdot (1 - f_{push}) + s_0' \cdot f_{push}) \text{ | degree} = 5 +u = \alpha_0 + \alpha_1 \cdot a + \alpha_2 \cdot gc + \alpha_3 \cdot ((h_0' \cdot 2^7 + op') \cdot (1 - f_{imm}) + s_0' \cdot f_{push} + h_2 \cdot f_{emit}) \text{ | degree} = 6 $$ -In the above, the value of the group is computed as $(h_0' \cdot 2^7 + op') \cdot (1 - f_{push}) + s_0' \cdot f_{push}$. This basically says that when we execute a `PUSH` operation we need to remove the immediate value from the table. This value is at the top of the stack (column $s_0$) in the next row. However, when we are not executing a `PUSH` operation, the value to be removed is an op group value which is a combination of values in $h_0$ and `op_bits` columns (also in the next row). Note also that value for batch address comes from the current value in the block address column ($a$), and the group position comes from the current value of the group count column ($gc$). +In the above, the value of the group is computed as $(h_0' \cdot 2^7 + op') \cdot (1 - f_{push}) + s_0' \cdot f_{push} + h_2 \cdot f_{emit}$. This basically says that when we execute a `PUSH` or `EMIT` operation we need to remove the immediate value from the table. For `PUSH`, this value is at the top of the stack (column $s_0$) in the next row; for `EMIT`, it is found in $h_2$. However, when we are executing neither a `PUSH` nor `EMIT` operation, the value to be removed is an op group value which is a combination of values in $h_0$ and `op_bits` columns (also in the next row). Note also that value for batch address comes from the current value in the block address column ($a$), and the group position comes from the current value of the group count column ($gc$). We also define a flag which is set to $1$ when a group needs to be removed from the op group table. $$ -f_{dg} = sp \cdot \Delta gc +f_{dg} = sp \cdot \Delta gc \text{ | degree} = 2 $$ The above says that we remove groups from the op group table whenever group count is decremented. We multiply by $sp$ to exclude the cases when the group count is decremented due to `SPAN` or `RESPAN` operations. @@ -600,12 +644,12 @@ The above says that we remove groups from the op group table whenever group coun Using the above variables together with flags $f_{g2}$, $f_{g4}$, $f_{g8}$ defined in the previous section, we describe the constraint for updating op group table as follows (note that we do not use $f_{g1}$ flag as when a batch consists of a single group, nothing is added to the op group table): > $$ -p_3' \cdot (f_{dg} \cdot u + 1 - f_{dg}) = p_3 \cdot (f_{g2} \cdot v_1 + f_{g4} \cdot \prod_{i=1}^3 v_i + f_{g8} \cdot \prod_{i=1}^7 v_i - 1 + (f_{span} + f_{respan})) +p_3' \cdot (f_{dg} \cdot u + 1 - f_{dg}) = p_3 \cdot (f_{g2} \cdot v_1 + f_{g4} \cdot \prod_{i=1}^3 v_i + f_{g8} \cdot (\prod_{i=1}^7 v_i) + 1 - (f_{span} + f_{respan})) $$ The above constraint specifies that: -* When `SPAN` or `RESPAN` operations are executed, we add between $1$ and $7$ groups to the op group table. -* When group count is decremented inside a *span* block, we remove a group from the op group table. +* When `SPAN` or `RESPAN` operations are executed, we add between $1$ and $7$ groups to the op group table; else, leave $p3$ untouched. +* When group count is decremented inside a *span* block, we remove a group from the op group table; else, leave $p3'$ untouched. The degree of this constraint is $9$. diff --git a/docs/src/design/decoder/main.md b/docs/src/design/decoder/main.md index af2020fe8d..b48604cd5d 100644 --- a/docs/src/design/decoder/main.md +++ b/docs/src/design/decoder/main.md @@ -17,7 +17,7 @@ The sections below describe how Miden VM decoder works. Throughout these section Miden VM programs consist of a set of code blocks organized into a binary tree. The leaves of the tree contain linear sequences of instructions, and control flow is defined by the internal nodes of the tree. -Managing control flow in the VM is accomplished by executing control flow operations listed in the table below. Each of these operations require exactly one VM cycle to execute. +Managing control flow in the VM is accomplished by executing control flow operations listed in the table below. Each of these operations requires exactly one VM cycle to execute. | Operation | Description | | --------- | ---------------------------------------------------------------------------- | @@ -118,10 +118,10 @@ These registers have the following meanings: 2. Registers $b_0, ..., b_6$, which encode opcodes for operation to be executed by the VM. Each of these registers can contain a single binary value (either $1$ or $0$). And together these values describe a single opcode. 3. Hasher registers $h_0, ..., h_7$. When control flow operations are executed, these registers are used to provide inputs for the current block's hash computation (e.g., for `JOIN`, `SPLIT`, `LOOP`, `SPAN`, `CALL`, `SYSCALL` operations) or to record the result of the hash computation (i.e., for `END` operation). However, when regular operations are executed, $2$ of these registers are used to help with op group decoding, and the remaining $6$ can be used to hold operation-specific helper variables. 4. Register $sp$ which contains a binary flag indicating whether the VM is currently executing instructions inside a *span* block. The flag is set to $1$ when the VM executes non-control flow instructions, and is set to $0$ otherwise. -5. Register $gc$ which keep track of the number of unprocessed operation groups in a given *span* block. +5. Register $gc$ which keeps track of the number of unprocessed operation groups in a given *span* block. 6. Register $ox$ which keeps track of a currently executing operation's index within its operation group. 7. Operation batch flags $c_0, c_1, c_2$ which indicate how many operation groups a given operation batch contains. These flags are set only for `SPAN` and `RESPAN` operations, and are set to $0$'s otherwise. -8. Two additional registers (not shown) used primarily for constraint degree reduction. +8. Two additional registers (not shown) are used primarily for constraint degree reduction. ### Program block hashing @@ -187,7 +187,9 @@ In addition to the hash chiplet, control flow operations rely on $3$ virtual tab When the VM starts executing a new program block, it adds its block ID together with the ID of its parent block (and some additional info) to the *block stack* table. When a program block is fully executed, it is removed from the table. In this way, the table represents a stack of blocks which are currently executing on the VM. By the time program execution completes, block stack table must be empty. -The table can be thought of as consisting of $3$ columns as shown below: +The block stack table is also used to ensure that execution contexts are managed properly across the `CALL` and `SYSCALL` operations. + +The table can be thought of as consisting of $11$ columns as shown below: ![decoder_block_stack_table](../../assets/design/decoder/decoder_block_stack_table.png) @@ -195,16 +197,20 @@ where: * The first column ($t_0$) contains the ID of the block. * The second column ($t_1$) contains the ID of the parent block. If the block has no parent (i.e., it is a root block of the program), parent ID is 0. * The third column ($t_2$) contains a binary value which is set to $1$ is the block is a *loop* block, and to $0$ otherwise. +* The following 8 columns are only set to non-zero values for `CALL` and `SYSCALL` operations. They save all the necessary information to be able to restore the parent context properly upon the corresponding `END` operation + - the `prnt_b0` and `prnt_b1` columns refer to the stack helper columns B0 and B1 (current stack depth and last overflow address, respectively) + +In the above diagram, the first 2 rows correspond to 2 different `CALL` operations. The first `CALL` operation is called from the root context, and hence its parent fn hash is the zero hash. Additionally, the second `CALL` operation has a parent fn hash of `[h0, h1, h2, h3]`, indicating that the first `CALL` was to a procedure with that hash. Running product column $p_1$ is used to keep track of the state of the table. At any step of the computation, the current value of $p_1$ defines which rows are present in the table. To reduce a row in the block stack table to a single value, we compute the following. $$ -row = \alpha_0 + \sum_{i=0}^3 (\alpha_{i+1} \cdot t_i) +row = \alpha_0 + \sum_{i=0}^{10} (\alpha_{i+1} \cdot t_i), $$ -Where $\alpha_0, ..., \alpha_3$ are the random values provided by the verifier. +where $\alpha_0, ..., \alpha_{11}$ are the random values provided by the verifier. #### Block hash table @@ -273,7 +279,7 @@ In the above diagram, `blk` is the ID of the *join* block which is about to be e When the VM executes a `JOIN` operation, it does the following: -1. Adds a tuple `(blk, prnt, 0)` to the block stack table. +1. Adds a tuple `(blk, prnt, 0, 0...)` to the block stack table. 2. Adds tuples `(blk, left_child_hash, 1, 0)` and `(blk, right_child_hash, 0, 0)` to the block hash table. 3. Initiates a 2-to-1 hash computation in the hash chiplet (as described [here](#simple-2-to-1-hash)) using `blk` as row address in the auxiliary hashing table and $h_0, ..., h_7$ as input values. @@ -287,7 +293,7 @@ In the above diagram, `blk` is the ID of the *split* block which is about to be When the VM executes a `SPLIT` operation, it does the following: -1. Adds a tuple `(blk, prnt, 0)` to the block stack table. +1. Adds a tuple `(blk, prnt, 0, 0...)` to the block stack table. 2. Pops the stack and:\ a. If the popped value is $1$, adds a tuple `(blk, true_branch_hash, 0, 0)` to the block hash table.\ b. If the popped value is $0$, adds a tuple `(blk, false_branch_hash, 0, 0)` to the block hash table.\ @@ -305,8 +311,8 @@ In the above diagram, `blk` is the ID of the *loop* block which is about to be e When the VM executes a `LOOP` operation, it does the following: 1. Pops the stack and:\ - a. If the popped value is $1$ adds a tuple `(blk, prnt, 1)` to the block stack table (the `1` indicates that the loop's body is expected to be executed). Then, adds a tuple `(blk, loop_body_hash, 0, 1)` to the block hash table.\ - b. If the popped value is $0$, adds `(blk, prnt, 0)` to the block stack table. In this case, nothing is added to the block hash table.\ + a. If the popped value is $1$ adds a tuple `(blk, prnt, 1, 0...)` to the block stack table (the `1` indicates that the loop's body is expected to be executed). Then, adds a tuple `(blk, loop_body_hash, 0, 1)` to the block hash table.\ + b. If the popped value is $0$, adds `(blk, prnt, 0, 0...)` to the block stack table. In this case, nothing is added to the block hash table.\ c. If the popped value is neither $1$ nor $0$, the execution fails. 2. Initiates a 2-to-1 hash computation in the hash chiplet (as described [here](#simple-2-to-1-hash)) using `blk` as row address in the auxiliary hashing table and $h_0, ..., h_3$ as input values. @@ -320,7 +326,7 @@ In the above diagram, `blk` is the ID of the *span* block which is about to be e When the VM executes a `SPAN` operation, it does the following: -1. Adds a tuple `(blk, prnt, 0)` to the block stack table. +1. Adds a tuple `(blk, prnt, 0, 0...)` to the block stack table. 2. Adds groups of the operation batch, as specified by op batch flags (see [here](#operation-batch-flags)) to the op group table. 3. Initiates a sequential hash computation in the hash chiplet (as described [here](#sequential-hash)) using `blk` as row address in the auxiliary hashing table and $h_0, ..., h_7$ as input values. 4. Sets the `in_span` register to $1$. @@ -330,24 +336,45 @@ When the VM executes a `SPAN` operation, it does the following: #### DYN operation -Before a `DYN` operation is executed by the VM, the prover populates $h_0, ..., h_7$ registers with $0$ as shown in the diagram below. - ![decoder_dyn_operation](../../assets/design/decoder/decoder_dyn_operation.png) -In the above diagram, `blk` is the ID of the *dyn* block which is about to be executed. `blk` is also the address of the hasher row in the auxiliary hasher table. `prnt` is the ID of the block's parent. +In the above diagram, `blk` is the ID of the *dyn* block which is about to be executed. `blk` is also the address of the hasher row in the auxiliary hasher table. `p_addr` is the ID of the block's parent. When the VM executes a `DYN` operation, it does the following: -1. Adds a tuple `(blk, prnt, 0)` to the block stack table. -2. Gets the hash of the dynamic code block `dynamic_block_hash` from the top four elements of the stack. -2. Adds the tuple `(blk, dynamic_block_hash, 0, 0)` to the block hash table. -3. Initiates a 2-to-1 hash computation in the hash chiplet (as described [here](#simple-2-to-1-hash)) using `blk` as row address in the auxiliary hashing table and $h_0, ..., h_7$ as input values. +1. Adds a tuple `(blk, p_addr, 0, 0...)` to the block stack table. +2. Sends a memory read request to the memory chiplet, using `s0` as the memory address. The result `hash of callee` is placed in the decoder hasher trace at $h_0, h_1, h_2, h_3$. +3. Adds the tuple `(blk, hash of callee, 0, 0)` to the block hash table. +4. Initiates a 2-to-1 hash computation in the hash chiplet (as described [here](#simple-2-to-1-hash)) using `blk` as row address in the auxiliary hashing table and `[ZERO; 8]` as input values. +5. Performs a stack left shift + - Above `s16` was pulled from the stack overflow table if present; otherwise set to `0`. + +Note that unlike `DYNCALL`, the `fmp`, `ctx`, `in_syscall` and `fn_hash` registers are unchanged. + +#### DYNCALL operation + +![decoder_dyncall_operation](../../assets/design/decoder/decoder_dyncall_operation.png) + +In the above diagram, `blk` is the ID of the *dyn* block which is about to be executed. `blk` is also the address of the hasher row in the auxiliary hasher table. `p_addr` is the ID of the block's parent. + +When the VM executes a `DYNCALL` operation, it does the following: + +1. Adds a tuple `(blk, p_addr, 0, ctx, fmp, b_0, b_1, fn_hash[0..3])` to the block stack table. +2. Sends a memory read request to the memory chiplet, using `s0` as the memory address. The result `hash of callee` is placed in the decoder hasher trace at $h_0, h_1, h_2, h_3$. +3. Adds the tuple `(blk, hash of callee, 0, 0)` to the block hash table. +4. Initiates a 2-to-1 hash computation in the hash chiplet (as described [here](#simple-2-to-1-hash)) using `blk` as row address in the auxiliary hashing table and `[ZERO; 8]` as input values. +5. Performs a stack left shift + - Above `s16` was pulled from the stack overflow table if present; otherwise set to `0`. + +Similar to `CALL`, `DYNCALL` resets the `fmp`, sets up a new `ctx`, and sets the `fn_hash` registers to the callee hash. `in_syscall` needs to be 0, since calls are not allowed during a syscall. #### END operation Before an `END` operation is executed by the VM, the prover populates $h_0, ..., h_3$ registers with the hash of the block which is about to end. The prover also sets values in $h_4$ and $h_5$ registers as follows: * $h_4$ is set to $1$ if the block is a body of a *loop* block. We denote this value as `f0`. * $h_5$ is set to $1$ if the block is a *loop* block. We denote this value as `f1`. +* $h_6$ is set to $1$ if the block is a *call* block. We denote this value as `f2`. +* $h_7$ is set to $1$ if the block is a *syscall* block. We denote this value as `f3`. ![decoder_end_operation](../../assets/design/decoder/decoder_end_operation.png) @@ -355,7 +382,10 @@ In the above diagram, `blk` is the ID of the block which is about to finish exec When the VM executes an `END` operation, it does the following: -1. Removes a tuple `(blk, prnt, f1)` from the block stack table. +1. Removes a tuple from the block stack table. + - if `f2` or `f3` is set, we remove a row `(blk, prnt, 0, ctx_next, fmp_next, b0_next, b1_next, fn_hash_next[0..4])` + - in the above, the `x_next` variables denote the column `x` in the next row + - else, we remove a row `(blk, prnt, f1, 0, 0, 0, 0, 0)` 2. Removes a tuple `(prnt, current_block_hash, nxt, f0)` from the block hash table, where $nxt=0$ if the next operation is either `END` or `REPEAT`, and $1$ otherwise. 3. Reads the hash result from the hash chiplet (as described [here](#program-block-hashing)) using `blk + 7` as row address in the auxiliary hashing table. 4. If $h_5 = 1$ (i.e., we are exiting a *loop* block), pops the value off the top of the stack and verifies that the value is $0$. @@ -402,14 +432,62 @@ In the above diagram, `g0_op0` is the first operation of the new operation batch When the VM executes a `RESPAN` operation, it does the following: 1. Increments block address by $8$. -2. Removes the tuple `(blk, prnt, 0)` from the block stack table. -3. Adds the tuple `(blk+8, prnt, 0)` to the block stack table. +2. Removes the tuple `(blk, prnt, 0, 0...)` from the block stack table. +3. Adds the tuple `(blk+8, prnt, 0, 0...)` to the block stack table. 4. Absorbs values in registers $h_0, ..., h_7$ into the hasher state of the hash chiplet (as described [here](#sequential-hash)). 5. Sets the `in_span` register to $1$. 6. Adds groups of the operation batch, as specified by op batch flags (see [here](#operation-batch-flags)) to the op group table using `blk+8` as batch ID. The net result of the above is that we incremented the ID of the current block by $8$ and added the next set of operation groups to the op group table. +#### CALL operation + +Recall that the purpose of a `CALL` operation is to execute a procedure in a new execution context. Specifically, this means that the entire memory is zero'd in the new execution context, and the stack is truncated to a depth of 16 (i.e. any element in the stack overflow table is not available in the new context). On the corresponding `END` instruction, the prover will restore the previous execution context (verified by the block stack table). + +Before a `CALL` operation, the prover populates $h_0, ..., h_3$ registers with the hash of the procedure being called. In the next row, the prover + +- resets the FMP register (free memory pointer), +- sets the context ID to the next row's CLK value +- sets the `fn hash` registers to the hash of the callee + - This register is what the `caller` instruction uses to return the hash of the caller in a syscall +- resets the stack `B0` register to 16 (which tracks the current stack depth) +- resets the overflow address to 0 (which tracks the "address" of the last element added to the overflow table) + - it is set to 0 to indicate that the overflow table is empty + +![decoder_call_operation](../../assets/design/decoder/decoder_call_operation.png) + +In the above diagram, `blk` is the ID of the *call* block which is about to be executed. `blk` is also the address of the hasher row in the auxiliary hasher table. `prnt` is the ID of the block's parent. + +When the VM executes a `CALL` operation, it does the following: + +1. Adds a tuple `(blk, prnt, 0, p_ctx, p_fmp, p_b0, p_b1, prnt_fn_hash[0..4])` to the block stack table. +2. Initiates a 2-to-1 hash computation in the hash chiplet (as described [here](#simple-2-to-1-hash)) using `blk` as row address in the auxiliary hashing table and $h_0, ..., h_3$ as input values. + +#### SYSCALL operation + +Similarly to the `CALL` operation, a `SYSCALL` changes the execution context. However, it always jumps back to the root context, and executes kernel procedures only. + +Before a `SYSCALL` operation, the prover populates $h_0, ..., h_3$ registers with the hash of the procedure being called. In the next row, the prover + +- resets the FMP register (free memory pointer), +- sets the context ID to 0, +- does NOT modify the `fn hash` register + - Hence, the `fn hash` register contains the procedure hash of the caller, to be accessed by the `caller` instruction, +- resets the stack `B0` register to 16 (which tracks the current stack depth) +- resets the overflow address to 0 (which tracks the "address" of the last element added to the overflow table) + - it is set to 0 to indicate that the overflow table is empty + +![decoder_syscall_operation](../../assets/design/decoder/decoder_syscall_operation.png) + +In the above diagram, `blk` is the ID of the *syscall* block which is about to be executed. `blk` is also the address of the hasher row in the auxiliary hasher table. `prnt` is the ID of the block's parent. + +When the VM executes a `SYSCALL` operation, it does the following: + +1. Adds a tuple `(blk, prnt, 0, p_ctx, p_fmp, p_b0, p_b1, prnt_fn_hash[0..4])` to the block stack table. +2. Sends a request to the kernel ROM chiplet indicating that `hash of callee` is being accessed. + - this results in a fault if `hash of callee` does not correspond to the hash of a kernel procedure +3. Initiates a 2-to-1 hash computation in the hash chiplet (as described [here](#simple-2-to-1-hash)) using `blk` as row address in the auxiliary hashing table and $h_0, ..., h_3$ as input values. + ## Program decoding When decoding a program, we start at the root block of the program. We can compute the hash of the root block directly from hashes of its children. The prover provides hashes of the child blocks non-deterministically, and we use them to compute the program's hash (here we rely on the hash chiplet). We then verify the program hash via boundary constraints. Thus, if the prover provided valid hashes for the child blocks, we will get the expected program hash. @@ -511,7 +589,7 @@ In the above, the batch contains $3$ operation groups. To bring the count up to Operation batch flags (denoted as $c_0, c_1, c_2$), encode the number of groups and define how many groups are added to the op group table as follows: -* `(1, 0, 0)` - $8$ groups. Groups in $h_1, ... h_7$ are added to the op group table. +* `(1, -, -)` - $8$ groups. Groups in $h_1, ... h_7$ are added to the op group table. * `(0, 1, 0)` - $4$ groups. Groups in $h_1, ... h_3$ are added to the op group table * `(0, 0, 1)` - $2$ groups. Groups in $h_1$ is added to the op group table. * `(0, 1, 1)` - $1$ group. Nothing is added to the op group table diff --git a/docs/src/design/stack/crypto_ops.md b/docs/src/design/stack/crypto_ops.md index fffad56411..bfb3ca7fe3 100644 --- a/docs/src/design/stack/crypto_ops.md +++ b/docs/src/design/stack/crypto_ops.md @@ -107,7 +107,7 @@ $$ v_{outputnew} = \alpha_0 + \alpha_1 \cdot op_{rethash} + \alpha_2 \cdot (h_0 + 2 \cdot 8 \cdot s_4 - 1) + \sum_{j=0}^3\alpha_{j + 8} \cdot s_{3 - j}' $$ -In the above, the first two expressions correspond to inputs and outputs for verifying the Merkle path between the old node value and the old tree root, while the last two expressions correspond to inputs and outputs for verifying the Merkle path between the new node value and the new tree root. The hash chiplet ensures the same set of sibling nodes are uses in both of these computations. +In the above, the first two expressions correspond to inputs and outputs for verifying the Merkle path between the old node value and the old tree root, while the last two expressions correspond to inputs and outputs for verifying the Merkle path between the new node value and the new tree root. The hash chiplet ensures the same set of sibling nodes are used in both of these computations. The $op_{mruold}$, $op_{mrunew}$, and $op_{rethash}$ are the unique [operation labels](../chiplets/main.md#operation-labels) used by the above computations. @@ -127,7 +127,7 @@ The stack for the operation is expected to be arranged as follows: - The first $8$ stack elements contain $4$ query points to be folded. Each point is represented by two field elements because points to be folded are in the extension field. We denote these points as $q_0 = (v_0, v_1)$, $q_1 = (v_2, v_3)$, $q_2 = (v_4, v_5)$, $q_3 = (v_6, v_7)$. - The next element $f\_pos$ is the query position in the folded domain. It can be computed as $pos \mod n$, where $pos$ is the position in the source domain, and $n$ is size of the folded domain. - The next element $d\_seg$ is a value indicating domain segment from which the position in the original domain was folded. It can be computed as $\lfloor \frac{pos}{n} \rfloor$. Since the size of the source domain is always $4$ times bigger than the size of the folded domain, possible domain segment values can be $0$, $1$, $2$, or $3$. -- The next element $poe$ is a power of initial domain generator which aid in a computation of the domain point $x$. +- The next element $poe$ is a power of initial domain generator which aids in a computation of the domain point $x$. - The next two elements contain the result of the previous layer folding - a single element in the extension field denoted as $pe = (pe_0, pe_1)$. - The next two elements specify a random verifier challenge $\alpha$ for the current layer defined as $\alpha = (a_0, a_1)$. - The last element on the top of the stack ($cptr$) is expected to be a memory address of the layer currently being folded. @@ -202,4 +202,4 @@ $$ $$ u_{mem} = u_{mem, 1} \cdot u_{mem, 2} -$$ \ No newline at end of file +$$ diff --git a/docs/src/design/stack/io_ops.md b/docs/src/design/stack/io_ops.md index 822d0fa831..08df6dfd93 100644 --- a/docs/src/design/stack/io_ops.md +++ b/docs/src/design/stack/io_ops.md @@ -2,7 +2,7 @@ In this section we describe the AIR constraints for Miden VM input / output operations. These operations move values between the stack and other components of the VM such as program code (i.e., decoder), memory, and advice provider. ### PUSH -The `PUSH` operation pushes the provided immediate value onto the stack (i.e., sets the value of $s_0$ register). Currently, it is the only operation in Miden VM which carries an immediate value. The semantics of this operation are explained in the [decoder section](../decoder/main.html#handling-immediate-values). +The `PUSH` operation pushes the provided immediate value onto the stack non-deterministically (i.e., sets the value of $s_0$ register); it is the responsibility of the [Op Group Table](../decoder/main.md#op-group-table) to ensure that the correct value was pushed on the stack. The semantics of this operation are explained in the [decoder section](../decoder/main.html#handling-immediate-values). The effect of this operation on the rest of the stack is: * **Right shift** starting from position $0$. diff --git a/docs/src/design/stack/main.md b/docs/src/design/stack/main.md index 2b71ce38b7..db428f7970 100644 --- a/docs/src/design/stack/main.md +++ b/docs/src/design/stack/main.md @@ -3,6 +3,7 @@ Miden VM is a stack machine. The stack is a push-down stack of practically unlimited depth (in practical terms, the depth will never exceed $2^{32}$), but only the top $16$ items are directly accessible to the VM. Items on the stack are elements in a prime field with modulus $2^{64} - 2^{32} + 1$. To keep the constraint system for the stack manageable, we impose the following rules: + 1. All operations executed on the VM can shift the stack by at most one item. That is, the end result of an operation must be that the stack shrinks by one item, grows by one item, or the number of items on the stack stays the same. 2. Stack depth must always be greater than or equal to $16$. At the start of program execution, the stack is initialized with exactly $16$ input values, all of which could be $0$'s. 3. By the end of program execution, exactly $16$ items must remain on the stack (again, all of them could be $0$'s). These items comprise the output of the program. @@ -12,24 +13,28 @@ To ensure that managing stack depth does not impose significant burden, we adopt * When the stack depth is $16$, removing additional items from the stack does not change its depth. To keep the depth at $16$, $0$'s are inserted into the deep end of the stack for each removed item. ## Stack representation + The VM allocates $19$ trace columns for the stack. The layout of the columns is illustrated below. -![](../../assets/design/stack/trace_layout.png) +![trace_layout](../../assets/design/stack/trace_layout.png) The meaning of the above columns is as follows: + * $s_0 ... s_{15}$ are the columns representing the top $16$ slots of the stack. * Column $b_0$ contains the number of items on the stack (i.e., the stack depth). In the above picture, there are 16 items on the stacks, so $b_0 = 16$. * Column $b_1$ contains an address of a row in the "overflow table" in which we'll store the data that doesn't fit into the top $16$ slots. When $b_1 = 0$, it means that all stack data fits into the top $16$ slots of the stack. * Helper column $h_0$ is used to ensure that stack depth does not drop below $16$. Values in this column are set by the prover non-deterministically to $\frac{1}{b_0 - 16}$ when $b_0 \neq 16$, and to any other value otherwise. ### Overflow table + To keep track of the data which doesn't fit into the top $16$ stack slots, we'll use an overflow table. This will be a [virtual table](../lookups/multiset.md#virtual-tables). To represent this table, we'll use a single auxiliary column $p_1$. The table itself can be thought of as having 3 columns as illustrated below. -![](../../assets/design/stack/overflow_table_layout.png) +![overflow_table_layout](../../assets/design/stack/overflow_table_layout.png) The meaning of the columns is as follows: + * Column $t_0$ contains row address. Every address in the table must be unique. * Column $t_1$ contains the value that overflowed the stack. * Column $t_2$ contains the address of the row containing the value that overflowed the stack right before the value in the current row. For example, in the picture above, first value $a$ overflowed the stack, then $b$ overflowed the stack, and then value $c$ overflowed the stack. Thus, row with value $b$ points back to the row with value $a$, and row with value $c$ points back to the row with value $b$. @@ -55,15 +60,17 @@ $$ The initial value of $p_1$ is set to $1$. Thus, if by the time Miden VM finishes executing a program the table is empty (we added and then removed exactly the same set of rows), $p_1$ will also be equal to $1$. There are a couple of other rules we'll need to enforce: + * We can delete a row only after the row has been inserted into the table. * We can't insert a row with the same address twice into the table (even if the row was inserted and then deleted). How these are enforced will be described a bit later. ## Right shift + If an operation adds data to the stack, we say that the operation caused a right shift. For example, `PUSH` and `DUP` operations cause a right shift. Graphically, this looks like so: -![](../../assets/design/stack/stack_right_shift.png) +![stack_right_shift](../../assets/design/stack/stack_right_shift.png) Here, we pushed value $v_{17}$ onto the stack. All other values on the stack are shifted by one slot to the right and the stack depth increases by $1$. There is not enough space at the top of the stack for all $17$ values, thus, $v_1$ needs to be moved to the overflow table. @@ -71,21 +78,22 @@ To do this, we need to rely on another column: $k_0$. This is a system column wh The row we want to add to the overflow table is defined by tuple $(clk, v1, 0)$, and after it is added, the table would look like so: -![](../../assets/design/stack/stack_overflow_table_post_1_right_shift.png) +![stack_overflow_table_post_1_right_shift](../../assets/design/stack/stack_overflow_table_post_1_right_shift.png) The reason we use VM clock cycle as row address is that the clock cycle is guaranteed to be unique, and thus, the same row can not be added to the table twice. Let's push another item onto the stack: -![](../../assets/design/stack/stack_overflow_push_2nd_item.png) +![stack_overflow_push_2nd_item](../../assets/design/stack/stack_overflow_push_2nd_item.png) Again, as we push $v_{18}$ onto the stack, all items on the stack are shifted to the right, and now $v_2$ needs to be moved to the overflow table. The tuple we want to insert into the table now is $(clk+1, v2, clk)$. After the operation, the overflow table will look like so: -![](../../assets/design/stack/stack_overflow_table_post_2_right_shift.png) +![stack_overflow_table_post_2_right_shift](../../assets/design/stack/stack_overflow_table_post_2_right_shift.png) Notice that $t_2$ for row which contains value $v_2$ points back to the row with address $clk$. Overall, during a right shift we do the following: + * Increment stack depth by $1$. * Shift stack columns $s_0, ..., s_{14}$ right by $1$ slot. * Add a row to the overflow table described by tuple $(k_0, s_{15}, b_0)$. @@ -94,9 +102,10 @@ Overall, during a right shift we do the following: Also, as mentioned previously, the prover sets values in $h_0$ non-deterministically to $\frac{1}{b_0 - 16}$. ## Left shift + If an operation removes an item from the stack, we say that the operation caused a left shift. For example, a `DROP` operation causes a left shift. Assuming the stack is in the state we left it at the end of the previous section, graphically, this looks like so: -![](../../assets/design/stack/stack_1st_left_shift.png) +![stack_1st_left_shift](../../assets/design/stack/stack_1st_left_shift.png) Overall, during the left shift we do the following: @@ -115,6 +124,7 @@ Overall, during the left shift we do the following: If the stack depth becomes (or remains) $16$, the prover can set $h_0$ to any value (e.g., $0$). But if the depth is greater than $16$ the prover sets $h_0$ to $\frac{1}{b_0 - 16}$. ## AIR Constraints + To simplify constraint descriptions, we'll assume that the VM exposes two binary flag values described below. | Flag | Degree | Description | @@ -125,6 +135,7 @@ To simplify constraint descriptions, we'll assume that the VM exposes two binary These flags are mutually exclusive. That is, if $f_{shl}=1$, then $f_{shr}=0$ and vice versa. However, both flags can be set to $0$ simultaneously. This happens when the executed instruction does not shift the stack. How these flags are computed is described [here](./op_constraints.md). ### Stack overflow flag + Additionally, we'll define a flag to indicate whether the overflow table contains values. This flag will be set to $0$ when the overflow table is empty, and to $1$ otherwise (i.e., when stack depth $>16$). This flag can be computed as follows: $$ diff --git a/docs/src/design/stack/op_constraints.md b/docs/src/design/stack/op_constraints.md index f4502b2daa..c2e807effd 100644 --- a/docs/src/design/stack/op_constraints.md +++ b/docs/src/design/stack/op_constraints.md @@ -189,9 +189,9 @@ This group contains operations which require constraints with degree up to $3$. | `JOIN` | $87$ | `101_0111` | [Flow control ops](../decoder/main.md) | $5$ | | `DYN` | $88$ | `101_1000` | [Flow control ops](../decoder/main.md) | $5$ | | `RCOMBBASE` | $89$ | `101_1001` | [Crypto ops](./crypto_ops.md) | $5$ | -| `` | $90$ | `101_1010` | | $5$ | -| `` | $91$ | `101_1011` | | $5$ | -| `` | $92$ | `101_1100` | | $5$ | +| `EMIT` | $90$ | `101_1010` | [System ops](./system_ops.md) | $5$ | +| `PUSH` | $91$ | `101_1011` | [I/O ops](./io_ops.md) | $5$ | +| `DYNCALL` | $92$ | `101_1100` | [Flow control ops](../decoder/main.md) | $5$ | | `` | $93$ | `101_1101` | | $5$ | | `` | $94$ | `101_1110` | | $5$ | | `` | $95$ | `101_1111` | | $5$ | @@ -211,7 +211,7 @@ This group contains operations which require constraints with degree up to $5$. | Operation | Opcode value | Binary encoding | Operation group | Flag degree | | ------------ | :----------: | :-------------: | :-------------------------------------:| :---------: | | `MRUPDATE` | $96$ | `110_0000` | [Crypto ops](./crypto_ops.md) | $4$ | -| `PUSH` | $100$ | `110_0100` | [I/O ops](./io_ops.md) | $4$ | +| `` | $100$ | `110_0100` | | $4$ | | `SYSCALL` | $104$ | `110_1000` | [Flow control ops](../decoder/main.md) | $4$ | | `CALL` | $108$ | `110_1100` | [Flow control ops](../decoder/main.md) | $4$ | | `END` | $112$ | `111_0000` | [Flow control ops](../decoder/main.md) | $4$ | @@ -294,3 +294,13 @@ $$ $$ f_{ctrl} = f_{span,join,split,loop} + f_{end,repeat,respan,halt} + f_{dyn} + f_{call} + f_{syscall} \text{ | degree} = 5 $$ + +### Immediate value flag + +The immediate value flag $f_{imm}$ is set to 1 when an operation has an immediate value, and 0 otherwise: + +$$ +f_{imm} = f_{push} + f_{emit} \text{ | degree} = 4 +$$ + +Note that the `ASSERT`, `MPVERIFY` and other operations have immediate values too. However, these immediate values are not included in the MAST digest, and hence are not considered for the $f_{imm}$ flag. diff --git a/docs/src/design/stack/system_ops.md b/docs/src/design/stack/system_ops.md index bebd639f23..becbbadae3 100644 --- a/docs/src/design/stack/system_ops.md +++ b/docs/src/design/stack/system_ops.md @@ -10,6 +10,17 @@ The `NOOP` operation does not impose any constraints besides the ones needed to s'_i - s_i = 0 \ \text{ for } i \in [0, 16) \text { | degree} = 1 $$ +## EMIT +Similarly to `NOOP`, the `EMIT` operation advances the cycle counter but does not change the state of the operand stack (i.e., the depth of the stack and the values on the stack remain the same). + +The `EMIT` operation does not impose any constraints besides the ones needed to ensure that the entire state of the stack is copied over. This constraint looks like so: + +>$$ +s'_i - s_i = 0 \ \text{ for } i \in [0, 16) \text { | degree} = 1 +$$ + +Additionally, the prover puts `EMIT`'s immediate value in the first user op helper register non-deterministically. The [Op Group Table](../decoder/main.md#op-group-table) is responsible for ensuring that the prover sets the appropriate value. + ## ASSERT The `ASSERT` operation pops an element off the stack and checks if the popped element is equal to $1$. If the element is not equal to $1$, program execution fails. diff --git a/docs/src/intro/main.md b/docs/src/intro/main.md index 4d39927d79..9937727635 100644 --- a/docs/src/intro/main.md +++ b/docs/src/intro/main.md @@ -2,7 +2,7 @@ Miden VM is a zero-knowledge virtual machine written in Rust. For any program executed on Miden VM, a STARK-based proof of execution is automatically generated. This proof can then be used by anyone to verify that the program was executed correctly without the need for re-executing the program or even knowing the contents of the program. ## Status and features -Miden VM is currently on release v0.10. In this release, most of the core features of the VM have been stabilized, and most of the STARK proof generation has been implemented. While we expect to keep making changes to the VM internals, the external interfaces should remain relatively stable, and we will do our best to minimize the amount of breaking changes going forward. +Miden VM is currently on release v0.11. In this release, most of the core features of the VM have been stabilized, and most of the STARK proof generation has been implemented. While we expect to keep making changes to the VM internals, the external interfaces should remain relatively stable, and we will do our best to minimize the amount of breaking changes going forward. At this point, Miden VM is good enough for experimentation, and even for real-world applications, but it is not yet ready for production use. The codebase has not been audited and contains known and unknown bugs and security flaws. diff --git a/docs/src/intro/overview.md b/docs/src/intro/overview.md index 9059cd717e..7e72f54fb0 100644 --- a/docs/src/intro/overview.md +++ b/docs/src/intro/overview.md @@ -1,11 +1,13 @@ -## Miden VM overview +# Miden VM overview + Miden VM is a stack machine. The base data type of the MV is a field element in a 64-bit [prime field](https://en.wikipedia.org/wiki/Finite_field) defined by modulus $p = 2^{64} - 2^{32} + 1$. This means that all values that the VM operates with are field elements in this field (i.e., values between $0$ and $2^{64} - 2^{32}$, both inclusive). Miden VM consists of four high-level components as illustrated below. -![](../assets/intro/vm_components.png) +![vm_components](../assets/intro/vm_components.png) These components are: + * **Stack** which is a push-down stack where each item is a field element. Most assembly instructions operate with values located on the stack. The stack can grow up to $2^{32}$ items deep, however, only the top 16 items are directly accessible. * **Memory** which is a linear random-access read-write memory. The memory is word-addressable, meaning, four elements are located at each address, and we can read and write elements to/from memory in batches of four. Memory addresses can be in the range $[0, 2^{32})$. * **Chiplets** which are specialized circuits for accelerating certain types of computations. These include Rescue Prime Optimized (RPO) hash function, 32-bit binary operations, and 16-bit range checks. @@ -14,32 +16,37 @@ These components are: Miden VM comes with a default implementation of the host interface (with an in-memory advice provider). However, the users are able to provide their own implementations which can connect the VM to arbitrary data sources (e.g., a database or RPC calls) and define custom logic for handling events emitted by the VM. ## Writing programs + Our goal is to make Miden VM an easy compilation target for high-level languages such as Rust, Move, Sway, and others. We believe it is important to let people write programs in the languages of their choice. However, compilers to help with this have not been developed yet. Thus, for now, the primary way to write programs for Miden VM is to use [Miden assembly](../user_docs/assembly/main.md). While writing programs in assembly is far from ideal, Miden assembly does make this task a little bit easier by supporting high-level flow control structures and named procedures. ## Inputs and outputs + External inputs can be provided to Miden VM in two ways: -1. Public inputs can be supplied to the VM by initializing the stack with desired values before a program starts executing. Any number of stack items can be initialized in this way, but providing a large number of public inputs will increase the cost for the verifier. +1. Public inputs can be supplied to the VM by initializing the stack with desired values before a program starts executing. At most 16 values can be initialized in this way, so providing more than 16 values will cause an error. 2. Secret (or nondeterministic) inputs can be supplied to the VM via the [*advice provider*](#nondeterministic-inputs). There is no limit on how much data the advice provider can hold. -After a program finishes executing, the elements remaining on the stack become the outputs of the program. Since these outputs will be public inputs for the verifier, having a large stack at the end of execution will increase cost to the verifier. Therefore, it's best to drop unneeded output values. We've provided the [`truncate_stack`](../user_docs/stdlib/sys.md) utility function in the standard library for this purpose. +After a program finishes executing, the elements remaining on the stack become the outputs of the program. Notice that having more than 16 values on the stack at the end of execution will cause an error, so the values beyond the top 16 elements of the stack should be dropped. We've provided the [`truncate_stack`](../user_docs/stdlib/sys.md) utility procedure in the standard library for this purpose. The number of public inputs and outputs of a program can be reduced by making use of the advice stack and Merkle trees. Just 4 elements are sufficient to represent a root of a Merkle tree, which can be expanded into an arbitrary number of values. For example, if we wanted to provide a thousand public input values to the VM, we could put these values into a Merkle tree, initialize the stack with the root of this tree, initialize the advice provider with the tree itself, and then retrieve values from the tree during program execution using `mtree_get` instruction (described [here](../user_docs/assembly/cryptographic_operations.md#hashing-and-merkle-trees)). ### Stack depth restrictions + For reasons explained [here](../design/stack/main.md), the VM imposes the restriction that the stack depth cannot be smaller than $16$. This has the following effects: -- When initializing a program with fewer than $16$ inputs, the VM will pad the stack with zeros to ensure the depth is $16$ at the beginning of execution. -- If an operation would result in the stack depth dropping below $16$, the VM will insert a zero at the deep end of the stack to make sure the depth stays at $16$. +* When initializing a program with fewer than $16$ inputs, the VM will pad the stack with zeros to ensure the depth is $16$ at the beginning of execution. +* If an operation would result in the stack depth dropping below $16$, the VM will insert a zero at the deep end of the stack to make sure the depth stays at $16$. ### Nondeterministic inputs + The *advice provider* component is responsible for supplying nondeterministic inputs to the VM. These inputs only need to be known to the prover (i.e., they do not need to be shared with the verifier). The advice provider consists of three components: + * **Advice stack** which is a one-dimensional array of field elements. Being a stack, the VM can either push new elements onto the advice stack, or pop the elements from its top. * **Advice map** which is a key-value map where keys are words and values are vectors of field elements. The VM can copy values from the advice map onto the advice stack as well as insert new values into the advice map (e.g., from a region of memory). * **Merkle store** which contain structured data reducible to Merkle paths. Some examples of such structures are: Merkle tree, Sparse Merkle Tree, and a collection of Merkle paths. The VM can request Merkle paths from the Merkle store, as well as mutate it by updating or merging nodes contained in the store. diff --git a/docs/src/intro/usage.md b/docs/src/intro/usage.md index 3df2dad659..67137c8593 100644 --- a/docs/src/intro/usage.md +++ b/docs/src/intro/usage.md @@ -1,6 +1,6 @@ # Usage -Before you can use Miden VM, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). Miden VM v0.10 requires Rust version **1.80** or later. +Before you can use Miden VM, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). Miden VM v0.11 requires Rust version **1.82** or later. Miden VM consists of several crates, each of which exposes a small set of functionality. The most notable of these crates are: @@ -16,19 +16,19 @@ The above functionality is also exposed via the single [miden-vm](https://crates To compile Miden VM into a binary, we have a [Makefile](https://www.gnu.org/software/make/manual/make.html) with the following tasks: -``` +```shell make exec ``` This will place an optimized, multi-threaded `miden` executable into the `./target/optimized` directory. It is equivalent to executing: -``` +```shell cargo build --profile optimized --features concurrent,executable ``` If you would like to enable single-threaded mode, you can compile Miden VM using the following command: -``` +```shell make exec-single ``` @@ -38,9 +38,9 @@ Internally, Miden VM uses [rayon](https://github.com/rayon-rs/rayon) for paralle ### GPU acceleration -Miden VM proof generation can be accelerated via GPUs. Currently, GPU acceleration is enabled only on Apple silicon hardware (via [Metal]()). To compile Miden VM with Metal acceleration enabled, you can run the following command: +Miden VM proof generation can be accelerated via GPUs. Currently, GPU acceleration is enabled only on Apple Silicon hardware (via [Metal]()). To compile Miden VM with Metal acceleration enabled, you can run the following command: -``` +```shell make exec-metal ``` @@ -54,13 +54,13 @@ Miden VM execution and proof generation can be accelerated via vectorized instru To compile Miden VM with AVX2 acceleration enabled, you can run the following command: -``` +```shell make exec-avx2 ``` To compile Miden VM with SVE acceleration enabled, you can run the following command: -``` +```shell make exec-sve ``` @@ -72,7 +72,7 @@ Similar to Metal acceleration, SVE/AVX2 acceleration is currently applicable onl Once the executable has been compiled, you can run Miden VM like so: -``` +```shell ./target/optimized/miden [subcommand] [parameters] ``` @@ -85,17 +85,17 @@ Currently, Miden VM can be executed with the following subcommands: - `debug` - this will instantiate a [Miden debugger](../tools/debugger.md) against the specified Miden assembly program and inputs. - `analyze` - this will run a Miden assembly program against specific inputs and will output stats about its execution. - `repl` - this will initiate the [Miden REPL](../tools/repl.md) tool. -- `example` - this will execute a Miden assembly example program, generate a STARK proof of execution and verify it. Currently it is possible to run `blake3` and `fibonacci` examples. +- `example` - this will execute a Miden assembly example program, generate a STARK proof of execution and verify it. Currently, it is possible to run `blake3` and `fibonacci` examples. All of the above subcommands require various parameters to be provided. To get more detailed help on what is needed for a given subcommand, you can run the following: -``` +```shell ./target/optimized/miden [subcommand] --help ``` For example: -``` +```shell ./target/optimized/miden prove --help ``` @@ -105,18 +105,26 @@ To execute a program using the Miden VM there needs to be a `.masm` file contain You can use `MIDEN_LOG` environment variable to control how much logging output the VM produces. For example: -``` +```shell MIDEN_LOG=trace ./target/optimized/miden [subcommand] [parameters] ``` If the level is not specified, `warn` level is set as default. +#### Enable Debugging features + +You can use the run command with `--debug` parameter to enable debugging with the [debug instruction](../user_docs/assembly/debugging.md) such as `debug.stack`: + +```shell +./target/optimized/miden run -a [path_to.masm] --debug +``` + ### Inputs As described [here](https://0xpolygonmiden.github.io/miden-vm/intro/overview.html#inputs-and-outputs) the Miden VM can consume public and secret inputs. - Public inputs: - - `operand_stack` - can be supplied to the VM to initialize the stack with the desired values before a program starts executing. There is no limit on the number of stack inputs that can be initialized in this way, although increasing the number of public inputs increases the cost to the verifier. + - `operand_stack` - can be supplied to the VM to initialize the stack with the desired values before a program starts executing. If the number of provided input values is less than 16, the input stack will be padded with zeros to the length of 16. The maximum number of the stack inputs is limited by 16 values, providing more than 16 values will cause an error. - Secret (or nondeterministic) inputs: - `advice_stack` - can be supplied to the VM. There is no limit on how much data the advice provider can hold. This is provided as a string array where each string entry represents a field element. - `advice_map` - is supplied as a map of 64-character hex keys, each mapped to an array of numbers. The hex keys are interpreted as 4 field elements and the arrays of numbers are interpreted as arrays of field elements. @@ -127,22 +135,34 @@ As described [here](https://0xpolygonmiden.github.io/miden-vm/intro/overview.htm _Check out the [comparison example](https://github.com/0xPolygonMiden/examples/blob/main/examples/comparison.masm) to see how secret inputs work._ -After a program finishes executing, the elements that remain on the stack become the outputs of the program, along with the overflow addresses (`overflow_addrs`) that are required to reconstruct the [stack overflow table](../design/stack/main.md#overflow-table). +After a program finishes executing, the elements that remain on the stack become the outputs of the program. Notice that the number of values on the operand stack at the end of the program execution can not be greater than 16, otherwise the program will return an error. The [`truncate_stack`](../user_docs/stdlib/sys.md) utility procedure from the standard library could be used to conveniently truncate the stack at the end of the program. ## Fibonacci example In the `miden/examples/fib` directory, we provide a very simple Fibonacci calculator example. This example computes the 1001st term of the Fibonacci sequence. You can execute this example on Miden VM like so: -``` +```shell ./target/optimized/miden run -a miden/examples/fib/fib.masm -n 1 ``` +### Capturing Output + This will run the example code to completion and will output the top element remaining on the stack. If you want the output of the program in a file, you can use the `--output` or `-o` flag and specify the path to the output file. For example: -``` +```shell ./target/optimized/miden run -a miden/examples/fib/fib.masm -o fib.out ``` This will dump the output of the program into the `fib.out` file. The output file will contain the state of the stack at the end of the program execution. + +### Running with debug instruction enabled + +Inside `miden/examples/fib/fib.masm`, insert `debug.stack` instruction anywhere between `begin` and `end`. Then run: + +```shell +./target/optimized/miden run -a miden/examples/fib/fib.masm -n 1 --debug +``` + +You should see output similar to "Stack state before step ..." diff --git a/docs/src/user_docs/assembly/code_organization.md b/docs/src/user_docs/assembly/code_organization.md index 8ad7753df0..4fabdd2d71 100644 --- a/docs/src/user_docs/assembly/code_organization.md +++ b/docs/src/user_docs/assembly/code_organization.md @@ -43,13 +43,19 @@ begin end ``` +Finally, a procedure cannot contain *solely* any number of [advice injectors](./io_operations.md#nondeterministic-inputs), `emit`, `debug` and `trace` instructions. In other words, it must contain at least one instruction which is not in the aforementioned list. + #### Dynamic procedure invocation -It is also possible to invoke procedures dynamically - i.e., without specifying target procedure labels at compile time. Unlike static procedure invocation, recursion is technically possible using dynamic invocation, but dynamic invocation is more expensive, and has less available operand stack capacity for procedure arguments, as 4 elements are required for the MAST root of the callee. There are two instructions, `dynexec` and `dyncall`, which can be used to execute dynamically-specified code targets. Both instructions expect the [MAST root](../../design/programs.md) of the target to be provided via the stack. The difference between `dynexec` and `dyncall` corresponds to the difference between `exec` and `call`, see the documentation on [procedure invocation semantics](./execution_contexts.md#procedure-invocation-semantics) for more detail. +It is also possible to invoke procedures dynamically - i.e., without specifying target procedure labels at compile time. A procedure can only call itself using dynamic invocation. There are two instructions, `dynexec` and `dyncall`, which can be used to execute dynamically-specified code targets. Both instructions expect the [MAST root](../../design/programs.md) of the target to be stored in memory, and the memory address of the MAST root to be on the top of the stack. The difference between `dynexec` and `dyncall` corresponds to the difference between `exec` and `call`, see the documentation on [procedure invocation semantics](./execution_contexts.md#procedure-invocation-semantics) for more details. + -Dynamic code execution in the same context is achieved by setting the top $4$ elements of the stack to the hash of the dynamic code block and then executing the `dynexec` or `dyncall` instruction. You can obtain the hash of a procedure in the current program, by name, using the `procref` instruction. See the following example of pairing the two: +Dynamic code execution in the same context is achieved by setting the top element of the stack to the memory address where the hash of the dynamic code block is stored, and then executing the `dynexec` or `dyncall` instruction. You can obtain the hash of a procedure in the current program, by name, using the `procref` instruction. See the following example of pairing the two: ``` -procref.foo +# Retrieve the hash of `foo`, store it at `ADDR`, and push `ADDR` on top of the stack +procref.foo mem_storew.ADDR dropw push.ADDR + +# Execute `foo` dynamically dynexec ``` @@ -57,14 +63,13 @@ During assembly, the `procref.foo` instruction is compiled to a `push.HASH`, whe During execution of the `dynexec` instruction, the VM does the following: -1. Reads, but does not consume, the top 4 elements of the stack to get the hash of the dynamic target (i.e. the operand stack is left unchanged). -2. Load the code block referenced by the hash, or trap if no such MAST root is known. -3. Execute the loaded code block +1. Read the top stack element $s_0$, and read the memory word at address $s_0$ (the hash of the dynamic target), +2. Shift the stack left by one element, +3. Load the code block referenced by the hash, or trap if no such MAST root is known, +4. Execute the loaded code block. The `dyncall` instruction is used the same way, with the difference that it involves a context switch to a new context when executing the referenced block, and switching back to the calling context once execution of the callee completes. -> **Note**: In both cases, the stack is left unchanged. Therefore, if the dynamic code is intended to manipulate the stack, it should start by either dropping or moving the code block hash from the top of the stack. - ### Modules A *module* consists of one or more procedures. There are two types of modules: *library modules* and *executable modules* (also called *programs*). diff --git a/docs/src/user_docs/assembly/events.md b/docs/src/user_docs/assembly/events.md index f6951a2d93..3636733d45 100644 --- a/docs/src/user_docs/assembly/events.md +++ b/docs/src/user_docs/assembly/events.md @@ -20,4 +20,4 @@ trace.EVENT_ID_1 trace.2 ``` -To make use of the `trace` instruction, programs should be ran with tracing flag (`-t` or `--tracing`), otherwise these instructions will be ignored. +To make use of the `trace` instruction, programs should be ran with tracing flag (`-t` or `--trace`), otherwise these instructions will be ignored. diff --git a/docs/src/user_docs/assembly/execution_contexts.md b/docs/src/user_docs/assembly/execution_contexts.md index abbbbe7e63..106c152b79 100644 --- a/docs/src/user_docs/assembly/execution_contexts.md +++ b/docs/src/user_docs/assembly/execution_contexts.md @@ -18,6 +18,7 @@ When a procedure is invoked via a `call`, `dyncall`, or a `syscall` instruction, - Execution moves into a different context. In case of the `call` and `dyncall` instructions, a new user context is created. In case of a `syscall` instruction, the execution moves back into the root context. - All stack items beyond the 16th item get "hidden" from the invoked procedure. That is, from the standpoint of the invoked procedure, the initial stack depth is set to 16. + - Note that for `dyncall`, the stack is shifted left by one element before being set to 16. When the callee returns, the following happens: @@ -26,8 +27,9 @@ When the callee returns, the following happens: The manipulations of the stack depth described above have the following implications: -- The top 16 elements of the stack can be used to pass parameters and return values between the caller and the callee. NOTE: Except for `dyncall`, as that instruction requires the first 4 elements to be the hash of the callee procedure, so only 12 elements are available in that case. +- The top 16 elements of the stack can be used to pass parameters and return values between the caller and the callee. - Caller's stack beyond the top 16 elements is inaccessible to the callee, and thus, is guaranteed not to change as the result of the call. + - As mentioned above, in the case of `dyncall`, the elements at indices 1 to 17 at the call site will be accessible to the callee (shifted to indices 0 to 16) - At the end of its execution, the callee must ensure that stack depth is exactly 16. If this is difficult to ensure manually, the [`truncate_stack`](../stdlib/sys.md) procedure can be used to drop all elements from the stack except for the top 16. #### Invoking via `exec` instruction @@ -42,7 +44,7 @@ A _kernel_ defines a set of procedures which can be invoked from user contexts t A kernel can be defined similarly to a regular [library module](./code_organization.md#library-modules) - i.e., it can have internal and exported procedures. However, there are some small differences between what procedures can do in a kernel module vs. what they can do in a regular library module. Specifically: -- Procedures in a kernel module cannot use `call` or `syscall` instructions. This means that creating a new context from within a `syscall` is not possible. +- Procedures in a kernel module cannot use `call`, `dyncall` or `syscall` instructions. This means that creating a new context from within a `syscall` is not possible. - Unlike procedures in regular library modules, procedures in a kernel module can use the `caller` instruction. This instruction puts the hash of the procedure which initiated the parent context onto the stack. ### Memory layout diff --git a/docs/src/user_docs/assembly/flow_control.md b/docs/src/user_docs/assembly/flow_control.md index 0b23a268f8..0075f1d334 100644 --- a/docs/src/user_docs/assembly/flow_control.md +++ b/docs/src/user_docs/assembly/flow_control.md @@ -82,9 +82,9 @@ where `instructions` can be a sequence of any instructions, including nested con 1. Pops the top item from the stack. 2. If the value of the item is $1$, `instructions` in the loop body are executed. - a. After the body is executed, the stack is popped again, and if the popped value is $1$, the body is executed again. - b. If the popped value is $0$, the loop is exited. - c. If the popped value is not binary, the execution fails. + 1. After the body is executed, the stack is popped again, and if the popped value is $1$, the body is executed again. + 2. If the popped value is $0$, the loop is exited. + 3. If the popped value is not binary, the execution fails. 3. If the value of the item is $0$, execution of loop body is skipped. 4. If the value is not binary, the execution fails. diff --git a/docs/src/user_docs/assembly/io_operations.md b/docs/src/user_docs/assembly/io_operations.md index 697d6ffc94..fc8c08608c 100644 --- a/docs/src/user_docs/assembly/io_operations.md +++ b/docs/src/user_docs/assembly/io_operations.md @@ -37,8 +37,8 @@ As mentioned above, nondeterministic inputs are provided to the VM via the advic | Instruction | Stack_input | Stack_output | Notes | | -------------------------------- | ------------------ | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| adv_push.*n*
- *(n cycles)* | [ ... ] | [a, ... ] | $a \leftarrow stack.pop()$
Pops $n$ values from the advice stack and pushes them onto the operand stack. Valid for $n \in \{1, ..., 16\}$.
Fails if the advice stack has fewer than $n$ values. | -| adv_loadw
- *(1 cycle)* | [0, 0, 0, 0, ... ] | [A, ... ] | $A \leftarrow stack.pop(4)$
Pop the next word (4 elements) from the advice stack and overwrites the first word of the operand stack (4 elements) with them.
Fails if the advice stack has fewer than $4$ values. | +| adv_push.*n*
- *(n cycles)* | [ ... ] | [a, ... ] | $a \leftarrow advstack.pop()$
Pops $n$ values from the advice stack and pushes them onto the operand stack. Valid for $n \in \{1, ..., 16\}$.
Fails if the advice stack has fewer than $n$ values. | +| adv_loadw
- *(1 cycle)* | [0, 0, 0, 0, ... ] | [A, ... ] | $A \leftarrow advstack.pop(4)$
Pop the next word (4 elements) from the advice stack and overwrites the first word of the operand stack (4 elements) with them.
Fails if the advice stack has fewer than $4$ values. | | adv_pipe
- *(1 cycle)* | [C, B, A, a, ... ] | [E, D, A, a', ... ] | $[D, E] \leftarrow [adv\_stack.pop(4), adv\_stack.pop(4)]$
$a' \leftarrow a + 2$
Pops the next two words from the advice stack, overwrites the top of the operand stack with them and also writes these words into memory at address $a$ and $a + 1$.
Fails if the advice stack has fewer than $8$ values. | > **Note**: The opcodes above always push data onto the operand stack so that the first element is placed deepest in the stack. For example, if the data on the stack is `a,b,c,d` and you use the opcode `adv_push.4`, the data will be `d,c,b,a` on your stack. This is also the behavior of the other opcodes. @@ -74,7 +74,7 @@ Memory is guaranteed to be initialized to zeros. Thus, when reading from memory | mem_storew
- *(1 cycle)*
mem_storew.*a*
- *(2-3 cycles)* | [a, A, ... ] | [A, ... ] | $A \rightarrow mem[a]$
Stores the top four elements of the stack in memory at address $a$. If $a$ is provided via the stack, it is removed from the stack first.
Fails if $a \ge 2^{32}$ | | mem_stream
- *(1 cycle)* | [C, B, A, a, ... ] | [E, D, A, a', ... ] | $[E, D] \leftarrow [mem[a], mem[a+1]]$
$a' \leftarrow a + 2$
Read two sequential words from memory starting at address $a$ and overwrites the first two words in the operand stack. | -The second way to access memory is via procedure locals using the instructions listed below. These instructions are available only in procedure context. The number of locals available to a given procedure must be specified at [procedure declaration](./code_organization.md#procedures) time, and trying to access more locals than was declared will result in a compile-time error. The number of locals per procedure is not limited, but the total number of locals available to all procedures at runtime must be smaller than $2^{32}$. +The second way to access memory is via procedure locals using the instructions listed below. These instructions are available only in procedure context. The number of locals available to a given procedure must be specified at [procedure declaration](./code_organization.md#procedures) time, and trying to access more locals than was declared will result in a compile-time error. A procedure can have at most $2^{16}$ locals, and the total number of locals available to all procedures at runtime is limited to $2^{30}$. | Instruction | Stack_input | Stack_output | Notes | | ------------------------------------ | ------------------ | ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | @@ -85,4 +85,4 @@ The second way to access memory is via procedure locals using the instructions l Unlike regular memory, procedure locals are not guaranteed to be initialized to zeros. Thus, when working with locals, one must assume that before a local memory address has been written to, it contains "garbage". -Internally in the VM, procedure locals are stored at memory offset stating at $2^{30}$. Thus, every procedure local has an absolute address in regular memory. The `locaddr.i` instruction is provided specifically to map an index of a procedure's local to an absolute address so that it can be passed to downstream procedures, when needed. +Internally in the VM, procedure locals are stored at memory offset starting at $2^{30}$. Thus, every procedure local has an absolute address in regular memory. The `locaddr.i` instruction is provided specifically to map an index of a procedure's local to an absolute address so that it can be passed to downstream procedures, when needed. diff --git a/docs/src/user_docs/assembly/u32_operations.md b/docs/src/user_docs/assembly/u32_operations.md index e11c3c1ae2..9b2e16d835 100644 --- a/docs/src/user_docs/assembly/u32_operations.md +++ b/docs/src/user_docs/assembly/u32_operations.md @@ -5,6 +5,56 @@ For instructions where one or more operands can be provided as immediate paramet In all the table below, the number of cycles it takes for the VM to execute each instruction is listed beneath the instruction. +### Notes on Undefined Behavior + +Most of the instructions documented below expect to receive operands whose values are valid `u32` +values, i.e. values in the range $0..=(2^{32} - 1)$. Currently, the semantics of the instructions +when given values outside of that range are undefined (as noted in the documented semantics for +each instruction). The rule with undefined behavior generally speaking is that you can make no +assumptions about what will happen if your program exhibits it. + +For purposes of describing the effects of undefined behavior below, we will refer to values which +are not valid for the input type of the affected operation, e.g. `u32`, as _poison_. Any use of a +poison value propagates the poison state. For example, performing `u32div` with a poison operand, +can be considered as producing a poison value as its result, for the purposes of discussing +undefined behavior semantics. + +With that in mind, there are two ways in which the effects of undefined behavior manifest: + +#### Executor Semantics + +From an executor perspective, currently, the semantics are completely undefined. An executor can +do everything from terminate the program, panic, always produce 42 as a result, produce a random +result, or something more principled. + +In practice, the Miden VM, when executing an operation, will almost always trap on _poison_ values. +This is not guaranteed, but is currently the case for most operations which have niches of undefined +behavior. To the extent that some other behavior may occur, it will generally be to truncate/wrap the +poison value, but this is subject to change at any time, and is undocumented. You should assume that +all operations will trap on poison. + +The reason the Miden VM makes the choice to trap on poison, is to ensure that undefined behavior is +caught close to the source, rather than propagated silently throughout the program. It also has the +effect of ensuring you do not execute a program with undefined behavior, and produce a proof that +is not actually valid, as we will describe in a moment. + +#### Verifier Semantics + +From the perspective of the verifier, the implementation details of the executor are completely +unknown. For example, the fact that the Miden VM traps on poison values is not actually verified +by constraints. An alternative executor implementation could choose _not_ to trap, and thus appear +to execute successfully. The resulting proof, however, as a result of the program exhibiting +undefined behavior, is not a valid proof. In effect the use of poison values "poisons" the proof +as well. + +As a result, a program that exhibits undefined behavior, and executes successfully, will produce +a proof that could pass verification, even though it should not. In other words, the proof does +not prove what it says it does. + +In the future, we may attempt to remove niches of undefined behavior in such a way that producing +such invalid proofs is not possible, but for the time being, you must ensure that your program does +not exhibit (or rely on) undefined behavior. + ### Conversions and tests | Instruction | Stack_input | Stack_output | Notes | @@ -53,11 +103,11 @@ If the error code is omitted, the default value of $0$ is assumed. | u32shl
- *(18 cycles)*
u32shl.*b*
- *(3 cycles)* | [b, a, ...] | [c, ...] | $c \leftarrow (a \cdot 2^b) \mod 2^{32}$
Undefined if $a \ge 2^{32}$ or $b > 31$ | | u32shr
- *(18 cycles)*
u32shr.*b*
- *(3 cycles)* | [b, a, ...] | [c, ...] | $c \leftarrow \lfloor a/2^b \rfloor$
Undefined if $a \ge 2^{32}$ or $b > 31$ | | u32rotl
- *(18 cycles)*
u32rotl.*b*
- *(3 cycles)* | [b, a, ...] | [c, ...] | Computes $c$ by rotating a 32-bit representation of $a$ to the left by $b$ bits.
Undefined if $a \ge 2^{32}$ or $b > 31$ | -| u32rotr
- *(22 cycles)*
u32rotr.*b*
- *(3 cycles)* | [b, a, ...] | [c, ...] | Computes $c$ by rotating a 32-bit representation of $a$ to the right by $b$ bits.
Undefined if $a \ge 2^{32}$ or $b > 31$ | +| u32rotr
- *(23 cycles)*
u32rotr.*b*
- *(3 cycles)* | [b, a, ...] | [c, ...] | Computes $c$ by rotating a 32-bit representation of $a$ to the right by $b$ bits.
Undefined if $a \ge 2^{32}$ or $b > 31$ | | u32popcnt
- *(33 cycles)* | [a, ...] | [b, ...] | Computes $b$ by counting the number of set bits in $a$ (hamming weight of $a$).
Undefined if $a \ge 2^{32}$ | -| u32clz
- *(37 cycles)* | [a, ...] | [b, ...] | Computes $b$ as a number of leading zeros of $a$.
Undefined if $a \ge 2^{32}$ | +| u32clz
- *(42 cycles)* | [a, ...] | [b, ...] | Computes $b$ as a number of leading zeros of $a$.
Undefined if $a \ge 2^{32}$ | | u32ctz
- *(34 cycles)* | [a, ...] | [b, ...] | Computes $b$ as a number of trailing zeros of $a$.
Undefined if $a \ge 2^{32}$ | -| u32clo
- *(36 cycles)* | [a, ...] | [b, ...] | Computes $b$ as a number of leading ones of $a$.
Undefined if $a \ge 2^{32}$ | +| u32clo
- *(41 cycles)* | [a, ...] | [b, ...] | Computes $b$ as a number of leading ones of $a$.
Undefined if $a \ge 2^{32}$ | | u32cto
- *(33 cycles)* | [a, ...] | [b, ...] | Computes $b$ as a number of trailing ones of $a$.
Undefined if $a \ge 2^{32}$ | diff --git a/miden/Cargo.toml b/miden/Cargo.toml index 5418fefa37..ce41d362bd 100644 --- a/miden/Cargo.toml +++ b/miden/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "miden-vm" -version = "0.10.5" +version = "0.11.0" description = "Miden virtual machine" -documentation = "https://docs.rs/miden-vm/0.10.5" +documentation = "https://docs.rs/miden-vm/0.11.0" readme = "README.md" categories = ["cryptography", "emulators", "no-std"] keywords = ["miden", "stark", "virtual-machine", "zkp"] @@ -57,38 +57,30 @@ metal = ["prover/metal", "std"] std = ["assembly/std", "processor/std", "prover/std", "verifier/std"] [dependencies] -assembly = { package = "miden-assembly", path = "../assembly", version = "0.10", default-features = false } +assembly = { package = "miden-assembly", path = "../assembly", version = "0.11", default-features = false } blake3 = "1.5" clap = { version = "4.4", features = ["derive"], optional = true } hex = { version = "0.4", optional = true } -processor = { package = "miden-processor", path = "../processor", version = "0.10", default-features = false } -prover = { package = "miden-prover", path = "../prover", version = "0.10", default-features = false } +processor = { package = "miden-processor", path = "../processor", version = "0.11", default-features = false } +prover = { package = "miden-prover", path = "../prover", version = "0.11", default-features = false } rustyline = { version = "13.0", default-features = false, optional = true } serde = { version = "1.0", optional = true } serde_derive = { version = "1.0", optional = true } serde_json = { version = "1.0", optional = true } -stdlib = { package = "miden-stdlib", path = "../stdlib", version = "0.10", default-features = false } -tracing = { version = "0.1", default-features = false, features = [ - "attributes", -] } -tracing-subscriber = { version = "0.3", optional = true, features = [ - "std", - "env-filter", -] } -tracing-forest = { version = "0.1", optional = true, features = [ - "ansi", - "smallvec", -] } -verifier = { package = "miden-verifier", path = "../verifier", version = "0.10", default-features = false } -vm-core = { package = "miden-core", path = "../core", version = "0.10", default-features = false } +stdlib = { package = "miden-stdlib", path = "../stdlib", version = "0.11", default-features = false } +tracing = { version = "0.1", default-features = false, features = ["attributes"] } +tracing-subscriber = { version = "0.3", optional = true, features = ["std", "env-filter"] } +tracing-forest = { version = "0.1", optional = true, features = ["ansi", "smallvec"] } +verifier = { package = "miden-verifier", path = "../verifier", version = "0.11", default-features = false } +vm-core = { package = "miden-core", path = "../core", version = "0.11", default-features = false } [dev-dependencies] assert_cmd = "2.0" criterion = "0.5" escargot = "0.5" num-bigint = "0.4" -predicates = "3.0" +predicates = "3.1" test-utils = { package = "miden-test-utils", path = "../test-utils" } -vm-core = { package = "miden-core", path = "../core", version = "0.10" } -winter-fri = { package = "winter-fri", version = "0.9" } -rand_chacha = "0.3.1" +vm-core = { package = "miden-core", path = "../core", version = "0.11" } +winter-fri = { package = "winter-fri", version = "0.10" } +rand_chacha = "0.3" diff --git a/miden/README.md b/miden/README.md index 951991a067..e8709ed4ca 100644 --- a/miden/README.md +++ b/miden/README.md @@ -21,12 +21,12 @@ All Miden programs can be reduced to a single 32-byte value, called program hash Currently, there are 3 ways to get values onto the stack: 1. You can use `push` instruction to push values onto the stack. These values become a part of the program itself, and, therefore, cannot be changed between program executions. You can think of them as constants. -2. The stack can be initialized to some set of values at the beginning of the program. These inputs are public and must be shared with the verifier for them to verify a proof of the correct execution of a Miden program. While it is possible to initialize the stack with a large number of values, we recommend keeping the number of initial values at 16 or fewer as each initial value beyond 16 increases verifier complexity. +2. The stack can be initialized to some set of values at the beginning of the program. These inputs are public and must be shared with the verifier for them to verify a proof of the correct execution of a Miden program. At most 16 values could be provided for the stack initialization, attempts to provide more than 16 values will cause an error. 3. The program may request nondeterministic advice inputs from the prover. These inputs are secret inputs. This means that the prover does not need to share them with the verifier. There are three types of advice inputs: (1) a single advice stack which can contain any number of elements; (2) a key-mapped element lists which can be pushed onto the advice stack; (3) a Merkle store, which is used to provide nondeterministic inputs for instructions which work with Merkle trees. There are no restrictions on the number of advice inputs a program can request. The stack is provided to Miden VM via `StackInputs` struct. These are public inputs of the execution, and should also be provided to the verifier. The secret inputs for the program are provided via the `Host` interface. The default implementation of the host relies on in-memory advice provider (`MemAdviceProvider`) that can be commonly used for operations that won't require persistence. -Values remaining on the stack after a program is executed can be returned as stack outputs. You can specify exactly how many values (from the top of the stack) should be returned. Similar to stack inputs, a large number of values can be returned via the stack, however, we recommend keeping this number to under 16 not to overburden the verifier. +Values remaining on the stack after a program is executed can be returned as stack outputs. You can specify exactly how many values (from the top of the stack) should be returned. Notice, that, similar to stack inputs, at most 16 values can be returned via the stack. Attempts to return more than 16 values will cause an error. Having a small number elements to describe public inputs and outputs of a program may seem limiting, however, just 4 elements are sufficient to represent a root of a Merkle tree or a sequential hash of elements. Both of these can be expanded into an arbitrary number of values by supplying the actual values non-deterministically via the host interface. @@ -225,7 +225,7 @@ If you want to execute, prove, and verify programs on Miden VM, but don't want t ### Compiling Miden VM -First, make sure you have Rust [installed](https://www.rust-lang.org/tools/install). The current version of Miden VM requires Rust version **1.80** or later. +First, make sure you have Rust [installed](https://www.rust-lang.org/tools/install). The current version of Miden VM requires Rust version **1.82** or later. Then, to compile Miden VM into a binary, run the following `make` command: @@ -308,6 +308,7 @@ Miden VM can be compiled with the following features: - `executable` - required for building Miden VM binary as described above. Implies `std`. - `metal` - enables [Metal]()-based acceleration of proof generation (for recursive proofs) on supported platforms (e.g., Apple silicon). - `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly. + - Only the `wasm32-unknown-unknown` and `wasm32-wasip1` targets are officially supported. To compile with `no_std`, disable default features via `--no-default-features` flag. diff --git a/miden/src/cli/compile.rs b/miden/src/cli/compile.rs index 5e2b112b7a..8ce8d7bce3 100644 --- a/miden/src/cli/compile.rs +++ b/miden/src/cli/compile.rs @@ -32,7 +32,7 @@ impl CompileCmd { let libraries = Libraries::new(&self.library_paths)?; // compile the program - let compiled_program = program.compile(&Debug::Off, &libraries.libraries)?; + let compiled_program = program.compile(Debug::Off, &libraries.libraries)?; // report program hash to user let program_hash: [u8; 32] = compiled_program.hash().into(); diff --git a/miden/src/cli/data.rs b/miden/src/cli/data.rs index c137203024..f49adb1034 100644 --- a/miden/src/cli/data.rs +++ b/miden/src/cli/data.rs @@ -42,6 +42,15 @@ impl Debug { } } +impl From for Debug { + fn from(value: bool) -> Self { + match value { + true => Debug::On, + false => Debug::Off, + } + } +} + // MERKLE DATA // ================================================================================================ @@ -314,7 +323,6 @@ impl InputFile { #[derive(Deserialize, Serialize, Debug)] pub struct OutputFile { pub stack: Vec, - pub overflow_addrs: Vec, } /// Helper methods to interact with the output file @@ -322,12 +330,7 @@ impl OutputFile { /// Returns a new [OutputFile] from the specified outputs vectors pub fn new(stack_outputs: &StackOutputs) -> Self { Self { - stack: stack_outputs.stack().iter().map(|&v| v.to_string()).collect::>(), - overflow_addrs: stack_outputs - .overflow_addrs() - .iter() - .map(|&v| v.to_string()) - .collect::>(), + stack: stack_outputs.iter().map(|&v| v.to_string()).collect::>(), } } @@ -366,17 +369,11 @@ impl OutputFile { .map_err(|err| format!("Failed to write output data - {}", err)) } - /// Converts outputs vectors for stack and overflow addresses to [StackOutputs]. + /// Converts stack output vector to [StackOutputs]. pub fn stack_outputs(&self) -> Result { let stack = self.stack.iter().map(|v| v.parse::().unwrap()).collect::>(); - let overflow_addrs = self - .overflow_addrs - .iter() - .map(|v| v.parse::().unwrap()) - .collect::>(); - - StackOutputs::try_from_ints(stack, overflow_addrs) + StackOutputs::try_from_ints(stack) .map_err(|e| format!("Construct stack outputs failed {e}")) } } @@ -416,7 +413,7 @@ impl ProgramFile { /// Compiles this program file into a [Program]. #[instrument(name = "compile_program", skip_all)] - pub fn compile<'a, I>(&self, debug: &Debug, libraries: I) -> Result + pub fn compile<'a, I>(&self, debug: Debug, libraries: I) -> Result where I: IntoIterator, { @@ -552,7 +549,7 @@ impl Libraries { // ================================================================================================ #[cfg(test)] mod test { - use super::InputFile; + use super::{Debug, InputFile}; #[test] fn test_merkle_data_parsing() { @@ -626,4 +623,16 @@ mod test { let merkle_store = inputs.parse_merkle_store().unwrap(); assert!(merkle_store.is_some()); } + + #[test] + fn test_debug_from_true() { + let debug_mode: Debug = true.into(); // true.into() will also test Debug.from(true) + assert!(matches!(debug_mode, Debug::On)); + } + + #[test] + fn test_debug_from_false() { + let debug_mode: Debug = false.into(); // false.into() will also test Debug.from(false) + assert!(matches!(debug_mode, Debug::Off)); + } } diff --git a/miden/src/cli/debug/mod.rs b/miden/src/cli/debug/mod.rs index 98f268760f..3670692001 100644 --- a/miden/src/cli/debug/mod.rs +++ b/miden/src/cli/debug/mod.rs @@ -42,7 +42,7 @@ impl DebugCmd { // load program from file and compile let program = ProgramFile::read_with(self.assembly_file.clone(), source_manager.clone())? - .compile(&Debug::On, &libraries.libraries)?; + .compile(Debug::On, &libraries.libraries)?; let program_hash: [u8; 32] = program.hash().into(); println!("Debugging program with hash {}...", hex::encode(program_hash)); diff --git a/miden/src/cli/prove.rs b/miden/src/cli/prove.rs index f4f36da9a4..377214bb25 100644 --- a/miden/src/cli/prove.rs +++ b/miden/src/cli/prove.rs @@ -55,14 +55,14 @@ pub struct ProveCmd { security: String, /// Enable tracing to monitor execution of the VM - #[clap(short = 't', long = "tracing")] - tracing: bool, + #[clap(short = 't', long = "trace")] + trace: bool, } impl ProveCmd { pub fn get_proof_options(&self) -> Result { let exec_options = - ExecutionOptions::new(Some(self.max_cycles), self.expected_cycles, self.tracing)?; + ExecutionOptions::new(Some(self.max_cycles), self.expected_cycles, self.trace, false)?; Ok(match self.security.as_str() { "96bits" => { if self.rpx { @@ -145,7 +145,7 @@ fn load_data(params: &ProveCmd) -> Result<(Program, InputFile), Report> { // load program from file and compile let program = - ProgramFile::read(¶ms.assembly_file)?.compile(&Debug::Off, &libraries.libraries)?; + ProgramFile::read(¶ms.assembly_file)?.compile(Debug::Off, &libraries.libraries)?; // load input data from file let input_data = InputFile::read(¶ms.input_file, ¶ms.assembly_file)?; diff --git a/miden/src/cli/run.rs b/miden/src/cli/run.rs index 07ede629f7..afb6c34a6b 100644 --- a/miden/src/cli/run.rs +++ b/miden/src/cli/run.rs @@ -4,7 +4,7 @@ use assembly::diagnostics::{IntoDiagnostic, Report, WrapErr}; use clap::Parser; use processor::{DefaultHost, ExecutionOptions, ExecutionTrace}; -use super::data::{instrument, Debug, InputFile, Libraries, OutputFile, ProgramFile}; +use super::data::{instrument, InputFile, Libraries, OutputFile, ProgramFile}; #[derive(Debug, Clone, Parser)] #[clap(about = "Run a miden program")] @@ -38,8 +38,12 @@ pub struct RunCmd { output_file: Option, /// Enable tracing to monitor execution of the VM - #[clap(short = 't', long = "tracing")] - tracing: bool, + #[clap(short = 't', long = "trace")] + trace: bool, + + /// Enable debug instructions + #[clap(short = 'd', long = "debug")] + debug: bool, } impl RunCmd { @@ -106,16 +110,19 @@ fn run_program(params: &RunCmd) -> Result<(ExecutionTrace, [u8; 32]), Report> { let libraries = Libraries::new(¶ms.library_paths)?; // load program from file and compile - let program = - ProgramFile::read(¶ms.assembly_file)?.compile(&Debug::Off, &libraries.libraries)?; + let program = ProgramFile::read(¶ms.assembly_file)? + .compile(params.debug.into(), &libraries.libraries)?; // load input data from file let input_data = InputFile::read(¶ms.input_file, ¶ms.assembly_file)?; - // get execution options - let execution_options = - ExecutionOptions::new(Some(params.max_cycles), params.expected_cycles, params.tracing) - .into_diagnostic()?; + let execution_options = ExecutionOptions::new( + Some(params.max_cycles), + params.expected_cycles, + params.trace, + params.debug, + ) + .into_diagnostic()?; // fetch the stack and program inputs from the arguments let stack_inputs = input_data.parse_stack_inputs().map_err(Report::msg)?; diff --git a/miden/src/examples/blake3.rs b/miden/src/examples/blake3.rs index 6db77685d3..4d4133b59f 100644 --- a/miden/src/examples/blake3.rs +++ b/miden/src/examples/blake3.rs @@ -22,7 +22,7 @@ pub fn get_example(n: usize) -> Example> { ); let mut host = DefaultHost::default(); - host.load_mast_forest(StdLibrary::default().into()); + host.load_mast_forest(StdLibrary::default().mast_forest().clone()); let stack_inputs = StackInputs::try_from_ints(INITIAL_HASH_VALUE.iter().map(|&v| v as u64)).unwrap(); @@ -41,11 +41,13 @@ fn generate_blake3_program(n: usize) -> Program { let program = format!( " use.std::crypto::hashes::blake3 + use.std::sys begin repeat.{} exec.blake3::hash_1to1 end + exec.sys::truncate_stack end", n ); diff --git a/miden/src/examples/mod.rs b/miden/src/examples/mod.rs index 186983ac9b..c377c8ec7d 100644 --- a/miden/src/examples/mod.rs +++ b/miden/src/examples/mod.rs @@ -76,8 +76,12 @@ pub enum ExampleType { impl ExampleOptions { pub fn get_proof_options(&self) -> Result { - let exec_options = - ExecutionOptions::new(Some(self.max_cycles), self.expected_cycles, self.tracing)?; + let exec_options = ExecutionOptions::new( + Some(self.max_cycles), + self.expected_cycles, + self.tracing, + false, + )?; Ok(match self.security.as_str() { "96bits" => { if self.rpx { diff --git a/miden/src/repl/mod.rs b/miden/src/repl/mod.rs index b0c41f4e2c..692e29df52 100644 --- a/miden/src/repl/mod.rs +++ b/miden/src/repl/mod.rs @@ -6,141 +6,141 @@ use processor::ContextId; use rustyline::{error::ReadlineError, DefaultEditor}; use stdlib::StdLibrary; -/// This work is in continuation to the amazing work done by team `Scribe` -/// [here](https://github.com/ControlCplusControlV/Scribe/blob/main/transpiler/src/repl.rs#L8) -/// -/// The Miden Read–eval–print loop (REPL) is a Miden shell that allows for quick and easy debugging -/// of Miden assembly. To use the repl, simply type "miden repl" after building it with feature -/// "executable" (cargo build --release --feature executable) when in the miden home -/// crate and the repl will launch. After the REPL gets initialized, you can execute any Miden -/// instruction, undo executed instructions, check the state of the stack and memory at a given -/// point, and do many other useful things! When the REPL is exited, a `history.txt` file is saved. -/// One thing to note is that all the REPL native commands start with an `!` to differentiate them -/// from regular assembly instructions. -/// -/// Miden Instructions -/// All Miden instructions mentioned in the -/// [Miden Assembly section](https://0xpolygonmiden.github.io/miden-vm/user_docs/assembly/main.html) -/// are valid. -/// One can either input instructions one by one or multiple instructions in one input. -/// For example, the below two commands will result in the same output. -/// >> push.1 -/// >> push.2 -/// >> push.3 -/// -/// >> push.1 push.2 push.3 -/// -/// In order to execute a control flow operation, one needs to write the entire flow operation in -/// a single line with spaces between individual operations. -/// Ex. -/// ``` -/// repeat.20 -/// pow2 -/// end -/// ``` -/// should be written as -/// `repeat.20 pow2 end` -/// -/// To execute a control flow operation, one must write the entire statement in a single line with -/// spaces between individual operations. -/// ``` -/// >> repeat.20 -/// pow2 -/// end -/// ``` -/// -/// The above example should be written as follows in the REPL tool: -/// >> repeat.20 pow2 end -/// -/// `!stack` -/// The `!stack` command prints out the state of the stack at the last executed instruction. Since -/// the stack always contains at least 16 elements, 16 or more elements will be printed out (even -/// if all of them are zeros). -/// >> push.1 push.2 push.3 push.4 push.5 -/// >> exp -/// >> u32wrapping_mul -/// >> swap -/// >> eq.2 -/// >> assert -/// -/// The `!stack` command will print out the following state of the stack: -/// ``` -/// >> !stack -/// 3072 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -/// ``` -/// -/// `!undo` -/// The `!undo` command reverts to the previous state of the stack and memory by dropping off the -/// last executed assembly instruction from the program. One could use `!undo` as often as they want -/// to restore the state of a stack and memory $n$ instructions ago (provided there are $n$ -/// instructions in the program). The `!undo` command will result in an error if no remaining -/// instructions are left in the miden program. -/// ``` -/// >> push.1 push.2 push.3 -/// >> push.4 -/// >> !stack -/// 4 3 2 1 0 0 0 0 0 0 0 0 0 0 0 0 -/// >> push.5 -/// >> !stack -/// 5 4 3 2 1 0 0 0 0 0 0 0 0 0 0 0 -/// >> !undo -/// 4 3 2 1 0 0 0 0 0 0 0 0 0 0 0 0 -/// >> !undo -/// 3 2 1 0 0 0 0 0 0 0 0 0 0 0 0 0 -/// ``` -/// -///`!program` -/// The `!program` command prints out the entire miden program getting executed. E.g., in the below -/// ``` -/// scenario: >> push.1 -/// >> push.2 -/// >> push.3 -/// >> add -/// >> add -/// >> !program -/// begin -/// push.1 -/// push.2 -/// push.3 -/// add -/// add -/// end -/// ``` -/// -/// `!help` -/// The `!help` command prints out all the available commands in the REPL tool. -/// -/// `!mem` -/// The `!mem` command prints out the contents of all initialized memory locations. For each such -/// location, the address, along with its memory values, is printed. Recall that four elements are -/// stored at each memory address. -/// If the memory has at least one value that has been initialized: -/// ``` -/// >> !mem -/// 7: [1, 2, 0, 3] -/// 8: [5, 7, 3, 32] -/// 9: [9, 10, 2, 0] -/// ``` -/// -/// If the memory is not yet been initialized: -/// ``` -/// >> !mem -/// The memory has not been initialized yet -/// ``` -/// -/// `!mem[addr]` -/// The `!mem[addr]` command prints out memory contents at the address specified by `addr`. -/// If the `addr` has been initialized: -/// ``` -/// >> !mem[9] -/// 9: [9, 10, 2, 0] -/// ``` -/// -/// If the `addr` has not been initialized: -/// ``` -/// >> !mem[87] -/// Memory at address 87 is empty -/// ``` +// This work is in continuation to the amazing work done by team `Scribe` +// [here](https://github.com/ControlCplusControlV/Scribe/blob/main/transpiler/src/repl.rs#L8) +// +// The Miden Read–eval–print loop (REPL) is a Miden shell that allows for quick and easy debugging +// of Miden assembly. To use the repl, simply type "miden repl" after building it with feature +// "executable" (cargo build --release --feature executable) when in the miden home +// crate and the repl will launch. After the REPL gets initialized, you can execute any Miden +// instruction, undo executed instructions, check the state of the stack and memory at a given +// point, and do many other useful things! When the REPL is exited, a `history.txt` file is saved. +// One thing to note is that all the REPL native commands start with an `!` to differentiate them +// from regular assembly instructions. +// +// Miden Instructions +// All Miden instructions mentioned in the +// [Miden Assembly section](https://0xpolygonmiden.github.io/miden-vm/user_docs/assembly/main.html) +// are valid. +// One can either input instructions one by one or multiple instructions in one input. +// For example, the below two commands will result in the same output. +// >> push.1 +// >> push.2 +// >> push.3 +// +// >> push.1 push.2 push.3 +// +// In order to execute a control flow operation, one needs to write the entire flow operation in +// a single line with spaces between individual operations. +// Ex. +// ``` +// repeat.20 +// pow2 +// end +// ``` +// should be written as +// `repeat.20 pow2 end` +// +// To execute a control flow operation, one must write the entire statement in a single line with +// spaces between individual operations. +// ``` +// >> repeat.20 +// pow2 +// end +// ``` +// +// The above example should be written as follows in the REPL tool: +// >> repeat.20 pow2 end +// +// `!stack` +// The `!stack` command prints out the state of the stack at the last executed instruction. Since +// the stack always contains at least 16 elements, 16 or more elements will be printed out (even +// if all of them are zeros). +// >> push.1 push.2 push.3 push.4 push.5 +// >> exp +// >> u32wrapping_mul +// >> swap +// >> eq.2 +// >> assert +// +// The `!stack` command will print out the following state of the stack: +// ``` +// >> !stack +// 3072 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 +// ``` +// +// `!undo` +// The `!undo` command reverts to the previous state of the stack and memory by dropping off the +// last executed assembly instruction from the program. One could use `!undo` as often as they want +// to restore the state of a stack and memory $n$ instructions ago (provided there are $n$ +// instructions in the program). The `!undo` command will result in an error if no remaining +// instructions are left in the miden program. +// ``` +// >> push.1 push.2 push.3 +// >> push.4 +// >> !stack +// 4 3 2 1 0 0 0 0 0 0 0 0 0 0 0 0 +// >> push.5 +// >> !stack +// 5 4 3 2 1 0 0 0 0 0 0 0 0 0 0 0 +// >> !undo +// 4 3 2 1 0 0 0 0 0 0 0 0 0 0 0 0 +// >> !undo +// 3 2 1 0 0 0 0 0 0 0 0 0 0 0 0 0 +// ``` +// +//`!program` +// The `!program` command prints out the entire miden program getting executed. E.g., in the below +// ``` +// scenario: >> push.1 +// >> push.2 +// >> push.3 +// >> add +// >> add +// >> !program +// begin +// push.1 +// push.2 +// push.3 +// add +// add +// end +// ``` +// +// `!help` +// The `!help` command prints out all the available commands in the REPL tool. +// +// `!mem` +// The `!mem` command prints out the contents of all initialized memory locations. For each such +// location, the address, along with its memory values, is printed. Recall that four elements are +// stored at each memory address. +// If the memory has at least one value that has been initialized: +// ``` +// >> !mem +// 7: [1, 2, 0, 3] +// 8: [5, 7, 3, 32] +// 9: [9, 10, 2, 0] +// ``` +// +// If the memory is not yet been initialized: +// ``` +// >> !mem +// The memory has not been initialized yet +// ``` +// +// `!mem[addr]` +// The `!mem[addr]` command prints out memory contents at the address specified by `addr`. +// If the `addr` has been initialized: +// ``` +// >> !mem[9] +// 9: [9, 10, 2, 0] +// ``` +// +// If the `addr` has not been initialized: +// ``` +// >> !mem[87] +// Memory at address 87 is empty +// ``` /// Initiates the Miden Repl tool. pub fn start_repl(library_paths: &Vec, use_stdlib: bool) { @@ -295,8 +295,8 @@ pub fn start_repl(library_paths: &Vec, use_stdlib: bool) { .expect("Couldn't dump the program into the history file"); } -/// HELPER METHODS -/// -------------------------------------------------------------------------------------------- +// HELPER METHODS +// -------------------------------------------------------------------------------------------- /// Compiles and executes a compiled Miden program, returning the stack, memory and any Miden /// errors. The program is passed in as a String, passed to the Miden Assembler, and then passed diff --git a/miden/src/tools/mod.rs b/miden/src/tools/mod.rs index 70501fbfba..39d9d9ea54 100644 --- a/miden/src/tools/mod.rs +++ b/miden/src/tools/mod.rs @@ -38,7 +38,7 @@ impl Analyze { // fetch the stack and program inputs from the arguments let stack_inputs = input_data.parse_stack_inputs().map_err(Report::msg)?; let mut host = DefaultHost::new(input_data.parse_advice_provider().map_err(Report::msg)?); - host.load_mast_forest(StdLibrary::default().into()); + host.load_mast_forest(StdLibrary::default().mast_forest().clone()); let execution_details: ExecutionDetails = analyze(program.as_str(), stack_inputs, host) .expect("Could not retrieve execution details"); diff --git a/miden/tests/integration/air/chiplets/memory.rs b/miden/tests/integration/air/chiplets/memory.rs index 7db597e909..b339720640 100644 --- a/miden/tests/integration/air/chiplets/memory.rs +++ b/miden/tests/integration/air/chiplets/memory.rs @@ -65,7 +65,7 @@ fn write_read() { #[test] fn helper_write_read() { // Sequence of operations: [Span, Pad, MStorew, Drop, Drop, Drop, Drop, Pad, MLoad, ... ] - let source = "begin mem_storew.0 dropw mem_load.0 swapw end"; + let source = "begin mem_storew.0 dropw mem_load.0 movup.4 drop end"; let pub_inputs = vec![4, 3, 2, 1]; let trace = build_test!(source, &pub_inputs).execute().unwrap(); @@ -78,7 +78,13 @@ fn helper_write_read() { #[test] fn update() { - let source = "begin push.0.0.0.0 mem_loadw.0 mem_storew.0 swapw end"; + let source = " + begin + push.0.0.0.0 + mem_loadw.0 + mem_storew.0 + swapw dropw + end"; let pub_inputs = vec![8, 7, 6, 5, 4, 3, 2, 1]; build_test!(source, &pub_inputs).prove_and_verify(pub_inputs, false); diff --git a/miden/tests/integration/air/chiplets/mod.rs b/miden/tests/integration/air/chiplets/mod.rs index ca7137832b..94ec2b64aa 100644 --- a/miden/tests/integration/air/chiplets/mod.rs +++ b/miden/tests/integration/air/chiplets/mod.rs @@ -7,10 +7,12 @@ mod memory; #[test] fn chiplets() { // Test a program that uses all of the chiplets. - let source = "begin + let source = " + begin hperm # hasher operation push.5 push.10 u32or # bitwise operation mem_load # memory operation + drop end"; let pub_inputs = rand_vector::(8); diff --git a/miden/tests/integration/air/stack/mod.rs b/miden/tests/integration/air/stack/mod.rs index 3ba80691be..e678e04fc0 100644 --- a/miden/tests/integration/air/stack/mod.rs +++ b/miden/tests/integration/air/stack/mod.rs @@ -12,15 +12,6 @@ fn empty_input() { build_op_test!(&asm_op, &pub_inputs).prove_and_verify(pub_inputs, false); } -/// Test an empty starting stack but enough outputs that the overflow table is non-empty at the end. -#[test] -fn empty_input_overflow_output() { - let asm_ops = "push.17 push.18"; - let pub_inputs = vec![]; - - build_op_test!(&asm_ops, &pub_inputs).prove_and_verify(pub_inputs, false); -} - /// Test starting stack with some inputs but not full with no overflow outputs. #[test] fn some_inputs() { @@ -38,33 +29,3 @@ fn full_inputs() { build_op_test!(&asm_op, &pub_inputs).prove_and_verify(pub_inputs, false); } - -/// Test a script that finishes with enough outputs that the overflow table is non-empty at the end. -#[test] -fn full_inputs_overflow_outputs() { - let asm_ops = "push.17 push.18"; - let pub_inputs = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - - build_op_test!(&asm_ops, &pub_inputs).prove_and_verify(pub_inputs, false); -} - -/// Test a script initialized with enough inputs that the overflow table is non-empty at the start -/// but there's no overflow output at the end. -#[test] -fn overflow_inputs() { - let asm_op = "push.19 drop"; - - let pub_inputs = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]; - - build_op_test!(&asm_op, &pub_inputs).prove_and_verify(pub_inputs, false); -} - -/// Test a script initialized with enough inputs that the overflow table is non-empty at the start -/// and at the end. -#[test] -fn overflow_inputs_overflow_outputs() { - let asm_op = "push.19 push.20"; - let pub_inputs = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]; - - build_op_test!(&asm_op, &pub_inputs).prove_and_verify(pub_inputs, false); -} diff --git a/miden/tests/integration/flow_control/mod.rs b/miden/tests/integration/flow_control/mod.rs index 0ae5de84e1..d5bb8b880b 100644 --- a/miden/tests/integration/flow_control/mod.rs +++ b/miden/tests/integration/flow_control/mod.rs @@ -5,7 +5,7 @@ use miden_vm::Module; use processor::ExecutionError; use prover::Digest; use stdlib::StdLibrary; -use test_utils::{build_test, expect_exec_error, StackInputs, Test}; +use test_utils::{build_test, expect_exec_error, push_inputs, StackInputs, Test}; // SIMPLE FLOW CONTROL TESTS // ================================================================================================ @@ -65,7 +65,7 @@ fn counter_controlled_loop() { repeat.10 dup.1 mul end - swap drop + movdn.2 drop drop end"; let test = build_test!(source); @@ -143,9 +143,12 @@ fn local_fn_call() { let build_test = build_test!(source, &[1, 2]); expect_exec_error!(build_test, ExecutionError::InvalidStackDepthOnReturn(17)); + let inputs = (1_u64..18).collect::>(); + // dropping values from the stack in the current execution context should not affect values // in the overflow table from the parent execution context - let source = " + let source = format!( + " proc.foo repeat.20 drop @@ -153,19 +156,22 @@ fn local_fn_call() { end begin + {inputs} push.18 call.foo repeat.16 drop end - end"; - let inputs = (1_u64..18).collect::>(); + swapw dropw + end", + inputs = push_inputs(&inputs) + ); - let test = build_test!(source, &inputs); + let test = build_test!(source, &[]); test.expect_stack(&[2, 1]); - test.prove_and_verify(inputs, false); + test.prove_and_verify(vec![], false); } #[test] @@ -182,6 +188,8 @@ fn local_fn_call_with_mem_access() { call.foo mem_load.0 eq.7 + + swap drop end"; let test = build_test!(source, &[3, 7]); @@ -215,6 +223,38 @@ fn simple_syscall() { test.prove_and_verify(vec![1, 2], false); } +#[test] +fn simple_syscall_2() { + let kernel_source = " + export.foo + add + end + export.bar + mul + end + "; + + // Note: we call each twice to ensure that the multiset check handles it correctly + let program_source = " + begin + syscall.foo + syscall.foo + syscall.bar + syscall.bar + end"; + + // TODO: update and use macro? + let mut test = Test::new(&format!("test{}", line!()), program_source, false); + test.stack_inputs = StackInputs::try_from_ints([2, 2, 3, 2, 1]).unwrap(); + test.kernel_source = Some( + test.source_manager + .load(&format!("kernel{}", line!()), kernel_source.to_string()), + ); + test.expect_stack(&[24]); + + test.prove_and_verify(vec![2, 2, 3, 2, 1], false); +} + // DYNAMIC CODE EXECUTION // ================================================================================================ @@ -222,62 +262,51 @@ fn simple_syscall() { fn simple_dyn_exec() { let program_source = " proc.foo - # drop the top 4 values, since that will be the code hash when we call this dynamically - dropw add end begin - # call foo directly so it will get added to the CodeBlockTable - padw + # call foo directly call.foo # move the first result of foo out of the way movdn.4 - # use dynexec to call foo again via its hash, which is on the stack + # use dynexec to call foo again via its hash, which is stored at memory location 42 + mem_storew.42 dropw + push.42 dynexec end"; - // The hash of foo can be obtained from the code block table by: - // let program = test.compile(); - // let cb_table = program.cb_table(); - // Result: - // [BaseElement(14592192105906586403), BaseElement(9256464248508904838), - // BaseElement(17436090329036592832), BaseElement(10814467189528518943)] - // Integer values can be obtained via Felt::from_mont(14592192105906586403).as_int(), etc. + // The hash of foo can be obtained with: + // let context = assembly::testing::TestContext::new(); + // let program = context.assemble(program_source).unwrap(); + // let procedure_digests: Vec = program.mast_forest().procedure_digests().collect(); + // let foo_digest = procedure_digests[0]; + // std::println!("foo digest: {foo_digest:?}"); + // As ints: - // [16045159387802755434, 10308872899350860082, 17306481765929021384, 16642043361554117790] + // [7259075614730273379, 2498922176515930900, 11574583201486131710, 6285975441353882141] + + let stack_init: [u64; 7] = [ + 3, + // put the hash of foo on the stack + 7259075614730273379, + 2498922176515930900, + 11574583201486131710, + 6285975441353882141, + 1, + 2, + ]; let test = Test { - stack_inputs: StackInputs::try_from_ints([ - 3, - // put the hash of foo on the stack - 16045159387802755434, - 10308872899350860082, - 17306481765929021384, - 16642043361554117790, - 1, - 2, - ]) - .unwrap(), - ..Test::new(&format!("test{}", line!()), program_source, false) + stack_inputs: StackInputs::try_from_ints(stack_init).unwrap(), + ..Test::new(&format!("test{}", line!()), program_source, true) }; test.expect_stack(&[6]); - test.prove_and_verify( - vec![ - 3, - 16045159387802755434, - 10308872899350860082, - 17306481765929021384, - 16642043361554117790, - 1, - 2, - ], - false, - ); + test.prove_and_verify(stack_init.to_vec(), false); } #[test] @@ -286,21 +315,22 @@ fn dynexec_with_procref() { use.external::module proc.foo - dropw push.1.2 u32wrapping_add end begin - procref.foo + procref.foo mem_storew.42 dropw push.42 dynexec - procref.module::func + procref.module::func mem_storew.42 dropw push.42 dynexec dup push.4 assert_eq.err=101 + + swap drop end"; let mut test = build_test!(program_source, &[]); @@ -309,7 +339,6 @@ fn dynexec_with_procref() { "external::module".parse().unwrap(), "\ export.func - dropw u32wrapping_add.1 end ", @@ -322,9 +351,6 @@ fn dynexec_with_procref() { fn simple_dyncall() { let program_source = " proc.foo - # drop the top 4 values, since that will be the code hash when we call this dynamically - dropw - # test that the execution context has changed mem_load.0 assertz @@ -336,39 +362,44 @@ fn simple_dyncall() { # write to memory so we can test that `call` and `dyncall` change the execution context push.5 mem_store.0 - # call foo directly so it will get added to the CodeBlockTable - padw + # call foo directly call.foo # move the first result of foo out of the way movdn.4 # use dyncall to call foo again via its hash, which is on the stack + mem_storew.42 dropw + push.42 dyncall + + swapw dropw end"; - // The hash of foo can be obtained from the code block table by: - // let program = test.compile(); - // let cb_table = program.cb_table(); - // Result: - // [BaseElement(3961142802598954486), BaseElement(5305628994393606376), - // BaseElement(7971171833137344204), BaseElement(10465350313512331391)] - // Integer values can be obtained via Felt::from_mont(14592192105906586403).as_int(), etc. + // The hash of foo can be obtained with: + // let context = assembly::testing::TestContext::new(); + // let program = context.assemble(program_source).unwrap(); + // let procedure_digests: Vec = program.mast_forest().procedure_digests().collect(); + // let foo_digest = procedure_digests[0]; + // std::println!("foo digest: {foo_digest:?}"); + + // // As ints: - // [8324248212344458853, 17691992706129158519, 18131640149172243086, 16129275750103409835] + // [6751154577850596602, 235765701633049111, 16334162752640292120, 7786442719091086500] let test = Test { stack_inputs: StackInputs::try_from_ints([ 3, // put the hash of foo on the stack - 8324248212344458853, - 17691992706129158519, - 18131640149172243086, - 16129275750103409835, + 6751154577850596602, + 235765701633049111, + 16334162752640292120, + 7786442719091086500, 1, 2, ]) .unwrap(), + libraries: vec![StdLibrary::default().into()], ..Test::new(&format!("test{}", line!()), program_source, false) }; @@ -377,10 +408,10 @@ fn simple_dyncall() { test.prove_and_verify( vec![ 3, - 8324248212344458853, - 17691992706129158519, - 18131640149172243086, - 16129275750103409835, + 6751154577850596602, + 235765701633049111, + 16334162752640292120, + 7786442719091086500, 1, 2, ], @@ -388,6 +419,55 @@ fn simple_dyncall() { ); } +/// Calls `bar` dynamically, which issues a syscall. We ensure that the `caller` instruction in the +/// kernel procedure correctly returns the hash of `bar`. +/// +/// We also populate the stack before `dyncall` to ensure that stack depth is properly restored +/// after `dyncall`. +#[test] +fn dyncall_with_syscall_and_caller() { + let kernel_source = " + export.foo + caller + end + "; + + let program_source = " + proc.bar + syscall.foo + end + + begin + # Populate stack before call + push.1 push.2 push.3 push.4 padw + + # Prepare dyncall + procref.bar mem_storew.42 dropw push.42 + dyncall + + # Truncate stack + movupw.3 dropw movupw.3 dropw + end"; + + let mut test = Test::new(&format!("test{}", line!()), program_source, true); + test.kernel_source = Some( + test.source_manager + .load(&format!("kernel{}", line!()), kernel_source.to_string()), + ); + test.expect_stack(&[ + 7618101086444903432, + 9972424747203251625, + 14917526361757867843, + 9845116178182948544, + 4, + 3, + 2, + 1, + ]); + + test.prove_and_verify(vec![], false); +} + // PROCREF INSTRUCTION // ================================================================================================ @@ -421,6 +501,7 @@ fn procref() -> Result<(), Report> { let source = " use.std::math::u64 + use.std::sys proc.foo.4 push.3.4 @@ -430,6 +511,8 @@ fn procref() -> Result<(), Report> { procref.u64::overflowing_add push.0 procref.foo + + exec.sys::truncate_stack end"; let mut test = build_test!(source, &[]); diff --git a/miden/tests/integration/main.rs b/miden/tests/integration/main.rs index effe92c656..10720fdb14 100644 --- a/miden/tests/integration/main.rs +++ b/miden/tests/integration/main.rs @@ -1,6 +1,6 @@ extern crate alloc; -use test_utils::build_test; +use test_utils::{build_op_test, build_test}; mod air; mod cli; @@ -13,7 +13,7 @@ mod operations; #[test] fn simple_program() { - build_test!("begin push.1 push.2 add end").expect_stack(&[3]); + build_test!("begin push.1 push.2 add swap drop end").expect_stack(&[3]); } #[test] @@ -21,3 +21,14 @@ fn multi_output_program() { let test = build_test!("begin mul movup.2 drop end", &[1, 2, 3]); test.prove_and_verify(vec![1, 2, 3], false); } + +#[test] +fn program_with_respan() { + let source = " + repeat.49 + swap dup.1 add + end"; + let pub_inputs = vec![]; + + build_op_test!(source, &pub_inputs).prove_and_verify(pub_inputs, false); +} diff --git a/miden/tests/integration/operations/decorators/advice.rs b/miden/tests/integration/operations/decorators/advice.rs index 03e95e5303..4c80720f07 100644 --- a/miden/tests/integration/operations/decorators/advice.rs +++ b/miden/tests/integration/operations/decorators/advice.rs @@ -7,7 +7,7 @@ use test_utils::{ expect_exec_error, rand::{rand_array, rand_value}, serde::Serializable, - Felt, + Felt, TRUNCATE_STACK_PROC, }; const ADVICE_PUSH_SIG: &str = " @@ -31,7 +31,7 @@ const ADVICE_PUSH_SIG: &str = " #[test] fn advice_push_u64div() { // push a/b onto the advice stack and then move these values onto the operand stack. - let source = "begin adv.push_u64div adv_push.4 end"; + let source = "begin adv.push_u64div adv_push.4 movupw.2 dropw end"; // get two random 64-bit integers and split them into 32-bit limbs let a = rand_value::(); @@ -65,7 +65,11 @@ fn advice_push_u64div_repeat() { // - reads quotient from advice stack to the stack // - push 2_u64 to the stack divided into 2 32 bit limbs // Finally the first 2 elements of the stack are removed - let source = "begin + let source = format!( + " + {TRUNCATE_STACK_PROC} + + begin repeat.7 adv.push_u64div drop drop @@ -74,7 +78,10 @@ fn advice_push_u64div_repeat() { push.0 end drop drop - end"; + + exec.truncate_stack + end" + ); let mut a = 256; let a_hi = 0; @@ -103,7 +110,16 @@ fn advice_push_u64div_repeat() { #[test] fn advice_push_u64div_local_procedure() { // push a/b onto the advice stack and then move these values onto the operand stack. - let source = "proc.foo adv.push_u64div adv_push.4 end begin exec.foo end"; + let source = " + proc.foo + adv.push_u64div + adv_push.4 + end + + begin + exec.foo + movupw.2 dropw + end"; // get two random 64-bit integers and split them into 32-bit limbs let a = rand_value::(); @@ -131,7 +147,18 @@ fn advice_push_u64div_local_procedure() { #[test] fn advice_push_u64div_conditional_execution() { - let source = "begin eq if.true adv.push_u64div adv_push.4 else padw end end"; + let source = " + begin + eq + if.true + adv.push_u64div + adv_push.4 + else + padw + end + + movupw.2 dropw + end"; // if branch let test = build_test!(source, &[8, 0, 4, 0, 1, 1]); @@ -202,16 +229,17 @@ fn advice_insert_mem() { #[test] fn advice_push_mapval() { // --- test simple adv.mapval --------------------------------------------- - let source: &str = "begin - # stack: [4, 3, 2, 1, ...] - - # load the advice stack with values from the advice map and drop the key - adv.push_mapval - dropw + let source: &str = " + begin + # stack: [4, 3, 2, 1, ...] - # move the values from the advice stack to the operand stack - adv_push.4 + # load the advice stack with values from the advice map and drop the key + adv.push_mapval + dropw + # move the values from the advice stack to the operand stack + adv_push.4 + swapw dropw end"; let stack_inputs = [1, 2, 3, 4]; @@ -224,19 +252,20 @@ fn advice_push_mapval() { test.expect_stack(&[5, 6, 7, 8]); // --- test adv.mapval with offset ---------------------------------------- - let source: &str = "begin - # stack: [4, 3, 2, 1, ...] - - # shift the key on the stack by 2 slots - push.0 push.0 + let source: &str = " + begin + # stack: [4, 3, 2, 1, ...] - # load the advice stack with values from the advice map and drop the key - adv.push_mapval.2 - dropw drop drop + # shift the key on the stack by 2 slots + push.0 push.0 - # move the values from the advice stack to the operand stack - adv_push.4 + # load the advice stack with values from the advice map and drop the key + adv.push_mapval.2 + dropw drop drop + # move the values from the advice stack to the operand stack + adv_push.4 + swapw dropw end"; let stack_inputs = [1, 2, 3, 4]; @@ -249,17 +278,18 @@ fn advice_push_mapval() { test.expect_stack(&[5, 6, 7, 8]); // --- test simple adv.mapvaln -------------------------------------------- - let source: &str = "begin - # stack: [4, 3, 2, 1, ...] - - # load the advice stack with values from the advice map (including the number - # of elements) and drop the key - adv.push_mapvaln - dropw + let source: &str = " + begin + # stack: [4, 3, 2, 1, ...] - # move the values from the advice stack to the operand stack - adv_push.6 + # load the advice stack with values from the advice map (including the number + # of elements) and drop the key + adv.push_mapvaln + dropw + # move the values from the advice stack to the operand stack + adv_push.6 + swapdw dropw dropw end"; let stack_inputs = [1, 2, 3, 4]; @@ -272,20 +302,21 @@ fn advice_push_mapval() { test.expect_stack(&[15, 14, 13, 12, 11, 5]); // --- test adv.mapval with offset ---------------------------------------- - let source: &str = "begin - # stack: [4, 3, 2, 1, ...] - - # shift the key on the stack by 2 slots - push.0 push.0 + let source: &str = " + begin + # stack: [4, 3, 2, 1, ...] - # load the advice stack with values from the advice map (including the number - # of elements) and drop the key - adv.push_mapvaln.2 - dropw drop drop + # shift the key on the stack by 2 slots + push.0 push.0 - # move the values from the advice stack to the operand stack - adv_push.6 + # load the advice stack with values from the advice map (including the number + # of elements) and drop the key + adv.push_mapvaln.2 + dropw drop drop + # move the values from the advice stack to the operand stack + adv_push.6 + swapdw dropw dropw end"; let stack_inputs = [1, 2, 3, 4]; @@ -301,49 +332,51 @@ fn advice_push_mapval() { #[test] fn advice_insert_hdword() { // --- test hashing without domain ---------------------------------------- - let source: &str = "begin - # stack: [1, 2, 3, 4, 5, 6, 7, 8, ...] - - # hash and insert top two words into the advice map - adv.insert_hdword + let source: &str = " + begin + # stack: [1, 2, 3, 4, 5, 6, 7, 8, ...] - # manually compute the hash of the two words - hmerge - # => [KEY, ...] + # hash and insert top two words into the advice map + adv.insert_hdword - # load the advice stack with values from the advice map and drop the key - adv.push_mapval - dropw + # manually compute the hash of the two words + hmerge + # => [KEY, ...] - # move the values from the advice stack to the operand stack - adv_push.8 + # load the advice stack with values from the advice map and drop the key + adv.push_mapval + dropw + # move the values from the advice stack to the operand stack + adv_push.8 + swapdw dropw dropw end"; let stack_inputs = [8, 7, 6, 5, 4, 3, 2, 1]; let test = build_test!(source, &stack_inputs); test.expect_stack(&[1, 2, 3, 4, 5, 6, 7, 8]); // --- test hashing with domain ------------------------------------------- - let source: &str = "begin - # stack: [1, 2, 3, 4, 5, 6, 7, 8, ...] - - # hash and insert top two words into the advice map - adv.insert_hdword.3 + let source: &str = " + begin + # stack: [1, 2, 3, 4, 5, 6, 7, 8, ...] - # manually compute the hash of the two words - push.0.3.0.0 - swapw.2 swapw - hperm - dropw swapw dropw - # => [KEY, ...] + # hash and insert top two words into the advice map + adv.insert_hdword.3 - # load the advice stack with values from the advice map and drop the key - adv.push_mapval - dropw + # manually compute the hash of the two words + push.0.3.0.0 + swapw.2 swapw + hperm + dropw swapw dropw + # => [KEY, ...] - # move the values from the advice stack to the operand stack - adv_push.8 + # load the advice stack with values from the advice map and drop the key + adv.push_mapval + dropw + # move the values from the advice stack to the operand stack + adv_push.8 + swapdw dropw dropw end"; let stack_inputs = [8, 7, 6, 5, 4, 3, 2, 1]; let test = build_test!(source, &stack_inputs); diff --git a/miden/tests/integration/operations/decorators/asmop.rs b/miden/tests/integration/operations/decorators/asmop.rs index 261b585f2c..fb85ebcdbf 100644 --- a/miden/tests/integration/operations/decorators/asmop.rs +++ b/miden/tests/integration/operations/decorators/asmop.rs @@ -4,7 +4,7 @@ use vm_core::{debuginfo::Location, AssemblyOp, Felt, Operation}; #[test] fn asmop_one_span_block_test() { - let source = "begin push.1 push.2 add end"; + let source = "begin push.1 push.2 add swap drop swap drop end"; let test = build_debug_test!(source); let path = test.source.name(); let push1_loc = Some(Location { @@ -22,6 +22,26 @@ fn asmop_one_span_block_test() { start: 20.into(), end: (20 + 3).into(), }); + let swap1_loc = Some(Location { + path: path.clone(), + start: 24.into(), + end: (24 + 4).into(), + }); + let drop1_loc = Some(Location { + path: path.clone(), + start: 29.into(), + end: (29 + 4).into(), + }); + let swap2_loc = Some(Location { + path: path.clone(), + start: 34.into(), + end: (34 + 4).into(), + }); + let drop2_loc = Some(Location { + path: path.clone(), + start: 39.into(), + end: (39 + 4).into(), + }); let vm_state_iterator = test.execute_iter(); let expected_vm_state = vec![ VmStatePartial { @@ -86,6 +106,62 @@ fn asmop_one_span_block_test() { }, VmStatePartial { clk: RowIndex::from(6), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + swap1_loc, + "#exec::#main".to_string(), + 1, + "swap.1".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Swap), + }, + VmStatePartial { + clk: RowIndex::from(7), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + drop1_loc, + "#exec::#main".to_string(), + 1, + "drop".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Drop), + }, + VmStatePartial { + clk: RowIndex::from(8), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + swap2_loc, + "#exec::#main".to_string(), + 1, + "swap.1".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Swap), + }, + VmStatePartial { + clk: RowIndex::from(9), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + drop2_loc, + "#exec::#main".to_string(), + 1, + "drop".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Drop), + }, + VmStatePartial { + clk: RowIndex::from(10), asmop: None, op: Some(Operation::End), }, @@ -96,7 +172,7 @@ fn asmop_one_span_block_test() { #[test] fn asmop_with_one_procedure() { - let source = "proc.foo push.1 push.2 add end begin exec.foo end"; + let source = "proc.foo push.1 push.2 add end begin exec.foo swap drop swap drop end"; let test = build_debug_test!(source); let path = test.source.name(); let push1_loc = Some(Location { @@ -114,6 +190,26 @@ fn asmop_with_one_procedure() { start: 23.into(), end: (23 + 3).into(), }); + let swap1_loc = Some(Location { + path: path.clone(), + start: 46.into(), + end: (46 + 4).into(), + }); + let drop1_loc = Some(Location { + path: path.clone(), + start: 51.into(), + end: (51 + 4).into(), + }); + let swap2_loc = Some(Location { + path: path.clone(), + start: 56.into(), + end: (56 + 4).into(), + }); + let drop2_loc = Some(Location { + path: path.clone(), + start: 61.into(), + end: (61 + 4).into(), + }); let vm_state_iterator = test.execute_iter(); let expected_vm_state = vec![ VmStatePartial { @@ -178,6 +274,62 @@ fn asmop_with_one_procedure() { }, VmStatePartial { clk: RowIndex::from(6), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + swap1_loc, + "#exec::#main".to_string(), + 1, + "swap.1".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Swap), + }, + VmStatePartial { + clk: RowIndex::from(7), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + drop1_loc, + "#exec::#main".to_string(), + 1, + "drop".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Drop), + }, + VmStatePartial { + clk: RowIndex::from(8), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + swap2_loc, + "#exec::#main".to_string(), + 1, + "swap.1".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Swap), + }, + VmStatePartial { + clk: RowIndex::from(9), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + drop2_loc, + "#exec::#main".to_string(), + 1, + "drop".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Drop), + }, + VmStatePartial { + clk: RowIndex::from(10), asmop: None, op: Some(Operation::End), }, @@ -188,27 +340,44 @@ fn asmop_with_one_procedure() { #[test] fn asmop_repeat_test() { - let source = "begin + let source = " + begin repeat.3 push.1 push.2 add end + swapdw dropw dropw end"; let test = build_debug_test!(source); let path = test.source.name(); let push1_loc = Some(Location { path: path.clone(), - start: 43.into(), - end: (43 + 6).into(), + start: 52.into(), + end: (52 + 6).into(), }); let push2_loc = Some(Location { path: path.clone(), - start: 50.into(), - end: (50 + 6).into(), + start: 59.into(), + end: (59 + 6).into(), }); let add_loc = Some(Location { path: path.clone(), - start: 57.into(), - end: (57 + 3).into(), + start: 66.into(), + end: (66 + 3).into(), + }); + let swapdw_loc = Some(Location { + path: path.clone(), + start: 98.into(), + end: (98 + 6).into(), + }); + let dropw1_loc = Some(Location { + path: path.clone(), + start: 105.into(), + end: (105 + 5).into(), + }); + let dropw2_loc = Some(Location { + path: path.clone(), + start: 111.into(), + end: (111 + 5).into(), }); let vm_state_iterator = test.execute_iter(); let expected_vm_state = vec![ @@ -388,21 +557,142 @@ fn asmop_repeat_test() { }, VmStatePartial { clk: RowIndex::from(14), - asmop: None, - op: Some(Operation::Noop), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + swapdw_loc, + "#exec::#main".to_string(), + 1, + "swapdw".to_string(), + false, + ), + 1, + )), + op: Some(Operation::SwapDW), }, VmStatePartial { clk: RowIndex::from(15), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + dropw1_loc.clone(), + "#exec::#main".to_string(), + 4, + "dropw".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Drop), + }, + VmStatePartial { + clk: RowIndex::from(16), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + dropw1_loc.clone(), + "#exec::#main".to_string(), + 4, + "dropw".to_string(), + false, + ), + 2, + )), + op: Some(Operation::Drop), + }, + VmStatePartial { + clk: RowIndex::from(17), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + dropw1_loc.clone(), + "#exec::#main".to_string(), + 4, + "dropw".to_string(), + false, + ), + 3, + )), + op: Some(Operation::Drop), + }, + VmStatePartial { + clk: RowIndex::from(18), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + dropw1_loc, + "#exec::#main".to_string(), + 4, + "dropw".to_string(), + false, + ), + 4, + )), + op: Some(Operation::Drop), + }, + VmStatePartial { + clk: RowIndex::from(19), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + dropw2_loc.clone(), + "#exec::#main".to_string(), + 4, + "dropw".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Drop), + }, + VmStatePartial { + clk: RowIndex::from(20), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + dropw2_loc.clone(), + "#exec::#main".to_string(), + 4, + "dropw".to_string(), + false, + ), + 2, + )), + op: Some(Operation::Drop), + }, + VmStatePartial { + clk: RowIndex::from(21), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + dropw2_loc.clone(), + "#exec::#main".to_string(), + 4, + "dropw".to_string(), + false, + ), + 3, + )), + op: Some(Operation::Drop), + }, + VmStatePartial { + clk: RowIndex::from(22), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + dropw2_loc, + "#exec::#main".to_string(), + 4, + "dropw".to_string(), + false, + ), + 4, + )), + op: Some(Operation::Drop), + }, + VmStatePartial { + clk: RowIndex::from(23), asmop: None, op: Some(Operation::Noop), }, VmStatePartial { - clk: RowIndex::from(16), + clk: RowIndex::from(24), asmop: None, op: Some(Operation::Noop), }, VmStatePartial { - clk: RowIndex::from(17), + clk: RowIndex::from(25), asmop: None, op: Some(Operation::End), }, @@ -413,13 +703,16 @@ fn asmop_repeat_test() { #[test] fn asmop_conditional_execution_test() { - let source = "begin + let source = " + begin eq if.true push.1 push.2 add else push.3 push.4 add end + + swap drop swap drop end"; //if branch @@ -427,23 +720,43 @@ fn asmop_conditional_execution_test() { let path = test.source.name(); let eq_loc = Some(Location { path: path.clone(), - start: 18.into(), - end: (18 + 2).into(), + start: 27.into(), + end: (27 + 2).into(), }); let push1_loc = Some(Location { path: path.clone(), - start: 57.into(), - end: (57 + 6).into(), + start: 66.into(), + end: (66 + 6).into(), }); let push2_loc = Some(Location { path: path.clone(), - start: 64.into(), - end: (64 + 6).into(), + start: 73.into(), + end: (73 + 6).into(), }); let add_loc = Some(Location { path: path.clone(), - start: 71.into(), - end: (71 + 3).into(), + start: 80.into(), + end: (80 + 3).into(), + }); + let swap1_loc = Some(Location { + path: path.clone(), + start: 164.into(), + end: (164 + 4).into(), + }); + let drop1_loc = Some(Location { + path: path.clone(), + start: 169.into(), + end: (169 + 4).into(), + }); + let swap2_loc = Some(Location { + path: path.clone(), + start: 174.into(), + end: (174 + 4).into(), + }); + let drop2_loc = Some(Location { + path: path.clone(), + start: 179.into(), + end: (179 + 4).into(), }); let vm_state_iterator = test.execute_iter(); let expected_vm_state = vec![ @@ -460,10 +773,15 @@ fn asmop_conditional_execution_test() { VmStatePartial { clk: RowIndex::from(2), asmop: None, - op: Some(Operation::Span), + op: Some(Operation::Join), }, VmStatePartial { clk: RowIndex::from(3), + asmop: None, + op: Some(Operation::Span), + }, + VmStatePartial { + clk: RowIndex::from(4), asmop: Some(AsmOpInfo::new( AssemblyOp::new(eq_loc, "#exec::#main".to_string(), 1, "eq".to_string(), false), 1, @@ -471,22 +789,22 @@ fn asmop_conditional_execution_test() { op: Some(Operation::Eq), }, VmStatePartial { - clk: RowIndex::from(4), + clk: RowIndex::from(5), asmop: None, op: Some(Operation::End), }, VmStatePartial { - clk: RowIndex::from(5), + clk: RowIndex::from(6), asmop: None, op: Some(Operation::Split), }, VmStatePartial { - clk: RowIndex::from(6), + clk: RowIndex::from(7), asmop: None, op: Some(Operation::Span), }, VmStatePartial { - clk: RowIndex::from(7), + clk: RowIndex::from(8), asmop: Some(AsmOpInfo::new( AssemblyOp::new( push1_loc.clone(), @@ -500,7 +818,7 @@ fn asmop_conditional_execution_test() { op: Some(Operation::Pad), }, VmStatePartial { - clk: RowIndex::from(8), + clk: RowIndex::from(9), asmop: Some(AsmOpInfo::new( AssemblyOp::new( push1_loc, @@ -514,7 +832,7 @@ fn asmop_conditional_execution_test() { op: Some(Operation::Incr), }, VmStatePartial { - clk: RowIndex::from(9), + clk: RowIndex::from(10), asmop: Some(AsmOpInfo::new( AssemblyOp::new( push2_loc, @@ -528,7 +846,7 @@ fn asmop_conditional_execution_test() { op: Some(Operation::Push(Felt::new(2))), }, VmStatePartial { - clk: RowIndex::from(10), + clk: RowIndex::from(11), asmop: Some(AsmOpInfo::new( AssemblyOp::new(add_loc, "#exec::#main".to_string(), 1, "add".to_string(), false), 1, @@ -536,17 +854,88 @@ fn asmop_conditional_execution_test() { op: Some(Operation::Add), }, VmStatePartial { - clk: RowIndex::from(11), + clk: RowIndex::from(12), asmop: None, op: Some(Operation::End), }, VmStatePartial { - clk: RowIndex::from(12), + clk: RowIndex::from(13), asmop: None, op: Some(Operation::End), }, VmStatePartial { - clk: RowIndex::from(13), + clk: RowIndex::from(14), + asmop: None, + op: Some(Operation::End), + }, + VmStatePartial { + clk: RowIndex::from(15), + asmop: None, + op: Some(Operation::Span), + }, + VmStatePartial { + clk: RowIndex::from(16), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + swap1_loc, + "#exec::#main".to_string(), + 1, + "swap.1".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Swap), + }, + VmStatePartial { + clk: RowIndex::from(17), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + drop1_loc, + "#exec::#main".to_string(), + 1, + "drop".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Drop), + }, + VmStatePartial { + clk: RowIndex::from(18), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + swap2_loc, + "#exec::#main".to_string(), + 1, + "swap.1".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Swap), + }, + VmStatePartial { + clk: RowIndex::from(19), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + drop2_loc, + "#exec::#main".to_string(), + 1, + "drop".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Drop), + }, + VmStatePartial { + clk: RowIndex::from(20), + asmop: None, + op: Some(Operation::End), + }, + VmStatePartial { + clk: RowIndex::from(21), asmop: None, op: Some(Operation::End), }, @@ -559,23 +948,43 @@ fn asmop_conditional_execution_test() { let path = test.source.name(); let eq_loc = Some(Location { path: path.clone(), - start: 18.into(), - end: (18 + 2).into(), + start: 27.into(), + end: (27 + 2).into(), }); let push3_loc = Some(Location { path: path.clone(), - start: 108.into(), - end: (108 + 6).into(), + start: 117.into(), + end: (117 + 6).into(), }); let push4_loc = Some(Location { path: path.clone(), - start: 115.into(), - end: (115 + 6).into(), + start: 124.into(), + end: (124 + 6).into(), }); let add_loc = Some(Location { path: path.clone(), - start: 122.into(), - end: (122 + 3).into(), + start: 131.into(), + end: (131 + 3).into(), + }); + let swap1_loc = Some(Location { + path: path.clone(), + start: 164.into(), + end: (164 + 4).into(), + }); + let drop1_loc = Some(Location { + path: path.clone(), + start: 169.into(), + end: (169 + 4).into(), + }); + let swap2_loc = Some(Location { + path: path.clone(), + start: 174.into(), + end: (174 + 4).into(), + }); + let drop2_loc = Some(Location { + path: path.clone(), + start: 179.into(), + end: (179 + 4).into(), }); let vm_state_iterator = test.execute_iter(); let expected_vm_state = vec![ @@ -592,10 +1001,15 @@ fn asmop_conditional_execution_test() { VmStatePartial { clk: RowIndex::from(2), asmop: None, - op: Some(Operation::Span), + op: Some(Operation::Join), }, VmStatePartial { clk: RowIndex::from(3), + asmop: None, + op: Some(Operation::Span), + }, + VmStatePartial { + clk: RowIndex::from(4), asmop: Some(AsmOpInfo::new( AssemblyOp::new(eq_loc, "#exec::#main".to_string(), 1, "eq".to_string(), false), 1, @@ -603,22 +1017,22 @@ fn asmop_conditional_execution_test() { op: Some(Operation::Eq), }, VmStatePartial { - clk: RowIndex::from(4), + clk: RowIndex::from(5), asmop: None, op: Some(Operation::End), }, VmStatePartial { - clk: RowIndex::from(5), + clk: RowIndex::from(6), asmop: None, op: Some(Operation::Split), }, VmStatePartial { - clk: RowIndex::from(6), + clk: RowIndex::from(7), asmop: None, op: Some(Operation::Span), }, VmStatePartial { - clk: RowIndex::from(7), + clk: RowIndex::from(8), asmop: Some(AsmOpInfo::new( AssemblyOp::new( push3_loc, @@ -632,7 +1046,7 @@ fn asmop_conditional_execution_test() { op: Some(Operation::Push(Felt::new(3))), }, VmStatePartial { - clk: RowIndex::from(8), + clk: RowIndex::from(9), asmop: Some(AsmOpInfo::new( AssemblyOp::new( push4_loc, @@ -646,7 +1060,7 @@ fn asmop_conditional_execution_test() { op: Some(Operation::Push(Felt::new(4))), }, VmStatePartial { - clk: RowIndex::from(9), + clk: RowIndex::from(10), asmop: Some(AsmOpInfo::new( AssemblyOp::new(add_loc, "#exec::#main".to_string(), 1, "add".to_string(), false), 1, @@ -654,22 +1068,93 @@ fn asmop_conditional_execution_test() { op: Some(Operation::Add), }, VmStatePartial { - clk: RowIndex::from(10), + clk: RowIndex::from(11), asmop: None, op: Some(Operation::Noop), }, VmStatePartial { - clk: RowIndex::from(11), + clk: RowIndex::from(12), asmop: None, op: Some(Operation::End), }, VmStatePartial { - clk: RowIndex::from(12), + clk: RowIndex::from(13), asmop: None, op: Some(Operation::End), }, VmStatePartial { - clk: RowIndex::from(13), + clk: RowIndex::from(14), + asmop: None, + op: Some(Operation::End), + }, + VmStatePartial { + clk: RowIndex::from(15), + asmop: None, + op: Some(Operation::Span), + }, + VmStatePartial { + clk: RowIndex::from(16), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + swap1_loc, + "#exec::#main".to_string(), + 1, + "swap.1".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Swap), + }, + VmStatePartial { + clk: RowIndex::from(17), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + drop1_loc, + "#exec::#main".to_string(), + 1, + "drop".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Drop), + }, + VmStatePartial { + clk: RowIndex::from(18), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + swap2_loc, + "#exec::#main".to_string(), + 1, + "swap.1".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Swap), + }, + VmStatePartial { + clk: RowIndex::from(19), + asmop: Some(AsmOpInfo::new( + AssemblyOp::new( + drop2_loc, + "#exec::#main".to_string(), + 1, + "drop".to_string(), + false, + ), + 1, + )), + op: Some(Operation::Drop), + }, + VmStatePartial { + clk: RowIndex::from(20), + asmop: None, + op: Some(Operation::End), + }, + VmStatePartial { + clk: RowIndex::from(21), asmop: None, op: Some(Operation::End), }, diff --git a/miden/tests/integration/operations/decorators/events.rs b/miden/tests/integration/operations/decorators/events.rs index ea835fdf78..73ced68efa 100644 --- a/miden/tests/integration/operations/decorators/events.rs +++ b/miden/tests/integration/operations/decorators/events.rs @@ -1,5 +1,6 @@ use assembly::Assembler; use processor::{ExecutionOptions, Program}; +use prover::StackInputs; use super::TestHost; @@ -11,12 +12,14 @@ fn test_event_handling() { emit.1 push.2 emit.2 + swapw dropw end"; // compile and execute program let program: Program = Assembler::default().assemble_program(source).unwrap(); let mut host = TestHost::default(); - processor::execute(&program, Default::default(), &mut host, Default::default()).unwrap(); + processor::execute(&program, StackInputs::default(), &mut host, ExecutionOptions::default()) + .unwrap(); // make sure events were handled correctly let expected = vec![1, 2]; @@ -31,6 +34,7 @@ fn test_trace_handling() { trace.1 push.2 trace.2 + swapw dropw end"; // compile program @@ -38,14 +42,15 @@ fn test_trace_handling() { let mut host = TestHost::default(); // execute program with disabled tracing - processor::execute(&program, Default::default(), &mut host, Default::default()).unwrap(); + processor::execute(&program, StackInputs::default(), &mut host, ExecutionOptions::default()) + .unwrap(); let expected = Vec::::new(); assert_eq!(host.trace_handler, expected); // execute program with enabled tracing processor::execute( &program, - Default::default(), + StackInputs::default(), &mut host, ExecutionOptions::default().with_tracing(), ) @@ -53,3 +58,51 @@ fn test_trace_handling() { let expected = vec![1, 2]; assert_eq!(host.trace_handler, expected); } + +#[test] +fn test_debug_with_debugging() { + let source: &str = "\ + begin + push.1 + debug.stack + debug.mem + drop + end"; + + // compile and execute program + let program: Program = + Assembler::default().with_debug_mode(true).assemble_program(source).unwrap(); + let mut host = TestHost::default(); + processor::execute( + &program, + StackInputs::default(), + &mut host, + ExecutionOptions::default().with_debugging(), + ) + .unwrap(); + + // Expect to see the debug.stack and debug.mem commands + let expected = vec!["stack", "mem"]; + assert_eq!(host.debug_handler, expected); +} + +#[test] +fn test_debug_without_debugging() { + let source: &str = "\ + begin + push.1 + debug.stack + debug.mem + drop + end"; + + // compile and execute program + let program: Program = Assembler::default().assemble_program(source).unwrap(); + let mut host = TestHost::default(); + processor::execute(&program, StackInputs::default(), &mut host, ExecutionOptions::default()) + .unwrap(); + + // Expect to see no debug commands + let expected: Vec = vec![]; + assert_eq!(host.debug_handler, expected); +} diff --git a/miden/tests/integration/operations/decorators/mod.rs b/miden/tests/integration/operations/decorators/mod.rs index 7c5fd8db9f..ce4fb82afc 100644 --- a/miden/tests/integration/operations/decorators/mod.rs +++ b/miden/tests/integration/operations/decorators/mod.rs @@ -4,7 +4,7 @@ use processor::{ AdviceExtractor, AdviceProvider, ExecutionError, Host, HostResponse, MastForest, MemAdviceProvider, ProcessState, }; -use vm_core::AdviceInjector; +use vm_core::{AdviceInjector, DebugOptions}; mod advice; mod asmop; @@ -16,6 +16,7 @@ pub struct TestHost { pub adv_provider: A, pub event_handler: Vec, pub trace_handler: Vec, + pub debug_handler: Vec, } impl Default for TestHost { @@ -24,6 +25,7 @@ impl Default for TestHost { adv_provider: MemAdviceProvider::default(), event_handler: Vec::new(), trace_handler: Vec::new(), + debug_handler: Vec::new(), } } } @@ -63,6 +65,15 @@ impl Host for TestHost { Ok(HostResponse::None) } + fn on_debug( + &mut self, + _process: &S, + _options: &DebugOptions, + ) -> Result { + self.debug_handler.push(_options.to_string()); + Ok(HostResponse::None) + } + fn get_mast_forest(&self, _node_digest: &prover::Digest) -> Option> { // Empty MAST forest store None diff --git a/miden/tests/integration/operations/field_ops.rs b/miden/tests/integration/operations/field_ops.rs index 9070d72516..51310b295c 100644 --- a/miden/tests/integration/operations/field_ops.rs +++ b/miden/tests/integration/operations/field_ops.rs @@ -187,9 +187,10 @@ fn div_b() { test, "invalid constant expression: division by zero", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin div.0 end", - " : ^^^^^", - " `----" + "11 |", + "12 | begin div.0 exec.truncate_stack end", + " : ^^^^^", + " `----" ); let test = build_op_test!(build_asm_op(2), &[4]); @@ -212,7 +213,7 @@ fn div_fail() { // --- test divide by zero -------------------------------------------------------------------- let test = build_op_test!(asm_op, &[1, 0]); - expect_exec_error!(test, ExecutionError::DivideByZero(1.into())); + expect_exec_error!(test, ExecutionError::DivideByZero(2.into())); } #[test] @@ -246,10 +247,11 @@ fn neg_fail() { test, "invalid syntax", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin neg.1 end", - " : |", - " : `-- found a . here", - " `----", + "11 |", + "12 | begin neg.1 exec.truncate_stack end", + " : |", + " : `-- found a . here", + " `----", r#" help: expected primitive opcode (e.g. "add"), or "end", or control flow opcode (e.g. "if.true")"# ); } @@ -277,7 +279,7 @@ fn inv_fail() { // --- test no inv on 0 ----------------------------------------------------------------------- let test = build_op_test!(asm_op, &[0]); - expect_exec_error!(test, ExecutionError::DivideByZero(1.into())); + expect_exec_error!(test, ExecutionError::DivideByZero(2.into())); let asm_op = "inv.1"; @@ -288,10 +290,11 @@ fn inv_fail() { test, "invalid syntax", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin inv.1 end", - " : |", - " : `-- found a . here", - " `----", + "11 |", + "12 | begin inv.1 exec.truncate_stack end", + " : |", + " : `-- found a . here", + " `----", r#" help: expected primitive opcode (e.g. "add"), or "end", or control flow opcode (e.g. "if.true")"# ); } @@ -318,7 +321,7 @@ fn pow2_fail() { expect_exec_error!( test, ExecutionError::FailedAssertion { - clk: 16.into(), + clk: 17.into(), err_code: 0, err_msg: None, } @@ -353,7 +356,7 @@ fn exp_bits_length_fail() { expect_exec_error!( test, ExecutionError::FailedAssertion { - clk: 18.into(), + clk: 19.into(), err_code: 0, err_msg: None } @@ -370,9 +373,11 @@ fn exp_bits_length_fail() { test, "invalid literal: expected value to be a valid bit size, e.g. 0..63", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin exp.u65 end", - " : ^^", - " `----" + "11 |", + "12 | begin exp.u65 exec.truncate_stack end", + " : ^^", + " `----", + r#" help: expected primitive opcode (e.g. "add"), or "end", or control flow opcode (e.g. "if.true")"# ); } @@ -402,7 +407,7 @@ fn ilog2_fail() { let asm_op = "ilog2"; let test = build_op_test!(asm_op, &[0]); - expect_exec_error!(test, ExecutionError::LogArgumentZero(1.into())); + expect_exec_error!(test, ExecutionError::LogArgumentZero(2.into())); } // FIELD OPS BOOLEAN - MANUAL TESTS diff --git a/miden/tests/integration/operations/fri_ops.rs b/miden/tests/integration/operations/fri_ops.rs index 7c1ac6c072..56949c306f 100644 --- a/miden/tests/integration/operations/fri_ops.rs +++ b/miden/tests/integration/operations/fri_ops.rs @@ -1,15 +1,12 @@ -use test_utils::{build_test, rand::rand_array, Felt, FieldElement}; +use test_utils::{ + build_test, push_inputs, rand::rand_array, Felt, FieldElement, TRUNCATE_STACK_PROC, +}; // FRI_EXT2FOLD4 // ================================================================================================ #[test] fn fri_ext2fold4() { - let source = " - begin - fri_ext2fold4 - end"; - // create a set of random inputs let mut inputs = rand_array::().iter().map(|v| v.as_int()).collect::>(); inputs[7] = 2; // domain segment must be < 4 @@ -23,8 +20,21 @@ fn fri_ext2fold4() { let poe = inputs[6]; let f_pos = inputs[8]; + let source = format!( + " + {TRUNCATE_STACK_PROC} + + begin + {inputs} + fri_ext2fold4 + + exec.truncate_stack + end", + inputs = push_inputs(&inputs) + ); + // execute the program - let test = build_test!(source, &inputs); + let test = build_test!(source, &[]); // check some items in the state transition; full state transition is checked in the // processor tests @@ -36,5 +46,5 @@ fn fri_ext2fold4() { assert_eq!(stack_state[15], Felt::new(end_ptr)); // make sure STARK proof can be generated and verified - test.prove_and_verify(inputs, false); + test.prove_and_verify(vec![], false); } diff --git a/miden/tests/integration/operations/io_ops/adv_ops.rs b/miden/tests/integration/operations/io_ops/adv_ops.rs index 09f91e9afb..e172f8aaf7 100644 --- a/miden/tests/integration/operations/io_ops/adv_ops.rs +++ b/miden/tests/integration/operations/io_ops/adv_ops.rs @@ -2,7 +2,7 @@ use processor::{ExecutionError, ExecutionError::AdviceStackReadFailed}; use test_utils::expect_exec_error; use vm_core::{chiplets::hasher::apply_permutation, utils::ToElements, Felt}; -use super::{build_op_test, build_test}; +use super::{build_op_test, build_test, TRUNCATE_STACK_PROC}; // PUSHING VALUES ONTO THE STACK (PUSH) // ================================================================================================ @@ -32,7 +32,7 @@ fn adv_push() { fn adv_push_invalid() { // attempting to read from empty advice stack should throw an error let test = build_op_test!("adv_push.1"); - expect_exec_error!(test, ExecutionError::AdviceStackReadFailed(1.into())); + expect_exec_error!(test, ExecutionError::AdviceStackReadFailed(2.into())); } // OVERWRITING VALUES ON THE STACK (LOAD) @@ -53,7 +53,7 @@ fn adv_loadw() { fn adv_loadw_invalid() { // attempting to read from empty advice stack should throw an error let test = build_op_test!("adv_loadw", &[0, 0, 0, 0]); - expect_exec_error!(test, AdviceStackReadFailed(1.into())); + expect_exec_error!(test, AdviceStackReadFailed(2.into())); } // MOVING ELEMENTS TO MEMORY VIA THE STACK (PIPE) @@ -61,11 +61,17 @@ fn adv_loadw_invalid() { #[test] fn adv_pipe() { - let source = " + let source = format!( + " + {TRUNCATE_STACK_PROC} + begin push.12.11.10.9.8.7.6.5.4.3.2.1 adv_pipe - end"; + + exec.truncate_stack + end" + ); let advice_stack = [1, 2, 3, 4, 5, 6, 7, 8]; @@ -88,11 +94,17 @@ fn adv_pipe() { #[test] fn adv_pipe_with_hperm() { - let source = " + let source = format!( + " + {TRUNCATE_STACK_PROC} + begin push.12.11.10.9.8.7.6.5.4.3.2.1 adv_pipe hperm - end"; + + exec.truncate_stack + end" + ); let advice_stack = [1, 2, 3, 4, 5, 6, 7, 8]; diff --git a/miden/tests/integration/operations/io_ops/env_ops.rs b/miden/tests/integration/operations/io_ops/env_ops.rs index 15048716ef..d8bc13b594 100644 --- a/miden/tests/integration/operations/io_ops/env_ops.rs +++ b/miden/tests/integration/operations/io_ops/env_ops.rs @@ -1,11 +1,13 @@ use assembly::SourceManager; use processor::FMP_MIN; -use test_utils::{build_op_test, build_test, StackInputs, Test, Word, STACK_TOP_SIZE}; +use test_utils::{build_op_test, build_test, StackInputs, Test, Word, MIN_STACK_DEPTH}; use vm_core::{ mast::{MastForest, MastNode}, Operation, }; +use super::TRUNCATE_STACK_PROC; + // SDEPTH INSTRUCTION // ================================================================================================ @@ -15,15 +17,27 @@ fn sdepth() { // --- empty stack ---------------------------------------------------------------------------- let test = build_op_test!(test_op); - test.expect_stack(&[STACK_TOP_SIZE as u64]); + test.expect_stack(&[MIN_STACK_DEPTH as u64]); // --- multi-element stack -------------------------------------------------------------------- let test = build_op_test!(test_op, &[2, 4, 6, 8, 10]); - test.expect_stack(&[STACK_TOP_SIZE as u64, 10, 8, 6, 4, 2]); + test.expect_stack(&[MIN_STACK_DEPTH as u64, 10, 8, 6, 4, 2]); // --- overflowed stack ----------------------------------------------------------------------- // push 2 values to increase the lenth of the stack beyond 16 - let source = format!("begin push.1 push.1 {test_op} end"); + let source = format!( + " + {TRUNCATE_STACK_PROC} + + begin + push.1 + push.1 + {test_op} + + exec.truncate_stack + end + " + ); let test = build_test!(&source, &[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]); test.expect_stack(&[18, 1, 1, 7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3]); } @@ -41,6 +55,7 @@ fn locaddr() { end begin exec.foo + swapw dropw end"; let test = build_test!(source, &[10]); @@ -60,13 +75,17 @@ fn locaddr() { end begin exec.foo + swapdw dropw dropw end"; let test = build_test!(source, &[10, 1, 2, 3, 4, 5]); test.expect_stack(&[4, 3, 2, 1, 5, 10]); // --- locaddr returns expected addresses in nested procedures -------------------------------- - let source = " + let source = format!( + " + {TRUNCATE_STACK_PROC} + proc.foo.3 locaddr.0 locaddr.1 @@ -80,7 +99,10 @@ fn locaddr() { begin exec.bar exec.foo - end"; + + exec.truncate_stack + end" + ); let test = build_test!(source, &[10]); test.expect_stack(&[ @@ -118,6 +140,7 @@ fn locaddr() { end begin exec.bar + swapdw dropw dropw end"; let test = build_test!(source, &[10, 1, 2, 3, 4, 5, 6, 7]); @@ -181,7 +204,7 @@ fn build_bar_hash() -> [u64; 4] { #[test] fn clk() { let test = build_op_test!("clk"); - test.expect_stack(&[1]); + test.expect_stack(&[2]); let source = " proc.foo @@ -189,8 +212,10 @@ fn clk() { push.4 clk end + begin exec.foo + swapw dropw end"; let test = build_test!(source, &[]); diff --git a/miden/tests/integration/operations/io_ops/local_ops.rs b/miden/tests/integration/operations/io_ops/local_ops.rs index e995cdf0de..f1fc61e1b3 100644 --- a/miden/tests/integration/operations/io_ops/local_ops.rs +++ b/miden/tests/integration/operations/io_ops/local_ops.rs @@ -1,4 +1,4 @@ -use super::build_test; +use super::{build_test, TRUNCATE_STACK_PROC}; // PUSHING VALUES ONTO THE STACK (PUSH) // ================================================================================================ @@ -9,8 +9,10 @@ fn push_local() { proc.foo.1 loc_load.0 end + begin exec.foo + movup.5 drop end"; // --- 1 value is pushed & the rest of the stack is unchanged --------------------------------- @@ -37,6 +39,7 @@ fn pop_local() { end begin exec.foo + swapw dropw end"; let test = build_test!(source, &[1, 2, 3, 4]); @@ -86,7 +89,10 @@ fn loadw_local() { #[test] fn storew_local() { // --- test write to local memory ------------------------------------------------------------- - let source = " + let source = format!( + " + {TRUNCATE_STACK_PROC} + proc.foo.2 loc_storew.0 swapw @@ -99,7 +105,10 @@ fn storew_local() { end begin exec.foo - end"; + + exec.truncate_stack + end" + ); let test = build_test!(source, &[1, 2, 3, 4, 5, 6, 7, 8]); test.expect_stack(&[4, 3, 2, 1, 8, 7, 6, 5, 8, 7, 6, 5, 4, 3, 2, 1]); @@ -133,8 +142,10 @@ fn inverse_operations() { loc_store.0 loc_load.0 end + begin exec.foo + movup.5 drop end"; let inputs = [0, 1, 2, 3, 4]; let mut final_stack = inputs; @@ -151,10 +162,12 @@ fn inverse_operations() { push.0.0.0.0 loc_loadw.0 end + begin exec.foo + swapw dropw end"; - let inputs = [0, 1, 2, 3, 4]; + let inputs = [1, 2, 3, 4]; let mut final_stack = inputs; final_stack.reverse(); @@ -167,6 +180,7 @@ fn inverse_operations() { loc_storew.0 loc_loadw.0 end + begin exec.foo end"; @@ -188,6 +202,7 @@ fn read_after_write() { end begin exec.foo + movup.5 drop end"; let test = build_test!(source, &[1, 2, 3, 4]); @@ -202,6 +217,7 @@ fn read_after_write() { end begin exec.foo + swapdw dropw dropw end"; let test = build_test!(source, &[1, 2, 3, 4]); @@ -229,13 +245,16 @@ fn nested_procedures() { proc.foo.1 loc_store.0 end + proc.bar.1 loc_store.0 exec.foo loc_load.0 end + begin exec.bar + movup.3 drop end"; let inputs = [0, 1, 2, 3]; @@ -257,6 +276,7 @@ fn nested_procedures() { end begin exec.bar + swapw dropw end"; let inputs = [0, 1, 2, 3, 4, 5, 6, 7]; @@ -276,6 +296,7 @@ fn nested_procedures() { end begin exec.bar + movup.7 movup.7 drop drop end"; let inputs = [0, 1, 2, 3]; @@ -301,6 +322,8 @@ fn free_memory_pointer() { mem_load.2 mem_load.1 mem_load.0 + + movupw.2 dropw end"; let inputs = [1, 2, 3, 4, 5, 6, 7]; diff --git a/miden/tests/integration/operations/io_ops/mem_ops.rs b/miden/tests/integration/operations/io_ops/mem_ops.rs index dd3610de6a..bb1f1469f1 100644 --- a/miden/tests/integration/operations/io_ops/mem_ops.rs +++ b/miden/tests/integration/operations/io_ops/mem_ops.rs @@ -1,4 +1,4 @@ -use super::{apply_permutation, build_op_test, build_test, Felt, ToElements}; +use super::{apply_permutation, build_op_test, build_test, Felt, ToElements, TRUNCATE_STACK_PROC}; // LOADING SINGLE ELEMENT ONTO THE STACK (MLOAD) // ================================================================================================ @@ -91,7 +91,10 @@ fn mem_storew() { #[test] fn mem_stream() { - let source = " + let source = format!( + " + {TRUNCATE_STACK_PROC} + begin push.1 mem_storew @@ -101,7 +104,10 @@ fn mem_stream() { dropw push.12.11.10.9.8.7.6.5.4.3.2.1 mem_stream - end"; + + exec.truncate_stack + end" + ); let inputs = [1, 2, 3, 4, 5, 6, 7, 8]; @@ -124,7 +130,10 @@ fn mem_stream() { #[test] fn mem_stream_with_hperm() { - let source = " + let source = format!( + " + {TRUNCATE_STACK_PROC} + begin push.1 mem_storew @@ -134,7 +143,10 @@ fn mem_stream_with_hperm() { dropw push.12.11.10.9.8.7.6.5.4.3.2.1 mem_stream hperm - end"; + + exec.truncate_stack + end" + ); let inputs = [1, 2, 3, 4, 5, 6, 7, 8]; @@ -172,6 +184,8 @@ fn inverse_operations() { push.1 mem_load mem_load.0 + + movup.6 movup.6 drop drop end"; let inputs = [0, 1, 2, 3, 4]; diff --git a/miden/tests/integration/operations/io_ops/mod.rs b/miden/tests/integration/operations/io_ops/mod.rs index dc3c3887e5..22f2255e42 100644 --- a/miden/tests/integration/operations/io_ops/mod.rs +++ b/miden/tests/integration/operations/io_ops/mod.rs @@ -1,4 +1,4 @@ -use test_utils::{assert_eq, build_op_test, build_test, Felt, ToElements}; +use test_utils::{assert_eq, build_op_test, build_test, Felt, ToElements, TRUNCATE_STACK_PROC}; use vm_core::chiplets::hasher::apply_permutation; mod adv_ops; @@ -38,6 +38,9 @@ fn mem_stream_pipe() { dropw movup.4 drop + + # truncate the stack + swapdw dropw dropw end"; let advice_stack = [1, 2, 3, 4, 5, 6, 7, 8]; diff --git a/miden/tests/integration/operations/stack_ops.rs b/miden/tests/integration/operations/stack_ops.rs index 60678cf662..89f8174e56 100644 --- a/miden/tests/integration/operations/stack_ops.rs +++ b/miden/tests/integration/operations/stack_ops.rs @@ -1,7 +1,7 @@ use assembly::regex; use test_utils::{ assert_assembler_diagnostic, assert_diagnostic_lines, build_op_test, proptest::prelude::*, - STACK_TOP_SIZE, WORD_SIZE, + MIN_STACK_DEPTH, WORD_SIZE, }; // STACK OPERATIONS TESTS @@ -63,8 +63,9 @@ fn dupn_fail() { test, "invalid immediate: value must be in the range 0..16 (exclusive)", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin dup.16 end", - " : ^^", + "11 |", + "12 | begin dup.16 exec.truncate_stack end", + " : ^^", " `----" ); } @@ -98,8 +99,9 @@ fn dupwn_fail() { test, "invalid immediate: value must be in the range 0..4 (exclusive)", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin dupw.4 end", - " : ^", + "11 |", + "12 | begin dupw.4 exec.truncate_stack end", + " : ^", " `----" ); } @@ -133,8 +135,9 @@ fn swapn_fail() { test, "invalid immediate: value must be in the range 1..16 (exclusive)", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin swap.16 end", - " : ^^", + "11 |", + "12 | begin swap.16 exec.truncate_stack end", + " : ^^", " `----" ); } @@ -168,9 +171,10 @@ fn swapwn_fail() { test, "invalid immediate: value must be in the range 1..4 (exclusive)", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin swapw.4 end", - " : ^", - " `----" + "11 |", + "12 | begin swapw.4 exec.truncate_stack end", + " : ^", + " `----" ); } @@ -200,8 +204,9 @@ fn movup_fail() { test, "invalid immediate: value must be in the range 2..16 (exclusive)", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin movup.0 end", - " : ^", + "11 |", + "12 | begin movup.0 exec.truncate_stack end", + " : ^", " `----" ); @@ -212,8 +217,9 @@ fn movup_fail() { test, "invalid immediate: value must be in the range 2..16 (exclusive)", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin movup.1 end", - " : ^", + "11 |", + "12 | begin movup.1 exec.truncate_stack end", + " : ^", " `----" ); @@ -224,8 +230,9 @@ fn movup_fail() { test, "invalid immediate: value must be in the range 2..16 (exclusive)", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin movup.16 end", - " : ^^", + "11 |", + "12 | begin movup.16 exec.truncate_stack end", + " : ^^", " `----" ); } @@ -247,8 +254,9 @@ fn movupw_fail() { test, "invalid immediate: value must be in the range 2..4 (exclusive)", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin movupw.0 end", - " : ^", + "11 |", + "12 | begin movupw.0 exec.truncate_stack end", + " : ^", " `----" ); @@ -259,8 +267,9 @@ fn movupw_fail() { test, "invalid immediate: value must be in the range 2..4 (exclusive)", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin movupw.1 end", - " : ^", + "11 |", + "12 | begin movupw.1 exec.truncate_stack end", + " : ^", " `----" ); @@ -271,8 +280,9 @@ fn movupw_fail() { test, "invalid immediate: value must be in the range 2..4 (exclusive)", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin movupw.4 end", - " : ^", + "11 |", + "12 | begin movupw.4 exec.truncate_stack end", + " : ^", " `----" ); } @@ -294,8 +304,9 @@ fn movdn_fail() { test, "invalid immediate: value must be in the range 2..16 (exclusive)", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin movdn.0 end", - " : ^", + "11 |", + "12 | begin movdn.0 exec.truncate_stack end", + " : ^", " `----" ); @@ -306,8 +317,9 @@ fn movdn_fail() { test, "invalid immediate: value must be in the range 2..16 (exclusive)", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin movdn.1 end", - " : ^", + "11 |", + "12 | begin movdn.1 exec.truncate_stack end", + " : ^", " `----" ); @@ -318,8 +330,9 @@ fn movdn_fail() { test, "invalid immediate: value must be in the range 2..16 (exclusive)", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin movdn.16 end", - " : ^^", + "11 |", + "12 | begin movdn.16 exec.truncate_stack end", + " : ^^", " `----" ); } @@ -341,8 +354,9 @@ fn movdnw_fail() { test, "invalid immediate: value must be in the range 2..4 (exclusive)", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin movdnw.0 end", - " : ^", + "11 |", + "12 | begin movdnw.0 exec.truncate_stack end", + " : ^", " `----" ); @@ -353,8 +367,9 @@ fn movdnw_fail() { test, "invalid immediate: value must be in the range 2..4 (exclusive)", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin movdnw.1 end", - " : ^", + "11 |", + "12 | begin movdnw.1 exec.truncate_stack end", + " : ^", " `----" ); @@ -365,8 +380,9 @@ fn movdnw_fail() { test, "invalid immediate: value must be in the range 2..4 (exclusive)", regex!(r#",-\[test[\d]+:[\d]+:[\d]+\]"#), - "1 | begin movdnw.4 end", - " : ^", + "11 |", + "12 | begin movdnw.4 exec.truncate_stack end", + " : ^", " `----" ); } @@ -419,27 +435,27 @@ fn cdropw() { proptest! { #[test] - fn drop_proptest(test_values in prop::collection::vec(any::(), STACK_TOP_SIZE)) { + fn drop_proptest(test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH)) { let asm_op = "drop"; let mut expected_values = test_values.clone(); - expected_values.remove(STACK_TOP_SIZE - 1); + expected_values.remove(MIN_STACK_DEPTH - 1); expected_values.reverse(); expected_values.push(0); build_op_test!(asm_op, &test_values).prop_expect_stack(&expected_values)?; } #[test] - fn dropw_proptest(test_values in prop::collection::vec(any::(), STACK_TOP_SIZE)) { + fn dropw_proptest(test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH)) { let asm_op = "dropw"; let mut expected_values = test_values.clone(); - expected_values.truncate(STACK_TOP_SIZE - WORD_SIZE); + expected_values.truncate(MIN_STACK_DEPTH - WORD_SIZE); expected_values.reverse(); expected_values.append(&mut vec![0; WORD_SIZE]); build_op_test!(asm_op, &test_values).prop_expect_stack(&expected_values)?; } #[test] - fn padw_proptest(test_values in prop::collection::vec(any::(), STACK_TOP_SIZE)) { + fn padw_proptest(test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH)) { let asm_op = "padw"; let mut expected_values = test_values.clone(); expected_values.drain(0..WORD_SIZE); @@ -449,7 +465,7 @@ proptest! { } #[test] - fn dup_proptest(test_values in prop::collection::vec(any::(), STACK_TOP_SIZE)) { + fn dup_proptest(test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH)) { let asm_op = "dup"; let mut expected_values = test_values.clone(); expected_values.remove(0); @@ -459,10 +475,10 @@ proptest! { } #[test] - fn dupn_proptest(test_values in prop::collection::vec(any::(), STACK_TOP_SIZE), n in 0_usize..STACK_TOP_SIZE) { + fn dupn_proptest(test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH), n in 0_usize..MIN_STACK_DEPTH) { let asm_op = format!("dup.{n}"); let mut expected_values = test_values.clone(); - let dup_idx = STACK_TOP_SIZE - n - 1; + let dup_idx = MIN_STACK_DEPTH - n - 1; let a = expected_values[dup_idx]; expected_values.remove(0); expected_values.push(a); @@ -471,11 +487,11 @@ proptest! { } #[test] - fn dupw_proptest(test_values in prop::collection::vec(any::(), STACK_TOP_SIZE)) { + fn dupw_proptest(test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH)) { let asm_op = "dupw"; let mut expected_values = test_values.clone(); expected_values.drain(0..WORD_SIZE); - let dupw_idx = STACK_TOP_SIZE - WORD_SIZE; + let dupw_idx = MIN_STACK_DEPTH - WORD_SIZE; let mut a = test_values[dupw_idx..].to_vec(); expected_values.append(&mut a); expected_values.reverse(); @@ -483,12 +499,12 @@ proptest! { } #[test] - fn dupwn_proptest(test_values in prop::collection::vec(any::(), STACK_TOP_SIZE), n in 0_usize..WORD_SIZE) { + fn dupwn_proptest(test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH), n in 0_usize..WORD_SIZE) { let asm_op = format!("dupw.{n}"); let mut expected_values = test_values.clone(); expected_values.drain(0..WORD_SIZE); - let start_dupw_idx = STACK_TOP_SIZE - WORD_SIZE * (n + 1); - let end_dupw_idx = STACK_TOP_SIZE - WORD_SIZE * n; + let start_dupw_idx = MIN_STACK_DEPTH - WORD_SIZE * (n + 1); + let end_dupw_idx = MIN_STACK_DEPTH - WORD_SIZE * n; let mut a = test_values[start_dupw_idx..end_dupw_idx].to_vec(); expected_values.append(&mut a); expected_values.reverse(); @@ -496,7 +512,7 @@ proptest! { } #[test] - fn swap_proptest(test_values in prop::collection::vec(any::(), STACK_TOP_SIZE)) { + fn swap_proptest(test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH)) { let asm_op = "swap"; let mut expected_values = test_values.clone(); expected_values.reverse(); @@ -505,7 +521,7 @@ proptest! { } #[test] - fn swapn_proptest(test_values in prop::collection::vec(any::(), STACK_TOP_SIZE), n in 1_usize..STACK_TOP_SIZE) { + fn swapn_proptest(test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH), n in 1_usize..MIN_STACK_DEPTH) { let asm_op = format!("swap.{n}"); let mut expected_values = test_values.clone(); expected_values.reverse(); @@ -514,7 +530,7 @@ proptest! { } #[test] - fn swapw_proptest(test_values in prop::collection::vec(any::(), STACK_TOP_SIZE)) { + fn swapw_proptest(test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH)) { let asm_op = "swapw"; let mut expected_values = test_values.clone(); let mut a = expected_values.split_off(WORD_SIZE * 3); @@ -526,10 +542,10 @@ proptest! { } #[test] - fn swapwn_proptest(test_values in prop::collection::vec(any::(), STACK_TOP_SIZE), n in 1_usize..WORD_SIZE) { + fn swapwn_proptest(test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH), n in 1_usize..WORD_SIZE) { let asm_op = format!("swapw.{n}"); let mut expected_values = test_values.clone(); - let start_swapwn_idx = WORD_SIZE * (STACK_TOP_SIZE / WORD_SIZE - n - 1); + let start_swapwn_idx = WORD_SIZE * (MIN_STACK_DEPTH / WORD_SIZE - n - 1); let mut a = expected_values.split_off(start_swapwn_idx); let mut b = a.split_off(WORD_SIZE); let mut c = b.split_off(b.len() - WORD_SIZE); @@ -541,7 +557,7 @@ proptest! { } #[test] - fn swapdw_proptest(test_values in prop::collection::vec(any::(), STACK_TOP_SIZE)) { + fn swapdw_proptest(test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH)) { let asm_op = "swapdw"; let mut expected_values = test_values[..(WORD_SIZE * 2)].to_vec(); let mut b = test_values[(WORD_SIZE * 2)..].to_vec(); @@ -552,10 +568,10 @@ proptest! { } #[test] - fn movup_proptest(test_values in prop::collection::vec(any::(), STACK_TOP_SIZE), movup_idx in 2_usize..STACK_TOP_SIZE) { + fn movup_proptest(test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH), movup_idx in 2_usize..MIN_STACK_DEPTH) { let asm_op = format!("movup.{movup_idx}"); let mut expected_values = test_values.clone(); - let idx1 = STACK_TOP_SIZE - movup_idx - 1; + let idx1 = MIN_STACK_DEPTH - movup_idx - 1; let movup_value = expected_values[idx1]; expected_values.remove(idx1); expected_values.push(movup_value); @@ -564,10 +580,10 @@ proptest! { } #[test] - fn movupw_proptest(test_values in prop::collection::vec(any::(), STACK_TOP_SIZE), movupw_idx in 2_usize..WORD_SIZE) { + fn movupw_proptest(test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH), movupw_idx in 2_usize..WORD_SIZE) { let asm_op = format!("movupw.{movupw_idx}"); - let start_movupw_idx = STACK_TOP_SIZE - (movupw_idx + 1) * WORD_SIZE; - let end_movupw_idx = STACK_TOP_SIZE - movupw_idx * WORD_SIZE; + let start_movupw_idx = MIN_STACK_DEPTH - (movupw_idx + 1) * WORD_SIZE; + let end_movupw_idx = MIN_STACK_DEPTH - movupw_idx * WORD_SIZE; let mut movupw_values = test_values[start_movupw_idx..end_movupw_idx].to_vec(); let mut expected_values = test_values[..start_movupw_idx].to_vec(); expected_values.append(&mut test_values[end_movupw_idx..].to_vec()); @@ -577,10 +593,10 @@ proptest! { } #[test] - fn movdn_proptest(test_values in prop::collection::vec(any::(), STACK_TOP_SIZE), movdn_idx in 2_usize..STACK_TOP_SIZE) { + fn movdn_proptest(test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH), movdn_idx in 2_usize..MIN_STACK_DEPTH) { let asm_op = format!("movdn.{movdn_idx}"); let mut expected_values = test_values.clone(); - let idx1 = STACK_TOP_SIZE - 1; + let idx1 = MIN_STACK_DEPTH - 1; let movdn_value = expected_values[idx1]; expected_values.remove(idx1); expected_values.insert(idx1 - movdn_idx, movdn_value); @@ -589,10 +605,10 @@ proptest! { } #[test] - fn movdnw_proptest(test_values in prop::collection::vec(any::(), STACK_TOP_SIZE), movdnw_idx in 2_usize..WORD_SIZE) { + fn movdnw_proptest(test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH), movdnw_idx in 2_usize..WORD_SIZE) { let asm_op = format!("movdnw.{movdnw_idx}"); - let idx1 = STACK_TOP_SIZE - (movdnw_idx + 1) * WORD_SIZE; - let movdnw_idx = STACK_TOP_SIZE - WORD_SIZE; + let idx1 = MIN_STACK_DEPTH - (movdnw_idx + 1) * WORD_SIZE; + let movdnw_idx = MIN_STACK_DEPTH - WORD_SIZE; let mut movdnw_values = test_values[movdnw_idx..].to_vec(); let mut expected_values = test_values[..idx1].to_vec(); expected_values.append(&mut movdnw_values); @@ -602,7 +618,7 @@ proptest! { } #[test] - fn cswap_proptest(mut test_values in prop::collection::vec(any::(), STACK_TOP_SIZE - 1), c in 0_u64..2) { + fn cswap_proptest(mut test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH - 1), c in 0_u64..2) { let asm_op = "cswap"; test_values.push(c); let mut expected_values = test_values.clone(); @@ -616,7 +632,7 @@ proptest! { } #[test] - fn cswapw_proptest(mut test_values in prop::collection::vec(any::(), STACK_TOP_SIZE - 1), c in 0_u64..2) { + fn cswapw_proptest(mut test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH - 1), c in 0_u64..2) { let asm_op = "cswapw"; let mut a = test_values.clone(); a.reverse(); @@ -634,7 +650,7 @@ proptest! { } #[test] - fn cdrop_proptest(mut test_values in prop::collection::vec(any::(), STACK_TOP_SIZE - 1), c in 0_u64..2) { + fn cdrop_proptest(mut test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH - 1), c in 0_u64..2) { let asm_op = "cdrop"; test_values.push(c); let mut expected_values = test_values.clone(); @@ -651,7 +667,7 @@ proptest! { } #[test] - fn cdropw_proptest(mut test_values in prop::collection::vec(any::(), STACK_TOP_SIZE - 1), c in 0_u64..2) { + fn cdropw_proptest(mut test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH - 1), c in 0_u64..2) { let asm_op = "cdropw"; let mut a = test_values.clone(); a.reverse(); diff --git a/miden/tests/integration/operations/sys_ops.rs b/miden/tests/integration/operations/sys_ops.rs index 366173d0fb..d775a7f01f 100644 --- a/miden/tests/integration/operations/sys_ops.rs +++ b/miden/tests/integration/operations/sys_ops.rs @@ -24,7 +24,7 @@ fn assert_with_code() { expect_exec_error!( test, ExecutionError::FailedAssertion { - clk: 1.into(), + clk: 2.into(), err_code: 123, err_msg: None, } @@ -39,7 +39,7 @@ fn assert_fail() { expect_exec_error!( test, ExecutionError::FailedAssertion { - clk: 1.into(), + clk: 2.into(), err_code: 0, err_msg: None, } @@ -65,7 +65,7 @@ fn assert_eq_fail() { expect_exec_error!( test, ExecutionError::FailedAssertion { - clk: 2.into(), + clk: 3.into(), err_code: 0, err_msg: None, } @@ -75,9 +75,18 @@ fn assert_eq_fail() { expect_exec_error!( test, ExecutionError::FailedAssertion { - clk: 2.into(), + clk: 3.into(), err_code: 0, err_msg: None, } ); } + +// EMITTING EVENTS +// ================================================================================================ + +#[test] +fn emit() { + let test = build_op_test!("emit.42", &[0, 0, 0, 0]); + test.prove_and_verify(vec![], false); +} diff --git a/miden/tests/integration/operations/u32_ops/arithmetic_ops.rs b/miden/tests/integration/operations/u32_ops/arithmetic_ops.rs index c138ab22c9..4c9ba2f5a9 100644 --- a/miden/tests/integration/operations/u32_ops/arithmetic_ops.rs +++ b/miden/tests/integration/operations/u32_ops/arithmetic_ops.rs @@ -3,8 +3,6 @@ use test_utils::{ build_op_test, expect_exec_error, proptest::prelude::*, rand::rand_value, U32_BOUND, }; -use super::test_unchecked_execution; - // U32 OPERATIONS TESTS - MANUAL - ARITHMETIC OPERATIONS // ================================================================================================ @@ -112,9 +110,6 @@ fn u32overflowing_add() { let e = rand_value::(); let test = build_op_test!(asm_op, &[e, a as u64, b as u64]); test.expect_stack(&[d, c as u64, e]); - - // should not fail when inputs are out of bounds. - test_unchecked_execution(asm_op, 2); } #[test] @@ -169,20 +164,6 @@ fn u32overflowing_add3() { let f = rand_value::(); let test = build_op_test!(asm_op, &[f, c as u64, a as u64, b as u64]); test.expect_stack(&[e, d as u64, f]); - - // --- test that out of bounds inputs do not cause a failure ---------------------------------- - - // should not fail if a >= 2^32. - let test = build_op_test!(asm_op, &[0, 0, U32_BOUND]); - assert!(test.execute().is_ok()); - - // should not fail if b >= 2^32. - let test = build_op_test!(asm_op, &[0, U32_BOUND, 0]); - assert!(test.execute().is_ok()); - - // should not fail if c >= 2^32. - let test = build_op_test!(asm_op, &[U32_BOUND, 0, 0]); - assert!(test.execute().is_ok()); } #[test] @@ -283,9 +264,6 @@ fn u32overflowing_sub() { let e = rand_value::(); let test = build_op_test!(asm_op, &[e, a as u64, b as u64]); test.expect_stack(&[d, c as u64, e]); - - // should not fail when inputs are out of bounds. - test_unchecked_execution(asm_op, 2); } #[test] @@ -384,9 +362,6 @@ fn u32overflowing_mul() { let e = rand_value::(); let test = build_op_test!(asm_op, &[e, a as u64, b as u64]); test.expect_stack(&[d, c as u64, e]); - - // should not fail when inputs are out of bounds. - test_unchecked_execution(asm_op, 2); } #[test] @@ -425,9 +400,6 @@ fn u32overflowing_madd() { let f = rand_value::(); let test = build_op_test!(asm_op, &[f, c as u64, a as u64, b as u64]); test.expect_stack(&[e, d, f]); - - // should not fail when inputs are out of bounds. - test_unchecked_execution(asm_op, 3); } #[test] @@ -460,9 +432,6 @@ fn u32div() { let e = rand_value::(); let test = build_op_test!("u32div", &[e, a as u64, b as u64]); test.expect_stack(&[quot, e]); - - // should not fail when inputs are out of bounds. - test_unchecked_execution("u32div", 2); } #[test] @@ -471,7 +440,7 @@ fn u32div_fail() { // should fail if b == 0. let test = build_op_test!(asm_op, &[1, 0]); - expect_exec_error!(test, ExecutionError::DivideByZero(1.into())); + expect_exec_error!(test, ExecutionError::DivideByZero(2.into())); } #[test] @@ -501,9 +470,6 @@ fn u32mod() { let c = rand_value::(); let test = build_op_test!("u32mod", &[c, a as u64, b as u64]); test.expect_stack(&[expected as u64, c]); - - // should not fail when inputs are out of bounds. - test_unchecked_execution("u32mod", 2); } #[test] @@ -512,7 +478,7 @@ fn u32mod_fail() { // should fail if b == 0 let test = build_op_test!(asm_op, &[1, 0]); - expect_exec_error!(test, ExecutionError::DivideByZero(1.into())); + expect_exec_error!(test, ExecutionError::DivideByZero(2.into())); } #[test] @@ -547,9 +513,6 @@ fn u32divmod() { let e = rand_value::(); let test = build_op_test!("u32divmod", &[e, a as u64, b as u64]); test.expect_stack(&[rem, quot, e]); - - // should not fail when inputs are out of bounds. - test_unchecked_execution("u32divmod", 2); } #[test] @@ -558,7 +521,7 @@ fn u32divmod_fail() { // should fail if b == 0. let test = build_op_test!(asm_op, &[1, 0]); - expect_exec_error!(test, ExecutionError::DivideByZero(1.into())); + expect_exec_error!(test, ExecutionError::DivideByZero(2.into())); } // U32 OPERATIONS TESTS - RANDOMIZED - ARITHMETIC OPERATIONS diff --git a/miden/tests/integration/operations/u32_ops/bitwise_ops.rs b/miden/tests/integration/operations/u32_ops/bitwise_ops.rs index 97bec15822..258fa77eb2 100644 --- a/miden/tests/integration/operations/u32_ops/bitwise_ops.rs +++ b/miden/tests/integration/operations/u32_ops/bitwise_ops.rs @@ -317,10 +317,6 @@ fn u32shl() { let test = build_op_test!(asm_op, &[a as u64, b as u64]); test.expect_stack(&[a.wrapping_shl(b) as u64]); - - // --- test out of bounds input (should not fail) -------------------------------------------- - let test = build_op_test!(asm_op, &[U32_BOUND, 1]); - assert!(test.execute().is_ok()); } #[test] @@ -355,11 +351,6 @@ fn u32shl_b() { // let test = build_op_test!(get_asm_op(b).as_str(), &[a as u64]); // test.expect_stack(&[a.wrapping_shl(b) as u64]); - - // --- test out of bounds input (should not fail) --------------------------------------------- - // let b = 1; - // let test = build_op_test!(get_asm_op(b).as_str(), &[U32_BOUND]); - // assert!(test.execute().is_ok()); } #[test] @@ -393,10 +384,6 @@ fn u32shr() { let test = build_op_test!(asm_op, &[a as u64, b as u64]); test.expect_stack(&[a.wrapping_shr(b) as u64]); - - // --- test out of bounds inputs (should not fail) -------------------------------------------- - let test = build_op_test!(asm_op, &[U32_BOUND, 1]); - assert!(test.execute().is_ok()); } #[test] @@ -431,11 +418,6 @@ fn u32shr_b() { let test = build_op_test!(get_asm_op(b).as_str(), &[a as u64]); test.expect_stack(&[a.wrapping_shr(b) as u64]); - - // --- test out of bounds inputs (should not fail) -------------------------------------------- - let b = 1; - let test = build_op_test!(get_asm_op(b).as_str(), &[U32_BOUND]); - assert!(test.execute().is_ok()); } #[test] @@ -480,10 +462,6 @@ fn u32rotl() { let test = build_op_test!(asm_op, &[a as u64, b as u64]); test.expect_stack(&[a.rotate_left(b) as u64]); - - // --- test out of bounds inputs (should not fail) -------------------------------------------- - let test = build_op_test!(asm_op, &[U32_BOUND, 1]); - assert!(test.execute().is_ok()); } #[test] @@ -528,10 +506,6 @@ fn u32rotr() { let test = build_op_test!(asm_op, &[a as u64, b as u64]); test.expect_stack(&[a.rotate_right(b) as u64]); - - // --- test out of bounds inputs (should not fail) -------------------------------------------- - let test = build_op_test!(asm_op, &[U32_BOUND, 1]); - assert!(test.execute().is_ok()); } #[test] diff --git a/miden/tests/integration/operations/u32_ops/comparison_ops.rs b/miden/tests/integration/operations/u32_ops/comparison_ops.rs index 2ec04e3078..18a5c8e192 100644 --- a/miden/tests/integration/operations/u32_ops/comparison_ops.rs +++ b/miden/tests/integration/operations/u32_ops/comparison_ops.rs @@ -2,8 +2,6 @@ use core::cmp::Ordering; use test_utils::{build_op_test, proptest::prelude::*, rand::rand_value}; -use super::test_unchecked_execution; - // U32 OPERATIONS TESTS - MANUAL - COMPARISON OPERATIONS // ================================================================================================ @@ -13,9 +11,6 @@ fn u32lt() { // should push 1 to the stack when a < b and 0 otherwise test_comparison_op(asm_op, 1, 0, 0); - - // should not fail when inputs are out of bounds - test_unchecked_execution(asm_op, 2); } #[test] @@ -24,9 +19,6 @@ fn u32lte() { // should push 1 to the stack when a <= b and 0 otherwise test_comparison_op(asm_op, 1, 1, 0); - - // should not fail when inputs are out of bounds - test_unchecked_execution(asm_op, 2); } #[test] @@ -35,9 +27,6 @@ fn u32gt() { // should push 1 to the stack when a > b and 0 otherwise test_comparison_op(asm_op, 0, 0, 1); - - // should not fail when inputs are out of bounds - test_unchecked_execution(asm_op, 2); } #[test] @@ -46,9 +35,6 @@ fn u32gte() { // should push 1 to the stack when a >= b and 0 otherwise test_comparison_op(asm_op, 0, 1, 1); - - // should not fail when inputs are out of bounds - test_unchecked_execution(asm_op, 2); } #[test] @@ -57,9 +43,6 @@ fn u32min() { // should put the minimum of the 2 inputs on the stack test_min(asm_op); - - // should not fail when inputs are out of bounds - test_unchecked_execution(asm_op, 2); } #[test] @@ -68,9 +51,6 @@ fn u32max() { // should put the maximum of the 2 inputs on the stack test_max(asm_op); - - // should not fail when inputs are out of bounds - test_unchecked_execution(asm_op, 2); } // U32 OPERATIONS TESTS - RANDOMIZED - COMPARISON OPERATIONS diff --git a/miden/tests/integration/operations/u32_ops/mod.rs b/miden/tests/integration/operations/u32_ops/mod.rs index 222fd8de4b..582052ac91 100644 --- a/miden/tests/integration/operations/u32_ops/mod.rs +++ b/miden/tests/integration/operations/u32_ops/mod.rs @@ -30,18 +30,3 @@ pub fn test_inputs_out_of_bounds(asm_op: &str, input_count: usize) { expect_exec_error!(test, ExecutionError::NotU32Value(Felt::new(U32_BOUND), ZERO)); } } - -/// This helper function tests that when the given u32 assembly instruction is executed on -/// out-of-bounds inputs it does not fail. Each input is tested independently. -pub fn test_unchecked_execution(asm_op: &str, input_count: usize) { - let values = vec![1_u64; input_count]; - - for i in 0..input_count { - let mut i_values = values.clone(); - // should execute successfully when the value of the input at index i is out of bounds - i_values[i] = U32_BOUND; - - let test = build_op_test!(asm_op, &i_values); - assert!(test.execute().is_ok()); - } -} diff --git a/processor/Cargo.toml b/processor/Cargo.toml index b932d002c3..3f8061c300 100644 --- a/processor/Cargo.toml +++ b/processor/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "miden-processor" -version = "0.10.6" +version = "0.11.0" description = "Miden VM processor" -documentation = "https://docs.rs/miden-processor/0.10.6" +documentation = "https://docs.rs/miden-processor/0.11.0" readme = "README.md" categories = ["emulators", "no-std"] keywords = ["miden", "virtual-machine"] @@ -20,18 +20,18 @@ doctest = false [features] concurrent = ["std", "winter-prover/concurrent"] default = ["std"] -testing = ["miden-air/testing"] std = ["vm-core/std", "winter-prover/std"] +testing = ["miden-air/testing"] [dependencies] -miden-air = { package = "miden-air", path = "../air", version = "0.10", default-features = false } +miden-air = { package = "miden-air", path = "../air", version = "0.11", default-features = false } tracing = { version = "0.1", default-features = false, features = ["attributes"] } -vm-core = { package = "miden-core", path = "../core", version = "0.10", default-features = false } -winter-prover = { package = "winter-prover", version = "0.9", default-features = false } +vm-core = { package = "miden-core", path = "../core", version = "0.11", default-features = false } +winter-prover = { package = "winter-prover", version = "0.10", default-features = false } [dev-dependencies] -assembly = { package = "miden-assembly", path = "../assembly", version = "0.10", default-features = false } +assembly = { package = "miden-assembly", path = "../assembly", version = "0.11", default-features = false } logtest = { version = "2.0", default-features = false } test-utils = { package = "miden-test-utils", path = "../test-utils" } -winter-fri = { package = "winter-fri", version = "0.9" } -winter-utils = { package = "winter-utils", version = "0.9" } +winter-fri = { package = "winter-fri", version = "0.10" } +winter-utils = { package = "winter-utils", version = "0.10" } diff --git a/processor/README.md b/processor/README.md index b0d9935619..837bb5fde0 100644 --- a/processor/README.md +++ b/processor/README.md @@ -63,6 +63,7 @@ Miden processor can be compiled with the following features: * `std` - enabled by default and relies on the Rust standard library. * `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly. + * Only the `wasm32-unknown-unknown` and `wasm32-wasip1` targets are officially supported. To compile with `no_std`, disable default features via `--no-default-features` flag. diff --git a/processor/src/chiplets/aux_trace/mod.rs b/processor/src/chiplets/aux_trace/mod.rs index 0d246c511d..db3e637dc5 100644 --- a/processor/src/chiplets/aux_trace/mod.rs +++ b/processor/src/chiplets/aux_trace/mod.rs @@ -17,10 +17,10 @@ use miden_air::{ RowIndex, }; use vm_core::{ - Word, ONE, OPCODE_CALL, OPCODE_DYN, OPCODE_END, OPCODE_HPERM, OPCODE_JOIN, OPCODE_LOOP, - OPCODE_MLOAD, OPCODE_MLOADW, OPCODE_MPVERIFY, OPCODE_MRUPDATE, OPCODE_MSTORE, OPCODE_MSTOREW, - OPCODE_MSTREAM, OPCODE_RCOMBBASE, OPCODE_RESPAN, OPCODE_SPAN, OPCODE_SPLIT, OPCODE_SYSCALL, - OPCODE_U32AND, OPCODE_U32XOR, ZERO, + Kernel, Word, ONE, OPCODE_CALL, OPCODE_DYN, OPCODE_DYNCALL, OPCODE_END, OPCODE_HPERM, + OPCODE_JOIN, OPCODE_LOOP, OPCODE_MLOAD, OPCODE_MLOADW, OPCODE_MPVERIFY, OPCODE_MRUPDATE, + OPCODE_MSTORE, OPCODE_MSTOREW, OPCODE_MSTREAM, OPCODE_PIPE, OPCODE_RCOMBBASE, OPCODE_RESPAN, + OPCODE_SPAN, OPCODE_SPLIT, OPCODE_SYSCALL, OPCODE_U32AND, OPCODE_U32XOR, ZERO, }; use super::{super::trace::AuxColumnBuilder, Felt, FieldElement}; @@ -34,10 +34,18 @@ const NUM_HEADER_ALPHAS: usize = 4; // ================================================================================================ /// Constructs the execution trace for chiplets-related auxiliary columns (used in multiset checks). -#[derive(Default)] -pub struct AuxTraceBuilder {} +pub struct AuxTraceBuilder { + kernel: Kernel, +} impl AuxTraceBuilder { + // CONSTRUCTORS + // -------------------------------------------------------------------------------------------- + + pub fn new(kernel: Kernel) -> Self { + Self { kernel } + } + // COLUMN TRACE CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -49,10 +57,14 @@ impl AuxTraceBuilder { main_trace: &MainTrace, rand_elements: &[E], ) -> Vec> { - let v_table_col_builder = ChipletsVTableColBuilder::default(); + let v_table_col_builder = ChipletsVTableColBuilder::new(self.kernel.clone()); let bus_col_builder = BusColumnBuilder::default(); let t_chip = v_table_col_builder.build_aux_column(main_trace, rand_elements); let b_chip = bus_col_builder.build_aux_column(main_trace, rand_elements); + + debug_assert_eq!(*t_chip.last().unwrap(), E::ONE); + // TODO: Fix and re-enable after testing with miden-base + // debug_assert_eq!(*b_chip.last().unwrap(), E::ONE); vec![t_chip, b_chip] } } @@ -62,10 +74,30 @@ impl AuxTraceBuilder { /// Describes how to construct the execution trace of the chiplets virtual table auxiliary trace /// column. -#[derive(Default)] -pub struct ChipletsVTableColBuilder {} +pub struct ChipletsVTableColBuilder { + kernel: Kernel, +} + +impl ChipletsVTableColBuilder { + fn new(kernel: Kernel) -> Self { + Self { kernel } + } +} impl> AuxColumnBuilder for ChipletsVTableColBuilder { + fn init_requests(&self, _main_trace: &MainTrace, alphas: &[E]) -> E { + let mut requests = E::ONE; + for (addr, proc_hash) in self.kernel.proc_hashes().iter().enumerate() { + requests *= alphas[0] + + alphas[1].mul_base((addr as u32).into()) + + alphas[2].mul_base(proc_hash[0]) + + alphas[3].mul_base(proc_hash[1]) + + alphas[4].mul_base(proc_hash[2]) + + alphas[5].mul_base(proc_hash[3]); + } + requests + } + fn get_requests_at(&self, main_trace: &MainTrace, alphas: &[E], row: RowIndex) -> E { chiplets_vtable_remove_sibling(main_trace, alphas, row) } @@ -86,14 +118,10 @@ where E: FieldElement, { let f_mu: bool = main_trace.f_mu(row); - let f_mua: bool = if row == 0 { false } else { main_trace.f_mua(row - 1) }; + let f_mua: bool = main_trace.f_mua(row); - if f_mu || f_mua { - let index = if f_mua { - main_trace.chiplet_node_index(row - 1) - } else { - main_trace.chiplet_node_index(row) - }; + if f_mu { + let index = main_trace.chiplet_node_index(row); let lsb = index.as_int() & 1; if lsb == 0 { let sibling = &main_trace.chiplet_hasher_state(row)[DIGEST_RANGE.end..]; @@ -112,6 +140,26 @@ where + alphas[10].mul_base(sibling[2]) + alphas[11].mul_base(sibling[3]) } + } else if f_mua { + let index = main_trace.chiplet_node_index(row); + let lsb = index.as_int() & 1; + if lsb == 0 { + let sibling = &main_trace.chiplet_hasher_state(row + 1)[DIGEST_RANGE.end..]; + alphas[0] + + alphas[3].mul_base(index) + + alphas[12].mul_base(sibling[0]) + + alphas[13].mul_base(sibling[1]) + + alphas[14].mul_base(sibling[2]) + + alphas[15].mul_base(sibling[3]) + } else { + let sibling = &main_trace.chiplet_hasher_state(row + 1)[DIGEST_RANGE]; + alphas[0] + + alphas[3].mul_base(index) + + alphas[8].mul_base(sibling[0]) + + alphas[9].mul_base(sibling[1]) + + alphas[10].mul_base(sibling[2]) + + alphas[11].mul_base(sibling[3]) + } } else { E::ONE } @@ -127,14 +175,10 @@ where E: FieldElement, { let f_mv: bool = main_trace.f_mv(row); - let f_mva: bool = if row == 0 { false } else { main_trace.f_mva(row - 1) }; + let f_mva: bool = main_trace.f_mva(row); - if f_mv || f_mva { - let index = if f_mva { - main_trace.chiplet_node_index(row - 1) - } else { - main_trace.chiplet_node_index(row) - }; + if f_mv { + let index = main_trace.chiplet_node_index(row); let lsb = index.as_int() & 1; if lsb == 0 { let sibling = &main_trace.chiplet_hasher_state(row)[DIGEST_RANGE.end..]; @@ -153,6 +197,26 @@ where + alphas[10].mul_base(sibling[2]) + alphas[11].mul_base(sibling[3]) } + } else if f_mva { + let index = main_trace.chiplet_node_index(row); + let lsb = index.as_int() & 1; + if lsb == 0 { + let sibling = &main_trace.chiplet_hasher_state(row + 1)[DIGEST_RANGE.end..]; + alphas[0] + + alphas[3].mul_base(index) + + alphas[12].mul_base(sibling[0]) + + alphas[13].mul_base(sibling[1]) + + alphas[14].mul_base(sibling[2]) + + alphas[15].mul_base(sibling[3]) + } else { + let sibling = &main_trace.chiplet_hasher_state(row + 1)[DIGEST_RANGE]; + alphas[0] + + alphas[3].mul_base(index) + + alphas[8].mul_base(sibling[0]) + + alphas[9].mul_base(sibling[1]) + + alphas[10].mul_base(sibling[2]) + + alphas[11].mul_base(sibling[3]) + } } else { E::ONE } @@ -169,21 +233,32 @@ where { if main_trace.is_kernel_row(row) { let addr = main_trace.chiplet_kernel_addr(row); - let addr_nxt = main_trace.chiplet_kernel_addr(row + 1); - let addr_delta = addr_nxt - addr; - let root0 = main_trace.chiplet_kernel_root_0(row); - let root1 = main_trace.chiplet_kernel_root_1(row); - let root2 = main_trace.chiplet_kernel_root_2(row); - let root3 = main_trace.chiplet_kernel_root_3(row); - - let v = alphas[0] - + alphas[1].mul_base(addr) - + alphas[2].mul_base(root0) - + alphas[3].mul_base(root1) - + alphas[4].mul_base(root2) - + alphas[5].mul_base(root3); - - v.mul_base(addr_delta) + E::from(ONE - addr_delta) + let addr_delta = { + let addr_nxt = main_trace.chiplet_kernel_addr(row + 1); + addr_nxt - addr + }; + let next_row_is_kernel = main_trace.is_kernel_row(row + 1); + + // We want to add an entry to the table in 2 cases: + // 1. when the next row is a kernel row and the address changes + // - this adds the last row of all rows that share the same address + // 2. when the next row is not a kernel row + // - this is the edge case of (1) + if !next_row_is_kernel || addr_delta == ONE { + let root0 = main_trace.chiplet_kernel_root_0(row); + let root1 = main_trace.chiplet_kernel_root_1(row); + let root2 = main_trace.chiplet_kernel_root_2(row); + let root3 = main_trace.chiplet_kernel_root_3(row); + + alphas[0] + + alphas[1].mul_base(addr) + + alphas[2].mul_base(root0) + + alphas[3].mul_base(root1) + + alphas[4].mul_base(root2) + + alphas[5].mul_base(root3) + } else { + E::ONE + } } else { E::ONE } @@ -206,8 +281,15 @@ impl> AuxColumnBuilder for BusColumnBuilder let op_code = op_code_felt.as_int() as u8; match op_code { - OPCODE_JOIN | OPCODE_SPLIT | OPCODE_LOOP | OPCODE_DYN | OPCODE_CALL => { - build_control_block_request(main_trace, op_code_felt, alphas, row) + OPCODE_JOIN | OPCODE_SPLIT | OPCODE_LOOP | OPCODE_CALL => build_control_block_request( + main_trace, + main_trace.decoder_hasher_state(row), + op_code_felt, + alphas, + row, + ), + OPCODE_DYN | OPCODE_DYNCALL => { + build_dyn_block_request(main_trace, op_code_felt, alphas, row) }, OPCODE_SYSCALL => build_syscall_block_request(main_trace, op_code_felt, alphas, row), OPCODE_SPAN => build_span_block_request(main_trace, alphas, row), @@ -224,6 +306,7 @@ impl> AuxColumnBuilder for BusColumnBuilder OPCODE_HPERM => build_hperm_request(main_trace, alphas, row), OPCODE_MPVERIFY => build_mpverify_request(main_trace, alphas, row), OPCODE_MRUPDATE => build_mrupdate_request(main_trace, alphas, row), + OPCODE_PIPE => build_pipe_request(main_trace, alphas, row), _ => E::ONE, } } @@ -253,21 +336,39 @@ impl> AuxColumnBuilder for BusColumnBuilder /// Builds requests made to the hasher chiplet at the start of a control block. fn build_control_block_request>( main_trace: &MainTrace, + decoder_hasher_state: [Felt; 8], op_code_felt: Felt, alphas: &[E], row: RowIndex, ) -> E { let op_label = LINEAR_HASH_LABEL; let addr_nxt = main_trace.addr(row + 1); - let first_cycle_row = addr_to_row_index(addr_nxt) % HASH_CYCLE_LEN == 0; - let transition_label = if first_cycle_row { op_label + 16 } else { op_label + 32 }; + let transition_label = op_label + 16; let header = alphas[0] + alphas[1].mul_base(Felt::from(transition_label)) + alphas[2].mul_base(addr_nxt); - let state = main_trace.decoder_hasher_state(row); + header + build_value(&alphas[8..16], &decoder_hasher_state) + alphas[5].mul_base(op_code_felt) +} + +/// Builds requests made on a `DYN` or `DYNCALL` operation. +fn build_dyn_block_request>( + main_trace: &MainTrace, + op_code_felt: Felt, + alphas: &[E], + row: RowIndex, +) -> E { + let control_block_req = + build_control_block_request(main_trace, [ZERO; 8], op_code_felt, alphas, row); - header + build_value(&alphas[8..16], &state) + alphas[5].mul_base(op_code_felt) + let memory_req = { + let mem_addr = main_trace.stack_element(0, row); + let mem_value = main_trace.decoder_hasher_state_first_half(row); + + compute_memory_request(main_trace, MEMORY_READ_LABEL, alphas, row, mem_addr, mem_value) + }; + + control_block_req * memory_req } /// Builds requests made to kernel ROM chiplet when initializing a syscall block. @@ -277,7 +378,13 @@ fn build_syscall_block_request>( alphas: &[E], row: RowIndex, ) -> E { - let factor1 = build_control_block_request(main_trace, op_code_felt, alphas, row); + let factor1 = build_control_block_request( + main_trace, + main_trace.decoder_hasher_state(row), + op_code_felt, + alphas, + row, + ); let op_label = KERNEL_PROC_LABEL; let state = main_trace.decoder_hasher_state(row); @@ -299,14 +406,12 @@ fn build_span_block_request>( ) -> E { let op_label = LINEAR_HASH_LABEL; let addr_nxt = main_trace.addr(row + 1); - let first_cycle_row = addr_to_row_index(addr_nxt) % HASH_CYCLE_LEN == 0; - let transition_label = if first_cycle_row { op_label + 16 } else { op_label + 32 }; + let transition_label = op_label + 16; let header = alphas[0] + alphas[1].mul_base(Felt::from(transition_label)) + alphas[2].mul_base(addr_nxt); let state = main_trace.decoder_hasher_state(row); - header + build_value(&alphas[8..16], &state) } @@ -318,19 +423,16 @@ fn build_respan_block_request>( ) -> E { let op_label = LINEAR_HASH_LABEL; let addr_nxt = main_trace.addr(row + 1); - - let first_cycle_row = addr_to_row_index(addr_nxt - ONE) % HASH_CYCLE_LEN == 0; - let transition_label = if first_cycle_row { op_label + 16 } else { op_label + 32 }; + let transition_label = op_label + 32; let header = alphas[0] + alphas[1].mul_base(Felt::from(transition_label)) + alphas[2].mul_base(addr_nxt - ONE) + alphas[3].mul_base(ZERO); - let state = &main_trace.chiplet_hasher_state(row - 2)[CAPACITY_LEN..]; - let state_nxt = &main_trace.chiplet_hasher_state(row - 1)[CAPACITY_LEN..]; + let state = main_trace.decoder_hasher_state(row); - header + build_value(&alphas[8..16], state_nxt) - build_value(&alphas[8..16], state) + header + build_value(&alphas[8..16], &state) } /// Builds requests made to the hasher chiplet at the end of a block. @@ -341,9 +443,7 @@ fn build_end_block_request>( ) -> E { let op_label = RETURN_HASH_LABEL; let addr = main_trace.addr(row) + Felt::from(NUM_ROUNDS as u8); - - let first_cycle_row = addr_to_row_index(addr) % HASH_CYCLE_LEN == 0; - let transition_label = if first_cycle_row { op_label + 16 } else { op_label + 32 }; + let transition_label = op_label + 32; let header = alphas[0] + alphas[1].mul_base(Felt::from(transition_label)) + alphas[2].mul_base(addr); @@ -437,6 +537,33 @@ fn build_mstream_request>( factor1 * factor2 } +/// Builds `PIPE` requests made to the memory chiplet. +fn build_pipe_request>( + main_trace: &MainTrace, + alphas: &[E], + row: RowIndex, +) -> E { + let word1 = [ + main_trace.stack_element(7, row + 1), + main_trace.stack_element(6, row + 1), + main_trace.stack_element(5, row + 1), + main_trace.stack_element(4, row + 1), + ]; + let word2 = [ + main_trace.stack_element(3, row + 1), + main_trace.stack_element(2, row + 1), + main_trace.stack_element(1, row + 1), + main_trace.stack_element(0, row + 1), + ]; + let addr = main_trace.stack_element(12, row); + let op_label = MEMORY_WRITE_LABEL; + + let factor1 = compute_memory_request(main_trace, op_label, alphas, row, addr, word1); + let factor2 = compute_memory_request(main_trace, op_label, alphas, row, addr + ONE, word2); + + factor1 * factor2 +} + /// Builds `RCOMBBASE` requests made to the memory chiplet. fn build_rcomb_base_request>( main_trace: &MainTrace, @@ -499,12 +626,7 @@ fn build_hperm_request>( main_trace.stack_element(11, row + 1), ]; - let op_label = LINEAR_HASH_LABEL; - let op_label = if addr_to_hash_cycle(helper_0) == 0 { - op_label + 16 - } else { - op_label + 32 - }; + let op_label = LINEAR_HASH_LABEL + 16; let sum_input = alphas[4..16] .iter() @@ -516,12 +638,7 @@ fn build_hperm_request>( + alphas[2].mul_base(helper_0) + sum_input; - let op_label = RETURN_STATE_LABEL; - let op_label = if addr_to_hash_cycle(helper_0 + Felt::new(7)) == 0 { - op_label + 16 - } else { - op_label + 32 - }; + let op_label = RETURN_STATE_LABEL + 32; let sum_output = alphas[4..16] .iter() @@ -559,12 +676,7 @@ fn build_mpverify_request>( main_trace.stack_element(9, row), ]; - let op_label = MP_VERIFY_LABEL; - let op_label = if addr_to_hash_cycle(helper_0) == 0 { - op_label + 16 - } else { - op_label + 32 - }; + let op_label = MP_VERIFY_LABEL + 16; let sum_input = alphas[8..12] .iter() @@ -578,12 +690,7 @@ fn build_mpverify_request>( + alphas[3].mul_base(s5) + sum_input; - let op_label = RETURN_HASH_LABEL; - let op_label = if (helper_0).as_int() % 8 == 0 { - op_label + 16 - } else { - op_label + 32 - }; + let op_label = RETURN_HASH_LABEL + 32; let sum_output = alphas[8..12] .iter() @@ -633,12 +740,7 @@ fn build_mrupdate_request>( main_trace.stack_element(13, row), ]; - let op_label = MR_UPDATE_OLD_LABEL; - let op_label = if addr_to_hash_cycle(helper_0) == 0 { - op_label + 16 - } else { - op_label + 32 - }; + let op_label = MR_UPDATE_OLD_LABEL + 16; let sum_input = alphas[8..12] .iter() @@ -651,12 +753,7 @@ fn build_mrupdate_request>( + alphas[3].mul_base(s5) + sum_input; - let op_label = RETURN_HASH_LABEL; - let op_label = if addr_to_hash_cycle(helper_0 + s4.mul_small(8) - ONE) == 0 { - op_label + 16 - } else { - op_label + 32 - }; + let op_label = RETURN_HASH_LABEL + 32; let sum_output = alphas[8..12] .iter() @@ -668,12 +765,7 @@ fn build_mrupdate_request>( + alphas[2].mul_base(helper_0 + s4.mul_small(8) - ONE) + sum_output; - let op_label = MR_UPDATE_NEW_LABEL; - let op_label = if addr_to_hash_cycle(helper_0 + s4.mul_small(8)) == 0 { - op_label + 16 - } else { - op_label + 32 - }; + let op_label = MR_UPDATE_NEW_LABEL + 16; let sum_input = alphas[8..12] .iter() .rev() @@ -685,12 +777,7 @@ fn build_mrupdate_request>( + alphas[3].mul_base(s5) + sum_input; - let op_label = RETURN_HASH_LABEL; - let op_label = if addr_to_hash_cycle(helper_0 + s4.mul_small(16) - ONE) == 0 { - op_label + 16 - } else { - op_label + 32 - }; + let op_label = RETURN_HASH_LABEL + 32; let sum_output = alphas[8..12] .iter() @@ -793,13 +880,12 @@ where let state_nxt = main_trace.chiplet_hasher_state(row + 1); - // build the value from the difference of the hasher state's just before and right - // after the absorption of new elements. + // build the value from the hasher state's just right after the absorption of new + // elements. let next_state_value = build_value(&alphas_state[CAPACITY_LEN..], &state_nxt[CAPACITY_LEN..]); - let state_value = build_value(&alphas_state[CAPACITY_LEN..], &state[CAPACITY_LEN..]); - multiplicand = header + next_state_value - state_value; + multiplicand = header + next_state_value; } } multiplicand @@ -898,20 +984,6 @@ fn get_op_label(s0: Felt, s1: Felt, s2: Felt, s3: Felt) -> Felt { s3.mul_small(1 << 3) + s2.mul_small(1 << 2) + s1.mul_small(2) + s0 + ONE } -/// Returns the hash cycle corresponding to the provided Hasher address. -fn addr_to_hash_cycle(addr: Felt) -> usize { - let row = (addr.as_int() - 1) as usize; - let cycle_row = row % HASH_CYCLE_LEN; - debug_assert!(cycle_row == 0 || cycle_row == HASH_CYCLE_LEN - 1, "invalid address for hasher"); - - cycle_row -} - -/// Convenience method to convert from addresses to rows. -fn addr_to_row_index(addr: Felt) -> usize { - (addr.as_int() - 1) as usize -} - /// Computes a memory read or write request at `row` given randomness `alphas`, memory address /// `addr` and value `value`. fn compute_memory_request>( diff --git a/processor/src/chiplets/mod.rs b/processor/src/chiplets/mod.rs index b098a57606..b4a4807247 100644 --- a/processor/src/chiplets/mod.rs +++ b/processor/src/chiplets/mod.rs @@ -74,6 +74,7 @@ mod tests; /// - columns 3-17: unused columns padded with ZERO /// /// The following is a pictorial representation of the chiplet module: +/// ```text /// +---+-------------------------------------------------------+-------------+ /// | 0 | | |-------------| /// | . | Hash chiplet | Hash chiplet |-------------| @@ -111,6 +112,7 @@ mod tests; /// | . | . | . | . |---------------------------------------------------------| /// | 1 | 1 | 1 | 1 |---------------------------------------------------------| /// +---+---+---+---+---------------------------------------------------------+ +/// ``` pub struct Chiplets { /// Current clock cycle of the VM. clk: RowIndex, @@ -391,6 +393,8 @@ impl Chiplets { // make sure that only padding rows will be overwritten by random values assert!(self.trace_len() + num_rand_rows <= trace_len, "target trace length too small"); + let kernel = self.kernel().clone(); + // Allocate columns for the trace of the chiplets. let mut trace = (0..CHIPLETS_WIDTH) .map(|_| vec![Felt::ZERO; trace_len]) @@ -401,7 +405,7 @@ impl Chiplets { ChipletsTrace { trace, - aux_builder: AuxTraceBuilder::default(), + aux_builder: AuxTraceBuilder::new(kernel), } } diff --git a/processor/src/chiplets/tests.rs b/processor/src/chiplets/tests.rs index d5f120113b..a0eab01620 100644 --- a/processor/src/chiplets/tests.rs +++ b/processor/src/chiplets/tests.rs @@ -119,8 +119,9 @@ fn build_trace( let mut mast_forest = MastForest::new(); let basic_block_id = mast_forest.add_block(operations, None).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; process.execute(&program).unwrap(); diff --git a/processor/src/debug.rs b/processor/src/debug.rs index 9d064a8111..c30d53efda 100644 --- a/processor/src/debug.rs +++ b/processor/src/debug.rs @@ -48,8 +48,8 @@ impl fmt::Display for VmState { /// Iterator that iterates through vm state at each step of the execution. /// -/// This allows debugging or replaying ability to view various process state at each clock cycle. -/// If the execution returned an error, it returns that error on the clock cycle it stopped. +/// This allows debugging or replaying ability to view various process state at each clock cycle. If +/// the execution returned an error, it returns that error on the clock cycle it stopped. pub struct VmStateIterator { chiplets: Chiplets, decoder: Decoder, diff --git a/processor/src/decoder/aux_trace/block_hash_table.rs b/processor/src/decoder/aux_trace/block_hash_table.rs index cba054d0bd..75d831b1fb 100644 --- a/processor/src/decoder/aux_trace/block_hash_table.rs +++ b/processor/src/decoder/aux_trace/block_hash_table.rs @@ -1,7 +1,7 @@ use miden_air::RowIndex; use vm_core::{ - Word, OPCODE_DYN, OPCODE_END, OPCODE_HALT, OPCODE_JOIN, OPCODE_LOOP, OPCODE_REPEAT, - OPCODE_SPLIT, ZERO, + Word, OPCODE_CALL, OPCODE_DYN, OPCODE_DYNCALL, OPCODE_END, OPCODE_HALT, OPCODE_JOIN, + OPCODE_LOOP, OPCODE_REPEAT, OPCODE_SPLIT, OPCODE_SYSCALL, ZERO, }; use super::{AuxColumnBuilder, Felt, FieldElement, MainTrace, ONE}; @@ -55,7 +55,9 @@ impl> AuxColumnBuilder for BlockHashTableCo .map(|row| row.collapse(alphas)) .unwrap_or(E::ONE), OPCODE_REPEAT => BlockHashTableRow::from_repeat(main_trace, row).collapse(alphas), - OPCODE_DYN => BlockHashTableRow::from_dyn(main_trace, row).collapse(alphas), + OPCODE_DYN | OPCODE_DYNCALL | OPCODE_CALL | OPCODE_SYSCALL => { + BlockHashTableRow::from_dyn_dyncall_call_syscall(main_trace, row).collapse(alphas) + }, _ => E::ONE, } } @@ -206,21 +208,15 @@ impl BlockHashTableRow { } } - /// Computes the row to add to the block hash table when encountering a `DYN` operation. - pub fn from_dyn(main_trace: &MainTrace, row: RowIndex) -> Self { - let child_block_hash = { - // Note: the child block hash is found on the stack, and hence in reverse order. - let s0 = main_trace.stack_element(0, row); - let s1 = main_trace.stack_element(1, row); - let s2 = main_trace.stack_element(2, row); - let s3 = main_trace.stack_element(3, row); - - [s3, s2, s1, s0] - }; - + /// Computes the row to add to the block hash table when encountering a `DYN`, `DYNCALL`, `CALL` + /// or `SYSCALL` operation. + /// + /// The hash of the child node being called is expected to be in the first half of the decoder + /// hasher state. + pub fn from_dyn_dyncall_call_syscall(main_trace: &MainTrace, row: RowIndex) -> Self { Self { parent_block_id: main_trace.addr(row + 1), - child_block_hash, + child_block_hash: main_trace.decoder_hasher_state_first_half(row), is_first_child: false, is_loop_body: false, } diff --git a/processor/src/decoder/aux_trace/block_stack_table.rs b/processor/src/decoder/aux_trace/block_stack_table.rs index 2516046aa6..e59c668617 100644 --- a/processor/src/decoder/aux_trace/block_stack_table.rs +++ b/processor/src/decoder/aux_trace/block_stack_table.rs @@ -1,7 +1,7 @@ use miden_air::RowIndex; use vm_core::{ - OPCODE_CALL, OPCODE_DYN, OPCODE_END, OPCODE_JOIN, OPCODE_LOOP, OPCODE_RESPAN, OPCODE_SPAN, - OPCODE_SPLIT, OPCODE_SYSCALL, + OPCODE_CALL, OPCODE_DYN, OPCODE_DYNCALL, OPCODE_END, OPCODE_JOIN, OPCODE_LOOP, OPCODE_RESPAN, + OPCODE_SPAN, OPCODE_SPLIT, OPCODE_SYSCALL, }; use super::{AuxColumnBuilder, Felt, FieldElement, MainTrace, ONE, ZERO}; @@ -21,10 +21,8 @@ impl> AuxColumnBuilder for BlockStackColumn let op_code = op_code_felt.as_int() as u8; match op_code { - OPCODE_RESPAN => { - get_block_stack_table_removal_multiplicand(main_trace, i, true, alphas) - }, - OPCODE_END => get_block_stack_table_removal_multiplicand(main_trace, i, false, alphas), + OPCODE_RESPAN => get_block_stack_table_respan_multiplicand(main_trace, i, alphas), + OPCODE_END => get_block_stack_table_end_multiplicand(main_trace, i, alphas), _ => E::ONE, } } @@ -35,8 +33,8 @@ impl> AuxColumnBuilder for BlockStackColumn let op_code = op_code_felt.as_int() as u8; match op_code { - OPCODE_JOIN | OPCODE_SPLIT | OPCODE_SPAN | OPCODE_DYN | OPCODE_LOOP | OPCODE_RESPAN - | OPCODE_CALL | OPCODE_SYSCALL => { + OPCODE_JOIN | OPCODE_SPLIT | OPCODE_SPAN | OPCODE_DYN | OPCODE_DYNCALL + | OPCODE_LOOP | OPCODE_RESPAN | OPCODE_CALL | OPCODE_SYSCALL => { get_block_stack_table_inclusion_multiplicand(main_trace, i, alphas, op_code) }, _ => E::ONE, @@ -47,19 +45,36 @@ impl> AuxColumnBuilder for BlockStackColumn // HELPER FUNCTIONS // ================================================================================================ -/// Computes the multiplicand representing the removal of a row from the block stack table. -fn get_block_stack_table_removal_multiplicand>( +/// Computes the multiplicand representing the removal of a row from the block stack table when +/// encountering a RESPAN operation. +fn get_block_stack_table_respan_multiplicand>( main_trace: &MainTrace, i: RowIndex, - is_respan: bool, alphas: &[E], ) -> E { let block_id = main_trace.addr(i); - let parent_id = if is_respan { - main_trace.decoder_hasher_state_element(1, i + 1) - } else { - main_trace.addr(i + 1) - }; + let parent_id = main_trace.decoder_hasher_state_element(1, i + 1); + let is_loop = ZERO; + + // Note: the last 8 elements are set to ZERO, so we omit them here. + let elements = [ONE, block_id, parent_id, is_loop]; + + let mut table_row = E::ZERO; + for (&alpha, &element) in alphas.iter().zip(elements.iter()) { + table_row += alpha.mul_base(element); + } + table_row +} + +/// Computes the multiplicand representing the removal of a row from the block stack table when +/// encountering an END operation. +fn get_block_stack_table_end_multiplicand>( + main_trace: &MainTrace, + i: RowIndex, + alphas: &[E], +) -> E { + let block_id = main_trace.addr(i); + let parent_id = main_trace.addr(i + 1); let is_loop = main_trace.is_loop_flag(i); let elements = if main_trace.is_call_flag(i) == ONE || main_trace.is_syscall_flag(i) == ONE { @@ -67,7 +82,7 @@ fn get_block_stack_table_removal_multiplicand> let parent_fmp = main_trace.fmp(i + 1); let parent_stack_depth = main_trace.stack_depth(i + 1); let parent_next_overflow_addr = main_trace.parent_overflow_address(i + 1); - let parent_fn_hash = main_trace.fn_hash(i); + let parent_fn_hash = main_trace.fn_hash(i + 1); [ ONE, @@ -81,7 +96,7 @@ fn get_block_stack_table_removal_multiplicand> parent_fn_hash[0], parent_fn_hash[1], parent_fn_hash[2], - parent_fn_hash[0], + parent_fn_hash[3], ] } else { let mut result = [ZERO; 12]; @@ -92,12 +107,11 @@ fn get_block_stack_table_removal_multiplicand> result }; - let mut value = E::ZERO; - + let mut table_row = E::ZERO; for (&alpha, &element) in alphas.iter().zip(elements.iter()) { - value += alpha.mul_base(element); + table_row += alpha.mul_base(element); } - value + table_row } /// Computes the multiplicand representing the inclusion of a new row to the block stack table. @@ -123,7 +137,32 @@ fn get_block_stack_table_inclusion_multiplicand>( ) -> E { let group_count = main_trace.group_count(i); let block_id = main_trace.addr(i); + let group_value = { + let op_code = main_trace.get_op_code(i); - let op_code = main_trace.get_op_code(i); - let tmp = if op_code == Felt::from(OPCODE_PUSH) { - main_trace.stack_element(0, i + 1) - } else { - let h0 = main_trace.decoder_hasher_state_first_half(i + 1)[0]; + if op_code == Felt::from(OPCODE_PUSH) { + main_trace.stack_element(0, i + 1) + } else if op_code == Felt::from(OPCODE_EMIT) { + main_trace.helper_register(0, i) + } else { + let h0 = main_trace.decoder_hasher_state_first_half(i + 1)[0]; - let op_prime = main_trace.get_op_code(i + 1); - h0.mul_small(1 << 7) + op_prime + let op_prime = main_trace.get_op_code(i + 1); + h0.mul_small(1 << 7) + op_prime + } }; + alphas[0] + alphas[1].mul_base(block_id) + alphas[2].mul_base(group_count) - + alphas[3].mul_base(tmp) + + alphas[3].mul_base(group_value) } diff --git a/processor/src/decoder/block_stack.rs b/processor/src/decoder/block_stack.rs index 9341c5916c..ec65e914f9 100644 --- a/processor/src/decoder/block_stack.rs +++ b/processor/src/decoder/block_stack.rs @@ -19,17 +19,20 @@ impl BlockStack { /// Pushes a new code block onto the block stack and returns the address of the block's parent. /// /// The block is identified by its address, and we also need to know what type of a block this - /// is. Additionally, for CALL blocks, execution context info must be provided. Other - /// information (i.e., the block's parent, whether the block is a body of a loop or a first - /// child of a JOIN block) is determined from the information already on the stack. + /// is. Additionally, for CALL, SYSCALL and DYNCALL blocks, execution context info must be + /// provided. Other information (i.e., the block's parent, whether the block is a body of a loop + /// or a first child of a JOIN block) is determined from the information already on the stack. pub fn push( &mut self, addr: Felt, block_type: BlockType, ctx_info: Option, ) -> Felt { - // make sure execution context was provided for CALL and SYSCALL blocks - if block_type == BlockType::Call || block_type == BlockType::SysCall { + // make sure execution context was provided for CALL, SYSCALL and DYNCALL blocks + if block_type == BlockType::Call + || block_type == BlockType::SysCall + || block_type == BlockType::Dyncall + { debug_assert!(ctx_info.is_some(), "no execution context provided for a CALL block"); } else { debug_assert!(ctx_info.is_none(), "execution context provided for a non-CALL block"); @@ -127,10 +130,10 @@ impl BlockInfo { } } - /// Returns ONE if this block is a CALL block; otherwise returns ZERO. + /// Returns ONE if this block is a CALL or DYNCALL block; otherwise returns ZERO. pub const fn is_call(&self) -> Felt { match self.block_type { - BlockType::Call => ONE, + BlockType::Call | BlockType::Dyncall => ONE, _ => ZERO, } } @@ -194,6 +197,7 @@ pub enum BlockType { Loop(bool), // internal value set to false if the loop is never entered Call, Dyn, + Dyncall, SysCall, Span, } diff --git a/processor/src/decoder/mod.rs b/processor/src/decoder/mod.rs index e403a9eb6a..a6248d0efb 100644 --- a/processor/src/decoder/mod.rs +++ b/processor/src/decoder/mod.rs @@ -14,7 +14,7 @@ use vm_core::{ mast::{ BasicBlockNode, CallNode, DynNode, JoinNode, LoopNode, MastForest, SplitNode, OP_BATCH_SIZE, }, - stack::STACK_TOP_SIZE, + stack::MIN_STACK_DEPTH, AssemblyOp, }; @@ -249,7 +249,7 @@ where self.system.start_syscall(); self.decoder.start_syscall(callee_hash, addr, ctx_info); } else { - self.system.start_call(callee_hash); + self.system.start_call_or_dyncall(callee_hash); self.decoder.start_call(callee_hash, addr, ctx_info); } @@ -261,7 +261,7 @@ where pub(super) fn end_call_node(&mut self, node: &CallNode) -> Result<(), ExecutionError> { // when a CALL block ends, stack depth must be exactly 16 let stack_depth = self.stack.depth(); - if stack_depth > STACK_TOP_SIZE { + if stack_depth > MIN_STACK_DEPTH { return Err(ExecutionError::InvalidStackDepthOnReturn(stack_depth)); } @@ -292,23 +292,119 @@ where // -------------------------------------------------------------------------------------------- /// Starts decoding of a DYN node. - pub(super) fn start_dyn_node(&mut self, callee_hash: Word) -> Result<(), ExecutionError> { + /// + /// Note: even though we will write the callee hash to h[0..4] for the chiplets bus and block + /// hash table, the issued hash request is still hash([ZERO; 8]). + pub(super) fn start_dyn_node(&mut self, dyn_node: &DynNode) -> Result { + debug_assert!(!dyn_node.is_dyncall()); + + let mem_addr = self.stack.get(0); + // The callee hash is stored in memory, and the address is specified on the top of the + // stack. + let callee_hash = self.read_mem_word(mem_addr)?; + let addr = self.chiplets.hash_control_block( EMPTY_WORD, EMPTY_WORD, - DynNode::DOMAIN, - DynNode.digest(), + dyn_node.domain(), + dyn_node.digest(), ); - self.decoder.start_dyn(callee_hash, addr); - self.execute_op(Operation::Noop) + self.decoder.start_dyn(addr, callee_hash); + + // Pop the memory address off the stack. + self.execute_op(Operation::Drop)?; + + Ok(callee_hash) + } + + /// Starts decoding of a DYNCALL node. + /// + /// Note: even though we will write the callee hash to h[0..4] for the chiplets bus and block + /// hash table, and the stack helper registers to h[4..5], the issued hash request is still + /// hash([ZERO; 8]). + pub(super) fn start_dyncall_node( + &mut self, + dyn_node: &DynNode, + ) -> Result { + debug_assert!(dyn_node.is_dyncall()); + + let mem_addr = self.stack.get(0); + // The callee hash is stored in memory, and the address is specified on the top of the + // stack. + let callee_hash = self.read_mem_word(mem_addr)?; + + // Note: other functions end in "executing a Noop", which + // 1. ensures trace capacity, + // 2. copies the stack over to the next row, + // 3. advances clock. + // + // Dyncall's effect on the trace can't be written in terms of any other operation, and + // therefore can't follow this framework. Hence, we do it "manually". It's probably worth + // refactoring the decoder though to remove this Noop execution pattern. + self.ensure_trace_capacity(); + + let addr = self.chiplets.hash_control_block( + EMPTY_WORD, + EMPTY_WORD, + dyn_node.domain(), + dyn_node.digest(), + ); + + let (stack_depth, next_overflow_addr) = self.stack.shift_left_and_start_context(); + debug_assert!(stack_depth <= u32::MAX as usize, "stack depth too big"); + + let ctx_info = ExecutionContextInfo::new( + self.system.ctx(), + self.system.fn_hash(), + self.system.fmp(), + stack_depth as u32, + next_overflow_addr, + ); + + self.system.start_call_or_dyncall(callee_hash); + self.decoder.start_dyncall(addr, callee_hash, ctx_info); + + self.advance_clock()?; + + Ok(callee_hash) } /// Ends decoding of a DYN node. - pub(super) fn end_dyn_node(&mut self) -> Result<(), ExecutionError> { + pub(super) fn end_dyn_node(&mut self, dyn_node: &DynNode) -> Result<(), ExecutionError> { + // this appends a row with END operation to the decoder trace. when the END operation is + // executed the rest of the VM state does not change + self.decoder.end_control_block(dyn_node.digest().into()); + + self.execute_op(Operation::Noop) + } + + /// Ends decoding of a DYNCALL node. + pub(super) fn end_dyncall_node(&mut self, dyn_node: &DynNode) -> Result<(), ExecutionError> { + // when a DYNCALL block ends, stack depth must be exactly 16 + let stack_depth = self.stack.depth(); + if stack_depth > MIN_STACK_DEPTH { + return Err(ExecutionError::InvalidStackDepthOnReturn(stack_depth)); + } + // this appends a row with END operation to the decoder trace. when the END operation is // executed the rest of the VM state does not change - self.decoder.end_control_block(DynNode.digest().into()); + let ctx_info = self + .decoder + .end_control_block(dyn_node.digest().into()) + .expect("no execution context"); + + // when returning from a function call, restore the context of the system + // registers and the operand stack to what it was prior to the call. + self.system.restore_context( + ctx_info.parent_ctx, + ctx_info.parent_fmp, + ctx_info.parent_fn_hash, + ); + self.stack.restore_context( + ctx_info.parent_stack_depth as usize, + ctx_info.parent_next_overflow_addr, + ); self.execute_op(Operation::Noop) } @@ -532,16 +628,50 @@ impl Decoder { /// Starts decoding of a DYN block. /// + /// Note that even though the hasher decoder columns are populated, the issued hash request is + /// still for [ZERO; 8 | domain=DYN]. This is because a `DYN` node takes its child on the stack, + /// and therefore the child hash cannot be included in the `DYN` node hash computation (see + /// [`vm_core::mast::DynNode`]). The decoder hasher columns are then not needed for the `DYN` + /// node hash computation, and so were used to store the result of the memory read operation for + /// the child hash. + /// /// This pushes a block with ID=addr onto the block stack and appends execution of a DYN /// operation to the trace. - pub fn start_dyn(&mut self, dyn_hash: Word, addr: Felt) { + pub fn start_dyn(&mut self, addr: Felt, callee_hash: Word) { // push DYN block info onto the block stack and append a DYN row to the execution trace let parent_addr = self.block_stack.push(addr, BlockType::Dyn, None); - self.trace.append_block_start(parent_addr, Operation::Dyn, dyn_hash, [ZERO; 4]); + self.trace + .append_block_start(parent_addr, Operation::Dyn, callee_hash, [ZERO; 4]); self.debug_info.append_operation(Operation::Dyn); } + /// Starts decoding of a DYNCALL block. + /// + /// Note that even though the hasher decoder columns are populated, the issued hash request is + /// still for [ZERO; 8 | domain=DYNCALL]. + /// + /// This pushes a block with ID=addr onto the block stack and appends execution of a DYNCALL + /// operation to the trace. The decoder hasher trace columns are populated with the callee hash, + /// as well as the stack helper registers (specifically their state after shifting the stack + /// left). We need to store those in the decoder trace so that the block stack table can access + /// them (since in the next row, we start a new context, and hence the stack registers are reset + /// to their default values). + pub fn start_dyncall(&mut self, addr: Felt, callee_hash: Word, ctx_info: ExecutionContextInfo) { + let parent_stack_depth = ctx_info.parent_stack_depth.into(); + let parent_next_overflow_addr = ctx_info.parent_next_overflow_addr; + + let parent_addr = self.block_stack.push(addr, BlockType::Dyncall, Some(ctx_info)); + self.trace.append_block_start( + parent_addr, + Operation::Dyncall, + callee_hash, + [parent_stack_depth, parent_next_overflow_addr, ZERO, ZERO], + ); + + self.debug_info.append_operation(Operation::Dyncall); + } + /// Ends decoding of a control block (i.e., a non-SPAN block). /// /// This appends an execution of an END operation to the trace. The top block on the block @@ -655,7 +785,10 @@ impl Decoder { /// TODO: it might be better to get the operation information from the decoder trace, rather /// than passing it in as a parameter. pub fn set_user_op_helpers(&mut self, op: Operation, values: &[Felt]) { - debug_assert!(!op.is_control_op(), "op is a control operation"); + debug_assert!( + !op.populates_decoder_hasher_registers(), + "user op helper registers not available for op" + ); self.trace.set_user_op_helpers(values); } diff --git a/processor/src/decoder/tests.rs b/processor/src/decoder/tests.rs index f17f94d1bf..c0a90bdec1 100644 --- a/processor/src/decoder/tests.rs +++ b/processor/src/decoder/tests.rs @@ -53,8 +53,9 @@ fn basic_block_one_group() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -92,23 +93,86 @@ fn basic_block_one_group() { #[test] fn basic_block_small() { let iv = [ONE, TWO]; - let ops = vec![Operation::Push(iv[0]), Operation::Push(iv[1]), Operation::Add]; + let ops = vec![ + Operation::Push(iv[0]), + Operation::Push(iv[1]), + Operation::Add, + Operation::Swap, + Operation::Drop, + ]; + let basic_block = BasicBlockNode::new(ops.clone(), None).unwrap(); + let program = { + let mut mast_forest = MastForest::new(); + + let basic_block_node = MastNode::Block(basic_block.clone()); + let basic_block_id = mast_forest.add_node(basic_block_node).unwrap(); + mast_forest.make_root(basic_block_id); + + Program::new(mast_forest.into(), basic_block_id) + }; + + let (trace, trace_len) = build_trace(&[], &program); + + // --- check block address, op_bits, group count, op_index, and in_span columns --------------- + check_op_decoding(&trace, 0, ZERO, Operation::Span, 4, 0, 0); + check_op_decoding(&trace, 1, INIT_ADDR, Operation::Push(ONE), 3, 0, 1); + check_op_decoding(&trace, 2, INIT_ADDR, Operation::Push(TWO), 2, 1, 1); + check_op_decoding(&trace, 3, INIT_ADDR, Operation::Add, 1, 2, 1); + check_op_decoding(&trace, 4, INIT_ADDR, Operation::Swap, 1, 3, 1); + check_op_decoding(&trace, 5, INIT_ADDR, Operation::Drop, 1, 4, 1); + + // starting new group: NOOP group is inserted by the processor to make sure number of groups + // is a power of two + check_op_decoding(&trace, 6, INIT_ADDR, Operation::Noop, 0, 0, 1); + check_op_decoding(&trace, 7, INIT_ADDR, Operation::End, 0, 0, 0); + check_op_decoding(&trace, 8, ZERO, Operation::Halt, 0, 0, 0); + + // --- check hasher state columns ------------------------------------------------------------- + let program_hash: Word = program.hash().into(); + + check_hasher_state( + &trace, + vec![ + basic_block.op_batches()[0].groups().to_vec(), + vec![build_op_group(&ops[1..])], + vec![build_op_group(&ops[2..])], + vec![build_op_group(&ops[3..])], + vec![build_op_group(&ops[4..])], + vec![], + vec![], + program_hash.to_vec(), // last row should contain program hash + ], + ); + + // HALT opcode and program hash gets propagated to the last row + for i in 8..trace_len { + assert!(contains_op(&trace, i, Operation::Halt)); + assert_eq!(ZERO, trace[OP_BITS_EXTRA_COLS_RANGE.start][i]); + assert_eq!(ONE, trace[OP_BITS_EXTRA_COLS_RANGE.start + 1][i]); + assert_eq!(program_hash, get_hasher_state1(&trace, i)); + } +} + +#[test] +fn basic_block_small_with_emit() { + let ops = vec![Operation::Push(ONE), Operation::Emit(1), Operation::Add]; let basic_block = BasicBlockNode::new(ops.clone(), None).unwrap(); let program = { let mut mast_forest = MastForest::new(); let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); // --- check block address, op_bits, group count, op_index, and in_span columns --------------- check_op_decoding(&trace, 0, ZERO, Operation::Span, 4, 0, 0); - check_op_decoding(&trace, 1, INIT_ADDR, Operation::Push(iv[0]), 3, 0, 1); - check_op_decoding(&trace, 2, INIT_ADDR, Operation::Push(iv[1]), 2, 1, 1); + check_op_decoding(&trace, 1, INIT_ADDR, Operation::Push(ONE), 3, 0, 1); + check_op_decoding(&trace, 2, INIT_ADDR, Operation::Emit(1), 2, 1, 1); check_op_decoding(&trace, 3, INIT_ADDR, Operation::Add, 1, 2, 1); // starting new group: NOOP group is inserted by the processor to make sure number of groups // is a power of two @@ -123,7 +187,8 @@ fn basic_block_small() { vec![ basic_block.op_batches()[0].groups().to_vec(), vec![build_op_group(&ops[1..])], - vec![build_op_group(&ops[2..])], + // emit(1) + vec![build_op_group(&ops[2..]), ZERO, ONE], vec![], vec![], program_hash.to_vec(), // last row should contain program hash @@ -155,6 +220,8 @@ fn basic_block() { Operation::Mul, Operation::Add, Operation::Inv, + Operation::Swap, + Operation::Drop, ]; let basic_block = BasicBlockNode::new(ops.clone(), None).unwrap(); let program = { @@ -162,8 +229,9 @@ fn basic_block() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -184,10 +252,13 @@ fn basic_block() { check_op_decoding(&trace, 11, INIT_ADDR, Operation::Mul, 1, 1, 1); check_op_decoding(&trace, 12, INIT_ADDR, Operation::Add, 1, 2, 1); check_op_decoding(&trace, 13, INIT_ADDR, Operation::Inv, 1, 3, 1); + check_op_decoding(&trace, 14, INIT_ADDR, Operation::Swap, 1, 4, 1); + check_op_decoding(&trace, 15, INIT_ADDR, Operation::Drop, 1, 5, 1); + // NOOP inserted by the processor to make sure the number of groups is a power of two - check_op_decoding(&trace, 14, INIT_ADDR, Operation::Noop, 0, 0, 1); - check_op_decoding(&trace, 15, INIT_ADDR, Operation::End, 0, 0, 0); - check_op_decoding(&trace, 16, ZERO, Operation::Halt, 0, 0, 0); + check_op_decoding(&trace, 16, INIT_ADDR, Operation::Noop, 0, 0, 1); + check_op_decoding(&trace, 17, INIT_ADDR, Operation::End, 0, 0, 0); + check_op_decoding(&trace, 18, ZERO, Operation::Halt, 0, 0, 0); // --- check hasher state columns ------------------------------------------------------------- let program_hash: Word = program.hash().into(); @@ -207,6 +278,8 @@ fn basic_block() { vec![build_op_group(&ops[9..])], // next group starts vec![build_op_group(&ops[10..])], vec![build_op_group(&ops[11..])], + vec![build_op_group(&ops[12..])], + vec![build_op_group(&ops[13..])], vec![], vec![], // a group with single NOOP added at the end program_hash.to_vec(), // last row should contain program hash @@ -214,7 +287,7 @@ fn basic_block() { ); // HALT opcode and program hash gets propagated to the last row - for i in 17..trace_len { + for i in 18..trace_len { assert!(contains_op(&trace, i, Operation::Halt)); assert_eq!(ZERO, trace[OP_BITS_EXTRA_COLS_RANGE.start][i]); assert_eq!(ONE, trace[OP_BITS_EXTRA_COLS_RANGE.start + 1][i]); @@ -247,6 +320,15 @@ fn span_block_with_respan() { Operation::Push(iv[7]), Operation::Add, Operation::Push(iv[8]), + Operation::SwapDW, + Operation::Drop, + Operation::Drop, + Operation::Drop, + Operation::Drop, + Operation::Drop, + Operation::Drop, + Operation::Drop, + Operation::Drop, ]; let basic_block = BasicBlockNode::new(ops.clone(), None).unwrap(); let program = { @@ -254,8 +336,9 @@ fn span_block_with_respan() { let basic_block_node = MastNode::Block(basic_block.clone()); let basic_block_id = mast_forest.add_node(basic_block_node).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -276,15 +359,23 @@ fn span_block_with_respan() { check_op_decoding(&trace, 10, batch1_addr, Operation::Push(iv[7]), 3, 0, 1); check_op_decoding(&trace, 11, batch1_addr, Operation::Add, 2, 1, 1); check_op_decoding(&trace, 12, batch1_addr, Operation::Push(iv[8]), 2, 2, 1); - // NOOP inserted by the processor to make sure the group doesn't end with a PUSH - check_op_decoding(&trace, 13, batch1_addr, Operation::Noop, 1, 3, 1); - // NOOP inserted by the processor to make sure the number of groups is a power of two - check_op_decoding(&trace, 14, batch1_addr, Operation::Noop, 0, 0, 1); - check_op_decoding(&trace, 15, batch1_addr, Operation::End, 0, 0, 0); - check_op_decoding(&trace, 16, ZERO, Operation::Halt, 0, 0, 0); + + check_op_decoding(&trace, 13, batch1_addr, Operation::SwapDW, 1, 3, 1); + check_op_decoding(&trace, 14, batch1_addr, Operation::Drop, 1, 4, 1); + check_op_decoding(&trace, 15, batch1_addr, Operation::Drop, 1, 5, 1); + check_op_decoding(&trace, 16, batch1_addr, Operation::Drop, 1, 6, 1); + check_op_decoding(&trace, 17, batch1_addr, Operation::Drop, 1, 7, 1); + check_op_decoding(&trace, 18, batch1_addr, Operation::Drop, 1, 8, 1); + check_op_decoding(&trace, 19, batch1_addr, Operation::Drop, 0, 0, 1); + check_op_decoding(&trace, 20, batch1_addr, Operation::Drop, 0, 1, 1); + check_op_decoding(&trace, 21, batch1_addr, Operation::Drop, 0, 2, 1); + + check_op_decoding(&trace, 22, batch1_addr, Operation::End, 0, 0, 0); + check_op_decoding(&trace, 23, ZERO, Operation::Halt, 0, 0, 0); // --- check hasher state columns ------------------------------------------------------------- let program_hash: Word = program.hash().into(); + check_hasher_state( &trace, vec![ @@ -298,17 +389,24 @@ fn span_block_with_respan() { vec![], vec![], // a NOOP inserted after last PUSH basic_block.op_batches()[1].groups().to_vec(), - vec![build_op_group(&ops[8..])], // next group starts - vec![build_op_group(&ops[9..])], + vec![build_op_group(&ops[8..16])], // next group starts + vec![build_op_group(&ops[9..16])], + vec![build_op_group(&ops[10..16])], + vec![build_op_group(&ops[11..16])], + vec![build_op_group(&ops[12..16])], + vec![build_op_group(&ops[13..16])], + vec![build_op_group(&ops[14..16])], + vec![build_op_group(&ops[15..16])], + vec![], + vec![build_op_group(&ops[17..])], + vec![build_op_group(&ops[18..])], vec![], - vec![], // a NOOP is inserted after last PUSH - vec![], // a group with single NOOP added at the end program_hash.to_vec(), // last row should contain program hash ], ); // HALT opcode and program hash gets propagated to the last row - for i in 17..trace_len { + for i in 23..trace_len { assert!(contains_op(&trace, i, Operation::Halt)); assert_eq!(ZERO, trace[OP_BITS_EXTRA_COLS_RANGE.start][i]); assert_eq!(ONE, trace[OP_BITS_EXTRA_COLS_RANGE.start + 1][i]); @@ -330,8 +428,9 @@ fn join_node() { let basic_block2_id = mast_forest.add_node(basic_block2.clone()).unwrap(); let join_node_id = mast_forest.add_join(basic_block1_id, basic_block2_id).unwrap(); + mast_forest.make_root(join_node_id); - Program::new(mast_forest, join_node_id) + Program::new(mast_forest.into(), join_node_id) }; let (trace, trace_len) = build_trace(&[], &program); @@ -395,8 +494,9 @@ fn split_node_true() { let basic_block2_id = mast_forest.add_node(basic_block2.clone()).unwrap(); let split_node_id = mast_forest.add_split(basic_block1_id, basic_block2_id).unwrap(); + mast_forest.make_root(split_node_id); - Program::new(mast_forest, split_node_id) + Program::new(mast_forest.into(), split_node_id) }; let (trace, trace_len) = build_trace(&[1], &program); @@ -447,8 +547,9 @@ fn split_node_false() { let basic_block2_id = mast_forest.add_node(basic_block2.clone()).unwrap(); let split_node_id = mast_forest.add_split(basic_block1_id, basic_block2_id).unwrap(); + mast_forest.make_root(split_node_id); - Program::new(mast_forest, split_node_id) + Program::new(mast_forest.into(), split_node_id) }; let (trace, trace_len) = build_trace(&[0], &program); @@ -498,10 +599,10 @@ fn loop_node() { let mut mast_forest = MastForest::new(); let loop_body_id = mast_forest.add_node(loop_body.clone()).unwrap(); - let loop_node_id = mast_forest.add_loop(loop_body_id).unwrap(); + mast_forest.make_root(loop_node_id); - Program::new(mast_forest, loop_node_id) + Program::new(mast_forest.into(), loop_node_id) }; let (trace, trace_len) = build_trace(&[0, 1], &program); @@ -550,10 +651,10 @@ fn loop_node_skip() { let mut mast_forest = MastForest::new(); let loop_body_id = mast_forest.add_node(loop_body.clone()).unwrap(); - let loop_node_id = mast_forest.add_loop(loop_body_id).unwrap(); + mast_forest.make_root(loop_node_id); - Program::new(mast_forest, loop_node_id) + Program::new(mast_forest.into(), loop_node_id) }; let (trace, trace_len) = build_trace(&[0], &program); @@ -592,10 +693,10 @@ fn loop_node_repeat() { let mut mast_forest = MastForest::new(); let loop_body_id = mast_forest.add_node(loop_body.clone()).unwrap(); - let loop_node_id = mast_forest.add_loop(loop_body_id).unwrap(); + mast_forest.make_root(loop_node_id); - Program::new(mast_forest, loop_node_id) + Program::new(mast_forest.into(), loop_node_id) }; let (trace, trace_len) = build_trace(&[0, 1, 1], &program); @@ -670,6 +771,8 @@ fn call_block() { // fmp <- fmp + 2 // call.foo // stack[0] <- fmp + // swap + // drop // end let mut mast_forest = MastForest::new(); @@ -686,7 +789,7 @@ fn call_block() { ], None).unwrap(); let foo_root_node_id = mast_forest.add_node(foo_root_node.clone()).unwrap(); - let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd], None).unwrap(); + let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd, Operation::Swap, Operation::Drop], None).unwrap(); let last_basic_block_id = mast_forest.add_node(last_basic_block.clone()).unwrap(); let foo_call_node = MastNode::new_call(foo_root_node_id, &mast_forest).unwrap(); @@ -696,8 +799,9 @@ fn call_block() { let join1_node_id = mast_forest.add_node(join1_node.clone()).unwrap(); let program_root_id = mast_forest.add_join(join1_node_id, last_basic_block_id).unwrap(); + mast_forest.make_root(program_root_id); - let program = Program::new(mast_forest, program_root_id); + let program = Program::new(mast_forest.into(), program_root_id); let (sys_trace, dec_trace, trace_len) = build_call_trace(&program, Kernel::default()); @@ -731,10 +835,13 @@ fn call_block() { let last_basic_block_addr = foo_root_addr + EIGHT; check_op_decoding(&dec_trace, 14, INIT_ADDR, Operation::Span, 1, 0, 0); check_op_decoding(&dec_trace, 15, last_basic_block_addr, Operation::FmpAdd, 0, 0, 1); - check_op_decoding(&dec_trace, 16, last_basic_block_addr, Operation::End, 0, 0, 0); + check_op_decoding(&dec_trace, 16, last_basic_block_addr, Operation::Swap, 0, 1, 1); + check_op_decoding(&dec_trace, 17, last_basic_block_addr, Operation::Drop, 0, 2, 1); + + check_op_decoding(&dec_trace, 18, last_basic_block_addr, Operation::End, 0, 0, 0); // ending the program - check_op_decoding(&dec_trace, 17, INIT_ADDR, Operation::End, 0, 0, 0); - check_op_decoding(&dec_trace, 18, ZERO, Operation::Halt, 0, 0, 0); + check_op_decoding(&dec_trace, 19, INIT_ADDR, Operation::End, 0, 0, 0); + check_op_decoding(&dec_trace, 20, ZERO, Operation::Halt, 0, 0, 0); // --- check hasher state columns ------------------------------------------------------------- // in the first row, the hasher state is set to hashes of (join1, span3) @@ -772,16 +879,16 @@ fn call_block() { assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 13)); // span3 ends in the 14th row - assert_eq!(last_basic_block_hash, get_hasher_state1(&dec_trace, 16)); - assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 16)); + assert_eq!(last_basic_block_hash, get_hasher_state1(&dec_trace, 18)); + assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 18)); - // the program ends in the 17th row + // the program ends in the 19th row let program_hash: Word = program.hash().into(); - assert_eq!(program_hash, get_hasher_state1(&dec_trace, 17)); - assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 17)); + assert_eq!(program_hash, get_hasher_state1(&dec_trace, 19)); + assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 19)); // HALT opcode and program hash gets propagated to the last row - for i in 18..trace_len { + for i in 20..trace_len { assert!(contains_op(&dec_trace, i, Operation::Halt)); assert_eq!(ZERO, dec_trace[OP_BITS_EXTRA_COLS_RANGE.start][i]); assert_eq!(ONE, dec_trace[OP_BITS_EXTRA_COLS_RANGE.start + 1][i]); @@ -882,6 +989,8 @@ fn syscall_block() { // fmp <- fmp + 1 // syscall.bar // stack[0] <- fmp + // swap + // drop // end let mut mast_forest = MastForest::new(); @@ -911,7 +1020,7 @@ fn syscall_block() { ], None).unwrap(); let first_basic_block_id = mast_forest.add_node(first_basic_block.clone()).unwrap(); - let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd], None).unwrap(); + let last_basic_block = MastNode::new_basic_block(vec![Operation::FmpAdd, Operation::Swap, Operation::Drop], None).unwrap(); let last_basic_block_id = mast_forest.add_node(last_basic_block.clone()).unwrap(); let bar_call_node = MastNode::new_call(bar_root_node_id, &mast_forest).unwrap(); @@ -922,8 +1031,9 @@ fn syscall_block() { let program_root_node = MastNode::new_join(inner_join_node_id, last_basic_block_id, &mast_forest).unwrap(); let program_root_node_id = mast_forest.add_node(program_root_node.clone()).unwrap(); + mast_forest.make_root(program_root_node_id); - let program = Program::with_kernel(mast_forest, program_root_node_id, kernel.clone()); + let program = Program::with_kernel(mast_forest.into(), program_root_node_id, kernel.clone()); let (sys_trace, dec_trace, trace_len) = build_call_trace(&program, kernel); @@ -977,11 +1087,13 @@ fn syscall_block() { let last_basic_block_addr = syscall_basic_block_addr + EIGHT; check_op_decoding(&dec_trace, 22, INIT_ADDR, Operation::Span, 1, 0, 0); check_op_decoding(&dec_trace, 23, last_basic_block_addr, Operation::FmpAdd, 0, 0, 1); - check_op_decoding(&dec_trace, 24, last_basic_block_addr, Operation::End, 0, 0, 0); + check_op_decoding(&dec_trace, 24, last_basic_block_addr, Operation::Swap, 0, 1, 1); + check_op_decoding(&dec_trace, 25, last_basic_block_addr, Operation::Drop, 0, 2, 1); + check_op_decoding(&dec_trace, 26, last_basic_block_addr, Operation::End, 0, 0, 0); // ending the program - check_op_decoding(&dec_trace, 25, INIT_ADDR, Operation::End, 0, 0, 0); - check_op_decoding(&dec_trace, 26, ZERO, Operation::Halt, 0, 0, 0); + check_op_decoding(&dec_trace, 27, INIT_ADDR, Operation::End, 0, 0, 0); + check_op_decoding(&dec_trace, 28, ZERO, Operation::Halt, 0, 0, 0); // --- check hasher state columns ------------------------------------------------------------- // in the first row, the hasher state is set to hashes of (inner_join, last_span) @@ -1042,17 +1154,17 @@ fn syscall_block() { assert_eq!(inner_join_hash, get_hasher_state1(&dec_trace, 21)); assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 21)); - // last span ends in the 24th row - assert_eq!(last_span_hash, get_hasher_state1(&dec_trace, 24)); - assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 24)); + // last span ends in the 26th row + assert_eq!(last_span_hash, get_hasher_state1(&dec_trace, 26)); + assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 26)); - // the program ends in the 25th row + // the program ends in the 27th row let program_hash: Word = program_root_node.digest().into(); - assert_eq!(program_hash, get_hasher_state1(&dec_trace, 25)); - assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 25)); + assert_eq!(program_hash, get_hasher_state1(&dec_trace, 27)); + assert_eq!(EMPTY_WORD, get_hasher_state2(&dec_trace, 27)); // HALT opcode and program hash gets propagated to the last row - for i in 26..trace_len { + for i in 28..trace_len { assert!(contains_op(&dec_trace, i, Operation::Halt)); assert_eq!(ZERO, dec_trace[OP_BITS_EXTRA_COLS_RANGE.start][i]); assert_eq!(ONE, dec_trace[OP_BITS_EXTRA_COLS_RANGE.start + 1][i]); @@ -1170,8 +1282,21 @@ fn syscall_block() { // ================================================================================================ #[test] fn dyn_block() { - // build a dynamic block which looks like this: - // push.1 add + // Equivalent masm: + // + // proc.foo + // push.1 add + // end + // + // begin + // # stack: [42, DIGEST] + // mstorew + // push.42 + // dynexec + // end + + const FOO_ROOT_NODE_ADDR: u64 = 42; + const PUSH_42_OP: Operation = Operation::Push(Felt::new(FOO_ROOT_NODE_ADDR)); let mut mast_forest = MastForest::new(); @@ -1180,13 +1305,13 @@ fn dyn_block() { let foo_root_node_id = mast_forest.add_node(foo_root_node.clone()).unwrap(); mast_forest.make_root(foo_root_node_id); - let mul_bb_node = MastNode::new_basic_block(vec![Operation::Mul], None).unwrap(); - let mul_bb_node_id = mast_forest.add_node(mul_bb_node.clone()).unwrap(); + let mstorew_node = MastNode::new_basic_block(vec![Operation::MStoreW], None).unwrap(); + let mstorew_node_id = mast_forest.add_node(mstorew_node.clone()).unwrap(); - let save_bb_node = MastNode::new_basic_block(vec![Operation::MovDn4], None).unwrap(); - let save_bb_node_id = mast_forest.add_node(save_bb_node.clone()).unwrap(); + let push_node = MastNode::new_basic_block(vec![PUSH_42_OP], None).unwrap(); + let push_node_id = mast_forest.add_node(push_node.clone()).unwrap(); - let join_node = MastNode::new_join(mul_bb_node_id, save_bb_node_id, &mast_forest).unwrap(); + let join_node = MastNode::new_join(mstorew_node_id, push_node_id, &mast_forest).unwrap(); let join_node_id = mast_forest.add_node(join_node.clone()).unwrap(); // This dyn will point to foo. @@ -1195,8 +1320,9 @@ fn dyn_block() { let program_root_node = MastNode::new_join(join_node_id, dyn_node_id, &mast_forest).unwrap(); let program_root_node_id = mast_forest.add_node(program_root_node.clone()).unwrap(); + mast_forest.make_root(program_root_node_id); - let program = Program::new(mast_forest, program_root_node_id); + let program = Program::new(mast_forest.into(), program_root_node_id); let (trace, trace_len) = build_dyn_trace( &[ @@ -1204,8 +1330,7 @@ fn dyn_block() { foo_root_node.digest()[1].as_int(), foo_root_node.digest()[2].as_int(), foo_root_node.digest()[3].as_int(), - 2, - 4, + FOO_ROOT_NODE_ADDR, ], &program, ); @@ -1216,30 +1341,31 @@ fn dyn_block() { let join_addr = INIT_ADDR + EIGHT; check_op_decoding(&trace, 1, INIT_ADDR, Operation::Join, 0, 0, 0); // starting first span - let mul_basic_block_addr = join_addr + EIGHT; + let mstorew_basic_block_addr = join_addr + EIGHT; check_op_decoding(&trace, 2, join_addr, Operation::Span, 1, 0, 0); - check_op_decoding(&trace, 3, mul_basic_block_addr, Operation::Mul, 0, 0, 1); - check_op_decoding(&trace, 4, mul_basic_block_addr, Operation::End, 0, 0, 0); + check_op_decoding(&trace, 3, mstorew_basic_block_addr, Operation::MStoreW, 0, 0, 1); + check_op_decoding(&trace, 4, mstorew_basic_block_addr, Operation::End, 0, 0, 0); // starting second span - let save_basic_block_addr = mul_basic_block_addr + EIGHT; - check_op_decoding(&trace, 5, join_addr, Operation::Span, 1, 0, 0); - check_op_decoding(&trace, 6, save_basic_block_addr, Operation::MovDn4, 0, 0, 1); - check_op_decoding(&trace, 7, save_basic_block_addr, Operation::End, 0, 0, 0); + let push_basic_block_addr = mstorew_basic_block_addr + EIGHT; + check_op_decoding(&trace, 5, join_addr, Operation::Span, 2, 0, 0); + check_op_decoding(&trace, 6, push_basic_block_addr, PUSH_42_OP, 1, 0, 1); + check_op_decoding(&trace, 7, push_basic_block_addr, Operation::Noop, 0, 1, 1); + check_op_decoding(&trace, 8, push_basic_block_addr, Operation::End, 0, 0, 0); // end inner join - check_op_decoding(&trace, 8, join_addr, Operation::End, 0, 0, 0); + check_op_decoding(&trace, 9, join_addr, Operation::End, 0, 0, 0); // dyn - check_op_decoding(&trace, 9, INIT_ADDR, Operation::Dyn, 0, 0, 0); + check_op_decoding(&trace, 10, INIT_ADDR, Operation::Dyn, 0, 0, 0); // starting foo span - let dyn_addr = save_basic_block_addr + EIGHT; + let dyn_addr = push_basic_block_addr + EIGHT; let add_basic_block_addr = dyn_addr + EIGHT; - check_op_decoding(&trace, 10, dyn_addr, Operation::Span, 2, 0, 0); - check_op_decoding(&trace, 11, add_basic_block_addr, Operation::Push(ONE), 1, 0, 1); - check_op_decoding(&trace, 12, add_basic_block_addr, Operation::Add, 0, 1, 1); - check_op_decoding(&trace, 13, add_basic_block_addr, Operation::End, 0, 0, 0); + check_op_decoding(&trace, 11, dyn_addr, Operation::Span, 2, 0, 0); + check_op_decoding(&trace, 12, add_basic_block_addr, Operation::Push(ONE), 1, 0, 1); + check_op_decoding(&trace, 13, add_basic_block_addr, Operation::Add, 0, 1, 1); + check_op_decoding(&trace, 14, add_basic_block_addr, Operation::End, 0, 0, 0); // end dyn - check_op_decoding(&trace, 14, dyn_addr, Operation::End, 0, 0, 0); + check_op_decoding(&trace, 15, dyn_addr, Operation::End, 0, 0, 0); // end outer join - check_op_decoding(&trace, 15, INIT_ADDR, Operation::End, 0, 0, 0); + check_op_decoding(&trace, 16, INIT_ADDR, Operation::End, 0, 0, 0); // --- check hasher state columns ------------------------------------------------------------- @@ -1250,8 +1376,8 @@ fn dyn_block() { assert_eq!(dyn_hash, get_hasher_state2(&trace, 0)); // in the second row, the hasher set is set to hashes of both child nodes of the inner JOIN - let mul_bb_node_hash: Word = mul_bb_node.digest().into(); - let save_bb_node_hash: Word = save_bb_node.digest().into(); + let mul_bb_node_hash: Word = mstorew_node.digest().into(); + let save_bb_node_hash: Word = push_node.digest().into(); assert_eq!(mul_bb_node_hash, get_hasher_state1(&trace, 1)); assert_eq!(save_bb_node_hash, get_hasher_state2(&trace, 1)); @@ -1260,32 +1386,31 @@ fn dyn_block() { assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 4)); // at the end of the second SPAN, the hasher state is set to the hash of the second child - assert_eq!(save_bb_node_hash, get_hasher_state1(&trace, 7)); - assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 7)); + assert_eq!(save_bb_node_hash, get_hasher_state1(&trace, 8)); + assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 8)); // at the end of the inner JOIN, the hasher set is set to the hash of the JOIN - assert_eq!(join_hash, get_hasher_state1(&trace, 8)); - assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 8)); + assert_eq!(join_hash, get_hasher_state1(&trace, 9)); + assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 9)); - // at the start of the DYN block, the hasher state is set to the hash of its child (foo span) + // at the start of the DYN block, the hasher state is set to foo digest let foo_hash: Word = foo_root_node.digest().into(); - assert_eq!(foo_hash, get_hasher_state1(&trace, 9)); - assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 9)); + assert_eq!(foo_hash, get_hasher_state1(&trace, 10)); // at the end of the DYN SPAN, the hasher state is set to the hash of the foo span - assert_eq!(foo_hash, get_hasher_state1(&trace, 13)); - assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 13)); + assert_eq!(foo_hash, get_hasher_state1(&trace, 14)); + assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 14)); // at the end of the DYN block, the hasher state is set to the hash of the DYN node - assert_eq!(dyn_hash, get_hasher_state1(&trace, 14)); + assert_eq!(dyn_hash, get_hasher_state1(&trace, 15)); // at the end of the program, the hasher state is set to the hash of the entire program let program_hash: Word = program_root_node.digest().into(); - assert_eq!(program_hash, get_hasher_state1(&trace, 15)); - assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 15)); + assert_eq!(program_hash, get_hasher_state1(&trace, 16)); + assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 16)); // the HALT opcode and program hash get propagated to the last row - for i in 16..trace_len { + for i in 17..trace_len { assert!(contains_op(&trace, i, Operation::Halt)); assert_eq!(ZERO, trace[OP_BITS_EXTRA_COLS_RANGE.start][i]); assert_eq!(ONE, trace[OP_BITS_EXTRA_COLS_RANGE.start + 1][i]); @@ -1302,8 +1427,9 @@ fn set_user_op_helpers_many() { let mut mast_forest = MastForest::new(); let basic_block_id = mast_forest.add_block(vec![Operation::U32div], None).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let a = rand_value::(); let b = rand_value::(); diff --git a/processor/src/decoder/trace.rs b/processor/src/decoder/trace.rs index a815395375..47cf9ba15c 100644 --- a/processor/src/decoder/trace.rs +++ b/processor/src/decoder/trace.rs @@ -85,16 +85,18 @@ impl DecoderTrace { // -------------------------------------------------------------------------------------------- /// Appends a trace row marking the start of a flow control block (JOIN, SPLIT, LOOP, CALL, - /// SYSCALL). + /// SYSCALL, DYN, DYNCALL). /// /// When a control block is starting, we do the following: /// - Set the address to the address of the parent block. This is not necessarily equal to the /// address from the previous row because in a SPLIT block, the second child follows the first /// child, rather than the parent. - /// - Set op_bits to opcode of the specified block (e.g., JOIN, SPLIT, LOOP, CALL, SYSCALL). + /// - Set op_bits to opcode of the specified block (e.g., JOIN, SPLIT, LOOP, CALL, SYSCALL, DYN, + /// DYNCALL). /// - Set the first half of the hasher state to the h1 parameter. For JOIN and SPLIT blocks this /// will contain the hash of the left child; for LOOP block this will contain hash of the - /// loop's body, for CALL and SYSCALL block this will contain hash of the called function. + /// loop's body, for CALL, SYSCALL, DYN and DYNCALL blocks this will contain hash of the + /// called function. /// - Set the second half of the hasher state to the h2 parameter. For JOIN and SPLIT blocks /// this will contain hash of the right child. /// - Set is_span to ZERO. @@ -299,6 +301,8 @@ impl DecoderTrace { self.hasher_trace[0].push(group_ops_left); self.hasher_trace[1].push(parent_addr); + // Note: use `Decoder::set_user_op_helpers()` when processing an instruction to set any of + // these values to something other than 0 for idx in USER_OP_HELPERS { self.hasher_trace[idx].push(ZERO); } diff --git a/processor/src/errors.rs b/processor/src/errors.rs index 53eeba9a7a..230a37a1c9 100644 --- a/processor/src/errors.rs +++ b/processor/src/errors.rs @@ -4,7 +4,11 @@ use core::fmt::{Display, Formatter}; use std::error::Error; use miden_air::RowIndex; -use vm_core::{mast::MastNodeId, stack::STACK_TOP_SIZE, utils::to_hex}; +use vm_core::{ + mast::{DecoratorId, MastNodeId}, + stack::MIN_STACK_DEPTH, + utils::to_hex, +}; use winter_prover::{math::FieldElement, ProverError}; use super::{ @@ -23,6 +27,9 @@ pub enum ExecutionError { CallerNotInSyscall, CircularExternalNode(Digest), CycleLimitExceeded(u32), + DecoratorNotFoundInForest { + decorator_id: DecoratorId, + }, DivideByZero(RowIndex), DynamicNodeNotFound(Digest), EventError(String), @@ -72,6 +79,7 @@ pub enum ExecutionError { MerkleStoreUpdateFailed(MerkleError), NotBinaryValue(Felt), NotU32Value(Felt, Felt), + OutputStackOverflow(usize), ProgramAlreadyExecuted, ProverError(ProverError), SmtNodeNotFound(Word), @@ -98,6 +106,9 @@ impl Display for ExecutionError { CycleLimitExceeded(max_cycles) => { write!(f, "Exceeded the allowed number of cycles (max cycles = {max_cycles})") }, + DecoratorNotFoundInForest { decorator_id } => { + write!(f, "Malformed MAST forest, decorator id {decorator_id} doesn't exist") + }, DivideByZero(clk) => write!(f, "Division by zero at clock cycle {clk}"), DynamicNodeNotFound(digest) => { let hex = to_hex(digest.as_bytes()); @@ -134,7 +145,7 @@ impl Display for ExecutionError { write!(f, "Memory range start address cannot exceed end address, but was ({start_addr}, {end_addr})") }, InvalidStackDepthOnReturn(depth) => { - write!(f, "When returning from a call, stack depth must be {STACK_TOP_SIZE}, but was {depth}") + write!(f, "When returning from a call, stack depth must be {MIN_STACK_DEPTH}, but was {depth}") }, InvalidStackWordOffset(offset) => { write!(f, "Stack word offset cannot exceed 12, but was {offset}") @@ -190,6 +201,9 @@ impl Display for ExecutionError { "An operation expected a u32 value, but received {v} (error code: {err_code})" ) }, + OutputStackOverflow(n) => { + write!(f, "The stack should have at most {MIN_STACK_DEPTH} elements at the end of program execution, but had {} elements", MIN_STACK_DEPTH + n) + }, SmtNodeNotFound(node) => { let node_hex = to_hex(Felt::elements_as_bytes(node)); write!(f, "Smt node {node_hex} not found") diff --git a/processor/src/host/advice/mod.rs b/processor/src/host/advice/mod.rs index 6dacc08b45..4eba93fe7c 100644 --- a/processor/src/host/advice/mod.rs +++ b/processor/src/host/advice/mod.rs @@ -718,7 +718,7 @@ pub trait AdviceProvider: Sized { R: Borrow; } -impl<'a, T> AdviceProvider for &'a mut T +impl AdviceProvider for &mut T where T: AdviceProvider, { diff --git a/processor/src/host/mast_forest_store.rs b/processor/src/host/mast_forest_store.rs index f6d05b0253..a2fd0296bb 100644 --- a/processor/src/host/mast_forest_store.rs +++ b/processor/src/host/mast_forest_store.rs @@ -23,9 +23,7 @@ pub struct MemMastForestStore { impl MemMastForestStore { /// Inserts all the procedures of the provided MAST forest in the store. - pub fn insert(&mut self, mast_forest: MastForest) { - let mast_forest = Arc::new(mast_forest); - + pub fn insert(&mut self, mast_forest: Arc) { // only register the procedures which are local to this forest for proc_digest in mast_forest.local_procedure_digests() { self.mast_forests.insert(proc_digest, mast_forest.clone()); diff --git a/processor/src/host/mod.rs b/processor/src/host/mod.rs index 0fbd6ef6b0..4e4289b331 100644 --- a/processor/src/host/mod.rs +++ b/processor/src/host/mod.rs @@ -176,7 +176,7 @@ pub trait Host { } } -impl<'a, H> Host for &'a mut H +impl Host for &mut H where H: Host, { @@ -316,7 +316,7 @@ where } } - pub fn load_mast_forest(&mut self, mast_forest: MastForest) { + pub fn load_mast_forest(&mut self, mast_forest: Arc) { self.store.insert(mast_forest) } diff --git a/processor/src/lib.rs b/processor/src/lib.rs index 19c8fc4e41..037f9269b9 100644 --- a/processor/src/lib.rs +++ b/processor/src/lib.rs @@ -24,8 +24,10 @@ pub use vm_core::{ StackInputs, StackOutputs, Word, EMPTY_WORD, ONE, ZERO, }; use vm_core::{ - mast::{BasicBlockNode, CallNode, JoinNode, LoopNode, OpBatch, SplitNode, OP_GROUP_SIZE}, - Decorator, DecoratorIterator, FieldElement, StackTopState, + mast::{ + BasicBlockNode, CallNode, DynNode, JoinNode, LoopNode, OpBatch, SplitNode, OP_GROUP_SIZE, + }, + Decorator, DecoratorIterator, FieldElement, }; pub use winter_prover::matrix::ColMatrix; @@ -103,7 +105,6 @@ pub struct DecoderTrace { pub struct StackTrace { trace: [Vec; STACK_TRACE_WIDTH], - aux_builder: stack::AuxTraceBuilder, } pub struct RangeCheckTrace { @@ -252,9 +253,9 @@ where return Err(ExecutionError::ProgramAlreadyExecuted); } - self.execute_mast_node(program.entrypoint(), program.mast_forest())?; + self.execute_mast_node(program.entrypoint(), &program.mast_forest().clone())?; - Ok(self.stack.build_stack_outputs()) + self.stack.build_stack_outputs() } // NODE EXECUTORS @@ -269,13 +270,17 @@ where .get_node_by_id(node_id) .ok_or(ExecutionError::MastNodeNotFoundInForest { node_id })?; + for &decorator_id in node.before_enter() { + self.execute_decorator(&program[decorator_id])?; + } + match node { - MastNode::Block(node) => self.execute_basic_block_node(node), - MastNode::Join(node) => self.execute_join_node(node, program), - MastNode::Split(node) => self.execute_split_node(node, program), - MastNode::Loop(node) => self.execute_loop_node(node, program), - MastNode::Call(node) => self.execute_call_node(node, program), - MastNode::Dyn => self.execute_dyn_node(program), + MastNode::Block(node) => self.execute_basic_block_node(node, program)?, + MastNode::Join(node) => self.execute_join_node(node, program)?, + MastNode::Split(node) => self.execute_split_node(node, program)?, + MastNode::Loop(node) => self.execute_loop_node(node, program)?, + MastNode::Call(node) => self.execute_call_node(node, program)?, + MastNode::Dyn(node) => self.execute_dyn_node(node, program)?, MastNode::External(external_node) => { let node_digest = external_node.digest(); let mast_forest = self @@ -296,9 +301,15 @@ where return Err(ExecutionError::CircularExternalNode(node_digest)); } - self.execute_mast_node(root_id, &mast_forest) + self.execute_mast_node(root_id, &mast_forest)?; }, } + + for &decorator_id in node.after_exit() { + self.execute_decorator(&program[decorator_id])?; + } + + Ok(()) } /// Executes the specified [JoinNode]. @@ -399,10 +410,16 @@ where /// The MAST root of the callee is assumed to be at the top of the stack, and the callee is /// expected to be either in the current `program` or in the host. #[inline(always)] - fn execute_dyn_node(&mut self, program: &MastForest) -> Result<(), ExecutionError> { - // get target hash from the stack - let callee_hash = self.stack.get_word(0); - self.start_dyn_node(callee_hash)?; + fn execute_dyn_node( + &mut self, + node: &DynNode, + program: &MastForest, + ) -> Result<(), ExecutionError> { + let callee_hash = if node.is_dyncall() { + self.start_dyncall_node(node)? + } else { + self.start_dyn_node(node)? + }; // if the callee is not in the program's MAST forest, try to find a MAST forest for it in // the host (corresponding to an external library loaded in the host); if none are @@ -426,7 +443,11 @@ where }, } - self.end_dyn_node() + if node.is_dyncall() { + self.end_dyncall_node(node) + } else { + self.end_dyn_node(node) + } } /// Executes the specified [BasicBlockNode]. @@ -434,14 +455,20 @@ where fn execute_basic_block_node( &mut self, basic_block: &BasicBlockNode, + program: &MastForest, ) -> Result<(), ExecutionError> { self.start_basic_block_node(basic_block)?; let mut op_offset = 0; - let mut decorators = basic_block.decorator_iter(); + let mut decorator_ids = basic_block.decorator_iter(); // execute the first operation batch - self.execute_op_batch(&basic_block.op_batches()[0], &mut decorators, op_offset)?; + self.execute_op_batch( + &basic_block.op_batches()[0], + &mut decorator_ids, + op_offset, + program, + )?; op_offset += basic_block.op_batches()[0].ops().len(); // if the span contains more operation batches, execute them. each additional batch is @@ -450,7 +477,7 @@ where for op_batch in basic_block.op_batches().iter().skip(1) { self.respan(op_batch); self.execute_op(Operation::Noop)?; - self.execute_op_batch(op_batch, &mut decorators, op_offset)?; + self.execute_op_batch(op_batch, &mut decorator_ids, op_offset, program)?; op_offset += op_batch.ops().len(); } @@ -460,7 +487,10 @@ where // can happen for decorators appearing after all operations in a block. these decorators // are executed after SPAN block is closed to make sure the VM clock cycle advances beyond // the last clock cycle of the SPAN block ops. - for decorator in decorators { + for &decorator_id in decorator_ids { + let decorator = program + .get_decorator_by_id(decorator_id) + .ok_or(ExecutionError::DecoratorNotFoundInForest { decorator_id })?; self.execute_decorator(decorator)?; } @@ -479,6 +509,7 @@ where batch: &OpBatch, decorators: &mut DecoratorIterator, op_offset: usize, + program: &MastForest, ) -> Result<(), ExecutionError> { let op_counts = batch.op_counts(); let mut op_idx = 0; @@ -492,7 +523,10 @@ where // execute operations in the batch one by one for (i, &op) in batch.ops().iter().enumerate() { - while let Some(decorator) = decorators.next_filtered(i + op_offset) { + while let Some(&decorator_id) = decorators.next_filtered(i + op_offset) { + let decorator = program + .get_decorator_by_id(decorator_id) + .ok_or(ExecutionError::DecoratorNotFoundInForest { decorator_id })?; self.execute_decorator(decorator)?; } @@ -561,16 +595,15 @@ where self.host.borrow_mut().set_advice(self, *injector)?; }, Decorator::Debug(options) => { - self.host.borrow_mut().on_debug(self, options)?; + if self.decoder.in_debug_mode() { + self.host.borrow_mut().on_debug(self, options)?; + } }, Decorator::AsmOp(assembly_op) => { if self.decoder.in_debug_mode() { self.decoder.append_asmop(self.system.clk(), assembly_op.clone()); } }, - Decorator::Event(id) => { - self.host.borrow_mut().on_event(self, *id)?; - }, Decorator::Trace(id) => { if self.enable_tracing { self.host.borrow_mut().on_trace(self, *id)?; diff --git a/processor/src/operations/comb_ops.rs b/processor/src/operations/comb_ops.rs index 699106e024..e0b68af198 100644 --- a/processor/src/operations/comb_ops.rs +++ b/processor/src/operations/comb_ops.rs @@ -173,7 +173,7 @@ where mod tests { use alloc::{borrow::ToOwned, vec::Vec}; - use test_utils::{build_test, rand::rand_array}; + use test_utils::{build_test, rand::rand_array, TRUNCATE_STACK_PROC}; use vm_core::{Felt, FieldElement, Operation, StackInputs, ONE, ZERO}; use crate::{ContextId, Process, QuadFelt}; @@ -272,46 +272,53 @@ mod tests { #[test] fn prove_verify() { - let source = " begin - # I) Prepare memory and stack - - # 1) Load T_i(x) for i=0,..,7 - push.0 padw - adv_pipe - - # 2) Load [T_i(z), T_i(gz)] for i=0,..,7 - repeat.4 - adv_pipe - end - - # 3) Load [a0, a1, 0, 0] for i=0,..,7 - repeat.4 - adv_pipe - end - - # 4) Clean up stack - dropw dropw dropw drop - - # 5) Prepare stack - - ## a) Push pointers - push.10 # a_ptr - push.2 # z_ptr - push.0 # x_ptr - - ## b) Push accumulators - padw - - ## c) Add padding for mem_stream - padw padw - - # II) Execute `rcomb_base` op - mem_stream - repeat.8 - rcomb_base - end - end - "; + let source = format!( + " + {TRUNCATE_STACK_PROC} + + begin + # I) Prepare memory and stack + + # 1) Load T_i(x) for i=0,..,7 + push.0 padw + adv_pipe + + # 2) Load [T_i(z), T_i(gz)] for i=0,..,7 + repeat.4 + adv_pipe + end + + # 3) Load [a0, a1, 0, 0] for i=0,..,7 + repeat.4 + adv_pipe + end + + # 4) Clean up stack + dropw dropw dropw drop + + # 5) Prepare stack + + ## a) Push pointers + push.10 # a_ptr + push.2 # z_ptr + push.0 # x_ptr + + ## b) Push accumulators + padw + + ## c) Add padding for mem_stream + padw padw + + # II) Execute `rcomb_base` op + mem_stream + repeat.8 + rcomb_base + end + + exec.truncate_stack + end + " + ); // generate the data let tx: [Felt; 8] = rand_array(); diff --git a/processor/src/operations/ext2_ops.rs b/processor/src/operations/ext2_ops.rs index 73a767357e..e10d80f4bd 100644 --- a/processor/src/operations/ext2_ops.rs +++ b/processor/src/operations/ext2_ops.rs @@ -37,7 +37,7 @@ mod tests { use vm_core::QuadExtension; use super::{ - super::{Felt, Operation, STACK_TOP_SIZE}, + super::{Felt, Operation, MIN_STACK_DEPTH}, Process, }; use crate::{StackInputs, ZERO}; @@ -60,7 +60,7 @@ mod tests { let c = (b * a).to_base_elements(); let expected = build_expected(&[b1, b0, c[1], c[0]]); - assert_eq!(STACK_TOP_SIZE, process.stack.depth()); + assert_eq!(MIN_STACK_DEPTH, process.stack.depth()); assert_eq!(2, process.stack.current_clk()); assert_eq!(expected, process.stack.trace_state()); diff --git a/processor/src/operations/field_ops.rs b/processor/src/operations/field_ops.rs index 5d221c8632..2fa799ed5e 100644 --- a/processor/src/operations/field_ops.rs +++ b/processor/src/operations/field_ops.rs @@ -227,7 +227,7 @@ mod tests { use vm_core::{ONE, ZERO}; use super::{ - super::{Felt, FieldElement, Operation, STACK_TOP_SIZE}, + super::{Felt, FieldElement, Operation, MIN_STACK_DEPTH}, Process, }; use crate::{AdviceInputs, StackInputs}; @@ -246,7 +246,7 @@ mod tests { process.execute_op(Operation::Add).unwrap(); let expected = build_expected(&[a + b, c]); - assert_eq!(STACK_TOP_SIZE, process.stack.depth()); + assert_eq!(MIN_STACK_DEPTH, process.stack.depth()); assert_eq!(2, process.stack.current_clk()); assert_eq!(expected, process.stack.trace_state()); @@ -267,7 +267,7 @@ mod tests { let expected = build_expected(&[-a, b, c]); assert_eq!(expected, process.stack.trace_state()); - assert_eq!(STACK_TOP_SIZE, process.stack.depth()); + assert_eq!(MIN_STACK_DEPTH, process.stack.depth()); assert_eq!(2, process.stack.current_clk()); } @@ -282,7 +282,7 @@ mod tests { process.execute_op(Operation::Mul).unwrap(); let expected = build_expected(&[a * b, c]); - assert_eq!(STACK_TOP_SIZE, process.stack.depth()); + assert_eq!(MIN_STACK_DEPTH, process.stack.depth()); assert_eq!(2, process.stack.current_clk()); assert_eq!(expected, process.stack.trace_state()); @@ -303,7 +303,7 @@ mod tests { process.execute_op(Operation::Inv).unwrap(); let expected = build_expected(&[a.inv(), b, c]); - assert_eq!(STACK_TOP_SIZE, process.stack.depth()); + assert_eq!(MIN_STACK_DEPTH, process.stack.depth()); assert_eq!(2, process.stack.current_clk()); assert_eq!(expected, process.stack.trace_state()); } @@ -324,7 +324,7 @@ mod tests { process.execute_op(Operation::Incr).unwrap(); let expected = build_expected(&[a + ONE, b, c]); - assert_eq!(STACK_TOP_SIZE, process.stack.depth()); + assert_eq!(MIN_STACK_DEPTH, process.stack.depth()); assert_eq!(2, process.stack.current_clk()); assert_eq!(expected, process.stack.trace_state()); } diff --git a/processor/src/operations/fri_ops.rs b/processor/src/operations/fri_ops.rs index 5756d116d5..01189c2d23 100644 --- a/processor/src/operations/fri_ops.rs +++ b/processor/src/operations/fri_ops.rs @@ -327,8 +327,11 @@ mod tests { ]; // --- execute FRIE2F4 operation -------------------------------------- - let stack_inputs = StackInputs::new(inputs.to_vec()).expect("inputs lenght too long"); + // construct the stack from the first 16 elements and push the 17th using the `push` op + let stack_inputs = + StackInputs::new(inputs[0..16].to_vec()).expect("inputs lenght too long"); let mut process = Process::new_dummy_with_decoder_helpers(stack_inputs); + process.execute_op(Operation::Push(inputs[16])).unwrap(); process.execute_op(Operation::FriE2F4).unwrap(); // --- check the stack state------------------------------------------- diff --git a/processor/src/operations/io_ops.rs b/processor/src/operations/io_ops.rs index 83f2c8872a..d6765bb2ac 100644 --- a/processor/src/operations/io_ops.rs +++ b/processor/src/operations/io_ops.rs @@ -35,12 +35,11 @@ where /// Thus, the net result of the operation is that the stack is shifted left by one item. pub(super) fn op_mloadw(&mut self) -> Result<(), ExecutionError> { // get the address from the stack and read the word from current memory context - let ctx = self.system.ctx(); - let addr = Self::get_valid_address(self.stack.get(0))?; - let word = self.chiplets.read_mem(ctx, addr); + let mut word = self.read_mem_word(self.stack.get(0))?; + word.reverse(); - // reverse the order of the memory word & update the stack state - for (i, &value) in word.iter().rev().enumerate() { + // update the stack state + for (i, &value) in word.iter().enumerate() { self.stack.set(i, value); } self.stack.shift_left(5); @@ -62,10 +61,7 @@ where /// register 0. pub(super) fn op_mload(&mut self) -> Result<(), ExecutionError> { // get the address from the stack and read the word from memory - let ctx = self.system.ctx(); - let addr = Self::get_valid_address(self.stack.get(0))?; - let mut word = self.chiplets.read_mem(ctx, addr); - // put the retrieved word into stack order + let mut word = self.read_mem_word(self.stack.get(0))?; word.reverse(); // update the stack state @@ -252,6 +248,15 @@ where // HELPER FUNCTIONS // -------------------------------------------------------------------------------------------- + /// Returns the memory word at address `addr` in the current context. + pub(crate) fn read_mem_word(&mut self, addr: Felt) -> Result { + let ctx = self.system.ctx(); + let mem_addr = Self::get_valid_address(addr)?; + let word_at_addr = self.chiplets.read_mem(ctx, mem_addr); + + Ok(word_at_addr) + } + /// Checks that provided address is less than u32::MAX and returns it cast to u32. /// /// # Errors @@ -273,7 +278,7 @@ mod tests { use vm_core::{utils::ToElements, Word, ONE, ZERO}; use super::{ - super::{super::AdviceProvider, Operation, STACK_TOP_SIZE}, + super::{super::AdviceProvider, Operation, MIN_STACK_DEPTH}, Felt, Host, Process, }; use crate::{AdviceSource, ContextId}; @@ -281,7 +286,7 @@ mod tests { #[test] fn op_push() { let mut process = Process::new_dummy_with_empty_stack(); - assert_eq!(STACK_TOP_SIZE, process.stack.depth()); + assert_eq!(MIN_STACK_DEPTH, process.stack.depth()); assert_eq!(1, process.stack.current_clk()); assert_eq!([ZERO; 16], process.stack.trace_state()); @@ -291,7 +296,7 @@ mod tests { let mut expected = [ZERO; 16]; expected[0] = ONE; - assert_eq!(STACK_TOP_SIZE + 1, process.stack.depth()); + assert_eq!(MIN_STACK_DEPTH + 1, process.stack.depth()); assert_eq!(2, process.stack.current_clk()); assert_eq!(expected, process.stack.trace_state()); @@ -302,7 +307,7 @@ mod tests { expected[0] = Felt::new(3); expected[1] = ONE; - assert_eq!(STACK_TOP_SIZE + 2, process.stack.depth()); + assert_eq!(MIN_STACK_DEPTH + 2, process.stack.depth()); assert_eq!(3, process.stack.current_clk()); assert_eq!(expected, process.stack.trace_state()); } diff --git a/processor/src/operations/mod.rs b/processor/src/operations/mod.rs index a3b7af5880..58216b7e00 100644 --- a/processor/src/operations/mod.rs +++ b/processor/src/operations/mod.rs @@ -1,4 +1,4 @@ -use vm_core::stack::STACK_TOP_SIZE; +use vm_core::stack::MIN_STACK_DEPTH; use super::{ExecutionError, Felt, FieldElement, Host, Operation, Process}; @@ -41,6 +41,7 @@ where Operation::Caller => self.op_caller()?, Operation::Clk => self.op_clk()?, + Operation::Emit(event_id) => self.op_emit(event_id)?, // ----- flow control operations ------------------------------------------------------ // control flow operations are never executed directly @@ -50,6 +51,7 @@ where Operation::Call => unreachable!("control flow operation"), Operation::SysCall => unreachable!("control flow operation"), Operation::Dyn => unreachable!("control flow operation"), + Operation::Dyncall => unreachable!("control flow operation"), Operation::Span => unreachable!("control flow operation"), Operation::Repeat => unreachable!("control flow operation"), Operation::Respan => unreachable!("control flow operation"), @@ -159,7 +161,7 @@ where } /// Increments the clock cycle for all components of the process. - fn advance_clock(&mut self) -> Result<(), ExecutionError> { + pub(super) fn advance_clock(&mut self) -> Result<(), ExecutionError> { self.system.advance_clock(self.max_cycles)?; self.stack.advance_clock(); self.chiplets.advance_clock(); @@ -167,7 +169,7 @@ where } /// Makes sure there is enough memory allocated for the trace to accommodate a new clock cycle. - fn ensure_trace_capacity(&mut self) { + pub(super) fn ensure_trace_capacity(&mut self) { self.system.ensure_trace_capacity(); self.stack.ensure_trace_capacity(); } diff --git a/processor/src/operations/stack_ops.rs b/processor/src/operations/stack_ops.rs index e16990d4f1..401294272f 100644 --- a/processor/src/operations/stack_ops.rs +++ b/processor/src/operations/stack_ops.rs @@ -1,4 +1,4 @@ -use super::{ExecutionError, Host, Process, STACK_TOP_SIZE}; +use super::{ExecutionError, Host, Process, MIN_STACK_DEPTH}; use crate::ZERO; impl Process @@ -185,7 +185,7 @@ where /// /// Elements between 0 and n are shifted right by one slot. pub(super) fn op_movup(&mut self, n: usize) -> Result<(), ExecutionError> { - debug_assert!(n < STACK_TOP_SIZE - 1, "n too large"); + debug_assert!(n < MIN_STACK_DEPTH - 1, "n too large"); // move the nth value to the top of the stack let value = self.stack.get(n); @@ -206,7 +206,7 @@ where /// /// Elements between 0 and n are shifted left by one slot. pub(super) fn op_movdn(&mut self, n: usize) -> Result<(), ExecutionError> { - debug_assert!(n < STACK_TOP_SIZE - 1, "n too large"); + debug_assert!(n < MIN_STACK_DEPTH - 1, "n too large"); // move the value at the top of the stack to the nth position let value = self.stack.get(0); @@ -304,7 +304,7 @@ where mod tests { use super::{ super::{Operation, Process}, - STACK_TOP_SIZE, + MIN_STACK_DEPTH, }; use crate::{Felt, StackInputs, ONE, ZERO}; @@ -322,7 +322,7 @@ mod tests { process.execute_op(Operation::Pad).unwrap(); let expected = build_expected(&[0, 1]); - assert_eq!(STACK_TOP_SIZE + 2, process.stack.depth()); + assert_eq!(MIN_STACK_DEPTH + 2, process.stack.depth()); assert_eq!(3, process.stack.current_clk()); assert_eq!(expected, process.stack.trace_state()); @@ -330,7 +330,7 @@ mod tests { process.execute_op(Operation::Pad).unwrap(); let expected = build_expected(&[0, 0, 1]); - assert_eq!(STACK_TOP_SIZE + 3, process.stack.depth()); + assert_eq!(MIN_STACK_DEPTH + 3, process.stack.depth()); assert_eq!(4, process.stack.current_clk()); assert_eq!(expected, process.stack.trace_state()); } @@ -347,13 +347,13 @@ mod tests { process.execute_op(Operation::Drop).unwrap(); let expected = build_expected(&[1]); assert_eq!(expected, process.stack.trace_state()); - assert_eq!(STACK_TOP_SIZE + 1, process.stack.depth()); + assert_eq!(MIN_STACK_DEPTH + 1, process.stack.depth()); // drop the next value process.execute_op(Operation::Drop).unwrap(); let expected = build_expected(&[]); assert_eq!(expected, process.stack.trace_state()); - assert_eq!(STACK_TOP_SIZE, process.stack.depth()); + assert_eq!(MIN_STACK_DEPTH, process.stack.depth()); // calling drop with a minimum stack depth should be ok assert!(process.execute_op(Operation::Drop).is_ok()); @@ -404,7 +404,7 @@ mod tests { process.execute_op(Operation::Drop).unwrap(); process.execute_op(Operation::Drop).unwrap(); - assert_eq!(STACK_TOP_SIZE + 15, process.stack.depth()); + assert_eq!(MIN_STACK_DEPTH + 15, process.stack.depth()); assert_eq!(&expected[2..], &process.stack.trace_state()[..14]); assert_eq!(ONE, process.stack.trace_state()[14]); @@ -464,17 +464,12 @@ mod tests { fn op_swapw3() { // push a few items onto the stack let stack = - StackInputs::try_from_ints([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]) + StackInputs::try_from_ints([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) .unwrap(); let mut process = Process::new_dummy(stack); process.execute_op(Operation::SwapW3).unwrap(); - let expected = build_expected(&[5, 4, 3, 2, 13, 12, 11, 10, 9, 8, 7, 6, 17, 16, 15, 14]); - assert_eq!(expected, process.stack.trace_state()); - - // value should remain on the overflow table - process.execute_op(Operation::Drop).unwrap(); - let expected = build_expected(&[4, 3, 2, 13, 12, 11, 10, 9, 8, 7, 6, 17, 16, 15, 14, 1]); + let expected = build_expected(&[4, 3, 2, 1, 12, 11, 10, 9, 8, 7, 6, 5, 16, 15, 14, 13]); assert_eq!(expected, process.stack.trace_state()); // swapping with a minimum stack should be ok diff --git a/processor/src/operations/sys_ops.rs b/processor/src/operations/sys_ops.rs index 046c86f748..ea211020fc 100644 --- a/processor/src/operations/sys_ops.rs +++ b/processor/src/operations/sys_ops.rs @@ -1,3 +1,5 @@ +use vm_core::Operation; + use super::{ super::{ system::{FMP_MAX, FMP_MIN}, @@ -108,6 +110,18 @@ where self.stack.shift_right(0); Ok(()) } + + // EVENTS + // -------------------------------------------------------------------------------------------- + + /// Forwards the emitted event id to the host. + pub(super) fn op_emit(&mut self, event_id: u32) -> Result<(), ExecutionError> { + self.stack.copy_state(0); + self.decoder.set_user_op_helpers(Operation::Emit(event_id), &[event_id.into()]); + self.host.borrow_mut().on_event(self, event_id)?; + + Ok(()) + } } // TESTS @@ -116,7 +130,7 @@ where #[cfg(test)] mod tests { use super::{ - super::{Operation, STACK_TOP_SIZE}, + super::{Operation, MIN_STACK_DEPTH}, Felt, Process, FMP_MAX, FMP_MIN, }; use crate::{StackInputs, ONE, ZERO}; @@ -212,27 +226,27 @@ mod tests { // stack is empty let mut process = Process::new_dummy_with_empty_stack(); process.execute_op(Operation::SDepth).unwrap(); - let expected = build_expected_stack(&[STACK_TOP_SIZE as u64]); + let expected = build_expected_stack(&[MIN_STACK_DEPTH as u64]); assert_eq!(expected, process.stack.trace_state()); - assert_eq!(STACK_TOP_SIZE + 1, process.stack.depth()); + assert_eq!(MIN_STACK_DEPTH + 1, process.stack.depth()); // stack has one item process.execute_op(Operation::SDepth).unwrap(); - let expected = build_expected_stack(&[STACK_TOP_SIZE as u64 + 1, STACK_TOP_SIZE as u64]); + let expected = build_expected_stack(&[MIN_STACK_DEPTH as u64 + 1, MIN_STACK_DEPTH as u64]); assert_eq!(expected, process.stack.trace_state()); - assert_eq!(STACK_TOP_SIZE + 2, process.stack.depth()); + assert_eq!(MIN_STACK_DEPTH + 2, process.stack.depth()); // stack has 3 items process.execute_op(Operation::Pad).unwrap(); process.execute_op(Operation::SDepth).unwrap(); let expected = build_expected_stack(&[ - STACK_TOP_SIZE as u64 + 3, + MIN_STACK_DEPTH as u64 + 3, 0, - STACK_TOP_SIZE as u64 + 1, - STACK_TOP_SIZE as u64, + MIN_STACK_DEPTH as u64 + 1, + MIN_STACK_DEPTH as u64, ]); assert_eq!(expected, process.stack.trace_state()); - assert_eq!(STACK_TOP_SIZE + 4, process.stack.depth()); + assert_eq!(MIN_STACK_DEPTH + 4, process.stack.depth()); } #[test] diff --git a/processor/src/operations/u32_ops.rs b/processor/src/operations/u32_ops.rs index c5213bdceb..1800029045 100644 --- a/processor/src/operations/u32_ops.rs +++ b/processor/src/operations/u32_ops.rs @@ -4,6 +4,22 @@ use super::{ }; use crate::ZERO; +const U32_MAX: u64 = u32::MAX as u64; + +macro_rules! require_u32_operand { + ($stack:expr, $idx:literal) => { + require_u32_operand!($stack, $idx, ZERO) + }; + + ($stack:expr, $idx:literal, $errno:expr) => {{ + let operand = $stack.get($idx); + if operand.as_int() > U32_MAX { + return Err(ExecutionError::NotU32Value(operand, $errno)); + } + operand + }}; +} + impl Process where H: Host, @@ -29,15 +45,8 @@ where /// the high values are equal to 0; if they are, puts the original elements back onto the /// stack; if they are not, returns an error. pub(super) fn op_u32assert2(&mut self, err_code: u32) -> Result<(), ExecutionError> { - let a = self.stack.get(0); - let b = self.stack.get(1); - - if a.as_int() >> 32 != 0 { - return Err(ExecutionError::NotU32Value(a, Felt::from(err_code))); - } - if b.as_int() >> 32 != 0 { - return Err(ExecutionError::NotU32Value(b, Felt::from(err_code))); - } + let b = require_u32_operand!(self.stack, 0, Felt::from(err_code)); + let a = require_u32_operand!(self.stack, 1, Felt::from(err_code)); self.add_range_checks(Operation::U32assert2(err_code), a, b, false); @@ -51,11 +60,11 @@ where /// Pops two elements off the stack, adds them, splits the result into low and high 32-bit /// values, and pushes these values back onto the stack. pub(super) fn op_u32add(&mut self) -> Result<(), ExecutionError> { - let b = self.stack.get(0); - let a = self.stack.get(1); - let result = a + b; - let (hi, lo) = split_element(result); + let b = require_u32_operand!(self.stack, 0).as_int(); + let a = require_u32_operand!(self.stack, 1).as_int(); + let result = Felt::new(a + b); + let (hi, lo) = split_element(result); self.add_range_checks(Operation::U32add, lo, hi, false); self.stack.set(0, hi); @@ -67,9 +76,9 @@ where /// Pops three elements off the stack, adds them, splits the result into low and high 32-bit /// values, and pushes these values back onto the stack. pub(super) fn op_u32add3(&mut self) -> Result<(), ExecutionError> { - let c = self.stack.get(0).as_int(); - let b = self.stack.get(1).as_int(); - let a = self.stack.get(2).as_int(); + let c = require_u32_operand!(self.stack, 0).as_int(); + let b = require_u32_operand!(self.stack, 1).as_int(); + let a = require_u32_operand!(self.stack, 2).as_int(); let result = Felt::new(a + b + c); let (hi, lo) = split_element(result); @@ -85,11 +94,11 @@ where /// pushes the result as well as a flag indicating whether there was underflow back onto the /// stack. pub(super) fn op_u32sub(&mut self) -> Result<(), ExecutionError> { - let b = self.stack.get(0).as_int(); - let a = self.stack.get(1).as_int(); + let b = require_u32_operand!(self.stack, 0).as_int(); + let a = require_u32_operand!(self.stack, 1).as_int(); let result = a.wrapping_sub(b); let d = Felt::new(result >> 63); - let c = Felt::new((result as u32) as u64); + let c = Felt::new(result & U32_MAX); // Force this operation to consume 4 range checks, even though only `lo` is needed. // This is required for making the constraints more uniform and grouping the opcodes of @@ -105,8 +114,8 @@ where /// Pops two elements off the stack, multiplies them, splits the result into low and high /// 32-bit values, and pushes these values back onto the stack. pub(super) fn op_u32mul(&mut self) -> Result<(), ExecutionError> { - let b = self.stack.get(0).as_int(); - let a = self.stack.get(1).as_int(); + let b = require_u32_operand!(self.stack, 0).as_int(); + let a = require_u32_operand!(self.stack, 1).as_int(); let result = Felt::new(a * b); let (hi, lo) = split_element(result); @@ -122,9 +131,9 @@ where /// the result, splits the result into low and high 32-bit values, and pushes these values /// back onto the stack. pub(super) fn op_u32madd(&mut self) -> Result<(), ExecutionError> { - let b = self.stack.get(0).as_int(); - let a = self.stack.get(1).as_int(); - let c = self.stack.get(2).as_int(); + let b = require_u32_operand!(self.stack, 0).as_int(); + let a = require_u32_operand!(self.stack, 1).as_int(); + let c = require_u32_operand!(self.stack, 2).as_int(); let result = Felt::new(a * b + c); let (hi, lo) = split_element(result); @@ -142,8 +151,8 @@ where /// # Errors /// Returns an error if the divisor is ZERO. pub(super) fn op_u32div(&mut self) -> Result<(), ExecutionError> { - let b = self.stack.get(0).as_int(); - let a = self.stack.get(1).as_int(); + let b = require_u32_operand!(self.stack, 0).as_int(); + let a = require_u32_operand!(self.stack, 1).as_int(); if b == 0 { return Err(ExecutionError::DivideByZero(self.system.clk())); @@ -170,8 +179,8 @@ where /// Pops two elements off the stack, computes their bitwise AND, and pushes the result back /// onto the stack. pub(super) fn op_u32and(&mut self) -> Result<(), ExecutionError> { - let b = self.stack.get(0); - let a = self.stack.get(1); + let b = require_u32_operand!(self.stack, 0); + let a = require_u32_operand!(self.stack, 1); let result = self.chiplets.u32and(a, b)?; self.stack.set(0, result); @@ -183,8 +192,8 @@ where /// Pops two elements off the stack, computes their bitwise XOR, and pushes the result back onto /// the stack. pub(super) fn op_u32xor(&mut self) -> Result<(), ExecutionError> { - let b = self.stack.get(0); - let a = self.stack.get(1); + let b = require_u32_operand!(self.stack, 0); + let a = require_u32_operand!(self.stack, 1); let result = self.chiplets.u32xor(a, b)?; self.stack.set(0, result); @@ -232,8 +241,9 @@ where #[cfg(test)] mod tests { - use miden_air::trace::{decoder::NUM_USER_OP_HELPERS, stack::STACK_TOP_SIZE}; + use miden_air::trace::decoder::NUM_USER_OP_HELPERS; use test_utils::rand::rand_value; + use vm_core::stack::MIN_STACK_DEPTH; use super::{ super::{Felt, Operation}, @@ -457,8 +467,8 @@ mod tests { (d, c, b, a) } - fn build_expected(values: &[u32]) -> [Felt; STACK_TOP_SIZE] { - let mut expected = [ZERO; STACK_TOP_SIZE]; + fn build_expected(values: &[u32]) -> [Felt; MIN_STACK_DEPTH] { + let mut expected = [ZERO; MIN_STACK_DEPTH]; for (&value, result) in values.iter().zip(expected.iter_mut()) { *result = Felt::new(value as u64); } diff --git a/processor/src/stack/aux_trace.rs b/processor/src/stack/aux_trace.rs index 96f33802d6..8c34985756 100644 --- a/processor/src/stack/aux_trace.rs +++ b/processor/src/stack/aux_trace.rs @@ -1,6 +1,7 @@ use alloc::vec::Vec; use miden_air::{trace::main_trace::MainTrace, RowIndex}; +use vm_core::OPCODE_DYNCALL; use super::{Felt, FieldElement, OverflowTableRow}; use crate::trace::AuxColumnBuilder; @@ -10,12 +11,7 @@ use crate::trace::AuxColumnBuilder; /// Describes how to construct execution traces of stack-related auxiliary trace segment columns /// (used in multiset checks). -pub struct AuxTraceBuilder { - /// A list of all rows that were added to and then removed from the overflow table. - pub(super) overflow_table_rows: Vec, - /// The number of rows in the overflow table when execution begins. - pub(super) num_init_rows: usize, -} +pub struct AuxTraceBuilder; impl AuxTraceBuilder { /// Builds and returns stack auxiliary trace columns. Currently this consists of a single @@ -26,24 +22,17 @@ impl AuxTraceBuilder { rand_elements: &[E], ) -> Vec> { let p1 = self.build_aux_column(main_trace, rand_elements); + + debug_assert_eq!(*p1.last().unwrap(), E::ONE); vec![p1] } } impl> AuxColumnBuilder for AuxTraceBuilder { - /// Initializes the overflow stack auxiliary column. - fn init_responses(&self, _main_trace: &MainTrace, alphas: &[E]) -> E { - let mut initial_column_value = E::ONE; - for row in self.overflow_table_rows.iter().take(self.num_init_rows) { - let value = (*row).to_value(alphas); - initial_column_value *= value; - } - initial_column_value - } - /// Removes a row from the stack overflow table. fn get_requests_at(&self, main_trace: &MainTrace, alphas: &[E], i: RowIndex) -> E { let is_left_shift = main_trace.is_left_shift(i); + let is_dyncall = main_trace.get_op_code(i) == OPCODE_DYNCALL.into(); let is_non_empty_overflow = main_trace.is_non_empty_overflow(i); if is_left_shift && is_non_empty_overflow { @@ -51,8 +40,13 @@ impl> AuxColumnBuilder for AuxTraceBuilder let s15_prime = main_trace.stack_element(15, i + 1); let b1_prime = main_trace.parent_overflow_address(i + 1); - let row = OverflowTableRow::new(b1, s15_prime, b1_prime); - row.to_value(alphas) + OverflowTableRow::new(b1, s15_prime, b1_prime).to_value(alphas) + } else if is_dyncall && is_non_empty_overflow { + let b1 = main_trace.parent_overflow_address(i); + let s15_prime = main_trace.stack_element(15, i + 1); + let b1_prime = main_trace.decoder_hasher_state_element(5, i); + + OverflowTableRow::new(b1, s15_prime, b1_prime).to_value(alphas) } else { E::ONE } diff --git a/processor/src/stack/mod.rs b/processor/src/stack/mod.rs index 172de5dd31..6708840b02 100644 --- a/processor/src/stack/mod.rs +++ b/processor/src/stack/mod.rs @@ -1,10 +1,11 @@ use alloc::vec::Vec; -use core::cmp; use miden_air::RowIndex; -use vm_core::{stack::STACK_TOP_SIZE, Word, WORD_SIZE}; +use vm_core::{stack::MIN_STACK_DEPTH, Word, WORD_SIZE}; -use super::{Felt, FieldElement, StackInputs, StackOutputs, ONE, STACK_TRACE_WIDTH, ZERO}; +use super::{ + ExecutionError, Felt, FieldElement, StackInputs, StackOutputs, ONE, STACK_TRACE_WIDTH, ZERO, +}; mod trace; use trace::StackTrace; @@ -23,7 +24,7 @@ mod tests; // ================================================================================================ /// The last stack index accessible by the VM. -const MAX_TOP_IDX: usize = STACK_TOP_SIZE - 1; +const MAX_TOP_IDX: usize = MIN_STACK_DEPTH - 1; // STACK // ================================================================================================ @@ -71,29 +72,15 @@ impl Stack { init_trace_capacity: usize, keep_overflow_trace: bool, ) -> Self { - let init_values = inputs.values(); - let depth = cmp::max(STACK_TOP_SIZE, init_values.len()); - - let (trace, overflow) = if init_values.len() > STACK_TOP_SIZE { - let overflow = - OverflowTable::new_with_inputs(keep_overflow_trace, &init_values[STACK_TOP_SIZE..]); - let trace = - StackTrace::new(&init_values[..STACK_TOP_SIZE], init_trace_capacity, depth, -ONE); - - (trace, overflow) - } else { - let overflow = OverflowTable::new(keep_overflow_trace); - let trace = StackTrace::new(init_values, init_trace_capacity, depth, ZERO); - - (trace, overflow) - }; + let overflow = OverflowTable::new(keep_overflow_trace); + let trace = StackTrace::new(&**inputs, init_trace_capacity, MIN_STACK_DEPTH, ZERO); Self { clk: RowIndex::from(0), trace, overflow, - active_depth: depth, - full_depth: depth, + active_depth: MIN_STACK_DEPTH, + full_depth: MIN_STACK_DEPTH, } } @@ -140,14 +127,18 @@ impl Stack { result } - /// Returns [StackOutputs] consisting of all values on the stack and all addresses in the - /// overflow table that are required to rebuild the rows in the overflow table. - pub fn build_stack_outputs(&self) -> StackOutputs { + /// Returns [StackOutputs] consisting of all values on the stack. + /// + /// # Errors + /// Returns an error if the overflow table is not empty at the current clock cycle. + pub fn build_stack_outputs(&self) -> Result { + if self.overflow.num_active_rows() != 0 { + return Err(ExecutionError::OutputStackOverflow(self.overflow.num_active_rows())); + } + let mut stack_items = Vec::with_capacity(self.active_depth); self.trace.append_state_into(&mut stack_items, self.clk); - self.overflow.append_into(&mut stack_items); - StackOutputs::new(stack_items, self.overflow.get_addrs()) - .expect("processor stack handling logic is valid") + Ok(StackOutputs::new(stack_items).expect("processor stack handling logic is valid")) } // TRACE ACCESSORS AND MUTATORS @@ -155,7 +146,7 @@ impl Stack { /// Returns the value located at the specified position on the stack at the current clock cycle. pub fn get(&self, pos: usize) -> Felt { - debug_assert!(pos < STACK_TOP_SIZE, "stack underflow"); + debug_assert!(pos < MIN_STACK_DEPTH, "stack underflow"); self.trace.get_stack_value_at(self.clk, pos) } @@ -181,7 +172,7 @@ impl Stack { /// Sets the value at the specified position on the stack at the next clock cycle. pub fn set(&mut self, pos: usize, value: Felt) { - debug_assert!(pos < STACK_TOP_SIZE, "stack underflow"); + debug_assert!(pos < MIN_STACK_DEPTH, "stack underflow"); self.trace.set_stack_value_at(self.clk + 1, pos, value); } @@ -206,18 +197,44 @@ impl Stack { /// stack is set to ZERO. pub fn shift_left(&mut self, start_pos: usize) { debug_assert!(start_pos > 0, "start position must be greater than 0"); - debug_assert!(start_pos <= STACK_TOP_SIZE, "start position cannot exceed stack top size"); + debug_assert!(start_pos <= MIN_STACK_DEPTH, "start position cannot exceed stack top size"); + + let (next_depth, next_overflow_addr) = self.shift_left_no_helpers(start_pos); + self.trace.set_helpers_at(self.clk.as_usize(), next_depth, next_overflow_addr); + } + + /// Copies stack values starting at the specified position at the current clock cycle to + /// position + 1 at the next clock cycle + /// + /// If stack depth grows beyond 16 items, the additional item is pushed into the overflow table. + pub fn shift_right(&mut self, start_pos: usize) { + debug_assert!(start_pos < MIN_STACK_DEPTH, "start position cannot exceed stack top size"); + + // Update the stack. + self.trace.stack_shift_right_at(self.clk, start_pos); + + // Update the overflow table. + let to_overflow = self.trace.get_stack_value_at(self.clk, MAX_TOP_IDX); + self.overflow.push(to_overflow, Felt::from(self.clk)); + + // Stack depth always increases on right shift. + self.active_depth += 1; + self.full_depth += 1; + } + /// Shifts the stack left, and returns the value for the helper columns B0 and B1, without + /// writing them to the trace. + fn shift_left_no_helpers(&mut self, start_pos: usize) -> (Felt, Felt) { match self.active_depth { 0..=MAX_TOP_IDX => unreachable!("stack underflow"), - STACK_TOP_SIZE => { + MIN_STACK_DEPTH => { // Shift in a ZERO, to prevent depth shrinking below the minimum stack depth. - self.trace.stack_shift_left_at(self.clk, start_pos, ZERO, None); + self.trace.stack_shift_left_no_helpers(self.clk, start_pos, ZERO, None) }, _ => { // Update the stack & overflow table. let from_overflow = self.overflow.pop(u64::from(self.clk)); - self.trace.stack_shift_left_at( + let helpers = self.trace.stack_shift_left_no_helpers( self.clk, start_pos, from_overflow, @@ -227,32 +244,43 @@ impl Stack { // Stack depth only decreases when it is greater than the minimum stack depth. self.active_depth -= 1; self.full_depth -= 1; + + helpers }, } } - /// Copies stack values starting at the specified position at the current clock cycle to - /// position + 1 at the next clock cycle - /// - /// If stack depth grows beyond 16 items, the additional item is pushed into the overflow table. - pub fn shift_right(&mut self, start_pos: usize) { - debug_assert!(start_pos < STACK_TOP_SIZE, "start position cannot exceed stack top size"); - - // Update the stack. - self.trace.stack_shift_right_at(self.clk, start_pos); + // CONTEXT MANAGEMENT + // -------------------------------------------------------------------------------------------- - // Update the overflow table. - let to_overflow = self.trace.get_stack_value_at(self.clk, MAX_TOP_IDX); - self.overflow.push(to_overflow, Felt::from(self.clk)); + /// Shifts the stack left, writes the default values for the stack helper registers in the trace + /// (stack depth and next overflow address), and returns the value of those helper registers + /// before the new context wipe. + /// + /// This specialized method is needed because the other ones write the updated helper register + /// values directly to the trace in the next row. However, the dyncall instruction needs to + /// shift the stack left, and start a new context simultaneously (and hence reset the stack + /// helper registers to their default value). It is assumed that the caller will write the + /// return values somewhere else in the trace. + pub fn shift_left_and_start_context(&mut self) -> (usize, Felt) { + const START_POSITION: usize = 1; + + self.shift_left_no_helpers(START_POSITION); + + // reset the helper columns to their default value, and write those to the trace in the next + // row. + let (next_depth, next_overflow_addr) = self.start_context(); + // Note: `start_context()` reset `active_depth` to 16, and `overflow.last_row_addr` to 0. + self.trace.set_helpers_at( + self.clk.as_usize(), + Felt::from(self.active_depth as u32), + self.overflow.last_row_addr(), + ); - // Stack depth always increases on right shift. - self.active_depth += 1; - self.full_depth += 1; + // return the helper registers' state before the new context + (next_depth, next_overflow_addr) } - // CONTEXT MANAGEMENT - // -------------------------------------------------------------------------------------------- - /// Starts a new execution context for this stack and returns a tuple consisting of the current /// stack depth and the address of the overflow table row prior to starting the new context. /// @@ -261,7 +289,7 @@ impl Stack { pub fn start_context(&mut self) -> (usize, Felt) { let current_depth = self.active_depth; let current_overflow_addr = self.overflow.last_row_addr(); - self.active_depth = STACK_TOP_SIZE; + self.active_depth = MIN_STACK_DEPTH; self.overflow.set_last_row_addr(ZERO); (current_depth, current_overflow_addr) } @@ -271,7 +299,7 @@ impl Stack { /// This has the effect bringing back items previously hidden from the overflow table. pub fn restore_context(&mut self, stack_depth: usize, next_overflow_addr: Felt) { debug_assert!(stack_depth <= self.full_depth, "stack depth too big"); - debug_assert_eq!(self.active_depth, STACK_TOP_SIZE, "overflow table not empty"); + debug_assert_eq!(self.active_depth, MIN_STACK_DEPTH, "overflow table not empty"); self.active_depth = stack_depth; self.overflow.set_last_row_addr(next_overflow_addr); } @@ -307,10 +335,7 @@ impl Stack { column.resize(trace_len, last_value); } - super::StackTrace { - trace, - aux_builder: self.overflow.into_aux_builder(), - } + super::StackTrace { trace } } // UTILITY METHODS @@ -334,7 +359,7 @@ impl Stack { /// Returns state of stack item columns at the current clock cycle. This does not include stack /// values in the overflow table. #[cfg(any(test, feature = "testing"))] - pub fn trace_state(&self) -> [Felt; STACK_TOP_SIZE] { + pub fn trace_state(&self) -> [Felt; MIN_STACK_DEPTH] { self.trace.get_stack_state_at(self.clk) } diff --git a/processor/src/stack/overflow.rs b/processor/src/stack/overflow.rs index a006467ecc..bd5e5a3142 100644 --- a/processor/src/stack/overflow.rs +++ b/processor/src/stack/overflow.rs @@ -1,8 +1,6 @@ use alloc::{collections::BTreeMap, vec::Vec}; -use vm_core::{utils::uninit_vector, StarkField}; - -use super::{AuxTraceBuilder, Felt, FieldElement, ZERO}; +use super::{Felt, FieldElement, ZERO}; // OVERFLOW TABLE // ================================================================================================ @@ -26,8 +24,6 @@ pub struct OverflowTable { /// whenever an update happens. This is set to true only when executing programs for debug /// purposes. trace_enabled: bool, - /// The number of rows in the overflow table when execution begins. - num_init_rows: usize, /// Holds the address (the clock cycle) of the row at to top of the overflow table. When /// entering new execution context, this value is set to ZERO, and thus, will differ from the /// row address actually at the top of the table. @@ -44,30 +40,10 @@ impl OverflowTable { active_rows: Vec::new(), trace: BTreeMap::new(), trace_enabled: enable_trace, - num_init_rows: 0, last_row_addr: ZERO, } } - /// Returns a new [OverflowTable]. The returned table contains a row for each of the provided - /// initial values, using a "negative" (mod p) `clk` value as the address for each of the rows, - /// since they are added before the first execution cycle. - /// - /// `init_values` is expected to be ordered such that values will be pushed onto the stack one - /// by one. Thus, the first item in the list will become the deepest item in the stack. - pub fn new_with_inputs(enable_trace: bool, init_values: &[Felt]) -> Self { - let mut overflow_table = Self::new(enable_trace); - overflow_table.num_init_rows = init_values.len(); - - let mut clk = Felt::MODULUS - init_values.len() as u64; - for &val in init_values.iter().rev() { - overflow_table.push(val, Felt::new(clk)); - clk += 1; - } - - overflow_table - } - // STATE MUTATORS // -------------------------------------------------------------------------------------------- @@ -172,36 +148,9 @@ impl OverflowTable { } } - /// Returns the addresses of active rows in the table required to reconstruct the table (when - /// combined with the values). This is a vector of all of the `clk` values (the address of each - /// row), preceded by the `prev` value in the first row of the table. (It's also equivalent to - /// all of the `prev` values followed by the `clk` value in the last row of the table.) - pub(super) fn get_addrs(&self) -> Vec { - if self.active_rows.is_empty() { - return Vec::new(); - } - - let mut addrs = unsafe { uninit_vector(self.active_rows.len() + 1) }; - // add the previous address of the first row in the overflow table. - addrs[0] = self.all_rows[self.active_rows[0]].prev; - // add the address for all the rows in the overflow table. - for (i, &row_idx) in self.active_rows.iter().enumerate() { - addrs[i + 1] = self.all_rows[row_idx].clk; - } - - addrs - } - - // AUX TRACE BUILDER GENERATION - // -------------------------------------------------------------------------------------------- - - /// Converts this [OverflowTable] into an auxiliary trace builder which can be used to construct - /// the auxiliary trace column describing the state of the overflow table at every cycle. - pub fn into_aux_builder(self) -> AuxTraceBuilder { - AuxTraceBuilder { - num_init_rows: self.num_init_rows, - overflow_table_rows: self.all_rows, - } + /// Returns the number of overflowing stack elements at the current clock cycle. + pub fn num_active_rows(&self) -> usize { + self.active_rows.len() } // HELPER METHODS diff --git a/processor/src/stack/tests.rs b/processor/src/stack/tests.rs index 44e4551adc..d8e16a810b 100644 --- a/processor/src/stack/tests.rs +++ b/processor/src/stack/tests.rs @@ -4,18 +4,16 @@ use miden_air::trace::{ stack::{B0_COL_IDX, B1_COL_IDX, H0_COL_IDX, NUM_STACK_HELPER_COLS}, STACK_TRACE_WIDTH, }; -use vm_core::{FieldElement, StarkField}; +use vm_core::FieldElement; -use super::{ - super::StackTopState, Felt, OverflowTableRow, Stack, StackInputs, ONE, STACK_TOP_SIZE, ZERO, -}; +use super::{Felt, OverflowTableRow, Stack, StackInputs, MIN_STACK_DEPTH, ONE, ZERO}; // TYPE ALIASES // ================================================================================================ type StackHelpersState = [Felt; NUM_STACK_HELPER_COLS]; -// INITIALIZATION TESTS +// INITIALIZATION TEST // ================================================================================================ #[test] @@ -28,7 +26,7 @@ fn initialize() { // Prepare the expected results. stack_inputs.reverse(); let expected_stack = build_stack(&stack_inputs); - let expected_helpers = [Felt::new(STACK_TOP_SIZE as u64), ZERO, ZERO]; + let expected_helpers = [Felt::new(MIN_STACK_DEPTH as u64), ZERO, ZERO]; // Check the stack state. assert_eq!(stack.trace_state(), expected_stack); @@ -37,23 +35,44 @@ fn initialize() { assert_eq!(stack.helpers_state(), expected_helpers); } +// OVERFLOW TEST +// ================================================================================================ + #[test] -fn initialize_overflow() { - // Initialize a new stack with enough initial values that the overflow table is non-empty. - let mut stack_inputs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]; - let stack = StackInputs::try_from_ints(stack_inputs).unwrap(); - let stack = Stack::new(&stack, 4, false); +fn stack_overflow() { + // Initialize a new fully loaded stack. + let mut stack_values_holder: [u64; 19] = + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]; + let stack = StackInputs::try_from_ints(stack_values_holder[0..16].to_vec()).unwrap(); + let mut stack = Stack::new(&stack, 5, false); + + // Push additional values to overflow the stack + stack.copy_state(0); + stack.advance_clock(); + + stack.shift_right(0); + stack.set(0, Felt::from(17u8)); + stack.advance_clock(); + + stack.shift_right(0); + stack.set(0, Felt::from(18u8)); + stack.advance_clock(); + + stack.shift_right(0); + stack.set(0, Felt::from(19u8)); + stack.advance_clock(); // Prepare the expected results. - stack_inputs.reverse(); - let expected_stack = build_stack(&stack_inputs[..STACK_TOP_SIZE]); - let expected_depth = stack_inputs.len() as u64; + stack_values_holder.reverse(); + let expected_stack = build_stack(&stack_values_holder[0..16]); + + let expected_depth = stack_values_holder.len() as u64; let expected_helpers = [ Felt::new(expected_depth), - -ONE, - Felt::new(expected_depth - STACK_TOP_SIZE as u64), + Felt::new(3u64), + Felt::new(expected_depth - MIN_STACK_DEPTH as u64), ]; - let init_addr = Felt::MODULUS - 3; + let init_addr = 1; let expected_overflow_rows = vec![ OverflowTableRow::new(Felt::new(init_addr), ONE, ZERO), OverflowTableRow::new(Felt::new(init_addr + 1), Felt::new(2), Felt::new(init_addr)), @@ -169,8 +188,7 @@ fn shift_right() { #[test] fn start_restore_context() { - let stack_init = (0..16).map(|v| v as u64 + 1); - let stack = StackInputs::try_from_ints(stack_init).unwrap(); + let stack = StackInputs::try_from_ints(1..17).unwrap(); let mut stack = Stack::new(&stack, 8, false); // ----- when overflow table is empty ------------------------------------- @@ -378,8 +396,8 @@ fn generate_trace() { /// Builds a [StackTopState] that starts with the provided stack inputs and is padded with zeros /// until the minimum stack depth. -fn build_stack(stack_inputs: &[u64]) -> StackTopState { - let mut result = [ZERO; STACK_TOP_SIZE]; +fn build_stack(stack_inputs: &[u64]) -> [Felt; MIN_STACK_DEPTH] { + let mut result = [ZERO; MIN_STACK_DEPTH]; for (idx, &input) in stack_inputs.iter().enumerate() { result[idx] = Felt::new(input); } @@ -390,7 +408,7 @@ fn build_stack(stack_inputs: &[u64]) -> StackTopState { fn build_helpers(stack_depth: u64, next_overflow_addr: u64) -> StackHelpersState { let b0 = Felt::new(stack_depth); let b1 = Felt::new(next_overflow_addr); - let h0 = (b0 - Felt::new(STACK_TOP_SIZE as u64)).inv(); + let h0 = (b0 - Felt::new(MIN_STACK_DEPTH as u64)).inv(); [b0, b1, h0] } @@ -399,17 +417,17 @@ fn build_helpers(stack_depth: u64, next_overflow_addr: u64) -> StackHelpersState /// The difference between this function and build_helpers() is that this function does not invert /// h0 value. fn build_helpers_partial(num_overflow: usize, next_overflow_addr: usize) -> StackHelpersState { - let depth = STACK_TOP_SIZE + num_overflow; + let depth = MIN_STACK_DEPTH + num_overflow; let b0 = Felt::new(depth as u64); let b1 = Felt::new(next_overflow_addr as u64); - let h0 = b0 - Felt::new(STACK_TOP_SIZE as u64); + let h0 = b0 - Felt::new(MIN_STACK_DEPTH as u64); [b0, b1, h0] } /// Returns values in stack top columns of the provided trace at the specified row. -fn read_stack_top(trace: &[Vec; STACK_TRACE_WIDTH], row: usize) -> StackTopState { - let mut result = [ZERO; STACK_TOP_SIZE]; +fn read_stack_top(trace: &[Vec; STACK_TRACE_WIDTH], row: usize) -> [Felt; MIN_STACK_DEPTH] { + let mut result = [ZERO; MIN_STACK_DEPTH]; for (value, column) in result.iter_mut().zip(trace) { *value = column[row]; } diff --git a/processor/src/stack/trace.rs b/processor/src/stack/trace.rs index 8483bbd651..f1dcc72101 100644 --- a/processor/src/stack/trace.rs +++ b/processor/src/stack/trace.rs @@ -1,10 +1,10 @@ use alloc::vec::Vec; use miden_air::{ - trace::stack::{H0_COL_IDX, NUM_STACK_HELPER_COLS, STACK_TOP_SIZE}, + trace::stack::{H0_COL_IDX, NUM_STACK_HELPER_COLS}, RowIndex, }; -use vm_core::FieldElement; +use vm_core::{stack::MIN_STACK_DEPTH, FieldElement}; use super::{super::utils::get_trace_len, Felt, MAX_TOP_IDX, ONE, STACK_TRACE_WIDTH, ZERO}; use crate::utils::math::batch_inversion; @@ -18,7 +18,7 @@ use crate::utils::math::batch_inversion; /// - 16 stack columns holding the top of the stack. /// - 3 columns for bookkeeping and helper values that manage left and right shifts. pub struct StackTrace { - stack: [Vec; STACK_TOP_SIZE], + stack: [Vec; MIN_STACK_DEPTH], helpers: [Vec; NUM_STACK_HELPER_COLS], } @@ -27,7 +27,7 @@ impl StackTrace { // -------------------------------------------------------------------------------------------- /// Returns a [StackTrace] instantiated with the provided input values. /// - /// When fewer than `STACK_TOP_SIZE` inputs are provided, the rest of the stack top elements + /// When fewer than `MIN_STACK_DEPTH` inputs are provided, the rest of the stack top elements /// are set to ZERO. The initial stack depth and initial overflow address are used to /// initialize the bookkeeping columns so they are consistent with the initial state of the /// overflow table. @@ -78,7 +78,7 @@ impl StackTrace { next_overflow_addr: Felt, ) { // copy over stack top columns - for i in start_pos..STACK_TOP_SIZE { + for i in start_pos..MIN_STACK_DEPTH { self.stack[i][clk + 1] = self.stack[i][clk]; } @@ -87,25 +87,24 @@ impl StackTrace { } /// Copies the stack values starting at the specified position at the specified clock cycle to - /// position - 1 at the next clock cycle. + /// position - 1 at the next clock cycle. Returns the new value of the helper registers without + /// writing them to the next row (i.e. the stack depth and the next overflow addr). /// /// The final stack item column is filled with the provided value in `last_value`. /// /// If next_overflow_addr is provided, this function assumes that the stack depth has been /// decreased by one and a row has been removed from the overflow table. Thus, it makes the - /// following changes to the helper columns: + /// following changes to the helper columns (without writing them to the next row): /// - Decrement the stack depth (b0) by one. /// - Sets b1 to the address of the top row in the overflow table to the specified /// `next_overflow_addr`. - /// - Set h0 to (depth - 16). Inverses of these values will be computed in into_array() method - /// after the entire trace is constructed. - pub fn stack_shift_left_at( + pub(super) fn stack_shift_left_no_helpers( &mut self, clk: RowIndex, start_pos: usize, last_value: Felt, next_overflow_addr: Option, - ) { + ) -> (Felt, Felt) { let clk = clk.as_usize(); // update stack top columns @@ -114,15 +113,15 @@ impl StackTrace { } self.stack[MAX_TOP_IDX][clk + 1] = last_value; - // update stack helper columns + // return stack helper columns if let Some(next_overflow_addr) = next_overflow_addr { let next_depth = self.helpers[0][clk] - ONE; - self.set_helpers_at(clk, next_depth, next_overflow_addr); + (next_depth, next_overflow_addr) } else { - // if next_overflow_addr was not provide, just copy over the values from the last row + // if next_overflow_addr was not provide, just return the values from the last row let next_depth = self.helpers[0][clk]; let next_overflow_addr = self.helpers[1][clk]; - self.set_helpers_at(clk, next_depth, next_overflow_addr); + (next_depth, next_overflow_addr) } } @@ -193,10 +192,15 @@ impl StackTrace { /// set to (stack_depth - 16) rather than to 1 / (stack_depth - 16). Inverses of these values /// will be computed in into_array() method (using batch inversion) after the entire trace is /// constructed. - fn set_helpers_at(&mut self, clk: usize, stack_depth: Felt, next_overflow_addr: Felt) { + pub(super) fn set_helpers_at( + &mut self, + clk: usize, + stack_depth: Felt, + next_overflow_addr: Felt, + ) { self.helpers[0][clk + 1] = stack_depth; self.helpers[1][clk + 1] = next_overflow_addr; - self.helpers[2][clk + 1] = stack_depth - Felt::from(STACK_TOP_SIZE as u32); + self.helpers[2][clk + 1] = stack_depth - Felt::from(MIN_STACK_DEPTH as u32); } // TEST HELPERS @@ -204,8 +208,8 @@ impl StackTrace { /// Returns the stack trace state at the specified clock cycle. #[cfg(any(test, feature = "testing"))] - pub fn get_stack_state_at(&self, clk: RowIndex) -> [Felt; STACK_TOP_SIZE] { - let mut result = [ZERO; STACK_TOP_SIZE]; + pub fn get_stack_state_at(&self, clk: RowIndex) -> [Felt; MIN_STACK_DEPTH] { + let mut result = [ZERO; MIN_STACK_DEPTH]; for (result, column) in result.iter_mut().zip(self.stack.iter()) { *result = column[clk.as_usize()]; } @@ -230,9 +234,9 @@ impl StackTrace { fn init_stack_columns( init_trace_capacity: usize, init_values: &[Felt], -) -> [Vec; STACK_TOP_SIZE] { - let mut stack: Vec> = Vec::with_capacity(STACK_TOP_SIZE); - for i in 0..STACK_TOP_SIZE { +) -> [Vec; MIN_STACK_DEPTH] { + let mut stack: Vec> = Vec::with_capacity(MIN_STACK_DEPTH); + for i in 0..MIN_STACK_DEPTH { let mut column = vec![Felt::ZERO; init_trace_capacity]; if i < init_values.len() { column[0] = init_values[i]; @@ -260,7 +264,7 @@ fn init_helper_columns( // if the overflow table is not empty, set h0 to (init_depth - 16) let mut h0 = vec![Felt::ZERO; init_trace_capacity]; // TODO: change type of `init_depth` to `u32` - h0[0] = Felt::try_from((init_depth - STACK_TOP_SIZE) as u64) + h0[0] = Felt::try_from((init_depth - MIN_STACK_DEPTH) as u64) .expect("value is greater than or equal to the field modulus"); [b0, b1, h0] diff --git a/processor/src/system/mod.rs b/processor/src/system/mod.rs index f6b50d0eaa..016560b544 100644 --- a/processor/src/system/mod.rs +++ b/processor/src/system/mod.rs @@ -178,8 +178,8 @@ impl System { /// - Sets the free memory pointer to its initial value (FMP_MIN). /// - Sets the hash of the function which initiated the current context to the provided value. /// - /// A CALL cannot be started when the VM is executing a SYSCALL. - pub fn start_call(&mut self, fn_hash: Word) { + /// A CALL or DYNCALL cannot be started when the VM is executing a SYSCALL. + pub fn start_call_or_dyncall(&mut self, fn_hash: Word) { debug_assert!(!self.in_syscall, "call in syscall"); self.ctx = (self.clk + 1).into(); self.fmp = Felt::new(FMP_MIN); diff --git a/processor/src/system/tests.rs b/processor/src/system/tests.rs index 7fcf59fbf7..e59598b0d2 100644 --- a/processor/src/system/tests.rs +++ b/processor/src/system/tests.rs @@ -9,7 +9,7 @@ fn cycles_num_exceeded() { Kernel::default(), stack, host, - ExecutionOptions::new(Some(64), 64, false).unwrap(), + ExecutionOptions::new(Some(64), 64, false, false).unwrap(), ); for _ in 0..64 { process.execute_op(Operation::Noop).unwrap(); diff --git a/processor/src/trace/mod.rs b/processor/src/trace/mod.rs index ffa3e47980..df1e2c4af1 100644 --- a/processor/src/trace/mod.rs +++ b/processor/src/trace/mod.rs @@ -6,7 +6,7 @@ use miden_air::trace::{ AUX_TRACE_RAND_ELEMENTS, AUX_TRACE_WIDTH, DECODER_TRACE_OFFSET, MIN_TRACE_LEN, STACK_TRACE_OFFSET, TRACE_WIDTH, }; -use vm_core::{stack::STACK_TOP_SIZE, ProgramInfo, StackOutputs, ZERO}; +use vm_core::{stack::MIN_STACK_DEPTH, ProgramInfo, StackInputs, StackOutputs, ZERO}; use winter_prover::{crypto::RandomCoin, EvaluationFrame, Trace, TraceInfo}; use super::{ @@ -14,7 +14,7 @@ use super::{ decoder::AuxTraceBuilder as DecoderAuxTraceBuilder, range::AuxTraceBuilder as RangeCheckerAuxTraceBuilder, stack::AuxTraceBuilder as StackAuxTraceBuilder, ColMatrix, Digest, Felt, FieldElement, Host, - Process, StackTopState, + Process, }; mod utils; @@ -121,22 +121,22 @@ impl ExecutionTrace { } /// Returns the initial state of the top 16 stack registers. - pub fn init_stack_state(&self) -> StackTopState { - let mut result = [ZERO; STACK_TOP_SIZE]; + pub fn init_stack_state(&self) -> StackInputs { + let mut result = [ZERO; MIN_STACK_DEPTH]; for (i, result) in result.iter_mut().enumerate() { *result = self.main_trace.get_column(i + STACK_TRACE_OFFSET)[0]; } - result + result.into() } /// Returns the final state of the top 16 stack registers. - pub fn last_stack_state(&self) -> StackTopState { + pub fn last_stack_state(&self) -> StackOutputs { let last_step = self.last_step(); - let mut result = [ZERO; STACK_TOP_SIZE]; + let mut result = [ZERO; MIN_STACK_DEPTH]; for (i, result) in result.iter_mut().enumerate() { *result = self.main_trace.get_column(i + STACK_TRACE_OFFSET)[last_step]; } - result + result.into() } /// Returns helper registers state at the specified `clk` of the VM @@ -338,7 +338,7 @@ where let aux_trace_hints = AuxTraceBuilders { decoder: decoder_trace.aux_builder, - stack: stack_trace.aux_builder, + stack: StackAuxTraceBuilder, range: range_check_trace.aux_builder, chiplets: chiplets_trace.aux_builder, }; diff --git a/processor/src/trace/tests/chiplets/hasher.rs b/processor/src/trace/tests/chiplets/hasher.rs index 0ed99b69d4..813b009d32 100644 --- a/processor/src/trace/tests/chiplets/hasher.rs +++ b/processor/src/trace/tests/chiplets/hasher.rs @@ -57,8 +57,9 @@ pub fn b_chip_span() { let basic_block_id = mast_forest.add_block(vec![Operation::Add, Operation::Mul], None).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let trace = build_trace_from_program(&program, &[]); @@ -129,8 +130,9 @@ pub fn b_chip_span_with_respan() { let (ops, _) = build_span_with_respan_ops(); let basic_block_id = mast_forest.add_block(ops, None).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let trace = build_trace_from_program(&program, &[]); @@ -220,12 +222,11 @@ pub fn b_chip_merge() { let mut mast_forest = MastForest::new(); let t_branch_id = mast_forest.add_block(vec![Operation::Add], None).unwrap(); - let f_branch_id = mast_forest.add_block(vec![Operation::Mul], None).unwrap(); - let split_id = mast_forest.add_split(t_branch_id, f_branch_id).unwrap(); + mast_forest.make_root(split_id); - Program::new(mast_forest, split_id) + Program::new(mast_forest.into(), split_id) }; let trace = build_trace_from_program(&program, &[]); @@ -336,8 +337,9 @@ pub fn b_chip_permutation() { let mut mast_forest = MastForest::new(); let basic_block_id = mast_forest.add_block(vec![Operation::HPerm], None).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let stack = vec![8, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8]; let trace = build_trace_from_program(&program, &stack); @@ -808,9 +810,8 @@ fn build_expected( // include the entire state (words a, b, c) value += build_value(&alphas[4..16], &state); } else if label == LINEAR_HASH_LABEL { - // include the delta between the next and current rate elements (words b and c) + // include the next rate elements value += build_value(&alphas[8..16], &next_state[CAPACITY_LEN..]); - value -= build_value(&alphas[8..16], &state[CAPACITY_LEN..]); } else if label == RETURN_HASH_LABEL { // include the digest (word b) value += build_value(&alphas[8..12], &state[DIGEST_RANGE]); diff --git a/processor/src/trace/tests/decoder.rs b/processor/src/trace/tests/decoder.rs index 1213b05dd4..021102dcc4 100644 --- a/processor/src/trace/tests/decoder.rs +++ b/processor/src/trace/tests/decoder.rs @@ -74,12 +74,11 @@ fn decoder_p1_join() { let mut mast_forest = MastForest::new(); let basic_block_1_id = mast_forest.add_block(vec![Operation::Mul], None).unwrap(); - let basic_block_2_id = mast_forest.add_block(vec![Operation::Add], None).unwrap(); - let join_id = mast_forest.add_join(basic_block_1_id, basic_block_2_id).unwrap(); + mast_forest.make_root(join_id); - Program::new(mast_forest, join_id) + Program::new(mast_forest.into(), join_id) }; let trace = build_trace_from_program(&program, &[]); @@ -142,12 +141,11 @@ fn decoder_p1_split() { let mut mast_forest = MastForest::new(); let basic_block_1_id = mast_forest.add_block(vec![Operation::Mul], None).unwrap(); - let basic_block_2_id = mast_forest.add_block(vec![Operation::Add], None).unwrap(); - let split_id = mast_forest.add_split(basic_block_1_id, basic_block_2_id).unwrap(); + mast_forest.make_root(split_id); - Program::new(mast_forest, split_id) + Program::new(mast_forest.into(), split_id) }; let trace = build_trace_from_program(&program, &[1]); @@ -197,14 +195,12 @@ fn decoder_p1_loop_with_repeat() { let mut mast_forest = MastForest::new(); let basic_block_1_id = mast_forest.add_block(vec![Operation::Pad], None).unwrap(); - let basic_block_2_id = mast_forest.add_block(vec![Operation::Drop], None).unwrap(); - let join_id = mast_forest.add_join(basic_block_1_id, basic_block_2_id).unwrap(); - let loop_node_id = mast_forest.add_loop(join_id).unwrap(); + mast_forest.make_root(loop_node_id); - Program::new(mast_forest, loop_node_id) + Program::new(mast_forest.into(), loop_node_id) }; let trace = build_trace_from_program(&program, &[0, 1, 1]); @@ -324,8 +320,9 @@ fn decoder_p2_span_with_respan() { let (ops, _) = build_span_with_respan_ops(); let basic_block_id = mast_forest.add_block(ops, None).unwrap(); + mast_forest.make_root(basic_block_id); - Program::new(mast_forest, basic_block_id) + Program::new(mast_forest.into(), basic_block_id) }; let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); @@ -366,8 +363,9 @@ fn decoder_p2_join() { let join = MastNode::new_join(basic_block_1_id, basic_block_2_id, &mast_forest).unwrap(); let join_id = mast_forest.add_node(join.clone()).unwrap(); + mast_forest.make_root(join_id); - let program = Program::new(mast_forest, join_id); + let program = Program::new(mast_forest.into(), join_id); let trace = build_trace_from_program(&program, &[]); let alphas = rand_array::(); @@ -425,12 +423,11 @@ fn decoder_p2_split_true() { let basic_block_1 = MastNode::new_basic_block(vec![Operation::Mul], None).unwrap(); let basic_block_1_id = mast_forest.add_node(basic_block_1.clone()).unwrap(); - let basic_block_2_id = mast_forest.add_block(vec![Operation::Add], None).unwrap(); - let split_id = mast_forest.add_split(basic_block_1_id, basic_block_2_id).unwrap(); + mast_forest.make_root(split_id); - let program = Program::new(mast_forest, split_id); + let program = Program::new(mast_forest.into(), split_id); // build trace from program let trace = build_trace_from_program(&program, &[1]); @@ -484,8 +481,9 @@ fn decoder_p2_split_false() { let basic_block_2_id = mast_forest.add_node(basic_block_2.clone()).unwrap(); let split_id = mast_forest.add_split(basic_block_1_id, basic_block_2_id).unwrap(); + mast_forest.make_root(split_id); - let program = Program::new(mast_forest, split_id); + let program = Program::new(mast_forest.into(), split_id); // build trace from program let trace = build_trace_from_program(&program, &[0]); @@ -542,8 +540,9 @@ fn decoder_p2_loop_with_repeat() { let join_id = mast_forest.add_node(join.clone()).unwrap(); let loop_node_id = mast_forest.add_loop(join_id).unwrap(); + mast_forest.make_root(loop_node_id); - let program = Program::new(mast_forest, loop_node_id); + let program = Program::new(mast_forest.into(), loop_node_id); // build trace from program let trace = build_trace_from_program(&program, &[0, 1, 1]); diff --git a/processor/src/trace/tests/hasher.rs b/processor/src/trace/tests/hasher.rs index f84e320e56..7197c925f5 100644 --- a/processor/src/trace/tests/hasher.rs +++ b/processor/src/trace/tests/hasher.rs @@ -98,23 +98,23 @@ fn hasher_p1_mr_update() { expected_value *= row_values[0]; assert_eq!(expected_value, p1[9]); - // and then again for the next 7 steps the value remains the same - for i in 10..17 { + // and then again for the next 6 steps the value remains the same + for i in 10..16 { assert_eq!(expected_value, p1[i]); } - // on step 16, the next sibling is added to the table in the following row (step 17) + // on step 15, the next sibling is added to the table in the following row (step 16) expected_value *= row_values[1]; - assert_eq!(expected_value, p1[17]); + assert_eq!(expected_value, p1[16]); - // and then again for the next 7 steps the value remains the same - for i in 18..25 { + // and then again for the next 6 steps the value remains the same + for i in 18..24 { assert_eq!(expected_value, p1[i]); } - // on step 24, the last sibling is added to the table in the following row (step 25) + // on step 23, the last sibling is added to the table in the following row (step 24) expected_value *= row_values[2]; - assert_eq!(expected_value, p1[25]); + assert_eq!(expected_value, p1[24]); // and then again for the next 7 steps the value remains the same for i in 25..33 { @@ -126,23 +126,23 @@ fn hasher_p1_mr_update() { expected_value *= row_values[0].inv(); assert_eq!(expected_value, p1[33]); - // then, for the next 7 steps the value remains the same - for i in 33..41 { + // then, for the next 6 steps the value remains the same + for i in 33..40 { assert_eq!(expected_value, p1[i]); } - // on step 40, the next sibling is removed from the table in the following row (step 41) + // on step 39, the next sibling is removed from the table in the following row (step 40) expected_value *= row_values[1].inv(); - assert_eq!(expected_value, p1[41]); + assert_eq!(expected_value, p1[40]); - // and then again for the next 7 steps the value remains the same - for i in 41..49 { + // and then again for the next 6 steps the value remains the same + for i in 41..48 { assert_eq!(expected_value, p1[i]); } - // on step 48, the last sibling is removed from the table in the following row (step 49) + // on step 47, the last sibling is removed from the table in the following row (step 48) expected_value *= row_values[2].inv(); - assert_eq!(expected_value, p1[49]); + assert_eq!(expected_value, p1[48]); // at this point the table should be empty again, and it should stay empty until the end assert_eq!(expected_value, ONE); diff --git a/processor/src/trace/tests/mod.rs b/processor/src/trace/tests/mod.rs index c29c4267fa..42524d68f9 100644 --- a/processor/src/trace/tests/mod.rs +++ b/processor/src/trace/tests/mod.rs @@ -34,8 +34,9 @@ pub fn build_trace_from_ops(operations: Vec, stack: &[u64]) -> Execut let mut mast_forest = MastForest::new(); let basic_block_id = mast_forest.add_block(operations, None).unwrap(); + mast_forest.make_root(basic_block_id); - let program = Program::new(mast_forest, basic_block_id); + let program = Program::new(mast_forest.into(), basic_block_id); build_trace_from_program(&program, stack) } @@ -55,8 +56,9 @@ pub fn build_trace_from_ops_with_inputs( let mut mast_forest = MastForest::new(); let basic_block_id = mast_forest.add_block(operations, None).unwrap(); + mast_forest.make_root(basic_block_id); - let program = Program::new(mast_forest, basic_block_id); + let program = Program::new(mast_forest.into(), basic_block_id); process.execute(&program).unwrap(); ExecutionTrace::new(process, StackOutputs::default()) diff --git a/processor/src/trace/utils.rs b/processor/src/trace/utils.rs index 70e7f79a75..bee240455d 100644 --- a/processor/src/trace/utils.rs +++ b/processor/src/trace/utils.rs @@ -232,7 +232,7 @@ pub trait AuxColumnBuilder> { responses_prod[0] = self.init_responses(main_trace, alphas); requests[0] = self.init_requests(main_trace, alphas); - let mut requests_running_prod = E::ONE; + let mut requests_running_prod = requests[0]; for row_idx in 0..main_trace.num_rows() - 1 { let row = row_idx.into(); responses_prod[row_idx + 1] = diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 07033045f8..64aa4add07 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "miden-prover" -version = "0.10.5" +version = "0.11.0" description = "Miden VM prover" -documentation = "https://docs.rs/miden-prover/0.10.5" +documentation = "https://docs.rs/miden-prover/0.11.0" readme = "README.md" categories = ["cryptography", "emulators", "no-std"] keywords = ["miden", "prover", "stark", "zkp"] @@ -14,18 +14,20 @@ rust-version.workspace = true edition.workspace = true [features] +async = ["winter-maybe-async/async"] concurrent = ["processor/concurrent", "std", "winter-prover/concurrent"] default = ["std"] metal = ["dep:miden-gpu", "dep:elsa", "dep:pollster", "concurrent", "std"] std = ["air/std", "processor/std", "winter-prover/std"] [dependencies] -air = { package = "miden-air", path = "../air", version = "0.10", default-features = false } -processor = { package = "miden-processor", path = "../processor", version = "0.10", default-features = false } +air = { package = "miden-air", path = "../air", version = "0.11", default-features = false } +processor = { package = "miden-processor", path = "../processor", version = "0.11", default-features = false } tracing = { version = "0.1", default-features = false, features = ["attributes"] } -winter-prover = { package = "winter-prover", version = "0.9", default-features = false } +winter-maybe-async = { package = "winter-maybe-async", version = "0.10", default-features = false } +winter-prover = { package = "winter-prover", version = "0.10", default-features = false } [target.'cfg(all(target_arch = "aarch64", target_os = "macos"))'.dependencies] elsa = { version = "1.9", optional = true } -miden-gpu = { version = "0.2", optional = true } -pollster = { version = "0.3", optional = true } +miden-gpu = { version = "0.3", optional = true } +pollster = { version = "0.4", optional = true } diff --git a/prover/README.md b/prover/README.md index 9f07ee2f00..be22bbe55b 100644 --- a/prover/README.md +++ b/prover/README.md @@ -46,6 +46,7 @@ Miden prover can be compiled with the following features: * `concurrent` - implies `std` and also enables multi-threaded proof generation. * `metal` - enables [Metal](https://en.wikipedia.org/wiki/Metal_(API))-based acceleration of proof generation (for recursive proofs) on supported platforms (e.g., Apple silicon). * `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly. + * Only the `wasm32-unknown-unknown` and `wasm32-wasip1` targets are officially supported. To compile with `no_std`, disable default features via `--no-default-features` flag. diff --git a/prover/src/gpu/metal/mod.rs b/prover/src/gpu/metal/mod.rs index ab7132834e..3aa6ae9dca 100644 --- a/prover/src/gpu/metal/mod.rs +++ b/prover/src/gpu/metal/mod.rs @@ -4,20 +4,17 @@ use std::{boxed::Box, marker::PhantomData, time::Instant, vec::Vec}; -use air::{AuxRandElements, LagrangeKernelEvaluationFrame}; +use air::{AuxRandElements, LagrangeKernelEvaluationFrame, PartitionOptions}; use elsa::FrozenVec; use miden_gpu::{ metal::{build_merkle_tree, utils::page_aligned_uninit_vector, RowHasher}, HashFn, }; use pollster::block_on; -use processor::{ - crypto::{ElementHasher, Hasher}, - ONE, -}; +use processor::crypto::{ElementHasher, Hasher}; use tracing::{event, Level}; use winter_prover::{ - crypto::{Digest, MerkleTree}, + crypto::{Digest, MerkleTree, VectorCommitment}, matrix::{get_evaluation_offsets, ColMatrix, RowMatrix, Segment}, proof::Queries, CompositionPoly, CompositionPolyTrace, ConstraintCommitment, ConstraintCompositionCoefficients, @@ -38,8 +35,6 @@ mod tests; // CONSTANTS // ================================================================================================ -// The Rate for RPO and RPX is the same -const RATE: usize = Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.start; const DIGEST_SIZE: usize = Rpo256::DIGEST_RANGE.end - Rpo256::DIGEST_RANGE.start; // METAL RPO/RPX PROVER @@ -71,7 +66,7 @@ where } } - fn build_aligned_segement( + fn build_aligned_segment( polys: &ColMatrix, poly_offset: usize, offsets: &[Felt], @@ -101,7 +96,7 @@ where Segment::new_with_buffer(data, polys, poly_offset, offsets, twiddles) } - fn build_aligned_segements( + fn build_aligned_segments( polys: &ColMatrix, twiddles: &[Felt], offsets: &[Felt], @@ -120,20 +115,21 @@ where }; (0..num_segments) - .map(|i| Self::build_aligned_segement(polys, i * N, offsets, twiddles)) + .map(|i| Self::build_aligned_segment(polys, i * N, offsets, twiddles)) .collect() } } impl Prover for MetalExecutionProver where - H: Hasher + ElementHasher, + H: Hasher + ElementHasher + Sync, D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>, R: RandomCoin + Send, { type BaseField = Felt; type Air = ProcessorAir; type Trace = ExecutionTrace; + type VC = MerkleTree; type HashFn = H; type RandomCoin = R; type TraceLde> = MetalTraceLde; @@ -148,11 +144,20 @@ where self.execution_prover.options() } + fn build_aux_trace>( + &self, + trace: &Self::Trace, + aux_rand_elements: &AuxRandElements, + ) -> ColMatrix { + trace.build_aux_trace(aux_rand_elements.rand_elements()).unwrap() + } + fn new_trace_lde>( &self, trace_info: &TraceInfo, main_trace: &ColMatrix, domain: &StarkDomain, + _partition_options: PartitionOptions, ) -> (Self::TraceLde, TracePolyTable) { MetalTraceLde::new(trace_info, main_trace, domain, self.metal_hash_fn) } @@ -196,7 +201,10 @@ where composition_poly_trace: CompositionPolyTrace, num_trace_poly_columns: usize, domain: &StarkDomain, - ) -> (ConstraintCommitment, CompositionPoly) { + ) -> ( + ConstraintCommitment>, + CompositionPoly, + ) { // evaluate composition polynomial columns over the LDE domain let now = Instant::now(); let composition_poly = @@ -204,7 +212,7 @@ where let blowup = domain.trace_to_lde_blowup(); let offsets = get_evaluation_offsets::(composition_poly.column_len(), blowup, domain.offset()); - let segments = Self::build_aligned_segements( + let segments = Self::build_aligned_segments( composition_poly.data(), domain.trace_twiddles(), &offsets, @@ -222,31 +230,14 @@ where let lde_domain_size = domain.lde_domain_size(); let num_base_columns = composition_poly.num_columns() * ::EXTENSION_DEGREE; - let rpo_requires_padding = num_base_columns % RATE != 0; - let rpo_padded_segment_idx = rpo_requires_padding.then_some(num_base_columns / RATE); + let mut row_hasher = RowHasher::new(lde_domain_size, num_base_columns, self.metal_hash_fn); - let mut rpo_padded_segment: Vec<[Felt; RATE]>; - for (segment_idx, segment) in segments.iter().enumerate() { - // check if the segment requires padding - if rpo_padded_segment_idx.map_or(false, |pad_idx| pad_idx == segment_idx) { - // duplicate and modify the last segment with Rpo256's padding - // rule ("1" followed by "0"s). Our segments are already - // padded with "0"s we only need to add the "1"s. - rpo_padded_segment = unsafe { page_aligned_uninit_vector(lde_domain_size) }; - rpo_padded_segment.copy_from_slice(segment); - // For rpx, skip this step - if self.metal_hash_fn == HashFn::Rpo256 { - let rpo_pad_column = num_base_columns % RATE; - rpo_padded_segment.iter_mut().for_each(|row| row[rpo_pad_column] = ONE); - } - row_hasher.update(&rpo_padded_segment); - assert_eq!(segments.len() - 1, segment_idx, "padded segment should be the last"); - break; - } + for segment in segments.iter() { row_hasher.update(segment); } let row_hashes = block_on(row_hasher.finish()); let tree_nodes = build_merkle_tree(&row_hashes, self.metal_hash_fn); + // aggregate segments at the same time as the GPU generates the merkle tree nodes let composed_evaluations = RowMatrix::::from_segments(segments, num_base_columns); let nodes = block_on(tree_nodes).into_iter().map(|dig| H::Digest::from(&dig)).collect(); @@ -256,7 +247,7 @@ where event!( Level::INFO, "Computed constraint evaluation commitment on the GPU (Merkle tree of depth {}) in {} ms", - constraint_commitment.tree_depth(), + lde_domain_size.ilog2(), now.elapsed().as_millis() ); (constraint_commitment, composition_poly) @@ -352,13 +343,14 @@ impl< } } -impl< - E: FieldElement, - H: Hasher + ElementHasher, - D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>, - > TraceLde for MetalTraceLde +impl TraceLde for MetalTraceLde +where + E: FieldElement, + H: Hasher + ElementHasher, + D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>, { type HashFn = H; + type VC = MerkleTree; /// Returns the commitment to the low-degree extension of the main trace segment. fn get_main_trace_commitment(&self) -> D { @@ -535,34 +527,16 @@ fn build_trace_commitment< let lde_segments = FrozenVec::new(); let lde_domain_size = domain.lde_domain_size(); let num_base_columns = trace.num_base_cols(); - let rpo_requires_padding = num_base_columns % RATE != 0; - let rpo_padded_segment_idx = rpo_requires_padding.then_some(num_base_columns / RATE); + let mut row_hasher = RowHasher::new(lde_domain_size, num_base_columns, hash_fn); - let mut rpo_padded_segment: Vec<[Felt; RATE]>; let mut lde_segment_generator = SegmentGenerator::new(trace_polys, domain); - let mut lde_segment_iter = lde_segment_generator.gen_segment_iter().enumerate(); - for (segment_idx, segment) in &mut lde_segment_iter { + for segment in lde_segment_generator.gen_segment_iter() { let segment = lde_segments.push_get(Box::new(segment)); - // check if the segment requires padding - if rpo_padded_segment_idx.map_or(false, |pad_idx| pad_idx == segment_idx) { - // duplicate and modify the last segment with Rpo256's padding - // rule ("1" followed by "0"s). Our segments are already - // padded with "0"s we only need to add the "1"s. - rpo_padded_segment = unsafe { page_aligned_uninit_vector(lde_domain_size) }; - rpo_padded_segment.copy_from_slice(segment); - // skip this in case of Rpx - if hash_fn == HashFn::Rpo256 { - let rpo_pad_column = num_base_columns % RATE; - rpo_padded_segment.iter_mut().for_each(|row| row[rpo_pad_column] = ONE); - } - row_hasher.update(&rpo_padded_segment); - assert!(lde_segment_iter.next().is_none(), "padded segment should be the last"); - break; - } row_hasher.update(segment); } let row_hashes = block_on(row_hasher.finish()); let tree_nodes = build_merkle_tree(&row_hashes, hash_fn); + // aggregate segments at the same time as the GPU generates the merkle tree nodes let lde_segments = lde_segments.into_vec().into_iter().map(|p| *p).collect(); let trace_lde = RowMatrix::from_segments(lde_segments, num_base_columns); @@ -662,25 +636,27 @@ where } } -fn build_segment_queries< - E: FieldElement, - H: Hasher + ElementHasher, ->( +fn build_segment_queries( segment_lde: &RowMatrix, - segment_tree: &MerkleTree, + segment_vector_com: &V, positions: &[usize], -) -> Queries { +) -> Queries +where + E: FieldElement, + H: ElementHasher, + V: VectorCommitment, +{ // for each position, get the corresponding row from the trace segment LDE and put all these // rows into a single vector let trace_states = positions.iter().map(|&pos| segment_lde.row(pos).to_vec()).collect::>(); - // build Merkle authentication paths to the leaves specified by positions - let trace_proof = segment_tree - .prove_batch(positions) - .expect("failed to generate a Merkle proof for trace queries"); + // build a batch opening proof to the leaves specified by positions + let trace_proof = segment_vector_com + .open_many(positions) + .expect("failed to generate a batch opening proof for trace queries"); - Queries::new(trace_proof, trace_states) + Queries::new::(trace_proof.1, trace_states) } struct SegmentIterator<'a, 'b, E, I, const N: usize>(&'b mut SegmentGenerator<'a, E, I, N>) @@ -688,7 +664,7 @@ where E: FieldElement, I: IntoIterator>; -impl<'a, 'b, E, I, const N: usize> Iterator for SegmentIterator<'a, 'b, E, I, N> +impl Iterator for SegmentIterator<'_, '_, E, I, N> where E: FieldElement, I: IntoIterator>, diff --git a/prover/src/gpu/metal/tests.rs b/prover/src/gpu/metal/tests.rs index 6e79ec98f9..533169d49f 100644 --- a/prover/src/gpu/metal/tests.rs +++ b/prover/src/gpu/metal/tests.rs @@ -1,9 +1,9 @@ use alloc::vec::Vec; -use air::{ProvingOptions, StarkField}; -use gpu::metal::{MetalExecutionProver, DIGEST_SIZE, RATE}; +use air::{PartitionOptions, ProvingOptions, StarkField}; +use gpu::metal::{MetalExecutionProver, DIGEST_SIZE}; use processor::{ - crypto::{Hasher, RpoDigest, RpoRandomCoin, Rpx256, RpxDigest, RpxRandomCoin}, + crypto::{Hasher, Rpo256, RpoDigest, RpoRandomCoin, Rpx256, RpxDigest}, math::fft, StackInputs, StackOutputs, }; @@ -11,28 +11,103 @@ use winter_prover::{crypto::Digest, math::fields::CubeExtension, CompositionPoly use crate::*; +const RATE: usize = Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.start; + type CubeFelt = CubeExtension; -fn build_trace_commitment_on_gpu_with_padding_matches_cpu< +// TESTS +// ================================================================================================ + +#[test] +fn rpo_build_trace_commitment_on_gpu_with_padding_matches_cpu() { + build_trace_commitment_on_gpu_with_padding_matches_cpu::( + HashFn::Rpo256, + ); +} + +#[test] +fn rpx_build_trace_commitment_on_gpu_with_padding_matches_cpu() { + build_trace_commitment_on_gpu_with_padding_matches_cpu::( + HashFn::Rpx256, + ); +} + +#[test] +fn rpo_build_trace_commitment_on_gpu_without_padding_matches_cpu() { + build_trace_commitment_on_gpu_without_padding_matches_cpu::( + HashFn::Rpo256, + ); +} + +#[test] +fn rpx_build_trace_commitment_on_gpu_without_padding_matches_cpu() { + build_trace_commitment_on_gpu_without_padding_matches_cpu::( + HashFn::Rpx256, + ); +} + +#[test] +fn rpo_build_constraint_commitment_on_gpu_with_padding_matches_cpu() { + build_constraint_commitment_on_gpu_with_padding_matches_cpu::( + HashFn::Rpo256, + ); +} + +#[test] +fn rpx_build_constraint_commitment_on_gpu_with_padding_matches_cpu() { + build_constraint_commitment_on_gpu_with_padding_matches_cpu::( + HashFn::Rpx256, + ); +} + +#[test] +fn rpo_build_constraint_commitment_on_gpu_without_padding_matches_cpu() { + build_constraint_commitment_on_gpu_without_padding_matches_cpu::< + RpoRandomCoin, + Rpo256, + RpoDigest, + >(HashFn::Rpo256); +} + +#[test] +fn rpx_build_constraint_commitment_on_gpu_without_padding_matches_cpu() { + build_constraint_commitment_on_gpu_without_padding_matches_cpu::< + RpxRandomCoin, + Rpx256, + RpxDigest, + >(HashFn::Rpx256); +} + +// TEST FUNCTIONS +// ================================================================================================ + +fn build_trace_commitment_on_gpu_with_padding_matches_cpu(hash_fn: HashFn) +where R: RandomCoin + Send, - H: ElementHasher + Hasher, + H: ElementHasher + Hasher + Sync, D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>, ->( - hash_fn: HashFn, -) { +{ let is_rpx = matches!(hash_fn, HashFn::Rpx256); - let cpu_prover = create_test_prover::(is_rpx); - let gpu_prover = MetalExecutionProver::new(create_test_prover::(is_rpx), hash_fn); + let cpu_prover = create_test_prover::(is_rpx); + let gpu_prover = MetalExecutionProver::new(create_test_prover::(is_rpx), hash_fn); let num_rows = 1 << 8; let trace_info = get_trace_info(1, num_rows); let trace = gen_random_trace(num_rows, RATE + 1); let domain = StarkDomain::from_twiddles(fft::get_twiddles(num_rows), 8, Felt::GENERATOR); - let (cpu_trace_lde, cpu_polys) = - cpu_prover.new_trace_lde::(&trace_info, &trace, &domain); - let (gpu_trace_lde, gpu_polys) = - gpu_prover.new_trace_lde::(&trace_info, &trace, &domain); + let (cpu_trace_lde, cpu_polys) = cpu_prover.new_trace_lde::( + &trace_info, + &trace, + &domain, + PartitionOptions::default(), + ); + let (gpu_trace_lde, gpu_polys) = gpu_prover.new_trace_lde::( + &trace_info, + &trace, + &domain, + PartitionOptions::default(), + ); assert_eq!( cpu_trace_lde.get_main_trace_commitment(), @@ -44,26 +119,33 @@ fn build_trace_commitment_on_gpu_with_padding_matches_cpu< ); } -fn build_trace_commitment_on_gpu_without_padding_matches_cpu< +fn build_trace_commitment_on_gpu_without_padding_matches_cpu(hash_fn: HashFn) +where R: RandomCoin + Send, - H: ElementHasher + Hasher, + H: ElementHasher + Hasher + Sync, D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>, ->( - hash_fn: HashFn, -) { +{ let is_rpx = matches!(hash_fn, HashFn::Rpx256); - let cpu_prover = create_test_prover::(is_rpx); - let gpu_prover = MetalExecutionProver::new(create_test_prover::(is_rpx), hash_fn); + let cpu_prover = create_test_prover::(is_rpx); + let gpu_prover = MetalExecutionProver::new(create_test_prover::(is_rpx), hash_fn); let num_rows = 1 << 8; let trace_info = get_trace_info(1, num_rows); let trace = gen_random_trace(num_rows, RATE); let domain = StarkDomain::from_twiddles(fft::get_twiddles(num_rows), 8, Felt::GENERATOR); - let (cpu_trace_lde, cpu_polys) = - cpu_prover.new_trace_lde::(&trace_info, &trace, &domain); - let (gpu_trace_lde, gpu_polys) = - gpu_prover.new_trace_lde::(&trace_info, &trace, &domain); + let (cpu_trace_lde, cpu_polys) = cpu_prover.new_trace_lde::( + &trace_info, + &trace, + &domain, + PartitionOptions::default(), + ); + let (gpu_trace_lde, gpu_polys) = gpu_prover.new_trace_lde::( + &trace_info, + &trace, + &domain, + PartitionOptions::default(), + ); assert_eq!( cpu_trace_lde.get_main_trace_commitment(), @@ -75,17 +157,16 @@ fn build_trace_commitment_on_gpu_without_padding_matches_cpu< ); } -fn build_constraint_commitment_on_gpu_with_padding_matches_cpu< +fn build_constraint_commitment_on_gpu_with_padding_matches_cpu(hash_fn: HashFn) +where R: RandomCoin + Send, - H: ElementHasher + Hasher, + H: ElementHasher + Hasher + Sync, D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>, ->( - hash_fn: HashFn, -) { +{ let is_rpx = matches!(hash_fn, HashFn::Rpx256); - let cpu_prover = create_test_prover::(is_rpx); - let gpu_prover = MetalExecutionProver::new(create_test_prover::(is_rpx), hash_fn); + let cpu_prover = create_test_prover::(is_rpx); + let gpu_prover = MetalExecutionProver::new(create_test_prover::(is_rpx), hash_fn); let num_rows = 1 << 8; let ce_blowup_factor = 2; let values = get_random_values::(num_rows * ce_blowup_factor); @@ -99,22 +180,21 @@ fn build_constraint_commitment_on_gpu_with_padding_matches_cpu< let (commitment_gpu, composition_poly_gpu) = gpu_prover.build_constraint_commitment(CompositionPolyTrace::new(values), 2, &domain); - assert_eq!(commitment_cpu.root(), commitment_gpu.root()); + assert_eq!(commitment_cpu.commitment(), commitment_gpu.commitment()); assert_ne!(0, composition_poly_cpu.data().num_base_cols() % RATE); assert_eq!(composition_poly_cpu.into_columns(), composition_poly_gpu.into_columns()); } -fn build_constraint_commitment_on_gpu_without_padding_matches_cpu< +fn build_constraint_commitment_on_gpu_without_padding_matches_cpu(hash_fn: HashFn) +where R: RandomCoin + Send, - H: ElementHasher + Hasher, + H: ElementHasher + Hasher + Sync, D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>, ->( - hash_fn: HashFn, -) { +{ let is_rpx = matches!(hash_fn, HashFn::Rpx256); - let cpu_prover = create_test_prover::(is_rpx); - let gpu_prover = MetalExecutionProver::new(create_test_prover::(is_rpx), hash_fn); + let cpu_prover = create_test_prover::(is_rpx); + let gpu_prover = MetalExecutionProver::new(create_test_prover::(is_rpx), hash_fn); let num_rows = 1 << 8; let ce_blowup_factor = 8; let values = get_random_values::(num_rows * ce_blowup_factor); @@ -128,70 +208,13 @@ fn build_constraint_commitment_on_gpu_without_padding_matches_cpu< let (commitment_gpu, composition_poly_gpu) = gpu_prover.build_constraint_commitment(CompositionPolyTrace::new(values), 8, &domain); - assert_eq!(commitment_cpu.root(), commitment_gpu.root()); + assert_eq!(commitment_cpu.commitment(), commitment_gpu.commitment()); assert_eq!(0, composition_poly_cpu.data().num_base_cols() % RATE); assert_eq!(composition_poly_cpu.into_columns(), composition_poly_gpu.into_columns()); } -#[test] -fn rpo_build_trace_commitment_on_gpu_with_padding_matches_cpu() { - build_trace_commitment_on_gpu_with_padding_matches_cpu::( - HashFn::Rpo256, - ); -} - -#[test] -fn rpx_build_trace_commitment_on_gpu_with_padding_matches_cpu() { - build_trace_commitment_on_gpu_with_padding_matches_cpu::( - HashFn::Rpx256, - ); -} - -#[test] -fn rpo_build_trace_commitment_on_gpu_without_padding_matches_cpu() { - build_trace_commitment_on_gpu_without_padding_matches_cpu::( - HashFn::Rpo256, - ); -} - -#[test] -fn rpx_build_trace_commitment_on_gpu_without_padding_matches_cpu() { - build_trace_commitment_on_gpu_without_padding_matches_cpu::( - HashFn::Rpx256, - ); -} - -#[test] -fn rpo_build_constraint_commitment_on_gpu_with_padding_matches_cpu() { - build_constraint_commitment_on_gpu_with_padding_matches_cpu::( - HashFn::Rpo256, - ); -} - -#[test] -fn rpx_build_constraint_commitment_on_gpu_with_padding_matches_cpu() { - build_constraint_commitment_on_gpu_with_padding_matches_cpu::( - HashFn::Rpx256, - ); -} - -#[test] -fn rpo_build_constraint_commitment_on_gpu_without_padding_matches_cpu() { - build_constraint_commitment_on_gpu_without_padding_matches_cpu::< - RpoRandomCoin, - Rpo256, - RpoDigest, - >(HashFn::Rpo256); -} - -#[test] -fn rpx_build_constraint_commitment_on_gpu_without_padding_matches_cpu() { - build_constraint_commitment_on_gpu_without_padding_matches_cpu::< - RpxRandomCoin, - Rpx256, - RpxDigest, - >(HashFn::Rpx256); -} +// HELPER FUNCTIONS +// ================================================================================================ fn gen_random_trace(num_rows: usize, num_cols: usize) -> ColMatrix { ColMatrix::new((0..num_cols as u64).map(|col| vec![Felt::new(col); num_rows]).collect()) @@ -205,12 +228,11 @@ fn get_trace_info(num_cols: usize, num_rows: usize) -> TraceInfo { TraceInfo::new(num_cols, num_rows) } -fn create_test_prover< +fn create_test_prover(use_rpx: bool) -> ExecutionProver +where + H: ElementHasher + Sync, R: RandomCoin + Send, - H: ElementHasher, ->( - use_rpx: bool, -) -> ExecutionProver { +{ if use_rpx { ExecutionProver::new( ProvingOptions::with_128_bit_security_rpx(), diff --git a/prover/src/lib.rs b/prover/src/lib.rs index b8bee06696..2ba063bc2d 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -8,7 +8,7 @@ extern crate std; use core::marker::PhantomData; -use air::{AuxRandElements, ProcessorAir, PublicInputs}; +use air::{AuxRandElements, PartitionOptions, ProcessorAir, PublicInputs}; #[cfg(all(feature = "metal", target_arch = "aarch64", target_os = "macos"))] use miden_gpu::HashFn; use processor::{ @@ -20,6 +20,7 @@ use processor::{ ExecutionTrace, Program, }; use tracing::instrument; +use winter_maybe_async::{maybe_async, maybe_await}; use winter_prover::{ matrix::ColMatrix, ConstraintCompositionCoefficients, DefaultConstraintEvaluator, DefaultTraceLde, ProofOptions as WinterProofOptions, Prover, StarkDomain, TraceInfo, @@ -37,7 +38,7 @@ pub use processor::{ crypto, math, utils, AdviceInputs, Digest, ExecutionError, Host, InputError, MemAdviceProvider, StackInputs, StackOutputs, Word, }; -pub use winter_prover::Proof; +pub use winter_prover::{crypto::MerkleTree as MerkleTreeVC, Proof}; // PROVER // ================================================================================================ @@ -45,13 +46,15 @@ pub use winter_prover::Proof; /// Executes and proves the specified `program` and returns the result together with a STARK-based /// proof of the program's execution. /// -/// * `inputs` specifies the initial state of the stack as well as non-deterministic (secret) inputs -/// for the VM. -/// * `options` defines parameters for STARK proof generation. +/// - `stack_inputs` specifies the initial state of the stack for the VM. +/// - `host` specifies the host environment which contain non-deterministic (secret) inputs for the +/// prover +/// - `options` defines parameters for STARK proof generation. /// /// # Errors /// Returns an error if program execution or STARK proof generation fails for any reason. #[instrument("prove_program", skip_all)] +#[maybe_async] pub fn prove( program: &Program, stack_inputs: StackInputs, @@ -81,18 +84,22 @@ where // generate STARK proof let proof = match hash_fn { - HashFunction::Blake3_192 => ExecutionProver::>::new( - options, - stack_inputs, - stack_outputs.clone(), - ) - .prove(trace), - HashFunction::Blake3_256 => ExecutionProver::>::new( - options, - stack_inputs, - stack_outputs.clone(), - ) - .prove(trace), + HashFunction::Blake3_192 => { + let prover = ExecutionProver::>::new( + options, + stack_inputs, + stack_outputs.clone(), + ); + maybe_await!(prover.prove(trace)) + }, + HashFunction::Blake3_256 => { + let prover = ExecutionProver::>::new( + options, + stack_inputs, + stack_outputs.clone(), + ); + maybe_await!(prover.prove(trace)) + }, HashFunction::Rpo256 => { let prover = ExecutionProver::::new( options, @@ -101,7 +108,7 @@ where ); #[cfg(all(feature = "metal", target_arch = "aarch64", target_os = "macos"))] let prover = gpu::metal::MetalExecutionProver::new(prover, HashFn::Rpo256); - prover.prove(trace) + maybe_await!(prover.prove(trace)) }, HashFunction::Rpx256 => { let prover = ExecutionProver::::new( @@ -111,7 +118,7 @@ where ); #[cfg(all(feature = "metal", target_arch = "aarch64", target_os = "macos"))] let prover = gpu::metal::MetalExecutionProver::new(prover, HashFn::Rpx256); - prover.prove(trace) + maybe_await!(prover.prove(trace)) }, } .map_err(ExecutionError::ProverError)?; @@ -158,7 +165,6 @@ where /// Validates the stack inputs against the provided execution trace and returns true if valid. fn are_inputs_valid(&self, trace: &ExecutionTrace) -> bool { self.stack_inputs - .values() .iter() .zip(trace.init_stack_state().iter()) .all(|(l, r)| l == r) @@ -167,7 +173,6 @@ where /// Validates the stack outputs against the provided execution trace and returns true if valid. fn are_outputs_valid(&self, trace: &ExecutionTrace) -> bool { self.stack_outputs - .stack_top() .iter() .zip(trace.last_stack_state().iter()) .all(|(l, r)| l == r) @@ -176,15 +181,16 @@ where impl Prover for ExecutionProver where - H: ElementHasher, + H: ElementHasher + Sync, R: RandomCoin + Send, { type BaseField = Felt; type Air = ProcessorAir; type Trace = ExecutionTrace; type HashFn = H; + type VC = MerkleTreeVC; type RandomCoin = R; - type TraceLde> = DefaultTraceLde; + type TraceLde> = DefaultTraceLde; type ConstraintEvaluator<'a, E: FieldElement> = DefaultConstraintEvaluator<'a, ProcessorAir, E>; @@ -207,15 +213,18 @@ where PublicInputs::new(program_info, self.stack_inputs.clone(), self.stack_outputs.clone()) } + #[maybe_async] fn new_trace_lde>( &self, trace_info: &TraceInfo, main_trace: &ColMatrix, domain: &StarkDomain, + partition_options: PartitionOptions, ) -> (Self::TraceLde, TracePolyTable) { - DefaultTraceLde::new(trace_info, main_trace, domain) + DefaultTraceLde::new(trace_info, main_trace, domain, partition_options) } + #[maybe_async] fn new_evaluator<'a, E: FieldElement>( &self, air: &'a ProcessorAir, @@ -225,14 +234,13 @@ where DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients) } - fn build_aux_trace( + #[instrument(skip_all)] + #[maybe_async] + fn build_aux_trace>( &self, trace: &Self::Trace, aux_rand_elements: &AuxRandElements, - ) -> ColMatrix - where - E: FieldElement, - { + ) -> ColMatrix { trace.build_aux_trace(aux_rand_elements.rand_elements()).unwrap() } } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 11ec1f8414..a1c01e0415 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "1.80" +channel = "1.82" components = ["rustfmt", "rust-src", "clippy"] targets = ["wasm32-unknown-unknown"] profile = "minimal" diff --git a/stdlib/Cargo.toml b/stdlib/Cargo.toml index 170fed6a81..07e24b43dc 100644 --- a/stdlib/Cargo.toml +++ b/stdlib/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "miden-stdlib" -version = "0.10.5" +version = "0.11.0" description = "Miden VM standard library" -documentation = "https://docs.rs/miden-stdlib/0.10.5" +documentation = "https://docs.rs/miden-stdlib/0.11.0" readme = "README.md" categories = ["cryptography", "mathematics"] keywords = ["miden", "program", "stdlib"] @@ -31,26 +31,26 @@ std = ["assembly/std"] with-debug-info = [] [dependencies] -assembly = { package = "miden-assembly", path = "../assembly", version = "0.10", default-features = false } +assembly = { package = "miden-assembly", path = "../assembly", version = "0.11", default-features = false } [dev-dependencies] blake3 = "1.5" criterion = "0.5" -miden-air = { package = "miden-air", path = "../air", version = "0.10", default-features = false } -num = "0.4.1" +miden-air = { package = "miden-air", path = "../air", version = "0.11", default-features = false } +num = "0.4" num-bigint = "0.4" pretty_assertions = "1.4" -processor = { package = "miden-processor", path = "../processor", version = "0.10", default-features = false, features = [ +processor = { package = "miden-processor", path = "../processor", version = "0.11", default-features = false, features = [ "testing", ] } -rand = { version = "0.8.5", default-features = false } +rand = { version = "0.8", default-features = false } serde_json = "1.0" sha2 = "0.10" sha3 = "0.10" test-utils = { package = "miden-test-utils", path = "../test-utils" } -winter-air = { package = "winter-air", version = "0.9" } -winter-fri = { package = "winter-fri", version = "0.9" } +winter-air = { package = "winter-air", version = "0.10" } +winter-fri = { package = "winter-fri", version = "0.10" } [build-dependencies] -assembly = { package = "miden-assembly", path = "../assembly", version = "0.10" } +assembly = { package = "miden-assembly", path = "../assembly", version = "0.11" } diff --git a/stdlib/asm/collections/mmr.masm b/stdlib/asm/collections/mmr.masm index 3f50adfb09..4ed948fc62 100644 --- a/stdlib/asm/collections/mmr.masm +++ b/stdlib/asm/collections/mmr.masm @@ -1,5 +1,5 @@ use.std::mem -use.std::crypto::hashes::native +use.std::crypto::hashes::rpo use.std::math::u64 #! Loads the leaf at the absolute `pos` in the MMR. @@ -165,7 +165,7 @@ export.unpack # => [C, B, A, mmr_ptr+17, HASH, ...] # drop anything but the hash result, word B (11 cycles) - dropw swapw dropw movup.4 drop + exec.rpo::squeeze_digest movup.4 drop # => [B, HASH, ...] # assert on the resulting hash (11 cycles) @@ -196,8 +196,8 @@ export.pack # hash the memory contents (25 + 3 * num_peaks) padw padw padw - exec.native::hash_memory_even - exec.native::state_to_digest + exec.rpo::absorb_double_words_from_memory + exec.rpo::squeeze_digest # => [HASH, peaks_end, peaks_end, mmr_ptr, ...] # prepare stack for adv.insert_mem (4 cycles) diff --git a/stdlib/asm/crypto/dsa/rpo_falcon512.masm b/stdlib/asm/crypto/dsa/rpo_falcon512.masm index 5cb418dce3..637f490bf9 100644 --- a/stdlib/asm/crypto/dsa/rpo_falcon512.masm +++ b/stdlib/asm/crypto/dsa/rpo_falcon512.masm @@ -1,3 +1,5 @@ +use.std::crypto::hashes::rpo + # CONSTANTS # ================================================================================================= @@ -224,7 +226,7 @@ export.load_h_s2_and_product.1 end # 6) Return the challenge point and the incremented pointer - dropw swapw dropw + exec.rpo::squeeze_digest drop drop #=> [tau1, tau0, ptr + 512] end diff --git a/stdlib/asm/crypto/hashes/native.masm b/stdlib/asm/crypto/hashes/native.masm deleted file mode 100644 index 1cd8ac3355..0000000000 --- a/stdlib/asm/crypto/hashes/native.masm +++ /dev/null @@ -1,102 +0,0 @@ -#! Prepares the top of the stack with the hasher initial state. -#! -#! This procedures does not handle padding, therefore, the user is expected to -#! consume an amount of data which is a multiple of the rate (2 words). -#! -#! Input: [] -#! Ouptut: [PERM, PERM, PERM, ...] -#! Cycles: 12 -export.init_no_padding - padw padw padw -end - -#! Given the hasher state, returns the hash output -#! -#! Input: [C, B, A, ...] -#! Ouptut: [HASH, ...] -#! Where: For the native RPO hasher HASH is B. -#! Cycles: 9 -export.state_to_digest - # drop the first rate word (4 cycles) - dropw - - # save the hash result (1 cycles) - swapw - - # drop the capacity word (4 cycles) - dropw -end - -#! Hashes the memory `start_addr` to `end_addr`. -#! -#! This requires that `end_addr=start_addr + 2n + 1`, otherwise the procedure will enter an infinite -#! loop. `end_addr` is not inclusive. -#! -#! Stack transition: -#! Input: [C, B, A, start_addr, end_addr, ...] -#! Output: [C', B', A', end_addr, end_addr ...] -#! Cycles: 4 + 3 * words, where `words` is the `start_addr - end_addr - 1` -#! -#! Where `A` is the capacity word that will be used by the hashing function, and `B'` the hash output. -export.hash_memory_even - dup.13 dup.13 neq # (4 cycles ) - while.true - mem_stream hperm # (2 cycles) - dup.13 dup.13 neq # (4 cycles ) - end -end - -#! Hashes the memory `start_addr` to `end_addr`, handles odd number of elements. -#! -#! Requires `start_addr < end_addr`, `end_addr` is not inclusive. -#! -#! Stack transition: -#! Input: [start_addr, end_addr, ...] -#! Output: [H, ...] -#! Cycles: -#! even words: 48 cycles + 3 * words -#! odd words: 60 cycles + 3 * words -export.hash_memory - # enforce `start_addr < end_addr` - dup.1 dup.1 u32assert2 u32gt assert - - # figure out if the range is for an odd number of words (9 cycles) - dup.1 dup.1 sub is_odd - # stack: [is_odd, start_addr, end_addr, ...] - - # make the start/end range even (4 cycles) - movup.2 dup.1 sub - # stack: [end_addr, is_odd, start_addr, ...] - - # move start_addr to the right stack position (1 cycles) - movup.2 - # stack: [start_addr, end_addr, is_odd, ...] - - # prepare hasher state (12 cycles) - dup.2 push.0.0.0 padw padw - # stack: [C, B, A, start_addr, end_addr, is_odd, ...] - - # (4 + 3 * words cycles) - exec.hash_memory_even - - # (1 cycles) - movup.14 - - # handle the odd element, if any (12 cycles) - if.true - # start_addr and end_addr are equal after calling `hash_memory_even`, and both point - # to the last element. Load the last word (2 cycles) - dup.13 mem_loadw - - # set the padding (9 cycles) - swapw dropw push.1.0.0.0 - - # (1 cycles) - hperm - end - - exec.state_to_digest - - # drop start_addr/end_addr (4 cycles) - movup.4 drop movup.4 drop -end diff --git a/stdlib/asm/crypto/hashes/rpo.masm b/stdlib/asm/crypto/hashes/rpo.masm new file mode 100644 index 0000000000..1035fa464d --- /dev/null +++ b/stdlib/asm/crypto/hashes/rpo.masm @@ -0,0 +1,266 @@ +#! Prepares the top of the stack with the hasher initial state. +#! +#! This procedures does not handle padding, therefore, the user is expected to +#! consume an amount of data which is a multiple of the rate (2 words). +#! +#! Input: [] +#! Ouptut: [PERM, PERM, PERM, ...] +#! +#! Cycles: 12 +export.init_no_padding + padw padw padw +end + +#! Given the hasher state, returns the hash output. +#! +#! Input: [C, B, A, ...] +#! Ouptut: [HASH, ...] +#! +#! Where : +#! - `A` is the capacity word that will be used by the hashing function. +#! - `B` is the hash output. +#! +#! Cycles: 9 +export.squeeze_digest + # drop the first rate word (4 cycles) + dropw + + # save the hash result (1 cycles) + swapw + + # drop the capacity word (4 cycles) + dropw +end + +#! Hashes the memory `start_addr` to `end_addr` given an RPO state specified by 3 words. +#! +#! This requires that `end_addr = start_addr + 2n` where n = {0, 1, 2 ...}, otherwise the procedure +#! will enter an infinite loop. +#! +#! Input: [C, B, A, start_addr, end_addr, ...] +#! Output: [C', B', A', end_addr, end_addr ...] +#! +#! Where : +#! - `A` is the capacity word that will be used by the hashing function. +#! - `B` is the hash output. +#! +#! Cycles: 4 + 3 * words, where `words` is the `start_addr - end_addr` +export.absorb_double_words_from_memory + dup.13 dup.13 neq # (4 cycles ) + while.true + mem_stream hperm # (2 cycles) + dup.13 dup.13 neq # (4 cycles ) + end +end + +#! Hashes the memory `start_addr` to `end_addr`, handles odd number of elements. +#! +#! Requires `start_addr ≤ end_addr`, `end_addr` is not inclusive. +#! +#! Input: [start_addr, end_addr, ...] +#! Output: [H, ...] +#! +#! Cycles: +#! - even words: 49 cycles + 3 * words +#! - odd words: 61 cycles + 3 * words +#! where `words` is the `start_addr - end_addr - 1` +export.hash_memory_words + # enforce `start_addr ≤ end_addr` + dup.1 dup.1 u32assert2 u32gte assert + + # figure out if the range is for an odd number of words (9 cycles) + dup.1 dup.1 sub is_odd + # => [is_odd, start_addr, end_addr, ...] + + # make the start/end range even (4 cycles) + movup.2 dup.1 sub + # => [end_addr, is_odd, start_addr, ...] + + # move start_addr to the right stack position (1 cycles) + movup.2 + # => [start_addr, end_addr, is_odd, ...] + + # prepare hasher state (14 cycles) + dup.2 mul.4 push.0.0.0 padw padw + # => [C, B, A, start_addr, end_addr, is_odd, ...] + + # (4 + 3 * words cycles) + exec.absorb_double_words_from_memory + # => [C', B', A', end_addr, end_addr, is_odd, ...] + + # (1 cycles) + movup.14 + # => [is_odd, C', B', A', end_addr, end_addr, ...] + + # handle the odd element, if any (12 cycles) + if.true + # start_addr and end_addr are equal after calling `absorb_double_words_from_memory`, and both + # point to the last element. Load the last word (6 cycles) + dropw dup.9 mem_loadw + # => [D, A', end_addr, end_addr, ...] + + # set the padding and compute the permutation (5 cycles) + padw hperm + end + + exec.squeeze_digest + # => [HASH, end_addr, end_addr, ...] + + # drop start_addr/end_addr (4 cycles) + movup.4 drop movup.4 drop + # => [HASH] +end + +#! Computes hash of Felt values starting at the specified memory address. +#! +#! This procedure divides the hashing process into two parts: hashing pairs of words using +#! `absorb_double_words_from_memory` procedure and hashing the remaining values using the `hperm` +#! instruction. +#! +#! Inputs: [ptr, num_elements] +#! Outputs: [HASH] +#! +#! Cycles: +#! - If number of elements divides by 8: 47 cycles + 3 * words +#! - Else: 180 cycles + 3 * words +#! where `words` is the number of quads of input values. +export.hash_memory + # move number of inputs to the top of the stack + swap + # => [num_elements, ptr] + + # get the number of double words + u32divmod.8 swap + # => [num_elements/8, num_elements%8, ptr] + + # get the end_addr for hash_memory_even procedure (end address for pairs of words) + mul.2 dup.2 add movup.2 + # => [ptr, end_addr, num_elements%8] + + # get the capacity element which is equal to num_elements%8 + dup.2 + # => [capacity, ptr, end_addr, num_elements%8] + + # prepare hasher state for RPO permutation + push.0.0.0 padw padw + # => [C, B, A, ptr, end_addr, num_elements%8] + + # hash every pair of words + exec.absorb_double_words_from_memory + # => [C', B', A', ptr', end_addr, num_elements%8] where ptr' = end_addr + + # hash remaining input values if there are any left + # if num_elements%8 is ZERO and there are no elements to hash + dup.14 eq.0 + if.true + # clean the stack + exec.squeeze_digest + swapw drop drop drop movdn.4 + # => [B'] + else + # load the remaining double word + mem_stream + # => [E, D, A', ptr'+2, end_addr, num_elements%8] + + # clean the stack + movup.12 drop movup.12 drop + # => [E, D, A', num_elements%8] + + # get the number of elements we need to drop + # notice that drop_counter could be any number from 1 to 7 + push.8 movup.13 sub movdn.12 + # => [E, D, A', drop_counter] + + ### 0th value ######################################################## + + # we need to drop first value anyway, since number of values is not divisible by 8 + # push the padding 0 on to the stack and move it down to the 6th position + drop push.0 movdn.6 + # => [e_2, e_1, e_0, d_3, d_2, d_1, 0, d_0, A', drop_counter] + + ### 1st value ######################################################## + + # prepare the second element of the E Word for cdrop instruction + push.0 swap + # => [e_2, 0, e_1, e_0, d_3, d_2, d_1, 0, d_0, A', drop_counter] + + # push latch variable onto the stack; this will be the control for the cdrop instruction + push.0 + # => [latch = 0, e_2, 0, e_1, e_0, d_3, d_2, d_1, 0, d_0, A', drop_counter] + + # get the flag whether the drop counter is equal 1 + dup.14 eq.1 + # => [drop_counter == 1, latch = 0, e_2, 0, e_1, e_0, d_3, d_2, d_1, 0, d_0, A', drop_counter] + + # update the latch: if drop_counter == 1, latch will become 1 + or + # => [latch', e_2, 0, e_1, e_0, d_3, d_2, d_1, 0, d_0, A', drop_counter] + + # save the latch value + dup movdn.14 + # => [latch', e_2, 0, e_1, e_0, d_3, d_2, d_1, 0, d_0, A', latch', drop_counter] + + # if latch == 1, drop 0; otherwise drop e_1 + cdrop + # => [e_2_or_0, e_1, e_0, d_3, d_2, d_1, 0, d_0, A', latch', drop_counter] + + # move the calculated value down the stack + movdn.6 + # => [e_1, e_0, d_3, d_2, d_1, 0, e_2_or_0, d_0, A', latch', drop_counter] + + ### 2nd value ######################################################## + + # repeat the above process but now compare drop_counter to 2 + push.0 swap + movup.13 dup.14 eq.2 or + dup movdn.14 + cdrop movdn.6 + # => [e_0, d_3, d_2, d_1, 0, e_2_or_0, e_1_or_0, d_0, A', latch', drop_counter] + + ### 3rd value ######################################################## + + # repeat the above process but now compare drop_counter to 3 + push.0 swap + movup.13 dup.14 eq.3 or + dup movdn.14 + cdrop movdn.6 + # => [d_3, d_2, d_1, 0, e_2_or_0, e_1_or_0, e_0_or_0, d_0, A', latch', drop_counter] + + ### 4th value ######################################################## + + # repeat the above process but now compare drop_counter to 4 + push.0 swap + movup.13 dup.14 eq.4 or + dup movdn.14 + cdrop movdn.6 + # => [d_2, d_1, 0, e_2_or_0, e_1_or_0, e_0_or_0, d_3_or_0, d_0, A', latch', drop_counter] + + ### 5th value ######################################################## + + # repeat the above process but now compare drop_counter to 5 + push.0 swap + movup.13 dup.14 eq.5 or + dup movdn.14 + cdrop movdn.6 + # => [d_1, 0, e_2_or_0, e_1_or_0, e_0_or_0, d_3_or_0, d_2_or_0, d_0, A', latch', drop_counter] + + ### 6th value ######################################################## + + # repeat the above process but now compare drop_counter to 6 + push.0 swap + movup.13 movup.14 eq.6 or + cdrop movdn.6 + # => [0, e_2_or_0, e_1_or_0, e_0_or_0, d_3_or_0, d_2_or_0, d_1_or_0, d_0, A'] + # or in other words + # => [C, B, A', ... ] + # notice that we don't need to check the d_0 value: entering the else branch means that + # we have number of elements not divisible by 8, so we will have at least one element to + # hash here (which turns out to be d_0) + + hperm + # => [F, E, D] + + exec.squeeze_digest + # => [E] + end +end diff --git a/stdlib/asm/crypto/stark/ood_frames.masm b/stdlib/asm/crypto/stark/ood_frames.masm index c6b9e50f26..bb3fabc1bc 100644 --- a/stdlib/asm/crypto/stark/ood_frames.masm +++ b/stdlib/asm/crypto/stark/ood_frames.masm @@ -1,4 +1,5 @@ use.std::crypto::stark::constants +use.std::crypto::hashes::rpo #! Loads OOD evaluation frame, with current and next rows interleaved, into memory. This ouputs @@ -105,7 +106,7 @@ export.load_constraint_evaluations hperm - dropw swapw dropw + exec.rpo::squeeze_digest end #! Computes the H(z) evaluation of the constraint composition polynomial at the OOD element z. diff --git a/stdlib/asm/crypto/stark/random_coin.masm b/stdlib/asm/crypto/stark/random_coin.masm index e66b024a40..07f431cb22 100644 --- a/stdlib/asm/crypto/stark/random_coin.masm +++ b/stdlib/asm/crypto/stark/random_coin.masm @@ -1,6 +1,6 @@ use.std::crypto::stark::constants use.std::crypto::stark::utils - +use.std::crypto::hashes::rpo #! Helper procedure to compute addition of two words component-wise. #! Input: [b3, b2, b1, b0, a3, a2, a1, a0] @@ -665,7 +665,7 @@ export.generate_list_indices exec.get_rate_2 hperm - dropw swapw dropw + exec.rpo::squeeze_digest #=> [R1, query_ptr, mask, depth, num_queries, ...] @@ -698,7 +698,7 @@ export.generate_list_indices exec.get_rate_2 hperm - dropw swapw dropw + exec.rpo::squeeze_digest #=> [R1, query_ptr, mask, depth, num_queries, ...] movup.7 sub.1 dup movdn.8 push.0 neq diff --git a/stdlib/asm/math/u64.masm b/stdlib/asm/math/u64.masm index 63cc2ae957..cc22854b9d 100644 --- a/stdlib/asm/math/u64.masm +++ b/stdlib/asm/math/u64.masm @@ -568,13 +568,11 @@ end #! error. #! Stack transition looks as follows: #! [b, a_hi, a_lo, ...] -> [c_hi, c_lo, ...], where c = a << b mod 2^64. -#! This takes 40 cycles. +#! This takes 44 cycles. export.rotr push.31 dup.1 - u32overflowing_sub - swap - drop + u32lt movdn.3 # Shift the low limb left by 32-b. @@ -582,17 +580,19 @@ export.rotr u32and push.32 swap - u32overflowing_sub - drop + u32wrapping_sub pow2 dup movup.3 - u32overflowing_mul + mul + u32split # Shift the high limb left by 32-b. movup.3 movup.3 - u32overflowing_madd + mul + add + u32split # Carry the overflow shift to the low bits. movup.2 @@ -609,7 +609,7 @@ end #! The input value is assumed to be represented using 32 bit limbs, but this is not checked. #! Stack transition looks as follows: #! [n_hi, n_lo, ...] -> [clz, ...], where clz is a number of leading zeros of value n. -#! This takes 43 cycles. +#! This takes 48 cycles. export.clz dup.0 eq.0 @@ -618,7 +618,7 @@ export.clz drop u32clz add.32 # clz(n_lo) + 32 - else + else swap drop u32clz # clz(n_hi) @@ -639,7 +639,7 @@ export.ctz drop u32ctz add.32 # ctz(n_hi) + 32 - else + else swap drop u32ctz # ctz(n_lo) @@ -650,7 +650,7 @@ end #! The input value is assumed to be represented using 32 bit limbs, but this is not checked. #! Stack transition looks as follows: #! [n_hi, n_lo, ...] -> [clo, ...], where clo is a number of leading ones of value n. -#! This takes 42 cycles. +#! This takes 47 cycles. export.clo dup.0 eq.4294967295 @@ -659,7 +659,7 @@ export.clo drop u32clo add.32 # clo(n_lo) + 32 - else + else swap drop u32clo # clo(n_hi) @@ -680,7 +680,7 @@ export.cto drop u32cto add.32 # cto(n_hi) + 32 - else + else swap drop u32cto # ctz(n_lo) diff --git a/stdlib/asm/mem.masm b/stdlib/asm/mem.masm index f81508faf4..b420ab944a 100644 --- a/stdlib/asm/mem.masm +++ b/stdlib/asm/mem.masm @@ -1,3 +1,5 @@ +use.std::crypto::hashes::rpo + # ===== MEMORY FUNCTIONS ========================================================================== #! Copies `n` words from `read_ptr` to `write_ptr`. @@ -82,10 +84,10 @@ end #! Copies an arbitrary number of words from the advice stack to memory #! #! Input: [num_words, write_ptr, ...] -#! Output: [HASH, write_ptr', ...] +#! Output: [C, B, A, write_ptr', ...] #! Cycles: -#! even num_words: 48 + 9 * num_words / 2 -#! odd num_words: 65 + 9 * round_down(num_words / 2) +#! even num_words: 41 + 9 * num_words / 2 +#! odd num_words: 58 + 9 * round_down(num_words / 2) export.pipe_words_to_memory.0 # check if there is an odd number of words (6 cycles) dup is_odd @@ -99,10 +101,11 @@ export.pipe_words_to_memory.0 sub dup.1 add swap # => [write_ptr, end_ptr, needs_padding, ...] - # Prepare the capacity word. For rescue prime optimized the first element is - # set to `1` when padding is used and `0` otherwse, this is determined by the - # `needs_padding` flag. (4 cycles) - dup.2 push.0.0.0 + # Prepare the capacity word. We use the padding rule which sets the first capacity + # element to `len % 8` where `len` is the length of the hashed sequence. Since `len % 8` + # is either equal to 0 or 4, this is determined by the `needs_padding` flag multiplied + # by 4. (6 cycles) + dup.2 mul.4 push.0.0.0 # => [A, write_ptr, end_ptr, needs_padding, ...] # set initial hasher state (8 cycles) @@ -141,17 +144,13 @@ export.pipe_words_to_memory.0 # => [B', A, write_ptr+1, ...] # Push padding word (4 cycles) - push.1.0.0.0 + padw # => [C, B', A, write_ptr+1, ...] # Run RPO permutation (1 cycles) hperm # => [C', B', A', write_ptr+1, ...] end - - # The RPO result is word B, discard the unused portion of the rate and the capacity. (9 cycles) - dropw swapw dropw - # => [rpo_hash, write_ptr', ...] end #! Moves an arbitrary number of words from the advice stack to memory and asserts it matches the commitment. @@ -159,12 +158,16 @@ end #! Input: [num_words, write_ptr, COM, ...] #! Output: [write_ptr', ...] #! Cycles: -#! even num_words: 58 + 9 * num_words / 2 -#! odd num_words: 75 + 9 * round_down(num_words / 2) +#! even num_words: 62 + 9 * num_words / 2 +#! odd num_words: 79 + 9 * round_down(num_words / 2) export.pipe_preimage_to_memory.0 # Copies the advice stack data to memory exec.pipe_words_to_memory - # => [HASH, write_ptr', COM, ...] + # => [C, B, A, write_ptr', COM, ...] + + # Leave only the digest on the stack + exec.rpo::squeeze_digest + # => [B, write_ptr', COM, ...] # Save the write_ptr (2 cycles) movup.4 movdn.8 diff --git a/stdlib/docs/crypto/hashes/native.md b/stdlib/docs/crypto/hashes/native.md deleted file mode 100644 index db8dc41628..0000000000 --- a/stdlib/docs/crypto/hashes/native.md +++ /dev/null @@ -1,7 +0,0 @@ -Prepares the top of the stack with the hasher initial state.

This procedures does not handle padding, therefore, the user is expected to
consume an amount of data which is a multiple of the rate (2 words).

Input: []
Ouptut: [PERM, PERM, PERM, ...]
Cycles: 12
-## std::crypto::hashes::native -| Procedure | Description | -| ----------- | ------------- | -| state_to_digest | Given the hasher state, returns the hash output

Input: [C, B, A, ...]
Ouptut: [HASH, ...]
Where: For the native RPO hasher HASH is B.
Cycles: 9
| -| hash_memory_even | Hashes the memory `start_addr` to `end_addr`.

This requires that `end_addr=start_addr + 2n + 1`, otherwise the procedure will enter an infinite
loop. `end_addr` is not inclusive.

Stack transition:
Input: [C, B, A, start_addr, end_addr, ...]
Output: [C', B', A', end_addr, end_addr ...]
Cycles: 4 + 3 * words, where `words` is the `start_addr - end_addr - 1`

Where `A` is the capacity word that will be used by the hashing function, and `B'` the hash output.
| -| hash_memory | Hashes the memory `start_addr` to `end_addr`, handles odd number of elements.

Requires `start_addr < end_addr`, `end_addr` is not inclusive.

Stack transition:
Input: [start_addr, end_addr, ...]
Output: [H, ...]
Cycles:
even words: 48 cycles + 3 * words
odd words: 60 cycles + 3 * words
| diff --git a/stdlib/docs/crypto/hashes/rpo.md b/stdlib/docs/crypto/hashes/rpo.md new file mode 100644 index 0000000000..da22de54cc --- /dev/null +++ b/stdlib/docs/crypto/hashes/rpo.md @@ -0,0 +1,8 @@ +Prepares the top of the stack with the hasher initial state.

This procedures does not handle padding, therefore, the user is expected to
consume an amount of data which is a multiple of the rate (2 words).

Input: []
Ouptut: [PERM, PERM, PERM, ...]
Cycles: 12
+## std::crypto::hashes::rpo +| Procedure | Description | +| ----------- | ------------- | +| squeeze_digest | Given the hasher state, returns the hash output.

Input: [C, B, A, ...]
Ouptut: [HASH, ...]
where: For the native RPO hasher HASH is B.
Cycles: 9
| +| absorb_double_words_from_memory | Hashes the memory `start_addr` to `end_addr` given an RPO state specified by 3 words.

This requires that `end_addr=start_addr + 2n + 1`, otherwise the procedure will enter an infinite
loop. `end_addr` is not inclusive.

Stack transition:
Input: [C, B, A, start_addr, end_addr, ...]
Output: [C', B', A', end_addr, end_addr ...]
Cycles: 4 + 3 * words, where `words` is the `start_addr - end_addr - 1`

Where `A` is the capacity word that will be used by the hashing function, and `B'` the hash output.
| +| hash_memory_words | Hashes the memory `start_addr` to `end_addr`, handles odd number of elements.

Requires `start_addr < end_addr`, `end_addr` is not inclusive.

Stack transition:
Input: [start_addr, end_addr, ...]
Output: [H, ...]
Cycles:
even words: 49 cycles + 3 * words
odd words: 61 cycles + 3 * words
| +| hash_memory | Computes hash of Felt values starting at the specified memory address.

This procedure divides the hashing process into two parts: hashing pairs of words using
`absorb_double_words_from_memory` procedure and hashing the remaining values using the `hperm`
instruction.

Inputs: [ptr, num_elements]
Outputs: [HASH]
Cycles:
- If number of elements divides by 8: 47 cycles + 3 * words
- Else: 180 cycles + 3 * words

Panics if number of inputs equals 0.
| diff --git a/stdlib/docs/mem.md b/stdlib/docs/mem.md index 4aa67c46b7..84fdb8134e 100644 --- a/stdlib/docs/mem.md +++ b/stdlib/docs/mem.md @@ -4,5 +4,5 @@ | ----------- | ------------- | | memcopy | Copies `n` words from `read_ptr` to `write_ptr`.

Stack transition looks as follows:
[n, read_ptr, write_ptr, ...] -> [...]
cycles: 15 + 16n
| | pipe_double_words_to_memory | Copies an even number of words from the advice_stack to memory.

Input: [C, B, A, write_ptr, end_ptr, ...]
Output: [C, B, A, write_ptr, ...]

Where:
- The words C, B, and A are the RPO hasher state
- A is the capacity
- C,B are the rate portion of the state
- The value `words = end_ptr - write_ptr` must be positive and even

Cycles: 10 + 9 * word_pairs
| -| pipe_words_to_memory | Copies an arbitrary number of words from the advice stack to memory

Input: [num_words, write_ptr, ...]
Output: [HASH, write_ptr', ...]
Cycles:
even num_words: 48 + 9 * num_words / 2
odd num_words: 65 + 9 * round_down(num_words / 2)
| -| pipe_preimage_to_memory | Moves an arbitrary number of words from the advice stack to memory and asserts it matches the commitment.

Input: [num_words, write_ptr, COM, ...]
Output: [write_ptr', ...]
Cycles:
even num_words: 58 + 9 * num_words / 2
odd num_words: 75 + 9 * round_down(num_words / 2)
| +| pipe_words_to_memory | Copies an arbitrary number of words from the advice stack to memory

Input: [num_words, write_ptr, ...]
Output: [C, B, A, write_ptr', ...]
Cycles:
even num_words: 41 + 9 * num_words / 2
odd num_words: 58 + 9 * round_down(num_words / 2)
| +| pipe_preimage_to_memory | Moves an arbitrary number of words from the advice stack to memory and asserts it matches the commitment.

Input: [num_words, write_ptr, COM, ...]
Output: [write_ptr', ...]
Cycles:
even num_words: 62 + 9 * num_words / 2
odd num_words: 79 + 9 * round_down(num_words / 2)
| diff --git a/stdlib/src/lib.rs b/stdlib/src/lib.rs index 9637aff89c..2ebfb2279f 100644 --- a/stdlib/src/lib.rs +++ b/stdlib/src/lib.rs @@ -2,12 +2,19 @@ extern crate alloc; -use assembly::{mast::MastForest, utils::Deserializable, Library}; +use alloc::sync::Arc; + +use assembly::{ + mast::MastForest, + utils::{sync::LazyLock, Deserializable}, + Library, +}; // STANDARD LIBRARY // ================================================================================================ /// TODO: add docs +#[derive(Clone)] pub struct StdLibrary(Library); impl AsRef for StdLibrary { @@ -22,22 +29,25 @@ impl From for Library { } } -impl From for MastForest { - fn from(value: StdLibrary) -> Self { - value.0.into() - } -} - impl StdLibrary { + /// Serialized representation of the Miden standard library. pub const SERIALIZED: &'static [u8] = include_bytes!(concat!(env!("OUT_DIR"), "/assets/std.masl")); + + /// Returns a reference to the [MastForest] underlying the Miden standard library. + pub fn mast_forest(&self) -> &Arc { + self.0.mast_forest() + } } impl Default for StdLibrary { fn default() -> Self { - let contents = - Library::read_from_bytes(Self::SERIALIZED).expect("failed to read std masl!"); - Self(contents) + static STDLIB: LazyLock = LazyLock::new(|| { + let contents = + Library::read_from_bytes(StdLibrary::SERIALIZED).expect("failed to read std masl!"); + StdLibrary(contents) + }); + STDLIB.clone() } } diff --git a/stdlib/tests/collections/mmr.rs b/stdlib/tests/collections/mmr.rs index c998f8dc2c..36fe1c5486 100644 --- a/stdlib/tests/collections/mmr.rs +++ b/stdlib/tests/collections/mmr.rs @@ -3,7 +3,7 @@ use test_utils::{ init_merkle_leaf, init_merkle_leaves, MerkleError, MerkleStore, MerkleTree, Mmr, NodeIndex, RpoDigest, }, - hash_elements, stack_to_ints, Felt, StarkField, Word, EMPTY_WORD, ONE, ZERO, + felt_slice_to_ints, hash_elements, Felt, StarkField, Word, EMPTY_WORD, ONE, ZERO, }; // TESTS @@ -70,13 +70,16 @@ fn test_mmr_get_single_peak() -> Result<(), MerkleError> { for pos in 0..(leaves.len() as u64) { let source = format!( - "use.std::collections::mmr + " + use.std::collections::mmr begin push.{num_leaves} push.1000 mem_store # leaves count adv_push.4 push.1001 mem_storew dropw # MMR single peak push.1000 push.{pos} exec.mmr::get + + swapw dropw end", num_leaves = leaves.len(), pos = pos, @@ -127,7 +130,8 @@ fn test_mmr_get_two_peaks() -> Result<(), MerkleError> { for (absolute_pos, leaf) in examples { let source = format!( - "use.std::collections::mmr + " + use.std::collections::mmr begin push.{num_leaves} push.1000 mem_store # leaves count @@ -135,6 +139,8 @@ fn test_mmr_get_two_peaks() -> Result<(), MerkleError> { adv_push.4 push.1002 mem_storew dropw # MMR second peak push.1000 push.{pos} exec.mmr::get + + swapw dropw end", num_leaves = num_leaves, pos = absolute_pos, @@ -176,13 +182,16 @@ fn test_mmr_tree_with_one_element() -> Result<(), MerkleError> { // Test case for single element MMR let advice_stack: Vec = merkle_root3.iter().map(StarkField::as_int).collect(); let source = format!( - "use.std::collections::mmr + " + use.std::collections::mmr begin push.{num_leaves} push.1000 mem_store # leaves count adv_push.4 push.1001 mem_storew dropw # MMR first peak push.1000 push.{pos} exec.mmr::get + + swapw dropw end", num_leaves = leaves3.len(), pos = 0, @@ -199,7 +208,8 @@ fn test_mmr_tree_with_one_element() -> Result<(), MerkleError> { .collect(); let num_leaves = leaves1.len() + leaves2.len() + leaves3.len(); let source = format!( - "use.std::collections::mmr + " + use.std::collections::mmr begin push.{num_leaves} push.1000 mem_store # leaves count @@ -208,6 +218,8 @@ fn test_mmr_tree_with_one_element() -> Result<(), MerkleError> { adv_push.4 push.1003 mem_storew dropw # MMR third peak push.1000 push.{pos} exec.mmr::get + + swapw dropw end", num_leaves = num_leaves, pos = num_leaves - 1, @@ -247,7 +259,7 @@ fn test_mmr_unpack() { let hash = hash_elements(&hash_data.concat()); // Set up the VM stack with the MMR hash, and its target address - let mut stack = stack_to_ints(&*hash); + let mut stack = felt_slice_to_ints(&*hash); let mmr_ptr = 1000_u32; stack.insert(0, mmr_ptr as u64); @@ -309,7 +321,7 @@ fn test_mmr_unpack_invalid_hash() { let hash = hash_elements(&hash_data.concat()); // Set up the VM stack with the MMR hash, and its target address - let mut stack = stack_to_ints(&*hash); + let mut stack = felt_slice_to_ints(&*hash); let mmr_ptr = 1000; stack.insert(0, mmr_ptr); @@ -371,7 +383,7 @@ fn test_mmr_unpack_large_mmr() { let hash = hash_elements(&hash_data.concat()); // Set up the VM stack with the MMR hash, and its target address - let mut stack = stack_to_ints(&*hash); + let mut stack = felt_slice_to_ints(&*hash); let mmr_ptr = 1000_u32; stack.insert(0, mmr_ptr as u64); @@ -427,11 +439,11 @@ fn test_mmr_pack_roundtrip() { mmr.add(init_merkle_leaf(2).into()); mmr.add(init_merkle_leaf(3).into()); - let accumulator = mmr.peaks(mmr.forest()).unwrap(); + let accumulator = mmr.peaks(); let hash = accumulator.hash_peaks(); // Set up the VM stack with the MMR hash, and its target address - let mut stack = stack_to_ints(hash.as_elements()); + let mut stack = felt_slice_to_ints(&*hash); let mmr_ptr = 1000; stack.insert(0, mmr_ptr); // first value is used by unpack, to load data to memory stack.insert(0, mmr_ptr); // second is used by pack, to load data from memory @@ -454,9 +466,12 @@ fn test_mmr_pack_roundtrip() { let source = " use.std::collections::mmr + begin exec.mmr::unpack exec.mmr::pack + + swapw dropw end "; let test = build_test!(source, &stack, advice_stack, store, advice_map.iter().cloned()); @@ -486,6 +501,8 @@ fn test_mmr_pack() { push.2.1002 mem_store # peak2 push.1000 exec.mmr::pack + + swapw dropw end "; @@ -509,7 +526,7 @@ fn test_mmr_pack() { let host = process.host.borrow_mut(); let advice_data = host.advice_provider().map().get(&hash_u8).unwrap(); - assert_eq!(stack_to_ints(advice_data), stack_to_ints(&expect_data)); + assert_eq!(advice_data, &expect_data); } #[test] @@ -560,7 +577,7 @@ fn test_mmr_two() { mmr.add([ONE, Felt::new(2), Felt::new(3), Felt::new(4)].into()); mmr.add([Felt::new(5), Felt::new(6), Felt::new(7), Felt::new(8)].into()); - let accumulator = mmr.peaks(mmr.forest()).unwrap(); + let accumulator = mmr.peaks(); let peak = accumulator.peaks()[0]; let num_leaves = accumulator.num_leaves() as u64; @@ -587,6 +604,8 @@ fn test_mmr_large() { push.{mmr_ptr}.0.0.0.7 exec.mmr::add push.{mmr_ptr} exec.mmr::pack + + swapw dropw end " ); @@ -600,7 +619,7 @@ fn test_mmr_large() { mmr.add([ZERO, ZERO, ZERO, Felt::new(6)].into()); mmr.add([ZERO, ZERO, ZERO, Felt::new(7)].into()); - let accumulator = mmr.peaks(mmr.forest()).unwrap(); + let accumulator = mmr.peaks(); let num_leaves = accumulator.num_leaves() as u64; let mut expected_memory = vec![num_leaves, 0, 0, 0]; @@ -625,11 +644,11 @@ fn test_mmr_large_add_roundtrip() { [ZERO, ZERO, ZERO, Felt::new(7)].into(), ]); - let old_accumulator = mmr.peaks(mmr.forest()).unwrap(); + let old_accumulator = mmr.peaks(); let hash = old_accumulator.hash_peaks(); // Set up the VM stack with the MMR hash, and its target address - let mut stack = stack_to_ints(hash.as_elements()); + let mut stack = felt_slice_to_ints(&*hash); stack.insert(0, mmr_ptr as u64); // both the advice stack and merkle store start empty (data is available in @@ -658,13 +677,15 @@ fn test_mmr_large_add_roundtrip() { exec.mmr::unpack push.{mmr_ptr}.0.0.0.8 exec.mmr::add push.{mmr_ptr} exec.mmr::pack + + swapw dropw end " ); mmr.add([ZERO, ZERO, ZERO, Felt::new(8)].into()); - let new_accumulator = mmr.peaks(mmr.forest()).unwrap(); + let new_accumulator = mmr.peaks(); let num_leaves = new_accumulator.num_leaves() as u64; let mut expected_memory = vec![num_leaves, 0, 0, 0]; let mut new_peaks = new_accumulator.peaks().to_vec(); diff --git a/stdlib/tests/collections/smt.rs b/stdlib/tests/collections/smt.rs index 928792be96..0b2d76198f 100644 --- a/stdlib/tests/collections/smt.rs +++ b/stdlib/tests/collections/smt.rs @@ -24,8 +24,9 @@ fn test_smt_get() { fn expect_value_from_get(key: RpoDigest, value: Word, smt: &Smt) { let source = " use.std::collections::smt + begin - exec.smt::get + exec.smt::get end "; let mut initial_stack = Vec::new(); @@ -61,8 +62,10 @@ fn test_smt_set() { let source = " use.std::collections::smt + begin - exec.smt::set + exec.smt::set + movupw.2 dropw end "; @@ -240,8 +243,10 @@ fn test_set_empty_key_in_non_empty_leaf() { let source = " use.std::collections::smt + begin - exec.smt::set + exec.smt::set + movupw.2 dropw end "; let (init_stack, final_stack, store, advice_map) = diff --git a/stdlib/tests/crypto/blake3.rs b/stdlib/tests/crypto/blake3.rs index 633d81610a..419e9e76ea 100644 --- a/stdlib/tests/crypto/blake3.rs +++ b/stdlib/tests/crypto/blake3.rs @@ -7,6 +7,7 @@ fn blake3_hash_64_bytes() { begin exec.blake3::hash_2to1 + swapdw dropw dropw end "; @@ -41,6 +42,7 @@ fn blake3_hash_32_bytes() { begin exec.blake3::hash_1to1 + swapdw dropw dropw end "; diff --git a/stdlib/tests/crypto/elgamal.rs b/stdlib/tests/crypto/elgamal.rs index bebc5e9c83..6955acb476 100644 --- a/stdlib/tests/crypto/elgamal.rs +++ b/stdlib/tests/crypto/elgamal.rs @@ -1,6 +1,6 @@ use std::ops::Add; -use test_utils::{rand::rand_array, Felt, FieldElement}; +use test_utils::{push_inputs, rand::rand_array, Felt, FieldElement}; use crate::math::ecgfp5::{base_field::Ext5, group::ECExt5}; @@ -39,9 +39,12 @@ fn test_elgamal_keygen() { let source = " use.std::crypto::elgamal_ecgfp5 + use.std::sys begin exec.elgamal_ecgfp5::gen_privatekey + + exec.sys::truncate_stack end "; @@ -94,9 +97,12 @@ fn test_elgamal_encrypt() { let source = " use.std::crypto::elgamal_ecgfp5 + use.std::sys begin exec.elgamal_ecgfp5::encrypt_ca + + exec.sys::truncate_stack end "; @@ -164,18 +170,24 @@ fn test_elgamal_encrypt() { pm.y.a4.as_int(), pm.point_at_infinity.as_int(), ]; + stack.reverse(); - let source = " + let source = format!( + " use.std::crypto::elgamal_ecgfp5 + use.std::sys begin + {inputs} exec.elgamal_ecgfp5::encrypt_cb - end - "; - stack.reverse(); + exec.sys::truncate_stack + end + ", + inputs = push_inputs(&stack) + ); - let test = build_test!(source, &stack); + let test = build_test!(source, &[]); let strace = test.get_last_stack_state(); assert_eq!(strace[0], cb.x.a0); @@ -227,14 +239,6 @@ fn test_elgamal_remask() { let c_prime_a = ca.add(r_prime_g); let c_prime_b = cb.add(r_prime_h); - let source = " - use.std::crypto::elgamal_ecgfp5 - - begin - exec.elgamal_ecgfp5::remask_ca - end - "; - let mut stack = [ r_prime[0] as u64, r_prime[1] as u64, @@ -260,7 +264,22 @@ fn test_elgamal_remask() { ]; stack.reverse(); - let test = build_test!(source, &stack); + let source = format!( + " + use.std::crypto::elgamal_ecgfp5 + use.std::sys + + begin + {inputs} + exec.elgamal_ecgfp5::remask_ca + + exec.sys::truncate_stack + end + ", + inputs = push_inputs(&stack) + ); + + let test = build_test!(source, &[]); let strace = test.get_last_stack_state(); assert_eq!(strace[0], c_prime_a.x.a0); @@ -275,14 +294,6 @@ fn test_elgamal_remask() { assert_eq!(strace[9], c_prime_a.y.a4); assert_eq!(strace[10], c_prime_a.point_at_infinity); - let source = " - use.std::crypto::elgamal_ecgfp5 - - begin - exec.elgamal_ecgfp5::remask_cb - end - "; - let mut stack = [ h.x.a0.as_int(), h.x.a1.as_int(), @@ -319,7 +330,22 @@ fn test_elgamal_remask() { ]; stack.reverse(); - let test = build_test!(source, &stack); + let source = format!( + " + use.std::crypto::elgamal_ecgfp5 + use.std::sys + + begin + {inputs} + exec.elgamal_ecgfp5::remask_cb + + exec.sys::truncate_stack + end + ", + inputs = push_inputs(&stack) + ); + + let test = build_test!(source, &[]); let strace = test.get_last_stack_state(); assert_eq!(strace[0], c_prime_b.x.a0); diff --git a/stdlib/tests/crypto/fri/channel.rs b/stdlib/tests/crypto/fri/channel.rs index fd10128bfa..864d810324 100644 --- a/stdlib/tests/crypto/fri/channel.rs +++ b/stdlib/tests/crypto/fri/channel.rs @@ -3,7 +3,7 @@ use test_utils::{ crypto::{BatchMerkleProof, ElementHasher, Hasher as HasherTrait, PartialMerkleTree}, math::fft, serde::DeserializationError, - Felt, FieldElement, StarkField, + Felt, FieldElement, MerkleTreeVC, StarkField, }; use winter_fri::{FriProof, VerifierError}; @@ -16,7 +16,10 @@ pub trait UnBatch { ) -> (Vec, Vec<(Digest, Vec)>); } -pub struct MidenFriVerifierChannel> { +pub struct MidenFriVerifierChannel< + E: FieldElement, + H: ElementHasher + ElementHasher, +> { layer_commitments: Vec, layer_proofs: Vec>, layer_queries: Vec>, @@ -25,8 +28,8 @@ pub struct MidenFriVerifierChannel MidenFriVerifierChannel where - E: FieldElement, - H: ElementHasher, + E: FieldElement, + H: ElementHasher + ElementHasher, { /// Builds a new verifier channel from the specified [FriProof]. /// @@ -40,7 +43,7 @@ where ) -> Result { let remainder = proof.parse_remainder()?; let (layer_queries, layer_proofs) = - proof.parse_layers::(domain_size, folding_factor)?; + proof.parse_layers::>(domain_size, folding_factor)?; Ok(MidenFriVerifierChannel { layer_commitments, diff --git a/stdlib/tests/crypto/fri/remainder.rs b/stdlib/tests/crypto/fri/remainder.rs index 6056aa9d6c..33746d2af6 100644 --- a/stdlib/tests/crypto/fri/remainder.rs +++ b/stdlib/tests/crypto/fri/remainder.rs @@ -1,5 +1,6 @@ use test_utils::{ - math::fft, rand::rand_vector, test_case, Felt, FieldElement, QuadFelt, StarkField, ONE, + math::fft, push_inputs, rand::rand_vector, test_case, Felt, FieldElement, QuadFelt, StarkField, + ONE, }; #[test_case(8, 1; "poly_8 |> evaluated_8 |> interpolated_8")] @@ -19,8 +20,22 @@ fn test_decorator_ext2intt(in_poly_len: usize, blowup: usize) { let eval_mem_req = (eval_len * 2) / 4; let out_mem_req = (in_poly_len * 2) / 4; + let poly = rand_vector::(in_poly_len); + let twiddles = fft::get_twiddles(poly.len()); + let evals = fft::evaluate_poly_with_offset(&poly, &twiddles, ONE, blowup); + + let ifelts = QuadFelt::slice_as_base_elements(&evals); + let iu64s = ifelts.iter().map(|v| v.as_int()).collect::>(); + let ou64s = QuadFelt::slice_as_base_elements(&poly) + .iter() + .rev() + .map(|v| v.as_int()) + .collect::>(); + let source = format!( " + use.std::sys + proc.helper.{} locaddr.{} repeat.{} @@ -48,7 +63,10 @@ fn test_decorator_ext2intt(in_poly_len: usize, blowup: usize) { end begin + {inputs} exec.helper + + exec.sys::truncate_stack end ", eval_mem_req, @@ -56,28 +74,28 @@ fn test_decorator_ext2intt(in_poly_len: usize, blowup: usize) { eval_mem_req, eval_len, in_poly_len, - out_mem_req + out_mem_req, + inputs = push_inputs(&iu64s) ); - let poly = rand_vector::(in_poly_len); - let twiddles = fft::get_twiddles(poly.len()); - let evals = fft::evaluate_poly_with_offset(&poly, &twiddles, ONE, blowup); - - let ifelts = QuadFelt::slice_as_base_elements(&evals); - let iu64s = ifelts.iter().map(|v| v.as_int()).collect::>(); - let ou64s = QuadFelt::slice_as_base_elements(&poly) - .iter() - .rev() - .map(|v| v.as_int()) - .collect::>(); - - let test = build_test!(&source, &iu64s); + let test = build_test!(&source, &[]); test.expect_stack(&ou64s); } #[test] fn test_verify_remainder_64() { - let source = " + let poly = rand_vector::(8); + let twiddles = fft::get_twiddles(poly.len()); + let evals = fft::evaluate_poly_with_offset(&poly, &twiddles, Felt::GENERATOR, 8); + let tau = rand_vector::(1); + + let mut ifelts = QuadFelt::slice_as_base_elements(&tau).to_vec(); + ifelts.extend_from_slice(QuadFelt::slice_as_base_elements(&evals)); + ifelts.extend_from_slice(QuadFelt::slice_as_base_elements(&poly)); + let iu64s = ifelts.iter().map(|v| v.as_int()).collect::>(); + + let source = format!( + " use.std::crypto::fri::ext2fri proc.helper.36 @@ -96,11 +114,20 @@ fn test_verify_remainder_64() { end begin + {inputs} exec.helper end - "; + ", + inputs = push_inputs(&iu64s) + ); - let poly = rand_vector::(8); + let test = build_test!(source, &[]); + test.expect_stack(&[]); +} + +#[test] +fn test_verify_remainder_32() { + let poly = rand_vector::(4); let twiddles = fft::get_twiddles(poly.len()); let evals = fft::evaluate_poly_with_offset(&poly, &twiddles, Felt::GENERATOR, 8); let tau = rand_vector::(1); @@ -110,13 +137,8 @@ fn test_verify_remainder_64() { ifelts.extend_from_slice(QuadFelt::slice_as_base_elements(&poly)); let iu64s = ifelts.iter().map(|v| v.as_int()).collect::>(); - let test = build_test!(source, &iu64s); - test.expect_stack(&[]); -} - -#[test] -fn test_verify_remainder_32() { - let source = " + let source = format!( + " use.std::crypto::fri::ext2fri proc.helper.18 @@ -135,20 +157,13 @@ fn test_verify_remainder_32() { end begin + {inputs} exec.helper end - "; - - let poly = rand_vector::(4); - let twiddles = fft::get_twiddles(poly.len()); - let evals = fft::evaluate_poly_with_offset(&poly, &twiddles, Felt::GENERATOR, 8); - let tau = rand_vector::(1); - - let mut ifelts = QuadFelt::slice_as_base_elements(&tau).to_vec(); - ifelts.extend_from_slice(QuadFelt::slice_as_base_elements(&evals)); - ifelts.extend_from_slice(QuadFelt::slice_as_base_elements(&poly)); - let iu64s = ifelts.iter().map(|v| v.as_int()).collect::>(); + ", + inputs = push_inputs(&iu64s) + ); - let test = build_test!(source, &iu64s); + let test = build_test!(source, &[]); test.expect_stack(&[]); } diff --git a/stdlib/tests/crypto/fri/verifier_fri_e2f4.rs b/stdlib/tests/crypto/fri/verifier_fri_e2f4.rs index a8eef9b4ee..ca4acb420f 100644 --- a/stdlib/tests/crypto/fri/verifier_fri_e2f4.rs +++ b/stdlib/tests/crypto/fri/verifier_fri_e2f4.rs @@ -8,7 +8,7 @@ use test_utils::{ crypto::{MerklePath, NodeIndex, PartialMerkleTree, Rpo256 as MidenHasher}, group_slice_elements, math::fft, - Felt, FieldElement, QuadFelt as QuadExt, StarkField, EMPTY_WORD, + Felt, FieldElement, MerkleTreeVC, QuadFelt as QuadExt, StarkField, EMPTY_WORD, }; use winter_fri::{ folding::fold_positions, DefaultProverChannel, FriOptions, FriProof, FriProver, VerifierError, @@ -66,7 +66,7 @@ pub fn fri_prove_verify_fold4_ext2(trace_length_e: usize) -> Result>::new(options.clone()); prover.build_layers(&mut channel, evaluations.clone()); let positions = channel.draw_query_positions(nonce); let proof = prover.build_proof(&positions); @@ -412,21 +412,17 @@ impl UnBatch for MidenFriVerifierChannel(query); + let leaves: Vec = + x.iter().map(|row| MidenHasher::hash_elements(row)).collect(); + let unbatched_proof = layer_proof.into_openings(&leaves, &folded_positions).unwrap(); assert_eq!(x.len(), unbatched_proof.len()); - let nodes: Vec<[Felt; 4]> = unbatched_proof - .iter_mut() - .map(|list| { - let node = list.remove(0); - let node = node.as_elements().to_owned(); - [node[0], node[1], node[2], node[3]] - }) - .collect(); + let nodes: Vec<[Felt; 4]> = + leaves.iter().map(|leaf| [leaf[0], leaf[1], leaf[2], leaf[3]]).collect(); let paths: Vec = - unbatched_proof.into_iter().map(|list| list.into()).collect(); + unbatched_proof.into_iter().map(|list| list.1.into()).collect(); let iter_pos = folded_positions.iter_mut().map(|a| *a as u64); let nodes_tmp = nodes.clone(); diff --git a/stdlib/tests/crypto/keccak256.rs b/stdlib/tests/crypto/keccak256.rs index 18104defb4..8991c87def 100644 --- a/stdlib/tests/crypto/keccak256.rs +++ b/stdlib/tests/crypto/keccak256.rs @@ -1,7 +1,7 @@ use sha3::{Digest, Keccak256}; use test_utils::{ rand::{rand_array, rand_value}, - Felt, IntoBytes, STACK_TOP_SIZE, + Felt, IntoBytes, MIN_STACK_DEPTH, }; /// Equivalent to https://github.com/itzmeanjan/merklize-sha/blob/1d35aae/include/test_bit_interleaving.hpp#L12-L34 @@ -35,6 +35,7 @@ fn keccak256_2_to_1_hash() { begin exec.keccak256::hash + swapdw dropw dropw end "; @@ -55,13 +56,13 @@ fn keccak256_2_to_1_hash() { // 32 -bytes digest represented in terms eight ( little endian ) // 32 -bit integers such that it's easy to compare against final stack trace - let mut expected_stack = [0u64; STACK_TOP_SIZE >> 1]; + let mut expected_stack = [0u64; MIN_STACK_DEPTH >> 1]; to_stack(&digest, &mut expected_stack); // 64 -bytes input represented in terms of sixteen ( little endian ) 32 -bit // integers so that miden assembly implementation of keccak256 2-to-1 hash can // consume it and produce 32 -bytes digest - let mut in_stack = [0u64; STACK_TOP_SIZE]; + let mut in_stack = [0u64; MIN_STACK_DEPTH]; to_stack(&i_digest, &mut in_stack); in_stack.reverse(); diff --git a/stdlib/tests/crypto/mod.rs b/stdlib/tests/crypto/mod.rs index 848b9e1f62..e68dbc291e 100644 --- a/stdlib/tests/crypto/mod.rs +++ b/stdlib/tests/crypto/mod.rs @@ -6,6 +6,6 @@ mod ecdsa_secp256k1; mod elgamal; mod fri; mod keccak256; -mod native; +mod rpo; mod sha256; mod stark; diff --git a/stdlib/tests/crypto/native.rs b/stdlib/tests/crypto/rpo.rs similarity index 53% rename from stdlib/tests/crypto/native.rs rename to stdlib/tests/crypto/rpo.rs index e63c77a247..48721c4e6a 100644 --- a/stdlib/tests/crypto/native.rs +++ b/stdlib/tests/crypto/rpo.rs @@ -5,34 +5,13 @@ use test_utils::{build_expected_hash, build_expected_perm, expect_exec_error}; fn test_invalid_end_addr() { // end_addr can not be smaller than start_addr let empty_range = " - use.std::crypto::hashes::native + use.std::crypto::hashes::rpo begin push.0999 # end address push.1000 # start address - exec.native::hash_memory - end - "; - let test = build_test!(empty_range, &[]); - expect_exec_error!( - test, - ExecutionError::FailedAssertion { - clk: 18.into(), - err_code: 0, - err_msg: None, - } - ); - - // address range can not contain zero elements - let empty_range = " - use.std::crypto::hashes::native - - begin - push.1000 # end address - push.1000 # start address - - exec.native::hash_memory + exec.rpo::hash_memory_words end "; let test = build_test!(empty_range, &[]); @@ -50,13 +29,18 @@ fn test_invalid_end_addr() { fn test_hash_empty() { // computes the hash for 8 consecutive zeros using mem_stream directly let two_zeros_mem_stream = " + use.std::crypto::hashes::rpo + begin # mem_stream state push.1000 padw padw padw mem_stream hperm # drop everything except the hash - dropw swapw dropw movup.4 drop + exec.rpo::squeeze_digest movup.4 drop + + # truncate stack + swapw dropw end "; @@ -67,15 +51,18 @@ fn test_hash_empty() { ]).into_iter().map(|e| e.as_int()).collect(); build_test!(two_zeros_mem_stream, &[]).expect_stack(&zero_hash); - // checks the hash compute from 8 zero elements is the same when using hash_memory + // checks the hash compute from 8 zero elements is the same when using hash_memory_words let two_zeros = " - use.std::crypto::hashes::native + use.std::crypto::hashes::rpo begin push.1002 # end address push.1000 # start address - exec.native::hash_memory + exec.rpo::hash_memory_words + + # truncate stack + swapw dropw end "; @@ -86,6 +73,8 @@ fn test_hash_empty() { fn test_single_iteration() { // computes the hash of 1 using mem_stream let one_memstream = " + use.std::crypto::hashes::rpo + begin # insert 1 to memory push.1.1000 mem_store @@ -95,7 +84,10 @@ fn test_single_iteration() { mem_stream hperm # drop everything except the hash - dropw swapw dropw movup.4 drop + exec.rpo::squeeze_digest movup.4 drop + + # truncate stack + swapw dropw end "; @@ -106,11 +98,11 @@ fn test_single_iteration() { ]).into_iter().map(|e| e.as_int()).collect(); build_test!(one_memstream, &[]).expect_stack(&one_hash); - // checks the hash of 1 is the same when using hash_memory + // checks the hash of 1 is the same when using hash_memory_words // Note: This is testing the hashing of two words, so no padding is added // here let one_element = " - use.std::crypto::hashes::native + use.std::crypto::hashes::rpo begin # insert 1 to memory @@ -119,7 +111,10 @@ fn test_single_iteration() { push.1002 # end address push.1000 # start address - exec.native::hash_memory + exec.rpo::hash_memory_words + + # truncate stack + swapw dropw end "; @@ -137,9 +132,9 @@ fn test_hash_one_word() { 1, 0, 0, 0, ]).into_iter().map(|e| e.as_int()).collect(); - // checks the hash of 1 is the same when using hash_memory + // checks the hash of 1 is the same when using hash_memory_words let one_element = " - use.std::crypto::hashes::native + use.std::crypto::hashes::rpo begin push.1.1000 mem_store # push data to memory @@ -147,7 +142,10 @@ fn test_hash_one_word() { push.1001 # end address push.1000 # start address - exec.native::hash_memory + exec.rpo::hash_memory_words + + # truncate stack + swapw dropw end "; @@ -158,7 +156,7 @@ fn test_hash_one_word() { fn test_hash_even_words() { // checks the hash of two words let even_words = " - use.std::crypto::hashes::native + use.std::crypto::hashes::rpo begin push.1.0.0.0.1000 mem_storew dropw @@ -167,7 +165,10 @@ fn test_hash_even_words() { push.1002 # end address push.1000 # start address - exec.native::hash_memory + exec.rpo::hash_memory_words + + # truncate stack + swapw dropw end "; @@ -183,7 +184,7 @@ fn test_hash_even_words() { fn test_hash_odd_words() { // checks the hash of three words let odd_words = " - use.std::crypto::hashes::native + use.std::crypto::hashes::rpo begin push.1.0.0.0.1000 mem_storew dropw @@ -193,7 +194,10 @@ fn test_hash_odd_words() { push.1003 # end address push.1000 # start address - exec.native::hash_memory + exec.rpo::hash_memory_words + + # truncate stack + swapw dropw end "; @@ -207,9 +211,10 @@ fn test_hash_odd_words() { } #[test] -fn test_hash_memory_even() { +fn test_absorb_double_words_from_memory() { let even_words = " - use.std::crypto::hashes::native + use.std::sys + use.std::crypto::hashes::rpo begin push.1.0.0.0.1000 mem_storew dropw @@ -218,7 +223,10 @@ fn test_hash_memory_even() { push.1002 # end address push.1000 # start address padw padw padw # hasher state - exec.native::hash_memory_even + exec.rpo::absorb_double_words_from_memory + + # truncate stack + exec.sys::truncate_stack end "; @@ -237,9 +245,9 @@ fn test_hash_memory_even() { } #[test] -fn test_state_to_digest() { +fn test_squeeze_digest() { let even_words = " - use.std::crypto::hashes::native + use.std::crypto::hashes::rpo begin push.1.0.0.0.1000 mem_storew dropw @@ -250,9 +258,12 @@ fn test_state_to_digest() { push.1004 # end address push.1000 # start address padw padw padw # hasher state - exec.native::hash_memory_even + exec.rpo::absorb_double_words_from_memory - exec.native::state_to_digest + exec.rpo::squeeze_digest + + # truncate stack + swapdw dropw dropw end "; @@ -270,3 +281,149 @@ fn test_state_to_digest() { build_test!(even_words, &[]).expect_stack(&even_hash); } + +#[test] +fn test_hash_memory() { + // hash fewer than 8 elements + let compute_inputs_hash_5 = " + use.std::crypto::hashes::rpo + + begin + push.1.2.3.4.1000 mem_storew dropw + push.5.0.0.0.1001 mem_storew dropw + push.11 + + push.5.1000 + + exec.rpo::hash_memory + + # truncate stack + swapdw dropw dropw + end + "; + + #[rustfmt::skip] + let mut expected_hash: Vec = build_expected_hash(&[ + 1, 2, 3, 4, 5 + ]).into_iter().map(|e| e.as_int()).collect(); + // make sure that value `11` stays unchanged + expected_hash.push(11); + build_test!(compute_inputs_hash_5, &[]).expect_stack(&expected_hash); + + // hash exactly 8 elements + let compute_inputs_hash_8 = " + use.std::crypto::hashes::rpo + + begin + push.1.2.3.4.1000 mem_storew dropw + push.5.6.7.8.1001 mem_storew dropw + push.11 + + push.8.1000 + + exec.rpo::hash_memory + + # truncate stack + swapdw dropw dropw + end + "; + + #[rustfmt::skip] + let mut expected_hash: Vec = build_expected_hash(&[ + 1, 2, 3, 4, 5, 6, 7, 8 + ]).into_iter().map(|e| e.as_int()).collect(); + // make sure that value `11` stays unchanged + expected_hash.push(11); + build_test!(compute_inputs_hash_8, &[]).expect_stack(&expected_hash); + + // hash more than 8 elements + let compute_inputs_hash_15 = " + use.std::crypto::hashes::rpo + + begin + push.1.2.3.4.1000 mem_storew dropw + push.5.6.7.8.1001 mem_storew dropw + push.9.10.11.12.1002 mem_storew dropw + push.13.14.15.0.1003 mem_storew dropw + push.11 + + push.15.1000 + + exec.rpo::hash_memory + + # truncate stack + swapdw dropw dropw + end + "; + + #[rustfmt::skip] + let mut expected_hash: Vec = build_expected_hash(&[ + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15 + ]).into_iter().map(|e| e.as_int()).collect(); + // make sure that value `11` stays unchanged + expected_hash.push(11); + build_test!(compute_inputs_hash_15, &[]).expect_stack(&expected_hash); +} + +#[test] +fn test_hash_memory_empty() { + // absorb_double_words_from_memory + let source = " + use.std::sys + use.std::crypto::hashes::rpo + + begin + push.1000 # end address + push.1000 # start address + padw padw padw # hasher state + + exec.rpo::absorb_double_words_from_memory + + # truncate stack + exec.sys::truncate_stack + end + "; + + let mut expected_stack = vec![0; 12]; + expected_stack.push(1000); + expected_stack.push(1000); + + build_test!(source, &[]).expect_stack(&expected_stack); + + // hash_memory_words + let source = " + use.std::crypto::hashes::rpo + + begin + push.1000 # end address + push.1000 # start address + + exec.rpo::hash_memory_words + + # truncate stack + swapw dropw + end + "; + + build_test!(source, &[]).expect_stack(&[0; 4]); + + // hash_memory + let source = " + use.std::crypto::hashes::rpo + + begin + push.0 # number of elements to hash + push.1000 # start address + + exec.rpo::hash_memory + + # truncate stack + swapw dropw + end + "; + + build_test!(source, &[]).expect_stack(&[0; 16]); +} diff --git a/stdlib/tests/crypto/sha256.rs b/stdlib/tests/crypto/sha256.rs index 8a40459f7f..e6a371b9c3 100644 --- a/stdlib/tests/crypto/sha256.rs +++ b/stdlib/tests/crypto/sha256.rs @@ -1,16 +1,34 @@ use sha2::{Digest, Sha256}; use test_utils::{ - group_slice_elements, + group_slice_elements, push_inputs, rand::{rand_array, rand_value, rand_vector}, Felt, IntoBytes, }; #[test] fn sha256_hash_memory() { - let source = " + let length = rand_value::() & 1023; // length: 0-1023 + let ibytes: Vec = rand_vector(length as usize); + let ipadding: Vec = vec![0; (4 - (length as usize % 4)) % 4]; + + let ifelts = [ + group_slice_elements::(&[ibytes.clone(), ipadding].concat()) + .iter() + .map(|&bytes| u32::from_be_bytes(bytes) as u64) + .rev() + .collect::>(), + vec![length as u64; 1], + ] + .concat(); + + let source = format!( + " use.std::crypto::hashes::sha256 begin + # push inputs on the stack + {inputs} + # mem.0 - input data address push.10000 mem_store.0 @@ -32,21 +50,12 @@ fn sha256_hash_memory() { mem_load.1 push.10000 exec.sha256::hash_memory - end"; - - let length = rand_value::() & 1023; // length: 0-1023 - let ibytes: Vec = rand_vector(length as usize); - let ipadding: Vec = vec![0; (4 - (length as usize % 4)) % 4]; - - let ifelts = [ - group_slice_elements::(&[ibytes.clone(), ipadding].concat()) - .iter() - .map(|&bytes| u32::from_be_bytes(bytes) as u64) - .rev() - .collect::>(), - vec![length as u64; 1], - ] - .concat(); + + # truncate the stack + swapdw dropw dropw + end", + inputs = push_inputs(&ifelts) + ); let mut hasher = Sha256::new(); hasher.update(ibytes); @@ -57,7 +66,7 @@ fn sha256_hash_memory() { .map(|&bytes| u32::from_be_bytes(bytes) as u64) .collect::>(); - let test = build_test!(source, &ifelts); + let test = build_test!(source, &[]); test.expect_stack(&ofelts); } diff --git a/stdlib/tests/crypto/stark/verifier_recursive/channel.rs b/stdlib/tests/crypto/stark/verifier_recursive/channel.rs index d52d5d57ae..0ae32bcdfd 100644 --- a/stdlib/tests/crypto/stark/verifier_recursive/channel.rs +++ b/stdlib/tests/crypto/stark/verifier_recursive/channel.rs @@ -8,7 +8,7 @@ use test_utils::{ crypto::{BatchMerkleProof, MerklePath, PartialMerkleTree, Rpo256, RpoDigest}, group_slice_elements, math::{FieldElement, QuadExtension, StarkField}, - Felt, VerifierError, EMPTY_WORD, + Felt, MerkleTreeVC, VerifierError, EMPTY_WORD, }; use winter_air::{ proof::{Proof, Queries, Table, TraceOodFrame}, @@ -89,7 +89,10 @@ impl VerifierChannel { .parse_remainder() .map_err(|err| VerifierError::ProofDeserializationError(err.to_string()))?; let (fri_layer_queries, fri_layer_proofs) = fri_proof - .parse_layers::(lde_domain_size, fri_options.folding_factor()) + .parse_layers::>( + lde_domain_size, + fri_options.folding_factor(), + ) .map_err(|err| VerifierError::ProofDeserializationError(err.to_string()))?; // --- parse out-of-domain evaluation frame ----------------------------------------------- @@ -242,22 +245,16 @@ impl VerifierChannel { let mut folded_positions = fold_positions(&positions, current_domain_size, N); let layer_proof = layer_proofs.remove(0); - - let mut unbatched_proof = layer_proof.into_paths(&folded_positions).unwrap(); let x = group_slice_elements::(query); + let leaves: Vec = x.iter().map(|row| Rpo256::hash_elements(row)).collect(); + let unbatched_proof = layer_proof.into_openings(&leaves, &folded_positions).unwrap(); assert_eq!(x.len(), unbatched_proof.len()); - let nodes: Vec<[Felt; 4]> = unbatched_proof - .iter_mut() - .map(|list| { - let node = list.remove(0); - let node = node.as_elements().to_owned(); - [node[0], node[1], node[2], node[3]] - }) - .collect(); + let nodes: Vec<[Felt; 4]> = + leaves.iter().map(|leaf| [leaf[0], leaf[1], leaf[2], leaf[3]]).collect(); let paths: Vec = - unbatched_proof.into_iter().map(|list| list.into()).collect(); + unbatched_proof.into_iter().map(|list| list.1.into()).collect(); let iter_pos = folded_positions.iter_mut().map(|a| *a as u64); let nodes_tmp = nodes.clone(); @@ -292,6 +289,7 @@ impl VerifierChannel { impl FriVerifierChannel for VerifierChannel { type Hasher = Rpo256; + type VectorCommitment = MerkleTreeVC; fn read_fri_num_partitions(&self) -> usize { self.fri_num_partitions @@ -341,7 +339,11 @@ impl TraceQueries { let main_segment_width = air.trace_info().main_trace_width(); let main_segment_queries = queries.remove(0); let (main_segment_query_proofs, main_segment_states) = main_segment_queries - .parse::(air.lde_domain_size(), num_queries, main_segment_width) + .parse::>( + air.lde_domain_size(), + num_queries, + main_segment_width, + ) .map_err(|err| { VerifierError::ProofDeserializationError(format!( "main trace segment query deserialization failed: {err}" @@ -359,7 +361,11 @@ impl TraceQueries { let aux_segment_queries = queries.remove(0); let (segment_query_proof, segment_trace_states) = aux_segment_queries - .parse::(air.lde_domain_size(), num_queries, segment_width) + .parse::>( + air.lde_domain_size(), + num_queries, + segment_width, + ) .map_err(|err| { VerifierError::ProofDeserializationError(format!( "auxiliary trace segment query deserialization failed: {err}" @@ -401,7 +407,11 @@ impl ConstraintQueries { num_queries: usize, ) -> Result { let (query_proofs, evaluations) = queries - .parse::(air.lde_domain_size(), num_queries, air.ce_blowup_factor()) + .parse::>( + air.lde_domain_size(), + num_queries, + air.ce_blowup_factor(), + ) .map_err(|err| { VerifierError::ProofDeserializationError(format!( "constraint evaluation query deserialization failed: {err}" @@ -420,18 +430,14 @@ pub fn unbatch_to_partial_mt( queries: Vec>, proof: BatchMerkleProof, ) -> (PartialMerkleTree, Vec<(RpoDigest, Vec)>) { - let mut unbatched_proof = proof.into_paths(&positions).unwrap(); + let leaves: Vec = queries.iter().map(|row| Rpo256::hash_elements(row)).collect(); + + let unbatched_proof = proof.into_openings(&leaves, &positions).unwrap(); let mut adv_key_map = Vec::new(); - let nodes: Vec<[Felt; 4]> = unbatched_proof - .iter_mut() - .map(|list| { - let node = list.remove(0); - let node = node.as_elements().to_owned(); - [node[0], node[1], node[2], node[3]] - }) - .collect(); + let nodes: Vec<[Felt; 4]> = + queries.iter().map(|node| [node[0], node[1], node[2], node[3]]).collect(); - let paths: Vec = unbatched_proof.into_iter().map(|list| list.into()).collect(); + let paths: Vec = unbatched_proof.into_iter().map(|list| list.1.into()).collect(); let iter_pos = positions.iter_mut().map(|a| *a as u64); let nodes_tmp = nodes.clone(); diff --git a/stdlib/tests/main.rs b/stdlib/tests/main.rs index b50db64413..f159a08612 100644 --- a/stdlib/tests/main.rs +++ b/stdlib/tests/main.rs @@ -12,6 +12,7 @@ macro_rules! build_test { mod collections; mod crypto; +mod mast_forest_merge; mod math; mod mem; mod sys; diff --git a/stdlib/tests/mast_forest_merge.rs b/stdlib/tests/mast_forest_merge.rs new file mode 100644 index 0000000000..3040551b94 --- /dev/null +++ b/stdlib/tests/mast_forest_merge.rs @@ -0,0 +1,19 @@ +use processor::MastForest; + +/// Tests that the stdlib merged with itself produces a forest that has the same procedure +/// roots. +/// +/// This test is added here since we do not have the StdLib in miden-core where merging is +/// implemented and the StdLib serves as a convenient example of a large MastForest. +#[test] +fn mast_forest_merge_stdlib() { + let std_lib = miden_stdlib::StdLibrary::default(); + let std_forest = std_lib.mast_forest().as_ref(); + + let (merged, _) = MastForest::merge([std_forest, std_forest]).unwrap(); + + let merged_digests = merged.procedure_digests().collect::>(); + for digest in std_forest.procedure_digests() { + assert!(merged_digests.contains(&digest)); + } +} diff --git a/stdlib/tests/math/ecgfp5/base_field.rs b/stdlib/tests/math/ecgfp5/base_field.rs index e38d80fd0f..8d204f81a8 100644 --- a/stdlib/tests/math/ecgfp5/base_field.rs +++ b/stdlib/tests/math/ecgfp5/base_field.rs @@ -585,6 +585,7 @@ fn test_ext5_sqrt() { begin exec.base_field::sqrt + movup.6 drop end"; let a = Ext5::rand(); diff --git a/stdlib/tests/math/ecgfp5/group.rs b/stdlib/tests/math/ecgfp5/group.rs index 3fcc067a6f..10b983cfe4 100644 --- a/stdlib/tests/math/ecgfp5/group.rs +++ b/stdlib/tests/math/ecgfp5/group.rs @@ -1,6 +1,6 @@ use std::ops::Add; -use test_utils::{test_case, Felt, ONE, ZERO}; +use test_utils::{push_inputs, test_case, Felt, ONE, ZERO}; use super::base_field::{bv_or, Ext5}; @@ -255,9 +255,12 @@ fn test_ec_ext5_point_validate(a0: u64, a1: u64, a2: u64, a3: u64, a4: u64, shou fn test_ec_ext5_point_decode(a0: u64, a1: u64, a2: u64, a3: u64, a4: u64, should_decode: bool) { let source = " use.std::math::ecgfp5::group + use.std::sys begin exec.group::decode + + exec.sys::truncate_stack end"; let w = Ext5::new(a0, a1, a2, a3, a4); @@ -299,6 +302,7 @@ fn test_ec_ext5_point_encode(a0: u64, a1: u64, a2: u64, a3: u64, a4: u64) { begin exec.group::encode + swapdw dropw dropw end"; let w = Ext5::new(a0, a1, a2, a3, a4); @@ -359,13 +363,6 @@ fn test_ec_ext5_point_addition( c3: u64, c4: u64, ) { - let source = " - use.std::math::ecgfp5::group - - begin - exec.group::add - end"; - let w0 = Ext5::new(a0, a1, a2, a3, a4); let w1 = Ext5::new(b0, b1, b2, b3, b4); let w2 = Ext5::new(c0, c1, c2, c3, c4); @@ -403,7 +400,21 @@ fn test_ec_ext5_point_addition( ]; stack.reverse(); - let test = build_test!(source, &stack); + let source = format!( + " + use.std::math::ecgfp5::group + use.std::sys + + begin + {inputs} + exec.group::add + + exec.sys::truncate_stack + end", + inputs = push_inputs(&stack) + ); + + let test = build_test!(source, &[]); let strace = test.get_last_stack_state(); assert_eq!(strace[0], q2.x.a0); @@ -437,9 +448,12 @@ fn test_ec_ext5_point_doubling( ) { let source = " use.std::math::ecgfp5::group + use.std::sys begin exec.group::double + + exec.sys::truncate_stack end"; let w0 = Ext5::new(a0, a1, a2, a3, a4); @@ -485,13 +499,6 @@ fn test_ec_ext5_point_doubling( // Test vectors taken from https://github.com/pornin/ecgfp5/blob/ce059c6/python/ecGFp5.py#L1528-L1558 #[test] fn test_ec_ext5_point_multiplication() { - let source = " - use.std::math::ecgfp5::group - - begin - exec.group::mul - end"; - let w0 = Ext5::new( 12539254003028696409, 15524144070600887654, @@ -553,7 +560,21 @@ fn test_ec_ext5_point_multiplication() { ]; stack.reverse(); - let test = build_test!(source, &stack); + let source = format!( + " + use.std::math::ecgfp5::group + use.std::sys + + begin + {inputs} + exec.group::mul + + exec.sys::truncate_stack + end", + inputs = push_inputs(&stack) + ); + + let test = build_test!(source, &[]); let strace = test.get_last_stack_state(); assert_eq!(strace[0], p1.x.a0); @@ -574,9 +595,12 @@ fn test_ec_ext5_point_multiplication() { fn test_ec_ext5_gen_multiplication() { let source = " use.std::math::ecgfp5::group + use.std::sys begin exec.group::gen_mul + + exec.sys::truncate_stack end"; // Conventional generator point of this group diff --git a/stdlib/tests/math/ecgfp5/scalar_field.rs b/stdlib/tests/math/ecgfp5/scalar_field.rs index 7cadf33314..b58fe6e8dd 100644 --- a/stdlib/tests/math/ecgfp5/scalar_field.rs +++ b/stdlib/tests/math/ecgfp5/scalar_field.rs @@ -1,6 +1,6 @@ use std::ops::Mul; -use test_utils::rand::rand_value; +use test_utils::{push_inputs, rand::rand_value}; #[derive(Copy, Clone, Debug)] struct Scalar { @@ -219,13 +219,6 @@ fn test_ec_ext5_scalar_arithmetic() { #[test] fn test_ec_ext5_scalar_mont_mul() { - let source = " - use.std::math::ecgfp5::scalar_field - - begin - exec.scalar_field::mont_mul - end"; - let a = Scalar { limbs: [ rand_value::() >> 1, @@ -263,7 +256,21 @@ fn test_ec_ext5_scalar_mont_mul() { } stack.reverse(); - let test = build_test!(source, &stack); + let source = format!( + " + use.std::math::ecgfp5::scalar_field + use.std::sys + + begin + {inputs} + exec.scalar_field::mont_mul + + exec.sys::truncate_stack + end", + inputs = push_inputs(&stack) + ); + + let test = build_test!(source, &[]); let strace = test.get_last_stack_state(); for (i, limb) in c.limbs.iter().enumerate() { @@ -275,10 +282,13 @@ fn test_ec_ext5_scalar_mont_mul() { fn test_ec_ext5_scalar_to_and_from_mont_repr() { let source = " use.std::math::ecgfp5::scalar_field + use.std::sys begin exec.scalar_field::to_mont exec.scalar_field::from_mont + + exec.sys::truncate_stack end"; let a = Scalar { @@ -318,9 +328,12 @@ fn test_ec_ext5_scalar_to_and_from_mont_repr() { fn test_ec_ext5_scalar_inv() { let source = " use.std::math::ecgfp5::scalar_field + use.std::sys begin exec.scalar_field::inv + + exec.sys::truncate_stack end"; let a = Scalar { diff --git a/stdlib/tests/math/u256_mod.rs b/stdlib/tests/math/u256_mod.rs index 14150c321c..d0eaf32c5a 100644 --- a/stdlib/tests/math/u256_mod.rs +++ b/stdlib/tests/math/u256_mod.rs @@ -13,6 +13,7 @@ fn mul_unsafe() { use.std::math::u256 begin exec.u256::mul_unsafe + swapdw dropw dropw end"; let operands = a diff --git a/stdlib/tests/math/u64_mod.rs b/stdlib/tests/math/u64_mod.rs index 862d5715c5..55cf8206fd 100644 --- a/stdlib/tests/math/u64_mod.rs +++ b/stdlib/tests/math/u64_mod.rs @@ -520,7 +520,7 @@ fn checked_and_fail() { end"; let test = build_test!(source, &[a0, a1, b0, b1]); - expect_exec_error!(test, ExecutionError::NotU32Value(Felt::new(b0), ZERO)); + expect_exec_error!(test, ExecutionError::NotU32Value(Felt::new(a0), ZERO)); } #[test] @@ -558,7 +558,7 @@ fn checked_or_fail() { end"; let test = build_test!(source, &[a0, a1, b0, b1]); - expect_exec_error!(test, ExecutionError::NotU32Value(Felt::new(b0), ZERO)); + expect_exec_error!(test, ExecutionError::NotU32Value(Felt::new(a0), ZERO)); } #[test] @@ -596,7 +596,7 @@ fn checked_xor_fail() { end"; let test = build_test!(source, &[a0, a1, b0, b1]); - expect_exec_error!(test, ExecutionError::NotU32Value(Felt::new(b0), ZERO)); + expect_exec_error!(test, ExecutionError::NotU32Value(Felt::new(a0), ZERO)); } #[test] diff --git a/stdlib/tests/mem/mod.rs b/stdlib/tests/mem/mod.rs index 7ce4bc80d1..61c3836f9d 100644 --- a/stdlib/tests/mem/mod.rs +++ b/stdlib/tests/mem/mod.rs @@ -1,6 +1,6 @@ use processor::{ContextId, DefaultHost, ProcessState, Program}; use test_utils::{ - build_expected_hash, build_expected_perm, stack_to_ints, ExecutionOptions, Process, + build_expected_hash, build_expected_perm, felt_slice_to_ints, ExecutionOptions, Process, StackInputs, ONE, ZERO, }; @@ -31,7 +31,7 @@ fn test_memcopy() { assembler.assemble_program(source).expect("Failed to compile test source."); let mut host = DefaultHost::default(); - host.load_mast_forest(stdlib.into()); + host.load_mast_forest(stdlib.mast_forest().clone()); let mut process = Process::new( program.kernel().clone(), @@ -98,7 +98,9 @@ fn test_memcopy() { fn test_pipe_double_words_to_memory() { let mem_addr = 1000; let source = format!( - "use.std::mem + " + use.std::mem + use.std::sys begin push.1002 # end_addr @@ -106,6 +108,8 @@ fn test_pipe_double_words_to_memory() { padw padw padw # hasher state exec.mem::pipe_double_words_to_memory + + exec.sys::truncate_stack end", mem_addr ); @@ -113,7 +117,7 @@ fn test_pipe_double_words_to_memory() { let operand_stack = &[]; let data = &[1, 2, 3, 4, 5, 6, 7, 8]; let mut expected_stack = - stack_to_ints(&build_expected_perm(&[0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8])); + felt_slice_to_ints(&build_expected_perm(&[0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8])); expected_stack.push(1002); build_test!(source, operand_stack, &data).expect_stack_and_memory( &expected_stack, @@ -126,20 +130,26 @@ fn test_pipe_double_words_to_memory() { fn test_pipe_words_to_memory() { let mem_addr = 1000; let one_word = format!( - "use.std::mem + " + use.std::mem + use.std::crypto::hashes::rpo begin push.{} # target address push.1 # number of words exec.mem::pipe_words_to_memory + exec.rpo::squeeze_digest + + # truncate stack + swapdw dropw dropw end", mem_addr ); let operand_stack = &[]; let data = &[1, 2, 3, 4]; - let mut expected_stack = stack_to_ints(&build_expected_hash(data)); + let mut expected_stack = felt_slice_to_ints(&build_expected_hash(data)); expected_stack.push(1001); build_test!(one_word, operand_stack, &data).expect_stack_and_memory( &expected_stack, @@ -148,20 +158,26 @@ fn test_pipe_words_to_memory() { ); let three_words = format!( - "use.std::mem + " + use.std::mem + use.std::crypto::hashes::rpo begin push.{} # target address push.3 # number of words exec.mem::pipe_words_to_memory + exec.rpo::squeeze_digest + + # truncate stack + swapdw dropw dropw end", mem_addr ); let operand_stack = &[]; let data = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; - let mut expected_stack = stack_to_ints(&build_expected_hash(data)); + let mut expected_stack = felt_slice_to_ints(&build_expected_hash(data)); expected_stack.push(1003); build_test!(three_words, operand_stack, &data).expect_stack_and_memory( &expected_stack, @@ -182,13 +198,14 @@ fn test_pipe_preimage_to_memory() { push.3 # number of words exec.mem::pipe_preimage_to_memory + swap drop end", mem_addr ); let operand_stack = &[]; let data = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; - let mut advice_stack = stack_to_ints(&build_expected_hash(data)); + let mut advice_stack = felt_slice_to_ints(&build_expected_hash(data)); advice_stack.reverse(); advice_stack.extend(data); build_test!(three_words, operand_stack, &advice_stack).expect_stack_and_memory( @@ -214,7 +231,7 @@ fn test_pipe_preimage_to_memory_invalid_preimage() { let operand_stack = &[]; let data = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; - let mut advice_stack = stack_to_ints(&build_expected_hash(data)); + let mut advice_stack = felt_slice_to_ints(&build_expected_hash(data)); advice_stack.reverse(); advice_stack[0] += 1; // corrupt the expected hash advice_stack.extend(data); diff --git a/stdlib/tests/sys/mod.rs b/stdlib/tests/sys/mod.rs index 50aee4de8c..d9c80dbb99 100644 --- a/stdlib/tests/sys/mod.rs +++ b/stdlib/tests/sys/mod.rs @@ -1,4 +1,4 @@ -use test_utils::{proptest::prelude::*, rand::rand_vector, STACK_TOP_SIZE}; +use test_utils::{proptest::prelude::*, rand::rand_vector, MIN_STACK_DEPTH}; #[test] fn truncate_stack() { @@ -9,7 +9,7 @@ fn truncate_stack() { proptest! { #[test] - fn truncate_stack_proptest(test_values in prop::collection::vec(any::(), STACK_TOP_SIZE), n in 1_usize..100) { + fn truncate_stack_proptest(test_values in prop::collection::vec(any::(), MIN_STACK_DEPTH), n in 1_usize..100) { let mut push_values = rand_vector::(n); let mut source_vec = vec!["use.std::sys".to_string(), "begin".to_string()]; for value in push_values.iter() { @@ -22,7 +22,7 @@ proptest! { let mut expected_values = test_values.clone(); expected_values.append(&mut push_values); expected_values.reverse(); - expected_values.truncate(STACK_TOP_SIZE); + expected_values.truncate(MIN_STACK_DEPTH); build_test!(&source, &test_values).prop_expect_stack(&expected_values)?; } } diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index 0d6ac384b2..5f02885ef9 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -24,25 +24,23 @@ std = [ ] [dependencies] -air = { package = "miden-air", path = "../air", version = "0.10", default-features = false } -assembly = { package = "miden-assembly", path = "../assembly", version = "0.10", default-features = false, features = [ +air = { package = "miden-air", path = "../air", version = "0.11", default-features = false } +assembly = { package = "miden-assembly", path = "../assembly", version = "0.11", default-features = false, features = [ "testing", ] } -processor = { package = "miden-processor", path = "../processor", version = "0.10", default-features = false, features = [ +processor = { package = "miden-processor", path = "../processor", version = "0.11", default-features = false, features = [ "testing", ] } -prover = { package = "miden-prover", path = "../prover", version = "0.10", default-features = false } +prover = { package = "miden-prover", path = "../prover", version = "0.11", default-features = false } test-case = "3.2" -verifier = { package = "miden-verifier", path = "../verifier", version = "0.10", default-features = false } -vm-core = { package = "miden-core", path = "../core", version = "0.10", default-features = false } -winter-prover = { package = "winter-prover", version = "0.9", default-features = false } +verifier = { package = "miden-verifier", path = "../verifier", version = "0.11", default-features = false } +vm-core = { package = "miden-core", path = "../core", version = "0.11", default-features = false } +winter-prover = { package = "winter-prover", version = "0.10", default-features = false } [target.'cfg(target_family = "wasm")'.dependencies] -pretty_assertions = { version = "1.4", default-features = false, features = [ - "alloc", -] } +pretty_assertions = { version = "1.4", default-features = false, features = ["alloc"] } [target.'cfg(not(target_family = "wasm"))'.dependencies] pretty_assertions = "1.4" proptest = "1.4" -rand-utils = { package = "winter-rand-utils", version = "0.9" } +rand-utils = { package = "winter-rand-utils", version = "0.10" } diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index c6253a0a16..c2515702d8 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -5,9 +5,6 @@ extern crate alloc; #[cfg(feature = "std")] extern crate std; -// IMPORTS -// ================================================================================================ - #[cfg(not(target_family = "wasm"))] use alloc::format; use alloc::{ @@ -16,27 +13,26 @@ use alloc::{ vec::Vec, }; -use assembly::Library; -// EXPORTS -// ================================================================================================ pub use assembly::{diagnostics::Report, LibraryPath, SourceFile, SourceManager}; +use assembly::{KernelLibrary, Library}; pub use pretty_assertions::{assert_eq, assert_ne, assert_str_eq}; +use processor::Program; pub use processor::{ AdviceInputs, AdviceProvider, ContextId, DefaultHost, ExecutionError, ExecutionOptions, - ExecutionTrace, Process, ProcessState, StackInputs, VmStateIterator, + ExecutionTrace, Process, ProcessState, VmStateIterator, }; -use processor::{MastForest, Program}; #[cfg(not(target_family = "wasm"))] use proptest::prelude::{Arbitrary, Strategy}; -pub use prover::{prove, MemAdviceProvider, ProvingOptions}; +pub use prover::{prove, MemAdviceProvider, MerkleTreeVC, ProvingOptions}; pub use test_case::test_case; pub use verifier::{verify, AcceptableOptions, VerifierError}; use vm_core::{chiplets::hasher::apply_permutation, ProgramInfo}; pub use vm_core::{ chiplets::hasher::{hash_elements, STATE_WIDTH}, - stack::STACK_TOP_SIZE, + stack::MIN_STACK_DEPTH, utils::{collections, group_slice_elements, IntoBytes, ToElements}, - Felt, FieldElement, StarkField, Word, EMPTY_WORD, ONE, WORD_SIZE, ZERO, + Felt, FieldElement, StackInputs, StackOutputs, StarkField, Word, EMPTY_WORD, ONE, WORD_SIZE, + ZERO, }; pub mod math { @@ -72,6 +68,19 @@ pub type QuadFelt = vm_core::QuadExtension; /// A value just over what a [u32] integer can hold. pub const U32_BOUND: u64 = u32::MAX as u64 + 1; +/// A source code of the `truncate_stack` procedure. +pub const TRUNCATE_STACK_PROC: &str = " +proc.truncate_stack.1 + loc_storew.0 dropw movupw.3 + sdepth neq.16 + while.true + dropw movupw.3 + sdepth neq.16 + end + loc_loadw.0 +end +"; + // TEST HANDLER // ================================================================================================ @@ -119,8 +128,8 @@ macro_rules! expect_exec_error { }; } -/// Like [assembly::assert_diagnostic], but matches each non-empty line of the rendered output -/// to a corresponding pattern. +/// Like [assembly::assert_diagnostic], but matches each non-empty line of the rendered output to a +/// corresponding pattern. /// /// So if the output has 3 lines, the second of which is empty, and you provide 2 patterns, the /// assertion passes if the first line matches the first pattern, and the third line matches the @@ -212,8 +221,8 @@ impl Test { /// test will result in the expected final stack state. #[track_caller] pub fn expect_stack(&self, final_stack: &[u64]) { - let result = stack_to_ints(&self.get_last_stack_state()); - let expected = stack_top_to_ints(final_stack); + let result = self.get_last_stack_state().as_int_vec(); + let expected = resize_to_min_stack_depth(final_stack); assert_eq!(expected, result, "Expected stack to be {:?}, found {:?}", expected, result); } @@ -231,7 +240,7 @@ impl Test { let (program, kernel) = self.compile().expect("Failed to compile test source."); let mut host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); if let Some(kernel) = kernel { - host.load_mast_forest(kernel); + host.load_mast_forest(kernel.mast_forest().clone()); } for library in &self.libraries { host.load_mast_forest(library.mast_forest().clone()); @@ -253,7 +262,7 @@ impl Test { let mem_state = process.get_mem_value(ContextId::root(), mem_start_addr).unwrap_or(EMPTY_WORD); - let mem_state = stack_to_ints(&mem_state); + let mem_state = felt_slice_to_ints(&mem_state); assert_eq!( data, mem_state, "Expected memory [{}] => {:?}, found {:?}", @@ -262,7 +271,7 @@ impl Test { mem_start_addr += 1; } - // validate the stack state + // validate the stack states self.expect_stack(final_stack); } @@ -274,8 +283,8 @@ impl Test { &self, final_stack: &[u64], ) -> Result<(), proptest::prelude::TestCaseError> { - let result = self.get_last_stack_state(); - proptest::prop_assert_eq!(stack_top_to_ints(final_stack), stack_to_ints(&result)); + let result = self.get_last_stack_state().as_int_vec(); + proptest::prop_assert_eq!(resize_to_min_stack_depth(final_stack), result); Ok(()) } @@ -283,22 +292,26 @@ impl Test { // UTILITY METHODS // -------------------------------------------------------------------------------------------- - /// Compiles a test's source and returns the resulting Program or Assembly error. - pub fn compile(&self) -> Result<(Program, Option), Report> { + /// Compiles a test's source and returns the resulting Program together with the associated + /// kernel library (when specified). + /// + /// # Errors + /// Returns an error if compilation of the program source or the kernel fails. + pub fn compile(&self) -> Result<(Program, Option), Report> { use assembly::{ast::ModuleKind, Assembler, CompileOptions}; - let (assembler, compiled_kernel) = if let Some(kernel) = self.kernel_source.clone() { + let (assembler, kernel_lib) = if let Some(kernel) = self.kernel_source.clone() { let kernel_lib = Assembler::new(self.source_manager.clone()).assemble_kernel(kernel).unwrap(); - let compiled_kernel = kernel_lib.mast_forest().clone(); ( - Assembler::with_kernel(self.source_manager.clone(), kernel_lib), - Some(compiled_kernel), + Assembler::with_kernel(self.source_manager.clone(), kernel_lib.clone()), + Some(kernel_lib), ) } else { (Assembler::new(self.source_manager.clone()), None) }; + let mut assembler = self .add_modules .iter() @@ -315,7 +328,7 @@ impl Test { assembler.add_library(library).unwrap(); } - Ok((assembler.assemble_program(self.source.clone())?, compiled_kernel)) + Ok((assembler.assemble_program(self.source.clone())?, kernel_lib)) } /// Compiles the test's source to a Program and executes it with the tests inputs. Returns a @@ -325,7 +338,7 @@ impl Test { let (program, kernel) = self.compile().expect("Failed to compile test source."); let mut host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); if let Some(kernel) = kernel { - host.load_mast_forest(kernel); + host.load_mast_forest(kernel.mast_forest().clone()); } for library in &self.libraries { host.load_mast_forest(library.mast_forest().clone()); @@ -341,7 +354,7 @@ impl Test { let (program, kernel) = self.compile().expect("Failed to compile test source."); let mut host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); if let Some(kernel) = kernel { - host.load_mast_forest(kernel); + host.load_mast_forest(kernel.mast_forest().clone()); } for library in &self.libraries { host.load_mast_forest(library.mast_forest().clone()); @@ -365,7 +378,7 @@ impl Test { let (program, kernel) = self.compile().expect("Failed to compile test source."); let mut host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); if let Some(kernel) = kernel { - host.load_mast_forest(kernel); + host.load_mast_forest(kernel.mast_forest().clone()); } for library in &self.libraries { host.load_mast_forest(library.mast_forest().clone()); @@ -390,7 +403,7 @@ impl Test { let (program, kernel) = self.compile().expect("Failed to compile test source."); let mut host = DefaultHost::new(MemAdviceProvider::from(self.advice_inputs.clone())); if let Some(kernel) = kernel { - host.load_mast_forest(kernel); + host.load_mast_forest(kernel.mast_forest().clone()); } for library in &self.libraries { host.load_mast_forest(library.mast_forest().clone()); @@ -400,7 +413,7 @@ impl Test { /// Returns the last state of the stack after executing a test. #[track_caller] - pub fn get_last_stack_state(&self) -> [Felt; STACK_TOP_SIZE] { + pub fn get_last_stack_state(&self) -> StackOutputs { let trace = self.execute().unwrap(); trace.last_stack_state() @@ -410,14 +423,14 @@ impl Test { // HELPER FUNCTIONS // ================================================================================================ -/// Converts an array of Felts into u64 -pub fn stack_to_ints(values: &[Felt]) -> Vec { +/// Converts a slice of Felts into a vector of u64 values. +pub fn felt_slice_to_ints(values: &[Felt]) -> Vec { values.iter().map(|e| (*e).as_int()).collect() } -pub fn stack_top_to_ints(values: &[u64]) -> Vec { +pub fn resize_to_min_stack_depth(values: &[u64]) -> Vec { let mut result: Vec = values.to_vec(); - result.resize(STACK_TOP_SIZE, 0); + result.resize(MIN_STACK_DEPTH, 0); result } @@ -453,3 +466,12 @@ pub fn build_expected_hash(values: &[u64]) -> [Felt; 4] { expected } + +// Generates the MASM code which pushes the input values during the execution of the program. +#[cfg(all(feature = "std", not(target_family = "wasm")))] +pub fn push_inputs(inputs: &[u64]) -> String { + let mut result = String::new(); + + inputs.iter().for_each(|v| result.push_str(&format!("push.{}\n", v))); + result +} diff --git a/test-utils/src/test_builders.rs b/test-utils/src/test_builders.rs index 3ca22c940d..2f8ea1f79d 100644 --- a/test-utils/src/test_builders.rs +++ b/test-utils/src/test_builders.rs @@ -17,11 +17,37 @@ #[macro_export] macro_rules! build_op_test { ($op_str:expr) => {{ - let source = format!("begin {} end", $op_str); + let source = format!(" +proc.truncate_stack.1 + loc_storew.0 dropw movupw.3 + sdepth neq.16 + while.true + dropw movupw.3 + sdepth neq.16 + end + loc_loadw.0 +end + +begin {} exec.truncate_stack end", + $op_str + ); $crate::build_test!(&source) }}; ($op_str:expr, $($tail:tt)+) => {{ - let source = format!("begin {} end", $op_str); + let source = format!(" +proc.truncate_stack.1 + loc_storew.0 dropw movupw.3 + sdepth neq.16 + while.true + dropw movupw.3 + sdepth neq.16 + end + loc_loadw.0 +end + +begin {} exec.truncate_stack end", + $op_str + ); $crate::build_test!(&source, $($tail)+) }}; } diff --git a/verifier/Cargo.toml b/verifier/Cargo.toml index a52ef82932..a72dc4e5d2 100644 --- a/verifier/Cargo.toml +++ b/verifier/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "miden-verifier" -version = "0.10.5" +version = "0.11.0" description="Miden VM execution verifier" -documentation = "https://docs.rs/miden-verifier/0.10.5" +documentation = "https://docs.rs/miden-verifier/0.11.0" readme = "README.md" categories = ["cryptography", "no-std"] keywords = ["miden", "stark", "verifier", "zkp"] @@ -22,7 +22,7 @@ default = ["std"] std = ["air/std", "vm-core/std", "winter-verifier/std"] [dependencies] -air = { package = "miden-air", path = "../air", version = "0.10", default-features = false } +air = { package = "miden-air", path = "../air", version = "0.11", default-features = false } tracing = { version = "0.1", default-features = false, features = ["attributes"] } -vm-core = { package = "miden-core", path = "../core", version = "0.10", default-features = false } -winter-verifier = { package = "winter-verifier", version = "0.9", default-features = false } +vm-core = { package = "miden-core", path = "../core", version = "0.11", default-features = false } +winter-verifier = { package = "winter-verifier", version = "0.10", default-features = false } diff --git a/verifier/README.md b/verifier/README.md index 691b9144c4..5f896e022f 100644 --- a/verifier/README.md +++ b/verifier/README.md @@ -28,6 +28,7 @@ Miden verifier can be compiled with the following features: * `std` - enabled by default and relies on the Rust standard library. * `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly. + * Only the `wasm32-unknown-unknown` and `wasm32-wasip1` targets are officially supported. To compile with `no_std`, disable default features via `--no-default-features` flag. diff --git a/verifier/src/lib.rs b/verifier/src/lib.rs index e015f110ac..0ccf8fda2e 100644 --- a/verifier/src/lib.rs +++ b/verifier/src/lib.rs @@ -16,7 +16,7 @@ use vm_core::crypto::{ // EXPORTS // ================================================================================================ pub use vm_core::{chiplets::hasher::Digest, Kernel, ProgramInfo, StackInputs, StackOutputs, Word}; -use winter_verifier::verify as verify_proof; +use winter_verifier::{crypto::MerkleTree, verify as verify_proof}; pub use winter_verifier::{AcceptableOptions, VerifierError}; pub mod math { pub use vm_core::{Felt, FieldElement, StarkField}; @@ -69,25 +69,33 @@ pub fn verify( match hash_fn { HashFunction::Blake3_192 => { let opts = AcceptableOptions::OptionSet(vec![ProvingOptions::REGULAR_96_BITS]); - verify_proof::>(proof, pub_inputs, &opts) + verify_proof::, MerkleTree<_>>( + proof, pub_inputs, &opts, + ) }, HashFunction::Blake3_256 => { let opts = AcceptableOptions::OptionSet(vec![ProvingOptions::REGULAR_128_BITS]); - verify_proof::>(proof, pub_inputs, &opts) + verify_proof::, MerkleTree<_>>( + proof, pub_inputs, &opts, + ) }, HashFunction::Rpo256 => { let opts = AcceptableOptions::OptionSet(vec![ ProvingOptions::RECURSIVE_96_BITS, ProvingOptions::RECURSIVE_128_BITS, ]); - verify_proof::(proof, pub_inputs, &opts) + verify_proof::>( + proof, pub_inputs, &opts, + ) }, HashFunction::Rpx256 => { let opts = AcceptableOptions::OptionSet(vec![ ProvingOptions::RECURSIVE_96_BITS, ProvingOptions::RECURSIVE_128_BITS, ]); - verify_proof::(proof, pub_inputs, &opts) + verify_proof::>( + proof, pub_inputs, &opts, + ) }, } .map_err(VerificationError::VerifierError)?;