diff --git a/lib/resumption/handover.go b/lib/resumption/handover.go
new file mode 100644
index 0000000000000..a60e31eaacff2
--- /dev/null
+++ b/lib/resumption/handover.go
@@ -0,0 +1,267 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package resumption
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/base64"
+ "errors"
+ "io"
+ "net"
+ "net/netip"
+ "os"
+ "path/filepath"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/lib/multiplexer"
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+const sockSuffix = ".sock"
+
+func sockPath(dataDir string, token resumptionToken) string {
+ hash := sha256.Sum256(token[:])
+ // unix domain sockets are limited to 108 or 104 characters, so the full
+ // sha256 hash is a bit too much (64 bytes in hex or 44 in b64); truncating
+ // the hash to 128 bits still gives us more than enough headroom to just
+ // assume that we'll have no collisions (a probability of one in a
+ // quintillion with 26 billion concurrent connections)
+ return filepath.Join(dataDir, "handover", base64.RawURLEncoding.EncodeToString(hash[:16])+sockSuffix)
+}
+
+func sockDir(dataDir string) string {
+ return filepath.Join(dataDir, "handover")
+}
+
+var errNoDataDir error = &trace.NotFoundError{Message: "data dir not configured"}
+
+func (r *SSHServerWrapper) attemptHandover(conn *multiplexer.Conn, token resumptionToken) {
+ handoverConn, err := r.dialHandover(token)
+ if err != nil {
+ if trace.IsNotFound(err) {
+ r.log.Debug("Resumable connection not found or already deleted.")
+ _, _ = conn.Write([]byte{notFoundServerExchangeTag})
+ return
+ }
+ r.log.WithError(err).Error("Error while connecting to handover socket.")
+ return
+ }
+ defer handoverConn.Close()
+
+ var remoteIP netip.Addr
+ if t, _ := conn.RemoteAddr().(*net.TCPAddr); t != nil {
+ remoteIP, _ = netip.AddrFromSlice(t.IP)
+ }
+ remoteIP16 := remoteIP.As16()
+
+ if _, err := handoverConn.Write(remoteIP16[:]); err != nil {
+ if !utils.IsOKNetworkError(err) {
+ r.log.WithError(err).Error("Error while forwarding remote address to handover socket.")
+ }
+ return
+ }
+
+ r.log.Debug("Forwarding resuming connection to handover socket.")
+ if err := utils.ProxyConn(context.Background(), conn, handoverConn); err != nil && !utils.IsOKNetworkError(err) {
+ r.log.WithError(err).Debug("Finished forwarding resuming connection to handover socket.")
+ }
+}
+
+func (r *SSHServerWrapper) dialHandover(token resumptionToken) (net.Conn, error) {
+ if r.dataDir == "" {
+ return nil, trace.Wrap(errNoDataDir)
+ }
+
+ c, err := net.DialTimeout("unix", sockPath(r.dataDir, token), time.Second)
+ if err != nil {
+ return nil, trace.ConvertSystemError(err)
+ }
+ return c, nil
+}
+
+func (r *SSHServerWrapper) startHandoverListener(ctx context.Context, token resumptionToken, entry *connEntry) error {
+ l, err := r.createHandoverListener(token)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ go r.runHandoverListener(l, entry)
+ context.AfterFunc(ctx, func() { _ = l.Close() })
+
+ return nil
+}
+
+func (r *SSHServerWrapper) createHandoverListener(token resumptionToken) (net.Listener, error) {
+ if r.dataDir == "" {
+ return nil, trace.Wrap(errNoDataDir)
+ }
+
+ _ = os.MkdirAll(sockDir(r.dataDir), teleport.PrivateDirMode)
+ l, err := net.Listen("unix", sockPath(r.dataDir, token))
+ if err != nil {
+ return nil, trace.ConvertSystemError(err)
+ }
+ return l, nil
+}
+
+func (r *SSHServerWrapper) runHandoverListener(l net.Listener, entry *connEntry) {
+ defer l.Close()
+
+ var tempDelay time.Duration
+ for {
+ // the logic for this Accept loop is copied from [net/http.Server]
+ c, err := l.Accept()
+ if err == nil {
+ tempDelay = 0
+ go r.handleHandoverConnection(c, entry)
+ continue
+ }
+
+ if tempErr, ok := err.(interface{ Temporary() bool }); !ok || !tempErr.Temporary() {
+ if !utils.IsOKNetworkError(err) {
+ r.log.WithError(err).Warn("Accept error in handover listener.")
+ }
+ return
+ }
+
+ tempDelay = max(5*time.Millisecond, min(2*tempDelay, time.Second))
+ r.log.WithError(err).WithField("delay", tempDelay).Warn("Temporary accept error in handover listener, continuing after delay.")
+ time.Sleep(tempDelay)
+ }
+}
+
+func (r *SSHServerWrapper) handleHandoverConnection(conn net.Conn, entry *connEntry) {
+ defer conn.Close()
+
+ var remoteIP16 [16]byte
+ if _, err := io.ReadFull(conn, remoteIP16[:]); err != nil {
+ if !utils.IsOKNetworkError(err) {
+ r.log.WithError(err).Error("Error while reading remote address from handover socket.")
+ }
+ return
+ }
+ remoteIP := netip.AddrFrom16(remoteIP16).Unmap()
+
+ r.resumeConnection(entry, conn, remoteIP)
+}
+
+// HandoverCleanup deletes handover sockets that were left over from previous
+// runs of Teleport that failed to clean up after themselves (because of an
+// uncatchable signal or a system crash). It will exhaustively clean up the
+// current left over sockets, so it's sufficient to call it once per process.
+func (r *SSHServerWrapper) HandoverCleanup(ctx context.Context) error {
+ const cleanupDelay = time.Second
+ return trace.Wrap(r.handoverCleanup(ctx, cleanupDelay))
+}
+
+func (r *SSHServerWrapper) handoverCleanup(ctx context.Context, cleanupDelay time.Duration) error {
+ if r.dataDir == "" {
+ return nil
+ }
+
+ dir := sockDir(r.dataDir)
+ entries, err := os.ReadDir(dir)
+ if err != nil {
+ if errors.Is(err, os.ErrNotExist) {
+ return nil
+ }
+ return trace.ConvertSystemError(err)
+ }
+
+ paths := make([]string, 0, len(entries))
+ for _, entry := range entries {
+ if strings.HasSuffix(entry.Name(), sockSuffix) {
+ paths = append(paths, filepath.Join(dir, entry.Name()))
+ }
+ }
+
+ paths, firstErr := retainNonConnectableSockets(ctx, paths)
+
+ if len(paths) < 1 {
+ return trace.Wrap(firstErr)
+ }
+
+ // unix domain sockets exist on disk between bind() and listen() but
+ // connecting before listen() results in ECONNREFUSED, so we just wait a
+ // little bit before testing them again; the first check lets us be done
+ // with the check immediately in the happy case where there's no
+ // unconnectable sockets
+ r.log.WithField("sockets", len(paths)).Debug("Found some unconnectable handover sockets, waiting before checking them again.")
+
+ select {
+ case <-ctx.Done():
+ return trace.NewAggregate(firstErr, ctx.Err())
+ case <-time.After(cleanupDelay):
+ }
+
+ paths, secondErr := retainNonConnectableSockets(ctx, paths)
+
+ if len(paths) < 1 {
+ r.log.Debug("Found no unconnectable handover sockets after waiting.")
+ return trace.NewAggregate(firstErr, secondErr)
+ }
+
+ r.log.WithField("sockets", len(paths)).Info("Cleaning up some non-connectable handover sockets from old Teleport instances.")
+
+ errs := []error{firstErr, secondErr}
+ for _, path := range paths {
+ if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) {
+ if len(errs) < 10 {
+ errs = append(errs, trace.ConvertSystemError(err))
+ }
+ }
+ }
+
+ return trace.NewAggregate(errs...)
+}
+
+// retainNonConnectableSockets attempts to connect to the given UNIX domain
+// sockets, returning all and only the ones that exist and that refuse the
+// connection.
+func retainNonConnectableSockets(ctx context.Context, paths []string) ([]string, error) {
+ var d net.Dialer
+ var errs []error
+ filtered := paths[:0]
+ for _, path := range paths {
+ c, err := d.DialContext(ctx, "unix", path)
+ if err == nil {
+ _ = c.Close()
+ continue
+ }
+
+ if errors.Is(err, os.ErrNotExist) {
+ continue
+ }
+
+ if errors.Is(err, syscall.ECONNREFUSED) {
+ filtered = append(filtered, path)
+ continue
+ }
+
+ if len(errs) < 10 {
+ errs = append(errs, trace.ConvertSystemError(err))
+ }
+ }
+
+ return filtered, trace.NewAggregate(errs...)
+}
diff --git a/lib/resumption/handover_test.go b/lib/resumption/handover_test.go
new file mode 100644
index 0000000000000..6b504866209ee
--- /dev/null
+++ b/lib/resumption/handover_test.go
@@ -0,0 +1,198 @@
+// Teleport
+// Copyright (C) 2024 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package resumption
+
+import (
+ "context"
+ "encoding/binary"
+ "math/rand"
+ "net"
+ "net/netip"
+ "os"
+ "path/filepath"
+ "runtime"
+ "testing"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/gravitational/trace"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/lib/utils"
+ "github.com/gravitational/teleport/lib/utils/uds"
+)
+
+func TestHandover(t *testing.T) {
+ t.Parallel()
+ require := require.New(t)
+
+ sshServer := discardingSSHServer(t)
+ hostID := uuid.NewString()
+ // unix domain socket names have a very tight length limit
+ dataDir := shortTempDir(t)
+
+ s1 := NewSSHServerWrapper(SSHServerWrapperConfig{
+ SSHServer: sshServer,
+ HostID: hostID,
+ DataDir: dataDir,
+ })
+ s2 := NewSSHServerWrapper(SSHServerWrapperConfig{
+ SSHServer: sshServer,
+ HostID: hostID,
+ DataDir: dataDir,
+ })
+
+ dial := func(handleConnection func(net.Conn), clientAddr netip.Addr) net.Conn {
+ c1, c2, err := uds.NewSocketpair(uds.SocketTypeStream)
+ require.NoError(err)
+
+ srv := &net.TCPAddr{
+ IP: net.IPv4(127, 0, 0, 1),
+ Port: 1 + rand.Intn(65535),
+ }
+ clt := &net.TCPAddr{
+ IP: clientAddr.AsSlice(),
+ Zone: clientAddr.Zone(),
+ Port: 1 + rand.Intn(65535),
+ }
+
+ go handleConnection(utils.NewConnWithAddr(c2, srv, clt))
+ conn := utils.NewConnWithAddr(c1, clt, srv)
+ t.Cleanup(func() { _ = conn.Close() })
+ return conn
+ }
+
+ originalNC := dial(s1.HandleConnection, netip.MustParseAddr("127.0.0.1"))
+
+ redialConns := make(chan net.Conn)
+ defer close(redialConns)
+
+ wrappedNC, err := WrapSSHClientConn(context.Background(), originalNC, func(ctx context.Context, receivedHostID string) (net.Conn, error) {
+ if receivedHostID != hostID {
+ return nil, trace.BadParameter("expected hostID %q, got %q", hostID, receivedHostID)
+ }
+ conn := <-redialConns
+ if conn == nil {
+ return nil, trace.ConnectionProblem(nil, "no redial connection received")
+ }
+ return conn, nil
+ })
+ require.NoError(err)
+ defer wrappedNC.Close()
+
+ require.IsType((*Conn)(nil), wrappedNC)
+
+ clt, err := sshClient(wrappedNC)
+ require.NoError(err)
+ defer clt.Close()
+
+ const wantReplyTrue = true
+ _, _, err = clt.SendRequest("foo", wantReplyTrue, nil)
+ require.NoError(err)
+
+ _ = originalNC.Close()
+ nextNC := dial(s2.HandleConnection, netip.MustParseAddr("127.0.0.1"))
+ redialConns <- nextNC
+
+ _, _, err = clt.SendRequest("foo", wantReplyTrue, nil)
+ require.NoError(err)
+
+ _ = nextNC.Close()
+ // this will result in a closed connection, because changing network address
+ // stops further reconnection attempts
+ redialConns <- dial(s2.HandleConnection, netip.MustParseAddr("127.0.0.2"))
+
+ require.ErrorIs(clt.Wait(), net.ErrClosed)
+}
+
+func shortTempDir(t *testing.T) string {
+ t.Helper()
+ base := ""
+ if runtime.GOOS == "darwin" {
+ base = "/tmp"
+ }
+ d, err := os.MkdirTemp(base, "")
+ require.NoError(t, err)
+ t.Cleanup(func() { require.NoError(t, os.RemoveAll(d)) })
+ return d
+}
+
+func TestHandoverCleanup(t *testing.T) {
+ t.Parallel()
+ require := require.New(t)
+
+ hostID := uuid.NewString()
+ // unix domain socket names have a very tight length limit
+ dataDir := shortTempDir(t)
+
+ var tok resumptionToken
+ binary.NativeEndian.PutUint64(tok[:8], rand.Uint64())
+ binary.NativeEndian.PutUint64(tok[8:], rand.Uint64())
+
+ s := NewSSHServerWrapper(SSHServerWrapperConfig{
+ SSHServer: func(c net.Conn) {
+ defer c.Close()
+ assert.Fail(t, "unexpected connection")
+ },
+ HostID: hostID,
+ DataDir: dataDir,
+ })
+
+ handoverDir := filepath.Join(dataDir, "handover")
+ require.NoError(os.MkdirAll(handoverDir, teleport.PrivateDirMode))
+
+ d, err := os.ReadDir(handoverDir)
+ require.NoError(err)
+ require.Empty(d)
+
+ l, err := net.ListenUnix("unix", &net.UnixAddr{Name: sockPath(dataDir, tok)})
+ require.NoError(err)
+ l.SetUnlinkOnClose(false)
+ defer l.Close()
+ go func() {
+ defer l.Close()
+ for {
+ c, err := l.Accept()
+ if err != nil {
+ break
+ }
+ _ = c.Close()
+ }
+ }()
+
+ d, err = os.ReadDir(handoverDir)
+ require.NoError(err)
+ require.NotEmpty(d)
+
+ ctx := context.Background()
+
+ const cleanupDelayZero time.Duration = 0
+ require.NoError(s.handoverCleanup(ctx, cleanupDelayZero))
+
+ d, err = os.ReadDir(handoverDir)
+ require.NoError(err)
+ require.NotEmpty(d)
+
+ l.Close()
+ require.NoError(s.handoverCleanup(ctx, cleanupDelayZero))
+
+ d, err = os.ReadDir(handoverDir)
+ require.NoError(err)
+ require.Empty(d)
+}
diff --git a/lib/resumption/resumption_test.go b/lib/resumption/resumption_test.go
index 13fa191b835cf..de6f74a89f58c 100644
--- a/lib/resumption/resumption_test.go
+++ b/lib/resumption/resumption_test.go
@@ -42,7 +42,10 @@ func TestResumption(t *testing.T) {
hostID := uuid.NewString()
sshServer := discardingSSHServer(t)
- resumableServer := NewSSHServerWrapper(nil, sshServer, hostID)
+ resumableServer := NewSSHServerWrapper(SSHServerWrapperConfig{
+ SSHServer: sshServer,
+ HostID: hostID,
+ })
directListener, err := net.Listen("tcp", "localhost:")
require.NoError(t, err)
diff --git a/lib/resumption/server_detect.go b/lib/resumption/server_detect.go
index 7805bc19ddc09..9b247271b2c02 100644
--- a/lib/resumption/server_detect.go
+++ b/lib/resumption/server_detect.go
@@ -58,18 +58,36 @@ func serverVersionCRLFV1(pubKey *ecdh.PublicKey, hostID string) string {
)
}
-// NewSSHServerWrapper wraps a given SSH server as to support connection
-// resumption.
-func NewSSHServerWrapper(log logrus.FieldLogger, sshServer func(net.Conn), hostID string) *SSHServerWrapper {
- if log == nil {
- log = logrus.WithField(trace.Component, Component)
+type SSHServerWrapperConfig struct {
+ Log logrus.FieldLogger
+
+ // SSHServer is a function that takes ownership of a [net.Conn] and uses it
+ // as a SSH server. If the Conn is a [sshutils.SSHServerVersionOverrider],
+ // the server should use the overridden server version.
+ SSHServer func(net.Conn)
+
+ // HostID is the host ID of the Teleport instance running the server;
+ // compliant connection resumption clients will reconnect to the host ID
+ // expecting to reach the instance.
+ HostID string
+
+ // DataDir is the path to the Teleport data directory, used to store
+ // temporary handover sockets.
+ DataDir string
+}
+
+// NewSSHServerWrapper creates a [SSHServerWrapper].
+func NewSSHServerWrapper(cfg SSHServerWrapperConfig) *SSHServerWrapper {
+ if cfg.Log == nil {
+ cfg.Log = logrus.WithField(trace.Component, Component)
}
return &SSHServerWrapper{
- sshServer: sshServer,
- log: log,
+ sshServer: cfg.SSHServer,
+ log: cfg.Log,
- hostID: hostID,
+ hostID: cfg.HostID,
+ dataDir: cfg.DataDir,
conns: make(map[resumptionToken]*connEntry),
}
@@ -85,7 +103,8 @@ type SSHServerWrapper struct {
sshServer func(net.Conn)
log logrus.FieldLogger
- hostID string
+ hostID string
+ dataDir string
mu sync.Mutex
conns map[resumptionToken]*connEntry
diff --git a/lib/resumption/server_exchange.go b/lib/resumption/server_exchange.go
index 96e2e01f139ab..292c8495cd58d 100644
--- a/lib/resumption/server_exchange.go
+++ b/lib/resumption/server_exchange.go
@@ -17,6 +17,7 @@
package resumption
import (
+ "context"
"crypto/ecdh"
"crypto/sha256"
"io"
@@ -99,24 +100,32 @@ func (r *SSHServerWrapper) handleResumptionExchangeV1(conn *multiplexer.Conn, dh
remoteIP, _ = netip.AddrFromSlice(t.IP)
remoteIP = remoteIP.Unmap()
}
- if !remoteIP.IsValid() {
- r.log.Warn("Tracking resumable connection with an invalid remote address (this is a bug).")
- }
token := resumptionToken(otp32[:16])
-
- r.mu.Lock()
entry := &connEntry{
conn: resumableConn,
remoteIP: remoteIP,
timeout: time.AfterFunc(detachedTimeout, func() { resumableConn.Close() }),
}
- r.conns[token] = entry
- r.mu.Unlock()
+
+ // this context is only used for the convenience of [context.AfterFunc]
+ handoverContext, handoverCancel := context.WithCancel(context.Background())
+ if remoteIP.IsValid() {
+ r.mu.Lock()
+ r.conns[token] = entry
+ r.mu.Unlock()
+
+ if err := r.startHandoverListener(handoverContext, token, entry); err != nil {
+ r.log.WithError(err).Warn("Unable to create handover listener for resumable connection, connection resumption will not work across graceful restarts.")
+ }
+ } else {
+ r.log.Warn("Refusing to track resumable connection with an invalid remote IP address, connection resumption will not work (this is a bug).")
+ }
go func() {
defer r.log.Info("Resumable connection completed.")
defer resumableConn.Close()
+ defer handoverCancel()
defer func() {
r.mu.Lock()
defer r.mu.Unlock()
@@ -156,25 +165,23 @@ func (r *SSHServerWrapper) handleResumptionExchangeV1(conn *multiplexer.Conn, dh
r.mu.Unlock()
if entry == nil {
- r.log.Debug("Resumable connection not found or already deleted.")
- _, _ = conn.Write([]byte{notFoundServerExchangeTag})
+ r.attemptHandover(conn, token)
return
}
- if entry.remoteIP.IsValid() {
- var remoteIP netip.Addr
- if t, ok := conn.RemoteAddr().(*net.TCPAddr); ok {
- remoteIP, _ = netip.AddrFromSlice(t.IP)
- }
- if entry.remoteIP != remoteIP.Unmap() {
- r.log.Warn("Resumable connection attempted resumption from a different remote address.")
- _, _ = conn.Write([]byte{badAddressServerExchangeTag})
- return
- }
- } else {
- r.log.Warn("Resumable connection attempted resumption on a connection with an invalid remote address.")
- // the client can't recover from this, just send a "not found"
- _, _ = conn.Write([]byte{notFoundServerExchangeTag})
+ var remoteIP netip.Addr
+ if t, _ := conn.RemoteAddr().(*net.TCPAddr); t != nil {
+ remoteIP, _ = netip.AddrFromSlice(t.IP)
+ remoteIP = remoteIP.Unmap()
+ }
+
+ r.resumeConnection(entry, conn, remoteIP)
+}
+
+func (r *SSHServerWrapper) resumeConnection(entry *connEntry, conn net.Conn, remoteIP netip.Addr) {
+ if entry.remoteIP != remoteIP {
+ r.log.Warn("Resumable connection attempted resumption from a different remote address.")
+ _, _ = conn.Write([]byte{badAddressServerExchangeTag})
return
}
diff --git a/lib/service/service.go b/lib/service/service.go
index 4700517f8cd0e..84c1455b0d557 100644
--- a/lib/service/service.go
+++ b/lib/service/service.go
@@ -2681,11 +2681,19 @@ func (process *TeleportProcess) initSSH() error {
var resumableServer *resumption.SSHServerWrapper
if os.Getenv("TELEPORT_UNSTABLE_DISABLE_SSH_RESUMPTION") == "" {
- resumableServer = resumption.NewSSHServerWrapper(
- log.WithField(trace.Component, teleport.Component(teleport.ComponentNode, resumption.Component)),
- s.HandleConnection,
- serverID,
- )
+ resumableServer = resumption.NewSSHServerWrapper(resumption.SSHServerWrapperConfig{
+ Log: log.WithField(trace.Component, teleport.Component(teleport.ComponentNode, resumption.Component)),
+ SSHServer: s.HandleConnection,
+
+ HostID: serverID,
+ DataDir: cfg.DataDir,
+ })
+
+ go func() {
+ if err := resumableServer.HandoverCleanup(process.GracefulExitContext()); err != nil {
+ log.WithError(err).Warn("Failed to clean up handover sockets.")
+ }
+ }()
}
var agentPool *reversetunnel.AgentPool