Skip to content

Commit

Permalink
[patch] refactored accesslog & CORS middleware (moved to new package)
Browse files Browse the repository at this point in the history
[patch] fixed a bug which did not send appropriate JSON response header for the convenience methods
  • Loading branch information
bnkamalesh committed Jun 14, 2020
1 parent e27b288 commit 0254c1b
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 6 deletions.
32 changes: 32 additions & 0 deletions middleware/accesslog/accesslog.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
Package accesslogs provides a simple straight forward access log middleware. The logs are of the
following format:
<timestamp> <HTTP request method> <full URL including query string parameters> <duration of execution> <HTTP response status code>
*/
package accesslog

import (
"fmt"
"net/http"
"time"

"github.com/bnkamalesh/webgo/v4"
)

// AccessLog is a middleware which prints access log to stdout
func AccessLog(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
start := time.Now()
next(rw, req)
end := time.Now()

webgo.LOGHANDLER.Info(
fmt.Sprintf(
"%s %s %s %s %d",
end.Format("2006-01-02 15:04:05 -0700 MST"),
req.Method,
req.URL.String(),
end.Sub(start).String(),
webgo.ResponseStatus(rw),
),
)
}
198 changes: 198 additions & 0 deletions middleware/cors/cors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/*
Package cors sets the appropriate CORS(https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS)
response headers, and lets you customize. Following customizations are allowed:
- provide a list of allowed domains
- provide a list of headers
- set the max-age of CORS headers
The list of allowed methods are
*/
package cors

import (
"fmt"
"net/http"
"regexp"
"sort"
"strings"

"github.com/bnkamalesh/webgo/v4"
)

const (
headerOrigin = "Access-Control-Allow-Origin"
headerMethods = "Access-Control-Allow-Methods"
headerCreds = "Access-Control-Allow-Credentials"
headerAllowHeaders = "Access-Control-Allow-Headers"
headerReqHeaders = "Access-Control-Request-Headers"
headerAccessControlAge = "Access-Control-Max-Age"
allowHeaders = "Accept,Content-Type,Content-Length,Accept-Encoding,Access-Control-Request-Headers,"
)

var (
defaultAllowMethods = "HEAD,GET,POST,PUT,PATCH,DELETE,OPTIONS"
)

func allowedDomains() []string {
// The domains mentioned here are default
domains := []string{"*"}
return domains
}

func getReqOrigin(r *http.Request) string {
return r.Header.Get("Origin")
}

func allowedOriginsRegex(allowedOrigins ...string) []regexp.Regexp {
if len(allowedOrigins) == 0 {
allowedOrigins = []string{"*"}
} else {
// If "*" is one of the allowed domains, i.e. all domains, then rest of the values are ignored
for _, val := range allowedOrigins {
val = strings.TrimSpace(val)

if val == "*" {
allowedOrigins = []string{"*"}
break
}
}
}

allowedOriginRegex := make([]regexp.Regexp, 0, len(allowedOrigins))
for _, ao := range allowedOrigins {
parts := strings.Split(ao, ":")
str := strings.TrimSpace(parts[0])
if str == "" {
continue
}

if str == "*" {
allowedOriginRegex = append(
allowedOriginRegex,
*(regexp.MustCompile(".+")),
)
break
}

regStr := fmt.Sprintf(`^(http)?(https)?(:\/\/)?(.+\.)?%s(:[0-9]+)?$`, str)

allowedOriginRegex = append(
allowedOriginRegex,
// Allow any port number of the specified domain
*(regexp.MustCompile(regStr)),
)
}

return allowedOriginRegex
}

func allowedMethods(routes []*webgo.Route) string {
if len(routes) == 0 {
return defaultAllowMethods
}

methods := make([]string, 0, len(routes))
for _, r := range routes {
found := false
for _, m := range methods {
if m == r.Method {
found = true
break
}
}
if found {
continue
}
methods = append(methods, r.Method)
}
sort.Strings(methods)
return strings.Join(methods, ",")
}

// Config holds all the configurations which is available for customizing this middleware
type Config struct {
TimeoutSecs int
Routes []*webgo.Route
AllowedOrigins []string
AllowedHeaders []string
}

func (cfg *Config) normalize() {
if cfg.TimeoutSecs < 60 {
cfg.TimeoutSecs = 60
}
}

func allowedHeaders(headers []string) string {
allowedHeaders := strings.Join(headers, ",")
if allowedHeaders[len(allowedHeaders)-1] != ',' {
allowedHeaders += ","
}
return allowedHeaders
}

func allowOrigin(reqOrigin string, allowedOriginRegex []regexp.Regexp) bool {

for _, o := range allowedOriginRegex {
// Set appropriate response headers required for CORS
if o.MatchString(reqOrigin) || reqOrigin == "" {
return true
}
}
return false
}

