diff --git a/runtime/contract_test.go b/runtime/contract_test.go index f242fb59ac..d09166b68b 100644 --- a/runtime/contract_test.go +++ b/runtime/contract_test.go @@ -26,8 +26,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/onflow/cadence/runtime/errors" "github.com/onflow/cadence/runtime/interpreter" "github.com/onflow/cadence/runtime/stdlib" + "github.com/onflow/cadence/runtime/tests/checker" . "github.com/onflow/cadence/runtime/tests/utils" "github.com/onflow/cadence" @@ -801,3 +803,288 @@ func TestRuntimeImportMultipleContracts(t *testing.T) { require.NoError(t, err) }) } + +func TestRuntimeContractTryUpdate(t *testing.T) { + t.Parallel() + + newTestRuntimeInterface := func(onUpdate func()) *testRuntimeInterface { + var actualEvents []cadence.Event + storage := newTestLedger(nil, nil) + accountCodes := map[Location][]byte{} + + return &testRuntimeInterface{ + storage: storage, + log: func(message string) {}, + emitEvent: func(event cadence.Event) error { + actualEvents = append(actualEvents, event) + return nil + }, + resolveLocation: singleIdentifierLocationResolver(t), + getSigningAccounts: func() ([]Address, error) { + return []Address{[8]byte{0, 0, 0, 0, 0, 0, 0, 1}}, nil + }, + updateAccountContractCode: func(location common.AddressLocation, code []byte) error { + onUpdate() + accountCodes[location] = code + return nil + }, + getAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + code = accountCodes[location] + return code, nil + }, + } + } + + t.Run("tryUpdate simple", func(t *testing.T) { + + t.Parallel() + + rt := newTestInterpreterRuntime() + + deployTx := DeploymentTransaction("Foo", []byte(`access(all) contract Foo {}`)) + + updateTx := []byte(` + transaction { + prepare(signer: AuthAccount) { + let code = "access(all) contract Foo { access(all) fun sayHello(): String {return \"hello\"} }".utf8 + let deploymentResult = signer.contracts.tryUpdate( + name: "Foo", + code: code, + ) + let deployedContract = deploymentResult.deployedContract! + assert(deployedContract.name == "Foo") + assert(deployedContract.address == 0x1) + assert(deployedContract.code == code) + } + } + `) + + invokeTx := []byte(` + import Foo from 0x1 + transaction { + prepare(signer: AuthAccount) { + assert(Foo.sayHello() == "hello") + } + } + `) + + runtimeInterface := newTestRuntimeInterface(func() {}) + nextTransactionLocation := newTransactionLocationGenerator() + + // Deploy 'Foo' + err := rt.ExecuteTransaction( + Script{ + Source: deployTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + // Update 'Foo' + err = rt.ExecuteTransaction( + Script{ + Source: updateTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + // Test the updated 'Foo' + err = rt.ExecuteTransaction( + Script{ + Source: invokeTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + }) + + t.Run("tryUpdate non existing", func(t *testing.T) { + + t.Parallel() + + rt := newTestInterpreterRuntime() + + updateTx := []byte(` + transaction { + prepare(signer: AuthAccount) { + let deploymentResult = signer.contracts.tryUpdate( + name: "Foo", + code: "access(all) contract Foo { access(all) fun sayHello(): String {return \"hello\"} }".utf8, + ) + assert(deploymentResult.deployedContract == nil) + } + } + `) + + invokeTx := []byte(` + import Foo from 0x1 + transaction { + prepare(signer: AuthAccount) { + assert(Foo.sayHello() == "hello") + } + } + `) + + runtimeInterface := newTestRuntimeInterface(func() {}) + nextTransactionLocation := newTransactionLocationGenerator() + + // Update non-existing 'Foo'. Should not panic. + err := rt.ExecuteTransaction( + Script{ + Source: updateTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + // Test the updated 'Foo'. + // Foo must not be available. + + err = rt.ExecuteTransaction( + Script{ + Source: invokeTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + RequireError(t, err) + + errs := checker.RequireCheckerErrors(t, err, 1) + var notExportedError *sema.NotExportedError + require.ErrorAs(t, errs[0], ¬ExportedError) + }) + + t.Run("tryUpdate with checking error", func(t *testing.T) { + + t.Parallel() + + rt := newTestInterpreterRuntime() + + deployTx := DeploymentTransaction("Foo", []byte(`access(all) contract Foo {}`)) + + updateTx := []byte(` + transaction { + prepare(signer: AuthAccount) { + let deploymentResult = signer.contracts.tryUpdate( + name: "Foo", + // Has a semantic error! + code: "access(all) contract Foo { access(all) fun sayHello(): Int { return \"hello\" } }".utf8, + ) + assert(deploymentResult.deployedContract == nil) + } + } + `) + + runtimeInterface := newTestRuntimeInterface(func() {}) + + nextTransactionLocation := newTransactionLocationGenerator() + + // Deploy 'Foo' + err := rt.ExecuteTransaction( + Script{ + Source: deployTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + // Update 'Foo'. + // User errors (parsing, checking and interpreting) should be handled gracefully. + + err = rt.ExecuteTransaction( + Script{ + Source: updateTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + require.NoError(t, err) + }) + + t.Run("tryUpdate panic with internal error", func(t *testing.T) { + + t.Parallel() + + rt := newTestInterpreterRuntime() + + deployTx := DeploymentTransaction("Foo", []byte(`access(all) contract Foo {}`)) + + updateTx := []byte(` + transaction { + prepare(signer: AuthAccount) { + let deploymentResult = signer.contracts.tryUpdate( + name: "Foo", + code: "access(all) contract Foo { access(all) fun sayHello(): String {return \"hello\"} }".utf8, + ) + assert(deploymentResult.deployedContract == nil) + } + } + `) + + shouldPanic := false + didPanic := false + + runtimeInterface := newTestRuntimeInterface(func() { + if shouldPanic { + didPanic = true + panic("panic during update") + } + }) + + nextTransactionLocation := newTransactionLocationGenerator() + + // Deploy 'Foo' + err := rt.ExecuteTransaction( + Script{ + Source: deployTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + assert.False(t, didPanic) + + // Update 'Foo'. + // Internal errors should NOT be handled gracefully. + + shouldPanic = true + err = rt.ExecuteTransaction( + Script{ + Source: updateTx, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + RequireError(t, err) + var unexpectedError errors.UnexpectedError + require.ErrorAs(t, err, &unexpectedError) + + assert.True(t, didPanic) + }) +} diff --git a/runtime/convertValues_test.go b/runtime/convertValues_test.go index eff3689143..5b79b88975 100644 --- a/runtime/convertValues_test.go +++ b/runtime/convertValues_test.go @@ -5649,3 +5649,153 @@ func TestDestroyedResourceReferenceExport(t *testing.T) { require.Error(t, err) require.ErrorAs(t, err, &interpreter.DestroyedResourceError{}) } + +func TestRuntimeDeploymentResultValueImportExport(t *testing.T) { + + t.Parallel() + + t.Run("import", func(t *testing.T) { + + t.Parallel() + + script := ` + access(all) fun main(v: DeploymentResult) {} + ` + + rt := newTestInterpreterRuntime() + runtimeInterface := &testRuntimeInterface{} + + _, err := rt.ExecuteScript( + Script{ + Source: []byte(script), + }, + Context{ + Interface: runtimeInterface, + Location: common.ScriptLocation{}, + }, + ) + + RequireError(t, err) + + var notImportableError *ScriptParameterTypeNotImportableError + require.ErrorAs(t, err, ¬ImportableError) + }) + + t.Run("export", func(t *testing.T) { + + t.Parallel() + + script := ` + access(all) fun main(): DeploymentResult? { + return nil + } + ` + + rt := newTestInterpreterRuntime() + runtimeInterface := &testRuntimeInterface{} + + _, err := rt.ExecuteScript( + Script{ + Source: []byte(script), + }, + Context{ + Interface: runtimeInterface, + Location: common.ScriptLocation{}, + }, + ) + + RequireError(t, err) + + var invalidReturnTypeError *InvalidScriptReturnTypeError + require.ErrorAs(t, err, &invalidReturnTypeError) + }) +} + +func TestRuntimeDeploymentResultTypeImportExport(t *testing.T) { + + t.Parallel() + + t.Run("import", func(t *testing.T) { + + t.Parallel() + + script := ` + access(all) fun main(v: Type) { + assert(v == Type()) + } + ` + + rt := newTestInterpreterRuntime() + + typeValue := cadence.NewTypeValue(&cadence.StructType{ + QualifiedIdentifier: "DeploymentResult", + Fields: []cadence.Field{ + { + Type: cadence.NewOptionalType(cadence.TheDeployedContractType), + Identifier: "deployedContract", + }, + }, + }) + + encodedArg, err := json.Encode(typeValue) + require.NoError(t, err) + + runtimeInterface := &testRuntimeInterface{} + + runtimeInterface.decodeArgument = func(b []byte, t cadence.Type) (value cadence.Value, err error) { + return json.Decode(runtimeInterface, b) + } + + _, err = rt.ExecuteScript( + Script{ + Source: []byte(script), + Arguments: [][]byte{encodedArg}, + }, + Context{ + Interface: runtimeInterface, + Location: common.ScriptLocation{}, + }, + ) + + require.NoError(t, err) + }) + + t.Run("export", func(t *testing.T) { + + t.Parallel() + + script := ` + access(all) fun main(): Type { + return Type() + } + ` + + rt := newTestInterpreterRuntime() + runtimeInterface := &testRuntimeInterface{} + + result, err := rt.ExecuteScript( + Script{ + Source: []byte(script), + }, + Context{ + Interface: runtimeInterface, + Location: common.ScriptLocation{}, + }, + ) + + require.NoError(t, err) + + assert.Equal(t, + cadence.NewTypeValue(&cadence.StructType{ + QualifiedIdentifier: "DeploymentResult", + Fields: []cadence.Field{ + { + Type: cadence.NewOptionalType(cadence.TheDeployedContractType), + Identifier: "deployedContract", + }, + }, + }), + result, + ) + }) +} diff --git a/runtime/interpreter/value_account_contracts.go b/runtime/interpreter/value_account_contracts.go index 47daddd3df..4cb21c9287 100644 --- a/runtime/interpreter/value_account_contracts.go +++ b/runtime/interpreter/value_account_contracts.go @@ -38,6 +38,7 @@ func NewAuthAccountContractsValue( address AddressValue, addFunction FunctionValue, updateFunction FunctionValue, + tryUpdateFunction FunctionValue, getFunction FunctionValue, borrowFunction FunctionValue, removeFunction FunctionValue, @@ -50,6 +51,7 @@ func NewAuthAccountContractsValue( sema.AuthAccountContractsTypeBorrowFunctionName: borrowFunction, sema.AuthAccountContractsTypeRemoveFunctionName: removeFunction, sema.AuthAccountContractsTypeUpdate__experimentalFunctionName: updateFunction, + sema.AuthAccountContractsTypeTryUpdateFunctionName: tryUpdateFunction, } computeField := func( diff --git a/runtime/interpreter/value_deployment_result.go b/runtime/interpreter/value_deployment_result.go new file mode 100644 index 0000000000..c39f2a25ed --- /dev/null +++ b/runtime/interpreter/value_deployment_result.go @@ -0,0 +1,49 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package interpreter + +import ( + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/sema" +) + +// DeploymentResult + +var deploymentResultTypeID = sema.DeploymentResultType.ID() +var deploymentResultStaticType = ConvertSemaToStaticType(nil, sema.DeploymentResultType) // unmetered +var deploymentResultFieldNames []string = nil + +func NewDeploymentResultValue( + gauge common.MemoryGauge, + deployedContract OptionalValue, +) Value { + + return NewSimpleCompositeValue( + gauge, + deploymentResultTypeID, + deploymentResultStaticType, + deploymentResultFieldNames, + map[string]Value{ + sema.DeploymentResultTypeDeployedContractFieldName: deployedContract, + }, + nil, + nil, + nil, + ) +} diff --git a/runtime/sema/authaccount.cdc b/runtime/sema/authaccount.cdc index 4814762454..b9bba93e5d 100644 --- a/runtime/sema/authaccount.cdc +++ b/runtime/sema/authaccount.cdc @@ -244,6 +244,25 @@ pub struct AuthAccount { /// Returns the deployed contract for the updated contract. pub fun update__experimental(name: String, code: [UInt8]): DeployedContract + /// Updates the code for the contract/contract interface in the account, + /// and handle any deployment errors gracefully. + /// + /// The `code` parameter is the UTF-8 encoded representation of the source code. + /// The code must contain exactly one contract or contract interface, + /// which must have the same name as the `name` parameter. + /// + /// Does **not** run the initializer of the contract/contract interface again. + /// The contract instance in the world state stays as is. + /// + /// Fails if no contract/contract interface with the given name exists in the account, + /// if the given code does not declare exactly one contract or contract interface, + /// or if the given name does not match the name of the contract/contract interface declaration in the code. + /// + /// Returns the deployment result. + /// Result would contain the deployed contract for the updated contract, if the update was successfull. + /// Otherwise, the deployed contract would be nil. + pub fun tryUpdate(name: String, code: [UInt8]): DeploymentResult + /// Returns the deployed contract for the contract/contract interface with the given name in the account, if any. /// /// Returns nil if no contract/contract interface with the given name exists in the account. diff --git a/runtime/sema/authaccount.gen.go b/runtime/sema/authaccount.gen.go index bc39b67733..a3cdb7ceaa 100644 --- a/runtime/sema/authaccount.gen.go +++ b/runtime/sema/authaccount.gen.go @@ -773,6 +773,46 @@ or if the given name does not match the name of the contract/contract interface Returns the deployed contract for the updated contract. ` +const AuthAccountContractsTypeTryUpdateFunctionName = "tryUpdate" + +var AuthAccountContractsTypeTryUpdateFunctionType = &FunctionType{ + Parameters: []Parameter{ + { + Identifier: "name", + TypeAnnotation: NewTypeAnnotation(StringType), + }, + { + Identifier: "code", + TypeAnnotation: NewTypeAnnotation(&VariableSizedType{ + Type: UInt8Type, + }), + }, + }, + ReturnTypeAnnotation: NewTypeAnnotation( + DeploymentResultType, + ), +} + +const AuthAccountContractsTypeTryUpdateFunctionDocString = ` +Updates the code for the contract/contract interface in the account, +and handle any deployment errors gracefully. + +The ` + "`code`" + ` parameter is the UTF-8 encoded representation of the source code. +The code must contain exactly one contract or contract interface, +which must have the same name as the ` + "`name`" + ` parameter. + +Does **not** run the initializer of the contract/contract interface again. +The contract instance in the world state stays as is. + +Fails if no contract/contract interface with the given name exists in the account, +if the given code does not declare exactly one contract or contract interface, +or if the given name does not match the name of the contract/contract interface declaration in the code. + +Returns the deployment result. +Result would contain the deployed contract for the updated contract, if the update was successfull. +Otherwise, the deployed contract would be nil. +` + const AuthAccountContractsTypeGetFunctionName = "get" var AuthAccountContractsTypeGetFunctionType = &FunctionType{ @@ -891,6 +931,13 @@ func init() { AuthAccountContractsTypeUpdate__experimentalFunctionType, AuthAccountContractsTypeUpdate__experimentalFunctionDocString, ), + NewUnmeteredFunctionMember( + AuthAccountContractsType, + ast.AccessPublic, + AuthAccountContractsTypeTryUpdateFunctionName, + AuthAccountContractsTypeTryUpdateFunctionType, + AuthAccountContractsTypeTryUpdateFunctionDocString, + ), NewUnmeteredFunctionMember( AuthAccountContractsType, ast.AccessPublic, diff --git a/runtime/sema/deployment_result.cdc b/runtime/sema/deployment_result.cdc new file mode 100644 index 0000000000..2f00d67553 --- /dev/null +++ b/runtime/sema/deployment_result.cdc @@ -0,0 +1,11 @@ +#compositeType +access(all) +struct DeploymentResult { + + /// The deployed contract. + /// + /// If the the deployment was unsuccessful, this will be nil. + /// + access(all) + let deployedContract: DeployedContract? +} \ No newline at end of file diff --git a/runtime/sema/deployment_result.gen.go b/runtime/sema/deployment_result.gen.go new file mode 100644 index 0000000000..9d3290e11d --- /dev/null +++ b/runtime/sema/deployment_result.gen.go @@ -0,0 +1,66 @@ +// Code generated from deployment_result.cdc. DO NOT EDIT. +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sema + +import ( + "github.com/onflow/cadence/runtime/ast" + "github.com/onflow/cadence/runtime/common" +) + +const DeploymentResultTypeDeployedContractFieldName = "deployedContract" + +var DeploymentResultTypeDeployedContractFieldType = &OptionalType{ + Type: DeployedContractType, +} + +const DeploymentResultTypeDeployedContractFieldDocString = ` +The deployed contract. + +If the the deployment was unsuccessful, this will be nil. +` + +const DeploymentResultTypeName = "DeploymentResult" + +var DeploymentResultType = func() *CompositeType { + var t = &CompositeType{ + Identifier: DeploymentResultTypeName, + Kind: common.CompositeKindStructure, + importable: false, + hasComputedMembers: true, + } + + return t +}() + +func init() { + var members = []*Member{ + NewUnmeteredFieldMember( + DeploymentResultType, + ast.AccessPublic, + ast.VariableKindConstant, + DeploymentResultTypeDeployedContractFieldName, + DeploymentResultTypeDeployedContractFieldType, + DeploymentResultTypeDeployedContractFieldDocString, + ), + } + + DeploymentResultType.Members = MembersAsMap(members) + DeploymentResultType.Fields = MembersFieldNames(members) +} diff --git a/runtime/sema/deployment_result.go b/runtime/sema/deployment_result.go new file mode 100644 index 0000000000..d360f76a71 --- /dev/null +++ b/runtime/sema/deployment_result.go @@ -0,0 +1,21 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sema + +//go:generate go run ./gen deployment_result.cdc deployment_result.gen.go diff --git a/runtime/sema/gen/main.go b/runtime/sema/gen/main.go index 4fa981605b..2767a56198 100644 --- a/runtime/sema/gen/main.go +++ b/runtime/sema/gen/main.go @@ -156,8 +156,9 @@ type typeDecl struct { } type generator struct { - typeStack []*typeDecl - decls []dst.Decl + typeStack []*typeDecl + decls []dst.Decl + leadingPragma map[string]struct{} } var _ ast.DeclarationVisitor[struct{}] = &generator{} @@ -355,15 +356,32 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ g.typeStack = g.typeStack[:lastIndex] }() - // We can generate a SimpleType declaration, - // if this is a top-level type, - // and this declaration has no nested type declarations. - // Otherwise, we have to generate a CompositeType + var generateSimpleType bool - canGenerateSimpleType := len(g.typeStack) == 1 + // Check if the declaration is explicitly marked to be generated as a composite type. + if _, ok := g.leadingPragma["compositeType"]; ok { + generateSimpleType = false + } else { + // If not, decide what to generate depending on the type. + + // We can generate a SimpleType declaration, + // if this is a top-level type, + // and this declaration has no nested type declarations. + // Otherwise, we have to generate a CompositeType + generateSimpleType = len(g.typeStack) == 1 + if generateSimpleType { + switch compositeKind { + case common.CompositeKindStructure, + common.CompositeKindResource: + break + default: + generateSimpleType = false + } + } + } for _, memberDeclaration := range decl.Members.Declarations() { - ast.AcceptDeclaration[struct{}](memberDeclaration, g) + generateDeclaration(g, memberDeclaration) // Visiting unsupported declarations panics, // so only supported member declarations are added @@ -378,14 +396,14 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ break default: - canGenerateSimpleType = false + generateSimpleType = false } } for _, conformance := range decl.Conformances { switch conformance.Identifier.Identifier { case "Storable": - if !canGenerateSimpleType { + if !generateSimpleType { panic(fmt.Errorf( "composite types cannot be explicitly marked as storable: %s", g.currentTypeID(), @@ -394,7 +412,7 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ typeDecl.storable = true case "Equatable": - if !canGenerateSimpleType { + if !generateSimpleType { panic(fmt.Errorf( "composite types cannot be explicitly marked as equatable: %s", g.currentTypeID(), @@ -403,7 +421,7 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ typeDecl.equatable = true case "Comparable": - if !canGenerateSimpleType { + if !generateSimpleType { panic(fmt.Errorf( "composite types cannot be explicitly marked as comparable: %s", g.currentTypeID(), @@ -412,7 +430,7 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ typeDecl.comparable = true case "Exportable": - if !canGenerateSimpleType { + if !generateSimpleType { panic(fmt.Errorf( "composite types cannot be explicitly marked as exportable: %s", g.currentTypeID(), @@ -426,7 +444,7 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ } var typeVarDecl dst.Expr - if canGenerateSimpleType { + if generateSimpleType { typeVarDecl = simpleTypeLiteral(typeDecl) } else { typeVarDecl = compositeTypeExpr(typeDecl) @@ -449,7 +467,7 @@ func (g *generator) VisitCompositeDeclaration(decl *ast.CompositeDeclaration) (_ if len(memberDeclarations) > 0 { - if canGenerateSimpleType { + if generateSimpleType { // func init() { // t.Members = func(t *SimpleType) map[string]MemberResolver { @@ -926,14 +944,39 @@ func (*generator) VisitEnumCaseDeclaration(_ *ast.EnumCaseDeclaration) struct{} panic("enum case declarations are not supported") } -func (*generator) VisitPragmaDeclaration(_ *ast.PragmaDeclaration) struct{} { - panic("pragma declarations are not supported") +func (g *generator) VisitPragmaDeclaration(pragma *ast.PragmaDeclaration) (_ struct{}) { + // Treat pragmas as part of the declaration to follow. + + identifierExpr, ok := pragma.Expression.(*ast.IdentifierExpression) + if !ok { + panic("only identifier pragmas are supported") + } + + if g.leadingPragma == nil { + g.leadingPragma = map[string]struct{}{} + } + g.leadingPragma[identifierExpr.Identifier.Identifier] = struct{}{} + + return } func (*generator) VisitImportDeclaration(_ *ast.ImportDeclaration) struct{} { panic("import declarations are not supported") } +func generateDeclaration(gen *generator, declaration ast.Declaration) { + // Treat leading pragmas as part of this declaration. + // Reset them after finishing the current decl. This is to handle nested declarations. + if declaration.DeclarationKind() != common.DeclarationKindPragma { + prevLeadingPragma := gen.leadingPragma + defer func() { + gen.leadingPragma = prevLeadingPragma + }() + } + + _ = ast.AcceptDeclaration[struct{}](declaration, gen) +} + func (g *generator) newFullTypeName(typeName string) string { if len(g.typeStack) == 0 { return typeName diff --git a/runtime/sema/type.go b/runtime/sema/type.go index ec47eee467..347c435e29 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -3411,6 +3411,7 @@ func init() { HashAlgorithmType, StorageCapabilityControllerType, AccountCapabilityControllerType, + DeploymentResultType, }, ) @@ -7358,6 +7359,7 @@ func init() { SignatureAlgorithmType, AuthAccountType, PublicAccountType, + DeploymentResultType, } for len(compositeTypes) > 0 { diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index f56cfada2d..0fcf70d38c 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -22,6 +22,7 @@ import ( "fmt" "golang.org/x/crypto/sha3" + "golang.org/x/xerrors" "github.com/onflow/atree" @@ -270,6 +271,12 @@ func newAuthAccountContractsValue( addressValue, true, ), + newAuthAccountContractsTryUpdateFunction( + sema.AuthAccountContractsTypeTryUpdateFunctionType, + gauge, + handler, + addressValue, + ), newAccountContractsGetFunction( sema.AuthAccountContractsTypeGetFunctionType, gauge, @@ -1431,302 +1438,358 @@ func newAccountContractsBorrowFunction( ) } -type AccountContractAdditionHandler interface { - EventEmitter - AccountContractProvider - ParseAndCheckProgram( - code []byte, - location common.Location, - getAndSetProgram bool, - ) (*interpreter.Program, error) - // UpdateAccountContractCode updates the code associated with an account contract. - UpdateAccountContractCode(location common.AddressLocation, code []byte) error - RecordContractUpdate( - location common.AddressLocation, - value *interpreter.CompositeValue, - ) - ContractUpdateRecorded(location common.AddressLocation) bool - InterpretContract( - location common.AddressLocation, - program *interpreter.Program, - name string, - invocation DeployedContractConstructorInvocation, - ) ( - *interpreter.CompositeValue, - error, - ) - TemporarilyRecordCode(location common.AddressLocation, code []byte) +func changeAccountContracts( + invocation interpreter.Invocation, + handler AccountContractAdditionHandler, + addressValue interpreter.AddressValue, + isUpdate bool, +) interpreter.Value { - // StartContractAddition starts adding a contract. - StartContractAddition(location common.AddressLocation) + locationRange := invocation.LocationRange - // EndContractAddition ends adding the contract - EndContractAddition(location common.AddressLocation) + const requiredArgumentCount = 2 - // IsContractBeingAdded checks whether a contract is being added in the current execution. - IsContractBeingAdded(location common.AddressLocation) bool -} + nameValue, ok := invocation.Arguments[0].(*interpreter.StringValue) + if !ok { + panic(errors.NewUnreachableError()) + } -// newAuthAccountContractsChangeFunction called when e.g. -// - adding: `AuthAccount.contracts.add(name: "Foo", code: [...])` (isUpdate = false) -// - updating: `AuthAccount.contracts.update__experimental(name: "Foo", code: [...])` (isUpdate = true) -func newAuthAccountContractsChangeFunction( - functionType *sema.FunctionType, - gauge common.MemoryGauge, - handler AccountContractAdditionHandler, - addressValue interpreter.AddressValue, - isUpdate bool, -) *interpreter.HostFunctionValue { - return interpreter.NewHostFunctionValue( - gauge, - functionType, - func(invocation interpreter.Invocation) interpreter.Value { + newCodeValue, ok := invocation.Arguments[1].(*interpreter.ArrayValue) + if !ok { + panic(errors.NewUnreachableError()) + } - locationRange := invocation.LocationRange + constructorArguments := invocation.Arguments[requiredArgumentCount:] + constructorArgumentTypes := invocation.ArgumentTypes[requiredArgumentCount:] - const requiredArgumentCount = 2 + code, err := interpreter.ByteArrayValueToByteSlice(invocation.Interpreter, newCodeValue, locationRange) + if err != nil { + panic(errors.NewDefaultUserError("add requires the second argument to be an array")) + } - nameValue, ok := invocation.Arguments[0].(*interpreter.StringValue) - if !ok { - panic(errors.NewUnreachableError()) - } + // Get the existing code - newCodeValue, ok := invocation.Arguments[1].(*interpreter.ArrayValue) - if !ok { - panic(errors.NewUnreachableError()) - } + contractName := nameValue.Str - constructorArguments := invocation.Arguments[requiredArgumentCount:] - constructorArgumentTypes := invocation.ArgumentTypes[requiredArgumentCount:] + if contractName == "" { + panic(errors.NewDefaultUserError( + "contract name argument cannot be empty." + + "it must match the name of the deployed contract declaration or contract interface declaration", + )) + } - code, err := interpreter.ByteArrayValueToByteSlice(invocation.Interpreter, newCodeValue, locationRange) - if err != nil { - panic(errors.NewDefaultUserError("add requires the second argument to be an array")) - } + address := addressValue.ToAddress() + location := common.NewAddressLocation(invocation.Interpreter, address, contractName) - // Get the existing code + existingCode, err := handler.GetAccountContractCode(location) + if err != nil { + panic(err) + } - contractName := nameValue.Str + if isUpdate { + // We are updating an existing contract. + // Ensure that there's a contract/contract-interface with the given name exists already - if contractName == "" { - panic(errors.NewDefaultUserError( - "contract name argument cannot be empty." + - "it must match the name of the deployed contract declaration or contract interface declaration", - )) - } + if len(existingCode) == 0 { + panic(errors.NewDefaultUserError( + "cannot update non-existing contract with name %q in account %s", + contractName, + address.ShortHexWithPrefix(), + )) + } - address := addressValue.ToAddress() - location := common.NewAddressLocation(invocation.Interpreter, address, contractName) + } else { + // We are adding a new contract. + // Ensure that no contract/contract interface with the given name exists already, + // and no contract deploy or update was recorded before + + if len(existingCode) > 0 || + handler.ContractUpdateRecorded(location) || + handler.IsContractBeingAdded(location) { + + panic(errors.NewDefaultUserError( + "cannot overwrite existing contract with name %q in account %s", + contractName, + address.ShortHexWithPrefix(), + )) + } + } - existingCode, err := handler.GetAccountContractCode(location) - if err != nil { - panic(err) - } + // Check the code + handleContractUpdateError := func(err error) { + if err == nil { + return + } - if isUpdate { - // We are updating an existing contract. - // Ensure that there's a contract/contract-interface with the given name exists already + // Update the code for the error pretty printing + // NOTE: only do this when an error occurs - if len(existingCode) == 0 { - panic(errors.NewDefaultUserError( - "cannot update non-existing contract with name %q in account %s", - contractName, - address.ShortHexWithPrefix(), - )) - } + handler.TemporarilyRecordCode(location, code) - } else { - // We are adding a new contract. - // Ensure that no contract/contract interface with the given name exists already, - // and no contract deploy or update was recorded before - - if len(existingCode) > 0 || - handler.ContractUpdateRecorded(location) || - handler.IsContractBeingAdded(location) { - - panic(errors.NewDefaultUserError( - "cannot overwrite existing contract with name %q in account %s", - contractName, - address.ShortHexWithPrefix(), - )) - } - } + panic(&InvalidContractDeploymentError{ + Err: err, + LocationRange: locationRange, + }) + } - // Check the code - handleContractUpdateError := func(err error) { - if err == nil { - return - } + // NOTE: do NOT use the program obtained from the host environment, as the current program. + // Always re-parse and re-check the new program. - // Update the code for the error pretty printing - // NOTE: only do this when an error occurs + // NOTE: *DO NOT* store the program – the new or updated program + // should not be effective during the execution - handler.TemporarilyRecordCode(location, code) + const getAndSetProgram = false - panic(&InvalidContractDeploymentError{ - Err: err, - LocationRange: locationRange, - }) + program, err := handler.ParseAndCheckProgram( + code, + location, + getAndSetProgram, + ) + handleContractUpdateError(err) + + // The code may declare exactly one contract or one contract interface. + + var contractTypes []*sema.CompositeType + var contractInterfaceTypes []*sema.InterfaceType + + program.Elaboration.ForEachGlobalType(func(_ string, variable *sema.Variable) { + switch ty := variable.Type.(type) { + case *sema.CompositeType: + if ty.Kind == common.CompositeKindContract { + contractTypes = append(contractTypes, ty) } - // NOTE: do NOT use the program obtained from the host environment, as the current program. - // Always re-parse and re-check the new program. + case *sema.InterfaceType: + if ty.CompositeKind == common.CompositeKindContract { + contractInterfaceTypes = append(contractInterfaceTypes, ty) + } + } + }) - // NOTE: *DO NOT* store the program – the new or updated program - // should not be effective during the execution + var deployedType sema.Type + var contractType *sema.CompositeType + var contractInterfaceType *sema.InterfaceType + var declaredName string + var declarationKind common.DeclarationKind + + switch { + case len(contractTypes) == 1 && len(contractInterfaceTypes) == 0: + contractType = contractTypes[0] + declaredName = contractType.Identifier + deployedType = contractType + declarationKind = common.DeclarationKindContract + case len(contractInterfaceTypes) == 1 && len(contractTypes) == 0: + contractInterfaceType = contractInterfaceTypes[0] + declaredName = contractInterfaceType.Identifier + deployedType = contractInterfaceType + declarationKind = common.DeclarationKindContractInterface + } - const getAndSetProgram = false + if deployedType == nil { + // Update the code for the error pretty printing + // NOTE: only do this when an error occurs - program, err := handler.ParseAndCheckProgram( - code, - location, - getAndSetProgram, - ) - handleContractUpdateError(err) + handler.TemporarilyRecordCode(location, code) - // The code may declare exactly one contract or one contract interface. + panic(errors.NewDefaultUserError( + "invalid %s: the code must declare exactly one contract or contract interface", + declarationKind.Name(), + )) + } - var contractTypes []*sema.CompositeType - var contractInterfaceTypes []*sema.InterfaceType + // The declared contract or contract interface must have the name + // passed to the constructor as the first argument - program.Elaboration.ForEachGlobalType(func(_ string, variable *sema.Variable) { - switch ty := variable.Type.(type) { - case *sema.CompositeType: - if ty.Kind == common.CompositeKindContract { - contractTypes = append(contractTypes, ty) - } + if declaredName != contractName { + // Update the code for the error pretty printing + // NOTE: only do this when an error occurs - case *sema.InterfaceType: - if ty.CompositeKind == common.CompositeKindContract { - contractInterfaceTypes = append(contractInterfaceTypes, ty) - } - } - }) + handler.TemporarilyRecordCode(location, code) - var deployedType sema.Type - var contractType *sema.CompositeType - var contractInterfaceType *sema.InterfaceType - var declaredName string - var declarationKind common.DeclarationKind + panic(errors.NewDefaultUserError( + "invalid %s: the name argument must match the name of the declaration: got %q, expected %q", + declarationKind.Name(), + contractName, + declaredName, + )) + } - switch { - case len(contractTypes) == 1 && len(contractInterfaceTypes) == 0: - contractType = contractTypes[0] - declaredName = contractType.Identifier - deployedType = contractType - declarationKind = common.DeclarationKindContract - case len(contractInterfaceTypes) == 1 && len(contractTypes) == 0: - contractInterfaceType = contractInterfaceTypes[0] - declaredName = contractInterfaceType.Identifier - deployedType = contractInterfaceType - declarationKind = common.DeclarationKindContractInterface - } + // Validate the contract update - if deployedType == nil { - // Update the code for the error pretty printing - // NOTE: only do this when an error occurs + if isUpdate { + oldCode, err := handler.GetAccountContractCode(location) + handleContractUpdateError(err) - handler.TemporarilyRecordCode(location, code) + oldProgram, err := parser.ParseProgram( + invocation.Interpreter.SharedState.Config.MemoryGauge, + oldCode, + parser.Config{}, + ) - panic(errors.NewDefaultUserError( - "invalid %s: the code must declare exactly one contract or contract interface", - declarationKind.Name(), - )) - } + if !ignoreUpdatedProgramParserError(err) { + handleContractUpdateError(err) + } - // The declared contract or contract interface must have the name - // passed to the constructor as the first argument + validator := NewContractUpdateValidator( + location, + contractName, + oldProgram, + program.Program, + ) + err = validator.Validate() + handleContractUpdateError(err) + } - if declaredName != contractName { - // Update the code for the error pretty printing - // NOTE: only do this when an error occurs + inter := invocation.Interpreter - handler.TemporarilyRecordCode(location, code) + err = updateAccountContractCode( + handler, + location, + program, + code, + contractType, + constructorArguments, + constructorArgumentTypes, + updateAccountContractCodeOptions{ + createContract: !isUpdate, + }, + ) + if err != nil { + // Update the code for the error pretty printing + // NOTE: only do this when an error occurs - panic(errors.NewDefaultUserError( - "invalid %s: the name argument must match the name of the declaration: got %q, expected %q", - declarationKind.Name(), - contractName, - declaredName, - )) - } + handler.TemporarilyRecordCode(location, code) - // Validate the contract update + panic(err) + } - if isUpdate { - oldCode, err := handler.GetAccountContractCode(location) - handleContractUpdateError(err) + var eventType *sema.CompositeType - oldProgram, err := parser.ParseProgram( - gauge, - oldCode, - parser.Config{}, - ) + if isUpdate { + eventType = AccountContractUpdatedEventType + } else { + eventType = AccountContractAddedEventType + } - if !ignoreUpdatedProgramParserError(err) { - handleContractUpdateError(err) - } + codeHashValue := CodeToHashValue(inter, code) - validator := NewContractUpdateValidator( - location, - contractName, - oldProgram, - program.Program, - ) - err = validator.Validate() - handleContractUpdateError(err) - } + handler.EmitEvent( + inter, + eventType, + []interpreter.Value{ + addressValue, + codeHashValue, + nameValue, + }, + locationRange, + ) - inter := invocation.Interpreter + return interpreter.NewDeployedContractValue( + inter, + addressValue, + nameValue, + newCodeValue, + ) +} - err = updateAccountContractCode( - handler, - location, - program, - code, - contractType, - constructorArguments, - constructorArgumentTypes, - updateAccountContractCodeOptions{ - createContract: !isUpdate, - }, - ) - if err != nil { - // Update the code for the error pretty printing - // NOTE: only do this when an error occurs +func newAuthAccountContractsTryUpdateFunction( + functionType *sema.FunctionType, + gauge common.MemoryGauge, + handler AccountContractAdditionHandler, + addressValue interpreter.AddressValue, +) *interpreter.HostFunctionValue { + return interpreter.NewHostFunctionValue( + gauge, + functionType, + func(invocation interpreter.Invocation) (deploymentResult interpreter.Value) { + var deployedContract interpreter.Value + + defer func() { + if r := recover(); r != nil { + rootError := r + for { + switch err := r.(type) { + case errors.UserError, errors.ExternalError: + // Error is ignored for now. + // Simply return with a `nil` deployed-contract + case xerrors.Wrapper: + r = err.Unwrap() + continue + default: + panic(rootError) + } + + break + } + } - handler.TemporarilyRecordCode(location, code) + var optionalDeployedContract interpreter.OptionalValue + if deployedContract == nil { + optionalDeployedContract = interpreter.NilOptionalValue + } else { + optionalDeployedContract = interpreter.NewSomeValueNonCopying(invocation.Interpreter, deployedContract) + } - panic(err) - } + deploymentResult = interpreter.NewDeploymentResultValue(gauge, optionalDeployedContract) + }() - var eventType *sema.CompositeType + deployedContract = changeAccountContracts(invocation, handler, addressValue, true) + return + }, + ) +} - if isUpdate { - eventType = AccountContractUpdatedEventType - } else { - eventType = AccountContractAddedEventType - } +type AccountContractAdditionHandler interface { + EventEmitter + AccountContractProvider + ParseAndCheckProgram( + code []byte, + location common.Location, + getAndSetProgram bool, + ) (*interpreter.Program, error) + // UpdateAccountContractCode updates the code associated with an account contract. + UpdateAccountContractCode(location common.AddressLocation, code []byte) error + RecordContractUpdate( + location common.AddressLocation, + value *interpreter.CompositeValue, + ) + ContractUpdateRecorded(location common.AddressLocation) bool + InterpretContract( + location common.AddressLocation, + program *interpreter.Program, + name string, + invocation DeployedContractConstructorInvocation, + ) ( + *interpreter.CompositeValue, + error, + ) + TemporarilyRecordCode(location common.AddressLocation, code []byte) - codeHashValue := CodeToHashValue(inter, code) + // StartContractAddition starts adding a contract. + StartContractAddition(location common.AddressLocation) - handler.EmitEvent( - inter, - eventType, - []interpreter.Value{ - addressValue, - codeHashValue, - nameValue, - }, - locationRange, - ) + // EndContractAddition ends adding the contract + EndContractAddition(location common.AddressLocation) - return interpreter.NewDeployedContractValue( - inter, - addressValue, - nameValue, - newCodeValue, - ) + // IsContractBeingAdded checks whether a contract is being added in the current execution. + IsContractBeingAdded(location common.AddressLocation) bool +} + +// newAuthAccountContractsChangeFunction called when e.g. +// - adding: `AuthAccount.contracts.add(name: "Foo", code: [...])` (isUpdate = false) +// - updating: `AuthAccount.contracts.update__experimental(name: "Foo", code: [...])` (isUpdate = true) +func newAuthAccountContractsChangeFunction( + functionType *sema.FunctionType, + gauge common.MemoryGauge, + handler AccountContractAdditionHandler, + addressValue interpreter.AddressValue, + isUpdate bool, +) *interpreter.HostFunctionValue { + return interpreter.NewHostFunctionValue( + gauge, + functionType, + func(invocation interpreter.Invocation) interpreter.Value { + return changeAccountContracts(invocation, handler, addressValue, isUpdate) }, ) } diff --git a/runtime/tests/checker/account_test.go b/runtime/tests/checker/account_test.go index 12ace1b81b..08f2104db0 100644 --- a/runtime/tests/checker/account_test.go +++ b/runtime/tests/checker/account_test.go @@ -1572,6 +1572,48 @@ func TestAuthAccountContracts(t *testing.T) { require.NoError(t, err) }) + t.Run("try update, unauthorized", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test(contracts: &PublicAccount.Contracts): DeploymentResult { + return contracts.tryUpdate(name: "foo", code: "012".decodeHex()) + } + `) + + errors := RequireCheckerErrors(t, err, 1) + + var missingMemberError *sema.NotDeclaredMemberError + require.ErrorAs(t, errors[0], &missingMemberError) + assert.Equal(t, "tryUpdate", missingMemberError.Name) + }) + + t.Run("try update, authorized", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test(contracts: &AuthAccount.Contracts): DeploymentResult { + return contracts.tryUpdate(name: "foo", code: "012".decodeHex()) + } + `) + require.NoError(t, err) + }) + + t.Run("deployment result fields", func(t *testing.T) { + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test(contracts: &AuthAccount.Contracts) { + let deploymentResult: DeploymentResult = contracts.tryUpdate(name: "foo", code: "012".decodeHex()) + let deployedContract: DeployedContract = deploymentResult.deployedContract! + let name: String = deployedContract.name + let address: Address = deployedContract.address + let code: [UInt8] = deployedContract.code + } + `) + require.NoError(t, err) + }) + } func TestPublicAccountContracts(t *testing.T) { diff --git a/runtime/tests/checker/type_inference_test.go b/runtime/tests/checker/type_inference_test.go index af024b4f16..5c05b1386a 100644 --- a/runtime/tests/checker/type_inference_test.go +++ b/runtime/tests/checker/type_inference_test.go @@ -1221,3 +1221,29 @@ func TestCheckTypeInferenceForTypesWithDifferentTypeMaskRanges(t *testing.T) { require.IsType(t, &sema.OptionalType{Type: sema.AnyStructType}, xType) }) } + +func TestCheckDeploymentResultInference(t *testing.T) { + + t.Parallel() + + code := ` + let x: DeploymentResult = getDeploymentResult() + let y: DeploymentResult = getDeploymentResult() + // Function is just to get a 'DeploymentResult' return type. + fun getDeploymentResult(): DeploymentResult { + let v: DeploymentResult? = nil + return v! + } + let z = [x, y] + ` + + checker, err := ParseAndCheck(t, code) + require.NoError(t, err) + + zType := RequireGlobalValue(t, checker.Elaboration, "z") + + require.IsType(t, &sema.VariableSizedType{}, zType) + variableSizedType := zType.(*sema.VariableSizedType) + + assert.Equal(t, sema.DeploymentResultType, variableSizedType.Type) +} diff --git a/runtime/tests/interpreter/interpreter_test.go b/runtime/tests/interpreter/interpreter_test.go index 447b8cf04f..95d2d652ad 100644 --- a/runtime/tests/interpreter/interpreter_test.go +++ b/runtime/tests/interpreter/interpreter_test.go @@ -9554,6 +9554,7 @@ func newTestAuthAccountValue(gauge common.MemoryGauge, addressValue interpreter. panicFunctionValue, panicFunctionValue, panicFunctionValue, + panicFunctionValue, func( inter *interpreter.Interpreter, locationRange interpreter.LocationRange,