Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert lib/srv/alpnproxy to use slog #50018

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integration/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,7 @@ func TestALPNSNIProxyDatabaseAccess(t *testing.T) {

// advance the fake clock and verify that the local proxy thinks its cert expired.
fakeClock.Advance(time.Hour * 48)
err = lp.CheckDBCert(routeToDatabase)
err = lp.CheckDBCert(context.Background(), routeToDatabase)
require.Error(t, err)
var x509Err x509.CertificateInvalidError
require.ErrorAs(t, err, &x509Err)
Expand Down
12 changes: 6 additions & 6 deletions lib/auth/authclient/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"log/slog"
"math"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport/api/types"
apiutils "github.com/gravitational/teleport/api/utils"
Expand Down Expand Up @@ -96,7 +96,7 @@ func DefaultClientCertPool(ctx context.Context, client CAGetter, clusterName str

// WithClusterCAs returns a TLS hello callback that returns a copy of the provided
// TLS config with client CAs pool of the specified cluster.
func WithClusterCAs(tlsConfig *tls.Config, ap CAGetter, currentClusterName string, log logrus.FieldLogger) func(*tls.ClientHelloInfo) (*tls.Config, error) {
func WithClusterCAs(tlsConfig *tls.Config, ap CAGetter, currentClusterName string, logger *slog.Logger) func(*tls.ClientHelloInfo) (*tls.Config, error) {
return func(info *tls.ClientHelloInfo) (*tls.Config, error) {
var clusterName string
var err error
Expand All @@ -105,14 +105,14 @@ func WithClusterCAs(tlsConfig *tls.Config, ap CAGetter, currentClusterName strin
clusterName, err = apiutils.DecodeClusterName(info.ServerName)
if err != nil {
if !trace.IsNotFound(err) {
log.Debugf("Ignoring unsupported cluster name name %q.", info.ServerName)
logger.DebugContext(info.Context(), "Ignoring unsupported cluster name name", "cluster_name", info.ServerName)
clusterName = ""
}
}
}
pool, totalSubjectsLen, err := DefaultClientCertPool(info.Context(), ap, clusterName)
if err != nil {
log.WithError(err).Errorf("Failed to retrieve client pool for %q.", clusterName)
logger.ErrorContext(info.Context(), "Failed to retrieve client pool for cluster", "error", err, "cluster", clusterName)
// this falls back to the default config
return nil, nil
}
Expand All @@ -130,11 +130,11 @@ func WithClusterCAs(tlsConfig *tls.Config, ap CAGetter, currentClusterName strin
// the current cluster CA. In the unlikely case where it's wrong, the
// client will be rejected.
if totalSubjectsLen >= int64(math.MaxUint16) {
log.Debugf("Number of CAs in client cert pool is too large and cannot be encoded in a TLS handshake; this is due to a large number of trusted clusters; will use only the CA of the current cluster to validate.")
logger.DebugContext(info.Context(), "Number of CAs in client cert pool is too large and cannot be encoded in a TLS handshake; this is due to a large number of trusted clusters; will use only the CA of the current cluster to validate")

pool, _, err = DefaultClientCertPool(info.Context(), ap, currentClusterName)
if err != nil {
log.WithError(err).Errorf("Failed to retrieve client pool for %q.", currentClusterName)
logger.ErrorContext(info.Context(), "Failed to retrieve client pool for cluster", "error", err, "cluster", currentClusterName)
// this falls back to the default config
return nil, nil
}
Expand Down
8 changes: 4 additions & 4 deletions lib/kube/grpc/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"log/slog"
"net"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -639,7 +639,7 @@ func initGRPCServer(t *testing.T, testCtx *kubeproxy.TestContext, listener net.L
AcceptedUsage: []string{teleport.UsageKubeOnly},
}

tlsConf := copyAndConfigureTLS(tlsConfig, logrus.New(), testCtx.AuthClient, clusterName)
tlsConf := copyAndConfigureTLS(tlsConfig, testCtx.AuthClient, clusterName)
creds, err := auth.NewTransportCredentials(auth.TransportCredentialsConfig{
TransportCredentials: credentials.NewTLS(tlsConf),
UserGetter: authMiddleware,
Expand Down Expand Up @@ -693,7 +693,7 @@ func initGRPCServer(t *testing.T, testCtx *kubeproxy.TestContext, listener net.L

// copyAndConfigureTLS can be used to copy and modify an existing *tls.Config
// for Teleport application proxy servers.
func copyAndConfigureTLS(config *tls.Config, log logrus.FieldLogger, accessPoint authclient.AccessCache, clusterName string) *tls.Config {
func copyAndConfigureTLS(config *tls.Config, accessPoint authclient.AccessCache, clusterName string) *tls.Config {
tlsConfig := config.Clone()

// Require clients to present a certificate
Expand All @@ -703,7 +703,7 @@ func copyAndConfigureTLS(config *tls.Config, log logrus.FieldLogger, accessPoint
// client's certificate to verify the chain presented. If the client does not
// pass in the cluster name, this functions pulls back all CA to try and
// match the certificate presented against any CA.
tlsConfig.GetConfigForClient = authclient.WithClusterCAs(tlsConfig.Clone(), accessPoint, clusterName, log)
tlsConfig.GetConfigForClient = authclient.WithClusterCAs(tlsConfig.Clone(), accessPoint, clusterName, slog.Default())

return tlsConfig
}
Expand Down
4 changes: 3 additions & 1 deletion lib/kube/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package proxy
import (
"context"
"crypto/tls"
"log/slog"
"maps"
"net"
"net/http"
Expand Down Expand Up @@ -421,7 +422,8 @@ func (t *TLSServer) close(ctx context.Context) error {
// and server's GetConfigForClient reloads the list of trusted
// local and remote certificate authorities
func (t *TLSServer) GetConfigForClient(info *tls.ClientHelloInfo) (*tls.Config, error) {
return authclient.WithClusterCAs(t.TLS, t.AccessPoint, t.ClusterName, t.log)(info)
// TODO(tross): remove slog.Default once the TLSServer is updated to use a slog.Logger
return authclient.WithClusterCAs(t.TLS, t.AccessPoint, t.ClusterName, slog.Default())(info)
}

// GetServerInfo returns a services.Server object for heartbeats (aka
Expand Down
5 changes: 2 additions & 3 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -5273,7 +5273,6 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
clusterName,
utils.NetAddrsToStrings(process.Config.AuthServerAddresses()),
proxySigner,
process.log,
process.TracingProvider.Tracer(teleport.ComponentProxy))

alpnRouter.Add(alpnproxy.HandlerDecs{
Expand Down Expand Up @@ -6698,7 +6697,7 @@ func (process *TeleportProcess) initSecureGRPCServer(cfg initSecureGRPCServerCfg

tlsConf := serverTLSConfig.Clone()
tlsConf.NextProtos = []string{string(alpncommon.ProtocolHTTP2), string(alpncommon.ProtocolProxyGRPCSecure)}
tlsConf = copyAndConfigureTLS(tlsConf, process.log, cfg.accessPoint, clusterName)
tlsConf = copyAndConfigureTLS(tlsConf, process.logger, cfg.accessPoint, clusterName)
creds, err := auth.NewTransportCredentials(auth.TransportCredentialsConfig{
TransportCredentials: credentials.NewTLS(tlsConf),
UserGetter: authMiddleware,
Expand Down Expand Up @@ -6752,7 +6751,7 @@ type initSecureGRPCServerCfg struct {

// copyAndConfigureTLS can be used to copy and modify an existing *tls.Config
// for Teleport application proxy servers.
func copyAndConfigureTLS(config *tls.Config, log logrus.FieldLogger, accessPoint authclient.AccessCache, clusterName string) *tls.Config {
func copyAndConfigureTLS(config *tls.Config, log *slog.Logger, accessPoint authclient.AccessCache, clusterName string) *tls.Config {
tlsConfig := config.Clone()

// Require clients to present a certificate
Expand Down
5 changes: 1 addition & 4 deletions lib/srv/alpnproxy/auth/auth_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"strings"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
oteltrace "go.opentelemetry.io/otel/trace"

Expand All @@ -44,13 +43,12 @@ type sitesGetter interface {
}

// NewAuthProxyDialerService create new instance of AuthProxyDialerService.
func NewAuthProxyDialerService(reverseTunnelServer sitesGetter, localClusterName string, authServers []string, proxySigner multiplexer.PROXYHeaderSigner, log logrus.FieldLogger, tracer oteltrace.Tracer) *AuthProxyDialerService {
func NewAuthProxyDialerService(reverseTunnelServer sitesGetter, localClusterName string, authServers []string, proxySigner multiplexer.PROXYHeaderSigner, tracer oteltrace.Tracer) *AuthProxyDialerService {
return &AuthProxyDialerService{
reverseTunnelServer: reverseTunnelServer,
localClusterName: localClusterName,
authServers: authServers,
proxySigner: proxySigner,
log: log,
tracer: tracer,
}
}
Expand All @@ -62,7 +60,6 @@ type AuthProxyDialerService struct {
localClusterName string
authServers []string
proxySigner multiplexer.PROXYHeaderSigner
log logrus.FieldLogger
tracer oteltrace.Tracer
}

Expand Down
6 changes: 3 additions & 3 deletions lib/srv/alpnproxy/auth/auth_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
)

func TestDialLocalAuthServerNoServers(t *testing.T) {
s := NewAuthProxyDialerService(nil /* reverseTunnelServer */, "clustername", nil /* authServers */, nil, nil, tracing.NoopTracer("test"))
s := NewAuthProxyDialerService(nil /* reverseTunnelServer */, "clustername", nil /* authServers */, nil, tracing.NoopTracer("test"))
_, err := s.dialLocalAuthServer(context.Background(), nil, nil)
require.Error(t, err, "dialLocalAuthServer expected to fail")
require.Equal(t, "empty auth servers list", err.Error())
Expand All @@ -40,7 +40,7 @@ func TestDialLocalAuthServerNoServers(t *testing.T) {
func TestDialLocalAuthServerNoAvailableServers(t *testing.T) {
// The 203.0.113.0/24 range is part of block TEST-NET-3 as defined in RFC-5735 (https://www.rfc-editor.org/rfc/rfc5735).
// IPs in this range do not appear on the public internet.
s := NewAuthProxyDialerService(nil /* reverseTunnelServer */, "clustername", []string{"203.0.113.1:3025"}, nil, nil, tracing.NoopTracer("test"))
s := NewAuthProxyDialerService(nil /* reverseTunnelServer */, "clustername", []string{"203.0.113.1:3025"}, nil, tracing.NoopTracer("test"))
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
t.Cleanup(cancel)
_, err := s.dialLocalAuthServer(ctx, nil, nil)
Expand All @@ -64,7 +64,7 @@ func TestDialLocalAuthServerAvailableServers(t *testing.T) {
// IPs in this range do not appear on the public internet.
authServers = append(authServers, fmt.Sprintf("203.0.113.%d:3025", i+1))
}
s := NewAuthProxyDialerService(nil /* reverseTunnelServer */, "clustername", authServers, nil, nil, tracing.NoopTracer("test"))
s := NewAuthProxyDialerService(nil /* reverseTunnelServer */, "clustername", authServers, nil, tracing.NoopTracer("test"))
require.Eventually(t, func() bool {
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
t.Cleanup(cancel)
Expand Down
8 changes: 4 additions & 4 deletions lib/srv/alpnproxy/auth_checker_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ package alpnproxy

import (
"crypto/subtle"
"log/slog"
"net/http"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
)
Expand All @@ -35,7 +35,7 @@ type AuthorizationCheckerMiddleware struct {
DefaultLocalProxyHTTPMiddleware

// Log is the Logger.
Log logrus.FieldLogger
Log *slog.Logger
// Secret is the expected value of a bearer token.
Secret string
}
Expand All @@ -45,7 +45,7 @@ var _ LocalProxyHTTPMiddleware = (*AuthorizationCheckerMiddleware)(nil)
// CheckAndSetDefaults checks configuration validity and sets defaults.
func (m *AuthorizationCheckerMiddleware) CheckAndSetDefaults() error {
if m.Log == nil {
m.Log = logrus.WithField(teleport.ComponentKey, "gcp")
m.Log = slog.With(teleport.ComponentKey, "authz")
}

if m.Secret == "" {
Expand All @@ -58,7 +58,7 @@ func (m *AuthorizationCheckerMiddleware) CheckAndSetDefaults() error {
func (m *AuthorizationCheckerMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Request) bool {
auth := req.Header.Get("Authorization")
if auth == "" {
m.Log.Debugf("No Authorization header present, ignoring request.")
m.Log.DebugContext(req.Context(), "No Authorization header present, ignoring request")
return false
}

Expand Down
20 changes: 10 additions & 10 deletions lib/srv/alpnproxy/aws_local_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
package alpnproxy

import (
"log/slog"
"net/http"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
awsapiutils "github.com/gravitational/teleport/api/utils/aws"
Expand All @@ -43,7 +43,7 @@ type AWSAccessMiddleware struct {
// signature verification.
AWSCredentialsProvider aws.CredentialsProvider

Log logrus.FieldLogger
Log *slog.Logger

assumedRoles utils.SyncMap[string, *sts.AssumeRoleOutput]
}
Expand All @@ -52,7 +52,7 @@ var _ LocalProxyHTTPMiddleware = &AWSAccessMiddleware{}

func (m *AWSAccessMiddleware) CheckAndSetDefaults() error {
if m.Log == nil {
m.Log = logrus.WithField(teleport.ComponentKey, "aws_access")
m.Log = slog.With(teleport.ComponentKey, "aws_access")
}

if m.AWSCredentialsProvider == nil {
Expand Down Expand Up @@ -113,7 +113,7 @@ func (m *AWSAccessMiddleware) CheckAndSetDefaults() error {
func (m *AWSAccessMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Request) bool {
sigV4, err := awsutils.ParseSigV4(req.Header.Get(awsutils.AuthorizationHeader))
if err != nil {
m.Log.WithError(err).Error("Failed to parse AWS request authorization header.")
m.Log.ErrorContext(req.Context(), "Failed to parse AWS request authorization header", "error", err)
rw.WriteHeader(http.StatusForbidden)
return true
}
Expand All @@ -135,7 +135,7 @@ func (m *AWSAccessMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Re

func (m *AWSAccessMiddleware) handleCommonRequest(rw http.ResponseWriter, req *http.Request) bool {
if err := awsutils.VerifyAWSSignatureV2(req, m.AWSCredentialsProvider); err != nil {
m.Log.WithError(err).Error("AWS signature verification failed.")
m.Log.ErrorContext(req.Context(), "AWS signature verification failed", "error", err)
rw.WriteHeader(http.StatusForbidden)
return true
}
Expand All @@ -150,12 +150,12 @@ func (m *AWSAccessMiddleware) handleRequestByAssumedRole(rw http.ResponseWriter,
)

if err := awsutils.VerifyAWSSignatureV2(req, credentials); err != nil {
m.Log.WithError(err).Error("AWS signature verification failed.")
m.Log.ErrorContext(req.Context(), "AWS signature verification failed", "error", err)
rw.WriteHeader(http.StatusForbidden)
return true
}

m.Log.Debugf("Rewriting headers for AWS request by assumed role %q.", aws.ToString(assumedRole.AssumedRoleUser.Arn))
m.Log.DebugContext(req.Context(), "Rewriting headers for AWS request by assumed role", "assumed_role", aws.ToString(assumedRole.AssumedRoleUser.Arn))

// Add a custom header for marking the special request.
req.Header.Add(appcommon.TeleportAWSAssumedRole, aws.ToString(assumedRole.AssumedRoleUser.Arn))
Expand All @@ -178,7 +178,7 @@ func (m *AWSAccessMiddleware) HandleResponse(response *http.Response) error {

sigV4, err := awsutils.ParseSigV4(authHeader)
if err != nil {
m.Log.WithError(err).Error("Failed to parse AWS request authorization header.")
m.Log.ErrorContext(response.Request.Context(), "Failed to parse AWS request authorization header", "error", err)
return nil
}

Expand All @@ -205,13 +205,13 @@ func (m *AWSAccessMiddleware) handleSTSResponse(response *http.Response) error {
assumedRole, err := unmarshalAssumeRoleResponse(body)
if err != nil {
if !trace.IsNotFound(err) {
m.Log.Warnf("Failed to unmarshal AssumeRoleResponse: %v.", err)
m.Log.WarnContext(response.Request.Context(), "Failed to unmarshal AssumeRoleResponse", "error", err)
}
return nil
}

m.assumedRoles.Store(aws.ToString(assumedRole.Credentials.AccessKeyId), assumedRole)
m.Log.Debugf("Saved credentials for assumed role %q.", aws.ToString(assumedRole.AssumedRoleUser.Arn))
m.Log.DebugContext(response.Request.Context(), "Saved credentials for assumed role", "assumed_role", aws.ToString(assumedRole.AssumedRoleUser.Arn))
return nil
}

Expand Down
12 changes: 6 additions & 6 deletions lib/srv/alpnproxy/azure_msi_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ import (
"crypto"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"sync"
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
Expand All @@ -49,7 +49,7 @@ type AzureMSIMiddleware struct {
// Clock is used to override time in tests.
Clock clockwork.Clock
// Log is the Logger.
Log logrus.FieldLogger
Log *slog.Logger
// Secret to be provided by the client.
Secret string

Expand All @@ -65,7 +65,7 @@ func (m *AzureMSIMiddleware) CheckAndSetDefaults() error {
m.Clock = clockwork.NewRealClock()
}
if m.Log == nil {
m.Log = logrus.WithField(teleport.ComponentKey, "azure_msi")
m.Log = slog.With(teleport.ComponentKey, "azure_msi")
}

if m.Secret == "" {
Expand All @@ -86,7 +86,7 @@ func (m *AzureMSIMiddleware) CheckAndSetDefaults() error {
func (m *AzureMSIMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Request) bool {
if req.Host == types.TeleportAzureMSIEndpoint {
if err := m.msiEndpoint(rw, req); err != nil {
m.Log.Warnf("Bad MSI request: %v", err)
m.Log.WarnContext(req.Context(), "Bad MSI request", "error", err)
trace.WriteError(rw, trace.Wrap(err))
}
return true
Expand Down Expand Up @@ -135,7 +135,7 @@ func (m *AzureMSIMiddleware) msiEndpoint(rw http.ResponseWriter, req *http.Reque
// check that msi_res_id matches expected Azure Identity
requestedAzureIdentity := req.Form.Get("msi_res_id")
if requestedAzureIdentity != m.Identity {
m.Log.Warnf("Requested unexpected identity %q, expected %q", requestedAzureIdentity, m.Identity)
m.Log.WarnContext(req.Context(), "Requested unexpected identity", "requested_identity", requestedAzureIdentity, "expected_identity", m.Identity)
return trace.BadParameter("unexpected value for parameter 'msi_res_id': %v", requestedAzureIdentity)
}

Expand All @@ -144,7 +144,7 @@ func (m *AzureMSIMiddleware) msiEndpoint(rw http.ResponseWriter, req *http.Reque
return trace.Wrap(err)
}

m.Log.Infof("MSI: returning token for identity %v", m.Identity)
m.Log.InfoContext(req.Context(), "MSI: returning token for identity", "identity", m.Identity)

rw.Header().Add("Content-Type", "application/json; charset=utf-8")
rw.Header().Add("Content-Length", fmt.Sprintf("%v", len(respBody)))
Expand Down
Loading
Loading