diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index b210529854..15d47bc1f4 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -6382,7 +6382,7 @@ type AccountValue interface { // AuthAccountValue type AuthAccountValue struct { Address AddressValue - storageUsedGet func() UInt64Value + storageUsedGet func(interpreter *Interpreter) UInt64Value storageCapacityGet func() UInt64Value addPublicKeyFunction FunctionValue removePublicKeyFunction FunctionValue @@ -6391,7 +6391,7 @@ type AuthAccountValue struct { func NewAuthAccountValue( address AddressValue, - storageUsedGet func() UInt64Value, + storageUsedGet func(interpreter *Interpreter) UInt64Value, storageCapacityGet func() UInt64Value, addPublicKeyFunction FunctionValue, removePublicKeyFunction FunctionValue, @@ -6518,7 +6518,7 @@ func (v AuthAccountValue) GetMember(inter *Interpreter, _ LocationRange, name st return v.Address case "storageUsed": - return v.storageUsedGet() + return v.storageUsedGet(inter) case "storageCapacity": return v.storageCapacityGet() @@ -6568,14 +6568,14 @@ func (AuthAccountValue) SetMember(_ *Interpreter, _ LocationRange, _ string, _ V type PublicAccountValue struct { Address AddressValue - storageUsedGet func() UInt64Value + storageUsedGet func(interpreter *Interpreter) UInt64Value storageCapacityGet func() UInt64Value Identifier string } func NewPublicAccountValue( address AddressValue, - storageUsedGet func() UInt64Value, + storageUsedGet func(interpreter *Interpreter) UInt64Value, storageCapacityGet func() UInt64Value, ) PublicAccountValue { return PublicAccountValue{ @@ -6632,7 +6632,7 @@ func (v PublicAccountValue) GetMember(inter *Interpreter, _ LocationRange, name return v.Address case "storageUsed": - return v.storageUsedGet() + return v.storageUsedGet(inter) case "storageCapacity": return v.storageCapacityGet() diff --git a/runtime/runtime.go b/runtime/runtime.go index 3608855d2d..e1328b9382 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -128,7 +128,7 @@ func (r *interpreterRuntime) ExecuteScript( location Location, ) (cadence.Value, error) { - runtimeStorage := newInterpreterRuntimeStorage(runtimeInterface) + runtimeStorage := newRuntimeStorage(runtimeInterface) functions := r.standardLibraryFunctions(runtimeInterface, runtimeStorage) @@ -215,7 +215,7 @@ type interpretFunc func(inter *interpreter.Interpreter) (interpreter.Value, erro func (r *interpreterRuntime) interpret( location ast.Location, runtimeInterface Interface, - runtimeStorage *interpreterRuntimeStorage, + runtimeStorage *runtimeStorage, checker *sema.Checker, functions stdlib.StandardLibraryFunctions, options []interpreter.Option, @@ -261,11 +261,11 @@ func (r *interpreterRuntime) interpret( func (r *interpreterRuntime) newAuthAccountValue( addressValue interpreter.AddressValue, runtimeInterface Interface, - runtimeStorage *interpreterRuntimeStorage, + runtimeStorage *runtimeStorage, ) interpreter.AuthAccountValue { return interpreter.NewAuthAccountValue( addressValue, - storageUsedGetFunction(addressValue, runtimeInterface), + storageUsedGetFunction(addressValue, runtimeInterface, runtimeStorage), storageCapacityGetFunction(addressValue, runtimeInterface), r.newAddPublicKeyFunction(addressValue, runtimeInterface), r.newRemovePublicKeyFunction(addressValue, runtimeInterface), @@ -283,7 +283,7 @@ func (r *interpreterRuntime) ExecuteTransaction( runtimeInterface Interface, location Location, ) error { - runtimeStorage := newInterpreterRuntimeStorage(runtimeInterface) + runtimeStorage := newRuntimeStorage(runtimeInterface) functions := r.standardLibraryFunctions(runtimeInterface, runtimeStorage) @@ -487,7 +487,7 @@ func validateArgumentParams( // ParseAndCheckProgram parses the given script and runs type check. func (r *interpreterRuntime) ParseAndCheckProgram(code []byte, runtimeInterface Interface, location Location) (*sema.Checker, error) { - runtimeStorage := newInterpreterRuntimeStorage(runtimeInterface) + runtimeStorage := newRuntimeStorage(runtimeInterface) functions := r.standardLibraryFunctions(runtimeInterface, runtimeStorage) checker, err := r.parseAndCheckProgram(code, runtimeInterface, location, functions, nil, true) @@ -618,7 +618,7 @@ func (r *interpreterRuntime) newInterpreter( checker *sema.Checker, functions stdlib.StandardLibraryFunctions, runtimeInterface Interface, - runtimeStorage *interpreterRuntimeStorage, + runtimeStorage *runtimeStorage, options []interpreter.Option, ) (*interpreter.Interpreter, error) { @@ -709,7 +709,7 @@ func (r *interpreterRuntime) importLocationHandler(runtimeInterface Interface) i func (r *interpreterRuntime) injectedCompositeFieldsHandler( runtimeInterface Interface, - runtimeStorage *interpreterRuntimeStorage, + runtimeStorage *runtimeStorage, ) interpreter.InjectedCompositeFieldsHandlerFunc { return func( _ *interpreter.Interpreter, @@ -746,7 +746,7 @@ func (r *interpreterRuntime) injectedCompositeFieldsHandler( } } -func (r *interpreterRuntime) storageInterpreterOptions(runtimeStorage *interpreterRuntimeStorage) []interpreter.Option { +func (r *interpreterRuntime) storageInterpreterOptions(runtimeStorage *runtimeStorage) []interpreter.Option { return []interpreter.Option{ interpreter.WithStorageExistenceHandler( func(_ *interpreter.Interpreter, address common.Address, key string) bool { @@ -814,12 +814,12 @@ func (r *interpreterRuntime) meteringInterpreterOptions(runtimeInterface Interfa func (r *interpreterRuntime) standardLibraryFunctions( runtimeInterface Interface, - runtimeStorage *interpreterRuntimeStorage, + runtimeStorage *runtimeStorage, ) stdlib.StandardLibraryFunctions { return append( stdlib.FlowBuiltInFunctions(stdlib.FlowBuiltinImpls{ CreateAccount: r.newCreateAccountFunction(runtimeInterface, runtimeStorage), - GetAccount: r.newGetAccountFunction(runtimeInterface), + GetAccount: r.newGetAccountFunction(runtimeInterface, runtimeStorage), Log: r.newLogFunction(runtimeInterface), GetCurrentBlock: r.newGetCurrentBlockFunction(runtimeInterface), GetBlock: r.newGetBlockFunction(runtimeInterface), @@ -961,7 +961,7 @@ func CodeToHashValue(code []byte) *interpreter.ArrayValue { func (r *interpreterRuntime) newCreateAccountFunction( runtimeInterface Interface, - runtimeStorage *interpreterRuntimeStorage, + runtimeStorage *runtimeStorage, ) interpreter.HostFunction { return func(invocation interpreter.Invocation) trampoline.Trampoline { payer, ok := invocation.Arguments[0].(interpreter.AuthAccountValue) @@ -996,9 +996,19 @@ func (r *interpreterRuntime) newCreateAccountFunction( return trampoline.Done{Result: account} } } -func storageUsedGetFunction(addressValue interpreter.AddressValue, runtimeInterface Interface) func() interpreter.UInt64Value { + +func storageUsedGetFunction( + addressValue interpreter.AddressValue, + runtimeInterface Interface, + runtimeStorage *runtimeStorage, +) func(inter *interpreter.Interpreter) interpreter.UInt64Value { address := addressValue.ToAddress() - return func() interpreter.UInt64Value { + return func(inter *interpreter.Interpreter) interpreter.UInt64Value { + + // NOTE: flush the cached values, so the host environment + // can properly calculate the amount of storage used by the account + runtimeStorage.writeCached(inter) + var capacity uint64 var err error wrapPanic(func() { @@ -1096,7 +1106,7 @@ func (r *interpreterRuntime) newRemovePublicKeyFunction( } func (r *interpreterRuntime) writeContract( - runtimeStorage *interpreterRuntimeStorage, + runtimeStorage *runtimeStorage, addressValue interpreter.AddressValue, name string, contractValue interpreter.OptionalValue, @@ -1121,7 +1131,7 @@ func (r *interpreterRuntime) loadContract( constructor interpreter.FunctionValue, invocationRange ast.Range, runtimeInterface Interface, - runtimeStorage *interpreterRuntimeStorage, + runtimeStorage *runtimeStorage, ) *interpreter.CompositeValue { switch compositeType.Location { @@ -1169,7 +1179,7 @@ func (r *interpreterRuntime) instantiateContract( constructorArguments []interpreter.Value, argumentTypes []sema.Type, runtimeInterface Interface, - runtimeStorage *interpreterRuntimeStorage, + runtimeStorage *runtimeStorage, checker *sema.Checker, functions stdlib.StandardLibraryFunctions, invocationRange ast.Range, @@ -1277,11 +1287,12 @@ func (r *interpreterRuntime) instantiateContract( return contract, err } -func (r *interpreterRuntime) newGetAccountFunction(runtimeInterface Interface) interpreter.HostFunction { +func (r *interpreterRuntime) newGetAccountFunction(runtimeInterface Interface, runtimeStorage *runtimeStorage) interpreter.HostFunction { return func(invocation interpreter.Invocation) trampoline.Trampoline { accountAddress := invocation.Arguments[0].(interpreter.AddressValue) - publicAccount := interpreter.NewPublicAccountValue(accountAddress, - storageUsedGetFunction(accountAddress, runtimeInterface), + publicAccount := interpreter.NewPublicAccountValue( + accountAddress, + storageUsedGetFunction(accountAddress, runtimeInterface, runtimeStorage), storageCapacityGetFunction(accountAddress, runtimeInterface), ) return trampoline.Done{Result: publicAccount} @@ -1369,7 +1380,7 @@ func (r *interpreterRuntime) newUnsafeRandomFunction(runtimeInterface Interface) func (r *interpreterRuntime) newAuthAccountContracts( addressValue interpreter.AddressValue, runtimeInterface Interface, - runtimeStorage *interpreterRuntimeStorage, + runtimeStorage *runtimeStorage, ) interpreter.AuthAccountContractsValue { return interpreter.AuthAccountContractsValue{ Address: addressValue, @@ -1383,7 +1394,7 @@ func (r *interpreterRuntime) newAuthAccountContracts( func (r *interpreterRuntime) newAuthAccountContractsChangeFunction( addressValue interpreter.AddressValue, runtimeInterface Interface, - runtimeStorage *interpreterRuntimeStorage, + runtimeStorage *runtimeStorage, isUpdate bool, ) interpreter.HostFunctionValue { return interpreter.NewHostFunctionValue( @@ -1564,7 +1575,7 @@ type updateAccountContractCodeOptions struct { // func (r *interpreterRuntime) updateAccountContractCode( runtimeInterface Interface, - runtimeStorage *interpreterRuntimeStorage, + runtimeStorage *runtimeStorage, name string, code []byte, addressValue interpreter.AddressValue, @@ -1674,7 +1685,7 @@ func (r *interpreterRuntime) newAuthAccountContractsGetFunction( func (r *interpreterRuntime) newAuthAccountContractsRemoveFunction( addressValue interpreter.AddressValue, runtimeInterface Interface, - runtimeStorage *interpreterRuntimeStorage, + runtimeStorage *runtimeStorage, ) interpreter.HostFunctionValue { return interpreter.NewHostFunctionValue( func(invocation interpreter.Invocation) trampoline.Trampoline { diff --git a/runtime/runtime_storage.go b/runtime/runtime_storage.go index 7755e83a5d..a0b323e9be 100644 --- a/runtime/runtime_storage.go +++ b/runtime/runtime_storage.go @@ -41,21 +41,21 @@ type CacheEntry struct { Value interpreter.Value } -type interpreterRuntimeStorage struct { +type runtimeStorage struct { runtimeInterface Interface highLevelStorageEnabled bool highLevelStorage HighLevelStorage cache Cache } -func newInterpreterRuntimeStorage(runtimeInterface Interface) *interpreterRuntimeStorage { +func newRuntimeStorage(runtimeInterface Interface) *runtimeStorage { highLevelStorageEnabled := false highLevelStorage, ok := runtimeInterface.(HighLevelStorage) if ok { highLevelStorageEnabled = highLevelStorage.HighLevelStorageEnabled() } - return &interpreterRuntimeStorage{ + return &runtimeStorage{ runtimeInterface: runtimeInterface, cache: Cache{}, highLevelStorage: highLevelStorage, @@ -71,7 +71,7 @@ func newInterpreterRuntimeStorage(runtimeInterface Interface) *interpreterRuntim // If there is a cache miss, the key is read from storage through the runtime interface, // places in the cache, and returned. // -func (s *interpreterRuntimeStorage) valueExists( +func (s *runtimeStorage) valueExists( address common.Address, key string, ) bool { @@ -116,7 +116,7 @@ func (s *interpreterRuntimeStorage) valueExists( // If there is a cache miss, the key is read from storage through the runtime interface, // places in the cache, and returned. // -func (s *interpreterRuntimeStorage) readValue( +func (s *runtimeStorage) readValue( address common.Address, key string, deferred bool, @@ -192,7 +192,7 @@ func (s *interpreterRuntimeStorage) readValue( // It does *not* serialize/save the value in storage (through the runtime interface). // (The Cache is finally written back through the runtime interface in `writeCached`.) // -func (s *interpreterRuntimeStorage) writeValue( +func (s *runtimeStorage) writeValue( address common.Address, key string, value interpreter.OptionalValue, @@ -224,7 +224,7 @@ func (s *interpreterRuntimeStorage) writeValue( // writeCached serializes/saves all values in the cache in storage (through the runtime interface). // -func (s *interpreterRuntimeStorage) writeCached(inter *interpreter.Interpreter) { +func (s *runtimeStorage) writeCached(inter *interpreter.Interpreter) { type writeItem struct { storageKey StorageKey @@ -321,7 +321,7 @@ func (s *interpreterRuntimeStorage) writeCached(inter *interpreter.Interpreter) } } -func (s *interpreterRuntimeStorage) encodeValue( +func (s *runtimeStorage) encodeValue( value interpreter.Value, path string, ) ( @@ -341,7 +341,7 @@ func (s *interpreterRuntimeStorage) encodeValue( return } -func (s *interpreterRuntimeStorage) move( +func (s *runtimeStorage) move( oldOwner common.Address, oldKey string, newOwner common.Address, newKey string, ) { diff --git a/runtime/storage_test.go b/runtime/storage_test.go index f3b5ba881b..848d9134e1 100644 --- a/runtime/storage_test.go +++ b/runtime/storage_test.go @@ -289,3 +289,59 @@ func TestRuntimeMagic(t *testing.T) { writes, ) } + +func TestAccountStorageStorage(t *testing.T) { + runtime := NewInterpreterRuntime() + + script := []byte(` + transaction { + prepare(signer: AuthAccount) { + let before = signer.storageUsed + signer.save(42, to: /storage/answer) + let after = signer.storageUsed + log(after != before) + } + } + `) + + var loggedMessages []string + + storage := newTestStorage(nil, nil) + + runtimeInterface := &testRuntimeInterface{ + storage: storage, + getSigningAccounts: func() ([]Address, error) { + return []Address{{42}}, nil + }, + getStorageUsed: func(_ Address) (uint64, error) { + var amount uint64 = 0 + + for _, data := range storage.storedValues { + amount += uint64(len(data)) + } + + return amount, nil + }, + log: func(message string) { + loggedMessages = append(loggedMessages, message) + }, + } + + nextTransactionLocation := newTransactionLocationGenerator() + + err := runtime.ExecuteTransaction( + Script{ + Source: script, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + require.Equal(t, + []string{"true"}, + loggedMessages, + ) +} diff --git a/runtime/tests/interpreter/account_test.go b/runtime/tests/interpreter/account_test.go index 90ae6870e3..4f28e13c00 100644 --- a/runtime/tests/interpreter/account_test.go +++ b/runtime/tests/interpreter/account_test.go @@ -55,7 +55,9 @@ func testAccount(t *testing.T, auth bool, code string) (*interpreter.Interpreter values["authAccount"] = interpreter.NewAuthAccountValue( address, - returnZero, + func(interpreter *interpreter.Interpreter) interpreter.UInt64Value { + return 0 + }, returnZero, panicFunction, panicFunction, @@ -73,7 +75,9 @@ func testAccount(t *testing.T, auth bool, code string) (*interpreter.Interpreter values["pubAccount"] = interpreter.NewPublicAccountValue( address, - returnZero, + func(interpreter *interpreter.Interpreter) interpreter.UInt64Value { + return 0 + }, returnZero, ) diff --git a/runtime/tests/interpreter/interpreter_test.go b/runtime/tests/interpreter/interpreter_test.go index cd0ded5f3b..48d33b147c 100644 --- a/runtime/tests/interpreter/interpreter_test.go +++ b/runtime/tests/interpreter/interpreter_test.go @@ -6615,7 +6615,9 @@ func TestInterpretContractAccountFieldUse(t *testing.T) { return map[string]interpreter.Value{ "account": interpreter.NewAuthAccountValue( addressValue, - returnZero, + func(interpreter *interpreter.Interpreter) interpreter.UInt64Value { + return 0 + }, returnZero, panicFunction, panicFunction, @@ -7360,7 +7362,9 @@ func TestInterpretResourceOwnerFieldUse(t *testing.T) { values := map[string]interpreter.Value{ "account": interpreter.NewAuthAccountValue( addressValue, - returnZero, + func(interpreter *interpreter.Interpreter) interpreter.UInt64Value { + return 0 + }, returnZero, panicFunction, panicFunction, diff --git a/runtime/tests/interpreter/transactions_test.go b/runtime/tests/interpreter/transactions_test.go index 29eb5024ef..fd543e6c5b 100644 --- a/runtime/tests/interpreter/transactions_test.go +++ b/runtime/tests/interpreter/transactions_test.go @@ -221,7 +221,9 @@ func TestInterpretTransactions(t *testing.T) { signer1 := interpreter.NewAuthAccountValue( interpreter.AddressValue{0, 0, 0, 0, 0, 0, 0, 1}, - returnZero, + func(interpreter *interpreter.Interpreter) interpreter.UInt64Value { + return 0 + }, returnZero, panicFunction, panicFunction, @@ -229,7 +231,9 @@ func TestInterpretTransactions(t *testing.T) { ) signer2 := interpreter.NewAuthAccountValue( interpreter.AddressValue{0, 0, 0, 0, 0, 0, 0, 2}, - returnZero, + func(interpreter *interpreter.Interpreter) interpreter.UInt64Value { + return 0 + }, returnZero, panicFunction, panicFunction, @@ -268,7 +272,9 @@ func TestInterpretTransactions(t *testing.T) { prepareArguments := []interpreter.Value{ interpreter.NewAuthAccountValue( interpreter.AddressValue{}, - returnZero, + func(interpreter *interpreter.Interpreter) interpreter.UInt64Value { + return 0 + }, returnZero, panicFunction, panicFunction,