Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/pr/253'
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelkarp committed Feb 11, 2021
2 parents 0ad55a1 + e252fbf commit 6ddba75
Show file tree
Hide file tree
Showing 232 changed files with 27,558 additions and 91,085 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ ECR registry:
```json
{
"credHelpers": {
"public.ecr.aws": "ecr-login",
"<aws_account_id>.dkr.ecr.<region>.amazonaws.com": "ecr-login"
}
}
Expand Down Expand Up @@ -204,6 +205,8 @@ The credentials must have a policy applied that

`docker push 123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repository:my-tag`

`docker pull public.ecr.aws/amazonlinux/amazonlinux:latest`

If you have configured additional profiles for use with the AWS CLI, you can use
those profiles by specifying the `AWS_PROFILE` environment variable when invoking `docker`.
For example:
Expand Down
3 changes: 2 additions & 1 deletion docs/docker-credential-ecr-login.1
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ section with the URI of your ECR registry:
.nf
{
"credHelpers": {
"aws_account_id.dkr.ecr.region.amazonaws.com":"ecr-login"
"public.ecr.aws": "ecr-login",
"aws_account_id.dkr.ecr.region.amazonaws.com": "ecr-login"
}
}
.fi
Expand Down
168 changes: 127 additions & 41 deletions ecr-login/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,70 @@ package api
import (
"encoding/base64"
"fmt"
"net/url"
"regexp"
"strings"
"time"

"github.com/pkg/errors"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ecr"
"github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache"
"github.com/aws/aws-sdk-go/service/ecrpublic"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"

"github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache"
)

const (
proxyEndpointScheme = "https://"
programName = "docker-credential-ecr-login"
ecrPublicName = "public.ecr.aws"
ecrPublicEndpoint = proxyEndpointScheme + ecrPublicName
)

const proxyEndpointScheme = "https://"
const programName = "docker-credential-ecr-login"
var ecrPattern = regexp.MustCompile(`(^[a-zA-Z0-9][a-zA-Z0-9-_]*)\.dkr\.ecr(-fips)?\.([a-zA-Z0-9][a-zA-Z0-9-_]*)\.amazonaws\.com(\.cn)?$`)

type Service string

var ecrPattern = regexp.MustCompile(`(^[a-zA-Z0-9][a-zA-Z0-9-_]*)\.dkr\.ecr(\-fips)?\.([a-zA-Z0-9][a-zA-Z0-9-_]*)\.amazonaws\.com(\.cn)?`)
const (
ServiceECR Service = "ecr"
ServiceECRPublic Service = "ecr-public"
)

// Registry in ECR
type Registry struct {
ID string
FIPS bool
Region string
Service Service
ID string
FIPS bool
Region string
}

// ExtractRegistry returns the ECR registry behind a given service endpoint
func ExtractRegistry(serverURL string) (*Registry, error) {
if strings.HasPrefix(serverURL, proxyEndpointScheme) {
serverURL = strings.TrimPrefix(serverURL, proxyEndpointScheme)
func ExtractRegistry(input string) (*Registry, error) {
if strings.HasPrefix(input, proxyEndpointScheme) {
input = strings.TrimPrefix(input, proxyEndpointScheme)
}
serverURL, err := url.Parse(proxyEndpointScheme + input)
if err != nil {
return nil, err
}
if serverURL.Hostname() == ecrPublicName {
return &Registry{
Service: ServiceECRPublic,
}, nil
}
matches := ecrPattern.FindStringSubmatch(serverURL)
matches := ecrPattern.FindStringSubmatch(serverURL.Hostname())
if len(matches) == 0 {
return nil, fmt.Errorf(programName + " can only be used with Amazon Elastic Container Registry.")
} else if len(matches) < 3 {
return nil, fmt.Errorf(serverURL + "is not a valid repository URI for Amazon Elastic Container Registry.")
return nil, fmt.Errorf("%q is not a valid repository URI for Amazon Elastic Container Registry.", input)
}
registry := &Registry{
ID: matches[1],
FIPS: matches[2] == "-fips",
Region: matches[3],
}
return registry, nil
return &Registry{
Service: ServiceECR,
ID: matches[1],
FIPS: matches[2] == "-fips",
Region: matches[3],
}, nil
}

// Client used for calling ECR service
Expand All @@ -65,19 +88,26 @@ type Client interface {
GetCredentialsByRegistryID(registryID string) (*Auth, error)
ListCredentials() ([]*Auth, error)
}

// Auth credentials returned by ECR service to allow docker login
type Auth struct {
ProxyEndpoint string
Username string
Password string
}

type defaultClient struct {
ecrClient ECRAPI
ecrPublicClient ECRPublicAPI
credentialCache cache.CredentialsCache
}

type ECRAPI interface {
GetAuthorizationToken(*ecr.GetAuthorizationTokenInput) (*ecr.GetAuthorizationTokenOutput, error)
}

// Auth credentials returned by ECR service to allow docker login
type Auth struct {
ProxyEndpoint string
Username string
Password string
type ECRPublicAPI interface {
GetAuthorizationToken(*ecrpublic.GetAuthorizationTokenInput) (*ecrpublic.GetAuthorizationTokenOutput, error)
}

// GetCredentials returns username, password, and proxyEndpoint
Expand All @@ -87,11 +117,18 @@ func (c *defaultClient) GetCredentials(serverURL string) (*Auth, error) {
return nil, err
}
logrus.
WithField("service", registry.Service).
WithField("registry", registry.ID).
WithField("region", registry.Region).
WithField("serverURL", serverURL).
Debug("Retrieving credentials")
return c.GetCredentialsByRegistryID(registry.ID)
switch registry.Service {
case ServiceECR:
return c.GetCredentialsByRegistryID(registry.ID)
case ServiceECRPublic:
return c.GetPublicCredentials()
}
return nil, fmt.Errorf("unknown service %q", registry.Service)
}

// GetCredentialsByRegistryID returns username, password, and proxyEndpoint
Expand Down Expand Up @@ -120,8 +157,42 @@ func (c *defaultClient) GetCredentialsByRegistryID(registryID string) (*Auth, er
return auth, err
}

func (c *defaultClient) GetPublicCredentials() (*Auth, error) {
cachedEntry := c.credentialCache.GetPublic()
if cachedEntry != nil {
if cachedEntry.IsValid(time.Now()) {
logrus.WithField("registry", ecrPublicName).Debug("Using cached token")
return extractToken(cachedEntry.AuthorizationToken, cachedEntry.ProxyEndpoint)
}
logrus.
WithField("requestedAt", cachedEntry.RequestedAt).
WithField("expiresAt", cachedEntry.ExpiresAt).
Debug("Cached token is no longer valid")
}

auth, err := c.getPublicAuthorizationToken()
// if we have a cached token, fall back to avoid failing the request. This may result an expired token
// being returned, but if there is a 500 or timeout from the service side, we'd like to attempt to re-use an
// old token. We invalidate tokens prior to their expiration date to help mitigate this scenario.
if err != nil && cachedEntry != nil {
logrus.WithError(err).Info("Got error fetching authorization token. Falling back to cached token.")
return extractToken(cachedEntry.AuthorizationToken, cachedEntry.ProxyEndpoint)
}
return auth, err
}

func (c *defaultClient) ListCredentials() ([]*Auth, error) {
auths := []*Auth{}
// prime the cache with default authorization tokens
_, err := c.GetCredentialsByRegistryID("")
if err != nil {
logrus.WithError(err).Debug("couldn't get authorization token for default registry")
}
_, err = c.GetPublicCredentials()
if err != nil {
logrus.WithError(err).Debug("couldn't get authorization token for public registry")
}

auths := make([]*Auth, 0)
for _, authEntry := range c.credentialCache.List() {
auth, err := extractToken(authEntry.AuthorizationToken, authEntry.ProxyEndpoint)
if err != nil {
Expand All @@ -131,18 +202,6 @@ func (c *defaultClient) ListCredentials() ([]*Auth, error) {
}
}

// If cache is empty, get authorization token of default registry
if len(auths) == 0 {
logrus.Debug("No credential cache")
auth, err := c.getAuthorizationToken("")
if err != nil {
logrus.WithError(err).Debugf("Couldn't get authorization token")
} else {
auths = append(auths, auth)
}
return auths, err
}

return auths, nil
}

Expand Down Expand Up @@ -177,6 +236,7 @@ func (c *defaultClient) getAuthorizationToken(registryID string) (*Auth, error)
RequestedAt: time.Now(),
ExpiresAt: aws.TimeValue(authData.ExpiresAt),
ProxyEndpoint: aws.StringValue(authData.ProxyEndpoint),
Service: cache.ServiceECR,
}
registry, err := ExtractRegistry(authEntry.ProxyEndpoint)
if err != nil {
Expand All @@ -196,15 +256,41 @@ func (c *defaultClient) getAuthorizationToken(registryID string) (*Auth, error)
return nil, fmt.Errorf("No AuthorizationToken found for %s", registryID)
}

func (c *defaultClient) getPublicAuthorizationToken() (*Auth, error) {
var input *ecrpublic.GetAuthorizationTokenInput

output, err := c.ecrPublicClient.GetAuthorizationToken(input)
if err != nil {
return nil, errors.Wrap(err, "ecr: failed to get authorization token")
}
if output == nil || output.AuthorizationData == nil {
return nil, fmt.Errorf("ecr: missing AuthorizationData in ECR Public response")
}
authData := output.AuthorizationData
token, err := extractToken(aws.StringValue(authData.AuthorizationToken), ecrPublicEndpoint)
if err != nil {
return nil, err
}
authEntry := cache.AuthEntry{
AuthorizationToken: aws.StringValue(authData.AuthorizationToken),
RequestedAt: time.Now(),
ExpiresAt: aws.TimeValue(authData.ExpiresAt),
ProxyEndpoint: ecrPublicEndpoint,
Service: cache.ServiceECRPublic,
}
c.credentialCache.Set(ecrPublicName, &authEntry)
return token, nil
}

func extractToken(token string, proxyEndpoint string) (*Auth, error) {
decodedToken, err := base64.StdEncoding.DecodeString(token)
if err != nil {
return nil, fmt.Errorf("Invalid token: %v", err)
return nil, errors.Wrap(err, "invalid token")
}

parts := strings.SplitN(string(decodedToken), ":", 2)
if len(parts) < 2 {
return nil, fmt.Errorf("Invalid token: expected two parts, got %d", len(parts))
return nil, fmt.Errorf("invalid token: expected two parts, got %d", len(parts))
}

return &Auth{
Expand Down
Loading

0 comments on commit 6ddba75

Please sign in to comment.