From 102658f5168bff2985838a8660a5439762990d4d Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Fri, 25 Oct 2024 17:21:07 +0000 Subject: [PATCH] Prevent racy access to session parties (#47936) Prefer using session.getParties instead of using session.parties directly to prevent races when new parties are added. Any functions that are using session.parties AND are called from another function that already obtains the lock have been renamed to reflect that they must only be called if the session lock is held. --- lib/srv/sess.go | 46 +++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 94abc49b08c53..26211ee32dca7 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -413,7 +413,9 @@ func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Chann return trace.Wrap(err) } - canStart, _, err := sess.checkIfStart() + sess.mu.Lock() + canStart, _, err := sess.checkIfStartUnderLock() + sess.mu.Unlock() if err != nil { return trace.Wrap(err) } @@ -500,7 +502,7 @@ func (s *SessionRegistry) isApprovedFileTransfer(scx *ServerContext) (bool, erro sess.fileTransferReq = nil sess.BroadcastMessage("file transfer request %s denied due to %s attempting to transfer files", req.ID, scx.Identity.TeleportUser) - _ = s.NotifyFileTransferRequest(req, FileTransferDenied, scx) + _ = s.notifyFileTransferRequestUnderLock(req, FileTransferDenied, scx) return false, trace.AccessDenied("Teleport user does not match original requester") } @@ -533,9 +535,9 @@ const ( FileTransferDenied FileTransferRequestEvent = "file_transfer_request_deny" ) -// NotifyFileTransferRequest is called to notify all members of a party that a file transfer request has been created/approved/denied. +// notifyFileTransferRequestUnderLock is called to notify all members of a party that a file transfer request has been created/approved/denied. // The notification is a global ssh request and requires the client to update its UI state accordingly. -func (s *SessionRegistry) NotifyFileTransferRequest(req *FileTransferRequest, res FileTransferRequestEvent, scx *ServerContext) error { +func (s *SessionRegistry) notifyFileTransferRequestUnderLock(req *FileTransferRequest, res FileTransferRequestEvent, scx *ServerContext) error { session := scx.getSession() if session == nil { s.log.Debugf("Unable to notify %s, no session found in context.", res) @@ -1074,7 +1076,7 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) { // Notify all members of the party that a new member has joined over the // "x-teleport-event" channel. - for _, p := range s.parties { + for _, p := range s.getParties() { if len(notifyPartyPayload) == 0 { s.log.Warnf("No join event to send to %v", p.sconn.RemoteAddr()) continue @@ -1092,10 +1094,10 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) { } } -// emitSessionLeaveEvent emits a session leave event to both the Audit Log as +// emitSessionLeaveEventUnderLock emits a session leave event to both the Audit Log as // well as sending a "x-teleport-event" global request on the SSH connection. // Must be called under session Lock. -func (s *session) emitSessionLeaveEvent(ctx *ServerContext) { +func (s *session) emitSessionLeaveEventUnderLock(ctx *ServerContext) { sessionLeaveEvent := &apievents.SessionLeave{ Metadata: apievents.Metadata{ Type: events.SessionLeaveEvent, @@ -1289,7 +1291,9 @@ func (s *session) launch() { // startInteractive starts a new interactive process (or a shell) in the // current session. func (s *session) startInteractive(ctx context.Context, scx *ServerContext, p *party) error { - canStart, _, err := s.checkIfStart() + s.mu.Lock() + canStart, _, err := s.checkIfStartUnderLock() + s.mu.Unlock() if err != nil { return trace.Wrap(err) } @@ -1554,11 +1558,8 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve } func (s *session) broadcastResult(r ExecResult) { - s.mu.Lock() - defer s.mu.Unlock() - payload := ssh.Marshal(struct{ C uint32 }{C: uint32(r.Code)}) - for _, p := range s.parties { + for _, p := range s.getParties() { if _, err := p.ch.SendRequest("exit-status", false, payload); err != nil { s.log.Infof("Failed to send exit status for %v: %v", r.Command, err) } @@ -1566,7 +1567,7 @@ func (s *session) broadcastResult(r ExecResult) { } func (s *session) String() string { - return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.parties)) + return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.getParties())) } // removePartyUnderLock removes the party from the in-memory map that holds all party members @@ -1592,9 +1593,9 @@ func (s *session) removePartyUnderLock(p *party) error { // Emit session leave event to both the Audit Log and over the // "x-teleport-event" channel in the SSH connection. - s.emitSessionLeaveEvent(p.ctx) + s.emitSessionLeaveEventUnderLock(p.ctx) - canRun, policyOptions, err := s.checkIfStart() + canRun, policyOptions, err := s.checkIfStartUnderLock() if err != nil { return trace.Wrap(err) } @@ -1819,7 +1820,7 @@ func (s *session) addFileTransferRequest(params *rsession.FileTransferRequestPar } else { s.BroadcastMessage("User %s would like to upload %s to: %s", params.Requester, params.Filename, params.Location) } - err = s.registry.NotifyFileTransferRequest(s.fileTransferReq, FileTransferUpdate, scx) + err = s.registry.notifyFileTransferRequestUnderLock(s.fileTransferReq, FileTransferUpdate, scx) return trace.Wrap(err) } @@ -1862,7 +1863,7 @@ func (s *session) approveFileTransferRequest(params *rsession.FileTransferDecisi } else { eventType = FileTransferUpdate } - err = s.registry.NotifyFileTransferRequest(s.fileTransferReq, eventType, scx) + err = s.registry.notifyFileTransferRequestUnderLock(s.fileTransferReq, eventType, scx) return trace.Wrap(err) } @@ -1895,12 +1896,15 @@ func (s *session) denyFileTransferRequest(params *rsession.FileTransferDecisionP s.fileTransferReq = nil s.BroadcastMessage("%s denied file transfer request %s", scx.Identity.TeleportUser, req.ID) - err := s.registry.NotifyFileTransferRequest(req, FileTransferDenied, scx) + err := s.registry.notifyFileTransferRequestUnderLock(req, FileTransferDenied, scx) return trace.Wrap(err) } -func (s *session) checkIfStart() (bool, auth.PolicyOptions, error) { +// checkIfStartUnderLock determines if any moderation policies associated with +// the session are satisfied. +// Must be called under session Lock. +func (s *session) checkIfStartUnderLock() (bool, auth.PolicyOptions, error) { var participants []auth.SessionAccessContext for _, party := range s.parties { @@ -1939,7 +1943,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error { } if len(s.parties) == 0 { - canStart, _, err := s.checkIfStart() + canStart, _, err := s.checkIfStartUnderLock() if err != nil { return trace.Wrap(err) } @@ -1992,7 +1996,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error { } if s.tracker.GetState() == types.SessionState_SessionStatePending { - canStart, _, err := s.checkIfStart() + canStart, _, err := s.checkIfStartUnderLock() if err != nil { return trace.Wrap(err) }