Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix S3 gateway cross-repo copies #7468

Merged
merged 12 commits into from
Feb 19, 2024
Merged
82 changes: 82 additions & 0 deletions esti/s3_gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,88 @@ func TestS3HeadBucket(t *testing.T) {
})
}

func TestS3CopyObjectMultipart(t *testing.T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

: 🎉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing tests are good for are finding bugs. And fixing them. And preventing regressions. Automatically. Not sure it's worth the effort.

Although... in this case this test actually found bugs. And helped fix one of them. And another test in the file prevented a regression. Automatically. Actually the case against writing tests has never been weaker 🥳 .

ctx, _, repo := setupTest(t)
defer tearDownTest(repo)

// additional repository for copy between repos
const destRepoName = "tests3copyobjectmultipartdest"
destRepo := createRepositoryByName(ctx, t, destRepoName)
defer deleteRepositoryIfAskedTo(ctx, destRepoName)

// content
r := rand.New(rand.NewSource(17))
objContent := testutil.NewRandomReader(r, largeDataContentLength)
srcPath := gatewayTestPrefix + "source-file"
destPath := gatewayTestPrefix + "dest-file"

// upload data
s3lakefsClient := newMinioClient(t, credentials.NewStaticV4)
_, err := s3lakefsClient.PutObject(ctx, repo, srcPath, objContent, largeDataContentLength,
minio.PutObjectOptions{})
require.NoError(t, err)

dest := minio.CopyDestOptions{
Bucket: destRepo,
Object: destPath,
}

srcs := []minio.CopySrcOptions{
{
Bucket: repo,
Object: srcPath,
MatchRange: true,
Start: 0,
End: minDataContentLengthForMultipart - 1,
}, {
Bucket: repo,
Object: srcPath,
MatchRange: true,
Start: minDataContentLengthForMultipart,
End: largeDataContentLength - 1,
},
}

ui, err := s3lakefsClient.ComposeObject(ctx, dest, srcs...)
if err != nil {
t.Fatalf("minio.Client.ComposeObject from(%+v) to(%+v): %s", srcs, dest, err)
}

if ui.Size != largeDataContentLength {
t.Errorf("Copied %d bytes when expecting %d", ui.Size, largeDataContentLength)
}

// Comparing 2 readers is too much work. Instead just hash them.
// This will fail for malicious bad S3 gateways, but otherwise is
// fine.
uploadedReader, err := s3lakefsClient.GetObject(ctx, repo, srcPath, minio.GetObjectOptions{})
if err != nil {
t.Fatalf("Get uploaded object: %s", err)
}
defer uploadedReader.Close()
uploadedCRC, err := testutil.ChecksumReader(uploadedReader)
if err != nil {
t.Fatalf("Read uploaded object: %s", err)
}
if uploadedCRC == 0 {
t.Fatal("Impossibly bad luck: uploaded data with CRC64 == 0!")
}

copiedReader, err := s3lakefsClient.GetObject(ctx, repo, srcPath, minio.GetObjectOptions{})
if err != nil {
t.Fatalf("Get copied object: %s", err)
}
defer copiedReader.Close()
copiedCRC, err := testutil.ChecksumReader(copiedReader)
if err != nil {
t.Fatalf("Read copied object: %s", err)
}

if uploadedCRC != copiedCRC {
t.Fatal("Copy not equal")
}
}

func TestS3CopyObject(t *testing.T) {
ctx, _, repo := setupTest(t)
defer tearDownTest(repo)
Expand Down
15 changes: 14 additions & 1 deletion esti/system_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,20 @@ func deleteRepositoryIfAskedTo(ctx context.Context, repositoryName string) {
}
}

const randomDataContentLength = 16
const (
// randomDataContentLength is the content length used for small
// objects. It is intentionally not a round number.
randomDataContentLength = 16

// minDataContentLengthForMultipart is the content length for all
// parts of a multipart upload except the last. Its value -- 5MiB
// -- is defined in the S3 protocol, and cannot be changed.
minDataContentLengthForMultipart = 5 << 20

// largeDataContentLength is >minDataContentLengthForMultipart,
// which is large enough to require multipart operations.
largeDataContentLength = 6 << 20
)

