Skip to content

Commit

Permalink
Merge pull request #2 from lithictech/better-middleware
Browse files Browse the repository at this point in the history
Better middleware
  • Loading branch information
rgalanakis authored Nov 4, 2021
2 parents 33a0fa5 + 7c1e325 commit e122067
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 26 deletions.
13 changes: 7 additions & 6 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ import (
)

type Config struct {
Logger *logrus.Entry
HealthHandler echo.HandlerFunc
CorsOrigins []string
HealthResponse map[string]interface{}
StatusResponse map[string]interface{}
Logger *logrus.Entry
LoggingMiddlwareConfig LoggingMiddlwareConfig
HealthHandler echo.HandlerFunc
CorsOrigins []string
HealthResponse map[string]interface{}
StatusResponse map[string]interface{}
}

func New(cfg Config) *echo.Echo {
Expand All @@ -54,7 +55,7 @@ func New(cfg Config) *echo.Echo {
e.Logger.SetOutput(os.Stdout)
e.HideBanner = true
e.HTTPErrorHandler = NewHTTPErrorHandler(e)
e.Use(LoggingMiddleware(cfg.Logger))
e.Use(LoggingMiddlewareWithConfig(cfg.Logger, cfg.LoggingMiddlwareConfig))
if cfg.CorsOrigins != nil {
e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
AllowOrigins: cfg.CorsOrigins,
Expand Down
61 changes: 61 additions & 0 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,52 @@ var _ = Describe("API", func() {
Expect(logHook.Entries).To(HaveLen(1))
Expect(logHook.Entries[0].Level).To(Equal(logrus.DebugLevel))
})
It("can log request and response headers", func() {
e = api.New(api.Config{
Logger: logEntry,
LoggingMiddlwareConfig: api.LoggingMiddlwareConfig{
RequestHeaders: true,
ResponseHeaders: true,
},
})
e.GET("/", func(c echo.Context) error {
c.Response().Header().Set("ResHead", "ResHeadVal")
return c.String(200, "ok")
})
Expect(Serve(e, GetRequest("/", SetReqHeader("ReqHead", "ReqHeadVal")))).To(HaveResponseCode(200))
Expect(logHook.Entries).To(HaveLen(1))
Expect(logHook.Entries[0].Data).To(And(
HaveKeyWithValue("request_header.Reqhead", "ReqHeadVal"),
HaveKeyWithValue("response_header.Reshead", "ResHeadVal"),
))
})
It("can use custom DoLog, BeforeRequest, and AfterRequest hooks", func() {
doLogCalled := false
e = api.New(api.Config{
Logger: logEntry,
LoggingMiddlwareConfig: api.LoggingMiddlwareConfig{
BeforeRequest: func(_ echo.Context, e *logrus.Entry) *logrus.Entry {
return e.WithField("before", 1)
},
AfterRequest: func(_ echo.Context, e *logrus.Entry) *logrus.Entry {
return e.WithField("after", 2)
},
DoLog: func(c echo.Context, e *logrus.Entry) {
doLogCalled = true
api.LoggingMiddlewareDefaultDoLog(c, e)
},
},
})
e.GET("/", func(c echo.Context) error {
return c.String(400, "")
})
Expect(Serve(e, GetRequest("/"))).To(HaveResponseCode(400))
Expect(doLogCalled).To(BeTrue())
Expect(logHook.Entries[len(logHook.Entries)-1].Data).To(And(
HaveKeyWithValue("before", 1),
HaveKeyWithValue("after", 2),
))
})
})

Describe("error handling", func() {
Expand Down Expand Up @@ -300,5 +346,20 @@ var _ = Describe("API", func() {
HaveKeyWithValue("debug_response_body", ContainSubstring("ok")),
))
})
It("can print memory stats every n requests", func() {
e.Use(api.DebugMiddleware(api.DebugMiddlewareConfig{Enabled: true, DumpMemoryEvery: 2}))
e.GET("/endpoint", func(c echo.Context) error {
return c.String(200, "ok")
})
Serve(e, NewRequest("GET", "/endpoint", nil, SetReqHeader("Foo", "x")))
Serve(e, NewRequest("GET", "/endpoint", nil, SetReqHeader("Foo", "x")))
Expect(logHook.Entries).To(HaveLen(4))
Expect(logHook.Entries[0].Message).To(Equal("request_debug"))
Expect(logHook.Entries[0].Data).ToNot(HaveKey("memory_sys"))
Expect(logHook.Entries[1].Message).To(Equal("request_finished"))
Expect(logHook.Entries[2].Message).To(Equal("request_debug"))
Expect(logHook.Entries[2].Data).To(HaveKey("memory_sys"))
Expect(logHook.Entries[3].Message).To(Equal("request_finished"))
})
})
})
32 changes: 32 additions & 0 deletions api/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"github.com/labstack/echo/middleware"
"github.com/lithictech/go-aperitif/logctx"
"net/http"
"runtime"
"sync/atomic"
)

