Skip to content

Commit

Permalink
Periodically check connectivity between peer proxies
Browse files Browse the repository at this point in the history
  • Loading branch information
espadolini authored and github-actions committed Nov 13, 2024
1 parent c761ad1 commit d910bfd
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 19 deletions.
74 changes: 58 additions & 16 deletions lib/proxy/peer/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,19 @@ type grpcClientConn struct {
cc *grpc.ClientConn
metrics *clientMetrics

id string
addr string
id string
addr string
host string
group string

// if closing is set, count is not allowed to increase from zero; upon
// reaching zero, cond should be broadcast
mu sync.Mutex
cond sync.Cond
closing bool
count int

pingCancel context.CancelFunc
}

var _ internal.ClientConn = (*grpcClientConn)(nil)
Expand Down Expand Up @@ -211,7 +215,7 @@ func (c *grpcClientConn) maybeAcquire() (release func()) {

// Shutdown implements [internal.ClientConn].
func (c *grpcClientConn) Shutdown(ctx context.Context) {
defer c.cc.Close()
defer c.Close()

c.mu.Lock()
defer c.mu.Unlock()
Expand All @@ -232,6 +236,7 @@ func (c *grpcClientConn) Shutdown(ctx context.Context) {

// Close implements [internal.ClientConn].
func (c *grpcClientConn) Close() error {
c.pingCancel()
return c.cc.Close()
}

Expand Down Expand Up @@ -476,7 +481,14 @@ func (c *Client) updateConnections(proxies []types.Server) error {

// establish new connections
supportsQUIC, _ := proxy.GetLabel(types.UnstableProxyPeerQUICLabel)
conn, err := c.connect(id, proxy.GetPeerAddr(), supportsQUIC == "yes")
proxyGroup, _ := proxy.GetLabel(types.ProxyGroupIDLabel)
conn, err := c.connect(connectParams{
peerID: id,
peerAddr: proxy.GetPeerAddr(),
peerHost: proxy.GetHostname(),
peerGroup: proxyGroup,
supportsQUIC: supportsQUIC == "yes",
})
if err != nil {
c.metrics.reportTunnelError(errorProxyPeerTunnelDial)
c.config.Log.DebugContext(c.ctx, "error dialing peer proxy", "peer_id", id, "peer_addr", proxy.GetPeerAddr())
Expand Down Expand Up @@ -677,7 +689,14 @@ func (c *Client) getConnections(proxyIDs []string) ([]internal.ClientConn, bool,
}

supportsQUIC, _ := proxy.GetLabel(types.UnstableProxyPeerQUICLabel)
conn, err := c.connect(id, proxy.GetPeerAddr(), supportsQUIC == "yes")
proxyGroup, _ := proxy.GetLabel(types.ProxyGroupIDLabel)
conn, err := c.connect(connectParams{
peerID: id,
peerAddr: proxy.GetPeerAddr(),
peerHost: proxy.GetHostname(),
peerGroup: proxyGroup,
supportsQUIC: supportsQUIC == "yes",
})
if err != nil {
c.metrics.reportTunnelError(errorProxyPeerTunnelDirectDial)
c.config.Log.DebugContext(c.ctx, "error direct dialing peer proxy", "peer_id", id, "peer_addr", proxy.GetPeerAddr())
Expand All @@ -704,9 +723,17 @@ func (c *Client) getConnections(proxyIDs []string) ([]internal.ClientConn, bool,
return conns, false, nil
}

// connect dials a new connection to proxyAddr.
func (c *Client) connect(peerID string, peerAddr string, supportsQUIC bool) (internal.ClientConn, error) {
if supportsQUIC && c.config.QUICTransport != nil {
type connectParams struct {
peerID string
peerAddr string
peerHost string
peerGroup string
supportsQUIC bool
}

// connect dials a new connection to a peer proxy with the given ID and address.
func (c *Client) connect(params connectParams) (internal.ClientConn, error) {
if params.supportsQUIC && c.config.QUICTransport != nil {
panic("QUIC proxy peering is not implemented")
}
tlsConfig := utils.TLSConfig(c.config.TLSCipherSuites)
Expand All @@ -721,11 +748,11 @@ func (c *Client) connect(peerID string, peerAddr string, supportsQUIC bool) (int
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyConnection = utils.VerifyConnectionWithRoots(c.config.GetTLSRoots)

expectedPeer := authclient.HostFQDN(peerID, c.config.ClusterName)
expectedPeer := authclient.HostFQDN(params.peerID, c.config.ClusterName)

conn, err := grpc.Dial(
peerAddr,
grpc.WithTransportCredentials(newClientCredentials(expectedPeer, peerAddr, c.config.Log, credentials.NewTLS(tlsConfig))),
params.peerAddr,
grpc.WithTransportCredentials(newClientCredentials(expectedPeer, params.peerAddr, c.config.Log, credentials.NewTLS(tlsConfig))),
grpc.WithStatsHandler(newStatsHandler(c.reporter)),
grpc.WithChainStreamInterceptor(metadata.StreamClientInterceptor, interceptors.GRPCClientStreamErrorInterceptor),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Expand All @@ -736,14 +763,29 @@ func (c *Client) connect(peerID string, peerAddr string, supportsQUIC bool) (int
grpc.WithDefaultServiceConfig(`{"loadBalancingPolicy":"round_robin"}`),
)
if err != nil {
return nil, trace.Wrap(err, "Error dialing proxy %q", peerID)
return nil, trace.Wrap(err, "Error dialing proxy %q", params.peerID)
}

return &grpcClientConn{
pingCtx, pingCancel := context.WithCancel(context.Background())
cc := &grpcClientConn{
cc: conn,
metrics: c.metrics,

id: peerID,
addr: peerAddr,
}, nil
id: params.peerID,
addr: params.peerAddr,
host: params.peerHost,
group: params.peerGroup,

pingCancel: pingCancel,
}

pings, pingFailures := internal.ClientPingsMetrics(internal.ClientPingsMetricsParams{
LocalID: c.config.ID,
PeerID: params.peerID,
PeerHost: params.peerHost,
PeerGroup: params.peerGroup,
})
go internal.RunClientPing(pingCtx, cc, pings, pingFailures)

return cc, nil
}
24 changes: 21 additions & 3 deletions lib/proxy/peer/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,13 @@ func TestCAChange(t *testing.T) {

// dial server and send a test data frame
const supportsQUICFalse = false
conn, err := client.connect("s1", ts.GetPeerAddr(), supportsQUICFalse)
conn, err := client.connect(connectParams{
peerID: "s1",
peerAddr: ts.GetPeerAddr(),
peerHost: "s1",
peerGroup: "",
supportsQUIC: supportsQUICFalse,
})
require.NoError(t, err)
require.NotNil(t, conn)
require.IsType(t, (*grpcClientConn)(nil), conn)
Expand All @@ -163,7 +169,13 @@ func TestCAChange(t *testing.T) {

// new connection should fail because client tls config still references old
// RootCAs.
conn, err = client.connect("s1", ts.GetPeerAddr(), supportsQUICFalse)
conn, err = client.connect(connectParams{
peerID: "s1",
peerAddr: ts.GetPeerAddr(),
peerHost: "s1",
peerGroup: "",
supportsQUIC: supportsQUICFalse,
})
require.NoError(t, err)
require.NotNil(t, conn)
require.IsType(t, (*grpcClientConn)(nil), conn)
Expand All @@ -175,7 +187,13 @@ func TestCAChange(t *testing.T) {
// RootCAs.
currentServerCA.Store(newServerCA)

conn, err = client.connect("s1", ts.GetPeerAddr(), supportsQUICFalse)
conn, err = client.connect(connectParams{
peerID: "s1",
peerAddr: ts.GetPeerAddr(),
peerHost: "s1",
peerGroup: "",
supportsQUIC: supportsQUICFalse,
})
require.NoError(t, err)
require.NotNil(t, conn)
require.IsType(t, (*grpcClientConn)(nil), conn)
Expand Down
110 changes: 110 additions & 0 deletions lib/proxy/peer/internal/metrics.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package internal

import (
"context"
"sync"
"time"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/interval"
)

var (
clientPingInitOnce sync.Once

clientPingsTotal *prometheus.CounterVec
clientFailedPingsTotal *prometheus.CounterVec
)

func clientPingInit() {
clientPingsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: teleport.MetricNamespace,
Subsystem: "proxy_peer_client",
Name: "pings_total",
Help: "Total number of proxy peering client pings per peer proxy, both successful and failed.",
}, []string{"local_id", "peer_id", "peer_host", "peer_group"})

clientFailedPingsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: teleport.MetricNamespace,
Subsystem: "proxy_peer_client",
Name: "failed_pings_total",
Help: "Total number of failed proxy peering client pings per peer proxy.",
}, []string{"local_id", "peer_id", "peer_host", "peer_group"})
}

// ClientPingsMetricsParams contains the parameters for [ClientPingsMetrics].
type ClientPingsMetricsParams struct {
// LocalID is the host ID of the current proxy.
LocalID string
// PeerID is the host ID of the peer proxy.
PeerID string
// PeerHost is the hostname of the peer proxy.
PeerHost string
// PeerGroup is the peer group ID of the peer proxy. Can be blank.
PeerGroup string
}

// ClientPingsMetrics returns the Prometheus metrics for a given peer proxy,
// given host ID, hostname and (optionally) peer group.
func ClientPingsMetrics(params ClientPingsMetricsParams) (pings, failedPings prometheus.Counter) {
clientPingInitOnce.Do(clientPingInit)

pings = clientPingsTotal.WithLabelValues(params.LocalID, params.PeerID, params.PeerHost, params.PeerGroup)
failedPings = clientFailedPingsTotal.WithLabelValues(params.LocalID, params.PeerID, params.PeerHost, params.PeerGroup)

return pings, failedPings
}

// RunClientPing periodically pings the peer proxy reachable through the given
// [ClientConn], accumulating counts in the given Prometheus metrics. Returns
// when the context is canceled.
func RunClientPing(ctx context.Context, cc ClientConn, pings, failedPings prometheus.Counter) {
const pingInterval = time.Minute
ivl := interval.New(interval.Config{
Duration: pingInterval * 14 / 13,
FirstDuration: utils.HalfJitter(pingInterval),
Jitter: retryutils.NewSeventhJitter(),
})
defer ivl.Stop()

for ctx.Err() == nil {
select {
case <-ivl.Next():
func() {
timeoutCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()

err := cc.Ping(timeoutCtx)
if err != nil {
if ctx.Err() != nil {
return
}
failedPings.Inc()
}
pings.Inc()
}()
case <-ctx.Done():
}
}
}

0 comments on commit d910bfd

Please sign in to comment.