diff --git a/go-runtime/ftl/call.go b/go-runtime/ftl/call.go index 4e2ecad15d..ed8e9c2a25 100644 --- a/go-runtime/ftl/call.go +++ b/go-runtime/ftl/call.go @@ -16,13 +16,12 @@ import ( ) func call[Req, Resp any](ctx context.Context, callee Ref, req Req, inline Verb[Req, Resp]) (resp Resp, err error) { - behavior, err := modulecontext.FromContext(ctx).BehaviorForVerb(schema.Ref{Module: callee.Module, Name: callee.Name}) + override, err := modulecontext.FromContext(ctx).BehaviorForVerb(schema.Ref{Module: callee.Module, Name: callee.Name}) if err != nil { return resp, fmt.Errorf("%s: %w", callee, err) } - switch behavior := behavior.(type) { - case modulecontext.MockBehavior: - uncheckedResp, err := behavior.Mock(ctx, req) + if behavior, ok := override.Get(); ok { + uncheckedResp, err := behavior.Call(ctx, modulecontext.Verb(widenVerb(inline)), req) if err != nil { return resp, fmt.Errorf("%s: %w", callee, err) } @@ -30,39 +29,31 @@ func call[Req, Resp any](ctx context.Context, callee Ref, req Req, inline Verb[R return r, nil } return resp, fmt.Errorf("%s: overridden verb had invalid response type %T, expected %v", callee, uncheckedResp, reflect.TypeFor[Resp]()) - case modulecontext.DirectBehavior: - resp, err = inline(ctx, req) - if err != nil { - return resp, fmt.Errorf("%s: %w", callee, err) - } - return resp, nil - case modulecontext.StandardBehavior: - reqData, err := encoding.Marshal(req) - if err != nil { - return resp, fmt.Errorf("%s: failed to marshal request: %w", callee, err) - } + } - client := rpc.ClientFromContext[ftlv1connect.VerbServiceClient](ctx) - cresp, err := client.Call(ctx, connect.NewRequest(&ftlv1.CallRequest{Verb: callee.ToProto(), Body: reqData})) - if err != nil { - return resp, fmt.Errorf("%s: failed to call Verb: %w", callee, err) - } - switch cresp := cresp.Msg.Response.(type) { - case *ftlv1.CallResponse_Error_: - return resp, fmt.Errorf("%s: %s", callee, cresp.Error.Message) + reqData, err := encoding.Marshal(req) + if err != nil { + return resp, fmt.Errorf("%s: failed to marshal request: %w", callee, err) + } - case *ftlv1.CallResponse_Body: - err = encoding.Unmarshal(cresp.Body, &resp) - if err != nil { - return resp, fmt.Errorf("%s: failed to decode response: %w", callee, err) - } - return resp, nil + client := rpc.ClientFromContext[ftlv1connect.VerbServiceClient](ctx) + cresp, err := client.Call(ctx, connect.NewRequest(&ftlv1.CallRequest{Verb: callee.ToProto(), Body: reqData})) + if err != nil { + return resp, fmt.Errorf("%s: failed to call Verb: %w", callee, err) + } + switch cresp := cresp.Msg.Response.(type) { + case *ftlv1.CallResponse_Error_: + return resp, fmt.Errorf("%s: %s", callee, cresp.Error.Message) - default: - panic(fmt.Sprintf("%s: invalid response type %T", callee, cresp)) + case *ftlv1.CallResponse_Body: + err = encoding.Unmarshal(cresp.Body, &resp) + if err != nil { + return resp, fmt.Errorf("%s: failed to decode response: %w", callee, err) } + return resp, nil + default: - panic(fmt.Sprintf("unknown behavior: %s", behavior)) + panic(fmt.Sprintf("%s: invalid response type %T", callee, cresp)) } } @@ -93,3 +84,13 @@ func CallEmpty(ctx context.Context, empty Empty) error { }) return err } + +func widenVerb[Req, Resp any](verb Verb[Req, Resp]) Verb[any, any] { + return func(ctx context.Context, uncheckedReq any) (any, error) { + req, ok := uncheckedReq.(Req) + if !ok { + return nil, fmt.Errorf("invalid request type %T for %v, expected %v", uncheckedReq, FuncRef(verb), reflect.TypeFor[Req]()) + } + return verb(ctx, req) + } +} diff --git a/go-runtime/ftl/ftltest/ftltest.go b/go-runtime/ftl/ftltest/ftltest.go index e791e0ad6d..cff389f55d 100644 --- a/go-runtime/ftl/ftltest/ftltest.go +++ b/go-runtime/ftl/ftltest/ftltest.go @@ -19,7 +19,7 @@ type OptionsState struct { configs map[string][]byte secrets map[string][]byte databases map[string]modulecontext.Database - mockVerbs map[schema.RefKey]modulecontext.MockVerb + mockVerbs map[schema.RefKey]modulecontext.Verb allowDirectVerbBehavior bool } @@ -39,7 +39,7 @@ func Context(options ...Option) context.Context { configs: make(map[string][]byte), secrets: make(map[string][]byte), databases: databases, - mockVerbs: make(map[schema.RefKey]modulecontext.MockVerb), + mockVerbs: make(map[schema.RefKey]modulecontext.Verb), } for _, option := range options { err := option(ctx, state) diff --git a/internal/modulecontext/module_context.go b/internal/modulecontext/module_context.go index 19f925eb7d..8e62177d9a 100644 --- a/internal/modulecontext/module_context.go +++ b/internal/modulecontext/module_context.go @@ -8,6 +8,7 @@ import ( "strconv" "strings" + "github.com/alecthomas/types/optional" _ "github.com/jackc/pgx/v5/stdlib" // SQL driver ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1" @@ -15,8 +16,9 @@ import ( "github.com/TBD54566975/ftl/internal/reflect" ) -type MockVerb func(ctx context.Context, req any) (resp any, err error) - +// Database represents a database connection based on a DSN +// +// It holds a private field for the database which is accessible through moduleCtx.GetDatabase(name) type Database struct { DSN string DBType DBType @@ -52,6 +54,11 @@ func (x DBType) String() string { } } +// Verb is a function that takes a request and returns a response but is not constrained by request/response type like ftl.Verb +// +// It is used for definitions of mock verbs as well as real implementations of verbs to directly execute +type Verb func(ctx context.Context, req any) (resp any, err error) + // ModuleContext holds the context needed for a module, including configs, secrets and DSNs // // ModuleContext is immutable @@ -62,7 +69,7 @@ type ModuleContext struct { databases map[string]Database isTesting bool - mockVerbs map[schema.RefKey]MockVerb + mockVerbs map[schema.RefKey]Verb allowDirectVerbBehavior bool } @@ -78,7 +85,7 @@ func NewBuilder(module string) *Builder { configs: map[string][]byte{}, secrets: map[string][]byte{}, databases: map[string]Database{}, - mockVerbs: map[schema.RefKey]MockVerb{}, + mockVerbs: map[schema.RefKey]Verb{}, } } @@ -107,7 +114,7 @@ func (b *Builder) AddDatabases(databases map[string]Database) *Builder { } // UpdateForTesting marks the builder as part of a test environment and adds mock verbs and flags for other test features. -func (b *Builder) UpdateForTesting(mockVerbs map[schema.RefKey]MockVerb, allowDirectVerbBehavior bool) *Builder { +func (b *Builder) UpdateForTesting(mockVerbs map[schema.RefKey]Verb, allowDirectVerbBehavior bool) *Builder { b.isTesting = true for name, verb := range mockVerbs { b.mockVerbs[name] = verb @@ -173,46 +180,40 @@ func (m ModuleContext) GetDatabase(name string, dbType DBType) (*sql.DB, error) // BehaviorForVerb returns what to do to execute a verb // // This allows module context to dictate behavior based on testing options -func (m ModuleContext) BehaviorForVerb(ref schema.Ref) (VerbBehavior, error) { +// Returning optional.Nil indicates the verb should be executed normally via the controller +func (m ModuleContext) BehaviorForVerb(ref schema.Ref) (optional.Option[VerbBehavior], error) { if mock, ok := m.mockVerbs[ref.ToRefKey()]; ok { - return MockBehavior{Mock: mock}, nil + return optional.Some(VerbBehavior(MockBehavior{Mock: mock})), nil } else if m.allowDirectVerbBehavior && ref.Module == m.module { - return DirectBehavior{}, nil + return optional.Some(VerbBehavior(DirectBehavior{})), nil } else if m.isTesting { if ref.Module == m.module { - return StandardBehavior{}, fmt.Errorf("no mock found: provide a mock with ftltest.WhenVerb(%s, ...) or enable all calls within the module with ftltest.WithCallsAllowedWithinModule()", strings.ToUpper(ref.Name[:1])+ref.Name[1:]) + return optional.None[VerbBehavior](), fmt.Errorf("no mock found: provide a mock with ftltest.WhenVerb(%s, ...) or enable all calls within the module with ftltest.WithCallsAllowedWithinModule()", strings.ToUpper(ref.Name[:1])+ref.Name[1:]) } - return StandardBehavior{}, fmt.Errorf("no mock found: provide a mock with ftltest.WhenVerb(%s.%s, ...)", ref.Module, strings.ToUpper(ref.Name[:1])+ref.Name[1:]) + return optional.None[VerbBehavior](), fmt.Errorf("no mock found: provide a mock with ftltest.WhenVerb(%s.%s, ...)", ref.Module, strings.ToUpper(ref.Name[:1])+ref.Name[1:]) } - return StandardBehavior{}, nil + return optional.None[VerbBehavior](), nil } // VerbBehavior indicates how to execute a verb -// -//sumtype:decl type VerbBehavior interface { - verbBehavior() + Call(ctx context.Context, verb Verb, request any) (any, error) } -// StandardBehavior indicates that the verb should be executed via the controller -type StandardBehavior struct{} - -func (StandardBehavior) verbBehavior() {} - -var _ VerbBehavior = StandardBehavior{} - // DirectBehavior indicates that the verb should be executed by calling the function directly (for testing) type DirectBehavior struct{} -func (DirectBehavior) verbBehavior() {} +func (DirectBehavior) Call(ctx context.Context, verb Verb, req any) (any, error) { + return verb(ctx, req) +} var _ VerbBehavior = DirectBehavior{} // MockBehavior indicates the verb has a mock implementation type MockBehavior struct { - Mock MockVerb + Mock Verb } -func (MockBehavior) verbBehavior() {} - -var _ VerbBehavior = MockBehavior{} +func (b MockBehavior) Call(ctx context.Context, verb Verb, req any) (any, error) { + return b.Mock(ctx, req) +}