Skip to content

Commit

Permalink
convert db lib to use slog, part 1 (#43382) (#43749)
Browse files Browse the repository at this point in the history
* convert db lib to use slog, part 1

* fix typo

* fix build and add deprecate

* fix lint
  • Loading branch information
greedy52 authored Jul 2, 2024
1 parent 632f230 commit 21b1eaf
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 122 deletions.
60 changes: 30 additions & 30 deletions lib/srv/db/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package db
import (
"context"
"crypto/tls"
"log/slog"
"testing"
"time"

Expand Down Expand Up @@ -263,8 +264,8 @@ type testAuth struct {
// Auth is the wrapped "real" auth that handles everything except for
// cloud auth tokens generation.
realAuth common.Auth
// FieldLogger is used for logging.
logger logrus.FieldLogger
// Logger is used for logging.
*slog.Logger
}

func newTestAuth(ac common.AuthConfig) (*testAuth, error) {
Expand All @@ -274,7 +275,7 @@ func newTestAuth(ac common.AuthConfig) (*testAuth, error) {
}
return &testAuth{
realAuth: auth,
logger: logrus.WithField(teleport.ComponentKey, "auth:test"),
Logger: slog.With(teleport.ComponentKey, "auth:test"),
}, nil
}

Expand Down Expand Up @@ -310,34 +311,34 @@ const (
)

type fakeTokenSource struct {
logrus.FieldLogger
*slog.Logger

token string
exp time.Time
}

func (f *fakeTokenSource) Token() (*oauth2.Token, error) {
f.Info("Generating Cloud Spanner auth token source")
f.InfoContext(context.Background(), "Generating Cloud Spanner auth token source")
return &oauth2.Token{
Expiry: f.exp,
AccessToken: f.token,
}, nil
}

func (a *testAuth) GetRDSAuthToken(ctx context.Context, database types.Database, databaseUser string) (string, error) {
a.logger.
WithField("database", database).
WithField("databaseUser", databaseUser).
Info("Generating RDS auth token")
a.InfoContext(ctx, "Generating RDS auth token.",
"database", database,
"database_user", databaseUser,
)
return rdsAuthToken, nil
}

func (a *testAuth) GetRedshiftAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) {
a.logger.
WithField("database", database).
WithField("databaseUser", databaseUser).
WithField("databaseName", databaseName).
Info("Generating Redshift auth token")
a.InfoContext(ctx, "Generating Redshift auth token",
"database", database,
"database_user", databaseUser,
"database_name", databaseName,
)
return redshiftAuthUser, redshiftAuthToken, nil
}

Expand All @@ -354,34 +355,32 @@ func (a *testAuth) GetMemoryDBToken(ctx context.Context, database types.Database
}

func (a *testAuth) GetCloudSQLAuthToken(ctx context.Context, databaseUser string) (string, error) {
a.logger.WithField("database_user", databaseUser).Info("Generating Cloud SQL auth token")
a.InfoContext(ctx, "Generating Cloud SQL auth token", "database_user", databaseUser)
return cloudSQLAuthToken, nil
}

func (a *testAuth) GetSpannerTokenSource(ctx context.Context, databaseUser string) (oauth2.TokenSource, error) {
return &fakeTokenSource{
token: cloudSpannerAuthToken,
FieldLogger: a.logger.WithField("database_user", databaseUser),
token: cloudSpannerAuthToken,
Logger: a.Logger.With("database_user", databaseUser),
}, nil
}

func (a *testAuth) GetCloudSQLPassword(ctx context.Context, database types.Database, databaseUser string) (string, error) {
a.logger.
WithField("database", database).
WithField("database_user", databaseUser).
Info("Generating Cloud SQL password")
a.InfoContext(ctx, "Generating Cloud SQL password",
"database", database,
"database_user", databaseUser,
)
return cloudSQLPassword, nil
}

func (a *testAuth) GetAzureAccessToken(ctx context.Context) (string, error) {
a.logger.Info("Generating Azure access token")
a.InfoContext(ctx, "Generating Azure access token")
return azureAccessToken, nil
}

func (a *testAuth) GetAzureCacheForRedisToken(ctx context.Context, database types.Database) (string, error) {
a.logger.
WithField("database", database).
Info("Generating Azure Redis token")
a.InfoContext(ctx, "Generating Azure Redis token", "database", database)
return azureRedisToken, nil
}

Expand All @@ -398,17 +397,18 @@ func (a *testAuth) GetAzureIdentityResourceID(ctx context.Context, identityName
}

func (a *testAuth) GetAWSIAMCreds(ctx context.Context, database types.Database, databaseUser string) (string, string, string, error) {
a.logger.
WithField("database", database).
WithField("database_user", databaseUser).
Info("Generating AWS IAM credentials")
a.InfoContext(ctx, "Generating AWS IAM credentials",
"database", database,
"database_user", databaseUser,
)
return atlasAuthUser, atlasAuthToken, atlasAuthSessionToken, nil
}

func (a *testAuth) WithLogger(getUpdatedLogger func(logrus.FieldLogger) logrus.FieldLogger) common.Auth {
// TODO(greedy52) update WithLogger to use slog.
return &testAuth{
realAuth: a.realAuth,
logger: getUpdatedLogger(a.logger),
Logger: a.Logger,
}
}

Expand Down
14 changes: 7 additions & 7 deletions lib/srv/db/ca.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (s *Server) startCARenewer(ctx context.Context) {
case <-schedule.Chan():
for _, database := range s.getProxiedDatabases() {
if err := s.initCACert(ctx, database); err != nil {
s.log.WithError(err).Errorf("Failed to renew database %q CA.", database.GetName())
s.log.ErrorContext(ctx, "Failed to renew database CA.", "db", database.GetName(), "error", err)
}
}
case <-ctx.Done():
Expand Down Expand Up @@ -154,7 +154,7 @@ func (s *Server) getCACert(ctx context.Context, database types.Database, filePat
}
// The update flow is going to create/update the cached CA, so we can read
// the contents from it.
s.log.Debugf("Loaded CA certificate %v.", filePath)
s.log.DebugContext(ctx, "Loaded CA certificate.", "path", filePath)
return os.ReadFile(filePath)
}

Expand Down Expand Up @@ -211,7 +211,7 @@ func (s *Server) getCACertPaths(database types.Database) ([]string, error) {
}

// saveCACert saves the downloaded certificate to the filesystem.
func (s *Server) saveCACert(filePath string, content []byte, version []byte) error {
func (s *Server) saveCACert(ctx context.Context, filePath string, content []byte, version []byte) error {
// Save CA contents.
err := os.WriteFile(filePath, content, teleport.FileMaskOwnerOnly)
if err != nil {
Expand All @@ -224,7 +224,7 @@ func (s *Server) saveCACert(filePath string, content []byte, version []byte) err
return trace.Wrap(err)
}

s.log.Debugf("Saved CA certificate %v.", filePath)
s.log.DebugContext(ctx, "Saved CA certificate.", "path", filePath)
return nil
}

Expand All @@ -250,7 +250,7 @@ func (s *Server) updateCACert(ctx context.Context, database types.Database, file
}

if equal {
s.log.Debugf("Database %q CA is up-to-date.", database.GetName())
s.log.DebugContext(ctx, "Database CA is up-to-date.", "db", database.GetName(), "type", database.GetType())
return nil
}

Expand All @@ -262,12 +262,12 @@ func (s *Server) updateCACert(ctx context.Context, database types.Database, file
}
}

err = s.saveCACert(filePath, contents, version)
err = s.saveCACert(ctx, filePath, contents, version)
if err != nil {
return trace.Wrap(err)
}

s.log.Infof("Database %q CA updated.", database.GetName())
s.log.InfoContext(ctx, "Database CA updated.", "db", database.GetName(), "type", database.GetType())
return nil
}

