Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only apply dynamic AWS settings to dynamic AWS dbs #50970

Merged
merged 2 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions lib/srv/db/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,10 @@ func (m *monitoredDatabases) setCloud(databases types.Databases) {
m.cloud = databases
}

func (m *monitoredDatabases) isCloud(database types.Database) bool {
m.mu.RLock()
defer m.mu.RUnlock()
// isCloud_Locked returns whether a database was discovered by the cloud
// watchers, aka legacy database discovery done by the db service.
// The lock must be held when calling this function.
func (m *monitoredDatabases) isCloud_Locked(database types.Database) bool {
for i := range m.cloud {
if m.cloud[i] == database {
return true
Expand All @@ -402,13 +403,17 @@ func (m *monitoredDatabases) isCloud(database types.Database) bool {
return false
}

func (m *monitoredDatabases) isDiscoveryResource(database types.Database) bool {
return database.Origin() == types.OriginCloud && m.isResource(database)
// isDiscoveryResource_Locked returns whether a database was discovered by the
// discovery service.
// The lock must be held when calling this function.
func (m *monitoredDatabases) isDiscoveryResource_Locked(database types.Database) bool {
return database.Origin() == types.OriginCloud && m.isResource_Locked(database)
}

func (m *monitoredDatabases) isResource(database types.Database) bool {
m.mu.RLock()
defer m.mu.RUnlock()
// isResource_Locked returns whether a database is a dynamic database, aka a db
// object.
// The lock must be held when calling this function.
func (m *monitoredDatabases) isResource_Locked(database types.Database) bool {
for i := range m.resources {
if m.resources[i] == database {
return true
Expand All @@ -417,9 +422,9 @@ func (m *monitoredDatabases) isResource(database types.Database) bool {
return false
}

func (m *monitoredDatabases) get() map[string]types.Database {
m.mu.RLock()
defer m.mu.RUnlock()
// getLocked returns a slice containing all of the monitored databases.
// The lock must be held when calling this function.
func (m *monitoredDatabases) getLocked() map[string]types.Database {
return utils.FromSlice(append(append(m.static, m.resources...), m.cloud...), types.Database.GetName)
}

Expand Down
33 changes: 25 additions & 8 deletions lib/srv/db/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (s *Server) startReconciler(ctx context.Context) error {
reconciler, err := services.NewReconciler(services.ReconcilerConfig[types.Database]{
Matcher: s.matcher,
GetCurrentResources: s.getResources,
GetNewResources: s.monitoredDatabases.get,
GetNewResources: s.monitoredDatabases.getLocked,
OnCreate: s.onCreate,
OnUpdate: s.onUpdate,
OnDelete: s.onDelete,
Expand All @@ -53,12 +53,15 @@ func (s *Server) startReconciler(ctx context.Context) error {
for {
select {
case <-s.reconcileCh:
// don't let monitored dbs change during reconciliation
s.monitoredDatabases.mu.RLock()
if err := reconciler.Reconcile(ctx); err != nil {
s.log.ErrorContext(ctx, "Failed to reconcile.", "error", err)
}
if s.cfg.OnReconcile != nil {
s.cfg.OnReconcile(s.getProxiedDatabases())
}
s.monitoredDatabases.mu.RUnlock()
case <-ctx.Done():
s.log.DebugContext(ctx, "Reconciler done.")
return
Expand Down Expand Up @@ -169,11 +172,15 @@ func (s *Server) onCreate(ctx context.Context, database types.Database) error {
// copy here so that any attribute changes to the proxied database will not
// affect database objects tracked in s.monitoredDatabases.
databaseCopy := database.Copy()
applyResourceMatchersToDatabase(databaseCopy, s.cfg.ResourceMatchers)

// only apply resource matcher settings to dynamic resources.
if s.monitoredDatabases.isResource_Locked(database) {
s.applyAWSResourceMatcherSettings(databaseCopy)
}

// Run DiscoveryResourceChecker after resource matchers are applied to make
// sure the correct AssumeRoleARN is used.
if s.monitoredDatabases.isDiscoveryResource(database) {
if s.monitoredDatabases.isDiscoveryResource_Locked(database) {
if err := s.cfg.discoveryResourceChecker.Check(ctx, databaseCopy); err != nil {
return trace.Wrap(err)
}
Expand All @@ -187,7 +194,11 @@ func (s *Server) onUpdate(ctx context.Context, database, _ types.Database) error
// copy here so that any attribute changes to the proxied database will not
// affect database objects tracked in s.monitoredDatabases.
databaseCopy := database.Copy()
applyResourceMatchersToDatabase(databaseCopy, s.cfg.ResourceMatchers)

// only apply resource matcher settings to dynamic resources.
if s.monitoredDatabases.isResource_Locked(database) {
s.applyAWSResourceMatcherSettings(databaseCopy)
}
return s.updateDatabase(ctx, databaseCopy)
}

Expand All @@ -200,7 +211,7 @@ func (s *Server) onDelete(ctx context.Context, database types.Database) error {
func (s *Server) matcher(database types.Database) bool {
// In the case of databases discovered by this database server, matchers
// should be skipped.
if s.monitoredDatabases.isCloud(database) {
if s.monitoredDatabases.isCloud_Locked(database) {
return true // Cloud fetchers return only matching databases.
}

Expand All @@ -209,12 +220,18 @@ func (s *Server) matcher(database types.Database) bool {
return services.MatchResourceLabels(s.cfg.ResourceMatchers, database.GetAllLabels())
}

func applyResourceMatchersToDatabase(database types.Database, resourceMatchers []services.ResourceMatcher) {
for _, matcher := range resourceMatchers {
func (s *Server) applyAWSResourceMatcherSettings(database types.Database) {
if !database.IsAWSHosted() {
// dynamic matchers only apply AWS settings (for now), so skip non-AWS
// databases.
return
}
dbLabels := database.GetAllLabels()
for _, matcher := range s.cfg.ResourceMatchers {
if len(matcher.Labels) == 0 || matcher.AWS.AssumeRoleARN == "" {
continue
}
if match, _, _ := services.MatchLabels(matcher.Labels, database.GetAllLabels()); !match {
if match, _, _ := services.MatchLabels(matcher.Labels, dbLabels); !match {
continue
}

Expand Down
63 changes: 43 additions & 20 deletions lib/srv/db/watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package db
import (
"context"
"fmt"
"maps"
"sort"
"testing"
"time"
Expand Down Expand Up @@ -60,11 +61,13 @@ func TestWatcher(t *testing.T) {
// watches for databases with label group=a.
testCtx.setupDatabaseServer(ctx, t, agentParams{
Databases: []types.Database{db0},
ResourceMatchers: []services.ResourceMatcher{
{Labels: types.Labels{
ResourceMatchers: []services.ResourceMatcher{{
Labels: types.Labels{
"group": []string{"a"},
}},
},
},
// these should not be applied to non-AWS databases.
AWS: services.ResourceMatcherAWS{AssumeRoleARN: "some-role", ExternalID: "some-externalid"},
}},
OnReconcile: func(d types.Databases) {
reconcileCh <- d
},
Expand Down Expand Up @@ -137,7 +140,7 @@ func TestWatcher(t *testing.T) {
// ResourceMatchers should be always evaluated for the dynamic registered
// resources.
func TestWatcherDynamicResource(t *testing.T) {
var db1, db2, db3, db4, db5 *types.DatabaseV3
var db1, db2, db3, db4, db5, db6 *types.DatabaseV3
ctx := context.Background()
testCtx := setupTestContext(ctx, t)

Expand Down Expand Up @@ -247,6 +250,7 @@ func TestWatcherDynamicResource(t *testing.T) {
// ResourceMatchers and has AssumeRoleARN set by the discovery service.
discoveredDB5, err := makeDiscoveryDatabase("db5", map[string]string{"group": "b"}, withRDSURL, withDiscoveryAssumeRoleARN)
require.NoError(t, err)
require.True(t, discoveredDB5.IsAWSHosted())
require.True(t, discoveredDB5.IsRDS())

err = testCtx.authServer.CreateDatabase(ctx, discoveredDB5)
Expand All @@ -260,6 +264,23 @@ func TestWatcherDynamicResource(t *testing.T) {
assertReconciledResource(t, reconcileCh, types.Databases{db0, db2, db4, db5})
})

t.Run("non-AWS discovery resource - AssumeRoleARN not applied", func(t *testing.T) {
// Created a discovery service created database resource that matches
// ResourceMatchers but is not an AWS database
_, azureDB := makeAzureSQLServer(t, "discovery-azure", "group")
setDiscoveryTypeLabel(azureDB, types.AzureMatcherSQLServer)
setLabels(azureDB, map[string]string{"group": "b"})
azureDB.SetOrigin(types.OriginCloud)
require.False(t, azureDB.IsAWSHosted())
require.True(t, azureDB.GetAWS().IsEmpty())
require.True(t, azureDB.IsAzure())
err = testCtx.authServer.CreateDatabase(ctx, azureDB)
require.NoError(t, err)

db6 = azureDB.Copy()
assertReconciledResource(t, reconcileCh, types.Databases{db0, db2, db4, db5, db6})
})

t.Run("discovery resource - fail check", func(t *testing.T) {
// Created a discovery service created database resource that fails the
// fakeDiscoveryResourceChecker.
Expand All @@ -268,27 +289,20 @@ func TestWatcherDynamicResource(t *testing.T) {
require.NoError(t, testCtx.authServer.CreateDatabase(ctx, dbFailCheck))

// dbFailCheck should not be proxied.
assertReconciledResource(t, reconcileCh, types.Databases{db0, db2, db4, db5})
assertReconciledResource(t, reconcileCh, types.Databases{db0, db2, db4, db5, db6})
})
}

func setDiscoveryGroupLabel(r types.ResourceWithLabels, discoveryGroup string) {
staticLabels := r.GetStaticLabels()
if staticLabels == nil {
staticLabels = make(map[string]string)
}
if discoveryGroup != "" {
staticLabels[types.TeleportInternalDiscoveryGroupName] = discoveryGroup
}
r.SetStaticLabels(staticLabels)
func setDiscoveryTypeLabel(r types.ResourceWithLabels, matcherType string) {
setLabels(r, map[string]string{types.DiscoveryTypeLabel: matcherType})
}

func setDiscoveryTypeLabel(r types.ResourceWithLabels, matcherType string) {
func setLabels(r types.ResourceWithLabels, newLabels map[string]string) {
staticLabels := r.GetStaticLabels()
if staticLabels == nil {
staticLabels = make(map[string]string)
}
staticLabels[types.DiscoveryTypeLabel] = matcherType
maps.Copy(staticLabels, newLabels)
r.SetStaticLabels(staticLabels)
}

Expand All @@ -301,15 +315,16 @@ func TestWatcherCloudFetchers(t *testing.T) {
redshiftServerlessDatabase, err := discovery.NewDatabaseFromRedshiftServerlessWorkgroup(redshiftServerlessWorkgroup, nil)
require.NoError(t, err)
redshiftServerlessDatabase.SetStatusAWS(redshiftServerlessDatabase.GetAWS())
setDiscoveryGroupLabel(redshiftServerlessDatabase, "")
setDiscoveryTypeLabel(redshiftServerlessDatabase, types.AWSMatcherRedshiftServerless)
redshiftServerlessDatabase.SetOrigin(types.OriginCloud)
discovery.ApplyAWSDatabaseNameSuffix(redshiftServerlessDatabase, types.AWSMatcherRedshiftServerless)
require.Empty(t, redshiftServerlessDatabase.GetAWS().AssumeRoleARN)
require.Empty(t, redshiftServerlessDatabase.GetAWS().ExternalID)
// Test an Azure fetcher.
azSQLServer, azSQLServerDatabase := makeAzureSQLServer(t, "discovery-azure", "group")
setDiscoveryGroupLabel(azSQLServerDatabase, "")
setDiscoveryTypeLabel(azSQLServerDatabase, types.AzureMatcherSQLServer)
azSQLServerDatabase.SetOrigin(types.OriginCloud)
require.False(t, azSQLServerDatabase.IsAWSHosted())
ctx := context.Background()
testCtx := setupTestContext(ctx, t)

Expand All @@ -319,7 +334,15 @@ func TestWatcherCloudFetchers(t *testing.T) {
OnReconcile: func(d types.Databases) {
reconcileCh <- d
},
ResourceMatchers: []services.ResourceMatcher{{
Labels: types.Labels{types.Wildcard: []string{types.Wildcard}},
AWS: services.ResourceMatcherAWS{
AssumeRoleARN: "role-arn",
ExternalID: "external-id",
},
}},
CloudClients: &clients.TestCloudClients{
STS: &mocks.STSClientV1{},
RedshiftServerless: &mocks.RedshiftServerlessMock{
Workgroups: []*redshiftserverless.Workgroup{redshiftServerlessWorkgroup},
},
Expand Down Expand Up @@ -351,7 +374,7 @@ func assertReconciledResource(t *testing.T, ch chan types.Databases, databases t
select {
case d := <-ch:
sort.Sort(d)
require.Equal(t, len(d), len(databases))
require.Equal(t, len(databases), len(d))
require.Empty(t, cmp.Diff(databases, d,
cmpopts.IgnoreFields(types.Metadata{}, "Revision"),
cmpopts.IgnoreFields(types.DatabaseStatusV3{}, "CACert"),
Expand Down
Loading