Skip to content

Commit

Permalink
TAG: Poll AWS RDS instances and clusters fom AWS (#39013)
Browse files Browse the repository at this point in the history
This PR polls the AWS RDS instances & clusters and syncs them into TAG.

Part of gravitational/access-graph#459

Signed-off-by: Tiago Silva <[email protected]>
  • Loading branch information
tigrato committed Mar 12, 2024
1 parent d9f330f commit 84bd915
Show file tree
Hide file tree
Showing 7 changed files with 1,096 additions and 422 deletions.
1,100 changes: 678 additions & 422 deletions gen/proto/go/accessgraph/v1alpha/aws.pb.go

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions lib/srv/discovery/fetchers/aws-sync/aws-sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ type Resources struct {
AssociatedAccessPolicies []*accessgraphv1alpha.AWSEKSAssociatedAccessPolicyV1
// AccessEntries is the list of Access Entries.
AccessEntries []*accessgraphv1alpha.AWSEKSClusterAccessEntryV1
// RDSDatabases is a list of RDS instances and clusters.
RDSDatabases []*accessgraphv1alpha.AWSRDSDatabaseV1
}

// NewAWSFetcher creates a new AWS fetcher.
Expand Down Expand Up @@ -181,6 +183,9 @@ func (a *awsFetcher) poll(ctx context.Context) (*Resources, error) {
// fetch AWS EKS clusters
eGroup.Go(a.pollAWSEKSClusters(ctx, result, collectErr))

// fetch AWS RDS instances and clusters
eGroup.Go(a.pollAWSRDSDatabases(ctx, result, collectErr))

if err := eGroup.Wait(); err != nil {
return nil, trace.Wrap(err)
}
Expand Down
1 change: 1 addition & 0 deletions lib/srv/discovery/fetchers/aws-sync/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func MergeResources(results ...*Resources) *Resources {
result.AssociatedAccessPolicies = append(result.AssociatedAccessPolicies, r.AssociatedAccessPolicies...)
result.EKSClusters = append(result.EKSClusters, r.EKSClusters...)
result.AccessEntries = append(result.AccessEntries, r.AccessEntries...)
result.RDSDatabases = append(result.RDSDatabases, r.RDSDatabases...)
}
return result
}
172 changes: 172 additions & 0 deletions lib/srv/discovery/fetchers/aws-sync/rds.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package aws_sync

import (
"context"
"sync"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/rds"
"github.com/gravitational/trace"
"golang.org/x/sync/errgroup"

accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha"
)

// pollAWSRDSDatabases is a function that returns a function that fetches
// RDS instances and clusters.
func (a *awsFetcher) pollAWSRDSDatabases(ctx context.Context, result *Resources, collectErr func(error)) func() error {
return func() error {
var err error
result.RDSDatabases, err = a.fetchAWSRDSDatabases(ctx)
if err != nil {
collectErr(trace.Wrap(err, "failed to fetch databases"))
}
return nil
}
}

// fetchAWSRDSDatabases fetches RDS databases from all regions.
func (a *awsFetcher) fetchAWSRDSDatabases(ctx context.Context) (
[]*accessgraphv1alpha.AWSRDSDatabaseV1,
error,
) {
var (
dbs []*accessgraphv1alpha.AWSRDSDatabaseV1
hostsMu sync.Mutex
errs []error
)
eG, ctx := errgroup.WithContext(ctx)
// Set the limit to 10 to avoid too many concurrent requests.
// This is a temporary solution until we have a better way to limit the
// number of concurrent requests.
eG.SetLimit(10)
collectDBs := func(db *accessgraphv1alpha.AWSRDSDatabaseV1, err error) {
hostsMu.Lock()
defer hostsMu.Unlock()
if err != nil {
errs = append(errs, err)
}
if db != nil {
dbs = append(dbs, db)
}

}

for _, region := range a.Regions {
region := region
eG.Go(func() error {
rdsClient, err := a.CloudClients.GetAWSRDSClient(ctx, region, a.getAWSOptions()...)
if err != nil {
collectDBs(nil, trace.Wrap(err))
return nil
}
err = rdsClient.DescribeDBInstancesPagesWithContext(ctx, &rds.DescribeDBInstancesInput{},
func(output *rds.DescribeDBInstancesOutput, lastPage bool) bool {
for _, db := range output.DBInstances {
// if instance belongs to a cluster, skip it as we want to represent the cluster itself
// and we pull it using DescribeDBClustersPagesWithContext instead.
if aws.StringValue(db.DBClusterIdentifier) != "" {
continue
}
protoRDS := awsRDSInstanceToRDS(db, region, a.AccountID)
collectDBs(protoRDS, nil)
}
return !lastPage
},
)
if err != nil {
collectDBs(nil, trace.Wrap(err))
}

err = rdsClient.DescribeDBClustersPagesWithContext(ctx, &rds.DescribeDBClustersInput{},
func(output *rds.DescribeDBClustersOutput, lastPage bool) bool {
for _, db := range output.DBClusters {
protoRDS := awsRDSClusterToRDS(db, region, a.AccountID)
collectDBs(protoRDS, nil)
}
return !lastPage
},
)
if err != nil {
collectDBs(nil, trace.Wrap(err))
}

return nil
})
}

err := eG.Wait()
return dbs, trace.NewAggregate(append(errs, err)...)
}

// awsRDSInstanceToRDS converts an rds.DBInstance to accessgraphv1alpha.AWSRDSDatabaseV1
// representation.
func awsRDSInstanceToRDS(instance *rds.DBInstance, region, accountID string) *accessgraphv1alpha.AWSRDSDatabaseV1 {
var tags []*accessgraphv1alpha.AWSTag
for _, v := range instance.TagList {
tags = append(tags, &accessgraphv1alpha.AWSTag{
Key: aws.StringValue(v.Key),
Value: strPtrToWrapper(v.Value),
})
}

return &accessgraphv1alpha.AWSRDSDatabaseV1{
Name: aws.StringValue(instance.DBName),
Arn: aws.StringValue(instance.DBInstanceArn),
CreatedAt: awsTimeToProtoTime(instance.InstanceCreateTime),
Status: aws.StringValue(instance.DBInstanceStatus),
Region: region,
AccountId: accountID,
Tags: tags,
EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{
Engine: aws.StringValue(instance.Engine),
Version: aws.StringValue(instance.EngineVersion),
},
IsCluster: false,
}
}

// awsRDSInstanceToRDS converts an rds.DBCluster to accessgraphv1alpha.AWSRDSDatabaseV1
// representation.
func awsRDSClusterToRDS(instance *rds.DBCluster, region, accountID string) *accessgraphv1alpha.AWSRDSDatabaseV1 {
var tags []*accessgraphv1alpha.AWSTag
for _, v := range instance.TagList {
tags = append(tags, &accessgraphv1alpha.AWSTag{
Key: aws.StringValue(v.Key),
Value: strPtrToWrapper(v.Value),
})
}

return &accessgraphv1alpha.AWSRDSDatabaseV1{
Name: aws.StringValue(instance.DatabaseName),
Arn: aws.StringValue(instance.DBClusterArn),
CreatedAt: awsTimeToProtoTime(instance.ClusterCreateTime),
Status: aws.StringValue(instance.Status),
Region: region,
AccountId: accountID,
Tags: tags,
EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{
Engine: aws.StringValue(instance.Engine),
Version: aws.StringValue(instance.EngineVersion),
},
IsCluster: true,
}
}
180 changes: 180 additions & 0 deletions lib/srv/discovery/fetchers/aws-sync/rds_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package aws_sync

import (
"context"
"sync"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/rds"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"

accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha"
"github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/cloud/mocks"
)

func TestPollAWSRDS(t *testing.T) {
const (
accountID = "12345678"
)
var (
regions = []string{"eu-west-1"}
)

tests := []struct {
name string
want *Resources
}{
{
name: "poll rds databases",
want: &Resources{
RDSDatabases: []*accessgraphv1alpha.AWSRDSDatabaseV1{
{
Arn: "arn:us-west1:rds:instance1",
Status: rds.DBProxyStatusAvailable,
Name: "db1",
EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{
Engine: rds.EngineFamilyMysql,
Version: "v1.1",
},
CreatedAt: timestamppb.New(date),
Tags: []*accessgraphv1alpha.AWSTag{
{
Key: "tag",
Value: wrapperspb.String("val"),
},
},
Region: "eu-west-1",
IsCluster: false,
AccountId: "12345678",
},
{
Arn: "arn:us-west1:rds:cluster1",
Status: rds.DBProxyStatusAvailable,
Name: "cluster1",
EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{
Engine: rds.EngineFamilyMysql,
Version: "v1.1",
},
CreatedAt: timestamppb.New(date),
Tags: []*accessgraphv1alpha.AWSTag{
{
Key: "tag",
Value: wrapperspb.String("val"),
},
},
Region: "eu-west-1",
IsCluster: true,
AccountId: "12345678",
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockedClients := &cloud.TestCloudClients{
RDS: &mocks.RDSMock{
DBInstances: dbInstances(),
DBClusters: dbClusters(),
},
}

var (
errs []error
mu sync.Mutex
)

collectErr := func(err error) {
mu.Lock()
defer mu.Unlock()
errs = append(errs, err)
}
a := &awsFetcher{
Config: Config{
AccountID: accountID,
CloudClients: mockedClients,
Regions: regions,
Integration: accountID,
},
}
result := &Resources{}
execFunc := a.pollAWSRDSDatabases(context.Background(), result, collectErr)
require.NoError(t, execFunc())
require.Empty(t, cmp.Diff(
tt.want,
result,
protocmp.Transform(),
// tags originate from a map so we must sort them before comparing.
protocmp.SortRepeated(
func(a, b *accessgraphv1alpha.AWSTag) bool {
return a.Key < b.Key
},
),
),
)

})
}
}

func dbInstances() []*rds.DBInstance {
return []*rds.DBInstance{
{
DBName: aws.String("db1"),
DBInstanceArn: aws.String("arn:us-west1:rds:instance1"),
InstanceCreateTime: aws.Time(date),
Engine: aws.String(rds.EngineFamilyMysql),
DBInstanceStatus: aws.String(rds.DBProxyStatusAvailable),
EngineVersion: aws.String("v1.1"),
TagList: []*rds.Tag{
{
Key: aws.String("tag"),
Value: aws.String("val"),
},
},
},
}
}

func dbClusters() []*rds.DBCluster {
return []*rds.DBCluster{
{
DatabaseName: aws.String("cluster1"),
DBClusterArn: aws.String("arn:us-west1:rds:cluster1"),
ClusterCreateTime: aws.Time(date),
Engine: aws.String(rds.EngineFamilyMysql),
Status: aws.String(rds.DBProxyStatusAvailable),
EngineVersion: aws.String("v1.1"),
TagList: []*rds.Tag{
{
Key: aws.String("tag"),
Value: aws.String("val"),
},
},
},
}
}
Loading

0 comments on commit 84bd915

Please sign in to comment.