Skip to content

Commit

Permalink
graceful shutdown: some improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
rvflash committed Sep 3, 2019
1 parent 7dceed7 commit 119b145
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 48 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.idea
example/client/client
example/server/server
example/graceful_server/graceful_server
example/start/start
25 changes: 12 additions & 13 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,24 @@ func (c *conn) newRequest(segment string, body io.Reader) *Request {
func (c *conn) serve(ctx context.Context) {
// New connection
c.bySegment(ctx, SYN, nil)

// Connection closed
defer c.bySegment(ctx, FIN, nil)
// Waiting for messages
r := bufio.NewReader(c.rwc)
for {
cb := make(chan []byte, 1)
go func() {
d, err := r.ReadBytes('\n')
if err != nil {
return
}
cb <- d
}()
select {
case <-ctx.Done():
// Connection closing, stops serving.
c.bySegment(ctx, FIN, r)
return
default:
}
d, err := r.ReadBytes('\n')
r := bytes.NewReader(d)
if err != nil {
// Unable to read on it: closing the connection.
c.bySegment(ctx, FIN, r)
return
case b := <-cb:
c.bySegment(ctx, ACK, bytes.NewReader(b))
}
// new message received
c.bySegment(ctx, ACK, r)
}
}
7 changes: 5 additions & 2 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ type Err interface {
Recovered() bool
}

// ErrRequest is returned if the request is invalid.
var ErrRequest = NewError("invalid request")
// List of common errors
var (
// ErrRequest is returned if the request is invalid.
ErrRequest = NewError("invalid request")
)

// NewError returns a new Error based of the given cause.
func NewError(msg string, cause ...error) *Error {
Expand Down
25 changes: 24 additions & 1 deletion example/graceful_server/main.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
package main

import (
"context"
"log"
"os"
"os/signal"
"syscall"
"time"

"github.com/rvflash/tcp"
)

func main() {
bye := make(chan os.Signal, 1)
signal.Notify(bye, os.Interrupt, syscall.SIGTERM)

r := tcp.Default()
r.ReadTimeout = 20 * time.Second
r.ACK(func(c *tcp.Context) {
// new message received
body, err := c.ReadAll()
Expand All @@ -24,5 +33,19 @@ func main() {
r.FIN(func(c *tcp.Context) {
log.Println("bye")
})
log.Fatal(r.Run(":9090"))

go func() {
err := r.Run(":9090")
if err != nil {
log.Printf("server: %q\n", err)
}
}()

<-bye
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
err := r.Shutdown(ctx)
cancel()
if err != nil {
log.Fatal(err)
}
}
88 changes: 56 additions & 32 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func Default() *Server {
func New() *Server {
s := &Server{
handlers: map[string][]HandlerFunc{},
shutdown: make(chan struct{}),
closing: make(chan struct{}),
closed: make(chan struct{}),
}
s.pool.New = func() interface{} {
Expand All @@ -81,11 +81,14 @@ type Server struct {
// A zero value for t means Read will not time out.
ReadTimeout time.Duration

cancel context.CancelFunc
listener net.Listener
handlers map[string][]HandlerFunc
pool sync.Pool

// graceful shutdown
cancelCtx context.CancelFunc
closed,
shutdown chan struct{}
closing chan struct{}
}

// Any attaches handlers on the given segment.
Expand Down Expand Up @@ -130,12 +133,12 @@ const network = "tcp"

// Run starts listening on TCP address.
// This method will block the calling goroutine indefinitely unless an error happens.
func (s *Server) Run(addr string) error {
l, err := net.Listen(network, addr)
func (s *Server) Run(addr string) (err error) {
s.listener, err = net.Listen(network, addr)
if err != nil {
return err
}
return s.serve(l)
return s.serve()
}

// RunTLS acts identically to the Run method, except that it uses the TLS protocol.
Expand All @@ -145,40 +148,58 @@ func (s *Server) RunTLS(addr, certFile, keyFile string) error {
if err != nil {
return err
}
l, err := tls.Listen(network, addr, c)
s.listener, err = tls.Listen(network, addr, c)
if err != nil {
return err
}
return s.serve(l)
return s.serve()
}

func (s *Server) serve(l net.Listener) (err error) {
func (s *Server) close() {
select {
case <-s.closed:
// Already closed.
return
default:
close(s.closed)
}
}

func (s *Server) closeListener() error {
if s.cancelCtx == nil {
return nil
}
s.cancelCtx()
return s.listener.Close()
}

func (s *Server) serve() (err error) {
var (
w8 sync.WaitGroup
ctx context.Context
)
ctx, s.cancel = context.WithCancel(context.Background())
ctx, s.cancelCtx = context.WithCancel(context.Background())
defer func() {
s.cancel()
cErr := l.Close()
if err != nil {
err = cErr
}
}()
for {
select {
case <-s.shutdown:
// Stops listening but does not interrupt any active connections.
// See the Shutdown method to gracefully shuts down the server.
w8.Wait()
close(s.closed)
return
case <-s.closed:
default:
err = s.closeListener()
return
}
}()
for {
var c net.Conn
c, err = read(l, s.ReadTimeout)
c, err = read(s.listener, s.ReadTimeout)
if err != nil {
return
select {
case <-s.closing:
// Stops listening but does not interrupt any active connections.
w8.Wait()
s.close()
return nil
default:
return
}
}
rwc := s.newConn(c)
w8.Add(1)
Expand Down Expand Up @@ -225,22 +246,25 @@ func (s *Server) computeHandlers(segment string) []HandlerFunc {
// Shutdown gracefully shuts down the server without interrupting any
// active connections. Shutdown works by first closing all open listeners and
// then waiting indefinitely for connections to return to idle and then shut down.
// If the provided context expires before the shutdown is complete,
// If the provided context expires before the closing is complete,
// Shutdown returns the context's error.
func (s *Server) Shutdown(ctx context.Context) error {
if s.shutdown == nil {
if s.closing == nil {
// Nothing to do
return nil
}
// Stops listening.
close(s.shutdown)
close(s.closing)

// Stops all.
// Stops listening.
err := s.closeListener()
if err != nil {
return err
}
for {
select {
case <-ctx.Done():
// Forces closing of actives connections.
s.cancel()
// Forces closing of all actives connections.
s.close()
return ctx.Err()
case <-s.closed:
return nil
Expand Down

0 comments on commit 119b145

Please sign in to comment.