diff --git a/backend/controller/sql/testdata/go/database/database.go b/backend/controller/sql/testdata/go/database/database.go index 692319a411..84672f56f2 100644 --- a/backend/controller/sql/testdata/go/database/database.go +++ b/backend/controller/sql/testdata/go/database/database.go @@ -6,7 +6,7 @@ import ( "github.com/TBD54566975/ftl/go-runtime/ftl" // Import the FTL SDK. ) -var db = ftl.PostgresDatabase("testdb") +type testdb = ftl.PostgresDatabaseHandle type InsertRequest struct { Data string @@ -15,8 +15,8 @@ type InsertRequest struct { type InsertResponse struct{} //ftl:verb -func Insert(ctx context.Context, req InsertRequest) (InsertResponse, error) { - err := persistRequest(ctx, req) +func Insert(ctx context.Context, req InsertRequest, db testdb) (InsertResponse, error) { + err := persistRequest(ctx, req, db) if err != nil { return InsertResponse{}, err } @@ -24,7 +24,7 @@ func Insert(ctx context.Context, req InsertRequest) (InsertResponse, error) { return InsertResponse{}, nil } -func persistRequest(ctx context.Context, req InsertRequest) error { +func persistRequest(ctx context.Context, req InsertRequest, db testdb) error { _, err := db.Get(ctx).Exec(`CREATE TABLE IF NOT EXISTS requests ( data TEXT, diff --git a/backend/controller/sql/testdata/go/database/database_test.go b/backend/controller/sql/testdata/go/database/database_test.go index d5b026600f..97f7b8d653 100644 --- a/backend/controller/sql/testdata/go/database/database_test.go +++ b/backend/controller/sql/testdata/go/database/database_test.go @@ -11,22 +11,26 @@ import ( func TestDatabase(t *testing.T) { ctx := ftltest.Context( + ftltest.WithCallsAllowedWithinModule(), ftltest.WithProjectFile("ftl-project.toml"), - ftltest.WithDatabase(db), + ftltest.WithDatabases(), ) - Insert(ctx, InsertRequest{Data: "unit test 1"}) + _, err := ftltest.Call[InsertClient, InsertRequest, InsertResponse](ctx, InsertRequest{Data: "unit test 1"}) + assert.NoError(t, err) list, err := getAll(ctx) assert.NoError(t, err) assert.Equal(t, 1, len(list)) assert.Equal(t, "unit test 1", list[0]) ctx = ftltest.Context( + ftltest.WithCallsAllowedWithinModule(), ftltest.WithProjectFile("ftl-project.toml"), - ftltest.WithDatabase(db), + ftltest.WithDatabases(), ) - Insert(ctx, InsertRequest{Data: "unit test 2"}) + _, err = ftltest.Call[InsertClient, InsertRequest, InsertResponse](ctx, InsertRequest{Data: "unit test 2"}) + assert.NoError(t, err) list, err = getAll(ctx) assert.NoError(t, err) assert.Equal(t, 1, len(list)) @@ -35,18 +39,35 @@ func TestDatabase(t *testing.T) { func TestOptionOrdering(t *testing.T) { ctx := ftltest.Context( - ftltest.WithDatabase(db), // <--- consumes DSNs + ftltest.WithCallsAllowedWithinModule(), + ftltest.WithDatabases(), // <--- consumes DSNs ftltest.WithProjectFile("ftl-project.toml"), // <--- provides DSNs ) - Insert(ctx, InsertRequest{Data: "unit test 1"}) + _, err := ftltest.Call[InsertClient, InsertRequest, InsertResponse](ctx, InsertRequest{Data: "unit test 1"}) + assert.NoError(t, err) list, err := getAll(ctx) assert.NoError(t, err) assert.Equal(t, 1, len(list)) assert.Equal(t, "unit test 1", list[0]) } +func TestWrongDbNameFetch(t *testing.T) { + ctx := ftltest.Context( + ftltest.WithCallsAllowedWithinModule(), + ftltest.WithProjectFile("ftl-project.toml"), + ftltest.WithDatabases(), + ) + + _, err := ftltest.GetDatabaseHandle(ctx, "Testdb") + assert.Error(t, err, `could not find database "Testdb"; did you mean "testdb"?`) +} + func getAll(ctx context.Context) ([]string, error) { + db, err := ftltest.GetDatabaseHandle(ctx, "testdb") + if err != nil { + return nil, err + } rows, err := db.Get(ctx).Query("SELECT data FROM requests ORDER BY created_at;") if err != nil { return nil, err diff --git a/backend/controller/sql/testdata/go/database/types.ftl.go b/backend/controller/sql/testdata/go/database/types.ftl.go new file mode 100644 index 0000000000..3bd03b26d0 --- /dev/null +++ b/backend/controller/sql/testdata/go/database/types.ftl.go @@ -0,0 +1,20 @@ +// Code generated by FTL. DO NOT EDIT. +package database + +import ( + "context" + "github.com/TBD54566975/ftl/go-runtime/ftl/reflection" + "github.com/TBD54566975/ftl/go-runtime/server" +) + +type InsertClient func(context.Context, InsertRequest) (InsertResponse, error) + +func init() { + reflection.Register( + reflection.Database("database", "testdb", server.InitPostgres), + reflection.ProvideResourcesForVerb( + Insert, + server.PostgresDatabase("database", "testdb"), + ), + ) +} diff --git a/backend/provisioner/testdata/go/echo/echo.go b/backend/provisioner/testdata/go/echo/echo.go index 8e5afd05e1..0b675fe13d 100644 --- a/backend/provisioner/testdata/go/echo/echo.go +++ b/backend/provisioner/testdata/go/echo/echo.go @@ -8,7 +8,7 @@ import ( "github.com/TBD54566975/ftl/go-runtime/ftl" ) -var db = ftl.PostgresDatabase("echodb") +var EchoDb = ftl.PostgresDatabaseHandle // Echo returns a greeting with the current time. // diff --git a/docs/content/docs/reference/unittests.md b/docs/content/docs/reference/unittests.md index 16641a3152..b74590afb5 100644 --- a/docs/content/docs/reference/unittests.md +++ b/docs/content/docs/reference/unittests.md @@ -62,16 +62,19 @@ ctx := ftltest.Context( ### Databases By default, calling `Get(ctx)` on a database panics. -To enable database access in a test, you must first [provide a DSN via a project file](#project-files-configs-and-secrets). You can then set up a test database: +To enable database access in a test, you must first [provide a DSN via a project file](#project-files-configs-and-secrets). + +You can then opt for `WithDatabases()` in your context, and all databases declared in your module will be +automatically provided for tests. ```go ctx := ftltest.Context( ftltest.WithDefaultProjectFile(), - ftltest.WithDatabase(db), + ftltest.WithDatabases(), ) ``` -This will: -- Take the provided DSN and appends `_test` to the database name. Eg: `accounts` becomes `accounts_test` -- Wipe all tables in the database so each test run happens on a clean database +Note: +- Database names from the provided DSNs will be appended with `_test`. Eg: `accounts` becomes `accounts_test` +- All tables in the database are wiped between tests, so each test run happens on a clean database ### Maps diff --git a/go-runtime/compile/build-template/.ftl.tmpl/go/main/main.go.tmpl b/go-runtime/compile/build-template/.ftl.tmpl/go/main/main.go.tmpl index 7febb945a2..bd34f18425 100644 --- a/go-runtime/compile/build-template/.ftl.tmpl/go/main/main.go.tmpl +++ b/go-runtime/compile/build-template/.ftl.tmpl/go/main/main.go.tmpl @@ -1,4 +1,5 @@ {{- $verbs := .Verbs -}} +{{- $dbs := .Databases -}} {{- $name := .Name -}} {{- with .MainCtx -}} @@ -24,6 +25,11 @@ func init() { {{- range .ExternalTypes}} reflection.ExternalType(*new({{.TypeName}})), {{- end}} +{{- range $dbs}} + {{- if eq .Type "postgres" }} + reflection.Database("{{.Module}}", "{{.Name}}", server.InitPostgres), + {{- end }} +{{- end}} {{- range $verbs}} reflection.ProvideResourcesForVerb( {{.TypeName}}, @@ -38,6 +44,11 @@ func init() { {{- else }} server.VerbClient[{{.TypeName}}, {{.Request.TypeName}}, {{.Response.TypeName}}](), {{- end -}} + {{- end }} + {{- with getDatabaseHandle . }} + {{- if eq .Type "postgres" }} + server.PostgresDatabase("{{.Module}}", "{{.Name}}"), + {{- end }} {{- end }} {{- end}} ), diff --git a/go-runtime/compile/build-template/types.ftl.go.tmpl b/go-runtime/compile/build-template/types.ftl.go.tmpl index 4d30ea5267..494d1d3501 100644 --- a/go-runtime/compile/build-template/types.ftl.go.tmpl +++ b/go-runtime/compile/build-template/types.ftl.go.tmpl @@ -1,4 +1,5 @@ {{- $verbs := .Verbs -}} +{{- $dbs := .Databases -}} {{- $name := .Name -}} {{- with .TypesCtx -}} {{- $moduleName := .MainModulePkg -}} @@ -42,6 +43,11 @@ func init() { {{- range .ExternalTypes}} reflection.ExternalType(*new({{.TypeName}})), {{- end}} +{{- range $dbs}} + {{- if eq .Type "postgres" }} + reflection.Database("{{.Module}}", "{{.Name}}", server.InitPostgres), + {{- end }} +{{- end}} {{- range $verbs}} reflection.ProvideResourcesForVerb( {{ trimModuleQualifier $moduleName .TypeName }}, @@ -58,6 +64,11 @@ func init() { {{- else }} server.VerbClient[{{$verb}}, {{.Request.LocalTypeName}}, {{.Response.LocalTypeName}}](), {{- end }} + {{- end }} + {{- with getDatabaseHandle . }} + {{- if eq .Type "postgres" }} + server.PostgresDatabase("{{.Module}}", "{{.Name}}"), + {{- end }} {{- end }} {{- end}} ), diff --git a/go-runtime/compile/build.go b/go-runtime/compile/build.go index fcf8e67250..3b76c6f72a 100644 --- a/go-runtime/compile/build.go +++ b/go-runtime/compile/build.go @@ -13,6 +13,7 @@ import ( "strings" "unicode" + "github.com/TBD54566975/ftl/go-runtime/schema/common" "github.com/TBD54566975/scaffolder" "github.com/alecthomas/types/optional" sets "github.com/deckarep/golang-set/v2" @@ -56,6 +57,7 @@ type mainModuleContext struct { Name string SharedModulesPaths []string Verbs []goVerb + Databases []goDBHandle Replacements []*modfile.Replace MainCtx mainFileContext TypesCtx typesFileContext @@ -99,6 +101,9 @@ func (c *mainModuleContext) generateTypesImports(mainModuleImport string) []stri if len(c.Verbs) > 0 { imports.Add(`"context"`) } + if len(c.Databases) > 0 { + imports.Add(`"github.com/TBD54566975/ftl/go-runtime/server"`) + } for _, st := range c.TypesCtx.SumTypes { imports.Add(st.importStatement()) for _, v := range st.Variants { @@ -249,6 +254,20 @@ type verbClient struct { func (v verbClient) resource() {} +type goDBHandle struct { + Type string + Name string + Module string + + nativeType +} + +func (d goDBHandle) resource() {} + +func (d goDBHandle) getNativeType() nativeType { + return d.nativeType +} + type ModifyFilesTransaction interface { Begin() error ModifiedFiles(paths ...string) error @@ -571,6 +590,7 @@ func (b *mainModuleContextBuilder) build(goModVersion, ftlVersion, projectName s SharedModulesPaths: sharedModulesPaths, Replacements: replacements, Verbs: make([]goVerb, 0, len(b.mainModule.Decls)), + Databases: make([]goDBHandle, 0, len(b.mainModule.Decls)), MainCtx: mainFileContext{ ProjectName: projectName, SumTypes: []goSumType{}, @@ -653,6 +673,8 @@ func (b *mainModuleContextBuilder) visit( case goExternalType: ctx.TypesCtx.ExternalTypes = append(ctx.TypesCtx.ExternalTypes, n) ctx.MainCtx.ExternalTypes = append(ctx.MainCtx.ExternalTypes, n) + case goDBHandle: + ctx.Databases = append(ctx.Databases, n) } return next() }) @@ -690,6 +712,15 @@ func (b *mainModuleContextBuilder) getGoType(module *schema.Module, node schema. return optional.None[goType](), isLocal, nil } return b.processExternalTypeAlias(n), isLocal, nil + case *schema.Database: + if !isLocal { + return optional.None[goType](), false, nil + } + dbHandle, err := b.processDatabase(module.Name, n) + if err != nil { + return optional.None[goType](), isLocal, err + } + return optional.Some[goType](dbHandle), isLocal, nil default: } @@ -788,6 +819,26 @@ func (b *mainModuleContextBuilder) processVerb(verb *schema.Verb) (goVerb, error calleeverb, }) } + case *schema.MetadataDatabases: + for _, call := range md.Calls { + resolved, ok := b.sch.Resolve(call).Get() + if !ok { + return goVerb{}, fmt.Errorf("failed to resolve %s database, used by %s.%s", call, + b.mainModule.Name, verb.Name) + } + db, ok := resolved.(*schema.Database) + if !ok { + return goVerb{}, fmt.Errorf("%s.%s uses %s database handle, but %s is not a database", + b.mainModule.Name, verb.Name, call, call) + } + + dbHandle, err := b.processDatabase(call.Module, db) + if err != nil { + return goVerb{}, err + } + resources = append(resources, dbHandle) + } + default: // TODO: implement other resources } @@ -800,6 +851,19 @@ func (b *mainModuleContextBuilder) processVerb(verb *schema.Verb) (goVerb, error return b.getGoVerb(nativeName, verb, resources...) } +func (b *mainModuleContextBuilder) processDatabase(moduleName string, db *schema.Database) (goDBHandle, error) { + nt, err := b.getNativeType(common.FtlPostgresDBTypePath) + if err != nil { + return goDBHandle{}, err + } + return goDBHandle{ + Name: db.Name, + Module: moduleName, + Type: db.Type, + nativeType: nt, + }, nil +} + func (b *mainModuleContextBuilder) getGoVerb(nativeName string, verb *schema.Verb, resources ...verbResource) (goVerb, error) { nt, err := b.getNativeType(nativeName) if err != nil { @@ -985,6 +1049,12 @@ var scaffoldFuncs = scaffolder.FuncMap{ } return nil }, + "getDatabaseHandle": func(resource verbResource) *goDBHandle { + if c, ok := resource.(goDBHandle); ok { + return &c + } + return nil + }, } // returns the import path and the directory name for a type alias if there is an associated go library diff --git a/go-runtime/compile/testdata/go/one/one.go b/go-runtime/compile/testdata/go/one/one.go index 2e07d9466d..388a15cda2 100644 --- a/go-runtime/compile/testdata/go/one/one.go +++ b/go-runtime/compile/testdata/go/one/one.go @@ -133,7 +133,8 @@ type ExportedData struct { var configValue = ftl.Config[Config]("configValue") var secretValue = ftl.Secret[string]("secretValue") -var testDb = ftl.PostgresDatabase("testDb") + +type testDb = ftl.PostgresDatabaseHandle //ftl:verb func Verb(ctx context.Context, req Req) (Resp, error) { diff --git a/go-runtime/ftl/database.go b/go-runtime/ftl/database.go index c0c6717654..cc99c8303b 100644 --- a/go-runtime/ftl/database.go +++ b/go-runtime/ftl/database.go @@ -4,61 +4,35 @@ import ( "context" "database/sql" "fmt" - "time" - "github.com/XSAM/otelsql" + "github.com/TBD54566975/ftl/internal/modulecontext" "github.com/alecthomas/types/once" _ "github.com/jackc/pgx/v5/stdlib" // Register Postgres driver - "go.opentelemetry.io/otel/attribute" - semconv "go.opentelemetry.io/otel/semconv/v1.4.0" - - "github.com/TBD54566975/ftl/internal/modulecontext" ) -type Database struct { - Name string - DBType modulecontext.DBType - - db *once.Handle[*sql.DB] +type DatabaseHandle interface { + Name() string + DBType() modulecontext.DBType + Get(ctx context.Context) *sql.DB + String() string } -// PostgresDatabase returns a handler for the named database. -func PostgresDatabase(name string) Database { - return Database{ - Name: name, - DBType: modulecontext.DBTypePostgres, - db: once.Once(func(ctx context.Context) (*sql.DB, error) { - provider := modulecontext.FromContext(ctx).CurrentContext() - dsn, err := provider.GetDatabase(name, modulecontext.DBTypePostgres) - if err != nil { - return nil, fmt.Errorf("failed to get database %q: %w", name, err) - } - db, err := otelsql.Open("pgx", dsn) - if err != nil { - return nil, fmt.Errorf("failed to open database %q: %w", name, err) - } +type PostgresDatabaseHandle struct { + name string + db *once.Handle[*sql.DB] +} - // sets db.system and db.name attributes - metricAttrs := otelsql.WithAttributes( - semconv.DBSystemPostgreSQL, - semconv.DBNameKey.String(name), - attribute.Bool("ftl.is_user_service", true), - ) - err = otelsql.RegisterDBStatsMetrics(db, metricAttrs) - if err != nil { - return nil, fmt.Errorf("failed to register database metrics: %w", err) - } - db.SetConnMaxIdleTime(time.Minute) - db.SetMaxOpenConns(20) - return db, nil - }), - } +// NewPostgresDatabaseHandle is managed by FTL. +func NewPostgresDatabaseHandle(name string, db *once.Handle[*sql.DB]) PostgresDatabaseHandle { + return PostgresDatabaseHandle{name: name, db: db} } -func (d Database) String() string { return fmt.Sprintf("database %q", d.Name) } +func (d PostgresDatabaseHandle) Name() string { return d.name } +func (d PostgresDatabaseHandle) DBType() modulecontext.DBType { return modulecontext.DBTypePostgres } +func (d PostgresDatabaseHandle) String() string { return fmt.Sprintf("database %q", d.name) } // Get returns the SQL DB connection for the database. -func (d Database) Get(ctx context.Context) *sql.DB { +func (d PostgresDatabaseHandle) Get(ctx context.Context) *sql.DB { db, err := d.db.Get(ctx) if err != nil { panic(err) diff --git a/go-runtime/ftl/ftltest/ftltest.go b/go-runtime/ftl/ftltest/ftltest.go index 748890830d..d272b6bb92 100644 --- a/go-runtime/ftl/ftltest/ftltest.go +++ b/go-runtime/ftl/ftltest/ftltest.go @@ -12,6 +12,7 @@ import ( "sort" "strings" + "github.com/TBD54566975/ftl/internal/schema/strcase" _ "github.com/jackc/pgx/v5/stdlib" // SQL driver "github.com/TBD54566975/ftl/go-runtime/ftl" @@ -33,6 +34,7 @@ type OptionsState struct { databases map[string]modulecontext.Database mockVerbs map[schema.RefKey]modulecontext.Verb allowDirectVerbBehavior bool + provideDatabases bool } type optionRank int @@ -73,6 +75,15 @@ func newContext(ctx context.Context, module string, options ...Option) context.C } } + if state.provideDatabases { + for _, db := range reflection.GetDatabases(module) { + err := withDatabase(db.Name).apply(ctx, state) + if err != nil { + panic(fmt.Sprintf("error applying option: %v", err)) + } + } + } + builder := modulecontext.NewBuilder(module).AddDatabases(state.databases) builder = builder.UpdateForTesting(state.mockVerbs, state.allowDirectVerbBehavior, newFakeLeaseClient()) @@ -211,63 +222,20 @@ func WithSecret[T ftl.SecretType](secret ftl.SecretValue[T], value T) Option { } } -// WithDatabase sets up a database for testing by appending "_test" to the DSN and emptying all tables +// WithDatabases sets up all databases declared in the current module for testing by appending "_test" to the DSN and +// emptying all tables // // To be used when setting up a context for a test: // // ctx := ftltest.Context( -// ftltest.WithDatabase(db), +// ftltest.WithDatabases(), // // ... other options // ) -func WithDatabase(dbHandle ftl.Database) Option { +func WithDatabases() Option { return Option{ rank: other, apply: func(ctx context.Context, state *OptionsState) error { - fftl := internal.FromContext(ctx) - originalDSN, err := getDSNFromSecret(fftl, moduleGetter(), dbHandle.Name) - if err != nil { - return err - } - - // convert DSN by appending "_test" to table name - // postgres DSN format: postgresql://[user[:password]@][netloc][:port][/dbname][?param1=value1&...] - // source: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING - dsnURL, err := url.Parse(originalDSN) - if err != nil { - return fmt.Errorf("could not parse DSN: %w", err) - } - if dsnURL.Path == "" { - return fmt.Errorf("DSN for %s must include table name: %s", dbHandle.Name, originalDSN) - } - dsnURL.Path += "_test" - dsn := dsnURL.String() - - // connect to db and clear out the contents of each table - sqlDB, err := sql.Open("pgx", dsn) - if err != nil { - return fmt.Errorf("could not create database %q with DSN %q: %w", dbHandle.Name, dsn, err) - } - _, err = sqlDB.ExecContext(ctx, `DO $$ - DECLARE - table_name text; - BEGIN - FOR table_name IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public') - LOOP - EXECUTE 'ALTER TABLE ' || quote_ident(table_name) || ' DISABLE TRIGGER ALL;'; - EXECUTE 'DELETE FROM ' || quote_ident(table_name) || ';'; - EXECUTE 'ALTER TABLE ' || quote_ident(table_name) || ' ENABLE TRIGGER ALL;'; - END LOOP; - END $$;`) - if err != nil { - return fmt.Errorf("could not clear tables in database %q: %w", dbHandle.Name, err) - } - - // replace original database with test database - replacementDB, err := modulecontext.NewTestDatabase(modulecontext.DBTypePostgres, dsn) - if err != nil { - return fmt.Errorf("could not create database %q with DSN %q: %w", dbHandle.Name, dsn, err) - } - state.databases[dbHandle.Name] = replacementDB + state.provideDatabases = true return nil }, } @@ -521,9 +489,39 @@ func CallEmpty[VerbClient any](ctx context.Context) error { return err } +// GetDatabaseHandle returns a database handle for the given database name. Database names are transformed +// to lower camel when provisioned and represented in the schema. +// +// E.g. +// +// type MyDatabase = ftl.PostgresDatabase +// +// can be referenced as "myDatabase" +func GetDatabaseHandle(ctx context.Context, dbname string) (ftl.DatabaseHandle, error) { + var reflectedDB *reflection.ReflectedDatabaseHandle + var containsCaseTransformedDBName bool + for _, db := range reflection.GetDatabases(moduleGetter()) { + if db.Name == dbname { + reflectedDB = reflection.GetDatabase(reflection.Ref{moduleGetter(), dbname}) + } + if db.Name == strcase.ToLowerCamel(dbname) { + containsCaseTransformedDBName = true + } + } + if reflectedDB == nil { + suffix := "" + if containsCaseTransformedDBName { + suffix = fmt.Sprintf("; did you mean %q?", strcase.ToLowerCamel(dbname)) + } + return nil, fmt.Errorf("could not find database %q%s", dbname, suffix) + } + + return ftl.NewPostgresDatabaseHandle(reflectedDB.Name, reflectedDB.DB), nil +} + func call[VerbClient, Req, Resp any](ctx context.Context, req Req) (resp Resp, err error) { ref := reflection.ClientRef[VerbClient]() - inline := server.Call[Req, Resp](ref) + inline := server.InvokeVerb[Req, Resp](ref) moduleCtx := modulecontext.FromContext(ctx).CurrentContext() override, err := moduleCtx.BehaviorForVerb(schema.Ref{Module: ref.Module, Name: ref.Name}) if err != nil { @@ -551,3 +549,58 @@ func widenVerb[Req, Resp any](verb ftl.Verb[Req, Resp]) ftl.Verb[any, any] { return verb(ctx, req) } } + +// withDatabase sets up a database for testing by appending "_test" to the DSN and emptying all tables +func withDatabase(name string) Option { + return Option{ + rank: other, + apply: func(ctx context.Context, state *OptionsState) error { + fftl := internal.FromContext(ctx) + originalDSN, err := getDSNFromSecret(fftl, moduleGetter(), name) + if err != nil { + return fmt.Errorf("could not get DSN for database %q, try adding ftltest.WithProject(db) as an option with ftltest.Context(...): %w", name, err) + } + + // convert DSN by appending "_test" to table name + // postgres DSN format: postgresql://[user[:password]@][netloc][:port][/dbname][?param1=value1&...] + // source: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING + dsnURL, err := url.Parse(originalDSN) + if err != nil { + return fmt.Errorf("could not parse DSN: %w", err) + } + if dsnURL.Path == "" { + return fmt.Errorf("DSN for %s must include table name: %s", name, originalDSN) + } + dsnURL.Path += "_test" + dsn := dsnURL.String() + + // connect to db and clear out the contents of each table + sqlDB, err := sql.Open("pgx", dsn) + if err != nil { + return fmt.Errorf("could not create database %q with DSN %q: %w", name, dsn, err) + } + _, err = sqlDB.ExecContext(ctx, `DO $$ + DECLARE + table_name text; + BEGIN + FOR table_name IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public') + LOOP + EXECUTE 'ALTER TABLE ' || quote_ident(table_name) || ' DISABLE TRIGGER ALL;'; + EXECUTE 'DELETE FROM ' || quote_ident(table_name) || ';'; + EXECUTE 'ALTER TABLE ' || quote_ident(table_name) || ' ENABLE TRIGGER ALL;'; + END LOOP; + END $$;`) + if err != nil { + return fmt.Errorf("could not clear tables in database %q: %w", name, err) + } + + // replace original database with test database + replacementDB, err := modulecontext.NewTestDatabase(modulecontext.DBTypePostgres, dsn) + if err != nil { + return fmt.Errorf("could not create database %q with DSN %q: %w", name, dsn, err) + } + state.databases[name] = replacementDB + return nil + }, + } +} diff --git a/go-runtime/ftl/reflection/database.go b/go-runtime/ftl/reflection/database.go new file mode 100644 index 0000000000..00d4548e6d --- /dev/null +++ b/go-runtime/ftl/reflection/database.go @@ -0,0 +1,29 @@ +package reflection + +import ( + "database/sql" + + "github.com/alecthomas/types/once" +) + +type DBType int + +const ( + DBTypePostgres DBType = iota +) + +type ReflectedDatabaseHandle struct { + Name string + DBType DBType + DB *once.Handle[*sql.DB] +} + +func Database(module, name string, init func(ref Ref) *ReflectedDatabaseHandle) Registree { + // databases are declared as a true alias, e.g. `type MyDatabase = ftl.PostgresDatabaseHandle`, so we can't + // derive the ref from the reflected type (the reflected type will always be ftl.PostgresDatabaseHandle). + // Passing module/name explicitly instead. + ref := Ref{Module: module, Name: name} + return func(t *TypeRegistry) { + t.databases[ref] = init(ref) + } +} diff --git a/go-runtime/ftl/reflection/singleton.go b/go-runtime/ftl/reflection/singleton.go index 8fecf751ed..87f6d44f0a 100644 --- a/go-runtime/ftl/reflection/singleton.go +++ b/go-runtime/ftl/reflection/singleton.go @@ -39,6 +39,21 @@ func GetVariantByName(discriminator reflect.Type, name string) optional.Option[r return singletonTypeRegistry.getVariantByName(discriminator, name) } +func GetDatabase(ref Ref) *ReflectedDatabaseHandle { + return singletonTypeRegistry.databases[ref] +} + +func GetDatabases(module string) []*ReflectedDatabaseHandle { + var databases []*ReflectedDatabaseHandle + for ref, db := range singletonTypeRegistry.databases { + if ref.Module == module { + databases = append(databases, db) + } + + } + return databases +} + // GetDiscriminatorByVariant returns the discriminator type for the given variant type. func GetDiscriminatorByVariant(variant reflect.Type) optional.Option[reflect.Type] { return singletonTypeRegistry.getDiscriminatorByVariant(variant) diff --git a/go-runtime/ftl/reflection/type_registry.go b/go-runtime/ftl/reflection/type_registry.go index f74708c6c8..f5af97a5c0 100644 --- a/go-runtime/ftl/reflection/type_registry.go +++ b/go-runtime/ftl/reflection/type_registry.go @@ -18,6 +18,7 @@ type TypeRegistry struct { fsm map[string]ReflectedFSM externalTypes map[reflect.Type]struct{} verbCalls map[Ref]verbCall + databases map[Ref]*ReflectedDatabaseHandle } type sumTypeVariant struct { @@ -73,6 +74,7 @@ func newTypeRegistry(options ...Registree) *TypeRegistry { fsm: map[string]ReflectedFSM{}, externalTypes: map[reflect.Type]struct{}{}, verbCalls: map[Ref]verbCall{}, + databases: map[Ref]*ReflectedDatabaseHandle{}, } for _, o := range options { o(t) diff --git a/go-runtime/schema/common/common.go b/go-runtime/schema/common/common.go index f81b3bef45..33bffc03f4 100644 --- a/go-runtime/schema/common/common.go +++ b/go-runtime/schema/common/common.go @@ -23,7 +23,8 @@ var ( // FtlUnitTypePath is the path to the FTL unit type. FtlUnitTypePath = "github.com/TBD54566975/ftl/go-runtime/ftl.Unit" // FtlOptionTypePath is the path to the FTL option type. - FtlOptionTypePath = "github.com/TBD54566975/ftl/go-runtime/ftl.Option" + FtlOptionTypePath = "github.com/TBD54566975/ftl/go-runtime/ftl.Option" + FtlPostgresDBTypePath = "github.com/TBD54566975/ftl/go-runtime/ftl.PostgresDatabaseHandle" extractorRegistery = xsync.NewMapOf[reflect.Type, ExtractDeclFunc[schema.Decl, ast.Node]]() ) diff --git a/go-runtime/schema/database/analyzer.go b/go-runtime/schema/database/analyzer.go index 8ce7ed7011..a6834b1c55 100644 --- a/go-runtime/schema/database/analyzer.go +++ b/go-runtime/schema/database/analyzer.go @@ -4,6 +4,7 @@ import ( "go/ast" "go/types" + "github.com/TBD54566975/ftl/internal/schema/strcase" "github.com/TBD54566975/golang-tools/go/analysis" "github.com/alecthomas/types/optional" @@ -11,34 +12,29 @@ import ( "github.com/TBD54566975/ftl/internal/schema" ) -const ftlPostgresDBFuncPath = "github.com/TBD54566975/ftl/go-runtime/ftl.PostgresDatabase" +const ftlPostgresDBTypePath = "github.com/TBD54566975/ftl/go-runtime/ftl.PostgresDatabaseHandle" // Extractor extracts databases to the module schema. -var Extractor = common.NewCallDeclExtractor[*schema.Database]("database", Extract, ftlPostgresDBFuncPath) +var Extractor = common.NewResourceDeclExtractor[*schema.Database]("database", Extract, ftlPostgresDBTypePath) -func Extract(pass *analysis.Pass, obj types.Object, node *ast.GenDecl, callExpr *ast.CallExpr, - callPath string) optional.Option[*schema.Database] { +func Extract(pass *analysis.Pass, obj types.Object, node *ast.TypeSpec, typePath string) optional.Option[*schema.Database] { var comments []string if md, ok := common.GetFactForObject[*common.ExtractedMetadata](pass, obj).Get(); ok { comments = md.Comments } - if callPath == ftlPostgresDBFuncPath { - return extractDatabase(pass, callExpr, schema.PostgresDatabaseType, comments) + if typePath == ftlPostgresDBTypePath { + return extractDatabase(pass, node, schema.PostgresDatabaseType, comments) } return optional.None[*schema.Database]() } func extractDatabase( pass *analysis.Pass, - node *ast.CallExpr, + node *ast.TypeSpec, dbType string, comments []string, ) optional.Option[*schema.Database] { - name := common.ExtractStringLiteralArg(pass, node, 0) - if name == "" { - return optional.None[*schema.Database]() - } - + name := strcase.ToLowerCamel(node.Name.Name) if !schema.ValidateName(name) { common.Errorf(pass, node, "invalid database name %q", name) return optional.None[*schema.Database]() diff --git a/go-runtime/schema/schema_fuzz_test.go b/go-runtime/schema/schema_fuzz_test.go index 8f2a1fcc02..4406678262 100644 --- a/go-runtime/schema/schema_fuzz_test.go +++ b/go-runtime/schema/schema_fuzz_test.go @@ -139,7 +139,7 @@ func DataFunc(ctx context.Context, req Data) (Data, error) { } -var db = ftl.PostgresDatabase("testDb") +type testDb = ftl.PostgresDatabaseHandle ` + (func() string { if symbol == "int" || symbol == "string" { diff --git a/go-runtime/schema/testdata/failing/child/child.go b/go-runtime/schema/testdata/failing/child/child.go index e51e75175d..d9d7aec099 100644 --- a/go-runtime/schema/testdata/failing/child/child.go +++ b/go-runtime/schema/testdata/failing/child/child.go @@ -38,4 +38,4 @@ var duplConfig = ftl.Config[string]("FTL_CONFIG_ENDPOINT") var duplSecret = ftl.Secret[string]("FTL_SECRET_ENDPOINT") var duplicateDeclName = ftl.Config[string]("PrivateData") -var duplDB = ftl.PostgresDatabase("testDb") +var TestDb = ftl.PostgresDatabaseHandle diff --git a/go-runtime/schema/testdata/failing/failing.go b/go-runtime/schema/testdata/failing/failing.go index 85aabcf0c0..7057026745 100644 --- a/go-runtime/schema/testdata/failing/failing.go +++ b/go-runtime/schema/testdata/failing/failing.go @@ -17,8 +17,8 @@ var goodConfig = ftl.Config[string]("FTL_CONFIG_ENDPOINT") // var duplSecret = ftl.Secret[string]("FTL_ENDPOINT") var goodSecret = ftl.Secret[string]("FTL_SECRET_ENDPOINT") -// var duplDB = ftl.PostgresDatabase("testDb") -var goodDB = ftl.PostgresDatabase("testDb") +// type TestDb = ftl.PostgresDatabaseHandle +type TestDb = ftl.PostgresDatabaseHandle type Request struct { BadParam error diff --git a/go-runtime/schema/testdata/one/one.go b/go-runtime/schema/testdata/one/one.go index 56f02c0116..fd29e11ddf 100644 --- a/go-runtime/schema/testdata/one/one.go +++ b/go-runtime/schema/testdata/one/one.go @@ -134,7 +134,8 @@ type ExportedData struct { var configValue = ftl.Config[Config]("configValue") var secretValue = ftl.Secret[string]("secretValue") -var testDb = ftl.PostgresDatabase("testDb") + +type testDb = ftl.PostgresDatabaseHandle //ftl:verb func Verb(ctx context.Context, req Req) (Resp, error) { diff --git a/go-runtime/schema/verb/analyzer.go b/go-runtime/schema/verb/analyzer.go index 373fadecaa..4b15185b85 100644 --- a/go-runtime/schema/verb/analyzer.go +++ b/go-runtime/schema/verb/analyzer.go @@ -20,6 +20,7 @@ type resourceType int const ( none resourceType = iota verbClient + databaseHandle ) // Extractor extracts verbs to the module schema. @@ -57,6 +58,8 @@ func Extract(pass *analysis.Pass, node *ast.FuncDecl, obj types.Object) optional calleeRef.Name = strings.TrimSuffix(calleeRef.Name, "Client") verb.AddCall(calleeRef) common.MarkIncludeNativeName(pass, paramObj, calleeRef) + case databaseHandle: + verb.AddDatabase(getResourceRef(paramObj, pass, param)) } } }) { @@ -147,7 +150,12 @@ func getParamResourceType(paramObj types.Object) resourceType { if paramObj == nil { return none } - + if paramObj.Pkg() == nil { + return none + } + if paramObj.Pkg().Path()+"."+paramObj.Name() == common.FtlPostgresDBTypePath { + return databaseHandle + } switch t := paramObj.Type().(type) { case *types.Named: if _, ok := t.Underlying().(*types.Signature); !ok { @@ -155,6 +163,13 @@ func getParamResourceType(paramObj types.Object) resourceType { } return verbClient + case *types.Alias: + named, ok := t.Rhs().(*types.Named) + if !ok { + return none + } + return getParamResourceType(named.Obj()) + default: return none } diff --git a/go-runtime/server/database.go b/go-runtime/server/database.go new file mode 100644 index 0000000000..b35c882f19 --- /dev/null +++ b/go-runtime/server/database.go @@ -0,0 +1,58 @@ +package server + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "time" + + "github.com/XSAM/otelsql" + + "github.com/TBD54566975/ftl/go-runtime/ftl" + "github.com/TBD54566975/ftl/go-runtime/ftl/reflection" + "github.com/TBD54566975/ftl/internal/modulecontext" + "github.com/alecthomas/types/once" + "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.4.0" +) + +func PostgresDatabase(module string, name string) reflection.VerbResource { + return func() reflect.Value { + reflectedDB := reflection.GetDatabase(reflection.Ref{Module: module, Name: name}) + db := ftl.NewPostgresDatabaseHandle(reflectedDB.Name, reflectedDB.DB) + return reflect.ValueOf(db) + } +} + +func InitPostgres(ref reflection.Ref) *reflection.ReflectedDatabaseHandle { + return &reflection.ReflectedDatabaseHandle{ + Name: ref.Name, + DBType: reflection.DBTypePostgres, + DB: once.Once(func(ctx context.Context) (*sql.DB, error) { + provider := modulecontext.FromContext(ctx).CurrentContext() + dsn, err := provider.GetDatabase(ref.Name, modulecontext.DBTypePostgres) + if err != nil { + return nil, fmt.Errorf("failed to get database %q: %w", ref.Name, err) + } + db, err := otelsql.Open("pgx", dsn) + if err != nil { + return nil, fmt.Errorf("failed to open database %q: %w", ref.Name, err) + } + + // sets db.system and db.name attributes + metricAttrs := otelsql.WithAttributes( + semconv.DBSystemPostgreSQL, + semconv.DBNameKey.String(ref.Name), + attribute.Bool("ftl.is_user_service", true), + ) + err = otelsql.RegisterDBStatsMetrics(db, metricAttrs) + if err != nil { + return nil, fmt.Errorf("failed to register database metrics: %w", err) + } + db.SetConnMaxIdleTime(time.Minute) + db.SetMaxOpenConns(20) + return db, nil + }), + } +} diff --git a/go-runtime/server/server.go b/go-runtime/server/server.go index 56f99a6d39..c3671cbf4e 100644 --- a/go-runtime/server/server.go +++ b/go-runtime/server/server.go @@ -9,8 +9,6 @@ import ( "strings" "connectrpc.com/connect" - "github.com/alecthomas/types/optional" - ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1" "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/ftlv1connect" "github.com/TBD54566975/ftl/common/plugin" @@ -24,6 +22,7 @@ import ( "github.com/TBD54566975/ftl/internal/observability" "github.com/TBD54566975/ftl/internal/rpc" "github.com/TBD54566975/ftl/internal/schema" + "github.com/alecthomas/types/optional" ) type UserVerbConfig struct { @@ -80,8 +79,8 @@ func HandleCall[Req, Resp any](verb any) Handler { return nil, fmt.Errorf("invalid request to verb %s: %w", ref, err) } - // Call Verb. - resp, err := Call[Req, Resp](ref)(ctx, req) + // InvokeVerb Verb. + resp, err := InvokeVerb[Req, Resp](ref)(ctx, req) if err != nil { return nil, fmt.Errorf("call to verb %s failed: %w", ref, err) } @@ -108,56 +107,29 @@ func HandleEmpty(verb any) Handler { return HandleCall[ftl.Unit, ftl.Unit](verb) } -func call[Verb, Req, Resp any]() func(ctx context.Context, req Req) (resp Resp, err error) { - typ := reflect.TypeFor[Verb]() - if typ.Kind() != reflect.Func { - panic(fmt.Sprintf("Cannot register %s: expected function, got %s", typ, typ.Kind())) - } - callee := reflection.TypeRef[Verb]() - callee.Name = strings.TrimSuffix(callee.Name, "Client") +func InvokeVerb[Req, Resp any](ref reflection.Ref) func(ctx context.Context, req Req) (resp Resp, err error) { return func(ctx context.Context, req Req) (resp Resp, err error) { - ref := reflection.Ref{Module: callee.Module, Name: callee.Name} - moduleCtx := modulecontext.FromContext(ctx).CurrentContext() - override, err := moduleCtx.BehaviorForVerb(schema.Ref{Module: ref.Module, Name: ref.Name}) - if err != nil { - return resp, fmt.Errorf("%s: %w", ref, err) - } - if behavior, ok := override.Get(); ok { - uncheckedResp, err := behavior.Call(ctx, modulecontext.Verb(widenVerb(Call[Req, Resp](ref))), req) - if err != nil { - return resp, fmt.Errorf("%s: %w", ref, err) - } - if r, ok := uncheckedResp.(Resp); ok { - return r, nil - } - return resp, fmt.Errorf("%s: overridden verb had invalid response type %T, expected %v", ref, - uncheckedResp, reflect.TypeFor[Resp]()) + request := optional.Some[any](req) + if reflect.TypeFor[Req]() == reflect.TypeFor[ftl.Unit]() { + request = optional.None[any]() } - reqData, err := encoding.Marshal(req) + out, err := reflection.CallVerb(reflection.Ref{Module: ref.Module, Name: ref.Name})(ctx, request) if err != nil { - return resp, fmt.Errorf("%s: failed to marshal request: %w", callee, err) + return resp, err } - client := rpc.ClientFromContext[ftlv1connect.VerbServiceClient](ctx) - cresp, err := client.Call(ctx, connect.NewRequest(&ftlv1.CallRequest{Verb: callee.ToProto(), Body: reqData})) - if err != nil { - return resp, fmt.Errorf("%s: failed to call Verb: %w", callee, err) + var respValue any + if r, ok := out.Get(); ok { + respValue = r + } else { + respValue = ftl.Unit{} } - switch cresp := cresp.Msg.Response.(type) { - case *ftlv1.CallResponse_Error_: - return resp, fmt.Errorf("%s: %s", callee, cresp.Error.Message) - - case *ftlv1.CallResponse_Body: - err = encoding.Unmarshal(cresp.Body, &resp) - if err != nil { - return resp, fmt.Errorf("%s: failed to decode response: %w", callee, err) - } - return resp, nil - - default: - panic(fmt.Sprintf("%s: invalid response type %T", callee, cresp)) + resp, ok := respValue.(Resp) + if !ok { + return resp, fmt.Errorf("unexpected response type from verb %s: %T", ref, resp) } + return resp, err } } @@ -200,29 +172,56 @@ func EmptyClient[Verb any]() reflection.VerbResource { } } -func Call[Req, Resp any](ref reflection.Ref) func(ctx context.Context, req Req) (resp Resp, err error) { +func call[Verb, Req, Resp any]() func(ctx context.Context, req Req) (resp Resp, err error) { + typ := reflect.TypeFor[Verb]() + if typ.Kind() != reflect.Func { + panic(fmt.Sprintf("Cannot register %s: expected function, got %s", typ, typ.Kind())) + } + callee := reflection.TypeRef[Verb]() + callee.Name = strings.TrimSuffix(callee.Name, "Client") return func(ctx context.Context, req Req) (resp Resp, err error) { - request := optional.Some[any](req) - if reflect.TypeFor[Req]() == reflect.TypeFor[ftl.Unit]() { - request = optional.None[any]() + ref := reflection.Ref{Module: callee.Module, Name: callee.Name} + moduleCtx := modulecontext.FromContext(ctx).CurrentContext() + override, err := moduleCtx.BehaviorForVerb(schema.Ref{Module: ref.Module, Name: ref.Name}) + if err != nil { + return resp, fmt.Errorf("%s: %w", ref, err) + } + if behavior, ok := override.Get(); ok { + uncheckedResp, err := behavior.Call(ctx, modulecontext.Verb(widenVerb(InvokeVerb[Req, Resp](ref))), req) + if err != nil { + return resp, fmt.Errorf("%s: %w", ref, err) + } + if r, ok := uncheckedResp.(Resp); ok { + return r, nil + } + return resp, fmt.Errorf("%s: overridden verb had invalid response type %T, expected %v", ref, + uncheckedResp, reflect.TypeFor[Resp]()) } - out, err := reflection.CallVerb(reflection.Ref{Module: ref.Module, Name: ref.Name})(ctx, request) + reqData, err := encoding.Marshal(req) if err != nil { - return resp, err + return resp, fmt.Errorf("%s: failed to marshal request: %w", callee, err) } - var respValue any - if r, ok := out.Get(); ok { - respValue = r - } else { - respValue = ftl.Unit{} + client := rpc.ClientFromContext[ftlv1connect.VerbServiceClient](ctx) + cresp, err := client.Call(ctx, connect.NewRequest(&ftlv1.CallRequest{Verb: callee.ToProto(), Body: reqData})) + if err != nil { + return resp, fmt.Errorf("%s: failed to call Verb: %w", callee, err) } - resp, ok := respValue.(Resp) - if !ok { - return resp, fmt.Errorf("unexpected response type from verb %s: %T", ref, resp) + switch cresp := cresp.Msg.Response.(type) { + case *ftlv1.CallResponse_Error_: + return resp, fmt.Errorf("%s: %s", callee, cresp.Error.Message) + + case *ftlv1.CallResponse_Body: + err = encoding.Unmarshal(cresp.Body, &resp) + if err != nil { + return resp, fmt.Errorf("%s: failed to decode response: %w", callee, err) + } + return resp, nil + + default: + panic(fmt.Sprintf("%s: invalid response type %T", callee, cresp)) } - return resp, err } } diff --git a/internal/buildengine/testdata/another/go.mod b/internal/buildengine/testdata/another/go.mod index 9981baf754..77469d3428 100644 --- a/internal/buildengine/testdata/another/go.mod +++ b/internal/buildengine/testdata/another/go.mod @@ -8,7 +8,6 @@ require ( connectrpc.com/connect v1.16.2 // indirect connectrpc.com/grpcreflect v1.2.0 // indirect connectrpc.com/otelconnect v0.7.1 // indirect - github.com/XSAM/otelsql v0.34.0 // indirect github.com/alecthomas/atomic v0.1.0-alpha2 // indirect github.com/alecthomas/concurrency v0.0.2 // indirect github.com/alecthomas/participle/v2 v2.1.1 // indirect diff --git a/internal/buildengine/testdata/another/go.sum b/internal/buildengine/testdata/another/go.sum index e86889ebc9..105f8053ac 100644 --- a/internal/buildengine/testdata/another/go.sum +++ b/internal/buildengine/testdata/another/go.sum @@ -6,8 +6,6 @@ connectrpc.com/otelconnect v0.7.1 h1:scO5pOb0i4yUE66CnNrHeK1x51yq0bE0ehPg6WvzXJY connectrpc.com/otelconnect v0.7.1/go.mod h1:dh3bFgHBTb2bkqGCeVVOtHJreSns7uu9wwL2Tbz17ms= github.com/TBD54566975/scaffolder v1.1.0 h1:R92zjC4XiS/lGCxJ8Ebn93g8gC0LU9qo06AAKo9cEJE= github.com/TBD54566975/scaffolder v1.1.0/go.mod h1:dRi67GryEhZ5u0XRSiR294SYaqAfnCkZ7u3rmc4W6iI= -github.com/XSAM/otelsql v0.34.0 h1:YdCRKy17Xn0MH717LEwqpVL/a+4nexmSCBrgoycYY6E= -github.com/XSAM/otelsql v0.34.0/go.mod h1:xaE+ybu+kJOYvtDyThbe0VoKWngvKHmNlrM1rOn8f94= github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= github.com/alecthomas/atomic v0.1.0-alpha2 h1:dqwXmax66gXvHhsOS4pGPZKqYOlTkapELkLb3MNdlH8= diff --git a/internal/buildengine/testdata/other/go.mod b/internal/buildengine/testdata/other/go.mod index ce8705f5ac..176cdd6e16 100644 --- a/internal/buildengine/testdata/other/go.mod +++ b/internal/buildengine/testdata/other/go.mod @@ -8,7 +8,6 @@ require ( connectrpc.com/connect v1.16.2 // indirect connectrpc.com/grpcreflect v1.2.0 // indirect connectrpc.com/otelconnect v0.7.1 // indirect - github.com/XSAM/otelsql v0.34.0 // indirect github.com/alecthomas/atomic v0.1.0-alpha2 // indirect github.com/alecthomas/concurrency v0.0.2 // indirect github.com/alecthomas/participle/v2 v2.1.1 // indirect diff --git a/internal/buildengine/testdata/other/go.sum b/internal/buildengine/testdata/other/go.sum index e86889ebc9..105f8053ac 100644 --- a/internal/buildengine/testdata/other/go.sum +++ b/internal/buildengine/testdata/other/go.sum @@ -6,8 +6,6 @@ connectrpc.com/otelconnect v0.7.1 h1:scO5pOb0i4yUE66CnNrHeK1x51yq0bE0ehPg6WvzXJY connectrpc.com/otelconnect v0.7.1/go.mod h1:dh3bFgHBTb2bkqGCeVVOtHJreSns7uu9wwL2Tbz17ms= github.com/TBD54566975/scaffolder v1.1.0 h1:R92zjC4XiS/lGCxJ8Ebn93g8gC0LU9qo06AAKo9cEJE= github.com/TBD54566975/scaffolder v1.1.0/go.mod h1:dRi67GryEhZ5u0XRSiR294SYaqAfnCkZ7u3rmc4W6iI= -github.com/XSAM/otelsql v0.34.0 h1:YdCRKy17Xn0MH717LEwqpVL/a+4nexmSCBrgoycYY6E= -github.com/XSAM/otelsql v0.34.0/go.mod h1:xaE+ybu+kJOYvtDyThbe0VoKWngvKHmNlrM1rOn8f94= github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= github.com/alecthomas/atomic v0.1.0-alpha2 h1:dqwXmax66gXvHhsOS4pGPZKqYOlTkapELkLb3MNdlH8= diff --git a/internal/modulecontext/module_context.go b/internal/modulecontext/module_context.go index 55cf425442..f583783755 100644 --- a/internal/modulecontext/module_context.go +++ b/internal/modulecontext/module_context.go @@ -143,7 +143,7 @@ func (m ModuleContext) GetDatabase(name string, dbType DBType) (string, error) { return "", fmt.Errorf("database %s does not match expected type of %s", name, dbType) } if m.isTesting && !db.isTestDB { - return "", fmt.Errorf("accessing non-test database %q while testing: try adding ftltest.WithDatabase(db) as an option with ftltest.Context(...)", name) + return "", fmt.Errorf("accessing non-test database %q while testing", name) } return db.DSN, nil } diff --git a/internal/schema/verb.go b/internal/schema/verb.go index dfc7d30479..4c68f94812 100644 --- a/internal/schema/verb.go +++ b/internal/schema/verb.go @@ -119,6 +119,15 @@ func (v *Verb) AddCall(verb *Ref) { v.Metadata = append(v.Metadata, &MetadataCalls{Calls: []*Ref{verb}}) } +// AddDatabase adds a DB reference to the Verb. +func (v *Verb) AddDatabase(db *Ref) { + if c, ok := slices.FindVariant[*MetadataDatabases](v.Metadata); ok { + c.Calls = append(c.Calls, db) + return + } + v.Metadata = append(v.Metadata, &MetadataDatabases{Calls: []*Ref{db}}) +} + func (v *Verb) GetMetadataIngress() optional.Option[*MetadataIngress] { if m, ok := slices.FindVariant[*MetadataIngress](v.Metadata); ok { return optional.Some(m) diff --git a/smoketest/relay/relay.go b/smoketest/relay/relay.go index 65c8f86a51..e3aa8d260f 100644 --- a/smoketest/relay/relay.go +++ b/smoketest/relay/relay.go @@ -13,7 +13,8 @@ import ( ) var logFile = ftl.Config[string]("log_file") -var db = ftl.PostgresDatabase("exemplardb") + +type Exemplardb = ftl.PostgresDatabaseHandle // PubSub