Skip to content

Commit

Permalink
feat: Create psql DB in the CF provisioner (#3334)
Browse files Browse the repository at this point in the history
Connects to the freshly created psql instance, and creates a database if
it does not exist.

This closes #3117 with
tbdeng/ftl-aws#35

We should refactor the DB creation from the success step to be part of
the actual provisioning flow. Thicket for this:
#3333

We should also not use the root credentials when connecting to the DB
from a module. However, we should change this when looking at DB
migrations, when we won't expect modules to execute DDL anymore.

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
jvmakine and github-actions[bot] authored Nov 6, 2024
1 parent 85f894f commit e4b72b5
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 43 deletions.
1 change: 1 addition & 0 deletions cmd/devel-provisioner/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func main() {
"ftl-provisioner-cloudformation",
provisionerconnect.NewProvisionerPluginServiceClient,
plugin.WithEnvars("FTL_PROVISIONER_CF_DB_SUBNET_GROUP=aurora-postgres-subnet-group"),
plugin.WithEnvars("FTL_PROVISIONER_CF_DB_SECURITY_GROUP=sg-08e06d6f8327024de"),
)
if err != nil {
panic(err)
Expand Down
14 changes: 14 additions & 0 deletions cmd/ftl-provisioner-cloudformation/cloudformation_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/cloudformation"
"github.com/aws/aws-sdk-go-v2/service/cloudformation/types"
"github.com/aws/aws-sdk-go-v2/service/secretsmanager"
"github.com/aws/smithy-go"
goformation "github.com/awslabs/goformation/v7/cloudformation"
"github.com/jpillora/backoff"
Expand Down Expand Up @@ -104,6 +105,19 @@ func createClient(ctx context.Context) (*cloudformation.Client, error) {
), nil
}

func createSecretsClient(ctx context.Context) (*secretsmanager.Client, error) {
cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
return nil, fmt.Errorf("failed to load default aws config: %w", err)
}
return secretsmanager.New(
secretsmanager.Options{
Credentials: cfg.Credentials,
Region: cfg.Region,
},
), nil
}

