Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Database Access: decouple Auth from Session #43344

Merged
merged 5 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 61 additions & 33 deletions lib/srv/db/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package db

import (
"context"
"crypto/tls"
"testing"
"time"

Expand Down Expand Up @@ -261,9 +262,9 @@ func TestAuthTokens(t *testing.T) {
type testAuth struct {
// Auth is the wrapped "real" auth that handles everything except for
// cloud auth tokens generation.
common.Auth
realAuth common.Auth
// FieldLogger is used for logging.
logrus.FieldLogger
logger logrus.FieldLogger
}

func newTestAuth(ac common.AuthConfig) (*testAuth, error) {
Expand All @@ -272,11 +273,13 @@ func newTestAuth(ac common.AuthConfig) (*testAuth, error) {
return nil, trace.Wrap(err)
}
return &testAuth{
Auth: auth,
FieldLogger: logrus.WithField(teleport.ComponentKey, "auth:test"),
realAuth: auth,
logger: logrus.WithField(teleport.ComponentKey, "auth:test"),
}, nil
}

var _ common.Auth = (*testAuth)(nil)

const (
// rdsAuthToken is a mock RDS IAM auth token.
rdsAuthToken = "rds-auth-token"
Expand Down Expand Up @@ -314,76 +317,101 @@ type fakeTokenSource struct {
}

func (f *fakeTokenSource) Token() (*oauth2.Token, error) {
f.Infof("Generating Cloud Spanner auth token source")
f.Info("Generating Cloud Spanner auth token source")
return &oauth2.Token{
Expiry: f.exp,
AccessToken: f.token,
}, nil
}

// GetRDSAuthToken generates RDS/Aurora auth token.
func (a *testAuth) GetRDSAuthToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
a.Infof("Generating RDS auth token for %v.", sessionCtx)
func (a *testAuth) GetRDSAuthToken(ctx context.Context, database types.Database, databaseUser string) (string, error) {
a.logger.
WithField("database", database).
WithField("databaseUser", databaseUser).
Info("Generating RDS auth token")
return rdsAuthToken, nil
}

// GetRedshiftAuthToken generates Redshift auth token.
func (a *testAuth) GetRedshiftAuthToken(ctx context.Context, sessionCtx *common.Session) (string, string, error) {
a.Infof("Generating Redshift auth token for %v.", sessionCtx)
func (a *testAuth) GetRedshiftAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) {
a.logger.
WithField("database", database).
WithField("databaseUser", databaseUser).
WithField("databaseName", databaseName).
Info("Generating Redshift auth token")
return redshiftAuthUser, redshiftAuthToken, nil
}

func (a *testAuth) GetRedshiftServerlessAuthToken(ctx context.Context, sessionCtx *common.Session) (string, string, error) {
func (a *testAuth) GetRedshiftServerlessAuthToken(ctx context.Context, database types.Database, databaseUser string, databaseName string) (string, string, error) {
return "", "", trace.NotImplemented("GetRedshiftServerlessAuthToken is not implemented")
}

func (a *testAuth) GetElastiCacheRedisToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
func (a *testAuth) GetElastiCacheRedisToken(ctx context.Context, database types.Database, databaseUser string) (string, error) {
return elastiCacheRedisToken, nil
}

func (a *testAuth) GetMemoryDBToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
func (a *testAuth) GetMemoryDBToken(ctx context.Context, database types.Database, databaseUser string) (string, error) {
return memorydbToken, nil
}

// GetCloudSQLAuthToken generates Cloud SQL auth token.
func (a *testAuth) GetCloudSQLAuthToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
a.Infof("Generating Cloud SQL auth token for %v.", sessionCtx)
func (a *testAuth) GetCloudSQLAuthToken(ctx context.Context, databaseUser string) (string, error) {
a.logger.WithField("database_user", databaseUser).Info("Generating Cloud SQL auth token")
return cloudSQLAuthToken, nil
}

// GetSpannerTokenSource returns an oauth token source for GCP Spanner.
func (a *testAuth) GetSpannerTokenSource(ctx context.Context, sessionCtx *common.Session) (oauth2.TokenSource, error) {
func (a *testAuth) GetSpannerTokenSource(ctx context.Context, databaseUser string) (oauth2.TokenSource, error) {
return &fakeTokenSource{
token: cloudSpannerAuthToken,
FieldLogger: a.WithField("session", sessionCtx),
FieldLogger: a.logger.WithField("database_user", databaseUser),
}, nil
}

// GetCloudSQLPassword generates Cloud SQL user password.
func (a *testAuth) GetCloudSQLPassword(ctx context.Context, sessionCtx *common.Session) (string, error) {
a.Infof("Generating Cloud SQL user password %v.", sessionCtx)
func (a *testAuth) GetCloudSQLPassword(ctx context.Context, database types.Database, databaseUser string) (string, error) {
a.logger.
WithField("database", database).
WithField("database_user", databaseUser).
Info("Generating Cloud SQL password")
return cloudSQLPassword, nil
}

// GetAzureAccessToken generates Azure access token.
func (a *testAuth) GetAzureAccessToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
a.Infof("Generating Azure access token for %v.", sessionCtx)
func (a *testAuth) GetAzureAccessToken(ctx context.Context) (string, error) {
a.logger.Info("Generating Azure access token")
return azureAccessToken, nil
}

// GetAzureCacheForRedisToken retrieves auth token for Azure Cache for Redis.
func (a *testAuth) GetAzureCacheForRedisToken(ctx context.Context, sessionCtx *common.Session) (string, error) {
a.Infof("Generating Azure Redis token for %v.", sessionCtx)
func (a *testAuth) GetAzureCacheForRedisToken(ctx context.Context, database types.Database) (string, error) {
a.logger.
WithField("database", database).
Info("Generating Azure Redis token")
return azureRedisToken, nil
}

// GetAWSIAMCreds returns the AWS IAM credentials, including access key, secret
// access key and session token.
func (a *testAuth) GetAWSIAMCreds(ctx context.Context, sessionCtx *common.Session) (string, string, string, error) {
a.Infof("Generating AWS IAM credentials for %v.", sessionCtx)
func (a *testAuth) GetTLSConfig(ctx context.Context, expiry time.Time, database types.Database, databaseUser string) (*tls.Config, error) {
return a.realAuth.GetTLSConfig(ctx, expiry, database, databaseUser)
}

func (a *testAuth) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) {
return a.realAuth.GetAuthPreference(ctx)
}

func (a *testAuth) GetAzureIdentityResourceID(ctx context.Context, identityName string) (string, error) {
return a.realAuth.GetAzureIdentityResourceID(ctx, identityName)
}

func (a *testAuth) GetAWSIAMCreds(ctx context.Context, database types.Database, databaseUser string) (string, string, string, error) {
a.logger.
WithField("database", database).
WithField("database_user", databaseUser).
Info("Generating AWS IAM credentials")
return atlasAuthUser, atlasAuthToken, atlasAuthSessionToken, nil
}

func (a *testAuth) WithLogger(getUpdatedLogger func(logrus.FieldLogger) logrus.FieldLogger) common.Auth {
return &testAuth{
realAuth: a.realAuth,
logger: getUpdatedLogger(a.logger),
}
}

func TestMongoDBAtlas(t *testing.T) {
t.Parallel()

Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/cassandra/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ func (e *Engine) authorizeConnection(ctx context.Context) error {
}

func (e *Engine) connect(ctx context.Context, sessionCtx *common.Session) (*protocol.Conn, error) {
config, err := e.Auth.GetTLSConfig(ctx, sessionCtx)
config, err := e.Auth.GetTLSConfig(ctx, sessionCtx.GetExpiry(), sessionCtx.Database, sessionCtx.DatabaseUser)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
3 changes: 1 addition & 2 deletions lib/srv/db/clickhouse/engine_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,7 @@ func (e *Engine) getTransport(ctx context.Context) (*http.Transport, error) {
if err != nil {
return nil, trace.Wrap(err)
}

tlsConfig, err := e.Auth.GetTLSConfig(ctx, e.sessionCtx)
tlsConfig, err := e.Auth.GetTLSConfig(ctx, e.sessionCtx.GetExpiry(), e.sessionCtx.Database, e.sessionCtx.DatabaseUser)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/clickhouse/engine_native.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (e *Engine) handleNativeConnection(ctx context.Context, sessionCtx *common.
return trace.Wrap(err)
}

tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx)
tlsConfig, err := e.Auth.GetTLSConfig(ctx, sessionCtx.GetExpiry(), sessionCtx.Database, sessionCtx.DatabaseUser)
if err != nil {
return trace.Wrap(err)
}
Expand Down
Loading
Loading