diff --git a/.forge-snapshots/NonfungiblePositionManager#collect.snap b/.forge-snapshots/NonfungiblePositionManager#collect.snap index dc983cf..d257b8c 100644 --- a/.forge-snapshots/NonfungiblePositionManager#collect.snap +++ b/.forge-snapshots/NonfungiblePositionManager#collect.snap @@ -1 +1 @@ -196939 \ No newline at end of file +199462 \ No newline at end of file diff --git a/src/pool-cl/NonfungiblePositionManager.sol b/src/pool-cl/NonfungiblePositionManager.sol index bc00fd1..324dc0c 100644 --- a/src/pool-cl/NonfungiblePositionManager.sol +++ b/src/pool-cl/NonfungiblePositionManager.sol @@ -436,6 +436,15 @@ contract NonfungiblePositionManager is nftPosition.tokensOwed0 = tokensOwed0 - amount0Collect; nftPosition.tokensOwed1 = tokensOwed1 - amount1Collect; + /// @dev due to rounding down calculation in FullMath, some wei might be loss if the fee is too small + /// if that happen we need to ignore the loss part and take the rest of the fee otherwise it will revert whole tx + uint128 actualFee0Left = uint128(vault.balanceOf(address(this), poolKey.currency0)); + uint128 actualFee1Left = uint128(vault.balanceOf(address(this), poolKey.currency1)); + (amount0Collect, amount1Collect) = ( + actualFee0Left > amount0Collect ? amount0Collect : actualFee0Left, + actualFee1Left > amount1Collect ? amount1Collect : actualFee1Left + ); + // cash out from vault burnAndTake(poolKey.currency0, params.recipient, amount0Collect); burnAndTake(poolKey.currency1, params.recipient, amount1Collect); diff --git a/test/pool-cl/CLSwapRouterInvariant.t.sol b/test/pool-cl/CLSwapRouterInvariant.t.sol index e563a29..02f6895 100644 --- a/test/pool-cl/CLSwapRouterInvariant.t.sol +++ b/test/pool-cl/CLSwapRouterInvariant.t.sol @@ -47,6 +47,9 @@ contract CLSwapRouterHandler is Test { uint256 public token1Minted; uint256 public nativeTokenMinted; + uint256 public token0FeeAccrued; + uint256 public token1FeeAccrued; + constructor() { WETH weth = new WETH(); vault = new Vault(); @@ -110,6 +113,9 @@ contract CLSwapRouterHandler is Test { PoolKey memory pk = isNativePool ? nativePoolKey : poolKey; // if native pool, have to ensure call method with value uint256 value = isNativePool ? amtIn : 0; + + vm.recordLogs(); + vm.prank(alice); router.exactInputSingle{value: value}( ICLSwapRouterBase.V4CLExactInputSingleParams({ @@ -123,6 +129,8 @@ contract CLSwapRouterHandler is Test { }), block.timestamp + 100 ); + + _accumulateFee(); } function exactSwapInput(uint128 amtIn, bool isNativePool) public { @@ -138,7 +146,6 @@ contract CLSwapRouterHandler is Test { // Step 3: swap PoolKey memory pk = isNativePool ? nativePoolKey : poolKey; - vm.prank(alice); ISwapRouterBase.PathKey[] memory path = new ISwapRouterBase.PathKey[](1); path[0] = ISwapRouterBase.PathKey({ intermediateCurrency: Currency.wrap(address(token1)), @@ -149,8 +156,11 @@ contract CLSwapRouterHandler is Test { parameters: pk.parameters }); + vm.recordLogs(); + // if native pool, have to ensure call method with value uint256 value = isNativePool ? amtIn : 0; + vm.prank(alice); router.exactInput{value: value}( ICLSwapRouterBase.V4CLExactInputParams({ currencyIn: isNativePool ? CurrencyLibrary.NATIVE : currency0, @@ -161,6 +171,8 @@ contract CLSwapRouterHandler is Test { }), block.timestamp + 100 ); + + _accumulateFee(); } function exactSwapOutputSingle(uint128 amtIn, bool isNativePool) public { @@ -174,6 +186,8 @@ contract CLSwapRouterHandler is Test { isNativePool ? nativeTokenMinted += amtIn : token0Minted += amtIn; // Step 3: swap + vm.recordLogs(); + PoolKey memory pk = isNativePool ? nativePoolKey : poolKey; // if native pool, have to ensure call method with value uint256 value = isNativePool ? amtIn : 0; @@ -190,6 +204,8 @@ contract CLSwapRouterHandler is Test { }), block.timestamp + 100 ); + + _accumulateFee(); } function exactSwapOutput(uint128 amtIn, bool isNativePool) public { @@ -203,8 +219,9 @@ contract CLSwapRouterHandler is Test { isNativePool ? nativeTokenMinted += amtIn : token0Minted += amtIn; // Step 3: swap + vm.recordLogs(); + PoolKey memory pk = isNativePool ? nativePoolKey : poolKey; - vm.prank(alice); ISwapRouterBase.PathKey[] memory path = new ISwapRouterBase.PathKey[](1); path[0] = ISwapRouterBase.PathKey({ intermediateCurrency: isNativePool ? CurrencyLibrary.NATIVE : currency0, @@ -217,6 +234,7 @@ contract CLSwapRouterHandler is Test { // if native pool, have to ensure call method with value uint256 value = isNativePool ? amtIn : 0; + vm.prank(alice); router.exactOutput{value: value}( ICLSwapRouterBase.V4CLExactOutputParams({ currencyOut: currency1, @@ -227,6 +245,8 @@ contract CLSwapRouterHandler is Test { }), block.timestamp + 60 ); + + _accumulateFee(); } function _mint(uint128 amt, bool isNativePool) private { @@ -239,8 +259,6 @@ contract CLSwapRouterHandler is Test { token1Minted += amt; token1.mint(alice, amt); - vm.startPrank(alice); - PoolKey memory pk = isNativePool ? nativePoolKey : poolKey; INonfungiblePositionManager.MintParams memory mintParams = INonfungiblePositionManager.MintParams({ poolKey: pk, @@ -250,10 +268,11 @@ contract CLSwapRouterHandler is Test { amount1Desired: amt, amount0Min: 0, amount1Min: 0, - recipient: address(this), + recipient: alice, deadline: block.timestamp }); + vm.startPrank(alice); if (isNativePool) { positionManager.mint{value: amt}(mintParams); } else { @@ -261,6 +280,29 @@ contract CLSwapRouterHandler is Test { } vm.stopPrank(); } + + function _accumulateFee() private { + // event Swap( + // PoolId indexed id, + // address indexed sender, + // int128 amount0, + // int128 amount1, + // uint160 sqrtPriceX96, + // uint128 liquidity, + // int24 tick, + // uint24 fee, + // uint256 protocolFee + // ); + Vm.Log[] memory entries = vm.getRecordedLogs(); + (int128 amount0, int128 amount1,,,,,) = + abi.decode(entries[0].data, (int128, int128, uint160, uint128, int24, uint24, uint256)); + + if (amount0 < 0) { + token1FeeAccrued += uint128(amount1) * 3000 / 1e6; + } else { + token0FeeAccrued += uint128(amount0) * 3000 / 1e6; + } + } } contract CLSwapRouterInvariant is Test { @@ -302,4 +344,32 @@ contract CLSwapRouterInvariant is Test { uint256 routerBalance = address(_handler.router()).balance; assertEq(nativeTokenInVault + nativeTokenWithAlice + routerBalance, _handler.nativeTokenMinted()); } + + function invariant_AllSwapFeeGoesToLP() public { + INonfungiblePositionManager positionManager = INonfungiblePositionManager(_handler.positionManager()); + + uint256 positionTokenAmt = positionManager.balanceOf(_handler.alice()); + + uint256 realFee0Accrued = 0; + uint256 realFee1Accrued = 0; + for (uint256 i = 0; i < positionTokenAmt; i++) { + uint256 tokenId = positionManager.tokenOfOwnerByIndex(_handler.alice(), i); + vm.prank(_handler.alice()); + positionManager.approve(address(this), tokenId); + (uint256 _realFee0Accrued, uint256 _realFee1Accrued) = positionManager.collect( + INonfungiblePositionManager.CollectParams({ + tokenId: tokenId, + recipient: _handler.alice(), + amount0Max: type(uint128).max, + amount1Max: type(uint128).max + }) + ); + realFee0Accrued += _realFee0Accrued; + realFee1Accrued += _realFee1Accrued; + } + + /// @dev due to the precision loss, fee accrued might not be exactly the same + assertLe(_handler.token0FeeAccrued() - realFee0Accrued, 10); + assertLe(_handler.token1FeeAccrued() - realFee1Accrued, 10); + } }