Skip to content

Commit

Permalink
replace big.Int and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
it4rb committed Mar 15, 2024
1 parent d376be1 commit 2935736
Show file tree
Hide file tree
Showing 28 changed files with 1,120 additions and 268 deletions.
3 changes: 3 additions & 0 deletions constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/daoleno/uniswap-sdk-core/entities"
"github.com/ethereum/go-ethereum/common"
"github.com/holiman/uint256"
)

const PoolInitCodeHash = "0xe34f199b19b2b4f47f68442619d555527d244f78a3297ea89325f843f87b8b54"
Expand Down Expand Up @@ -51,5 +52,7 @@ var (
Q96 = new(big.Int).Exp(big.NewInt(2), big.NewInt(96), nil)
Q192 = new(big.Int).Exp(Q96, big.NewInt(2), nil)

Q96U256 = new(uint256.Int).Exp(uint256.NewInt(2), uint256.NewInt(96))

PercentZero = entities.NewFraction(big.NewInt(0), big.NewInt(1))
)
135 changes: 86 additions & 49 deletions entities/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"errors"
"math/big"

"github.com/KyberNetwork/int256"
"github.com/KyberNetwork/uniswapv3-sdk-uint256/constants"
"github.com/KyberNetwork/uniswapv3-sdk-uint256/utils"
"github.com/daoleno/uniswap-sdk-core/entities"
"github.com/ethereum/go-ethereum/common"
"github.com/holiman/uint256"
)

