From 0f9109f2524a90083f5e5ebd6b8fd261df8a3591 Mon Sep 17 00:00:00 2001 From: Ryan Sauge Date: Thu, 30 Nov 2023 16:47:22 +0100 Subject: [PATCH] Separate snapshotModuleInternal in two contracts --- contracts/modules/CMTAT_BASE.sol | 19 +- .../internal/ERC20SnapshotModuleInternal.sol | 427 ++---------------- .../internal/base/SnapshotModuleBase.sol | 410 +++++++++++++++++ .../CMTATSnapshot/CMTAT_BASE_SnapshotTest.sol | 19 +- package-lock.json | 2 + 5 files changed, 459 insertions(+), 418 deletions(-) create mode 100644 contracts/modules/internal/base/SnapshotModuleBase.sol diff --git a/contracts/modules/CMTAT_BASE.sol b/contracts/modules/CMTAT_BASE.sol index bc0c284a..395a17f6 100644 --- a/contracts/modules/CMTAT_BASE.sol +++ b/contracts/modules/CMTAT_BASE.sol @@ -110,9 +110,11 @@ abstract contract CMTAT_BASE is __Enforcement_init_unchained(); /* SnapshotModule: - Add this call in case you add the SnapshotModule + Add these two calls in case you add the SnapshotModule + + __SnapshotModuleBase_init_unchained(); __ERC20Snapshot_init_unchained(); - */ + */ __Validation_init_unchained(ruleEngine_); /* Wrapper */ @@ -174,10 +176,7 @@ abstract contract CMTAT_BASE is /** * @dev - * SnapshotModule: - * - override SnapshotModuleInternal if you add the SnapshotModule - * e.g. override(ERC20SnapshotModuleInternal, ERC20Upgradeable) - * - remove the keyword view + * */ function _update( address from, @@ -187,13 +186,13 @@ abstract contract CMTAT_BASE is if (!ValidationModule.validateTransfer(from, to, amount)) { revert Errors.CMTAT_InvalidTransfer(from, to, amount); } - ERC20Upgradeable._update(from, to, amount); - // We call the SnapshotModule only if the transfer is valid /* SnapshotModule: - Add this call in case you add the SnapshotModule - ERC20SnapshotModuleInternal._update(from, to, amount); + Add this in case you add the SnapshotModule + We call the SnapshotModule only if the transfer is valid */ + // ERC20SnapshotModuleInternal._snapshotUpdate(from, to); + ERC20Upgradeable._update(from, to, amount); } /** diff --git a/contracts/modules/internal/ERC20SnapshotModuleInternal.sol b/contracts/modules/internal/ERC20SnapshotModuleInternal.sol index c868a15a..83c1a91d 100644 --- a/contracts/modules/internal/ERC20SnapshotModuleInternal.sol +++ b/contracts/modules/internal/ERC20SnapshotModuleInternal.sol @@ -8,9 +8,9 @@ import "../../../openzeppelin-contracts-upgradeable/contracts/token/ERC20/ERC20U import {Arrays} from '@openzeppelin/contracts/utils/Arrays.sol'; import "../../libraries/Errors.sol"; - +import "./base/SnapshotModuleBase.sol"; /** - * @dev Snapshot module. + * @dev Snapshot module internal. * * Useful to take a snapshot of token holder balance and total supply at a specific time * Inspired by Openzeppelin - ERC20Snapshot but use the time as Id instead of a counter. @@ -18,43 +18,9 @@ import "../../libraries/Errors.sol"; because overriding this function can break the contract. */ -abstract contract ERC20SnapshotModuleInternal is ERC20Upgradeable { +abstract contract ERC20SnapshotModuleInternal is SnapshotModuleBase, ERC20Upgradeable { using Arrays for uint256[]; - /** - @notice Emitted when the snapshot with the specified oldTime was scheduled or rescheduled at the specified newTime. - */ - event SnapshotSchedule(uint256 indexed oldTime, uint256 indexed newTime); - - /** - @notice Emitted when the scheduled snapshot with the specified time was cancelled. - */ - event SnapshotUnschedule(uint256 indexed time); - - /** - @dev See {OpenZeppelin - ERC20Snapshot} - Snapshotted values have arrays of ids (time) and the value corresponding to that id. - ids is expected to be sorted in ascending order, and to contain no repeated elements - because we use findUpperBound in the function _valueAt - */ - struct Snapshots { - uint256[] ids; - uint256[] values; - } - - /** - @dev See {OpenZeppelin - ERC20Snapshot} - */ - mapping(address => Snapshots) private _accountBalanceSnapshots; - Snapshots private _totalSupplySnapshots; - - /** - @dev time instead of a counter for OpenZeppelin - */ - // Initialized to zero - uint256 private _currentSnapshotTime; - // Initialized to zero - uint256 private _currentSnapshotIndex; /** @dev @@ -72,6 +38,7 @@ abstract contract ERC20SnapshotModuleInternal is ERC20Upgradeable { ) internal onlyInitializing { __Context_init_unchained(); __ERC20_init(name_, symbol_); + __SnapshotModuleBase_init_unchained(); __ERC20Snapshot_init_unchained(); } @@ -80,242 +47,15 @@ abstract contract ERC20SnapshotModuleInternal is ERC20Upgradeable { // _currentSnapshotTime & _currentSnapshotIndex are initialized to zero } - /** - @dev schedule a snapshot at the specified time - You can only add a snapshot after the last previous - */ - function _scheduleSnapshot(uint256 time) internal { - // Check the time firstly to avoid an useless read of storage - if (time <= block.timestamp) { - revert Errors.CMTAT_SnapshotModule_SnapshotScheduledInThePast( - time, - block.timestamp - ); - } - - if (_scheduledSnapshots.length > 0) { - // We check the last snapshot on the list - uint256 nextSnapshotTime = _scheduledSnapshots[ - _scheduledSnapshots.length - 1 - ]; - if (time < nextSnapshotTime) { - revert Errors - .CMTAT_SnapshotModule_SnapshotTimestampBeforeLastSnapshot( - time, - nextSnapshotTime - ); - } else if (time == nextSnapshotTime) { - revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyExists(); - } - } - _scheduledSnapshots.push(time); - emit SnapshotSchedule(0, time); - } - - /** - @dev schedule a snapshot at the specified time - */ - function _scheduleSnapshotNotOptimized(uint256 time) internal { - if (time <= block.timestamp) { - revert Errors.CMTAT_SnapshotModule_SnapshotScheduledInThePast( - time, - block.timestamp - ); - } - (bool isFound, uint256 index) = _findScheduledSnapshotIndex(time); - // Perfect match - if (isFound) { - revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyExists(); - } - // if no upper bound match found, we push the snapshot at the end of the list - if (index == _scheduledSnapshots.length) { - _scheduledSnapshots.push(time); - } else { - _scheduledSnapshots.push( - _scheduledSnapshots[_scheduledSnapshots.length - 1] - ); - for (uint256 i = _scheduledSnapshots.length - 2; i > index; ) { - _scheduledSnapshots[i] = _scheduledSnapshots[i - 1]; - unchecked { - --i; - } - } - _scheduledSnapshots[index] = time; - } - emit SnapshotSchedule(0, time); - } - - /** - @dev reschedule a scheduled snapshot at the specified newTime - */ - function _rescheduleSnapshot(uint256 oldTime, uint256 newTime) internal { - // Check the time firstly to avoid an useless read of storage - if (oldTime <= block.timestamp) { - revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyDone(); - } - if (newTime <= block.timestamp) { - revert Errors.CMTAT_SnapshotModule_SnapshotScheduledInThePast( - newTime, - block.timestamp - ); - } - if (_scheduledSnapshots.length == 0) { - revert Errors.CMTAT_SnapshotModule_NoSnapshotScheduled(); - } - (bool foundOld, uint256 index) = _findScheduledSnapshotIndex(oldTime); - if (!foundOld) { - revert Errors.CMTAT_SnapshotModule_SnapshotNotFound(); - } - if (index + 1 < _scheduledSnapshots.length) { - uint256 nextSnapshotTime = _scheduledSnapshots[index + 1]; - if (newTime > nextSnapshotTime) { - revert Errors - .CMTAT_SnapshotModule_SnapshotTimestampAfterNextSnapshot( - newTime, - nextSnapshotTime - ); - } else if (newTime == nextSnapshotTime) { - revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyExists(); - } - } - if (index > 0) { - if (newTime <= _scheduledSnapshots[index - 1]) - revert Errors - .CMTAT_SnapshotModule_SnapshotTimestampBeforePreviousSnapshot( - newTime, - _scheduledSnapshots[index - 1] - ); - } - _scheduledSnapshots[index] = newTime; - - emit SnapshotSchedule(oldTime, newTime); - } - - /** - @dev unschedule the last scheduled snapshot - */ - function _unscheduleLastSnapshot(uint256 time) internal { - // Check the time firstly to avoid an useless read of storage - if (time <= block.timestamp) { - revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyDone(); - } - if (_scheduledSnapshots.length == 0) { - revert Errors.CMTAT_SnapshotModule_NoSnapshotScheduled(); - } - // All snapshot time are unique, so we do not check the indice - if (time != _scheduledSnapshots[_scheduledSnapshots.length - 1]) { - revert Errors.CMTAT_SnapshotModule_SnapshotNotFound(); - } - _scheduledSnapshots.pop(); - emit SnapshotUnschedule(time); - } - - /** - @dev unschedule (remove) a scheduled snapshot in three steps: - - search the snapshot in the list - - If found, move all next snapshots one position to the left - - Reduce the array size by deleting the last snapshot - */ - function _unscheduleSnapshotNotOptimized(uint256 time) internal { - if (time <= block.timestamp) { - revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyDone(); - } - (bool isFound, uint256 index) = _findScheduledSnapshotIndex(time); - if (!isFound) { - revert Errors.CMTAT_SnapshotModule_SnapshotNotFound(); - } - for (uint256 i = index; i + 1 < _scheduledSnapshots.length; ) { - _scheduledSnapshots[i] = _scheduledSnapshots[i + 1]; - unchecked { - ++i; - } - } - _scheduledSnapshots.pop(); - } - - /** - @dev - Get the next scheduled snapshots - */ - function getNextSnapshots() public view returns (uint256[] memory) { - uint256[] memory nextScheduledSnapshot = new uint256[](0); - // no snapshot were planned - if (_scheduledSnapshots.length > 0) { - ( - uint256 timeLowerBound, - uint256 indexLowerBound - ) = _findScheduledMostRecentPastSnapshot(); - // All snapshots are situated in the futur - if ((timeLowerBound == 0) && (_currentSnapshotTime == 0)) { - return _scheduledSnapshots; - } else { - // There are snapshots situated in the futur - if (indexLowerBound + 1 != _scheduledSnapshots.length) { - // All next snapshots are located after the snapshot specified by indexLowerBound - uint256 arraySize = _scheduledSnapshots.length - - indexLowerBound - - 1; - nextScheduledSnapshot = new uint256[](arraySize); - for (uint256 i; i < arraySize; ) { - nextScheduledSnapshot[i] = _scheduledSnapshots[ - indexLowerBound + 1 + i - ]; - unchecked { - ++i; - } - } - } - } - } - return nextScheduledSnapshot; - } - - /** - @dev - Get all snapshots - */ - function getAllSnapshots() public view returns (uint256[] memory) { - return _scheduledSnapshots; - } - - /** - @notice Return the number of tokens owned by the given owner at the time when the snapshot with the given time was created. - @return value stored in the snapshot, or the actual balance if no snapshot - */ - function snapshotBalanceOf( - uint256 time, - address owner - ) public view returns (uint256) { - (bool snapshotted, uint256 value) = _valueAt( - time, - _accountBalanceSnapshots[owner] - ); - - return snapshotted ? value : balanceOf(owner); - } - - /** - @dev See {OpenZeppelin - ERC20Snapshot} - Retrieves the total supply at the specified time. - @return value stored in the snapshot, or the actual totalSupply if no snapshot - */ - function snapshotTotalSupply(uint256 time) public view returns (uint256) { - (bool snapshotted, uint256 value) = _valueAt( - time, - _totalSupplySnapshots - ); - return snapshotted ? value : totalSupply(); - } /** @dev Update balance and/or total supply snapshots before the values are modified. This is implemented in the _beforeTokenTransfer hook, which is executed for _mint, _burn, and _transfer operations. */ - function _update( + function _snapshotUpdate( address from, - address to, - uint256 amount - ) internal virtual override { + address to + ) internal virtual { _setCurrentSnapshot(); if (from != address(0)) { // for both burn and transfer @@ -332,42 +72,8 @@ abstract contract ERC20SnapshotModuleInternal is ERC20Upgradeable { _updateAccountSnapshot(to); _updateTotalSupplySnapshot(); } - ERC20Upgradeable._update(from, to, amount); } - /** - @dev See {OpenZeppelin - ERC20Snapshot} - @param time where we want a snapshot - @param snapshots the struct where are stored the snapshots - @return snapshotExist true if a snapshot is found, false otherwise - value 0 if no snapshot, balance value if a snapshot exists - */ - function _valueAt( - uint256 time, - Snapshots storage snapshots - ) private view returns (bool snapshotExist, uint256 value) { - // When a valid snapshot is queried, there are three possibilities: - // a) The queried value was not modified after the snapshot was taken. Therefore, a snapshot entry was never - // created for this id, and all stored snapshot ids are smaller than the requested one. The value that corresponds - // to this id is the current one. - // b) The queried value was modified after the snapshot was taken. Therefore, there will be an entry with the - // requested id, and its value is the one to return. - // c) More snapshots were created after the requested one, and the queried value was later modified. There will be - // no entry for the requested id: the value that corresponds to it is that of the smallest snapshot id that is - // larger than the requested one. - // - // In summary, we need to find an element in an array, returning the index of the smallest value that is larger if - // it is not found, unless said value doesn't exist (e.g. when all values are smaller). Arrays.findUpperBound does - // exactly this. - - uint256 index = snapshots.ids.findUpperBound(time); - - if (index == snapshots.ids.length) { - return (false, 0); - } else { - return (true, snapshots.values[index]); - } - } /** @dev See {OpenZeppelin - ERC20Snapshot} @@ -383,111 +89,34 @@ abstract contract ERC20SnapshotModuleInternal is ERC20Upgradeable { _updateSnapshot(_totalSupplySnapshots, totalSupply()); } - /** - @dev - Inside a struct Snapshots: - - Update the array ids to the current Snapshot time if this one is greater than the snapshot times stored in ids. - - Update the value to the corresponding value. - */ - function _updateSnapshot( - Snapshots storage snapshots, - uint256 currentValue - ) private { - uint256 current = _currentSnapshotTime; - if (_lastSnapshot(snapshots.ids) < current) { - snapshots.ids.push(current); - snapshots.values.push(currentValue); - } - } /** - @dev - Set the currentSnapshotTime by retrieving the most recent snapshot - if a snapshot exists, clear all past scheduled snapshot - */ - function _setCurrentSnapshot() internal { - ( - uint256 scheduleSnapshotTime, - uint256 scheduleSnapshotIndex - ) = _findScheduledMostRecentPastSnapshot(); - if (scheduleSnapshotTime > 0) { - _currentSnapshotTime = scheduleSnapshotTime; - _currentSnapshotIndex = scheduleSnapshotIndex; - } - } - - /** - @return the last snapshot time inside a snapshot ids array + @notice Return the number of tokens owned by the given owner at the time when the snapshot with the given time was created. + @return value stored in the snapshot, or the actual balance if no snapshot */ - function _lastSnapshot( - uint256[] storage ids - ) private view returns (uint256) { - if (ids.length == 0) { - return 0; - } else { - return ids[ids.length - 1]; - } - } + function snapshotBalanceOf( + uint256 time, + address owner + ) public view returns (uint256) { + (bool snapshotted, uint256 value) = _valueAt( + time, + _accountBalanceSnapshots[owner] + ); - /** - @dev Find the snapshot index at the specified time - @return (true, index) if the snapshot exists, (false, 0) otherwise - */ - function _findScheduledSnapshotIndex( - uint256 time - ) private view returns (bool, uint256) { - uint256 indexFound = _scheduledSnapshots.findUpperBound(time); - uint256 _scheduledSnapshotsLength = _scheduledSnapshots.length; - // Exact match - if ( - indexFound != _scheduledSnapshotsLength && - _scheduledSnapshots[indexFound] == time - ) { - return (true, indexFound); - } - // Upper bound match - else if (indexFound != _scheduledSnapshotsLength) { - return (false, indexFound); - } - // no match - else { - return (false, _scheduledSnapshotsLength); - } + return snapshotted ? value : balanceOf(owner); } - /** - @dev find the most recent past snapshot - The complexity of this function is O(N) because we go through the whole list + /** + @dev See {OpenZeppelin - ERC20Snapshot} + Retrieves the total supply at the specified time. + @return value stored in the snapshot, or the actual totalSupply if no snapshot */ - function _findScheduledMostRecentPastSnapshot() - private - view - returns (uint256 time, uint256 index) - { - uint256 currentArraySize = _scheduledSnapshots.length; - // no snapshot or the current snapshot already points on the last snapshot - if ( - currentArraySize == 0 || - ((_currentSnapshotIndex + 1 == currentArraySize) && (time != 0)) - ) { - return (0, currentArraySize); - } - // mostRecent is initialized in the loop - uint256 mostRecent; - index = currentArraySize; - for (uint256 i = _currentSnapshotIndex; i < currentArraySize; ) { - if (_scheduledSnapshots[i] <= block.timestamp) { - mostRecent = _scheduledSnapshots[i]; - index = i; - } else { - // All snapshot are planned in the futur - break; - } - unchecked { - ++i; - } - } - return (mostRecent, index); + function snapshotTotalSupply(uint256 time) public view returns (uint256) { + (bool snapshotted, uint256 value) = _valueAt( + time, + _totalSupplySnapshots + ); + return snapshotted ? value : totalSupply(); } uint256[50] private __gap; diff --git a/contracts/modules/internal/base/SnapshotModuleBase.sol b/contracts/modules/internal/base/SnapshotModuleBase.sol new file mode 100644 index 00000000..935eba94 --- /dev/null +++ b/contracts/modules/internal/base/SnapshotModuleBase.sol @@ -0,0 +1,410 @@ +//SPDX-License-Identifier: MPL-2.0 + +pragma solidity ^0.8.20; + +import "../../../../openzeppelin-contracts-upgradeable/contracts/proxy/utils/Initializable.sol"; +import {Arrays} from '@openzeppelin/contracts/utils/Arrays.sol'; + +import "../../../libraries/Errors.sol"; + +/** + * @dev Base for the Snapshot module + * + * Useful to take a snapshot of token holder balance and total supply at a specific time + * Inspired by Openzeppelin - ERC20Snapshot but use the time as Id instead of a counter. + * Contrary to OpenZeppelin, the function _getCurrentSnapshotId is not available + because overriding this function can break the contract. + */ + +abstract contract SnapshotModuleBase is Initializable { + using Arrays for uint256[]; + + /** + @notice Emitted when the snapshot with the specified oldTime was scheduled or rescheduled at the specified newTime. + */ + event SnapshotSchedule(uint256 indexed oldTime, uint256 indexed newTime); + + /** + @notice Emitted when the scheduled snapshot with the specified time was cancelled. + */ + event SnapshotUnschedule(uint256 indexed time); + + /** + @dev See {OpenZeppelin - ERC20Snapshot} + Snapshotted values have arrays of ids (time) and the value corresponding to that id. + ids is expected to be sorted in ascending order, and to contain no repeated elements + because we use findUpperBound in the function _valueAt + */ + struct Snapshots { + uint256[] ids; + uint256[] values; + } + + /** + @dev See {OpenZeppelin - ERC20Snapshot} + */ + mapping(address => Snapshots) internal _accountBalanceSnapshots; + Snapshots internal _totalSupplySnapshots; + + /** + @dev time instead of a counter for OpenZeppelin + */ + // Initialized to zero + uint256 private _currentSnapshotTime; + // Initialized to zero + uint256 private _currentSnapshotIndex; + + /** + @dev + list of scheduled snapshot (time) + This list is sorted in ascending order + */ + uint256[] private _scheduledSnapshots; + + function __SnapshotModuleBase_init_unchained() internal onlyInitializing { + // Nothing to do + // _currentSnapshotTime & _currentSnapshotIndex are initialized to zero + } + + /** + @dev schedule a snapshot at the specified time + You can only add a snapshot after the last previous + */ + function _scheduleSnapshot(uint256 time) internal { + // Check the time firstly to avoid an useless read of storage + if (time <= block.timestamp) { + revert Errors.CMTAT_SnapshotModule_SnapshotScheduledInThePast( + time, + block.timestamp + ); + } + + if (_scheduledSnapshots.length > 0) { + // We check the last snapshot on the list + uint256 nextSnapshotTime = _scheduledSnapshots[ + _scheduledSnapshots.length - 1 + ]; + if (time < nextSnapshotTime) { + revert Errors + .CMTAT_SnapshotModule_SnapshotTimestampBeforeLastSnapshot( + time, + nextSnapshotTime + ); + } else if (time == nextSnapshotTime) { + revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyExists(); + } + } + _scheduledSnapshots.push(time); + emit SnapshotSchedule(0, time); + } + + /** + @dev schedule a snapshot at the specified time + */ + function _scheduleSnapshotNotOptimized(uint256 time) internal { + if (time <= block.timestamp) { + revert Errors.CMTAT_SnapshotModule_SnapshotScheduledInThePast( + time, + block.timestamp + ); + } + (bool isFound, uint256 index) = _findScheduledSnapshotIndex(time); + // Perfect match + if (isFound) { + revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyExists(); + } + // if no upper bound match found, we push the snapshot at the end of the list + if (index == _scheduledSnapshots.length) { + _scheduledSnapshots.push(time); + } else { + _scheduledSnapshots.push( + _scheduledSnapshots[_scheduledSnapshots.length - 1] + ); + for (uint256 i = _scheduledSnapshots.length - 2; i > index; ) { + _scheduledSnapshots[i] = _scheduledSnapshots[i - 1]; + unchecked { + --i; + } + } + _scheduledSnapshots[index] = time; + } + emit SnapshotSchedule(0, time); + } + + /** + @dev reschedule a scheduled snapshot at the specified newTime + */ + function _rescheduleSnapshot(uint256 oldTime, uint256 newTime) internal { + // Check the time firstly to avoid an useless read of storage + if (oldTime <= block.timestamp) { + revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyDone(); + } + if (newTime <= block.timestamp) { + revert Errors.CMTAT_SnapshotModule_SnapshotScheduledInThePast( + newTime, + block.timestamp + ); + } + if (_scheduledSnapshots.length == 0) { + revert Errors.CMTAT_SnapshotModule_NoSnapshotScheduled(); + } + (bool foundOld, uint256 index) = _findScheduledSnapshotIndex(oldTime); + if (!foundOld) { + revert Errors.CMTAT_SnapshotModule_SnapshotNotFound(); + } + if (index + 1 < _scheduledSnapshots.length) { + uint256 nextSnapshotTime = _scheduledSnapshots[index + 1]; + if (newTime > nextSnapshotTime) { + revert Errors + .CMTAT_SnapshotModule_SnapshotTimestampAfterNextSnapshot( + newTime, + nextSnapshotTime + ); + } else if (newTime == nextSnapshotTime) { + revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyExists(); + } + } + if (index > 0) { + if (newTime <= _scheduledSnapshots[index - 1]) + revert Errors + .CMTAT_SnapshotModule_SnapshotTimestampBeforePreviousSnapshot( + newTime, + _scheduledSnapshots[index - 1] + ); + } + _scheduledSnapshots[index] = newTime; + + emit SnapshotSchedule(oldTime, newTime); + } + + /** + @dev unschedule the last scheduled snapshot + */ + function _unscheduleLastSnapshot(uint256 time) internal { + // Check the time firstly to avoid an useless read of storage + if (time <= block.timestamp) { + revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyDone(); + } + if (_scheduledSnapshots.length == 0) { + revert Errors.CMTAT_SnapshotModule_NoSnapshotScheduled(); + } + // All snapshot time are unique, so we do not check the indice + if (time != _scheduledSnapshots[_scheduledSnapshots.length - 1]) { + revert Errors.CMTAT_SnapshotModule_SnapshotNotFound(); + } + _scheduledSnapshots.pop(); + emit SnapshotUnschedule(time); + } + + /** + @dev unschedule (remove) a scheduled snapshot in three steps: + - search the snapshot in the list + - If found, move all next snapshots one position to the left + - Reduce the array size by deleting the last snapshot + */ + function _unscheduleSnapshotNotOptimized(uint256 time) internal { + if (time <= block.timestamp) { + revert Errors.CMTAT_SnapshotModule_SnapshotAlreadyDone(); + } + (bool isFound, uint256 index) = _findScheduledSnapshotIndex(time); + if (!isFound) { + revert Errors.CMTAT_SnapshotModule_SnapshotNotFound(); + } + for (uint256 i = index; i + 1 < _scheduledSnapshots.length; ) { + _scheduledSnapshots[i] = _scheduledSnapshots[i + 1]; + unchecked { + ++i; + } + } + _scheduledSnapshots.pop(); + } + + /** + @dev + Get the next scheduled snapshots + */ + function getNextSnapshots() public view returns (uint256[] memory) { + uint256[] memory nextScheduledSnapshot = new uint256[](0); + // no snapshot were planned + if (_scheduledSnapshots.length > 0) { + ( + uint256 timeLowerBound, + uint256 indexLowerBound + ) = _findScheduledMostRecentPastSnapshot(); + // All snapshots are situated in the futur + if ((timeLowerBound == 0) && (_currentSnapshotTime == 0)) { + return _scheduledSnapshots; + } else { + // There are snapshots situated in the futur + if (indexLowerBound + 1 != _scheduledSnapshots.length) { + // All next snapshots are located after the snapshot specified by indexLowerBound + uint256 arraySize = _scheduledSnapshots.length - + indexLowerBound - + 1; + nextScheduledSnapshot = new uint256[](arraySize); + for (uint256 i; i < arraySize; ) { + nextScheduledSnapshot[i] = _scheduledSnapshots[ + indexLowerBound + 1 + i + ]; + unchecked { + ++i; + } + } + } + } + } + return nextScheduledSnapshot; + } + + /** + @dev + Get all snapshots + */ + function getAllSnapshots() public view returns (uint256[] memory) { + return _scheduledSnapshots; + } + + + /** + @dev See {OpenZeppelin - ERC20Snapshot} + @param time where we want a snapshot + @param snapshots the struct where are stored the snapshots + @return snapshotExist true if a snapshot is found, false otherwise + value 0 if no snapshot, balance value if a snapshot exists + */ + function _valueAt( + uint256 time, + Snapshots storage snapshots + ) internal view returns (bool snapshotExist, uint256 value) { + // When a valid snapshot is queried, there are three possibilities: + // a) The queried value was not modified after the snapshot was taken. Therefore, a snapshot entry was never + // created for this id, and all stored snapshot ids are smaller than the requested one. The value that corresponds + // to this id is the current one. + // b) The queried value was modified after the snapshot was taken. Therefore, there will be an entry with the + // requested id, and its value is the one to return. + // c) More snapshots were created after the requested one, and the queried value was later modified. There will be + // no entry for the requested id: the value that corresponds to it is that of the smallest snapshot id that is + // larger than the requested one. + // + // In summary, we need to find an element in an array, returning the index of the smallest value that is larger if + // it is not found, unless said value doesn't exist (e.g. when all values are smaller). Arrays.findUpperBound does + // exactly this. + + uint256 index = snapshots.ids.findUpperBound(time); + + if (index == snapshots.ids.length) { + return (false, 0); + } else { + return (true, snapshots.values[index]); + } + } + + /** + @dev + Inside a struct Snapshots: + - Update the array ids to the current Snapshot time if this one is greater than the snapshot times stored in ids. + - Update the value to the corresponding value. + */ + function _updateSnapshot( + Snapshots storage snapshots, + uint256 currentValue + ) internal { + uint256 current = _currentSnapshotTime; + if (_lastSnapshot(snapshots.ids) < current) { + snapshots.ids.push(current); + snapshots.values.push(currentValue); + } + } + + /** + @dev + Set the currentSnapshotTime by retrieving the most recent snapshot + if a snapshot exists, clear all past scheduled snapshot + */ + function _setCurrentSnapshot() internal { + ( + uint256 scheduleSnapshotTime, + uint256 scheduleSnapshotIndex + ) = _findScheduledMostRecentPastSnapshot(); + if (scheduleSnapshotTime > 0) { + _currentSnapshotTime = scheduleSnapshotTime; + _currentSnapshotIndex = scheduleSnapshotIndex; + } + } + + /** + @return the last snapshot time inside a snapshot ids array + */ + function _lastSnapshot( + uint256[] storage ids + ) private view returns (uint256) { + if (ids.length == 0) { + return 0; + } else { + return ids[ids.length - 1]; + } + } + + /** + @dev Find the snapshot index at the specified time + @return (true, index) if the snapshot exists, (false, 0) otherwise + */ + function _findScheduledSnapshotIndex( + uint256 time + ) private view returns (bool, uint256) { + uint256 indexFound = _scheduledSnapshots.findUpperBound(time); + uint256 _scheduledSnapshotsLength = _scheduledSnapshots.length; + // Exact match + if ( + indexFound != _scheduledSnapshotsLength && + _scheduledSnapshots[indexFound] == time + ) { + return (true, indexFound); + } + // Upper bound match + else if (indexFound != _scheduledSnapshotsLength) { + return (false, indexFound); + } + // no match + else { + return (false, _scheduledSnapshotsLength); + } + } + + /** + @dev find the most recent past snapshot + The complexity of this function is O(N) because we go through the whole list + */ + function _findScheduledMostRecentPastSnapshot() + private + view + returns (uint256 time, uint256 index) + { + uint256 currentArraySize = _scheduledSnapshots.length; + // no snapshot or the current snapshot already points on the last snapshot + if ( + currentArraySize == 0 || + ((_currentSnapshotIndex + 1 == currentArraySize) && (time != 0)) + ) { + return (0, currentArraySize); + } + // mostRecent is initialized in the loop + uint256 mostRecent; + index = currentArraySize; + for (uint256 i = _currentSnapshotIndex; i < currentArraySize; ) { + if (_scheduledSnapshots[i] <= block.timestamp) { + mostRecent = _scheduledSnapshots[i]; + index = i; + } else { + // All snapshot are planned in the futur + break; + } + unchecked { + ++i; + } + } + return (mostRecent, index); + } + + uint256[50] private __gap; +} diff --git a/contracts/test/CMTATSnapshot/CMTAT_BASE_SnapshotTest.sol b/contracts/test/CMTATSnapshot/CMTAT_BASE_SnapshotTest.sol index ab588cfc..a86567b7 100644 --- a/contracts/test/CMTATSnapshot/CMTAT_BASE_SnapshotTest.sol +++ b/contracts/test/CMTATSnapshot/CMTAT_BASE_SnapshotTest.sol @@ -100,8 +100,9 @@ abstract contract CMTAT_BASE_SnapshotTest is __Enforcement_init_unchained(); /* SnapshotModule: - Add this call in case you add the SnapshotModule + Add these two calls in case you add the SnapshotModule */ + __SnapshotModuleBase_init_unchained(); __ERC20Snapshot_init_unchained(); __Validation_init_unchained(ruleEngine_); @@ -175,18 +176,18 @@ abstract contract CMTAT_BASE_SnapshotTest is address from, address to, uint256 amount - ) internal override(ERC20SnapshotModuleInternal, ERC20Upgradeable) { - // We call the SnapshotModule only if the transfer is valid - if (!ValidationModule.validateTransfer(from, to, amount)) + ) internal override(ERC20Upgradeable) { + + if (!ValidationModule.validateTransfer(from, to, amount)){ revert Errors.CMTAT_InvalidTransfer(from, to, amount); - /* - We do not call ERC20Upgradeable._update(from, to, amount) here because it is called inside the SnapshotModule - */ + } /* SnapshotModule: - Add this call in case you add the SnapshotModule + Add this in case you add the SnapshotModule + We call the SnapshotModule only if the transfer is valid */ - ERC20SnapshotModuleInternal._update(from, to, amount); + ERC20SnapshotModuleInternal._snapshotUpdate(from, to); + ERC20Upgradeable._update(from, to, amount); } /** diff --git a/package-lock.json b/package-lock.json index a0c0b9de..ff374341 100644 --- a/package-lock.json +++ b/package-lock.json @@ -13916,6 +13916,7 @@ "version": "4.0.5", "resolved": "https://registry.npmjs.org/bufferutil/-/bufferutil-4.0.5.tgz", "integrity": "sha512-HTm14iMQKK2FjFLRTM5lAVcyaUzOnqbPtesFIvREgXpJHdQm8bWS+GkQgIkfaBYRHuCnea7w8UVNfwiAQhlr9A==", + "hasInstallScript": true, "optional": true, "dependencies": { "node-gyp-build": "^4.3.0" @@ -14245,6 +14246,7 @@ "version": "5.0.7", "resolved": "https://registry.npmjs.org/utf-8-validate/-/utf-8-validate-5.0.7.tgz", "integrity": "sha512-vLt1O5Pp+flcArHGIyKEQq883nBt8nN8tVBcoL0qUXj2XT1n7p70yGIq2VK98I5FdZ1YHc0wk/koOnHjnXWk1Q==", + "hasInstallScript": true, "optional": true, "dependencies": { "node-gyp-build": "^4.3.0"