Skip to content

Commit

Permalink
make cache gorilla mux compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
Siva Manivannan committed Jun 28, 2024
1 parent 28bc4e7 commit 639acd8
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 20 deletions.
24 changes: 13 additions & 11 deletions pkg/apiserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,28 @@ func Start(params APIServerParams) {
r := mux.NewRouter()
r.Use(handlers.CorsMiddleware)

const DefaultCacheTTL = 1 * time.Minute

// TODO: make all routes authenticated
authRouter := r.NewRoute().Subrouter()
authRouter.Use(handlers.RequireValidLicenseIDMiddleware)

cachedRouter := r.NewRoute().Subrouter()
cacheHandler := handlers.CacheMiddleware(handlers.NewCache(), handlers.CacheMiddlewareDefaultTTL)
cachedRouter.Use(cacheHandler)

r.HandleFunc("/healthz", handlers.Healthz)

// license
r.HandleFunc("/api/v1/license/info", handlers.GetLicenseInfo).Methods("GET")
r.HandleFunc("/api/v1/license/fields", handlers.GetLicenseFields).Methods("GET")
r.HandleFunc("/api/v1/license/fields/{fieldName}", handlers.GetLicenseField).Methods("GET")
cachedRouter.HandleFunc("/api/v1/license/info", handlers.GetLicenseInfo).Methods("GET")
cachedRouter.HandleFunc("/api/v1/license/fields", handlers.GetLicenseFields).Methods("GET")
cachedRouter.HandleFunc("/api/v1/license/fields/{fieldName}", handlers.GetLicenseField).Methods("GET")

// app
r.HandleFunc("/api/v1/app/info", handlers.GetCurrentAppInfo).Methods("GET")
r.HandleFunc("/api/v1/app/updates", handlers.GetAppUpdates).Methods("GET")
r.HandleFunc("/api/v1/app/history", handlers.GetAppHistory).Methods("GET")
r.HandleFunc("/api/v1/app/custom-metrics", handlers.CacheMiddleware(handlers.SendCustomAppMetrics, DefaultCacheTTL)).Methods("POST", "PATCH")
r.HandleFunc("/api/v1/app/custom-metrics/{key}", handlers.CacheMiddleware(handlers.DeleteCustomAppMetricsKey, DefaultCacheTTL)).Methods("DELETE")
r.HandleFunc("/api/v1/app/instance-tags", handlers.CacheMiddleware(handlers.SendAppInstanceTags, DefaultCacheTTL)).Methods("POST")
cachedRouter.HandleFunc("/api/v1/app/info", handlers.GetCurrentAppInfo).Methods("GET")
cachedRouter.HandleFunc("/api/v1/app/updates", handlers.GetAppUpdates).Methods("GET")
cachedRouter.HandleFunc("/api/v1/app/history", handlers.GetAppHistory).Methods("GET")
cachedRouter.HandleFunc("/api/v1/app/custom-metrics", handlers.SendCustomAppMetrics).Methods("POST", "PATCH")
cachedRouter.HandleFunc("/api/v1/app/custom-metrics/{key}", handlers.DeleteCustomAppMetricsKey).Methods("DELETE")
cachedRouter.HandleFunc("/api/v1/app/instance-tags", handlers.SendAppInstanceTags).Methods("POST")

// integration
r.HandleFunc("/api/v1/integration/mock-data", handlers.EnforceMockAccess(handlers.PostIntegrationMockData)).Methods("POST")
Expand Down
17 changes: 11 additions & 6 deletions pkg/handlers/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"sync"
"time"

"github.com/gorilla/mux"
"github.com/pkg/errors"
"github.com/replicatedhq/replicated-sdk/pkg/handlers/types"
"github.com/replicatedhq/replicated-sdk/pkg/logger"
Expand Down Expand Up @@ -116,32 +117,36 @@ func (r *responseRecorder) Write(b []byte) (int, error) {
return r.ResponseWriter.Write(b)
}

func CacheMiddleware(next http.HandlerFunc, duration time.Duration) http.HandlerFunc {
// Each handler has its own cache to reduce contention for the in-memory store
cache := NewCache()
const CacheMiddlewareDefaultTTL = 1 * time.Minute

func CacheMiddleware(cache *cache, duration time.Duration) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return cacheMiddleware(next, cache, duration)
}
}

