diff --git a/cmd/util/cmd/verify_execution_result/cmd.go b/cmd/util/cmd/verify_execution_result/cmd.go index a5d26d75281..9263773aa5b 100644 --- a/cmd/util/cmd/verify_execution_result/cmd.go +++ b/cmd/util/cmd/verify_execution_result/cmd.go @@ -18,6 +18,7 @@ var ( flagChunkDataPackDir string flagChain string flagFromTo string + flagWorkerCount uint // number of workers to verify the blocks concurrently ) // # verify the last 100 sealed blocks @@ -47,12 +48,20 @@ func init() { Cmd.Flags().StringVar(&flagFromTo, "from_to", "", "the height range to verify blocks (inclusive), i.e, 1-1000, 1000-2000, 2000-3000, etc.") + + Cmd.Flags().UintVar(&flagWorkerCount, "worker_count", 1, + "number of workers to use for verification, default is 1") + } func run(*cobra.Command, []string) { chainID := flow.ChainID(flagChain) _ = chainID.Chain() + if flagWorkerCount < 1 { + log.Fatal().Msgf("worker count must be at least 1, but got %v", flagWorkerCount) + } + lg := log.With(). Str("chain", string(chainID)). Str("datadir", flagDatadir). @@ -66,7 +75,7 @@ func run(*cobra.Command, []string) { } lg.Info().Msgf("verifying range from %d to %d", from, to) - err = verifier.VerifyRange(from, to, chainID, flagDatadir, flagChunkDataPackDir) + err = verifier.VerifyRange(from, to, chainID, flagDatadir, flagChunkDataPackDir, flagWorkerCount) if err != nil { lg.Fatal().Err(err).Msgf("could not verify range from %d to %d", from, to) } @@ -74,7 +83,7 @@ func run(*cobra.Command, []string) { } else { lg.Info().Msgf("verifying last %d sealed blocks", flagLastK) - err := verifier.VerifyLastKHeight(flagLastK, chainID, flagDatadir, flagChunkDataPackDir) + err := verifier.VerifyLastKHeight(flagLastK, chainID, flagDatadir, flagChunkDataPackDir, flagWorkerCount) if err != nil { lg.Fatal().Err(err).Msg("could not verify last k height") } diff --git a/engine/verification/verifier/verifiers.go b/engine/verification/verifier/verifiers.go index f92a25ad97e..7ccef14982f 100644 --- a/engine/verification/verifier/verifiers.go +++ b/engine/verification/verifier/verifiers.go @@ -1,8 +1,10 @@ package verifier import ( + "context" "errors" "fmt" + "sync" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -25,7 +27,7 @@ import ( // It assumes the latest sealed block has been executed, and the chunk data packs have not been // pruned. // Note, it returns nil if certain block is not executed, in this case warning will be logged -func VerifyLastKHeight(k uint64, chainID flow.ChainID, protocolDataDir string, chunkDataPackDir string) (err error) { +func VerifyLastKHeight(k uint64, chainID flow.ChainID, protocolDataDir string, chunkDataPackDir string, nWorker uint) (err error) { closer, storages, chunkDataPacks, state, verifier, err := initStorages(chainID, protocolDataDir, chunkDataPackDir) if err != nil { return fmt.Errorf("could not init storages: %w", err) @@ -62,12 +64,9 @@ func VerifyLastKHeight(k uint64, chainID flow.ChainID, protocolDataDir string, c log.Info().Msgf("verifying blocks from %d to %d", from, to) - for height := from; height <= to; height++ { - log.Info().Uint64("height", height).Msg("verifying height") - err := verifyHeight(height, storages.Headers, chunkDataPacks, storages.Results, state, verifier) - if err != nil { - return fmt.Errorf("could not verify height %d: %w", height, err) - } + err = verifyConcurrently(from, to, nWorker, storages.Headers, chunkDataPacks, storages.Results, state, verifier, verifyHeight) + if err != nil { + return err } return nil @@ -79,6 +78,7 @@ func VerifyRange( from, to uint64, chainID flow.ChainID, protocolDataDir string, chunkDataPackDir string, + nWorker uint, ) (err error) { closer, storages, chunkDataPacks, state, verifier, err := initStorages(chainID, protocolDataDir, chunkDataPackDir) if err != nil { @@ -99,12 +99,94 @@ func VerifyRange( return fmt.Errorf("cannot verify blocks before the root block, from: %d, root: %d", from, root) } - for height := from; height <= to; height++ { - log.Info().Uint64("height", height).Msg("verifying height") - err := verifyHeight(height, storages.Headers, chunkDataPacks, storages.Results, state, verifier) - if err != nil { - return fmt.Errorf("could not verify height %d: %w", height, err) + err = verifyConcurrently(from, to, nWorker, storages.Headers, chunkDataPacks, storages.Results, state, verifier, verifyHeight) + if err != nil { + return err + } + + return nil +} + +func verifyConcurrently( + from, to uint64, + nWorker uint, + headers storage.Headers, + chunkDataPacks storage.ChunkDataPacks, + results storage.ExecutionResults, + state protocol.State, + verifier module.ChunkVerifier, + verifyHeight func(uint64, storage.Headers, storage.ChunkDataPacks, storage.ExecutionResults, protocol.State, module.ChunkVerifier) error, +) error { + tasks := make(chan uint64, int(nWorker)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // Ensure cancel is called to release resources + + var lowestErr error + var lowestErrHeight uint64 = ^uint64(0) // Initialize to max value of uint64 + var mu sync.Mutex // To protect access to lowestErr and lowestErrHeight + + // Worker function + worker := func() { + for { + select { + case <-ctx.Done(): + return // Stop processing tasks if context is canceled + case height, ok := <-tasks: + if !ok { + return // Exit if the tasks channel is closed + } + log.Info().Uint64("height", height).Msg("verifying height") + err := verifyHeight(height, headers, chunkDataPacks, results, state, verifier) + if err != nil { + log.Error().Uint64("height", height).Err(err).Msg("error encountered while verifying height") + + // when encountered an error, the error might not be from the lowest height that had + // error, so we need to first cancel the context to stop worker from processing further tasks + // and wait until all workers are done, which will ensure all the heights before this height + // that had error are processed. Then we can safely update the lowestErr and lowestErrHeight + mu.Lock() + if height < lowestErrHeight { + lowestErr = err + lowestErrHeight = height + cancel() // Cancel context to stop further task dispatch + } + mu.Unlock() + } else { + log.Info().Uint64("height", height).Msg("verified height successfully") + } + } + } + } + + // Start nWorker workers + var wg sync.WaitGroup + for i := 0; i < int(nWorker); i++ { + wg.Add(1) + go func() { + defer wg.Done() + worker() + }() + } + + // Send tasks to workers + go func() { + defer close(tasks) // Close tasks channel once all tasks are pushed + for height := from; height <= to; height++ { + select { + case <-ctx.Done(): + return // Stop pushing tasks if context is canceled + case tasks <- height: + } } + }() + + // Wait for all workers to complete + wg.Wait() + + // Check if there was an error + if lowestErr != nil { + log.Error().Uint64("height", lowestErrHeight).Err(lowestErr).Msg("error encountered while verifying height") + return fmt.Errorf("could not verify height %d: %w", lowestErrHeight, lowestErr) } return nil diff --git a/engine/verification/verifier/verifiers_test.go b/engine/verification/verifier/verifiers_test.go new file mode 100644 index 00000000000..c16486bb098 --- /dev/null +++ b/engine/verification/verifier/verifiers_test.go @@ -0,0 +1,85 @@ +package verifier + +import ( + "errors" + "fmt" + "testing" + + "github.com/onflow/flow-go/module" + mockmodule "github.com/onflow/flow-go/module/mock" + "github.com/onflow/flow-go/state/protocol" + "github.com/onflow/flow-go/storage" + "github.com/onflow/flow-go/storage/mock" + unittestMocks "github.com/onflow/flow-go/utils/unittest/mocks" +) + +func TestVerifyConcurrently(t *testing.T) { + + tests := []struct { + name string + from uint64 + to uint64 + nWorker uint + errors map[uint64]error // Map of heights to errors + expectedErr error + }{ + { + name: "All heights verified successfully", + from: 1, + to: 5, + nWorker: 3, + errors: nil, + expectedErr: nil, + }, + { + name: "Single error at a height", + from: 1, + to: 5, + nWorker: 3, + errors: map[uint64]error{3: errors.New("mock error")}, + expectedErr: fmt.Errorf("mock error"), + }, + { + name: "Multiple errors, lowest height returned", + from: 1, + to: 5, + nWorker: 3, + errors: map[uint64]error{2: errors.New("error 2"), 4: errors.New("error 4")}, + expectedErr: fmt.Errorf("error 2"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset mockVerifyHeight for each test + mockVerifyHeight := func( + height uint64, + headers storage.Headers, + chunkDataPacks storage.ChunkDataPacks, + results storage.ExecutionResults, + state protocol.State, + verifier module.ChunkVerifier, + ) error { + if err, ok := tt.errors[height]; ok { + return err + } + return nil + } + + mockHeaders := mock.NewHeaders(t) + mockChunkDataPacks := mock.NewChunkDataPacks(t) + mockResults := mock.NewExecutionResults(t) + mockState := unittestMocks.NewProtocolState() + mockVerifier := mockmodule.NewChunkVerifier(t) + + err := verifyConcurrently(tt.from, tt.to, tt.nWorker, mockHeaders, mockChunkDataPacks, mockResults, mockState, mockVerifier, mockVerifyHeight) + if tt.expectedErr != nil { + if err == nil || errors.Is(err, tt.expectedErr) { + t.Fatalf("expected error: %v, got: %v", tt.expectedErr, err) + } + } else if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + }) + } +}