From 16d7b1630b0a30d48b1bb64b02923078f3b27533 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 20 Sep 2024 10:26:34 +0100 Subject: [PATCH 01/13] Start hacking on removing deprecated session api --- lib/auth/apiserver.go | 72 ----------------------------- lib/auth/auth_with_roles.go | 28 ----------- lib/auth/authclient/http_client.go | 45 ------------------ web/packages/teleport/src/config.ts | 7 --- 4 files changed, 152 deletions(-) diff --git a/lib/auth/apiserver.go b/lib/auth/apiserver.go index 358d389b9ab72..92c87d029bba2 100644 --- a/lib/auth/apiserver.go +++ b/lib/auth/apiserver.go @@ -40,7 +40,6 @@ import ( "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/plugin" "github.com/gravitational/teleport/lib/services" - "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/utils" ) @@ -147,11 +146,6 @@ func NewAPIServer(config *APIConfig) (http.Handler, error) { // Tokens srv.POST("/:version/tokens/register", srv.WithAuth(srv.registerUsingToken)) - // Active sessions - // TODO(zmb3): remove these endpoints when Assist no longer needs them - srv.GET("/:version/namespaces/:namespace/sessions/:id/stream", srv.WithAuth(srv.getSessionChunk)) - srv.GET("/:version/namespaces/:namespace/sessions/:id/events", srv.WithAuth(srv.getSessionEvents)) - // Namespaces srv.POST("/:version/namespaces", srv.WithAuth(srv.upsertNamespace)) srv.GET("/:version/namespaces", srv.WithAuth(srv.getNamespaces)) @@ -697,72 +691,6 @@ func (s *APIServer) searchSessionEvents(auth *ServerWithRoles, w http.ResponseWr return eventsList, nil } -// HTTP GET /:version/sessions/:id/stream?offset=x&bytes=y -// Query parameters: -// -// "offset" : bytes from the beginning -// "bytes" : number of bytes to read (it won't return more than 512Kb) -func (s *APIServer) getSessionChunk(auth *ServerWithRoles, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { - sid, err := session.ParseID(p.ByName("id")) - if err != nil { - return nil, trace.BadParameter("missing parameter id") - } - namespace := p.ByName("namespace") - if !types.IsValidNamespace(namespace) { - return nil, trace.BadParameter("invalid namespace %q", namespace) - } - - // "offset bytes" query param - offsetBytes, err := strconv.Atoi(r.URL.Query().Get("offset")) - if err != nil || offsetBytes < 0 { - offsetBytes = 0 - } - // "max bytes" query param - max, err := strconv.Atoi(r.URL.Query().Get("bytes")) - if err != nil || offsetBytes < 0 { - offsetBytes = 0 - } - s.AuthServer.logger.DebugContext( - r.Context(), "apiserver.GetSessionChunk called", - "namespace", namespace, - "sid", sid, - "offset", offsetBytes, - ) - - w.Header().Set("Content-Type", "text/plain") - - buffer, err := auth.GetSessionChunk(namespace, *sid, offsetBytes, max) - if err != nil { - return nil, trace.Wrap(err) - } - if _, err = w.Write(buffer); err != nil { - return nil, trace.Wrap(err) - } - w.Header().Set("Content-Type", "application/octet-stream") - return nil, nil -} - -// HTTP GET /:version/sessions/:id/events?maxage=n -// Query: -// -// 'after' : cursor value to return events newer than N. Defaults to 0, (return all) -func (s *APIServer) getSessionEvents(auth *ServerWithRoles, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { - sid, err := session.ParseID(p.ByName("id")) - if err != nil { - return nil, trace.Wrap(err) - } - namespace := p.ByName("namespace") - if !types.IsValidNamespace(namespace) { - return nil, trace.BadParameter("invalid namespace %q", namespace) - } - afterN, err := strconv.Atoi(r.URL.Query().Get("after")) - if err != nil { - afterN = 0 - } - - return auth.GetSessionEvents(namespace, *sid, afterN) -} - type upsertNamespaceReq struct { Namespace types.Namespace `json:"namespace"` } diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 4a0e524d0df72..4f9db7c427b17 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -4235,34 +4235,6 @@ func (s *streamWithRoles) RecordEvent(ctx context.Context, pe apievents.Prepared return s.stream.RecordEvent(ctx, pe) } -func (a *ServerWithRoles) GetSessionChunk(namespace string, sid session.ID, offsetBytes, maxBytes int) ([]byte, error) { - if err := a.actionForKindSession(namespace, sid); err != nil { - return nil, trace.Wrap(err) - } - - return a.alog.GetSessionChunk(namespace, sid, offsetBytes, maxBytes) -} - -func (a *ServerWithRoles) GetSessionEvents(namespace string, sid session.ID, afterN int) ([]events.EventFields, error) { - if err := a.actionForKindSession(namespace, sid); err != nil { - return nil, trace.Wrap(err) - } - - // emit a session recording view event for the audit log - if err := a.authServer.emitter.EmitAuditEvent(a.authServer.closeCtx, &apievents.SessionRecordingAccess{ - Metadata: apievents.Metadata{ - Type: events.SessionRecordingAccessEvent, - Code: events.SessionRecordingAccessCode, - }, - SessionID: sid.String(), - UserMetadata: a.context.Identity.GetIdentity().GetUserMetadata(), - }); err != nil { - return nil, trace.Wrap(err) - } - - return a.alog.GetSessionEvents(namespace, sid, afterN) -} - func (a *ServerWithRoles) findSessionEndEvent(namespace string, sid session.ID) (apievents.AuditEvent, error) { sessionEvents, _, err := a.alog.SearchSessionEvents(context.TODO(), events.SearchSessionEventsRequest{ From: time.Time{}, diff --git a/lib/auth/authclient/http_client.go b/lib/auth/authclient/http_client.go index 8dcac07f57707..00b494af1824b 100644 --- a/lib/auth/authclient/http_client.go +++ b/lib/auth/authclient/http_client.go @@ -24,7 +24,6 @@ import ( "encoding/json" "net/http" "net/url" - "strconv" "strings" "time" @@ -41,10 +40,8 @@ import ( tracehttp "github.com/gravitational/teleport/api/observability/tracing/http" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/services" - "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/utils" ) @@ -864,48 +861,6 @@ func (c *HTTPClient) ValidateGithubAuthCallback(ctx context.Context, q url.Value return &response, nil } -// GetSessionChunk allows clients to receive a byte array (chunk) from a recorded -// session stream, starting from 'offset', up to 'max' in length. The upper bound -// of 'max' is set to events.MaxChunkBytes -// -// Deprecated: use StreamSessionEvents API instead -func (c *HTTPClient) GetSessionChunk(namespace string, sid session.ID, offsetBytes, maxBytes int) ([]byte, error) { - // DELETE IN 16(zmb3): v15 web UIs stopped calling this - if namespace == "" { - return nil, trace.BadParameter(MissingNamespaceError) - } - response, err := c.Get(context.TODO(), c.Endpoint("namespaces", namespace, "sessions", string(sid), "stream"), url.Values{ - "offset": []string{strconv.Itoa(offsetBytes)}, - "bytes": []string{strconv.Itoa(maxBytes)}, - }) - if err != nil { - return nil, trace.Wrap(err) - } - return response.Bytes(), nil -} - -// Deprecated: use StreamSessionEvents API instead. -// TODO(zmb3): remove from ClientI interface -func (c *HTTPClient) GetSessionEvents(namespace string, sid session.ID, afterN int) (retval []events.EventFields, err error) { - // DELETE IN 16(zmb3): v15 web UIs stopped calling this - if namespace == "" { - return nil, trace.BadParameter(MissingNamespaceError) - } - query := make(url.Values) - if afterN > 0 { - query.Set("after", strconv.Itoa(afterN)) - } - response, err := c.Get(context.TODO(), c.Endpoint("namespaces", namespace, "sessions", string(sid), "events"), query) - if err != nil { - return nil, trace.Wrap(err) - } - retval = make([]events.EventFields, 0) - if err := json.Unmarshal(response.Bytes(), &retval); err != nil { - return nil, trace.Wrap(err) - } - return retval, nil -} - // GetNamespaces returns a list of namespaces func (c *HTTPClient) GetNamespaces() ([]types.Namespace, error) { out, err := c.Get(context.TODO(), c.Endpoint("namespaces"), url.Values{}) diff --git a/web/packages/teleport/src/config.ts b/web/packages/teleport/src/config.ts index 61800d5c97ac2..217b48e52b7a8 100644 --- a/web/packages/teleport/src/config.ts +++ b/web/packages/teleport/src/config.ts @@ -245,8 +245,6 @@ const cfg = { 'wss://:fqdn/v1/webapi/sites/:clusterId/ttyplayback/:sid?access_token=:token', // TODO(zmb3): get token out of URL activeAndPendingSessionsPath: '/v1/webapi/sites/:clusterId/sessions', - // TODO(zmb3): remove this when Assist is no longer using it - sshPlaybackPrefix: '/v1/webapi/sites/:clusterId/sessions/:sid', // prefix because this is eventually concatenated with "/stream" or "/events" kubernetesPath: '/v1/webapi/sites/:clusterId/kubernetes?searchAsRoles=:searchAsRoles?&limit=:limit?&startKey=:startKey?&query=:query?&search=:search?&sort=:sort?', @@ -706,11 +704,6 @@ const cfg = { return generatePath(cfg.api.userWithUsernamePath, { username }); }, - getSshPlaybackPrefixUrl({ clusterId, sid }: UrlParams) { - // TODO(zmb3): remove this when Assist is no longer using it - return generatePath(cfg.api.sshPlaybackPrefix, { clusterId, sid }); - }, - getActiveAndPendingSessionsUrl({ clusterId }: UrlParams) { return generatePath(cfg.api.activeAndPendingSessionsPath, { clusterId }); }, From 80a078a94d66b5355f5d4af33aeaf0398cb7350c Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 20 Sep 2024 10:32:35 +0100 Subject: [PATCH 02/13] Start clearing out unused methods --- lib/events/api.go | 13 --- lib/events/auditlog.go | 127 ------------------------------ lib/events/discard.go | 8 -- lib/events/eventstest/streamer.go | 10 --- lib/web/apiserver.go | 122 ---------------------------- lib/web/apiserver_test.go | 4 - lib/web/terminal.go | 2 - 7 files changed, 286 deletions(-) diff --git a/lib/events/api.go b/lib/events/api.go index 06a4f67c60148..6aa605d1bb3ed 100644 --- a/lib/events/api.go +++ b/lib/events/api.go @@ -1003,19 +1003,6 @@ type AuditLogSessionStreamer interface { // SessionStreamer supports streaming session chunks or events. type SessionStreamer interface { - // GetSessionChunk returns a reader which can be used to read a byte stream - // of a recorded session starting from 'offsetBytes' (pass 0 to start from the - // beginning) up to maxBytes bytes. - // - // If maxBytes > MaxChunkBytes, it gets rounded down to MaxChunkBytes - GetSessionChunk(namespace string, sid session.ID, offsetBytes, maxBytes int) ([]byte, error) - - // Returns all events that happen during a session sorted by time - // (oldest first). - // - // after is used to return events after a specified cursor ID - GetSessionEvents(namespace string, sid session.ID, after int) ([]EventFields, error) - // StreamSessionEvents streams all events from a given session recording. An // error is returned on the first channel if one is encountered. Otherwise // the event channel is closed when the stream ends. The event channel is diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index af952c07145b5..09d9d201c8e3a 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -20,7 +20,6 @@ package events import ( "bufio" - "bytes" "compress/gzip" "context" "encoding/json" @@ -667,31 +666,6 @@ func (l *AuditLog) downloadSession(namespace string, sid session.ID) error { return nil } -// GetSessionChunk returns a reader which console and web clients request -// to receive a live stream of a given session. The reader allows access to a -// session stream range from offsetBytes to offsetBytes+maxBytes -func (l *AuditLog) GetSessionChunk(namespace string, sid session.ID, offsetBytes, maxBytes int) ([]byte, error) { - if err := l.downloadSession(namespace, sid); err != nil { - return nil, trace.Wrap(err) - } - var data []byte - for { - out, err := l.getSessionChunk(namespace, sid, offsetBytes, maxBytes) - if err != nil { - if errors.Is(err, io.EOF) { - return data, nil - } - return nil, trace.Wrap(err) - } - data = append(data, out...) - if len(data) == maxBytes || len(out) == 0 { - return data, nil - } - maxBytes = maxBytes - len(out) - offsetBytes = offsetBytes + len(out) - } -} - func (l *AuditLog) cleanupOldPlaybacks() error { // scan the log directory and clean files last // accessed after an hour @@ -785,107 +759,6 @@ func (l *AuditLog) unpackFile(fileName string) (readSeekCloser, error) { return dest, nil } -func (l *AuditLog) getSessionChunk(namespace string, sid session.ID, offsetBytes, maxBytes int) ([]byte, error) { - if namespace == "" { - return nil, trace.BadParameter("missing parameter namespace") - } - idx, err := l.readSessionIndex(namespace, sid) - if err != nil { - return nil, trace.Wrap(err) - } - fileName, fileOffset, err := idx.chunksFile(int64(offsetBytes)) - if err != nil { - return nil, trace.Wrap(err) - } - reader, err := l.unpackFile(fileName) - if err != nil { - return nil, trace.Wrap(err) - } - defer reader.Close() - - // seek to 'offset' from the beginning - if _, err := reader.Seek(int64(offsetBytes)-fileOffset, 0); err != nil { - return nil, trace.Wrap(err) - } - - // copy up to maxBytes from the offset position: - var buff bytes.Buffer - _, err = io.Copy(&buff, io.LimitReader(reader, int64(maxBytes))) - return buff.Bytes(), err -} - -// Returns all events that happen during a session sorted by time -// (oldest first). -// -// Can be filtered by 'after' (cursor value to return events newer than) -func (l *AuditLog) GetSessionEvents(namespace string, sid session.ID, afterN int) ([]EventFields, error) { - l.log.WithFields(log.Fields{"sid": string(sid), "afterN": afterN}).Debugf("GetSessionEvents.") - if namespace == "" { - return nil, trace.BadParameter("missing parameter namespace") - } - - // If code has to fetch print events (for playback) it has to download - // the playback from external storage first - if err := l.downloadSession(namespace, sid); err != nil { - return nil, trace.Wrap(err) - } - idx, err := l.readSessionIndex(namespace, sid) - if err != nil { - return nil, trace.Wrap(err) - } - fileIndex, err := idx.eventsFile(afterN) - if err != nil { - return nil, trace.Wrap(err) - } - events := make([]EventFields, 0, 256) - for i := fileIndex; i < len(idx.events); i++ { - skip := 0 - if i == fileIndex { - skip = afterN - } - out, err := l.fetchSessionEvents(idx.eventsFileName(i), skip) - if err != nil { - return nil, trace.Wrap(err) - } - events = append(events, out...) - } - return events, nil -} - -func (l *AuditLog) fetchSessionEvents(fileName string, afterN int) ([]EventFields, error) { - logFile, err := os.OpenFile(fileName, os.O_RDONLY, 0o640) - if err != nil { - // no file found? this means no events have been logged yet - if os.IsNotExist(err) { - return nil, nil - } - return nil, trace.Wrap(err) - } - defer logFile.Close() - reader, err := gzip.NewReader(logFile) - if err != nil { - return nil, trace.Wrap(err) - } - defer reader.Close() - - retval := make([]EventFields, 0, 256) - // read line by line: - scanner := bufio.NewScanner(reader) - for lineNo := 0; scanner.Scan(); lineNo++ { - if lineNo < afterN { - continue - } - var fields EventFields - if err = json.Unmarshal(scanner.Bytes(), &fields); err != nil { - log.Error(err) - return nil, trace.Wrap(err) - } - fields[EventCursor] = lineNo - retval = append(retval, fields) - } - return retval, nil -} - // EmitAuditEvent adds a new event to the local file log func (l *AuditLog) EmitAuditEvent(ctx context.Context, event apievents.AuditEvent) error { ctx = context.WithoutCancel(ctx) diff --git a/lib/events/discard.go b/lib/events/discard.go index 2d74b896da276..01d26fab026b9 100644 --- a/lib/events/discard.go +++ b/lib/events/discard.go @@ -44,14 +44,6 @@ func (d *DiscardAuditLog) Close() error { return nil } -func (d *DiscardAuditLog) GetSessionChunk(namespace string, sid session.ID, offsetBytes, maxBytes int) ([]byte, error) { - return make([]byte, 0), nil -} - -func (d *DiscardAuditLog) GetSessionEvents(namespace string, sid session.ID, after int) ([]EventFields, error) { - return make([]EventFields, 0), nil -} - func (d *DiscardAuditLog) SearchEvents(ctx context.Context, req SearchEventsRequest) ([]apievents.AuditEvent, string, error) { return make([]apievents.AuditEvent, 0), "", nil } diff --git a/lib/events/eventstest/streamer.go b/lib/events/eventstest/streamer.go index 8a00bc9607c89..5827d87cc2bd3 100644 --- a/lib/events/eventstest/streamer.go +++ b/lib/events/eventstest/streamer.go @@ -22,8 +22,6 @@ import ( "context" "time" - "github.com/gravitational/trace" - apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/session" @@ -70,11 +68,3 @@ func (f fakeStreamer) StreamSessionEvents(ctx context.Context, sessionID session return events, errors } - -func (f fakeStreamer) GetSessionChunk(namespace string, sid session.ID, offsetBytes, maxBytes int) ([]byte, error) { - return nil, trace.NotImplemented("GetSessionChunk") -} - -func (f fakeStreamer) GetSessionEvents(namespace string, sid session.ID, after int) ([]events.EventFields, error) { - return nil, trace.NotImplemented("GetSessionEvents") -} diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index c53a988472851..6ee3c628f7ca4 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -21,7 +21,6 @@ package web import ( - "compress/gzip" "context" "crypto/tls" "encoding/base64" @@ -810,11 +809,6 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/sites/:site/events/search/sessions", h.WithClusterAuth(h.clusterSearchSessionEvents)) // search site session events h.GET("/webapi/sites/:site/ttyplayback/:sid", h.WithClusterAuth(h.ttyPlaybackHandle)) - // TODO(zmb3): remove these endpoints when Assist is no longer using them - // (assist calls the proxy's web API, and the proxy uses an HTTP client to call auth's API) - h.GET("/webapi/sites/:site/sessions/:sid/events", h.WithClusterAuth(h.siteSessionEventsGet)) // get recorded session's timing information (from events) - h.GET("/webapi/sites/:site/sessions/:sid/stream", h.siteSessionStreamGet) // get recorded session's bytes (from events) - // scp file transfer h.GET("/webapi/sites/:site/nodes/:server/:login/scp", h.WithClusterAuth(h.transferFile)) h.POST("/webapi/sites/:site/nodes/:server/:login/scp", h.WithClusterAuth(h.transferFile)) @@ -4044,83 +4038,6 @@ func queryOrder(query url.Values, name string, def types.EventOrder) (types.Even } } -// siteSessionStreamGet returns a byte array from a session's stream -// -// GET /v1/webapi/sites/:site/namespaces/:namespace/sessions/:sid/stream?query -// -// Query parameters: -// -// "offset" : bytes from the beginning -// "bytes" : number of bytes to read (it won't return more than 512Kb) -// -// Unlike other request handlers, this one does not return JSON. -// It returns the binary stream unencoded, directly in the respose body, -// with Content-Type of application/octet-stream, gzipped with up to 95% -// compression ratio. -func (h *Handler) siteSessionStreamGet(w http.ResponseWriter, r *http.Request, p httprouter.Params) { - httplib.SetNoCacheHeaders(w.Header()) - - onError := func(err error) { - h.log.WithError(err).Debug("Unable to retrieve session chunk.") - http.Error(w, err.Error(), trace.ErrorToCode(err)) - } - - // authenticate first - sctx, site, err := h.authenticateRequestWithCluster(w, r, p) - if err != nil { - onError(trace.Wrap(err)) - return - } - - // get the session - sid, err := session.ParseID(p.ByName("sid")) - if err != nil { - onError(trace.Wrap(err)) - return - } - clt, err := sctx.GetUserClient(r.Context(), site) - if err != nil { - onError(trace.Wrap(err)) - return - } - - // look at 'offset' parameter - // (skip error check and treat an invalid offset as offset 0) - query := r.URL.Query() - offset, _ := strconv.Atoi(query.Get("offset")) - - max, err := strconv.Atoi(query.Get("bytes")) - if err != nil || max <= 0 { - max = maxStreamBytes - } - if max > maxStreamBytes { - max = maxStreamBytes - } - - // call the site API to get the chunk: - bytes, err := clt.GetSessionChunk(apidefaults.Namespace, *sid, offset, max) - if err != nil { - onError(trace.Wrap(err)) - return - } - // see if we can gzip it: - var writer io.Writer = w - for _, acceptedEnc := range strings.Split(r.Header.Get("Accept-Encoding"), ",") { - if strings.TrimSpace(acceptedEnc) == "gzip" { - gzipper := gzip.NewWriter(w) - writer = gzipper - defer gzipper.Close() - w.Header().Set("Content-Encoding", "gzip") - } - } - w.Header().Set("Content-Type", "application/octet-stream") - _, err = writer.Write(bytes) - if err != nil { - onError(trace.Wrap(err)) - return - } -} - type eventsListGetResponse struct { // Events is list of events retrieved. Events []events.EventFields `json:"events"` @@ -4128,45 +4045,6 @@ type eventsListGetResponse struct { StartKey string `json:"startKey"` } -// siteSessionEventsGet gets the site session by id -// -// GET /v1/webapi/sites/:site/namespaces/:namespace/sessions/:sid/events?after=N -// -// Query: -// -// "after" : cursor value of an event to return "newer than" events -// good for repeated polling -// -// Response body (each event is an arbitrary JSON structure) -// -// {"events": [{...}, {...}, ...} -func (h *Handler) siteSessionEventsGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { - sessionID, err := session.ParseID(p.ByName("sid")) - if err != nil { - return nil, trace.BadParameter("invalid session ID %q", p.ByName("sid")) - } - - clt, err := sctx.GetUserClient(r.Context(), site) - if err != nil { - return nil, trace.Wrap(err) - } - afterN, err := strconv.Atoi(r.URL.Query().Get("after")) - if err != nil { - afterN = 0 - } - - e, err := clt.GetSessionEvents(apidefaults.Namespace, *sessionID, afterN) - if err != nil { - h.log.WithError(err).Debugf("Unable to find events for session %v.", sessionID) - if trace.IsNotFound(err) { - return nil, trace.NotFound("unable to find events for session %q", sessionID) - } - - return nil, trace.Wrap(err) - } - return eventsListGetResponse{Events: e}, nil -} - // hostCredentials sends a registration token and metadata to the Auth Server // and gets back SSH and TLS certificates. func (h *Handler) hostCredentials(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index b4d67945c9f62..acafbda958202 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -7734,10 +7734,6 @@ func (mock authProviderMock) GetNode(ctx context.Context, namespace, name string return &mock.server, nil } -func (mock authProviderMock) GetSessionEvents(n string, s session.ID, c int) ([]events.EventFields, error) { - return []events.EventFields{}, nil -} - func (mock authProviderMock) GetSessionTracker(ctx context.Context, sessionID string) (types.SessionTracker, error) { return nil, trace.NotFound("foo") } diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 02c2d1cce313f..6c99ce864b52e 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -55,7 +55,6 @@ import ( wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/proxy" @@ -100,7 +99,6 @@ type TerminalRequest struct { // UserAuthClient is a subset of the Auth API that performs // operations on behalf of the user so that the correct RBAC is applied. type UserAuthClient interface { - GetSessionEvents(namespace string, sid session.ID, after int) ([]events.EventFields, error) GetSessionTracker(ctx context.Context, sessionID string) (types.SessionTracker, error) IsMFARequired(ctx context.Context, req *authproto.IsMFARequiredRequest) (*authproto.IsMFARequiredResponse, error) CreateAuthenticateChallenge(ctx context.Context, req *authproto.CreateAuthenticateChallengeRequest) (*authproto.MFAAuthenticateChallenge, error) From 86046b8ad814cf31aeb80fa55324b586c3ccb664 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 20 Sep 2024 10:37:03 +0100 Subject: [PATCH 03/13] Remove a bunch of unused helpers --- lib/events/auditlog.go | 340 ----------------------------------------- 1 file changed, 340 deletions(-) diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index 09d9d201c8e3a..62517fe2bed12 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -19,18 +19,12 @@ package events import ( - "bufio" - "compress/gzip" "context" - "encoding/json" "errors" - "fmt" "io" "io/fs" "os" "path/filepath" - "sort" - "strings" "sync" "time" @@ -393,279 +387,6 @@ func getAuthServers(dataDir string) ([]string, error) { return authServers, nil } -type sessionIndex struct { - dataDir string - namespace string - sid session.ID - events []indexEntry - enhancedEvents map[string][]indexEntry - chunks []indexEntry - indexFiles []string -} - -func (idx *sessionIndex) sort() { - sort.Slice(idx.events, func(i, j int) bool { - return idx.events[i].Index < idx.events[j].Index - }) - sort.Slice(idx.chunks, func(i, j int) bool { - return idx.chunks[i].Offset < idx.chunks[j].Offset - }) - - // Enhanced events. - for _, events := range idx.enhancedEvents { - sort.Slice(events, func(i, j int) bool { - return events[i].Index < events[j].Index - }) - } -} - -func (idx *sessionIndex) eventsFileName(index int) string { - entry := idx.events[index] - return filepath.Join(idx.dataDir, entry.authServer, SessionLogsDir, idx.namespace, entry.FileName) -} - -func (idx *sessionIndex) eventsFile(afterN int) (int, error) { - for i := len(idx.events) - 1; i >= 0; i-- { - entry := idx.events[i] - if int64(afterN) >= entry.Index { - return i, nil - } - } - return -1, trace.NotFound("%v not found", afterN) -} - -// chunkFileNames returns file names of all session chunk files -func (idx *sessionIndex) chunkFileNames() []string { - fileNames := make([]string, len(idx.chunks)) - for i := 0; i < len(idx.chunks); i++ { - fileNames[i] = idx.chunksFileName(i) - } - return fileNames -} - -func (idx *sessionIndex) chunksFile(offset int64) (string, int64, error) { - for i := len(idx.chunks) - 1; i >= 0; i-- { - entry := idx.chunks[i] - if offset >= entry.Offset { - return idx.chunksFileName(i), entry.Offset, nil - } - } - return "", 0, trace.NotFound("offset %v not found for session %v", offset, idx.sid) -} - -func (idx *sessionIndex) chunksFileName(index int) string { - entry := idx.chunks[index] - return filepath.Join(idx.dataDir, entry.authServer, SessionLogsDir, idx.namespace, entry.FileName) -} - -func (l *AuditLog) readSessionIndex(namespace string, sid session.ID) (*sessionIndex, error) { - index, err := readSessionIndex(l.DataDir, []string{PlaybackDir}, namespace, sid) - if err == nil { - return index, nil - } - if !trace.IsNotFound(err) { - return nil, trace.Wrap(err) - } - // some legacy records may be stored unpacked in the JSON format - // in the data dir, under server format - authServers, err := getAuthServers(l.DataDir) - if err != nil { - return nil, trace.Wrap(err) - } - return readSessionIndex(l.DataDir, authServers, namespace, sid) -} - -func readSessionIndex(dataDir string, authServers []string, namespace string, sid session.ID) (*sessionIndex, error) { - index := sessionIndex{ - sid: sid, - dataDir: dataDir, - namespace: namespace, - enhancedEvents: map[string][]indexEntry{ - SessionCommandEvent: {}, - SessionDiskEvent: {}, - SessionNetworkEvent: {}, - }, - } - for _, authServer := range authServers { - indexFileName := filepath.Join(dataDir, authServer, SessionLogsDir, namespace, fmt.Sprintf("%v.index", sid)) - indexFile, err := os.OpenFile(indexFileName, os.O_RDONLY, 0o640) - err = trace.ConvertSystemError(err) - if err != nil { - if !trace.IsNotFound(err) { - return nil, trace.Wrap(err) - } - continue - } - index.indexFiles = append(index.indexFiles, indexFileName) - - entries, err := readIndexEntries(indexFile, authServer) - if err != nil { - return nil, trace.Wrap(err) - } - for _, entry := range entries { - switch entry.Type { - case fileTypeEvents: - index.events = append(index.events, entry) - case fileTypeChunks: - index.chunks = append(index.chunks, entry) - // Enhanced events. - case SessionCommandEvent, SessionDiskEvent, SessionNetworkEvent: - index.enhancedEvents[entry.Type] = append(index.enhancedEvents[entry.Type], entry) - default: - return nil, trace.BadParameter("found unknown event type: %q", entry.Type) - } - } - - err = indexFile.Close() - if err != nil { - return nil, trace.Wrap(err) - } - } - - if len(index.indexFiles) == 0 { - return nil, trace.NotFound("session %q not found", sid) - } - - index.sort() - return &index, nil -} - -func readIndexEntries(file *os.File, authServer string) ([]indexEntry, error) { - var entries []indexEntry - - scanner := bufio.NewScanner(file) - for lineNo := 0; scanner.Scan(); lineNo++ { - var entry indexEntry - if err := json.Unmarshal(scanner.Bytes(), &entry); err != nil { - return nil, trace.Wrap(err) - } - entry.authServer = authServer - entries = append(entries, entry) - } - - return entries, nil -} - -// createOrGetDownload creates a new download sync entry for a given session, -// if there is no active download in progress, or returns an existing one. -// if the new context has been created, cancel function is returned as a -// second argument. Caller should call this function to signal that download has been -// completed or failed. -func (l *AuditLog) createOrGetDownload(path string) (context.Context, context.CancelFunc) { - l.Lock() - defer l.Unlock() - ctx, ok := l.activeDownloads[path] - if ok { - return ctx, nil - } - ctx, cancel := context.WithCancel(context.TODO()) - l.activeDownloads[path] = ctx - return ctx, func() { - cancel() - l.Lock() - defer l.Unlock() - delete(l.activeDownloads, path) - } -} - -func (l *AuditLog) downloadSession(namespace string, sid session.ID) error { - tarballPath := filepath.Join(l.playbackDir, string(sid)+".tar") - - ctx, cancel := l.createOrGetDownload(tarballPath) - // means that another download is in progress, so simply wait until - // it finishes - if cancel == nil { - l.log.Debugf("Another download is in progress for %v, waiting until it gets completed.", sid) - select { - case <-ctx.Done(): - return nil - case <-l.ctx.Done(): - return trace.BadParameter("audit log is closing, aborting the download") - } - } - defer cancel() - _, err := os.Stat(tarballPath) - err = trace.ConvertSystemError(err) - if err == nil { - l.log.Debugf("Recording %v is already downloaded and unpacked to %v.", sid, tarballPath) - return nil - } - if !trace.IsNotFound(err) { - return trace.Wrap(err) - } - start := time.Now() - l.log.Debugf("Starting download of %v.", sid) - tarball, err := os.OpenFile(tarballPath, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0o640) - if err != nil { - return trace.ConvertSystemError(err) - } - defer func() { - if err := tarball.Close(); err != nil { - l.log.WithError(err).Errorf("Failed to close file %q.", tarballPath) - } - }() - if err := l.UploadHandler.Download(l.ctx, sid, tarball); err != nil { - // remove partially downloaded tarball - if rmErr := os.Remove(tarballPath); rmErr != nil { - l.log.WithError(rmErr).Warningf("Failed to remove file %v.", tarballPath) - } - return trace.Wrap(err) - } - l.log.WithField("duration", time.Since(start)).Debugf("Downloaded %v to %v.", sid, tarballPath) - - _, err = tarball.Seek(0, 0) - if err != nil { - return trace.ConvertSystemError(err) - } - format, err := DetectFormat(tarball) - if err != nil { - l.log.WithError(err).Debugf("Failed to detect playback %v format.", tarballPath) - return trace.Wrap(err) - } - _, err = tarball.Seek(0, 0) - if err != nil { - return trace.ConvertSystemError(err) - } - switch { - case format.Proto: - start = time.Now() - l.log.Debugf("Converting %v to playback format.", tarballPath) - protoReader := NewProtoReader(tarball) - _, err = WriteForSSHPlayback(l.Context, sid, protoReader, l.playbackDir) - if err != nil { - l.log.WithError(err).Error("Failed to convert.") - return trace.Wrap(err) - } - stats := protoReader.GetStats().ToFields() - stats["duration"] = time.Since(start) - l.log.WithFields(stats).Debugf("Converted %v to %v.", tarballPath, l.playbackDir) - case format.Tar: - if err := utils.Extract(tarball, l.playbackDir); err != nil { - return trace.Wrap(err) - } - default: - return trace.BadParameter("Unexpected format %v.", format) - } - - // Extract every chunks file on disk while holding the context, - // otherwise parallel downloads will try to unpack the file at the same time. - idx, err := l.readSessionIndex(namespace, sid) - if err != nil { - return trace.Wrap(err) - } - for _, fileName := range idx.chunkFileNames() { - reader, err := l.unpackFile(fileName) - if err != nil { - return trace.Wrap(err) - } - if err := reader.Close(); err != nil { - l.log.Warningf("Failed to close file: %v.", err) - } - } - l.log.WithField("duration", time.Since(start)).Debugf("Unpacked %v to %v.", tarballPath, l.playbackDir) - return nil -} - func (l *AuditLog) cleanupOldPlaybacks() error { // scan the log directory and clean files last // accessed after an hour @@ -698,67 +419,6 @@ func (l *AuditLog) cleanupOldPlaybacks() error { return nil } -type readSeekCloser interface { - io.Reader - io.Seeker - io.Closer -} - -func (l *AuditLog) unpackFile(fileName string) (readSeekCloser, error) { - basename := filepath.Base(fileName) - unpackedFile := filepath.Join(l.playbackDir, strings.TrimSuffix(basename, filepath.Ext(basename))) - - // If client has called GetSessionChunk before session is over - // this could lead to cases when not all data will be returned, - // because unpackFile will be called concurrently with the unfinished write - unpackedInfo, err := os.Stat(unpackedFile) - err = trace.ConvertSystemError(err) - switch { - case err != nil && !trace.IsNotFound(err): - return nil, trace.Wrap(err) - case err == nil: - packedInfo, err := os.Stat(fileName) - if err != nil { - return nil, trace.ConvertSystemError(err) - } - // no new data has been added - if unpackedInfo.ModTime().Unix() >= packedInfo.ModTime().Unix() { - return os.OpenFile(unpackedFile, os.O_RDONLY, 0o640) - } - } - - start := l.Clock.Now() - dest, err := os.OpenFile(unpackedFile, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0o640) - if err != nil { - return nil, trace.ConvertSystemError(err) - } - source, err := os.OpenFile(fileName, os.O_RDONLY, 0o640) - if err != nil { - return nil, trace.ConvertSystemError(err) - } - defer source.Close() - reader, err := gzip.NewReader(source) - if err != nil { - return nil, trace.Wrap(err) - } - defer reader.Close() - if _, err := io.Copy(dest, reader); err != nil { - // Unexpected EOF is returned by gzip reader - // when the file has not been closed yet, - // ignore this error - if !errors.Is(err, io.ErrUnexpectedEOF) { - dest.Close() - return nil, trace.Wrap(err) - } - } - if _, err := dest.Seek(0, 0); err != nil { - dest.Close() - return nil, trace.Wrap(err) - } - l.log.Debugf("Uncompressed %v into %v in %v", fileName, unpackedFile, l.Clock.Now().Sub(start)) - return dest, nil -} - // EmitAuditEvent adds a new event to the local file log func (l *AuditLog) EmitAuditEvent(ctx context.Context, event apievents.AuditEvent) error { ctx = context.WithoutCancel(ctx) From ec5d0ed73900d2d25406849ccade8e4bedddaa60 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 20 Sep 2024 11:41:58 +0100 Subject: [PATCH 04/13] Switch integration tests to use new StreamSessionEvents API --- integration/integration_test.go | 217 ++++++++++++--------------- integration/kube_integration_test.go | 62 +++++++- lib/events/api.go | 6 - 3 files changed, 152 insertions(+), 133 deletions(-) diff --git a/integration/integration_test.go b/integration/integration_test.go index 17a73d1e4f260..8855fac1d5611 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -544,108 +544,60 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { } } - // read back the entire session (we have to try several times until we get back - // everything because the session is closing) - var sessionStream []byte - for i := 0; i < 6; i++ { - sessionStream, err = site.GetSessionChunk(defaults.Namespace, session.ID(tracker.GetSessionID()), 0, events.MaxChunkBytes) - require.NoError(t, err) - if strings.Contains(string(sessionStream), "exit") { - break - } - time.Sleep(time.Millisecond * 250) - if i >= 5 { - // session stream keeps coming back short - t.Fatalf("%s: Stream is not getting data: %q.", tt.comment, string(sessionStream)) - } - } - - // see what we got. It looks different based on bash settings, but here it is - // on Ev's machine (hostname is 'edsger'): - // - // edsger ~: echo hi - // hi - // edsger ~: exit - // logout - // - text := string(sessionStream) - require.Contains(t, text, "echo hi") - require.Contains(t, text, "exit") - - // Wait until session.start, session.leave, and session.end events have arrived. - getSessions := func(site authclient.ClientI) ([]events.EventFields, error) { - tickCh := time.Tick(500 * time.Millisecond) - stopCh := time.After(10 * time.Second) - for { - select { - case <-tickCh: - // Get all session events from the backend. - sessionEvents, err := site.GetSessionEvents(defaults.Namespace, session.ID(tracker.GetSessionID()), 0) - if err != nil { - return nil, trace.Wrap(err) - } - - // Look through all session events for the three wanted. - var hasStart bool - var hasEnd bool - var hasLeave bool - for _, se := range sessionEvents { - var isAuditEvent bool - if se.GetType() == events.SessionStartEvent { - isAuditEvent = true - hasStart = true - } - if se.GetType() == events.SessionEndEvent { - isAuditEvent = true - hasEnd = true - } - if se.GetType() == events.SessionLeaveEvent { - isAuditEvent = true - hasLeave = true - } - - // ensure session events are also in audit log - if !isAuditEvent { - continue - } - auditEvents, _, err := site.SearchEvents(ctx, events.SearchEventsRequest{ - To: time.Now(), - EventTypes: []string{se.GetType()}, - }) - require.NoError(t, err) - - found := slices.ContainsFunc(auditEvents, func(ae apievents.AuditEvent) bool { - return ae.GetID() == se.GetID() - }) - require.True(t, found) - } - - // Make sure all three events were found. - if hasStart && hasEnd && hasLeave { - return sessionEvents, nil - } - case <-stopCh: - return nil, trace.BadParameter("unable to find all session events after 10s (mode=%v)", tt.inRecordLocation) + // Stream all the session events into a slice to make them easier + // to work with. + evtCh, errCh := site.StreamSessionEvents(ctx, session.ID(tracker.GetSessionID()), 0) + sessionEvents := make([]apievents.AuditEvent, 0) + readLoop: + for { + select { + case evt := <-evtCh: + if evt == nil { + break readLoop } + sessionEvents = append(sessionEvents, evt) + case err := <-errCh: + require.NoError(t, err) } } - history, err := getSessions(site) - require.NoError(t, err) - getChunk := func(e events.EventFields, maxlen int) string { - offset := e.GetInt("offset") - length := e.GetInt("bytes") - if length == 0 { - return "" + var hasStart bool + var hasEnd bool + var hasLeave bool + for _, se := range sessionEvents { + var isAuditEvent bool + if se.GetType() == events.SessionStartEvent { + isAuditEvent = true + hasStart = true + } + if se.GetType() == events.SessionEndEvent { + isAuditEvent = true + hasEnd = true } - if length > maxlen { - length = maxlen + if se.GetType() == events.SessionLeaveEvent { + isAuditEvent = true + hasLeave = true } - return string(sessionStream[offset : offset+length]) + + // ensure session events are also in audit log + if !isAuditEvent { + continue + } + auditEvents, _, err := site.SearchEvents(ctx, events.SearchEventsRequest{ + To: time.Now(), + EventTypes: []string{se.GetType()}, + }) + require.NoError(t, err) + + found := slices.ContainsFunc(auditEvents, func(ae apievents.AuditEvent) bool { + return ae.GetID() == se.GetID() + }) + require.True(t, found) } + require.True(t, hasStart && hasEnd && hasLeave) - findByType := func(et string) events.EventFields { - for _, e := range history { + findByType := func(et string) apievents.AuditEvent { + for _, e := range sessionEvents { if e.GetType() == et { return e } @@ -654,38 +606,38 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { } // there should always be 'session.start' event (and it must be first) - first := history[0] - start := findByType(events.SessionStartEvent) + first := sessionEvents[0].(*apievents.SessionStart) + start := findByType(events.SessionStartEvent).(*apievents.SessionStart) require.Equal(t, first, start) - require.Equal(t, 0, start.GetInt("bytes")) - require.Equal(t, sessionID, start.GetString(events.SessionEventID)) - require.NotEmpty(t, start.GetString(events.TerminalSize)) - - // make sure data is recorded properly - out := &bytes.Buffer{} - for _, e := range history { - out.WriteString(getChunk(e, 1000)) - } - recorded := replaceNewlines(out.String()) - require.Regexp(t, ".*exit.*", recorded) - require.Regexp(t, ".*echo hi.*", recorded) + require.Equal(t, sessionID, start.SessionID) + require.NotEmpty(t, start.TerminalSize) // there should always be 'session.end' event - end := findByType(events.SessionEndEvent) + end := findByType(events.SessionEndEvent).(*apievents.SessionEnd) require.NotNil(t, end) - require.Equal(t, 0, end.GetInt("bytes")) - require.Equal(t, sessionID, end.GetString(events.SessionEventID)) + require.Equal(t, sessionID, end.SessionID) // there should always be 'session.leave' event - leave := findByType(events.SessionLeaveEvent) + leave := findByType(events.SessionLeaveEvent).(*apievents.SessionLeave) require.NotNil(t, leave) - require.Equal(t, 0, leave.GetInt("bytes")) - require.Equal(t, sessionID, leave.GetString(events.SessionEventID)) + require.Equal(t, sessionID, leave.SessionID) // all of them should have a proper time - for _, e := range history { - require.False(t, e.GetTime("time").IsZero()) + for _, e := range sessionEvents { + require.False(t, e.GetTime().IsZero()) } + + // Check data was recorded properly + out := &bytes.Buffer{} + for _, e := range sessionEvents { + if e.GetType() != events.SessionPrintEvent { + continue + } + out.Write(e.(*apievents.SessionPrint).Data) + } + recorded := replaceNewlines(out.String()) + require.Regexp(t, ".*exit.*", recorded) + require.Regexp(t, ".*echo hi.*", recorded) }) } } @@ -1279,9 +1231,22 @@ func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) { } require.EventuallyWithT(t, func(t *assert.CollectT) { - events, err := authSrv.GetSessionEvents(defaults.Namespace, session.ID(sessionID), 0) - assert.NoError(t, err) - assert.NotEmpty(t, events) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + eventsCh, errCh := authSrv.StreamSessionEvents(ctx, session.ID(sessionID), 0) + for { + select { + case err := <-errCh: + assert.NoError(t, err) + return + case evt := <-eventsCh: + if evt != nil { + return + } + assert.Fail(t, "expected event, got nil") + return + } + } }, 15*time.Second, 200*time.Millisecond) }) } @@ -5018,7 +4983,19 @@ func testAuditOff(t *testing.T, suite *integrationTestSuite) { // however, attempts to read the actual sessions should fail because it was // not actually recorded - _, err = site.GetSessionChunk(defaults.Namespace, session.ID(tracker.GetSessionID()), 0, events.MaxChunkBytes) + eventsCh, errCh := site.StreamSessionEvents(ctx, session.ID(tracker.GetSessionID()), 0) + err = nil +readLoop: + for { + select { + case evt := <-eventsCh: + if evt != nil { + t.Fatalf("Unexpected event: %v", evt) + } + case err = <-errCh: + break readLoop + } + } require.Error(t, err) // ensure that session related events were emitted to audit log diff --git a/integration/kube_integration_test.go b/integration/kube_integration_test.go index 5f4761a96a870..314749579d415 100644 --- a/integration/kube_integration_test.go +++ b/integration/kube_integration_test.go @@ -69,9 +69,9 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/breaker" "github.com/gravitational/teleport/api/constants" - apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/profile" "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/entitlements" "github.com/gravitational/teleport/integration/helpers" "github.com/gravitational/teleport/integration/kube" @@ -343,10 +343,26 @@ loop: } // read back the entire session and verify that it matches the stated output - capturedStream, err := teleport.Process.GetAuthServer().GetSessionChunk(apidefaults.Namespace, session.ID(sessionID), 0, events.MaxChunkBytes) + evtCh, errCh := teleport.Process.GetAuthServer().StreamSessionEvents(ctx, session.ID(sessionID), 0) require.NoError(t, err) + capturedStream := &bytes.Buffer{} +readLoop: + for { + select { + case evt := <-evtCh: + if evt == nil { + break readLoop + } + if evt.GetType() != events.SessionPrintEvent { + continue + } + capturedStream.Write(evt.(*apievents.SessionPrint).Data) + case err := <-errCh: + require.NoError(t, err) + } + } - require.Equal(t, sessionStream, string(capturedStream)) + require.Equal(t, sessionStream, capturedStream.String()) // impersonating kube exec should be denied // interactive command, allocate pty @@ -778,10 +794,26 @@ loop: } // read back the entire session and verify that it matches the stated output - capturedStream, err := main.Process.GetAuthServer().GetSessionChunk(apidefaults.Namespace, session.ID(sessionID), 0, events.MaxChunkBytes) + evtCh, errCh := main.Process.GetAuthServer().StreamSessionEvents(ctx, session.ID(sessionID), 0) require.NoError(t, err) + capturedStream := &bytes.Buffer{} +readLoop: + for { + select { + case evt := <-evtCh: + if evt == nil { + break readLoop + } + if evt.GetType() != events.SessionPrintEvent { + continue + } + capturedStream.Write(evt.(*apievents.SessionPrint).Data) + case err := <-errCh: + require.NoError(t, err) + } + } - require.Equal(t, sessionStream, string(capturedStream)) + require.Equal(t, sessionStream, capturedStream.String()) // impersonating kube exec should be denied // interactive command, allocate pty @@ -1052,10 +1084,26 @@ loop: } // read back the entire session and verify that it matches the stated output - capturedStream, err := main.Process.GetAuthServer().GetSessionChunk(apidefaults.Namespace, session.ID(sessionID), 0, events.MaxChunkBytes) + evtCh, errCh := main.Process.GetAuthServer().StreamSessionEvents(ctx, session.ID(sessionID), 0) require.NoError(t, err) + capturedStream := &bytes.Buffer{} +readLoop: + for { + select { + case evt := <-evtCh: + if evt == nil { + break readLoop + } + if evt.GetType() != events.SessionPrintEvent { + continue + } + capturedStream.Write(evt.(*apievents.SessionPrint).Data) + case err := <-errCh: + require.NoError(t, err) + } + } - require.Equal(t, sessionStream, string(capturedStream)) + require.Equal(t, sessionStream, capturedStream.String()) // impersonating kube exec should be denied // interactive command, allocate pty diff --git a/lib/events/api.go b/lib/events/api.go index 6aa605d1bb3ed..d6b9de473360d 100644 --- a/lib/events/api.go +++ b/lib/events/api.go @@ -813,12 +813,6 @@ const ( // Add an entry to eventsMap in lib/events/events_test.go when you add // a new event name here. -const ( - // MaxChunkBytes defines the maximum size of a session stream chunk that - // can be requested via AuditLog.GetSessionChunk(). Set to 5MB - MaxChunkBytes = 1024 * 1024 * 5 -) - const ( // V1 is the V1 version of slice chunks API, // it is 0 because it was not defined before From 13e49973b8a5315c53a5989ab07835c8038e3afb Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 20 Sep 2024 12:14:48 +0100 Subject: [PATCH 05/13] Add helper for draining events/stream --- integration/integration_test.go | 51 ++++++++++++--------- integration/kube_integration_test.go | 68 +++------------------------- 2 files changed, 35 insertions(+), 84 deletions(-) diff --git a/integration/integration_test.go b/integration/integration_test.go index 8855fac1d5611..e75276d3082de 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -546,20 +546,7 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { // Stream all the session events into a slice to make them easier // to work with. - evtCh, errCh := site.StreamSessionEvents(ctx, session.ID(tracker.GetSessionID()), 0) - sessionEvents := make([]apievents.AuditEvent, 0) - readLoop: - for { - select { - case evt := <-evtCh: - if evt == nil { - break readLoop - } - sessionEvents = append(sessionEvents, evt) - case err := <-errCh: - require.NoError(t, err) - } - } + capturedStream, sessionEvents := streamSession(ctx, t, site, sessionID) var hasStart bool var hasEnd bool @@ -628,20 +615,40 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { } // Check data was recorded properly - out := &bytes.Buffer{} - for _, e := range sessionEvents { - if e.GetType() != events.SessionPrintEvent { - continue - } - out.Write(e.(*apievents.SessionPrint).Data) - } - recorded := replaceNewlines(out.String()) + recorded := replaceNewlines(capturedStream) require.Regexp(t, ".*exit.*", recorded) require.Regexp(t, ".*echo hi.*", recorded) }) } } +func streamSession( + ctx context.Context, + t *testing.T, + streamer events.SessionStreamer, + sessionID string, +) (string, []apievents.AuditEvent) { + evtCh, errCh := streamer.StreamSessionEvents(ctx, session.ID(sessionID), 0) + capturedStream := &bytes.Buffer{} + evts := make([]apievents.AuditEvent, 0) +readLoop: + for { + select { + case evt := <-evtCh: + if evt == nil { + break readLoop + } + if evt.GetType() != events.SessionPrintEvent { + capturedStream.Write(evt.(*apievents.SessionPrint).Data) + } + evts = append(evts, evt) + case err := <-errCh: + require.NoError(t, err) + } + } + return capturedStream.String(), evts +} + // testInteroperability checks if Teleport and OpenSSH behave in the same way // when executing commands. func testInteroperability(t *testing.T, suite *integrationTestSuite) { diff --git a/integration/kube_integration_test.go b/integration/kube_integration_test.go index 314749579d415..9c25eccdffad2 100644 --- a/integration/kube_integration_test.go +++ b/integration/kube_integration_test.go @@ -71,7 +71,6 @@ import ( "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/profile" "github.com/gravitational/teleport/api/types" - apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/entitlements" "github.com/gravitational/teleport/integration/helpers" "github.com/gravitational/teleport/integration/kube" @@ -79,7 +78,6 @@ import ( "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/events" kubeutils "github.com/gravitational/teleport/lib/kube/utils" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/service" @@ -343,26 +341,8 @@ loop: } // read back the entire session and verify that it matches the stated output - evtCh, errCh := teleport.Process.GetAuthServer().StreamSessionEvents(ctx, session.ID(sessionID), 0) - require.NoError(t, err) - capturedStream := &bytes.Buffer{} -readLoop: - for { - select { - case evt := <-evtCh: - if evt == nil { - break readLoop - } - if evt.GetType() != events.SessionPrintEvent { - continue - } - capturedStream.Write(evt.(*apievents.SessionPrint).Data) - case err := <-errCh: - require.NoError(t, err) - } - } - - require.Equal(t, sessionStream, capturedStream.String()) + capturedStream, _ := streamSession(ctx, t, teleport.Process.GetAuthServer(), sessionID) + require.Equal(t, sessionStream, capturedStream) // impersonating kube exec should be denied // interactive command, allocate pty @@ -794,26 +774,8 @@ loop: } // read back the entire session and verify that it matches the stated output - evtCh, errCh := main.Process.GetAuthServer().StreamSessionEvents(ctx, session.ID(sessionID), 0) - require.NoError(t, err) - capturedStream := &bytes.Buffer{} -readLoop: - for { - select { - case evt := <-evtCh: - if evt == nil { - break readLoop - } - if evt.GetType() != events.SessionPrintEvent { - continue - } - capturedStream.Write(evt.(*apievents.SessionPrint).Data) - case err := <-errCh: - require.NoError(t, err) - } - } - - require.Equal(t, sessionStream, capturedStream.String()) + capturedStream, _ := streamSession(ctx, t, main.Process.GetAuthServer(), sessionID) + require.Equal(t, sessionStream, capturedStream) // impersonating kube exec should be denied // interactive command, allocate pty @@ -1084,26 +1046,8 @@ loop: } // read back the entire session and verify that it matches the stated output - evtCh, errCh := main.Process.GetAuthServer().StreamSessionEvents(ctx, session.ID(sessionID), 0) - require.NoError(t, err) - capturedStream := &bytes.Buffer{} -readLoop: - for { - select { - case evt := <-evtCh: - if evt == nil { - break readLoop - } - if evt.GetType() != events.SessionPrintEvent { - continue - } - capturedStream.Write(evt.(*apievents.SessionPrint).Data) - case err := <-errCh: - require.NoError(t, err) - } - } - - require.Equal(t, sessionStream, capturedStream.String()) + capturedStream, _ := streamSession(ctx, t, main.Process.GetAuthServer(), sessionID) + require.Equal(t, sessionStream, capturedStream) // impersonating kube exec should be denied // interactive command, allocate pty From 24efd2a3ad222840e388033a971cc7b6383095d1 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 20 Sep 2024 12:17:55 +0100 Subject: [PATCH 06/13] Fix dodgy comparison in helper --- integration/integration_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration/integration_test.go b/integration/integration_test.go index e75276d3082de..c17c07e688f54 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -638,7 +638,7 @@ readLoop: if evt == nil { break readLoop } - if evt.GetType() != events.SessionPrintEvent { + if evt.GetType() == events.SessionPrintEvent { capturedStream.Write(evt.(*apievents.SessionPrint).Data) } evts = append(evts, evt) From 7a92ec91b84ecbc2a50a465477552eb843308d49 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 20 Sep 2024 15:12:18 +0100 Subject: [PATCH 07/13] Remove more unused fields/types --- lib/events/sessionlog.go | 9 ++++----- lib/web/apiserver.go | 2 -- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/lib/events/sessionlog.go b/lib/events/sessionlog.go index 76f820d0d4195..7511ada732026 100644 --- a/lib/events/sessionlog.go +++ b/lib/events/sessionlog.go @@ -56,11 +56,10 @@ func chunksFileName(dataDir string, sessionID session.ID, offset int64) string { } type indexEntry struct { - FileName string `json:"file_name"` - Type string `json:"type"` - Index int64 `json:"index"` - Offset int64 `json:"offset,"` - authServer string + FileName string `json:"file_name"` + Type string `json:"type"` + Index int64 `json:"index"` + Offset int64 `json:"offset,"` } // gzipWriter wraps file, on close close both gzip writer and file diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 6ee3c628f7ca4..0e05fb23c9c2d 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -3852,8 +3852,6 @@ func (h *Handler) clusterActiveAndPendingSessionsGet(w http.ResponseWriter, r *h return siteSessionsGetResponse{Sessions: sessions}, nil } -const maxStreamBytes = 5 * 1024 * 1024 - func toFieldsSlice(rawEvents []apievents.AuditEvent) ([]events.EventFields, error) { el := make([]events.EventFields, 0, len(rawEvents)) for _, event := range rawEvents { From cd72217a0cb3aff7eb3ba3849d96be898c0afbc1 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 20 Sep 2024 15:35:12 +0100 Subject: [PATCH 08/13] Update integration/integration_test.go Co-authored-by: Zac Bergquist --- integration/integration_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/integration/integration_test.go b/integration/integration_test.go index c17c07e688f54..50294b50b665a 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -628,6 +628,7 @@ func streamSession( streamer events.SessionStreamer, sessionID string, ) (string, []apievents.AuditEvent) { + t.Helper() evtCh, errCh := streamer.StreamSessionEvents(ctx, session.ID(sessionID), 0) capturedStream := &bytes.Buffer{} evts := make([]apievents.AuditEvent, 0) From 94b892512e48f51f81d48cda43d24de9b1d002d4 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 20 Sep 2024 15:42:28 +0100 Subject: [PATCH 09/13] Remove SSHPlaybackWriter --- lib/events/playback.go | 292 ----------------------------------------- 1 file changed, 292 deletions(-) diff --git a/lib/events/playback.go b/lib/events/playback.go index b6390b01c417b..8c8c21cfdf7d7 100644 --- a/lib/events/playback.go +++ b/lib/events/playback.go @@ -20,22 +20,15 @@ package events import ( "archive/tar" - "bufio" - "compress/gzip" "context" "encoding/binary" "errors" "fmt" "io" - "os" - "path/filepath" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport" - apievents "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/utils" ) @@ -125,288 +118,3 @@ func Export(ctx context.Context, rs io.ReadSeeker, w io.Writer, exportFormat str return trace.BadParameter("unsupported format %v", format) } } - -// WriteForSSHPlayback reads events from an SessionReader and writes them to disk in a format optimized for playback. -func WriteForSSHPlayback(ctx context.Context, sid session.ID, reader SessionReader, dir string) (*SSHPlaybackWriter, error) { - w := &SSHPlaybackWriter{ - sid: sid, - reader: reader, - dir: dir, - eventIndex: -1, - } - defer func() { - if err := w.Close(); err != nil { - log.WithError(err).Warningf("Failed to close writer.") - } - }() - return w, w.Write(ctx) -} - -// SessionEvents returns slice of event fields from gzipped events file. -func (w *SSHPlaybackWriter) SessionEvents() ([]EventFields, error) { - var sessionEvents []EventFields - // events - eventFile, err := os.Open(w.EventsPath) - if err != nil { - return nil, trace.Wrap(err) - } - defer eventFile.Close() - - grEvents, err := gzip.NewReader(eventFile) - if err != nil { - return nil, trace.Wrap(err) - } - defer grEvents.Close() - scanner := bufio.NewScanner(grEvents) - for scanner.Scan() { - var f EventFields - err := utils.FastUnmarshal(scanner.Bytes(), &f) - if err != nil { - if errors.Is(err, io.EOF) { - return sessionEvents, nil - } - return nil, trace.Wrap(err) - } - sessionEvents = append(sessionEvents, f) - } - - if err := scanner.Err(); err != nil { - return nil, trace.Wrap(err) - } - - return sessionEvents, nil -} - -// SessionChunks interprets the file at the given path as gzip-compressed list of session events and returns -// the uncompressed contents as a result. -func (w *SSHPlaybackWriter) SessionChunks() ([]byte, error) { - var stream []byte - chunkFile, err := os.Open(w.ChunksPath) - if err != nil { - return nil, trace.Wrap(err) - } - defer chunkFile.Close() - grChunk, err := gzip.NewReader(chunkFile) - if err != nil { - return nil, trace.Wrap(err) - } - defer grChunk.Close() - stream, err = io.ReadAll(grChunk) - if err != nil { - return nil, trace.Wrap(err) - } - return stream, nil -} - -// SSHPlaybackWriter reads messages from an SessionReader and writes them -// to disk in a format suitable for SSH session playback. -type SSHPlaybackWriter struct { - sid session.ID - dir string - reader SessionReader - indexFile *os.File - eventsFile *gzipWriter - chunksFile *gzipWriter - eventIndex int64 - EventsPath string - ChunksPath string -} - -// Close closes all files -func (w *SSHPlaybackWriter) Close() error { - if w.indexFile != nil { - if err := w.indexFile.Close(); err != nil { - log.Warningf("Failed to close index file: %v.", err) - } - w.indexFile = nil - } - - if w.chunksFile != nil { - if err := w.chunksFile.Flush(); err != nil { - log.Warningf("Failed to flush chunks file: %v.", err) - } - - if err := w.chunksFile.Close(); err != nil { - log.Warningf("Failed closing chunks file: %v.", err) - } - } - - if w.eventsFile != nil { - if err := w.eventsFile.Flush(); err != nil { - log.Warningf("Failed to flush events file: %v.", err) - } - - if err := w.eventsFile.Close(); err != nil { - log.Warningf("Failed closing events file: %v.", err) - } - } - - return nil -} - -// Write writes all events from the SessionReader and writes -// files to disk in the format optimized for playback. -func (w *SSHPlaybackWriter) Write(ctx context.Context) error { - if err := w.openIndexFile(); err != nil { - return trace.Wrap(err) - } - for { - event, err := w.reader.Read(ctx) - if err != nil { - if errors.Is(err, io.EOF) { - return nil - } - return trace.Wrap(err) - } - if err := w.writeEvent(event); err != nil { - return trace.Wrap(err) - } - } -} - -func (w *SSHPlaybackWriter) writeEvent(event apievents.AuditEvent) error { - switch event.GetType() { - // Timing events for TTY playback go to both a chunks file (the raw bytes) as - // well as well as the events file (structured events). - case SessionPrintEvent: - return trace.Wrap(w.writeSessionPrintEvent(event)) - - // Playback does not use enhanced events at the moment, - // so they are skipped - case SessionCommandEvent, SessionDiskEvent, SessionNetworkEvent: - return nil - - // PlaybackWriter is not used for desktop playback, so we should never see - // these events, but skip them if a user or developer somehow tries to playback - // a desktop session using this TTY PlaybackWriter - case DesktopRecordingEvent: - return nil - - // All other events get put into the general events file. These are events like - // session.join, session.end, etc. - default: - return trace.Wrap(w.writeRegularEvent(event)) - } -} - -func (w *SSHPlaybackWriter) writeSessionPrintEvent(event apievents.AuditEvent) error { - print, ok := event.(*apievents.SessionPrint) - if !ok { - return trace.BadParameter("expected session print event, got %T", event) - } - w.eventIndex++ - event.SetIndex(w.eventIndex) - if err := w.openEventsFile(0); err != nil { - return trace.Wrap(err) - } - if err := w.openChunksFile(0); err != nil { - return trace.Wrap(err) - } - data := print.Data - print.Data = nil - bytes, err := utils.FastMarshal(event) - if err != nil { - return trace.Wrap(err) - } - _, err = w.eventsFile.Write(append(bytes, '\n')) - if err != nil { - return trace.Wrap(err) - } - _, err = w.chunksFile.Write(data) - if err != nil { - return trace.Wrap(err) - } - return nil -} - -func (w *SSHPlaybackWriter) writeRegularEvent(event apievents.AuditEvent) error { - w.eventIndex++ - event.SetIndex(w.eventIndex) - if err := w.openEventsFile(0); err != nil { - return trace.Wrap(err) - } - bytes, err := utils.FastMarshal(event) - if err != nil { - return trace.Wrap(err) - } - _, err = w.eventsFile.Write(append(bytes, '\n')) - if err != nil { - return trace.Wrap(err) - } - return nil -} - -func (w *SSHPlaybackWriter) openIndexFile() error { - if w.indexFile != nil { - return nil - } - var err error - w.indexFile, err = os.OpenFile( - filepath.Join(w.dir, fmt.Sprintf("%v.index", w.sid.String())), os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o640) - if err != nil { - return trace.Wrap(err) - } - return nil -} - -func (w *SSHPlaybackWriter) openEventsFile(eventIndex int64) error { - if w.eventsFile != nil { - return nil - } - w.EventsPath = eventsFileName(w.dir, w.sid, "", eventIndex) - - // update the index file to write down that new events file has been created - data, err := utils.FastMarshal(indexEntry{ - FileName: filepath.Base(w.EventsPath), - Type: fileTypeEvents, - Index: eventIndex, - }) - if err != nil { - return trace.Wrap(err) - } - - _, err = fmt.Fprintf(w.indexFile, "%v\n", string(data)) - if err != nil { - return trace.Wrap(err) - } - - // open new events file for writing - file, err := os.OpenFile(w.EventsPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o640) - if err != nil { - return trace.Wrap(err) - } - w.eventsFile = newGzipWriter(file) - return nil -} - -func (w *SSHPlaybackWriter) openChunksFile(offset int64) error { - if w.chunksFile != nil { - return nil - } - w.ChunksPath = chunksFileName(w.dir, w.sid, offset) - - // Update the index file to write down that new chunks file has been created. - data, err := utils.FastMarshal(indexEntry{ - FileName: filepath.Base(w.ChunksPath), - Type: fileTypeChunks, - Offset: offset, - }) - if err != nil { - return trace.Wrap(err) - } - - // index file will contain file name with extension .gz (assuming it was gzipped) - _, err = fmt.Fprintf(w.indexFile, "%v\n", string(data)) - if err != nil { - return trace.Wrap(err) - } - - // open the chunks file for writing, but because the file is written without - // compression, remove the .gz - file, err := os.OpenFile(w.ChunksPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o640) - if err != nil { - return trace.Wrap(err) - } - w.chunksFile = newGzipWriter(file) - return nil -} From f20c33f822e5e88e6a3784aeefcc424dc3237db6 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 20 Sep 2024 15:46:00 +0100 Subject: [PATCH 10/13] Remove more unused helpers --- lib/events/sessionlog.go | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/lib/events/sessionlog.go b/lib/events/sessionlog.go index 7511ada732026..e0b43a50a1671 100644 --- a/lib/events/sessionlog.go +++ b/lib/events/sessionlog.go @@ -20,48 +20,12 @@ package events import ( "compress/gzip" - "fmt" "io" - "path/filepath" "sync" "github.com/gravitational/trace" - - "github.com/gravitational/teleport/lib/session" -) - -const ( - fileTypeChunks = "chunks" - fileTypeEvents = "events" - - // eventsSuffix is the suffix of the archive that contains session events. - eventsSuffix = "events.gz" - - // chunksSuffix is the suffix of the archive that contains session chunks. - chunksSuffix = "chunks.gz" ) -// eventsFileName consists of session id and the first global event index -// recorded. Optionally for enhanced session recording events, the event type. -func eventsFileName(dataDir string, sessionID session.ID, eventType string, eventIndex int64) string { - if eventType != "" { - return filepath.Join(dataDir, fmt.Sprintf("%v-%v.%v-%v", sessionID.String(), eventIndex, eventType, eventsSuffix)) - } - return filepath.Join(dataDir, fmt.Sprintf("%v-%v.%v", sessionID.String(), eventIndex, eventsSuffix)) -} - -// chunksFileName consists of session id and the first global offset recorded -func chunksFileName(dataDir string, sessionID session.ID, offset int64) string { - return filepath.Join(dataDir, fmt.Sprintf("%v-%v.%v", sessionID.String(), offset, chunksSuffix)) -} - -type indexEntry struct { - FileName string `json:"file_name"` - Type string `json:"type"` - Index int64 `json:"index"` - Offset int64 `json:"offset,"` -} - // gzipWriter wraps file, on close close both gzip writer and file type gzipWriter struct { *gzip.Writer From e4c007cd91d505c1479c8bd91573876748fa70ef Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 20 Sep 2024 15:46:50 +0100 Subject: [PATCH 11/13] Remove unused constants --- lib/events/stream.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/lib/events/stream.go b/lib/events/stream.go index 7f44db4b8f418..36bda7c70b600 100644 --- a/lib/events/stream.go +++ b/lib/events/stream.go @@ -55,17 +55,10 @@ const ( // MaxProtoMessageSizeBytes is maximum protobuf marshaled message size MaxProtoMessageSizeBytes = 64 * 1024 - // MaxUploadParts is the maximum allowed number of parts in a multi-part upload - // on Amazon S3. - MaxUploadParts = 10000 - // MinUploadPartSizeBytes is the minimum allowed part size when uploading a part to // Amazon S3. MinUploadPartSizeBytes = 1024 * 1024 * 5 - // ReservedParts is the amount of parts reserved by default - ReservedParts = 100 - // ProtoStreamV1 is a version of the binary protocol ProtoStreamV1 = 1 From fa63cdfbc20dd73d72c9cc82ee46ce3d5a4c1a10 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 20 Sep 2024 17:03:10 +0100 Subject: [PATCH 12/13] Simplify how integration test checks for audit log presence --- integration/integration_test.go | 53 ++++++++++----------------------- 1 file changed, 16 insertions(+), 37 deletions(-) diff --git a/integration/integration_test.go b/integration/integration_test.go index 50294b50b665a..6f9f4e5c42614 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -544,45 +544,8 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { } } - // Stream all the session events into a slice to make them easier - // to work with. capturedStream, sessionEvents := streamSession(ctx, t, site, sessionID) - var hasStart bool - var hasEnd bool - var hasLeave bool - for _, se := range sessionEvents { - var isAuditEvent bool - if se.GetType() == events.SessionStartEvent { - isAuditEvent = true - hasStart = true - } - if se.GetType() == events.SessionEndEvent { - isAuditEvent = true - hasEnd = true - } - if se.GetType() == events.SessionLeaveEvent { - isAuditEvent = true - hasLeave = true - } - - // ensure session events are also in audit log - if !isAuditEvent { - continue - } - auditEvents, _, err := site.SearchEvents(ctx, events.SearchEventsRequest{ - To: time.Now(), - EventTypes: []string{se.GetType()}, - }) - require.NoError(t, err) - - found := slices.ContainsFunc(auditEvents, func(ae apievents.AuditEvent) bool { - return ae.GetID() == se.GetID() - }) - require.True(t, found) - } - require.True(t, hasStart && hasEnd && hasLeave) - findByType := func(et string) apievents.AuditEvent { for _, e := range sessionEvents { if e.GetType() == et { @@ -591,6 +554,19 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { } return nil } + // helper that asserts that a session event is also included in the + // general audit log. + requireInAuditLog := func(t *testing.T, sessionEvent apievents.AuditEvent) { + t.Helper() + auditEvents, _, err := site.SearchEvents(ctx, events.SearchEventsRequest{ + To: time.Now(), + EventTypes: []string{sessionEvent.GetType()}, + }) + require.NoError(t, err) + require.True(t, slices.ContainsFunc(auditEvents, func(ae apievents.AuditEvent) bool { + return ae.GetID() == sessionEvent.GetID() + })) + } // there should always be 'session.start' event (and it must be first) first := sessionEvents[0].(*apievents.SessionStart) @@ -598,16 +574,19 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { require.Equal(t, first, start) require.Equal(t, sessionID, start.SessionID) require.NotEmpty(t, start.TerminalSize) + requireInAuditLog(t, start) // there should always be 'session.end' event end := findByType(events.SessionEndEvent).(*apievents.SessionEnd) require.NotNil(t, end) require.Equal(t, sessionID, end.SessionID) + requireInAuditLog(t, end) // there should always be 'session.leave' event leave := findByType(events.SessionLeaveEvent).(*apievents.SessionLeave) require.NotNil(t, leave) require.Equal(t, sessionID, leave.SessionID) + requireInAuditLog(t, leave) // all of them should have a proper time for _, e := range sessionEvents { From 81ec91d955589bd4df963ed6cf2d5ee85f7ebe30 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Mon, 23 Sep 2024 09:27:47 +0100 Subject: [PATCH 13/13] Remove outdated comment --- lib/events/complete.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/events/complete.go b/lib/events/complete.go index 62e610df1d7ce..20bf4fc4ca997 100644 --- a/lib/events/complete.go +++ b/lib/events/complete.go @@ -349,8 +349,7 @@ func (u *UploadCompleter) ensureSessionEndEvent(ctx context.Context, uploadData var desktopSessionEnd events.WindowsDesktopSessionEnd // We use the streaming events API to search through the session events, because it works - // for both Desktop and SSH sessions, where as the GetSessionEvents API relies on downloading - // a copy of the session and using the SSH-specific index to iterate through events. + // for both Desktop and SSH sessions var lastEvent events.AuditEvent ctx, cancel := context.WithCancel(ctx) defer cancel()