Skip to content

Commit

Permalink
fix: #38 add ssmClient to App struct and createSSMClient to leverage …
Browse files Browse the repository at this point in the history
…shared config options, fixes sso credential issue with port forwarding
  • Loading branch information
tedsmitt committed Dec 29, 2023
1 parent 8ceb1d0 commit 10069b5
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 75 deletions.
40 changes: 22 additions & 18 deletions internal/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ecs"
"github.com/aws/aws-sdk-go/service/ecs/ecsiface"
"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
"github.com/spf13/viper"
)

Expand All @@ -18,7 +19,8 @@ type App struct {
input chan string
err chan error
exit chan error
client ecsiface.ECSAPI
ecsClient ecsiface.ECSAPI
ssmClient ssmiface.SSMAPI
region string
endpoint string
cluster string
Expand All @@ -31,14 +33,16 @@ type App struct {

// CreateApp initialises a new App struct with the required initial values
func CreateApp() *App {
client := createEcsClient()
ecsClient := createEcsClient()
ssmClient := createSSMClient()
e := &App{
input: make(chan string, 1),
err: make(chan error, 1),
exit: make(chan error, 1),
client: client,
region: client.SigningRegion,
endpoint: client.Endpoint,
input: make(chan string, 1),
err: make(chan error, 1),
exit: make(chan error, 1),
ecsClient: ecsClient,
ssmClient: ssmClient,
region: ecsClient.SigningRegion,
endpoint: ecsClient.Endpoint,
}

return e
Expand Down Expand Up @@ -112,7 +116,7 @@ func (e *App) getCluster() {
return
}

list, err := e.client.ListClusters(&ecs.ListClustersInput{
list, err := e.ecsClient.ListClusters(&ecs.ListClustersInput{
MaxResults: awsMaxResults,
})
if err != nil {
Expand All @@ -124,7 +128,7 @@ func (e *App) getCluster() {

if nextToken != nil {
for {
list, err := e.client.ListClusters(&ecs.ListClustersInput{
list, err := e.ecsClient.ListClusters(&ecs.ListClustersInput{
MaxResults: awsMaxResults,
NextToken: nextToken,
})
Expand Down Expand Up @@ -184,7 +188,7 @@ func (e *App) getService() {
return
}

list, err := e.client.ListServices(&ecs.ListServicesInput{
list, err := e.ecsClient.ListServices(&ecs.ListServicesInput{
Cluster: aws.String(e.cluster),
MaxResults: awsMaxResults,
})
Expand All @@ -197,7 +201,7 @@ func (e *App) getService() {

if nextToken != nil {
for {
list, err := e.client.ListServices(&ecs.ListServicesInput{
list, err := e.ecsClient.ListServices(&ecs.ListServicesInput{
Cluster: aws.String(e.cluster),
MaxResults: awsMaxResults,
NextToken: nextToken,
Expand Down Expand Up @@ -262,7 +266,7 @@ func (e *App) getTask() {

cliArg := viper.GetString("task")
if cliArg != "" {
describe, err := e.client.DescribeTasks(&ecs.DescribeTasksInput{
describe, err := e.ecsClient.DescribeTasks(&ecs.DescribeTasksInput{
Cluster: aws.String(e.cluster),
Tasks: []*string{aws.String(cliArg)},
})
Expand Down Expand Up @@ -298,7 +302,7 @@ func (e *App) getTask() {
}
}

list, err := e.client.ListTasks(input)
list, err := e.ecsClient.ListTasks(input)
if err != nil {
e.err <- err
return
Expand All @@ -309,7 +313,7 @@ func (e *App) getTask() {

if nextToken != nil {
for {
list, err := e.client.ListTasks(&ecs.ListTasksInput{
list, err := e.ecsClient.ListTasks(&ecs.ListTasksInput{
Cluster: aws.String(e.cluster),
MaxResults: awsMaxResults,
NextToken: nextToken,
Expand All @@ -329,7 +333,7 @@ func (e *App) getTask() {

e.tasks = make(map[string]*ecs.Task)
if len(taskArns) > 0 {
describe, err := e.client.DescribeTasks(&ecs.DescribeTasksInput{
describe, err := e.ecsClient.DescribeTasks(&ecs.DescribeTasksInput{
Cluster: aws.String(e.cluster),
Tasks: taskArns,
})
Expand Down Expand Up @@ -419,7 +423,7 @@ func (e *App) getContainer() {
func (e *App) getContainerOS() {
// Get associated task definition and determine OS family if EC2 launch-type
if *e.task.LaunchType == "EC2" {
family, err := getPlatformFamily(e.client, e.cluster, e.task)
family, err := getPlatformFamily(e.ecsClient, e.cluster, e.task)
if err != nil {
e.err <- err
return
Expand All @@ -428,7 +432,7 @@ func (e *App) getContainerOS() {
// then we refer to the container instance to determine the OS
if family == "" {
ec2Client := createEc2Client()
family, err = getContainerInstanceOS(e.client, ec2Client, e.cluster, *e.task.ContainerInstanceArn)
family, err = getContainerInstanceOS(e.ecsClient, ec2Client, e.cluster, *e.task.ContainerInstanceArn)
if err != nil {
e.err <- err
return
Expand Down
74 changes: 37 additions & 37 deletions internal/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ func (m *MockEC2API) DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.
// CreateMockApp initialises a new App struct and takes a MockClient as an argument - only used in tests
func CreateMockApp(c *MockECSAPI) *App {
e := &App{
input: make(chan string, 1),
err: make(chan error, 1),
exit: make(chan error, 1),
client: c,
region: "eu-west-1",
endpoint: "ecs.eu-west-1.amazonaws.com",
input: make(chan string, 1),
err: make(chan error, 1),
exit: make(chan error, 1),
ecsClient: c,
region: "eu-west-1",
endpoint: "ecs.eu-west-1.amazonaws.com",
}

return e
Expand All @@ -106,13 +106,13 @@ func CreateMockApp(c *MockECSAPI) *App {
func TestGetCluster(t *testing.T) {
paginationCall := 0
cases := []struct {
name string
client *MockECSAPI
expected string
name string
ecsClient *MockECSAPI
expected string
}{
{
name: "TestGetClusterWithResults",
client: &MockECSAPI{
ecsClient: &MockECSAPI{
ListClustersMock: func(input *ecs.ListClustersInput) (*ecs.ListClustersOutput, error) {
return &ecs.ListClustersOutput{
ClusterArns: []*string{
Expand All @@ -126,7 +126,7 @@ func TestGetCluster(t *testing.T) {
},
{
name: "TestGetClusterWithResultsPaginated",
client: &MockECSAPI{
ecsClient: &MockECSAPI{
ListClustersMock: func(input *ecs.ListClustersInput) (*ecs.ListClustersOutput, error) {
var clusters []*string
for i := paginationCall; i < (paginationCall * 100); i++ {
Expand All @@ -149,7 +149,7 @@ func TestGetCluster(t *testing.T) {
},
{
name: "TestGetClusterWithSingleResult",
client: &MockECSAPI{
ecsClient: &MockECSAPI{
ListClustersMock: func(input *ecs.ListClustersInput) (*ecs.ListClustersOutput, error) {
return &ecs.ListClustersOutput{
ClusterArns: []*string{
Expand All @@ -162,7 +162,7 @@ func TestGetCluster(t *testing.T) {
},
{
name: "TestGetClusterWithoutResults",
client: &MockECSAPI{
ecsClient: &MockECSAPI{
ListClustersMock: func(input *ecs.ListClustersInput) (*ecs.ListClustersOutput, error) {
return &ecs.ListClustersOutput{
ClusterArns: []*string{},
Expand All @@ -174,7 +174,7 @@ func TestGetCluster(t *testing.T) {
}

for _, c := range cases {
input := CreateMockApp(c.client)
input := CreateMockApp(c.ecsClient)
input.getCluster()
if ok := assert.Equal(t, c.expected, input.cluster); ok != true {
fmt.Printf("%s FAILED\n", c.name)
Expand All @@ -186,13 +186,13 @@ func TestGetCluster(t *testing.T) {
func TestGetService(t *testing.T) {
paginationCall := 1
cases := []struct {
name string
client *MockECSAPI
expected string
name string
ecsClient *MockECSAPI
expected string
}{
{
name: "TestGetServiceWithResults",
client: &MockECSAPI{
ecsClient: &MockECSAPI{
ListServicesMock: func(input *ecs.ListServicesInput) (*ecs.ListServicesOutput, error) {
return &ecs.ListServicesOutput{
ServiceArns: []*string{
Expand All @@ -206,7 +206,7 @@ func TestGetService(t *testing.T) {
},
{
name: "TestGetServiceWithResultsPaginated",
client: &MockECSAPI{
ecsClient: &MockECSAPI{
ListServicesMock: func(input *ecs.ListServicesInput) (*ecs.ListServicesOutput, error) {
var services []*string
for i := paginationCall; i < (paginationCall * 100); i++ {
Expand All @@ -229,7 +229,7 @@ func TestGetService(t *testing.T) {
},
{
name: "TestGetServiceWithoutResults",
client: &MockECSAPI{
ecsClient: &MockECSAPI{
ListServicesMock: func(input *ecs.ListServicesInput) (*ecs.ListServicesOutput, error) {
return &ecs.ListServicesOutput{
ServiceArns: []*string{},
Expand All @@ -241,7 +241,7 @@ func TestGetService(t *testing.T) {
}

for _, c := range cases {
input := CreateMockApp(c.client)
input := CreateMockApp(c.ecsClient)
input.cluster = "App"
input.getService()
if ok := assert.Equal(t, c.expected, input.service); ok != true {
Expand All @@ -254,13 +254,13 @@ func TestGetService(t *testing.T) {
func TestGetTask(t *testing.T) {
paginationCall := 1
cases := []struct {
name string
client *MockECSAPI
expected *ecs.Task
name string
ecsClient *MockECSAPI
expected *ecs.Task
}{
{
name: "TestGetTaskWithResults",
client: &MockECSAPI{
ecsClient: &MockECSAPI{
ListTasksMock: func(input *ecs.ListTasksInput) (*ecs.ListTasksOutput, error) {
return &ecs.ListTasksOutput{
TaskArns: []*string{
Expand All @@ -285,7 +285,7 @@ func TestGetTask(t *testing.T) {
},
{
name: "TestGetTaskWithResultsPaginated",
client: &MockECSAPI{
ecsClient: &MockECSAPI{
ListTasksMock: func(input *ecs.ListTasksInput) (*ecs.ListTasksOutput, error) {
var taskArns []*string
var tasks []*ecs.Task
Expand Down Expand Up @@ -323,7 +323,7 @@ func TestGetTask(t *testing.T) {
},
{
name: "TestGetTaskWithoutResults",
client: &MockECSAPI{
ecsClient: &MockECSAPI{
ListTasksMock: func(input *ecs.ListTasksInput) (*ecs.ListTasksOutput, error) {
return &ecs.ListTasksOutput{
TaskArns: []*string{},
Expand All @@ -335,7 +335,7 @@ func TestGetTask(t *testing.T) {
}

for _, c := range cases {
input := CreateMockApp(c.client)
input := CreateMockApp(c.ecsClient)
input.cluster = "App"
input.service = "test-service-1"
input.getTask()
Expand All @@ -348,14 +348,14 @@ func TestGetTask(t *testing.T) {

func TestGetContainer(t *testing.T) {
cases := []struct {
name string
client *MockECSAPI
task *ecs.Task
expected *ecs.Container
name string
ecsClient *MockECSAPI
task *ecs.Task
expected *ecs.Container
}{
{
name: "TestGetContainerWithMultipleContainers",
client: &MockECSAPI{},
name: "TestGetContainerWithMultipleContainers",
ecsClient: &MockECSAPI{},
task: &ecs.Task{
Containers: []*ecs.Container{
{
Expand All @@ -371,8 +371,8 @@ func TestGetContainer(t *testing.T) {
},
},
{
name: "TestGetContainerWithSingleContainer",
client: &MockECSAPI{},
name: "TestGetContainerWithSingleContainer",
ecsClient: &MockECSAPI{},
task: &ecs.Task{
Containers: []*ecs.Container{
{
Expand All @@ -387,7 +387,7 @@ func TestGetContainer(t *testing.T) {
}

for _, c := range cases {
input := CreateMockApp(c.client)
input := CreateMockApp(c.ecsClient)
input.task = c.task
input.getContainer()
if ok := assert.Equal(t, c.expected, input.container); ok != true {
Expand Down
2 changes: 1 addition & 1 deletion internal/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (e *App) executeCommand() error {
command = "/bin/sh"
}
}
App, err := e.client.ExecuteCommand(&ecs.ExecuteCommandInput{
App, err := e.ecsClient.ExecuteCommand(&ecs.ExecuteCommandInput{
Cluster: aws.String(e.cluster),
Interactive: aws.Bool(true),
Task: e.task.TaskArn,
Expand Down
28 changes: 14 additions & 14 deletions internal/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import (

func TestExecuteInput(t *testing.T) {
cases := []struct {
name string
expected error
client *MockECSAPI
cluster string
task *ecs.Task
name string
expected error
ecsClient *MockECSAPI
cluster string
task *ecs.Task
}{
{
name: "TestExecuteInput",
Expand All @@ -30,7 +30,7 @@ func TestExecuteInput(t *testing.T) {
},
PlatformFamily: aws.String("Linux"),
},
client: &MockECSAPI{
ecsClient: &MockECSAPI{
ExecuteCommandMock: func(input *ecs.ExecuteCommandInput) (*ecs.ExecuteCommandOutput, error) {
return &ecs.ExecuteCommandOutput{
Session: &ecs.Session{
Expand All @@ -47,14 +47,14 @@ func TestExecuteInput(t *testing.T) {

for _, c := range cases {
app := &App{
input: make(chan string, 1),
err: make(chan error, 1),
exit: make(chan error, 1),
client: c.client,
region: "eu-west-1",
endpoint: "ecs.eu-west-1.amazonaws.com",
cluster: c.cluster,
task: c.task,
input: make(chan string, 1),
err: make(chan error, 1),
exit: make(chan error, 1),
ecsClient: c.ecsClient,
region: "eu-west-1",
endpoint: "ecs.eu-west-1.amazonaws.com",
cluster: c.cluster,
task: c.task,
}
app.container = c.task.Containers[0]
err := app.executeCommand()
Expand Down
Loading

0 comments on commit 10069b5

Please sign in to comment.