Skip to content

Commit

Permalink
Refactor generic middleware stack to run before route is resolved and…
Browse files Browse the repository at this point in the history
… per route middlewares (#50)

* Refactor generic middlewares to run before routes are resolved

 - refactor generic middlewares to run before routes are resolved
 - add per route middleware mechanism

Signed-off-by: George Georgiou <[email protected]>

* Update README

Signed-off-by: George Georgiou <[email protected]>

* Fix linting

Signed-off-by: George Georgiou <[email protected]>

* Fix import order

Signed-off-by: George Georgiou <[email protected]>

* Codecov config for circleci

Signed-off-by: George Georgiou <[email protected]>
  • Loading branch information
georgegg authored and sotirispl committed May 27, 2019
1 parent 0aac399 commit 09c9de7
Show file tree
Hide file tree
Showing 11 changed files with 249 additions and 195 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
- run:
name: Running with test coverage and send to codecov
command: |
export CODECOV_TOKEN="90085daa-03f9-4ada-82e4-c24c50706e6f"
go get -u github.com/jstemmer/go-junit-report
mkdir -p $TEST_RESULTS
trap "go-junit-report <${TEST_RESULTS}/go-test.out > ${TEST_RESULTS}/go-test-report.xml" EXIT
Expand Down
22 changes: 17 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,15 @@ A `MiddlewareFunc` preserves the default net/http middleware pattern.
You can create new middleware functions and pass them to Service to be chained on all routes in the default Http Component.

```go
type MiddlewareFunc func(next http.HandlerFunc) http.HandlerFunc
type MiddlewareFunc func(next http.Handler) http.Handler

// Setup a simple middleware for CORS
newMiddleware := func(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
newMiddleware := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Access-Control-Allow-Origin", "*")
// Next
h(w, r)
}
h.ServeHTTP(w, r)
})
}
```

Expand Down Expand Up @@ -172,6 +172,18 @@ The `Response` model contains the following properties (which are provided when

- Payload, which may hold a struct of type `interface{}`

### Middlewares per Route

Middlewares can also run per routes using the processor as Handler.
So using the `Route` helpers:

```go
// A route with ...MiddlewareFunc that will run for this route only + tracing
route := NewRoute("/index", "GET" ProcessorFunc, true, ...MiddlewareFunc)
// A route with ...MiddlewareFunc that will run for this route only + auth + tracing
routeWithAuth := NewAuthRoute("/index", "GET" ProcessorFunc, true, Authendicator, ...MiddlewareFunc)
```

### Asynchronous

The implementation of the async processor follows exactly the same principle as the sync processor.
Expand Down
8 changes: 4 additions & 4 deletions examples/first/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ func main() {
}

// Setup a simple CORS middleware
middlewareCors := func(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
middlewareCors := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Methods", "GET, POST")
w.Header().Add("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type")
w.Header().Add("Access-Control-Allow-Credentials", "Allow")
h(w, r)
}
h.ServeHTTP(w, r)
})
}
sig := patron.SIGHUP(func() {
fmt.Println("exit gracefully...")
Expand Down
11 changes: 6 additions & 5 deletions option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ import (
phttp "github.com/thebeatapp/patron/sync/http"
)

func middleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h.ServeHTTP(w, r)
})
}

