Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v14] Prevent exiting a session prior to output being consumed #45373

Merged
merged 2 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) {

// wait until we've found the session in the audit log
getSession := func(site authclient.ClientI) (types.SessionTracker, error) {
timeout, cancel := context.WithTimeout(context.Background(), 10*time.Second)
timeout, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
sessions, err := waitForSessionToBeEstablished(timeout, defaults.Namespace, site)
if err != nil {
Expand Down Expand Up @@ -520,7 +520,7 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) {
select {
case err := <-endC:
require.NoError(t, err)
case <-time.After(10 * time.Second):
case <-time.After(15 * time.Second):
t.Fatalf("%s: Timeout waiting for session to finish.", tt.comment)
}

Expand Down Expand Up @@ -1251,7 +1251,7 @@ func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) {
return
}
sessionID = trackers[0].GetSessionID()
}, time.Second*5, time.Millisecond*100)
}, time.Second*15, time.Millisecond*100)

// Send stuff to the session.
term.Type("echo Hello\n\r")
Expand All @@ -1265,7 +1265,7 @@ func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) {

// Wait for the session to terminate without error.
term.Type("exit\n\r")
require.NoError(t, waitForError(errCh, 5*time.Second))
require.NoError(t, waitForError(errCh, 15*time.Second))

// Wait for the session recording to be uploaded and available
var uploaded bool
Expand All @@ -1285,7 +1285,7 @@ func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) {
events, err := authSrv.GetSessionEvents(defaults.Namespace, session.ID(sessionID), 0)
assert.NoError(t, err)
assert.NotEmpty(t, events)
}, 5*time.Second, 200*time.Millisecond)
}, 15*time.Second, 200*time.Millisecond)
})
}
}
Expand Down Expand Up @@ -2032,7 +2032,7 @@ func testShutdown(t *testing.T, suite *integrationTestSuite) {
select {
case err := <-sshErr:
require.NoError(t, err)
case <-time.After(5 * time.Second):
case <-time.After(15 * time.Second):
require.FailNow(t, "failed to shutdown ssh session")
}

Expand Down Expand Up @@ -8898,7 +8898,7 @@ func TestConnectivityWithoutAuth(t *testing.T) {
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(5 * time.Second):
case <-time.After(15 * time.Second):
t.Fatal("timeout waiting for session to exit")
}
require.Contains(t, term.AllOutput(), "hi")
Expand Down Expand Up @@ -8930,7 +8930,7 @@ func TestConnectivityWithoutAuth(t *testing.T) {
if !authRunning {
require.Empty(t, term.AllOutput())
}
case <-time.After(5 * time.Second):
case <-time.After(15 * time.Second):
t.Fatal("timeout waiting for session to exit")
}
},
Expand Down
30 changes: 17 additions & 13 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -915,18 +915,19 @@ func (s *session) haltTerminal() {
// prematurely can result in missing audit events, session recordings, and other
// unexpected errors.
func (s *session) Close() error {
s.Stop()

s.BroadcastMessage("Closing session...")
s.log.Infof("Closing session")

serverSessions.Dec()

// Remove session parties and close client connections.
// Remove session parties and close client connections. Since terminals
// might await for all the parties to be released, we must close them first.
// Closing the parties will cause their SSH channel to be closed, meaning
// any goroutine reading from it will be released.
for _, p := range s.getParties() {
p.Close()
}

s.Stop()
serverSessions.Dec()
s.registry.removeSession(s)

// Complete the session recording
Expand Down Expand Up @@ -1343,21 +1344,21 @@ func (s *session) startInteractive(ctx context.Context, scx *ServerContext, p *p
s.log.WithError(err).Error("Received error waiting for the interactive session to finish")
}

if result != nil {
if err := s.registry.broadcastResult(s.id, *result); err != nil {
s.log.Warningf("Failed to broadcast session result: %v", err)
}
}

// wait for copying from the pty to be complete or a timeout before
// broadcasting the result (which will close the pty) if it has not been
// closed already.
select {
case <-time.After(defaults.WaitCopyTimeout):
s.log.Error("Timed out waiting for PTY copy to finish, session data may be missing.")
s.log.Debug("Timed out waiting for PTY copy to finish, session data may be missing.")
case <-s.doneCh:
}

if result != nil {
if err := s.registry.broadcastResult(s.id, *result); err != nil {
s.log.Warningf("Failed to broadcast session result: %v", err)
}
}

if execRequest, err := scx.GetExecRequest(); err == nil && execRequest.GetCommand() != "" {
emitExecAuditEvent(scx, execRequest.GetCommand(), err)
}
Expand Down Expand Up @@ -1527,8 +1528,11 @@ func (s *session) broadcastResult(r ExecResult) {
s.mu.Lock()
defer s.mu.Unlock()

payload := ssh.Marshal(struct{ C uint32 }{C: uint32(r.Code)})
for _, p := range s.parties {
p.ctx.SendExecResult(r)
if _, err := p.ch.SendRequest("exit-status", false, payload); err != nil {
s.log.Infof("Failed to send exit status for %v: %v", r.Command, err)
}
}
}

Expand Down
188 changes: 188 additions & 0 deletions lib/srv/sess_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ package srv

import (
"context"
"crypto/ed25519"
"io"
"net"
"os/user"
"slices"
"sync/atomic"
Expand All @@ -30,10 +32,12 @@ import (
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"golang.org/x/term"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/constants"
apidefaults "github.com/gravitational/teleport/api/defaults"
tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh"
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/defaults"
Expand Down Expand Up @@ -1235,3 +1239,187 @@ func TestSessionRecordingMode(t *testing.T) {
})
}
}

func TestCloseProxySession(t *testing.T) {
srv := newMockServer(t)
srv.component = teleport.ComponentProxy

reg, err := NewSessionRegistry(SessionRegistryConfig{
Srv: srv,
SessionTrackerService: srv.auth,
})
require.NoError(t, err)
t.Cleanup(func() { reg.Close() })

scx := newTestServerContext(t, reg.Srv, nil)

// Open a new session
sshChanOpen := newMockSSHChannel()
// Always close the session from the client side to avoid it being stuck
// on closing (server side).
t.Cleanup(func() { sshChanOpen.Close() })
go func() {
// Consume stdout sent to the channel
io.ReadAll(sshChanOpen)
}()

err = reg.OpenSession(context.Background(), sshChanOpen, scx)
require.NoError(t, err)
require.NotNil(t, scx.session)

// After the session is open, we force a close coming from the server. Do
// this inside a goroutine to avoid being blocked.
closeChan := make(chan error)
go func() {
closeChan <- scx.session.Close()
}()

select {
case err := <-closeChan:
require.NoError(t, err)
case <-time.After(10 * time.Second):
require.Fail(t, "expected session to be closed")
}
}

// TestClodeRemoteSession given a remote session recording at proxy ensure that
// closing the session releases all the resources, and return properly to the
// user.
func TestCloseRemoteSession(t *testing.T) {
srv := newMockServer(t)
srv.component = teleport.ComponentProxy

// init a session registry
reg, _ := NewSessionRegistry(SessionRegistryConfig{
Srv: srv,
SessionTrackerService: srv.auth,
})
t.Cleanup(func() { reg.Close() })

scx := newTestServerContext(t, reg.Srv, nil)
scx.SessionRecordingConfig.SetMode(types.RecordAtProxy)
scx.RemoteSession = mockSSHSession(t)

// Open a new session
sshChanOpen := newMockSSHChannel()
// Always close the session from the client side to avoid it being stuck
// on closing (server side).
t.Cleanup(func() { sshChanOpen.Close() })
go func() {
// Consume stdout sent to the channel
io.ReadAll(sshChanOpen)
}()

err := reg.OpenSession(context.Background(), sshChanOpen, scx)
require.NoError(t, err)
require.NotNil(t, scx.session)

// After the session is open, we force a close coming from the server. Do
// this inside a goroutine to avoid being blocked.
closeChan := make(chan error)
go func() {
closeChan <- scx.session.Close()
}()

select {
case err := <-closeChan:
require.NoError(t, err)
case <-time.After(10 * time.Second):
require.Fail(t, "expected session to be closed")
}
}

func mockSSHSession(t *testing.T) *tracessh.Session {
t.Helper()

ctx := context.Background()

_, key, err := ed25519.GenerateKey(nil)
require.NoError(t, err)
signer, err := ssh.NewSignerFromKey(key)
require.NoError(t, err)

cfg := &ssh.ServerConfig{NoClientAuth: true}
cfg.AddHostKey(signer)

listener, err := net.Listen("tcp", "localhost:")
require.NoError(t, err)
t.Cleanup(func() { listener.Close() })

go func() {
conn, err := listener.Accept()
if err != nil {
t.Logf("error while accepting ssh connections: %s", err)
return
}

srvConn, chCh, reqCh, err := ssh.NewServerConn(conn, cfg)
if err != nil {
t.Logf("error while accepting creating a new ssh server conn: %s", err)
return
}
t.Cleanup(func() { srvConn.Close() })

go ssh.DiscardRequests(reqCh)
for newChannel := range chCh {
channel, requests, err := newChannel.Accept()
if err != nil {
t.Logf("failed to accept channel: %s", err)
continue
}

go func() {
for req := range requests {
req.Reply(true, nil)
}
}()

sessTerm := term.NewTerminal(channel, "> ")
go func() {
defer channel.Close()
for {
_, err := sessTerm.ReadLine()
if err != nil {
break
}
}
}()
}
}()

// Establish a connection to the newly created server.
sessCh := make(chan *tracessh.Session)
go func() {
client, err := tracessh.Dial(ctx, listener.Addr().Network(), listener.Addr().String(), &ssh.ClientConfig{
Timeout: 10 * time.Second,
User: "user",
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
HostKeyCallback: ssh.FixedHostKey(signer.PublicKey()),
})
if err != nil {
t.Logf("failed to dial test ssh server: %s", err)
close(sessCh)
return
}
t.Cleanup(func() { client.Close() })

sess, err := client.NewSession(ctx)
if err != nil {
t.Logf("failed to dial test ssh server: %s", err)
close(sessCh)
return
}
t.Cleanup(func() { sess.Close() })

sessCh <- sess
}()

select {
case sess, ok := <-sessCh:
require.True(t, ok, "expected SSH session but got nothing")
return sess
case <-time.After(10 * time.Second):
require.Fail(t, "timeout while waiting for the SSH session")
return nil
}
}
7 changes: 5 additions & 2 deletions lib/srv/term.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,16 +654,19 @@ func (t *remoteTerminal) PID() int {
}

func (t *remoteTerminal) Close() error {
t.wg.Wait()
// this closes the underlying stdin,stdout,stderr which is what ptyBuffer is
// hooked to directly
err := t.session.Close()
if err != nil {
return trace.Wrap(err)
}

t.log.Debugf("Closed remote terminal and underlying SSH session")
// Wait for parties to be relased after closing the remote session. This
// avoid cases where the parties are blocked, reading from the remote
// session.
t.wg.Wait()

t.log.Debugf("Closed remote terminal and underlying SSH session")
return nil
}

Expand Down
Loading