diff --git a/.golangci.yml b/.golangci.yml index 1f069c05f9..54fd702a36 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -58,6 +58,7 @@ linters: - tagalign - nolintlint - protogetter + - thelper linters-settings: exhaustive: diff --git a/backend/controller/controller.go b/backend/controller/controller.go index 1ba66372d8..b2fc96923c 100644 --- a/backend/controller/controller.go +++ b/backend/controller/controller.go @@ -862,7 +862,7 @@ func (s *Service) SendFSMEvent(ctx context.Context, req *connect.Request[ftlv1.S // schedules an event for a FSM instance within a db transaction // body may already be encrypted, which is denoted by the encrypted flag -func (s *Service) sendFSMEventInTx(ctx context.Context, tx *dal.Tx, instance *dal.FSMInstance, fsm *schema.FSM, eventType schema.Type, body []byte, encrypted bool) error { +func (s *Service) sendFSMEventInTx(ctx context.Context, tx *dal.DAL, instance *dal.FSMInstance, fsm *schema.FSM, eventType schema.Type, body []byte, encrypted bool) error { // Populated if we find a matching transition. var destinationRef *schema.Ref var destinationVerb *schema.Verb @@ -1484,7 +1484,7 @@ func (s *Service) executeAsyncCalls(ctx context.Context) (interval time.Duration } queueDepth := call.QueueDepth - didScheduleAnotherCall, err := s.dal.CompleteAsyncCall(ctx, call, callResult, func(tx *dal.Tx, isFinalResult bool) error { + didScheduleAnotherCall, err := s.dal.CompleteAsyncCall(ctx, call, callResult, func(tx *dal.DAL, isFinalResult bool) error { return s.finaliseAsyncCall(ctx, tx, call, callResult, isFinalResult) }) if err != nil { @@ -1556,7 +1556,7 @@ func (s *Service) catchAsyncCall(ctx context.Context, logger *log.Logger, call * catchResult = either.LeftOf[string](resp.Msg.GetBody()) } queueDepth := call.QueueDepth - didScheduleAnotherCall, err := s.dal.CompleteAsyncCall(ctx, call, catchResult, func(tx *dal.Tx, isFinalResult bool) error { + didScheduleAnotherCall, err := s.dal.CompleteAsyncCall(ctx, call, catchResult, func(tx *dal.DAL, isFinalResult bool) error { // Exposes the original error to external components such as PubSub and FSM return s.finaliseAsyncCall(ctx, tx, call, originalResult, isFinalResult) }) @@ -1602,7 +1602,7 @@ func metadataForAsyncCall(call *dal.AsyncCall) *ftlv1.Metadata { } } -func (s *Service) finaliseAsyncCall(ctx context.Context, tx *dal.Tx, call *dal.AsyncCall, callResult either.Either[[]byte, string], isFinalResult bool) error { +func (s *Service) finaliseAsyncCall(ctx context.Context, tx *dal.DAL, call *dal.AsyncCall, callResult either.Either[[]byte, string], isFinalResult bool) error { _, failed := callResult.(either.Right[[]byte, string]) // Allow for handling of completion based on origin @@ -1628,7 +1628,7 @@ func (s *Service) finaliseAsyncCall(ctx context.Context, tx *dal.Tx, call *dal.A return nil } -func (s *Service) onAsyncFSMCallCompletion(ctx context.Context, tx *dal.Tx, origin dal.AsyncOriginFSM, failed bool, isFinalResult bool) error { +func (s *Service) onAsyncFSMCallCompletion(ctx context.Context, tx *dal.DAL, origin dal.AsyncOriginFSM, failed bool, isFinalResult bool) error { logger := log.FromContext(ctx).Scope(origin.FSM.String()) // retrieve the next fsm event and delete it diff --git a/backend/controller/cronjobs/cronjobs.go b/backend/controller/cronjobs/cronjobs.go index 2f1a9efcd2..f1436d7b6b 100644 --- a/backend/controller/cronjobs/cronjobs.go +++ b/backend/controller/cronjobs/cronjobs.go @@ -145,7 +145,7 @@ func (s *Service) OnJobCompletion(ctx context.Context, key model.CronJobKey, fai } // scheduleCronJob schedules the next execution of a single cron job. -func (s *Service) scheduleCronJob(ctx context.Context, tx *dal.Tx, job model.CronJob) error { +func (s *Service) scheduleCronJob(ctx context.Context, tx *dal.DAL, job model.CronJob) error { logger := log.FromContext(ctx).Scope("cron") now := s.clock.Now().UTC() pending, err := tx.IsCronJobPending(ctx, job.Key, now) diff --git a/backend/controller/cronjobs/cronjobs_test.go b/backend/controller/cronjobs/cronjobs_test.go index 7ff9c6052f..d97583c212 100644 --- a/backend/controller/cronjobs/cronjobs_test.go +++ b/backend/controller/cronjobs/cronjobs_test.go @@ -94,7 +94,7 @@ func TestNewCronJobsForModule(t *testing.T) { // Complete all calls for _, call := range calls { callResult := either.LeftOf[string]([]byte("{}")) - _, err = parentDAL.CompleteAsyncCall(ctx, call, callResult, func(tx *parentdal.Tx, isFinalResult bool) error { + _, err = parentDAL.CompleteAsyncCall(ctx, call, callResult, func(tx *parentdal.DAL, isFinalResult bool) error { return nil }) assert.NoError(t, err) diff --git a/backend/controller/cronjobs/dal/dal.go b/backend/controller/cronjobs/dal/dal.go index e75d278155..ada64fa446 100644 --- a/backend/controller/cronjobs/dal/dal.go +++ b/backend/controller/cronjobs/dal/dal.go @@ -9,62 +9,24 @@ import ( "github.com/TBD54566975/ftl/backend/controller/cronjobs/sql" "github.com/TBD54566975/ftl/backend/controller/observability" - dalerrs "github.com/TBD54566975/ftl/backend/dal" + "github.com/TBD54566975/ftl/backend/dal" "github.com/TBD54566975/ftl/backend/schema" "github.com/TBD54566975/ftl/internal/model" "github.com/TBD54566975/ftl/internal/slices" ) type DAL struct { - db sql.DBI + *dal.Handle[DAL] + db sql.Querier } -func New(conn sql.ConnI) *DAL { - return &DAL{db: sql.NewDB(conn)} -} - -type Tx struct { - *DAL -} - -func (d *DAL) Begin(ctx context.Context) (*Tx, error) { - tx, err := d.db.Begin(ctx) - if err != nil { - return nil, fmt.Errorf("failed to begin transaction: %w", dalerrs.TranslatePGError(err)) - } - return &Tx{DAL: &DAL{db: tx}}, nil -} - -func (t *Tx) CommitOrRollback(ctx context.Context, err *error) { - tx, ok := t.db.(*sql.Tx) - if !ok { - panic("inconceivable") - } - tx.CommitOrRollback(ctx, err) -} - -func (t *Tx) Commit(ctx context.Context) error { - tx, ok := t.db.(*sql.Tx) - if !ok { - panic("inconcievable") - } - err := tx.Commit(ctx) - if err != nil { - return fmt.Errorf("failed to commit transaction: %w", dalerrs.TranslatePGError(err)) - } - return nil -} - -func (t *Tx) Rollback(ctx context.Context) error { - tx, ok := t.db.(*sql.Tx) - if !ok { - panic("inconcievable") +func New(conn dal.Connection) *DAL { + return &DAL{ + db: sql.New(conn), + Handle: dal.New(conn, func(h *dal.Handle[DAL]) *DAL { + return &DAL{Handle: h, db: sql.New(h.Connection)} + }), } - err := tx.Rollback(ctx) - if err != nil { - return fmt.Errorf("failed to rollback transaction: %w", dalerrs.TranslatePGError(err)) - } - return nil } func cronJobFromRow(c sql.CronJob, d sql.Deployment) model.CronJob { @@ -83,7 +45,7 @@ func cronJobFromRow(c sql.CronJob, d sql.Deployment) model.CronJob { func (d *DAL) CreateAsyncCall(ctx context.Context, params sql.CreateAsyncCallParams) (int64, error) { id, err := d.db.CreateAsyncCall(ctx, params) if err != nil { - return 0, fmt.Errorf("failed to create async call: %w", dalerrs.TranslatePGError(err)) + return 0, fmt.Errorf("failed to create async call: %w", dal.TranslatePGError(err)) } observability.AsyncCalls.Created(ctx, params.Verb, optional.None[schema.RefKey](), params.Origin, 0, err) queueDepth, err := d.db.AsyncCallQueueDepth(ctx) @@ -100,7 +62,7 @@ func (d *DAL) CreateAsyncCall(ctx context.Context, params sql.CreateAsyncCallPar func (d *DAL) GetUnscheduledCronJobs(ctx context.Context, startTime time.Time) ([]model.CronJob, error) { rows, err := d.db.GetUnscheduledCronJobs(ctx, startTime) if err != nil { - return nil, fmt.Errorf("failed to get cron jobs: %w", dalerrs.TranslatePGError(err)) + return nil, fmt.Errorf("failed to get cron jobs: %w", dal.TranslatePGError(err)) } return slices.Map(rows, func(r sql.GetUnscheduledCronJobsRow) model.CronJob { return cronJobFromRow(r.CronJob, r.Deployment) @@ -111,7 +73,7 @@ func (d *DAL) GetUnscheduledCronJobs(ctx context.Context, startTime time.Time) ( func (d *DAL) GetCronJobByKey(ctx context.Context, key model.CronJobKey) (model.CronJob, error) { row, err := d.db.GetCronJobByKey(ctx, key) if err != nil { - return model.CronJob{}, fmt.Errorf("failed to get cron job %q: %w", key, dalerrs.TranslatePGError(err)) + return model.CronJob{}, fmt.Errorf("failed to get cron job %q: %w", key, dal.TranslatePGError(err)) } return cronJobFromRow(row.CronJob, row.Deployment), nil } @@ -120,7 +82,7 @@ func (d *DAL) GetCronJobByKey(ctx context.Context, key model.CronJobKey) (model. func (d *DAL) IsCronJobPending(ctx context.Context, key model.CronJobKey, startTime time.Time) (bool, error) { pending, err := d.db.IsCronJobPending(ctx, key, startTime) if err != nil { - return false, fmt.Errorf("failed to check if cron job %q is pending: %w", key, dalerrs.TranslatePGError(err)) + return false, fmt.Errorf("failed to check if cron job %q is pending: %w", key, dal.TranslatePGError(err)) } return pending, nil } @@ -130,7 +92,7 @@ func (d *DAL) IsCronJobPending(ctx context.Context, key model.CronJobKey, startT func (d *DAL) UpdateCronJobExecution(ctx context.Context, params sql.UpdateCronJobExecutionParams) error { err := d.db.UpdateCronJobExecution(ctx, params) if err != nil { - return fmt.Errorf("failed to update cron job %q: %w", params.Key, dalerrs.TranslatePGError(err)) + return fmt.Errorf("failed to update cron job %q: %w", params.Key, dal.TranslatePGError(err)) } return nil } diff --git a/backend/controller/cronjobs/sql/conn.go b/backend/controller/cronjobs/sql/conn.go deleted file mode 100644 index 62699e9ef6..0000000000 --- a/backend/controller/cronjobs/sql/conn.go +++ /dev/null @@ -1,94 +0,0 @@ -package sql - -import ( - "context" - "database/sql" - "errors" - "fmt" -) - -type DBI interface { - Querier - Conn() ConnI - Begin(ctx context.Context) (*Tx, error) -} - -type ConnI interface { - DBTX - Begin() (*sql.Tx, error) -} - -type DB struct { - conn ConnI - *Queries -} - -func NewDB(conn ConnI) *DB { - return &DB{conn: conn, Queries: New(conn)} -} - -func (d *DB) Conn() ConnI { return d.conn } - -func (d *DB) Begin(ctx context.Context) (*Tx, error) { - tx, err := d.conn.Begin() - if err != nil { - return nil, fmt.Errorf("beginning transaction: %w", err) - } - return &Tx{tx: tx, Queries: New(tx)}, nil -} - -type noopSubConn struct { - DBTX -} - -func (noopSubConn) Begin() (*sql.Tx, error) { - return nil, errors.New("sql: not implemented") -} - -type Tx struct { - tx *sql.Tx - *Queries -} - -func (t *Tx) Conn() ConnI { return noopSubConn{t.tx} } - -func (t *Tx) Tx() *sql.Tx { return t.tx } - -func (t *Tx) Begin(ctx context.Context) (*Tx, error) { - return nil, fmt.Errorf("cannot nest transactions") -} - -func (t *Tx) Commit(ctx context.Context) error { - err := t.tx.Commit() - if err != nil { - return fmt.Errorf("committing transaction: %w", err) - } - - return nil -} - -func (t *Tx) Rollback(ctx context.Context) error { - err := t.tx.Rollback() - if err != nil { - return fmt.Errorf("rolling back transaction: %w", err) - } - - return nil -} - -// CommitOrRollback can be used in a defer statement to commit or rollback a -// transaction depending on whether the enclosing function returned an error. -// -// func myFunc() (err error) { -// tx, err := db.Begin(ctx) -// if err != nil { return err } -// defer tx.CommitOrRollback(ctx, &err) -// ... -// } -func (t *Tx) CommitOrRollback(ctx context.Context, err *error) { - if *err != nil { - *err = errors.Join(*err, t.Rollback(ctx)) - } else { - *err = t.Commit(ctx) - } -} diff --git a/backend/controller/dal/async_calls.go b/backend/controller/dal/async_calls.go index 5c621ffeb0..d806974d74 100644 --- a/backend/controller/dal/async_calls.go +++ b/backend/controller/dal/async_calls.go @@ -169,7 +169,7 @@ func (d *DAL) AcquireAsyncCall(ctx context.Context) (call *AsyncCall, err error) func (d *DAL) CompleteAsyncCall(ctx context.Context, call *AsyncCall, result either.Either[[]byte, string], - finalise func(tx *Tx, isFinalResult bool) error) (didScheduleAnotherCall bool, err error) { + finalise func(tx *DAL, isFinalResult bool) error) (didScheduleAnotherCall bool, err error) { tx, err := d.Begin(ctx) if err != nil { return false, dalerrs.TranslatePGError(err) //nolint:wrapcheck diff --git a/backend/controller/dal/dal.go b/backend/controller/dal/dal.go index 90dfd6670d..8fbace64cf 100644 --- a/backend/controller/dal/dal.go +++ b/backend/controller/dal/dal.go @@ -3,7 +3,6 @@ package dal import ( "context" - stdsql "database/sql" "encoding/json" "errors" "fmt" @@ -18,7 +17,7 @@ import ( "github.com/TBD54566975/ftl/backend/controller/sql" "github.com/TBD54566975/ftl/backend/controller/sql/sqltypes" - dalerrs "github.com/TBD54566975/ftl/backend/dal" + "github.com/TBD54566975/ftl/backend/dal" ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1" "github.com/TBD54566975/ftl/backend/schema" "github.com/TBD54566975/ftl/internal/encryption" @@ -210,89 +209,45 @@ func WithReservation(ctx context.Context, reservation Reservation, fn func() err return reservation.Commit(ctx) } -func New(ctx context.Context, conn *stdsql.DB, encryptionBuilder encryption.Builder) (*DAL, error) { - d := &DAL{ - db: sql.NewDB(conn), +func New(ctx context.Context, conn dal.Connection, encryptionBuilder encryption.Builder) (*DAL, error) { + var d *DAL + d = &DAL{ + db: sql.New(conn), + Handle: dal.New(conn, func(h *dal.Handle[DAL]) *DAL { + return &DAL{ + Handle: h, + db: sql.New(h.Connection), + encryptor: d.encryptor, + DeploymentChanges: d.DeploymentChanges, + } + }), DeploymentChanges: pubsub.New[DeploymentNotification](), } - encryptor, err := encryptionBuilder.Build(ctx, d) if err != nil { - return nil, fmt.Errorf("failed to setup encryptor: %w", err) + return nil, fmt.Errorf("build encryptor: %w", err) } - d.encryptor = encryptor - if err = d.verifyEncryptor(ctx); err != nil { - return nil, fmt.Errorf("failed to verify encryption: %w", err) + if err := d.verifyEncryptor(ctx, encryptor); err != nil { + return nil, fmt.Errorf("verify encryptor: %w", err) } - + d.encryptor = encryptor return d, nil } type DAL struct { - db sql.DBI + *dal.Handle[DAL] + db sql.Querier - kmsURL optional.Option[string] encryptor encryption.DataEncryptor // DeploymentChanges is a Topic that receives changes to the deployments table. DeploymentChanges *pubsub.Topic[DeploymentNotification] } -// Tx is DAL within a transaction. -type Tx struct { - *DAL -} - -// CommitOrRollback can be used in a defer statement to commit or rollback a -// transaction depending on whether the enclosing function returned an error. -// -// func myFunc() (err error) { -// tx, err := dal.Begin(ctx) -// if err != nil { return err } -// defer tx.CommitOrRollback(ctx, &err) -// ... -// } -func (t *Tx) CommitOrRollback(ctx context.Context, err *error) { - tx, ok := t.db.(*sql.Tx) - if !ok { - panic("inconceivable") - } - tx.CommitOrRollback(ctx, err) -} - -func (t *Tx) Commit(ctx context.Context) error { - tx, ok := t.db.(*sql.Tx) - if !ok { - panic("inconcievable") - } - return tx.Commit(ctx) -} - -func (t *Tx) Rollback(ctx context.Context) error { - tx, ok := t.db.(*sql.Tx) - if !ok { - panic("inconcievable") - } - return tx.Rollback(ctx) -} - -func (d *DAL) Begin(ctx context.Context) (*Tx, error) { - tx, err := d.db.Begin(ctx) - if err != nil { - return nil, dalerrs.TranslatePGError(err) - } - return &Tx{&DAL{ - db: tx, - DeploymentChanges: d.DeploymentChanges, - kmsURL: d.kmsURL, - encryptor: d.encryptor, - }}, nil -} - func (d *DAL) GetActiveControllers(ctx context.Context) ([]Controller, error) { controllers, err := d.db.GetActiveControllers(ctx) if err != nil { - return nil, dalerrs.TranslatePGError(err) + return nil, dal.TranslatePGError(err) } return slices.Map(controllers, func(in sql.Controller) Controller { return Controller{ @@ -305,23 +260,23 @@ func (d *DAL) GetActiveControllers(ctx context.Context) ([]Controller, error) { func (d *DAL) GetStatus(ctx context.Context) (Status, error) { controllers, err := d.GetActiveControllers(ctx) if err != nil { - return Status{}, fmt.Errorf("could not get control planes: %w", dalerrs.TranslatePGError(err)) + return Status{}, fmt.Errorf("could not get control planes: %w", dal.TranslatePGError(err)) } runners, err := d.db.GetActiveRunners(ctx) if err != nil { - return Status{}, fmt.Errorf("could not get active runners: %w", dalerrs.TranslatePGError(err)) + return Status{}, fmt.Errorf("could not get active runners: %w", dal.TranslatePGError(err)) } deployments, err := d.db.GetActiveDeployments(ctx) if err != nil { - return Status{}, fmt.Errorf("could not get active deployments: %w", dalerrs.TranslatePGError(err)) + return Status{}, fmt.Errorf("could not get active deployments: %w", dal.TranslatePGError(err)) } ingressRoutes, err := d.db.GetActiveIngressRoutes(ctx) if err != nil { - return Status{}, fmt.Errorf("could not get ingress routes: %w", dalerrs.TranslatePGError(err)) + return Status{}, fmt.Errorf("could not get ingress routes: %w", dal.TranslatePGError(err)) } routes, err := d.db.GetRoutingTable(ctx, nil) if err != nil { - return Status{}, fmt.Errorf("could not get routing table: %w", dalerrs.TranslatePGError(err)) + return Status{}, fmt.Errorf("could not get routing table: %w", dal.TranslatePGError(err)) } statusDeployments, err := slices.MapErr(deployments, func(in sql.GetActiveDeploymentsRow) (Deployment, error) { labels := model.Labels{} @@ -394,7 +349,7 @@ func (d *DAL) GetRunnersForDeployment(ctx context.Context, deployment model.Depl runners := []Runner{} rows, err := d.db.GetRunnersForDeployment(ctx, deployment) if err != nil { - return nil, dalerrs.TranslatePGError(err) + return nil, dal.TranslatePGError(err) } for _, row := range rows { attrs := model.Labels{} @@ -415,14 +370,14 @@ func (d *DAL) GetRunnersForDeployment(ctx context.Context, deployment model.Depl func (d *DAL) UpsertModule(ctx context.Context, language, name string) (err error) { _, err = d.db.UpsertModule(ctx, language, name) - return dalerrs.TranslatePGError(err) + return dal.TranslatePGError(err) } // GetMissingArtefacts returns the digests of the artefacts that are missing from the database. func (d *DAL) GetMissingArtefacts(ctx context.Context, digests []sha256.SHA256) ([]sha256.SHA256, error) { have, err := d.db.GetArtefactDigests(ctx, sha256esToBytes(digests)) if err != nil { - return nil, dalerrs.TranslatePGError(err) + return nil, dal.TranslatePGError(err) } haveStr := slices.Map(have, func(in sql.GetArtefactDigestsRow) sha256.SHA256 { return sha256.FromBytes(in.Digest) @@ -434,7 +389,7 @@ func (d *DAL) GetMissingArtefacts(ctx context.Context, digests []sha256.SHA256) func (d *DAL) CreateArtefact(ctx context.Context, content []byte) (digest sha256.SHA256, err error) { sha256digest := sha256.Sum(content) _, err = d.db.CreateArtefact(ctx, sha256digest[:], content) - return sha256digest, dalerrs.TranslatePGError(err) + return sha256digest, dal.TranslatePGError(err) } type IngressRoutingEntry struct { @@ -451,7 +406,7 @@ func (d *DAL) CreateDeployment(ctx context.Context, language string, moduleSchem logger := log.FromContext(ctx) // Start the parent transaction - tx, err := d.db.Begin(ctx) + tx, err := d.Begin(ctx) if err != nil { return model.DeploymentKey{}, fmt.Errorf("could not start transaction: %w", err) } @@ -475,9 +430,9 @@ func (d *DAL) CreateDeployment(ctx context.Context, language string, moduleSchem } // TODO(aat): "schema" containing language? - _, err = tx.UpsertModule(ctx, language, moduleSchema.Name) + _, err = tx.db.UpsertModule(ctx, language, moduleSchema.Name) if err != nil { - return model.DeploymentKey{}, fmt.Errorf("failed to upsert module: %w", dalerrs.TranslatePGError(err)) + return model.DeploymentKey{}, fmt.Errorf("failed to upsert module: %w", dal.TranslatePGError(err)) } // upsert topics @@ -486,27 +441,27 @@ func (d *DAL) CreateDeployment(ctx context.Context, language string, moduleSchem if !ok { continue } - err := tx.UpsertTopic(ctx, sql.UpsertTopicParams{ + err := tx.db.UpsertTopic(ctx, sql.UpsertTopicParams{ Topic: model.NewTopicKey(moduleSchema.Name, t.Name), Module: moduleSchema.Name, Name: t.Name, EventType: t.Event.String(), }) if err != nil { - return model.DeploymentKey{}, fmt.Errorf("could not insert topic: %w", dalerrs.TranslatePGError(err)) + return model.DeploymentKey{}, fmt.Errorf("could not insert topic: %w", dal.TranslatePGError(err)) } } deploymentKey := model.NewDeploymentKey(moduleSchema.Name) // Create the deployment - err = tx.CreateDeployment(ctx, moduleSchema.Name, schemaBytes, deploymentKey) + err = tx.db.CreateDeployment(ctx, moduleSchema.Name, schemaBytes, deploymentKey) if err != nil { - return model.DeploymentKey{}, fmt.Errorf("failed to create deployment: %w", dalerrs.TranslatePGError(err)) + return model.DeploymentKey{}, fmt.Errorf("failed to create deployment: %w", dal.TranslatePGError(err)) } uploadedDigests := slices.Map(artefacts, func(in DeploymentArtefact) []byte { return in.Digest[:] }) - artefactDigests, err := tx.GetArtefactDigests(ctx, uploadedDigests) + artefactDigests, err := tx.db.GetArtefactDigests(ctx, uploadedDigests) if err != nil { return model.DeploymentKey{}, fmt.Errorf("failed to get artefact digests: %w", err) } @@ -518,19 +473,19 @@ func (d *DAL) CreateDeployment(ctx context.Context, language string, moduleSchem // Associate the artefacts with the deployment for _, row := range artefactDigests { artefact := artefactsByDigest[sha256.FromBytes(row.Digest)] - err = tx.AssociateArtefactWithDeployment(ctx, sql.AssociateArtefactWithDeploymentParams{ + err = tx.db.AssociateArtefactWithDeployment(ctx, sql.AssociateArtefactWithDeploymentParams{ Key: deploymentKey, ArtefactID: row.ID, Executable: artefact.Executable, Path: artefact.Path, }) if err != nil { - return model.DeploymentKey{}, fmt.Errorf("failed to associate artefact with deployment: %w", dalerrs.TranslatePGError(err)) + return model.DeploymentKey{}, fmt.Errorf("failed to associate artefact with deployment: %w", dal.TranslatePGError(err)) } } for _, ingressRoute := range ingressRoutes { - err = tx.CreateIngressRoute(ctx, sql.CreateIngressRouteParams{ + err = tx.db.CreateIngressRoute(ctx, sql.CreateIngressRouteParams{ Key: deploymentKey, Method: ingressRoute.Method, Path: ingressRoute.Path, @@ -538,14 +493,14 @@ func (d *DAL) CreateDeployment(ctx context.Context, language string, moduleSchem Verb: ingressRoute.Verb, }) if err != nil { - return model.DeploymentKey{}, fmt.Errorf("failed to create ingress route: %w", dalerrs.TranslatePGError(err)) + return model.DeploymentKey{}, fmt.Errorf("failed to create ingress route: %w", dal.TranslatePGError(err)) } } for _, job := range cronJobs { // Start time must be calculated by the caller rather than generated by db // This ensures that nextExecution is after start time, otherwise the job will never be triggered - err := tx.CreateCronJob(ctx, sql.CreateCronJobParams{ + err := tx.db.CreateCronJob(ctx, sql.CreateCronJobParams{ Key: job.Key, DeploymentKey: deploymentKey, ModuleName: job.Verb.Module, @@ -555,7 +510,7 @@ func (d *DAL) CreateDeployment(ctx context.Context, language string, moduleSchem NextExecution: job.NextExecution, }) if err != nil { - return model.DeploymentKey{}, fmt.Errorf("failed to create cron job: %w", dalerrs.TranslatePGError(err)) + return model.DeploymentKey{}, fmt.Errorf("failed to create cron job: %w", dal.TranslatePGError(err)) } } @@ -565,7 +520,7 @@ func (d *DAL) CreateDeployment(ctx context.Context, language string, moduleSchem func (d *DAL) GetDeployment(ctx context.Context, key model.DeploymentKey) (*model.Deployment, error) { deployment, err := d.db.GetDeployment(ctx, key) if err != nil { - return nil, dalerrs.TranslatePGError(err) + return nil, dal.TranslatePGError(err) } return d.loadDeployment(ctx, deployment) } @@ -587,7 +542,7 @@ func (d *DAL) UpsertRunner(ctx context.Context, runner Runner) error { Labels: attrBytes, }) if err != nil { - return dalerrs.TranslatePGError(err) + return dal.TranslatePGError(err) } if runner.Deployment.Ok() && !deploymentID.Ok() { return fmt.Errorf("deployment %s not found", runner.Deployment) @@ -611,10 +566,10 @@ func (d *DAL) KillStaleControllers(ctx context.Context, age time.Duration) (int6 func (d *DAL) DeregisterRunner(ctx context.Context, key model.RunnerKey) error { count, err := d.db.DeregisterRunner(ctx, key) if err != nil { - return dalerrs.TranslatePGError(err) + return dal.TranslatePGError(err) } if count == 0 { - return dalerrs.ErrNotFound + return dal.ErrNotFound } return nil } @@ -628,24 +583,27 @@ func (d *DAL) ReserveRunnerForDeployment(ctx context.Context, deployment model.D return nil, fmt.Errorf("failed to JSON encode labels: %w", err) } ctx, cancel := context.WithTimeout(ctx, reservationTimeout) - tx, err := d.db.Begin(ctx) + tx, err := d.Begin(ctx) if err != nil { cancel() - return nil, dalerrs.TranslatePGError(err) + return nil, dal.TranslatePGError(err) } - runner, err := tx.ReserveRunner(ctx, time.Now().Add(reservationTimeout), deployment, jsonLabels) + runner, err := tx.db.ReserveRunner(ctx, time.Now().Add(reservationTimeout), deployment, jsonLabels) if err != nil { if rerr := tx.Rollback(context.Background()); rerr != nil { - err = errors.Join(err, dalerrs.TranslatePGError(rerr)) + err = errors.Join(err, dal.TranslatePGError(rerr)) } cancel() - if dalerrs.IsNotFound(err) { - return nil, fmt.Errorf("no idle runners found matching labels %s: %w", jsonLabels, dalerrs.ErrNotFound) + if dal.IsNotFound(err) { + return nil, fmt.Errorf("no idle runners found matching labels %s: %w", jsonLabels, dal.ErrNotFound) } - return nil, dalerrs.TranslatePGError(err) + return nil, dal.TranslatePGError(err) } runnerLabels := model.Labels{} if err := json.Unmarshal(runner.Labels, &runnerLabels); err != nil { + if rerr := tx.Rollback(context.Background()); rerr != nil { + err = errors.Join(err, dal.TranslatePGError(rerr)) + } cancel() return nil, fmt.Errorf("failed to JSON decode labels for runner %d: %w", runner.ID, err) } @@ -666,19 +624,19 @@ func (d *DAL) ReserveRunnerForDeployment(ctx context.Context, deployment model.D var _ Reservation = (*postgresClaim)(nil) type postgresClaim struct { - tx *sql.Tx + tx *DAL runner Runner cancel context.CancelFunc } func (p *postgresClaim) Commit(ctx context.Context) error { defer p.cancel() - return dalerrs.TranslatePGError(p.tx.Commit(ctx)) + return dal.TranslatePGError(p.tx.Commit(ctx)) } func (p *postgresClaim) Rollback(ctx context.Context) error { defer p.cancel() - return dalerrs.TranslatePGError(p.tx.Rollback(ctx)) + return dal.TranslatePGError(p.tx.Rollback(ctx)) } func (p *postgresClaim) Runner() Runner { return p.runner } @@ -686,31 +644,30 @@ func (p *postgresClaim) Runner() Runner { return p.runner } // SetDeploymentReplicas activates the given deployment. func (d *DAL) SetDeploymentReplicas(ctx context.Context, key model.DeploymentKey, minReplicas int) (err error) { // Start the transaction - tx, err := d.db.Begin(ctx) + tx, err := d.Begin(ctx) if err != nil { - return dalerrs.TranslatePGError(err) + return dal.TranslatePGError(err) } - defer tx.CommitOrRollback(ctx, &err) - deployment, err := d.db.GetDeployment(ctx, key) + deployment, err := tx.db.GetDeployment(ctx, key) if err != nil { - return dalerrs.TranslatePGError(err) + return dal.TranslatePGError(err) } - err = d.db.SetDeploymentDesiredReplicas(ctx, key, int32(minReplicas)) + err = tx.db.SetDeploymentDesiredReplicas(ctx, key, int32(minReplicas)) if err != nil { - return dalerrs.TranslatePGError(err) + return dal.TranslatePGError(err) } if minReplicas == 0 { - err = d.deploymentWillDeactivate(ctx, tx, key) + err = tx.deploymentWillDeactivate(ctx, key) if err != nil { - return dalerrs.TranslatePGError(err) + return dal.TranslatePGError(err) } } else if deployment.MinReplicas == 0 { - err = d.deploymentWillActivate(ctx, tx, key) + err = tx.deploymentWillActivate(ctx, key) if err != nil { - return dalerrs.TranslatePGError(err) + return dal.TranslatePGError(err) } } var payload encryption.EncryptedTimelineColumn @@ -721,12 +678,12 @@ func (d *DAL) SetDeploymentReplicas(ctx context.Context, key model.DeploymentKey if err != nil { return fmt.Errorf("failed to encrypt payload for InsertDeploymentUpdatedEvent: %w", err) } - err = tx.InsertTimelineDeploymentUpdatedEvent(ctx, sql.InsertTimelineDeploymentUpdatedEventParams{ + err = tx.db.InsertTimelineDeploymentUpdatedEvent(ctx, sql.InsertTimelineDeploymentUpdatedEventParams{ DeploymentKey: key, Payload: payload, }) if err != nil { - return dalerrs.TranslatePGError(err) + return dal.TranslatePGError(err) } return nil @@ -739,51 +696,51 @@ var ErrReplaceDeploymentAlreadyActive = errors.New("deployment already active") // returns ErrReplaceDeploymentAlreadyActive if the new deployment is already active. func (d *DAL) ReplaceDeployment(ctx context.Context, newDeploymentKey model.DeploymentKey, minReplicas int) (err error) { // Start the transaction - tx, err := d.db.Begin(ctx) + tx, err := d.Begin(ctx) if err != nil { - return fmt.Errorf("replace deployment failed to begin transaction for %v: %w", newDeploymentKey, dalerrs.TranslatePGError(err)) + return fmt.Errorf("replace deployment failed to begin transaction for %v: %w", newDeploymentKey, dal.TranslatePGError(err)) } defer tx.CommitOrRollback(ctx, &err) - newDeployment, err := tx.GetDeployment(ctx, newDeploymentKey) + newDeployment, err := tx.db.GetDeployment(ctx, newDeploymentKey) if err != nil { - return fmt.Errorf("replace deployment failed to get deployment for %v: %w", newDeploymentKey, dalerrs.TranslatePGError(err)) + return fmt.Errorf("replace deployment failed to get deployment for %v: %w", newDeploymentKey, dal.TranslatePGError(err)) } // must be called before deploymentWillDeactivate for the old deployment - err = d.deploymentWillActivate(ctx, tx, newDeploymentKey) + err = tx.deploymentWillActivate(ctx, newDeploymentKey) if err != nil { - return fmt.Errorf("replace deployment failed willActivate trigger for %v: %w", newDeploymentKey, dalerrs.TranslatePGError(err)) + return fmt.Errorf("replace deployment failed willActivate trigger for %v: %w", newDeploymentKey, dal.TranslatePGError(err)) } // If there's an existing deployment, set its desired replicas to 0 var replacedDeploymentKey optional.Option[model.DeploymentKey] - oldDeployment, err := tx.GetExistingDeploymentForModule(ctx, newDeployment.ModuleName) + oldDeployment, err := tx.db.GetExistingDeploymentForModule(ctx, newDeployment.ModuleName) if err == nil { if oldDeployment.Key.String() == newDeploymentKey.String() { return fmt.Errorf("replace deployment failed: deployment already exists from %v to %v: %w", oldDeployment.Key, newDeploymentKey, ErrReplaceDeploymentAlreadyActive) } - err = tx.SetDeploymentDesiredReplicas(ctx, oldDeployment.Key, 0) + err = tx.db.SetDeploymentDesiredReplicas(ctx, oldDeployment.Key, 0) if err != nil { - return fmt.Errorf("replace deployment failed to set old deployment replicas from %v to %v: %w", oldDeployment.Key, newDeploymentKey, dalerrs.TranslatePGError(err)) + return fmt.Errorf("replace deployment failed to set old deployment replicas from %v to %v: %w", oldDeployment.Key, newDeploymentKey, dal.TranslatePGError(err)) } - err = tx.SetDeploymentDesiredReplicas(ctx, newDeploymentKey, int32(minReplicas)) + err = tx.db.SetDeploymentDesiredReplicas(ctx, newDeploymentKey, int32(minReplicas)) if err != nil { - return fmt.Errorf("replace deployment failed to set new deployment replicas from %v to %v: %w", oldDeployment.Key, newDeploymentKey, dalerrs.TranslatePGError(err)) + return fmt.Errorf("replace deployment failed to set new deployment replicas from %v to %v: %w", oldDeployment.Key, newDeploymentKey, dal.TranslatePGError(err)) } - err = d.deploymentWillDeactivate(ctx, tx, oldDeployment.Key) + err = d.deploymentWillDeactivate(ctx, oldDeployment.Key) if err != nil { - return fmt.Errorf("replace deployment failed willDeactivate trigger from %v to %v: %w", oldDeployment.Key, newDeploymentKey, dalerrs.TranslatePGError(err)) + return fmt.Errorf("replace deployment failed willDeactivate trigger from %v to %v: %w", oldDeployment.Key, newDeploymentKey, dal.TranslatePGError(err)) } replacedDeploymentKey = optional.Some(oldDeployment.Key) - } else if !dalerrs.IsNotFound(err) { + } else if !dal.IsNotFound(err) { // any error other than not found - return fmt.Errorf("replace deployment failed to get existing deployment for %v: %w", newDeploymentKey, dalerrs.TranslatePGError(err)) + return fmt.Errorf("replace deployment failed to get existing deployment for %v: %w", newDeploymentKey, dal.TranslatePGError(err)) } else { // Set the desired replicas for the new deployment - err = tx.SetDeploymentDesiredReplicas(ctx, newDeploymentKey, int32(minReplicas)) + err = tx.db.SetDeploymentDesiredReplicas(ctx, newDeploymentKey, int32(minReplicas)) if err != nil { - return fmt.Errorf("replace deployment failed to set replicas for %v: %w", newDeploymentKey, dalerrs.TranslatePGError(err)) + return fmt.Errorf("replace deployment failed to set replicas for %v: %w", newDeploymentKey, dal.TranslatePGError(err)) } } @@ -796,14 +753,14 @@ func (d *DAL) ReplaceDeployment(ctx context.Context, newDeploymentKey model.Depl return fmt.Errorf("replace deployment failed to encrypt payload: %w", err) } - err = tx.InsertTimelineDeploymentCreatedEvent(ctx, sql.InsertTimelineDeploymentCreatedEventParams{ + err = tx.db.InsertTimelineDeploymentCreatedEvent(ctx, sql.InsertTimelineDeploymentCreatedEventParams{ DeploymentKey: newDeploymentKey, Language: newDeployment.Language, ModuleName: newDeployment.ModuleName, Payload: payload, }) if err != nil { - return fmt.Errorf("replace deployment failed to create event: %w", dalerrs.TranslatePGError(err)) + return fmt.Errorf("replace deployment failed to create event: %w", dal.TranslatePGError(err)) } return nil @@ -813,23 +770,23 @@ func (d *DAL) ReplaceDeployment(ctx context.Context, newDeploymentKey model.Depl // // when replacing a deployment, this should be called first before calling deploymentWillDeactivate on the old deployment. // This allows the new deployment to migrate from the old deployment (such as subscriptions). -func (d *DAL) deploymentWillActivate(ctx context.Context, tx *sql.Tx, key model.DeploymentKey) error { - module, err := tx.GetSchemaForDeployment(ctx, key) +func (d *DAL) deploymentWillActivate(ctx context.Context, key model.DeploymentKey) error { + module, err := d.db.GetSchemaForDeployment(ctx, key) if err != nil { - return fmt.Errorf("could not get schema: %w", dalerrs.TranslatePGError(err)) + return fmt.Errorf("could not get schema: %w", dal.TranslatePGError(err)) } - err = d.createSubscriptions(ctx, tx, key, module) + err = d.createSubscriptions(ctx, key, module) if err != nil { return err } - return d.createSubscribers(ctx, tx, key, module) + return d.createSubscribers(ctx, key, module) } // deploymentWillDeactivate is called whenever a deployment goes to min_replicas=0. // // it may be called when min_replicas was already 0 -func (d *DAL) deploymentWillDeactivate(ctx context.Context, tx *sql.Tx, key model.DeploymentKey) error { - return d.removeSubscriptionsAndSubscribers(ctx, tx, key) +func (d *DAL) deploymentWillDeactivate(ctx context.Context, key model.DeploymentKey) error { + return d.removeSubscriptionsAndSubscribers(ctx, key) } // GetDeploymentsNeedingReconciliation returns deployments that have a @@ -837,10 +794,10 @@ func (d *DAL) deploymentWillDeactivate(ctx context.Context, tx *sql.Tx, key mode func (d *DAL) GetDeploymentsNeedingReconciliation(ctx context.Context) ([]Reconciliation, error) { counts, err := d.db.GetDeploymentsNeedingReconciliation(ctx) if err != nil { - if dalerrs.IsNotFound(err) { + if dal.IsNotFound(err) { return nil, nil } - return nil, dalerrs.TranslatePGError(err) + return nil, dal.TranslatePGError(err) } return slices.Map(counts, func(t sql.GetDeploymentsNeedingReconciliationRow) Reconciliation { return Reconciliation{ @@ -857,10 +814,10 @@ func (d *DAL) GetDeploymentsNeedingReconciliation(ctx context.Context) ([]Reconc func (d *DAL) GetActiveDeployments(ctx context.Context) ([]Deployment, error) { rows, err := d.db.GetActiveDeployments(ctx) if err != nil { - if dalerrs.IsNotFound(err) { + if dal.IsNotFound(err) { return nil, nil } - return nil, dalerrs.TranslatePGError(err) + return nil, dal.TranslatePGError(err) } return slices.MapErr(rows, func(in sql.GetActiveDeploymentsRow) (Deployment, error) { return Deployment{ @@ -895,10 +852,10 @@ func (d *DAL) GetActiveSchema(ctx context.Context) (*schema.Schema, error) { func (d *DAL) GetDeploymentsWithMinReplicas(ctx context.Context) ([]Deployment, error) { rows, err := d.db.GetDeploymentsWithMinReplicas(ctx) if err != nil { - if dalerrs.IsNotFound(err) { + if dal.IsNotFound(err) { return nil, nil } - return nil, dalerrs.TranslatePGError(err) + return nil, dal.TranslatePGError(err) } return slices.MapErr(rows, func(in sql.GetDeploymentsWithMinReplicasRow) (Deployment, error) { return Deployment{ @@ -915,7 +872,7 @@ func (d *DAL) GetDeploymentsWithMinReplicas(ctx context.Context) ([]Deployment, func (d *DAL) GetActiveDeploymentSchemas(ctx context.Context) ([]*schema.Module, error) { rows, err := d.db.GetActiveDeploymentSchemas(ctx) if err != nil { - return nil, fmt.Errorf("could not get active deployments: %w", dalerrs.TranslatePGError(err)) + return nil, fmt.Errorf("could not get active deployments: %w", dal.TranslatePGError(err)) } return slices.MapErr(rows, func(in sql.GetActiveDeploymentSchemasRow) (*schema.Module, error) { return in.Schema, nil }) } @@ -937,7 +894,7 @@ type Process struct { func (d *DAL) GetProcessList(ctx context.Context) ([]Process, error) { rows, err := d.db.GetProcessList(ctx) if err != nil { - return nil, dalerrs.TranslatePGError(err) + return nil, dal.TranslatePGError(err) } return slices.MapErr(rows, func(row sql.GetProcessListRow) (Process, error) { var runner optional.Option[ProcessRunner] @@ -982,10 +939,10 @@ func (d *DAL) GetIdleRunners(ctx context.Context, limit int, labels model.Labels return nil, fmt.Errorf("could not marshal labels: %w", err) } runners, err := d.db.GetIdleRunners(ctx, jsonb, int64(limit)) - if dalerrs.IsNotFound(err) { + if dal.IsNotFound(err) { return nil, nil } else if err != nil { - return nil, dalerrs.TranslatePGError(err) + return nil, dal.TranslatePGError(err) } return slices.MapErr(runners, func(row sql.Runner) (Runner, error) { rowLabels := model.Labels{} @@ -1010,10 +967,10 @@ func (d *DAL) GetIdleRunners(ctx context.Context, limit int, labels model.Labels func (d *DAL) GetRoutingTable(ctx context.Context, modules []string) (map[string][]Route, error) { routes, err := d.db.GetRoutingTable(ctx, modules) if err != nil { - return nil, dalerrs.TranslatePGError(err) + return nil, dal.TranslatePGError(err) } if len(routes) == 0 { - return nil, fmt.Errorf("no routes found: %w", dalerrs.ErrNotFound) + return nil, fmt.Errorf("no routes found: %w", dal.ErrNotFound) } out := make(map[string][]Route, len(routes)) for _, route := range routes { @@ -1033,7 +990,7 @@ func (d *DAL) GetRoutingTable(ctx context.Context, modules []string) (map[string func (d *DAL) GetRunnerState(ctx context.Context, runnerKey model.RunnerKey) (RunnerState, error) { state, err := d.db.GetRunnerState(ctx, runnerKey) if err != nil { - return "", dalerrs.TranslatePGError(err) + return "", dal.TranslatePGError(err) } return RunnerState(state), nil } @@ -1041,14 +998,14 @@ func (d *DAL) GetRunnerState(ctx context.Context, runnerKey model.RunnerKey) (Ru func (d *DAL) GetRunner(ctx context.Context, runnerKey model.RunnerKey) (Runner, error) { row, err := d.db.GetRunner(ctx, runnerKey) if err != nil { - return Runner{}, dalerrs.TranslatePGError(err) + return Runner{}, dal.TranslatePGError(err) } return runnerFromDB(row), nil } func (d *DAL) ExpireRunnerClaims(ctx context.Context) (int64, error) { count, err := d.db.ExpireRunnerReservations(ctx) - return count, dalerrs.TranslatePGError(err) + return count, dal.TranslatePGError(err) } func (d *DAL) InsertLogEvent(ctx context.Context, log *LogEvent) error { @@ -1068,7 +1025,7 @@ func (d *DAL) InsertLogEvent(ctx context.Context, log *LogEvent) error { if err != nil { return fmt.Errorf("failed to encrypt log payload: %w", err) } - return dalerrs.TranslatePGError(d.db.InsertTimelineLogEvent(ctx, sql.InsertTimelineLogEventParams{ + return dal.TranslatePGError(d.db.InsertTimelineLogEvent(ctx, sql.InsertTimelineLogEventParams{ DeploymentKey: log.DeploymentKey, RequestKey: requestKey, TimeStamp: log.Time, @@ -1086,7 +1043,7 @@ func (d *DAL) loadDeployment(ctx context.Context, deployment sql.GetDeploymentRo } artefacts, err := d.db.GetDeploymentArtefacts(ctx, deployment.Deployment.ID) if err != nil { - return nil, dalerrs.TranslatePGError(err) + return nil, dal.TranslatePGError(err) } out.Artefacts = slices.Map(artefacts, func(row sql.GetDeploymentArtefactsRow) *model.Artefact { return &model.Artefact{ @@ -1101,7 +1058,7 @@ func (d *DAL) loadDeployment(ctx context.Context, deployment sql.GetDeploymentRo func (d *DAL) CreateRequest(ctx context.Context, key model.RequestKey, addr string) error { if err := d.db.CreateRequest(ctx, sql.Origin(key.Payload.Origin), key, addr); err != nil { - return dalerrs.TranslatePGError(err) + return dal.TranslatePGError(err) } return nil } @@ -1109,10 +1066,10 @@ func (d *DAL) CreateRequest(ctx context.Context, key model.RequestKey, addr stri func (d *DAL) GetIngressRoutes(ctx context.Context, method string) ([]IngressRoute, error) { routes, err := d.db.GetIngressRoutes(ctx, method) if err != nil { - return nil, dalerrs.TranslatePGError(err) + return nil, dal.TranslatePGError(err) } if len(routes) == 0 { - return nil, dalerrs.ErrNotFound + return nil, dal.ErrNotFound } return slices.Map(routes, func(row sql.GetIngressRoutesRow) IngressRoute { return IngressRoute{ @@ -1128,7 +1085,7 @@ func (d *DAL) GetIngressRoutes(ctx context.Context, method string) ([]IngressRou func (d *DAL) UpsertController(ctx context.Context, key model.ControllerKey, addr string) (int64, error) { id, err := d.db.UpsertController(ctx, key, addr) - return id, dalerrs.TranslatePGError(err) + return id, dal.TranslatePGError(err) } func (d *DAL) InsertCallEvent(ctx context.Context, call *CallEvent) error { @@ -1155,7 +1112,7 @@ func (d *DAL) InsertCallEvent(ctx context.Context, call *CallEvent) error { if err != nil { return fmt.Errorf("failed to encrypt call payload: %w", err) } - return dalerrs.TranslatePGError(d.db.InsertTimelineCallEvent(ctx, sql.InsertTimelineCallEventParams{ + return dal.TranslatePGError(d.db.InsertTimelineCallEvent(ctx, sql.InsertTimelineCallEventParams{ DeploymentKey: call.DeploymentKey, RequestKey: requestKey, ParentRequestKey: parentRequestKey, @@ -1170,13 +1127,13 @@ func (d *DAL) InsertCallEvent(ctx context.Context, call *CallEvent) error { func (d *DAL) DeleteOldEvents(ctx context.Context, eventType EventType, age time.Duration) (int64, error) { count, err := d.db.DeleteOldTimelineEvents(ctx, sqltypes.Duration(age), eventType) - return count, dalerrs.TranslatePGError(err) + return count, dal.TranslatePGError(err) } func (d *DAL) GetActiveRunners(ctx context.Context) ([]Runner, error) { rows, err := d.db.GetActiveRunners(ctx) if err != nil { - return nil, dalerrs.TranslatePGError(err) + return nil, dal.TranslatePGError(err) } return slices.Map(rows, func(row sql.GetActiveRunnersRow) Runner { return runnerFromDB(sql.GetRunnerRow(row)) @@ -1184,12 +1141,12 @@ func (d *DAL) GetActiveRunners(ctx context.Context) ([]Runner, error) { } // Check if a deployment exists that exactly matches the given artefacts and schema. -func (*DAL) checkForExistingDeployments(ctx context.Context, tx *sql.Tx, moduleSchema *schema.Module, artefacts []DeploymentArtefact) (model.DeploymentKey, error) { +func (*DAL) checkForExistingDeployments(ctx context.Context, tx *DAL, moduleSchema *schema.Module, artefacts []DeploymentArtefact) (model.DeploymentKey, error) { schemaBytes, err := schema.ModuleToBytes(moduleSchema) if err != nil { return model.DeploymentKey{}, fmt.Errorf("failed to marshal schema: %w", err) } - existing, err := tx.GetDeploymentsWithArtefacts(ctx, + existing, err := tx.db.GetDeploymentsWithArtefacts(ctx, sha256esToBytes(slices.Map(artefacts, func(in DeploymentArtefact) sha256.SHA256 { return in.Digest })), schemaBytes, int64(len(artefacts)), @@ -1209,7 +1166,7 @@ func sha256esToBytes(digests []sha256.SHA256) [][]byte { type artefactReader struct { id int64 - db sql.DBI + db sql.Querier offset int32 } @@ -1218,7 +1175,7 @@ func (r *artefactReader) Close() error { return nil } func (r *artefactReader) Read(p []byte) (n int, err error) { content, err := r.db.GetArtefactContentRange(context.Background(), r.offset+1, int32(len(p)), r.id) if err != nil { - return 0, dalerrs.TranslatePGError(err) + return 0, dal.TranslatePGError(err) } copy(p, content) clen := len(content) diff --git a/backend/controller/dal/dal_test.go b/backend/controller/dal/dal_test.go index 6dab71cf0b..669ebb6afe 100644 --- a/backend/controller/dal/dal_test.go +++ b/backend/controller/dal/dal_test.go @@ -29,7 +29,7 @@ func TestDAL(t *testing.T) { conn := sqltest.OpenForTesting(ctx, t) dal, err := New(ctx, conn, encryption.NewBuilder()) assert.NoError(t, err) - assert.NotZero(t, dal) + var testContent = bytes.Repeat([]byte("sometestcontentthatislongerthanthereadbuffer"), 100) var testSHA = sha256.Sum(testContent) diff --git a/backend/controller/dal/encryption.go b/backend/controller/dal/encryption.go index 51af56ca0f..844f0b5231 100644 --- a/backend/controller/dal/encryption.go +++ b/backend/controller/dal/encryption.go @@ -13,10 +13,6 @@ import ( ) func (d *DAL) encrypt(cleartext []byte, dest encryption.Encrypted) error { - if d.encryptor == nil { - return fmt.Errorf("encryptor not set") - } - err := d.encryptor.Encrypt(cleartext, dest) if err != nil { return fmt.Errorf("failed to encrypt binary with subkey %s: %w", dest.SubKey(), err) @@ -26,10 +22,6 @@ func (d *DAL) encrypt(cleartext []byte, dest encryption.Encrypted) error { } func (d *DAL) decrypt(encrypted encryption.Encrypted) ([]byte, error) { - if d.encryptor == nil { - return nil, fmt.Errorf("encryptor not set") - } - v, err := d.encryptor.Decrypt(encrypted) if err != nil { return nil, fmt.Errorf("failed to decrypt binary with subkey %s: %w", encrypted.SubKey(), err) @@ -93,9 +85,8 @@ func (d *DAL) EnsureKey(ctx context.Context, generateKey func() ([]byte, error)) const verification = "FTL - Towards a 𝝺-calculus for large-scale systems" -func (d *DAL) verifyEncryptor(ctx context.Context) (err error) { - var tx *Tx - tx, err = d.Begin(ctx) +func (d *DAL) verifyEncryptor(ctx context.Context, encryptor encryption.DataEncryptor) (err error) { + tx, err := d.Begin(ctx) if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } @@ -111,7 +102,7 @@ func (d *DAL) verifyEncryptor(ctx context.Context) (err error) { } needsUpdate := false - newTimeline, err := verifySubkey(d.encryptor, row.VerifyTimeline) + newTimeline, err := verifySubkey(encryptor, row.VerifyTimeline) if err != nil { return fmt.Errorf("failed to verify timeline subkey: %w", err) } @@ -120,7 +111,7 @@ func (d *DAL) verifyEncryptor(ctx context.Context) (err error) { row.VerifyTimeline = optional.Some(newTimeline) } - newAsync, err := verifySubkey(d.encryptor, row.VerifyAsync) + newAsync, err := verifySubkey(encryptor, row.VerifyAsync) if err != nil { return fmt.Errorf("failed to verify async subkey: %w", err) } diff --git a/backend/controller/dal/events.go b/backend/controller/dal/events.go index 85d0c0bb8c..3c26ade38b 100644 --- a/backend/controller/dal/events.go +++ b/backend/controller/dal/events.go @@ -260,7 +260,7 @@ func (d *DAL) QueryTimeline(ctx context.Context, limit int, filters ...TimelineF deploymentQuery += ` WHERE key = ANY($1::TEXT[])` deploymentArgs = append(deploymentArgs, filter.deployments) } - rows, err := d.db.Conn().QueryContext(ctx, deploymentQuery, deploymentArgs...) + rows, err := d.Handle.Connection.QueryContext(ctx, deploymentQuery, deploymentArgs...) if err != nil { return nil, dalerrs.TranslatePGError(err) } @@ -316,7 +316,7 @@ func (d *DAL) QueryTimeline(ctx context.Context, limit int, filters ...TimelineF q += fmt.Sprintf(" LIMIT %d", limit) // Issue query. - rows, err = d.db.Conn().QueryContext(ctx, q, args...) + rows, err = d.Handle.Connection.QueryContext(ctx, q, args...) if err != nil { return nil, fmt.Errorf("%s: %w", q, dalerrs.TranslatePGError(err)) } diff --git a/backend/controller/dal/fsm_test.go b/backend/controller/dal/fsm_test.go index 957a3c3f4e..cb12000118 100644 --- a/backend/controller/dal/fsm_test.go +++ b/backend/controller/dal/fsm_test.go @@ -52,7 +52,7 @@ func TestSendFSMEvent(t *testing.T) { } assert.Equal(t, expectedCall, call, assert.Exclude[*Lease](), assert.Exclude[time.Time]()) - _, err = dal.CompleteAsyncCall(ctx, call, either.LeftOf[string]([]byte(`{}`)), func(tx *Tx, isFinalResult bool) error { return nil }) + _, err = dal.CompleteAsyncCall(ctx, call, either.LeftOf[string]([]byte(`{}`)), func(tx *DAL, isFinalResult bool) error { return nil }) assert.NoError(t, err) actual, err := dal.LoadAsyncCall(ctx, call.ID) diff --git a/backend/controller/dal/lease.go b/backend/controller/dal/lease.go index 8bbc1fd5b0..23634173a1 100644 --- a/backend/controller/dal/lease.go +++ b/backend/controller/dal/lease.go @@ -24,7 +24,7 @@ var _ leases.Leaser = (*DAL)(nil) type Lease struct { key leases.Key idempotencyKey uuid.UUID - db sql.DBI + db sql.Querier ttl time.Duration errch chan error release chan bool diff --git a/backend/controller/dal/lease_test.go b/backend/controller/dal/lease_test.go index 34b58e58a3..182dc8275b 100644 --- a/backend/controller/dal/lease_test.go +++ b/backend/controller/dal/lease_test.go @@ -11,20 +11,19 @@ import ( "github.com/google/uuid" "github.com/TBD54566975/ftl/backend/controller/leases" - "github.com/TBD54566975/ftl/backend/controller/sql" "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" - dalerrs "github.com/TBD54566975/ftl/backend/dal" + "github.com/TBD54566975/ftl/backend/dal" "github.com/TBD54566975/ftl/internal/encryption" "github.com/TBD54566975/ftl/internal/log" ) -func leaseExists(t *testing.T, conn sql.ConnI, idempotencyKey uuid.UUID, key leases.Key) bool { +func leaseExists(t *testing.T, conn dal.Connection, idempotencyKey uuid.UUID, key leases.Key) bool { t.Helper() var count int - err := dalerrs.TranslatePGError(conn. + err := dal.TranslatePGError(conn. QueryRowContext(context.Background(), "SELECT COUNT(*) FROM leases WHERE idempotency_key = $1 AND key = $2", idempotencyKey, key). Scan(&count)) - if errors.Is(err, dalerrs.ErrNotFound) { + if errors.Is(err, dal.ErrNotFound) { return false } assert.NoError(t, err) diff --git a/backend/controller/dal/pubsub.go b/backend/controller/dal/pubsub.go index 924ef17834..bee0d36528 100644 --- a/backend/controller/dal/pubsub.go +++ b/backend/controller/dal/pubsub.go @@ -177,13 +177,13 @@ func (d *DAL) CompleteEventForSubscription(ctx context.Context, module, name str // ResetSubscription resets the subscription cursor to the topic's head. func (d *DAL) ResetSubscription(ctx context.Context, module, name string) (err error) { - tx, err := d.db.Begin(ctx) + tx, err := d.Begin(ctx) if err != nil { return fmt.Errorf("could not start transaction: %w", err) } defer tx.CommitOrRollback(ctx, &err) - subscription, err := tx.GetSubscription(ctx, name, module) + subscription, err := tx.db.GetSubscription(ctx, name, module) if err != nil { if dalerrs.IsNotFound(err) { return fmt.Errorf("subscription %s.%s not found", module, name) @@ -191,7 +191,7 @@ func (d *DAL) ResetSubscription(ctx context.Context, module, name string) (err e return fmt.Errorf("could not fetch subscription: %w", dalerrs.TranslatePGError(err)) } - topic, err := tx.GetTopic(ctx, subscription.TopicID) + topic, err := tx.db.GetTopic(ctx, subscription.TopicID) if err != nil { return fmt.Errorf("could not fetch topic: %w", dalerrs.TranslatePGError(err)) } @@ -201,12 +201,12 @@ func (d *DAL) ResetSubscription(ctx context.Context, module, name string) (err e return fmt.Errorf("no events published to topic %s", topic.Name) } - headEvent, err := tx.GetTopicEvent(ctx, headEventID) + headEvent, err := tx.db.GetTopicEvent(ctx, headEventID) if err != nil { return fmt.Errorf("could not fetch topic head: %w", dalerrs.TranslatePGError(err)) } - err = tx.SetSubscriptionCursor(ctx, subscription.Key, headEvent.Key) + err = tx.db.SetSubscriptionCursor(ctx, subscription.Key, headEvent.Key) if err != nil { return fmt.Errorf("failed to reset subscription: %w", dalerrs.TranslatePGError(err)) } @@ -214,7 +214,7 @@ func (d *DAL) ResetSubscription(ctx context.Context, module, name string) (err e return nil } -func (d *DAL) createSubscriptions(ctx context.Context, tx *sql.Tx, key model.DeploymentKey, module *schema.Module) error { +func (d *DAL) createSubscriptions(ctx context.Context, key model.DeploymentKey, module *schema.Module) error { logger := log.FromContext(ctx) for _, decl := range module.Decls { @@ -232,7 +232,7 @@ func (d *DAL) createSubscriptions(ctx context.Context, tx *sql.Tx, key model.Dep continue } subscriptionKey := model.NewSubscriptionKey(module.Name, s.Name) - result, err := tx.UpsertSubscription(ctx, sql.UpsertSubscriptionParams{ + result, err := d.db.UpsertSubscription(ctx, sql.UpsertSubscriptionParams{ Key: subscriptionKey, Module: module.Name, Deployment: key, @@ -271,7 +271,7 @@ func hasSubscribers(subscription *schema.Subscription, decls []schema.Decl) bool return false } -func (d *DAL) createSubscribers(ctx context.Context, tx *sql.Tx, key model.DeploymentKey, module *schema.Module) error { +func (d *DAL) createSubscribers(ctx context.Context, key model.DeploymentKey, module *schema.Module) error { logger := log.FromContext(ctx) for _, decl := range module.Decls { v, ok := decl.(*schema.Verb) @@ -296,7 +296,7 @@ func (d *DAL) createSubscribers(ctx context.Context, tx *sql.Tx, key model.Deplo } } subscriberKey := model.NewSubscriberKey(module.Name, s.Name, v.Name) - err = tx.InsertSubscriber(ctx, sql.InsertSubscriberParams{ + err = d.db.InsertSubscriber(ctx, sql.InsertSubscriberParams{ Key: subscriberKey, Module: module.Name, SubscriptionName: s.Name, @@ -316,10 +316,10 @@ func (d *DAL) createSubscribers(ctx context.Context, tx *sql.Tx, key model.Deplo return nil } -func (d *DAL) removeSubscriptionsAndSubscribers(ctx context.Context, tx *sql.Tx, key model.DeploymentKey) error { +func (d *DAL) removeSubscriptionsAndSubscribers(ctx context.Context, key model.DeploymentKey) error { logger := log.FromContext(ctx) - subscribers, err := tx.DeleteSubscribers(ctx, key) + subscribers, err := d.db.DeleteSubscribers(ctx, key) if err != nil { return fmt.Errorf("could not delete old subscribers: %w", dalerrs.TranslatePGError(err)) } @@ -327,7 +327,7 @@ func (d *DAL) removeSubscriptionsAndSubscribers(ctx context.Context, tx *sql.Tx, logger.Debugf("Deleted subscribers for %s: %s", key, strings.Join(slices.Map(subscribers, func(key model.SubscriberKey) string { return key.String() }), ", ")) } - subscriptions, err := tx.DeleteSubscriptions(ctx, key) + subscriptions, err := d.db.DeleteSubscriptions(ctx, key) if err != nil { return fmt.Errorf("could not delete old subscriptions: %w", dalerrs.TranslatePGError(err)) } diff --git a/backend/controller/pubsub/manager.go b/backend/controller/pubsub/manager.go index 23c34ebe59..9a7e340e33 100644 --- a/backend/controller/pubsub/manager.go +++ b/backend/controller/pubsub/manager.go @@ -66,7 +66,7 @@ func (m *Manager) progressSubscriptions(ctx context.Context) (time.Duration, err } // OnCallCompletion is called within a transaction after an async call has completed to allow the subscription state to be updated. -func (m *Manager) OnCallCompletion(ctx context.Context, tx *dal.Tx, origin dal.AsyncOriginPubSub, failed bool, isFinalResult bool) error { +func (m *Manager) OnCallCompletion(ctx context.Context, tx *dal.DAL, origin dal.AsyncOriginPubSub, failed bool, isFinalResult bool) error { if !isFinalResult { // Wait for the async call's retries to complete before progressing the subscription return nil diff --git a/backend/controller/sql/conn.go b/backend/controller/sql/conn.go deleted file mode 100644 index d3a29c87cf..0000000000 --- a/backend/controller/sql/conn.go +++ /dev/null @@ -1,94 +0,0 @@ -package sql - -import ( - "context" - "database/sql" - "errors" - "fmt" -) - -type DBI interface { - Querier - Conn() ConnI - Begin(ctx context.Context) (*Tx, error) -} - -type ConnI interface { - DBTX - Begin() (*sql.Tx, error) -} - -type DB struct { - conn ConnI - *Queries -} - -func NewDB(conn ConnI) *DB { - return &DB{conn: conn, Queries: New(conn)} -} - -func (d *DB) Conn() ConnI { return d.conn } - -func (d *DB) Begin(ctx context.Context) (*Tx, error) { - tx, err := d.conn.Begin() - if err != nil { - return nil, err - } - return &Tx{tx: tx, Queries: New(tx)}, nil -} - -type noopSubConn struct { - DBTX -} - -func (noopSubConn) Begin() (*sql.Tx, error) { - return nil, errors.New("sql: not implemented") -} - -type Tx struct { - tx *sql.Tx - *Queries -} - -func (t *Tx) Conn() ConnI { return noopSubConn{t.tx} } - -func (t *Tx) Tx() *sql.Tx { return t.tx } - -func (t *Tx) Begin(ctx context.Context) (*Tx, error) { - return nil, fmt.Errorf("cannot nest transactions") -} - -func (t *Tx) Commit(ctx context.Context) error { - err := t.tx.Commit() - if err != nil { - return fmt.Errorf("committing transaction: %w", err) - } - - return nil -} - -func (t *Tx) Rollback(ctx context.Context) error { - err := t.tx.Rollback() - if err != nil { - return fmt.Errorf("rolling back transaction: %w", err) - } - - return nil -} - -// CommitOrRollback can be used in a defer statement to commit or rollback a -// transaction depending on whether the enclosing function returned an error. -// -// func myFunc() (err error) { -// tx, err := db.Begin(ctx) -// if err != nil { return err } -// defer tx.CommitOrRollback(ctx, &err) -// ... -// } -func (t *Tx) CommitOrRollback(ctx context.Context, err *error) { - if *err != nil { - *err = errors.Join(*err, t.Rollback(ctx)) - } else { - *err = t.Commit(ctx) - } -} diff --git a/backend/dal/dal.go b/backend/dal/dal.go new file mode 100644 index 0000000000..14fcb3de47 --- /dev/null +++ b/backend/dal/dal.go @@ -0,0 +1,132 @@ +package dal + +import ( + "context" + "database/sql" + "errors" + "fmt" + "sync/atomic" +) + +// Counters for testing. +var ( + testCommitCounter atomic.Int64 + testRollbackCounter atomic.Int64 +) + +// Connection is a common interface for *sql.DB and *sql.Tx. +type Connection interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} + +// MakeWithHandle is a function that can be used to create a new T with a SQLHandle. +type MakeWithHandle[T any] func(*Handle[T]) *T + +// Handle is a wrapper around a database connection that can be embedded within a struct to provide access +// to a database connection and methods for managing transactions. +type Handle[T any] struct { + Connection Connection + txCounter int64 + Make MakeWithHandle[T] +} + +// New creates a new Handle +func New[T any](sql Connection, fn MakeWithHandle[T]) *Handle[T] { + return &Handle[T]{Connection: sql, Make: fn} +} + +// Begin creates a new transaction or increments the transaction counter if the handle is already in a transaction. +// +// In all cases a new handle is returned. +func (h *Handle[T]) Begin(ctx context.Context) (*T, error) { + switch conn := h.Connection.(type) { + case *sql.DB: + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("failed to begin transaction: %w", TranslatePGError(err)) + } + + txConn := &Handle[T]{Connection: tx, Make: h.Make} + return h.Make(txConn), nil + + case *sql.Tx: + sub := &Handle[T]{Connection: conn, Make: h.Make, txCounter: h.txCounter + 1} + _, err := conn.ExecContext(ctx, fmt.Sprintf("SAVEPOINT sp%d", sub.txCounter)) + if err != nil { + return nil, fmt.Errorf("failed to begin savepoint: %w", TranslatePGError(err)) + } + + return h.Make(sub), nil + default: + return nil, errors.New("invalid connection type") + } +} + +// CommitOrRollback commits the transaction if err is nil, otherwise rolls it back. +// +// Use it in a defer like so, particularly taking note of named return value +// `(err error)`. Without this it will not work. +// +// func (d *DAL) SomeMethod() (err error) { +// tx, err := d.Begin() +// if err != nil { return err } +// defer tx.CommitOrRollback(&err) +// // ... +// return nil +// } +func (h *Handle[T]) CommitOrRollback(ctx context.Context, err *error) { + _, ok := h.Connection.(*sql.Tx) + if !ok { + *err = errors.New("can only commit or rollback a transaction") + return + } + + if h.txCounter > 0 { + return + } + + if *err != nil { + *err = errors.Join(*err, h.Rollback(ctx)) + } else { + *err = h.Commit(ctx) + } +} + +// Commit the transaction or savepoint. +func (h *Handle[T]) Commit(ctx context.Context) error { + sqlTx, ok := h.Connection.(*sql.Tx) + if !ok { + return errors.New("can only commit or rollback a transaction") + } + + testCommitCounter.Add(1) + if h.txCounter == 0 { + return TranslatePGError(sqlTx.Commit()) + } + _, err := sqlTx.Exec(fmt.Sprintf("RELEASE SAVEPOINT sp%d", h.txCounter)) + if err != nil { + return TranslatePGError(err) + } + return nil +} + +// Rollback the transaction or savepoint. +func (h *Handle[T]) Rollback(ctx context.Context) error { + sqlTx, ok := h.Connection.(*sql.Tx) + if !ok { + return errors.New("can only commit or rollback a transaction") + } + + testRollbackCounter.Add(1) + if h.txCounter == 0 { + return TranslatePGError(sqlTx.Rollback()) + } + _, err := sqlTx.Exec(fmt.Sprintf("ROLLBACK TO SAVEPOINT sp%d", h.txCounter)) + if err != nil { + return TranslatePGError(err) + } + return nil +} diff --git a/backend/dal/dal_test.go b/backend/dal/dal_test.go new file mode 100644 index 0000000000..22140afe47 --- /dev/null +++ b/backend/dal/dal_test.go @@ -0,0 +1,174 @@ +package dal + +import ( + "context" + "database/sql" + "errors" + "fmt" + "testing" + + "github.com/alecthomas/assert/v2" + _ "modernc.org/sqlite" // Pure Go SQLite driver. +) + +type DAL struct { + *Handle[DAL] +} + +// New creates a new Data Access Layer instance. +func NewConn(sqlConn *sql.DB) *DAL { + return NewWithConn(New(sqlConn, NewWithConn)) +} + +func NewWithConn(conn *Handle[DAL]) *DAL { + return &DAL{conn} +} + +func (d *DAL) CreateUser(ctx context.Context, username string, name string) error { + _, err := d.Connection.ExecContext(ctx, ` + INSERT INTO users (username, name) + VALUES ($1, $2) + `, username, name) + if err != nil { + return fmt.Errorf("create user %s: %w", username, err) + } + return nil +} + +func (d *DAL) CreateUsers(ctx context.Context, users [][]string) (err error) { + txn, err := d.Begin(ctx) + if err != nil { + return err + } + + defer txn.CommitOrRollback(ctx, &err) + + for _, user := range users { + err = txn.CreateUser(ctx, user[0], user[1]) + if err != nil { + return err + } + } + + return err +} + +func (d *DAL) GetUserByUsername(ctx context.Context, username string) (string, error) { + var user string + err := d.Connection.QueryRowContext(ctx, ` + SELECT name + FROM users + WHERE username = $1 + `, username).Scan(&user) + if err != nil { + return user, fmt.Errorf("user by username %s: %w", username, err) + } + return user, nil +} + +func TestDAL(t *testing.T) { + for _, test := range []struct { + name string + fn func(ctx context.Context, t *testing.T, dal *DAL) + }{ + {"WriteAndRead", func(ctx context.Context, t *testing.T, dal *DAL) { + err := dal.CreateUser(ctx, "bob", "Bob Smith") + assert.NoError(t, err) + + user, err := dal.GetUserByUsername(ctx, "bob") + assert.NoError(t, err) + assert.Equal(t, "Bob Smith", user) + }}, + {"CommitOrRollbackWillRollbackOnError", func(ctx context.Context, t *testing.T, dal *DAL) { + f := func() (err error) { + tx, err := dal.Begin(ctx) + assert.NoError(t, err) + defer tx.CommitOrRollback(ctx, &err) + + err = tx.CreateUser(ctx, "bob", "Bob Smith") + assert.NoError(t, err) + + return errors.New("some error") + } + + err := f() + assert.EqualError(t, err, "some error") + assert.Equal(t, 1, testRollbackCounter.Load()) + assert.Equal(t, 0, testCommitCounter.Load()) + + _, err = dal.GetUserByUsername(ctx, "bob") + assert.IsError(t, err, sql.ErrNoRows) + }}, + {"CommitOrRollbackWillCommitOnSuccess", func(ctx context.Context, t *testing.T, dal *DAL) { + f := func() (err error) { + tx, err := dal.Begin(ctx) + assert.NoError(t, err) + defer tx.CommitOrRollback(ctx, &err) + + err = tx.CreateUser(ctx, "bob", "Bob Smith") + assert.NoError(t, err) + + return nil + } + + err := f() + assert.NoError(t, err) + assert.Equal(t, 0, testRollbackCounter.Load()) + assert.Equal(t, 1, testCommitCounter.Load()) + + user, err := dal.GetUserByUsername(ctx, "bob") + assert.NoError(t, err) + assert.Equal(t, "Bob Smith", user) + }}, + {"TestMultipleTxn", func(ctx context.Context, t *testing.T, dal *DAL) { + f := func() (err error) { + tx, err := dal.Begin(ctx) + assert.NoError(t, err) + defer tx.CommitOrRollback(ctx, &err) + + err = tx.CreateUser(ctx, "bob", "Bob Smith") + assert.NoError(t, err) + + err = tx.CreateUsers(ctx, [][]string{ + {"randy", "Randy McRando"}, + {"hehe", "Jimmy DROP TABLES"}, + }) + + assert.NoError(t, err) + + return nil + } + + err := f() + assert.NoError(t, err) + assert.Equal(t, 0, testRollbackCounter.Load()) + assert.Equal(t, 1, testCommitCounter.Load()) + + user, err := dal.GetUserByUsername(ctx, "bob") + assert.NoError(t, err) + assert.Equal(t, "Bob Smith", user) + + user2, err := dal.GetUserByUsername(ctx, "randy") + assert.NoError(t, err) + assert.Equal(t, "Randy McRando", user2) + + user3, err := dal.GetUserByUsername(ctx, "hehe") + assert.NoError(t, err) + assert.Equal(t, "Jimmy DROP TABLES", user3) + }}, + } { + t.Run(test.name, func(t *testing.T) { + t.Cleanup(func() { + testRollbackCounter.Store(0) + testCommitCounter.Store(0) + }) + ctx := context.Background() + db, err := sql.Open("sqlite", ":memory:") + assert.NoError(t, err) + t.Cleanup(func() { assert.NoError(t, db.Close()) }) + _, err = db.Exec(`CREATE TABLE users (username TEXT, name TEXT)`) + assert.NoError(t, err) + test.fn(ctx, t, NewConn(db)) + }) + } +} diff --git a/cmd/ftl-controller/main.go b/cmd/ftl-controller/main.go index fda7db9e4c..d95d054855 100644 --- a/cmd/ftl-controller/main.go +++ b/cmd/ftl-controller/main.go @@ -61,7 +61,7 @@ func main() { dal, err := dal.New(ctx, conn, encryptionBuilder) kctx.FatalIfErrorf(err) - configDal, err := cfdal.New(ctx, conn) + configDal := cfdal.New(conn) kctx.FatalIfErrorf(err) configProviders := []cf.Provider[cf.Configuration]{cf.NewDBConfigProvider(configDal)} configResolver := cf.NewDBConfigResolver(configDal) diff --git a/go-runtime/schema/testdata/failing/go.sum b/go-runtime/schema/testdata/failing/go.sum index a0bde06e0d..34b6a7ddd4 100644 --- a/go-runtime/schema/testdata/failing/go.sum +++ b/go-runtime/schema/testdata/failing/go.sum @@ -134,6 +134,8 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= +gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI= modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4= modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U= diff --git a/go-runtime/schema/testdata/validation/go.sum b/go-runtime/schema/testdata/validation/go.sum index a0bde06e0d..34b6a7ddd4 100644 --- a/go-runtime/schema/testdata/validation/go.sum +++ b/go-runtime/schema/testdata/validation/go.sum @@ -134,6 +134,8 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= +gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI= modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4= modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U= diff --git a/internal/buildengine/testdata/alpha/go.sum b/internal/buildengine/testdata/alpha/go.sum index a0bde06e0d..34b6a7ddd4 100644 --- a/internal/buildengine/testdata/alpha/go.sum +++ b/internal/buildengine/testdata/alpha/go.sum @@ -134,6 +134,8 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= +gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI= modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4= modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U= diff --git a/internal/buildengine/testdata/another/go.sum b/internal/buildengine/testdata/another/go.sum index a0bde06e0d..34b6a7ddd4 100644 --- a/internal/buildengine/testdata/another/go.sum +++ b/internal/buildengine/testdata/another/go.sum @@ -134,6 +134,8 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= +gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI= modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4= modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U= diff --git a/internal/buildengine/testdata/other/go.sum b/internal/buildengine/testdata/other/go.sum index a0bde06e0d..34b6a7ddd4 100644 --- a/internal/buildengine/testdata/other/go.sum +++ b/internal/buildengine/testdata/other/go.sum @@ -134,6 +134,8 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= +gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI= modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4= modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U= diff --git a/internal/configuration/dal/dal.go b/internal/configuration/dal/dal.go index 0496e09f7a..b77ea16287 100644 --- a/internal/configuration/dal/dal.go +++ b/internal/configuration/dal/dal.go @@ -7,41 +7,46 @@ import ( "github.com/alecthomas/types/optional" - dalerrs "github.com/TBD54566975/ftl/backend/dal" + "github.com/TBD54566975/ftl/backend/dal" "github.com/TBD54566975/ftl/internal/configuration/sql" ) type DAL struct { - db sql.DBI + *dal.Handle[DAL] + db sql.Querier } -func New(ctx context.Context, conn sql.ConnI) (*DAL, error) { - dal := &DAL{db: sql.NewDB(conn)} - return dal, nil +func New(conn dal.Connection) *DAL { + return &DAL{ + db: sql.New(conn), + Handle: dal.New(conn, func(h *dal.Handle[DAL]) *DAL { + return &DAL{Handle: h, db: sql.New(h.Connection)} + }), + } } func (d *DAL) GetModuleConfiguration(ctx context.Context, module optional.Option[string], name string) ([]byte, error) { b, err := d.db.GetModuleConfiguration(ctx, module, name) if err != nil { - return nil, dalerrs.TranslatePGError(err) + return nil, dal.TranslatePGError(err) } return b, nil } func (d *DAL) SetModuleConfiguration(ctx context.Context, module optional.Option[string], name string, value []byte) error { err := d.db.SetModuleConfiguration(ctx, module, name, value) - return dalerrs.TranslatePGError(err) + return dal.TranslatePGError(err) } func (d *DAL) UnsetModuleConfiguration(ctx context.Context, module optional.Option[string], name string) error { err := d.db.UnsetModuleConfiguration(ctx, module, name) - return dalerrs.TranslatePGError(err) + return dal.TranslatePGError(err) } func (d *DAL) ListModuleConfiguration(ctx context.Context) ([]sql.ModuleConfiguration, error) { l, err := d.db.ListModuleConfiguration(ctx) if err != nil { - return nil, dalerrs.TranslatePGError(err) + return nil, dal.TranslatePGError(err) } return l, nil } @@ -49,7 +54,7 @@ func (d *DAL) ListModuleConfiguration(ctx context.Context) ([]sql.ModuleConfigur func (d *DAL) GetModuleSecretURL(ctx context.Context, module optional.Option[string], name string) (string, error) { b, err := d.db.GetModuleSecretURL(ctx, module, name) if err != nil { - return "", fmt.Errorf("could not get secret URL: %w", dalerrs.TranslatePGError(err)) + return "", fmt.Errorf("could not get secret URL: %w", dal.TranslatePGError(err)) } return b, nil } @@ -57,7 +62,7 @@ func (d *DAL) GetModuleSecretURL(ctx context.Context, module optional.Option[str func (d *DAL) SetModuleSecretURL(ctx context.Context, module optional.Option[string], name string, url string) error { err := d.db.SetModuleSecretURL(ctx, module, name, url) if err != nil { - return fmt.Errorf("could not set secret URL: %w", dalerrs.TranslatePGError(err)) + return fmt.Errorf("could not set secret URL: %w", dal.TranslatePGError(err)) } return nil } @@ -65,7 +70,7 @@ func (d *DAL) SetModuleSecretURL(ctx context.Context, module optional.Option[str func (d *DAL) UnsetModuleSecret(ctx context.Context, module optional.Option[string], name string) error { err := d.db.UnsetModuleSecret(ctx, module, name) if err != nil { - return fmt.Errorf("could not unset secret: %w", dalerrs.TranslatePGError(err)) + return fmt.Errorf("could not unset secret: %w", dal.TranslatePGError(err)) } return nil } @@ -75,7 +80,7 @@ type ModuleSecret sql.ModuleSecret func (d *DAL) ListModuleSecrets(ctx context.Context) ([]ModuleSecret, error) { l, err := d.db.ListModuleSecrets(ctx) if err != nil { - return nil, fmt.Errorf("could not list secrets: %w", dalerrs.TranslatePGError(err)) + return nil, fmt.Errorf("could not list secrets: %w", dal.TranslatePGError(err)) } // Convert []sql.ModuleSecret to []ModuleSecret diff --git a/internal/configuration/dal/dal_test.go b/internal/configuration/dal/dal_test.go index b555e7489c..d673648693 100644 --- a/internal/configuration/dal/dal_test.go +++ b/internal/configuration/dal/dal_test.go @@ -6,22 +6,21 @@ import ( "fmt" "testing" + "github.com/alecthomas/assert/v2" + "github.com/alecthomas/types/optional" + "github.com/TBD54566975/ftl/backend/controller/sql/sqltest" libdal "github.com/TBD54566975/ftl/backend/dal" "github.com/TBD54566975/ftl/internal/log" - "github.com/alecthomas/assert/v2" - "github.com/alecthomas/types/optional" ) func TestDALConfiguration(t *testing.T) { t.Run("ModuleConfiguration", func(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn) - assert.NoError(t, err) - assert.NotZero(t, dal) + dal := New(conn) - err = dal.SetModuleConfiguration(ctx, optional.Some("echo"), "my_config", []byte(`""`)) + err := dal.SetModuleConfiguration(ctx, optional.Some("echo"), "my_config", []byte(`""`)) assert.NoError(t, err) value, err := dal.GetModuleConfiguration(ctx, optional.Some("echo"), "my_config") @@ -40,11 +39,9 @@ func TestDALConfiguration(t *testing.T) { t.Run("GlobalConfiguration", func(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn) - assert.NoError(t, err) - assert.NotZero(t, dal) + dal := New(conn) - err = dal.SetModuleConfiguration(ctx, optional.None[string](), "my_config", []byte(`""`)) + err := dal.SetModuleConfiguration(ctx, optional.None[string](), "my_config", []byte(`""`)) assert.NoError(t, err) value, err := dal.GetModuleConfiguration(ctx, optional.None[string](), "my_config") @@ -64,11 +61,9 @@ func TestDALConfiguration(t *testing.T) { t.Run("SetSameGlobalConfigTwice", func(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn) - assert.NoError(t, err) - assert.NotZero(t, dal) + dal := New(conn) - err = dal.SetModuleConfiguration(ctx, optional.None[string](), "my_config", []byte(`""`)) + err := dal.SetModuleConfiguration(ctx, optional.None[string](), "my_config", []byte(`""`)) assert.NoError(t, err) err = dal.SetModuleConfiguration(ctx, optional.None[string](), "my_config", []byte(`"hehe"`)) @@ -82,11 +77,9 @@ func TestDALConfiguration(t *testing.T) { t.Run("SetModuleOverridesGlobal", func(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn) - assert.NoError(t, err) - assert.NotZero(t, dal) + dal := New(conn) - err = dal.SetModuleConfiguration(ctx, optional.None[string](), "my_config", []byte(`""`)) + err := dal.SetModuleConfiguration(ctx, optional.None[string](), "my_config", []byte(`""`)) assert.NoError(t, err) err = dal.SetModuleConfiguration(ctx, optional.Some("echo"), "my_config", []byte(`"hehe"`)) assert.NoError(t, err) @@ -99,11 +92,9 @@ func TestDALConfiguration(t *testing.T) { t.Run("HandlesConflicts", func(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn) - assert.NoError(t, err) - assert.NotZero(t, dal) + dal := New(conn) - err = dal.SetModuleConfiguration(ctx, optional.Some("echo"), "my_config", []byte(`""`)) + err := dal.SetModuleConfiguration(ctx, optional.Some("echo"), "my_config", []byte(`""`)) assert.NoError(t, err) err = dal.SetModuleConfiguration(ctx, optional.Some("echo"), "my_config", []byte(`""`)) assert.NoError(t, err) @@ -123,11 +114,9 @@ func TestDALSecrets(t *testing.T) { t.Run("ModuleSecret", func(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn) - assert.NoError(t, err) - assert.NotZero(t, dal) + dal := New(conn) - err = dal.SetModuleSecretURL(ctx, optional.Some("echo"), "my_secret", "http://example.com") + err := dal.SetModuleSecretURL(ctx, optional.Some("echo"), "my_secret", "http://example.com") assert.NoError(t, err) value, err := dal.GetModuleSecretURL(ctx, optional.Some("echo"), "my_secret") @@ -146,11 +135,9 @@ func TestDALSecrets(t *testing.T) { t.Run("GlobalSecret", func(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn) - assert.NoError(t, err) - assert.NotZero(t, dal) + dal := New(conn) - err = dal.SetModuleSecretURL(ctx, optional.None[string](), "my_secret", "http://example.com") + err := dal.SetModuleSecretURL(ctx, optional.None[string](), "my_secret", "http://example.com") assert.NoError(t, err) value, err := dal.GetModuleSecretURL(ctx, optional.None[string](), "my_secret") @@ -169,11 +156,9 @@ func TestDALSecrets(t *testing.T) { t.Run("SetSameGlobalSecretTwice", func(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn) - assert.NoError(t, err) - assert.NotZero(t, dal) + dal := New(conn) - err = dal.SetModuleSecretURL(ctx, optional.None[string](), "my_secret", "http://example.com") + err := dal.SetModuleSecretURL(ctx, optional.None[string](), "my_secret", "http://example.com") assert.NoError(t, err) err = dal.SetModuleSecretURL(ctx, optional.None[string](), "my_secret", "http://example2.com") @@ -187,11 +172,9 @@ func TestDALSecrets(t *testing.T) { t.Run("SetModuleOverridesGlobal", func(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn) - assert.NoError(t, err) - assert.NotZero(t, dal) + dal := New(conn) - err = dal.SetModuleSecretURL(ctx, optional.None[string](), "my_secret", "http://example.com") + err := dal.SetModuleSecretURL(ctx, optional.None[string](), "my_secret", "http://example.com") assert.NoError(t, err) err = dal.SetModuleSecretURL(ctx, optional.Some("echo"), "my_secret", "http://example2.com") assert.NoError(t, err) @@ -204,11 +187,9 @@ func TestDALSecrets(t *testing.T) { t.Run("HandlesConflicts", func(t *testing.T) { ctx := log.ContextWithNewDefaultLogger(context.Background()) conn := sqltest.OpenForTesting(ctx, t) - dal, err := New(ctx, conn) - assert.NoError(t, err) - assert.NotZero(t, dal) + dal := New(conn) - err = dal.SetModuleSecretURL(ctx, optional.Some("echo"), "my_secret", "http://example.com") + err := dal.SetModuleSecretURL(ctx, optional.Some("echo"), "my_secret", "http://example.com") assert.NoError(t, err) err = dal.SetModuleSecretURL(ctx, optional.Some("echo"), "my_secret", "http://example2.com") assert.NoError(t, err) diff --git a/internal/configuration/sql/conn.go b/internal/configuration/sql/conn.go deleted file mode 100644 index 065487cefa..0000000000 --- a/internal/configuration/sql/conn.go +++ /dev/null @@ -1,21 +0,0 @@ -package sql - -type DBI interface { - Querier - Conn() ConnI -} - -type ConnI interface { - DBTX -} - -type DB struct { - conn ConnI - *Queries -} - -func NewDB(conn ConnI) *DB { - return &DB{conn: conn, Queries: New(conn)} -} - -func (d *DB) Conn() ConnI { return d.conn }