Skip to content

Commit

Permalink
Presign S3 URLs using fresh credentials and cap their TTL (#6392)
Browse files Browse the repository at this point in the history
* Refresh web identity token file periodically

Fixes #6385.

* Trim S3 token expiry to match client expiry

Otherwise the AWS SDK pre-signs the URL to be valid supposedly until the
requested configured timeout of 15m, and encode that into the X-Amz-Expires
head.

When used, S3 will (correctly!) refuse the URL after expiration of the
client token that lakeFS used to sign.  And the poor client gets a "Token
Expired" error message.

This is particularly unfortunate with DataBricks Unity, which appears to
ignore the reported "expirationTimestamp".

* Actively refresh client credentials before pre-signing

* [CR] Make IRSA "web identity" expiration params configurable
  • Loading branch information
arielshaqed authored Aug 16, 2023
1 parent 775b243 commit fb8d0c5
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 53 deletions.
31 changes: 26 additions & 5 deletions pkg/block/factory/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"cloud.google.com/go/storage"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/treeverse/lakefs/pkg/block"
Expand All @@ -24,8 +25,10 @@ import (
"google.golang.org/api/option"
)

// googleAuthCloudPlatform - Cloud Storage authentication https://cloud.google.com/storage/docs/authentication
const googleAuthCloudPlatform = "https://www.googleapis.com/auth/cloud-platform"
const (
// googleAuthCloudPlatform - Cloud Storage authentication https://cloud.google.com/storage/docs/authentication
googleAuthCloudPlatform = "https://www.googleapis.com/auth/cloud-platform"
)

func BuildBlockAdapter(ctx context.Context, statsCollector stats.Collector, c params.AdapterConfig) (block.Adapter, error) {
blockstore := c.BlockstoreType()
Expand Down Expand Up @@ -82,15 +85,30 @@ func buildLocalAdapter(ctx context.Context, params params.Local) (*local.Adapter
return adapter, nil
}

func BuildS3Client(params *aws.Config, skipVerifyCertificateTestOnly bool) (*session.Session, error) {
func BuildS3Client(awsConfig *aws.Config, webIdentity *params.S3WebIdentity, skipVerifyCertificateTestOnly bool) (*session.Session, error) {
client := http.DefaultClient
if skipVerifyCertificateTestOnly {
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec
}
client = &http.Client{Transport: tr}
}
sess, err := session.NewSession(params, aws.NewConfig().WithHTTPClient(client))
opts := session.Options{}
opts.Config.MergeIn(awsConfig, aws.NewConfig().WithHTTPClient(client))
if webIdentity != nil {
wi := *webIdentity // Copy WebIdentity: it will be used asynchronously.
opts.CredentialsProviderOptions = &session.CredentialsProviderOptions{
WebIdentityRoleProviderOptions: func(wirp *stscreds.WebIdentityRoleProvider) {
if wi.SessionDuration > 0 {
wirp.Duration = wi.SessionDuration
}
if wi.SessionExpiryWindow > 0 {
wirp.ExpiryWindow = wi.SessionExpiryWindow
}
},
}
}
sess, err := session.NewSessionWithOptions(opts)
if err != nil {
return nil, err
}
Expand All @@ -99,7 +117,7 @@ func BuildS3Client(params *aws.Config, skipVerifyCertificateTestOnly bool) (*ses
}

func buildS3Adapter(ctx context.Context, statsCollector stats.Collector, params params.S3) (*s3a.Adapter, error) {
sess, err := BuildS3Client(params.AwsConfig, params.SkipVerifyCertificateTestOnly)
sess, err := BuildS3Client(params.AwsConfig, params.WebIdentity, params.SkipVerifyCertificateTestOnly)
if err != nil {
return nil, err
}
Expand All @@ -118,6 +136,9 @@ func buildS3Adapter(ctx context.Context, statsCollector stats.Collector, params
if params.ServerSideEncryptionKmsKeyID != "" {
opts = append(opts, s3a.WithServerSideEncryptionKmsKeyID(params.ServerSideEncryptionKmsKeyID))
}
if params.WebIdentity != nil && params.WebIdentity.SessionExpiryWindow > 0 {
opts = append(opts, s3a.WithPreSignedRefreshWindow(params.WebIdentity.SessionExpiryWindow))
}
adapter := s3a.NewAdapter(sess, opts...)
logging.FromContext(ctx).WithField("type", "s3").Info("initialized blockstore adapter")
return adapter, nil
Expand Down
14 changes: 14 additions & 0 deletions pkg/block/params/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ type Local struct {
AllowedExternalPrefixes []string
}

// S3WebIdentity contains parameters for customizing S3 web identity. This
// is also used when configuring S3 with IRSA in EKS (Kubernetes).
type S3WebIdentity struct {
// SessionDuration is the duration WebIdentityRoleProvider will
// request for a token for its assumed role. It can be 1 hour or
// more, but its maximum is configurable on AWS.
SessionDuration time.Duration

// SessionExpiryWindow is the time before credentials expiry that
// the WebIdentityRoleProvider may request a fresh token.
SessionExpiryWindow time.Duration
}

type S3 struct {
AwsConfig *aws.Config
StreamingChunkSize int
Expand All @@ -35,6 +48,7 @@ type S3 struct {
PreSignedExpiry time.Duration
DisablePreSigned bool
DisablePreSignedUI bool
WebIdentity *S3WebIdentity
}

type GS struct {
Expand Down
102 changes: 84 additions & 18 deletions pkg/block/s3/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type Adapter struct {
ServerSideEncryption string
ServerSideEncryptionKmsKeyID string
preSignedExpiry time.Duration
preSignedRefreshWindow time.Duration
disablePreSigned bool
disablePreSignedUI bool
}
Expand Down Expand Up @@ -74,6 +75,12 @@ func WithDiscoverBucketRegion(b bool) func(a *Adapter) {
}
}

func WithPreSignedRefreshWindow(v time.Duration) func(a *Adapter) {
return func(a *Adapter) {
a.preSignedRefreshWindow = v
}
}

func WithPreSignedExpiry(v time.Duration) func(a *Adapter) {
return func(a *Adapter) {
a.preSignedExpiry = v
Expand Down Expand Up @@ -321,6 +328,43 @@ func (a *Adapter) GetWalker(uri *url.URL) (block.Walker, error) {
return NewS3Walker(a.clients.awsSession), nil
}

// refreshClientIfNeeded ensure client has some time before its token
// expires. It returns the updated expiry time. It fails with
// ErrDoesntExpire if the client is not an Expirer,
func refreshClientIfNeeded(ctx context.Context, client S3APIWithExpirer, refreshWindow time.Duration) (time.Time, error) {
if client == nil {
return time.Time{}, nil
}

expiry, err := client.ExpiresAt()
if errors.Is(err, ErrDoesntExpire) {
return time.Time{}, ErrDoesntExpire
} else if err != nil {
return time.Time{}, fmt.Errorf("refresh client if needed: get current expiry: %w", err)
}

ttl := time.Until(expiry)
l := logging.FromContext(ctx).WithFields(logging.Fields{
"expiry": expiry,
"TTL": ttl.String(),
})
if ttl < refreshWindow {
l.Info("Refresh client as it will expire soon")
expiry, err = client.Refresh()
if err != nil {
return time.Time{}, fmt.Errorf("refresh client if needed: refreshing: %w", err)
}
ttl = time.Until(expiry)
l = l.WithFields(logging.Fields{
"expiry": expiry,
"TTL": ttl.String(),
})
l.Info("Refreshed client")
}
l.Trace("Got client")
return expiry, nil
}

func (a *Adapter) GetPreSignedURL(ctx context.Context, obj block.ObjectPointer, mode block.PreSignMode) (string, time.Time, error) {
if a.disablePreSigned {
return "", time.Time{}, block.ErrOperationNotSupported
Expand All @@ -334,43 +378,65 @@ func (a *Adapter) GetPreSignedURL(ctx context.Context, obj block.ObjectPointer,
WithError(err).Error("could not resolve namespace")
return "", time.Time{}, err
}
var preSignedURL string

client := a.clients.Get(ctx, qualifiedKey.GetStorageNamespace())

clientExpiry, clientExpiryErr := refreshClientIfNeeded(ctx, client, a.preSignedRefreshWindow)

expiry := time.Now().Add(a.preSignedExpiry)
log = log.WithField("expiry", expiry)
switch {
case clientExpiryErr == nil:
if clientExpiry.Before(expiry) && !clientExpiry.IsZero() {
log.WithField("client_expiry", clientExpiry).
Trace("URL expiry shortened by client expiry")
// TODO(ariels): Monitor this?
expiry = clientExpiry
log = log.WithField("expiry", expiry)
}
case errors.Is(clientExpiryErr, ErrDoesntExpire):
break
default:
log.WithFields(logging.Fields{
"namespace": obj.StorageNamespace,
"identifier": obj.Identifier,
}).
WithError(err).
Warning("Failed to get client (token) expiry: URL expiry may be too high")
}

// BUG(ariels): This is an inherent race. urlLifetime is computed
// relative to the local clock. If expiry was shortened because
// of clientExpiry then AWS will determine _remotely_ whether
// the URL expired. So this URL can expire before the client or
// even lakeFS think that it has.
//
// This is a limitation of the AWS SDK, which signs locally, and
// of the AWS S3 API, which does not allow a meaningful
// workaround.
urlLifetime := time.Until(expiry)
log = log.WithField("TTL", urlLifetime)
var preSignedURL string
if mode == block.PreSignModeWrite {
putObjectInput := &s3.PutObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
}
req, _ := client.PutObjectRequest(putObjectInput)
preSignedURL, err = req.Presign(a.preSignedExpiry)
preSignedURL, err = req.Presign(urlLifetime)
} else {
getObjectInput := &s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
}
req, _ := client.GetObjectRequest(getObjectInput)
preSignedURL, err = req.Presign(a.preSignedExpiry)
preSignedURL, err = req.Presign(urlLifetime)
}
if err != nil {
log.WithField("namespace", obj.StorageNamespace).
WithField("identifier", obj.Identifier).
WithError(err).Error("could not pre-sign request")
}
expiry := time.Now().Add(a.preSignedExpiry)
clientExpiry, clientExpiryErr := client.ExpiresAt()
switch {
case clientExpiryErr == nil:
if clientExpiry.Before(expiry) {
expiry = clientExpiry
}
case errors.Is(clientExpiryErr, ErrDoesntExpire):
break
default:
log.WithFields(logging.Fields{
"namespace": obj.StorageNamespace,
"identifier": obj.Identifier,
}).WithError(err).Warning("Failed to get client (token) expiry: URL expiry may be too high")
}
return preSignedURL, expiry, err
}

Expand Down
43 changes: 14 additions & 29 deletions pkg/block/s3/client_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type Expirer interface {
// a ErrDoesntExpire if it cannot determine expiry times -- for
// instance, if AWS is configured using an access key.
ExpiresAt() (time.Time, error)
// Refresh attempts to refresh and returns ExpiresAt().
Refresh() (time.Time, error)
}

type S3APIWithExpirer interface {
Expand Down Expand Up @@ -87,6 +89,15 @@ func (c *s3Client) ExpiresAt() (time.Time, error) {
return expiryTime, err
}

func (c *s3Client) Refresh() (time.Time, error) {
c.awsSession.Config.Credentials.Expire()
_, err := c.awsSession.Config.Credentials.Get()
if err != nil {
return time.Time{}, fmt.Errorf("refresh credentials: %w", err)
}
return c.ExpiresAt()
}

func NewClientCache(awsSession *session.Session) *ClientCache {
return &ClientCache{
awsSession: awsSession,
Expand Down Expand Up @@ -122,37 +133,12 @@ func (c *ClientCache) getBucketRegion(ctx context.Context, bucket string) string
}

// Get returns an AWS client configured to the region of the given bucket.
func (c *ClientCache) Get(ctx context.Context, bucket string) (ret S3APIWithExpirer) {
defer func() {
if ret == nil {
return
}
expiry, err := ret.ExpiresAt()
ttl := time.Until(expiry)
l := logging.FromContext(ctx)
if !l.IsTracing() && ttl > 0 {
return
}
if err != nil {
l = l.WithField("error", err)
} else if !expiry.IsZero() {
l = l.WithFields(logging.Fields{
"expiry": expiry,
"TTL": ttl.String(),
})
}
ll := l.Trace
if ttl <= 5*time.Second {
ll = l.Warn
}
ll("Got client")
}()

func (c *ClientCache) Get(ctx context.Context, bucket string) S3APIWithExpirer {
region := c.getBucketRegion(ctx, bucket)
svc, hasClient := c.regionToS3Client.Load(region)
if !hasClient {
logging.FromContext(ctx).WithField("bucket", bucket).WithField("region", region).Debug("creating client for region")
ret = c.clientFactory(c.awsSession, &aws.Config{Region: swag.String(region)})
ret := c.clientFactory(c.awsSession, &aws.Config{Region: swag.String(region)})
c.regionToS3Client.Store(region, ret)
if c.collector != nil {
c.collector.CollectEvent(stats.Event{
Expand All @@ -162,8 +148,7 @@ func (c *ClientCache) Get(ctx context.Context, bucket string) (ret S3APIWithExpi
}
return ret
} else {
ret = svc.(S3APIWithExpirer)
return ret
return svc.(S3APIWithExpirer)
}
}

Expand Down
4 changes: 4 additions & 0 deletions pkg/block/s3/client_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ func (f FakeS3APIWithExpirer) ExpiresAt() (time.Time, error) {
return time.Time{}, errFakeExpires
}

func (f FakeS3APIWithExpirer) Refresh() (time.Time, error) {
return f.ExpiresAt()
}

func TestClientCache(t *testing.T) {
defaultRegion := "us-west-2"
sess, err := session.NewSession(&aws.Config{Region: &defaultRegion})
Expand Down
12 changes: 12 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ type Config struct {
PreSignedExpiry time.Duration `mapstructure:"pre_signed_expiry"`
DisablePreSigned bool `mapstructure:"disable_pre_signed"`
DisablePreSignedUI bool `mapstructure:"disable_pre_signed_ui"`
WebIdentity *struct {
SessionDuration time.Duration `mapstructure:"session_duration"`
SessionExpiryWindow time.Duration `mapstructure:"session_expiry_window"`
} `mapstructure:"web_identity"`
} `mapstructure:"s3"`
Azure *struct {
TryTimeout time.Duration `mapstructure:"try_timeout"`
Expand Down Expand Up @@ -547,6 +551,13 @@ func (c *Config) BlockstoreType() string {
}

func (c *Config) BlockstoreS3Params() (blockparams.S3, error) {
var webIdentity *blockparams.S3WebIdentity
if c.Blockstore.S3.WebIdentity != nil {
webIdentity = &blockparams.S3WebIdentity{
SessionDuration: c.Blockstore.S3.WebIdentity.SessionDuration,
SessionExpiryWindow: c.Blockstore.S3.WebIdentity.SessionExpiryWindow,
}
}
return blockparams.S3{
AwsConfig: c.GetAwsConfig(),
StreamingChunkSize: c.Blockstore.S3.StreamingChunkSize,
Expand All @@ -558,6 +569,7 @@ func (c *Config) BlockstoreS3Params() (blockparams.S3, error) {
PreSignedExpiry: c.Blockstore.S3.PreSignedExpiry,
DisablePreSigned: c.Blockstore.S3.DisablePreSigned,
DisablePreSignedUI: c.Blockstore.S3.DisablePreSignedUI,
WebIdentity: webIdentity,
}, nil
}

Expand Down
1 change: 1 addition & 0 deletions pkg/config/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ func setDefaults(cfgType string) {
viper.SetDefault("blockstore.s3.max_retries", 5)
viper.SetDefault("blockstore.s3.discover_bucket_region", true)
viper.SetDefault("blockstore.s3.pre_signed_expiry", 15*time.Minute)
viper.SetDefault("blockstore.s3.web_identity.session_expiry_window", 5*time.Minute)
viper.SetDefault("blockstore.s3.disable_pre_signed_ui", true)

viper.SetDefault("committed.local_cache.size_bytes", 1*1024*1024*1024)
Expand Down
2 changes: 1 addition & 1 deletion pkg/ingest/store/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (f *WalkerFactory) buildS3Walker(opts WalkerOptions) (*s3.Walker, error) {
if err != nil {
return nil, err
}
sess, err = factory.BuildS3Client(s3params.AwsConfig, s3params.SkipVerifyCertificateTestOnly)
sess, err = factory.BuildS3Client(s3params.AwsConfig, s3params.WebIdentity, s3params.SkipVerifyCertificateTestOnly)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit fb8d0c5

Please sign in to comment.