Skip to content

Commit

Permalink
Limit redundant work in HTTP fetcher
Browse files Browse the repository at this point in the history
Where possible, redundant in-memory buffering and hashing of downloaded
data is avoided in the HTTP fetcher

In order for these things to be avoided two conditions must be met:

1. The REAPI client must provide the expected checksum of the data being
   requested
2. The HTTP response must include the length of the content being
   downloaded

If the checksum is missing, the file being downloaded must be fetched
and buffered in memory so that the checksum can be determined (despite
the fact that it will end up being hashed again later).

If only the length is missing, the file being downloaded will still be
fetched an buffered in memory in its entirety, but the calculation of
the checksum can be skipped.

Otherwise, the body of the HTTP response is used to create an internal
buffer directly, leading to it being fetched in a more sensible way
as it is being written to the underlying storage.
  • Loading branch information
sdclarke authored and tomcoldrick-ct committed Jun 4, 2024
1 parent 00658e8 commit d7a9200
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 43 deletions.
72 changes: 49 additions & 23 deletions pkg/fetch/http_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@ import (
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"io/ioutil"
"io"
"log"
"net/http"
"strings"

"github.com/buildbarn/bb-remote-asset/pkg/qualifier"
"github.com/buildbarn/bb-storage/pkg/blobstore"
"github.com/buildbarn/bb-storage/pkg/blobstore/buffer"
bb_digest "github.com/buildbarn/bb-storage/pkg/digest"
"github.com/buildbarn/bb-storage/pkg/util"

"github.com/buildbarn/bb-remote-asset/pkg/qualifier"

remoteasset "github.com/bazelbuild/remote-apis/build/bazel/remote/asset/v1"
remoteexecution "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -60,12 +62,14 @@ func (hf *httpFetcher) FetchBlob(ctx context.Context, req *remoteasset.FetchBlob

for _, uri := range req.Uris {

buffer, digest := hf.DownloadBlob(ctx, uri, instanceName, expectedDigest, auth)
buffer, digest := hf.downloadBlob(ctx, uri, instanceName, expectedDigest, auth)
if _, err = buffer.GetSizeBytes(); err != nil {
log.Printf("Error downloading blob with URI %s: %v", uri, err)
continue
}

if err := hf.contentAddressableStorage.Put(ctx, digest, buffer); err != nil {
if err = hf.contentAddressableStorage.Put(ctx, digest, buffer); err != nil {
log.Printf("Error downloading blob with URI %s: %v", uri, err)
return nil, util.StatusWrapWithCode(err, codes.Internal, "Failed to place blob into CAS")
}
return &remoteasset.FetchBlobResponse{
Expand All @@ -87,7 +91,7 @@ func (hf *httpFetcher) CheckQualifiers(qualifiers qualifier.Set) qualifier.Set {
return qualifier.Difference(qualifiers, qualifier.NewSet([]string{"checksum.sri", "bazel.auth_headers", "bazel.canonical_id"}))
}

func (hf *httpFetcher) DownloadBlob(ctx context.Context, uri string, instanceName bb_digest.InstanceName, expectedDigest string, auth *AuthHeaders) (buffer.Buffer, bb_digest.Digest) {
func (hf *httpFetcher) downloadBlob(ctx context.Context, uri string, instanceName bb_digest.InstanceName, expectedDigest string, auth *AuthHeaders) (buffer.Buffer, bb_digest.Digest) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil)
if err != nil {
return buffer.NewBufferFromError(util.StatusWrapWithCode(err, codes.Internal, "Failed to create HTTP request")), bb_digest.BadDigest
Expand All @@ -99,39 +103,61 @@ func (hf *httpFetcher) DownloadBlob(ctx context.Context, uri string, instanceNam

resp, err := hf.httpClient.Do(req)
if err != nil {
log.Printf("Error downloading blob with URI %s: %v", uri, err)
return buffer.NewBufferFromError(util.StatusWrapWithCode(err, codes.Internal, "HTTP request failed")), bb_digest.BadDigest
}
if resp.StatusCode != http.StatusOK {
log.Printf("Error downloading blob with URI %s: %v", uri, resp.StatusCode)
return buffer.NewBufferFromError(status.Errorf(codes.Internal, "HTTP request failed with status %#v", resp.Status)), bb_digest.BadDigest
}

// Read all of the content (Not ideal) | // TODO: find a way to avoid internal buffering here
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return buffer.NewBufferFromError(util.StatusWrapWithCode(err, codes.Internal, "Failed to read response body")), bb_digest.BadDigest
}
nBytes := len(body)

hasher := sha256.New()
hasher.Write(body)
hash := hasher.Sum(nil)
hexHash := hex.EncodeToString(hash)
// Work out the digest of the downloaded data
//
// If the HTTP response includes the content length (indicated by the value
// of the field being >= 0) and the client has provided an expected hash of
// the content, we can avoid holding the contents of the entire file in
// memory at one time by creating a new buffer from the response body
// directly
//
// If either one (or both) of these things is not available, we will need to
// read the enitre response body into a byte slice in order to be able to
// determine the digest
length := resp.ContentLength
body := resp.Body
if length < 0 || expectedDigest == "" {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return buffer.NewBufferFromError(util.StatusWrapWithCode(err, codes.Internal, "Failed to read response body")), bb_digest.BadDigest
}
err = resp.Body.Close()
if err != nil {
return buffer.NewBufferFromError(util.StatusWrapWithCode(err, codes.Internal, "Failed to close response body")), bb_digest.BadDigest
}
length = int64(len(bodyBytes))

// If we don't know what the hash should be we will need to work out the
// actual hash of the content
if expectedDigest == "" {
hasher := sha256.New()
hasher.Write(bodyBytes)
hash := hasher.Sum(nil)
expectedDigest = hex.EncodeToString(hash)
}

if expectedDigest != "" && hexHash != expectedDigest {
return buffer.NewBufferFromError(
status.Errorf(codes.PermissionDenied, "Checksum invalid for fetched blob. Expected: %s, Found: %s", expectedDigest, hexHash)), bb_digest.BadDigest
body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}

digestFunction, err := instanceName.GetDigestFunction(remoteexecution.DigestFunction_UNKNOWN, len(hexHash))
digestFunction, err := instanceName.GetDigestFunction(remoteexecution.DigestFunction_UNKNOWN, len(expectedDigest))
if err != nil {
return buffer.NewBufferFromError(util.StatusWrapfWithCode(err, codes.Internal, "Failed to get digest function for instance: %v", instanceName)), bb_digest.BadDigest
}
digest, err := digestFunction.NewDigest(hexHash, int64(nBytes))
digest, err := digestFunction.NewDigest(expectedDigest, length)
if err != nil {
return buffer.NewBufferFromError(util.StatusWrapWithCode(err, codes.Internal, "Digest Creation failed")), bb_digest.BadDigest
}

return buffer.NewCASBufferFromReader(digest, ioutil.NopCloser(bytes.NewBuffer(body)), buffer.UserProvided), digest
// An error will be generated down the line if the data does not match the
// digest
return buffer.NewCASBufferFromReader(digest, body, buffer.UserProvided), digest
}

func getChecksumSri(qualifiers []*remoteasset.Qualifier) (string, error) {
Expand Down
111 changes: 91 additions & 20 deletions pkg/fetch/http_fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
request := &remoteasset.FetchBlobRequest{
InstanceName: "",
Uris: []string{uri, "www.another.com"},
Qualifiers: []*remoteasset.Qualifier{
{
Name: "checksum.sri",
Value: "sha256-GF+NsyJx/iX1Yab8k4suJkMG7DBO2lGAB9F2SCY4GWk=",
},
},
}
casBlobAccess := mock.NewMockBlobAccess(ctrl)
roundTripper := mock.NewMockRoundTripper(ctrl)
Expand All @@ -68,15 +74,82 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {

t.Run("Success", func(t *testing.T) {
httpDoCall := roundTripper.EXPECT().RoundTrip(gomock.Any()).Return(&http.Response{
Status: "200 Success",
StatusCode: 200,
Body: body,
Status: "200 Success",
StatusCode: 200,
Body: body,
ContentLength: 5,
}, nil)
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(httpDoCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto()))
require.Equal(t, response.Status.Code, int32(codes.OK))
})

t.Run("SuccessNoContentLength", func(t *testing.T) {
httpDoCall := roundTripper.EXPECT().RoundTrip(gomock.Any()).Return(&http.Response{
Status: "200 Success",
StatusCode: 200,
Body: body,
ContentLength: -1,
}, nil)
bodyReadCall := body.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
copy(p, "Hello")
return 5, io.EOF
}).After(httpDoCall)
bodyCloseCall := body.EXPECT().Close().Return(nil).After(bodyReadCall)
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(bodyCloseCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto()))
require.Equal(t, response.Status.Code, int32(codes.OK))
})

t.Run("SuccessNoExpectedDigest", func(t *testing.T) {
request := &remoteasset.FetchBlobRequest{
InstanceName: "",
Uris: []string{uri, "www.another.com"},
Qualifiers: []*remoteasset.Qualifier{},
}
httpDoCall := roundTripper.EXPECT().RoundTrip(gomock.Any()).Return(&http.Response{
Status: "200 Success",
StatusCode: 200,
Body: body,
ContentLength: 5,
}, nil)
bodyReadCall := body.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
copy(p, "Hello")
return 5, io.EOF
}).After(httpDoCall)
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(bodyReadCall)
bodyCloseCall := body.EXPECT().Close().Return(nil).After(bodyReadCall)
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(bodyCloseCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto()))
require.Equal(t, response.Status.Code, int32(codes.OK))
})

