diff --git a/api/defaults/defaults.go b/api/defaults/defaults.go index 88624cd12ce1d..feb61929fbbc7 100644 --- a/api/defaults/defaults.go +++ b/api/defaults/defaults.go @@ -139,6 +139,11 @@ func EnhancedEvents() []string { } } +const ( + // MaxIterationLimit is max iteration limit + MaxIterationLimit = 1000 +) + const ( // DefaultChunkSize is the default chunk size for paginated endpoints. DefaultChunkSize = 1000 diff --git a/api/sessionrecording/sessionlog.go b/api/sessionrecording/sessionlog.go new file mode 100644 index 0000000000000..ac242535a5b19 --- /dev/null +++ b/api/sessionrecording/sessionlog.go @@ -0,0 +1,44 @@ +package sessionrecording + +import ( + "compress/gzip" + "io" + + "github.com/gravitational/trace" +) + +// gzipReader wraps file, on close close both gzip writer and file +type gzipReader struct { + io.ReadCloser + inner io.ReadCloser +} + +// Close closes file and gzip writer +func (f *gzipReader) Close() error { + var errors []error + if f.ReadCloser != nil { + errors = append(errors, f.ReadCloser.Close()) + f.ReadCloser = nil + } + if f.inner != nil { + errors = append(errors, f.inner.Close()) + f.inner = nil + } + return trace.NewAggregate(errors...) +} + +func newGzipReader(reader io.ReadCloser) (*gzipReader, error) { + gzReader, err := gzip.NewReader(reader) + if err != nil { + return nil, trace.Wrap(err) + } + // older bugged versions of teleport would sometimes incorrectly inject padding bytes into + // the gzip section of the archive. this causes gzip readers with multistream enabled (the + // default behavior) to fail. we disable multistream here in order to ensure that the gzip + // reader halts when it reaches the end of the current (only) valid gzip entry. + gzReader.Multistream(false) + return &gzipReader{ + ReadCloser: gzReader, + inner: reader, + }, nil +} diff --git a/api/sessionrecording/stream.go b/api/sessionrecording/stream.go new file mode 100644 index 0000000000000..731f2866b3a74 --- /dev/null +++ b/api/sessionrecording/stream.go @@ -0,0 +1,287 @@ +package sessionrecording + +import ( + "context" + "encoding/binary" + "errors" + "io" + "log/slog" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/defaults" + apievents "github.com/gravitational/teleport/api/types/events" +) + +const ( + // Int32Size is a constant for 32 bit integer byte size + Int32Size = 4 + + // Int64Size is a constant for 64 bit integer byte size + Int64Size = 8 + + // MaxProtoMessageSizeBytes is maximum protobuf marshaled message size + MaxProtoMessageSizeBytes = 64 * 1024 + + // ProtoStreamV1 is a version of the binary protocol + ProtoStreamV1 = 1 + + // ProtoStreamV1PartHeaderSize is the size of the part of the protocol stream + // on disk format, it consists of + // * 8 bytes for the format version + // * 8 bytes for meaningful size of the part + // * 8 bytes for optional padding size at the end of the slice + ProtoStreamV1PartHeaderSize = Int64Size * 3 + + // ProtoStreamV1RecordHeaderSize is the size of the header + // of the record header, it consists of the record length + ProtoStreamV1RecordHeaderSize = Int32Size +) + +// NewProtoReader returns a new proto reader with slice pool +func NewProtoReader(r io.Reader) *ProtoReader { + return &ProtoReader{ + reader: r, + lastIndex: -1, + } +} + +const ( + // protoReaderStateInit is ready to start reading the next part + protoReaderStateInit = 0 + // protoReaderStateCurrent will read the data from the current part + protoReaderStateCurrent = iota + // protoReaderStateEOF indicates that reader has completed reading + // all parts + protoReaderStateEOF = iota + // protoReaderStateError indicates that reader has reached internal + // error and should close + protoReaderStateError = iota +) + +// ProtoReader reads protobuf encoding from reader +type ProtoReader struct { + gzipReader *gzipReader + padding int64 + reader io.Reader + sizeBytes [Int64Size]byte + messageBytes [MaxProtoMessageSizeBytes]byte + state int + error error + lastIndex int64 + stats ProtoReaderStats +} + +// ProtoReaderStats contains some reader statistics +type ProtoReaderStats struct { + // SkippedEvents is a counter with encountered + // events recorded several times or events + // that have been out of order as skipped + SkippedEvents int64 + // OutOfOrderEvents is a counter with events + // received out of order + OutOfOrderEvents int64 + // TotalEvents contains total amount of + // processed events (including duplicates) + TotalEvents int64 +} + +// ToFields returns a copy of the stats to be used as log fields +func (p ProtoReaderStats) ToFields() map[string]any { + return map[string]any{ + "skipped-events": p.SkippedEvents, + "out-of-order-events": p.OutOfOrderEvents, + "total-events": p.TotalEvents, + } +} + +// Close releases reader resources +func (r *ProtoReader) Close() error { + if r.gzipReader != nil { + return r.gzipReader.Close() + } + return nil +} + +// Reset sets reader to read from the new reader +// without resetting the stats, could be used +// to deduplicate the events +func (r *ProtoReader) Reset(reader io.Reader) error { + if r.error != nil { + return r.error + } + if r.gzipReader != nil { + if r.error = r.gzipReader.Close(); r.error != nil { + return trace.Wrap(r.error) + } + r.gzipReader = nil + } + r.reader = reader + r.state = protoReaderStateInit + return nil +} + +func (r *ProtoReader) setError(err error) error { + r.state = protoReaderStateError + r.error = err + return err +} + +// GetStats returns stats about processed events +func (r *ProtoReader) GetStats() ProtoReaderStats { + return r.stats +} + +// Read returns next event or io.EOF in case of the end of the parts +func (r *ProtoReader) Read(ctx context.Context) (apievents.AuditEvent, error) { + // periodic checks of context after fixed amount of iterations + // is an extra precaution to avoid + // accidental endless loop due to logic error crashing the system + // and allows ctx timeout to kick in if specified + var checkpointIteration int64 + for { + checkpointIteration++ + if checkpointIteration%defaults.MaxIterationLimit == 0 { + select { + case <-ctx.Done(): + if ctx.Err() != nil { + return nil, trace.Wrap(ctx.Err()) + } + return nil, trace.LimitExceeded("context has been canceled") + default: + } + } + switch r.state { + case protoReaderStateEOF: + return nil, io.EOF + case protoReaderStateError: + return nil, r.error + case protoReaderStateInit: + // read the part header that consists of the protocol version + // and the part size (for the V1 version of the protocol) + _, err := io.ReadFull(r.reader, r.sizeBytes[:Int64Size]) + if err != nil { + // reached the end of the stream + if errors.Is(err, io.EOF) { + r.state = protoReaderStateEOF + return nil, err + } + return nil, r.setError(trace.ConvertSystemError(err)) + } + protocolVersion := binary.BigEndian.Uint64(r.sizeBytes[:Int64Size]) + if protocolVersion != ProtoStreamV1 { + return nil, trace.BadParameter("unsupported protocol version %v", protocolVersion) + } + // read size of this gzipped part as encoded by V1 protocol version + _, err = io.ReadFull(r.reader, r.sizeBytes[:Int64Size]) + if err != nil { + return nil, r.setError(trace.ConvertSystemError(err)) + } + partSize := binary.BigEndian.Uint64(r.sizeBytes[:Int64Size]) + // read padding size (could be 0) + _, err = io.ReadFull(r.reader, r.sizeBytes[:Int64Size]) + if err != nil { + return nil, r.setError(trace.ConvertSystemError(err)) + } + r.padding = int64(binary.BigEndian.Uint64(r.sizeBytes[:Int64Size])) + gzipReader, err := newGzipReader(io.NopCloser(io.LimitReader(r.reader, int64(partSize)))) + if err != nil { + return nil, r.setError(trace.Wrap(err)) + } + r.gzipReader = gzipReader + r.state = protoReaderStateCurrent + continue + // read the next version from the gzip reader + case protoReaderStateCurrent: + // the record consists of length of the protobuf encoded + // message and the message itself + _, err := io.ReadFull(r.gzipReader, r.sizeBytes[:Int32Size]) + if err != nil { + if !errors.Is(err, io.EOF) { + return nil, r.setError(trace.ConvertSystemError(err)) + } + + // due to a bug in older versions of teleport it was possible that padding + // bytes would end up inside of the gzip section of the archive. we should + // skip any dangling data in the gzip secion. + n, err := io.CopyBuffer(io.Discard, r.gzipReader.inner, r.messageBytes[:]) + if err != nil { + return nil, r.setError(trace.ConvertSystemError(err)) + } + + if n != 0 { + // log the number of bytes that were skipped + slog.DebugContext(ctx, "skipped dangling data in session recording section", "length", n) + } + + // reached the end of the current part, but not necessarily + // the end of the stream + if err := r.gzipReader.Close(); err != nil { + return nil, r.setError(trace.ConvertSystemError(err)) + } + if r.padding != 0 { + skipped, err := io.CopyBuffer(io.Discard, io.LimitReader(r.reader, r.padding), r.messageBytes[:]) + if err != nil { + return nil, r.setError(trace.ConvertSystemError(err)) + } + if skipped != r.padding { + return nil, r.setError(trace.BadParameter( + "data truncated, expected to read %v bytes, but got %v", r.padding, skipped)) + } + } + r.padding = 0 + r.gzipReader = nil + r.state = protoReaderStateInit + continue + } + messageSize := binary.BigEndian.Uint32(r.sizeBytes[:Int32Size]) + // zero message size indicates end of the part + // that sometimes is present in partially submitted parts + // that have to be filled with zeroes for parts smaller + // than minimum allowed size + if messageSize == 0 { + return nil, r.setError(trace.BadParameter("unexpected message size 0")) + } + _, err = io.ReadFull(r.gzipReader, r.messageBytes[:messageSize]) + if err != nil { + return nil, r.setError(trace.ConvertSystemError(err)) + } + oneof := apievents.OneOf{} + err = oneof.Unmarshal(r.messageBytes[:messageSize]) + if err != nil { + return nil, trace.Wrap(err) + } + event, err := apievents.FromOneOf(oneof) + if err != nil { + return nil, trace.Wrap(err) + } + r.stats.TotalEvents++ + if event.GetIndex() <= r.lastIndex { + r.stats.SkippedEvents++ + continue + } + if r.lastIndex > 0 && event.GetIndex() != r.lastIndex+1 { + r.stats.OutOfOrderEvents++ + } + r.lastIndex = event.GetIndex() + return event, nil + default: + return nil, trace.BadParameter("unsupported reader size") + } + } +} + +// ReadAll reads all events until EOF +func (r *ProtoReader) ReadAll(ctx context.Context) ([]apievents.AuditEvent, error) { + var events []apievents.AuditEvent + for { + event, err := r.Read(ctx) + if err != nil { + if errors.Is(err, io.EOF) { + return events, nil + } + return nil, trace.Wrap(err) + } + events = append(events, event) + } +} diff --git a/api/sessionrecording/stream_test.go b/api/sessionrecording/stream_test.go new file mode 100644 index 0000000000000..218c753bcfcd9 --- /dev/null +++ b/api/sessionrecording/stream_test.go @@ -0,0 +1,31 @@ +package sessionrecording_test + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/sessionrecording" +) + +// TestReadCorruptedRecording tests that the streamer can successfully decode the kind of corrupted +// recordings that some older bugged versions of teleport might end up producing when under heavy load/throttling. +func TestReadCorruptedRecording(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + f, err := os.Open("testdata/corrupted-session") + require.NoError(t, err) + defer f.Close() + + reader := sessionrecording.NewProtoReader(f) + defer reader.Close() + + events, err := reader.ReadAll(ctx) + require.NoError(t, err) + + // verify that the expected number of events are extracted + require.Len(t, events, 12) +} diff --git a/lib/events/testdata/corrupted-session b/api/sessionrecording/testdata/corrupted-session similarity index 100% rename from lib/events/testdata/corrupted-session rename to api/sessionrecording/testdata/corrupted-session diff --git a/lib/client/player.go b/lib/client/player.go index 4a87d352a35e6..fc35680c1ee68 100644 --- a/lib/client/player.go +++ b/lib/client/player.go @@ -24,8 +24,8 @@ import ( "github.com/gravitational/trace" + "github.com/gravitational/teleport/api/sessionrecording" apievents "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/session" ) @@ -51,7 +51,7 @@ func (p *playFromFileStreamer) StreamSessionEvents( } defer f.Close() - pr := events.NewProtoReader(f) + pr := sessionrecording.NewProtoReader(f) for i := int64(0); ; i++ { evt, err := pr.Read(ctx) if err != nil { diff --git a/lib/defaults/defaults.go b/lib/defaults/defaults.go index 2bfd34c557d69..9a0af9985b372 100644 --- a/lib/defaults/defaults.go +++ b/lib/defaults/defaults.go @@ -210,7 +210,7 @@ const ( MaxPasswordLength = 128 // MaxIterationLimit is max iteration limit - MaxIterationLimit = 1000 + MaxIterationLimit = defaults.MaxIterationLimit // EventsIterationLimit is a default limit if it's not set for events EventsIterationLimit = 500 diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index 3570171f40996..a80aa68979d24 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -37,6 +37,7 @@ import ( apidefaults "github.com/gravitational/teleport/api/defaults" auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1" "github.com/gravitational/teleport/api/internalutils/stream" + "github.com/gravitational/teleport/api/sessionrecording" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/observability/metrics" @@ -554,7 +555,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID return } - protoReader := NewProtoReader(rawSession) + protoReader := sessionrecording.NewProtoReader(rawSession) defer protoReader.Close() for { diff --git a/lib/events/emitter_test.go b/lib/events/emitter_test.go index c7b4cc2076d77..822ce331ad6fe 100644 --- a/lib/events/emitter_test.go +++ b/lib/events/emitter_test.go @@ -31,6 +31,7 @@ import ( "github.com/stretchr/testify/require" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/sessionrecording" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" @@ -108,7 +109,7 @@ func TestProtoStreamer(t *testing.T) { require.NoError(t, err) for _, part := range parts { - reader := events.NewProtoReader(bytes.NewReader(part)) + reader := sessionrecording.NewProtoReader(bytes.NewReader(part)) out, err := reader.ReadAll(ctx) require.NoError(t, err, "part crash %#v", part) outEvents = append(outEvents, out...) @@ -256,7 +257,7 @@ func TestExport(t *testing.T) { _, err := f.Write(part) require.NoError(t, err) } - reader := events.NewProtoReader(io.MultiReader(readers...)) + reader := sessionrecording.NewProtoReader(io.MultiReader(readers...)) outEvents, err := reader.ReadAll(ctx) require.NoError(t, err) diff --git a/lib/events/filesessions/fileasync.go b/lib/events/filesessions/fileasync.go index 27b14f4408fd7..fef1c99b57154 100644 --- a/lib/events/filesessions/fileasync.go +++ b/lib/events/filesessions/fileasync.go @@ -34,6 +34,7 @@ import ( "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/sessionrecording" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/defaults" @@ -308,7 +309,7 @@ func (u *Uploader) sessionErrorFilePath(sid session.ID) string { type upload struct { sessionID session.ID - reader *events.ProtoReader + reader *sessionrecording.ProtoReader file *os.File fileUnlockFn func() error checkpointFile *os.File @@ -441,7 +442,7 @@ func (u *Uploader) startUpload(ctx context.Context, fileName string) (err error) upload := &upload{ sessionID: sessionID, - reader: events.NewProtoReader(sessionFile), + reader: sessionrecording.NewProtoReader(sessionFile), file: sessionFile, fileUnlockFn: unlock, } diff --git a/lib/events/filesessions/fileasync_test.go b/lib/events/filesessions/fileasync_test.go index 7e34693ac0ea9..213a7633192a9 100644 --- a/lib/events/filesessions/fileasync_test.go +++ b/lib/events/filesessions/fileasync_test.go @@ -33,6 +33,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/sessionrecording" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" @@ -666,10 +667,10 @@ func readStream(ctx context.Context, t *testing.T, uploadID string, uploader *ev require.NoError(t, err) var outEvents []apievents.AuditEvent - var reader *events.ProtoReader + var reader *sessionrecording.ProtoReader for i, part := range parts { if i == 0 { - reader = events.NewProtoReader(bytes.NewReader(part)) + reader = sessionrecording.NewProtoReader(bytes.NewReader(part)) } else { err := reader.Reset(bytes.NewReader(part)) require.NoError(t, err) diff --git a/lib/events/playback.go b/lib/events/playback.go index 8c8c21cfdf7d7..66ba59779069b 100644 --- a/lib/events/playback.go +++ b/lib/events/playback.go @@ -29,6 +29,7 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/sessionrecording" "github.com/gravitational/teleport/lib/utils" ) @@ -46,13 +47,13 @@ type Header struct { // of the header. Callers should call Seek() // to reuse reader after calling this function. func DetectFormat(r io.ReadSeeker) (*Header, error) { - version := make([]byte, Int64Size) + version := make([]byte, sessionrecording.Int64Size) _, err := io.ReadFull(r, version) if err != nil { return nil, trace.ConvertSystemError(err) } protocolVersion := binary.BigEndian.Uint64(version) - if protocolVersion == ProtoStreamV1 { + if protocolVersion == sessionrecording.ProtoStreamV1 { return &Header{ Proto: true, ProtoVersion: int64(protocolVersion), @@ -88,7 +89,7 @@ func Export(ctx context.Context, rs io.ReadSeeker, w io.Writer, exportFormat str } switch { case format.Proto: - protoReader := NewProtoReader(rs) + protoReader := sessionrecording.NewProtoReader(rs) for { event, err := protoReader.Read(ctx) if err != nil { diff --git a/lib/events/session_writer_test.go b/lib/events/session_writer_test.go index d0b9fa72189eb..a3311e0c6cba3 100644 --- a/lib/events/session_writer_test.go +++ b/lib/events/session_writer_test.go @@ -31,6 +31,7 @@ import ( "github.com/stretchr/testify/require" apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/sessionrecording" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" @@ -74,7 +75,7 @@ func TestSessionWriter(t *testing.T) { require.NoError(t, err) for _, part := range parts { - reader := events.NewProtoReader(bytes.NewReader(part)) + reader := sessionrecording.NewProtoReader(bytes.NewReader(part)) out, err := reader.ReadAll(test.ctx) require.NoError(t, err, "part crash %#v", part) outEvents = append(outEvents, out...) @@ -420,7 +421,7 @@ func (a *sessionWriterTest) collectEvents(t *testing.T) []apievents.AuditEvent { for _, part := range parts { readers = append(readers, bytes.NewReader(part)) } - reader := events.NewProtoReader(io.MultiReader(readers...)) + reader := sessionrecording.NewProtoReader(io.MultiReader(readers...)) outEvents, err := reader.ReadAll(a.ctx) require.NoError(t, err, "failed to read") t.Logf("Reader stats :%v", reader.GetStats().ToFields()) diff --git a/lib/events/sessionlog.go b/lib/events/sessionlog.go index e0b43a50a1671..78f5e6b81c798 100644 --- a/lib/events/sessionlog.go +++ b/lib/events/sessionlog.go @@ -67,39 +67,3 @@ func newGzipWriter(writer io.WriteCloser) *gzipWriter { inner: writer, } } - -// gzipReader wraps file, on close close both gzip writer and file -type gzipReader struct { - io.ReadCloser - inner io.ReadCloser -} - -// Close closes file and gzip writer -func (f *gzipReader) Close() error { - var errors []error - if f.ReadCloser != nil { - errors = append(errors, f.ReadCloser.Close()) - f.ReadCloser = nil - } - if f.inner != nil { - errors = append(errors, f.inner.Close()) - f.inner = nil - } - return trace.NewAggregate(errors...) -} - -func newGzipReader(reader io.ReadCloser) (*gzipReader, error) { - gzReader, err := gzip.NewReader(reader) - if err != nil { - return nil, trace.Wrap(err) - } - // older bugged versions of teleport would sometimes incorrectly inject padding bytes into - // the gzip section of the archive. this causes gzip readers with multistream enabled (the - // default behavior) to fail. we disable multistream here in order to ensure that the gzip - // reader halts when it reaches the end of the current (only) valid gzip entry. - gzReader.Multistream(false) - return &gzipReader{ - ReadCloser: gzReader, - inner: reader, - }, nil -} diff --git a/lib/events/stream.go b/lib/events/stream.go index 0f00327d16116..54c4c58b00fce 100644 --- a/lib/events/stream.go +++ b/lib/events/stream.go @@ -34,6 +34,7 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/gravitational/teleport/api/sessionrecording" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/defaults" @@ -42,12 +43,6 @@ import ( ) const ( - // Int32Size is a constant for 32 bit integer byte size - Int32Size = 4 - - // Int64Size is a constant for 64 bit integer byte size - Int64Size = 8 - // ConcurrentUploadsPerStream limits the amount of concurrent uploads // per stream ConcurrentUploadsPerStream = 1 @@ -59,20 +54,6 @@ const ( // Amazon S3. MinUploadPartSizeBytes = 1024 * 1024 * 5 - // ProtoStreamV1 is a version of the binary protocol - ProtoStreamV1 = 1 - - // ProtoStreamV1PartHeaderSize is the size of the part of the protocol stream - // on disk format, it consists of - // * 8 bytes for the format version - // * 8 bytes for meaningful size of the part - // * 8 bytes for optional padding size at the end of the slice - ProtoStreamV1PartHeaderSize = Int64Size * 3 - - // ProtoStreamV1RecordHeaderSize is the size of the header - // of the record header, it consists of the record length - ProtoStreamV1RecordHeaderSize = Int32Size - // uploaderReservePartErrorMessage error message present when // `ReserveUploadPart` fails. uploaderReservePartErrorMessage = "uploader failed to reserve upload part" @@ -116,7 +97,7 @@ func NewProtoStreamer(cfg ProtoStreamerConfig) (*ProtoStreamer, error) { // Min upload bytes + some overhead to prevent buffer growth (gzip writer is not precise) bufferPool: utils.NewBufferSyncPool(cfg.MinUploadBytes + cfg.MinUploadBytes/3), // MaxProtoMessage size + length of the message record - slicePool: utils.NewSliceSyncPool(MaxProtoMessageSizeBytes + ProtoStreamV1RecordHeaderSize), + slicePool: utils.NewSliceSyncPool(MaxProtoMessageSizeBytes + sessionrecording.ProtoStreamV1RecordHeaderSize), }, nil } @@ -488,7 +469,7 @@ type sliceWriter struct { completedParts []StreamPart // emptyHeader is used to write empty header // to preserve some bytes - emptyHeader [ProtoStreamV1PartHeaderSize]byte + emptyHeader [sessionrecording.ProtoStreamV1PartHeaderSize]byte } func (w *sliceWriter) updateCompletedParts(part StreamPart, lastEventIndex int64) { @@ -871,9 +852,9 @@ func (s *slice) reader() (io.ReadSeeker, error) { data := s.buffer.Bytes() // when the slice was created, the first bytes were reserved // for the protocol version number and size of the slice in bytes - binary.BigEndian.PutUint64(data[0:], ProtoStreamV1) - binary.BigEndian.PutUint64(data[Int64Size:], uint64(wroteBytes-ProtoStreamV1PartHeaderSize)) - binary.BigEndian.PutUint64(data[Int64Size*2:], uint64(paddingBytes)) + binary.BigEndian.PutUint64(data[0:], sessionrecording.ProtoStreamV1) + binary.BigEndian.PutUint64(data[sessionrecording.Int64Size:], uint64(wroteBytes-sessionrecording.ProtoStreamV1PartHeaderSize)) + binary.BigEndian.PutUint64(data[sessionrecording.Int64Size*2:], uint64(paddingBytes)) return bytes.NewReader(data), nil } @@ -905,7 +886,7 @@ func (s *slice) recordEvent(event protoEvent) error { s.eventCount++ messageSize := event.oneof.Size() - recordSize := ProtoStreamV1RecordHeaderSize + messageSize + recordSize := sessionrecording.ProtoStreamV1RecordHeaderSize + messageSize if len(bytes) < recordSize { return trace.BadParameter( @@ -913,7 +894,7 @@ func (s *slice) recordEvent(event protoEvent) error { } binary.BigEndian.PutUint32(bytes, uint32(messageSize)) - _, err := event.oneof.MarshalTo(bytes[Int32Size:]) + _, err := event.oneof.MarshalTo(bytes[sessionrecording.Int32Size:]) if err != nil { return trace.Wrap(err) } @@ -930,14 +911,6 @@ func (s *slice) recordEvent(event protoEvent) error { return nil } -// NewProtoReader returns a new proto reader with slice pool -func NewProtoReader(r io.Reader) *ProtoReader { - return &ProtoReader{ - reader: r, - lastIndex: -1, - } -} - // SessionReader provides method to read // session events one by one type SessionReader interface { @@ -945,246 +918,6 @@ type SessionReader interface { Read(context.Context) (apievents.AuditEvent, error) } -const ( - // protoReaderStateInit is ready to start reading the next part - protoReaderStateInit = 0 - // protoReaderStateCurrent will read the data from the current part - protoReaderStateCurrent = iota - // protoReaderStateEOF indicates that reader has completed reading - // all parts - protoReaderStateEOF = iota - // protoReaderStateError indicates that reader has reached internal - // error and should close - protoReaderStateError = iota -) - -// ProtoReader reads protobuf encoding from reader -type ProtoReader struct { - gzipReader *gzipReader - padding int64 - reader io.Reader - sizeBytes [Int64Size]byte - messageBytes [MaxProtoMessageSizeBytes]byte - state int - error error - lastIndex int64 - stats ProtoReaderStats -} - -// ProtoReaderStats contains some reader statistics -type ProtoReaderStats struct { - // SkippedEvents is a counter with encountered - // events recorded several times or events - // that have been out of order as skipped - SkippedEvents int64 - // OutOfOrderEvents is a counter with events - // received out of order - OutOfOrderEvents int64 - // TotalEvents contains total amount of - // processed events (including duplicates) - TotalEvents int64 -} - -// ToFields returns a copy of the stats to be used as log fields -func (p ProtoReaderStats) ToFields() map[string]any { - return map[string]any{ - "skipped-events": p.SkippedEvents, - "out-of-order-events": p.OutOfOrderEvents, - "total-events": p.TotalEvents, - } -} - -// Close releases reader resources -func (r *ProtoReader) Close() error { - if r.gzipReader != nil { - return r.gzipReader.Close() - } - return nil -} - -// Reset sets reader to read from the new reader -// without resetting the stats, could be used -// to deduplicate the events -func (r *ProtoReader) Reset(reader io.Reader) error { - if r.error != nil { - return r.error - } - if r.gzipReader != nil { - if r.error = r.gzipReader.Close(); r.error != nil { - return trace.Wrap(r.error) - } - r.gzipReader = nil - } - r.reader = reader - r.state = protoReaderStateInit - return nil -} - -func (r *ProtoReader) setError(err error) error { - r.state = protoReaderStateError - r.error = err - return err -} - -// GetStats returns stats about processed events -func (r *ProtoReader) GetStats() ProtoReaderStats { - return r.stats -} - -// Read returns next event or io.EOF in case of the end of the parts -func (r *ProtoReader) Read(ctx context.Context) (apievents.AuditEvent, error) { - // periodic checks of context after fixed amount of iterations - // is an extra precaution to avoid - // accidental endless loop due to logic error crashing the system - // and allows ctx timeout to kick in if specified - var checkpointIteration int64 - for { - checkpointIteration++ - if checkpointIteration%defaults.MaxIterationLimit == 0 { - select { - case <-ctx.Done(): - if ctx.Err() != nil { - return nil, trace.Wrap(ctx.Err()) - } - return nil, trace.LimitExceeded("context has been canceled") - default: - } - } - switch r.state { - case protoReaderStateEOF: - return nil, io.EOF - case protoReaderStateError: - return nil, r.error - case protoReaderStateInit: - // read the part header that consists of the protocol version - // and the part size (for the V1 version of the protocol) - _, err := io.ReadFull(r.reader, r.sizeBytes[:Int64Size]) - if err != nil { - // reached the end of the stream - if errors.Is(err, io.EOF) { - r.state = protoReaderStateEOF - return nil, err - } - return nil, r.setError(trace.ConvertSystemError(err)) - } - protocolVersion := binary.BigEndian.Uint64(r.sizeBytes[:Int64Size]) - if protocolVersion != ProtoStreamV1 { - return nil, trace.BadParameter("unsupported protocol version %v", protocolVersion) - } - // read size of this gzipped part as encoded by V1 protocol version - _, err = io.ReadFull(r.reader, r.sizeBytes[:Int64Size]) - if err != nil { - return nil, r.setError(trace.ConvertSystemError(err)) - } - partSize := binary.BigEndian.Uint64(r.sizeBytes[:Int64Size]) - // read padding size (could be 0) - _, err = io.ReadFull(r.reader, r.sizeBytes[:Int64Size]) - if err != nil { - return nil, r.setError(trace.ConvertSystemError(err)) - } - r.padding = int64(binary.BigEndian.Uint64(r.sizeBytes[:Int64Size])) - gzipReader, err := newGzipReader(io.NopCloser(io.LimitReader(r.reader, int64(partSize)))) - if err != nil { - return nil, r.setError(trace.Wrap(err)) - } - r.gzipReader = gzipReader - r.state = protoReaderStateCurrent - continue - // read the next version from the gzip reader - case protoReaderStateCurrent: - // the record consists of length of the protobuf encoded - // message and the message itself - _, err := io.ReadFull(r.gzipReader, r.sizeBytes[:Int32Size]) - if err != nil { - if !errors.Is(err, io.EOF) { - return nil, r.setError(trace.ConvertSystemError(err)) - } - - // due to a bug in older versions of teleport it was possible that padding - // bytes would end up inside of the gzip section of the archive. we should - // skip any dangling data in the gzip secion. - n, err := io.CopyBuffer(io.Discard, r.gzipReader.inner, r.messageBytes[:]) - if err != nil { - return nil, r.setError(trace.ConvertSystemError(err)) - } - - if n != 0 { - // log the number of bytes that were skipped - slog.DebugContext(ctx, "skipped dangling data in session recording section", "length", n) - } - - // reached the end of the current part, but not necessarily - // the end of the stream - if err := r.gzipReader.Close(); err != nil { - return nil, r.setError(trace.ConvertSystemError(err)) - } - if r.padding != 0 { - skipped, err := io.CopyBuffer(io.Discard, io.LimitReader(r.reader, r.padding), r.messageBytes[:]) - if err != nil { - return nil, r.setError(trace.ConvertSystemError(err)) - } - if skipped != r.padding { - return nil, r.setError(trace.BadParameter( - "data truncated, expected to read %v bytes, but got %v", r.padding, skipped)) - } - } - r.padding = 0 - r.gzipReader = nil - r.state = protoReaderStateInit - continue - } - messageSize := binary.BigEndian.Uint32(r.sizeBytes[:Int32Size]) - // zero message size indicates end of the part - // that sometimes is present in partially submitted parts - // that have to be filled with zeroes for parts smaller - // than minimum allowed size - if messageSize == 0 { - return nil, r.setError(trace.BadParameter("unexpected message size 0")) - } - _, err = io.ReadFull(r.gzipReader, r.messageBytes[:messageSize]) - if err != nil { - return nil, r.setError(trace.ConvertSystemError(err)) - } - oneof := apievents.OneOf{} - err = oneof.Unmarshal(r.messageBytes[:messageSize]) - if err != nil { - return nil, trace.Wrap(err) - } - event, err := apievents.FromOneOf(oneof) - if err != nil { - return nil, trace.Wrap(err) - } - r.stats.TotalEvents++ - if event.GetIndex() <= r.lastIndex { - r.stats.SkippedEvents++ - continue - } - if r.lastIndex > 0 && event.GetIndex() != r.lastIndex+1 { - r.stats.OutOfOrderEvents++ - } - r.lastIndex = event.GetIndex() - return event, nil - default: - return nil, trace.BadParameter("unsupported reader size") - } - } -} - -// ReadAll reads all events until EOF -func (r *ProtoReader) ReadAll(ctx context.Context) ([]apievents.AuditEvent, error) { - var events []apievents.AuditEvent - for { - event, err := r.Read(ctx) - if err != nil { - if errors.Is(err, io.EOF) { - return events, nil - } - return nil, trace.Wrap(err) - } - events = append(events, event) - } -} - // isReserveUploadPartError identifies uploader reserve part errors. func isReserveUploadPartError(err error) bool { return strings.Contains(err.Error(), uploaderReservePartErrorMessage) diff --git a/lib/events/stream_test.go b/lib/events/stream_test.go index a9d0ee1c94f60..bde6f9168c09f 100644 --- a/lib/events/stream_test.go +++ b/lib/events/stream_test.go @@ -29,6 +29,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/sessionrecording" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" @@ -213,7 +214,7 @@ func TestReadCorruptedRecording(t *testing.T) { require.NoError(t, err) defer f.Close() - reader := events.NewProtoReader(f) + reader := sessionrecording.NewProtoReader(f) defer reader.Close() events, err := reader.ReadAll(ctx) diff --git a/lib/events/test/streamsuite.go b/lib/events/test/streamsuite.go index 24dba55ec0ac2..76ef8b851da6a 100644 --- a/lib/events/test/streamsuite.go +++ b/lib/events/test/streamsuite.go @@ -30,6 +30,7 @@ import ( "github.com/gravitational/trace" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/sessionrecording" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/fixtures" @@ -251,7 +252,7 @@ func StreamWithParameters(t *testing.T, handler events.MultipartHandler, params _, err = f.Seek(0, 0) require.NoError(t, err) - reader := events.NewProtoReader(f) + reader := sessionrecording.NewProtoReader(f) out, err := reader.ReadAll(ctx) require.NoError(t, err) @@ -318,7 +319,7 @@ func StreamResumeWithParameters(t *testing.T, handler events.MultipartHandler, p _, err = f.Seek(0, 0) require.NoError(t, err) - reader := events.NewProtoReader(f) + reader := sessionrecording.NewProtoReader(f) out, err := reader.ReadAll(ctx) require.NoError(t, err)