From 2a844ba8695e7f7914d9d2c49a80a372f79771dc Mon Sep 17 00:00:00 2001 From: hgouchet Date: Sat, 2 Mar 2019 15:34:35 +0100 Subject: [PATCH] some improvements and first tests --- conn.go | 22 ++-------------- context.go | 9 ++++--- errors.go | 12 +++++---- errors_test.go | 68 ++++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 5 +++- go.sum | 2 ++ logger.go | 27 +++++++++---------- recovery_test.go | 15 +++++++++++ request.go | 3 ++- response.go | 51 +++++++++++++++++++++++++++++++++--- server.go | 48 ++++++++++++++++++++-------------- 11 files changed, 195 insertions(+), 67 deletions(-) create mode 100644 errors_test.go create mode 100644 recovery_test.go diff --git a/conn.go b/conn.go index ad5981c..8c64e49 100644 --- a/conn.go +++ b/conn.go @@ -15,28 +15,10 @@ type conn struct { rwc net.Conn } -// ServeTCP implements the Handler interface. -func (c *conn) ServeTCP(w ResponseWriter, req *Request) { - ctx := c.srv.get() - ctx.writer.rebase(w) - ctx.Request = req - ctx.reset() - c.handle(ctx) - c.srv.put(ctx) -} - func (c *conn) bySegment(segment string, body io.Reader) { - req := c.newRequest(segment, body) w := c.newResponseWriter() - c.ServeTCP(w, req) -} - -func (c *conn) handle(ctx *Context) { - ctx.handlers = c.srv.computeHandlers(ctx.Request.Segment) - if len(ctx.handlers) == 0 { - return - } - ctx.Next() + req := c.newRequest(segment, body) + c.srv.ServeTCP(w, req) } func (c *conn) newResponseWriter() *responseWriter { diff --git a/context.go b/context.go index bcf3001..6678eef 100644 --- a/context.go +++ b/context.go @@ -130,7 +130,7 @@ func (c *Context) Next() { } } -// ReadAll return stream data. +// ReadAll return the stream data. func (c *Context) ReadAll() ([]byte, error) { if c.Request == nil { return nil, ErrRequest @@ -143,9 +143,10 @@ func (c *Context) ReadAll() ([]byte, error) { // String writes the given string on the current connection. func (c *Context) String(s string) { - if !strings.HasSuffix(s, "\n") { - // sends it now - s += "\n" + const eom = "\n" + if !strings.HasSuffix(s, eom) { + // sends it now, ending the message. + s += eom } _, err := c.writer.WriteString(s) if err != nil { diff --git a/errors.go b/errors.go index 8f9ab52..f7abfa2 100644 --- a/errors.go +++ b/errors.go @@ -15,7 +15,7 @@ type Err interface { var ErrRequest = NewError("invalid request") // NewError returns a new Error based of the given cause. -func NewError(msg string, cause ...error) error { +func NewError(msg string, cause ...error) *Error { if cause == nil { return &Error{msg: msg} } @@ -32,10 +32,11 @@ type Error struct { // Error implements the Err interface. func (e *Error) Error() string { + const prefix = "tcp: " if e.cause == nil { - return "tcp: " + e.msg + return prefix + e.msg } - return "tcp: " + e.msg + ": " + e.cause.Error() + return prefix + e.msg + ": " + e.cause.Error() } // Recovered implements the Err interface. @@ -67,9 +68,10 @@ func (e Errors) Error() string { // Recovered implements the Err interface. func (e Errors) Recovered() (ok bool) { + var err Err for _, r := range e { - _, ok = r.(*Error) - if ok { + err, ok = r.(Err) + if ok && err.Recovered() { return } } diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..607317a --- /dev/null +++ b/errors_test.go @@ -0,0 +1,68 @@ +package tcp_test + +import ( + "errors" + "strconv" + "testing" + + "github.com/matryer/is" + "github.com/rvflash/tcp" +) + +const ( + hiWorld = "hello world" + prefix = "tcp: " +) + +func TestNewError(t *testing.T) { + var ( + dt = []struct { + in string + err error + out string + }{ + {out: prefix}, + {in: "hi!", out: prefix + "hi!"}, + {in: "hello", err: errors.New("world"), out: prefix + "hello: world"}, + } + are = is.New(t) + err error + ) + for i, tt := range dt { + t.Run("#"+strconv.Itoa(i), func(t *testing.T) { + err = tcp.NewError(tt.in, tt.err) + are.Equal(err.Error(), tt.out) + }) + } +} + +func TestError_Recovered(t *testing.T) { + e := &tcp.Error{} + is.New(t).True(!e.Recovered()) +} + +func TestErrors_Error(t *testing.T) { + var err tcp.Errors + err = append(err, errors.New(hiWorld)) + err = append(err, tcp.NewError(hiWorld)) + are := is.New(t) + are.Equal(err.Error(), hiWorld+", "+prefix+hiWorld) +} + +func TestErrors_Recovered(t *testing.T) { + var ( + dt = []struct { + err tcp.Errors + ok bool + }{ + {err: tcp.Errors{}}, + {err: tcp.Errors{tcp.NewError(hiWorld)}, ok: true}, + } + are = is.New(t) + ) + for i, tt := range dt { + t.Run("#"+strconv.Itoa(i), func(t *testing.T) { + are.Equal(tt.err.Recovered(), tt.ok) + }) + } +} diff --git a/go.mod b/go.mod index ab4e2f1..450b086 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,6 @@ module github.com/rvflash/tcp -require github.com/sirupsen/logrus v1.3.0 +require ( + github.com/matryer/is v1.2.0 + github.com/sirupsen/logrus v1.3.0 +) diff --git a/go.sum b/go.sum index 1a0296f..e4a5c15 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/matryer/is v1.2.0 h1:92UTHpy8CDwaJ08GqLDzhhuixiBUUD1p3AU6PHddz4A= +github.com/matryer/is v1.2.0/go.mod h1:2fLPjFQM9rhQ15aVEtbuwhJinnOqrmgXPNdZsdwlWXA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME= diff --git a/logger.go b/logger.go index bab5030..9e78f92 100644 --- a/logger.go +++ b/logger.go @@ -11,11 +11,11 @@ import ( ) const ( - RemoteAddr = "addr" - RequestLength = "req_size" - ResponseLength = "resp_size" - Latency = "latency" - Hostname = "server" + remoteAddr = "addr" + reqLength = "req_size" + respLength = "resp_size" + latency = "latency" + hostname = "server" ) // Logger returns a middleware to log each TCP request. @@ -30,9 +30,9 @@ func Logger(log *logrus.Logger, fields logrus.Fields) HandlerFunc { if e := c.Err(); e == nil { entry.Info(m.String()) } else if e.Recovered() { - entry.Error(m.String() + " " + e.Error()) + entry.Errorf("%s %s", m, e) } else { - entry.Warn(m.String() + " " + e.Error()) + entry.Warnf("%s %s", m, e) } } } @@ -63,16 +63,16 @@ func (m *message) fields(w ResponseWriter, f logrus.Fields) logrus.Fields { d := make(logrus.Fields) for k := range f { switch k { - case RemoteAddr: + case remoteAddr: d[k] = m.req.RemoteAddr - case RequestLength: + case reqLength: d[k] = w.Size() - case ResponseLength: + case respLength: d[k] = m.reqSize - case Latency: + case latency: m.latency = time.Since(m.start) d[k] = int(math.Ceil(float64(m.latency.Nanoseconds()) / 1000.0)) - case Hostname: + case hostname: d[k], _ = os.Hostname() } } @@ -81,6 +81,5 @@ func (m *message) fields(w ResponseWriter, f logrus.Fields) logrus.Fields { // String implements the fmt.Stringer interface. func (m *message) String() string { - sep := " | " - return "[TCP] " + m.start.Format(time.RFC3339) + sep + m.req.Segment + return "[TCP] " + m.start.Format(time.RFC3339) + " | " + m.req.Segment } diff --git a/recovery_test.go b/recovery_test.go new file mode 100644 index 0000000..c7d925c --- /dev/null +++ b/recovery_test.go @@ -0,0 +1,15 @@ +package tcp_test + +import ( + "testing" + + "github.com/rvflash/tcp" +) + +func TestRecovery(t *testing.T) { + srv := tcp.New() + srv.Use(tcp.Recovery()) + req := tcp.NewRequest(tcp.SYN, nil) + w := tcp.NewRecorder() + srv.ServeTCP(w, req) +} diff --git a/request.go b/request.go index 2d2a01f..6c80709 100644 --- a/request.go +++ b/request.go @@ -12,8 +12,9 @@ type Request struct { Segment string // Body is the request's body. Body io.ReadCloser - // RemoteAddr returns the remote network address. + // remoteAddr returns the remote network address. RemoteAddr string + // Context of the request. ctx context.Context cancel context.CancelFunc diff --git a/response.go b/response.go index b0866ff..bf744d2 100644 --- a/response.go +++ b/response.go @@ -1,15 +1,16 @@ package tcp import ( + "bytes" "io" ) // ResponseWriter interface is used by a TCP handler to write the response. type ResponseWriter interface { - io.WriteCloser // Size returns the number of bytes already written into the response body. // -1: not already written Size() int + io.WriteCloser } type responseWriter struct { @@ -44,8 +45,8 @@ func (r *responseWriter) WriteString(s string) (n int, err error) { } func (r *responseWriter) incr(n int) { - if n == noWritten { - n = 0 + if r.size == noWritten { + r.size = 0 } r.size += n } @@ -54,3 +55,47 @@ func (r *responseWriter) rebase(w ResponseWriter) { r.ResponseWriter = w r.size = noWritten } + +// ResponseRecorder is an implementation of http.ResponseWriter that records its changes. +type ResponseRecorder struct { + // Body is the buffer to which the Handler's Write calls are sent. + Body *bytes.Buffer +} + +// NewRecorder returns an initialized writer to record the response. +func NewRecorder() *ResponseRecorder { + return &ResponseRecorder{ + Body: new(bytes.Buffer), + } +} + +// Close implements the ResponseWriter interface. +func (r *ResponseRecorder) Close() error { + return nil +} + +// Size implements the ResponseWriter interface. +func (r *ResponseRecorder) Size() int { + if r.Body == nil { + return noWritten + } + return r.Body.Len() +} + +// Write implements the ResponseWriter interface. +func (r *ResponseRecorder) Write(p []byte) (n int, err error) { + if r.Body == nil { + return 0, io.EOF + } + n, err = r.Body.Write(p) + return +} + +// WriteString allows to directly write string. +func (r *ResponseRecorder) WriteString(s string) (n int, err error) { + if r.Body == nil { + return 0, io.EOF + } + n, err = r.Body.WriteString(s) + return +} diff --git a/server.go b/server.go index 87b5c33..f711a1c 100644 --- a/server.go +++ b/server.go @@ -42,15 +42,15 @@ const ( // Default returns an instance of TCP server with a Logger and a Recover on panic attached. func Default() *Server { - l := logrus.New() - l.Formatter = &logrus.TextFormatter{DisableTimestamp: true} f := logrus.Fields{ - Latency: 0, - Hostname: "", - RemoteAddr: "", - RequestLength: 0, - ResponseLength: 0, + latency: 0, + hostname: "", + remoteAddr: "", + reqLength: 0, + respLength: 0, } + l := logrus.New() + l.Formatter = &logrus.TextFormatter{DisableTimestamp: true} h := New() h.Use(Logger(l, f)) h.Use(Recovery()) @@ -68,6 +68,10 @@ func New() *Server { return s } +func (s *Server) allocateContext() *Context { + return &Context{srv: s} +} + // Server is the TCP server. It contains type Server struct { // ReadTimeout is the maximum duration for reading the entire request, including the body. @@ -130,7 +134,7 @@ func (s *Server) Run(addr string) (err error) { }() ctx := context.Background() for { - c, err := newConn(l, s.ReadTimeout) + c, err := read(l, s.ReadTimeout) if err != nil { return err } @@ -148,8 +152,22 @@ func (s *Server) newConn(ctx context.Context, c net.Conn) *conn { } } -func (s *Server) allocateContext() *Context { - return &Context{srv: s} +// ServeTCP implements the Handler interface; +func (s *Server) ServeTCP(w ResponseWriter, req *Request) { + ctx := s.pool.Get().(*Context) + ctx.writer.rebase(w) + ctx.Request = req + ctx.reset() + s.handle(ctx) + s.pool.Put(ctx) +} + +func (s *Server) handle(ctx *Context) { + ctx.handlers = s.computeHandlers(ctx.Request.Segment) + if len(ctx.handlers) == 0 { + return + } + ctx.Next() } func (s *Server) computeHandlers(segment string) []HandlerFunc { @@ -159,15 +177,7 @@ func (s *Server) computeHandlers(segment string) []HandlerFunc { return m } -func (s *Server) get() *Context { - return s.pool.Get().(*Context) -} - -func (s *Server) put(c *Context) { - s.pool.Put(c) -} - -func newConn(l net.Listener, to time.Duration) (net.Conn, error) { +func read(l net.Listener, to time.Duration) (net.Conn, error) { c, err := l.Accept() if err != nil { return nil, err