Expand Down
18 changes: 9 additions & 9 deletions lib/srv/db/mysql/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ package mysql
import (
"context"
"crypto/tls"
"log/slog"
"net"
"time"

"github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/server"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/defaults"
Expand All @@ -52,7 +52,7 @@ type Proxy struct {
// Service is used to connect to a remote database service.
Service common.Service
// Log is used for logging.
Log logrus.FieldLogger
Log *slog.Logger
// Limiter limits the number of active connections per client IP.
Limiter *limiter.Limiter
// IngressReporter reports new and active connections.
Expand Down Expand Up @@ -81,18 +81,18 @@ func (p *Proxy) HandleConnection(ctx context.Context, clientConn net.Conn) (err
// has a chance to close the connection from its side.
defer func() {
if r := recover(); r != nil {
p.Log.Warnf("Recovered in MySQL proxy while handling connection from %v: %v.", clientConn.RemoteAddr(), r)
p.Log.WarnContext(ctx, "Recovered in MySQL proxy while handling connectionv.", "from", clientConn.RemoteAddr(), "to", r)
err = trace.BadParameter("failed to handle MySQL client connection")
}
if err != nil {
if writeErr := mysqlServer.WriteError(err); writeErr != nil {
p.Log.WithError(writeErr).Debugf("Failed to send error %q to MySQL client.", err)
p.Log.DebugContext(ctx, "Failed to send error to MySQL client.", "original_err", err.Error(), "error", writeErr)
}
}
}()
// Perform first part of the handshake, up to the point where client sends
// us certificate and connection upgrades to TLS.
tlsConn, err := p.performHandshake(conn, mysqlServer)
tlsConn, err := p.performHandshake(ctx, conn, mysqlServer)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -189,7 +189,7 @@ func (p *Proxy) makeServer(clientConn net.Conn, serverVersion string) *server.Co
// performHandshake performs the initial handshake between MySQL client and
// this server, up to the point where the client sends us a certificate for
// authentication, and returns the upgraded connection.
func (p *Proxy) performHandshake(conn *multiplexer.Conn, server *server.Conn) (utils.TLSConn, error) {
func (p *Proxy) performHandshake(ctx context.Context, conn *multiplexer.Conn, server *server.Conn) (utils.TLSConn, error) {
// MySQL protocol is server-initiated which means the client will expect
// server to send initial handshake message.
err := server.WriteInitialHandshake()
Expand All @@ -198,7 +198,7 @@ func (p *Proxy) performHandshake(conn *multiplexer.Conn, server *server.Conn) (u
}
// See if we need to read the proxy-line which could happen if Teleport
// is running behind a load balancer with proxy protocol enabled.
err = p.maybeReadProxyLine(conn)
err = p.maybeReadProxyLine(ctx, conn)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -228,7 +228,7 @@ func (p *Proxy) performHandshake(conn *multiplexer.Conn, server *server.Conn) (u
// maybeReadProxyLine peeks into the connection to see if instead of regular
// MySQL protocol we were sent a proxy-line. This usually happens when Teleport
// is running behind a load balancer with proxy protocol enabled.
func (p *Proxy) maybeReadProxyLine(conn *multiplexer.Conn) error {
func (p *Proxy) maybeReadProxyLine(ctx context.Context, conn *multiplexer.Conn) error {
proto, err := conn.Detect()
if err != nil {
return trace.Wrap(err)
Expand All @@ -240,7 +240,7 @@ func (p *Proxy) maybeReadProxyLine(conn *multiplexer.Conn) error {
if err != nil {
return trace.Wrap(err)
}
p.Log.Debugf("MySQL listener proxy-line: %s.", proxyLine)
p.Log.DebugContext(ctx, "MySQL listener proxy-line.", "proxy_line", proxyLine)
return nil
}

Expand Down
10 changes: 5 additions & 5 deletions lib/srv/db/postgres/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ package postgres
import (
"context"
"crypto/tls"
"log/slog"
"net"

"github.com/gravitational/trace"
"github.com/jackc/pgproto3/v2"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/limiter"
Expand All @@ -46,7 +46,7 @@ type Proxy struct {
// Service is used to connect to a remote database service.
Service common.Service
// Log is used for logging.
Log logrus.FieldLogger
Log *slog.Logger
// Limiter limits the number of active connections per client IP.
Limiter *limiter.Limiter
// IngressReporter reports new and active connections.
Expand All @@ -67,7 +67,7 @@ func (p *Proxy) HandleConnection(ctx context.Context, clientConn net.Conn) (err
}
if err := p.handleConnection(ctx, tlsConn, startupMessage); err != nil {
if serr := backend.Send(toErrorResponse(err)); serr != nil {
p.Log.WithError(serr).Warn("Failed to send error to backend.")
p.Log.WarnContext(ctx, "Failed to send error to backend.", "error", serr)
}
return trace.Wrap(err)
}
Expand Down Expand Up @@ -139,7 +139,7 @@ func (p *Proxy) handleStartup(ctx context.Context, clientConn net.Conn) (pgproto
// We don't want to log the cancel request secret key, so we handle
// this case separately.
if m, ok := startupMessage.(*pgproto3.CancelRequest); ok {
p.Log.Debugf("Received cancel request for pid: %v.", m.ProcessID)
p.Log.DebugContext(ctx, "Received cancel request for pid.", "pid", m.ProcessID)
tlsConn, ok := clientConn.(utils.TLSConn)
if !ok {
return nil, nil, nil, trace.BadParameter(
Expand All @@ -156,7 +156,7 @@ func (p *Proxy) handleStartup(ctx context.Context, clientConn net.Conn) (pgproto
// user name, database name, etc.
//
// https://www.postgresql.org/docs/13/protocol-flow.html#id-1.10.5.7.11
p.Log.Debugf("Received startup message: %#v.", startupMessage)
p.Log.DebugContext(ctx, "Received startup message.", "message", startupMessage)
switch m := startupMessage.(type) {
case *pgproto3.SSLRequest:
if receivedSSLRequest {
Expand Down
Loading

0 comments on commit 21b1eaf

Please sign in to comment.