Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache Kubernetes App Discovery port check results #50286

Merged
merged 3 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/srv/discovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ kubernetes matchers are present.`)
}

if c.protocolChecker == nil {
c.protocolChecker = fetchers.NewProtoChecker(false)
c.protocolChecker = fetchers.NewProtoChecker()
}

if c.PollInterval == 0 {
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/discovery/discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ func newMockKubeService(name, namespace, externalName string, labels, annotation
type noopProtocolChecker struct{}

// CheckProtocol for noopProtocolChecker just returns 'tcp'
func (*noopProtocolChecker) CheckProtocol(uri string) string {
func (*noopProtocolChecker) CheckProtocol(service corev1.Service, port corev1.ServicePort) string {
return "tcp"
}

Expand Down
81 changes: 58 additions & 23 deletions lib/srv/discovery/fetchers/kube_services.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package fetchers

import (
"context"
"crypto/tls"
"errors"
"fmt"
"log/slog"
"net/http"
Expand Down Expand Up @@ -73,7 +73,7 @@ func (k *KubeAppsFetcherConfig) CheckAndSetDefaults() error {
return trace.BadParameter("missing parameter ClusterName")
}
if k.ProtocolChecker == nil {
k.ProtocolChecker = NewProtoChecker(false)
k.ProtocolChecker = NewProtoChecker()
}

return nil
Expand Down Expand Up @@ -194,7 +194,7 @@ func (f *KubeAppFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, er
case protoHTTPS, protoHTTP, protoTCP:
portProtocols[port] = protocolAnnotation
default:
if p := autoProtocolDetection(services.GetServiceFQDN(service), port, f.ProtocolChecker); p != protoTCP {
if p := autoProtocolDetection(service, port, f.ProtocolChecker); p != protoTCP {
portProtocols[port] = p
}
}
Expand Down Expand Up @@ -259,7 +259,7 @@ func (f *KubeAppFetcher) String() string {
// - If port's name is `http` or number is 80 or 8080, we return `http`
// - If protocol checker is available it will perform HTTP request to the service fqdn trying to find out protocol. If it
// gives us result `http` or `https` we return it
func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolChecker) string {
func autoProtocolDetection(service v1.Service, port v1.ServicePort, pc ProtocolChecker) string {
if port.AppProtocol != nil {
switch p := strings.ToLower(*port.AppProtocol); p {
case protoHTTP, protoHTTPS:
Expand All @@ -281,8 +281,7 @@ func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolC
}

if pc != nil {
result := pc.CheckProtocol(fmt.Sprintf("%s:%d", serviceFQDN, port.Port))
if result != protoTCP {
if result := pc.CheckProtocol(service, port); result != protoTCP {
return result
}
}
Expand All @@ -292,7 +291,7 @@ func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolC

// ProtocolChecker is an interface used to check what protocol uri serves
type ProtocolChecker interface {
CheckProtocol(uri string) string
CheckProtocol(service v1.Service, port v1.ServicePort) string
}

func getServicePorts(s v1.Service) ([]v1.ServicePort, error) {
Expand Down Expand Up @@ -324,40 +323,76 @@ func getServicePorts(s v1.Service) ([]v1.ServicePort, error) {
}

type ProtoChecker struct {
InsecureSkipVerify bool
client *http.Client
client *http.Client

// cacheKubernetesServiceProtocol maps a Kubernetes Service Namespace/Name to a tuple containing the Service's ResourceVersion and the Protocol.
// When the Kubernetes Service ResourceVersion changes, then we assume the protocol might've changed as well, so the cache is invalidated.
// Only protocol checkers that require a network connection are cached.
cacheKubernetesServiceProtocol map[kubernetesNameNamespace]appResourceVersionProtocol
marcoandredinis marked this conversation as resolved.
Show resolved Hide resolved
cacheMU sync.RWMutex
}

type appResourceVersionProtocol struct {
resourceVersion string
protocol string
}

type kubernetesNameNamespace struct {
namespace string
name string
}

func NewProtoChecker(insecureSkipVerify bool) *ProtoChecker {
func NewProtoChecker() *ProtoChecker {
p := &ProtoChecker{
InsecureSkipVerify: insecureSkipVerify,
client: &http.Client{
// This is a best-effort scenario, where teleport tries to guess which protocol is being used.
// Ideally it should either be inferred by the Service's ports or explicitly configured by using annotations on the service.
Timeout: 500 * time.Millisecond,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: insecureSkipVerify,
},
},
},
cacheKubernetesServiceProtocol: make(map[kubernetesNameNamespace]appResourceVersionProtocol),
}

return p
}

func (p *ProtoChecker) CheckProtocol(uri string) string {
func (p *ProtoChecker) CheckProtocol(service v1.Service, port v1.ServicePort) string {
if p.client == nil {
return protoTCP
}

resp, err := p.client.Head(fmt.Sprintf("https://%s", uri))
if err == nil {
key := kubernetesNameNamespace{namespace: service.Namespace, name: service.Name}

p.cacheMU.RLock()
versionProtocol, keyIsCached := p.cacheKubernetesServiceProtocol[key]
p.cacheMU.RUnlock()

if keyIsCached && versionProtocol.resourceVersion == service.ResourceVersion {
return versionProtocol.protocol
}

var result string

uri := fmt.Sprintf("https://%s:%d", services.GetServiceFQDN(service), port.Port)
resp, err := p.client.Head(uri)
switch {
case err == nil:
result = protoHTTPS
_ = resp.Body.Close()
return protoHTTPS
} else if strings.Contains(err.Error(), "server gave HTTP response to HTTPS client") {
return protoHTTP

case errors.Is(err, http.ErrSchemeMismatch):
result = protoHTTP

default:
result = protoTCP

}

return protoTCP
p.cacheMU.Lock()
p.cacheKubernetesServiceProtocol[key] = appResourceVersionProtocol{
resourceVersion: service.ResourceVersion,
protocol: result,
}
p.cacheMU.Unlock()

return result
}
74 changes: 65 additions & 9 deletions lib/srv/discovery/fetchers/kube_services_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@ package fetchers

import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"slices"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
Expand All @@ -36,6 +41,7 @@ import (
"k8s.io/client-go/kubernetes/fake"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
)

Expand Down Expand Up @@ -305,7 +311,8 @@ type mockProtocolChecker struct {
results map[string]string
}

func (m *mockProtocolChecker) CheckProtocol(uri string) string {
func (m *mockProtocolChecker) CheckProtocol(service corev1.Service, port corev1.ServicePort) string {
uri := fmt.Sprintf("%s:%d", services.GetServiceFQDN(service), port.Port)
if result, ok := m.results[uri]; ok {
return result
}
Expand Down Expand Up @@ -453,46 +460,95 @@ func TestGetServicePorts(t *testing.T) {

func TestProtoChecker_CheckProtocol(t *testing.T) {
t.Parallel()
checker := NewProtoChecker(true)
checker := NewProtoChecker()
// Increasing client Timeout because CI/CD fails with a lower value.
checker.client.Timeout = 5 * time.Second

// Allow connections to HTTPS server created below.
checker.client.Transport = &http.Transport{TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
}}

totalNetworkHits := &atomic.Int32{}

httpsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
totalNetworkHits.Add(1)
_, _ = fmt.Fprintln(w, "httpsServer")
}))
httpsServerBaseURL, err := url.Parse(httpsServer.URL)
require.NoError(t, err)

httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = fmt.Fprintln(w, "httpServer")
// this never gets called because the HTTP server will not accept the HTTPS request.
}))
httpServerBaseURL, err := url.Parse(httpServer.URL)
require.NoError(t, err)

tcpServer := newTCPServer(t, func(conn net.Conn) {
totalNetworkHits.Add(1)
_, _ = conn.Write([]byte("tcpServer"))
_ = conn.Close()
})
tcpServerBaseURL := &url.URL{
Host: tcpServer.Addr().String(),
}

t.Cleanup(func() {
httpsServer.Close()
httpServer.Close()
_ = tcpServer.Close()
})

tests := []struct {
uri string
host string
expected string
}{
{
uri: strings.TrimPrefix(httpServer.URL, "http://"),
host: httpServerBaseURL.Host,
expected: "http",
},
{
uri: strings.TrimPrefix(httpsServer.URL, "https://"),
host: httpsServerBaseURL.Host,
expected: "https",
},
{
uri: tcpServer.Addr().String(),
host: tcpServerBaseURL.Host,
expected: "tcp",
},
}

for _, tt := range tests {
res := checker.CheckProtocol(tt.uri)
service, servicePort := createServiceAndServicePort(t, tt.expected, tt.host)
res := checker.CheckProtocol(service, servicePort)
require.Equal(t, tt.expected, res)
}

t.Run("caching prevents more than 1 network request to the same service", func(t *testing.T) {
service, servicePort := createServiceAndServicePort(t, "https", httpsServerBaseURL.Host)
checker.CheckProtocol(service, servicePort)
// There can only be two hits recorded: one for the HTTPS Server and another one for the TCP Server.
// The HTTP Server does not generate a network hit. See above.
require.Equal(t, int32(2), totalNetworkHits.Load())
})
}

func createServiceAndServicePort(t *testing.T, serviceName, host string) (corev1.Service, corev1.ServicePort) {
host, portString, err := net.SplitHostPort(host)
require.NoError(t, err)
port, err := strconv.Atoi(portString)
require.NoError(t, err)
service := corev1.Service{
ObjectMeta: v1.ObjectMeta{
Name: serviceName,
Namespace: "default",
},
Spec: corev1.ServiceSpec{
Type: corev1.ServiceTypeExternalName,
ExternalName: host,
},
}
servicePort := corev1.ServicePort{Port: int32(port)}
return service, servicePort
}

func newTCPServer(t *testing.T, handleConn func(net.Conn)) net.Listener {
Expand Down Expand Up @@ -579,7 +635,7 @@ func TestAutoProtocolDetection(t *testing.T) {
port.AppProtocol = &tt.appProtocol
}

result := autoProtocolDetection("192.1.1.1", port, nil)
result := autoProtocolDetection(corev1.Service{}, port, nil)

require.Equal(t, tt.expected, result)
})
Expand Down
Loading