Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Resolve DSN from runtime in PG Proxy #3458

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,8 @@ func (s *Service) GetModuleContext(ctx context.Context, req *connect.Request[ftl
continue
}
dbTypes[db.Name] = dbType
if db.Runtime != nil {
// TODO: Move the DSN resolution to the runtime once MySQL proxy is working
if db.Runtime != nil && dbType == modulecontext.DBTypeMySQL {
databases[db.Name] = modulecontext.Database{
DSN: db.Runtime.DSN,
DBType: dbType,
Expand Down
111 changes: 81 additions & 30 deletions backend/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/alecthomas/types/optional"
"github.com/jpillora/backoff"
"github.com/otiai10/copy"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"

Expand All @@ -39,6 +40,7 @@ import (
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/model"
ftlobservability "github.com/TBD54566975/ftl/internal/observability"
"github.com/TBD54566975/ftl/internal/pgproxy"
"github.com/TBD54566975/ftl/internal/rpc"
"github.com/TBD54566975/ftl/internal/schema"
"github.com/TBD54566975/ftl/internal/slices"
Expand Down Expand Up @@ -129,7 +131,20 @@ func Start(ctx context.Context, config Config) error {
cancelFunc: doneFunc,
devEndpoint: config.DevEndpoint,
}
err = svc.deploy(ctx)

deploymentKey, err := model.ParseDeploymentKey(config.Deployment)
if err != nil {
observability.Deployment.Failure(ctx, optional.None[string]())
svc.cancelFunc()
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("invalid deployment key: %w", err))
}

module, err := svc.getModule(ctx, deploymentKey)
if err != nil {
return fmt.Errorf("failed to get module: %w", err)
}

err = svc.deploy(ctx, deploymentKey, module)
if err != nil {
// If we fail to deploy we just exit
// Kube or local scaling will start a new instance to continue
Expand All @@ -143,11 +158,24 @@ func Start(ctx context.Context, config Config) error {
go rpc.RetryStreamingClientStream(ctx, backoff.Backoff{}, controllerClient.StreamDeploymentLogs, svc.streamLogsLoop)
}()

return rpc.Serve(ctx, config.Bind,
rpc.GRPC(ftlv1connect.NewVerbServiceHandler, svc),
rpc.HTTP("/", svc),
rpc.HealthCheck(svc.healthCheck),
)
pgProxyStarted := make(chan pgproxy.Started)

g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
return svc.startPgProxy(ctx, module, pgProxyStarted)
})
g.Go(func() error {
pgProxy := <-pgProxyStarted
os.Setenv("PG_PROXY_ADDRESS", fmt.Sprintf("127.0.0.1:%d", pgProxy.Address.Port))
logger.Debugf("PG_PROXY_ADDRESS: %s", os.Getenv("PG_PROXY_ADDRESS"))

return rpc.Serve(ctx, config.Bind,
rpc.GRPC(ftlv1connect.NewVerbServiceHandler, svc),
rpc.HTTP("/", svc),
rpc.HealthCheck(svc.healthCheck),
)
})
return fmt.Errorf("failure in runner: %w", g.Wait())
}

func newIdentityStore(ctx context.Context, config Config, key model.RunnerKey, controllerClient ftlv1connect.ControllerServiceClient) (*identity.Store, error) {
Expand Down Expand Up @@ -294,52 +322,51 @@ func (s *Service) Ping(ctx context.Context, req *connect.Request[ftlv1.PingReque
return connect.NewResponse(&ftlv1.PingResponse{}), nil
}

func (s *Service) deploy(ctx context.Context) error {
logger := log.FromContext(ctx)
if err, ok := s.registrationFailure.Load().Get(); ok {
observability.Deployment.Failure(ctx, optional.None[string]())
return connect.NewError(connect.CodeUnavailable, fmt.Errorf("failed to register runner: %w", err))
func (s *Service) getModule(ctx context.Context, key model.DeploymentKey) (*schema.Module, error) {
gdResp, err := s.controllerClient.GetDeployment(ctx, connect.NewRequest(&ftlv1.GetDeploymentRequest{DeploymentKey: s.config.Deployment}))
if err != nil {
observability.Deployment.Failure(ctx, optional.Some(key.String()))
return nil, fmt.Errorf("failed to get deployment: %w", err)
}

key, err := model.ParseDeploymentKey(s.config.Deployment)
module, err := schema.ModuleFromProto(gdResp.Msg.Schema)
if err != nil {
observability.Deployment.Failure(ctx, optional.Some(key.String()))
return nil, fmt.Errorf("invalid module: %w", err)
}
return module, nil
}

func (s *Service) deploy(ctx context.Context, key model.DeploymentKey, module *schema.Module) error {
logger := log.FromContext(ctx)

if err, ok := s.registrationFailure.Load().Get(); ok {
observability.Deployment.Failure(ctx, optional.None[string]())
s.cancelFunc()
return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("invalid deployment key: %w", err))
return connect.NewError(connect.CodeUnavailable, fmt.Errorf("failed to register runner: %w", err))
}

observability.Deployment.Started(ctx, key.String())
defer observability.Deployment.Completed(ctx, key.String())

deploymentLogger := s.getDeploymentLogger(ctx, key)
ctx = log.ContextWithLogger(ctx, deploymentLogger)

s.lock.Lock()
defer s.lock.Unlock()
if s.deployment.Load().Ok() {
observability.Deployment.Failure(ctx, optional.Some(key.String()))
return errors.New("already deployed")
}

gdResp, err := s.controllerClient.GetDeployment(ctx, connect.NewRequest(&ftlv1.GetDeploymentRequest{DeploymentKey: s.config.Deployment}))
if err != nil {
observability.Deployment.Failure(ctx, optional.Some(key.String()))
return fmt.Errorf("failed to get deployment: %w", err)
}
module, err := schema.ModuleFromProto(gdResp.Msg.Schema)
if err != nil {
observability.Deployment.Failure(ctx, optional.Some(key.String()))
return fmt.Errorf("invalid module: %w", err)
}
deploymentLogger := s.getDeploymentLogger(ctx, key)
ctx = log.ContextWithLogger(ctx, deploymentLogger)

deploymentDir := filepath.Join(s.config.DeploymentDir, module.Name, key.String())
if s.config.TemplateDir != "" {
err = copy.Copy(s.config.TemplateDir, deploymentDir)
err := copy.Copy(s.config.TemplateDir, deploymentDir)
if err != nil {
observability.Deployment.Failure(ctx, optional.Some(key.String()))
return fmt.Errorf("failed to copy template directory: %w", err)
}
} else {
err = os.MkdirAll(deploymentDir, 0700)
err := os.MkdirAll(deploymentDir, 0700)
if err != nil {
observability.Deployment.Failure(ctx, optional.Some(key.String()))
return fmt.Errorf("failed to create deployment directory: %w", err)
Expand Down Expand Up @@ -377,7 +404,7 @@ func (s *Service) deploy(ctx context.Context) error {
deployment, cmdCtx, err := plugin.Spawn(
unstoppable.Context(verbCtx),
log.FromContext(ctx).GetLevel(),
gdResp.Msg.Schema.Name,
module.Name,
deploymentDir,
"./launch",
ftlv1connect.NewVerbServiceClient,
Expand Down Expand Up @@ -568,3 +595,27 @@ func (s *Service) healthCheck(writer http.ResponseWriter, request *http.Request)
}
writer.WriteHeader(http.StatusServiceUnavailable)
}

func (s *Service) startPgProxy(ctx context.Context, module *schema.Module, started chan<- pgproxy.Started) error {
logger := log.FromContext(ctx)

databases := map[string]*schema.Database{}
for _, decl := range module.Decls {
if db, ok := decl.(*schema.Database); ok {
databases[db.Name] = db
}
}

if err := pgproxy.New(":0", func(ctx context.Context, params map[string]string) (string, error) {
db, ok := databases[params["database"]]
if !ok {
return "", fmt.Errorf("database %s not found", params["database"])
}
logger.Debugf("Resolved DSN (%s): %s", params["database"], db.Runtime.DSN)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is where the logic to use a dsn resolver based on provisioned runtime (like type: aws-iam-auth) will be added


return db.Runtime.DSN, nil
}).Start(ctx, started); err != nil {
return fmt.Errorf("failed to start pgproxy: %w", err)
}
return nil
}
4 changes: 2 additions & 2 deletions cmd/ftl-proxy-pg/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ func main() {
err = observability.Init(ctx, false, "", "ftl-provisioner", ftl.Version, cli.ObservabilityConfig)
kctx.FatalIfErrorf(err, "failed to initialize observability")

proxy := pgproxy.New(cli.Config, func(ctx context.Context, params map[string]string) (string, error) {
proxy := pgproxy.New(cli.Config.Listen, func(ctx context.Context, params map[string]string) (string, error) {
return "postgres://localhost:5432/postgres?user=" + params["user"], nil
})
if err := proxy.Start(ctx); err != nil {
if err := proxy.Start(ctx, nil); err != nil {
kctx.FatalIfErrorf(err, "failed to start proxy")
}
}
22 changes: 1 addition & 21 deletions cmd/ftl-runner/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,21 @@ package main

import (
"context"
"fmt"
"os"
"path/filepath"

"github.com/alecthomas/kong"
"golang.org/x/sync/errgroup"

"github.com/TBD54566975/ftl"
"github.com/TBD54566975/ftl/backend/runner"
_ "github.com/TBD54566975/ftl/internal/automaxprocs" // Set GOMAXPROCS to match Linux container CPU quota.
"github.com/TBD54566975/ftl/internal/log"
"github.com/TBD54566975/ftl/internal/pgproxy"
)

var cli struct {
Version kong.VersionFlag `help:"Show version."`
LogConfig log.Config `prefix:"log-" embed:""`
RunnerConfig runner.Config `embed:""`
ProxyConfig pgproxy.Config `embed:"" prefix:"pgproxy-"`
}

func main() {
Expand All @@ -47,21 +43,5 @@ and route to user code.
logger := log.Configure(os.Stderr, cli.LogConfig)
ctx := log.ContextWithLogger(context.Background(), logger)

g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
return runPGProxy(ctx, cli.ProxyConfig)
})
g.Go(func() error {
return runner.Start(ctx, cli.RunnerConfig)
})
kctx.FatalIfErrorf(g.Wait())
}

func runPGProxy(ctx context.Context, config pgproxy.Config) error {
if err := pgproxy.New(config, func(ctx context.Context, params map[string]string) (string, error) {
return "postgres://127.0.0.1:5432/postgres?user=" + params["user"], nil
}).Start(ctx); err != nil {
return fmt.Errorf("failed to start pgproxy: %w", err)
}
return nil
kctx.FatalIfErrorf(runner.Start(ctx, cli.RunnerConfig))
}
5 changes: 5 additions & 0 deletions internal/modulecontext/module_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"os"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -159,6 +160,10 @@ func (m ModuleContext) GetSecret(name string, value any) error {
func (m ModuleContext) GetDatabase(name string, dbType DBType) (string, bool, error) {
db, ok := m.databases[name]
if !ok {
if dbType == DBTypePostgres {
proxyAddress := os.Getenv("PG_PROXY_ADDRESS")
return "postgres://" + proxyAddress + "/" + name, false, nil
}
return "", false, fmt.Errorf("missing DSN for database %s", name)
}
if db.DBType != dbType {
Expand Down
20 changes: 16 additions & 4 deletions internal/pgproxy/pgproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,19 @@ type DSNConstructor func(ctx context.Context, params map[string]string) (string,
//
// address is the address to listen on for incoming connections.
// connectionFn is a function that constructs a new connection string from parameters of the incoming connection.
func New(config Config, connectionFn DSNConstructor) *PgProxy {
func New(listenAddress string, connectionFn DSNConstructor) *PgProxy {
return &PgProxy{
listenAddress: config.Listen,
listenAddress: listenAddress,
connectionStringFn: connectionFn,
}
}

// Start the proxy.
func (p *PgProxy) Start(ctx context.Context) error {
type Started struct {
Address *net.TCPAddr
}

// Start the proxy
func (p *PgProxy) Start(ctx context.Context, started chan<- Started) error {
logger := log.FromContext(ctx)

listener, err := net.Listen("tcp", p.listenAddress)
Expand All @@ -47,6 +51,14 @@ func (p *PgProxy) Start(ctx context.Context) error {
}
defer listener.Close()

if started != nil {
addr, ok := listener.Addr().(*net.TCPAddr)
if !ok {
panic("failed to get TCP address")
}
started <- Started{Address: addr}
}

for {
conn, err := listener.Accept()
if err != nil {
Expand Down
Loading