diff --git a/cmd/util/cmd/verify_execution_result/cmd.go b/cmd/util/cmd/verify_execution_result/cmd.go index a5d26d75281..9263773aa5b 100644 --- a/cmd/util/cmd/verify_execution_result/cmd.go +++ b/cmd/util/cmd/verify_execution_result/cmd.go @@ -18,6 +18,7 @@ var ( flagChunkDataPackDir string flagChain string flagFromTo string + flagWorkerCount uint // number of workers to verify the blocks concurrently ) // # verify the last 100 sealed blocks @@ -47,12 +48,20 @@ func init() { Cmd.Flags().StringVar(&flagFromTo, "from_to", "", "the height range to verify blocks (inclusive), i.e, 1-1000, 1000-2000, 2000-3000, etc.") + + Cmd.Flags().UintVar(&flagWorkerCount, "worker_count", 1, + "number of workers to use for verification, default is 1") + } func run(*cobra.Command, []string) { chainID := flow.ChainID(flagChain) _ = chainID.Chain() + if flagWorkerCount < 1 { + log.Fatal().Msgf("worker count must be at least 1, but got %v", flagWorkerCount) + } + lg := log.With(). Str("chain", string(chainID)). Str("datadir", flagDatadir). @@ -66,7 +75,7 @@ func run(*cobra.Command, []string) { } lg.Info().Msgf("verifying range from %d to %d", from, to) - err = verifier.VerifyRange(from, to, chainID, flagDatadir, flagChunkDataPackDir) + err = verifier.VerifyRange(from, to, chainID, flagDatadir, flagChunkDataPackDir, flagWorkerCount) if err != nil { lg.Fatal().Err(err).Msgf("could not verify range from %d to %d", from, to) } @@ -74,7 +83,7 @@ func run(*cobra.Command, []string) { } else { lg.Info().Msgf("verifying last %d sealed blocks", flagLastK) - err := verifier.VerifyLastKHeight(flagLastK, chainID, flagDatadir, flagChunkDataPackDir) + err := verifier.VerifyLastKHeight(flagLastK, chainID, flagDatadir, flagChunkDataPackDir, flagWorkerCount) if err != nil { lg.Fatal().Err(err).Msg("could not verify last k height") } diff --git a/engine/verification/verifier/verifiers.go b/engine/verification/verifier/verifiers.go index f92a25ad97e..7ccef14982f 100644 --- a/engine/verification/verifier/verifiers.go +++ b/engine/verification/verifier/verifiers.go @@ -1,8 +1,10 @@ package verifier import ( + "context" "errors" "fmt" + "sync" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -25,7 +27,7 @@ import ( // It assumes the latest sealed block has been executed, and the chunk data packs have not been // pruned. // Note, it returns nil if certain block is not executed, in this case warning will be logged -func VerifyLastKHeight(k uint64, chainID flow.ChainID, protocolDataDir string, chunkDataPackDir string) (err error) { +func VerifyLastKHeight(k uint64, chainID flow.ChainID, protocolDataDir string, chunkDataPackDir string, nWorker uint) (err error) { closer, storages, chunkDataPacks, state, verifier, err := initStorages(chainID, protocolDataDir, chunkDataPackDir) if err != nil { return fmt.Errorf("could not init storages: %w", err) @@ -62,12 +64,9 @@ func VerifyLastKHeight(k uint64, chainID flow.ChainID, protocolDataDir string, c log.Info().Msgf("verifying blocks from %d to %d", from, to) - for height := from; height <= to; height++ { - log.Info().Uint64("height", height).Msg("verifying height") - err := verifyHeight(height, storages.Headers, chunkDataPacks, storages.Results, state, verifier) - if err != nil { - return fmt.Errorf("could not verify height %d: %w", height, err) - } + err = verifyConcurrently(from, to, nWorker, storages.Headers, chunkDataPacks, storages.Results, state, verifier, verifyHeight) + if err != nil { + return err } return nil @@ -79,6 +78,7 @@ func VerifyRange( from, to uint64, chainID flow.ChainID, protocolDataDir string, chunkDataPackDir string, + nWorker uint, ) (err error) { closer, storages, chunkDataPacks, state, verifier, err := initStorages(chainID, protocolDataDir, chunkDataPackDir) if err != nil { @@ -99,12 +99,94 @@ func VerifyRange( return fmt.Errorf("cannot verify blocks before the root block, from: %d, root: %d", from, root) } - for height := from; height <= to; height++ { - log.Info().Uint64("height", height).Msg("verifying height") - err := verifyHeight(height, storages.Headers, chunkDataPacks, storages.Results, state, verifier) - if err != nil { - return fmt.Errorf("could not verify height %d: %w", height, err) + err = verifyConcurrently(from, to, nWorker, storages.Headers, chunkDataPacks, storages.Results, state, verifier, verifyHeight) + if err != nil { + return err + } + + return nil +} + +func verifyConcurrently( + from, to uint64, + nWorker uint, + headers storage.Headers, + chunkDataPacks storage.ChunkDataPacks, + results storage.ExecutionResults, + state protocol.State, + verifier module.ChunkVerifier, + verifyHeight func(uint64, storage.Headers, storage.ChunkDataPacks, storage.ExecutionResults, protocol.State, module.ChunkVerifier) error, +) error { + tasks := make(chan uint64, int(nWorker)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // Ensure cancel is called to release resources + + var lowestErr error + var lowestErrHeight uint64 = ^uint64(0) // Initialize to max value of uint64 + var mu sync.Mutex // To protect access to lowestErr and lowestErrHeight + + // Worker function + worker := func() { + for { + select { + case <-ctx.Done(): + return // Stop processing tasks if context is canceled + case height, ok := <-tasks: + if !ok { + return // Exit if the tasks channel is closed + } + log.Info().Uint64("height", height).Msg("verifying height") + err := verifyHeight(height, headers, chunkDataPacks, results, state, verifier) + if err != nil { + log.Error().Uint64("height", height).Err(err).Msg("error encountered while verifying height") + + // when encountered an error, the error might not be from the lowest height that had + // error, so we need to first cancel the context to stop worker from processing further tasks + // and wait until all workers are done, which will ensure all the heights before this height + // that had error are processed. Then we can safely update the lowestErr and lowestErrHeight + mu.Lock() + if height < lowestErrHeight { + lowestErr = err + lowestErrHeight = height + cancel() // Cancel context to stop further task dispatch + } + mu.Unlock() + } else { + log.Info().Uint64("height", height).Msg("verified height successfully") + } + } + } + } + + // Start nWorker workers + var wg sync.WaitGroup + for i := 0; i < int(nWorker); i++ { + wg.Add(1) + go func() { + defer wg.Done() + worker() + }() + } + + // Send tasks to workers + go func() { + defer close(tasks) // Close tasks channel once all tasks are pushed + for height := from; height <= to; height++ { + select { + case <-ctx.Done(): + return // Stop pushing tasks if context is canceled + case tasks <- height: + } } + }() + + // Wait for all workers to complete + wg.Wait() + + // Check if there was an error + if lowestErr != nil { + log.Error().Uint64("height", lowestErrHeight).Err(lowestErr).Msg("error encountered while verifying height") + return fmt.Errorf("could not verify height %d: %w", lowestErrHeight, lowestErr) } return nil diff --git a/engine/verification/verifier/verifiers_test.go b/engine/verification/verifier/verifiers_test.go new file mode 100644 index 00000000000..c16486bb098 --- /dev/null +++ b/engine/verification/verifier/verifiers_test.go @@ -0,0 +1,85 @@ +package verifier + +import ( + "errors" + "fmt" + "testing" + + "github.com/onflow/flow-go/module" + mockmodule "github.com/onflow/flow-go/module/mock" + "github.com/onflow/flow-go/state/protocol" + "github.com/onflow/flow-go/storage" + "github.com/onflow/flow-go/storage/mock" + unittestMocks "github.com/onflow/flow-go/utils/unittest/mocks" +) + +func TestVerifyConcurrently(t *testing.T) { + + tests := []struct { + name string + from uint64 + to uint64 + nWorker uint + errors map[uint64]error // Map of heights to errors + expectedErr error + }{ + { + name: "All heights verified successfully", + from: 1, + to: 5, + nWorker: 3, + errors: nil, + expectedErr: nil, + }, + { + name: "Single error at a height", + from: 1, + to: 5, + nWorker: 3, + errors: map[uint64]error{3: errors.New("mock error")}, + expectedErr: fmt.Errorf("mock error"), + }, + { + name: "Multiple errors, lowest height returned", + from: 1, + to: 5, + nWorker: 3, + errors: map[uint64]error{2: errors.New("error 2"), 4: errors.New("error 4")}, + expectedErr: fmt.Errorf("error 2"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset mockVerifyHeight for each test + mockVerifyHeight := func( + height uint64, + headers storage.Headers, + chunkDataPacks storage.ChunkDataPacks, + results storage.ExecutionResults, + state protocol.State, + verifier module.ChunkVerifier, + ) error { + if err, ok := tt.errors[height]; ok { + return err + } + return nil + } + + mockHeaders := mock.NewHeaders(t) + mockChunkDataPacks := mock.NewChunkDataPacks(t) + mockResults := mock.NewExecutionResults(t) + mockState := unittestMocks.NewProtocolState() + mockVerifier := mockmodule.NewChunkVerifier(t) + + err := verifyConcurrently(tt.from, tt.to, tt.nWorker, mockHeaders, mockChunkDataPacks, mockResults, mockState, mockVerifier, mockVerifyHeight) + if tt.expectedErr != nil { + if err == nil || errors.Is(err, tt.expectedErr) { + t.Fatalf("expected error: %v, got: %v", tt.expectedErr, err) + } + } else if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + }) + } +} diff --git a/module/metrics/server.go b/module/metrics/server.go index cd8187b1fbd..09221cb35ce 100644 --- a/module/metrics/server.go +++ b/module/metrics/server.go @@ -3,18 +3,29 @@ package metrics import ( "context" "errors" + "net" "net/http" "strconv" "time" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + + "github.com/onflow/flow-go/module/component" + "github.com/onflow/flow-go/module/irrecoverable" ) +// metricsServerShutdownTimeout is the time to wait for the server to shut down gracefully +const metricsServerShutdownTimeout = 5 * time.Second + // Server is the http server that will be serving the /metrics request for prometheus type Server struct { - server *http.Server - log zerolog.Logger + component.Component + + address string + server *http.Server + log zerolog.Logger } // NewServer creates a new server that will start on the specified port, @@ -25,44 +36,71 @@ func NewServer(log zerolog.Logger, port uint) *Server { mux := http.NewServeMux() endpoint := "/metrics" mux.Handle(endpoint, promhttp.Handler()) - log.Info().Str("address", addr).Str("endpoint", endpoint).Msg("metrics server started") m := &Server{ - server: &http.Server{Addr: addr, Handler: mux}, - log: log, + address: addr, + server: &http.Server{Addr: addr, Handler: mux}, + log: log.With().Str("address", addr).Str("endpoint", endpoint).Logger(), } + m.Component = component.NewComponentManagerBuilder(). + AddWorker(m.serve). + AddWorker(m.shutdownOnContextDone). + Build() + return m } -// Ready returns a channel that will close when the network stack is ready. -func (m *Server) Ready() <-chan struct{} { - ready := make(chan struct{}) - go func() { - if err := m.server.ListenAndServe(); err != nil { - // http.ErrServerClosed is returned when Close or Shutdown is called - // we don't consider this an error, so print this with debug level instead - if errors.Is(err, http.ErrServerClosed) { - m.log.Debug().Err(err).Msg("metrics server shutdown") - } else { - m.log.Err(err).Msg("error shutting down metrics server") - } +func (m *Server) serve(ctx irrecoverable.SignalerContext, ready component.ReadyFunc) { + m.log.Info().Msg("starting metrics server on address") + + l, err := net.Listen("tcp", m.address) + if err != nil { + m.log.Err(err).Msg("failed to start the metrics server") + ctx.Throw(err) + return + } + + ready() + + // pass the signaler context to the server so that the signaler context + // can control the server's lifetime + m.server.BaseContext = func(_ net.Listener) context.Context { + return ctx + } + + err = m.server.Serve(l) // blocking call + if err != nil { + if errors.Is(err, http.ErrServerClosed) { + return } - }() - go func() { - close(ready) - }() - return ready + log.Err(err).Msg("fatal error in the metrics server") + ctx.Throw(err) + } } -// Done returns a channel that will close when shutdown is complete. -func (m *Server) Done() <-chan struct{} { - done := make(chan struct{}) - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - _ = m.server.Shutdown(ctx) - cancel() - close(done) - }() - return done +func (m *Server) shutdownOnContextDone(ictx irrecoverable.SignalerContext, ready component.ReadyFunc) { + ready() + <-ictx.Done() + + ctx, cancel := context.WithTimeout(context.Background(), metricsServerShutdownTimeout) + defer cancel() + + // shutdown the server gracefully + err := m.server.Shutdown(ctx) + if err == nil { + m.log.Info().Msg("metrics server graceful shutdown completed") + return + } + + if errors.Is(err, ctx.Err()) { + m.log.Warn().Msg("metrics server graceful shutdown timed out") + // shutdown the server forcefully + err := m.server.Close() + if err != nil { + m.log.Err(err).Msg("error closing metrics server") + } + } else { + m.log.Err(err).Msg("error shutting down metrics server") + } }