From 372bfb09a31347417fdda7c5320f8a37b9daf8e2 Mon Sep 17 00:00:00 2001 From: Kuisong Tong Date: Sun, 5 May 2024 07:54:46 -0700 Subject: [PATCH] use context.AfterFunc (#38) --- grpc/server.go | 26 +++++++++++---------- http/server.go | 20 ++++++++-------- pprof.go | 51 ++++++++++++++++++---------------------- run/run.go | 44 ++--------------------------------- run/run_test.go | 62 ------------------------------------------------- run/runner.go | 10 ++++---- 6 files changed, 54 insertions(+), 159 deletions(-) delete mode 100644 run/run_test.go diff --git a/grpc/server.go b/grpc/server.go index d4e4cf8..5ea411c 100644 --- a/grpc/server.go +++ b/grpc/server.go @@ -86,8 +86,21 @@ func Run(server *grpc.Server, opts ...Option) func(context.Context) error { //no ctx, cancel := context.WithCancelCause(ctx) defer cancel(nil) + defer context.AfterFunc(ctx, func() { + slog.LogAttrs(ctx, slog.LevelInfo, "Starting shutdown gRPC Server...") + if healthServer != nil { + // Shutdown health server so client knows it's not serving. + slog.LogAttrs(ctx, slog.LevelInfo, "Starting shutdown gRPC Health service...") + healthServer.Shutdown() + slog.LogAttrs(ctx, slog.LevelInfo, "Shutdown gRPC Health service completed.") + } + server.GracefulStop() + slog.LogAttrs(ctx, slog.LevelInfo, "Shutdown gRPC Server completed.") + })() + + slog.LogAttrs(ctx, slog.LevelInfo, "Starting gRPC Server...") var waitGroup sync.WaitGroup - waitGroup.Add(len(option.addresses) + 1) + waitGroup.Add(len(option.addresses)) for _, addr := range option.addresses { addr := addr go func() { @@ -111,17 +124,6 @@ func Run(server *grpc.Server, opts ...Option) func(context.Context) error { //no } }() } - go func() { - defer waitGroup.Done() - - <-ctx.Done() - if healthServer != nil { - // Shutdown health server so client knows it's not serving. - healthServer.Shutdown() - } - server.GracefulStop() - slog.LogAttrs(ctx, slog.LevelInfo, "gRPC Server is stopped.") - }() waitGroup.Wait() if err := context.Cause(ctx); err != nil && !errors.Is(err, ctx.Err()) { diff --git a/http/server.go b/http/server.go index bef4ca4..b82a04f 100644 --- a/http/server.go +++ b/http/server.go @@ -82,8 +82,17 @@ func Run(server *http.Server, opts ...Option) func(context.Context) error { //no ctx, cancel := context.WithCancelCause(ctx) defer cancel(nil) + defer context.AfterFunc(ctx, func() { + slog.LogAttrs(ctx, slog.LevelInfo, "Starting shutdown HTTP Server...") + if err := server.Shutdown(context.WithoutCancel(ctx)); err != nil { + cancel(fmt.Errorf("shutdown HTTP Server: %w", err)) + } + slog.LogAttrs(ctx, slog.LevelInfo, "Shutdown HTTP Server completed.") + })() + + slog.LogAttrs(ctx, slog.LevelInfo, "Starting HTTP Server...") var waitGroup sync.WaitGroup - waitGroup.Add(len(option.addresses) + 1) + waitGroup.Add(len(option.addresses)) if slices.ContainsFunc(option.addresses, func(addr socket) bool { return addr.network == unix }) { if transport, ok := http.DefaultTransport.(*http.Transport); ok { internal.RegisterUnixProtocol(transport) @@ -112,15 +121,6 @@ func Run(server *http.Server, opts ...Option) func(context.Context) error { //no } }() } - go func() { - defer waitGroup.Done() - - <-ctx.Done() - if err := server.Shutdown(context.WithoutCancel(ctx)); err != nil { - cancel(fmt.Errorf("shutdown HTTP Server: %w", err)) - } - slog.LogAttrs(ctx, slog.LevelInfo, "HTTP Server is stopped.") - }() waitGroup.Wait() if err := context.Cause(ctx); err != nil && !errors.Is(err, ctx.Err()) { diff --git a/pprof.go b/pprof.go index b5651da..dc42128 100644 --- a/pprof.go +++ b/pprof.go @@ -13,8 +13,6 @@ import ( "net/http/pprof" "runtime" "time" - - "github.com/nil-go/nilgo/run" ) // PProf starts a pprof server at localhost:6060. @@ -27,38 +25,35 @@ func PProf(ctx context.Context) error { mux.HandleFunc("/debug/pprof/profile", pprof.Profile) mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) mux.HandleFunc("/debug/pprof/trace", pprof.Trace) - server := &http.Server{ Handler: mux, ReadTimeout: time.Second, } - return run.WithCloser( - func(ctx context.Context) error { - listener, err := net.Listen("tcp", "localhost:6060") - if err != nil { - listener, err = net.Listen("tcp", "localhost:0") - if err != nil { - slog.LogAttrs(ctx, slog.LevelWarn, "Fail to find port for pprof server.", slog.Any("error", err)) - - return nil - } - } - slog.LogAttrs(ctx, slog.LevelInfo, fmt.Sprintf("Start pprof server at http://%s/debug/pprof/.", listener.Addr())) - - runtime.SetBlockProfileRate(1) // Required by gRPC server. - if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { - slog.LogAttrs(ctx, slog.LevelWarn, "Fail to start pprof server.", slog.Any("error", err)) - } + defer context.AfterFunc(ctx, func() { + slog.LogAttrs(ctx, slog.LevelInfo, "Starting shutdown pprof Server...") + if err := server.Shutdown(context.WithoutCancel(ctx)); err != nil { + slog.LogAttrs(ctx, slog.LevelWarn, "Fail to shutdown pprof server.", slog.Any("error", err)) + } + slog.LogAttrs(ctx, slog.LevelInfo, "Shutdown pprof Server completed.") + })() + + slog.LogAttrs(ctx, slog.LevelInfo, "Starting pprof server.") + listener, err := net.Listen("tcp", "localhost:6060") + if err != nil { + listener, err = net.Listen("tcp", "localhost:0") + if err != nil { + slog.LogAttrs(ctx, slog.LevelWarn, "Fail to find port for pprof server.", slog.Any("error", err)) return nil - }, - func() error { - if err := server.Shutdown(context.WithoutCancel(ctx)); err != nil { - slog.LogAttrs(ctx, slog.LevelWarn, "Fail to shutdown pprof server.", slog.Any("error", err)) - } + } + } + slog.LogAttrs(ctx, slog.LevelInfo, fmt.Sprintf("pprof server started at http://%s/debug/pprof/.", listener.Addr())) - return nil - }, - )(ctx) + runtime.SetBlockProfileRate(1) // Required by gRPC server. + if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + slog.LogAttrs(ctx, slog.LevelWarn, "Fail to start pprof server.", slog.Any("error", err)) + } + + return nil } diff --git a/run/run.go b/run/run.go index 0f3ec01..4eec700 100644 --- a/run/run.go +++ b/run/run.go @@ -9,11 +9,11 @@ import ( "sync" ) -// Parallel executes the given runs in parallel. +// parallel executes the given runs in parallel. // // It blocks until all runs complete or ctx is done, then // returns the first non-nil error if received from any run. -func Parallel(ctx context.Context, runs ...func(context.Context) error) error { +func parallel(ctx context.Context, runs ...func(context.Context) error) error { ctx, cancel := context.WithCancelCause(ctx) defer cancel(nil) @@ -37,43 +37,3 @@ func Parallel(ctx context.Context, runs ...func(context.Context) error) error { return nil } - -// WithCloser wraps the given run with a dedicated closer function, -// which should cause the run returns when closer function executes. -// -// closer function executes even if run function returns non-nil error. -// It guarantees both run and closer functions complete -// and return the first non-nil error if any. -func WithCloser(run func(context.Context) error, closer func() error) func(context.Context) error { - if run == nil { - run = func(context.Context) error { return nil } - } - if closer == nil { - closer = func() error { return nil } - } - - return func(ctx context.Context) error { - ctx, cancel := context.WithCancelCause(ctx) - defer cancel(nil) - - var waitGroup sync.WaitGroup - waitGroup.Add(1) - go func() { - defer waitGroup.Done() - - if err := run(ctx); err != nil { - cancel(err) - } - }() - - <-ctx.Done() - err := closer() - waitGroup.Wait() - - if e := context.Cause(ctx); e != nil && !errors.Is(e, ctx.Err()) { - return e //nolint:wrapcheck - } - - return err - } -} diff --git a/run/run_test.go b/run/run_test.go deleted file mode 100644 index 67bd3e9..0000000 --- a/run/run_test.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) 2024 The nilgo authors -// Use of this source code is governed by a MIT license found in the LICENSE file. - -package run_test - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/nil-go/nilgo/internal/assert" - "github.com/nil-go/nilgo/run" -) - -func TestWithCloser(t *testing.T) { - testcases := []struct { - description string - run func(ctx context.Context) error - closer func() error - err string - }{ - { - description: "no error", - run: func(context.Context) error { - return nil - }, - closer: func() error { - return nil - }, - }, - { - description: "run error", - run: func(context.Context) error { - return errors.New("run error") - }, - err: "run error", - }, - { - description: "closer error", - closer: func() error { - return errors.New("closer error") - }, - err: "closer error", - }, - } - - for _, testcase := range testcases { - testcase := testcase - - t.Run(testcase.description, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - err := run.WithCloser(testcase.run, testcase.closer)(ctx) - if testcase.err == "" { - assert.NoError(t, err) - } else { - assert.EqualError(t, err, testcase.err) - } - }) - } -} diff --git a/run/runner.go b/run/runner.go index 5cf47d0..b18ba40 100644 --- a/run/runner.go +++ b/run/runner.go @@ -89,7 +89,7 @@ func (r Runner) Run(ctx context.Context, runs ...func(context.Context) error) er runCtx, runCancel := context.WithCancel(rootCtx) defer runCancel() - return Parallel(rootCtx, + return parallel(rootCtx, append(preRuns, func(context.Context) error { defer runCancel() // Notify all main runs to stop. @@ -97,14 +97,14 @@ func (r Runner) Run(ctx context.Context, runs ...func(context.Context) error) er <-signalCtx.Done() // Wait for all stop gates to open. - return Parallel(runCtx, r.stopGates...) + return parallel(runCtx, r.stopGates...) }, func(context.Context) (err error) { //nolint:nonamedreturns defer func() { signalCancel() // Stop listening to OS signals. // Wait for all post runs to finish. - e := Parallel(rootCtx, r.postRuns...) + e := parallel(rootCtx, r.postRuns...) if err == nil { err = e } @@ -112,11 +112,11 @@ func (r Runner) Run(ctx context.Context, runs ...func(context.Context) error) er }() // Wait for all start gates to open. - if err = Parallel(runCtx, startGates...); err != nil { + if err = parallel(runCtx, startGates...); err != nil { return err } - return Parallel(runCtx, runs...) + return parallel(runCtx, runs...) }, )..., )