diff --git a/pkg/apiserver/server.go b/pkg/apiserver/server.go index 24cb3fd..6b45342 100644 --- a/pkg/apiserver/server.go +++ b/pkg/apiserver/server.go @@ -55,6 +55,10 @@ func Start(params APIServerParams) { authRouter := r.NewRoute().Subrouter() authRouter.Use(handlers.RequireValidLicenseIDMiddleware) + cacheHandler := handlers.CacheMiddleware(handlers.NewCache(), handlers.CacheMiddlewareDefaultTTL) + cachedRouter := r.NewRoute().Subrouter() + cachedRouter.Use(cacheHandler) + r.HandleFunc("/healthz", handlers.Healthz) // license @@ -66,9 +70,9 @@ func Start(params APIServerParams) { r.HandleFunc("/api/v1/app/info", handlers.GetCurrentAppInfo).Methods("GET") r.HandleFunc("/api/v1/app/updates", handlers.GetAppUpdates).Methods("GET") r.HandleFunc("/api/v1/app/history", handlers.GetAppHistory).Methods("GET") - r.HandleFunc("/api/v1/app/custom-metrics", handlers.SendCustomAppMetrics).Methods("POST", "PATCH") - r.HandleFunc("/api/v1/app/custom-metrics/{key}", handlers.DeleteCustomAppMetricsKey).Methods("DELETE") - r.HandleFunc("/api/v1/app/instance-tags", handlers.SendAppInstanceTags).Methods("POST") + cachedRouter.HandleFunc("/api/v1/app/custom-metrics", handlers.SendCustomAppMetrics).Methods("POST", "PATCH") + cachedRouter.HandleFunc("/api/v1/app/custom-metrics/{key}", handlers.DeleteCustomAppMetricsKey).Methods("DELETE") + cachedRouter.HandleFunc("/api/v1/app/instance-tags", handlers.SendAppInstanceTags).Methods("POST") // integration r.HandleFunc("/api/v1/integration/mock-data", handlers.EnforceMockAccess(handlers.PostIntegrationMockData)).Methods("POST") diff --git a/pkg/handlers/middleware.go b/pkg/handlers/middleware.go index 9373901..246d11b 100644 --- a/pkg/handlers/middleware.go +++ b/pkg/handlers/middleware.go @@ -1,9 +1,20 @@ package handlers import ( + "bytes" + "crypto/sha256" + "encoding/json" + "fmt" + "io" "net/http" + "reflect" + "sync" + "time" + "github.com/gorilla/mux" + "github.com/pkg/errors" "github.com/replicatedhq/replicated-sdk/pkg/handlers/types" + "github.com/replicatedhq/replicated-sdk/pkg/logger" "github.com/replicatedhq/replicated-sdk/pkg/store" ) @@ -44,3 +55,134 @@ func RequireValidLicenseIDMiddleware(next http.Handler) http.Handler { next.ServeHTTP(w, r) }) } + +// Code for the cache middleware +type CacheEntry struct { + RequestBody []byte + ResponseBody []byte + StatusCode int + Expiry time.Time +} + +type cache struct { + store map[string]CacheEntry + mu sync.RWMutex +} + +func NewCache() *cache { + return &cache{ + store: map[string]CacheEntry{}, + } +} + +func (c *cache) Get(key string) (CacheEntry, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + entry, found := c.store[key] + if !found || time.Now().After(entry.Expiry) { + return CacheEntry{}, false + } + return entry, true +} + +func (c *cache) Set(key string, entry CacheEntry, duration time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + + // Clean up expired entries + for k, v := range c.store { + if time.Now().After(v.Expiry) { + delete(c.store, k) + } + } + + entry.Expiry = time.Now().Add(duration) + c.store[key] = entry +} + +type responseRecorder struct { + http.ResponseWriter + Body *bytes.Buffer + StatusCode int +} + +func (r *responseRecorder) WriteHeader(code int) { + r.StatusCode = code + r.ResponseWriter.WriteHeader(code) +} + +func (r *responseRecorder) Write(b []byte) (int, error) { + r.Body.Write(b) + return r.ResponseWriter.Write(b) +} + +const CacheMiddlewareDefaultTTL = 1 * time.Minute + +func CacheMiddleware(cache *cache, duration time.Duration) mux.MiddlewareFunc { + return func(next http.Handler) http.Handler { + return cacheMiddleware(next, cache, duration) + } +} + +func cacheMiddleware(next http.Handler, cache *cache, duration time.Duration) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + logger.Error(errors.Wrap(err, "cache middleware - failed to read request body")) + http.Error(w, "cache middleware: unable to read request body", http.StatusInternalServerError) + return + } + r.Body = io.NopCloser(bytes.NewBuffer(body)) + + hash := sha256.Sum256([]byte(r.Method + "::" + r.URL.Path + "::" + r.URL.Query().Encode())) + key := fmt.Sprintf("%x", hash) + + if entry, found := cache.Get(key); found && IsSamePayload(entry.RequestBody, body) { + logger.Infof("cache middleware: serving cached payload for method: %s path: %s ttl: %s ", r.Method, r.URL.Path, time.Until(entry.Expiry).Round(time.Second).String()) + w.Header().Set("X-Replicated-Rate-Limited", "true") + JSONCached(w, entry.StatusCode, json.RawMessage(entry.ResponseBody)) + return + } + + recorder := &responseRecorder{ResponseWriter: w, Body: &bytes.Buffer{}} + next.ServeHTTP(recorder, r) + + // Save only successful responses in the cache + if recorder.StatusCode < 200 || recorder.StatusCode >= 300 { + return + } + + cache.Set(key, CacheEntry{ + StatusCode: recorder.StatusCode, + RequestBody: body, + ResponseBody: recorder.Body.Bytes(), + }, duration) + + } +} + +func IsSamePayload(a, b []byte) bool { + if len(a) == 0 && len(b) == 0 { + return true + } + + if len(a) == 0 { + a = []byte(`{}`) + } + + if len(b) == 0 { + b = []byte(`{}`) + } + + var aPayload, bPayload map[string]interface{} + if err := json.Unmarshal(a, &aPayload); err != nil { + logger.Error(errors.Wrap(err, "failed to unmarshal payload A")) + return false + } + if err := json.Unmarshal(b, &bPayload); err != nil { + logger.Error(errors.Wrap(err, "failed to unmarshal payload B")) + return false + } + return reflect.DeepEqual(aPayload, bPayload) +} diff --git a/pkg/handlers/middleware_test.go b/pkg/handlers/middleware_test.go new file mode 100644 index 0000000..8a97388 --- /dev/null +++ b/pkg/handlers/middleware_test.go @@ -0,0 +1,268 @@ +package handlers + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func Test_IsSamePayload(t *testing.T) { + tests := []struct { + name string + payloadA []byte + payloadB []byte + expect bool + }{ + { + name: "should return true for empty nil payloads", + payloadA: nil, + payloadB: nil, + expect: true, + }, + { + name: "should return true despite one payload being nil", + payloadA: []byte{}, + payloadB: nil, + expect: true, + }, + { + name: "should tolerate empty non-nil byte payloads", + payloadA: []byte{}, + payloadB: []byte{}, + expect: true, + }, + { + name: "should return false for different payloads where one payload is empty", + payloadA: []byte{}, + payloadB: []byte(`{"numPeople": 10}`), + expect: false, + }, + { + name: "should return false for different payloads", + payloadA: []byte(`{"numProjects": 2000}`), + payloadB: []byte(`{"numPeople": 10}`), + expect: false, + }, + { + name: "should return true for same payloads", + payloadA: []byte(`{"numPeople": 10}`), + payloadB: []byte(`{"numPeople": 10}`), + expect: true, + }, + { + name: "should return true for same payload despite differences in key ordering and spacing", + payloadA: []byte(`{"numProjects": 2000, "numPeople": 10 }`), + payloadB: []byte(`{"numPeople": 10, "numProjects": 2000}`), + expect: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsSamePayload(tt.payloadA, tt.payloadB) + require.Equal(t, tt.expect, got) + }) + } + +} +func Test_cache(t *testing.T) { + tests := []struct { + name string + assertFn func(*testing.T, *cache) + }{ + { + name: "cache should be able to set and get KV pair with valid ttl", + assertFn: func(t *testing.T, c *cache) { + now := time.Now() + + entry := CacheEntry{ + RequestBody: []byte("request body"), + ResponseBody: []byte("response body"), + StatusCode: http.StatusOK, + } + c.Set("cache-key", entry, 1*time.Minute) + + cachedEntry, exists := c.Get("cache-key") + + require.True(t, exists) + require.Equal(t, entry.RequestBody, cachedEntry.RequestBody) + require.Equal(t, entry.ResponseBody, cachedEntry.ResponseBody) + require.Equal(t, entry.StatusCode, cachedEntry.StatusCode) + + // TTL should be valid + require.Equal(t, true, cachedEntry.Expiry.After(now)) + }, + }, + { + name: "cache get should return false for non-existent key", + assertFn: func(t *testing.T, c *cache) { + _, exists := c.Get("cache-key-does-not-exist") + require.False(t, exists) + }, + }, + { + name: "cache get should return false for expired cache entry", + assertFn: func(t *testing.T, c *cache) { + + entry := CacheEntry{ + RequestBody: []byte("request body"), + ResponseBody: []byte("response body"), + StatusCode: http.StatusOK, + } + c.Set("cache-key", entry, 5*time.Millisecond) + + time.Sleep(10 * time.Millisecond) + + _, exists := c.Get("cache-key") + + require.False(t, exists) + + }, + }, + { + name: "cache set should delete expired cache entries", + assertFn: func(t *testing.T, c *cache) { + + entry := CacheEntry{ + RequestBody: []byte("request body"), + ResponseBody: []byte("response body"), + StatusCode: http.StatusOK, + } + + c.Set("first-cache-key", entry, 10*time.Millisecond) + _, exists := c.Get("first-cache-key") + require.True(t, exists) + + time.Sleep(20 * time.Millisecond) + + c.Set("second-cache-key", entry, 1*time.Minute) + _, exists = c.Get("first-cache-key") + require.False(t, exists) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := NewCache() + tt.assertFn(t, c) + }) + } +} + +func newTestRequest(method, url string, body []byte) (*http.Request, *httptest.ResponseRecorder) { + req := httptest.NewRequest(method, url, bytes.NewReader(body)) + rec := httptest.NewRecorder() + return req, rec +} + +func Test_CacheMiddleware(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + JSON(w, http.StatusOK, map[string]interface{}{"message": "Hello, World!"}) + }) + + duration := 1 * time.Minute + cache := NewCache() + cachedHandler := CacheMiddleware(cache, duration).Middleware(handler) + + /* First request should not be served from cache */ + req, recorder := newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`)) + cachedHandler.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String()) + require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache + require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT rate limited + + /* Second request should be served from cache since the payload it the same */ + req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`)) + cachedHandler.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String()) + require.Equal(t, "true", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should exist because the response is served from cache + require.Equal(t, "true", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should exist because the response is rate limited + + /* Third request should not be served from cache since the payload is different */ + req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 1111}}`)) + cachedHandler.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String()) + require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache + require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT served from cache + +} + +func Test_CacheMiddleware_Expiry(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + JSON(w, http.StatusOK, map[string]interface{}{"message": "Hello, World!"}) + }) + + duration := 100 * time.Millisecond + cache := NewCache() + cachedHandler := CacheMiddleware(cache, duration).Middleware(handler) + + /* First request should not be served from cache */ + req, recorder := newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`)) + cachedHandler.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String()) + require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache + require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT served from cache + + /* Second request should be served from cache since the payload it the same and under the expiry time */ + req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`)) + cachedHandler.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String()) + require.Equal(t, "true", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should exist because the response is served from cache + require.Equal(t, "true", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should exist because the response is rate limited + + time.Sleep(110 * time.Millisecond) + + /* Third request should not be served from cache due to expiry */ + req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`)) + cachedHandler.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, `{"message":"Hello, World!"}`, recorder.Body.String()) + require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache + require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT rate limited + +} + +func Test_CacheMiddleware_DoNotCacheErroredPayload(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + JSON(w, http.StatusInternalServerError, map[string]interface{}{"error": "Something went wrong!"}) + }) + + duration := 1 * time.Minute + cache := NewCache() + cachedHandler := CacheMiddleware(cache, duration).Middleware(handler) + + /* First request should not be served from cache */ + req, recorder := newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`)) + cachedHandler.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusInternalServerError, recorder.Code) + require.Equal(t, `{"error":"Something went wrong!"}`, recorder.Body.String()) + require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache + require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT served from cache + + /* Second request should not be served from cache - err'ed payloads are not cached */ + req, recorder = newTestRequest("POST", "/custom-metric", []byte(`{"data": {"numProjects": 2000}}`)) + cachedHandler.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusInternalServerError, recorder.Code) + require.Equal(t, `{"error":"Something went wrong!"}`, recorder.Body.String()) + require.Equal(t, "", recorder.Header().Get("X-Replicated-Served-From-Cache")) // Header should NOT exist because the response is NOT served from cache + require.Equal(t, "", recorder.Header().Get("X-Replicated-Rate-Limited")) // Header should NOT exist because the response is NOT rate limited + +} diff --git a/pkg/util/http.go b/pkg/util/http_client.go similarity index 100% rename from pkg/util/http.go rename to pkg/util/http_client.go