diff --git a/pkg/fetch/auth_headers.go b/pkg/fetch/auth_headers.go index e497509..74fffe1 100644 --- a/pkg/fetch/auth_headers.go +++ b/pkg/fetch/auth_headers.go @@ -15,6 +15,19 @@ func NewAuthHeadersFromQualifier(value string) (*AuthHeaders, error) { return &ah, err } +// NewAuthHeaders creates an empty AuthHeaders +func NewAuthHeaders() *AuthHeaders { + return &AuthHeaders{} +} + +// AddHeader adds a header to the AuthHeaders +func (ah AuthHeaders) AddHeader(uri, header, value string) { + if _, ok := ah[uri]; !ok { + ah[uri] = make(map[string]string) + } + ah[uri][header] = value +} + // ApplyHeaders mutates a http.Request to apply headers requested by the client. func (ah AuthHeaders) ApplyHeaders(uri string, req *http.Request) { if headers, ok := ah[uri]; ok { diff --git a/pkg/fetch/http_fetcher.go b/pkg/fetch/http_fetcher.go index 66bd254..1ee11ee 100644 --- a/pkg/fetch/http_fetcher.go +++ b/pkg/fetch/http_fetcher.go @@ -8,6 +8,7 @@ import ( "io" "log" "net/http" + "strconv" "strings" "github.com/buildbarn/bb-storage/pkg/blobstore" @@ -23,6 +24,17 @@ import ( "google.golang.org/grpc/status" ) +const ( + // QualifierLegacyBazelHTTPHeaders is the qualifier older versions of bazel sends. + QualifierLegacyBazelHTTPHeaders = "bazel.auth_headers" + // QualifierHTTPHeaderPrefix is a qualifer to add a header to all URIs. + // Qualifier will be in the form http_header:
+ QualifierHTTPHeaderPrefix = "http_header:" + // QualifierHTTPHeaderURLPrefix is a qualifier to add a header to a specific URI. + // Qualifier will be in the form http_header_url::
+ QualifierHTTPHeaderURLPrefix = "http_header_url:" +) + type httpFetcher struct { httpClient *http.Client contentAddressableStorage blobstore.BlobAccess @@ -58,7 +70,7 @@ func (hf *httpFetcher) FetchBlob(ctx context.Context, req *remoteasset.FetchBlob digestFunctionEnum = remoteexecution.DigestFunction_SHA256 } - auth, err := getAuthHeaders(req.Qualifiers) + auth, err := getAuthHeaders(req.Uris, req.Qualifiers) if err != nil { return nil, err } @@ -200,11 +212,44 @@ func getChecksumSri(qualifiers []*remoteasset.Qualifier) (string, remoteexecutio return expectedDigest, digestFunctionEnum, nil } -func getAuthHeaders(qualifiers []*remoteasset.Qualifier) (*AuthHeaders, error) { +func getAuthHeaders(uris []string, qualifiers []*remoteasset.Qualifier) (*AuthHeaders, error) { + ah := AuthHeaders{} + perURLQualifiers := map[string]string{} for _, qualifier := range qualifiers { - if qualifier.Name == "bazel.auth_headers" { + // If this is set, then any other headers are ignored + // as this is the only way to set headers in older versions of bazel + if qualifier.Name == QualifierLegacyBazelHTTPHeaders { return NewAuthHeadersFromQualifier(qualifier.Value) } + + if strings.HasPrefix(qualifier.Name, QualifierHTTPHeaderPrefix) { + header := strings.TrimPrefix(qualifier.Name, QualifierHTTPHeaderPrefix) + for _, uri := range uris { + ah.AddHeader(uri, header, qualifier.Value) + } + } + + if strings.HasPrefix(qualifier.Name, QualifierHTTPHeaderURLPrefix) { + perURLQualifiers[qualifier.Name] = qualifier.Value + } } - return nil, nil + // If we have per URL headers, we need to go through and apply them after applying the global headers. + for k, v := range perURLQualifiers { + parts := strings.Split(k, ":") + if len(parts) != 3 { + return nil, status.Errorf(codes.InvalidArgument, "Invalid http_header_url qualifier: %s", k) + } + uriIdx, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "Invalid http_header_url qualifier: %s: Bad URL index: %v: %v", k, parts[1], err) + } + if uriIdx < 0 || uriIdx >= int64(len(uris)) { + return nil, status.Errorf(codes.InvalidArgument, "Invalid http_header_url qualifier: %s: URL index out of range: %v", k, uriIdx) + } + header := parts[2] + ah.AddHeader(uris[uriIdx], header, v) + + } + + return &ah, nil } diff --git a/pkg/fetch/http_fetcher_test.go b/pkg/fetch/http_fetcher_test.go index 57850df..09e6562 100644 --- a/pkg/fetch/http_fetcher_test.go +++ b/pkg/fetch/http_fetcher_test.go @@ -2,6 +2,7 @@ package fetch_test import ( "context" + "fmt" "io" "net/http" "testing" @@ -25,7 +26,7 @@ type headerMatcher struct { } func (hm *headerMatcher) String() string { - return "has headers" + return fmt.Sprintf("has headers: %v", hm.headers) } func (hm *headerMatcher) Matches(x interface{}) bool { @@ -318,7 +319,7 @@ func TestHTTPFetcherFetchBlob(t *testing.T) { require.Equal(t, status.Code(err), codes.NotFound) }) - t.Run("WithAuthHeaders", func(t *testing.T) { + t.Run("WithLegacyAuthHeaders", func(t *testing.T) { request := &remoteasset.FetchBlobRequest{ InstanceName: "", Uris: []string{uri}, @@ -351,6 +352,60 @@ func TestHTTPFetcherFetchBlob(t *testing.T) { require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto())) require.Equal(t, response.Status.Code, int32(codes.OK)) }) + + t.Run("WithAuthHeaders", func(t *testing.T) { + request := &remoteasset.FetchBlobRequest{ + InstanceName: "", + Uris: []string{"www.another.com", uri}, + Qualifiers: []*remoteasset.Qualifier{ + { + Name: "http_header:Authorization", + Value: `Bearer anothertoken`, + }, + { + Name: "http_header:Accept", + Value: "application/vnd.docker.distribution.manifest.list.v2+json", + }, + { + Name: "http_header_url:1:Authorization", + Value: `Bearer letmein1`, + }, + { + Name: "checksum.sri", + Value: "sha256-GF+NsyJx/iX1Yab8k4suJkMG7DBO2lGAB9F2SCY4GWk=", + }, + }, + } + matcherReq1 := &headerMatcher{ + headers: map[string]string{ + "Authorization": "Bearer anothertoken", + "Accept": "application/vnd.docker.distribution.manifest.list.v2+json", + }, + } + matcherReq2 := &headerMatcher{ + headers: map[string]string{ + "Authorization": "Bearer letmein1", + "Accept": "application/vnd.docker.distribution.manifest.list.v2+json", + }, + } + roundTripper.EXPECT().RoundTrip(matcherReq1).Return(&http.Response{ + Status: "404 NotFound", + StatusCode: 404, + }, nil) + httpDoCall2 := roundTripper.EXPECT().RoundTrip(matcherReq2).Return(&http.Response{ + Status: "200 Success", + StatusCode: 200, + Body: body, + ContentLength: 5, + }, nil) + + casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(httpDoCall2) + + response, err := HTTPFetcher.FetchBlob(ctx, request) + require.Nil(t, err) + require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto())) + require.Equal(t, response.Status.Code, int32(codes.OK)) + }) } func TestHTTPFetcherFetchDirectory(t *testing.T) {