diff --git a/go-runtime/ftl/database.go b/go-runtime/ftl/database.go index d23da37064..87465db955 100644 --- a/go-runtime/ftl/database.go +++ b/go-runtime/ftl/database.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" + "github.com/alecthomas/types/once" _ "github.com/jackc/pgx/v5/stdlib" // Register Postgres driver "github.com/TBD54566975/ftl/internal/modulecontext" @@ -13,6 +14,8 @@ import ( type Database struct { Name string DBType modulecontext.DBType + + db *once.Handle[*sql.DB] } // PostgresDatabase returns a handler for the named database. @@ -20,17 +23,28 @@ func PostgresDatabase(name string) Database { return Database{ Name: name, DBType: modulecontext.DBTypePostgres, + db: once.Once(func(ctx context.Context) (*sql.DB, error) { + provider := modulecontext.FromContext(ctx).CurrentContext() + dsn, err := provider.GetDatabase(name, modulecontext.DBTypePostgres) + if err != nil { + return nil, fmt.Errorf("failed to get database %q: %w", name, err) + } + db, err := sql.Open("pgx", dsn) + if err != nil { + return nil, fmt.Errorf("failed to open database %q: %w", name, err) + } + return db, nil + }), } } func (d Database) String() string { return fmt.Sprintf("database %q", d.Name) } -// Get returns the sql db connection for the database. +// Get returns the SQL DB connection for the database. func (d Database) Get(ctx context.Context) *sql.DB { - provider := modulecontext.FromContext(ctx).CurrentContext() - db, err := provider.GetDatabase(d.Name, d.DBType) + db, err := d.db.Get(ctx) if err != nil { - panic(err.Error()) + panic(err) } return db } diff --git a/internal/modulecontext/database.go b/internal/modulecontext/database.go index fef41e6f54..c69d5c433b 100644 --- a/internal/modulecontext/database.go +++ b/internal/modulecontext/database.go @@ -1,7 +1,6 @@ package modulecontext import ( - "database/sql" "fmt" "strconv" @@ -14,19 +13,13 @@ type Database struct { DSN string DBType DBType isTestDB bool - db *sql.DB } // NewDatabase creates a Database that can be added to ModuleContext func NewDatabase(dbType DBType, dsn string) (Database, error) { - db, err := sql.Open("pgx", dsn) - if err != nil { - return Database{}, fmt.Errorf("failed to bring up DB connection: %w", err) - } return Database{ DSN: dsn, DBType: dbType, - db: db, }, nil } diff --git a/internal/modulecontext/module_context.go b/internal/modulecontext/module_context.go index a8c1522c90..57985929df 100644 --- a/internal/modulecontext/module_context.go +++ b/internal/modulecontext/module_context.go @@ -2,7 +2,6 @@ package modulecontext import ( "context" - "database/sql" "encoding/json" "errors" "fmt" @@ -130,22 +129,23 @@ func (m ModuleContext) GetSecret(name string, value any) error { return json.Unmarshal(data, value) } -// GetDatabase gets a database connection +// GetDatabase gets a database DSN by name and type. // -// Returns an error if no database with that name is found or it is not the expected type -// When in a testing context (via ftltest), an error is returned if the database is not a test database -func (m ModuleContext) GetDatabase(name string, dbType DBType) (*sql.DB, error) { +// Returns an error if no database with that name is found or it is not the +// expected type. When in a testing context (via ftltest), an error is returned +// if the database is not a test database. +func (m ModuleContext) GetDatabase(name string, dbType DBType) (string, error) { db, ok := m.databases[name] if !ok { - return nil, fmt.Errorf("missing DSN for database %s", name) + return "", fmt.Errorf("missing DSN for database %s", name) } if db.DBType != dbType { - return nil, fmt.Errorf("database %s does not match expected type of %s", name, dbType) + return "", fmt.Errorf("database %s does not match expected type of %s", name, dbType) } if m.isTesting && !db.isTestDB { - return nil, fmt.Errorf("accessing non-test database %q while testing: try adding ftltest.WithDatabase(db) as an option with ftltest.Context(...)", name) + return "", fmt.Errorf("accessing non-test database %q while testing: try adding ftltest.WithDatabase(db) as an option with ftltest.Context(...)", name) } - return db.db, nil + return db.DSN, nil } // LeaseClient is the interface for acquiring, heartbeating and releasing leases