From ecf1735f13cffba7f8561cfd8b8a6b55a8afb1d0 Mon Sep 17 00:00:00 2001 From: Andrew Burke Date: Mon, 2 Dec 2024 15:08:27 -0800 Subject: [PATCH] Port ssm client to v2 sdk --- lib/cloud/clients.go | 23 -- lib/srv/discovery/discovery.go | 77 ++++--- lib/srv/discovery/discovery_test.go | 89 ++++---- lib/srv/server/ssm_install.go | 108 +++++---- lib/srv/server/ssm_install_test.go | 325 ++++++++++++++-------------- 5 files changed, 315 insertions(+), 307 deletions(-) diff --git a/lib/cloud/clients.go b/lib/cloud/clients.go index 328ee76bcee0e..3cce15285593a 100644 --- a/lib/cloud/clients.go +++ b/lib/cloud/clients.go @@ -60,8 +60,6 @@ import ( "github.com/aws/aws-sdk-go/service/s3/s3iface" "github.com/aws/aws-sdk-go/service/secretsmanager" "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" - "github.com/aws/aws-sdk-go/service/ssm" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" "github.com/aws/aws-sdk-go/service/sts" "github.com/aws/aws-sdk-go/service/sts/stsiface" "github.com/gravitational/trace" @@ -133,8 +131,6 @@ type AWSClients interface { GetAWSIAMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (iamiface.IAMAPI, error) // GetAWSSTSClient returns AWS STS client for the specified region. GetAWSSTSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (stsiface.STSAPI, error) - // GetAWSSSMClient returns AWS SSM client for the specified region. - GetAWSSSMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (ssmiface.SSMAPI, error) // GetAWSEKSClient returns AWS EKS client for the specified region. GetAWSEKSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (eksiface.EKSAPI, error) // GetAWSKMSClient returns AWS KMS client for the specified region. @@ -592,15 +588,6 @@ func (c *cloudClients) GetAWSSTSClient(ctx context.Context, region string, opts return sts.New(session), nil } -// GetAWSSSMClient returns AWS SSM client for the specified region. -func (c *cloudClients) GetAWSSSMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (ssmiface.SSMAPI, error) { - session, err := c.GetAWSSession(ctx, region, opts...) - if err != nil { - return nil, trace.Wrap(err) - } - return ssm.New(session), nil -} - // GetAWSEKSClient returns AWS EKS client for the specified region. func (c *cloudClients) GetAWSEKSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (eksiface.EKSAPI, error) { session, err := c.GetAWSSession(ctx, region, opts...) @@ -1021,7 +1008,6 @@ type TestCloudClients struct { GCPGKE gcp.GKEClient GCPProjects gcp.ProjectsClient GCPInstances gcp.InstancesClient - SSM ssmiface.SSMAPI InstanceMetadata imds.Client EKS eksiface.EKSAPI KMS kmsiface.KMSAPI @@ -1191,15 +1177,6 @@ func (c *TestCloudClients) GetAWSKMSClient(ctx context.Context, region string, o return c.KMS, nil } -// GetAWSSSMClient returns an AWS SSM client -func (c *TestCloudClients) GetAWSSSMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (ssmiface.SSMAPI, error) { - _, err := c.GetAWSSession(ctx, region, opts...) - if err != nil { - return nil, trace.Wrap(err) - } - return c.SSM, nil -} - // GetGCPIAMClient returns GCP IAM client. func (c *TestCloudClients) GetGCPIAMClient(ctx context.Context) (*gcpcredentials.IamCredentialsClient, error) { return gcpcredentials.NewIamCredentialsClient(ctx, diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index 2f8c4d097b845..3138ccc052f52 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -30,12 +30,11 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3" - awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go-v2/service/ssm" + ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ssm" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" @@ -116,6 +115,8 @@ type Config struct { CloudClients cloud.Clients // GetEC2Client gets an AWS EC2 client for the given region. GetEC2Client server.EC2ClientGetter + // GetSSMClient gets an AWS SSM client for the given region. + GetSSMClient func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (server.SSMClient, error) // IntegrationOnlyCredentials discards any Matcher that don't have an Integration. // When true, ambient credentials (used by the Cloud SDKs) are not used. IntegrationOnlyCredentials bool @@ -221,32 +222,22 @@ kubernetes matchers are present.`) } if c.GetEC2Client == nil { c.GetEC2Client = func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) { - opts = append(opts, config.WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integrationName string) (awsv2.CredentialsProvider, error) { - integration, err := c.AccessPoint.GetIntegration(ctx, integrationName) - if err != nil { - return nil, trace.Wrap(err) - } - if integration.GetAWSOIDCIntegrationSpec() == nil { - return nil, trace.BadParameter("integration does not have aws oidc spec fields %q", integrationName) - } - token, err := c.AccessPoint.GenerateAWSOIDCToken(ctx, integrationName) - if err != nil { - return nil, trace.Wrap(err) - } - cred, err := awsoidc.NewAWSCredentialsProvider(ctx, &awsoidc.AWSClientRequest{ - Token: token, - RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN, - Region: region, - }) - return cred, trace.Wrap(err) - })) - cfg, err := config.GetAWSConfig(ctx, region, opts...) + cfg, err := c.getAWSConfig(ctx, region, opts...) if err != nil { return nil, trace.Wrap(err) } return ec2.NewFromConfig(cfg), nil } } + if c.GetSSMClient == nil { + c.GetSSMClient = func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (server.SSMClient, error) { + cfg, err := c.getAWSConfig(ctx, region, opts...) + if err != nil { + return nil, trace.Wrap(err) + } + return ssm.NewFromConfig(cfg), nil + } + } if c.KubernetesClient == nil && len(c.Matchers.Kubernetes) > 0 { cfg, err := rest.InClusterConfig() if err != nil { @@ -302,6 +293,30 @@ kubernetes matchers are present.`) return nil } +func (c *Config) getAWSConfig(ctx context.Context, region string, opts ...config.AWSOptionsFn) (aws.Config, error) { + opts = append(opts, config.WithAWSIntegrationCredentialProvider(func(ctx context.Context, region, integrationName string) (aws.CredentialsProvider, error) { + integration, err := c.AccessPoint.GetIntegration(ctx, integrationName) + if err != nil { + return nil, trace.Wrap(err) + } + if integration.GetAWSOIDCIntegrationSpec() == nil { + return nil, trace.BadParameter("integration does not have aws oidc spec fields %q", integrationName) + } + token, err := c.AccessPoint.GenerateAWSOIDCToken(ctx, integrationName) + if err != nil { + return nil, trace.Wrap(err) + } + cred, err := awsoidc.NewAWSCredentialsProvider(ctx, &awsoidc.AWSClientRequest{ + Token: token, + RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN, + Region: region, + }) + return cred, trace.Wrap(err) + })) + cfg, err := config.GetAWSConfig(ctx, region, opts...) + return cfg, trace.Wrap(err) +} + // Server is a discovery server, used to discover cloud resources for // inclusion in Teleport type Server struct { @@ -863,7 +878,7 @@ func genEC2InstancesLogStr(instances []server.EC2Instance) string { func genAzureInstancesLogStr(instances []*armcompute.VirtualMachine) string { return genInstancesLogStr(instances, func(i *armcompute.VirtualMachine) string { - return aws.StringValue(i.Name) + return aws.ToString(i.Name) }) } @@ -1019,9 +1034,9 @@ func (s *Server) heartbeatEICEInstance(instances *server.EC2Instances) { func (s *Server) handleEC2RemoteInstallation(instances *server.EC2Instances) error { // TODO(gavin): support assume_role_arn for ec2. - ec2Client, err := s.CloudClients.GetAWSSSMClient(s.ctx, + ssmClient, err := s.GetSSMClient(s.ctx, instances.Region, - cloud.WithCredentialsMaybeIntegration(instances.Integration), + config.WithCredentialsMaybeIntegration(instances.Integration), ) if err != nil { return trace.Wrap(err) @@ -1031,7 +1046,7 @@ func (s *Server) handleEC2RemoteInstallation(instances *server.EC2Instances) err req := server.SSMRunRequest{ DocumentName: instances.DocumentName, - SSM: ec2Client, + SSM: ssmClient, Instances: instances.Instances, Params: instances.Parameters, Region: instances.Region, @@ -1070,8 +1085,8 @@ func (s *Server) handleEC2RemoteInstallation(instances *server.EC2Instances) err } func (s *Server) logHandleInstancesErr(err error) { - var aErr awserr.Error - if errors.As(err, &aErr) && aErr.Code() == ssm.ErrCodeInvalidInstanceId { + var instanceIDErr *ssmtypes.InvalidInstanceId + if errors.As(err, &instanceIDErr) { const errorMessage = "SSM SendCommand failed with ErrCodeInvalidInstanceId. " + "Make sure that the instances have AmazonSSMManagedInstanceCore policy assigned. " + "Also check that SSM agent is running and registered with the SSM endpoint on that instance and try restarting or reinstalling it in case of issues. " + @@ -1210,7 +1225,7 @@ outer: for _, node := range nodes { var vmID string if inst.Properties != nil { - vmID = aws.StringValue(inst.Properties.VMID) + vmID = aws.ToString(inst.Properties.VMID) } match := types.MatchLabels(node, map[string]string{ types.SubscriptionIDLabel: instances.SubscriptionID, diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index 8c9c08b0aff95..7f569caee654b 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -39,15 +39,14 @@ import ( awsv2 "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/aws-sdk-go-v2/service/ssm" + ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/eks" "github.com/aws/aws-sdk-go/service/eks/eksiface" "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/redshift" - "github.com/aws/aws-sdk-go/service/ssm" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" @@ -98,26 +97,19 @@ func TestMain(m *testing.M) { } type mockSSMClient struct { - ssmiface.SSMAPI + server.SSMClient commandOutput *ssm.SendCommandOutput invokeOutput *ssm.GetCommandInvocationOutput } -func (sm *mockSSMClient) SendCommandWithContext(_ context.Context, input *ssm.SendCommandInput, _ ...request.Option) (*ssm.SendCommandOutput, error) { +func (sm *mockSSMClient) SendCommand(_ context.Context, input *ssm.SendCommandInput, _ ...func(*ssm.Options)) (*ssm.SendCommandOutput, error) { return sm.commandOutput, nil } -func (sm *mockSSMClient) GetCommandInvocationWithContext(_ context.Context, input *ssm.GetCommandInvocationInput, _ ...request.Option) (*ssm.GetCommandInvocationOutput, error) { +func (sm *mockSSMClient) GetCommandInvocation(_ context.Context, input *ssm.GetCommandInvocationInput, _ ...func(*ssm.Options)) (*ssm.GetCommandInvocationOutput, error) { return sm.invokeOutput, nil } -func (sm *mockSSMClient) WaitUntilCommandExecutedWithContext(aws.Context, *ssm.GetCommandInvocationInput, ...request.WaiterOption) error { - if aws.StringValue(sm.commandOutput.Command.Status) == ssm.CommandStatusFailed { - return awserr.New(request.WaiterResourceNotReadyErrorCode, "err", nil) - } - return nil -} - type mockEmitter struct { eventHandler func(*testing.T, events.AuditEvent, *Server) server *Server @@ -325,13 +317,13 @@ func TestDiscoveryServer(t *testing.T) { }, ssm: &mockSSMClient{ commandOutput: &ssm.SendCommandOutput{ - Command: &ssm.Command{ - CommandId: aws.String("command-id-1"), + Command: &ssmtypes.Command{ + CommandId: awsv2.String("command-id-1"), }, }, invokeOutput: &ssm.GetCommandInvocationOutput{ - Status: aws.String(ssm.CommandStatusSuccess), - ResponseCode: aws.Int64(0), + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, }, }, emitter: &mockEmitter{ @@ -347,7 +339,7 @@ func TestDiscoveryServer(t *testing.T) { InstanceID: "instance-id-1", Region: "eu-central-1", ExitCode: 0, - Status: ssm.CommandStatusSuccess, + Status: string(ssmtypes.CommandInvocationStatusSuccess), }, ae) }, }, @@ -383,13 +375,13 @@ func TestDiscoveryServer(t *testing.T) { }, ssm: &mockSSMClient{ commandOutput: &ssm.SendCommandOutput{ - Command: &ssm.Command{ - CommandId: aws.String("command-id-1"), + Command: &ssmtypes.Command{ + CommandId: awsv2.String("command-id-1"), }, }, invokeOutput: &ssm.GetCommandInvocationOutput{ - Status: aws.String(ssm.CommandStatusSuccess), - ResponseCode: aws.Int64(0), + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, }, }, staticMatchers: defaultStaticMatcher, @@ -424,13 +416,13 @@ func TestDiscoveryServer(t *testing.T) { }, ssm: &mockSSMClient{ commandOutput: &ssm.SendCommandOutput{ - Command: &ssm.Command{ - CommandId: aws.String("command-id-1"), + Command: &ssmtypes.Command{ + CommandId: awsv2.String("command-id-1"), }, }, invokeOutput: &ssm.GetCommandInvocationOutput{ - Status: aws.String(ssm.CommandStatusSuccess), - ResponseCode: aws.Int64(0), + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, }, }, emitter: &mockEmitter{}, @@ -443,13 +435,13 @@ func TestDiscoveryServer(t *testing.T) { foundEC2Instances: genEC2Instances(58), ssm: &mockSSMClient{ commandOutput: &ssm.SendCommandOutput{ - Command: &ssm.Command{ - CommandId: aws.String("command-id-1"), + Command: &ssmtypes.Command{ + CommandId: awsv2.String("command-id-1"), }, }, invokeOutput: &ssm.GetCommandInvocationOutput{ - Status: aws.String(ssm.CommandStatusSuccess), - ResponseCode: aws.Int64(0), + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, }, }, emitter: &mockEmitter{}, @@ -473,13 +465,13 @@ func TestDiscoveryServer(t *testing.T) { }, ssm: &mockSSMClient{ commandOutput: &ssm.SendCommandOutput{ - Command: &ssm.Command{ - CommandId: aws.String("command-id-1"), + Command: &ssmtypes.Command{ + CommandId: awsv2.String("command-id-1"), }, }, invokeOutput: &ssm.GetCommandInvocationOutput{ - Status: aws.String(ssm.CommandStatusSuccess), - ResponseCode: aws.Int64(0), + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, }, }, emitter: &mockEmitter{ @@ -495,7 +487,7 @@ func TestDiscoveryServer(t *testing.T) { InstanceID: "instance-id-1", Region: "eu-central-1", ExitCode: 0, - Status: ssm.CommandStatusSuccess, + Status: string(ssmtypes.CommandInvocationStatusSuccess), }, ae) }, }, @@ -520,13 +512,13 @@ func TestDiscoveryServer(t *testing.T) { }, ssm: &mockSSMClient{ commandOutput: &ssm.SendCommandOutput{ - Command: &ssm.Command{ - CommandId: aws.String("command-id-1"), + Command: &ssmtypes.Command{ + CommandId: awsv2.String("command-id-1"), }, }, invokeOutput: &ssm.GetCommandInvocationOutput{ - Status: aws.String(ssm.CommandStatusSuccess), - ResponseCode: aws.Int64(0), + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, }, }, emitter: &mockEmitter{ @@ -542,7 +534,7 @@ func TestDiscoveryServer(t *testing.T) { InstanceID: "instance-id-1", Region: "eu-central-1", ExitCode: 0, - Status: ssm.CommandStatusSuccess, + Status: string(ssmtypes.CommandInvocationStatusSuccess), }, ae) }, }, @@ -582,13 +574,13 @@ func TestDiscoveryServer(t *testing.T) { }, ssm: &mockSSMClient{ commandOutput: &ssm.SendCommandOutput{ - Command: &ssm.Command{ - CommandId: aws.String("command-id-1"), + Command: &ssmtypes.Command{ + CommandId: awsv2.String("command-id-1"), }, }, invokeOutput: &ssm.GetCommandInvocationOutput{ - Status: aws.String(ssm.CommandStatusSuccess), - ResponseCode: aws.Int64(0), + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, }, }, ssmRunError: trace.BadParameter("ssm run failed"), @@ -605,7 +597,7 @@ func TestDiscoveryServer(t *testing.T) { InstanceID: "instance-id-1", Region: "eu-central-1", ExitCode: 0, - Status: ssm.CommandStatusSuccess, + Status: string(ssmtypes.CommandInvocationStatusSuccess), }, ae) }, }, @@ -640,9 +632,6 @@ func TestDiscoveryServer(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - testCloudClients := &cloud.TestCloudClients{ - SSM: tc.ssm, - } ec2Client := &mockEC2Client{output: &ec2.DescribeInstancesOutput{ Reservations: []ec2types.Reservation{ { @@ -691,10 +680,12 @@ func TestDiscoveryServer(t *testing.T) { } server, err := New(authz.ContextWithUser(context.Background(), identity.I), &Config{ - CloudClients: testCloudClients, GetEC2Client: func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (ec2.DescribeInstancesAPIClient, error) { return ec2Client, nil }, + GetSSMClient: func(ctx context.Context, region string, opts ...config.AWSOptionsFn) (server.SSMClient, error) { + return tc.ssm, nil + }, ClusterFeatures: func() proto.Features { return proto.Features{} }, KubernetesClient: fake.NewSimpleClientset(), AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), diff --git a/lib/srv/server/ssm_install.go b/lib/srv/server/ssm_install.go index 3c23f672884a3..51943f4400058 100644 --- a/lib/srv/server/ssm_install.go +++ b/lib/srv/server/ssm_install.go @@ -26,12 +26,11 @@ import ( "maps" "slices" "strings" + "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/ssm" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ssm" + ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" "github.com/gravitational/trace" "golang.org/x/sync/errgroup" @@ -42,6 +41,23 @@ import ( libevents "github.com/gravitational/teleport/lib/events" ) +// waiterTimedOutErrorMessage is the error message returned by the AWS SDK command +// executed waiter when it times out. +const waiterTimedOutErrorMessage = "exceeded max wait time for CommandExecuted waiter" + +// SSMClient is the subset of the AWS SSM API required for EC2 discovery. +type SSMClient interface { + ssm.DescribeInstanceInformationAPIClient + ssm.GetCommandInvocationAPIClient + ssm.ListCommandInvocationsAPIClient + // SendCommand runs commands on one or more managed nodes. + SendCommand(ctx context.Context, params *ssm.SendCommandInput, optFns ...func(*ssm.Options)) (*ssm.SendCommandOutput, error) +} + +type commandWaiter interface { + Wait(ctx context.Context, params *ssm.GetCommandInvocationInput, maxWaitDur time.Duration, optFns ...func(*ssm.CommandExecutedWaiterOptions)) error +} + // SSMInstallerConfig represents configuration for an SSM install // script executor. type SSMInstallerConfig struct { @@ -50,6 +66,9 @@ type SSMInstallerConfig struct { // Logger is used to log messages. // Optional. A logger is created if one not supplied. Logger *slog.Logger + // getWaiter replaces the default command waiter for a given SSM client. + // Used in tests. + getWaiter func(SSMClient) commandWaiter } // SSMInstallationResult contains the result of trying to install teleport @@ -84,7 +103,7 @@ type SSMRunRequest struct { // DocumentName is the name of the SSM document to run. DocumentName string // SSM is an SSM API client. - SSM ssmiface.SSMAPI + SSM SSMClient // Instances is the list of instances that will have the SSM // document executed on them. Instances []EC2Instance @@ -124,6 +143,12 @@ func (c *SSMInstallerConfig) checkAndSetDefaults() error { c.Logger = slog.Default().With(teleport.ComponentKey, "ssminstaller") } + if c.getWaiter == nil { + c.getWaiter = func(s SSMClient) commandWaiter { + return ssm.NewCommandExecutedWaiter(s) + } + } + return nil } @@ -144,9 +169,9 @@ func (si *SSMInstaller) Run(ctx context.Context, req SSMRunRequest) error { instances[inst.InstanceID] = inst.InstanceName } - params := make(map[string][]*string) + params := make(map[string][]string) for k, v := range req.Params { - params[k] = []*string{aws.String(v)} + params[k] = []string{v} } validInstances := instances @@ -175,9 +200,9 @@ func (si *SSMInstaller) Run(ctx context.Context, req SSMRunRequest) error { } validInstanceIDs := instanceIDsFrom(validInstances) - output, err := req.SSM.SendCommandWithContext(ctx, &ssm.SendCommandInput{ + output, err := req.SSM.SendCommand(ctx, &ssm.SendCommandInput{ DocumentName: aws.String(req.DocumentName), - InstanceIds: aws.StringSlice(validInstanceIDs), + InstanceIds: validInstanceIDs, Parameters: params, }) if err != nil { @@ -194,9 +219,9 @@ func (si *SSMInstaller) Run(ctx context.Context, req SSMRunRequest) error { // As a best effort, we try to call ssm.SendCommand again but this time without the "sshdConfigPath" param // We must not remove the Param "sshdConfigPath" beforehand because customers might be using custom SSM Documents for ec2 auto discovery. delete(params, ParamSSHDConfigPath) - output, err = req.SSM.SendCommandWithContext(ctx, &ssm.SendCommandInput{ + output, err = req.SSM.SendCommand(ctx, &ssm.SendCommandInput{ DocumentName: aws.String(req.DocumentName), - InstanceIds: aws.StringSlice(validInstanceIDs), + InstanceIds: validInstanceIDs, Parameters: params, }) if err != nil { @@ -296,20 +321,20 @@ func (si *SSMInstaller) describeSSMAgentState(ctx context.Context, req SSMRunReq } instanceIDs := instanceIDsFrom(allInstances) - ssmInstancesInfo, err := req.SSM.DescribeInstanceInformationWithContext(ctx, &ssm.DescribeInstanceInformationInput{ - Filters: []*ssm.InstanceInformationStringFilter{ - {Key: aws.String(ssm.InstanceInformationFilterKeyInstanceIds), Values: aws.StringSlice(instanceIDs)}, + ssmInstancesInfo, err := req.SSM.DescribeInstanceInformation(ctx, &ssm.DescribeInstanceInformationInput{ + Filters: []ssmtypes.InstanceInformationStringFilter{ + {Key: aws.String(string(ssmtypes.InstanceInformationFilterKeyInstanceIds)), Values: instanceIDs}, }, - MaxResults: aws.Int64(awsEC2APIChunkSize), + MaxResults: aws.Int32(awsEC2APIChunkSize), }) if err != nil { return nil, trace.Wrap(awslib.ConvertRequestFailureError(err)) } - instanceStateByInstanceID := make(map[string]*ssm.InstanceInformation, len(ssmInstancesInfo.InstanceInformationList)) + instanceStateByInstanceID := make(map[string]ssmtypes.InstanceInformation, len(ssmInstancesInfo.InstanceInformationList)) for _, instanceState := range ssmInstancesInfo.InstanceInformationList { // instanceState.InstanceId always has the InstanceID value according to AWS Docs. - instanceStateByInstanceID[aws.StringValue(instanceState.InstanceId)] = instanceState + instanceStateByInstanceID[aws.ToString(instanceState.InstanceId)] = instanceState } for instanceID, instanceName := range allInstances { @@ -319,12 +344,12 @@ func (si *SSMInstaller) describeSSMAgentState(ctx context.Context, req SSMRunReq continue } - if aws.StringValue(instanceState.PingStatus) == ssm.PingStatusConnectionLost { + if instanceState.PingStatus == ssmtypes.PingStatusConnectionLost { ret.connectionLost[instanceID] = instanceName continue } - if aws.StringValue(instanceState.PlatformType) != ssm.PlatformTypeLinux { + if instanceState.PlatformType != ssmtypes.PlatformTypeLinux { ret.unsupportedOS[instanceID] = instanceName continue } @@ -336,23 +361,22 @@ func (si *SSMInstaller) describeSSMAgentState(ctx context.Context, req SSMRunReq } // skipAWSWaitErr is used to ignore the error returned from -// WaitUntilCommandExecutedWithContext if it is a resource not ready -// code as this can represent one of several different errors which +// Wait if it times out, as this can represent one of several different errors which // are handled by checking the command invocation after calling this // to get more information about the error. func skipAWSWaitErr(err error) error { - var aErr awserr.Error - if errors.As(err, &aErr) && aErr.Code() == request.WaiterResourceNotReadyErrorCode { + if err != nil && err.Error() == waiterTimedOutErrorMessage { return nil } return trace.Wrap(err) } func (si *SSMInstaller) checkCommand(ctx context.Context, req SSMRunRequest, commandID, instanceID *string, instanceName string) error { - err := req.SSM.WaitUntilCommandExecutedWithContext(ctx, &ssm.GetCommandInvocationInput{ + err := si.getWaiter(req.SSM).Wait(ctx, &ssm.GetCommandInvocationInput{ CommandId: commandID, InstanceId: instanceID, - }) + // 100 seconds to match v1 sdk waiter default. + }, 100*time.Second) if err := skipAWSWaitErr(err); err != nil { return trace.Wrap(err) @@ -378,7 +402,7 @@ func (si *SSMInstaller) checkCommand(ctx context.Context, req SSMRunRequest, com for i, step := range invocationSteps { stepResultEvent, err := si.getCommandStepStatusEvent(ctx, step, req, commandID, instanceID) if err != nil { - var invalidPluginNameErr *ssm.InvalidPluginName + var invalidPluginNameErr *ssmtypes.InvalidPluginName if errors.As(err, &invalidPluginNameErr) { // If using a custom SSM Document and the client does not have access to ssm:ListCommandInvocations // the list of invocationSteps (ie plugin name) might be wrong. @@ -422,10 +446,10 @@ func (si *SSMInstaller) checkCommand(ctx context.Context, req SSMRunRequest, com func (si *SSMInstaller) getInvocationSteps(ctx context.Context, req SSMRunRequest, commandID, instanceID *string) ([]string, error) { // ssm:ListCommandInvocations is used to list the actual steps because users might be using a custom SSM Document. - listCommandInvocationResp, err := req.SSM.ListCommandInvocationsWithContext(ctx, &ssm.ListCommandInvocationsInput{ + listCommandInvocationResp, err := req.SSM.ListCommandInvocations(ctx, &ssm.ListCommandInvocationsInput{ CommandId: commandID, InstanceId: instanceID, - Details: aws.Bool(true), + Details: true, }) if err != nil { return nil, trace.Wrap(awslib.ConvertRequestFailureError(err)) @@ -436,8 +460,8 @@ func (si *SSMInstaller) getInvocationSteps(ctx context.Context, req SSMRunReques if len(listCommandInvocationResp.CommandInvocations) == 0 { si.Logger.WarnContext(ctx, "No command invocation was found.", - "command_id", aws.StringValue(commandID), - "instance_id", aws.StringValue(instanceID), + "command_id", aws.ToString(commandID), + "instance_id", aws.ToString(instanceID), ) return nil, trace.BadParameter("no command invocation was found") } @@ -445,7 +469,7 @@ func (si *SSMInstaller) getInvocationSteps(ctx context.Context, req SSMRunReques documentSteps := make([]string, 0, len(commandInvocation.CommandPlugins)) for _, step := range commandInvocation.CommandPlugins { - documentSteps = append(documentSteps, aws.StringValue(step.Name)) + documentSteps = append(documentSteps, aws.ToString(step.Name)) } return documentSteps, nil } @@ -458,16 +482,16 @@ func (si *SSMInstaller) getCommandStepStatusEvent(ctx context.Context, step stri if step != "" { getCommandInvocationReq.PluginName = aws.String(step) } - stepResult, err := req.SSM.GetCommandInvocationWithContext(ctx, getCommandInvocationReq) + stepResult, err := req.SSM.GetCommandInvocation(ctx, getCommandInvocationReq) if err != nil { return nil, trace.Wrap(err) } - status := aws.StringValue(stepResult.Status) - exitCode := aws.Int64Value(stepResult.ResponseCode) + status := stepResult.Status + exitCode := int64(stepResult.ResponseCode) eventCode := libevents.SSMRunSuccessCode - if status != ssm.CommandStatusSuccess { + if status != ssmtypes.CommandInvocationStatusSuccess { eventCode = libevents.SSMRunFailCode if exitCode == 0 { exitCode = -1 @@ -479,7 +503,7 @@ func (si *SSMInstaller) getCommandStepStatusEvent(ctx context.Context, step stri // Example: // https://eu-west-2.console.aws.amazon.com/systems-manager/run-command/3cb11aaa-11aa-1111-aaaa-2188108225de/i-0775091aa11111111 invocationURL := fmt.Sprintf("https://%s.console.aws.amazon.com/systems-manager/run-command/%s/%s", - req.Region, aws.StringValue(commandID), aws.StringValue(instanceID), + req.Region, aws.ToString(commandID), aws.ToString(instanceID), ) return &apievents.SSMRun{ @@ -487,14 +511,14 @@ func (si *SSMInstaller) getCommandStepStatusEvent(ctx context.Context, step stri Type: libevents.SSMRunEvent, Code: eventCode, }, - CommandID: aws.StringValue(commandID), - InstanceID: aws.StringValue(instanceID), + CommandID: aws.ToString(commandID), + InstanceID: aws.ToString(instanceID), AccountID: req.AccountID, Region: req.Region, ExitCode: exitCode, - Status: status, - StandardOutput: aws.StringValue(stepResult.StandardOutputContent), - StandardError: aws.StringValue(stepResult.StandardErrorContent), + Status: string(status), + StandardOutput: aws.ToString(stepResult.StandardOutputContent), + StandardError: aws.ToString(stepResult.StandardErrorContent), InvocationURL: invocationURL, }, nil } diff --git a/lib/srv/server/ssm_install_test.go b/lib/srv/server/ssm_install_test.go index c56b286258527..102bcbf5a4475 100644 --- a/lib/srv/server/ssm_install_test.go +++ b/lib/srv/server/ssm_install_test.go @@ -21,15 +21,13 @@ package server import ( "context" "fmt" - "net/http" "testing" + "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/ssm" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" - "github.com/google/uuid" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ssm" + ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types/events" @@ -37,7 +35,7 @@ import ( ) type mockSSMClient struct { - ssmiface.SSMAPI + SSMClient commandOutput *ssm.SendCommandOutput commandInvokeOutput map[string]*ssm.GetCommandInvocationOutput describeOutput *ssm.DescribeInstanceInformationOutput @@ -46,37 +44,37 @@ type mockSSMClient struct { const docWithoutSSHDConfigPathParam = "ssmdocument-without-sshdConfigPath-param" -func (sm *mockSSMClient) SendCommandWithContext(_ context.Context, input *ssm.SendCommandInput, _ ...request.Option) (*ssm.SendCommandOutput, error) { - if _, hasExtraParam := input.Parameters["sshdConfigPath"]; hasExtraParam && aws.StringValue(input.DocumentName) == docWithoutSSHDConfigPathParam { +func (sm *mockSSMClient) SendCommand(_ context.Context, input *ssm.SendCommandInput, _ ...func(*ssm.Options)) (*ssm.SendCommandOutput, error) { + if _, hasExtraParam := input.Parameters["sshdConfigPath"]; hasExtraParam && aws.ToString(input.DocumentName) == docWithoutSSHDConfigPathParam { return nil, fmt.Errorf("InvalidParameters: document %s does not support parameters", docWithoutSSHDConfigPathParam) } return sm.commandOutput, nil } -func (sm *mockSSMClient) GetCommandInvocationWithContext(_ context.Context, input *ssm.GetCommandInvocationInput, _ ...request.Option) (*ssm.GetCommandInvocationOutput, error) { - if stepResult, found := sm.commandInvokeOutput[aws.StringValue(input.PluginName)]; found { +func (sm *mockSSMClient) GetCommandInvocation(_ context.Context, input *ssm.GetCommandInvocationInput, _ ...func(*ssm.Options)) (*ssm.GetCommandInvocationOutput, error) { + if stepResult, found := sm.commandInvokeOutput[aws.ToString(input.PluginName)]; found { return stepResult, nil } - return nil, &ssm.InvalidPluginName{} + return nil, &ssmtypes.InvalidPluginName{} } -func (sm *mockSSMClient) DescribeInstanceInformationWithContext(_ context.Context, input *ssm.DescribeInstanceInformationInput, _ ...request.Option) (*ssm.DescribeInstanceInformationOutput, error) { +func (sm *mockSSMClient) DescribeInstanceInformation(_ context.Context, input *ssm.DescribeInstanceInformationInput, _ ...func(*ssm.Options)) (*ssm.DescribeInstanceInformationOutput, error) { if sm.describeOutput == nil { - return nil, awserr.NewRequestFailure(awserr.New("AccessDeniedException", "message", nil), http.StatusBadRequest, uuid.NewString()) + return nil, trace.AccessDenied("") } return sm.describeOutput, nil } -func (sm *mockSSMClient) ListCommandInvocationsWithContext(aws.Context, *ssm.ListCommandInvocationsInput, ...request.Option) (*ssm.ListCommandInvocationsOutput, error) { +func (sm *mockSSMClient) ListCommandInvocations(_ context.Context, input *ssm.ListCommandInvocationsInput, _ ...func(*ssm.Options)) (*ssm.ListCommandInvocationsOutput, error) { if sm.listCommandInvocations == nil { - return nil, awserr.NewRequestFailure(awserr.New("AccessDeniedException", "message", nil), http.StatusBadRequest, uuid.NewString()) + return nil, trace.AccessDenied("") } return sm.listCommandInvocations, nil } -func (sm *mockSSMClient) WaitUntilCommandExecutedWithContext(aws.Context, *ssm.GetCommandInvocationInput, ...request.WaiterOption) error { - if aws.StringValue(sm.commandOutput.Command.Status) == ssm.CommandStatusFailed { - return awserr.New(request.WaiterResourceNotReadyErrorCode, "err", nil) +func (sm *mockSSMClient) Wait(ctx context.Context, params *ssm.GetCommandInvocationInput, maxWaitDur time.Duration, optFns ...func(*ssm.CommandExecutedWaiterOptions)) error { + if sm.commandOutput.Command.Status == ssmtypes.CommandStatusFailed { + return trace.Errorf(waiterTimedOutErrorMessage) } return nil } @@ -94,6 +92,7 @@ func TestSSMInstaller(t *testing.T) { document := "ssmdocument" for _, tc := range []struct { + client *mockSSMClient req SSMRunRequest expectedInstallations []*SSMInstallationResult name string @@ -108,25 +107,25 @@ func TestSSMInstaller(t *testing.T) { Params: map[string]string{"token": "abcdefg"}, IntegrationName: "aws-integration", DiscoveryConfig: "dc001", - SSM: &mockSSMClient{ - commandOutput: &ssm.SendCommandOutput{ - Command: &ssm.Command{ - CommandId: aws.String("command-id-1"), - }, + Region: "eu-central-1", + AccountID: "account-id", + }, + client: &mockSSMClient{ + commandOutput: &ssm.SendCommandOutput{ + Command: &ssmtypes.Command{ + CommandId: aws.String("command-id-1"), }, - commandInvokeOutput: map[string]*ssm.GetCommandInvocationOutput{ - "downloadContent": { - Status: aws.String(ssm.CommandStatusSuccess), - ResponseCode: aws.Int64(0), - }, - "runShellScript": { - Status: aws.String(ssm.CommandStatusSuccess), - ResponseCode: aws.Int64(0), - }, + }, + commandInvokeOutput: map[string]*ssm.GetCommandInvocationOutput{ + "downloadContent": { + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, + }, + "runShellScript": { + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, }, }, - Region: "eu-central-1", - AccountID: "account-id", }, expectedInstallations: []*SSMInstallationResult{{ IntegrationName: "aws-integration", @@ -141,7 +140,7 @@ func TestSSMInstaller(t *testing.T) { AccountID: "account-id", Region: "eu-central-1", ExitCode: 0, - Status: ssm.CommandStatusSuccess, + Status: string(ssmtypes.CommandInvocationStatusSuccess), InvocationURL: "https://eu-central-1.console.aws.amazon.com/systems-manager/run-command/command-id-1/instance-id-1", }, IssueType: "ec2-ssm-script-failure", @@ -157,25 +156,25 @@ func TestSSMInstaller(t *testing.T) { }, DocumentName: docWithoutSSHDConfigPathParam, Params: map[string]string{"sshdConfigPath": "abcdefg"}, - SSM: &mockSSMClient{ - commandOutput: &ssm.SendCommandOutput{ - Command: &ssm.Command{ - CommandId: aws.String("command-id-1"), - }, + Region: "eu-central-1", + AccountID: "account-id", + }, + client: &mockSSMClient{ + commandOutput: &ssm.SendCommandOutput{ + Command: &ssmtypes.Command{ + CommandId: aws.String("command-id-1"), }, - commandInvokeOutput: map[string]*ssm.GetCommandInvocationOutput{ - "downloadContent": { - Status: aws.String(ssm.CommandStatusSuccess), - ResponseCode: aws.Int64(0), - }, - "runShellScript": { - Status: aws.String(ssm.CommandStatusSuccess), - ResponseCode: aws.Int64(0), - }, + }, + commandInvokeOutput: map[string]*ssm.GetCommandInvocationOutput{ + "downloadContent": { + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, + }, + "runShellScript": { + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, }, }, - Region: "eu-central-1", - AccountID: "account-id", }, expectedInstallations: []*SSMInstallationResult{{ SSMRunEvent: &events.SSMRun{ @@ -188,7 +187,7 @@ func TestSSMInstaller(t *testing.T) { AccountID: "account-id", Region: "eu-central-1", ExitCode: 0, - Status: ssm.CommandStatusSuccess, + Status: string(ssmtypes.CommandInvocationStatusSuccess), InvocationURL: "https://eu-central-1.console.aws.amazon.com/systems-manager/run-command/command-id-1/instance-id-1", }, IssueType: "ec2-ssm-script-failure", @@ -204,23 +203,23 @@ func TestSSMInstaller(t *testing.T) { }, IntegrationName: "aws-1", Params: map[string]string{"token": "abcdefg"}, - SSM: &mockSSMClient{ - commandOutput: &ssm.SendCommandOutput{ - Command: &ssm.Command{ - CommandId: aws.String("command-id-1"), - }, + Region: "eu-central-1", + AccountID: "account-id", + }, + client: &mockSSMClient{ + commandOutput: &ssm.SendCommandOutput{ + Command: &ssmtypes.Command{ + CommandId: aws.String("command-id-1"), }, - commandInvokeOutput: map[string]*ssm.GetCommandInvocationOutput{ - "downloadContent": { - Status: aws.String(ssm.CommandStatusFailed), - ResponseCode: aws.Int64(1), - StandardErrorContent: aws.String("timeout error"), - StandardOutputContent: aws.String(""), - }, + }, + commandInvokeOutput: map[string]*ssm.GetCommandInvocationOutput{ + "downloadContent": { + Status: ssmtypes.CommandInvocationStatusFailed, + ResponseCode: 1, + StandardErrorContent: aws.String("timeout error"), + StandardOutputContent: aws.String(""), }, }, - Region: "eu-central-1", - AccountID: "account-id", }, expectedInstallations: []*SSMInstallationResult{{ IntegrationName: "aws-1", @@ -234,7 +233,7 @@ func TestSSMInstaller(t *testing.T) { AccountID: "account-id", Region: "eu-central-1", ExitCode: 1, - Status: ssm.CommandStatusFailed, + Status: string(ssmtypes.CommandInvocationStatusFailed), StandardOutput: "", StandardError: "timeout error", InvocationURL: "https://eu-central-1.console.aws.amazon.com/systems-manager/run-command/command-id-1/instance-id-1", @@ -250,30 +249,30 @@ func TestSSMInstaller(t *testing.T) { Instances: []EC2Instance{ {InstanceID: "instance-id-1"}, }, - Params: map[string]string{"token": "abcdefg"}, - SSM: &mockSSMClient{ - commandOutput: &ssm.SendCommandOutput{ - Command: &ssm.Command{ - CommandId: aws.String("command-id-1"), - }, + Params: map[string]string{"token": "abcdefg"}, + Region: "eu-central-1", + AccountID: "account-id", + }, + client: &mockSSMClient{ + commandOutput: &ssm.SendCommandOutput{ + Command: &ssmtypes.Command{ + CommandId: aws.String("command-id-1"), }, - commandInvokeOutput: map[string]*ssm.GetCommandInvocationOutput{ - "downloadContent": { - Status: aws.String(ssm.CommandStatusSuccess), - ResponseCode: aws.Int64(0), - StandardErrorContent: aws.String("no error"), - StandardOutputContent: aws.String(""), - }, - "runShellScript": { - Status: aws.String(ssm.CommandStatusFailed), - ResponseCode: aws.Int64(1), - StandardErrorContent: aws.String("timeout error"), - StandardOutputContent: aws.String(""), - }, + }, + commandInvokeOutput: map[string]*ssm.GetCommandInvocationOutput{ + "downloadContent": { + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, + StandardErrorContent: aws.String("no error"), + StandardOutputContent: aws.String(""), + }, + "runShellScript": { + Status: ssmtypes.CommandInvocationStatusFailed, + ResponseCode: 1, + StandardErrorContent: aws.String("timeout error"), + StandardOutputContent: aws.String(""), }, }, - Region: "eu-central-1", - AccountID: "account-id", }, expectedInstallations: []*SSMInstallationResult{{ SSMRunEvent: &events.SSMRun{ @@ -286,7 +285,7 @@ func TestSSMInstaller(t *testing.T) { AccountID: "account-id", Region: "eu-central-1", ExitCode: 1, - Status: ssm.CommandStatusFailed, + Status: string(ssmtypes.CommandInvocationStatusFailed), StandardOutput: "", StandardError: "timeout error", InvocationURL: "https://eu-central-1.console.aws.amazon.com/systems-manager/run-command/command-id-1/instance-id-1", @@ -306,44 +305,44 @@ func TestSSMInstaller(t *testing.T) { }, DocumentName: document, Params: map[string]string{"token": "abcdefg"}, - SSM: &mockSSMClient{ - commandOutput: &ssm.SendCommandOutput{ - Command: &ssm.Command{ - CommandId: aws.String("command-id-1"), - }, + Region: "eu-central-1", + AccountID: "account-id", + }, + client: &mockSSMClient{ + commandOutput: &ssm.SendCommandOutput{ + Command: &ssmtypes.Command{ + CommandId: aws.String("command-id-1"), + }, + }, + commandInvokeOutput: map[string]*ssm.GetCommandInvocationOutput{ + "downloadContent": { + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, + }, + "runShellScript": { + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, }, - commandInvokeOutput: map[string]*ssm.GetCommandInvocationOutput{ - "downloadContent": { - Status: aws.String(ssm.CommandStatusSuccess), - ResponseCode: aws.Int64(0), + }, + describeOutput: &ssm.DescribeInstanceInformationOutput{ + InstanceInformationList: []ssmtypes.InstanceInformation{ + { + InstanceId: aws.String("instance-id-1"), + PingStatus: ssmtypes.PingStatusOnline, + PlatformType: ssmtypes.PlatformTypeLinux, }, - "runShellScript": { - Status: aws.String(ssm.CommandStatusSuccess), - ResponseCode: aws.Int64(0), + { + InstanceId: aws.String("instance-id-2"), + PingStatus: ssmtypes.PingStatusConnectionLost, + PlatformType: ssmtypes.PlatformTypeLinux, }, - }, - describeOutput: &ssm.DescribeInstanceInformationOutput{ - InstanceInformationList: []*ssm.InstanceInformation{ - { - InstanceId: aws.String("instance-id-1"), - PingStatus: aws.String("Online"), - PlatformType: aws.String("Linux"), - }, - { - InstanceId: aws.String("instance-id-2"), - PingStatus: aws.String("ConnectionLost"), - PlatformType: aws.String("Linux"), - }, - { - InstanceId: aws.String("instance-id-3"), - PingStatus: aws.String("Online"), - PlatformType: aws.String("Windows"), - }, + { + InstanceId: aws.String("instance-id-3"), + PingStatus: ssmtypes.PingStatusOnline, + PlatformType: ssmtypes.PlatformTypeWindows, }, }, }, - Region: "eu-central-1", - AccountID: "account-id", }, expectedInstallations: []*SSMInstallationResult{ { @@ -357,7 +356,7 @@ func TestSSMInstaller(t *testing.T) { AccountID: "account-id", Region: "eu-central-1", ExitCode: 0, - Status: ssm.CommandStatusSuccess, + Status: string(ssmtypes.CommandInvocationStatusSuccess), InvocationURL: "https://eu-central-1.console.aws.amazon.com/systems-manager/run-command/command-id-1/instance-id-1", }, IssueType: "ec2-ssm-script-failure", @@ -421,34 +420,34 @@ func TestSSMInstaller(t *testing.T) { }, DocumentName: document, Params: map[string]string{"token": "abcdefg"}, - SSM: &mockSSMClient{ - commandOutput: &ssm.SendCommandOutput{ - Command: &ssm.Command{ - CommandId: aws.String("command-id-1"), - }, + Region: "eu-central-1", + AccountID: "account-id", + }, + client: &mockSSMClient{ + commandOutput: &ssm.SendCommandOutput{ + Command: &ssmtypes.Command{ + CommandId: aws.String("command-id-1"), }, - commandInvokeOutput: map[string]*ssm.GetCommandInvocationOutput{ - "downloadContentCustom": { - Status: aws.String(ssm.CommandStatusSuccess), - ResponseCode: aws.Int64(0), - }, - "runShellScriptCustom": { - Status: aws.String(ssm.CommandStatusSuccess), - ResponseCode: aws.Int64(0), - StandardOutputContent: aws.String("custom output"), - }, + }, + commandInvokeOutput: map[string]*ssm.GetCommandInvocationOutput{ + "downloadContentCustom": { + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, }, - listCommandInvocations: &ssm.ListCommandInvocationsOutput{ - CommandInvocations: []*ssm.CommandInvocation{{ - CommandPlugins: []*ssm.CommandPlugin{ - {Name: aws.String("downloadContentCustom")}, - {Name: aws.String("runShellScriptCustom")}, - }, - }}, + "runShellScriptCustom": { + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, + StandardOutputContent: aws.String("custom output"), }, }, - Region: "eu-central-1", - AccountID: "account-id", + listCommandInvocations: &ssm.ListCommandInvocationsOutput{ + CommandInvocations: []ssmtypes.CommandInvocation{{ + CommandPlugins: []ssmtypes.CommandPlugin{ + {Name: aws.String("downloadContentCustom")}, + {Name: aws.String("runShellScriptCustom")}, + }, + }}, + }, }, expectedInstallations: []*SSMInstallationResult{{ SSMRunEvent: &events.SSMRun{ @@ -461,7 +460,7 @@ func TestSSMInstaller(t *testing.T) { AccountID: "account-id", Region: "eu-central-1", ExitCode: 0, - Status: ssm.CommandStatusSuccess, + Status: string(ssmtypes.CommandInvocationStatusSuccess), StandardOutput: "custom output", InvocationURL: "https://eu-central-1.console.aws.amazon.com/systems-manager/run-command/command-id-1/instance-id-1", }, @@ -477,21 +476,21 @@ func TestSSMInstaller(t *testing.T) { }, DocumentName: document, Params: map[string]string{"token": "abcdefg"}, - SSM: &mockSSMClient{ - commandOutput: &ssm.SendCommandOutput{ - Command: &ssm.Command{ - CommandId: aws.String("command-id-1"), - }, + Region: "eu-central-1", + AccountID: "account-id", + }, + client: &mockSSMClient{ + commandOutput: &ssm.SendCommandOutput{ + Command: &ssmtypes.Command{ + CommandId: aws.String("command-id-1"), }, - commandInvokeOutput: map[string]*ssm.GetCommandInvocationOutput{ - "": { - Status: aws.String(ssm.CommandStatusSuccess), - ResponseCode: aws.Int64(0), - }, + }, + commandInvokeOutput: map[string]*ssm.GetCommandInvocationOutput{ + "": { + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, }, }, - Region: "eu-central-1", - AccountID: "account-id", }, expectedInstallations: []*SSMInstallationResult{{ SSMRunEvent: &events.SSMRun{ @@ -504,7 +503,7 @@ func TestSSMInstaller(t *testing.T) { AccountID: "account-id", Region: "eu-central-1", ExitCode: 0, - Status: ssm.CommandStatusSuccess, + Status: string(ssmtypes.CommandInvocationStatusSuccess), InvocationURL: "https://eu-central-1.console.aws.amazon.com/systems-manager/run-command/command-id-1/instance-id-1", }, IssueType: "ec2-ssm-script-failure", @@ -516,9 +515,11 @@ func TestSSMInstaller(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() + tc.req.SSM = tc.client installationResultsCollector := &mockInstallationResults{} inst, err := NewSSMInstaller(SSMInstallerConfig{ ReportSSMInstallationResultFunc: installationResultsCollector.ReportInstallationResult, + getWaiter: func(s SSMClient) commandWaiter { return tc.client }, }) require.NoError(t, err)