From 68b5f58b00532ebd96efebfd19fe5d00262fa1b6 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Tue, 13 Feb 2024 16:58:47 +0100 Subject: [PATCH 01/18] WIP: Connection handover --- lib/resumption/handover.go | 112 ++++++++++++++++++++++ lib/resumption/handover_unix.go | 152 ++++++++++++++++++++++++++++++ lib/resumption/resumption_test.go | 5 +- lib/resumption/server_detect.go | 29 ++++-- lib/resumption/server_exchange.go | 53 ++++++----- lib/service/service.go | 18 +++- 6 files changed, 331 insertions(+), 38 deletions(-) create mode 100644 lib/resumption/handover.go create mode 100644 lib/resumption/handover_unix.go diff --git a/lib/resumption/handover.go b/lib/resumption/handover.go new file mode 100644 index 0000000000000..e442632c35f76 --- /dev/null +++ b/lib/resumption/handover.go @@ -0,0 +1,112 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package resumption + +import ( + "context" + "io" + "net" + "net/netip" + "time" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/multiplexer" + "github.com/gravitational/teleport/lib/utils" +) + +func (r *SSHServerWrapper) attemptHandover(conn *multiplexer.Conn, token resumptionToken) { + handoverConn, err := r.dialHandover(token) + if err != nil { + if trace.IsNotFound(err) { + r.log.Debug("Resumable connection not found or already deleted.") + _, _ = conn.Write([]byte{notFoundServerExchangeTag}) + return + } + r.log.WithError(err).Error("Error while connecting to handover socket.") + return + } + defer handoverConn.Close() + + var remoteIP netip.Addr + if t, _ := conn.RemoteAddr().(*net.TCPAddr); t != nil { + remoteIP, _ = netip.AddrFromSlice(t.IP) + } + remoteIP16 := remoteIP.As16() + + if _, err := handoverConn.Write(remoteIP16[:]); err != nil { + if !utils.IsOKNetworkError(err) { + r.log.WithError(err).Error("Error while forwarding remote address to handover socket.") + } + return + } + + r.log.Debug("Forwarding resuming connection to handover socket.") + _ = utils.ProxyConn(context.Background(), conn, handoverConn) +} + +func (r *SSHServerWrapper) setupHandoverListener(ctx context.Context, token resumptionToken, entry *connEntry) error { + l, err := r.listenHandover(token) + if err != nil { + return trace.Wrap(err) + } + + go r.runHandoverListener(l, entry) + context.AfterFunc(ctx, func() { _ = l.Close() }) + + return nil +} + +func (r *SSHServerWrapper) runHandoverListener(l net.Listener, entry *connEntry) { + defer l.Close() + + var tempDelay time.Duration + for { + c, err := l.Accept() + if err == nil { + tempDelay = 0 + go r.handleHandoverConnection(c, entry) + continue + } + + if tempErr, ok := err.(interface{ Temporary() bool }); !ok || !tempErr.Temporary() { + if !utils.IsOKNetworkError(err) { + r.log.WithError(err).Warn("Accept error in handover listener.") + } + return + } + + tempDelay = max(5*time.Millisecond, min(2*tempDelay, time.Second)) + r.log.WithError(err).WithField("delay", tempDelay).Warn("Temporary accept error in handover listener, continuing after delay.") + time.Sleep(tempDelay) + } +} + +func (r *SSHServerWrapper) handleHandoverConnection(conn net.Conn, entry *connEntry) { + defer conn.Close() + + var remoteIP16 [16]byte + if _, err := io.ReadFull(conn, remoteIP16[:]); err != nil { + if !utils.IsOKNetworkError(err) { + r.log.WithError(err).Error("Error while reading remote address from handover socket.") + } + return + } + remoteIP := netip.AddrFrom16(remoteIP16).Unmap() + + r.resumeConnection(entry, conn, remoteIP) +} diff --git a/lib/resumption/handover_unix.go b/lib/resumption/handover_unix.go new file mode 100644 index 0000000000000..37a35e63c3a83 --- /dev/null +++ b/lib/resumption/handover_unix.go @@ -0,0 +1,152 @@ +//go:build unix + +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package resumption + +import ( + "context" + "crypto/sha256" + "errors" + "fmt" + "net" + "os" + "path/filepath" + "strings" + "syscall" + "time" + + "github.com/gravitational/trace" +) + +func sockPath(dataDir string, token resumptionToken) string { + return filepath.Join(dataDir, "handover", fmt.Sprintf("%x.sock", sha256.Sum256(token[:]))) +} + +func sockDir(dataDir string) string { + return filepath.Join(dataDir, "handover") +} + +var errNoDataDir error = &trace.NotFoundError{Message: "data dir not configured"} + +func (r *SSHServerWrapper) listenHandover(token resumptionToken) (net.Listener, error) { + if r.dataDir == "" { + return nil, trace.Wrap(errNoDataDir) + } + + _ = os.MkdirAll(sockDir(r.dataDir), 0o700) + l, err := net.Listen("unix", sockPath(r.dataDir, token)) + if err != nil { + return nil, trace.ConvertSystemError(err) + } + return l, nil +} + +func (r *SSHServerWrapper) dialHandover(token resumptionToken) (net.Conn, error) { + if r.dataDir == "" { + return nil, trace.Wrap(errNoDataDir) + } + + c, err := net.DialTimeout("unix", sockPath(r.dataDir, token), time.Second) + if err != nil { + return nil, trace.ConvertSystemError(err) + } + return c, nil +} + +func filterNonConnectableSockets(ctx context.Context, paths []string) (filtered []string, lastErr error) { + filtered = paths[:0] + + var d net.Dialer + for _, path := range paths { + c, err := d.DialContext(ctx, "unix", path) + if err == nil { + _ = c.Close() + continue + } + + if errors.Is(err, os.ErrNotExist) { + continue + } + + if errors.Is(err, syscall.ECONNREFUSED) { + filtered = append(filtered, path) + continue + } + + lastErr = trace.ConvertSystemError(err) + } + + return filtered, lastErr +} + +func (r *SSHServerWrapper) HandoverCleanup(ctx context.Context) error { + if r.dataDir == "" { + return nil + } + + dir := sockDir(r.dataDir) + entries, err := os.ReadDir(dir) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return trace.ConvertSystemError(err) + } + + paths := make([]string, 0, len(entries)) + for _, entry := range entries { + if strings.HasSuffix(entry.Name(), ".sock") { + paths = append(paths, filepath.Join(dir, entry.Name())) + } + } + + paths, firstErr := filterNonConnectableSockets(ctx, paths) + + if len(paths) < 1 { + return trace.Wrap(firstErr) + } + + // unix domain sockets exist on disk between bind() and listen() but + // connecting before listen() results in ECONNREFUSED, so we just wait a + // little bit before testing them again; the first check lets us be done + // with the check immediately in the happy case where there's no + // unconnectable sockets + r.log.WithField("sockets", len(paths)).Debug("Found some non-connectable handover sockets, waiting before checking them again.") + + select { + case <-ctx.Done(): + return trace.NewAggregate(firstErr, ctx.Err()) + case <-time.After(3 * time.Second): + } + + paths, secondErr := filterNonConnectableSockets(ctx, paths) + + if len(paths) < 1 { + r.log.Debug("Found no non-connectable handover socket after waiting.") + return trace.NewAggregate(firstErr, secondErr) + } + + r.log.WithField("sockets", len(paths)).Info("Cleaning up some non-connectable handover sockets, left over from previous Teleport instances.") + + errs := []error{firstErr, secondErr} + for _, path := range paths { + errs = append(errs, trace.ConvertSystemError(os.Remove(path))) + } + + return trace.NewAggregate(errs...) +} diff --git a/lib/resumption/resumption_test.go b/lib/resumption/resumption_test.go index 13fa191b835cf..de6f74a89f58c 100644 --- a/lib/resumption/resumption_test.go +++ b/lib/resumption/resumption_test.go @@ -42,7 +42,10 @@ func TestResumption(t *testing.T) { hostID := uuid.NewString() sshServer := discardingSSHServer(t) - resumableServer := NewSSHServerWrapper(nil, sshServer, hostID) + resumableServer := NewSSHServerWrapper(SSHServerWrapperConfig{ + SSHServer: sshServer, + HostID: hostID, + }) directListener, err := net.Listen("tcp", "localhost:") require.NoError(t, err) diff --git a/lib/resumption/server_detect.go b/lib/resumption/server_detect.go index 7805bc19ddc09..e9fdd41ab23c8 100644 --- a/lib/resumption/server_detect.go +++ b/lib/resumption/server_detect.go @@ -58,18 +58,28 @@ func serverVersionCRLFV1(pubKey *ecdh.PublicKey, hostID string) string { ) } -// NewSSHServerWrapper wraps a given SSH server as to support connection -// resumption. -func NewSSHServerWrapper(log logrus.FieldLogger, sshServer func(net.Conn), hostID string) *SSHServerWrapper { - if log == nil { - log = logrus.WithField(trace.Component, Component) +type SSHServerWrapperConfig struct { + Log logrus.FieldLogger + SSHServer func(net.Conn) + HostID string + + DataDir string +} + +// NewSSHServerWrapper wraps a given SSH server to support connection +// resumption, providing a connection handler method and a +// [multiplexer.PreDetectFunc] method. +func NewSSHServerWrapper(cfg SSHServerWrapperConfig) *SSHServerWrapper { + if cfg.Log == nil { + cfg.Log = logrus.WithField(trace.Component, Component) } return &SSHServerWrapper{ - sshServer: sshServer, - log: log, + sshServer: cfg.SSHServer, + log: cfg.Log, - hostID: hostID, + hostID: cfg.HostID, + dataDir: cfg.DataDir, conns: make(map[resumptionToken]*connEntry), } @@ -85,7 +95,8 @@ type SSHServerWrapper struct { sshServer func(net.Conn) log logrus.FieldLogger - hostID string + hostID string + dataDir string mu sync.Mutex conns map[resumptionToken]*connEntry diff --git a/lib/resumption/server_exchange.go b/lib/resumption/server_exchange.go index 96e2e01f139ab..37253ada451c0 100644 --- a/lib/resumption/server_exchange.go +++ b/lib/resumption/server_exchange.go @@ -17,6 +17,7 @@ package resumption import ( + "context" "crypto/ecdh" "crypto/sha256" "io" @@ -99,24 +100,32 @@ func (r *SSHServerWrapper) handleResumptionExchangeV1(conn *multiplexer.Conn, dh remoteIP, _ = netip.AddrFromSlice(t.IP) remoteIP = remoteIP.Unmap() } - if !remoteIP.IsValid() { - r.log.Warn("Tracking resumable connection with an invalid remote address (this is a bug).") - } token := resumptionToken(otp32[:16]) - - r.mu.Lock() entry := &connEntry{ conn: resumableConn, remoteIP: remoteIP, timeout: time.AfterFunc(detachedTimeout, func() { resumableConn.Close() }), } - r.conns[token] = entry - r.mu.Unlock() + + // this context is only used for the convenience of [context.AfterFunc] + handoverContext, handoverCancel := context.WithCancel(context.Background()) + if remoteIP.IsValid() { + r.mu.Lock() + r.conns[token] = entry + r.mu.Unlock() + + if err := r.setupHandoverListener(handoverContext, token, entry); err != nil { + r.log.WithError(err).Warn("Unable to create handover listener for resumable connection, connection resumption will not work across graceful restarts.") + } + } else { + r.log.Warn("Refusing to track resumable connection with an invalid remote IP address, connection resumption will not work (this is a bug).") + } go func() { defer r.log.Info("Resumable connection completed.") defer resumableConn.Close() + defer handoverCancel() defer func() { r.mu.Lock() defer r.mu.Unlock() @@ -156,25 +165,23 @@ func (r *SSHServerWrapper) handleResumptionExchangeV1(conn *multiplexer.Conn, dh r.mu.Unlock() if entry == nil { - r.log.Debug("Resumable connection not found or already deleted.") - _, _ = conn.Write([]byte{notFoundServerExchangeTag}) + r.attemptHandover(conn, token) return } - if entry.remoteIP.IsValid() { - var remoteIP netip.Addr - if t, ok := conn.RemoteAddr().(*net.TCPAddr); ok { - remoteIP, _ = netip.AddrFromSlice(t.IP) - } - if entry.remoteIP != remoteIP.Unmap() { - r.log.Warn("Resumable connection attempted resumption from a different remote address.") - _, _ = conn.Write([]byte{badAddressServerExchangeTag}) - return - } - } else { - r.log.Warn("Resumable connection attempted resumption on a connection with an invalid remote address.") - // the client can't recover from this, just send a "not found" - _, _ = conn.Write([]byte{notFoundServerExchangeTag}) + var remoteIP netip.Addr + if t, _ := conn.RemoteAddr().(*net.TCPAddr); t != nil { + remoteIP, _ = netip.AddrFromSlice(t.IP) + remoteIP = remoteIP.Unmap() + } + + r.resumeConnection(entry, conn, remoteIP) +} + +func (r *SSHServerWrapper) resumeConnection(entry *connEntry, conn net.Conn, remoteIP netip.Addr) { + if entry.remoteIP != remoteIP { + r.log.Warn("Resumable connection attempted resumption from a different remote address.") + _, _ = conn.Write([]byte{badAddressServerExchangeTag}) return } diff --git a/lib/service/service.go b/lib/service/service.go index 4700517f8cd0e..84c1455b0d557 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -2681,11 +2681,19 @@ func (process *TeleportProcess) initSSH() error { var resumableServer *resumption.SSHServerWrapper if os.Getenv("TELEPORT_UNSTABLE_DISABLE_SSH_RESUMPTION") == "" { - resumableServer = resumption.NewSSHServerWrapper( - log.WithField(trace.Component, teleport.Component(teleport.ComponentNode, resumption.Component)), - s.HandleConnection, - serverID, - ) + resumableServer = resumption.NewSSHServerWrapper(resumption.SSHServerWrapperConfig{ + Log: log.WithField(trace.Component, teleport.Component(teleport.ComponentNode, resumption.Component)), + SSHServer: s.HandleConnection, + + HostID: serverID, + DataDir: cfg.DataDir, + }) + + go func() { + if err := resumableServer.HandoverCleanup(process.GracefulExitContext()); err != nil { + log.WithError(err).Warn("Failed to clean up handover sockets.") + } + }() } var agentPool *reversetunnel.AgentPool From 226decc31069be6e7b8bd7790b7d0fee0aa7933a Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Thu, 15 Feb 2024 19:26:42 +0100 Subject: [PATCH 02/18] Add a test, make socket names shorter --- lib/resumption/handover_test.go | 114 ++++++++++++++++++++++++++++++++ lib/resumption/handover_unix.go | 8 ++- 2 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 lib/resumption/handover_test.go diff --git a/lib/resumption/handover_test.go b/lib/resumption/handover_test.go new file mode 100644 index 0000000000000..7d21b3b61a039 --- /dev/null +++ b/lib/resumption/handover_test.go @@ -0,0 +1,114 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package resumption + +import ( + "context" + "math/rand" + "net" + "os" + "runtime" + "testing" + + "github.com/google/uuid" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/uds" +) + +func TestHandover(t *testing.T) { + t.Parallel() + require := require.New(t) + + sshServer := discardingSSHServer(t) + hostID := uuid.NewString() + dataDir := shortTempDir(t) + t.Logf("using temporary data dir %q", dataDir) + + s1 := NewSSHServerWrapper(SSHServerWrapperConfig{ + SSHServer: sshServer, + HostID: hostID, + DataDir: dataDir, + }) + s2 := NewSSHServerWrapper(SSHServerWrapperConfig{ + SSHServer: sshServer, + HostID: hostID, + DataDir: dataDir, + }) + + dial := func(handleConnection func(net.Conn)) (net.Conn, error) { + c1, c2, err := uds.NewSocketpair(uds.SocketTypeStream) + if err != nil { + return nil, err + } + + a1 := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: rand.Intn(65536)} + a2 := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: rand.Intn(65536)} + + go handleConnection(utils.NewConnWithAddr(c2, a2, a1)) + return utils.NewConnWithAddr(c1, a1, a2), nil + } + + originalNC, err := dial(s1.HandleConnection) + require.NoError(err) + defer originalNC.Close() + + redialDestination := make(chan func(net.Conn)) + defer close(redialDestination) + + wrappedNC, err := WrapSSHClientConn(context.Background(), originalNC, func(ctx context.Context, receivedHostID string) (net.Conn, error) { + if receivedHostID != hostID { + return nil, trace.BadParameter("expected hostID %q, got %q", hostID, receivedHostID) + } + handleConnection := <-redialDestination + if handleConnection == nil { + return nil, trace.ConnectionProblem(nil, "no redial destination received") + } + return dial(handleConnection) + }) + require.NoError(err) + defer wrappedNC.Close() + require.IsType((*Conn)(nil), wrappedNC) + + clt, err := sshClient(wrappedNC) + require.NoError(err) + defer clt.Close() + + const wantReplyTrue = true + _, _, err = clt.SendRequest("foo", wantReplyTrue, nil) + require.NoError(err) + + _ = originalNC.Close() + redialDestination <- s2.HandleConnection + + _, _, err = clt.SendRequest("foo", wantReplyTrue, nil) + require.NoError(err) +} + +func shortTempDir(t *testing.T) string { + t.Helper() + base := "" + if runtime.GOOS == "darwin" { + base = "/tmp" + } + d, err := os.MkdirTemp(base, "") + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, os.RemoveAll(d)) }) + return d +} diff --git a/lib/resumption/handover_unix.go b/lib/resumption/handover_unix.go index 37a35e63c3a83..221ef08d8bc8e 100644 --- a/lib/resumption/handover_unix.go +++ b/lib/resumption/handover_unix.go @@ -34,7 +34,13 @@ import ( ) func sockPath(dataDir string, token resumptionToken) string { - return filepath.Join(dataDir, "handover", fmt.Sprintf("%x.sock", sha256.Sum256(token[:]))) + hash := sha256.Sum256(token[:]) + // unix domain sockets are limited to 108 or 104 characters, so spending 64 + // for the full sha256 hash is a bit too much; truncating the hash to 128 + // bits still gives us more than enough headroom to just assume that we'll + // have no collisions (a probability of one in a quintillion with 26 billion + // concurrent connections) + return filepath.Join(dataDir, "handover", fmt.Sprintf("%x.sock", hash[:16])) } func sockDir(dataDir string) string { From 4b4eb5a64c7cf8c68ad81aabd2d2a9b9ab9b9450 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Fri, 16 Feb 2024 16:48:24 +0100 Subject: [PATCH 03/18] Encode token hashes with base 64 --- lib/resumption/handover_unix.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/resumption/handover_unix.go b/lib/resumption/handover_unix.go index 221ef08d8bc8e..c95b342c2a24a 100644 --- a/lib/resumption/handover_unix.go +++ b/lib/resumption/handover_unix.go @@ -21,8 +21,8 @@ package resumption import ( "context" "crypto/sha256" + "encoding/base64" "errors" - "fmt" "net" "os" "path/filepath" @@ -35,12 +35,12 @@ import ( func sockPath(dataDir string, token resumptionToken) string { hash := sha256.Sum256(token[:]) - // unix domain sockets are limited to 108 or 104 characters, so spending 64 - // for the full sha256 hash is a bit too much; truncating the hash to 128 - // bits still gives us more than enough headroom to just assume that we'll - // have no collisions (a probability of one in a quintillion with 26 billion - // concurrent connections) - return filepath.Join(dataDir, "handover", fmt.Sprintf("%x.sock", hash[:16])) + // unix domain sockets are limited to 108 or 104 characters, so the full + // sha256 hash is a bit too much (64 bytes in hex or 44 in b64); truncating + // the hash to 128 bits still gives us more than enough headroom to just + // assume that we'll have no collisions (a probability of one in a + // quintillion with 26 billion concurrent connections) + return filepath.Join(dataDir, "handover", base64.RawURLEncoding.EncodeToString(hash[:16])) } func sockDir(dataDir string) string { From 71cabe8dcd35f72374acb3d54a3069c8822ceeb3 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Fri, 16 Feb 2024 16:54:54 +0100 Subject: [PATCH 04/18] Document temporary directory creation more --- lib/resumption/handover_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/resumption/handover_test.go b/lib/resumption/handover_test.go index 7d21b3b61a039..8e0554a4e801f 100644 --- a/lib/resumption/handover_test.go +++ b/lib/resumption/handover_test.go @@ -38,8 +38,8 @@ func TestHandover(t *testing.T) { sshServer := discardingSSHServer(t) hostID := uuid.NewString() + // unix domain socket names have a very tight length limit dataDir := shortTempDir(t) - t.Logf("using temporary data dir %q", dataDir) s1 := NewSSHServerWrapper(SSHServerWrapperConfig{ SSHServer: sshServer, From 56c74eceb5462ee51e38da3af4629a1ce0c64b8f Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Wed, 21 Feb 2024 17:28:34 +0100 Subject: [PATCH 05/18] Minor cleanups --- lib/resumption/handover_unix.go | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/lib/resumption/handover_unix.go b/lib/resumption/handover_unix.go index c95b342c2a24a..78f361dcc1035 100644 --- a/lib/resumption/handover_unix.go +++ b/lib/resumption/handover_unix.go @@ -33,6 +33,8 @@ import ( "github.com/gravitational/trace" ) +const sockSuffix = ".sock" + func sockPath(dataDir string, token resumptionToken) string { hash := sha256.Sum256(token[:]) // unix domain sockets are limited to 108 or 104 characters, so the full @@ -40,7 +42,7 @@ func sockPath(dataDir string, token resumptionToken) string { // the hash to 128 bits still gives us more than enough headroom to just // assume that we'll have no collisions (a probability of one in a // quintillion with 26 billion concurrent connections) - return filepath.Join(dataDir, "handover", base64.RawURLEncoding.EncodeToString(hash[:16])) + return filepath.Join(dataDir, "handover", base64.RawURLEncoding.EncodeToString(hash[:16])+sockSuffix) } func sockDir(dataDir string) string { @@ -100,6 +102,10 @@ func filterNonConnectableSockets(ctx context.Context, paths []string) (filtered return filtered, lastErr } +// HandoverCleanup deletes hand-over sockets that were left over from previous +// runs of Teleport that failed to clean up after themselves (because of an +// uncatchable signal or a system crash). It will exhaustively clean up the +// current left over sockets, so it's sufficient to call it once per process. func (r *SSHServerWrapper) HandoverCleanup(ctx context.Context) error { if r.dataDir == "" { return nil @@ -116,7 +122,7 @@ func (r *SSHServerWrapper) HandoverCleanup(ctx context.Context) error { paths := make([]string, 0, len(entries)) for _, entry := range entries { - if strings.HasSuffix(entry.Name(), ".sock") { + if strings.HasSuffix(entry.Name(), sockSuffix) { paths = append(paths, filepath.Join(dir, entry.Name())) } } @@ -132,7 +138,7 @@ func (r *SSHServerWrapper) HandoverCleanup(ctx context.Context) error { // little bit before testing them again; the first check lets us be done // with the check immediately in the happy case where there's no // unconnectable sockets - r.log.WithField("sockets", len(paths)).Debug("Found some non-connectable handover sockets, waiting before checking them again.") + r.log.WithField("sockets", len(paths)).Debug("Found some unconnectable handover sockets, waiting before checking them again.") select { case <-ctx.Done(): @@ -143,7 +149,7 @@ func (r *SSHServerWrapper) HandoverCleanup(ctx context.Context) error { paths, secondErr := filterNonConnectableSockets(ctx, paths) if len(paths) < 1 { - r.log.Debug("Found no non-connectable handover socket after waiting.") + r.log.Debug("Found no unconnectable handover socket after waiting.") return trace.NewAggregate(firstErr, secondErr) } @@ -151,7 +157,9 @@ func (r *SSHServerWrapper) HandoverCleanup(ctx context.Context) error { errs := []error{firstErr, secondErr} for _, path := range paths { - errs = append(errs, trace.ConvertSystemError(os.Remove(path))) + if err := trace.ConvertSystemError(os.Remove(path)); err != nil { + errs = append(errs, err) + } } return trace.NewAggregate(errs...) From cfbec5f65dfb60cb4abea8644c0a7ead3b053652 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Wed, 21 Feb 2024 17:29:27 +0100 Subject: [PATCH 06/18] fix non-unix builds --- lib/resumption/handover_other.go | 40 +++++++++++++++++++ ...handover_test.go => handover_unix_test.go} | 2 + lib/resumption/server_exchange.go | 4 +- 3 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 lib/resumption/handover_other.go rename lib/resumption/{handover_test.go => handover_unix_test.go} (99%) diff --git a/lib/resumption/handover_other.go b/lib/resumption/handover_other.go new file mode 100644 index 0000000000000..9b5e7720b0584 --- /dev/null +++ b/lib/resumption/handover_other.go @@ -0,0 +1,40 @@ +//go:build !unix + +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package resumption + +import ( + "context" + "net" + + "github.com/gravitational/trace" +) + +func (r *SSHServerWrapper) listenHandover(token resumptionToken) (net.Listener, error) { + return nil, trace.NotImplemented("handover is not implemented for the current platform") +} + +func (r *SSHServerWrapper) dialHandover(token resumptionToken) (net.Conn, error) { + return nil, trace.NotFound("handover is not implemented for the current platform") +} + +// HandoverCleanup does nothing, because on this platform we don't support +// hand-over sockets, so there can't be anything to clean up. +func (r *SSHServerWrapper) HandoverCleanup(context.Context) error { + return nil +} diff --git a/lib/resumption/handover_test.go b/lib/resumption/handover_unix_test.go similarity index 99% rename from lib/resumption/handover_test.go rename to lib/resumption/handover_unix_test.go index 8e0554a4e801f..38a308dd6bef0 100644 --- a/lib/resumption/handover_test.go +++ b/lib/resumption/handover_unix_test.go @@ -1,3 +1,5 @@ +//go:build unix + // Teleport // Copyright (C) 2024 Gravitational, Inc. // diff --git a/lib/resumption/server_exchange.go b/lib/resumption/server_exchange.go index 37253ada451c0..d5bd08337d9cb 100644 --- a/lib/resumption/server_exchange.go +++ b/lib/resumption/server_exchange.go @@ -25,6 +25,8 @@ import ( "net/netip" "time" + "github.com/gravitational/trace" + "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/utils" ) @@ -115,7 +117,7 @@ func (r *SSHServerWrapper) handleResumptionExchangeV1(conn *multiplexer.Conn, dh r.conns[token] = entry r.mu.Unlock() - if err := r.setupHandoverListener(handoverContext, token, entry); err != nil { + if err := r.setupHandoverListener(handoverContext, token, entry); err != nil && !trace.IsNotImplemented(err) { r.log.WithError(err).Warn("Unable to create handover listener for resumable connection, connection resumption will not work across graceful restarts.") } } else { From 63794f54a75ade0ae109b1ebb6c9afd1bf2c8169 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Wed, 21 Feb 2024 18:38:54 +0100 Subject: [PATCH 07/18] Expand handover tests --- lib/resumption/handover_unix.go | 9 +- lib/resumption/handover_unix_test.go | 118 +++++++++++++++++++++++---- 2 files changed, 108 insertions(+), 19 deletions(-) diff --git a/lib/resumption/handover_unix.go b/lib/resumption/handover_unix.go index 78f361dcc1035..77c95bb8e9320 100644 --- a/lib/resumption/handover_unix.go +++ b/lib/resumption/handover_unix.go @@ -102,6 +102,8 @@ func filterNonConnectableSockets(ctx context.Context, paths []string) (filtered return filtered, lastErr } +type cleanupDelayContextKey struct{} + // HandoverCleanup deletes hand-over sockets that were left over from previous // runs of Teleport that failed to clean up after themselves (because of an // uncatchable signal or a system crash). It will exhaustively clean up the @@ -140,10 +142,15 @@ func (r *SSHServerWrapper) HandoverCleanup(ctx context.Context) error { // unconnectable sockets r.log.WithField("sockets", len(paths)).Debug("Found some unconnectable handover sockets, waiting before checking them again.") + cleanupDelay := time.Second + if d, ok := ctx.Value((*cleanupDelayContextKey)(nil)).(time.Duration); ok { + cleanupDelay = d + } + select { case <-ctx.Done(): return trace.NewAggregate(firstErr, ctx.Err()) - case <-time.After(3 * time.Second): + case <-time.After(cleanupDelay): } paths, secondErr := filterNonConnectableSockets(ctx, paths) diff --git a/lib/resumption/handover_unix_test.go b/lib/resumption/handover_unix_test.go index 38a308dd6bef0..b7baae75990fe 100644 --- a/lib/resumption/handover_unix_test.go +++ b/lib/resumption/handover_unix_test.go @@ -20,14 +20,19 @@ package resumption import ( "context" + "encoding/binary" "math/rand" "net" + "net/netip" "os" + "path/filepath" "runtime" "testing" + "time" "github.com/google/uuid" "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/lib/utils" @@ -54,38 +59,44 @@ func TestHandover(t *testing.T) { DataDir: dataDir, }) - dial := func(handleConnection func(net.Conn)) (net.Conn, error) { + dial := func(handleConnection func(net.Conn), clientAddr netip.Addr) net.Conn { c1, c2, err := uds.NewSocketpair(uds.SocketTypeStream) - if err != nil { - return nil, err - } + require.NoError(err) - a1 := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: rand.Intn(65536)} - a2 := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: rand.Intn(65536)} + srv := &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 1 + rand.Intn(65535), + } + clt := &net.TCPAddr{ + IP: clientAddr.AsSlice(), + Zone: clientAddr.Zone(), + Port: 1 + rand.Intn(65535), + } - go handleConnection(utils.NewConnWithAddr(c2, a2, a1)) - return utils.NewConnWithAddr(c1, a1, a2), nil + go handleConnection(utils.NewConnWithAddr(c2, srv, clt)) + conn := utils.NewConnWithAddr(c1, clt, srv) + t.Cleanup(func() { _ = conn.Close() }) + return conn } - originalNC, err := dial(s1.HandleConnection) - require.NoError(err) - defer originalNC.Close() + originalNC := dial(s1.HandleConnection, netip.MustParseAddr("127.0.0.1")) - redialDestination := make(chan func(net.Conn)) - defer close(redialDestination) + redialConns := make(chan net.Conn) + defer close(redialConns) wrappedNC, err := WrapSSHClientConn(context.Background(), originalNC, func(ctx context.Context, receivedHostID string) (net.Conn, error) { if receivedHostID != hostID { return nil, trace.BadParameter("expected hostID %q, got %q", hostID, receivedHostID) } - handleConnection := <-redialDestination - if handleConnection == nil { - return nil, trace.ConnectionProblem(nil, "no redial destination received") + conn := <-redialConns + if conn == nil { + return nil, trace.ConnectionProblem(nil, "no redial connection received") } - return dial(handleConnection) + return conn, nil }) require.NoError(err) defer wrappedNC.Close() + require.IsType((*Conn)(nil), wrappedNC) clt, err := sshClient(wrappedNC) @@ -97,10 +108,18 @@ func TestHandover(t *testing.T) { require.NoError(err) _ = originalNC.Close() - redialDestination <- s2.HandleConnection + nextNC := dial(s2.HandleConnection, netip.MustParseAddr("127.0.0.1")) + redialConns <- nextNC _, _, err = clt.SendRequest("foo", wantReplyTrue, nil) require.NoError(err) + + _ = nextNC.Close() + // this will result in a closed connection, because changing network address + // stops further reconnection attempts + redialConns <- dial(s2.HandleConnection, netip.MustParseAddr("127.0.0.2")) + + require.ErrorIs(clt.Wait(), net.ErrClosed) } func shortTempDir(t *testing.T) string { @@ -114,3 +133,66 @@ func shortTempDir(t *testing.T) string { t.Cleanup(func() { require.NoError(t, os.RemoveAll(d)) }) return d } + +func TestHandoverCleanup(t *testing.T) { + t.Parallel() + require := require.New(t) + + hostID := uuid.NewString() + // unix domain socket names have a very tight length limit + dataDir := shortTempDir(t) + + var tok resumptionToken + binary.NativeEndian.PutUint64(tok[:8], rand.Uint64()) + binary.NativeEndian.PutUint64(tok[8:], rand.Uint64()) + + s := NewSSHServerWrapper(SSHServerWrapperConfig{ + SSHServer: func(c net.Conn) { + defer c.Close() + assert.Fail(t, "unexpected connection") + }, + HostID: hostID, + DataDir: dataDir, + }) + + handoverDir := filepath.Join(dataDir, "handover") + require.NoError(os.MkdirAll(handoverDir, 0o700)) + + d, err := os.ReadDir(handoverDir) + require.NoError(err) + require.Empty(d) + + l, err := net.ListenUnix("unix", &net.UnixAddr{Name: sockPath(dataDir, tok)}) + require.NoError(err) + l.SetUnlinkOnClose(false) + defer l.Close() + go func() { + defer l.Close() + for { + c, err := l.Accept() + if err != nil { + break + } + _ = c.Close() + } + }() + + d, err = os.ReadDir(handoverDir) + require.NoError(err) + require.NotEmpty(d) + + ctx := context.WithValue(context.Background(), (*cleanupDelayContextKey)(nil), time.Duration(0)) + + require.NoError(s.HandoverCleanup(ctx)) + + d, err = os.ReadDir(handoverDir) + require.NoError(err) + require.NotEmpty(d) + + l.Close() + require.NoError(s.HandoverCleanup(ctx)) + + d, err = os.ReadDir(handoverDir) + require.NoError(err) + require.Empty(d) +} From 50e4b60f8ec620517374791d07adcea29cb97c55 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Tue, 27 Feb 2024 16:18:43 +0100 Subject: [PATCH 08/18] use PrivateDirMode constant --- lib/resumption/handover_unix.go | 4 +++- lib/resumption/handover_unix_test.go | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/resumption/handover_unix.go b/lib/resumption/handover_unix.go index 77c95bb8e9320..47b83ce5f460a 100644 --- a/lib/resumption/handover_unix.go +++ b/lib/resumption/handover_unix.go @@ -31,6 +31,8 @@ import ( "time" "github.com/gravitational/trace" + + "github.com/gravitational/teleport" ) const sockSuffix = ".sock" @@ -56,7 +58,7 @@ func (r *SSHServerWrapper) listenHandover(token resumptionToken) (net.Listener, return nil, trace.Wrap(errNoDataDir) } - _ = os.MkdirAll(sockDir(r.dataDir), 0o700) + _ = os.MkdirAll(sockDir(r.dataDir), teleport.PrivateDirMode) l, err := net.Listen("unix", sockPath(r.dataDir, token)) if err != nil { return nil, trace.ConvertSystemError(err) diff --git a/lib/resumption/handover_unix_test.go b/lib/resumption/handover_unix_test.go index b7baae75990fe..1df14f73f70bf 100644 --- a/lib/resumption/handover_unix_test.go +++ b/lib/resumption/handover_unix_test.go @@ -35,6 +35,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/uds" ) @@ -156,7 +157,7 @@ func TestHandoverCleanup(t *testing.T) { }) handoverDir := filepath.Join(dataDir, "handover") - require.NoError(os.MkdirAll(handoverDir, 0o700)) + require.NoError(os.MkdirAll(handoverDir, teleport.PrivateDirMode)) d, err := os.ReadDir(handoverDir) require.NoError(err) From 0608afd663235c80510964da83edd72e48667aef Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Tue, 27 Feb 2024 16:22:01 +0100 Subject: [PATCH 09/18] name changes: {start,create}HandoverListener --- lib/resumption/handover.go | 4 ++-- lib/resumption/handover_other.go | 2 +- lib/resumption/handover_unix.go | 2 +- lib/resumption/server_exchange.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/resumption/handover.go b/lib/resumption/handover.go index e442632c35f76..46f1ab7c29ff7 100644 --- a/lib/resumption/handover.go +++ b/lib/resumption/handover.go @@ -59,8 +59,8 @@ func (r *SSHServerWrapper) attemptHandover(conn *multiplexer.Conn, token resumpt _ = utils.ProxyConn(context.Background(), conn, handoverConn) } -func (r *SSHServerWrapper) setupHandoverListener(ctx context.Context, token resumptionToken, entry *connEntry) error { - l, err := r.listenHandover(token) +func (r *SSHServerWrapper) startHandoverListener(ctx context.Context, token resumptionToken, entry *connEntry) error { + l, err := r.createHandoverListener(token) if err != nil { return trace.Wrap(err) } diff --git a/lib/resumption/handover_other.go b/lib/resumption/handover_other.go index 9b5e7720b0584..ac7be32d21433 100644 --- a/lib/resumption/handover_other.go +++ b/lib/resumption/handover_other.go @@ -25,7 +25,7 @@ import ( "github.com/gravitational/trace" ) -func (r *SSHServerWrapper) listenHandover(token resumptionToken) (net.Listener, error) { +func (r *SSHServerWrapper) createHandoverListener(token resumptionToken) (net.Listener, error) { return nil, trace.NotImplemented("handover is not implemented for the current platform") } diff --git a/lib/resumption/handover_unix.go b/lib/resumption/handover_unix.go index 47b83ce5f460a..9723d17181d13 100644 --- a/lib/resumption/handover_unix.go +++ b/lib/resumption/handover_unix.go @@ -53,7 +53,7 @@ func sockDir(dataDir string) string { var errNoDataDir error = &trace.NotFoundError{Message: "data dir not configured"} -func (r *SSHServerWrapper) listenHandover(token resumptionToken) (net.Listener, error) { +func (r *SSHServerWrapper) createHandoverListener(token resumptionToken) (net.Listener, error) { if r.dataDir == "" { return nil, trace.Wrap(errNoDataDir) } diff --git a/lib/resumption/server_exchange.go b/lib/resumption/server_exchange.go index d5bd08337d9cb..4ea1f9950581d 100644 --- a/lib/resumption/server_exchange.go +++ b/lib/resumption/server_exchange.go @@ -117,7 +117,7 @@ func (r *SSHServerWrapper) handleResumptionExchangeV1(conn *multiplexer.Conn, dh r.conns[token] = entry r.mu.Unlock() - if err := r.setupHandoverListener(handoverContext, token, entry); err != nil && !trace.IsNotImplemented(err) { + if err := r.startHandoverListener(handoverContext, token, entry); err != nil && !trace.IsNotImplemented(err) { r.log.WithError(err).Warn("Unable to create handover listener for resumable connection, connection resumption will not work across graceful restarts.") } } else { From 44c62fa9d8940af53a2d92ee6b36fe8f4b57cafa Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Tue, 27 Feb 2024 16:28:27 +0100 Subject: [PATCH 10/18] SSHServerWrapper godocs --- lib/resumption/server_detect.go | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/lib/resumption/server_detect.go b/lib/resumption/server_detect.go index e9fdd41ab23c8..97ac10d5c65ef 100644 --- a/lib/resumption/server_detect.go +++ b/lib/resumption/server_detect.go @@ -59,16 +59,24 @@ func serverVersionCRLFV1(pubKey *ecdh.PublicKey, hostID string) string { } type SSHServerWrapperConfig struct { - Log logrus.FieldLogger + Log logrus.FieldLogger + + // SSHServer is a function that takes ownership of a [net.Conn] and uses it + // as a SSH server. If the Conn is a [sshutils.SSHServerVersionOverrider], + // the server should use the overridden server version. SSHServer func(net.Conn) - HostID string + // HostID is the host ID of the Teleport instance running the server; + // compliant connection resumption clients will reconnect to the host ID + // expecting to reach the instance. + HostID string + + // DataDir is the path to the Teleport data directory. Depending on the + // platform, it might be used to store temporary hand-over sockets. DataDir string } -// NewSSHServerWrapper wraps a given SSH server to support connection -// resumption, providing a connection handler method and a -// [multiplexer.PreDetectFunc] method. +// NewSSHServerWrapper creates a [SSHServerWrapper]. func NewSSHServerWrapper(cfg SSHServerWrapperConfig) *SSHServerWrapper { if cfg.Log == nil { cfg.Log = logrus.WithField(trace.Component, Component) From fa81fa62c5962a6127410962701d9a7377b382ea Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Tue, 27 Feb 2024 16:32:12 +0100 Subject: [PATCH 11/18] avoid passing cleanupDelay in a context --- lib/resumption/handover_unix.go | 12 +++++------- lib/resumption/handover_unix_test.go | 7 ++++--- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/lib/resumption/handover_unix.go b/lib/resumption/handover_unix.go index 9723d17181d13..810f9cdf0749f 100644 --- a/lib/resumption/handover_unix.go +++ b/lib/resumption/handover_unix.go @@ -104,13 +104,16 @@ func filterNonConnectableSockets(ctx context.Context, paths []string) (filtered return filtered, lastErr } -type cleanupDelayContextKey struct{} - // HandoverCleanup deletes hand-over sockets that were left over from previous // runs of Teleport that failed to clean up after themselves (because of an // uncatchable signal or a system crash). It will exhaustively clean up the // current left over sockets, so it's sufficient to call it once per process. func (r *SSHServerWrapper) HandoverCleanup(ctx context.Context) error { + const cleanupDelay = time.Second + return trace.Wrap(r.handoverCleanup(ctx, cleanupDelay)) +} + +func (r *SSHServerWrapper) handoverCleanup(ctx context.Context, cleanupDelay time.Duration) error { if r.dataDir == "" { return nil } @@ -144,11 +147,6 @@ func (r *SSHServerWrapper) HandoverCleanup(ctx context.Context) error { // unconnectable sockets r.log.WithField("sockets", len(paths)).Debug("Found some unconnectable handover sockets, waiting before checking them again.") - cleanupDelay := time.Second - if d, ok := ctx.Value((*cleanupDelayContextKey)(nil)).(time.Duration); ok { - cleanupDelay = d - } - select { case <-ctx.Done(): return trace.NewAggregate(firstErr, ctx.Err()) diff --git a/lib/resumption/handover_unix_test.go b/lib/resumption/handover_unix_test.go index 1df14f73f70bf..524f918521ded 100644 --- a/lib/resumption/handover_unix_test.go +++ b/lib/resumption/handover_unix_test.go @@ -182,16 +182,17 @@ func TestHandoverCleanup(t *testing.T) { require.NoError(err) require.NotEmpty(d) - ctx := context.WithValue(context.Background(), (*cleanupDelayContextKey)(nil), time.Duration(0)) + ctx := context.Background() - require.NoError(s.HandoverCleanup(ctx)) + const cleanupDelayZero time.Duration = 0 + require.NoError(s.handoverCleanup(ctx, cleanupDelayZero)) d, err = os.ReadDir(handoverDir) require.NoError(err) require.NotEmpty(d) l.Close() - require.NoError(s.HandoverCleanup(ctx)) + require.NoError(s.handoverCleanup(ctx, cleanupDelayZero)) d, err = os.ReadDir(handoverDir) require.NoError(err) From a2924f2f507d01c6d7cc0ca6b54e58b9b812d036 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Tue, 27 Feb 2024 16:35:44 +0100 Subject: [PATCH 12/18] name change: retainNonConnectableSockets --- lib/resumption/handover_unix.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/lib/resumption/handover_unix.go b/lib/resumption/handover_unix.go index 810f9cdf0749f..96c7ea1572b34 100644 --- a/lib/resumption/handover_unix.go +++ b/lib/resumption/handover_unix.go @@ -78,7 +78,10 @@ func (r *SSHServerWrapper) dialHandover(token resumptionToken) (net.Conn, error) return c, nil } -func filterNonConnectableSockets(ctx context.Context, paths []string) (filtered []string, lastErr error) { +// retainNonConnectableSockets attempts to connect to the given UNIX domain +// sockets, returning all and only the ones that exist and that refuse the +// connection. +func retainNonConnectableSockets(ctx context.Context, paths []string) (filtered []string, lastErr error) { filtered = paths[:0] var d net.Dialer @@ -98,10 +101,10 @@ func filterNonConnectableSockets(ctx context.Context, paths []string) (filtered continue } - lastErr = trace.ConvertSystemError(err) + lastErr = err } - return filtered, lastErr + return filtered, trace.ConvertSystemError(lastErr) } // HandoverCleanup deletes hand-over sockets that were left over from previous @@ -134,7 +137,7 @@ func (r *SSHServerWrapper) handoverCleanup(ctx context.Context, cleanupDelay tim } } - paths, firstErr := filterNonConnectableSockets(ctx, paths) + paths, firstErr := retainNonConnectableSockets(ctx, paths) if len(paths) < 1 { return trace.Wrap(firstErr) @@ -153,7 +156,7 @@ func (r *SSHServerWrapper) handoverCleanup(ctx context.Context, cleanupDelay tim case <-time.After(cleanupDelay): } - paths, secondErr := filterNonConnectableSockets(ctx, paths) + paths, secondErr := retainNonConnectableSockets(ctx, paths) if len(paths) < 1 { r.log.Debug("Found no unconnectable handover socket after waiting.") From aefbc4574ca6c1c174e6c9b36428587d66e8bfa3 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Tue, 27 Feb 2024 16:38:21 +0100 Subject: [PATCH 13/18] add a reference for the logic in the Accept loop --- lib/resumption/handover.go | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/resumption/handover.go b/lib/resumption/handover.go index 46f1ab7c29ff7..3824e0529360f 100644 --- a/lib/resumption/handover.go +++ b/lib/resumption/handover.go @@ -76,6 +76,7 @@ func (r *SSHServerWrapper) runHandoverListener(l net.Listener, entry *connEntry) var tempDelay time.Duration for { + // the logic for this Accept loop is copied from [net/http.Server] c, err := l.Accept() if err == nil { tempDelay = 0 From 3f14a3532e3b30b6ea5d61e6e13eecde388451ff Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Tue, 27 Feb 2024 21:28:35 +0100 Subject: [PATCH 14/18] get rid of handover_other, windows supports AF_UNIX --- lib/resumption/handover.go | 148 +++++++++++++++ lib/resumption/handover_other.go | 40 ---- ...handover_unix_test.go => handover_test.go} | 2 - lib/resumption/handover_unix.go | 176 ------------------ lib/resumption/server_exchange.go | 4 +- 5 files changed, 149 insertions(+), 221 deletions(-) delete mode 100644 lib/resumption/handover_other.go rename lib/resumption/{handover_unix_test.go => handover_test.go} (99%) delete mode 100644 lib/resumption/handover_unix.go diff --git a/lib/resumption/handover.go b/lib/resumption/handover.go index 3824e0529360f..c3ebf97ab9b19 100644 --- a/lib/resumption/handover.go +++ b/lib/resumption/handover.go @@ -18,17 +18,43 @@ package resumption import ( "context" + "crypto/sha256" + "encoding/base64" + "errors" "io" "net" "net/netip" + "os" + "path/filepath" + "strings" + "syscall" "time" "github.com/gravitational/trace" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/utils" ) +const sockSuffix = ".sock" + +func sockPath(dataDir string, token resumptionToken) string { + hash := sha256.Sum256(token[:]) + // unix domain sockets are limited to 108 or 104 characters, so the full + // sha256 hash is a bit too much (64 bytes in hex or 44 in b64); truncating + // the hash to 128 bits still gives us more than enough headroom to just + // assume that we'll have no collisions (a probability of one in a + // quintillion with 26 billion concurrent connections) + return filepath.Join(dataDir, "handover", base64.RawURLEncoding.EncodeToString(hash[:16])+sockSuffix) +} + +func sockDir(dataDir string) string { + return filepath.Join(dataDir, "handover") +} + +var errNoDataDir error = &trace.NotFoundError{Message: "data dir not configured"} + func (r *SSHServerWrapper) attemptHandover(conn *multiplexer.Conn, token resumptionToken) { handoverConn, err := r.dialHandover(token) if err != nil { @@ -59,6 +85,18 @@ func (r *SSHServerWrapper) attemptHandover(conn *multiplexer.Conn, token resumpt _ = utils.ProxyConn(context.Background(), conn, handoverConn) } +func (r *SSHServerWrapper) dialHandover(token resumptionToken) (net.Conn, error) { + if r.dataDir == "" { + return nil, trace.Wrap(errNoDataDir) + } + + c, err := net.DialTimeout("unix", sockPath(r.dataDir, token), time.Second) + if err != nil { + return nil, trace.ConvertSystemError(err) + } + return c, nil +} + func (r *SSHServerWrapper) startHandoverListener(ctx context.Context, token resumptionToken, entry *connEntry) error { l, err := r.createHandoverListener(token) if err != nil { @@ -71,6 +109,19 @@ func (r *SSHServerWrapper) startHandoverListener(ctx context.Context, token resu return nil } +func (r *SSHServerWrapper) createHandoverListener(token resumptionToken) (net.Listener, error) { + if r.dataDir == "" { + return nil, trace.Wrap(errNoDataDir) + } + + _ = os.MkdirAll(sockDir(r.dataDir), teleport.PrivateDirMode) + l, err := net.Listen("unix", sockPath(r.dataDir, token)) + if err != nil { + return nil, trace.ConvertSystemError(err) + } + return l, nil +} + func (r *SSHServerWrapper) runHandoverListener(l net.Listener, entry *connEntry) { defer l.Close() @@ -111,3 +162,100 @@ func (r *SSHServerWrapper) handleHandoverConnection(conn net.Conn, entry *connEn r.resumeConnection(entry, conn, remoteIP) } + +// HandoverCleanup deletes hand-over sockets that were left over from previous +// runs of Teleport that failed to clean up after themselves (because of an +// uncatchable signal or a system crash). It will exhaustively clean up the +// current left over sockets, so it's sufficient to call it once per process. +func (r *SSHServerWrapper) HandoverCleanup(ctx context.Context) error { + const cleanupDelay = time.Second + return trace.Wrap(r.handoverCleanup(ctx, cleanupDelay)) +} + +func (r *SSHServerWrapper) handoverCleanup(ctx context.Context, cleanupDelay time.Duration) error { + if r.dataDir == "" { + return nil + } + + dir := sockDir(r.dataDir) + entries, err := os.ReadDir(dir) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return trace.ConvertSystemError(err) + } + + paths := make([]string, 0, len(entries)) + for _, entry := range entries { + if strings.HasSuffix(entry.Name(), sockSuffix) { + paths = append(paths, filepath.Join(dir, entry.Name())) + } + } + + paths, firstErr := retainNonConnectableSockets(ctx, paths) + + if len(paths) < 1 { + return trace.Wrap(firstErr) + } + + // unix domain sockets exist on disk between bind() and listen() but + // connecting before listen() results in ECONNREFUSED, so we just wait a + // little bit before testing them again; the first check lets us be done + // with the check immediately in the happy case where there's no + // unconnectable sockets + r.log.WithField("sockets", len(paths)).Debug("Found some unconnectable handover sockets, waiting before checking them again.") + + select { + case <-ctx.Done(): + return trace.NewAggregate(firstErr, ctx.Err()) + case <-time.After(cleanupDelay): + } + + paths, secondErr := retainNonConnectableSockets(ctx, paths) + + if len(paths) < 1 { + r.log.Debug("Found no unconnectable handover socket after waiting.") + return trace.NewAggregate(firstErr, secondErr) + } + + r.log.WithField("sockets", len(paths)).Info("Cleaning up some non-connectable handover sockets, left over from previous Teleport instances.") + + errs := []error{firstErr, secondErr} + for _, path := range paths { + if err := trace.ConvertSystemError(os.Remove(path)); err != nil { + errs = append(errs, err) + } + } + + return trace.NewAggregate(errs...) +} + +// retainNonConnectableSockets attempts to connect to the given UNIX domain +// sockets, returning all and only the ones that exist and that refuse the +// connection. +func retainNonConnectableSockets(ctx context.Context, paths []string) (filtered []string, lastErr error) { + filtered = paths[:0] + + var d net.Dialer + for _, path := range paths { + c, err := d.DialContext(ctx, "unix", path) + if err == nil { + _ = c.Close() + continue + } + + if errors.Is(err, os.ErrNotExist) { + continue + } + + if errors.Is(err, syscall.ECONNREFUSED) { + filtered = append(filtered, path) + continue + } + + lastErr = err + } + + return filtered, trace.ConvertSystemError(lastErr) +} diff --git a/lib/resumption/handover_other.go b/lib/resumption/handover_other.go deleted file mode 100644 index ac7be32d21433..0000000000000 --- a/lib/resumption/handover_other.go +++ /dev/null @@ -1,40 +0,0 @@ -//go:build !unix - -// Teleport -// Copyright (C) 2024 Gravitational, Inc. -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package resumption - -import ( - "context" - "net" - - "github.com/gravitational/trace" -) - -func (r *SSHServerWrapper) createHandoverListener(token resumptionToken) (net.Listener, error) { - return nil, trace.NotImplemented("handover is not implemented for the current platform") -} - -func (r *SSHServerWrapper) dialHandover(token resumptionToken) (net.Conn, error) { - return nil, trace.NotFound("handover is not implemented for the current platform") -} - -// HandoverCleanup does nothing, because on this platform we don't support -// hand-over sockets, so there can't be anything to clean up. -func (r *SSHServerWrapper) HandoverCleanup(context.Context) error { - return nil -} diff --git a/lib/resumption/handover_unix_test.go b/lib/resumption/handover_test.go similarity index 99% rename from lib/resumption/handover_unix_test.go rename to lib/resumption/handover_test.go index 524f918521ded..6b504866209ee 100644 --- a/lib/resumption/handover_unix_test.go +++ b/lib/resumption/handover_test.go @@ -1,5 +1,3 @@ -//go:build unix - // Teleport // Copyright (C) 2024 Gravitational, Inc. // diff --git a/lib/resumption/handover_unix.go b/lib/resumption/handover_unix.go deleted file mode 100644 index 96c7ea1572b34..0000000000000 --- a/lib/resumption/handover_unix.go +++ /dev/null @@ -1,176 +0,0 @@ -//go:build unix - -// Teleport -// Copyright (C) 2024 Gravitational, Inc. -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package resumption - -import ( - "context" - "crypto/sha256" - "encoding/base64" - "errors" - "net" - "os" - "path/filepath" - "strings" - "syscall" - "time" - - "github.com/gravitational/trace" - - "github.com/gravitational/teleport" -) - -const sockSuffix = ".sock" - -func sockPath(dataDir string, token resumptionToken) string { - hash := sha256.Sum256(token[:]) - // unix domain sockets are limited to 108 or 104 characters, so the full - // sha256 hash is a bit too much (64 bytes in hex or 44 in b64); truncating - // the hash to 128 bits still gives us more than enough headroom to just - // assume that we'll have no collisions (a probability of one in a - // quintillion with 26 billion concurrent connections) - return filepath.Join(dataDir, "handover", base64.RawURLEncoding.EncodeToString(hash[:16])+sockSuffix) -} - -func sockDir(dataDir string) string { - return filepath.Join(dataDir, "handover") -} - -var errNoDataDir error = &trace.NotFoundError{Message: "data dir not configured"} - -func (r *SSHServerWrapper) createHandoverListener(token resumptionToken) (net.Listener, error) { - if r.dataDir == "" { - return nil, trace.Wrap(errNoDataDir) - } - - _ = os.MkdirAll(sockDir(r.dataDir), teleport.PrivateDirMode) - l, err := net.Listen("unix", sockPath(r.dataDir, token)) - if err != nil { - return nil, trace.ConvertSystemError(err) - } - return l, nil -} - -func (r *SSHServerWrapper) dialHandover(token resumptionToken) (net.Conn, error) { - if r.dataDir == "" { - return nil, trace.Wrap(errNoDataDir) - } - - c, err := net.DialTimeout("unix", sockPath(r.dataDir, token), time.Second) - if err != nil { - return nil, trace.ConvertSystemError(err) - } - return c, nil -} - -// retainNonConnectableSockets attempts to connect to the given UNIX domain -// sockets, returning all and only the ones that exist and that refuse the -// connection. -func retainNonConnectableSockets(ctx context.Context, paths []string) (filtered []string, lastErr error) { - filtered = paths[:0] - - var d net.Dialer - for _, path := range paths { - c, err := d.DialContext(ctx, "unix", path) - if err == nil { - _ = c.Close() - continue - } - - if errors.Is(err, os.ErrNotExist) { - continue - } - - if errors.Is(err, syscall.ECONNREFUSED) { - filtered = append(filtered, path) - continue - } - - lastErr = err - } - - return filtered, trace.ConvertSystemError(lastErr) -} - -// HandoverCleanup deletes hand-over sockets that were left over from previous -// runs of Teleport that failed to clean up after themselves (because of an -// uncatchable signal or a system crash). It will exhaustively clean up the -// current left over sockets, so it's sufficient to call it once per process. -func (r *SSHServerWrapper) HandoverCleanup(ctx context.Context) error { - const cleanupDelay = time.Second - return trace.Wrap(r.handoverCleanup(ctx, cleanupDelay)) -} - -func (r *SSHServerWrapper) handoverCleanup(ctx context.Context, cleanupDelay time.Duration) error { - if r.dataDir == "" { - return nil - } - - dir := sockDir(r.dataDir) - entries, err := os.ReadDir(dir) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - return nil - } - return trace.ConvertSystemError(err) - } - - paths := make([]string, 0, len(entries)) - for _, entry := range entries { - if strings.HasSuffix(entry.Name(), sockSuffix) { - paths = append(paths, filepath.Join(dir, entry.Name())) - } - } - - paths, firstErr := retainNonConnectableSockets(ctx, paths) - - if len(paths) < 1 { - return trace.Wrap(firstErr) - } - - // unix domain sockets exist on disk between bind() and listen() but - // connecting before listen() results in ECONNREFUSED, so we just wait a - // little bit before testing them again; the first check lets us be done - // with the check immediately in the happy case where there's no - // unconnectable sockets - r.log.WithField("sockets", len(paths)).Debug("Found some unconnectable handover sockets, waiting before checking them again.") - - select { - case <-ctx.Done(): - return trace.NewAggregate(firstErr, ctx.Err()) - case <-time.After(cleanupDelay): - } - - paths, secondErr := retainNonConnectableSockets(ctx, paths) - - if len(paths) < 1 { - r.log.Debug("Found no unconnectable handover socket after waiting.") - return trace.NewAggregate(firstErr, secondErr) - } - - r.log.WithField("sockets", len(paths)).Info("Cleaning up some non-connectable handover sockets, left over from previous Teleport instances.") - - errs := []error{firstErr, secondErr} - for _, path := range paths { - if err := trace.ConvertSystemError(os.Remove(path)); err != nil { - errs = append(errs, err) - } - } - - return trace.NewAggregate(errs...) -} diff --git a/lib/resumption/server_exchange.go b/lib/resumption/server_exchange.go index 4ea1f9950581d..292c8495cd58d 100644 --- a/lib/resumption/server_exchange.go +++ b/lib/resumption/server_exchange.go @@ -25,8 +25,6 @@ import ( "net/netip" "time" - "github.com/gravitational/trace" - "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/utils" ) @@ -117,7 +115,7 @@ func (r *SSHServerWrapper) handleResumptionExchangeV1(conn *multiplexer.Conn, dh r.conns[token] = entry r.mu.Unlock() - if err := r.startHandoverListener(handoverContext, token, entry); err != nil && !trace.IsNotImplemented(err) { + if err := r.startHandoverListener(handoverContext, token, entry); err != nil { r.log.WithError(err).Warn("Unable to create handover listener for resumable connection, connection resumption will not work across graceful restarts.") } } else { From da3998926b792b2122b2ee4fe94ca9967fa00ca1 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Tue, 27 Feb 2024 21:39:24 +0100 Subject: [PATCH 15/18] don't log "not exist" errors when cleaning up --- lib/resumption/handover.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/resumption/handover.go b/lib/resumption/handover.go index c3ebf97ab9b19..4a130b62b639c 100644 --- a/lib/resumption/handover.go +++ b/lib/resumption/handover.go @@ -223,8 +223,8 @@ func (r *SSHServerWrapper) handoverCleanup(ctx context.Context, cleanupDelay tim errs := []error{firstErr, secondErr} for _, path := range paths { - if err := trace.ConvertSystemError(os.Remove(path)); err != nil { - errs = append(errs, err) + if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) { + errs = append(errs, trace.ConvertSystemError(err)) } } From 347ca067ffd465c357c18b4c1af49d8c8fedd410 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Tue, 27 Feb 2024 21:42:36 +0100 Subject: [PATCH 16/18] grammar and wording --- lib/resumption/handover.go | 6 +++--- lib/resumption/server_detect.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/resumption/handover.go b/lib/resumption/handover.go index 4a130b62b639c..4131b26c4bfd2 100644 --- a/lib/resumption/handover.go +++ b/lib/resumption/handover.go @@ -163,7 +163,7 @@ func (r *SSHServerWrapper) handleHandoverConnection(conn net.Conn, entry *connEn r.resumeConnection(entry, conn, remoteIP) } -// HandoverCleanup deletes hand-over sockets that were left over from previous +// HandoverCleanup deletes handover sockets that were left over from previous // runs of Teleport that failed to clean up after themselves (because of an // uncatchable signal or a system crash). It will exhaustively clean up the // current left over sockets, so it's sufficient to call it once per process. @@ -215,11 +215,11 @@ func (r *SSHServerWrapper) handoverCleanup(ctx context.Context, cleanupDelay tim paths, secondErr := retainNonConnectableSockets(ctx, paths) if len(paths) < 1 { - r.log.Debug("Found no unconnectable handover socket after waiting.") + r.log.Debug("Found no unconnectable handover sockets after waiting.") return trace.NewAggregate(firstErr, secondErr) } - r.log.WithField("sockets", len(paths)).Info("Cleaning up some non-connectable handover sockets, left over from previous Teleport instances.") + r.log.WithField("sockets", len(paths)).Info("Cleaning up some non-connectable handover sockets from old Teleport instances.") errs := []error{firstErr, secondErr} for _, path := range paths { diff --git a/lib/resumption/server_detect.go b/lib/resumption/server_detect.go index 97ac10d5c65ef..9b247271b2c02 100644 --- a/lib/resumption/server_detect.go +++ b/lib/resumption/server_detect.go @@ -71,8 +71,8 @@ type SSHServerWrapperConfig struct { // expecting to reach the instance. HostID string - // DataDir is the path to the Teleport data directory. Depending on the - // platform, it might be used to store temporary hand-over sockets. + // DataDir is the path to the Teleport data directory, used to store + // temporary handover sockets. DataDir string } From f434358aac6303015e71e3644c3e5cd75d029d95 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Tue, 27 Feb 2024 21:58:25 +0100 Subject: [PATCH 17/18] retain multiple errors in cleanup, up to a limit --- lib/resumption/handover.go | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/lib/resumption/handover.go b/lib/resumption/handover.go index 4131b26c4bfd2..d2de8de7c1342 100644 --- a/lib/resumption/handover.go +++ b/lib/resumption/handover.go @@ -224,7 +224,9 @@ func (r *SSHServerWrapper) handoverCleanup(ctx context.Context, cleanupDelay tim errs := []error{firstErr, secondErr} for _, path := range paths { if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) { - errs = append(errs, trace.ConvertSystemError(err)) + if len(errs) < 10 { + errs = append(errs, trace.ConvertSystemError(err)) + } } } @@ -234,10 +236,10 @@ func (r *SSHServerWrapper) handoverCleanup(ctx context.Context, cleanupDelay tim // retainNonConnectableSockets attempts to connect to the given UNIX domain // sockets, returning all and only the ones that exist and that refuse the // connection. -func retainNonConnectableSockets(ctx context.Context, paths []string) (filtered []string, lastErr error) { - filtered = paths[:0] - +func retainNonConnectableSockets(ctx context.Context, paths []string) ([]string, error) { var d net.Dialer + var errs []error + filtered := paths[:0] for _, path := range paths { c, err := d.DialContext(ctx, "unix", path) if err == nil { @@ -254,8 +256,10 @@ func retainNonConnectableSockets(ctx context.Context, paths []string) (filtered continue } - lastErr = err + if len(errs) < 10 { + errs = append(errs, trace.ConvertSystemError(err)) + } } - return filtered, trace.ConvertSystemError(lastErr) + return filtered, trace.NewAggregate(errs...) } From 006bd9e2d36b7dac808eb18641a0252e2832dbfd Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Wed, 28 Feb 2024 15:09:11 +0100 Subject: [PATCH 18/18] Log non-ok-network-error errors from ProxyConn --- lib/resumption/handover.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/resumption/handover.go b/lib/resumption/handover.go index d2de8de7c1342..a60e31eaacff2 100644 --- a/lib/resumption/handover.go +++ b/lib/resumption/handover.go @@ -82,7 +82,9 @@ func (r *SSHServerWrapper) attemptHandover(conn *multiplexer.Conn, token resumpt } r.log.Debug("Forwarding resuming connection to handover socket.") - _ = utils.ProxyConn(context.Background(), conn, handoverConn) + if err := utils.ProxyConn(context.Background(), conn, handoverConn); err != nil && !utils.IsOKNetworkError(err) { + r.log.WithError(err).Debug("Finished forwarding resuming connection to handover socket.") + } } func (r *SSHServerWrapper) dialHandover(token resumptionToken) (net.Conn, error) {