Skip to content

Commit

Permalink
use context.AfterFunc (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
ktong authored May 5, 2024
1 parent adeef1b commit 372bfb0
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 159 deletions.
26 changes: 14 additions & 12 deletions grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,21 @@ func Run(server *grpc.Server, opts ...Option) func(context.Context) error { //no
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)

defer context.AfterFunc(ctx, func() {
slog.LogAttrs(ctx, slog.LevelInfo, "Starting shutdown gRPC Server...")
if healthServer != nil {
// Shutdown health server so client knows it's not serving.
slog.LogAttrs(ctx, slog.LevelInfo, "Starting shutdown gRPC Health service...")
healthServer.Shutdown()
slog.LogAttrs(ctx, slog.LevelInfo, "Shutdown gRPC Health service completed.")
}
server.GracefulStop()
slog.LogAttrs(ctx, slog.LevelInfo, "Shutdown gRPC Server completed.")
})()

slog.LogAttrs(ctx, slog.LevelInfo, "Starting gRPC Server...")
var waitGroup sync.WaitGroup
waitGroup.Add(len(option.addresses) + 1)
waitGroup.Add(len(option.addresses))
for _, addr := range option.addresses {
addr := addr
go func() {
Expand All @@ -111,17 +124,6 @@ func Run(server *grpc.Server, opts ...Option) func(context.Context) error { //no
}
}()
}
go func() {
defer waitGroup.Done()

<-ctx.Done()
if healthServer != nil {
// Shutdown health server so client knows it's not serving.
healthServer.Shutdown()
}
server.GracefulStop()
slog.LogAttrs(ctx, slog.LevelInfo, "gRPC Server is stopped.")
}()
waitGroup.Wait()

if err := context.Cause(ctx); err != nil && !errors.Is(err, ctx.Err()) {
Expand Down
20 changes: 10 additions & 10 deletions http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,17 @@ func Run(server *http.Server, opts ...Option) func(context.Context) error { //no
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)

defer context.AfterFunc(ctx, func() {
slog.LogAttrs(ctx, slog.LevelInfo, "Starting shutdown HTTP Server...")
if err := server.Shutdown(context.WithoutCancel(ctx)); err != nil {
cancel(fmt.Errorf("shutdown HTTP Server: %w", err))
}
slog.LogAttrs(ctx, slog.LevelInfo, "Shutdown HTTP Server completed.")
})()

slog.LogAttrs(ctx, slog.LevelInfo, "Starting HTTP Server...")
var waitGroup sync.WaitGroup
waitGroup.Add(len(option.addresses) + 1)
waitGroup.Add(len(option.addresses))
if slices.ContainsFunc(option.addresses, func(addr socket) bool { return addr.network == unix }) {
if transport, ok := http.DefaultTransport.(*http.Transport); ok {
internal.RegisterUnixProtocol(transport)
Expand Down Expand Up @@ -112,15 +121,6 @@ func Run(server *http.Server, opts ...Option) func(context.Context) error { //no
}
}()
}
go func() {
defer waitGroup.Done()

<-ctx.Done()
if err := server.Shutdown(context.WithoutCancel(ctx)); err != nil {
cancel(fmt.Errorf("shutdown HTTP Server: %w", err))
}
slog.LogAttrs(ctx, slog.LevelInfo, "HTTP Server is stopped.")
}()
waitGroup.Wait()

if err := context.Cause(ctx); err != nil && !errors.Is(err, ctx.Err()) {
Expand Down
51 changes: 23 additions & 28 deletions pprof.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ import (
"net/http/pprof"
"runtime"
"time"

"github.com/nil-go/nilgo/run"
)

// PProf starts a pprof server at localhost:6060.
Expand All @@ -27,38 +25,35 @@ func PProf(ctx context.Context) error {
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)

server := &http.Server{
Handler: mux,
ReadTimeout: time.Second,
}

return run.WithCloser(
func(ctx context.Context) error {
listener, err := net.Listen("tcp", "localhost:6060")
if err != nil {
listener, err = net.Listen("tcp", "localhost:0")
if err != nil {
slog.LogAttrs(ctx, slog.LevelWarn, "Fail to find port for pprof server.", slog.Any("error", err))

return nil
}
}
slog.LogAttrs(ctx, slog.LevelInfo, fmt.Sprintf("Start pprof server at http://%s/debug/pprof/.", listener.Addr()))

runtime.SetBlockProfileRate(1) // Required by gRPC server.
if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
slog.LogAttrs(ctx, slog.LevelWarn, "Fail to start pprof server.", slog.Any("error", err))
}
defer context.AfterFunc(ctx, func() {
slog.LogAttrs(ctx, slog.LevelInfo, "Starting shutdown pprof Server...")
if err := server.Shutdown(context.WithoutCancel(ctx)); err != nil {
slog.LogAttrs(ctx, slog.LevelWarn, "Fail to shutdown pprof server.", slog.Any("error", err))
}
slog.LogAttrs(ctx, slog.LevelInfo, "Shutdown pprof Server completed.")
})()

slog.LogAttrs(ctx, slog.LevelInfo, "Starting pprof server.")
listener, err := net.Listen("tcp", "localhost:6060")
if err != nil {
listener, err = net.Listen("tcp", "localhost:0")
if err != nil {
slog.LogAttrs(ctx, slog.LevelWarn, "Fail to find port for pprof server.", slog.Any("error", err))

return nil
},
func() error {
if err := server.Shutdown(context.WithoutCancel(ctx)); err != nil {
slog.LogAttrs(ctx, slog.LevelWarn, "Fail to shutdown pprof server.", slog.Any("error", err))
}
}
}
slog.LogAttrs(ctx, slog.LevelInfo, fmt.Sprintf("pprof server started at http://%s/debug/pprof/.", listener.Addr()))

return nil
},
)(ctx)
runtime.SetBlockProfileRate(1) // Required by gRPC server.
if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
slog.LogAttrs(ctx, slog.LevelWarn, "Fail to start pprof server.", slog.Any("error", err))
}

return nil
}
44 changes: 2 additions & 42 deletions run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import (
"sync"
)

