Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace big.Int and add test #16

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/daoleno/uniswap-sdk-core/entities"
"github.com/ethereum/go-ethereum/common"
"github.com/holiman/uint256"
)

const PoolInitCodeHash = "0xe34f199b19b2b4f47f68442619d555527d244f78a3297ea89325f843f87b8b54"
Expand Down Expand Up @@ -51,5 +52,7 @@ var (
Q96 = new(big.Int).Exp(big.NewInt(2), big.NewInt(96), nil)
Q192 = new(big.Int).Exp(Q96, big.NewInt(2), nil)

Q96U256 = new(uint256.Int).Exp(uint256.NewInt(2), uint256.NewInt(96))

PercentZero = entities.NewFraction(big.NewInt(0), big.NewInt(1))
)
2 changes: 1 addition & 1 deletion entities/nearestusabletick.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package entities
import (
"math"

"github.com/daoleno/uniswapv3-sdk/utils"
"github.com/KyberNetwork/uniswapv3-sdk-uint256/utils"
)

/**
Expand Down
2 changes: 1 addition & 1 deletion entities/nearestusabletick_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package entities
import (
"testing"

"github.com/daoleno/uniswapv3-sdk/utils"
"github.com/KyberNetwork/uniswapv3-sdk-uint256/utils"
"github.com/stretchr/testify/assert"
)

Expand Down
139 changes: 88 additions & 51 deletions entities/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"errors"
"math/big"

"github.com/KyberNetwork/int256"
"github.com/KyberNetwork/uniswapv3-sdk-uint256/constants"
"github.com/KyberNetwork/uniswapv3-sdk-uint256/utils"
"github.com/daoleno/uniswap-sdk-core/entities"
"github.com/daoleno/uniswapv3-sdk/constants"
"github.com/daoleno/uniswapv3-sdk/utils"
"github.com/ethereum/go-ethereum/common"
"github.com/holiman/uint256"
)

var (
Expand All @@ -19,22 +21,22 @@ var (
)

type StepComputations struct {
sqrtPriceStartX96 *big.Int
sqrtPriceStartX96 *utils.Uint160
tickNext int
initialized bool
sqrtPriceNextX96 *big.Int
amountIn *big.Int
amountOut *big.Int
feeAmount *big.Int
sqrtPriceNextX96 *utils.Uint160
amountIn *utils.Uint256
amountOut *utils.Uint256
feeAmount *utils.Uint256
}

// Represents a V3 pool
type Pool struct {
Token0 *entities.Token
Token1 *entities.Token
Fee constants.FeeAmount
SqrtRatioX96 *big.Int
Liquidity *big.Int
SqrtRatioX96 *utils.Uint160
Liquidity *utils.Uint128
TickCurrent int
TickDataProvider TickDataProvider

Expand All @@ -43,10 +45,10 @@ type Pool struct {
}

type SwapResult struct {
amountCalculated *big.Int
sqrtRatioX96 *big.Int
liquidity *big.Int
remainingAmountIn *big.Int
amountCalculated *utils.Int256
sqrtRatioX96 *utils.Uint160
liquidity *utils.Uint128
remainingAmountIn *utils.Int256
currentTick int
crossInitTickLoops int
}
Expand All @@ -62,6 +64,17 @@ func GetAddress(tokenA, tokenB *entities.Token, fee constants.FeeAmount, initCod
return utils.ComputePoolAddress(constants.FactoryAddress, tokenA, tokenB, fee, initCodeHashManualOverride)
}

// deprecated
func NewPool(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRatioX96 *big.Int, liquidity *big.Int, tickCurrent int, ticks TickDataProvider) (*Pool, error) {
return NewPoolV2(
tokenA, tokenB, fee,
uint256.MustFromBig(sqrtRatioX96),
uint256.MustFromBig(liquidity),
tickCurrent,
ticks,
)
}

/**
* Construct a pool
* @param tokenA One of the tokens in the pool
Expand All @@ -72,16 +85,16 @@ func GetAddress(tokenA, tokenB *entities.Token, fee constants.FeeAmount, initCod
* @param tickCurrent The current tick of the pool
* @param ticks The current state of the pool ticks or a data provider that can return tick data
*/
func NewPool(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRatioX96 *big.Int, liquidity *big.Int, tickCurrent int, ticks TickDataProvider) (*Pool, error) {
func NewPoolV2(tokenA, tokenB *entities.Token, fee constants.FeeAmount, sqrtRatioX96 *utils.Uint160, liquidity *utils.Uint128, tickCurrent int, ticks TickDataProvider) (*Pool, error) {
if fee >= constants.FeeMax {
return nil, ErrFeeTooHigh
}

tickCurrentSqrtRatioX96, err := utils.GetSqrtRatioAtTick(tickCurrent)
tickCurrentSqrtRatioX96, err := utils.GetSqrtRatioAtTickV2(tickCurrent)
if err != nil {
return nil, err
}
nextTickSqrtRatioX96, err := utils.GetSqrtRatioAtTick(tickCurrent + 1)
nextTickSqrtRatioX96, err := utils.GetSqrtRatioAtTickV2(tickCurrent + 1)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -125,7 +138,7 @@ func (p *Pool) Token0Price() *entities.Price {
if p.token0Price != nil {
return p.token0Price
}
p.token0Price = entities.NewPrice(p.Token0, p.Token1, constants.Q192, new(big.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96))
p.token0Price = entities.NewPrice(p.Token0, p.Token1, constants.Q192, new(uint256.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96).ToBig())
return p.token0Price
}

Expand All @@ -134,7 +147,7 @@ func (p *Pool) Token1Price() *entities.Price {
if p.token1Price != nil {
return p.token1Price
}
p.token1Price = entities.NewPrice(p.Token1, p.Token0, new(big.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96), constants.Q192)
p.token1Price = entities.NewPrice(p.Token1, p.Token0, new(uint256.Int).Mul(p.SqrtRatioX96, p.SqrtRatioX96).ToBig(), constants.Q192)
return p.token1Price
}

Expand Down Expand Up @@ -164,12 +177,16 @@ func (p *Pool) ChainID() uint {
* @param sqrtPriceLimitX96 The Q64.96 sqrt price limit
* @returns The output amount and the pool with updated state
*/
func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *big.Int) (*GetAmountResult, error) {
func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *utils.Uint160) (*GetAmountResult, error) {
if !(inputAmount.Currency.IsToken() && p.InvolvesToken(inputAmount.Currency.Wrapped())) {
return nil, ErrTokenNotInvolved
}
zeroForOne := inputAmount.Currency.Equal(p.Token0)
swapResult, err := p.swap(zeroForOne, inputAmount.Quotient(), sqrtPriceLimitX96)
q, err := int256.FromBig(inputAmount.Quotient())
if err != nil {
return nil, err
}
swapResult, err := p.swap(zeroForOne, q, sqrtPriceLimitX96)
if err != nil {
return nil, err
}
Expand All @@ -179,7 +196,7 @@ func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLi
} else {
outputToken = p.Token0
}
pool, err := NewPool(
pool, err := NewPoolV2(
p.Token0,
p.Token1,
p.Fee,
Expand All @@ -192,8 +209,8 @@ func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLi
return nil, err
}
return &GetAmountResult{
ReturnedAmount: entities.FromRawAmount(outputToken, new(big.Int).Mul(swapResult.amountCalculated, constants.NegativeOne)),
RemainingAmountIn: entities.FromRawAmount(inputAmount.Currency, swapResult.remainingAmountIn),
ReturnedAmount: entities.FromRawAmount(outputToken, new(utils.Int256).Neg(swapResult.amountCalculated).ToBig()),
RemainingAmountIn: entities.FromRawAmount(inputAmount.Currency, swapResult.remainingAmountIn.ToBig()),
NewPoolState: pool,
CrossInitTickLoops: swapResult.crossInitTickLoops,
}, nil
Expand All @@ -205,12 +222,17 @@ func (p *Pool) GetOutputAmount(inputAmount *entities.CurrencyAmount, sqrtPriceLi
* @param sqrtPriceLimitX96 The Q64.96 sqrt price limit. If zero for one, the price cannot be less than this value after the swap. If one for zero, the price cannot be greater than this value after the swap
* @returns The input amount and the pool with updated state
*/
func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *big.Int) (*entities.CurrencyAmount, *Pool, error) {
func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLimitX96 *utils.Uint160) (*entities.CurrencyAmount, *Pool, error) {
if !(outputAmount.Currency.IsToken() && p.InvolvesToken(outputAmount.Currency.Wrapped())) {
return nil, nil, ErrTokenNotInvolved
}
zeroForOne := outputAmount.Currency.Equal(p.Token1)
swapResult, err := p.swap(zeroForOne, new(big.Int).Mul(outputAmount.Quotient(), constants.NegativeOne), sqrtPriceLimitX96)
q, err := int256.FromBig(outputAmount.Quotient())
if err != nil {
return nil, nil, err
}
q.Neg(q)
swapResult, err := p.swap(zeroForOne, q, sqrtPriceLimitX96)
if err != nil {
return nil, nil, err
}
Expand All @@ -220,7 +242,7 @@ func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLi
} else {
inputToken = p.Token1
}
pool, err := NewPool(
pool, err := NewPoolV2(
p.Token0,
p.Token1,
p.Fee,
Expand All @@ -232,7 +254,7 @@ func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLi
if err != nil {
return nil, nil, err
}
return entities.FromRawAmount(inputToken, swapResult.amountCalculated), pool, nil
return entities.FromRawAmount(inputToken, swapResult.amountCalculated.ToBig()), pool, nil
}

/**
Expand All @@ -245,56 +267,56 @@ func (p *Pool) GetInputAmount(outputAmount *entities.CurrencyAmount, sqrtPriceLi
* @returns swapResult.liquidity
* @returns swapResult.tickCurrent
*/
func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int) (*SwapResult, error) {
func (p *Pool) swap(zeroForOne bool, amountSpecified *utils.Int256, sqrtPriceLimitX96 *utils.Uint160) (*SwapResult, error) {
var err error
if sqrtPriceLimitX96 == nil {
if zeroForOne {
sqrtPriceLimitX96 = new(big.Int).Add(utils.MinSqrtRatio, constants.One)
sqrtPriceLimitX96 = new(uint256.Int).AddUint64(utils.MinSqrtRatioU256, 1)
} else {
sqrtPriceLimitX96 = new(big.Int).Sub(utils.MaxSqrtRatio, constants.One)
sqrtPriceLimitX96 = new(uint256.Int).SubUint64(utils.MaxSqrtRatioU256, 1)
}
}

if zeroForOne {
if sqrtPriceLimitX96.Cmp(utils.MinSqrtRatio) < 0 {
if sqrtPriceLimitX96.Cmp(utils.MinSqrtRatioU256) < 0 {
return nil, ErrSqrtPriceLimitX96TooLow
}
if sqrtPriceLimitX96.Cmp(p.SqrtRatioX96) >= 0 {
return nil, ErrSqrtPriceLimitX96TooHigh
}
} else {
if sqrtPriceLimitX96.Cmp(utils.MaxSqrtRatio) > 0 {
if sqrtPriceLimitX96.Cmp(utils.MaxSqrtRatioU256) > 0 {
return nil, ErrSqrtPriceLimitX96TooHigh
}
if sqrtPriceLimitX96.Cmp(p.SqrtRatioX96) <= 0 {
return nil, ErrSqrtPriceLimitX96TooLow
}
}

exactInput := amountSpecified.Cmp(constants.Zero) >= 0
exactInput := amountSpecified.Sign() >= 0

// keep track of swap state

state := struct {
amountSpecifiedRemaining *big.Int
amountCalculated *big.Int
sqrtPriceX96 *big.Int
amountSpecifiedRemaining *utils.Int256
amountCalculated *utils.Int256
sqrtPriceX96 *utils.Uint160
tick int
liquidity *big.Int
liquidity *utils.Uint128
}{
amountSpecifiedRemaining: amountSpecified,
amountCalculated: constants.Zero,
sqrtPriceX96: p.SqrtRatioX96,
amountSpecifiedRemaining: new(utils.Int256).Set(amountSpecified),
amountCalculated: int256.NewInt(0),
sqrtPriceX96: new(utils.Uint160).Set(p.SqrtRatioX96),
tick: p.TickCurrent,
liquidity: p.Liquidity,
liquidity: new(utils.Uint128).Set(p.Liquidity),
}

// crossInitTickLoops is the number of loops that cross an initialized tick.
// We only count when tick passes an initialized tick, since gas only significant in this case.
crossInitTickLoops := 0

// start swap while loop
for state.amountSpecifiedRemaining.Cmp(constants.Zero) != 0 && state.sqrtPriceX96.Cmp(sqrtPriceLimitX96) != 0 {
for !state.amountSpecifiedRemaining.IsZero() && state.sqrtPriceX96.Cmp(sqrtPriceLimitX96) != 0 {
var step StepComputations
step.sqrtPriceStartX96 = state.sqrtPriceX96

Expand All @@ -312,11 +334,11 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int
step.tickNext = utils.MaxTick
}

step.sqrtPriceNextX96, err = utils.GetSqrtRatioAtTick(step.tickNext)
step.sqrtPriceNextX96, err = utils.GetSqrtRatioAtTickV2(step.tickNext)
if err != nil {
return nil, err
}
var targetValue *big.Int
var targetValue *utils.Uint160
if zeroForOne {
if step.sqrtPriceNextX96.Cmp(sqrtPriceLimitX96) < 0 {
targetValue = sqrtPriceLimitX96
Expand All @@ -336,12 +358,27 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int
return nil, err
}

var amountInPlusFee utils.Uint256
amountInPlusFee.Add(step.amountIn, step.feeAmount)

var amountInPlusFeeSigned utils.Int256
err = utils.ToInt256(&amountInPlusFee, &amountInPlusFeeSigned)
if err != nil {
return nil, err
}

var amountOutSigned utils.Int256
err = utils.ToInt256(step.amountOut, &amountOutSigned)
if err != nil {
return nil, err
}

if exactInput {
state.amountSpecifiedRemaining = new(big.Int).Sub(state.amountSpecifiedRemaining, new(big.Int).Add(step.amountIn, step.feeAmount))
state.amountCalculated = new(big.Int).Sub(state.amountCalculated, step.amountOut)
state.amountSpecifiedRemaining.Sub(state.amountSpecifiedRemaining, &amountInPlusFeeSigned)
state.amountCalculated.Sub(state.amountCalculated, &amountOutSigned)
} else {
state.amountSpecifiedRemaining = new(big.Int).Add(state.amountSpecifiedRemaining, step.amountOut)
state.amountCalculated = new(big.Int).Add(state.amountCalculated, new(big.Int).Add(step.amountIn, step.feeAmount))
state.amountSpecifiedRemaining.Add(state.amountSpecifiedRemaining, &amountOutSigned)
state.amountCalculated.Add(state.amountCalculated, &amountInPlusFeeSigned)
}

// TODO
Expand All @@ -357,9 +394,9 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int
// if we're moving leftward, we interpret liquidityNet as the opposite sign
// safe because liquidityNet cannot be type(int128).min
if zeroForOne {
liquidityNet = new(big.Int).Mul(liquidityNet, constants.NegativeOne)
liquidityNet = new(utils.Int128).Neg(liquidityNet)
}
state.liquidity = utils.AddDelta(state.liquidity, liquidityNet)
utils.AddDeltaInPlace(state.liquidity, liquidityNet)

crossInitTickLoops++
}
Expand All @@ -371,7 +408,7 @@ func (p *Pool) swap(zeroForOne bool, amountSpecified, sqrtPriceLimitX96 *big.Int

} else if state.sqrtPriceX96.Cmp(step.sqrtPriceStartX96) != 0 {
// recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved
state.tick, err = utils.GetTickAtSqrtRatio(state.sqrtPriceX96)
state.tick, err = utils.GetTickAtSqrtRatioV2(state.sqrtPriceX96)
if err != nil {
return nil, err
}
Expand Down
Loading