From e5234dd7fe4baa5f72c8eeeefdaf59f01de3364b Mon Sep 17 00:00:00 2001 From: pablochacin Date: Fri, 13 Dec 2024 17:15:05 +0100 Subject: [PATCH] refactor store download method (#76) * refactor store download method Signed-off-by: Pablo Chacin --- pkg/store/downloader/downloader.go | 82 +++++++++++++ pkg/store/downloader/downloader_test.go | 124 +++++++++++++++++++ pkg/store/file/file.go | 46 ------- pkg/store/file/file_test.go | 154 +++++++----------------- pkg/store/s3/s3.go | 24 ---- pkg/store/server/server.go | 3 +- pkg/store/server/server_test.go | 79 +++--------- pkg/store/store.go | 5 +- 8 files changed, 268 insertions(+), 249 deletions(-) create mode 100644 pkg/store/downloader/downloader.go create mode 100644 pkg/store/downloader/downloader_test.go diff --git a/pkg/store/downloader/downloader.go b/pkg/store/downloader/downloader.go new file mode 100644 index 0000000..670be54 --- /dev/null +++ b/pkg/store/downloader/downloader.go @@ -0,0 +1,82 @@ +// Package downloader implements utility functions for downloading objects from a store +package downloader + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + + "github.com/grafana/k6build" + "github.com/grafana/k6build/pkg/store" + "github.com/grafana/k6build/pkg/util" +) + +// Download returns the content of the object +func Download(ctx context.Context, object store.Object) (io.ReadCloser, error) { + url, err := url.Parse(object.URL) + if err != nil { + return nil, k6build.NewWrappedError(store.ErrAccessingObject, err) + } + + switch url.Scheme { + case "file": + objectPath, err := util.URLToFilePath(url) + if err != nil { + return nil, err + } + + // prevent malicious path + objectPath, err = sanitizePath(objectPath) + if err != nil { + return nil, err + } + + objectFile, err := os.Open(objectPath) //nolint:gosec // path is sanitized + if err != nil { + // FIXME: is the path has invalid characters, still will return ErrNotExists + if errors.Is(err, os.ErrNotExist) { + return nil, store.ErrObjectNotFound + } + return nil, k6build.NewWrappedError(store.ErrAccessingObject, err) + } + + return objectFile, nil + case "http", "https": + req, err := http.NewRequestWithContext(ctx, http.MethodGet, object.URL, nil) + if err != nil { + return nil, k6build.NewWrappedError(store.ErrAccessingObject, err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, k6build.NewWrappedError(store.ErrAccessingObject, err) + } + + if resp.StatusCode == http.StatusNotFound { + return nil, store.ErrObjectNotFound + } + + if resp.StatusCode != http.StatusOK { + return nil, k6build.NewWrappedError(store.ErrAccessingObject, fmt.Errorf("HTTP response: %s", resp.Status)) + } + + return resp.Body, nil + default: + return nil, fmt.Errorf("%w unsupported schema: %s", store.ErrInvalidURL, url.Scheme) + } +} + +func sanitizePath(path string) (string, error) { + path = filepath.Clean(path) + + if !filepath.IsAbs(path) { + return "", fmt.Errorf("%w : invalid path %s", store.ErrInvalidURL, path) + } + + return path, nil +} diff --git a/pkg/store/downloader/downloader_test.go b/pkg/store/downloader/downloader_test.go new file mode 100644 index 0000000..b2bb50d --- /dev/null +++ b/pkg/store/downloader/downloader_test.go @@ -0,0 +1,124 @@ +package downloader + +import ( + "bytes" + "context" + "errors" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/grafana/k6build/pkg/store" + "github.com/grafana/k6build/pkg/util" +) + +func fileURL(dir string, path string) string { + url, err := util.URLFromFilePath(filepath.Join(dir, path)) + if err != nil { + panic(err) + } + return url.String() +} + +func httpURL(srv *httptest.Server, path string) string { + return srv.URL + "/" + path +} + +func TestDownload(t *testing.T) { + t.Parallel() + + storeDir := t.TempDir() + + objects := []struct { + id string + content []byte + }{ + { + id: "object", + content: []byte("content"), + }, + } + + for _, o := range objects { + if err := os.WriteFile(filepath.Join(storeDir, o.id), o.content, 0o600); err != nil { + t.Fatalf("test setup %v", err) + } + } + + srv := httptest.NewServer(http.FileServer(http.Dir(storeDir))) + t.Cleanup(srv.Close) + + testCases := []struct { + title string + id string + url string + expected []byte + expectErr error + }{ + { + title: "download file url", + id: "object", + url: fileURL(storeDir, "object"), + expected: []byte("content"), + expectErr: nil, + }, + { + title: "download non existing file url", + id: "object", + url: fileURL(storeDir, "another_object"), + expectErr: store.ErrObjectNotFound, + }, + // FIXME: can't check url is outside object store's directory + // { + // title: "download malicious file url", + // id: "object", + // url: fileURL(storeDir, "/../../object"), + // expectErr: store.ErrInvalidURL, + // }, + { + title: "download http url", + id: "object", + url: httpURL(srv, "object"), + expected: []byte("content"), + expectErr: nil, + }, + { + title: "download non existing http url", + id: "object", + url: httpURL(srv, "another-object"), + expectErr: store.ErrObjectNotFound, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.title, func(t *testing.T) { + t.Parallel() + + object := store.Object{ID: tc.id, URL: tc.url} + content, err := Download(context.TODO(), object) + if !errors.Is(err, tc.expectErr) { + t.Fatalf("expected %v got %v", tc.expectErr, err) + } + + // if expected error, don't check returned object + if tc.expectErr != nil { + return + } + + defer content.Close() //nolint:errcheck + + data := bytes.Buffer{} + _, err = data.ReadFrom(content) + if err != nil { + t.Fatalf("reading content: %v", err) + } + + if !bytes.Equal(data.Bytes(), tc.expected) { + t.Fatalf("expected %v got %v", tc.expected, data) + } + }) + } +} diff --git a/pkg/store/file/file.go b/pkg/store/file/file.go index 7e5bf69..e68e3e4 100644 --- a/pkg/store/file/file.go +++ b/pkg/store/file/file.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "io" - "net/url" "os" "path/filepath" "strings" @@ -129,51 +128,6 @@ func (f *Store) Get(_ context.Context, id string) (store.Object, error) { }, nil } -// Download returns the content of the object given its url -func (f *Store) Download(_ context.Context, object store.Object) (io.ReadCloser, error) { - url, err := url.Parse(object.URL) - if err != nil { - return nil, k6build.NewWrappedError(store.ErrAccessingObject, err) - } - - switch url.Scheme { - case "file": - objectPath, err := util.URLToFilePath(url) - if err != nil { - return nil, err - } - - // prevent malicious path - objectPath, err = f.sanitizePath(objectPath) - if err != nil { - return nil, err - } - - objectFile, err := os.Open(objectPath) //nolint:gosec // path is sanitized - if err != nil { - // FIXME: is the path has invalid characters, still will return ErrNotExists - if errors.Is(err, os.ErrNotExist) { - return nil, store.ErrObjectNotFound - } - return nil, k6build.NewWrappedError(store.ErrAccessingObject, err) - } - - return objectFile, nil - default: - return nil, fmt.Errorf("%w unsupported schema: %s", store.ErrInvalidURL, url.Scheme) - } -} - -func (f *Store) sanitizePath(path string) (string, error) { - path = filepath.Clean(path) - - if !filepath.IsAbs(path) || !strings.HasPrefix(path, f.dir) { - return "", fmt.Errorf("%w : invalid path %s", store.ErrInvalidURL, path) - } - - return path, nil -} - // lockObject obtains a mutex used to prevent concurrent builds of the same artifact and // returns a function that will unlock the mutex associated to the given id in the object store. // The lock is also removed from the map. Subsequent calls will get another lock on the same diff --git a/pkg/store/file/file_test.go b/pkg/store/file/file_test.go index 9c5d6f7..e881f8e 100644 --- a/pkg/store/file/file_test.go +++ b/pkg/store/file/file_test.go @@ -7,7 +7,6 @@ import ( "fmt" "net/url" "os" - "path/filepath" "testing" "github.com/grafana/k6build/pkg/store" @@ -123,7 +122,7 @@ func TestFileStoreStoreObject(t *testing.T) { } } -func TestFileStoreRetrieval(t *testing.T) { +func TestFileStoreGet(t *testing.T) { t.Parallel() preload := []object{ @@ -139,118 +138,53 @@ func TestFileStoreRetrieval(t *testing.T) { t.Fatalf("test setup: %v", err) } - t.Run("TestFileStoreGet", func(t *testing.T) { - testCases := []struct { - title string - id string - expected []byte - expectErr error - }{ - { - title: "retrieve existing object", - id: "object", - expected: []byte("content"), - expectErr: nil, - }, - { - title: "retrieve non existing object", - id: "another object", - expectErr: store.ErrObjectNotFound, - }, - } - - for _, tc := range testCases { - t.Run(tc.title, func(t *testing.T) { - t.Parallel() - - obj, err := fileStore.Get(context.TODO(), tc.id) - if !errors.Is(err, tc.expectErr) { - t.Fatalf("expected %v got %v", tc.expectErr, err) - } - - // if expected error, don't check returned object - if tc.expectErr != nil { - return - } - - objectURL, _ := url.Parse(obj.URL) - fileUPath, err := util.URLToFilePath(objectURL) - if err != nil { - t.Fatalf("invalid url %v", err) - } - - data, err := os.ReadFile(fileUPath) - if err != nil { - t.Fatalf("reading object url %v", err) - } - - if !bytes.Equal(data, tc.expected) { - t.Fatalf("expected %v got %v", tc.expected, data) - } - }) - } - }) - - // FIXME: This test is leaking how the file store creates the URLs for the objects - t.Run("TestFileStoreDownload", func(t *testing.T) { - t.Parallel() - - testCases := []struct { - title string - id string - url string - expected []byte - expectErr error - }{ - { - title: "download existing object", - id: "object", - url: filepath.Join(storeDir, "/object/data"), - expected: []byte("content"), - expectErr: nil, - }, - { - title: "download non existing object", - id: "object", - url: filepath.Join(storeDir, "/another_object/data"), - expectErr: store.ErrObjectNotFound, - }, - { - title: "download malicious url", - id: "object", - url: filepath.Join(storeDir, "/../../data"), - expectErr: store.ErrInvalidURL, - }, - } + testCases := []struct { + title string + id string + expected []byte + expectErr error + }{ + { + title: "retrieve existing object", + id: "object", + expected: []byte("content"), + expectErr: nil, + }, + { + title: "retrieve non existing object", + id: "another object", + expectErr: store.ErrObjectNotFound, + }, + } - for _, tc := range testCases { - t.Run(tc.title, func(t *testing.T) { - t.Parallel() + for _, tc := range testCases { + t.Run(tc.title, func(t *testing.T) { + t.Parallel() - objectURL, _ := util.URLFromFilePath(tc.url) - object := store.Object{ID: tc.id, URL: objectURL.String()} - content, err := fileStore.Download(context.TODO(), object) - if !errors.Is(err, tc.expectErr) { - t.Fatalf("expected %v got %v", tc.expectErr, err) - } + obj, err := fileStore.Get(context.TODO(), tc.id) + if !errors.Is(err, tc.expectErr) { + t.Fatalf("expected %v got %v", tc.expectErr, err) + } - // if expected error, don't check returned object - if tc.expectErr != nil { - return - } + // if expected error, don't check returned object + if tc.expectErr != nil { + return + } - defer content.Close() //nolint:errcheck + objectURL, _ := url.Parse(obj.URL) + fileUPath, err := util.URLToFilePath(objectURL) + if err != nil { + t.Fatalf("invalid url %v", err) + } - data := bytes.Buffer{} - _, err = data.ReadFrom(content) - if err != nil { - t.Fatalf("reading content: %v", err) - } + data, err := os.ReadFile(fileUPath) + if err != nil { + t.Fatalf("reading object url %v", err) + } - if !bytes.Equal(data.Bytes(), tc.expected) { - t.Fatalf("expected %v got %v", tc.expected, data) - } - }) - } - }) + if !bytes.Equal(data, tc.expected) { + t.Fatalf("expected %v got %v", tc.expected, data) + } + }) + } } diff --git a/pkg/store/s3/s3.go b/pkg/store/s3/s3.go index 619d71d..3f165e9 100644 --- a/pkg/store/s3/s3.go +++ b/pkg/store/s3/s3.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "net/http" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -178,29 +177,6 @@ func (s *Store) Get(ctx context.Context, id string) (store.Object, error) { }, nil } -// Download returns the content of the object given its url -func (s *Store) Download(ctx context.Context, object store.Object) (io.ReadCloser, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, object.URL, nil) - if err != nil { - return nil, k6build.NewWrappedError(store.ErrAccessingObject, err) - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, k6build.NewWrappedError(store.ErrAccessingObject, err) - } - - if resp.StatusCode == http.StatusNotFound { - return nil, k6build.NewWrappedError(store.ErrAccessingObject, err) - } - - if resp.StatusCode != http.StatusOK { - return nil, k6build.NewWrappedError(store.ErrAccessingObject, fmt.Errorf("HTTP response: %s", resp.Status)) - } - - return resp.Body, nil -} - func (s *Store) getDownloadURL(ctx context.Context, id string) (string, error) { // create a presigned get request to get the download URL request, err := s3.NewPresignClient(s.client).PresignGetObject( diff --git a/pkg/store/server/server.go b/pkg/store/server/server.go index 2328e4f..1a8a15d 100644 --- a/pkg/store/server/server.go +++ b/pkg/store/server/server.go @@ -13,6 +13,7 @@ import ( "github.com/grafana/k6build" "github.com/grafana/k6build/pkg/store" "github.com/grafana/k6build/pkg/store/api" + "github.com/grafana/k6build/pkg/store/downloader" ) // StoreServer implements an http server that handles object store requests @@ -162,7 +163,7 @@ func (s *StoreServer) Download(w http.ResponseWriter, r *http.Request) { return } - objectContent, err := s.store.Download(context.Background(), object) //nolint:contextcheck + objectContent, err := downloader.Download(context.Background(), object) //nolint:contextcheck if err != nil { w.WriteHeader(http.StatusInternalServerError) return diff --git a/pkg/store/server/server_test.go b/pkg/store/server/server_test.go index a1fbe23..89d0940 100644 --- a/pkg/store/server/server_test.go +++ b/pkg/store/server/server_test.go @@ -3,81 +3,23 @@ package server import ( "bytes" "context" - "crypto/sha256" "encoding/json" "fmt" - "io" "net/http" "net/http/httptest" - "net/url" - "strings" "testing" - "github.com/grafana/k6build/pkg/store" "github.com/grafana/k6build/pkg/store/api" + "github.com/grafana/k6build/pkg/store/file" ) -// MemoryStore implements a memory backed object store -type MemoryStore struct { - objects map[string]store.Object - content map[string][]byte -} - -func NewMemoryStore() *MemoryStore { - return &MemoryStore{ - objects: map[string]store.Object{}, - content: map[string][]byte{}, - } -} - -func (f *MemoryStore) Get(_ context.Context, id string) (store.Object, error) { - object, found := f.objects[id] - if !found { - return store.Object{}, store.ErrObjectNotFound - } - - return object, nil -} - -func (f *MemoryStore) Put(_ context.Context, id string, content io.Reader) (store.Object, error) { - buffer := bytes.Buffer{} - _, err := buffer.ReadFrom(content) - if err != nil { - return store.Object{}, store.ErrCreatingObject - } - - checksum := fmt.Sprintf("%x", sha256.Sum256(buffer.Bytes())) - object := store.Object{ - ID: id, - Checksum: checksum, - URL: fmt.Sprintf("memory:///%s", id), - } - - f.objects[id] = object - f.content[id] = buffer.Bytes() - - return object, nil -} - -func (f *MemoryStore) Download(_ context.Context, object store.Object) (io.ReadCloser, error) { - url, err := url.Parse(object.URL) - if err != nil { - return nil, err - } - - id, _ := strings.CutPrefix(url.Path, "/") - content, found := f.content[id] - if !found { - return nil, store.ErrObjectNotFound - } - - return io.NopCloser(bytes.NewBuffer(content)), nil -} - func TestStoreServerGet(t *testing.T) { t.Parallel() - store := NewMemoryStore() + store, err := file.NewFileStore(t.TempDir()) + if err != nil { + t.Fatalf("creating test file store %v", err) + } objects := map[string][]byte{ "object1": []byte("content object 1"), } @@ -155,7 +97,10 @@ func TestStoreServerGet(t *testing.T) { func TestStoreServerPut(t *testing.T) { t.Parallel() - store := NewMemoryStore() + store, err := file.NewFileStore(t.TempDir()) + if err != nil { + t.Fatalf("creating test file store %v", err) + } config := StoreServerConfig{ Store: store, @@ -223,7 +168,11 @@ func TestStoreServerPut(t *testing.T) { func TestStoreServerDownload(t *testing.T) { t.Parallel() - store := NewMemoryStore() + store, err := file.NewFileStore(t.TempDir()) + if err != nil { + t.Fatalf("creating test file store %v", err) + } + objects := map[string][]byte{ "object1": []byte("content object 1"), } diff --git a/pkg/store/store.go b/pkg/store/store.go index d1a8c3b..29c7001 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -13,8 +13,9 @@ var ( ErrAccessingObject = errors.New("accessing object") //nolint:revive ErrCreatingObject = errors.New("creating object") //nolint:revive ErrInitializingStore = errors.New("initializing store") //nolint:revive - ErrObjectNotFound = errors.New("object not found") //nolint:revive ErrInvalidURL = errors.New("invalid object URL") //nolint:revive + ErrObjectNotFound = errors.New("object not found") //nolint:revive + ErrNotSupported = errors.New("not supported") //nolint:revive ) @@ -42,6 +43,4 @@ type ObjectStore interface { Get(ctx context.Context, id string) (Object, error) // Put stores the object and returns the metadata Put(ctx context.Context, id string, content io.Reader) (Object, error) - // Download returns the content of the object - Download(ctx context.Context, object Object) (io.ReadCloser, error) }