diff --git a/lib/resumption/handover.go b/lib/resumption/handover.go new file mode 100644 index 0000000000000..a60e31eaacff2 --- /dev/null +++ b/lib/resumption/handover.go @@ -0,0 +1,267 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package resumption + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "errors" + "io" + "net" + "net/netip" + "os" + "path/filepath" + "strings" + "syscall" + "time" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/multiplexer" + "github.com/gravitational/teleport/lib/utils" +) + +const sockSuffix = ".sock" + +func sockPath(dataDir string, token resumptionToken) string { + hash := sha256.Sum256(token[:]) + // unix domain sockets are limited to 108 or 104 characters, so the full + // sha256 hash is a bit too much (64 bytes in hex or 44 in b64); truncating + // the hash to 128 bits still gives us more than enough headroom to just + // assume that we'll have no collisions (a probability of one in a + // quintillion with 26 billion concurrent connections) + return filepath.Join(dataDir, "handover", base64.RawURLEncoding.EncodeToString(hash[:16])+sockSuffix) +} + +func sockDir(dataDir string) string { + return filepath.Join(dataDir, "handover") +} + +var errNoDataDir error = &trace.NotFoundError{Message: "data dir not configured"} + +func (r *SSHServerWrapper) attemptHandover(conn *multiplexer.Conn, token resumptionToken) { + handoverConn, err := r.dialHandover(token) + if err != nil { + if trace.IsNotFound(err) { + r.log.Debug("Resumable connection not found or already deleted.") + _, _ = conn.Write([]byte{notFoundServerExchangeTag}) + return + } + r.log.WithError(err).Error("Error while connecting to handover socket.") + return + } + defer handoverConn.Close() + + var remoteIP netip.Addr + if t, _ := conn.RemoteAddr().(*net.TCPAddr); t != nil { + remoteIP, _ = netip.AddrFromSlice(t.IP) + } + remoteIP16 := remoteIP.As16() + + if _, err := handoverConn.Write(remoteIP16[:]); err != nil { + if !utils.IsOKNetworkError(err) { + r.log.WithError(err).Error("Error while forwarding remote address to handover socket.") + } + return + } + + r.log.Debug("Forwarding resuming connection to handover socket.") + if err := utils.ProxyConn(context.Background(), conn, handoverConn); err != nil && !utils.IsOKNetworkError(err) { + r.log.WithError(err).Debug("Finished forwarding resuming connection to handover socket.") + } +} + +func (r *SSHServerWrapper) dialHandover(token resumptionToken) (net.Conn, error) { + if r.dataDir == "" { + return nil, trace.Wrap(errNoDataDir) + } + + c, err := net.DialTimeout("unix", sockPath(r.dataDir, token), time.Second) + if err != nil { + return nil, trace.ConvertSystemError(err) + } + return c, nil +} + +func (r *SSHServerWrapper) startHandoverListener(ctx context.Context, token resumptionToken, entry *connEntry) error { + l, err := r.createHandoverListener(token) + if err != nil { + return trace.Wrap(err) + } + + go r.runHandoverListener(l, entry) + context.AfterFunc(ctx, func() { _ = l.Close() }) + + return nil +} + +func (r *SSHServerWrapper) createHandoverListener(token resumptionToken) (net.Listener, error) { + if r.dataDir == "" { + return nil, trace.Wrap(errNoDataDir) + } + + _ = os.MkdirAll(sockDir(r.dataDir), teleport.PrivateDirMode) + l, err := net.Listen("unix", sockPath(r.dataDir, token)) + if err != nil { + return nil, trace.ConvertSystemError(err) + } + return l, nil +} + +func (r *SSHServerWrapper) runHandoverListener(l net.Listener, entry *connEntry) { + defer l.Close() + + var tempDelay time.Duration + for { + // the logic for this Accept loop is copied from [net/http.Server] + c, err := l.Accept() + if err == nil { + tempDelay = 0 + go r.handleHandoverConnection(c, entry) + continue + } + + if tempErr, ok := err.(interface{ Temporary() bool }); !ok || !tempErr.Temporary() { + if !utils.IsOKNetworkError(err) { + r.log.WithError(err).Warn("Accept error in handover listener.") + } + return + } + + tempDelay = max(5*time.Millisecond, min(2*tempDelay, time.Second)) + r.log.WithError(err).WithField("delay", tempDelay).Warn("Temporary accept error in handover listener, continuing after delay.") + time.Sleep(tempDelay) + } +} + +func (r *SSHServerWrapper) handleHandoverConnection(conn net.Conn, entry *connEntry) { + defer conn.Close() + + var remoteIP16 [16]byte + if _, err := io.ReadFull(conn, remoteIP16[:]); err != nil { + if !utils.IsOKNetworkError(err) { + r.log.WithError(err).Error("Error while reading remote address from handover socket.") + } + return + } + remoteIP := netip.AddrFrom16(remoteIP16).Unmap() + + r.resumeConnection(entry, conn, remoteIP) +} + +// HandoverCleanup deletes handover sockets that were left over from previous +// runs of Teleport that failed to clean up after themselves (because of an +// uncatchable signal or a system crash). It will exhaustively clean up the +// current left over sockets, so it's sufficient to call it once per process. +func (r *SSHServerWrapper) HandoverCleanup(ctx context.Context) error { + const cleanupDelay = time.Second + return trace.Wrap(r.handoverCleanup(ctx, cleanupDelay)) +} + +func (r *SSHServerWrapper) handoverCleanup(ctx context.Context, cleanupDelay time.Duration) error { + if r.dataDir == "" { + return nil + } + + dir := sockDir(r.dataDir) + entries, err := os.ReadDir(dir) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return trace.ConvertSystemError(err) + } + + paths := make([]string, 0, len(entries)) + for _, entry := range entries { + if strings.HasSuffix(entry.Name(), sockSuffix) { + paths = append(paths, filepath.Join(dir, entry.Name())) + } + } + + paths, firstErr := retainNonConnectableSockets(ctx, paths) + + if len(paths) < 1 { + return trace.Wrap(firstErr) + } + + // unix domain sockets exist on disk between bind() and listen() but + // connecting before listen() results in ECONNREFUSED, so we just wait a + // little bit before testing them again; the first check lets us be done + // with the check immediately in the happy case where there's no + // unconnectable sockets + r.log.WithField("sockets", len(paths)).Debug("Found some unconnectable handover sockets, waiting before checking them again.") + + select { + case <-ctx.Done(): + return trace.NewAggregate(firstErr, ctx.Err()) + case <-time.After(cleanupDelay): + } + + paths, secondErr := retainNonConnectableSockets(ctx, paths) + + if len(paths) < 1 { + r.log.Debug("Found no unconnectable handover sockets after waiting.") + return trace.NewAggregate(firstErr, secondErr) + } + + r.log.WithField("sockets", len(paths)).Info("Cleaning up some non-connectable handover sockets from old Teleport instances.") + + errs := []error{firstErr, secondErr} + for _, path := range paths { + if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) { + if len(errs) < 10 { + errs = append(errs, trace.ConvertSystemError(err)) + } + } + } + + return trace.NewAggregate(errs...) +} + +// retainNonConnectableSockets attempts to connect to the given UNIX domain +// sockets, returning all and only the ones that exist and that refuse the +// connection. +func retainNonConnectableSockets(ctx context.Context, paths []string) ([]string, error) { + var d net.Dialer + var errs []error + filtered := paths[:0] + for _, path := range paths { + c, err := d.DialContext(ctx, "unix", path) + if err == nil { + _ = c.Close() + continue + } + + if errors.Is(err, os.ErrNotExist) { + continue + } + + if errors.Is(err, syscall.ECONNREFUSED) { + filtered = append(filtered, path) + continue + } + + if len(errs) < 10 { + errs = append(errs, trace.ConvertSystemError(err)) + } + } + + return filtered, trace.NewAggregate(errs...) +} diff --git a/lib/resumption/handover_test.go b/lib/resumption/handover_test.go new file mode 100644 index 0000000000000..6b504866209ee --- /dev/null +++ b/lib/resumption/handover_test.go @@ -0,0 +1,198 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package resumption + +import ( + "context" + "encoding/binary" + "math/rand" + "net" + "net/netip" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/google/uuid" + "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/uds" +) + +func TestHandover(t *testing.T) { + t.Parallel() + require := require.New(t) + + sshServer := discardingSSHServer(t) + hostID := uuid.NewString() + // unix domain socket names have a very tight length limit + dataDir := shortTempDir(t) + + s1 := NewSSHServerWrapper(SSHServerWrapperConfig{ + SSHServer: sshServer, + HostID: hostID, + DataDir: dataDir, + }) + s2 := NewSSHServerWrapper(SSHServerWrapperConfig{ + SSHServer: sshServer, + HostID: hostID, + DataDir: dataDir, + }) + + dial := func(handleConnection func(net.Conn), clientAddr netip.Addr) net.Conn { + c1, c2, err := uds.NewSocketpair(uds.SocketTypeStream) + require.NoError(err) + + srv := &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 1 + rand.Intn(65535), + } + clt := &net.TCPAddr{ + IP: clientAddr.AsSlice(), + Zone: clientAddr.Zone(), + Port: 1 + rand.Intn(65535), + } + + go handleConnection(utils.NewConnWithAddr(c2, srv, clt)) + conn := utils.NewConnWithAddr(c1, clt, srv) + t.Cleanup(func() { _ = conn.Close() }) + return conn + } + + originalNC := dial(s1.HandleConnection, netip.MustParseAddr("127.0.0.1")) + + redialConns := make(chan net.Conn) + defer close(redialConns) + + wrappedNC, err := WrapSSHClientConn(context.Background(), originalNC, func(ctx context.Context, receivedHostID string) (net.Conn, error) { + if receivedHostID != hostID { + return nil, trace.BadParameter("expected hostID %q, got %q", hostID, receivedHostID) + } + conn := <-redialConns + if conn == nil { + return nil, trace.ConnectionProblem(nil, "no redial connection received") + } + return conn, nil + }) + require.NoError(err) + defer wrappedNC.Close() + + require.IsType((*Conn)(nil), wrappedNC) + + clt, err := sshClient(wrappedNC) + require.NoError(err) + defer clt.Close() + + const wantReplyTrue = true + _, _, err = clt.SendRequest("foo", wantReplyTrue, nil) + require.NoError(err) + + _ = originalNC.Close() + nextNC := dial(s2.HandleConnection, netip.MustParseAddr("127.0.0.1")) + redialConns <- nextNC + + _, _, err = clt.SendRequest("foo", wantReplyTrue, nil) + require.NoError(err) + + _ = nextNC.Close() + // this will result in a closed connection, because changing network address + // stops further reconnection attempts + redialConns <- dial(s2.HandleConnection, netip.MustParseAddr("127.0.0.2")) + + require.ErrorIs(clt.Wait(), net.ErrClosed) +} + +func shortTempDir(t *testing.T) string { + t.Helper() + base := "" + if runtime.GOOS == "darwin" { + base = "/tmp" + } + d, err := os.MkdirTemp(base, "") + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, os.RemoveAll(d)) }) + return d +} + +func TestHandoverCleanup(t *testing.T) { + t.Parallel() + require := require.New(t) + + hostID := uuid.NewString() + // unix domain socket names have a very tight length limit + dataDir := shortTempDir(t) + + var tok resumptionToken + binary.NativeEndian.PutUint64(tok[:8], rand.Uint64()) + binary.NativeEndian.PutUint64(tok[8:], rand.Uint64()) + + s := NewSSHServerWrapper(SSHServerWrapperConfig{ + SSHServer: func(c net.Conn) { + defer c.Close() + assert.Fail(t, "unexpected connection") + }, + HostID: hostID, + DataDir: dataDir, + }) + + handoverDir := filepath.Join(dataDir, "handover") + require.NoError(os.MkdirAll(handoverDir, teleport.PrivateDirMode)) + + d, err := os.ReadDir(handoverDir) + require.NoError(err) + require.Empty(d) + + l, err := net.ListenUnix("unix", &net.UnixAddr{Name: sockPath(dataDir, tok)}) + require.NoError(err) + l.SetUnlinkOnClose(false) + defer l.Close() + go func() { + defer l.Close() + for { + c, err := l.Accept() + if err != nil { + break + } + _ = c.Close() + } + }() + + d, err = os.ReadDir(handoverDir) + require.NoError(err) + require.NotEmpty(d) + + ctx := context.Background() + + const cleanupDelayZero time.Duration = 0 + require.NoError(s.handoverCleanup(ctx, cleanupDelayZero)) + + d, err = os.ReadDir(handoverDir) + require.NoError(err) + require.NotEmpty(d) + + l.Close() + require.NoError(s.handoverCleanup(ctx, cleanupDelayZero)) + + d, err = os.ReadDir(handoverDir) + require.NoError(err) + require.Empty(d) +} diff --git a/lib/resumption/resumption_test.go b/lib/resumption/resumption_test.go index 13fa191b835cf..de6f74a89f58c 100644 --- a/lib/resumption/resumption_test.go +++ b/lib/resumption/resumption_test.go @@ -42,7 +42,10 @@ func TestResumption(t *testing.T) { hostID := uuid.NewString() sshServer := discardingSSHServer(t) - resumableServer := NewSSHServerWrapper(nil, sshServer, hostID) + resumableServer := NewSSHServerWrapper(SSHServerWrapperConfig{ + SSHServer: sshServer, + HostID: hostID, + }) directListener, err := net.Listen("tcp", "localhost:") require.NoError(t, err) diff --git a/lib/resumption/server_detect.go b/lib/resumption/server_detect.go index 7805bc19ddc09..9b247271b2c02 100644 --- a/lib/resumption/server_detect.go +++ b/lib/resumption/server_detect.go @@ -58,18 +58,36 @@ func serverVersionCRLFV1(pubKey *ecdh.PublicKey, hostID string) string { ) } -// NewSSHServerWrapper wraps a given SSH server as to support connection -// resumption. -func NewSSHServerWrapper(log logrus.FieldLogger, sshServer func(net.Conn), hostID string) *SSHServerWrapper { - if log == nil { - log = logrus.WithField(trace.Component, Component) +type SSHServerWrapperConfig struct { + Log logrus.FieldLogger + + // SSHServer is a function that takes ownership of a [net.Conn] and uses it + // as a SSH server. If the Conn is a [sshutils.SSHServerVersionOverrider], + // the server should use the overridden server version. + SSHServer func(net.Conn) + + // HostID is the host ID of the Teleport instance running the server; + // compliant connection resumption clients will reconnect to the host ID + // expecting to reach the instance. + HostID string + + // DataDir is the path to the Teleport data directory, used to store + // temporary handover sockets. + DataDir string +} + +// NewSSHServerWrapper creates a [SSHServerWrapper]. +func NewSSHServerWrapper(cfg SSHServerWrapperConfig) *SSHServerWrapper { + if cfg.Log == nil { + cfg.Log = logrus.WithField(trace.Component, Component) } return &SSHServerWrapper{ - sshServer: sshServer, - log: log, + sshServer: cfg.SSHServer, + log: cfg.Log, - hostID: hostID, + hostID: cfg.HostID, + dataDir: cfg.DataDir, conns: make(map[resumptionToken]*connEntry), } @@ -85,7 +103,8 @@ type SSHServerWrapper struct { sshServer func(net.Conn) log logrus.FieldLogger - hostID string + hostID string + dataDir string mu sync.Mutex conns map[resumptionToken]*connEntry diff --git a/lib/resumption/server_exchange.go b/lib/resumption/server_exchange.go index 96e2e01f139ab..292c8495cd58d 100644 --- a/lib/resumption/server_exchange.go +++ b/lib/resumption/server_exchange.go @@ -17,6 +17,7 @@ package resumption import ( + "context" "crypto/ecdh" "crypto/sha256" "io" @@ -99,24 +100,32 @@ func (r *SSHServerWrapper) handleResumptionExchangeV1(conn *multiplexer.Conn, dh remoteIP, _ = netip.AddrFromSlice(t.IP) remoteIP = remoteIP.Unmap() } - if !remoteIP.IsValid() { - r.log.Warn("Tracking resumable connection with an invalid remote address (this is a bug).") - } token := resumptionToken(otp32[:16]) - - r.mu.Lock() entry := &connEntry{ conn: resumableConn, remoteIP: remoteIP, timeout: time.AfterFunc(detachedTimeout, func() { resumableConn.Close() }), } - r.conns[token] = entry - r.mu.Unlock() + + // this context is only used for the convenience of [context.AfterFunc] + handoverContext, handoverCancel := context.WithCancel(context.Background()) + if remoteIP.IsValid() { + r.mu.Lock() + r.conns[token] = entry + r.mu.Unlock() + + if err := r.startHandoverListener(handoverContext, token, entry); err != nil { + r.log.WithError(err).Warn("Unable to create handover listener for resumable connection, connection resumption will not work across graceful restarts.") + } + } else { + r.log.Warn("Refusing to track resumable connection with an invalid remote IP address, connection resumption will not work (this is a bug).") + } go func() { defer r.log.Info("Resumable connection completed.") defer resumableConn.Close() + defer handoverCancel() defer func() { r.mu.Lock() defer r.mu.Unlock() @@ -156,25 +165,23 @@ func (r *SSHServerWrapper) handleResumptionExchangeV1(conn *multiplexer.Conn, dh r.mu.Unlock() if entry == nil { - r.log.Debug("Resumable connection not found or already deleted.") - _, _ = conn.Write([]byte{notFoundServerExchangeTag}) + r.attemptHandover(conn, token) return } - if entry.remoteIP.IsValid() { - var remoteIP netip.Addr - if t, ok := conn.RemoteAddr().(*net.TCPAddr); ok { - remoteIP, _ = netip.AddrFromSlice(t.IP) - } - if entry.remoteIP != remoteIP.Unmap() { - r.log.Warn("Resumable connection attempted resumption from a different remote address.") - _, _ = conn.Write([]byte{badAddressServerExchangeTag}) - return - } - } else { - r.log.Warn("Resumable connection attempted resumption on a connection with an invalid remote address.") - // the client can't recover from this, just send a "not found" - _, _ = conn.Write([]byte{notFoundServerExchangeTag}) + var remoteIP netip.Addr + if t, _ := conn.RemoteAddr().(*net.TCPAddr); t != nil { + remoteIP, _ = netip.AddrFromSlice(t.IP) + remoteIP = remoteIP.Unmap() + } + + r.resumeConnection(entry, conn, remoteIP) +} + +func (r *SSHServerWrapper) resumeConnection(entry *connEntry, conn net.Conn, remoteIP netip.Addr) { + if entry.remoteIP != remoteIP { + r.log.Warn("Resumable connection attempted resumption from a different remote address.") + _, _ = conn.Write([]byte{badAddressServerExchangeTag}) return } diff --git a/lib/service/service.go b/lib/service/service.go index 4700517f8cd0e..84c1455b0d557 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2681,11 +2681,19 @@ func (process *TeleportProcess) initSSH() error { var resumableServer *resumption.SSHServerWrapper if os.Getenv("TELEPORT_UNSTABLE_DISABLE_SSH_RESUMPTION") == "" { - resumableServer = resumption.NewSSHServerWrapper( - log.WithField(trace.Component, teleport.Component(teleport.ComponentNode, resumption.Component)), - s.HandleConnection, - serverID, - ) + resumableServer = resumption.NewSSHServerWrapper(resumption.SSHServerWrapperConfig{ + Log: log.WithField(trace.Component, teleport.Component(teleport.ComponentNode, resumption.Component)), + SSHServer: s.HandleConnection, + + HostID: serverID, + DataDir: cfg.DataDir, + }) + + go func() { + if err := resumableServer.HandoverCleanup(process.GracefulExitContext()); err != nil { + log.WithError(err).Warn("Failed to clean up handover sockets.") + } + }() } var agentPool *reversetunnel.AgentPool