diff --git a/cmd/ocm-backplane/cloud/assume.go b/cmd/ocm-backplane/cloud/assume.go index 91371317..d727c4f8 100644 --- a/cmd/ocm-backplane/cloud/assume.go +++ b/cmd/ocm-backplane/cloud/assume.go @@ -3,6 +3,7 @@ package cloud import ( "context" "encoding/json" + "errors" "fmt" "io" "os" @@ -16,15 +17,10 @@ import ( "github.com/spf13/cobra" ) -const ( - DefaultInitialRoleArn = "arn:aws:iam::922711891673:role/SRE-Support-Role" -) - var assumeArgs struct { - initialRoleArn string - output string - debugFile string - console bool + output string + debugFile string + console bool } var StsClientWithProxy = awsutil.StsClientWithProxy @@ -37,19 +33,19 @@ var AssumeCmd = &cobra.Command{ Short: "Performs the assume role chaining necessary to generate temporary access to the customer's AWS account", Long: `Performs the assume role chaining necessary to generate temporary access to the customer's AWS account -This command is the equivalent of running "aws sts assume-role-with-web-identity --initial-role-arn [role-arn] --web-identity-token [ocm token] --role-session-name [email from OCM token]" behind the scenes, -where the ocm token used is the result of running "ocm token". Then, the command makes a call to the backplane API to get the necessary jump roles for the cluster's account. It then calls the -equivalent of "aws sts assume-role --initial-role-arn [role-arn] --role-session-name [email from OCM token]" repeatedly for each role arn in the chain, using the previous role's credentials to assume the next -role in the chain. +This command is the equivalent of running "aws sts assume-role-with-web-identity --role-arn [role-arn] --web-identity-token [ocm token] --role-session-name [email from OCM token]" +behind the scenes, where the ocm token used is the result of running "ocm token" and the role-arn is the value of "assume-initial-arn" from the backplane configuration. -This command will output sts credentials for the target role in the given cluster in formatted JSON. If no "role-arn" is provided, a default role will be used. +Then, the command makes a call to the backplane API to get the necessary jump roles for the cluster's account. It then calls the +equivalent of "aws sts assume-role --role-arn [role-arn] --role-session-name [email from OCM token]" repeatedly for each +role arn in the chain, using the previous role's credentials to assume the next role in the chain. + +By default this command will output sts credentials for the support in the given cluster account formatted as terminal envars. +If the "--console" flag is provided, it will output a link to the web console for the target cluster's account. `, - Example: `With default role: + Example: `With -o flag specified: backplane cloud assume e3b2fdc5-d9a7-435e-8870-312689cfb29c -oenv -With given role: -backplane cloud assume e3b2fdc5-d9a7-435e-8870-312689cfb29c --initial-role-arn arn:aws:iam::1234567890:role/read-only -oenv - With a debug file: backplane cloud assume e3b2fdc5-d9a7-435e-8870-312689cfb29c --debug-file test_arns @@ -61,8 +57,7 @@ backplane cloud assume e3b2fdc5-d9a7-435e-8870-312689cfb29c --console`, func init() { flags := AssumeCmd.Flags() - flags.StringVar(&assumeArgs.initialRoleArn, "initial-role-arn", DefaultInitialRoleArn, "The arn of the role for which to start the role assume process.") - flags.StringVarP(&assumeArgs.output, "output", "o", "env", "Format the output of the console response.") + flags.StringVarP(&assumeArgs.output, "output", "o", "env", "Format the output of the console response. Valid values are `env`, `json`, and `yaml`.") flags.StringVar(&assumeArgs.debugFile, "debug-file", "", "A file containing the list of ARNs to assume in order, not including the initial role ARN. Providing this flag will bypass calls to the backplane API to retrieve the assume role chain. The file should be a plain text file with each ARN on a new line.") flags.BoolVar(&assumeArgs.console, "console", false, "Outputs a console url to access the targeted cluster instead of the STS credentials.") } @@ -86,30 +81,30 @@ func runAssume(_ *cobra.Command, args []string) error { return fmt.Errorf("failed to retrieve OCM token: %w", err) } + email, err := utils.GetStringFieldFromJWT(*ocmToken, "email") + if err != nil { + return fmt.Errorf("unable to extract email from given token: %w", err) + } + bpConfig, err := GetBackplaneConfiguration() if err != nil { return fmt.Errorf("error retrieving backplane configuration: %w", err) } + if bpConfig.AssumeInitialArn == "" { + return errors.New("backplane config is missing required `assume-initial-arn` property") + } + initialClient, err := StsClientWithProxy(bpConfig.ProxyURL) if err != nil { return fmt.Errorf("failed to create sts client: %w", err) } - seedCredentials, err := AssumeRoleWithJWT(*ocmToken, assumeArgs.initialRoleArn, initialClient) - if err != nil { - return fmt.Errorf("failed to assume role using JWT: %w", err) - } - email, err := utils.GetStringFieldFromJWT(*ocmToken, "email") + seedCredentials, err := AssumeRoleWithJWT(*ocmToken, bpConfig.AssumeInitialArn, initialClient) if err != nil { - return fmt.Errorf("unable to extract email from given token: %w", err) + return fmt.Errorf("failed to assume role using JWT: %w", err) } - seedClient := sts.NewFromConfig(aws.Config{ - Region: "us-east-1", - Credentials: NewStaticCredentialsProvider(*seedCredentials.AccessKeyId, *seedCredentials.SecretAccessKey, *seedCredentials.SessionToken), - }) - var roleAssumeSequence []string if assumeArgs.debugFile == "" { clusterID, _, err := utils.DefaultOCMInterface.GetTargetCluster(args[0]) @@ -154,6 +149,11 @@ func runAssume(_ *cobra.Command, args []string) error { roleAssumeSequence = append(roleAssumeSequence, strings.Split(string(arnBytes), "\n")...) } + seedClient := sts.NewFromConfig(aws.Config{ + Region: "us-east-1", + Credentials: NewStaticCredentialsProvider(seedCredentials.AccessKeyID, seedCredentials.SecretAccessKey, seedCredentials.SessionToken), + }) + targetCredentials, err := AssumeRoleSequence(email, seedClient, roleAssumeSequence, bpConfig.ProxyURL, awsutil.DefaultSTSClientProviderFunc) if err != nil { return fmt.Errorf("failed to assume role sequence: %w", err) @@ -173,10 +173,10 @@ func runAssume(_ *cobra.Command, args []string) error { fmt.Printf("The AWS Console URL is:\n%s\n", signInFederationURL.String()) } else { credsResponse := awsutil.AWSCredentialsResponse{ - AccessKeyID: *targetCredentials.AccessKeyId, - SecretAccessKey: *targetCredentials.SecretAccessKey, - SessionToken: *targetCredentials.SessionToken, - Expiration: targetCredentials.Expiration.String(), + AccessKeyID: targetCredentials.AccessKeyID, + SecretAccessKey: targetCredentials.SecretAccessKey, + SessionToken: targetCredentials.SessionToken, + Expiration: targetCredentials.Expires.String(), } formattedResult, err := credsResponse.RenderOutput(assumeArgs.output) if err != nil { diff --git a/cmd/ocm-backplane/cloud/assume_test.go b/cmd/ocm-backplane/cloud/assume_test.go index 71af4b9b..4dc6ef25 100644 --- a/cmd/ocm-backplane/cloud/assume_test.go +++ b/cmd/ocm-backplane/cloud/assume_test.go @@ -4,9 +4,10 @@ import ( "context" "errors" "fmt" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/sts" - "github.com/aws/aws-sdk-go-v2/service/sts/types" "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -64,18 +65,19 @@ var _ = Describe("Cloud assume command", func() { mockOcmInterface.EXPECT().GetOCMAccessToken().Return(&testOcmToken, nil).Times(1) GetBackplaneConfiguration = func() (bpConfig config.BackplaneConfiguration, err error) { return config.BackplaneConfiguration{ - URL: "testUrl.com", - ProxyURL: "testProxyUrl.com", + URL: "testUrl.com", + ProxyURL: "testProxyUrl.com", + AssumeInitialArn: "arn:aws:iam::123456789:role/ManagedOpenShift-Support-Role", }, nil } StsClientWithProxy = func(proxyURL string) (*sts.Client, error) { return &sts.Client{}, nil } - AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) { - return &types.Credentials{ - AccessKeyId: &testAccessKeyID, - SecretAccessKey: &testSecretAccessKey, - SessionToken: &testSessionToken, + AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) { + return aws.Credentials{ + AccessKeyID: testAccessKeyID, + SecretAccessKey: testSecretAccessKey, + SessionToken: testSessionToken, }, nil } NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider { @@ -87,12 +89,12 @@ var _ = Describe("Cloud assume command", func() { StatusCode: 200, Body: io.NopCloser(strings.NewReader(`{"assumption_sequence":[{"name": "name_one", "arn": "arn_one"},{"name": "name_two", "arn": "arn_two"}]}`)), }, nil).Times(1) - AssumeRoleSequence = func(roleSessionName string, seedClient awsutil.STSRoleAssumer, roleArnSequence []string, proxyURL string, stsClientProviderFunc awsutil.STSClientProviderFunc) (*types.Credentials, error) { - return &types.Credentials{ - AccessKeyId: &testAccessKeyID, - SecretAccessKey: &testSecretAccessKey, - SessionToken: &testSessionToken, - Expiration: &testExpiration, + AssumeRoleSequence = func(roleSessionName string, seedClient stscreds.AssumeRoleAPIClient, roleArnSequence []string, proxyURL string, stsClientProviderFunc awsutil.STSClientProviderFunc) (aws.Credentials, error) { + return aws.Credentials{ + AccessKeyID: testAccessKeyID, + SecretAccessKey: testSecretAccessKey, + SessionToken: testSessionToken, + Expires: testExpiration, }, nil } @@ -118,12 +120,13 @@ var _ = Describe("Cloud assume command", func() { err := runAssume(nil, []string{testClusterID}) Expect(err.Error()).To(Equal("error retrieving backplane configuration: oops")) }) - It("should fail if cannot create create sts client with proxy", func() { + It("should fail if cannot create sts client with proxy", func() { mockOcmInterface.EXPECT().GetOCMAccessToken().Return(&testOcmToken, nil).Times(1) GetBackplaneConfiguration = func() (bpConfig config.BackplaneConfiguration, err error) { return config.BackplaneConfiguration{ - URL: "testUrl.com", - ProxyURL: "testProxyUrl.com", + URL: "testUrl.com", + ProxyURL: "testProxyUrl.com", + AssumeInitialArn: "arn:aws:iam::123456789:role/ManagedOpenShift-Support-Role", }, nil } StsClientWithProxy = func(proxyURL string) (*sts.Client, error) { @@ -137,15 +140,16 @@ var _ = Describe("Cloud assume command", func() { mockOcmInterface.EXPECT().GetOCMAccessToken().Return(&testOcmToken, nil).Times(1) GetBackplaneConfiguration = func() (bpConfig config.BackplaneConfiguration, err error) { return config.BackplaneConfiguration{ - URL: "testUrl.com", - ProxyURL: "testProxyUrl.com", + URL: "testUrl.com", + ProxyURL: "testProxyUrl.com", + AssumeInitialArn: "arn:aws:iam::123456789:role/ManagedOpenShift-Support-Role", }, nil } StsClientWithProxy = func(proxyURL string) (*sts.Client, error) { return &sts.Client{}, nil } - AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) { - return nil, errors.New("failure") + AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) { + return aws.Credentials{}, errors.New("failure") } err := runAssume(nil, []string{testClusterID}) @@ -156,18 +160,19 @@ var _ = Describe("Cloud assume command", func() { mockOcmInterface.EXPECT().GetOCMAccessToken().Return(&testOcmToken, nil).Times(1) GetBackplaneConfiguration = func() (bpConfig config.BackplaneConfiguration, err error) { return config.BackplaneConfiguration{ - URL: "testUrl.com", - ProxyURL: "testProxyUrl.com", + URL: "testUrl.com", + ProxyURL: "testProxyUrl.com", + AssumeInitialArn: "arn:aws:iam::123456789:role/ManagedOpenShift-Support-Role", }, nil } StsClientWithProxy = func(proxyURL string) (*sts.Client, error) { return &sts.Client{}, nil } - AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) { - return &types.Credentials{ - AccessKeyId: &testAccessKeyID, - SecretAccessKey: &testSecretAccessKey, - SessionToken: &testSessionToken, + AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) { + return aws.Credentials{ + AccessKeyID: testAccessKeyID, + SecretAccessKey: testSecretAccessKey, + SessionToken: testSessionToken, }, nil } @@ -178,18 +183,19 @@ var _ = Describe("Cloud assume command", func() { mockOcmInterface.EXPECT().GetOCMAccessToken().Return(&testOcmToken, nil).Times(1) GetBackplaneConfiguration = func() (bpConfig config.BackplaneConfiguration, err error) { return config.BackplaneConfiguration{ - URL: "testUrl.com", - ProxyURL: "testProxyUrl.com", + URL: "testUrl.com", + ProxyURL: "testProxyUrl.com", + AssumeInitialArn: "arn:aws:iam::123456789:role/ManagedOpenShift-Support-Role", }, nil } StsClientWithProxy = func(proxyURL string) (*sts.Client, error) { return &sts.Client{}, nil } - AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) { - return &types.Credentials{ - AccessKeyId: &testAccessKeyID, - SecretAccessKey: &testSecretAccessKey, - SessionToken: &testSessionToken, + AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) { + return aws.Credentials{ + AccessKeyID: testAccessKeyID, + SecretAccessKey: testSecretAccessKey, + SessionToken: testSessionToken, }, nil } mockOcmInterface.EXPECT().GetTargetCluster(testClusterID).Return("", "", errors.New("oh no")).Times(1) @@ -201,18 +207,19 @@ var _ = Describe("Cloud assume command", func() { mockOcmInterface.EXPECT().GetOCMAccessToken().Return(&testOcmToken, nil).Times(1) GetBackplaneConfiguration = func() (bpConfig config.BackplaneConfiguration, err error) { return config.BackplaneConfiguration{ - URL: "testUrl.com", - ProxyURL: "testProxyUrl.com", + URL: "testUrl.com", + ProxyURL: "testProxyUrl.com", + AssumeInitialArn: "arn:aws:iam::123456789:role/ManagedOpenShift-Support-Role", }, nil } StsClientWithProxy = func(proxyURL string) (*sts.Client, error) { return &sts.Client{}, nil } - AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) { - return &types.Credentials{ - AccessKeyId: &testAccessKeyID, - SecretAccessKey: &testSecretAccessKey, - SessionToken: &testSessionToken, + AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) { + return aws.Credentials{ + AccessKeyID: testAccessKeyID, + SecretAccessKey: testSecretAccessKey, + SessionToken: testSessionToken, }, nil } NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider { @@ -228,18 +235,19 @@ var _ = Describe("Cloud assume command", func() { mockOcmInterface.EXPECT().GetOCMAccessToken().Return(&testOcmToken, nil).Times(1) GetBackplaneConfiguration = func() (bpConfig config.BackplaneConfiguration, err error) { return config.BackplaneConfiguration{ - URL: "testUrl.com", - ProxyURL: "testProxyUrl.com", + URL: "testUrl.com", + ProxyURL: "testProxyUrl.com", + AssumeInitialArn: "arn:aws:iam::123456789:role/ManagedOpenShift-Support-Role", }, nil } StsClientWithProxy = func(proxyURL string) (*sts.Client, error) { return &sts.Client{}, nil } - AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) { - return &types.Credentials{ - AccessKeyId: &testAccessKeyID, - SecretAccessKey: &testSecretAccessKey, - SessionToken: &testSessionToken, + AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) { + return aws.Credentials{ + AccessKeyID: testAccessKeyID, + SecretAccessKey: testSecretAccessKey, + SessionToken: testSessionToken, }, nil } NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider { @@ -256,18 +264,19 @@ var _ = Describe("Cloud assume command", func() { mockOcmInterface.EXPECT().GetOCMAccessToken().Return(&testOcmToken, nil).Times(1) GetBackplaneConfiguration = func() (bpConfig config.BackplaneConfiguration, err error) { return config.BackplaneConfiguration{ - URL: "testUrl.com", - ProxyURL: "testProxyUrl.com", + URL: "testUrl.com", + ProxyURL: "testProxyUrl.com", + AssumeInitialArn: "arn:aws:iam::123456789:role/ManagedOpenShift-Support-Role", }, nil } StsClientWithProxy = func(proxyURL string) (*sts.Client, error) { return &sts.Client{}, nil } - AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) { - return &types.Credentials{ - AccessKeyId: &testAccessKeyID, - SecretAccessKey: &testSecretAccessKey, - SessionToken: &testSessionToken, + AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) { + return aws.Credentials{ + AccessKeyID: testAccessKeyID, + SecretAccessKey: testSecretAccessKey, + SessionToken: testSessionToken, }, nil } NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider { @@ -288,18 +297,19 @@ var _ = Describe("Cloud assume command", func() { mockOcmInterface.EXPECT().GetOCMAccessToken().Return(&testOcmToken, nil).Times(1) GetBackplaneConfiguration = func() (bpConfig config.BackplaneConfiguration, err error) { return config.BackplaneConfiguration{ - URL: "testUrl.com", - ProxyURL: "testProxyUrl.com", + URL: "testUrl.com", + ProxyURL: "testProxyUrl.com", + AssumeInitialArn: "arn:aws:iam::123456789:role/ManagedOpenShift-Support-Role", }, nil } StsClientWithProxy = func(proxyURL string) (*sts.Client, error) { return &sts.Client{}, nil } - AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) { - return &types.Credentials{ - AccessKeyId: &testAccessKeyID, - SecretAccessKey: &testSecretAccessKey, - SessionToken: &testSessionToken, + AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) { + return aws.Credentials{ + AccessKeyID: testAccessKeyID, + SecretAccessKey: testSecretAccessKey, + SessionToken: testSessionToken, }, nil } NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider { @@ -319,18 +329,19 @@ var _ = Describe("Cloud assume command", func() { mockOcmInterface.EXPECT().GetOCMAccessToken().Return(&testOcmToken, nil).Times(1) GetBackplaneConfiguration = func() (bpConfig config.BackplaneConfiguration, err error) { return config.BackplaneConfiguration{ - URL: "testUrl.com", - ProxyURL: "testProxyUrl.com", + URL: "testUrl.com", + ProxyURL: "testProxyUrl.com", + AssumeInitialArn: "arn:aws:iam::123456789:role/ManagedOpenShift-Support-Role", }, nil } StsClientWithProxy = func(proxyURL string) (*sts.Client, error) { return &sts.Client{}, nil } - AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) { - return &types.Credentials{ - AccessKeyId: &testAccessKeyID, - SecretAccessKey: &testSecretAccessKey, - SessionToken: &testSessionToken, + AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) { + return aws.Credentials{ + AccessKeyID: testAccessKeyID, + SecretAccessKey: testSecretAccessKey, + SessionToken: testSessionToken, }, nil } NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider { @@ -342,8 +353,8 @@ var _ = Describe("Cloud assume command", func() { StatusCode: 200, Body: io.NopCloser(strings.NewReader(`{"assumption_sequence":[{"name": "name_one", "arn": "arn_one"},{"name": "name_two", "arn": "arn_two"}]}`)), }, nil).Times(1) - AssumeRoleSequence = func(roleSessionName string, seedClient awsutil.STSRoleAssumer, roleArnSequence []string, proxyURL string, stsClientProviderFunc awsutil.STSClientProviderFunc) (*types.Credentials, error) { - return nil, errors.New("oops") + AssumeRoleSequence = func(roleSessionName string, seedClient stscreds.AssumeRoleAPIClient, roleArnSequence []string, proxyURL string, stsClientProviderFunc awsutil.STSClientProviderFunc) (aws.Credentials, error) { + return aws.Credentials{}, errors.New("oops") } err := runAssume(nil, []string{testClusterID}) diff --git a/cmd/ocm-backplane/cloud/token.go b/cmd/ocm-backplane/cloud/token.go index 2836df41..c23b3245 100644 --- a/cmd/ocm-backplane/cloud/token.go +++ b/cmd/ocm-backplane/cloud/token.go @@ -56,10 +56,10 @@ func runToken(*cobra.Command, []string) error { } credsResponse := awsutil.AWSCredentialsResponse{ - AccessKeyID: *result.AccessKeyId, - SecretAccessKey: *result.SecretAccessKey, - SessionToken: *result.SessionToken, - Expiration: result.Expiration.String(), + AccessKeyID: result.AccessKeyID, + SecretAccessKey: result.SecretAccessKey, + SessionToken: result.SessionToken, + Expiration: result.Expires.String(), } formattedResult, err := credsResponse.RenderOutput(tokenArgs.output) diff --git a/pkg/awsutil/sts.go b/pkg/awsutil/sts.go index 92724b73..7953b70d 100644 --- a/pkg/awsutil/sts.go +++ b/pkg/awsutil/sts.go @@ -5,25 +5,25 @@ import ( "encoding/json" "errors" "fmt" - "io" - "net/http" - "net/url" - "time" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/aws/retry" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/sts" - "github.com/aws/aws-sdk-go-v2/service/sts/types" - "github.com/openshift/backplane-cli/pkg/utils" + "io" + "net/http" + "net/url" + "time" ) const ( AwsFederatedSigninEndpoint = "https://signin.aws.amazon.com/federation" AwsConsoleURL = "https://console.aws.amazon.com/" DefaultIssuer = "Red Hat SRE" + + assumeRoleMaxRetries = 5 + assumeRoleRetryBackoff = 5 * time.Second ) type AWSFederatedSessionData struct { @@ -56,49 +56,52 @@ func StsClientWithProxy(proxyURL string) (*sts.Client, error) { return sts.NewFromConfig(cfg), nil } -type STSRoleAssumer interface { - AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) -} +// IdentityTokenValue is for retrieving an identity token from the given file name +type IdentityTokenValue string -type STSRoleWithWebIdentityAssumer interface { - AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) +// GetIdentityToken retrieves the JWT token from the file and returns the contents as a []byte +func (j IdentityTokenValue) GetIdentityToken() ([]byte, error) { + return []byte(j), nil } -func AssumeRoleWithJWT(jwt string, roleArn string, stsClient STSRoleWithWebIdentityAssumer) (*types.Credentials, error) { +func AssumeRoleWithJWT(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) { email, err := utils.GetStringFieldFromJWT(jwt, "email") if err != nil { - return nil, fmt.Errorf("unable to extract email from given token: %w", err) - } - input := &sts.AssumeRoleWithWebIdentityInput{ - RoleArn: aws.String(roleArn), - RoleSessionName: aws.String(email), - WebIdentityToken: aws.String(jwt), + return aws.Credentials{}, fmt.Errorf("unable to extract email from given token: %w", err) } - result, err := stsClient.AssumeRoleWithWebIdentity(context.TODO(), input) + credentialsCache := aws.NewCredentialsCache(stscreds.NewWebIdentityRoleProvider( + stsClient, + roleArn, + IdentityTokenValue(jwt), + func(options *stscreds.WebIdentityRoleOptions) { + options.RoleSessionName = email + }, + )) + + result, err := credentialsCache.Retrieve(context.TODO()) if err != nil { - return nil, fmt.Errorf("unable to assume the given role with the token provided: %w", err) + return aws.Credentials{}, fmt.Errorf("unable to assume the given role with the token provided: %w", err) } - return result.Credentials, nil + return result, nil } -func AssumeRole(roleSessionName string, stsClient STSRoleAssumer, roleArn string) (*types.Credentials, error) { - input := &sts.AssumeRoleInput{ - RoleArn: aws.String(roleArn), - RoleSessionName: aws.String(roleSessionName), - } - result, err := stsClient.AssumeRole(context.TODO(), input) +func AssumeRole(roleSessionName string, stsClient stscreds.AssumeRoleAPIClient, roleArn string) (aws.Credentials, error) { + assumeRoleProvider := stscreds.NewAssumeRoleProvider(stsClient, roleArn, func(options *stscreds.AssumeRoleOptions) { + options.RoleSessionName = roleSessionName + }) + result, err := assumeRoleProvider.Retrieve(context.TODO()) if err != nil { - return nil, fmt.Errorf("failed to assume role %v: %w", roleArn, err) + return aws.Credentials{}, fmt.Errorf("failed to assume role %v: %w", roleArn, err) } - return result.Credentials, nil + return result, nil } -type STSClientProviderFunc func(optFns ...func(*config.LoadOptions) error) (STSRoleAssumer, error) +type STSClientProviderFunc func(optFns ...func(*config.LoadOptions) error) (stscreds.AssumeRoleAPIClient, error) -var DefaultSTSClientProviderFunc STSClientProviderFunc = func(optnFns ...func(options *config.LoadOptions) error) (STSRoleAssumer, error) { +var DefaultSTSClientProviderFunc STSClientProviderFunc = func(optnFns ...func(options *config.LoadOptions) error) (stscreds.AssumeRoleAPIClient, error) { cfg, err := config.LoadDefaultConfig(context.TODO(), optnFns...) if err != nil { return nil, fmt.Errorf("failed to load default AWS config: %w", err) @@ -106,44 +109,39 @@ var DefaultSTSClientProviderFunc STSClientProviderFunc = func(optnFns ...func(op return sts.NewFromConfig(cfg), nil } -func AssumeRoleSequence(roleSessionName string, seedClient STSRoleAssumer, roleArnSequence []string, proxyURL string, stsClientProviderFunc STSClientProviderFunc) (*types.Credentials, error) { +func AssumeRoleSequence(roleSessionName string, seedClient stscreds.AssumeRoleAPIClient, roleArnSequence []string, proxyURL string, stsClientProviderFunc STSClientProviderFunc) (aws.Credentials, error) { if len(roleArnSequence) == 0 { - return nil, errors.New("role ARN sequence cannot be empty") + return aws.Credentials{}, errors.New("role ARN sequence cannot be empty") } nextClient := seedClient - var lastCredentials *types.Credentials + var lastCredentials aws.Credentials for i, roleArn := range roleArnSequence { result, err := AssumeRole(roleSessionName, nextClient, roleArn) - if err != nil { - return nil, fmt.Errorf("failed to assume role %v: %w", roleArn, err) + retryCount := 0 + for err != nil { + // IAM policy updates can take a few seconds to resolve, and the sts.Client in AWS' Go SDK doesn't refresh itself on retries. + // https://github.com/aws/aws-sdk-go-v2/issues/2332 + if retryCount < assumeRoleMaxRetries { + fmt.Println("Waiting for IAM policy changes to resolve...") + time.Sleep(assumeRoleRetryBackoff) + nextClient, err = createAssumeRoleSequenceClient(stsClientProviderFunc, lastCredentials, proxyURL) + if err != nil { + return aws.Credentials{}, fmt.Errorf("failed to create client with credentials for role %v: %w", roleArn, err) + } + result, err = AssumeRole(roleSessionName, nextClient, roleArn) + retryCount++ + } else { + return aws.Credentials{}, fmt.Errorf("failed to assume role %v: %w", roleArn, err) + } } lastCredentials = result if i < len(roleArnSequence)-1 { - nextClient, err = stsClientProviderFunc( - config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(*lastCredentials.AccessKeyId, *lastCredentials.SecretAccessKey, *lastCredentials.SessionToken)), - config.WithHTTPClient(&http.Client{ - Transport: &http.Transport{ - Proxy: func(*http.Request) (*url.URL, error) { - return url.Parse(proxyURL) - }, - }, - }), - config.WithRetryer(func() aws.Retryer { - return retry.NewStandard(func(options *retry.StandardOptions) { - options.Retryables = append(options.Retryables, retry.RetryableHTTPStatusCode{ - Codes: map[int]struct{}{401: {}, 403: {}, 404: {}}, // Handle IAM eventual consistency because backplane api modifies trust policy - }) - options.MaxAttempts = 5 - options.MaxBackoff = 20 * time.Second - }) - }), - config.WithRegion("us-east-1"), // We don't care about region here, but the API still wants to see one set - ) + nextClient, err = createAssumeRoleSequenceClient(stsClientProviderFunc, lastCredentials, proxyURL) if err != nil { - return nil, fmt.Errorf("failed to create client with credentials from role %v: %w", roleArn, err) + return aws.Credentials{}, fmt.Errorf("failed to create client with credentials for role %v: %w", roleArn, err) } } } @@ -151,11 +149,25 @@ func AssumeRoleSequence(roleSessionName string, seedClient STSRoleAssumer, roleA return lastCredentials, nil } -func GetSigninToken(awsCredentials *types.Credentials) (*AWSSigninTokenResponse, error) { +func createAssumeRoleSequenceClient(stsClientProviderFunc STSClientProviderFunc, creds aws.Credentials, proxyURL string) (stscreds.AssumeRoleAPIClient, error) { + return stsClientProviderFunc( + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(creds.AccessKeyID, creds.SecretAccessKey, creds.SessionToken)), + config.WithHTTPClient(&http.Client{ + Transport: &http.Transport{ + Proxy: func(*http.Request) (*url.URL, error) { + return url.Parse(proxyURL) + }, + }, + }), + config.WithRegion("us-east-1"), // We don't care about region here, but the API still wants to see one set + ) +} + +func GetSigninToken(awsCredentials aws.Credentials) (*AWSSigninTokenResponse, error) { sessionData := AWSFederatedSessionData{ - SessionID: *awsCredentials.AccessKeyId, - SessionKey: *awsCredentials.SecretAccessKey, - SessionToken: *awsCredentials.SessionToken, + SessionID: awsCredentials.AccessKeyID, + SessionKey: awsCredentials.SecretAccessKey, + SessionToken: awsCredentials.SessionToken, } data, err := json.Marshal(sessionData) diff --git a/pkg/awsutil/sts_test.go b/pkg/awsutil/sts_test.go index 39538525..79ac865a 100644 --- a/pkg/awsutil/sts_test.go +++ b/pkg/awsutil/sts_test.go @@ -4,11 +4,13 @@ import ( "bytes" "context" "errors" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "io" "net/http" "net/url" "reflect" "testing" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" @@ -36,12 +38,14 @@ func defaultSuccessMockSTSClient() STSRoleAssumerMock { AccessKeyId: aws.String("test-access-key-id"), SecretAccessKey: aws.String("test-secret-access-key"), SessionToken: aws.String("test-session-token"), + Expiration: aws.Time(time.UnixMilli(1)), }, }, &sts.AssumeRoleWithWebIdentityOutput{ Credentials: &types.Credentials{ AccessKeyId: aws.String("test-access-key-id"), SecretAccessKey: aws.String("test-secret-access-key"), SessionToken: aws.String("test-session-token"), + Expiration: aws.Time(time.UnixMilli(1)), }, }, nil) } @@ -62,12 +66,12 @@ func TestAssumeRoleWithJWT(t *testing.T) { type args struct { jwt string roleArn string - stsClient STSRoleWithWebIdentityAssumer + stsClient stscreds.AssumeRoleWithWebIdentityAPIClient } tests := []struct { name string args args - want *types.Credentials + want aws.Credentials wantErr bool }{ { @@ -94,10 +98,13 @@ func TestAssumeRoleWithJWT(t *testing.T) { roleArn: "arn:aws:iam::1234567890:role/read-only", stsClient: defaultSuccessMockSTSClient(), }, - want: &types.Credentials{ - AccessKeyId: aws.String("test-access-key-id"), - SecretAccessKey: aws.String("test-secret-access-key"), - SessionToken: aws.String("test-session-token"), + want: aws.Credentials{ + AccessKeyID: "test-access-key-id", + SecretAccessKey: "test-secret-access-key", + SessionToken: "test-session-token", + Source: "WebIdentityCredentials", + CanExpire: true, + Expires: time.UnixMilli(1), }, }, } @@ -118,8 +125,8 @@ func TestAssumeRoleWithJWT(t *testing.T) { func TestAssumeRole(t *testing.T) { tests := []struct { name string - stsClient STSRoleAssumer - want *types.Credentials + stsClient stscreds.AssumeRoleAPIClient + want aws.Credentials wantErr bool }{ { @@ -130,10 +137,13 @@ func TestAssumeRole(t *testing.T) { { name: "Successfully assumes role", stsClient: defaultSuccessMockSTSClient(), - want: &types.Credentials{ - AccessKeyId: aws.String("test-access-key-id"), - SecretAccessKey: aws.String("test-secret-access-key"), - SessionToken: aws.String("test-session-token"), + want: aws.Credentials{ + AccessKeyID: "test-access-key-id", + SecretAccessKey: "test-secret-access-key", + SessionToken: "test-session-token", + Source: "AssumeRoleProvider", + CanExpire: true, + Expires: time.UnixMilli(1), }, }, } @@ -153,14 +163,14 @@ func TestAssumeRole(t *testing.T) { func TestAssumeRoleSequence(t *testing.T) { type args struct { - seedClient STSRoleAssumer + seedClient stscreds.AssumeRoleAPIClient roleArnSequence []string stsClientProviderFunc STSClientProviderFunc } tests := []struct { name string args args - want *types.Credentials + want aws.Credentials wantErr bool }{ { @@ -182,14 +192,17 @@ func TestAssumeRoleSequence(t *testing.T) { args: args{ seedClient: defaultSuccessMockSTSClient(), roleArnSequence: []string{"a"}, - stsClientProviderFunc: func(optFns ...func(*config.LoadOptions) error) (STSRoleAssumer, error) { + stsClientProviderFunc: func(optFns ...func(*config.LoadOptions) error) (stscreds.AssumeRoleAPIClient, error) { return defaultSuccessMockSTSClient(), nil }, }, - want: &types.Credentials{ - AccessKeyId: aws.String("test-access-key-id"), - SecretAccessKey: aws.String("test-secret-access-key"), - SessionToken: aws.String("test-session-token"), + want: aws.Credentials{ + AccessKeyID: "test-access-key-id", + SecretAccessKey: "test-secret-access-key", + SessionToken: "test-session-token", + Source: "AssumeRoleProvider", + CanExpire: true, + Expires: time.UnixMilli(1), }, }, } @@ -208,10 +221,10 @@ func TestAssumeRoleSequence(t *testing.T) { } func TestGetSigninToken(t *testing.T) { - awsCredentials := &types.Credentials{ - AccessKeyId: aws.String("testAccessKeyId"), - SecretAccessKey: aws.String("testSecretAccessKey"), - SessionToken: aws.String("testSessionToken"), + awsCredentials := aws.Credentials{ + AccessKeyID: "testAccessKeyId", + SecretAccessKey: "testSecretAccessKey", + SessionToken: "testSessionToken", } tests := []struct { name string diff --git a/pkg/cli/config/config.go b/pkg/cli/config/config.go index 5bcf5b8f..76503c16 100644 --- a/pkg/cli/config/config.go +++ b/pkg/cli/config/config.go @@ -16,6 +16,7 @@ type BackplaneConfiguration struct { URL string ProxyURL string SessionDirectory string + AssumeInitialArn string } // GetConfigFilePath returns the Backplane CLI configuration filepath @@ -71,6 +72,7 @@ func GetBackplaneConfiguration() (bpConfig BackplaneConfiguration, err error) { bpConfig.URL = viper.GetString("url") bpConfig.ProxyURL = viper.GetString("proxy-url") bpConfig.SessionDirectory = viper.GetString("session-dir") + bpConfig.AssumeInitialArn = viper.GetString("assume-initial-arn") return bpConfig, nil }