diff --git a/lib/configurators/aws/aws.go b/lib/configurators/aws/aws.go index cece7047dfc13..839bf08de2df1 100644 --- a/lib/configurators/aws/aws.go +++ b/lib/configurators/aws/aws.go @@ -22,6 +22,8 @@ import ( "context" "errors" "fmt" + "io" + "os" "slices" "strings" @@ -38,6 +40,7 @@ import ( apiutils "github.com/gravitational/teleport/api/utils" apiawsutils "github.com/gravitational/teleport/api/utils/aws" awslib "github.com/gravitational/teleport/lib/cloud/aws" + awsimds "github.com/gravitational/teleport/lib/cloud/imds/aws" "github.com/gravitational/teleport/lib/configurators" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/modules" @@ -316,6 +319,41 @@ type ssmClient interface { CreateDocument(ctx context.Context, params *ssm.CreateDocumentInput, optFns ...func(*ssm.Options)) (*ssm.CreateDocumentOutput, error) } +type localRegionGetter interface { + GetRegion(context.Context) (string, error) +} + +func getLocalRegion(ctx context.Context, localRegionGetter localRegionGetter) (string, bool) { + if localRegionGetter == nil { + imdsClient, err := awsimds.NewInstanceMetadataClient(ctx) + if err != nil || !imdsClient.IsAvailable(ctx) { + return "", false + } + localRegionGetter = imdsClient + } + + region, err := localRegionGetter.GetRegion(ctx) + if err != nil || region == "" { + return "", false + } + return region, true +} + +func getFallbackRegion(ctx context.Context, w io.Writer, localRegionGetter localRegionGetter) string { + if localRegion, ok := getLocalRegion(ctx, localRegionGetter); ok { + fmt.Fprintf(w, "Using region %q from instance metadata.\n", localRegion) + return localRegion + } + + // Fallback to us-east-1, which also supports fips. + fmt.Fprint(w, ` +Warning: No region found from the default AWS config or instance metadata. Defaulting to 'us-east-1'. +To avoid seeing this warning, please provide a region in your AWS config or through the AWS_REGION environment variable. + +`) + return "us-east-1" +} + // CheckAndSetDefaults checks and set configuration default values. func (c *ConfiguratorConfig) CheckAndSetDefaults() error { ctx := context.Background() @@ -342,6 +380,10 @@ func (c *ConfiguratorConfig) CheckAndSetDefaults() error { if err != nil { return trace.Wrap(err) } + + if cfg.Region == "" { + cfg.Region = getFallbackRegion(ctx, os.Stdout, nil) + } c.awsCfg = &cfg } diff --git a/lib/configurators/aws/aws_test.go b/lib/configurators/aws/aws_test.go index a64552620b442..c0b34dffaad2e 100644 --- a/lib/configurators/aws/aws_test.go +++ b/lib/configurators/aws/aws_test.go @@ -21,6 +21,7 @@ package aws import ( "context" "fmt" + "io" "regexp" "sort" "testing" @@ -1918,3 +1919,42 @@ func (m *iamMock) GetRole(ctx context.Context, input *iam.GetRoleInput, optFns . arn := fmt.Sprintf("arn:%s:iam::%s:role%s%s", m.partition, m.account, path, roleName) return &iam.GetRoleOutput{Role: &iamtypes.Role{Arn: &arn}}, nil } + +type mockLocalRegionGetter struct { + region string + err error +} + +func (m mockLocalRegionGetter) GetRegion(context.Context) (string, error) { + return m.region, m.err +} + +func Test_getFallbackRegion(t *testing.T) { + tests := []struct { + name string + localRegionGetter localRegionGetter + wantRegion string + }{ + { + name: "fallback to retrieved local region", + localRegionGetter: mockLocalRegionGetter{ + region: "my-local-region", + }, + wantRegion: "my-local-region", + }, + { + name: "fallback to us-east", + localRegionGetter: mockLocalRegionGetter{ + err: fmt.Errorf("failed to get local region"), + }, + wantRegion: "us-east-1", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + region := getFallbackRegion(context.Background(), io.Discard, test.localRegionGetter) + require.Equal(t, test.wantRegion, region) + }) + } +}