Skip to content

Commit

Permalink
Cache Kubernetes App Discovery port check results
Browse files Browse the repository at this point in the history
For the K8S Services that we couldn't auto detect the protocol, after
trying to infer the port, cache the result.
Cache is evicted when the K8S Service changes (we check the Service's
ResourceVersion).
  • Loading branch information
marcoandredinis committed Dec 16, 2024
1 parent d5c6911 commit 341a968
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 21 deletions.
2 changes: 1 addition & 1 deletion lib/srv/discovery/discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ func newMockKubeService(name, namespace, externalName string, labels, annotation
type noopProtocolChecker struct{}

// CheckProtocol for noopProtocolChecker just returns 'tcp'
func (*noopProtocolChecker) CheckProtocol(uri string) string {
func (*noopProtocolChecker) CheckProtocol(service corev1.Service, port corev1.ServicePort) string {
return "tcp"
}

Expand Down
61 changes: 49 additions & 12 deletions lib/srv/discovery/fetchers/kube_services.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package fetchers
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/http"
"slices"
Expand Down Expand Up @@ -194,7 +195,7 @@ func (f *KubeAppFetcher) Get(ctx context.Context) (types.ResourcesWithLabels, er
case protoHTTPS, protoHTTP, protoTCP:
portProtocols[port] = protocolAnnotation
default:
if p := autoProtocolDetection(services.GetServiceFQDN(service), port, f.ProtocolChecker); p != protoTCP {
if p := autoProtocolDetection(service, port, f.ProtocolChecker); p != protoTCP {
portProtocols[port] = p
}
}
Expand Down Expand Up @@ -259,7 +260,7 @@ func (f *KubeAppFetcher) String() string {
// - If port's name is `http` or number is 80 or 8080, we return `http`
// - If protocol checker is available it will perform HTTP request to the service fqdn trying to find out protocol. If it
// gives us result `http` or `https` we return it
func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolChecker) string {
func autoProtocolDetection(service v1.Service, port v1.ServicePort, pc ProtocolChecker) string {
if port.AppProtocol != nil {
switch p := strings.ToLower(*port.AppProtocol); p {
case protoHTTP, protoHTTPS:
Expand All @@ -281,8 +282,7 @@ func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolC
}

if pc != nil {
result := pc.CheckProtocol(fmt.Sprintf("%s:%d", serviceFQDN, port.Port))
if result != protoTCP {
if result := pc.CheckProtocol(service, port); result != protoTCP {
return result
}
}
Expand All @@ -292,7 +292,7 @@ func autoProtocolDetection(serviceFQDN string, port v1.ServicePort, pc ProtocolC

// ProtocolChecker is an interface used to check what protocol uri serves
type ProtocolChecker interface {
CheckProtocol(uri string) string
CheckProtocol(service v1.Service, port v1.ServicePort) string
}

func getServicePorts(s v1.Service) ([]v1.ServicePort, error) {
Expand Down Expand Up @@ -326,6 +326,21 @@ func getServicePorts(s v1.Service) ([]v1.ServicePort, error) {
type ProtoChecker struct {
InsecureSkipVerify bool
client *http.Client

// cacheKubernetesServiceProtocol maps a Kubernetes Service Namespace/Name to a tuple containing the Service's ResourceVersion and the Protocol.
// When the Kubernetes Service ResourceVersion changes, then we assume the protocol might've changed as well, so the cache is invalidated.
// Only protocol checkers that require a network connection are cached.
cacheKubernetesServiceProtocol map[kubernetesNameNamespace]appResourceVersionProtocol
}

type appResourceVersionProtocol struct {
resourceVersion string
protocol string
}

type kubernetesNameNamespace struct {
namespace string
name string
}

func NewProtoChecker(insecureSkipVerify bool) *ProtoChecker {
Expand All @@ -341,23 +356,45 @@ func NewProtoChecker(insecureSkipVerify bool) *ProtoChecker {
},
},
},
cacheKubernetesServiceProtocol: make(map[kubernetesNameNamespace]appResourceVersionProtocol),
}

return p
}

func (p *ProtoChecker) CheckProtocol(uri string) string {
func (p *ProtoChecker) CheckProtocol(service v1.Service, port v1.ServicePort) string {
if p.client == nil {
return protoTCP
}

resp, err := p.client.Head(fmt.Sprintf("https://%s", uri))
if err == nil {
key := kubernetesNameNamespace{namespace: service.Namespace, name: service.Name}
if versionProtocol, ok := p.cacheKubernetesServiceProtocol[key]; ok {
if versionProtocol.resourceVersion == service.ResourceVersion {
return versionProtocol.protocol
}
}

var result string

uri := fmt.Sprintf("https://%s:%d", services.GetServiceFQDN(service), port.Port)
resp, err := p.client.Head(uri)
switch {
case err == nil:
result = protoHTTPS
_ = resp.Body.Close()
return protoHTTPS
} else if strings.Contains(err.Error(), "server gave HTTP response to HTTPS client") {
return protoHTTP

case errors.Is(err, http.ErrSchemeMismatch):
result = protoHTTP

default:
result = protoTCP

}

return protoTCP
p.cacheKubernetesServiceProtocol[key] = appResourceVersionProtocol{
resourceVersion: service.ResourceVersion,
protocol: result,
}

return result
}
63 changes: 55 additions & 8 deletions lib/srv/discovery/fetchers/kube_services_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ import (
"net"
"net/http"
"net/http/httptest"
"net/url"
"slices"
"strconv"
"strings"
"sync/atomic"
"testing"

"github.com/google/go-cmp/cmp"
Expand All @@ -36,6 +39,7 @@ import (
"k8s.io/client-go/kubernetes/fake"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
)

Expand Down Expand Up @@ -305,7 +309,8 @@ type mockProtocolChecker struct {
results map[string]string
}

func (m *mockProtocolChecker) CheckProtocol(uri string) string {
func (m *mockProtocolChecker) CheckProtocol(service corev1.Service, port corev1.ServicePort) string {
uri := fmt.Sprintf("%s:%d", services.GetServiceFQDN(service), port.Port)
if result, ok := m.results[uri]; ok {
return result
}
Expand Down Expand Up @@ -455,44 +460,86 @@ func TestProtoChecker_CheckProtocol(t *testing.T) {
t.Parallel()
checker := NewProtoChecker(true)

totalNetworkHits := &atomic.Int32{}

httpsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
totalNetworkHits.Add(1)
_, _ = fmt.Fprintln(w, "httpsServer")
}))
httpsServerBaseURL, err := url.Parse(httpsServer.URL)
require.NoError(t, err)

httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = fmt.Fprintln(w, "httpServer")
// this never gets called because the HTTP server will not accept the HTTPS request.
}))
httpServerBaseURL, err := url.Parse(httpServer.URL)
require.NoError(t, err)

tcpServer := newTCPServer(t, func(conn net.Conn) {
totalNetworkHits.Add(1)
_, _ = conn.Write([]byte("tcpServer"))
_ = conn.Close()
})
tcpServerBaseURL := &url.URL{
Host: tcpServer.Addr().String(),
}

t.Cleanup(func() {
httpsServer.Close()
httpServer.Close()
_ = tcpServer.Close()
})

tests := []struct {
uri string
host string
expected string
}{
{
uri: strings.TrimPrefix(httpServer.URL, "http://"),
host: httpServerBaseURL.Host,
expected: "http",
},
{
uri: strings.TrimPrefix(httpsServer.URL, "https://"),
host: httpsServerBaseURL.Host,
expected: "https",
},
{
uri: tcpServer.Addr().String(),
host: tcpServerBaseURL.Host,
expected: "tcp",
},
}

for _, tt := range tests {
res := checker.CheckProtocol(tt.uri)
service, servicePort := createServiceAndServicePort(t, tt.expected, tt.host)
res := checker.CheckProtocol(service, servicePort)
require.Equal(t, tt.expected, res)
}

t.Run("caching prevents more than 1 network request to the same service", func(t *testing.T) {
service, servicePort := createServiceAndServicePort(t, "https", httpsServerBaseURL.Host)
checker.CheckProtocol(service, servicePort)
// There can only be two hits recorded: one for the HTTPS Server and another one for the TCP Server.
// The HTTP Server does not generate a network hit. See above.
require.Equal(t, int32(2), totalNetworkHits.Load())
})
}

func createServiceAndServicePort(t *testing.T, serviceName, host string) (corev1.Service, corev1.ServicePort) {
host, portString, err := net.SplitHostPort(host)
require.NoError(t, err)
port, err := strconv.Atoi(portString)
require.NoError(t, err)
service := corev1.Service{
ObjectMeta: v1.ObjectMeta{
Name: serviceName,
Namespace: "default",
},
Spec: corev1.ServiceSpec{
Type: corev1.ServiceTypeExternalName,
ExternalName: host,
},
}
servicePort := corev1.ServicePort{Port: int32(port)}
return service, servicePort
}

func newTCPServer(t *testing.T, handleConn func(net.Conn)) net.Listener {
Expand Down Expand Up @@ -579,7 +626,7 @@ func TestAutoProtocolDetection(t *testing.T) {
port.AppProtocol = &tt.appProtocol
}

result := autoProtocolDetection("192.1.1.1", port, nil)
result := autoProtocolDetection(corev1.Service{}, port, nil)

require.Equal(t, tt.expected, result)
})
Expand Down

0 comments on commit 341a968

Please sign in to comment.