Skip to content

Commit

Permalink
[v16] prevent session id reuse (#42133)
Browse files Browse the repository at this point in the history
* prevent SSH session id reuse

If a user sets `TELEPORT_SESSION` to a session ID
that isn't running, the sent session ID is ignored
and a new one is used to prevent the user
controlling the value. When parsing a session ID
the parsed result is used to ensure all session
IDs will have the same format.
`CreateSessionTracker` is updated to create
instead of upsert to ensure `SessionTracker`
resources cannot be overwritten by accident.

Clients often need to know the current session ID,
and before when they were in control of it that
wasn't an issue. To fix this servers will send the
current session ID in a SSH channel request on the
channel the SSH session was created on. This
should be fully backwards compatible, as when an
unknown SSH request type is received clients will
just ignore it/reply with false.

* add test for wrappedSSHConn
  • Loading branch information
capnspacehook authored May 30, 2024
1 parent 836b145 commit 646bd5f
Show file tree
Hide file tree
Showing 21 changed files with 449 additions and 87 deletions.
92 changes: 87 additions & 5 deletions api/observability/tracing/ssh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"net"
"sync"
"sync/atomic"

"github.com/gravitational/trace"
"go.opentelemetry.io/otel/attribute"
Expand Down Expand Up @@ -164,6 +165,18 @@ func (c *Client) OpenChannel(
// NewSession creates a new SSH session that is passed tracing context
// so that spans may be correlated properly over the ssh connection.
func (c *Client) NewSession(ctx context.Context) (*Session, error) {
return c.newSession(ctx, nil)
}

// NewSessionWithRequestCallback creates a new SSH session that is passed
// tracing context so that spans may be correlated properly over the ssh
// connection. The handling of channel requests from the underlying SSH
// session can be controlled with chanReqCallback.
func (c *Client) NewSessionWithRequestCallback(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) {
return c.newSession(ctx, chanReqCallback)
}

func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) {
tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName)

ctx, span := tracer.Start(
Expand Down Expand Up @@ -191,7 +204,7 @@ func (c *Client) NewSession(ctx context.Context) (*Session, error) {
}

// get a session from the wrapper
session, err := wrapper.NewSession()
session, err := wrapper.NewSession(chanReqCallback)
return session, trace.Wrap(err)
}

Expand All @@ -216,18 +229,54 @@ type clientWrapper struct {
contexts map[string][]context.Context
}

// ChannelRequestCallback allows the handling of channel requests
// to be customized. nil can be returned if you don't want
// golang/x/crypto/ssh to handle the request.
type ChannelRequestCallback func(req *ssh.Request) *ssh.Request

// NewSession opens a new Session for this client.
func (c *clientWrapper) NewSession() (*Session, error) {
func (c *clientWrapper) NewSession(callback ChannelRequestCallback) (*Session, error) {
// create a client that will defer to us when
// opening the "session" channel so that we
// can add an Envelope to the request
client := &ssh.Client{
Conn: c,
}

session, err := client.NewSession()
if err != nil {
return nil, trace.Wrap(err)
var session *ssh.Session
var err error
if callback != nil {
// open a session manually so we can take ownership of the
// requests chan
ch, originalReqs, openChannelErr := client.OpenChannel("session", nil)
if openChannelErr != nil {
return nil, trace.Wrap(openChannelErr)
}

// pass the channel requests to the provided callback and
// forward them to another chan so golang.org/x/crypto/ssh
// can handle Session exiting correctly
reqs := make(chan *ssh.Request, cap(originalReqs))
go func() {
defer close(reqs)

for req := range originalReqs {
if req := callback(req); req != nil {
reqs <- req
}
}
}()

session, err = newCryptoSSHSession(ch, reqs)
if err != nil {
_ = ch.Close()
return nil, trace.Wrap(err)
}
} else {
session, err = client.NewSession()
if err != nil {
return nil, trace.Wrap(err)
}
}

// wrap the session so all session requests on the channel
Expand All @@ -238,6 +287,39 @@ func (c *clientWrapper) NewSession() (*Session, error) {
}, nil
}

// wrappedSSHConn allows an SSH session to be created while also allowing
// callers to take ownership of the SSH channel requests chan.
type wrappedSSHConn struct {
ssh.Conn

channelOpened atomic.Bool

ch ssh.Channel
reqs <-chan *ssh.Request
}

func (f *wrappedSSHConn) OpenChannel(_ string, _ []byte) (ssh.Channel, <-chan *ssh.Request, error) {
if !f.channelOpened.CompareAndSwap(false, true) {
panic("wrappedSSHConn OpenChannel called more than once")
}

return f.ch, f.reqs, nil
}

// newCryptoSSHSession allows callers to take ownership of the SSH
// channel requests chan and allow callers to handle SSH channel requests.
// golang.org/x/crypto/ssh.(Client).NewSession takes ownership of all
// SSH channel requests and doesn't allow the caller to view or reply
// to them, so this workaround is needed.
func newCryptoSSHSession(ch ssh.Channel, reqs <-chan *ssh.Request) (*ssh.Session, error) {
return (&ssh.Client{
Conn: &wrappedSSHConn{
ch: ch,
reqs: reqs,
},
}).NewSession()
}

// Dial initiates a connection to the addr from the remote host.
func (c *clientWrapper) Dial(n, addr string) (net.Conn, error) {
// create a client that will defer to us when
Expand Down
26 changes: 26 additions & 0 deletions api/observability/tracing/ssh/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,29 @@ func TestSetEnvs(t *testing.T) {
default:
}
}

type mockSSHChannel struct {
ssh.Channel
}

func TestWrappedSSHConn(t *testing.T) {
sshCh := new(mockSSHChannel)
reqs := make(<-chan *ssh.Request)

// ensure that OpenChannel returns the same SSH channel and requests
// chan that wrappedSSHConn was given
wrappedConn := &wrappedSSHConn{
ch: sshCh,
reqs: reqs,
}
retCh, retReqs, err := wrappedConn.OpenChannel("", nil)
require.NoError(t, err)
require.Equal(t, sshCh, retCh)
require.Equal(t, reqs, retReqs)

// ensure the wrapped SSH conn will panic if OpenChannel is called
// twice
require.Panics(t, func() {
wrappedConn.OpenChannel("", nil)
})
}
8 changes: 8 additions & 0 deletions constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,14 @@ const (
// version they are running.
VersionRequest = "x-teleport-version"

// CurrentSessionIDRequest is sent by servers to inform clients of
// the session ID that is being used.
CurrentSessionIDRequest = "[email protected]"

// SessionIDQueryRequest is sent by clients to ask servers if they
// will generate their own session ID when a new session is created.
SessionIDQueryRequest = "[email protected]"

// ForceTerminateRequest is an SSH request to forcefully terminate a session.
ForceTerminateRequest = "x-teleport-force-terminate"

Expand Down
2 changes: 1 addition & 1 deletion integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,7 @@ func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) {
)
assert.NoError(t, err)

errCh <- nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil)
errCh <- nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil, nil)
assert.NoError(t, nodeClient.Close())
}()

