Skip to content

Commit

Permalink
Fix an issue "tsh aws ssm start-session" fails when KMS encryption is…
Browse files Browse the repository at this point in the history
… enabled (#50402) (#50798)

* Fix an issue "tsh aws ssm start-session" fails when KMS encryption is enabled

* remove httputil dump

* fix ut

* remove unused funcs
  • Loading branch information
greedy52 authored Jan 8, 2025
1 parent 062bf3a commit e2dcafd
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 89 deletions.
22 changes: 21 additions & 1 deletion lib/srv/alpnproxy/forward_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,27 @@ func MatchAllRequests(req *http.Request) bool {
// MatchAWSRequests is a MatchFunc that returns true if request is an AWS API
// request.
func MatchAWSRequests(req *http.Request) bool {
return awsapiutils.IsAWSEndpoint(req.Host)
return awsapiutils.IsAWSEndpoint(req.Host) &&
// Avoid proxying SSM session WebSocket requests and let the forward proxy
// send it directly to AWS.
//
// `aws ssm start-session` first calls ssm.<region>.amazonaws.com to get
// a stream URL and a token. Then it makes a wss connection with the
// provided token to the provided stream URL. The stream URL looks like:
// wss://ssmmessages.region.amazonaws.com/v1/data-channel/session-id?stream=(input|output)
//
// The wss request currently respects HTTPS_PROXY but does not
// respect local CA bundle we provided thus causing a failure. The
// request is not signed with SigV4 either.
//
// Reference:
// https://github.com/aws/session-manager-plugin/
!isAWSSSMWebsocketRequest(req)
}

func isAWSSSMWebsocketRequest(req *http.Request) bool {
return awsapiutils.IsAWSEndpoint(req.Host) &&
strings.HasPrefix(req.Host, "ssmmessages.")
}

// MatchAzureRequests is a MatchFunc that returns true if request is an Azure API
Expand Down
40 changes: 40 additions & 0 deletions lib/srv/alpnproxy/forward_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,43 @@ func TestMatchGCPRequests(t *testing.T) {
})
}
}

func TestMatchAWSRequests(t *testing.T) {
makeRequest := func(url string) *http.Request {
// Forward proxy always receives CONNECT requests.
request, err := http.NewRequest("CONNECT", url, nil)
require.NoError(t, err)
return request
}
tests := []struct {
name string
req *http.Request
check require.BoolAssertionFunc
}{
{
name: "AWS request",
req: makeRequest("http://s3.ca-central-1.amazonaws.com"),
check: require.True,
},
{
name: "non-AWS request",
req: makeRequest("https://registry.terraform.io"),
check: require.False,
},
{
name: "SSM API",
req: makeRequest("https://ssm.ca-central-1.amazonaws.com"),
check: require.True,
},
{
name: "SSM session WebSocket",
req: makeRequest("wss://ssmmessages.ca-central-1.amazonaws.com"),
check: require.False,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.check(t, MatchAWSRequests(tt.req))
})
}
}
55 changes: 0 additions & 55 deletions tool/tsh/common/app_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"net"
"os"
"os/exec"
"strings"
"sync"

awsarn "github.com/aws/aws-sdk-go/aws/arn"
Expand All @@ -49,11 +48,6 @@ func onAWS(cf *CLIConf) error {
return trace.Wrap(err)
}

if shouldUseAWSEndpointURLMode(cf) {
log.Debugf("Forcing endpoint URL mode for AWS command %q.", cf.AWSCommandArgs)
cf.AWSEndpointURLMode = true
}

err = awsApp.StartLocalProxies()
if err != nil {
return trace.Wrap(err)
Expand All @@ -79,55 +73,6 @@ func onAWS(cf *CLIConf) error {
return awsApp.RunCommand(cmd)
}

func shouldUseAWSEndpointURLMode(cf *CLIConf) bool {
inputAWSCommand := strings.Join(removeAWSCommandFlags(cf.AWSCommandArgs), " ")
switch inputAWSCommand {
// `aws ssm start-session` first calls ssm.<region>.amazonaws.com to get an
// stream URL and an token. Then it makes a wss connection with the
// provided token to the provided stream URL. The wss request currently
// respects HTTPS_PROXY but does not respect local CA bundle we provided
// thus causing a failure. Even if this is resolved one day, the wss send
// the token through websocket data channel for authentication, instead of
// sigv4, which likely we won't support.
//
// When using the endpoint URL mode, only the first request goes through
// Teleport Proxy. The wss connection does not respect the endpoint URL and
// goes to AWS directly (thus working fine).
//
// Reference:
// https://github.com/aws/session-manager-plugin/
//
// "aws ecs execute-command" also start SSM sessions.
case "ssm start-session", "ecs execute-command":
return true
default:
return false
}
}

func removeAWSCommandFlags(args []string) (ret []string) {
for i := 0; i < len(args); i++ {
switch {
case isAWSFlag(args, i):
// Skip next arg, if next arg is not a flag but a flag value.
if !isAWSFlag(args, i+1) {
i++
}
continue
default:
ret = append(ret, args[i])
}
}
return
}

func isAWSFlag(args []string, i int) bool {
if i >= len(args) {
return false
}
return strings.HasPrefix(args[i], "--")
}

// awsApp is an AWS app that can start local proxies to serve AWS APIs.
type awsApp struct {
cf *CLIConf
Expand Down
33 changes: 0 additions & 33 deletions tool/tsh/common/app_aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,39 +141,6 @@ func TestAWS(t *testing.T) {
setCmdRunner(validateCmd),
)
require.NoError(t, err)

t.Run("aws ssm start-session", func(t *testing.T) {
// Validate --endpoint-url 127.0.0.1:<port> is added to the command.
validateCmd := func(cmd *exec.Cmd) error {
require.Len(t, cmd.Args, 9)
require.Equal(t, []string{"aws", "ssm", "--region", "us-west-1", "start-session", "--target", "target-id", "--endpoint-url"}, cmd.Args[:8])
require.Contains(t, cmd.Args[8], "127.0.0.1:")
return nil
}
err = Run(
context.Background(),
[]string{"aws", "ssm", "--region", "us-west-1", "start-session", "--target", "target-id"},
setHomePath(tmpHomePath),
setCmdRunner(validateCmd),
)
require.NoError(t, err)
})
t.Run("aws ecs execute-command", func(t *testing.T) {
// Validate --endpoint-url 127.0.0.1:<port> is added to the command.
validateCmd := func(cmd *exec.Cmd) error {
require.Len(t, cmd.Args, 13)
require.Equal(t, []string{"aws", "ecs", "execute-command", "--debug", "--cluster", "cluster-name", "--task", "task-name", "--command", "/bin/bash", "--interactive", "--endpoint-url"}, cmd.Args[:12])
require.Contains(t, cmd.Args[12], "127.0.0.1:")
return nil
}
err = Run(
context.Background(),
[]string{"aws", "ecs", "execute-command", "--debug", "--cluster", "cluster-name", "--task", "task-name", "--command", "/bin/bash", "--interactive"},
setHomePath(tmpHomePath),
setCmdRunner(validateCmd),
)
require.NoError(t, err)
})
}

func makeUserWithAWSRole(t *testing.T) (types.User, types.Role) {
Expand Down

0 comments on commit e2dcafd

Please sign in to comment.