Skip to content

Commit

Permalink
insert and remove optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
mllwchrry committed Jun 6, 2024
1 parent 33afd89 commit 7b90762
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 211 deletions.
168 changes: 75 additions & 93 deletions contracts/libs/data-structures/AvlTree.sol
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ library AvlTree {
*/
function setComparator(
UintAVL storage tree,
function(bytes32, bytes32) view returns (int8) comparator_
function(bytes32, bytes32) view returns (int256) comparator_
) internal {
_setComparator(tree._tree, comparator_);
}
Expand Down Expand Up @@ -201,9 +201,8 @@ library AvlTree {
* @param key_ the key of the node to search for.
* @return True if the node exists, false otherwise.
*/
function search(UintAVL storage tree, uint256 key_) internal view returns (uint64) {
return
_search(tree._tree.tree, tree._tree.root, bytes32(key_), _getComparator(tree._tree));
function get(UintAVL storage tree, uint256 key_) internal view returns (bytes32) {
return _get(tree._tree, bytes32(key_));
}

/**
Expand All @@ -214,27 +213,25 @@ library AvlTree {
* @param key_ the key to get the value for.
* @return The value associated with the key.
*/
function getValue(UintAVL storage tree, uint256 key_) internal view returns (bool, bytes32) {
return _getValue(tree._tree, bytes32(key_));
function tryGet(UintAVL storage tree, uint256 key_) internal view returns (bool, bytes32) {
return _tryGet(tree._tree, bytes32(key_));
}

/**
* @notice The function to retrieve the size of the uint256 tree.
* @param tree self.
* @return The size of the tree.
*/
function treeSize(UintAVL storage tree) internal view returns (uint64) {
return uint64(_treeSize(tree._tree));
function size(UintAVL storage tree) internal view returns (uint64) {
return uint64(_size(tree._tree));
}

function beginTraversal(
UintAVL storage tree
) internal view returns (Traversal.Iterator memory) {
return _beginTraversal(tree._tree);
function first(UintAVL storage tree) internal view returns (Traversal.Iterator memory) {
return _first(tree._tree);
}

function endTraversal(UintAVL storage tree) internal view returns (Traversal.Iterator memory) {
return _endTraversal(tree._tree);
function last(UintAVL storage tree) internal view returns (Traversal.Iterator memory) {
return _last(tree._tree);
}

/**
Expand Down Expand Up @@ -263,7 +260,7 @@ library AvlTree {
*/
function setComparator(
Bytes32AVL storage tree,
function(bytes32, bytes32) view returns (int8) comparator_
function(bytes32, bytes32) view returns (int256) comparator_
) internal {
_setComparator(tree._tree, comparator_);
}
Expand Down Expand Up @@ -299,8 +296,8 @@ library AvlTree {
* @param key_ the key of the node to search for.
* @return True if the node exists, false otherwise.
*/
function search(Bytes32AVL storage tree, bytes32 key_) internal view returns (uint64) {
return _search(tree._tree.tree, tree._tree.root, key_, _getComparator(tree._tree));
function get(Bytes32AVL storage tree, bytes32 key_) internal view returns (bytes32) {
return _get(tree._tree, key_);
}

/**
Expand All @@ -311,32 +308,25 @@ library AvlTree {
* @param key_ the key to get the value for.
* @return The value associated with the key.
*/
function getValue(
Bytes32AVL storage tree,
bytes32 key_
) internal view returns (bool, bytes32) {
return _getValue(tree._tree, key_);
function tryGet(Bytes32AVL storage tree, bytes32 key_) internal view returns (bool, bytes32) {
return _tryGet(tree._tree, key_);
}

/**
* @notice The function to retrieve the size of the bytes32 tree.
* @param tree self.
* @return The size of the tree.
*/
function treeSize(Bytes32AVL storage tree) internal view returns (uint64) {
return uint64(_treeSize(tree._tree));
function size(Bytes32AVL storage tree) internal view returns (uint64) {
return uint64(_size(tree._tree));
}

function beginTraversal(
Bytes32AVL storage tree
) internal view returns (Traversal.Iterator memory) {
return _beginTraversal(tree._tree);
function first(Bytes32AVL storage tree) internal view returns (Traversal.Iterator memory) {
return _first(tree._tree);
}

function endTraversal(
Bytes32AVL storage tree
) internal view returns (Traversal.Iterator memory) {
return _endTraversal(tree._tree);
function last(Bytes32AVL storage tree) internal view returns (Traversal.Iterator memory) {
return _last(tree._tree);
}

/**
Expand Down Expand Up @@ -365,7 +355,7 @@ library AvlTree {
*/
function setComparator(
AddressAVL storage tree,
function(bytes32, bytes32) view returns (int8) comparator_
function(bytes32, bytes32) view returns (int256) comparator_
) internal {
_setComparator(tree._tree, comparator_);
}
Expand Down Expand Up @@ -401,14 +391,8 @@ library AvlTree {
* @param key_ the key of the node to search for.
* @return True if the node exists, false otherwise.
*/
function search(AddressAVL storage tree, address key_) internal view returns (uint64) {
return
_search(
tree._tree.tree,
tree._tree.root,
bytes32(uint256(uint160(key_))),
_getComparator(tree._tree)
);
function get(AddressAVL storage tree, address key_) internal view returns (bytes32) {
return _get(tree._tree, bytes32(uint256(uint160(key_))));
}

/**
Expand All @@ -419,32 +403,25 @@ library AvlTree {
* @param key_ the key to get the value for.
* @return The value associated with the key.
*/
function getValue(
AddressAVL storage tree,
address key_
) internal view returns (bool, bytes32) {
return _getValue(tree._tree, bytes32(uint256(uint160(key_))));
function tryGet(AddressAVL storage tree, address key_) internal view returns (bool, bytes32) {
return _tryGet(tree._tree, bytes32(uint256(uint160(key_))));
}

/**
* @notice The function to retrieve the size of the address tree.
* @param tree self.
* @return The size of the tree.
*/
function treeSize(AddressAVL storage tree) internal view returns (uint64) {
return uint64(_treeSize(tree._tree));
function size(AddressAVL storage tree) internal view returns (uint64) {
return uint64(_size(tree._tree));
}

function beginTraversal(
AddressAVL storage tree
) internal view returns (Traversal.Iterator memory) {
return _beginTraversal(tree._tree);
function first(AddressAVL storage tree) internal view returns (Traversal.Iterator memory) {
return _first(tree._tree);
}

function endTraversal(
AddressAVL storage tree
) internal view returns (Traversal.Iterator memory) {
return _endTraversal(tree._tree);
function last(AddressAVL storage tree) internal view returns (Traversal.Iterator memory) {
return _last(tree._tree);
}

/**
Expand Down Expand Up @@ -477,14 +454,14 @@ library AvlTree {
uint64 removedCount;
bool isCustomComparatorSet;
mapping(uint64 => Node) tree;
function(bytes32, bytes32) view returns (int8) comparator;
function(bytes32, bytes32) view returns (int256) comparator;
}

function _setComparator(
Tree storage tree,
function(bytes32, bytes32) view returns (int8) comparator_
function(bytes32, bytes32) view returns (int256) comparator_
) private {
require(_treeSize(tree) == 0, "AvlTree: the tree must be empty");
require(_size(tree) == 0, "AvlTree: the tree must be empty");

tree.isCustomComparatorSet = true;

Expand All @@ -493,10 +470,6 @@ library AvlTree {

function _insert(Tree storage tree, bytes32 key_, bytes32 value_) private {
require(key_ != 0, "AvlTree: key is not allowed to be 0");
require(
_search(tree.tree, tree.root, key_, _getComparator(tree)) == 0,
"AvlTree: the node already exists"
);

tree.totalCount++;

Expand All @@ -513,10 +486,6 @@ library AvlTree {

function _remove(Tree storage tree, bytes32 key_) private {
require(key_ != 0, "AvlTree: key is not allowed to be 0");
require(
_search(tree.tree, tree.root, key_, _getComparator(tree)) != 0,
"AvlTree: the node doesn't exist"
);

tree.root = _removeNode(tree.tree, tree.root, 0, bytes32(key_), _getComparator(tree));

Expand All @@ -530,8 +499,10 @@ library AvlTree {
uint64 parent_,
bytes32 key_,
bytes32 value_,
function(bytes32, bytes32) view returns (int8) comparator_
function(bytes32, bytes32) view returns (int256) comparator_
) private returns (uint64) {
int256 comparison_ = comparator_(key_, _tree[node_].key);

if (_tree[node_].key == 0) {
_tree[index_] = Node({
key: key_,
Expand All @@ -545,7 +516,7 @@ library AvlTree {
return index_;
}

if (comparator_(key_, _tree[node_].key) <= 0) {
if (comparison_ < 0) {
_tree[node_].left = _insertNode(
_tree,
index_,
Expand All @@ -555,6 +526,8 @@ library AvlTree {
value_,
comparator_
);
} else if (comparison_ == 0) {
revert("AvlTree: the node already exists");
} else {
_tree[node_].right = _insertNode(
_tree,
Expand All @@ -575,9 +548,11 @@ library AvlTree {
uint64 node_,
uint64 parent_,
bytes32 key_,
function(bytes32, bytes32) view returns (int8) comparator_
function(bytes32, bytes32) view returns (int256) comparator_
) private returns (uint64) {
int8 comparison_ = comparator_(key_, _tree[node_].key);
require(node_ != 0, "AvlTree: the node doesn't exist");

int256 comparison_ = comparator_(key_, _tree[node_].key);

if (comparison_ == 0) {
uint64 left_ = _tree[node_].left;
Expand Down Expand Up @@ -633,7 +608,7 @@ library AvlTree {
return _balance(_tree, node_);
}

function _rotateLeft(
function _rotateRight(
mapping(uint64 => Node) storage _tree,
uint64 node_
) private returns (uint64) {
Expand All @@ -657,7 +632,7 @@ library AvlTree {
return temp_;
}

function _rotateRight(
function _rotateLeft(
mapping(uint64 => Node) storage _tree,
uint64 node_
) private returns (uint64) {
Expand Down Expand Up @@ -692,16 +667,16 @@ library AvlTree {

if (_left.height > _right.height + 1) {
if (_tree[_left.right].height > _tree[_left.left].height) {
_tree[node_].left = _rotateRight(_tree, _tree[node_].left);
_tree[node_].left = _rotateLeft(_tree, _tree[node_].left);
}

return _rotateLeft(_tree, node_);
return _rotateRight(_tree, node_);
} else if (_right.height > _left.height + 1) {
if (_tree[_right.left].height > _tree[_right.right].height) {
_tree[node_].right = _rotateLeft(_tree, _tree[node_].right);
_tree[node_].right = _rotateRight(_tree, _tree[node_].right);
}

return _rotateRight(_tree, node_);
return _rotateLeft(_tree, node_);
}

return node_;
Expand All @@ -717,13 +692,13 @@ library AvlTree {
mapping(uint64 => Node) storage _tree,
uint64 node_,
bytes32 key_,
function(bytes32, bytes32) view returns (int8) comparator_
function(bytes32, bytes32) view returns (int256) comparator_
) private view returns (uint64) {
if (node_ == 0) {
return 0;
}

int8 comparison_ = comparator_(key_, _tree[node_].key);
int256 comparison_ = comparator_(key_, _tree[node_].key);

if (comparison_ == 0) {
return node_;
Expand All @@ -734,7 +709,15 @@ library AvlTree {
}
}

function _getValue(Tree storage tree, bytes32 key_) private view returns (bool, bytes32) {
function _get(Tree storage tree, bytes32 key_) private view returns (bytes32) {
uint64 index_ = _search(tree.tree, tree.root, key_, _getComparator(tree));

require(index_ != 0, "AvlTree: the node doesn't exist");

return tree.tree[index_].value;
}

function _tryGet(Tree storage tree, bytes32 key_) private view returns (bool, bytes32) {
uint64 index_ = _search(tree.tree, tree.root, key_, _getComparator(tree));

if (index_ == 0) {
Expand All @@ -744,46 +727,45 @@ library AvlTree {
return (true, tree.tree[index_].value);
}

function _treeSize(Tree storage tree) private view returns (uint256) {
function _size(Tree storage tree) private view returns (uint256) {
return tree.totalCount - tree.removedCount;
}

function _beginTraversal(Tree storage tree) private view returns (Traversal.Iterator memory) {
function _first(Tree storage tree) private view returns (Traversal.Iterator memory) {
uint256 treeMappingSlot_;
assembly {
treeMappingSlot_ := add(tree.slot, 1)
}

uint64 root_ = tree.root;

if (root_ == 0) {
return Traversal.Iterator({treeMappingSlot: treeMappingSlot_, currentNode: 0});
}

uint64 current_ = root_;
uint64 current_ = tree.root;
while (tree.tree[current_].left != 0) {
current_ = tree.tree[current_].left;
}

return Traversal.Iterator({treeMappingSlot: treeMappingSlot_, currentNode: current_});
}

function _endTraversal(Tree storage tree) private pure returns (Traversal.Iterator memory) {
function _last(Tree storage tree) private view returns (Traversal.Iterator memory) {
uint256 treeMappingSlot_;
assembly {
treeMappingSlot_ := add(tree.slot, 1)
}

return Traversal.Iterator({treeMappingSlot: treeMappingSlot_, currentNode: 0});
uint64 current_ = tree.root;
while (tree.tree[current_].right != 0) {
current_ = tree.tree[current_].right;
}

return Traversal.Iterator({treeMappingSlot: treeMappingSlot_, currentNode: current_});
}

function _getComparator(
Tree storage tree
) private view returns (function(bytes32, bytes32) view returns (int8)) {
) private view returns (function(bytes32, bytes32) view returns (int256)) {
return tree.isCustomComparatorSet ? tree.comparator : _defaultComparator;
}

function _defaultComparator(bytes32 key1_, bytes32 key2_) private pure returns (int8) {
function _defaultComparator(bytes32 key1_, bytes32 key2_) private pure returns (int256) {
if (key1_ < key2_) {
return -1;
}
Expand Down
Loading

0 comments on commit 7b90762

Please sign in to comment.