Skip to content

Commit

Permalink
Port ssm client to v2 sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
atburke committed Dec 4, 2024
1 parent db9505d commit ecf1735
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 307 deletions.
23 changes: 0 additions & 23 deletions lib/cloud/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ import (
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/aws/aws-sdk-go/service/secretsmanager"
"github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/gravitational/trace"
Expand Down Expand Up @@ -133,8 +131,6 @@ type AWSClients interface {
GetAWSIAMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (iamiface.IAMAPI, error)
// GetAWSSTSClient returns AWS STS client for the specified region.
GetAWSSTSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (stsiface.STSAPI, error)
// GetAWSSSMClient returns AWS SSM client for the specified region.
GetAWSSSMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (ssmiface.SSMAPI, error)
// GetAWSEKSClient returns AWS EKS client for the specified region.
GetAWSEKSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (eksiface.EKSAPI, error)
// GetAWSKMSClient returns AWS KMS client for the specified region.
Expand Down Expand Up @@ -592,15 +588,6 @@ func (c *cloudClients) GetAWSSTSClient(ctx context.Context, region string, opts
return sts.New(session), nil
}

// GetAWSSSMClient returns AWS SSM client for the specified region.
func (c *cloudClients) GetAWSSSMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (ssmiface.SSMAPI, error) {
session, err := c.GetAWSSession(ctx, region, opts...)
if err != nil {
return nil, trace.Wrap(err)
}
return ssm.New(session), nil
}

// GetAWSEKSClient returns AWS EKS client for the specified region.
func (c *cloudClients) GetAWSEKSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (eksiface.EKSAPI, error) {
session, err := c.GetAWSSession(ctx, region, opts...)
Expand Down Expand Up @@ -1021,7 +1008,6 @@ type TestCloudClients struct {
GCPGKE gcp.GKEClient
GCPProjects gcp.ProjectsClient
GCPInstances gcp.InstancesClient
SSM ssmiface.SSMAPI
InstanceMetadata imds.Client
EKS eksiface.EKSAPI
KMS kmsiface.KMSAPI
Expand Down Expand Up @@ -1191,15 +1177,6 @@ func (c *TestCloudClients) GetAWSKMSClient(ctx context.Context, region string, o
return c.KMS, nil
}

// GetAWSSSMClient returns an AWS SSM client
func (c *TestCloudClients) GetAWSSSMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (ssmiface.SSMAPI, error) {
_, err := c.GetAWSSession(ctx, region, opts...)
if err != nil {
return nil, trace.Wrap(err)
}
return c.SSM, nil
}

// GetGCPIAMClient returns GCP IAM client.
func (c *TestCloudClients) GetGCPIAMClient(ctx context.Context) (*gcpcredentials.IamCredentialsClient, error) {
return gcpcredentials.NewIamCredentialsClient(ctx,
Expand Down
77 changes: 46 additions & 31 deletions lib/srv/discovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@ import (
"time"

"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3"
awsv2 "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go-v2/service/ssm"
ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -116,6 +115,8 @@ type Config struct {
CloudClients cloud.Clients
// GetEC2Client gets an AWS EC2 client for the given region.
GetEC2Client server.EC2ClientGetter
// GetSSMClient gets an AWS SSM client for the given region.
GetSSMClient func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (server.SSMClient, error)
// IntegrationOnlyCredentials discards any Matcher that don't have an Integration.
// When true, ambient credentials (used by the Cloud SDKs) are not used.
IntegrationOnlyCredentials bool
Expand Down Expand Up @@ -221,32 +222,22 @@ kubernetes matchers are present.`)
}
if c.GetEC2Client == nil {
c.GetEC2Client = func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) {
opts = append(opts, config.WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integrationName string) (awsv2.CredentialsProvider, error) {
integration, err := c.AccessPoint.GetIntegration(ctx, integrationName)
if err != nil {
return nil, trace.Wrap(err)
}
if integration.GetAWSOIDCIntegrationSpec() == nil {
return nil, trace.BadParameter("integration does not have aws oidc spec fields %q", integrationName)
}
token, err := c.AccessPoint.GenerateAWSOIDCToken(ctx, integrationName)
if err != nil {
return nil, trace.Wrap(err)
}
cred, err := awsoidc.NewAWSCredentialsProvider(ctx, &awsoidc.AWSClientRequest{
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: region,
})
return cred, trace.Wrap(err)
}))
cfg, err := config.GetAWSConfig(ctx, region, opts...)
cfg, err := c.getAWSConfig(ctx, region, opts...)
if err != nil {
return nil, trace.Wrap(err)
}
return ec2.NewFromConfig(cfg), nil
}
}
if c.GetSSMClient == nil {
c.GetSSMClient = func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (server.SSMClient, error) {
cfg, err := c.getAWSConfig(ctx, region, opts...)
if err != nil {
return nil, trace.Wrap(err)
}
return ssm.NewFromConfig(cfg), nil
}
}
if c.KubernetesClient == nil && len(c.Matchers.Kubernetes) > 0 {
cfg, err := rest.InClusterConfig()
if err != nil {
Expand Down Expand Up @@ -302,6 +293,30 @@ kubernetes matchers are present.`)
return nil
}

func (c *Config) getAWSConfig(ctx context.Context, region string, opts ...config.AWSOptionsFn) (aws.Config, error) {
opts = append(opts, config.WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integrationName string) (aws.CredentialsProvider, error) {
integration, err := c.AccessPoint.GetIntegration(ctx, integrationName)
if err != nil {
return nil, trace.Wrap(err)
}
if integration.GetAWSOIDCIntegrationSpec() == nil {
return nil, trace.BadParameter("integration does not have aws oidc spec fields %q", integrationName)
}
token, err := c.AccessPoint.GenerateAWSOIDCToken(ctx, integrationName)
if err != nil {
return nil, trace.Wrap(err)
}
cred, err := awsoidc.NewAWSCredentialsProvider(ctx, &awsoidc.AWSClientRequest{
Token: token,
RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN,
Region: region,
})
return cred, trace.Wrap(err)
}))
cfg, err := config.GetAWSConfig(ctx, region, opts...)
return cfg, trace.Wrap(err)
}

// Server is a discovery server, used to discover cloud resources for
// inclusion in Teleport
type Server struct {
Expand Down Expand Up @@ -863,7 +878,7 @@ func genEC2InstancesLogStr(instances []server.EC2Instance) string {

func genAzureInstancesLogStr(instances []*armcompute.VirtualMachine) string {
return genInstancesLogStr(instances, func(i *armcompute.VirtualMachine) string {
return aws.StringValue(i.Name)
return aws.ToString(i.Name)
})
}

Expand Down Expand Up @@ -1019,9 +1034,9 @@ func (s *Server) heartbeatEICEInstance(instances *server.EC2Instances) {

func (s *Server) handleEC2RemoteInstallation(instances *server.EC2Instances) error {
// TODO(gavin): support assume_role_arn for ec2.
ec2Client, err := s.CloudClients.GetAWSSSMClient(s.ctx,
ssmClient, err := s.GetSSMClient(s.ctx,
instances.Region,
cloud.WithCredentialsMaybeIntegration(instances.Integration),
config.WithCredentialsMaybeIntegration(instances.Integration),
)
if err != nil {
return trace.Wrap(err)
Expand All @@ -1031,7 +1046,7 @@ func (s *Server) handleEC2RemoteInstallation(instances *server.EC2Instances) err

req := server.SSMRunRequest{
DocumentName: instances.DocumentName,
SSM: ec2Client,
SSM: ssmClient,
Instances: instances.Instances,
Params: instances.Parameters,
Region: instances.Region,
Expand Down Expand Up @@ -1070,8 +1085,8 @@ func (s *Server) handleEC2RemoteInstallation(instances *server.EC2Instances) err
}

func (s *Server) logHandleInstancesErr(err error) {
var aErr awserr.Error
if errors.As(err, &aErr) && aErr.Code() == ssm.ErrCodeInvalidInstanceId {
var instanceIDErr *ssmtypes.InvalidInstanceId
if errors.As(err, &instanceIDErr) {
const errorMessage = "SSM SendCommand failed with ErrCodeInvalidInstanceId. " +
"Make sure that the instances have AmazonSSMManagedInstanceCore policy assigned. " +
"Also check that SSM agent is running and registered with the SSM endpoint on that instance and try restarting or reinstalling it in case of issues. " +
Expand Down Expand Up @@ -1210,7 +1225,7 @@ outer:
for _, node := range nodes {
var vmID string
if inst.Properties != nil {
vmID = aws.StringValue(inst.Properties.VMID)
vmID = aws.ToString(inst.Properties.VMID)
}
match := types.MatchLabels(node, map[string]string{
types.SubscriptionIDLabel: instances.SubscriptionID,
Expand Down
Loading

0 comments on commit ecf1735

Please sign in to comment.