Skip to content

Commit

Permalink
feat: use MySQL proxy for auth (#3519)
Browse files Browse the repository at this point in the history
  • Loading branch information
stuartwdouglas authored Nov 26, 2024
1 parent 8b511fe commit 8131257
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 35 deletions.
24 changes: 1 addition & 23 deletions backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,6 @@ func (s *Service) GetModuleContext(ctx context.Context, req *connect.Request[ftl
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("could not get deployments: %w", err))
}
databases := map[string]modulecontext.Database{}
for _, dep := range deps {
if dep.Module == name {
for _, decl := range dep.Schema.Decls {
Expand All @@ -711,17 +710,6 @@ func (s *Service) GetModuleContext(ctx context.Context, req *connect.Request[ftl
continue
}
dbTypes[db.Name] = dbType
// TODO: Move the DSN resolution to the runtime once MySQL proxy is working
if db.Runtime != nil && dbType == modulecontext.DBTypeMySQL {
if dsn, ok := db.Runtime.(*schema.DSNDatabaseRuntime); ok {
databases[db.Name] = modulecontext.Database{
DSN: dsn.DSN,
DBType: dbType,
}
} else {
return connect.NewError(connect.CodeInternal, fmt.Errorf("unknown database runtime type: %T", db.Runtime))
}
}
}
}
break
Expand All @@ -738,28 +726,18 @@ func (s *Service) GetModuleContext(ctx context.Context, req *connect.Request[ftl
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("could not get secrets: %w", err))
}
secretDbs, err := modulecontext.DatabasesFromSecrets(ctx, name, secrets, dbTypes)
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("could not get databases: %w", err))
}
for k, v := range secretDbs {
databases[k] = v
}

if err := hashConfigurationMap(h, configs); err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("could not detect change on configs: %w", err))
}
if err := hashConfigurationMap(h, secrets); err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("could not detect change on secrets: %w", err))
}
if err := hashDatabaseConfiguration(h, databases); err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("could not detect change on databases: %w", err))
}

checksum := int64(binary.BigEndian.Uint64((h.Sum(nil))[0:8]))

if checksum != lastChecksum {
response := modulecontext.NewBuilder(name).AddConfigs(configs).AddSecrets(secrets).AddDatabases(databases).Build().ToProto()
response := modulecontext.NewBuilder(name).AddConfigs(configs).AddSecrets(secrets).Build().ToProto()

if err := resp.Send(response); err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("could not send response: %w", err))
Expand Down
87 changes: 75 additions & 12 deletions backend/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"connectrpc.com/connect"
"github.com/alecthomas/atomic"
"github.com/alecthomas/types/optional"
mysql "github.com/block/ftl-mysql-auth-proxy"
"github.com/jpillora/backoff"
"github.com/otiai10/copy"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -144,27 +145,29 @@ func Start(ctx context.Context, config Config) error {
return fmt.Errorf("failed to get module: %w", err)
}

pgProxyStarted := make(chan optional.Option[pgproxy.Started])
startedLatch := &sync.WaitGroup{}
startedLatch.Add(2)
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
return svc.startPgProxy(ctx, module, pgProxyStarted)
return svc.startPgProxy(ctx, module, startedLatch)
})
g.Go(func() error {
return svc.startMySQLProxy(ctx, module, startedLatch)
})
g.Go(func() error {
startedLatch.Wait()
select {
case pgProxy := <-pgProxyStarted:
return svc.startDeployment(ctx, deploymentKey, module, pgProxy)
case <-ctx.Done():
return ctx.Err()
default:
return svc.startDeployment(ctx, deploymentKey, module)
}
})

return fmt.Errorf("failure in runner: %w", g.Wait())
}

