Skip to content

Commit

Permalink
[v16] Database Access: decouple Auth from Session (#43527)
Browse files Browse the repository at this point in the history
* Database Access: decouple Auth from Session, add SessionAuth wrapper, share CloudClients instance, update tests.

* Drop `SessionAuth` interface, shift usages over to updated `Auth` interface

* Drop unused params
  • Loading branch information
Tener authored Jun 27, 2024
1 parent 3de08d2 commit e90d473
Show file tree
Hide file tree
Showing 22 changed files with 358 additions and 290 deletions.
94 changes: 61 additions & 33 deletions lib/srv/db/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package db

import (
"context"
"crypto/tls"
"testing"
"time"

Expand Down Expand Up @@ -261,9 +262,9 @@ func TestAuthTokens(t *testing.T) {
type testAuth struct {
// Auth is the wrapped "real" auth that handles everything except for
// cloud auth tokens generation.
common.Auth
realAuth common.Auth
// FieldLogger is used for logging.
logrus.FieldLogger
logger logrus.FieldLogger
}

func newTestAuth(ac common.AuthConfig) (*testAuth, error) {
Expand All @@ -272,11 +273,13 @@ func newTestAuth(ac common.AuthConfig) (*testAuth, error) {
return nil, trace.Wrap(err)
}
return &testAuth{
Auth: auth,
FieldLogger: logrus.WithField(teleport.ComponentKey, "auth:test"),
realAuth: auth,
logger: logrus.WithField(teleport.ComponentKey, "auth:test"),
}, nil
}

var _ common.Auth = (*testAuth)(nil)

const (
// rdsAuthToken is a mock RDS IAM auth token.
rdsAuthToken = "rds-auth-token"
Expand Down Expand Up @@ -314,76 +317,101 @@ type fakeTokenSource struct {
}

func (f *fakeTokenSource) Token() (*oauth2.Token, error) {
f.Infof("Generating Cloud Spanner auth token source")
f.Info("Generating Cloud Spanner auth token source")
return &oauth2.Token{
Expiry: f.exp,
AccessToken: f.token,
}, nil
}

// GetRDSAuthToken generates RDS/Aurora auth token.
func (a *testAuth) GetRDSAuthToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
a.Infof("Generating RDS auth token for %v.", sessionCtx)
func (a *testAuth) GetRDSAuthToken(ctx context.Context, database types.Database, databaseUser string) (string, error) {
a.logger.
WithField("database", database).
WithField("databaseUser", databaseUser).
Info("Generating RDS auth token")
return rdsAuthToken, nil
}

// GetRedshiftAuthToken generates Redshift auth token.
func (a *testAuth) GetRedshiftAuthToken(ctx context.Context, sessionCtx *common.Session) (string, string, error) {
a.Infof("Generating Redshift auth token for %v.", sessionCtx)
func (a *testAuth) GetRedshiftAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) {
a.logger.
WithField("database", database).
WithField("databaseUser", databaseUser).
WithField("databaseName", databaseName).
Info("Generating Redshift auth token")
return redshiftAuthUser, redshiftAuthToken, nil
}

func (a *testAuth) GetRedshiftServerlessAuthToken(ctx context.Context, sessionCtx *common.Session) (string, string, error) {
func (a *testAuth) GetRedshiftServerlessAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) {
return "", "", trace.NotImplemented("GetRedshiftServerlessAuthToken is not implemented")
}

func (a *testAuth) GetElastiCacheRedisToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
func (a *testAuth) GetElastiCacheRedisToken(ctx context.Context, database types.Database, databaseUser string) (string, error) {
return elastiCacheRedisToken, nil
}

func (a *testAuth) GetMemoryDBToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
func (a *testAuth) GetMemoryDBToken(ctx context.Context, database types.Database, databaseUser string) (string, error) {
return memorydbToken, nil
}

// GetCloudSQLAuthToken generates Cloud SQL auth token.
func (a *testAuth) GetCloudSQLAuthToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
a.Infof("Generating Cloud SQL auth token for %v.", sessionCtx)
func (a *testAuth) GetCloudSQLAuthToken(ctx context.Context, databaseUser string) (string, error) {
a.logger.WithField("database_user", databaseUser).Info("Generating Cloud SQL auth token")
return cloudSQLAuthToken, nil
}

// GetSpannerTokenSource returns an oauth token source for GCP Spanner.
func (a *testAuth) GetSpannerTokenSource(ctx context.Context, sessionCtx *common.Session) (oauth2.TokenSource, error) {
func (a *testAuth) GetSpannerTokenSource(ctx context.Context, databaseUser string) (oauth2.TokenSource, error) {
return &fakeTokenSource{
token: cloudSpannerAuthToken,
FieldLogger: a.WithField("session", sessionCtx),
FieldLogger: a.logger.WithField("database_user", databaseUser),
}, nil
}

// GetCloudSQLPassword generates Cloud SQL user password.
func (a *testAuth) GetCloudSQLPassword(ctx context.Context, sessionCtx *common.Session) (string, error) {
a.Infof("Generating Cloud SQL user password %v.", sessionCtx)
func (a *testAuth) GetCloudSQLPassword(ctx context.Context, database types.Database, databaseUser string) (string, error) {
a.logger.
WithField("database", database).
WithField("database_user", databaseUser).
Info("Generating Cloud SQL password")
return cloudSQLPassword, nil
}

// GetAzureAccessToken generates Azure access token.
func (a *testAuth) GetAzureAccessToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
a.Infof("Generating Azure access token for %v.", sessionCtx)
func (a *testAuth) GetAzureAccessToken(ctx context.Context) (string, error) {
a.logger.Info("Generating Azure access token")
return azureAccessToken, nil
}

// GetAzureCacheForRedisToken retrieves auth token for Azure Cache for Redis.
func (a *testAuth) GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
a.Infof("Generating Azure Redis token for %v.", sessionCtx)
func (a *testAuth) GetAzureCacheForRedisToken(ctx context.Context, database types.Database) (string, error) {
a.logger.
WithField("database", database).
Info("Generating Azure Redis token")
return azureRedisToken, nil
}

// GetAWSIAMCreds returns the AWS IAM credentials, including access key, secret
// access key and session token.
func (a *testAuth) GetAWSIAMCreds(ctx context.Context, sessionCtx *common.Session) (string, string, string, error) {
a.Infof("Generating AWS IAM credentials for %v.", sessionCtx)
func (a *testAuth) GetTLSConfig(ctx context.Context, expiry time.Time, database types.Database, databaseUser string) (*tls.Config, error) {
return a.realAuth.GetTLSConfig(ctx, expiry, database, databaseUser)
}

func (a *testAuth) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) {
return a.realAuth.GetAuthPreference(ctx)
}

func (a *testAuth) GetAzureIdentityResourceID(ctx context.Context, identityName string) (string, error) {
return a.realAuth.GetAzureIdentityResourceID(ctx, identityName)
}

func (a *testAuth) GetAWSIAMCreds(ctx context.Context, database types.Database, databaseUser string) (string, string, string, error) {
a.logger.
WithField("database", database).
WithField("database_user", databaseUser).
Info("Generating AWS IAM credentials")
return atlasAuthUser, atlasAuthToken, atlasAuthSessionToken, nil
}

func (a *testAuth) WithLogger(getUpdatedLogger func(logrus.FieldLogger) logrus.FieldLogger) common.Auth {
return &testAuth{
realAuth: a.realAuth,
logger: getUpdatedLogger(a.logger),
}
}

func TestMongoDBAtlas(t *testing.T) {
t.Parallel()

Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/cassandra/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ func (e *Engine) authorizeConnection(ctx context.Context) error {
}

func (e *Engine) connect(ctx context.Context, sessionCtx *common.Session) (*protocol.Conn, error) {
config, err := e.Auth.GetTLSConfig(ctx, sessionCtx)
config, err := e.Auth.GetTLSConfig(ctx, sessionCtx.GetExpiry(), sessionCtx.Database, sessionCtx.DatabaseUser)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
3 changes: 1 addition & 2 deletions lib/srv/db/clickhouse/engine_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,7 @@ func (e *Engine) getTransport(ctx context.Context) (*http.Transport, error) {
if err != nil {
return nil, trace.Wrap(err)
}

tlsConfig, err := e.Auth.GetTLSConfig(ctx, e.sessionCtx)
tlsConfig, err := e.Auth.GetTLSConfig(ctx, e.sessionCtx.GetExpiry(), e.sessionCtx.Database, e.sessionCtx.DatabaseUser)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/clickhouse/engine_native.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (e *Engine) handleNativeConnection(ctx context.Context, sessionCtx *common.
return trace.Wrap(err)
}

tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx)
tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx.GetExpiry(), sessionCtx.Database, sessionCtx.DatabaseUser)
if err != nil {
return trace.Wrap(err)
}
Expand Down
Loading

0 comments on commit e90d473

Please sign in to comment.