diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 94abc49b08c53..26211ee32dca7 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -413,7 +413,9 @@ func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Chann return trace.Wrap(err) } - canStart, _, err := sess.checkIfStart() + sess.mu.Lock() + canStart, _, err := sess.checkIfStartUnderLock() + sess.mu.Unlock() if err != nil { return trace.Wrap(err) } @@ -500,7 +502,7 @@ func (s *SessionRegistry) isApprovedFileTransfer(scx *ServerContext) (bool, erro sess.fileTransferReq = nil sess.BroadcastMessage("file transfer request %s denied due to %s attempting to transfer files", req.ID, scx.Identity.TeleportUser) - _ = s.NotifyFileTransferRequest(req, FileTransferDenied, scx) + _ = s.notifyFileTransferRequestUnderLock(req, FileTransferDenied, scx) return false, trace.AccessDenied("Teleport user does not match original requester") } @@ -533,9 +535,9 @@ const ( FileTransferDenied FileTransferRequestEvent = "file_transfer_request_deny" ) -// NotifyFileTransferRequest is called to notify all members of a party that a file transfer request has been created/approved/denied. +// notifyFileTransferRequestUnderLock is called to notify all members of a party that a file transfer request has been created/approved/denied. // The notification is a global ssh request and requires the client to update its UI state accordingly. -func (s *SessionRegistry) NotifyFileTransferRequest(req *FileTransferRequest, res FileTransferRequestEvent, scx *ServerContext) error { +func (s *SessionRegistry) notifyFileTransferRequestUnderLock(req *FileTransferRequest, res FileTransferRequestEvent, scx *ServerContext) error { session := scx.getSession() if session == nil { s.log.Debugf("Unable to notify %s, no session found in context.", res) @@ -1074,7 +1076,7 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) { // Notify all members of the party that a new member has joined over the // "x-teleport-event" channel. - for _, p := range s.parties { + for _, p := range s.getParties() { if len(notifyPartyPayload) == 0 { s.log.Warnf("No join event to send to %v", p.sconn.RemoteAddr()) continue @@ -1092,10 +1094,10 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) { } } -// emitSessionLeaveEvent emits a session leave event to both the Audit Log as +// emitSessionLeaveEventUnderLock emits a session leave event to both the Audit Log as // well as sending a "x-teleport-event" global request on the SSH connection. // Must be called under session Lock. -func (s *session) emitSessionLeaveEvent(ctx *ServerContext) { +func (s *session) emitSessionLeaveEventUnderLock(ctx *ServerContext) { sessionLeaveEvent := &apievents.SessionLeave{ Metadata: apievents.Metadata{ Type: events.SessionLeaveEvent, @@ -1289,7 +1291,9 @@ func (s *session) launch() { // startInteractive starts a new interactive process (or a shell) in the // current session. func (s *session) startInteractive(ctx context.Context, scx *ServerContext, p *party) error { - canStart, _, err := s.checkIfStart() + s.mu.Lock() + canStart, _, err := s.checkIfStartUnderLock() + s.mu.Unlock() if err != nil { return trace.Wrap(err) } @@ -1554,11 +1558,8 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve } func (s *session) broadcastResult(r ExecResult) { - s.mu.Lock() - defer s.mu.Unlock() - payload := ssh.Marshal(struct{ C uint32 }{C: uint32(r.Code)}) - for _, p := range s.parties { + for _, p := range s.getParties() { if _, err := p.ch.SendRequest("exit-status", false, payload); err != nil { s.log.Infof("Failed to send exit status for %v: %v", r.Command, err) } @@ -1566,7 +1567,7 @@ func (s *session) broadcastResult(r ExecResult) { } func (s *session) String() string { - return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.parties)) + return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.getParties())) } // removePartyUnderLock removes the party from the in-memory map that holds all party members @@ -1592,9 +1593,9 @@ func (s *session) removePartyUnderLock(p *party) error { // Emit session leave event to both the Audit Log and over the // "x-teleport-event" channel in the SSH connection. - s.emitSessionLeaveEvent(p.ctx) + s.emitSessionLeaveEventUnderLock(p.ctx) - canRun, policyOptions, err := s.checkIfStart() + canRun, policyOptions, err := s.checkIfStartUnderLock() if err != nil { return trace.Wrap(err) } @@ -1819,7 +1820,7 @@ func (s *session) addFileTransferRequest(params *rsession.FileTransferRequestPar } else { s.BroadcastMessage("User %s would like to upload %s to: %s", params.Requester, params.Filename, params.Location) } - err = s.registry.NotifyFileTransferRequest(s.fileTransferReq, FileTransferUpdate, scx) + err = s.registry.notifyFileTransferRequestUnderLock(s.fileTransferReq, FileTransferUpdate, scx) return trace.Wrap(err) } @@ -1862,7 +1863,7 @@ func (s *session) approveFileTransferRequest(params *rsession.FileTransferDecisi } else { eventType = FileTransferUpdate } - err = s.registry.NotifyFileTransferRequest(s.fileTransferReq, eventType, scx) + err = s.registry.notifyFileTransferRequestUnderLock(s.fileTransferReq, eventType, scx) return trace.Wrap(err) } @@ -1895,12 +1896,15 @@ func (s *session) denyFileTransferRequest(params *rsession.FileTransferDecisionP s.fileTransferReq = nil s.BroadcastMessage("%s denied file transfer request %s", scx.Identity.TeleportUser, req.ID) - err := s.registry.NotifyFileTransferRequest(req, FileTransferDenied, scx) + err := s.registry.notifyFileTransferRequestUnderLock(req, FileTransferDenied, scx) return trace.Wrap(err) } -func (s *session) checkIfStart() (bool, auth.PolicyOptions, error) { +// checkIfStartUnderLock determines if any moderation policies associated with +// the session are satisfied. +// Must be called under session Lock. +func (s *session) checkIfStartUnderLock() (bool, auth.PolicyOptions, error) { var participants []auth.SessionAccessContext for _, party := range s.parties { @@ -1939,7 +1943,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error { } if len(s.parties) == 0 { - canStart, _, err := s.checkIfStart() + canStart, _, err := s.checkIfStartUnderLock() if err != nil { return trace.Wrap(err) } @@ -1992,7 +1996,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error { } if s.tracker.GetState() == types.SessionState_SessionStatePending { - canStart, _, err := s.checkIfStart() + canStart, _, err := s.checkIfStartUnderLock() if err != nil { return trace.Wrap(err) }