From 5a16d8d71de48cfbd3c476dc01bbc6b6cbd27526 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Fri, 4 Oct 2024 18:53:04 +0200 Subject: [PATCH 1/5] Make the peer clientConn generic --- lib/proxy/peer/client.go | 313 ++++++++++++++++++++------------- lib/proxy/peer/client_test.go | 47 ++--- lib/proxy/peer/helpers_test.go | 4 +- lib/proxy/peer/server_test.go | 8 +- 4 files changed, 223 insertions(+), 149 deletions(-) diff --git a/lib/proxy/peer/client.go b/lib/proxy/peer/client.go index 36d70b87779f6..ed322c3307936 100644 --- a/lib/proxy/peer/client.go +++ b/lib/proxy/peer/client.go @@ -21,7 +21,7 @@ package peer import ( "context" "crypto/tls" - "math/rand" + "math/rand/v2" "net" "sync" "time" @@ -86,11 +86,11 @@ type ClientConfig struct { } // connShuffler shuffles the order of client connections. -type connShuffler func([]*clientConn) +type connShuffler func([]clientConn) // randomConnShuffler returns a conn shuffler that randomizes the order of connections. func randomConnShuffler() connShuffler { - return func(conns []*clientConn) { + return func(conns []clientConn) { rand.Shuffle(len(conns), func(i, j int) { conns[i], conns[j] = conns[j], conns[i] }) @@ -99,7 +99,7 @@ func randomConnShuffler() connShuffler { // noopConnShutffler returns a conn shuffler that keeps the original connection ordering. func noopConnShuffler() connShuffler { - return func([]*clientConn) {} + return func([]clientConn) {} } // checkAndSetDefaults checks and sets default values @@ -160,9 +160,36 @@ func (c *ClientConfig) checkAndSetDefaults() error { return nil } -// clientConn hold info about a dialed grpc connection -type clientConn struct { - *grpc.ClientConn +// clientConn manages client connections to a specific peer proxy (with a fixed +// host ID and address). +type clientConn interface { + // peerID returns the host ID of the peer proxy. + peerID() string + // peerAddr returns the address of the peer proxy. + peerAddr() string + + // dial opens a connection of a given tunnel type to a node with the given + // ID through the peer proxy managed by the clientConn. + dial( + nodeID string, + src net.Addr, + dst net.Addr, + tunnelType types.TunnelType, + ) (net.Conn, error) + + // close closes all connections and releases any background resources + // immediately. + close() error + + // shutdown waits until all connections are closed or the context is done, + // then acts like close. + shutdown(context.Context) +} + +// grpcClientConn manages client connections to a specific peer proxy over gRPC. +type grpcClientConn struct { + cc *grpc.ClientConn + metrics *clientMetrics id string addr string @@ -175,7 +202,17 @@ type clientConn struct { count int } -func (c *clientConn) maybeAcquire() (release func()) { +var _ clientConn = (*grpcClientConn)(nil) + +// peerID implements [clientConn]. +func (c *grpcClientConn) peerID() string { return c.id } + +// peerAddr implements [clientConn]. +func (c *grpcClientConn) peerAddr() string { return c.addr } + +// maybeAcquire returns a non-nil release func if the grpcClientConn is +// currently allowed to open connections; i.e., if it hasn't fully shut down. +func (c *grpcClientConn) maybeAcquire() (release func()) { c.mu.Lock() defer c.mu.Unlock() @@ -194,10 +231,9 @@ func (c *clientConn) maybeAcquire() (release func()) { }) } -// Shutdown closes the clientConn after all connections through it are closed, -// or after the context is done. -func (c *clientConn) Shutdown(ctx context.Context) { - defer c.Close() +// shutdown implements [clientConn]. +func (c *grpcClientConn) shutdown(ctx context.Context) { + defer c.cc.Close() c.mu.Lock() defer c.mu.Unlock() @@ -216,14 +252,87 @@ func (c *clientConn) Shutdown(ctx context.Context) { } } -// Client is a peer proxy service client using grpc and tls. +// close implements [clientConn]. +func (c *grpcClientConn) close() error { + return c.cc.Close() +} + +// dial implements [clientConn]. +func (c *grpcClientConn) dial( + nodeID string, + src net.Addr, + dst net.Addr, + tunnelType types.TunnelType, +) (net.Conn, error) { + release := c.maybeAcquire() + if release == nil { + c.metrics.reportTunnelError(errorProxyPeerTunnelRPC) + return nil, trace.ConnectionProblem(nil, "error starting stream: connection is shutting down") + } + + ctx, cancel := context.WithCancel(context.Background()) + context.AfterFunc(ctx, release) + + stream, err := clientapi.NewProxyServiceClient(c.cc).DialNode(ctx) + if err != nil { + cancel() + c.metrics.reportTunnelError(errorProxyPeerTunnelRPC) + return nil, trace.ConnectionProblem(err, "error starting stream: %v", err) + } + + err = stream.Send(&clientapi.Frame{ + Message: &clientapi.Frame_DialRequest{ + DialRequest: &clientapi.DialRequest{ + NodeID: nodeID, + TunnelType: tunnelType, + Source: &clientapi.NetAddr{ + Addr: src.String(), + Network: src.Network(), + }, + Destination: &clientapi.NetAddr{ + Addr: dst.String(), + Network: dst.Network(), + }, + }, + }, + }) + if err != nil { + cancel() + return nil, trace.ConnectionProblem(err, "error sending dial frame: %v", err) + } + msg, err := stream.Recv() + if err != nil { + cancel() + return nil, trace.ConnectionProblem(err, "error receiving dial response: %v", err) + } + if msg.GetConnectionEstablished() == nil { + cancel() + return nil, trace.ConnectionProblem(nil, "received malformed connection established frame") + } + + source := &frameStream{ + stream: stream, + cancel: cancel, + } + + streamRW, err := streamutils.NewReadWriter(source) + if err != nil { + _ = source.Close() + return nil, trace.Wrap(err) + } + + return streamutils.NewConn(streamRW, src, dst), nil +} + +// Client manages connections to known peer proxies and allows to open +// connections to agents through them. type Client struct { sync.RWMutex ctx context.Context cancel context.CancelFunc config ClientConfig - conns map[string]*clientConn + conns map[string]clientConn metrics *clientMetrics reporter *reporter } @@ -248,7 +357,7 @@ func NewClient(config ClientConfig) (*Client, error) { config: config, ctx: closeContext, cancel: cancel, - conns: make(map[string]*clientConn), + conns: make(map[string]clientConn), metrics: metrics, reporter: reporter, } @@ -276,17 +385,20 @@ func (c *Client) monitor() { c.RLock() c.reporter.resetConnections() for _, conn := range c.conns { - switch conn.GetState() { - case connectivity.Idle: - c.reporter.incConnection(c.config.ID, conn.id, connectivity.Idle.String()) - case connectivity.Connecting: - c.reporter.incConnection(c.config.ID, conn.id, connectivity.Connecting.String()) - case connectivity.Ready: - c.reporter.incConnection(c.config.ID, conn.id, connectivity.Ready.String()) - case connectivity.TransientFailure: - c.reporter.incConnection(c.config.ID, conn.id, connectivity.TransientFailure.String()) - case connectivity.Shutdown: - c.reporter.incConnection(c.config.ID, conn.id, connectivity.Shutdown.String()) + switch conn := conn.(type) { + case *grpcClientConn: + switch conn.cc.GetState() { + case connectivity.Idle: + c.reporter.incConnection(c.config.ID, conn.id, connectivity.Idle.String()) + case connectivity.Connecting: + c.reporter.incConnection(c.config.ID, conn.id, connectivity.Connecting.String()) + case connectivity.Ready: + c.reporter.incConnection(c.config.ID, conn.id, connectivity.Ready.String()) + case connectivity.TransientFailure: + c.reporter.incConnection(c.config.ID, conn.id, connectivity.TransientFailure.String()) + case connectivity.Shutdown: + c.reporter.incConnection(c.config.ID, conn.id, connectivity.Shutdown.String()) + } } } c.RUnlock() @@ -337,7 +449,7 @@ func (c *Client) updateConnections(proxies []types.Server) error { } var toDelete []string - toKeep := make(map[string]*clientConn) + toKeep := make(map[string]clientConn) for id, conn := range c.conns { proxy, ok := toDial[id] @@ -348,7 +460,7 @@ func (c *Client) updateConnections(proxies []types.Server) error { } // peer address changed - if conn.addr != proxy.GetPeerAddr() { + if conn.peerAddr() != proxy.GetPeerAddr() { toDelete = append(toDelete, id) continue } @@ -386,7 +498,7 @@ func (c *Client) updateConnections(proxies []types.Server) error { for _, id := range toDelete { if conn, ok := c.conns[id]; ok { - go conn.Shutdown(c.ctx) + go conn.shutdown(c.ctx) } } c.conns = toKeep @@ -394,39 +506,6 @@ func (c *Client) updateConnections(proxies []types.Server) error { return trace.NewAggregate(errs...) } -// DialNode dials a node through a peer proxy. -func (c *Client) DialNode( - proxyIDs []string, - nodeID string, - src net.Addr, - dst net.Addr, - tunnelType types.TunnelType, -) (net.Conn, error) { - stream, _, err := c.dial(proxyIDs, &clientapi.DialRequest{ - NodeID: nodeID, - TunnelType: tunnelType, - Source: &clientapi.NetAddr{ - Addr: src.String(), - Network: src.Network(), - }, - Destination: &clientapi.NetAddr{ - Addr: dst.String(), - Network: dst.Network(), - }, - }) - if err != nil { - return nil, trace.ConnectionProblem(err, "error dialing peer proxies %s: %v", proxyIDs, err) - } - - streamRW, err := streamutils.NewReadWriter(stream) - if err != nil { - _ = stream.Close() - return nil, trace.Wrap(err) - } - - return streamutils.NewConn(streamRW, src, dst), nil -} - // stream is the common subset of the [clientapi.ProxyService_DialNodeClient] and // [clientapi.ProxyService_DialNodeServer] interfaces. type stream interface { @@ -472,9 +551,9 @@ func (c *Client) Shutdown(ctx context.Context) { var wg sync.WaitGroup for _, conn := range c.conns { wg.Add(1) - go func(conn *clientConn) { + go func(conn clientConn) { defer wg.Done() - conn.Shutdown(ctx) + conn.shutdown(ctx) }(conn) } wg.Wait() @@ -488,7 +567,7 @@ func (c *Client) Stop() error { var errs []error for _, conn := range c.conns { - if err := conn.Close(); err != nil { + if err := conn.close(); err != nil { errs = append(errs, err) } } @@ -502,67 +581,56 @@ func (c *Client) GetConnectionsCount() int { return len(c.conns) } -// dial opens a new stream to one of the supplied proxy ids. -// it tries to find an existing grpc.ClientConn or initializes a new rpc -// to one of the proxies otherwise. -// The boolean returned in the second argument is intended for testing purposes, -// to indicates whether the connection was cached or newly established. -func (c *Client) dial(proxyIDs []string, dialRequest *clientapi.DialRequest) (frameStream, bool, error) { - conns, existing, err := c.getConnections(proxyIDs) +// DialNode dials a node through a peer proxy. +func (c *Client) DialNode( + proxyIDs []string, + nodeID string, + src net.Addr, + dst net.Addr, + tunnelType types.TunnelType, +) (net.Conn, error) { + conn, _, err := c.dial( + proxyIDs, + nodeID, + src, + dst, + tunnelType, + ) if err != nil { - return frameStream{}, existing, trace.Wrap(err) + return nil, trace.Wrap(err) } - var errs []error - for _, conn := range conns { - release := conn.maybeAcquire() - if release == nil { - c.metrics.reportTunnelError(errorProxyPeerTunnelRPC) - errs = append(errs, trace.ConnectionProblem(nil, "error starting stream: connection is shutting down")) - continue - } - - ctx, cancel := context.WithCancel(context.Background()) - context.AfterFunc(ctx, release) + return conn, nil +} - stream, err := clientapi.NewProxyServiceClient(conn.ClientConn).DialNode(ctx) - if err != nil { - cancel() - c.metrics.reportTunnelError(errorProxyPeerTunnelRPC) - c.config.Log.Debugf("Error opening tunnel rpc to proxy %+v at %+v", conn.id, conn.addr) - errs = append(errs, trace.ConnectionProblem(err, "error starting stream: %v", err)) - continue - } +// dial opens a new connection through one of the given proxy ids. It tries to +// find an existing [clientConn] or initializes new clientConns to the given +// proxies otherwise. The boolean returned in the second argument is intended +// for testing purposes, to indicates whether the connection used an existing +// clientConn or a newly established one. +func (c *Client) dial( + proxyIDs []string, + nodeID string, + src net.Addr, + dst net.Addr, + tunnelType types.TunnelType, +) (net.Conn, bool, error) { + conns, existing, err := c.getConnections(proxyIDs) + if err != nil { + return nil, false, trace.Wrap(err) + } - err = stream.Send(&clientapi.Frame{ - Message: &clientapi.Frame_DialRequest{ - DialRequest: dialRequest, - }, - }) - if err != nil { - cancel() - errs = append(errs, trace.ConnectionProblem(err, "error sending dial frame: %v", err)) - continue - } - msg, err := stream.Recv() + var errs []error + for _, clientConn := range conns { + conn, err := clientConn.dial(nodeID, src, dst, tunnelType) if err != nil { - cancel() - errs = append(errs, trace.ConnectionProblem(err, "error receiving dial response: %v", err)) + errs = append(errs, trace.Wrap(err)) continue } - if msg.GetConnectionEstablished() == nil { - cancel() - errs = append(errs, trace.ConnectionProblem(nil, "received malformed connection established frame")) - continue - } - - return frameStream{ - stream: stream, - cancel: cancel, - }, existing, nil + return conn, existing, nil } - return frameStream{}, existing, trace.NewAggregate(errs...) + return nil, existing, trace.NewAggregate(errs...) } // getConnections returns connections to the supplied proxy ids. @@ -570,13 +638,13 @@ func (c *Client) dial(proxyIDs []string, dialRequest *clientapi.DialRequest) (fr // otherwise. // The boolean returned in the second argument is intended for testing purposes, // to indicates whether the connection was cached or newly established. -func (c *Client) getConnections(proxyIDs []string) ([]*clientConn, bool, error) { +func (c *Client) getConnections(proxyIDs []string) ([]clientConn, bool, error) { if len(proxyIDs) == 0 { return nil, false, trace.BadParameter("failed to dial: no proxy ids given") } ids := make(map[string]struct{}) - var conns []*clientConn + var conns []clientConn // look for existing matching connections. c.RLock() @@ -633,7 +701,7 @@ func (c *Client) getConnections(proxyIDs []string) ([]*clientConn, bool, error) defer c.Unlock() for _, conn := range conns { - c.conns[conn.id] = conn + c.conns[conn.peerID()] = conn } c.config.connShuffler(conns) @@ -641,7 +709,7 @@ func (c *Client) getConnections(proxyIDs []string) ([]*clientConn, bool, error) } // connect dials a new connection to proxyAddr. -func (c *Client) connect(peerID string, peerAddr string) (*clientConn, error) { +func (c *Client) connect(peerID string, peerAddr string) (clientConn, error) { tlsConfig, err := c.config.getConfigForServer() if err != nil { return nil, trace.Wrap(err, "Error updating client tls config") @@ -665,8 +733,9 @@ func (c *Client) connect(peerID string, peerAddr string) (*clientConn, error) { return nil, trace.Wrap(err, "Error dialing proxy %q", peerID) } - return &clientConn{ - ClientConn: conn, + return &grpcClientConn{ + cc: conn, + metrics: c.metrics, id: peerID, addr: peerAddr, diff --git a/lib/proxy/peer/client_test.go b/lib/proxy/peer/client_test.go index 017eeb8b3f4dd..c403b209ea03d 100644 --- a/lib/proxy/peer/client_test.go +++ b/lib/proxy/peer/client_test.go @@ -32,6 +32,7 @@ import ( "github.com/gravitational/teleport/api/client/proto" clientapi "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/utils" ) // TestClientConn checks the client's connection caching capabilities @@ -48,33 +49,33 @@ func TestClientConn(t *testing.T) { require.Len(t, client.conns, 2) // dial first server and send a test data frame - stream, cached, err := client.dial([]string{"s1"}, &proto.DialRequest{}) + stream, cached, err := client.dial([]string{"s1"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) require.True(t, cached) - require.NotNil(t, stream.stream) + require.NotNil(t, stream) stream.Close() // dial second server - stream, cached, err = client.dial([]string{"s2"}, &proto.DialRequest{}) + stream, cached, err = client.dial([]string{"s2"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) require.True(t, cached) - require.NotNil(t, stream.stream) + require.NotNil(t, stream) stream.Close() // redial second server - stream, cached, err = client.dial([]string{"s2"}, &proto.DialRequest{}) + stream, cached, err = client.dial([]string{"s2"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) require.True(t, cached) - require.NotNil(t, stream.stream) + require.NotNil(t, stream) stream.Close() // close second server // and attempt to redial it server2.Shutdown() - stream, cached, err = client.dial([]string{"s2"}, &proto.DialRequest{}) + stream, cached, err = client.dial([]string{"s2"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.Error(t, err) require.True(t, cached) - require.Nil(t, stream.stream) + require.Nil(t, stream) } // TestClientUpdate checks the client's watcher update behavior @@ -92,12 +93,12 @@ func TestClientUpdate(t *testing.T) { require.Contains(t, client.conns, "s1") require.Contains(t, client.conns, "s2") - s1, _, err := client.dial([]string{"s1"}, &proto.DialRequest{}) + s1, _, err := client.dial([]string{"s1"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) - require.NotNil(t, s1.stream) - s2, _, err := client.dial([]string{"s2"}, &proto.DialRequest{}) + require.NotNil(t, s1) + s2, _, err := client.dial([]string{"s2"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) - require.NotNil(t, s2.stream) + require.NotNil(t, s2) // watcher finds one of the two servers err = client.updateConnections([]types.Server{def1}) @@ -116,7 +117,7 @@ func TestClientUpdate(t *testing.T) { require.Len(t, client.conns, 2) require.Contains(t, client.conns, "s1") sendMsg(t, s1) // stream is still going strong - _, _, err = client.dial([]string{"s2"}, &proto.DialRequest{}) + _, _, err = client.dial([]string{"s2"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.Error(t, err) // can't dial server2, obviously // peer address change @@ -126,7 +127,7 @@ func TestClientUpdate(t *testing.T) { require.Len(t, client.conns, 1) require.Contains(t, client.conns, "s1") sendMsg(t, s1) // stream is not forcefully closed. ClientConn waits for a graceful shutdown before it closes. - s3, _, err := client.dial([]string{"s1"}, &proto.DialRequest{}) + s3, _, err := client.dial([]string{"s1"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) require.NotNil(t, s3) @@ -145,9 +146,10 @@ func TestCAChange(t *testing.T) { conn, err := client.connect("s1", server.config.Listener.Addr().String()) require.NoError(t, err) require.NotNil(t, conn) + require.IsType(t, (*grpcClientConn)(nil), conn) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - stream, err := clientapi.NewProxyServiceClient(conn.ClientConn).DialNode(ctx) + stream, err := clientapi.NewProxyServiceClient(conn.(*grpcClientConn).cc).DialNode(ctx) require.NoError(t, err) require.NotNil(t, stream) @@ -161,7 +163,8 @@ func TestCAChange(t *testing.T) { conn, err = client.connect("s1", server.config.Listener.Addr().String()) require.NoError(t, err) require.NotNil(t, conn) - stream, err = clientapi.NewProxyServiceClient(conn.ClientConn).DialNode(ctx) + require.IsType(t, (*grpcClientConn)(nil), conn) + stream, err = clientapi.NewProxyServiceClient(conn.(*grpcClientConn).cc).DialNode(ctx) require.Error(t, err) require.Nil(t, stream) @@ -178,7 +181,8 @@ func TestCAChange(t *testing.T) { conn, err = client.connect("s1", server.config.Listener.Addr().String()) require.NoError(t, err) require.NotNil(t, conn) - stream, err = clientapi.NewProxyServiceClient(conn.ClientConn).DialNode(ctx) + require.IsType(t, (*grpcClientConn)(nil), conn) + stream, err = clientapi.NewProxyServiceClient(conn.(*grpcClientConn).cc).DialNode(ctx) require.NoError(t, err) require.NotNil(t, stream) } @@ -201,17 +205,18 @@ func TestBackupClient(t *testing.T) { err := client.updateConnections([]types.Server{def1, def2}) require.NoError(t, err) - waitForConns(t, client.conns, time.Second*2) + waitForGRPCConns(t, client.conns, time.Second*2) - _, _, err = client.dial([]string{def1.GetName(), def2.GetName()}, &proto.DialRequest{}) + _, _, err = client.dial([]string{def1.GetName(), def2.GetName()}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) require.True(t, dialCalled) } -func waitForConns(t *testing.T, conns map[string]*clientConn, d time.Duration) { +func waitForGRPCConns(t *testing.T, conns map[string]clientConn, d time.Duration) { require.Eventually(t, func() bool { for _, conn := range conns { - if conn.GetState() != connectivity.Ready { + // panic if we hit a non-grpc client conn + if conn.(*grpcClientConn).cc.GetState() != connectivity.Ready { return false } } diff --git a/lib/proxy/peer/helpers_test.go b/lib/proxy/peer/helpers_test.go index a07a4fd01f1dc..8ade7b887d8da 100644 --- a/lib/proxy/peer/helpers_test.go +++ b/lib/proxy/peer/helpers_test.go @@ -262,7 +262,7 @@ func setupServer(t *testing.T, name string, serverCA, clientCA *tlsca.CertAuthor return server, ts } -func sendMsg(t *testing.T, stream frameStream) { - err := stream.Send([]byte("ping")) +func sendMsg(t *testing.T, stream net.Conn) { + _, err := stream.Write([]byte("ping")) require.NoError(t, err) } diff --git a/lib/proxy/peer/server_test.go b/lib/proxy/peer/server_test.go index 0f88b332c4864..772c5802fb15f 100644 --- a/lib/proxy/peer/server_test.go +++ b/lib/proxy/peer/server_test.go @@ -23,8 +23,8 @@ import ( "github.com/stretchr/testify/require" - "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/utils" ) // TestServerTLS ensures that only trusted certificates with the proxy role @@ -38,7 +38,7 @@ func TestServerTLS(t *testing.T) { _, serverDef1 := setupServer(t, "s1", ca1, ca1, types.RoleProxy) err := client1.updateConnections([]types.Server{serverDef1}) require.NoError(t, err) - stream, _, err := client1.dial([]string{"s1"}, &proto.DialRequest{}) + stream, _, err := client1.dial([]string{"s1"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) require.NotNil(t, stream) stream.Close() @@ -48,7 +48,7 @@ func TestServerTLS(t *testing.T) { _, serverDef2 := setupServer(t, "s2", ca1, ca1, types.RoleProxy) err = client2.updateConnections([]types.Server{serverDef2}) require.NoError(t, err) // connection succeeds but is in transient failure state - _, _, err = client2.dial([]string{"s2"}, &proto.DialRequest{}) + _, _, err = client2.dial([]string{"s2"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.Error(t, err) // certificates with correct role from different CAs @@ -56,7 +56,7 @@ func TestServerTLS(t *testing.T) { _, serverDef3 := setupServer(t, "s3", ca2, ca1, types.RoleProxy) err = client3.updateConnections([]types.Server{serverDef3}) require.NoError(t, err) - stream, _, err = client3.dial([]string{"s3"}, &proto.DialRequest{}) + stream, _, err = client3.dial([]string{"s3"}, "", &utils.NetAddr{}, &utils.NetAddr{}, "") require.NoError(t, err) require.NotNil(t, stream) stream.Close() From e162dcede0dcd814769ebf5f4b4d4b355d35097b Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Mon, 7 Oct 2024 12:07:42 +0200 Subject: [PATCH 2/5] Convert the peer server to slog --- lib/proxy/peer/credentials.go | 5 +++-- lib/proxy/peer/server.go | 8 ++++---- lib/proxy/peer/service.go | 18 +++++++++--------- lib/proxy/peer/service_test.go | 4 ++-- lib/service/service.go | 2 +- 5 files changed, 19 insertions(+), 18 deletions(-) diff --git a/lib/proxy/peer/credentials.go b/lib/proxy/peer/credentials.go index 3c7c9026a0320..6ab47123c6159 100644 --- a/lib/proxy/peer/credentials.go +++ b/lib/proxy/peer/credentials.go @@ -22,6 +22,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "log/slog" "net" "github.com/gravitational/trace" @@ -167,13 +168,13 @@ func validatePeer(peerID string, identity *tlsca.Identity) error { // getConfigForClient clones and updates the server's tls config with the // appropriate client certificate authorities. -func getConfigForClient(tlsConfig *tls.Config, ap authclient.CAGetter, log logrus.FieldLogger, clusterName string) func(*tls.ClientHelloInfo) (*tls.Config, error) { +func getConfigForClient(tlsConfig *tls.Config, ap authclient.CAGetter, log *slog.Logger, clusterName string) func(*tls.ClientHelloInfo) (*tls.Config, error) { return func(info *tls.ClientHelloInfo) (*tls.Config, error) { tlsCopy := tlsConfig.Clone() pool, err := getCertPool(info.Context(), ap, clusterName) if err != nil { - log.WithError(err).Error("Failed to retrieve client CA pool.") + log.ErrorContext(info.Context(), "Failed to retrieve client CA pool.", "error", err) return tlsCopy, nil } diff --git a/lib/proxy/peer/server.go b/lib/proxy/peer/server.go index ec46bb8697588..cc0b97128b575 100644 --- a/lib/proxy/peer/server.go +++ b/lib/proxy/peer/server.go @@ -21,12 +21,12 @@ package peer import ( "crypto/tls" "errors" + "log/slog" "math" "net" "time" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" @@ -50,7 +50,7 @@ type ServerConfig struct { Listener net.Listener TLSConfig *tls.Config ClusterDialer ClusterDialer - Log logrus.FieldLogger + Log *slog.Logger ClusterName string // getConfigForClient gets the client tls config. @@ -65,9 +65,9 @@ type ServerConfig struct { // checkAndSetDefaults checks and sets default values func (c *ServerConfig) checkAndSetDefaults() error { if c.Log == nil { - c.Log = logrus.New() + c.Log = slog.Default() } - c.Log = c.Log.WithField( + c.Log = c.Log.With( teleport.ComponentKey, teleport.Component(teleport.ComponentProxy, "peer"), ) diff --git a/lib/proxy/peer/service.go b/lib/proxy/peer/service.go index 43fb0e20230dd..70610b75120b8 100644 --- a/lib/proxy/peer/service.go +++ b/lib/proxy/peer/service.go @@ -19,11 +19,11 @@ package peer import ( + "log/slog" "net" "strings" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" @@ -34,7 +34,7 @@ import ( // proxyService implements the grpc ProxyService. type proxyService struct { clusterDialer ClusterDialer - log logrus.FieldLogger + log *slog.Logger } // DialNode opens a bidirectional stream to the requested node. @@ -54,12 +54,12 @@ func (s *proxyService) DialNode(stream proto.ProxyService_DialNodeServer) error return trace.BadParameter("invalid dial request: source and destination must not be nil") } - log := s.log.WithFields(logrus.Fields{ - "node": dial.NodeID, - "src": dial.Source.Addr, - "dst": dial.Destination.Addr, - }) - log.Debugf("Dial request from peer.") + log := s.log.With( + "node", dial.NodeID, + "src", dial.Source.Addr, + "dst", dial.Destination.Addr, + ) + log.DebugContext(stream.Context(), "Dial request from peer.") _, clusterName, err := splitServerID(dial.NodeID) if err != nil { @@ -103,7 +103,7 @@ func (s *proxyService) DialNode(stream proto.ProxyService_DialNodeServer) error err = utils.ProxyConn(stream.Context(), streamConn, nodeConn) sent, received := streamConn.Stat() - log.Debugf("Closing dial request from peer. sent: %d received %d", sent, received) + log.DebugContext(stream.Context(), "Closing dial request from peer.", "sent", sent, "received", received) return trace.Wrap(err) } diff --git a/lib/proxy/peer/service_test.go b/lib/proxy/peer/service_test.go index 9f66860b92978..687759e0892d8 100644 --- a/lib/proxy/peer/service_test.go +++ b/lib/proxy/peer/service_test.go @@ -20,11 +20,11 @@ package peer import ( "context" + "log/slog" "net" "testing" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -52,7 +52,7 @@ func setupService(t *testing.T) (*proxyService, proto.ProxyServiceClient) { require.NoError(t, err) proxyService := &proxyService{ - log: logrus.New(), + log: slog.Default(), } proto.RegisterProxyServiceServer(server, proxyService) diff --git a/lib/service/service.go b/lib/service/service.go index 83c11173898ba..09cf6106d7c63 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -4588,7 +4588,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { Listener: listeners.proxyPeer, TLSConfig: serverTLSConfig, ClusterDialer: clusterdial.NewClusterDialer(tsrv), - Log: process.log.WithField(teleport.ComponentKey, teleport.Component(teleport.ComponentReverseTunnelServer, process.id)), + Log: process.logger, ClusterName: clusterName, }) if err != nil { From 35678ff8ecbe8b62c6a7f9b306c0750847bc243a Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Tue, 12 Nov 2024 18:55:52 +0100 Subject: [PATCH 3/5] Move lib/proxy/clusterdial to lib/peer/dial --- lib/proxy/clusterdial/dial.go | 60 --------------------------------- lib/proxy/peer/dial/dial.go | 37 ++++++++++++++++++++ lib/proxy/peer/helpers_test.go | 2 +- lib/proxy/peer/server.go | 21 ++++++------ lib/proxy/peer/service.go | 21 +++--------- lib/proxy/peer/service_test.go | 9 ++--- lib/reversetunnelclient/peer.go | 57 +++++++++++++++++++++++++++++++ lib/service/service.go | 14 ++++---- 8 files changed, 121 insertions(+), 100 deletions(-) delete mode 100644 lib/proxy/clusterdial/dial.go create mode 100644 lib/proxy/peer/dial/dial.go create mode 100644 lib/reversetunnelclient/peer.go diff --git a/lib/proxy/clusterdial/dial.go b/lib/proxy/clusterdial/dial.go deleted file mode 100644 index dc8ce3f4b1f73..0000000000000 --- a/lib/proxy/clusterdial/dial.go +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package clusterdial - -import ( - "net" - - "github.com/gravitational/trace" - - "github.com/gravitational/teleport/lib/proxy/peer" - "github.com/gravitational/teleport/lib/reversetunnelclient" -) - -// ClusterDialerFunc is a function that implements a peer.ClusterDialer. -type ClusterDialerFunc func(clusterName string, request peer.DialParams) (net.Conn, error) - -// Dial dials makes a dial request to the given cluster. -func (f ClusterDialerFunc) Dial(clusterName string, request peer.DialParams) (net.Conn, error) { - return f(clusterName, request) -} - -// NewClusterDialer implements proxy.ClusterDialer for a reverse tunnel server. -func NewClusterDialer(server reversetunnelclient.Server) ClusterDialerFunc { - return func(clusterName string, request peer.DialParams) (net.Conn, error) { - site, err := server.GetSite(clusterName) - if err != nil { - return nil, trace.Wrap(err) - } - - dialParams := reversetunnelclient.DialParams{ - ServerID: request.ServerID, - ConnType: request.ConnType, - From: request.From, - To: request.To, - FromPeerProxy: true, - } - - conn, err := site.Dial(dialParams) - if err != nil { - return nil, trace.Wrap(err) - } - return conn, nil - } -} diff --git a/lib/proxy/peer/dial/dial.go b/lib/proxy/peer/dial/dial.go new file mode 100644 index 0000000000000..22a23344bf77d --- /dev/null +++ b/lib/proxy/peer/dial/dial.go @@ -0,0 +1,37 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package peerdial + +import ( + "net" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/utils" +) + +// Dialer dials a node in the given cluster. +type Dialer interface { + Dial(clusterName string, request DialParams) (net.Conn, error) +} + +// DialParams defines the target for a [Dialer.Dial]. +type DialParams struct { + From *utils.NetAddr + To *utils.NetAddr + ServerID string + ConnType types.TunnelType +} diff --git a/lib/proxy/peer/helpers_test.go b/lib/proxy/peer/helpers_test.go index 8ade7b887d8da..2c7e34ef439e3 100644 --- a/lib/proxy/peer/helpers_test.go +++ b/lib/proxy/peer/helpers_test.go @@ -236,7 +236,7 @@ func setupServer(t *testing.T, name string, serverCA, clientCA *tlsca.CertAuthor AccessCache: &mockCAGetter{}, Listener: listener, TLSConfig: tlsConf, - ClusterDialer: &mockClusterDialer{}, + Dialer: &mockClusterDialer{}, getConfigForClient: getConfigForClient, service: &mockProxyService{}, ClusterName: "test", diff --git a/lib/proxy/peer/server.go b/lib/proxy/peer/server.go index cc0b97128b575..be83bd4705757 100644 --- a/lib/proxy/peer/server.go +++ b/lib/proxy/peer/server.go @@ -36,6 +36,7 @@ import ( "github.com/gravitational/teleport/api/metadata" "github.com/gravitational/teleport/api/utils/grpc/interceptors" "github.com/gravitational/teleport/lib/auth/authclient" + peerdial "github.com/gravitational/teleport/lib/proxy/peer/dial" "github.com/gravitational/teleport/lib/utils" ) @@ -46,12 +47,12 @@ const ( // ServerConfig configures a Server instance. type ServerConfig struct { - AccessCache authclient.CAGetter - Listener net.Listener - TLSConfig *tls.Config - ClusterDialer ClusterDialer - Log *slog.Logger - ClusterName string + AccessCache authclient.CAGetter + Listener net.Listener + TLSConfig *tls.Config + Dialer peerdial.Dialer + Log *slog.Logger + ClusterName string // getConfigForClient gets the client tls config. // configurable for testing purposes. @@ -80,8 +81,8 @@ func (c *ServerConfig) checkAndSetDefaults() error { return trace.BadParameter("missing listener") } - if c.ClusterDialer == nil { - return trace.BadParameter("missing cluster dialer server") + if c.Dialer == nil { + return trace.BadParameter("missing Dialer") } if c.ClusterName == "" { @@ -106,8 +107,8 @@ func (c *ServerConfig) checkAndSetDefaults() error { if c.service == nil { c.service = &proxyService{ - c.ClusterDialer, - c.Log, + dialer: c.Dialer, + log: c.Log, } } diff --git a/lib/proxy/peer/service.go b/lib/proxy/peer/service.go index 70610b75120b8..772399074051f 100644 --- a/lib/proxy/peer/service.go +++ b/lib/proxy/peer/service.go @@ -20,21 +20,20 @@ package peer import ( "log/slog" - "net" "strings" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/client/proto" - "github.com/gravitational/teleport/api/types" streamutils "github.com/gravitational/teleport/api/utils/grpc/stream" + peerdial "github.com/gravitational/teleport/lib/proxy/peer/dial" "github.com/gravitational/teleport/lib/utils" ) // proxyService implements the grpc ProxyService. type proxyService struct { - clusterDialer ClusterDialer - log *slog.Logger + dialer peerdial.Dialer + log *slog.Logger } // DialNode opens a bidirectional stream to the requested node. @@ -75,7 +74,7 @@ func (s *proxyService) DialNode(stream proto.ProxyService_DialNodeServer) error AddrNetwork: dial.Destination.Network, } - nodeConn, err := s.clusterDialer.Dial(clusterName, DialParams{ + nodeConn, err := s.dialer.Dial(clusterName, peerdial.DialParams{ From: source, To: destination, ServerID: dial.NodeID, @@ -116,15 +115,3 @@ func splitServerID(address string) (string, string, error) { return split[0], strings.Join(split[1:], "."), nil } - -// ClusterDialer dials a node in the given cluster. -type ClusterDialer interface { - Dial(clusterName string, request DialParams) (net.Conn, error) -} - -type DialParams struct { - From *utils.NetAddr - To *utils.NetAddr - ServerID string - ConnType types.TunnelType -} diff --git a/lib/proxy/peer/service_test.go b/lib/proxy/peer/service_test.go index 687759e0892d8..1510d65d9de2d 100644 --- a/lib/proxy/peer/service_test.go +++ b/lib/proxy/peer/service_test.go @@ -31,13 +31,14 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" + peerdial "github.com/gravitational/teleport/lib/proxy/peer/dial" ) type mockClusterDialer struct { - MockDialCluster func(string, DialParams) (net.Conn, error) + MockDialCluster func(string, peerdial.DialParams) (net.Conn, error) } -func (m *mockClusterDialer) Dial(clusterName string, request DialParams) (net.Conn, error) { +func (m *mockClusterDialer) Dial(clusterName string, request peerdial.DialParams) (net.Conn, error) { if m.MockDialCluster == nil { return nil, trace.NotImplemented("") } @@ -93,8 +94,8 @@ func TestSendReceive(t *testing.T) { } local, remote := net.Pipe() - service.clusterDialer = &mockClusterDialer{ - MockDialCluster: func(clusterName string, request DialParams) (net.Conn, error) { + service.dialer = &mockClusterDialer{ + MockDialCluster: func(clusterName string, request peerdial.DialParams) (net.Conn, error) { require.Equal(t, "test-cluster", clusterName) require.Equal(t, dialRequest.TunnelType, request.ConnType) require.Equal(t, dialRequest.NodeID, request.ServerID) diff --git a/lib/reversetunnelclient/peer.go b/lib/reversetunnelclient/peer.go new file mode 100644 index 0000000000000..00266d53b7df1 --- /dev/null +++ b/lib/reversetunnelclient/peer.go @@ -0,0 +1,57 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package reversetunnelclient + +import ( + "net" + + "github.com/gravitational/trace" + + peerdial "github.com/gravitational/teleport/lib/proxy/peer/dial" +) + +// PeerDialerFunc is a function that implements [peerdial.Dialer]. +type PeerDialerFunc func(clusterName string, request peerdial.DialParams) (net.Conn, error) + +// Dial implements [peerdial.Dialer]. +func (f PeerDialerFunc) Dial(clusterName string, request peerdial.DialParams) (net.Conn, error) { + return f(clusterName, request) +} + +// NewPeerDialer implements [peerdial.Dialer] for a reverse tunnel server. +func NewPeerDialer(server Tunnel) PeerDialerFunc { + return func(clusterName string, request peerdial.DialParams) (net.Conn, error) { + site, err := server.GetSite(clusterName) + if err != nil { + return nil, trace.Wrap(err) + } + + dialParams := DialParams{ + ServerID: request.ServerID, + ConnType: request.ConnType, + From: request.From, + To: request.To, + FromPeerProxy: true, + } + + conn, err := site.Dial(dialParams) + if err != nil { + return nil, trace.Wrap(err) + } + return conn, nil + } +} diff --git a/lib/service/service.go b/lib/service/service.go index 09cf6106d7c63..42ed6fe5737a9 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -134,7 +134,6 @@ import ( "github.com/gravitational/teleport/lib/openssh" "github.com/gravitational/teleport/lib/plugin" "github.com/gravitational/teleport/lib/proxy" - "github.com/gravitational/teleport/lib/proxy/clusterdial" "github.com/gravitational/teleport/lib/proxy/peer" "github.com/gravitational/teleport/lib/resumption" "github.com/gravitational/teleport/lib/reversetunnel" @@ -4584,12 +4583,12 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } peerAddrString = peerAddr.String() proxyServer, err = peer.NewServer(peer.ServerConfig{ - AccessCache: accessPoint, - Listener: listeners.proxyPeer, - TLSConfig: serverTLSConfig, - ClusterDialer: clusterdial.NewClusterDialer(tsrv), - Log: process.logger, - ClusterName: clusterName, + AccessCache: accessPoint, + Listener: listeners.proxyPeer, + TLSConfig: serverTLSConfig, + Dialer: reversetunnelclient.NewPeerDialer(tsrv), + Log: process.logger, + ClusterName: clusterName, }) if err != nil { return trace.Wrap(err) @@ -6439,7 +6438,6 @@ func (process *TeleportProcess) initPublicGRPCServer( }) if err != nil { return nil, trace.Wrap(err) - } accessgraphsecretsv1pb.RegisterSecretsScannerServiceServer(server, accessGraphProxySvc) From 7eb2af1821d682f2fa7f3b5b7eba156123087cfc Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Tue, 12 Nov 2024 19:05:34 +0100 Subject: [PATCH 4/5] Move peer.clientConn to lib/proxy/peer/internal --- lib/proxy/peer/client.go | 81 +++++++++------------------ lib/proxy/peer/client_test.go | 3 +- lib/proxy/peer/internal/clientconn.go | 50 +++++++++++++++++ 3 files changed, 80 insertions(+), 54 deletions(-) create mode 100644 lib/proxy/peer/internal/clientconn.go diff --git a/lib/proxy/peer/client.go b/lib/proxy/peer/client.go index ed322c3307936..e4e425c93ad6a 100644 --- a/lib/proxy/peer/client.go +++ b/lib/proxy/peer/client.go @@ -42,6 +42,7 @@ import ( streamutils "github.com/gravitational/teleport/api/utils/grpc/stream" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/proxy/peer/internal" "github.com/gravitational/teleport/lib/services" ) @@ -86,11 +87,11 @@ type ClientConfig struct { } // connShuffler shuffles the order of client connections. -type connShuffler func([]clientConn) +type connShuffler func([]internal.ClientConn) // randomConnShuffler returns a conn shuffler that randomizes the order of connections. func randomConnShuffler() connShuffler { - return func(conns []clientConn) { + return func(conns []internal.ClientConn) { rand.Shuffle(len(conns), func(i, j int) { conns[i], conns[j] = conns[j], conns[i] }) @@ -99,7 +100,7 @@ func randomConnShuffler() connShuffler { // noopConnShutffler returns a conn shuffler that keeps the original connection ordering. func noopConnShuffler() connShuffler { - return func([]clientConn) {} + return func([]internal.ClientConn) {} } // checkAndSetDefaults checks and sets default values @@ -160,32 +161,6 @@ func (c *ClientConfig) checkAndSetDefaults() error { return nil } -// clientConn manages client connections to a specific peer proxy (with a fixed -// host ID and address). -type clientConn interface { - // peerID returns the host ID of the peer proxy. - peerID() string - // peerAddr returns the address of the peer proxy. - peerAddr() string - - // dial opens a connection of a given tunnel type to a node with the given - // ID through the peer proxy managed by the clientConn. - dial( - nodeID string, - src net.Addr, - dst net.Addr, - tunnelType types.TunnelType, - ) (net.Conn, error) - - // close closes all connections and releases any background resources - // immediately. - close() error - - // shutdown waits until all connections are closed or the context is done, - // then acts like close. - shutdown(context.Context) -} - // grpcClientConn manages client connections to a specific peer proxy over gRPC. type grpcClientConn struct { cc *grpc.ClientConn @@ -202,13 +177,13 @@ type grpcClientConn struct { count int } -var _ clientConn = (*grpcClientConn)(nil) +var _ internal.ClientConn = (*grpcClientConn)(nil) -// peerID implements [clientConn]. -func (c *grpcClientConn) peerID() string { return c.id } +// PeerID implements [internal.ClientConn]. +func (c *grpcClientConn) PeerID() string { return c.id } -// peerAddr implements [clientConn]. -func (c *grpcClientConn) peerAddr() string { return c.addr } +// PeerAddr implements [internal.ClientConn]. +func (c *grpcClientConn) PeerAddr() string { return c.addr } // maybeAcquire returns a non-nil release func if the grpcClientConn is // currently allowed to open connections; i.e., if it hasn't fully shut down. @@ -231,8 +206,8 @@ func (c *grpcClientConn) maybeAcquire() (release func()) { }) } -// shutdown implements [clientConn]. -func (c *grpcClientConn) shutdown(ctx context.Context) { +// Shutdown implements [internal.ClientConn]. +func (c *grpcClientConn) Shutdown(ctx context.Context) { defer c.cc.Close() c.mu.Lock() @@ -252,13 +227,13 @@ func (c *grpcClientConn) shutdown(ctx context.Context) { } } -// close implements [clientConn]. -func (c *grpcClientConn) close() error { +// Close implements [internal.ClientConn]. +func (c *grpcClientConn) Close() error { return c.cc.Close() } -// dial implements [clientConn]. -func (c *grpcClientConn) dial( +// Dial implements [internal.ClientConn]. +func (c *grpcClientConn) Dial( nodeID string, src net.Addr, dst net.Addr, @@ -332,7 +307,7 @@ type Client struct { cancel context.CancelFunc config ClientConfig - conns map[string]clientConn + conns map[string]internal.ClientConn metrics *clientMetrics reporter *reporter } @@ -357,7 +332,7 @@ func NewClient(config ClientConfig) (*Client, error) { config: config, ctx: closeContext, cancel: cancel, - conns: make(map[string]clientConn), + conns: make(map[string]internal.ClientConn), metrics: metrics, reporter: reporter, } @@ -449,7 +424,7 @@ func (c *Client) updateConnections(proxies []types.Server) error { } var toDelete []string - toKeep := make(map[string]clientConn) + toKeep := make(map[string]internal.ClientConn) for id, conn := range c.conns { proxy, ok := toDial[id] @@ -460,7 +435,7 @@ func (c *Client) updateConnections(proxies []types.Server) error { } // peer address changed - if conn.peerAddr() != proxy.GetPeerAddr() { + if conn.PeerAddr() != proxy.GetPeerAddr() { toDelete = append(toDelete, id) continue } @@ -498,7 +473,7 @@ func (c *Client) updateConnections(proxies []types.Server) error { for _, id := range toDelete { if conn, ok := c.conns[id]; ok { - go conn.shutdown(c.ctx) + go conn.Shutdown(c.ctx) } } c.conns = toKeep @@ -551,9 +526,9 @@ func (c *Client) Shutdown(ctx context.Context) { var wg sync.WaitGroup for _, conn := range c.conns { wg.Add(1) - go func(conn clientConn) { + go func(conn internal.ClientConn) { defer wg.Done() - conn.shutdown(ctx) + conn.Shutdown(ctx) }(conn) } wg.Wait() @@ -567,7 +542,7 @@ func (c *Client) Stop() error { var errs []error for _, conn := range c.conns { - if err := conn.close(); err != nil { + if err := conn.Close(); err != nil { errs = append(errs, err) } } @@ -622,7 +597,7 @@ func (c *Client) dial( var errs []error for _, clientConn := range conns { - conn, err := clientConn.dial(nodeID, src, dst, tunnelType) + conn, err := clientConn.Dial(nodeID, src, dst, tunnelType) if err != nil { errs = append(errs, trace.Wrap(err)) continue @@ -638,13 +613,13 @@ func (c *Client) dial( // otherwise. // The boolean returned in the second argument is intended for testing purposes, // to indicates whether the connection was cached or newly established. -func (c *Client) getConnections(proxyIDs []string) ([]clientConn, bool, error) { +func (c *Client) getConnections(proxyIDs []string) ([]internal.ClientConn, bool, error) { if len(proxyIDs) == 0 { return nil, false, trace.BadParameter("failed to dial: no proxy ids given") } ids := make(map[string]struct{}) - var conns []clientConn + var conns []internal.ClientConn // look for existing matching connections. c.RLock() @@ -701,7 +676,7 @@ func (c *Client) getConnections(proxyIDs []string) ([]clientConn, bool, error) { defer c.Unlock() for _, conn := range conns { - c.conns[conn.peerID()] = conn + c.conns[conn.PeerID()] = conn } c.config.connShuffler(conns) @@ -709,7 +684,7 @@ func (c *Client) getConnections(proxyIDs []string) ([]clientConn, bool, error) { } // connect dials a new connection to proxyAddr. -func (c *Client) connect(peerID string, peerAddr string) (clientConn, error) { +func (c *Client) connect(peerID string, peerAddr string) (internal.ClientConn, error) { tlsConfig, err := c.config.getConfigForServer() if err != nil { return nil, trace.Wrap(err, "Error updating client tls config") diff --git a/lib/proxy/peer/client_test.go b/lib/proxy/peer/client_test.go index c403b209ea03d..81c4d6d814437 100644 --- a/lib/proxy/peer/client_test.go +++ b/lib/proxy/peer/client_test.go @@ -32,6 +32,7 @@ import ( "github.com/gravitational/teleport/api/client/proto" clientapi "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/proxy/peer/internal" "github.com/gravitational/teleport/lib/utils" ) @@ -212,7 +213,7 @@ func TestBackupClient(t *testing.T) { require.True(t, dialCalled) } -func waitForGRPCConns(t *testing.T, conns map[string]clientConn, d time.Duration) { +func waitForGRPCConns(t *testing.T, conns map[string]internal.ClientConn, d time.Duration) { require.Eventually(t, func() bool { for _, conn := range conns { // panic if we hit a non-grpc client conn diff --git a/lib/proxy/peer/internal/clientconn.go b/lib/proxy/peer/internal/clientconn.go new file mode 100644 index 0000000000000..f44e64afd7b52 --- /dev/null +++ b/lib/proxy/peer/internal/clientconn.go @@ -0,0 +1,50 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package internal + +import ( + "context" + "net" + + "github.com/gravitational/teleport/api/types" +) + +// ClientConn manages client connections to a specific peer proxy (with a fixed +// host ID and address). +type ClientConn interface { + // PeerID returns the host ID of the peer proxy. + PeerID() string + // PeerAddr returns the address of the peer proxy. + PeerAddr() string + + // Dial opens a connection of a given tunnel type to a node with the given + // ID through the peer proxy managed by the clientConn. + Dial( + nodeID string, + src net.Addr, + dst net.Addr, + tunnelType types.TunnelType, + ) (net.Conn, error) + + // Close closes all connections and releases any background resources + // immediately. + Close() error + + // Shutdown waits until all connections are closed or the context is done, + // then acts like Close. + Shutdown(context.Context) +} From 63eee1c103e062f47e8bb5c16223a2f5c08fcdac Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Wed, 13 Nov 2024 18:46:14 +0100 Subject: [PATCH 5/5] Periodically check connectivity between peer proxies (#48838) --- api/client/proto/proxyservice.pb.go | 359 ++++++++++++++++-- .../legacy/client/proto/proxyservice.proto | 7 + lib/proxy/peer/client.go | 84 +++- lib/proxy/peer/client_test.go | 21 +- lib/proxy/peer/helpers_test.go | 1 + lib/proxy/peer/internal/clientconn.go | 3 + lib/proxy/peer/internal/metrics.go | 110 ++++++ lib/proxy/peer/service.go | 5 + 8 files changed, 543 insertions(+), 47 deletions(-) create mode 100644 lib/proxy/peer/internal/metrics.go diff --git a/api/client/proto/proxyservice.pb.go b/api/client/proto/proxyservice.pb.go index 4b088e27d928d..d2695c0784251 100644 --- a/api/client/proto/proxyservice.pb.go +++ b/api/client/proto/proxyservice.pb.go @@ -353,12 +353,92 @@ func (m *ConnectionEstablished) XXX_DiscardUnknown() { var xxx_messageInfo_ConnectionEstablished proto.InternalMessageInfo +type ProxyServicePingRequest struct { + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ProxyServicePingRequest) Reset() { *m = ProxyServicePingRequest{} } +func (m *ProxyServicePingRequest) String() string { return proto.CompactTextString(m) } +func (*ProxyServicePingRequest) ProtoMessage() {} +func (*ProxyServicePingRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_b76fff22d4479739, []int{5} +} +func (m *ProxyServicePingRequest) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *ProxyServicePingRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_ProxyServicePingRequest.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *ProxyServicePingRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_ProxyServicePingRequest.Merge(m, src) +} +func (m *ProxyServicePingRequest) XXX_Size() int { + return m.Size() +} +func (m *ProxyServicePingRequest) XXX_DiscardUnknown() { + xxx_messageInfo_ProxyServicePingRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_ProxyServicePingRequest proto.InternalMessageInfo + +type ProxyServicePingResponse struct { + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ProxyServicePingResponse) Reset() { *m = ProxyServicePingResponse{} } +func (m *ProxyServicePingResponse) String() string { return proto.CompactTextString(m) } +func (*ProxyServicePingResponse) ProtoMessage() {} +func (*ProxyServicePingResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_b76fff22d4479739, []int{6} +} +func (m *ProxyServicePingResponse) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *ProxyServicePingResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_ProxyServicePingResponse.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *ProxyServicePingResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_ProxyServicePingResponse.Merge(m, src) +} +func (m *ProxyServicePingResponse) XXX_Size() int { + return m.Size() +} +func (m *ProxyServicePingResponse) XXX_DiscardUnknown() { + xxx_messageInfo_ProxyServicePingResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_ProxyServicePingResponse proto.InternalMessageInfo + func init() { proto.RegisterType((*Frame)(nil), "proto.Frame") proto.RegisterType((*DialRequest)(nil), "proto.DialRequest") proto.RegisterType((*NetAddr)(nil), "proto.NetAddr") proto.RegisterType((*Data)(nil), "proto.Data") proto.RegisterType((*ConnectionEstablished)(nil), "proto.ConnectionEstablished") + proto.RegisterType((*ProxyServicePingRequest)(nil), "proto.ProxyServicePingRequest") + proto.RegisterType((*ProxyServicePingResponse)(nil), "proto.ProxyServicePingResponse") } func init() { @@ -366,34 +446,36 @@ func init() { } var fileDescriptor_b76fff22d4479739 = []byte{ - // 422 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x52, 0x41, 0x6f, 0xd3, 0x30, - 0x14, 0xae, 0xa1, 0xed, 0xe8, 0x4b, 0xc5, 0xc1, 0x1a, 0x10, 0x4d, 0x53, 0x81, 0x1c, 0xd0, 0xc4, - 0x21, 0x99, 0x8a, 0x34, 0xa4, 0x9d, 0x20, 0x04, 0x54, 0x0e, 0x4c, 0xc8, 0xeb, 0x69, 0x37, 0x37, - 0x79, 0xca, 0x2c, 0xb2, 0x38, 0xd8, 0xce, 0x20, 0xbf, 0x8f, 0x0b, 0x47, 0xee, 0x48, 0x08, 0xf5, - 0x67, 0x70, 0x42, 0x71, 0x5c, 0x35, 0x95, 0x8a, 0xc4, 0x25, 0x7e, 0xcf, 0xef, 0x7b, 0x5f, 0xbe, - 0xf7, 0xf9, 0x41, 0x64, 0xb0, 0xc0, 0x4a, 0x2a, 0x13, 0x15, 0x98, 0xf3, 0xb4, 0x89, 0xd2, 0x42, - 0x60, 0x69, 0xa2, 0x4a, 0x49, 0x23, 0xdb, 0xef, 0xd7, 0x46, 0xa3, 0xba, 0x15, 0x29, 0x86, 0xf6, - 0x8a, 0x8e, 0xec, 0x71, 0x74, 0x98, 0xcb, 0x5c, 0x76, 0xa0, 0x36, 0xea, 0x8a, 0xc1, 0x37, 0x02, - 0xa3, 0x77, 0x8a, 0xdf, 0x20, 0x3d, 0x03, 0x2f, 0x11, 0xbc, 0x60, 0xf8, 0xb9, 0x46, 0x6d, 0x7c, - 0xf2, 0x84, 0x9c, 0x78, 0x73, 0xda, 0xc1, 0xc2, 0x5e, 0x65, 0x31, 0x60, 0x7d, 0x20, 0x5d, 0xc2, - 0x83, 0x37, 0xb2, 0x2c, 0x31, 0x35, 0x42, 0x96, 0x6f, 0xb5, 0xe1, 0xab, 0x42, 0xe8, 0x6b, 0xcc, - 0xfc, 0x3b, 0x96, 0xe1, 0xd8, 0x31, 0xec, 0xc5, 0x2c, 0x06, 0x6c, 0x7f, 0x33, 0x7d, 0x0a, 0xc3, - 0x84, 0x1b, 0xee, 0xdf, 0xb5, 0x24, 0xde, 0x46, 0x06, 0x37, 0x7c, 0x31, 0x60, 0xb6, 0x14, 0x4f, - 0xe0, 0xe0, 0x03, 0x6a, 0xcd, 0x73, 0x0c, 0x7e, 0x92, 0x1d, 0xf1, 0xf4, 0x21, 0x8c, 0x2f, 0x64, - 0x86, 0xef, 0x13, 0x3b, 0xc6, 0x84, 0xb9, 0x8c, 0x5e, 0x01, 0x2c, 0xeb, 0xb2, 0xc4, 0x62, 0xd9, - 0x54, 0x68, 0x05, 0x4e, 0xe2, 0xf3, 0x3f, 0xbf, 0x1e, 0x9f, 0xe5, 0xc2, 0x5c, 0xd7, 0xab, 0x30, - 0x95, 0x37, 0x51, 0xae, 0xf8, 0xad, 0x30, 0xbc, 0x15, 0xc4, 0x8b, 0xad, 0xd9, 0xbc, 0x12, 0x91, - 0x69, 0x2a, 0xd4, 0xe1, 0x96, 0x81, 0xf5, 0xd8, 0xe8, 0x33, 0x18, 0x5f, 0xca, 0x5a, 0xa5, 0xe8, - 0x34, 0xdf, 0x77, 0x9a, 0x2f, 0xd0, 0xbc, 0xce, 0x32, 0xc5, 0x5c, 0x95, 0x9e, 0x82, 0x97, 0xa0, - 0x36, 0xa2, 0xb4, 0xbf, 0xf0, 0x87, 0x7b, 0xc1, 0x7d, 0x48, 0xf0, 0x12, 0x0e, 0xdc, 0x3d, 0xf5, - 0x6d, 0xf8, 0x45, 0xaa, 0x4f, 0x6e, 0xb2, 0x4d, 0x4a, 0x29, 0x0c, 0x5b, 0x44, 0x37, 0x14, 0xb3, - 0x71, 0x70, 0xdc, 0x99, 0x48, 0x0f, 0x61, 0x14, 0x37, 0x06, 0xb5, 0xed, 0x99, 0xb2, 0x2e, 0x09, - 0x1e, 0xfd, 0xe3, 0xe1, 0xe6, 0xe7, 0x30, 0xfd, 0xd8, 0xae, 0xd1, 0x65, 0xb7, 0x46, 0xf4, 0x39, - 0xdc, 0x6b, 0xcd, 0x6d, 0x3d, 0xa4, 0x53, 0x27, 0xd4, 0xee, 0xcc, 0xd1, 0x4e, 0x76, 0x42, 0x4e, - 0x49, 0xfc, 0xea, 0xfb, 0x7a, 0x46, 0x7e, 0xac, 0x67, 0xe4, 0xf7, 0x7a, 0x46, 0xae, 0xe6, 0xff, - 0xe7, 0x6b, 0x7f, 0x83, 0x57, 0x63, 0x7b, 0xbc, 0xf8, 0x1b, 0x00, 0x00, 0xff, 0xff, 0xe0, 0x08, - 0x45, 0x85, 0xe8, 0x02, 0x00, 0x00, + // 460 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x53, 0xbf, 0x6f, 0xd3, 0x40, + 0x14, 0xce, 0x41, 0x92, 0x92, 0x97, 0x88, 0xe1, 0xa9, 0x50, 0x13, 0x55, 0x29, 0x78, 0x40, 0x15, + 0x43, 0x5c, 0x05, 0xa9, 0x48, 0x4c, 0x10, 0x02, 0x84, 0x81, 0xaa, 0xba, 0x66, 0xea, 0x76, 0x71, + 0x9e, 0xdc, 0x13, 0xee, 0x9d, 0xf1, 0x5d, 0x0a, 0x9e, 0xf9, 0xd3, 0x58, 0x18, 0xd9, 0x91, 0x10, + 0xca, 0x9f, 0xc1, 0x84, 0x7c, 0xbe, 0xaa, 0x8e, 0x48, 0xa4, 0x2e, 0xb9, 0xf7, 0xee, 0xfb, 0xee, + 0xcb, 0xf7, 0x7e, 0x18, 0x22, 0x4b, 0x29, 0x65, 0x3a, 0xb7, 0x51, 0x4a, 0x89, 0x88, 0x8b, 0x28, + 0x4e, 0x25, 0x29, 0x1b, 0x65, 0xb9, 0xb6, 0xba, 0xfc, 0xfd, 0x5a, 0x18, 0xca, 0xaf, 0x64, 0x4c, + 0x43, 0x77, 0x85, 0x2d, 0x77, 0xf4, 0x77, 0x13, 0x9d, 0xe8, 0x8a, 0x54, 0x46, 0x15, 0x18, 0x7e, + 0x67, 0xd0, 0x7a, 0x97, 0x8b, 0x4b, 0xc2, 0x63, 0xe8, 0x4e, 0xa4, 0x48, 0x39, 0x7d, 0x5e, 0x92, + 0xb1, 0x01, 0x7b, 0xcc, 0x0e, 0xbb, 0x23, 0xac, 0x68, 0xc3, 0x1a, 0x32, 0x6d, 0xf0, 0x3a, 0x11, + 0x67, 0xf0, 0xe0, 0x8d, 0x56, 0x8a, 0x62, 0x2b, 0xb5, 0x7a, 0x6b, 0xac, 0x98, 0xa7, 0xd2, 0x5c, + 0xd0, 0x22, 0xb8, 0xe3, 0x14, 0xf6, 0xbd, 0xc2, 0x46, 0xce, 0xb4, 0xc1, 0x37, 0x3f, 0xc6, 0x27, + 0xd0, 0x9c, 0x08, 0x2b, 0x82, 0xbb, 0x4e, 0xa4, 0x7b, 0x6d, 0x43, 0x58, 0x31, 0x6d, 0x70, 0x07, + 0x8d, 0x3b, 0xb0, 0xf3, 0x91, 0x8c, 0x11, 0x09, 0x85, 0xbf, 0xd8, 0x9a, 0x79, 0x7c, 0x08, 0xed, + 0x13, 0xbd, 0xa0, 0x0f, 0x13, 0x57, 0x46, 0x87, 0xfb, 0x0c, 0xcf, 0x01, 0x66, 0x4b, 0xa5, 0x28, + 0x9d, 0x15, 0x19, 0x39, 0x83, 0x9d, 0xf1, 0xcb, 0xbf, 0xbf, 0x0f, 0x8e, 0x13, 0x69, 0x2f, 0x96, + 0xf3, 0x61, 0xac, 0x2f, 0xa3, 0x24, 0x17, 0x57, 0xd2, 0x8a, 0xd2, 0x90, 0x48, 0x6f, 0x9a, 0x2d, + 0x32, 0x19, 0xd9, 0x22, 0x23, 0x33, 0xbc, 0x51, 0xe0, 0x35, 0x35, 0x7c, 0x0a, 0xed, 0x33, 0xbd, + 0xcc, 0x63, 0xf2, 0x9e, 0xef, 0x7b, 0xcf, 0x27, 0x64, 0x5f, 0x2f, 0x16, 0x39, 0xf7, 0x28, 0x1e, + 0x41, 0x77, 0x42, 0xc6, 0x4a, 0xe5, 0xfe, 0x22, 0x68, 0x6e, 0x24, 0xd7, 0x29, 0xe1, 0x0b, 0xd8, + 0xf1, 0xf7, 0x18, 0xb8, 0xf0, 0x8b, 0xce, 0x3f, 0xf9, 0xca, 0xae, 0x53, 0x44, 0x68, 0x96, 0x8c, + 0xaa, 0x28, 0xee, 0xe2, 0x70, 0xbf, 0x6a, 0x22, 0xee, 0x42, 0x6b, 0x5c, 0x58, 0x32, 0xee, 0x4d, + 0x8f, 0x57, 0x49, 0xb8, 0xb7, 0x65, 0x70, 0xe1, 0x23, 0xd8, 0x3b, 0x2d, 0xd7, 0xe8, 0xac, 0x5a, + 0xa3, 0x53, 0xa9, 0x12, 0xdf, 0xd8, 0xb0, 0x0f, 0xc1, 0xff, 0x90, 0xc9, 0xb4, 0x32, 0x34, 0xfa, + 0xc6, 0xa0, 0x57, 0x07, 0xf1, 0x19, 0xdc, 0x2b, 0x87, 0x52, 0xf6, 0x1e, 0x7b, 0xbe, 0x40, 0xb7, + 0x6b, 0xfd, 0xb5, 0xec, 0x90, 0x1d, 0x31, 0x7c, 0x0f, 0xcd, 0x52, 0x0c, 0x07, 0x1e, 0xd9, 0x62, + 0xa0, 0x7f, 0xb0, 0x15, 0xaf, 0x5c, 0x8c, 0x5f, 0xfd, 0x58, 0x0d, 0xd8, 0xcf, 0xd5, 0x80, 0xfd, + 0x59, 0x0d, 0xd8, 0xf9, 0xe8, 0x76, 0x83, 0xad, 0x7f, 0x42, 0xf3, 0xb6, 0x3b, 0x9e, 0xff, 0x0b, + 0x00, 0x00, 0xff, 0xff, 0xf7, 0xa9, 0x04, 0x60, 0x69, 0x03, 0x00, 0x00, } // Reference imports to suppress errors if they are not otherwise used. @@ -410,6 +492,8 @@ const _ = grpc.SupportPackageIsVersion4 type ProxyServiceClient interface { // DialNode opens a bidrectional stream to the requested node. DialNode(ctx context.Context, opts ...grpc.CallOption) (ProxyService_DialNodeClient, error) + // Ping checks if the peer is reachable and responsive. + Ping(ctx context.Context, in *ProxyServicePingRequest, opts ...grpc.CallOption) (*ProxyServicePingResponse, error) } type proxyServiceClient struct { @@ -451,10 +535,21 @@ func (x *proxyServiceDialNodeClient) Recv() (*Frame, error) { return m, nil } +func (c *proxyServiceClient) Ping(ctx context.Context, in *ProxyServicePingRequest, opts ...grpc.CallOption) (*ProxyServicePingResponse, error) { + out := new(ProxyServicePingResponse) + err := c.cc.Invoke(ctx, "/proto.ProxyService/Ping", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // ProxyServiceServer is the server API for ProxyService service. type ProxyServiceServer interface { // DialNode opens a bidrectional stream to the requested node. DialNode(ProxyService_DialNodeServer) error + // Ping checks if the peer is reachable and responsive. + Ping(context.Context, *ProxyServicePingRequest) (*ProxyServicePingResponse, error) } // UnimplementedProxyServiceServer can be embedded to have forward compatible implementations. @@ -464,6 +559,9 @@ type UnimplementedProxyServiceServer struct { func (*UnimplementedProxyServiceServer) DialNode(srv ProxyService_DialNodeServer) error { return status.Errorf(codes.Unimplemented, "method DialNode not implemented") } +func (*UnimplementedProxyServiceServer) Ping(ctx context.Context, req *ProxyServicePingRequest) (*ProxyServicePingResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Ping not implemented") +} func RegisterProxyServiceServer(s *grpc.Server, srv ProxyServiceServer) { s.RegisterService(&_ProxyService_serviceDesc, srv) @@ -495,10 +593,33 @@ func (x *proxyServiceDialNodeServer) Recv() (*Frame, error) { return m, nil } +func _ProxyService_Ping_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ProxyServicePingRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ProxyServiceServer).Ping(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/proto.ProxyService/Ping", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ProxyServiceServer).Ping(ctx, req.(*ProxyServicePingRequest)) + } + return interceptor(ctx, in, info, handler) +} + var _ProxyService_serviceDesc = grpc.ServiceDesc{ ServiceName: "proto.ProxyService", HandlerType: (*ProxyServiceServer)(nil), - Methods: []grpc.MethodDesc{}, + Methods: []grpc.MethodDesc{ + { + MethodName: "Ping", + Handler: _ProxyService_Ping_Handler, + }, + }, Streams: []grpc.StreamDesc{ { StreamName: "DialNode", @@ -776,6 +897,60 @@ func (m *ConnectionEstablished) MarshalToSizedBuffer(dAtA []byte) (int, error) { return len(dAtA) - i, nil } +func (m *ProxyServicePingRequest) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *ProxyServicePingRequest) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *ProxyServicePingRequest) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.XXX_unrecognized != nil { + i -= len(m.XXX_unrecognized) + copy(dAtA[i:], m.XXX_unrecognized) + } + return len(dAtA) - i, nil +} + +func (m *ProxyServicePingResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *ProxyServicePingResponse) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *ProxyServicePingResponse) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.XXX_unrecognized != nil { + i -= len(m.XXX_unrecognized) + copy(dAtA[i:], m.XXX_unrecognized) + } + return len(dAtA) - i, nil +} + func encodeVarintProxyservice(dAtA []byte, offset int, v uint64) int { offset -= sovProxyservice(v) base := offset @@ -914,6 +1089,30 @@ func (m *ConnectionEstablished) Size() (n int) { return n } +func (m *ProxyServicePingRequest) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func (m *ProxyServicePingResponse) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + func sovProxyservice(x uint64) (n int) { return (math_bits.Len64(x|1) + 6) / 7 } @@ -1514,6 +1713,108 @@ func (m *ConnectionEstablished) Unmarshal(dAtA []byte) error { } return nil } +func (m *ProxyServicePingRequest) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProxyservice + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: ProxyServicePingRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: ProxyServicePingRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + default: + iNdEx = preIndex + skippy, err := skipProxyservice(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthProxyservice + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *ProxyServicePingResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowProxyservice + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: ProxyServicePingResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: ProxyServicePingResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + default: + iNdEx = preIndex + skippy, err := skipProxyservice(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthProxyservice + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} func skipProxyservice(dAtA []byte) (n int, err error) { l := len(dAtA) iNdEx := 0 diff --git a/api/proto/teleport/legacy/client/proto/proxyservice.proto b/api/proto/teleport/legacy/client/proto/proxyservice.proto index 0cabcc6f2c8e8..30940ef7c932e 100644 --- a/api/proto/teleport/legacy/client/proto/proxyservice.proto +++ b/api/proto/teleport/legacy/client/proto/proxyservice.proto @@ -24,6 +24,9 @@ option go_package = "github.com/gravitational/teleport/api/client/proto"; service ProxyService { // DialNode opens a bidrectional stream to the requested node. rpc DialNode(stream Frame) returns (stream Frame); + + // Ping checks if the peer is reachable and responsive. + rpc Ping(ProxyServicePingRequest) returns (ProxyServicePingResponse); } // Frame wraps different message types to be sent over a stream. @@ -63,3 +66,7 @@ message Data { // ConnectionEstablished signals to the client a connection to the node has been established. message ConnectionEstablished {} + +message ProxyServicePingRequest {} + +message ProxyServicePingResponse {} diff --git a/lib/proxy/peer/client.go b/lib/proxy/peer/client.go index e4e425c93ad6a..2f52d31edf6da 100644 --- a/lib/proxy/peer/client.go +++ b/lib/proxy/peer/client.go @@ -166,8 +166,10 @@ type grpcClientConn struct { cc *grpc.ClientConn metrics *clientMetrics - id string - addr string + id string + addr string + host string + group string // if closing is set, count is not allowed to increase from zero; upon // reaching zero, cond should be broadcast @@ -175,6 +177,8 @@ type grpcClientConn struct { cond sync.Cond closing bool count int + + pingCancel context.CancelFunc } var _ internal.ClientConn = (*grpcClientConn)(nil) @@ -208,7 +212,7 @@ func (c *grpcClientConn) maybeAcquire() (release func()) { // Shutdown implements [internal.ClientConn]. func (c *grpcClientConn) Shutdown(ctx context.Context) { - defer c.cc.Close() + defer c.Close() c.mu.Lock() defer c.mu.Unlock() @@ -229,9 +233,25 @@ func (c *grpcClientConn) Shutdown(ctx context.Context) { // Close implements [internal.ClientConn]. func (c *grpcClientConn) Close() error { + c.pingCancel() return c.cc.Close() } +// Ping implements [internal.ClientConn]. +func (c *grpcClientConn) Ping(ctx context.Context) error { + release := c.maybeAcquire() + if release == nil { + return trace.ConnectionProblem(nil, "error starting stream: connection is shutting down") + } + defer release() + + _, err := clientapi.NewProxyServiceClient(c.cc).Ping(ctx, new(clientapi.ProxyServicePingRequest)) + if trace.IsNotImplemented(err) { + err = nil + } + return trace.Wrap(err) +} + // Dial implements [internal.ClientConn]. func (c *grpcClientConn) Dial( nodeID string, @@ -456,7 +476,13 @@ func (c *Client) updateConnections(proxies []types.Server) error { } // establish new connections - conn, err := c.connect(id, proxy.GetPeerAddr()) + proxyGroup, _ := proxy.GetLabel(types.ProxyGroupIDLabel) + conn, err := c.connect(connectParams{ + peerID: id, + peerAddr: proxy.GetPeerAddr(), + peerHost: proxy.GetHostname(), + peerGroup: proxyGroup, + }) if err != nil { c.metrics.reportTunnelError(errorProxyPeerTunnelDial) c.config.Log.Debugf("Error dialing peer proxy %+v at %+v", id, proxy.GetPeerAddr()) @@ -656,7 +682,13 @@ func (c *Client) getConnections(proxyIDs []string) ([]internal.ClientConn, bool, continue } - conn, err := c.connect(id, proxy.GetPeerAddr()) + proxyGroup, _ := proxy.GetLabel(types.ProxyGroupIDLabel) + conn, err := c.connect(connectParams{ + peerID: id, + peerAddr: proxy.GetPeerAddr(), + peerHost: proxy.GetHostname(), + peerGroup: proxyGroup, + }) if err != nil { c.metrics.reportTunnelError(errorProxyPeerTunnelDirectDial) c.config.Log.Debugf("Error direct dialing peer proxy %+v at %+v", id, proxy.GetPeerAddr()) @@ -683,18 +715,25 @@ func (c *Client) getConnections(proxyIDs []string) ([]internal.ClientConn, bool, return conns, false, nil } -// connect dials a new connection to proxyAddr. -func (c *Client) connect(peerID string, peerAddr string) (internal.ClientConn, error) { +type connectParams struct { + peerID string + peerAddr string + peerHost string + peerGroup string +} + +// connect dials a new connection to a peer proxy with the given ID and address. +func (c *Client) connect(params connectParams) (internal.ClientConn, error) { tlsConfig, err := c.config.getConfigForServer() if err != nil { return nil, trace.Wrap(err, "Error updating client tls config") } - expectedPeer := authclient.HostFQDN(peerID, c.config.ClusterName) + expectedPeer := authclient.HostFQDN(params.peerID, c.config.ClusterName) conn, err := grpc.Dial( - peerAddr, - grpc.WithTransportCredentials(newClientCredentials(expectedPeer, peerAddr, c.config.Log, credentials.NewTLS(tlsConfig))), + params.peerAddr, + grpc.WithTransportCredentials(newClientCredentials(expectedPeer, params.peerAddr, c.config.Log, credentials.NewTLS(tlsConfig))), grpc.WithStatsHandler(newStatsHandler(c.reporter)), grpc.WithChainStreamInterceptor(metadata.StreamClientInterceptor, interceptors.GRPCClientStreamErrorInterceptor), grpc.WithKeepaliveParams(keepalive.ClientParameters{ @@ -705,14 +744,29 @@ func (c *Client) connect(peerID string, peerAddr string) (internal.ClientConn, e grpc.WithDefaultServiceConfig(`{"loadBalancingPolicy":"round_robin"}`), ) if err != nil { - return nil, trace.Wrap(err, "Error dialing proxy %q", peerID) + return nil, trace.Wrap(err, "Error dialing proxy %q", params.peerID) } - return &grpcClientConn{ + pingCtx, pingCancel := context.WithCancel(context.Background()) + cc := &grpcClientConn{ cc: conn, metrics: c.metrics, - id: peerID, - addr: peerAddr, - }, nil + id: params.peerID, + addr: params.peerAddr, + host: params.peerHost, + group: params.peerGroup, + + pingCancel: pingCancel, + } + + pings, pingFailures := internal.ClientPingsMetrics(internal.ClientPingsMetricsParams{ + LocalID: c.config.ID, + PeerID: params.peerID, + PeerHost: params.peerHost, + PeerGroup: params.peerGroup, + }) + go internal.RunClientPing(pingCtx, cc, pings, pingFailures) + + return cc, nil } diff --git a/lib/proxy/peer/client_test.go b/lib/proxy/peer/client_test.go index 81c4d6d814437..15c6228283570 100644 --- a/lib/proxy/peer/client_test.go +++ b/lib/proxy/peer/client_test.go @@ -144,7 +144,12 @@ func TestCAChange(t *testing.T) { server, _ := setupServer(t, "s1", serverCA, clientCA, types.RoleProxy) // dial server and send a test data frame - conn, err := client.connect("s1", server.config.Listener.Addr().String()) + conn, err := client.connect(connectParams{ + peerID: "s1", + peerAddr: server.config.Listener.Addr().String(), + peerHost: "s1", + peerGroup: "", + }) require.NoError(t, err) require.NotNil(t, conn) require.IsType(t, (*grpcClientConn)(nil), conn) @@ -161,7 +166,12 @@ func TestCAChange(t *testing.T) { // new connection should fail because client tls config still references old // RootCAs. - conn, err = client.connect("s1", server.config.Listener.Addr().String()) + conn, err = client.connect(connectParams{ + peerID: "s1", + peerAddr: server.config.Listener.Addr().String(), + peerHost: "s1", + peerGroup: "", + }) require.NoError(t, err) require.NotNil(t, conn) require.IsType(t, (*grpcClientConn)(nil), conn) @@ -179,7 +189,12 @@ func TestCAChange(t *testing.T) { return config, nil } - conn, err = client.connect("s1", server.config.Listener.Addr().String()) + conn, err = client.connect(connectParams{ + peerID: "s1", + peerAddr: server.config.Listener.Addr().String(), + peerHost: "s1", + peerGroup: "", + }) require.NoError(t, err) require.NotNil(t, conn) require.IsType(t, (*grpcClientConn)(nil), conn) diff --git a/lib/proxy/peer/helpers_test.go b/lib/proxy/peer/helpers_test.go index 2c7e34ef439e3..a5fd995f59835 100644 --- a/lib/proxy/peer/helpers_test.go +++ b/lib/proxy/peer/helpers_test.go @@ -59,6 +59,7 @@ type mockProxyAccessPoint struct { } type mockProxyService struct { + clientapi.UnimplementedProxyServiceServer mockDialNode func(stream clientapi.ProxyService_DialNodeServer) error } diff --git a/lib/proxy/peer/internal/clientconn.go b/lib/proxy/peer/internal/clientconn.go index f44e64afd7b52..b4fa423f07c03 100644 --- a/lib/proxy/peer/internal/clientconn.go +++ b/lib/proxy/peer/internal/clientconn.go @@ -40,6 +40,9 @@ type ClientConn interface { tunnelType types.TunnelType, ) (net.Conn, error) + // Ping checks if the peer is reachable and responsive. + Ping(context.Context) error + // Close closes all connections and releases any background resources // immediately. Close() error diff --git a/lib/proxy/peer/internal/metrics.go b/lib/proxy/peer/internal/metrics.go new file mode 100644 index 0000000000000..885f095684c77 --- /dev/null +++ b/lib/proxy/peer/internal/metrics.go @@ -0,0 +1,110 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package internal + +import ( + "context" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/utils/retryutils" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/interval" +) + +var ( + clientPingInitOnce sync.Once + + clientPingsTotal *prometheus.CounterVec + clientFailedPingsTotal *prometheus.CounterVec +) + +func clientPingInit() { + clientPingsTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: teleport.MetricNamespace, + Subsystem: "proxy_peer_client", + Name: "pings_total", + Help: "Total number of proxy peering client pings per peer proxy, both successful and failed.", + }, []string{"local_id", "peer_id", "peer_host", "peer_group"}) + + clientFailedPingsTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: teleport.MetricNamespace, + Subsystem: "proxy_peer_client", + Name: "failed_pings_total", + Help: "Total number of failed proxy peering client pings per peer proxy.", + }, []string{"local_id", "peer_id", "peer_host", "peer_group"}) +} + +// ClientPingsMetricsParams contains the parameters for [ClientPingsMetrics]. +type ClientPingsMetricsParams struct { + // LocalID is the host ID of the current proxy. + LocalID string + // PeerID is the host ID of the peer proxy. + PeerID string + // PeerHost is the hostname of the peer proxy. + PeerHost string + // PeerGroup is the peer group ID of the peer proxy. Can be blank. + PeerGroup string +} + +// ClientPingsMetrics returns the Prometheus metrics for a given peer proxy, +// given host ID, hostname and (optionally) peer group. +func ClientPingsMetrics(params ClientPingsMetricsParams) (pings, failedPings prometheus.Counter) { + clientPingInitOnce.Do(clientPingInit) + + pings = clientPingsTotal.WithLabelValues(params.LocalID, params.PeerID, params.PeerHost, params.PeerGroup) + failedPings = clientFailedPingsTotal.WithLabelValues(params.LocalID, params.PeerID, params.PeerHost, params.PeerGroup) + + return pings, failedPings +} + +// RunClientPing periodically pings the peer proxy reachable through the given +// [ClientConn], accumulating counts in the given Prometheus metrics. Returns +// when the context is canceled. +func RunClientPing(ctx context.Context, cc ClientConn, pings, failedPings prometheus.Counter) { + const pingInterval = time.Minute + ivl := interval.New(interval.Config{ + Duration: pingInterval * 14 / 13, + FirstDuration: utils.HalfJitter(pingInterval), + Jitter: retryutils.NewSeventhJitter(), + }) + defer ivl.Stop() + + for ctx.Err() == nil { + select { + case <-ivl.Next(): + func() { + timeoutCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + err := cc.Ping(timeoutCtx) + if err != nil { + if ctx.Err() != nil { + return + } + failedPings.Inc() + } + pings.Inc() + }() + case <-ctx.Done(): + } + } +} diff --git a/lib/proxy/peer/service.go b/lib/proxy/peer/service.go index 772399074051f..bb8bfe3c7af3d 100644 --- a/lib/proxy/peer/service.go +++ b/lib/proxy/peer/service.go @@ -19,6 +19,7 @@ package peer import ( + "context" "log/slog" "strings" @@ -106,6 +107,10 @@ func (s *proxyService) DialNode(stream proto.ProxyService_DialNodeServer) error return trace.Wrap(err) } +func (s *proxyService) Ping(ctx context.Context, _ *proto.ProxyServicePingRequest) (*proto.ProxyServicePingResponse, error) { + return new(proto.ProxyServicePingResponse), nil +} + // splitServerID splits a server id in to a node id and cluster name. func splitServerID(address string) (string, string, error) { split := strings.Split(address, ".")