Skip to content

Commit

Permalink
Implement retry for eventual consistency in IAM updates
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexVulaj committed Oct 23, 2023
1 parent 3e569d6 commit 624f36b
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 163 deletions.
82 changes: 61 additions & 21 deletions cmd/ocm-backplane/cloud/assume.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cloud
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
Expand All @@ -17,14 +18,20 @@ import (
)

const (
DefaultInitialRoleArn = "arn:aws:iam::922711891673:role/SRE-Support-Role"
InitialRoleArnTemplate = "arn:aws:iam::%v:role/SRE-Support-Role"
EnvProd Environment = "prod"
ProdPayerAccountID = 922711891673
EnvStg Environment = "stg"
StgPayerAccountId = 277304166082
EnvInt Environment = "int"
IntPayerAccountId = 277304166082
)

var assumeArgs struct {
initialRoleArn string
output string
debugFile string
console bool
environment Environment
output string
debugFile string
console bool
}

var StsClientWithProxy = awsutil.StsClientWithProxy
Expand Down Expand Up @@ -61,7 +68,7 @@ backplane cloud assume e3b2fdc5-d9a7-435e-8870-312689cfb29c --console`,

func init() {
flags := AssumeCmd.Flags()
flags.StringVar(&assumeArgs.initialRoleArn, "initial-role-arn", DefaultInitialRoleArn, "The arn of the role for which to start the role assume process.")
flags.Var(&assumeArgs.environment, "env", "The OCM environment in which the target cluster is deployed. Valid values are `prod`, `stg`, and `int`.")
flags.StringVarP(&assumeArgs.output, "output", "o", "env", "Format the output of the console response.")
flags.StringVar(&assumeArgs.debugFile, "debug-file", "", "A file containing the list of ARNs to assume in order, not including the initial role ARN. Providing this flag will bypass calls to the backplane API to retrieve the assume role chain. The file should be a plain text file with each ARN on a new line.")
flags.BoolVar(&assumeArgs.console, "console", false, "Outputs a console url to access the targeted cluster instead of the STS credentials.")
Expand All @@ -86,30 +93,38 @@ func runAssume(_ *cobra.Command, args []string) error {
return fmt.Errorf("failed to retrieve OCM token: %w", err)
}

email, err := utils.GetStringFieldFromJWT(*ocmToken, "email")
if err != nil {
return fmt.Errorf("unable to extract email from given token: %w", err)
}

bpConfig, err := GetBackplaneConfiguration()
if err != nil {
return fmt.Errorf("error retrieving backplane configuration: %w", err)
}

initialRoleArn := ""
switch assumeArgs.environment {
case EnvProd, "":
initialRoleArn = fmt.Sprintf(InitialRoleArnTemplate, ProdPayerAccountID)
case EnvStg:
initialRoleArn = fmt.Sprintf(InitialRoleArnTemplate, StgPayerAccountId)
case EnvInt:
initialRoleArn = fmt.Sprintf(InitialRoleArnTemplate, IntPayerAccountId)
default:
return fmt.Errorf("unrecognized environment %v", assumeArgs.environment)
}

initialClient, err := StsClientWithProxy(bpConfig.ProxyURL)
if err != nil {
return fmt.Errorf("failed to create sts client: %w", err)
}
seedCredentials, err := AssumeRoleWithJWT(*ocmToken, assumeArgs.initialRoleArn, initialClient)
if err != nil {
return fmt.Errorf("failed to assume role using JWT: %w", err)
}

email, err := utils.GetStringFieldFromJWT(*ocmToken, "email")
seedCredentials, err := AssumeRoleWithJWT(*ocmToken, initialRoleArn, initialClient)
if err != nil {
return fmt.Errorf("unable to extract email from given token: %w", err)
return fmt.Errorf("failed to assume role using JWT: %w", err)
}

seedClient := sts.NewFromConfig(aws.Config{
Region: "us-east-1",
Credentials: NewStaticCredentialsProvider(*seedCredentials.AccessKeyId, *seedCredentials.SecretAccessKey, *seedCredentials.SessionToken),
})

var roleAssumeSequence []string
if assumeArgs.debugFile == "" {
clusterID, _, err := utils.DefaultOCMInterface.GetTargetCluster(args[0])
Expand Down Expand Up @@ -154,6 +169,11 @@ func runAssume(_ *cobra.Command, args []string) error {
roleAssumeSequence = append(roleAssumeSequence, strings.Split(string(arnBytes), "\n")...)
}

seedClient := sts.NewFromConfig(aws.Config{
Region: "us-east-1",
Credentials: NewStaticCredentialsProvider(seedCredentials.AccessKeyID, seedCredentials.SecretAccessKey, seedCredentials.SessionToken),
})

targetCredentials, err := AssumeRoleSequence(email, seedClient, roleAssumeSequence, bpConfig.ProxyURL, awsutil.DefaultSTSClientProviderFunc)
if err != nil {
return fmt.Errorf("failed to assume role sequence: %w", err)
Expand All @@ -173,10 +193,10 @@ func runAssume(_ *cobra.Command, args []string) error {
fmt.Printf("The AWS Console URL is:\n%s\n", signInFederationURL.String())
} else {
credsResponse := awsutil.AWSCredentialsResponse{
AccessKeyID: *targetCredentials.AccessKeyId,
SecretAccessKey: *targetCredentials.SecretAccessKey,
SessionToken: *targetCredentials.SessionToken,
Expiration: targetCredentials.Expiration.String(),
AccessKeyID: targetCredentials.AccessKeyID,
SecretAccessKey: targetCredentials.SecretAccessKey,
SessionToken: targetCredentials.SessionToken,
Expiration: targetCredentials.Expires.String(),
}
formattedResult, err := credsResponse.RenderOutput(assumeArgs.output)
if err != nil {
Expand All @@ -186,3 +206,23 @@ func runAssume(_ *cobra.Command, args []string) error {
}
return nil
}

type Environment string

func (e *Environment) String() string {
return string(*e)
}

func (e *Environment) Set(env string) error {
switch strings.ToLower(env) {
case "int", "stg", "prod":
*e = Environment(env)
return nil
default:
return errors.New(`must be one of "int", "stg", or "prod"`)
}
}

func (e *Environment) Type() string {
return "Environment"
}
105 changes: 53 additions & 52 deletions cmd/ocm-backplane/cloud/assume_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import (
"context"
"errors"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -71,11 +72,11 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
}, nil
}
NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider {
Expand All @@ -87,12 +88,12 @@ var _ = Describe("Cloud assume command", func() {
StatusCode: 200,
Body: io.NopCloser(strings.NewReader(`{"assumption_sequence":[{"name": "name_one", "arn": "arn_one"},{"name": "name_two", "arn": "arn_two"}]}`)),
}, nil).Times(1)
AssumeRoleSequence = func(roleSessionName string, seedClient awsutil.STSRoleAssumer, roleArnSequence []string, proxyURL string, stsClientProviderFunc awsutil.STSClientProviderFunc) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
Expiration: &testExpiration,
AssumeRoleSequence = func(roleSessionName string, seedClient stscreds.AssumeRoleAPIClient, roleArnSequence []string, proxyURL string, stsClientProviderFunc awsutil.STSClientProviderFunc) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
Expires: testExpiration,
}, nil
}

Expand All @@ -118,7 +119,7 @@ var _ = Describe("Cloud assume command", func() {
err := runAssume(nil, []string{testClusterID})
Expect(err.Error()).To(Equal("error retrieving backplane configuration: oops"))
})
It("should fail if cannot create create sts client with proxy", func() {
It("should fail if cannot create sts client with proxy", func() {
mockOcmInterface.EXPECT().GetOCMAccessToken().Return(&testOcmToken, nil).Times(1)
GetBackplaneConfiguration = func() (bpConfig config.BackplaneConfiguration, err error) {
return config.BackplaneConfiguration{
Expand All @@ -144,8 +145,8 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return nil, errors.New("failure")
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{}, errors.New("failure")
}

err := runAssume(nil, []string{testClusterID})
Expand All @@ -163,11 +164,11 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
}, nil
}

Expand All @@ -185,11 +186,11 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
}, nil
}
mockOcmInterface.EXPECT().GetTargetCluster(testClusterID).Return("", "", errors.New("oh no")).Times(1)
Expand All @@ -208,11 +209,11 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
}, nil
}
NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider {
Expand All @@ -235,11 +236,11 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
}, nil
}
NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider {
Expand All @@ -263,11 +264,11 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
}, nil
}
NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider {
Expand Down Expand Up @@ -295,11 +296,11 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
}, nil
}
NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider {
Expand All @@ -326,11 +327,11 @@ var _ = Describe("Cloud assume command", func() {
StsClientWithProxy = func(proxyURL string) (*sts.Client, error) {
return &sts.Client{}, nil
}
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient awsutil.STSRoleWithWebIdentityAssumer) (*types.Credentials, error) {
return &types.Credentials{
AccessKeyId: &testAccessKeyID,
SecretAccessKey: &testSecretAccessKey,
SessionToken: &testSessionToken,
AssumeRoleWithJWT = func(jwt string, roleArn string, stsClient stscreds.AssumeRoleWithWebIdentityAPIClient) (aws.Credentials, error) {
return aws.Credentials{
AccessKeyID: testAccessKeyID,
SecretAccessKey: testSecretAccessKey,
SessionToken: testSessionToken,
}, nil
}
NewStaticCredentialsProvider = func(key, secret, session string) credentials.StaticCredentialsProvider {
Expand All @@ -342,8 +343,8 @@ var _ = Describe("Cloud assume command", func() {
StatusCode: 200,
Body: io.NopCloser(strings.NewReader(`{"assumption_sequence":[{"name": "name_one", "arn": "arn_one"},{"name": "name_two", "arn": "arn_two"}]}`)),
}, nil).Times(1)
AssumeRoleSequence = func(roleSessionName string, seedClient awsutil.STSRoleAssumer, roleArnSequence []string, proxyURL string, stsClientProviderFunc awsutil.STSClientProviderFunc) (*types.Credentials, error) {
return nil, errors.New("oops")
AssumeRoleSequence = func(roleSessionName string, seedClient stscreds.AssumeRoleAPIClient, roleArnSequence []string, proxyURL string, stsClientProviderFunc awsutil.STSClientProviderFunc) (aws.Credentials, error) {
return aws.Credentials{}, errors.New("oops")
}

err := runAssume(nil, []string{testClusterID})
Expand Down
8 changes: 4 additions & 4 deletions cmd/ocm-backplane/cloud/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ func runToken(*cobra.Command, []string) error {
}

credsResponse := awsutil.AWSCredentialsResponse{
AccessKeyID: *result.AccessKeyId,
SecretAccessKey: *result.SecretAccessKey,
SessionToken: *result.SessionToken,
Expiration: result.Expiration.String(),
AccessKeyID: result.AccessKeyID,
SecretAccessKey: result.SecretAccessKey,
SessionToken: result.SessionToken,
Expiration: result.Expires.String(),
}

formattedResult, err := credsResponse.RenderOutput(tokenArgs.output)
Expand Down
Loading

0 comments on commit 624f36b

Please sign in to comment.