Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Siva Manivannan committed Jun 26, 2024
1 parent 709d00e commit 5aeba44
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 5 deletions.
17 changes: 12 additions & 5 deletions pkg/handlers/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package handlers

import (
"bytes"
"encoding/json"
"io"
"net/http"
"sync"
Expand Down Expand Up @@ -66,7 +67,7 @@ type cache struct {

func NewCache() *cache {
return &cache{
store: make(map[string]CacheEntry),
store: map[string]CacheEntry{},
}
}

Expand All @@ -92,11 +93,11 @@ func (c *cache) Set(key string, entry CacheEntry, duration time.Duration) {
type responseRecorder struct {
http.ResponseWriter
Body *bytes.Buffer
statusCode int
StatusCode int
}

func (r *responseRecorder) WriteHeader(code int) {
r.statusCode = code
r.StatusCode = code
r.ResponseWriter.WriteHeader(code)
}

Expand All @@ -123,17 +124,23 @@ func CacheMiddleware(next http.HandlerFunc, duration time.Duration) http.Handler

if entry, found := cache.Get(key); found {
logger.Infof("cache middleware: serving cached payload for method: %s path: %s ttl: %s ", r.Method, r.URL.Path, time.Until(entry.Expiry).Round(time.Second).String())
JSONCached(w, entry.StatusCode, entry.ResponseBody)
JSONCached(w, entry.StatusCode, json.RawMessage(entry.ResponseBody))
return
}

recorder := &responseRecorder{ResponseWriter: w, Body: &bytes.Buffer{}}
next(recorder, r)

// Save only successful responses in the cache
if recorder.StatusCode < 200 || recorder.StatusCode >= 300 {
return
}

cache.Set(key, CacheEntry{
StatusCode: recorder.statusCode,
StatusCode: recorder.StatusCode,
RequestBody: body,
ResponseBody: recorder.Body.Bytes(),
}, duration)

}
}
169 changes: 169 additions & 0 deletions pkg/handlers/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package handlers

import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func Test_cache(t *testing.T) {
tests := []struct {
name string
assertFn func(*testing.T, *cache)
}{
{
name: "cache should be able to set and get KV pair with valid ttl",
assertFn: func(t *testing.T, c *cache) {
now := time.Now()

entry := CacheEntry{
RequestBody: []byte("request body"),
ResponseBody: []byte("response body"),
StatusCode: http.StatusOK,
}
c.Set("cache-key", entry, 1*time.Minute)

cachedEntry, exists := c.Get("cache-key")

require.True(t, exists)
require.Equal(t, entry.RequestBody, cachedEntry.RequestBody)
require.Equal(t, entry.ResponseBody, cachedEntry.ResponseBody)
require.Equal(t, entry.StatusCode, cachedEntry.StatusCode)

// TTL should be valid
require.Equal(t, true, cachedEntry.Expiry.After(now))
},
},
{
name: "cache get should return false for non-existent key",
assertFn: func(t *testing.T, c *cache) {
_, exists := c.Get("cache-key-does-not-exist")
require.False(t, exists)
},
},
{
name: "cache get should return false for expired cache entry",
assertFn: func(t *testing.T, c *cache) {

entry := CacheEntry{
RequestBody: []byte("request body"),
ResponseBody: []byte("response body"),
StatusCode: http.StatusOK,
}
c.Set("cache-key", entry, 5*time.Millisecond)

time.Sleep(10 * time.Millisecond)

_, exists := c.Get("cache-key")

require.False(t, exists)

},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := NewCache()
tt.assertFn(t, c)
})
}
}

func newTestRequest(method, url string, body []byte) (*http.Request, *httptest.ResponseRecorder) {
req := httptest.NewRequest(method, url, bytes.NewReader(body))
rec := httptest.NewRecorder()
return req, rec
}

func Test_CacheMiddleware(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
JSON(w, http.StatusOK, map[string]interface{}{"message": "Hello, World!"})
})

duration := 1 * time.Minute
cachedHandler := CacheMiddleware(handler, duration)

/* First request should not be served from cache */
req, recorder := newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should not exist because the response is NOT served from cache

/* Second request should be served from cache since the payload it the same */
req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "true", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should exist because the response is served from cache

/* Third request should not be served from cache since the payload is different */
req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 1111}}`))
cachedHandler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should not exist because the response is NOT served from cache

}

func Test_CacheMiddleware_Expiry(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
JSON(w, http.StatusOK, map[string]interface{}{"message": "Hello, World!"})
})

duration := 100 * time.Millisecond
cachedHandler := CacheMiddleware(handler, duration)

/* First request should not be served from cache */
req, recorder := newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should not exist because the response is NOT served from cache

/* Second request should be served from cache since the payload it the same and under the expiry time */
req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "true", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should exist because the response is served from cache

time.Sleep(110 * time.Millisecond)

/* Third request should not be served from cache due to expiry */
req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should not exist because the response is NOT served from cache

}

func Test_CacheMiddleware_DoNotCacheErroredPayload(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
JSON(w, http.StatusInternalServerError, map[string]interface{}{"error": "Something went wrong!"})
})

duration := 1 * time.Minute
cachedHandler := CacheMiddleware(handler, duration)

/* First request should not be served from cache */
req, recorder := newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusInternalServerError, recorder.Code)
require.Equal(t, `{"error":"Something went wrong!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should not exist because the response is NOT served from cache

/* Second request should not be served from cache - err'ed payloads not saved to cache */
req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusInternalServerError, recorder.Code)
require.Equal(t, `{"error":"Something went wrong!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should exist because the response is served from cache

}

0 comments on commit 5aeba44

Please sign in to comment.