Skip to content

Commit

Permalink
Fix S3 gateway cross-repo copies
Browse files Browse the repository at this point in the history
putobject and friends were confusing src and dst repositories, which will
not do.

Fixes #7467.
  • Loading branch information
arielshaqed committed Feb 15, 2024
1 parent 19defdd commit 754398d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 16 deletions.
8 changes: 4 additions & 4 deletions pkg/block/local/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,13 @@ func (l *Adapter) UploadCopyPart(ctx context.Context, sourceObj, destinationObj
}
r, err := l.Get(ctx, sourceObj, 0)
if err != nil {
return nil, err
return nil, fmt.Errorf("copy get: %w", err)
}
md5Read := block.NewHashingReader(r, block.HashFunctionMD5)
fName := uploadID + fmt.Sprintf("-%05d", partNumber)
err = l.Put(ctx, block.ObjectPointer{StorageNamespace: destinationObj.StorageNamespace, Identifier: fName}, -1, md5Read, block.PutOpts{})
if err != nil {
return nil, err
return nil, fmt.Errorf("copy put: %w", err)
}
etag := hex.EncodeToString(md5Read.Md5.Sum(nil))
return &block.UploadPartResponse{
Expand All @@ -259,13 +259,13 @@ func (l *Adapter) UploadCopyPartRange(ctx context.Context, sourceObj, destinatio
}
r, err := l.GetRange(ctx, sourceObj, startPosition, endPosition)
if err != nil {
return nil, err
return nil, fmt.Errorf("copy range get: %w", err)
}
md5Read := block.NewHashingReader(r, block.HashFunctionMD5)
fName := uploadID + fmt.Sprintf("-%05d", partNumber)
err = l.Put(ctx, block.ObjectPointer{StorageNamespace: destinationObj.StorageNamespace, Identifier: fName}, -1, md5Read, block.PutOpts{})
if err != nil {
return nil, err
return nil, fmt.Errorf("copy range put: %w", err)
}
etag := hex.EncodeToString(md5Read.Md5.Sum(nil))
return &block.UploadPartResponse{
Expand Down
43 changes: 31 additions & 12 deletions pkg/gateway/operations/putobject.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,8 @@ func (controller *PutObject) RequiredPermissions(req *http.Request, repoID, _, d
}

// extractEntryFromCopyReq: get metadata from source file
func extractEntryFromCopyReq(w http.ResponseWriter, req *http.Request, o *PathOperation, copySource string) *catalog.DBEntry {
p, err := getPathFromSource(copySource)
if err != nil {
o.Log(req).WithError(err).Error("could not parse copy source path")
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInvalidCopySource))
return nil
}
ent, err := o.Catalog.GetEntry(req.Context(), o.Repository.Name, p.Reference, p.Path, catalog.GetEntryParams{})
func extractEntryFromCopyReq(w http.ResponseWriter, req *http.Request, o *PathOperation, copySource path.ResolvedAbsolutePath) *catalog.DBEntry {
ent, err := o.Catalog.GetEntry(req.Context(), copySource.Repo, copySource.Reference, copySource.Path, catalog.GetEntryParams{})
if err != nil {
o.Log(req).WithError(err).Error("could not read copy source")
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInvalidCopySource))
Expand Down Expand Up @@ -151,13 +145,31 @@ func handleUploadPart(w http.ResponseWriter, req *http.Request, o *PathOperation
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_UploadPartCopy.html#API_UploadPartCopy_RequestSyntax
if copySource := req.Header.Get(CopySourceHeader); copySource != "" {
// see if there's a range passed as well
ent := extractEntryFromCopyReq(w, req, o, copySource)
resolvedCopySource, err := getPathFromSource(copySource)
if err != nil {
o.Log(req).WithField("copy_source", copySource).WithError(err).Error("could not parse copy source path")
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInvalidCopySource))
return
}
ent := extractEntryFromCopyReq(w, req, o, resolvedCopySource)
if ent == nil {
return // operation already failed
}
srcRepo := o.Repository
if resolvedCopySource.Repo != o.Repository.Name {
srcRepo, err = o.Catalog.GetRepository(req.Context(), resolvedCopySource.Repo)
if err != nil {
o.Log(req).
WithField("copy_source", copySource).
WithError(err).
Error("Failed to get repository")
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInvalidCopySource))
return
}
}

src := block.ObjectPointer{
StorageNamespace: o.Repository.StorageNamespace,
StorageNamespace: srcRepo.StorageNamespace,
IdentifierType: ent.AddressType.ToIdentifierType(),
Identifier: ent.PhysicalAddress,
}
Expand All @@ -184,7 +196,13 @@ func handleUploadPart(w http.ResponseWriter, req *http.Request, o *PathOperation
}

if err != nil {
o.Log(req).WithError(err).WithField("copy_source", ent.Path).Error("copy part " + partNumberStr + " upload failed")
o.Log(req).
WithFields(logging.Fields{
"copy_source": ent.Path,
"part": partNumberStr,
}).
WithError(err).
Error("copy part: upload failed")
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInternalError))
return
}
Expand All @@ -204,7 +222,8 @@ func handleUploadPart(w http.ResponseWriter, req *http.Request, o *PathOperation
},
byteSize, req.Body, uploadID, partNumber)
if err != nil {
o.Log(req).WithError(err).Error("part " + partNumberStr + " upload failed")
o.Log(req).WithField("part", partNumberStr).
WithError(err).Error("part upload failed")
_ = o.EncodeError(w, req, err, gatewayErrors.Codes.ToAPIErr(gatewayErrors.ErrInternalError))
return
}
Expand Down

0 comments on commit 754398d

Please sign in to comment.