From 1cb98cf54cfbe634106fdb9329d838a56d4c71c6 Mon Sep 17 00:00:00 2001 From: Pablo Chacin Date: Thu, 20 Jun 2024 13:05:58 +0200 Subject: [PATCH] WIP: Implement cache server Signed-off-by: Pablo Chacin --- cache_server.go | 142 +++++++++++++++++++++++++++++++++++++++++++ cache_server_test.go | 123 +++++++++++++++++++++++++++++++++++++ 2 files changed, 265 insertions(+) create mode 100644 cache_server.go create mode 100644 cache_server_test.go diff --git a/cache_server.go b/cache_server.go new file mode 100644 index 0000000..f020433 --- /dev/null +++ b/cache_server.go @@ -0,0 +1,142 @@ +package k6build + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" +) + +// CacheServerResponse is the response to a cache server request +type CacheServerResponse struct { + Error string + Object Object +} + +// CacheServer implements an http server that handles cache requests +type CacheServer struct { + cache Cache + baseURL string +} + +// NewCacheServer returns a CacheServer backed by a cache +func NewCacheServer(baseURL string, cache Cache) http.Handler { + cacheSrv := &CacheServer{ + baseURL: baseURL, + cache: cache, + } + + handler := http.NewServeMux() + handler.HandleFunc("/store", cacheSrv.Store) + handler.HandleFunc("/get", cacheSrv.Get) + handler.HandleFunc("/download", cacheSrv.Download) + + return handler +} + +// Get retrieves an objects if exists in the cache or an error otherwise +func (s *CacheServer) Get(w http.ResponseWriter, r *http.Request) { + resp := CacheServerResponse{} + + id := r.URL.Query().Get("id") + if id == "" { + w.WriteHeader(http.StatusBadRequest) + return + } + + object, err := s.cache.Get(context.Background(), id) //nolint:contextcheck + if err != nil { + if errors.Is(err, ErrObjectNotFound) { + w.WriteHeader(http.StatusNotFound) + } else { + w.WriteHeader(http.StatusInternalServerError) + } + return + } + + // overwrite URL with own + resp.Object = Object{ + ID: id, + Checksum: object.Checksum, + URL: fmt.Sprintf(url.JoinPath(s.baseURL, object.ID)), + } + + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) //nolint:errchkjson +} + +// Store stores the object and returns the metadata +func (s *CacheServer) Store(w http.ResponseWriter, r *http.Request) { + resp := CacheServerResponse{} + + id := r.URL.Query().Get("id") + if id == "" { + w.WriteHeader(http.StatusBadRequest) + return + } + + object, err := s.cache.Store(context.Background(), id, r.Body) //nolint:contextcheck + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + // overwrite URL with own + resp.Object = Object{ + ID: id, + Checksum: object.Checksum, + URL: fmt.Sprintf(url.JoinPath(s.baseURL, object.ID)), + } + + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) //nolint:errchkjson +} + +// Download returns an object's content given its id +func (s *CacheServer) Download(w http.ResponseWriter, r *http.Request) { + id := r.URL.Query().Get("id") + if id == "" { + w.WriteHeader(http.StatusBadRequest) + return + } + + object, err := s.cache.Get(context.Background(), id) //nolint:contextcheck + if err != nil { + if errors.Is(err, ErrObjectNotFound) { + w.WriteHeader(http.StatusNotFound) + } else { + w.WriteHeader(http.StatusInternalServerError) + } + return + } + + objectURL, err := url.Parse(object.URL) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + switch objectURL.Scheme { + case "file": + objectFile, err := os.Open(objectURL.Path) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + defer func() { + _ = objectFile.Close() + }() + + w.WriteHeader(http.StatusOK) + w.Header().Add("Content-Type", "application/binary") + w.Header().Add("ETag", object.ID) + _, _ = io.Copy(w, objectFile) + default: + w.WriteHeader(http.StatusInternalServerError) + return + } +} diff --git a/cache_server_test.go b/cache_server_test.go new file mode 100644 index 0000000..8ce4a59 --- /dev/null +++ b/cache_server_test.go @@ -0,0 +1,123 @@ +package k6build + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +type MemoryCache struct { + objects map[string]Object +} + +func NewMemoryCache() *MemoryCache { + return &MemoryCache{ + objects: map[string]Object{}, + } +} + +func (f *MemoryCache) Get(_ context.Context, id string) (Object, error) { + object, found := f.objects[id] + if !found { + return Object{}, ErrObjectNotFound + } + + return object, nil +} + +func (f *MemoryCache) Store(_ context.Context, id string, content io.Reader) (Object, error) { + buffer := bytes.Buffer{} + _, err := buffer.ReadFrom(content) + if err != nil { + return Object{}, ErrCreatingObject + } + + checksum := fmt.Sprintf("%x", sha256.Sum256(buffer.Bytes())) + object := Object{ + ID: id, + Checksum: checksum, + URL: fmt.Sprintf("memory://%s", id), + } + + f.objects[id] = object + + return object, nil +} + +func TestCacheServer(t *testing.T) { + t.Parallel() + + cache := NewMemoryCache() + objects := map[string][]byte{ + "object1": []byte("content object 1"), + } + + for id, content := range objects { + buffer := bytes.NewBuffer(content) + if _, err := cache.Store(context.TODO(), id, buffer); err != nil { + t.Fatalf("test setup: %v", err) + } + } + + cacheSrv := NewCacheServer("", cache) + + srv := httptest.NewServer(cacheSrv) + + testCases := []struct { + title string + id string + status int + epectErr string + }{ + { + title: "return object", + id: "object1", + status: http.StatusOK, + }, + { + title: "object not found", + id: "not_found", + status: http.StatusNotFound, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.title, func(t *testing.T) { + t.Parallel() + + url := fmt.Sprintf("%s/get?id=%s", srv.URL, tc.id) + resp, err := http.Get(url) + if err != nil { + t.Fatalf("accessing server %v", err) + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != tc.status { + t.Fatalf("expected %s got %s", http.StatusText(tc.status), resp.Status) + } + + if tc.status != http.StatusOK { + return + } + + cacheResponse := CacheServerResponse{} + err = json.NewDecoder(resp.Body).Decode(&cacheResponse) + if err != nil { + t.Fatalf("reading response content %v", err) + } + + if cacheResponse.Object.ID != tc.id { + t.Fatalf("expected object id %s got %s", tc.id, cacheResponse.Object.ID) + } + }) + } +}