Skip to content

Commit

Permalink
Rename lib/cloud/aws/config to lib/cloud/awsconfig
Browse files Browse the repository at this point in the history
  • Loading branch information
GavinFrazar committed Dec 14, 2024
1 parent 821708e commit f4b3f6b
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 72 deletions.
84 changes: 42 additions & 42 deletions lib/cloud/aws/config/config.go → lib/cloud/awsconfig/awsconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package config
package awsconfig

import (
"context"
Expand Down Expand Up @@ -43,13 +43,13 @@ const (
credentialsSourceIntegration
)

// AWSIntegrationSessionProvider defines a function that creates a credential provider from a region and an integration.
// IntegrationSessionProviderFunc defines a function that creates a credential provider from a region and an integration.
// This is used to generate aws configs for clients that must use an integration instead of ambient credentials.
type AWSIntegrationCredentialProvider func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error)
type IntegrationCredentialProviderFunc func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error)

// awsOptions is a struct of additional options for assuming an AWS role
// options is a struct of additional options for assuming an AWS role
// when construction an underlying AWS config.
type awsOptions struct {
type options struct {
// baseConfigis a config to use instead of the default config for an
// AWS region, which is used to enable role chaining.
baseConfig *aws.Config
Expand All @@ -61,15 +61,15 @@ type awsOptions struct {
credentialsSource credentialsSource
// integration is the name of the integration to be used to fetch the credentials.
integration string
// awsIntegrationCredentialsProvider is the integration credential provider to use.
awsIntegrationCredentialsProvider AWSIntegrationCredentialProvider
// integrationCredentialsProvider is the integration credential provider to use.
integrationCredentialsProvider IntegrationCredentialProviderFunc
// customRetryer is a custom retryer to use for the config.
customRetryer func() aws.Retryer
// maxRetries is the maximum number of retries to use for the config.
maxRetries *int
}

func (a *awsOptions) checkAndSetDefaults() error {
func (a *options) checkAndSetDefaults() error {
switch a.credentialsSource {
case credentialsSourceAmbient:
if a.integration != "" {
Expand All @@ -86,36 +86,36 @@ func (a *awsOptions) checkAndSetDefaults() error {
return nil
}

// AWSOptionsFn is an option function for setting additional options
// OptionsFn is an option function for setting additional options
// when getting an AWS config.
type AWSOptionsFn func(*awsOptions)
type OptionsFn func(*options)

// WithAssumeRole configures options needed for assuming an AWS role.
func WithAssumeRole(roleARN, externalID string) AWSOptionsFn {
return func(options *awsOptions) {
func WithAssumeRole(roleARN, externalID string) OptionsFn {
return func(options *options) {
options.assumeRoleARN = roleARN
options.assumeRoleExternalID = externalID
}
}

// WithRetryer sets a custom retryer for the config.
func WithRetryer(retryer func() aws.Retryer) AWSOptionsFn {
return func(options *awsOptions) {
func WithRetryer(retryer func() aws.Retryer) OptionsFn {
return func(options *options) {
options.customRetryer = retryer
}
}

// WithMaxRetries sets the maximum allowed value for the sdk to keep retrying.
func WithMaxRetries(maxRetries int) AWSOptionsFn {
return func(options *awsOptions) {
func WithMaxRetries(maxRetries int) OptionsFn {
return func(options *options) {
options.maxRetries = &maxRetries
}
}

// WithCredentialsMaybeIntegration sets the credential source to be
// - ambient if the integration is an empty string
// - integration, otherwise
func WithCredentialsMaybeIntegration(integration string) AWSOptionsFn {
func WithCredentialsMaybeIntegration(integration string) OptionsFn {
if integration != "" {
return withIntegrationCredentials(integration)
}
Expand All @@ -125,36 +125,36 @@ func WithCredentialsMaybeIntegration(integration string) AWSOptionsFn {

// withIntegrationCredentials configures options with an Integration that must be used to fetch Credentials to assume a role.
// This prevents the usage of AWS environment credentials.
func withIntegrationCredentials(integration string) AWSOptionsFn {
return func(options *awsOptions) {
func withIntegrationCredentials(integration string) OptionsFn {
return func(options *options) {
options.credentialsSource = credentialsSourceIntegration
options.integration = integration
}
}

// WithAmbientCredentials configures options to use the ambient credentials.
func WithAmbientCredentials() AWSOptionsFn {
return func(options *awsOptions) {
func WithAmbientCredentials() OptionsFn {
return func(options *options) {
options.credentialsSource = credentialsSourceAmbient
}
}

// WithAWSIntegrationCredentialProvider sets the integration credential provider.
func WithAWSIntegrationCredentialProvider(cred AWSIntegrationCredentialProvider) AWSOptionsFn {
return func(options *awsOptions) {
options.awsIntegrationCredentialsProvider = cred
// WithIntegrationCredentialProvider sets the integration credential provider.
func WithIntegrationCredentialProvider(cred IntegrationCredentialProviderFunc) OptionsFn {
return func(options *options) {
options.integrationCredentialsProvider = cred
}
}

// GetAWSConfig returns an AWS config for the specified region, optionally
// GetConfig returns an AWS config for the specified region, optionally
// assuming AWS IAM Roles.
func GetAWSConfig(ctx context.Context, region string, opts ...AWSOptionsFn) (aws.Config, error) {
var options awsOptions
func GetConfig(ctx context.Context, region string, opts ...OptionsFn) (aws.Config, error) {
var options options
for _, opt := range opts {
opt(&options)
}
if options.baseConfig == nil {
cfg, err := getAWSConfigForRegion(ctx, region, options)
cfg, err := getConfigForRegion(ctx, region, options)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
Expand All @@ -163,17 +163,17 @@ func GetAWSConfig(ctx context.Context, region string, opts ...AWSOptionsFn) (aws
if options.assumeRoleARN == "" {
return *options.baseConfig, nil
}
return getAWSConfigForRole(ctx, region, options)
return getConfigForRole(ctx, region, options)
}

// awsAmbientConfigProvider loads a new config using the environment variables.
func awsAmbientConfigProvider(region string, cred aws.CredentialsProvider, options awsOptions) (aws.Config, error) {
opts := buildAWSConfigOptions(region, cred, options)
// ambientConfigProvider loads a new config using the environment variables.
func ambientConfigProvider(region string, cred aws.CredentialsProvider, options options) (aws.Config, error) {
opts := buildConfigOptions(region, cred, options)
cfg, err := config.LoadDefaultConfig(context.Background(), opts...)
return cfg, trace.Wrap(err)
}

func buildAWSConfigOptions(region string, cred aws.CredentialsProvider, options awsOptions) []func(*config.LoadOptions) error {
func buildConfigOptions(region string, cred aws.CredentialsProvider, options options) []func(*config.LoadOptions) error {
opts := []func(*config.LoadOptions) error{
config.WithDefaultRegion(defaultRegion),
config.WithRegion(region),
Expand All @@ -191,34 +191,34 @@ func buildAWSConfigOptions(region string, cred aws.CredentialsProvider, options
return opts
}

// getAWSConfigForRegion returns AWS config for the specified region.
func getAWSConfigForRegion(ctx context.Context, region string, options awsOptions) (aws.Config, error) {
// getConfigForRegion returns AWS config for the specified region.
func getConfigForRegion(ctx context.Context, region string, options options) (aws.Config, error) {
if err := options.checkAndSetDefaults(); err != nil {
return aws.Config{}, trace.Wrap(err)
}

var cred aws.CredentialsProvider
if options.credentialsSource == credentialsSourceIntegration {
if options.awsIntegrationCredentialsProvider == nil {
if options.integrationCredentialsProvider == nil {
return aws.Config{}, trace.BadParameter("missing aws integration credential provider")
}

slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", options.integration)
var err error
cred, err = options.awsIntegrationCredentialsProvider(ctx, region, options.integration)
cred, err = options.integrationCredentialsProvider(ctx, region, options.integration)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
} else {
slog.DebugContext(ctx, "Initializing AWS config from environment", "region", region)
}

cfg, err := awsAmbientConfigProvider(region, cred, options)
cfg, err := ambientConfigProvider(region, cred, options)
return cfg, trace.Wrap(err)
}

// getAWSConfigForRole returns an AWS config for the specified region and role.
func getAWSConfigForRole(ctx context.Context, region string, options awsOptions) (aws.Config, error) {
// getConfigForRole returns an AWS config for the specified region and role.
func getConfigForRole(ctx context.Context, region string, options options) (aws.Config, error) {
if err := options.checkAndSetDefaults(); err != nil {
return aws.Config{}, trace.Wrap(err)
}
Expand All @@ -235,7 +235,7 @@ func getAWSConfigForRole(ctx context.Context, region string, options awsOptions)
return aws.Config{}, trace.Wrap(err)
}

opts := buildAWSConfigOptions(region, cred, options)
opts := buildConfigOptions(region, cred, options)
cfg, err := config.LoadDefaultConfig(ctx, opts...)
return cfg, trace.Wrap(err)
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package config
package awsconfig

import (
"context"
Expand All @@ -33,24 +33,24 @@ func (m *mockCredentialProvider) Retrieve(ctx context.Context) (aws.Credentials,
return m.cred, nil
}

func TestGetAWSConfigIntegration(t *testing.T) {
func TestGetConfigIntegration(t *testing.T) {
t.Parallel()
dummyIntegration := "integration-test"
dummyRegion := "test-region-123"

t.Run("without an integration credential provider, must return missing credential provider error", func(t *testing.T) {
ctx := context.Background()
_, err := GetAWSConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration))
_, err := GetConfig(ctx, dummyRegion, WithCredentialsMaybeIntegration(dummyIntegration))
require.True(t, trace.IsBadParameter(err), "unexpected error: %v", err)
require.ErrorContains(t, err, "missing aws integration credential provider")
})

t.Run("with an integration credential provider, must return the credentials", func(t *testing.T) {
ctx := context.Background()

cfg, err := GetAWSConfig(ctx, dummyRegion,
cfg, err := GetConfig(ctx, dummyRegion,
WithCredentialsMaybeIntegration(dummyIntegration),
WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
if region == dummyRegion && integration == dummyIntegration {
return &mockCredentialProvider{
cred: aws.Credentials{
Expand All @@ -69,9 +69,9 @@ func TestGetAWSConfigIntegration(t *testing.T) {
t.Run("with an integration credential provider, but using an empty integration falls back to ambient credentials", func(t *testing.T) {
ctx := context.Background()

_, err := GetAWSConfig(ctx, dummyRegion,
_, err := GetConfig(ctx, dummyRegion,
WithCredentialsMaybeIntegration(""),
WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
require.Fail(t, "this function should not be called")
return nil, nil
}))
Expand All @@ -81,9 +81,9 @@ func TestGetAWSConfigIntegration(t *testing.T) {
t.Run("with an integration credential provider, but using ambient credentials", func(t *testing.T) {
ctx := context.Background()

_, err := GetAWSConfig(ctx, dummyRegion,
_, err := GetConfig(ctx, dummyRegion,
WithAmbientCredentials(),
WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
require.Fail(t, "this function should not be called")
return nil, nil
}))
Expand All @@ -93,8 +93,8 @@ func TestGetAWSConfigIntegration(t *testing.T) {
t.Run("with an integration credential provider, but no credential source", func(t *testing.T) {
ctx := context.Background()

_, err := GetAWSConfig(ctx, dummyRegion,
WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
_, err := GetConfig(ctx, dummyRegion,
WithIntegrationCredentialProvider(func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error) {
require.Fail(t, "this function should not be called")
return nil, nil
}))
Expand Down
16 changes: 8 additions & 8 deletions lib/srv/discovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import (
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/cloud/aws/config"
"github.com/gravitational/teleport/lib/cloud/awsconfig"
gcpimds "github.com/gravitational/teleport/lib/cloud/imds/gcp"
"github.com/gravitational/teleport/lib/cryptosuites"
"github.com/gravitational/teleport/lib/integrations/awsoidc"
Expand Down Expand Up @@ -116,7 +116,7 @@ type Config struct {
// GetEC2Client gets an AWS EC2 client for the given region.
GetEC2Client server.EC2ClientGetter
// GetSSMClient gets an AWS SSM client for the given region.
GetSSMClient func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (server.SSMClient, error)
GetSSMClient func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (server.SSMClient, error)
// IntegrationOnlyCredentials discards any Matcher that don't have an Integration.
// When true, ambient credentials (used by the Cloud SDKs) are not used.
IntegrationOnlyCredentials bool
Expand Down Expand Up @@ -221,7 +221,7 @@ kubernetes matchers are present.`)
c.CloudClients = cloudClients
}
if c.GetEC2Client == nil {
c.GetEC2Client = func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) {
c.GetEC2Client = func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (ec2.DescribeInstancesAPIClient, error) {
cfg, err := c.getAWSConfig(ctx, region, opts...)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -230,7 +230,7 @@ kubernetes matchers are present.`)
}
}
if c.GetSSMClient == nil {
c.GetSSMClient = func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (server.SSMClient, error) {
c.GetSSMClient = func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (server.SSMClient, error) {
cfg, err := c.getAWSConfig(ctx, region, opts...)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -293,8 +293,8 @@ kubernetes matchers are present.`)
return nil
}

func (c *Config) getAWSConfig(ctx context.Context, region string, opts ...config.AWSOptionsFn) (aws.Config, error) {
opts = append(opts, config.WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integrationName string) (aws.CredentialsProvider, error) {
func (c *Config) getAWSConfig(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (aws.Config, error) {
opts = append(opts, awsconfig.WithIntegrationCredentialProvider(func(ctx context.Context, region, integrationName string) (aws.CredentialsProvider, error) {
integration, err := c.AccessPoint.GetIntegration(ctx, integrationName)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -313,7 +313,7 @@ func (c *Config) getAWSConfig(ctx context.Context, region string, opts ...config
})
return cred, trace.Wrap(err)
}))
cfg, err := config.GetAWSConfig(ctx, region, opts...)
cfg, err := awsconfig.GetConfig(ctx, region, opts...)
return cfg, trace.Wrap(err)
}

Expand Down Expand Up @@ -1037,7 +1037,7 @@ func (s *Server) handleEC2RemoteInstallation(instances *server.EC2Instances) err
// TODO(gavin): support assume_role_arn for ec2.
ssmClient, err := s.GetSSMClient(s.ctx,
instances.Region,
config.WithCredentialsMaybeIntegration(instances.Integration),
awsconfig.WithCredentialsMaybeIntegration(instances.Integration),
)
if err != nil {
return trace.Wrap(err)
Expand Down
8 changes: 4 additions & 4 deletions lib/srv/discovery/discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ import (
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/cloud/aws/config"
"github.com/gravitational/teleport/lib/cloud/awsconfig"
"github.com/gravitational/teleport/lib/cloud/azure"
"github.com/gravitational/teleport/lib/cloud/gcp"
gcpimds "github.com/gravitational/teleport/lib/cloud/imds/gcp"
Expand Down Expand Up @@ -772,10 +772,10 @@ func TestDiscoveryServer(t *testing.T) {
}

server, err := New(authz.ContextWithUser(context.Background(), identity.I), &Config{
GetEC2Client: func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) {
GetEC2Client: func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (ec2.DescribeInstancesAPIClient, error) {
return ec2Client, nil
},
GetSSMClient: func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (server.SSMClient, error) {
GetSSMClient: func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (server.SSMClient, error) {
return tc.ssm, nil
},
ClusterFeatures: func() proto.Features { return proto.Features{} },
Expand Down Expand Up @@ -914,7 +914,7 @@ func TestDiscoveryServerConcurrency(t *testing.T) {
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, authClient.Close()) })

getEC2Client := func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) {
getEC2Client := func(ctx context.Context, region string, opts ...awsconfig.OptionsFn) (ec2.DescribeInstancesAPIClient, error) {
return ec2Client, nil
}

Expand Down
Loading

0 comments on commit f4b3f6b

Please sign in to comment.