diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9145b53..5eb4c82 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,11 +3,11 @@ name: CI on: push: branches: - - develop + - dev - master pull_request: branches: - - develop + - dev - master jobs: @@ -17,18 +17,17 @@ jobs: - uses: actions/checkout@v3 - name: Update Path run: echo "$RUNNER_WORKSPACE/$(basename $GITHUB_REPOSITORY)" >> $GITHUB_PATH # Make it accessible from runner - - name: Install solc - run: | - set -x - wget -c https://github.com/ethereum/solidity/releases/download/v0.5.17/solc-static-linux - mv solc-static-linux solc - chmod +x solc - solc --version - name: Setup Node.js environment uses: actions/setup-node@v3 with: node-version: '16' registry-url: 'https://registry.npmjs.org' + - uses: dtolnay/rust-toolchain@nightly + - name: Install svm + run: | + cargo install svm-rs + svm install 0.5.17 + svm install 0.6.12 - name: Install Foundry uses: foundry-rs/foundry-toolchain@v1 with: diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 0000000..420a7b3 --- /dev/null +++ b/.prettierrc @@ -0,0 +1,16 @@ +{ + "tabWidth": 2, + "useTabs": false, + "semi": false, + "singleQuote": true, + "trailingComma": "none", + "overrides": [ + { + "files": "*.sol", + "options": { + "printWidth": 120, + "singleQuote": false + } + } + ] +} diff --git a/contracts/IStateReceiver.sol b/contracts/IStateReceiver.sol index 03a8ee1..4107085 100644 --- a/contracts/IStateReceiver.sol +++ b/contracts/IStateReceiver.sol @@ -1,4 +1,4 @@ -pragma solidity ^0.5.11; +pragma solidity >0.5.11; // IStateReceiver represents interface to receive state interface IStateReceiver { diff --git a/contracts/StateReceiver.sol b/contracts/StateReceiver.sol index 0575c49..2fd8843 100644 --- a/contracts/StateReceiver.sol +++ b/contracts/StateReceiver.sol @@ -1,8 +1,8 @@ -pragma solidity ^0.5.11; +pragma solidity 0.6.12; -import { RLPReader } from "solidity-rlp/contracts/RLPReader.sol"; - -import { System } from "./System.sol"; +import {RLPReader} from "./utils/RLPReader.sol"; +import {System} from "./System.sol"; +import {IStateReceiver} from "./IStateReceiver.sol"; contract StateReceiver is System { using RLPReader for bytes; @@ -10,34 +10,97 @@ contract StateReceiver is System { uint256 public lastStateId; + bytes32 public failedStateSyncsRoot; + mapping(bytes32 => bool) public nullifier; + + mapping(uint256 => bytes) public failedStateSyncs; + + address public immutable rootSetter; + uint256 public leafCount; + uint256 public replayCount; + uint256 public constant TREE_DEPTH = 16; + event StateCommitted(uint256 indexed stateId, bool success); + event StateSyncReplay(uint256 indexed stateId); + + constructor(address _rootSetter) public { + rootSetter = _rootSetter; + } - function commitState(uint256 syncTime, bytes calldata recordBytes) external onlySystem returns(bool success) { + function commitState(uint256 syncTime, bytes calldata recordBytes) external onlySystem returns (bool success) { // parse state data RLPReader.RLPItem[] memory dataList = recordBytes.toRlpItem().toList(); uint256 stateId = dataList[0].toUint(); - require( - lastStateId + 1 == stateId, - "StateIds are not sequential" - ); + require(lastStateId + 1 == stateId, "StateIds are not sequential"); lastStateId++; - address receiver = dataList[1].toAddress(); bytes memory stateData = dataList[2].toBytes(); // notify state receiver contract, in a non-revert manner if (isContract(receiver)) { uint256 txGas = 5000000; + bytes memory data = abi.encodeWithSignature("onStateReceive(uint256,bytes)", stateId, stateData); // solium-disable-next-line security/no-inline-assembly assembly { success := call(txGas, receiver, 0, add(data, 0x20), mload(data), 0, 0) } emit StateCommitted(stateId, success); + if (!success) failedStateSyncs[stateId] = abi.encode(receiver, stateData); } } + function replayFailedStateSync(uint256 stateId) external { + bytes memory stateSyncData = failedStateSyncs[stateId]; + require(stateSyncData.length != 0, "!found"); + delete failedStateSyncs[stateId]; + + (address receiver, bytes memory stateData) = abi.decode(stateSyncData, (address, bytes)); + emit StateSyncReplay(stateId); + IStateReceiver(receiver).onStateReceive(stateId, stateData); // revertable + } + + function setRootAndLeafCount(bytes32 _root, uint256 _leafCount) external { + require(msg.sender == rootSetter, "!rootSetter"); + failedStateSyncsRoot = _root; + leafCount = _leafCount; + } + + function replayHistoricFailedStateSync( + bytes32[TREE_DEPTH] calldata proof, + uint256 leafIndex, + uint256 stateId, + address receiver, + bytes calldata data + ) external { + require(leafIndex < 2 ** TREE_DEPTH, "invalid leafIndex"); + require(++replayCount <= leafCount, "end"); + bytes32 root = failedStateSyncsRoot; + require(root != bytes32(0), "!root"); + + bytes32 leafHash = keccak256(abi.encode(stateId, receiver, data)); + bytes32 zeroHash = 0x28cf91ac064e179f8a42e4b7a20ba080187781da55fd4f3f18870b7a25bacb55; // keccak256(abi.encode(uint256(0), address(0), new bytes(0))); + require(leafHash != zeroHash && !nullifier[leafHash], "used"); + nullifier[leafHash] = true; + + require(root == _getRoot(proof, leafIndex, leafHash), "!proof"); + + emit StateSyncReplay(stateId); + IStateReceiver(receiver).onStateReceive(stateId, data); + } + + function _getRoot(bytes32[TREE_DEPTH] memory proof, uint256 index, bytes32 leafHash) private pure returns (bytes32) { + bytes32 node = leafHash; + + for (uint256 height = 0; height < TREE_DEPTH; height++) { + if (((index >> height) & 1) == 1) node = keccak256(abi.encodePacked(proof[height], node)); + else node = keccak256(abi.encodePacked(node, proof[height])); + } + + return node; + } + // check if address is contract - function isContract(address _addr) private view returns (bool){ + function isContract(address _addr) private view returns (bool) { uint32 size; // solium-disable-next-line security/no-inline-assembly assembly { diff --git a/contracts/System.sol b/contracts/System.sol index aec687c..d0de549 100644 --- a/contracts/System.sol +++ b/contracts/System.sol @@ -1,4 +1,4 @@ -pragma solidity ^0.5.11; +pragma solidity >0.5.11; contract System { address public constant SYSTEM_ADDRESS = 0xffffFFFfFFffffffffffffffFfFFFfffFFFfFFfE; diff --git a/contracts/test/TestStateReceiver.sol b/contracts/test/TestStateReceiver.sol deleted file mode 100644 index a555e8f..0000000 --- a/contracts/test/TestStateReceiver.sol +++ /dev/null @@ -1,7 +0,0 @@ -pragma solidity ^0.5.11; -pragma experimental ABIEncoderV2; - -import {StateReceiver} from "../StateReceiver.sol"; -import {TestSystem} from "./TestSystem.sol"; - -contract TestStateReceiver is StateReceiver, TestSystem {} diff --git a/contracts/utils/RLPReader.sol b/contracts/utils/RLPReader.sol new file mode 100644 index 0000000..3410c01 --- /dev/null +++ b/contracts/utils/RLPReader.sol @@ -0,0 +1,355 @@ +// SPDX-License-Identifier: Apache-2.0 + +/* + * @author Hamdi Allam hamdi.allam97@gmail.com + * Please reach out with any questions or concerns + */ +pragma solidity >=0.5.10 <0.9.0; + +library RLPReader { + uint8 constant STRING_SHORT_START = 0x80; + uint8 constant STRING_LONG_START = 0xb8; + uint8 constant LIST_SHORT_START = 0xc0; + uint8 constant LIST_LONG_START = 0xf8; + uint8 constant WORD_SIZE = 32; + + struct RLPItem { + uint256 len; + uint256 memPtr; + } + + struct Iterator { + RLPItem item; // Item that's being iterated over. + uint256 nextPtr; // Position of the next item in the list. + } + + /* + * @dev Returns the next element in the iteration. Reverts if it has not next element. + * @param self The iterator. + * @return The next element in the iteration. + */ + function next(Iterator memory self) internal pure returns (RLPItem memory) { + require(hasNext(self)); + + uint256 ptr = self.nextPtr; + uint256 itemLength = _itemLength(ptr); + self.nextPtr = ptr + itemLength; + + return RLPItem(itemLength, ptr); + } + + /* + * @dev Returns true if the iteration has more elements. + * @param self The iterator. + * @return true if the iteration has more elements. + */ + function hasNext(Iterator memory self) internal pure returns (bool) { + RLPItem memory item = self.item; + return self.nextPtr < item.memPtr + item.len; + } + + /* + * @param item RLP encoded bytes + */ + function toRlpItem(bytes memory item) internal pure returns (RLPItem memory) { + uint256 memPtr; + assembly { + memPtr := add(item, 0x20) + } + + return RLPItem(item.length, memPtr); + } + + /* + * @dev Create an iterator. Reverts if item is not a list. + * @param self The RLP item. + * @return An 'Iterator' over the item. + */ + function iterator(RLPItem memory self) internal pure returns (Iterator memory) { + require(isList(self)); + + uint256 ptr = self.memPtr + _payloadOffset(self.memPtr); + return Iterator(self, ptr); + } + + /* + * @param the RLP item. + */ + function rlpLen(RLPItem memory item) internal pure returns (uint256) { + return item.len; + } + + /* + * @param the RLP item. + * @return (memPtr, len) pair: location of the item's payload in memory. + */ + function payloadLocation(RLPItem memory item) internal pure returns (uint256, uint256) { + uint256 offset = _payloadOffset(item.memPtr); + uint256 memPtr = item.memPtr + offset; + uint256 len = item.len - offset; // data length + return (memPtr, len); + } + + /* + * @param the RLP item. + */ + function payloadLen(RLPItem memory item) internal pure returns (uint256) { + (, uint256 len) = payloadLocation(item); + return len; + } + + /* + * @param the RLP item containing the encoded list. + */ + function toList(RLPItem memory item) internal pure returns (RLPItem[] memory) { + require(isList(item)); + + uint256 items = numItems(item); + RLPItem[] memory result = new RLPItem[](items); + + uint256 memPtr = item.memPtr + _payloadOffset(item.memPtr); + uint256 dataLen; + for (uint256 i = 0; i < items; i++) { + dataLen = _itemLength(memPtr); + result[i] = RLPItem(dataLen, memPtr); + memPtr = memPtr + dataLen; + } + require(memPtr - item.memPtr == item.len, "Wrong total length."); + + return result; + } + + // @return indicator whether encoded payload is a list. negate this function call for isData. + function isList(RLPItem memory item) internal pure returns (bool) { + if (item.len == 0) return false; + + uint8 byte0; + uint256 memPtr = item.memPtr; + assembly { + byte0 := byte(0, mload(memPtr)) + } + + if (byte0 < LIST_SHORT_START) return false; + return true; + } + + /* + * @dev A cheaper version of keccak256(toRlpBytes(item)) that avoids copying memory. + * @return keccak256 hash of RLP encoded bytes. + */ + function rlpBytesKeccak256(RLPItem memory item) internal pure returns (bytes32) { + uint256 ptr = item.memPtr; + uint256 len = item.len; + bytes32 result; + assembly { + result := keccak256(ptr, len) + } + return result; + } + + /* + * @dev A cheaper version of keccak256(toBytes(item)) that avoids copying memory. + * @return keccak256 hash of the item payload. + */ + function payloadKeccak256(RLPItem memory item) internal pure returns (bytes32) { + (uint256 memPtr, uint256 len) = payloadLocation(item); + bytes32 result; + assembly { + result := keccak256(memPtr, len) + } + return result; + } + + /** RLPItem conversions into data types **/ + + // @returns raw rlp encoding in bytes + function toRlpBytes(RLPItem memory item) internal pure returns (bytes memory) { + bytes memory result = new bytes(item.len); + if (result.length == 0) return result; + + uint256 ptr; + assembly { + ptr := add(0x20, result) + } + + copy(item.memPtr, ptr, item.len); + return result; + } + + // any non-zero byte except "0x80" is considered true + function toBoolean(RLPItem memory item) internal pure returns (bool) { + require(item.len == 1); + uint256 result; + uint256 memPtr = item.memPtr; + assembly { + result := byte(0, mload(memPtr)) + } + + // SEE Github Issue #5. + // Summary: Most commonly used RLP libraries (i.e Geth) will encode + // "0" as "0x80" instead of as "0". We handle this edge case explicitly + // here. + if (result == 0 || result == STRING_SHORT_START) { + return false; + } else { + return true; + } + } + + function toAddress(RLPItem memory item) internal pure returns (address) { + // 1 byte for the length prefix + require(item.len == 21); + + return address(uint160(toUint(item))); + } + + function toUint(RLPItem memory item) internal pure returns (uint256) { + require(item.len > 0 && item.len <= 33); + + (uint256 memPtr, uint256 len) = payloadLocation(item); + + uint256 result; + assembly { + result := mload(memPtr) + + // shift to the correct location if neccesary + if lt(len, 32) { + result := div(result, exp(256, sub(32, len))) + } + } + + return result; + } + + // enforces 32 byte length + function toUintStrict(RLPItem memory item) internal pure returns (uint256) { + // one byte prefix + require(item.len == 33); + + uint256 result; + uint256 memPtr = item.memPtr + 1; + assembly { + result := mload(memPtr) + } + + return result; + } + + function toBytes(RLPItem memory item) internal pure returns (bytes memory) { + require(item.len > 0); + + (uint256 memPtr, uint256 len) = payloadLocation(item); + bytes memory result = new bytes(len); + + uint256 destPtr; + assembly { + destPtr := add(0x20, result) + } + + copy(memPtr, destPtr, len); + return result; + } + + /* + * Private Helpers + */ + + // @return number of payload items inside an encoded list. + function numItems(RLPItem memory item) private pure returns (uint256) { + if (item.len == 0) return 0; + + uint256 count = 0; + uint256 currPtr = item.memPtr + _payloadOffset(item.memPtr); + uint256 endPtr = item.memPtr + item.len; + while (currPtr < endPtr) { + currPtr = currPtr + _itemLength(currPtr); // skip over an item + count++; + } + + return count; + } + + // @return entire rlp item byte length + function _itemLength(uint256 memPtr) private pure returns (uint256) { + uint256 itemLen; + uint256 byte0; + assembly { + byte0 := byte(0, mload(memPtr)) + } + + if (byte0 < STRING_SHORT_START) { + itemLen = 1; + } else if (byte0 < STRING_LONG_START) { + itemLen = byte0 - STRING_SHORT_START + 1; + } else if (byte0 < LIST_SHORT_START) { + assembly { + let byteLen := sub(byte0, 0xb7) // # of bytes the actual length is + memPtr := add(memPtr, 1) // skip over the first byte + + /* 32 byte word size */ + let dataLen := div(mload(memPtr), exp(256, sub(32, byteLen))) // right shifting to get the len + itemLen := add(dataLen, add(byteLen, 1)) + } + } else if (byte0 < LIST_LONG_START) { + itemLen = byte0 - LIST_SHORT_START + 1; + } else { + assembly { + let byteLen := sub(byte0, 0xf7) + memPtr := add(memPtr, 1) + + let dataLen := div(mload(memPtr), exp(256, sub(32, byteLen))) // right shifting to the correct length + itemLen := add(dataLen, add(byteLen, 1)) + } + } + + return itemLen; + } + + // @return number of bytes until the data + function _payloadOffset(uint256 memPtr) private pure returns (uint256) { + uint256 byte0; + assembly { + byte0 := byte(0, mload(memPtr)) + } + + if (byte0 < STRING_SHORT_START) { + return 0; + } else if (byte0 < STRING_LONG_START || (byte0 >= LIST_SHORT_START && byte0 < LIST_LONG_START)) { + return 1; + } else if (byte0 < LIST_SHORT_START) { + // being explicit + return byte0 - (STRING_LONG_START - 1) + 1; + } else { + return byte0 - (LIST_LONG_START - 1) + 1; + } + } + + /* + * @param src Pointer to source + * @param dest Pointer to destination + * @param len Amount of memory to copy from the source + */ + function copy(uint256 src, uint256 dest, uint256 len) private pure { + if (len == 0) return; + + // copy as many word sizes as possible + for (; len >= WORD_SIZE; len -= WORD_SIZE) { + assembly { + mstore(dest, mload(src)) + } + + src += WORD_SIZE; + dest += WORD_SIZE; + } + + if (len > 0) { + // left over bytes. Mask is used to remove unwanted bytes from the word + uint256 mask = 256 ** (WORD_SIZE - len) - 1; + assembly { + let srcpart := and(mload(src), not(mask)) // zero out src + let destpart := and(mload(dest), mask) // retrieve the bytes + mstore(dest, or(destpart, srcpart)) + } + } + } +} diff --git a/foundry.toml b/foundry.toml index c576fce..4f7fcac 100644 --- a/foundry.toml +++ b/foundry.toml @@ -3,8 +3,9 @@ src = "contracts" out = "out" libs = ["lib"] optimizer = true -optimizer_runs = 999999 -via_ir = true +optimizer_runs = 200 +via_ir = false +bytecode_hash = "none" verbosity = 2 ffi = true fs_permissions = [{ access = "read", path = "./out/"}] @@ -42,4 +43,4 @@ mumbai = { key = "${POLYGONSCAN_API_KEY}" } polygon_zkevm = { key = "${POLYGONSCAN_ZKEVM_API_KEY}" } polygon_zkevm_testnet = { key = "${POLYGONSCAN_ZKEVM_API_KEY}" } -# See more config options https://github.com/foundry-rs/foundry/tree/master/config \ No newline at end of file +# See more config options https://github.com/foundry-rs/foundry/tree/master/config diff --git a/generate-genesis.js b/generate-genesis.js index 7865a36..bf128bb 100644 --- a/generate-genesis.js +++ b/generate-genesis.js @@ -26,18 +26,12 @@ program.option( program.parse(process.argv) // compile contract -function compileContract(key, contractFile, contractName) { +async function compileContract(key, contractFile, contractName, solcVersion) { return new Promise((resolve, reject) => { - const ls = spawn("solc", [ - "--bin-runtime", - "openzeppelin-solidity/=node_modules/openzeppelin-solidity/", - "solidity-rlp/=node_modules/solidity-rlp/", - "/=/", - // "--optimize", - // "--optimize-runs", - // "200", - contractFile - ]) + const ls = spawn( + `svm use ${solcVersion} && solc --bin-runtime openzeppelin-solidity/=node_modules/openzeppelin-solidity/ solidity-rlp/=node_modules/solidity-rlp/ /=/ ${contractFile}`, + { shell: true } + ) const result = [] ls.stdout.on("data", data => { @@ -64,25 +58,32 @@ function compileContract(key, contractFile, contractName) { }) } -// compile files -Promise.all([ - compileContract( - "borValidatorSetContract", - "contracts/BorValidatorSet.sol", - "BorValidatorSet" - ), - compileContract( - "borStateReceiverContract", - "contracts/StateReceiver.sol", - "StateReceiver" - ), - compileContract( - "maticChildERC20Contract", - "matic-contracts/contracts/child/MRC20.sol", - "MRC20" - ) -]).then(result => { - const totalMaticSupply = web3.utils.toBN("10000000000") +// compile files sequentially +async function main() { + const result = [] + for (const file of [ + [ + "borValidatorSetContract", + "contracts/BorValidatorSet.sol", + "BorValidatorSet", + "0.5.17" + ], + [ + "borStateReceiverContract", + "contracts/StateReceiver.sol", + "StateReceiver", + "0.6.12" + ], + [ + "maticChildERC20Contract", + "matic-contracts/contracts/child/MRC20.sol", + "MRC20", + "0.5.17" + ] + ]) { + result.push(await compileContract(...file)) + } + const totalMaticSupply = web3.utils.toBN('10000000000') var validatorsBalance = web3.utils.toBN(0) validators.forEach(v => { @@ -108,7 +109,9 @@ Promise.all([ const templateString = fs.readFileSync(program.template).toString() const resultString = nunjucks.renderString(templateString, data) fs.writeFileSync(program.output, resultString) -}).catch(err => { +} + +main().catch(err => { console.log(err) process.exit(1) }) diff --git a/generate.sh b/generate.sh old mode 100644 new mode 100755 index fb5616c..f23ddef --- a/generate.sh +++ b/generate.sh @@ -27,5 +27,4 @@ node scripts/process-templates.js --bor-chain-id $1 npm run truffle:compile cd .. node generate-borvalidatorset.js --bor-chain-id $1 --heimdall-chain-id $2 -npm run truffle:compile node generate-genesis.js --bor-chain-id $1 --heimdall-chain-id $2 diff --git a/lib/forge-std b/lib/forge-std index 978ac6f..58d3051 160000 --- a/lib/forge-std +++ b/lib/forge-std @@ -1 +1 @@ -Subproject commit 978ac6fadb62f5f0b723c996f64be52eddba6801 +Subproject commit 58d30519826c313ce47345abedfdc07679e944d1 diff --git a/test/BorValidatorSet.t.sol b/test/BorValidatorSet.t.sol index fa6255f..2f5245c 100644 --- a/test/BorValidatorSet.t.sol +++ b/test/BorValidatorSet.t.sol @@ -1,4 +1,4 @@ -pragma solidity >0.5.0; +pragma solidity 0.8.26; import "../lib/forge-std/src/Test.sol"; @@ -85,50 +85,62 @@ contract BorValidatorSetTest is Test { bytes memory validatorBytes; bytes memory producerBytes; - string[] memory cmd = new string[](4); - cmd[0] = "node"; - cmd[1] = "test/helpers/rlpEncodeValidatorsAndProducers.js"; - cmd[2] = vm.toString(numOfValidators); - cmd[3] = vm.toString(numOfProducers); - bytes memory result = vm.ffi(cmd); - (ids, powers, signers, validatorBytes, producerBytes) = abi.decode(result, (uint256[], uint256[], address[], bytes, bytes)); + { + string[] memory cmd = new string[](4); + cmd[0] = "node"; + cmd[1] = "test/helpers/rlpEncodeValidatorsAndProducers.js"; + cmd[2] = vm.toString(numOfValidators); + cmd[3] = vm.toString(numOfProducers); + bytes memory result = vm.ffi(cmd); + (ids, powers, signers, validatorBytes, producerBytes) = abi.decode(result, (uint256[], uint256[], address[], bytes, bytes)); + } vm.prank(SYSTEM_ADDRESS); borValidatorSet.commitSpan(newSpan, startBlock, endBlock, validatorBytes, producerBytes); - - (uint256 number_, uint256 startBlock_, uint256 endBlock_) = borValidatorSet.spans(0); - assertEq(number_, 0); - assertEq(startBlock_, 0); - assertEq(endBlock_, FIRST_END_BLOCK); - assertEq(borValidatorSet.spanNumbers(0), 0); - (uint256 id, uint256 power, address signer) = borValidatorSet.validators(0, 0); - (address[] memory initialAddresses, uint256[] memory initialPowers) = borValidatorSet.getInitialValidators(); - assertEq(id, 0); - assertEq(power, initialPowers[0]); - assertEq(signer, initialAddresses[0]); - vm.expectRevert(); - borValidatorSet.validators(0, 1); - (id, power, signer) = borValidatorSet.producers(0, 0); - assertEq(id, 0); - assertEq(power, initialPowers[0]); - assertEq(signer, initialAddresses[0]); - vm.expectRevert(); - borValidatorSet.producers(0, 1); - (number_, startBlock_, endBlock_) = borValidatorSet.spans(newSpan); - assertEq(number_, newSpan); - assertEq(startBlock_, startBlock); - assertEq(endBlock_, endBlock); - assertEq(borValidatorSet.spanNumbers(1), newSpan); + { + (uint256 number_, uint256 startBlock_, uint256 endBlock_) = borValidatorSet.spans(0); + assertEq(number_, 0); + assertEq(startBlock_, 0); + assertEq(endBlock_, FIRST_END_BLOCK); + assertEq(borValidatorSet.spanNumbers(0), 0); + } + { + (uint256 id, uint256 power, address signer) = borValidatorSet.validators(0, 0); + (address[] memory initialAddresses, uint256[] memory initialPowers) = borValidatorSet.getInitialValidators(); + assertEq(id, 0); + assertEq(power, initialPowers[0]); + assertEq(signer, initialAddresses[0]); + vm.expectRevert(); + borValidatorSet.validators(0, 1); + } + { + (uint256 id, uint256 power, address signer) = borValidatorSet.producers(0, 0); + assertEq(id, 0); + vm.expectRevert(); + borValidatorSet.producers(0, 1); + } + { + (uint256 number_, uint256 startBlock_, uint256 endBlock_) = borValidatorSet.spans(newSpan); + assertEq(number_, newSpan); + assertEq(startBlock_, startBlock); + assertEq(endBlock_, endBlock); + assertEq(borValidatorSet.spanNumbers(1), newSpan); + } + for (uint256 i = 0; i < numOfValidators; i++) { - (id, power, signer) = borValidatorSet.validators(newSpan, i); - assertEq(id, ids[i]); - assertEq(power, powers[i]); - assertEq(signer, signers[i]); + { + (uint256 id, uint256 power, address signer) = borValidatorSet.validators(newSpan, i); + assertEq(id, ids[i]); + assertEq(power, powers[i]); + assertEq(signer, signers[i]); + } if (i >= numOfProducers) continue; - (id, power, signer) = borValidatorSet.producers(newSpan, i); - assertEq(id, ids[i]); - assertEq(power, powers[i]); - assertEq(signer, signers[i]); + { + (uint256 id, uint256 power, address signer) = borValidatorSet.producers(newSpan, i); + assertEq(id, ids[i]); + assertEq(power, powers[i]); + assertEq(signer, signers[i]); + } } vm.expectRevert(); borValidatorSet.validators(newSpan, numOfValidators); diff --git a/test/ECVerify.t.sol b/test/ECVerify.t.sol index 0d61de6..639db69 100644 --- a/test/ECVerify.t.sol +++ b/test/ECVerify.t.sol @@ -1,4 +1,4 @@ -pragma solidity >0.5.0; +pragma solidity 0.8.26; import "../lib/forge-std/src/Test.sol"; diff --git a/test/IterableMapping.t.sol b/test/IterableMapping.t.sol index ac78702..5240292 100644 --- a/test/IterableMapping.t.sol +++ b/test/IterableMapping.t.sol @@ -1,4 +1,4 @@ -pragma solidity >0.5.0; +pragma solidity 0.8.26; import "../lib/forge-std/src/Test.sol"; diff --git a/test/Migrations.t.sol b/test/Migrations.t.sol index c190eb3..f083244 100644 --- a/test/Migrations.t.sol +++ b/test/Migrations.t.sol @@ -1,4 +1,4 @@ -pragma solidity >0.5.0; +pragma solidity 0.8.26; import "../lib/forge-std/src/Test.sol"; diff --git a/test/StateReceiver.t.sol b/test/StateReceiver.t.sol index b78d7d2..d8cb2a8 100644 --- a/test/StateReceiver.t.sol +++ b/test/StateReceiver.t.sol @@ -1,129 +1,306 @@ -pragma solidity >0.5.0; +pragma solidity 0.8.26; import "../lib/forge-std/src/Test.sol"; import "./helpers/IStateReceiver.sol"; +import {TestReenterer} from "test/helpers/TestReenterer.sol"; +import {TestRevertingReceiver} from "test/helpers/TestRevertingReceiver.sol"; contract StateReceiverTest is Test { - address public constant SYSTEM_ADDRESS = 0xffffFFFfFFffffffffffffffFfFFFfffFFFfFFfE; - uint8 constant LIST_SHORT_START = 0xc0; + address public constant SYSTEM_ADDRESS = 0xffffFFFfFFffffffffffffffFfFFFfffFFFfFFfE; + uint8 constant LIST_SHORT_START = 0xc0; - IStateReceiver internal stateReceiver; + IStateReceiver internal stateReceiver = IStateReceiver(0x0000000000000000000000000000000000001001); + address internal rootSetter = makeAddr("rootSetter"); - function setUp() public { - stateReceiver = IStateReceiver(deployCode("out/StateReceiver.sol/StateReceiver.json")); - } + TestReenterer internal reenterer = new TestReenterer(); + TestRevertingReceiver internal revertingReceiver = new TestRevertingReceiver(); - function testRevert_commitState_OnlySystem() public { - vm.expectRevert("Not System Addess!"); - stateReceiver.commitState(0, ""); - } + function setUp() public { + address tmp = deployCode("out/StateReceiver.sol/StateReceiver.json", abi.encode(rootSetter)); + vm.etch(address(stateReceiver), tmp.code); + vm.label(address(stateReceiver), "stateReceiver"); + } - function testRevert_commitState_StateIdsAreNotSequential() public { - bytes memory recordBytes = _encodeRecord(2, address(0), ""); - vm.expectRevert("StateIds are not sequential"); - vm.prank(SYSTEM_ADDRESS); - stateReceiver.commitState(0, recordBytes); - } + function test_deployment() public view { + assertEq(stateReceiver.rootSetter(), rootSetter); + } - function test_commitState_ReceiverNotContract() public { - uint256 stateId = 1; - address receiver = makeAddr("receiver"); - bytes memory recordBytes = _encodeRecord(stateId, receiver, ""); + function testRevert_commitState_OnlySystem() public { + vm.expectRevert("Not System Addess!"); + stateReceiver.commitState(0, ""); + } - vm.prank(SYSTEM_ADDRESS); - assertFalse(stateReceiver.commitState(0, recordBytes)); - assertEq(stateReceiver.lastStateId(), 1); - } + function testRevert_commitState_StateIdsAreNotSequential() public { + bytes memory recordBytes = _encodeRecord(2, address(0), ""); + vm.expectRevert("StateIds are not sequential"); + vm.prank(SYSTEM_ADDRESS); + stateReceiver.commitState(0, recordBytes); + } - function test_commitState_ReceiverReverts() public { - uint256 stateId = 1; - address receiver = makeAddr("receiver"); - vm.etch(receiver, "00"); - vm.mockCallRevert(receiver, "", ""); - bytes memory stateData = "State data"; - bytes memory recordBytes = _encodeRecord(stateId, receiver, stateData); - - vm.expectCallMinGas(receiver, 0, 5_000_000, abi.encodeWithSignature("onStateReceive(uint256,bytes)", stateId, stateData)); - vm.prank(SYSTEM_ADDRESS); - vm.expectEmit(); - emit StateCommitted(stateId, false); - assertFalse(stateReceiver.commitState(0, recordBytes)); - assertEq(stateReceiver.lastStateId(), 1); - } + function test_commitState_ReceiverNotContract() public { + uint256 stateId = 1; + address receiver = makeAddr("receiver"); + bytes memory recordBytes = _encodeRecord(stateId, receiver, ""); + + vm.prank(SYSTEM_ADDRESS); + assertFalse(stateReceiver.commitState(0, recordBytes)); + assertEq(stateReceiver.lastStateId(), 1); + } + + function test_commitState_ReceiverReverts() public { + uint256 stateId = 1; + address receiver = makeAddr("receiver"); + vm.etch(receiver, "00"); + vm.mockCallRevert(receiver, "", ""); + bytes memory stateData = "State data"; + bytes memory recordBytes = _encodeRecord(stateId, receiver, stateData); + + vm.expectCallMinGas( + receiver, + 0, + 5_000_000, + abi.encodeWithSignature("onStateReceive(uint256,bytes)", stateId, stateData) + ); + vm.prank(SYSTEM_ADDRESS); + vm.expectEmit(); + emit StateCommitted(stateId, false); + assertFalse(stateReceiver.commitState(0, recordBytes)); + assertEq(stateReceiver.lastStateId(), 1); + assertEq(stateReceiver.failedStateSyncs(stateId), abi.encode(receiver, stateData)); + } + + function test_commitState_Success() public { + uint256 stateId = 1; + address receiver = makeAddr("receiver"); + vm.etch(receiver, "00"); + bytes memory stateData = "State data"; + bytes memory recordBytes = _encodeRecord(stateId, receiver, stateData); + + vm.expectCallMinGas( + receiver, + 0, + 5_000_000, + abi.encodeWithSignature("onStateReceive(uint256,bytes)", stateId, stateData) + ); + vm.prank(SYSTEM_ADDRESS); + vm.expectEmit(); + emit StateCommitted(stateId, true); + assertTrue(stateReceiver.commitState(0, recordBytes)); + assertEq(stateReceiver.lastStateId(), 1); + } + + function testRevert_ReplayFailedStateSync(uint256 stateId, bytes memory callData) public { + vm.assume(stateId > 0); + vm.store(address(stateReceiver), bytes32(0), bytes32(stateId - 1)); + assertTrue(revertingReceiver.shouldIRevert()); + bytes memory recordBytes = _encodeRecord(stateId, address(revertingReceiver), callData); + + vm.prank(SYSTEM_ADDRESS); + vm.expectEmit(); + emit StateCommitted(stateId, false); + assertFalse(stateReceiver.commitState(0, recordBytes)); + assertEq(stateReceiver.failedStateSyncs(stateId), abi.encode(address(revertingReceiver), callData)); + + assertTrue(revertingReceiver.shouldIRevert()); + + vm.expectRevert("TestRevertingReceiver"); + stateReceiver.replayFailedStateSync(stateId); + } + + function test_ReplayFailedStateSync(uint256 stateId, bytes memory callData) public { + vm.assume(stateId > 0); + vm.store(address(stateReceiver), bytes32(0), bytes32(stateId - 1)); + assertTrue(revertingReceiver.shouldIRevert()); + bytes memory recordBytes = _encodeRecord(stateId, address(revertingReceiver), callData); + + vm.prank(SYSTEM_ADDRESS); + vm.expectEmit(); + emit StateCommitted(stateId, false); + assertFalse(stateReceiver.commitState(0, recordBytes)); + assertEq(stateReceiver.failedStateSyncs(stateId), abi.encode(address(revertingReceiver), callData)); + + revertingReceiver.toggle(); + assertFalse(revertingReceiver.shouldIRevert()); + + vm.expectEmit(); + emit StateSyncReplay(stateId); + vm.expectCall( + address(revertingReceiver), + 0, + abi.encodeWithSignature("onStateReceive(uint256,bytes)", stateId, callData) + ); + stateReceiver.replayFailedStateSync(stateId); + + vm.expectRevert("!found"); + stateReceiver.replayFailedStateSync(stateId); + } - function test_commitState_Success() public { - uint256 stateId = 1; - address receiver = makeAddr("receiver"); - vm.etch(receiver, "00"); - bytes memory stateData = "State data"; - bytes memory recordBytes = _encodeRecord(stateId, receiver, stateData); - - vm.expectCallMinGas(receiver, 0, 5_000_000, abi.encodeWithSignature("onStateReceive(uint256,bytes)", stateId, stateData)); - vm.prank(SYSTEM_ADDRESS); - vm.expectEmit(); - emit StateCommitted(stateId, true); - assertTrue(stateReceiver.commitState(0, recordBytes)); - assertEq(stateReceiver.lastStateId(), 1); + function test_ReplayFailFromReenterer(uint256 stateId, bytes memory callData) public { + vm.assume(stateId > 0); + vm.store(address(stateReceiver), bytes32(0), bytes32(stateId - 1)); + bytes memory recordBytes = _encodeRecord(stateId, address(reenterer), callData); + + vm.prank(SYSTEM_ADDRESS); + vm.expectEmit(); + emit StateCommitted(stateId, false); + assertFalse(stateReceiver.commitState(0, recordBytes)); + assertEq(stateReceiver.failedStateSyncs(stateId), abi.encode(address(reenterer), callData)); + + revertingReceiver.toggle(); + assertFalse(revertingReceiver.shouldIRevert()); + + vm.expectCall(address(reenterer), 0, abi.encodeWithSignature("onStateReceive(uint256,bytes)", stateId, callData)); + vm.expectRevert("!found"); + stateReceiver.replayFailedStateSync(stateId); + } + + function test_rootSetter(address random) public { + vm.prank(random); + if (random != rootSetter) vm.expectRevert("!rootSetter"); + stateReceiver.setRootAndLeafCount(bytes32(uint(0x1337)), 0); + } + + function test_shouldNotReplayZeroLeaf(bytes32 root, bytes32[16] memory proof) public { + vm.prank(rootSetter); + stateReceiver.setRootAndLeafCount(root, 1); + + vm.expectRevert(bytes("used")); + stateReceiver.replayHistoricFailedStateSync(proof, 0, 0, address(0), new bytes(0)); + } + + function test_shouldNotReplayInvalidProof(bytes32 root, bytes32[16] memory proof, bytes memory stateData) public { + vm.prank(rootSetter); + stateReceiver.setRootAndLeafCount(root, 1); + + vm.expectRevert("!proof"); + stateReceiver.replayHistoricFailedStateSync( + proof, + vm.randomUint(0, 2 ** 16), + vm.randomUint(), + vm.randomAddress(), + stateData + ); + } + + function test_FailedStateSyncs(bytes[] memory stateDatas) external { + vm.assume(stateDatas.length > 1 && stateDatas.length < 10); + + address receiver = address(revertingReceiver); + + for (uint256 i = 0; i < stateDatas.length; ++i) { + bytes memory recordBytes = _encodeRecord(i + 1, receiver, stateDatas[i]); + + vm.prank(SYSTEM_ADDRESS); + vm.expectEmit(); + emit StateCommitted(i + 1, false); + assertFalse(stateReceiver.commitState(0, recordBytes)); } - function _encodeRecord(uint256 stateId, address receiver, bytes memory stateData) public returns (bytes memory recordBytes) { - return abi.encodePacked(LIST_SHORT_START, _rlpEncodeUint(stateId), _rlpEncodeAddress(receiver), _rlpEncodeBytes(stateData)); + uint256 leafCount = stateDatas.length; + bytes32 root; + bytes[] memory proofs = new bytes[](leafCount); + (root, proofs) = _getRootAndProofs(receiver, abi.encode(stateDatas)); + + vm.prank(rootSetter); + stateReceiver.setRootAndLeafCount(root, leafCount); + + revertingReceiver.toggle(); + + for (uint256 i = 0; i < leafCount; ++i) { + vm.expectCall(receiver, 0, abi.encodeWithSignature("onStateReceive(uint256,bytes)", i + 1, stateDatas[i])); + vm.expectEmit(); + emit StateSyncReplay(i + 1); + stateReceiver.replayHistoricFailedStateSync( + abi.decode(proofs[i], (bytes32[16])), + i, + i + 1, + receiver, + stateDatas[i] + ); } + } + + function _getRootAndProofs( + address receiver, + bytes memory stateDatasEncoded + ) internal returns (bytes32 root, bytes[] memory proofs) { + string[] memory inputs = new string[](4); + inputs[0] = "node"; + inputs[1] = "test/helpers/merkle.js"; + inputs[2] = vm.toString(receiver); + inputs[3] = vm.toString(stateDatasEncoded); + + (root, proofs) = abi.decode(vm.ffi(inputs), (bytes32, bytes[])); + } + + function _encodeRecord( + uint256 stateId, + address receiver, + bytes memory stateData + ) public pure returns (bytes memory recordBytes) { + return + abi.encodePacked( + LIST_SHORT_START, + _rlpEncodeUint(stateId), + _rlpEncodeAddress(receiver), + _rlpEncodeBytes(stateData) + ); + } - function _rlpEncodeUint(uint256 value) internal pure returns (bytes memory) { - if (value == 0) { - return hex"80"; - } else if (value < 0x80) { - return abi.encodePacked(uint8(value)); - } else { - bytes memory result = new bytes(33); - uint256 length = 0; - while (value != 0) { - length++; - result[33 - length] = bytes1(uint8(value)); - value >>= 8; - } - bytes memory encoded = new bytes(length + 1); - encoded[0] = bytes1(uint8(0x80 + length)); - for (uint256 i = 0; i < length; i++) { - encoded[i + 1] = result[33 - length + i]; - } - return encoded; - } + function _rlpEncodeUint(uint256 value) internal pure returns (bytes memory) { + if (value == 0) { + return hex"80"; + } else if (value < 0x80) { + return abi.encodePacked(uint8(value)); + } else { + bytes memory result = new bytes(33); + uint256 length = 0; + while (value != 0) { + length++; + result[33 - length] = bytes1(uint8(value)); + value >>= 8; + } + bytes memory encoded = new bytes(length + 1); + encoded[0] = bytes1(uint8(0x80 + length)); + for (uint256 i = 0; i < length; i++) { + encoded[i + 1] = result[33 - length + i]; + } + return encoded; } + } - function _rlpEncodeAddress(address value) internal pure returns (bytes memory) { - bytes memory encoded = new bytes(21); - encoded[0] = bytes1(uint8(0x94)); - for (uint256 i = 0; i < 20; i++) { - encoded[i + 1] = bytes1(uint8(uint256(uint160(value)) >> (8 * (19 - i)))); - } - return encoded; + function _rlpEncodeAddress(address value) internal pure returns (bytes memory) { + bytes memory encoded = new bytes(21); + encoded[0] = bytes1(uint8(0x94)); + for (uint256 i = 0; i < 20; i++) { + encoded[i + 1] = bytes1(uint8(uint256(uint160(value)) >> (8 * (19 - i)))); } + return encoded; + } - function _rlpEncodeBytes(bytes memory value) internal pure returns (bytes memory) { - uint256 length = value.length; - if (length == 1 && uint8(value[0]) < 0x80) { - return value; - } else if (length <= 55) { - bytes memory encoded = new bytes(length + 1); - encoded[0] = bytes1(uint8(0x80 + length)); - for (uint256 i = 0; i < length; i++) { - encoded[i + 1] = value[i]; - } - return encoded; - } else { - bytes memory lengthEncoded = _rlpEncodeUint(length); - bytes memory encoded = new bytes(1 + lengthEncoded.length + length); - encoded[0] = bytes1(uint8(0xb7 + lengthEncoded.length)); - for (uint256 i = 0; i < lengthEncoded.length; i++) { - encoded[i + 1] = lengthEncoded[i]; - } - for (uint256 i = 0; i < length; i++) { - encoded[i + 1 + lengthEncoded.length] = value[i]; - } - return encoded; - } + function _rlpEncodeBytes(bytes memory value) internal pure returns (bytes memory) { + uint256 length = value.length; + if (length == 1 && uint8(value[0]) < 0x80) { + return value; + } else if (length <= 55) { + bytes memory encoded = new bytes(length + 1); + encoded[0] = bytes1(uint8(0x80 + length)); + for (uint256 i = 0; i < length; i++) { + encoded[i + 1] = value[i]; + } + return encoded; + } else { + bytes memory lengthEncoded = _rlpEncodeUint(length); + bytes memory encoded = new bytes(1 + lengthEncoded.length + length); + encoded[0] = bytes1(uint8(0xb7 + lengthEncoded.length)); + for (uint256 i = 0; i < lengthEncoded.length; i++) { + encoded[i + 1] = lengthEncoded[i]; + } + for (uint256 i = 0; i < length; i++) { + encoded[i + 1 + lengthEncoded.length] = value[i]; + } + return encoded; } + } } diff --git a/test/System.t.sol b/test/System.t.sol index 1ac00a9..95e6e98 100644 --- a/test/System.t.sol +++ b/test/System.t.sol @@ -1,4 +1,4 @@ -pragma solidity >0.5.0; +pragma solidity 0.8.26; import "../lib/forge-std/src/Test.sol"; @@ -8,7 +8,7 @@ contract SystemTest is Test { ISystem internal system; function setUp() public { - system = ISystem(deployCode("out/System.sol/System.json")); + system = ISystem(deployCode("out/System.sol/System.0.6.12.json")); } function test_constants() public { diff --git a/test/ValidatorVerifier.t.sol b/test/ValidatorVerifier.t.sol index 67f2b06..9b117b1 100644 --- a/test/ValidatorVerifier.t.sol +++ b/test/ValidatorVerifier.t.sol @@ -1,4 +1,4 @@ -pragma solidity >0.5.0; +pragma solidity 0.8.26; import "../lib/forge-std/src/Test.sol"; diff --git a/test/helpers/IStateReceiver.sol b/test/helpers/IStateReceiver.sol index 4756457..dbdae1b 100644 --- a/test/helpers/IStateReceiver.sol +++ b/test/helpers/IStateReceiver.sol @@ -2,9 +2,27 @@ pragma solidity >0.5.0; event StateCommitted(uint256 indexed stateId, bool success); +event StateSyncReplay(uint256 indexed stateId); interface IStateReceiver { - function SYSTEM_ADDRESS() external view returns (address); - function commitState(uint256 syncTime, bytes memory recordBytes) external returns (bool success); - function lastStateId() external view returns (uint256); + function SYSTEM_ADDRESS() external view returns (address); + function commitState(uint256 syncTime, bytes memory recordBytes) external returns (bool success); + function lastStateId() external view returns (uint256); + function rootSetter() external view returns (address); + function failedStateSyncsRoot() external view returns (bytes32); + function nullifier(bytes32) external view returns (bool); + function failedStateSyncs(uint256) external view returns (bytes memory); + function leafCount() external view returns (uint256); + function replayCount() external view returns (uint256); + function TREE_DEPTH() external view returns (uint256); + + function replayFailedStateSync(uint256 stateId) external; + function setRootAndLeafCount(bytes32 _root, uint256 _leafCount) external; + function replayHistoricFailedStateSync( + bytes32[16] calldata proof, + uint256 leafIndex, + uint256 stateId, + address receiver, + bytes calldata data + ) external; } diff --git a/test/helpers/TestReenterer.sol b/test/helpers/TestReenterer.sol new file mode 100644 index 0000000..f19a351 --- /dev/null +++ b/test/helpers/TestReenterer.sol @@ -0,0 +1,17 @@ +pragma solidity 0.8.26; + +contract TestReenterer { + uint256 public reenterCount; + + function onStateReceive(uint256 id, bytes calldata _data) external { + if (reenterCount++ == 0) { + (bool success, bytes memory ret) = msg.sender.call(abi.encodeWithSignature("replayFailedStateSync(uint256)", id)); + // bubble up revert for tests + if (!success) { + assembly { + revert(add(ret, 0x20), mload(ret)) + } + } + } + } +} diff --git a/test/helpers/TestRevertingReceiver.sol b/test/helpers/TestRevertingReceiver.sol new file mode 100644 index 0000000..89fea93 --- /dev/null +++ b/test/helpers/TestRevertingReceiver.sol @@ -0,0 +1,12 @@ +pragma solidity 0.8.26; + +contract TestRevertingReceiver { + bool public shouldIRevert = true; + function onStateReceive(uint256 _id, bytes calldata _data) external { + if (shouldIRevert) revert("TestRevertingReceiver"); + } + + function toggle() external { + shouldIRevert = !shouldIRevert; + } +} diff --git a/test/helpers/merkle.js b/test/helpers/merkle.js new file mode 100644 index 0000000..6107e35 --- /dev/null +++ b/test/helpers/merkle.js @@ -0,0 +1,132 @@ +const AbiCoder = require('web3-eth-abi') +const { keccak256 } = require('web3-utils') + +const abi = AbiCoder + +class SparseMerkleTree { + constructor(height) { + if (height <= 1) { + throw new Error('invalid height, must be greater than 1') + } + this.height = height + this.zeroHashes = this.generateZeroHashes(height) + const tree = [] + for (let i = 0; i <= height; i++) { + tree.push([]) + } + this.tree = tree + this.leafCount = 0 + this.dirty = false + } + + add(leaf) { + this.dirty = true + this.leafCount++ + this.tree[0].push(leaf) + } + + calcBranches() { + for (let i = 0; i < this.height; i++) { + const parent = this.tree[i + 1] + const child = this.tree[i] + for (let j = 0; j < child.length; j += 2) { + const leftNode = child[j] + const rightNode = + j + 1 < child.length ? child[j + 1] : this.zeroHashes[i] + parent[j / 2] = keccak256( + abi.encodeParameters(['bytes32', 'bytes32'], [leftNode, rightNode]) + ) + } + } + this.dirty = false + } + + getProofTreeByIndex(index) { + if (this.dirty) this.calcBranches() + const proof = [] + let currentIndex = index + for (let i = 0; i < this.height; i++) { + currentIndex = + currentIndex % 2 === 1 ? currentIndex - 1 : currentIndex + 1 + if (currentIndex < this.tree[i].length) + proof.push(this.tree[i][currentIndex]) + else proof.push(this.zeroHashes[i]) + currentIndex = Math.floor(currentIndex / 2) + } + + return proof + } + + getProofTreeByValue(value) { + const index = this.tree[0].indexOf(value) + if (index === -1) throw new Error('value not found') + return this.getProofTreeByIndex(index) + } + + getRoot() { + if (this.tree[0][0] === undefined) { + // No leafs in the tree, calculate root with all leafs to 0 + return keccak256( + abi.encodeParameters( + ['bytes32', 'bytes32'], + [this.zeroHashes[this.height - 1], this.zeroHashes[this.height - 1]] + ) + ) + } + if (this.dirty) this.calcBranches() + + return this.tree[this.height][0] + } + + generateZeroHashes(height) { + // keccak256(abi.encode(uint256(0), address(0), new bytes(0))); + const zeroHashes = [ + keccak256( + abi.encodeParameters( + ['uint256', 'address', 'bytes'], + [0, '0x' + '0'.repeat(40), '0x'] + ) + ) + ] + for (let i = 1; i < height; i++) { + zeroHashes.push( + keccak256( + abi.encodeParameters( + ['bytes32', 'bytes32'], + [zeroHashes[i - 1], zeroHashes[i - 1]] + ) + ) + ) + } + + return zeroHashes + } +} + +function getLeaf(stateID, receiverAddress, stateData) { + return keccak256( + abi.encodeParameters( + ['uint256', 'address', 'bytes'], + [stateID, receiverAddress, stateData] + ) + ) +} + +const [receiver, stateDatasEncoded] = process.argv.slice(2) + +const stateDatas = abi.decodeParameter('bytes[]', stateDatasEncoded) + +const tree = new SparseMerkleTree(16) + +for (let i = 0; i < stateDatas.length; i++) { + tree.add(getLeaf(i + 1, receiver, stateDatas[i])) +} +const root = tree.getRoot() +const proofs = stateDatas.map((_, i) => tree.getProofTreeByIndex(i)) + +console.log( + abi.encodeParameters( + ['bytes32', 'bytes[]'], + [root, proofs.map((proof) => abi.encodeParameter('bytes32[16]', proof))] + ) +)