diff --git a/lib/srv/db/auth_test.go b/lib/srv/db/auth_test.go index f3d827aa47df2..942ffd7e278c3 100644 --- a/lib/srv/db/auth_test.go +++ b/lib/srv/db/auth_test.go @@ -20,6 +20,7 @@ package db import ( "context" + "crypto/tls" "testing" "time" @@ -261,9 +262,9 @@ func TestAuthTokens(t *testing.T) { type testAuth struct { // Auth is the wrapped "real" auth that handles everything except for // cloud auth tokens generation. - common.Auth + realAuth common.Auth // FieldLogger is used for logging. - logrus.FieldLogger + logger logrus.FieldLogger } func newTestAuth(ac common.AuthConfig) (*testAuth, error) { @@ -272,11 +273,13 @@ func newTestAuth(ac common.AuthConfig) (*testAuth, error) { return nil, trace.Wrap(err) } return &testAuth{ - Auth: auth, - FieldLogger: logrus.WithField(teleport.ComponentKey, "auth:test"), + realAuth: auth, + logger: logrus.WithField(teleport.ComponentKey, "auth:test"), }, nil } +var _ common.Auth = (*testAuth)(nil) + const ( // rdsAuthToken is a mock RDS IAM auth token. rdsAuthToken = "rds-auth-token" @@ -314,76 +317,106 @@ type fakeTokenSource struct { } func (f *fakeTokenSource) Token() (*oauth2.Token, error) { - f.Infof("Generating Cloud Spanner auth token source") + f.Info("Generating Cloud Spanner auth token source") return &oauth2.Token{ Expiry: f.exp, AccessToken: f.token, }, nil } -// GetRDSAuthToken generates RDS/Aurora auth token. -func (a *testAuth) GetRDSAuthToken(ctx context.Context, sessionCtx *common.Session) (string, error) { - a.Infof("Generating RDS auth token for %v.", sessionCtx) +func (a *testAuth) GetRDSAuthToken(ctx context.Context, database types.Database, databaseUser string) (string, error) { + a.logger. + WithField("database", database). + WithField("databaseUser", databaseUser). + Info("Generating RDS auth token") return rdsAuthToken, nil } -// GetRedshiftAuthToken generates Redshift auth token. -func (a *testAuth) GetRedshiftAuthToken(ctx context.Context, sessionCtx *common.Session) (string, string, error) { - a.Infof("Generating Redshift auth token for %v.", sessionCtx) +func (a *testAuth) GetRedshiftAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) { + a.logger. + WithField("database", database). + WithField("databaseUser", databaseUser). + WithField("databaseName", databaseName). + Info("Generating Redshift auth token") return redshiftAuthUser, redshiftAuthToken, nil } -func (a *testAuth) GetRedshiftServerlessAuthToken(ctx context.Context, sessionCtx *common.Session) (string, string, error) { +func (a *testAuth) GetRedshiftServerlessAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) { return "", "", trace.NotImplemented("GetRedshiftServerlessAuthToken is not implemented") } -func (a *testAuth) GetElastiCacheRedisToken(ctx context.Context, sessionCtx *common.Session) (string, error) { +func (a *testAuth) GetElastiCacheRedisToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, error) { return elastiCacheRedisToken, nil } -func (a *testAuth) GetMemoryDBToken(ctx context.Context, sessionCtx *common.Session) (string, error) { +func (a *testAuth) GetMemoryDBToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, error) { return memorydbToken, nil } -// GetCloudSQLAuthToken generates Cloud SQL auth token. -func (a *testAuth) GetCloudSQLAuthToken(ctx context.Context, sessionCtx *common.Session) (string, error) { - a.Infof("Generating Cloud SQL auth token for %v.", sessionCtx) +func (a *testAuth) GetCloudSQLAuthToken(ctx context.Context, databaseUser string) (string, error) { + a.logger.WithField("database_user", databaseUser).Info("Generating Cloud SQL auth token") return cloudSQLAuthToken, nil } -// GetSpannerTokenSource returns an oauth token source for GCP Spanner. -func (a *testAuth) GetSpannerTokenSource(ctx context.Context, sessionCtx *common.Session) (oauth2.TokenSource, error) { +func (a *testAuth) GetSpannerTokenSource(ctx context.Context, databaseUser string) (oauth2.TokenSource, error) { return &fakeTokenSource{ token: cloudSpannerAuthToken, - FieldLogger: a.WithField("session", sessionCtx), + FieldLogger: a.logger.WithField("database_user", databaseUser), }, nil } -// GetCloudSQLPassword generates Cloud SQL user password. -func (a *testAuth) GetCloudSQLPassword(ctx context.Context, sessionCtx *common.Session) (string, error) { - a.Infof("Generating Cloud SQL user password %v.", sessionCtx) +func (a *testAuth) GetCloudSQLPassword(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, error) { + a.logger. + WithField("database", database). + WithField("database_user", databaseUser). + WithField("database_name", databaseName). + Info("Generating Cloud SQL password") return cloudSQLPassword, nil } -// GetAzureAccessToken generates Azure access token. -func (a *testAuth) GetAzureAccessToken(ctx context.Context, sessionCtx *common.Session) (string, error) { - a.Infof("Generating Azure access token for %v.", sessionCtx) +func (a *testAuth) GetAzureAccessToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, error) { + a.logger. + WithField("database", database). + WithField("database_user", databaseUser). + WithField("database_name", databaseName). + Info("Generating Azure access token") return azureAccessToken, nil } -// GetAzureCacheForRedisToken retrieves auth token for Azure Cache for Redis. -func (a *testAuth) GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *common.Session) (string, error) { - a.Infof("Generating Azure Redis token for %v.", sessionCtx) +func (a *testAuth) GetAzureCacheForRedisToken(ctx context.Context, database types.Database) (string, error) { + a.logger. + WithField("database", database). + Info("Generating Azure Redis token") return azureRedisToken, nil } -// GetAWSIAMCreds returns the AWS IAM credentials, including access key, secret -// access key and session token. -func (a *testAuth) GetAWSIAMCreds(ctx context.Context, sessionCtx *common.Session) (string, string, string, error) { - a.Infof("Generating AWS IAM credentials for %v.", sessionCtx) +func (a *testAuth) GetTLSConfig(ctx context.Context, certTTL time.Duration, database types.Database, databaseUser string, databaseName string) (*tls.Config, error) { + return a.realAuth.GetTLSConfig(ctx, certTTL, database, databaseUser, databaseName) +} + +func (a *testAuth) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) { + return a.realAuth.GetAuthPreference(ctx) +} + +func (a *testAuth) GetAzureIdentityResourceID(ctx context.Context, identityName string) (string, error) { + return a.realAuth.GetAzureIdentityResourceID(ctx, identityName) +} + +func (a *testAuth) GetAWSIAMCreds(ctx context.Context, database types.Database, databaseUser string) (string, string, string, error) { + a.logger. + WithField("database", database). + WithField("database_user", databaseUser). + Info("Generating AWS IAM credentials") return atlasAuthUser, atlasAuthToken, atlasAuthSessionToken, nil } +func (a *testAuth) WithLogger(getUpdatedLogger func(logrus.FieldLogger) logrus.FieldLogger) common.Auth { + return &testAuth{ + realAuth: a.realAuth, + logger: getUpdatedLogger(a.logger), + } +} + func TestMongoDBAtlas(t *testing.T) { t.Parallel() diff --git a/lib/srv/db/common/auth.go b/lib/srv/db/common/auth.go index b1381d97e819f..a1f654a83d2b2 100644 --- a/lib/srv/db/common/auth.go +++ b/lib/srv/db/common/auth.go @@ -24,7 +24,6 @@ import ( "crypto/x509" "crypto/x509/pkix" "fmt" - "io" "net/http" "net/url" "strings" @@ -72,27 +71,27 @@ const azureVirtualMachineCacheTTL = 5 * time.Minute // Auth defines interface for creating auth tokens and TLS configurations. type Auth interface { // GetRDSAuthToken generates RDS/Aurora auth token. - GetRDSAuthToken(ctx context.Context, sessionCtx *Session) (string, error) + GetRDSAuthToken(ctx context.Context, database types.Database, databaseUser string) (string, error) // GetRedshiftAuthToken generates Redshift auth token. - GetRedshiftAuthToken(ctx context.Context, sessionCtx *Session) (string, string, error) + GetRedshiftAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) // GetRedshiftServerlessAuthToken generates Redshift Serverless auth token. - GetRedshiftServerlessAuthToken(ctx context.Context, sessionCtx *Session) (string, string, error) + GetRedshiftServerlessAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) // GetElastiCacheRedisToken generates an ElastiCache Redis auth token. - GetElastiCacheRedisToken(ctx context.Context, sessionCtx *Session) (string, error) + GetElastiCacheRedisToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, error) // GetMemoryDBToken generates a MemoryDB auth token. - GetMemoryDBToken(ctx context.Context, sessionCtx *Session) (string, error) + GetMemoryDBToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, error) // GetCloudSQLAuthToken generates Cloud SQL auth token. - GetCloudSQLAuthToken(ctx context.Context, sessionCtx *Session) (string, error) + GetCloudSQLAuthToken(ctx context.Context, databaseUser string) (string, error) // GetSpannerTokenSource returns an oauth token source for GCP Spanner. - GetSpannerTokenSource(ctx context.Context, sessionCtx *Session) (oauth2.TokenSource, error) + GetSpannerTokenSource(ctx context.Context, databaseUser string) (oauth2.TokenSource, error) // GetCloudSQLPassword generates password for a Cloud SQL database user. - GetCloudSQLPassword(ctx context.Context, sessionCtx *Session) (string, error) + GetCloudSQLPassword(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, error) // GetAzureAccessToken generates Azure database access token. - GetAzureAccessToken(ctx context.Context, sessionCtx *Session) (string, error) + GetAzureAccessToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, error) // GetAzureCacheForRedisToken retrieves auth token for Azure Cache for Redis. - GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *Session) (string, error) + GetAzureCacheForRedisToken(ctx context.Context, database types.Database) (string, error) // GetTLSConfig builds the client TLS configuration for the session. - GetTLSConfig(ctx context.Context, sessionCtx *Session) (*tls.Config, error) + GetTLSConfig(ctx context.Context, certTTL time.Duration, database types.Database, databaseUser string, databaseName string) (*tls.Config, error) // GetAuthPreference returns the cluster authentication config. GetAuthPreference(ctx context.Context) (types.AuthPreference, error) // GetAzureIdentityResourceID returns the Azure identity resource ID @@ -101,9 +100,10 @@ type Auth interface { GetAzureIdentityResourceID(ctx context.Context, identityName string) (string, error) // GetAWSIAMCreds returns the AWS IAM credentials, including access key, // secret access key and session token. - GetAWSIAMCreds(ctx context.Context, sessionCtx *Session) (string, string, string, error) - // Closer releases all resources used by authenticator. - io.Closer + GetAWSIAMCreds(ctx context.Context, database types.Database, databaseUser string) (string, string, string, error) + // WithLogger returns a new instance of Auth with updated logger. + // The callback function receives the current logger and returns a new one. + WithLogger(getUpdatedLogger func(logrus.FieldLogger) logrus.FieldLogger) Auth } // AuthClient is an interface that defines a subset of libauth.Client's @@ -134,11 +134,7 @@ func (c *AuthConfig) CheckAndSetDefaults() error { return trace.BadParameter("missing AuthClient") } if c.Clients == nil { - cloudClients, err := cloud.NewClients() - if err != nil { - return trace.Wrap(err) - } - c.Clients = cloudClients + return trace.BadParameter("missing Clients") } if c.Clock == nil { c.Clock = clockwork.NewRealClock() @@ -149,6 +145,15 @@ func (c *AuthConfig) CheckAndSetDefaults() error { return nil } +func (c *AuthConfig) withLogger(getUpdatedLogger func(logrus.FieldLogger) logrus.FieldLogger) AuthConfig { + return AuthConfig{ + AuthClient: c.AuthClient, + Clients: c.Clients, + Clock: c.Clock, + Log: getUpdatedLogger(c.Log), + } +} + // dbAuth provides utilities for creating TLS configurations and // generating auth tokens when connecting to databases. type dbAuth struct { @@ -179,10 +184,19 @@ func NewAuth(config AuthConfig) (Auth, error) { }, nil } +// WithLogger returns a new instance of Auth with updated logger. +// The callback function receives the current logger and returns a new one. +func (a *dbAuth) WithLogger(getUpdatedLogger func(logrus.FieldLogger) logrus.FieldLogger) Auth { + return &dbAuth{ + cfg: a.cfg.withLogger(getUpdatedLogger), + azureVirtualMachineCache: a.azureVirtualMachineCache, + } +} + // GetRDSAuthToken returns authorization token that will be used as a password // when connecting to RDS and Aurora databases. -func (a *dbAuth) GetRDSAuthToken(ctx context.Context, sessionCtx *Session) (string, error) { - meta := sessionCtx.Database.GetAWS() +func (a *dbAuth) GetRDSAuthToken(ctx context.Context, database types.Database, databaseUser string) (string, error) { + meta := database.GetAWS() awsSession, err := a.cfg.Clients.GetAWSSession(ctx, meta.Region, cloud.WithAssumeRoleFromAWSMeta(meta), cloud.WithAmbientCredentials(), @@ -190,14 +204,17 @@ func (a *dbAuth) GetRDSAuthToken(ctx context.Context, sessionCtx *Session) (stri if err != nil { return "", trace.Wrap(err) } - a.cfg.Log.Debugf("Generating RDS auth token for %s.", sessionCtx) + a.cfg.Log. + WithField("database", database). + WithField("database_user", databaseUser). + Debug("Generating RDS auth token") token, err := rdsutils.BuildAuthToken( - sessionCtx.Database.GetURI(), + database.GetURI(), meta.Region, - sessionCtx.DatabaseUser, + databaseUser, awsSession.Config.Credentials) if err != nil { - policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocument(sessionCtx.Database) + policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocument(database) if getPolicyErr != nil { policy = fmt.Sprintf("failed to generate IAM policy: %v", getPolicyErr) } @@ -216,17 +233,17 @@ permissions (note that IAM changes may take a few minutes to propagate): // GetRedshiftAuthToken returns authorization token that will be used as a // password when connecting to Redshift databases. -func (a *dbAuth) GetRedshiftAuthToken(ctx context.Context, sessionCtx *Session) (string, string, error) { - if awsutils.IsRoleARN(sessionCtx.DatabaseUser) { - return a.getRedshiftIAMRoleAuthToken(ctx, sessionCtx) +func (a *dbAuth) GetRedshiftAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) { + if awsutils.IsRoleARN(databaseUser) { + return a.getRedshiftIAMRoleAuthToken(ctx, database, databaseUser, databaseName) } - return a.getRedshiftDBUserAuthToken(ctx, sessionCtx) + return a.getRedshiftDBUserAuthToken(ctx, database, databaseUser, databaseName) } -func (a *dbAuth) getRedshiftIAMRoleAuthToken(ctx context.Context, sessionCtx *Session) (string, string, error) { - meta := sessionCtx.Database.GetAWS() - roleARN, err := a.buildAWSRoleARNFromDatabaseUser(ctx, sessionCtx) +func (a *dbAuth) getRedshiftIAMRoleAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) { + meta := database.GetAWS() + roleARN, err := a.buildAWSRoleARNFromDatabaseUser(ctx, database, databaseUser) if err != nil { return "", "", trace.Wrap(err) } @@ -254,13 +271,17 @@ Make sure that IAM role %q has a trust relationship with Teleport database agent } // Now make the API call to generate the temporary credentials. - a.cfg.Log.Debugf("Generating Redshift IAM role auth token for %s.", sessionCtx) + a.cfg.Log. + WithField("database", database). + WithField("database_user", databaseUser). + WithField("database_name", databaseName). + Debug("Generating Redshift IAM role auth token") resp, err := client.GetClusterCredentialsWithIAMWithContext(ctx, &redshift.GetClusterCredentialsWithIAMInput{ ClusterIdentifier: aws.String(meta.Redshift.ClusterID), - DbName: aws.String(sessionCtx.DatabaseName), + DbName: aws.String(databaseName), }) if err != nil { - policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocumentForAssumedRole(sessionCtx.Database) + policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocumentForAssumedRole(database) if getPolicyErr != nil { policy = fmt.Sprintf("failed to generate IAM policy: %v", getPolicyErr) } @@ -276,8 +297,8 @@ Make sure that IAM role %q has permissions to generate credentials. Here is a sa return aws.StringValue(resp.DbUser), aws.StringValue(resp.DbPassword), nil } -func (a *dbAuth) getRedshiftDBUserAuthToken(ctx context.Context, sessionCtx *Session) (string, string, error) { - meta := sessionCtx.Database.GetAWS() +func (a *dbAuth) getRedshiftDBUserAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) { + meta := database.GetAWS() redshiftClient, err := a.cfg.Clients.GetAWSRedshiftClient(ctx, meta.Region, cloud.WithAssumeRoleFromAWSMeta(meta), cloud.WithAmbientCredentials(), @@ -285,11 +306,15 @@ func (a *dbAuth) getRedshiftDBUserAuthToken(ctx context.Context, sessionCtx *Ses if err != nil { return "", "", trace.Wrap(err) } - a.cfg.Log.Debugf("Generating Redshift auth token for %s.", sessionCtx) + a.cfg.Log. + WithField("database", database). + WithField("database_user", databaseUser). + WithField("database_name", databaseName). + Debug("Generating Redshift auth token") resp, err := redshiftClient.GetClusterCredentialsWithContext(ctx, &redshift.GetClusterCredentialsInput{ ClusterIdentifier: aws.String(meta.Redshift.ClusterID), - DbUser: aws.String(sessionCtx.DatabaseUser), - DbName: aws.String(sessionCtx.DatabaseName), + DbUser: aws.String(databaseUser), + DbName: aws.String(databaseName), // TODO(r0mant): Do not auto-create database account if DbUser doesn't // exist for now, but it may be potentially useful in future. AutoCreate: aws.Bool(false), @@ -298,7 +323,7 @@ func (a *dbAuth) getRedshiftDBUserAuthToken(ctx context.Context, sessionCtx *Ses DbGroups: []*string{}, }) if err != nil { - policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocument(sessionCtx.Database) + policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocument(database) if getPolicyErr != nil { policy = fmt.Sprintf("failed to generate IAM policy: %v", getPolicyErr) } @@ -317,13 +342,13 @@ propagate): } // GetRedshiftServerlessAuthToken generates Redshift Serverless auth token. -func (a *dbAuth) GetRedshiftServerlessAuthToken(ctx context.Context, sessionCtx *Session) (string, string, error) { +func (a *dbAuth) GetRedshiftServerlessAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) { // Redshift Serverless maps caller IAM users/roles to database users. For // example, an IAM role "arn:aws:iam::1234567890:role/my-role-name" will be // mapped to a Postgres user "IAMR:my-role-name" inside the database. So we // first need to assume this IAM role before getting auth token. - meta := sessionCtx.Database.GetAWS() - roleARN, err := redshiftServerlessUsernameToRoleARN(meta, sessionCtx.DatabaseUser) + meta := database.GetAWS() + roleARN, err := redshiftServerlessUsernameToRoleARN(meta, databaseUser) if err != nil { return "", "", trace.Wrap(err) } @@ -350,13 +375,17 @@ Make sure that IAM role %q has a trust relationship with Teleport database agent } // Now make the API call to generate the temporary credentials. - a.cfg.Log.Debugf("Generating Redshift Serverless auth token for %s.", sessionCtx) + a.cfg.Log. + WithField("database", database). + WithField("database_user", databaseUser). + WithField("database_name", databaseName). + Debug("Generating Redshift Serverless auth token") resp, err := client.GetCredentialsWithContext(ctx, &redshiftserverless.GetCredentialsInput{ WorkgroupName: aws.String(meta.RedshiftServerless.WorkgroupName), - DbName: aws.String(sessionCtx.DatabaseName), + DbName: aws.String(databaseName), }) if err != nil { - policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocumentForAssumedRole(sessionCtx.Database) + policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocumentForAssumedRole(database) if getPolicyErr != nil { policy = fmt.Sprintf("failed to generate IAM policy: %v", getPolicyErr) } @@ -374,12 +403,12 @@ Make sure that IAM role %q has permissions to generate credentials. Here is a sa // GetCloudSQLAuthToken returns authorization token that will be used as a // password when connecting to Cloud SQL databases. -func (a *dbAuth) GetCloudSQLAuthToken(ctx context.Context, sessionCtx *Session) (string, error) { +func (a *dbAuth) GetCloudSQLAuthToken(ctx context.Context, databaseUser string) (string, error) { // https://developers.google.com/identity/protocols/oauth2/scopes#sqladmin scopes := []string{ "https://www.googleapis.com/auth/sqlservice.admin", } - ts, err := a.getCloudTokenSource(ctx, sessionCtx, scopes) + ts, err := a.getCloudTokenSource(ctx, databaseUser, scopes) if err != nil { return "", trace.Wrap(err) } @@ -391,12 +420,12 @@ func (a *dbAuth) GetCloudSQLAuthToken(ctx context.Context, sessionCtx *Session) } // GetSpannerTokenSource returns an oauth token source for GCP Spanner. -func (a *dbAuth) GetSpannerTokenSource(ctx context.Context, sessionCtx *Session) (oauth2.TokenSource, error) { +func (a *dbAuth) GetSpannerTokenSource(ctx context.Context, databaseUser string) (oauth2.TokenSource, error) { // https://developers.google.com/identity/protocols/oauth2/scopes#spanner scopes := []string{ "https://www.googleapis.com/auth/spanner.data", } - ts, err := a.getCloudTokenSource(ctx, sessionCtx, scopes) + ts, err := a.getCloudTokenSource(ctx, databaseUser, scopes) if err != nil { return nil, trace.Wrap(err) } @@ -404,19 +433,19 @@ func (a *dbAuth) GetSpannerTokenSource(ctx context.Context, sessionCtx *Session) return oauth2.ReuseTokenSource(nil, ts), nil } -func (a *dbAuth) getCloudTokenSource(ctx context.Context, sessionCtx *Session, scopes []string) (*cloudTokenSource, error) { +func (a *dbAuth) getCloudTokenSource(ctx context.Context, databaseUser string, scopes []string) (*cloudTokenSource, error) { gcpIAM, err := a.cfg.Clients.GetGCPIAMClient(ctx) if err != nil { return nil, trace.Wrap(err) } - serviceAccountName := sessionCtx.DatabaseUser + serviceAccountName := databaseUser if !strings.HasSuffix(serviceAccountName, ".gserviceaccount.com") { serviceAccountName = serviceAccountName + ".gserviceaccount.com" } return &cloudTokenSource{ ctx: ctx, client: gcpIAM, - log: a.cfg.Log.WithField("session", sessionCtx.String()), + log: a.cfg.Log.WithField("database_user", databaseUser), serviceAccount: serviceAccountName, scopes: scopes, }, nil @@ -472,12 +501,16 @@ or "iam.serviceAccounts.getAccessToken" IAM permission. // // It is used to generate a one-time password when connecting to GCP MySQL // databases which don't support IAM authentication. -func (a *dbAuth) GetCloudSQLPassword(ctx context.Context, sessionCtx *Session) (string, error) { +func (a *dbAuth) GetCloudSQLPassword(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, error) { gcpCloudSQL, err := a.cfg.Clients.GetGCPSQLAdminClient(ctx) if err != nil { return "", trace.Wrap(err) } - a.cfg.Log.Debugf("Generating GCP user password for %s.", sessionCtx) + a.cfg.Log. + WithField("database", database). + WithField("database_user", databaseUser). + WithField("database_name", databaseName). + Debug("Generating GCP user password") token, err := utils.CryptoRandomHex(defaults.TokenLenBytes) if err != nil { return "", trace.Wrap(err) @@ -492,7 +525,7 @@ func (a *dbAuth) GetCloudSQLPassword(ctx context.Context, sessionCtx *Session) ( retryCtx, cancel := context.WithTimeout(ctx, defaults.DatabaseConnectTimeout) defer cancel() err = retry.For(retryCtx, func() error { - err := a.updateCloudSQLUser(ctx, sessionCtx, gcpCloudSQL, &sqladmin.User{ + err := a.updateCloudSQLUser(ctx, database, databaseUser, gcpCloudSQL, &sqladmin.User{ Password: token, }) if err != nil && !trace.IsCompareFailed(ConvertError(err)) { // We only want to retry on 409. @@ -507,12 +540,12 @@ func (a *dbAuth) GetCloudSQLPassword(ctx context.Context, sessionCtx *Session) ( } // updateCloudSQLUser makes a request to Cloud SQL API to update the provided user. -func (a *dbAuth) updateCloudSQLUser(ctx context.Context, sessionCtx *Session, gcpCloudSQL gcp.SQLAdminClient, user *sqladmin.User) error { - err := gcpCloudSQL.UpdateUser(ctx, sessionCtx.Database, sessionCtx.DatabaseUser, user) +func (a *dbAuth) updateCloudSQLUser(ctx context.Context, database types.Database, databaseUser string, gcpCloudSQL gcp.SQLAdminClient, user *sqladmin.User) error { + err := gcpCloudSQL.UpdateUser(ctx, database, databaseUser, user) if err != nil { // Note that mysql client has a 1024 char limit for displaying errors // so we need to keep the message short when possible. This message - // does get cut off when sessionCtx.DatabaseUser or err is long. + // does get cut off when databaseUser or err is long. return trace.AccessDenied(`Could not update Cloud SQL user %q password: %v @@ -523,14 +556,18 @@ ID as "--db-user", or grant the Teleport Database Service the If the db user uses passwords, make sure Teleport Database Service has "Cloud SQL Admin" GCP IAM role, or "cloudsql.users.update" IAM permission. -`, sessionCtx.DatabaseUser, err) +`, databaseUser, err) } return nil } // GetAzureAccessToken generates Azure database access token. -func (a *dbAuth) GetAzureAccessToken(ctx context.Context, sessionCtx *Session) (string, error) { - a.cfg.Log.Debugf("Generating Azure access token for %s.", sessionCtx) +func (a *dbAuth) GetAzureAccessToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, error) { + a.cfg.Log. + WithField("database", database). + WithField("database_user", databaseUser). + WithField("database_name", databaseName). + Debug("Generating Azure access token") cred, err := a.cfg.Clients.GetAzureCredential() if err != nil { return "", trace.Wrap(err) @@ -548,8 +585,8 @@ func (a *dbAuth) GetAzureAccessToken(ctx context.Context, sessionCtx *Session) ( } // GetElastiCacheRedisToken generates an ElastiCache Redis auth token. -func (a *dbAuth) GetElastiCacheRedisToken(ctx context.Context, sessionCtx *Session) (string, error) { - meta := sessionCtx.Database.GetAWS() +func (a *dbAuth) GetElastiCacheRedisToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, error) { + meta := database.GetAWS() awsSession, err := a.cfg.Clients.GetAWSSession(ctx, meta.Region, cloud.WithAssumeRoleFromAWSMeta(meta), cloud.WithAmbientCredentials(), @@ -557,11 +594,15 @@ func (a *dbAuth) GetElastiCacheRedisToken(ctx context.Context, sessionCtx *Sessi if err != nil { return "", trace.Wrap(err) } - a.cfg.Log.Debugf("Generating ElastiCache Redis auth token for %s.", sessionCtx) + a.cfg.Log. + WithField("database", database). + WithField("database_user", databaseUser). + WithField("database_name", databaseName). + Debug("Generating ElastiCache Redis auth token") tokenReq := &awsRedisIAMTokenRequest{ // For IAM-enabled ElastiCache users, the username and user id properties must be identical. // https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/auth-iam.html#auth-iam-limits - userID: sessionCtx.DatabaseUser, + userID: databaseUser, targetID: meta.ElastiCache.ReplicationGroupID, serviceName: elasticache.ServiceName, region: meta.Region, @@ -573,8 +614,8 @@ func (a *dbAuth) GetElastiCacheRedisToken(ctx context.Context, sessionCtx *Sessi } // GetMemoryDBToken generates a MemoryDB auth token. -func (a *dbAuth) GetMemoryDBToken(ctx context.Context, sessionCtx *Session) (string, error) { - meta := sessionCtx.Database.GetAWS() +func (a *dbAuth) GetMemoryDBToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, error) { + meta := database.GetAWS() awsSession, err := a.cfg.Clients.GetAWSSession(ctx, meta.Region, cloud.WithAssumeRoleFromAWSMeta(meta), cloud.WithAmbientCredentials(), @@ -582,9 +623,13 @@ func (a *dbAuth) GetMemoryDBToken(ctx context.Context, sessionCtx *Session) (str if err != nil { return "", trace.Wrap(err) } - a.cfg.Log.Debugf("Generating MemoryDB auth token for %s.", sessionCtx) + a.cfg.Log. + WithField("database", database). + WithField("database_user", databaseUser). + WithField("database_name", databaseName). + Debug("Generating MemoryDB auth token") tokenReq := &awsRedisIAMTokenRequest{ - userID: sessionCtx.DatabaseUser, + userID: databaseUser, targetID: meta.MemoryDB.ClusterName, serviceName: strings.ToLower(memorydb.ServiceName), region: meta.Region, @@ -596,8 +641,8 @@ func (a *dbAuth) GetMemoryDBToken(ctx context.Context, sessionCtx *Session) (str } // GetAzureCacheForRedisToken retrieves auth token for Azure Cache for Redis. -func (a *dbAuth) GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *Session) (string, error) { - resourceID, err := arm.ParseResourceID(sessionCtx.Database.GetAzure().ResourceID) +func (a *dbAuth) GetAzureCacheForRedisToken(ctx context.Context, database types.Database) (string, error) { + resourceID, err := arm.ParseResourceID(database.GetAzure().ResourceID) if err != nil { return "", trace.Wrap(err) } @@ -617,24 +662,27 @@ func (a *dbAuth) GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *Ses default: return "", trace.BadParameter("unknown Azure Cache for Redis resource type: %v", resourceID.ResourceType) } - token, err := client.GetToken(ctx, sessionCtx.Database.GetAzure().ResourceID) + token, err := client.GetToken(ctx, database.GetAzure().ResourceID) if err != nil { // Some Azure error messages are long, multi-lined, and may even // contain divider lines like "------". It's unreadable in redis-cli as // the message has to be merged to a single line string. Thus logging // the original error as debug and returning a more user friendly // message. - a.cfg.Log.WithError(err).Debugf("Failed to get token for Azure Redis %q.", sessionCtx.Database.GetName()) + a.cfg.Log. + WithField("database", database). + WithError(err). + Debug("Failed to get token for Azure Redis") switch { case trace.IsAccessDenied(err): - return "", trace.AccessDenied("Failed to get token for Azure Redis %q. Please make sure the database agent has the \"listKeys\" permission to the database.", sessionCtx.Database.GetName()) + return "", trace.AccessDenied("Failed to get token for Azure Redis %q. Please make sure the database agent has the \"listKeys\" permission to the database.", database.GetName()) case trace.IsNotFound(err): // Note that Azure Cache for Redis should always have both keys // generated at all time. Here just checking in case something // wrong with the API. - return "", trace.AccessDenied("Failed to get token for Azure Redis %q. Please make sure either the primary key or the secondary key is generated.", sessionCtx.Database.GetName()) + return "", trace.AccessDenied("Failed to get token for Azure Redis %q. Please make sure either the primary key or the secondary key is generated.", database.GetName()) default: - return "", trace.Errorf("Failed to get token for Azure Redis %q.", sessionCtx.Database.GetName()) + return "", trace.Errorf("Failed to get token for Azure Redis %q.", database.GetName()) } } return token, nil @@ -645,30 +693,30 @@ func (a *dbAuth) GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *Ses // For RDS/Aurora, the config must contain RDS root certificate as a trusted // authority. For on-prem we generate a client certificate signed by the host // CA used to authenticate. -func (a *dbAuth) GetTLSConfig(ctx context.Context, sessionCtx *Session) (*tls.Config, error) { - dbTLSConfig := sessionCtx.Database.GetTLS() +func (a *dbAuth) GetTLSConfig(ctx context.Context, certTTL time.Duration, database types.Database, databaseUser string, databaseName string) (*tls.Config, error) { + dbTLSConfig := database.GetTLS() // Mode won't be set for older clients. We will default to VerifyFull then - the same as before. switch dbTLSConfig.Mode { case types.DatabaseTLSMode_INSECURE: - return a.getTLSConfigInsecure(ctx, sessionCtx) + return a.getTLSConfigInsecure(ctx, certTTL, database, databaseUser, databaseName) case types.DatabaseTLSMode_VERIFY_CA: - return a.getTLSConfigVerifyCA(ctx, sessionCtx) + return a.getTLSConfigVerifyCA(ctx, certTTL, database, databaseUser, databaseName) default: - return a.getTLSConfigVerifyFull(ctx, sessionCtx) + return a.getTLSConfigVerifyFull(ctx, certTTL, database, databaseUser, databaseName) } } // getTLSConfigVerifyFull returns tls.Config with full verification enabled ('verify-full' mode). // Config also includes database specific adjustment. -func (a *dbAuth) getTLSConfigVerifyFull(ctx context.Context, sessionCtx *Session) (*tls.Config, error) { +func (a *dbAuth) getTLSConfigVerifyFull(ctx context.Context, certTTL time.Duration, database types.Database, databaseUser string, databaseName string) (*tls.Config, error) { tlsConfig := &tls.Config{} // Add CA certificate to the trusted pool if it's present, e.g. when // connecting to RDS/Aurora which require AWS CA or when was provided in config file. // // Some databases may also require the system cert pool, e.g Azure Redis. - if err := setupTLSConfigRootCAs(tlsConfig, sessionCtx); err != nil { + if err := setupTLSConfigRootCAs(tlsConfig, database); err != nil { return nil, trace.Wrap(err) } @@ -690,11 +738,11 @@ func (a *dbAuth) getTLSConfigVerifyFull(ctx context.Context, sessionCtx *Session // // See the following Go issue for more context: // https://github.com/golang/go/issues/40748 - if sessionCtx.Database.IsCloudSQL() { + if database.IsCloudSQL() { // Cloud SQL server presented certificates encode instance names as // ":" in CommonName. This is verified against // the ServerName in a custom connection verification step (see below). - tlsConfig.ServerName = sessionCtx.Database.GetGCP().GetServerName() + tlsConfig.ServerName = database.GetGCP().GetServerName() // This just disables default verification. tlsConfig.InsecureSkipVerify = true // This will verify CN and cert chain on each connection. @@ -702,32 +750,32 @@ func (a *dbAuth) getTLSConfigVerifyFull(ctx context.Context, sessionCtx *Session } // Setup server name for verification. - if err := setupTLSConfigServerName(tlsConfig, sessionCtx); err != nil { + if err := setupTLSConfigServerName(tlsConfig, database); err != nil { return nil, trace.Wrap(err) } // RDS/Aurora/Redshift/ElastiCache and Cloud SQL auth is done with an auth // token so don't generate a client certificate and exit here. - if sessionCtx.Database.IsCloudHosted() { + if database.IsCloudHosted() { return tlsConfig, nil } // MongoDB Atlas doesn't not require client certificates if is using AWS // authentication. - if awsutils.IsRoleARN(sessionCtx.DatabaseUser) && sessionCtx.Database.GetType() == types.DatabaseTypeMongoAtlas { + if awsutils.IsRoleARN(databaseUser) && database.GetType() == types.DatabaseTypeMongoAtlas { return tlsConfig, nil } // Otherwise, when connecting to an onprem database, generate a client // certificate. The database instance should be configured with // Teleport's CA obtained with 'tctl auth sign --type=db'. - return a.appendClientCert(ctx, sessionCtx, tlsConfig) + return a.appendClientCert(ctx, certTTL, database, databaseUser, databaseName, tlsConfig) } // getTLSConfigInsecure generates tls.Config when TLS mode is equal to 'insecure'. // Generated configuration will accept any certificate provided by database. -func (a *dbAuth) getTLSConfigInsecure(ctx context.Context, sessionCtx *Session) (*tls.Config, error) { - tlsConfig, err := a.getTLSConfigVerifyFull(ctx, sessionCtx) +func (a *dbAuth) getTLSConfigInsecure(ctx context.Context, certTTL time.Duration, database types.Database, databaseUser string, databaseName string) (*tls.Config, error) { + tlsConfig, err := a.getTLSConfigVerifyFull(ctx, certTTL, database, databaseUser, databaseName) if err != nil { return nil, trace.Wrap(err) } @@ -743,8 +791,8 @@ func (a *dbAuth) getTLSConfigInsecure(ctx context.Context, sessionCtx *Session) // getTLSConfigVerifyCA generates tls.Config when TLS mode is equal to 'verify-ca'. // Generated configuration is the same as 'verify-full' except the server name // verification is disabled. -func (a *dbAuth) getTLSConfigVerifyCA(ctx context.Context, sessionCtx *Session) (*tls.Config, error) { - tlsConfig, err := a.getTLSConfigVerifyFull(ctx, sessionCtx) +func (a *dbAuth) getTLSConfigVerifyCA(ctx context.Context, certTTL time.Duration, database types.Database, databaseUser string, databaseName string) (*tls.Config, error) { + tlsConfig, err := a.getTLSConfigVerifyFull(ctx, certTTL, database, databaseUser, databaseName) if err != nil { return nil, trace.Wrap(err) } @@ -761,8 +809,8 @@ func (a *dbAuth) getTLSConfigVerifyCA(ctx context.Context, sessionCtx *Session) } // appendClientCert generates a client certificate and appends it to the provided tlsConfig. -func (a *dbAuth) appendClientCert(ctx context.Context, sessionCtx *Session, tlsConfig *tls.Config) (*tls.Config, error) { - cert, cas, err := a.getClientCert(ctx, sessionCtx) +func (a *dbAuth) appendClientCert(ctx context.Context, certTTL time.Duration, database types.Database, databaseUser string, databaseName string, tlsConfig *tls.Config) (*tls.Config, error) { + cert, cas, err := a.getClientCert(ctx, certTTL, database, databaseUser, databaseName) if err != nil { return nil, trace.Wrap(err) } @@ -778,9 +826,9 @@ func (a *dbAuth) appendClientCert(ctx context.Context, sessionCtx *Session, tlsC // setupTLSConfigRootCAs initializes the root CA cert pool for the provided // tlsConfig based on session context. -func setupTLSConfigRootCAs(tlsConfig *tls.Config, sessionCtx *Session) error { +func setupTLSConfigRootCAs(tlsConfig *tls.Config, database types.Database) error { // Start with an empty pool or a system cert pool. - if shouldUseSystemCertPool(sessionCtx) { + if shouldUseSystemCertPool(database) { systemCertPool, err := x509.SystemCertPool() if err != nil { return trace.Wrap(err) @@ -791,8 +839,8 @@ func setupTLSConfigRootCAs(tlsConfig *tls.Config, sessionCtx *Session) error { } // If CAs are provided by the database object, add them to the pool. - if len(sessionCtx.Database.GetCA()) != 0 { - if !tlsConfig.RootCAs.AppendCertsFromPEM([]byte(sessionCtx.Database.GetCA())) { + if len(database.GetCA()) != 0 { + if !tlsConfig.RootCAs.AppendCertsFromPEM([]byte(database.GetCA())) { return trace.BadParameter("invalid server CA certificate") } return nil @@ -805,8 +853,8 @@ func setupTLSConfigRootCAs(tlsConfig *tls.Config, sessionCtx *Session) error { // shouldUseSystemCertPool returns true for database servers presenting // certificates signed by publicly trusted CAs so a system cert pool can be // used. -func shouldUseSystemCertPool(sessionCtx *Session) bool { - switch sessionCtx.Database.GetType() { +func shouldUseSystemCertPool(database types.Database) bool { + switch database.GetType() { // Azure databases either use Baltimore Root CA or DigiCert Global Root G2. // // https://docs.microsoft.com/en-us/azure/postgresql/concepts-ssl-connection-security @@ -832,9 +880,9 @@ func shouldUseSystemCertPool(sessionCtx *Session) bool { // setupTLSConfigServerName initializes the server name for the provided // tlsConfig based on session context. -func setupTLSConfigServerName(tlsConfig *tls.Config, sessionCtx *Session) error { +func setupTLSConfigServerName(tlsConfig *tls.Config, database types.Database) error { // Use user provided server name if set. Override the current value if needed. - if dbTLSConfig := sessionCtx.Database.GetTLS(); dbTLSConfig.ServerName != "" { + if dbTLSConfig := database.GetTLS(); dbTLSConfig.ServerName != "" { tlsConfig.ServerName = dbTLSConfig.ServerName return nil } @@ -844,7 +892,7 @@ func setupTLSConfigServerName(tlsConfig *tls.Config, sessionCtx *Session) error return nil } - switch sessionCtx.Database.GetProtocol() { + switch database.GetProtocol() { case defaults.ProtocolMongoDB: // Don't set the ServerName when connecting to a MongoDB cluster - in case // of replica set the driver may dial multiple servers and will set @@ -859,8 +907,8 @@ func setupTLSConfigServerName(tlsConfig *tls.Config, sessionCtx *Session) error // hostnames. However, OSS cluster mode may redirect to an IP address, // and without correct ServerName the handshake will fail as the IPs // are not in SANs. - if sessionCtx.Database.IsAzure() { - serverName, err := azureutils.GetHostFromRedisURI(sessionCtx.Database.GetURI()) + if database.IsAzure() { + serverName, err := azureutils.GetHostFromRedisURI(database.GetURI()) if err != nil { return trace.Wrap(err) } @@ -872,7 +920,7 @@ func setupTLSConfigServerName(tlsConfig *tls.Config, sessionCtx *Session) error // Redis is using custom URI schema. return nil case defaults.ProtocolClickHouse, defaults.ProtocolClickHouseHTTP: - u, err := url.Parse(sessionCtx.Database.GetURI()) + u, err := url.Parse(database.GetURI()) if err != nil { return trace.Wrap(err) } @@ -885,7 +933,7 @@ func setupTLSConfigServerName(tlsConfig *tls.Config, sessionCtx *Session) error default: // For other databases we're always connecting to the server specified // in URI so set ServerName ourselves. - addr, err := utils.ParseAddr(sessionCtx.Database.GetURI()) + addr, err := utils.ParseAddr(database.GetURI()) if err != nil { return trace.Wrap(err) } @@ -922,24 +970,28 @@ func verifyConnectionFunc(rootCAs *x509.CertPool) func(cs tls.ConnectionState) e // getClientCert signs an ephemeral client certificate used by this // server to authenticate with the database instance. -func (a *dbAuth) getClientCert(ctx context.Context, sessionCtx *Session) (cert *tls.Certificate, cas [][]byte, err error) { +func (a *dbAuth) getClientCert(ctx context.Context, certTTL time.Duration, database types.Database, databaseUser string, databaseName string) (cert *tls.Certificate, cas [][]byte, err error) { privateKey, err := native.GeneratePrivateKey() if err != nil { return nil, nil, trace.Wrap(err) } // Postgres requires the database username to be encoded as a common // name in the client certificate. - subject := pkix.Name{CommonName: sessionCtx.DatabaseUser} + subject := pkix.Name{CommonName: databaseUser} csr, err := tlsca.GenerateCertificateRequestPEM(subject, privateKey) if err != nil { return nil, nil, trace.Wrap(err) } // TODO(r0mant): Cache database certificates to avoid expensive generate // operation on each connection. - a.cfg.Log.Debugf("Generating client certificate for %s.", sessionCtx) + a.cfg.Log. + WithField("database", database). + WithField("database_user", databaseUser). + WithField("database_name", databaseName). + Debug("Generating client certificate") resp, err := a.cfg.AuthClient.GenerateDatabaseCert(ctx, &proto.DatabaseCertRequest{ CSR: csr, - TTL: proto.Duration(sessionCtx.Identity.Expires.Sub(a.cfg.Clock.Now())), + TTL: proto.Duration(certTTL), }) if err != nil { return nil, nil, trace.Wrap(err) @@ -1012,14 +1064,18 @@ func (a *dbAuth) getCurrentAzureVM(ctx context.Context) (*libazure.VirtualMachin return vm, nil } -func (a *dbAuth) buildAWSRoleARNFromDatabaseUser(ctx context.Context, sessionCtx *Session) (string, error) { - dbAWS := sessionCtx.Database.GetAWS() +func (a *dbAuth) buildAWSRoleARNFromDatabaseUser(ctx context.Context, database types.Database, databaseUser string) (string, error) { + dbAWS := database.GetAWS() awsAccountID := dbAWS.AccountID - if awsutils.IsPartialRoleARN(sessionCtx.DatabaseUser) && awsAccountID == "" { + log := a.cfg.Log. + WithField("database", database). + WithField("database_user", databaseUser) + + if awsutils.IsPartialRoleARN(databaseUser) && awsAccountID == "" { switch { case dbAWS.AssumeRoleARN != "": - a.cfg.Log.Debugf("Using AWS Account ID from assumed role") + log.Debug("Using AWS Account ID from assumed role") assumeRoleARN, err := awsutils.ParseRoleARN(dbAWS.AssumeRoleARN) if err != nil { return "", trace.Wrap(err) @@ -1027,7 +1083,7 @@ func (a *dbAuth) buildAWSRoleARNFromDatabaseUser(ctx context.Context, sessionCtx awsAccountID = assumeRoleARN.AccountID default: - a.cfg.Log.Debugf("Fetching AWS Account ID to build role ARN") + log.Debug("Fetching AWS Account ID to build role ARN") stsClient, err := a.cfg.Clients.GetAWSSTSClient(ctx, dbAWS.Region, cloud.WithAmbientCredentials()) if err != nil { return "", trace.Wrap(err) @@ -1042,15 +1098,15 @@ func (a *dbAuth) buildAWSRoleARNFromDatabaseUser(ctx context.Context, sessionCtx } } - arn, err := awsutils.BuildRoleARN(sessionCtx.DatabaseUser, dbAWS.Region, awsAccountID) + arn, err := awsutils.BuildRoleARN(databaseUser, dbAWS.Region, awsAccountID) return arn, trace.Wrap(err) } // GetAWSIAMCreds returns the AWS IAM credentials, including access key, secret // access key and session token. -func (a *dbAuth) GetAWSIAMCreds(ctx context.Context, sessionCtx *Session) (string, string, string, error) { - dbAWS := sessionCtx.Database.GetAWS() - arn, err := a.buildAWSRoleARNFromDatabaseUser(ctx, sessionCtx) +func (a *dbAuth) GetAWSIAMCreds(ctx context.Context, database types.Database, databaseUser string) (string, string, string, error) { + dbAWS := database.GetAWS() + arn, err := a.buildAWSRoleARNFromDatabaseUser(ctx, database, databaseUser) if err != nil { return "", "", "", trace.Wrap(err) } @@ -1226,25 +1282,3 @@ func (r *awsRedisIAMTokenRequest) getSignableRequest() (*http.Request, error) { } return req, nil } - -type reportingAuth struct { - Auth - component string - db types.Database -} - -// newReportingAuth returns a reporting version of Auth, wrapping the original Auth instance. -func newReportingAuth(db types.Database, auth Auth) *reportingAuth { - return &reportingAuth{ - Auth: auth, - component: "db:auth", - db: db, - } -} - -func (r *reportingAuth) GetTLSConfig(ctx context.Context, sessionCtx *Session) (*tls.Config, error) { - defer methodCallMetrics("GetTLSConfig", r.component, r.db)() - return r.Auth.GetTLSConfig(ctx, sessionCtx) -} - -var _ Auth = (*reportingAuth)(nil) diff --git a/lib/srv/db/common/auth_test.go b/lib/srv/db/common/auth_test.go index 51d7b1a88956f..025d15fa8bc02 100644 --- a/lib/srv/db/common/auth_test.go +++ b/lib/srv/db/common/auth_test.go @@ -90,9 +90,7 @@ func TestAuthGetAzureCacheForRedisToken(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - token, err := auth.GetAzureCacheForRedisToken(context.TODO(), &Session{ - Database: newAzureRedisDatabase(t, test.resourceID), - }) + token, err := auth.GetAzureCacheForRedisToken(context.TODO(), newAzureRedisDatabase(t, test.resourceID)) if test.expectError { require.Error(t, err) } else { @@ -121,11 +119,11 @@ func TestAuthGetRedshiftServerlessAuthToken(t *testing.T) { }) require.NoError(t, err) - dbUser, dbPassword, err := auth.GetRedshiftServerlessAuthToken(context.TODO(), &Session{ - DatabaseUser: "some-user", - DatabaseName: "some-database", - Database: newRedshiftServerlessDatabase(t), - }) + dbUser, dbPassword, err := auth.GetRedshiftServerlessAuthToken(context.TODO(), + newRedshiftServerlessDatabase(t), + "some-user", + "some-database", + ) require.NoError(t, err) require.Equal(t, "IAM:some-user", dbUser) require.Equal(t, "some-password", dbPassword) @@ -226,11 +224,12 @@ func TestAuthGetTLSConfig(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - tlsConfig, err := auth.GetTLSConfig(context.TODO(), &Session{ - Identity: tlsca.Identity{}, - DatabaseUser: "default", - Database: test.sessionDatabase, - }) + tlsConfig, err := auth.GetTLSConfig(context.TODO(), + time.Hour, + test.sessionDatabase, + "defaultUser", + "defaultDatabase", + ) require.NoError(t, err) require.Equal(t, test.expectServerName, tlsConfig.ServerName) @@ -456,24 +455,22 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) tests := map[string]struct { - sessionCtx *Session - checkGetAuthFn func(t *testing.T, auth Auth, sessionCtx *Session) + checkGetAuthFn func(t *testing.T, auth Auth) checkSTS func(t *testing.T, stsMock *mocks.STSMock) }{ "Redshift": { - sessionCtx: &Session{ - DatabaseUser: "some-user", - DatabaseName: "some-database", - Database: newRedshiftDatabase(t, + checkGetAuthFn: func(t *testing.T, auth Auth) { + t.Helper() + databaseUser := "some-user" + databaseName := "some-database" + database := newRedshiftDatabase(t, withCA(fixtures.SAMLOktaCertPEM), withAssumeRole(types.AssumeRole{ RoleARN: "arn:aws:iam::123456789012:role/RedshiftRole", ExternalID: "externalRedshift", - })), - }, - checkGetAuthFn: func(t *testing.T, auth Auth, sessionCtx *Session) { - t.Helper() - dbUser, dbPassword, err := auth.GetRedshiftAuthToken(ctx, sessionCtx) + })) + + dbUser, dbPassword, err := auth.GetRedshiftAuthToken(ctx, database, databaseUser, databaseName) require.NoError(t, err) require.Equal(t, "IAM:some-user", dbUser) require.Equal(t, "some-password", dbPassword) @@ -485,19 +482,18 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { }, }, "Redshift with IAM role": { - sessionCtx: &Session{ - DatabaseUser: "role/some-role", - DatabaseName: "some-database", - Database: newRedshiftDatabase(t, + checkGetAuthFn: func(t *testing.T, auth Auth) { + t.Helper() + databaseUser := "role/some-role" + databaseName := "some-database" + database := newRedshiftDatabase(t, withCA(fixtures.SAMLOktaCertPEM), withAssumeRole(types.AssumeRole{ RoleARN: "arn:aws:iam::123456789012:role/RedshiftRole", ExternalID: "externalRedshift", - })), - }, - checkGetAuthFn: func(t *testing.T, auth Auth, sessionCtx *Session) { - t.Helper() - dbUser, dbPassword, err := auth.GetRedshiftAuthToken(ctx, sessionCtx) + })) + + dbUser, dbPassword, err := auth.GetRedshiftAuthToken(ctx, database, databaseUser, databaseName) require.NoError(t, err) require.Equal(t, "IAM:some-role", dbUser) require.Equal(t, "some-password-for-some-role", dbPassword) @@ -509,18 +505,17 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { }, }, "Redshift Serverless": { - sessionCtx: &Session{ - DatabaseUser: "some-user", - DatabaseName: "some-database", - Database: newRedshiftServerlessDatabase(t, + checkGetAuthFn: func(t *testing.T, auth Auth) { + t.Helper() + databaseUser := "some-user" + databaseName := "some-database" + database := newRedshiftServerlessDatabase(t, withAssumeRole(types.AssumeRole{ RoleARN: "arn:aws:iam::123456789012:role/RedshiftServerlessRole", ExternalID: "externalRedshiftServerless", - })), - }, - checkGetAuthFn: func(t *testing.T, auth Auth, sessionCtx *Session) { - t.Helper() - dbUser, dbPassword, err := auth.GetRedshiftServerlessAuthToken(ctx, sessionCtx) + })) + + dbUser, dbPassword, err := auth.GetRedshiftServerlessAuthToken(ctx, database, databaseUser, databaseName) require.NoError(t, err) require.Equal(t, "IAM:some-user", dbUser) require.Equal(t, "some-password", dbPassword) @@ -533,18 +528,15 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { }, }, "RDS Proxy": { - sessionCtx: &Session{ - DatabaseUser: "some-user", - DatabaseName: "some-database", - Database: newRDSProxyDatabase(t, "my-proxy.proxy-abcdefghijklmnop.us-east-1.rds.amazonaws.com:5432", + checkGetAuthFn: func(t *testing.T, auth Auth) { + t.Helper() + databaseUser := "some-user" + database := newRDSProxyDatabase(t, "my-proxy.proxy-abcdefghijklmnop.us-east-1.rds.amazonaws.com:5432", withAssumeRole(types.AssumeRole{ RoleARN: "arn:aws:iam::123456789012:role/RDSProxyRole", ExternalID: "externalRDSProxy", - })), - }, - checkGetAuthFn: func(t *testing.T, auth Auth, sessionCtx *Session) { - t.Helper() - token, err := auth.GetRDSAuthToken(ctx, sessionCtx) + })) + token, err := auth.GetRDSAuthToken(ctx, database, databaseUser) require.NoError(t, err) require.Contains(t, token, "DBUser=some-user") }, @@ -555,18 +547,16 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { }, }, "ElastiCache Redis": { - sessionCtx: &Session{ - DatabaseUser: "some-user", - DatabaseName: "some-database", - Database: newElastiCacheRedisDatabase(t, + checkGetAuthFn: func(t *testing.T, auth Auth) { + t.Helper() + databaseUser := "some-user" + databaseName := "some-database" + database := newElastiCacheRedisDatabase(t, withAssumeRole(types.AssumeRole{ RoleARN: "arn:aws:iam::123456789012:role/RedisRole", ExternalID: "externalElastiCacheRedis", - })), - }, - checkGetAuthFn: func(t *testing.T, auth Auth, sessionCtx *Session) { - t.Helper() - token, err := auth.GetElastiCacheRedisToken(ctx, sessionCtx) + })) + token, err := auth.GetElastiCacheRedisToken(ctx, database, databaseUser, databaseName) require.NoError(t, err) u, err := url.Parse(token) require.NoError(t, err) @@ -610,7 +600,7 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { tt := tt t.Run(name, func(t *testing.T) { t.Parallel() - tt.checkGetAuthFn(t, auth, tt.sessionCtx) + tt.checkGetAuthFn(t, auth) tt.checkSTS(t, stsMock) }) } @@ -687,10 +677,7 @@ func TestGetAWSIAMCreds(t *testing.T) { }) require.NoError(t, err) - keyId, _, _, err := auth.GetAWSIAMCreds(ctx, &Session{ - Database: tt.db, - DatabaseUser: tt.username, - }) + keyId, _, _, err := auth.GetAWSIAMCreds(ctx, tt.db, tt.username) tt.expectErr(t, err) require.Equal(t, tt.expectedKeyId, keyId) require.ElementsMatch(t, tt.expectedAssumedRoles, tt.stsMock.GetAssumedRoleARNs()) diff --git a/lib/srv/db/common/engines.go b/lib/srv/db/common/engines.go index d9803bf5258e6..4003ced441abf 100644 --- a/lib/srv/db/common/engines.go +++ b/lib/srv/db/common/engines.go @@ -60,7 +60,7 @@ func GetEngine(db types.Database, conf EngineConfig) (Engine, error) { if err := conf.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } - conf.Auth = newReportingAuth(db, conf.Auth) + conf.Auth = newReportingSessionAuth(db, conf.Auth) enginesMu.RLock() name := db.GetProtocol() engineFn := engines[name] @@ -97,8 +97,8 @@ func CheckEngines(names ...string) error { // EngineConfig is the common configuration every database engine uses. type EngineConfig struct { - // Auth handles database access authentication. - Auth Auth + // SessionAuth handles database access authentication. + Auth SessionAuth // Audit emits database access audit events. Audit Audit // AuthClient is the cluster auth server client. diff --git a/lib/srv/db/common/engines_test.go b/lib/srv/db/common/engines_test.go index 41d2c5b16c963..feed3bd69b5ba 100644 --- a/lib/srv/db/common/engines_test.go +++ b/lib/srv/db/common/engines_test.go @@ -86,9 +86,9 @@ func TestRegisterEngine(t *testing.T) { // The auth will be replaced with reporting auth internally, but we can unwrap the original auth. engineInst, ok := repEngine.engine.(*testEngine) require.True(t, ok) - repAuth, ok := engineInst.ec.Auth.(*reportingAuth) + repAuth, ok := engineInst.ec.Auth.(*reportingSessionAuth) require.True(t, ok) - require.Equal(t, ec.Auth, repAuth.Auth) + require.Equal(t, ec.Auth, repAuth.SessionAuth) engineInst.ec.Auth = ec.Auth require.Equal(t, ec, engineInst.ec) } @@ -103,5 +103,5 @@ type testAudit struct { } type testAuth struct { - Auth + SessionAuth } diff --git a/lib/srv/db/common/sessionauth.go b/lib/srv/db/common/sessionauth.go new file mode 100644 index 0000000000000..15e6e5b0f146b --- /dev/null +++ b/lib/srv/db/common/sessionauth.go @@ -0,0 +1,159 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "context" + "crypto/tls" + + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + "golang.org/x/oauth2" + + "github.com/gravitational/teleport/api/types" +) + +// SessionAuth is a convenience wrapper around Auth interface that uses Session type instead of fine-grained parameters. +type SessionAuth interface { + // GetRDSAuthToken generates RDS/Aurora auth token. + GetRDSAuthToken(ctx context.Context, sessionCtx *Session) (string, error) + // GetRedshiftAuthToken generates Redshift auth token. + GetRedshiftAuthToken(ctx context.Context, sessionCtx *Session) (string, string, error) + // GetRedshiftServerlessAuthToken generates Redshift Serverless auth token. + GetRedshiftServerlessAuthToken(ctx context.Context, sessionCtx *Session) (string, string, error) + // GetElastiCacheRedisToken generates an ElastiCache Redis auth token. + GetElastiCacheRedisToken(ctx context.Context, sessionCtx *Session) (string, error) + // GetMemoryDBToken generates a MemoryDB auth token. + GetMemoryDBToken(ctx context.Context, sessionCtx *Session) (string, error) + // GetCloudSQLAuthToken generates Cloud SQL auth token. + GetCloudSQLAuthToken(ctx context.Context, sessionCtx *Session) (string, error) + // GetSpannerTokenSource returns an oauth token source for GCP Spanner. + GetSpannerTokenSource(ctx context.Context, sessionCtx *Session) (oauth2.TokenSource, error) + // GetCloudSQLPassword generates password for a Cloud SQL database user. + GetCloudSQLPassword(ctx context.Context, sessionCtx *Session) (string, error) + // GetAzureAccessToken generates Azure database access token. + GetAzureAccessToken(ctx context.Context, sessionCtx *Session) (string, error) + // GetAzureCacheForRedisToken retrieves auth token for Azure Cache for Redis. + GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *Session) (string, error) + // GetTLSConfig builds the client TLS configuration for the session. + GetTLSConfig(ctx context.Context, sessionCtx *Session) (*tls.Config, error) + // GetAuthPreference returns the cluster authentication config. + GetAuthPreference(ctx context.Context) (types.AuthPreference, error) + // GetAzureIdentityResourceID returns the Azure identity resource ID + // attached to the current compute instance. If Teleport is not running on + // Azure VM returns an error. + GetAzureIdentityResourceID(ctx context.Context, identityName string) (string, error) + // GetAWSIAMCreds returns the AWS IAM credentials, including access key, + // secret access key and session token. + GetAWSIAMCreds(ctx context.Context, sessionCtx *Session) (string, string, string, error) + // Closer releases all resources used by authenticator. + // io.Closer +} + +type sessionAuth struct { + auth Auth + clock clockwork.Clock +} + +func NewSessionAuth(auth Auth, clock clockwork.Clock, sessionCtx *Session) SessionAuth { + return &sessionAuth{ + auth: auth.WithLogger(func(logger logrus.FieldLogger) logrus.FieldLogger { + return logger.WithFields(logrus.Fields{ + "session_id": sessionCtx.ID, + "database": sessionCtx.Database.GetName(), + }) + }), + clock: clock, + } +} + +func (s *sessionAuth) GetRDSAuthToken(ctx context.Context, sessionCtx *Session) (string, error) { + return s.auth.GetRDSAuthToken(ctx, sessionCtx.Database, sessionCtx.DatabaseUser) +} + +func (s *sessionAuth) GetRedshiftAuthToken(ctx context.Context, sessionCtx *Session) (string, string, error) { + return s.auth.GetRedshiftAuthToken(ctx, sessionCtx.Database, sessionCtx.DatabaseUser, sessionCtx.DatabaseName) +} + +func (s *sessionAuth) GetRedshiftServerlessAuthToken(ctx context.Context, sessionCtx *Session) (string, string, error) { + return s.auth.GetRedshiftServerlessAuthToken(ctx, sessionCtx.Database, sessionCtx.DatabaseUser, sessionCtx.DatabaseName) +} + +func (s *sessionAuth) GetElastiCacheRedisToken(ctx context.Context, sessionCtx *Session) (string, error) { + return s.auth.GetElastiCacheRedisToken(ctx, sessionCtx.Database, sessionCtx.DatabaseUser, sessionCtx.DatabaseName) +} + +func (s *sessionAuth) GetMemoryDBToken(ctx context.Context, sessionCtx *Session) (string, error) { + return s.auth.GetMemoryDBToken(ctx, sessionCtx.Database, sessionCtx.DatabaseUser, sessionCtx.DatabaseName) +} + +func (s *sessionAuth) GetCloudSQLAuthToken(ctx context.Context, sessionCtx *Session) (string, error) { + return s.auth.GetCloudSQLAuthToken(ctx, sessionCtx.DatabaseUser) +} + +func (s *sessionAuth) GetSpannerTokenSource(ctx context.Context, sessionCtx *Session) (oauth2.TokenSource, error) { + return s.auth.GetSpannerTokenSource(ctx, sessionCtx.DatabaseUser) +} + +func (s *sessionAuth) GetCloudSQLPassword(ctx context.Context, sessionCtx *Session) (string, error) { + return s.auth.GetCloudSQLPassword(ctx, sessionCtx.Database, sessionCtx.DatabaseUser, sessionCtx.DatabaseName) +} + +func (s *sessionAuth) GetAzureAccessToken(ctx context.Context, sessionCtx *Session) (string, error) { + return s.auth.GetAzureAccessToken(ctx, sessionCtx.Database, sessionCtx.DatabaseUser, sessionCtx.DatabaseName) +} + +func (s *sessionAuth) GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *Session) (string, error) { + return s.auth.GetAzureCacheForRedisToken(ctx, sessionCtx.Database) +} + +func (s *sessionAuth) GetTLSConfig(ctx context.Context, sessionCtx *Session) (*tls.Config, error) { + ttl := sessionCtx.Identity.Expires.Sub(s.clock.Now()) + return s.auth.GetTLSConfig(ctx, ttl, sessionCtx.Database, sessionCtx.DatabaseUser, sessionCtx.DatabaseName) +} + +func (s *sessionAuth) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) { + return s.auth.GetAuthPreference(ctx) +} + +func (s *sessionAuth) GetAzureIdentityResourceID(ctx context.Context, identityName string) (string, error) { + return s.auth.GetAzureIdentityResourceID(ctx, identityName) +} + +func (s *sessionAuth) GetAWSIAMCreds(ctx context.Context, sessionCtx *Session) (string, string, string, error) { + return s.auth.GetAWSIAMCreds(ctx, sessionCtx.Database, sessionCtx.DatabaseUser) +} + +type reportingSessionAuth struct { + SessionAuth + component string + db types.Database +} + +// newReportingSessionAuth returns a reporting version of Auth, wrapping the original Auth instance. +func newReportingSessionAuth(db types.Database, auth SessionAuth) SessionAuth { + return &reportingSessionAuth{ + SessionAuth: auth, + component: "db:auth", + db: db, + } +} + +func (r *reportingSessionAuth) GetTLSConfig(ctx context.Context, sessionCtx *Session) (*tls.Config, error) { + defer methodCallMetrics("GetTLSConfig", r.component, r.db)() + return r.SessionAuth.GetTLSConfig(ctx, sessionCtx) +} diff --git a/lib/srv/db/mysql/gcp_test.go b/lib/srv/db/mysql/gcp_test.go index c06e49dc9c9e9..95394476f4137 100644 --- a/lib/srv/db/mysql/gcp_test.go +++ b/lib/srv/db/mysql/gcp_test.go @@ -41,30 +41,35 @@ type fakeAuth struct { common.Auth } -func (a fakeAuth) GetCloudSQLAuthToken(ctx context.Context, sessionCtx *common.Session) (string, error) { - if !isDBUserFullGCPServerAccountID(sessionCtx.DatabaseUser) { +func (a fakeAuth) GetCloudSQLAuthToken(ctx context.Context, databaseUser string) (string, error) { + if !isDBUserFullGCPServerAccountID(databaseUser) { return "", trace.BadParameter("database user must be a service account") } return "iam-auth-token", nil } -func (a fakeAuth) GetCloudSQLPassword(ctx context.Context, sessionCtx *common.Session) (string, error) { - if isDBUserFullGCPServerAccountID(sessionCtx.DatabaseUser) { +func (a fakeAuth) GetCloudSQLPassword(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, error) { + if isDBUserFullGCPServerAccountID(databaseUser) { return "", trace.BadParameter("database user must not be a service account") } return "one-time-password", nil } -func Test_getGCPUserAndPassowrd(t *testing.T) { +func (a fakeAuth) WithLogger(getUpdatedLogger func(logrus.FieldLogger) logrus.FieldLogger) common.Auth { + if a.Auth != nil { + return a.Auth.WithLogger(getUpdatedLogger) + } + return a +} + +func Test_getGCPUserAndPassword(t *testing.T) { ctx := context.Background() authClient := makeAuthClient(t) db := makeGCPMySQLDatabase(t) - dbAuth := &fakeAuth{} tests := []struct { name string inputDatabaseUser string - mockDBAuth common.Auth mockGCPClient gcp.SQLAdminClient wantDatabaseUser string wantPassword string @@ -73,20 +78,17 @@ func Test_getGCPUserAndPassowrd(t *testing.T) { { name: "iam auth with full service account", inputDatabaseUser: "iam-auth-user@project-id.iam.gserviceaccount.com", - mockDBAuth: dbAuth, wantDatabaseUser: "iam-auth-user", wantPassword: "iam-auth-token", }, { name: "iam auth with short service account", inputDatabaseUser: "iam-auth-user@project-id.iam", - mockDBAuth: dbAuth, wantError: true, }, { name: "iam auth with CLOUD_IAM_SERVICE_ACCOUNT user", inputDatabaseUser: "iam-auth-user", - mockDBAuth: dbAuth, mockGCPClient: &mocks.GCPSQLAdminClientMock{ DatabaseUser: makeGCPDatabaseUser("iam-auth-user", "CLOUD_IAM_SERVICE_ACCOUNT"), }, @@ -96,7 +98,6 @@ func Test_getGCPUserAndPassowrd(t *testing.T) { { name: "iam auth with CLOUD_IAM_GROUP_SERVICE_ACCOUNT user", inputDatabaseUser: "iam-auth-user", - mockDBAuth: dbAuth, mockGCPClient: &mocks.GCPSQLAdminClientMock{ DatabaseUser: makeGCPDatabaseUser("iam-auth-user", "CLOUD_IAM_GROUP_SERVICE_ACCOUNT"), }, @@ -106,7 +107,6 @@ func Test_getGCPUserAndPassowrd(t *testing.T) { { name: "password auth without GetUser permission", inputDatabaseUser: "some-user", - mockDBAuth: dbAuth, mockGCPClient: &mocks.GCPSQLAdminClientMock{ // Default no permission to GetUser, }, @@ -116,7 +116,6 @@ func Test_getGCPUserAndPassowrd(t *testing.T) { { name: "password auth with BUILT_IN user", inputDatabaseUser: "password-user", - mockDBAuth: dbAuth, mockGCPClient: &mocks.GCPSQLAdminClientMock{ DatabaseUser: makeGCPDatabaseUser("password-user", "BUILT_IN"), }, @@ -126,7 +125,6 @@ func Test_getGCPUserAndPassowrd(t *testing.T) { { name: "password auth with empty user type", inputDatabaseUser: "password-user", - mockDBAuth: dbAuth, mockGCPClient: &mocks.GCPSQLAdminClientMock{ DatabaseUser: makeGCPDatabaseUser("password-user", ""), }, @@ -136,7 +134,6 @@ func Test_getGCPUserAndPassowrd(t *testing.T) { { name: "unsupported user type", inputDatabaseUser: "some-user", - mockDBAuth: dbAuth, mockGCPClient: &mocks.GCPSQLAdminClientMock{ DatabaseUser: makeGCPDatabaseUser("some-user", "CLOUD_IAM_USER"), }, @@ -147,19 +144,20 @@ func Test_getGCPUserAndPassowrd(t *testing.T) { for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { + sessionCtx := &common.Session{ + Database: db, + DatabaseUser: test.inputDatabaseUser, + ID: "00000000-0000AAAA-0000BBBB-0000CCCC", + } + engine := NewEngine(common.EngineConfig{ - Auth: test.mockDBAuth, + Auth: common.NewSessionAuth(&fakeAuth{}, clockwork.NewRealClock(), sessionCtx), AuthClient: authClient, Context: ctx, Clock: clockwork.NewRealClock(), Log: logrus.StandardLogger(), }).(*Engine) - sessionCtx := &common.Session{ - Database: db, - DatabaseUser: test.inputDatabaseUser, - } - databaseUser, password, err := engine.getGCPUserAndPassword(ctx, sessionCtx, test.mockGCPClient) if test.wantError { require.Error(t, err) diff --git a/lib/srv/db/redis/client.go b/lib/srv/db/redis/client.go index 5946746220893..84407d157dc4b 100644 --- a/lib/srv/db/redis/client.go +++ b/lib/srv/db/redis/client.go @@ -192,7 +192,7 @@ func fetchCredentialsOnConnect(closeCtx context.Context, sessionCtx *common.Sess } // managedUserCredFetchFunc fetches user password on the fly. -func managedUserCredFetchFunc(sessionCtx *common.Session, auth common.Auth, users common.Users) fetchCredentialsFunc { +func managedUserCredFetchFunc(sessionCtx *common.Session, auth common.SessionAuth, users common.Users) fetchCredentialsFunc { return func(ctx context.Context) (string, string, error) { username := sessionCtx.DatabaseUser password, err := users.GetPassword(ctx, sessionCtx.Database, username) @@ -205,7 +205,7 @@ func managedUserCredFetchFunc(sessionCtx *common.Session, auth common.Auth, user } // azureAccessKeyFetchFunc Azure access key for the "default" user. -func azureAccessKeyFetchFunc(sessionCtx *common.Session, auth common.Auth) fetchCredentialsFunc { +func azureAccessKeyFetchFunc(sessionCtx *common.Session, auth common.SessionAuth) fetchCredentialsFunc { return func(ctx context.Context) (string, string, error) { // Retrieve the auth token for Azure Cache for Redis. Use default user. password, err := auth.GetAzureCacheForRedisToken(ctx, sessionCtx) @@ -218,7 +218,7 @@ func azureAccessKeyFetchFunc(sessionCtx *common.Session, auth common.Auth) fetch } // elasticacheIAMTokenFetchFunc fetches an AWS ElastiCache IAM auth token. -func elasticacheIAMTokenFetchFunc(sessionCtx *common.Session, auth common.Auth) fetchCredentialsFunc { +func elasticacheIAMTokenFetchFunc(sessionCtx *common.Session, auth common.SessionAuth) fetchCredentialsFunc { return func(ctx context.Context) (string, string, error) { // Retrieve the auth token for AWS IAM ElastiCache. password, err := auth.GetElastiCacheRedisToken(ctx, sessionCtx) @@ -232,7 +232,7 @@ func elasticacheIAMTokenFetchFunc(sessionCtx *common.Session, auth common.Auth) } // memorydbIAMTokenFetchFunc fetches an AWS MemoryDB IAM auth token. -func memorydbIAMTokenFetchFunc(sessionCtx *common.Session, auth common.Auth) fetchCredentialsFunc { +func memorydbIAMTokenFetchFunc(sessionCtx *common.Session, auth common.SessionAuth) fetchCredentialsFunc { return func(ctx context.Context) (string, string, error) { password, err := auth.GetMemoryDBToken(ctx, sessionCtx) if err != nil { @@ -244,7 +244,7 @@ func memorydbIAMTokenFetchFunc(sessionCtx *common.Session, auth common.Auth) fet } } -func awsIAMTokenFetchFunc(sessionCtx *common.Session, auth common.Auth) (fetchCredentialsFunc, error) { +func awsIAMTokenFetchFunc(sessionCtx *common.Session, auth common.SessionAuth) (fetchCredentialsFunc, error) { switch sessionCtx.Database.GetType() { case types.DatabaseTypeElastiCache: return elasticacheIAMTokenFetchFunc(sessionCtx, auth), nil diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index 302c37162d405..93b382430f10a 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -175,10 +175,18 @@ func (c *Config) CheckAndSetDefaults(ctx context.Context) (err error) { if c.NewAudit == nil { c.NewAudit = common.NewAudit } + if c.CloudClients == nil { + cloudClients, err := clients.NewClients() + if err != nil { + return trace.Wrap(err) + } + c.CloudClients = cloudClients + } if c.Auth == nil { c.Auth, err = common.NewAuth(common.AuthConfig{ AuthClient: c.AuthClient, Clock: c.Clock, + Clients: c.CloudClients, }) if err != nil { return trace.Wrap(err) @@ -205,13 +213,6 @@ func (c *Config) CheckAndSetDefaults(ctx context.Context) (err error) { if c.ConnectionMonitor == nil { return trace.BadParameter("missing ConnectionMonitor") } - if c.CloudClients == nil { - cloudClients, err := clients.NewClients() - if err != nil { - return trace.Wrap(err) - } - c.CloudClients = cloudClients - } if c.CloudMeta == nil { c.CloudMeta, err = cloud.NewMetadata(cloud.MetadataConfig{ Clients: c.CloudClients, @@ -891,7 +892,7 @@ func (s *Server) close(ctx context.Context) error { s.watcher.Close() } // Close all cloud clients. - errors = append(errors, s.cfg.Auth.Close()) + errors = append(errors, s.cfg.CloudClients.Close()) return trace.NewAggregate(errors...) } @@ -1090,7 +1091,7 @@ func (s *Server) dispatch(sessionCtx *common.Session, rec events.SessionPreparer // An error is returned when a protocol is not supported. func (s *Server) createEngine(sessionCtx *common.Session, audit common.Audit) (common.Engine, error) { return common.GetEngine(sessionCtx.Database, common.EngineConfig{ - Auth: s.cfg.Auth, + Auth: common.NewSessionAuth(s.cfg.Auth, s.cfg.Clock, sessionCtx), Audit: audit, AuthClient: s.cfg.AuthClient, CloudClients: s.cfg.CloudClients, diff --git a/lib/srv/db/sqlserver/connect.go b/lib/srv/db/sqlserver/connect.go index a3b4f6b291c39..10ce597c85c88 100644 --- a/lib/srv/db/sqlserver/connect.go +++ b/lib/srv/db/sqlserver/connect.go @@ -55,7 +55,7 @@ type Connector interface { type connector struct { // Auth is the database auth client - DBAuth common.Auth + DBAuth common.SessionAuth // AuthClient is the teleport client AuthClient windows.AuthInterface // DataDir is the Teleport data directory diff --git a/lib/srv/db/sqlserver/engine_test.go b/lib/srv/db/sqlserver/engine_test.go index ba45897c00b76..1125279f45f43 100644 --- a/lib/srv/db/sqlserver/engine_test.go +++ b/lib/srv/db/sqlserver/engine_test.go @@ -392,7 +392,7 @@ func (o *mockConn) Close() error { } type mockDBAuth struct { - common.Auth + common.SessionAuth // GetAzureIdentityResourceID mocks. azureIdentityResourceID string azureIdentityResourceIDErr error