Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
palango committed Oct 26, 2023
1 parent 96f0875 commit 2673d28
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 9 deletions.
32 changes: 23 additions & 9 deletions core/txpool/legacypool/celo.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@ import (
"github.com/ethereum/go-ethereum/params"
)

type ExchangeRates = map[common.Address]*big.Rat

type celoContext interface {
IsWhitelisted(feeCurrency common.Address) (bool, error)
GetBalanceOf(account common.Address, feeCurrency common.Address) (*big.Int, error)
CompareFees(fee1 *big.Int, feCurrency1 common.Address, fee2 *big.Int, feeCurrency2 common.Address) int
CompareValue(fee1 *big.Int, feCurrency1 common.Address, fee2 *big.Int, feeCurrency2 common.Address) int
}

type celoContextImpl struct {
backend *core.CeloBackend
registry *abigen.RegistryCaller
exchangeRates map[common.Address]*big.Rat
exchangeRates ExchangeRates
}

func newCeloContext(config *params.ChainConfig, state *state.StateDB) *celoContextImpl {
Expand Down Expand Up @@ -85,14 +87,26 @@ func (cc *celoContextImpl) GetBalanceOf(account common.Address, feeCurrency comm
}

// CmpValues compares values of potentially different currencies
func (cc *celoContextImpl) CompareFees(val1 *big.Int, feeCurrency1 *common.Address, val2 *big.Int, feeCurrency2 *common.Address) int {
// Short circuit if the fee currency is the same. nil currency => native currency
// nil currency => native currency
func (cc *celoContextImpl) CompareValue(val1 *big.Int, feeCurrency1 *common.Address, val2 *big.Int, feeCurrency2 *common.Address) int {
// Short circuit if the fee currency is the same.
if areEqualAddresses(feeCurrency1, feeCurrency2) {
return val1.Cmp(val2)
}

exchangeRate1, ok1 := cc.exchangeRates[*feeCurrency1]
exchangeRate2, ok2 := cc.exchangeRates[*feeCurrency2]
var exchangeRate1, exchangeRate2 *big.Rat
ok1, ok2 := true, true
if feeCurrency1 == nil {
exchangeRate1 = big.NewRat(1, 1)
} else {
exchangeRate1, ok1 = cc.exchangeRates[*feeCurrency1]
}

if feeCurrency2 == nil {
exchangeRate2 = big.NewRat(1, 1)
} else {
exchangeRate2, ok2 = cc.exchangeRates[*feeCurrency2]
}

if !ok1 || !ok2 {
currency1Output := "nil"
Expand All @@ -104,14 +118,14 @@ func (cc *celoContextImpl) CompareFees(val1 *big.Int, feeCurrency1 *common.Addre
currency2Output = feeCurrency2.Hex()
}
// TODO(pl): I guess we should just error here?
log.Warn("Error in retrieving exchange rate. Will do comparison of two values without exchange rate conversion.", "currency1", currency1Output, "err1", err1, "currency2", currency2Output, "err2", err2)
log.Warn("Error in retrieving exchange rate. Will do comparison of two values without exchange rate conversion.", "currency1", currency1Output, "currency2", currency2Output)
return val1.Cmp(val2)
}

// Below code block is basically evaluating this comparison:
// currencyAmount * c.toCELORate.denominator / c.toCELORate.numerator < sndCurrencyAmount * sndCurrency.toCELORate.denominator / sndCurrency.toCELORate.numerator
// val1 * echangeRate1.denominator / echangeRate1.numerator < val2 * exchangeRate2.denominator / exchangeRate2.numerator
// It will transform that comparison to this, to remove having to deal with fractional values.
// currencyAmount * c.toCELORate.denominator * sndCurrency.toCELORate.numerator < sndCurrencyAmount * sndCurrency.toCELORate.denominator * c.toCELORate.numerator
// val1 * echangeRate1.denominator * exchangeRate2.numerator < val2 * exchangeRate2.denominator * c.toCELORate.numerator
leftSide := new(big.Int).Mul(
val1,
new(big.Int).Mul(
Expand Down
129 changes: 129 additions & 0 deletions core/txpool/legacypool/celo_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package legacypool

import (
"math/big"
"testing"

"github.com/ethereum/go-ethereum/common"
)

func Test_celoContextImpl_CompareFees(t *testing.T) {
currA := common.HexToAddress("0xA")
currB := common.HexToAddress("0xB")

exchangeRates := ExchangeRates{
currA: big.NewRat(1, 2), // token is worth 2 celo
currB: big.NewRat(2, 1), // token is worth 0.5 celo
}
type args struct {
val1 *big.Int
feeCurrency1 *common.Address
val2 *big.Int
feeCurrency2 *common.Address
}
tests := []struct {
name string
args args
want int
}{
// Native currency
{
name: "Same amount of native currency",
args: args{
val1: big.NewInt(1),
feeCurrency1: nil,
val2: big.NewInt(1),
feeCurrency2: nil,
},
want: 0,
}, {
name: "Different amounts of native currency 1",
args: args{
val1: big.NewInt(2),
feeCurrency1: nil,
val2: big.NewInt(1),
feeCurrency2: nil,
},
want: 1,
}, {
name: "Different amounts of native currency 2",
args: args{
val1: big.NewInt(1),
feeCurrency1: nil,
val2: big.NewInt(5),
feeCurrency2: nil,
},
want: -1,
},
// Mixed currency
{
name: "Same amount of mixed currency",
args: args{
val1: big.NewInt(1),
feeCurrency1: nil,
val2: big.NewInt(1),
feeCurrency2: &currA,
},
want: -1,
}, {
name: "Different amounts of mixed currency 1",
args: args{
val1: big.NewInt(2),
feeCurrency1: nil,
val2: big.NewInt(1),
feeCurrency2: &currA,
},
want: 0,
}, {
name: "Different amounts of mixed currency 2",
args: args{
val1: big.NewInt(1),
feeCurrency1: nil,
val2: big.NewInt(2),
feeCurrency2: &currB,
},
want: 0,
},
// Two fee currencies
{
name: "Same amount of same currency",
args: args{
val1: big.NewInt(1),
feeCurrency1: &currA,
val2: big.NewInt(1),
feeCurrency2: &currA,
},
want: 0,
}, {
name: "Different amounts of same currency 1",
args: args{
val1: big.NewInt(3),
feeCurrency1: &currA,
val2: big.NewInt(1),
feeCurrency2: &currA,
},
want: 1,
}, {
name: "Different amounts of same currency 2",
args: args{
val1: big.NewInt(1),
feeCurrency1: &currA,
val2: big.NewInt(7),
feeCurrency2: &currA,
},
want: -1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cc := &celoContextImpl{
backend: nil,
registry: nil,
exchangeRates: exchangeRates,
}
if got := cc.CompareValue(tt.args.val1, tt.args.feeCurrency1, tt.args.val2, tt.args.feeCurrency2); got != tt.want {
t.Errorf("celoContextImpl.CompareFees() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit 2673d28

Please sign in to comment.