Skip to content

Commit

Permalink
fix: resolve pg DSN in the runner
Browse files Browse the repository at this point in the history
  • Loading branch information
jvmakine committed Nov 22, 2024
1 parent f2a91ef commit 6f92cd2
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 27 deletions.
17 changes: 3 additions & 14 deletions backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -769,21 +769,10 @@ func (s *Service) GetModuleContext(ctx context.Context, req *connect.Request[ftl
continue
}
dbTypes[db.Name] = dbType
// TODO: Move the DSN resolution to the runtime
if db.Runtime != nil {
var dsn string
switch dbType {
case modulecontext.DBTypePostgres:
// TODO: Get the port from config
dsn = "postgres://127.0.0.1:5678/" + db.Name
case modulecontext.DBTypeMySQL:
// TODO: Route MySQL through a proxy as well
dsn = db.Runtime.DSN
default:
return connect.NewError(connect.CodeInternal, fmt.Errorf("unknown DB type: %s", db.Type))
}
// TODO: Move the DSN resolution to the runtime once MySQL proxy is working
if db.Runtime != nil && dbType == modulecontext.DBTypeMySQL {
databases[db.Name] = modulecontext.Database{
DSN: dsn,
DSN: db.Runtime.DSN,
DBType: dbType,
}
}
Expand Down
22 changes: 15 additions & 7 deletions backend/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"math/rand"
"net"
"net/http"
"net/http/httputil"
"net/url"
Expand Down Expand Up @@ -63,8 +64,6 @@ type Config struct {
Registry artefacts.RegistryConfig `embed:"" prefix:"oci-"`
ObservabilityConfig ftlobservability.Config `embed:"" prefix:"o11y-"`
DevEndpoint optional.Option[url.URL] `help:"An existing endpoint to connect to in development mode" env:"FTL_DEV_ENDPOINT"`

PgProxyConfig pgproxy.Config `embed:"" prefix:"pgproxy-"`
}

func Start(ctx context.Context, config Config) error {
Expand Down Expand Up @@ -160,12 +159,17 @@ func Start(ctx context.Context, config Config) error {
go rpc.RetryStreamingClientStream(ctx, backoff.Backoff{}, controllerClient.StreamDeploymentLogs, svc.streamLogsLoop)
}()

pgProxyStarted := make(chan pgproxy.Started)

g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
return svc.startPgProxy(ctx, module)
return svc.startPgProxy(ctx, module, pgProxyStarted)
})
g.Go(func() error {
// TODO: Make sure pgproxy is ready before starting the runner
pgProxy := <-pgProxyStarted
os.Setenv("PG_PROXY_ADDRESS", fmt.Sprintf("127.0.0.1:%d", pgProxy.Address.(*net.TCPAddr).Port))

Check failure on line 170 in backend/runner/runner.go

View workflow job for this annotation

GitHub Actions / Lint

type assertion must be checked (forcetypeassert)
logger.Debugf("PG_PROXY_ADDRESS: %s", os.Getenv("PG_PROXY_ADDRESS"))

return rpc.Serve(ctx, config.Bind,
rpc.GRPC(ftlv1connect.NewVerbServiceHandler, svc),
rpc.HTTP("/", svc),
Expand Down Expand Up @@ -593,21 +597,25 @@ func (s *Service) healthCheck(writer http.ResponseWriter, request *http.Request)
writer.WriteHeader(http.StatusServiceUnavailable)
}

func (s *Service) startPgProxy(ctx context.Context, module *schema.Module) error {
func (s *Service) startPgProxy(ctx context.Context, module *schema.Module, started chan<- pgproxy.Started) error {
logger := log.FromContext(ctx)

databases := map[string]*schema.Database{}
for _, decl := range module.Decls {
if db, ok := decl.(*schema.Database); ok {
databases[db.Name] = db
}
}

if err := pgproxy.New(s.config.PgProxyConfig, func(ctx context.Context, params map[string]string) (string, error) {
if err := pgproxy.New(":0", func(ctx context.Context, params map[string]string) (string, error) {
db, ok := databases[params["database"]]
if !ok {
return "", fmt.Errorf("database %s not found", params["database"])
}
logger.Debugf("Resolved DSN (%s): %s", params["database"], db.Runtime.DSN)

return db.Runtime.DSN, nil
}).Start(ctx); err != nil {
}).Start(ctx, started); err != nil {
return fmt.Errorf("failed to start pgproxy: %w", err)
}
return nil
Expand Down
4 changes: 2 additions & 2 deletions cmd/ftl-proxy-pg/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ func main() {
err = observability.Init(ctx, false, "", "ftl-provisioner", ftl.Version, cli.ObservabilityConfig)
kctx.FatalIfErrorf(err, "failed to initialize observability")

proxy := pgproxy.New(cli.Config, func(ctx context.Context, params map[string]string) (string, error) {
proxy := pgproxy.New(cli.Config.Listen, func(ctx context.Context, params map[string]string) (string, error) {
return "postgres://localhost:5432/postgres?user=" + params["user"], nil
})
if err := proxy.Start(ctx); err != nil {
if err := proxy.Start(ctx, nil); err != nil {
kctx.FatalIfErrorf(err, "failed to start proxy")
}
}
5 changes: 5 additions & 0 deletions internal/modulecontext/module_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"os"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -159,6 +160,10 @@ func (m ModuleContext) GetSecret(name string, value any) error {
func (m ModuleContext) GetDatabase(name string, dbType DBType) (string, bool, error) {
db, ok := m.databases[name]
if !ok {
if dbType == DBTypePostgres {
proxyAddress := os.Getenv("PG_PROXY_ADDRESS")
return "postgres://" + proxyAddress + "/" + name, false, nil
}
return "", false, fmt.Errorf("missing DSN for database %s", name)
}
if db.DBType != dbType {
Expand Down
16 changes: 12 additions & 4 deletions internal/pgproxy/pgproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,19 @@ type DSNConstructor func(ctx context.Context, params map[string]string) (string,
//
// address is the address to listen on for incoming connections.
// connectionFn is a function that constructs a new connection string from parameters of the incoming connection.
func New(config Config, connectionFn DSNConstructor) *PgProxy {
func New(listenAddress string, connectionFn DSNConstructor) *PgProxy {
return &PgProxy{
listenAddress: config.Listen,
listenAddress: listenAddress,
connectionStringFn: connectionFn,
}
}

// Start the proxy.
func (p *PgProxy) Start(ctx context.Context) error {
type Started struct {
Address net.Addr
}

// Start the proxy
func (p *PgProxy) Start(ctx context.Context, started chan<- Started) error {
logger := log.FromContext(ctx)

listener, err := net.Listen("tcp", p.listenAddress)
Expand All @@ -47,6 +51,10 @@ func (p *PgProxy) Start(ctx context.Context) error {
}
defer listener.Close()

if started != nil {
started <- Started{Address: listener.Addr()}
}

for {
conn, err := listener.Accept()
if err != nil {
Expand Down

0 comments on commit 6f92cd2

Please sign in to comment.