Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove deprecated GetSessionChunk and GetSessionEvents APIs #46797

Merged
merged 14 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 91 additions & 127 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -544,152 +544,91 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) {
}
}

// read back the entire session (we have to try several times until we get back
// everything because the session is closing)
var sessionStream []byte
for i := 0; i < 6; i++ {
sessionStream, err = site.GetSessionChunk(defaults.Namespace, session.ID(tracker.GetSessionID()), 0, events.MaxChunkBytes)
require.NoError(t, err)
if strings.Contains(string(sessionStream), "exit") {
break
}
time.Sleep(time.Millisecond * 250)
if i >= 5 {
// session stream keeps coming back short
t.Fatalf("%s: Stream is not getting data: %q.", tt.comment, string(sessionStream))
}
}

// see what we got. It looks different based on bash settings, but here it is
// on Ev's machine (hostname is 'edsger'):
//
// edsger ~: echo hi
// hi
// edsger ~: exit
// logout
//
text := string(sessionStream)
require.Contains(t, text, "echo hi")
require.Contains(t, text, "exit")

// Wait until session.start, session.leave, and session.end events have arrived.
getSessions := func(site authclient.ClientI) ([]events.EventFields, error) {
tickCh := time.Tick(500 * time.Millisecond)
stopCh := time.After(10 * time.Second)
for {
select {
case <-tickCh:
// Get all session events from the backend.
sessionEvents, err := site.GetSessionEvents(defaults.Namespace, session.ID(tracker.GetSessionID()), 0)
if err != nil {
return nil, trace.Wrap(err)
}

// Look through all session events for the three wanted.
var hasStart bool
var hasEnd bool
var hasLeave bool
for _, se := range sessionEvents {
var isAuditEvent bool
if se.GetType() == events.SessionStartEvent {
isAuditEvent = true
hasStart = true
}
if se.GetType() == events.SessionEndEvent {
isAuditEvent = true
hasEnd = true
}
if se.GetType() == events.SessionLeaveEvent {
isAuditEvent = true
hasLeave = true
}

// ensure session events are also in audit log
if !isAuditEvent {
continue
}
auditEvents, _, err := site.SearchEvents(ctx, events.SearchEventsRequest{
To: time.Now(),
EventTypes: []string{se.GetType()},
})
require.NoError(t, err)

found := slices.ContainsFunc(auditEvents, func(ae apievents.AuditEvent) bool {
return ae.GetID() == se.GetID()
})
require.True(t, found)
}

// Make sure all three events were found.
if hasStart && hasEnd && hasLeave {
return sessionEvents, nil
}
case <-stopCh:
return nil, trace.BadParameter("unable to find all session events after 10s (mode=%v)", tt.inRecordLocation)
}
}
}
history, err := getSessions(site)
require.NoError(t, err)

getChunk := func(e events.EventFields, maxlen int) string {
offset := e.GetInt("offset")
length := e.GetInt("bytes")
if length == 0 {
return ""
}
if length > maxlen {
length = maxlen
}
return string(sessionStream[offset : offset+length])
}
capturedStream, sessionEvents := streamSession(ctx, t, site, sessionID)

findByType := func(et string) events.EventFields {
for _, e := range history {
findByType := func(et string) apievents.AuditEvent {
for _, e := range sessionEvents {
if e.GetType() == et {
return e
}
}
return nil
}
// helper that asserts that a session event is also included in the
// general audit log.
requireInAuditLog := func(t *testing.T, sessionEvent apievents.AuditEvent) {
t.Helper()
auditEvents, _, err := site.SearchEvents(ctx, events.SearchEventsRequest{
To: time.Now(),
EventTypes: []string{sessionEvent.GetType()},
})
require.NoError(t, err)
require.True(t, slices.ContainsFunc(auditEvents, func(ae apievents.AuditEvent) bool {
return ae.GetID() == sessionEvent.GetID()
}))
}

// there should always be 'session.start' event (and it must be first)
first := history[0]
start := findByType(events.SessionStartEvent)
first := sessionEvents[0].(*apievents.SessionStart)
start := findByType(events.SessionStartEvent).(*apievents.SessionStart)
require.Equal(t, first, start)
require.Equal(t, 0, start.GetInt("bytes"))
require.Equal(t, sessionID, start.GetString(events.SessionEventID))
require.NotEmpty(t, start.GetString(events.TerminalSize))

// make sure data is recorded properly
out := &bytes.Buffer{}
for _, e := range history {
out.WriteString(getChunk(e, 1000))
}
recorded := replaceNewlines(out.String())
require.Regexp(t, ".*exit.*", recorded)
require.Regexp(t, ".*echo hi.*", recorded)
require.Equal(t, sessionID, start.SessionID)
require.NotEmpty(t, start.TerminalSize)
requireInAuditLog(t, start)

// there should always be 'session.end' event
end := findByType(events.SessionEndEvent)
end := findByType(events.SessionEndEvent).(*apievents.SessionEnd)
require.NotNil(t, end)
require.Equal(t, 0, end.GetInt("bytes"))
require.Equal(t, sessionID, end.GetString(events.SessionEventID))
require.Equal(t, sessionID, end.SessionID)
requireInAuditLog(t, end)

// there should always be 'session.leave' event
leave := findByType(events.SessionLeaveEvent)
leave := findByType(events.SessionLeaveEvent).(*apievents.SessionLeave)
require.NotNil(t, leave)
require.Equal(t, 0, leave.GetInt("bytes"))
require.Equal(t, sessionID, leave.GetString(events.SessionEventID))
require.Equal(t, sessionID, leave.SessionID)
requireInAuditLog(t, leave)

// all of them should have a proper time
for _, e := range history {
require.False(t, e.GetTime("time").IsZero())
for _, e := range sessionEvents {
require.False(t, e.GetTime().IsZero())
}

// Check data was recorded properly
recorded := replaceNewlines(capturedStream)
require.Regexp(t, ".*exit.*", recorded)
require.Regexp(t, ".*echo hi.*", recorded)
})
}
}

