diff --git a/.travis.yml b/.travis.yml index 89c4ec7..3f8252e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,7 +2,8 @@ language: go sudo: required go: - - 1.x + - "1.10.x" + - master install: - go get -u github.com/golang/dep/cmd/dep diff --git a/Makefile b/Makefile index 4dc227b..2781045 100644 --- a/Makefile +++ b/Makefile @@ -7,4 +7,7 @@ test: go test -race -cover ./... test-cover: - go test -race -coverprofile=test.out ./... && go tool cover --html=test.out \ No newline at end of file + go test -race -coverprofile=test.out ./... && go tool cover --html=test.out + +bench: + go test -bench=. ./ diff --git a/bench_test.go b/bench_test.go new file mode 100644 index 0000000..c6d859f --- /dev/null +++ b/bench_test.go @@ -0,0 +1,19 @@ +package cod + +import ( + "net/http/httptest" + "testing" +) + +func BenchmarkRoutes(b *testing.B) { + d := NewWithoutServer() + d.GET("/", func(c *Context) error { + return nil + }) + b.ReportAllocs() + req := httptest.NewRequest("GET", "/", nil) + resp := httptest.NewRecorder() + for i := 0; i < b.N; i++ { + d.ServeHTTP(resp, req) + } +} diff --git a/cod.go b/cod.go index 6fc7e06..5f231d8 100644 --- a/cod.go +++ b/cod.go @@ -1,43 +1,51 @@ package cod import ( + "errors" "net/http" + "sync" "github.com/julienschmidt/httprouter" ) +var ( + // ErrOutOfHandlerRange out of handler range (call next over handler's size) + ErrOutOfHandlerRange = errors.New("out of handler range") +) + type ( // Cod web framework instance Cod struct { - Server *http.Server - Router *httprouter.Router - Middlewares []Handle + Server *http.Server + Router *httprouter.Router + // Middlewares middleware function + Middlewares []Handler errorLinsteners []ErrorLinstener - ErrorHandle ErrorHandle + ErrorHandler ErrorHandler + // NotFoundHandler not found handler + NotFoundHandler http.HandlerFunc GenerateID GenerateID + ctxPool sync.Pool } // Group group router Group struct { - Path string - HandleList []Handle - Cod *Cod + Path string + HandlerList []Handler + Cod *Cod } - // ErrorHandle error handle function - ErrorHandle func(error, *Context) - // GenerateID generate id + // ErrorHandler error handle function + ErrorHandler func(*Context, error) + // GenerateID generate context id GenerateID func() string - // Handle cod handle function - Handle func(*Context) error + // Handler cod handle function + Handler func(*Context) error // ErrorLinstener error listener function ErrorLinstener func(*Context, error) ) // New create a cod instance func New() *Cod { - d := &Cod{ - Router: httprouter.New(), - Middlewares: make([]Handle, 0), - } + d := NewWithoutServer() s := &http.Server{ Handler: d, } @@ -45,8 +53,23 @@ func New() *Cod { return d } +// NewWithoutServer create a cod instance without server +func NewWithoutServer() *Cod { + d := &Cod{ + Router: httprouter.New(), + Middlewares: make([]Handler, 0), + } + d.ctxPool.New = func() interface{} { + return &Context{} + } + return d +} + // ListenAndServe listen and serve for http server func (d *Cod) ListenAndServe(addr string) error { + if d.Server == nil { + panic("server is not inited") + } d.Server.Addr = addr return d.Server.ListenAndServe() } @@ -67,11 +90,21 @@ func (d *Cod) ServeHTTP(resp http.ResponseWriter, req *http.Request) { d.NotFound(resp, req) } +// fillContext fill the context +func (d *Cod) fillContext(c *Context, resp http.ResponseWriter, req *http.Request) { + c.Reset() + c.Request = req + c.Response = resp + if resp != nil { + c.Headers = resp.Header() + } +} + // Handle add http handle function -func (d *Cod) Handle(method, path string, handleList ...Handle) { +func (d *Cod) Handle(method, path string, handlerList ...Handler) { d.Router.Handle(method, path, func(resp http.ResponseWriter, req *http.Request, params httprouter.Params) { - c := NewContext(resp, req) - defer ReleaseContext(c) + c := d.ctxPool.Get().(*Context) + d.fillContext(c, resp, req) c.Params = make(map[string]string) for _, item := range params { c.Params[item.Key] = item.Value @@ -83,96 +116,106 @@ func (d *Cod) Handle(method, path string, handleList ...Handle) { c.cod = d mids := d.Middlewares maxMid := len(mids) + maxNext := maxMid + len(handlerList) index := -1 c.Next = func() error { index++ - // 如果已到最后,执行handle list + // 如果调用过多的next,则会导致panic + if index >= maxNext { + panic(ErrOutOfHandlerRange) + } + // 如果已执行完公共添加的中间件,执行handler list if index >= maxMid { - return handleList[index-maxMid](c) + return handlerList[index-maxMid](c) } return mids[index](c) } err := c.Next() if err != nil { d.EmitError(c, err) - fn := d.ErrorHandle - if fn == nil { - fn = d.Error - } - fn(err, c) + d.Error(c, err) } + d.ctxPool.Put(c) }) } // GET add http get method handle -func (d *Cod) GET(path string, handleList ...Handle) { - d.Handle(http.MethodGet, path, handleList...) +func (d *Cod) GET(path string, handlerList ...Handler) { + d.Handle(http.MethodGet, path, handlerList...) } // POST add http post method handle -func (d *Cod) POST(path string, handleList ...Handle) { - d.Handle(http.MethodPost, path, handleList...) +func (d *Cod) POST(path string, handlerList ...Handler) { + d.Handle(http.MethodPost, path, handlerList...) } // PUT add http put method handle -func (d *Cod) PUT(path string, handleList ...Handle) { - d.Handle(http.MethodPut, path, handleList...) +func (d *Cod) PUT(path string, handlerList ...Handler) { + d.Handle(http.MethodPut, path, handlerList...) } // PATCH add http patch method handle -func (d *Cod) PATCH(path string, handleList ...Handle) { - d.Handle(http.MethodPatch, path, handleList...) +func (d *Cod) PATCH(path string, handlerList ...Handler) { + d.Handle(http.MethodPatch, path, handlerList...) } // DELETE add http delete method handle -func (d *Cod) DELETE(path string, handleList ...Handle) { - d.Handle(http.MethodDelete, path, handleList...) +func (d *Cod) DELETE(path string, handlerList ...Handler) { + d.Handle(http.MethodDelete, path, handlerList...) } // HEAD add http head method handle -func (d *Cod) HEAD(path string, handleList ...Handle) { - d.Handle(http.MethodHead, path, handleList...) +func (d *Cod) HEAD(path string, handlerList ...Handler) { + d.Handle(http.MethodHead, path, handlerList...) } // OPTIONS add http options method handle -func (d *Cod) OPTIONS(path string, handleList ...Handle) { - d.Handle(http.MethodOptions, path, handleList...) +func (d *Cod) OPTIONS(path string, handlerList ...Handler) { + d.Handle(http.MethodOptions, path, handlerList...) } // TRACE add http trace method handle -func (d *Cod) TRACE(path string, handleList ...Handle) { - d.Handle(http.MethodTrace, path, handleList...) +func (d *Cod) TRACE(path string, handlerList ...Handler) { + d.Handle(http.MethodTrace, path, handlerList...) } // ALL add http all method handle -func (d *Cod) ALL(path string, handleList ...Handle) { +func (d *Cod) ALL(path string, handlerList ...Handler) { for _, method := range methods { - d.Handle(method, path, handleList...) + d.Handle(method, path, handlerList...) } } // Group create a http handle group -func (d *Cod) Group(path string, handleList ...Handle) (g *Group) { +func (d *Cod) Group(path string, handlerList ...Handler) (g *Group) { return &Group{ - Cod: d, - Path: path, - HandleList: handleList, + Cod: d, + Path: path, + HandlerList: handlerList, } } // Use add middleware function handle -func (d *Cod) Use(handleList ...Handle) { - d.Middlewares = append(d.Middlewares, handleList...) +func (d *Cod) Use(handlerList ...Handler) { + d.Middlewares = append(d.Middlewares, handlerList...) } // NotFound not found handle func (d *Cod) NotFound(resp http.ResponseWriter, req *http.Request) { + if d.NotFoundHandler != nil { + d.NotFoundHandler(resp, req) + return + } resp.WriteHeader(http.StatusNotFound) resp.Write([]byte("Not found")) } // Error error handle -func (d *Cod) Error(err error, c *Context) { +func (d *Cod) Error(c *Context, err error) { + if d.ErrorHandler != nil { + d.ErrorHandler(c, err) + return + } resp := c.Response he, ok := err.(*HTTPError) if ok { @@ -200,73 +243,73 @@ func (d *Cod) OnError(ln ErrorLinstener) { d.errorLinsteners = append(d.errorLinsteners, ln) } -func (g *Group) merge(s2 []Handle) []Handle { - s1 := g.HandleList - fns := make([]Handle, len(s1)+len(s2)) +func (g *Group) merge(s2 []Handler) []Handler { + s1 := g.HandlerList + fns := make([]Handler, len(s1)+len(s2)) copy(fns, s1) copy(fns[len(s1):], s2) return fns } // GET add group http get method handl -func (g *Group) GET(path string, handleList ...Handle) { +func (g *Group) GET(path string, handlerList ...Handler) { p := g.Path + path - fns := g.merge(handleList) + fns := g.merge(handlerList) g.Cod.GET(p, fns...) } // POST add group http post method handl -func (g *Group) POST(path string, handleList ...Handle) { +func (g *Group) POST(path string, handlerList ...Handler) { p := g.Path + path - fns := g.merge(handleList) + fns := g.merge(handlerList) g.Cod.POST(p, fns...) } // PUT add group http put method handl -func (g *Group) PUT(path string, handleList ...Handle) { +func (g *Group) PUT(path string, handlerList ...Handler) { p := g.Path + path - fns := g.merge(handleList) + fns := g.merge(handlerList) g.Cod.PUT(p, fns...) } // PATCH add group http patch method handl -func (g *Group) PATCH(path string, handleList ...Handle) { +func (g *Group) PATCH(path string, handlerList ...Handler) { p := g.Path + path - fns := g.merge(handleList) + fns := g.merge(handlerList) g.Cod.PATCH(p, fns...) } // DELETE add group http delete method handl -func (g *Group) DELETE(path string, handleList ...Handle) { +func (g *Group) DELETE(path string, handlerList ...Handler) { p := g.Path + path - fns := g.merge(handleList) + fns := g.merge(handlerList) g.Cod.DELETE(p, fns...) } // HEAD add group http head method handl -func (g *Group) HEAD(path string, handleList ...Handle) { +func (g *Group) HEAD(path string, handlerList ...Handler) { p := g.Path + path - fns := g.merge(handleList) + fns := g.merge(handlerList) g.Cod.HEAD(p, fns...) } // OPTIONS add group http options method handl -func (g *Group) OPTIONS(path string, handleList ...Handle) { +func (g *Group) OPTIONS(path string, handlerList ...Handler) { p := g.Path + path - fns := g.merge(handleList) + fns := g.merge(handlerList) g.Cod.OPTIONS(p, fns...) } // TRACE add group http trace method handl -func (g *Group) TRACE(path string, handleList ...Handle) { +func (g *Group) TRACE(path string, handlerList ...Handler) { p := g.Path + path - fns := g.merge(handleList) + fns := g.merge(handlerList) g.Cod.TRACE(p, fns...) } // ALL add group http all method handl -func (g *Group) ALL(path string, handleList ...Handle) { +func (g *Group) ALL(path string, handlerList ...Handler) { p := g.Path + path - fns := g.merge(handleList) + fns := g.merge(handlerList) g.Cod.ALL(p, fns...) } diff --git a/cod_test.go b/cod_test.go index fb848dd..07045ed 100644 --- a/cod_test.go +++ b/cod_test.go @@ -17,6 +17,13 @@ func TestListenAndServe(t *testing.T) { } } +func TestNewWithoutServer(t *testing.T) { + d := NewWithoutServer() + if d.Server != nil { + t.Fatalf("new without server fail") + } +} + func TestHandle(t *testing.T) { d := New() t.Run("group", func(t *testing.T) { @@ -239,6 +246,41 @@ func TestHandle(t *testing.T) { }) } +func TestErrorHandler(t *testing.T) { + d := New() + d.GET("/", func(c *Context) error { + return errors.New("abc") + }) + + done := false + d.ErrorHandler = func(c *Context, err error) { + done = true + } + req := httptest.NewRequest("GET", "/", nil) + resp := httptest.NewRecorder() + d.ServeHTTP(resp, req) + if !done || resp.Code != http.StatusOK { + t.Fatalf("custom error handler is not called") + } +} + +func TestNotFoundHandler(t *testing.T) { + d := New() + d.GET("/", func(c *Context) error { + return nil + }) + done := false + d.NotFoundHandler = func(resp http.ResponseWriter, req *http.Request) { + done = true + } + req := httptest.NewRequest("GET", "/users/me", nil) + resp := httptest.NewRecorder() + d.ServeHTTP(resp, req) + if !done { + t.Fatalf("custom not found handler is not called") + } +} + func TestOnError(t *testing.T) { d := New() c := NewContext(nil, nil) diff --git a/context.go b/context.go index 90de347..8f3c2bb 100644 --- a/context.go +++ b/context.go @@ -5,7 +5,6 @@ import ( "net/http" "strconv" "strings" - "sync" "time" ) @@ -15,7 +14,7 @@ type ( Request *http.Request Response http.ResponseWriter Headers http.Header - // ID request id + // ID context id ID string // Route route path Route string @@ -42,14 +41,16 @@ type ( func (c *Context) Reset() { c.Request = nil c.Response = nil + c.Headers = nil + c.ID = "" c.Route = "" c.Next = nil c.Params = nil - c.RequestBody = nil c.StatusCode = 0 + c.Body = nil c.BodyBytes = nil + c.RequestBody = nil c.m = nil - c.ID = "" } // RealIP get the real ip @@ -106,7 +107,7 @@ func (c *Context) Redirect(code int, url string) (err error) { return } -// Set set the value +// Set store the value in the context func (c *Context) Set(key string, value interface{}) { if c.m == nil { c.m = make(map[string]interface{}) @@ -140,7 +141,7 @@ func (c *Context) SetCookie(cookie *http.Cookie) error { return nil } -// Get get the value +// Get get the value from context func (c *Context) Get(key string) interface{} { if c.m == nil { return nil @@ -154,17 +155,17 @@ func (c *Context) NoContent() { c.Body = nil } -// NoCache set http no cache +// NoCache set http response no cache func (c *Context) NoCache() { c.SetHeader(HeaderCacheControl, "no-cache, max-age=0") } -// NoStore set http no store +// NoStore set http response no store func (c *Context) NoStore() { c.SetHeader(HeaderCacheControl, "no-store") } -// CacheMaxAge set http cache for max age +// CacheMaxAge set http response to cache for max age func (c *Context) CacheMaxAge(age string) { d, _ := time.ParseDuration(age) cache := "public, max-age=" + strconv.Itoa(int(d.Seconds())) @@ -182,16 +183,9 @@ func (c *Context) Cod() *Cod { return c.cod } -var contextPool = sync.Pool{ - New: func() interface{} { - return &Context{} - }, -} - // NewContext new a context func NewContext(resp http.ResponseWriter, req *http.Request) *Context { - c := contextPool.Get().(*Context) - c.Reset() + c := &Context{} c.Request = req c.Response = resp if resp != nil { @@ -199,8 +193,3 @@ func NewContext(resp http.ResponseWriter, req *http.Request) *Context { } return c } - -// ReleaseContext release context -func ReleaseContext(c *Context) { - contextPool.Put(c) -} diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index c2317d8..5ac95e1 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -24,6 +24,7 @@ type ( ) var ( + // errUnauthorized unauthorized err errUnauthorized = getBasicAuthError("unAuthorized", http.StatusUnauthorized) ) @@ -36,7 +37,7 @@ func getBasicAuthError(message string, statusCode int) *cod.HTTPError { } // NewBasicAuth new basic auth -func NewBasicAuth(config BasicAuthConfig) cod.Handle { +func NewBasicAuth(config BasicAuthConfig) cod.Handler { if config.Validate == nil { panic("require validate function") } diff --git a/middleware/body_parser.go b/middleware/body_parser.go index 2d9880a..7fa3827 100644 --- a/middleware/body_parser.go +++ b/middleware/body_parser.go @@ -36,7 +36,7 @@ var ( ) // NewBodyParser new json parser -func NewBodyParser(config BodyParserConfig) cod.Handle { +func NewBodyParser(config BodyParserConfig) cod.Handler { limit := defaultRequestBodyLimit if config.Limit != 0 { limit = config.Limit diff --git a/middleware/logger.go b/middleware/logger.go index 547ba1e..45509dd 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -235,7 +235,7 @@ func format(c *cod.Context, tags []*Tag, startedAt time.Time) string { } // NewLogger create a logger -func NewLogger(config LoggerConfig) cod.Handle { +func NewLogger(config LoggerConfig) cod.Handler { if config.Format == "" { panic("logger require format") } diff --git a/middleware/recover.go b/middleware/recover.go index b217fdc..faf6a73 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -8,9 +8,11 @@ import ( ) // NewRecover new recover -func NewRecover() cod.Handle { +func NewRecover() cod.Handler { return func(c *cod.Context) error { defer func() { + // 此recover只是示例,在实际使用中, + // 需要针对实际需求调整,如对于每个recover增加邮件通知等 if r := recover(); r != nil { err, ok := r.(error) if !ok { diff --git a/middleware/responder.go b/middleware/responder.go index 993fd0b..06dd248 100644 --- a/middleware/responder.go +++ b/middleware/responder.go @@ -18,7 +18,7 @@ type ( ) // NewResponder create a responder -func NewResponder(config ResponderConfig) cod.Handle { +func NewResponder(config ResponderConfig) cod.Handler { return func(c *cod.Context) error { e := c.Next() var err *cod.HTTPError @@ -41,8 +41,8 @@ func NewResponder(config ResponderConfig) cod.Handle { respHeader := c.Headers ct := cod.HeaderContentType + // 从出错中获取响应数据,响应状态码 if err != nil { - c.StatusCode = err.StatusCode c.Body, _ = json.Marshal(err) respHeader.Set(ct, cod.MIMEApplicationJSON)