diff --git a/cmd/util/cmd/execution-state-extract/cmd.go b/cmd/util/cmd/execution-state-extract/cmd.go index 9bfba99bad4..b7edecd8936 100644 --- a/cmd/util/cmd/execution-state-extract/cmd.go +++ b/cmd/util/cmd/execution-state-extract/cmd.go @@ -1,8 +1,10 @@ package extract import ( + "compress/gzip" "encoding/hex" "fmt" + "io" "os" "path" "runtime/pprof" @@ -32,6 +34,8 @@ var ( flagChain string flagNWorker int flagNoMigration bool + flagMigration string + flagAuthorizationFixes string flagNoReport bool flagValidateMigration bool flagAllowPartialStateFromPayloads bool @@ -87,6 +91,12 @@ func init() { Cmd.Flags().BoolVar(&flagNoMigration, "no-migration", false, "don't migrate the state") + Cmd.Flags().StringVar(&flagMigration, "migration", "cadence-1.0", + "migration name. 'cadence-1.0' (default) or 'fix-authorizations'") + + Cmd.Flags().StringVar(&flagAuthorizationFixes, "authorization-fixes", "", + "authorization fixes to apply. requires '--migration=fix-authorizations'") + Cmd.Flags().BoolVar(&flagNoReport, "no-report", false, "don't report the state") @@ -195,7 +205,10 @@ func run(*cobra.Command, []string) { defer pprof.StopCPUProfile() } - var stateCommitment flow.StateCommitment + err := os.MkdirAll(flagOutputDir, 0755) + if err != nil { + log.Fatal().Err(err).Msgf("cannot create output directory %s", flagOutputDir) + } if len(flagBlockHash) > 0 && len(flagStateCommitment) > 0 { log.Fatal().Msg("cannot run the command with both block hash and state commitment as inputs, only one of them should be provided") @@ -219,6 +232,21 @@ func run(*cobra.Command, []string) { log.Fatal().Msg("Both --validate and --diff are enabled, please specify only one (or none) of these") } + switch flagMigration { + case "cadence-1.0": + // valid, no-op + + case "fix-authorizations": + if flagAuthorizationFixes == "" { + log.Fatal().Msg("--migration=fix-authorizations requires --authorization-fixes") + } + + default: + log.Fatal().Msg("Invalid --migration: got %s, expected 'cadence-1.0' or 'fix-authorizations'") + } + + var stateCommitment flow.StateCommitment + if len(flagBlockHash) > 0 { blockID, err := flow.HexStringToIdentifier(flagBlockHash) if err != nil { @@ -429,9 +457,29 @@ func run(*cobra.Command, []string) { // Migrate payloads. if !flagNoMigration { - migrations := newMigrations(log.Logger, flagOutputDir, opts) + var migs []migrations.NamedMigration + + switch flagMigration { + case "cadence-1.0": + migs = newCadence1Migrations( + log.Logger, + flagOutputDir, + opts, + ) + + case "fix-authorizations": + migs = newFixAuthorizationsMigrations( + log.Logger, + flagAuthorizationFixes, + flagOutputDir, + opts, + ) + + default: + log.Fatal().Msgf("unknown migration: %s", flagMigration) + } - migration := newMigration(log.Logger, migrations, flagNWorker) + migration := newMigration(log.Logger, migs, flagNWorker) payloads, err = migration(payloads) if err != nil { @@ -509,3 +557,35 @@ func ensureCheckpointFileExist(dir string) error { return fmt.Errorf("no checkpoint file was found, no root checkpoint file was found in %v, check the --execution-state-dir flag", dir) } + +func readAuthorizationFixes(path string) migrations.AuthorizationFixes { + + file, err := os.Open(path) + if err != nil { + log.Fatal().Err(err).Msgf("can't open authorization fixes: %s", path) + } + defer file.Close() + + var reader io.Reader = file + if isGzip(file) { + reader, err = gzip.NewReader(file) + if err != nil { + log.Fatal().Err(err).Msgf("failed to create gzip reader for %s", path) + } + } + + log.Info().Msgf("Reading authorization fixes from %s ...", path) + + fixes, err := migrations.ReadAuthorizationFixes(reader, nil) + if err != nil { + log.Fatal().Err(err).Msgf("failed to read authorization fixes %s", path) + } + + log.Info().Msgf("Read %d authorization fixes", len(fixes)) + + return fixes +} + +func isGzip(file *os.File) bool { + return strings.HasSuffix(file.Name(), ".gz") +} diff --git a/cmd/util/cmd/execution-state-extract/execution_state_extract.go b/cmd/util/cmd/execution-state-extract/execution_state_extract.go index f9055f2f2d0..49b1728bb69 100644 --- a/cmd/util/cmd/execution-state-extract/execution_state_extract.go +++ b/cmd/util/cmd/execution-state-extract/execution_state_extract.go @@ -358,13 +358,13 @@ func createTrieFromPayloads(logger zerolog.Logger, payloads []*ledger.Payload) ( return newTrie, nil } -func newMigrations( +func newCadence1Migrations( log zerolog.Logger, outputDir string, opts migrators.Options, ) []migrators.NamedMigration { - log.Info().Msg("initializing migrations") + log.Info().Msg("initializing Cadence 1.0 migrations ...") rwf := reporters.NewReportFileWriterFactory(outputDir, log) @@ -394,3 +394,43 @@ func newMigrations( return namedMigrations } + +func newFixAuthorizationsMigrations( + log zerolog.Logger, + authorizationFixesPath string, + outputDir string, + opts migrators.Options, +) []migrators.NamedMigration { + + log.Info().Msg("initializing authorization fix migrations ...") + + rwf := reporters.NewReportFileWriterFactory(outputDir, log) + + authorizationFixes := readAuthorizationFixes(authorizationFixesPath) + + namedMigrations := migrators.NewFixAuthorizationsMigrations( + log, + rwf, + authorizationFixes, + opts, + ) + + // At the end, fix up storage-used discrepancies + namedMigrations = append( + namedMigrations, + migrators.NamedMigration{ + Name: "account-usage-migration", + Migrate: migrators.NewAccountBasedMigration( + log, + opts.NWorker, + []migrators.AccountBasedMigration{ + migrators.NewAccountUsageMigration(rwf), + }, + ), + }, + ) + + log.Info().Msg("initialized migrations") + + return namedMigrations +} diff --git a/cmd/util/cmd/generate-authorization-fixes/cmd.go b/cmd/util/cmd/generate-authorization-fixes/cmd.go new file mode 100644 index 00000000000..ce034ef3d79 --- /dev/null +++ b/cmd/util/cmd/generate-authorization-fixes/cmd.go @@ -0,0 +1,443 @@ +package generate_authorization_fixes + +import ( + "compress/gzip" + "encoding/json" + "io" + "os" + "strings" + "sync" + + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/interpreter" + "github.com/onflow/cadence/runtime/sema" + "github.com/onflow/cadence/runtime/stdlib" + "github.com/rs/zerolog/log" + "github.com/schollz/progressbar/v3" + "github.com/spf13/cobra" + + common2 "github.com/onflow/flow-go/cmd/util/common" + "github.com/onflow/flow-go/cmd/util/ledger/migrations" + "github.com/onflow/flow-go/cmd/util/ledger/reporters" + "github.com/onflow/flow-go/cmd/util/ledger/util" + "github.com/onflow/flow-go/cmd/util/ledger/util/registers" + "github.com/onflow/flow-go/ledger" + "github.com/onflow/flow-go/model/flow" +) + +var ( + flagPayloads string + flagState string + flagStateCommitment string + flagOutputDirectory string + flagChain string + flagLinkMigrationReport string + flagAddresses string +) + +var Cmd = &cobra.Command{ + Use: "generate-authorization-fixes", + Short: "generate authorization fixes for capability controllers", + Run: run, +} + +func init() { + + Cmd.Flags().StringVar( + &flagPayloads, + "payloads", + "", + "Input payload file name", + ) + + Cmd.Flags().StringVar( + &flagState, + "state", + "", + "Input state file name", + ) + + Cmd.Flags().StringVar( + &flagStateCommitment, + "state-commitment", + "", + "Input state commitment", + ) + + Cmd.Flags().StringVar( + &flagOutputDirectory, + "output-directory", + "", + "Output directory", + ) + + Cmd.Flags().StringVar( + &flagChain, + "chain", + "", + "Chain name", + ) + _ = Cmd.MarkFlagRequired("chain") + + Cmd.Flags().StringVar( + &flagLinkMigrationReport, + "link-migration-report", + "", + "Input link migration report file name", + ) + _ = Cmd.MarkFlagRequired("link-migration-report") + + Cmd.Flags().StringVar( + &flagAddresses, + "addresses", + "", + "only generate fixes for given accounts (comma-separated hex-encoded addresses)", + ) +} + +func run(*cobra.Command, []string) { + + var addressFilter map[common.Address]struct{} + + if len(flagAddresses) > 0 { + for _, hexAddr := range strings.Split(flagAddresses, ",") { + + hexAddr = strings.TrimSpace(hexAddr) + + if len(hexAddr) == 0 { + continue + } + + addr, err := common2.ParseAddress(hexAddr) + if err != nil { + log.Fatal().Err(err).Msgf("failed to parse address: %s", hexAddr) + } + + if addressFilter == nil { + addressFilter = make(map[common.Address]struct{}) + } + addressFilter[common.Address(addr)] = struct{}{} + } + + addresses := make([]string, 0, len(addressFilter)) + for addr := range addressFilter { + addresses = append(addresses, addr.HexWithPrefix()) + } + log.Info().Msgf( + "Only generating fixes for %d accounts: %s", + len(addressFilter), + addresses, + ) + } + + if flagPayloads == "" && flagState == "" { + log.Fatal().Msg("Either --payloads or --state must be provided") + } else if flagPayloads != "" && flagState != "" { + log.Fatal().Msg("Only one of --payloads or --state must be provided") + } + if flagState != "" && flagStateCommitment == "" { + log.Fatal().Msg("--state-commitment must be provided when --state is provided") + } + + rwf := reporters.NewReportFileWriterFactory(flagOutputDirectory, log.Logger) + + chainID := flow.ChainID(flagChain) + // Validate chain ID + _ = chainID.Chain() + + migratedPublicLinkSetChan := make(chan MigratedPublicLinkSet, 1) + go func() { + migratedPublicLinkSetChan <- readMigratedPublicLinkSet( + flagLinkMigrationReport, + addressFilter, + ) + }() + + registersByAccountChan := make(chan *registers.ByAccount, 1) + go func() { + registersByAccountChan <- loadRegistersByAccount() + }() + + migratedPublicLinkSet := <-migratedPublicLinkSetChan + registersByAccount := <-registersByAccountChan + + fixReporter := rwf.ReportWriter("authorization-fixes") + defer fixReporter.Close() + + authorizationFixGenerator := &AuthorizationFixGenerator{ + registersByAccount: registersByAccount, + chainID: chainID, + migratedPublicLinkSet: migratedPublicLinkSet, + reporter: fixReporter, + } + + log.Info().Msg("Generating authorization fixes ...") + + if len(addressFilter) > 0 { + authorizationFixGenerator.generateFixesForAccounts(addressFilter) + } else { + authorizationFixGenerator.generateFixesForAllAccounts() + } +} + +func loadRegistersByAccount() *registers.ByAccount { + // Read payloads from payload file or checkpoint file + + var payloads []*ledger.Payload + var err error + + if flagPayloads != "" { + log.Info().Msgf("Reading payloads from %s", flagPayloads) + + _, payloads, err = util.ReadPayloadFile(log.Logger, flagPayloads) + if err != nil { + log.Fatal().Err(err).Msg("failed to read payloads") + } + } else { + log.Info().Msgf("Reading trie %s", flagStateCommitment) + + stateCommitment := util.ParseStateCommitment(flagStateCommitment) + payloads, err = util.ReadTrie(flagState, stateCommitment) + if err != nil { + log.Fatal().Err(err).Msg("failed to read state") + } + } + + log.Info().Msgf("creating registers from payloads (%d)", len(payloads)) + + registersByAccount, err := registers.NewByAccountFromPayloads(payloads) + if err != nil { + log.Fatal().Err(err) + } + log.Info().Msgf( + "created %d registers from payloads (%d accounts)", + registersByAccount.Count(), + registersByAccount.AccountCount(), + ) + + return registersByAccount +} + +func readMigratedPublicLinkSet(path string, addressFilter map[common.Address]struct{}) MigratedPublicLinkSet { + + file, err := os.Open(path) + if err != nil { + log.Fatal().Err(err).Msgf("can't open link migration report: %s", path) + } + defer file.Close() + + var reader io.Reader = file + if isGzip(file) { + reader, err = gzip.NewReader(file) + if err != nil { + log.Fatal().Err(err).Msgf("failed to create gzip reader for %s", path) + } + } + + log.Info().Msgf("Reading link migration report from %s ...", path) + + migratedPublicLinkSet, err := ReadMigratedPublicLinkSet(reader, addressFilter) + if err != nil { + log.Fatal().Err(err).Msgf("failed to read public link report: %s", path) + } + + log.Info().Msgf("Read %d public link migration entries", len(migratedPublicLinkSet)) + + return migratedPublicLinkSet +} + +func jsonEncodeAuthorization(authorization interpreter.Authorization) string { + switch authorization { + case interpreter.UnauthorizedAccess, interpreter.InaccessibleAccess: + return "" + default: + return string(authorization.ID()) + } +} + +type fixEntitlementsEntry struct { + CapabilityAddress common.Address + CapabilityID uint64 + ReferencedType interpreter.StaticType + Authorization interpreter.Authorization +} + +var _ json.Marshaler = fixEntitlementsEntry{} + +func (e fixEntitlementsEntry) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + CapabilityAddress string `json:"capability_address"` + CapabilityID uint64 `json:"capability_id"` + ReferencedType string `json:"referenced_type"` + Authorization string `json:"authorization"` + }{ + CapabilityAddress: e.CapabilityAddress.String(), + CapabilityID: e.CapabilityID, + ReferencedType: string(e.ReferencedType.ID()), + Authorization: jsonEncodeAuthorization(e.Authorization), + }) +} + +type AuthorizationFixGenerator struct { + registersByAccount *registers.ByAccount + chainID flow.ChainID + migratedPublicLinkSet MigratedPublicLinkSet + reporter reporters.ReportWriter +} + +func (g *AuthorizationFixGenerator) generateFixesForAllAccounts() { + var wg sync.WaitGroup + progress := progressbar.Default(int64(g.registersByAccount.AccountCount()), "Processing:") + + err := g.registersByAccount.ForEachAccount(func(accountRegisters *registers.AccountRegisters) error { + address := common.MustBytesToAddress([]byte(accountRegisters.Owner())) + wg.Add(1) + go func(address common.Address) { + defer wg.Done() + g.generateFixesForAccount(address) + _ = progress.Add(1) + }(address) + return nil + }) + if err != nil { + log.Fatal().Err(err) + } + + wg.Wait() + _ = progress.Finish() +} + +func (g *AuthorizationFixGenerator) generateFixesForAccounts(addresses map[common.Address]struct{}) { + var wg sync.WaitGroup + progress := progressbar.Default(int64(len(addresses)), "Processing:") + + for address := range addresses { + wg.Add(1) + go func(address common.Address) { + defer wg.Done() + g.generateFixesForAccount(address) + _ = progress.Add(1) + }(address) + } + + wg.Wait() + _ = progress.Finish() +} + +func (g *AuthorizationFixGenerator) generateFixesForAccount(address common.Address) { + mr, err := migrations.NewInterpreterMigrationRuntime( + g.registersByAccount, + g.chainID, + migrations.InterpreterMigrationRuntimeConfig{}, + ) + if err != nil { + log.Fatal().Err(err) + } + + capabilityControllerStorage := mr.Storage.GetStorageMap( + address, + stdlib.CapabilityControllerStorageDomain, + false, + ) + if capabilityControllerStorage == nil { + return + } + + iterator := capabilityControllerStorage.Iterator(nil) + for { + k, v := iterator.Next() + + if k == nil || v == nil { + break + } + + key, ok := k.(interpreter.Uint64AtreeValue) + if !ok { + log.Fatal().Msgf("unexpected key type: %T", k) + } + + capabilityID := uint64(key) + + value := interpreter.MustConvertUnmeteredStoredValue(v) + + capabilityController, ok := value.(*interpreter.StorageCapabilityControllerValue) + if !ok { + continue + } + + borrowType := capabilityController.BorrowType + + switch borrowType.Authorization.(type) { + case interpreter.EntitlementSetAuthorization: + g.maybeGenerateFixForEntitledCapabilityController( + address, + capabilityID, + borrowType, + ) + + case interpreter.Unauthorized: + // Already unauthorized, nothing to do + + case interpreter.Inaccessible: + log.Warn().Msgf( + "capability controller %d in account %s has borrow type with inaccessible authorization", + capabilityID, + address.HexWithPrefix(), + ) + + case interpreter.EntitlementMapAuthorization: + log.Warn().Msgf( + "capability controller %d in account %s has borrow type with entitlement map authorization", + capabilityID, + address.HexWithPrefix(), + ) + + default: + log.Warn().Msgf( + "capability controller %d in account %s has borrow type with entitlement map authorization", + capabilityID, + address.HexWithPrefix(), + ) + } + } +} + +func newEntitlementSetAuthorizationFromTypeIDs( + typeIDs []common.TypeID, + setKind sema.EntitlementSetKind, +) interpreter.EntitlementSetAuthorization { + return interpreter.NewEntitlementSetAuthorization( + nil, + func() []common.TypeID { + return typeIDs + }, + len(typeIDs), + setKind, + ) +} + +func (g *AuthorizationFixGenerator) maybeGenerateFixForEntitledCapabilityController( + capabilityAddress common.Address, + capabilityID uint64, + borrowType *interpreter.ReferenceStaticType, +) { + // Only remove the authorization if the capability controller was migrated from a public link + _, ok := g.migratedPublicLinkSet[AccountCapabilityID{ + Address: capabilityAddress, + CapabilityID: capabilityID, + }] + if !ok { + return + } + + g.reporter.Write(fixEntitlementsEntry{ + CapabilityAddress: capabilityAddress, + CapabilityID: capabilityID, + ReferencedType: borrowType.ReferencedType, + Authorization: borrowType.Authorization, + }) +} + +func isGzip(file *os.File) bool { + return strings.HasSuffix(file.Name(), ".gz") +} diff --git a/cmd/util/cmd/generate-authorization-fixes/cmd_test.go b/cmd/util/cmd/generate-authorization-fixes/cmd_test.go new file mode 100644 index 00000000000..7a5f8f0f459 --- /dev/null +++ b/cmd/util/cmd/generate-authorization-fixes/cmd_test.go @@ -0,0 +1,258 @@ +package generate_authorization_fixes + +import ( + "fmt" + "testing" + + "github.com/onflow/cadence" + jsoncdc "github.com/onflow/cadence/encoding/json" + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/interpreter" + "github.com/onflow/cadence/runtime/sema" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/onflow/flow-go/cmd/util/ledger/migrations" + "github.com/onflow/flow-go/cmd/util/ledger/reporters" + "github.com/onflow/flow-go/cmd/util/ledger/util/registers" + "github.com/onflow/flow-go/fvm" + "github.com/onflow/flow-go/fvm/storage/snapshot" + "github.com/onflow/flow-go/ledger" + "github.com/onflow/flow-go/ledger/common/convert" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/utils/unittest" +) + +func newBootstrapPayloads( + chainID flow.ChainID, + bootstrapProcedureOptions ...fvm.BootstrapProcedureOption, +) ([]*ledger.Payload, error) { + + ctx := fvm.NewContext( + fvm.WithChain(chainID.Chain()), + ) + + vm := fvm.NewVirtualMachine() + + storageSnapshot := snapshot.MapStorageSnapshot{} + + bootstrapProcedure := fvm.Bootstrap( + unittest.ServiceAccountPublicKey, + bootstrapProcedureOptions..., + ) + + executionSnapshot, _, err := vm.Run( + ctx, + bootstrapProcedure, + storageSnapshot, + ) + if err != nil { + return nil, err + } + + payloads := make([]*ledger.Payload, 0, len(executionSnapshot.WriteSet)) + + for registerID, registerValue := range executionSnapshot.WriteSet { + payloadKey := convert.RegisterIDToLedgerKey(registerID) + payload := ledger.NewPayload(payloadKey, registerValue) + payloads = append(payloads, payload) + } + + return payloads, nil +} + +type testReportWriter struct { + entries []any +} + +func (t *testReportWriter) Write(entry interface{}) { + t.entries = append(t.entries, entry) +} + +func (*testReportWriter) Close() { + // NO-OP +} + +var _ reporters.ReportWriter = &testReportWriter{} + +func TestGenerateAuthorizationFixes(t *testing.T) { + t.Parallel() + + const chainID = flow.Emulator + chain := chainID.Chain() + + address, err := chain.AddressAtIndex(1000) + require.NoError(t, err) + + require.Equal(t, "bf519681cdb888b1", address.Hex()) + + log := zerolog.New(zerolog.NewTestWriter(t)) + + bootstrapPayloads, err := newBootstrapPayloads(chainID) + require.NoError(t, err) + + registersByAccount, err := registers.NewByAccountFromPayloads(bootstrapPayloads) + require.NoError(t, err) + + mr := migrations.NewBasicMigrationRuntime(registersByAccount) + + err = mr.Accounts.Create(nil, address) + require.NoError(t, err) + + expectedWriteAddresses := map[flow.Address]struct{}{ + address: {}, + } + + err = mr.Commit(expectedWriteAddresses, log) + require.NoError(t, err) + + const contractCode = ` + access(all) contract Test { + + access(all) entitlement E + + access(all) struct S {} + } + ` + + deployTX := flow.NewTransactionBody(). + SetScript([]byte(` + transaction(code: String) { + prepare(signer: auth(Contracts) &Account) { + signer.contracts.add(name: "Test", code: code.utf8) + } + } + `)). + AddAuthorizer(address). + AddArgument(jsoncdc.MustEncode(cadence.String(contractCode))) + + runDeployTx := migrations.NewTransactionBasedMigration( + deployTX, + chainID, + log, + expectedWriteAddresses, + ) + err = runDeployTx(registersByAccount) + require.NoError(t, err) + + setupTx := flow.NewTransactionBody(). + SetScript([]byte(fmt.Sprintf(` + import Test from %s + + transaction { + prepare(signer: auth(Storage, Capabilities) &Account) { + // Capability 1 was a public, unauthorized capability, which is now authorized. + // It should lose its entitlement + let cap1 = signer.capabilities.storage.issue(/storage/s) + signer.capabilities.publish(cap1, at: /public/s1) + + // Capability 2 was a public, unauthorized capability, which is now authorized. + // It is currently only stored, nested, in storage, and is not published. + // It should lose its entitlement + let cap2 = signer.capabilities.storage.issue(/storage/s) + signer.storage.save([cap2], to: /storage/caps2) + + // Capability 3 was a private, authorized capability. + // It is currently only stored, nested, in storage, and is not published. + // It should keep its entitlement + let cap3 = signer.capabilities.storage.issue(/storage/s) + signer.storage.save([cap3], to: /storage/caps3) + + // Capability 4 was a private, authorized capability. + // It is currently both stored, nested, in storage, and is published. + // It should keep its entitlement + let cap4 = signer.capabilities.storage.issue(/storage/s) + signer.storage.save([cap4], to: /storage/caps4) + signer.capabilities.publish(cap4, at: /public/s4) + + // Capability 5 was a public, unauthorized capability, which is still unauthorized. + // It is currently both stored, nested, in storage, and is published. + // There is no need to fix it. + let cap5 = signer.capabilities.storage.issue<&Test.S>(/storage/s) + signer.storage.save([cap5], to: /storage/caps5) + signer.capabilities.publish(cap5, at: /public/s5) + } + } + `, + address.HexWithPrefix(), + ))). + AddAuthorizer(address) + + runSetupTx := migrations.NewTransactionBasedMigration( + setupTx, + chainID, + log, + expectedWriteAddresses, + ) + err = runSetupTx(registersByAccount) + require.NoError(t, err) + + testContractLocation := common.AddressLocation{ + Address: common.Address(address), + Name: "Test", + } + + migratedPublicLinkSet := MigratedPublicLinkSet{ + { + Address: common.Address(address), + CapabilityID: 1, + }: {}, + { + Address: common.Address(address), + CapabilityID: 2, + }: {}, + { + Address: common.Address(address), + CapabilityID: 5, + }: {}, + } + + reporter := &testReportWriter{} + + generator := &AuthorizationFixGenerator{ + registersByAccount: registersByAccount, + chainID: chainID, + migratedPublicLinkSet: migratedPublicLinkSet, + reporter: reporter, + } + generator.generateFixesForAllAccounts() + + eTypeID := testContractLocation.TypeID(nil, "Test.E") + + assert.Equal(t, + []any{ + fixEntitlementsEntry{ + CapabilityAddress: common.Address(address), + CapabilityID: 1, + ReferencedType: interpreter.NewCompositeStaticTypeComputeTypeID( + nil, + testContractLocation, + "Test.S", + ), + Authorization: newEntitlementSetAuthorizationFromTypeIDs( + []common.TypeID{ + eTypeID, + }, + sema.Conjunction, + ), + }, + fixEntitlementsEntry{ + CapabilityAddress: common.Address(address), + CapabilityID: 2, + ReferencedType: interpreter.NewCompositeStaticTypeComputeTypeID( + nil, + testContractLocation, + "Test.S", + ), + Authorization: newEntitlementSetAuthorizationFromTypeIDs( + []common.TypeID{ + eTypeID, + }, + sema.Conjunction, + ), + }, + }, + reporter.entries, + ) +} diff --git a/cmd/util/cmd/generate-authorization-fixes/link_migration_report.go b/cmd/util/cmd/generate-authorization-fixes/link_migration_report.go new file mode 100644 index 00000000000..b5888d8cf92 --- /dev/null +++ b/cmd/util/cmd/generate-authorization-fixes/link_migration_report.go @@ -0,0 +1,93 @@ +package generate_authorization_fixes + +import ( + "encoding/json" + "fmt" + "io" + "strings" + + "github.com/onflow/cadence/runtime/common" +) + +// AccountCapabilityID is a capability ID in an account. +type AccountCapabilityID struct { + Address common.Address + CapabilityID uint64 +} + +// MigratedPublicLinkSet is a set of capability controller IDs which were migrated from public links. +type MigratedPublicLinkSet map[AccountCapabilityID]struct{} + +// ReadMigratedPublicLinkSet reads a link migration report from the given reader, +// and returns a set of all capability controller IDs which were migrated from public links. +// +// The report is expected to be a JSON array of objects with the following structure: +// +// [ +// {"kind":"link-migration-success","account_address":"0x1","path":"/public/foo","capability_id":1}, +// ] +func ReadMigratedPublicLinkSet( + reader io.Reader, + filter map[common.Address]struct{}, +) (MigratedPublicLinkSet, error) { + + set := MigratedPublicLinkSet{} + + dec := json.NewDecoder(reader) + + token, err := dec.Token() + if err != nil { + return nil, fmt.Errorf("failed to read token: %w", err) + } + if token != json.Delim('[') { + return nil, fmt.Errorf("expected start of array, got %s", token) + } + + for dec.More() { + var entry struct { + Kind string `json:"kind"` + Address string `json:"account_address"` + Path string `json:"path"` + CapabilityID uint64 `json:"capability_id"` + } + err := dec.Decode(&entry) + if err != nil { + return nil, fmt.Errorf("failed to decode entry: %w", err) + } + + if entry.Kind != "link-migration-success" { + continue + } + + if !strings.HasPrefix(entry.Path, "/public/") { + continue + } + + address, err := common.HexToAddress(entry.Address) + if err != nil { + return nil, fmt.Errorf("failed to parse address: %w", err) + } + + if filter != nil { + if _, ok := filter[address]; !ok { + continue + } + } + + accountCapabilityID := AccountCapabilityID{ + Address: address, + CapabilityID: entry.CapabilityID, + } + set[accountCapabilityID] = struct{}{} + } + + token, err = dec.Token() + if err != nil { + return nil, fmt.Errorf("failed to read token: %w", err) + } + if token != json.Delim(']') { + return nil, fmt.Errorf("expected end of array, got %s", token) + } + + return set, nil +} diff --git a/cmd/util/cmd/generate-authorization-fixes/link_migration_report_test.go b/cmd/util/cmd/generate-authorization-fixes/link_migration_report_test.go new file mode 100644 index 00000000000..fd7ffac5ed9 --- /dev/null +++ b/cmd/util/cmd/generate-authorization-fixes/link_migration_report_test.go @@ -0,0 +1,70 @@ +package generate_authorization_fixes + +import ( + "strings" + "testing" + + "github.com/onflow/cadence/runtime/common" + "github.com/stretchr/testify/require" +) + +func TestReadPublicLinkMigrationReport(t *testing.T) { + t.Parallel() + + contents := ` + [ + {"kind":"link-migration-success","account_address":"0x1","path":"/public/foo","capability_id":1}, + {"kind":"link-migration-success","account_address":"0x2","path":"/private/bar","capability_id":2}, + {"kind":"link-migration-success","account_address":"0x3","path":"/public/baz","capability_id":3} + ] + ` + + t.Run("unfiltered", func(t *testing.T) { + t.Parallel() + + reader := strings.NewReader(contents) + + mapping, err := ReadMigratedPublicLinkSet(reader, nil) + require.NoError(t, err) + + require.Equal(t, + MigratedPublicLinkSet{ + { + Address: common.MustBytesToAddress([]byte{0x1}), + CapabilityID: 1, + }: struct{}{}, + { + Address: common.MustBytesToAddress([]byte{0x3}), + CapabilityID: 3, + }: struct{}{}, + }, + mapping, + ) + }) + + t.Run("filtered", func(t *testing.T) { + t.Parallel() + + address1 := common.MustBytesToAddress([]byte{0x1}) + + reader := strings.NewReader(contents) + + mapping, err := ReadMigratedPublicLinkSet( + reader, + map[common.Address]struct{}{ + address1: {}, + }, + ) + require.NoError(t, err) + + require.Equal(t, + MigratedPublicLinkSet{ + { + Address: address1, + CapabilityID: 1, + }: struct{}{}, + }, + mapping, + ) + }) +} diff --git a/cmd/util/cmd/root.go b/cmd/util/cmd/root.go index dd11c40b14b..281d1dbebbf 100644 --- a/cmd/util/cmd/root.go +++ b/cmd/util/cmd/root.go @@ -29,6 +29,7 @@ import ( extractpayloads "github.com/onflow/flow-go/cmd/util/cmd/extract-payloads-by-address" find_inconsistent_result "github.com/onflow/flow-go/cmd/util/cmd/find-inconsistent-result" find_trie_root "github.com/onflow/flow-go/cmd/util/cmd/find-trie-root" + generate_authorization_fixes "github.com/onflow/flow-go/cmd/util/cmd/generate-authorization-fixes" read_badger "github.com/onflow/flow-go/cmd/util/cmd/read-badger/cmd" read_execution_state "github.com/onflow/flow-go/cmd/util/cmd/read-execution-state" read_hotstuff "github.com/onflow/flow-go/cmd/util/cmd/read-hotstuff/cmd" @@ -122,6 +123,7 @@ func addCommands() { rootCmd.AddCommand(check_storage.Cmd) rootCmd.AddCommand(debug_tx.Cmd) rootCmd.AddCommand(debug_script.Cmd) + rootCmd.AddCommand(generate_authorization_fixes.Cmd) } func initConfig() { diff --git a/cmd/util/cmd/run-script/cmd.go b/cmd/util/cmd/run-script/cmd.go index 7f12cef5b35..1f24d2599c2 100644 --- a/cmd/util/cmd/run-script/cmd.go +++ b/cmd/util/cmd/run-script/cmd.go @@ -1,19 +1,29 @@ package run_script import ( + "context" + "errors" + "fmt" "io" "os" jsoncdc "github.com/onflow/cadence/encoding/json" + "github.com/onflow/flow/protobuf/go/flow/entities" "github.com/rs/zerolog/log" "github.com/spf13/cobra" + "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/cmd/util/ledger/util" "github.com/onflow/flow-go/cmd/util/ledger/util/registers" + "github.com/onflow/flow-go/engine/access/rest" + "github.com/onflow/flow-go/engine/access/state_stream/backend" + "github.com/onflow/flow-go/engine/access/subscription" "github.com/onflow/flow-go/engine/execution/computation" "github.com/onflow/flow-go/fvm" + "github.com/onflow/flow-go/fvm/storage/snapshot" "github.com/onflow/flow-go/ledger" "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/module/metrics" ) var ( @@ -21,6 +31,8 @@ var ( flagState string flagStateCommitment string flagChain string + flagServe bool + flagPort int ) var Cmd = &cobra.Command{ @@ -60,6 +72,20 @@ func init() { "Chain name", ) _ = Cmd.MarkFlagRequired("chain") + + Cmd.Flags().BoolVar( + &flagServe, + "serve", + false, + "serve with an HTTP server", + ) + + Cmd.Flags().IntVar( + &flagPort, + "port", + 8000, + "port for HTTP server", + ) } func run(*cobra.Command, []string) { @@ -75,15 +101,14 @@ func run(*cobra.Command, []string) { chainID := flow.ChainID(flagChain) // Validate chain ID - _ = chainID.Chain() + chain := chainID.Chain() - code, err := io.ReadAll(os.Stdin) - if err != nil { - log.Fatal().Msgf("failed to read script: %s", err) - } - - var payloads []*ledger.Payload + log.Info().Msg("loading state ...") + var ( + err error + payloads []*ledger.Payload + ) if flagPayloads != "" { _, payloads, err = util.ReadPayloadFile(log.Logger, flagPayloads) } else { @@ -125,23 +150,390 @@ func run(*cobra.Command, []string) { vm := fvm.NewVirtualMachine() + if flagServe { + + api := &api{ + chainID: chainID, + vm: vm, + ctx: ctx, + storageSnapshot: storageSnapshot, + } + + server, err := rest.NewServer( + api, + rest.Config{ + ListenAddress: fmt.Sprintf(":%d", flagPort), + }, + log.Logger, + chain, + metrics.NewNoopCollector(), + nil, + backend.Config{}, + ) + if err != nil { + log.Fatal().Err(err).Msg("failed to create server") + } + + log.Info().Msgf("serving on port %d", flagPort) + + err = server.ListenAndServe() + if err != nil { + log.Info().Msg("server stopped") + } + } else { + code, err := io.ReadAll(os.Stdin) + if err != nil { + log.Fatal().Msgf("failed to read script: %s", err) + } + + encodedResult, err := runScript(vm, ctx, storageSnapshot, code, nil) + if err != nil { + log.Fatal().Err(err).Msg("failed to run script") + } + + _, _ = os.Stdout.Write(encodedResult) + } +} + +func runScript( + vm *fvm.VirtualMachine, + ctx fvm.Context, + storageSnapshot snapshot.StorageSnapshot, + code []byte, + arguments [][]byte, +) ( + encodedResult []byte, + err error, +) { _, res, err := vm.Run( ctx, - fvm.Script(code), + fvm.Script(code).WithArguments(arguments...), storageSnapshot, ) if err != nil { - log.Fatal().Msgf("failed to run script: %s", err) + return nil, err } if res.Err != nil { - log.Fatal().Msgf("script failed: %s", res.Err) + return nil, res.Err } encoded, err := jsoncdc.Encode(res.Value) if err != nil { - log.Fatal().Msgf("failed to encode result: %s", err) + return nil, err } - _, _ = os.Stdout.Write(encoded) + return encoded, nil +} + +type api struct { + chainID flow.ChainID + vm *fvm.VirtualMachine + ctx fvm.Context + storageSnapshot registers.StorageSnapshot +} + +var _ access.API = &api{} + +func (*api) Ping(_ context.Context) error { + return nil +} + +func (a *api) GetNetworkParameters(_ context.Context) access.NetworkParameters { + return access.NetworkParameters{ + ChainID: a.chainID, + } +} + +func (*api) GetNodeVersionInfo(_ context.Context) (*access.NodeVersionInfo, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetLatestBlockHeader(_ context.Context, _ bool) (*flow.Header, flow.BlockStatus, error) { + return nil, flow.BlockStatusUnknown, errors.New("unimplemented") +} + +func (*api) GetBlockHeaderByHeight(_ context.Context, _ uint64) (*flow.Header, flow.BlockStatus, error) { + return nil, flow.BlockStatusUnknown, errors.New("unimplemented") +} + +func (*api) GetBlockHeaderByID(_ context.Context, _ flow.Identifier) (*flow.Header, flow.BlockStatus, error) { + return nil, flow.BlockStatusUnknown, errors.New("unimplemented") +} + +func (*api) GetLatestBlock(_ context.Context, _ bool) (*flow.Block, flow.BlockStatus, error) { + return nil, flow.BlockStatusUnknown, errors.New("unimplemented") +} + +func (*api) GetBlockByHeight(_ context.Context, _ uint64) (*flow.Block, flow.BlockStatus, error) { + return nil, flow.BlockStatusUnknown, errors.New("unimplemented") +} + +func (*api) GetBlockByID(_ context.Context, _ flow.Identifier) (*flow.Block, flow.BlockStatus, error) { + return nil, flow.BlockStatusUnknown, errors.New("unimplemented") +} + +func (*api) GetCollectionByID(_ context.Context, _ flow.Identifier) (*flow.LightCollection, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetFullCollectionByID(_ context.Context, _ flow.Identifier) (*flow.Collection, error) { + return nil, errors.New("unimplemented") +} + +func (*api) SendTransaction(_ context.Context, _ *flow.TransactionBody) error { + return errors.New("unimplemented") +} + +func (*api) GetTransaction(_ context.Context, _ flow.Identifier) (*flow.TransactionBody, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetTransactionsByBlockID(_ context.Context, _ flow.Identifier) ([]*flow.TransactionBody, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetTransactionResult( + _ context.Context, + _ flow.Identifier, + _ flow.Identifier, + _ flow.Identifier, + _ entities.EventEncodingVersion, +) (*access.TransactionResult, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetTransactionResultByIndex( + _ context.Context, + _ flow.Identifier, + _ uint32, + _ entities.EventEncodingVersion, +) (*access.TransactionResult, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetTransactionResultsByBlockID( + _ context.Context, + _ flow.Identifier, + _ entities.EventEncodingVersion, +) ([]*access.TransactionResult, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetSystemTransaction( + _ context.Context, + _ flow.Identifier, +) (*flow.TransactionBody, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetSystemTransactionResult( + _ context.Context, + _ flow.Identifier, + _ entities.EventEncodingVersion, +) (*access.TransactionResult, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetAccount(_ context.Context, _ flow.Address) (*flow.Account, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetAccountAtLatestBlock(_ context.Context, _ flow.Address) (*flow.Account, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetAccountAtBlockHeight(_ context.Context, _ flow.Address, _ uint64) (*flow.Account, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetAccountBalanceAtLatestBlock(_ context.Context, _ flow.Address) (uint64, error) { + return 0, errors.New("unimplemented") +} + +func (*api) GetAccountBalanceAtBlockHeight( + _ context.Context, + _ flow.Address, + _ uint64, +) (uint64, error) { + return 0, errors.New("unimplemented") +} + +func (*api) GetAccountKeyAtLatestBlock( + _ context.Context, + _ flow.Address, + _ uint32, +) (*flow.AccountPublicKey, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetAccountKeyAtBlockHeight( + _ context.Context, + _ flow.Address, + _ uint32, + _ uint64, +) (*flow.AccountPublicKey, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetAccountKeysAtLatestBlock( + _ context.Context, + _ flow.Address, +) ([]flow.AccountPublicKey, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetAccountKeysAtBlockHeight( + _ context.Context, + _ flow.Address, + _ uint64, +) ([]flow.AccountPublicKey, error) { + return nil, errors.New("unimplemented") +} + +func (a *api) ExecuteScriptAtLatestBlock( + _ context.Context, + script []byte, + arguments [][]byte, +) ([]byte, error) { + return runScript( + a.vm, + a.ctx, + a.storageSnapshot, + script, + arguments, + ) +} + +func (*api) ExecuteScriptAtBlockHeight( + _ context.Context, + _ uint64, + _ []byte, + _ [][]byte, +) ([]byte, error) { + return nil, errors.New("unimplemented") +} + +func (*api) ExecuteScriptAtBlockID( + _ context.Context, + _ flow.Identifier, + _ []byte, + _ [][]byte, +) ([]byte, error) { + return nil, errors.New("unimplemented") +} + +func (a *api) GetEventsForHeightRange( + _ context.Context, + _ string, + _, _ uint64, + _ entities.EventEncodingVersion, +) ([]flow.BlockEvents, error) { + return nil, errors.New("unimplemented") +} + +func (a *api) GetEventsForBlockIDs( + _ context.Context, + _ string, + _ []flow.Identifier, + _ entities.EventEncodingVersion, +) ([]flow.BlockEvents, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetLatestProtocolStateSnapshot(_ context.Context) ([]byte, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetProtocolStateSnapshotByBlockID(_ context.Context, _ flow.Identifier) ([]byte, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetProtocolStateSnapshotByHeight(_ context.Context, _ uint64) ([]byte, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetExecutionResultForBlockID(_ context.Context, _ flow.Identifier) (*flow.ExecutionResult, error) { + return nil, errors.New("unimplemented") +} + +func (*api) GetExecutionResultByID(_ context.Context, _ flow.Identifier) (*flow.ExecutionResult, error) { + return nil, errors.New("unimplemented") +} + +func (*api) SubscribeBlocksFromStartBlockID( + _ context.Context, + _ flow.Identifier, + _ flow.BlockStatus, +) subscription.Subscription { + return nil +} + +func (*api) SubscribeBlocksFromStartHeight( + _ context.Context, + _ uint64, + _ flow.BlockStatus, +) subscription.Subscription { + return nil +} + +func (*api) SubscribeBlocksFromLatest( + _ context.Context, + _ flow.BlockStatus, +) subscription.Subscription { + return nil +} + +func (*api) SubscribeBlockHeadersFromStartBlockID( + _ context.Context, + _ flow.Identifier, + _ flow.BlockStatus, +) subscription.Subscription { + return nil +} + +func (*api) SubscribeBlockHeadersFromStartHeight( + _ context.Context, + _ uint64, + _ flow.BlockStatus, +) subscription.Subscription { + return nil +} + +func (*api) SubscribeBlockHeadersFromLatest( + _ context.Context, + _ flow.BlockStatus, +) subscription.Subscription { + return nil +} + +func (*api) SubscribeBlockDigestsFromStartBlockID( + _ context.Context, + _ flow.Identifier, + _ flow.BlockStatus, +) subscription.Subscription { + return nil +} + +func (*api) SubscribeBlockDigestsFromStartHeight( + _ context.Context, + _ uint64, + _ flow.BlockStatus, +) subscription.Subscription { + return nil +} + +func (*api) SubscribeBlockDigestsFromLatest( + _ context.Context, + _ flow.BlockStatus, +) subscription.Subscription { + return nil +} + +func (*api) SubscribeTransactionStatuses( + _ context.Context, + _ *flow.TransactionBody, + _ entities.EventEncodingVersion, +) subscription.Subscription { + return nil } diff --git a/cmd/util/ledger/migrations/contract_checking_migration.go b/cmd/util/ledger/migrations/contract_checking_migration.go index bffc05c9dfd..91aa7bbb902 100644 --- a/cmd/util/ledger/migrations/contract_checking_migration.go +++ b/cmd/util/ledger/migrations/contract_checking_migration.go @@ -51,7 +51,7 @@ func NewContractCheckingMigration( return fmt.Errorf("failed to create interpreter migration runtime: %w", err) } - contracts, err := gatherContractsFromRegisters(registersByAccount, log) + contracts, err := GatherContractsFromRegisters(registersByAccount, log) if err != nil { return err } @@ -64,7 +64,7 @@ func NewContractCheckingMigration( // Check all contracts for _, contract := range contracts { - checkContract( + CheckContract( contract, log, mr, @@ -80,7 +80,7 @@ func NewContractCheckingMigration( } } -func gatherContractsFromRegisters(registersByAccount *registers.ByAccount, log zerolog.Logger) ([]AddressContract, error) { +func GatherContractsFromRegisters(registersByAccount *registers.ByAccount, log zerolog.Logger) ([]AddressContract, error) { log.Info().Msg("Gathering contracts ...") contracts := make([]AddressContract, 0, contractCountEstimate) @@ -142,7 +142,7 @@ func gatherContractsFromRegisters(registersByAccount *registers.ByAccount, log z return contracts, nil } -func checkContract( +func CheckContract( contract AddressContract, log zerolog.Logger, mr *InterpreterMigrationRuntime, diff --git a/cmd/util/ledger/migrations/fix_authorizations_migration.go b/cmd/util/ledger/migrations/fix_authorizations_migration.go new file mode 100644 index 00000000000..ff4ab7898af --- /dev/null +++ b/cmd/util/ledger/migrations/fix_authorizations_migration.go @@ -0,0 +1,461 @@ +package migrations + +import ( + "encoding/json" + "errors" + "fmt" + "io" + + "github.com/onflow/cadence/migrations" + "github.com/onflow/cadence/runtime/common" + cadenceErrors "github.com/onflow/cadence/runtime/errors" + "github.com/onflow/cadence/runtime/interpreter" + "github.com/onflow/cadence/runtime/sema" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + + "github.com/onflow/flow-go/cmd/util/ledger/reporters" + "github.com/onflow/flow-go/fvm/environment" +) + +type AccountCapabilityID struct { + Address common.Address + CapabilityID uint64 +} + +// FixAuthorizationsMigration + +type FixAuthorizationsMigrationReporter interface { + MigratedCapability( + storageKey interpreter.StorageKey, + capabilityAddress common.Address, + capabilityID uint64, + ) + MigratedCapabilityController( + storageKey interpreter.StorageKey, + capabilityID uint64, + ) +} + +type FixAuthorizationsMigration struct { + Reporter FixAuthorizationsMigrationReporter + AuthorizationFixes AuthorizationFixes +} + +var _ migrations.ValueMigration = &FixAuthorizationsMigration{} + +func (*FixAuthorizationsMigration) Name() string { + return "FixAuthorizationsMigration" +} + +func (*FixAuthorizationsMigration) Domains() map[string]struct{} { + return nil +} + +func (m *FixAuthorizationsMigration) Migrate( + storageKey interpreter.StorageKey, + _ interpreter.StorageMapKey, + value interpreter.Value, + _ *interpreter.Interpreter, + _ migrations.ValueMigrationPosition, +) ( + interpreter.Value, + error, +) { + switch value := value.(type) { + case *interpreter.IDCapabilityValue: + capabilityAddress := common.Address(value.Address()) + capabilityID := uint64(value.ID) + + _, ok := m.AuthorizationFixes[AccountCapabilityID{ + Address: capabilityAddress, + CapabilityID: capabilityID, + }] + if !ok { + // This capability does not need to be fixed + return nil, nil + } + + oldBorrowType := value.BorrowType + if oldBorrowType == nil { + log.Warn().Msgf( + "missing borrow type for capability with target %s#%d", + capabilityAddress.HexWithPrefix(), + capabilityID, + ) + } + + oldBorrowReferenceType, ok := oldBorrowType.(*interpreter.ReferenceStaticType) + if !ok { + log.Warn().Msgf( + "invalid non-reference borrow type for capability with target %s#%d: %s", + capabilityAddress.HexWithPrefix(), + capabilityID, + oldBorrowType, + ) + return nil, nil + } + + newBorrowType := interpreter.NewReferenceStaticType( + nil, + interpreter.UnauthorizedAccess, + oldBorrowReferenceType.ReferencedType, + ) + newCapabilityValue := interpreter.NewUnmeteredCapabilityValue( + interpreter.UInt64Value(capabilityID), + interpreter.AddressValue(capabilityAddress), + newBorrowType, + ) + + m.Reporter.MigratedCapability( + storageKey, + capabilityAddress, + capabilityID, + ) + + return newCapabilityValue, nil + + case *interpreter.StorageCapabilityControllerValue: + // The capability controller's address is implicitly + // the address of the account in which it is stored + capabilityAddress := storageKey.Address + capabilityID := uint64(value.CapabilityID) + + _, ok := m.AuthorizationFixes[AccountCapabilityID{ + Address: capabilityAddress, + CapabilityID: capabilityID, + }] + if !ok { + // This capability controller does not need to be fixed + return nil, nil + } + + oldBorrowReferenceType := value.BorrowType + + newBorrowType := interpreter.NewReferenceStaticType( + nil, + interpreter.UnauthorizedAccess, + oldBorrowReferenceType.ReferencedType, + ) + newStorageCapabilityControllerValue := interpreter.NewUnmeteredStorageCapabilityControllerValue( + newBorrowType, + interpreter.UInt64Value(capabilityID), + value.TargetPath, + ) + + m.Reporter.MigratedCapabilityController( + storageKey, + capabilityID, + ) + + return newStorageCapabilityControllerValue, nil + } + + return nil, nil +} + +func (*FixAuthorizationsMigration) CanSkip(valueType interpreter.StaticType) bool { + return CanSkipFixAuthorizationsMigration(valueType) +} + +func CanSkipFixAuthorizationsMigration(valueType interpreter.StaticType) bool { + switch valueType := valueType.(type) { + case *interpreter.DictionaryStaticType: + return CanSkipFixAuthorizationsMigration(valueType.KeyType) && + CanSkipFixAuthorizationsMigration(valueType.ValueType) + + case interpreter.ArrayStaticType: + return CanSkipFixAuthorizationsMigration(valueType.ElementType()) + + case *interpreter.OptionalStaticType: + return CanSkipFixAuthorizationsMigration(valueType.Type) + + case *interpreter.CapabilityStaticType: + return false + + case interpreter.PrimitiveStaticType: + + switch valueType { + case interpreter.PrimitiveStaticTypeCapability, + interpreter.PrimitiveStaticTypeStorageCapabilityController: + return false + + case interpreter.PrimitiveStaticTypeBool, + interpreter.PrimitiveStaticTypeVoid, + interpreter.PrimitiveStaticTypeAddress, + interpreter.PrimitiveStaticTypeMetaType, + interpreter.PrimitiveStaticTypeBlock, + interpreter.PrimitiveStaticTypeString, + interpreter.PrimitiveStaticTypeCharacter: + + return true + } + + if !valueType.IsDeprecated() { //nolint:staticcheck + semaType := valueType.SemaType() + + if sema.IsSubType(semaType, sema.NumberType) || + sema.IsSubType(semaType, sema.PathType) { + + return true + } + } + } + + return false +} + +const fixAuthorizationsMigrationReporterName = "fix-authorizations-migration" + +func NewFixAuthorizationsMigration( + rwf reporters.ReportWriterFactory, + authorizationFixes AuthorizationFixes, + opts Options, +) *CadenceBaseMigration { + var diffReporter reporters.ReportWriter + if opts.DiffMigrations { + diffReporter = rwf.ReportWriter("fix-authorizations-migration-diff") + } + + reporter := rwf.ReportWriter(fixAuthorizationsMigrationReporterName) + + return &CadenceBaseMigration{ + name: "fix_authorizations_migration", + reporter: reporter, + diffReporter: diffReporter, + logVerboseDiff: opts.LogVerboseDiff, + verboseErrorOutput: opts.VerboseErrorOutput, + checkStorageHealthBeforeMigration: opts.CheckStorageHealthBeforeMigration, + valueMigrations: func( + _ *interpreter.Interpreter, + _ environment.Accounts, + _ *cadenceValueMigrationReporter, + ) []migrations.ValueMigration { + + return []migrations.ValueMigration{ + &FixAuthorizationsMigration{ + AuthorizationFixes: authorizationFixes, + Reporter: &fixAuthorizationsMigrationReporter{ + reportWriter: reporter, + verboseErrorOutput: opts.VerboseErrorOutput, + }, + }, + } + }, + chainID: opts.ChainID, + } +} + +type fixAuthorizationsMigrationReporter struct { + reportWriter reporters.ReportWriter + errorMessageHandler *errorMessageHandler + verboseErrorOutput bool +} + +var _ FixAuthorizationsMigrationReporter = &fixAuthorizationsMigrationReporter{} +var _ migrations.Reporter = &fixAuthorizationsMigrationReporter{} + +func (r *fixAuthorizationsMigrationReporter) Migrated( + storageKey interpreter.StorageKey, + storageMapKey interpreter.StorageMapKey, + migration string, +) { + r.reportWriter.Write(cadenceValueMigrationEntry{ + StorageKey: storageKey, + StorageMapKey: storageMapKey, + Migration: migration, + }) +} + +func (r *fixAuthorizationsMigrationReporter) Error(err error) { + + var migrationErr migrations.StorageMigrationError + + if !errors.As(err, &migrationErr) { + panic(cadenceErrors.NewUnreachableError()) + } + + message, showStack := r.errorMessageHandler.FormatError(migrationErr.Err) + + storageKey := migrationErr.StorageKey + storageMapKey := migrationErr.StorageMapKey + migration := migrationErr.Migration + + if showStack && len(migrationErr.Stack) > 0 { + message = fmt.Sprintf("%s\n%s", message, migrationErr.Stack) + } + + if r.verboseErrorOutput { + r.reportWriter.Write(cadenceValueMigrationFailureEntry{ + StorageKey: storageKey, + StorageMapKey: storageMapKey, + Migration: migration, + Message: message, + }) + } +} + +func (r *fixAuthorizationsMigrationReporter) DictionaryKeyConflict(accountAddressPath interpreter.AddressPath) { + r.reportWriter.Write(dictionaryKeyConflictEntry{ + AddressPath: accountAddressPath, + }) +} + +func (r *fixAuthorizationsMigrationReporter) MigratedCapabilityController( + storageKey interpreter.StorageKey, + capabilityID uint64, +) { + r.reportWriter.Write(capabilityControllerAuthorizationFixedEntry{ + StorageKey: storageKey, + CapabilityID: capabilityID, + }) +} + +func (r *fixAuthorizationsMigrationReporter) MigratedCapability( + storageKey interpreter.StorageKey, + capabilityAddress common.Address, + capabilityID uint64, +) { + r.reportWriter.Write(capabilityAuthorizationFixedEntry{ + StorageKey: storageKey, + CapabilityAddress: capabilityAddress, + CapabilityID: capabilityID, + }) +} + +// capabilityControllerAuthorizationFixedEntry +type capabilityControllerAuthorizationFixedEntry struct { + StorageKey interpreter.StorageKey + CapabilityID uint64 +} + +var _ json.Marshaler = capabilityControllerAuthorizationFixedEntry{} + +func (e capabilityControllerAuthorizationFixedEntry) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Kind string `json:"kind"` + AccountAddress string `json:"account_address"` + StorageDomain string `json:"domain"` + CapabilityID uint64 `json:"capability_id"` + }{ + Kind: "capability-controller-authorizations-fixed", + AccountAddress: e.StorageKey.Address.HexWithPrefix(), + StorageDomain: e.StorageKey.Key, + CapabilityID: e.CapabilityID, + }) +} + +// capabilityAuthorizationFixedEntry +type capabilityAuthorizationFixedEntry struct { + StorageKey interpreter.StorageKey + CapabilityAddress common.Address + CapabilityID uint64 +} + +var _ json.Marshaler = capabilityAuthorizationFixedEntry{} + +func (e capabilityAuthorizationFixedEntry) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Kind string `json:"kind"` + AccountAddress string `json:"account_address"` + StorageDomain string `json:"domain"` + CapabilityAddress string `json:"capability_address"` + CapabilityID uint64 `json:"capability_id"` + }{ + Kind: "capability-authorizations-fixed", + AccountAddress: e.StorageKey.Address.HexWithPrefix(), + StorageDomain: e.StorageKey.Key, + CapabilityAddress: e.CapabilityAddress.HexWithPrefix(), + CapabilityID: e.CapabilityID, + }) +} + +func NewFixAuthorizationsMigrations( + log zerolog.Logger, + rwf reporters.ReportWriterFactory, + authorizationFixes AuthorizationFixes, + opts Options, +) []NamedMigration { + + return []NamedMigration{ + { + Name: "fix-authorizations", + Migrate: NewAccountBasedMigration( + log, + opts.NWorker, + []AccountBasedMigration{ + NewFixAuthorizationsMigration( + rwf, + authorizationFixes, + opts, + ), + }, + ), + }, + } +} + +type AuthorizationFixes map[AccountCapabilityID]struct{} + +// ReadAuthorizationFixes reads a report of authorization fixes from the given reader. +// The report is expected to be a JSON array of objects with the following structure: +// +// [ +// {"capability_address":"0x1","capability_id":1} +// ] +func ReadAuthorizationFixes( + reader io.Reader, + filter map[common.Address]struct{}, +) (AuthorizationFixes, error) { + + fixes := AuthorizationFixes{} + + dec := json.NewDecoder(reader) + + token, err := dec.Token() + if err != nil { + return nil, fmt.Errorf("failed to read token: %w", err) + } + if token != json.Delim('[') { + return nil, fmt.Errorf("expected start of array, got %s", token) + } + + for dec.More() { + var entry struct { + CapabilityAddress string `json:"capability_address"` + CapabilityID uint64 `json:"capability_id"` + } + err := dec.Decode(&entry) + if err != nil { + return nil, fmt.Errorf("failed to decode entry: %w", err) + } + + address, err := common.HexToAddress(entry.CapabilityAddress) + if err != nil { + return nil, fmt.Errorf("failed to parse address: %w", err) + } + + if filter != nil { + if _, ok := filter[address]; !ok { + continue + } + } + + accountCapabilityID := AccountCapabilityID{ + Address: address, + CapabilityID: entry.CapabilityID, + } + + fixes[accountCapabilityID] = struct{}{} + } + + token, err = dec.Token() + if err != nil { + return nil, fmt.Errorf("failed to read token: %w", err) + } + if token != json.Delim(']') { + return nil, fmt.Errorf("expected end of array, got %s", token) + } + + return fixes, nil +} diff --git a/cmd/util/ledger/migrations/fix_authorizations_migration_test.go b/cmd/util/ledger/migrations/fix_authorizations_migration_test.go new file mode 100644 index 00000000000..37c2c06fb7f --- /dev/null +++ b/cmd/util/ledger/migrations/fix_authorizations_migration_test.go @@ -0,0 +1,346 @@ +package migrations + +import ( + "fmt" + "strings" + "testing" + + "github.com/onflow/cadence" + jsoncdc "github.com/onflow/cadence/encoding/json" + "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/interpreter" + "github.com/onflow/cadence/runtime/sema" + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + + "github.com/onflow/flow-go/cmd/util/ledger/util/registers" + "github.com/onflow/flow-go/model/flow" +) + +func newEntitlementSetAuthorizationFromTypeIDs( + typeIDs []common.TypeID, + setKind sema.EntitlementSetKind, +) interpreter.EntitlementSetAuthorization { + return interpreter.NewEntitlementSetAuthorization( + nil, + func() []common.TypeID { + return typeIDs + }, + len(typeIDs), + setKind, + ) +} + +func TestFixAuthorizationsMigration(t *testing.T) { + t.Parallel() + + const chainID = flow.Emulator + chain := chainID.Chain() + + const nWorker = 2 + + address, err := chain.AddressAtIndex(1000) + require.NoError(t, err) + + require.Equal(t, "bf519681cdb888b1", address.Hex()) + + log := zerolog.New(zerolog.NewTestWriter(t)) + + bootstrapPayloads, err := newBootstrapPayloads(chainID) + require.NoError(t, err) + + registersByAccount, err := registers.NewByAccountFromPayloads(bootstrapPayloads) + require.NoError(t, err) + + mr := NewBasicMigrationRuntime(registersByAccount) + err = mr.Accounts.Create(nil, address) + require.NoError(t, err) + + expectedWriteAddresses := map[flow.Address]struct{}{ + address: {}, + } + + err = mr.Commit(expectedWriteAddresses, log) + require.NoError(t, err) + + const contractCode = ` + access(all) contract Test { + + access(all) entitlement E + + access(all) struct S {} + } + ` + + deployTX := flow.NewTransactionBody(). + SetScript([]byte(` + transaction(code: String) { + prepare(signer: auth(Contracts) &Account) { + signer.contracts.add(name: "Test", code: code.utf8) + } + } + `)). + AddAuthorizer(address). + AddArgument(jsoncdc.MustEncode(cadence.String(contractCode))) + + runDeployTx := NewTransactionBasedMigration( + deployTX, + chainID, + log, + expectedWriteAddresses, + ) + err = runDeployTx(registersByAccount) + require.NoError(t, err) + + setupTx := flow.NewTransactionBody(). + SetScript([]byte(fmt.Sprintf(` + import Test from %s + + transaction { + prepare(signer: auth(Storage, Capabilities) &Account) { + signer.storage.save(Test.S(), to: /storage/s) + + // Capability 1 was a public, unauthorized capability, which is now authorized. + // It should lose its entitlement + let cap1 = signer.capabilities.storage.issue(/storage/s) + assert(cap1.borrow() != nil) + signer.capabilities.publish(cap1, at: /public/s1) + + // Capability 2 was a public, unauthorized capability, which is now authorized. + // It is currently only stored, nested, in storage, and is not published. + // It should lose its entitlement + let cap2 = signer.capabilities.storage.issue(/storage/s) + assert(cap2.borrow() != nil) + signer.storage.save([cap2], to: /storage/caps2) + + // Capability 3 was a private, authorized capability. + // It is currently only stored, nested, in storage, and is not published. + // It should keep its entitlement + let cap3 = signer.capabilities.storage.issue(/storage/s) + assert(cap3.borrow() != nil) + signer.storage.save([cap3], to: /storage/caps3) + + // Capability 4 was a private, authorized capability. + // It is currently both stored, nested, in storage, and is published. + // It should keep its entitlement + let cap4 = signer.capabilities.storage.issue(/storage/s) + assert(cap4.borrow() != nil) + signer.storage.save([cap4], to: /storage/caps4) + signer.capabilities.publish(cap4, at: /public/s4) + + // Capability 5 was a public, unauthorized capability, which is still unauthorized. + // It is currently both stored, nested, in storage, and is published. + // There is no need to fix it. + let cap5 = signer.capabilities.storage.issue<&Test.S>(/storage/s) + assert(cap5.borrow() != nil) + signer.storage.save([cap5], to: /storage/caps5) + signer.capabilities.publish(cap5, at: /public/s5) + } + } + `, + address.HexWithPrefix(), + ))). + AddAuthorizer(address) + + runSetupTx := NewTransactionBasedMigration( + setupTx, + chainID, + log, + expectedWriteAddresses, + ) + + err = runSetupTx(registersByAccount) + require.NoError(t, err) + + rwf := &testReportWriterFactory{} + + options := Options{ + ChainID: chainID, + NWorker: nWorker, + } + + fixes := AuthorizationFixes{ + AccountCapabilityID{ + Address: common.Address(address), + CapabilityID: 1, + }: {}, + AccountCapabilityID{ + Address: common.Address(address), + CapabilityID: 2, + }: {}, + } + + migrations := NewFixAuthorizationsMigrations( + log, + rwf, + fixes, + options, + ) + + for _, namedMigration := range migrations { + err = namedMigration.Migrate(registersByAccount) + require.NoError(t, err) + } + + reporter := rwf.reportWriters[fixAuthorizationsMigrationReporterName] + require.NotNil(t, reporter) + + var entries []any + + for _, entry := range reporter.entries { + switch entry := entry.(type) { + case capabilityAuthorizationFixedEntry, + capabilityControllerAuthorizationFixedEntry: + + entries = append(entries, entry) + } + } + + require.ElementsMatch(t, + []any{ + capabilityControllerAuthorizationFixedEntry{ + StorageKey: interpreter.StorageKey{ + Key: "cap_con", + Address: common.Address(address), + }, + CapabilityID: 1, + }, + capabilityControllerAuthorizationFixedEntry{ + StorageKey: interpreter.StorageKey{ + Key: "cap_con", + Address: common.Address(address), + }, + CapabilityID: 2, + }, + capabilityAuthorizationFixedEntry{ + StorageKey: interpreter.StorageKey{ + Key: "public", + Address: common.Address(address), + }, + CapabilityAddress: common.Address(address), + CapabilityID: 1, + }, + capabilityAuthorizationFixedEntry{ + StorageKey: interpreter.StorageKey{ + Key: "storage", + Address: common.Address(address), + }, + CapabilityAddress: common.Address(address), + CapabilityID: 2, + }, + }, + entries, + ) + + // Check account + + _, err = runScript( + chainID, + registersByAccount, + fmt.Sprintf( + //language=Cadence + ` + import Test from %s + + access(all) + fun main() { + let account = getAuthAccount(%[1]s) + // NOTE: capability can NOT be borrowed with E anymore + assert(account.capabilities.borrow(/public/s1) == nil) + assert(account.capabilities.borrow<&Test.S>(/public/s1) != nil) + + let caps2 = account.storage.copy<[Capability]>(from: /storage/caps2)! + // NOTE: capability can NOT be borrowed with E anymore + assert(caps2[0].borrow() == nil) + assert(caps2[0].borrow<&Test.S>() != nil) + + let caps3 = account.storage.copy<[Capability]>(from: /storage/caps3)! + // NOTE: capability can still be borrowed with E + assert(caps3[0].borrow() != nil) + assert(caps3[0].borrow<&Test.S>() != nil) + + let caps4 = account.storage.copy<[Capability]>(from: /storage/caps4)! + // NOTE: capability can still be borrowed with E + assert(account.capabilities.borrow(/public/s4) != nil) + assert(account.capabilities.borrow<&Test.S>(/public/s4) != nil) + assert(caps4[0].borrow() != nil) + assert(caps4[0].borrow<&Test.S>() != nil) + } + `, + address.HexWithPrefix(), + ), + ) + require.NoError(t, err) +} + +func TestReadAuthorizationFixes(t *testing.T) { + t.Parallel() + + validContents := ` + [ + {"capability_address":"01","capability_id":4}, + {"capability_address":"02","capability_id":5}, + {"capability_address":"03","capability_id":6} + ] + ` + + t.Run("unfiltered", func(t *testing.T) { + + t.Parallel() + + reader := strings.NewReader(validContents) + + mapping, err := ReadAuthorizationFixes(reader, nil) + require.NoError(t, err) + + require.Equal(t, + AuthorizationFixes{ + { + Address: common.MustBytesToAddress([]byte{0x1}), + CapabilityID: 4, + }: {}, + { + Address: common.MustBytesToAddress([]byte{0x2}), + CapabilityID: 5, + }: {}, + { + Address: common.MustBytesToAddress([]byte{0x3}), + CapabilityID: 6, + }: {}, + }, + mapping, + ) + }) + + t.Run("filtered", func(t *testing.T) { + + t.Parallel() + + address1 := common.MustBytesToAddress([]byte{0x1}) + address3 := common.MustBytesToAddress([]byte{0x3}) + + addressFilter := map[common.Address]struct{}{ + address1: {}, + address3: {}, + } + + reader := strings.NewReader(validContents) + + mapping, err := ReadAuthorizationFixes(reader, addressFilter) + require.NoError(t, err) + + require.Equal(t, + AuthorizationFixes{ + { + Address: common.MustBytesToAddress([]byte{0x1}), + CapabilityID: 4, + }: {}, + { + Address: common.MustBytesToAddress([]byte{0x3}), + CapabilityID: 6, + }: {}, + }, + mapping, + ) + }) +} diff --git a/cmd/util/ledger/migrations/type_requirements_extractor.go b/cmd/util/ledger/migrations/type_requirements_extractor.go index 9fcfe70396d..67de5cb52c9 100644 --- a/cmd/util/ledger/migrations/type_requirements_extractor.go +++ b/cmd/util/ledger/migrations/type_requirements_extractor.go @@ -37,7 +37,7 @@ func NewTypeRequirementsExtractingMigration( // Gather all contracts - contracts, err := gatherContractsFromRegisters(registersByAccount, log) + contracts, err := GatherContractsFromRegisters(registersByAccount, log) if err != nil { return err }