Skip to content

Commit

Permalink
refactor: decouple conflicting ModuleContext requirements
Browse files Browse the repository at this point in the history
The ModuleContext was designed to be an abstract data model in the
Controller for the resources required by a module, but along the way it
started to be used for storing DB connections for use by the go-runtime.
This change cleanly separates those requirements so that the go-runtime
is entirely responsible for creating new connections from the DSN
provided by the ModuleContext.
  • Loading branch information
alecthomas committed Aug 15, 2024
1 parent af62b9e commit 17900c6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 20 deletions.
22 changes: 18 additions & 4 deletions go-runtime/ftl/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"fmt"

"github.com/alecthomas/types/once"
_ "github.com/jackc/pgx/v5/stdlib" // Register Postgres driver

"github.com/TBD54566975/ftl/internal/modulecontext"
Expand All @@ -13,24 +14,37 @@ import (
type Database struct {
Name string
DBType modulecontext.DBType

db *once.Handle[*sql.DB]
}

// PostgresDatabase returns a handler for the named database.
func PostgresDatabase(name string) Database {
return Database{
Name: name,
DBType: modulecontext.DBTypePostgres,
db: once.Once(func(ctx context.Context) (*sql.DB, error) {
provider := modulecontext.FromContext(ctx).CurrentContext()
dsn, err := provider.GetDatabase(name, modulecontext.DBTypePostgres)
if err != nil {
return nil, fmt.Errorf("failed to get database %q: %w", name, err)
}
db, err := sql.Open("pgx", dsn)
if err != nil {
return nil, fmt.Errorf("failed to open database %q: %w", name, err)
}
return db, nil
}),
}
}

func (d Database) String() string { return fmt.Sprintf("database %q", d.Name) }

// Get returns the sql db connection for the database.
// Get returns the SQL DB connection for the database.
func (d Database) Get(ctx context.Context) *sql.DB {
provider := modulecontext.FromContext(ctx).CurrentContext()
db, err := provider.GetDatabase(d.Name, d.DBType)
db, err := d.db.Get(ctx)
if err != nil {
panic(err.Error())
panic(err)
}
return db
}
7 changes: 0 additions & 7 deletions internal/modulecontext/database.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package modulecontext

import (
"database/sql"
"fmt"
"strconv"

Expand All @@ -14,19 +13,13 @@ type Database struct {
DSN string
DBType DBType
isTestDB bool
db *sql.DB
}

// NewDatabase creates a Database that can be added to ModuleContext
func NewDatabase(dbType DBType, dsn string) (Database, error) {
db, err := sql.Open("pgx", dsn)
if err != nil {
return Database{}, fmt.Errorf("failed to bring up DB connection: %w", err)
}
return Database{
DSN: dsn,
DBType: dbType,
db: db,
}, nil
}

Expand Down
18 changes: 9 additions & 9 deletions internal/modulecontext/module_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package modulecontext

import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -130,22 +129,23 @@ func (m ModuleContext) GetSecret(name string, value any) error {
return json.Unmarshal(data, value)
}

// GetDatabase gets a database connection
// GetDatabase gets a database DSN by name and type.
//
// Returns an error if no database with that name is found or it is not the expected type
// When in a testing context (via ftltest), an error is returned if the database is not a test database
func (m ModuleContext) GetDatabase(name string, dbType DBType) (*sql.DB, error) {
// Returns an error if no database with that name is found or it is not the
// expected type. When in a testing context (via ftltest), an error is returned
// if the database is not a test database.
func (m ModuleContext) GetDatabase(name string, dbType DBType) (string, error) {
db, ok := m.databases[name]
if !ok {
return nil, fmt.Errorf("missing DSN for database %s", name)
return "", fmt.Errorf("missing DSN for database %s", name)
}
if db.DBType != dbType {
return nil, fmt.Errorf("database %s does not match expected type of %s", name, dbType)
return "", fmt.Errorf("database %s does not match expected type of %s", name, dbType)
}
if m.isTesting && !db.isTestDB {
return nil, fmt.Errorf("accessing non-test database %q while testing: try adding ftltest.WithDatabase(db) as an option with ftltest.Context(...)", name)
return "", fmt.Errorf("accessing non-test database %q while testing: try adding ftltest.WithDatabase(db) as an option with ftltest.Context(...)", name)
}
return db.db, nil
return db.DSN, nil
}

// LeaseClient is the interface for acquiring, heartbeating and releasing leases
Expand Down

0 comments on commit 17900c6

Please sign in to comment.