Skip to content

Commit

Permalink
Add CeloContext
Browse files Browse the repository at this point in the history
Add tests

Rebase

tmp

Add error

PR review

CHanges after discussion
  • Loading branch information
palango committed Nov 7, 2023
1 parent 10e5420 commit c7f30a9
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 7 deletions.
2 changes: 2 additions & 0 deletions common/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,5 @@ func (d *Decimal) UnmarshalJSON(input []byte) error {
return err
}
}

type ExchangeRates = map[Address]*big.Rat
8 changes: 4 additions & 4 deletions core/celo_backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ import (
// CeloBackend provide a partial ContractBackend implementation, so that we can
// access core contracts during block processing.
type CeloBackend struct {
chainConfig *params.ChainConfig
state *state.StateDB
ChainConfig *params.ChainConfig
State *state.StateDB
}

// ContractCaller implementation

func (b *CeloBackend) CodeAt(ctx context.Context, contract common.Address, blockNumber *big.Int) ([]byte, error) {
return b.state.GetCode(contract), nil
return b.State.GetCode(contract), nil
}

func (b *CeloBackend) CallContract(ctx context.Context, call ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) {
Expand All @@ -44,7 +44,7 @@ func (b *CeloBackend) CallContract(ctx context.Context, call ethereum.CallMsg, b
txCtx := vm.TxContext{}
vmConfig := vm.Config{}

evm := vm.NewEVM(blockCtx, txCtx, b.state, b.chainConfig, vmConfig)
evm := vm.NewEVM(blockCtx, txCtx, b.State, b.ChainConfig, vmConfig)
ret, _, err := evm.StaticCall(vm.AccountRef(evm.Origin), *call.To, call.Data, call.Gas)

return ret, err
Expand Down
4 changes: 2 additions & 2 deletions core/celo_evm.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
)

// Returns the exchange rates for all gas currencies from CELO
func getExchangeRates(caller *CeloBackend) (map[common.Address]*big.Rat, error) {
func GetExchangeRates(caller bind.ContractCaller) (common.ExchangeRates, error) {
exchangeRates := map[common.Address]*big.Rat{}
whitelist, err := abigen.NewFeeCurrencyWhitelistCaller(contracts.FeeCurrencyWhitelistAddress, caller)
if err != nil {
Expand Down Expand Up @@ -57,7 +57,7 @@ func setCeloFieldsInBlockContext(blockContext *vm.BlockContext, header *types.He

// Add fee currency exchange rates
var err error
blockContext.ExchangeRates, err = getExchangeRates(caller)
blockContext.ExchangeRates, err = GetExchangeRates(caller)
if err != nil {
log.Error("Error fetching exchange rates!", "err", err)
}
Expand Down
92 changes: 92 additions & 0 deletions core/txpool/legacypool/celo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package legacypool

import (
"errors"
"fmt"
"math/big"

"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/contracts/celo/abigen"
)

var (
unitRate = big.NewRat(1, 1)
)

// IsWhitelisted checks if a given fee currency is whitelisted
func IsWhitelisted(exchangeRates common.ExchangeRates, feeCurrency *common.Address) bool {
if feeCurrency == nil {
return true
}
_, ok := exchangeRates[*feeCurrency]
return ok
}

// Compares values in different currencies
// nil currency => native currency
func CompareValue(exchangeRates common.ExchangeRates, val1 *big.Int, feeCurrency1 *common.Address, val2 *big.Int, feeCurrency2 *common.Address) (int, error) {
// Short circuit if the fee currency is the same.
if areEqualAddresses(feeCurrency1, feeCurrency2) {
return val1.Cmp(val2), nil
}

var exchangeRate1, exchangeRate2 *big.Rat
var ok bool
if feeCurrency1 == nil {
exchangeRate1 = unitRate
} else {
exchangeRate1, ok = exchangeRates[*feeCurrency1]
if !ok {
return 0, fmt.Errorf("fee currency not registered: %s", feeCurrency1.Hex())
}
}

if feeCurrency2 == nil {
exchangeRate2 = unitRate
} else {
exchangeRate2, ok = exchangeRates[*feeCurrency2]
if !ok {
return 0, fmt.Errorf("fee currency not registered: %s", feeCurrency1.Hex())
}
}

// Below code block is basically evaluating this comparison:
// val1 * exchangeRate1.denominator / exchangeRate1.numerator < val2 * exchangeRate2.denominator / exchangeRate2.numerator
// It will transform that comparison to this, to remove having to deal with fractional values.
// val1 * exchangeRate1.denominator * exchangeRate2.numerator < val2 * exchangeRate2.denominator * exchangeRate1.numerator
leftSide := new(big.Int).Mul(
val1,
new(big.Int).Mul(
exchangeRate1.Denom(),
exchangeRate2.Num(),
),
)
rightSide := new(big.Int).Mul(
val2,
new(big.Int).Mul(
exchangeRate2.Denom(),
exchangeRate1.Num(),
),
)

return leftSide.Cmp(rightSide), nil
}

func areEqualAddresses(addr1, addr2 *common.Address) bool {
return (addr1 == nil && addr2 == nil) || (addr1 != nil && addr2 != nil && *addr1 == *addr2)
}

func GetBalanceOf(backend *bind.ContractCaller, account common.Address, feeCurrency common.Address) (*big.Int, error) {
token, err := abigen.NewFeeCurrencyCaller(feeCurrency, *backend)
if err != nil {
return nil, errors.New("failed to access fee currency token")
}

balance, err := token.BalanceOf(&bind.CallOpts{}, account)
if err != nil {
return nil, errors.New("failed to access token balance")
}

return balance, nil
}
175 changes: 175 additions & 0 deletions core/txpool/legacypool/celo_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
package legacypool

import (
"math/big"
"testing"

"github.com/ethereum/go-ethereum/common"
)

var (
currA = common.HexToAddress("0xA")
currB = common.HexToAddress("0xB")
currX = common.HexToAddress("0xF")
exchangeRates = common.ExchangeRates{
currA: big.NewRat(47, 100),
currB: big.NewRat(45, 100),
}
)

func TestCompareFees(t *testing.T) {
type args struct {
val1 *big.Int
feeCurrency1 *common.Address
val2 *big.Int
feeCurrency2 *common.Address
}
tests := []struct {
name string
args args
wantResult int
wantErr bool
}{
// Native currency
{
name: "Same amount of native currency",
args: args{
val1: big.NewInt(1),
feeCurrency1: nil,
val2: big.NewInt(1),
feeCurrency2: nil,
},
wantResult: 0,
}, {
name: "Different amounts of native currency 1",
args: args{
val1: big.NewInt(2),
feeCurrency1: nil,
val2: big.NewInt(1),
feeCurrency2: nil,
},
wantResult: 1,
}, {
name: "Different amounts of native currency 2",
args: args{
val1: big.NewInt(1),
feeCurrency1: nil,
val2: big.NewInt(5),
feeCurrency2: nil,
},
wantResult: -1,
},
// Mixed currency
{
name: "Same amount of mixed currency",
args: args{
val1: big.NewInt(1),
feeCurrency1: nil,
val2: big.NewInt(1),
feeCurrency2: &currA,
},
wantResult: -1,
}, {
name: "Different amounts of mixed currency 1",
args: args{
val1: big.NewInt(100),
feeCurrency1: nil,
val2: big.NewInt(47),
feeCurrency2: &currA,
},
wantResult: 0,
}, {
name: "Different amounts of mixed currency 2",
args: args{
val1: big.NewInt(100),
feeCurrency1: nil,
val2: big.NewInt(45),
feeCurrency2: &currB,
},
wantResult: 0,
},
// Two fee currencies
{
name: "Same amount of same currency",
args: args{
val1: big.NewInt(1),
feeCurrency1: &currA,
val2: big.NewInt(1),
feeCurrency2: &currA,
},
wantResult: 0,
}, {
name: "Different amounts of same currency 1",
args: args{
val1: big.NewInt(3),
feeCurrency1: &currA,
val2: big.NewInt(1),
feeCurrency2: &currA,
},
wantResult: 1,
}, {
name: "Different amounts of same currency 2",
args: args{
val1: big.NewInt(1),
feeCurrency1: &currA,
val2: big.NewInt(7),
feeCurrency2: &currA,
},
wantResult: -1,
},
// Unregistered fee currency
{
name: "Different amounts of different currencies",
args: args{
val1: big.NewInt(1),
feeCurrency1: &currA,
val2: big.NewInt(1),
feeCurrency2: &currX,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := CompareValue(exchangeRates, tt.args.val1, tt.args.feeCurrency1, tt.args.val2, tt.args.feeCurrency2)

if tt.wantErr && err == nil {
t.Error("Expected error in celoContextImpl.CompareFees")
}
if got != tt.wantResult {
t.Errorf("celoContextImpl.CompareFees() = %v, want %v", got, tt.wantResult)
}
})
}
}

func TestIsWhitelisted(t *testing.T) {
tests := []struct {
name string
feeCurrency *common.Address
want bool
}{
{
name: "no fee currency",
feeCurrency: nil,
want: true,
},
{
name: "valid fee currency",
feeCurrency: &currA,
want: true,
},
{
name: "invalid fee currency",
feeCurrency: &currX,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsWhitelisted(exchangeRates, tt.feeCurrency); got != tt.want {
t.Errorf("IsWhitelisted() = %v, want %v", got, tt.want)
}
})
}
}
2 changes: 1 addition & 1 deletion core/vm/evm.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ type BlockContext struct {
ExcessBlobGas *uint64 // ExcessBlobGas field in the header, needed to compute the data

// Celo specific information
ExchangeRates map[common.Address]*big.Rat
ExchangeRates common.ExchangeRates
}

// TxContext provides the EVM with information about a transaction.
Expand Down

0 comments on commit c7f30a9

Please sign in to comment.