Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat [sc 106247]: Add a caching layer to prevent SDK from sending the same payloads too frequently upstream #193

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions pkg/apiserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ func Start(params APIServerParams) {
r := mux.NewRouter()
r.Use(handlers.CorsMiddleware)

const DefaultCacheTTL = 1 * time.Minute

// TODO: make all routes authenticated
authRouter := r.NewRoute().Subrouter()
authRouter.Use(handlers.RequireValidLicenseIDMiddleware)
Expand All @@ -66,9 +68,9 @@ func Start(params APIServerParams) {
r.HandleFunc("/api/v1/app/info", handlers.GetCurrentAppInfo).Methods("GET")
r.HandleFunc("/api/v1/app/updates", handlers.GetAppUpdates).Methods("GET")
r.HandleFunc("/api/v1/app/history", handlers.GetAppHistory).Methods("GET")
r.HandleFunc("/api/v1/app/custom-metrics", handlers.SendCustomAppMetrics).Methods("POST", "PATCH")
r.HandleFunc("/api/v1/app/custom-metrics/{key}", handlers.DeleteCustomAppMetricsKey).Methods("DELETE")
r.HandleFunc("/api/v1/app/instance-tags", handlers.SendAppInstanceTags).Methods("POST")
r.HandleFunc("/api/v1/app/custom-metrics", handlers.CacheMiddleware(handlers.SendCustomAppMetrics, DefaultCacheTTL)).Methods("POST", "PATCH")
divolgin marked this conversation as resolved.
Show resolved Hide resolved
r.HandleFunc("/api/v1/app/custom-metrics/{key}", handlers.CacheMiddleware(handlers.DeleteCustomAppMetricsKey, DefaultCacheTTL)).Methods("DELETE")
r.HandleFunc("/api/v1/app/instance-tags", handlers.CacheMiddleware(handlers.SendAppInstanceTags, DefaultCacheTTL)).Methods("POST")

// integration
r.HandleFunc("/api/v1/integration/mock-data", handlers.EnforceMockAccess(handlers.PostIntegrationMockData)).Methods("POST")
Expand Down
111 changes: 111 additions & 0 deletions pkg/handlers/middleware.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
package handlers

import (
"bytes"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"sync"
"time"

"github.com/pkg/errors"
"github.com/replicatedhq/replicated-sdk/pkg/handlers/types"
"github.com/replicatedhq/replicated-sdk/pkg/logger"
"github.com/replicatedhq/replicated-sdk/pkg/store"
)

Expand Down Expand Up @@ -44,3 +53,105 @@ func RequireValidLicenseIDMiddleware(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
})
}

// Code for the cache middleware
type CacheEntry struct {
RequestBody []byte
ResponseBody []byte
StatusCode int
Expiry time.Time
}

type cache struct {
store map[string]CacheEntry
mu sync.RWMutex
}

func NewCache() *cache {
return &cache{
store: map[string]CacheEntry{},
}
}

func (c *cache) Get(key string) (CacheEntry, bool) {
c.mu.RLock()
defer c.mu.RUnlock()

entry, found := c.store[key]
if !found || time.Now().After(entry.Expiry) {
return CacheEntry{}, false
}
return entry, true
}

func (c *cache) Set(key string, entry CacheEntry, duration time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()

// Clean up expired entries
divolgin marked this conversation as resolved.
Show resolved Hide resolved
for k, v := range c.store {
if time.Now().After(v.Expiry) {
delete(c.store, k)
}
}

entry.Expiry = time.Now().Add(duration)
c.store[key] = entry
divolgin marked this conversation as resolved.
Show resolved Hide resolved
}

type responseRecorder struct {
http.ResponseWriter
Body *bytes.Buffer
StatusCode int
}

func (r *responseRecorder) WriteHeader(code int) {
r.StatusCode = code
r.ResponseWriter.WriteHeader(code)
}

func (r *responseRecorder) Write(b []byte) (int, error) {
r.Body.Write(b)
return r.ResponseWriter.Write(b)
}

func CacheMiddleware(next http.HandlerFunc, duration time.Duration) http.HandlerFunc {
// Each handler has its own cache to reduce contention for the in-memory store
cache := NewCache()

return func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
logger.Error(errors.Wrap(err, "cache middleware - failed to read request body"))
http.Error(w, "cache middleware: unable to read request body", http.StatusInternalServerError)
return
}

r.Body = io.NopCloser(bytes.NewBuffer(body))

hash := sha256.Sum256([]byte(r.Method + "::" + r.URL.Path))
divolgin marked this conversation as resolved.
Show resolved Hide resolved

key := fmt.Sprintf("%x\n", hash)
divolgin marked this conversation as resolved.
Show resolved Hide resolved

