From 119b145398ed1a95a71c3fbe0b0df0309ca91aa7 Mon Sep 17 00:00:00 2001 From: hgouchet Date: Tue, 3 Sep 2019 14:24:03 +0200 Subject: [PATCH] graceful shutdown: some improvements --- .gitignore | 1 + conn.go | 25 +++++----- errors.go | 7 ++- example/graceful_server/main.go | 25 +++++++++- server.go | 88 +++++++++++++++++++++------------ 5 files changed, 98 insertions(+), 48 deletions(-) diff --git a/.gitignore b/.gitignore index 1285731..4d9bae3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .idea example/client/client example/server/server +example/graceful_server/graceful_server example/start/start \ No newline at end of file diff --git a/conn.go b/conn.go index 83199e0..45038c2 100644 --- a/conn.go +++ b/conn.go @@ -32,25 +32,24 @@ func (c *conn) newRequest(segment string, body io.Reader) *Request { func (c *conn) serve(ctx context.Context) { // New connection c.bySegment(ctx, SYN, nil) - + // Connection closed + defer c.bySegment(ctx, FIN, nil) // Waiting for messages r := bufio.NewReader(c.rwc) for { + cb := make(chan []byte, 1) + go func() { + d, err := r.ReadBytes('\n') + if err != nil { + return + } + cb <- d + }() select { case <-ctx.Done(): - // Connection closing, stops serving. - c.bySegment(ctx, FIN, r) - return - default: - } - d, err := r.ReadBytes('\n') - r := bytes.NewReader(d) - if err != nil { - // Unable to read on it: closing the connection. - c.bySegment(ctx, FIN, r) return + case b := <-cb: + c.bySegment(ctx, ACK, bytes.NewReader(b)) } - // new message received - c.bySegment(ctx, ACK, r) } } diff --git a/errors.go b/errors.go index f7abfa2..c176502 100644 --- a/errors.go +++ b/errors.go @@ -11,8 +11,11 @@ type Err interface { Recovered() bool } -// ErrRequest is returned if the request is invalid. -var ErrRequest = NewError("invalid request") +// List of common errors +var ( + // ErrRequest is returned if the request is invalid. + ErrRequest = NewError("invalid request") +) // NewError returns a new Error based of the given cause. func NewError(msg string, cause ...error) *Error { diff --git a/example/graceful_server/main.go b/example/graceful_server/main.go index 7440a41..5baa2ed 100644 --- a/example/graceful_server/main.go +++ b/example/graceful_server/main.go @@ -1,13 +1,22 @@ package main import ( + "context" "log" + "os" + "os/signal" + "syscall" + "time" "github.com/rvflash/tcp" ) func main() { + bye := make(chan os.Signal, 1) + signal.Notify(bye, os.Interrupt, syscall.SIGTERM) + r := tcp.Default() + r.ReadTimeout = 20 * time.Second r.ACK(func(c *tcp.Context) { // new message received body, err := c.ReadAll() @@ -24,5 +33,19 @@ func main() { r.FIN(func(c *tcp.Context) { log.Println("bye") }) - log.Fatal(r.Run(":9090")) + + go func() { + err := r.Run(":9090") + if err != nil { + log.Printf("server: %q\n", err) + } + }() + + <-bye + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + err := r.Shutdown(ctx) + cancel() + if err != nil { + log.Fatal(err) + } } diff --git a/server.go b/server.go index d4d198a..4f9d994 100644 --- a/server.go +++ b/server.go @@ -62,7 +62,7 @@ func Default() *Server { func New() *Server { s := &Server{ handlers: map[string][]HandlerFunc{}, - shutdown: make(chan struct{}), + closing: make(chan struct{}), closed: make(chan struct{}), } s.pool.New = func() interface{} { @@ -81,11 +81,14 @@ type Server struct { // A zero value for t means Read will not time out. ReadTimeout time.Duration - cancel context.CancelFunc + listener net.Listener handlers map[string][]HandlerFunc pool sync.Pool + + // graceful shutdown + cancelCtx context.CancelFunc closed, - shutdown chan struct{} + closing chan struct{} } // Any attaches handlers on the given segment. @@ -130,12 +133,12 @@ const network = "tcp" // Run starts listening on TCP address. // This method will block the calling goroutine indefinitely unless an error happens. -func (s *Server) Run(addr string) error { - l, err := net.Listen(network, addr) +func (s *Server) Run(addr string) (err error) { + s.listener, err = net.Listen(network, addr) if err != nil { return err } - return s.serve(l) + return s.serve() } // RunTLS acts identically to the Run method, except that it uses the TLS protocol. @@ -145,40 +148,58 @@ func (s *Server) RunTLS(addr, certFile, keyFile string) error { if err != nil { return err } - l, err := tls.Listen(network, addr, c) + s.listener, err = tls.Listen(network, addr, c) if err != nil { return err } - return s.serve(l) + return s.serve() } -func (s *Server) serve(l net.Listener) (err error) { +func (s *Server) close() { + select { + case <-s.closed: + // Already closed. + return + default: + close(s.closed) + } +} + +func (s *Server) closeListener() error { + if s.cancelCtx == nil { + return nil + } + s.cancelCtx() + return s.listener.Close() +} + +func (s *Server) serve() (err error) { var ( w8 sync.WaitGroup ctx context.Context ) - ctx, s.cancel = context.WithCancel(context.Background()) + ctx, s.cancelCtx = context.WithCancel(context.Background()) defer func() { - s.cancel() - cErr := l.Close() - if err != nil { - err = cErr - } - }() - for { select { - case <-s.shutdown: - // Stops listening but does not interrupt any active connections. - // See the Shutdown method to gracefully shuts down the server. - w8.Wait() - close(s.closed) - return + case <-s.closed: default: + err = s.closeListener() + return } + }() + for { var c net.Conn - c, err = read(l, s.ReadTimeout) + c, err = read(s.listener, s.ReadTimeout) if err != nil { - return + select { + case <-s.closing: + // Stops listening but does not interrupt any active connections. + w8.Wait() + s.close() + return nil + default: + return + } } rwc := s.newConn(c) w8.Add(1) @@ -225,22 +246,25 @@ func (s *Server) computeHandlers(segment string) []HandlerFunc { // Shutdown gracefully shuts down the server without interrupting any // active connections. Shutdown works by first closing all open listeners and // then waiting indefinitely for connections to return to idle and then shut down. -// If the provided context expires before the shutdown is complete, +// If the provided context expires before the closing is complete, // Shutdown returns the context's error. func (s *Server) Shutdown(ctx context.Context) error { - if s.shutdown == nil { + if s.closing == nil { // Nothing to do return nil } - // Stops listening. - close(s.shutdown) + close(s.closing) - // Stops all. + // Stops listening. + err := s.closeListener() + if err != nil { + return err + } for { select { case <-ctx.Done(): - // Forces closing of actives connections. - s.cancel() + // Forces closing of all actives connections. + s.close() return ctx.Err() case <-s.closed: return nil