Skip to content

Commit

Permalink
Generalize resource watchers (#47561)
Browse files Browse the repository at this point in the history
Consolidate resource watchers into a single watcher that leverages
generics. While most of the resource watchers were similar, some
resources have some one off functionality. These watchers have not
been touched, however, all that could be refactored to use the
generic watcher easily were.
  • Loading branch information
rosstimothy authored Oct 28, 2024
1 parent ac2f8dc commit 24e8b68
Show file tree
Hide file tree
Showing 27 changed files with 1,042 additions and 1,362 deletions.
13 changes: 9 additions & 4 deletions lib/kube/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"github.com/gravitational/teleport/lib/multiplexer"
"github.com/gravitational/teleport/lib/reversetunnel"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/readonly"
"github.com/gravitational/teleport/lib/srv"
"github.com/gravitational/teleport/lib/srv/ingress"
)
Expand Down Expand Up @@ -98,7 +99,7 @@ type TLSServerConfig struct {
// kubernetes cluster name. Proxy uses this map to route requests to the correct
// kubernetes_service. The servers are kept in memory to avoid making unnecessary
// unmarshal calls followed by filtering and to improve memory usage.
KubernetesServersWatcher *services.KubeServerWatcher
KubernetesServersWatcher *services.GenericWatcher[types.KubeServer, readonly.KubeServer]
// PROXYProtocolMode controls behavior related to unsigned PROXY protocol headers.
PROXYProtocolMode multiplexer.PROXYProtocolMode
// InventoryHandle is used to send kube server heartbeats via the inventory control stream.
Expand Down Expand Up @@ -170,7 +171,7 @@ type TLSServer struct {
closeContext context.Context
closeFunc context.CancelFunc
// kubeClusterWatcher monitors changes to kube cluster resources.
kubeClusterWatcher *services.KubeClusterWatcher
kubeClusterWatcher *services.GenericWatcher[types.KubeCluster, readonly.KubeCluster]
// reconciler reconciles proxied kube clusters with kube_clusters resources.
reconciler *services.Reconciler[types.KubeCluster]
// monitoredKubeClusters contains all kube clusters the proxied kube_clusters are
Expand Down Expand Up @@ -620,7 +621,9 @@ func (t *TLSServer) getKubernetesServersForKubeClusterFunc() (getKubeServersByNa
}, nil
case ProxyService:
return func(ctx context.Context, name string) ([]types.KubeServer, error) {
servers, err := t.KubernetesServersWatcher.GetKubeServersByClusterName(ctx, name)
servers, err := t.KubernetesServersWatcher.CurrentResourcesWithFilter(ctx, func(ks readonly.KubeServer) bool {
return ks.GetCluster().GetName() == name
})
return servers, trace.Wrap(err)
}, nil
case LegacyProxyService:
Expand All @@ -630,7 +633,9 @@ func (t *TLSServer) getKubernetesServersForKubeClusterFunc() (getKubeServersByNa
// and forward the request to the next proxy.
kube, err := t.getKubeClusterWithServiceLabels(name)
if err != nil {
servers, err := t.KubernetesServersWatcher.GetKubeServersByClusterName(ctx, name)
servers, err := t.KubernetesServersWatcher.CurrentResourcesWithFilter(ctx, func(ks readonly.KubeServer) bool {
return ks.GetCluster().GetName() == name
})
return servers, trace.Wrap(err)
}
srv, err := types.NewKubernetesServerV3FromCluster(kube, "", t.HostID)
Expand Down
3 changes: 2 additions & 1 deletion lib/kube/proxy/utils_testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ func SetupTestContext(ctx context.Context, t *testing.T, cfg TestConfig) *TestCo
Component: teleport.ComponentKube,
Client: client,
},
KubernetesServerGetter: client,
},
)
require.NoError(t, err)
Expand Down Expand Up @@ -387,7 +388,7 @@ func SetupTestContext(ctx context.Context, t *testing.T, cfg TestConfig) *TestCo

// Ensure watcher has the correct list of clusters.
require.Eventually(t, func() bool {
kubeServers, err := kubeServersWatcher.GetKubernetesServers(ctx)
kubeServers, err := kubeServersWatcher.CurrentResources(ctx)
return err == nil && len(kubeServers) == len(cfg.Clusters)
}, 3*time.Second, time.Millisecond*100)