type DebugMiddlewareConfig struct {
Expand All @@ -14,6 +16,9 @@ type DebugMiddlewareConfig struct {
DumpRequestHeaders bool
DumpResponseHeaders bool
DumpAll bool
// Log out memory stats every 'n' requests.
// If <= 0, do not log them.
DumpMemoryEvery int
}

func DebugMiddleware(cfg DebugMiddlewareConfig) echo.MiddlewareFunc {
Expand All @@ -30,7 +35,10 @@ func DebugMiddleware(cfg DebugMiddlewareConfig) echo.MiddlewareFunc {
cfg.DumpResponseHeaders = true
cfg.DumpResponseBody = true
}
var requestCounter uint64
dumpEveryUint := uint64(cfg.DumpMemoryEvery)
bd := middleware.BodyDump(func(c echo.Context, reqBody []byte, resBody []byte) {
atomic.AddUint64(&requestCounter, 1)
log := logctx.Logger(StdContext(c))
if cfg.DumpRequestBody {
log = log.WithField("debug_request_body", string(reqBody))
Expand All @@ -44,6 +52,30 @@ func DebugMiddleware(cfg DebugMiddlewareConfig) echo.MiddlewareFunc {
if cfg.DumpResponseHeaders {
log = log.WithField("debug_response_headers", headerToMap(c.Response().Header()))
}
if cfg.DumpMemoryEvery > 0 && (requestCounter%dumpEveryUint) == 0 {
var ms runtime.MemStats
runtime.ReadMemStats(&ms)
log = log.WithFields(map[string]interface{}{
"memory_alloc": ms.Alloc,
"memory_total_alloc": ms.TotalAlloc,
"memory_sys": ms.Sys,
"memory_mallocs": ms.Mallocs,
"memory_frees": ms.Frees,
"memory_heap_alloc": ms.HeapAlloc,
"memory_heap_sys": ms.HeapSys,
"memory_heap_idle": ms.HeapIdle,
"memory_heap_inuse": ms.HeapInuse,
"memory_heap_released": ms.HeapReleased,
"memory_heap_objects": ms.HeapObjects,
"memory_stack_inuse": ms.StackInuse,
"memory_stack_sys": ms.StackSys,
"memory_other_sys": ms.OtherSys,
"memory_next_gc": ms.NextGC,
"memory_last_gc": ms.LastGC,
"memory_pause_total_ns": ms.PauseTotalNs,
"memory_num_gc": ms.NumGC,
})
}
log.Debug("request_debug")
})
return bd
Expand Down
77 changes: 60 additions & 17 deletions api/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,31 @@ func SetLogger(c echo.Context, logger *logrus.Entry) {
c.Set(logctx.LoggerKey, logger)
}

type LoggingMiddlwareConfig struct {
// If true, log request headers.
RequestHeaders bool
// If true, log response headers.
ResponseHeaders bool
// If provided, the returned logger is stored in the context
// which is eventually passed to the handler.
// Use to add additional fields to the logger based on the request.
BeforeRequest func(echo.Context, *logrus.Entry) *logrus.Entry
// If provided, the returned logger is used for response logging.
// Use to add additional fields to the logger based on the request or response.
AfterRequest func(echo.Context, *logrus.Entry) *logrus.Entry
// The function that does the actual logging.
// By default, it will log at a certain level based on the status code of the response.
DoLog func(echo.Context, *logrus.Entry)
}

func LoggingMiddleware(outerLogger *logrus.Entry) echo.MiddlewareFunc {
return LoggingMiddlewareWithConfig(outerLogger, LoggingMiddlwareConfig{})
}

func LoggingMiddlewareWithConfig(outerLogger *logrus.Entry, cfg LoggingMiddlwareConfig) echo.MiddlewareFunc {
if cfg.DoLog == nil {
cfg.DoLog = LoggingMiddlewareDefaultDoLog
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
start := time.Now()
Expand Down Expand Up @@ -57,11 +81,16 @@ func LoggingMiddleware(outerLogger *logrus.Entry) echo.MiddlewareFunc {
"request_bytes_in": bytesIn,
string(logctx.RequestTraceIdKey): TraceId(c),
})
//for k, v := range req.Header {
// if len(v) > 0 && k != "Authorization" && k != "Cookie" {
// logger = logger.WithField("header."+k, v[0])
// }
//}
if cfg.RequestHeaders {
for k, v := range req.Header {
if len(v) > 0 && k != "Authorization" && k != "Cookie" {
logger = logger.WithField("request_header."+k, v[0])
}
}
}
if cfg.BeforeRequest != nil {
logger = cfg.BeforeRequest(c, logger)
}

SetLogger(c, logger)

Expand All @@ -80,28 +109,42 @@ func LoggingMiddleware(outerLogger *logrus.Entry) echo.MiddlewareFunc {
"request_latency_ms": int(stop.Sub(start)) / 1000 / 1000,
"request_bytes_out": strconv.FormatInt(res.Size, 10),
})
if cfg.ResponseHeaders {
for k, v := range res.Header() {
if len(v) > 0 && k != "Set-Cookie" {
logger = logger.WithField("response_header."+k, v[0])
}
}
}
if err != nil {
logger = logger.WithField("request_error", err)
}

logMethod := logger.Info
if req.Method == http.MethodOptions {
logMethod = logger.Debug
} else if res.Status >= 500 {
logMethod = logger.Error
} else if res.Status >= 400 {
logMethod = logger.Warn
} else if req.URL.Path == HealthPath || req.URL.Path == StatusPath {
logMethod = logger.Debug
if cfg.BeforeRequest != nil {
logger = cfg.AfterRequest(c, logger)
}
logMethod("request_finished")

cfg.DoLog(c, logger)
// c.Error is already called
return nil
}
}
}

func LoggingMiddlewareDefaultDoLog(c echo.Context, logger *logrus.Entry) {
req := c.Request()
res := c.Response()
logMethod := logger.Info
if req.Method == http.MethodOptions {
logMethod = logger.Debug
} else if res.Status >= 500 {
logMethod = logger.Error
} else if res.Status >= 400 {
logMethod = logger.Warn
} else if req.URL.Path == HealthPath || req.URL.Path == StatusPath {
logMethod = logger.Debug
}
logMethod("request_finished")
}

// Invoke next(c) within a function wrapped with defer,
// so that if it panics, we can recover from it and pass on a 500.
// Use the "named return parameter can be set in defer" trick so we can
Expand Down
8 changes: 8 additions & 0 deletions logctx/logctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ func Logger(c context.Context) *logrus.Entry {
return logger
}

// ActiveTraceId returns the first valid trace value and type from the given context,
// or MissingTraceIdKey if there is none.
func ActiveTraceId(c context.Context) (TraceIdKey, string) {
if trace, ok := c.Value(RequestTraceIdKey).(string); ok {
return RequestTraceIdKey, trace
Expand All @@ -80,6 +82,12 @@ func ActiveTraceId(c context.Context) (TraceIdKey, string) {
return MissingTraceIdKey, "no-trace-id-in-context"
}

// ActiveTraceIdValue returns the value part of ActiveTraceId (does not return the TradeIdKey type part).
func ActiveTraceIdValue(c context.Context) string {
_, v := ActiveTraceId(c)
return v
}

func AddFieldsAndGet(c context.Context, fields map[string]interface{}) (context.Context, *logrus.Entry) {
logger := Logger(c)
logger = logger.WithFields(fields)
Expand Down
1 change: 1 addition & 0 deletions logctx/logctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ var _ = Describe("logtools", func() {
key, val := logctx.ActiveTraceId(c)
Expect(key).To(Equal(logctx.RequestTraceIdKey))
Expect(val).To(Equal("abc"))
Expect(logctx.ActiveTraceIdValue(c)).To(Equal("abc"))
})
It("returns a process trace id", func() {
c := context.WithValue(bg, logctx.ProcessTraceIdKey, "abc")
Expand Down
13 changes: 10 additions & 3 deletions parallel/parallel.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package parallel

import (
"errors"
"github.com/hashicorp/go-multierror"
"github.com/lithictech/go-aperitif/mariobros"
"sync"
)

var ErrInvalidParallelism = errors.New("degree of parallelism must be > 0")

type empty struct{}
type Processor func(idx int) error

Expand All @@ -22,8 +25,12 @@ type Processor func(idx int) error
// and assign to the slice index while processing.
// See ParallelForFiles for an example usage.
func ForEach(total int, n int, process Processor) error {
if n <= 0 {
return ErrInvalidParallelism
}

semaphore := make(chan empty, n)
errors := make([]error, total)
errs := make([]error, total)

wg := sync.WaitGroup{}
wg.Add(total)
Expand All @@ -32,11 +39,11 @@ func ForEach(total int, n int, process Processor) error {
mario := mariobros.Yo("parallel.foreach")
defer mario()
semaphore <- empty{}
errors[i] = process(i)
errs[i] = process(i)
<-semaphore
wg.Done()
}(i)
}
wg.Wait()
return multierror.Append(nil, errors...).ErrorOrNil()
return multierror.Append(nil, errs...).ErrorOrNil()
}
4 changes: 4 additions & 0 deletions parallel/parallel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,8 @@ var _ = Describe("ParallelFor", func() {
Expect(called).To(Equal(1000))
Expect(active).To(Equal(0))
})
It("errors for 0 or negative n", func() {
err := parallel.ForEach(1, 0, nil)
Expect(err).To(BeIdenticalTo(parallel.ErrInvalidParallelism))
})
})

0 comments on commit e122067

Please sign in to comment.