Skip to content

Commit

Permalink
Only apply dynamic AWS settings to dynamic AWS dbs (#50970)
Browse files Browse the repository at this point in the history
* Only apply dynamic AWS settings to dynamic AWS dbs

Dynamic database resource matchers can include AWS settings to assume an
AWS IAM role when they match a database.
The settings should only be applied to dynamic AWS databases.

The db service will no longer apply these settings to non-AWS databases.

It will also no longer apply these settings to databases discovered by
the legacy cloud watchers in db_service.aws - the cloud watchers have an
assume_role_arn setting that should not be overridden by dynamic
database matcher settings.

* fix reconcilitation race
  • Loading branch information
GavinFrazar authored and mvbrock committed Jan 18, 2025
1 parent 2ed5072 commit f8c9466
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 39 deletions.
27 changes: 16 additions & 11 deletions lib/srv/db/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,10 @@ func (m *monitoredDatabases) setCloud(databases types.Databases) {
m.cloud = databases
}

func (m *monitoredDatabases) isCloud(database types.Database) bool {
m.mu.RLock()
defer m.mu.RUnlock()
// isCloud_Locked returns whether a database was discovered by the cloud
// watchers, aka legacy database discovery done by the db service.
// The lock must be held when calling this function.
func (m *monitoredDatabases) isCloud_Locked(database types.Database) bool {
for i := range m.cloud {
if m.cloud[i] == database {
return true
Expand All @@ -402,13 +403,17 @@ func (m *monitoredDatabases) isCloud(database types.Database) bool {
return false
}

func (m *monitoredDatabases) isDiscoveryResource(database types.Database) bool {
return database.Origin() == types.OriginCloud && m.isResource(database)
// isDiscoveryResource_Locked returns whether a database was discovered by the
// discovery service.
// The lock must be held when calling this function.
func (m *monitoredDatabases) isDiscoveryResource_Locked(database types.Database) bool {
return database.Origin() == types.OriginCloud && m.isResource_Locked(database)
}

func (m *monitoredDatabases) isResource(database types.Database) bool {
m.mu.RLock()
defer m.mu.RUnlock()
// isResource_Locked returns whether a database is a dynamic database, aka a db
// object.
// The lock must be held when calling this function.
func (m *monitoredDatabases) isResource_Locked(database types.Database) bool {
for i := range m.resources {
if m.resources[i] == database {
return true
Expand All @@ -417,9 +422,9 @@ func (m *monitoredDatabases) isResource(database types.Database) bool {
return false
}

func (m *monitoredDatabases) get() map[string]types.Database {
m.mu.RLock()
defer m.mu.RUnlock()
// getLocked returns a slice containing all of the monitored databases.
// The lock must be held when calling this function.
func (m *monitoredDatabases) getLocked() map[string]types.Database {
return utils.FromSlice(append(append(m.static, m.resources...), m.cloud...), types.Database.GetName)
}

Expand Down
33 changes: 25 additions & 8 deletions lib/srv/db/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (s *Server) startReconciler(ctx context.Context) error {
reconciler, err := services.NewReconciler(services.ReconcilerConfig[types.Database]{
Matcher: s.matcher,
GetCurrentResources: s.getResources,
GetNewResources: s.monitoredDatabases.get,
GetNewResources: s.monitoredDatabases.getLocked,
OnCreate: s.onCreate,
OnUpdate: s.onUpdate,
OnDelete: s.onDelete,
Expand All @@ -53,12 +53,15 @@ func (s *Server) startReconciler(ctx context.Context) error {
for {
select {
case <-s.reconcileCh:
// don't let monitored dbs change during reconciliation
s.monitoredDatabases.mu.RLock()
if err := reconciler.Reconcile(ctx); err != nil {
s.log.ErrorContext(ctx, "Failed to reconcile.", "error", err)
}
if s.cfg.OnReconcile != nil {
s.cfg.OnReconcile(s.getProxiedDatabases())
}
s.monitoredDatabases.mu.RUnlock()
case <-ctx.Done():
s.log.DebugContext(ctx, "Reconciler done.")
return
Expand Down Expand Up @@ -169,11 +172,15 @@ func (s *Server) onCreate(ctx context.Context, database types.Database) error {
// copy here so that any attribute changes to the proxied database will not
// affect database objects tracked in s.monitoredDatabases.
databaseCopy := database.Copy()
applyResourceMatchersToDatabase(databaseCopy, s.cfg.ResourceMatchers)

// only apply resource matcher settings to dynamic resources.
if s.monitoredDatabases.isResource_Locked(database) {
s.applyAWSResourceMatcherSettings(databaseCopy)
}

// Run DiscoveryResourceChecker after resource matchers are applied to make
// sure the correct AssumeRoleARN is used.
if s.monitoredDatabases.isDiscoveryResource(database) {
if s.monitoredDatabases.isDiscoveryResource_Locked(database) {
if err := s.cfg.discoveryResourceChecker.Check(ctx, databaseCopy); err != nil {
return trace.Wrap(err)
}
Expand All @@ -187,7 +194,11 @@ func (s *Server) onUpdate(ctx context.Context, database, _ types.Database) error
// copy here so that any attribute changes to the proxied database will not
// affect database objects tracked in s.monitoredDatabases.
databaseCopy := database.Copy()
applyResourceMatchersToDatabase(databaseCopy, s.cfg.ResourceMatchers)

// only apply resource matcher settings to dynamic resources.
if s.monitoredDatabases.isResource_Locked(database) {
s.applyAWSResourceMatcherSettings(databaseCopy)
}
return s.updateDatabase(ctx, databaseCopy)
}

Expand All @@ -200,7 +211,7 @@ func (s *Server) onDelete(ctx context.Context, database types.Database) error {
func (s *Server) matcher(database types.Database) bool {
// In the case of databases discovered by this database server, matchers
// should be skipped.
if s.monitoredDatabases.isCloud(database) {
if s.monitoredDatabases.isCloud_Locked(database) {
return true // Cloud fetchers return only matching databases.
}

Expand All @@ -209,12 +220,18 @@ func (s *Server) matcher(database types.Database) bool {
return services.MatchResourceLabels(s.cfg.ResourceMatchers, database.GetAllLabels())
}

func applyResourceMatchersToDatabase(database types.Database, resourceMatchers []services.ResourceMatcher) {
for _, matcher := range resourceMatchers {
func (s *Server) applyAWSResourceMatcherSettings(database types.Database) {
if !database.IsAWSHosted() {
// dynamic matchers only apply AWS settings (for now), so skip non-AWS
// databases.
return
}
dbLabels := database.GetAllLabels()
for _, matcher := range s.cfg.ResourceMatchers {
if len(matcher.Labels) == 0 || matcher.AWS.AssumeRoleARN == "" {
continue
}
if match, _, _ := services.MatchLabels(matcher.Labels, database.GetAllLabels()); !match {
if match, _, _ := services.MatchLabels(matcher.Labels, dbLabels); !match {
continue
}

Expand Down
63 changes: 43 additions & 20 deletions lib/srv/db/watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package db
import (
"context"
"fmt"
"maps"
"sort"
"testing"
"time"
Expand Down Expand Up @@ -60,11 +61,13 @@ func TestWatcher(t *testing.T) {
// watches for databases with label group=a.
testCtx.setupDatabaseServer(ctx, t, agentParams{
Databases: []types.Database{db0},
ResourceMatchers: []services.ResourceMatcher{
{Labels: types.Labels{
ResourceMatchers: []services.ResourceMatcher{{
Labels: types.Labels{
"group": []string{"a"},
}},
},
},
// these should not be applied to non-AWS databases.
AWS: services.ResourceMatcherAWS{AssumeRoleARN: "some-role", ExternalID: "some-externalid"},
}},
OnReconcile: func(d types.Databases) {
reconcileCh <- d
},
Expand Down Expand Up @@ -137,7 +140,7 @@ func TestWatcher(t *testing.T) {
// ResourceMatchers should be always evaluated for the dynamic registered
// resources.
func TestWatcherDynamicResource(t *testing.T) {
var db1, db2, db3, db4, db5 *types.DatabaseV3
var db1, db2, db3, db4, db5, db6 *types.DatabaseV3
ctx := context.Background()
testCtx := setupTestContext(ctx, t)

Expand Down Expand Up @@ -247,6 +250,7 @@ func TestWatcherDynamicResource(t *testing.T) {
// ResourceMatchers and has AssumeRoleARN set by the discovery service.
discoveredDB5, err := makeDiscoveryDatabase("db5", map[string]string{"group": "b"}, withRDSURL, withDiscoveryAssumeRoleARN)
require.NoError(t, err)
require.True(t, discoveredDB5.IsAWSHosted())
require.True(t, discoveredDB5.IsRDS())

err = testCtx.authServer.CreateDatabase(ctx, discoveredDB5)
Expand All @@ -260,6 +264,23 @@ func TestWatcherDynamicResource(t *testing.T) {
assertReconciledResource(t, reconcileCh, types.Databases{db0, db2, db4, db5})
})

t.Run("non-AWS discovery resource - AssumeRoleARN not applied", func(t *testing.T) {
// Created a discovery service created database resource that matches
// ResourceMatchers but is not an AWS database
_, azureDB := makeAzureSQLServer(t, "discovery-azure", "group")
setDiscoveryTypeLabel(azureDB, types.AzureMatcherSQLServer)
setLabels(azureDB, map[string]string{"group": "b"})
azureDB.SetOrigin(types.OriginCloud)
require.False(t, azureDB.IsAWSHosted())
require.True(t, azureDB.GetAWS().IsEmpty())
require.True(t, azureDB.IsAzure())
err = testCtx.authServer.CreateDatabase(ctx, azureDB)
require.NoError(t, err)

db6 = azureDB.Copy()
assertReconciledResource(t, reconcileCh, types.Databases{db0, db2, db4, db5, db6})
})

t.Run("discovery resource - fail check", func(t *testing.T) {
// Created a discovery service created database resource that fails the
// fakeDiscoveryResourceChecker.
Expand All @@ -268,27 +289,20 @@ func TestWatcherDynamicResource(t *testing.T) {
require.NoError(t, testCtx.authServer.CreateDatabase(ctx, dbFailCheck))

// dbFailCheck should not be proxied.
assertReconciledResource(t, reconcileCh, types.Databases{db0, db2, db4, db5})
assertReconciledResource(t, reconcileCh, types.Databases{db0, db2, db4, db5, db6})
})
}

func setDiscoveryGroupLabel(r types.ResourceWithLabels, discoveryGroup string) {
staticLabels := r.GetStaticLabels()
if staticLabels == nil {
staticLabels = make(map[string]string)
}
if discoveryGroup != "" {
staticLabels[types.TeleportInternalDiscoveryGroupName] = discoveryGroup
}
r.SetStaticLabels(staticLabels)
func setDiscoveryTypeLabel(r types.ResourceWithLabels, matcherType string) {
setLabels(r, map[string]string{types.DiscoveryTypeLabel: matcherType})
}

func setDiscoveryTypeLabel(r types.ResourceWithLabels, matcherType string) {
func setLabels(r types.ResourceWithLabels, newLabels map[string]string) {
staticLabels := r.GetStaticLabels()
if staticLabels == nil {
staticLabels = make(map[string]string)
}
staticLabels[types.DiscoveryTypeLabel] = matcherType
maps.Copy(staticLabels, newLabels)
r.SetStaticLabels(staticLabels)
}

Expand All @@ -301,15 +315,16 @@ func TestWatcherCloudFetchers(t *testing.T) {
redshiftServerlessDatabase, err := discovery.NewDatabaseFromRedshiftServerlessWorkgroup(redshiftServerlessWorkgroup, nil)
require.NoError(t, err)
redshiftServerlessDatabase.SetStatusAWS(redshiftServerlessDatabase.GetAWS())
setDiscoveryGroupLabel(redshiftServerlessDatabase, "")
setDiscoveryTypeLabel(redshiftServerlessDatabase, types.AWSMatcherRedshiftServerless)
redshiftServerlessDatabase.SetOrigin(types.OriginCloud)
discovery.ApplyAWSDatabaseNameSuffix(redshiftServerlessDatabase, types.AWSMatcherRedshiftServerless)
require.Empty(t, redshiftServerlessDatabase.GetAWS().AssumeRoleARN)
require.Empty(t, redshiftServerlessDatabase.GetAWS().ExternalID)
// Test an Azure fetcher.
azSQLServer, azSQLServerDatabase := makeAzureSQLServer(t, "discovery-azure", "group")
setDiscoveryGroupLabel(azSQLServerDatabase, "")
setDiscoveryTypeLabel(azSQLServerDatabase, types.AzureMatcherSQLServer)
azSQLServerDatabase.SetOrigin(types.OriginCloud)
require.False(t, azSQLServerDatabase.IsAWSHosted())
ctx := context.Background()
testCtx := setupTestContext(ctx, t)

Expand All @@ -319,7 +334,15 @@ func TestWatcherCloudFetchers(t *testing.T) {
OnReconcile: func(d types.Databases) {
reconcileCh <- d
},
ResourceMatchers: []services.ResourceMatcher{{
Labels: types.Labels{types.Wildcard: []string{types.Wildcard}},
AWS: services.ResourceMatcherAWS{
AssumeRoleARN: "role-arn",
ExternalID: "external-id",
},
}},
CloudClients: &clients.TestCloudClients{
STS: &mocks.STSClientV1{},
RedshiftServerless: &mocks.RedshiftServerlessMock{
Workgroups: []*redshiftserverless.Workgroup{redshiftServerlessWorkgroup},
},
Expand Down Expand Up @@ -351,7 +374,7 @@ func assertReconciledResource(t *testing.T, ch chan types.Databases, databases t
select {
case d := <-ch:
sort.Sort(d)
require.Equal(t, len(d), len(databases))
require.Equal(t, len(databases), len(d))
require.Empty(t, cmp.Diff(databases, d,
cmpopts.IgnoreFields(types.Metadata{}, "Revision"),
cmpopts.IgnoreFields(types.DatabaseStatusV3{}, "CACert"),
Expand Down

0 comments on commit f8c9466

Please sign in to comment.