diff --git a/runtime/sharedstate_test.go b/runtime/sharedstate_test.go index 52dd68f00..b2cd5cc49 100644 --- a/runtime/sharedstate_test.go +++ b/runtime/sharedstate_test.go @@ -255,12 +255,18 @@ func TestRuntimeSharedState(t *testing.T) { test( true, []ownerKeyPair{ - // Read account domain register to check if it is a migrated account + // Read account register to check if it is a migrated account // Read returns no value. { owner: signerAddress[:], key: []byte(AccountStorageKey), }, + // Read contract domain register. + // Read returns no value. + { + owner: signerAddress[:], + key: []byte(common.StorageDomainContract.Identifier()), + }, // Read all available domain registers to check if it is a new account // Read returns no value. { diff --git a/runtime/storage.go b/runtime/storage.go index 49aaa92b4..90d177b96 100644 --- a/runtime/storage.go +++ b/runtime/storage.go @@ -40,6 +40,14 @@ type StorageConfig struct { StorageFormatV2Enabled bool } +type storageFormat uint8 + +const ( + storageFormatUnknown storageFormat = iota + storageFormatV1 + storageFormatV2 +) + type Storage struct { *atree.PersistentSlabStorage @@ -168,19 +176,29 @@ func (s *Storage) GetDomainStorageMap( } }() - if !s.Config.StorageFormatV2Enabled || s.IsV1Account(address) { - domainStorageMap = s.AccountStorageV1.GetDomainStorageMap( + if !s.Config.StorageFormatV2Enabled { + + // When StorageFormatV2 is disabled, handle all accounts as v1 accounts. + + // Only read requested domain register. + + domainStorageMap = s.getDomainStorageMapForV1Account( address, domain, createIfNotExists, ) - if domainStorageMap != nil { - s.cacheIsV1Account(address, true) - } + return + } - } else { - domainStorageMap = s.AccountStorageV2.GetDomainStorageMap( + // StorageFormatV2 is enabled. + + // Check if cached account format is available. + + cachedFormat, known := s.getCachedAccountFormat(address) + if known { + return s.getDomainStorageMap( + cachedFormat, inter, address, domain, @@ -188,47 +206,165 @@ func (s *Storage) GetDomainStorageMap( ) } + // Check if account is v2 (by reading "stored" register). + + if s.isV2Account(address) { + return s.getDomainStorageMapForV2Account( + inter, + address, + domain, + createIfNotExists, + ) + } + + // Check if account is v1 (by reading requested domain register). + + if s.hasDomainRegister(address, domain) { + return s.getDomainStorageMapForV1Account( + address, + domain, + createIfNotExists, + ) + } + + // Domain register doesn't exist. + + // Return early if !createIfNotExists to avoid more register reading. + + if !createIfNotExists { + return nil + } + + // At this point, account is either new account or v1 account without requested domain register. + + // Check if account is v1 (by reading more domain registers) + + if s.isV1Account(address) { + return s.getDomainStorageMapForV1Account( + address, + domain, + createIfNotExists, + ) + } + + // New account is treated as v2 account when feature flag is enabled. + + return s.getDomainStorageMapForV2Account( + inter, + address, + domain, + createIfNotExists, + ) +} + +func (s *Storage) getDomainStorageMapForV1Account( + address common.Address, + domain common.StorageDomain, + createIfNotExists bool, +) *interpreter.DomainStorageMap { + domainStorageMap := s.AccountStorageV1.GetDomainStorageMap( + address, + domain, + createIfNotExists, + ) + + s.cacheIsV1Account(address, true) + return domainStorageMap } -// IsV1Account returns true if given account is in account storage format v1. -func (s *Storage) IsV1Account(address common.Address) (isV1 bool) { +func (s *Storage) getDomainStorageMapForV2Account( + inter *interpreter.Interpreter, + address common.Address, + domain common.StorageDomain, + createIfNotExists bool, +) *interpreter.DomainStorageMap { + domainStorageMap := s.AccountStorageV2.GetDomainStorageMap( + inter, + address, + domain, + createIfNotExists, + ) - // Check cache + s.cacheIsV1Account(address, false) - if isV1, present := s.cachedV1Accounts[address]; present { - return isV1 - } + return domainStorageMap +} - // Cache result +func (s *Storage) getDomainStorageMap( + format storageFormat, + inter *interpreter.Interpreter, + address common.Address, + domain common.StorageDomain, + createIfNotExists bool, +) *interpreter.DomainStorageMap { + switch format { - defer func() { - s.cacheIsV1Account(address, isV1) - }() + case storageFormatV1: + return s.getDomainStorageMapForV1Account( + address, + domain, + createIfNotExists, + ) + + case storageFormatV2: + return s.getDomainStorageMapForV2Account( + inter, + address, + domain, + createIfNotExists, + ) - // First check if account storage map exists. - // In that case the account was already migrated to account storage format v2, - // and we do not need to check the domain storage map registers. + default: + panic(errors.NewUnreachableError()) + } +} +func (s *Storage) getCachedAccountFormat(address common.Address) (format storageFormat, known bool) { + isV1, cached := s.cachedV1Accounts[address] + if !cached { + return storageFormatUnknown, false + } + if isV1 { + return storageFormatV1, true + } else { + return storageFormatV2, true + } +} + +// isV2Account returns true if given account is in account storage format v2. +func (s *Storage) isV2Account(address common.Address) bool { accountStorageMapExists, err := hasAccountStorageMap(s.Ledger, address) if err != nil { panic(err) } - if accountStorageMapExists { - return false + + return accountStorageMapExists +} + +// hasDomainRegister returns true if given account has given domain register. +// NOTE: account storage format v1 has domain registers. +func (s *Storage) hasDomainRegister(address common.Address, domain common.StorageDomain) (domainExists bool) { + _, domainExists, err := readSlabIndexFromRegister( + s.Ledger, + address, + []byte(domain.Identifier()), + ) + if err != nil { + panic(err) } + return domainExists +} + +// isV1Account returns true if given account is in account storage format v1 +// by checking if any of the domain registers exist. +func (s *Storage) isV1Account(address common.Address) (isV1 bool) { + // Check if a storage map register exists for any of the domains. // Check the most frequently used domains first, such as storage, public, private. for _, domain := range common.AllStorageDomains { - _, domainExists, err := readSlabIndexFromRegister( - s.Ledger, - address, - []byte(domain.Identifier()), - ) - if err != nil { - panic(err) - } + domainExists := s.hasDomainRegister(address, domain) if domainExists { return true } @@ -500,7 +636,7 @@ func (s *Storage) CheckHealth() error { // Only accounts in storage format v1 store domain storage maps // directly at the root of the account - if !s.IsV1Account(address) { + if !s.isV1Account(address) { continue } diff --git a/runtime/storage_test.go b/runtime/storage_test.go index f0bbaef75..470510149 100644 --- a/runtime/storage_test.go +++ b/runtime/storage_test.go @@ -8156,70 +8156,30 @@ func TestGetDomainStorageMapRegisterReadsForNewAccount(t *testing.T) { owner: address[:], key: []byte(AccountStorageKey), }, - // Read all available domain registers to check if it is a new account - // Read returns no value. - { - owner: address[:], - key: []byte(common.PathDomainStorage.Identifier()), - }, - { - owner: address[:], - key: []byte(common.PathDomainPrivate.Identifier()), - }, - { - owner: address[:], - key: []byte(common.PathDomainPublic.Identifier()), - }, - { - owner: address[:], - key: []byte(common.StorageDomainContract.Identifier()), - }, - { - owner: address[:], - key: []byte(common.StorageDomainInbox.Identifier()), - }, - { - owner: address[:], - key: []byte(common.StorageDomainCapabilityController.Identifier()), - }, - { - owner: address[:], - key: []byte(common.StorageDomainCapabilityControllerTag.Identifier()), - }, - { - owner: address[:], - key: []byte(common.StorageDomainPathCapability.Identifier()), - }, - { - owner: address[:], - key: []byte(common.StorageDomainAccountCapability.Identifier()), - }, - // Try to read account register to create account storage map + // Check domain register { owner: address[:], - key: []byte(AccountStorageKey), + key: []byte(common.StorageDomainPathStorage.Identifier()), }, }, expectedReadsFor2ndGetDomainStorageMapCall: []ownerKeyPair{ - // Second GetDomainStorageMap() get cached account format v2 (cached during first GetDomainStorageMap()). + // Second GetDomainStorageMap() has the same register reading as the first GetDomainStorageMap() + // because account status can't be cached in previous call. - // Try to read account register to create account storage map + // Check if account is v2 { owner: address[:], key: []byte(AccountStorageKey), }, + // Check domain register + { + owner: address[:], + key: []byte(common.StorageDomainPathStorage.Identifier()), + }, }, expectedReadsSet: map[string]struct{}{ - concatRegisterAddressAndKey(address, []byte(AccountStorageKey)): {}, - concatRegisterAddressAndDomain(address, common.StorageDomainPathStorage): {}, - concatRegisterAddressAndDomain(address, common.StorageDomainPathPrivate): {}, - concatRegisterAddressAndDomain(address, common.StorageDomainPathPublic): {}, - concatRegisterAddressAndDomain(address, common.StorageDomainContract): {}, - concatRegisterAddressAndDomain(address, common.StorageDomainInbox): {}, - concatRegisterAddressAndDomain(address, common.StorageDomainCapabilityController): {}, - concatRegisterAddressAndDomain(address, common.StorageDomainCapabilityControllerTag): {}, - concatRegisterAddressAndDomain(address, common.StorageDomainPathCapability): {}, - concatRegisterAddressAndDomain(address, common.StorageDomainAccountCapability): {}, + concatRegisterAddressAndKey(address, []byte(AccountStorageKey)): {}, + concatRegisterAddressAndDomain(address, common.StorageDomainPathStorage): {}, }, }, { @@ -8234,6 +8194,11 @@ func TestGetDomainStorageMapRegisterReadsForNewAccount(t *testing.T) { owner: address[:], key: []byte(AccountStorageKey), }, + // Check domain register + { + owner: address[:], + key: []byte(common.StorageDomainPathStorage.Identifier()), + }, // Check all domain registers { owner: address[:], @@ -8534,27 +8499,19 @@ func TestGetDomainStorageMapRegisterReadsForV1Account(t *testing.T) { owner: address[:], key: []byte(AccountStorageKey), }, - // Check all domain registers until existing domain register is read - { - owner: address[:], - key: []byte(common.PathDomainStorage.Identifier()), - }, - { - owner: address[:], - key: []byte(common.PathDomainPrivate.Identifier()), - }, - { - owner: address[:], - key: []byte(common.PathDomainPublic.Identifier()), - }, - // Read requested domain register + // Check domain register { owner: address[:], key: []byte(common.StorageDomainPathStorage.Identifier()), }, }, expectedReadsFor2ndGetDomainStorageMapCall: []ownerKeyPair{ - // Read requested domain register + // Check if account is v2 + { + owner: address[:], + key: []byte(AccountStorageKey), + }, + // Check domain register { owner: address[:], key: []byte(common.StorageDomainPathStorage.Identifier()), @@ -8563,8 +8520,6 @@ func TestGetDomainStorageMapRegisterReadsForV1Account(t *testing.T) { expectedReadsSet: map[string]struct{}{ concatRegisterAddressAndKey(address, []byte(AccountStorageKey)): {}, concatRegisterAddressAndDomain(address, common.StorageDomainPathStorage): {}, - concatRegisterAddressAndDomain(address, common.StorageDomainPathPrivate): {}, - concatRegisterAddressAndDomain(address, common.StorageDomainPathPublic): {}, }, }, { @@ -8580,7 +8535,12 @@ func TestGetDomainStorageMapRegisterReadsForV1Account(t *testing.T) { owner: address[:], key: []byte(AccountStorageKey), }, - // Check all domain registers until existing domain register is read + // Check domain register + { + owner: address[:], + key: []byte(common.StorageDomainPathStorage.Identifier()), + }, + // Check all domain registers until any existing domain is checked { owner: address[:], key: []byte(common.PathDomainStorage.Identifier()), @@ -8593,7 +8553,7 @@ func TestGetDomainStorageMapRegisterReadsForV1Account(t *testing.T) { owner: address[:], key: []byte(common.PathDomainPublic.Identifier()), }, - // Read requested domain register + // Check domain register { owner: address[:], key: []byte(common.StorageDomainPathStorage.Identifier()),