From 09c9de77fecb8a2521403b9648b7a9c576ee2a92 Mon Sep 17 00:00:00 2001 From: George Georgiou Date: Mon, 27 May 2019 10:57:20 +0300 Subject: [PATCH] Refactor generic middleware stack to run before route is resolved and per route middlewares (#50) * Refactor generic middlewares to run before routes are resolved - refactor generic middlewares to run before routes are resolved - add per route middleware mechanism Signed-off-by: George Georgiou * Update README Signed-off-by: George Georgiou * Fix linting Signed-off-by: George Georgiou * Fix import order Signed-off-by: George Georgiou * Codecov config for circleci Signed-off-by: George Georgiou --- .circleci/config.yml | 1 + README.md | 22 +++++-- examples/first/main.go | 8 +-- option_test.go | 11 ++-- service_test.go | 8 +-- sync/http/component.go | 18 ++++-- sync/http/middleware.go | 112 ++++++++++++++++---------------- sync/http/middleware_test.go | 122 ++++++++++++++--------------------- sync/http/option_test.go | 2 +- sync/http/route.go | 106 +++++++++++++++++++----------- sync/http/route_test.go | 34 +++++++++- 11 files changed, 249 insertions(+), 195 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 2846799fa..76c0e3de3 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -14,6 +14,7 @@ jobs: - run: name: Running with test coverage and send to codecov command: | + export CODECOV_TOKEN="90085daa-03f9-4ada-82e4-c24c50706e6f" go get -u github.com/jstemmer/go-junit-report mkdir -p $TEST_RESULTS trap "go-junit-report <${TEST_RESULTS}/go-test.out > ${TEST_RESULTS}/go-test-report.xml" EXIT diff --git a/README.md b/README.md index b5f9c55ee..28ba40ae8 100644 --- a/README.md +++ b/README.md @@ -122,15 +122,15 @@ A `MiddlewareFunc` preserves the default net/http middleware pattern. You can create new middleware functions and pass them to Service to be chained on all routes in the default Http Component. ```go -type MiddlewareFunc func(next http.HandlerFunc) http.HandlerFunc +type MiddlewareFunc func(next http.Handler) http.Handler // Setup a simple middleware for CORS -newMiddleware := func(h http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +newMiddleware := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Access-Control-Allow-Origin", "*") // Next - h(w, r) - } + h.ServeHTTP(w, r) + }) } ``` @@ -172,6 +172,18 @@ The `Response` model contains the following properties (which are provided when - Payload, which may hold a struct of type `interface{}` +### Middlewares per Route + +Middlewares can also run per routes using the processor as Handler. +So using the `Route` helpers: + +```go +// A route with ...MiddlewareFunc that will run for this route only + tracing +route := NewRoute("/index", "GET" ProcessorFunc, true, ...MiddlewareFunc) +// A route with ...MiddlewareFunc that will run for this route only + auth + tracing +routeWithAuth := NewAuthRoute("/index", "GET" ProcessorFunc, true, Authendicator, ...MiddlewareFunc) +``` + ### Asynchronous The implementation of the async processor follows exactly the same principle as the sync processor. diff --git a/examples/first/main.go b/examples/first/main.go index 890c45146..af834348a 100644 --- a/examples/first/main.go +++ b/examples/first/main.go @@ -47,14 +47,14 @@ func main() { } // Setup a simple CORS middleware - middlewareCors := func(h http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { + middlewareCors := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Access-Control-Allow-Origin", "*") w.Header().Add("Access-Control-Allow-Methods", "GET, POST") w.Header().Add("Access-Control-Allow-Headers", "Origin, Authorization, Content-Type") w.Header().Add("Access-Control-Allow-Credentials", "Allow") - h(w, r) - } + h.ServeHTTP(w, r) + }) } sig := patron.SIGHUP(func() { fmt.Println("exit gracefully...") diff --git a/option_test.go b/option_test.go index 1c5f833c6..266a65487 100644 --- a/option_test.go +++ b/option_test.go @@ -8,6 +8,12 @@ import ( phttp "github.com/thebeatapp/patron/sync/http" ) +func middleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.ServeHTTP(w, r) + }) +} + func TestRoutes(t *testing.T) { type args struct { rr []phttp.Route @@ -36,11 +42,6 @@ func TestRoutes(t *testing.T) { } func TestMiddlewares(t *testing.T) { - middleware := func(h http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - h(w, r) - } - } type args struct { mm []phttp.MiddlewareFunc } diff --git a/service_test.go b/service_test.go index e9a7e3bb5..beab8c5fb 100644 --- a/service_test.go +++ b/service_test.go @@ -15,10 +15,10 @@ import ( func TestNewServer(t *testing.T) { route := phttp.NewRoute("/", "GET", nil, true, nil) - middleware := func(h http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - h(w, r) - } + middleware := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h.ServeHTTP(w, r) + }) } type args struct { name string diff --git a/sync/http/component.go b/sync/http/component.go index af86bba2b..2beb7ee64 100644 --- a/sync/http/component.go +++ b/sync/http/component.go @@ -74,10 +74,6 @@ func (c *Component) Info() map[string]interface{} { func (c *Component) Run(ctx context.Context) error { c.Lock() log.Debug("applying tracing to routes") - for i := 0; i < len(c.routes); i++ { - c.routes[i].Handler = MiddlewareDefaults(c.routes[i].Trace, c.routes[i].Auth, c.routes[i].Pattern, c.routes[i].Handler) - c.routes[i].Handler = MiddlewareChain(c.routes[i].Handler, c.middlewares...) - } chFail := make(chan error) srv := c.createHTTPServer() go c.listenAndServe(srv, chFail) @@ -106,15 +102,25 @@ func (c *Component) createHTTPServer() *http.Server { log.Debugf("adding %d routes", len(c.routes)) router := httprouter.New() for _, route := range c.routes { - router.HandlerFunc(route.Method, route.Pattern, route.Handler) + if len(route.Middlewares) > 0 { + h := MiddlewareChain(route.Handler, route.Middlewares...) + router.Handler(route.Method, route.Pattern, h) + } else { + router.HandlerFunc(route.Method, route.Pattern, route.Handler) + } + log.Debugf("added route %s %s", route.Method, route.Pattern) } + // Add first the recovery middleware to ensure that no panic occur. + routerAfterMiddleware := MiddlewareChain(router, NewRecoveryMiddleware()) + routerAfterMiddleware = MiddlewareChain(routerAfterMiddleware, c.middlewares...) + return &http.Server{ Addr: fmt.Sprintf(":%d", c.httpPort), ReadTimeout: c.httpReadTimeout, WriteTimeout: c.httpWriteTimeout, IdleTimeout: httpIdleTimeout, - Handler: router, + Handler: routerAfterMiddleware, } } diff --git a/sync/http/middleware.go b/sync/http/middleware.go index 9d6219c6f..c55a8b3f7 100644 --- a/sync/http/middleware.go +++ b/sync/http/middleware.go @@ -52,73 +52,69 @@ func (w *responseWriter) WriteHeader(code int) { w.statusHeaderWritten = true } -// MiddlewareChain chains middlewares to a handler func. -func MiddlewareChain(f http.HandlerFunc, mm ...MiddlewareFunc) http.HandlerFunc { - for i := len(mm) - 1; i >= 0; i-- { - f = mm[i](f) - } - return f -} - -// MiddlewareDefaults chains all default middlewares to handler function and returns the handler func. -func MiddlewareDefaults(trace bool, auth auth.Authenticator, path string, next http.HandlerFunc) http.HandlerFunc { - next = recoveryMiddleware(next) - if auth != nil { - next = authMiddleware(auth, next) - } - if trace { - next = tracingMiddleware(path, next) - } - return next -} - // MiddlewareFunc type declaration of middleware func. -type MiddlewareFunc func(next http.HandlerFunc) http.HandlerFunc - -func tracingMiddleware(path string, next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - sp, r := trace.HTTPSpan(path, r) - lw := newResponseWriter(w) - next(lw, r) - trace.FinishHTTPSpan(sp, lw.Status()) +type MiddlewareFunc func(next http.Handler) http.Handler + +// NewRecoveryMiddleware creates a MiddlewareFunc that ensures recovery and no panic. +func NewRecoveryMiddleware() MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if r := recover(); r != nil { + var err error + switch x := r.(type) { + case string: + err = errors.New(x) + case error: + err = x + default: + err = errors.New("unknown panic") + } + _ = err + log.Errorf("recovering from an error %v", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } + }() + next.ServeHTTP(w, r) + }) } } -func recoveryMiddleware(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - defer func() { - if r := recover(); r != nil { - var err error - switch x := r.(type) { - case string: - err = errors.New(x) - case error: - err = x - default: - err = errors.New("unknown panic") - } - _ = err - log.Errorf("recovering from an error %v", err) +// NewAuthMiddleware creates a MiddlewareFunc that implements authentication using an Authenticator. +func NewAuthMiddleware(auth auth.Authenticator) MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authenticated, err := auth.Authenticate(r) + if err != nil { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + + if !authenticated { + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return } - }() - next(w, r) + next.ServeHTTP(w, r) + }) } } -func authMiddleware(auth auth.Authenticator, next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - authenticated, err := auth.Authenticate(r) - if err != nil { - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - return - } - - if !authenticated { - http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - return - } +// NewTracingMiddleware creates a MiddlewareFunc that continues a tracing span and finishes it. +func NewTracingMiddleware(path string) MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sp, r := trace.HTTPSpan(path, r) + lw := newResponseWriter(w) + next.ServeHTTP(w, r) + trace.FinishHTTPSpan(sp, lw.Status()) + }) + } +} - next(w, r) +// MiddlewareChain chains middlewares to a handler func. +func MiddlewareChain(f http.Handler, mm ...MiddlewareFunc) http.Handler { + for i := len(mm) - 1; i >= 0; i-- { + f = mm[i](f) } + return f } diff --git a/sync/http/middleware_test.go b/sync/http/middleware_test.go index 6f541bea9..dec27d015 100644 --- a/sync/http/middleware_test.go +++ b/sync/http/middleware_test.go @@ -7,76 +7,76 @@ import ( "github.com/stretchr/testify/assert" "github.com/thebeatapp/patron/errors" - "github.com/thebeatapp/patron/sync/http/auth" ) -func testHandle(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(202) -} - -func testPanicHandleString(w http.ResponseWriter, r *http.Request) { - panic("test") -} - -func testPanicHandleError(w http.ResponseWriter, r *http.Request) { - panic(errors.New("TEST")) -} - -func testPanicHandleInt(w http.ResponseWriter, r *http.Request) { - panic(1000) -} - // A middleware generator that tags resp for assertions func tagMiddleware(tag string) MiddlewareFunc { - return func(h http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(tag)) - h(w, r) - } + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(tag)) + //next + h.ServeHTTP(w, r) + }) } } -func TestMiddlewareDefaults(t *testing.T) { +// Panic middleware to test recovery middleware +func panicMiddleware(v interface{}) MiddlewareFunc { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic(v) + }) + } +} + +func TestMiddlewareChain(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(202) + }) + r, err := http.NewRequest("POST", "/test", nil) assert.NoError(t, err) + t1 := tagMiddleware("t1\n") + t2 := tagMiddleware("t2\n") + t3 := tagMiddleware("t3\n") + type args struct { - next http.HandlerFunc - trace bool - auth auth.Authenticator + next http.Handler + mws []MiddlewareFunc } tests := []struct { name string args args expectedCode int + expectedBody string }{ - {"middleware success", args{next: testHandle, trace: false, auth: &MockAuthenticator{success: true}}, 202}, - {"middleware trace success", args{next: testHandle, trace: true, auth: &MockAuthenticator{success: true}}, 202}, - {"middleware panic string", args{next: testPanicHandleString, trace: true, auth: &MockAuthenticator{success: true}}, 500}, - {"middleware panic error", args{next: testPanicHandleError, trace: true, auth: &MockAuthenticator{success: true}}, 500}, - {"middleware panic other", args{next: testPanicHandleInt, trace: true, auth: &MockAuthenticator{success: true}}, 500}, - {"middleware auth error", args{next: testPanicHandleInt, trace: true, auth: &MockAuthenticator{err: errors.New("TEST")}}, 500}, - {"middleware auth failure", args{next: testPanicHandleInt, trace: true, auth: &MockAuthenticator{success: false}}, 401}, + {"middleware 1,2,3 and finish", args{next: handler, mws: []MiddlewareFunc{t1, t2, t3}}, 202, "t1\nt2\nt3\n"}, + {"middleware 1,2 and finish", args{next: handler, mws: []MiddlewareFunc{t1, t2}}, 202, "t1\nt2\n"}, + {"no middleware and finish", args{next: handler, mws: []MiddlewareFunc{}}, 202, ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - resp := httptest.NewRecorder() - MiddlewareDefaults(tt.args.trace, tt.args.auth, "path", tt.args.next)(resp, r) - assert.Equal(t, tt.expectedCode, resp.Code) + rc := httptest.NewRecorder() + rw := newResponseWriter(rc) + tt.args.next = MiddlewareChain(tt.args.next, tt.args.mws...) + tt.args.next.ServeHTTP(rw, r) + assert.Equal(t, tt.expectedCode, rw.Status()) + assert.Equal(t, tt.expectedBody, rc.Body.String()) }) } } -func TestMiddlewareChain(t *testing.T) { +func TestMiddlewares(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(202) + }) + r, err := http.NewRequest("POST", "/test", nil) assert.NoError(t, err) - t1 := tagMiddleware("t1\n") - t2 := tagMiddleware("t2\n") - t3 := tagMiddleware("t3\n") - type args struct { - next http.HandlerFunc + next http.Handler mws []MiddlewareFunc } tests := []struct { @@ -85,16 +85,20 @@ func TestMiddlewareChain(t *testing.T) { expectedCode int expectedBody string }{ - {"middleware 1,2,3 and finish", args{next: testHandle, mws: []MiddlewareFunc{t1, t2, t3}}, 202, "t1\nt2\nt3\n"}, - {"middleware 1,2 and finish", args{next: testHandle, mws: []MiddlewareFunc{t1, t2}}, 202, "t1\nt2\n"}, - {"no middleware and finish", args{next: testHandle, mws: []MiddlewareFunc{}}, 202, ""}, + {"auth middleware success", args{next: handler, mws: []MiddlewareFunc{NewAuthMiddleware(&MockAuthenticator{success: true})}}, 202, ""}, + {"auth middleware false", args{next: handler, mws: []MiddlewareFunc{NewAuthMiddleware(&MockAuthenticator{success: false})}}, 401, "Unauthorized\n"}, + {"auth middleware error", args{next: handler, mws: []MiddlewareFunc{NewAuthMiddleware(&MockAuthenticator{err: errors.New("auth error")})}}, 500, "Internal Server Error\n"}, + {"tracing middleware", args{next: handler, mws: []MiddlewareFunc{NewTracingMiddleware("/index")}}, 202, ""}, + {"recovery middleware from panic 1", args{next: handler, mws: []MiddlewareFunc{NewRecoveryMiddleware(), panicMiddleware("error")}}, 500, "Internal Server Error\n"}, + {"recovery middleware from panic 2", args{next: handler, mws: []MiddlewareFunc{NewRecoveryMiddleware(), panicMiddleware(errors.New("error"))}}, 500, "Internal Server Error\n"}, + {"recovery middleware from panic 3", args{next: handler, mws: []MiddlewareFunc{NewRecoveryMiddleware(), panicMiddleware(-1)}}, 500, "Internal Server Error\n"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { rc := httptest.NewRecorder() rw := newResponseWriter(rc) tt.args.next = MiddlewareChain(tt.args.next, tt.args.mws...) - tt.args.next(rw, r) + tt.args.next.ServeHTTP(rw, r) assert.Equal(t, tt.expectedCode, rw.Status()) assert.Equal(t, tt.expectedBody, rc.Body.String()) }) @@ -114,29 +118,3 @@ func TestResponseWriter(t *testing.T) { assert.True(t, rw.statusHeaderWritten, "expected to be true") assert.Equal(t, "test", rc.Body.String(), "body expected to be test but was %s", rc.Body.String()) } - -// func Test_authMiddleware(t *testing.T) { -// r, err := http.NewRequest("POST", "/test", nil) -// assert.NoError(t, err) - -// type args struct { -// auth Authenticator -// next http.HandlerFunc -// resp *httptest.ResponseRecorder -// } -// tests := []struct { -// name string -// args args -// expectedCode int -// }{ -// {name: "authenticated", args: args{auth: &MockAuthenticator{success: true}}, expectedCode: 202}, -// {name: "unauthorized", args: args{auth: &MockAuthenticator{success: false}}, expectedCode: 401}, -// {name: "error", args: args{auth: &MockAuthenticator{err: errors.New("TEST")}}, expectedCode: 500}, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// authMiddleware(tt.args.auth, testHandle)(tt.args.resp, r) -// assert.Equal(t, tt.expectedCode, tt.args.resp.Code) -// }) -// } -// } diff --git a/sync/http/option_test.go b/sync/http/option_test.go index 1f068e654..dc2c33012 100644 --- a/sync/http/option_test.go +++ b/sync/http/option_test.go @@ -62,7 +62,7 @@ func TestSetMiddlewares(t *testing.T) { mm []MiddlewareFunc wantErr bool }{ - {"success", []MiddlewareFunc{func(next http.HandlerFunc) http.HandlerFunc { return next }}, false}, + {"success", []MiddlewareFunc{func(next http.Handler) http.Handler { return next }}, false}, {"error for empty middlewares", []MiddlewareFunc{}, true}, {"error for nil middlewares", nil, true}, } diff --git a/sync/http/route.go b/sync/http/route.go index 254a80bfb..c672d7c14 100644 --- a/sync/http/route.go +++ b/sync/http/route.go @@ -9,94 +9,122 @@ import ( // Route definition of a HTTP route. type Route struct { - Pattern string - Method string - Handler http.HandlerFunc - Trace bool - Auth auth.Authenticator + Pattern string + Method string + Handler http.HandlerFunc + Trace bool + Auth auth.Authenticator + Middlewares []MiddlewareFunc } // NewGetRoute creates a new GET route from a generic handler. -func NewGetRoute(p string, pr sync.ProcessorFunc, trace bool) Route { - return NewRoute(p, http.MethodGet, pr, trace, nil) +func NewGetRoute(p string, pr sync.ProcessorFunc, trace bool, mm ...MiddlewareFunc) Route { + return NewRoute(p, http.MethodGet, pr, trace, nil, mm...) } // NewPostRoute creates a new POST route from a generic handler. -func NewPostRoute(p string, pr sync.ProcessorFunc, trace bool) Route { - return NewRoute(p, http.MethodPost, pr, trace, nil) +func NewPostRoute(p string, pr sync.ProcessorFunc, trace bool, mm ...MiddlewareFunc) Route { + return NewRoute(p, http.MethodPost, pr, trace, nil, mm...) } // NewPutRoute creates a new PUT route from a generic handler. -func NewPutRoute(p string, pr sync.ProcessorFunc, trace bool) Route { - return NewRoute(p, http.MethodPut, pr, trace, nil) +func NewPutRoute(p string, pr sync.ProcessorFunc, trace bool, mm ...MiddlewareFunc) Route { + return NewRoute(p, http.MethodPut, pr, trace, nil, mm...) } // NewDeleteRoute creates a new DELETE route from a generic handler. -func NewDeleteRoute(p string, pr sync.ProcessorFunc, trace bool) Route { - return NewRoute(p, http.MethodDelete, pr, trace, nil) +func NewDeleteRoute(p string, pr sync.ProcessorFunc, trace bool, mm ...MiddlewareFunc) Route { + return NewRoute(p, http.MethodDelete, pr, trace, nil, mm...) } // NewPatchRoute creates a new PATCH route from a generic handler. -func NewPatchRoute(p string, pr sync.ProcessorFunc, trace bool) Route { - return NewRoute(p, http.MethodPatch, pr, trace, nil) +func NewPatchRoute(p string, pr sync.ProcessorFunc, trace bool, mm ...MiddlewareFunc) Route { + return NewRoute(p, http.MethodPatch, pr, trace, nil, mm...) } // NewHeadRoute creates a new HEAD route from a generic handler. -func NewHeadRoute(p string, pr sync.ProcessorFunc, trace bool) Route { - return NewRoute(p, http.MethodHead, pr, trace, nil) +func NewHeadRoute(p string, pr sync.ProcessorFunc, trace bool, mm ...MiddlewareFunc) Route { + return NewRoute(p, http.MethodHead, pr, trace, nil, mm...) } // NewOptionsRoute creates a new OPTIONS route from a generic handler. -func NewOptionsRoute(p string, pr sync.ProcessorFunc, trace bool) Route { - return NewRoute(p, http.MethodOptions, pr, trace, nil) +func NewOptionsRoute(p string, pr sync.ProcessorFunc, trace bool, mm ...MiddlewareFunc) Route { + return NewRoute(p, http.MethodOptions, pr, trace, nil, mm...) } // NewRoute creates a new route from a generic handler with auth capability. -func NewRoute(p string, m string, pr sync.ProcessorFunc, trace bool, auth auth.Authenticator) Route { - return Route{Pattern: p, Method: m, Handler: handler(pr), Trace: trace, Auth: auth} +func NewRoute(p string, m string, pr sync.ProcessorFunc, trace bool, auth auth.Authenticator, mm ...MiddlewareFunc) Route { + var middlewares []MiddlewareFunc + if trace { + middlewares = append(middlewares, NewTracingMiddleware(p)) + } + if auth != nil { + middlewares = append(middlewares, NewAuthMiddleware(auth)) + } + if len(mm) > 0 { + middlewares = append(middlewares, mm...) + } + return Route{Pattern: p, Method: m, Handler: handler(pr), Trace: trace, Auth: auth, Middlewares: middlewares} } // NewRouteRaw creates a new route from a HTTP handler. -func NewRouteRaw(p string, m string, h http.HandlerFunc, trace bool) Route { - return Route{Pattern: p, Method: m, Handler: h, Trace: trace} +func NewRouteRaw(p string, m string, h http.HandlerFunc, trace bool, mm ...MiddlewareFunc) Route { + var middlewares []MiddlewareFunc + if trace { + middlewares = append(middlewares, NewTracingMiddleware(p)) + } + if len(mm) > 0 { + middlewares = append(middlewares, mm...) + } + return Route{Pattern: p, Method: m, Handler: h, Trace: trace, Middlewares: middlewares} } // NewAuthGetRoute creates a new GET route from a generic handler with auth capability. -func NewAuthGetRoute(p string, pr sync.ProcessorFunc, trace bool, auth auth.Authenticator) Route { - return NewRoute(p, http.MethodGet, pr, trace, auth) +func NewAuthGetRoute(p string, pr sync.ProcessorFunc, trace bool, auth auth.Authenticator, mm ...MiddlewareFunc) Route { + return NewRoute(p, http.MethodGet, pr, trace, auth, mm...) } // NewAuthPostRoute creates a new POST route from a generic handler with auth capability. -func NewAuthPostRoute(p string, pr sync.ProcessorFunc, trace bool, auth auth.Authenticator) Route { - return NewRoute(p, http.MethodPost, pr, trace, auth) +func NewAuthPostRoute(p string, pr sync.ProcessorFunc, trace bool, auth auth.Authenticator, mm ...MiddlewareFunc) Route { + return NewRoute(p, http.MethodPost, pr, trace, auth, mm...) } // NewAuthPutRoute creates a new PUT route from a generic handler with auth capability. -func NewAuthPutRoute(p string, pr sync.ProcessorFunc, trace bool, auth auth.Authenticator) Route { - return NewRoute(p, http.MethodPut, pr, trace, auth) +func NewAuthPutRoute(p string, pr sync.ProcessorFunc, trace bool, auth auth.Authenticator, mm ...MiddlewareFunc) Route { + return NewRoute(p, http.MethodPut, pr, trace, auth, mm...) } // NewAuthDeleteRoute creates a new DELETE route from a generic handler with auth capability. -func NewAuthDeleteRoute(p string, pr sync.ProcessorFunc, trace bool, auth auth.Authenticator) Route { - return NewRoute(p, http.MethodDelete, pr, trace, auth) +func NewAuthDeleteRoute(p string, pr sync.ProcessorFunc, trace bool, auth auth.Authenticator, mm ...MiddlewareFunc) Route { + return NewRoute(p, http.MethodDelete, pr, trace, auth, mm...) } // NewAuthPatchRoute creates a new PATCH route from a generic handler with auth capability. -func NewAuthPatchRoute(p string, pr sync.ProcessorFunc, trace bool, auth auth.Authenticator) Route { - return NewRoute(p, http.MethodPatch, pr, trace, auth) +func NewAuthPatchRoute(p string, pr sync.ProcessorFunc, trace bool, auth auth.Authenticator, mm ...MiddlewareFunc) Route { + return NewRoute(p, http.MethodPatch, pr, trace, auth, mm...) } // NewAuthHeadRoute creates a new HEAD route from a generic handler with auth capability. -func NewAuthHeadRoute(p string, pr sync.ProcessorFunc, trace bool, auth auth.Authenticator) Route { - return NewRoute(p, http.MethodHead, pr, trace, auth) +func NewAuthHeadRoute(p string, pr sync.ProcessorFunc, trace bool, auth auth.Authenticator, mm ...MiddlewareFunc) Route { + return NewRoute(p, http.MethodHead, pr, trace, auth, mm...) } // NewAuthOptionsRoute creates a new OPTIONS route from a generic handler with auth capability. -func NewAuthOptionsRoute(p string, pr sync.ProcessorFunc, trace bool, auth auth.Authenticator) Route { - return NewRoute(p, http.MethodOptions, pr, trace, auth) +func NewAuthOptionsRoute(p string, pr sync.ProcessorFunc, trace bool, auth auth.Authenticator, mm ...MiddlewareFunc) Route { + return NewRoute(p, http.MethodOptions, pr, trace, auth, mm...) } // NewAuthRouteRaw creates a new route from a HTTP handler with auth capability. -func NewAuthRouteRaw(p string, m string, h http.HandlerFunc, trace bool, auth auth.Authenticator) Route { - return Route{Pattern: p, Method: m, Handler: h, Trace: trace, Auth: auth} +func NewAuthRouteRaw(p string, m string, h http.HandlerFunc, trace bool, auth auth.Authenticator, mm ...MiddlewareFunc) Route { + var middlewares []MiddlewareFunc + if trace { + middlewares = append(middlewares, NewTracingMiddleware(p)) + } + if auth != nil { + middlewares = append(middlewares, NewAuthMiddleware(auth)) + } + if len(mm) > 0 { + middlewares = append(middlewares, mm...) + } + return Route{Pattern: p, Method: m, Handler: h, Trace: trace, Auth: auth, Middlewares: middlewares} } diff --git a/sync/http/route_test.go b/sync/http/route_test.go index fc79f3f50..a125d502e 100644 --- a/sync/http/route_test.go +++ b/sync/http/route_test.go @@ -28,11 +28,14 @@ func TestNewRoute(t *testing.T) { } func TestNewGetRoute(t *testing.T) { - r := NewGetRoute("/index", nil, true) + t1 := tagMiddleware("t1\n") + t2 := tagMiddleware("t2\n") + r := NewGetRoute("/index", nil, true, t1, t2) assert.Equal(t, "/index", r.Pattern) assert.Equal(t, http.MethodGet, r.Method) assert.True(t, r.Trace) assert.Nil(t, r.Auth) + assert.Len(t, r.Middlewares, 3) } func TestNewPostRoute(t *testing.T) { @@ -41,6 +44,7 @@ func TestNewPostRoute(t *testing.T) { assert.Equal(t, http.MethodPost, r.Method) assert.True(t, r.Trace) assert.Nil(t, r.Auth) + assert.Len(t, r.Middlewares, 1) } func TestNewPutRoute(t *testing.T) { @@ -49,6 +53,7 @@ func TestNewPutRoute(t *testing.T) { assert.Equal(t, http.MethodPut, r.Method) assert.True(t, r.Trace) assert.Nil(t, r.Auth) + assert.Len(t, r.Middlewares, 1) } func TestNewDeleteRoute(t *testing.T) { @@ -57,6 +62,7 @@ func TestNewDeleteRoute(t *testing.T) { assert.Equal(t, http.MethodDelete, r.Method) assert.True(t, r.Trace) assert.Nil(t, r.Auth) + assert.Len(t, r.Middlewares, 1) } func TestNewPatchRoute(t *testing.T) { @@ -65,6 +71,7 @@ func TestNewPatchRoute(t *testing.T) { assert.Equal(t, http.MethodPatch, r.Method) assert.True(t, r.Trace) assert.Nil(t, r.Auth) + assert.Len(t, r.Middlewares, 1) } func TestNewHeadRoute(t *testing.T) { @@ -73,6 +80,7 @@ func TestNewHeadRoute(t *testing.T) { assert.Equal(t, http.MethodHead, r.Method) assert.True(t, r.Trace) assert.Nil(t, r.Auth) + assert.Len(t, r.Middlewares, 1) } func TestNewOptionsRoute(t *testing.T) { @@ -81,6 +89,7 @@ func TestNewOptionsRoute(t *testing.T) { assert.Equal(t, http.MethodOptions, r.Method) assert.True(t, r.Trace) assert.Nil(t, r.Auth) + assert.Len(t, r.Middlewares, 1) } func TestNewRouteRaw(t *testing.T) { r := NewRouteRaw("/index", http.MethodGet, nil, false) @@ -88,6 +97,14 @@ func TestNewRouteRaw(t *testing.T) { assert.Equal(t, "GET", r.Method) assert.False(t, r.Trace) assert.Nil(t, r.Auth) + assert.Len(t, r.Middlewares, 0) + + r = NewRouteRaw("/index", http.MethodGet, nil, true, tagMiddleware("t1")) + assert.Equal(t, "/index", r.Pattern) + assert.Equal(t, "GET", r.Method) + assert.True(t, r.Trace) + assert.Nil(t, r.Auth) + assert.Len(t, r.Middlewares, 2) } func TestNewAuthGetRoute(t *testing.T) { @@ -96,6 +113,7 @@ func TestNewAuthGetRoute(t *testing.T) { assert.Equal(t, http.MethodGet, r.Method) assert.True(t, r.Trace) assert.NotNil(t, r.Auth) + assert.Len(t, r.Middlewares, 2) } func TestNewAuthPostRoute(t *testing.T) { @@ -104,6 +122,7 @@ func TestNewAuthPostRoute(t *testing.T) { assert.Equal(t, http.MethodPost, r.Method) assert.True(t, r.Trace) assert.NotNil(t, r.Auth) + assert.Len(t, r.Middlewares, 2) } func TestNewAuthPutRoute(t *testing.T) { @@ -112,6 +131,7 @@ func TestNewAuthPutRoute(t *testing.T) { assert.Equal(t, http.MethodPut, r.Method) assert.True(t, r.Trace) assert.NotNil(t, r.Auth) + assert.Len(t, r.Middlewares, 2) } func TestNewAuthDeleteRoute(t *testing.T) { @@ -120,6 +140,7 @@ func TestNewAuthDeleteRoute(t *testing.T) { assert.Equal(t, http.MethodDelete, r.Method) assert.True(t, r.Trace) assert.NotNil(t, r.Auth) + assert.Len(t, r.Middlewares, 2) } func TestNewAuthPatchRoute(t *testing.T) { @@ -128,6 +149,7 @@ func TestNewAuthPatchRoute(t *testing.T) { assert.Equal(t, http.MethodPatch, r.Method) assert.True(t, r.Trace) assert.NotNil(t, r.Auth) + assert.Len(t, r.Middlewares, 2) } func TestNewAuthHeadRoute(t *testing.T) { @@ -136,6 +158,7 @@ func TestNewAuthHeadRoute(t *testing.T) { assert.Equal(t, http.MethodHead, r.Method) assert.True(t, r.Trace) assert.NotNil(t, r.Auth) + assert.Len(t, r.Middlewares, 2) } func TestNewAuthOptionsRoute(t *testing.T) { @@ -144,6 +167,7 @@ func TestNewAuthOptionsRoute(t *testing.T) { assert.Equal(t, http.MethodOptions, r.Method) assert.True(t, r.Trace) assert.NotNil(t, r.Auth) + assert.Len(t, r.Middlewares, 2) } func TestNewAuthRouteRaw(t *testing.T) { @@ -152,4 +176,12 @@ func TestNewAuthRouteRaw(t *testing.T) { assert.Equal(t, "GET", r.Method) assert.False(t, r.Trace) assert.NotNil(t, r.Auth) + assert.Len(t, r.Middlewares, 1) + + r = NewAuthRouteRaw("/index", http.MethodGet, nil, true, &MockAuthenticator{}, tagMiddleware("tag1")) + assert.Equal(t, "/index", r.Pattern) + assert.Equal(t, "GET", r.Method) + assert.True(t, r.Trace) + assert.NotNil(t, r.Auth) + assert.Len(t, r.Middlewares, 3) }