diff --git a/core/txpool/legacypool/celo.go b/core/txpool/legacypool/celo.go index d9b179ac8c..fa3ab5d119 100644 --- a/core/txpool/legacypool/celo.go +++ b/core/txpool/legacypool/celo.go @@ -15,16 +15,18 @@ import ( "github.com/ethereum/go-ethereum/params" ) +type ExchangeRates = map[common.Address]*big.Rat + type celoContext interface { IsWhitelisted(feeCurrency common.Address) (bool, error) GetBalanceOf(account common.Address, feeCurrency common.Address) (*big.Int, error) - CompareFees(fee1 *big.Int, feCurrency1 common.Address, fee2 *big.Int, feeCurrency2 common.Address) int + CompareValue(fee1 *big.Int, feCurrency1 common.Address, fee2 *big.Int, feeCurrency2 common.Address) int } type celoContextImpl struct { backend *core.CeloBackend registry *abigen.RegistryCaller - exchangeRates map[common.Address]*big.Rat + exchangeRates ExchangeRates } func newCeloContext(config *params.ChainConfig, state *state.StateDB) *celoContextImpl { @@ -85,14 +87,26 @@ func (cc *celoContextImpl) GetBalanceOf(account common.Address, feeCurrency comm } // CmpValues compares values of potentially different currencies -func (cc *celoContextImpl) CompareFees(val1 *big.Int, feeCurrency1 *common.Address, val2 *big.Int, feeCurrency2 *common.Address) int { - // Short circuit if the fee currency is the same. nil currency => native currency +// nil currency => native currency +func (cc *celoContextImpl) CompareValue(val1 *big.Int, feeCurrency1 *common.Address, val2 *big.Int, feeCurrency2 *common.Address) int { + // Short circuit if the fee currency is the same. if areEqualAddresses(feeCurrency1, feeCurrency2) { return val1.Cmp(val2) } - exchangeRate1, ok1 := cc.exchangeRates[*feeCurrency1] - exchangeRate2, ok2 := cc.exchangeRates[*feeCurrency2] + var exchangeRate1, exchangeRate2 *big.Rat + ok1, ok2 := true, true + if feeCurrency1 == nil { + exchangeRate1 = big.NewRat(1, 1) + } else { + exchangeRate1, ok1 = cc.exchangeRates[*feeCurrency1] + } + + if feeCurrency2 == nil { + exchangeRate2 = big.NewRat(1, 1) + } else { + exchangeRate2, ok2 = cc.exchangeRates[*feeCurrency2] + } if !ok1 || !ok2 { currency1Output := "nil" @@ -104,14 +118,14 @@ func (cc *celoContextImpl) CompareFees(val1 *big.Int, feeCurrency1 *common.Addre currency2Output = feeCurrency2.Hex() } // TODO(pl): I guess we should just error here? - log.Warn("Error in retrieving exchange rate. Will do comparison of two values without exchange rate conversion.", "currency1", currency1Output, "err1", err1, "currency2", currency2Output, "err2", err2) + log.Warn("Error in retrieving exchange rate. Will do comparison of two values without exchange rate conversion.", "currency1", currency1Output, "currency2", currency2Output) return val1.Cmp(val2) } // Below code block is basically evaluating this comparison: - // currencyAmount * c.toCELORate.denominator / c.toCELORate.numerator < sndCurrencyAmount * sndCurrency.toCELORate.denominator / sndCurrency.toCELORate.numerator + // val1 * echangeRate1.denominator / echangeRate1.numerator < val2 * exchangeRate2.denominator / exchangeRate2.numerator // It will transform that comparison to this, to remove having to deal with fractional values. - // currencyAmount * c.toCELORate.denominator * sndCurrency.toCELORate.numerator < sndCurrencyAmount * sndCurrency.toCELORate.denominator * c.toCELORate.numerator + // val1 * echangeRate1.denominator * exchangeRate2.numerator < val2 * exchangeRate2.denominator * c.toCELORate.numerator leftSide := new(big.Int).Mul( val1, new(big.Int).Mul( diff --git a/core/txpool/legacypool/celo_test.go b/core/txpool/legacypool/celo_test.go new file mode 100644 index 0000000000..092ee043d8 --- /dev/null +++ b/core/txpool/legacypool/celo_test.go @@ -0,0 +1,129 @@ +package legacypool + +import ( + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +func Test_celoContextImpl_CompareFees(t *testing.T) { + currA := common.HexToAddress("0xA") + currB := common.HexToAddress("0xB") + + exchangeRates := ExchangeRates{ + currA: big.NewRat(1, 2), // token is worth 2 celo + currB: big.NewRat(2, 1), // token is worth 0.5 celo + } + type args struct { + val1 *big.Int + feeCurrency1 *common.Address + val2 *big.Int + feeCurrency2 *common.Address + } + tests := []struct { + name string + args args + want int + }{ + // Native currency + { + name: "Same amount of native currency", + args: args{ + val1: big.NewInt(1), + feeCurrency1: nil, + val2: big.NewInt(1), + feeCurrency2: nil, + }, + want: 0, + }, { + name: "Different amounts of native currency 1", + args: args{ + val1: big.NewInt(2), + feeCurrency1: nil, + val2: big.NewInt(1), + feeCurrency2: nil, + }, + want: 1, + }, { + name: "Different amounts of native currency 2", + args: args{ + val1: big.NewInt(1), + feeCurrency1: nil, + val2: big.NewInt(5), + feeCurrency2: nil, + }, + want: -1, + }, + // Mixed currency + { + name: "Same amount of mixed currency", + args: args{ + val1: big.NewInt(1), + feeCurrency1: nil, + val2: big.NewInt(1), + feeCurrency2: &currA, + }, + want: -1, + }, { + name: "Different amounts of mixed currency 1", + args: args{ + val1: big.NewInt(2), + feeCurrency1: nil, + val2: big.NewInt(1), + feeCurrency2: &currA, + }, + want: 0, + }, { + name: "Different amounts of mixed currency 2", + args: args{ + val1: big.NewInt(1), + feeCurrency1: nil, + val2: big.NewInt(2), + feeCurrency2: &currB, + }, + want: 0, + }, + // Two fee currencies + { + name: "Same amount of same currency", + args: args{ + val1: big.NewInt(1), + feeCurrency1: &currA, + val2: big.NewInt(1), + feeCurrency2: &currA, + }, + want: 0, + }, { + name: "Different amounts of same currency 1", + args: args{ + val1: big.NewInt(3), + feeCurrency1: &currA, + val2: big.NewInt(1), + feeCurrency2: &currA, + }, + want: 1, + }, { + name: "Different amounts of same currency 2", + args: args{ + val1: big.NewInt(1), + feeCurrency1: &currA, + val2: big.NewInt(7), + feeCurrency2: &currA, + }, + want: -1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cc := &celoContextImpl{ + backend: nil, + registry: nil, + exchangeRates: exchangeRates, + } + if got := cc.CompareValue(tt.args.val1, tt.args.feeCurrency1, tt.args.val2, tt.args.feeCurrency2); got != tt.want { + t.Errorf("celoContextImpl.CompareFees() = %v, want %v", got, tt.want) + } + }) + } +}