diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index ea784c8b9718a..f74b8233c0f46 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -238,7 +238,7 @@ kubernetes matchers are present.`) c.LegacyLogger = logrus.New() } if c.protocolChecker == nil { - c.protocolChecker = fetchers.NewProtoChecker(false) + c.protocolChecker = fetchers.NewProtoChecker() } if c.PollInterval == 0 { diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index 9c9ebbee0c788..79ec90f2113d3 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -993,7 +993,7 @@ func newMockKubeService(name, namespace, externalName string, labels, annotation type noopProtocolChecker struct{} // CheckProtocol for noopProtocolChecker just returns 'tcp' -func (*noopProtocolChecker) CheckProtocol(uri string) string { +func (*noopProtocolChecker) CheckProtocol(service corev1.Service, port corev1.ServicePort) string { return "tcp" } diff --git a/lib/srv/discovery/fetchers/kube_services.go b/lib/srv/discovery/fetchers/kube_services.go index bc44a9c5cc153..1ba23b03c1ee3 100644 --- a/lib/srv/discovery/fetchers/kube_services.go +++ b/lib/srv/discovery/fetchers/kube_services.go @@ -20,7 +20,7 @@ package fetchers import ( "context" - "crypto/tls" + "errors" "fmt" "net/http" "slices" @@ -73,7 +73,7 @@ func (k *KubeAppsFetcherConfig) CheckAndSetDefaults() error { return trace.BadParameter("missing parameter ClusterName") } if k.ProtocolChecker == nil { - k.ProtocolChecker = NewProtoChecker(false) + k.ProtocolChecker = NewProtoChecker() } return nil @@ -194,7 +194,7 @@ func (f *KubeAppFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, er case protoHTTPS, protoHTTP, protoTCP: portProtocols[port] = protocolAnnotation default: - if p := autoProtocolDetection(services.GetServiceFQDN(service), port, f.ProtocolChecker); p != protoTCP { + if p := autoProtocolDetection(service, port, f.ProtocolChecker); p != protoTCP { portProtocols[port] = p } } @@ -259,7 +259,7 @@ func (f *KubeAppFetcher) String() string { // - If port's name is `http` or number is 80 or 8080, we return `http` // - If protocol checker is available it will perform HTTP request to the service fqdn trying to find out protocol. If it // gives us result `http` or `https` we return it -func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolChecker) string { +func autoProtocolDetection(service v1.Service, port v1.ServicePort, pc ProtocolChecker) string { if port.AppProtocol != nil { switch p := strings.ToLower(*port.AppProtocol); p { case protoHTTP, protoHTTPS: @@ -281,8 +281,7 @@ func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolC } if pc != nil { - result := pc.CheckProtocol(fmt.Sprintf("%s:%d", serviceFQDN, port.Port)) - if result != protoTCP { + if result := pc.CheckProtocol(service, port); result != protoTCP { return result } } @@ -292,7 +291,7 @@ func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolC // ProtocolChecker is an interface used to check what protocol uri serves type ProtocolChecker interface { - CheckProtocol(uri string) string + CheckProtocol(service v1.Service, port v1.ServicePort) string } func getServicePorts(s v1.Service) ([]v1.ServicePort, error) { @@ -324,40 +323,76 @@ func getServicePorts(s v1.Service) ([]v1.ServicePort, error) { } type ProtoChecker struct { - InsecureSkipVerify bool - client *http.Client + client *http.Client + + // cacheKubernetesServiceProtocol maps a Kubernetes Service Namespace/Name to a tuple containing the Service's ResourceVersion and the Protocol. + // When the Kubernetes Service ResourceVersion changes, then we assume the protocol might've changed as well, so the cache is invalidated. + // Only protocol checkers that require a network connection are cached. + cacheKubernetesServiceProtocol map[kubernetesNameNamespace]appResourceVersionProtocol + cacheMU sync.RWMutex +} + +type appResourceVersionProtocol struct { + resourceVersion string + protocol string +} + +type kubernetesNameNamespace struct { + namespace string + name string } -func NewProtoChecker(insecureSkipVerify bool) *ProtoChecker { +func NewProtoChecker() *ProtoChecker { p := &ProtoChecker{ - InsecureSkipVerify: insecureSkipVerify, client: &http.Client{ // This is a best-effort scenario, where teleport tries to guess which protocol is being used. // Ideally it should either be inferred by the Service's ports or explicitly configured by using annotations on the service. Timeout: 500 * time.Millisecond, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: insecureSkipVerify, - }, - }, }, + cacheKubernetesServiceProtocol: make(map[kubernetesNameNamespace]appResourceVersionProtocol), } return p } -func (p *ProtoChecker) CheckProtocol(uri string) string { +func (p *ProtoChecker) CheckProtocol(service v1.Service, port v1.ServicePort) string { if p.client == nil { return protoTCP } - resp, err := p.client.Head(fmt.Sprintf("https://%s", uri)) - if err == nil { + key := kubernetesNameNamespace{namespace: service.Namespace, name: service.Name} + + p.cacheMU.RLock() + versionProtocol, keyIsCached := p.cacheKubernetesServiceProtocol[key] + p.cacheMU.RUnlock() + + if keyIsCached && versionProtocol.resourceVersion == service.ResourceVersion { + return versionProtocol.protocol + } + + var result string + + uri := fmt.Sprintf("https://%s:%d", services.GetServiceFQDN(service), port.Port) + resp, err := p.client.Head(uri) + switch { + case err == nil: + result = protoHTTPS _ = resp.Body.Close() - return protoHTTPS - } else if strings.Contains(err.Error(), "server gave HTTP response to HTTPS client") { - return protoHTTP + + case errors.Is(err, http.ErrSchemeMismatch): + result = protoHTTP + + default: + result = protoTCP + } - return protoTCP + p.cacheMU.Lock() + p.cacheKubernetesServiceProtocol[key] = appResourceVersionProtocol{ + resourceVersion: service.ResourceVersion, + protocol: result, + } + p.cacheMU.Unlock() + + return result } diff --git a/lib/srv/discovery/fetchers/kube_services_test.go b/lib/srv/discovery/fetchers/kube_services_test.go index 9b139e35b151f..4502c0ab0b6ff 100644 --- a/lib/srv/discovery/fetchers/kube_services_test.go +++ b/lib/srv/discovery/fetchers/kube_services_test.go @@ -20,13 +20,18 @@ package fetchers import ( "context" + "crypto/tls" "fmt" "net" "net/http" "net/http/httptest" + "net/url" "slices" + "strconv" "strings" + "sync/atomic" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" @@ -36,6 +41,7 @@ import ( "k8s.io/client-go/kubernetes/fake" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" ) @@ -305,7 +311,8 @@ type mockProtocolChecker struct { results map[string]string } -func (m *mockProtocolChecker) CheckProtocol(uri string) string { +func (m *mockProtocolChecker) CheckProtocol(service corev1.Service, port corev1.ServicePort) string { + uri := fmt.Sprintf("%s:%d", services.GetServiceFQDN(service), port.Port) if result, ok := m.results[uri]; ok { return result } @@ -453,18 +460,39 @@ func TestGetServicePorts(t *testing.T) { func TestProtoChecker_CheckProtocol(t *testing.T) { t.Parallel() - checker := NewProtoChecker(true) + checker := NewProtoChecker() + // Increasing client Timeout because CI/CD fails with a lower value. + checker.client.Timeout = 5 * time.Second + + // Allow connections to HTTPS server created below. + checker.client.Transport = &http.Transport{TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }} + + totalNetworkHits := &atomic.Int32{} httpsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + totalNetworkHits.Add(1) _, _ = fmt.Fprintln(w, "httpsServer") })) + httpsServerBaseURL, err := url.Parse(httpsServer.URL) + require.NoError(t, err) + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = fmt.Fprintln(w, "httpServer") + // this never gets called because the HTTP server will not accept the HTTPS request. })) + httpServerBaseURL, err := url.Parse(httpServer.URL) + require.NoError(t, err) + tcpServer := newTCPServer(t, func(conn net.Conn) { + totalNetworkHits.Add(1) _, _ = conn.Write([]byte("tcpServer")) _ = conn.Close() }) + tcpServerBaseURL := &url.URL{ + Host: tcpServer.Addr().String(), + } + t.Cleanup(func() { httpsServer.Close() httpServer.Close() @@ -472,27 +500,55 @@ func TestProtoChecker_CheckProtocol(t *testing.T) { }) tests := []struct { - uri string + host string expected string }{ { - uri: strings.TrimPrefix(httpServer.URL, "http://"), + host: httpServerBaseURL.Host, expected: "http", }, { - uri: strings.TrimPrefix(httpsServer.URL, "https://"), + host: httpsServerBaseURL.Host, expected: "https", }, { - uri: tcpServer.Addr().String(), + host: tcpServerBaseURL.Host, expected: "tcp", }, } for _, tt := range tests { - res := checker.CheckProtocol(tt.uri) + service, servicePort := createServiceAndServicePort(t, tt.expected, tt.host) + res := checker.CheckProtocol(service, servicePort) require.Equal(t, tt.expected, res) } + + t.Run("caching prevents more than 1 network request to the same service", func(t *testing.T) { + service, servicePort := createServiceAndServicePort(t, "https", httpsServerBaseURL.Host) + checker.CheckProtocol(service, servicePort) + // There can only be two hits recorded: one for the HTTPS Server and another one for the TCP Server. + // The HTTP Server does not generate a network hit. See above. + require.Equal(t, int32(2), totalNetworkHits.Load()) + }) +} + +func createServiceAndServicePort(t *testing.T, serviceName, host string) (corev1.Service, corev1.ServicePort) { + host, portString, err := net.SplitHostPort(host) + require.NoError(t, err) + port, err := strconv.Atoi(portString) + require.NoError(t, err) + service := corev1.Service{ + ObjectMeta: v1.ObjectMeta{ + Name: serviceName, + Namespace: "default", + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeExternalName, + ExternalName: host, + }, + } + servicePort := corev1.ServicePort{Port: int32(port)} + return service, servicePort } func newTCPServer(t *testing.T, handleConn func(net.Conn)) net.Listener { @@ -579,7 +635,7 @@ func TestAutoProtocolDetection(t *testing.T) { port.AppProtocol = &tt.appProtocol } - result := autoProtocolDetection("192.1.1.1", port, nil) + result := autoProtocolDetection(corev1.Service{}, port, nil) require.Equal(t, tt.expected, result) })