From 709decfab916b9338919bf0550b935afdea13287 Mon Sep 17 00:00:00 2001 From: Artem Chystiakov <47551140+Arvolear@users.noreply.github.com> Date: Sat, 6 Apr 2024 11:31:40 +0300 Subject: [PATCH] Support of `update` and `remove` operations in SMT (#96) * the meat * fixed tests * fix remove + tests * typo * added idempotence test * fix test * optimization * Updated SparseMerkleTree doc --------- Co-authored-by: Kyryl Riabov --- .../libs/data-structures/SparseMerkleTree.sol | 379 +++++++++++++----- .../data-structures/SparseMerkleTreeMock.sol | 40 +- package-lock.json | 22 +- package.json | 6 +- .../data-structures/SparseMerkleTree.test.ts | 319 +++++++++++++-- 5 files changed, 625 insertions(+), 141 deletions(-) diff --git a/contracts/libs/data-structures/SparseMerkleTree.sol b/contracts/libs/data-structures/SparseMerkleTree.sol index 75c886b1..d5acf940 100644 --- a/contracts/libs/data-structures/SparseMerkleTree.sol +++ b/contracts/libs/data-structures/SparseMerkleTree.sol @@ -12,27 +12,33 @@ pragma solidity ^0.8.4; * in using different types of keys and values. * * The main differences from the original implementation include: + * - Added the ability to remove or update nodes in the tree. * - Optimized storage usage to reduce the number of storage slots. * - Added the ability to set custom hash functions. * - Removed methods and associated storage for managing the tree root's history. * - * Gas usage for adding (addBytes32) 16,001 leaves to a tree of size 80 is detailed below: + * Gas usage for adding (addUint) 20,000 leaves to a tree of size 80 "based" on the Poseidon Hash function is detailed below: * - * | Statistic | Value | + * | Statistic | Add | * |-----------|-------------- | - * | Count | 16,001 | - * | Mean | 1,444,220 gas | - * | Std Dev | 209,147.6 gas | - * | Min | 177,853 gas | - * | 25% | 1,317,555 gas | - * | 50% | 1,461,562 gas | - * | 75% | 1,554,030 gas | - * | Max | 2,723,812 gas | + * | Count | 20,000 | + * | Mean | 890,446 gas | + * | Std Dev | 147,775 gas | + * | Min | 177,797 gas | + * | 25% | 784,961 gas | + * | 50% | 866,482 gas | + * | 75% | 959,075 gas | + * | Max | 1,937,554 gas | * * The gas cost increases linearly with the depth of the leaves added. This growth can be approximated by the following formula: - * Linear regression formula: y = 92,457x + 255,689 + * Linear regression formula: y = 46,377x + 215,088 * - * This implies that adding an element at depth 80 would approximately cost 7.5M gas. + * This implies that adding an element at depth 80 would approximately cost 3.93M gas. + * + * On the other hand, the growth of the gas cost for removing leaves can be approximated by the following formula: + * Linear regression formula: y = 44840*x + 88821 + * + * This implies that removing an element at depth 80 would approximately cost 3.68M gas. * * ## Usage Example: * @@ -50,6 +56,8 @@ pragma solidity ^0.8.4; * SparseMerkleTree.Proof memory proof = uintTree.getProof(100); * * uintTree.getNodeByKey(100); + * + * uintTree.remove(100); * ``` */ library SparseMerkleTree { @@ -74,7 +82,7 @@ library SparseMerkleTree { * @param tree self. * @param maxDepth_ The max depth of the Merkle tree. */ - function initialize(UintSMT storage tree, uint64 maxDepth_) internal { + function initialize(UintSMT storage tree, uint32 maxDepth_) internal { _initialize(tree._tree, maxDepth_); } @@ -89,7 +97,7 @@ library SparseMerkleTree { * @param tree self. * @param maxDepth_ The max depth of the Merkle tree. */ - function setMaxDepth(UintSMT storage tree, uint64 maxDepth_) internal { + function setMaxDepth(UintSMT storage tree, uint32 maxDepth_) internal { _setMaxDepth(tree._tree, maxDepth_); } @@ -119,10 +127,33 @@ library SparseMerkleTree { * @param key_ The key of the element. * @param value_ The value of the element. */ - function add(UintSMT storage tree, uint256 key_, uint256 value_) internal { + function add(UintSMT storage tree, bytes32 key_, uint256 value_) internal { _add(tree._tree, bytes32(key_), bytes32(value_)); } + /** + * @notice The function to remove a (leaf) element from the uint256 tree. + * Complexity is O(log(n)), where n is the max depth of the tree. + * + * @param tree self. + * @param key_ The key of the element. + */ + function remove(UintSMT storage tree, bytes32 key_) internal { + _remove(tree._tree, key_); + } + + /** + * @notice The function to update a (leaf) element in the uint256 tree. + * Complexity is O(log(n)), where n is the max depth of the tree. + * + * @param tree self. + * @param key_ The key of the element. + * @param newValue_ The new value of the element. + */ + function update(UintSMT storage tree, bytes32 key_, uint256 newValue_) internal { + _update(tree._tree, key_, bytes32(newValue_)); + } + /** * @notice The function to get the proof if a node with specific key exists or not exists in the SMT. * Complexity is O(log(n)), where n is the max depth of the tree. @@ -131,7 +162,7 @@ library SparseMerkleTree { * @param key_ The key of the element. * @return SMT proof struct. */ - function getProof(UintSMT storage tree, uint256 key_) internal view returns (Proof memory) { + function getProof(UintSMT storage tree, bytes32 key_) internal view returns (Proof memory) { return _proof(tree._tree, bytes32(key_)); } @@ -221,7 +252,7 @@ library SparseMerkleTree { * @param tree self. * @param maxDepth_ The max depth of the Merkle tree. */ - function initialize(Bytes32SMT storage tree, uint64 maxDepth_) internal { + function initialize(Bytes32SMT storage tree, uint32 maxDepth_) internal { _initialize(tree._tree, maxDepth_); } @@ -236,7 +267,7 @@ library SparseMerkleTree { * @param tree self. * @param maxDepth_ The max depth of the Merkle tree. */ - function setMaxDepth(Bytes32SMT storage tree, uint64 maxDepth_) internal { + function setMaxDepth(Bytes32SMT storage tree, uint32 maxDepth_) internal { _setMaxDepth(tree._tree, maxDepth_); } @@ -270,6 +301,29 @@ library SparseMerkleTree { _add(tree._tree, key_, value_); } + /** + * @notice The function to remove a (leaf) element from the bytes32 tree. + * Complexity is O(log(n)), where n is the max depth of the tree. + * + * @param tree self. + * @param key_ The key of the element. + */ + function remove(Bytes32SMT storage tree, bytes32 key_) internal { + _remove(tree._tree, key_); + } + + /** + * @notice The function to update a (leaf) element in the bytes32 tree. + * Complexity is O(log(n)), where n is the max depth of the tree. + * + * @param tree self. + * @param key_ The key of the element. + * @param newValue_ The new value of the element. + */ + function update(Bytes32SMT storage tree, bytes32 key_, bytes32 newValue_) internal { + _update(tree._tree, key_, newValue_); + } + /** * @notice The function to get the proof if a node with specific key exists or not exists in the SMT. * Complexity is O(log(n)), where n is the max depth of the tree. @@ -374,7 +428,7 @@ library SparseMerkleTree { * @param tree self. * @param maxDepth_ The max depth of the Merkle tree. */ - function initialize(AddressSMT storage tree, uint64 maxDepth_) internal { + function initialize(AddressSMT storage tree, uint32 maxDepth_) internal { _initialize(tree._tree, maxDepth_); } @@ -389,7 +443,7 @@ library SparseMerkleTree { * @param tree self. * @param maxDepth_ The max depth of the Merkle tree. */ - function setMaxDepth(AddressSMT storage tree, uint64 maxDepth_) internal { + function setMaxDepth(AddressSMT storage tree, uint32 maxDepth_) internal { _setMaxDepth(tree._tree, maxDepth_); } @@ -423,6 +477,29 @@ library SparseMerkleTree { _add(tree._tree, key_, bytes32(uint256(uint160(value_)))); } + /** + * @notice The function to remove a (leaf) element from the address tree. + * Complexity is O(log(n)), where n is the max depth of the tree. + * + * @param tree self. + * @param key_ The key of the element. + */ + function remove(AddressSMT storage tree, bytes32 key_) internal { + _remove(tree._tree, key_); + } + + /** + * @notice The function to update a (leaf) element in the address tree. + * Complexity is O(log(n)), where n is the max depth of the tree. + * + * @param tree self. + * @param key_ The key of the element. + * @param newValue_ The new value of the element. + */ + function update(AddressSMT storage tree, bytes32 key_, address newValue_) internal { + _update(tree._tree, key_, bytes32(uint256(uint160(newValue_)))); + } + /** * @notice The function to get the proof if a node with specific key exists or not exists in the SMT. * Complexity is O(log(n)), where n is the max depth of the tree. @@ -548,8 +625,9 @@ library SparseMerkleTree { struct SMT { mapping(uint256 => Node) nodes; uint64 merkleRootId; - uint64 maxDepth; uint64 nodesCount; + uint64 deletedNodesCount; + uint32 maxDepth; bool isCustomHasherSet; function(bytes32, bytes32) view returns (bytes32) hash2; function(bytes32, bytes32, bytes32) view returns (bytes32) hash3; @@ -605,13 +683,13 @@ library SparseMerkleTree { _; } - function _initialize(SMT storage tree, uint64 maxDepth_) private { + function _initialize(SMT storage tree, uint32 maxDepth_) private { require(!_isInitialized(tree), "SparseMerkleTree: tree is already initialized"); _setMaxDepth(tree, maxDepth_); } - function _setMaxDepth(SMT storage tree, uint64 maxDepth_) private { + function _setMaxDepth(SMT storage tree, uint32 maxDepth_) private { require(maxDepth_ > 0, "SparseMerkleTree: max depth must be greater than zero"); require(maxDepth_ > tree.maxDepth, "SparseMerkleTree: max depth can only be increased"); require( @@ -648,56 +726,25 @@ library SparseMerkleTree { tree.merkleRootId = uint64(_add(tree, node_, tree.merkleRootId, 0)); } - function _proof(SMT storage tree, bytes32 key_) private view returns (Proof memory) { - uint256 maxDepth_ = _maxDepth(tree); + function _remove(SMT storage tree, bytes32 key_) private onlyInitialized(tree) { + tree.merkleRootId = uint64(_remove(tree, key_, tree.merkleRootId, 0)); + } - Proof memory proof_ = Proof({ - root: _root(tree), - siblings: new bytes32[](maxDepth_), - existence: false, + function _update( + SMT storage tree, + bytes32 key_, + bytes32 newValue_ + ) private onlyInitialized(tree) { + Node memory node_ = Node({ + nodeType: NodeType.LEAF, + childLeft: ZERO_IDX, + childRight: ZERO_IDX, + nodeHash: ZERO_HASH, key: key_, - value: ZERO_HASH, - auxExistence: false, - auxKey: ZERO_HASH, - auxValue: ZERO_HASH + value: newValue_ }); - Node memory node_; - uint256 nextNodeId_ = tree.merkleRootId; - - for (uint256 i = 0; i <= maxDepth_; i++) { - node_ = _node(tree, nextNodeId_); - - if (node_.nodeType == NodeType.EMPTY) { - break; - } else if (node_.nodeType == NodeType.LEAF) { - if (node_.key == proof_.key) { - proof_.existence = true; - proof_.value = node_.value; - - break; - } else { - proof_.auxExistence = true; - proof_.auxKey = node_.key; - proof_.auxValue = node_.value; - proof_.value = node_.value; - - break; - } - } else { - if ((uint256(proof_.key) >> i) & 1 == 1) { - nextNodeId_ = node_.childRight; - - proof_.siblings[i] = tree.nodes[node_.childLeft].nodeHash; - } else { - nextNodeId_ = node_.childLeft; - - proof_.siblings[i] = tree.nodes[node_.childRight].nodeHash; - } - } - } - - return proof_; + _update(tree, node_, tree.merkleRootId, 0); } /** @@ -713,48 +760,128 @@ library SparseMerkleTree { uint16 currentDepth_ ) private returns (uint256) { Node memory currentNode_ = tree.nodes[nodeId_]; - uint256 leafId_; if (currentNode_.nodeType == NodeType.EMPTY) { - leafId_ = _setNode(tree, newLeaf_); + return _setNode(tree, newLeaf_); } else if (currentNode_.nodeType == NodeType.LEAF) { if (currentNode_.key == newLeaf_.key) { revert("SparseMerkleTree: the key already exists"); } - leafId_ = _pushLeaf(tree, newLeaf_, currentNode_, nodeId_, currentDepth_); + return _pushLeaf(tree, newLeaf_, currentNode_, nodeId_, currentDepth_); } else { - Node memory newNodeMiddle_; uint256 nextNodeId_; if ((uint256(newLeaf_.key) >> currentDepth_) & 1 == 1) { nextNodeId_ = _add(tree, newLeaf_, currentNode_.childRight, currentDepth_ + 1); - newNodeMiddle_ = Node({ - nodeType: NodeType.MIDDLE, - childLeft: currentNode_.childLeft, - childRight: uint64(nextNodeId_), - nodeHash: ZERO_HASH, - key: ZERO_HASH, - value: ZERO_HASH - }); + tree.nodes[nodeId_].childRight = uint64(nextNodeId_); } else { nextNodeId_ = _add(tree, newLeaf_, currentNode_.childLeft, currentDepth_ + 1); - newNodeMiddle_ = Node({ - nodeType: NodeType.MIDDLE, - childLeft: uint64(nextNodeId_), - childRight: currentNode_.childRight, - nodeHash: ZERO_HASH, - key: ZERO_HASH, - value: ZERO_HASH - }); + tree.nodes[nodeId_].childLeft = uint64(nextNodeId_); + } + + tree.nodes[nodeId_].nodeHash = _getNodeHash(tree, tree.nodes[nodeId_]); + + return nodeId_; + } + } + + function _remove( + SMT storage tree, + bytes32 key_, + uint256 nodeId_, + uint16 currentDepth_ + ) private returns (uint256) { + Node memory currentNode_ = tree.nodes[nodeId_]; + + if (currentNode_.nodeType == NodeType.EMPTY) { + revert("SparseMerkleTree: the node does not exist"); + } else if (currentNode_.nodeType == NodeType.LEAF) { + if (currentNode_.key != key_) { + revert("SparseMerkleTree: the leaf does not match"); + } + + _deleteNode(tree, nodeId_); + + return ZERO_IDX; + } else { + uint256 nextNodeId_; + + if ((uint256(key_) >> currentDepth_) & 1 == 1) { + nextNodeId_ = _remove(tree, key_, currentNode_.childRight, currentDepth_ + 1); + } else { + nextNodeId_ = _remove(tree, key_, currentNode_.childLeft, currentDepth_ + 1); + } + + NodeType rightType_ = tree.nodes[currentNode_.childRight].nodeType; + NodeType leftType_ = tree.nodes[currentNode_.childLeft].nodeType; + + if (rightType_ == NodeType.EMPTY && leftType_ == NodeType.EMPTY) { + _deleteNode(tree, nodeId_); + + return nextNodeId_; } - leafId_ = _setNode(tree, newNodeMiddle_); + NodeType nextType_ = tree.nodes[nextNodeId_].nodeType; + + if ( + (rightType_ == NodeType.EMPTY || leftType_ == NodeType.EMPTY) && + nextType_ != NodeType.MIDDLE + ) { + if ( + nextType_ == NodeType.EMPTY && + (leftType_ == NodeType.LEAF || rightType_ == NodeType.LEAF) + ) { + _deleteNode(tree, nodeId_); + + if (rightType_ == NodeType.LEAF) { + return currentNode_.childRight; + } + + return currentNode_.childLeft; + } + + if (rightType_ == NodeType.EMPTY) { + tree.nodes[nodeId_].childRight = uint64(nextNodeId_); + } else { + tree.nodes[nodeId_].childLeft = uint64(nextNodeId_); + } + } + + tree.nodes[nodeId_].nodeHash = _getNodeHash(tree, tree.nodes[nodeId_]); + + return nodeId_; + } + } + + function _update( + SMT storage tree, + Node memory newLeaf_, + uint256 nodeId_, + uint16 currentDepth_ + ) private { + Node memory currentNode_ = tree.nodes[nodeId_]; + + if (currentNode_.nodeType == NodeType.EMPTY) { + revert("SparseMerkleTree: the node does not exist"); + } else if (currentNode_.nodeType == NodeType.LEAF) { + if (currentNode_.key != newLeaf_.key) { + revert("SparseMerkleTree: the leaf does not match"); + } + + tree.nodes[nodeId_] = newLeaf_; + currentNode_ = newLeaf_; + } else { + if ((uint256(newLeaf_.key) >> currentDepth_) & 1 == 1) { + _update(tree, newLeaf_, currentNode_.childRight, currentDepth_ + 1); + } else { + _update(tree, newLeaf_, currentNode_.childLeft, currentDepth_ + 1); + } } - return leafId_; + tree.nodes[nodeId_].nodeHash = _getNodeHash(tree, currentNode_); } function _pushLeaf( @@ -831,15 +958,23 @@ library SparseMerkleTree { } /** - * @dev The function used to add only new nodes. + * @dev The function used to add new nodes. */ function _setNode(SMT storage tree, Node memory node_) private returns (uint256) { node_.nodeHash = _getNodeHash(tree, node_); - uint256 newSize_ = ++tree.nodesCount; - tree.nodes[newSize_] = node_; + uint256 newCount_ = ++tree.nodesCount; + tree.nodes[newCount_] = node_; - return newSize_; + return newCount_; + } + + /** + * @dev The function used to delete removed nodes. + */ + function _deleteNode(SMT storage tree, uint256 nodeId_) private { + delete tree.nodes[nodeId_]; + ++tree.deletedNodesCount; } /** @@ -861,6 +996,58 @@ library SparseMerkleTree { return hash2_(tree.nodes[node_.childLeft].nodeHash, tree.nodes[node_.childRight].nodeHash); } + function _proof(SMT storage tree, bytes32 key_) private view returns (Proof memory) { + uint256 maxDepth_ = _maxDepth(tree); + + Proof memory proof_ = Proof({ + root: _root(tree), + siblings: new bytes32[](maxDepth_), + existence: false, + key: key_, + value: ZERO_HASH, + auxExistence: false, + auxKey: ZERO_HASH, + auxValue: ZERO_HASH + }); + + Node memory node_; + uint256 nextNodeId_ = tree.merkleRootId; + + for (uint256 i = 0; i <= maxDepth_; i++) { + node_ = _node(tree, nextNodeId_); + + if (node_.nodeType == NodeType.EMPTY) { + break; + } else if (node_.nodeType == NodeType.LEAF) { + if (node_.key == proof_.key) { + proof_.existence = true; + proof_.value = node_.value; + + break; + } else { + proof_.auxExistence = true; + proof_.auxKey = node_.key; + proof_.auxValue = node_.value; + proof_.value = node_.value; + + break; + } + } else { + if ((uint256(proof_.key) >> i) & 1 == 1) { + nextNodeId_ = node_.childRight; + + proof_.siblings[i] = tree.nodes[node_.childLeft].nodeHash; + } else { + nextNodeId_ = node_.childLeft; + + proof_.siblings[i] = tree.nodes[node_.childRight].nodeHash; + } + } + } + + return proof_; + } + function _hash2(bytes32 a, bytes32 b) private pure returns (bytes32 result) { assembly { mstore(0, a) @@ -933,7 +1120,7 @@ library SparseMerkleTree { } function _nodesCount(SMT storage tree) private view returns (uint256) { - return tree.nodesCount; + return tree.nodesCount - tree.deletedNodesCount; } function _isInitialized(SMT storage tree) private view returns (bool) { diff --git a/contracts/mock/libs/data-structures/SparseMerkleTreeMock.sol b/contracts/mock/libs/data-structures/SparseMerkleTreeMock.sol index 274169f0..2e3ad90a 100644 --- a/contracts/mock/libs/data-structures/SparseMerkleTreeMock.sol +++ b/contracts/mock/libs/data-structures/SparseMerkleTreeMock.sol @@ -18,27 +18,27 @@ contract SparseMerkleTreeMock { SparseMerkleTree.Bytes32SMT internal _bytes32Tree; SparseMerkleTree.AddressSMT internal _addressTree; - function initializeUintTree(uint64 maxDepth_) external { + function initializeUintTree(uint32 maxDepth_) external { _uintTree.initialize(maxDepth_); } - function initializeBytes32Tree(uint64 maxDepth_) external { + function initializeBytes32Tree(uint32 maxDepth_) external { _bytes32Tree.initialize(maxDepth_); } - function initializeAddressTree(uint64 maxDepth_) external { + function initializeAddressTree(uint32 maxDepth_) external { _addressTree.initialize(maxDepth_); } - function setMaxDepthUintTree(uint64 maxDepth_) external { + function setMaxDepthUintTree(uint32 maxDepth_) external { _uintTree.setMaxDepth(maxDepth_); } - function setMaxDepthBytes32Tree(uint64 maxDepth_) external { + function setMaxDepthBytes32Tree(uint32 maxDepth_) external { _bytes32Tree.setMaxDepth(maxDepth_); } - function setMaxDepthAddressTree(uint64 maxDepth_) external { + function setMaxDepthAddressTree(uint32 maxDepth_) external { _addressTree.setMaxDepth(maxDepth_); } @@ -54,19 +54,43 @@ contract SparseMerkleTreeMock { _addressTree.setHashers(_hash2, _hash3); } - function addUint(uint256 key_, uint256 value_) external { + function addUint(bytes32 key_, uint256 value_) external { _uintTree.add(key_, value_); } + function removeUint(bytes32 key_) external { + _uintTree.remove(key_); + } + + function updateUint(bytes32 key_, uint256 newValue_) external { + _uintTree.update(key_, newValue_); + } + function addBytes32(bytes32 key_, bytes32 value_) external { _bytes32Tree.add(key_, value_); } + function removeBytes32(bytes32 key_) external { + _bytes32Tree.remove(key_); + } + + function updateBytes32(bytes32 key_, bytes32 newValue_) external { + _bytes32Tree.update(key_, newValue_); + } + function addAddress(bytes32 key_, address value_) external { _addressTree.add(key_, value_); } - function getUintProof(uint256 key_) external view returns (SparseMerkleTree.Proof memory) { + function removeAddress(bytes32 key_) external { + _addressTree.remove(key_); + } + + function updateAddress(bytes32 key_, address newValue_) external { + _addressTree.update(key_, newValue_); + } + + function getUintProof(bytes32 key_) external view returns (SparseMerkleTree.Proof memory) { return _uintTree.getProof(key_); } diff --git a/package-lock.json b/package-lock.json index ea302279..e4919f6c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@solarity/solidity-lib", - "version": "2.7.2", + "version": "2.7.3", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@solarity/solidity-lib", - "version": "2.7.2", + "version": "2.7.3", "license": "MIT", "dependencies": { "@openzeppelin/contracts": "4.9.5", @@ -17,8 +17,8 @@ "@uniswap/v3-periphery": "1.4.4" }, "devDependencies": { - "@iden3/js-crypto": "^1.0.3", - "@iden3/js-merkletree": "^1.1.2", + "@iden3/js-crypto": "^1.1.0", + "@iden3/js-merkletree": "^1.2.0", "@metamask/eth-sig-util": "^7.0.1", "@nomicfoundation/hardhat-chai-matchers": "^2.0.6", "@nomicfoundation/hardhat-ethers": "^3.0.5", @@ -906,18 +906,18 @@ } }, "node_modules/@iden3/js-crypto": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/@iden3/js-crypto/-/js-crypto-1.0.3.tgz", - "integrity": "sha512-IFBLIN1O26mM5MVWO8dlABDC6HKLuhYs+30BT+p6dGWsNXB4Rr5JWuhKBUbKlkW78ly3j3+YSoY+J63q7vPs5Q==", + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@iden3/js-crypto/-/js-crypto-1.1.0.tgz", + "integrity": "sha512-MbL7OpOxBoCybAPoorxrp+fwjDVESyDe6giIWxErjEIJy0Q2n1DU4VmKh4vDoCyhJx/RdVgT8Dkb59lKwISqsw==", "dev": true }, "node_modules/@iden3/js-merkletree": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/@iden3/js-merkletree/-/js-merkletree-1.1.2.tgz", - "integrity": "sha512-NT0L+Nk6barcEnSV5q2M6LkZuR889E856e+awnok6iDlmzYMt2l3gulo//zMqGsO6wQvzVECaSn0LJQ7uM5c1A==", + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/@iden3/js-merkletree/-/js-merkletree-1.2.0.tgz", + "integrity": "sha512-tM6jj1v/41qQ6V2K6CTrv0KsNHQ2y/O6Q9RSB1SdN2LTu+cgA9FnD2Qr3whzSvwgUs7X3SjuJgb9OTgs0lDemQ==", "dev": true, "peerDependencies": { - "@iden3/js-crypto": "1.0.3", + "@iden3/js-crypto": "1.1.0", "idb-keyval": "^6.2.0" } }, diff --git a/package.json b/package.json index 2ad29344..9c5a6db9 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@solarity/solidity-lib", - "version": "2.7.2", + "version": "2.7.3", "license": "MIT", "author": "Distributed Lab", "readme": "README.md", @@ -42,8 +42,8 @@ "@uniswap/v3-periphery": "1.4.4" }, "devDependencies": { - "@iden3/js-crypto": "^1.0.3", - "@iden3/js-merkletree": "^1.1.2", + "@iden3/js-crypto": "^1.1.0", + "@iden3/js-merkletree": "^1.2.0", "@metamask/eth-sig-util": "^7.0.1", "@nomicfoundation/hardhat-chai-matchers": "^2.0.6", "@nomicfoundation/hardhat-ethers": "^3.0.5", diff --git a/test/libs/data-structures/SparseMerkleTree.test.ts b/test/libs/data-structures/SparseMerkleTree.test.ts index 630c3fc5..6d5f2639 100644 --- a/test/libs/data-structures/SparseMerkleTree.test.ts +++ b/test/libs/data-structures/SparseMerkleTree.test.ts @@ -9,6 +9,7 @@ import { SparseMerkleTree } from "@/generated-types/ethers/contracts/mock/libs/d import { SignerWithAddress } from "@nomicfoundation/hardhat-ethers/signers"; import { Reverter } from "@/test/helpers/reverter"; +import { ZERO_BYTES32 } from "@/scripts/utils/constants"; import { getPoseidon, poseidonHash } from "@/test/helpers/poseidon-hash"; import "mock-local-storage"; @@ -118,7 +119,7 @@ describe("SparseMerkleTree", () => { expect(await merkleTree.getUintMaxDepth()).to.equal(21); }); - it("should revert if trying to call add function on non-initialized tree", async () => { + it("should revert if trying to call add/remove/update functions on non-initialized tree", async () => { const SparseMerkleTreeMock = await ethers.getContractFactory("SparseMerkleTreeMock", { libraries: { PoseidonUnit2L: await (await getPoseidon(2)).getAddress(), @@ -127,47 +128,55 @@ describe("SparseMerkleTree", () => { }); const newMerkleTree = await SparseMerkleTreeMock.deploy(); - await expect(newMerkleTree.addUint(1n, 1n)).to.be.rejectedWith("SparseMerkleTree: tree is not initialized"); + await expect(newMerkleTree.addUint(ethers.toBeHex(1n, 32), 1n)).to.be.rejectedWith( + "SparseMerkleTree: tree is not initialized", + ); + await expect(newMerkleTree.removeUint(ethers.toBeHex(1n, 32))).to.be.rejectedWith( + "SparseMerkleTree: tree is not initialized", + ); + await expect(newMerkleTree.updateUint(ethers.toBeHex(1n, 32), 1n)).to.be.rejectedWith( + "SparseMerkleTree: tree is not initialized", + ); }); it("should build a Merkle Tree of a predefined size with correct initial values", async () => { const value = 2341n; - const key = BigInt(poseidonHash(ethers.toBeHex(value))); + const key = poseidonHash(ethers.toBeHex(value)); expect(await merkleTree.getUintRoot()).to.equal(await getRoot(localMerkleTree)); await merkleTree.addUint(key, value); - await localMerkleTree.add(key, value); + await localMerkleTree.add(BigInt(key), value); expect(await merkleTree.getUintRoot()).to.equal(await getRoot(localMerkleTree)); expect(await merkleTree.getUintMaxDepth()).to.equal(20); expect(await merkleTree.getUintNodesCount()).to.equal(1); - await compareNodes(await merkleTree.getUintNode(1), key); - await compareNodes(await merkleTree.getUintNodeByKey(key), key); + await compareNodes(await merkleTree.getUintNode(1), BigInt(key)); + await compareNodes(await merkleTree.getUintNodeByKey(key), BigInt(key)); const onchainProof = getOnchainProof(await merkleTree.getUintProof(key)); - expect(await verifyProof(await localMerkleTree.root(), onchainProof, key, value)).to.be.true; + expect(await verifyProof(await localMerkleTree.root(), onchainProof, BigInt(key), value)).to.be.true; }); it("should build a Merkle Tree correctly with multiple elements", async () => { for (let i = 1n; i < 20n; i++) { const value = BigInt(ethers.toBeHex(ethers.hexlify(ethers.randomBytes(28)), 32)); - const key = BigInt(poseidonHash(ethers.toBeHex(`0x` + value.toString(16), 32))); + const key = poseidonHash(ethers.toBeHex(`0x` + value.toString(16), 32)); await merkleTree.addUint(key, value); - await localMerkleTree.add(key, value); + await localMerkleTree.add(BigInt(key), value); expect(await merkleTree.getUintRoot()).to.equal(await getRoot(localMerkleTree)); - await compareNodes(await merkleTree.getUintNodeByKey(key), key); + await compareNodes(await merkleTree.getUintNodeByKey(key), BigInt(key)); const onchainProof = getOnchainProof(await merkleTree.getUintProof(key)); - expect(await verifyProof(await localMerkleTree.root(), onchainProof, key, value)).to.be.true; + expect(await verifyProof(await localMerkleTree.root(), onchainProof, BigInt(key), value)).to.be.true; } expect(await merkleTree.isUintCustomHasherSet()).to.be.true; @@ -175,16 +184,168 @@ describe("SparseMerkleTree", () => { await expect(merkleTree.setUintPoseidonHasher()).to.be.rejectedWith("SparseMerkleTree: tree is not empty"); }); + it("should add and full remove elements from Merkle Tree correctly", async () => { + const keys: string[] = []; + + for (let i = 1n; i < 20n; i++) { + const value = BigInt(ethers.toBeHex(ethers.hexlify(ethers.randomBytes(28)), 32)); + const key = poseidonHash(ethers.toBeHex(`0x` + value.toString(16), 32)); + + await merkleTree.addUint(key, value); + + keys.push(key); + } + + for (let i = 1n; i < 20n; i++) { + const key = ethers.toBeHex(keys[Number(i) - 1], 32); + + await merkleTree.removeUint(key); + } + + expect(await merkleTree.getUintRoot()).to.equal(ZERO_BYTES32); + + expect(await merkleTree.getUintNodesCount()).to.equal(0); + + expect(await merkleTree.isUintCustomHasherSet()).to.be.true; + expect(merkleTree.setUintPoseidonHasher()).to.not.be.rejected; + }); + + it("should maintain idempotence", async () => { + const keys: string[] = []; + let proof; + + for (let i = 1n; i < 20n; i++) { + const value = BigInt(ethers.toBeHex(ethers.hexlify(ethers.randomBytes(28)), 32)); + const key = poseidonHash(ethers.toBeHex(`0x` + value.toString(16), 32)); + + await merkleTree.addUint(key, value); + + if (i > 1n) { + await merkleTree.removeUint(key); + + const hexKey = ethers.toBeHex(keys[Number(i - 2n)], 32); + expect(await merkleTree.getUintProof(hexKey)).to.deep.equal(proof); + + await merkleTree.addUint(key, value); + } + + proof = await merkleTree.getUintProof(key); + + keys.push(key); + } + + for (let key of keys) { + const hexKey = ethers.toBeHex(key, 32); + const value = (await merkleTree.getUintNodeByKey(hexKey)).value; + + proof = await merkleTree.getUintProof(hexKey); + + await merkleTree.removeUint(hexKey); + await merkleTree.addUint(hexKey, value); + + expect(await merkleTree.getUintProof(hexKey)).to.deep.equal(proof); + } + }); + + it("should rebalance elements in Merkle Tree correctly", async () => { + const expectedRoot = "0x2f9bbaa7ab83da6e8d1d8dd05bac16e65fa40b4f6455c1d2ee77e968dfc382dc"; + const keys = [7n, 1n, 5n]; + + for (let key of keys) { + const hexKey = ethers.toBeHex(key, 32); + + await merkleTree.addUint(hexKey, key); + } + + const oldRoot = await merkleTree.getUintRoot(); + + expect(oldRoot).to.equal(expectedRoot); + expect(await merkleTree.getUintNodesCount()).to.equal(6); + + for (let key of keys) { + const hexKey = ethers.toBeHex(key, 32); + + await merkleTree.removeUint(hexKey); + await merkleTree.addUint(hexKey, key); + } + + expect(await merkleTree.getUintRoot()).to.equal(oldRoot); + expect(await merkleTree.getUintNodesCount()).to.equal(6); + }); + + it("should not remove non-existant leaves", async () => { + const keys = [7n, 1n, 5n]; + + for (let key of keys) { + const hexKey = ethers.toBeHex(key, 32); + + await merkleTree.addUint(hexKey, key); + } + + expect(merkleTree.removeUint(ethers.toBeHex(8, 32))).to.be.revertedWith( + "SparseMerkleTree: the node does not exist", + ); + expect(merkleTree.removeUint(ethers.toBeHex(9, 32))).to.be.revertedWith( + "SparseMerkleTree: the leaf does not match", + ); + }); + + it("should update existing leaves", async () => { + const keys: string[] = []; + + for (let i = 1n; i < 20n; i++) { + const value = BigInt(ethers.toBeHex(ethers.hexlify(ethers.randomBytes(28)), 32)); + const key = poseidonHash(ethers.toBeHex(`0x` + value.toString(16), 32)); + + await merkleTree.addUint(key, value); + await localMerkleTree.add(BigInt(key), value); + + keys.push(key); + } + + for (let i = 1n; i < 20n; i++) { + const key = ethers.toBeHex(keys[Number(i) - 1], 32); + const value = BigInt(ethers.toBeHex(ethers.hexlify(ethers.randomBytes(28)), 32)); + + await merkleTree.updateUint(key, value); + await localMerkleTree.update(BigInt(key), value); + + expect(await merkleTree.getUintRoot()).to.equal(await getRoot(localMerkleTree)); + + await compareNodes(await merkleTree.getUintNodeByKey(key), BigInt(key)); + + const onchainProof = getOnchainProof(await merkleTree.getUintProof(key)); + expect(await verifyProof(await localMerkleTree.root(), onchainProof, BigInt(key), value)).to.be.true; + } + }); + + it("should not update non-existant leaves", async () => { + const keys = [7n, 1n, 5n]; + + for (let key of keys) { + const hexKey = ethers.toBeHex(key, 32); + + await merkleTree.addUint(hexKey, key); + } + + expect(merkleTree.updateUint(ethers.toBeHex(8, 32), 1n)).to.be.revertedWith( + "SparseMerkleTree: the node does not exist", + ); + expect(merkleTree.updateUint(ethers.toBeHex(9, 32), 1n)).to.be.revertedWith( + "SparseMerkleTree: the leaf does not match", + ); + }); + it("should generate empty proof on empty tree", async () => { - const onchainProof = getOnchainProof(await merkleTree.getUintProof(1n)); + const onchainProof = getOnchainProof(await merkleTree.getUintProof(ethers.toBeHex(1n, 32))); expect(onchainProof.allSiblings()).to.have.length(0); }); it("should generate an empty proof for but with aux fields", async () => { - await merkleTree.addUint(7n, 1n); + await merkleTree.addUint(ethers.toBeHex(7n, 32), 1n); - const onchainProof = await merkleTree.getUintProof(5n); + const onchainProof = await merkleTree.getUintProof(ethers.toBeHex(5n, 32)); expect(onchainProof.auxKey).to.equal(7n); expect(onchainProof.auxValue).to.equal(1n); @@ -196,19 +357,19 @@ describe("SparseMerkleTree", () => { await localMerkleTree.add(3n, 15n); // key -> 0b011 await localMerkleTree.add(7n, 15n); // key -> 0b111 - await merkleTree.addUint(3n, 15n); - await merkleTree.addUint(7n, 15n); + await merkleTree.addUint(ethers.toBeHex(3n, 32), 15n); + await merkleTree.addUint(ethers.toBeHex(7n, 32), 15n); - let onchainProof = getOnchainProof(await merkleTree.getUintProof(5n)); + let onchainProof = getOnchainProof(await merkleTree.getUintProof(ethers.toBeHex(5n, 32))); expect(await verifyProof(await localMerkleTree.root(), onchainProof, 5n, 0n)).to.be.true; - onchainProof = getOnchainProof(await merkleTree.getUintProof(15n)); + onchainProof = getOnchainProof(await merkleTree.getUintProof(ethers.toBeHex(15n, 32))); expect(await verifyProof(await localMerkleTree.root(), onchainProof, 15n, 15n)).to.be.true; }); it("should revert if trying to add a node with the same key", async () => { const value = 2341n; - const key = BigInt(poseidonHash(ethers.toBeHex(value))); + const key = poseidonHash(ethers.toBeHex(value)); await merkleTree.addUint(key, value); @@ -226,16 +387,18 @@ describe("SparseMerkleTree", () => { await newMerkleTree.initializeUintTree(1); - await newMerkleTree.addUint(1n, 1n); - await newMerkleTree.addUint(2n, 1n); + await newMerkleTree.addUint(ethers.toBeHex(1n, 32), 1n); + await newMerkleTree.addUint(ethers.toBeHex(2n, 32), 1n); - await expect(newMerkleTree.addUint(3n, 1n)).to.be.rejectedWith("SparseMerkleTree: max depth reached"); + await expect(newMerkleTree.addUint(ethers.toBeHex(3n, 32), 1n)).to.be.rejectedWith( + "SparseMerkleTree: max depth reached", + ); }); it("should get empty Node by non-existing key", async () => { expect((await merkleTree.getUintNodeByKey(1n)).nodeType).to.be.equal(0); - await merkleTree.addUint(7n, 1n); + await merkleTree.addUint(ethers.toBeHex(7n, 32), 1n); expect((await merkleTree.getUintNodeByKey(5n)).nodeType).to.be.equal(0); }); @@ -297,6 +460,61 @@ describe("SparseMerkleTree", () => { await expect(merkleTree.setBytes32PoseidonHasher()).to.be.rejectedWith("SparseMerkleTree: tree is not empty"); }); + + it("should add and full remove elements from Merkle Tree correctly", async () => { + const keys: string[] = []; + + for (let i = 1n; i < 20n; i++) { + const value = ethers.toBeHex(ethers.hexlify(ethers.randomBytes(28)), 32); + const key = poseidonHash(value); + + await merkleTree.addBytes32(key, value); + + keys.push(key); + } + + for (let i = 1n; i < 20n; i++) { + const key = ethers.toBeHex(keys[Number(i) - 1], 32); + + await merkleTree.removeBytes32(key); + } + + expect(await merkleTree.getBytes32Root()).to.equal(ZERO_BYTES32); + + expect(await merkleTree.getBytes32NodesCount()).to.equal(0); + + expect(await merkleTree.isBytes32CustomHasherSet()).to.be.true; + expect(merkleTree.setBytes32PoseidonHasher()).to.not.be.rejected; + }); + + it("should update existing leaves", async () => { + const keys: string[] = []; + + for (let i = 1n; i < 20n; i++) { + const value = ethers.toBeHex(ethers.hexlify(ethers.randomBytes(28)), 32); + const key = poseidonHash(value); + + await merkleTree.addBytes32(key, value); + await localMerkleTree.add(BigInt(key), BigInt(value)); + + keys.push(key); + } + + for (let i = 1n; i < 20n; i++) { + const key = ethers.toBeHex(keys[Number(i) - 1], 32); + const value = BigInt(ethers.toBeHex(ethers.hexlify(ethers.randomBytes(28)), 32)); + + await merkleTree.updateBytes32(key, ethers.toBeHex(value, 32)); + await localMerkleTree.update(BigInt(key), BigInt(value)); + + expect(await merkleTree.getBytes32Root()).to.equal(await getRoot(localMerkleTree)); + + await compareNodes(await merkleTree.getBytes32NodeByKey(key), BigInt(key)); + + const onchainProof = getOnchainProof(await merkleTree.getBytes32Proof(key)); + expect(await verifyProof(await localMerkleTree.root(), onchainProof, BigInt(key), BigInt(value))).to.be.true; + } + }); }); describe("Address SMT", () => { @@ -355,5 +573,60 @@ describe("SparseMerkleTree", () => { await expect(merkleTree.setAddressPoseidonHasher()).to.be.rejectedWith("SparseMerkleTree: tree is not empty"); }); + + it("should add and full remove elements from Merkle Tree correctly", async () => { + const keys: string[] = []; + + for (let i = 1n; i < 20n; i++) { + const value = ethers.toBeHex(BigInt(await USER1.getAddress()) + i); + const key = poseidonHash(value); + + await merkleTree.addAddress(key, value); + + keys.push(key); + } + + for (let i = 1n; i < 20n; i++) { + const key = ethers.toBeHex(keys[Number(i) - 1], 32); + + await merkleTree.removeAddress(key); + } + + expect(await merkleTree.getAddressRoot()).to.equal(ZERO_BYTES32); + + expect(await merkleTree.getAddressNodesCount()).to.equal(0); + + expect(await merkleTree.isAddressCustomHasherSet()).to.be.true; + expect(merkleTree.setAddressPoseidonHasher()).to.not.be.rejected; + }); + + it("should update existing leaves", async () => { + const keys: string[] = []; + + for (let i = 1n; i < 20n; i++) { + const value = ethers.toBeHex(BigInt(await USER1.getAddress()) + i); + const key = poseidonHash(value); + + await merkleTree.addAddress(key, value); + await localMerkleTree.add(BigInt(key), BigInt(value)); + + keys.push(key); + } + + for (let i = 1n; i < 20n; i++) { + const key = ethers.toBeHex(keys[Number(i) - 1], 32); + const value = ethers.toBeHex(ethers.hexlify(ethers.randomBytes(20))); + + await merkleTree.updateAddress(key, ethers.toBeHex(value)); + await localMerkleTree.update(BigInt(key), BigInt(value)); + + expect(await merkleTree.getAddressRoot()).to.equal(await getRoot(localMerkleTree)); + + await compareNodes(await merkleTree.getAddressNodeByKey(key), BigInt(key)); + + const onchainProof = getOnchainProof(await merkleTree.getAddressProof(key)); + expect(await verifyProof(await localMerkleTree.root(), onchainProof, BigInt(key), BigInt(value))).to.be.true; + } + }); }); });