diff --git a/lib/resumption/handover_unix.go b/lib/resumption/handover_unix.go index 78f361dcc1035..77c95bb8e9320 100644 --- a/lib/resumption/handover_unix.go +++ b/lib/resumption/handover_unix.go @@ -102,6 +102,8 @@ func filterNonConnectableSockets(ctx context.Context, paths []string) (filtered return filtered, lastErr } +type cleanupDelayContextKey struct{} + // HandoverCleanup deletes hand-over sockets that were left over from previous // runs of Teleport that failed to clean up after themselves (because of an // uncatchable signal or a system crash). It will exhaustively clean up the @@ -140,10 +142,15 @@ func (r *SSHServerWrapper) HandoverCleanup(ctx context.Context) error { // unconnectable sockets r.log.WithField("sockets", len(paths)).Debug("Found some unconnectable handover sockets, waiting before checking them again.") + cleanupDelay := time.Second + if d, ok := ctx.Value((*cleanupDelayContextKey)(nil)).(time.Duration); ok { + cleanupDelay = d + } + select { case <-ctx.Done(): return trace.NewAggregate(firstErr, ctx.Err()) - case <-time.After(3 * time.Second): + case <-time.After(cleanupDelay): } paths, secondErr := filterNonConnectableSockets(ctx, paths) diff --git a/lib/resumption/handover_unix_test.go b/lib/resumption/handover_unix_test.go index 38a308dd6bef0..b7baae75990fe 100644 --- a/lib/resumption/handover_unix_test.go +++ b/lib/resumption/handover_unix_test.go @@ -20,14 +20,19 @@ package resumption import ( "context" + "encoding/binary" "math/rand" "net" + "net/netip" "os" + "path/filepath" "runtime" "testing" + "time" "github.com/google/uuid" "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/lib/utils" @@ -54,38 +59,44 @@ func TestHandover(t *testing.T) { DataDir: dataDir, }) - dial := func(handleConnection func(net.Conn)) (net.Conn, error) { + dial := func(handleConnection func(net.Conn), clientAddr netip.Addr) net.Conn { c1, c2, err := uds.NewSocketpair(uds.SocketTypeStream) - if err != nil { - return nil, err - } + require.NoError(err) - a1 := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: rand.Intn(65536)} - a2 := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: rand.Intn(65536)} + srv := &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 1 + rand.Intn(65535), + } + clt := &net.TCPAddr{ + IP: clientAddr.AsSlice(), + Zone: clientAddr.Zone(), + Port: 1 + rand.Intn(65535), + } - go handleConnection(utils.NewConnWithAddr(c2, a2, a1)) - return utils.NewConnWithAddr(c1, a1, a2), nil + go handleConnection(utils.NewConnWithAddr(c2, srv, clt)) + conn := utils.NewConnWithAddr(c1, clt, srv) + t.Cleanup(func() { _ = conn.Close() }) + return conn } - originalNC, err := dial(s1.HandleConnection) - require.NoError(err) - defer originalNC.Close() + originalNC := dial(s1.HandleConnection, netip.MustParseAddr("127.0.0.1")) - redialDestination := make(chan func(net.Conn)) - defer close(redialDestination) + redialConns := make(chan net.Conn) + defer close(redialConns) wrappedNC, err := WrapSSHClientConn(context.Background(), originalNC, func(ctx context.Context, receivedHostID string) (net.Conn, error) { if receivedHostID != hostID { return nil, trace.BadParameter("expected hostID %q, got %q", hostID, receivedHostID) } - handleConnection := <-redialDestination - if handleConnection == nil { - return nil, trace.ConnectionProblem(nil, "no redial destination received") + conn := <-redialConns + if conn == nil { + return nil, trace.ConnectionProblem(nil, "no redial connection received") } - return dial(handleConnection) + return conn, nil }) require.NoError(err) defer wrappedNC.Close() + require.IsType((*Conn)(nil), wrappedNC) clt, err := sshClient(wrappedNC) @@ -97,10 +108,18 @@ func TestHandover(t *testing.T) { require.NoError(err) _ = originalNC.Close() - redialDestination <- s2.HandleConnection + nextNC := dial(s2.HandleConnection, netip.MustParseAddr("127.0.0.1")) + redialConns <- nextNC _, _, err = clt.SendRequest("foo", wantReplyTrue, nil) require.NoError(err) + + _ = nextNC.Close() + // this will result in a closed connection, because changing network address + // stops further reconnection attempts + redialConns <- dial(s2.HandleConnection, netip.MustParseAddr("127.0.0.2")) + + require.ErrorIs(clt.Wait(), net.ErrClosed) } func shortTempDir(t *testing.T) string { @@ -114,3 +133,66 @@ func shortTempDir(t *testing.T) string { t.Cleanup(func() { require.NoError(t, os.RemoveAll(d)) }) return d } + +func TestHandoverCleanup(t *testing.T) { + t.Parallel() + require := require.New(t) + + hostID := uuid.NewString() + // unix domain socket names have a very tight length limit + dataDir := shortTempDir(t) + + var tok resumptionToken + binary.NativeEndian.PutUint64(tok[:8], rand.Uint64()) + binary.NativeEndian.PutUint64(tok[8:], rand.Uint64()) + + s := NewSSHServerWrapper(SSHServerWrapperConfig{ + SSHServer: func(c net.Conn) { + defer c.Close() + assert.Fail(t, "unexpected connection") + }, + HostID: hostID, + DataDir: dataDir, + }) + + handoverDir := filepath.Join(dataDir, "handover") + require.NoError(os.MkdirAll(handoverDir, 0o700)) + + d, err := os.ReadDir(handoverDir) + require.NoError(err) + require.Empty(d) + + l, err := net.ListenUnix("unix", &net.UnixAddr{Name: sockPath(dataDir, tok)}) + require.NoError(err) + l.SetUnlinkOnClose(false) + defer l.Close() + go func() { + defer l.Close() + for { + c, err := l.Accept() + if err != nil { + break + } + _ = c.Close() + } + }() + + d, err = os.ReadDir(handoverDir) + require.NoError(err) + require.NotEmpty(d) + + ctx := context.WithValue(context.Background(), (*cleanupDelayContextKey)(nil), time.Duration(0)) + + require.NoError(s.HandoverCleanup(ctx)) + + d, err = os.ReadDir(handoverDir) + require.NoError(err) + require.NotEmpty(d) + + l.Close() + require.NoError(s.HandoverCleanup(ctx)) + + d, err = os.ReadDir(handoverDir) + require.NoError(err) + require.Empty(d) +}