// Middleware can be used as well, it lets the user use this middleware without webgo
func Middleware(allowedOriginRegex []regexp.Regexp, corsTimeout, allowedMethods, allowedHeaders string) webgo.Middleware {
return func(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
reqOrigin := getReqOrigin(req)
allowed := allowOrigin(reqOrigin, allowedOriginRegex)

if !allowed {
// If CORS failed, no respective headers are set. But the execution is allowed to continue
// Earlier this middleware blocked access altogether, which was considered an added
// security measure despite it being outside the scope of this middelware. Though, such
// restrictions create unnecessary complexities during inter-app communication.
next(rw, req)
return
}

// Set appropriate response headers required for CORS
rw.Header().Set(headerOrigin, reqOrigin)
rw.Header().Set(headerAccessControlAge, corsTimeout)
rw.Header().Set(headerCreds, "true")
rw.Header().Set(headerMethods, allowedMethods)
rw.Header().Set(headerAllowHeaders, allowedHeaders+req.Header.Get(headerReqHeaders))

if req.Method == http.MethodOptions {
webgo.SendHeader(rw, http.StatusOK)
return
}

next(rw, req)
}
}

// CORS is a single CORS middleware which can be applied to the whole app at once
func CORS(cfg *Config) webgo.Middleware {
if cfg == nil {
cfg = new(Config)
}

allowedOrigins := cfg.AllowedOrigins
if len(allowedOrigins) == 0 {
allowedOrigins = allowedDomains()
}

allowedOriginRegex := allowedOriginsRegex(allowedOrigins...)
allowedmethods := allowedMethods(cfg.Routes)
allowedHeaders := allowedHeaders(cfg.AllowedHeaders)
corsTimeout := fmt.Sprintf("%d", cfg.TimeoutSecs)

return Middleware(
allowedOriginRegex,
corsTimeout,
allowedmethods,
allowedHeaders,
)
}
7 changes: 7 additions & 0 deletions middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,13 @@ const (
allowHeaders = "Accept,Content-Type,Content-Length,Accept-Encoding,Access-Control-Request-Headers,"
)

func deprecationLog() {
webgo.LOGHANDLER.Warn("this middleware is deprecated, use github.com/bnkamalesh/middleware/cors")
}

// Cors is a basic CORS middleware which can be added to individual handlers
func Cors(allowedOrigins ...string) http.HandlerFunc {
deprecationLog()
if len(allowedOrigins) == 0 {
allowedOrigins = append(allowedOrigins, "*")
}
Expand Down Expand Up @@ -73,6 +78,7 @@ func Cors(allowedOrigins ...string) http.HandlerFunc {

// CorsOptions is a CORS middleware only for OPTIONS request method
func CorsOptions(allowedOrigins ...string) http.HandlerFunc {
deprecationLog()
if len(allowedOrigins) == 0 {
allowedOrigins = append(allowedOrigins, "*")
}
Expand Down Expand Up @@ -103,6 +109,7 @@ func CorsOptions(allowedOrigins ...string) http.HandlerFunc {

// CorsWrap is a single Cors middleware which can be applied to the whole app at once
func CorsWrap(allowedOrigins ...string) func(http.ResponseWriter, *http.Request, http.HandlerFunc) {
deprecationLog()
if len(allowedOrigins) == 0 {
allowedOrigins = append(allowedOrigins, "*")
}
Expand Down
12 changes: 6 additions & 6 deletions responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ func Send(w http.ResponseWriter, contentType string, data interface{}, rCode int
// SendResponse is used to respond to any request (JSON response) based on the code, data etc.
func SendResponse(w http.ResponseWriter, data interface{}, rCode int) {
w = crwAsserter(w, rCode)

w.Header().Add(HeaderContentType, JSONContentType)
err := json.NewEncoder(w).Encode(dOutput{Data: data, Status: rCode})
if err != nil {
/*
In case of encoding error, send "internal server error" after
logging the actual error.
In case of encoding error, send "internal server error" and
log the actual error.
*/
R500(w, ErrInternalServer)
LOGHANDLER.Error(err)
Expand All @@ -83,12 +83,12 @@ func SendResponse(w http.ResponseWriter, data interface{}, rCode int) {
// SendError is used to respond to any request with an error
func SendError(w http.ResponseWriter, data interface{}, rCode int) {
w = crwAsserter(w, rCode)

w.Header().Add(HeaderContentType, JSONContentType)
err := json.NewEncoder(w).Encode(errOutput{data, rCode})
if err != nil {
/*
In case of encoding error, send "internal server error" after
logging the actual error.
In case of encoding error, send "internal server error" and
log the actual error.
*/
R500(w, ErrInternalServer)
LOGHANDLER.Error(err)
Expand Down

0 comments on commit 0254c1b

Please sign in to comment.