From ebeac205b7a1a3487a7046ea12307426642c728a Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Fri, 18 Oct 2024 12:01:37 -0700 Subject: [PATCH] aws oidc skip aurora clusters without instances (#47605) Instead of returning an error, just skip the cluster that have no instances. --- lib/auth/integration/integrationv1/awsoidc.go | 2 +- lib/integrations/awsoidc/listdatabases.go | 33 +++---- .../awsoidc/listdatabases_test.go | 88 +++++++++++++++---- 3 files changed, 92 insertions(+), 31 deletions(-) diff --git a/lib/auth/integration/integrationv1/awsoidc.go b/lib/auth/integration/integrationv1/awsoidc.go index 1e56a1b90e91e..7911daf0d3a40 100644 --- a/lib/auth/integration/integrationv1/awsoidc.go +++ b/lib/auth/integration/integrationv1/awsoidc.go @@ -315,7 +315,7 @@ func (s *AWSOIDCService) ListDatabases(ctx context.Context, req *integrationpb.L return nil, trace.Wrap(err) } - listDBsResp, err := awsoidc.ListDatabases(ctx, listDBsClient, awsoidc.ListDatabasesRequest{ + listDBsResp, err := awsoidc.ListDatabases(ctx, listDBsClient, s.logger, awsoidc.ListDatabasesRequest{ Region: req.Region, RDSType: req.RdsType, Engines: req.Engines, diff --git a/lib/integrations/awsoidc/listdatabases.go b/lib/integrations/awsoidc/listdatabases.go index d6b8af3720b6a..8298ec7aef609 100644 --- a/lib/integrations/awsoidc/listdatabases.go +++ b/lib/integrations/awsoidc/listdatabases.go @@ -20,6 +20,7 @@ package awsoidc import ( "context" + "log/slog" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/rds" @@ -116,14 +117,14 @@ var listDatabasesPageSize int32 = 50 // https://docs.aws.amazon.com/AmazonRDS/latest/APIReference/API_DescribeDBClusters.html // https://docs.aws.amazon.com/AmazonRDS/latest/APIReference/API_DescribeDBInstances.html // It returns a list of Databases and an optional NextToken that can be used to fetch the next page -func ListDatabases(ctx context.Context, clt ListDatabasesClient, req ListDatabasesRequest) (*ListDatabasesResponse, error) { +func ListDatabases(ctx context.Context, clt ListDatabasesClient, log *slog.Logger, req ListDatabasesRequest) (*ListDatabasesResponse, error) { if err := req.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } all := &ListDatabasesResponse{} for { - res, err := listDatabases(ctx, clt, req) + res, err := listDatabases(ctx, clt, log, req) if err != nil { return nil, trace.Wrap(err) } @@ -140,7 +141,7 @@ func ListDatabases(ctx context.Context, clt ListDatabasesClient, req ListDatabas } } -func listDatabases(ctx context.Context, clt ListDatabasesClient, req ListDatabasesRequest) (*ListDatabasesResponse, error) { +func listDatabases(ctx context.Context, clt ListDatabasesClient, log *slog.Logger, req ListDatabasesRequest) (*ListDatabasesResponse, error) { // Uses https://docs.aws.amazon.com/AmazonRDS/latest/APIReference/API_DescribeDBInstances.html if req.RDSType == rdsTypeInstance { ret, err := listDBInstances(ctx, clt, req) @@ -151,7 +152,7 @@ func listDatabases(ctx context.Context, clt ListDatabasesClient, req ListDatabas } // Uses https://docs.aws.amazon.com/AmazonRDS/latest/APIReference/API_DescribeDBClusters.html - ret, err := listDBClusters(ctx, clt, req) + ret, err := listDBClusters(ctx, clt, log, req) if err != nil { return nil, trace.Wrap(err) } @@ -199,7 +200,7 @@ func listDBInstances(ctx context.Context, clt ListDatabasesClient, req ListDatab return ret, nil } -func listDBClusters(ctx context.Context, clt ListDatabasesClient, req ListDatabasesRequest) (*ListDatabasesResponse, error) { +func listDBClusters(ctx context.Context, clt ListDatabasesClient, log *slog.Logger, req ListDatabasesRequest) (*ListDatabasesResponse, error) { describeDBClusterInput := &rds.DescribeDBClustersInput{ Filters: []rdsTypes.Filter{ {Name: &filterEngine, Values: req.Engines}, @@ -231,16 +232,23 @@ func listDBClusters(ctx context.Context, clt ListDatabasesClient, req ListDataba // To get this value, a member of the cluster is fetched and its Network Information is used to // populate the RDS Cluster information. // All the members have the same network information, so picking one at random should not matter. - clusterInstance, err := fetchSingleRDSDBInstance(ctx, clt, req, aws.ToString(db.DBClusterIdentifier)) + instances, err := fetchRDSClusterInstances(ctx, clt, req, aws.ToString(db.DBClusterIdentifier)) if err != nil { return nil, trace.Wrap(err) } + if len(instances) == 0 { + log.InfoContext(ctx, "Skipping RDS cluster because it has no instances", + "cluster", aws.ToString(db.DBClusterIdentifier), + ) + continue + } + instance := &instances[0] - if req.VpcId != "" && !subnetGroupIsInVPC(clusterInstance.DBSubnetGroup, req.VpcId) { + if req.VpcId != "" && !subnetGroupIsInVPC(instance.DBSubnetGroup, req.VpcId) { continue } - awsDB, err := common.NewDatabaseFromRDSV2Cluster(&db, clusterInstance) + awsDB, err := common.NewDatabaseFromRDSV2Cluster(&db, instance) if err != nil { return nil, trace.Wrap(err) } @@ -251,7 +259,7 @@ func listDBClusters(ctx context.Context, clt ListDatabasesClient, req ListDataba return ret, nil } -func fetchSingleRDSDBInstance(ctx context.Context, clt ListDatabasesClient, req ListDatabasesRequest, clusterID string) (*rdsTypes.DBInstance, error) { +func fetchRDSClusterInstances(ctx context.Context, clt ListDatabasesClient, req ListDatabasesRequest, clusterID string) ([]rdsTypes.DBInstance, error) { describeDBInstanceInput := &rds.DescribeDBInstancesInput{ Filters: []rdsTypes.Filter{ {Name: &filterDBClusterID, Values: []string{clusterID}}, @@ -262,12 +270,7 @@ func fetchSingleRDSDBInstance(ctx context.Context, clt ListDatabasesClient, req if err != nil { return nil, trace.Wrap(err) } - - if len(rdsDBs.DBInstances) == 0 { - return nil, trace.BadParameter("database cluster %s has no instance", clusterID) - } - - return &rdsDBs.DBInstances[0], nil + return rdsDBs.DBInstances, nil } // subnetGroupIsInVPC is a simple helper to check if a db subnet group is in diff --git a/lib/integrations/awsoidc/listdatabases_test.go b/lib/integrations/awsoidc/listdatabases_test.go index 09b8490077a4a..6f831668e4f85 100644 --- a/lib/integrations/awsoidc/listdatabases_test.go +++ b/lib/integrations/awsoidc/listdatabases_test.go @@ -32,6 +32,7 @@ import ( "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/utils" ) func stringPointer(s string) *string { @@ -140,6 +141,7 @@ func TestListDatabases(t *testing.T) { t.Run("without vpc filter", func(t *testing.T) { t.Parallel() + logger := utils.NewSlogLoggerForTests().With("test", t.Name()) // First page must return pageSize number of DBs req := ListDatabasesRequest{ Region: "us-east-1", @@ -149,14 +151,14 @@ func TestListDatabases(t *testing.T) { NextToken: "", } for i := 0; i < totalDBs/int(listDatabasesPageSize); i++ { - resp, err := ListDatabases(ctx, mockListClient, req) + resp, err := ListDatabases(ctx, mockListClient, logger, req) require.NoError(t, err) require.Len(t, resp.Databases, int(listDatabasesPageSize)) require.NotEmpty(t, resp.NextToken) req.NextToken = resp.NextToken } // Last page must return remaining databases and an empty token. - resp, err := ListDatabases(ctx, mockListClient, req) + resp, err := ListDatabases(ctx, mockListClient, logger, req) require.NoError(t, err) require.Len(t, resp.Databases, totalDBs%int(listDatabasesPageSize)) require.Empty(t, resp.NextToken) @@ -164,10 +166,11 @@ func TestListDatabases(t *testing.T) { t.Run("with vpc filter", func(t *testing.T) { t.Parallel() + logger := utils.NewSlogLoggerForTests().With("test", t.Name()) // First page must return at least pageSize number of DBs var gotDatabases []types.Database wantVPC := "vpc-2" - resp, err := ListDatabases(ctx, mockListClient, ListDatabasesRequest{ + resp, err := ListDatabases(ctx, mockListClient, logger, ListDatabasesRequest{ Region: "us-east-1", RDSType: "instance", Engines: []string{"postgres"}, @@ -188,7 +191,7 @@ func TestListDatabases(t *testing.T) { gotDatabases = append(gotDatabases, resp.Databases...) // Second page must return pageSize number of DBs - resp, err = ListDatabases(ctx, mockListClient, ListDatabasesRequest{ + resp, err = ListDatabases(ctx, mockListClient, logger, ListDatabasesRequest{ Region: "us-east-1", RDSType: "instance", Engines: []string{"postgres"}, @@ -202,7 +205,7 @@ func TestListDatabases(t *testing.T) { gotDatabases = append(gotDatabases, resp.Databases...) // Third page must return only the remaining DBs and an empty nextToken - resp, err = ListDatabases(ctx, mockListClient, ListDatabasesRequest{ + resp, err = ListDatabases(ctx, mockListClient, logger, ListDatabasesRequest{ Region: "us-east-1", RDSType: "instance", Engines: []string{"postgres"}, @@ -583,23 +586,77 @@ func TestListDatabases(t *testing.T) { }, { - name: "cluster exists but no instance exists, returns an error", + name: "listing clusters returns all valid clusters and ignores the others", req: ListDatabasesRequest{ Region: "us-east-1", RDSType: "cluster", Engines: []string{"postgres"}, NextToken: "", }, - mockClusters: []rdsTypes.DBCluster{{ - Status: stringPointer("available"), + mockInstances: []rdsTypes.DBInstance{{ DBClusterIdentifier: stringPointer("my-dbc"), - DbClusterResourceId: stringPointer("db-123"), - Engine: stringPointer("aurora-postgresql"), - Endpoint: stringPointer("aurora-instance-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com"), - Port: &clusterPort, - DBClusterArn: stringPointer("arn:aws:iam::123456789012:role/MyARN"), + DBSubnetGroup: &rdsTypes.DBSubnetGroup{ + Subnets: []rdsTypes.Subnet{{SubnetIdentifier: aws.String("subnet-999")}}, + VpcId: aws.String("vpc-999"), + }, }}, - errCheck: trace.IsBadParameter, + mockClusters: []rdsTypes.DBCluster{ + { + Status: stringPointer("available"), + DBClusterIdentifier: stringPointer("my-empty-cluster"), + DbClusterResourceId: stringPointer("db-456"), + Engine: stringPointer("aurora-mysql"), + Endpoint: stringPointer("aurora-instance-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com"), + Port: &clusterPort, + DBClusterArn: stringPointer("arn:aws:iam::123456789012:role/MyARN"), + }, + { + Status: stringPointer("available"), + DBClusterIdentifier: stringPointer("my-dbc"), + DbClusterResourceId: stringPointer("db-123"), + Engine: stringPointer("aurora-postgresql"), + Endpoint: stringPointer("aurora-instance-2.abcdefghijklmnop.us-west-1.rds.amazonaws.com"), + Port: &clusterPort, + DBClusterArn: stringPointer("arn:aws:iam::123456789012:role/MyARN"), + }, + }, + respCheck: func(t *testing.T, ldr *ListDatabasesResponse) { + require.Len(t, ldr.Databases, 1, "expected 1 database, got %d", len(ldr.Databases)) + require.Empty(t, ldr.NextToken, "expected an empty NextToken") + expectedDB, err := types.NewDatabaseV3( + types.Metadata{ + Name: "my-dbc", + Description: "Aurora cluster in ", + Labels: map[string]string{ + "account-id": "123456789012", + "endpoint-type": "primary", + "engine": "aurora-postgresql", + "engine-version": "", + "region": "", + "status": "available", + "vpc-id": "vpc-999", + "teleport.dev/cloud": "AWS", + }, + }, + types.DatabaseSpecV3{ + Protocol: "postgres", + URI: "aurora-instance-2.abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432", + AWS: types.AWS{ + AccountID: "123456789012", + RDS: types.RDS{ + ClusterID: "my-dbc", + InstanceID: "aurora-instance-2", + ResourceID: "db-123", + Subnets: []string{"subnet-999"}, + VPCID: "vpc-999", + }, + }, + }, + ) + require.NoError(t, err) + require.Empty(t, cmp.Diff(expectedDB, ldr.Databases[0])) + }, + errCheck: noErrorFunc, }, { name: "no region", @@ -637,7 +694,8 @@ func TestListDatabases(t *testing.T) { dbInstances: tt.mockInstances, dbClusters: tt.mockClusters, } - resp, err := ListDatabases(ctx, mockListClient, tt.req) + logger := utils.NewSlogLoggerForTests().With("test", t.Name()) + resp, err := ListDatabases(ctx, mockListClient, logger, tt.req) require.True(t, tt.errCheck(err), "unexpected err: %v", err) if err != nil { return