var (
Expand All @@ -19,22 +21,22 @@ var (
)

type StepComputations struct {
sqrtPriceStartX96 *big.Int
sqrtPriceStartX96 *utils.Uint160
tickNext int
initialized bool
sqrtPriceNextX96 *big.Int
amountIn *big.Int
amountOut *big.Int
feeAmount *big.Int
sqrtPriceNextX96 *utils.Uint160
amountIn *utils.Uint256
amountOut *utils.Uint256
feeAmount *utils.Uint256
}

// Represents a V3 pool
type Pool struct {
Token0 *entities.Token
Token1 *entities.Token
Fee constants.FeeAmount
SqrtRatioX96 *big.Int
Liquidity *big.Int
SqrtRatioX96 *utils.Uint160
Liquidity *utils.Uint128
TickCurrent int
TickDataProvider TickDataProvider

Expand All @@ -43,10 +45,10 @@ type Pool struct {
}

type SwapResult struct {
amountCalculated *big.Int
sqrtRatioX96 *big.Int
liquidity *big.Int
remainingAmountIn *big.Int
amountCalculated *utils.Int256
sqrtRatioX96 *utils.Uint160
liquidity *utils.Uint128
remainingAmountIn *utils.Int256
currentTick int
crossInitTickLoops int
}
Expand All @@ -62,6 +64,17 @@ func GetAddress(tokenA, tokenB *entities.Token, fee constants.FeeAmount, initCod
return utils.ComputePoolAddress(constants.FactoryAddress, tokenA, tokenB, fee, initCodeHashManualOverride)
}

// deprecated
func NewPool(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRatioX96 *big.Int, liquidity *big.Int, tickCurrent int, ticks TickDataProvider) (*Pool, error) {
return NewPoolV2(
tokenA, tokenB, fee,
uint256.MustFromBig(sqrtRatioX96),
uint256.MustFromBig(liquidity),
tickCurrent,
ticks,
)
}

/**
* Construct a pool
* @param tokenA One of the tokens in the pool
Expand All @@ -72,16 +85,16 @@ func GetAddress(tokenA, tokenB *entities.Token, fee constants.FeeAmount, initCod
* @param tickCurrent The current tick of the pool
* @param ticks The current state of the pool ticks or a data provider that can return tick data
*/
func NewPool(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRatioX96 *big.Int, liquidity *big.Int, tickCurrent int, ticks TickDataProvider) (*Pool, error) {
func NewPoolV2(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRatioX96 *utils.Uint160, liquidity *utils.Uint128, tickCurrent int, ticks TickDataProvider) (*Pool, error) {
if fee >= constants.FeeMax {
return nil, ErrFeeTooHigh
}

tickCurrentSqrtRatioX96, err := utils.GetSqrtRatioAtTick(tickCurrent)
tickCurrentSqrtRatioX96, err := utils.GetSqrtRatioAtTickV2(tickCurrent)
if err != nil {
return nil, err
}
nextTickSqrtRatioX96, err := utils.GetSqrtRatioAtTick(tickCurrent + 1)
nextTickSqrtRatioX96, err := utils.GetSqrtRatioAtTickV2(tickCurrent + 1)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -125,7 +138,7 @@ func (p *Pool) Token0Price() *entities.Price {
if p.token0Price != nil {
return p.token0Price
}
p.token0Price = entities.NewPrice(p.Token0, p.Token1, constants.Q192, new(big.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96))
p.token0Price = entities.NewPrice(p.Token0, p.Token1, constants.Q192, new(uint256.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96).ToBig())
return p.token0Price
}

Expand All @@ -134,7 +147,7 @@ func (p *Pool) Token1Price() *entities.Price {
if p.token1Price != nil {
return p.token1Price
}
p.token1Price = entities.NewPrice(p.Token1, p.Token0, new(big.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96), constants.Q192)
p.token1Price = entities.NewPrice(p.Token1, p.Token0, new(uint256.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96).ToBig(), constants.Q192)
return p.token1Price
}

Expand Down Expand Up @@ -164,12 +177,16 @@ func (p *Pool) ChainID() uint {
* @param sqrtPriceLimitX96 The Q64.96 sqrt price limit
* @returns The output amount and the pool with updated state
*/
func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *big.Int) (*GetAmountResult, error) {
func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *utils.Uint160) (*GetAmountResult, error) {
if !(inputAmount.Currency.IsToken() && p.InvolvesToken(inputAmount.Currency.Wrapped())) {
return nil, ErrTokenNotInvolved
}
zeroForOne := inputAmount.Currency.Equal(p.Token0)
swapResult, err := p.swap(zeroForOne, inputAmount.Quotient(), sqrtPriceLimitX96)
q, err := int256.FromBig(inputAmount.Quotient())
if err != nil {
return nil, err
}
swapResult, err := p.swap(zeroForOne, q, sqrtPriceLimitX96)
if err != nil {
return nil, err
}
Expand All @@ -179,7 +196,7 @@ func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLi
} else {
outputToken = p.Token0
}
pool, err := NewPool(
pool, err := NewPoolV2(
p.Token0,
p.Token1,
p.Fee,
Expand All @@ -192,8 +209,8 @@ func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLi
return nil, err
}
return &GetAmountResult{
ReturnedAmount: entities.FromRawAmount(outputToken, new(big.Int).Mul(swapResult.amountCalculated, constants.NegativeOne)),
RemainingAmountIn: entities.FromRawAmount(inputAmount.Currency, swapResult.remainingAmountIn),
ReturnedAmount: entities.FromRawAmount(outputToken, new(utils.Int256).Neg(swapResult.amountCalculated).ToBig()),
RemainingAmountIn: entities.FromRawAmount(inputAmount.Currency, swapResult.remainingAmountIn.ToBig()),
NewPoolState: pool,
CrossInitTickLoops: swapResult.crossInitTickLoops,
}, nil
Expand All @@ -205,12 +222,17 @@ func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLi
* @param sqrtPriceLimitX96 The Q64.96 sqrt price limit. If zero for one, the price cannot be less than this value after the swap. If one for zero, the price cannot be greater than this value after the swap
* @returns The input amount and the pool with updated state
*/
func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *big.Int) (*entities.CurrencyAmount, *Pool, error) {
func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *utils.Uint160) (*entities.CurrencyAmount, *Pool, error) {
if !(outputAmount.Currency.IsToken() && p.InvolvesToken(outputAmount.Currency.Wrapped())) {
return nil, nil, ErrTokenNotInvolved
}
zeroForOne := outputAmount.Currency.Equal(p.Token1)
swapResult, err := p.swap(zeroForOne, new(big.Int).Mul(outputAmount.Quotient(), constants.NegativeOne), sqrtPriceLimitX96)
q, err := int256.FromBig(outputAmount.Quotient())
if err != nil {
return nil, nil, err
}
q.Neg(q)
swapResult, err := p.swap(zeroForOne, q, sqrtPriceLimitX96)
if err != nil {
return nil, nil, err
}
Expand All @@ -220,7 +242,7 @@ func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLi
} else {
inputToken = p.Token1
}
pool, err := NewPool(
pool, err := NewPoolV2(
p.Token0,
p.Token1,
p.Fee,
Expand All @@ -232,7 +254,7 @@ func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLi
if err != nil {
return nil, nil, err
}
return entities.FromRawAmount(inputToken, swapResult.amountCalculated), pool, nil
return entities.FromRawAmount(inputToken, swapResult.amountCalculated.ToBig()), pool, nil
}

/**
Expand All @@ -245,56 +267,56 @@ func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLi
* @returns swapResult.liquidity
* @returns swapResult.tickCurrent
*/
func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int) (*SwapResult, error) {
func (p *Pool) swap(zeroForOne bool, amountSpecified *utils.Int256, sqrtPriceLimitX96 *utils.Uint160) (*SwapResult, error) {
var err error
if sqrtPriceLimitX96 == nil {
if zeroForOne {
sqrtPriceLimitX96 = new(big.Int).Add(utils.MinSqrtRatio, constants.One)
sqrtPriceLimitX96 = new(uint256.Int).AddUint64(utils.MinSqrtRatioU256, 1)
} else {
sqrtPriceLimitX96 = new(big.Int).Sub(utils.MaxSqrtRatio, constants.One)
sqrtPriceLimitX96 = new(uint256.Int).SubUint64(utils.MaxSqrtRatioU256, 1)
}
}

if zeroForOne {
if sqrtPriceLimitX96.Cmp(utils.MinSqrtRatio) < 0 {
if sqrtPriceLimitX96.Cmp(utils.MinSqrtRatioU256) < 0 {
return nil, ErrSqrtPriceLimitX96TooLow
}
if sqrtPriceLimitX96.Cmp(p.SqrtRatioX96) >= 0 {
return nil, ErrSqrtPriceLimitX96TooHigh
}
} else {
if sqrtPriceLimitX96.Cmp(utils.MaxSqrtRatio) > 0 {
if sqrtPriceLimitX96.Cmp(utils.MaxSqrtRatioU256) > 0 {
return nil, ErrSqrtPriceLimitX96TooHigh
}
if sqrtPriceLimitX96.Cmp(p.SqrtRatioX96) <= 0 {
return nil, ErrSqrtPriceLimitX96TooLow
}
}

exactInput := amountSpecified.Cmp(constants.Zero) >= 0
exactInput := amountSpecified.Sign() >= 0

// keep track of swap state

state := struct {
amountSpecifiedRemaining *big.Int
amountCalculated *big.Int
sqrtPriceX96 *big.Int
amountSpecifiedRemaining *utils.Int256
amountCalculated *utils.Int256
sqrtPriceX96 *utils.Uint160
tick int
liquidity *big.Int
liquidity *utils.Uint128
}{
amountSpecifiedRemaining: amountSpecified,
amountCalculated: constants.Zero,
sqrtPriceX96: p.SqrtRatioX96,
amountSpecifiedRemaining: new(utils.Int256).Set(amountSpecified),
amountCalculated: int256.NewInt(0),
sqrtPriceX96: new(utils.Uint160).Set(p.SqrtRatioX96),
tick: p.TickCurrent,
liquidity: p.Liquidity,
liquidity: new(utils.Uint128).Set(p.Liquidity),
}

// crossInitTickLoops is the number of loops that cross an initialized tick.
// We only count when tick passes an initialized tick, since gas only significant in this case.
crossInitTickLoops := 0

// start swap while loop
for state.amountSpecifiedRemaining.Cmp(constants.Zero) != 0 && state.sqrtPriceX96.Cmp(sqrtPriceLimitX96) != 0 {
for !state.amountSpecifiedRemaining.IsZero() && state.sqrtPriceX96.Cmp(sqrtPriceLimitX96) != 0 {
var step StepComputations
step.sqrtPriceStartX96 = state.sqrtPriceX96

Expand All @@ -312,11 +334,11 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int
step.tickNext = utils.MaxTick
}

step.sqrtPriceNextX96, err = utils.GetSqrtRatioAtTick(step.tickNext)
step.sqrtPriceNextX96, err = utils.GetSqrtRatioAtTickV2(step.tickNext)
if err != nil {
return nil, err
}
var targetValue *big.Int
var targetValue *utils.Uint160
if zeroForOne {
if step.sqrtPriceNextX96.Cmp(sqrtPriceLimitX96) < 0 {
targetValue = sqrtPriceLimitX96
Expand All @@ -336,12 +358,27 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int
return nil, err
}

var amountInPlusFee utils.Uint256
amountInPlusFee.Add(step.amountIn, step.feeAmount)

var amountInPlusFeeSigned utils.Int256
err = utils.ToInt256(&amountInPlusFee, &amountInPlusFeeSigned)
if err != nil {
return nil, err
}

var amountOutSigned utils.Int256
err = utils.ToInt256(step.amountOut, &amountOutSigned)
if err != nil {
return nil, err
}

if exactInput {
state.amountSpecifiedRemaining = new(big.Int).Sub(state.amountSpecifiedRemaining, new(big.Int).Add(step.amountIn, step.feeAmount))
state.amountCalculated = new(big.Int).Sub(state.amountCalculated, step.amountOut)
state.amountSpecifiedRemaining.Sub(state.amountSpecifiedRemaining, &amountInPlusFeeSigned)
state.amountCalculated.Sub(state.amountCalculated, &amountOutSigned)
} else {
state.amountSpecifiedRemaining = new(big.Int).Add(state.amountSpecifiedRemaining, step.amountOut)
state.amountCalculated = new(big.Int).Add(state.amountCalculated, new(big.Int).Add(step.amountIn, step.feeAmount))
state.amountSpecifiedRemaining.Add(state.amountSpecifiedRemaining, &amountOutSigned)
state.amountCalculated.Add(state.amountCalculated, &amountInPlusFeeSigned)
}

// TODO
Expand All @@ -357,9 +394,9 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int
// if we're moving leftward, we interpret liquidityNet as the opposite sign
// safe because liquidityNet cannot be type(int128).min
if zeroForOne {
liquidityNet = new(big.Int).Mul(liquidityNet, constants.NegativeOne)
liquidityNet = new(utils.Int128).Neg(liquidityNet)
}
state.liquidity = utils.AddDelta(state.liquidity, liquidityNet)
utils.AddDeltaInPlace(state.liquidity, liquidityNet)

crossInitTickLoops++
}
Expand All @@ -371,7 +408,7 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int

} else if state.sqrtPriceX96.Cmp(step.sqrtPriceStartX96) != 0 {
// recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved
state.tick, err = utils.GetTickAtSqrtRatio(state.sqrtPriceX96)
state.tick, err = utils.GetTickAtSqrtRatioV2(state.sqrtPriceX96)
if err != nil {
return nil, err
}
Expand Down
13 changes: 9 additions & 4 deletions entities/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,22 @@ import (
"math/big"
"testing"

"github.com/KyberNetwork/int256"
"github.com/KyberNetwork/uniswapv3-sdk-uint256/constants"
"github.com/KyberNetwork/uniswapv3-sdk-uint256/utils"
"github.com/daoleno/uniswap-sdk-core/entities"
"github.com/ethereum/go-ethereum/common"
"github.com/holiman/uint256"
"github.com/stretchr/testify/assert"
)

var (
USDC = entities.NewToken(1, common.HexToAddress("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"), 6, "USDC", "USD Coin")
DAI = entities.NewToken(1, common.HexToAddress("0x6B175474E89094C44Da98b954EedeAC495271d0F"), 18, "DAI", "Dai Stablecoin")
OneEther = big.NewInt(1e18)

OneEtherI256 = int256.NewInt(1e18)
OneEtherUI256 = uint256.NewInt(1e18)
)

func TestNewPool(t *testing.T) {
Expand Down Expand Up @@ -116,13 +121,13 @@ func newTestPool() *Pool {
ticks := []Tick{
{
Index: NearestUsableTick(utils.MinTick, constants.TickSpacings[constants.FeeLow]),
LiquidityNet: OneEther,
LiquidityGross: OneEther,
LiquidityNet: OneEtherI256,
LiquidityGross: OneEtherUI256,
},
{
Index: NearestUsableTick(utils.MaxTick, constants.TickSpacings[constants.FeeLow]),
LiquidityNet: new(big.Int).Mul(OneEther, constants.NegativeOne),
LiquidityGross: OneEther,
LiquidityNet: new(int256.Int).Neg(OneEtherI256),
LiquidityGross: OneEtherUI256,
},
}

Expand Down
Loading

0 comments on commit 2935736

Please sign in to comment.