Skip to content

Commit

Permalink
Merge pull request #35 from air-hand/feat/aws-ssm
Browse files Browse the repository at this point in the history
feat: aws session manager support
  • Loading branch information
pascalbreuninger authored Jul 10, 2024
2 parents ff8ba39 + e959bfa commit e949dfd
Show file tree
Hide file tree
Showing 11 changed files with 215 additions and 28 deletions.
7 changes: 6 additions & 1 deletion .devcontainer.json
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
{"image":"mcr.microsoft.com/devcontainers/go","build":{}}
{
"image":"mcr.microsoft.com/devcontainers/go:1.19",
"mounts": [
"source=${localEnv:HOME}/.aws,target=/home/vscode/.aws,type=bind,consistency=cached"
]
}
3 changes: 2 additions & 1 deletion .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ jobs:
- id: get_version
run: |
RELEASE_VERSION=$(echo $GITHUB_REF | sed -nE 's!refs/tags/!!p')
echo "::set-output name=release_version::$RELEASE_VERSION"
echo "release_version=$RELEASE_VERSION" >> $GITHUB_OUTPUT
- name: Compile binaries
run: |
chmod +x ./hack/build.sh
./hack/build.sh
env:
RELEASE_VERSION: ${{ steps.get_version.outputs.release_version }}
GITHUB_OWNER: ${{ github.repository_owner }}
- name: Save release assets
uses: softprops/action-gh-release@v1
with:
Expand Down
52 changes: 47 additions & 5 deletions cmd/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,15 @@ func (cmd *CommandCmd) Run(
if err != nil {
return err
}
addr := "localhost:" + port
portStr := strconv.Itoa(port)
addr := "localhost:" + portStr
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
connectArgs := []string{
"ec2-instance-connect",
"open-tunnel",
"--instance-id", instanceID,
"--local-port", port,
"--local-port", portStr,
}
if endpointID != "" {
connectArgs = append(connectArgs, "--instance-connect-endpoint-id", endpointID)
Expand Down Expand Up @@ -121,6 +122,47 @@ func (cmd *CommandCmd) Run(
return err
}

// try session manager
if providerAws.Config.UseSessionManager {
instanceID := *instance.Reservations[0].Instances[0].InstanceId

cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()

var err error
port, err := findAvailablePort()
if err != nil {
return err
}

addr := fmt.Sprintf("localhost:%d", port)
connectArgs, err := aws.CommandArgsSSMTunneling(instanceID, port)
if err != nil {
return err
}

cmd := exec.CommandContext(cancelCtx, "aws", connectArgs...)
// open tunnel in background
if err = cmd.Start(); err != nil {
return fmt.Errorf("start tunnel: %w", err)
}
defer func() {
err = cmd.Process.Kill()
}()
timeoutCtx, cancelFn := context.WithTimeout(ctx, 30*time.Second)
defer cancelFn()
waitForPort(timeoutCtx, addr)

client, err := ssh.NewSSHClient("devpod", addr, privateKey)
if err != nil {
logs.Debugf("error connecting by session manager: %v", err)
return err
}

defer client.Close()
return ssh.Run(ctx, client, command, os.Stdin, os.Stdout, os.Stderr)
}

// try public ip
if instance.Reservations[0].Instances[0].PublicIpAddress != nil {
ip := *instance.Reservations[0].Instances[0].PublicIpAddress
Expand Down Expand Up @@ -174,12 +216,12 @@ func waitForPort(ctx context.Context, addr string) {
}

}
func findAvailablePort() (string, error) {
func findAvailablePort() (int, error) {
l, err := net.Listen("tcp", ":0")
if err != nil {
return "", err
return -1, err
}
defer l.Close()

return strconv.Itoa(l.Addr().(*net.TCPAddr).Port), nil
return l.Addr().(*net.TCPAddr).Port, nil
}
2 changes: 1 addition & 1 deletion hack/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ for OS in ${PROVIDER_BUILD_PLATFORMS[@]}; do
done

# generate provider.yaml
go run -mod vendor "${PROVIDER_ROOT}/hack/provider/main.go" ${RELEASE_VERSION} ${BUILD_VERSION} ${PROVIDER_ROOT} > "${PROVIDER_ROOT}/release/provider.yaml"
GITHUB_OWNER=${GITHUB_OWNER:-"loft-sh"} go run -mod vendor "${PROVIDER_ROOT}/hack/provider/main.go" ${RELEASE_VERSION} ${BUILD_VERSION} ${PROVIDER_ROOT} > "${PROVIDER_ROOT}/release/provider.yaml"
7 changes: 7 additions & 0 deletions hack/provider/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ func main() {

if buildVersion == "dev" {
replaced = strings.Replace(replaced, "##PROJECT_ROOT##", projectRoot, -1)
} else {
githubOwner, found := os.LookupEnv("GITHUB_OWNER")
if !found {
fmt.Fprintln(os.Stderr, "Expected GITHUB_OWNER environment variable")
os.Exit(1)
}
replaced = strings.ReplaceAll(replaced, "##GITHUB_OWNER##", githubOwner)
}

for k, v := range checksumMap {
Expand Down
9 changes: 9 additions & 0 deletions hack/provider/provider-dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ optionGroups:
- AWS_USE_INSTANCE_CONNECT_ENDPOINT
- AWS_INSTANCE_CONNECT_ENDPOINT_ID
- AWS_USE_SPOT_INSTANCE
- AWS_USE_SESSION_MANAGER
- AWS_KMS_KEY_ARN_FOR_SESSION_MANAGER
name: "AWS options"
defaultVisible: false
- options:
Expand Down Expand Up @@ -242,6 +244,13 @@ options:
description: "Prefer the Spot instead of On-Demand instances."
type: boolean
default: false
AWS_USE_SESSION_MANAGER:
description: "If defined, will try to connect to the ec2 instance via the AWS Session Manager"
type: boolean
default: false
AWS_KMS_KEY_ARN_FOR_SESSION_MANAGER:
description: "Specify the KMS key ARN to use for the AWS Session Manager"
default: ""
INACTIVITY_TIMEOUT:
description: If defined, will automatically stop the VM after the inactivity period.
default: 10m
Expand Down
23 changes: 16 additions & 7 deletions hack/provider/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ optionGroups:
- AWS_USE_INSTANCE_CONNECT_ENDPOINT
- AWS_INSTANCE_CONNECT_ENDPOINT_ID
- AWS_USE_SPOT_INSTANCE
- AWS_USE_SESSION_MANAGER
- AWS_KMS_KEY_ARN_FOR_SESSION_MANAGER
name: "AWS options"
defaultVisible: false
- options:
Expand Down Expand Up @@ -242,6 +244,13 @@ options:
description: "Prefer the Spot instead of On-Demand instances."
type: boolean
default: false
AWS_USE_SESSION_MANAGER:
description: "If defined, will try to connect to the ec2 instance via the AWS Session Manager"
type: boolean
default: false
AWS_KMS_KEY_ARN_FOR_SESSION_MANAGER:
description: "Specify the KMS key ARN to use for the AWS Session Manager"
default: ""
INACTIVITY_TIMEOUT:
description: If defined, will automatically stop the VM after the inactivity period.
default: 10m
Expand All @@ -263,11 +272,11 @@ agent:
AWS_PROVIDER:
- os: linux
arch: amd64
path: https://github.com/loft-sh/devpod-provider-aws/releases/download/##VERSION##/devpod-provider-aws-linux-amd64
path: https://github.com/##GITHUB_OWNER##/devpod-provider-aws/releases/download/##VERSION##/devpod-provider-aws-linux-amd64
checksum: ##CHECKSUM_LINUX_AMD64##
- os: linux
arch: arm64
path: https://github.com/loft-sh/devpod-provider-aws/releases/download/##VERSION##/devpod-provider-aws-linux-arm64
path: https://github.com/##GITHUB_OWNER##/devpod-provider-aws/releases/download/##VERSION##/devpod-provider-aws-linux-arm64
checksum: ##CHECKSUM_LINUX_ARM64##
exec:
shutdown: |-
Expand All @@ -276,23 +285,23 @@ binaries:
AWS_PROVIDER:
- os: linux
arch: amd64
path: https://github.com/loft-sh/devpod-provider-aws/releases/download/##VERSION##/devpod-provider-aws-linux-amd64
path: https://github.com/##GITHUB_OWNER##/devpod-provider-aws/releases/download/##VERSION##/devpod-provider-aws-linux-amd64
checksum: ##CHECKSUM_LINUX_AMD64##
- os: linux
arch: arm64
path: https://github.com/loft-sh/devpod-provider-aws/releases/download/##VERSION##/devpod-provider-aws-linux-arm64
path: https://github.com/##GITHUB_OWNER##/devpod-provider-aws/releases/download/##VERSION##/devpod-provider-aws-linux-arm64
checksum: ##CHECKSUM_LINUX_ARM64##
- os: darwin
arch: amd64
path: https://github.com/loft-sh/devpod-provider-aws/releases/download/##VERSION##/devpod-provider-aws-darwin-amd64
path: https://github.com/##GITHUB_OWNER##/devpod-provider-aws/releases/download/##VERSION##/devpod-provider-aws-darwin-amd64
checksum: ##CHECKSUM_DARWIN_AMD64##
- os: darwin
arch: arm64
path: https://github.com/loft-sh/devpod-provider-aws/releases/download/##VERSION##/devpod-provider-aws-darwin-arm64
path: https://github.com/##GITHUB_OWNER##/devpod-provider-aws/releases/download/##VERSION##/devpod-provider-aws-darwin-arm64
checksum: ##CHECKSUM_DARWIN_ARM64##
- os: windows
arch: amd64
path: https://github.com/loft-sh/devpod-provider-aws/releases/download/##VERSION##/devpod-provider-aws-windows-amd64.exe
path: https://github.com/##GITHUB_OWNER##/devpod-provider-aws/releases/download/##VERSION##/devpod-provider-aws-windows-amd64.exe
checksum: ##CHECKSUM_WINDOWS_AMD64##
exec:
init: ${AWS_PROVIDER} init
Expand Down
40 changes: 40 additions & 0 deletions pkg/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,41 @@ func CreateDevpodInstanceProfile(ctx context.Context, provider *AwsProvider) (st
return "", err
}

ssmManagedInstanceCorePolicyInput := &iam.AttachRolePolicyInput{
PolicyArn: aws.String("arn:aws:iam::aws:policy/AmazonSSMManagedInstanceCore"),
RoleName: aws.String("devpod-ec2-role"),
}

_, err = svc.AttachRolePolicy(ctx, ssmManagedInstanceCorePolicyInput)
if err != nil {
return "", err
}

if provider.Config.KmsKeyARNForSessionManager != "" {
kmsDecryptPolicyInput := &iam.PutRolePolicyInput{
PolicyDocument: aws.String(fmt.Sprintf(`{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "DecryptSSM",
"Action": [
"kms:Decrypt"
],
"Effect": "Allow",
"Resource": "%s"
}
]
}`, provider.Config.KmsKeyARNForSessionManager)),
PolicyName: aws.String("ssm-kms-decrypt-policy"),
RoleName: aws.String("devpod-ec2-role"),
}

_, err = svc.PutRolePolicy(ctx, kmsDecryptPolicyInput)
if err != nil {
return "", err
}
}

instanceProfile := &iam.CreateInstanceProfileInput{
InstanceProfileName: aws.String("devpod-ec2-role"),
}
Expand Down Expand Up @@ -445,6 +480,11 @@ func CreateDevpodSecurityGroup(ctx context.Context, provider *AwsProvider) (stri

groupID := *result.GroupId

// No need to open ssh port if use session manager.
if provider.Config.UseSessionManager {
return groupID, nil
}

// Add permissions to the security group
_, err = svc.AuthorizeSecurityGroupIngress(ctx, &ec2.AuthorizeSecurityGroupIngressInput{
GroupId: aws.String(groupID),
Expand Down
30 changes: 30 additions & 0 deletions pkg/aws/ssm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package aws

import (
"encoding/json"
"fmt"
)

type ssmPortForwardingParameters struct {
PortNumber []string `json:"portNumber"`
LocalPortNumber []string `json:"localPortNumber"`
}

func CommandArgsSSMTunneling(instanceID string, localPort int) ([]string, error) {
parameters := &ssmPortForwardingParameters{
PortNumber: []string{"22"},
LocalPortNumber: []string{fmt.Sprintf("%d", localPort)},
}

parameters_as_json, err := json.Marshal(parameters)
if err != nil {
return []string{}, err
}

return []string{
"ssm", "start-session",
"--target", instanceID,
"--document-name", "AWS-StartPortForwardingSession",
fmt.Sprintf("--parameters=%s", string(parameters_as_json)),
}, nil
}
38 changes: 38 additions & 0 deletions pkg/aws/ssm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package aws

import (
"fmt"
"strings"
"testing"
)

func TestCommandArgsSSMTunneling(t *testing.T) {
tests := []struct {
testName string
instanceId string
localPort int
expect []string
}{
{
testName: "Test 1",
instanceId: "i-0011223344",
localPort: 30114,
expect: []string{
"ssm", "start-session", "--target", "i-0011223344",
"--document-name", "AWS-StartPortForwardingSession",
fmt.Sprintf("--parameters={\"portNumber\":[\"22\"],\"localPortNumber\":[\"%d\"]}", 30114),
},
},
}

for _, tt := range tests {
t.Run(tt.testName, func(t *testing.T) {
args, _ := CommandArgsSSMTunneling(tt.instanceId, tt.localPort)
expect_str := strings.Join(tt.expect, " ")
args_str := strings.Join(args, " ")
if expect_str != args_str {
t.Errorf("Expected %v but got %v", tt.expect, args)
}
})
}
}
32 changes: 19 additions & 13 deletions pkg/options/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,21 @@ import (
)

var (
AWS_AMI = "AWS_AMI"
AWS_DISK_SIZE = "AWS_DISK_SIZE"
AWS_ROOT_DEVICE = "AWS_ROOT_DEVICE"
AWS_INSTANCE_TYPE = "AWS_INSTANCE_TYPE"
AWS_REGION = "AWS_REGION"
AWS_SECURITY_GROUP_ID = "AWS_SECURITY_GROUP_ID"
AWS_SUBNET_ID = "AWS_SUBNET_ID"
AWS_VPC_ID = "AWS_VPC_ID"
AWS_INSTANCE_TAGS = "AWS_INSTANCE_TAGS"
AWS_INSTANCE_PROFILE_ARN = "AWS_INSTANCE_PROFILE_ARN"
AWS_USE_INSTANCE_CONNECT_ENDPOINT = "AWS_USE_INSTANCE_CONNECT_ENDPOINT"
AWS_INSTANCE_CONNECT_ENDPOINT_ID = "AWS_INSTANCE_CONNECT_ENDPOINT_ID"
AWS_USE_SPOT_INSTANCE = "AWS_USE_SPOT_INSTANCE"
AWS_AMI = "AWS_AMI"
AWS_DISK_SIZE = "AWS_DISK_SIZE"
AWS_ROOT_DEVICE = "AWS_ROOT_DEVICE"
AWS_INSTANCE_TYPE = "AWS_INSTANCE_TYPE"
AWS_REGION = "AWS_REGION"
AWS_SECURITY_GROUP_ID = "AWS_SECURITY_GROUP_ID"
AWS_SUBNET_ID = "AWS_SUBNET_ID"
AWS_VPC_ID = "AWS_VPC_ID"
AWS_INSTANCE_TAGS = "AWS_INSTANCE_TAGS"
AWS_INSTANCE_PROFILE_ARN = "AWS_INSTANCE_PROFILE_ARN"
AWS_USE_INSTANCE_CONNECT_ENDPOINT = "AWS_USE_INSTANCE_CONNECT_ENDPOINT"
AWS_INSTANCE_CONNECT_ENDPOINT_ID = "AWS_INSTANCE_CONNECT_ENDPOINT_ID"
AWS_USE_SPOT_INSTANCE = "AWS_USE_SPOT_INSTANCE"
AWS_USE_SESSION_MANAGER = "AWS_USE_SESSION_MANAGER"
AWS_KMS_KEY_ARN_FOR_SESSION_MANAGER = "AWS_KMS_KEY_ARN_FOR_SESSION_MANAGER"
)

type Options struct {
Expand All @@ -38,6 +40,8 @@ type Options struct {
UseInstanceConnectEndpoint bool
InstanceConnectEndpointID string
UseSpotInstance bool
UseSessionManager bool
KmsKeyARNForSessionManager string
}

func FromEnv(init bool) (*Options, error) {
Expand Down Expand Up @@ -71,6 +75,8 @@ func FromEnv(init bool) (*Options, error) {
retOptions.UseInstanceConnectEndpoint = os.Getenv(AWS_USE_INSTANCE_CONNECT_ENDPOINT) == "true"
retOptions.InstanceConnectEndpointID = os.Getenv(AWS_INSTANCE_CONNECT_ENDPOINT_ID)
retOptions.UseSpotInstance = os.Getenv(AWS_USE_SPOT_INSTANCE) == "true"
retOptions.UseSessionManager = os.Getenv(AWS_USE_SESSION_MANAGER) == "true"
retOptions.KmsKeyARNForSessionManager = os.Getenv(AWS_KMS_KEY_ARN_FOR_SESSION_MANAGER)

// Return eraly if we're just doing init
if init {
Expand Down

0 comments on commit e949dfd

Please sign in to comment.