diff --git a/.gitignore b/.gitignore index 1285731..4d9bae3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .idea example/client/client example/server/server +example/graceful_server/graceful_server example/start/start \ No newline at end of file diff --git a/README.md b/README.md index 6f36276..65a3ee3 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,12 @@ The `Next` method on the `Context` should only be used inside middleware. Its al See the `Recovery` or `Logger` methods as sample code. +### Graceful shutdown + +By running the TCP server is in own go routine, you can gracefully shuts down the server without interrupting any active connections. +`Shutdown` works by first closing all open listeners and then waiting indefinitely for connections to return to idle and then shut down. + + ## Quick start Assuming the following code that runs a server on port 9090: @@ -76,28 +82,44 @@ Assuming the following code that runs a server on port 9090: package main import ( - "log" + "context" + "log" + "os" + "os/signal" "github.com/rvflash/tcp" ) func main() { - // creates a server with a logger and a recover on panic as middlewares. + bye := make(chan os.Signal, 1) + signal.Notify(bye, os.Interrupt, syscall.SIGTERM) + + // Creates a server with a logger and a recover on panic as middlewares. r := tcp.Default() r.ACK(func(c *tcp.Context) { - // new message received - // gets the request body + // New message received + // Gets the request body buf, err := c.ReadAll() if err != nil { c.Error(err) return } - // writes something as response + // Writes something as response c.String(string(buf)) }) - err := r.Run(":9090") // listen and serve on 0.0.0.0:9090 + go func() { + err := r.Run(":9090") + if err != nil { + log.Printf("server: %q\n", err) + } + }() + + <-bye + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + err := r.Shutdown(ctx) + cancel() if err != nil { - log.Fatalf("listen: %s", err) + log.Fatal(err) } } ``` \ No newline at end of file diff --git a/conn.go b/conn.go index fca2dc3..0c5eec1 100644 --- a/conn.go +++ b/conn.go @@ -10,13 +10,12 @@ import ( type conn struct { addr string - ctx context.Context - srv *Server rwc net.Conn + srv *Server } -func (c *conn) bySegment(segment string, body io.Reader) { - ctx, cancel := context.WithCancel(c.ctx) +func (c *conn) bySegment(ctx context.Context, segment string, body io.Reader) { + ctx, cancel := context.WithCancel(ctx) defer cancel() w := newWriter(c.rwc) @@ -30,20 +29,18 @@ func (c *conn) newRequest(segment string, body io.Reader) *Request { return req } -func (c *conn) serve() { - // deals with a new connection - go c.bySegment(SYN, nil) - // waiting for messages +func (c *conn) serve(ctx context.Context) { + // New connection + go c.bySegment(ctx, SYN, nil) + // Waiting for messages r := bufio.NewReader(c.rwc) for { d, err := r.ReadBytes('\n') - r := bytes.NewReader(d) if err != nil { - // unable to read on it: closing the connection. - c.bySegment(FIN, r) - return + break } - // new message received - go c.bySegment(ACK, r) + go c.bySegment(ctx, ACK, bytes.NewReader(d)) } + // Connection closed + c.bySegment(ctx, FIN, nil) } diff --git a/errors.go b/errors.go index f7abfa2..c176502 100644 --- a/errors.go +++ b/errors.go @@ -11,8 +11,11 @@ type Err interface { Recovered() bool } -// ErrRequest is returned if the request is invalid. -var ErrRequest = NewError("invalid request") +// List of common errors +var ( + // ErrRequest is returned if the request is invalid. + ErrRequest = NewError("invalid request") +) // NewError returns a new Error based of the given cause. func NewError(msg string, cause ...error) *Error { diff --git a/example/graceful_server/main.go b/example/graceful_server/main.go new file mode 100644 index 0000000..5baa2ed --- /dev/null +++ b/example/graceful_server/main.go @@ -0,0 +1,51 @@ +package main + +import ( + "context" + "log" + "os" + "os/signal" + "syscall" + "time" + + "github.com/rvflash/tcp" +) + +func main() { + bye := make(chan os.Signal, 1) + signal.Notify(bye, os.Interrupt, syscall.SIGTERM) + + r := tcp.Default() + r.ReadTimeout = 20 * time.Second + r.ACK(func(c *tcp.Context) { + // new message received + body, err := c.ReadAll() + if err != nil { + c.Error(err) + return + } + log.Println(string(body)) + c.String("read") + }) + r.SYN(func(c *tcp.Context) { + c.String("hello") + }) + r.FIN(func(c *tcp.Context) { + log.Println("bye") + }) + + go func() { + err := r.Run(":9090") + if err != nil { + log.Printf("server: %q\n", err) + } + }() + + <-bye + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + err := r.Shutdown(ctx) + cancel() + if err != nil { + log.Fatal(err) + } +} diff --git a/server.go b/server.go index 205bbd8..4f9d994 100644 --- a/server.go +++ b/server.go @@ -62,6 +62,8 @@ func Default() *Server { func New() *Server { s := &Server{ handlers: map[string][]HandlerFunc{}, + closing: make(chan struct{}), + closed: make(chan struct{}), } s.pool.New = func() interface{} { return s.allocateContext() @@ -79,8 +81,14 @@ type Server struct { // A zero value for t means Read will not time out. ReadTimeout time.Duration + listener net.Listener handlers map[string][]HandlerFunc pool sync.Pool + + // graceful shutdown + cancelCtx context.CancelFunc + closed, + closing chan struct{} } // Any attaches handlers on the given segment. @@ -126,48 +134,85 @@ const network = "tcp" // Run starts listening on TCP address. // This method will block the calling goroutine indefinitely unless an error happens. func (s *Server) Run(addr string) (err error) { - l, err := net.Listen(network, addr) + s.listener, err = net.Listen(network, addr) if err != nil { - return + return err } - return s.serve(l) + return s.serve() } // RunTLS acts identically to the Run method, except that it uses the TLS protocol. // This method will block the calling goroutine indefinitely unless an error happens. -func (s *Server) RunTLS(addr, certFile, keyFile string) (err error) { +func (s *Server) RunTLS(addr, certFile, keyFile string) error { c, err := tlsConfig(certFile, keyFile) if err != nil { - return + return err } - l, err := tls.Listen(network, addr, c) + s.listener, err = tls.Listen(network, addr, c) if err != nil { + return err + } + return s.serve() +} + +func (s *Server) close() { + select { + case <-s.closed: + // Already closed. return + default: + close(s.closed) + } +} + +func (s *Server) closeListener() error { + if s.cancelCtx == nil { + return nil } - return s.serve(l) + s.cancelCtx() + return s.listener.Close() } -func (s *Server) serve(l net.Listener) (err error) { +func (s *Server) serve() (err error) { + var ( + w8 sync.WaitGroup + ctx context.Context + ) + ctx, s.cancelCtx = context.WithCancel(context.Background()) defer func() { - if err == nil { - err = l.Close() + select { + case <-s.closed: + default: + err = s.closeListener() + return } }() - ctx := context.Background() for { - c, err := read(l, s.ReadTimeout) + var c net.Conn + c, err = read(s.listener, s.ReadTimeout) if err != nil { - return err + select { + case <-s.closing: + // Stops listening but does not interrupt any active connections. + w8.Wait() + s.close() + return nil + default: + return + } } - rwc := s.newConn(ctx, c) - go rwc.serve() + rwc := s.newConn(c) + w8.Add(1) + go func() { + defer w8.Done() + rwc.serve(ctx) + }() } } -func (s *Server) newConn(ctx context.Context, c net.Conn) *conn { +func (s *Server) newConn(c net.Conn) *conn { return &conn{ addr: c.RemoteAddr().String(), - ctx: ctx, srv: s, rwc: c, } @@ -198,6 +243,35 @@ func (s *Server) computeHandlers(segment string) []HandlerFunc { return m } +// Shutdown gracefully shuts down the server without interrupting any +// active connections. Shutdown works by first closing all open listeners and +// then waiting indefinitely for connections to return to idle and then shut down. +// If the provided context expires before the closing is complete, +// Shutdown returns the context's error. +func (s *Server) Shutdown(ctx context.Context) error { + if s.closing == nil { + // Nothing to do + return nil + } + close(s.closing) + + // Stops listening. + err := s.closeListener() + if err != nil { + return err + } + for { + select { + case <-ctx.Done(): + // Forces closing of all actives connections. + s.close() + return ctx.Err() + case <-s.closed: + return nil + } + } +} + func tlsConfig(certFile, keyFile string) (*tls.Config, error) { var err error c := make([]tls.Certificate, 1) @@ -211,7 +285,7 @@ func read(l net.Listener, to time.Duration) (net.Conn, error) { return nil, err } if to == 0 { - return c, err + return c, nil } err = c.SetReadDeadline(time.Now().Add(to)) if err != nil {