func TestRoutes(t *testing.T) {
type args struct {
rr []phttp.Route
Expand Down Expand Up @@ -36,11 +42,6 @@ func TestRoutes(t *testing.T) {
}

func TestMiddlewares(t *testing.T) {
middleware := func(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h(w, r)
}
}
type args struct {
mm []phttp.MiddlewareFunc
}
Expand Down
8 changes: 4 additions & 4 deletions service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ import (

func TestNewServer(t *testing.T) {
route := phttp.NewRoute("/", "GET", nil, true, nil)
middleware := func(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h(w, r)
}
middleware := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h.ServeHTTP(w, r)
})
}
type args struct {
name string
Expand Down
18 changes: 12 additions & 6 deletions sync/http/component.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,6 @@ func (c *Component) Info() map[string]interface{} {
func (c *Component) Run(ctx context.Context) error {
c.Lock()
log.Debug("applying tracing to routes")
for i := 0; i < len(c.routes); i++ {
c.routes[i].Handler = MiddlewareDefaults(c.routes[i].Trace, c.routes[i].Auth, c.routes[i].Pattern, c.routes[i].Handler)
c.routes[i].Handler = MiddlewareChain(c.routes[i].Handler, c.middlewares...)
}
chFail := make(chan error)
srv := c.createHTTPServer()
go c.listenAndServe(srv, chFail)
Expand Down Expand Up @@ -106,15 +102,25 @@ func (c *Component) createHTTPServer() *http.Server {
log.Debugf("adding %d routes", len(c.routes))
router := httprouter.New()
for _, route := range c.routes {
router.HandlerFunc(route.Method, route.Pattern, route.Handler)
if len(route.Middlewares) > 0 {
h := MiddlewareChain(route.Handler, route.Middlewares...)
router.Handler(route.Method, route.Pattern, h)
} else {
router.HandlerFunc(route.Method, route.Pattern, route.Handler)
}

log.Debugf("added route %s %s", route.Method, route.Pattern)
}
// Add first the recovery middleware to ensure that no panic occur.
routerAfterMiddleware := MiddlewareChain(router, NewRecoveryMiddleware())
routerAfterMiddleware = MiddlewareChain(routerAfterMiddleware, c.middlewares...)

return &http.Server{
Addr: fmt.Sprintf(":%d", c.httpPort),
ReadTimeout: c.httpReadTimeout,
WriteTimeout: c.httpWriteTimeout,
IdleTimeout: httpIdleTimeout,
Handler: router,
Handler: routerAfterMiddleware,
}
}

Expand Down
112 changes: 54 additions & 58 deletions sync/http/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,73 +52,69 @@ func (w *responseWriter) WriteHeader(code int) {
w.statusHeaderWritten = true
}

// MiddlewareChain chains middlewares to a handler func.
func MiddlewareChain(f http.HandlerFunc, mm ...MiddlewareFunc) http.HandlerFunc {
for i := len(mm) - 1; i >= 0; i-- {
f = mm[i](f)
}
return f
}

// MiddlewareDefaults chains all default middlewares to handler function and returns the handler func.
func MiddlewareDefaults(trace bool, auth auth.Authenticator, path string, next http.HandlerFunc) http.HandlerFunc {
next = recoveryMiddleware(next)
if auth != nil {
next = authMiddleware(auth, next)
}
if trace {
next = tracingMiddleware(path, next)
}
return next
}

// MiddlewareFunc type declaration of middleware func.
type MiddlewareFunc func(next http.HandlerFunc) http.HandlerFunc

func tracingMiddleware(path string, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
sp, r := trace.HTTPSpan(path, r)
lw := newResponseWriter(w)
next(lw, r)
trace.FinishHTTPSpan(sp, lw.Status())
type MiddlewareFunc func(next http.Handler) http.Handler

// NewRecoveryMiddleware creates a MiddlewareFunc that ensures recovery and no panic.
func NewRecoveryMiddleware() MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if r := recover(); r != nil {
var err error
switch x := r.(type) {
case string:
err = errors.New(x)
case error:
err = x
default:
err = errors.New("unknown panic")
}
_ = err
log.Errorf("recovering from an error %v", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
}

func recoveryMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
defer func() {
if r := recover(); r != nil {
var err error
switch x := r.(type) {
case string:
err = errors.New(x)
case error:
err = x
default:
err = errors.New("unknown panic")
}
_ = err
log.Errorf("recovering from an error %v", err)
// NewAuthMiddleware creates a MiddlewareFunc that implements authentication using an Authenticator.
func NewAuthMiddleware(auth auth.Authenticator) MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authenticated, err := auth.Authenticate(r)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}

if !authenticated {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
}()
next(w, r)
next.ServeHTTP(w, r)
})
}
}

func authMiddleware(auth auth.Authenticator, next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authenticated, err := auth.Authenticate(r)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}

if !authenticated {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
// NewTracingMiddleware creates a MiddlewareFunc that continues a tracing span and finishes it.
func NewTracingMiddleware(path string) MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sp, r := trace.HTTPSpan(path, r)
lw := newResponseWriter(w)
next.ServeHTTP(w, r)
trace.FinishHTTPSpan(sp, lw.Status())
})
}
}

next(w, r)
// MiddlewareChain chains middlewares to a handler func.
func MiddlewareChain(f http.Handler, mm ...MiddlewareFunc) http.Handler {
for i := len(mm) - 1; i >= 0; i-- {
f = mm[i](f)
}
return f
}
Loading

0 comments on commit 09c9de7

Please sign in to comment.