From 0f9109f2524a90083f5e5ebd6b8fd261df8a3591 Mon Sep 17 00:00:00 2001 From: Ryan Sauge Date: Thu, 30 Nov 2023 16:47:22 +0100 Subject: [PATCH 1/2] Separate snapshotModuleInternal in two contracts --- contracts/modules/CMTAT_BASE.sol | 19 +- .../internal/ERC20SnapshotModuleInternal.sol | 427 ++---------------- .../internal/base/SnapshotModuleBase.sol | 410 +++++++++++++++++ .../CMTATSnapshot/CMTAT_BASE_SnapshotTest.sol | 19 +- package-lock.json | 2 + 5 files changed, 459 insertions(+), 418 deletions(-) create mode 100644 contracts/modules/internal/base/SnapshotModuleBase.sol diff --git a/contracts/modules/CMTAT_BASE.sol b/contracts/modules/CMTAT_BASE.sol index bc0c284a..395a17f6 100644 --- a/contracts/modules/CMTAT_BASE.sol +++ b/contracts/modules/CMTAT_BASE.sol @@ -110,9 +110,11 @@ abstract contract CMTAT_BASE is __Enforcement_init_unchained(); /* SnapshotModule: - Add this call in case you add the SnapshotModule + Add these two calls in case you add the SnapshotModule + + __SnapshotModuleBase_init_unchained(); __ERC20Snapshot_init_unchained(); - */ + */ __Validation_init_unchained(ruleEngine_); /* Wrapper */ @@ -174,10 +176,7 @@ abstract contract CMTAT_BASE is /** * @dev - * SnapshotModule: - * - override SnapshotModuleInternal if you add the SnapshotModule - * e.g. override(ERC20SnapshotModuleInternal, ERC20Upgradeable) - * - remove the keyword view + * */ function _update( address from, @@ -187,13 +186,13 @@ abstract contract CMTAT_BASE is if (!ValidationModule.validateTransfer(from, to, amount)) { revert Errors.CMTAT_InvalidTransfer(from, to, amount); } - ERC20Upgradeable._update(from, to, amount); - // We call the SnapshotModule only if the transfer is valid /* SnapshotModule: - Add this call in case you add the SnapshotModule - ERC20SnapshotModuleInternal._update(from, to, amount); + Add this in case you add the SnapshotModule + We call the SnapshotModule only if the transfer is valid */ + // ERC20SnapshotModuleInternal._snapshotUpdate(from, to); + ERC20Upgradeable._update(from, to, amount); } /** diff --git a/contracts/modules/internal/ERC20SnapshotModuleInternal.sol b/contracts/modules/internal/ERC20SnapshotModuleInternal.sol index c868a15a..83c1a91d 100644 --- a/contracts/modules/internal/ERC20SnapshotModuleInternal.sol +++ b/contracts/modules/internal/ERC20SnapshotModuleInternal.sol @@ -8,9 +8,9 @@ import "../../../openzeppelin-contracts-upgradeable/contracts/token/ERC20/ERC20U import {Arrays} from '@openzeppelin/contracts/utils/Arrays.sol'; import "../../libraries/Errors.sol"; - +import "./base/SnapshotModuleBase.sol"; /** - * @dev Snapshot module. + * @dev Snapshot module internal. * * Useful to take a snapshot of token holder balance and total supply at a specific time * Inspired by Openzeppelin - ERC20Snapshot but use the time as Id instead of a counter. @@ -18,43 +18,9 @@ import "../../libraries/Errors.sol"; because overriding this function can break the contract. */ -abstract contract ERC20SnapshotModuleInternal is ERC20Upgradeable { +abstract contract ERC20SnapshotModuleInternal is SnapshotModuleBase, ERC20Upgradeable { using Arrays for uint256[]; - /** - @notice Emitted when the snapshot with the specified oldTime was scheduled or rescheduled at the specified newTime. - */ - event SnapshotSchedule(uint256 indexed oldTime, uint256 indexed newTime); - - /** - @notice Emitted when the scheduled snapshot with the specified time was cancelled. - */ - event SnapshotUnschedule(uint256 indexed time); - - /** - @dev See {OpenZeppelin - ERC20Snapshot} - Snapshotted values have arrays of ids (time) and the value corresponding to that id. - ids is expected to be sorted in ascending order, and to contain no repeated elements - because we use findUpperBound in the function _valueAt - */ - struct Snapshots { - uint256[] ids; - uint256[] values; - } - - /** - @dev See {OpenZeppelin - ERC20Snapshot} - */ - mapping(address => Snapshots) private _accountBalanceSnapshots; - Snapshots private _totalSupplySnapshots; - - /** - @dev time instead of a counter for OpenZeppelin - */ - // Initialized to zero - uint256 private _currentSnapshotTime; - // Initialized to zero - uint256 private _currentSnapshotIndex; /** @dev @@ -72,6 +38,7 @@ abstract contract ERC20SnapshotModuleInternal is ERC20Upgradeable { ) internal onlyInitializing { __Context_init_unchained(); __ERC20_init(name_, symbol_); + __SnapshotModuleBase_init_unchained(); __ERC20Snapshot_init_unchained(); } @@ -80,242 +47,15 @@ abstract contract ERC20SnapshotModuleInternal is ERC20Upgradeable { // _currentSnapshotTime & _currentSnapshotIndex are initialized to zero } - /** - @dev schedule a snapshot at the specified time - You can only add a snapshot after the last previous - */ - function _scheduleSnapshot(uint256 time) internal { - // Check the time firstly to avoid an useless read of storage - if (time <= block.timestamp) { - revert Errors.CMTAT_SnapshotModule_SnapshotScheduledInThePast( - time, - block.timestamp - ); - } - - if (_scheduledSnapshots.length > 0) { - // We check the last snapshot on the list - uint256 nextSnapshotTime = _scheduledSnapshots[ - _scheduledSnapshots.length - 1 - ]; - if (time < nextSnapshotTime) { - revert Errors - .CMTAT_SnapshotModule_SnapshotTimestampBeforeLastSnapshot( - time, - nextSnapshotTime - ); - } else if (time == nextSnapshotTime) { - revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyExists(); - } - } - _scheduledSnapshots.push(time); - emit SnapshotSchedule(0, time); - } - - /** - @dev schedule a snapshot at the specified time - */ - function _scheduleSnapshotNotOptimized(uint256 time) internal { - if (time <= block.timestamp) { - revert Errors.CMTAT_SnapshotModule_SnapshotScheduledInThePast( - time, - block.timestamp - ); - } - (bool isFound, uint256 index) = _findScheduledSnapshotIndex(time); - // Perfect match - if (isFound) { - revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyExists(); - } - // if no upper bound match found, we push the snapshot at the end of the list - if (index == _scheduledSnapshots.length) { - _scheduledSnapshots.push(time); - } else { - _scheduledSnapshots.push( - _scheduledSnapshots[_scheduledSnapshots.length - 1] - ); - for (uint256 i = _scheduledSnapshots.length - 2; i > index; ) { - _scheduledSnapshots[i] = _scheduledSnapshots[i - 1]; - unchecked { - --i; - } - } - _scheduledSnapshots[index] = time; - } - emit SnapshotSchedule(0, time); - } - - /** - @dev reschedule a scheduled snapshot at the specified newTime - */ - function _rescheduleSnapshot(uint256 oldTime, uint256 newTime) internal { - // Check the time firstly to avoid an useless read of storage - if (oldTime <= block.timestamp) { - revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyDone(); - } - if (newTime <= block.timestamp) { - revert Errors.CMTAT_SnapshotModule_SnapshotScheduledInThePast( - newTime, - block.timestamp - ); - } - if (_scheduledSnapshots.length == 0) { - revert Errors.CMTAT_SnapshotModule_NoSnapshotScheduled(); - } - (bool foundOld, uint256 index) = _findScheduledSnapshotIndex(oldTime); - if (!foundOld) { - revert Errors.CMTAT_SnapshotModule_SnapshotNotFound(); - } - if (index + 1 < _scheduledSnapshots.length) { - uint256 nextSnapshotTime = _scheduledSnapshots[index + 1]; - if (newTime > nextSnapshotTime) { - revert Errors - .CMTAT_SnapshotModule_SnapshotTimestampAfterNextSnapshot( - newTime, - nextSnapshotTime - ); - } else if (newTime == nextSnapshotTime) { - revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyExists(); - } - } - if (index > 0) { - if (newTime <= _scheduledSnapshots[index - 1]) - revert Errors - .CMTAT_SnapshotModule_SnapshotTimestampBeforePreviousSnapshot( - newTime, - _scheduledSnapshots[index - 1] - ); - } - _scheduledSnapshots[index] = newTime; - - emit SnapshotSchedule(oldTime, newTime); - } - - /** - @dev unschedule the last scheduled snapshot - */ - function _unscheduleLastSnapshot(uint256 time) internal { - // Check the time firstly to avoid an useless read of storage - if (time <= block.timestamp) { - revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyDone(); - } - if (_scheduledSnapshots.length == 0) { - revert Errors.CMTAT_SnapshotModule_NoSnapshotScheduled(); - } - // All snapshot time are unique, so we do not check the indice - if (time != _scheduledSnapshots[_scheduledSnapshots.length - 1]) { - revert Errors.CMTAT_SnapshotModule_SnapshotNotFound(); - } - _scheduledSnapshots.pop(); - emit SnapshotUnschedule(time); - } - - /** - @dev unschedule (remove) a scheduled snapshot in three steps: - - search the snapshot in the list - - If found, move all next snapshots one position to the left - - Reduce the array size by deleting the last snapshot - */ - function _unscheduleSnapshotNotOptimized(uint256 time) internal { - if (time <= block.timestamp) { - revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyDone(); - } - (bool isFound, uint256 index) = _findScheduledSnapshotIndex(time); - if (!isFound) { - revert Errors.CMTAT_SnapshotModule_SnapshotNotFound(); - } - for (uint256 i = index; i + 1 < _scheduledSnapshots.length; ) { - _scheduledSnapshots[i] = _scheduledSnapshots[i + 1]; - unchecked { - ++i; - } - } - _scheduledSnapshots.pop(); - } - - /** - @dev - Get the next scheduled snapshots - */ - function getNextSnapshots() public view returns (uint256[] memory) { - uint256[] memory nextScheduledSnapshot = new uint256[](0); - // no snapshot were planned - if (_scheduledSnapshots.length > 0) { - ( - uint256 timeLowerBound, - uint256 indexLowerBound - ) = _findScheduledMostRecentPastSnapshot(); - // All snapshots are situated in the futur - if ((timeLowerBound == 0) && (_currentSnapshotTime == 0)) { - return _scheduledSnapshots; - } else { - // There are snapshots situated in the futur - if (indexLowerBound + 1 != _scheduledSnapshots.length) { - // All next snapshots are located after the snapshot specified by indexLowerBound - uint256 arraySize = _scheduledSnapshots.length - - indexLowerBound - - 1; - nextScheduledSnapshot = new uint256[](arraySize); - for (uint256 i; i < arraySize; ) { - nextScheduledSnapshot[i] = _scheduledSnapshots[ - indexLowerBound + 1 + i - ]; - unchecked { - ++i; - } - } - } - } - } - return nextScheduledSnapshot; - } - - /** - @dev - Get all snapshots - */ - function getAllSnapshots() public view returns (uint256[] memory) { - return _scheduledSnapshots; - } - - /** - @notice Return the number of tokens owned by the given owner at the time when the snapshot with the given time was created. - @return value stored in the snapshot, or the actual balance if no snapshot - */ - function snapshotBalanceOf( - uint256 time, - address owner - ) public view returns (uint256) { - (bool snapshotted, uint256 value) = _valueAt( - time, - _accountBalanceSnapshots[owner] - ); - - return snapshotted ? value : balanceOf(owner); - } - - /** - @dev See {OpenZeppelin - ERC20Snapshot} - Retrieves the total supply at the specified time. - @return value stored in the snapshot, or the actual totalSupply if no snapshot - */ - function snapshotTotalSupply(uint256 time) public view returns (uint256) { - (bool snapshotted, uint256 value) = _valueAt( - time, - _totalSupplySnapshots - ); - return snapshotted ? value : totalSupply(); - } /** @dev Update balance and/or total supply snapshots before the values are modified. This is implemented in the _beforeTokenTransfer hook, which is executed for _mint, _burn, and _transfer operations. */ - function _update( + function _snapshotUpdate( address from, - address to, - uint256 amount - ) internal virtual override { + address to + ) internal virtual { _setCurrentSnapshot(); if (from != address(0)) { // for both burn and transfer @@ -332,42 +72,8 @@ abstract contract ERC20SnapshotModuleInternal is ERC20Upgradeable { _updateAccountSnapshot(to); _updateTotalSupplySnapshot(); } - ERC20Upgradeable._update(from, to, amount); } - /** - @dev See {OpenZeppelin - ERC20Snapshot} - @param time where we want a snapshot - @param snapshots the struct where are stored the snapshots - @return snapshotExist true if a snapshot is found, false otherwise - value 0 if no snapshot, balance value if a snapshot exists - */ - function _valueAt( - uint256 time, - Snapshots storage snapshots - ) private view returns (bool snapshotExist, uint256 value) { - // When a valid snapshot is queried, there are three possibilities: - // a) The queried value was not modified after the snapshot was taken. Therefore, a snapshot entry was never - // created for this id, and all stored snapshot ids are smaller than the requested one. The value that corresponds - // to this id is the current one. - // b) The queried value was modified after the snapshot was taken. Therefore, there will be an entry with the - // requested id, and its value is the one to return. - // c) More snapshots were created after the requested one, and the queried value was later modified. There will be - // no entry for the requested id: the value that corresponds to it is that of the smallest snapshot id that is - // larger than the requested one. - // - // In summary, we need to find an element in an array, returning the index of the smallest value that is larger if - // it is not found, unless said value doesn't exist (e.g. when all values are smaller). Arrays.findUpperBound does - // exactly this. - - uint256 index = snapshots.ids.findUpperBound(time); - - if (index == snapshots.ids.length) { - return (false, 0); - } else { - return (true, snapshots.values[index]); - } - } /** @dev See {OpenZeppelin - ERC20Snapshot} @@ -383,111 +89,34 @@ abstract contract ERC20SnapshotModuleInternal is ERC20Upgradeable { _updateSnapshot(_totalSupplySnapshots, totalSupply()); } - /** - @dev - Inside a struct Snapshots: - - Update the array ids to the current Snapshot time if this one is greater than the snapshot times stored in ids. - - Update the value to the corresponding value. - */ - function _updateSnapshot( - Snapshots storage snapshots, - uint256 currentValue - ) private { - uint256 current = _currentSnapshotTime; - if (_lastSnapshot(snapshots.ids) < current) { - snapshots.ids.push(current); - snapshots.values.push(currentValue); - } - } /** - @dev - Set the currentSnapshotTime by retrieving the most recent snapshot - if a snapshot exists, clear all past scheduled snapshot - */ - function _setCurrentSnapshot() internal { - ( - uint256 scheduleSnapshotTime, - uint256 scheduleSnapshotIndex - ) = _findScheduledMostRecentPastSnapshot(); - if (scheduleSnapshotTime > 0) { - _currentSnapshotTime = scheduleSnapshotTime; - _currentSnapshotIndex = scheduleSnapshotIndex; - } - } - - /** - @return the last snapshot time inside a snapshot ids array + @notice Return the number of tokens owned by the given owner at the time when the snapshot with the given time was created. + @return value stored in the snapshot, or the actual balance if no snapshot */ - function _lastSnapshot( - uint256[] storage ids - ) private view returns (uint256) { - if (ids.length == 0) { - return 0; - } else { - return ids[ids.length - 1]; - } - } + function snapshotBalanceOf( + uint256 time, + address owner + ) public view returns (uint256) { + (bool snapshotted, uint256 value) = _valueAt( + time, + _accountBalanceSnapshots[owner] + ); - /** - @dev Find the snapshot index at the specified time - @return (true, index) if the snapshot exists, (false, 0) otherwise - */ - function _findScheduledSnapshotIndex( - uint256 time - ) private view returns (bool, uint256) { - uint256 indexFound = _scheduledSnapshots.findUpperBound(time); - uint256 _scheduledSnapshotsLength = _scheduledSnapshots.length; - // Exact match - if ( - indexFound != _scheduledSnapshotsLength && - _scheduledSnapshots[indexFound] == time - ) { - return (true, indexFound); - } - // Upper bound match - else if (indexFound != _scheduledSnapshotsLength) { - return (false, indexFound); - } - // no match - else { - return (false, _scheduledSnapshotsLength); - } + return snapshotted ? value : balanceOf(owner); } - /** - @dev find the most recent past snapshot - The complexity of this function is O(N) because we go through the whole list + /** + @dev See {OpenZeppelin - ERC20Snapshot} + Retrieves the total supply at the specified time. + @return value stored in the snapshot, or the actual totalSupply if no snapshot */ - function _findScheduledMostRecentPastSnapshot() - private - view - returns (uint256 time, uint256 index) - { - uint256 currentArraySize = _scheduledSnapshots.length; - // no snapshot or the current snapshot already points on the last snapshot - if ( - currentArraySize == 0 || - ((_currentSnapshotIndex + 1 == currentArraySize) && (time != 0)) - ) { - return (0, currentArraySize); - } - // mostRecent is initialized in the loop - uint256 mostRecent; - index = currentArraySize; - for (uint256 i = _currentSnapshotIndex; i < currentArraySize; ) { - if (_scheduledSnapshots[i] <= block.timestamp) { - mostRecent = _scheduledSnapshots[i]; - index = i; - } else { - // All snapshot are planned in the futur - break; - } - unchecked { - ++i; - } - } - return (mostRecent, index); + function snapshotTotalSupply(uint256 time) public view returns (uint256) { + (bool snapshotted, uint256 value) = _valueAt( + time, + _totalSupplySnapshots + ); + return snapshotted ? value : totalSupply(); } uint256[50] private __gap; diff --git a/contracts/modules/internal/base/SnapshotModuleBase.sol b/contracts/modules/internal/base/SnapshotModuleBase.sol new file mode 100644 index 00000000..935eba94 --- /dev/null +++ b/contracts/modules/internal/base/SnapshotModuleBase.sol @@ -0,0 +1,410 @@ +//SPDX-License-Identifier: MPL-2.0 + +pragma solidity ^0.8.20; + +import "../../../../openzeppelin-contracts-upgradeable/contracts/proxy/utils/Initializable.sol"; +import {Arrays} from '@openzeppelin/contracts/utils/Arrays.sol'; + +import "../../../libraries/Errors.sol"; + +/** + * @dev Base for the Snapshot module + * + * Useful to take a snapshot of token holder balance and total supply at a specific time + * Inspired by Openzeppelin - ERC20Snapshot but use the time as Id instead of a counter. + * Contrary to OpenZeppelin, the function _getCurrentSnapshotId is not available + because overriding this function can break the contract. + */ + +abstract contract SnapshotModuleBase is Initializable { + using Arrays for uint256[]; + + /** + @notice Emitted when the snapshot with the specified oldTime was scheduled or rescheduled at the specified newTime. + */ + event SnapshotSchedule(uint256 indexed oldTime, uint256 indexed newTime); + + /** + @notice Emitted when the scheduled snapshot with the specified time was cancelled. + */ + event SnapshotUnschedule(uint256 indexed time); + + /** + @dev See {OpenZeppelin - ERC20Snapshot} + Snapshotted values have arrays of ids (time) and the value corresponding to that id. + ids is expected to be sorted in ascending order, and to contain no repeated elements + because we use findUpperBound in the function _valueAt + */ + struct Snapshots { + uint256[] ids; + uint256[] values; + } + + /** + @dev See {OpenZeppelin - ERC20Snapshot} + */ + mapping(address => Snapshots) internal _accountBalanceSnapshots; + Snapshots internal _totalSupplySnapshots; + + /** + @dev time instead of a counter for OpenZeppelin + */ + // Initialized to zero + uint256 private _currentSnapshotTime; + // Initialized to zero + uint256 private _currentSnapshotIndex; + + /** + @dev + list of scheduled snapshot (time) + This list is sorted in ascending order + */ + uint256[] private _scheduledSnapshots; + + function __SnapshotModuleBase_init_unchained() internal onlyInitializing { + // Nothing to do + // _currentSnapshotTime & _currentSnapshotIndex are initialized to zero + } + + /** + @dev schedule a snapshot at the specified time + You can only add a snapshot after the last previous + */ + function _scheduleSnapshot(uint256 time) internal { + // Check the time firstly to avoid an useless read of storage + if (time <= block.timestamp) { + revert Errors.CMTAT_SnapshotModule_SnapshotScheduledInThePast( + time, + block.timestamp + ); + } + + if (_scheduledSnapshots.length > 0) { + // We check the last snapshot on the list + uint256 nextSnapshotTime = _scheduledSnapshots[ + _scheduledSnapshots.length - 1 + ]; + if (time < nextSnapshotTime) { + revert Errors + .CMTAT_SnapshotModule_SnapshotTimestampBeforeLastSnapshot( + time, + nextSnapshotTime + ); + } else if (time == nextSnapshotTime) { + revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyExists(); + } + } + _scheduledSnapshots.push(time); + emit SnapshotSchedule(0, time); + } + + /** + @dev schedule a snapshot at the specified time + */ + function _scheduleSnapshotNotOptimized(uint256 time) internal { + if (time <= block.timestamp) { + revert Errors.CMTAT_SnapshotModule_SnapshotScheduledInThePast( + time, + block.timestamp + ); + } + (bool isFound, uint256 index) = _findScheduledSnapshotIndex(time); + // Perfect match + if (isFound) { + revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyExists(); + } + // if no upper bound match found, we push the snapshot at the end of the list + if (index == _scheduledSnapshots.length) { + _scheduledSnapshots.push(time); + } else { + _scheduledSnapshots.push( + _scheduledSnapshots[_scheduledSnapshots.length - 1] + ); + for (uint256 i = _scheduledSnapshots.length - 2; i > index; ) { + _scheduledSnapshots[i] = _scheduledSnapshots[i - 1]; + unchecked { + --i; + } + } + _scheduledSnapshots[index] = time; + } + emit SnapshotSchedule(0, time); + } + + /** + @dev reschedule a scheduled snapshot at the specified newTime + */ + function _rescheduleSnapshot(uint256 oldTime, uint256 newTime) internal { + // Check the time firstly to avoid an useless read of storage + if (oldTime <= block.timestamp) { + revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyDone(); + } + if (newTime <= block.timestamp) { + revert Errors.CMTAT_SnapshotModule_SnapshotScheduledInThePast( + newTime, + block.timestamp + ); + } + if (_scheduledSnapshots.length == 0) { + revert Errors.CMTAT_SnapshotModule_NoSnapshotScheduled(); + } + (bool foundOld, uint256 index) = _findScheduledSnapshotIndex(oldTime); + if (!foundOld) { + revert Errors.CMTAT_SnapshotModule_SnapshotNotFound(); + } + if (index + 1 < _scheduledSnapshots.length) { + uint256 nextSnapshotTime = _scheduledSnapshots[index + 1]; + if (newTime > nextSnapshotTime) { + revert Errors + .CMTAT_SnapshotModule_SnapshotTimestampAfterNextSnapshot( + newTime, + nextSnapshotTime + ); + } else if (newTime == nextSnapshotTime) { + revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyExists(); + } + } + if (index > 0) { + if (newTime <= _scheduledSnapshots[index - 1]) + revert Errors + .CMTAT_SnapshotModule_SnapshotTimestampBeforePreviousSnapshot( + newTime, + _scheduledSnapshots[index - 1] + ); + } + _scheduledSnapshots[index] = newTime; + + emit SnapshotSchedule(oldTime, newTime); + } + + /** + @dev unschedule the last scheduled snapshot + */ + function _unscheduleLastSnapshot(uint256 time) internal { + // Check the time firstly to avoid an useless read of storage + if (time <= block.timestamp) { + revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyDone(); + } + if (_scheduledSnapshots.length == 0) { + revert Errors.CMTAT_SnapshotModule_NoSnapshotScheduled(); + } + // All snapshot time are unique, so we do not check the indice + if (time != _scheduledSnapshots[_scheduledSnapshots.length - 1]) { + revert Errors.CMTAT_SnapshotModule_SnapshotNotFound(); + } + _scheduledSnapshots.pop(); + emit SnapshotUnschedule(time); + } + + /** + @dev unschedule (remove) a scheduled snapshot in three steps: + - search the snapshot in the list + - If found, move all next snapshots one position to the left + - Reduce the array size by deleting the last snapshot + */ + function _unscheduleSnapshotNotOptimized(uint256 time) internal { + if (time <= block.timestamp) { + revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyDone(); + } + (bool isFound, uint256 index) = _findScheduledSnapshotIndex(time); + if (!isFound) { + revert Errors.CMTAT_SnapshotModule_SnapshotNotFound(); + } + for (uint256 i = index; i + 1 < _scheduledSnapshots.length; ) { + _scheduledSnapshots[i] = _scheduledSnapshots[i + 1]; + unchecked { + ++i; + } + } + _scheduledSnapshots.pop(); + } + + /** + @dev + Get the next scheduled snapshots + */ + function getNextSnapshots() public view returns (uint256[] memory) { + uint256[] memory nextScheduledSnapshot = new uint256[](0); + // no snapshot were planned + if (_scheduledSnapshots.length > 0) { + ( + uint256 timeLowerBound, + uint256 indexLowerBound + ) = _findScheduledMostRecentPastSnapshot(); + // All snapshots are situated in the futur + if ((timeLowerBound == 0) && (_currentSnapshotTime == 0)) { + return _scheduledSnapshots; + } else { + // There are snapshots situated in the futur + if (indexLowerBound + 1 != _scheduledSnapshots.length) { + // All next snapshots are located after the snapshot specified by indexLowerBound + uint256 arraySize = _scheduledSnapshots.length - + indexLowerBound - + 1; + nextScheduledSnapshot = new uint256[](arraySize); + for (uint256 i; i < arraySize; ) { + nextScheduledSnapshot[i] = _scheduledSnapshots[ + indexLowerBound + 1 + i + ]; + unchecked { + ++i; + } + } + } + } + } + return nextScheduledSnapshot; + } + + /** + @dev + Get all snapshots + */ + function getAllSnapshots() public view returns (uint256[] memory) { + return _scheduledSnapshots; + } + + + /** + @dev See {OpenZeppelin - ERC20Snapshot} + @param time where we want a snapshot + @param snapshots the struct where are stored the snapshots + @return snapshotExist true if a snapshot is found, false otherwise + value 0 if no snapshot, balance value if a snapshot exists + */ + function _valueAt( + uint256 time, + Snapshots storage snapshots + ) internal view returns (bool snapshotExist, uint256 value) { + // When a valid snapshot is queried, there are three possibilities: + // a) The queried value was not modified after the snapshot was taken. Therefore, a snapshot entry was never + // created for this id, and all stored snapshot ids are smaller than the requested one. The value that corresponds + // to this id is the current one. + // b) The queried value was modified after the snapshot was taken. Therefore, there will be an entry with the + // requested id, and its value is the one to return. + // c) More snapshots were created after the requested one, and the queried value was later modified. There will be + // no entry for the requested id: the value that corresponds to it is that of the smallest snapshot id that is + // larger than the requested one. + // + // In summary, we need to find an element in an array, returning the index of the smallest value that is larger if + // it is not found, unless said value doesn't exist (e.g. when all values are smaller). Arrays.findUpperBound does + // exactly this. + + uint256 index = snapshots.ids.findUpperBound(time); + + if (index == snapshots.ids.length) { + return (false, 0); + } else { + return (true, snapshots.values[index]); + } + } + + /** + @dev + Inside a struct Snapshots: + - Update the array ids to the current Snapshot time if this one is greater than the snapshot times stored in ids. + - Update the value to the corresponding value. + */ + function _updateSnapshot( + Snapshots storage snapshots, + uint256 currentValue + ) internal { + uint256 current = _currentSnapshotTime; + if (_lastSnapshot(snapshots.ids) < current) { + snapshots.ids.push(current); + snapshots.values.push(currentValue); + } + } + + /** + @dev + Set the currentSnapshotTime by retrieving the most recent snapshot + if a snapshot exists, clear all past scheduled snapshot + */ + function _setCurrentSnapshot() internal { + ( + uint256 scheduleSnapshotTime, + uint256 scheduleSnapshotIndex + ) = _findScheduledMostRecentPastSnapshot(); + if (scheduleSnapshotTime > 0) { + _currentSnapshotTime = scheduleSnapshotTime; + _currentSnapshotIndex = scheduleSnapshotIndex; + } + } + + /** + @return the last snapshot time inside a snapshot ids array + */ + function _lastSnapshot( + uint256[] storage ids + ) private view returns (uint256) { + if (ids.length == 0) { + return 0; + } else { + return ids[ids.length - 1]; + } + } + + /** + @dev Find the snapshot index at the specified time + @return (true, index) if the snapshot exists, (false, 0) otherwise + */ + function _findScheduledSnapshotIndex( + uint256 time + ) private view returns (bool, uint256) { + uint256 indexFound = _scheduledSnapshots.findUpperBound(time); + uint256 _scheduledSnapshotsLength = _scheduledSnapshots.length; + // Exact match + if ( + indexFound != _scheduledSnapshotsLength && + _scheduledSnapshots[indexFound] == time + ) { + return (true, indexFound); + } + // Upper bound match + else if (indexFound != _scheduledSnapshotsLength) { + return (false, indexFound); + } + // no match + else { + return (false, _scheduledSnapshotsLength); + } + } + + /** + @dev find the most recent past snapshot + The complexity of this function is O(N) because we go through the whole list + */ + function _findScheduledMostRecentPastSnapshot() + private + view + returns (uint256 time, uint256 index) + { + uint256 currentArraySize = _scheduledSnapshots.length; + // no snapshot or the current snapshot already points on the last snapshot + if ( + currentArraySize == 0 || + ((_currentSnapshotIndex + 1 == currentArraySize) && (time != 0)) + ) { + return (0, currentArraySize); + } + // mostRecent is initialized in the loop + uint256 mostRecent; + index = currentArraySize; + for (uint256 i = _currentSnapshotIndex; i < currentArraySize; ) { + if (_scheduledSnapshots[i] <= block.timestamp) { + mostRecent = _scheduledSnapshots[i]; + index = i; + } else { + // All snapshot are planned in the futur + break; + } + unchecked { + ++i; + } + } + return (mostRecent, index); + } + + uint256[50] private __gap; +} diff --git a/contracts/test/CMTATSnapshot/CMTAT_BASE_SnapshotTest.sol b/contracts/test/CMTATSnapshot/CMTAT_BASE_SnapshotTest.sol index ab588cfc..a86567b7 100644 --- a/contracts/test/CMTATSnapshot/CMTAT_BASE_SnapshotTest.sol +++ b/contracts/test/CMTATSnapshot/CMTAT_BASE_SnapshotTest.sol @@ -100,8 +100,9 @@ abstract contract CMTAT_BASE_SnapshotTest is __Enforcement_init_unchained(); /* SnapshotModule: - Add this call in case you add the SnapshotModule + Add these two calls in case you add the SnapshotModule */ + __SnapshotModuleBase_init_unchained(); __ERC20Snapshot_init_unchained(); __Validation_init_unchained(ruleEngine_); @@ -175,18 +176,18 @@ abstract contract CMTAT_BASE_SnapshotTest is address from, address to, uint256 amount - ) internal override(ERC20SnapshotModuleInternal, ERC20Upgradeable) { - // We call the SnapshotModule only if the transfer is valid - if (!ValidationModule.validateTransfer(from, to, amount)) + ) internal override(ERC20Upgradeable) { + + if (!ValidationModule.validateTransfer(from, to, amount)){ revert Errors.CMTAT_InvalidTransfer(from, to, amount); - /* - We do not call ERC20Upgradeable._update(from, to, amount) here because it is called inside the SnapshotModule - */ + } /* SnapshotModule: - Add this call in case you add the SnapshotModule + Add this in case you add the SnapshotModule + We call the SnapshotModule only if the transfer is valid */ - ERC20SnapshotModuleInternal._update(from, to, amount); + ERC20SnapshotModuleInternal._snapshotUpdate(from, to); + ERC20Upgradeable._update(from, to, amount); } /** diff --git a/package-lock.json b/package-lock.json index a0c0b9de..ff374341 100644 --- a/package-lock.json +++ b/package-lock.json @@ -13916,6 +13916,7 @@ "version": "4.0.5", "resolved": "https://registry.npmjs.org/bufferutil/-/bufferutil-4.0.5.tgz", "integrity": "sha512-HTm14iMQKK2FjFLRTM5lAVcyaUzOnqbPtesFIvREgXpJHdQm8bWS+GkQgIkfaBYRHuCnea7w8UVNfwiAQhlr9A==", + "hasInstallScript": true, "optional": true, "dependencies": { "node-gyp-build": "^4.3.0" @@ -14245,6 +14246,7 @@ "version": "5.0.7", "resolved": "https://registry.npmjs.org/utf-8-validate/-/utf-8-validate-5.0.7.tgz", "integrity": "sha512-vLt1O5Pp+flcArHGIyKEQq883nBt8nN8tVBcoL0qUXj2XT1n7p70yGIq2VK98I5FdZ1YHc0wk/koOnHjnXWk1Q==", + "hasInstallScript": true, "optional": true, "dependencies": { "node-gyp-build": "^4.3.0" From 81aa58ffa7fe8a4c039999782d0f7f081fbf3dff Mon Sep 17 00:00:00 2001 From: Ryan Sauge Date: Fri, 1 Dec 2023 11:45:36 +0100 Subject: [PATCH 2/2] Add ruleEngine with OperateOnTransfer + remove useless init function in internal modules --- contracts/CMTAT_STANDALONE.sol | 2 +- .../draft-IERC1404/IRuleEngineCMTAT.sol | 17 +++++++++++++ contracts/mocks/RuleEngine/RuleEngineMock.sol | 10 ++++++++ .../RuleEngine/interfaces/IRuleEngine.sol | 4 +-- contracts/modules/CMTAT_BASE.sol | 8 +++--- .../internal/ERC20SnapshotModuleInternal.sol | 13 ---------- .../internal/EnforcementModuleInternal.sol | 8 ------ .../internal/ValidationModuleInternal.sol | 22 ++++++---------- .../modules/security/AuthorizationModule.sol | 17 ------------- .../wrapper/controllers/ValidationModule.sol | 25 +++++++++++++++++-- .../CMTATSnapshotStandaloneTest.sol | 2 +- .../CMTATSnapshot/CMTAT_BASE_SnapshotTest.sol | 7 +++--- 12 files changed, 69 insertions(+), 66 deletions(-) create mode 100644 contracts/interfaces/draft-IERC1404/IRuleEngineCMTAT.sol diff --git a/contracts/CMTAT_STANDALONE.sol b/contracts/CMTAT_STANDALONE.sol index 29fbed43..ad52c01a 100644 --- a/contracts/CMTAT_STANDALONE.sol +++ b/contracts/CMTAT_STANDALONE.sol @@ -28,7 +28,7 @@ contract CMTAT_STANDALONE is CMTAT_BASE { uint8 decimalsIrrevocable, string memory tokenId_, string memory terms_, - IERC1404Wrapper ruleEngine_, + IRuleEngineCMTAT ruleEngine_, string memory information_, uint256 flag_ ) MetaTxModule(forwarderIrrevocable) { diff --git a/contracts/interfaces/draft-IERC1404/IRuleEngineCMTAT.sol b/contracts/interfaces/draft-IERC1404/IRuleEngineCMTAT.sol new file mode 100644 index 00000000..40e34023 --- /dev/null +++ b/contracts/interfaces/draft-IERC1404/IRuleEngineCMTAT.sol @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MPL-2.0 + +pragma solidity ^0.8.0; + +import "./draft-IERC1404Wrapper.sol"; + +interface IRuleEngineCMTAT is IERC1404Wrapper { + /** + * @dev Returns true if the operation is a success, and false otherwise. + */ + function operateOnTransfer( + address _from, + address _to, + uint256 _amount + ) external returns (bool isValid); + +} diff --git a/contracts/mocks/RuleEngine/RuleEngineMock.sol b/contracts/mocks/RuleEngine/RuleEngineMock.sol index 46473108..5bd94774 100644 --- a/contracts/mocks/RuleEngine/RuleEngineMock.sol +++ b/contracts/mocks/RuleEngine/RuleEngineMock.sol @@ -67,6 +67,16 @@ contract RuleEngineMock is IRuleEngine { return detectTransferRestriction(_from, _to, _amount) == 0; } + /* + @dev + Warning: if you want to use this mock, you have to restrict the access to this function through an an access control + */ + function operateOnTransfer( address _from, + address _to, + uint256 _amount) public override returns (bool){ + return validateTransfer(_from, _to, _amount); + } + /** @dev For all the rules, each restriction code has to be unique. diff --git a/contracts/mocks/RuleEngine/interfaces/IRuleEngine.sol b/contracts/mocks/RuleEngine/interfaces/IRuleEngine.sol index 44720657..f7a5b4bc 100644 --- a/contracts/mocks/RuleEngine/interfaces/IRuleEngine.sol +++ b/contracts/mocks/RuleEngine/interfaces/IRuleEngine.sol @@ -3,9 +3,9 @@ pragma solidity ^0.8.0; import "./IRule.sol"; -import "../../../interfaces/draft-IERC1404/draft-IERC1404Wrapper.sol"; +import "../../../interfaces/draft-IERC1404/IRuleEngineCMTAT.sol"; -interface IRuleEngine is IERC1404Wrapper { +interface IRuleEngine is IRuleEngineCMTAT { /** * @dev define the rules, the precedent rules will be overwritten */ diff --git a/contracts/modules/CMTAT_BASE.sol b/contracts/modules/CMTAT_BASE.sol index 395a17f6..59da035a 100644 --- a/contracts/modules/CMTAT_BASE.sol +++ b/contracts/modules/CMTAT_BASE.sol @@ -62,8 +62,8 @@ abstract contract CMTAT_BASE is uint8 decimalsIrrevocable, string memory tokenId_, string memory terms_, - IERC1404Wrapper ruleEngine_, - string memory information_, + IRuleEngineCMTAT ruleEngine_, + string memory information_, uint256 flag_ ) public initializer { __CMTAT_init( @@ -91,7 +91,7 @@ abstract contract CMTAT_BASE is uint8 decimalsIrrevocable, string memory tokenId_, string memory terms_, - IERC1404Wrapper ruleEngine_, + IRuleEngineCMTAT ruleEngine_, string memory information_, uint256 flag_ ) internal onlyInitializing { @@ -183,7 +183,7 @@ abstract contract CMTAT_BASE is address to, uint256 amount ) internal override(ERC20Upgradeable) { - if (!ValidationModule.validateTransfer(from, to, amount)) { + if (!ValidationModule._operateOnTransfer(from, to, amount)) { revert Errors.CMTAT_InvalidTransfer(from, to, amount); } /* diff --git a/contracts/modules/internal/ERC20SnapshotModuleInternal.sol b/contracts/modules/internal/ERC20SnapshotModuleInternal.sol index 83c1a91d..f80ec992 100644 --- a/contracts/modules/internal/ERC20SnapshotModuleInternal.sol +++ b/contracts/modules/internal/ERC20SnapshotModuleInternal.sol @@ -29,19 +29,6 @@ abstract contract ERC20SnapshotModuleInternal is SnapshotModuleBase, ERC20Upgrad */ uint256[] private _scheduledSnapshots; - /** - * @dev Initializes the contract - */ - function __ERC20Snapshot_init( - string memory name_, - string memory symbol_ - ) internal onlyInitializing { - __Context_init_unchained(); - __ERC20_init(name_, symbol_); - __SnapshotModuleBase_init_unchained(); - __ERC20Snapshot_init_unchained(); - } - function __ERC20Snapshot_init_unchained() internal onlyInitializing { // Nothing to do // _currentSnapshotTime & _currentSnapshotIndex are initialized to zero diff --git a/contracts/modules/internal/EnforcementModuleInternal.sol b/contracts/modules/internal/EnforcementModuleInternal.sol index 173db731..63230bfb 100644 --- a/contracts/modules/internal/EnforcementModuleInternal.sol +++ b/contracts/modules/internal/EnforcementModuleInternal.sol @@ -37,14 +37,6 @@ abstract contract EnforcementModuleInternal is mapping(address => bool) private _frozen; - /** - * @dev Initializes the contract - */ - function __Enforcement_init() internal onlyInitializing { - __Context_init_unchained(); - __Enforcement_init_unchained(); - } - function __Enforcement_init_unchained() internal onlyInitializing { // no variable to initialize } diff --git a/contracts/modules/internal/ValidationModuleInternal.sol b/contracts/modules/internal/ValidationModuleInternal.sol index 12c03c3e..f780fef5 100644 --- a/contracts/modules/internal/ValidationModuleInternal.sol +++ b/contracts/modules/internal/ValidationModuleInternal.sol @@ -5,7 +5,7 @@ pragma solidity ^0.8.20; import "../../../openzeppelin-contracts-upgradeable/contracts/utils/ContextUpgradeable.sol"; import "../../../openzeppelin-contracts-upgradeable/contracts/proxy/utils/Initializable.sol"; import "../../interfaces/draft-IERC1404/draft-IERC1404Wrapper.sol"; - +import "../../interfaces/draft-IERC1404/IRuleEngineCMTAT.sol"; /** * @dev Validation module. * @@ -18,22 +18,12 @@ abstract contract ValidationModuleInternal is /** * @dev Emitted when a rule engine is set. */ - event RuleEngine(IERC1404Wrapper indexed newRuleEngine); - - IERC1404Wrapper public ruleEngine; + event RuleEngine(IRuleEngineCMTAT indexed newRuleEngine); - /** - * @dev Initializes the contract with rule engine. - */ - function __Validation_init( - IERC1404Wrapper ruleEngine_ - ) internal onlyInitializing { - __Context_init_unchained(); - __Validation_init_unchained(ruleEngine_); - } + IRuleEngineCMTAT public ruleEngine; function __Validation_init_unchained( - IERC1404Wrapper ruleEngine_ + IRuleEngineCMTAT ruleEngine_ ) internal onlyInitializing { if (address(ruleEngine_) != address(0)) { ruleEngine = ruleEngine_; @@ -72,5 +62,9 @@ abstract contract ValidationModuleInternal is return ruleEngine.detectTransferRestriction(from, to, amount); } + function _operateOnTransfer(address from, address to, uint256 amount) virtual internal returns (bool) { + return ruleEngine.operateOnTransfer(from, to, amount); + } + uint256[50] private __gap; } diff --git a/contracts/modules/security/AuthorizationModule.sol b/contracts/modules/security/AuthorizationModule.sol index 43dbf630..2df2a049 100644 --- a/contracts/modules/security/AuthorizationModule.sol +++ b/contracts/modules/security/AuthorizationModule.sol @@ -24,23 +24,6 @@ abstract contract AuthorizationModule is AccessControlDefaultAdminRulesUpgradeab // SnapshotModule bytes32 public constant SNAPSHOOTER_ROLE = keccak256("SNAPSHOOTER_ROLE"); - - - function __AuthorizationModule_init( - address admin, - uint48 initialDelay - ) internal onlyInitializing { - /* OpenZeppelin */ - __Context_init_unchained(); - // AccessControlUpgradeable inherits from ERC165Upgradeable - __ERC165_init_unchained(); - __AccessControl_init_unchained(); - __AccessControlDefaultAdminRules_init_unchained(initialDelay, admin); - - /* own function */ - __AuthorizationModule_init_unchained(); - } - /** * @dev * diff --git a/contracts/modules/wrapper/controllers/ValidationModule.sol b/contracts/modules/wrapper/controllers/ValidationModule.sol index d7b0ed86..ef699c20 100644 --- a/contracts/modules/wrapper/controllers/ValidationModule.sol +++ b/contracts/modules/wrapper/controllers/ValidationModule.sol @@ -33,7 +33,7 @@ abstract contract ValidationModule is @param ruleEngine_ the call will be reverted if the new value of ruleEngine is the same as the current one */ function setRuleEngine( - IERC1404Wrapper ruleEngine_ + IRuleEngineCMTAT ruleEngine_ ) external onlyRole(DEFAULT_ADMIN_ROLE) { if (ruleEngine == ruleEngine_) revert Errors.CMTAT_ValidationModule_SameValue(); @@ -97,13 +97,24 @@ abstract contract ValidationModule is return TEXT_UNKNOWN_CODE; } } + + function validateTransferByModule( + address from, + address to, + uint256 /*amount*/ + ) internal view returns (bool) { + if (paused() || frozen(from) || frozen(to)) { + return false; + } + return true; + } function validateTransfer( address from, address to, uint256 amount ) public view override returns (bool) { - if (paused() || frozen(from) || frozen(to)) { + if (!validateTransferByModule(from, to, amount)) { return false; } if (address(ruleEngine) != address(0)) { @@ -112,5 +123,15 @@ abstract contract ValidationModule is return true; } + function _operateOnTransfer(address from, address to, uint256 amount) override internal returns (bool){ + if (!validateTransferByModule(from, to, amount)){ + return false; + } + if (address(ruleEngine) != address(0)) { + return ValidationModuleInternal._operateOnTransfer(from, to, amount); + } + return true; + } + uint256[50] private __gap; } diff --git a/contracts/test/CMTATSnapshot/CMTATSnapshotStandaloneTest.sol b/contracts/test/CMTATSnapshot/CMTATSnapshotStandaloneTest.sol index 654c6b5d..d372faac 100644 --- a/contracts/test/CMTATSnapshot/CMTATSnapshotStandaloneTest.sol +++ b/contracts/test/CMTATSnapshot/CMTATSnapshotStandaloneTest.sol @@ -27,7 +27,7 @@ contract CMTATSnapshotStandaloneTest is CMTAT_BASE_SnapshotTest { uint8 decimalsIrrevocable, string memory tokenId_, string memory terms_, - IERC1404Wrapper ruleEngine_, + IRuleEngineCMTAT ruleEngine_, string memory information_, uint256 flag_ ) MetaTxModule(forwarderIrrevocable) { diff --git a/contracts/test/CMTATSnapshot/CMTAT_BASE_SnapshotTest.sol b/contracts/test/CMTATSnapshot/CMTAT_BASE_SnapshotTest.sol index a86567b7..2b051940 100644 --- a/contracts/test/CMTATSnapshot/CMTAT_BASE_SnapshotTest.sol +++ b/contracts/test/CMTATSnapshot/CMTAT_BASE_SnapshotTest.sol @@ -53,7 +53,7 @@ abstract contract CMTAT_BASE_SnapshotTest is uint8 decimalsIrrevocable, string memory tokenId_, string memory terms_, - IERC1404Wrapper ruleEngine_, + IRuleEngineCMTAT ruleEngine_, string memory information_, uint256 flag_ ) public initializer { @@ -82,7 +82,7 @@ abstract contract CMTAT_BASE_SnapshotTest is uint8 decimalsIrrevocable, string memory tokenId_, string memory terms_, - IERC1404Wrapper ruleEngine_, + IRuleEngineCMTAT ruleEngine_, string memory information_, uint256 flag_ ) internal onlyInitializing { @@ -177,8 +177,7 @@ abstract contract CMTAT_BASE_SnapshotTest is address to, uint256 amount ) internal override(ERC20Upgradeable) { - - if (!ValidationModule.validateTransfer(from, to, amount)){ + if (!ValidationModule._operateOnTransfer(from, to, amount)){ revert Errors.CMTAT_InvalidTransfer(from, to, amount); } /*