Expand Down
6 changes: 4 additions & 2 deletions lib/kube/proxy/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/readonly"
"github.com/gravitational/teleport/lib/utils"
)

Expand Down Expand Up @@ -89,7 +90,7 @@ func (s *TLSServer) startReconciler(ctx context.Context) (err error) {

// startKubeClusterResourceWatcher starts watching changes to Kube Clusters resources and
// registers/unregisters the proxied Kube Cluster accordingly.
func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*services.KubeClusterWatcher, error) {
func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*services.GenericWatcher[types.KubeCluster, readonly.KubeCluster], error) {
if len(s.ResourceMatchers) == 0 || s.KubeServiceType != KubeService {
s.log.Debug("Not initializing Kube Cluster resource watcher.")
return nil, nil
Expand All @@ -102,6 +103,7 @@ func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*servi
// Logger: s.log,
Client: s.AccessPoint,
},
KubernetesClusterGetter: s.AccessPoint,
})
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -110,7 +112,7 @@ func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*servi
defer watcher.Close()
for {
select {
case clusters := <-watcher.KubeClustersC:
case clusters := <-watcher.ResourcesC:
s.monitoredKubeClusters.setResources(clusters)
select {
case s.reconcileCh <- struct{}{}:
Expand Down
4 changes: 3 additions & 1 deletion lib/proxy/peer/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import (
// AccessPoint is the subset of the auth cache consumed by the [Client].
type AccessPoint interface {
types.Events
services.ProxyGetter
}

// ClientConfig configures a Client instance.
Expand Down Expand Up @@ -416,6 +417,7 @@ func (c *Client) sync() {
Client: c.config.AccessPoint,
Logger: c.config.Log,
},
ProxyGetter: c.config.AccessPoint,
ProxyDiffer: func(old, new types.Server) bool {
return old.GetPeerAddr() != new.GetPeerAddr()
},
Expand All @@ -434,7 +436,7 @@ func (c *Client) sync() {
case <-proxyWatcher.Done():
c.config.Log.DebugContext(c.ctx, "stopping peer proxy sync: proxy watcher done")
return
case proxies := <-proxyWatcher.ProxiesC:
case proxies := <-proxyWatcher.ResourcesC:
if err := c.updateConnections(proxies); err != nil {
c.config.Log.ErrorContext(c.ctx, "error syncing peer proxies", "error", err)
}
Expand Down
9 changes: 5 additions & 4 deletions lib/proxy/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
"github.com/gravitational/teleport/lib/observability/metrics"
"github.com/gravitational/teleport/lib/reversetunnelclient"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/readonly"
"github.com/gravitational/teleport/lib/teleagent"
"github.com/gravitational/teleport/lib/utils"
)
Expand Down Expand Up @@ -383,7 +384,7 @@ func (r *Router) getRemoteCluster(ctx context.Context, clusterName string, check
// site is the minimum interface needed to match servers
// for a reversetunnelclient.RemoteSite. It makes testing easier.
type site interface {
GetNodes(ctx context.Context, fn func(n services.Node) bool) ([]types.Server, error)
GetNodes(ctx context.Context, fn func(n readonly.Server) bool) ([]types.Server, error)
GetClusterNetworkingConfig(ctx context.Context) (types.ClusterNetworkingConfig, error)
}

Expand All @@ -394,13 +395,13 @@ type remoteSite struct {
}

// GetNodes uses the wrapped sites NodeWatcher to filter nodes
func (r remoteSite) GetNodes(ctx context.Context, fn func(n services.Node) bool) ([]types.Server, error) {
func (r remoteSite) GetNodes(ctx context.Context, fn func(n readonly.Server) bool) ([]types.Server, error) {
watcher, err := r.site.NodeWatcher()
if err != nil {
return nil, trace.Wrap(err)
}

return watcher.GetNodes(ctx, fn), nil
return watcher.CurrentResourcesWithFilter(ctx, fn)
}

// GetClusterNetworkingConfig uses the wrapped sites cache to retrieve the ClusterNetworkingConfig
Expand Down Expand Up @@ -450,7 +451,7 @@ func getServerWithResolver(ctx context.Context, host, port string, site site, re

var maxScore int
scores := make(map[string]int)
matches, err := site.GetNodes(ctx, func(server services.Node) bool {
matches, err := site.GetNodes(ctx, func(server readonly.Server) bool {
score := routeMatcher.RouteToServerScore(server)
if score < 1 {
return false
Expand Down
4 changes: 2 additions & 2 deletions lib/proxy/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import (
"github.com/gravitational/teleport/lib/cryptosuites"
"github.com/gravitational/teleport/lib/observability/tracing"
"github.com/gravitational/teleport/lib/reversetunnelclient"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/readonly"
"github.com/gravitational/teleport/lib/teleagent"
"github.com/gravitational/teleport/lib/utils"
)
Expand All @@ -51,7 +51,7 @@ func (t testSite) GetClusterNetworkingConfig(ctx context.Context) (types.Cluster
return t.cfg, nil
}

func (t testSite) GetNodes(ctx context.Context, fn func(n services.Node) bool) ([]types.Server, error) {
func (t testSite) GetNodes(ctx context.Context, fn func(n readonly.Server) bool) ([]types.Server, error) {
var out []types.Server
for _, s := range t.nodes {
if fn(s) {
Expand Down
23 changes: 17 additions & 6 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
"github.com/gravitational/teleport/lib/reversetunnel/track"
"github.com/gravitational/teleport/lib/reversetunnelclient"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/readonly"
"github.com/gravitational/teleport/lib/srv/forward"
"github.com/gravitational/teleport/lib/teleagent"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -180,7 +181,7 @@ func (s *localSite) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, err
}

// NodeWatcher returns a services.NodeWatcher for this cluster.
func (s *localSite) NodeWatcher() (*services.NodeWatcher, error) {
func (s *localSite) NodeWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error) {
return s.srv.NodeWatcher, nil
}

Expand Down Expand Up @@ -738,7 +739,11 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch
return
case <-proxyResyncTicker.Chan():
var req discoveryRequest
req.SetProxies(s.srv.proxyWatcher.GetCurrent())
proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx)
if err != nil {
logger.WithError(err).Warn("Failed to get proxy set")
}
req.SetProxies(proxies)

if err := rconn.sendDiscoveryRequest(req); err != nil {
logger.WithError(err).Debug("Marking connection invalid on error")
Expand All @@ -763,9 +768,12 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch
if firstHeartbeat {
// as soon as the agent connects and sends a first heartbeat
// send it the list of current proxies back
current := s.srv.proxyWatcher.GetCurrent()
if len(current) > 0 {
rconn.updateProxies(current)
proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx)
if err != nil {
logger.WithError(err).Warn("Failed to get proxy set")
}
if len(proxies) > 0 {
rconn.updateProxies(proxies)
}
reverseSSHTunnels.WithLabelValues(rconn.tunnelType).Inc()
firstHeartbeat = false
Expand Down Expand Up @@ -934,7 +942,7 @@ func (s *localSite) periodicFunctions() {

// sshTunnelStats reports SSH tunnel statistics for the cluster.
func (s *localSite) sshTunnelStats() error {
missing := s.srv.NodeWatcher.GetNodes(s.srv.ctx, func(server services.Node) bool {
missing, err := s.srv.NodeWatcher.CurrentResourcesWithFilter(s.srv.ctx, func(server readonly.Server) bool {
// Skip over any servers that have a TTL larger than announce TTL (10
// minutes) and are non-IoT SSH servers (they won't have tunnels).
//
Expand Down Expand Up @@ -966,6 +974,9 @@ func (s *localSite) sshTunnelStats() error {

return err != nil
})
if err != nil {
return trace.Wrap(err)
}

// Update Prometheus metrics and also log if any tunnels are missing.
missingSSHTunnels.Set(float64(len(missing)))
Expand Down
16 changes: 10 additions & 6 deletions lib/reversetunnel/localsite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,16 @@ func TestRemoteConnCleanup(t *testing.T) {

clock := clockwork.NewFakeClock()

clt := &mockLocalSiteClient{}
watcher, err := services.NewProxyWatcher(ctx, services.ProxyWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: "test",
Logger: utils.NewSlogLoggerForTests(),
Clock: clock,
Client: &mockLocalSiteClient{},
Client: clt,
},
ProxiesC: make(chan []types.Server, 2),
ProxyGetter: clt,
ProxiesC: make(chan []types.Server, 2),
})
require.NoError(t, err)
require.NoError(t, watcher.WaitInitialization())
Expand Down Expand Up @@ -249,17 +251,19 @@ func TestProxyResync(t *testing.T) {
proxy2, err := types.NewServer(uuid.NewString(), types.KindProxy, types.ServerSpecV2{})
require.NoError(t, err)

clt := &mockLocalSiteClient{
proxies: []types.Server{proxy1, proxy2},
}
// set up the watcher and wait for it to be initialized
watcher, err := services.NewProxyWatcher(ctx, services.ProxyWatcherConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: "test",
Logger: utils.NewSlogLoggerForTests(),
Clock: clock,
Client: &mockLocalSiteClient{
proxies: []types.Server{proxy1, proxy2},
},
Client: clt,
},
ProxiesC: make(chan []types.Server, 2),
ProxyGetter: clt,
ProxiesC: make(chan []types.Server, 2),
})
require.NoError(t, err)
require.NoError(t, watcher.WaitInitialization())
Expand Down
5 changes: 3 additions & 2 deletions lib/reversetunnel/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/reversetunnelclient"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/readonly"
)

func newClusterPeers(clusterName string) *clusterPeers {
Expand Down Expand Up @@ -90,7 +91,7 @@ func (p *clusterPeers) CachingAccessPoint() (authclient.RemoteProxyAccessPoint,
return peer.CachingAccessPoint()
}

func (p *clusterPeers) NodeWatcher() (*services.NodeWatcher, error) {
func (p *clusterPeers) NodeWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error) {
peer, err := p.pickPeer()
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -202,7 +203,7 @@ func (s *clusterPeer) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, e
return nil, trace.ConnectionProblem(nil, "unable to fetch access point, this proxy %v has not been discovered yet, try again later", s)
}

func (s *clusterPeer) NodeWatcher() (*services.NodeWatcher, error) {
func (s *clusterPeer) NodeWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error) {
return nil, trace.ConnectionProblem(nil, "unable to fetch node watcher, this proxy %v has not been discovered yet, try again later", s)
}

Expand Down
20 changes: 14 additions & 6 deletions lib/reversetunnel/remotesite.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/reversetunnelclient"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/readonly"
"github.com/gravitational/teleport/lib/srv/forward"
"github.com/gravitational/teleport/lib/teleagent"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -85,7 +86,7 @@ type remoteSite struct {
remoteAccessPoint authclient.RemoteProxyAccessPoint

// nodeWatcher provides access the node set for the remote site
nodeWatcher *services.NodeWatcher
nodeWatcher *services.GenericWatcher[types.Server, readonly.Server]

// remoteCA is the last remote certificate authority recorded by the client.
// It is used to detect CA rotation status changes. If the rotation
Expand Down Expand Up @@ -164,7 +165,7 @@ func (s *remoteSite) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, er
}

// NodeWatcher returns the services.NodeWatcher for the remote cluster.
func (s *remoteSite) NodeWatcher() (*services.NodeWatcher, error) {
func (s *remoteSite) NodeWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error) {
return s.nodeWatcher, nil
}

Expand Down Expand Up @@ -429,7 +430,11 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch
return
case <-proxyResyncTicker.Chan():
var req discoveryRequest
req.SetProxies(s.srv.proxyWatcher.GetCurrent())
proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx)
if err != nil {
logger.WithError(err).Warn("Failed to get proxy set")
}
req.SetProxies(proxies)

if err := conn.sendDiscoveryRequest(req); err != nil {
logger.WithError(err).Debug("Marking connection invalid on error")
Expand Down Expand Up @@ -458,9 +463,12 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch
if firstHeartbeat {
// as soon as the agent connects and sends a first heartbeat
// send it the list of current proxies back
current := s.srv.proxyWatcher.GetCurrent()
if len(current) > 0 {
conn.updateProxies(current)
proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx)
if err != nil {
logger.WithError(err).Warn("Failed to get proxy set")
}
if len(proxies) > 0 {
conn.updateProxies(proxies)
}
firstHeartbeat = false
}
Expand Down
Loading

0 comments on commit 24e8b68

Please sign in to comment.