Skip to content

Commit

Permalink
migrate tool/tsh/common to AWS SDK v2
Browse files Browse the repository at this point in the history
  • Loading branch information
GavinFrazar authored and github-actions committed Dec 9, 2024
1 parent 51ed796 commit af0da5d
Showing 1 changed file with 16 additions and 29 deletions.
45 changes: 16 additions & 29 deletions tool/tsh/common/app_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ import (
"strings"
"sync"

awsarn "github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/arn"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/google/uuid"
"github.com/gravitational/trace"

Expand Down Expand Up @@ -136,7 +137,7 @@ type awsApp struct {

cf *CLIConf

credentials *credentials.Credentials
credentials aws.CredentialsProvider
credentialsOnce sync.Once
}

Expand Down Expand Up @@ -168,13 +169,8 @@ func (a *awsApp) GetAppName() string {
// The first method is always preferred as the original hostname is preserved
// through forward proxy.
func (a *awsApp) StartLocalProxies(ctx context.Context, opts ...alpnproxy.LocalProxyConfigOpt) error {
cred, err := a.GetAWSCredentials()
if err != nil {
return trace.Wrap(err)
}

awsMiddleware := &alpnproxy.AWSAccessMiddleware{
AWSCredentials: cred,
AWSCredentialsV2Provider: a.GetAWSCredentialsProvider(),
}

// AWS endpoint URL mode
Expand All @@ -184,14 +180,14 @@ func (a *awsApp) StartLocalProxies(ctx context.Context, opts ...alpnproxy.LocalP
}

// HTTPS proxy mode
err = a.StartLocalProxyWithForwarder(ctx, alpnproxy.MatchAWSRequests, alpnproxy.WithHTTPMiddleware(awsMiddleware))
err := a.StartLocalProxyWithForwarder(ctx, alpnproxy.MatchAWSRequests, alpnproxy.WithHTTPMiddleware(awsMiddleware))
return trace.Wrap(err)
}

// GetAWSCredentials generates fake AWS credentials that are used for
// signing an AWS request during AWS API calls and verified on local AWS proxy
// side.
func (a *awsApp) GetAWSCredentials() (*credentials.Credentials, error) {
// GetAWSCredentialsProvider returns an [aws.CredentialsProvider] that generates
// fake AWS credentials that are used for signing an AWS request during AWS API
// calls and verified on local AWS proxy side.
func (a *awsApp) GetAWSCredentialsProvider() aws.CredentialsProvider {
// There is no specific format or value required for access key and secret,
// as long as the AWS clients and the local proxy are using the same
// credentials. The only constraint is the access key must have a length
Expand All @@ -200,17 +196,13 @@ func (a *awsApp) GetAWSCredentials() (*credentials.Credentials, error) {
//
// https://docs.aws.amazon.com/STS/latest/APIReference/API_Credentials.html
a.credentialsOnce.Do(func() {
a.credentials = credentials.NewStaticCredentials(
a.credentials = credentials.NewStaticCredentialsProvider(
getEnvOrDefault(awsAccessKeyIDEnvVar, uuid.NewString()),
getEnvOrDefault(awsSecretAccessKeyEnvVar, uuid.NewString()),
"",
)
})

if a.credentials == nil {
return nil, trace.BadParameter("missing credentials")
}
return a.credentials, nil
return a.credentials
}

// GetEnvVars returns required environment variables to configure the
Expand All @@ -220,12 +212,7 @@ func (a *awsApp) GetEnvVars() (map[string]string, error) {
return nil, trace.NotFound("ALPN proxy is not running")
}

cred, err := a.GetAWSCredentials()
if err != nil {
return nil, trace.Wrap(err)
}

credValues, err := cred.Get()
cred, err := a.GetAWSCredentialsProvider().Retrieve(context.Background())
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -234,8 +221,8 @@ func (a *awsApp) GetEnvVars() (map[string]string, error) {
// AWS CLI and SDKs can load credentials through environment variables.
//
// https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-envvars.html
"AWS_ACCESS_KEY_ID": credValues.AccessKeyID,
"AWS_SECRET_ACCESS_KEY": credValues.SecretAccessKey,
"AWS_ACCESS_KEY_ID": cred.AccessKeyID,
"AWS_SECRET_ACCESS_KEY": cred.SecretAccessKey,
"AWS_CA_BUNDLE": a.appInfo.appLocalCAPath(a.cf.SiteName),
}

Expand Down Expand Up @@ -318,7 +305,7 @@ func getARNFromFlags(cf *CLIConf, app types.Application, logins []string) (strin
}

// Match by role ARN.
if awsarn.IsARN(cf.AWSRole) {
if arn.IsARN(cf.AWSRole) {
if role, found := roles.FindRoleByARN(cf.AWSRole); found {
return role.ARN, nil
}
Expand Down

0 comments on commit af0da5d

Please sign in to comment.