diff --git a/internal/cadence/lint.go b/internal/cadence/lint.go index 525eb3bec..85fc0d944 100644 --- a/internal/cadence/lint.go +++ b/internal/cadence/lint.go @@ -32,6 +32,7 @@ import ( "github.com/onflow/flowkit/v2/output" "github.com/onflow/flow-cli/internal/command" + "github.com/onflow/flow-cli/internal/util" ) type lintFlagsCollection struct{} @@ -202,7 +203,15 @@ func (r *lintResult) String() string { total := numErrors + numWarnings if total > 0 { - sb.WriteString(aurora.Colorize(fmt.Sprintf("%d %s (%d %s, %d %s)", total, pluralize("problem", total), numErrors, pluralize("error", numErrors), numWarnings, pluralize("warning", numWarnings)), color).String()) + sb.WriteString(aurora.Colorize(fmt.Sprintf( + "%d %s (%d %s, %d %s)", + total, + util.Pluralize("problem", total), + numErrors, + util.Pluralize("error", numErrors), + numWarnings, + util.Pluralize("warning", numWarnings), + ), color).String()) } else { sb.WriteString(aurora.Green("Lint passed").String()) } @@ -219,7 +228,7 @@ func (r *lintResult) Oneliner() string { total := numErrors + numWarnings if total > 0 { - return fmt.Sprintf("%d %s (%d %s, %d %s)", total, pluralize("problem", total), numErrors, pluralize("error", numErrors), numWarnings, pluralize("warning", numWarnings)) + return fmt.Sprintf("%d %s (%d %s, %d %s)", total, util.Pluralize("problem", total), numErrors, util.Pluralize("error", numErrors), numWarnings, util.Pluralize("warning", numWarnings)) } return "Lint passed" } @@ -227,10 +236,3 @@ func (r *lintResult) Oneliner() string { func (r *lintResult) ExitCode() int { return r.exitCode } - -func pluralize(word string, count int) string { - if count == 1 { - return word - } - return word + "s" -} diff --git a/internal/migrate/get_staged_code.go b/internal/migrate/get_staged_code.go index 0a4779f52..1083263c7 100644 --- a/internal/migrate/get_staged_code.go +++ b/internal/migrate/get_staged_code.go @@ -23,6 +23,7 @@ import ( "fmt" "github.com/onflow/cadence" + "github.com/onflow/cadence/runtime/common" "github.com/onflow/flowkit/v2" "github.com/onflow/flowkit/v2/output" "github.com/spf13/cobra" @@ -64,24 +65,56 @@ func getStagedCode( return nil, fmt.Errorf("error getting address by contract name: %w", err) } - cName, err := cadence.NewString(contractName) + location := common.NewAddressLocation(nil, common.Address(addr), contractName) + code, err := getStagedContractCode(context.Background(), flow, location) if err != nil { - return nil, fmt.Errorf("error creating cadence string: %w", err) + return nil, err + } + + // If the contract is not staged, return nil + if code == nil { + return scripts.NewScriptResult(cadence.NewOptional(nil)), nil } - caddr := cadence.NewAddress(addr) + return scripts.NewScriptResult(cadence.NewOptional(cadence.String(code))), nil +} + +func getStagedContractCode( + ctx context.Context, + flow flowkit.Services, + location common.AddressLocation, +) ([]byte, error) { + cAddr := cadence.BytesToAddress(location.Address.Bytes()) + cName, err := cadence.NewString(location.Name) + if err != nil { + return nil, fmt.Errorf("failed to get cadence string from contract name: %w", err) + } value, err := flow.ExecuteScript( context.Background(), flowkit.Script{ Code: templates.GenerateGetStagedContractCodeScript(MigrationContractStagingAddress(flow.Network().Name)), - Args: []cadence.Value{caddr, cName}, + Args: []cadence.Value{cAddr, cName}, }, flowkit.LatestScriptQuery, ) if err != nil { - return nil, fmt.Errorf("error executing script: %w", err) + return nil, err + } + + optValue, ok := value.(cadence.Optional) + if !ok { + return nil, fmt.Errorf("invalid script return value type: %T", value) + } + + if optValue.Value == nil { + return nil, nil + } + + strValue, ok := optValue.Value.(cadence.String) + if !ok { + return nil, fmt.Errorf("invalid script return value type: %T", value) } - return scripts.NewScriptResult(value), nil + return []byte(strValue), nil } diff --git a/internal/migrate/get_staged_code_test.go b/internal/migrate/get_staged_code_test.go index a80f7fd94..681564c4f 100644 --- a/internal/migrate/get_staged_code_test.go +++ b/internal/migrate/get_staged_code_test.go @@ -79,7 +79,7 @@ func Test_GetStagedCode(t *testing.T) { contractAddr := cadence.NewAddress(account.Address) assert.Equal(t, contractAddr.String(), actualContractAddressArg.String()) - }).Return(cadence.NewString(string(testContract.Source))) + }).Return(cadence.NewOptional(cadence.String(testContract.Source)), nil) result, err := getStagedCode( []string{testContract.Name}, diff --git a/internal/migrate/is_validated.go b/internal/migrate/is_validated.go index e7d8f890b..e0d3d7cac 100644 --- a/internal/migrate/is_validated.go +++ b/internal/migrate/is_validated.go @@ -39,14 +39,14 @@ import ( "github.com/onflow/flow-cli/internal/util" ) -//go:generate mockery --name GitHubRepositoriesService --output ./mocks --case underscore -type GitHubRepositoriesService interface { +//go:generate mockery --name gitHubRepositoriesService --inpackage --testonly --case underscore +type gitHubRepositoriesService interface { GetContents(ctx context.Context, owner string, repo string, path string, opt *github.RepositoryContentGetOptions) (fileContent *github.RepositoryContent, directoryContent []*github.RepositoryContent, resp *github.Response, err error) DownloadContents(ctx context.Context, owner string, repo string, filepath string, opt *github.RepositoryContentGetOptions) (io.ReadCloser, error) } type validator struct { - repoService GitHubRepositoriesService + repoService gitHubRepositoriesService state *flowkit.State logger output.Logger network config.Network @@ -102,7 +102,7 @@ func isValidated( return v.validate(contractName) } -func newValidator(repoService GitHubRepositoriesService, network config.Network, state *flowkit.State, logger output.Logger) *validator { +func newValidator(repoService gitHubRepositoriesService, network config.Network, state *flowkit.State, logger output.Logger) *validator { return &validator{ repoService: repoService, state: state, diff --git a/internal/migrate/is_validated_test.go b/internal/migrate/is_validated_test.go index 18e59b449..33c5ab7e8 100644 --- a/internal/migrate/is_validated_test.go +++ b/internal/migrate/is_validated_test.go @@ -32,7 +32,6 @@ import ( "github.com/stretchr/testify/require" "github.com/onflow/flow-cli/internal/command" - "github.com/onflow/flow-cli/internal/migrate/mocks" "github.com/onflow/flow-cli/internal/util" ) @@ -47,7 +46,7 @@ func Test_IsValidated(t *testing.T) { // Helper function to test the isValidated function // with all of the necessary mocks testIsValidatedWithStatuses := func(statuses []contractUpdateStatus) (command.Result, error) { - mockClient := mocks.NewGitHubRepositoriesService(t) + mockClient := newMockGitHubRepositoriesService(t) // mock github file download data, _ := json.Marshal(statuses) diff --git a/internal/migrate/migrate.go b/internal/migrate/migrate.go index 84820b723..3f0348759 100644 --- a/internal/migrate/migrate.go +++ b/internal/migrate/migrate.go @@ -34,7 +34,7 @@ func init() { getStagedCodeCommand.AddToParent(Cmd) IsStagedCommand.AddToParent(Cmd) listStagedContractsCommand.AddToParent(Cmd) - stageContractCommand.AddToParent(Cmd) + stageCommand.AddToParent(Cmd) unstageContractCommand.AddToParent(Cmd) stateCommand.AddToParent(Cmd) IsValidatedCommand.AddToParent(Cmd) diff --git a/internal/migrate/mocks/git_hub_repositories_service.go b/internal/migrate/mock_git_hub_repositories_service_test.go similarity index 74% rename from internal/migrate/mocks/git_hub_repositories_service.go rename to internal/migrate/mock_git_hub_repositories_service_test.go index e586ee016..6c5d0cb5e 100644 --- a/internal/migrate/mocks/git_hub_repositories_service.go +++ b/internal/migrate/mock_git_hub_repositories_service_test.go @@ -1,6 +1,6 @@ -// Code generated by mockery v2.38.0. DO NOT EDIT. +// Code generated by mockery v2.40.3. DO NOT EDIT. -package mocks +package migrate import ( context "context" @@ -11,13 +11,13 @@ import ( mock "github.com/stretchr/testify/mock" ) -// GitHubRepositoriesService is an autogenerated mock type for the GitHubRepositoriesService type -type GitHubRepositoriesService struct { +// mockGitHubRepositoriesService is an autogenerated mock type for the gitHubRepositoriesService type +type mockGitHubRepositoriesService struct { mock.Mock } // DownloadContents provides a mock function with given fields: ctx, owner, repo, filepath, opt -func (_m *GitHubRepositoriesService) DownloadContents(ctx context.Context, owner string, repo string, filepath string, opt *github.RepositoryContentGetOptions) (io.ReadCloser, error) { +func (_m *mockGitHubRepositoriesService) DownloadContents(ctx context.Context, owner string, repo string, filepath string, opt *github.RepositoryContentGetOptions) (io.ReadCloser, error) { ret := _m.Called(ctx, owner, repo, filepath, opt) if len(ret) == 0 { @@ -47,7 +47,7 @@ func (_m *GitHubRepositoriesService) DownloadContents(ctx context.Context, owner } // GetContents provides a mock function with given fields: ctx, owner, repo, path, opt -func (_m *GitHubRepositoriesService) GetContents(ctx context.Context, owner string, repo string, path string, opt *github.RepositoryContentGetOptions) (*github.RepositoryContent, []*github.RepositoryContent, *github.Response, error) { +func (_m *mockGitHubRepositoriesService) GetContents(ctx context.Context, owner string, repo string, path string, opt *github.RepositoryContentGetOptions) (*github.RepositoryContent, []*github.RepositoryContent, *github.Response, error) { ret := _m.Called(ctx, owner, repo, path, opt) if len(ret) == 0 { @@ -94,13 +94,13 @@ func (_m *GitHubRepositoriesService) GetContents(ctx context.Context, owner stri return r0, r1, r2, r3 } -// NewGitHubRepositoriesService creates a new instance of GitHubRepositoriesService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// newMockGitHubRepositoriesService creates a new instance of mockGitHubRepositoriesService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. -func NewGitHubRepositoriesService(t interface { +func newMockGitHubRepositoriesService(t interface { mock.TestingT Cleanup(func()) -}) *GitHubRepositoriesService { - mock := &GitHubRepositoriesService{} +}) *mockGitHubRepositoriesService { + mock := &mockGitHubRepositoriesService{} mock.Mock.Test(t) t.Cleanup(func() { mock.AssertExpectations(t) }) diff --git a/internal/migrate/mock_staging_service_test.go b/internal/migrate/mock_staging_service_test.go new file mode 100644 index 000000000..a7f79ce54 --- /dev/null +++ b/internal/migrate/mock_staging_service_test.go @@ -0,0 +1,80 @@ +// Code generated by mockery v2.40.3. DO NOT EDIT. + +package migrate + +import ( + context "context" + + common "github.com/onflow/cadence/runtime/common" + + mock "github.com/stretchr/testify/mock" + + project "github.com/onflow/flowkit/v2/project" +) + +// mockStagingService is an autogenerated mock type for the stagingService type +type mockStagingService struct { + mock.Mock +} + +// PrettyPrintValidationError provides a mock function with given fields: err, location +func (_m *mockStagingService) PrettyPrintValidationError(err error, location common.Location) string { + ret := _m.Called(err, location) + + if len(ret) == 0 { + panic("no return value specified for PrettyPrintValidationError") + } + + var r0 string + if rf, ok := ret.Get(0).(func(error, common.Location) string); ok { + r0 = rf(err, location) + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// StageContracts provides a mock function with given fields: ctx, contracts +func (_m *mockStagingService) StageContracts(ctx context.Context, contracts []*project.Contract) (map[common.AddressLocation]stagingResult, error) { + ret := _m.Called(ctx, contracts) + + if len(ret) == 0 { + panic("no return value specified for StageContracts") + } + + var r0 map[common.AddressLocation]stagingResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []*project.Contract) (map[common.AddressLocation]stagingResult, error)); ok { + return rf(ctx, contracts) + } + if rf, ok := ret.Get(0).(func(context.Context, []*project.Contract) map[common.AddressLocation]stagingResult); ok { + r0 = rf(ctx, contracts) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[common.AddressLocation]stagingResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, []*project.Contract) error); ok { + r1 = rf(ctx, contracts) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// newMockStagingService creates a new instance of mockStagingService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newMockStagingService(t interface { + mock.TestingT + Cleanup(func()) +}) *mockStagingService { + mock := &mockStagingService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/migrate/mock_staging_validator_test.go b/internal/migrate/mock_staging_validator_test.go new file mode 100644 index 000000000..e16fab916 --- /dev/null +++ b/internal/migrate/mock_staging_validator_test.go @@ -0,0 +1,63 @@ +// Code generated by mockery v2.40.3. DO NOT EDIT. + +package migrate + +import ( + common "github.com/onflow/cadence/runtime/common" + mock "github.com/stretchr/testify/mock" +) + +// mockStagingValidator is an autogenerated mock type for the stagingValidator type +type mockStagingValidator struct { + mock.Mock +} + +// PrettyPrintError provides a mock function with given fields: err, location +func (_m *mockStagingValidator) PrettyPrintError(err error, location common.Location) string { + ret := _m.Called(err, location) + + if len(ret) == 0 { + panic("no return value specified for PrettyPrintError") + } + + var r0 string + if rf, ok := ret.Get(0).(func(error, common.Location) string); ok { + r0 = rf(err, location) + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// Validate provides a mock function with given fields: stagedContracts +func (_m *mockStagingValidator) Validate(stagedContracts []stagedContractUpdate) error { + ret := _m.Called(stagedContracts) + + if len(ret) == 0 { + panic("no return value specified for Validate") + } + + var r0 error + if rf, ok := ret.Get(0).(func([]stagedContractUpdate) error); ok { + r0 = rf(stagedContracts) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// newMockStagingValidator creates a new instance of mockStagingValidator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newMockStagingValidator(t interface { + mock.TestingT + Cleanup(func()) +}) *mockStagingValidator { + mock := &mockStagingValidator{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/migrate/stage.go b/internal/migrate/stage.go new file mode 100644 index 000000000..9bbcd4d4b --- /dev/null +++ b/internal/migrate/stage.go @@ -0,0 +1,335 @@ +/* + * Flow CLI + * + * Copyright 2019 Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package migrate + +import ( + "context" + "fmt" + "strings" + + "github.com/logrusorgru/aurora/v4" + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/flow-go-sdk" + "github.com/onflow/flowkit/v2" + "github.com/onflow/flowkit/v2/output" + "github.com/onflow/flowkit/v2/project" + "github.com/spf13/cobra" + + "github.com/onflow/flow-cli/internal/command" + "github.com/onflow/flow-cli/internal/util" +) + +const stagingLimit = 200 + +type stagingResults struct { + Results map[common.AddressLocation]stagingResult + prettyPrinter func(err error, location common.Location) string +} + +var _ command.ResultWithExitCode = &stagingResults{} + +var stageProjectFlags struct { + Accounts []string `default:"" flag:"account" info:"Accounts to stage the contract under"` + SkipValidation bool `default:"false" flag:"skip-validation" info:"Do not validate the contract code against staged dependencies"` +} + +var stageCommand = &command.Command{ + Cmd: &cobra.Command{ + Use: "stage [contract names...]", + Short: "Stage a contract, or many contracts, for migration", + Example: `# Stage all contracts +flow migrate stage --network testnet + +# Stage by contract name(s) +flow migrate stage Foo Bar --network testnet + +# Stage by account name(s) +flow migrate stage --account my-account --network testnet`, + Args: cobra.ArbitraryArgs, + Aliases: []string{"stage"}, + }, + Flags: &stageProjectFlags, + RunS: stageProject, +} + +func stageProject( + args []string, + globalFlags command.GlobalFlags, + logger output.Logger, + flow flowkit.Services, + state *flowkit.State, +) (command.Result, error) { + err := checkNetwork(flow.Network()) + if err != nil { + return nil, err + } + + // Validate command arguments + if len(stageProjectFlags.Accounts) > 0 && len(args) > 0 { + return nil, fmt.Errorf("only one of contract names or --account can be provided") + } + + // Stage based on flags + var v stagingValidator + if !stageProjectFlags.SkipValidation { + v = newStagingValidator(flow) + } + s := newStagingService(flow, state, logger, v, promptStagingUnvalidatedContracts(logger)) + + if len(args) == 0 && len(stageProjectFlags.Accounts) == 0 { + return stageAll(s, state, flow) + } + + if len(stageProjectFlags.Accounts) > 0 { + return stageByAccountNames(s, state, flow, stageProjectFlags.Accounts) + } + + return stageByContractNames(s, state, flow, args) +} + +func promptStagingUnvalidatedContracts(logger output.Logger) func(validatorError *stagingValidatorError) bool { + return func(validatorError *stagingValidatorError) bool { + infoMessage := strings.Builder{} + infoMessage.WriteString("Preliminary validation could not be performed on the following contracts:\n") + + // Sort the locations for consistent output + missingDependencyErrors := validatorError.MissingDependencyErrors() + sortedLocations := make([]common.AddressLocation, 0, len(missingDependencyErrors)) + for deployLocation := range missingDependencyErrors { + sortedLocations = append(sortedLocations, deployLocation) + } + sortAddressLocations(sortedLocations) + + // Print the locations + for _, deployLocation := range sortedLocations { + infoMessage.WriteString(fmt.Sprintf(" - %s\n", deployLocation)) + } + + infoMessage.WriteString("\nThese contracts depend on the following contracts which have not been staged yet:\n") + + // Print the missing dependencies + missingDependencies := validatorError.MissingDependencies() + for _, depLocation := range missingDependencies { + infoMessage.WriteString(fmt.Sprintf(" - %s\n", depLocation)) + } + + infoMessage.WriteString("\nYou may still stage your contract, however it will be unable to be migrated until the missing contracts are staged by their respective owners. It is important to monitor the status of your contract using the `flow migrate is-validated` command\n") + logger.Error(infoMessage.String()) + + return util.GenericBoolPrompt("Do you wish to continue staging your contract?") + } +} + +func stageAll( + s stagingService, + state *flowkit.State, + flow flowkit.Services, +) (*stagingResults, error) { + contracts, err := state.DeploymentContractsByNetwork(flow.Network()) + if err != nil { + return nil, err + } + + if len(contracts) > stagingLimit { + return nil, fmt.Errorf("cannot stage more than %d contracts at once", stagingLimit) + } + + results, err := s.StageContracts(context.Background(), contracts) + if err != nil { + return nil, err + } + + return &stagingResults{Results: results, prettyPrinter: s.PrettyPrintValidationError}, nil +} + +func stageByContractNames( + s stagingService, + state *flowkit.State, + flow flowkit.Services, + contractNames []string, +) (*stagingResults, error) { + contracts, err := state.DeploymentContractsByNetwork(flow.Network()) + if err != nil { + return nil, err + } + + filteredContracts := make([]*project.Contract, 0) + for _, name := range contractNames { + found := false + for _, contract := range contracts { + if contract.Name == name { + filteredContracts = append(filteredContracts, contract) + found = true + } + } + if !found { + return nil, fmt.Errorf("deployment not found for contract %s on network %s", name, flow.Network().Name) + } + } + + if len(contracts) > stagingLimit { + return nil, fmt.Errorf("cannot stage more than %d contracts at once", stagingLimit) + } + + results, err := s.StageContracts(context.Background(), filteredContracts) + if err != nil { + return nil, err + } + + return &stagingResults{Results: results, prettyPrinter: s.PrettyPrintValidationError}, nil +} + +func stageByAccountNames( + s stagingService, + state *flowkit.State, + flow flowkit.Services, + accountNames []string, +) (*stagingResults, error) { + contracts, err := state.DeploymentContractsByNetwork(flow.Network()) + if err != nil { + return nil, err + } + + filteredContracts := make([]*project.Contract, 0) + for _, accountName := range accountNames { + account, err := state.Accounts().ByName(accountName) + if err != nil { + return nil, err + } + + found := false + for _, contract := range contracts { + if contract.AccountName == account.Name { + filteredContracts = append(filteredContracts, contract) + found = true + } + } + + if !found { + return nil, fmt.Errorf("no deployments found for account %s on network %s", account.Name, flow.Network().Name) + } + } + + if len(contracts) > stagingLimit { + return nil, fmt.Errorf("cannot stage more than %d contracts at once", stagingLimit) + } + + results, err := s.StageContracts(context.Background(), filteredContracts) + if err != nil { + return nil, err + } + + return &stagingResults{Results: results, prettyPrinter: s.PrettyPrintValidationError}, nil +} + +func (r *stagingResults) ExitCode() int { + for _, r := range r.Results { + if r.Err != nil { + return 1 + } + } + return 0 +} + +func (r *stagingResults) String() string { + var sb strings.Builder + + // First print out any errors that occurred during staging + for _, result := range r.Results { + if result.Err != nil { + sb.WriteString(r.prettyPrinter(result.Err, nil)) + sb.WriteString("\n") + } + } + + numStaged := 0 + numUnvalidated := 0 + numFailed := 0 + + for location, result := range r.Results { + var color aurora.Color + var prefix string + + if result.Err == nil { + if result.WasValidated { + color = aurora.GreenFg + prefix = "✔" + numStaged++ + } else { + color = aurora.YellowFg + prefix = "⚠" + numUnvalidated++ + } + } else { + color = aurora.RedFg + prefix = "✘" + numFailed++ + } + + sb.WriteString(aurora.Colorize(fmt.Sprintf("%s %s ", prefix, location.String()), color).String()) + if result.TxId != flow.EmptyID { + sb.WriteString(fmt.Sprintf(" (txId: %s)", result.TxId)) + } else if result.Err == nil { + sb.WriteString(" (no changes)") + } + sb.WriteString("\n") + } + + sb.WriteString("\n") + + reports := []string{} + if numStaged > 0 { + reports = append(reports, aurora.Green(fmt.Sprintf("%d %s staged & validated", numStaged, util.Pluralize("contract", numStaged))).String()) + } + if numUnvalidated > 0 { + reports = append(reports, aurora.Yellow(fmt.Sprintf("%d %s staged without validation", numUnvalidated, util.Pluralize("contract", numStaged))).String()) + } + if numFailed > 0 { + reports = append(reports, aurora.Red(fmt.Sprintf("%d %s failed to stage", numFailed, util.Pluralize("contract", numFailed))).String()) + } + + sb.WriteString(fmt.Sprintf("Staging results: %s\n\n", strings.Join(reports, ", "))) + + sb.WriteString("DISCLAIMER: Pre-staging validation checks are not exhaustive and do not guarantee the contract will work as expected, please monitor the status of your contract using the `flow migrate is-validated` command\n\n") + sb.WriteString("You may use the --skip-validation flag to disable these checks and stage all contracts regardless") + + return sb.String() +} + +func (s *stagingResults) JSON() interface{} { + return s +} + +func (r *stagingResults) Oneliner() string { + if len(r.Results) == 0 { + return "no contracts staged" + } + return fmt.Sprintf("staged %d contracts", len(r.Results)) +} + +// helpers +func boolCount(flags ...bool) int { + count := 0 + for _, flag := range flags { + if flag { + count++ + } + } + return count +} diff --git a/internal/migrate/stage_contract.go b/internal/migrate/stage_contract.go deleted file mode 100644 index b2493cd6c..000000000 --- a/internal/migrate/stage_contract.go +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Flow CLI - * - * Copyright 2019 Dapper Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package migrate - -import ( - "context" - "errors" - "fmt" - "strings" - - "github.com/manifoldco/promptui" - "github.com/onflow/cadence" - "github.com/onflow/cadence/runtime/common" - "github.com/onflow/contract-updater/lib/go/templates" - flowsdk "github.com/onflow/flow-go-sdk" - "github.com/onflow/flowkit/v2" - "github.com/onflow/flowkit/v2/output" - "github.com/onflow/flowkit/v2/transactions" - "github.com/spf13/cobra" - - internaltx "github.com/onflow/flow-cli/internal/transactions" - - "github.com/onflow/flow-cli/internal/command" -) - -var stageContractflags struct { - SkipValidation bool `default:"false" flag:"skip-validation" info:"Do not validate the contract code against staged dependencies"` -} - -var stageContractCommand = &command.Command{ - Cmd: &cobra.Command{ - Use: "stage-contract ", - Short: "stage a contract for migration", - Example: `flow migrate stage-contract HelloWorld`, - Args: cobra.MinimumNArgs(1), - }, - Flags: &stageContractflags, - RunS: stageContract, -} - -func stageContract( - args []string, - globalFlags command.GlobalFlags, - logger output.Logger, - flow flowkit.Services, - state *flowkit.State, -) (command.Result, error) { - err := checkNetwork(flow.Network()) - if err != nil { - return nil, err - } - - contractName := args[0] - contract, err := state.Contracts().ByName(contractName) - if err != nil { - return nil, fmt.Errorf("no contracts found in state") - } - - replacedCode, err := replaceImportsIfExists(state, flow, contract.Location) - if err != nil { - return nil, fmt.Errorf("failed to replace imports: %w", err) - } - - cName, err := cadence.NewString(contractName) - if err != nil { - return nil, fmt.Errorf("failed to get cadence string from contract name: %w", err) - } - - cCode, err := cadence.NewString(string(replacedCode)) - if err != nil { - return nil, fmt.Errorf("failed to get cadence string from contract code: %w", err) - } - - account, err := getAccountByContractName(state, contractName, flow.Network()) - if err != nil { - return nil, fmt.Errorf("failed to get account by contract name: %w", err) - } - - // Validate the contract code by default - if !stageContractflags.SkipValidation { - logger.StartProgress("Validating contract code against any staged dependencies") - validator := newStagingValidator(flow, state) - - var missingDependenciesErr *missingDependenciesError - contractLocation := common.NewAddressLocation(nil, common.Address(account.Address), contractName) - err = validator.ValidateContractUpdate(contractLocation, common.StringLocation(contract.Location), replacedCode) - - logger.StopProgress() - - // Errors when the contract's dependencies have not been staged yet are non-fatal - // This is because the contract may be dependent on contracts that are not yet staged - // and we do not want to require in-order staging of contracts - // Instead, we will prompt the user to continue staging the contract. Other errors - // will be fatal and require manual intervention using the --skip-validation flag if desired - if errors.As(err, &missingDependenciesErr) { - infoMessage := strings.Builder{} - infoMessage.WriteString("Validation cannot be performed as some of your contract's dependencies could not be found (have they been staged yet?)\n") - for _, contract := range missingDependenciesErr.MissingContracts { - infoMessage.WriteString(fmt.Sprintf(" - %s\n", contract)) - } - infoMessage.WriteString("\nYou may still stage your contract, however it will be unable to be migrated until the missing contracts are staged by their respective owners. It is important to monitor the status of your contract using the `flow migrate is-validated` command\n") - logger.Error(infoMessage.String()) - - continuePrompt := promptui.Select{ - Label: "Do you wish to continue staging your contract?", - Items: []string{"Yes", "No"}, - } - - _, result, err := continuePrompt.Run() - if err != nil { - return nil, err - } - - if result == "No" { - return nil, fmt.Errorf("staging cancelled") - } - } else if err != nil { - logger.Error(validator.prettyPrintError(err, common.StringLocation(contract.Location))) - return nil, fmt.Errorf("errors were found while attempting to perform preliminary validation of the contract code, and your contract HAS NOT been staged, however you can use the --skip-validation flag to bypass this check & stage the contract anyway") - } else { - logger.Info("No issues found while validating contract code\n") - logger.Info("DISCLAIMER: Pre-staging validation checks are not exhaustive and do not guarantee the contract will work as expected, please monitor the status of your contract using the `flow migrate is-validated` command\n") - } - } else { - logger.Info("Skipping contract code validation, you may monitor the status of your contract using the `flow migrate is-validated` command\n") - } - - tx, res, err := flow.SendTransaction( - context.Background(), - transactions.SingleAccountRole(*account), - flowkit.Script{ - Code: templates.GenerateStageContractScript(MigrationContractStagingAddress(flow.Network().Name)), - Args: []cadence.Value{cName, cCode}, - }, - flowsdk.DefaultTransactionGasLimit, - ) - - if err != nil { - return nil, fmt.Errorf("failed to send transaction: %w", err) - } - - return internaltx.NewTransactionResult(tx, res), nil -} diff --git a/internal/migrate/stage_contract_test.go b/internal/migrate/stage_contract_test.go deleted file mode 100644 index 02533d794..000000000 --- a/internal/migrate/stage_contract_test.go +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Flow CLI - * - * Copyright 2019 Dapper Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package migrate - -import ( - "testing" - - "github.com/onflow/cadence" - "github.com/onflow/contract-updater/lib/go/templates" - "github.com/onflow/flow-go-sdk" - "github.com/onflow/flowkit/v2" - "github.com/onflow/flowkit/v2/config" - "github.com/onflow/flowkit/v2/tests" - "github.com/onflow/flowkit/v2/transactions" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - - "github.com/onflow/flow-cli/internal/command" - "github.com/onflow/flow-cli/internal/util" -) - -func Test_StageContract(t *testing.T) { - testContract := tests.ContractSimple - - t.Run("Success", func(t *testing.T) { - srv, state, _ := util.TestMocks(t) - - // Add contract to state - state.Contracts().AddOrUpdate( - config.Contract{ - Name: testContract.Name, - Location: testContract.Filename, - }, - ) - - // Add deployment to state - state.Deployments().AddOrUpdate( - config.Deployment{ - Network: "testnet", - Account: "emulator-account", - Contracts: []config.ContractDeployment{ - { - Name: testContract.Name, - }, - }, - }, - ) - - srv.Network.Return(config.Network{ - Name: "testnet", - }, nil) - - srv.SendTransaction.Run(func(args mock.Arguments) { - accountRoles := args.Get(1).(transactions.AccountRoles) - script := args.Get(2).(flowkit.Script) - - assert.Equal(t, templates.GenerateStageContractScript(MigrationContractStagingAddress("testnet")), script.Code) - - assert.Equal(t, 1, len(accountRoles.Signers())) - assert.Equal(t, "emulator-account", accountRoles.Signers()[0].Name) - assert.Equal(t, 2, len(script.Args)) - - actualContractNameArg, actualContractCodeArg := script.Args[0], script.Args[1] - - contractName, _ := cadence.NewString(testContract.Name) - contractBody, _ := cadence.NewString(string(testContract.Source)) - assert.Equal(t, contractName, actualContractNameArg) - assert.Equal(t, contractBody, actualContractCodeArg) - }).Return(flow.NewTransaction(), &flow.TransactionResult{ - Status: flow.TransactionStatusSealed, - Error: nil, - BlockHeight: 1, - }, nil) - - // disable validation - stageContractflags.SkipValidation = true - - result, err := stageContract( - []string{testContract.Name}, - command.GlobalFlags{ - Network: "testnet", - }, - util.NoLogger, - srv.Mock, - state, - ) - // reset flags - stageContractflags.SkipValidation = false - - assert.NoError(t, err) - assert.NotNil(t, result) - }) - - t.Run("missing contract", func(t *testing.T) { - srv, state, _ := util.TestMocks(t) - result, err := stageContract( - []string{testContract.Name}, - command.GlobalFlags{ - Network: "testnet", - }, - util.NoLogger, - srv.Mock, - state, - ) - assert.Error(t, err) - assert.Nil(t, result) - }) -} diff --git a/internal/migrate/stage_test.go b/internal/migrate/stage_test.go new file mode 100644 index 000000000..eaf0cb166 --- /dev/null +++ b/internal/migrate/stage_test.go @@ -0,0 +1,189 @@ +/* + * Flow CLI + * + * Copyright 2019 Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package migrate + +import ( + "testing" + + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/flowkit/v2" + "github.com/onflow/flowkit/v2/config" + flowkitMocks "github.com/onflow/flowkit/v2/mocks" + "github.com/onflow/flowkit/v2/tests" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func Test_StageContract(t *testing.T) { + setupMocks := func( + accts []mockAccount, + ) (*mockStagingService, flowkit.Services, *flowkit.State) { + ss := newMockStagingService(t) + srv := flowkitMocks.NewServices(t) + rw, _ := tests.ReaderWriter() + state, _ := flowkit.Init(rw) + + addAccountsToState(t, state, accts) + + srv.On("Network").Return(config.Network{ + Name: "testnet", + }, nil) + + return ss, srv, state + } + + t.Run("all contracts filter", func(t *testing.T) { + ss, srv, state := setupMocks([]mockAccount{ + { + name: "my-account", + address: "0x01", + deployments: []mockDeployment{ + { + name: "Foo", + code: `FooCode`, + }, + }, + }, + }) + + mockResult := make(map[common.AddressLocation]stagingResult) + mockResult[common.NewAddressLocation(nil, common.Address{0x01}, "Foo")] = stagingResult{ + Err: nil, + } + + ss.On("StageContracts", mock.Anything, mock.Anything).Return(mockResult, nil) + + result, err := stageAll( + ss, + state, + srv, + ) + + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, mockResult, result.Results) + }) + + t.Run("contract name filter", func(t *testing.T) { + ss, srv, state := setupMocks([]mockAccount{ + { + name: "my-account", + address: "0x01", + deployments: []mockDeployment{ + { + name: "Foo", + code: `FooCode`, + }, + { + name: "Bar", + code: `BarCode`, + }, + }, + }, + }) + + mockResult := make(map[common.AddressLocation]stagingResult) + mockResult[common.NewAddressLocation(nil, common.Address{0x01}, "Foo")] = stagingResult{ + Err: nil, + } + + ss.On("StageContracts", mock.Anything, mock.Anything).Return(mockResult, nil).Once() + + result, err := stageByContractNames( + ss, + state, + srv, + []string{"Foo"}, + ) + + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, mockResult, result.Results) + }) + + t.Run("contract name filter", func(t *testing.T) { + ss, srv, state := setupMocks([]mockAccount{ + { + name: "my-account", + address: "0x01", + deployments: []mockDeployment{ + { + name: "Foo", + code: `FooCode`, + }, + }, + }, + { + name: "other-account", + address: "0x02", + deployments: []mockDeployment{ + { + name: "Bar", + code: `BarCode`, + }, + }, + }, + }) + + mockResult := make(map[common.AddressLocation]stagingResult) + mockResult[common.NewAddressLocation(nil, common.Address{0x01}, "Foo")] = stagingResult{ + Err: nil, + } + ss.On("StageContracts", mock.Anything, mock.Anything).Return(mockResult, nil).Once() + + result, err := stageByAccountNames( + ss, + state, + srv, + []string{"my-account"}, + ) + + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, mockResult, result.Results) + }) + + t.Run("contract name not found", func(t *testing.T) { + ss, srv, state := setupMocks(nil) + + result, err := stageByContractNames( + ss, + state, + srv, + []string{"my-contract"}, + ) + + assert.Error(t, err) + assert.Nil(t, result) + }) + + t.Run("account not found", func(t *testing.T) { + ss, srv, state := setupMocks(nil) + + result, err := stageByAccountNames( + ss, + state, + srv, + []string{"my-account"}, + ) + + assert.Error(t, err) + assert.Nil(t, result) + }) +} diff --git a/internal/migrate/staging_service.go b/internal/migrate/staging_service.go new file mode 100644 index 000000000..2ed9ab368 --- /dev/null +++ b/internal/migrate/staging_service.go @@ -0,0 +1,283 @@ +/* + * Flow CLI + * + * Copyright 2019 Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package migrate + +import ( + "context" + "errors" + "fmt" + + "github.com/onflow/cadence" + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/contract-updater/lib/go/templates" + "github.com/onflow/flow-go-sdk" + "github.com/onflow/flowkit/v2" + "github.com/onflow/flowkit/v2/output" + "github.com/onflow/flowkit/v2/project" + "github.com/onflow/flowkit/v2/transactions" +) + +//go:generate mockery --name stagingService --inpackage --testonly --case underscore +type stagingService interface { + StageContracts(ctx context.Context, contracts []*project.Contract) (map[common.AddressLocation]stagingResult, error) + PrettyPrintValidationError(err error, location common.Location) string +} + +type stagingServiceImpl struct { + flow flowkit.Services + state *flowkit.State + logger output.Logger + validator stagingValidator + unvalidatedContractsHandler func(*stagingValidatorError) bool +} + +type stagingResult struct { + Err error + WasValidated bool + TxId flow.Identifier +} + +var _ stagingService = &stagingServiceImpl{} + +func newStagingService( + flow flowkit.Services, + state *flowkit.State, + logger output.Logger, + validator stagingValidator, + unvalidatedContractsHandler func(*stagingValidatorError) bool, +) *stagingServiceImpl { + handler := func(err *stagingValidatorError) bool { + return false + } + if unvalidatedContractsHandler != nil { + handler = unvalidatedContractsHandler + } + + return &stagingServiceImpl{ + flow: flow, + state: state, + logger: logger, + validator: validator, + unvalidatedContractsHandler: handler, + } +} + +func (s *stagingServiceImpl) StageContracts(ctx context.Context, contracts []*project.Contract) (map[common.AddressLocation]stagingResult, error) { + // Convert contracts to staged contracts + stagedContracts, err := s.convertToStagedContracts(contracts) + if err != nil { + return nil, err + } + + // If validation is disabled, just stage the contracts + if s.validator == nil { + s.logger.Info("Skipping contract code validation, you may monitor the status of your contract using the `flow migrate is-validated` command\n") + s.logger.StartProgress(fmt.Sprintf("Staging %d contracts for accounts: %s", len(contracts), s.state.AccountsForNetwork(s.flow.Network()).String())) + defer s.logger.StopProgress() + + results := s.stageContracts(ctx, stagedContracts) + return results, nil + } + + // Otherwise, validate and stage the contracts + return s.validateAndStageContracts(ctx, stagedContracts) +} + +func (s *stagingServiceImpl) validateAndStageContracts(ctx context.Context, contracts []stagedContractUpdate) (map[common.AddressLocation]stagingResult, error) { + s.logger.StartProgress(fmt.Sprintf("Validating and staging %d contracts", len(contracts))) + defer s.logger.StopProgress() + + // Validate all contracts + var validatorError *stagingValidatorError + err := s.validator.Validate(contracts) + + // We will handle validation errors separately per contract to allow for partial staging + if err != nil && !errors.As(err, &validatorError) { + return nil, fmt.Errorf("failed to validate contracts: %w", err) + } + + // Collect all staging errors to report to the user + results := make(map[common.AddressLocation]stagingResult) + if validatorError != nil { + for location, err := range validatorError.errors { + results[location] = stagingResult{ + Err: err, + WasValidated: true, + } + } + } + + // Split contracts into valid, and contracts with missing dependencies + missingDepsContracts := make([]stagedContractUpdate, 0) + validContracts := make([]stagedContractUpdate, 0) + if validatorError == nil { + validContracts = contracts + } else { + for _, contract := range contracts { + contractErr := validatorError.errors[contract.DeployLocation] + + var missingDepsError *missingDependenciesError + if errors.As(contractErr, &missingDepsError) { + missingDepsContracts = append(missingDepsContracts, contract) + } else if contractErr == nil { + validContracts = append(validContracts, contract) + } + } + } + + s.logger.StopProgress() + + // Now, handle contracts that were not validated due to missing dependencies + if len(missingDepsContracts) > 0 && s.unvalidatedContractsHandler(validatorError) { + for location, res := range s.stageContracts(ctx, missingDepsContracts) { + res.WasValidated = false + results[location] = res + } + } + + // Stage contracts that passed validation + for contractLocation, res := range s.stageContracts(ctx, validContracts) { + res.WasValidated = true + results[contractLocation] = res + } + + return results, nil +} + +func (s *stagingServiceImpl) stageContracts(ctx context.Context, contracts []stagedContractUpdate) map[common.AddressLocation]stagingResult { + results := make(map[common.AddressLocation]stagingResult) + for _, contract := range contracts { + txId, err := s.stageContract( + ctx, + contract, + ) + if err != nil { + results[contract.DeployLocation] = stagingResult{ + Err: fmt.Errorf("failed to stage contract: %w", err), + TxId: txId, + } + } else { + results[contract.DeployLocation] = stagingResult{ + Err: nil, + TxId: txId, + } + } + } + + return results +} + +func (s *stagingServiceImpl) stageContract(ctx context.Context, contract stagedContractUpdate) (flow.Identifier, error) { + s.logger.StartProgress(fmt.Sprintf("Staging contract %s", contract.DeployLocation)) + defer s.logger.StopProgress() + + // Check if the staged contract has changed + if !s.hasStagedContractChanged(contract) { + return flow.EmptyID, nil + } + + cName := cadence.String(contract.DeployLocation.Name) + cCode := cadence.String(contract.Code) + + // Get the account for the contract + account, err := s.state.Accounts().ByAddress(flow.Address(contract.DeployLocation.Address)) + if err != nil { + return flow.Identifier{}, fmt.Errorf("failed to get account for contract %s: %w", contract.DeployLocation.Name, err) + } + + tx, _, err := s.flow.SendTransaction( + context.Background(), + transactions.SingleAccountRole(*account), + flowkit.Script{ + Code: templates.GenerateStageContractScript(MigrationContractStagingAddress(s.flow.Network().Name)), + Args: []cadence.Value{cName, cCode}, + }, + flow.DefaultTransactionGasLimit, + ) + if err != nil { + return flow.Identifier{}, err + } + + return tx.ID(), nil +} + +func (s *stagingServiceImpl) hasStagedContractChanged(contract stagedContractUpdate) bool { + // Get the staged contract code + stagedCode, err := getStagedContractCode(context.Background(), s.flow, contract.DeployLocation) + if err != nil { + // swallow error, if we can't get the staged contract code, we should stage + return true + } + + if stagedCode == nil { + return true + } + + // If the staged contract code is different from the contract code, we need to stage it + if string(stagedCode) != string(contract.Code) { + return true + } + + return false +} + +func (s *stagingServiceImpl) convertToStagedContracts(contracts []*project.Contract) ([]stagedContractUpdate, error) { + // Collect all staged contracts + stagedContracts := make([]stagedContractUpdate, 0) + for _, contract := range contracts { + rawScript := flowkit.Script{ + Code: contract.Code(), + Location: contract.Location(), + Args: contract.Args, + } + + // Replace imports in the contract + script, err := s.flow.ReplaceImportsInScript(context.Background(), rawScript) + if err != nil { + return nil, fmt.Errorf("failed to replace imports in contract %s: %w", contract.Name, err) + } + + // We need the real name of the contract, not the name in flow.json + program, err := project.NewProgram(script.Code, script.Args, script.Location) + if err != nil { + return nil, fmt.Errorf("failed to parse contract %s: %w", contract.Name, err) + } + + name, err := program.Name() + if err != nil { + return nil, fmt.Errorf("failed to parse contract name: %w", err) + } + + // Convert relevant information to Cadence types + deployLocation := common.NewAddressLocation(nil, common.Address(contract.AccountAddress), name) + sourceLocation := common.StringLocation(contract.Location()) + + stagedContracts = append(stagedContracts, stagedContractUpdate{ + DeployLocation: deployLocation, + SourceLocation: sourceLocation, + Code: script.Code, + }) + } + + return stagedContracts, nil +} + +func (v *stagingServiceImpl) PrettyPrintValidationError(err error, location common.Location) string { + return v.validator.PrettyPrintError(err, location) +} diff --git a/internal/migrate/staging_service_test.go b/internal/migrate/staging_service_test.go new file mode 100644 index 000000000..1ee22d3fb --- /dev/null +++ b/internal/migrate/staging_service_test.go @@ -0,0 +1,559 @@ +/* + * Flow CLI + * + * Copyright 2019 Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package migrate + +import ( + "context" + "errors" + "path/filepath" + "reflect" + "testing" + + "github.com/onflow/cadence" + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/contract-updater/lib/go/templates" + "github.com/onflow/flow-go-sdk" + "github.com/onflow/flow-go-sdk/crypto" + "github.com/onflow/flowkit/v2" + "github.com/onflow/flowkit/v2/accounts" + "github.com/onflow/flowkit/v2/config" + flowkitMocks "github.com/onflow/flowkit/v2/mocks" + "github.com/onflow/flowkit/v2/project" + "github.com/onflow/flowkit/v2/tests" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/onflow/flow-cli/internal/util" +) + +type mockDeployment struct { + name string + code string +} + +type mockAccount struct { + name string + address string + deployments []mockDeployment +} + +func addAccountsToState( + t *testing.T, + state *flowkit.State, + accts []mockAccount, +) { + for _, account := range accts { + key, err := crypto.GeneratePrivateKey(crypto.ECDSA_P256, make([]byte, 32)) + require.NoError(t, err) + + state.Accounts().AddOrUpdate( + &accounts.Account{ + Name: account.name, + Address: flow.HexToAddress(account.address), + Key: accounts.NewHexKeyFromPrivateKey( + 0, + crypto.SHA3_256, + key, + ), + }, + ) + + contractDeployments := make([]config.ContractDeployment, 0) + for _, deployment := range account.deployments { + fname := account.address + "/" + deployment.name + ".cdc" + require.NoError(t, state.ReaderWriter().WriteFile(fname, []byte(deployment.code), 0644)) + + state.Contracts().AddOrUpdate( + config.Contract{ + Name: deployment.name, + Location: fname, + }, + ) + + contractDeployments = append( + contractDeployments, + config.ContractDeployment{ + Name: deployment.name, + }, + ) + } + + state.Deployments().AddOrUpdate( + config.Deployment{ + Network: "testnet", + Account: account.name, + Contracts: contractDeployments, + }, + ) + } +} + +func Test_StagingService(t *testing.T) { + setupMocks := func( + accts []mockAccount, + mocksStagedContracts map[common.AddressLocation][]byte, + ) (*flowkitMocks.Services, *flowkit.State, []*project.Contract) { + srv := flowkitMocks.NewServices(t) + rw, _ := tests.ReaderWriter() + state, err := flowkit.Init(rw) + require.NoError(t, err) + + addAccountsToState(t, state, accts) + + srv.On("Network", mock.Anything).Return(config.Network{ + Name: "testnet", + }, nil).Maybe() + + srv.On("ReplaceImportsInScript", mock.Anything, mock.Anything).Return(func(_ context.Context, script flowkit.Script) (flowkit.Script, error) { + return script, nil + }).Maybe() + + deploymentContracts, err := state.DeploymentContractsByNetwork(config.TestnetNetwork) + require.NoError(t, err) + + srv.On("SendTransaction", mock.Anything, mock.Anything, mock.MatchedBy(func(script flowkit.Script) bool { + expectedScript := templates.GenerateStageContractScript(MigrationContractStagingAddress("testnet")) + if string(script.Code) != string(expectedScript) { + return false + } + + if len(script.Args) != 2 { + return false + } + + _, ok := script.Args[0].(cadence.String) + if !ok { + return false + } + + _, ok = script.Args[1].(cadence.String) + return ok + }), mock.Anything).Return(tests.NewTransaction(), nil, nil).Maybe() + + // Mock staged contracts on network + for location, code := range mocksStagedContracts { + srv.On( + "ExecuteScript", + mock.Anything, + mock.MatchedBy(func(script flowkit.Script) bool { + if string(script.Code) != string(templates.GenerateGetStagedContractCodeScript(MigrationContractStagingAddress("testnet"))) { + return false + } + + if len(script.Args) != 2 { + return false + } + + callContractAddress, callContractName := script.Args[0], script.Args[1] + if callContractName != cadence.String(location.Name) { + return false + } + if callContractAddress != cadence.Address(location.Address) { + return false + } + + return true + }), + mock.Anything, + ).Return(cadence.NewOptional(cadence.String(code)), nil).Maybe() + } + + // Default all staged contracts to nil + srv.On( + "ExecuteScript", + mock.Anything, + mock.MatchedBy(func(script flowkit.Script) bool { + if string(script.Code) != string(templates.GenerateGetStagedContractCodeScript(MigrationContractStagingAddress("testnet"))) { + return false + } + + if len(script.Args) != 2 { + return false + } + + return true + }), + mock.Anything, + ).Return(cadence.NewOptional(nil), nil).Maybe() + + return srv, state, deploymentContracts + } + + t.Run("stages valid contracts", func(t *testing.T) { + mockAccount := []mockAccount{ + { + name: "some-account", + address: "0x01", + deployments: []mockDeployment{ + { + name: "Foo", + code: "access(all) contract Foo {}", + }, + { + name: "Bar", + code: "access(all) contract Bar {}", + }, + }, + }, + } + srv, state, deploymentContracts := setupMocks(mockAccount, nil) + + v := newMockStagingValidator(t) + v.On("Validate", mock.MatchedBy(func(stagedContracts []stagedContractUpdate) bool { + return reflect.DeepEqual(stagedContracts, []stagedContractUpdate{ + { + DeployLocation: simpleAddressLocation("0x01.Foo"), + SourceLocation: common.StringLocation(filepath.FromSlash("0x01/Foo.cdc")), + Code: []byte("access(all) contract Foo {}"), + }, + { + DeployLocation: simpleAddressLocation("0x01.Bar"), + SourceLocation: common.StringLocation(filepath.FromSlash("0x01/Bar.cdc")), + Code: []byte("access(all) contract Bar {}"), + }, + }) + })).Return(nil).Once() + + s := newStagingService( + srv, + state, + util.NoLogger, + v, + func(sve *stagingValidatorError) bool { + return false + }, + ) + + results, err := s.StageContracts( + context.Background(), + deploymentContracts, + ) + + require.NoError(t, err) + require.NotNil(t, results) + + require.Equal(t, 2, len(results)) + require.Contains(t, results, simpleAddressLocation("0x01.Foo")) + require.Nil(t, results[simpleAddressLocation("0x01.Foo")].Err) + require.Equal(t, results[simpleAddressLocation("0x01.Foo")].WasValidated, true) + require.IsType(t, flow.Identifier{}, results[simpleAddressLocation("0x01.Foo")].TxId) + + require.Contains(t, results, simpleAddressLocation("0x01.Bar")) + require.Nil(t, results[simpleAddressLocation("0x01.Bar")].Err) + require.Equal(t, results[simpleAddressLocation("0x01.Bar")].WasValidated, true) + require.IsType(t, flow.Identifier{}, results[simpleAddressLocation("0x01.Bar")].TxId) + + srv.AssertNumberOfCalls(t, "SendTransaction", 2) + }) + + t.Run("stages unvalidated contracts if chosen", func(t *testing.T) { + mockAccount := []mockAccount{ + { + name: "some-account", + address: "0x01", + deployments: []mockDeployment{ + { + name: "Foo", + code: "access(all) contract Foo {}", + }, + }, + }, + } + srv, state, deploymentContracts := setupMocks(mockAccount, nil) + + v := newMockStagingValidator(t) + v.On("Validate", mock.MatchedBy(func(stagedContracts []stagedContractUpdate) bool { + return reflect.DeepEqual(stagedContracts, []stagedContractUpdate{ + { + DeployLocation: simpleAddressLocation("0x01.Foo"), + SourceLocation: common.StringLocation(filepath.FromSlash("0x01/Foo.cdc")), + Code: []byte("access(all) contract Foo {}"), + }, + }) + })).Return(&stagingValidatorError{ + errors: map[common.AddressLocation]error{ + simpleAddressLocation("0x01.Foo"): &missingDependenciesError{ + MissingContracts: []common.AddressLocation{ + simpleAddressLocation("0x02.Bar"), + }, + }, + }, + }).Once() + + s := newStagingService( + srv, + state, + util.NoLogger, + v, + func(sve *stagingValidatorError) bool { + require.NotNil(t, sve) + return true + }, + ) + + results, err := s.StageContracts( + context.Background(), + deploymentContracts, + ) + + require.NoError(t, err) + require.NotNil(t, results) + + require.Equal(t, 1, len(results)) + require.Contains(t, results, simpleAddressLocation("0x01.Foo")) + require.Nil(t, results[simpleAddressLocation("0x01.Foo")].Err) + require.Equal(t, results[simpleAddressLocation("0x01.Foo")].WasValidated, false) + require.IsType(t, flow.Identifier{}, results[simpleAddressLocation("0x01.Foo")].TxId) + + srv.AssertNumberOfCalls(t, "SendTransaction", 1) + }) + + t.Run("skips validation if no validator", func(t *testing.T) { + mockAccount := []mockAccount{ + { + name: "some-account", + address: "0x01", + deployments: []mockDeployment{ + { + name: "Foo", + code: "access(all) contract Foo {}", + }, + }, + }, + } + srv, state, deploymentContracts := setupMocks(mockAccount, nil) + + s := newStagingService( + srv, + state, + util.NoLogger, + nil, + func(sve *stagingValidatorError) bool { + require.NotNil(t, sve) + return true + }, + ) + + results, err := s.StageContracts( + context.Background(), + deploymentContracts, + ) + + require.NoError(t, err) + require.NotNil(t, results) + + require.Equal(t, 1, len(results)) + require.Contains(t, results, simpleAddressLocation("0x01.Foo")) + require.Nil(t, results[simpleAddressLocation("0x01.Foo")].Err) + require.Equal(t, results[simpleAddressLocation("0x01.Foo")].WasValidated, false) + require.IsType(t, flow.Identifier{}, results[simpleAddressLocation("0x01.Foo")].TxId) + + srv.AssertNumberOfCalls(t, "SendTransaction", 1) + }) + + t.Run("returns missing dependency error if staging not chosen", func(t *testing.T) { + mockAccount := []mockAccount{ + { + name: "some-account", + address: "0x01", + deployments: []mockDeployment{ + { + name: "Foo", + code: "access(all) contract Foo {}", + }, + }, + }, + } + srv, state, deploymentContracts := setupMocks(mockAccount, nil) + + v := newMockStagingValidator(t) + v.On("Validate", mock.MatchedBy(func(stagedContracts []stagedContractUpdate) bool { + return reflect.DeepEqual(stagedContracts, []stagedContractUpdate{ + { + DeployLocation: simpleAddressLocation("0x01.Foo"), + SourceLocation: common.StringLocation(filepath.FromSlash("0x01/Foo.cdc")), + Code: []byte("access(all) contract Foo {}"), + }, + }) + })).Return(&stagingValidatorError{ + errors: map[common.AddressLocation]error{ + simpleAddressLocation("0x01.Foo"): &missingDependenciesError{ + MissingContracts: []common.AddressLocation{ + simpleAddressLocation("0x02.Bar"), + }, + }, + }, + }).Once() + + s := newStagingService( + srv, + state, + util.NoLogger, + v, + func(sve *stagingValidatorError) bool { + require.NotNil(t, sve) + return false + }, + ) + + results, err := s.StageContracts( + context.Background(), + deploymentContracts, + ) + + require.NoError(t, err) + require.NotNil(t, results) + + require.Equal(t, 1, len(results)) + require.Contains(t, results, simpleAddressLocation("0x01.Foo")) + + var mde *missingDependenciesError + require.ErrorAs(t, results[simpleAddressLocation("0x01.Foo")].Err, &mde) + require.NotNil(t, results[simpleAddressLocation("0x01.Foo")].Err) + require.Equal(t, []common.AddressLocation{simpleAddressLocation("0x02.Bar")}, mde.MissingContracts) + require.Equal(t, results[simpleAddressLocation("0x01.Foo")].WasValidated, true) + + srv.AssertNumberOfCalls(t, "SendTransaction", 0) + }) + + t.Run("reports and does not stage invalid contracts", func(t *testing.T) { + mockAccount := []mockAccount{ + { + name: "some-account", + address: "0x01", + deployments: []mockDeployment{ + { + name: "Foo", + code: "access(all) contract Foo {}", + }, + { + name: "Bar", + code: "access(all) contract Bar {}", + }, + }, + }, + } + srv, state, deploymentContracts := setupMocks(mockAccount, nil) + + v := newMockStagingValidator(t) + v.On("Validate", mock.MatchedBy(func(stagedContracts []stagedContractUpdate) bool { + return reflect.DeepEqual(stagedContracts, []stagedContractUpdate{ + { + DeployLocation: simpleAddressLocation("0x01.Foo"), + SourceLocation: common.StringLocation(filepath.FromSlash("0x01/Foo.cdc")), + Code: []byte("access(all) contract Foo {}"), + }, + { + DeployLocation: simpleAddressLocation("0x01.Bar"), + SourceLocation: common.StringLocation(filepath.FromSlash("0x01/Bar.cdc")), + Code: []byte("access(all) contract Bar {}"), + }, + }) + })).Return(&stagingValidatorError{ + errors: map[common.AddressLocation]error{ + simpleAddressLocation("0x01.Foo"): errors.New("FooError"), + }, + }).Once() + + s := newStagingService( + srv, + state, + util.NoLogger, + v, + func(sve *stagingValidatorError) bool { + return false + }, + ) + + results, err := s.StageContracts( + context.Background(), + deploymentContracts, + ) + + require.NoError(t, err) + require.NotNil(t, results) + + require.Equal(t, 2, len(results)) + require.Contains(t, results, simpleAddressLocation("0x01.Foo")) + require.ErrorContains(t, results[simpleAddressLocation("0x01.Foo")].Err, "FooError") + require.Equal(t, results[simpleAddressLocation("0x01.Foo")].WasValidated, true) + + require.Contains(t, results, simpleAddressLocation("0x01.Bar")) + require.Nil(t, results[simpleAddressLocation("0x01.Bar")].Err) + require.Equal(t, results[simpleAddressLocation("0x01.Bar")].WasValidated, true) + require.IsType(t, flow.Identifier{}, results[simpleAddressLocation("0x01.Bar")].TxId) + + srv.AssertNumberOfCalls(t, "SendTransaction", 1) + }) + + t.Run("skips staging contracts without changes", func(t *testing.T) { + mockAccount := []mockAccount{ + { + name: "some-account", + address: "0x01", + deployments: []mockDeployment{ + { + name: "Foo", + code: "access(all) contract Foo {}", + }, + }, + }, + } + srv, state, deploymentContracts := setupMocks(mockAccount, map[common.AddressLocation][]byte{ + simpleAddressLocation("0x01.Foo"): []byte("access(all) contract Foo {}"), + }) + + v := newMockStagingValidator(t) + v.On("Validate", mock.MatchedBy(func(stagedContracts []stagedContractUpdate) bool { + return reflect.DeepEqual(stagedContracts, []stagedContractUpdate{ + { + DeployLocation: simpleAddressLocation("0x01.Foo"), + SourceLocation: common.StringLocation(filepath.FromSlash("0x01/Foo.cdc")), + Code: []byte("access(all) contract Foo {}"), + }, + }) + })).Return(nil).Once() + + s := newStagingService( + srv, + state, + util.NoLogger, + v, + func(sve *stagingValidatorError) bool { + return false + }, + ) + + results, err := s.StageContracts( + context.Background(), + deploymentContracts, + ) + + require.NoError(t, err) + require.NotNil(t, results) + + require.Equal(t, 1, len(results)) + require.Contains(t, results, simpleAddressLocation("0x01.Foo")) + require.Nil(t, results[simpleAddressLocation("0x01.Foo")].Err) + require.Equal(t, true, results[simpleAddressLocation("0x01.Foo")].WasValidated) + require.Equal(t, flow.Identifier{}, results[simpleAddressLocation("0x01.Foo")].TxId) + + srv.AssertNumberOfCalls(t, "SendTransaction", 0) + }) +} diff --git a/internal/migrate/staging_validator.go b/internal/migrate/staging_validator.go index dac94e35b..2e92693c2 100644 --- a/internal/migrate/staging_validator.go +++ b/internal/migrate/staging_validator.go @@ -19,10 +19,14 @@ package migrate import ( + "bytes" "context" + "errors" "fmt" "strings" + "golang.org/x/exp/slices" + "github.com/onflow/cadence" "github.com/onflow/cadence/runtime" "github.com/onflow/cadence/runtime/ast" @@ -42,23 +46,41 @@ import ( "github.com/onflow/flow-cli/internal/util" ) -type stagingValidator struct { - flow flowkit.Services - state *flowkit.State +//go:generate mockery --name stagingValidator --inpackage --testonly --case underscore +type stagingValidator interface { + Validate(stagedContracts []stagedContractUpdate) error + PrettyPrintError(err error, location common.Location) string +} + +type stagingValidatorImpl struct { + flow flowkit.Services - // Location of the source code that is used for the update - sourceCodeLocation common.Location - // Location of the target contract that is being updated - targetLocation common.AddressLocation + stagedContracts map[common.AddressLocation]stagedContractUpdate // Cache for account contract names so we don't have to fetch them multiple times accountContractNames map[common.Address][]string // All resolved contract code contracts map[common.Location][]byte - // Record errors related to missing staged dependencies, as these are reported separately - missingDependencies []common.AddressLocation - // Cache for contract elaborations which are reused during program checking & used for the update checker - elaborations map[common.Location]*sema.Elaboration + + // Dependency graph for staged contracts + // This root level map holds all nodes + graph map[common.Location]node + + // Cache for contract checkers which are reused during program checking & used for the update checker + checkingCache map[common.Location]*cachedCheckingResult +} + +type node map[common.Location]node + +type cachedCheckingResult struct { + checker *sema.Checker + err error +} + +type stagedContractUpdate struct { + DeployLocation common.AddressLocation + SourceLocation common.StringLocation + Code []byte } type accountContractNamesProviderImpl struct { @@ -81,44 +103,202 @@ func (e *missingDependenciesError) Error() string { var _ error = &missingDependenciesError{} +type upstreamValidationError struct { + Location common.Location + BadDependencies []common.Location +} + +func (e *upstreamValidationError) Error() string { + return fmt.Sprintf("contract %s has upstream validation errors, related to the following dependencies: %v", e.Location, e.BadDependencies) +} + +var _ error = &upstreamValidationError{} + +type stagingValidatorError struct { + errors map[common.AddressLocation]error +} + +func (e *stagingValidatorError) Error() string { + var sb strings.Builder + for location, err := range e.errors { + sb.WriteString(fmt.Sprintf("error for contract %s: %s\n", location, err)) + } + return sb.String() +} + +// MissingDependencies returns the contracts dependended on by the staged contracts that are missing +func (e *stagingValidatorError) MissingDependencies() []common.AddressLocation { + missingDepsMap := make(map[common.AddressLocation]struct{}) + for _, err := range e.MissingDependencyErrors() { + for _, missingDep := range err.MissingContracts { + missingDepsMap[missingDep] = struct{}{} + } + } + + missingDependencies := make([]common.AddressLocation, 0) + for missingDep := range missingDepsMap { + missingDependencies = append(missingDependencies, missingDep) + } + + sortAddressLocations(missingDependencies) + return missingDependencies +} + +// ContractsMissingDependencies returns the contracts attempted to be validated that are missing dependencies +func (e *stagingValidatorError) MissingDependencyErrors() map[common.AddressLocation]*missingDependenciesError { + missingDependencyErrors := make(map[common.AddressLocation]*missingDependenciesError) + for location := range e.errors { + var missingDependenciesErr *missingDependenciesError + if errors.As(e.errors[location], &missingDependenciesErr) { + missingDependencyErrors[location] = missingDependenciesErr + } + } + return missingDependencyErrors +} + +var _ error = &stagingValidatorError{} + var chainIdMap = map[string]flow.ChainID{ "mainnet": flow.Mainnet, "testnet": flow.Testnet, } -func newStagingValidator(flow flowkit.Services, state *flowkit.State) *stagingValidator { - return &stagingValidator{ +func newStagingValidator(flow flowkit.Services) *stagingValidatorImpl { + return &stagingValidatorImpl{ flow: flow, - state: state, contracts: make(map[common.Location][]byte), - elaborations: make(map[common.Location]*sema.Elaboration), + checkingCache: make(map[common.Location]*cachedCheckingResult), accountContractNames: make(map[common.Address][]string), + graph: make(map[common.Location]node), } } -func (v *stagingValidator) ValidateContractUpdate( - // Network location of the contract to be updated - location common.AddressLocation, - // Location of the source code, ensures that the error messages reference this instead of a network location - sourceCodeLocation common.Location, - // Code of the updated contract - updatedCode []byte, -) error { - v.sourceCodeLocation = sourceCodeLocation - v.targetLocation = location - - // Resolve all system contract code & add to cache +func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate) error { + v.stagedContracts = make(map[common.AddressLocation]stagedContractUpdate) + for _, stagedContract := range stagedContracts { + v.stagedContracts[stagedContract.DeployLocation] = stagedContract + + // Add the contract code to the contracts map for pretty printing + v.contracts[stagedContract.SourceLocation] = stagedContract.Code + } + + // Load system contracts v.loadSystemContracts() + // Parse and check all staged contracts + errs := v.checkAllStaged() + + // Validate all contract updates + for _, contract := range v.stagedContracts { + // Don't validate contracts with existing errors + if errs[contract.SourceLocation] != nil { + continue + } + + // Validate the contract update + checker := v.checkingCache[contract.SourceLocation].checker + err := v.validateContractUpdate(contract, checker) + if err != nil { + errs[contract.SourceLocation] = err + } + } + + // Check for any upstream contract update failures + for _, contract := range v.stagedContracts { + err := errs[contract.SourceLocation] + + // We will override any errors other than those related + // to missing dependencies, since they are more specific + // forms of upstream validation errors + var missingDependenciesErr *missingDependenciesError + if errors.As(err, &missingDependenciesErr) { + continue + } + + // Leave cyclic import errors to the checker + var cyclicImportErr *sema.CyclicImportsError + if errors.As(err, &cyclicImportErr) { + continue + } + + badDeps := make([]common.Location, 0) + v.forEachDependency(contract, func(dependency common.Location) { + strLocation, ok := dependency.(common.StringLocation) + if !ok { + return + } + + if errs[strLocation] != nil { + badDeps = append(badDeps, dependency) + } + }) + + if len(badDeps) > 0 { + errs[contract.SourceLocation] = &upstreamValidationError{ + Location: contract.SourceLocation, + BadDependencies: badDeps, + } + } + } + + // Return a validator error if there are any errors + if len(errs) > 0 { + // Map errors to address locations + errsByAddress := make(map[common.AddressLocation]error) + for _, contract := range v.stagedContracts { + err := errs[contract.SourceLocation] + if err != nil { + errsByAddress[contract.DeployLocation] = err + } + } + return &stagingValidatorError{errors: errsByAddress} + } + return nil +} + +func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error { + errors := make(map[common.StringLocation]error) + for _, contract := range v.stagedContracts { + _, err := v.checkContract(contract.SourceLocation) + if err != nil { + errors[contract.SourceLocation] = err + } + } + + // Report any missing dependency errors separately + // These will override any other errors parsing/checking errors + // Note: nodes are not visited more than once so cyclic imports are not an issue + // They will be reported, however, by the checker, if they do exist + for _, contract := range v.stagedContracts { + // Create a set of all dependencies + missingDependencies := make([]common.AddressLocation, 0) + v.forEachDependency(contract, func(dependency common.Location) { + if code := v.contracts[dependency]; code == nil { + if dependency, ok := dependency.(common.AddressLocation); ok { + missingDependencies = append(missingDependencies, dependency) + } + } + }) + + if len(missingDependencies) > 0 { + errors[contract.SourceLocation] = &missingDependenciesError{ + MissingContracts: missingDependencies, + } + } + } + return errors +} + +func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpdate, checker *sema.Checker) error { // Get the account for the contract - address := flowsdk.Address(location.Address) + address := flowsdk.Address(contract.DeployLocation.Address) account, err := v.flow.GetAccount(context.Background(), address) if err != nil { return fmt.Errorf("failed to get account: %w", err) } // Get the target contract old code - contractName := location.Name + contractName := contract.DeployLocation.Name contractCode, ok := account.Contracts[contractName] if !ok { return fmt.Errorf("old contract code not found for contract: %s", contractName) @@ -130,36 +310,19 @@ func (v *stagingValidator) ValidateContractUpdate( return fmt.Errorf("failed to parse old contract code: %w", err) } - // Store contract code for error pretty printing - v.contracts[sourceCodeLocation] = updatedCode - - // Parse and check the contract code - _, newProgramChecker, err := v.parseAndCheckContract(sourceCodeLocation) - - // Errors related to missing dependencies are separate from other errors - // They may be handled differently by the caller, and it's parsing/checking - // errors are not relevant/informative if these are present (they are expected) - if len(v.missingDependencies) > 0 { - return &missingDependenciesError{MissingContracts: v.missingDependencies} - } - - if err != nil { - return err - } - // Convert the new program checker to an interpreter program - interpreterProgram := interpreter.ProgramFromChecker(newProgramChecker) + interpreterProgram := interpreter.ProgramFromChecker(checker) // Check if contract code is valid according to Cadence V1 Update Checker validator := stdlib.NewCadenceV042ToV1ContractUpdateValidator( - sourceCodeLocation, + contract.SourceLocation, contractName, &accountContractNamesProviderImpl{ resolverFunc: v.resolveAddressContractNames, }, oldProgram, interpreterProgram, - v.elaborations, + v.elaborations(), ) // Set the user defined type change checker @@ -177,21 +340,56 @@ func (v *stagingValidator) ValidateContractUpdate( return nil } -func (v *stagingValidator) parseAndCheckContract( - location common.Location, -) (*ast.Program, *sema.Checker, error) { - code := v.contracts[location] +// Check a contract by location +func (v *stagingValidatorImpl) checkContract( + importedLocation common.Location, +) (checker *sema.Checker, err error) { + // Try to load cached checker + if cacheItem, ok := v.checkingCache[importedLocation]; ok { + return cacheItem.checker, cacheItem.err + } + + // Cache the checking result + defer func() { + var cacheItem *cachedCheckingResult + if existingCacheItem, ok := v.checkingCache[importedLocation]; ok { + cacheItem = existingCacheItem + } else { + cacheItem = &cachedCheckingResult{} + } + + cacheItem.checker = checker + cacheItem.err = err + }() + + // Resolve the contract code and real location based on whether this is a staged update + var code []byte + + // If it's an address location, get the staged contract code from the network + if addressLocation, ok := importedLocation.(common.AddressLocation); ok { + code, err = v.getStagedContractCode(addressLocation) + if err != nil { + return nil, err + } + } else { + // Otherwise, the code is already known + code = v.contracts[importedLocation] + if code == nil { + return nil, fmt.Errorf("contract code not found for location: %s", importedLocation) + } + } // Parse the contract code - program, err := parser.ParseProgram(nil, code, parser.Config{}) + var program *ast.Program + program, err = parser.ParseProgram(nil, code, parser.Config{}) if err != nil { - return nil, nil, err + return nil, err } // Check the contract code - checker, err := sema.NewChecker( + checker, err = sema.NewChecker( program, - location, + importedLocation, nil, &sema.Config{ AccessCheckMode: sema.AccessCheckModeStrict, @@ -206,18 +404,19 @@ func (v *stagingValidator) parseAndCheckContract( }, ) if err != nil { - return nil, nil, err + return nil, err } - err = checker.Check() - if err != nil { - return nil, nil, err + // We must add this checker to the cache before checking to prevent cyclic imports + v.checkingCache[importedLocation] = &cachedCheckingResult{ + checker: checker, } - return program, checker, nil + err = checker.Check() + return checker, err } -func (v *stagingValidator) getStagedContractCode( +func (v *stagingValidatorImpl) getStagedContractCode( location common.AddressLocation, ) ([]byte, error) { // First check if the code is already known @@ -227,39 +426,21 @@ func (v *stagingValidator) getStagedContractCode( return code, nil } - cAddr := cadence.BytesToAddress(location.Address.Bytes()) - cName, err := cadence.NewString(location.Name) - if err != nil { - return nil, fmt.Errorf("failed to get cadence string from contract name: %w", err) - } - - value, err := v.flow.ExecuteScript( - context.Background(), - flowkit.Script{ - Code: templates.GenerateGetStagedContractCodeScript(MigrationContractStagingAddress(v.flow.Network().Name)), - Args: []cadence.Value{cAddr, cName}, - }, - flowkit.LatestScriptQuery, - ) + code, err := getStagedContractCode(context.Background(), v.flow, location) if err != nil { return nil, err } - optValue, ok := value.(cadence.Optional) - if !ok { - return nil, fmt.Errorf("invalid script return value type: %T", value) - } - - strValue, ok := optValue.Value.(cadence.String) - if !ok { - return nil, fmt.Errorf("invalid script return value type: %T", value) - } - - v.contracts[location] = []byte(strValue) + v.contracts[location] = code return v.contracts[location], nil } -func (v *stagingValidator) resolveImport(checker *sema.Checker, importedLocation common.Location, _ ast.Range) (sema.Import, error) { +func (v *stagingValidatorImpl) resolveImport(parentChecker *sema.Checker, importedLocation common.Location, _ ast.Range) (sema.Import, error) { + // Add this to the dependency graph + if parentChecker != nil { + v.addDependency(parentChecker.Location, importedLocation) + } + // Check if the imported location is the crypto checker if importedLocation == stdlib.CryptoCheckerLocation { cryptoChecker := stdlib.CryptoChecker() @@ -268,62 +449,19 @@ func (v *stagingValidator) resolveImport(checker *sema.Checker, importedLocation }, nil } - // Check if the imported location is an address location - // No other location types are supported (as is the case with code on-chain) - addrLocation, ok := importedLocation.(common.AddressLocation) - if !ok { - return nil, fmt.Errorf("expected address location") - } - // Check if this contract has already been resolved - elaboration, ok := v.elaborations[importedLocation] - - // If not resolved, parse and check the contract code - if !ok { - importedCode, err := v.getStagedContractCode(addrLocation) - if err != nil { - v.missingDependencies = append(v.missingDependencies, addrLocation) - return nil, fmt.Errorf("failed to get staged contract code: %w", err) - } - v.contracts[addrLocation] = importedCode - - _, checker, err = v.parseAndCheckContract(addrLocation) - if err != nil { - return nil, fmt.Errorf("failed to parse and check contract code: %w", err) - } - - v.elaborations[importedLocation] = checker.Elaboration - elaboration = checker.Elaboration + subChecker, err := v.checkContract(importedLocation) + if err != nil { + return nil, err } return sema.ElaborationImport{ - Elaboration: elaboration, + Elaboration: subChecker.Elaboration, }, nil } -func (v *stagingValidator) loadSystemContracts() { - chainId, ok := chainIdMap[v.flow.Network().Name] - if !ok { - return - } - - stagedSystemContracts := migrations.SystemContractChanges(chainId, migrations.SystemContractsMigrationOptions{ - Burner: migrations.BurnerContractChangeUpdate, // needs to be update for now since BurnerChangeDeploy is a no-op in flow-go - EVM: migrations.EVMContractChangeFull, - }) - for _, stagedSystemContract := range stagedSystemContracts { - location := common.AddressLocation{ - Address: stagedSystemContract.Address, - Name: stagedSystemContract.Name, - } - - v.contracts[location] = stagedSystemContract.Code - v.accountContractNames[stagedSystemContract.Address] = append(v.accountContractNames[stagedSystemContract.Address], stagedSystemContract.Name) - } -} - -// This is a copy of the resolveLocation function from the linter/language server -func (v *stagingValidator) resolveLocation( +// This mostly is a copy of the resolveLocation function from the linter/language server +func (v *stagingValidatorImpl) resolveLocation( identifiers []ast.Identifier, location common.Location, ) ( @@ -378,11 +516,22 @@ func (v *stagingValidator) resolveLocation( resolvedLocations := make([]runtime.ResolvedLocation, len(identifiers)) for i := range resolvedLocations { identifier := identifiers[i] + + var resolvedLocation common.Location + resovledAddrLocation := common.AddressLocation{ + Address: addressLocation.Address, + Name: identifier.Identifier, + } + + // If the contract one of our staged contract updates, use the source location + if stagedUpdate, ok := v.stagedContracts[resovledAddrLocation]; ok { + resolvedLocation = stagedUpdate.SourceLocation + } else { + resolvedLocation = resovledAddrLocation + } + resolvedLocations[i] = runtime.ResolvedLocation{ - Location: common.AddressLocation{ - Address: addressLocation.Address, - Name: identifier.Identifier, - }, + Location: resolvedLocation, Identifiers: []runtime.Identifier{identifier}, } } @@ -390,27 +539,49 @@ func (v *stagingValidator) resolveLocation( return resolvedLocations, nil } -func (v *stagingValidator) resolveAccountAccess(checker *sema.Checker, memberLocation common.Location) bool { +func (v *stagingValidatorImpl) resolveAccountAccess(checker *sema.Checker, memberLocation common.Location) bool { if checker == nil { return false } - checkerLocation, ok := checker.Location.(common.StringLocation) - if !ok { - return false + var memberAddress common.Address + if memberAddressLocation, ok := memberLocation.(common.AddressLocation); ok { + memberAddress = memberAddressLocation.Address + } else if memberStringLocation, ok := memberLocation.(common.StringLocation); ok { + found := false + for _, stagedContract := range v.stagedContracts { + if stagedContract.SourceLocation == memberStringLocation { + memberAddress = stagedContract.DeployLocation.Address + found = true + break + } + } + if !found { + return false + } } - memberAddressLocation, ok := memberLocation.(common.AddressLocation) - if !ok { - return false + var checkerAddress common.Address + if checkerAddressLocation, ok := checker.Location.(common.AddressLocation); ok { + checkerAddress = checkerAddressLocation.Address + } else if checkerStringLocation, ok := checker.Location.(common.StringLocation); ok { + found := false + for _, stagedContract := range v.stagedContracts { + if stagedContract.SourceLocation == checkerStringLocation { + checkerAddress = stagedContract.DeployLocation.Address + found = true + break + } + } + if !found { + return false + } } - // If the source code of the update is being checked, we should check account access based on the - // targeted network location of the contract & not the source code location - return checkerLocation == v.sourceCodeLocation && memberAddressLocation.Address == v.targetLocation.Address + return memberAddress == checkerAddress } -func (v *stagingValidator) resolveAddressContractNames(address common.Address) ([]string, error) { +func (v *stagingValidatorImpl) resolveAddressContractNames(address common.Address) ([]string, error) { // Check if the contract names are already cached if names, ok := v.accountContractNames[address]; ok { return names, nil @@ -452,11 +623,78 @@ func (v *stagingValidator) resolveAddressContractNames(address common.Address) ( return v.accountContractNames[address], nil } +func (v *stagingValidatorImpl) loadSystemContracts() { + chainId, ok := chainIdMap[v.flow.Network().Name] + if !ok { + return + } + + stagedSystemContracts := migrations.SystemContractChanges(chainId, migrations.SystemContractsMigrationOptions{ + Burner: migrations.BurnerContractChangeUpdate, // needs to be update for now since BurnerChangeDeploy is a no-op in flow-go + EVM: migrations.EVMContractChangeFull, + }) + for _, stagedSystemContract := range stagedSystemContracts { + location := common.AddressLocation{ + Address: stagedSystemContract.Address, + Name: stagedSystemContract.Name, + } + + v.contracts[location] = stagedSystemContract.Code + v.accountContractNames[stagedSystemContract.Address] = append(v.accountContractNames[stagedSystemContract.Address], stagedSystemContract.Name) + } +} + +func (v *stagingValidatorImpl) elaborations() map[common.Location]*sema.Elaboration { + elaborations := make(map[common.Location]*sema.Elaboration) + for location, cacheItem := range v.checkingCache { + checker := cacheItem.checker + if checker == nil { + continue + } + elaborations[location] = checker.Elaboration + } + return elaborations +} + +func (v *stagingValidatorImpl) addDependency(dependent common.Location, dependency common.Location) { + // Create the dependent node if it does not exist + if _, ok := v.graph[dependent]; !ok { + v.graph[dependent] = make(node) + } + + // Create the dependency node if it does not exist + if _, ok := v.graph[dependency]; !ok { + v.graph[dependency] = make(node) + } + + // Add the dependency + v.graph[dependent][dependency] = v.graph[dependency] +} + +func (v *stagingValidatorImpl) forEachDependency( + contract stagedContractUpdate, + visitor func(dependency common.Location), +) { + seen := make(map[common.Location]bool) + var traverse func(location common.Location) + traverse = func(location common.Location) { + seen[location] = true + + for dep := range v.graph[location] { + if !seen[dep] { + visitor(dep) + traverse(dep) + } + } + } + traverse(contract.SourceLocation) +} + // Helper for pretty printing errors // While it is done by default in checker/parser errors, this has two purposes: // 1. Add color to the error message // 2. Use pretty printing on contract update errors which do not do this by default -func (v *stagingValidator) prettyPrintError(err error, location common.Location) string { +func (v *stagingValidatorImpl) PrettyPrintError(err error, location common.Location) string { var sb strings.Builder printErr := pretty.NewErrorPrettyPrinter(&sb, true). PrettyPrintError(err, location, v.contracts) @@ -474,3 +712,14 @@ func (a *accountContractNamesProviderImpl) GetAccountContractNames( ) ([]string, error) { return a.resolverFunc(address) } + +// util to sort address locations +func sortAddressLocations(locations []common.AddressLocation) { + slices.SortFunc(locations, func(a common.AddressLocation, b common.AddressLocation) int { + addrCmp := bytes.Compare(a.Address.Bytes(), b.Address.Bytes()) + if addrCmp != 0 { + return addrCmp + } + return strings.Compare(a.Name, b.Name) + }) +} diff --git a/internal/migrate/staging_validator_test.go b/internal/migrate/staging_validator_test.go index fd2e5070b..afa612adf 100644 --- a/internal/migrate/staging_validator_test.go +++ b/internal/migrate/staging_validator_test.go @@ -19,10 +19,9 @@ package migrate import ( + "strings" "testing" - "github.com/onflow/flow-cli/internal/util" - "github.com/onflow/cadence" "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/sema" @@ -31,15 +30,109 @@ import ( "github.com/onflow/flow-go-sdk" "github.com/onflow/flowkit/v2" "github.com/onflow/flowkit/v2/config" - "github.com/stretchr/testify/assert" + flowkitMocks "github.com/onflow/flowkit/v2/mocks" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) +type mockNetworkAccount struct { + address flow.Address + contracts map[string][]byte + stagedContracts map[string][]byte +} + +func setupValidatorMocks( + t *testing.T, + accounts []mockNetworkAccount, +) *flowkitMocks.Services { + t.Helper() + srv := flowkitMocks.NewServices(t) + + // Mock all accounts & staged contracts + for _, acc := range accounts { + mockAcct := &flow.Account{ + Address: acc.address, + Balance: 1000, + Keys: nil, + Contracts: acc.contracts, + } + + srv.On("GetAccount", mock.Anything, acc.address).Return(mockAcct, nil).Maybe() + + for contractName, code := range acc.stagedContracts { + srv.On( + "ExecuteScript", + mock.Anything, + mock.MatchedBy(func(script flowkit.Script) bool { + if string(script.Code) != string(templates.GenerateGetStagedContractCodeScript(MigrationContractStagingAddress("testnet"))) { + return false + } + + if len(script.Args) != 2 { + return false + } + + callContractAddress, callContractName := script.Args[0], script.Args[1] + if callContractName != cadence.String(contractName) { + return false + } + if callContractAddress != cadence.Address(acc.address) { + return false + } + + return true + }), + mock.Anything, + ).Return(cadence.NewOptional(cadence.String(code)), nil).Maybe() + } + } + + // Mock trying to get staged contract code for a contract that doesn't exist + // This is the fallback mock for all other staged contract code requests + srv.On( + "ExecuteScript", + mock.Anything, + mock.MatchedBy(func(script flowkit.Script) bool { + if string(script.Code) != string(templates.GenerateGetStagedContractCodeScript(MigrationContractStagingAddress("testnet"))) { + return false + } + + if len(script.Args) != 2 { + return false + } + + callContractAddress, callContractName := script.Args[0], script.Args[1] + + if callContractAddress.Type() != cadence.AddressType { + return false + } + + if callContractName.Type() != cadence.StringType { + return false + } + + return true + }), + mock.Anything, + ).Return(cadence.NewOptional(nil), nil).Maybe() + + srv.On("Network", mock.Anything).Return(config.Network{ + Name: "testnet", + }, nil).Maybe() + + return srv +} + +// Helper for creating address locations from strings in tests +func simpleAddressLocation(location string) common.AddressLocation { + split := strings.Split(location, ".") + addr, _ := common.HexToAddress(split[0]) + return common.NewAddressLocation(nil, addr, split[1]) +} + func Test_StagingValidator(t *testing.T) { - srv, state, rw := util.TestMocks(t) t.Run("valid contract update with no dependencies", func(t *testing.T) { - location := common.NewAddressLocation(nil, common.Address{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, "Test") + location := simpleAddressLocation("0x01.Test") sourceCodeLocation := common.StringLocation("./Test.cdc") oldContract := ` pub contract Test { @@ -49,31 +142,23 @@ func Test_StagingValidator(t *testing.T) { access(all) contract Test { access(all) fun test() {} }` - mockAccount := &flow.Account{ - Address: flow.HexToAddress("01"), - Balance: 1000, - Keys: nil, - Contracts: map[string][]byte{ - "Test": []byte(oldContract), - }, - } // setup mocks - require.NoError(t, rw.WriteFile(sourceCodeLocation.String(), []byte(newContract), 0o644)) - srv.GetAccount.Run(func(args mock.Arguments) { - require.Equal(t, flow.HexToAddress("01"), args.Get(1).(flow.Address)) - }).Return(mockAccount, nil) - srv.Network.Return(config.Network{ - Name: "testnet", - }, nil) - - validator := newStagingValidator(srv.Mock, state) - err := validator.ValidateContractUpdate(location, sourceCodeLocation, []byte(newContract)) + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Test": []byte(oldContract)}, + stagedContracts: nil, + }, + }) + + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{{location, sourceCodeLocation, []byte(newContract)}}) require.NoError(t, err) }) t.Run("contract update with update error", func(t *testing.T) { - location := common.NewAddressLocation(nil, common.Address{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, "Test") + location := simpleAddressLocation("0x01.Test") sourceCodeLocation := common.StringLocation("./Test.cdc") oldContract := ` pub contract Test { @@ -88,32 +173,28 @@ func Test_StagingValidator(t *testing.T) { self.x = 1 } }` - mockAccount := &flow.Account{ - Address: flow.HexToAddress("01"), - Balance: 1000, - Keys: nil, - Contracts: map[string][]byte{ - "Test": []byte(oldContract), - }, - } // setup mocks - require.NoError(t, rw.WriteFile(sourceCodeLocation.String(), []byte(newContract), 0o644)) - srv.GetAccount.Run(func(args mock.Arguments) { - require.Equal(t, flow.HexToAddress("01"), args.Get(1).(flow.Address)) - }).Return(mockAccount, nil) - srv.Network.Return(config.Network{ - Name: "testnet", - }, nil) - - validator := newStagingValidator(srv.Mock, state) - err := validator.ValidateContractUpdate(location, sourceCodeLocation, []byte(newContract)) + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Test": []byte(oldContract)}, + stagedContracts: nil, + }, + }) + + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{{location, sourceCodeLocation, []byte(newContract)}}) + + var validatorErr *stagingValidatorError + require.ErrorAs(t, err, &validatorErr) + var updateErr *stdlib.ContractUpdateError - require.ErrorAs(t, err, &updateErr) + require.ErrorAs(t, validatorErr.errors[simpleAddressLocation("0x01.Test")], &updateErr) }) t.Run("contract update with checker error", func(t *testing.T) { - location := common.NewAddressLocation(nil, common.Address{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, "Test") + location := simpleAddressLocation("0x01.Test") sourceCodeLocation := common.StringLocation("./Test.cdc") oldContract := ` pub contract Test { @@ -129,32 +210,28 @@ func Test_StagingValidator(t *testing.T) { self.x = "bad type :(" } }` - mockAccount := &flow.Account{ - Address: flow.HexToAddress("01"), - Balance: 1000, - Keys: nil, - Contracts: map[string][]byte{ - "Test": []byte(oldContract), - }, - } // setup mocks - require.NoError(t, rw.WriteFile(sourceCodeLocation.String(), []byte(newContract), 0o644)) - srv.GetAccount.Run(func(args mock.Arguments) { - require.Equal(t, flow.HexToAddress("01"), args.Get(1).(flow.Address)) - }).Return(mockAccount, nil) - srv.Network.Return(config.Network{ - Name: "testnet", - }, nil) - - validator := newStagingValidator(srv.Mock, state) - err := validator.ValidateContractUpdate(location, sourceCodeLocation, []byte(newContract)) + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Test": []byte(oldContract)}, + stagedContracts: nil, + }, + }) + + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{{location, sourceCodeLocation, []byte(newContract)}}) + + var validatorErr *stagingValidatorError + require.ErrorAs(t, err, &validatorErr) + var checkerErr *sema.CheckerError - require.ErrorAs(t, err, &checkerErr) + require.ErrorAs(t, validatorErr.errors[simpleAddressLocation("0x01.Test")], &checkerErr) }) t.Run("valid contract update with dependencies", func(t *testing.T) { - location := common.NewAddressLocation(nil, common.Address{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, "Test") + location := simpleAddressLocation("0x01.Test") sourceCodeLocation := common.StringLocation("./Test.cdc") oldContract := ` pub contract Test { @@ -172,88 +249,66 @@ func Test_StagingValidator(t *testing.T) { self.x = 1 } }` - mockScriptResultString, err := cadence.NewString(impContract) - require.NoError(t, err) - - mockAccount := &flow.Account{ - Address: flow.HexToAddress("01"), - Balance: 1000, - Keys: nil, - Contracts: map[string][]byte{ - "Test": []byte(oldContract), - }, - } // setup mocks - require.NoError(t, rw.WriteFile(sourceCodeLocation.String(), []byte(newContract), 0o644)) - srv.GetAccount.Run(func(args mock.Arguments) { - require.Equal(t, flow.HexToAddress("01"), args.Get(1).(flow.Address)) - }).Return(mockAccount, nil) - srv.Network.Return(config.Network{ - Name: "testnet", - }, nil) - srv.ExecuteScript.Run(func(args mock.Arguments) { - script := args.Get(1).(flowkit.Script) - - assert.Equal(t, templates.GenerateGetStagedContractCodeScript(MigrationContractStagingAddress("testnet")), script.Code) - - assert.Equal(t, 2, len(script.Args)) - actualContractAddressArg, actualContractNameArg := script.Args[0], script.Args[1] - - contractName, _ := cadence.NewString("ImpContract") - contractAddr := cadence.NewAddress(flow.HexToAddress("02")) - assert.Equal(t, contractName, actualContractNameArg) - assert.Equal(t, contractAddr, actualContractAddressArg) - }).Return(cadence.NewOptional(mockScriptResultString), nil) + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Test": []byte(oldContract)}, + }, + { + address: flow.HexToAddress("02"), + stagedContracts: map[string][]byte{"ImpContract": []byte(impContract)}, + }, + }) // validate - validator := newStagingValidator(srv.Mock, state) - err = validator.ValidateContractUpdate(location, sourceCodeLocation, []byte(newContract)) + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{{location, sourceCodeLocation, []byte(newContract)}}) require.NoError(t, err) }) t.Run("contract update missing dependency", func(t *testing.T) { - location := common.NewAddressLocation(nil, common.Address{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, "Test") - impLocation := common.NewAddressLocation(nil, common.Address{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02}, "ImpContract") + location := simpleAddressLocation("0x01.Test") + impLocation := simpleAddressLocation("0x02.ImpContract") sourceCodeLocation := common.StringLocation("./Test.cdc") oldContract := ` pub contract Test { pub fun test() {} }` newContract := ` + // staged contract does not exist import ImpContract from 0x02 access(all) contract Test { access(all) fun test() {} }` - mockAccount := &flow.Account{ - Address: flow.HexToAddress("01"), - Balance: 1000, - Keys: nil, - Contracts: map[string][]byte{ - "Test": []byte(oldContract), - }, - } // setup mocks - require.NoError(t, rw.WriteFile(sourceCodeLocation.String(), []byte(newContract), 0o644)) - srv.GetAccount.Run(func(args mock.Arguments) { - require.Equal(t, flow.HexToAddress("01"), args.Get(1).(flow.Address)) - }).Return(mockAccount, nil) - srv.Network.Return(config.Network{ - Name: "testnet", - }, nil) - srv.ExecuteScript.Return(cadence.NewOptional(nil), nil) - - validator := newStagingValidator(srv.Mock, state) - err := validator.ValidateContractUpdate(location, sourceCodeLocation, []byte(newContract)) - var missingDepsErr *missingDependenciesError - require.ErrorAs(t, err, &missingDepsErr) - require.Equal(t, 1, len(missingDepsErr.MissingContracts)) - require.Equal(t, impLocation, missingDepsErr.MissingContracts[0]) + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Test": []byte(oldContract)}, + }, + { + address: flow.HexToAddress("02"), + }, + }) + + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{{location, sourceCodeLocation, []byte(newContract)}}) + + var validatorErr *stagingValidatorError + require.ErrorAs(t, err, &validatorErr) + require.Equal(t, 1, len(validatorErr.errors)) + + var missingDependenciesErr *missingDependenciesError + require.ErrorAs(t, validatorErr.errors[simpleAddressLocation("0x01.Test")], &missingDependenciesErr) + require.Equal(t, 1, len(missingDependenciesErr.MissingContracts)) + require.Equal(t, impLocation, missingDependenciesErr.MissingContracts[0]) }) t.Run("valid contract update with system contract imports", func(t *testing.T) { - location := common.NewAddressLocation(nil, common.Address{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, "Test") + location := simpleAddressLocation("0x01.Test") sourceCodeLocation := common.StringLocation("./Test.cdc") oldContract := ` import FlowToken from 0x7e60df042a9c0868 @@ -266,31 +321,190 @@ func Test_StagingValidator(t *testing.T) { access(all) contract Test { access(all) fun test() {} }` - mockAccount := &flow.Account{ - Address: flow.HexToAddress("01"), - Balance: 1000, - Keys: nil, - Contracts: map[string][]byte{ - "Test": []byte(oldContract), + + // setup mocks + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Test": []byte(oldContract)}, }, - } + }) + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{{location, sourceCodeLocation, []byte(newContract)}}) + require.NoError(t, err) + }) + + t.Run("resolves account access correctly", func(t *testing.T) { // setup mocks - require.NoError(t, rw.WriteFile(sourceCodeLocation.String(), []byte(newContract), 0o644)) - srv.GetAccount.Run(func(args mock.Arguments) { - require.Equal(t, flow.HexToAddress("01"), args.Get(1).(flow.Address)) - }).Return(mockAccount, nil) - srv.Network.Return(config.Network{ - Name: "testnet", - }, nil) - - validator := newStagingValidator(srv.Mock, state) - err := validator.ValidateContractUpdate(location, sourceCodeLocation, []byte(newContract)) + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{ + "Test": []byte(` + import ImpContract from 0x01 + pub contract Test { + pub fun test() {} + }`), + "Imp2": []byte(` + pub contract Imp2 { + access(account) fun test() {} + }`), + }, + stagedContracts: map[string][]byte{"ImpContract": []byte(` + access(all) contract ImpContract { + access(account) fun test() {} + init() {} + }`)}, + }, + }) + + // validate + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{ + { + simpleAddressLocation("0x01.Test"), + common.StringLocation("./Test.cdc"), + []byte(` + import ImpContract from 0x01 + import Imp2 from 0x01 + access(all) contract Test { + access(all) fun test() {} + init() { + ImpContract.test() + Imp2.test() + } + }`), + }, + { + simpleAddressLocation("0x01.Imp2"), + common.StringLocation("./Imp2.cdc"), + []byte(` + access(all) contract Imp2 { + access(account) fun test() {} + }`), + }, + }) require.NoError(t, err) }) + t.Run("validates multiple contracts, no error", func(t *testing.T) { + // setup mocks + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Foo": []byte(` + pub contract Foo { + pub fun test() {} + }`)}, + }, + { + address: flow.HexToAddress("02"), + contracts: map[string][]byte{"Bar": []byte(` + import Foo from 0x01 + pub contract Bar { + pub fun test() {} + init() { + Foo.test() + } + }`)}, + }, + }) + + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{ + { + DeployLocation: simpleAddressLocation("0x01.Foo"), + SourceLocation: common.StringLocation("./Foo.cdc"), + Code: []byte(` + access(all) contract Foo { + access(all) fun test() {} + }`), + }, + { + DeployLocation: simpleAddressLocation("0x02.Bar"), + SourceLocation: common.StringLocation("./Bar.cdc"), + Code: []byte(` + import Foo from 0x01 + access(all) contract Bar { + access(all) fun test() {} + init() { + Foo.test() + } + }`), + }, + }) + + require.NoError(t, err) + }) + + t.Run("validates multiple contracts with errors", func(t *testing.T) { + // setup mocks + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Foo": []byte(` + pub contract Foo { + pub fun test() {} + init() {} + }`)}, + }, + { + address: flow.HexToAddress("02"), + contracts: map[string][]byte{"Bar": []byte(` + pub contract Bar { + pub fun test() {} + init() { + Foo.test() + } + }`)}, + }, + }) + + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{ + { + DeployLocation: simpleAddressLocation("0x01.Foo"), + SourceLocation: common.StringLocation("./Foo.cdc"), + Code: []byte(` + access(all) contract Foo { + access(all) fun test() {} + init() { + let x: Int = "bad type" + } + }`), + }, + { + DeployLocation: simpleAddressLocation("0x02.Bar"), + SourceLocation: common.StringLocation("./Bar.cdc"), + Code: []byte(` + import Foo from 0x01 + access(all) contract Bar { + access(all) fun test() {} + init() { + Foo.test() + } + }`), + }, + }) + + var validatorErr *stagingValidatorError + require.ErrorAs(t, err, &validatorErr) + + require.Equal(t, 2, len(validatorErr.errors)) + + // check that error exists & ensure that the local contract names are used (not the deploy locations) + fooErr := validatorErr.errors[simpleAddressLocation("0x01.Foo")] + require.ErrorContains(t, fooErr, "mismatched types") + require.ErrorContains(t, fooErr, "Foo.cdc") + + // Bar should have an error related to + var upstreamErr *upstreamValidationError + require.ErrorAs(t, validatorErr.errors[simpleAddressLocation("0x02.Bar")], &upstreamErr) + }) + t.Run("resolves account access correctly", func(t *testing.T) { - location := common.NewAddressLocation(nil, common.Address{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, "Test") + location := simpleAddressLocation("0x01.Test") sourceCodeLocation := common.StringLocation("./Test.cdc") oldContract := ` import ImpContract from 0x01 @@ -310,43 +524,242 @@ func Test_StagingValidator(t *testing.T) { access(account) fun test() {} init() {} }` - mockScriptResultString, err := cadence.NewString(impContract) - require.NoError(t, err) - mockAccount := &flow.Account{ - Address: flow.HexToAddress("01"), - Balance: 1000, - Keys: nil, - Contracts: map[string][]byte{ - "Test": []byte(oldContract), + // setup mocks + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Test": []byte(oldContract)}, + stagedContracts: map[string][]byte{"ImpContract": []byte(impContract)}, }, - } + }) + // validate + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{{location, sourceCodeLocation, []byte(newContract)}}) + require.NoError(t, err) + }) + + t.Run("validates multiple contracts, no error", func(t *testing.T) { // setup mocks - require.NoError(t, rw.WriteFile(sourceCodeLocation.String(), []byte(newContract), 0o644)) - srv.GetAccount.Run(func(args mock.Arguments) { - require.Equal(t, flow.HexToAddress("01"), args.Get(1).(flow.Address)) - }).Return(mockAccount, nil) - srv.Network.Return(config.Network{ - Name: "testnet", - }, nil) - srv.ExecuteScript.Run(func(args mock.Arguments) { - script := args.Get(1).(flowkit.Script) - - assert.Equal(t, templates.GenerateGetStagedContractCodeScript(MigrationContractStagingAddress("testnet")), script.Code) - - assert.Equal(t, 2, len(script.Args)) - actualContractAddressArg, actualContractNameArg := script.Args[0], script.Args[1] - - contractName, _ := cadence.NewString("ImpContract") - contractAddr := cadence.NewAddress(flow.HexToAddress("01")) - assert.Equal(t, contractName, actualContractNameArg) - assert.Equal(t, contractAddr, actualContractAddressArg) - }).Return(cadence.NewOptional(mockScriptResultString), nil) + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Foo": []byte(` + pub contract Foo { + pub fun test() {} + }`)}, + }, + { + address: flow.HexToAddress("02"), + contracts: map[string][]byte{"Bar": []byte(` + import Foo from 0x01 + pub contract Bar { + pub fun test() {} + init() { + Foo.test() + } + }`)}, + }, + }) + + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{ + { + DeployLocation: simpleAddressLocation("0x01.Foo"), + SourceLocation: common.StringLocation("./Foo.cdc"), + Code: []byte(` + access(all) contract Foo { + access(all) fun test() {} + }`), + }, + { + DeployLocation: simpleAddressLocation("0x02.Bar"), + SourceLocation: common.StringLocation("./Bar.cdc"), + Code: []byte(` + import Foo from 0x01 + access(all) contract Bar { + access(all) fun test() {} + init() { + Foo.test() + } + }`), + }, + }) - // validate - validator := newStagingValidator(srv.Mock, state) - err = validator.ValidateContractUpdate(location, sourceCodeLocation, []byte(newContract)) require.NoError(t, err) }) + + t.Run("validates cyclic imports", func(t *testing.T) { + // setup mocks + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Foo": []byte(` + pub contract Foo { + pub fun test() {} + init() {} + }`)}, + }, + { + address: flow.HexToAddress("02"), + contracts: map[string][]byte{"Bar": []byte(` + pub contract Bar { + pub fun test() {} + init() { + Foo.test() + } + }`)}, + }, + }) + + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{ + { + DeployLocation: simpleAddressLocation("0x01.Foo"), + SourceLocation: common.StringLocation("./Foo.cdc"), + Code: []byte(` + import Bar from 0x02 + access(all) contract Foo { + access(all) fun test() {} + init() {} + }`), + }, + { + DeployLocation: simpleAddressLocation("0x02.Bar"), + SourceLocation: common.StringLocation("./Bar.cdc"), + Code: []byte(` + import Foo from 0x01 + access(all) contract Bar { + access(all) fun test() {} + init() { + Foo.test() + } + }`), + }, + }) + + var validatorErr *stagingValidatorError + require.ErrorAs(t, err, &validatorErr) + + require.Equal(t, 2, len(validatorErr.errors)) + + // check that error exists & ensure that the local contract names are used (not the deploy locations) + var cyclicImportError *sema.CyclicImportsError + require.ErrorAs(t, validatorErr.errors[simpleAddressLocation("0x01.Foo")], &cyclicImportError) + require.ErrorAs(t, validatorErr.errors[simpleAddressLocation("0x02.Bar")], &cyclicImportError) + }) + + t.Run("upstream missing dependency errors", func(t *testing.T) { + // setup mocks + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Foo": []byte(` + import ImpContract from 0x03 + pub contract Foo { + pub fun test() {} + init() {} + }`)}, + }, + { + address: flow.HexToAddress("02"), + contracts: map[string][]byte{"Bar": []byte(` + pub contract Bar { + pub fun test() {} + init() { + Foo.test() + } + }`)}, + }, + { + address: flow.HexToAddress("03"), + contracts: map[string][]byte{"ImpContract": []byte(` + pub contract ImpContract {} + `)}, + }, + { + address: flow.HexToAddress("04"), + contracts: map[string][]byte{"AnotherImp": []byte(` + pub contract AnotherImp {} + `)}, + }, + }) + + validator := newStagingValidator(srv) + + // ordering is important here, e.g. even though Foo is checked + // first, Bar will still recognize the missing dependency + err := validator.Validate([]stagedContractUpdate{ + { + DeployLocation: simpleAddressLocation("0x01.Foo"), + SourceLocation: common.StringLocation("./Foo.cdc"), + Code: []byte(` + // staged contract does not exist + import ImpContract from 0x03 + access(all) contract Foo { + access(all) fun test() {} + init() {} + }`), + }, + { + DeployLocation: simpleAddressLocation("0x02.Bar"), + SourceLocation: common.StringLocation("./Bar.cdc"), + Code: []byte(` + import Foo from 0x01 + import AnotherImp from 0x04 + access(all) contract Bar { + access(all) fun test() {} + init() { + Foo.test() + } + }`), + }, + }) + + var validatorErr *stagingValidatorError + require.ErrorAs(t, err, &validatorErr) + require.Equal(t, 2, len(validatorErr.errors)) + + var missingDependenciesErr *missingDependenciesError + require.ErrorAs(t, validatorErr.errors[simpleAddressLocation("0x01.Foo")], &missingDependenciesErr) + require.Equal(t, 1, len(missingDependenciesErr.MissingContracts)) + require.Equal(t, simpleAddressLocation("0x03.ImpContract"), missingDependenciesErr.MissingContracts[0]) + + require.ErrorAs(t, validatorErr.errors[simpleAddressLocation("0x02.Bar")], &missingDependenciesErr) + require.Equal(t, 2, len(missingDependenciesErr.MissingContracts)) + require.ElementsMatch(t, []common.AddressLocation{ + simpleAddressLocation("0x03.ImpContract"), + simpleAddressLocation("0x04.AnotherImp"), + }, missingDependenciesErr.MissingContracts) + }) + + t.Run("import Crypto checker", func(t *testing.T) { + // setup mocks + srv := setupValidatorMocks(t, []mockNetworkAccount{ + { + address: flow.HexToAddress("01"), + contracts: map[string][]byte{"Foo": []byte(` + import Crypto + pub contract Foo { + init() {} + }`)}, + }, + }) + + validator := newStagingValidator(srv) + err := validator.Validate([]stagedContractUpdate{ + { + DeployLocation: simpleAddressLocation("0x01.Foo"), + SourceLocation: common.StringLocation("./Foo.cdc"), + Code: []byte(` + import Crypto + access(all) contract Foo { + init() {} + }`), + }, + }) + + require.Nil(t, err) + }) } diff --git a/internal/migrate/unstage_contract_test.go b/internal/migrate/unstage_contract_test.go index 846947afe..c5866ea3d 100644 --- a/internal/migrate/unstage_contract_test.go +++ b/internal/migrate/unstage_contract_test.go @@ -101,7 +101,7 @@ func Test_UnstageContract(t *testing.T) { t.Run("missing contract file", func(t *testing.T) { srv, state, _ := util.TestMocks(t) - result, err := stageContract( + result, err := unstageContract( []string{testContract.Name}, command.GlobalFlags{ Network: "testnet", diff --git a/internal/util/util.go b/internal/util/util.go index 59b481fbf..2ace0b12e 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -116,3 +116,10 @@ func removeFromStringArray(s []string, el string) []string { func NormalizeLineEndings(s string) string { return strings.ReplaceAll(s, "\r\n", "\n") } + +func Pluralize(word string, count int) string { + if count == 1 { + return word + } + return word + "s" +}