From e90d4732a5efeb36109a69ffe1659097425e21e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Skrz=C4=99tnicki?= Date: Thu, 27 Jun 2024 20:07:19 +0200 Subject: [PATCH] [v16] Database Access: decouple Auth from Session (#43527) * Database Access: decouple Auth from Session, add SessionAuth wrapper, share CloudClients instance, update tests. * Drop `SessionAuth` interface, shift usages over to updated `Auth` interface * Drop unused params --- lib/srv/db/auth_test.go | 94 +++++--- lib/srv/db/cassandra/engine.go | 2 +- lib/srv/db/clickhouse/engine_http.go | 3 +- lib/srv/db/clickhouse/engine_native.go | 2 +- lib/srv/db/common/auth.go | 313 ++++++++++++++----------- lib/srv/db/common/auth_test.go | 114 ++++----- lib/srv/db/common/session.go | 6 + lib/srv/db/dynamodb/engine.go | 2 +- lib/srv/db/elasticsearch/engine.go | 2 +- lib/srv/db/mongodb/autousers_admin.go | 2 +- lib/srv/db/mongodb/connect.go | 4 +- lib/srv/db/mysql/engine.go | 6 +- lib/srv/db/mysql/gcp.go | 4 +- lib/srv/db/mysql/gcp_test.go | 40 ++-- lib/srv/db/opensearch/engine.go | 2 +- lib/srv/db/postgres/engine.go | 14 +- lib/srv/db/redis/client.go | 6 +- lib/srv/db/redis/engine.go | 2 +- lib/srv/db/server.go | 19 +- lib/srv/db/spanner/grpcserver.go | 4 +- lib/srv/db/sqlserver/connect.go | 4 +- lib/srv/db/sqlserver/engine_test.go | 3 +- 22 files changed, 358 insertions(+), 290 deletions(-) diff --git a/lib/srv/db/auth_test.go b/lib/srv/db/auth_test.go index f3d827aa47df2..994839e7e2c54 100644 --- a/lib/srv/db/auth_test.go +++ b/lib/srv/db/auth_test.go @@ -20,6 +20,7 @@ package db import ( "context" + "crypto/tls" "testing" "time" @@ -261,9 +262,9 @@ func TestAuthTokens(t *testing.T) { type testAuth struct { // Auth is the wrapped "real" auth that handles everything except for // cloud auth tokens generation. - common.Auth + realAuth common.Auth // FieldLogger is used for logging. - logrus.FieldLogger + logger logrus.FieldLogger } func newTestAuth(ac common.AuthConfig) (*testAuth, error) { @@ -272,11 +273,13 @@ func newTestAuth(ac common.AuthConfig) (*testAuth, error) { return nil, trace.Wrap(err) } return &testAuth{ - Auth: auth, - FieldLogger: logrus.WithField(teleport.ComponentKey, "auth:test"), + realAuth: auth, + logger: logrus.WithField(teleport.ComponentKey, "auth:test"), }, nil } +var _ common.Auth = (*testAuth)(nil) + const ( // rdsAuthToken is a mock RDS IAM auth token. rdsAuthToken = "rds-auth-token" @@ -314,76 +317,101 @@ type fakeTokenSource struct { } func (f *fakeTokenSource) Token() (*oauth2.Token, error) { - f.Infof("Generating Cloud Spanner auth token source") + f.Info("Generating Cloud Spanner auth token source") return &oauth2.Token{ Expiry: f.exp, AccessToken: f.token, }, nil } -// GetRDSAuthToken generates RDS/Aurora auth token. -func (a *testAuth) GetRDSAuthToken(ctx context.Context, sessionCtx *common.Session) (string, error) { - a.Infof("Generating RDS auth token for %v.", sessionCtx) +func (a *testAuth) GetRDSAuthToken(ctx context.Context, database types.Database, databaseUser string) (string, error) { + a.logger. + WithField("database", database). + WithField("databaseUser", databaseUser). + Info("Generating RDS auth token") return rdsAuthToken, nil } -// GetRedshiftAuthToken generates Redshift auth token. -func (a *testAuth) GetRedshiftAuthToken(ctx context.Context, sessionCtx *common.Session) (string, string, error) { - a.Infof("Generating Redshift auth token for %v.", sessionCtx) +func (a *testAuth) GetRedshiftAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) { + a.logger. + WithField("database", database). + WithField("databaseUser", databaseUser). + WithField("databaseName", databaseName). + Info("Generating Redshift auth token") return redshiftAuthUser, redshiftAuthToken, nil } -func (a *testAuth) GetRedshiftServerlessAuthToken(ctx context.Context, sessionCtx *common.Session) (string, string, error) { +func (a *testAuth) GetRedshiftServerlessAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) { return "", "", trace.NotImplemented("GetRedshiftServerlessAuthToken is not implemented") } -func (a *testAuth) GetElastiCacheRedisToken(ctx context.Context, sessionCtx *common.Session) (string, error) { +func (a *testAuth) GetElastiCacheRedisToken(ctx context.Context, database types.Database, databaseUser string) (string, error) { return elastiCacheRedisToken, nil } -func (a *testAuth) GetMemoryDBToken(ctx context.Context, sessionCtx *common.Session) (string, error) { +func (a *testAuth) GetMemoryDBToken(ctx context.Context, database types.Database, databaseUser string) (string, error) { return memorydbToken, nil } -// GetCloudSQLAuthToken generates Cloud SQL auth token. -func (a *testAuth) GetCloudSQLAuthToken(ctx context.Context, sessionCtx *common.Session) (string, error) { - a.Infof("Generating Cloud SQL auth token for %v.", sessionCtx) +func (a *testAuth) GetCloudSQLAuthToken(ctx context.Context, databaseUser string) (string, error) { + a.logger.WithField("database_user", databaseUser).Info("Generating Cloud SQL auth token") return cloudSQLAuthToken, nil } -// GetSpannerTokenSource returns an oauth token source for GCP Spanner. -func (a *testAuth) GetSpannerTokenSource(ctx context.Context, sessionCtx *common.Session) (oauth2.TokenSource, error) { +func (a *testAuth) GetSpannerTokenSource(ctx context.Context, databaseUser string) (oauth2.TokenSource, error) { return &fakeTokenSource{ token: cloudSpannerAuthToken, - FieldLogger: a.WithField("session", sessionCtx), + FieldLogger: a.logger.WithField("database_user", databaseUser), }, nil } -// GetCloudSQLPassword generates Cloud SQL user password. -func (a *testAuth) GetCloudSQLPassword(ctx context.Context, sessionCtx *common.Session) (string, error) { - a.Infof("Generating Cloud SQL user password %v.", sessionCtx) +func (a *testAuth) GetCloudSQLPassword(ctx context.Context, database types.Database, databaseUser string) (string, error) { + a.logger. + WithField("database", database). + WithField("database_user", databaseUser). + Info("Generating Cloud SQL password") return cloudSQLPassword, nil } -// GetAzureAccessToken generates Azure access token. -func (a *testAuth) GetAzureAccessToken(ctx context.Context, sessionCtx *common.Session) (string, error) { - a.Infof("Generating Azure access token for %v.", sessionCtx) +func (a *testAuth) GetAzureAccessToken(ctx context.Context) (string, error) { + a.logger.Info("Generating Azure access token") return azureAccessToken, nil } -// GetAzureCacheForRedisToken retrieves auth token for Azure Cache for Redis. -func (a *testAuth) GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *common.Session) (string, error) { - a.Infof("Generating Azure Redis token for %v.", sessionCtx) +func (a *testAuth) GetAzureCacheForRedisToken(ctx context.Context, database types.Database) (string, error) { + a.logger. + WithField("database", database). + Info("Generating Azure Redis token") return azureRedisToken, nil } -// GetAWSIAMCreds returns the AWS IAM credentials, including access key, secret -// access key and session token. -func (a *testAuth) GetAWSIAMCreds(ctx context.Context, sessionCtx *common.Session) (string, string, string, error) { - a.Infof("Generating AWS IAM credentials for %v.", sessionCtx) +func (a *testAuth) GetTLSConfig(ctx context.Context, expiry time.Time, database types.Database, databaseUser string) (*tls.Config, error) { + return a.realAuth.GetTLSConfig(ctx, expiry, database, databaseUser) +} + +func (a *testAuth) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) { + return a.realAuth.GetAuthPreference(ctx) +} + +func (a *testAuth) GetAzureIdentityResourceID(ctx context.Context, identityName string) (string, error) { + return a.realAuth.GetAzureIdentityResourceID(ctx, identityName) +} + +func (a *testAuth) GetAWSIAMCreds(ctx context.Context, database types.Database, databaseUser string) (string, string, string, error) { + a.logger. + WithField("database", database). + WithField("database_user", databaseUser). + Info("Generating AWS IAM credentials") return atlasAuthUser, atlasAuthToken, atlasAuthSessionToken, nil } +func (a *testAuth) WithLogger(getUpdatedLogger func(logrus.FieldLogger) logrus.FieldLogger) common.Auth { + return &testAuth{ + realAuth: a.realAuth, + logger: getUpdatedLogger(a.logger), + } +} + func TestMongoDBAtlas(t *testing.T) { t.Parallel() diff --git a/lib/srv/db/cassandra/engine.go b/lib/srv/db/cassandra/engine.go index 7553a0771f8c2..2c83a820b0420 100644 --- a/lib/srv/db/cassandra/engine.go +++ b/lib/srv/db/cassandra/engine.go @@ -284,7 +284,7 @@ func (e *Engine) authorizeConnection(ctx context.Context) error { } func (e *Engine) connect(ctx context.Context, sessionCtx *common.Session) (*protocol.Conn, error) { - config, err := e.Auth.GetTLSConfig(ctx, sessionCtx) + config, err := e.Auth.GetTLSConfig(ctx, sessionCtx.GetExpiry(), sessionCtx.Database, sessionCtx.DatabaseUser) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/db/clickhouse/engine_http.go b/lib/srv/db/clickhouse/engine_http.go index f6d7c3030eceb..bfffb20714740 100644 --- a/lib/srv/db/clickhouse/engine_http.go +++ b/lib/srv/db/clickhouse/engine_http.go @@ -210,8 +210,7 @@ func (e *Engine) getTransport(ctx context.Context) (*http.Transport, error) { if err != nil { return nil, trace.Wrap(err) } - - tlsConfig, err := e.Auth.GetTLSConfig(ctx, e.sessionCtx) + tlsConfig, err := e.Auth.GetTLSConfig(ctx, e.sessionCtx.GetExpiry(), e.sessionCtx.Database, e.sessionCtx.DatabaseUser) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/db/clickhouse/engine_native.go b/lib/srv/db/clickhouse/engine_native.go index 679def5db86e6..3e1d5779b3b95 100644 --- a/lib/srv/db/clickhouse/engine_native.go +++ b/lib/srv/db/clickhouse/engine_native.go @@ -35,7 +35,7 @@ func (e *Engine) handleNativeConnection(ctx context.Context, sessionCtx *common. return trace.Wrap(err) } - tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx) + tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx.GetExpiry(), sessionCtx.Database, sessionCtx.DatabaseUser) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/db/common/auth.go b/lib/srv/db/common/auth.go index 1ec6d2fe1d284..dc44e8b32cf20 100644 --- a/lib/srv/db/common/auth.go +++ b/lib/srv/db/common/auth.go @@ -24,7 +24,6 @@ import ( "crypto/x509" "crypto/x509/pkix" "fmt" - "io" "net/http" "net/url" "strings" @@ -72,27 +71,27 @@ const azureVirtualMachineCacheTTL = 5 * time.Minute // Auth defines interface for creating auth tokens and TLS configurations. type Auth interface { // GetRDSAuthToken generates RDS/Aurora auth token. - GetRDSAuthToken(ctx context.Context, sessionCtx *Session) (string, error) + GetRDSAuthToken(ctx context.Context, database types.Database, databaseUser string) (string, error) // GetRedshiftAuthToken generates Redshift auth token. - GetRedshiftAuthToken(ctx context.Context, sessionCtx *Session) (string, string, error) + GetRedshiftAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) // GetRedshiftServerlessAuthToken generates Redshift Serverless auth token. - GetRedshiftServerlessAuthToken(ctx context.Context, sessionCtx *Session) (string, string, error) + GetRedshiftServerlessAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) // GetElastiCacheRedisToken generates an ElastiCache Redis auth token. - GetElastiCacheRedisToken(ctx context.Context, sessionCtx *Session) (string, error) + GetElastiCacheRedisToken(ctx context.Context, database types.Database, databaseUser string) (string, error) // GetMemoryDBToken generates a MemoryDB auth token. - GetMemoryDBToken(ctx context.Context, sessionCtx *Session) (string, error) + GetMemoryDBToken(ctx context.Context, database types.Database, databaseUser string) (string, error) // GetCloudSQLAuthToken generates Cloud SQL auth token. - GetCloudSQLAuthToken(ctx context.Context, sessionCtx *Session) (string, error) + GetCloudSQLAuthToken(ctx context.Context, databaseUser string) (string, error) // GetSpannerTokenSource returns an oauth token source for GCP Spanner. - GetSpannerTokenSource(ctx context.Context, sessionCtx *Session) (oauth2.TokenSource, error) + GetSpannerTokenSource(ctx context.Context, databaseUser string) (oauth2.TokenSource, error) // GetCloudSQLPassword generates password for a Cloud SQL database user. - GetCloudSQLPassword(ctx context.Context, sessionCtx *Session) (string, error) + GetCloudSQLPassword(ctx context.Context, database types.Database, databaseUser string) (string, error) // GetAzureAccessToken generates Azure database access token. - GetAzureAccessToken(ctx context.Context, sessionCtx *Session) (string, error) + GetAzureAccessToken(ctx context.Context) (string, error) // GetAzureCacheForRedisToken retrieves auth token for Azure Cache for Redis. - GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *Session) (string, error) + GetAzureCacheForRedisToken(ctx context.Context, database types.Database) (string, error) // GetTLSConfig builds the client TLS configuration for the session. - GetTLSConfig(ctx context.Context, sessionCtx *Session) (*tls.Config, error) + GetTLSConfig(ctx context.Context, certExpiry time.Time, database types.Database, databaseUser string) (*tls.Config, error) // GetAuthPreference returns the cluster authentication config. GetAuthPreference(ctx context.Context) (types.AuthPreference, error) // GetAzureIdentityResourceID returns the Azure identity resource ID @@ -101,9 +100,10 @@ type Auth interface { GetAzureIdentityResourceID(ctx context.Context, identityName string) (string, error) // GetAWSIAMCreds returns the AWS IAM credentials, including access key, // secret access key and session token. - GetAWSIAMCreds(ctx context.Context, sessionCtx *Session) (string, string, string, error) - // Closer releases all resources used by authenticator. - io.Closer + GetAWSIAMCreds(ctx context.Context, database types.Database, databaseUser string) (string, string, string, error) + // WithLogger returns a new instance of Auth with updated logger. + // The callback function receives the current logger and returns a new one. + WithLogger(getUpdatedLogger func(logrus.FieldLogger) logrus.FieldLogger) Auth } // AuthClient is an interface that defines a subset of libauth.Client's @@ -134,11 +134,7 @@ func (c *AuthConfig) CheckAndSetDefaults() error { return trace.BadParameter("missing AuthClient") } if c.Clients == nil { - cloudClients, err := cloud.NewClients() - if err != nil { - return trace.Wrap(err) - } - c.Clients = cloudClients + return trace.BadParameter("missing Clients") } if c.Clock == nil { c.Clock = clockwork.NewRealClock() @@ -149,6 +145,15 @@ func (c *AuthConfig) CheckAndSetDefaults() error { return nil } +func (c *AuthConfig) withLogger(getUpdatedLogger func(logrus.FieldLogger) logrus.FieldLogger) AuthConfig { + return AuthConfig{ + AuthClient: c.AuthClient, + Clients: c.Clients, + Clock: c.Clock, + Log: getUpdatedLogger(c.Log), + } +} + // dbAuth provides utilities for creating TLS configurations and // generating auth tokens when connecting to databases. type dbAuth struct { @@ -179,10 +184,29 @@ func NewAuth(config AuthConfig) (Auth, error) { }, nil } +// NewAuthForSession returns a copy of Auth with session-specific logging. +func NewAuthForSession(auth Auth, sessionCtx *Session) Auth { + return auth.WithLogger(func(logger logrus.FieldLogger) logrus.FieldLogger { + return logger.WithFields(logrus.Fields{ + "session_id": sessionCtx.ID, + "database": sessionCtx.Database.GetName(), + }) + }) +} + +// WithLogger returns a new instance of Auth with updated logger. +// The callback function receives the current logger and returns a new one. +func (a *dbAuth) WithLogger(getUpdatedLogger func(logrus.FieldLogger) logrus.FieldLogger) Auth { + return &dbAuth{ + cfg: a.cfg.withLogger(getUpdatedLogger), + azureVirtualMachineCache: a.azureVirtualMachineCache, + } +} + // GetRDSAuthToken returns authorization token that will be used as a password // when connecting to RDS and Aurora databases. -func (a *dbAuth) GetRDSAuthToken(ctx context.Context, sessionCtx *Session) (string, error) { - meta := sessionCtx.Database.GetAWS() +func (a *dbAuth) GetRDSAuthToken(ctx context.Context, database types.Database, databaseUser string) (string, error) { + meta := database.GetAWS() awsSession, err := a.cfg.Clients.GetAWSSession(ctx, meta.Region, cloud.WithAssumeRoleFromAWSMeta(meta), cloud.WithAmbientCredentials(), @@ -190,14 +214,17 @@ func (a *dbAuth) GetRDSAuthToken(ctx context.Context, sessionCtx *Session) (stri if err != nil { return "", trace.Wrap(err) } - a.cfg.Log.Debugf("Generating RDS auth token for %s.", sessionCtx) + a.cfg.Log. + WithField("database", database). + WithField("database_user", databaseUser). + Debug("Generating RDS auth token") token, err := rdsutils.BuildAuthToken( - sessionCtx.Database.GetURI(), + database.GetURI(), meta.Region, - sessionCtx.DatabaseUser, + databaseUser, awsSession.Config.Credentials) if err != nil { - policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocument(sessionCtx.Database) + policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocument(database) if getPolicyErr != nil { policy = fmt.Sprintf("failed to generate IAM policy: %v", getPolicyErr) } @@ -216,17 +243,17 @@ permissions (note that IAM changes may take a few minutes to propagate): // GetRedshiftAuthToken returns authorization token that will be used as a // password when connecting to Redshift databases. -func (a *dbAuth) GetRedshiftAuthToken(ctx context.Context, sessionCtx *Session) (string, string, error) { - if awsutils.IsRoleARN(sessionCtx.DatabaseUser) { - return a.getRedshiftIAMRoleAuthToken(ctx, sessionCtx) +func (a *dbAuth) GetRedshiftAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) { + if awsutils.IsRoleARN(databaseUser) { + return a.getRedshiftIAMRoleAuthToken(ctx, database, databaseUser, databaseName) } - return a.getRedshiftDBUserAuthToken(ctx, sessionCtx) + return a.getRedshiftDBUserAuthToken(ctx, database, databaseUser, databaseName) } -func (a *dbAuth) getRedshiftIAMRoleAuthToken(ctx context.Context, sessionCtx *Session) (string, string, error) { - meta := sessionCtx.Database.GetAWS() - roleARN, err := a.buildAWSRoleARNFromDatabaseUser(ctx, sessionCtx) +func (a *dbAuth) getRedshiftIAMRoleAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) { + meta := database.GetAWS() + roleARN, err := a.buildAWSRoleARNFromDatabaseUser(ctx, database, databaseUser) if err != nil { return "", "", trace.Wrap(err) } @@ -254,13 +281,17 @@ Make sure that IAM role %q has a trust relationship with Teleport database agent } // Now make the API call to generate the temporary credentials. - a.cfg.Log.Debugf("Generating Redshift IAM role auth token for %s.", sessionCtx) + a.cfg.Log. + WithField("database", database). + WithField("database_user", databaseUser). + WithField("database_name", databaseName). + Debug("Generating Redshift IAM role auth token") resp, err := client.GetClusterCredentialsWithIAMWithContext(ctx, &redshift.GetClusterCredentialsWithIAMInput{ ClusterIdentifier: aws.String(meta.Redshift.ClusterID), - DbName: aws.String(sessionCtx.DatabaseName), + DbName: aws.String(databaseName), }) if err != nil { - policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocumentForAssumedRole(sessionCtx.Database) + policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocumentForAssumedRole(database) if getPolicyErr != nil { policy = fmt.Sprintf("failed to generate IAM policy: %v", getPolicyErr) } @@ -276,8 +307,8 @@ Make sure that IAM role %q has permissions to generate credentials. Here is a sa return aws.StringValue(resp.DbUser), aws.StringValue(resp.DbPassword), nil } -func (a *dbAuth) getRedshiftDBUserAuthToken(ctx context.Context, sessionCtx *Session) (string, string, error) { - meta := sessionCtx.Database.GetAWS() +func (a *dbAuth) getRedshiftDBUserAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) { + meta := database.GetAWS() redshiftClient, err := a.cfg.Clients.GetAWSRedshiftClient(ctx, meta.Region, cloud.WithAssumeRoleFromAWSMeta(meta), cloud.WithAmbientCredentials(), @@ -285,11 +316,15 @@ func (a *dbAuth) getRedshiftDBUserAuthToken(ctx context.Context, sessionCtx *Ses if err != nil { return "", "", trace.Wrap(err) } - a.cfg.Log.Debugf("Generating Redshift auth token for %s.", sessionCtx) + a.cfg.Log. + WithField("database", database). + WithField("database_user", databaseUser). + WithField("database_name", databaseName). + Debug("Generating Redshift auth token") resp, err := redshiftClient.GetClusterCredentialsWithContext(ctx, &redshift.GetClusterCredentialsInput{ ClusterIdentifier: aws.String(meta.Redshift.ClusterID), - DbUser: aws.String(sessionCtx.DatabaseUser), - DbName: aws.String(sessionCtx.DatabaseName), + DbUser: aws.String(databaseUser), + DbName: aws.String(databaseName), // TODO(r0mant): Do not auto-create database account if DbUser doesn't // exist for now, but it may be potentially useful in future. AutoCreate: aws.Bool(false), @@ -298,7 +333,7 @@ func (a *dbAuth) getRedshiftDBUserAuthToken(ctx context.Context, sessionCtx *Ses DbGroups: []*string{}, }) if err != nil { - policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocument(sessionCtx.Database) + policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocument(database) if getPolicyErr != nil { policy = fmt.Sprintf("failed to generate IAM policy: %v", getPolicyErr) } @@ -317,13 +352,13 @@ propagate): } // GetRedshiftServerlessAuthToken generates Redshift Serverless auth token. -func (a *dbAuth) GetRedshiftServerlessAuthToken(ctx context.Context, sessionCtx *Session) (string, string, error) { +func (a *dbAuth) GetRedshiftServerlessAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) { // Redshift Serverless maps caller IAM users/roles to database users. For // example, an IAM role "arn:aws:iam::1234567890:role/my-role-name" will be // mapped to a Postgres user "IAMR:my-role-name" inside the database. So we // first need to assume this IAM role before getting auth token. - meta := sessionCtx.Database.GetAWS() - roleARN, err := redshiftServerlessUsernameToRoleARN(meta, sessionCtx.DatabaseUser) + meta := database.GetAWS() + roleARN, err := redshiftServerlessUsernameToRoleARN(meta, databaseUser) if err != nil { return "", "", trace.Wrap(err) } @@ -350,13 +385,17 @@ Make sure that IAM role %q has a trust relationship with Teleport database agent } // Now make the API call to generate the temporary credentials. - a.cfg.Log.Debugf("Generating Redshift Serverless auth token for %s.", sessionCtx) + a.cfg.Log. + WithField("database", database). + WithField("database_user", databaseUser). + WithField("database_name", databaseName). + Debug("Generating Redshift Serverless auth token") resp, err := client.GetCredentialsWithContext(ctx, &redshiftserverless.GetCredentialsInput{ WorkgroupName: aws.String(meta.RedshiftServerless.WorkgroupName), - DbName: aws.String(sessionCtx.DatabaseName), + DbName: aws.String(databaseName), }) if err != nil { - policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocumentForAssumedRole(sessionCtx.Database) + policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocumentForAssumedRole(database) if getPolicyErr != nil { policy = fmt.Sprintf("failed to generate IAM policy: %v", getPolicyErr) } @@ -374,12 +413,12 @@ Make sure that IAM role %q has permissions to generate credentials. Here is a sa // GetCloudSQLAuthToken returns authorization token that will be used as a // password when connecting to Cloud SQL databases. -func (a *dbAuth) GetCloudSQLAuthToken(ctx context.Context, sessionCtx *Session) (string, error) { +func (a *dbAuth) GetCloudSQLAuthToken(ctx context.Context, databaseUser string) (string, error) { // https://developers.google.com/identity/protocols/oauth2/scopes#sqladmin scopes := []string{ "https://www.googleapis.com/auth/sqlservice.admin", } - ts, err := a.getCloudTokenSource(ctx, sessionCtx, scopes) + ts, err := a.getCloudTokenSource(ctx, databaseUser, scopes) if err != nil { return "", trace.Wrap(err) } @@ -391,12 +430,12 @@ func (a *dbAuth) GetCloudSQLAuthToken(ctx context.Context, sessionCtx *Session) } // GetSpannerTokenSource returns an oauth token source for GCP Spanner. -func (a *dbAuth) GetSpannerTokenSource(ctx context.Context, sessionCtx *Session) (oauth2.TokenSource, error) { +func (a *dbAuth) GetSpannerTokenSource(ctx context.Context, databaseUser string) (oauth2.TokenSource, error) { // https://developers.google.com/identity/protocols/oauth2/scopes#spanner scopes := []string{ "https://www.googleapis.com/auth/spanner.data", } - ts, err := a.getCloudTokenSource(ctx, sessionCtx, scopes) + ts, err := a.getCloudTokenSource(ctx, databaseUser, scopes) if err != nil { return nil, trace.Wrap(err) } @@ -404,19 +443,19 @@ func (a *dbAuth) GetSpannerTokenSource(ctx context.Context, sessionCtx *Session) return oauth2.ReuseTokenSource(nil, ts), nil } -func (a *dbAuth) getCloudTokenSource(ctx context.Context, sessionCtx *Session, scopes []string) (*cloudTokenSource, error) { +func (a *dbAuth) getCloudTokenSource(ctx context.Context, databaseUser string, scopes []string) (*cloudTokenSource, error) { gcpIAM, err := a.cfg.Clients.GetGCPIAMClient(ctx) if err != nil { return nil, trace.Wrap(err) } - serviceAccountName := sessionCtx.DatabaseUser + serviceAccountName := databaseUser if !strings.HasSuffix(serviceAccountName, ".gserviceaccount.com") { serviceAccountName = serviceAccountName + ".gserviceaccount.com" } return &cloudTokenSource{ ctx: ctx, client: gcpIAM, - log: a.cfg.Log.WithField("session", sessionCtx.String()), + log: a.cfg.Log.WithField("database_user", databaseUser), serviceAccount: serviceAccountName, scopes: scopes, }, nil @@ -472,12 +511,15 @@ or "iam.serviceAccounts.getAccessToken" IAM permission. // // It is used to generate a one-time password when connecting to GCP MySQL // databases which don't support IAM authentication. -func (a *dbAuth) GetCloudSQLPassword(ctx context.Context, sessionCtx *Session) (string, error) { +func (a *dbAuth) GetCloudSQLPassword(ctx context.Context, database types.Database, databaseUser string) (string, error) { gcpCloudSQL, err := a.cfg.Clients.GetGCPSQLAdminClient(ctx) if err != nil { return "", trace.Wrap(err) } - a.cfg.Log.Debugf("Generating GCP user password for %s.", sessionCtx) + a.cfg.Log. + WithField("database", database). + WithField("database_user", databaseUser). + Debug("Generating GCP user password") token, err := utils.CryptoRandomHex(defaults.TokenLenBytes) if err != nil { return "", trace.Wrap(err) @@ -492,7 +534,7 @@ func (a *dbAuth) GetCloudSQLPassword(ctx context.Context, sessionCtx *Session) ( retryCtx, cancel := context.WithTimeout(ctx, defaults.DatabaseConnectTimeout) defer cancel() err = retry.For(retryCtx, func() error { - err := a.updateCloudSQLUser(ctx, sessionCtx, gcpCloudSQL, &sqladmin.User{ + err := a.updateCloudSQLUser(ctx, database, databaseUser, gcpCloudSQL, &sqladmin.User{ Password: token, }) if err != nil && !trace.IsCompareFailed(ConvertError(err)) { // We only want to retry on 409. @@ -507,12 +549,12 @@ func (a *dbAuth) GetCloudSQLPassword(ctx context.Context, sessionCtx *Session) ( } // updateCloudSQLUser makes a request to Cloud SQL API to update the provided user. -func (a *dbAuth) updateCloudSQLUser(ctx context.Context, sessionCtx *Session, gcpCloudSQL gcp.SQLAdminClient, user *sqladmin.User) error { - err := gcpCloudSQL.UpdateUser(ctx, sessionCtx.Database, sessionCtx.DatabaseUser, user) +func (a *dbAuth) updateCloudSQLUser(ctx context.Context, database types.Database, databaseUser string, gcpCloudSQL gcp.SQLAdminClient, user *sqladmin.User) error { + err := gcpCloudSQL.UpdateUser(ctx, database, databaseUser, user) if err != nil { // Note that mysql client has a 1024 char limit for displaying errors // so we need to keep the message short when possible. This message - // does get cut off when sessionCtx.DatabaseUser or err is long. + // does get cut off when databaseUser or err is long. return trace.AccessDenied(`Could not update Cloud SQL user %q password: %v @@ -523,14 +565,14 @@ ID as "--db-user", or grant the Teleport Database Service the If the db user uses passwords, make sure Teleport Database Service has "Cloud SQL Admin" GCP IAM role, or "cloudsql.users.update" IAM permission. -`, sessionCtx.DatabaseUser, err) +`, databaseUser, err) } return nil } // GetAzureAccessToken generates Azure database access token. -func (a *dbAuth) GetAzureAccessToken(ctx context.Context, sessionCtx *Session) (string, error) { - a.cfg.Log.Debugf("Generating Azure access token for %s.", sessionCtx) +func (a *dbAuth) GetAzureAccessToken(ctx context.Context) (string, error) { + a.cfg.Log.Debug("Generating Azure access token") cred, err := a.cfg.Clients.GetAzureCredential() if err != nil { return "", trace.Wrap(err) @@ -548,8 +590,8 @@ func (a *dbAuth) GetAzureAccessToken(ctx context.Context, sessionCtx *Session) ( } // GetElastiCacheRedisToken generates an ElastiCache Redis auth token. -func (a *dbAuth) GetElastiCacheRedisToken(ctx context.Context, sessionCtx *Session) (string, error) { - meta := sessionCtx.Database.GetAWS() +func (a *dbAuth) GetElastiCacheRedisToken(ctx context.Context, database types.Database, databaseUser string) (string, error) { + meta := database.GetAWS() awsSession, err := a.cfg.Clients.GetAWSSession(ctx, meta.Region, cloud.WithAssumeRoleFromAWSMeta(meta), cloud.WithAmbientCredentials(), @@ -557,11 +599,14 @@ func (a *dbAuth) GetElastiCacheRedisToken(ctx context.Context, sessionCtx *Sessi if err != nil { return "", trace.Wrap(err) } - a.cfg.Log.Debugf("Generating ElastiCache Redis auth token for %s.", sessionCtx) + a.cfg.Log. + WithField("database", database). + WithField("database_user", databaseUser). + Debug("Generating ElastiCache Redis auth token") tokenReq := &awsRedisIAMTokenRequest{ // For IAM-enabled ElastiCache users, the username and user id properties must be identical. // https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/auth-iam.html#auth-iam-limits - userID: sessionCtx.DatabaseUser, + userID: databaseUser, targetID: meta.ElastiCache.ReplicationGroupID, serviceName: elasticache.ServiceName, region: meta.Region, @@ -573,8 +618,8 @@ func (a *dbAuth) GetElastiCacheRedisToken(ctx context.Context, sessionCtx *Sessi } // GetMemoryDBToken generates a MemoryDB auth token. -func (a *dbAuth) GetMemoryDBToken(ctx context.Context, sessionCtx *Session) (string, error) { - meta := sessionCtx.Database.GetAWS() +func (a *dbAuth) GetMemoryDBToken(ctx context.Context, database types.Database, databaseUser string) (string, error) { + meta := database.GetAWS() awsSession, err := a.cfg.Clients.GetAWSSession(ctx, meta.Region, cloud.WithAssumeRoleFromAWSMeta(meta), cloud.WithAmbientCredentials(), @@ -582,9 +627,12 @@ func (a *dbAuth) GetMemoryDBToken(ctx context.Context, sessionCtx *Session) (str if err != nil { return "", trace.Wrap(err) } - a.cfg.Log.Debugf("Generating MemoryDB auth token for %s.", sessionCtx) + a.cfg.Log. + WithField("database", database). + WithField("database_user", databaseUser). + Debug("Generating MemoryDB auth token") tokenReq := &awsRedisIAMTokenRequest{ - userID: sessionCtx.DatabaseUser, + userID: databaseUser, targetID: meta.MemoryDB.ClusterName, serviceName: strings.ToLower(memorydb.ServiceName), region: meta.Region, @@ -596,8 +644,8 @@ func (a *dbAuth) GetMemoryDBToken(ctx context.Context, sessionCtx *Session) (str } // GetAzureCacheForRedisToken retrieves auth token for Azure Cache for Redis. -func (a *dbAuth) GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *Session) (string, error) { - resourceID, err := arm.ParseResourceID(sessionCtx.Database.GetAzure().ResourceID) +func (a *dbAuth) GetAzureCacheForRedisToken(ctx context.Context, database types.Database) (string, error) { + resourceID, err := arm.ParseResourceID(database.GetAzure().ResourceID) if err != nil { return "", trace.Wrap(err) } @@ -617,24 +665,27 @@ func (a *dbAuth) GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *Ses default: return "", trace.BadParameter("unknown Azure Cache for Redis resource type: %v", resourceID.ResourceType) } - token, err := client.GetToken(ctx, sessionCtx.Database.GetAzure().ResourceID) + token, err := client.GetToken(ctx, database.GetAzure().ResourceID) if err != nil { // Some Azure error messages are long, multi-lined, and may even // contain divider lines like "------". It's unreadable in redis-cli as // the message has to be merged to a single line string. Thus logging // the original error as debug and returning a more user friendly // message. - a.cfg.Log.WithError(err).Debugf("Failed to get token for Azure Redis %q.", sessionCtx.Database.GetName()) + a.cfg.Log. + WithField("database", database). + WithError(err). + Debug("Failed to get token for Azure Redis") switch { case trace.IsAccessDenied(err): - return "", trace.AccessDenied("Failed to get token for Azure Redis %q. Please make sure the database agent has the \"listKeys\" permission to the database.", sessionCtx.Database.GetName()) + return "", trace.AccessDenied("Failed to get token for Azure Redis %q. Please make sure the database agent has the \"listKeys\" permission to the database.", database.GetName()) case trace.IsNotFound(err): // Note that Azure Cache for Redis should always have both keys // generated at all time. Here just checking in case something // wrong with the API. - return "", trace.AccessDenied("Failed to get token for Azure Redis %q. Please make sure either the primary key or the secondary key is generated.", sessionCtx.Database.GetName()) + return "", trace.AccessDenied("Failed to get token for Azure Redis %q. Please make sure either the primary key or the secondary key is generated.", database.GetName()) default: - return "", trace.Errorf("Failed to get token for Azure Redis %q.", sessionCtx.Database.GetName()) + return "", trace.Errorf("Failed to get token for Azure Redis %q.", database.GetName()) } } return token, nil @@ -645,30 +696,30 @@ func (a *dbAuth) GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *Ses // For RDS/Aurora, the config must contain RDS root certificate as a trusted // authority. For on-prem we generate a client certificate signed by the host // CA used to authenticate. -func (a *dbAuth) GetTLSConfig(ctx context.Context, sessionCtx *Session) (*tls.Config, error) { - dbTLSConfig := sessionCtx.Database.GetTLS() +func (a *dbAuth) GetTLSConfig(ctx context.Context, expiry time.Time, database types.Database, databaseUser string) (*tls.Config, error) { + dbTLSConfig := database.GetTLS() // Mode won't be set for older clients. We will default to VerifyFull then - the same as before. switch dbTLSConfig.Mode { case types.DatabaseTLSMode_INSECURE: - return a.getTLSConfigInsecure(ctx, sessionCtx) + return a.getTLSConfigInsecure(ctx, expiry, database, databaseUser) case types.DatabaseTLSMode_VERIFY_CA: - return a.getTLSConfigVerifyCA(ctx, sessionCtx) + return a.getTLSConfigVerifyCA(ctx, expiry, database, databaseUser) default: - return a.getTLSConfigVerifyFull(ctx, sessionCtx) + return a.getTLSConfigVerifyFull(ctx, expiry, database, databaseUser) } } // getTLSConfigVerifyFull returns tls.Config with full verification enabled ('verify-full' mode). // Config also includes database specific adjustment. -func (a *dbAuth) getTLSConfigVerifyFull(ctx context.Context, sessionCtx *Session) (*tls.Config, error) { +func (a *dbAuth) getTLSConfigVerifyFull(ctx context.Context, expiry time.Time, database types.Database, databaseUser string) (*tls.Config, error) { tlsConfig := &tls.Config{} // Add CA certificate to the trusted pool if it's present, e.g. when // connecting to RDS/Aurora which require AWS CA or when was provided in config file. // // Some databases may also require the system cert pool, e.g Azure Redis. - if err := setupTLSConfigRootCAs(tlsConfig, sessionCtx); err != nil { + if err := setupTLSConfigRootCAs(tlsConfig, database); err != nil { return nil, trace.Wrap(err) } @@ -690,11 +741,11 @@ func (a *dbAuth) getTLSConfigVerifyFull(ctx context.Context, sessionCtx *Session // // See the following Go issue for more context: // https://github.com/golang/go/issues/40748 - if sessionCtx.Database.IsCloudSQL() { + if database.IsCloudSQL() { // Cloud SQL server presented certificates encode instance names as // ":" in CommonName. This is verified against // the ServerName in a custom connection verification step (see below). - tlsConfig.ServerName = sessionCtx.Database.GetGCP().GetServerName() + tlsConfig.ServerName = database.GetGCP().GetServerName() // This just disables default verification. tlsConfig.InsecureSkipVerify = true // This will verify CN and cert chain on each connection. @@ -702,32 +753,32 @@ func (a *dbAuth) getTLSConfigVerifyFull(ctx context.Context, sessionCtx *Session } // Setup server name for verification. - if err := setupTLSConfigServerName(tlsConfig, sessionCtx); err != nil { + if err := setupTLSConfigServerName(tlsConfig, database); err != nil { return nil, trace.Wrap(err) } // RDS/Aurora/Redshift/ElastiCache and Cloud SQL auth is done with an auth // token so don't generate a client certificate and exit here. - if sessionCtx.Database.IsCloudHosted() { + if database.IsCloudHosted() { return tlsConfig, nil } // MongoDB Atlas doesn't not require client certificates if is using AWS // authentication. - if awsutils.IsRoleARN(sessionCtx.DatabaseUser) && sessionCtx.Database.GetType() == types.DatabaseTypeMongoAtlas { + if awsutils.IsRoleARN(databaseUser) && database.GetType() == types.DatabaseTypeMongoAtlas { return tlsConfig, nil } // Otherwise, when connecting to an onprem database, generate a client // certificate. The database instance should be configured with // Teleport's CA obtained with 'tctl auth sign --type=db'. - return a.appendClientCert(ctx, sessionCtx, tlsConfig) + return a.appendClientCert(ctx, expiry, databaseUser, tlsConfig) } // getTLSConfigInsecure generates tls.Config when TLS mode is equal to 'insecure'. // Generated configuration will accept any certificate provided by database. -func (a *dbAuth) getTLSConfigInsecure(ctx context.Context, sessionCtx *Session) (*tls.Config, error) { - tlsConfig, err := a.getTLSConfigVerifyFull(ctx, sessionCtx) +func (a *dbAuth) getTLSConfigInsecure(ctx context.Context, expiry time.Time, database types.Database, databaseUser string) (*tls.Config, error) { + tlsConfig, err := a.getTLSConfigVerifyFull(ctx, expiry, database, databaseUser) if err != nil { return nil, trace.Wrap(err) } @@ -743,8 +794,8 @@ func (a *dbAuth) getTLSConfigInsecure(ctx context.Context, sessionCtx *Session) // getTLSConfigVerifyCA generates tls.Config when TLS mode is equal to 'verify-ca'. // Generated configuration is the same as 'verify-full' except the server name // verification is disabled. -func (a *dbAuth) getTLSConfigVerifyCA(ctx context.Context, sessionCtx *Session) (*tls.Config, error) { - tlsConfig, err := a.getTLSConfigVerifyFull(ctx, sessionCtx) +func (a *dbAuth) getTLSConfigVerifyCA(ctx context.Context, expiry time.Time, database types.Database, databaseUser string) (*tls.Config, error) { + tlsConfig, err := a.getTLSConfigVerifyFull(ctx, expiry, database, databaseUser) if err != nil { return nil, trace.Wrap(err) } @@ -761,8 +812,8 @@ func (a *dbAuth) getTLSConfigVerifyCA(ctx context.Context, sessionCtx *Session) } // appendClientCert generates a client certificate and appends it to the provided tlsConfig. -func (a *dbAuth) appendClientCert(ctx context.Context, sessionCtx *Session, tlsConfig *tls.Config) (*tls.Config, error) { - cert, cas, err := a.getClientCert(ctx, sessionCtx) +func (a *dbAuth) appendClientCert(ctx context.Context, expiry time.Time, databaseUser string, tlsConfig *tls.Config) (*tls.Config, error) { + cert, cas, err := a.getClientCert(ctx, expiry, databaseUser) if err != nil { return nil, trace.Wrap(err) } @@ -778,9 +829,9 @@ func (a *dbAuth) appendClientCert(ctx context.Context, sessionCtx *Session, tlsC // setupTLSConfigRootCAs initializes the root CA cert pool for the provided // tlsConfig based on session context. -func setupTLSConfigRootCAs(tlsConfig *tls.Config, sessionCtx *Session) error { +func setupTLSConfigRootCAs(tlsConfig *tls.Config, database types.Database) error { // Start with an empty pool or a system cert pool. - if shouldUseSystemCertPool(sessionCtx) { + if shouldUseSystemCertPool(database) { systemCertPool, err := x509.SystemCertPool() if err != nil { return trace.Wrap(err) @@ -791,8 +842,8 @@ func setupTLSConfigRootCAs(tlsConfig *tls.Config, sessionCtx *Session) error { } // If CAs are provided by the database object, add them to the pool. - if len(sessionCtx.Database.GetCA()) != 0 { - if !tlsConfig.RootCAs.AppendCertsFromPEM([]byte(sessionCtx.Database.GetCA())) { + if len(database.GetCA()) != 0 { + if !tlsConfig.RootCAs.AppendCertsFromPEM([]byte(database.GetCA())) { return trace.BadParameter("invalid server CA certificate") } return nil @@ -805,12 +856,11 @@ func setupTLSConfigRootCAs(tlsConfig *tls.Config, sessionCtx *Session) error { // shouldUseSystemCertPool returns true for database servers presenting // certificates signed by publicly trusted CAs so a system cert pool can be // used. -func shouldUseSystemCertPool(sessionCtx *Session) bool { - if sessionCtx.Database.GetTLS().TrustSystemCertPool { +func shouldUseSystemCertPool(database types.Database) bool { + if database.GetTLS().TrustSystemCertPool { return true } - - switch sessionCtx.Database.GetType() { + switch database.GetType() { // Azure databases either use Baltimore Root CA or DigiCert Global Root G2. // // https://docs.microsoft.com/en-us/azure/postgresql/concepts-ssl-connection-security @@ -836,9 +886,9 @@ func shouldUseSystemCertPool(sessionCtx *Session) bool { // setupTLSConfigServerName initializes the server name for the provided // tlsConfig based on session context. -func setupTLSConfigServerName(tlsConfig *tls.Config, sessionCtx *Session) error { +func setupTLSConfigServerName(tlsConfig *tls.Config, database types.Database) error { // Use user provided server name if set. Override the current value if needed. - if dbTLSConfig := sessionCtx.Database.GetTLS(); dbTLSConfig.ServerName != "" { + if dbTLSConfig := database.GetTLS(); dbTLSConfig.ServerName != "" { tlsConfig.ServerName = dbTLSConfig.ServerName return nil } @@ -848,7 +898,7 @@ func setupTLSConfigServerName(tlsConfig *tls.Config, sessionCtx *Session) error return nil } - switch sessionCtx.Database.GetProtocol() { + switch database.GetProtocol() { case defaults.ProtocolMongoDB: // Don't set the ServerName when connecting to a MongoDB cluster - in case // of replica set the driver may dial multiple servers and will set @@ -863,8 +913,8 @@ func setupTLSConfigServerName(tlsConfig *tls.Config, sessionCtx *Session) error // hostnames. However, OSS cluster mode may redirect to an IP address, // and without correct ServerName the handshake will fail as the IPs // are not in SANs. - if sessionCtx.Database.IsAzure() { - serverName, err := azureutils.GetHostFromRedisURI(sessionCtx.Database.GetURI()) + if database.IsAzure() { + serverName, err := azureutils.GetHostFromRedisURI(database.GetURI()) if err != nil { return trace.Wrap(err) } @@ -876,7 +926,7 @@ func setupTLSConfigServerName(tlsConfig *tls.Config, sessionCtx *Session) error // Redis is using custom URI schema. return nil case defaults.ProtocolClickHouse, defaults.ProtocolClickHouseHTTP: - u, err := url.Parse(sessionCtx.Database.GetURI()) + u, err := url.Parse(database.GetURI()) if err != nil { return trace.Wrap(err) } @@ -889,7 +939,7 @@ func setupTLSConfigServerName(tlsConfig *tls.Config, sessionCtx *Session) error default: // For other databases we're always connecting to the server specified // in URI so set ServerName ourselves. - addr, err := utils.ParseAddr(sessionCtx.Database.GetURI()) + addr, err := utils.ParseAddr(database.GetURI()) if err != nil { return trace.Wrap(err) } @@ -926,24 +976,27 @@ func verifyConnectionFunc(rootCAs *x509.CertPool) func(cs tls.ConnectionState) e // getClientCert signs an ephemeral client certificate used by this // server to authenticate with the database instance. -func (a *dbAuth) getClientCert(ctx context.Context, sessionCtx *Session) (cert *tls.Certificate, cas [][]byte, err error) { +func (a *dbAuth) getClientCert(ctx context.Context, expiry time.Time, databaseUser string) (cert *tls.Certificate, cas [][]byte, err error) { privateKey, err := native.GeneratePrivateKey() if err != nil { return nil, nil, trace.Wrap(err) } // Postgres requires the database username to be encoded as a common // name in the client certificate. - subject := pkix.Name{CommonName: sessionCtx.DatabaseUser} + subject := pkix.Name{CommonName: databaseUser} csr, err := tlsca.GenerateCertificateRequestPEM(subject, privateKey) if err != nil { return nil, nil, trace.Wrap(err) } // TODO(r0mant): Cache database certificates to avoid expensive generate // operation on each connection. - a.cfg.Log.Debugf("Generating client certificate for %s.", sessionCtx) + a.cfg.Log. + WithField("database_user", databaseUser). + Debug("Generating client certificate") + resp, err := a.cfg.AuthClient.GenerateDatabaseCert(ctx, &proto.DatabaseCertRequest{ CSR: csr, - TTL: proto.Duration(sessionCtx.Identity.Expires.Sub(a.cfg.Clock.Now())), + TTL: proto.Duration(expiry.Sub(a.cfg.Clock.Now())), }) if err != nil { return nil, nil, trace.Wrap(err) @@ -1016,14 +1069,14 @@ func (a *dbAuth) getCurrentAzureVM(ctx context.Context) (*libazure.VirtualMachin return vm, nil } -func (a *dbAuth) buildAWSRoleARNFromDatabaseUser(ctx context.Context, sessionCtx *Session) (string, error) { - dbAWS := sessionCtx.Database.GetAWS() +func (a *dbAuth) buildAWSRoleARNFromDatabaseUser(ctx context.Context, database types.Database, databaseUser string) (string, error) { + dbAWS := database.GetAWS() awsAccountID := dbAWS.AccountID - if awsutils.IsPartialRoleARN(sessionCtx.DatabaseUser) && awsAccountID == "" { + if awsutils.IsPartialRoleARN(databaseUser) && awsAccountID == "" { switch { case dbAWS.AssumeRoleARN != "": - a.cfg.Log.Debugf("Using AWS Account ID from assumed role") + a.cfg.Log.Debug("Using AWS Account ID from assumed role") assumeRoleARN, err := awsutils.ParseRoleARN(dbAWS.AssumeRoleARN) if err != nil { return "", trace.Wrap(err) @@ -1031,7 +1084,7 @@ func (a *dbAuth) buildAWSRoleARNFromDatabaseUser(ctx context.Context, sessionCtx awsAccountID = assumeRoleARN.AccountID default: - a.cfg.Log.Debugf("Fetching AWS Account ID to build role ARN") + a.cfg.Log.Debug("Fetching AWS Account ID to build role ARN") stsClient, err := a.cfg.Clients.GetAWSSTSClient(ctx, dbAWS.Region, cloud.WithAmbientCredentials()) if err != nil { return "", trace.Wrap(err) @@ -1046,15 +1099,15 @@ func (a *dbAuth) buildAWSRoleARNFromDatabaseUser(ctx context.Context, sessionCtx } } - arn, err := awsutils.BuildRoleARN(sessionCtx.DatabaseUser, dbAWS.Region, awsAccountID) + arn, err := awsutils.BuildRoleARN(databaseUser, dbAWS.Region, awsAccountID) return arn, trace.Wrap(err) } // GetAWSIAMCreds returns the AWS IAM credentials, including access key, secret // access key and session token. -func (a *dbAuth) GetAWSIAMCreds(ctx context.Context, sessionCtx *Session) (string, string, string, error) { - dbAWS := sessionCtx.Database.GetAWS() - arn, err := a.buildAWSRoleARNFromDatabaseUser(ctx, sessionCtx) +func (a *dbAuth) GetAWSIAMCreds(ctx context.Context, database types.Database, databaseUser string) (string, string, string, error) { + dbAWS := database.GetAWS() + arn, err := a.buildAWSRoleARNFromDatabaseUser(ctx, database, databaseUser) if err != nil { return "", "", "", trace.Wrap(err) } @@ -1238,7 +1291,7 @@ type reportingAuth struct { } // newReportingAuth returns a reporting version of Auth, wrapping the original Auth instance. -func newReportingAuth(db types.Database, auth Auth) *reportingAuth { +func newReportingAuth(db types.Database, auth Auth) Auth { return &reportingAuth{ Auth: auth, component: "db:auth", @@ -1246,9 +1299,7 @@ func newReportingAuth(db types.Database, auth Auth) *reportingAuth { } } -func (r *reportingAuth) GetTLSConfig(ctx context.Context, sessionCtx *Session) (*tls.Config, error) { +func (r *reportingAuth) GetTLSConfig(ctx context.Context, expiry time.Time, database types.Database, databaseUser string) (*tls.Config, error) { defer methodCallMetrics("GetTLSConfig", r.component, r.db)() - return r.Auth.GetTLSConfig(ctx, sessionCtx) + return r.Auth.GetTLSConfig(ctx, expiry, database, databaseUser) } - -var _ Auth = (*reportingAuth)(nil) diff --git a/lib/srv/db/common/auth_test.go b/lib/srv/db/common/auth_test.go index 4fe788a68b0dd..63d6247e4d646 100644 --- a/lib/srv/db/common/auth_test.go +++ b/lib/srv/db/common/auth_test.go @@ -90,9 +90,7 @@ func TestAuthGetAzureCacheForRedisToken(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - token, err := auth.GetAzureCacheForRedisToken(context.TODO(), &Session{ - Database: newAzureRedisDatabase(t, test.resourceID), - }) + token, err := auth.GetAzureCacheForRedisToken(context.TODO(), newAzureRedisDatabase(t, test.resourceID)) if test.expectError { require.Error(t, err) } else { @@ -121,11 +119,11 @@ func TestAuthGetRedshiftServerlessAuthToken(t *testing.T) { }) require.NoError(t, err) - dbUser, dbPassword, err := auth.GetRedshiftServerlessAuthToken(context.TODO(), &Session{ - DatabaseUser: "some-user", - DatabaseName: "some-database", - Database: newRedshiftServerlessDatabase(t), - }) + dbUser, dbPassword, err := auth.GetRedshiftServerlessAuthToken(context.TODO(), + newRedshiftServerlessDatabase(t), + "some-user", + "some-database", + ) require.NoError(t, err) require.Equal(t, "IAM:some-user", dbUser) require.Equal(t, "some-password", dbPassword) @@ -233,11 +231,10 @@ func TestAuthGetTLSConfig(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - tlsConfig, err := auth.GetTLSConfig(context.TODO(), &Session{ - Identity: tlsca.Identity{}, - DatabaseUser: "default", - Database: test.sessionDatabase, - }) + tlsConfig, err := auth.GetTLSConfig(context.TODO(), + time.Now().Add(time.Hour), + test.sessionDatabase, + "defaultUser") require.NoError(t, err) require.Equal(t, test.expectServerName, tlsConfig.ServerName) @@ -463,24 +460,22 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) tests := map[string]struct { - sessionCtx *Session - checkGetAuthFn func(t *testing.T, auth Auth, sessionCtx *Session) + checkGetAuthFn func(t *testing.T, auth Auth) checkSTS func(t *testing.T, stsMock *mocks.STSMock) }{ "Redshift": { - sessionCtx: &Session{ - DatabaseUser: "some-user", - DatabaseName: "some-database", - Database: newRedshiftDatabase(t, + checkGetAuthFn: func(t *testing.T, auth Auth) { + t.Helper() + databaseUser := "some-user" + databaseName := "some-database" + database := newRedshiftDatabase(t, withCA(fixtures.SAMLOktaCertPEM), withAssumeRole(types.AssumeRole{ RoleARN: "arn:aws:iam::123456789012:role/RedshiftRole", ExternalID: "externalRedshift", - })), - }, - checkGetAuthFn: func(t *testing.T, auth Auth, sessionCtx *Session) { - t.Helper() - dbUser, dbPassword, err := auth.GetRedshiftAuthToken(ctx, sessionCtx) + })) + + dbUser, dbPassword, err := auth.GetRedshiftAuthToken(ctx, database, databaseUser, databaseName) require.NoError(t, err) require.Equal(t, "IAM:some-user", dbUser) require.Equal(t, "some-password", dbPassword) @@ -492,19 +487,18 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { }, }, "Redshift with IAM role": { - sessionCtx: &Session{ - DatabaseUser: "role/some-role", - DatabaseName: "some-database", - Database: newRedshiftDatabase(t, + checkGetAuthFn: func(t *testing.T, auth Auth) { + t.Helper() + databaseUser := "role/some-role" + databaseName := "some-database" + database := newRedshiftDatabase(t, withCA(fixtures.SAMLOktaCertPEM), withAssumeRole(types.AssumeRole{ RoleARN: "arn:aws:iam::123456789012:role/RedshiftRole", ExternalID: "externalRedshift", - })), - }, - checkGetAuthFn: func(t *testing.T, auth Auth, sessionCtx *Session) { - t.Helper() - dbUser, dbPassword, err := auth.GetRedshiftAuthToken(ctx, sessionCtx) + })) + + dbUser, dbPassword, err := auth.GetRedshiftAuthToken(ctx, database, databaseUser, databaseName) require.NoError(t, err) require.Equal(t, "IAM:some-role", dbUser) require.Equal(t, "some-password-for-some-role", dbPassword) @@ -516,18 +510,17 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { }, }, "Redshift Serverless": { - sessionCtx: &Session{ - DatabaseUser: "some-user", - DatabaseName: "some-database", - Database: newRedshiftServerlessDatabase(t, + checkGetAuthFn: func(t *testing.T, auth Auth) { + t.Helper() + databaseUser := "some-user" + databaseName := "some-database" + database := newRedshiftServerlessDatabase(t, withAssumeRole(types.AssumeRole{ RoleARN: "arn:aws:iam::123456789012:role/RedshiftServerlessRole", ExternalID: "externalRedshiftServerless", - })), - }, - checkGetAuthFn: func(t *testing.T, auth Auth, sessionCtx *Session) { - t.Helper() - dbUser, dbPassword, err := auth.GetRedshiftServerlessAuthToken(ctx, sessionCtx) + })) + + dbUser, dbPassword, err := auth.GetRedshiftServerlessAuthToken(ctx, database, databaseUser, databaseName) require.NoError(t, err) require.Equal(t, "IAM:some-user", dbUser) require.Equal(t, "some-password", dbPassword) @@ -540,18 +533,15 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { }, }, "RDS Proxy": { - sessionCtx: &Session{ - DatabaseUser: "some-user", - DatabaseName: "some-database", - Database: newRDSProxyDatabase(t, "my-proxy.proxy-abcdefghijklmnop.us-east-1.rds.amazonaws.com:5432", + checkGetAuthFn: func(t *testing.T, auth Auth) { + t.Helper() + databaseUser := "some-user" + database := newRDSProxyDatabase(t, "my-proxy.proxy-abcdefghijklmnop.us-east-1.rds.amazonaws.com:5432", withAssumeRole(types.AssumeRole{ RoleARN: "arn:aws:iam::123456789012:role/RDSProxyRole", ExternalID: "externalRDSProxy", - })), - }, - checkGetAuthFn: func(t *testing.T, auth Auth, sessionCtx *Session) { - t.Helper() - token, err := auth.GetRDSAuthToken(ctx, sessionCtx) + })) + token, err := auth.GetRDSAuthToken(ctx, database, databaseUser) require.NoError(t, err) require.Contains(t, token, "DBUser=some-user") }, @@ -562,18 +552,15 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { }, }, "ElastiCache Redis": { - sessionCtx: &Session{ - DatabaseUser: "some-user", - DatabaseName: "some-database", - Database: newElastiCacheRedisDatabase(t, + checkGetAuthFn: func(t *testing.T, auth Auth) { + t.Helper() + databaseUser := "some-user" + database := newElastiCacheRedisDatabase(t, withAssumeRole(types.AssumeRole{ RoleARN: "arn:aws:iam::123456789012:role/RedisRole", ExternalID: "externalElastiCacheRedis", - })), - }, - checkGetAuthFn: func(t *testing.T, auth Auth, sessionCtx *Session) { - t.Helper() - token, err := auth.GetElastiCacheRedisToken(ctx, sessionCtx) + })) + token, err := auth.GetElastiCacheRedisToken(ctx, database, databaseUser) require.NoError(t, err) u, err := url.Parse(token) require.NoError(t, err) @@ -617,7 +604,7 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { tt := tt t.Run(name, func(t *testing.T) { t.Parallel() - tt.checkGetAuthFn(t, auth, tt.sessionCtx) + tt.checkGetAuthFn(t, auth) tt.checkSTS(t, stsMock) }) } @@ -694,10 +681,7 @@ func TestGetAWSIAMCreds(t *testing.T) { }) require.NoError(t, err) - keyId, _, _, err := auth.GetAWSIAMCreds(ctx, &Session{ - Database: tt.db, - DatabaseUser: tt.username, - }) + keyId, _, _, err := auth.GetAWSIAMCreds(ctx, tt.db, tt.username) tt.expectErr(t, err) require.Equal(t, tt.expectedKeyId, keyId) require.ElementsMatch(t, tt.expectedAssumedRoles, tt.stsMock.GetAssumedRoleARNs()) diff --git a/lib/srv/db/common/session.go b/lib/srv/db/common/session.go index b70b0f4c9b135..fcc7b846c92e0 100644 --- a/lib/srv/db/common/session.go +++ b/lib/srv/db/common/session.go @@ -21,6 +21,7 @@ package common import ( "fmt" "strings" + "time" "github.com/gravitational/trace" "github.com/sirupsen/logrus" @@ -127,3 +128,8 @@ func (c *Session) CheckUsernameForAutoUserProvisioning() error { return trace.AccessDenied("please use your Teleport username (%q) to connect instead of %q", c.Identity.Username, c.DatabaseUser) } + +// GetExpiry returns the expiry time of current session. +func (c *Session) GetExpiry() time.Time { + return c.Identity.Expires +} diff --git a/lib/srv/db/dynamodb/engine.go b/lib/srv/db/dynamodb/engine.go index 3c2e3f4fcbf48..48ef1c4a81371 100644 --- a/lib/srv/db/dynamodb/engine.go +++ b/lib/srv/db/dynamodb/engine.go @@ -323,7 +323,7 @@ func (e *Engine) getRoundTripper(ctx context.Context, URL string) (http.RoundTri if rt, ok := e.RoundTrippers[URL]; ok { return rt, nil } - tlsConfig, err := e.Auth.GetTLSConfig(ctx, e.sessionCtx) + tlsConfig, err := e.Auth.GetTLSConfig(ctx, e.sessionCtx.GetExpiry(), e.sessionCtx.Database, e.sessionCtx.DatabaseUser) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/db/elasticsearch/engine.go b/lib/srv/db/elasticsearch/engine.go index 0d53f7150ae7c..d2a91833509c1 100644 --- a/lib/srv/db/elasticsearch/engine.go +++ b/lib/srv/db/elasticsearch/engine.go @@ -120,7 +120,7 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio return trace.BadParameter("database username required for Elasticsearch") } - tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx) + tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx.GetExpiry(), sessionCtx.Database, sessionCtx.DatabaseUser) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/db/mongodb/autousers_admin.go b/lib/srv/db/mongodb/autousers_admin.go index d6d47b713a0d8..c3ed4e4394aeb 100644 --- a/lib/srv/db/mongodb/autousers_admin.go +++ b/lib/srv/db/mongodb/autousers_admin.go @@ -190,7 +190,7 @@ type makeBasicAdminClientFunc func(context.Context, *common.Session, *Engine) (a func makeBasicAdminClient(ctx context.Context, sessionCtx *common.Session, e *Engine) (adminClient, error) { sessionCtx = sessionCtx.WithUser(sessionCtx.Database.GetAdminUser().Name) - tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx) + tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx.GetExpiry(), sessionCtx.Database, sessionCtx.DatabaseUser) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/db/mongodb/connect.go b/lib/srv/db/mongodb/connect.go index 063bd80ce7b33..a923d6a89149c 100644 --- a/lib/srv/db/mongodb/connect.go +++ b/lib/srv/db/mongodb/connect.go @@ -132,7 +132,7 @@ func (e *Engine) getServerOptions(ctx context.Context, sessionCtx *common.Sessio // getConnectionOptions constructs connection options for connecting to a MongoDB server. func (e *Engine) getConnectionOptions(ctx context.Context, sessionCtx *common.Session, clientCfg *options.ClientOptions) ([]topology.ConnectionOption, error) { - tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx) + tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx.GetExpiry(), sessionCtx.Database, sessionCtx.DatabaseUser) if err != nil { return nil, trace.Wrap(err) } @@ -192,7 +192,7 @@ func (e *Engine) getAuthenticator(ctx context.Context, sessionCtx *common.Sessio func (e *Engine) getAWSAuthenticator(ctx context.Context, sessionCtx *common.Session) (auth.Authenticator, error) { e.Log.Debug("Authenticating to database using AWS IAM authentication.") - username, password, sessToken, err := e.Auth.GetAWSIAMCreds(ctx, sessionCtx) + username, password, sessToken, err := e.Auth.GetAWSIAMCreds(ctx, sessionCtx.Database, sessionCtx.DatabaseUser) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/db/mysql/engine.go b/lib/srv/db/mysql/engine.go index 69721192ab4f8..acb01f153a78b 100644 --- a/lib/srv/db/mysql/engine.go +++ b/lib/srv/db/mysql/engine.go @@ -218,7 +218,7 @@ func (e *Engine) checkAccess(ctx context.Context, sessionCtx *common.Session) er // connect establishes connection to MySQL database. func (e *Engine) connect(ctx context.Context, sessionCtx *common.Session) (*client.Conn, error) { - tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx) + tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx.GetExpiry(), sessionCtx.Database, sessionCtx.DatabaseUser) if err != nil { return nil, trace.Wrap(err) } @@ -231,7 +231,7 @@ func (e *Engine) connect(ctx context.Context, sessionCtx *common.Session) (*clie var password string switch { case sessionCtx.Database.IsRDS(), sessionCtx.Database.IsRDSProxy(): - password, err = e.Auth.GetRDSAuthToken(ctx, sessionCtx) + password, err = e.Auth.GetRDSAuthToken(ctx, sessionCtx.Database, sessionCtx.DatabaseUser) if err != nil { return nil, trace.Wrap(err) } @@ -265,7 +265,7 @@ func (e *Engine) connect(ctx context.Context, sessionCtx *common.Session) (*clie dialer = newGCPTLSDialer(tlsConfig) } case sessionCtx.Database.IsAzure(): - password, err = e.Auth.GetAzureAccessToken(ctx, sessionCtx) + password, err = e.Auth.GetAzureAccessToken(ctx) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/db/mysql/gcp.go b/lib/srv/db/mysql/gcp.go index 4b424f062f00e..ca4723c19caf3 100644 --- a/lib/srv/db/mysql/gcp.go +++ b/lib/srv/db/mysql/gcp.go @@ -128,7 +128,7 @@ func (e *Engine) getGCPIAMAuthToken(ctx context.Context, sessionCtx *common.Sess e.Log.WithField("session", sessionCtx).Debug("Authenticating GCP MySQL with IAM auth.") // Note that sessionCtx.DatabaseUser is the service account. - password, err := e.Auth.GetCloudSQLAuthToken(ctx, sessionCtx) + password, err := e.Auth.GetCloudSQLAuthToken(ctx, sessionCtx.DatabaseUser) return password, trace.Wrap(err) } @@ -154,7 +154,7 @@ func (e *Engine) getGCPOneTimePassword(ctx context.Context, sessionCtx *common.S e.Log.WithError(err).Errorf("Failed to cancel lease: %v.", lease) } }() - password, err := e.Auth.GetCloudSQLPassword(ctx, sessionCtx) + password, err := e.Auth.GetCloudSQLPassword(ctx, sessionCtx.Database, sessionCtx.DatabaseUser) if err != nil { return "", trace.Wrap(err) } diff --git a/lib/srv/db/mysql/gcp_test.go b/lib/srv/db/mysql/gcp_test.go index c06e49dc9c9e9..bae16e2c877be 100644 --- a/lib/srv/db/mysql/gcp_test.go +++ b/lib/srv/db/mysql/gcp_test.go @@ -41,30 +41,35 @@ type fakeAuth struct { common.Auth } -func (a fakeAuth) GetCloudSQLAuthToken(ctx context.Context, sessionCtx *common.Session) (string, error) { - if !isDBUserFullGCPServerAccountID(sessionCtx.DatabaseUser) { +func (a fakeAuth) GetCloudSQLAuthToken(ctx context.Context, databaseUser string) (string, error) { + if !isDBUserFullGCPServerAccountID(databaseUser) { return "", trace.BadParameter("database user must be a service account") } return "iam-auth-token", nil } -func (a fakeAuth) GetCloudSQLPassword(ctx context.Context, sessionCtx *common.Session) (string, error) { - if isDBUserFullGCPServerAccountID(sessionCtx.DatabaseUser) { +func (a fakeAuth) GetCloudSQLPassword(ctx context.Context, database types.Database, databaseUser string) (string, error) { + if isDBUserFullGCPServerAccountID(databaseUser) { return "", trace.BadParameter("database user must not be a service account") } return "one-time-password", nil } -func Test_getGCPUserAndPassowrd(t *testing.T) { +func (a fakeAuth) WithLogger(getUpdatedLogger func(logrus.FieldLogger) logrus.FieldLogger) common.Auth { + if a.Auth != nil { + return a.Auth.WithLogger(getUpdatedLogger) + } + return a +} + +func Test_getGCPUserAndPassword(t *testing.T) { ctx := context.Background() authClient := makeAuthClient(t) db := makeGCPMySQLDatabase(t) - dbAuth := &fakeAuth{} tests := []struct { name string inputDatabaseUser string - mockDBAuth common.Auth mockGCPClient gcp.SQLAdminClient wantDatabaseUser string wantPassword string @@ -73,20 +78,17 @@ func Test_getGCPUserAndPassowrd(t *testing.T) { { name: "iam auth with full service account", inputDatabaseUser: "iam-auth-user@project-id.iam.gserviceaccount.com", - mockDBAuth: dbAuth, wantDatabaseUser: "iam-auth-user", wantPassword: "iam-auth-token", }, { name: "iam auth with short service account", inputDatabaseUser: "iam-auth-user@project-id.iam", - mockDBAuth: dbAuth, wantError: true, }, { name: "iam auth with CLOUD_IAM_SERVICE_ACCOUNT user", inputDatabaseUser: "iam-auth-user", - mockDBAuth: dbAuth, mockGCPClient: &mocks.GCPSQLAdminClientMock{ DatabaseUser: makeGCPDatabaseUser("iam-auth-user", "CLOUD_IAM_SERVICE_ACCOUNT"), }, @@ -96,7 +98,6 @@ func Test_getGCPUserAndPassowrd(t *testing.T) { { name: "iam auth with CLOUD_IAM_GROUP_SERVICE_ACCOUNT user", inputDatabaseUser: "iam-auth-user", - mockDBAuth: dbAuth, mockGCPClient: &mocks.GCPSQLAdminClientMock{ DatabaseUser: makeGCPDatabaseUser("iam-auth-user", "CLOUD_IAM_GROUP_SERVICE_ACCOUNT"), }, @@ -106,7 +107,6 @@ func Test_getGCPUserAndPassowrd(t *testing.T) { { name: "password auth without GetUser permission", inputDatabaseUser: "some-user", - mockDBAuth: dbAuth, mockGCPClient: &mocks.GCPSQLAdminClientMock{ // Default no permission to GetUser, }, @@ -116,7 +116,6 @@ func Test_getGCPUserAndPassowrd(t *testing.T) { { name: "password auth with BUILT_IN user", inputDatabaseUser: "password-user", - mockDBAuth: dbAuth, mockGCPClient: &mocks.GCPSQLAdminClientMock{ DatabaseUser: makeGCPDatabaseUser("password-user", "BUILT_IN"), }, @@ -126,7 +125,6 @@ func Test_getGCPUserAndPassowrd(t *testing.T) { { name: "password auth with empty user type", inputDatabaseUser: "password-user", - mockDBAuth: dbAuth, mockGCPClient: &mocks.GCPSQLAdminClientMock{ DatabaseUser: makeGCPDatabaseUser("password-user", ""), }, @@ -136,7 +134,6 @@ func Test_getGCPUserAndPassowrd(t *testing.T) { { name: "unsupported user type", inputDatabaseUser: "some-user", - mockDBAuth: dbAuth, mockGCPClient: &mocks.GCPSQLAdminClientMock{ DatabaseUser: makeGCPDatabaseUser("some-user", "CLOUD_IAM_USER"), }, @@ -147,19 +144,20 @@ func Test_getGCPUserAndPassowrd(t *testing.T) { for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { + sessionCtx := &common.Session{ + Database: db, + DatabaseUser: test.inputDatabaseUser, + ID: "00000000-0000AAAA-0000BBBB-0000CCCC", + } + engine := NewEngine(common.EngineConfig{ - Auth: test.mockDBAuth, + Auth: &fakeAuth{}, AuthClient: authClient, Context: ctx, Clock: clockwork.NewRealClock(), Log: logrus.StandardLogger(), }).(*Engine) - sessionCtx := &common.Session{ - Database: db, - DatabaseUser: test.inputDatabaseUser, - } - databaseUser, password, err := engine.getGCPUserAndPassword(ctx, sessionCtx, test.mockGCPClient) if test.wantError { require.Error(t, err) diff --git a/lib/srv/db/opensearch/engine.go b/lib/srv/db/opensearch/engine.go index e9174afebee2b..5bcf0a0d1837f 100644 --- a/lib/srv/db/opensearch/engine.go +++ b/lib/srv/db/opensearch/engine.go @@ -230,7 +230,7 @@ func (e *Engine) getTransport(ctx context.Context) (*http.Transport, error) { return nil, trace.Wrap(err) } - tlsConfig, err := e.Auth.GetTLSConfig(ctx, e.sessionCtx) + tlsConfig, err := e.Auth.GetTLSConfig(ctx, e.sessionCtx.GetExpiry(), e.sessionCtx.Database, e.sessionCtx.DatabaseUser) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/db/postgres/engine.go b/lib/srv/db/postgres/engine.go index 47c0c7afe6de2..21a424aa9c658 100644 --- a/lib/srv/db/postgres/engine.go +++ b/lib/srv/db/postgres/engine.go @@ -483,7 +483,7 @@ func (e *Engine) getConnectConfig(ctx context.Context, sessionCtx *common.Sessio } // TLS config will use client certificate for an onprem database or // will contain RDS root certificate for RDS/Aurora. - config.TLSConfig, err = e.Auth.GetTLSConfig(ctx, sessionCtx) + config.TLSConfig, err = e.Auth.GetTLSConfig(ctx, sessionCtx.GetExpiry(), sessionCtx.Database, sessionCtx.DatabaseUser) if err != nil { return nil, trace.Wrap(err) } @@ -499,22 +499,22 @@ func (e *Engine) getConnectConfig(ctx context.Context, sessionCtx *common.Sessio // auth token and use it as a password. switch sessionCtx.Database.GetType() { case types.DatabaseTypeRDS, types.DatabaseTypeRDSProxy: - config.Password, err = e.Auth.GetRDSAuthToken(ctx, sessionCtx) + config.Password, err = e.Auth.GetRDSAuthToken(ctx, sessionCtx.Database, sessionCtx.DatabaseUser) if err != nil { return nil, trace.Wrap(err) } case types.DatabaseTypeRedshift: - config.User, config.Password, err = e.Auth.GetRedshiftAuthToken(ctx, sessionCtx) + config.User, config.Password, err = e.Auth.GetRedshiftAuthToken(ctx, sessionCtx.Database, sessionCtx.DatabaseUser, sessionCtx.DatabaseName) if err != nil { return nil, trace.Wrap(err) } case types.DatabaseTypeRedshiftServerless: - config.User, config.Password, err = e.Auth.GetRedshiftServerlessAuthToken(ctx, sessionCtx) + config.User, config.Password, err = e.Auth.GetRedshiftServerlessAuthToken(ctx, sessionCtx.Database, sessionCtx.DatabaseUser, sessionCtx.DatabaseName) if err != nil { return nil, trace.Wrap(err) } case types.DatabaseTypeCloudSQL: - config.Password, err = e.Auth.GetCloudSQLAuthToken(ctx, sessionCtx) + config.Password, err = e.Auth.GetCloudSQLAuthToken(ctx, sessionCtx.DatabaseUser) if err != nil { return nil, trace.Wrap(err) } @@ -538,7 +538,7 @@ func (e *Engine) getConnectConfig(ctx context.Context, sessionCtx *common.Sessio } } case types.DatabaseTypeAzure: - config.Password, err = e.Auth.GetAzureAccessToken(ctx, sessionCtx) + config.Password, err = e.Auth.GetAzureAccessToken(ctx) if err != nil { return nil, trace.Wrap(err) } @@ -553,7 +553,7 @@ func (e *Engine) handleCancelRequest(ctx context.Context, sessionCtx *common.Ses if err != nil { return trace.Wrap(err) } - tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx) + tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx.GetExpiry(), sessionCtx.Database, sessionCtx.DatabaseUser) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/db/redis/client.go b/lib/srv/db/redis/client.go index 5946746220893..cf0fd9a76e806 100644 --- a/lib/srv/db/redis/client.go +++ b/lib/srv/db/redis/client.go @@ -208,7 +208,7 @@ func managedUserCredFetchFunc(sessionCtx *common.Session, auth common.Auth, user func azureAccessKeyFetchFunc(sessionCtx *common.Session, auth common.Auth) fetchCredentialsFunc { return func(ctx context.Context) (string, string, error) { // Retrieve the auth token for Azure Cache for Redis. Use default user. - password, err := auth.GetAzureCacheForRedisToken(ctx, sessionCtx) + password, err := auth.GetAzureCacheForRedisToken(ctx, sessionCtx.Database) if err != nil { return "", "", trace.AccessDenied("failed to get Azure access key: %v", err) } @@ -221,7 +221,7 @@ func azureAccessKeyFetchFunc(sessionCtx *common.Session, auth common.Auth) fetch func elasticacheIAMTokenFetchFunc(sessionCtx *common.Session, auth common.Auth) fetchCredentialsFunc { return func(ctx context.Context) (string, string, error) { // Retrieve the auth token for AWS IAM ElastiCache. - password, err := auth.GetElastiCacheRedisToken(ctx, sessionCtx) + password, err := auth.GetElastiCacheRedisToken(ctx, sessionCtx.Database, sessionCtx.DatabaseUser) if err != nil { return "", "", trace.AccessDenied( "failed to get AWS ElastiCache IAM auth token for %v: %v", @@ -234,7 +234,7 @@ func elasticacheIAMTokenFetchFunc(sessionCtx *common.Session, auth common.Auth) // memorydbIAMTokenFetchFunc fetches an AWS MemoryDB IAM auth token. func memorydbIAMTokenFetchFunc(sessionCtx *common.Session, auth common.Auth) fetchCredentialsFunc { return func(ctx context.Context) (string, string, error) { - password, err := auth.GetMemoryDBToken(ctx, sessionCtx) + password, err := auth.GetMemoryDBToken(ctx, sessionCtx.Database, sessionCtx.DatabaseUser) if err != nil { return "", "", trace.AccessDenied( "failed to get AWS MemoryDB IAM auth token for %v: %v", diff --git a/lib/srv/db/redis/engine.go b/lib/srv/db/redis/engine.go index b5ea50118dffe..ef55a75455912 100644 --- a/lib/srv/db/redis/engine.go +++ b/lib/srv/db/redis/engine.go @@ -216,7 +216,7 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio // getNewClientFn returns a partial Redis client factory function. func (e *Engine) getNewClientFn(ctx context.Context, sessionCtx *common.Session) (redisClientFactoryFn, error) { - tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx) + tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx.GetExpiry(), sessionCtx.Database, sessionCtx.DatabaseUser) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index 302c37162d405..1fc48aa4e7d5b 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -175,10 +175,18 @@ func (c *Config) CheckAndSetDefaults(ctx context.Context) (err error) { if c.NewAudit == nil { c.NewAudit = common.NewAudit } + if c.CloudClients == nil { + cloudClients, err := clients.NewClients() + if err != nil { + return trace.Wrap(err) + } + c.CloudClients = cloudClients + } if c.Auth == nil { c.Auth, err = common.NewAuth(common.AuthConfig{ AuthClient: c.AuthClient, Clock: c.Clock, + Clients: c.CloudClients, }) if err != nil { return trace.Wrap(err) @@ -205,13 +213,6 @@ func (c *Config) CheckAndSetDefaults(ctx context.Context) (err error) { if c.ConnectionMonitor == nil { return trace.BadParameter("missing ConnectionMonitor") } - if c.CloudClients == nil { - cloudClients, err := clients.NewClients() - if err != nil { - return trace.Wrap(err) - } - c.CloudClients = cloudClients - } if c.CloudMeta == nil { c.CloudMeta, err = cloud.NewMetadata(cloud.MetadataConfig{ Clients: c.CloudClients, @@ -891,7 +892,7 @@ func (s *Server) close(ctx context.Context) error { s.watcher.Close() } // Close all cloud clients. - errors = append(errors, s.cfg.Auth.Close()) + errors = append(errors, s.cfg.CloudClients.Close()) return trace.NewAggregate(errors...) } @@ -1090,7 +1091,7 @@ func (s *Server) dispatch(sessionCtx *common.Session, rec events.SessionPreparer // An error is returned when a protocol is not supported. func (s *Server) createEngine(sessionCtx *common.Session, audit common.Audit) (common.Engine, error) { return common.GetEngine(sessionCtx.Database, common.EngineConfig{ - Auth: s.cfg.Auth, + Auth: common.NewAuthForSession(s.cfg.Auth, sessionCtx), Audit: audit, AuthClient: s.cfg.AuthClient, CloudClients: s.cfg.CloudClients, diff --git a/lib/srv/db/spanner/grpcserver.go b/lib/srv/db/spanner/grpcserver.go index 3af1e7d773170..0d008cff5e831 100644 --- a/lib/srv/db/spanner/grpcserver.go +++ b/lib/srv/db/spanner/grpcserver.go @@ -300,7 +300,7 @@ func (e *Engine) getClientLocked(ctx context.Context) (spannerpb.SpannerClient, return spannerpb.NewSpannerClient(e.gcloudClient), nil } - tlsCfg, err := e.Auth.GetTLSConfig(ctx, e.sessionCtx) + tlsCfg, err := e.Auth.GetTLSConfig(ctx, e.sessionCtx.GetExpiry(), e.sessionCtx.Database, e.sessionCtx.DatabaseUser) if err != nil { return nil, trace.Wrap(err) } @@ -311,7 +311,7 @@ func (e *Engine) getClientLocked(ctx context.Context) (spannerpb.SpannerClient, // reason someone would want to use the full service account name, like for // a service account in a different project? dbUser := databaseUserToGCPServiceAccount(e.sessionCtx) - ts, err := e.Auth.GetSpannerTokenSource(e.clientConnContext, e.sessionCtx.WithUser(dbUser)) + ts, err := e.Auth.GetSpannerTokenSource(e.clientConnContext, dbUser) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/db/sqlserver/connect.go b/lib/srv/db/sqlserver/connect.go index a3b4f6b291c39..e1ec06a31862f 100644 --- a/lib/srv/db/sqlserver/connect.go +++ b/lib/srv/db/sqlserver/connect.go @@ -97,7 +97,7 @@ func (c *connector) Connect(ctx context.Context, sessionCtx *common.Session, log return nil, nil, trace.Wrap(err) } - tlsConfig, err := c.DBAuth.GetTLSConfig(ctx, sessionCtx) + tlsConfig, err := c.DBAuth.GetTLSConfig(ctx, sessionCtx.GetExpiry(), sessionCtx.Database, sessionCtx.DatabaseUser) if err != nil { return nil, nil, trace.Wrap(err) } @@ -190,6 +190,6 @@ func (c *connector) getAzureConnector(ctx context.Context, sessionCtx *common.Se // authenticate. func (c *connector) getAccessTokenConnector(ctx context.Context, sessionCtx *common.Session, dsnConfig msdsn.Config) (*mssql.Connector, error) { return mssql.NewSecurityTokenConnector(dsnConfig, func(ctx context.Context) (string, error) { - return c.DBAuth.GetRDSAuthToken(ctx, sessionCtx) + return c.DBAuth.GetRDSAuthToken(ctx, sessionCtx.Database, sessionCtx.DatabaseUser) }) } diff --git a/lib/srv/db/sqlserver/engine_test.go b/lib/srv/db/sqlserver/engine_test.go index ba45897c00b76..fd1dbc5ab8b3f 100644 --- a/lib/srv/db/sqlserver/engine_test.go +++ b/lib/srv/db/sqlserver/engine_test.go @@ -25,6 +25,7 @@ import ( "io" "net" "testing" + "time" "github.com/google/go-cmp/cmp" mssql "github.com/microsoft/go-mssqldb" @@ -409,7 +410,7 @@ func (m *mockDBAuth) GetAuthPreference(ctx context.Context) (types.AuthPreferenc }) } -func (m *mockDBAuth) GetTLSConfig(_ context.Context, _ *common.Session) (*tls.Config, error) { +func (m *mockDBAuth) GetTLSConfig(ctx context.Context, certExpiry time.Time, database types.Database, databaseUser string) (*tls.Config, error) { return &tls.Config{}, nil }