func (s *Service) startDeployment(ctx context.Context, key model.DeploymentKey, module *schema.Module, pgProxyOpt optional.Option[pgproxy.Started]) error {
if pgProxy, ok := pgProxyOpt.Get(); ok {
os.Setenv("FTL_PROXY_POSTGRES_ADDRESS", fmt.Sprintf("127.0.0.1:%d", pgProxy.Address.Port))
}
func (s *Service) startDeployment(ctx context.Context, key model.DeploymentKey, module *schema.Module) error {

err := s.deploy(ctx, key, module)
if err != nil {
Expand Down Expand Up @@ -603,7 +606,7 @@ func (s *Service) healthCheck(writer http.ResponseWriter, request *http.Request)
writer.WriteHeader(http.StatusServiceUnavailable)
}

func (s *Service) startPgProxy(ctx context.Context, module *schema.Module, started chan<- optional.Option[pgproxy.Started]) error {
func (s *Service) startPgProxy(ctx context.Context, module *schema.Module, started *sync.WaitGroup) error {
logger := log.FromContext(ctx)

databases := map[string]*schema.Database{}
Expand All @@ -614,7 +617,7 @@ func (s *Service) startPgProxy(ctx context.Context, module *schema.Module, start
}

if len(databases) == 0 {
started <- optional.None[pgproxy.Started]()
started.Done()
return nil
}

Expand All @@ -623,13 +626,15 @@ func (s *Service) startPgProxy(ctx context.Context, module *schema.Module, start
go func() {
select {
case pgProxy := <-channel:
started <- optional.Some(pgProxy)
os.Setenv("FTL_PROXY_POSTGRES_ADDRESS", fmt.Sprintf("127.0.0.1:%d", pgProxy.Address.Port))
started.Done()
case <-ctx.Done():
started.Done()
return
}
}()

if err := pgproxy.New(":0", func(ctx context.Context, params map[string]string) (string, error) {
if err := pgproxy.New("127.0.0.1:0", func(ctx context.Context, params map[string]string) (string, error) {
db, ok := databases[params["database"]]
if !ok {
return "", fmt.Errorf("database %s not found", params["database"])
Expand All @@ -642,8 +647,66 @@ func (s *Service) startPgProxy(ctx context.Context, module *schema.Module, start

return "", fmt.Errorf("unknown database runtime type: %T", db.Runtime)
}).Start(ctx, channel); err != nil {
started.Done()
return fmt.Errorf("failed to start pgproxy: %w", err)
}

return nil
}

func (s *Service) startMySQLProxy(ctx context.Context, module *schema.Module, latch *sync.WaitGroup) error {
defer latch.Done()
logger := log.FromContext(ctx)

databases := map[string]*schema.Database{}
for _, decl := range module.Decls {
if db, ok := decl.(*schema.Database); ok && db.Type == "mysql" {
databases[db.Name] = db
}
}

if len(databases) == 0 {
return nil
}
for db, decl := range databases {
logger.Debugf("Starting MySQL proxy for %s", db)
logger := log.FromContext(ctx)
portC := make(chan int)
errorC := make(chan error)
databaseRuntime := decl.Runtime
var proxy *mysql.Proxy
switch db := databaseRuntime.(type) {
case *schema.DSNDatabaseRuntime:
proxy = mysql.NewProxy("localhost", 0, db.DSN, &mysqlLogger{logger: logger}, portC)
default:
return fmt.Errorf("unknown database runtime type: %T", databaseRuntime)
}
go func() {
err := proxy.ListenAndServe(ctx)
if err != nil {
errorC <- err
}
}()
port := 0
select {
case err := <-errorC:
return fmt.Errorf("error: %w", err)
case port = <-portC:
}

os.Setenv(strings.ToUpper("FTL_PROXY_MYSQL_ADDRESS_"+decl.Name), fmt.Sprintf("127.0.0.1:%d", port))
}
return nil
}

var _ mysql.Logger = (*mysqlLogger)(nil)

type mysqlLogger struct {
logger *log.Logger
}

func (m *mysqlLogger) Print(v ...any) {
for _, s := range v {
m.logger.Infof("mysql: %v", s)
}
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.6
github.com/aws/smithy-go v1.22.1
github.com/beevik/etree v1.4.1
github.com/block/ftl-mysql-auth-proxy v0.0.0-20241126024735-7acb0031b469
github.com/block/scaffolder v1.3.0
github.com/bmatcuk/doublestar/v4 v4.7.1
github.com/deckarep/golang-set/v2 v2.6.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions internal/modulecontext/module_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ func (m ModuleContext) GetDatabase(name string, dbType DBType) (string, bool, er
if dbType == DBTypePostgres {
proxyAddress := os.Getenv("FTL_PROXY_POSTGRES_ADDRESS")
return "postgres://" + proxyAddress + "/" + name, false, nil
} else if dbType == DBTypeMySQL {
proxyAddress := os.Getenv("FTL_PROXY_MYSQL_ADDRESS_" + strings.ToUpper(name))
return "ftl:ftl@tcp(" + proxyAddress + ")/" + name, false, nil
}
return "", false, fmt.Errorf("missing DSN for database %s", name)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ public Datasource getDatasource(String name) {
if (databases.get(name) == ModuleContextResponse.DBType.POSTGRES) {
var proxyAddress = System.getenv("FTL_PROXY_POSTGRES_ADDRESS");
return new Datasource("jdbc:postgresql://" + proxyAddress + "/" + name, "ftl", "ftl");
} else if (databases.get(name) == ModuleContextResponse.DBType.MYSQL) {
var proxyAddress = System.getenv("FTL_PROXY_MYSQL_ADDRESS_" + name.toUpperCase());
return new Datasource("jdbc:mysql://" + proxyAddress + "/" + name, "ftl", "ftl");
}
List<ModuleContextResponse.DSN> databasesList = getModuleContext().getDatabasesList();
for (var i : databasesList) {
Expand Down

0 comments on commit 8131257

Please sign in to comment.