diff --git a/esti/s3_gateway_test.go b/esti/s3_gateway_test.go index d4936e30809..00295616a72 100644 --- a/esti/s3_gateway_test.go +++ b/esti/s3_gateway_test.go @@ -328,6 +328,88 @@ func TestS3HeadBucket(t *testing.T) { }) } +func TestS3CopyObjectMultipart(t *testing.T) { + ctx, _, repo := setupTest(t) + defer tearDownTest(repo) + + // additional repository for copy between repos + const destRepoName = "tests3copyobjectmultipartdest" + destRepo := createRepositoryByName(ctx, t, destRepoName) + defer deleteRepositoryIfAskedTo(ctx, destRepoName) + + // content + r := rand.New(rand.NewSource(17)) + objContent := testutil.NewRandomReader(r, largeDataContentLength) + srcPath := gatewayTestPrefix + "source-file" + destPath := gatewayTestPrefix + "dest-file" + + // upload data + s3lakefsClient := newMinioClient(t, credentials.NewStaticV4) + _, err := s3lakefsClient.PutObject(ctx, repo, srcPath, objContent, largeDataContentLength, + minio.PutObjectOptions{}) + require.NoError(t, err) + + dest := minio.CopyDestOptions{ + Bucket: destRepo, + Object: destPath, + } + + srcs := []minio.CopySrcOptions{ + { + Bucket: repo, + Object: srcPath, + MatchRange: true, + Start: 0, + End: minDataContentLengthForMultipart - 1, + }, { + Bucket: repo, + Object: srcPath, + MatchRange: true, + Start: minDataContentLengthForMultipart, + End: largeDataContentLength - 1, + }, + } + + ui, err := s3lakefsClient.ComposeObject(ctx, dest, srcs...) + if err != nil { + t.Fatalf("minio.Client.ComposeObject from(%+v) to(%+v): %s", srcs, dest, err) + } + + if ui.Size != largeDataContentLength { + t.Errorf("Copied %d bytes when expecting %d", ui.Size, largeDataContentLength) + } + + // Comparing 2 readers is too much work. Instead just hash them. + // This will fail for malicious bad S3 gateways, but otherwise is + // fine. + uploadedReader, err := s3lakefsClient.GetObject(ctx, repo, srcPath, minio.GetObjectOptions{}) + if err != nil { + t.Fatalf("Get uploaded object: %s", err) + } + defer uploadedReader.Close() + uploadedCRC, err := testutil.ChecksumReader(uploadedReader) + if err != nil { + t.Fatalf("Read uploaded object: %s", err) + } + if uploadedCRC == 0 { + t.Fatal("Impossibly bad luck: uploaded data with CRC64 == 0!") + } + + copiedReader, err := s3lakefsClient.GetObject(ctx, repo, srcPath, minio.GetObjectOptions{}) + if err != nil { + t.Fatalf("Get copied object: %s", err) + } + defer copiedReader.Close() + copiedCRC, err := testutil.ChecksumReader(copiedReader) + if err != nil { + t.Fatalf("Read copied object: %s", err) + } + + if uploadedCRC != copiedCRC { + t.Fatal("Copy not equal") + } +} + func TestS3CopyObject(t *testing.T) { ctx, _, repo := setupTest(t) defer tearDownTest(repo) diff --git a/esti/system_test.go b/esti/system_test.go index 6b0685d0e34..2d64255ab7d 100644 --- a/esti/system_test.go +++ b/esti/system_test.go @@ -144,7 +144,20 @@ func deleteRepositoryIfAskedTo(ctx context.Context, repositoryName string) { } } -const randomDataContentLength = 16 +const ( + // randomDataContentLength is the content length used for small + // objects. It is intentionally not a round number. + randomDataContentLength = 16 + + // minDataContentLengthForMultipart is the content length for all + // parts of a multipart upload except the last. Its value -- 5MiB + // -- is defined in the S3 protocol, and cannot be changed. + minDataContentLengthForMultipart = 5 << 20 + + // largeDataContentLength is >minDataContentLengthForMultipart, + // which is large enough to require multipart operations. + largeDataContentLength = 6 << 20 +) func uploadFileRandomDataAndReport(ctx context.Context, repo, branch, objPath string, direct bool) (checksum, content string, err error) { objContent := randstr.String(randomDataContentLength) diff --git a/pkg/block/azure/adapter.go b/pkg/block/azure/adapter.go index 0e0da45df7c..104cc47d4f8 100644 --- a/pkg/block/azure/adapter.go +++ b/pkg/block/azure/adapter.go @@ -533,18 +533,7 @@ func (a *Adapter) copyPartRange(ctx context.Context, sourceObj, destinationObj b return nil, err } - destinationContainer, err := a.clientCache.NewContainerClient(qualifiedDestinationKey.StorageAccountName, qualifiedDestinationKey.ContainerName) - if err != nil { - return nil, err - } - sourceContainer, err := a.clientCache.NewContainerClient(qualifiedSourceKey.StorageAccountName, qualifiedSourceKey.ContainerName) - if err != nil { - return nil, err - } - - sourceBlobURL := sourceContainer.NewBlockBlobClient(qualifiedSourceKey.BlobURL) - - return copyPartRange(ctx, *destinationContainer, qualifiedDestinationKey.BlobURL, *sourceBlobURL, startPosition, count) + return copyPartRange(ctx, a.clientCache, qualifiedDestinationKey, qualifiedSourceKey, startPosition, count) } func (a *Adapter) AbortMultiPartUpload(_ context.Context, _ block.ObjectPointer, _ string) error { diff --git a/pkg/block/azure/multipart_block_writer.go b/pkg/block/azure/multipart_block_writer.go index 6b0ac03ef44..e82b1f0dec4 100644 --- a/pkg/block/azure/multipart_block_writer.go +++ b/pkg/block/azure/multipart_block_writer.go @@ -174,9 +174,20 @@ func getMultipartSize(ctx context.Context, container container.Client, objName s return int64(size), nil } -func copyPartRange(ctx context.Context, destinationContainer container.Client, destinationObjName string, sourceBlobURL blockblob.Client, startPosition, count int64) (*block.UploadPartResponse, error) { +func copyPartRange(ctx context.Context, clientCache *ClientCache, destinationKey, sourceKey BlobURLInfo, startPosition, count int64) (*block.UploadPartResponse, error) { + destinationContainer, err := clientCache.NewContainerClient(destinationKey.StorageAccountName, destinationKey.ContainerName) + if err != nil { + return nil, fmt.Errorf("copy part: get destination client: %w", err) + } + sourceContainer, err := clientCache.NewContainerClient(sourceKey.StorageAccountName, sourceKey.ContainerName) + if err != nil { + return nil, fmt.Errorf("copy part: get source client: %w", err) + } base64BlockID := generateRandomBlockID() - _, err := sourceBlobURL.StageBlockFromURL(ctx, base64BlockID, sourceBlobURL.URL(), + destinationBlob := destinationContainer.NewBlockBlobClient(destinationKey.BlobURL) + sourceBlob := sourceContainer.NewBlockBlobClient(sourceKey.BlobURL) + + stageBlockResponse, err := destinationBlob.StageBlockFromURL(ctx, base64BlockID, sourceBlob.URL(), &blockblob.StageBlockFromURLOptions{ Range: blob.HTTPRange{ Offset: startPosition, @@ -187,25 +198,20 @@ func copyPartRange(ctx context.Context, destinationContainer container.Client, d return nil, err } - // add size and id to etag - response, err := sourceBlobURL.GetProperties(ctx, nil) - if err != nil { - return nil, err - } - etag := "\"" + hex.EncodeToString(response.ContentMD5) + "\"" - size := response.ContentLength + // add size, etag + etag := "\"" + hex.EncodeToString(stageBlockResponse.ContentMD5) + "\"" base64Etag := base64.StdEncoding.EncodeToString([]byte(etag)) // stage id data - blobIDsURL := destinationContainer.NewBlockBlobClient(destinationObjName + idSuffix) - _, err = blobIDsURL.StageBlock(ctx, base64Etag, streaming.NopCloser(strings.NewReader(base64BlockID+"\n")), nil) + blobIDsBlob := destinationContainer.NewBlockBlobClient(destinationKey.BlobURL + idSuffix) + _, err = blobIDsBlob.StageBlock(ctx, base64Etag, streaming.NopCloser(strings.NewReader(base64BlockID+"\n")), nil) if err != nil { return nil, fmt.Errorf("failed staging part data: %w", err) } // stage size data - sizeData := fmt.Sprintf("%d\n", size) - blobSizesURL := destinationContainer.NewBlockBlobClient(destinationObjName + sizeSuffix) - _, err = blobSizesURL.StageBlock(ctx, base64Etag, streaming.NopCloser(strings.NewReader(sizeData)), nil) + sizeData := fmt.Sprintf("%d\n", count) + blobSizesBlob := destinationContainer.NewBlockBlobClient(destinationKey.BlobURL + sizeSuffix) + _, err = blobSizesBlob.StageBlock(ctx, base64Etag, streaming.NopCloser(strings.NewReader(sizeData)), nil) if err != nil { return nil, fmt.Errorf("failed staging part data: %w", err) } diff --git a/pkg/block/local/adapter.go b/pkg/block/local/adapter.go index e08abf3031d..6b604c61506 100644 --- a/pkg/block/local/adapter.go +++ b/pkg/block/local/adapter.go @@ -239,13 +239,13 @@ func (l *Adapter) UploadCopyPart(ctx context.Context, sourceObj, destinationObj } r, err := l.Get(ctx, sourceObj, 0) if err != nil { - return nil, err + return nil, fmt.Errorf("copy get: %w", err) } md5Read := block.NewHashingReader(r, block.HashFunctionMD5) fName := uploadID + fmt.Sprintf("-%05d", partNumber) err = l.Put(ctx, block.ObjectPointer{StorageNamespace: destinationObj.StorageNamespace, Identifier: fName}, -1, md5Read, block.PutOpts{}) if err != nil { - return nil, err + return nil, fmt.Errorf("copy put: %w", err) } etag := hex.EncodeToString(md5Read.Md5.Sum(nil)) return &block.UploadPartResponse{ @@ -259,13 +259,13 @@ func (l *Adapter) UploadCopyPartRange(ctx context.Context, sourceObj, destinatio } r, err := l.GetRange(ctx, sourceObj, startPosition, endPosition) if err != nil { - return nil, err + return nil, fmt.Errorf("copy range get: %w", err) } md5Read := block.NewHashingReader(r, block.HashFunctionMD5) fName := uploadID + fmt.Sprintf("-%05d", partNumber) err = l.Put(ctx, block.ObjectPointer{StorageNamespace: destinationObj.StorageNamespace, Identifier: fName}, -1, md5Read, block.PutOpts{}) if err != nil { - return nil, err + return nil, fmt.Errorf("copy range put: %w", err) } etag := hex.EncodeToString(md5Read.Md5.Sum(nil)) return &block.UploadPartResponse{ diff --git a/pkg/gateway/operations/putobject.go b/pkg/gateway/operations/putobject.go index 48c5651ce4b..cf5f16e81de 100644 --- a/pkg/gateway/operations/putobject.go +++ b/pkg/gateway/operations/putobject.go @@ -66,14 +66,8 @@ func (controller *PutObject) RequiredPermissions(req *http.Request, repoID, _, d } // extractEntryFromCopyReq: get metadata from source file -func extractEntryFromCopyReq(w http.ResponseWriter, req *http.Request, o *PathOperation, copySource string) *catalog.DBEntry { - p, err := getPathFromSource(copySource) - if err != nil { - o.Log(req).WithError(err).Error("could not parse copy source path") - _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInvalidCopySource)) - return nil - } - ent, err := o.Catalog.GetEntry(req.Context(), o.Repository.Name, p.Reference, p.Path, catalog.GetEntryParams{}) +func extractEntryFromCopyReq(w http.ResponseWriter, req *http.Request, o *PathOperation, copySource path.ResolvedAbsolutePath) *catalog.DBEntry { + ent, err := o.Catalog.GetEntry(req.Context(), copySource.Repo, copySource.Reference, copySource.Path, catalog.GetEntryParams{}) if err != nil { o.Log(req).WithError(err).Error("could not read copy source") _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInvalidCopySource)) @@ -151,13 +145,31 @@ func handleUploadPart(w http.ResponseWriter, req *http.Request, o *PathOperation // https://docs.aws.amazon.com/AmazonS3/latest/API/API_UploadPartCopy.html#API_UploadPartCopy_RequestSyntax if copySource := req.Header.Get(CopySourceHeader); copySource != "" { // see if there's a range passed as well - ent := extractEntryFromCopyReq(w, req, o, copySource) + resolvedCopySource, err := getPathFromSource(copySource) + if err != nil { + o.Log(req).WithField("copy_source", copySource).WithError(err).Error("could not parse copy source path") + _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInvalidCopySource)) + return + } + ent := extractEntryFromCopyReq(w, req, o, resolvedCopySource) if ent == nil { return // operation already failed } + srcRepo := o.Repository + if resolvedCopySource.Repo != o.Repository.Name { + srcRepo, err = o.Catalog.GetRepository(req.Context(), resolvedCopySource.Repo) + if err != nil { + o.Log(req). + WithField("copy_source", copySource). + WithError(err). + Error("Failed to get repository") + _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInvalidCopySource)) + return + } + } src := block.ObjectPointer{ - StorageNamespace: o.Repository.StorageNamespace, + StorageNamespace: srcRepo.StorageNamespace, IdentifierType: ent.AddressType.ToIdentifierType(), Identifier: ent.PhysicalAddress, } @@ -184,7 +196,13 @@ func handleUploadPart(w http.ResponseWriter, req *http.Request, o *PathOperation } if err != nil { - o.Log(req).WithError(err).WithField("copy_source", ent.Path).Error("copy part " + partNumberStr + " upload failed") + o.Log(req). + WithFields(logging.Fields{ + "copy_source": ent.Path, + "part": partNumberStr, + }). + WithError(err). + Error("copy part: upload failed") _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInternalError)) return } @@ -204,7 +222,8 @@ func handleUploadPart(w http.ResponseWriter, req *http.Request, o *PathOperation }, byteSize, req.Body, uploadID, partNumber) if err != nil { - o.Log(req).WithError(err).Error("part " + partNumberStr + " upload failed") + o.Log(req).WithField("part", partNumberStr). + WithError(err).Error("part upload failed") _ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInternalError)) return } diff --git a/pkg/gateway/sig/v2.go b/pkg/gateway/sig/v2.go index b58924e8d6e..6949e0cbdd7 100644 --- a/pkg/gateway/sig/v2.go +++ b/pkg/gateway/sig/v2.go @@ -32,7 +32,7 @@ var ( //nolint:gochecknoinits func init() { interestingResourcesContainer := []string{ - "accelerate", "acl", "cors", "defaultObjectAcl", + "accelerate", "acl", "copy-source", "cors", "defaultObjectAcl", "location", "logging", "partNumber", "policy", "requestPayment", "torrent", "versioning", "versionId", "versions", "website", diff --git a/pkg/testutil/checksum.go b/pkg/testutil/checksum.go new file mode 100644 index 00000000000..c6adbba6816 --- /dev/null +++ b/pkg/testutil/checksum.go @@ -0,0 +1,29 @@ +package testutil + +import ( + "hash/crc64" + "io" +) + +const bufSize = 4 << 16 + +var table *crc64.Table = crc64.MakeTable(crc64.ECMA) + +// ChecksumReader returns the checksum (CRC-64) of the contents of reader. +func ChecksumReader(reader io.Reader) (uint64, error) { + buf := make([]byte, bufSize) + var val uint64 + for { + n, err := reader.Read(buf) + if err != nil { + if err == io.EOF { + return val, nil + } + return val, err + } + if n == 0 { + return val, nil + } + val = crc64.Update(val, table, buf[:n]) + } +} diff --git a/pkg/testutil/random.go b/pkg/testutil/random.go index 064b9aa14c9..70373d5271b 100644 --- a/pkg/testutil/random.go +++ b/pkg/testutil/random.go @@ -1,6 +1,8 @@ package testutil import ( + "io" + "math" "math/rand" "strings" "unicode/utf8" @@ -33,3 +35,27 @@ func RandomString(rand *rand.Rand, size int) string { _, lastRuneSize := utf8.DecodeLastRuneInString(ret) return ret[0 : len(ret)-lastRuneSize] } + +type randomReader struct { + rand *rand.Rand + remaining int64 +} + +func (r *randomReader) Read(p []byte) (int, error) { + if r.remaining <= 0 { + return 0, io.EOF + } + n := len(p) + if math.MaxInt >= r.remaining && n > int(r.remaining) { + n = int(r.remaining) + } + // n still fits into int! + r.rand.Read(p[:n]) + r.remaining -= int64(n) + return n, nil +} + +// NewRandomReader returns a reader that will return size bytes from rand. +func NewRandomReader(rand *rand.Rand, size int64) io.Reader { + return &randomReader{rand: rand, remaining: size} +}