From 639acd81e3b9b4347f9f7805d49a8e75d681306c Mon Sep 17 00:00:00 2001 From: Siva Manivannan Date: Thu, 27 Jun 2024 22:34:14 -0500 Subject: [PATCH] make cache gorilla mux compatible --- pkg/apiserver/server.go | 24 +++++++++++++----------- pkg/handlers/middleware.go | 17 +++++++++++------ pkg/handlers/middleware_test.go | 25 ++++++++++++++++++++++--- 3 files changed, 46 insertions(+), 20 deletions(-) diff --git a/pkg/apiserver/server.go b/pkg/apiserver/server.go index e857aef..382fdde 100644 --- a/pkg/apiserver/server.go +++ b/pkg/apiserver/server.go @@ -51,26 +51,28 @@ func Start(params APIServerParams) { r := mux.NewRouter() r.Use(handlers.CorsMiddleware) - const DefaultCacheTTL = 1 * time.Minute - // TODO: make all routes authenticated authRouter := r.NewRoute().Subrouter() authRouter.Use(handlers.RequireValidLicenseIDMiddleware) + cachedRouter := r.NewRoute().Subrouter() + cacheHandler := handlers.CacheMiddleware(handlers.NewCache(), handlers.CacheMiddlewareDefaultTTL) + cachedRouter.Use(cacheHandler) + r.HandleFunc("/healthz", handlers.Healthz) // license - r.HandleFunc("/api/v1/license/info", handlers.GetLicenseInfo).Methods("GET") - r.HandleFunc("/api/v1/license/fields", handlers.GetLicenseFields).Methods("GET") - r.HandleFunc("/api/v1/license/fields/{fieldName}", handlers.GetLicenseField).Methods("GET") + cachedRouter.HandleFunc("/api/v1/license/info", handlers.GetLicenseInfo).Methods("GET") + cachedRouter.HandleFunc("/api/v1/license/fields", handlers.GetLicenseFields).Methods("GET") + cachedRouter.HandleFunc("/api/v1/license/fields/{fieldName}", handlers.GetLicenseField).Methods("GET") // app - r.HandleFunc("/api/v1/app/info", handlers.GetCurrentAppInfo).Methods("GET") - r.HandleFunc("/api/v1/app/updates", handlers.GetAppUpdates).Methods("GET") - r.HandleFunc("/api/v1/app/history", handlers.GetAppHistory).Methods("GET") - r.HandleFunc("/api/v1/app/custom-metrics", handlers.CacheMiddleware(handlers.SendCustomAppMetrics, DefaultCacheTTL)).Methods("POST", "PATCH") - r.HandleFunc("/api/v1/app/custom-metrics/{key}", handlers.CacheMiddleware(handlers.DeleteCustomAppMetricsKey, DefaultCacheTTL)).Methods("DELETE") - r.HandleFunc("/api/v1/app/instance-tags", handlers.CacheMiddleware(handlers.SendAppInstanceTags, DefaultCacheTTL)).Methods("POST") + cachedRouter.HandleFunc("/api/v1/app/info", handlers.GetCurrentAppInfo).Methods("GET") + cachedRouter.HandleFunc("/api/v1/app/updates", handlers.GetAppUpdates).Methods("GET") + cachedRouter.HandleFunc("/api/v1/app/history", handlers.GetAppHistory).Methods("GET") + cachedRouter.HandleFunc("/api/v1/app/custom-metrics", handlers.SendCustomAppMetrics).Methods("POST", "PATCH") + cachedRouter.HandleFunc("/api/v1/app/custom-metrics/{key}", handlers.DeleteCustomAppMetricsKey).Methods("DELETE") + cachedRouter.HandleFunc("/api/v1/app/instance-tags", handlers.SendAppInstanceTags).Methods("POST") // integration r.HandleFunc("/api/v1/integration/mock-data", handlers.EnforceMockAccess(handlers.PostIntegrationMockData)).Methods("POST") diff --git a/pkg/handlers/middleware.go b/pkg/handlers/middleware.go index b9856fa..246d11b 100644 --- a/pkg/handlers/middleware.go +++ b/pkg/handlers/middleware.go @@ -11,6 +11,7 @@ import ( "sync" "time" + "github.com/gorilla/mux" "github.com/pkg/errors" "github.com/replicatedhq/replicated-sdk/pkg/handlers/types" "github.com/replicatedhq/replicated-sdk/pkg/logger" @@ -116,10 +117,15 @@ func (r *responseRecorder) Write(b []byte) (int, error) { return r.ResponseWriter.Write(b) } -func CacheMiddleware(next http.HandlerFunc, duration time.Duration) http.HandlerFunc { - // Each handler has its own cache to reduce contention for the in-memory store - cache := NewCache() +const CacheMiddlewareDefaultTTL = 1 * time.Minute +func CacheMiddleware(cache *cache, duration time.Duration) mux.MiddlewareFunc { + return func(next http.Handler) http.Handler { + return cacheMiddleware(next, cache, duration) + } +} + +func cacheMiddleware(next http.Handler, cache *cache, duration time.Duration) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { @@ -127,21 +133,20 @@ func CacheMiddleware(next http.HandlerFunc, duration time.Duration) http.Handler http.Error(w, "cache middleware: unable to read request body", http.StatusInternalServerError) return } - r.Body = io.NopCloser(bytes.NewBuffer(body)) hash := sha256.Sum256([]byte(r.Method + "::" + r.URL.Path + "::" + r.URL.Query().Encode())) - key := fmt.Sprintf("%x", hash) if entry, found := cache.Get(key); found && IsSamePayload(entry.RequestBody, body) { logger.Infof("cache middleware: serving cached payload for method: %s path: %s ttl: %s ", r.Method, r.URL.Path, time.Until(entry.Expiry).Round(time.Second).String()) + w.Header().Set("X-Replicated-Rate-Limited", "true") JSONCached(w, entry.StatusCode, json.RawMessage(entry.ResponseBody)) return } recorder := &responseRecorder{ResponseWriter: w, Body: &bytes.Buffer{}} - next(recorder, r) + next.ServeHTTP(recorder, r) // Save only successful responses in the cache if recorder.StatusCode < 200 || recorder.StatusCode >= 300 { diff --git a/pkg/handlers/middleware_test.go b/pkg/handlers/middleware_test.go index a5b4cce..8a97388 100644 --- a/pkg/handlers/middleware_test.go +++ b/pkg/handlers/middleware_test.go @@ -166,28 +166,35 @@ func Test_CacheMiddleware(t *testing.T) { }) duration := 1 * time.Minute - cachedHandler := CacheMiddleware(handler, duration) + cache := NewCache() + cachedHandler := CacheMiddleware(cache, duration).Middleware(handler) /* First request should not be served from cache */ req, recorder := newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`)) cachedHandler.ServeHTTP(recorder, req) + require.Equal(t, http.StatusOK, recorder.Code) require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String()) require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache + require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT rate limited /* Second request should be served from cache since the payload it the same */ req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`)) cachedHandler.ServeHTTP(recorder, req) + require.Equal(t, http.StatusOK, recorder.Code) require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String()) require.Equal(t, "true", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should exist because the response is served from cache + require.Equal(t, "true", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should exist because the response is rate limited /* Third request should not be served from cache since the payload is different */ req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 1111}}`)) cachedHandler.ServeHTTP(recorder, req) + require.Equal(t, http.StatusOK, recorder.Code) require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String()) require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache + require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT served from cache } @@ -197,30 +204,37 @@ func Test_CacheMiddleware_Expiry(t *testing.T) { }) duration := 100 * time.Millisecond - cachedHandler := CacheMiddleware(handler, duration) + cache := NewCache() + cachedHandler := CacheMiddleware(cache, duration).Middleware(handler) /* First request should not be served from cache */ req, recorder := newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`)) cachedHandler.ServeHTTP(recorder, req) + require.Equal(t, http.StatusOK, recorder.Code) require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String()) require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache + require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT served from cache /* Second request should be served from cache since the payload it the same and under the expiry time */ req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`)) cachedHandler.ServeHTTP(recorder, req) + require.Equal(t, http.StatusOK, recorder.Code) require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String()) require.Equal(t, "true", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should exist because the response is served from cache + require.Equal(t, "true", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should exist because the response is rate limited time.Sleep(110 * time.Millisecond) /* Third request should not be served from cache due to expiry */ req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`)) cachedHandler.ServeHTTP(recorder, req) + require.Equal(t, http.StatusOK, recorder.Code) require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String()) require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache + require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT rate limited } @@ -230,20 +244,25 @@ func Test_CacheMiddleware_DoNotCacheErroredPayload(t *testing.T) { }) duration := 1 * time.Minute - cachedHandler := CacheMiddleware(handler, duration) + cache := NewCache() + cachedHandler := CacheMiddleware(cache, duration).Middleware(handler) /* First request should not be served from cache */ req, recorder := newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`)) cachedHandler.ServeHTTP(recorder, req) + require.Equal(t, http.StatusInternalServerError, recorder.Code) require.Equal(t, `{"error":"Something went wrong!"}`, recorder.Body.String()) require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache + require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT served from cache /* Second request should not be served from cache - err'ed payloads are not cached */ req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`)) cachedHandler.ServeHTTP(recorder, req) + require.Equal(t, http.StatusInternalServerError, recorder.Code) require.Equal(t, `{"error":"Something went wrong!"}`, recorder.Body.String()) require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache + require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT rate limited }