From 646bd5f68be3cc9b44cd1190caca00956b5e92df Mon Sep 17 00:00:00 2001 From: Andrew LeFevre Date: Wed, 29 May 2024 22:29:59 -0400 Subject: [PATCH] [v16] prevent session id reuse (#42133) * prevent SSH session id reuse If a user sets `TELEPORT_SESSION` to a session ID that isn't running, the sent session ID is ignored and a new one is used to prevent the user controlling the value. When parsing a session ID the parsed result is used to ensure all session IDs will have the same format. `CreateSessionTracker` is updated to create instead of upsert to ensure `SessionTracker` resources cannot be overwritten by accident. Clients often need to know the current session ID, and before when they were in control of it that wasn't an issue. To fix this servers will send the current session ID in a SSH channel request on the channel the SSH session was created on. This should be fully backwards compatible, as when an unknown SSH request type is received clients will just ignore it/reply with false. * add test for wrappedSSHConn --- api/observability/tracing/ssh/client.go | 92 ++++++++++++++++++-- api/observability/tracing/ssh/client_test.go | 26 ++++++ constants.go | 8 ++ integration/integration_test.go | 2 +- lib/auth/auth_with_roles_test.go | 2 +- lib/client/api.go | 8 +- lib/client/client.go | 6 +- lib/client/client_test.go | 1 - lib/client/session.go | 38 ++++---- lib/services/local/sessiontracker.go | 2 +- lib/session/session.go | 11 ++- lib/srv/ctx.go | 32 ++++++- lib/srv/ctx_test.go | 75 ++++++++++++++++ lib/srv/forward/sshserver.go | 7 ++ lib/srv/mock.go | 9 +- lib/srv/regular/sshserver.go | 7 ++ lib/srv/sess.go | 25 ++---- lib/srv/sess_test.go | 4 +- lib/web/sessions.go | 84 +++++++++++++++++- lib/web/terminal.go | 57 ++++++++++-- lib/web/terminal/terminal.go | 40 +++++---- 21 files changed, 449 insertions(+), 87 deletions(-) diff --git a/api/observability/tracing/ssh/client.go b/api/observability/tracing/ssh/client.go index e79686adf8b42..dd0a20c1d3217 100644 --- a/api/observability/tracing/ssh/client.go +++ b/api/observability/tracing/ssh/client.go @@ -20,6 +20,7 @@ import ( "fmt" "net" "sync" + "sync/atomic" "github.com/gravitational/trace" "go.opentelemetry.io/otel/attribute" @@ -164,6 +165,18 @@ func (c *Client) OpenChannel( // NewSession creates a new SSH session that is passed tracing context // so that spans may be correlated properly over the ssh connection. func (c *Client) NewSession(ctx context.Context) (*Session, error) { + return c.newSession(ctx, nil) +} + +// NewSessionWithRequestCallback creates a new SSH session that is passed +// tracing context so that spans may be correlated properly over the ssh +// connection. The handling of channel requests from the underlying SSH +// session can be controlled with chanReqCallback. +func (c *Client) NewSessionWithRequestCallback(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) { + return c.newSession(ctx, chanReqCallback) +} + +func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) { tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName) ctx, span := tracer.Start( @@ -191,7 +204,7 @@ func (c *Client) NewSession(ctx context.Context) (*Session, error) { } // get a session from the wrapper - session, err := wrapper.NewSession() + session, err := wrapper.NewSession(chanReqCallback) return session, trace.Wrap(err) } @@ -216,8 +229,13 @@ type clientWrapper struct { contexts map[string][]context.Context } +// ChannelRequestCallback allows the handling of channel requests +// to be customized. nil can be returned if you don't want +// golang/x/crypto/ssh to handle the request. +type ChannelRequestCallback func(req *ssh.Request) *ssh.Request + // NewSession opens a new Session for this client. -func (c *clientWrapper) NewSession() (*Session, error) { +func (c *clientWrapper) NewSession(callback ChannelRequestCallback) (*Session, error) { // create a client that will defer to us when // opening the "session" channel so that we // can add an Envelope to the request @@ -225,9 +243,40 @@ func (c *clientWrapper) NewSession() (*Session, error) { Conn: c, } - session, err := client.NewSession() - if err != nil { - return nil, trace.Wrap(err) + var session *ssh.Session + var err error + if callback != nil { + // open a session manually so we can take ownership of the + // requests chan + ch, originalReqs, openChannelErr := client.OpenChannel("session", nil) + if openChannelErr != nil { + return nil, trace.Wrap(openChannelErr) + } + + // pass the channel requests to the provided callback and + // forward them to another chan so golang.org/x/crypto/ssh + // can handle Session exiting correctly + reqs := make(chan *ssh.Request, cap(originalReqs)) + go func() { + defer close(reqs) + + for req := range originalReqs { + if req := callback(req); req != nil { + reqs <- req + } + } + }() + + session, err = newCryptoSSHSession(ch, reqs) + if err != nil { + _ = ch.Close() + return nil, trace.Wrap(err) + } + } else { + session, err = client.NewSession() + if err != nil { + return nil, trace.Wrap(err) + } } // wrap the session so all session requests on the channel @@ -238,6 +287,39 @@ func (c *clientWrapper) NewSession() (*Session, error) { }, nil } +// wrappedSSHConn allows an SSH session to be created while also allowing +// callers to take ownership of the SSH channel requests chan. +type wrappedSSHConn struct { + ssh.Conn + + channelOpened atomic.Bool + + ch ssh.Channel + reqs <-chan *ssh.Request +} + +func (f *wrappedSSHConn) OpenChannel(_ string, _ []byte) (ssh.Channel, <-chan *ssh.Request, error) { + if !f.channelOpened.CompareAndSwap(false, true) { + panic("wrappedSSHConn OpenChannel called more than once") + } + + return f.ch, f.reqs, nil +} + +// newCryptoSSHSession allows callers to take ownership of the SSH +// channel requests chan and allow callers to handle SSH channel requests. +// golang.org/x/crypto/ssh.(Client).NewSession takes ownership of all +// SSH channel requests and doesn't allow the caller to view or reply +// to them, so this workaround is needed. +func newCryptoSSHSession(ch ssh.Channel, reqs <-chan *ssh.Request) (*ssh.Session, error) { + return (&ssh.Client{ + Conn: &wrappedSSHConn{ + ch: ch, + reqs: reqs, + }, + }).NewSession() +} + // Dial initiates a connection to the addr from the remote host. func (c *clientWrapper) Dial(n, addr string) (net.Conn, error) { // create a client that will defer to us when diff --git a/api/observability/tracing/ssh/client_test.go b/api/observability/tracing/ssh/client_test.go index 3fb7a297369ac..b59549f2181bc 100644 --- a/api/observability/tracing/ssh/client_test.go +++ b/api/observability/tracing/ssh/client_test.go @@ -242,3 +242,29 @@ func TestSetEnvs(t *testing.T) { default: } } + +type mockSSHChannel struct { + ssh.Channel +} + +func TestWrappedSSHConn(t *testing.T) { + sshCh := new(mockSSHChannel) + reqs := make(<-chan *ssh.Request) + + // ensure that OpenChannel returns the same SSH channel and requests + // chan that wrappedSSHConn was given + wrappedConn := &wrappedSSHConn{ + ch: sshCh, + reqs: reqs, + } + retCh, retReqs, err := wrappedConn.OpenChannel("", nil) + require.NoError(t, err) + require.Equal(t, sshCh, retCh) + require.Equal(t, reqs, retReqs) + + // ensure the wrapped SSH conn will panic if OpenChannel is called + // twice + require.Panics(t, func() { + wrappedConn.OpenChannel("", nil) + }) +} diff --git a/constants.go b/constants.go index 5d068f67bb6cd..e02ddac503a62 100644 --- a/constants.go +++ b/constants.go @@ -731,6 +731,14 @@ const ( // version they are running. VersionRequest = "x-teleport-version" + // CurrentSessionIDRequest is sent by servers to inform clients of + // the session ID that is being used. + CurrentSessionIDRequest = "current-session-id@goteleport.com" + + // SessionIDQueryRequest is sent by clients to ask servers if they + // will generate their own session ID when a new session is created. + SessionIDQueryRequest = "session-id-query@goteleport.com" + // ForceTerminateRequest is an SSH request to forcefully terminate a session. ForceTerminateRequest = "x-teleport-force-terminate" diff --git a/integration/integration_test.go b/integration/integration_test.go index fc5dd0fa7a228..7742ad071e2fd 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -1236,7 +1236,7 @@ func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) { ) assert.NoError(t, err) - errCh <- nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil) + errCh <- nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil, nil) assert.NoError(t, nodeClient.Close()) }() diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index b0e8e2b4f0f78..db0820f3df0dc 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -5717,7 +5717,7 @@ func TestLocalServiceRolesHavePermissionsForUploaderService(t *testing.T) { } t.Run("GetSessionTracker", func(t *testing.T) { - sid := session.ID("test-session") + sid := session.NewID() tracker, err := s.CreateSessionTracker(ctx, &types.SessionTrackerV1{ ResourceHeader: types.ResourceHeader{ Metadata: types.Metadata{ diff --git a/lib/client/api.go b/lib/client/api.go index 49e5aadc5212d..8dfce0e029be5 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -1111,6 +1111,10 @@ type TeleportClient struct { Config localAgent *LocalKeyAgent + // OnChannelRequest gets called when SSH channel requests are + // received. It's safe to keep it nil. + OnChannelRequest tracessh.ChannelRequestCallback + // OnShellCreated gets called when the shell is created. It's // safe to keep it nil. OnShellCreated ShellCreatedCallback @@ -1941,7 +1945,7 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt // Reuse the existing nodeClient we connected above. return nodeClient.RunCommand(ctx, command) } - return trace.Wrap(nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil)) + return trace.Wrap(nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, tc.OnChannelRequest, nil)) } func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, clt *ClusterClient, nodes []targetNode, command []string) error { @@ -2100,7 +2104,7 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan fmt.Printf("Joining session with participant mode: %v. \n\n", mode) // running shell with a given session means "join" it: - err = nc.RunInteractiveShell(ctx, mode, session, beforeStart) + err = nc.RunInteractiveShell(ctx, mode, session, tc.OnChannelRequest, beforeStart) return trace.Wrap(err) } diff --git a/lib/client/client.go b/lib/client/client.go index 1df323824e2a2..9fe7bd1508486 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -366,7 +366,7 @@ func NewNodeClient(ctx context.Context, sshConfig *ssh.ClientConfig, conn net.Co // RunInteractiveShell creates an interactive shell on the node and copies stdin/stdout/stderr // to and from the node and local shell. This will block until the interactive shell on the node // is terminated. -func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker, beforeStart func(io.Writer)) error { +func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker, chanReqCallback tracessh.ChannelRequestCallback, beforeStart func(io.Writer)) error { ctx, span := c.Tracer.Start( ctx, "nodeClient/RunInteractiveShell", @@ -395,7 +395,7 @@ func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.Session return trace.Wrap(err) } - if err = nodeSession.runShell(ctx, mode, beforeStart, c.TC.OnShellCreated); err != nil { + if err = nodeSession.runShell(ctx, mode, c.TC.OnChannelRequest, beforeStart, c.TC.OnShellCreated); err != nil { var exitErr *ssh.ExitError var exitMissingErr *ssh.ExitMissingError switch err := trace.Unwrap(err); { @@ -539,7 +539,7 @@ func (c *NodeClient) RunCommand(ctx context.Context, command []string, opts ...R return trace.Wrap(err) } defer nodeSession.Close() - if err := nodeSession.runCommand(ctx, types.SessionPeerMode, command, c.TC.OnShellCreated, c.TC.Config.InteractiveCommand); err != nil { + if err := nodeSession.runCommand(ctx, types.SessionPeerMode, command, c.TC.OnChannelRequest, c.TC.OnShellCreated, c.TC.Config.InteractiveCommand); err != nil { originErr := trace.Unwrap(err) var exitErr *ssh.ExitError if errors.As(originErr, &exitErr) { diff --git a/lib/client/client_test.go b/lib/client/client_test.go index 43a1017266a32..8a7273eb04d95 100644 --- a/lib/client/client_test.go +++ b/lib/client/client_test.go @@ -248,7 +248,6 @@ func TestListenAndForwardCancel(t *testing.T) { } }) } - } func newTestListener(t *testing.T, handle func(net.Conn)) net.Listener { diff --git a/lib/client/session.go b/lib/client/session.go index eda85b389be4a..4bf05e0b8bbff 100644 --- a/lib/client/session.go +++ b/lib/client/session.go @@ -157,6 +157,10 @@ func newSession(ctx context.Context, } // new session! } else { + // TODO(capnspacehook): DELETE IN 17.0.0 + // clients shouldn't set TELEPORT_SESSION when they aren't joining + // a session, and won't need to once all supported Proxy/Node + // versions set the session ID for new sessions sid, ok := ns.env[sshutils.SessionEnvVar] if !ok { sid = string(session.NewID()) @@ -188,7 +192,7 @@ func (ns *NodeSession) NodeClient() *NodeClient { return ns.nodeClient } -func (ns *NodeSession) regularSession(ctx context.Context, callback func(s *tracessh.Session) error) error { +func (ns *NodeSession) regularSession(ctx context.Context, chanReqCallback tracessh.ChannelRequestCallback, sessionCallback func(s *tracessh.Session) error) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/regularSession", @@ -196,19 +200,19 @@ func (ns *NodeSession) regularSession(ctx context.Context, callback func(s *trac ) defer span.End() - session, err := ns.createServerSession(ctx) + session, err := ns.createServerSession(ctx, chanReqCallback) if err != nil { return trace.Wrap(err) } session.Stdout = ns.terminal.Stdout() session.Stderr = ns.terminal.Stderr() session.Stdin = ns.terminal.Stdin() - return trace.Wrap(callback(session)) + return trace.Wrap(sessionCallback(session)) } type interactiveCallback func(serverSession *tracessh.Session, shell io.ReadWriteCloser) error -func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Session, error) { +func (ns *NodeSession) createServerSession(ctx context.Context, chanReqCallback tracessh.ChannelRequestCallback) (*tracessh.Session, error) { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/createServerSession", @@ -216,7 +220,7 @@ func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Sessi ) defer span.End() - sess, err := ns.nodeClient.Client.NewSession(ctx) + sess, err := ns.nodeClient.Client.NewSessionWithRequestCallback(ctx, chanReqCallback) if err != nil { return nil, trace.Wrap(err) } @@ -284,7 +288,7 @@ func selectKeyAgent(tc *TeleportClient) agent.ExtendedAgent { // interactiveSession creates an interactive session on the remote node, executes // the given callback on it, and waits for the session to end -func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.SessionParticipantMode, callback interactiveCallback) error { +func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.SessionParticipantMode, chanReqCallback tracessh.ChannelRequestCallback, sessionCallback interactiveCallback) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/interactiveSession", @@ -298,7 +302,7 @@ func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.Sessio termType = teleport.SafeTerminalType } // create the server-side session: - sess, err := ns.createServerSession(ctx) + sess, err := ns.createServerSession(ctx, chanReqCallback) if err != nil { return trace.Wrap(err) } @@ -311,7 +315,7 @@ func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.Sessio // call the passed callback and give them the established // ssh session: - if err := callback(sess, remoteTerm); err != nil { + if err := sessionCallback(sess, remoteTerm); err != nil { return trace.Wrap(err) } @@ -514,8 +518,8 @@ func (s *sessionWriter) Write(p []byte) (int, error) { } // runShell executes user's shell on the remote node under an interactive session -func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipantMode, beforeStart func(io.Writer), callback ShellCreatedCallback) error { - return ns.interactiveSession(ctx, mode, func(s *tracessh.Session, shell io.ReadWriteCloser) error { +func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipantMode, chanReqCallback tracessh.ChannelRequestCallback, beforeStart func(io.Writer), shellCallback ShellCreatedCallback) error { + return ns.interactiveSession(ctx, mode, chanReqCallback, func(s *tracessh.Session, shell io.ReadWriteCloser) error { w := &sessionWriter{ tshOut: ns.nodeClient.TC.Stdout, session: s, @@ -531,8 +535,8 @@ func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipa } // call the client-supplied callback - if callback != nil { - exit, err := callback(s, ns.nodeClient.Client, shell) + if shellCallback != nil { + exit, err := shellCallback(s, ns.nodeClient.Client, shell) if exit { return trace.Wrap(err) } @@ -543,7 +547,7 @@ func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipa // runCommand executes a "exec" request either in interactive mode (with a // TTY attached) or non-intractive mode (no TTY). -func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionParticipantMode, cmd []string, callback ShellCreatedCallback, interactive bool) error { +func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionParticipantMode, cmd []string, chanReqCallback tracessh.ChannelRequestCallback, shellCallback ShellCreatedCallback, interactive bool) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/runCommand", @@ -564,13 +568,13 @@ func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionPartici // keyboard based signals will be propogated to the TTY on the server which is // where all signal handling will occur. if interactive { - return ns.interactiveSession(ctx, mode, func(s *tracessh.Session, term io.ReadWriteCloser) error { + return ns.interactiveSession(ctx, mode, chanReqCallback, func(s *tracessh.Session, term io.ReadWriteCloser) error { err := s.Start(ctx, strings.Join(cmd, " ")) if err != nil { return trace.Wrap(err) } - if callback != nil { - exit, err := callback(s, ns.NodeClient().Client, term) + if shellCallback != nil { + exit, err := shellCallback(s, ns.NodeClient().Client, term) if exit { return trace.Wrap(err) } @@ -591,7 +595,7 @@ func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionPartici // Unfortunately at the moment the Go SSH library Teleport uses does not // support sending SSH_MSG_DISCONNECT. Instead we close the SSH channel and // SSH client, and try and exit as gracefully as possible. - return ns.regularSession(ctx, func(s *tracessh.Session) error { + return ns.regularSession(ctx, chanReqCallback, func(s *tracessh.Session) error { errCh := make(chan error, 1) go func() { errCh <- s.Run(ctx, strings.Join(cmd, " ")) diff --git a/lib/services/local/sessiontracker.go b/lib/services/local/sessiontracker.go index f24c5f88ba820..780344b208013 100644 --- a/lib/services/local/sessiontracker.go +++ b/lib/services/local/sessiontracker.go @@ -192,7 +192,7 @@ func (s *sessionTracker) CreateSessionTracker(ctx context.Context, tracker types Value: json, Expires: tracker.Expiry(), } - _, err = s.bk.Put(ctx, item) + _, err = s.bk.Create(ctx, item) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/session/session.go b/lib/session/session.go index 76c02d1110029..c145823e5f79b 100644 --- a/lib/session/session.go +++ b/lib/session/session.go @@ -54,17 +54,22 @@ func (s *ID) Check() error { // ParseID parses ID and checks if it's correct. func ParseID(id string) (*ID, error) { - _, err := uuid.Parse(id) + parsed, err := uuid.Parse(id) if err != nil { return nil, trace.BadParameter("%v is not a valid UUID", id) } - uid := ID(id) + // use the parsed UUID to build the ID instead of the string that + // was passed in. id is user controlled and uuid.Parse accepts + // several UUID formats that are not supported correctly across + // Teleport. (uuid.UUID).String always uses the same format that + // is supported by Teleport everywhere, so use that. + uid := ID(parsed.String()) return &uid, nil } // NewID returns new session ID. The session ID is based on UUIDv4. func NewID() ID { - return ID(uuid.New().String()) + return ID(uuid.NewString()) } // Session is a session of any kind (SSH, Kubernetes, Desktop, etc) diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index f836a2f3e4748..815b94df99a12 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -319,6 +319,10 @@ type ServerContext struct { // term holds PTY if it was requested by the session. term Terminal + // sessionID holds the session ID that will be used when a new + // session is created. + sessionID rsession.ID + // session holds the active session (if there's an active one). session *session @@ -634,7 +638,7 @@ func (c *ServerContext) SessionID() rsession.ID { c.mu.RLock() defer c.mu.RUnlock() if c.session == nil { - return "" + return c.sessionID } return c.session.id } @@ -653,19 +657,31 @@ func (c *ServerContext) CreateOrJoinSession(reg *SessionRegistry) error { // its ID will be added to the environment ssid, found := c.getEnvLocked(sshutils.SessionEnvVar) if !found { + c.sessionID = rsession.NewID() + c.Logger.Debugf("Will create new session for SSH connection %v.", c.ServerConn.RemoteAddr()) return nil } + // make sure whatever session is requested is a valid session id, err := rsession.ParseID(ssid) if err != nil { - return trace.BadParameter("invalid session id") + return trace.BadParameter("invalid session ID") } // update ctx with the session if it exists if sess, found := reg.findSession(*id); found { + c.sessionID = *id c.session = sess c.Logger.Debugf("Will join session %v for SSH connection %v.", c.session.id, c.ServerConn.RemoteAddr()) } else { + // TODO(capnspacehook): DELETE IN 17.0.0 - by then all supported + // clients should only set TELEPORT_SESSION when they want to + // join a session. Always return an error instead of using a + // new ID. + // + // to prevent the user from controlling the session ID, generate + // a new one + c.sessionID = rsession.NewID() c.Logger.Debugf("Will create new session for SSH connection %v.", c.ServerConn.RemoteAddr()) } @@ -737,10 +753,20 @@ func (c *ServerContext) getEnvLocked(key string) (string, bool) { } // setSession sets the context's session -func (c *ServerContext) setSession(sess *session) { +func (c *ServerContext) setSession(sess *session, ch ssh.Channel) { c.mu.Lock() defer c.mu.Unlock() c.session = sess + + // inform the client of the session ID that is being used in a new + // goroutine to reduce latency + go func() { + c.Logger.Debug("Sending current session ID.") + _, err := ch.SendRequest(teleport.CurrentSessionIDRequest, false, []byte(sess.ID())) + if err != nil { + c.Logger.WithError(err).Debug("Failed to send the current session ID.") + } + }() } // getSession returns the context's session diff --git a/lib/srv/ctx_test.go b/lib/srv/ctx_test.go index d85ea20f7001f..4a630e8d9741d 100644 --- a/lib/srv/ctx_test.go +++ b/lib/srv/ctx_test.go @@ -20,6 +20,7 @@ package srv import ( "bytes" + "context" "os/user" "testing" @@ -34,6 +35,8 @@ import ( "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/services" + rsession "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/teleport/lib/sshutils" ) // TestDecodeChildError ensures that child error message marshaling @@ -302,3 +305,75 @@ type fixedRolesChecker struct { func (c *fixedRolesChecker) RoleNames() []string { return c.roleNames } + +func TestCreateOrJoinSession(t *testing.T) { + t.Parallel() + + srv := newMockServer(t) + registry, err := NewSessionRegistry(SessionRegistryConfig{ + clock: srv.clock, + Srv: srv, + SessionTrackerService: srv.auth, + }) + require.NoError(t, err) + + runningSessionID := rsession.NewID() + sess, _, err := newSession(context.Background(), runningSessionID, registry, newTestServerContext(t, srv, nil), newMockSSHChannel()) + require.NoError(t, err) + registry.sessions[runningSessionID] = sess + + tests := []struct { + name string + sessionID string + wantSameSessionID bool + }{ + { + name: "no session ID", + wantSameSessionID: false, + }, + // TODO(capnspacehook): Check that an error is returned in v17 + { + name: "new session ID", + sessionID: string(rsession.NewID()), + wantSameSessionID: false, + }, + { + name: "existing session ID", + sessionID: runningSessionID.String(), + wantSameSessionID: true, + }, + { + name: "existing session ID in Windows format", + sessionID: "{" + runningSessionID.String() + "}", + wantSameSessionID: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + parsedSessionID := new(rsession.ID) + var err error + if tt.sessionID != "" { + parsedSessionID, err = rsession.ParseID(tt.sessionID) + require.NoError(t, err) + } + + ctx := newTestServerContext(t, srv, nil) + if tt.sessionID != "" { + ctx.SetEnv(sshutils.SessionEnvVar, tt.sessionID) + } + + err = ctx.CreateOrJoinSession(registry) + require.NoError(t, err) + require.False(t, ctx.sessionID.IsZero()) + if tt.wantSameSessionID { + require.Equal(t, parsedSessionID.String(), ctx.sessionID.String()) + } else { + require.NotEqual(t, parsedSessionID.String(), ctx.sessionID.String()) + } + }) + } +} diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 202a3a875cfa6..a0e2a70adf93f 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -967,6 +967,13 @@ func (s *Server) handleGlobalRequest(ctx context.Context, req *ssh.Request) { return } // Pass request on unchanged. + case teleport.SessionIDQueryRequest: + // Reply true to session ID query requests, we will set new + // session IDs for new sessions + if err := req.Reply(true, nil); err != nil { + s.log.WithError(err).Warnf("Failed to reply to session ID query request") + } + return case teleport.KeepAliveReqType: default: s.log.Debugf("Rejecting unknown global request %q.", req.Type) diff --git a/lib/srv/mock.go b/lib/srv/mock.go index fa0859cdded69..f3302bf0dbf64 100644 --- a/lib/srv/mock.go +++ b/lib/srv/mock.go @@ -46,6 +46,7 @@ import ( "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" + rsession "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" ) @@ -65,16 +66,16 @@ func newTestServerContext(t *testing.T, srv Server, roleSet services.RoleSet) *S recConfig := types.DefaultSessionRecordingConfig() recConfig.SetMode(types.RecordOff) clusterName := "localhost" + _, connCtx := sshutils.NewConnectionContext(ctx, nil, &ssh.ServerConn{Conn: sshConn}) scx := &ServerContext{ - Entry: logrus.NewEntry(logrus.StandardLogger()), - ConnectionContext: &sshutils.ConnectionContext{ - ServerConn: &ssh.ServerConn{Conn: sshConn}, - }, + Entry: logrus.NewEntry(logrus.StandardLogger()), + ConnectionContext: connCtx, env: make(map[string]string), SessionRecordingConfig: recConfig, IsTestStub: true, ClusterName: clusterName, srv: srv, + sessionID: rsession.NewID(), Identity: IdentityContext{ Login: usr.Username, TeleportUser: "teleportUser", diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 74e7ce3f10c99..ed8e1c347fbf6 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -1381,6 +1381,13 @@ func (s *Server) HandleRequest(ctx context.Context, ccx *sshutils.ConnectionCont s.Logger.Warnf("Failed to reply to %q request: %v", r.Type, err) } } + case teleport.SessionIDQueryRequest: + // Reply true to session ID query requests, we will set new + // session IDs for new sessions + if err := r.Reply(true, nil); err != nil { + s.Logger.WithError(err).Warnf("Failed to reply to session ID query request") + } + return default: if r.WantReply { if err := r.Reply(false, nil); err != nil { diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 91f56772ccc52..0204c1529c399 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -55,7 +55,6 @@ import ( "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/services" rsession "github.com/gravitational/teleport/lib/session" - "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/sshutils/sftp" "github.com/gravitational/teleport/lib/utils" ) @@ -325,20 +324,18 @@ func (s *SessionRegistry) OpenSession(ctx context.Context, ch ssh.Channel, scx * return trace.AccessDenied("join-only mode was used to create this connection but attempted to create a new session.") } - // session not found? need to create one. start by getting/generating an ID for it - sid, found := scx.GetEnv(sshutils.SessionEnvVar) - if !found { - sid = string(rsession.NewID()) - scx.SetEnv(sshutils.SessionEnvVar, sid) + sid := scx.SessionID() + if sid.IsZero() { + return trace.BadParameter("session ID is not set") } // This logic allows concurrent request to create a new session // to fail, what is ok because we should never have this condition - sess, p, err := newSession(ctx, rsession.ID(sid), s, scx, ch) + sess, p, err := newSession(ctx, sid, s, scx, ch) if err != nil { return trace.Wrap(err) } - scx.setSession(sess) + scx.setSession(sess, ch) s.addSession(sess) scx.Infof("Creating (interactive) session %v.", sid) @@ -353,22 +350,18 @@ func (s *SessionRegistry) OpenSession(ctx context.Context, ch ssh.Channel, scx * // OpenExecSession opens a non-interactive exec session. func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Channel, scx *ServerContext) error { - var sessionID rsession.ID + sessionID := scx.SessionID() - sid, found := scx.GetEnv(sshutils.SessionEnvVar) - if !found { - // Create a new session ID. These sessions can not be joined + if sessionID.IsZero() { sessionID = rsession.NewID() scx.Tracef("Session not found, creating a new session %s", sessionID) } else { // Use passed session ID. Assist uses this "feature" to record // the execution output. - sessionID = rsession.ID(sid) scx.Tracef("Session found, reusing it %s", sessionID) } - _, found = scx.GetEnv(teleport.EnableNonInteractiveSessionRecording) - if found { + if _, found := scx.GetEnv(teleport.EnableNonInteractiveSessionRecording); found { scx.recordNonInteractiveSession = true } @@ -398,7 +391,7 @@ func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Chann // Start a non-interactive session (TTY attached). Close the session if an error // occurs, otherwise it will be closed by the callee. - scx.setSession(sess) + scx.setSession(sess, channel) err = sess.startExec(ctx, channel, scx) if err != nil { diff --git a/lib/srv/sess_test.go b/lib/srv/sess_test.go index f55c4c517ac7c..1cd0acc2870b4 100644 --- a/lib/srv/sess_test.go +++ b/lib/srv/sess_test.go @@ -609,10 +609,10 @@ func TestParties(t *testing.T) { func testJoinSession(t *testing.T, reg *SessionRegistry, sess *session) { scx := newTestServerContext(t, reg.Srv, nil) - scx.setSession(sess) + sshChanOpen := newMockSSHChannel() + scx.setSession(sess, sshChanOpen) // Open a new session - sshChanOpen := newMockSSHChannel() go func() { // Consume stdout sent to the channel io.ReadAll(sshChanOpen) diff --git a/lib/web/sessions.go b/lib/web/sessions.go index 126d143a30d45..90594cc1ed60d 100644 --- a/lib/web/sessions.go +++ b/lib/web/sessions.go @@ -27,6 +27,7 @@ import ( "io" "net" "sync" + "sync/atomic" "time" "github.com/gravitational/trace" @@ -56,6 +57,7 @@ import ( "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/session" alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/tlsca" @@ -1004,7 +1006,6 @@ func (s *sessionCache) getOrCreateSession(ctx context.Context, user, sessionID s } return sctx, nil - } func (s *sessionCache) invalidateSession(ctx context.Context, sctx *SessionContext) error { @@ -1326,3 +1327,84 @@ func (c *remoteClientCache) Close() error { return trace.NewAggregate(errors...) } + +// sessionIDStatus indicates whether the session ID was received from +// the server or not, and if not why +type sessionIDStatus int + +const ( + // sessionIDReceived indicates the the session ID was received + sessionIDReceived sessionIDStatus = iota + 1 + // sessionIDNotSent indicates that the server set the session ID + // but didn't send it to us + sessionIDNotSent + // sessionIDNotModified indicates that the server used the session + // ID that was set by us + sessionIDNotModified +) + +// prepareToReceiveSessionID configures the TeleportClient to listen for +// the server to send the session ID it's using. The returned function +// will return the current session ID from the server or a reason why +// one wasn't received. +func prepareToReceiveSessionID(ctx context.Context, log *logrus.Entry, nc *client.NodeClient) func() (session.ID, sessionIDStatus) { + // send the session ID received from the server + var gotSessionID atomic.Bool + sessionIDFromServer := make(chan session.ID, 1) + nc.TC.OnChannelRequest = func(req *ssh.Request) *ssh.Request { + // ignore unrelated requests and handle only the first session + // ID request + if req.Type != teleport.CurrentSessionIDRequest || gotSessionID.Load() { + return req + } + + sid, err := session.ParseID(string(req.Payload)) + if err != nil { + log.WithError(err).Warn("Unable to parse session ID.") + return nil + } + + if gotSessionID.CompareAndSwap(false, true) { + sessionIDFromServer <- *sid + } + + return nil + } + + // If the session is about to close and we haven't received a session + // ID yet, ask if the server even supports sending one. Send the + // request in a new goroutine so session establishment won't be + // blocked on making this request + serverWillSetSessionID := make(chan bool, 1) + go func() { + resp, _, err := nc.Client.SendRequest(ctx, teleport.SessionIDQueryRequest, true, nil) + if err != nil { + log.WithError(err).Warn("Failed to send session ID query request") + serverWillSetSessionID <- false + } else { + serverWillSetSessionID <- resp + } + }() + + return func() (session.ID, sessionIDStatus) { + timer := time.NewTimer(10 * time.Second) + defer timer.Stop() + + for { + select { + case sessionID := <-sessionIDFromServer: + return sessionID, sessionIDReceived + case sessionIDIsComing := <-serverWillSetSessionID: + if !sessionIDIsComing { + return session.ID(""), sessionIDNotModified + } + // the server will send the session ID, continue + // waiting for it + case <-ctx.Done(): + return session.ID(""), sessionIDNotSent + case <-timer.C: + return session.ID(""), sessionIDNotSent + } + } + } +} diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 14d73b08bbbec..9a4a3b474036f 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -201,9 +201,11 @@ type TerminalHandlerConfig struct { func (t *TerminalHandlerConfig) CheckAndSetDefaults() error { // Make sure whatever session is requested is a valid session id. - _, err := session.ParseID(t.SessionData.ID.String()) - if err != nil { - return trace.BadParameter("sid: invalid session id") + if !t.SessionData.ID.IsZero() { + _, err := session.ParseID(t.SessionData.ID.String()) + if err != nil { + return trace.BadParameter("sid: invalid session id") + } } if t.SessionData.Login == "" { @@ -349,7 +351,7 @@ func (t *TerminalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { t.handler(ws, r) } -func (t *TerminalHandler) writeSessionData(ctx context.Context) error { +func (t *TerminalHandler) writeSessionData() error { envelope := &terminal.Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketSessionMetadata, @@ -786,10 +788,6 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor return } - if err := t.writeSessionData(ctx); err != nil { - t.log.WithError(err).Warn("Unable to stream terminal - failure sending session data") - } - var beforeStart func(io.Writer) if t.participantMode == types.SessionModeratorMode { beforeStart = func(out io.Writer) { @@ -810,9 +808,47 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor } }() + sessionDataSent := make(chan struct{}) + // If we are joining a session, send the session data right away, we + // know the session ID + if t.tracker != nil { + if err := t.writeSessionData(); err != nil { + t.log.WithError(err).Warn("Failure sending session data") + } + close(sessionDataSent) + } else { + // We are creating a new session and the server will generate a + // new session ID, send the session data once the session is + // created and the server sends us the session ID it is using + writeSessionCtx, writeSessionCancel := context.WithCancel(ctx) + defer writeSessionCancel() + waitForSessionID := prepareToReceiveSessionID(writeSessionCtx, t.log, nc) + + // wait in a new goroutine because the server won't set a + // session ID until we open a shell + go func() { + defer close(sessionDataSent) + + sid, status := waitForSessionID() + switch status { + case sessionIDReceived: + t.sessionData.ID = sid + fallthrough + case sessionIDNotModified: + if err := t.writeSessionData(); err != nil { + t.log.WithError(err).Warn("Failure sending session data") + } + case sessionIDNotSent: + t.log.Warn("Failed to receive session data") + default: + t.log.Warnf("Invalid session ID status %v", status) + } + }() + } + // Establish SSH connection to the server. This function will block until // either an error occurs or it completes successfully. - if err = nc.RunInteractiveShell(ctx, t.participantMode, t.tracker, beforeStart); err != nil { + if err = nc.RunInteractiveShell(ctx, t.participantMode, t.tracker, nil, beforeStart); err != nil { if !t.closedByClient.Load() { t.stream.WriteError(err.Error()) } @@ -823,6 +859,9 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor return } + // Wait for the session data to be sent before closing the session + <-sessionDataSent + // Send close envelope to web terminal upon exit without an error. if err := t.stream.SendCloseMessage(t.sessionData.ServerID); err != nil { t.log.WithError(err).Error("Unable to send close event to web client.") diff --git a/lib/web/terminal/terminal.go b/lib/web/terminal/terminal.go index 2ad18331a673b..c92dae03fb2d1 100644 --- a/lib/web/terminal/terminal.go +++ b/lib/web/terminal/terminal.go @@ -87,9 +87,11 @@ type WSStream struct { // fit into the buffer provided by the callee to Read method buffer []byte - // mu protects writes to ws - mu sync.Mutex - // ws the connection to the UI + // readMu protects reads to WSConn + readMu sync.Mutex + // writeMu protects writes to WSConn + writeMu sync.Mutex + // WSConn the connection to the UI WSConn // log holds the structured logger. @@ -99,6 +101,18 @@ type WSStream struct { // Replace \n with \r\n so the message is correctly aligned. var replacer = strings.NewReplacer("\r\n", "\r\n", "\n", "\r\n") +func (t *WSStream) ReadMessage() (messageType int, p []byte, err error) { + t.readMu.Lock() + defer t.readMu.Unlock() + return t.WSConn.ReadMessage() +} + +func (t *WSStream) WriteMessage(messageType int, data []byte) error { + t.writeMu.Lock() + defer t.writeMu.Unlock() + return t.WSConn.WriteMessage(messageType, data) +} + // WriteError displays an error in the terminal window. func (t *WSStream) WriteError(msg string) { if _, writeErr := replacer.WriteString(t, msg); writeErr != nil { @@ -214,9 +228,7 @@ func (t *WSStream) WriteChallenge(challenge *client.MFAAuthenticateChallenge, co return trace.Wrap(err) } - t.mu.Lock() - defer t.mu.Unlock() - return trace.Wrap(t.WSConn.WriteMessage(websocket.BinaryMessage, msg)) + return trace.Wrap(t.WriteMessage(websocket.BinaryMessage, msg)) } // ReadChallengeResponse reads and decodes the challenge response from the @@ -262,9 +274,7 @@ func (t *WSStream) WriteAuditEvent(event []byte) error { } // Send bytes over the websocket to the web client. - t.mu.Lock() - defer t.mu.Unlock() - return trace.Wrap(t.WSConn.WriteMessage(websocket.BinaryMessage, envelopeBytes)) + return trace.Wrap(t.WriteMessage(websocket.BinaryMessage, envelopeBytes)) } // SSHSessionLatencyStats contain latency measurements for both @@ -302,9 +312,7 @@ func (t *WSStream) WriteLatency(latency SSHSessionLatencyStats) error { } // Send bytes over the websocket to the web client. - t.mu.Lock() - defer t.mu.Unlock() - return trace.Wrap(t.WSConn.WriteMessage(websocket.BinaryMessage, envelopeBytes)) + return trace.Wrap(t.WriteMessage(websocket.BinaryMessage, envelopeBytes)) } // Write wraps the data bytes in a raw envelope and sends. @@ -325,9 +333,7 @@ func (t *WSStream) Write(data []byte) (n int, err error) { } // Send bytes over the websocket to the web client. - t.mu.Lock() - err = t.WSConn.WriteMessage(websocket.BinaryMessage, envelopeBytes) - t.mu.Unlock() + err = t.WriteMessage(websocket.BinaryMessage, envelopeBytes) if err != nil { return 0, trace.Wrap(err) } @@ -391,9 +397,7 @@ func (t *WSStream) SendCloseMessage(id string) error { return trace.Wrap(err) } - t.mu.Lock() - defer t.mu.Unlock() - return trace.Wrap(t.WSConn.WriteMessage(websocket.BinaryMessage, envelopeBytes)) + return trace.Wrap(t.WriteMessage(websocket.BinaryMessage, envelopeBytes)) } func (t *WSStream) close() {