Expand Down
2 changes: 1 addition & 1 deletion lib/auth/auth_with_roles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5717,7 +5717,7 @@ func TestLocalServiceRolesHavePermissionsForUploaderService(t *testing.T) {
}

t.Run("GetSessionTracker", func(t *testing.T) {
sid := session.ID("test-session")
sid := session.NewID()
tracker, err := s.CreateSessionTracker(ctx, &types.SessionTrackerV1{
ResourceHeader: types.ResourceHeader{
Metadata: types.Metadata{
Expand Down
8 changes: 6 additions & 2 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,10 @@ type TeleportClient struct {
Config
localAgent *LocalKeyAgent

// OnChannelRequest gets called when SSH channel requests are
// received. It's safe to keep it nil.
OnChannelRequest tracessh.ChannelRequestCallback

// OnShellCreated gets called when the shell is created. It's
// safe to keep it nil.
OnShellCreated ShellCreatedCallback
Expand Down Expand Up @@ -1941,7 +1945,7 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt
// Reuse the existing nodeClient we connected above.
return nodeClient.RunCommand(ctx, command)
}
return trace.Wrap(nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil))
return trace.Wrap(nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, tc.OnChannelRequest, nil))
}

func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, clt *ClusterClient, nodes []targetNode, command []string) error {
Expand Down Expand Up @@ -2100,7 +2104,7 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan
fmt.Printf("Joining session with participant mode: %v. \n\n", mode)

// running shell with a given session means "join" it:
err = nc.RunInteractiveShell(ctx, mode, session, beforeStart)
err = nc.RunInteractiveShell(ctx, mode, session, tc.OnChannelRequest, beforeStart)
return trace.Wrap(err)
}

Expand Down
6 changes: 3 additions & 3 deletions lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ func NewNodeClient(ctx context.Context, sshConfig *ssh.ClientConfig, conn net.Co
// RunInteractiveShell creates an interactive shell on the node and copies stdin/stdout/stderr
// to and from the node and local shell. This will block until the interactive shell on the node
// is terminated.
func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker, beforeStart func(io.Writer)) error {
func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker, chanReqCallback tracessh.ChannelRequestCallback, beforeStart func(io.Writer)) error {
ctx, span := c.Tracer.Start(
ctx,
"nodeClient/RunInteractiveShell",
Expand Down Expand Up @@ -395,7 +395,7 @@ func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.Session
return trace.Wrap(err)
}

if err = nodeSession.runShell(ctx, mode, beforeStart, c.TC.OnShellCreated); err != nil {
if err = nodeSession.runShell(ctx, mode, c.TC.OnChannelRequest, beforeStart, c.TC.OnShellCreated); err != nil {
var exitErr *ssh.ExitError
var exitMissingErr *ssh.ExitMissingError
switch err := trace.Unwrap(err); {
Expand Down Expand Up @@ -539,7 +539,7 @@ func (c *NodeClient) RunCommand(ctx context.Context, command []string, opts ...R
return trace.Wrap(err)
}
defer nodeSession.Close()
if err := nodeSession.runCommand(ctx, types.SessionPeerMode, command, c.TC.OnShellCreated, c.TC.Config.InteractiveCommand); err != nil {
if err := nodeSession.runCommand(ctx, types.SessionPeerMode, command, c.TC.OnChannelRequest, c.TC.OnShellCreated, c.TC.Config.InteractiveCommand); err != nil {
originErr := trace.Unwrap(err)
var exitErr *ssh.ExitError
if errors.As(originErr, &exitErr) {
Expand Down
1 change: 0 additions & 1 deletion lib/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,6 @@ func TestListenAndForwardCancel(t *testing.T) {
}
})
}

}

func newTestListener(t *testing.T, handle func(net.Conn)) net.Listener {
Expand Down
38 changes: 21 additions & 17 deletions lib/client/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ func newSession(ctx context.Context,
}
// new session!
} else {
// TODO(capnspacehook): DELETE IN 17.0.0
// clients shouldn't set TELEPORT_SESSION when they aren't joining
// a session, and won't need to once all supported Proxy/Node
// versions set the session ID for new sessions
sid, ok := ns.env[sshutils.SessionEnvVar]
if !ok {
sid = string(session.NewID())
Expand Down Expand Up @@ -188,35 +192,35 @@ func (ns *NodeSession) NodeClient() *NodeClient {
return ns.nodeClient
}

func (ns *NodeSession) regularSession(ctx context.Context, callback func(s *tracessh.Session) error) error {
func (ns *NodeSession) regularSession(ctx context.Context, chanReqCallback tracessh.ChannelRequestCallback, sessionCallback func(s *tracessh.Session) error) error {
ctx, span := ns.nodeClient.Tracer.Start(
ctx,
"nodeClient/regularSession",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
)
defer span.End()

session, err := ns.createServerSession(ctx)
session, err := ns.createServerSession(ctx, chanReqCallback)
if err != nil {
return trace.Wrap(err)
}
session.Stdout = ns.terminal.Stdout()
session.Stderr = ns.terminal.Stderr()
session.Stdin = ns.terminal.Stdin()
return trace.Wrap(callback(session))
return trace.Wrap(sessionCallback(session))
}

type interactiveCallback func(serverSession *tracessh.Session, shell io.ReadWriteCloser) error

func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Session, error) {
func (ns *NodeSession) createServerSession(ctx context.Context, chanReqCallback tracessh.ChannelRequestCallback) (*tracessh.Session, error) {
ctx, span := ns.nodeClient.Tracer.Start(
ctx,
"nodeClient/createServerSession",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
)
defer span.End()

sess, err := ns.nodeClient.Client.NewSession(ctx)
sess, err := ns.nodeClient.Client.NewSessionWithRequestCallback(ctx, chanReqCallback)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -284,7 +288,7 @@ func selectKeyAgent(tc *TeleportClient) agent.ExtendedAgent {

// interactiveSession creates an interactive session on the remote node, executes
// the given callback on it, and waits for the session to end
func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.SessionParticipantMode, callback interactiveCallback) error {
func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.SessionParticipantMode, chanReqCallback tracessh.ChannelRequestCallback, sessionCallback interactiveCallback) error {
ctx, span := ns.nodeClient.Tracer.Start(
ctx,
"nodeClient/interactiveSession",
Expand All @@ -298,7 +302,7 @@ func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.Sessio
termType = teleport.SafeTerminalType
}
// create the server-side session:
sess, err := ns.createServerSession(ctx)
sess, err := ns.createServerSession(ctx, chanReqCallback)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -311,7 +315,7 @@ func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.Sessio

// call the passed callback and give them the established
// ssh session:
if err := callback(sess, remoteTerm); err != nil {
if err := sessionCallback(sess, remoteTerm); err != nil {
return trace.Wrap(err)
}

Expand Down Expand Up @@ -514,8 +518,8 @@ func (s *sessionWriter) Write(p []byte) (int, error) {
}

// runShell executes user's shell on the remote node under an interactive session
func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipantMode, beforeStart func(io.Writer), callback ShellCreatedCallback) error {
return ns.interactiveSession(ctx, mode, func(s *tracessh.Session, shell io.ReadWriteCloser) error {
func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipantMode, chanReqCallback tracessh.ChannelRequestCallback, beforeStart func(io.Writer), shellCallback ShellCreatedCallback) error {
return ns.interactiveSession(ctx, mode, chanReqCallback, func(s *tracessh.Session, shell io.ReadWriteCloser) error {
w := &sessionWriter{
tshOut: ns.nodeClient.TC.Stdout,
session: s,
Expand All @@ -531,8 +535,8 @@ func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipa
}

// call the client-supplied callback
if callback != nil {
exit, err := callback(s, ns.nodeClient.Client, shell)
if shellCallback != nil {
exit, err := shellCallback(s, ns.nodeClient.Client, shell)
if exit {
return trace.Wrap(err)
}
Expand All @@ -543,7 +547,7 @@ func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipa

// runCommand executes a "exec" request either in interactive mode (with a
// TTY attached) or non-intractive mode (no TTY).
func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionParticipantMode, cmd []string, callback ShellCreatedCallback, interactive bool) error {
func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionParticipantMode, cmd []string, chanReqCallback tracessh.ChannelRequestCallback, shellCallback ShellCreatedCallback, interactive bool) error {
ctx, span := ns.nodeClient.Tracer.Start(
ctx,
"nodeClient/runCommand",
Expand All @@ -564,13 +568,13 @@ func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionPartici
// keyboard based signals will be propogated to the TTY on the server which is
// where all signal handling will occur.
if interactive {
return ns.interactiveSession(ctx, mode, func(s *tracessh.Session, term io.ReadWriteCloser) error {
return ns.interactiveSession(ctx, mode, chanReqCallback, func(s *tracessh.Session, term io.ReadWriteCloser) error {
err := s.Start(ctx, strings.Join(cmd, " "))
if err != nil {
return trace.Wrap(err)
}
if callback != nil {
exit, err := callback(s, ns.NodeClient().Client, term)
if shellCallback != nil {
exit, err := shellCallback(s, ns.NodeClient().Client, term)
if exit {
return trace.Wrap(err)
}
Expand All @@ -591,7 +595,7 @@ func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionPartici
// Unfortunately at the moment the Go SSH library Teleport uses does not
// support sending SSH_MSG_DISCONNECT. Instead we close the SSH channel and
// SSH client, and try and exit as gracefully as possible.
return ns.regularSession(ctx, func(s *tracessh.Session) error {
return ns.regularSession(ctx, chanReqCallback, func(s *tracessh.Session) error {
errCh := make(chan error, 1)
go func() {
errCh <- s.Run(ctx, strings.Join(cmd, " "))
Expand Down
Loading

0 comments on commit 646bd5f

Please sign in to comment.