From eda3671595b6507c695dd460ae9d659f0505654a Mon Sep 17 00:00:00 2001 From: Trent Clarke Date: Thu, 24 Oct 2024 00:27:56 +1100 Subject: [PATCH] Moves and exposes the AWS OIDC credentials cache (#47840) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Moves and exposes the AWS OIDC credentials cache The IdentityCenter integration users OIDC to authenticate with AWS, and so will re-use the existing OIDC credential caching code used by the external audit storage package. This change - extracts the credential cache from the `externalauditstorage` package, - moves it to the `awsoidc` package to indicate that is generally useful, not just for storage access), and - makes it public. This patch also copies the applicable cache tests from `externalauditstorage`. The credential cache tests in `externalauditstorage` have been preserved because they also test backwards compatibility with AWS SDK v1 credential provider, which the new tests do not. * Test fixup * Linter fixups * Remove ttlValuer * Apply suggestions from code review Co-authored-by: Marek Smoliński --------- Co-authored-by: Marek Smoliński --- lib/integrations/awsoidc/credentialscache.go | 284 ++++++++++++++++++ .../awsoidc/credentialscache_test.go | 226 ++++++++++++++ .../externalauditstorage/configurator.go | 206 ++----------- .../externalauditstorage/configurator_test.go | 2 +- lib/service/service_test.go | 2 +- 5 files changed, 530 insertions(+), 190 deletions(-) create mode 100644 lib/integrations/awsoidc/credentialscache.go create mode 100644 lib/integrations/awsoidc/credentialscache_test.go diff --git a/lib/integrations/awsoidc/credentialscache.go b/lib/integrations/awsoidc/credentialscache.go new file mode 100644 index 0000000000000..1d1ddffe3bf1c --- /dev/null +++ b/lib/integrations/awsoidc/credentialscache.go @@ -0,0 +1,284 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package awsoidc + +import ( + "context" + "errors" + "log/slog" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + + "github.com/gravitational/teleport" +) + +const ( + // TokenLifetime is the lifetime of OIDC tokens used by the + // ExternalAuditStorage service with the AWS OIDC integration. + TokenLifetime = time.Hour + + refreshBeforeExpirationPeriod = 15 * time.Minute + refreshCheckInterval = 30 * time.Second + retrieveTimeout = 30 * time.Second +) + +// GenerateOIDCTokenFn is a function that should return a valid, signed JWT for +// authenticating to AWS via OIDC. +type GenerateOIDCTokenFn func(ctx context.Context, integration string) (string, error) + +type credsOrErr struct { + creds aws.Credentials + err error +} + +// CredentialsCache is used to store and refresh AWS credentials used with +// AWS OIDC integration. +// +// Credentials are valid for 1h, but they cannot be refreshed if Proxy is down, +// so we attempt to refresh the credentials early and retry on failure. +type CredentialsCache struct { + log *slog.Logger + + roleARN arn.ARN + integration string + + // generateOIDCTokenFn is dynamically set after auth is initialized. + generateOIDCTokenFn GenerateOIDCTokenFn + + // initialized communicates (via closing channel) that generateOIDCTokenFn is set. + initialized chan struct{} + closeInitialized func() + + // gotFirstCredsOrErr communicates (via closing channel) that the first + // credsOrErr has been set. + gotFirstCredsOrErr chan struct{} + closeGotFirstCredsOrErr func() + + credsOrErr credsOrErr + credsOrErrMu sync.RWMutex + + stsClient stscreds.AssumeRoleWithWebIdentityAPIClient + clock clockwork.Clock +} + +type CredentialsCacheOptions struct { + // Integration is the name of the Teleport OIDC integration to use + Integration string + + // RoleARN is the ARN of the role to assume once authenticated + RoleARN arn.ARN + + // STSClient is the AWS sts client implementation to use when communicating + // with AWS + STSClient stscreds.AssumeRoleWithWebIdentityAPIClient + + // Log is the logger to use. A default will be supplied if no logger is + // explicitly set + Log *slog.Logger + + // Clock is the clock to use. A default system clock will be provided if + // none is supplied. + Clock clockwork.Clock +} + +func (opts *CredentialsCacheOptions) CheckAndSetDefaults() error { + if opts.STSClient == nil { + return trace.BadParameter("stsClient must be provided") + } + + if opts.Log == nil { + opts.Log = slog.Default().With(teleport.ComponentKey, "AWS-OIDC-CredentialCache") + } + + if opts.Clock == nil { + opts.Clock = clockwork.NewRealClock() + } + + return nil +} + +var errNotReady = errors.New("ExternalAuditStorage: credential cache not yet initialized") + +func NewCredentialsCache(options CredentialsCacheOptions) (*CredentialsCache, error) { + if err := options.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err, "creating credentials cache") + } + + initialized := make(chan struct{}) + gotFirstCredsOrErr := make(chan struct{}) + + return &CredentialsCache{ + roleARN: options.RoleARN, + integration: options.Integration, + log: options.Log.With("integration", options.Integration), + initialized: initialized, + closeInitialized: sync.OnceFunc(func() { close(initialized) }), + gotFirstCredsOrErr: gotFirstCredsOrErr, + closeGotFirstCredsOrErr: sync.OnceFunc(func() { close(gotFirstCredsOrErr) }), + credsOrErr: credsOrErr{err: errNotReady}, + clock: options.Clock, + stsClient: options.STSClient, + }, nil +} + +func (cc *CredentialsCache) SetGenerateOIDCTokenFn(fn GenerateOIDCTokenFn) { + cc.generateOIDCTokenFn = fn + cc.closeInitialized() +} + +// Retrieve implements [aws.CredentialsProvider] and returns the latest cached +// credentials, or an error if no credentials have been generated yet or the +// last generated credentials have expired. +func (cc *CredentialsCache) Retrieve(ctx context.Context) (aws.Credentials, error) { + cc.credsOrErrMu.RLock() + defer cc.credsOrErrMu.RUnlock() + + if cc.credsOrErr.err != nil { + cc.log.WarnContext(ctx, "Returning error to AWS client", errorValue(cc.credsOrErr.err)) + } + + return cc.credsOrErr.creds, cc.credsOrErr.err +} + +func (cc *CredentialsCache) Run(ctx context.Context) { + // Wait for initialized signal before running loop. + select { + case <-cc.initialized: + case <-ctx.Done(): + cc.log.DebugContext(ctx, "Context canceled before initialized.") + return + } + + cc.refreshIfNeeded(ctx) + + ticker := cc.clock.NewTicker(refreshCheckInterval) + defer ticker.Stop() + for { + select { + case <-ticker.Chan(): + cc.refreshIfNeeded(ctx) + case <-ctx.Done(): + cc.log.DebugContext(ctx, "Context canceled, stopping refresh loop.") + return + } + } +} + +func (cc *CredentialsCache) refreshIfNeeded(ctx context.Context) { + credsFromCache, err := cc.Retrieve(ctx) + if err == nil && + credsFromCache.HasKeys() && + cc.clock.Now().Add(refreshBeforeExpirationPeriod).Before(credsFromCache.Expires) { + // No need to refresh, credentials in cache are still valid for longer + // than refreshBeforeExpirationPeriod + return + } + cc.log.DebugContext(ctx, "Refreshing credentials.") + + creds, err := cc.refresh(ctx) + if err != nil { + cc.log.WarnContext(ctx, "Failed to retrieve new credentials", errorValue(err)) + now := cc.clock.Now() + // If we were not able to refresh, check if existing credentials in + // cache are still valid. If yes, just log debug, it will be retried on + // next interval check. + if credsFromCache.HasKeys() && now.Before(credsFromCache.Expires) { + cc.log.DebugContext(ctx, "Continuing to use existing credentials", + slog.Duration( + "ttl", + credsFromCache.Expires.Sub(now).Round(time.Second))) + return + } + // If existing creds are expired, update cached error. + cc.log.ErrorContext(ctx, "Setting cached error", "error", err) + cc.setCredsOrErr(credsOrErr{err: trace.Wrap(err)}) + return + } + + // Refresh went well, update cached creds. + cc.setCredsOrErr(credsOrErr{creds: creds}) + cc.log.DebugContext(ctx, "Successfully refreshed credentials", + slog.Time("expires", creds.Expires)) +} + +func (cc *CredentialsCache) setCredsOrErr(coe credsOrErr) { + cc.credsOrErrMu.Lock() + defer cc.credsOrErrMu.Unlock() + cc.credsOrErr = coe + cc.closeGotFirstCredsOrErr() +} + +func (cc *CredentialsCache) refresh(ctx context.Context) (aws.Credentials, error) { + cc.log.InfoContext(ctx, "Refreshing AWS credentials") + defer cc.log.InfoContext(ctx, "Exiting AWS credentials refresh") + + cc.log.InfoContext(ctx, "Generating Token") + oidcToken, err := cc.generateOIDCTokenFn(ctx, cc.integration) + if err != nil { + cc.log.ErrorContext(ctx, "Token generation failed", errorValue(err)) + return aws.Credentials{}, trace.Wrap(err) + } + + roleProvider := stscreds.NewWebIdentityRoleProvider( + cc.stsClient, + cc.roleARN.String(), + identityToken(oidcToken), + func(wiro *stscreds.WebIdentityRoleOptions) { + wiro.Duration = TokenLifetime + }, + ) + + ctx, cancel := context.WithTimeout(ctx, retrieveTimeout) + defer cancel() + + cc.log.InfoContext(ctx, "Retrieving AWS role credentials") + + creds, err := roleProvider.Retrieve(ctx) + if err != nil { + cc.log.ErrorContext(ctx, "Role retrieval failed", errorValue(err)) + } + + return creds, trace.Wrap(err) +} + +func (cc *CredentialsCache) WaitForFirstCredsOrErr(ctx context.Context) { + cc.log.InfoContext(ctx, "Entering wait on first credential refresh") + defer cc.log.InfoContext(ctx, "Exiting wait on first credential refresh") + + select { + case <-ctx.Done(): + case <-cc.gotFirstCredsOrErr: + } +} + +// identityToken is an implementation of [stscreds.IdentityTokenRetriever] for returning a static token. +type identityToken string + +// GetIdentityToken returns the token configured. +func (j identityToken) GetIdentityToken() ([]byte, error) { + return []byte(j), nil +} + +func errorValue(v error) slog.Attr { + return slog.Any("error", v) +} diff --git a/lib/integrations/awsoidc/credentialscache_test.go b/lib/integrations/awsoidc/credentialscache_test.go new file mode 100644 index 0000000000000..cc997758f70be --- /dev/null +++ b/lib/integrations/awsoidc/credentialscache_test.go @@ -0,0 +1,226 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package awsoidc + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts" + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/google/uuid" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/entitlements" + "github.com/gravitational/teleport/lib/modules" +) + +type fakeSTSClient struct { + clock clockwork.Clock + err error + sync.Mutex +} + +func (f *fakeSTSClient) setError(err error) { + f.Lock() + f.err = err + f.Unlock() +} + +func (f *fakeSTSClient) getError() error { + f.Lock() + defer f.Unlock() + return f.err +} + +func (f *fakeSTSClient) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { + if err := f.getError(); err != nil { + return nil, err + } + + expiration := f.clock.Now().Add(time.Second * time.Duration(*params.DurationSeconds)) + return &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &ststypes.Credentials{ + Expiration: &expiration, + // These are example values taken from https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html + SessionToken: aws.String("AQoDYXdzEE0a8ANXXXXXXXXNO1ewxE5TijQyp+IEXAMPLE"), + SecretAccessKey: aws.String("wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY"), + AccessKeyId: aws.String("ASgeIAIOSFODNN7EXAMPLE"), + }, + }, nil +} + +func TestCredentialsCache(t *testing.T) { + logrus.SetLevel(logrus.DebugLevel) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + modules.SetTestModules(t, &modules.TestModules{ + TestFeatures: modules.Features{ + Cloud: true, + Entitlements: map[entitlements.EntitlementKind]modules.EntitlementInfo{ + entitlements.ExternalAuditStorage: {Enabled: true}, + }, + }, + }) + + // GIVEN a configured and running credential cache... + clock := clockwork.NewFakeClock() + stsClient := &fakeSTSClient{ + clock: clock, + } + cacheUnderTest, err := NewCredentialsCache(CredentialsCacheOptions{ + STSClient: stsClient, + Integration: "test", + Clock: clock, + }) + require.NoError(t, err) + require.NotNil(t, cacheUnderTest) + go cacheUnderTest.Run(ctx) + + advanceClock := func(d time.Duration) { + // Wait for the run loop to actually wait on the clock ticker before advancing. If we advance before + // the loop waits on the ticker, it may never tick. + clock.BlockUntil(1) + clock.Advance(d) + } + + // Set the GenerateOIDCTokenFn to a dumb faked function. + cacheUnderTest.SetGenerateOIDCTokenFn( + func(ctx context.Context, integration string) (string, error) { + return uuid.NewString(), nil + }) + + checkRetrieveCredentials := func(t require.TestingT, expectErr error) { + _, err := cacheUnderTest.Retrieve(ctx) + assert.ErrorIs(t, err, expectErr) + } + + checkRetrieveCredentialsWithExpiry := func(t require.TestingT, expectExpiry time.Time) { + creds, err := cacheUnderTest.Retrieve(ctx) + assert.NoError(t, err) + if err == nil { + assert.WithinDuration(t, expectExpiry, creds.Expires, time.Minute) + } + } + + const ( + // Using a longer wait time to avoid test flakes observed with 1s wait. + waitFor = 10 * time.Second + // We're using a short sleep (1ms) to allow the refresh loop goroutine to get scheduled. + // This keeps the test fast under normal conditions. If there's CPU starvation in CI, + // neither the test goroutine nor the refresh loop are likely getting scheduled often, + // so this shouldn't result in a busy loop. + tick = 1 * time.Millisecond + ) + + t.Run("Retrieve", func(t *testing.T) { + // Assert that credentials can be retrieved when everything is happy. + // EventuallyWithT is necessary to allow credentialsCache.run to be + // scheduled after SetGenerateOIDCTokenFn above. + initialCredentialExpiry := clock.Now().Add(TokenLifetime) + require.EventuallyWithT(t, func(t *assert.CollectT) { + checkRetrieveCredentialsWithExpiry(t, initialCredentialExpiry) + }, waitFor, tick) + }) + + t.Run("CachedCredsArePreservedOnError", func(t *testing.T) { + initialCredentialExpiry := clock.Now().Add(TokenLifetime) + require.EventuallyWithT(t, func(t *assert.CollectT) { + checkRetrieveCredentialsWithExpiry(t, initialCredentialExpiry) + }, waitFor, tick) + + // Assert that the good cached credentials are still used even if sts starts + // returning errors. + stsError := errors.New("test error") + stsClient.setError(stsError) + // Test immediately + checkRetrieveCredentialsWithExpiry(t, initialCredentialExpiry) + // Advance to 1 minute before first refresh attempt + advanceClock(TokenLifetime - refreshBeforeExpirationPeriod - time.Minute) + checkRetrieveCredentialsWithExpiry(t, initialCredentialExpiry) + // Advance to 1 minute after first refresh attempt + advanceClock(2 * time.Minute) + checkRetrieveCredentialsWithExpiry(t, initialCredentialExpiry) + // Advance to 1 minute before credential expiry + advanceClock(refreshBeforeExpirationPeriod - 2*time.Minute) + checkRetrieveCredentialsWithExpiry(t, initialCredentialExpiry) + + // Advance 1 minute past the credential expiry and make sure we get the + // expected error. + advanceClock(2 * time.Minute) + require.EventuallyWithT(t, func(t *assert.CollectT) { + checkRetrieveCredentials(t, stsError) + }, waitFor, tick) + + // Fix STS and make sure we stop getting errors within refreshCheckInterval + stsClient.setError(nil) + advanceClock(refreshCheckInterval) + newCredentialExpiry := clock.Now().Add(TokenLifetime) + require.EventuallyWithT(t, func(t *assert.CollectT) { + checkRetrieveCredentialsWithExpiry(t, newCredentialExpiry) + }, waitFor, tick) + }) + + t.Run("WindowedErrors", func(t *testing.T) { + // Test a scenario where STS is returning errors in two different 10-minute windows: the first surrounding + // the expected cert refresh time, and the second surrounding the cert expiry time. + // In this case the credentials cache should refresh the certs somewhere between those two outages, and + // clients should never see an error retrieving credentials. + newCredentialExpiry := clock.Now().Add(TokenLifetime) + expectedRefreshTime := newCredentialExpiry.Add(-refreshBeforeExpirationPeriod) + credentialsUpdated := false + done := newCredentialExpiry.Add(10 * time.Minute) + stsError := errors.New("test error") + for clock.Now().Before(done) { + if clock.Now().Sub(expectedRefreshTime).Abs() < 5*time.Minute || + clock.Now().Sub(newCredentialExpiry).Abs() < 5*time.Minute { + // Within one of the 10-minute outage windows, make the STS client return errors. + stsClient.setError(stsError) + advanceClock(time.Minute) + } else { + // Not within an outage window, STS client should not return errors. + stsClient.setError(nil) + advanceClock(time.Minute) + + if !credentialsUpdated && clock.Now().After(expectedRefreshTime) { + // This is after the expected refresh time and not within an outage window, for the test to + // not be flaky we need to wait for the cache run loop to get a chance to refresh the + // credentials. + expectedExpiry := clock.Now().Add(TokenLifetime) + require.EventuallyWithT(t, func(t *assert.CollectT) { + creds, err := cacheUnderTest.Retrieve(ctx) + assert.NoError(t, err) + assert.WithinDuration(t, expectedExpiry, creds.Expires, 2*time.Minute) + }, waitFor, tick) + credentialsUpdated = true + } + } + + // Assert that there is never an error getting credentials. + checkRetrieveCredentials(t, nil) + } + }) +} diff --git a/lib/integrations/externalauditstorage/configurator.go b/lib/integrations/externalauditstorage/configurator.go index 050298f9e13e0..66cea204a57cc 100644 --- a/lib/integrations/externalauditstorage/configurator.go +++ b/lib/integrations/externalauditstorage/configurator.go @@ -20,23 +20,21 @@ package externalauditstorage import ( "context" - "errors" - "sync" "time" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" - "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/externalauditstorage" "github.com/gravitational/teleport/entitlements" + "github.com/gravitational/teleport/lib/integrations/awsoidc" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/services" ) @@ -86,7 +84,7 @@ type Configurator struct { spec *externalauditstorage.ExternalAuditStorageSpec isUsed bool - credentialsCache *credentialsCache + credentialsCache *awsoidc.CredentialsCache } // Options holds options for the Configurator. @@ -202,7 +200,10 @@ func newConfigurator(ctx context.Context, spec *externalauditstorage.ExternalAud "ExternalAuditStorage: configured integration %q does not appear to be an AWS OIDC integration", oidcIntegrationName) } - awsRoleARN := awsOIDCSpec.RoleARN + awsRoleARN, err := arn.Parse(awsOIDCSpec.RoleARN) + if err != nil { + return nil, trace.Wrap(err, "AWS role is not a valid ARN") + } options := &Options{} for _, optFn := range optFns { @@ -212,11 +213,16 @@ func newConfigurator(ctx context.Context, spec *externalauditstorage.ExternalAud return nil, trace.Wrap(err) } - credentialsCache, err := newCredentialsCache(oidcIntegrationName, awsRoleARN, options) + credentialsCache, err := awsoidc.NewCredentialsCache(awsoidc.CredentialsCacheOptions{ + Integration: oidcIntegrationName, + RoleARN: awsRoleARN, + STSClient: options.stsClient, + Clock: options.clock, + }) if err != nil { return nil, trace.Wrap(err) } - go credentialsCache.run(ctx) + go credentialsCache.Run(ctx) // Draft configurator does not need to count errors or create cluster // alerts. @@ -245,13 +251,9 @@ func (c *Configurator) GetSpec() *externalauditstorage.ExternalAuditStorageSpec return c.spec } -// GenerateOIDCTokenFn is a function that should return a valid, signed JWT for -// authenticating to AWS via OIDC. -type GenerateOIDCTokenFn func(ctx context.Context, integration string) (string, error) - // SetGenerateOIDCTokenFn sets the source of OIDC tokens for this Configurator. -func (c *Configurator) SetGenerateOIDCTokenFn(fn GenerateOIDCTokenFn) { - c.credentialsCache.setGenerateOIDCTokenFn(fn) +func (c *Configurator) SetGenerateOIDCTokenFn(fn awsoidc.GenerateOIDCTokenFn) { + c.credentialsCache.SetGenerateOIDCTokenFn(fn) } // CredentialsProvider returns an aws.CredentialsProvider that can be used to @@ -274,185 +276,13 @@ func (p *Configurator) CredentialsProviderSDKV1() credentials.ProviderWithContex // credential providers won't return errors simply due to the cache not being // ready yet. func (p *Configurator) WaitForFirstCredentials(ctx context.Context) { - p.credentialsCache.waitForFirstCredsOrErr(ctx) -} - -// credentialsCache is used to store and refresh AWS credentials used with -// AWS OIDC integration. -// -// Credentials are valid for 1h, but they cannot be refreshed if Proxy is down, -// so we attempt to refresh the credentials early and retry on failure. -// -// credentialsCache is a dependency to both the s3 session uploader and the -// athena audit logger. They are both initialized before auth. However AWS -// credentials using OIDC integration can be obtained only after auth is -// initialized. That's why generateOIDCTokenFn is injected dynamically after -// auth is initialized. Before initialization, credentialsCache will return -// an error on any Retrieve call. -type credentialsCache struct { - log *logrus.Entry - - roleARN string - integration string - - // generateOIDCTokenFn is dynamically set after auth is initialized. - generateOIDCTokenFn GenerateOIDCTokenFn - - // initialized communicates (via closing channel) that generateOIDCTokenFn is set. - initialized chan struct{} - closeInitialized func() - - // gotFirstCredsOrErr communicates (via closing channel) that the first - // credsOrErr has been set. - gotFirstCredsOrErr chan struct{} - closeGotFirstCredsOrErr func() - - credsOrErr credsOrErr - credsOrErrMu sync.RWMutex - - stsClient stscreds.AssumeRoleWithWebIdentityAPIClient - clock clockwork.Clock -} - -type credsOrErr struct { - creds aws.Credentials - err error -} - -func newCredentialsCache(integration, roleARN string, options *Options) (*credentialsCache, error) { - initialized := make(chan struct{}) - gotFirstCredsOrErr := make(chan struct{}) - return &credentialsCache{ - roleARN: roleARN, - integration: integration, - log: logrus.WithField(teleport.ComponentKey, "ExternalAuditStorage.CredentialsCache"), - initialized: initialized, - closeInitialized: sync.OnceFunc(func() { close(initialized) }), - gotFirstCredsOrErr: gotFirstCredsOrErr, - closeGotFirstCredsOrErr: sync.OnceFunc(func() { close(gotFirstCredsOrErr) }), - credsOrErr: credsOrErr{ - err: errors.New("ExternalAuditStorage: credential cache not yet initialized"), - }, - clock: options.clock, - stsClient: options.stsClient, - }, nil -} - -func (cc *credentialsCache) setGenerateOIDCTokenFn(fn GenerateOIDCTokenFn) { - cc.generateOIDCTokenFn = fn - cc.closeInitialized() -} - -// Retrieve implements [aws.CredentialsProvider] and returns the latest cached -// credentials, or an error if no credentials have been generated yet or the -// last generated credentials have expired. -func (cc *credentialsCache) Retrieve(ctx context.Context) (aws.Credentials, error) { - cc.credsOrErrMu.RLock() - defer cc.credsOrErrMu.RUnlock() - return cc.credsOrErr.creds, cc.credsOrErr.err -} - -func (cc *credentialsCache) run(ctx context.Context) { - // Wait for initialized signal before running loop. - select { - case <-cc.initialized: - case <-ctx.Done(): - cc.log.Debug("Context canceled before initialized.") - return - } - - cc.refreshIfNeeded(ctx) - - ticker := cc.clock.NewTicker(refreshCheckInterval) - defer ticker.Stop() - for { - select { - case <-ticker.Chan(): - cc.refreshIfNeeded(ctx) - case <-ctx.Done(): - cc.log.Debugf("Context canceled, stopping refresh loop.") - return - } - } -} - -func (cc *credentialsCache) refreshIfNeeded(ctx context.Context) { - credsFromCache, err := cc.Retrieve(ctx) - if err == nil && - credsFromCache.HasKeys() && - cc.clock.Now().Add(refreshBeforeExpirationPeriod).Before(credsFromCache.Expires) { - // No need to refresh, credentials in cache are still valid for longer - // than refreshBeforeExpirationPeriod - return - } - cc.log.Debugf("Refreshing credentials.") - - creds, err := cc.refresh(ctx) - if err != nil { - cc.log.Warnf("Failed to retrieve new credentials: %v", err) - // If we were not able to refresh, check if existing credentials in cache are still valid. - // If yes, just log debug, it will be retried on next interval check. - if credsFromCache.HasKeys() && cc.clock.Now().Before(credsFromCache.Expires) { - cc.log.Debugf("Using existing credentials expiring in %s.", credsFromCache.Expires.Sub(cc.clock.Now()).Round(time.Second).String()) - return - } - // If existing creds are expired, update cached error. - cc.setCredsOrErr(credsOrErr{err: trace.Wrap(err)}) - return - } - // Refresh went well, update cached creds. - cc.setCredsOrErr(credsOrErr{creds: creds}) - cc.log.Debugf("Successfully refreshed credentials, new expiry at %v", creds.Expires) -} - -func (cc *credentialsCache) setCredsOrErr(coe credsOrErr) { - cc.credsOrErrMu.Lock() - defer cc.credsOrErrMu.Unlock() - cc.credsOrErr = coe - cc.closeGotFirstCredsOrErr() -} - -func (cc *credentialsCache) refresh(ctx context.Context) (aws.Credentials, error) { - oidcToken, err := cc.generateOIDCTokenFn(ctx, cc.integration) - if err != nil { - return aws.Credentials{}, trace.Wrap(err) - } - - roleProvider := stscreds.NewWebIdentityRoleProvider( - cc.stsClient, - cc.roleARN, - identityToken(oidcToken), - func(wiro *stscreds.WebIdentityRoleOptions) { - wiro.Duration = TokenLifetime - }, - ) - - ctx, cancel := context.WithTimeout(ctx, retrieveTimeout) - defer cancel() - - creds, err := roleProvider.Retrieve(ctx) - return creds, trace.Wrap(err) -} - -func (cc *credentialsCache) waitForFirstCredsOrErr(ctx context.Context) { - select { - case <-ctx.Done(): - case <-cc.gotFirstCredsOrErr: - } -} - -// identityToken is an implementation of [stscreds.IdentityTokenRetriever] for returning a static token. -type identityToken string - -// GetIdentityToken returns the token configured. -func (j identityToken) GetIdentityToken() ([]byte, error) { - return []byte(j), nil + p.credentialsCache.WaitForFirstCredsOrErr(ctx) } // v1Adapter wraps the credentialsCache to implement // [credentials.ProviderWithContext] used by aws-sdk-go (v1). type v1Adapter struct { - cc *credentialsCache + cc *awsoidc.CredentialsCache } var _ credentials.ProviderWithContext = (*v1Adapter)(nil) diff --git a/lib/integrations/externalauditstorage/configurator_test.go b/lib/integrations/externalauditstorage/configurator_test.go index abb1ce1425b9e..ba86e5f8e0c27 100644 --- a/lib/integrations/externalauditstorage/configurator_test.go +++ b/lib/integrations/externalauditstorage/configurator_test.go @@ -47,7 +47,7 @@ func testOIDCIntegration(t *testing.T) *types.IntegrationV1 { oidcIntegration, err := types.NewIntegrationAWSOIDC( types.Metadata{Name: "aws-integration-1"}, &types.AWSOIDCIntegrationSpecV1{ - RoleARN: "role1", + RoleARN: "arn:aws:iam::account:role/role1", }, ) require.NoError(t, err) diff --git a/lib/service/service_test.go b/lib/service/service_test.go index 8b029bdab2ca8..ec596200d1edc 100644 --- a/lib/service/service_test.go +++ b/lib/service/service_test.go @@ -503,7 +503,7 @@ func TestAthenaAuditLogSetup(t *testing.T) { oidcIntegration, err := types.NewIntegrationAWSOIDC( types.Metadata{Name: "aws-integration-1"}, &types.AWSOIDCIntegrationSpecV1{ - RoleARN: "role1", + RoleARN: "arn:aws:iam::account:role/role1", }, ) require.NoError(t, err)