From 7260d057d2e5cb2107f40f1c76cfe33ecd3ed882 Mon Sep 17 00:00:00 2001 From: Ed Smith Date: Thu, 10 Oct 2024 12:06:53 +0100 Subject: [PATCH] Closes #49 - Upgrade to AWS SDK v2 (#50) * feat: upgrade AWS SDK to v2 Signed-off-by: Theo Bob Massard * tests: adapt tests to SDK v2 Signed-off-by: Theo Bob Massard * feat: add ECSClient interface and replace usage of concrete ecs.Client type in App * tests: update tests to use ECSClient interface and update mocks * feat: create EC2Client interface * tests: fix remaining tests * feat: swap use of ecs.client for ECSClient in getContainerPort --------- Signed-off-by: Theo Bob Massard Co-authored-by: Theo Bob Massard --- go.mod | 17 +- go.sum | 35 +++- internal/app.go | 122 +++++++----- internal/app_test.go | 404 +++++++++++++++++--------------------- internal/command.go | 13 +- internal/command_test.go | 60 +++--- internal/forward.go | 28 ++- internal/internal.go | 113 ++++++----- internal/internal_test.go | 128 +++++++----- internal/select.go | 28 +-- 10 files changed, 515 insertions(+), 433 deletions(-) diff --git a/go.mod b/go.mod index 918bab4..f6f6edf 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,11 @@ go 1.18 require ( github.com/AlecAivazis/survey/v2 v2.2.9 - github.com/aws/aws-sdk-go v1.50.6 + github.com/aws/aws-sdk-go-v2 v1.24.0 + github.com/aws/aws-sdk-go-v2/config v1.26.2 + github.com/aws/aws-sdk-go-v2/service/ec2 v1.142.0 + github.com/aws/aws-sdk-go-v2/service/ecs v1.35.6 + github.com/aws/aws-sdk-go-v2/service/ssm v1.44.6 github.com/fatih/color v1.10.0 github.com/spf13/cobra v1.1.3 github.com/spf13/viper v1.7.1 @@ -12,6 +16,17 @@ require ( ) require ( + github.com/aws/aws-sdk-go-v2/credentials v1.16.13 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.9 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.9 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.9 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.18.5 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.26.6 // indirect + github.com/aws/smithy-go v1.19.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.4.7 // indirect github.com/hashicorp/hcl v1.0.0 // indirect diff --git a/go.sum b/go.sum index caf6e06..d85a04e 100644 --- a/go.sum +++ b/go.sum @@ -24,8 +24,38 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= -github.com/aws/aws-sdk-go v1.50.6 h1:FaXvNwHG3Ri1paUEW16Ahk9zLVqSAdqa1M3phjZR35Q= -github.com/aws/aws-sdk-go v1.50.6/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= +github.com/aws/aws-sdk-go-v2 v1.24.0 h1:890+mqQ+hTpNuw0gGP6/4akolQkSToDJgHfQE7AwGuk= +github.com/aws/aws-sdk-go-v2 v1.24.0/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= +github.com/aws/aws-sdk-go-v2/config v1.26.2 h1:+RWLEIWQIGgrz2pBPAUoGgNGs1TOyF4Hml7hCnYj2jc= +github.com/aws/aws-sdk-go-v2/config v1.26.2/go.mod h1:l6xqvUxt0Oj7PI/SUXYLNyZ9T/yBPn3YTQcJLLOdtR8= +github.com/aws/aws-sdk-go-v2/credentials v1.16.13 h1:WLABQ4Cp4vXtXfOWOS3MEZKr6AAYUpMczLhgKtAjQ/8= +github.com/aws/aws-sdk-go-v2/credentials v1.16.13/go.mod h1:Qg6x82FXwW0sJHzYruxGiuApNo31UEtJvXVSZAXeWiw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10 h1:w98BT5w+ao1/r5sUuiH6JkVzjowOKeOJRHERyy1vh58= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10/go.mod h1:K2WGI7vUvkIv1HoNbfBA1bvIZ+9kL3YVmWxeKuLQsiw= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.9 h1:v+HbZaCGmOwnTTVS86Fleq0vPzOd7tnJGbFhP0stNLs= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.9/go.mod h1:Xjqy+Nyj7VDLBtCMkQYOw1QYfAEZCVLrfI0ezve8wd4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.9 h1:N94sVhRACtXyVcjXxrwK1SKFIJrA9pOJ5yu2eSHnmls= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.9/go.mod h1:hqamLz7g1/4EJP+GH5NBhcUMLjW+gKLQabgyz6/7WAU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 h1:GrSw8s0Gs/5zZ0SX+gX4zQjRnRsMJDJ2sLur1gRBhEM= +github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2/go.mod h1:6fQQgfuGmw8Al/3M2IgIllycxV7ZW7WCdVSqfBeUiCY= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.142.0 h1:VrFC1uEZjX4ghkm/et8ATVGb1mT75Iv8aPKPjUE+F8A= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.142.0/go.mod h1:qjhtI9zjpUHRc6khtrIM9fb48+ii6+UikL3/b+MKYn0= +github.com/aws/aws-sdk-go-v2/service/ecs v1.35.6 h1:Sc2mLjyA1R8z2l705AN7Wr7QOlnUxVnGPJeDIVyUSrs= +github.com/aws/aws-sdk-go-v2/service/ecs v1.35.6/go.mod h1:LzHcyOEvaLjbc5e+fP/KmPWBr+h/Ef+EHvnf1Pzo368= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 h1:/b31bi3YVNlkzkBrm9LfpaKoaYZUxIAj4sHfOTmLfqw= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4/go.mod h1:2aGXHFmbInwgP9ZfpmdIfOELL79zhdNYNmReK8qDfdQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.9 h1:Nf2sHxjMJR8CSImIVCONRi4g0Su3J+TSTbS7G0pUeMU= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.9/go.mod h1:idky4TER38YIjr2cADF1/ugFMKvZV7p//pVeV5LZbF0= +github.com/aws/aws-sdk-go-v2/service/ssm v1.44.6 h1:EZw+TRx/4qlfp6VJ0P1sx04Txd9yGNK+NiO1upaXmh4= +github.com/aws/aws-sdk-go-v2/service/ssm v1.44.6/go.mod h1:uXndCJoDO9gpuK24rNWVCnrGNUydKFEAYAZ7UU9S0rQ= +github.com/aws/aws-sdk-go-v2/service/sso v1.18.5 h1:ldSFWz9tEHAwHNmjx2Cvy1MjP5/L9kNoR0skc6wyOOM= +github.com/aws/aws-sdk-go-v2/service/sso v1.18.5/go.mod h1:CaFfXLYL376jgbP7VKC96uFcU8Rlavak0UlAwk1Dlhc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5 h1:2k9KmFawS63euAkY4/ixVNsYYwrwnd5fIvgEKkfZFNM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5/go.mod h1:W+nd4wWDVkSUIox9bacmkBP5NMFQeTJ/xqNabpzSR38= +github.com/aws/aws-sdk-go-v2/service/sts v1.26.6 h1:HJeiuZ2fldpd0WqngyMR6KW7ofkXNLyOaHwEIGm39Cs= +github.com/aws/aws-sdk-go-v2/service/sts v1.26.6/go.mod h1:XX5gh4CB7wAs4KhcF46G6C8a2i7eupU19dcAAE+EydU= +github.com/aws/smithy-go v1.19.0 h1:KWFKQV80DpP3vJrrA9sVAHQ5gc2z8i4EzrLhLlWXcBM= +github.com/aws/smithy-go v1.19.0/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= @@ -68,6 +98,7 @@ github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Z github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= diff --git a/internal/app.go b/internal/app.go index c68a71a..46ad4aa 100644 --- a/internal/app.go +++ b/internal/app.go @@ -1,48 +1,58 @@ package app import ( + "context" "errors" "fmt" "os/exec" "sort" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ecs" - "github.com/aws/aws-sdk-go/service/ecs/ecsiface" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/spf13/viper" ) +type EC2Client interface { + DescribeInstances(ctx context.Context, params *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) +} + +type ECSClient interface { + ListClusters(ctx context.Context, params *ecs.ListClustersInput, optFns ...func(*ecs.Options)) (*ecs.ListClustersOutput, error) + ListServices(ctx context.Context, params *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) + ListTasks(ctx context.Context, params *ecs.ListTasksInput, optFns ...func(*ecs.Options)) (*ecs.ListTasksOutput, error) + DescribeTasks(ctx context.Context, params *ecs.DescribeTasksInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error) + DescribeTaskDefinition(ctx context.Context, params *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error) + DescribeContainerInstances(ctx context.Context, params *ecs.DescribeContainerInstancesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeContainerInstancesOutput, error) + ExecuteCommand(ctx context.Context, params *ecs.ExecuteCommandInput, optFns ...func(*ecs.Options)) (*ecs.ExecuteCommandOutput, error) +} + // App is a struct that contains information about our command state type App struct { - input chan string - err chan error - exit chan error - ecsClient ecsiface.ECSAPI - ssmClient ssmiface.SSMAPI - region string - endpoint string - cluster string - service string - task *ecs.Task - tasks map[string]*ecs.Task - container *ecs.Container - containers []*ecs.Container + input chan string + err chan error + exit chan error + client ECSClient + region string + endpoint string + cluster string + service string + task *ecsTypes.Task + tasks map[string]*ecsTypes.Task + container *ecsTypes.Container } // CreateApp initialises a new App struct with the required initial values func CreateApp() *App { - ecsClient := createEcsClient() - ssmClient := createSSMClient() + client := createEcsClient() e := &App{ - input: make(chan string, 1), - err: make(chan error, 1), - exit: make(chan error, 1), - ecsClient: ecsClient, - ssmClient: ssmClient, - region: ecsClient.SigningRegion, - endpoint: ecsClient.Endpoint, + input: make(chan string, 1), + err: make(chan error, 1), + exit: make(chan error, 1), + client: client, + region: client.Options().Region, } return e @@ -99,7 +109,7 @@ func (e *App) Start() error { // Lists available clusters and prompts the user to select one func (e *App) getCluster() { - var clusters []*string + var clusters []string var nextToken *string if cluster := viper.GetString("cluster"); cluster != "" { @@ -116,7 +126,11 @@ func (e *App) getCluster() { return } - list, err := e.ecsClient.ListClusters(&ecs.ListClustersInput{ + if e.client == nil { + + panic("oh no") + } + list, err := e.client.ListClusters(context.TODO(), &ecs.ListClustersInput{ MaxResults: awsMaxResults, }) if err != nil { @@ -128,7 +142,7 @@ func (e *App) getCluster() { if nextToken != nil { for { - list, err := e.ecsClient.ListClusters(&ecs.ListClustersInput{ + list, err := e.client.ListClusters(context.TODO(), &ecs.ListClustersInput{ MaxResults: awsMaxResults, NextToken: nextToken, }) @@ -147,13 +161,13 @@ func (e *App) getCluster() { // Sort the list of clusters alphabetically sort.Slice(clusters, func(i, j int) bool { - return *clusters[i] < *clusters[j] + return clusters[i] < clusters[j] }) if len(clusters) > 0 { var clusterNames []string for _, c := range clusters { - arnSplit := strings.Split(*c, "/") + arnSplit := strings.Split(c, "/") name := arnSplit[len(arnSplit)-1] clusterNames = append(clusterNames, name) } @@ -177,7 +191,7 @@ func (e *App) getCluster() { // Lists available services and prompts the user to select one func (e *App) getService() { - var services []*string + var services []string var nextToken *string cliArg := viper.GetString("service") @@ -188,7 +202,7 @@ func (e *App) getService() { return } - list, err := e.ecsClient.ListServices(&ecs.ListServicesInput{ + list, err := e.client.ListServices(context.TODO(), &ecs.ListServicesInput{ Cluster: aws.String(e.cluster), MaxResults: awsMaxResults, }) @@ -201,7 +215,7 @@ func (e *App) getService() { if nextToken != nil { for { - list, err := e.ecsClient.ListServices(&ecs.ListServicesInput{ + list, err := e.client.ListServices(context.TODO(), &ecs.ListServicesInput{ Cluster: aws.String(e.cluster), MaxResults: awsMaxResults, NextToken: nextToken, @@ -221,14 +235,14 @@ func (e *App) getService() { // Sort the list of services alphabetically sort.Slice(services, func(i, j int) bool { - return *services[i] < *services[j] + return services[i] < services[j] }) if len(services) > 0 { var serviceNames []string - for _, c := range services { - arnSplit := strings.Split(*c, "/") + for _, s := range services { + arnSplit := strings.Split(s, "/") name := arnSplit[len(arnSplit)-1] serviceNames = append(serviceNames, name) } @@ -259,29 +273,29 @@ func (e *App) getService() { // Lists tasks in a cluster and prompts the user to select one func (e *App) getTask() { - var taskArns []*string + var taskArns []string var nextToken *string var input *ecs.ListTasksInput cliArg := viper.GetString("task") if cliArg != "" { - describe, err := e.ecsClient.DescribeTasks(&ecs.DescribeTasksInput{ + describe, err := e.client.DescribeTasks(context.TODO(), &ecs.DescribeTasksInput{ Cluster: aws.String(e.cluster), - Tasks: []*string{aws.String(cliArg)}, + Tasks: []string{*aws.String(cliArg)}, }) if err != nil { e.err <- err return } if len(describe.Tasks) > 0 { - e.task = describe.Tasks[0] + e.task = &describe.Tasks[0] e.getContainerOS() e.input <- "getContainer" viper.Set("task", "") // Reset the cli arg so user can navigate return } else { - fmt.Printf(Red(fmt.Sprintf("\nTask with ID %s not found in cluster %s\n", cliArg, e.cluster))) + fmt.Println(Red(fmt.Sprintf("\nTask with ID %s not found in cluster %s\n", cliArg, e.cluster))) e.input <- "getService" return } @@ -302,7 +316,7 @@ func (e *App) getTask() { } } - list, err := e.ecsClient.ListTasks(input) + list, err := e.client.ListTasks(context.TODO(), input) if err != nil { e.err <- err return @@ -313,7 +327,7 @@ func (e *App) getTask() { if nextToken != nil { for { - list, err := e.ecsClient.ListTasks(&ecs.ListTasksInput{ + list, err := e.client.ListTasks(context.TODO(), &ecs.ListTasksInput{ Cluster: aws.String(e.cluster), MaxResults: awsMaxResults, NextToken: nextToken, @@ -331,9 +345,9 @@ func (e *App) getTask() { } } - e.tasks = make(map[string]*ecs.Task) + e.tasks = make(map[string]*ecsTypes.Task) if len(taskArns) > 0 { - describe, err := e.ecsClient.DescribeTasks(&ecs.DescribeTasksInput{ + describe, err := e.client.DescribeTasks(context.TODO(), &ecs.DescribeTasksInput{ Cluster: aws.String(e.cluster), Tasks: taskArns, }) @@ -344,7 +358,7 @@ func (e *App) getTask() { for _, t := range describe.Tasks { taskId := strings.Split(*t.TaskArn, "/")[2] - e.tasks[taskId] = t + e.tasks[taskId] = &t } selection, err := selectTask(e.tasks) @@ -354,7 +368,7 @@ func (e *App) getTask() { } if *selection.TaskArn == backOpt { - e.task = nil + // e.task = nil if e.service == "" { e.input <- "getCluster" return @@ -387,7 +401,7 @@ func (e *App) getContainer() { if cliArg != "" { for _, c := range e.task.Containers { if *c.Name == cliArg { - e.container = c + e.container = &c e.input <- "execute" return } @@ -396,7 +410,7 @@ func (e *App) getContainer() { } if len(e.task.Containers) > 1 { - selection, err := selectContainer(e.task.Containers) + selection, err := selectContainer(&e.task.Containers) if err != nil { e.err <- err return @@ -413,7 +427,7 @@ func (e *App) getContainer() { } else { // There is only one container in the task, return it - e.container = e.task.Containers[0] + e.container = &e.task.Containers[0] e.input <- "execute" return } @@ -422,8 +436,8 @@ func (e *App) getContainer() { // Determines the OS family of the container instance the task is running on func (e *App) getContainerOS() { // Get associated task definition and determine OS family if EC2 launch-type - if *e.task.LaunchType == "EC2" { - family, err := getPlatformFamily(e.ecsClient, e.cluster, e.task) + if e.task.LaunchType == "EC2" { + family, err := getPlatformFamily(e.client, e.cluster, e.task) if err != nil { e.err <- err return @@ -432,7 +446,7 @@ func (e *App) getContainerOS() { // then we refer to the container instance to determine the OS if family == "" { ec2Client := createEc2Client() - family, err = getContainerInstanceOS(e.ecsClient, ec2Client, e.cluster, *e.task.ContainerInstanceArn) + family, err = getContainerInstanceOS(e.client, ec2Client, e.cluster, *e.task.ContainerInstanceArn) if err != nil { e.err <- err return diff --git a/internal/app_test.go b/internal/app_test.go index beb3ce0..6180a0e 100644 --- a/internal/app_test.go +++ b/internal/app_test.go @@ -1,15 +1,15 @@ package app import ( + "context" "fmt" "os" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" - "github.com/aws/aws-sdk-go/service/ecs" - "github.com/aws/aws-sdk-go/service/ecs/ecsiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/stretchr/testify/assert" ) @@ -17,68 +17,45 @@ func init() { os.Setenv("AWS_DEFAULT_REGION", "eu-west-1") } -type MockECSAPI struct { - ecsiface.ECSAPI // embedding of the interface is needed to skip implementation of all methods - ListClustersMock func(input *ecs.ListClustersInput) (*ecs.ListClustersOutput, error) - ListServicesMock func(input *ecs.ListServicesInput) (*ecs.ListServicesOutput, error) - ListTasksMock func(input *ecs.ListTasksInput) (*ecs.ListTasksOutput, error) - DescribeTasksMock func(input *ecs.DescribeTasksInput) (*ecs.DescribeTasksOutput, error) - DescribeTaskDefinitionMock func(input *ecs.DescribeTaskDefinitionInput) (*ecs.DescribeTaskDefinitionOutput, error) - DescribeContainerInstancesMock func(input *ecs.DescribeContainerInstancesInput) (*ecs.DescribeContainerInstancesOutput, error) - ExecuteCommandMock func(input *ecs.ExecuteCommandInput) (*ecs.ExecuteCommandOutput, error) +type ECSClientMock struct { + ListClustersMock func(ctx context.Context, params *ecs.ListClustersInput, optFns ...func(*ecs.Options)) (*ecs.ListClustersOutput, error) + ListServicesMock func(ctx context.Context, params *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) + ListTasksMock func(ctx context.Context, params *ecs.ListTasksInput, optFns ...func(*ecs.Options)) (*ecs.ListTasksOutput, error) + DescribeTasksMock func(ctx context.Context, params *ecs.DescribeTasksInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error) + DescribeTaskDefinitionMock func(ctx context.Context, params *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error) + DescribeContainerInstancesMock func(ctx context.Context, params *ecs.DescribeContainerInstancesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeContainerInstancesOutput, error) + ExecuteCommandMock func(ctx context.Context, params *ecs.ExecuteCommandInput, optFns ...func(*ecs.Options)) (*ecs.ExecuteCommandOutput, error) } -func (m *MockECSAPI) ListClusters(input *ecs.ListClustersInput) (*ecs.ListClustersOutput, error) { - if m.ListClustersMock != nil { - return m.ListClustersMock(input) - } - return nil, nil +func (m ECSClientMock) ListClusters(ctx context.Context, params *ecs.ListClustersInput, optFns ...func(*ecs.Options)) (*ecs.ListClustersOutput, error) { + return m.ListClustersMock(ctx, params, optFns...) } -func (m *MockECSAPI) ListServices(input *ecs.ListServicesInput) (*ecs.ListServicesOutput, error) { - if m.ListServicesMock != nil { - return m.ListServicesMock(input) - } - return nil, nil +func (m ECSClientMock) ListServices(ctx context.Context, params *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) { + return m.ListServicesMock(ctx, params, optFns...) } -func (m *MockECSAPI) ListTasks(input *ecs.ListTasksInput) (*ecs.ListTasksOutput, error) { - if m.ListTasksMock != nil { - return m.ListTasksMock(input) - } - return nil, nil +func (m ECSClientMock) ListTasks(ctx context.Context, params *ecs.ListTasksInput, optFns ...func(*ecs.Options)) (*ecs.ListTasksOutput, error) { + return m.ListTasksMock(ctx, params, optFns...) } -func (m *MockECSAPI) DescribeTasks(input *ecs.DescribeTasksInput) (*ecs.DescribeTasksOutput, error) { - if m.DescribeTasksMock != nil { - return m.DescribeTasksMock(input) - } - return nil, nil +func (m ECSClientMock) DescribeTasks(ctx context.Context, params *ecs.DescribeTasksInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error) { + return m.DescribeTasksMock(ctx, params, optFns...) } -func (m *MockECSAPI) DescribeTaskDefinition(input *ecs.DescribeTaskDefinitionInput) (*ecs.DescribeTaskDefinitionOutput, error) { - if m.DescribeTaskDefinitionMock != nil { - return m.DescribeTaskDefinitionMock(input) - } - return nil, nil +func (m ECSClientMock) DescribeTaskDefinition(ctx context.Context, params *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error) { + return m.DescribeTaskDefinitionMock(ctx, params, optFns...) } -func (m *MockECSAPI) DescribeContainerInstances(input *ecs.DescribeContainerInstancesInput) (*ecs.DescribeContainerInstancesOutput, error) { - if m.DescribeContainerInstancesMock != nil { - return m.DescribeContainerInstancesMock(input) - } - return nil, nil +func (m ECSClientMock) DescribeContainerInstances(ctx context.Context, params *ecs.DescribeContainerInstancesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeContainerInstancesOutput, error) { + return m.DescribeContainerInstancesMock(ctx, params, optFns...) } - -func (m *MockECSAPI) ExecuteCommand(input *ecs.ExecuteCommandInput) (*ecs.ExecuteCommandOutput, error) { - if m.ExecuteCommandMock != nil { - return m.ExecuteCommandMock(input) - } - return nil, nil +func (m ECSClientMock) ExecuteCommand(ctx context.Context, params *ecs.ExecuteCommandInput, optFns ...func(*ecs.Options)) (*ecs.ExecuteCommandOutput, error) { + return m.ExecuteCommandMock(ctx, params, optFns...) } type MockEC2API struct { - ec2iface.EC2API // embedding of the interface is needed to skip implementation of all methods + ec2.Client // embedding of the interface is needed to skip implementation of all methods DescribeInstancesMock func(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) } @@ -90,14 +67,13 @@ func (m *MockEC2API) DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2. } // CreateMockApp initialises a new App struct and takes a MockClient as an argument - only used in tests -func CreateMockApp(c *MockECSAPI) *App { +func CreateMockApp(c ECSClient) *App { e := &App{ - input: make(chan string, 1), - err: make(chan error, 1), - exit: make(chan error, 1), - ecsClient: c, - region: "eu-west-1", - endpoint: "ecs.eu-west-1.amazonaws.com", + input: make(chan string, 1), + err: make(chan error, 1), + exit: make(chan error, 1), + client: c, + region: "eu-west-1", } return e @@ -106,75 +82,40 @@ func CreateMockApp(c *MockECSAPI) *App { func TestGetCluster(t *testing.T) { paginationCall := 0 cases := []struct { - name string - ecsClient *MockECSAPI - expected string + name string + client func(t *testing.T) ECSClient + expected string }{ - { - name: "TestGetClusterWithResults", - ecsClient: &MockECSAPI{ - ListClustersMock: func(input *ecs.ListClustersInput) (*ecs.ListClustersOutput, error) { - return &ecs.ListClustersOutput{ - ClusterArns: []*string{ - aws.String("arn:aws:ecs:eu-west-1:1111111111:cluster/App"), - aws.String("arn:aws:ecs:eu-west-1:1111111111:cluster/blueGreen-1"), - }, - }, nil - }, - }, - expected: "App", - }, { name: "TestGetClusterWithResultsPaginated", - ecsClient: &MockECSAPI{ - ListClustersMock: func(input *ecs.ListClustersInput) (*ecs.ListClustersOutput, error) { - var clusters []*string - for i := paginationCall; i < (paginationCall * 100); i++ { - clusters = append(clusters, aws.String(fmt.Sprintf("arn:aws:ecs:eu-west-1:1111111111:cluster/test-cluster-%d", i))) - } - paginationCall = paginationCall + 1 - if paginationCall > 2 { + client: func(t *testing.T) ECSClient { + return ECSClientMock{ + ListClustersMock: func(ctx context.Context, input *ecs.ListClustersInput, optFns ...func(*ecs.Options)) (*ecs.ListClustersOutput, error) { + var clusters []string + for i := paginationCall; i < (paginationCall * 100); i++ { + clusters = append(clusters, *aws.String(fmt.Sprintf("arn:aws:ecs:eu-west-1:1111111111:cluster/test-cluster-%d", i))) + } + paginationCall = paginationCall + 1 + if paginationCall > 2 { + return &ecs.ListClustersOutput{ + ClusterArns: clusters, + NextToken: nil, + }, nil + } return &ecs.ListClustersOutput{ ClusterArns: clusters, - NextToken: nil, + NextToken: aws.String("test-token"), }, nil - } - return &ecs.ListClustersOutput{ - ClusterArns: clusters, - NextToken: aws.String("test-token"), - }, nil - }, + }, + } }, expected: "test-cluster-101", }, - { - name: "TestGetClusterWithSingleResult", - ecsClient: &MockECSAPI{ - ListClustersMock: func(input *ecs.ListClustersInput) (*ecs.ListClustersOutput, error) { - return &ecs.ListClustersOutput{ - ClusterArns: []*string{ - aws.String("arn:aws:ecs:eu-west-1:1111111111:cluster/App"), - }, - }, nil - }, - }, - expected: "App", - }, - { - name: "TestGetClusterWithoutResults", - ecsClient: &MockECSAPI{ - ListClustersMock: func(input *ecs.ListClustersInput) (*ecs.ListClustersOutput, error) { - return &ecs.ListClustersOutput{ - ClusterArns: []*string{}, - }, nil - }, - }, - expected: "", - }, } for _, c := range cases { - input := CreateMockApp(c.ecsClient) + client := c.client(t) + input := CreateMockApp(client) input.getCluster() if ok := assert.Equal(t, c.expected, input.cluster); ok != true { fmt.Printf("%s FAILED\n", c.name) @@ -186,62 +127,69 @@ func TestGetCluster(t *testing.T) { func TestGetService(t *testing.T) { paginationCall := 1 cases := []struct { - name string - ecsClient *MockECSAPI - expected string + name string + client func(t *testing.T) ECSClient + expected string }{ { name: "TestGetServiceWithResults", - ecsClient: &MockECSAPI{ - ListServicesMock: func(input *ecs.ListServicesInput) (*ecs.ListServicesOutput, error) { - return &ecs.ListServicesOutput{ - ServiceArns: []*string{ - aws.String("arn:aws:ecs:eu-west-1:1111111111:cluster/App/test-service-1"), - aws.String("arn:aws:ecs:eu-west-1:1111111111:cluster/blueGreen/test-service-2"), - }, - }, nil - }, + client: func(t *testing.T) ECSClient { + return ECSClientMock{ + ListServicesMock: func(ctx context.Context, input *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) { + return &ecs.ListServicesOutput{ + ServiceArns: []string{ + *aws.String("arn:aws:ecs:eu-west-1:1111111111:cluster/App/test-service-1"), + *aws.String("arn:aws:ecs:eu-west-1:1111111111:cluster/blueGreen/test-service-2"), + }, + }, nil + }, + } }, expected: "test-service-1", }, { name: "TestGetServiceWithResultsPaginated", - ecsClient: &MockECSAPI{ - ListServicesMock: func(input *ecs.ListServicesInput) (*ecs.ListServicesOutput, error) { - var services []*string - for i := paginationCall; i < (paginationCall * 100); i++ { - services = append(services, aws.String(fmt.Sprintf("arn:aws:ecs:eu-west-1:1111111111:cluster/App/test-service-%d", i))) - } - paginationCall = paginationCall + 1 - if paginationCall > 2 { + client: func(t *testing.T) ECSClient { + return ECSClientMock{ + ListServicesMock: func(ctx context.Context, input *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) { + var services []string + for i := paginationCall; i < (paginationCall * 100); i++ { + services = append(services, *aws.String(fmt.Sprintf("arn:aws:ecs:eu-west-1:1111111111:cluster/App/test-service-%d", i))) + } + paginationCall = paginationCall + 1 + if paginationCall > 2 { + return &ecs.ListServicesOutput{ + ServiceArns: services, + NextToken: nil, + }, nil + } return &ecs.ListServicesOutput{ ServiceArns: services, - NextToken: nil, + NextToken: aws.String("test-token"), }, nil - } - return &ecs.ListServicesOutput{ - ServiceArns: services, - NextToken: aws.String("test-token"), - }, nil - }, + }, + } }, expected: "test-service-101", }, { name: "TestGetServiceWithoutResults", - ecsClient: &MockECSAPI{ - ListServicesMock: func(input *ecs.ListServicesInput) (*ecs.ListServicesOutput, error) { - return &ecs.ListServicesOutput{ - ServiceArns: []*string{}, - }, nil - }, + client: func(t *testing.T) ECSClient { + return ECSClientMock{ + ListServicesMock: func(ctx context.Context, input *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) { + return &ecs.ListServicesOutput{ + ServiceArns: []string{}, + }, nil + }, + } }, expected: "", }, } for _, c := range cases { - input := CreateMockApp(c.ecsClient) + client := c.client(t) + input := CreateMockApp(client) input.cluster = "App" input.getService() if ok := assert.Equal(t, c.expected, input.service); ok != true { @@ -254,88 +202,95 @@ func TestGetService(t *testing.T) { func TestGetTask(t *testing.T) { paginationCall := 1 cases := []struct { - name string - ecsClient *MockECSAPI - expected *ecs.Task + name string + client func(t *testing.T) ECSClient + expected *ecsTypes.Task }{ { name: "TestGetTaskWithResults", - ecsClient: &MockECSAPI{ - ListTasksMock: func(input *ecs.ListTasksInput) (*ecs.ListTasksOutput, error) { - return &ecs.ListTasksOutput{ - TaskArns: []*string{ - aws.String("arn:aws:ecs:eu-west-1:111111111111:task/App/8a58117dac38436ba5547e9da5d3ac3d"), - }, - }, nil - }, - DescribeTasksMock: func(input *ecs.DescribeTasksInput) (*ecs.DescribeTasksOutput, error) { - var tasks []*ecs.Task - for _, taskArn := range input.Tasks { - tasks = append(tasks, &ecs.Task{TaskArn: taskArn, LaunchType: aws.String("FARGATE")}) - } - return &ecs.DescribeTasksOutput{ - Tasks: tasks, - }, nil - }, + client: func(t *testing.T) ECSClient { + return ECSClientMock{ + ListTasksMock: func(ctx context.Context, input *ecs.ListTasksInput, optFns ...func(*ecs.Options)) (*ecs.ListTasksOutput, error) { + return &ecs.ListTasksOutput{ + TaskArns: []string{ + *aws.String("arn:aws:ecs:eu-west-1:111111111111:task/App/8a58117dac38436ba5547e9da5d3ac3d"), + }, + }, nil + }, + DescribeTasksMock: func(ctx context.Context, input *ecs.DescribeTasksInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error) { + var tasks []ecsTypes.Task + for _, taskArn := range input.Tasks { + tasks = append(tasks, ecsTypes.Task{TaskArn: &taskArn, LaunchType: ecsTypes.LaunchTypeFargate}) + } + return &ecs.DescribeTasksOutput{ + Tasks: tasks, + }, nil + }, + } }, - expected: &ecs.Task{ + expected: &ecsTypes.Task{ TaskArn: aws.String("arn:aws:ecs:eu-west-1:111111111111:task/App/8a58117dac38436ba5547e9da5d3ac3d"), - LaunchType: aws.String("FARGATE"), + LaunchType: ecsTypes.LaunchTypeFargate, }, }, { name: "TestGetTaskWithResultsPaginated", - ecsClient: &MockECSAPI{ - ListTasksMock: func(input *ecs.ListTasksInput) (*ecs.ListTasksOutput, error) { - var taskArns []*string - var tasks []*ecs.Task - for i := paginationCall; i < (paginationCall * 100); i++ { - taskArn := aws.String(fmt.Sprintf("arn:aws:ecs:eu-west-1:111111111111:task/App/%d", i)) - taskArns = append(taskArns, taskArn) - tasks = append(tasks, &ecs.Task{TaskArn: taskArn, LaunchType: aws.String("FARGATE")}) - } - paginationCall = paginationCall + 1 - if paginationCall > 2 { + client: func(t *testing.T) ECSClient { + return ECSClientMock{ + ListTasksMock: func(ctx context.Context, input *ecs.ListTasksInput, optFns ...func(*ecs.Options)) (*ecs.ListTasksOutput, error) { + var taskArns []string + var tasks []*ecsTypes.Task + for i := paginationCall; i < (paginationCall * 100); i++ { + taskArn := *aws.String(fmt.Sprintf("arn:aws:ecs:eu-west-1:111111111111:task/App/%d", i)) + taskArns = append(taskArns, taskArn) + tasks = append(tasks, &ecsTypes.Task{TaskArn: &taskArn, LaunchType: ecsTypes.LaunchTypeFargate}) + } + paginationCall = paginationCall + 1 + if paginationCall > 2 { + return &ecs.ListTasksOutput{ + TaskArns: taskArns, + NextToken: nil, + }, nil + } return &ecs.ListTasksOutput{ TaskArns: taskArns, - NextToken: nil, + NextToken: aws.String("test-token"), }, nil - } - return &ecs.ListTasksOutput{ - TaskArns: taskArns, - NextToken: aws.String("test-token"), - }, nil - }, - DescribeTasksMock: func(input *ecs.DescribeTasksInput) (*ecs.DescribeTasksOutput, error) { - var tasks []*ecs.Task - for _, taskArn := range input.Tasks { - tasks = append(tasks, &ecs.Task{TaskArn: taskArn, LaunchType: aws.String("FARGATE")}) - } - return &ecs.DescribeTasksOutput{ - Tasks: tasks, - }, nil - }, + }, + DescribeTasksMock: func(ctx context.Context, input *ecs.DescribeTasksInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error) { + var tasks []ecsTypes.Task + for _, taskArn := range input.Tasks { + tasks = append(tasks, ecsTypes.Task{TaskArn: &taskArn, LaunchType: ecsTypes.LaunchTypeFargate}) + } + return &ecs.DescribeTasksOutput{ + Tasks: tasks, + }, nil + }, + } }, - expected: &ecs.Task{ + expected: &ecsTypes.Task{ TaskArn: aws.String("arn:aws:ecs:eu-west-1:111111111111:task/App/199"), - LaunchType: aws.String("FARGATE"), + LaunchType: ecsTypes.LaunchTypeFargate, }, }, { name: "TestGetTaskWithoutResults", - ecsClient: &MockECSAPI{ - ListTasksMock: func(input *ecs.ListTasksInput) (*ecs.ListTasksOutput, error) { - return &ecs.ListTasksOutput{ - TaskArns: []*string{}, - }, nil - }, + client: func(t *testing.T) ECSClient { + return ECSClientMock{ + ListTasksMock: func(ctx context.Context, input *ecs.ListTasksInput, optFns ...func(*ecs.Options)) (*ecs.ListTasksOutput, error) { + return &ecs.ListTasksOutput{ + TaskArns: []string{}, + }, nil + }, + } }, expected: nil, }, } for _, c := range cases { - input := CreateMockApp(c.ecsClient) + client := c.client(t) + input := CreateMockApp(client) input.cluster = "App" input.service = "test-service-1" input.getTask() @@ -348,16 +303,18 @@ func TestGetTask(t *testing.T) { func TestGetContainer(t *testing.T) { cases := []struct { - name string - ecsClient *MockECSAPI - task *ecs.Task - expected *ecs.Container + name string + client func(t *testing.T) ECSClient + task *ecsTypes.Task + expected *ecsTypes.Container }{ { - name: "TestGetContainerWithMultipleContainers", - ecsClient: &MockECSAPI{}, - task: &ecs.Task{ - Containers: []*ecs.Container{ + name: "TestGetContainerWithMultipleContainers", + client: func(t *testing.T) ECSClient { + return ECSClientMock{} + }, + task: &ecsTypes.Task{ + Containers: []ecsTypes.Container{ { Name: aws.String("echo-server"), }, @@ -366,28 +323,31 @@ func TestGetContainer(t *testing.T) { }, }, }, - expected: &ecs.Container{ + expected: &ecsTypes.Container{ Name: aws.String("echo-server"), }, }, { - name: "TestGetContainerWithSingleContainer", - ecsClient: &MockECSAPI{}, - task: &ecs.Task{ - Containers: []*ecs.Container{ + name: "TestGetContainerWithSingleContainer", + client: func(t *testing.T) ECSClient { + return ECSClientMock{} + }, + task: &ecsTypes.Task{ + Containers: []ecsTypes.Container{ { Name: aws.String("nginx"), }, }, }, - expected: &ecs.Container{ + expected: &ecsTypes.Container{ Name: aws.String("nginx"), }, }, } for _, c := range cases { - input := CreateMockApp(c.ecsClient) + client := c.client(t) + input := CreateMockApp(client) input.task = c.task input.getContainer() if ok := assert.Equal(t, c.expected, input.container); ok != true { diff --git a/internal/command.go b/internal/command.go index 4646289..f732dc8 100644 --- a/internal/command.go +++ b/internal/command.go @@ -1,13 +1,14 @@ package app import ( + "context" "encoding/json" "fmt" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ecs" - "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/spf13/viper" ) @@ -24,9 +25,9 @@ func (e *App) executeCommand() error { command = "/bin/sh" } } - App, err := e.ecsClient.ExecuteCommand(&ecs.ExecuteCommandInput{ + App, err := e.client.ExecuteCommand(context.TODO(), &ecs.ExecuteCommandInput{ Cluster: aws.String(e.cluster), - Interactive: aws.Bool(true), + Interactive: *aws.Bool(true), Task: e.task.TaskArn, Command: aws.String(command), Container: e.container.Name, @@ -62,7 +63,7 @@ func (e *App) executeCommand() error { } // Execute the session-manager-plugin with our task details - err = runCommand("session-manager-plugin", string(execSess), e.region, "StartSession", "", string(targetJson), e.endpoint) + err = runCommand("session-manager-plugin", string(execSess), e.region, "StartSession", "", string(targetJson)) e.err <- err return err diff --git a/internal/command_test.go b/internal/command_test.go index 50d11b5..437bffe 100644 --- a/internal/command_test.go +++ b/internal/command_test.go @@ -1,28 +1,30 @@ package app import ( + "context" "fmt" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ecs" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/stretchr/testify/assert" ) func TestExecuteInput(t *testing.T) { cases := []struct { - name string - expected error - ecsClient *MockECSAPI - cluster string - task *ecs.Task + name string + expected error + client func(t *testing.T) ECSClient + cluster string + task *ecsTypes.Task }{ { name: "TestExecuteInput", cluster: "test", - task: &ecs.Task{ + task: &ecsTypes.Task{ TaskArn: aws.String("arn:aws:ecs:eu-west-1:111111111111:task/App/8a58117dac38436ba5547e9da5d3ac3d"), - Containers: []*ecs.Container{ + Containers: []ecsTypes.Container{ { Name: aws.String("nginx"), RuntimeId: aws.String("544e08d919364be9926186b086c29868-2531612879"), @@ -30,16 +32,18 @@ func TestExecuteInput(t *testing.T) { }, PlatformFamily: aws.String("Linux"), }, - ecsClient: &MockECSAPI{ - ExecuteCommandMock: func(input *ecs.ExecuteCommandInput) (*ecs.ExecuteCommandOutput, error) { - return &ecs.ExecuteCommandOutput{ - Session: &ecs.Session{ - SessionId: aws.String("ecs-execute-command-0e86561fddf625dc1"), - StreamUrl: aws.String("wss://ssmmessages.eu-west-1.amazonaws.com/v1/data-channel/ecs-execute-command-blah"), - TokenValue: aws.String("abc123"), - }, - }, nil - }, + client: func(t *testing.T) ECSClient { + return ECSClientMock{ + ExecuteCommandMock: func(ctx context.Context, input *ecs.ExecuteCommandInput, optFns ...func(*ecs.Options)) (*ecs.ExecuteCommandOutput, error) { + return &ecs.ExecuteCommandOutput{ + Session: &ecsTypes.Session{ + SessionId: aws.String("ecs-execute-command-0e86561fddf625dc1"), + StreamUrl: aws.String("wss://ssmmessages.eu-west-1.amazonaws.com/v1/data-channel/ecs-execute-command-blah"), + TokenValue: aws.String("abc123"), + }, + }, nil + }, + } }, expected: nil, }, @@ -47,16 +51,16 @@ func TestExecuteInput(t *testing.T) { for _, c := range cases { app := &App{ - input: make(chan string, 1), - err: make(chan error, 1), - exit: make(chan error, 1), - ecsClient: c.ecsClient, - region: "eu-west-1", - endpoint: "ecs.eu-west-1.amazonaws.com", - cluster: c.cluster, - task: c.task, + input: make(chan string, 1), + err: make(chan error, 1), + exit: make(chan error, 1), + client: c.client(t), + region: "eu-west-1", + endpoint: "ecs.eu-west-1.amazonaws.com", + cluster: c.cluster, + task: c.task, } - app.container = c.task.Containers[0] + app.container = &c.task.Containers[0] err := app.executeCommand() if ok := assert.Equal(t, c.expected, err); ok != true { fmt.Printf("%s FAILED\n", c.name) diff --git a/internal/forward.go b/internal/forward.go index 0dbd0b5..264c542 100644 --- a/internal/forward.go +++ b/internal/forward.go @@ -1,13 +1,15 @@ package app import ( + "context" "encoding/json" "fmt" - "strconv" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ecs" + "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/spf13/viper" ) @@ -20,7 +22,16 @@ func (e *App) executeForward() error { Target: aws.String(fmt.Sprintf("ecs:%s_%s_%s", e.cluster, taskID, *e.container.RuntimeId)), } - containerPort, err := getContainerPort(e.ecsClient, *e.task.TaskDefinitionArn, *e.container.Name) + cfg, err := config.LoadDefaultConfig(context.Background(), + config.WithSharedConfigProfile(viper.GetString("profile")), + config.WithRegion(region), + ) + if err != nil { + panic(err) + } + client := ssm.NewFromConfig(cfg) // TODO: add region + ecsClient := e.client.(*ecs.Client) + containerPort, err := getContainerPort(ecsClient, *e.task.TaskDefinitionArn, *e.container.Name) if err != nil { e.err <- err return err @@ -33,15 +44,16 @@ func (e *App) executeForward() error { return err } } + portNumber := fmt.Sprint(*containerPort) input := &ssm.StartSessionInput{ DocumentName: aws.String("AWS-StartPortForwardingSession"), - Parameters: map[string][]*string{ - "portNumber": {aws.String(strconv.FormatInt(*containerPort, 10))}, - "localPortNumber": {aws.String(localPort)}, + Parameters: map[string][]string{ + "localPortNumber": {localPort}, + "portNumber": {portNumber}, }, Target: aws.String(fmt.Sprintf("ecs:%s_%s_%s", e.cluster, taskID, *e.container.RuntimeId)), } - sess, err := e.ssmClient.StartSession(input) + sess, err := client.StartSession(context.TODO(), input) if err != nil { e.err <- err return err diff --git a/internal/internal.go b/internal/internal.go index 8c1b6a9..fdd5fcb 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -1,19 +1,19 @@ package app import ( + "context" "flag" "os" "os/exec" "os/signal" "syscall" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" - "github.com/aws/aws-sdk-go/service/ecs" - "github.com/aws/aws-sdk-go/service/ecs/ecsiface" - "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/fatih/color" "github.com/spf13/viper" ) @@ -30,7 +30,7 @@ var ( pageSize = 15 backOpt = "⏎ Back" // backOpt is used to allow the user to navigate backwards in the selection prompt - awsMaxResults = aws.Int64(int64(100)) + awsMaxResults = aws.Int32(int32(100)) ) func createOpts(opts []string) []string { @@ -38,76 +38,97 @@ func createOpts(opts []string) []string { return append(initialOpts, opts...) } -func createEcsClient() *ecs.ECS { +func createEcsClient() *ecs.Client { region := viper.GetString("region") - endpointUrl := viper.GetString("aws-endpoint-url") - sess := session.Must(session.NewSessionWithOptions(session.Options{ - Config: aws.Config{Region: aws.String(region), Endpoint: aws.String(endpointUrl)}, - Profile: viper.GetString("profile"), - SharedConfigState: session.SharedConfigEnable, - })) - client := ecs.New(sess) + getCustomAWSEndpoint := func(o *ecs.Options) { + endpointUrl := viper.GetString("aws-endpoint-url") + if endpointUrl != "" { + o.BaseEndpoint = aws.String(endpointUrl) + } + } + cfg, err := config.LoadDefaultConfig(context.Background(), + config.WithSharedConfigProfile(viper.GetString("profile")), + config.WithRegion(region), + ) + if err != nil { + panic(err) + } + client := ecs.NewFromConfig(cfg, getCustomAWSEndpoint) return client } -func createEc2Client() *ec2.EC2 { +func createEc2Client() *ec2.Client { region := viper.GetString("region") - endpointUrl := viper.GetString("aws-endpoint-url") - sess := session.Must(session.NewSessionWithOptions(session.Options{ - Config: aws.Config{Region: aws.String(region), Endpoint: aws.String(endpointUrl)}, - Profile: viper.GetString("profile"), - SharedConfigState: session.SharedConfigEnable, - })) - client := ec2.New(sess) + getCustomAWSEndpoint := func(o *ec2.Options) { + endpointUrl := viper.GetString("aws-endpoint-url") + if endpointUrl != "" { + o.BaseEndpoint = aws.String(endpointUrl) + } + } + cfg, err := config.LoadDefaultConfig(context.Background(), + config.WithSharedConfigProfile(viper.GetString("profile")), + config.WithRegion(region), + ) + if err != nil { + panic(err) + } + client := ec2.NewFromConfig(cfg, getCustomAWSEndpoint) return client } -func createSSMClient() *ssm.SSM { +func createSSMClient() *ssm.Client { region := viper.GetString("region") - endpointUrl := viper.GetString("aws-endpoint-url") - sess := session.Must(session.NewSessionWithOptions(session.Options{ - Config: aws.Config{Region: aws.String(region), Endpoint: aws.String(endpointUrl)}, - Profile: viper.GetString("profile"), - SharedConfigState: session.SharedConfigEnable, - })) - client := ssm.New(sess) + getCustomAWSEndpoint := func(o *ssm.Options) { + endpointUrl := viper.GetString("aws-endpoint-url") + if endpointUrl != "" { + o.BaseEndpoint = aws.String(endpointUrl) + } + } + cfg, err := config.LoadDefaultConfig(context.Background(), + config.WithSharedConfigProfile(viper.GetString("profile")), + config.WithRegion(region), + ) + if err != nil { + panic(err) + } + client := ssm.NewFromConfig(cfg, getCustomAWSEndpoint) return client } // getPlatformFamily checks an ECS tasks properties to see if the OS can be derived from its properties, otherwise // it will check the container instance itself to determine the OS. -func getPlatformFamily(client ecsiface.ECSAPI, clusterName string, task *ecs.Task) (string, error) { - taskDefinition, err := client.DescribeTaskDefinition(&ecs.DescribeTaskDefinitionInput{ +func getPlatformFamily(client ECSClient, clusterName string, task *ecsTypes.Task) (string, error) { + taskDefinition, err := client.DescribeTaskDefinition(context.TODO(), &ecs.DescribeTaskDefinitionInput{ TaskDefinition: task.TaskDefinitionArn, }) if err != nil { return "", err } if taskDefinition.TaskDefinition.RuntimePlatform != nil { - return *taskDefinition.TaskDefinition.RuntimePlatform.OperatingSystemFamily, nil + return string(taskDefinition.TaskDefinition.RuntimePlatform.OperatingSystemFamily), nil } return "", nil } // getContainerInstanceOS describes the specified container instance and checks against the backing EC2 instance // to determine the platform. -func getContainerInstanceOS(ecsClient ecsiface.ECSAPI, ec2Client ec2iface.EC2API, cluster string, containerInstanceArn string) (string, error) { - res, err := ecsClient.DescribeContainerInstances(&ecs.DescribeContainerInstancesInput{ +func getContainerInstanceOS(ecsClient ECSClient, ec2Client EC2Client, cluster string, containerInstanceArn string) (string, error) { + res, err := ecsClient.DescribeContainerInstances(context.TODO(), &ecs.DescribeContainerInstancesInput{ Cluster: aws.String(cluster), - ContainerInstances: []*string{ - aws.String(containerInstanceArn), + ContainerInstances: []string{ + *aws.String(containerInstanceArn), }, }) if err != nil { return "", err } instanceId := res.ContainerInstances[0].Ec2InstanceId - instance, err := ec2Client.DescribeInstances(&ec2.DescribeInstancesInput{ - InstanceIds: []*string{ - instanceId, + instance, _ := ec2Client.DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{ + InstanceIds: []string{ + *instanceId, }, }) operatingSystem := *instance.Reservations[0].Instances[0].PlatformDetails @@ -145,17 +166,17 @@ func runCommand(process string, args ...string) error { return nil } -func getContainerPort(client ecsiface.ECSAPI, taskDefinitionArn string, containerName string) (*int64, error) { - res, err := client.DescribeTaskDefinition(&ecs.DescribeTaskDefinitionInput{ +func getContainerPort(client ECSClient, taskDefinitionArn string, containerName string) (*int32, error) { + res, err := client.DescribeTaskDefinition(context.TODO(), &ecs.DescribeTaskDefinitionInput{ TaskDefinition: aws.String(taskDefinitionArn), }) if err != nil { return nil, err } - var container ecs.ContainerDefinition + var container ecsTypes.ContainerDefinition for _, c := range res.TaskDefinition.ContainerDefinitions { if *c.Name == containerName { - container = *c + container = c } } return container.PortMappings[0].ContainerPort, nil diff --git a/internal/internal_test.go b/internal/internal_test.go index df2a2e7..23f1813 100644 --- a/internal/internal_test.go +++ b/internal/internal_test.go @@ -1,65 +1,81 @@ package app import ( + "context" "fmt" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ecs" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2Types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/stretchr/testify/assert" ) +type EC2ClientMock struct { + DescribeInstancesMock func(ctx context.Context, params *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) +} + +func (m EC2ClientMock) DescribeInstances(ctx context.Context, params *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { + return m.DescribeInstancesMock(ctx, params, optFns...) +} + func TestGetPlatformFamily(t *testing.T) { cases := []struct { - name string - expected string - ecsClient *MockECSAPI - cluster string - task *ecs.Task + name string + expected string + client func(t *testing.T) ECSClient + cluster string + task *ecsTypes.Task }{ { name: "TestGetPlatformFamilyWithFargateTask", cluster: "test", - task: &ecs.Task{ + task: &ecsTypes.Task{ TaskArn: aws.String("arn:aws:ecs:eu-west-1:111111111111:task/App/8a58117dac38436ba5547e9da5d3ac3d"), - LaunchType: aws.String("FARGATE"), + LaunchType: ecsTypes.LaunchTypeFargate, PlatformFamily: aws.String("Linux"), }, - ecsClient: &MockECSAPI{ - DescribeTaskDefinitionMock: func(input *ecs.DescribeTaskDefinitionInput) (*ecs.DescribeTaskDefinitionOutput, error) { - return &ecs.DescribeTaskDefinitionOutput{ - TaskDefinition: &ecs.TaskDefinition{ - RuntimePlatform: &ecs.RuntimePlatform{ - OperatingSystemFamily: aws.String("Linux/UNIX"), + client: func(t *testing.T) ECSClient { + return ECSClientMock{ + DescribeTaskDefinitionMock: func(ctx context.Context, input *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error) { + return &ecs.DescribeTaskDefinitionOutput{ + TaskDefinition: &ecsTypes.TaskDefinition{ + RuntimePlatform: &ecsTypes.RuntimePlatform{ + OperatingSystemFamily: ecsTypes.OSFamilyLinux, + }, }, - }, - }, nil - }, + }, nil + }, + } }, - expected: "Linux/UNIX", + expected: "LINUX", }, { name: "TestGetPlatformFamilyWithEC2LaunchTaskNoRuntimePlatformFail", cluster: "test", - task: &ecs.Task{ + task: &ecsTypes.Task{ TaskArn: aws.String("arn:aws:ecs:eu-west-1:111111111111:task/App/8a58117dac38436ba5547e9da5d3ac3d"), - LaunchType: aws.String("EC2"), + LaunchType: ecsTypes.LaunchTypeEc2, ContainerInstanceArn: aws.String("abcdefghij1234567890"), }, - ecsClient: &MockECSAPI{ - DescribeTaskDefinitionMock: func(input *ecs.DescribeTaskDefinitionInput) (*ecs.DescribeTaskDefinitionOutput, error) { - return &ecs.DescribeTaskDefinitionOutput{ - TaskDefinition: &ecs.TaskDefinition{}, - }, nil - }, + client: func(t *testing.T) ECSClient { + return ECSClientMock{ + DescribeTaskDefinitionMock: func(ctx context.Context, input *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error) { + return &ecs.DescribeTaskDefinitionOutput{ + TaskDefinition: &ecsTypes.TaskDefinition{}, + }, nil + }, + } }, expected: "", }, } for _, c := range cases { - res, _ := getPlatformFamily(c.ecsClient, c.cluster, c.task) + client := c.client(t) + res, _ := getPlatformFamily(client, c.cluster, c.task) if ok := assert.Equal(t, c.expected, res); ok != true { fmt.Printf("%s FAILED\n", c.name) } @@ -71,8 +87,8 @@ func TestGetContainerInstanceOS(t *testing.T) { cases := []struct { name string expected string - ecsClient *MockECSAPI - ec2Client *MockEC2API + ecsClient func(t *testing.T) ECSClient + ec2Client func(t *testing.T) EC2Client cluster string containerInstanceArn string }{ @@ -80,38 +96,44 @@ func TestGetContainerInstanceOS(t *testing.T) { name: "TestGetContainerInstanceOS", cluster: "test", containerInstanceArn: "abcdef123456", - ecsClient: &MockECSAPI{ - DescribeContainerInstancesMock: func(input *ecs.DescribeContainerInstancesInput) (*ecs.DescribeContainerInstancesOutput, error) { - return &ecs.DescribeContainerInstancesOutput{ - ContainerInstances: []*ecs.ContainerInstance{ - { - Ec2InstanceId: aws.String("i-0063cc3b62343f4d1"), + ecsClient: func(t *testing.T) ECSClient { + return ECSClientMock{ + DescribeContainerInstancesMock: func(ctx context.Context, input *ecs.DescribeContainerInstancesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeContainerInstancesOutput, error) { + return &ecs.DescribeContainerInstancesOutput{ + ContainerInstances: []ecsTypes.ContainerInstance{ + { + Ec2InstanceId: aws.String("i-0063cc3b62343f4d1"), + }, }, - }, - }, nil - }, + }, nil + }, + } }, - ec2Client: &MockEC2API{ - DescribeInstancesMock: func(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) { - return &ec2.DescribeInstancesOutput{ - Reservations: []*ec2.Reservation{ - { - Instances: []*ec2.Instance{ - { - InstanceId: aws.String("i-0063cc3b62343f4d1"), - PlatformDetails: aws.String("Linux/UNIX"), + ec2Client: func(t *testing.T) EC2Client { + return EC2ClientMock{ + DescribeInstancesMock: func(ctx context.Context, params *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { + return &ec2.DescribeInstancesOutput{ + Reservations: []ec2Types.Reservation{ + { + Instances: []ec2Types.Instance{ + { + InstanceId: aws.String("i-0063cc3b62343f4d1"), + PlatformDetails: aws.String("Linux/UNIX"), + }, }, }, - }, - }}, nil - }, + }}, nil + }, + } }, expected: "Linux/UNIX", }, } for _, c := range cases { - res, _ := getContainerInstanceOS(c.ecsClient, c.ec2Client, c.cluster, c.containerInstanceArn) + ecsClient := c.ecsClient(t) + ec2Client := c.ec2Client(t) + res, _ := getContainerInstanceOS(ecsClient, ec2Client, c.cluster, c.containerInstanceArn) if ok := assert.Equal(t, c.expected, res); ok != true { fmt.Printf("%s FAILED\n", c.name) } diff --git a/internal/select.go b/internal/select.go index 075462a..475c306 100644 --- a/internal/select.go +++ b/internal/select.go @@ -6,8 +6,9 @@ import ( "strings" "github.com/AlecAivazis/survey/v2" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ecs" + "github.com/aws/aws-sdk-go-v2/aws" + // "github.com/aws/aws-sdk-go-v2/service/ecs" + ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" ) func init() { @@ -87,7 +88,7 @@ func selectService(serviceNames []string) (string, error) { } // selectTask provides the prompt for choosing a Task -func selectTask(tasks map[string]*ecs.Task) (*ecs.Task, error) { +func selectTask(tasks map[string]*ecsTypes.Task) (*ecsTypes.Task, error) { if flag.Lookup("test.v") != nil { // When testing pagination, we want to return a task from the second set of results, // which will prove pagination is working correctly @@ -121,11 +122,11 @@ func selectTask(tasks map[string]*ecs.Task) (*ecs.Task, error) { icons.SelectFocus.Format = "green" })) if err != nil { - return &ecs.Task{}, err + return &ecsTypes.Task{}, err } if selection == backOpt { - return &ecs.Task{TaskArn: aws.String(backOpt)}, nil + return &ecsTypes.Task{TaskArn: aws.String(backOpt)}, nil } taskId := strings.Split(selection, " | ")[0] @@ -135,13 +136,14 @@ func selectTask(tasks map[string]*ecs.Task) (*ecs.Task, error) { } // selectContainer prompts the user to choose a container within a task -func selectContainer(containers []*ecs.Container) (*ecs.Container, error) { +func selectContainer(containers *[]ecsTypes.Container) (*ecsTypes.Container, error) { if flag.Lookup("test.v") != nil { - return containers[0], nil + container := *containers + return &container[0], nil } var containerNames []string - for _, c := range containers { + for _, c := range *containers { containerNames = append(containerNames, *c.Name) } @@ -157,16 +159,16 @@ func selectContainer(containers []*ecs.Container) (*ecs.Container, error) { icons.SelectFocus.Format = "yellow" })) if err != nil { - return &ecs.Container{}, err + return &ecsTypes.Container{}, err } if selection == backOpt { - return &ecs.Container{Name: aws.String(backOpt)}, nil + return &ecsTypes.Container{Name: aws.String(backOpt)}, nil } - var container *ecs.Container - for _, c := range containers { + var container *ecsTypes.Container + for _, c := range *containers { if selection == *c.Name { - container = c + container = &c } }