Skip to content

Commit

Permalink
[v16] Cache Kubernetes App Discovery port check results (#50342)
Browse files Browse the repository at this point in the history
* Cache Kubernetes App Discovery port check results

For the K8S Services that we couldn't auto detect the protocol, after
trying to infer the port, cache the result.
Cache is evicted when the K8S Service changes (we check the Service's
ResourceVersion).

* add mutex to cached responses

* fix flaky test when hitting the HTTPS server
  • Loading branch information
marcoandredinis authored Dec 17, 2024
1 parent d9cba81 commit acbae80
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 34 deletions.
2 changes: 1 addition & 1 deletion lib/srv/discovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ kubernetes matchers are present.`)
c.LegacyLogger = logrus.New()
}
if c.protocolChecker == nil {
c.protocolChecker = fetchers.NewProtoChecker(false)
c.protocolChecker = fetchers.NewProtoChecker()
}

if c.PollInterval == 0 {
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/discovery/discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ func newMockKubeService(name, namespace, externalName string, labels, annotation
type noopProtocolChecker struct{}

// CheckProtocol for noopProtocolChecker just returns 'tcp'
func (*noopProtocolChecker) CheckProtocol(uri string) string {
func (*noopProtocolChecker) CheckProtocol(service corev1.Service, port corev1.ServicePort) string {
return "tcp"
}

Expand Down
81 changes: 58 additions & 23 deletions lib/srv/discovery/fetchers/kube_services.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package fetchers

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/http"
"slices"
Expand Down Expand Up @@ -73,7 +73,7 @@ func (k *KubeAppsFetcherConfig) CheckAndSetDefaults() error {
return trace.BadParameter("missing parameter ClusterName")
}
if k.ProtocolChecker == nil {
k.ProtocolChecker = NewProtoChecker(false)
k.ProtocolChecker = NewProtoChecker()
}

return nil
Expand Down Expand Up @@ -194,7 +194,7 @@ func (f *KubeAppFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, er
case protoHTTPS, protoHTTP, protoTCP:
portProtocols[port] = protocolAnnotation
default:
if p := autoProtocolDetection(services.GetServiceFQDN(service), port, f.ProtocolChecker); p != protoTCP {
if p := autoProtocolDetection(service, port, f.ProtocolChecker); p != protoTCP {
portProtocols[port] = p
}
}
Expand Down Expand Up @@ -259,7 +259,7 @@ func (f *KubeAppFetcher) String() string {
// - If port's name is `http` or number is 80 or 8080, we return `http`
// - If protocol checker is available it will perform HTTP request to the service fqdn trying to find out protocol. If it
// gives us result `http` or `https` we return it
func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolChecker) string {
func autoProtocolDetection(service v1.Service, port v1.ServicePort, pc ProtocolChecker) string {
if port.AppProtocol != nil {
switch p := strings.ToLower(*port.AppProtocol); p {
case protoHTTP, protoHTTPS:
Expand All @@ -281,8 +281,7 @@ func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolC
}

if pc != nil {
result := pc.CheckProtocol(fmt.Sprintf("%s:%d", serviceFQDN, port.Port))
if result != protoTCP {
if result := pc.CheckProtocol(service, port); result != protoTCP {
return result
}
}
Expand All @@ -292,7 +291,7 @@ func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolC

// ProtocolChecker is an interface used to check what protocol uri serves
type ProtocolChecker interface {
CheckProtocol(uri string) string
CheckProtocol(service v1.Service, port v1.ServicePort) string
}

func getServicePorts(s v1.Service) ([]v1.ServicePort, error) {
Expand Down Expand Up @@ -324,40 +323,76 @@ func getServicePorts(s v1.Service) ([]v1.ServicePort, error) {
}

type ProtoChecker struct {
InsecureSkipVerify bool
client *http.Client
client *http.Client

// cacheKubernetesServiceProtocol maps a Kubernetes Service Namespace/Name to a tuple containing the Service's ResourceVersion and the Protocol.
// When the Kubernetes Service ResourceVersion changes, then we assume the protocol might've changed as well, so the cache is invalidated.
// Only protocol checkers that require a network connection are cached.
cacheKubernetesServiceProtocol map[kubernetesNameNamespace]appResourceVersionProtocol
cacheMU sync.RWMutex
}

type appResourceVersionProtocol struct {
resourceVersion string
protocol string
}

type kubernetesNameNamespace struct {
namespace string
name string
}

func NewProtoChecker(insecureSkipVerify bool) *ProtoChecker {
func NewProtoChecker() *ProtoChecker {
p := &ProtoChecker{
InsecureSkipVerify: insecureSkipVerify,
client: &http.Client{
// This is a best-effort scenario, where teleport tries to guess which protocol is being used.
// Ideally it should either be inferred by the Service's ports or explicitly configured by using annotations on the service.
Timeout: 500 * time.Millisecond,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: insecureSkipVerify,
},
},
},
cacheKubernetesServiceProtocol: make(map[kubernetesNameNamespace]appResourceVersionProtocol),
}

return p
}

func (p *ProtoChecker) CheckProtocol(uri string) string {
func (p *ProtoChecker) CheckProtocol(service v1.Service, port v1.ServicePort) string {
if p.client == nil {
return protoTCP
}

resp, err := p.client.Head(fmt.Sprintf("https://%s", uri))
if err == nil {
key := kubernetesNameNamespace{namespace: service.Namespace, name: service.Name}

p.cacheMU.RLock()
versionProtocol, keyIsCached := p.cacheKubernetesServiceProtocol[key]
p.cacheMU.RUnlock()

if keyIsCached && versionProtocol.resourceVersion == service.ResourceVersion {
return versionProtocol.protocol
}

var result string

uri := fmt.Sprintf("https://%s:%d", services.GetServiceFQDN(service), port.Port)
resp, err := p.client.Head(uri)
switch {
case err == nil:
result = protoHTTPS
_ = resp.Body.Close()
return protoHTTPS
} else if strings.Contains(err.Error(), "server gave HTTP response to HTTPS client") {
return protoHTTP

case errors.Is(err, http.ErrSchemeMismatch):
result = protoHTTP

default:
result = protoTCP

}

return protoTCP
p.cacheMU.Lock()
p.cacheKubernetesServiceProtocol[key] = appResourceVersionProtocol{
resourceVersion: service.ResourceVersion,
protocol: result,
}
p.cacheMU.Unlock()

return result
}
74 changes: 65 additions & 9 deletions lib/srv/discovery/fetchers/kube_services_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@ package fetchers

import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"slices"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
Expand All @@ -36,6 +41,7 @@ import (
"k8s.io/client-go/kubernetes/fake"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
)

Expand Down Expand Up @@ -305,7 +311,8 @@ type mockProtocolChecker struct {
results map[string]string
}

func (m *mockProtocolChecker) CheckProtocol(uri string) string {
func (m *mockProtocolChecker) CheckProtocol(service corev1.Service, port corev1.ServicePort) string {
uri := fmt.Sprintf("%s:%d", services.GetServiceFQDN(service), port.Port)
if result, ok := m.results[uri]; ok {
return result
}
Expand Down Expand Up @@ -453,46 +460,95 @@ func TestGetServicePorts(t *testing.T) {

func TestProtoChecker_CheckProtocol(t *testing.T) {
t.Parallel()
checker := NewProtoChecker(true)
checker := NewProtoChecker()
// Increasing client Timeout because CI/CD fails with a lower value.
checker.client.Timeout = 5 * time.Second

// Allow connections to HTTPS server created below.
checker.client.Transport = &http.Transport{TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
}}

totalNetworkHits := &atomic.Int32{}

httpsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
totalNetworkHits.Add(1)
_, _ = fmt.Fprintln(w, "httpsServer")
}))
httpsServerBaseURL, err := url.Parse(httpsServer.URL)
require.NoError(t, err)

httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = fmt.Fprintln(w, "httpServer")
// this never gets called because the HTTP server will not accept the HTTPS request.
}))
httpServerBaseURL, err := url.Parse(httpServer.URL)
require.NoError(t, err)

tcpServer := newTCPServer(t, func(conn net.Conn) {
totalNetworkHits.Add(1)
_, _ = conn.Write([]byte("tcpServer"))
_ = conn.Close()
})
tcpServerBaseURL := &url.URL{
Host: tcpServer.Addr().String(),
}

t.Cleanup(func() {
httpsServer.Close()
httpServer.Close()
_ = tcpServer.Close()
})

tests := []struct {
uri string
host string
expected string
}{
{
uri: strings.TrimPrefix(httpServer.URL, "http://"),
host: httpServerBaseURL.Host,
expected: "http",
},
{
uri: strings.TrimPrefix(httpsServer.URL, "https://"),
host: httpsServerBaseURL.Host,
expected: "https",
},
{
uri: tcpServer.Addr().String(),
host: tcpServerBaseURL.Host,
expected: "tcp",
},
}

for _, tt := range tests {
res := checker.CheckProtocol(tt.uri)
service, servicePort := createServiceAndServicePort(t, tt.expected, tt.host)
res := checker.CheckProtocol(service, servicePort)
require.Equal(t, tt.expected, res)
}

t.Run("caching prevents more than 1 network request to the same service", func(t *testing.T) {
service, servicePort := createServiceAndServicePort(t, "https", httpsServerBaseURL.Host)
checker.CheckProtocol(service, servicePort)
// There can only be two hits recorded: one for the HTTPS Server and another one for the TCP Server.
// The HTTP Server does not generate a network hit. See above.
require.Equal(t, int32(2), totalNetworkHits.Load())
})
}

func createServiceAndServicePort(t *testing.T, serviceName, host string) (corev1.Service, corev1.ServicePort) {
host, portString, err := net.SplitHostPort(host)
require.NoError(t, err)
port, err := strconv.Atoi(portString)
require.NoError(t, err)
service := corev1.Service{
ObjectMeta: v1.ObjectMeta{
Name: serviceName,
Namespace: "default",
},
Spec: corev1.ServiceSpec{
Type: corev1.ServiceTypeExternalName,
ExternalName: host,
},
}
servicePort := corev1.ServicePort{Port: int32(port)}
return service, servicePort
}

func newTCPServer(t *testing.T, handleConn func(net.Conn)) net.Listener {
Expand Down Expand Up @@ -579,7 +635,7 @@ func TestAutoProtocolDetection(t *testing.T) {
port.AppProtocol = &tt.appProtocol
}

result := autoProtocolDetection("192.1.1.1", port, nil)
result := autoProtocolDetection(corev1.Service{}, port, nil)

require.Equal(t, tt.expected, result)
})
Expand Down

0 comments on commit acbae80

Please sign in to comment.