diff --git a/internal/cadence/linter.go b/internal/cadence/linter.go index 0a54b9700..179a10f05 100644 --- a/internal/cadence/linter.go +++ b/internal/cadence/linter.go @@ -25,6 +25,8 @@ import ( "errors" + "github.com/onflow/flow-cli/internal/util" + cdclint "github.com/onflow/cadence-tools/lint" cdctests "github.com/onflow/cadence-tools/test/helpers" "github.com/onflow/cadence/runtime/ast" @@ -67,8 +69,8 @@ func newLinter(state *flowkit.State) *linter { // Create checker configs for both standard and script // Scripts have a different stdlib than contracts and transactions - l.checkerStandardConfig = l.newCheckerConfig(newStandardLibrary()) - l.checkerScriptConfig = l.newCheckerConfig(newScriptStandardLibrary()) + l.checkerStandardConfig = l.newCheckerConfig(util.NewStandardLibrary()) + l.checkerScriptConfig = l.newCheckerConfig(util.NewScriptStandardLibrary()) return l } @@ -152,10 +154,10 @@ func (l *linter) lintFile( } // Create a new checker config with the given standard library -func (l *linter) newCheckerConfig(lib standardLibrary) *sema.Config { +func (l *linter) newCheckerConfig(lib util.StandardLibrary) *sema.Config { return &sema.Config{ BaseValueActivationHandler: func(_ common.Location) *sema.VariableActivation { - return lib.baseValueActivation + return lib.BaseValueActivation }, AccessCheckMode: sema.AccessCheckModeStrict, PositionInfoEnabled: true, // Must be enabled for linters diff --git a/internal/migrate/stage_contract.go b/internal/migrate/stage_contract.go index e9b2bed13..e95c3df5d 100644 --- a/internal/migrate/stage_contract.go +++ b/internal/migrate/stage_contract.go @@ -20,9 +20,13 @@ package migrate import ( "context" + "errors" "fmt" + "strings" + "github.com/manifoldco/promptui" "github.com/onflow/cadence" + "github.com/onflow/cadence/runtime/common" "github.com/onflow/contract-updater/lib/go/templates" flowsdk "github.com/onflow/flow-go-sdk" "github.com/onflow/flowkit/v2" @@ -35,7 +39,9 @@ import ( "github.com/onflow/flow-cli/internal/command" ) -var stageContractflags struct{} +var stageContractflags struct { + SkipValidation bool `default:"false" flag:"skip-validation" info:"Do not validate the contract code against staged dependencies"` +} var stageContractCommand = &command.Command{ Cmd: &cobra.Command{ @@ -51,7 +57,7 @@ var stageContractCommand = &command.Command{ func stageContract( args []string, globalFlags command.GlobalFlags, - _ output.Logger, + logger output.Logger, flow flowkit.Services, state *flowkit.State, ) (command.Result, error) { @@ -82,6 +88,55 @@ func stageContract( return nil, fmt.Errorf("failed to get account by contract name: %w", err) } + // Validate the contract code by default + if !stageContractflags.SkipValidation { + logger.StartProgress("Validating contract code against any staged dependencies") + validator := newStagingValidator(flow, state) + + var missingDependenciesErr *missingDependenciesError + contractLocation := common.NewAddressLocation(nil, common.Address(account.Address), contractName) + err = validator.ValidateContractUpdate(contractLocation, common.StringLocation(contract.Location), replacedCode) + + logger.StopProgress() + + // Errors when the contract's dependencies have not been staged yet are non-fatal + // This is because the contract may be dependent on contracts that are not yet staged + // and we do not want to require in-order staging of contracts + // Instead, we will prompt the user to continue staging the contract. Other errors + // will be fatal and require manual intervention using the --skip-validation flag if desired + if errors.As(err, &missingDependenciesErr) { + infoMessage := strings.Builder{} + infoMessage.WriteString("Validation cannot be performed as some of your contract's dependencies could not be found (have they been staged yet?)\n") + for _, contract := range missingDependenciesErr.MissingContracts { + infoMessage.WriteString(fmt.Sprintf(" - %s\n", contract)) + } + infoMessage.WriteString("\nYou may still stage your contract, however it will be unable to be migrated until the missing contracts are staged by their respective owners. It is important to monitor the status of your contract using the `flow migrate is-validated` command\n") + logger.Error(infoMessage.String()) + + continuePrompt := promptui.Select{ + Label: "Do you wish to continue staging your contract?", + Items: []string{"Yes", "No"}, + } + + _, result, err := continuePrompt.Run() + if err != nil { + return nil, err + } + + if result == "No" { + return nil, fmt.Errorf("staging cancelled") + } + } else if err != nil { + logger.Error(validator.prettyPrintError(err, common.StringLocation(contract.Location))) + return nil, fmt.Errorf("errors were found while validating the contract code, and your contract HAS NOT been staged, you can use the --skip-validation flag to bypass this check") + } else { + logger.Info("No issues found while validating contract code\n") + logger.Info("DISCLAIMER: Pre-staging validation checks are not exhaustive and do not guarantee the contract will work as expected, please monitor the status of your contract using the `flow migrate is-validated` command\n") + } + } else { + logger.Info("Skipping contract code validation, you may monitor the status of your contract using the `flow migrate is-validated` command\n") + } + tx, res, err := flow.SendTransaction( context.Background(), transactions.SingleAccountRole(*account), diff --git a/internal/migrate/stage_contract_test.go b/internal/migrate/stage_contract_test.go index 9881ad0cc..02533d794 100644 --- a/internal/migrate/stage_contract_test.go +++ b/internal/migrate/stage_contract_test.go @@ -88,6 +88,9 @@ func Test_StageContract(t *testing.T) { BlockHeight: 1, }, nil) + // disable validation + stageContractflags.SkipValidation = true + result, err := stageContract( []string{testContract.Name}, command.GlobalFlags{ @@ -97,6 +100,9 @@ func Test_StageContract(t *testing.T) { srv.Mock, state, ) + // reset flags + stageContractflags.SkipValidation = false + assert.NoError(t, err) assert.NotNil(t, result) }) diff --git a/internal/migrate/staging_validator.go b/internal/migrate/staging_validator.go new file mode 100644 index 000000000..c65e83921 --- /dev/null +++ b/internal/migrate/staging_validator.go @@ -0,0 +1,440 @@ +/* + * Flow CLI + * + * Copyright 2019 Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package migrate + +import ( + "context" + "fmt" + "strings" + + "github.com/onflow/flow-cli/internal/util" + + "github.com/onflow/cadence" + "github.com/onflow/cadence/runtime" + "github.com/onflow/cadence/runtime/ast" + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/interpreter" + "github.com/onflow/cadence/runtime/old_parser" + "github.com/onflow/cadence/runtime/parser" + "github.com/onflow/cadence/runtime/pretty" + "github.com/onflow/cadence/runtime/sema" + "github.com/onflow/cadence/runtime/stdlib" + "github.com/onflow/contract-updater/lib/go/templates" + flowsdk "github.com/onflow/flow-go-sdk" + "github.com/onflow/flow-go/cmd/util/ledger/migrations" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flowkit/v2" +) + +type stagingValidator struct { + flow flowkit.Services + state *flowkit.State + + // Cache for account contract names so we don't have to fetch them multiple times + accountContractNames map[common.Address][]string + // All resolved contract code + contracts map[common.Location][]byte + // Record errors related to missing staged dependencies, as these are reported separately + missingDependencies []common.AddressLocation + // Cache for contract elaborations which are reused during program checking & used for the update checker + elaborations map[common.Location]*sema.Elaboration +} + +type accountContractNamesProviderImpl struct { + resolverFunc func(address common.Address) ([]string, error) +} + +var _ stdlib.AccountContractNamesProvider = &accountContractNamesProviderImpl{} + +type missingDependenciesError struct { + MissingContracts []common.AddressLocation +} + +func (e *missingDependenciesError) Error() string { + contractNames := make([]string, len(e.MissingContracts)) + for i, location := range e.MissingContracts { + contractNames[i] = location.Name + } + return fmt.Sprintf("the following staged contract dependencies could not be found (have they been staged yet?): %v", contractNames) +} + +var _ error = &missingDependenciesError{} + +var chainIdMap = map[string]flow.ChainID{ + "mainnet": flow.Mainnet, + "testnet": flow.Testnet, +} + +func newStagingValidator(flow flowkit.Services, state *flowkit.State) *stagingValidator { + return &stagingValidator{ + flow: flow, + state: state, + contracts: make(map[common.Location][]byte), + elaborations: make(map[common.Location]*sema.Elaboration), + accountContractNames: make(map[common.Address][]string), + } +} + +func (v *stagingValidator) ValidateContractUpdate( + // Network location of the contract to be updated + location common.AddressLocation, + // Location of the source code, ensures that the error messages reference this instead of a network location + sourceCodeLocation common.Location, + // Code of the updated contract + updatedCode []byte, +) error { + // Resolve all system contract code & add to cache + v.loadSystemContracts() + + // Get the account for the contract + address := flowsdk.Address(location.Address) + account, err := v.flow.GetAccount(context.Background(), address) + if err != nil { + return fmt.Errorf("failed to get account: %w", err) + } + + // Get the target contract old code + contractName := location.Name + contractCode, ok := account.Contracts[contractName] + if !ok { + return fmt.Errorf("old contract code not found for contract: %s", contractName) + } + + // Parse the old contract code + oldProgram, err := old_parser.ParseProgram(nil, contractCode, old_parser.Config{}) + if err != nil { + return fmt.Errorf("failed to parse old contract code: %w", err) + } + + // Store contract code for error pretty printing + v.contracts[sourceCodeLocation] = updatedCode + + // Parse and check the contract code + _, newProgramChecker, err := v.parseAndCheckContract(sourceCodeLocation) + + // Errors related to missing dependencies are separate from other errors + // They may be handled differently by the caller, and it's parsing/checking + // errors are not relevant/informative if these are present (they are expected) + if len(v.missingDependencies) > 0 { + return &missingDependenciesError{MissingContracts: v.missingDependencies} + } + + if err != nil { + return err + } + + // Convert the new program checker to an interpreter program + interpreterProgram := interpreter.ProgramFromChecker(newProgramChecker) + + // Check if contract code is valid according to Cadence V1 Update Checker + validator := stdlib.NewCadenceV042ToV1ContractUpdateValidator( + sourceCodeLocation, + contractName, + &accountContractNamesProviderImpl{ + resolverFunc: v.resolveAddressContractNames, + }, + oldProgram, + interpreterProgram, + v.elaborations, + ) + + err = validator.Validate() + if err != nil { + return err + } + + return nil +} + +func (v *stagingValidator) parseAndCheckContract( + location common.Location, +) (*ast.Program, *sema.Checker, error) { + code := v.contracts[location] + + // Parse the contract code + program, err := parser.ParseProgram(nil, code, parser.Config{}) + if err != nil { + return nil, nil, err + } + + // Check the contract code + checker, err := sema.NewChecker( + program, + location, + nil, + &sema.Config{ + AccessCheckMode: sema.AccessCheckModeStrict, + AttachmentsEnabled: true, + BaseValueActivationHandler: func(_ common.Location) *sema.VariableActivation { + // Only checking contracts, so no need to consider script standard library + return util.NewStandardLibrary().BaseValueActivation + }, + LocationHandler: v.resolveLocation, + ImportHandler: v.resolveImport, + }, + ) + if err != nil { + return nil, nil, err + } + + err = checker.Check() + if err != nil { + return nil, nil, err + } + + return program, checker, nil +} + +func (v *stagingValidator) getStagedContractCode( + location common.AddressLocation, +) ([]byte, error) { + // First check if the code is already known + // This may be true for system contracts since they are not staged + // Or any other staged contracts that have been resolved + if code, ok := v.contracts[location]; ok { + return code, nil + } + + cAddr := cadence.BytesToAddress(location.Address.Bytes()) + cName, err := cadence.NewString(location.Name) + if err != nil { + return nil, fmt.Errorf("failed to get cadence string from contract name: %w", err) + } + + value, err := v.flow.ExecuteScript( + context.Background(), + flowkit.Script{ + Code: templates.GenerateGetStagedContractCodeScript(MigrationContractStagingAddress(v.flow.Network().Name)), + Args: []cadence.Value{cAddr, cName}, + }, + flowkit.LatestScriptQuery, + ) + if err != nil { + return nil, err + } + + optValue, ok := value.(cadence.Optional) + if !ok { + return nil, fmt.Errorf("invalid script return value type: %T", value) + } + + strValue, ok := optValue.Value.(cadence.String) + if !ok { + return nil, fmt.Errorf("invalid script return value type: %T", value) + } + + v.contracts[location] = []byte(strValue) + return v.contracts[location], nil +} + +func (v *stagingValidator) resolveImport(checker *sema.Checker, importedLocation common.Location, _ ast.Range) (sema.Import, error) { + // Check if the imported location is the crypto checker + if importedLocation == stdlib.CryptoCheckerLocation { + cryptoChecker := stdlib.CryptoChecker() + return sema.ElaborationImport{ + Elaboration: cryptoChecker.Elaboration, + }, nil + } + + // Check if the imported location is an address location + // No other location types are supported (as is the case with code on-chain) + addrLocation, ok := importedLocation.(common.AddressLocation) + if !ok { + return nil, fmt.Errorf("expected address location") + } + + // Check if this contract has already been resolved + elaboration, ok := v.elaborations[importedLocation] + + // If not resolved, parse and check the contract code + if !ok { + importedCode, err := v.getStagedContractCode(addrLocation) + if err != nil { + v.missingDependencies = append(v.missingDependencies, addrLocation) + return nil, fmt.Errorf("failed to get staged contract code: %w", err) + } + v.contracts[addrLocation] = importedCode + + _, checker, err = v.parseAndCheckContract(addrLocation) + if err != nil { + return nil, fmt.Errorf("failed to parse and check contract code: %w", err) + } + + v.elaborations[importedLocation] = checker.Elaboration + elaboration = checker.Elaboration + } + + return sema.ElaborationImport{ + Elaboration: elaboration, + }, nil +} + +func (v *stagingValidator) loadSystemContracts() { + chainId, ok := chainIdMap[v.flow.Network().Name] + if !ok { + return + } + + stagedSystemContracts := migrations.SystemContractChanges(chainId, migrations.SystemContractChangesOptions{ + Burner: migrations.BurnerContractChangeUpdate, // needs to be update for now since BurnerChangeDeploy is a no-op in flow-go + EVM: migrations.EVMContractChangeFull, + }) + for _, stagedSystemContract := range stagedSystemContracts { + location := common.AddressLocation{ + Address: stagedSystemContract.Address, + Name: stagedSystemContract.Name, + } + + v.contracts[location] = stagedSystemContract.Code + v.accountContractNames[stagedSystemContract.Address] = append(v.accountContractNames[stagedSystemContract.Address], stagedSystemContract.Name) + } +} + +// This is a copy of the resolveLocation function from the linter/language server +func (v *stagingValidator) resolveLocation( + identifiers []ast.Identifier, + location common.Location, +) ( + []sema.ResolvedLocation, + error, +) { + addressLocation, isAddress := location.(common.AddressLocation) + + // if the location is not an address location, e.g. an identifier location (`import Crypto`), + // then return a single resolved location which declares all identifiers. + + if !isAddress { + return []runtime.ResolvedLocation{ + { + Location: location, + Identifiers: identifiers, + }, + }, nil + } + + // if the location is an address, + // and no specific identifiers where requested in the import statement, + // then fetch all identifiers at this address + + if len(identifiers) == 0 { + // if there is no contract name resolver, + // then return no resolved locations + contractNames, err := v.resolveAddressContractNames(addressLocation.Address) + if err != nil { + return nil, err + } + + // if there are no contracts deployed, + // then return no resolved locations + + if len(contractNames) == 0 { + return nil, nil + } + + identifiers = make([]ast.Identifier, len(contractNames)) + + for i := range identifiers { + identifiers[i] = runtime.Identifier{ + Identifier: contractNames[i], + } + } + } + + // return one resolved location per identifier. + // each resolved location is an address contract location + + resolvedLocations := make([]runtime.ResolvedLocation, len(identifiers)) + for i := range resolvedLocations { + identifier := identifiers[i] + resolvedLocations[i] = runtime.ResolvedLocation{ + Location: common.AddressLocation{ + Address: addressLocation.Address, + Name: identifier.Identifier, + }, + Identifiers: []runtime.Identifier{identifier}, + } + } + + return resolvedLocations, nil +} + +func (v *stagingValidator) resolveAddressContractNames(address common.Address) ([]string, error) { + // Check if the contract names are already cached + if names, ok := v.accountContractNames[address]; ok { + return names, nil + } + + cAddr := cadence.BytesToAddress(address.Bytes()) + value, err := v.flow.ExecuteScript( + context.Background(), + flowkit.Script{ + Code: templates.GenerateGetStagedContractNamesForAddressScript(MigrationContractStagingAddress(v.flow.Network().Name)), + Args: []cadence.Value{cAddr}, + }, + flowkit.LatestScriptQuery, + ) + + if err != nil { + return nil, err + } + + optValue, ok := value.(cadence.Optional) + if !ok { + return nil, fmt.Errorf("invalid script return value type: %T", value) + } + + arrValue, ok := optValue.Value.(cadence.Array) + if !ok { + return nil, fmt.Errorf("invalid script return value type: %T", value) + } + + // Cache the contract names + for _, name := range arrValue.Values { + strName, ok := name.(cadence.String) + if !ok { + return nil, fmt.Errorf("invalid array value type: %T", name) + } + v.accountContractNames[address] = append(v.accountContractNames[address], string(strName)) + } + + return v.accountContractNames[address], nil +} + +// Helper for pretty printing errors +// While it is done by default in checker/parser errors, this has two purposes: +// 1. Add color to the error message +// 2. Use pretty printing on contract update errors which do not do this by default +func (v *stagingValidator) prettyPrintError(err error, location common.Location) string { + var sb strings.Builder + printErr := pretty.NewErrorPrettyPrinter(&sb, true). + PrettyPrintError(err, location, v.contracts) + if printErr != nil { + return fmt.Sprintf("failed to pretty print error: %v", printErr) + } else { + return sb.String() + } +} + +// Stdlib handler used by the Cadence V1 Update Checker to resolve contract names +// When an address import with no identifiers is used, the contract names are resolved +func (a *accountContractNamesProviderImpl) GetAccountContractNames( + address common.Address, +) ([]string, error) { + return a.resolverFunc(address) +} diff --git a/internal/migrate/staging_validator_test.go b/internal/migrate/staging_validator_test.go new file mode 100644 index 000000000..48db6e70f --- /dev/null +++ b/internal/migrate/staging_validator_test.go @@ -0,0 +1,291 @@ +/* + * Flow CLI + * + * Copyright 2019 Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package migrate + +import ( + "testing" + + "github.com/onflow/flow-cli/internal/util" + + "github.com/onflow/cadence" + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/sema" + "github.com/onflow/cadence/runtime/stdlib" + "github.com/onflow/contract-updater/lib/go/templates" + "github.com/onflow/flow-go-sdk" + "github.com/onflow/flowkit/v2" + "github.com/onflow/flowkit/v2/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_StagingValidator(t *testing.T) { + srv, state, rw := util.TestMocks(t) + t.Run("valid contract update with no dependencies", func(t *testing.T) { + location := common.NewAddressLocation(nil, common.Address{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, "Test") + sourceCodeLocation := common.StringLocation("./Test.cdc") + oldContract := ` + pub contract Test { + pub fun test() {} + }` + newContract := ` + access(all) contract Test { + access(all) fun test() {} + }` + mockAccount := &flow.Account{ + Address: flow.HexToAddress("01"), + Balance: 1000, + Keys: nil, + Contracts: map[string][]byte{ + "Test": []byte(oldContract), + }, + } + + // setup mocks + require.NoError(t, rw.WriteFile(sourceCodeLocation.String(), []byte(newContract), 0o644)) + srv.GetAccount.Run(func(args mock.Arguments) { + require.Equal(t, flow.HexToAddress("01"), args.Get(1).(flow.Address)) + }).Return(mockAccount, nil) + srv.Network.Return(config.Network{ + Name: "testnet", + }, nil) + + validator := newStagingValidator(srv.Mock, state) + err := validator.ValidateContractUpdate(location, sourceCodeLocation, []byte(newContract)) + require.NoError(t, err) + }) + + t.Run("contract update with update error", func(t *testing.T) { + location := common.NewAddressLocation(nil, common.Address{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, "Test") + sourceCodeLocation := common.StringLocation("./Test.cdc") + oldContract := ` + pub contract Test { + pub fun test() {} + }` + newContract := ` + access(all) contract Test { + access(all) let x: Int + access(all) fun test() {} + + init() { + self.x = 1 + } + }` + mockAccount := &flow.Account{ + Address: flow.HexToAddress("01"), + Balance: 1000, + Keys: nil, + Contracts: map[string][]byte{ + "Test": []byte(oldContract), + }, + } + + // setup mocks + require.NoError(t, rw.WriteFile(sourceCodeLocation.String(), []byte(newContract), 0o644)) + srv.GetAccount.Run(func(args mock.Arguments) { + require.Equal(t, flow.HexToAddress("01"), args.Get(1).(flow.Address)) + }).Return(mockAccount, nil) + srv.Network.Return(config.Network{ + Name: "testnet", + }, nil) + + validator := newStagingValidator(srv.Mock, state) + err := validator.ValidateContractUpdate(location, sourceCodeLocation, []byte(newContract)) + var updateErr *stdlib.ContractUpdateError + require.ErrorAs(t, err, &updateErr) + }) + + t.Run("contract update with checker error", func(t *testing.T) { + location := common.NewAddressLocation(nil, common.Address{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, "Test") + sourceCodeLocation := common.StringLocation("./Test.cdc") + oldContract := ` + pub contract Test { + let x: Int + init() { + self.x = 1 + } + }` + newContract := ` + access(all) contract Test { + access(all) let x: Int + init() { + self.x = "bad type :(" + } + }` + mockAccount := &flow.Account{ + Address: flow.HexToAddress("01"), + Balance: 1000, + Keys: nil, + Contracts: map[string][]byte{ + "Test": []byte(oldContract), + }, + } + + // setup mocks + require.NoError(t, rw.WriteFile(sourceCodeLocation.String(), []byte(newContract), 0o644)) + srv.GetAccount.Run(func(args mock.Arguments) { + require.Equal(t, flow.HexToAddress("01"), args.Get(1).(flow.Address)) + }).Return(mockAccount, nil) + srv.Network.Return(config.Network{ + Name: "testnet", + }, nil) + + validator := newStagingValidator(srv.Mock, state) + err := validator.ValidateContractUpdate(location, sourceCodeLocation, []byte(newContract)) + var checkerErr *sema.CheckerError + require.ErrorAs(t, err, &checkerErr) + }) + + t.Run("valid contract update with dependencies", func(t *testing.T) { + location := common.NewAddressLocation(nil, common.Address{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, "Test") + sourceCodeLocation := common.StringLocation("./Test.cdc") + oldContract := ` + pub contract Test { + pub fun test() {} + }` + newContract := ` + import ImpContract from 0x02 + access(all) contract Test { + access(all) fun test() {} + }` + impContract := ` + access(all) contract ImpContract { + access(all) let x: Int + init() { + self.x = 1 + } + }` + mockScriptResultString, err := cadence.NewString(impContract) + require.NoError(t, err) + + mockAccount := &flow.Account{ + Address: flow.HexToAddress("01"), + Balance: 1000, + Keys: nil, + Contracts: map[string][]byte{ + "Test": []byte(oldContract), + }, + } + + // setup mocks + require.NoError(t, rw.WriteFile(sourceCodeLocation.String(), []byte(newContract), 0o644)) + srv.GetAccount.Run(func(args mock.Arguments) { + require.Equal(t, flow.HexToAddress("01"), args.Get(1).(flow.Address)) + }).Return(mockAccount, nil) + srv.Network.Return(config.Network{ + Name: "testnet", + }, nil) + srv.ExecuteScript.Run(func(args mock.Arguments) { + script := args.Get(1).(flowkit.Script) + + assert.Equal(t, templates.GenerateGetStagedContractCodeScript(MigrationContractStagingAddress("testnet")), script.Code) + + assert.Equal(t, 2, len(script.Args)) + actualContractAddressArg, actualContractNameArg := script.Args[0], script.Args[1] + + contractName, _ := cadence.NewString("ImpContract") + contractAddr := cadence.NewAddress(flow.HexToAddress("02")) + assert.Equal(t, contractName, actualContractNameArg) + assert.Equal(t, contractAddr, actualContractAddressArg) + }).Return(cadence.NewOptional(mockScriptResultString), nil) + + // validate + validator := newStagingValidator(srv.Mock, state) + err = validator.ValidateContractUpdate(location, sourceCodeLocation, []byte(newContract)) + require.NoError(t, err) + }) + + t.Run("contract update missing dependency", func(t *testing.T) { + location := common.NewAddressLocation(nil, common.Address{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, "Test") + impLocation := common.NewAddressLocation(nil, common.Address{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, "ImpContract") + sourceCodeLocation := common.StringLocation("./Test.cdc") + oldContract := ` + pub contract Test { + pub fun test() {} + }` + newContract := ` + import ImpContract from 0x02 + access(all) contract Test { + access(all) fun test() {} + }` + mockAccount := &flow.Account{ + Address: flow.HexToAddress("01"), + Balance: 1000, + Keys: nil, + Contracts: map[string][]byte{ + "Test": []byte(oldContract), + }, + } + + // setup mocks + require.NoError(t, rw.WriteFile(sourceCodeLocation.String(), []byte(newContract), 0o644)) + srv.GetAccount.Run(func(args mock.Arguments) { + require.Equal(t, flow.HexToAddress("01"), args.Get(1).(flow.Address)) + }).Return(mockAccount, nil) + srv.Network.Return(config.Network{ + Name: "testnet", + }, nil) + srv.ExecuteScript.Return(cadence.NewOptional(nil), nil) + + validator := newStagingValidator(srv.Mock, state) + err := validator.ValidateContractUpdate(location, sourceCodeLocation, []byte(newContract)) + var missingDepsErr *missingDependenciesError + require.ErrorAs(t, err, &missingDepsErr) + require.Equal(t, 1, len(missingDepsErr.MissingContracts)) + require.Equal(t, impLocation, missingDepsErr.MissingContracts[0]) + }) + + t.Run("valid contract update with system contract imports", func(t *testing.T) { + location := common.NewAddressLocation(nil, common.Address{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, "Test") + sourceCodeLocation := common.StringLocation("./Test.cdc") + oldContract := ` + import FlowToken from 0x7e60df042a9c0868 + pub contract Test { + pub fun test() {} + }` + newContract := ` + import FlowToken from 0x7e60df042a9c0868 + import Burner from 0x9a0766d93b6608b7 + access(all) contract Test { + access(all) fun test() {} + }` + mockAccount := &flow.Account{ + Address: flow.HexToAddress("01"), + Balance: 1000, + Keys: nil, + Contracts: map[string][]byte{ + "Test": []byte(oldContract), + }, + } + + // setup mocks + require.NoError(t, rw.WriteFile(sourceCodeLocation.String(), []byte(newContract), 0o644)) + srv.GetAccount.Run(func(args mock.Arguments) { + require.Equal(t, flow.HexToAddress("01"), args.Get(1).(flow.Address)) + }).Return(mockAccount, nil) + srv.Network.Return(config.Network{ + Name: "testnet", + }, nil) + + validator := newStagingValidator(srv.Mock, state) + err := validator.ValidateContractUpdate(location, sourceCodeLocation, []byte(newContract)) + require.NoError(t, err) + }) +} diff --git a/internal/cadence/stdlib.go b/internal/util/stdlib.go similarity index 72% rename from internal/cadence/stdlib.go rename to internal/util/stdlib.go index 15c053887..194f274fc 100644 --- a/internal/cadence/stdlib.go +++ b/internal/util/stdlib.go @@ -18,7 +18,7 @@ // NOTE: This file is a copy of the file https://github.com/onflow/cadence-tools/blob/master/languageserver/server/stdlib.go -package cadence +package util import ( "github.com/onflow/cadence/runtime/common" @@ -28,85 +28,85 @@ import ( "github.com/onflow/cadence/runtime/stdlib" ) -type standardLibrary struct { - baseValueActivation *sema.VariableActivation +type StandardLibrary struct { + BaseValueActivation *sema.VariableActivation } -var _ stdlib.StandardLibraryHandler = standardLibrary{} +var _ stdlib.StandardLibraryHandler = StandardLibrary{} -func (standardLibrary) ProgramLog(_ string, _ interpreter.LocationRange) error { +func (StandardLibrary) ProgramLog(_ string, _ interpreter.LocationRange) error { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) UnsafeRandom() (uint64, error) { +func (StandardLibrary) UnsafeRandom() (uint64, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) GetBlockAtHeight(_ uint64) (stdlib.Block, bool, error) { +func (StandardLibrary) GetBlockAtHeight(_ uint64) (stdlib.Block, bool, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) GetCurrentBlockHeight() (uint64, error) { +func (StandardLibrary) GetCurrentBlockHeight() (uint64, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) GetAccountBalance(_ common.Address) (uint64, error) { +func (StandardLibrary) GetAccountBalance(_ common.Address) (uint64, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) GetAccountAvailableBalance(_ common.Address) (uint64, error) { +func (StandardLibrary) GetAccountAvailableBalance(_ common.Address) (uint64, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) CommitStorageTemporarily(_ *interpreter.Interpreter) error { +func (StandardLibrary) CommitStorageTemporarily(_ *interpreter.Interpreter) error { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) GetStorageUsed(_ common.Address) (uint64, error) { +func (StandardLibrary) GetStorageUsed(_ common.Address) (uint64, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) GetStorageCapacity(_ common.Address) (uint64, error) { +func (StandardLibrary) GetStorageCapacity(_ common.Address) (uint64, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) GetAccountKey(_ common.Address, _ int) (*stdlib.AccountKey, error) { +func (StandardLibrary) GetAccountKey(_ common.Address, _ int) (*stdlib.AccountKey, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) GetAccountContractNames(_ common.Address) ([]string, error) { +func (StandardLibrary) GetAccountContractNames(_ common.Address) ([]string, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) GetAccountContractCode(_ common.AddressLocation) ([]byte, error) { +func (StandardLibrary) GetAccountContractCode(_ common.AddressLocation) ([]byte, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) EmitEvent( +func (StandardLibrary) EmitEvent( _ *interpreter.Interpreter, _ *sema.CompositeType, _ []interpreter.Value, @@ -117,19 +117,19 @@ func (standardLibrary) EmitEvent( panic(errors.NewUnreachableError()) } -func (standardLibrary) AddEncodedAccountKey(_ common.Address, _ []byte) error { +func (StandardLibrary) AddEncodedAccountKey(_ common.Address, _ []byte) error { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) RevokeEncodedAccountKey(_ common.Address, _ int) ([]byte, error) { +func (StandardLibrary) RevokeEncodedAccountKey(_ common.Address, _ int) ([]byte, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) AddAccountKey( +func (StandardLibrary) AddAccountKey( _ common.Address, _ *stdlib.PublicKey, _ sema.HashAlgorithm, @@ -143,37 +143,37 @@ func (standardLibrary) AddAccountKey( panic(errors.NewUnreachableError()) } -func (standardLibrary) RevokeAccountKey(_ common.Address, _ int) (*stdlib.AccountKey, error) { +func (StandardLibrary) RevokeAccountKey(_ common.Address, _ int) (*stdlib.AccountKey, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) ParseAndCheckProgram(_ []byte, _ common.Location, _ bool) (*interpreter.Program, error) { +func (StandardLibrary) ParseAndCheckProgram(_ []byte, _ common.Location, _ bool) (*interpreter.Program, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) UpdateAccountContractCode(_ common.AddressLocation, _ []byte) error { +func (StandardLibrary) UpdateAccountContractCode(_ common.AddressLocation, _ []byte) error { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) RecordContractUpdate(_ common.AddressLocation, _ *interpreter.CompositeValue) { +func (StandardLibrary) RecordContractUpdate(_ common.AddressLocation, _ *interpreter.CompositeValue) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) ContractUpdateRecorded(_ common.AddressLocation) bool { +func (StandardLibrary) ContractUpdateRecorded(_ common.AddressLocation) bool { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) InterpretContract( +func (StandardLibrary) InterpretContract( _ common.AddressLocation, _ *interpreter.Program, _ string, @@ -184,37 +184,37 @@ func (standardLibrary) InterpretContract( panic(errors.NewUnreachableError()) } -func (standardLibrary) TemporarilyRecordCode(_ common.AddressLocation, _ []byte) { +func (StandardLibrary) TemporarilyRecordCode(_ common.AddressLocation, _ []byte) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) RemoveAccountContractCode(_ common.AddressLocation) error { +func (StandardLibrary) RemoveAccountContractCode(_ common.AddressLocation) error { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) RecordContractRemoval(_ common.AddressLocation) { +func (StandardLibrary) RecordContractRemoval(_ common.AddressLocation) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) CreateAccount(_ common.Address) (address common.Address, err error) { +func (StandardLibrary) CreateAccount(_ common.Address) (address common.Address, err error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) ValidatePublicKey(_ *stdlib.PublicKey) error { +func (StandardLibrary) ValidatePublicKey(_ *stdlib.PublicKey) error { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) VerifySignature( +func (StandardLibrary) VerifySignature( _ []byte, _ string, _ []byte, @@ -227,78 +227,78 @@ func (standardLibrary) VerifySignature( panic(errors.NewUnreachableError()) } -func (standardLibrary) BLSVerifyPOP(_ *stdlib.PublicKey, _ []byte) (bool, error) { +func (StandardLibrary) BLSVerifyPOP(_ *stdlib.PublicKey, _ []byte) (bool, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) Hash(_ []byte, _ string, _ sema.HashAlgorithm) ([]byte, error) { +func (StandardLibrary) Hash(_ []byte, _ string, _ sema.HashAlgorithm) ([]byte, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) AccountKeysCount(_ common.Address) (uint64, error) { +func (StandardLibrary) AccountKeysCount(_ common.Address) (uint64, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) BLSAggregatePublicKeys(_ []*stdlib.PublicKey) (*stdlib.PublicKey, error) { +func (StandardLibrary) BLSAggregatePublicKeys(_ []*stdlib.PublicKey) (*stdlib.PublicKey, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) BLSAggregateSignatures(_ [][]byte) ([]byte, error) { +func (StandardLibrary) BLSAggregateSignatures(_ [][]byte) ([]byte, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (l standardLibrary) GenerateAccountID(_ common.Address) (uint64, error) { +func (l StandardLibrary) GenerateAccountID(_ common.Address) (uint64, error) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (l standardLibrary) ReadRandom(_ []byte) error { +func (l StandardLibrary) ReadRandom(_ []byte) error { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) StartContractAddition(_ common.AddressLocation) { +func (StandardLibrary) StartContractAddition(_ common.AddressLocation) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) EndContractAddition(_ common.AddressLocation) { +func (StandardLibrary) EndContractAddition(_ common.AddressLocation) { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func (standardLibrary) IsContractBeingAdded(_ common.AddressLocation) bool { +func (StandardLibrary) IsContractBeingAdded(_ common.AddressLocation) bool { // Implementation should never be called, // only its definition is used for type-checking panic(errors.NewUnreachableError()) } -func newStandardLibrary() (result standardLibrary) { - result.baseValueActivation = sema.NewVariableActivation(sema.BaseValueActivation) +func NewStandardLibrary() (result StandardLibrary) { + result.BaseValueActivation = sema.NewVariableActivation(sema.BaseValueActivation) for _, valueDeclaration := range stdlib.DefaultStandardLibraryValues(result) { - result.baseValueActivation.DeclareValue(valueDeclaration) + result.BaseValueActivation.DeclareValue(valueDeclaration) } return } -func newScriptStandardLibrary() (result standardLibrary) { - result.baseValueActivation = sema.NewVariableActivation(sema.BaseValueActivation) +func NewScriptStandardLibrary() (result StandardLibrary) { + result.BaseValueActivation = sema.NewVariableActivation(sema.BaseValueActivation) for _, declaration := range stdlib.DefaultScriptStandardLibraryValues(result) { - result.baseValueActivation.DeclareValue(declaration) + result.BaseValueActivation.DeclareValue(declaration) } return }