func streamSession(
ctx context.Context,
t *testing.T,
streamer events.SessionStreamer,
sessionID string,
) (string, []apievents.AuditEvent) {
t.Helper()
evtCh, errCh := streamer.StreamSessionEvents(ctx, session.ID(sessionID), 0)
strideynet marked this conversation as resolved.
Show resolved Hide resolved
capturedStream := &bytes.Buffer{}
evts := make([]apievents.AuditEvent, 0)
readLoop:
for {
select {
case evt := <-evtCh:
if evt == nil {
break readLoop
}
if evt.GetType() == events.SessionPrintEvent {
capturedStream.Write(evt.(*apievents.SessionPrint).Data)
}
evts = append(evts, evt)
case err := <-errCh:
require.NoError(t, err)
}
}
return capturedStream.String(), evts
}

// testInteroperability checks if Teleport and OpenSSH behave in the same way
// when executing commands.
func testInteroperability(t *testing.T, suite *integrationTestSuite) {
Expand Down Expand Up @@ -1279,9 +1218,22 @@ func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) {
}

require.EventuallyWithT(t, func(t *assert.CollectT) {
events, err := authSrv.GetSessionEvents(defaults.Namespace, session.ID(sessionID), 0)
assert.NoError(t, err)
assert.NotEmpty(t, events)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
eventsCh, errCh := authSrv.StreamSessionEvents(ctx, session.ID(sessionID), 0)
for {
select {
case err := <-errCh:
assert.NoError(t, err)
return
case evt := <-eventsCh:
if evt != nil {
return
}
assert.Fail(t, "expected event, got nil")
return
}
}
}, 15*time.Second, 200*time.Millisecond)
})
}
Expand Down Expand Up @@ -5018,7 +4970,19 @@ func testAuditOff(t *testing.T, suite *integrationTestSuite) {

// however, attempts to read the actual sessions should fail because it was
// not actually recorded
_, err = site.GetSessionChunk(defaults.Namespace, session.ID(tracker.GetSessionID()), 0, events.MaxChunkBytes)
eventsCh, errCh := site.StreamSessionEvents(ctx, session.ID(tracker.GetSessionID()), 0)
err = nil
readLoop:
for {
select {
case evt := <-eventsCh:
if evt != nil {
t.Fatalf("Unexpected event: %v", evt)
}
case err = <-errCh:
break readLoop
}
}
require.Error(t, err)

// ensure that session related events were emitted to audit log
Expand Down
20 changes: 6 additions & 14 deletions integration/kube_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ import (
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/breaker"
"github.com/gravitational/teleport/api/constants"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/profile"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/entitlements"
Expand All @@ -79,7 +78,6 @@ import (
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
kubeutils "github.com/gravitational/teleport/lib/kube/utils"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/service"
Expand Down Expand Up @@ -343,10 +341,8 @@ loop:
}

// read back the entire session and verify that it matches the stated output
capturedStream, err := teleport.Process.GetAuthServer().GetSessionChunk(apidefaults.Namespace, session.ID(sessionID), 0, events.MaxChunkBytes)
require.NoError(t, err)

require.Equal(t, sessionStream, string(capturedStream))
capturedStream, _ := streamSession(ctx, t, teleport.Process.GetAuthServer(), sessionID)
require.Equal(t, sessionStream, capturedStream)

// impersonating kube exec should be denied
// interactive command, allocate pty
Expand Down Expand Up @@ -778,10 +774,8 @@ loop:
}

// read back the entire session and verify that it matches the stated output
capturedStream, err := main.Process.GetAuthServer().GetSessionChunk(apidefaults.Namespace, session.ID(sessionID), 0, events.MaxChunkBytes)
require.NoError(t, err)

require.Equal(t, sessionStream, string(capturedStream))
capturedStream, _ := streamSession(ctx, t, main.Process.GetAuthServer(), sessionID)
require.Equal(t, sessionStream, capturedStream)

// impersonating kube exec should be denied
// interactive command, allocate pty
Expand Down Expand Up @@ -1052,10 +1046,8 @@ loop:
}

// read back the entire session and verify that it matches the stated output
capturedStream, err := main.Process.GetAuthServer().GetSessionChunk(apidefaults.Namespace, session.ID(sessionID), 0, events.MaxChunkBytes)
require.NoError(t, err)

require.Equal(t, sessionStream, string(capturedStream))
capturedStream, _ := streamSession(ctx, t, main.Process.GetAuthServer(), sessionID)
require.Equal(t, sessionStream, capturedStream)

// impersonating kube exec should be denied
// interactive command, allocate pty
Expand Down
Loading
Loading