func uploadFileRandomDataAndReport(ctx context.Context, repo, branch, objPath string, direct bool) (checksum, content string, err error) {
objContent := randstr.String(randomDataContentLength)
Expand Down
13 changes: 1 addition & 12 deletions pkg/block/azure/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,18 +533,7 @@ func (a *Adapter) copyPartRange(ctx context.Context, sourceObj, destinationObj b
return nil, err
}

destinationContainer, err := a.clientCache.NewContainerClient(qualifiedDestinationKey.StorageAccountName, qualifiedDestinationKey.ContainerName)
if err != nil {
return nil, err
}
sourceContainer, err := a.clientCache.NewContainerClient(qualifiedSourceKey.StorageAccountName, qualifiedSourceKey.ContainerName)
if err != nil {
return nil, err
}

sourceBlobURL := sourceContainer.NewBlockBlobClient(qualifiedSourceKey.BlobURL)

return copyPartRange(ctx, *destinationContainer, qualifiedDestinationKey.BlobURL, *sourceBlobURL, startPosition, count)
return copyPartRange(ctx, a.clientCache, qualifiedDestinationKey, qualifiedSourceKey, startPosition, count)
}

func (a *Adapter) AbortMultiPartUpload(_ context.Context, _ block.ObjectPointer, _ string) error {
Expand Down
34 changes: 20 additions & 14 deletions pkg/block/azure/multipart_block_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,20 @@ func getMultipartSize(ctx context.Context, container container.Client, objName s
return int64(size), nil
}

func copyPartRange(ctx context.Context, destinationContainer container.Client, destinationObjName string, sourceBlobURL blockblob.Client, startPosition, count int64) (*block.UploadPartResponse, error) {
func copyPartRange(ctx context.Context, clientCache *ClientCache, destinationKey, sourceKey BlobURLInfo, startPosition, count int64) (*block.UploadPartResponse, error) {
destinationContainer, err := clientCache.NewContainerClient(destinationKey.StorageAccountName, destinationKey.ContainerName)
if err != nil {
return nil, fmt.Errorf("copy part: get destination client: %w", err)
}
sourceContainer, err := clientCache.NewContainerClient(sourceKey.StorageAccountName, sourceKey.ContainerName)
if err != nil {
return nil, fmt.Errorf("copy part: get source client: %w", err)
}
base64BlockID := generateRandomBlockID()
_, err := sourceBlobURL.StageBlockFromURL(ctx, base64BlockID, sourceBlobURL.URL(),
destinationBlob := destinationContainer.NewBlockBlobClient(destinationKey.BlobURL)
sourceBlob := sourceContainer.NewBlockBlobClient(sourceKey.BlobURL)

stageBlockResponse, err := destinationBlob.StageBlockFromURL(ctx, base64BlockID, sourceBlob.URL(),
&blockblob.StageBlockFromURLOptions{
Range: blob.HTTPRange{
Offset: startPosition,
Expand All @@ -187,25 +198,20 @@ func copyPartRange(ctx context.Context, destinationContainer container.Client, d
return nil, err
}

// add size and id to etag
response, err := sourceBlobURL.GetProperties(ctx, nil)
if err != nil {
return nil, err
}
etag := "\"" + hex.EncodeToString(response.ContentMD5) + "\""
size := response.ContentLength
// add size, etag
etag := "\"" + hex.EncodeToString(stageBlockResponse.ContentMD5) + "\""
base64Etag := base64.StdEncoding.EncodeToString([]byte(etag))
// stage id data
blobIDsURL := destinationContainer.NewBlockBlobClient(destinationObjName + idSuffix)
_, err = blobIDsURL.StageBlock(ctx, base64Etag, streaming.NopCloser(strings.NewReader(base64BlockID+"\n")), nil)
blobIDsBlob := destinationContainer.NewBlockBlobClient(destinationKey.BlobURL + idSuffix)
_, err = blobIDsBlob.StageBlock(ctx, base64Etag, streaming.NopCloser(strings.NewReader(base64BlockID+"\n")), nil)
if err != nil {
return nil, fmt.Errorf("failed staging part data: %w", err)
}

// stage size data
sizeData := fmt.Sprintf("%d\n", size)
blobSizesURL := destinationContainer.NewBlockBlobClient(destinationObjName + sizeSuffix)
_, err = blobSizesURL.StageBlock(ctx, base64Etag, streaming.NopCloser(strings.NewReader(sizeData)), nil)
sizeData := fmt.Sprintf("%d\n", count)
blobSizesBlob := destinationContainer.NewBlockBlobClient(destinationKey.BlobURL + sizeSuffix)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider moving sizeSuffix and idSuffix to this file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not done in this PR.

_, err = blobSizesBlob.StageBlock(ctx, base64Etag, streaming.NopCloser(strings.NewReader(sizeData)), nil)
if err != nil {
return nil, fmt.Errorf("failed staging part data: %w", err)
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/block/local/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,13 @@ func (l *Adapter) UploadCopyPart(ctx context.Context, sourceObj, destinationObj
}
r, err := l.Get(ctx, sourceObj, 0)
if err != nil {
return nil, err
return nil, fmt.Errorf("copy get: %w", err)
}
md5Read := block.NewHashingReader(r, block.HashFunctionMD5)
fName := uploadID + fmt.Sprintf("-%05d", partNumber)
err = l.Put(ctx, block.ObjectPointer{StorageNamespace: destinationObj.StorageNamespace, Identifier: fName}, -1, md5Read, block.PutOpts{})
if err != nil {
return nil, err
return nil, fmt.Errorf("copy put: %w", err)
}
etag := hex.EncodeToString(md5Read.Md5.Sum(nil))
return &block.UploadPartResponse{
Expand All @@ -259,13 +259,13 @@ func (l *Adapter) UploadCopyPartRange(ctx context.Context, sourceObj, destinatio
}
r, err := l.GetRange(ctx, sourceObj, startPosition, endPosition)
if err != nil {
return nil, err
return nil, fmt.Errorf("copy range get: %w", err)
}
md5Read := block.NewHashingReader(r, block.HashFunctionMD5)
fName := uploadID + fmt.Sprintf("-%05d", partNumber)
err = l.Put(ctx, block.ObjectPointer{StorageNamespace: destinationObj.StorageNamespace, Identifier: fName}, -1, md5Read, block.PutOpts{})
if err != nil {
return nil, err
return nil, fmt.Errorf("copy range put: %w", err)
}
etag := hex.EncodeToString(md5Read.Md5.Sum(nil))
return &block.UploadPartResponse{
Expand Down
43 changes: 31 additions & 12 deletions pkg/gateway/operations/putobject.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,8 @@ func (controller *PutObject) RequiredPermissions(req *http.Request, repoID, _, d
}

// extractEntryFromCopyReq: get metadata from source file
func extractEntryFromCopyReq(w http.ResponseWriter, req *http.Request, o *PathOperation, copySource string) *catalog.DBEntry {
p, err := getPathFromSource(copySource)
if err != nil {
o.Log(req).WithError(err).Error("could not parse copy source path")
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInvalidCopySource))
return nil
}
ent, err := o.Catalog.GetEntry(req.Context(), o.Repository.Name, p.Reference, p.Path, catalog.GetEntryParams{})
func extractEntryFromCopyReq(w http.ResponseWriter, req *http.Request, o *PathOperation, copySource path.ResolvedAbsolutePath) *catalog.DBEntry {
ent, err := o.Catalog.GetEntry(req.Context(), copySource.Repo, copySource.Reference, copySource.Path, catalog.GetEntryParams{})
if err != nil {
o.Log(req).WithError(err).Error("could not read copy source")
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInvalidCopySource))
Expand Down Expand Up @@ -151,13 +145,31 @@ func handleUploadPart(w http.ResponseWriter, req *http.Request, o *PathOperation
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_UploadPartCopy.html#API_UploadPartCopy_RequestSyntax
if copySource := req.Header.Get(CopySourceHeader); copySource != "" {
// see if there's a range passed as well
ent := extractEntryFromCopyReq(w, req, o, copySource)
resolvedCopySource, err := getPathFromSource(copySource)
if err != nil {
o.Log(req).WithField("copy_source", copySource).WithError(err).Error("could not parse copy source path")
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInvalidCopySource))
return
}
ent := extractEntryFromCopyReq(w, req, o, resolvedCopySource)
if ent == nil {
return // operation already failed
}
srcRepo := o.Repository
if resolvedCopySource.Repo != o.Repository.Name {
srcRepo, err = o.Catalog.GetRepository(req.Context(), resolvedCopySource.Repo)
if err != nil {
o.Log(req).
WithField("copy_source", copySource).
WithError(err).
Error("Failed to get repository")
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInvalidCopySource))
return
}
}

src := block.ObjectPointer{
StorageNamespace: o.Repository.StorageNamespace,
StorageNamespace: srcRepo.StorageNamespace,
IdentifierType: ent.AddressType.ToIdentifierType(),
Identifier: ent.PhysicalAddress,
}
Expand All @@ -184,7 +196,13 @@ func handleUploadPart(w http.ResponseWriter, req *http.Request, o *PathOperation
}

if err != nil {
o.Log(req).WithError(err).WithField("copy_source", ent.Path).Error("copy part " + partNumberStr + " upload failed")
o.Log(req).
WithFields(logging.Fields{
"copy_source": ent.Path,
"part": partNumberStr,
}).
WithError(err).
Error("copy part: upload failed")
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInternalError))
return
}
Expand All @@ -204,7 +222,8 @@ func handleUploadPart(w http.ResponseWriter, req *http.Request, o *PathOperation
},
byteSize, req.Body, uploadID, partNumber)
if err != nil {
o.Log(req).WithError(err).Error("part " + partNumberStr + " upload failed")
o.Log(req).WithField("part", partNumberStr).
WithError(err).Error("part upload failed")
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInternalError))
return
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/gateway/sig/v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ var (
//nolint:gochecknoinits
func init() {
interestingResourcesContainer := []string{
"accelerate", "acl", "cors", "defaultObjectAcl",
"accelerate", "acl", "copy-source", "cors", "defaultObjectAcl",
"location", "logging", "partNumber", "policy",
"requestPayment", "torrent",
"versioning", "versionId", "versions", "website",
Expand Down
29 changes: 29 additions & 0 deletions pkg/testutil/checksum.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package testutil

import (
"hash/crc64"
"io"
)

const bufSize = 4 << 16

var table *crc64.Table = crc64.MakeTable(crc64.ECMA)

// ChecksumReader returns the checksum (CRC-64) of the contents of reader.
func ChecksumReader(reader io.Reader) (uint64, error) {
buf := make([]byte, bufSize)
var val uint64
for {
n, err := reader.Read(buf)
if err != nil {
if err == io.EOF {
return val, nil
}
return val, err
}
if n == 0 {
return val, nil
}
val = crc64.Update(val, table, buf[:n])
}
}
26 changes: 26 additions & 0 deletions pkg/testutil/random.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package testutil

import (
"io"
"math"
"math/rand"
"strings"
"unicode/utf8"
Expand Down Expand Up @@ -33,3 +35,27 @@ func RandomString(rand *rand.Rand, size int) string {
_, lastRuneSize := utf8.DecodeLastRuneInString(ret)
return ret[0 : len(ret)-lastRuneSize]
}

type randomReader struct {
rand *rand.Rand
remaining int64
}

func (r *randomReader) Read(p []byte) (int, error) {
if r.remaining <= 0 {
return 0, io.EOF
}
n := len(p)
if math.MaxInt >= r.remaining && n > int(r.remaining) {
n = int(r.remaining)
}
// n still fits into int!
r.rand.Read(p[:n])
r.remaining -= int64(n)
return n, nil
}

// NewRandomReader returns a reader that will return size bytes from rand.
func NewRandomReader(rand *rand.Rand, size int64) io.Reader {
return &randomReader{rand: rand, remaining: size}
}
Loading