diff --git a/aws_config_test.go b/aws_config_test.go index 6ed0073b..55c0ca18 100644 --- a/aws_config_test.go +++ b/aws_config_test.go @@ -3071,6 +3071,197 @@ web_identity_token_file = no-such-file } } +func TestStsEndpoint(t *testing.T) { + testcases := map[string]struct { + Config Config + SetConfig bool + SetEnv string + SetInvalidEnv string + ConfigFile string + InvalidConfigFile string + ExpectedCredentials aws.Credentials + }{ + "config": { + Config: Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetConfig: true, + ExpectedCredentials: mockdata.MockStaticCredentials, + }, + + "service envvar": { + Config: Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetEnv: "AWS_ENDPOINT_URL_STS", + ExpectedCredentials: mockdata.MockStaticCredentials, + }, + + "global envvar": { + Config: Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetEnv: "AWS_ENDPOINT_URL", + ExpectedCredentials: mockdata.MockStaticCredentials, + }, + + "service envvar overrides global envvar": { + Config: Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetEnv: "AWS_ENDPOINT_URL_STS", + SetInvalidEnv: "AWS_ENDPOINT_URL", + ExpectedCredentials: mockdata.MockStaticCredentials, + }, + + // "service config_file": { + // Config: Config{ + // Profile: "default", + // }, + // ConfigFile: ` + // [default] + // aws_access_key_id = DefaultSharedCredentialsAccessKey + // aws_secret_access_key = DefaultSharedCredentialsSecretKey + // services = sts-test + + // [services sts-test] + // sts = + // endpoint_url = %s + // `, + // ExpectedCredentials: aws.Credentials{ + // AccessKeyID: "DefaultSharedCredentialsAccessKey", + // SecretAccessKey: "DefaultSharedCredentialsSecretKey", + // Source: sharedConfigCredentialsProvider, + // }, + // }, + + // TODO: service envvar overrides service config_file + + // TODO: does global envvar override service config_file? + + "global config_file": { + Config: Config{ + Profile: "default", + }, + ConfigFile: ` +[default] +aws_access_key_id = DefaultSharedCredentialsAccessKey +aws_secret_access_key = DefaultSharedCredentialsSecretKey +endpoint_url = %s +`, + ExpectedCredentials: aws.Credentials{ + AccessKeyID: "DefaultSharedCredentialsAccessKey", + SecretAccessKey: "DefaultSharedCredentialsSecretKey", + Source: sharedConfigCredentialsProvider, + }, + }, + + "global envvar overrides global config_file": { + Config: Config{ + Profile: "default", + }, + SetEnv: "AWS_ENDPOINT_URL", + InvalidConfigFile: ` +[default] +aws_access_key_id = DefaultSharedCredentialsAccessKey +aws_secret_access_key = DefaultSharedCredentialsSecretKey +endpoint_url = %s +`, + ExpectedCredentials: aws.Credentials{ + AccessKeyID: "DefaultSharedCredentialsAccessKey", + SecretAccessKey: "DefaultSharedCredentialsSecretKey", + Source: sharedConfigCredentialsProvider, + }, + }, + + "service envvar overrides global config_file": { + Config: Config{ + Profile: "default", + }, + SetEnv: "AWS_ENDPOINT_URL_STS", + InvalidConfigFile: ` +[default] +aws_access_key_id = DefaultSharedCredentialsAccessKey +aws_secret_access_key = DefaultSharedCredentialsSecretKey +endpoint_url = %s +`, + ExpectedCredentials: aws.Credentials{ + AccessKeyID: "DefaultSharedCredentialsAccessKey", + SecretAccessKey: "DefaultSharedCredentialsSecretKey", + Source: sharedConfigCredentialsProvider, + }, + }, + } + + for name, testcase := range testcases { + testcase := testcase + + t.Run(name, func(t *testing.T) { + servicemocks.InitSessionTestEnv(t) + + ctx := context.Background() + + ts := servicemocks.MockAwsApiServer("STS", []*servicemocks.MockEndpoint{ + servicemocks.MockStsGetCallerIdentityValidEndpoint, + }) + defer ts.Close() + stsEndpoint := ts.URL + + invalidTS := servicemocks.MockAwsApiServer("STS", []*servicemocks.MockEndpoint{ + servicemocks.MockStsGetCallerIdentityInvalidEndpointAccessDenied, + }) + defer invalidTS.Close() + stsInvalidEndpoint := invalidTS.URL + + if testcase.SetConfig { + testcase.Config.StsEndpoint = stsEndpoint + } + if testcase.SetEnv != "" { + t.Setenv(testcase.SetEnv, stsEndpoint) + } + if testcase.SetInvalidEnv != "" { + t.Setenv(testcase.SetInvalidEnv, stsInvalidEndpoint) + } + if testcase.ConfigFile != "" { + tempDir := t.TempDir() + filename := writeSharedConfigFile(t, &testcase.Config, tempDir, fmt.Sprintf(testcase.ConfigFile, stsEndpoint)) + testcase.ExpectedCredentials.Source = sharedConfigCredentialsSource(filename) + } + if testcase.InvalidConfigFile != "" { + tempDir := t.TempDir() + filename := writeSharedConfigFile(t, &testcase.Config, tempDir, fmt.Sprintf(testcase.InvalidConfigFile, stsInvalidEndpoint)) + testcase.ExpectedCredentials.Source = sharedConfigCredentialsSource(filename) + } + + ctx, awsConfig, diags := GetAwsConfig(ctx, &testcase.Config) + + if diff := cmp.Diff(diags, diag.Diagnostics{}); diff != "" { + t.Errorf("Unexpected response (+wanted, -got): %s", diff) + } + if diags.HasError() { + return + } + + credentialsValue, err := awsConfig.Credentials.Retrieve(ctx) + if err != nil { + t.Fatalf("unexpected credentials Retrieve() error: %s", err) + } + + if diff := cmp.Diff(credentialsValue, testcase.ExpectedCredentials, cmpopts.IgnoreFields(aws.Credentials{}, "Expires")); diff != "" { + t.Fatalf("unexpected credentials: (- got, + expected)\n%s", diff) + } + }) + } +} + var _ configtesting.TestDriver = &testDriver{} type testDriver struct { @@ -4001,3 +4192,21 @@ func configureHcLogger(name string, output io.Writer) hclog.Logger { return logger } + +func writeSharedConfigFile(t *testing.T, config *Config, tempDir, content string) string { + t.Helper() + + file, err := os.Create(filepath.Join(tempDir, "aws-sdk-go-base-shared-configuration-file")) + if err != nil { + t.Fatalf("creating shared configuration file: %s", err) + } + + _, err = file.WriteString(content) + if err != nil { + t.Fatalf(" writing shared configuration file: %s", err) + } + + config.SharedConfigFiles = append(config.SharedConfigFiles, file.Name()) + + return file.Name() +}