if entry, found := cache.Get(key); found && bytes.Equal(entry.RequestBody, body) {
divolgin marked this conversation as resolved.
Show resolved Hide resolved
logger.Infof("cache middleware: serving cached payload for method: %s path: %s ttl: %s ", r.Method, r.URL.Path, time.Until(entry.Expiry).Round(time.Second).String())
JSONCached(w, entry.StatusCode, json.RawMessage(entry.ResponseBody))
return
}

recorder := &responseRecorder{ResponseWriter: w, Body: &bytes.Buffer{}}
next(recorder, r)

// Save only successful responses in the cache
if recorder.StatusCode < 200 || recorder.StatusCode >= 300 {
return
}

cache.Set(key, CacheEntry{
StatusCode: recorder.StatusCode,
RequestBody: body,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should consider making this a digest too because even without historical record this map can get big

ResponseBody: recorder.Body.Bytes(),
}, duration)

}
}
190 changes: 190 additions & 0 deletions pkg/handlers/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
package handlers

import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func Test_cache(t *testing.T) {
tests := []struct {
name string
assertFn func(*testing.T, *cache)
}{
{
name: "cache should be able to set and get KV pair with valid ttl",
assertFn: func(t *testing.T, c *cache) {
now := time.Now()

entry := CacheEntry{
RequestBody: []byte("request body"),
ResponseBody: []byte("response body"),
StatusCode: http.StatusOK,
}
c.Set("cache-key", entry, 1*time.Minute)

cachedEntry, exists := c.Get("cache-key")

require.True(t, exists)
require.Equal(t, entry.RequestBody, cachedEntry.RequestBody)
require.Equal(t, entry.ResponseBody, cachedEntry.ResponseBody)
require.Equal(t, entry.StatusCode, cachedEntry.StatusCode)

// TTL should be valid
require.Equal(t, true, cachedEntry.Expiry.After(now))
},
},
{
name: "cache get should return false for non-existent key",
assertFn: func(t *testing.T, c *cache) {
_, exists := c.Get("cache-key-does-not-exist")
require.False(t, exists)
},
},
{
name: "cache get should return false for expired cache entry",
assertFn: func(t *testing.T, c *cache) {

entry := CacheEntry{
RequestBody: []byte("request body"),
ResponseBody: []byte("response body"),
StatusCode: http.StatusOK,
}
c.Set("cache-key", entry, 5*time.Millisecond)

time.Sleep(10 * time.Millisecond)

_, exists := c.Get("cache-key")

require.False(t, exists)

},
},
{
name: "cache set should delete expired cache entries",
assertFn: func(t *testing.T, c *cache) {

entry := CacheEntry{
RequestBody: []byte("request body"),
ResponseBody: []byte("response body"),
StatusCode: http.StatusOK,
}

c.Set("first-cache-key", entry, 10*time.Millisecond)
_, exists := c.Get("first-cache-key")
require.True(t, exists)

time.Sleep(20 * time.Millisecond)

c.Set("second-cache-key", entry, 1*time.Minute)
_, exists = c.Get("first-cache-key")
require.False(t, exists)
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := NewCache()
tt.assertFn(t, c)
})
}
}

func newTestRequest(method, url string, body []byte) (*http.Request, *httptest.ResponseRecorder) {
req := httptest.NewRequest(method, url, bytes.NewReader(body))
rec := httptest.NewRecorder()
return req, rec
}

func Test_CacheMiddleware(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
JSON(w, http.StatusOK, map[string]interface{}{"message": "Hello, World!"})
})

duration := 1 * time.Minute
cachedHandler := CacheMiddleware(handler, duration)

/* First request should not be served from cache */
req, recorder := newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache

/* Second request should be served from cache since the payload it the same */
req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "true", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should exist because the response is served from cache

/* Third request should not be served from cache since the payload is different */
req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 1111}}`))
cachedHandler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache

}

func Test_CacheMiddleware_Expiry(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
JSON(w, http.StatusOK, map[string]interface{}{"message": "Hello, World!"})
})

duration := 100 * time.Millisecond
cachedHandler := CacheMiddleware(handler, duration)

/* First request should not be served from cache */
req, recorder := newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache

/* Second request should be served from cache since the payload it the same and under the expiry time */
req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "true", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should exist because the response is served from cache

time.Sleep(110 * time.Millisecond)

/* Third request should not be served from cache due to expiry */
req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache

}

func Test_CacheMiddleware_DoNotCacheErroredPayload(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
JSON(w, http.StatusInternalServerError, map[string]interface{}{"error": "Something went wrong!"})
})

duration := 1 * time.Minute
cachedHandler := CacheMiddleware(handler, duration)

/* First request should not be served from cache */
req, recorder := newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusInternalServerError, recorder.Code)
require.Equal(t, `{"error":"Something went wrong!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache

/* Second request should not be served from cache - err'ed payloads are not cached */
req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusInternalServerError, recorder.Code)
require.Equal(t, `{"error":"Something went wrong!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache

}
File renamed without changes.
Loading