From ab26f36a4fe5b655214fac29827863c31b768319 Mon Sep 17 00:00:00 2001 From: Marco Dinis Date: Tue, 17 Dec 2024 16:52:10 +0000 Subject: [PATCH] [v16] Cache Kubernetes App Discovery port check results (#50342) * Cache Kubernetes App Discovery port check results For the K8S Services that we couldn't auto detect the protocol, after trying to infer the port, cache the result. Cache is evicted when the K8S Service changes (we check the Service's ResourceVersion). * add mutex to cached responses * fix flaky test when hitting the HTTPS server --- lib/srv/discovery/discovery.go | 2 +- lib/srv/discovery/discovery_test.go | 2 +- lib/srv/discovery/fetchers/kube_services.go | 81 +++++++++++++------ .../discovery/fetchers/kube_services_test.go | 74 ++++++++++++++--- 4 files changed, 125 insertions(+), 34 deletions(-) diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index e5c6b8f892f5c..5f77a6d7a1ca1 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -233,7 +233,7 @@ kubernetes matchers are present.`) c.LegacyLogger = logrus.New() } if c.protocolChecker == nil { - c.protocolChecker = fetchers.NewProtoChecker(false) + c.protocolChecker = fetchers.NewProtoChecker() } if c.PollInterval == 0 { diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index 4c8acb864f721..116fc9b15920c 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -899,7 +899,7 @@ func newMockKubeService(name, namespace, externalName string, labels, annotation type noopProtocolChecker struct{} // CheckProtocol for noopProtocolChecker just returns 'tcp' -func (*noopProtocolChecker) CheckProtocol(uri string) string { +func (*noopProtocolChecker) CheckProtocol(service corev1.Service, port corev1.ServicePort) string { return "tcp" } diff --git a/lib/srv/discovery/fetchers/kube_services.go b/lib/srv/discovery/fetchers/kube_services.go index 3574e0a31a851..2618ef4b97195 100644 --- a/lib/srv/discovery/fetchers/kube_services.go +++ b/lib/srv/discovery/fetchers/kube_services.go @@ -20,7 +20,7 @@ package fetchers import ( "context" - "crypto/tls" + "errors" "fmt" "net/http" "slices" @@ -73,7 +73,7 @@ func (k *KubeAppsFetcherConfig) CheckAndSetDefaults() error { return trace.BadParameter("missing parameter ClusterName") } if k.ProtocolChecker == nil { - k.ProtocolChecker = NewProtoChecker(false) + k.ProtocolChecker = NewProtoChecker() } return nil @@ -194,7 +194,7 @@ func (f *KubeAppFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, er case protoHTTPS, protoHTTP, protoTCP: portProtocols[port] = protocolAnnotation default: - if p := autoProtocolDetection(services.GetServiceFQDN(service), port, f.ProtocolChecker); p != protoTCP { + if p := autoProtocolDetection(service, port, f.ProtocolChecker); p != protoTCP { portProtocols[port] = p } } @@ -259,7 +259,7 @@ func (f *KubeAppFetcher) String() string { // - If port's name is `http` or number is 80 or 8080, we return `http` // - If protocol checker is available it will perform HTTP request to the service fqdn trying to find out protocol. If it // gives us result `http` or `https` we return it -func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolChecker) string { +func autoProtocolDetection(service v1.Service, port v1.ServicePort, pc ProtocolChecker) string { if port.AppProtocol != nil { switch p := strings.ToLower(*port.AppProtocol); p { case protoHTTP, protoHTTPS: @@ -281,8 +281,7 @@ func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolC } if pc != nil { - result := pc.CheckProtocol(fmt.Sprintf("%s:%d", serviceFQDN, port.Port)) - if result != protoTCP { + if result := pc.CheckProtocol(service, port); result != protoTCP { return result } } @@ -292,7 +291,7 @@ func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolC // ProtocolChecker is an interface used to check what protocol uri serves type ProtocolChecker interface { - CheckProtocol(uri string) string + CheckProtocol(service v1.Service, port v1.ServicePort) string } func getServicePorts(s v1.Service) ([]v1.ServicePort, error) { @@ -324,40 +323,76 @@ func getServicePorts(s v1.Service) ([]v1.ServicePort, error) { } type ProtoChecker struct { - InsecureSkipVerify bool - client *http.Client + client *http.Client + + // cacheKubernetesServiceProtocol maps a Kubernetes Service Namespace/Name to a tuple containing the Service's ResourceVersion and the Protocol. + // When the Kubernetes Service ResourceVersion changes, then we assume the protocol might've changed as well, so the cache is invalidated. + // Only protocol checkers that require a network connection are cached. + cacheKubernetesServiceProtocol map[kubernetesNameNamespace]appResourceVersionProtocol + cacheMU sync.RWMutex +} + +type appResourceVersionProtocol struct { + resourceVersion string + protocol string +} + +type kubernetesNameNamespace struct { + namespace string + name string } -func NewProtoChecker(insecureSkipVerify bool) *ProtoChecker { +func NewProtoChecker() *ProtoChecker { p := &ProtoChecker{ - InsecureSkipVerify: insecureSkipVerify, client: &http.Client{ // This is a best-effort scenario, where teleport tries to guess which protocol is being used. // Ideally it should either be inferred by the Service's ports or explicitly configured by using annotations on the service. Timeout: 500 * time.Millisecond, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: insecureSkipVerify, - }, - }, }, + cacheKubernetesServiceProtocol: make(map[kubernetesNameNamespace]appResourceVersionProtocol), } return p } -func (p *ProtoChecker) CheckProtocol(uri string) string { +func (p *ProtoChecker) CheckProtocol(service v1.Service, port v1.ServicePort) string { if p.client == nil { return protoTCP } - resp, err := p.client.Head(fmt.Sprintf("https://%s", uri)) - if err == nil { + key := kubernetesNameNamespace{namespace: service.Namespace, name: service.Name} + + p.cacheMU.RLock() + versionProtocol, keyIsCached := p.cacheKubernetesServiceProtocol[key] + p.cacheMU.RUnlock() + + if keyIsCached && versionProtocol.resourceVersion == service.ResourceVersion { + return versionProtocol.protocol + } + + var result string + + uri := fmt.Sprintf("https://%s:%d", services.GetServiceFQDN(service), port.Port) + resp, err := p.client.Head(uri) + switch { + case err == nil: + result = protoHTTPS _ = resp.Body.Close() - return protoHTTPS - } else if strings.Contains(err.Error(), "server gave HTTP response to HTTPS client") { - return protoHTTP + + case errors.Is(err, http.ErrSchemeMismatch): + result = protoHTTP + + default: + result = protoTCP + } - return protoTCP + p.cacheMU.Lock() + p.cacheKubernetesServiceProtocol[key] = appResourceVersionProtocol{ + resourceVersion: service.ResourceVersion, + protocol: result, + } + p.cacheMU.Unlock() + + return result } diff --git a/lib/srv/discovery/fetchers/kube_services_test.go b/lib/srv/discovery/fetchers/kube_services_test.go index 9b139e35b151f..4502c0ab0b6ff 100644 --- a/lib/srv/discovery/fetchers/kube_services_test.go +++ b/lib/srv/discovery/fetchers/kube_services_test.go @@ -20,13 +20,18 @@ package fetchers import ( "context" + "crypto/tls" "fmt" "net" "net/http" "net/http/httptest" + "net/url" "slices" + "strconv" "strings" + "sync/atomic" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" @@ -36,6 +41,7 @@ import ( "k8s.io/client-go/kubernetes/fake" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" ) @@ -305,7 +311,8 @@ type mockProtocolChecker struct { results map[string]string } -func (m *mockProtocolChecker) CheckProtocol(uri string) string { +func (m *mockProtocolChecker) CheckProtocol(service corev1.Service, port corev1.ServicePort) string { + uri := fmt.Sprintf("%s:%d", services.GetServiceFQDN(service), port.Port) if result, ok := m.results[uri]; ok { return result } @@ -453,18 +460,39 @@ func TestGetServicePorts(t *testing.T) { func TestProtoChecker_CheckProtocol(t *testing.T) { t.Parallel() - checker := NewProtoChecker(true) + checker := NewProtoChecker() + // Increasing client Timeout because CI/CD fails with a lower value. + checker.client.Timeout = 5 * time.Second + + // Allow connections to HTTPS server created below. + checker.client.Transport = &http.Transport{TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }} + + totalNetworkHits := &atomic.Int32{} httpsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + totalNetworkHits.Add(1) _, _ = fmt.Fprintln(w, "httpsServer") })) + httpsServerBaseURL, err := url.Parse(httpsServer.URL) + require.NoError(t, err) + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = fmt.Fprintln(w, "httpServer") + // this never gets called because the HTTP server will not accept the HTTPS request. })) + httpServerBaseURL, err := url.Parse(httpServer.URL) + require.NoError(t, err) + tcpServer := newTCPServer(t, func(conn net.Conn) { + totalNetworkHits.Add(1) _, _ = conn.Write([]byte("tcpServer")) _ = conn.Close() }) + tcpServerBaseURL := &url.URL{ + Host: tcpServer.Addr().String(), + } + t.Cleanup(func() { httpsServer.Close() httpServer.Close() @@ -472,27 +500,55 @@ func TestProtoChecker_CheckProtocol(t *testing.T) { }) tests := []struct { - uri string + host string expected string }{ { - uri: strings.TrimPrefix(httpServer.URL, "http://"), + host: httpServerBaseURL.Host, expected: "http", }, { - uri: strings.TrimPrefix(httpsServer.URL, "https://"), + host: httpsServerBaseURL.Host, expected: "https", }, { - uri: tcpServer.Addr().String(), + host: tcpServerBaseURL.Host, expected: "tcp", }, } for _, tt := range tests { - res := checker.CheckProtocol(tt.uri) + service, servicePort := createServiceAndServicePort(t, tt.expected, tt.host) + res := checker.CheckProtocol(service, servicePort) require.Equal(t, tt.expected, res) } + + t.Run("caching prevents more than 1 network request to the same service", func(t *testing.T) { + service, servicePort := createServiceAndServicePort(t, "https", httpsServerBaseURL.Host) + checker.CheckProtocol(service, servicePort) + // There can only be two hits recorded: one for the HTTPS Server and another one for the TCP Server. + // The HTTP Server does not generate a network hit. See above. + require.Equal(t, int32(2), totalNetworkHits.Load()) + }) +} + +func createServiceAndServicePort(t *testing.T, serviceName, host string) (corev1.Service, corev1.ServicePort) { + host, portString, err := net.SplitHostPort(host) + require.NoError(t, err) + port, err := strconv.Atoi(portString) + require.NoError(t, err) + service := corev1.Service{ + ObjectMeta: v1.ObjectMeta{ + Name: serviceName, + Namespace: "default", + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeExternalName, + ExternalName: host, + }, + } + servicePort := corev1.ServicePort{Port: int32(port)} + return service, servicePort } func newTCPServer(t *testing.T, handleConn func(net.Conn)) net.Listener { @@ -579,7 +635,7 @@ func TestAutoProtocolDetection(t *testing.T) { port.AppProtocol = &tt.appProtocol } - result := autoProtocolDetection("192.1.1.1", port, nil) + result := autoProtocolDetection(corev1.Service{}, port, nil) require.Equal(t, tt.expected, result) })