Skip to content

Commit

Permalink
Expand handover tests
Browse files Browse the repository at this point in the history
  • Loading branch information
espadolini committed Feb 22, 2024
1 parent cfbec5f commit 63794f5
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 19 deletions.
9 changes: 8 additions & 1 deletion lib/resumption/handover_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ func filterNonConnectableSockets(ctx context.Context, paths []string) (filtered
return filtered, lastErr
}

type cleanupDelayContextKey struct{}

// HandoverCleanup deletes hand-over sockets that were left over from previous
// runs of Teleport that failed to clean up after themselves (because of an
// uncatchable signal or a system crash). It will exhaustively clean up the
Expand Down Expand Up @@ -140,10 +142,15 @@ func (r *SSHServerWrapper) HandoverCleanup(ctx context.Context) error {
// unconnectable sockets
r.log.WithField("sockets", len(paths)).Debug("Found some unconnectable handover sockets, waiting before checking them again.")

cleanupDelay := time.Second
if d, ok := ctx.Value((*cleanupDelayContextKey)(nil)).(time.Duration); ok {
cleanupDelay = d
}

select {
case <-ctx.Done():
return trace.NewAggregate(firstErr, ctx.Err())
case <-time.After(3 * time.Second):
case <-time.After(cleanupDelay):
}

paths, secondErr := filterNonConnectableSockets(ctx, paths)
Expand Down
118 changes: 100 additions & 18 deletions lib/resumption/handover_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@ package resumption

import (
"context"
"encoding/binary"
"math/rand"
"net"
"net/netip"
"os"
"path/filepath"
"runtime"
"testing"
"time"

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/utils"
Expand All @@ -54,38 +59,44 @@ func TestHandover(t *testing.T) {
DataDir: dataDir,
})

dial := func(handleConnection func(net.Conn)) (net.Conn, error) {
dial := func(handleConnection func(net.Conn), clientAddr netip.Addr) net.Conn {
c1, c2, err := uds.NewSocketpair(uds.SocketTypeStream)
if err != nil {
return nil, err
}
require.NoError(err)

a1 := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: rand.Intn(65536)}
a2 := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: rand.Intn(65536)}
srv := &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 1 + rand.Intn(65535),
}
clt := &net.TCPAddr{
IP: clientAddr.AsSlice(),
Zone: clientAddr.Zone(),
Port: 1 + rand.Intn(65535),
}

go handleConnection(utils.NewConnWithAddr(c2, a2, a1))
return utils.NewConnWithAddr(c1, a1, a2), nil
go handleConnection(utils.NewConnWithAddr(c2, srv, clt))
conn := utils.NewConnWithAddr(c1, clt, srv)
t.Cleanup(func() { _ = conn.Close() })
return conn
}

originalNC, err := dial(s1.HandleConnection)
require.NoError(err)
defer originalNC.Close()
originalNC := dial(s1.HandleConnection, netip.MustParseAddr("127.0.0.1"))

redialDestination := make(chan func(net.Conn))
defer close(redialDestination)
redialConns := make(chan net.Conn)
defer close(redialConns)

wrappedNC, err := WrapSSHClientConn(context.Background(), originalNC, func(ctx context.Context, receivedHostID string) (net.Conn, error) {
if receivedHostID != hostID {
return nil, trace.BadParameter("expected hostID %q, got %q", hostID, receivedHostID)
}
handleConnection := <-redialDestination
if handleConnection == nil {
return nil, trace.ConnectionProblem(nil, "no redial destination received")
conn := <-redialConns
if conn == nil {
return nil, trace.ConnectionProblem(nil, "no redial connection received")
}
return dial(handleConnection)
return conn, nil
})
require.NoError(err)
defer wrappedNC.Close()

require.IsType((*Conn)(nil), wrappedNC)

clt, err := sshClient(wrappedNC)
Expand All @@ -97,10 +108,18 @@ func TestHandover(t *testing.T) {
require.NoError(err)

_ = originalNC.Close()
redialDestination <- s2.HandleConnection
nextNC := dial(s2.HandleConnection, netip.MustParseAddr("127.0.0.1"))
redialConns <- nextNC

_, _, err = clt.SendRequest("foo", wantReplyTrue, nil)
require.NoError(err)

_ = nextNC.Close()
// this will result in a closed connection, because changing network address
// stops further reconnection attempts
redialConns <- dial(s2.HandleConnection, netip.MustParseAddr("127.0.0.2"))

require.ErrorIs(clt.Wait(), net.ErrClosed)
}

func shortTempDir(t *testing.T) string {
Expand All @@ -114,3 +133,66 @@ func shortTempDir(t *testing.T) string {
t.Cleanup(func() { require.NoError(t, os.RemoveAll(d)) })
return d
}

func TestHandoverCleanup(t *testing.T) {
t.Parallel()
require := require.New(t)

hostID := uuid.NewString()
// unix domain socket names have a very tight length limit
dataDir := shortTempDir(t)

var tok resumptionToken
binary.NativeEndian.PutUint64(tok[:8], rand.Uint64())
binary.NativeEndian.PutUint64(tok[8:], rand.Uint64())

s := NewSSHServerWrapper(SSHServerWrapperConfig{
SSHServer: func(c net.Conn) {
defer c.Close()
assert.Fail(t, "unexpected connection")
},
HostID: hostID,
DataDir: dataDir,
})

handoverDir := filepath.Join(dataDir, "handover")
require.NoError(os.MkdirAll(handoverDir, 0o700))

d, err := os.ReadDir(handoverDir)
require.NoError(err)
require.Empty(d)

l, err := net.ListenUnix("unix", &net.UnixAddr{Name: sockPath(dataDir, tok)})
require.NoError(err)
l.SetUnlinkOnClose(false)
defer l.Close()
go func() {
defer l.Close()
for {
c, err := l.Accept()
if err != nil {
break
}
_ = c.Close()
}
}()

d, err = os.ReadDir(handoverDir)
require.NoError(err)
require.NotEmpty(d)

ctx := context.WithValue(context.Background(), (*cleanupDelayContextKey)(nil), time.Duration(0))

require.NoError(s.HandoverCleanup(ctx))

d, err = os.ReadDir(handoverDir)
require.NoError(err)
require.NotEmpty(d)

l.Close()
require.NoError(s.HandoverCleanup(ctx))

d, err = os.ReadDir(handoverDir)
require.NoError(err)
require.Empty(d)
}

0 comments on commit 63794f5

Please sign in to comment.