From db92f423e2c75507bebae287b4874a28079e2a87 Mon Sep 17 00:00:00 2001 From: Paul Lange Date: Tue, 7 Nov 2023 14:09:55 +0100 Subject: [PATCH] CHanges after discussion --- core/txpool/legacypool/celo.go | 64 ++++++++++------------------- core/txpool/legacypool/celo_test.go | 51 ++++++++++++++++++----- 2 files changed, 62 insertions(+), 53 deletions(-) diff --git a/core/txpool/legacypool/celo.go b/core/txpool/legacypool/celo.go index 823564da84..049ac27699 100644 --- a/core/txpool/legacypool/celo.go +++ b/core/txpool/legacypool/celo.go @@ -8,58 +8,24 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/contracts/celo/abigen" - "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/log" ) var ( unitRate = big.NewRat(1, 1) ) -type celoContext interface { - IsWhitelisted(feeCurrency common.Address) bool - GetBalanceOf(account common.Address, feeCurrency common.Address) (*big.Int, error) - CompareValue(fee1 *big.Int, feeCurrency1 common.Address, fee2 *big.Int, feeCurrency2 common.Address) int -} - -type celoContextImpl struct { - backend *core.CeloBackend - exchangeRates common.ExchangeRates -} - -func newCeloContext(backend *core.CeloBackend) *celoContextImpl { - // GetExchangeRates retrieves currency-to-CELO exchange rate - exchangeRates, err := core.GetExchangeRates(backend) - if err != nil { - log.Error("Error fetching exchange rates!", "err", err) +// IsWhitelisted checks if a given fee currency is whitelisted +func IsWhitelisted(exchangeRates common.ExchangeRates, feeCurrency *common.Address) bool { + if feeCurrency == nil { + return true } - - return &celoContextImpl{backend: backend, exchangeRates: exchangeRates} -} - -func (cc *celoContextImpl) IsWhitelisted(feeCurrency common.Address) bool { - _, ok := cc.exchangeRates[feeCurrency] + _, ok := exchangeRates[*feeCurrency] return ok } -func (cc *celoContextImpl) GetBalanceOf(account common.Address, feeCurrency common.Address) (*big.Int, error) { - // OPT: Cache fee currency caller? - token, err := abigen.NewFeeCurrencyCaller(feeCurrency, cc.backend) - if err != nil { - return nil, errors.New("failed to access fee currency token") - } - - balance, err := token.BalanceOf(&bind.CallOpts{}, account) - if err != nil { - return nil, errors.New("failed to access token balance") - } - - return balance, nil -} - // Compares values in different currencies // nil currency => native currency -func (cc *celoContextImpl) CompareValue(val1 *big.Int, feeCurrency1 *common.Address, val2 *big.Int, feeCurrency2 *common.Address) (int, error) { +func CompareValue(exchangeRates common.ExchangeRates, val1 *big.Int, feeCurrency1 *common.Address, val2 *big.Int, feeCurrency2 *common.Address) (int, error) { // Short circuit if the fee currency is the same. if areEqualAddresses(feeCurrency1, feeCurrency2) { return val1.Cmp(val2), nil @@ -70,7 +36,7 @@ func (cc *celoContextImpl) CompareValue(val1 *big.Int, feeCurrency1 *common.Addr if feeCurrency1 == nil { exchangeRate1 = unitRate } else { - exchangeRate1, ok = cc.exchangeRates[*feeCurrency1] + exchangeRate1, ok = exchangeRates[*feeCurrency1] if !ok { return 0, fmt.Errorf("fee currency not registered: %s", feeCurrency1.Hex()) } @@ -79,7 +45,7 @@ func (cc *celoContextImpl) CompareValue(val1 *big.Int, feeCurrency1 *common.Addr if feeCurrency2 == nil { exchangeRate2 = unitRate } else { - exchangeRate2, ok = cc.exchangeRates[*feeCurrency2] + exchangeRate2, ok = exchangeRates[*feeCurrency2] if !ok { return 0, fmt.Errorf("fee currency not registered: %s", feeCurrency1.Hex()) } @@ -110,3 +76,17 @@ func (cc *celoContextImpl) CompareValue(val1 *big.Int, feeCurrency1 *common.Addr func areEqualAddresses(addr1, addr2 *common.Address) bool { return (addr1 == nil && addr2 == nil) || (addr1 != nil && addr2 != nil && *addr1 == *addr2) } + +func GetBalanceOf(backend *bind.ContractCaller, account common.Address, feeCurrency common.Address) (*big.Int, error) { + token, err := abigen.NewFeeCurrencyCaller(feeCurrency, *backend) + if err != nil { + return nil, errors.New("failed to access fee currency token") + } + + balance, err := token.BalanceOf(&bind.CallOpts{}, account) + if err != nil { + return nil, errors.New("failed to access token balance") + } + + return balance, nil +} diff --git a/core/txpool/legacypool/celo_test.go b/core/txpool/legacypool/celo_test.go index 8aa3fe0da2..49fa943f70 100644 --- a/core/txpool/legacypool/celo_test.go +++ b/core/txpool/legacypool/celo_test.go @@ -7,15 +7,17 @@ import ( "github.com/ethereum/go-ethereum/common" ) -func Test_celoContextImpl_CompareFees(t *testing.T) { - currA := common.HexToAddress("0xA") - currB := common.HexToAddress("0xB") - currX := common.HexToAddress("0xF") - - exchangeRates := common.ExchangeRates{ +var ( + currA = common.HexToAddress("0xA") + currB = common.HexToAddress("0xB") + currX = common.HexToAddress("0xF") + exchangeRates = common.ExchangeRates{ currA: big.NewRat(47, 100), currB: big.NewRat(45, 100), } +) + +func TestCompareFees(t *testing.T) { type args struct { val1 *big.Int feeCurrency1 *common.Address @@ -129,11 +131,7 @@ func Test_celoContextImpl_CompareFees(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cc := &celoContextImpl{ - backend: nil, - exchangeRates: exchangeRates, - } - got, err := cc.CompareValue(tt.args.val1, tt.args.feeCurrency1, tt.args.val2, tt.args.feeCurrency2) + got, err := CompareValue(exchangeRates, tt.args.val1, tt.args.feeCurrency1, tt.args.val2, tt.args.feeCurrency2) if tt.wantErr && err == nil { t.Error("Expected error in celoContextImpl.CompareFees") @@ -144,3 +142,34 @@ func Test_celoContextImpl_CompareFees(t *testing.T) { }) } } + +func TestIsWhitelisted(t *testing.T) { + tests := []struct { + name string + feeCurrency *common.Address + want bool + }{ + { + name: "no fee currency", + feeCurrency: nil, + want: true, + }, + { + name: "valid fee currency", + feeCurrency: &currA, + want: true, + }, + { + name: "invalid fee currency", + feeCurrency: &currX, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsWhitelisted(exchangeRates, tt.feeCurrency); got != tt.want { + t.Errorf("IsWhitelisted() = %v, want %v", got, tt.want) + } + }) + } +}