Skip to content

Commit

Permalink
Avoid double invariant computation of invariant on Phantom Pools (bal…
Browse files Browse the repository at this point in the history
  • Loading branch information
nventuro authored Sep 29, 2021
1 parent 1fe8822 commit ceba378
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 38 deletions.
62 changes: 38 additions & 24 deletions pkg/pool-stable-phantom/contracts/StablePhantomPool.sol
Original file line number Diff line number Diff line change
Expand Up @@ -205,26 +205,33 @@ contract StablePhantomPool is StablePool {
// the equivalent BPT amount that accounts for that growth and finally extract the percentage that
// corresponds to protocol fees.

uint256 newIndexIn = _skipBptIndex(indexIn);
uint256 newIndexOut = _skipBptIndex(indexOut);

// Note however that we can skip all of this if there are no protocol fees to be paid!
if (protocolSwapFeePercentage == 0) {
amountOut = super._onSwapGivenIn(request, balances, _skipBptIndex(indexIn), _skipBptIndex(indexOut));
} else {
(uint256 amp, ) = _getAmplificationParameter();

uint256 previousInvariant = StableMath._calculateInvariant(amp, balances, true);
// Since the original StablePool._onSwapGivenIn implementation already computes the invariant, we fully
// replace it and reimplement it here to take advtange of that.

(uint256 currentAmp, ) = _getAmplificationParameter();
uint256 invariant = StableMath._calculateInvariant(currentAmp, balances, true);

amountOut = StableMath._calcOutGivenIn(
currentAmp,
balances,
_skipBptIndex(indexIn),
_skipBptIndex(indexOut),
request.amount,
invariant
);

amountOut = super._onSwapGivenIn(request, balances, _skipBptIndex(indexIn), _skipBptIndex(indexOut));
if (protocolSwapFeePercentage > 0) {
// We could've stored these indices in stack variables, but that causes stack-too-deep issues.
uint256 newIndexIn = _skipBptIndex(indexIn);
uint256 newIndexOut = _skipBptIndex(indexOut);

uint256 amountInWithFee = _addSwapFeeAmount(request.amount);
balances[newIndexIn] = balances[newIndexIn].add(amountInWithFee);
balances[newIndexOut] = balances[newIndexOut].sub(amountOut);

_trackDueProtocolFeeByInvariantIncrement(
previousInvariant,
amp,
invariant,
currentAmp,
balances,
virtualSupply,
protocolSwapFeePercentage
Expand Down Expand Up @@ -264,26 +271,33 @@ contract StablePhantomPool is StablePool {
// the equivalent BPT amount that accounts for that growth and finally extract the percentage that
// corresponds to protocol fees.

uint256 newIndexIn = _skipBptIndex(indexIn);
uint256 newIndexOut = _skipBptIndex(indexOut);
// Since the original StablePool._onSwapGivenOut implementation already computes the invariant, we fully
// replace it and reimplement it here to take advtange of that.

// Note however that we can skip all of this if there are no protocol fees to be paid!
if (protocolSwapFeePercentage == 0) {
amountIn = super._onSwapGivenOut(request, balances, newIndexIn, newIndexOut);
} else {
(uint256 amp, ) = _getAmplificationParameter();
(uint256 currentAmp, ) = _getAmplificationParameter();
uint256 invariant = StableMath._calculateInvariant(currentAmp, balances, true);

uint256 previousInvariant = StableMath._calculateInvariant(amp, balances, true);
amountIn = StableMath._calcInGivenOut(
currentAmp,
balances,
_skipBptIndex(indexIn),
_skipBptIndex(indexOut),
request.amount,
invariant
);

amountIn = super._onSwapGivenOut(request, balances, newIndexIn, newIndexOut);
if (protocolSwapFeePercentage > 0) {
// We could've stored these indices in stack variables, but that causes stack-too-deep issues.
uint256 newIndexIn = _skipBptIndex(indexIn);
uint256 newIndexOut = _skipBptIndex(indexOut);

uint256 amountInWithFee = _addSwapFeeAmount(amountIn);
balances[newIndexIn] = balances[newIndexIn].add(amountInWithFee);
balances[newIndexOut] = balances[newIndexOut].sub(request.amount);

_trackDueProtocolFeeByInvariantIncrement(
previousInvariant,
amp,
invariant,
currentAmp,
balances,
virtualSupply,
protocolSwapFeePercentage
Expand Down
16 changes: 6 additions & 10 deletions pkg/pool-stable/contracts/StableMath.sol
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,14 @@ library StableMath {

// Computes how many tokens can be taken out of a pool if `tokenAmountIn` are sent, given the current balances.
// The amplification parameter equals: A n^(n-1)
// The invariant should be rounded up.
function _calcOutGivenIn(
uint256 amplificationParameter,
uint256[] memory balances,
uint256 tokenIndexIn,
uint256 tokenIndexOut,
uint256 tokenAmountIn
uint256 tokenAmountIn,
uint256 invariant
) internal pure returns (uint256) {
/**************************************************************************************************************
// outGivenIn token x for y - polynomial equation to solve //
Expand All @@ -132,10 +134,6 @@ library StableMath {
**************************************************************************************************************/

// Amount out, so we round down overall.

// Given that we need to have a greater final balance out, the invariant needs to be rounded up
uint256 invariant = _calculateInvariant(amplificationParameter, balances, true);

balances[tokenIndexIn] = balances[tokenIndexIn].add(tokenAmountIn);

uint256 finalBalanceOut = _getTokenBalanceGivenInvariantAndAllOtherBalances(
Expand All @@ -155,12 +153,14 @@ library StableMath {
// Computes how many tokens must be sent to a pool if `tokenAmountOut` are sent given the
// current balances, using the Newton-Raphson approximation.
// The amplification parameter equals: A n^(n-1)
// The invariant should be rounded up.
function _calcInGivenOut(
uint256 amplificationParameter,
uint256[] memory balances,
uint256 tokenIndexIn,
uint256 tokenIndexOut,
uint256 tokenAmountOut
uint256 tokenAmountOut,
uint256 invariant
) internal pure returns (uint256) {
/**************************************************************************************************************
// inGivenOut token x for y - polynomial equation to solve //
Expand All @@ -175,10 +175,6 @@ library StableMath {
**************************************************************************************************************/

// Amount in, so we round up overall.

// Given that we need to have a greater final balance in, the invariant needs to be rounded up
uint256 invariant = _calculateInvariant(amplificationParameter, balances, true);

balances[tokenIndexOut] = balances[tokenIndexOut].sub(tokenAmountOut);

uint256 finalBalanceIn = _getTokenBalanceGivenInvariantAndAllOtherBalances(
Expand Down
24 changes: 22 additions & 2 deletions pkg/pool-stable/contracts/StablePool.sol
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,17 @@ contract StablePool is BaseGeneralPool, BaseMinimalSwapInfoPool, IRateProvider {
uint256 indexOut
) internal virtual override whenNotPaused returns (uint256) {
(uint256 currentAmp, ) = _getAmplificationParameter();
uint256 amountOut = StableMath._calcOutGivenIn(currentAmp, balances, indexIn, indexOut, swapRequest.amount);

uint256 invariant = StableMath._calculateInvariant(currentAmp, balances, true);
uint256 amountOut = StableMath._calcOutGivenIn(
currentAmp,
balances,
indexIn,
indexOut,
swapRequest.amount,
invariant
);

return amountOut;
}

Expand All @@ -155,7 +165,17 @@ contract StablePool is BaseGeneralPool, BaseMinimalSwapInfoPool, IRateProvider {
uint256 indexOut
) internal virtual override whenNotPaused returns (uint256) {
(uint256 currentAmp, ) = _getAmplificationParameter();
uint256 amountIn = StableMath._calcInGivenOut(currentAmp, balances, indexIn, indexOut, swapRequest.amount);

uint256 invariant = StableMath._calculateInvariant(currentAmp, balances, true);
uint256 amountIn = StableMath._calcInGivenOut(
currentAmp,
balances,
indexIn,
indexOut,
swapRequest.amount,
invariant
);

return amountIn;
}

Expand Down
20 changes: 18 additions & 2 deletions pkg/pool-stable/contracts/test/MockStableMath.sol
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@ contract MockStableMath {
uint256 tokenIndexOut,
uint256 tokenAmountIn
) external pure returns (uint256) {
return StableMath._calcOutGivenIn(amp, balances, tokenIndexIn, tokenIndexOut, tokenAmountIn);
return
StableMath._calcOutGivenIn(
amp,
balances,
tokenIndexIn,
tokenIndexOut,
tokenAmountIn,
StableMath._calculateInvariant(amp, balances, true)
);
}

function inGivenOut(
Expand All @@ -42,7 +50,15 @@ contract MockStableMath {
uint256 tokenIndexOut,
uint256 tokenAmountOut
) external pure returns (uint256) {
return StableMath._calcInGivenOut(amp, balances, tokenIndexIn, tokenIndexOut, tokenAmountOut);
return
StableMath._calcInGivenOut(
amp,
balances,
tokenIndexIn,
tokenIndexOut,
tokenAmountOut,
StableMath._calculateInvariant(amp, balances, true)
);
}

function exactTokensInForBPTOut(
Expand Down

0 comments on commit ceba378

Please sign in to comment.