Skip to content

Commit

Permalink
add support for sts endpoint when authenticating with aws
Browse files Browse the repository at this point in the history
  • Loading branch information
Tiago Posse committed Jul 15, 2024
1 parent 527510d commit c7dbce0
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 6 deletions.
10 changes: 10 additions & 0 deletions docs/sources/shared/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,11 @@ dynamodb:
# CLI flag: -s3.endpoint
[endpoint: <string> | default = ""]

# Accessing S3 resources using temporary, secure credentials provided by AWS
# Security Token Service.
# CLI flag: -s3.sts-endpoint
[sts_endpoint: <string> | default = ""]

# AWS region to use.
# CLI flag: -s3.region
[region: <string> | default = ""]
Expand Down Expand Up @@ -4965,6 +4970,11 @@ The `s3_storage_config` block configures the connection to Amazon S3 object stor
# CLI flag: -<prefix>.storage.s3.endpoint
[endpoint: <string> | default = ""]

# Accessing S3 resources using temporary, secure credentials provided by AWS
# Security Token Service.
# CLI flag: -<prefix>.storage.s3.sts-endpoint
[sts_endpoint: <string> | default = ""]

# AWS region to use.
# CLI flag: -<prefix>.storage.s3.region
[region: <string> | default = ""]
Expand Down
1 change: 1 addition & 0 deletions pkg/storage/bucket/s3/bucket_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func newS3Config(cfg Config) (s3.Config, error) {
Bucket: cfg.BucketName,
Endpoint: cfg.Endpoint,
Region: cfg.Region,
STSEndpoint: cfg.STSEndpoint,
AccessKey: cfg.AccessKeyID,
SecretKey: cfg.SecretAccessKey.String(),
SessionToken: cfg.SessionToken.String(),
Expand Down
7 changes: 7 additions & 0 deletions pkg/storage/bucket/s3/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ var (
errUnsupportedSignatureVersion = errors.New("unsupported signature version")
errUnsupportedSSEType = errors.New("unsupported S3 SSE type")
errInvalidSSEContext = errors.New("invalid S3 SSE encryption context")
errInvalidSTSEndpoint = errors.New("sts-endpoint must be a valid url")
)

// HTTPConfig stores the http.Transport configuration for the s3 minio client.
Expand All @@ -63,6 +64,7 @@ type Config struct {
Insecure bool `yaml:"insecure"`
SignatureVersion string `yaml:"signature_version"`
StorageClass string `yaml:"storage_class"`
STSEndpoint string `yaml:"sts_endpoint"`

SSE SSEConfig `yaml:"sse"`
HTTP HTTPConfig `yaml:"http"`
Expand All @@ -84,6 +86,7 @@ func (cfg *Config) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) {
f.BoolVar(&cfg.Insecure, prefix+"s3.insecure", false, "If enabled, use http:// for the S3 endpoint instead of https://. This could be useful in local dev/test environments while using an S3-compatible backend storage, like Minio.")
f.StringVar(&cfg.SignatureVersion, prefix+"s3.signature-version", SignatureVersionV4, fmt.Sprintf("The signature version to use for authenticating against S3. Supported values are: %s.", strings.Join(supportedSignatureVersions, ", ")))
f.StringVar(&cfg.StorageClass, prefix+"s3.storage-class", aws.StorageClassStandard, "The S3 storage class to use. Details can be found at https://aws.amazon.com/s3/storage-classes/.")
f.StringVar(&cfg.STSEndpoint, prefix+"s3.sts-endpoint", "", "Accessing S3 resources using temporary, secure credentials provided by AWS Security Token Service.")
cfg.SSE.RegisterFlagsWithPrefix(prefix+"s3.sse.", f)
cfg.HTTP.RegisterFlagsWithPrefix(prefix, f)
}
Expand All @@ -94,6 +97,10 @@ func (cfg *Config) Validate() error {
return errUnsupportedSignatureVersion
}

if cfg.STSEndpoint != "" && !util.IsValidURL(cfg.STSEndpoint) {
return errInvalidSTSEndpoint
}

if err := aws.ValidateStorageClass(cfg.StorageClass); err != nil {
return err
}
Expand Down
39 changes: 33 additions & 6 deletions pkg/storage/chunk/client/aws/s3_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/aws/aws-sdk-go/service/sts"
awscommon "github.com/grafana/dskit/aws"
"github.com/grafana/dskit/backoff"
"github.com/grafana/dskit/flagext"
Expand All @@ -44,6 +46,7 @@ const (
var (
supportedSignatureVersions = []string{SignatureVersionV4}
errUnsupportedSignatureVersion = errors.New("unsupported signature version")
errInvalidSTSEndpoint = errors.New("sts-endpoint must be a valid url")
)

var s3RequestDuration = instrument.NewHistogramCollector(prometheus.NewHistogramVec(prometheus.HistogramOpts{
Expand All @@ -68,6 +71,7 @@ type S3Config struct {

BucketNames string
Endpoint string `yaml:"endpoint"`
STSEndpoint string `yaml:"sts_endpoint"`
Region string `yaml:"region"`
AccessKeyID string `yaml:"access_key_id"`
SecretAccessKey flagext.Secret `yaml:"secret_access_key"`
Expand Down Expand Up @@ -109,6 +113,7 @@ func (cfg *S3Config) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) {
f.Var(&cfg.SecretAccessKey, prefix+"s3.secret-access-key", "AWS Secret Access Key")
f.Var(&cfg.SessionToken, prefix+"s3.session-token", "AWS Session Token")
f.BoolVar(&cfg.Insecure, prefix+"s3.insecure", false, "Disable https on s3 connection.")
f.StringVar(&cfg.STSEndpoint, prefix+"s3.sts-endpoint", "", "Accessing S3 resources using temporary, secure credentials provided by AWS Security Token Service.")

cfg.SSEConfig.RegisterFlagsWithPrefix(prefix+"s3.sse.", f)

Expand All @@ -131,7 +136,15 @@ func (cfg *S3Config) Validate() error {
return errUnsupportedSignatureVersion
}

return storageawscommon.ValidateStorageClass(cfg.StorageClass)
if cfg.STSEndpoint != "" && !util.IsValidURL(cfg.STSEndpoint) {
return errInvalidSTSEndpoint
}

if err := storageawscommon.ValidateStorageClass(cfg.StorageClass); err != nil {
return err
}

return cfg.SSEConfig.Validate()
}

type S3ObjectClient struct {
Expand Down Expand Up @@ -196,13 +209,27 @@ func buildS3Client(cfg S3Config, hedgingCfg hedging.Config, hedging bool) (*s3.S
s3Config = s3Config.WithRegion("dummy")
}

s3Config = s3Config.WithMaxRetries(0) // We do our own retries, so we can monitor them
s3Config = s3Config.WithS3ForcePathStyle(cfg.S3ForcePathStyle) // support for Path Style S3 url if has the flag
customEndpointResolver := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
if cfg.Endpoint != "" && service == s3.EndpointsID {
return endpoints.ResolvedEndpoint{
URL: cfg.Endpoint,
}, nil
}

if cfg.Endpoint != "" {
s3Config = s3Config.WithEndpoint(cfg.Endpoint)
if cfg.STSEndpoint != "" && service == sts.EndpointsID {
return endpoints.ResolvedEndpoint{
URL: cfg.STSEndpoint,
}, nil
}

return endpoints.DefaultResolver().EndpointFor(service, region, optFns...)
}

s3Config = s3Config.WithEndpointResolver(endpoints.ResolverFunc(customEndpointResolver))

s3Config = s3Config.WithMaxRetries(0) // We do our own retries, so we can monitor them
s3Config = s3Config.WithS3ForcePathStyle(cfg.S3ForcePathStyle) // support for Path Style S3 url if has the flag

if cfg.Insecure {
s3Config = s3Config.WithDisableSSL(true)
}
Expand Down Expand Up @@ -257,6 +284,7 @@ func buildS3Client(cfg S3Config, hedgingCfg hedging.Config, hedging bool) (*s3.S
if cfg.Inject != nil {
transport = cfg.Inject(transport)
}

httpClient := &http.Client{
Transport: transport,
Timeout: cfg.HTTPConfig.Timeout,
Expand All @@ -270,7 +298,6 @@ func buildS3Client(cfg S3Config, hedgingCfg hedging.Config, hedging bool) (*s3.S
}

s3Config = s3Config.WithHTTPClient(httpClient)

sess, err := session.NewSession(s3Config)
if err != nil {
return nil, errors.Wrap(err, "failed to create new s3 session")
Expand Down
9 changes: 9 additions & 0 deletions pkg/util/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,12 @@ func FlagFromValues(values url.Values, key string, d bool) bool {
return d
}
}

func IsValidURL(endpoint string) bool {
u, err := url.Parse(endpoint)
if err != nil {
return false
}

return u.Scheme != "" && u.Host != ""
}

0 comments on commit c7dbce0

Please sign in to comment.