Skip to content

Commit

Permalink
some improvements and first tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rvflash committed Mar 2, 2019
1 parent 7a6af71 commit 2a844ba
Show file tree
Hide file tree
Showing 11 changed files with 195 additions and 67 deletions.
22 changes: 2 additions & 20 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,10 @@ type conn struct {
rwc net.Conn
}

// ServeTCP implements the Handler interface.
func (c *conn) ServeTCP(w ResponseWriter, req *Request) {
ctx := c.srv.get()
ctx.writer.rebase(w)
ctx.Request = req
ctx.reset()
c.handle(ctx)
c.srv.put(ctx)
}

func (c *conn) bySegment(segment string, body io.Reader) {
req := c.newRequest(segment, body)
w := c.newResponseWriter()
c.ServeTCP(w, req)
}

func (c *conn) handle(ctx *Context) {
ctx.handlers = c.srv.computeHandlers(ctx.Request.Segment)
if len(ctx.handlers) == 0 {
return
}
ctx.Next()
req := c.newRequest(segment, body)
c.srv.ServeTCP(w, req)
}

func (c *conn) newResponseWriter() *responseWriter {
Expand Down
9 changes: 5 additions & 4 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (c *Context) Next() {
}
}

// ReadAll return stream data.
// ReadAll return the stream data.
func (c *Context) ReadAll() ([]byte, error) {
if c.Request == nil {
return nil, ErrRequest
Expand All @@ -143,9 +143,10 @@ func (c *Context) ReadAll() ([]byte, error) {

// String writes the given string on the current connection.
func (c *Context) String(s string) {
if !strings.HasSuffix(s, "\n") {
// sends it now
s += "\n"
const eom = "\n"
if !strings.HasSuffix(s, eom) {
// sends it now, ending the message.
s += eom
}
_, err := c.writer.WriteString(s)
if err != nil {
Expand Down
12 changes: 7 additions & 5 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type Err interface {
var ErrRequest = NewError("invalid request")

// NewError returns a new Error based of the given cause.
func NewError(msg string, cause ...error) error {
func NewError(msg string, cause ...error) *Error {
if cause == nil {
return &Error{msg: msg}
}
Expand All @@ -32,10 +32,11 @@ type Error struct {

// Error implements the Err interface.
func (e *Error) Error() string {
const prefix = "tcp: "
if e.cause == nil {
return "tcp: " + e.msg
return prefix + e.msg
}
return "tcp: " + e.msg + ": " + e.cause.Error()
return prefix + e.msg + ": " + e.cause.Error()
}

// Recovered implements the Err interface.
Expand Down Expand Up @@ -67,9 +68,10 @@ func (e Errors) Error() string {

// Recovered implements the Err interface.
func (e Errors) Recovered() (ok bool) {
var err Err
for _, r := range e {
_, ok = r.(*Error)
if ok {
err, ok = r.(Err)
if ok && err.Recovered() {
return
}
}
Expand Down
68 changes: 68 additions & 0 deletions errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package tcp_test

import (
"errors"
"strconv"
"testing"

"github.com/matryer/is"
"github.com/rvflash/tcp"
)

const (
hiWorld = "hello world"
prefix = "tcp: "
)

func TestNewError(t *testing.T) {
var (
dt = []struct {
in string
err error
out string
}{
{out: prefix},
{in: "hi!", out: prefix + "hi!"},
{in: "hello", err: errors.New("world"), out: prefix + "hello: world"},
}
are = is.New(t)
err error
)
for i, tt := range dt {
t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
err = tcp.NewError(tt.in, tt.err)
are.Equal(err.Error(), tt.out)
})
}
}

func TestError_Recovered(t *testing.T) {
e := &tcp.Error{}
is.New(t).True(!e.Recovered())
}

func TestErrors_Error(t *testing.T) {
var err tcp.Errors
err = append(err, errors.New(hiWorld))
err = append(err, tcp.NewError(hiWorld))
are := is.New(t)
are.Equal(err.Error(), hiWorld+", "+prefix+hiWorld)
}

func TestErrors_Recovered(t *testing.T) {
var (
dt = []struct {
err tcp.Errors
ok bool
}{
{err: tcp.Errors{}},
{err: tcp.Errors{tcp.NewError(hiWorld)}, ok: true},
}
are = is.New(t)
)
for i, tt := range dt {
t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
are.Equal(tt.err.Recovered(), tt.ok)
})
}
}
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
module github.com/rvflash/tcp

require github.com/sirupsen/logrus v1.3.0
require (
github.com/matryer/is v1.2.0
github.com/sirupsen/logrus v1.3.0
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/matryer/is v1.2.0 h1:92UTHpy8CDwaJ08GqLDzhhuixiBUUD1p3AU6PHddz4A=
github.com/matryer/is v1.2.0/go.mod h1:2fLPjFQM9rhQ15aVEtbuwhJinnOqrmgXPNdZsdwlWXA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME=
Expand Down
27 changes: 13 additions & 14 deletions logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import (
)

const (
RemoteAddr = "addr"
RequestLength = "req_size"
ResponseLength = "resp_size"
Latency = "latency"
Hostname = "server"
remoteAddr = "addr"
reqLength = "req_size"
respLength = "resp_size"
latency = "latency"
hostname = "server"
)

// Logger returns a middleware to log each TCP request.
Expand All @@ -30,9 +30,9 @@ func Logger(log *logrus.Logger, fields logrus.Fields) HandlerFunc {
if e := c.Err(); e == nil {
entry.Info(m.String())
} else if e.Recovered() {
entry.Error(m.String() + " " + e.Error())
entry.Errorf("%s %s", m, e)
} else {
entry.Warn(m.String() + " " + e.Error())
entry.Warnf("%s %s", m, e)
}
}
}
Expand Down Expand Up @@ -63,16 +63,16 @@ func (m *message) fields(w ResponseWriter, f logrus.Fields) logrus.Fields {
d := make(logrus.Fields)
for k := range f {
switch k {
case RemoteAddr:
case remoteAddr:
d[k] = m.req.RemoteAddr
case RequestLength:
case reqLength:
d[k] = w.Size()
case ResponseLength:
case respLength:
d[k] = m.reqSize
case Latency:
case latency:
m.latency = time.Since(m.start)
d[k] = int(math.Ceil(float64(m.latency.Nanoseconds()) / 1000.0))
case Hostname:
case hostname:
d[k], _ = os.Hostname()
}
}
Expand All @@ -81,6 +81,5 @@ func (m *message) fields(w ResponseWriter, f logrus.Fields) logrus.Fields {

// String implements the fmt.Stringer interface.
func (m *message) String() string {
sep := " | "
return "[TCP] " + m.start.Format(time.RFC3339) + sep + m.req.Segment
return "[TCP] " + m.start.Format(time.RFC3339) + " | " + m.req.Segment
}
15 changes: 15 additions & 0 deletions recovery_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package tcp_test

import (
"testing"

"github.com/rvflash/tcp"
)

func TestRecovery(t *testing.T) {
srv := tcp.New()
srv.Use(tcp.Recovery())
req := tcp.NewRequest(tcp.SYN, nil)
w := tcp.NewRecorder()
srv.ServeTCP(w, req)
}
3 changes: 2 additions & 1 deletion request.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ type Request struct {
Segment string
// Body is the request's body.
Body io.ReadCloser
// RemoteAddr returns the remote network address.
// remoteAddr returns the remote network address.
RemoteAddr string

// Context of the request.
ctx context.Context
cancel context.CancelFunc
Expand Down
51 changes: 48 additions & 3 deletions response.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
package tcp

import (
"bytes"
"io"
)

// ResponseWriter interface is used by a TCP handler to write the response.
type ResponseWriter interface {
io.WriteCloser
// Size returns the number of bytes already written into the response body.
// -1: not already written
Size() int
io.WriteCloser
}

type responseWriter struct {
Expand Down Expand Up @@ -44,8 +45,8 @@ func (r *responseWriter) WriteString(s string) (n int, err error) {
}

func (r *responseWriter) incr(n int) {
if n == noWritten {
n = 0
if r.size == noWritten {
r.size = 0
}
r.size += n
}
Expand All @@ -54,3 +55,47 @@ func (r *responseWriter) rebase(w ResponseWriter) {
r.ResponseWriter = w
r.size = noWritten
}

// ResponseRecorder is an implementation of http.ResponseWriter that records its changes.
type ResponseRecorder struct {
// Body is the buffer to which the Handler's Write calls are sent.
Body *bytes.Buffer
}

// NewRecorder returns an initialized writer to record the response.
func NewRecorder() *ResponseRecorder {
return &ResponseRecorder{
Body: new(bytes.Buffer),
}
}

// Close implements the ResponseWriter interface.
func (r *ResponseRecorder) Close() error {
return nil
}

// Size implements the ResponseWriter interface.
func (r *ResponseRecorder) Size() int {
if r.Body == nil {
return noWritten
}
return r.Body.Len()
}

// Write implements the ResponseWriter interface.
func (r *ResponseRecorder) Write(p []byte) (n int, err error) {
if r.Body == nil {
return 0, io.EOF
}
n, err = r.Body.Write(p)
return
}

// WriteString allows to directly write string.
func (r *ResponseRecorder) WriteString(s string) (n int, err error) {
if r.Body == nil {
return 0, io.EOF
}
n, err = r.Body.WriteString(s)
return
}
Loading

0 comments on commit 2a844ba

Please sign in to comment.