From 341a968e214769108ff5fc25ab7234eacd616097 Mon Sep 17 00:00:00 2001 From: Marco Dinis Date: Mon, 16 Dec 2024 17:10:18 +0000 Subject: [PATCH] Cache Kubernetes App Discovery port check results For the K8S Services that we couldn't auto detect the protocol, after trying to infer the port, cache the result. Cache is evicted when the K8S Service changes (we check the Service's ResourceVersion). --- lib/srv/discovery/discovery_test.go | 2 +- lib/srv/discovery/fetchers/kube_services.go | 61 ++++++++++++++---- .../discovery/fetchers/kube_services_test.go | 63 ++++++++++++++++--- 3 files changed, 105 insertions(+), 21 deletions(-) diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index e2a187357d084..f575d640b543d 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -996,7 +996,7 @@ func newMockKubeService(name, namespace, externalName string, labels, annotation type noopProtocolChecker struct{} // CheckProtocol for noopProtocolChecker just returns 'tcp' -func (*noopProtocolChecker) CheckProtocol(uri string) string { +func (*noopProtocolChecker) CheckProtocol(service corev1.Service, port corev1.ServicePort) string { return "tcp" } diff --git a/lib/srv/discovery/fetchers/kube_services.go b/lib/srv/discovery/fetchers/kube_services.go index 3574e0a31a851..6cb72070cb137 100644 --- a/lib/srv/discovery/fetchers/kube_services.go +++ b/lib/srv/discovery/fetchers/kube_services.go @@ -21,6 +21,7 @@ package fetchers import ( "context" "crypto/tls" + "errors" "fmt" "net/http" "slices" @@ -194,7 +195,7 @@ func (f *KubeAppFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, er case protoHTTPS, protoHTTP, protoTCP: portProtocols[port] = protocolAnnotation default: - if p := autoProtocolDetection(services.GetServiceFQDN(service), port, f.ProtocolChecker); p != protoTCP { + if p := autoProtocolDetection(service, port, f.ProtocolChecker); p != protoTCP { portProtocols[port] = p } } @@ -259,7 +260,7 @@ func (f *KubeAppFetcher) String() string { // - If port's name is `http` or number is 80 or 8080, we return `http` // - If protocol checker is available it will perform HTTP request to the service fqdn trying to find out protocol. If it // gives us result `http` or `https` we return it -func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolChecker) string { +func autoProtocolDetection(service v1.Service, port v1.ServicePort, pc ProtocolChecker) string { if port.AppProtocol != nil { switch p := strings.ToLower(*port.AppProtocol); p { case protoHTTP, protoHTTPS: @@ -281,8 +282,7 @@ func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolC } if pc != nil { - result := pc.CheckProtocol(fmt.Sprintf("%s:%d", serviceFQDN, port.Port)) - if result != protoTCP { + if result := pc.CheckProtocol(service, port); result != protoTCP { return result } } @@ -292,7 +292,7 @@ func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolC // ProtocolChecker is an interface used to check what protocol uri serves type ProtocolChecker interface { - CheckProtocol(uri string) string + CheckProtocol(service v1.Service, port v1.ServicePort) string } func getServicePorts(s v1.Service) ([]v1.ServicePort, error) { @@ -326,6 +326,21 @@ func getServicePorts(s v1.Service) ([]v1.ServicePort, error) { type ProtoChecker struct { InsecureSkipVerify bool client *http.Client + + // cacheKubernetesServiceProtocol maps a Kubernetes Service Namespace/Name to a tuple containing the Service's ResourceVersion and the Protocol. + // When the Kubernetes Service ResourceVersion changes, then we assume the protocol might've changed as well, so the cache is invalidated. + // Only protocol checkers that require a network connection are cached. + cacheKubernetesServiceProtocol map[kubernetesNameNamespace]appResourceVersionProtocol +} + +type appResourceVersionProtocol struct { + resourceVersion string + protocol string +} + +type kubernetesNameNamespace struct { + namespace string + name string } func NewProtoChecker(insecureSkipVerify bool) *ProtoChecker { @@ -341,23 +356,45 @@ func NewProtoChecker(insecureSkipVerify bool) *ProtoChecker { }, }, }, + cacheKubernetesServiceProtocol: make(map[kubernetesNameNamespace]appResourceVersionProtocol), } return p } -func (p *ProtoChecker) CheckProtocol(uri string) string { +func (p *ProtoChecker) CheckProtocol(service v1.Service, port v1.ServicePort) string { if p.client == nil { return protoTCP } - resp, err := p.client.Head(fmt.Sprintf("https://%s", uri)) - if err == nil { + key := kubernetesNameNamespace{namespace: service.Namespace, name: service.Name} + if versionProtocol, ok := p.cacheKubernetesServiceProtocol[key]; ok { + if versionProtocol.resourceVersion == service.ResourceVersion { + return versionProtocol.protocol + } + } + + var result string + + uri := fmt.Sprintf("https://%s:%d", services.GetServiceFQDN(service), port.Port) + resp, err := p.client.Head(uri) + switch { + case err == nil: + result = protoHTTPS _ = resp.Body.Close() - return protoHTTPS - } else if strings.Contains(err.Error(), "server gave HTTP response to HTTPS client") { - return protoHTTP + + case errors.Is(err, http.ErrSchemeMismatch): + result = protoHTTP + + default: + result = protoTCP + } - return protoTCP + p.cacheKubernetesServiceProtocol[key] = appResourceVersionProtocol{ + resourceVersion: service.ResourceVersion, + protocol: result, + } + + return result } diff --git a/lib/srv/discovery/fetchers/kube_services_test.go b/lib/srv/discovery/fetchers/kube_services_test.go index 9b139e35b151f..ea32105c86334 100644 --- a/lib/srv/discovery/fetchers/kube_services_test.go +++ b/lib/srv/discovery/fetchers/kube_services_test.go @@ -24,8 +24,11 @@ import ( "net" "net/http" "net/http/httptest" + "net/url" "slices" + "strconv" "strings" + "sync/atomic" "testing" "github.com/google/go-cmp/cmp" @@ -36,6 +39,7 @@ import ( "k8s.io/client-go/kubernetes/fake" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" ) @@ -305,7 +309,8 @@ type mockProtocolChecker struct { results map[string]string } -func (m *mockProtocolChecker) CheckProtocol(uri string) string { +func (m *mockProtocolChecker) CheckProtocol(service corev1.Service, port corev1.ServicePort) string { + uri := fmt.Sprintf("%s:%d", services.GetServiceFQDN(service), port.Port) if result, ok := m.results[uri]; ok { return result } @@ -455,16 +460,30 @@ func TestProtoChecker_CheckProtocol(t *testing.T) { t.Parallel() checker := NewProtoChecker(true) + totalNetworkHits := &atomic.Int32{} + httpsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + totalNetworkHits.Add(1) _, _ = fmt.Fprintln(w, "httpsServer") })) + httpsServerBaseURL, err := url.Parse(httpsServer.URL) + require.NoError(t, err) + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = fmt.Fprintln(w, "httpServer") + // this never gets called because the HTTP server will not accept the HTTPS request. })) + httpServerBaseURL, err := url.Parse(httpServer.URL) + require.NoError(t, err) + tcpServer := newTCPServer(t, func(conn net.Conn) { + totalNetworkHits.Add(1) _, _ = conn.Write([]byte("tcpServer")) _ = conn.Close() }) + tcpServerBaseURL := &url.URL{ + Host: tcpServer.Addr().String(), + } + t.Cleanup(func() { httpsServer.Close() httpServer.Close() @@ -472,27 +491,55 @@ func TestProtoChecker_CheckProtocol(t *testing.T) { }) tests := []struct { - uri string + host string expected string }{ { - uri: strings.TrimPrefix(httpServer.URL, "http://"), + host: httpServerBaseURL.Host, expected: "http", }, { - uri: strings.TrimPrefix(httpsServer.URL, "https://"), + host: httpsServerBaseURL.Host, expected: "https", }, { - uri: tcpServer.Addr().String(), + host: tcpServerBaseURL.Host, expected: "tcp", }, } for _, tt := range tests { - res := checker.CheckProtocol(tt.uri) + service, servicePort := createServiceAndServicePort(t, tt.expected, tt.host) + res := checker.CheckProtocol(service, servicePort) require.Equal(t, tt.expected, res) } + + t.Run("caching prevents more than 1 network request to the same service", func(t *testing.T) { + service, servicePort := createServiceAndServicePort(t, "https", httpsServerBaseURL.Host) + checker.CheckProtocol(service, servicePort) + // There can only be two hits recorded: one for the HTTPS Server and another one for the TCP Server. + // The HTTP Server does not generate a network hit. See above. + require.Equal(t, int32(2), totalNetworkHits.Load()) + }) +} + +func createServiceAndServicePort(t *testing.T, serviceName, host string) (corev1.Service, corev1.ServicePort) { + host, portString, err := net.SplitHostPort(host) + require.NoError(t, err) + port, err := strconv.Atoi(portString) + require.NoError(t, err) + service := corev1.Service{ + ObjectMeta: v1.ObjectMeta{ + Name: serviceName, + Namespace: "default", + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeExternalName, + ExternalName: host, + }, + } + servicePort := corev1.ServicePort{Port: int32(port)} + return service, servicePort } func newTCPServer(t *testing.T, handleConn func(net.Conn)) net.Listener { @@ -579,7 +626,7 @@ func TestAutoProtocolDetection(t *testing.T) { port.AppProtocol = &tt.appProtocol } - result := autoProtocolDetection("192.1.1.1", port, nil) + result := autoProtocolDetection(corev1.Service{}, port, nil) require.Equal(t, tt.expected, result) })