// Parallel executes the given runs in parallel.
// parallel executes the given runs in parallel.
//
// It blocks until all runs complete or ctx is done, then
// returns the first non-nil error if received from any run.
func Parallel(ctx context.Context, runs ...func(context.Context) error) error {
func parallel(ctx context.Context, runs ...func(context.Context) error) error {
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)

Expand All @@ -37,43 +37,3 @@ func Parallel(ctx context.Context, runs ...func(context.Context) error) error {

return nil
}

// WithCloser wraps the given run with a dedicated closer function,
// which should cause the run returns when closer function executes.
//
// closer function executes even if run function returns non-nil error.
// It guarantees both run and closer functions complete
// and return the first non-nil error if any.
func WithCloser(run func(context.Context) error, closer func() error) func(context.Context) error {
if run == nil {
run = func(context.Context) error { return nil }
}
if closer == nil {
closer = func() error { return nil }
}

return func(ctx context.Context) error {
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)

var waitGroup sync.WaitGroup
waitGroup.Add(1)
go func() {
defer waitGroup.Done()

if err := run(ctx); err != nil {
cancel(err)
}
}()

<-ctx.Done()
err := closer()
waitGroup.Wait()

if e := context.Cause(ctx); e != nil && !errors.Is(e, ctx.Err()) {
return e //nolint:wrapcheck
}

return err
}
}
62 changes: 0 additions & 62 deletions run/run_test.go

This file was deleted.

10 changes: 5 additions & 5 deletions run/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,34 +89,34 @@ func (r Runner) Run(ctx context.Context, runs ...func(context.Context) error) er
runCtx, runCancel := context.WithCancel(rootCtx)
defer runCancel()

return Parallel(rootCtx,
return parallel(rootCtx,
append(preRuns,
func(context.Context) error {
defer runCancel() // Notify all main runs to stop.

<-signalCtx.Done()

// Wait for all stop gates to open.
return Parallel(runCtx, r.stopGates...)
return parallel(runCtx, r.stopGates...)
},
func(context.Context) (err error) { //nolint:nonamedreturns
defer func() {
signalCancel() // Stop listening to OS signals.

// Wait for all post runs to finish.
e := Parallel(rootCtx, r.postRuns...)
e := parallel(rootCtx, r.postRuns...)
if err == nil {
err = e
}
rootCancel(err) // Notify all pre runs to stop.
}()

// Wait for all start gates to open.
if err = Parallel(runCtx, startGates...); err != nil {
if err = parallel(runCtx, startGates...); err != nil {
return err
}

return Parallel(runCtx, runs...)
return parallel(runCtx, runs...)
},
)...,
)
Expand Down

0 comments on commit 372bfb0

Please sign in to comment.