From ceba3781307613475b864318ac3f016ca86d7d91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=A1s=20Venturo?= Date: Wed, 29 Sep 2021 20:05:07 -0300 Subject: [PATCH] Avoid double invariant computation of invariant on Phantom Pools (#872) --- .../contracts/StablePhantomPool.sol | 62 ++++++++++++------- pkg/pool-stable/contracts/StableMath.sol | 16 ++--- pkg/pool-stable/contracts/StablePool.sol | 24 ++++++- .../contracts/test/MockStableMath.sol | 20 +++++- 4 files changed, 84 insertions(+), 38 deletions(-) diff --git a/pkg/pool-stable-phantom/contracts/StablePhantomPool.sol b/pkg/pool-stable-phantom/contracts/StablePhantomPool.sol index 0d999bc3f5..e9fba55a04 100644 --- a/pkg/pool-stable-phantom/contracts/StablePhantomPool.sol +++ b/pkg/pool-stable-phantom/contracts/StablePhantomPool.sol @@ -205,26 +205,33 @@ contract StablePhantomPool is StablePool { // the equivalent BPT amount that accounts for that growth and finally extract the percentage that // corresponds to protocol fees. - uint256 newIndexIn = _skipBptIndex(indexIn); - uint256 newIndexOut = _skipBptIndex(indexOut); - - // Note however that we can skip all of this if there are no protocol fees to be paid! - if (protocolSwapFeePercentage == 0) { - amountOut = super._onSwapGivenIn(request, balances, _skipBptIndex(indexIn), _skipBptIndex(indexOut)); - } else { - (uint256 amp, ) = _getAmplificationParameter(); - - uint256 previousInvariant = StableMath._calculateInvariant(amp, balances, true); + // Since the original StablePool._onSwapGivenIn implementation already computes the invariant, we fully + // replace it and reimplement it here to take advtange of that. + + (uint256 currentAmp, ) = _getAmplificationParameter(); + uint256 invariant = StableMath._calculateInvariant(currentAmp, balances, true); + + amountOut = StableMath._calcOutGivenIn( + currentAmp, + balances, + _skipBptIndex(indexIn), + _skipBptIndex(indexOut), + request.amount, + invariant + ); - amountOut = super._onSwapGivenIn(request, balances, _skipBptIndex(indexIn), _skipBptIndex(indexOut)); + if (protocolSwapFeePercentage > 0) { + // We could've stored these indices in stack variables, but that causes stack-too-deep issues. + uint256 newIndexIn = _skipBptIndex(indexIn); + uint256 newIndexOut = _skipBptIndex(indexOut); uint256 amountInWithFee = _addSwapFeeAmount(request.amount); balances[newIndexIn] = balances[newIndexIn].add(amountInWithFee); balances[newIndexOut] = balances[newIndexOut].sub(amountOut); _trackDueProtocolFeeByInvariantIncrement( - previousInvariant, - amp, + invariant, + currentAmp, balances, virtualSupply, protocolSwapFeePercentage @@ -264,26 +271,33 @@ contract StablePhantomPool is StablePool { // the equivalent BPT amount that accounts for that growth and finally extract the percentage that // corresponds to protocol fees. - uint256 newIndexIn = _skipBptIndex(indexIn); - uint256 newIndexOut = _skipBptIndex(indexOut); + // Since the original StablePool._onSwapGivenOut implementation already computes the invariant, we fully + // replace it and reimplement it here to take advtange of that. - // Note however that we can skip all of this if there are no protocol fees to be paid! - if (protocolSwapFeePercentage == 0) { - amountIn = super._onSwapGivenOut(request, balances, newIndexIn, newIndexOut); - } else { - (uint256 amp, ) = _getAmplificationParameter(); + (uint256 currentAmp, ) = _getAmplificationParameter(); + uint256 invariant = StableMath._calculateInvariant(currentAmp, balances, true); - uint256 previousInvariant = StableMath._calculateInvariant(amp, balances, true); + amountIn = StableMath._calcInGivenOut( + currentAmp, + balances, + _skipBptIndex(indexIn), + _skipBptIndex(indexOut), + request.amount, + invariant + ); - amountIn = super._onSwapGivenOut(request, balances, newIndexIn, newIndexOut); + if (protocolSwapFeePercentage > 0) { + // We could've stored these indices in stack variables, but that causes stack-too-deep issues. + uint256 newIndexIn = _skipBptIndex(indexIn); + uint256 newIndexOut = _skipBptIndex(indexOut); uint256 amountInWithFee = _addSwapFeeAmount(amountIn); balances[newIndexIn] = balances[newIndexIn].add(amountInWithFee); balances[newIndexOut] = balances[newIndexOut].sub(request.amount); _trackDueProtocolFeeByInvariantIncrement( - previousInvariant, - amp, + invariant, + currentAmp, balances, virtualSupply, protocolSwapFeePercentage diff --git a/pkg/pool-stable/contracts/StableMath.sol b/pkg/pool-stable/contracts/StableMath.sol index 4123c73a5c..a42b857db6 100644 --- a/pkg/pool-stable/contracts/StableMath.sol +++ b/pkg/pool-stable/contracts/StableMath.sol @@ -112,12 +112,14 @@ library StableMath { // Computes how many tokens can be taken out of a pool if `tokenAmountIn` are sent, given the current balances. // The amplification parameter equals: A n^(n-1) + // The invariant should be rounded up. function _calcOutGivenIn( uint256 amplificationParameter, uint256[] memory balances, uint256 tokenIndexIn, uint256 tokenIndexOut, - uint256 tokenAmountIn + uint256 tokenAmountIn, + uint256 invariant ) internal pure returns (uint256) { /************************************************************************************************************** // outGivenIn token x for y - polynomial equation to solve // @@ -132,10 +134,6 @@ library StableMath { **************************************************************************************************************/ // Amount out, so we round down overall. - - // Given that we need to have a greater final balance out, the invariant needs to be rounded up - uint256 invariant = _calculateInvariant(amplificationParameter, balances, true); - balances[tokenIndexIn] = balances[tokenIndexIn].add(tokenAmountIn); uint256 finalBalanceOut = _getTokenBalanceGivenInvariantAndAllOtherBalances( @@ -155,12 +153,14 @@ library StableMath { // Computes how many tokens must be sent to a pool if `tokenAmountOut` are sent given the // current balances, using the Newton-Raphson approximation. // The amplification parameter equals: A n^(n-1) + // The invariant should be rounded up. function _calcInGivenOut( uint256 amplificationParameter, uint256[] memory balances, uint256 tokenIndexIn, uint256 tokenIndexOut, - uint256 tokenAmountOut + uint256 tokenAmountOut, + uint256 invariant ) internal pure returns (uint256) { /************************************************************************************************************** // inGivenOut token x for y - polynomial equation to solve // @@ -175,10 +175,6 @@ library StableMath { **************************************************************************************************************/ // Amount in, so we round up overall. - - // Given that we need to have a greater final balance in, the invariant needs to be rounded up - uint256 invariant = _calculateInvariant(amplificationParameter, balances, true); - balances[tokenIndexOut] = balances[tokenIndexOut].sub(tokenAmountOut); uint256 finalBalanceIn = _getTokenBalanceGivenInvariantAndAllOtherBalances( diff --git a/pkg/pool-stable/contracts/StablePool.sol b/pkg/pool-stable/contracts/StablePool.sol index b38eee5ef1..21b50f91b4 100644 --- a/pkg/pool-stable/contracts/StablePool.sol +++ b/pkg/pool-stable/contracts/StablePool.sol @@ -144,7 +144,17 @@ contract StablePool is BaseGeneralPool, BaseMinimalSwapInfoPool, IRateProvider { uint256 indexOut ) internal virtual override whenNotPaused returns (uint256) { (uint256 currentAmp, ) = _getAmplificationParameter(); - uint256 amountOut = StableMath._calcOutGivenIn(currentAmp, balances, indexIn, indexOut, swapRequest.amount); + + uint256 invariant = StableMath._calculateInvariant(currentAmp, balances, true); + uint256 amountOut = StableMath._calcOutGivenIn( + currentAmp, + balances, + indexIn, + indexOut, + swapRequest.amount, + invariant + ); + return amountOut; } @@ -155,7 +165,17 @@ contract StablePool is BaseGeneralPool, BaseMinimalSwapInfoPool, IRateProvider { uint256 indexOut ) internal virtual override whenNotPaused returns (uint256) { (uint256 currentAmp, ) = _getAmplificationParameter(); - uint256 amountIn = StableMath._calcInGivenOut(currentAmp, balances, indexIn, indexOut, swapRequest.amount); + + uint256 invariant = StableMath._calculateInvariant(currentAmp, balances, true); + uint256 amountIn = StableMath._calcInGivenOut( + currentAmp, + balances, + indexIn, + indexOut, + swapRequest.amount, + invariant + ); + return amountIn; } diff --git a/pkg/pool-stable/contracts/test/MockStableMath.sol b/pkg/pool-stable/contracts/test/MockStableMath.sol index 7990004461..5fce2bf83e 100644 --- a/pkg/pool-stable/contracts/test/MockStableMath.sol +++ b/pkg/pool-stable/contracts/test/MockStableMath.sol @@ -32,7 +32,15 @@ contract MockStableMath { uint256 tokenIndexOut, uint256 tokenAmountIn ) external pure returns (uint256) { - return StableMath._calcOutGivenIn(amp, balances, tokenIndexIn, tokenIndexOut, tokenAmountIn); + return + StableMath._calcOutGivenIn( + amp, + balances, + tokenIndexIn, + tokenIndexOut, + tokenAmountIn, + StableMath._calculateInvariant(amp, balances, true) + ); } function inGivenOut( @@ -42,7 +50,15 @@ contract MockStableMath { uint256 tokenIndexOut, uint256 tokenAmountOut ) external pure returns (uint256) { - return StableMath._calcInGivenOut(amp, balances, tokenIndexIn, tokenIndexOut, tokenAmountOut); + return + StableMath._calcInGivenOut( + amp, + balances, + tokenIndexIn, + tokenIndexOut, + tokenAmountOut, + StableMath._calculateInvariant(amp, balances, true) + ); } function exactTokensInForBPTOut(