Skip to content

Commit

Permalink
Convert lib/auth/middleware to use slog (#50521)
Browse files Browse the repository at this point in the history
  • Loading branch information
rosstimothy authored Dec 21, 2024
1 parent db73df3 commit 0fcdc9d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 22 deletions.
8 changes: 5 additions & 3 deletions lib/auth/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,13 @@ import (
"github.com/gravitational/teleport/lib/tlsca"
usagereporter "github.com/gravitational/teleport/lib/usagereporter/teleport"
"github.com/gravitational/teleport/lib/utils"
logutils "github.com/gravitational/teleport/lib/utils/log"
)

var log = logrus.WithFields(logrus.Fields{
teleport.ComponentKey: teleport.ComponentAuth,
})
var (
log = logrus.WithField(teleport.ComponentKey, teleport.ComponentAuth)
logger = logutils.NewPackageLogger(teleport.ComponentKey, teleport.ComponentAuth)
)

// VersionStorage local storage for saving the version.
type VersionStorage interface {
Expand Down
50 changes: 31 additions & 19 deletions lib/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"crypto/x509"
"encoding/json"
"fmt"
"log/slog"
"net"
"net/http"
"os"
Expand All @@ -36,7 +37,6 @@ import (
"github.com/gravitational/trace"
grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"golang.org/x/net/http2"
"google.golang.org/grpc"
Expand All @@ -57,6 +57,7 @@ import (
"github.com/gravitational/teleport/lib/observability/metrics"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
logutils "github.com/gravitational/teleport/lib/utils/log"
)

const (
Expand Down Expand Up @@ -139,7 +140,7 @@ type TLSServer struct {
// cfg is TLS server configuration used for auth server
cfg TLSServerConfig
// log is TLS server logging entry
log *logrus.Entry
log *slog.Logger
// mux is a listener that multiplexes HTTP/2 and HTTP/1.1
// on different listeners
mux *multiplexer.TLSListener
Expand Down Expand Up @@ -215,9 +216,7 @@ func NewTLSServer(ctx context.Context, cfg TLSServerConfig) (*TLSServer, error)
return authz.ContextWithConn(ctx, c)
},
},
log: logrus.WithFields(logrus.Fields{
teleport.ComponentKey: cfg.Component,
}),
log: slog.With(teleport.ComponentKey, cfg.Component),
}

tlsConfig := cfg.TLS.Clone()
Expand Down Expand Up @@ -306,7 +305,7 @@ func (t *TLSServer) Serve() error {
errC := make(chan error, 2)
go func() {
err := t.mux.Serve()
t.log.WithError(err).Warningf("Mux serve failed.")
t.log.WarnContext(context.Background(), "Mux serve failed", "error", err)
}()
go func() {
errC <- t.httpServer.Serve(t.mux.HTTP())
Expand Down Expand Up @@ -372,7 +371,9 @@ func getCustomRate(endpoint string) *limiter.RateSet {
rates := limiter.NewRateSet()
// This limit means: 1 request per minute with bursts up to 10 requests.
if err := rates.Add(time.Minute, 1, 10); err != nil {
log.WithError(err).Debugf("Failed to define a custom rate for rpc method %q, using default rate", endpoint)
logger.DebugContext(context.Background(), "Failed to define a custom rate for rpc method, using default rate",
"error", err,
"rpc_method", endpoint)
return nil
}
return rates
Expand All @@ -383,7 +384,10 @@ func getCustomRate(endpoint string) *limiter.RateSet {
const burst = defaults.LimiterBurst
rates := limiter.NewRateSet()
if err := rates.Add(period, average, burst); err != nil {
log.WithError(err).Debugf("Failed to define a custom rate for rpc method %q, using default rate", endpoint)
logger.DebugContext(context.Background(), "Failed to define a custom rate for rpc method, using default rate",
"error", err,
"rpc_method", endpoint,
)
return nil
}
return rates
Expand All @@ -406,24 +410,29 @@ func (a *Middleware) ValidateClientVersion(ctx context.Context, info IdentityInf

ua := metadata.UserAgentFromContext(ctx)

logger := log.WithFields(logrus.Fields{"user_agent": ua, "identity": info.IdentityGetter.GetIdentity().Username, "version": clientVersionString, "addr": info.Conn.RemoteAddr().String()})
logger := slog.With(
"user_agent", ua,
"identity", info.IdentityGetter.GetIdentity().Username,
"version", clientVersionString,
"addr", logutils.StringerAttr(info.Conn.RemoteAddr()),
)
clientVersion, err := semver.NewVersion(clientVersionString)
if err != nil {
logger.WithError(err).Warn("Failed to determine client version")
logger.WarnContext(ctx, "Failed to determine client version", "error", err)
a.displayRejectedClientAlert(ctx, clientVersionString, info.Conn.RemoteAddr(), ua, info.IdentityGetter)
if err := info.Conn.Close(); err != nil {
logger.WithError(err).Warn("Failed to close client connection")
logger.WarnContext(ctx, "Failed to close client connection", "error", err)
}

return trace.AccessDenied("client version is unsupported")
}

if clientVersion.LessThan(*a.OldestSupportedVersion) {
logger.Info("Terminating connection of client using unsupported version")
logger.InfoContext(ctx, "Terminating connection of client using unsupported version")
a.displayRejectedClientAlert(ctx, clientVersionString, info.Conn.RemoteAddr(), ua, info.IdentityGetter)

if err := info.Conn.Close(); err != nil {
logger.WithError(err).Warn("Failed to close client connection")
logger.WarnContext(ctx, "Failed to close client connection", "error", err)
}

return trace.AccessDenied("client version is unsupported")
Expand Down Expand Up @@ -486,12 +495,12 @@ func (a *Middleware) displayRejectedClientAlert(ctx context.Context, clientVersi
types.WithAlertLabel(types.AlertVerbPermit, fmt.Sprintf("%s:%s", types.KindToken, types.VerbCreate)),
)
if err != nil {
log.WithError(err).Warn("failed to create rejected-unsupported-connection alert")
logger.WarnContext(ctx, "failed to create rejected-unsupported-connection alert", "error", err)
return
}

if err := a.AlertCreator(ctx, alert); err != nil {
log.WithError(err).Warn("failed to persist rejected-unsupported-connection alert")
logger.WarnContext(ctx, "failed to persist rejected-unsupported-connection alert", "error", err)
return
}
}
Expand Down Expand Up @@ -656,7 +665,7 @@ func (a *Middleware) GetUser(connState tls.ConnectionState) (authz.IdentityGette
if certClusterName == "" {
certClusterName, err = tlsca.ClusterName(clientCert.Issuer)
if err != nil {
log.Warnf("Failed to parse client certificate %v.", err)
logger.WarnContext(context.Background(), "Failed to parse client certificate", "error", err)
return nil, trace.AccessDenied("access denied: invalid client certificate")
}
identity.TeleportCluster = certClusterName
Expand All @@ -667,8 +676,11 @@ func (a *Middleware) GetUser(connState tls.ConnectionState) (authz.IdentityGette
// against auth server. Later on we can extend more
// advanced cert usage, but for now this is the safest option.
if len(identity.Usage) != 0 && !slices.Equal(a.AcceptedUsage, identity.Usage) {
log.Warningf("Restricted certificate of user %q with usage %v rejected while accessing the auth endpoint with acceptable usage %v.",
identity.Username, identity.Usage, a.AcceptedUsage)
logger.WarnContext(context.Background(), "Restricted certificate rejected while accessing the auth endpoint",
"user", identity.Username,
"cert_usage", identity.Usage,
"acceptable_usage", a.AcceptedUsage,
)
return nil, trace.AccessDenied("access denied: invalid client certificate")
}

Expand Down Expand Up @@ -734,7 +746,7 @@ func extractAdditionalSystemRoles(roles []string) types.SystemRoles {
if err != nil {
// ignore unknown system roles rather than rejecting them, since new unknown system
// roles may be present on certs if we rolled back from a newer version.
log.Warnf("Ignoring unknown system role: %q", role)
logger.WarnContext(context.Background(), "Ignoring unknown system role", "unknown_role", role)
continue
}
systemRoles = append(systemRoles, systemRole)
Expand Down

0 comments on commit 0fcdc9d

Please sign in to comment.