Skip to content

Commit

Permalink
refactor: move leases DAL+queries into its own package (#2566)
Browse files Browse the repository at this point in the history
This is a first step in splitting out functionality from the core
controller into discrete sub-packages.

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
alecthomas and github-actions[bot] authored Aug 31, 2024
1 parent 6eabae2 commit 37aed73
Show file tree
Hide file tree
Showing 65 changed files with 1,067 additions and 558 deletions.
30 changes: 17 additions & 13 deletions backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ import (
"github.com/TBD54566975/ftl/backend/controller/dal"
"github.com/TBD54566975/ftl/backend/controller/ingress"
"github.com/TBD54566975/ftl/backend/controller/leases"
leasesdal "github.com/TBD54566975/ftl/backend/controller/leases/dal"
"github.com/TBD54566975/ftl/backend/controller/observability"
"github.com/TBD54566975/ftl/backend/controller/pubsub"
"github.com/TBD54566975/ftl/backend/controller/scaling"
"github.com/TBD54566975/ftl/backend/controller/scaling/localscaling"
"github.com/TBD54566975/ftl/backend/controller/scheduledtask"
dalerrs "github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/backend/libdal"
ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1"
"github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/console/pbconsoleconnect"
"github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/ftlv1connect"
Expand Down Expand Up @@ -194,6 +195,7 @@ type ControllerListListener interface {

type Service struct {
conn *sql.DB
leasesdal *leasesdal.DAL
dal *dal.DAL
key model.ControllerKey
deploymentLogsSink *deploymentLogsSink
Expand Down Expand Up @@ -231,14 +233,16 @@ func New(ctx context.Context, conn *sql.DB, config Config, runnerScaling scaling
config.ControllerTimeout = time.Second * 5
}

ldb := leasesdal.New(conn)
db, err := dal.New(ctx, conn, encryption.NewBuilder().WithKMSURI(optional.Ptr(config.KMSURI)))
if err != nil {
return nil, fmt.Errorf("failed to create DAL: %w", err)
}

svc := &Service{
tasks: scheduledtask.New(ctx, key, db),
tasks: scheduledtask.New(ctx, key, ldb),
dal: db,
leasesdal: ldb,
conn: conn,
key: key,
deploymentLogsSink: newDeploymentLogsSink(ctx, db),
Expand Down Expand Up @@ -307,7 +311,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {

routes, err := s.dal.GetIngressRoutes(r.Context(), r.Method)
if err != nil {
if errors.Is(err, dalerrs.ErrNotFound) {
if errors.Is(err, libdal.ErrNotFound) {
http.NotFound(w, r)
observability.Ingress.Request(r.Context(), r.Method, r.URL.Path, optional.None[*schemapb.Ref](), start, optional.Some("route not found in dal"))
return
Expand Down Expand Up @@ -509,7 +513,7 @@ func (s *Service) UpdateDeploy(ctx context.Context, req *connect.Request[ftlv1.U

err = s.dal.SetDeploymentReplicas(ctx, deploymentKey, int(req.Msg.MinReplicas))
if err != nil {
if errors.Is(err, dalerrs.ErrNotFound) {
if errors.Is(err, libdal.ErrNotFound) {
logger.Errorf(err, "Deployment not found: %s", deploymentKey)
return nil, connect.NewError(connect.CodeNotFound, errors.New("deployment not found"))
}
Expand All @@ -531,7 +535,7 @@ func (s *Service) ReplaceDeploy(ctx context.Context, c *connect.Request[ftlv1.Re

err = s.dal.ReplaceDeployment(ctx, newDeploymentKey, int(c.Msg.MinReplicas))
if err != nil {
if errors.Is(err, dalerrs.ErrNotFound) {
if errors.Is(err, libdal.ErrNotFound) {
logger.Errorf(err, "Deployment not found: %s", newDeploymentKey)
return nil, connect.NewError(connect.CodeNotFound, errors.New("deployment not found"))
} else if errors.Is(err, dal.ErrReplaceDeploymentAlreadyActive) {
Expand Down Expand Up @@ -591,7 +595,7 @@ func (s *Service) RegisterRunner(ctx context.Context, stream *connect.ClientStre
Deployment: maybeDeployment,
Labels: msg.Labels.AsMap(),
})
if errors.Is(err, dalerrs.ErrConflict) {
if errors.Is(err, libdal.ErrConflict) {
return nil, connect.NewError(connect.CodeAlreadyExists, err)
} else if err != nil {
return nil, err
Expand All @@ -608,7 +612,7 @@ func (s *Service) RegisterRunner(ctx context.Context, stream *connect.ClientStre
}

routes, err := s.dal.GetRoutingTable(ctx, nil)
if errors.Is(err, dalerrs.ErrNotFound) {
if errors.Is(err, libdal.ErrNotFound) {
routes = map[string][]dal.Route{}
} else if err != nil {
return nil, err
Expand Down Expand Up @@ -815,7 +819,7 @@ func (s *Service) AcquireLease(ctx context.Context, stream *connect.BidiStream[f
return connect.NewError(connect.CodeInternal, fmt.Errorf("could not receive lease request: %w", err))
}
if lease == nil {
lease, _, err = s.dal.AcquireLease(ctx, leases.ModuleKey(msg.Module, msg.Key...), msg.Ttl.AsDuration(), optional.None[any]())
lease, _, err = s.leasesdal.AcquireLease(ctx, leases.ModuleKey(msg.Module, msg.Key...), msg.Ttl.AsDuration(), optional.None[any]())
if err != nil {
if errors.Is(err, leases.ErrConflict) {
return connect.NewError(connect.CodeResourceExhausted, fmt.Errorf("lease is held: %w", err))
Expand Down Expand Up @@ -948,7 +952,7 @@ func (s *Service) SetNextFSMEvent(ctx context.Context, req *connect.Request[ftlv
// Get the current state the instance is transitioning to.
_, currentDestinationState, err := tx.GetFSMStates(ctx, fsmKey, req.Msg.Instance)
if err != nil {
if errors.Is(err, dalerrs.ErrNotFound) {
if errors.Is(err, libdal.ErrNotFound) {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("fsm instance not found: %w", err))
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("could not get fsm instance: %w", err))
Expand All @@ -963,7 +967,7 @@ func (s *Service) SetNextFSMEvent(ctx context.Context, req *connect.Request[ftlv
// Set the next event.
err = tx.SetNextFSMEvent(ctx, fsmKey, msg.Instance, nextState.ToRefKey(), msg.Body, eventType)
if err != nil {
if errors.Is(err, dalerrs.ErrConflict) {
if errors.Is(err, libdal.ErrConflict) {
return nil, connect.NewError(connect.CodeFailedPrecondition, fmt.Errorf("fsm instance already has its next state set: %w", err))
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("could not set next fsm event: %w", err))
Expand Down Expand Up @@ -1403,7 +1407,7 @@ func (s *Service) executeAsyncCalls(ctx context.Context) (interval time.Duration
logger.Tracef("Acquiring async call")

call, leaseCtx, err := s.dal.AcquireAsyncCall(ctx)
if errors.Is(err, dalerrs.ErrNotFound) {
if errors.Is(err, libdal.ErrNotFound) {
logger.Tracef("No async calls to execute")
return time.Second * 2, nil
} else if err != nil {
Expand Down Expand Up @@ -1740,7 +1744,7 @@ func (s *Service) resolveFSMEvent(msg *ftlv1.SendFSMEventRequest) (fsm *schema.F
}

func (s *Service) expireStaleLeases(ctx context.Context) (time.Duration, error) {
err := s.dal.ExpireLeases(ctx)
err := s.leasesdal.ExpireLeases(ctx)
if err != nil {
return 0, fmt.Errorf("failed to expire leases: %w", err)
}
Expand Down Expand Up @@ -1972,7 +1976,7 @@ func (s *Service) getDeploymentLogger(ctx context.Context, deploymentKey model.D
// Periodically sync the routing table from the DB.
func (s *Service) syncRoutes(ctx context.Context) (time.Duration, error) {
routes, err := s.dal.GetRoutingTable(ctx, nil)
if errors.Is(err, dalerrs.ErrNotFound) {
if errors.Is(err, libdal.ErrNotFound) {
routes = map[string][]dal.Route{}
} else if err != nil {
return 0, err
Expand Down
4 changes: 2 additions & 2 deletions backend/controller/cronjobs/cronjobs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/TBD54566975/ftl/backend/controller/cronjobs/dal"
parentdal "github.com/TBD54566975/ftl/backend/controller/dal"
"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
dalerrs "github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/backend/libdal"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/cron"
"github.com/TBD54566975/ftl/internal/encryption"
Expand Down Expand Up @@ -57,7 +57,7 @@ func TestNewCronJobsForModule(t *testing.T) {

// No async calls yet
_, _, err = parentDAL.AcquireAsyncCall(ctx)
assert.IsError(t, err, dalerrs.ErrNotFound)
assert.IsError(t, err, libdal.ErrNotFound)
assert.EqualError(t, err, "no pending async calls: not found")

err = cjs.scheduleCronJobs(ctx)
Expand Down
18 changes: 9 additions & 9 deletions backend/controller/cronjobs/dal/dal.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@ import (

"github.com/TBD54566975/ftl/backend/controller/cronjobs/sql"
"github.com/TBD54566975/ftl/backend/controller/observability"
"github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/backend/libdal"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/model"
"github.com/TBD54566975/ftl/internal/slices"
)

type DAL struct {
*dal.Handle[DAL]
*libdal.Handle[DAL]
db sql.Querier
}

func New(conn dal.Connection) *DAL {
func New(conn libdal.Connection) *DAL {
return &DAL{
db: sql.New(conn),
Handle: dal.New(conn, func(h *dal.Handle[DAL]) *DAL {
Handle: libdal.New(conn, func(h *libdal.Handle[DAL]) *DAL {
return &DAL{Handle: h, db: sql.New(h.Connection)}
}),
}
Expand All @@ -45,7 +45,7 @@ func cronJobFromRow(c sql.CronJob, d sql.Deployment) model.CronJob {
func (d *DAL) CreateAsyncCall(ctx context.Context, params sql.CreateAsyncCallParams) (int64, error) {
id, err := d.db.CreateAsyncCall(ctx, params)
if err != nil {
return 0, fmt.Errorf("failed to create async call: %w", dal.TranslatePGError(err))
return 0, fmt.Errorf("failed to create async call: %w", libdal.TranslatePGError(err))
}
observability.AsyncCalls.Created(ctx, params.Verb, optional.None[schema.RefKey](), params.Origin, 0, err)
queueDepth, err := d.db.AsyncCallQueueDepth(ctx)
Expand All @@ -62,7 +62,7 @@ func (d *DAL) CreateAsyncCall(ctx context.Context, params sql.CreateAsyncCallPar
func (d *DAL) GetUnscheduledCronJobs(ctx context.Context, startTime time.Time) ([]model.CronJob, error) {
rows, err := d.db.GetUnscheduledCronJobs(ctx, startTime)
if err != nil {
return nil, fmt.Errorf("failed to get cron jobs: %w", dal.TranslatePGError(err))
return nil, fmt.Errorf("failed to get cron jobs: %w", libdal.TranslatePGError(err))
}
return slices.Map(rows, func(r sql.GetUnscheduledCronJobsRow) model.CronJob {
return cronJobFromRow(r.CronJob, r.Deployment)
Expand All @@ -73,7 +73,7 @@ func (d *DAL) GetUnscheduledCronJobs(ctx context.Context, startTime time.Time) (
func (d *DAL) GetCronJobByKey(ctx context.Context, key model.CronJobKey) (model.CronJob, error) {
row, err := d.db.GetCronJobByKey(ctx, key)
if err != nil {
return model.CronJob{}, fmt.Errorf("failed to get cron job %q: %w", key, dal.TranslatePGError(err))
return model.CronJob{}, fmt.Errorf("failed to get cron job %q: %w", key, libdal.TranslatePGError(err))
}
return cronJobFromRow(row.CronJob, row.Deployment), nil
}
Expand All @@ -82,7 +82,7 @@ func (d *DAL) GetCronJobByKey(ctx context.Context, key model.CronJobKey) (model.
func (d *DAL) IsCronJobPending(ctx context.Context, key model.CronJobKey, startTime time.Time) (bool, error) {
pending, err := d.db.IsCronJobPending(ctx, key, startTime)
if err != nil {
return false, fmt.Errorf("failed to check if cron job %q is pending: %w", key, dal.TranslatePGError(err))
return false, fmt.Errorf("failed to check if cron job %q is pending: %w", key, libdal.TranslatePGError(err))
}
return pending, nil
}
Expand All @@ -92,7 +92,7 @@ func (d *DAL) IsCronJobPending(ctx context.Context, key model.CronJobKey, startT
func (d *DAL) UpdateCronJobExecution(ctx context.Context, params sql.UpdateCronJobExecutionParams) error {
err := d.db.UpdateCronJobExecution(ctx, params)
if err != nil {
return fmt.Errorf("failed to update cron job %q: %w", params.Key, dal.TranslatePGError(err))
return fmt.Errorf("failed to update cron job %q: %w", params.Key, libdal.TranslatePGError(err))
}
return nil
}
2 changes: 1 addition & 1 deletion backend/controller/cronjobs/sql/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions backend/controller/cronjobs/sql/types.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package sql

import csql "github.com/TBD54566975/ftl/backend/controller/sql"
import "github.com/TBD54566975/ftl/backend/controller/sql/sqltypes"

type Type = csql.Type
type Type = sqltypes.Type
27 changes: 14 additions & 13 deletions backend/controller/dal/async_calls.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ import (
"github.com/alecthomas/types/either"
"github.com/alecthomas/types/optional"

leasedal "github.com/TBD54566975/ftl/backend/controller/leases/dal"
"github.com/TBD54566975/ftl/backend/controller/sql"
"github.com/TBD54566975/ftl/backend/controller/sql/sqltypes"
dalerrs "github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/backend/libdal"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/model"
Expand Down Expand Up @@ -95,7 +96,7 @@ func ParseAsyncOrigin(origin string) (AsyncOrigin, error) {
}

type AsyncCall struct {
*Lease // May be nil
*leasedal.Lease // May be nil
ID int64
Origin AsyncOrigin
Verb schema.RefKey
Expand Down Expand Up @@ -127,9 +128,9 @@ func (d *DAL) AcquireAsyncCall(ctx context.Context) (call *AsyncCall, leaseCtx c
ttl := time.Second * 5
row, err := tx.db.AcquireAsyncCall(ctx, sqltypes.Duration(ttl))
if err != nil {
err = dalerrs.TranslatePGError(err)
if errors.Is(err, dalerrs.ErrNotFound) {
return nil, ctx, fmt.Errorf("no pending async calls: %w", dalerrs.ErrNotFound)
err = libdal.TranslatePGError(err)
if errors.Is(err, libdal.ErrNotFound) {
return nil, ctx, fmt.Errorf("no pending async calls: %w", libdal.ErrNotFound)
}
return nil, ctx, fmt.Errorf("failed to acquire async call: %w", err)
}
Expand All @@ -143,7 +144,7 @@ func (d *DAL) AcquireAsyncCall(ctx context.Context) (call *AsyncCall, leaseCtx c
return nil, ctx, fmt.Errorf("failed to decrypt async call request: %w", err)
}

lease, leaseCtx := d.newLease(ctx, row.LeaseKey, row.LeaseIdempotencyKey, ttl)
lease, leaseCtx := d.leasedal.NewLease(ctx, row.LeaseKey, row.LeaseIdempotencyKey, ttl)
return &AsyncCall{
ID: row.AsyncCallID,
Verb: row.Verb,
Expand Down Expand Up @@ -177,7 +178,7 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context,
case *dbsql.DB:
tx, err = d.Begin(ctx)
if err != nil {
return false, dalerrs.TranslatePGError(err) //nolint:wrapcheck
return false, libdal.TranslatePGError(err) //nolint:wrapcheck
}
defer tx.CommitOrRollback(ctx, &err)
case *dbsql.Tx:
Expand All @@ -197,7 +198,7 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context,
}
_, err = tx.db.SucceedAsyncCall(ctx, optional.Some(encryptedResult), call.ID)
if err != nil {
return false, dalerrs.TranslatePGError(err) //nolint:wrapcheck
return false, libdal.TranslatePGError(err) //nolint:wrapcheck
}

case either.Right[[]byte, string]: // Failure message.
Expand All @@ -211,7 +212,7 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context,
ScheduledAt: time.Now().Add(call.Backoff),
})
if err != nil {
return false, dalerrs.TranslatePGError(err) //nolint:wrapcheck
return false, libdal.TranslatePGError(err) //nolint:wrapcheck
}
isFinalResult = false
didScheduleAnotherCall = true
Expand All @@ -234,14 +235,14 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context,
OriginalError: optional.Some(originalError),
})
if err != nil {
return false, dalerrs.TranslatePGError(err) //nolint:wrapcheck
return false, libdal.TranslatePGError(err) //nolint:wrapcheck
}
isFinalResult = false
didScheduleAnotherCall = true
} else {
_, err = tx.db.FailAsyncCall(ctx, result.Get(), call.ID)
if err != nil {
return false, dalerrs.TranslatePGError(err) //nolint:wrapcheck
return false, libdal.TranslatePGError(err) //nolint:wrapcheck
}
}
}
Expand All @@ -254,7 +255,7 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context,
func (d *DAL) LoadAsyncCall(ctx context.Context, id int64) (*AsyncCall, error) {
row, err := d.db.LoadAsyncCall(ctx, id)
if err != nil {
return nil, dalerrs.TranslatePGError(err)
return nil, libdal.TranslatePGError(err)
}
origin, err := ParseAsyncOrigin(row.Origin)
if err != nil {
Expand All @@ -275,7 +276,7 @@ func (d *DAL) LoadAsyncCall(ctx context.Context, id int64) (*AsyncCall, error) {
func (d *DAL) GetZombieAsyncCalls(ctx context.Context, limit int) ([]*AsyncCall, error) {
rows, err := d.db.GetZombieAsyncCalls(ctx, int32(limit))
if err != nil {
return nil, dalerrs.TranslatePGError(err)
return nil, libdal.TranslatePGError(err)
}
var calls []*AsyncCall
for _, row := range rows {
Expand Down
4 changes: 2 additions & 2 deletions backend/controller/dal/async_calls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/alecthomas/assert/v2"

"github.com/TBD54566975/ftl/backend/controller/sql/sqltest"
dalerrs "github.com/TBD54566975/ftl/backend/dal"
"github.com/TBD54566975/ftl/backend/libdal"
"github.com/TBD54566975/ftl/backend/schema"
"github.com/TBD54566975/ftl/internal/encryption"
"github.com/TBD54566975/ftl/internal/log"
Expand All @@ -21,7 +21,7 @@ func TestNoCallToAcquire(t *testing.T) {
assert.NoError(t, err)

_, _, err = dal.AcquireAsyncCall(ctx)
assert.IsError(t, err, dalerrs.ErrNotFound)
assert.IsError(t, err, libdal.ErrNotFound)
assert.EqualError(t, err, "no pending async calls: not found")
}

Expand Down
Loading

0 comments on commit 37aed73

Please sign in to comment.