Skip to content

Commit

Permalink
feat: improve instrumentation for google cloud mTLS issues
Browse files Browse the repository at this point in the history
When deployed in google cloud, mTLS is off-loaded to the cloud
load balancer as it needs to terminate TLS. However, the
instrumentation around failures due to this off-load was not
great.

Moved the creation of the initial trace message to middleware.
This allows the middleware that processes the Google Cloud
proxy headers for mTLS to add lots of additional information to
the span.
  • Loading branch information
subnova committed Feb 19, 2024
1 parent 063c8ef commit 84bf26a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 11 deletions.
43 changes: 43 additions & 0 deletions gateway/server/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,78 @@ package server
import (
"crypto/tls"
"crypto/x509"
"fmt"
"github.com/go-chi/chi/v5"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"net/http"
"strconv"

semconv "go.opentelemetry.io/otel/semconv/v1.17.0"

"github.com/thoughtworks/maeve-csms/gateway/registry"
"golang.org/x/exp/slog"
)

func TraceRequest(tracer trace.Tracer) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
slog.Info("websocket connection received", "path", r.URL.Path, "method", r.Method)
slog.Info("processing connection", "uri", r.RequestURI)

newCtx, span := tracer.Start(r.Context(), fmt.Sprintf("%s %s", r.Method, r.URL.String()), trace.WithSpanKind(trace.SpanKindServer),
trace.WithAttributes(
semconv.HTTPScheme(getScheme(r)),
semconv.HTTPMethod(r.Method),
semconv.HTTPURL(r.URL.String())))
defer span.End()

h.ServeHTTP(w, r.WithContext(newCtx))

routePattern := chi.RouteContext(r.Context()).RoutePattern()
if routePattern != "" {
span.SetName(fmt.Sprintf("%s %s", r.Method, routePattern))
} else {
span.SetStatus(codes.Error, "not found")
span.SetAttributes(semconv.HTTPStatusCode(http.StatusNotFound))
}
span.SetAttributes(semconv.HTTPRoute(chi.RouteContext(r.Context()).RoutePattern()))
})
}
}

func TLSOffload(registry registry.DeviceRegistry) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
span := trace.SpanFromContext(r.Context())

forwardedProtoHeader := r.Header.Get("X-Forwarded-Proto")
span.SetAttributes(attribute.String("http.proto", forwardedProtoHeader))

if forwardedProtoHeader == "https" {
r.TLS = &tls.ConnectionState{
HandshakeComplete: true,
}

clientCertPresentHeader := r.Header.Get("X-Client-Cert-Present")
clientCertPresent, err := strconv.ParseBool(clientCertPresentHeader)
span.SetAttributes(attribute.Bool("cert.present", clientCertPresent))
if err == nil && clientCertPresent {
clientCertChainValidHeader := r.Header.Get("X-Client-Cert-Chain-Verified")
clientCertChainValid, err := strconv.ParseBool(clientCertChainValidHeader)
span.SetAttributes(attribute.Bool("cert.valid", clientCertChainValid))
if err == nil && clientCertChainValid {
clientCertHashHeader := r.Header.Get("X-Client-Cert-Hash")
span.SetAttributes(attribute.String("cert.hash", clientCertHashHeader))
certificate, err := registry.LookupCertificate(clientCertHashHeader)
if err == nil && certificate != nil {
r.TLS.PeerCertificates = []*x509.Certificate{certificate}
} else if err != nil {
span.SetAttributes(attribute.String("cert.lookup.error", err.Error()))
slog.Error("lookup certificate", "clientCertHashHeader", clientCertHashHeader, "err", err)
} else {
span.SetAttributes(attribute.String("cert.lookup.error", "NotFound"))
slog.Warn("certificate not found", "clientCertHashHeader", clientCertHashHeader)
}
}
Expand Down
17 changes: 6 additions & 11 deletions gateway/server/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ func NewWebsocketHandler(opts ...WebsocketOpt) http.Handler {

r := chi.NewRouter()
r.Use(middleware.Recoverer)
r.Use(TraceRequest(s.tracer))
if s.trustProxyHeaders {
r.Use(TLSOffload(s.deviceRegistry))
}
Expand Down Expand Up @@ -177,13 +178,7 @@ func (s *WebsocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
slog.Info("websocket connection received", "path", r.URL.Path, "method", r.Method)
slog.Info("processing connection", "uri", r.RequestURI)

newCtx, span := s.tracer.Start(r.Context(), "GET /ws/{id}", trace.WithSpanKind(trace.SpanKindServer),
trace.WithAttributes(
semconv.HTTPScheme(getScheme(r)),
semconv.HTTPMethod("GET"),
semconv.HTTPURL(r.URL.String()),
semconv.HTTPRoute("/ws/{id}")))
defer span.End()
span := trace.SpanFromContext(r.Context())

clientId := chi.URLParam(r, "id")
if clientId == "" {
Expand Down Expand Up @@ -213,7 +208,7 @@ func (s *WebsocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

switch cs.SecurityProfile {
case registry.UnsecuredTransportWithBasicAuth:
if r.TLS != nil || !checkAuthorization(newCtx, r, cs) {
if r.TLS != nil || !checkAuthorization(r.Context(), r, cs) {
if r.TLS != nil {
span.SetAttributes(attribute.String("auth.failure_reason", "tls for unsecured transport"))
}
Expand All @@ -223,7 +218,7 @@ func (s *WebsocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
case registry.TLSWithBasicAuth:
if r.TLS == nil || !checkAuthorization(newCtx, r, cs) {
if r.TLS == nil || !checkAuthorization(r.Context(), r, cs) {
if r.TLS == nil {
span.SetAttributes(attribute.String("auth.failure_reason", "no tls for secured transport"))
}
Expand All @@ -233,7 +228,7 @@ func (s *WebsocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
case registry.TLSWithClientSideCertificates:
if r.TLS == nil || !checkCertificate(newCtx, r, s.orgNames, cs) {
if r.TLS == nil || !checkCertificate(r.Context(), r, s.orgNames, cs) {
if r.TLS == nil {
span.SetAttributes(attribute.String("auth.failure_reason", "no tls for secured transport"))
}
Expand Down Expand Up @@ -265,7 +260,7 @@ func (s *WebsocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
p.Start()
defer p.Close()

ctx, cancel := context.WithCancel(newCtx)
ctx, cancel := context.WithCancel(r.Context())
defer cancel()

mqttBrokerURLStrings := make([]string, len(s.mqttBrokerURLs))
Expand Down

0 comments on commit 84bf26a

Please sign in to comment.