diff --git a/lib/srv/alpnproxy/aws_local_proxy.go b/lib/srv/alpnproxy/aws_local_proxy.go index 794fc6b9b78c9..56c1bbabe4dd5 100644 --- a/lib/srv/alpnproxy/aws_local_proxy.go +++ b/lib/srv/alpnproxy/aws_local_proxy.go @@ -22,6 +22,7 @@ import ( "net/http" "strings" + awsv2 "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/service/sts" @@ -33,6 +34,7 @@ import ( appcommon "github.com/gravitational/teleport/lib/srv/app/common" "github.com/gravitational/teleport/lib/utils" awsutils "github.com/gravitational/teleport/lib/utils/aws" + "github.com/gravitational/teleport/lib/utils/aws/migration" ) // AWSAccessMiddleware verifies the requests to AWS proxy are properly signed. @@ -42,6 +44,11 @@ type AWSAccessMiddleware struct { // AWSCredentials are AWS Credentials used by LocalProxy for request's signature verification. AWSCredentials *credentials.Credentials + // AWSCredentialsV2Provider is an aws sdk v2 credential provider used by + // LocalProxy for request's signature verification if AWSCredentials is not + // specified. + AWSCredentialsV2Provider awsv2.CredentialsProvider + Log logrus.FieldLogger assumedRoles utils.SyncMap[string, *sts.AssumeRoleOutput] @@ -55,7 +62,10 @@ func (m *AWSAccessMiddleware) CheckAndSetDefaults() error { } if m.AWSCredentials == nil { - return trace.BadParameter("missing AWSCredentials") + if m.AWSCredentialsV2Provider == nil { + return trace.BadParameter("missing AWSCredentials") + } + m.AWSCredentials = credentials.NewCredentials(migration.NewProviderAdapter(m.AWSCredentialsV2Provider)) } return nil diff --git a/lib/srv/alpnproxy/aws_local_proxy_test.go b/lib/srv/alpnproxy/aws_local_proxy_test.go index 59a044ec8a45c..30f39290d697b 100644 --- a/lib/srv/alpnproxy/aws_local_proxy_test.go +++ b/lib/srv/alpnproxy/aws_local_proxy_test.go @@ -25,6 +25,7 @@ import ( "testing" "time" + credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" v4 "github.com/aws/aws-sdk-go/aws/signer/v4" @@ -42,60 +43,80 @@ func TestAWSAccessMiddleware(t *testing.T) { localProxyCred := credentials.NewStaticCredentials("local-proxy", "local-proxy-secret", "") assumedRoleCred := credentials.NewStaticCredentials("assumed-role", "assumed-role-secret", "assumed-role-token") - stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil) - v4.NewSigner(localProxyCred).Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now()) - - requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil) - v4.NewSigner(assumedRoleCred).Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now()) + tests := []struct { + name string + middleware *AWSAccessMiddleware + }{ + { + name: "v1", + middleware: &AWSAccessMiddleware{ + AWSCredentials: localProxyCred, + }, + }, + { + name: "v2", + middleware: &AWSAccessMiddleware{ + AWSCredentialsV2Provider: credentialsv2.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", ""), + }, + }, + } - m := &AWSAccessMiddleware{ - AWSCredentials: localProxyCred, + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + m := test.middleware + require.NoError(t, m.CheckAndSetDefaults()) + + stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil) + v4.NewSigner(localProxyCred).Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now()) + + requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil) + v4.NewSigner(assumedRoleCred).Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now()) + + t.Run("request no authorization", func(t *testing.T) { + recorder := httptest.NewRecorder() + require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil))) + require.Equal(t, http.StatusForbidden, recorder.Code) + }) + + t.Run("request signed by unknown credentials", func(t *testing.T) { + recorder := httptest.NewRecorder() + require.True(t, m.HandleRequest(recorder, requestByAssumedRole)) + require.Equal(t, http.StatusForbidden, recorder.Code) + }) + + t.Run("request signed by local proxy credentials", func(t *testing.T) { + recorder := httptest.NewRecorder() + require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred)) + require.Equal(t, http.StatusOK, recorder.Code) + }) + + // Verifies sts:AssumeRole output can be handled successfully. The + // credentials should be saved afterwards. + t.Run("handle sts:AssumeRole response", func(t *testing.T) { + response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred) + response.Request = stsRequestByLocalProxyCred + defer response.Body.Close() + require.NoError(t, m.HandleResponse(response)) + }) + + // This is the same request as the "unknown credentials" test above. But at + // this point, the assumed role credentials should have been saved by the + // middleware so the request can be handled successfully now. + t.Run("request signed by assumed role", func(t *testing.T) { + recorder := httptest.NewRecorder() + require.False(t, m.HandleRequest(recorder, requestByAssumedRole)) + require.Equal(t, http.StatusOK, recorder.Code) + }) + + // Verifies non sts:AssumeRole responses do not give errors. + t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) { + response := getCallerIdentityResponse(t, assumedRoleARN) + response.Request = stsRequestByLocalProxyCred + defer response.Body.Close() + require.NoError(t, m.HandleResponse(response)) + }) + }) } - require.NoError(t, m.CheckAndSetDefaults()) - - t.Run("request no authorization", func(t *testing.T) { - recorder := httptest.NewRecorder() - require.True(t, m.HandleRequest(recorder, httptest.NewRequest("", "http://localhost", nil))) - require.Equal(t, http.StatusForbidden, recorder.Code) - }) - - t.Run("request signed by unknown credentials", func(t *testing.T) { - recorder := httptest.NewRecorder() - require.True(t, m.HandleRequest(recorder, requestByAssumedRole)) - require.Equal(t, http.StatusForbidden, recorder.Code) - }) - - t.Run("request signed by local proxy credentials", func(t *testing.T) { - recorder := httptest.NewRecorder() - require.False(t, m.HandleRequest(recorder, stsRequestByLocalProxyCred)) - require.Equal(t, http.StatusOK, recorder.Code) - }) - - // Verifies sts:AssumeRole output can be handled successfully. The - // credentials should be saved afterwards. - t.Run("handle sts:AssumeRole response", func(t *testing.T) { - response := assumeRoleResponse(t, assumedRoleARN, assumedRoleCred) - response.Request = stsRequestByLocalProxyCred - defer response.Body.Close() - require.NoError(t, m.HandleResponse(response)) - }) - - // This is the same request as the "unknown credentials" test above. But at - // this point, the assumed role credentials should have been saved by the - // middleware so the request can be handled successfully now. - t.Run("request signed by assumed role", func(t *testing.T) { - recorder := httptest.NewRecorder() - require.False(t, m.HandleRequest(recorder, requestByAssumedRole)) - require.Equal(t, http.StatusOK, recorder.Code) - }) - - // Verifies non sts:AssumeRole responses do not give errors. - t.Run("handle sts:GetCallerIdentity response", func(t *testing.T) { - response := getCallerIdentityResponse(t, assumedRoleARN) - response.Request = stsRequestByLocalProxyCred - defer response.Body.Close() - require.NoError(t, m.HandleResponse(response)) - }) } func assumeRoleResponse(t *testing.T, roleARN string, cred *credentials.Credentials) *http.Response { diff --git a/lib/utils/aws/migration/migration.go b/lib/utils/aws/migration/migration.go new file mode 100644 index 0000000000000..5288f2ab2a13c --- /dev/null +++ b/lib/utils/aws/migration/migration.go @@ -0,0 +1,92 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package migration + +import ( + "context" + "sync" + + awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/gravitational/trace" +) + +// NewProviderAdapter returns a [ProviderAdapter] that can be used as an AWS SDK +// v1 credentials provider. +func NewProviderAdapter(providerV2 awsv2.CredentialsProvider) *ProviderAdapter { + return &ProviderAdapter{ + providerV2: providerV2, + } +} + +var _ credentials.ProviderWithContext = (*ProviderAdapter)(nil) + +// ProviderAdapter adapts an [awsv2.CredentialsProvider] to an AWS SDK v1 +// credentials provider. +type ProviderAdapter struct { + providerV2 awsv2.CredentialsProvider + + m sync.RWMutex + // creds are retrieved and saved to satisfy IsExpired. + creds awsv2.Credentials +} + +func (a *ProviderAdapter) IsExpired() bool { + a.m.RLock() + defer a.m.RUnlock() + + var emptyCreds awsv2.Credentials + return a.creds == emptyCreds || a.creds.Expired() +} + +func (a *ProviderAdapter) Retrieve() (credentials.Value, error) { + return a.RetrieveWithContext(context.Background()) +} + +func (a *ProviderAdapter) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { + creds, err := a.retrieveLocked(ctx) + if err != nil { + return credentials.Value{}, trace.Wrap(err) + } + + return credentials.Value{ + AccessKeyID: creds.AccessKeyID, + SecretAccessKey: creds.SecretAccessKey, + SessionToken: creds.SessionToken, + ProviderName: creds.Source, + }, nil +} + +func (a *ProviderAdapter) retrieveLocked(ctx context.Context) (awsv2.Credentials, error) { + a.m.Lock() + defer a.m.Unlock() + + var emptyCreds awsv2.Credentials + if a.creds != emptyCreds && !a.creds.Expired() { + return a.creds, nil + } + + creds, err := a.providerV2.Retrieve(ctx) + if err != nil { + return emptyCreds, trace.Wrap(err) + } + + a.creds = creds + return creds, nil +} diff --git a/tool/tsh/common/app_aws.go b/tool/tsh/common/app_aws.go index 0869d5c09cc01..7c11458b2f300 100644 --- a/tool/tsh/common/app_aws.go +++ b/tool/tsh/common/app_aws.go @@ -27,8 +27,9 @@ import ( "strings" "sync" - awsarn "github.com/aws/aws-sdk-go/aws/arn" - "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/credentials" "github.com/google/uuid" "github.com/gravitational/trace" @@ -136,7 +137,7 @@ type awsApp struct { cf *CLIConf - credentials *credentials.Credentials + credentials aws.CredentialsProvider credentialsOnce sync.Once } @@ -168,13 +169,8 @@ func (a *awsApp) GetAppName() string { // The first method is always preferred as the original hostname is preserved // through forward proxy. func (a *awsApp) StartLocalProxies(ctx context.Context, opts ...alpnproxy.LocalProxyConfigOpt) error { - cred, err := a.GetAWSCredentials() - if err != nil { - return trace.Wrap(err) - } - awsMiddleware := &alpnproxy.AWSAccessMiddleware{ - AWSCredentials: cred, + AWSCredentialsV2Provider: a.GetAWSCredentialsProvider(), } // AWS endpoint URL mode @@ -184,14 +180,14 @@ func (a *awsApp) StartLocalProxies(ctx context.Context, opts ...alpnproxy.LocalP } // HTTPS proxy mode - err = a.StartLocalProxyWithForwarder(ctx, alpnproxy.MatchAWSRequests, alpnproxy.WithHTTPMiddleware(awsMiddleware)) + err := a.StartLocalProxyWithForwarder(ctx, alpnproxy.MatchAWSRequests, alpnproxy.WithHTTPMiddleware(awsMiddleware)) return trace.Wrap(err) } -// GetAWSCredentials generates fake AWS credentials that are used for -// signing an AWS request during AWS API calls and verified on local AWS proxy -// side. -func (a *awsApp) GetAWSCredentials() (*credentials.Credentials, error) { +// GetAWSCredentialsProvider returns an [aws.CredentialsProvider] that generates +// fake AWS credentials that are used for signing an AWS request during AWS API +// calls and verified on local AWS proxy side. +func (a *awsApp) GetAWSCredentialsProvider() aws.CredentialsProvider { // There is no specific format or value required for access key and secret, // as long as the AWS clients and the local proxy are using the same // credentials. The only constraint is the access key must have a length @@ -200,17 +196,13 @@ func (a *awsApp) GetAWSCredentials() (*credentials.Credentials, error) { // // https://docs.aws.amazon.com/STS/latest/APIReference/API_Credentials.html a.credentialsOnce.Do(func() { - a.credentials = credentials.NewStaticCredentials( + a.credentials = credentials.NewStaticCredentialsProvider( getEnvOrDefault(awsAccessKeyIDEnvVar, uuid.NewString()), getEnvOrDefault(awsSecretAccessKeyEnvVar, uuid.NewString()), "", ) }) - - if a.credentials == nil { - return nil, trace.BadParameter("missing credentials") - } - return a.credentials, nil + return a.credentials } // GetEnvVars returns required environment variables to configure the @@ -220,12 +212,7 @@ func (a *awsApp) GetEnvVars() (map[string]string, error) { return nil, trace.NotFound("ALPN proxy is not running") } - cred, err := a.GetAWSCredentials() - if err != nil { - return nil, trace.Wrap(err) - } - - credValues, err := cred.Get() + cred, err := a.GetAWSCredentialsProvider().Retrieve(context.Background()) if err != nil { return nil, trace.Wrap(err) } @@ -234,8 +221,8 @@ func (a *awsApp) GetEnvVars() (map[string]string, error) { // AWS CLI and SDKs can load credentials through environment variables. // // https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-envvars.html - "AWS_ACCESS_KEY_ID": credValues.AccessKeyID, - "AWS_SECRET_ACCESS_KEY": credValues.SecretAccessKey, + "AWS_ACCESS_KEY_ID": cred.AccessKeyID, + "AWS_SECRET_ACCESS_KEY": cred.SecretAccessKey, "AWS_CA_BUNDLE": a.appInfo.appLocalCAPath(a.cf.SiteName), } @@ -318,7 +305,7 @@ func getARNFromFlags(cf *CLIConf, app types.Application, logins []string) (strin } // Match by role ARN. - if awsarn.IsARN(cf.AWSRole) { + if arn.IsARN(cf.AWSRole) { if role, found := roles.FindRoleByARN(cf.AWSRole); found { return role.ARN, nil }