Skip to content

Commit

Permalink
feat: prevent db access to dev dbs in tests (#1462)
Browse files Browse the repository at this point in the history
fixes #1460
Issue says to panic, but in the end I just returned an error

Example error message:
> accessing non-test database "oidcauth" while testing: try adding
ftltest.WithDatabase(db) as an option with ftltest.Context(...)
  • Loading branch information
matt2e authored May 10, 2024
1 parent e1b3c41 commit 3066d64
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 41 deletions.
2 changes: 1 addition & 1 deletion go-runtime/ftl/ftltest/ftltest.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ func WithDatabase(dbHandle ftl.Database) Option {
}

// replace original database with test database
replacementDB, err := modulecontext.NewDatabase(modulecontext.DBTypePostgres, dsn)
replacementDB, err := modulecontext.NewTestDatabase(modulecontext.DBTypePostgres, dsn)
if err != nil {
return fmt.Errorf("could not create database %q with DSN %q: %w", dbHandle.Name, dsn, err)
}
Expand Down
58 changes: 58 additions & 0 deletions internal/modulecontext/database.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package modulecontext

import (
"database/sql"
"fmt"
"strconv"

ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1"
)

// Database represents a database connection based on a DSN
// It holds a private field for the database which is accessible through moduleCtx.GetDatabase(name)
type Database struct {
DSN string
DBType DBType
isTestDB bool
db *sql.DB
}

// NewDatabase creates a Database that can be added to ModuleContext
func NewDatabase(dbType DBType, dsn string) (Database, error) {
db, err := sql.Open("pgx", dsn)
if err != nil {
return Database{}, err
}
return Database{
DSN: dsn,
DBType: dbType,
db: db,
}, nil
}

// NewTestDatabase creates a Database that can be added to ModuleContext
//
// Test databases can be used within module tests
func NewTestDatabase(dbType DBType, dsn string) (Database, error) {
db, err := NewDatabase(dbType, dsn)
if err != nil {
return Database{}, err
}
db.isTestDB = true
return db, nil
}

type DBType ftlv1.ModuleContextResponse_DBType

const (
DBTypePostgres = DBType(ftlv1.ModuleContextResponse_POSTGRES)
)

func (x DBType) String() string {
switch x {
case DBTypePostgres:
return "Postgres"
default:
panic(fmt.Sprintf("unknown DB type: %s", strconv.Itoa(int(x))))
}
}
44 changes: 4 additions & 40 deletions internal/modulecontext/module_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,55 +5,15 @@ import (
"database/sql"
"encoding/json"
"fmt"
"strconv"
"strings"

"github.com/alecthomas/types/optional"
_ "github.com/jackc/pgx/v5/stdlib" // SQL driver

ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/reflect"
)

// Database represents a database connection based on a DSN
//
// It holds a private field for the database which is accessible through moduleCtx.GetDatabase(name)
type Database struct {
DSN string
DBType DBType

db *sql.DB
}

// NewDatabase creates a Database that can be added to ModuleContext
func NewDatabase(dbType DBType, dsn string) (Database, error) {
db, err := sql.Open("pgx", dsn)
if err != nil {
return Database{}, err
}
return Database{
DSN: dsn,
DBType: dbType,
db: db,
}, nil
}

type DBType ftlv1.ModuleContextResponse_DBType

const (
DBTypePostgres = DBType(ftlv1.ModuleContextResponse_POSTGRES)
)

func (x DBType) String() string {
switch x {
case DBTypePostgres:
return "Postgres"
default:
panic(fmt.Sprintf("unknown DB type: %s", strconv.Itoa(int(x))))
}
}

// Verb is a function that takes a request and returns a response but is not constrained by request/response type like ftl.Verb
//
// It is used for definitions of mock verbs as well as real implementations of verbs to directly execute
Expand Down Expand Up @@ -166,6 +126,7 @@ func (m ModuleContext) GetSecret(name string, value any) error {
// GetDatabase gets a database connection
//
// Returns an error if no database with that name is found or it is not the expected type
// When in a testing context (via ftltest), an error is returned if the database is not a test database
func (m ModuleContext) GetDatabase(name string, dbType DBType) (*sql.DB, error) {
db, ok := m.databases[name]
if !ok {
Expand All @@ -174,6 +135,9 @@ func (m ModuleContext) GetDatabase(name string, dbType DBType) (*sql.DB, error)
if db.DBType != dbType {
return nil, fmt.Errorf("database %s does not match expected type of %s", name, dbType)
}
if m.isTesting && !db.isTestDB {
return nil, fmt.Errorf("accessing non-test database %q while testing: try adding ftltest.WithDatabase(db) as an option with ftltest.Context(...)", name)
}
return db.db, nil
}

Expand Down

0 comments on commit 3066d64

Please sign in to comment.