Skip to content

Commit

Permalink
Convert lib/srv/db to use slog
Browse files Browse the repository at this point in the history
  • Loading branch information
rosstimothy committed Dec 16, 2024
1 parent 0644b99 commit 4eed9f0
Show file tree
Hide file tree
Showing 31 changed files with 338 additions and 315 deletions.
2 changes: 1 addition & 1 deletion lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2611,7 +2611,7 @@ func (c *testContext) setupDatabaseServer(ctx context.Context, t testing.TB, p a
for _, db := range p.Databases {
select {
case sender := <-inventoryHandle.Sender():
dbServer, err := server.getServerInfo(db)
dbServer, err := server.getServerInfo(ctx, db)
require.NoError(t, err)
require.NoError(t, sender.Send(ctx, proto.InventoryHeartbeat{
DatabaseServer: dbServer,
Expand Down
4 changes: 1 addition & 3 deletions lib/srv/db/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (
"github.com/aws/aws-sdk-go/service/elasticache"
"github.com/aws/aws-sdk-go/service/memorydb"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"

Expand Down Expand Up @@ -411,8 +410,7 @@ func (a *testAuth) GenerateDatabaseClientKey(ctx context.Context) (*keys.Private
return key, trace.Wrap(err)
}

func (a *testAuth) WithLogger(getUpdatedLogger func(logrus.FieldLogger) logrus.FieldLogger) common.Auth {
// TODO(greedy52) update WithLogger to use slog.
func (a *testAuth) WithLogger(getUpdatedLogger func(*slog.Logger) *slog.Logger) common.Auth {
return &testAuth{
realAuth: a.realAuth,
Logger: a.Logger,
Expand Down
15 changes: 8 additions & 7 deletions lib/srv/db/cassandra/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"bytes"
"context"
"crypto/tls"
"log/slog"
"net"
"strings"
"time"
Expand All @@ -33,11 +34,11 @@ import (
"github.com/datastax/go-cassandra-native-protocol/primitive"
"github.com/gocql/gocql"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/utils"
)

// Session alias for easier use.
Expand Down Expand Up @@ -97,7 +98,7 @@ type TestServer struct {
cfg common.TestServerConfig
port string
tlsConfig *tls.Config
log logrus.FieldLogger
logger *slog.Logger
server *client.CqlServer
}

Expand Down Expand Up @@ -142,10 +143,10 @@ func NewTestServer(config common.TestServerConfig, opts ...TestServerOption) (*T
port: port,
tlsConfig: tlsConfig,
server: server,
log: logrus.WithFields(logrus.Fields{
teleport.ComponentKey: defaults.ProtocolCassandra,
"name": config.Name,
}),
logger: utils.NewSlogLoggerForTests().With(
teleport.ComponentKey, defaults.ProtocolCassandra,
"name", config.Name,
),
}
for _, opt := range opts {
opt(testServer)
Expand Down Expand Up @@ -336,7 +337,7 @@ func handleMessageBatch(request *frame.Frame, conn *client.CqlServerConnection,
}
responseFrame, err := codec.ConvertFromRawFrame(resp)
if err != nil {
logrus.Errorf("Error converting raw frame to frame: %v", err)
slog.ErrorContext(context.Background(), "Error converting raw frame to frame", "error", err)
return nil
}
return responseFrame
Expand Down
21 changes: 11 additions & 10 deletions lib/srv/db/clickhouse/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"crypto/tls"
"database/sql"
"io"
"log/slog"
"net"
"net/http"
"net/http/httptest"
Expand All @@ -32,11 +33,11 @@ import (
"github.com/ClickHouse/clickhouse-go/v2"
"github.com/ClickHouse/clickhouse-go/v2/lib/proto"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/utils"
)

// TestServerOption allows setting test server options.
Expand All @@ -63,7 +64,7 @@ type TestServer struct {
listener net.Listener
port string
tlsConfig *tls.Config
log logrus.FieldLogger
logger *slog.Logger
protocol string
}

Expand Down Expand Up @@ -92,10 +93,10 @@ func NewTestServer(config common.TestServerConfig, opts ...TestServerOption) (*T
listener: config.Listener,
port: port,
tlsConfig: tlsConfig,
log: logrus.WithFields(logrus.Fields{
teleport.ComponentKey: defaults.ProtocolClickHouse,
"name": config.Name,
}),
logger: utils.NewSlogLoggerForTests().With(
teleport.ComponentKey, defaults.ProtocolClickHouse,
"name", config.Name,
),
}

for _, opt := range opts {
Expand Down Expand Up @@ -182,27 +183,27 @@ func (s *TestServer) serveHTTP() error {
mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {
buff, err := io.ReadAll(request.Body)
if err != nil {
s.log.Errorf("Got unexpected error %q", err)
s.logger.ErrorContext(request.Context(), "Got unexpected error", "error", err)
}
defer request.Body.Close()

query := string(buff)
enc, ok := encHandler[query]
if !ok {
s.log.Errorf("Got unexpected query %q", query)
s.logger.ErrorContext(request.Context(), "Got unexpected query", "query", query)
writer.WriteHeader(http.StatusInternalServerError)
return
}
respBuff, err := enc()
if err != nil {
s.log.Errorf("Got unexpected error: %v", err)
s.logger.ErrorContext(request.Context(), "Got unexpected error", "error", err)
writer.WriteHeader(http.StatusInternalServerError)
return
}

_, err = writer.Write(respBuff)
if err != nil {
s.log.Errorf("Got unexpected error: %v", err)
s.logger.ErrorContext(request.Context(), "Got unexpected error", "error", err)
writer.WriteHeader(http.StatusInternalServerError)
return
}
Expand Down
47 changes: 26 additions & 21 deletions lib/srv/db/cloud/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ package cloud
import (
"context"
"encoding/json"
"log/slog"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/rds"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
Expand Down Expand Up @@ -71,10 +71,10 @@ func newAWS(ctx context.Context, config awsConfig) (*awsClient, error) {
return nil, trace.Wrap(err)
}

logger := logrus.WithFields(logrus.Fields{
teleport.ComponentKey: "aws",
"db": config.database.GetName(),
})
logger := slog.With(
teleport.ComponentKey, "aws",
"db", config.database.GetName(),
)
dbConfigurator, err := getDBConfigurator(ctx, logger, config.clients, config.database)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -92,7 +92,7 @@ func newAWS(ctx context.Context, config awsConfig) (*awsClient, error) {
cfg: config,
dbConfigurator: dbConfigurator,
iam: iam,
log: logger,
logger: logger,
}, nil
}

Expand All @@ -102,10 +102,10 @@ type dbIAMAuthConfigurator interface {
}

// getDBConfigurator returns a database IAM Auth configurator.
func getDBConfigurator(ctx context.Context, log logrus.FieldLogger, clients cloud.Clients, db types.Database) (dbIAMAuthConfigurator, error) {
func getDBConfigurator(ctx context.Context, logger *slog.Logger, clients cloud.Clients, db types.Database) (dbIAMAuthConfigurator, error) {
if db.IsRDS() {
// Only setting for RDS instances and Aurora clusters.
return &rdsDBConfigurator{clients: clients, log: log}, nil
return &rdsDBConfigurator{clients: clients, logger: logger}, nil
}
// IAM Auth for Redshift, ElastiCache, and RDS Proxy is always enabled.
return &nopDBConfigurator{}, nil
Expand All @@ -115,14 +115,14 @@ type awsClient struct {
cfg awsConfig
dbConfigurator dbIAMAuthConfigurator
iam iamiface.IAMAPI
log logrus.FieldLogger
logger *slog.Logger
}

// setupIAMAuth ensures the IAM Authentication is enbaled for RDS, Aurora, ElastiCache or Redshift database.
func (r *awsClient) setupIAMAuth(ctx context.Context) error {
if err := r.dbConfigurator.ensureIAMAuth(ctx, r.cfg.database); err != nil {
if trace.IsAccessDenied(err) { // Permission errors are expected.
r.log.Debugf("No permissions to enable IAM auth: %v.", err)
r.logger.DebugContext(ctx, "No permissions to enable IAM auth", "error", err)
return nil
}
return trace.Wrap(err)
Expand All @@ -137,7 +137,7 @@ func (r *awsClient) setupIAMAuth(ctx context.Context) error {
func (r *awsClient) setupIAMPolicy(ctx context.Context) (bool, error) {
if err := r.ensureIAMPolicy(ctx); err != nil {
if trace.IsAccessDenied(err) { // Permission errors are expected.
r.log.Debugf("No permissions to ensure IAM policy: %v.", err)
r.logger.DebugContext(ctx, "No permissions to ensure IAM policy", "error", err)
return false, nil
}

Expand All @@ -152,7 +152,7 @@ func (r *awsClient) teardownIAM(ctx context.Context) error {
var errors []error
if err := r.deleteIAMPolicy(ctx); err != nil {
if trace.IsAccessDenied(err) { // Permission errors are expected.
r.log.Debugf("No permissions to delete IAM policy: %v.", err)
r.logger.DebugContext(ctx, "No permissions to delete IAM policy", "error", err)
} else {
errors = append(errors, err)
}
Expand All @@ -174,10 +174,13 @@ func (r *awsClient) ensureIAMPolicy(ctx context.Context) error {
var changed bool
dbIAM.ForEach(func(effect, action, resource string, conditions awslib.Conditions) {
if policy.EnsureResourceAction(effect, action, resource, conditions) {
r.log.Debugf("Adding permission %q for %q to policy.", action, resource)
r.logger.DebugContext(ctx, "Adding database permission to policy",
"action", action,
"resource", resource,
)
changed = true
} else {
r.log.Debugf("Permission %q for %q is already part of policy.", action, resource)
r.logger.DebugContext(ctx, "Permission is already part of policy", "action", action, "resource", resource)
}
})
if !changed {
Expand All @@ -189,8 +192,10 @@ func (r *awsClient) ensureIAMPolicy(ctx context.Context) error {
}

if len(placeholders) > 0 {
r.log.Warnf("Please make sure the database agent has the IAM permissions to fetch cloud metadata, or make sure these values are set in the static config. Placeholders %q are found when configuring the IAM policy for database %v.",
placeholders, r.cfg.database.GetName())
r.logger.WarnContext(ctx, "Please make sure the database agent has the IAM permissions to fetch cloud metadata, or make sure these values are set in the static config. Placeholders were found when configuring the IAM policy for database.",
"placeholders", placeholders,
"database", r.cfg.database.GetName(),
)
}
return nil
}
Expand Down Expand Up @@ -252,7 +257,7 @@ func (r *awsClient) getIAMPolicy(ctx context.Context) (*awslib.PolicyDocument, e

// updateIAMPolicy attaches IAM access policy to the identity this agent is running as.
func (r *awsClient) updateIAMPolicy(ctx context.Context, policy *awslib.PolicyDocument) error {
r.log.Debugf("Updating IAM policy for %v.", r.cfg.identity)
r.logger.DebugContext(ctx, "Updating IAM policy", "identity", r.cfg.identity)
document, err := json.Marshal(policy)
if err != nil {
return trace.Wrap(err)
Expand All @@ -278,7 +283,7 @@ func (r *awsClient) updateIAMPolicy(ctx context.Context, policy *awslib.PolicyDo

// detachIAMPolicy detaches IAM access policy from the identity this agent is running as.
func (r *awsClient) detachIAMPolicy(ctx context.Context) error {
r.log.Debugf("Detaching IAM policy from %v.", r.cfg.identity)
r.logger.DebugContext(ctx, "Detaching IAM policy", "identity", r.cfg.identity)
var err error
switch r.cfg.identity.(type) {
case awslib.Role:
Expand All @@ -299,13 +304,13 @@ func (r *awsClient) detachIAMPolicy(ctx context.Context) error {

type rdsDBConfigurator struct {
clients cloud.Clients
log logrus.FieldLogger
logger *slog.Logger
}

// ensureIAMAuth enables RDS instance IAM auth if it isn't already enabled.
func (r *rdsDBConfigurator) ensureIAMAuth(ctx context.Context, db types.Database) error {
if db.GetAWS().RDS.IAMAuth {
r.log.Debug("IAM auth already enabled.")
r.logger.DebugContext(ctx, "IAM auth already enabled")
return nil
}
if err := r.enableIAMAuth(ctx, db); err != nil {
Expand All @@ -316,7 +321,7 @@ func (r *rdsDBConfigurator) ensureIAMAuth(ctx context.Context, db types.Database

// enableIAMAuth turns on IAM auth setting on the RDS instance.
func (r *rdsDBConfigurator) enableIAMAuth(ctx context.Context, db types.Database) error {
r.log.Debug("Enabling IAM auth for RDS.")
r.logger.DebugContext(ctx, "Enabling IAM auth for RDS")
meta := db.GetAWS()
rdsClt, err := r.clients.GetAWSRDSClient(ctx, meta.Region,
cloud.WithAssumeRoleFromAWSMeta(meta),
Expand Down
Loading

0 comments on commit 4eed9f0

Please sign in to comment.