t.Run("SuccessNoExpectedDigestOrContentLength", func(t *testing.T) {
request := &remoteasset.FetchBlobRequest{
InstanceName: "",
Uris: []string{uri, "www.another.com"},
Qualifiers: []*remoteasset.Qualifier{},
}
httpDoCall := roundTripper.EXPECT().RoundTrip(gomock.Any()).Return(&http.Response{
Status: "200 Success",
StatusCode: 200,
Body: body,
ContentLength: -1,
}, nil)
bodyReadCall := body.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
copy(p, "Hello")
return 5, io.EOF
}).After(httpDoCall)
bodyCloseCall := body.EXPECT().Close().Return(nil).After(bodyReadCall)
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(bodyCloseCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
Expand All @@ -90,15 +163,12 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
StatusCode: 404,
}, nil)
httpSuccessCall := roundTripper.EXPECT().RoundTrip(gomock.Any()).Return(&http.Response{
Status: "200 Success",
StatusCode: 200,
Body: body,
Status: "200 Success",
StatusCode: 200,
Body: body,
ContentLength: 5,
}, nil).After(httpFailCall)
bodyReadCall := body.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
copy(p, "Hello")
return 5, io.EOF
}).After(httpSuccessCall)
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(bodyReadCall)
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(httpSuccessCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
Expand Down Expand Up @@ -126,6 +196,10 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
Name: "bazel.auth_headers",
Value: `{ "www.example.com": {"Authorization": "Bearer letmein"}}`,
},
{
Name: "checksum.sri",
Value: "sha256-GF+NsyJx/iX1Yab8k4suJkMG7DBO2lGAB9F2SCY4GWk=",
},
},
}
matcher := &headerMatcher{
Expand All @@ -134,15 +208,12 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
},
}
httpDoCall := roundTripper.EXPECT().RoundTrip(matcher).Return(&http.Response{
Status: "200 Success",
StatusCode: 200,
Body: body,
Status: "200 Success",
StatusCode: 200,
Body: body,
ContentLength: 5,
}, nil)
bodyReadCall := body.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
copy(p, "Hello")
return 5, io.EOF
}).After(httpDoCall)
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(bodyReadCall)
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(httpDoCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
require.Nil(t, err)
Expand Down

0 comments on commit d7a9200

Please sign in to comment.