From 8b06220ed10719687735b94e8987886712c0ad7e Mon Sep 17 00:00:00 2001 From: Jordan Ribbink Date: Wed, 17 Apr 2024 16:13:02 -0700 Subject: [PATCH] Fix `access(account)` bug in contract update validator (#1523) --- internal/migrate/staging_validator.go | 33 +++++++++++- internal/migrate/staging_validator_test.go | 61 ++++++++++++++++++++++ 2 files changed, 92 insertions(+), 2 deletions(-) diff --git a/internal/migrate/staging_validator.go b/internal/migrate/staging_validator.go index 11decd165..ba4d1c9e6 100644 --- a/internal/migrate/staging_validator.go +++ b/internal/migrate/staging_validator.go @@ -46,6 +46,11 @@ type stagingValidator struct { flow flowkit.Services state *flowkit.State + // Location of the source code that is used for the update + sourceCodeLocation common.Location + // Location of the target contract that is being updated + targetLocation common.AddressLocation + // Cache for account contract names so we don't have to fetch them multiple times accountContractNames map[common.Address][]string // All resolved contract code @@ -99,6 +104,9 @@ func (v *stagingValidator) ValidateContractUpdate( // Code of the updated contract updatedCode []byte, ) error { + v.sourceCodeLocation = sourceCodeLocation + v.targetLocation = location + // Resolve all system contract code & add to cache v.loadSystemContracts() @@ -192,8 +200,9 @@ func (v *stagingValidator) parseAndCheckContract( // Only checking contracts, so no need to consider script standard library return util.NewStandardLibrary().BaseValueActivation }, - LocationHandler: v.resolveLocation, - ImportHandler: v.resolveImport, + LocationHandler: v.resolveLocation, + ImportHandler: v.resolveImport, + MemberAccountAccessHandler: v.resolveAccountAccess, }, ) if err != nil { @@ -381,6 +390,26 @@ func (v *stagingValidator) resolveLocation( return resolvedLocations, nil } +func (v *stagingValidator) resolveAccountAccess(checker *sema.Checker, memberLocation common.Location) bool { + if checker == nil { + return false + } + + checkerLocation, ok := checker.Location.(common.StringLocation) + if !ok { + return false + } + + memberAddressLocation, ok := memberLocation.(common.AddressLocation) + if !ok { + return false + } + + // If the source code of the update is being checked, we should check account access based on the + // targeted network location of the contract & not the source code location + return checkerLocation == v.sourceCodeLocation && memberAddressLocation.Address == v.targetLocation.Address +} + func (v *stagingValidator) resolveAddressContractNames(address common.Address) ([]string, error) { // Check if the contract names are already cached if names, ok := v.accountContractNames[address]; ok { diff --git a/internal/migrate/staging_validator_test.go b/internal/migrate/staging_validator_test.go index 48db6e70f..fd2e5070b 100644 --- a/internal/migrate/staging_validator_test.go +++ b/internal/migrate/staging_validator_test.go @@ -288,4 +288,65 @@ func Test_StagingValidator(t *testing.T) { err := validator.ValidateContractUpdate(location, sourceCodeLocation, []byte(newContract)) require.NoError(t, err) }) + + t.Run("resolves account access correctly", func(t *testing.T) { + location := common.NewAddressLocation(nil, common.Address{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, "Test") + sourceCodeLocation := common.StringLocation("./Test.cdc") + oldContract := ` + import ImpContract from 0x01 + pub contract Test { + pub fun test() {} + }` + newContract := ` + import ImpContract from 0x01 + access(all) contract Test { + access(all) fun test() {} + init() { + ImpContract.test() + } + }` + impContract := ` + access(all) contract ImpContract { + access(account) fun test() {} + init() {} + }` + mockScriptResultString, err := cadence.NewString(impContract) + require.NoError(t, err) + + mockAccount := &flow.Account{ + Address: flow.HexToAddress("01"), + Balance: 1000, + Keys: nil, + Contracts: map[string][]byte{ + "Test": []byte(oldContract), + }, + } + + // setup mocks + require.NoError(t, rw.WriteFile(sourceCodeLocation.String(), []byte(newContract), 0o644)) + srv.GetAccount.Run(func(args mock.Arguments) { + require.Equal(t, flow.HexToAddress("01"), args.Get(1).(flow.Address)) + }).Return(mockAccount, nil) + srv.Network.Return(config.Network{ + Name: "testnet", + }, nil) + srv.ExecuteScript.Run(func(args mock.Arguments) { + script := args.Get(1).(flowkit.Script) + + assert.Equal(t, templates.GenerateGetStagedContractCodeScript(MigrationContractStagingAddress("testnet")), script.Code) + + assert.Equal(t, 2, len(script.Args)) + actualContractAddressArg, actualContractNameArg := script.Args[0], script.Args[1] + + contractName, _ := cadence.NewString("ImpContract") + contractAddr := cadence.NewAddress(flow.HexToAddress("01")) + assert.Equal(t, contractName, actualContractNameArg) + assert.Equal(t, contractAddr, actualContractAddressArg) + }).Return(cadence.NewOptional(mockScriptResultString), nil) + + // validate + validator := newStagingValidator(srv.Mock, state) + err = validator.ValidateContractUpdate(location, sourceCodeLocation, []byte(newContract)) + require.NoError(t, err) + }) }