From 66094dde9611189c7988005129c3d146ab7f8cb8 Mon Sep 17 00:00:00 2001 From: joerger Date: Thu, 14 Nov 2024 12:23:33 -0800 Subject: [PATCH] Add individual event validation. --- lib/srv/regular/sshserver_test.go | 85 ++++++++++++++++++------------- 1 file changed, 50 insertions(+), 35 deletions(-) diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 76940f8b8d9bb..953e6d38d30a7 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -448,6 +448,7 @@ loop: // // Note: This is a regression test for a bug which resulted in extra, empty session.data // events for networking requests. +// See https://github.com/gravitational/teleport/issues/48728. func TestSessionAuditLog(t *testing.T) { ctx := context.Background() t.Parallel() @@ -479,6 +480,16 @@ func TestSessionAuditLog(t *testing.T) { _, err = f.testSrv.Auth().UpsertRole(ctx, role) require.NoError(t, err) + nextEvent := func() apievents.AuditEvent { + select { + case event := <-emitter.C(): + return event + case <-time.After(time.Second): + require.Fail(t, "timed out waiting for event") + } + return nil + } + // Start a new session se, err := f.ssh.clt.NewSession(ctx) require.NoError(t, err) @@ -487,17 +498,26 @@ func TestSessionAuditLog(t *testing.T) { err = se.Shell(ctx) require.NoError(t, err) - // Request agent forwarding + e := nextEvent() + startEvent, ok := e.(*apievents.SessionStart) + require.True(t, ok, "expected SessionStart event but got event of type %T", e) + require.NotEmpty(t, startEvent.SessionID, "expected non empty sessionID") + sessionID := startEvent.SessionID + + // Request agent forwarding, no individual event emitted. err = agent.RequestAgentForwarding(se.Session) require.NoError(t, err) - // Request x11 forwarding + // Request x11 forwarding, event should be emitted immediately. clientXAuthEntry, err := x11.NewFakeXAuthEntry(x11.Display{}) require.NoError(t, err) err = x11.RequestForwarding(se.Session, clientXAuthEntry) require.NoError(t, err) - // Request a remote port forwarding listener. + x11Event := nextEvent() + require.IsType(t, x11Event, &apievents.X11Forward{}, "expected X11Forward event but got event of tgsype %T", x11Event) + + // Request a remote port forwarding listener. The event is logged at the end of the session. listener, err := f.ssh.clt.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -509,44 +529,39 @@ func TestSessionAuditLog(t *testing.T) { ts.Listener = listener ts.Start() - // Request forward to remote port. + // Request forward to remote port. Each dial should result in a new event. Note that we don't + // know what port the server will forward the connection on, so we don't have an easy way to + // validate the event's addr field. conn, err := f.ssh.clt.DialContext(context.Background(), "tcp", listener.Addr().String()) require.NoError(t, err) conn.Close() - // End the session. - se.Close() - - expectedEvents := []apievents.AuditEvent{ - &apievents.SessionStart{}, - &apievents.X11Forward{}, - &apievents.PortForward{}, - &apievents.SessionLeave{}, - &apievents.SessionData{}, - &apievents.PortForward{}, - &apievents.SessionEnd{}, - } - - timeout := time.After(time.Second) - var gotEvents []apievents.AuditEvent -loop: - for { - select { - case event := <-emitter.C(): - fmt.Printf("type %T\n", event) - gotEvents = append(gotEvents, event) - case <-timeout: - break loop - } - } + directPortForwardEvent := nextEvent() + require.IsType(t, directPortForwardEvent, &apievents.PortForward{}, "expected PortForward event but got event of type %T", directPortForwardEvent) - if len(expectedEvents) != len(gotEvents) { - require.Fail(t, "collected audit events don't match expected events", "got %v, expected %v", gotEvents, expectedEvents) - } + // End the session. Session leave, data, and end events should be emitted, along with the remote + // port forwarding event. + se.Close() - for i, ee := range expectedEvents { - require.IsType(t, ee, gotEvents[i]) - } + e = nextEvent() + leaveEvent, ok := e.(*apievents.SessionLeave) + require.True(t, ok, "expected SessionLeave event but got event of type %T", e) + require.Equal(t, sessionID, leaveEvent.SessionID) + + e = nextEvent() + dataEvent, ok := e.(*apievents.SessionData) + require.True(t, ok, "expected SessionData event but got event of type %T", e) + require.Equal(t, sessionID, dataEvent.SessionID) + + e = nextEvent() + remotePortForwardEvent, ok := e.(*apievents.PortForward) + require.True(t, ok, "expected PortForward event but got event of type %T", e) + require.Equal(t, listener.Addr().String(), remotePortForwardEvent.Addr) + + e = nextEvent() + endEvent, ok := e.(*apievents.SessionEnd) + require.True(t, ok, "expected SessionEnd event but got event of type %T", e) + require.Equal(t, sessionID, endEvent.SessionID) } func newProxyClient(t *testing.T, testSvr *auth.TestServer) (*authclient.Client, string) {