From bac405b208430f83a73d176f68e9165b7bf710dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Mon, 2 Dec 2024 08:53:17 +0000 Subject: [PATCH] return sql.ErrNotFound if account doesn't exist (#6508) ## Motivation The decision how to handle a situation when the queried account doesn't exist should be left to the caller - it's not the responsibility of the code accessing the cache/database. --- common/types/address.go | 6 ---- common/types/hashes.go | 12 ------- common/types/testutil.go | 10 ++++++ genvm/core/context.go | 6 +++- genvm/core/context_test.go | 6 ++-- genvm/core/staged_cache.go | 8 ++++- genvm/core/staged_cache_test.go | 1 + genvm/core/types.go | 7 ++++ genvm/rewards.go | 6 +++- genvm/vm.go | 17 +++++++--- genvm/vm_test.go | 10 ++++-- node/relay.go | 2 +- sql/accounts/accounts.go | 24 +++++--------- sql/accounts/accounts_test.go | 57 +++++++++++++++++++++++++++++++++ 14 files changed, 124 insertions(+), 48 deletions(-) diff --git a/common/types/address.go b/common/types/address.go index bc631e77cc..d2252855b2 100644 --- a/common/types/address.go +++ b/common/types/address.go @@ -112,12 +112,6 @@ func (a Address) String() string { return result } -// Format implements fmt.Formatter, forcing the byte slice to be formatted as is, -// without going through the stringer interface used for logging. -func (a Address) Format(s fmt.State, c rune) { - fmt.Fprintf(s, "%"+string(c), a[:]) -} - // EncodeScale implements scale codec interface. func (a *Address) EncodeScale(e *scale.Encoder) (int, error) { return scale.EncodeByteArray(e, a[:]) diff --git a/common/types/hashes.go b/common/types/hashes.go index 179f1384c0..562a702a01 100644 --- a/common/types/hashes.go +++ b/common/types/hashes.go @@ -44,12 +44,6 @@ func (h Hash20) ShortString() string { return hex.EncodeToString(h[:5]) } -// Format implements fmt.Formatter, forcing the byte slice to be formatted as is, -// without going through the stringer interface used for logging. -func (h Hash20) Format(s fmt.State, c rune) { - fmt.Fprintf(s, "%"+string(c), h[:]) -} - // UnmarshalText parses a hash in hex syntax. func (h *Hash20) UnmarshalText(input []byte) error { if err := util.UnmarshalFixedText("Hash", input, h[:]); err != nil { @@ -163,12 +157,6 @@ func (h Hash32) ShortString() string { return hex.EncodeToString(h[:5]) } -// Format implements fmt.Formatter, forcing the byte slice to be formatted as is, -// without going through the stringer interface used for logging. -func (h Hash32) Format(s fmt.State, c rune) { - fmt.Fprintf(s, "%"+string(c), h[:]) -} - // UnmarshalText parses a hash in hex syntax. func (h *Hash32) UnmarshalText(input []byte) error { if err := util.UnmarshalFixedText("Hash", input, h[:]); err != nil { diff --git a/common/types/testutil.go b/common/types/testutil.go index 50d11d481f..d03f56dad6 100644 --- a/common/types/testutil.go +++ b/common/types/testutil.go @@ -2,6 +2,9 @@ package types import ( "crypto/rand" + "testing" + + "github.com/stretchr/testify/require" ) // RandomBytes generates random data in bytes for testing. @@ -137,3 +140,10 @@ func RandomVrfSignature() VrfSignature { } return VrfSignature(b) } + +func RandomAddress(tb testing.TB) Address { + var a Address + _, err := rand.Read(a[:]) + require.NoError(tb, err) + return a +} diff --git a/genvm/core/context.go b/genvm/core/context.go index e5f46a1e41..0718ce2702 100644 --- a/genvm/core/context.go +++ b/genvm/core/context.go @@ -2,6 +2,7 @@ package core import ( "bytes" + "errors" "fmt" "github.com/spacemeshos/go-scale" @@ -235,7 +236,10 @@ func (c *Context) load(address types.Address) (*Account, error) { account, exist := c.changed[address] if !exist { loaded, err := c.Loader.Get(address) - if err != nil { + switch { + case errors.Is(err, ErrNotFound): + loaded = types.Account{Address: address} + case err != nil: return nil, fmt.Errorf("%w: %w", ErrInternal, err) } account = &loaded diff --git a/genvm/core/context_test.go b/genvm/core/context_test.go index 90f861f65d..0f35671c15 100644 --- a/genvm/core/context_test.go +++ b/genvm/core/context_test.go @@ -193,10 +193,8 @@ func TestRelay(t *testing.T) { require.Equal(t, amount1, int(rec1state.Balance)) require.NotEqual(t, encoded, rec1state.State) - rec2state, err := cache.Get(receiver2) - require.NoError(t, err) - require.Equal(t, 0, int(rec2state.Balance)) - require.NotEqual(t, encoded, rec2state.State) + _, err = cache.Get(receiver2) + require.ErrorIs(t, err, core.ErrNotFound) // relay to receiver2 failed remoteState, err := cache.Get(remote) require.NoError(t, err) diff --git a/genvm/core/staged_cache.go b/genvm/core/staged_cache.go index c116cd8a8f..885a2efd24 100644 --- a/genvm/core/staged_cache.go +++ b/genvm/core/staged_cache.go @@ -1,6 +1,8 @@ package core import ( + "errors" + "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" @@ -11,7 +13,11 @@ type DBLoader struct { } func (db DBLoader) Get(address types.Address) (types.Account, error) { - return accounts.Latest(db.Executor, address) + account, err := accounts.Latest(db.Executor, address) + if errors.Is(err, sql.ErrNotFound) { + return types.Account{}, ErrNotFound + } + return account, err } // NewStagedCache returns instance of the staged cache. diff --git a/genvm/core/staged_cache_test.go b/genvm/core/staged_cache_test.go index 4153b626de..d4f29043ca 100644 --- a/genvm/core/staged_cache_test.go +++ b/genvm/core/staged_cache_test.go @@ -13,6 +13,7 @@ func TestCacheGetCopies(t *testing.T) { db := statesql.InMemoryTest(t) ss := core.NewStagedCache(core.DBLoader{db}) address := core.Address{1} + ss.Update(core.Account{Address: address}) account, err := ss.Get(address) require.NoError(t, err) account.Balance = 100 diff --git a/genvm/core/types.go b/genvm/core/types.go index 7c56487754..eb51824747 100644 --- a/genvm/core/types.go +++ b/genvm/core/types.go @@ -1,6 +1,8 @@ package core import ( + "errors" + "github.com/spacemeshos/go-scale" "github.com/spacemeshos/go-spacemesh/common/types" @@ -76,8 +78,13 @@ type Template interface { Verify(Host, []byte, *scale.Decoder) bool } +var ErrNotFound = errors.New("not found") + // AccountLoader is an interface for loading accounts. type AccountLoader interface { + // Get account for given address + // + // Returns ErrNotFound if the account doesn't exist. Get(Address) (Account, error) } diff --git a/genvm/rewards.go b/genvm/rewards.go index b84e82ccd8..3fed84fe97 100644 --- a/genvm/rewards.go +++ b/genvm/rewards.go @@ -1,6 +1,7 @@ package vm import ( + "errors" "fmt" "math/big" @@ -67,7 +68,10 @@ func (v *VM) addRewards( } result = append(result, reward) account, err := ss.Get(blockReward.Coinbase) - if err != nil { + switch { + case errors.Is(err, core.ErrNotFound): + account = types.Account{Address: blockReward.Coinbase} + case err != nil: return nil, fmt.Errorf("%w: %w", core.ErrInternal, err) } account.Balance += reward.TotalReward diff --git a/genvm/vm.go b/genvm/vm.go index 5e8671513d..652d63a989 100644 --- a/genvm/vm.go +++ b/genvm/vm.go @@ -156,7 +156,10 @@ func (v *VM) AccountExists(address core.Address) (bool, error) { // GetNonce returns expected next nonce for the address. func (v *VM) GetNonce(address core.Address) (core.Nonce, error) { account, err := accounts.Latest(v.db, address) - if err != nil { + switch { + case errors.Is(err, sql.ErrNotFound): + return 0, nil + case err != nil: return 0, err } return account.NextNonce, nil @@ -165,7 +168,10 @@ func (v *VM) GetNonce(address core.Address) (core.Nonce, error) { // GetBalance returns balance for an address. func (v *VM) GetBalance(address types.Address) (uint64, error) { account, err := accounts.Latest(v.db, address) - if err != nil { + switch { + case errors.Is(err, sql.ErrNotFound): + return 0, nil + case err != nil: return 0, err } return account.Balance, nil @@ -501,9 +507,12 @@ func parse( return nil, nil, nil, fmt.Errorf("%w: failed to decode method selector %w", core.ErrMalformed, err) } account, err := loader.Get(principal) - if err != nil { + switch { + case errors.Is(err, core.ErrNotFound): + account = types.Account{Address: principal} + case err != nil: return nil, nil, nil, fmt.Errorf( - "%w: failed load state for principal %s - %w", + "%w: failed load state for principal %s: %w", core.ErrInternal, principal, err, diff --git a/genvm/vm_test.go b/genvm/vm_test.go index 54e2943e13..7dc2785130 100644 --- a/genvm/vm_test.go +++ b/genvm/vm_test.go @@ -34,6 +34,7 @@ import ( "github.com/spacemeshos/go-spacemesh/genvm/templates/wallet" "github.com/spacemeshos/go-spacemesh/hash" "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/statesql" @@ -1441,9 +1442,13 @@ func runTestCases(t *testing.T, tcs []templateTestCase, genTester func(t *testin } for account, changes := range layer.expected { prev, err := accounts.Get(tt.db, tt.accounts[account].getAddress(), lid.Sub(1)) - require.NoError(tt, err) + if err != nil { + require.ErrorIs(t, err, sql.ErrNotFound) + } current, err := accounts.Get(tt.db, tt.accounts[account].getAddress(), lid) - require.NoError(tt, err) + if err != nil { + require.ErrorIs(t, err, sql.ErrNotFound) + } tt.Logf("verifying account index=%d in layer index=%d", account, i) changes.verify(tt, &prev, ¤t) } @@ -1652,6 +1657,7 @@ func testValidation(t *testing.T, tt *tester, template core.Address) { if tc.err != nil { require.ErrorIs(t, err, tc.err) } else { + require.NoError(t, err) require.Equal(t, tc.verified, req.Verify()) if tc.verified { require.Equal(t, tc.header, header) diff --git a/node/relay.go b/node/relay.go index 7088fbc03a..938ab2698b 100644 --- a/node/relay.go +++ b/node/relay.go @@ -46,7 +46,7 @@ func runRelay(ctx context.Context, cfg *config.Config) error { types.SetLayersPerEpoch(cfg.LayersPerEpoch) prologue := fmt.Sprintf("%x-%v", - cfg.Genesis.GenesisID(), + cfg.Genesis.GenesisID().Bytes(), types.GetEffectiveGenesis(), ) // Prevent testnet nodes from working on the mainnet, but diff --git a/sql/accounts/accounts.go b/sql/accounts/accounts.go index bb04b2f50f..55f165e407 100644 --- a/sql/accounts/accounts.go +++ b/sql/accounts/accounts.go @@ -24,7 +24,7 @@ func Has(db sql.Executor, address types.Address) (bool, error) { // Latest latest account data for an address. func Latest(db sql.Executor, address types.Address) (types.Account, error) { var account types.Account - _, err := db.Exec(` + rows, err := db.Exec(` select balance, next_nonce, layer_updated, template, state from accounts where address = ?1 order by layer_updated desc;`, @@ -46,20 +46,16 @@ func Latest(db sql.Executor, address types.Address) (types.Account, error) { if err != nil { return types.Account{}, fmt.Errorf("failed to load %v: %w", address, err) } - // TODO(mafa): returning `sql.ErrNotFound` causes a bunch of tests to fail, some even panic - // this needs to be investigated and fixed - // - // if account.Address != address { - // return types.Account{}, sql.ErrNotFound - // } - account.Address = address // without this tests are failing not only assertions but are also panicking + if rows == 0 { + return types.Account{}, sql.ErrNotFound + } return account, nil } // Get account data that was valid at the specified layer. func Get(db sql.Executor, address types.Address, layer types.LayerID) (types.Account, error) { var account types.Account - _, err := db.Exec(` + rows, err := db.Exec(` select balance, next_nonce, layer_updated, template, state from accounts where address = ?1 and layer_updated <= ?2 order by layer_updated desc;`, @@ -84,13 +80,9 @@ func Get(db sql.Executor, address types.Address, layer types.LayerID) (types.Acc if err != nil { return types.Account{}, fmt.Errorf("failed to load %v for layer %v: %w", address, layer, err) } - // TODO(mafa): returning `sql.ErrNotFound` causes a bunch of tests to fail, some even panic - // this needs to be investigated and fixed - // - // if account.Address != address { - // return types.Account{}, sql.ErrNotFound - // } - account.Address = address // without this tests are failing not only assertions but are also panicking + if rows == 0 { + return types.Account{}, sql.ErrNotFound + } return account, nil } diff --git a/sql/accounts/accounts_test.go b/sql/accounts/accounts_test.go index ebc1554337..10c25f8ad8 100644 --- a/sql/accounts/accounts_test.go +++ b/sql/accounts/accounts_test.go @@ -48,6 +48,63 @@ func TestHas(t *testing.T) { require.True(t, has) } +func TestLatest(t *testing.T) { + t.Run("doesn't exist", func(t *testing.T) { + db := statesql.InMemoryTest(t) + account, err := Latest(db, types.RandomAddress(t)) + require.ErrorIs(t, err, sql.ErrNotFound) + require.Empty(t, account) + }) + t.Run("picks latest", func(t *testing.T) { + address := types.RandomAddress(t) + db := statesql.InMemoryTest(t) + err := Update(db, &types.Account{ + Address: address, + }) + require.NoError(t, err) + account := types.Account{ + Layer: 1, + NextNonce: 1, + Balance: 100, + Address: address, + } + err = Update(db, &account) + require.NoError(t, err) + + got, err := Latest(db, address) + require.NoError(t, err) + require.Equal(t, account, got) + }) +} + +func TestGet(t *testing.T) { + t.Run("doesn't exist", func(t *testing.T) { + db := statesql.InMemoryTest(t) + account, err := Get(db, types.RandomAddress(t), 0) + require.ErrorIs(t, err, sql.ErrNotFound) + require.Empty(t, account) + }) + t.Run("picks the right one", func(t *testing.T) { + address := types.RandomAddress(t) + db := statesql.InMemoryTest(t) + account := types.Account{ + Layer: 1, + NextNonce: 1, + Balance: 100, + Address: address, + } + err := Update(db, &account) + require.NoError(t, err) + + _, err = Get(db, address, 0) + require.ErrorIs(t, err, sql.ErrNotFound) + + got, err := Get(db, address, 1) + require.NoError(t, err) + require.Equal(t, account, got) + }) +} + func TestRevert(t *testing.T) { address := types.Address{1, 1} seq := genSeq(address, 10)