Skip to content

Commit

Permalink
fix(lib): use MaxProtoMessageSizeBytes from api/sessionrecording
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinspecker committed Dec 20, 2024
1 parent 8ca748e commit cf51226
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 18 deletions.
3 changes: 2 additions & 1 deletion lib/events/auditwriter_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ import (

"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/sessionrecording"
apievents "github.com/gravitational/teleport/api/types/events"
)

func TestBytesToSessionPrintEvents(t *testing.T) {
b := make([]byte, MaxProtoMessageSizeBytes+1)
b := make([]byte, sessionrecording.MaxProtoMessageSizeBytes+1)
_, err := rand.Read(b)
require.NoError(t, err)

Expand Down
5 changes: 3 additions & 2 deletions lib/events/filesessions/filestream.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/gravitational/trace"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/sessionrecording"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -69,7 +70,7 @@ func GetOpenFileFunc() utils.OpenFileWithFlagsFunc {
}

// minUploadBytes is the minimum part file size required to trigger its upload.
const minUploadBytes = events.MaxProtoMessageSizeBytes * 2
const minUploadBytes = sessionrecording.MaxProtoMessageSizeBytes * 2

// NewStreamer creates a streamer sending uploads to disk
func NewStreamer(dir string) (*events.ProtoStreamer, error) {
Expand Down Expand Up @@ -355,7 +356,7 @@ func (h *Handler) ReserveUploadPart(ctx context.Context, upload events.StreamUpl
}

// Create a buffer with the max size that a part file can have.
buf := make([]byte, minUploadBytes+events.MaxProtoMessageSizeBytes)
buf := make([]byte, minUploadBytes+sessionrecording.MaxProtoMessageSizeBytes)

_, err = file.Write(buf)
if err = trace.NewAggregate(err, file.Close()); err != nil {
Expand Down
5 changes: 3 additions & 2 deletions lib/events/session_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/jonboulle/clockwork"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/sessionrecording"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/session"
Expand Down Expand Up @@ -141,8 +142,8 @@ func bytesToSessionPrintEvents(b []byte) []apievents.AuditEvent {
},
Data: b,
}
if printEvent.Size() > MaxProtoMessageSizeBytes {
extraBytes := printEvent.Size() - MaxProtoMessageSizeBytes
if printEvent.Size() > sessionrecording.MaxProtoMessageSizeBytes {
extraBytes := printEvent.Size() - sessionrecording.MaxProtoMessageSizeBytes
printEvent.Data = b[:extraBytes]
printEvent.Bytes = int64(len(printEvent.Data))
b = b[extraBytes:]
Expand Down
13 changes: 5 additions & 8 deletions lib/events/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ const (
// per stream
ConcurrentUploadsPerStream = 1

// MaxProtoMessageSizeBytes is maximum protobuf marshaled message size
MaxProtoMessageSizeBytes = 64 * 1024

// MinUploadPartSizeBytes is the minimum allowed part size when uploading a part to
// Amazon S3.
MinUploadPartSizeBytes = 1024 * 1024 * 5
Expand Down Expand Up @@ -108,7 +105,7 @@ func NewProtoStreamer(cfg ProtoStreamerConfig) (*ProtoStreamer, error) {
// Min upload bytes + some overhead to prevent buffer growth (gzip writer is not precise)
bufferPool: utils.NewBufferSyncPool(cfg.MinUploadBytes + cfg.MinUploadBytes/3),
// MaxProtoMessage size + length of the message record
slicePool: utils.NewSliceSyncPool(MaxProtoMessageSizeBytes + ProtoStreamV1RecordHeaderSize),
slicePool: utils.NewSliceSyncPool(sessionrecording.MaxProtoMessageSizeBytes + ProtoStreamV1RecordHeaderSize),
}, nil
}

Expand Down Expand Up @@ -383,10 +380,10 @@ func (s *ProtoStream) Done() <-chan struct{} {
func (s *ProtoStream) RecordEvent(ctx context.Context, pe apievents.PreparedSessionEvent) error {
event := pe.GetAuditEvent()
messageSize := event.Size()
if messageSize > MaxProtoMessageSizeBytes {
event = event.TrimToMaxSize(MaxProtoMessageSizeBytes)
if event.Size() > MaxProtoMessageSizeBytes {
return trace.BadParameter("record size %v exceeds max message size of %v bytes", messageSize, MaxProtoMessageSizeBytes)
if messageSize > sessionrecording.MaxProtoMessageSizeBytes {
event = event.TrimToMaxSize(sessionrecording.MaxProtoMessageSizeBytes)
if event.Size() > sessionrecording.MaxProtoMessageSizeBytes {
return trace.BadParameter("record size %v exceeds max message size of %v bytes", messageSize, sessionrecording.MaxProtoMessageSizeBytes)
}
}

Expand Down
5 changes: 3 additions & 2 deletions lib/events/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/sessionrecording"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/eventstest"
Expand Down Expand Up @@ -174,12 +175,12 @@ func TestProtoStreamLargeEvent(t *testing.T) {
}{
{
name: "large trimmable event is trimmed",
event: makeQueryEvent("1", strings.Repeat("A", events.MaxProtoMessageSizeBytes)),
event: makeQueryEvent("1", strings.Repeat("A", sessionrecording.MaxProtoMessageSizeBytes)),
errAssertion: require.NoError,
},
{
name: "large untrimmable event returns error",
event: makeAccessRequestEvent("1", strings.Repeat("A", events.MaxProtoMessageSizeBytes)),
event: makeAccessRequestEvent("1", strings.Repeat("A", sessionrecording.MaxProtoMessageSizeBytes)),
errAssertion: require.Error,
},
}
Expand Down
5 changes: 3 additions & 2 deletions lib/srv/desktop/windows_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (

"github.com/gravitational/teleport"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/sessionrecording"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/auth"
Expand Down Expand Up @@ -1063,7 +1064,7 @@ func (s *WindowsService) makeTDPSendHandler(
Message: b,
DelayMilliseconds: delay(),
}
if e.Size() > libevents.MaxProtoMessageSizeBytes {
if e.Size() > sessionrecording.MaxProtoMessageSizeBytes {
// Technically a PNG frame is unbounded and could be too big for a single protobuf.
// In practice though, Windows limits RDP bitmaps to 64x64 pixels, and we compress
// the PNGs before they get here, so most PNG frames are under 500 bytes. The largest
Expand Down Expand Up @@ -1137,7 +1138,7 @@ func (s *WindowsService) makeTDPReceiveHandler(
Message: b,
DelayMilliseconds: delay(),
}
if e.Size() > libevents.MaxProtoMessageSizeBytes {
if e.Size() > sessionrecording.MaxProtoMessageSizeBytes {
// screen spec, mouse button, and mouse move are fixed size messages,
// so they cannot exceed the maximum size
s.cfg.Logger.WarnContext(ctx, "refusing to record message", "len", len(b), "type", logutils.TypeAttr(m))
Expand Down
3 changes: 2 additions & 1 deletion lib/srv/desktop/windows_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/sessionrecording"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/auth"
Expand Down Expand Up @@ -249,7 +250,7 @@ func TestSkipsExtremelyLargePNGs(t *testing.T) {
emitterPreparer := libevents.WithNoOpPreparer(emitter)

// a fake PNG Frame message, which is way too big to be legitimate
maliciousPNG := make([]byte, libevents.MaxProtoMessageSizeBytes+1)
maliciousPNG := make([]byte, sessionrecording.MaxProtoMessageSizeBytes+1)
rand.Read(maliciousPNG)
maliciousPNG[0] = byte(tdp.TypePNGFrame)

Expand Down

0 comments on commit cf51226

Please sign in to comment.