diff --git a/lib/events/auditwriter_internal_test.go b/lib/events/auditwriter_internal_test.go index ed282f9a810e7..5fc7010034a59 100644 --- a/lib/events/auditwriter_internal_test.go +++ b/lib/events/auditwriter_internal_test.go @@ -24,11 +24,12 @@ import ( "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/sessionrecording" apievents "github.com/gravitational/teleport/api/types/events" ) func TestBytesToSessionPrintEvents(t *testing.T) { - b := make([]byte, MaxProtoMessageSizeBytes+1) + b := make([]byte, sessionrecording.MaxProtoMessageSizeBytes+1) _, err := rand.Read(b) require.NoError(t, err) diff --git a/lib/events/filesessions/filestream.go b/lib/events/filesessions/filestream.go index 36362ab7aa1c8..f5fdbcbd54de7 100644 --- a/lib/events/filesessions/filestream.go +++ b/lib/events/filesessions/filestream.go @@ -35,6 +35,7 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/sessionrecording" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/utils" @@ -69,7 +70,7 @@ func GetOpenFileFunc() utils.OpenFileWithFlagsFunc { } // minUploadBytes is the minimum part file size required to trigger its upload. -const minUploadBytes = events.MaxProtoMessageSizeBytes * 2 +const minUploadBytes = sessionrecording.MaxProtoMessageSizeBytes * 2 // NewStreamer creates a streamer sending uploads to disk func NewStreamer(dir string) (*events.ProtoStreamer, error) { @@ -355,7 +356,7 @@ func (h *Handler) ReserveUploadPart(ctx context.Context, upload events.StreamUpl } // Create a buffer with the max size that a part file can have. - buf := make([]byte, minUploadBytes+events.MaxProtoMessageSizeBytes) + buf := make([]byte, minUploadBytes+sessionrecording.MaxProtoMessageSizeBytes) _, err = file.Write(buf) if err = trace.NewAggregate(err, file.Close()); err != nil { diff --git a/lib/events/session_writer.go b/lib/events/session_writer.go index 0f52296b2332a..fd5c8d09388e3 100644 --- a/lib/events/session_writer.go +++ b/lib/events/session_writer.go @@ -31,6 +31,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/sessionrecording" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/session" @@ -141,8 +142,8 @@ func bytesToSessionPrintEvents(b []byte) []apievents.AuditEvent { }, Data: b, } - if printEvent.Size() > MaxProtoMessageSizeBytes { - extraBytes := printEvent.Size() - MaxProtoMessageSizeBytes + if printEvent.Size() > sessionrecording.MaxProtoMessageSizeBytes { + extraBytes := printEvent.Size() - sessionrecording.MaxProtoMessageSizeBytes printEvent.Data = b[:extraBytes] printEvent.Bytes = int64(len(printEvent.Data)) b = b[extraBytes:] diff --git a/lib/events/stream.go b/lib/events/stream.go index 7f23b65d560eb..d02e70ef41752 100644 --- a/lib/events/stream.go +++ b/lib/events/stream.go @@ -47,9 +47,6 @@ const ( // per stream ConcurrentUploadsPerStream = 1 - // MaxProtoMessageSizeBytes is maximum protobuf marshaled message size - MaxProtoMessageSizeBytes = 64 * 1024 - // MinUploadPartSizeBytes is the minimum allowed part size when uploading a part to // Amazon S3. MinUploadPartSizeBytes = 1024 * 1024 * 5 @@ -108,7 +105,7 @@ func NewProtoStreamer(cfg ProtoStreamerConfig) (*ProtoStreamer, error) { // Min upload bytes + some overhead to prevent buffer growth (gzip writer is not precise) bufferPool: utils.NewBufferSyncPool(cfg.MinUploadBytes + cfg.MinUploadBytes/3), // MaxProtoMessage size + length of the message record - slicePool: utils.NewSliceSyncPool(MaxProtoMessageSizeBytes + ProtoStreamV1RecordHeaderSize), + slicePool: utils.NewSliceSyncPool(sessionrecording.MaxProtoMessageSizeBytes + ProtoStreamV1RecordHeaderSize), }, nil } @@ -383,10 +380,10 @@ func (s *ProtoStream) Done() <-chan struct{} { func (s *ProtoStream) RecordEvent(ctx context.Context, pe apievents.PreparedSessionEvent) error { event := pe.GetAuditEvent() messageSize := event.Size() - if messageSize > MaxProtoMessageSizeBytes { - event = event.TrimToMaxSize(MaxProtoMessageSizeBytes) - if event.Size() > MaxProtoMessageSizeBytes { - return trace.BadParameter("record size %v exceeds max message size of %v bytes", messageSize, MaxProtoMessageSizeBytes) + if messageSize > sessionrecording.MaxProtoMessageSizeBytes { + event = event.TrimToMaxSize(sessionrecording.MaxProtoMessageSizeBytes) + if event.Size() > sessionrecording.MaxProtoMessageSizeBytes { + return trace.BadParameter("record size %v exceeds max message size of %v bytes", messageSize, sessionrecording.MaxProtoMessageSizeBytes) } } diff --git a/lib/events/stream_test.go b/lib/events/stream_test.go index 6b1dab52e6575..a9d3fa659eb7e 100644 --- a/lib/events/stream_test.go +++ b/lib/events/stream_test.go @@ -28,6 +28,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/sessionrecording" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" @@ -174,12 +175,12 @@ func TestProtoStreamLargeEvent(t *testing.T) { }{ { name: "large trimmable event is trimmed", - event: makeQueryEvent("1", strings.Repeat("A", events.MaxProtoMessageSizeBytes)), + event: makeQueryEvent("1", strings.Repeat("A", sessionrecording.MaxProtoMessageSizeBytes)), errAssertion: require.NoError, }, { name: "large untrimmable event returns error", - event: makeAccessRequestEvent("1", strings.Repeat("A", events.MaxProtoMessageSizeBytes)), + event: makeAccessRequestEvent("1", strings.Repeat("A", sessionrecording.MaxProtoMessageSizeBytes)), errAssertion: require.Error, }, } diff --git a/lib/srv/desktop/windows_server.go b/lib/srv/desktop/windows_server.go index 8dbbad96b3fb6..373a0e4d0dfc4 100644 --- a/lib/srv/desktop/windows_server.go +++ b/lib/srv/desktop/windows_server.go @@ -39,6 +39,7 @@ import ( "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/sessionrecording" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/auth" @@ -1063,7 +1064,7 @@ func (s *WindowsService) makeTDPSendHandler( Message: b, DelayMilliseconds: delay(), } - if e.Size() > libevents.MaxProtoMessageSizeBytes { + if e.Size() > sessionrecording.MaxProtoMessageSizeBytes { // Technically a PNG frame is unbounded and could be too big for a single protobuf. // In practice though, Windows limits RDP bitmaps to 64x64 pixels, and we compress // the PNGs before they get here, so most PNG frames are under 500 bytes. The largest @@ -1137,7 +1138,7 @@ func (s *WindowsService) makeTDPReceiveHandler( Message: b, DelayMilliseconds: delay(), } - if e.Size() > libevents.MaxProtoMessageSizeBytes { + if e.Size() > sessionrecording.MaxProtoMessageSizeBytes { // screen spec, mouse button, and mouse move are fixed size messages, // so they cannot exceed the maximum size s.cfg.Logger.WarnContext(ctx, "refusing to record message", "len", len(b), "type", logutils.TypeAttr(m)) diff --git a/lib/srv/desktop/windows_server_test.go b/lib/srv/desktop/windows_server_test.go index 6ce3dd70b9160..ae3d64065f7c3 100644 --- a/lib/srv/desktop/windows_server_test.go +++ b/lib/srv/desktop/windows_server_test.go @@ -31,6 +31,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/sessionrecording" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/auth" @@ -249,7 +250,7 @@ func TestSkipsExtremelyLargePNGs(t *testing.T) { emitterPreparer := libevents.WithNoOpPreparer(emitter) // a fake PNG Frame message, which is way too big to be legitimate - maliciousPNG := make([]byte, libevents.MaxProtoMessageSizeBytes+1) + maliciousPNG := make([]byte, sessionrecording.MaxProtoMessageSizeBytes+1) rand.Read(maliciousPNG) maliciousPNG[0] = byte(tdp.TypePNGFrame)