diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index e2a187357d084..f575d640b543d 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -996,7 +996,7 @@ func newMockKubeService(name, namespace, externalName string, labels, annotation type noopProtocolChecker struct{} // CheckProtocol for noopProtocolChecker just returns 'tcp' -func (*noopProtocolChecker) CheckProtocol(uri string) string { +func (*noopProtocolChecker) CheckProtocol(service corev1.Service, port corev1.ServicePort) string { return "tcp" } diff --git a/lib/srv/discovery/fetchers/kube_services.go b/lib/srv/discovery/fetchers/kube_services.go index 3574e0a31a851..6cb72070cb137 100644 --- a/lib/srv/discovery/fetchers/kube_services.go +++ b/lib/srv/discovery/fetchers/kube_services.go @@ -21,6 +21,7 @@ package fetchers import ( "context" "crypto/tls" + "errors" "fmt" "net/http" "slices" @@ -194,7 +195,7 @@ func (f *KubeAppFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, er case protoHTTPS, protoHTTP, protoTCP: portProtocols[port] = protocolAnnotation default: - if p := autoProtocolDetection(services.GetServiceFQDN(service), port, f.ProtocolChecker); p != protoTCP { + if p := autoProtocolDetection(service, port, f.ProtocolChecker); p != protoTCP { portProtocols[port] = p } } @@ -259,7 +260,7 @@ func (f *KubeAppFetcher) String() string { // - If port's name is `http` or number is 80 or 8080, we return `http` // - If protocol checker is available it will perform HTTP request to the service fqdn trying to find out protocol. If it // gives us result `http` or `https` we return it -func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolChecker) string { +func autoProtocolDetection(service v1.Service, port v1.ServicePort, pc ProtocolChecker) string { if port.AppProtocol != nil { switch p := strings.ToLower(*port.AppProtocol); p { case protoHTTP, protoHTTPS: @@ -281,8 +282,7 @@ func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolC } if pc != nil { - result := pc.CheckProtocol(fmt.Sprintf("%s:%d", serviceFQDN, port.Port)) - if result != protoTCP { + if result := pc.CheckProtocol(service, port); result != protoTCP { return result } } @@ -292,7 +292,7 @@ func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolC // ProtocolChecker is an interface used to check what protocol uri serves type ProtocolChecker interface { - CheckProtocol(uri string) string + CheckProtocol(service v1.Service, port v1.ServicePort) string } func getServicePorts(s v1.Service) ([]v1.ServicePort, error) { @@ -326,6 +326,21 @@ func getServicePorts(s v1.Service) ([]v1.ServicePort, error) { type ProtoChecker struct { InsecureSkipVerify bool client *http.Client + + // cacheKubernetesServiceProtocol maps a Kubernetes Service Namespace/Name to a tuple containing the Service's ResourceVersion and the Protocol. + // When the Kubernetes Service ResourceVersion changes, then we assume the protocol might've changed as well, so the cache is invalidated. + // Only protocol checkers that require a network connection are cached. + cacheKubernetesServiceProtocol map[kubernetesNameNamespace]appResourceVersionProtocol +} + +type appResourceVersionProtocol struct { + resourceVersion string + protocol string +} + +type kubernetesNameNamespace struct { + namespace string + name string } func NewProtoChecker(insecureSkipVerify bool) *ProtoChecker { @@ -341,23 +356,45 @@ func NewProtoChecker(insecureSkipVerify bool) *ProtoChecker { }, }, }, + cacheKubernetesServiceProtocol: make(map[kubernetesNameNamespace]appResourceVersionProtocol), } return p } -func (p *ProtoChecker) CheckProtocol(uri string) string { +func (p *ProtoChecker) CheckProtocol(service v1.Service, port v1.ServicePort) string { if p.client == nil { return protoTCP } - resp, err := p.client.Head(fmt.Sprintf("https://%s", uri)) - if err == nil { + key := kubernetesNameNamespace{namespace: service.Namespace, name: service.Name} + if versionProtocol, ok := p.cacheKubernetesServiceProtocol[key]; ok { + if versionProtocol.resourceVersion == service.ResourceVersion { + return versionProtocol.protocol + } + } + + var result string + + uri := fmt.Sprintf("https://%s:%d", services.GetServiceFQDN(service), port.Port) + resp, err := p.client.Head(uri) + switch { + case err == nil: + result = protoHTTPS _ = resp.Body.Close() - return protoHTTPS - } else if strings.Contains(err.Error(), "server gave HTTP response to HTTPS client") { - return protoHTTP + + case errors.Is(err, http.ErrSchemeMismatch): + result = protoHTTP + + default: + result = protoTCP + } - return protoTCP + p.cacheKubernetesServiceProtocol[key] = appResourceVersionProtocol{ + resourceVersion: service.ResourceVersion, + protocol: result, + } + + return result } diff --git a/lib/srv/discovery/fetchers/kube_services_test.go b/lib/srv/discovery/fetchers/kube_services_test.go index 9b139e35b151f..ea32105c86334 100644 --- a/lib/srv/discovery/fetchers/kube_services_test.go +++ b/lib/srv/discovery/fetchers/kube_services_test.go @@ -24,8 +24,11 @@ import ( "net" "net/http" "net/http/httptest" + "net/url" "slices" + "strconv" "strings" + "sync/atomic" "testing" "github.com/google/go-cmp/cmp" @@ -36,6 +39,7 @@ import ( "k8s.io/client-go/kubernetes/fake" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" ) @@ -305,7 +309,8 @@ type mockProtocolChecker struct { results map[string]string } -func (m *mockProtocolChecker) CheckProtocol(uri string) string { +func (m *mockProtocolChecker) CheckProtocol(service corev1.Service, port corev1.ServicePort) string { + uri := fmt.Sprintf("%s:%d", services.GetServiceFQDN(service), port.Port) if result, ok := m.results[uri]; ok { return result } @@ -455,16 +460,30 @@ func TestProtoChecker_CheckProtocol(t *testing.T) { t.Parallel() checker := NewProtoChecker(true) + totalNetworkHits := &atomic.Int32{} + httpsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + totalNetworkHits.Add(1) _, _ = fmt.Fprintln(w, "httpsServer") })) + httpsServerBaseURL, err := url.Parse(httpsServer.URL) + require.NoError(t, err) + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = fmt.Fprintln(w, "httpServer") + // this never gets called because the HTTP server will not accept the HTTPS request. })) + httpServerBaseURL, err := url.Parse(httpServer.URL) + require.NoError(t, err) + tcpServer := newTCPServer(t, func(conn net.Conn) { + totalNetworkHits.Add(1) _, _ = conn.Write([]byte("tcpServer")) _ = conn.Close() }) + tcpServerBaseURL := &url.URL{ + Host: tcpServer.Addr().String(), + } + t.Cleanup(func() { httpsServer.Close() httpServer.Close() @@ -472,27 +491,55 @@ func TestProtoChecker_CheckProtocol(t *testing.T) { }) tests := []struct { - uri string + host string expected string }{ { - uri: strings.TrimPrefix(httpServer.URL, "http://"), + host: httpServerBaseURL.Host, expected: "http", }, { - uri: strings.TrimPrefix(httpsServer.URL, "https://"), + host: httpsServerBaseURL.Host, expected: "https", }, { - uri: tcpServer.Addr().String(), + host: tcpServerBaseURL.Host, expected: "tcp", }, } for _, tt := range tests { - res := checker.CheckProtocol(tt.uri) + service, servicePort := createServiceAndServicePort(t, tt.expected, tt.host) + res := checker.CheckProtocol(service, servicePort) require.Equal(t, tt.expected, res) } + + t.Run("caching prevents more than 1 network request to the same service", func(t *testing.T) { + service, servicePort := createServiceAndServicePort(t, "https", httpsServerBaseURL.Host) + checker.CheckProtocol(service, servicePort) + // There can only be two hits recorded: one for the HTTPS Server and another one for the TCP Server. + // The HTTP Server does not generate a network hit. See above. + require.Equal(t, int32(2), totalNetworkHits.Load()) + }) +} + +func createServiceAndServicePort(t *testing.T, serviceName, host string) (corev1.Service, corev1.ServicePort) { + host, portString, err := net.SplitHostPort(host) + require.NoError(t, err) + port, err := strconv.Atoi(portString) + require.NoError(t, err) + service := corev1.Service{ + ObjectMeta: v1.ObjectMeta{ + Name: serviceName, + Namespace: "default", + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeExternalName, + ExternalName: host, + }, + } + servicePort := corev1.ServicePort{Port: int32(port)} + return service, servicePort } func newTCPServer(t *testing.T, handleConn func(net.Conn)) net.Listener { @@ -579,7 +626,7 @@ func TestAutoProtocolDetection(t *testing.T) { port.AppProtocol = &tt.appProtocol } - result := autoProtocolDetection("192.1.1.1", port, nil) + result := autoProtocolDetection(corev1.Service{}, port, nil) require.Equal(t, tt.expected, result) })