diff --git a/lib/cloud/aws/config/config.go b/lib/cloud/awsconfig/awsconfig.go similarity index 68% rename from lib/cloud/aws/config/config.go rename to lib/cloud/awsconfig/awsconfig.go index 13e032bf3fe65..92f7e8aa96e86 100644 --- a/lib/cloud/aws/config/config.go +++ b/lib/cloud/awsconfig/awsconfig.go @@ -14,7 +14,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -package config +package awsconfig import ( "context" @@ -43,13 +43,13 @@ const ( credentialsSourceIntegration ) -// AWSIntegrationSessionProvider defines a function that creates a credential provider from a region and an integration. +// IntegrationSessionProviderFunc defines a function that creates a credential provider from a region and an integration. // This is used to generate aws configs for clients that must use an integration instead of ambient credentials. -type AWSIntegrationCredentialProvider func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) +type IntegrationCredentialProviderFunc func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) -// awsOptions is a struct of additional options for assuming an AWS role +// options is a struct of additional options for assuming an AWS role // when construction an underlying AWS config. -type awsOptions struct { +type options struct { // baseConfigis a config to use instead of the default config for an // AWS region, which is used to enable role chaining. baseConfig *aws.Config @@ -61,15 +61,15 @@ type awsOptions struct { credentialsSource credentialsSource // integration is the name of the integration to be used to fetch the credentials. integration string - // awsIntegrationCredentialsProvider is the integration credential provider to use. - awsIntegrationCredentialsProvider AWSIntegrationCredentialProvider + // integrationCredentialsProvider is the integration credential provider to use. + integrationCredentialsProvider IntegrationCredentialProviderFunc // customRetryer is a custom retryer to use for the config. customRetryer func() aws.Retryer // maxRetries is the maximum number of retries to use for the config. maxRetries *int } -func (a *awsOptions) checkAndSetDefaults() error { +func (a *options) checkAndSetDefaults() error { switch a.credentialsSource { case credentialsSourceAmbient: if a.integration != "" { @@ -86,28 +86,28 @@ func (a *awsOptions) checkAndSetDefaults() error { return nil } -// AWSOptionsFn is an option function for setting additional options +// OptionsFn is an option function for setting additional options // when getting an AWS config. -type AWSOptionsFn func(*awsOptions) +type OptionsFn func(*options) // WithAssumeRole configures options needed for assuming an AWS role. -func WithAssumeRole(roleARN, externalID string) AWSOptionsFn { - return func(options *awsOptions) { +func WithAssumeRole(roleARN, externalID string) OptionsFn { + return func(options *options) { options.assumeRoleARN = roleARN options.assumeRoleExternalID = externalID } } // WithRetryer sets a custom retryer for the config. -func WithRetryer(retryer func() aws.Retryer) AWSOptionsFn { - return func(options *awsOptions) { +func WithRetryer(retryer func() aws.Retryer) OptionsFn { + return func(options *options) { options.customRetryer = retryer } } // WithMaxRetries sets the maximum allowed value for the sdk to keep retrying. -func WithMaxRetries(maxRetries int) AWSOptionsFn { - return func(options *awsOptions) { +func WithMaxRetries(maxRetries int) OptionsFn { + return func(options *options) { options.maxRetries = &maxRetries } } @@ -115,7 +115,7 @@ func WithMaxRetries(maxRetries int) AWSOptionsFn { // WithCredentialsMaybeIntegration sets the credential source to be // - ambient if the integration is an empty string // - integration, otherwise -func WithCredentialsMaybeIntegration(integration string) AWSOptionsFn { +func WithCredentialsMaybeIntegration(integration string) OptionsFn { if integration != "" { return withIntegrationCredentials(integration) } @@ -125,36 +125,36 @@ func WithCredentialsMaybeIntegration(integration string) AWSOptionsFn { // withIntegrationCredentials configures options with an Integration that must be used to fetch Credentials to assume a role. // This prevents the usage of AWS environment credentials. -func withIntegrationCredentials(integration string) AWSOptionsFn { - return func(options *awsOptions) { +func withIntegrationCredentials(integration string) OptionsFn { + return func(options *options) { options.credentialsSource = credentialsSourceIntegration options.integration = integration } } // WithAmbientCredentials configures options to use the ambient credentials. -func WithAmbientCredentials() AWSOptionsFn { - return func(options *awsOptions) { +func WithAmbientCredentials() OptionsFn { + return func(options *options) { options.credentialsSource = credentialsSourceAmbient } } -// WithAWSIntegrationCredentialProvider sets the integration credential provider. -func WithAWSIntegrationCredentialProvider(cred AWSIntegrationCredentialProvider) AWSOptionsFn { - return func(options *awsOptions) { - options.awsIntegrationCredentialsProvider = cred +// WithIntegrationCredentialProvider sets the integration credential provider. +func WithIntegrationCredentialProvider(cred IntegrationCredentialProviderFunc) OptionsFn { + return func(options *options) { + options.integrationCredentialsProvider = cred } } -// GetAWSConfig returns an AWS config for the specified region, optionally +// GetConfig returns an AWS config for the specified region, optionally // assuming AWS IAM Roles. -func GetAWSConfig(ctx context.Context, region string, opts ...AWSOptionsFn) (aws.Config, error) { - var options awsOptions +func GetConfig(ctx context.Context, region string, opts ...OptionsFn) (aws.Config, error) { + var options options for _, opt := range opts { opt(&options) } if options.baseConfig == nil { - cfg, err := getAWSConfigForRegion(ctx, region, options) + cfg, err := getConfigForRegion(ctx, region, options) if err != nil { return aws.Config{}, trace.Wrap(err) } @@ -163,17 +163,17 @@ func GetAWSConfig(ctx context.Context, region string, opts ...AWSOptionsFn) (aws if options.assumeRoleARN == "" { return *options.baseConfig, nil } - return getAWSConfigForRole(ctx, region, options) + return getConfigForRole(ctx, region, options) } -// awsAmbientConfigProvider loads a new config using the environment variables. -func awsAmbientConfigProvider(region string, cred aws.CredentialsProvider, options awsOptions) (aws.Config, error) { - opts := buildAWSConfigOptions(region, cred, options) +// ambientConfigProvider loads a new config using the environment variables. +func ambientConfigProvider(region string, cred aws.CredentialsProvider, options options) (aws.Config, error) { + opts := buildConfigOptions(region, cred, options) cfg, err := config.LoadDefaultConfig(context.Background(), opts...) return cfg, trace.Wrap(err) } -func buildAWSConfigOptions(region string, cred aws.CredentialsProvider, options awsOptions) []func(*config.LoadOptions) error { +func buildConfigOptions(region string, cred aws.CredentialsProvider, options options) []func(*config.LoadOptions) error { opts := []func(*config.LoadOptions) error{ config.WithDefaultRegion(defaultRegion), config.WithRegion(region), @@ -191,21 +191,21 @@ func buildAWSConfigOptions(region string, cred aws.CredentialsProvider, options return opts } -// getAWSConfigForRegion returns AWS config for the specified region. -func getAWSConfigForRegion(ctx context.Context, region string, options awsOptions) (aws.Config, error) { +// getConfigForRegion returns AWS config for the specified region. +func getConfigForRegion(ctx context.Context, region string, options options) (aws.Config, error) { if err := options.checkAndSetDefaults(); err != nil { return aws.Config{}, trace.Wrap(err) } var cred aws.CredentialsProvider if options.credentialsSource == credentialsSourceIntegration { - if options.awsIntegrationCredentialsProvider == nil { + if options.integrationCredentialsProvider == nil { return aws.Config{}, trace.BadParameter("missing aws integration credential provider") } slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", options.integration) var err error - cred, err = options.awsIntegrationCredentialsProvider(ctx, region, options.integration) + cred, err = options.integrationCredentialsProvider(ctx, region, options.integration) if err != nil { return aws.Config{}, trace.Wrap(err) } @@ -213,12 +213,12 @@ func getAWSConfigForRegion(ctx context.Context, region string, options awsOption slog.DebugContext(ctx, "Initializing AWS config from environment", "region", region) } - cfg, err := awsAmbientConfigProvider(region, cred, options) + cfg, err := ambientConfigProvider(region, cred, options) return cfg, trace.Wrap(err) } -// getAWSConfigForRole returns an AWS config for the specified region and role. -func getAWSConfigForRole(ctx context.Context, region string, options awsOptions) (aws.Config, error) { +// getConfigForRole returns an AWS config for the specified region and role. +func getConfigForRole(ctx context.Context, region string, options options) (aws.Config, error) { if err := options.checkAndSetDefaults(); err != nil { return aws.Config{}, trace.Wrap(err) } @@ -235,7 +235,7 @@ func getAWSConfigForRole(ctx context.Context, region string, options awsOptions) return aws.Config{}, trace.Wrap(err) } - opts := buildAWSConfigOptions(region, cred, options) + opts := buildConfigOptions(region, cred, options) cfg, err := config.LoadDefaultConfig(ctx, opts...) return cfg, trace.Wrap(err) } diff --git a/lib/cloud/aws/config/config_test.go b/lib/cloud/awsconfig/awsconfig_test.go similarity index 77% rename from lib/cloud/aws/config/config_test.go rename to lib/cloud/awsconfig/awsconfig_test.go index b6a0b867b7965..5c0ab10ed6abb 100644 --- a/lib/cloud/aws/config/config_test.go +++ b/lib/cloud/awsconfig/awsconfig_test.go @@ -14,7 +14,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -package config +package awsconfig import ( "context" @@ -33,14 +33,14 @@ func (m *mockCredentialProvider) Retrieve(ctx context.Context) (aws.Credentials, return m.cred, nil } -func TestGetAWSConfigIntegration(t *testing.T) { +func TestGetConfigIntegration(t *testing.T) { t.Parallel() dummyIntegration := "integration-test" dummyRegion := "test-region-123" t.Run("without an integration credential provider, must return missing credential provider error", func(t *testing.T) { ctx := context.Background() - _, err := GetAWSConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration)) + _, err := GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration)) require.True(t, trace.IsBadParameter(err), "unexpected error: %v", err) require.ErrorContains(t, err, "missing aws integration credential provider") }) @@ -48,9 +48,9 @@ func TestGetAWSConfigIntegration(t *testing.T) { t.Run("with an integration credential provider, must return the credentials", func(t *testing.T) { ctx := context.Background() - cfg, err := GetAWSConfig(ctx, dummyRegion, + cfg, err := GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration), - WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { + WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { if region == dummyRegion && integration == dummyIntegration { return &mockCredentialProvider{ cred: aws.Credentials{ @@ -69,9 +69,9 @@ func TestGetAWSConfigIntegration(t *testing.T) { t.Run("with an integration credential provider, but using an empty integration falls back to ambient credentials", func(t *testing.T) { ctx := context.Background() - _, err := GetAWSConfig(ctx, dummyRegion, + _, err := GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(""), - WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { + WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { require.Fail(t, "this function should not be called") return nil, nil })) @@ -81,9 +81,9 @@ func TestGetAWSConfigIntegration(t *testing.T) { t.Run("with an integration credential provider, but using ambient credentials", func(t *testing.T) { ctx := context.Background() - _, err := GetAWSConfig(ctx, dummyRegion, + _, err := GetConfig(ctx, dummyRegion, WithAmbientCredentials(), - WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { + WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { require.Fail(t, "this function should not be called") return nil, nil })) @@ -93,8 +93,8 @@ func TestGetAWSConfigIntegration(t *testing.T) { t.Run("with an integration credential provider, but no credential source", func(t *testing.T) { ctx := context.Background() - _, err := GetAWSConfig(ctx, dummyRegion, - WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { + _, err := GetConfig(ctx, dummyRegion, + WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) { require.Fail(t, "this function should not be called") return nil, nil })) diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index 078ab234faca4..1f8b40af2cdbf 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -53,7 +53,7 @@ import ( "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/cloud" - "github.com/gravitational/teleport/lib/cloud/aws/config" + "github.com/gravitational/teleport/lib/cloud/awsconfig" gcpimds "github.com/gravitational/teleport/lib/cloud/imds/gcp" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/integrations/awsoidc" @@ -119,7 +119,7 @@ type Config struct { // GetEC2Client gets an AWS EC2 client for the given region. GetEC2Client server.EC2ClientGetter // GetSSMClient gets an AWS SSM client for the given region. - GetSSMClient func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (server.SSMClient, error) + GetSSMClient func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (server.SSMClient, error) // IntegrationOnlyCredentials discards any Matcher that don't have an Integration. // When true, ambient credentials (used by the Cloud SDKs) are not used. IntegrationOnlyCredentials bool @@ -224,7 +224,7 @@ kubernetes matchers are present.`) c.CloudClients = cloudClients } if c.GetEC2Client == nil { - c.GetEC2Client = func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) { + c.GetEC2Client = func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (ec2.DescribeInstancesAPIClient, error) { cfg, err := c.getAWSConfig(ctx, region, opts...) if err != nil { return nil, trace.Wrap(err) @@ -233,7 +233,7 @@ kubernetes matchers are present.`) } } if c.GetSSMClient == nil { - c.GetSSMClient = func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (server.SSMClient, error) { + c.GetSSMClient = func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (server.SSMClient, error) { cfg, err := c.getAWSConfig(ctx, region, opts...) if err != nil { return nil, trace.Wrap(err) @@ -296,8 +296,8 @@ kubernetes matchers are present.`) return nil } -func (c *Config) getAWSConfig(ctx context.Context, region string, opts ...config.AWSOptionsFn) (aws.Config, error) { - opts = append(opts, config.WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integrationName string) (aws.CredentialsProvider, error) { +func (c *Config) getAWSConfig(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (aws.Config, error) { + opts = append(opts, awsconfig.WithIntegrationCredentialProvider(func(ctx context.Context, region, integrationName string) (aws.CredentialsProvider, error) { integration, err := c.AccessPoint.GetIntegration(ctx, integrationName) if err != nil { return nil, trace.Wrap(err) @@ -316,7 +316,7 @@ func (c *Config) getAWSConfig(ctx context.Context, region string, opts ...config }) return cred, trace.Wrap(err) })) - cfg, err := config.GetAWSConfig(ctx, region, opts...) + cfg, err := awsconfig.GetConfig(ctx, region, opts...) return cfg, trace.Wrap(err) } @@ -1066,7 +1066,7 @@ func (s *Server) handleEC2RemoteInstallation(instances *server.EC2Instances) err // TODO(gavin): support assume_role_arn for ec2. ssmClient, err := s.GetSSMClient(s.ctx, instances.Region, - config.WithCredentialsMaybeIntegration(instances.Integration), + awsconfig.WithCredentialsMaybeIntegration(instances.Integration), ) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index 42516c9cf7491..7e3e90722d4bb 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -77,7 +77,7 @@ import ( "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/cloud" - "github.com/gravitational/teleport/lib/cloud/aws/config" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/cloud/gcp" gcpimds "github.com/gravitational/teleport/lib/cloud/imds/gcp" @@ -772,10 +772,10 @@ func TestDiscoveryServer(t *testing.T) { } server, err := New(authz.ContextWithUser(context.Background(), identity.I), &Config{ - GetEC2Client: func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) { + GetEC2Client: func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (ec2.DescribeInstancesAPIClient, error) { return ec2Client, nil }, - GetSSMClient: func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (server.SSMClient, error) { + GetSSMClient: func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (server.SSMClient, error) { return tc.ssm, nil }, ClusterFeatures: func() proto.Features { return proto.Features{} }, @@ -914,7 +914,7 @@ func TestDiscoveryServerConcurrency(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { require.NoError(t, authClient.Close()) }) - getEC2Client := func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) { + getEC2Client := func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (ec2.DescribeInstancesAPIClient, error) { return ec2Client, nil } diff --git a/lib/srv/discovery/fetchers/aws-sync/aws-sync.go b/lib/srv/discovery/fetchers/aws-sync/aws-sync.go index d96e44075195e..2a7e928370091 100644 --- a/lib/srv/discovery/fetchers/aws-sync/aws-sync.go +++ b/lib/srv/discovery/fetchers/aws-sync/aws-sync.go @@ -35,7 +35,7 @@ import ( usageeventsv1 "github.com/gravitational/teleport/api/gen/proto/go/usageevents/v1" accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" "github.com/gravitational/teleport/lib/cloud" - "github.com/gravitational/teleport/lib/cloud/aws/config" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/srv/server" ) @@ -326,16 +326,16 @@ func (a *awsFetcher) getAWSOptions() []cloud.AWSOptionsFn { // getAWSV2Options returns a list of options to be used when // creating AWS clients with the v2 sdk. -func (a *awsFetcher) getAWSV2Options() []config.AWSOptionsFn { - opts := []config.AWSOptionsFn{ - config.WithCredentialsMaybeIntegration(a.Config.Integration), +func (a *awsFetcher) getAWSV2Options() []awsconfig.OptionsFn { + opts := []awsconfig.OptionsFn{ + awsconfig.WithCredentialsMaybeIntegration(a.Config.Integration), } if a.Config.AssumeRole != nil { - opts = append(opts, config.WithAssumeRole(a.Config.AssumeRole.RoleARN, a.Config.AssumeRole.ExternalID)) + opts = append(opts, awsconfig.WithAssumeRole(a.Config.AssumeRole.RoleARN, a.Config.AssumeRole.ExternalID)) } const maxRetries = 10 - opts = append(opts, config.WithRetryer(func() awsv2.Retryer { + opts = append(opts, awsconfig.WithRetryer(func() awsv2.Retryer { return retry.NewStandard(func(so *retry.StandardOptions) { so.MaxAttempts = maxRetries so.Backoff = retry.NewExponentialJitterBackoff(300 * time.Second) diff --git a/lib/srv/server/ec2_watcher.go b/lib/srv/server/ec2_watcher.go index 5c61ff178210f..79d20905408ac 100644 --- a/lib/srv/server/ec2_watcher.go +++ b/lib/srv/server/ec2_watcher.go @@ -32,7 +32,7 @@ import ( usageeventsv1 "github.com/gravitational/teleport/api/gen/proto/go/usageevents/v1" "github.com/gravitational/teleport/api/types" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" - "github.com/gravitational/teleport/lib/cloud/aws/config" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/labels" ) @@ -189,7 +189,7 @@ func NewEC2Watcher(ctx context.Context, fetchersFn func() []Fetcher, missedRotat } // EC2ClientGetter gets an AWS EC2 client for the given region. -type EC2ClientGetter func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) +type EC2ClientGetter func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (ec2.DescribeInstancesAPIClient, error) // MatchersToEC2InstanceFetchers converts a list of AWS EC2 Matchers into a list of AWS EC2 Fetchers. func MatchersToEC2InstanceFetchers(ctx context.Context, matchers []types.AWSMatcher, getEC2Client EC2ClientGetter, discoveryConfigName string) ([]Fetcher, error) { @@ -198,7 +198,7 @@ func MatchersToEC2InstanceFetchers(ctx context.Context, matchers []types.AWSMatc for _, region := range matcher.Regions { // TODO(gavin): support assume_role_arn for ec2. ec2Client, err := getEC2Client(ctx, region, - config.WithCredentialsMaybeIntegration(matcher.Integration), + awsconfig.WithCredentialsMaybeIntegration(matcher.Integration), ) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/server/ec2_watcher_test.go b/lib/srv/server/ec2_watcher_test.go index f7c9c0a85458d..f62cbb737d5f4 100644 --- a/lib/srv/server/ec2_watcher_test.go +++ b/lib/srv/server/ec2_watcher_test.go @@ -32,7 +32,7 @@ import ( usageeventsv1 "github.com/gravitational/teleport/api/gen/proto/go/usageevents/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils" - "github.com/gravitational/teleport/lib/cloud/aws/config" + "github.com/gravitational/teleport/lib/cloud/awsconfig" ) type mockEC2Client struct { @@ -228,7 +228,7 @@ func TestEC2Watcher(t *testing.T) { const noDiscoveryConfig = "" fetchersFn := func() []Fetcher { - fetchers, err := MatchersToEC2InstanceFetchers(ctx, matchers, func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) { + fetchers, err := MatchersToEC2InstanceFetchers(ctx, matchers, func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (ec2.DescribeInstancesAPIClient, error) { return client, nil }, noDiscoveryConfig) require.NoError(t, err)