Skip to content

Commit

Permalink
feat: allow FSMs to specify the next event to send (#1723)
Browse files Browse the repository at this point in the history
closes #1664
closes #2334

This is particularly useful when an FSM state itself knows which state
it should be in next, for example transitioning a payment to "voided" if
an error occurs communicating with an external vendor.

- Adds `fsm.Next(ctx, event)` to the go runtime
    - Only one next transition is allowed per fsm instance
- When a transition completes, the next transition is queued up if it
exists
- When a transition returns an error, we wipe the next transition so
that the retry attempt can set the next transition again

In another PR we will move away from using `ftl.Next()` to avoid the
reference cycle issue.

---------

Co-authored-by: Matt Toohey <[email protected]>
  • Loading branch information
alecthomas and matt2e authored Aug 15, 2024
1 parent 2b27ccc commit ea46c80
Show file tree
Hide file tree
Showing 40 changed files with 1,371 additions and 315 deletions.
161 changes: 129 additions & 32 deletions backend/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -831,17 +831,13 @@ func (s *Service) Call(ctx context.Context, req *connect.Request[ftlv1.CallReque

func (s *Service) SendFSMEvent(ctx context.Context, req *connect.Request[ftlv1.SendFSMEventRequest]) (resp *connect.Response[ftlv1.SendFSMEventResponse], err error) {
msg := req.Msg
sch := s.schema.Load()

// Resolve the FSM.
fsm := &schema.FSM{}
if err := sch.ResolveToType(schema.RefFromProto(msg.Fsm), fsm); err != nil {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("fsm not found: %w", err))
fsm, eventType, fsmKey, err := s.resolveFSMEvent(msg)
if err != nil {
return nil, connect.NewError(connect.CodeNotFound, err)
}

eventType := schema.TypeFromProto(msg.Event)

fsmKey := schema.RefFromProto(msg.Fsm).ToRefKey()

tx, err := s.dal.Begin(ctx)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("could not start transaction: %w", err))
Expand All @@ -854,12 +850,24 @@ func (s *Service) SendFSMEvent(ctx context.Context, req *connect.Request[ftlv1.S
}
defer instance.Release() //nolint:errcheck

err = s.sendFSMEventInTx(ctx, tx, instance, fsm, eventType, msg.Body, false)
if err != nil {
return nil, err
}
return connect.NewResponse(&ftlv1.SendFSMEventResponse{}), nil
}

// schedules an event for a FSM instance within a db transaction
// body may already be encrypted, which is denoted by the encrypted flag
func (s *Service) sendFSMEventInTx(ctx context.Context, tx *dal.Tx, instance *dal.FSMInstance, fsm *schema.FSM, eventType schema.Type, body []byte, encrypted bool) error {
// Populated if we find a matching transition.
var destinationRef *schema.Ref
var destinationVerb *schema.Verb

var candidates []string

sch := s.schema.Load()

updateCandidates := func(ref *schema.Ref) (brk bool, err error) {
verb := &schema.Verb{}
if err := sch.ResolveToType(ref, verb); err != nil {
Expand All @@ -879,7 +887,7 @@ func (s *Service) SendFSMEvent(ctx context.Context, req *connect.Request[ftlv1.S
if !instance.CurrentState.Ok() {
for _, start := range fsm.Start {
if brk, err := updateCandidates(start); err != nil {
return nil, err
return err
} else if brk {
break
}
Expand All @@ -892,7 +900,7 @@ func (s *Service) SendFSMEvent(ctx context.Context, req *connect.Request[ftlv1.S
continue
}
if brk, err := updateCandidates(transition.To); err != nil {
return nil, err
return err
} else if brk {
break
}
Expand All @@ -901,20 +909,59 @@ func (s *Service) SendFSMEvent(ctx context.Context, req *connect.Request[ftlv1.S

if destinationRef == nil {
if len(candidates) > 0 {
return nil, connect.NewError(connect.CodeFailedPrecondition,
return connect.NewError(connect.CodeFailedPrecondition,
fmt.Errorf("no transition found from state %s for type %s, candidates are %s", instance.CurrentState, eventType, strings.Join(candidates, ", ")))
}
return nil, connect.NewError(connect.CodeFailedPrecondition, fmt.Errorf("no transition found from state %s for type %s", instance.CurrentState, eventType))
return connect.NewError(connect.CodeFailedPrecondition, fmt.Errorf("no transition found from state %s for type %s", instance.CurrentState, eventType))
}

retryParams, err := schema.RetryParamsForFSMTransition(fsm, destinationVerb)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
return connect.NewError(connect.CodeInternal, err)
}

err = tx.StartFSMTransition(ctx, instance.FSM, instance.Key, destinationRef.ToRefKey(), msg.Body, retryParams)
err = tx.StartFSMTransition(ctx, instance.FSM, instance.Key, destinationRef.ToRefKey(), body, encrypted, retryParams)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("could not start fsm transition: %w", err))
return connect.NewError(connect.CodeInternal, fmt.Errorf("could not start fsm transition: %w", err))
}
return nil
}

func (s *Service) SetNextFSMEvent(ctx context.Context, req *connect.Request[ftlv1.SendFSMEventRequest]) (resp *connect.Response[ftlv1.SendFSMEventResponse], err error) {
tx, err := s.dal.Begin(ctx)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("could not start transaction: %w", err))
}
defer tx.CommitOrRollback(ctx, &err)
sch := s.schema.Load()
msg := req.Msg
fsm, eventType, fsmKey, err := s.resolveFSMEvent(msg)
if err != nil {
return nil, connect.NewError(connect.CodeNotFound, err)
}

// Get the current state the instance is transitioning to.
_, currentDestinationState, err := tx.GetFSMStates(ctx, fsmKey, req.Msg.Instance)
if err != nil {
if errors.Is(err, dalerrs.ErrNotFound) {
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("fsm instance not found: %w", err))
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("could not get fsm instance: %w", err))
}

// Check if the transition is valid from the current state.
nextState, ok := fsm.NextState(sch, currentDestinationState, eventType).Get()
if !ok {
return nil, connect.NewError(connect.CodeFailedPrecondition, fmt.Errorf("invalid event %q for state %q", eventType, currentDestinationState))
}

// Set the next event.
err = tx.SetNextFSMEvent(ctx, fsmKey, msg.Instance, nextState.ToRefKey(), msg.Body, eventType)
if err != nil {
if errors.Is(err, dalerrs.ErrConflict) {
return nil, connect.NewError(connect.CodeFailedPrecondition, fmt.Errorf("fsm instance already has its next state set: %w", err))
}
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("could not set next fsm event: %w", err))
}
return connect.NewResponse(&ftlv1.SendFSMEventResponse{}), nil
}
Expand Down Expand Up @@ -1406,8 +1453,9 @@ func (s *Service) executeAsyncCalls(ctx context.Context) (interval time.Duration

logger.Tracef("Executing async call")
req := &ftlv1.CallRequest{
Verb: call.Verb.ToProto(),
Body: call.Request,
Verb: call.Verb.ToProto(),
Body: call.Request,
Metadata: metadataForAsyncCall(call),
}
resp, err := s.callWithRequest(ctx, connect.NewRequest(req), optional.None[model.RequestKey](), parentRequestKey, s.config.Advertise.String())
var callResult either.Either[[]byte, string]
Expand Down Expand Up @@ -1464,8 +1512,9 @@ func (s *Service) catchAsyncCall(ctx context.Context, logger *log.Logger, call *
}

req := &ftlv1.CallRequest{
Verb: catchVerb.ToProto(),
Body: body,
Verb: catchVerb.ToProto(),
Body: body,
Metadata: metadataForAsyncCall(call),
}
resp, err := s.callWithRequest(ctx, connect.NewRequest(req), optional.None[model.RequestKey](), optional.None[model.RequestKey](), s.config.Advertise.String())
var catchResult either.Either[[]byte, string]
Expand Down Expand Up @@ -1503,23 +1552,42 @@ func (s *Service) catchAsyncCall(ctx context.Context, logger *log.Logger, call *
return nil
}

func (s *Service) finaliseAsyncCall(ctx context.Context, tx *dal.Tx, call *dal.AsyncCall, callResult either.Either[[]byte, string], isFinalResult bool) error {
if !isFinalResult {
// Will retry, do not propagate yet.
return nil
func metadataForAsyncCall(call *dal.AsyncCall) *ftlv1.Metadata {
switch origin := call.Origin.(type) {
case dal.AsyncOriginFSM:
return &ftlv1.Metadata{
Values: []*ftlv1.Metadata_Pair{
{
Key: "fsmName",
Value: origin.FSM.Name,
},
{
Key: "fsmInstance",
Value: origin.Key,
},
},
}

case dal.AsyncOriginPubSub:
return &ftlv1.Metadata{}

default:
panic(fmt.Errorf("unsupported async call origin: %v", call.Origin))
}
}

func (s *Service) finaliseAsyncCall(ctx context.Context, tx *dal.Tx, call *dal.AsyncCall, callResult either.Either[[]byte, string], isFinalResult bool) error {
_, failed := callResult.(either.Right[[]byte, string])

// Allow for handling of completion based on origin
switch origin := call.Origin.(type) {
case dal.AsyncOriginFSM:
if err := s.onAsyncFSMCallCompletion(ctx, tx, origin, failed); err != nil {
if err := s.onAsyncFSMCallCompletion(ctx, tx, origin, failed, isFinalResult); err != nil {
return fmt.Errorf("failed to finalize FSM async call: %w", err)
}

case dal.AsyncOriginPubSub:
if err := s.pubSub.OnCallCompletion(ctx, tx, origin, failed); err != nil {
if err := s.pubSub.OnCallCompletion(ctx, tx, origin, failed, isFinalResult); err != nil {
return fmt.Errorf("failed to finalize pubsub async call: %w", err)
}

Expand All @@ -1529,20 +1597,30 @@ func (s *Service) finaliseAsyncCall(ctx context.Context, tx *dal.Tx, call *dal.A
return nil
}

func (s *Service) onAsyncFSMCallCompletion(ctx context.Context, tx *dal.Tx, origin dal.AsyncOriginFSM, failed bool) error {
func (s *Service) onAsyncFSMCallCompletion(ctx context.Context, tx *dal.Tx, origin dal.AsyncOriginFSM, failed bool, isFinalResult bool) error {
logger := log.FromContext(ctx).Scope(origin.FSM.String())

// retrieve the next fsm event and delete it
next, err := tx.PopNextFSMEvent(ctx, origin.FSM, origin.Key)
if err != nil {
return fmt.Errorf("%s: failed to get next FSM event: %w", origin, err)
}
if !isFinalResult {
// Will retry, so we only want next fsm to be removed
return nil
}

instance, err := tx.AcquireFSMInstance(ctx, origin.FSM, origin.Key)
if err != nil {
return fmt.Errorf("could not acquire lock on FSM instance: %w", err)
return fmt.Errorf("%s: could not acquire lock on FSM instance: %w", origin, err)
}
defer instance.Release() //nolint:errcheck

if failed {
logger.Warnf("FSM %s failed async call", origin.FSM)
err := tx.FailFSMInstance(ctx, origin.FSM, origin.Key)
if err != nil {
return fmt.Errorf("failed to fail FSM instance: %w", err)
return fmt.Errorf("%s: failed to fail FSM instance: %w", origin, err)
}
return nil
}
Expand All @@ -1552,7 +1630,7 @@ func (s *Service) onAsyncFSMCallCompletion(ctx context.Context, tx *dal.Tx, orig
fsm := &schema.FSM{}
err = sch.ResolveToType(origin.FSM.ToRef(), fsm)
if err != nil {
return fmt.Errorf("could not resolve FSM: %w", err)
return fmt.Errorf("%s: could not resolve FSM: %w", origin, err)
}

destinationState, _ := instance.DestinationState.Get()
Expand All @@ -1562,20 +1640,39 @@ func (s *Service) onAsyncFSMCallCompletion(ctx context.Context, tx *dal.Tx, orig
logger.Debugf("FSM reached terminal state %s", destinationState)
err := tx.SucceedFSMInstance(ctx, origin.FSM, origin.Key)
if err != nil {
return fmt.Errorf("failed to succeed FSM instance: %w", err)
return fmt.Errorf("%s: failed to succeed FSM instance: %w", origin, err)
}
return nil
}

}

err = tx.FinishFSMTransition(ctx, origin.FSM, origin.Key)
instance, err = tx.FinishFSMTransition(ctx, instance)
if err != nil {
return fmt.Errorf("failed to complete FSM transition: %w", err)
return fmt.Errorf("%s: failed to complete FSM transition: %w", origin, err)
}

// If there's a next event enqueued, we immediately start it.
if next, ok := next.Get(); ok {
return s.sendFSMEventInTx(ctx, tx, instance, fsm, next.RequestType, next.Request, true)
}
return nil
}

func (s *Service) resolveFSMEvent(msg *ftlv1.SendFSMEventRequest) (fsm *schema.FSM, eventType schema.Type, fsmKey schema.RefKey, err error) {
sch := s.schema.Load()

fsm = &schema.FSM{}
if err := sch.ResolveToType(schema.RefFromProto(msg.Fsm), fsm); err != nil {
return nil, nil, schema.RefKey{}, fmt.Errorf("fsm not found: %w", err)
}

eventType = schema.TypeFromProto(msg.Event)

fsmKey = schema.RefFromProto(msg.Fsm).ToRefKey()
return fsm, eventType, fsmKey, nil
}

func (s *Service) expireStaleLeases(ctx context.Context) (time.Duration, error) {
err := s.dal.ExpireLeases(ctx)
if err != nil {
Expand Down
9 changes: 9 additions & 0 deletions backend/controller/cronjobs/sql/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions backend/controller/cronjobs/sql/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package sql

import csql "github.com/TBD54566975/ftl/backend/controller/sql"

type Type = csql.Type
Loading

0 comments on commit ea46c80

Please sign in to comment.