Skip to content

Commit

Permalink
Database Access: decouple Auth from Session, add SessionAuth wrapper,…
Browse files Browse the repository at this point in the history
… share CloudClients instance, update tests.
  • Loading branch information
Tener committed Jun 21, 2024
1 parent 0fbdb4a commit 35c5b94
Show file tree
Hide file tree
Showing 11 changed files with 499 additions and 287 deletions.
99 changes: 66 additions & 33 deletions lib/srv/db/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package db

import (
"context"
"crypto/tls"
"testing"
"time"

Expand Down Expand Up @@ -261,9 +262,9 @@ func TestAuthTokens(t *testing.T) {
type testAuth struct {
// Auth is the wrapped "real" auth that handles everything except for
// cloud auth tokens generation.
common.Auth
realAuth common.Auth
// FieldLogger is used for logging.
logrus.FieldLogger
logger logrus.FieldLogger
}

func newTestAuth(ac common.AuthConfig) (*testAuth, error) {
Expand All @@ -272,11 +273,13 @@ func newTestAuth(ac common.AuthConfig) (*testAuth, error) {
return nil, trace.Wrap(err)
}
return &testAuth{
Auth: auth,
FieldLogger: logrus.WithField(teleport.ComponentKey, "auth:test"),
realAuth: auth,
logger: logrus.WithField(teleport.ComponentKey, "auth:test"),
}, nil
}

var _ common.Auth = (*testAuth)(nil)

const (
// rdsAuthToken is a mock RDS IAM auth token.
rdsAuthToken = "rds-auth-token"
Expand Down Expand Up @@ -314,76 +317,106 @@ type fakeTokenSource struct {
}

func (f *fakeTokenSource) Token() (*oauth2.Token, error) {
f.Infof("Generating Cloud Spanner auth token source")
f.Info("Generating Cloud Spanner auth token source")
return &oauth2.Token{
Expiry: f.exp,
AccessToken: f.token,
}, nil
}

// GetRDSAuthToken generates RDS/Aurora auth token.
func (a *testAuth) GetRDSAuthToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
a.Infof("Generating RDS auth token for %v.", sessionCtx)
func (a *testAuth) GetRDSAuthToken(ctx context.Context, database types.Database, databaseUser string) (string, error) {
a.logger.
WithField("database", database).
WithField("databaseUser", databaseUser).
Info("Generating RDS auth token")
return rdsAuthToken, nil
}

// GetRedshiftAuthToken generates Redshift auth token.
func (a *testAuth) GetRedshiftAuthToken(ctx context.Context, sessionCtx *common.Session) (string, string, error) {
a.Infof("Generating Redshift auth token for %v.", sessionCtx)
func (a *testAuth) GetRedshiftAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) {
a.logger.
WithField("database", database).
WithField("databaseUser", databaseUser).
WithField("databaseName", databaseName).
Info("Generating Redshift auth token")
return redshiftAuthUser, redshiftAuthToken, nil
}

func (a *testAuth) GetRedshiftServerlessAuthToken(ctx context.Context, sessionCtx *common.Session) (string, string, error) {
func (a *testAuth) GetRedshiftServerlessAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) {
return "", "", trace.NotImplemented("GetRedshiftServerlessAuthToken is not implemented")
}

func (a *testAuth) GetElastiCacheRedisToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
func (a *testAuth) GetElastiCacheRedisToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, error) {
return elastiCacheRedisToken, nil
}

func (a *testAuth) GetMemoryDBToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
func (a *testAuth) GetMemoryDBToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, error) {
return memorydbToken, nil
}

// GetCloudSQLAuthToken generates Cloud SQL auth token.
func (a *testAuth) GetCloudSQLAuthToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
a.Infof("Generating Cloud SQL auth token for %v.", sessionCtx)
func (a *testAuth) GetCloudSQLAuthToken(ctx context.Context, databaseUser string) (string, error) {
a.logger.WithField("database_user", databaseUser).Info("Generating Cloud SQL auth token")
return cloudSQLAuthToken, nil
}

// GetSpannerTokenSource returns an oauth token source for GCP Spanner.
func (a *testAuth) GetSpannerTokenSource(ctx context.Context, sessionCtx *common.Session) (oauth2.TokenSource, error) {
func (a *testAuth) GetSpannerTokenSource(ctx context.Context, databaseUser string) (oauth2.TokenSource, error) {
return &fakeTokenSource{
token: cloudSpannerAuthToken,
FieldLogger: a.WithField("session", sessionCtx),
FieldLogger: a.logger.WithField("database_user", databaseUser),
}, nil
}

// GetCloudSQLPassword generates Cloud SQL user password.
func (a *testAuth) GetCloudSQLPassword(ctx context.Context, sessionCtx *common.Session) (string, error) {
a.Infof("Generating Cloud SQL user password %v.", sessionCtx)
func (a *testAuth) GetCloudSQLPassword(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, error) {
a.logger.
WithField("database", database).
WithField("database_user", databaseUser).
WithField("database_name", databaseName).
Info("Generating Cloud SQL password")
return cloudSQLPassword, nil
}

// GetAzureAccessToken generates Azure access token.
func (a *testAuth) GetAzureAccessToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
a.Infof("Generating Azure access token for %v.", sessionCtx)
func (a *testAuth) GetAzureAccessToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, error) {
a.logger.
WithField("database", database).
WithField("database_user", databaseUser).
WithField("database_name", databaseName).
Info("Generating Azure access token")
return azureAccessToken, nil
}

// GetAzureCacheForRedisToken retrieves auth token for Azure Cache for Redis.
func (a *testAuth) GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
a.Infof("Generating Azure Redis token for %v.", sessionCtx)
func (a *testAuth) GetAzureCacheForRedisToken(ctx context.Context, database types.Database) (string, error) {
a.logger.
WithField("database", database).
Info("Generating Azure Redis token")
return azureRedisToken, nil
}

// GetAWSIAMCreds returns the AWS IAM credentials, including access key, secret
// access key and session token.
func (a *testAuth) GetAWSIAMCreds(ctx context.Context, sessionCtx *common.Session) (string, string, string, error) {
a.Infof("Generating AWS IAM credentials for %v.", sessionCtx)
func (a *testAuth) GetTLSConfig(ctx context.Context, certTTL time.Duration, database types.Database, databaseUser string, databaseName string) (*tls.Config, error) {
return a.realAuth.GetTLSConfig(ctx, certTTL, database, databaseUser, databaseName)
}

func (a *testAuth) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) {
return a.realAuth.GetAuthPreference(ctx)
}

func (a *testAuth) GetAzureIdentityResourceID(ctx context.Context, identityName string) (string, error) {
return a.realAuth.GetAzureIdentityResourceID(ctx, identityName)
}

func (a *testAuth) GetAWSIAMCreds(ctx context.Context, database types.Database, databaseUser string) (string, string, string, error) {
a.logger.
WithField("database", database).
WithField("database_user", databaseUser).
Info("Generating AWS IAM credentials")
return atlasAuthUser, atlasAuthToken, atlasAuthSessionToken, nil
}

func (a *testAuth) WithLogger(getUpdatedLogger func(logrus.FieldLogger) logrus.FieldLogger) common.Auth {
return &testAuth{
realAuth: a.realAuth,
logger: getUpdatedLogger(a.logger),
}
}

func TestMongoDBAtlas(t *testing.T) {
t.Parallel()

Expand Down
Loading

0 comments on commit 35c5b94

Please sign in to comment.