diff --git a/src/modules/RuleEngineOperation.sol b/src/modules/RuleEngineOperation.sol index c4c0040..80b7e20 100644 --- a/src/modules/RuleEngineOperation.sol +++ b/src/modules/RuleEngineOperation.sol @@ -21,22 +21,10 @@ abstract contract RuleEngineOperation is AccessControl, RuleInternal, IRuleEngin function setRulesOperation( address[] calldata rules_ ) public onlyRole(RULE_ENGINE_ROLE) { - if(rules_.length == 0){ - revert RuleEngine_ArrayIsEmpty(); - } - for (uint256 i = 0; i < rules_.length; ) { - if( address(rules_[i]) == address(0x0)){ - revert RuleEngine_RuleAddressZeroNotAllowed(); - } - if(_ruleIsPresent[rules_[i]]){ - revert RuleEngine_RuleAlreadyExists(); - } - _ruleIsPresent[rules_[i]] = true; - emit AddRule(rules_[i]); - unchecked { - ++i; - } + if(rules_.length > 0){ + clearRulesOperation(); } + _setRules(rules_); _rulesOperation = rules_; } @@ -45,8 +33,30 @@ abstract contract RuleEngineOperation is AccessControl, RuleInternal, IRuleEngin * */ function clearRulesOperation() public onlyRole(RULE_ENGINE_ROLE) { + uint256 rulesLength = _rulesOperation.length; + for(uint256 i = 0; i < rulesLength; ++i){ + _removeRuleOperation(_rulesOperation[i], i); + } + emit ClearRules(_rulesOperation); + // No longer useful + //_rulesOperation = new address[](0); + } + + /** + * @notice Clear all the rules of the array of rules + * + */ + function _clearRulesOperation() internal { + uint256 index; + // we remove the last element first since it is more optimized. + for(uint256 i = _rulesOperation.length; i > 0; --i){ + unchecked { + // don't underflow since i > 0 + index = i - 1; + } + _removeRuleOperation(_rulesOperation[index], index); + } emit ClearRules(_rulesOperation); - _rulesOperation = new address[](0); } /** @@ -54,9 +64,9 @@ abstract contract RuleEngineOperation is AccessControl, RuleInternal, IRuleEngin * Revert if one rule is a zero address or if the rule is already present * */ - function addRuleOperation(address rule_) public onlyRole(RULE_ENGINE_ROLE) { - RuleInternal.addRule( _rulesOperation, rule_); - emit AddRule(rule_); + function addRuleOperation(IRuleOperation rule_) public onlyRole(RULE_ENGINE_ROLE) { + RuleInternal._addRule( _rulesOperation, address(rule_)); + emit AddRule(address(rule_)); } /** @@ -73,7 +83,24 @@ abstract contract RuleEngineOperation is AccessControl, RuleInternal, IRuleEngin IRuleOperation rule_, uint256 index ) public onlyRole(RULE_ENGINE_ROLE) { - RuleInternal.removeRule(_rulesOperation, address(rule_), index); + _removeRuleOperation(address(rule_), index); + } + + /** + * @notice Remove a rule from the array of rules + * Revert if the rule found at the specified index does not match the rule in argument + * @param rule_ address of the target rule + * @param index the position inside the array of rule + * @dev To reduce the array size, the last rule is moved to the location occupied + * by the rule to remove + * + * + */ + function _removeRuleOperation( + address rule_, + uint256 index + ) internal { + RuleInternal._removeRule(_rulesOperation, rule_, index); emit RemoveRule(address(rule_)); } diff --git a/src/modules/RuleEngineValidation.sol b/src/modules/RuleEngineValidation.sol index ee73425..db0c4ec 100644 --- a/src/modules/RuleEngineValidation.sol +++ b/src/modules/RuleEngineValidation.sol @@ -11,7 +11,7 @@ import "../interfaces/IRuleValidation.sol"; */ abstract contract RuleEngineValidation is AccessControl, RuleInternal, IRuleEngineValidation { /// @dev Array of rules - address[] internal _rules; + address[] internal _rulesValidation; /** @@ -23,23 +23,11 @@ abstract contract RuleEngineValidation is AccessControl, RuleInternal, IRuleEngi function setRulesValidation( address[] calldata rules_ ) public override onlyRole(RULE_ENGINE_ROLE) { - if(rules_.length == 0){ - revert RuleEngine_ArrayIsEmpty(); + if(rules_.length > 0){ + _clearRulesValidation(); } - for (uint256 i = 0; i < rules_.length; ) { - if( address(rules_[i]) == address(0x0)){ - revert RuleEngine_RuleAddressZeroNotAllowed(); - } - if(_ruleIsPresent[rules_[i]]){ - revert RuleEngine_RuleAlreadyExists(); - } - _ruleIsPresent[rules_[i]] = true; - emit AddRule(rules_[i]); - unchecked { - ++i; - } - } - _rules = rules_; + _setRules(rules_); + _rulesValidation = rules_; } /** @@ -47,8 +35,24 @@ abstract contract RuleEngineValidation is AccessControl, RuleInternal, IRuleEngi * */ function clearRulesValidation() public onlyRole(RULE_ENGINE_ROLE) { - emit ClearRules(_rules); - _rules = new address[](0); + _clearRulesValidation(); + } + + /** + * @notice Clear all the rules of the array of rules + * + */ + function _clearRulesValidation() internal { + uint256 index; + // we remove the last element first since it is more optimized. + for(uint256 i = _rulesValidation.length; i > 0; --i){ + unchecked { + // don't underflow since i > 0 + index = i - 1; + } + _removeRuleValidation(_rulesValidation[index], index); + } + emit ClearRules(_rulesValidation); } /** @@ -57,16 +61,7 @@ abstract contract RuleEngineValidation is AccessControl, RuleInternal, IRuleEngi * */ function addRuleValidation(IRuleValidation rule_) public onlyRole(RULE_ENGINE_ROLE) { - if( address(rule_) == address(0x0)) - { - revert RuleEngine_RuleAddressZeroNotAllowed(); - } - if( _ruleIsPresent[address(rule_)]) - { - revert RuleEngine_RuleAlreadyExists(); - } - _rules.push(address(rule_)); - _ruleIsPresent[address(rule_)] = true; + RuleInternal._addRule( _rulesValidation, address(rule_)); emit AddRule(address(rule_)); } @@ -84,7 +79,24 @@ abstract contract RuleEngineValidation is AccessControl, RuleInternal, IRuleEngi IRuleValidation rule_, uint256 index ) public onlyRole(RULE_ENGINE_ROLE) { - RuleInternal.removeRule(_rules, address(rule_), index); + _removeRuleValidation(address(rule_), index); + } + + /** + * @notice Remove a rule from the array of rules + * Revert if the rule found at the specified index does not match the rule in argument + * @param rule_ address of the target rule + * @param index the position inside the array of rule + * @dev To reduce the array size, the last rule is moved to the location occupied + * by the rule to remove + * + * + */ + function _removeRuleValidation( + address rule_, + uint256 index + ) internal { + RuleInternal._removeRule(_rulesValidation, rule_, index); emit RemoveRule(address(rule_)); } @@ -92,15 +104,15 @@ abstract contract RuleEngineValidation is AccessControl, RuleInternal, IRuleEngi * @return The number of rules inside the array */ function rulesCountValidation() external view override returns (uint256) { - return _rules.length; + return _rulesValidation.length; } /** * @notice Get the index of a rule inside the list - * @return index if the rule is found, _rules.length otherwise + * @return index if the rule is found, _rulesValidation.length otherwise */ function getRuleIndexValidation(IRuleValidation rule_) external view returns (uint256 index) { - return RuleInternal.getRuleIndex(_rules, address(rule_)); + return RuleInternal.getRuleIndex(_rulesValidation, address(rule_)); } /** @@ -109,7 +121,7 @@ abstract contract RuleEngineValidation is AccessControl, RuleInternal, IRuleEngi * @return a rule address */ function ruleValidation(uint256 ruleId) external view override returns (address) { - return _rules[ruleId]; + return _rulesValidation[ruleId]; } /** @@ -117,7 +129,7 @@ abstract contract RuleEngineValidation is AccessControl, RuleInternal, IRuleEngi * @return An array of rules */ function rulesValidation() external view override returns (address[] memory) { - return _rules; + return _rulesValidation; } /** @@ -132,9 +144,9 @@ abstract contract RuleEngineValidation is AccessControl, RuleInternal, IRuleEngi address _to, uint256 _amount ) public view override returns (uint8) { - uint256 rulesLength = _rules.length; + uint256 rulesLength = _rulesValidation.length; for (uint256 i = 0; i < rulesLength; ) { - uint8 restriction = IRuleValidation(_rules[i]).detectTransferRestriction( + uint8 restriction = IRuleValidation(_rulesValidation[i]).detectTransferRestriction( _from, _to, _amount @@ -176,11 +188,11 @@ abstract contract RuleEngineValidation is AccessControl, RuleInternal, IRuleEngi function messageForTransferRestriction( uint8 _restrictionCode ) external view override returns (string memory) { - uint256 rulesLength = _rules.length; + uint256 rulesLength = _rulesValidation.length; for (uint256 i = 0; i < rulesLength; ) { - if (IRuleValidation(_rules[i]).canReturnTransferRestrictionCode(_restrictionCode)) { + if (IRuleValidation(_rulesValidation[i]).canReturnTransferRestrictionCode(_restrictionCode)) { return - IRuleValidation(_rules[i]).messageForTransferRestriction(_restrictionCode); + IRuleValidation(_rulesValidation[i]).messageForTransferRestriction(_restrictionCode); } unchecked { ++i; diff --git a/src/modules/RuleInternal.sol b/src/modules/RuleInternal.sol index 5e69fda..b62337e 100644 --- a/src/modules/RuleInternal.sol +++ b/src/modules/RuleInternal.sol @@ -39,6 +39,33 @@ abstract contract RuleInternal is RuleEngineInvariantStorage { _rules = rules_; }*/ + + /** + * @notice Set all the rules, will overwrite all the previous rules. \n + * Revert if one rule is a zero address or if the rule is already present + * + */ + function _setRules( + address[] calldata rules_ + ) internal { + if(rules_.length == 0){ + revert RuleEngine_ArrayIsEmpty(); + } + for (uint256 i = 0; i < rules_.length; ) { + if( address(rules_[i]) == address(0x0)){ + revert RuleEngine_RuleAddressZeroNotAllowed(); + } + if(_ruleIsPresent[rules_[i]]){ + revert RuleEngine_RuleAlreadyExists(); + } + _ruleIsPresent[rules_[i]] = true; + emit AddRule(rules_[i]); + unchecked { + ++i; + } + } + } + /** * @notice Clear all the rules of the array of rules * @@ -53,7 +80,7 @@ abstract contract RuleInternal is RuleEngineInvariantStorage { * Revert if one rule is a zero address or if the rule is already present * */ - function addRule(address[] storage _rules, address rule_) internal { + function _addRule(address[] storage _rules, address rule_) internal { if( address(rule_) == address(0x0)) { revert RuleEngine_RuleAddressZeroNotAllowed(); @@ -77,7 +104,7 @@ abstract contract RuleInternal is RuleEngineInvariantStorage { * * */ - function removeRule( + function _removeRule( address[] storage _rules, address rule_, uint256 index @@ -91,7 +118,7 @@ abstract contract RuleInternal is RuleEngineInvariantStorage { } _rules.pop(); _ruleIsPresent[rule_] = false; - //emit RemoveRule(rule_); + emit RemoveRule(rule_); } /** diff --git a/test/RuleEngine/RuleEngine.t.sol b/test/RuleEngine/RuleEngineValidation.t.sol similarity index 89% rename from test/RuleEngine/RuleEngine.t.sol rename to test/RuleEngine/RuleEngineValidation.t.sol index ca5e60d..1d23808 100644 --- a/test/RuleEngine/RuleEngine.t.sol +++ b/test/RuleEngine/RuleEngineValidation.t.sol @@ -159,7 +159,7 @@ contract RuleEngineTest is Test, HelperContract { (bool resCallBool, ) = address(ruleEngineMock).call( abi.encodeCall(ruleEngineMock.setRulesValidation, ruleWhitelistTab) ); - + ruleEngineMock.rulesValidation(); // Assert - Arrange assertEq(resCallBool, true); resUint256 = ruleEngineMock.rulesCountValidation(); @@ -174,6 +174,56 @@ contract RuleEngineTest is Test, HelperContract { assertEq(resUint256, 0); } + function testCanClearRulesAndAddAgain() public { + // Arrange + vm.prank(WHITELIST_OPERATOR_ADDRESS); + RuleWhitelist ruleWhitelist1 = new RuleWhitelist( + WHITELIST_OPERATOR_ADDRESS, + ZERO_ADDRESS + ); + vm.prank(WHITELIST_OPERATOR_ADDRESS); + RuleWhitelist ruleWhitelist2 = new RuleWhitelist( + WHITELIST_OPERATOR_ADDRESS, + ZERO_ADDRESS + ); + address[] memory ruleWhitelistTab = new address[](2); + ruleWhitelistTab[0] = address(IRuleValidation(ruleWhitelist1)); + ruleWhitelistTab[1] = address(IRuleValidation(ruleWhitelist2)); + + vm.prank(RULE_ENGINE_OPERATOR_ADDRESS); + (bool resCallBool, ) = address(ruleEngineMock).call( + abi.encodeCall(ruleEngineMock.setRulesValidation, ruleWhitelistTab) + ); + + // Act + vm.prank(RULE_ENGINE_OPERATOR_ADDRESS); + ruleEngineMock.clearRulesValidation(); + + // Assert + resUint256 = ruleEngineMock.rulesCountValidation(); + assertEq(resUint256, 0); + + // Can set again the previous rules + vm.prank(RULE_ENGINE_OPERATOR_ADDRESS); + (resCallBool, ) = address(ruleEngineMock).call( + abi.encodeCall(ruleEngineMock.setRulesValidation, ruleWhitelistTab) + ); + assertEq(resCallBool, true); + // Arrange before assert + + // Act + vm.prank(RULE_ENGINE_OPERATOR_ADDRESS); + ruleEngineMock.clearRulesValidation(); + resUint256 = ruleEngineMock.rulesCountValidation(); + assertEq(resUint256, 0); + + // Can add previous rule again + vm.expectEmit(true, false, false, false); + emit AddRule(address(ruleWhitelist1)); + vm.prank(RULE_ENGINE_OPERATOR_ADDRESS); + ruleEngineMock.addRuleValidation(ruleWhitelist1); + } + function testCanAddRule() public { // Arrange vm.prank(WHITELIST_OPERATOR_ADDRESS);