Skip to content

Commit

Permalink
Implement retry for eventual consistency in IAM updates
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexVulaj committed Oct 23, 2023
1 parent 3e569d6 commit 4b81bba
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 170 deletions.
98 changes: 70 additions & 28 deletions cmd/ocm-backplane/cloud/assume.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cloud
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
Expand All @@ -17,14 +18,20 @@ import (
)

const (
DefaultInitialRoleArn = "arn:aws:iam::922711891673:role/SRE-Support-Role"
InitialRoleArnTemplate = "arn:aws:iam::%v:role/SRE-Support-Role"
EnvProd Environment = "prod"
ProdPayerAccountID = 922711891673
EnvStg Environment = "stg"
StgPayerAccountId = 277304166082
EnvInt Environment = "int"
IntPayerAccountId = 277304166082
)

var assumeArgs struct {
initialRoleArn string
output string
debugFile string
console bool
environment Environment
output string
debugFile string
console bool
}

var StsClientWithProxy = awsutil.StsClientWithProxy
Expand All @@ -37,18 +44,20 @@ var AssumeCmd = &cobra.Command{
Short: "Performs the assume role chaining necessary to generate temporary access to the customer's AWS account",
Long: `Performs the assume role chaining necessary to generate temporary access to the customer's AWS account
This command is the equivalent of running "aws sts assume-role-with-web-identity --initial-role-arn [role-arn] --web-identity-token [ocm token] --role-session-name [email from OCM token]" behind the scenes,
This command is the equivalent of running "aws sts assume-role-with-web-identity --role-arn [role-arn] --web-identity-token [ocm token] --role-session-name [email from OCM token]" behind the scenes,
where the ocm token used is the result of running "ocm token". Then, the command makes a call to the backplane API to get the necessary jump roles for the cluster's account. It then calls the
equivalent of "aws sts assume-role --initial-role-arn [role-arn] --role-session-name [email from OCM token]" repeatedly for each role arn in the chain, using the previous role's credentials to assume the next
equivalent of "aws sts assume-role --role-arn [role-arn] --role-session-name [email from OCM token]" repeatedly for each role arn in the chain, using the previous role's credentials to assume the next
role in the chain.
This command will output sts credentials for the target role in the given cluster in formatted JSON. If no "role-arn" is provided, a default role will be used.
If running in an OCM environment other than production, the environment must be manually specified with the "env" flag, for which accepted values are "prod", "stg", and "int".
This command will output sts credentials for the target role in the given cluster in formatted JSON.
`,
Example: `With default role:
Example: `With -o flag specified:
backplane cloud assume e3b2fdc5-d9a7-435e-8870-312689cfb29c -oenv
With given role:
backplane cloud assume e3b2fdc5-d9a7-435e-8870-312689cfb29c --initial-role-arn arn:aws:iam::1234567890:role/read-only -oenv
With given OCM environment:
backplane cloud assume e3b2fdc5-d9a7-435e-8870-312689cfb29c --env stg -oenv
With a debug file:
backplane cloud assume e3b2fdc5-d9a7-435e-8870-312689cfb29c --debug-file test_arns
Expand All @@ -61,8 +70,8 @@ backplane cloud assume e3b2fdc5-d9a7-435e-8870-312689cfb29c --console`,

func init() {
flags := AssumeCmd.Flags()
flags.StringVar(&assumeArgs.initialRoleArn, "initial-role-arn", DefaultInitialRoleArn, "The arn of the role for which to start the role assume process.")
flags.StringVarP(&assumeArgs.output, "output", "o", "env", "Format the output of the console response.")
flags.Var(&assumeArgs.environment, "env", "The OCM environment in which the target cluster is deployed. Valid values are `prod`, `stg`, and `int`.")
flags.StringVarP(&assumeArgs.output, "output", "o", "env", "Format the output of the console response. Valid values are `env`, `json`, and `yaml`.")
flags.StringVar(&assumeArgs.debugFile, "debug-file", "", "A file containing the list of ARNs to assume in order, not including the initial role ARN. Providing this flag will bypass calls to the backplane API to retrieve the assume role chain. The file should be a plain text file with each ARN on a new line.")
flags.BoolVar(&assumeArgs.console, "console", false, "Outputs a console url to access the targeted cluster instead of the STS credentials.")
}
Expand All @@ -86,30 +95,38 @@ func runAssume(_ *cobra.Command, args []string) error {
return fmt.Errorf("failed to retrieve OCM token: %w", err)
}

email, err := utils.GetStringFieldFromJWT(*ocmToken, "email")
if err != nil {
return fmt.Errorf("unable to extract email from given token: %w", err)
}

bpConfig, err := GetBackplaneConfiguration()
if err != nil {
return fmt.Errorf("error retrieving backplane configuration: %w", err)
}

initialRoleArn := ""
switch assumeArgs.environment {
case EnvProd, "":
initialRoleArn = fmt.Sprintf(InitialRoleArnTemplate, ProdPayerAccountID)
case EnvStg:
initialRoleArn = fmt.Sprintf(InitialRoleArnTemplate, StgPayerAccountId)
case EnvInt:
initialRoleArn = fmt.Sprintf(InitialRoleArnTemplate, IntPayerAccountId)
default:
return fmt.Errorf("unrecognized environment %v", assumeArgs.environment)
}

initialClient, err := StsClientWithProxy(bpConfig.ProxyURL)
if err != nil {
return fmt.Errorf("failed to create sts client: %w", err)
}
seedCredentials, err := AssumeRoleWithJWT(*ocmToken, assumeArgs.initialRoleArn, initialClient)
if err != nil {
return fmt.Errorf("failed to assume role using JWT: %w", err)
}

email, err := utils.GetStringFieldFromJWT(*ocmToken, "email")
seedCredentials, err := AssumeRoleWithJWT(*ocmToken, initialRoleArn, initialClient)
if err != nil {
return fmt.Errorf("unable to extract email from given token: %w", err)
return fmt.Errorf("failed to assume role using JWT: %w", err)
}

seedClient := sts.NewFromConfig(aws.Config{
Region: "us-east-1",
Credentials: NewStaticCredentialsProvider(*seedCredentials.AccessKeyId, *seedCredentials.SecretAccessKey, *seedCredentials.SessionToken),
})

var roleAssumeSequence []string
if assumeArgs.debugFile == "" {
clusterID, _, err := utils.DefaultOCMInterface.GetTargetCluster(args[0])
Expand Down Expand Up @@ -154,6 +171,11 @@ func runAssume(_ *cobra.Command, args []string) error {
roleAssumeSequence = append(roleAssumeSequence, strings.Split(string(arnBytes), "\n")...)
}

seedClient := sts.NewFromConfig(aws.Config{
Region: "us-east-1",
Credentials: NewStaticCredentialsProvider(seedCredentials.AccessKeyID, seedCredentials.SecretAccessKey, seedCredentials.SessionToken),
})

targetCredentials, err := AssumeRoleSequence(email, seedClient, roleAssumeSequence, bpConfig.ProxyURL, awsutil.DefaultSTSClientProviderFunc)
if err != nil {
return fmt.Errorf("failed to assume role sequence: %w", err)
Expand All @@ -173,10 +195,10 @@ func runAssume(_ *cobra.Command, args []string) error {
fmt.Printf("The AWS Console URL is:\n%s\n", signInFederationURL.String())
} else {
credsResponse := awsutil.AWSCredentialsResponse{
AccessKeyID: *targetCredentials.AccessKeyId,
SecretAccessKey: *targetCredentials.SecretAccessKey,
SessionToken: *targetCredentials.SessionToken,
Expiration: targetCredentials.Expiration.String(),
AccessKeyID: targetCredentials.AccessKeyID,
SecretAccessKey: targetCredentials.SecretAccessKey,
SessionToken: targetCredentials.SessionToken,
Expiration: targetCredentials.Expires.String(),
}
formattedResult, err := credsResponse.RenderOutput(assumeArgs.output)
if err != nil {
Expand All @@ -186,3 +208,23 @@ func runAssume(_ *cobra.Command, args []string) error {
}
return nil
}

type Environment string

func (e *Environment) String() string {
return string(*e)
}

func (e *Environment) Set(env string) error {
switch strings.ToLower(env) {
case "int", "stg", "prod":
*e = Environment(env)
return nil
default:
return errors.New(`must be one of "int", "stg", or "prod"`)
}
}

func (e *Environment) Type() string {
return "Environment"
}
105 changes: 53 additions & 52 deletions cmd/ocm-backplane/cloud/assume_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import (
"context"
"errors"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -71,11 +72,11 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
}, nil
}
NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider {
Expand All @@ -87,12 +88,12 @@ var _ = Describe("Cloud assume command", func() {
StatusCode: 200,
Body: io.NopCloser(strings.NewReader(`{"assumption_sequence":[{"name": "name_one", "arn": "arn_one"},{"name": "name_two", "arn": "arn_two"}]}`)),
}, nil).Times(1)
AssumeRoleSequence = func(roleSessionName string, seedClient awsutil.STSRoleAssumer, roleArnSequence []string, proxyURL string, stsClientProviderFunc awsutil.STSClientProviderFunc) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
Expiration: &testExpiration,
AssumeRoleSequence = func(roleSessionName string, seedClient stscreds.AssumeRoleAPIClient, roleArnSequence []string, proxyURL string, stsClientProviderFunc awsutil.STSClientProviderFunc) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
Expires: testExpiration,
}, nil
}

Expand All @@ -118,7 +119,7 @@ var _ = Describe("Cloud assume command", func() {
err := runAssume(nil, []string{testClusterID})
Expect(err.Error()).To(Equal("error retrieving backplane configuration: oops"))
})
It("should fail if cannot create create sts client with proxy", func() {
It("should fail if cannot create sts client with proxy", func() {
mockOcmInterface.EXPECT().GetOCMAccessToken().Return(&testOcmToken, nil).Times(1)
GetBackplaneConfiguration = func() (bpConfig config.BackplaneConfiguration, err error) {
return config.BackplaneConfiguration{
Expand All @@ -144,8 +145,8 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return nil, errors.New("failure")
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{}, errors.New("failure")
}

err := runAssume(nil, []string{testClusterID})
Expand All @@ -163,11 +164,11 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
}, nil
}

Expand All @@ -185,11 +186,11 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
}, nil
}
mockOcmInterface.EXPECT().GetTargetCluster(testClusterID).Return("", "", errors.New("oh no")).Times(1)
Expand All @@ -208,11 +209,11 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
}, nil
}
NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider {
Expand All @@ -235,11 +236,11 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
}, nil
}
NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider {
Expand All @@ -263,11 +264,11 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
}, nil
}
NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider {
Expand Down Expand Up @@ -295,11 +296,11 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
}, nil
}
NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider {
Expand All @@ -326,11 +327,11 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
}, nil
}
NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider {
Expand All @@ -342,8 +343,8 @@ var _ = Describe("Cloud assume command", func() {
StatusCode: 200,
Body: io.NopCloser(strings.NewReader(`{"assumption_sequence":[{"name": "name_one", "arn": "arn_one"},{"name": "name_two", "arn": "arn_two"}]}`)),
}, nil).Times(1)
AssumeRoleSequence = func(roleSessionName string, seedClient awsutil.STSRoleAssumer, roleArnSequence []string, proxyURL string, stsClientProviderFunc awsutil.STSClientProviderFunc) (*types.Credentials, error) {
return nil, errors.New("oops")
AssumeRoleSequence = func(roleSessionName string, seedClient stscreds.AssumeRoleAPIClient, roleArnSequence []string, proxyURL string, stsClientProviderFunc awsutil.STSClientProviderFunc) (aws.Credentials, error) {
return aws.Credentials{}, errors.New("oops")
}

err := runAssume(nil, []string{testClusterID})
Expand Down
8 changes: 4 additions & 4 deletions cmd/ocm-backplane/cloud/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ func runToken(*cobra.Command, []string) error {
}

credsResponse := awsutil.AWSCredentialsResponse{
AccessKeyID: *result.AccessKeyId,
SecretAccessKey: *result.SecretAccessKey,
SessionToken: *result.SessionToken,
Expiration: result.Expiration.String(),
AccessKeyID: result.AccessKeyID,
SecretAccessKey: result.SecretAccessKey,
SessionToken: result.SessionToken,
Expiration: result.Expires.String(),
}

formattedResult, err := credsResponse.RenderOutput(tokenArgs.output)
Expand Down
Loading

0 comments on commit 4b81bba

Please sign in to comment.