Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: clean up async calls with expired leases #2435

Merged
merged 11 commits into from
Aug 27, 2024
38 changes: 36 additions & 2 deletions backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ func New(ctx context.Context, conn *sql.DB, config Config, runnerScaling scaling
svc.tasks.Singleton(maybeDevelTask(svc.releaseExpiredReservations, time.Second*2, time.Second, time.Second*20))
svc.tasks.Singleton(maybeDevelTask(svc.reconcileDeployments, time.Second*2, time.Second, time.Second*5))
svc.tasks.Singleton(maybeDevelTask(svc.reconcileRunners, time.Second*2, time.Second, time.Second*5))
svc.tasks.Singleton(maybeDevelTask(svc.reapAsyncCalls, time.Second*5, time.Second, time.Second*5))
return svc, nil
}

Expand Down Expand Up @@ -1401,7 +1402,7 @@ func (s *Service) executeAsyncCalls(ctx context.Context) (interval time.Duration
logger := log.FromContext(ctx)
logger.Tracef("Acquiring async call")

call, err := s.dal.AcquireAsyncCall(ctx)
call, leaseCtx, err := s.dal.AcquireAsyncCall(ctx)
if errors.Is(err, dalerrs.ErrNotFound) {
logger.Tracef("No async calls to execute")
return time.Second * 2, nil
Expand All @@ -1413,6 +1414,9 @@ func (s *Service) executeAsyncCalls(ctx context.Context) (interval time.Duration
}
return 0, err
}
// use originalCtx for things that should are done outside of the lease lifespan
originalCtx := ctx
ctx = leaseCtx

// Extract the otel context from the call
ctx, err = observability.ExtractTraceContextToContext(ctx, call.TraceContext)
Expand Down Expand Up @@ -1444,7 +1448,7 @@ func (s *Service) executeAsyncCalls(ctx context.Context) (interval time.Duration
break

case dal.AsyncOriginPubSub:
go s.pubSub.AsyncCallDidCommit(ctx, origin)
go s.pubSub.AsyncCallDidCommit(originalCtx, origin)

default:
break
Expand Down Expand Up @@ -1576,6 +1580,36 @@ func (s *Service) catchAsyncCall(ctx context.Context, logger *log.Logger, call *
return nil
}

// fails async calls that have had their leases reaped
func (s *Service) reapAsyncCalls(ctx context.Context) (nextInterval time.Duration, err error) {
tx, err := s.dal.Begin(ctx)
if err != nil {
return 0, connect.NewError(connect.CodeInternal, fmt.Errorf("could not start transaction: %w", err))
}
defer tx.CommitOrRollback(ctx, &err)

limit := 20
calls, err := tx.GetZombieAsyncCalls(ctx, 20)
if err != nil {
return 0, fmt.Errorf("failed to get zombie async calls: %w", err)
}
for _, call := range calls {
callResult := either.RightOf[[]byte]("async call lease expired")
_, err := tx.CompleteAsyncCall(ctx, call, callResult, func(tx *dal.DAL, isFinalResult bool) error {
return s.finaliseAsyncCall(ctx, tx, call, callResult, isFinalResult)
})
if err != nil {
return 0, fmt.Errorf("failed to complete zombie async call: %w", err)
}
observability.AsyncCalls.Executed(ctx, call.Verb, call.CatchVerb, call.Origin.String(), call.ScheduledAt, true, optional.Some("async call lease failed"))
}

if len(calls) == limit {
return 0, nil
}
return time.Second * 5, nil
}

func metadataForAsyncCall(call *dal.AsyncCall) *ftlv1.Metadata {
switch origin := call.Origin.(type) {
case dal.AsyncOriginCron:
Expand Down
6 changes: 3 additions & 3 deletions backend/controller/cronjobs/cronjobs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestNewCronJobsForModule(t *testing.T) {
assert.Equal(t, len(unscheduledJobs), 2)

// No async calls yet
_, err = parentDAL.AcquireAsyncCall(ctx)
_, _, err = parentDAL.AcquireAsyncCall(ctx)
assert.IsError(t, err, dalerrs.ErrNotFound)
assert.EqualError(t, err, "no pending async calls: not found")

Expand All @@ -76,7 +76,7 @@ func TestNewCronJobsForModule(t *testing.T) {
// Now there should be async calls
calls := []*parentdal.AsyncCall{}
for i, job := range jobsToCreate {
call, err := parentDAL.AcquireAsyncCall(ctx)
call, _, err := parentDAL.AcquireAsyncCall(ctx)
assert.NoError(t, err)
assert.Equal(t, call.Verb, job.Verb.ToRefKey())
assert.Equal(t, call.Origin.String(), fmt.Sprintf("cron:%s", job.Key))
Expand Down Expand Up @@ -110,7 +110,7 @@ func TestNewCronJobsForModule(t *testing.T) {
}
expectUnscheduledJobs(t, dal, clk, 0)
for i, job := range jobsToCreate {
call, err := parentDAL.AcquireAsyncCall(ctx)
call, _, err := parentDAL.AcquireAsyncCall(ctx)
assert.NoError(t, err)
assert.Equal(t, call.Verb, job.Verb.ToRefKey())
assert.Equal(t, call.Origin.String(), fmt.Sprintf("cron:%s", job.Key))
Expand Down
74 changes: 59 additions & 15 deletions backend/controller/dal/async_calls.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dal

import (
"context"
dbsql "database/sql"
"errors"
"fmt"
"time"
Expand Down Expand Up @@ -116,10 +117,10 @@ type AsyncCall struct {
// AcquireAsyncCall acquires a pending async call to execute.
//
// Returns ErrNotFound if there are no async calls to acquire.
func (d *DAL) AcquireAsyncCall(ctx context.Context) (call *AsyncCall, err error) {
func (d *DAL) AcquireAsyncCall(ctx context.Context) (call *AsyncCall, leaseCtx context.Context, err error) {
tx, err := d.Begin(ctx)
if err != nil {
return nil, fmt.Errorf("failed to begin transaction: %w", err)
return nil, ctx, fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.CommitOrRollback(ctx, &err)

Expand All @@ -128,21 +129,21 @@ func (d *DAL) AcquireAsyncCall(ctx context.Context) (call *AsyncCall, err error)
if err != nil {
err = dalerrs.TranslatePGError(err)
if errors.Is(err, dalerrs.ErrNotFound) {
return nil, fmt.Errorf("no pending async calls: %w", dalerrs.ErrNotFound)
return nil, ctx, fmt.Errorf("no pending async calls: %w", dalerrs.ErrNotFound)
}
return nil, fmt.Errorf("failed to acquire async call: %w", err)
return nil, ctx, fmt.Errorf("failed to acquire async call: %w", err)
}
origin, err := ParseAsyncOrigin(row.Origin)
if err != nil {
return nil, fmt.Errorf("failed to parse origin key %q: %w", row.Origin, err)
return nil, ctx, fmt.Errorf("failed to parse origin key %q: %w", row.Origin, err)
}

decryptedRequest, err := d.decrypt(&row.Request)
if err != nil {
return nil, fmt.Errorf("failed to decrypt async call request: %w", err)
return nil, ctx, fmt.Errorf("failed to decrypt async call request: %w", err)
}

lease, _ := d.newLease(ctx, row.LeaseKey, row.LeaseIdempotencyKey, ttl)
lease, leaseCtx := d.newLease(ctx, row.LeaseKey, row.LeaseIdempotencyKey, ttl)
return &AsyncCall{
ID: row.AsyncCallID,
Verb: row.Verb,
Expand All @@ -159,29 +160,38 @@ func (d *DAL) AcquireAsyncCall(ctx context.Context) (call *AsyncCall, err error)
Backoff: time.Duration(row.Backoff),
MaxBackoff: time.Duration(row.MaxBackoff),
Catching: row.Catching,
}, nil
}, leaseCtx, nil
}

// CompleteAsyncCall completes an async call.
// The call will use the existing transaction if d is a transaction. Otherwise it will create and commit a new transaction.
//
// "result" is either a []byte representing the successful response, or a string
// representing a failure message.
func (d *DAL) CompleteAsyncCall(ctx context.Context,
call *AsyncCall,
result either.Either[[]byte, string],
finalise func(tx *DAL, isFinalResult bool) error) (didScheduleAnotherCall bool, err error) {
tx, err := d.Begin(ctx)
if err != nil {
return false, dalerrs.TranslatePGError(err) //nolint:wrapcheck
var tx *DAL
switch d.Connection.(type) {
case *dbsql.DB:
tx, err = d.Begin(ctx)
if err != nil {
return false, dalerrs.TranslatePGError(err) //nolint:wrapcheck
}
defer tx.CommitOrRollback(ctx, &err)
case *dbsql.Tx:
tx = d
default:
return false, errors.New("invalid connection type")
}
defer tx.CommitOrRollback(ctx, &err)

isFinalResult := true
didScheduleAnotherCall = false
switch result := result.(type) {
case either.Left[[]byte, string]: // Successful response.
var encryptedResult encryption.EncryptedAsyncColumn
err := d.encrypt(result.Get(), &encryptedResult)
err := tx.encrypt(result.Get(), &encryptedResult)
if err != nil {
return false, fmt.Errorf("failed to encrypt async call result: %w", err)
}
Expand All @@ -192,7 +202,7 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context,

case either.Right[[]byte, string]: // Failure message.
if call.RemainingAttempts > 0 {
_, err = d.db.FailAsyncCallWithRetry(ctx, sql.FailAsyncCallWithRetryParams{
_, err = tx.db.FailAsyncCallWithRetry(ctx, sql.FailAsyncCallWithRetryParams{
ID: call.ID,
Error: result.Get(),
RemainingAttempts: call.RemainingAttempts - 1,
Expand All @@ -213,7 +223,7 @@ func (d *DAL) CompleteAsyncCall(ctx context.Context,
if call.Catching {
scheduledAt = scheduledAt.Add(call.Backoff)
}
_, err = d.db.FailAsyncCallWithRetry(ctx, sql.FailAsyncCallWithRetryParams{
_, err = tx.db.FailAsyncCallWithRetry(ctx, sql.FailAsyncCallWithRetryParams{
ID: call.ID,
Error: result.Get(),
RemainingAttempts: 0,
Expand Down Expand Up @@ -261,3 +271,37 @@ func (d *DAL) LoadAsyncCall(ctx context.Context, id int64) (*AsyncCall, error) {
Request: request,
}, nil
}

func (d *DAL) GetZombieAsyncCalls(ctx context.Context, limit int) ([]*AsyncCall, error) {
rows, err := d.db.GetZombieAsyncCalls(ctx, int32(limit))
if err != nil {
return nil, dalerrs.TranslatePGError(err)
}
var calls []*AsyncCall
for _, row := range rows {
origin, err := ParseAsyncOrigin(row.Origin)
if err != nil {
return nil, fmt.Errorf("failed to parse origin key %q: %w", row.Origin, err)
}
decryptedRequest, err := d.decrypt(&row.Request)
if err != nil {
return nil, fmt.Errorf("failed to decrypt async call request: %w", err)
}
calls = append(calls, &AsyncCall{
ID: row.ID,
Origin: origin,
ScheduledAt: row.ScheduledAt,
Verb: row.Verb,
CatchVerb: row.CatchVerb,
Request: decryptedRequest,
ParentRequestKey: row.ParentRequestKey,
TraceContext: row.TraceContext.RawMessage,
Error: row.Error,
RemainingAttempts: row.RemainingAttempts,
Backoff: time.Duration(row.Backoff),
MaxBackoff: time.Duration(row.MaxBackoff),
Catching: row.Catching,
})
}
return calls, nil
}
2 changes: 1 addition & 1 deletion backend/controller/dal/async_calls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestNoCallToAcquire(t *testing.T) {
dal, err := New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

_, err = dal.AcquireAsyncCall(ctx)
_, _, err = dal.AcquireAsyncCall(ctx)
assert.IsError(t, err, dalerrs.ErrNotFound)
assert.EqualError(t, err, "no pending async calls: not found")
}
Expand Down
4 changes: 2 additions & 2 deletions backend/controller/dal/fsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestSendFSMEvent(t *testing.T) {
dal, err := New(ctx, conn, encryption.NewBuilder())
assert.NoError(t, err)

_, err = dal.AcquireAsyncCall(ctx)
_, _, err = dal.AcquireAsyncCall(ctx)
assert.IsError(t, err, dalerrs.ErrNotFound)

ref := schema.RefKey{Module: "module", Name: "verb"}
Expand All @@ -32,7 +32,7 @@ func TestSendFSMEvent(t *testing.T) {
assert.IsError(t, err, dalerrs.ErrConflict)
assert.EqualError(t, err, "transition already executing: conflict")

call, err := dal.AcquireAsyncCall(ctx)
call, _, err := dal.AcquireAsyncCall(ctx)
assert.NoError(t, err)
t.Cleanup(func() {
err := call.Lease.Release()
Expand Down
41 changes: 41 additions & 0 deletions backend/controller/pubsub/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package pubsub

import (
"fmt"
"path/filepath"
"testing"
"time"

Expand Down Expand Up @@ -169,3 +170,43 @@ func TestExternalPublishRuntimeCheck(t *testing.T) {
),
)
}

func TestLeaseFailure(t *testing.T) {
logFilePath := filepath.Join(t.TempDir(), "pubsub.log")
t.Setenv("FSM_LOG_FILE", logFilePath)

in.Run(t,
in.CopyModule("slow"),
in.Deploy("slow"),

// publish 2 events, with the first taking a long time to consume
in.Call("slow", "publish", in.Obj{
"durations": []int{20, 1},
}, func(t testing.TB, resp in.Obj) {}),

// while it is consuming the first event, force delete the lease in the db
in.QueryRow("ftl", `
WITH deleted_rows AS (
DELETE FROM leases WHERE id = (
SELECT lease_id FROM async_calls WHERE verb = 'slow.consume'
)
RETURNING *
)
SELECT COUNT(*) FROM deleted_rows;
`, 1),

in.Sleep(time.Second*7),

// confirm that the first event failed and the second event succeeded,
in.QueryRow("ftl", `SELECT state, error FROM async_calls WHERE verb = 'slow.consume' ORDER BY created_at`, "error", "async call lease expired"),
in.QueryRow("ftl", `SELECT state, error FROM async_calls WHERE verb = 'slow.consume' ORDER BY created_at OFFSET 1`, "success", nil),

// confirm that the first call did not keep executing for too long after the lease was expired
in.IfLanguage("go",
in.ExpectError(
in.FileContains(logFilePath, "slept for 5s"),
"Haystack does not contain needle",
),
),
)
}
2 changes: 2 additions & 0 deletions backend/controller/pubsub/testdata/go/slow/ftl.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
module = "slow"
language = "go"
48 changes: 48 additions & 0 deletions backend/controller/pubsub/testdata/go/slow/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
module ftl/slow

go 1.23.0

require github.com/TBD54566975/ftl v1.1.5

require (
connectrpc.com/connect v1.16.2 // indirect
connectrpc.com/grpcreflect v1.2.0 // indirect
connectrpc.com/otelconnect v0.7.1 // indirect
github.com/alecthomas/atomic v0.1.0-alpha2 // indirect
github.com/alecthomas/concurrency v0.0.2 // indirect
github.com/alecthomas/participle/v2 v2.1.1 // indirect
github.com/alecthomas/types v0.16.0 // indirect
github.com/alessio/shellescape v1.4.2 // indirect
github.com/benbjohnson/clock v1.3.5 // indirect
github.com/danieljoos/wincred v1.2.0 // indirect
github.com/deckarep/golang-set/v2 v2.6.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/godbus/dbus/v5 v5.1.0 // indirect
github.com/hashicorp/cronexpr v1.1.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.6.0 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jpillora/backoff v1.0.0 // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/multiformats/go-base36 v0.2.0 // indirect
github.com/puzpuzpuz/xsync/v3 v3.4.0 // indirect
github.com/swaggest/jsonschema-go v0.3.72 // indirect
github.com/swaggest/refl v1.3.0 // indirect
github.com/zalando/go-keyring v0.2.5 // indirect
go.opentelemetry.io/otel v1.29.0 // indirect
go.opentelemetry.io/otel/metric v1.29.0 // indirect
go.opentelemetry.io/otel/trace v1.29.0 // indirect
golang.org/x/crypto v0.26.0 // indirect
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa // indirect
golang.org/x/mod v0.20.0 // indirect
golang.org/x/net v0.28.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.24.0 // indirect
golang.org/x/text v0.17.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect
)

replace github.com/TBD54566975/ftl => ../../../../../..
Loading
Loading