diff --git a/api/observability/tracing/ssh/client.go b/api/observability/tracing/ssh/client.go index e79686adf8b42..dd0a20c1d3217 100644 --- a/api/observability/tracing/ssh/client.go +++ b/api/observability/tracing/ssh/client.go @@ -20,6 +20,7 @@ import ( "fmt" "net" "sync" + "sync/atomic" "github.com/gravitational/trace" "go.opentelemetry.io/otel/attribute" @@ -164,6 +165,18 @@ func (c *Client) OpenChannel( // NewSession creates a new SSH session that is passed tracing context // so that spans may be correlated properly over the ssh connection. func (c *Client) NewSession(ctx context.Context) (*Session, error) { + return c.newSession(ctx, nil) +} + +// NewSessionWithRequestCallback creates a new SSH session that is passed +// tracing context so that spans may be correlated properly over the ssh +// connection. The handling of channel requests from the underlying SSH +// session can be controlled with chanReqCallback. +func (c *Client) NewSessionWithRequestCallback(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) { + return c.newSession(ctx, chanReqCallback) +} + +func (c *Client) newSession(ctx context.Context, chanReqCallback ChannelRequestCallback) (*Session, error) { tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName) ctx, span := tracer.Start( @@ -191,7 +204,7 @@ func (c *Client) NewSession(ctx context.Context) (*Session, error) { } // get a session from the wrapper - session, err := wrapper.NewSession() + session, err := wrapper.NewSession(chanReqCallback) return session, trace.Wrap(err) } @@ -216,8 +229,13 @@ type clientWrapper struct { contexts map[string][]context.Context } +// ChannelRequestCallback allows the handling of channel requests +// to be customized. nil can be returned if you don't want +// golang/x/crypto/ssh to handle the request. +type ChannelRequestCallback func(req *ssh.Request) *ssh.Request + // NewSession opens a new Session for this client. -func (c *clientWrapper) NewSession() (*Session, error) { +func (c *clientWrapper) NewSession(callback ChannelRequestCallback) (*Session, error) { // create a client that will defer to us when // opening the "session" channel so that we // can add an Envelope to the request @@ -225,9 +243,40 @@ func (c *clientWrapper) NewSession() (*Session, error) { Conn: c, } - session, err := client.NewSession() - if err != nil { - return nil, trace.Wrap(err) + var session *ssh.Session + var err error + if callback != nil { + // open a session manually so we can take ownership of the + // requests chan + ch, originalReqs, openChannelErr := client.OpenChannel("session", nil) + if openChannelErr != nil { + return nil, trace.Wrap(openChannelErr) + } + + // pass the channel requests to the provided callback and + // forward them to another chan so golang.org/x/crypto/ssh + // can handle Session exiting correctly + reqs := make(chan *ssh.Request, cap(originalReqs)) + go func() { + defer close(reqs) + + for req := range originalReqs { + if req := callback(req); req != nil { + reqs <- req + } + } + }() + + session, err = newCryptoSSHSession(ch, reqs) + if err != nil { + _ = ch.Close() + return nil, trace.Wrap(err) + } + } else { + session, err = client.NewSession() + if err != nil { + return nil, trace.Wrap(err) + } } // wrap the session so all session requests on the channel @@ -238,6 +287,39 @@ func (c *clientWrapper) NewSession() (*Session, error) { }, nil } +// wrappedSSHConn allows an SSH session to be created while also allowing +// callers to take ownership of the SSH channel requests chan. +type wrappedSSHConn struct { + ssh.Conn + + channelOpened atomic.Bool + + ch ssh.Channel + reqs <-chan *ssh.Request +} + +func (f *wrappedSSHConn) OpenChannel(_ string, _ []byte) (ssh.Channel, <-chan *ssh.Request, error) { + if !f.channelOpened.CompareAndSwap(false, true) { + panic("wrappedSSHConn OpenChannel called more than once") + } + + return f.ch, f.reqs, nil +} + +// newCryptoSSHSession allows callers to take ownership of the SSH +// channel requests chan and allow callers to handle SSH channel requests. +// golang.org/x/crypto/ssh.(Client).NewSession takes ownership of all +// SSH channel requests and doesn't allow the caller to view or reply +// to them, so this workaround is needed. +func newCryptoSSHSession(ch ssh.Channel, reqs <-chan *ssh.Request) (*ssh.Session, error) { + return (&ssh.Client{ + Conn: &wrappedSSHConn{ + ch: ch, + reqs: reqs, + }, + }).NewSession() +} + // Dial initiates a connection to the addr from the remote host. func (c *clientWrapper) Dial(n, addr string) (net.Conn, error) { // create a client that will defer to us when diff --git a/api/observability/tracing/ssh/client_test.go b/api/observability/tracing/ssh/client_test.go index 3fb7a297369ac..b59549f2181bc 100644 --- a/api/observability/tracing/ssh/client_test.go +++ b/api/observability/tracing/ssh/client_test.go @@ -242,3 +242,29 @@ func TestSetEnvs(t *testing.T) { default: } } + +type mockSSHChannel struct { + ssh.Channel +} + +func TestWrappedSSHConn(t *testing.T) { + sshCh := new(mockSSHChannel) + reqs := make(<-chan *ssh.Request) + + // ensure that OpenChannel returns the same SSH channel and requests + // chan that wrappedSSHConn was given + wrappedConn := &wrappedSSHConn{ + ch: sshCh, + reqs: reqs, + } + retCh, retReqs, err := wrappedConn.OpenChannel("", nil) + require.NoError(t, err) + require.Equal(t, sshCh, retCh) + require.Equal(t, reqs, retReqs) + + // ensure the wrapped SSH conn will panic if OpenChannel is called + // twice + require.Panics(t, func() { + wrappedConn.OpenChannel("", nil) + }) +} diff --git a/constants.go b/constants.go index 5d068f67bb6cd..e02ddac503a62 100644 --- a/constants.go +++ b/constants.go @@ -731,6 +731,14 @@ const ( // version they are running. VersionRequest = "x-teleport-version" + // CurrentSessionIDRequest is sent by servers to inform clients of + // the session ID that is being used. + CurrentSessionIDRequest = "current-session-id@goteleport.com" + + // SessionIDQueryRequest is sent by clients to ask servers if they + // will generate their own session ID when a new session is created. + SessionIDQueryRequest = "session-id-query@goteleport.com" + // ForceTerminateRequest is an SSH request to forcefully terminate a session. ForceTerminateRequest = "x-teleport-force-terminate" diff --git a/integration/integration_test.go b/integration/integration_test.go index fc5dd0fa7a228..7742ad071e2fd 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -1236,7 +1236,7 @@ func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) { ) assert.NoError(t, err) - errCh <- nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil) + errCh <- nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil, nil) assert.NoError(t, nodeClient.Close()) }() diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index b0e8e2b4f0f78..db0820f3df0dc 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -5717,7 +5717,7 @@ func TestLocalServiceRolesHavePermissionsForUploaderService(t *testing.T) { } t.Run("GetSessionTracker", func(t *testing.T) { - sid := session.ID("test-session") + sid := session.NewID() tracker, err := s.CreateSessionTracker(ctx, &types.SessionTrackerV1{ ResourceHeader: types.ResourceHeader{ Metadata: types.Metadata{ diff --git a/lib/client/api.go b/lib/client/api.go index 49e5aadc5212d..8dfce0e029be5 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -1111,6 +1111,10 @@ type TeleportClient struct { Config localAgent *LocalKeyAgent + // OnChannelRequest gets called when SSH channel requests are + // received. It's safe to keep it nil. + OnChannelRequest tracessh.ChannelRequestCallback + // OnShellCreated gets called when the shell is created. It's // safe to keep it nil. OnShellCreated ShellCreatedCallback @@ -1941,7 +1945,7 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt // Reuse the existing nodeClient we connected above. return nodeClient.RunCommand(ctx, command) } - return trace.Wrap(nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil)) + return trace.Wrap(nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, tc.OnChannelRequest, nil)) } func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, clt *ClusterClient, nodes []targetNode, command []string) error { @@ -2100,7 +2104,7 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan fmt.Printf("Joining session with participant mode: %v. \n\n", mode) // running shell with a given session means "join" it: - err = nc.RunInteractiveShell(ctx, mode, session, beforeStart) + err = nc.RunInteractiveShell(ctx, mode, session, tc.OnChannelRequest, beforeStart) return trace.Wrap(err) } diff --git a/lib/client/client.go b/lib/client/client.go index 1df323824e2a2..9fe7bd1508486 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -366,7 +366,7 @@ func NewNodeClient(ctx context.Context, sshConfig *ssh.ClientConfig, conn net.Co // RunInteractiveShell creates an interactive shell on the node and copies stdin/stdout/stderr // to and from the node and local shell. This will block until the interactive shell on the node // is terminated. -func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker, beforeStart func(io.Writer)) error { +func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker, chanReqCallback tracessh.ChannelRequestCallback, beforeStart func(io.Writer)) error { ctx, span := c.Tracer.Start( ctx, "nodeClient/RunInteractiveShell", @@ -395,7 +395,7 @@ func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.Session return trace.Wrap(err) } - if err = nodeSession.runShell(ctx, mode, beforeStart, c.TC.OnShellCreated); err != nil { + if err = nodeSession.runShell(ctx, mode, c.TC.OnChannelRequest, beforeStart, c.TC.OnShellCreated); err != nil { var exitErr *ssh.ExitError var exitMissingErr *ssh.ExitMissingError switch err := trace.Unwrap(err); { @@ -539,7 +539,7 @@ func (c *NodeClient) RunCommand(ctx context.Context, command []string, opts ...R return trace.Wrap(err) } defer nodeSession.Close() - if err := nodeSession.runCommand(ctx, types.SessionPeerMode, command, c.TC.OnShellCreated, c.TC.Config.InteractiveCommand); err != nil { + if err := nodeSession.runCommand(ctx, types.SessionPeerMode, command, c.TC.OnChannelRequest, c.TC.OnShellCreated, c.TC.Config.InteractiveCommand); err != nil { originErr := trace.Unwrap(err) var exitErr *ssh.ExitError if errors.As(originErr, &exitErr) { diff --git a/lib/client/client_test.go b/lib/client/client_test.go index 43a1017266a32..8a7273eb04d95 100644 --- a/lib/client/client_test.go +++ b/lib/client/client_test.go @@ -248,7 +248,6 @@ func TestListenAndForwardCancel(t *testing.T) { } }) } - } func newTestListener(t *testing.T, handle func(net.Conn)) net.Listener { diff --git a/lib/client/session.go b/lib/client/session.go index eda85b389be4a..4bf05e0b8bbff 100644 --- a/lib/client/session.go +++ b/lib/client/session.go @@ -157,6 +157,10 @@ func newSession(ctx context.Context, } // new session! } else { + // TODO(capnspacehook): DELETE IN 17.0.0 + // clients shouldn't set TELEPORT_SESSION when they aren't joining + // a session, and won't need to once all supported Proxy/Node + // versions set the session ID for new sessions sid, ok := ns.env[sshutils.SessionEnvVar] if !ok { sid = string(session.NewID()) @@ -188,7 +192,7 @@ func (ns *NodeSession) NodeClient() *NodeClient { return ns.nodeClient } -func (ns *NodeSession) regularSession(ctx context.Context, callback func(s *tracessh.Session) error) error { +func (ns *NodeSession) regularSession(ctx context.Context, chanReqCallback tracessh.ChannelRequestCallback, sessionCallback func(s *tracessh.Session) error) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/regularSession", @@ -196,19 +200,19 @@ func (ns *NodeSession) regularSession(ctx context.Context, callback func(s *trac ) defer span.End() - session, err := ns.createServerSession(ctx) + session, err := ns.createServerSession(ctx, chanReqCallback) if err != nil { return trace.Wrap(err) } session.Stdout = ns.terminal.Stdout() session.Stderr = ns.terminal.Stderr() session.Stdin = ns.terminal.Stdin() - return trace.Wrap(callback(session)) + return trace.Wrap(sessionCallback(session)) } type interactiveCallback func(serverSession *tracessh.Session, shell io.ReadWriteCloser) error -func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Session, error) { +func (ns *NodeSession) createServerSession(ctx context.Context, chanReqCallback tracessh.ChannelRequestCallback) (*tracessh.Session, error) { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/createServerSession", @@ -216,7 +220,7 @@ func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Sessi ) defer span.End() - sess, err := ns.nodeClient.Client.NewSession(ctx) + sess, err := ns.nodeClient.Client.NewSessionWithRequestCallback(ctx, chanReqCallback) if err != nil { return nil, trace.Wrap(err) } @@ -284,7 +288,7 @@ func selectKeyAgent(tc *TeleportClient) agent.ExtendedAgent { // interactiveSession creates an interactive session on the remote node, executes // the given callback on it, and waits for the session to end -func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.SessionParticipantMode, callback interactiveCallback) error { +func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.SessionParticipantMode, chanReqCallback tracessh.ChannelRequestCallback, sessionCallback interactiveCallback) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/interactiveSession", @@ -298,7 +302,7 @@ func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.Sessio termType = teleport.SafeTerminalType } // create the server-side session: - sess, err := ns.createServerSession(ctx) + sess, err := ns.createServerSession(ctx, chanReqCallback) if err != nil { return trace.Wrap(err) } @@ -311,7 +315,7 @@ func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.Sessio // call the passed callback and give them the established // ssh session: - if err := callback(sess, remoteTerm); err != nil { + if err := sessionCallback(sess, remoteTerm); err != nil { return trace.Wrap(err) } @@ -514,8 +518,8 @@ func (s *sessionWriter) Write(p []byte) (int, error) { } // runShell executes user's shell on the remote node under an interactive session -func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipantMode, beforeStart func(io.Writer), callback ShellCreatedCallback) error { - return ns.interactiveSession(ctx, mode, func(s *tracessh.Session, shell io.ReadWriteCloser) error { +func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipantMode, chanReqCallback tracessh.ChannelRequestCallback, beforeStart func(io.Writer), shellCallback ShellCreatedCallback) error { + return ns.interactiveSession(ctx, mode, chanReqCallback, func(s *tracessh.Session, shell io.ReadWriteCloser) error { w := &sessionWriter{ tshOut: ns.nodeClient.TC.Stdout, session: s, @@ -531,8 +535,8 @@ func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipa } // call the client-supplied callback - if callback != nil { - exit, err := callback(s, ns.nodeClient.Client, shell) + if shellCallback != nil { + exit, err := shellCallback(s, ns.nodeClient.Client, shell) if exit { return trace.Wrap(err) } @@ -543,7 +547,7 @@ func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipa // runCommand executes a "exec" request either in interactive mode (with a // TTY attached) or non-intractive mode (no TTY). -func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionParticipantMode, cmd []string, callback ShellCreatedCallback, interactive bool) error { +func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionParticipantMode, cmd []string, chanReqCallback tracessh.ChannelRequestCallback, shellCallback ShellCreatedCallback, interactive bool) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/runCommand", @@ -564,13 +568,13 @@ func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionPartici // keyboard based signals will be propogated to the TTY on the server which is // where all signal handling will occur. if interactive { - return ns.interactiveSession(ctx, mode, func(s *tracessh.Session, term io.ReadWriteCloser) error { + return ns.interactiveSession(ctx, mode, chanReqCallback, func(s *tracessh.Session, term io.ReadWriteCloser) error { err := s.Start(ctx, strings.Join(cmd, " ")) if err != nil { return trace.Wrap(err) } - if callback != nil { - exit, err := callback(s, ns.NodeClient().Client, term) + if shellCallback != nil { + exit, err := shellCallback(s, ns.NodeClient().Client, term) if exit { return trace.Wrap(err) } @@ -591,7 +595,7 @@ func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionPartici // Unfortunately at the moment the Go SSH library Teleport uses does not // support sending SSH_MSG_DISCONNECT. Instead we close the SSH channel and // SSH client, and try and exit as gracefully as possible. - return ns.regularSession(ctx, func(s *tracessh.Session) error { + return ns.regularSession(ctx, chanReqCallback, func(s *tracessh.Session) error { errCh := make(chan error, 1) go func() { errCh <- s.Run(ctx, strings.Join(cmd, " ")) diff --git a/lib/services/local/sessiontracker.go b/lib/services/local/sessiontracker.go index f24c5f88ba820..780344b208013 100644 --- a/lib/services/local/sessiontracker.go +++ b/lib/services/local/sessiontracker.go @@ -192,7 +192,7 @@ func (s *sessionTracker) CreateSessionTracker(ctx context.Context, tracker types Value: json, Expires: tracker.Expiry(), } - _, err = s.bk.Put(ctx, item) + _, err = s.bk.Create(ctx, item) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/session/session.go b/lib/session/session.go index 76c02d1110029..c145823e5f79b 100644 --- a/lib/session/session.go +++ b/lib/session/session.go @@ -54,17 +54,22 @@ func (s *ID) Check() error { // ParseID parses ID and checks if it's correct. func ParseID(id string) (*ID, error) { - _, err := uuid.Parse(id) + parsed, err := uuid.Parse(id) if err != nil { return nil, trace.BadParameter("%v is not a valid UUID", id) } - uid := ID(id) + // use the parsed UUID to build the ID instead of the string that + // was passed in. id is user controlled and uuid.Parse accepts + // several UUID formats that are not supported correctly across + // Teleport. (uuid.UUID).String always uses the same format that + // is supported by Teleport everywhere, so use that. + uid := ID(parsed.String()) return &uid, nil } // NewID returns new session ID. The session ID is based on UUIDv4. func NewID() ID { - return ID(uuid.New().String()) + return ID(uuid.NewString()) } // Session is a session of any kind (SSH, Kubernetes, Desktop, etc) diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index f836a2f3e4748..815b94df99a12 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -319,6 +319,10 @@ type ServerContext struct { // term holds PTY if it was requested by the session. term Terminal + // sessionID holds the session ID that will be used when a new + // session is created. + sessionID rsession.ID + // session holds the active session (if there's an active one). session *session @@ -634,7 +638,7 @@ func (c *ServerContext) SessionID() rsession.ID { c.mu.RLock() defer c.mu.RUnlock() if c.session == nil { - return "" + return c.sessionID } return c.session.id } @@ -653,19 +657,31 @@ func (c *ServerContext) CreateOrJoinSession(reg *SessionRegistry) error { // its ID will be added to the environment ssid, found := c.getEnvLocked(sshutils.SessionEnvVar) if !found { + c.sessionID = rsession.NewID() + c.Logger.Debugf("Will create new session for SSH connection %v.", c.ServerConn.RemoteAddr()) return nil } + // make sure whatever session is requested is a valid session id, err := rsession.ParseID(ssid) if err != nil { - return trace.BadParameter("invalid session id") + return trace.BadParameter("invalid session ID") } // update ctx with the session if it exists if sess, found := reg.findSession(*id); found { + c.sessionID = *id c.session = sess c.Logger.Debugf("Will join session %v for SSH connection %v.", c.session.id, c.ServerConn.RemoteAddr()) } else { + // TODO(capnspacehook): DELETE IN 17.0.0 - by then all supported + // clients should only set TELEPORT_SESSION when they want to + // join a session. Always return an error instead of using a + // new ID. + // + // to prevent the user from controlling the session ID, generate + // a new one + c.sessionID = rsession.NewID() c.Logger.Debugf("Will create new session for SSH connection %v.", c.ServerConn.RemoteAddr()) } @@ -737,10 +753,20 @@ func (c *ServerContext) getEnvLocked(key string) (string, bool) { } // setSession sets the context's session -func (c *ServerContext) setSession(sess *session) { +func (c *ServerContext) setSession(sess *session, ch ssh.Channel) { c.mu.Lock() defer c.mu.Unlock() c.session = sess + + // inform the client of the session ID that is being used in a new + // goroutine to reduce latency + go func() { + c.Logger.Debug("Sending current session ID.") + _, err := ch.SendRequest(teleport.CurrentSessionIDRequest, false, []byte(sess.ID())) + if err != nil { + c.Logger.WithError(err).Debug("Failed to send the current session ID.") + } + }() } // getSession returns the context's session diff --git a/lib/srv/ctx_test.go b/lib/srv/ctx_test.go index d85ea20f7001f..4a630e8d9741d 100644 --- a/lib/srv/ctx_test.go +++ b/lib/srv/ctx_test.go @@ -20,6 +20,7 @@ package srv import ( "bytes" + "context" "os/user" "testing" @@ -34,6 +35,8 @@ import ( "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/services" + rsession "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/teleport/lib/sshutils" ) // TestDecodeChildError ensures that child error message marshaling @@ -302,3 +305,75 @@ type fixedRolesChecker struct { func (c *fixedRolesChecker) RoleNames() []string { return c.roleNames } + +func TestCreateOrJoinSession(t *testing.T) { + t.Parallel() + + srv := newMockServer(t) + registry, err := NewSessionRegistry(SessionRegistryConfig{ + clock: srv.clock, + Srv: srv, + SessionTrackerService: srv.auth, + }) + require.NoError(t, err) + + runningSessionID := rsession.NewID() + sess, _, err := newSession(context.Background(), runningSessionID, registry, newTestServerContext(t, srv, nil), newMockSSHChannel()) + require.NoError(t, err) + registry.sessions[runningSessionID] = sess + + tests := []struct { + name string + sessionID string + wantSameSessionID bool + }{ + { + name: "no session ID", + wantSameSessionID: false, + }, + // TODO(capnspacehook): Check that an error is returned in v17 + { + name: "new session ID", + sessionID: string(rsession.NewID()), + wantSameSessionID: false, + }, + { + name: "existing session ID", + sessionID: runningSessionID.String(), + wantSameSessionID: true, + }, + { + name: "existing session ID in Windows format", + sessionID: "{" + runningSessionID.String() + "}", + wantSameSessionID: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + parsedSessionID := new(rsession.ID) + var err error + if tt.sessionID != "" { + parsedSessionID, err = rsession.ParseID(tt.sessionID) + require.NoError(t, err) + } + + ctx := newTestServerContext(t, srv, nil) + if tt.sessionID != "" { + ctx.SetEnv(sshutils.SessionEnvVar, tt.sessionID) + } + + err = ctx.CreateOrJoinSession(registry) + require.NoError(t, err) + require.False(t, ctx.sessionID.IsZero()) + if tt.wantSameSessionID { + require.Equal(t, parsedSessionID.String(), ctx.sessionID.String()) + } else { + require.NotEqual(t, parsedSessionID.String(), ctx.sessionID.String()) + } + }) + } +} diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 202a3a875cfa6..a0e2a70adf93f 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -967,6 +967,13 @@ func (s *Server) handleGlobalRequest(ctx context.Context, req *ssh.Request) { return } // Pass request on unchanged. + case teleport.SessionIDQueryRequest: + // Reply true to session ID query requests, we will set new + // session IDs for new sessions + if err := req.Reply(true, nil); err != nil { + s.log.WithError(err).Warnf("Failed to reply to session ID query request") + } + return case teleport.KeepAliveReqType: default: s.log.Debugf("Rejecting unknown global request %q.", req.Type) diff --git a/lib/srv/mock.go b/lib/srv/mock.go index fa0859cdded69..f3302bf0dbf64 100644 --- a/lib/srv/mock.go +++ b/lib/srv/mock.go @@ -46,6 +46,7 @@ import ( "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" + rsession "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" ) @@ -65,16 +66,16 @@ func newTestServerContext(t *testing.T, srv Server, roleSet services.RoleSet) *S recConfig := types.DefaultSessionRecordingConfig() recConfig.SetMode(types.RecordOff) clusterName := "localhost" + _, connCtx := sshutils.NewConnectionContext(ctx, nil, &ssh.ServerConn{Conn: sshConn}) scx := &ServerContext{ - Entry: logrus.NewEntry(logrus.StandardLogger()), - ConnectionContext: &sshutils.ConnectionContext{ - ServerConn: &ssh.ServerConn{Conn: sshConn}, - }, + Entry: logrus.NewEntry(logrus.StandardLogger()), + ConnectionContext: connCtx, env: make(map[string]string), SessionRecordingConfig: recConfig, IsTestStub: true, ClusterName: clusterName, srv: srv, + sessionID: rsession.NewID(), Identity: IdentityContext{ Login: usr.Username, TeleportUser: "teleportUser", diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 74e7ce3f10c99..ed8e1c347fbf6 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -1381,6 +1381,13 @@ func (s *Server) HandleRequest(ctx context.Context, ccx *sshutils.ConnectionCont s.Logger.Warnf("Failed to reply to %q request: %v", r.Type, err) } } + case teleport.SessionIDQueryRequest: + // Reply true to session ID query requests, we will set new + // session IDs for new sessions + if err := r.Reply(true, nil); err != nil { + s.Logger.WithError(err).Warnf("Failed to reply to session ID query request") + } + return default: if r.WantReply { if err := r.Reply(false, nil); err != nil { diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 91f56772ccc52..0204c1529c399 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -55,7 +55,6 @@ import ( "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/services" rsession "github.com/gravitational/teleport/lib/session" - "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/sshutils/sftp" "github.com/gravitational/teleport/lib/utils" ) @@ -325,20 +324,18 @@ func (s *SessionRegistry) OpenSession(ctx context.Context, ch ssh.Channel, scx * return trace.AccessDenied("join-only mode was used to create this connection but attempted to create a new session.") } - // session not found? need to create one. start by getting/generating an ID for it - sid, found := scx.GetEnv(sshutils.SessionEnvVar) - if !found { - sid = string(rsession.NewID()) - scx.SetEnv(sshutils.SessionEnvVar, sid) + sid := scx.SessionID() + if sid.IsZero() { + return trace.BadParameter("session ID is not set") } // This logic allows concurrent request to create a new session // to fail, what is ok because we should never have this condition - sess, p, err := newSession(ctx, rsession.ID(sid), s, scx, ch) + sess, p, err := newSession(ctx, sid, s, scx, ch) if err != nil { return trace.Wrap(err) } - scx.setSession(sess) + scx.setSession(sess, ch) s.addSession(sess) scx.Infof("Creating (interactive) session %v.", sid) @@ -353,22 +350,18 @@ func (s *SessionRegistry) OpenSession(ctx context.Context, ch ssh.Channel, scx * // OpenExecSession opens a non-interactive exec session. func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Channel, scx *ServerContext) error { - var sessionID rsession.ID + sessionID := scx.SessionID() - sid, found := scx.GetEnv(sshutils.SessionEnvVar) - if !found { - // Create a new session ID. These sessions can not be joined + if sessionID.IsZero() { sessionID = rsession.NewID() scx.Tracef("Session not found, creating a new session %s", sessionID) } else { // Use passed session ID. Assist uses this "feature" to record // the execution output. - sessionID = rsession.ID(sid) scx.Tracef("Session found, reusing it %s", sessionID) } - _, found = scx.GetEnv(teleport.EnableNonInteractiveSessionRecording) - if found { + if _, found := scx.GetEnv(teleport.EnableNonInteractiveSessionRecording); found { scx.recordNonInteractiveSession = true } @@ -398,7 +391,7 @@ func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Chann // Start a non-interactive session (TTY attached). Close the session if an error // occurs, otherwise it will be closed by the callee. - scx.setSession(sess) + scx.setSession(sess, channel) err = sess.startExec(ctx, channel, scx) if err != nil { diff --git a/lib/srv/sess_test.go b/lib/srv/sess_test.go index f55c4c517ac7c..1cd0acc2870b4 100644 --- a/lib/srv/sess_test.go +++ b/lib/srv/sess_test.go @@ -609,10 +609,10 @@ func TestParties(t *testing.T) { func testJoinSession(t *testing.T, reg *SessionRegistry, sess *session) { scx := newTestServerContext(t, reg.Srv, nil) - scx.setSession(sess) + sshChanOpen := newMockSSHChannel() + scx.setSession(sess, sshChanOpen) // Open a new session - sshChanOpen := newMockSSHChannel() go func() { // Consume stdout sent to the channel io.ReadAll(sshChanOpen) diff --git a/lib/web/sessions.go b/lib/web/sessions.go index 126d143a30d45..90594cc1ed60d 100644 --- a/lib/web/sessions.go +++ b/lib/web/sessions.go @@ -27,6 +27,7 @@ import ( "io" "net" "sync" + "sync/atomic" "time" "github.com/gravitational/trace" @@ -56,6 +57,7 @@ import ( "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/session" alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/tlsca" @@ -1004,7 +1006,6 @@ func (s *sessionCache) getOrCreateSession(ctx context.Context, user, sessionID s } return sctx, nil - } func (s *sessionCache) invalidateSession(ctx context.Context, sctx *SessionContext) error { @@ -1326,3 +1327,84 @@ func (c *remoteClientCache) Close() error { return trace.NewAggregate(errors...) } + +// sessionIDStatus indicates whether the session ID was received from +// the server or not, and if not why +type sessionIDStatus int + +const ( + // sessionIDReceived indicates the the session ID was received + sessionIDReceived sessionIDStatus = iota + 1 + // sessionIDNotSent indicates that the server set the session ID + // but didn't send it to us + sessionIDNotSent + // sessionIDNotModified indicates that the server used the session + // ID that was set by us + sessionIDNotModified +) + +// prepareToReceiveSessionID configures the TeleportClient to listen for +// the server to send the session ID it's using. The returned function +// will return the current session ID from the server or a reason why +// one wasn't received. +func prepareToReceiveSessionID(ctx context.Context, log *logrus.Entry, nc *client.NodeClient) func() (session.ID, sessionIDStatus) { + // send the session ID received from the server + var gotSessionID atomic.Bool + sessionIDFromServer := make(chan session.ID, 1) + nc.TC.OnChannelRequest = func(req *ssh.Request) *ssh.Request { + // ignore unrelated requests and handle only the first session + // ID request + if req.Type != teleport.CurrentSessionIDRequest || gotSessionID.Load() { + return req + } + + sid, err := session.ParseID(string(req.Payload)) + if err != nil { + log.WithError(err).Warn("Unable to parse session ID.") + return nil + } + + if gotSessionID.CompareAndSwap(false, true) { + sessionIDFromServer <- *sid + } + + return nil + } + + // If the session is about to close and we haven't received a session + // ID yet, ask if the server even supports sending one. Send the + // request in a new goroutine so session establishment won't be + // blocked on making this request + serverWillSetSessionID := make(chan bool, 1) + go func() { + resp, _, err := nc.Client.SendRequest(ctx, teleport.SessionIDQueryRequest, true, nil) + if err != nil { + log.WithError(err).Warn("Failed to send session ID query request") + serverWillSetSessionID <- false + } else { + serverWillSetSessionID <- resp + } + }() + + return func() (session.ID, sessionIDStatus) { + timer := time.NewTimer(10 * time.Second) + defer timer.Stop() + + for { + select { + case sessionID := <-sessionIDFromServer: + return sessionID, sessionIDReceived + case sessionIDIsComing := <-serverWillSetSessionID: + if !sessionIDIsComing { + return session.ID(""), sessionIDNotModified + } + // the server will send the session ID, continue + // waiting for it + case <-ctx.Done(): + return session.ID(""), sessionIDNotSent + case <-timer.C: + return session.ID(""), sessionIDNotSent + } + } + } +} diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 14d73b08bbbec..9a4a3b474036f 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -201,9 +201,11 @@ type TerminalHandlerConfig struct { func (t *TerminalHandlerConfig) CheckAndSetDefaults() error { // Make sure whatever session is requested is a valid session id. - _, err := session.ParseID(t.SessionData.ID.String()) - if err != nil { - return trace.BadParameter("sid: invalid session id") + if !t.SessionData.ID.IsZero() { + _, err := session.ParseID(t.SessionData.ID.String()) + if err != nil { + return trace.BadParameter("sid: invalid session id") + } } if t.SessionData.Login == "" { @@ -349,7 +351,7 @@ func (t *TerminalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { t.handler(ws, r) } -func (t *TerminalHandler) writeSessionData(ctx context.Context) error { +func (t *TerminalHandler) writeSessionData() error { envelope := &terminal.Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketSessionMetadata, @@ -786,10 +788,6 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor return } - if err := t.writeSessionData(ctx); err != nil { - t.log.WithError(err).Warn("Unable to stream terminal - failure sending session data") - } - var beforeStart func(io.Writer) if t.participantMode == types.SessionModeratorMode { beforeStart = func(out io.Writer) { @@ -810,9 +808,47 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor } }() + sessionDataSent := make(chan struct{}) + // If we are joining a session, send the session data right away, we + // know the session ID + if t.tracker != nil { + if err := t.writeSessionData(); err != nil { + t.log.WithError(err).Warn("Failure sending session data") + } + close(sessionDataSent) + } else { + // We are creating a new session and the server will generate a + // new session ID, send the session data once the session is + // created and the server sends us the session ID it is using + writeSessionCtx, writeSessionCancel := context.WithCancel(ctx) + defer writeSessionCancel() + waitForSessionID := prepareToReceiveSessionID(writeSessionCtx, t.log, nc) + + // wait in a new goroutine because the server won't set a + // session ID until we open a shell + go func() { + defer close(sessionDataSent) + + sid, status := waitForSessionID() + switch status { + case sessionIDReceived: + t.sessionData.ID = sid + fallthrough + case sessionIDNotModified: + if err := t.writeSessionData(); err != nil { + t.log.WithError(err).Warn("Failure sending session data") + } + case sessionIDNotSent: + t.log.Warn("Failed to receive session data") + default: + t.log.Warnf("Invalid session ID status %v", status) + } + }() + } + // Establish SSH connection to the server. This function will block until // either an error occurs or it completes successfully. - if err = nc.RunInteractiveShell(ctx, t.participantMode, t.tracker, beforeStart); err != nil { + if err = nc.RunInteractiveShell(ctx, t.participantMode, t.tracker, nil, beforeStart); err != nil { if !t.closedByClient.Load() { t.stream.WriteError(err.Error()) } @@ -823,6 +859,9 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor return } + // Wait for the session data to be sent before closing the session + <-sessionDataSent + // Send close envelope to web terminal upon exit without an error. if err := t.stream.SendCloseMessage(t.sessionData.ServerID); err != nil { t.log.WithError(err).Error("Unable to send close event to web client.") diff --git a/lib/web/terminal/terminal.go b/lib/web/terminal/terminal.go index 2ad18331a673b..c92dae03fb2d1 100644 --- a/lib/web/terminal/terminal.go +++ b/lib/web/terminal/terminal.go @@ -87,9 +87,11 @@ type WSStream struct { // fit into the buffer provided by the callee to Read method buffer []byte - // mu protects writes to ws - mu sync.Mutex - // ws the connection to the UI + // readMu protects reads to WSConn + readMu sync.Mutex + // writeMu protects writes to WSConn + writeMu sync.Mutex + // WSConn the connection to the UI WSConn // log holds the structured logger. @@ -99,6 +101,18 @@ type WSStream struct { // Replace \n with \r\n so the message is correctly aligned. var replacer = strings.NewReplacer("\r\n", "\r\n", "\n", "\r\n") +func (t *WSStream) ReadMessage() (messageType int, p []byte, err error) { + t.readMu.Lock() + defer t.readMu.Unlock() + return t.WSConn.ReadMessage() +} + +func (t *WSStream) WriteMessage(messageType int, data []byte) error { + t.writeMu.Lock() + defer t.writeMu.Unlock() + return t.WSConn.WriteMessage(messageType, data) +} + // WriteError displays an error in the terminal window. func (t *WSStream) WriteError(msg string) { if _, writeErr := replacer.WriteString(t, msg); writeErr != nil { @@ -214,9 +228,7 @@ func (t *WSStream) WriteChallenge(challenge *client.MFAAuthenticateChallenge, co return trace.Wrap(err) } - t.mu.Lock() - defer t.mu.Unlock() - return trace.Wrap(t.WSConn.WriteMessage(websocket.BinaryMessage, msg)) + return trace.Wrap(t.WriteMessage(websocket.BinaryMessage, msg)) } // ReadChallengeResponse reads and decodes the challenge response from the @@ -262,9 +274,7 @@ func (t *WSStream) WriteAuditEvent(event []byte) error { } // Send bytes over the websocket to the web client. - t.mu.Lock() - defer t.mu.Unlock() - return trace.Wrap(t.WSConn.WriteMessage(websocket.BinaryMessage, envelopeBytes)) + return trace.Wrap(t.WriteMessage(websocket.BinaryMessage, envelopeBytes)) } // SSHSessionLatencyStats contain latency measurements for both @@ -302,9 +312,7 @@ func (t *WSStream) WriteLatency(latency SSHSessionLatencyStats) error { } // Send bytes over the websocket to the web client. - t.mu.Lock() - defer t.mu.Unlock() - return trace.Wrap(t.WSConn.WriteMessage(websocket.BinaryMessage, envelopeBytes)) + return trace.Wrap(t.WriteMessage(websocket.BinaryMessage, envelopeBytes)) } // Write wraps the data bytes in a raw envelope and sends. @@ -325,9 +333,7 @@ func (t *WSStream) Write(data []byte) (n int, err error) { } // Send bytes over the websocket to the web client. - t.mu.Lock() - err = t.WSConn.WriteMessage(websocket.BinaryMessage, envelopeBytes) - t.mu.Unlock() + err = t.WriteMessage(websocket.BinaryMessage, envelopeBytes) if err != nil { return 0, trace.Wrap(err) } @@ -391,9 +397,7 @@ func (t *WSStream) SendCloseMessage(id string) error { return trace.Wrap(err) } - t.mu.Lock() - defer t.mu.Unlock() - return trace.Wrap(t.WSConn.WriteMessage(websocket.BinaryMessage, envelopeBytes)) + return trace.Wrap(t.WriteMessage(websocket.BinaryMessage, envelopeBytes)) } func (t *WSStream) close() {