Skip to content

Commit

Permalink
aws oidc skip aurora clusters without instances (#47605)
Browse files Browse the repository at this point in the history
Instead of returning an error, just skip the cluster that have no
instances.
  • Loading branch information
GavinFrazar authored Oct 18, 2024
1 parent 944e2b7 commit ebeac20
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 31 deletions.
2 changes: 1 addition & 1 deletion lib/auth/integration/integrationv1/awsoidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ func (s *AWSOIDCService) ListDatabases(ctx context.Context, req *integrationpb.L
return nil, trace.Wrap(err)
}

listDBsResp, err := awsoidc.ListDatabases(ctx, listDBsClient, awsoidc.ListDatabasesRequest{
listDBsResp, err := awsoidc.ListDatabases(ctx, listDBsClient, s.logger, awsoidc.ListDatabasesRequest{
Region: req.Region,
RDSType: req.RdsType,
Engines: req.Engines,
Expand Down
33 changes: 18 additions & 15 deletions lib/integrations/awsoidc/listdatabases.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package awsoidc

import (
"context"
"log/slog"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/rds"
Expand Down Expand Up @@ -116,14 +117,14 @@ var listDatabasesPageSize int32 = 50
// https://docs.aws.amazon.com/AmazonRDS/latest/APIReference/API_DescribeDBClusters.html
// https://docs.aws.amazon.com/AmazonRDS/latest/APIReference/API_DescribeDBInstances.html
// It returns a list of Databases and an optional NextToken that can be used to fetch the next page
func ListDatabases(ctx context.Context, clt ListDatabasesClient, req ListDatabasesRequest) (*ListDatabasesResponse, error) {
func ListDatabases(ctx context.Context, clt ListDatabasesClient, log *slog.Logger, req ListDatabasesRequest) (*ListDatabasesResponse, error) {
if err := req.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}

all := &ListDatabasesResponse{}
for {
res, err := listDatabases(ctx, clt, req)
res, err := listDatabases(ctx, clt, log, req)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -140,7 +141,7 @@ func ListDatabases(ctx context.Context, clt ListDatabasesClient, req ListDatabas
}
}

func listDatabases(ctx context.Context, clt ListDatabasesClient, req ListDatabasesRequest) (*ListDatabasesResponse, error) {
func listDatabases(ctx context.Context, clt ListDatabasesClient, log *slog.Logger, req ListDatabasesRequest) (*ListDatabasesResponse, error) {
// Uses https://docs.aws.amazon.com/AmazonRDS/latest/APIReference/API_DescribeDBInstances.html
if req.RDSType == rdsTypeInstance {
ret, err := listDBInstances(ctx, clt, req)
Expand All @@ -151,7 +152,7 @@ func listDatabases(ctx context.Context, clt ListDatabasesClient, req ListDatabas
}

// Uses https://docs.aws.amazon.com/AmazonRDS/latest/APIReference/API_DescribeDBClusters.html
ret, err := listDBClusters(ctx, clt, req)
ret, err := listDBClusters(ctx, clt, log, req)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -199,7 +200,7 @@ func listDBInstances(ctx context.Context, clt ListDatabasesClient, req ListDatab
return ret, nil
}

func listDBClusters(ctx context.Context, clt ListDatabasesClient, req ListDatabasesRequest) (*ListDatabasesResponse, error) {
func listDBClusters(ctx context.Context, clt ListDatabasesClient, log *slog.Logger, req ListDatabasesRequest) (*ListDatabasesResponse, error) {
describeDBClusterInput := &rds.DescribeDBClustersInput{
Filters: []rdsTypes.Filter{
{Name: &filterEngine, Values: req.Engines},
Expand Down Expand Up @@ -231,16 +232,23 @@ func listDBClusters(ctx context.Context, clt ListDatabasesClient, req ListDataba
// To get this value, a member of the cluster is fetched and its Network Information is used to
// populate the RDS Cluster information.
// All the members have the same network information, so picking one at random should not matter.
clusterInstance, err := fetchSingleRDSDBInstance(ctx, clt, req, aws.ToString(db.DBClusterIdentifier))
instances, err := fetchRDSClusterInstances(ctx, clt, req, aws.ToString(db.DBClusterIdentifier))
if err != nil {
return nil, trace.Wrap(err)
}
if len(instances) == 0 {
log.InfoContext(ctx, "Skipping RDS cluster because it has no instances",
"cluster", aws.ToString(db.DBClusterIdentifier),
)
continue
}
instance := &instances[0]

if req.VpcId != "" && !subnetGroupIsInVPC(clusterInstance.DBSubnetGroup, req.VpcId) {
if req.VpcId != "" && !subnetGroupIsInVPC(instance.DBSubnetGroup, req.VpcId) {
continue
}

awsDB, err := common.NewDatabaseFromRDSV2Cluster(&db, clusterInstance)
awsDB, err := common.NewDatabaseFromRDSV2Cluster(&db, instance)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -251,7 +259,7 @@ func listDBClusters(ctx context.Context, clt ListDatabasesClient, req ListDataba
return ret, nil
}

func fetchSingleRDSDBInstance(ctx context.Context, clt ListDatabasesClient, req ListDatabasesRequest, clusterID string) (*rdsTypes.DBInstance, error) {
func fetchRDSClusterInstances(ctx context.Context, clt ListDatabasesClient, req ListDatabasesRequest, clusterID string) ([]rdsTypes.DBInstance, error) {
describeDBInstanceInput := &rds.DescribeDBInstancesInput{
Filters: []rdsTypes.Filter{
{Name: &filterDBClusterID, Values: []string{clusterID}},
Expand All @@ -262,12 +270,7 @@ func fetchSingleRDSDBInstance(ctx context.Context, clt ListDatabasesClient, req
if err != nil {
return nil, trace.Wrap(err)
}

if len(rdsDBs.DBInstances) == 0 {
return nil, trace.BadParameter("database cluster %s has no instance", clusterID)
}

return &rdsDBs.DBInstances[0], nil
return rdsDBs.DBInstances, nil
}

// subnetGroupIsInVPC is a simple helper to check if a db subnet group is in
Expand Down
88 changes: 73 additions & 15 deletions lib/integrations/awsoidc/listdatabases_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/utils"
)

func stringPointer(s string) *string {
Expand Down Expand Up @@ -140,6 +141,7 @@ func TestListDatabases(t *testing.T) {

t.Run("without vpc filter", func(t *testing.T) {
t.Parallel()
logger := utils.NewSlogLoggerForTests().With("test", t.Name())
// First page must return pageSize number of DBs
req := ListDatabasesRequest{
Region: "us-east-1",
Expand All @@ -149,25 +151,26 @@ func TestListDatabases(t *testing.T) {
NextToken: "",
}
for i := 0; i < totalDBs/int(listDatabasesPageSize); i++ {
resp, err := ListDatabases(ctx, mockListClient, req)
resp, err := ListDatabases(ctx, mockListClient, logger, req)
require.NoError(t, err)
require.Len(t, resp.Databases, int(listDatabasesPageSize))
require.NotEmpty(t, resp.NextToken)
req.NextToken = resp.NextToken
}
// Last page must return remaining databases and an empty token.
resp, err := ListDatabases(ctx, mockListClient, req)
resp, err := ListDatabases(ctx, mockListClient, logger, req)
require.NoError(t, err)
require.Len(t, resp.Databases, totalDBs%int(listDatabasesPageSize))
require.Empty(t, resp.NextToken)
})

t.Run("with vpc filter", func(t *testing.T) {
t.Parallel()
logger := utils.NewSlogLoggerForTests().With("test", t.Name())
// First page must return at least pageSize number of DBs
var gotDatabases []types.Database
wantVPC := "vpc-2"
resp, err := ListDatabases(ctx, mockListClient, ListDatabasesRequest{
resp, err := ListDatabases(ctx, mockListClient, logger, ListDatabasesRequest{
Region: "us-east-1",
RDSType: "instance",
Engines: []string{"postgres"},
Expand All @@ -188,7 +191,7 @@ func TestListDatabases(t *testing.T) {
gotDatabases = append(gotDatabases, resp.Databases...)

// Second page must return pageSize number of DBs
resp, err = ListDatabases(ctx, mockListClient, ListDatabasesRequest{
resp, err = ListDatabases(ctx, mockListClient, logger, ListDatabasesRequest{
Region: "us-east-1",
RDSType: "instance",
Engines: []string{"postgres"},
Expand All @@ -202,7 +205,7 @@ func TestListDatabases(t *testing.T) {
gotDatabases = append(gotDatabases, resp.Databases...)

// Third page must return only the remaining DBs and an empty nextToken
resp, err = ListDatabases(ctx, mockListClient, ListDatabasesRequest{
resp, err = ListDatabases(ctx, mockListClient, logger, ListDatabasesRequest{
Region: "us-east-1",
RDSType: "instance",
Engines: []string{"postgres"},
Expand Down Expand Up @@ -583,23 +586,77 @@ func TestListDatabases(t *testing.T) {
},

{
name: "cluster exists but no instance exists, returns an error",
name: "listing clusters returns all valid clusters and ignores the others",
req: ListDatabasesRequest{
Region: "us-east-1",
RDSType: "cluster",
Engines: []string{"postgres"},
NextToken: "",
},
mockClusters: []rdsTypes.DBCluster{{
Status: stringPointer("available"),
mockInstances: []rdsTypes.DBInstance{{
DBClusterIdentifier: stringPointer("my-dbc"),
DbClusterResourceId: stringPointer("db-123"),
Engine: stringPointer("aurora-postgresql"),
Endpoint: stringPointer("aurora-instance-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com"),
Port: &clusterPort,
DBClusterArn: stringPointer("arn:aws:iam::123456789012:role/MyARN"),
DBSubnetGroup: &rdsTypes.DBSubnetGroup{
Subnets: []rdsTypes.Subnet{{SubnetIdentifier: aws.String("subnet-999")}},
VpcId: aws.String("vpc-999"),
},
}},
errCheck: trace.IsBadParameter,
mockClusters: []rdsTypes.DBCluster{
{
Status: stringPointer("available"),
DBClusterIdentifier: stringPointer("my-empty-cluster"),
DbClusterResourceId: stringPointer("db-456"),
Engine: stringPointer("aurora-mysql"),
Endpoint: stringPointer("aurora-instance-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com"),
Port: &clusterPort,
DBClusterArn: stringPointer("arn:aws:iam::123456789012:role/MyARN"),
},
{
Status: stringPointer("available"),
DBClusterIdentifier: stringPointer("my-dbc"),
DbClusterResourceId: stringPointer("db-123"),
Engine: stringPointer("aurora-postgresql"),
Endpoint: stringPointer("aurora-instance-2.abcdefghijklmnop.us-west-1.rds.amazonaws.com"),
Port: &clusterPort,
DBClusterArn: stringPointer("arn:aws:iam::123456789012:role/MyARN"),
},
},
respCheck: func(t *testing.T, ldr *ListDatabasesResponse) {
require.Len(t, ldr.Databases, 1, "expected 1 database, got %d", len(ldr.Databases))
require.Empty(t, ldr.NextToken, "expected an empty NextToken")
expectedDB, err := types.NewDatabaseV3(
types.Metadata{
Name: "my-dbc",
Description: "Aurora cluster in ",
Labels: map[string]string{
"account-id": "123456789012",
"endpoint-type": "primary",
"engine": "aurora-postgresql",
"engine-version": "",
"region": "",
"status": "available",
"vpc-id": "vpc-999",
"teleport.dev/cloud": "AWS",
},
},
types.DatabaseSpecV3{
Protocol: "postgres",
URI: "aurora-instance-2.abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432",
AWS: types.AWS{
AccountID: "123456789012",
RDS: types.RDS{
ClusterID: "my-dbc",
InstanceID: "aurora-instance-2",
ResourceID: "db-123",
Subnets: []string{"subnet-999"},
VPCID: "vpc-999",
},
},
},
)
require.NoError(t, err)
require.Empty(t, cmp.Diff(expectedDB, ldr.Databases[0]))
},
errCheck: noErrorFunc,
},
{
name: "no region",
Expand Down Expand Up @@ -637,7 +694,8 @@ func TestListDatabases(t *testing.T) {
dbInstances: tt.mockInstances,
dbClusters: tt.mockClusters,
}
resp, err := ListDatabases(ctx, mockListClient, tt.req)
logger := utils.NewSlogLoggerForTests().With("test", t.Name())
resp, err := ListDatabases(ctx, mockListClient, logger, tt.req)
require.True(t, tt.errCheck(err), "unexpected err: %v", err)
if err != nil {
return
Expand Down

0 comments on commit ebeac20

Please sign in to comment.