// CloudformationOutputKey is structured key to be used as an output from a CF stack
type CloudformationOutputKey struct {
ResourceID string `json:"r"`
Expand Down
17 changes: 14 additions & 3 deletions cmd/ftl-provisioner-cloudformation/provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"connectrpc.com/connect"
"github.com/aws/aws-sdk-go-v2/service/cloudformation"
"github.com/aws/aws-sdk-go-v2/service/secretsmanager"
goformation "github.com/awslabs/goformation/v7/cloudformation"
cf "github.com/awslabs/goformation/v7/cloudformation/cloudformation"
"github.com/awslabs/goformation/v7/cloudformation/rds"
Expand All @@ -26,6 +27,7 @@ import (
const (
PropertyDBReadEndpoint = "db:read_endpoint"
PropertyDBWriteEndpoint = "db:write_endpoint"
PropertyMasterUserARN = "db:maser_user_secret_arn"
)

type Config struct {
Expand All @@ -35,8 +37,9 @@ type Config struct {
}

type CloudformationProvisioner struct {
client *cloudformation.Client
confg *Config
client *cloudformation.Client
secrets *secretsmanager.Client
confg *Config
}

var _ provisionerconnect.ProvisionerPluginServiceHandler = (*CloudformationProvisioner)(nil)
Expand All @@ -46,8 +49,12 @@ func NewCloudformationProvisioner(ctx context.Context, config Config) (context.C
if err != nil {
return nil, nil, fmt.Errorf("failed to create cloudformation client: %w", err)
}
secrets, err := createSecretsClient(ctx)
if err != nil {
return nil, nil, fmt.Errorf("failed to create secretsmanager client: %w", err)
}

return ctx, &CloudformationProvisioner{client: client, confg: &config}, nil
return ctx, &CloudformationProvisioner{client: client, secrets: secrets, confg: &config}, nil
}

func (c *CloudformationProvisioner) Ping(context.Context, *connect.Request[ftlv1.PingRequest]) (*connect.Response[ftlv1.PingResponse], error) {
Expand Down Expand Up @@ -165,6 +172,10 @@ func (c *CloudformationProvisioner) resourceToCF(cluster, module string, templat
ResourceID: resource.ResourceId,
PropertyName: PropertyDBReadEndpoint,
})
addOutput(template.Outputs, goformation.GetAtt(clusterID, "MasterUserSecret.SecretArn"), &CloudformationOutputKey{
ResourceID: resource.ResourceId,
PropertyName: PropertyMasterUserARN,
})
return nil
}
return errors.New("unsupported resource type")
Expand Down
148 changes: 108 additions & 40 deletions cmd/ftl-provisioner-cloudformation/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@ package main

import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/url"
"os"
"strings"

"connectrpc.com/connect"
"github.com/aws/aws-sdk-go-v2/service/cloudformation"
"github.com/aws/aws-sdk-go-v2/service/cloudformation/types"
"github.com/aws/aws-sdk-go-v2/service/secretsmanager"
_ "github.com/lib/pq"

"github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1beta1/provisioner"
)
Expand All @@ -32,7 +39,7 @@ func (c *CloudformationProvisioner) Status(ctx context.Context, req *connect.Req
case types.StackStatusCreateFailed:
return failure(&stack)
case types.StackStatusCreateComplete:
return success(&stack, req.Msg.DesiredResources)
return c.success(ctx, &stack, req.Msg.DesiredResources)
case types.StackStatusRollbackInProgress:
return failure(&stack)
case types.StackStatusRollbackFailed:
Expand All @@ -44,13 +51,13 @@ func (c *CloudformationProvisioner) Status(ctx context.Context, req *connect.Req
case types.StackStatusDeleteFailed:
return failure(&stack)
case types.StackStatusDeleteComplete:
return success(&stack, req.Msg.DesiredResources)
return c.success(ctx, &stack, req.Msg.DesiredResources)
case types.StackStatusUpdateInProgress:
return running()
case types.StackStatusUpdateCompleteCleanupInProgress:
return running()
case types.StackStatusUpdateComplete:
return success(&stack, req.Msg.DesiredResources)
return c.success(ctx, &stack, req.Msg.DesiredResources)
case types.StackStatusUpdateFailed:
return failure(&stack)
case types.StackStatusUpdateRollbackInProgress:
Expand All @@ -60,8 +67,8 @@ func (c *CloudformationProvisioner) Status(ctx context.Context, req *connect.Req
}
}

func success(stack *types.Stack, resources []*provisioner.Resource) (*connect.Response[provisioner.StatusResponse], error) {
err := updateResources(stack.Outputs, resources)
func (c *CloudformationProvisioner) success(ctx context.Context, stack *types.Stack, resources []*provisioner.Resource) (*connect.Response[provisioner.StatusResponse], error) {
err := c.updateResources(ctx, stack.Outputs, resources)
if err != nil {
return nil, err
}
Expand All @@ -86,49 +93,110 @@ func failure(stack *types.Stack) (*connect.Response[provisioner.StatusResponse],
return nil, connect.NewError(connect.CodeUnknown, errors.New(*stack.StackStatusReason))
}

func updateResources(outputs []types.Output, update []*provisioner.Resource) error {
func outputsByResourceID(outputs []types.Output) (map[string][]types.Output, error) {
m := make(map[string][]types.Output)
for _, output := range outputs {
key, err := decodeOutputKey(output)
if err != nil {
return fmt.Errorf("failed to decode output key: %w", err)
return nil, fmt.Errorf("failed to decode output key: %w", err)
}
for _, resource := range update {
if resource.ResourceId == key.ResourceID {
if postgres, ok := resource.Resource.(*provisioner.Resource_Postgres); ok {
if postgres.Postgres == nil {
postgres.Postgres = &provisioner.PostgresResource{}
}
if postgres.Postgres.Output == nil {
postgres.Postgres.Output = &provisioner.PostgresResource_PostgresResourceOutput{}
}

switch key.PropertyName {
case PropertyDBReadEndpoint:
postgres.Postgres.Output.ReadDsn = endpointToDSN(*output.OutputValue, key.ResourceID, 5432)
case PropertyDBWriteEndpoint:
postgres.Postgres.Output.WriteDsn = endpointToDSN(*output.OutputValue, key.ResourceID, 5432)
}
} else if mysql, ok := resource.Resource.(*provisioner.Resource_Mysql); ok {
if mysql.Mysql == nil {
mysql.Mysql = &provisioner.MysqlResource{}
}
if mysql.Mysql.Output == nil {
mysql.Mysql.Output = &provisioner.MysqlResource_MysqlResourceOutput{}
}

switch key.PropertyName {
case PropertyDBReadEndpoint:
mysql.Mysql.Output.ReadDsn = endpointToDSN(*output.OutputValue, key.ResourceID, 5432)
case PropertyDBWriteEndpoint:
mysql.Mysql.Output.WriteDsn = endpointToDSN(*output.OutputValue, key.ResourceID, 3306)
}
}
m[key.ResourceID] = append(m[key.ResourceID], output)
}
return m, nil
}

func outputsByPropertyName(outputs []types.Output) (map[string]types.Output, error) {
m := make(map[string]types.Output)
for _, output := range outputs {
key, err := decodeOutputKey(output)
if err != nil {
return nil, fmt.Errorf("failed to decode output key: %w", err)
}
m[key.PropertyName] = output
}
return m, nil
}

func (c *CloudformationProvisioner) updateResources(ctx context.Context, outputs []types.Output, update []*provisioner.Resource) error {
byResourceID, err := outputsByResourceID(outputs)
if err != nil {
return fmt.Errorf("failed to group outputs by resource ID: %w", err)
}

for _, resource := range update {
if postgres, ok := resource.Resource.(*provisioner.Resource_Postgres); ok {
if postgres.Postgres == nil {
postgres.Postgres = &provisioner.PostgresResource{}
}
if postgres.Postgres.Output == nil {
postgres.Postgres.Output = &provisioner.PostgresResource_PostgresResourceOutput{}
}

if err := c.updatePostgresOutputs(ctx, postgres.Postgres.Output, resource.ResourceId, byResourceID[resource.ResourceId]); err != nil {
return fmt.Errorf("failed to update postgres outputs: %w", err)
}
} else if _, ok := resource.Resource.(*provisioner.Resource_Mysql); ok {
panic("mysql not implemented")
}
}
return nil
}

func endpointToDSN(endpoint, database string, port int) string {
return fmt.Sprintf("postgres://%s:%d/%s?user=postgres&password=password", endpoint, port, database)
func (c *CloudformationProvisioner) updatePostgresOutputs(ctx context.Context, to *provisioner.PostgresResource_PostgresResourceOutput, resourceID string, outputs []types.Output) error {
byName, err := outputsByPropertyName(outputs)
if err != nil {
return fmt.Errorf("failed to group outputs by property name: %w", err)
}

fmt.Fprintf(os.Stderr, "byName: %v\n", byName)

// TODO: Move to provisioner workflow
secretARN := *byName[PropertyMasterUserARN].OutputValue
username, password, err := c.secretARNToUsernamePassword(ctx, secretARN)
if err != nil {
return fmt.Errorf("failed to get username and password from secret ARN: %w", err)
}

to.ReadDsn = endpointToDSN(*byName[PropertyDBReadEndpoint].OutputValue, resourceID, 5432, username, password)
to.WriteDsn = endpointToDSN(*byName[PropertyDBWriteEndpoint].OutputValue, resourceID, 5432, username, password)
adminEndpoint := endpointToDSN(*byName[PropertyDBReadEndpoint].OutputValue, "postgres", 5432, username, password)

// Connect to postgres without a specific database to create the new one
db, err := sql.Open("postgres", adminEndpoint)
if err != nil {
return fmt.Errorf("failed to connect to postgres: %w", err)
}
defer db.Close()

// Create the database if it doesn't exist
if _, err := db.ExecContext(ctx, "CREATE DATABASE "+resourceID); err != nil {
// Ignore if database already exists
if !strings.Contains(err.Error(), "already exists") {
return fmt.Errorf("failed to create database: %w", err)
}
}

return nil
}

func endpointToDSN(endpoint, database string, port int, username, password string) string {
urlEncodedPassword := url.QueryEscape(password)
return fmt.Sprintf("postgres://%s:%d/%s?user=%s&password=%s", endpoint, port, database, username, urlEncodedPassword)
}

func (c *CloudformationProvisioner) secretARNToUsernamePassword(ctx context.Context, secretARN string) (string, string, error) {
secret, err := c.secrets.GetSecretValue(ctx, &secretsmanager.GetSecretValueInput{
SecretId: &secretARN,
})
if err != nil {
return "", "", fmt.Errorf("failed to get secret value: %w", err)
}
secretString := *secret.SecretString

var secretData map[string]string
if err := json.Unmarshal([]byte(secretString), &secretData); err != nil {
return "", "", fmt.Errorf("failed to unmarshal secret data: %w", err)
}

return secretData["username"], secretData["password"], nil
}

0 comments on commit e4b72b5

Please sign in to comment.