func cacheMiddleware(next http.Handler, cache *cache, duration time.Duration) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
logger.Error(errors.Wrap(err, "cache middleware - failed to read request body"))
http.Error(w, "cache middleware: unable to read request body", http.StatusInternalServerError)
return
}

r.Body = io.NopCloser(bytes.NewBuffer(body))

hash := sha256.Sum256([]byte(r.Method + "::" + r.URL.Path + "::" + r.URL.Query().Encode()))

key := fmt.Sprintf("%x", hash)

if entry, found := cache.Get(key); found && IsSamePayload(entry.RequestBody, body) {
logger.Infof("cache middleware: serving cached payload for method: %s path: %s ttl: %s ", r.Method, r.URL.Path, time.Until(entry.Expiry).Round(time.Second).String())
w.Header().Set("X-Replicated-Rate-Limited", "true")
JSONCached(w, entry.StatusCode, json.RawMessage(entry.ResponseBody))
return
}

recorder := &responseRecorder{ResponseWriter: w, Body: &bytes.Buffer{}}
next(recorder, r)
next.ServeHTTP(recorder, r)

// Save only successful responses in the cache
if recorder.StatusCode < 200 || recorder.StatusCode >= 300 {
Expand Down
25 changes: 22 additions & 3 deletions pkg/handlers/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,28 +166,35 @@ func Test_CacheMiddleware(t *testing.T) {
})

duration := 1 * time.Minute
cachedHandler := CacheMiddleware(handler, duration)
cache := NewCache()
cachedHandler := CacheMiddleware(cache, duration).Middleware(handler)

/* First request should not be served from cache */
req, recorder := newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)

require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache
require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT rate limited

/* Second request should be served from cache since the payload it the same */
req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)

require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "true", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should exist because the response is served from cache
require.Equal(t, "true", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should exist because the response is rate limited

/* Third request should not be served from cache since the payload is different */
req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 1111}}`))
cachedHandler.ServeHTTP(recorder, req)

require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache
require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT served from cache

}

Expand All @@ -197,30 +204,37 @@ func Test_CacheMiddleware_Expiry(t *testing.T) {
})

duration := 100 * time.Millisecond
cachedHandler := CacheMiddleware(handler, duration)
cache := NewCache()
cachedHandler := CacheMiddleware(cache, duration).Middleware(handler)

/* First request should not be served from cache */
req, recorder := newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)

require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache
require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT served from cache

/* Second request should be served from cache since the payload it the same and under the expiry time */
req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)

require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "true", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should exist because the response is served from cache
require.Equal(t, "true", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should exist because the response is rate limited

time.Sleep(110 * time.Millisecond)

/* Third request should not be served from cache due to expiry */
req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)

require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache
require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT rate limited

}

Expand All @@ -230,20 +244,25 @@ func Test_CacheMiddleware_DoNotCacheErroredPayload(t *testing.T) {
})

duration := 1 * time.Minute
cachedHandler := CacheMiddleware(handler, duration)
cache := NewCache()
cachedHandler := CacheMiddleware(cache, duration).Middleware(handler)

/* First request should not be served from cache */
req, recorder := newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)

require.Equal(t, http.StatusInternalServerError, recorder.Code)
require.Equal(t, `{"error":"Something went wrong!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache
require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT served from cache

/* Second request should not be served from cache - err'ed payloads are not cached */
req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`))
cachedHandler.ServeHTTP(recorder, req)

require.Equal(t, http.StatusInternalServerError, recorder.Code)
require.Equal(t, `{"error":"Something went wrong!"}`, recorder.Body.String())
require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache
require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT rate limited

}

0 comments on commit 639acd8

Please sign in to comment.