Skip to content

Commit

Permalink
feat(api/sessionrecording): move lib/events' ProtoReader
Browse files Browse the repository at this point in the history
This moves ProtoReader from lib/events to
api/sessionrecording so that it may be imported by users.
  • Loading branch information
dustinspecker committed Dec 20, 2024
1 parent c61461b commit 82fea9f
Show file tree
Hide file tree
Showing 17 changed files with 401 additions and 329 deletions.
5 changes: 5 additions & 0 deletions api/defaults/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ func EnhancedEvents() []string {
}
}

const (
// MaxIterationLimit is max iteration limit
MaxIterationLimit = 1000
)

const (
// DefaultChunkSize is the default chunk size for paginated endpoints.
DefaultChunkSize = 1000
Expand Down
44 changes: 44 additions & 0 deletions api/sessionrecording/sessionlog.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package sessionrecording

import (
"compress/gzip"
"io"

"github.com/gravitational/trace"
)

// gzipReader wraps file, on close close both gzip writer and file
type gzipReader struct {
io.ReadCloser
inner io.ReadCloser
}

// Close closes file and gzip writer
func (f *gzipReader) Close() error {
var errors []error
if f.ReadCloser != nil {
errors = append(errors, f.ReadCloser.Close())
f.ReadCloser = nil
}
if f.inner != nil {
errors = append(errors, f.inner.Close())
f.inner = nil
}
return trace.NewAggregate(errors...)
}

func newGzipReader(reader io.ReadCloser) (*gzipReader, error) {
gzReader, err := gzip.NewReader(reader)
if err != nil {
return nil, trace.Wrap(err)
}
// older bugged versions of teleport would sometimes incorrectly inject padding bytes into
// the gzip section of the archive. this causes gzip readers with multistream enabled (the
// default behavior) to fail. we disable multistream here in order to ensure that the gzip
// reader halts when it reaches the end of the current (only) valid gzip entry.
gzReader.Multistream(false)
return &gzipReader{
ReadCloser: gzReader,
inner: reader,
}, nil
}
287 changes: 287 additions & 0 deletions api/sessionrecording/stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
package sessionrecording

import (
"context"
"encoding/binary"
"errors"
"io"
"log/slog"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/defaults"
apievents "github.com/gravitational/teleport/api/types/events"
)

const (
// Int32Size is a constant for 32 bit integer byte size
Int32Size = 4

// Int64Size is a constant for 64 bit integer byte size
Int64Size = 8

// MaxProtoMessageSizeBytes is maximum protobuf marshaled message size
MaxProtoMessageSizeBytes = 64 * 1024

// ProtoStreamV1 is a version of the binary protocol
ProtoStreamV1 = 1

// ProtoStreamV1PartHeaderSize is the size of the part of the protocol stream
// on disk format, it consists of
// * 8 bytes for the format version
// * 8 bytes for meaningful size of the part
// * 8 bytes for optional padding size at the end of the slice
ProtoStreamV1PartHeaderSize = Int64Size * 3

// ProtoStreamV1RecordHeaderSize is the size of the header
// of the record header, it consists of the record length
ProtoStreamV1RecordHeaderSize = Int32Size
)

// NewProtoReader returns a new proto reader with slice pool
func NewProtoReader(r io.Reader) *ProtoReader {
return &ProtoReader{
reader: r,
lastIndex: -1,
}
}

const (
// protoReaderStateInit is ready to start reading the next part
protoReaderStateInit = 0
// protoReaderStateCurrent will read the data from the current part
protoReaderStateCurrent = iota
// protoReaderStateEOF indicates that reader has completed reading
// all parts
protoReaderStateEOF = iota
// protoReaderStateError indicates that reader has reached internal
// error and should close
protoReaderStateError = iota
)

// ProtoReader reads protobuf encoding from reader
type ProtoReader struct {
gzipReader *gzipReader
padding int64
reader io.Reader
sizeBytes [Int64Size]byte
messageBytes [MaxProtoMessageSizeBytes]byte
state int
error error
lastIndex int64
stats ProtoReaderStats
}

// ProtoReaderStats contains some reader statistics
type ProtoReaderStats struct {
// SkippedEvents is a counter with encountered
// events recorded several times or events
// that have been out of order as skipped
SkippedEvents int64
// OutOfOrderEvents is a counter with events
// received out of order
OutOfOrderEvents int64
// TotalEvents contains total amount of
// processed events (including duplicates)
TotalEvents int64
}

// ToFields returns a copy of the stats to be used as log fields
func (p ProtoReaderStats) ToFields() map[string]any {
return map[string]any{
"skipped-events": p.SkippedEvents,
"out-of-order-events": p.OutOfOrderEvents,
"total-events": p.TotalEvents,
}
}

// Close releases reader resources
func (r *ProtoReader) Close() error {
if r.gzipReader != nil {
return r.gzipReader.Close()
}
return nil
}

// Reset sets reader to read from the new reader
// without resetting the stats, could be used
// to deduplicate the events
func (r *ProtoReader) Reset(reader io.Reader) error {
if r.error != nil {
return r.error
}
if r.gzipReader != nil {
if r.error = r.gzipReader.Close(); r.error != nil {
return trace.Wrap(r.error)
}
r.gzipReader = nil
}
r.reader = reader
r.state = protoReaderStateInit
return nil
}

func (r *ProtoReader) setError(err error) error {
r.state = protoReaderStateError
r.error = err
return err
}

// GetStats returns stats about processed events
func (r *ProtoReader) GetStats() ProtoReaderStats {
return r.stats
}

// Read returns next event or io.EOF in case of the end of the parts
func (r *ProtoReader) Read(ctx context.Context) (apievents.AuditEvent, error) {
// periodic checks of context after fixed amount of iterations
// is an extra precaution to avoid
// accidental endless loop due to logic error crashing the system
// and allows ctx timeout to kick in if specified
var checkpointIteration int64
for {
checkpointIteration++
if checkpointIteration%defaults.MaxIterationLimit == 0 {
select {
case <-ctx.Done():
if ctx.Err() != nil {
return nil, trace.Wrap(ctx.Err())
}
return nil, trace.LimitExceeded("context has been canceled")
default:
}
}
switch r.state {
case protoReaderStateEOF:
return nil, io.EOF
case protoReaderStateError:
return nil, r.error
case protoReaderStateInit:
// read the part header that consists of the protocol version
// and the part size (for the V1 version of the protocol)
_, err := io.ReadFull(r.reader, r.sizeBytes[:Int64Size])
if err != nil {
// reached the end of the stream
if errors.Is(err, io.EOF) {
r.state = protoReaderStateEOF
return nil, err
}
return nil, r.setError(trace.ConvertSystemError(err))
}
protocolVersion := binary.BigEndian.Uint64(r.sizeBytes[:Int64Size])
if protocolVersion != ProtoStreamV1 {
return nil, trace.BadParameter("unsupported protocol version %v", protocolVersion)
}
// read size of this gzipped part as encoded by V1 protocol version
_, err = io.ReadFull(r.reader, r.sizeBytes[:Int64Size])
if err != nil {
return nil, r.setError(trace.ConvertSystemError(err))
}
partSize := binary.BigEndian.Uint64(r.sizeBytes[:Int64Size])
// read padding size (could be 0)
_, err = io.ReadFull(r.reader, r.sizeBytes[:Int64Size])
if err != nil {
return nil, r.setError(trace.ConvertSystemError(err))
}
r.padding = int64(binary.BigEndian.Uint64(r.sizeBytes[:Int64Size]))
gzipReader, err := newGzipReader(io.NopCloser(io.LimitReader(r.reader, int64(partSize))))
if err != nil {
return nil, r.setError(trace.Wrap(err))
}
r.gzipReader = gzipReader
r.state = protoReaderStateCurrent
continue
// read the next version from the gzip reader
case protoReaderStateCurrent:
// the record consists of length of the protobuf encoded
// message and the message itself
_, err := io.ReadFull(r.gzipReader, r.sizeBytes[:Int32Size])
if err != nil {
if !errors.Is(err, io.EOF) {
return nil, r.setError(trace.ConvertSystemError(err))
}

// due to a bug in older versions of teleport it was possible that padding
// bytes would end up inside of the gzip section of the archive. we should
// skip any dangling data in the gzip secion.
n, err := io.CopyBuffer(io.Discard, r.gzipReader.inner, r.messageBytes[:])
if err != nil {
return nil, r.setError(trace.ConvertSystemError(err))
}

if n != 0 {
// log the number of bytes that were skipped
slog.DebugContext(ctx, "skipped dangling data in session recording section", "length", n)
}

// reached the end of the current part, but not necessarily
// the end of the stream
if err := r.gzipReader.Close(); err != nil {
return nil, r.setError(trace.ConvertSystemError(err))
}
if r.padding != 0 {
skipped, err := io.CopyBuffer(io.Discard, io.LimitReader(r.reader, r.padding), r.messageBytes[:])
if err != nil {
return nil, r.setError(trace.ConvertSystemError(err))
}
if skipped != r.padding {
return nil, r.setError(trace.BadParameter(
"data truncated, expected to read %v bytes, but got %v", r.padding, skipped))
}
}
r.padding = 0
r.gzipReader = nil
r.state = protoReaderStateInit
continue
}
messageSize := binary.BigEndian.Uint32(r.sizeBytes[:Int32Size])
// zero message size indicates end of the part
// that sometimes is present in partially submitted parts
// that have to be filled with zeroes for parts smaller
// than minimum allowed size
if messageSize == 0 {
return nil, r.setError(trace.BadParameter("unexpected message size 0"))
}
_, err = io.ReadFull(r.gzipReader, r.messageBytes[:messageSize])
if err != nil {
return nil, r.setError(trace.ConvertSystemError(err))
}
oneof := apievents.OneOf{}
err = oneof.Unmarshal(r.messageBytes[:messageSize])
if err != nil {
return nil, trace.Wrap(err)
}
event, err := apievents.FromOneOf(oneof)
if err != nil {
return nil, trace.Wrap(err)
}
r.stats.TotalEvents++
if event.GetIndex() <= r.lastIndex {
r.stats.SkippedEvents++
continue
}
if r.lastIndex > 0 && event.GetIndex() != r.lastIndex+1 {
r.stats.OutOfOrderEvents++
}
r.lastIndex = event.GetIndex()
return event, nil
default:
return nil, trace.BadParameter("unsupported reader size")
}
}
}

// ReadAll reads all events until EOF
func (r *ProtoReader) ReadAll(ctx context.Context) ([]apievents.AuditEvent, error) {
var events []apievents.AuditEvent
for {
event, err := r.Read(ctx)
if err != nil {
if errors.Is(err, io.EOF) {
return events, nil
}
return nil, trace.Wrap(err)
}
events = append(events, event)
}
}
31 changes: 31 additions & 0 deletions api/sessionrecording/stream_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package sessionrecording_test

import (
"context"
"os"
"testing"

"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/sessionrecording"
)

// TestReadCorruptedRecording tests that the streamer can successfully decode the kind of corrupted
// recordings that some older bugged versions of teleport might end up producing when under heavy load/throttling.
func TestReadCorruptedRecording(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

f, err := os.Open("testdata/corrupted-session")
require.NoError(t, err)
defer f.Close()

reader := sessionrecording.NewProtoReader(f)
defer reader.Close()

events, err := reader.ReadAll(ctx)
require.NoError(t, err)

// verify that the expected number of events are extracted
require.Len(t, events, 12)
}
File renamed without changes.
4 changes: 2 additions & 2 deletions lib/client/player.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ import (

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/sessionrecording"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/session"
)

Expand All @@ -51,7 +51,7 @@ func (p *playFromFileStreamer) StreamSessionEvents(
}
defer f.Close()

pr := events.NewProtoReader(f)
pr := sessionrecording.NewProtoReader(f)
for i := int64(0); ; i++ {
evt, err := pr.Read(ctx)
if err != nil {
Expand Down
Loading

0 comments on commit 82fea9f

Please sign in to comment.