From 953e090eef107a8425a64fb3d30cca83920620a2 Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Thu, 28 Mar 2024 16:14:00 -0700 Subject: [PATCH] update AWS RDS db e2e tests * update GHA workflow env vars to be RDS specific * add auto db user provisioning tests * provision any db users needed in test code as setup, instead of relying on the cloud-terraform repo provisioning them --- .github/workflows/aws-e2e-tests-non-root.yaml | 5 +- e2e/aws/eks_test.go | 4 +- e2e/aws/fixtures_test.go | 66 +- e2e/aws/main_test.go | 19 +- e2e/aws/rds_test.go | 686 +++++++++++++----- go.mod | 1 + go.sum | 2 + 7 files changed, 582 insertions(+), 201 deletions(-) diff --git a/.github/workflows/aws-e2e-tests-non-root.yaml b/.github/workflows/aws-e2e-tests-non-root.yaml index 0f9cf44fabb5a..0167970c4196f 100644 --- a/.github/workflows/aws-e2e-tests-non-root.yaml +++ b/.github/workflows/aws-e2e-tests-non-root.yaml @@ -14,9 +14,8 @@ env: KUBERNETES_SERVICE_ASSUME_ROLE: arn:aws:iam::307493967395:role/tf-eks-discovery-ci-cluster-kubernetes-service-access-role KUBE_DISCOVERY_SERVICE_ASSUME_ROLE: arn:aws:iam::307493967395:role/tf-eks-discovery-ci-cluster-discovery-service-access-role EKS_CLUSTER_NAME: gha-discovery-ci-eks-us-west-2-307493967395 - DATABASE_USER: teleport-ci-e2e-test - DATABASE_SERVICE_ASSUME_ROLE: arn:aws:iam::307493967395:role/ci-database-e2e-tests-database-svc - DATABASE_DISCOVERY_SERVICE_ASSUME_ROLE: arn:aws:iam::307493967395:role/ci-database-e2e-tests-discovery-svc + RDS_ACCESS_ROLE: arn:aws:iam::307493967395:role/ci-database-e2e-tests-rds-access + RDS_DISCOVERY_ROLE: arn:aws:iam::307493967395:role/ci-database-e2e-tests-rds-discovery RDS_POSTGRES_INSTANCE_NAME: ci-database-e2e-tests-rds-postgres-instance-us-west-2-307493967395 RDS_MYSQL_INSTANCE_NAME: ci-database-e2e-tests-rds-mysql-instance-us-west-2-307493967395 DISCOVERY_MATCHER_LABELS: "*=*" diff --git a/e2e/aws/eks_test.go b/e2e/aws/eks_test.go index 1295cfc474157..9a936894c8b78 100644 --- a/e2e/aws/eks_test.go +++ b/e2e/aws/eks_test.go @@ -128,7 +128,7 @@ func awsEKSDiscoveryMatchedCluster(t *testing.T) { // the permissions are correct. kubeClient, _, err := kube.ProxyClient(kube.ProxyConfig{ T: teleport, - Username: username, + Username: hostUser, KubeUsers: kubeUsers, KubeGroups: kubeGroups, KubeCluster: clusters[0].GetName(), @@ -182,7 +182,7 @@ func awsEKSDiscoveryUnmatchedCluster(t *testing.T) { // clusters. func withFullKubeAccessUserRole(t *testing.T) testOptionsFunc { // Create a new role with full access to all kube clusters. - return withUserRole(t, "kubemaster", types.RoleSpecV6{ + return withUserRole(t, hostUser, "kubemaster", types.RoleSpecV6{ Allow: types.RoleConditions{ KubeGroups: kubeGroups, KubeUsers: kubeUsers, diff --git a/e2e/aws/fixtures_test.go b/e2e/aws/fixtures_test.go index 4e32d828e19f4..c9f2d7abdded7 100644 --- a/e2e/aws/fixtures_test.go +++ b/e2e/aws/fixtures_test.go @@ -35,15 +35,15 @@ import ( "github.com/gravitational/teleport/lib/utils" ) -// username is the name of the host user used for tests. -var username string +// hostUser is the name of the host user used for tests. +var hostUser string func init() { me, err := user.Current() if err != nil { panic(err) } - username = me.Username + hostUser = me.Username } // mustGetEnv is a test helper that fetches an env variable or fails with an @@ -67,6 +67,13 @@ func mustGetDiscoveryMatcherLabels(t *testing.T) types.Labels { return out } +func mustGetDBAdmin(t *testing.T, db types.Database) types.DatabaseAdminUser { + t.Helper() + adminUser := db.GetAdminUser() + require.NotEmpty(t, adminUser.Name, "unknown db auto user provisioning admin, should have been imported using the %q label", types.DatabaseAdminLabel) + return adminUser +} + // testOptionsFunc is a test option configuration func. type testOptionsFunc func(*testOptions) @@ -78,15 +85,16 @@ type testOptions struct { // serviceConfigFuncs are a list of functions that configure the Teleport // cluster before it starts. serviceConfigFuncs []func(*servicecfg.Config) - // userRoles are roles that will be bootstrapped and added to the Teleport - // user under test. - userRoles []types.Role + // userRoles is a map from username to that user's roles that will be + // bootstrapped and added to the Teleport test cluster. + userRoles map[string][]types.Role } // createTeleportCluster sets up a Teleport cluster for tests. func createTeleportCluster(t *testing.T, opts ...testOptionsFunc) *helpers.TeleInstance { t.Helper() var options testOptions + options.userRoles = make(map[string][]types.Role) for _, opt := range opts { opt(&options) } @@ -98,7 +106,9 @@ func createTeleportCluster(t *testing.T, opts ...testOptionsFunc) *helpers.TeleI teleport := helpers.NewInstance(t, cfg) // Create a new user with the role created above. - teleport.AddUserWithRole(username, options.userRoles...) + for name, roles := range options.userRoles { + teleport.AddUserWithRole(name, roles...) + } tconf := newTeleportConfig(t) for _, optFn := range options.serviceConfigFuncs { @@ -149,13 +159,13 @@ func newTeleportConfig(t *testing.T) *servicecfg.Config { // withUserRole creates a new role that will be bootstraped and then granted to // the Teleport user under test. -func withUserRole(t *testing.T, name string, spec types.RoleSpecV6) testOptionsFunc { +func withUserRole(t *testing.T, user, name string, spec types.RoleSpecV6) testOptionsFunc { t.Helper() // Create a new role with full access to all databases. role, err := types.NewRole(name, spec) require.NoError(t, err) return func(options *testOptions) { - options.userRoles = append(options.userRoles, role) + options.userRoles[user] = append(options.userRoles[user], role) } } @@ -208,11 +218,35 @@ func withDatabaseService(t *testing.T, matchers ...services.ResourceMatcher) tes func withFullDatabaseAccessUserRole(t *testing.T) testOptionsFunc { t.Helper() // Create a new role with full access to all databases. - return withUserRole(t, "db-access", types.RoleSpecV6{ - Allow: types.RoleConditions{ - DatabaseLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, - DatabaseUsers: []string{types.Wildcard}, - DatabaseNames: []string{types.Wildcard}, - }, - }) + return withUserRole(t, hostUser, "db-access", allowDatabaseAccessRoleSpec) +} + +var allowDatabaseAccessRoleSpec = types.RoleSpecV6{ + Allow: types.RoleConditions{ + DatabaseLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, + DatabaseUsers: []string{types.Wildcard}, + DatabaseNames: []string{types.Wildcard}, + }, +} + +var autoDBUserKeepSpec = types.RoleSpecV6{ + Allow: types.RoleConditions{ + DatabaseLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, + DatabaseUsers: []string{types.Wildcard}, + DatabaseNames: []string{types.Wildcard}, + }, + Options: types.RoleOptions{ + CreateDatabaseUserMode: types.CreateDatabaseUserMode_DB_USER_MODE_KEEP, + }, +} + +var autoDBUserDropSpec = types.RoleSpecV6{ + Allow: types.RoleConditions{ + DatabaseLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, + DatabaseUsers: []string{types.Wildcard}, + DatabaseNames: []string{types.Wildcard}, + }, + Options: types.RoleOptions{ + CreateDatabaseUserMode: types.CreateDatabaseUserMode_DB_USER_MODE_BEST_EFFORT_DROP, + }, } diff --git a/e2e/aws/main_test.go b/e2e/aws/main_test.go index e29e806886043..03df021073692 100644 --- a/e2e/aws/main_test.go +++ b/e2e/aws/main_test.go @@ -32,18 +32,15 @@ const ( // discoveryMatcherLabelsEnv is the env variable that specifies the matcher // labels to use in test discovery services. discoveryMatcherLabelsEnv = "DISCOVERY_MATCHER_LABELS" - // dbSvcRoleARNEnv is the environment variable that specifies the IAM role - // that Teleport Database Service will assume to access databases. - // check modules/databases-ci/ from cloud-terraform repo for more details. - dbSvcRoleARNEnv = "DATABASE_SERVICE_ASSUME_ROLE" - // dbDiscoverySvcRoleARNEnv is the environment variable that specifies the + // rdsAccessRoleEnv is the environment variable that specifies the IAM role + // that Teleport Database Service will assume to access RDS databases. + // See modules/databases-ci/ from cloud-terraform repo for more details. + rdsAccessRoleEnv = "RDS_ACCESS_ROLE" + // rdsDiscoveryRoleEnv is the environment variable that specifies the // IAM role that Teleport Discovery Service will assume to discover - // databases. - // check modules/databases-ci/ from cloud-terraform repo for more details. - dbDiscoverySvcRoleARNEnv = "DATABASE_DISCOVERY_SERVICE_ASSUME_ROLE" - // dbUserEnv is the database user configured in databases for access via - // Teleport. - dbUserEnv = "DATABASE_USER" + // RDS databases. + // See modules/databases-ci/ from cloud-terraform repo for more details. + rdsDiscoveryRoleEnv = "RDS_DISCOVERY_ROLE" // rdsPostgresInstanceNameEnv is the environment variable that specifies the // name of the RDS Postgres DB instance that will be created by the Teleport // Discovery Service. diff --git a/e2e/aws/rds_test.go b/e2e/aws/rds_test.go index 952aca1954b91..3c1c47c0f416b 100644 --- a/e2e/aws/rds_test.go +++ b/e2e/aws/rds_test.go @@ -20,16 +20,24 @@ package e2e import ( "context" + "crypto/rand" "crypto/tls" + "encoding/json" "fmt" "net" "os" "strconv" + "sync" "testing" "time" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/rds" + "github.com/aws/aws-sdk-go-v2/service/secretsmanager" mysqlclient "github.com/go-mysql-org/go-mysql/client" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -53,38 +61,43 @@ func TestDatabases(t *testing.T) { t.Parallel() testEnabled := os.Getenv(teleport.AWSRunDBTests) if ok, _ := strconv.ParseBool(testEnabled); !ok { - t.Skip("Skipping Databases test suite.") + t.Skip("Skipping AWS Databases test suite.") } + // when adding a new type of AWS db e2e test, you should add to this + // unmatched discovery test and add a test for matched discovery/connection + // as well below. t.Run("unmatched discovery", awsDBDiscoveryUnmatched) - t.Run("matched discovery", awsDBDiscoveryMatched) + t.Run("rds", testRDS) } func awsDBDiscoveryUnmatched(t *testing.T) { t.Parallel() // get test settings awsRegion := mustGetEnv(t, awsRegionEnv) - dbDiscoverySvcRoleARN := mustGetEnv(t, dbDiscoverySvcRoleARNEnv) - dbSvcRoleARN := mustGetEnv(t, dbSvcRoleARNEnv) - cluster := createTeleportCluster(t, - withSingleProxyPort(t), - withDiscoveryService(t, "db-e2e-test", types.AWSMatcher{ - Types: []string{types.AWSMatcherRDS}, + + // setup discovery matchers + var matchers []types.AWSMatcher + for matcherType, assumeRoleARN := range map[string]string{ + // add a new matcher/role here to test that discovery properly + // does *not* that kind of database for some unmatched tag. + types.AWSMatcherRDS: mustGetEnv(t, rdsDiscoveryRoleEnv), + } { + matchers = append(matchers, types.AWSMatcher{ + Types: []string{matcherType}, Tags: types.Labels{ // This label should not match. "env": {"tag_not_found"}, }, Regions: []string{awsRegion}, AssumeRole: &types.AssumeRole{ - RoleARN: dbDiscoverySvcRoleARN, + RoleARN: assumeRoleARN, }, - }), - withDatabaseService(t, services.ResourceMatcher{ - Labels: types.Labels{types.Wildcard: {types.Wildcard}}, - AWS: services.ResourceMatcherAWS{ - AssumeRoleARN: dbSvcRoleARN, - }, - }), - withFullDatabaseAccessUserRole(t), + }) + } + + cluster := createTeleportCluster(t, + withSingleProxyPort(t), + withDiscoveryService(t, "db-e2e-test", matchers...), ) // Get the auth server. @@ -100,186 +113,262 @@ func awsDBDiscoveryUnmatched(t *testing.T) { }, 2*time.Minute, 10*time.Second, "discovery service incorrectly created a database") } -func awsDBDiscoveryMatched(t *testing.T) { - t.Parallel() - // get test settings - awsRegion := mustGetEnv(t, awsRegionEnv) - dbDiscoverySvcRoleARN := mustGetEnv(t, dbDiscoverySvcRoleARNEnv) - dbSvcRoleARN := mustGetEnv(t, dbSvcRoleARNEnv) - dbUser := mustGetEnv(t, dbUserEnv) - rdsPostgresInstanceName := mustGetEnv(t, rdsPostgresInstanceNameEnv) - rdsMySQLInstanceName := mustGetEnv(t, rdsMySQLInstanceNameEnv) - - cluster := createTeleportCluster(t, +// makeDBTestCluster is a test helper to set up a typical test cluster for +// database e2e tests. +func makeDBTestCluster(t *testing.T, accessRole, discoveryRole, discoveryMatcherType string, opts ...testOptionsFunc) *helpers.TeleInstance { + t.Helper() + opts = append([]testOptionsFunc{ withSingleProxyPort(t), withDiscoveryService(t, "db-e2e-test", types.AWSMatcher{ - Types: []string{types.AWSMatcherRDS}, + Types: []string{discoveryMatcherType}, Tags: mustGetDiscoveryMatcherLabels(t), - Regions: []string{awsRegion}, + Regions: []string{mustGetEnv(t, awsRegionEnv)}, AssumeRole: &types.AssumeRole{ - RoleARN: dbDiscoverySvcRoleARN, + RoleARN: discoveryRole, }, }), withDatabaseService(t, services.ResourceMatcher{ Labels: types.Labels{types.Wildcard: {types.Wildcard}}, AWS: services.ResourceMatcherAWS{ - AssumeRoleARN: dbSvcRoleARN, + AssumeRoleARN: accessRole, }, }), withFullDatabaseAccessUserRole(t), - ) + }, opts...) + return createTeleportCluster(t, opts...) +} - wantDBNames := []string{ - rdsPostgresInstanceName, - rdsMySQLInstanceName, - } - // wait for the databases to be discovered - waitForDatabases(t, cluster.Process, wantDBNames...) - // wait for the database heartbeats from database service - waitForDatabaseServers(t, cluster.Process, wantDBNames...) - - rdsPostgresInstance := tlsca.RouteToDatabase{ - ServiceName: rdsPostgresInstanceName, - Protocol: defaults.ProtocolPostgres, - Username: dbUser, - Database: "postgres", - } - rdsMySQLInstance := tlsca.RouteToDatabase{ - ServiceName: rdsMySQLInstanceName, - Protocol: defaults.ProtocolMySQL, - Username: dbUser, - Database: "", // not needed +// testRDS tests AWS RDS database discovery and connections. +// Since RDS has many different db engines available, this test groups all +// the engines together into subtests: postgres, mysql, etc. +func testRDS(t *testing.T) { + t.Parallel() + // give everything 2 minutes to finish. Realistically it takes ~10-20 + // seconds, but let's be generous to maybe avoid flakey failures. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + t.Cleanup(cancel) + + // use random names so we can test auto provisioning these users via + // Teleport, without tests colliding with eachother across parallel test + // runs. + autoUserKeep := "auto_keep_" + randASCII(t, 6) + autoUserDrop := "auto_drop_" + randASCII(t, 6) + + accessRole := mustGetEnv(t, rdsAccessRoleEnv) + discoveryRole := mustGetEnv(t, rdsDiscoveryRoleEnv) + opts := []testOptionsFunc{ + withUserRole(t, autoUserKeep, "db-auto-user-keeper", autoDBUserKeepSpec), + withUserRole(t, autoUserDrop, "db-auto-user-dropper", autoDBUserDropSpec), } - t.Run("connection", func(t *testing.T) { - tests := []struct { - name string - route tlsca.RouteToDatabase - testDBConnection dbConnectionTestFunc - }{ - { - name: "RDS postgres instance", - route: rdsPostgresInstance, - testDBConnection: postgresConnTestFn(cluster), - }, - { - name: "RDS postgres instance via local proxy", - route: rdsPostgresInstance, - testDBConnection: postgresLocalProxyConnTestFn(cluster), - }, - { - name: "RDS MySQL instance via local proxy", - route: rdsMySQLInstance, - testDBConnection: mySQLLocalProxyConnTestFn(cluster), - }, - } - for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - test.testDBConnection(t, ctx, test.route) - }) - } + cluster := makeDBTestCluster(t, accessRole, discoveryRole, types.AWSMatcherRDS, opts...) + + t.Run("postgres", func(t *testing.T) { + t.Parallel() + + // wait for the database to be discovered + pgDBName := mustGetEnv(t, rdsPostgresInstanceNameEnv) + waitForDatabases(t, cluster.Process, pgDBName) + db, err := cluster.Process.GetAuthServer().GetDatabase(ctx, pgDBName) + require.NoError(t, err) + adminUser := mustGetDBAdmin(t, db) + + conn := connectAsRDSPostgresAdmin(t, ctx, db.GetAWS().RDS.InstanceID) + provisionRDSPostgresAutoUsersAdmin(t, ctx, conn, adminUser.Name) + t.Cleanup(func() { + // best effort cleanup all the users created for the tests, + // including the auto drop user in case Teleport fails to do so. + _, _ = conn.Exec(ctx, fmt.Sprintf("DROP USER %q", autoUserKeep)) + _, _ = conn.Exec(ctx, fmt.Sprintf("DROP USER %q", autoUserDrop)) + }) + + t.Run("connect", func(t *testing.T) { + for name, test := range map[string]struct { + user string + dbUser string + }{ + "existing user": {user: hostUser, dbUser: adminUser.Name}, + "auto user keep": {user: autoUserKeep, dbUser: autoUserKeep}, + "auto user drop": {user: autoUserDrop, dbUser: autoUserDrop}, + } { + test := test + route := tlsca.RouteToDatabase{ + ServiceName: db.GetName(), + Protocol: defaults.ProtocolPostgres, + Username: test.dbUser, + Database: "postgres", + } + t.Run(name+"/via proxy", func(t *testing.T) { + t.Parallel() + postgresConnTest(t, cluster, test.user, route) + }) + t.Run(name+"/via local proxy", func(t *testing.T) { + t.Parallel() + postgresLocalProxyConnTest(t, cluster, test.user, route) + }) + } + }) + + t.Run("connect/auto user keep", func(t *testing.T) { + t.Parallel() + waitForPostgresAutoUserDeactivate(t, ctx, conn, autoUserKeep) + }) + t.Run("connect/auto user drop", func(t *testing.T) { + t.Parallel() + waitForPostgresAutoUserDrop(t, ctx, conn, autoUserDrop) + }) }) -} -type dbConnectionTestFunc func(*testing.T, context.Context, tlsca.RouteToDatabase) + t.Run("mysql", func(t *testing.T) { + t.Parallel() + + // wait for the database to be discovered + myDBName := mustGetEnv(t, rdsMySQLInstanceNameEnv) + waitForDatabases(t, cluster.Process, myDBName) + db, err := cluster.Process.GetAuthServer().GetDatabase(ctx, myDBName) + require.NoError(t, err) + adminUser := mustGetDBAdmin(t, db) + + conn := connectAsRDSMySQLAdmin(t, ctx, db.GetAWS().RDS.InstanceID) + provisionRDSMySQLAutoUsersAdmin(t, ctx, conn, adminUser.Name) + t.Cleanup(func() { + // best effort cleanup all the users created for the tests, + // including the auto drop user in case Teleport fails to do so. + _, _ = conn.Execute(fmt.Sprintf("DROP USER %q", autoUserKeep)) + _, _ = conn.Execute(fmt.Sprintf("DROP USER %q", autoUserDrop)) + }) + + t.Run("connect", func(t *testing.T) { + for name, test := range map[string]struct { + user string + dbUser string + }{ + "existing user": {user: hostUser, dbUser: adminUser.Name}, + "auto user keep": {user: autoUserKeep, dbUser: autoUserKeep}, + "auto user drop": {user: autoUserDrop, dbUser: autoUserDrop}, + } { + test := test + route := tlsca.RouteToDatabase{ + ServiceName: myDBName, + Protocol: defaults.ProtocolMySQL, + Username: test.dbUser, + Database: "", // not needed + } + t.Run(name+"/via local proxy", func(t *testing.T) { + t.Parallel() + mysqlLocalProxyConnTest(t, cluster, test.user, route) + }) + } + }) + + t.Run("connect/auto user keep", func(t *testing.T) { + t.Parallel() + waitForMySQLAutoUserDeactivate(t, conn, autoUserKeep) + }) + t.Run("connect/auto user drop", func(t *testing.T) { + t.Parallel() + waitForMySQLAutoUserDrop(t, conn, autoUserDrop) + }) + }) +} // postgresConnTestFn tests connection to a postgres database via proxy web // multiplexer. -func postgresConnTestFn(cluster *helpers.TeleInstance) dbConnectionTestFunc { - return func(t *testing.T, ctx context.Context, route tlsca.RouteToDatabase) { - var pgConn *pgconn.PgConn - // retry for a while, the database service might need time to give - // itself IAM rds:connect permissions. - require.EventuallyWithT(t, func(t *assert.CollectT) { - var err error - pgConn, err = postgres.MakeTestClient(ctx, common.TestClientConfig{ - AuthClient: cluster.GetSiteAPI(cluster.Secrets.SiteName), - AuthServer: cluster.Process.GetAuthServer(), - Address: cluster.Web, - Cluster: cluster.Secrets.SiteName, - Username: username, - RouteToDatabase: route, - }) - assert.NoError(t, err) - assert.NotNil(t, pgConn) - }, time.Second*10, time.Second, "connecting to postgres") - - // Execute a query. - results, err := pgConn.Exec(ctx, "select 1").ReadAll() - require.NoError(t, err) - for i, r := range results { - require.NoError(t, r.Err, "error in result %v", i) - } +func postgresConnTest(t *testing.T, cluster *helpers.TeleInstance, user string, route tlsca.RouteToDatabase) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + t.Cleanup(cancel) + var pgConn *pgconn.PgConn + // retry for a while, the database service might need time to give + // itself IAM rds:connect permissions. + require.EventuallyWithT(t, func(t *assert.CollectT) { + var err error + pgConn, err = postgres.MakeTestClient(ctx, common.TestClientConfig{ + AuthClient: cluster.GetSiteAPI(cluster.Secrets.SiteName), + AuthServer: cluster.Process.GetAuthServer(), + Address: cluster.Web, + Cluster: cluster.Secrets.SiteName, + Username: user, + RouteToDatabase: route, + }) + assert.NoError(t, err) + assert.NotNil(t, pgConn) + }, time.Second*10, time.Second, "connecting to postgres") - // Disconnect. - err = pgConn.Close(ctx) - require.NoError(t, err) + // Execute a query. + results, err := pgConn.Exec(ctx, "select 1").ReadAll() + require.NoError(t, err) + for i, r := range results { + require.NoError(t, r.Err, "error in result %v", i) } + + // Disconnect. + err = pgConn.Close(ctx) + require.NoError(t, err) } -// postgresLocalProxyConnTestFn tests connection to a postgres database via +// postgresLocalProxyConnTest tests connection to a postgres database via // local proxy tunnel. -func postgresLocalProxyConnTestFn(cluster *helpers.TeleInstance) dbConnectionTestFunc { - return func(t *testing.T, ctx context.Context, route tlsca.RouteToDatabase) { - lp := startLocalALPNProxy(t, ctx, cluster, route) - defer lp.Close() - - connString := fmt.Sprintf("postgres://%s@%v/%s", - route.Username, lp.GetAddr(), route.Database) - var pgConn *pgconn.PgConn - // retry for a while, the database service might need time to give - // itself IAM rds:connect permissions. - require.EventuallyWithT(t, func(t *assert.CollectT) { - var err error - pgConn, err = pgconn.Connect(ctx, connString) - assert.NoError(t, err) - assert.NotNil(t, pgConn) - }, time.Second*10, time.Second, "connecting to postgres") - - // Execute a query. - results, err := pgConn.Exec(ctx, "select 1").ReadAll() - require.NoError(t, err) - for i, r := range results { - require.NoError(t, r.Err, "error in result %v", i) - } +func postgresLocalProxyConnTest(t *testing.T, cluster *helpers.TeleInstance, user string, route tlsca.RouteToDatabase) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + t.Cleanup(cancel) + lp := startLocalALPNProxy(t, ctx, user, cluster, route) + defer lp.Close() + + connString := fmt.Sprintf("postgres://%s@%v/%s", + route.Username, lp.GetAddr(), route.Database) + var pgConn *pgconn.PgConn + // retry for a while, the database service might need time to give + // itself IAM rds:connect permissions. + require.EventuallyWithT(t, func(t *assert.CollectT) { + var err error + pgConn, err = pgconn.Connect(ctx, connString) + assert.NoError(t, err) + assert.NotNil(t, pgConn) + }, time.Second*10, time.Second, "connecting to postgres") - // Disconnect. - err = pgConn.Close(ctx) - require.NoError(t, err) + // Execute a query. + results, err := pgConn.Exec(ctx, "select 1").ReadAll() + require.NoError(t, err) + for i, r := range results { + require.NoError(t, r.Err, "error in result %v", i) } + + // Disconnect. + err = pgConn.Close(ctx) + require.NoError(t, err) } -// mySQLLocalProxyConnTestFn tests connection to a MySQL database via +// mysqlLocalProxyConnTest tests connection to a MySQL database via // local proxy tunnel. -func mySQLLocalProxyConnTestFn(cluster *helpers.TeleInstance) dbConnectionTestFunc { - return func(t *testing.T, ctx context.Context, route tlsca.RouteToDatabase) { - lp := startLocalALPNProxy(t, ctx, cluster, route) - defer lp.Close() - - var conn *mysqlclient.Conn - // retry for a while, the database service might need time to give - // itself IAM rds:connect permissions. - require.EventuallyWithT(t, func(t *assert.CollectT) { - var err error - conn, err = mysqlclient.Connect(lp.GetAddr(), route.Username, "" /*no password*/, route.Database) - assert.NoError(t, err) - assert.NotNil(t, conn) - }, time.Second*10, time.Second, "connecting to mysql") - - // Execute a query. - _, err := conn.Execute("select 1") - require.NoError(t, err) +func mysqlLocalProxyConnTest(t *testing.T, cluster *helpers.TeleInstance, user string, route tlsca.RouteToDatabase) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + t.Cleanup(cancel) - // Disconnect. - require.NoError(t, conn.Close()) - } + lp := startLocalALPNProxy(t, ctx, user, cluster, route) + defer lp.Close() + + var conn *mysqlclient.Conn + // retry for a while, the database service might need time to give + // itself IAM rds:connect permissions. + require.EventuallyWithT(t, func(t *assert.CollectT) { + var err error + conn, err = mysqlclient.Connect(lp.GetAddr(), route.Username, "" /*no password*/, route.Database) + assert.NoError(t, err) + assert.NotNil(t, conn) + }, time.Second*10, time.Second, "connecting to mysql") + + // Execute a query. + _, err := conn.Execute("select 1") + require.NoError(t, err) + + // Disconnect. + require.NoError(t, conn.Close()) } // startLocalALPNProxy starts local ALPN proxy for the specified database. -func startLocalALPNProxy(t *testing.T, ctx context.Context, cluster *helpers.TeleInstance, route tlsca.RouteToDatabase) *alpnproxy.LocalProxy { +func startLocalALPNProxy(t *testing.T, ctx context.Context, user string, cluster *helpers.TeleInstance, route tlsca.RouteToDatabase) *alpnproxy.LocalProxy { t.Helper() proto, err := alpncommon.ToALPNProtocol(route.Protocol) require.NoError(t, err) @@ -291,7 +380,7 @@ func startLocalALPNProxy(t *testing.T, ctx context.Context, cluster *helpers.Tel require.NoError(t, err) authSrv := cluster.Process.GetAuthServer() - tlsCert := generateClientDBCert(t, authSrv, username, route) + tlsCert := generateClientDBCert(t, authSrv, user, route) proxy, err := alpnproxy.NewLocalProxy(alpnproxy.LocalProxyConfig{ RemoteProxyAddr: proxyNetAddr.String(), @@ -332,6 +421,7 @@ func generateClientDBCert(t *testing.T, authSrv *auth.Server, user string, route } func waitForDatabases(t *testing.T, auth *service.TeleportProcess, wantNames ...string) { + t.Helper() require.EventuallyWithT(t, func(t *assert.CollectT) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -347,10 +437,7 @@ func waitForDatabases(t *testing.T, auth *service.TeleportProcess, wantNames ... for _, name := range wantNames { assert.Contains(t, seen, name) } - }, 3*time.Minute, 3*time.Second, "waiting for the discovery service to create databases") -} - -func waitForDatabaseServers(t *testing.T, auth *service.TeleportProcess, wantNames ...string) { + }, 3*time.Minute, 3*time.Second, "waiting for the discovery service to create db resources") require.EventuallyWithT(t, func(t *assert.CollectT) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -368,3 +455,264 @@ func waitForDatabaseServers(t *testing.T, auth *service.TeleportProcess, wantNam } }, 1*time.Minute, time.Second, "waiting for the database service to heartbeat the databases") } + +// rdsAdminInfo contains common info needed to connect as an RDS admin user via +// password auth. +type rdsAdminInfo struct { + endpoint, + username, + password string +} + +func connectAsRDSPostgresAdmin(t *testing.T, ctx context.Context, instanceID string) *pgx.Conn { + t.Helper() + info := getRDSAdminInfo(t, ctx, instanceID) + pgCfg, err := pgx.ParseConfig(fmt.Sprintf("postgres://%s/?sslmode=require", info.endpoint)) + require.NoError(t, err) + pgCfg.User = info.username + pgCfg.Password = info.password + pgCfg.Database = "postgres" + + conn, err := pgx.ConnectConfig(ctx, pgCfg) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.Close(ctx) + }) + return conn +} + +// mySQLConn wraps a go-mysql conn to provide a client that's thread safe. +type mySQLConn struct { + mu sync.Mutex + conn *mysqlclient.Conn +} + +func (c *mySQLConn) Execute(command string, args ...interface{}) (*mysql.Result, error) { + c.mu.Lock() + defer c.mu.Unlock() + return c.conn.Execute(command, args...) +} + +func connectAsRDSMySQLAdmin(t *testing.T, ctx context.Context, instanceID string) *mySQLConn { + t.Helper() + const dbName = "mysql" + info := getRDSAdminInfo(t, ctx, instanceID) + conn, err := mysqlclient.Connect(info.endpoint, info.username, info.password, dbName) + require.NoError(t, err) + t.Cleanup(func() { + _ = conn.Close() + }) + return &mySQLConn{conn: conn} +} + +func getRDSAdminInfo(t *testing.T, ctx context.Context, instanceID string) rdsAdminInfo { + t.Helper() + cfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion(mustGetEnv(t, awsRegionEnv)), + ) + require.NoError(t, err) + + rdsClt := rds.NewFromConfig(cfg) + result, err := rdsClt.DescribeDBInstances(ctx, &rds.DescribeDBInstancesInput{ + DBInstanceIdentifier: &instanceID, + }) + require.NoError(t, err) + require.Len(t, result.DBInstances, 1) + dbInstance := result.DBInstances[0] + return rdsAdminInfo{ + endpoint: fmt.Sprintf("%s:%d", *dbInstance.Endpoint.Address, *dbInstance.Endpoint.Port), + username: *dbInstance.MasterUsername, + password: getRDSMasterUserPassword(t, ctx, *dbInstance.MasterUserSecret.SecretArn), + } +} + +func getRDSMasterUserPassword(t *testing.T, ctx context.Context, secretID string) string { + t.Helper() + secretVal := getSecretValue(t, ctx, secretID) + type rdsMasterSecret struct { + User string `json:"username"` + Pass string `json:"password"` + } + var secret rdsMasterSecret + if err := json.Unmarshal([]byte(*secretVal.SecretString), &secret); err != nil { + // being paranoid. I don't want to leak the secret string in test error + // logs. + require.FailNow(t, "error unmarshaling secret string") + } + return secret.Pass +} + +func getSecretValue(t *testing.T, ctx context.Context, secretID string) *secretsmanager.GetSecretValueOutput { + t.Helper() + cfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion(mustGetEnv(t, awsRegionEnv)), + ) + require.NoError(t, err) + + secretsClt := secretsmanager.NewFromConfig(cfg) + secretVal, err := secretsClt.GetSecretValue(ctx, &secretsmanager.GetSecretValueInput{ + SecretId: &secretID, + }) + require.NoError(t, err) + return secretVal +} + +// provisionRDSPostgresAutoUsersAdmin provisions an admin user suitable for auto-user +// provisioning. +func provisionRDSPostgresAutoUsersAdmin(t *testing.T, ctx context.Context, conn *pgx.Conn, adminUser string) { + t.Helper() + // Create the admin user and grant rds_iam so Teleport can auth + // with IAM as an existing user. + // Also needed so the auto-user admin can auto-provision others. + // If the admin already exists, ignore errors - there's only + // one admin because the admin has to own all the functions + // we provision and creating a different admin for each test + // is not necessary. + // Don't cleanup the db admin after, because test runs would interfere + // with each other. + _, _ = conn.Exec(ctx, fmt.Sprintf("CREATE USER %q WITH login createrole", adminUser)) + _, err := conn.Exec(ctx, fmt.Sprintf("GRANT rds_iam TO %q WITH ADMIN OPTION", adminUser)) + if err != nil { + require.ErrorContains(t, err, "already a member") + } +} + +// provisionRDSMySQLAutoUsersAdmin provisions an admin user suitable for auto-user +// provisioning. +func provisionRDSMySQLAutoUsersAdmin(t *testing.T, ctx context.Context, conn *mySQLConn, adminUser string) { + t.Helper() + // provision the IAM user to test with. + // ignore errors from user creation. If the user doesn't exist + // later steps will catch it. The error we might get back when + // another test runner already created the admin is + // unpredictable: all we need to know is the user exists for + // test setup. + // Don't cleanup the db admin after, because test runs would interfere + // with each other. + _, _ = conn.Execute(fmt.Sprintf("CREATE USER IF NOT EXISTS %q IDENTIFIED WITH AWSAuthenticationPlugin AS 'RDS'", adminUser)) + + // these statements are all idempotent - they should not return + // an error even if run in parallel by many test runners. + _, err := conn.Execute(fmt.Sprintf("GRANT SELECT ON mysql.role_edges TO %q", adminUser)) + require.NoError(t, err) + _, err = conn.Execute(fmt.Sprintf("GRANT PROCESS, ROLE_ADMIN, CREATE USER ON *.* TO %q", adminUser)) + require.NoError(t, err) + _, err = conn.Execute("CREATE DATABASE IF NOT EXISTS `teleport`") + require.NoError(t, err) + _, err = conn.Execute(fmt.Sprintf("GRANT ALTER ROUTINE, CREATE ROUTINE, EXECUTE ON `teleport`.* TO %q", adminUser)) + require.NoError(t, err) +} + +// randASCII is a helper func that returns a random string of ascii characters. +func randASCII(t *testing.T, length int) string { + t.Helper() + const charset = "abcdefghijklmnopqrstuvwxyz" + b := make([]byte, length) + + _, err := rand.Read(b) + require.NoError(t, err) + + for i := 0; i < length; i++ { + b[i] = charset[int(b[i])%len(charset)] + } + return string(b) +} + +const ( + autoUserWaitDur = 20 * time.Second + autoUserWaitStep = 2 * time.Second +) + +func waitForPostgresAutoUserDeactivate(t *testing.T, ctx context.Context, conn *pgx.Conn, user string) { + t.Helper() + require.EventuallyWithT(t, func(c *assert.CollectT) { + rows, err := conn.Query(ctx, "SELECT 1 FROM pg_roles WHERE rolname=$1", user) + if !assert.NoError(c, err) { + return + } + if !assert.True(c, rows.Next(), "user %q should not have been dropped after disconnecting", user) { + rows.Close() + return + } + rows.Close() + + rows, err = conn.Query(ctx, "SELECT 1 FROM pg_roles WHERE rolname = $1 AND rolcanlogin = false", user) + if !assert.NoError(c, err) { + return + } + if !assert.True(c, rows.Next(), "user %q should not be able to login after deactivating", user) { + rows.Close() + return + } + rows.Close() + + rows, err = conn.Query(ctx, "SELECT 1 FROM pg_roles AS a WHERE pg_has_role($1, a.oid, 'member') AND a.rolname NOT IN ($1, 'teleport-auto-user')", user) + if !assert.NoError(c, err) { + return + } + if !assert.False(c, rows.Next(), "user %q should have lost all additional roles after deactivating", user) { + rows.Close() + return + } + rows.Close() + }, autoUserWaitDur, autoUserWaitStep, "waiting for auto user %q to be deactivated", user) +} + +func waitForPostgresAutoUserDrop(t *testing.T, ctx context.Context, conn *pgx.Conn, user string) { + t.Helper() + require.EventuallyWithT(t, func(c *assert.CollectT) { + rows, err := conn.Query(ctx, "SELECT 1 FROM pg_roles WHERE rolname=$1", user) + if !assert.NoError(c, err) { + return + } + assert.False(c, rows.Next(), "user %q should have been dropped automatically after disconnecting", user) + rows.Close() + }, autoUserWaitDur, autoUserWaitStep, "waiting for auto user %q to be dropped", user) +} + +func waitForMySQLAutoUserDeactivate(t *testing.T, conn *mySQLConn, user string) { + t.Helper() + require.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := conn.Execute("SELECT 1 FROM mysql.user AS u WHERE u.user = ?", user) + if !assert.NoError(c, err) { + return + } + if !assert.Equal(c, 1, result.RowNumber(), "user %q should not have been dropped after disconnecting", user) { + result.Close() + return + } + result.Close() + + result, err = conn.Execute("SELECT 1 FROM mysql.user AS u WHERE u.user = ? AND u.account_locked = 'Y'", user) + if !assert.NoError(c, err) { + return + } + if !assert.Equal(c, 1, result.RowNumber(), "user %q should not be able to login after deactivating", user) { + result.Close() + return + } + result.Close() + + result, err = conn.Execute("SELECT 1 FROM mysql.role_edges AS u WHERE u.to_user = ? AND u.from_user != 'teleport-auto-user'", user) + if !assert.NoError(c, err) { + return + } + if !assert.Equal(c, 0, result.RowNumber(), "user %q should have lost all additional roles after deactivating", user) { + result.Close() + return + } + result.Close() + }, autoUserWaitDur, autoUserWaitStep, "waiting for auto user %q to be deactivated", user) +} + +func waitForMySQLAutoUserDrop(t *testing.T, conn *mySQLConn, user string) { + t.Helper() + require.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := conn.Execute("SELECT 1 FROM mysql.user AS u WHERE u.user = ?", user) + if !assert.NoError(c, err) { + return + } + assert.Equal(c, 0, result.RowNumber(), "user %q should have been dropped automatically after disconnecting", user) + result.Close() + }, autoUserWaitDur, autoUserWaitStep, "waiting for auto user %q to be dropped", user) +} diff --git a/go.mod b/go.mod index 746ed2e410e74..1ac2f06aa5389 100644 --- a/go.mod +++ b/go.mod @@ -56,6 +56,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/iam v1.31.3 github.com/aws/aws-sdk-go-v2/service/rds v1.76.0 github.com/aws/aws-sdk-go-v2/service/s3 v1.53.0 + github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.28.5 github.com/aws/aws-sdk-go-v2/service/sns v1.29.3 github.com/aws/aws-sdk-go-v2/service/sqs v1.31.3 github.com/aws/aws-sdk-go-v2/service/sts v1.28.5 diff --git a/go.sum b/go.sum index 275ea088893d3..6853ebb378d11 100644 --- a/go.sum +++ b/go.sum @@ -832,6 +832,8 @@ github.com/aws/aws-sdk-go-v2/service/rds v1.76.0 h1:cQUdm2sU/71O1vCCV627GrQz5b9R github.com/aws/aws-sdk-go-v2/service/rds v1.76.0/go.mod h1:TsRoxafRyxgt1c1JWQXmxj/dCEwOkBapTwskET8vgFo= github.com/aws/aws-sdk-go-v2/service/s3 v1.53.0 h1:r3o2YsgW9zRcIP3Q0WCmttFVhTuugeKIvT5z9xDspc0= github.com/aws/aws-sdk-go-v2/service/s3 v1.53.0/go.mod h1:w2E4f8PUfNtyjfL6Iu+mWI96FGttE03z3UdNcUEC4tA= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.28.5 h1:1i3Pq5g1NaXI/u8lTHRVMHyCc0HoZzSk2EFmiy14Hbk= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.28.5/go.mod h1:slgOMs1CQu8UVgwoFqEvCi71L4HVoZgM0r8MtcNP6Mc= github.com/aws/aws-sdk-go-v2/service/sns v1.29.3 h1:R2MIMza/lZex1wIawXmo6S+suwFv/JcxOFSJPpsSVBY= github.com/aws/aws-sdk-go-v2/service/sns v1.29.3/go.mod h1:tr9l7BHYU/SvlJAL9CH56XZNcOBb/d24j3RrXkzzaTA= github.com/aws/aws-sdk-go-v2/service/sqs v1.31.3 h1:AOQ5bXiVWqoEAv8Ag7zgJoDVhOz3lUrZyk1/M45/keU=