diff --git a/backend/controller/ingress/handler_test.go b/backend/controller/ingress/handler_test.go index 012571d390..72af762c4f 100644 --- a/backend/controller/ingress/handler_test.go +++ b/backend/controller/ingress/handler_test.go @@ -100,7 +100,7 @@ func TestIngress(t *testing.T) { req.URL.RawQuery = test.query.Encode() reqKey := model.NewRequestKey(model.OriginIngress, "test") ingress.Handle(sch, reqKey, routes, rec, req, func(ctx context.Context, r *connect.Request[ftlv1.CallRequest], requestKey optional.Option[model.RequestKey], requestSource string) (*connect.Response[ftlv1.CallResponse], error) { - body, err := encoding.Marshal(ctx, response) + body, err := encoding.Marshal(response) assert.NoError(t, err) return connect.NewResponse(&ftlv1.CallResponse{Response: &ftlv1.CallResponse_Body{Body: body}}), nil }) diff --git a/backend/controller/ingress/request_test.go b/backend/controller/ingress/request_test.go index 82cf9b9b6c..fab0664702 100644 --- a/backend/controller/ingress/request_test.go +++ b/backend/controller/ingress/request_test.go @@ -2,7 +2,6 @@ package ingress import ( "bytes" - "context" "fmt" "net/http" "net/url" @@ -164,7 +163,7 @@ func TestBuildRequestBody(t *testing.T) { if test.body == nil { test.body = obj{} } - body, err := encoding.Marshal(context.Background(), test.body) + body, err := encoding.Marshal(test.body) assert.NoError(t, err) requestURL := "http://127.0.0.1" + test.path if test.query != nil { @@ -184,7 +183,7 @@ func TestBuildRequestBody(t *testing.T) { assert.NoError(t, err) actualrv := reflect.New(reflect.TypeOf(test.expected)) actual := actualrv.Interface() - err = encoding.Unmarshal(context.Background(), requestBody, actual) + err = encoding.Unmarshal(requestBody, actual) assert.NoError(t, err) assert.Equal(t, test.expected, actualrv.Elem().Interface(), assert.OmitEmpty()) }) diff --git a/buildengine/build_go_test.go b/buildengine/build_go_test.go index 8f9253ea13..96616082ea 100644 --- a/buildengine/build_go_test.go +++ b/buildengine/build_go_test.go @@ -87,6 +87,8 @@ package other import ( "context" + + "github.com/TBD54566975/ftl/go-runtime/ftl/reflection" ) var _ = context.Background @@ -161,6 +163,15 @@ func Source(context.Context) (SourceResp, error) { func Nothing(context.Context) error { panic("Verb stubs should not be called directly, instead use github.com/TBD54566975/ftl/runtime-go/ftl.CallEmpty()") } + +func init() { + reflection.Register( + reflection.WithSumType[TypeEnum]( + *new(A), + *new(B), + ), + ) +} ` bctx := buildContext{ moduleDir: "testdata/projects/another", diff --git a/buildengine/testdata/type_registry_main.go b/buildengine/testdata/type_registry_main.go index c3fbba29c6..db670bd19b 100644 --- a/buildengine/testdata/type_registry_main.go +++ b/buildengine/testdata/type_registry_main.go @@ -3,6 +3,7 @@ package main import ( "context" + "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/ftlv1connect" "github.com/TBD54566975/ftl/common/plugin" "github.com/TBD54566975/ftl/go-runtime/ftl/reflection" @@ -12,13 +13,8 @@ import ( "ftl/other" ) -func main() { - verbConstructor := server.NewUserVerbServer("other", - server.HandleCall(other.Echo), - ) - ctx := context.Background() - - tr := reflection.NewTypeRegistry( +func init() { + reflection.Register( reflection.WithSumType[another.SecondTypeEnum]( *new(another.One), *new(another.Two), @@ -45,7 +41,11 @@ func main() { *new(other.MyUnit), ), ) - ctx = reflection.ContextWithTypeRegistry(ctx, tr) +} - plugin.Start(ctx, "other", verbConstructor, ftlv1connect.VerbServiceName, ftlv1connect.NewVerbServiceHandler) +func main() { + verbConstructor := server.NewUserVerbServer("other", + server.HandleCall(other.Echo), + ) + plugin.Start(context.Background(), "other", verbConstructor, ftlv1connect.VerbServiceName, ftlv1connect.NewVerbServiceHandler) } diff --git a/go-runtime/compile/build-template/_ftl.tmpl/go/main/main.go b/go-runtime/compile/build-template/_ftl.tmpl/go/main/main.go index acace4f713..4f74317b88 100644 --- a/go-runtime/compile/build-template/_ftl.tmpl/go/main/main.go +++ b/go-runtime/compile/build-template/_ftl.tmpl/go/main/main.go @@ -3,6 +3,7 @@ package main import ( "context" + "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/ftlv1connect" "github.com/TBD54566975/ftl/common/plugin" {{- if .SumTypes }} @@ -13,6 +14,20 @@ import ( "ftl/{{.}}" {{- end}} ) +{{- if .SumTypes}} + +func init() { + reflection.Register( +{{- range .SumTypes}} + reflection.WithSumType[{{.Discriminator}}]( + {{- range .Variants}} + *new({{.Type}}), + {{- end}} + ), +{{- end}} + ) +} +{{- end}} func main() { verbConstructor := server.NewUserVerbServer("{{.Name}}", @@ -28,23 +43,5 @@ func main() { {{- end}} {{- end}} ) - ctx := context.Background() - -{{- if .SumTypes}} - - tr := reflection.NewTypeRegistry( -{{- range .SumTypes}} - reflection.WithSumType[{{.Discriminator}}]( - {{- range .Variants}} - *new({{.Type}}), - {{- end}} - ), -{{- end}} - ) -{{- end}} -{{- if .SumTypes}} - ctx = reflection.ContextWithTypeRegistry(ctx, tr) -{{- end}} - - plugin.Start(ctx, "{{.Name}}", verbConstructor, ftlv1connect.VerbServiceName, ftlv1connect.NewVerbServiceHandler) + plugin.Start(context.Background(), "{{.Name}}", verbConstructor, ftlv1connect.VerbServiceName, ftlv1connect.NewVerbServiceHandler) } diff --git a/go-runtime/compile/build.go b/go-runtime/compile/build.go index 82b36edac8..02ef0f256a 100644 --- a/go-runtime/compile/build.go +++ b/go-runtime/compile/build.go @@ -363,6 +363,20 @@ var scaffoldFuncs = scaffolder.FuncMap{ return false }, + "sumTypes": func(m *schema.Module) []*schema.Enum { + out := []*schema.Enum{} + for _, d := range m.Decls { + switch d := d.(type) { + // Type enums (i.e. sum types) are all the non-value enums + case *schema.Enum: + if !d.IsValueEnum() && d.IsExported() { + out = append(out, d) + } + default: + } + } + return out + }, } func schemaType(t schema.Type) string { diff --git a/go-runtime/compile/external-module-template/_ftl/go/modules/{{ range .NonMainModules }}{{ push .Name . }}{{ end }}/external_module.go b/go-runtime/compile/external-module-template/_ftl/go/modules/{{ range .NonMainModules }}{{ push .Name . }}{{ end }}/external_module.go index 2c55ad8d27..11e682ea6a 100644 --- a/go-runtime/compile/external-module-template/_ftl/go/modules/{{ range .NonMainModules }}{{ push .Name . }}{{ end }}/external_module.go +++ b/go-runtime/compile/external-module-template/_ftl/go/modules/{{ range .NonMainModules }}{{ push .Name . }}{{ end }}/external_module.go @@ -7,6 +7,11 @@ import ( {{- range $import, $alias := (.|imports)}} {{if $alias}}{{$alias}} {{end}}"{{$import}}" {{- end}} +{{- $sumTypes := $ | sumTypes}} +{{- if $sumTypes}} + + "github.com/TBD54566975/ftl/go-runtime/ftl/reflection" +{{- end}} ) var _ = context.Background @@ -72,3 +77,17 @@ func {{.Name|title}}(context.Context, {{type $ .Request}}) ({{type $ .Response}} {{- end}} {{- end}} {{- end}} +{{- if $sumTypes}} + +func init() { + reflection.Register( +{{- range $sumTypes}} + reflection.WithSumType[{{.Name|title}}]( +{{- range .Variants}} + *new({{.Name|title}}), +{{- end}} + ), +{{- end}} + ) +} +{{- end}} diff --git a/go-runtime/encoding/encoding.go b/go-runtime/encoding/encoding.go index 7dd4367ee3..1989bb3783 100644 --- a/go-runtime/encoding/encoding.go +++ b/go-runtime/encoding/encoding.go @@ -4,7 +4,6 @@ package encoding import ( "bytes" - "context" "encoding/base64" "encoding/json" "fmt" @@ -23,19 +22,19 @@ var ( ) type OptionMarshaler interface { - Marshal(ctx context.Context, w *bytes.Buffer, encode func(ctx context.Context, v reflect.Value, w *bytes.Buffer) error) error + Marshal(w *bytes.Buffer, encode func(v reflect.Value, w *bytes.Buffer) error) error } type OptionUnmarshaler interface { - Unmarshal(ctx context.Context, d *json.Decoder, isNull bool, decode func(ctx context.Context, d *json.Decoder, v reflect.Value) error) error + Unmarshal(d *json.Decoder, isNull bool, decode func(d *json.Decoder, v reflect.Value) error) error } -func Marshal(ctx context.Context, v any) ([]byte, error) { +func Marshal(v any) ([]byte, error) { w := &bytes.Buffer{} - err := encodeValue(ctx, reflect.ValueOf(v), w) + err := encodeValue(reflect.ValueOf(v), w) return w.Bytes(), err } -func encodeValue(ctx context.Context, v reflect.Value, w *bytes.Buffer) error { +func encodeValue(v reflect.Value, w *bytes.Buffer) error { if !v.IsValid() { w.WriteString("null") return nil @@ -58,7 +57,7 @@ func encodeValue(ctx context.Context, v reflect.Value, w *bytes.Buffer) error { case t.Implements(optionMarshaler): enc := v.Interface().(OptionMarshaler) //nolint:forcetypeassert - return enc.Marshal(ctx, w, encodeValue) + return enc.Marshal(w, encodeValue) // TODO(Issue #1439): remove this special case by removing all usage of // json.RawMessage, which is not a type we support. @@ -73,16 +72,16 @@ func encodeValue(ctx context.Context, v reflect.Value, w *bytes.Buffer) error { switch v.Kind() { case reflect.Struct: - return encodeStruct(ctx, v, w) + return encodeStruct(v, w) case reflect.Slice: if v.Type().Elem().Kind() == reflect.Uint8 { return encodeBytes(v, w) } - return encodeSlice(ctx, v, w) + return encodeSlice(v, w) case reflect.Map: - return encodeMap(ctx, v, w) + return encodeMap(v, w) case reflect.String: return encodeString(v, w) @@ -98,17 +97,15 @@ func encodeValue(ctx context.Context, v reflect.Value, w *bytes.Buffer) error { case reflect.Interface: if t == reflect.TypeFor[any]() { - return encodeValue(ctx, v.Elem(), w) + return encodeValue(v.Elem(), w) } - if tr, ok := reflection.TypeRegistryFromContext(ctx).Get(); ok { - if vName, ok := tr.GetVariantByType(v.Type(), v.Elem().Type()).Get(); ok { - sumType := struct { - Name string - Value any - }{Name: vName, Value: v.Elem().Interface()} - return encodeValue(ctx, reflect.ValueOf(sumType), w) - } + if vName, ok := reflection.GetVariantByType(v.Type(), v.Elem().Type()).Get(); ok { + sumType := struct { + Name string + Value any + }{Name: vName, Value: v.Elem().Interface()} + return encodeValue(reflect.ValueOf(sumType), w) } return fmt.Errorf("the only supported interface types are enums or any, not %s", t) @@ -118,7 +115,7 @@ func encodeValue(ctx context.Context, v reflect.Value, w *bytes.Buffer) error { } } -func encodeStruct(ctx context.Context, v reflect.Value, w *bytes.Buffer) error { +func encodeStruct(v reflect.Value, w *bytes.Buffer) error { w.WriteRune('{') afterFirst := false for i := range v.NumField() { @@ -146,7 +143,7 @@ func encodeStruct(ctx context.Context, v reflect.Value, w *bytes.Buffer) error { } afterFirst = true w.WriteString(`"` + strcase.ToLowerCamel(ft.Name) + `":`) - if err := encodeValue(ctx, fv, w); err != nil { + if err := encodeValue(fv, w); err != nil { return err } } @@ -171,13 +168,13 @@ func encodeBytes(v reflect.Value, w *bytes.Buffer) error { return nil } -func encodeSlice(ctx context.Context, v reflect.Value, w *bytes.Buffer) error { +func encodeSlice(v reflect.Value, w *bytes.Buffer) error { w.WriteRune('[') for i := range v.Len() { if i > 0 { w.WriteRune(',') } - if err := encodeValue(ctx, v.Index(i), w); err != nil { + if err := encodeValue(v.Index(i), w); err != nil { return err } } @@ -185,7 +182,7 @@ func encodeSlice(ctx context.Context, v reflect.Value, w *bytes.Buffer) error { return nil } -func encodeMap(ctx context.Context, v reflect.Value, w *bytes.Buffer) error { +func encodeMap(v reflect.Value, w *bytes.Buffer) error { w.WriteRune('{') for i, key := range v.MapKeys() { if i > 0 { @@ -194,7 +191,7 @@ func encodeMap(ctx context.Context, v reflect.Value, w *bytes.Buffer) error { w.WriteRune('"') w.WriteString(key.String()) w.WriteString(`":`) - if err := encodeValue(ctx, v.MapIndex(key), w); err != nil { + if err := encodeValue(v.MapIndex(key), w); err != nil { return err } } @@ -226,17 +223,17 @@ func encodeString(v reflect.Value, w *bytes.Buffer) error { return nil } -func Unmarshal(ctx context.Context, data []byte, v any) error { +func Unmarshal(data []byte, v any) error { rv := reflect.ValueOf(v) if rv.Kind() != reflect.Ptr || rv.IsNil() { return fmt.Errorf("unmarshal expects a non-nil pointer") } d := json.NewDecoder(bytes.NewReader(data)) - return decodeValue(ctx, d, rv.Elem()) + return decodeValue(d, rv.Elem()) } -func decodeValue(ctx context.Context, d *json.Decoder, v reflect.Value) error { +func decodeValue(d *json.Decoder, v reflect.Value) error { if !v.CanSet() { return fmt.Errorf("cannot set value: %s", v.Type()) } @@ -261,30 +258,28 @@ func decodeValue(ctx context.Context, d *json.Decoder, v reflect.Value) error { } dec := v.Interface().(OptionUnmarshaler) //nolint:forcetypeassert return handleIfNextTokenIsNull(d, func(d *json.Decoder) error { - return dec.Unmarshal(ctx, d, true, decodeValue) + return dec.Unmarshal(d, true, decodeValue) }, func(d *json.Decoder) error { - return dec.Unmarshal(ctx, d, false, decodeValue) + return dec.Unmarshal(d, false, decodeValue) }) } switch v.Kind() { case reflect.Struct: - return decodeStruct(ctx, d, v) + return decodeStruct(d, v) case reflect.Slice: if v.Type().Elem().Kind() == reflect.Uint8 { return decodeBytes(d, v) } - return decodeSlice(ctx, d, v) + return decodeSlice(d, v) case reflect.Map: - return decodeMap(ctx, d, v) + return decodeMap(d, v) case reflect.Interface: - if tr, ok := reflection.TypeRegistryFromContext(ctx).Get(); ok { - if tr.IsSumTypeDiscriminator(v.Type()) { - return decodeSumType(ctx, d, v) - } + if reflection.IsSumTypeDiscriminator(v.Type()) { + return decodeSumType(d, v) } if v.Type().NumMethod() != 0 { @@ -297,7 +292,7 @@ func decodeValue(ctx context.Context, d *json.Decoder, v reflect.Value) error { } } -func decodeStruct(ctx context.Context, d *json.Decoder, v reflect.Value) error { +func decodeStruct(d *json.Decoder, v reflect.Value) error { if err := expectDelim(d, '{'); err != nil { return err } @@ -325,7 +320,7 @@ func decodeStruct(ctx context.Context, d *json.Decoder, v reflect.Value) error { field.Set(reflect.New(field.Type().Elem())) } default: - if err := decodeValue(ctx, d, field); err != nil { + if err := decodeValue(d, field); err != nil { return err } } @@ -345,14 +340,14 @@ func decodeBytes(d *json.Decoder, v reflect.Value) error { return nil } -func decodeSlice(ctx context.Context, d *json.Decoder, v reflect.Value) error { +func decodeSlice(d *json.Decoder, v reflect.Value) error { if err := expectDelim(d, '['); err != nil { return err } for d.More() { newElem := reflect.New(v.Type().Elem()).Elem() - if err := decodeValue(ctx, d, newElem); err != nil { + if err := decodeValue(d, newElem); err != nil { return err } v.Set(reflect.Append(v, newElem)) @@ -362,7 +357,7 @@ func decodeSlice(ctx context.Context, d *json.Decoder, v reflect.Value) error { return err } -func decodeMap(ctx context.Context, d *json.Decoder, v reflect.Value) error { +func decodeMap(d *json.Decoder, v reflect.Value) error { if err := expectDelim(d, '{'); err != nil { return err } @@ -379,7 +374,7 @@ func decodeMap(ctx context.Context, d *json.Decoder, v reflect.Value) error { } newElem := reflect.New(valType).Elem() - if err := decodeValue(ctx, d, newElem); err != nil { + if err := decodeValue(d, newElem); err != nil { return err } @@ -390,12 +385,7 @@ func decodeMap(ctx context.Context, d *json.Decoder, v reflect.Value) error { return err } -func decodeSumType(ctx context.Context, d *json.Decoder, v reflect.Value) error { - tr, ok := reflection.TypeRegistryFromContext(ctx).Get() - if !ok { - return fmt.Errorf("no type registry found in context") - } - +func decodeSumType(d *json.Decoder, v reflect.Value) error { var sumType struct { Name string Value json.RawMessage @@ -411,13 +401,13 @@ func decodeSumType(ctx context.Context, d *json.Decoder, v reflect.Value) error return fmt.Errorf("no value found for type enum variant") } - variantType, ok := tr.GetVariantByName(v.Type(), sumType.Name).Get() + variantType, ok := reflection.GetVariantByName(v.Type(), sumType.Name).Get() if !ok { return fmt.Errorf("no enum variant found by name %s", sumType.Name) } out := reflect.New(variantType) - if err := decodeValue(ctx, json.NewDecoder(bytes.NewReader(sumType.Value)), out.Elem()); err != nil { + if err := decodeValue(json.NewDecoder(bytes.NewReader(sumType.Value)), out.Elem()); err != nil { return err } if !out.Type().AssignableTo(v.Type()) { diff --git a/go-runtime/encoding/encoding_test.go b/go-runtime/encoding/encoding_test.go index 1c7f05981c..e0502603d6 100644 --- a/go-runtime/encoding/encoding_test.go +++ b/go-runtime/encoding/encoding_test.go @@ -1,7 +1,6 @@ package encoding_test import ( - "context" "reflect" "testing" "time" @@ -32,6 +31,9 @@ func TestMarshal(t *testing.T) { type inner struct { FooBar string } + type sumtypeStruct struct { + D discriminator + } type validateOmitempty struct { ShouldOmit string `json:",omitempty"` ShouldntOmit string `json:""` @@ -60,6 +62,7 @@ func TestMarshal(t *testing.T) { {name: "OptionNull", input: struct{ Option ftl.Option[int] }{ftl.None[int]()}, expected: `{"option":null}`}, {name: "OptionZero", input: struct{ Option ftl.Option[int] }{ftl.Some(0)}, expected: `{"option":0}`}, {name: "OptionStruct", input: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}, expected: `{"option":{"fooBar":"foo"}}`}, + {name: "OptionSumType", input: struct{ Option ftl.Option[sumtypeStruct] }{ftl.Some(sumtypeStruct{variant{"hello"}})}, expected: `{"option":{"d":{"name":"Variant","value":{"message":"hello"}}}}`}, {name: "Unit", input: ftl.Unit{}, expected: `{}`}, {name: "UnitField", input: struct { String string @@ -77,15 +80,14 @@ func TestMarshal(t *testing.T) { }, expected: `{"shouldntOmit":null,"notTagged":null}`}, } - tr := reflection.NewTypeRegistry() - tr.RegisterSumType(reflect.TypeFor[discriminator](), map[string]reflect.Type{ - "Variant": reflect.TypeFor[variant](), - }) - ctx := reflection.ContextWithTypeRegistry(context.Background(), tr) + reflection.AllowAnyPackageForTesting = true + defer func() { reflection.AllowAnyPackageForTesting = false }() + reflection.ResetTypeRegistry() + reflection.Register(reflection.WithSumType[discriminator](variant{})) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - actual, err := Marshal(ctx, tt.input) + actual, err := Marshal(tt.input) assert.EqualError(t, err, tt.err) if err == nil { assert.Equal(t, tt.expected, string(actual)) @@ -134,17 +136,16 @@ func TestUnmarshal(t *testing.T) { {name: "UnregisteredSumType", input: `{"d":{"name":"Variant","value":{"message":"hello"}}}`, expected: struct{ D unregistered }{}, err: `the only supported interface types are enums or any, not encoding_test.unregistered`}, } - tr := reflection.NewTypeRegistry() - tr.RegisterSumType(reflect.TypeFor[discriminator](), map[string]reflect.Type{ - "Variant": reflect.TypeFor[variant](), - }) - ctx := reflection.ContextWithTypeRegistry(context.Background(), tr) + reflection.AllowAnyPackageForTesting = true + defer func() { reflection.AllowAnyPackageForTesting = false }() + reflection.ResetTypeRegistry() + reflection.Register(reflection.WithSumType[discriminator](variant{})) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { eType := reflect.TypeOf(tt.expected) o := reflect.New(eType) - err := Unmarshal(ctx, []byte(tt.input), o.Interface()) + err := Unmarshal([]byte(tt.input), o.Interface()) assert.EqualError(t, err, tt.err) if err == nil { assert.Equal(t, tt.expected, o.Elem().Interface()) @@ -184,20 +185,19 @@ func TestRoundTrip(t *testing.T) { {name: "SumType", input: struct{ D discriminator }{variant{"hello"}}}, } - tr := reflection.NewTypeRegistry() - tr.RegisterSumType(reflect.TypeFor[discriminator](), map[string]reflect.Type{ - "Variant": reflect.TypeFor[variant](), - }) - ctx := reflection.ContextWithTypeRegistry(context.Background(), tr) + reflection.AllowAnyPackageForTesting = true + defer func() { reflection.AllowAnyPackageForTesting = false }() + reflection.ResetTypeRegistry() + reflection.Register(reflection.WithSumType[discriminator](variant{})) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - marshaled, err := Marshal(ctx, tt.input) + marshaled, err := Marshal(tt.input) assert.NoError(t, err) eType := reflect.TypeOf(tt.input) o := reflect.New(eType) - err = Unmarshal(ctx, marshaled, o.Interface()) + err = Unmarshal(marshaled, o.Interface()) assert.NoError(t, err) assert.Equal(t, tt.input, o.Elem().Interface()) diff --git a/go-runtime/ftl/call.go b/go-runtime/ftl/call.go index 5083311192..47f0780aa0 100644 --- a/go-runtime/ftl/call.go +++ b/go-runtime/ftl/call.go @@ -32,7 +32,7 @@ func call[Req, Resp any](ctx context.Context, callee reflection.Ref, req Req, in return resp, fmt.Errorf("%s: overridden verb had invalid response type %T, expected %v", callee, uncheckedResp, reflect.TypeFor[Resp]()) } - reqData, err := encoding.Marshal(ctx, req) + reqData, err := encoding.Marshal(req) if err != nil { return resp, fmt.Errorf("%s: failed to marshal request: %w", callee, err) } @@ -47,7 +47,7 @@ func call[Req, Resp any](ctx context.Context, callee reflection.Ref, req Req, in return resp, fmt.Errorf("%s: %s", callee, cresp.Error.Message) case *ftlv1.CallResponse_Body: - err = encoding.Unmarshal(ctx, cresp.Body, &resp) + err = encoding.Unmarshal(cresp.Body, &resp) if err != nil { return resp, fmt.Errorf("%s: failed to decode response: %w", callee, err) } diff --git a/go-runtime/ftl/option.go b/go-runtime/ftl/option.go index 4280c77ec8..8ac8cb712f 100644 --- a/go-runtime/ftl/option.go +++ b/go-runtime/ftl/option.go @@ -3,7 +3,6 @@ package ftl import ( "bytes" - "context" "database/sql" "database/sql/driver" "encoding" @@ -199,28 +198,26 @@ func (o Option[T]) GoString() string { } func (o Option[T]) Marshal( - ctx context.Context, w *bytes.Buffer, - encode func(ctx context.Context, v reflect.Value, w *bytes.Buffer) error, + encode func(v reflect.Value, w *bytes.Buffer) error, ) error { if o.ok { - return encode(ctx, reflect.ValueOf(&o.value).Elem(), w) + return encode(reflect.ValueOf(&o.value).Elem(), w) } w.WriteString("null") return nil } func (o *Option[T]) Unmarshal( - ctx context.Context, d *json.Decoder, isNull bool, - decode func(ctx context.Context, d *json.Decoder, v reflect.Value) error, + decode func(d *json.Decoder, v reflect.Value) error, ) error { if isNull { o.ok = false return nil } - if err := decode(ctx, d, reflect.ValueOf(&o.value).Elem()); err != nil { + if err := decode(d, reflect.ValueOf(&o.value).Elem()); err != nil { return err } o.ok = true diff --git a/go-runtime/ftl/reflection/reflection.go b/go-runtime/ftl/reflection/reflection.go index 4f458f7586..add1d9abc1 100644 --- a/go-runtime/ftl/reflection/reflection.go +++ b/go-runtime/ftl/reflection/reflection.go @@ -1,7 +1,6 @@ package reflection import ( - "context" "fmt" "reflect" "runtime" @@ -57,10 +56,10 @@ func FuncRef(call any) Ref { return goRefToFTLRef(ref) } -var allowAnyPackageForTesting = false +var AllowAnyPackageForTesting = false func goRefToFTLRef(ref string) Ref { - if !allowAnyPackageForTesting && !strings.HasPrefix(ref, "ftl/") { + if !AllowAnyPackageForTesting && !strings.HasPrefix(ref, "ftl/") { panic(fmt.Sprintf("invalid reference %q, must start with ftl/ ", ref)) } parts := strings.Split(ref[strings.LastIndex(ref, "/")+1:], ".") @@ -68,7 +67,7 @@ func goRefToFTLRef(ref string) Ref { } // Reflect returns the FTL schema for a Go type. -func reflectSchemaType(ctx context.Context, t reflect.Type) schema.Type { +func reflectSchemaType(t reflect.Type) schema.Type { switch t.Kind() { case reflect.Struct: // Handle well-known types. @@ -80,10 +79,10 @@ func reflectSchemaType(ctx context.Context, t reflect.Type) schema.Type { return refForType(t) case reflect.Slice: - return &schema.Array{Element: reflectSchemaType(ctx, t.Elem())} + return &schema.Array{Element: reflectSchemaType(t.Elem())} case reflect.Map: - return &schema.Map{Key: reflectSchemaType(ctx, t.Key()), Value: reflectSchemaType(ctx, t.Elem())} + return &schema.Map{Key: reflectSchemaType(t.Key()), Value: reflectSchemaType(t.Elem())} case reflect.Bool: return &schema.Bool{} @@ -114,8 +113,7 @@ func reflectSchemaType(ctx context.Context, t reflect.Type) schema.Type { return &schema.Any{} } // Check if it's a sum-type discriminator. - registry, ok := TypeRegistryFromContext(ctx).Get() - if !ok || !registry.IsSumTypeDiscriminator(t) { + if !IsSumTypeDiscriminator(t) { panic(fmt.Sprintf("unsupported interface type %s", t)) } return refForType(t) @@ -128,7 +126,7 @@ func reflectSchemaType(ctx context.Context, t reflect.Type) schema.Type { // Return the FTL module for a type or panic if it's not an FTL type. func moduleForType(t reflect.Type) string { module := t.PkgPath() - if !allowAnyPackageForTesting && !strings.HasPrefix(module, "ftl/") { + if !AllowAnyPackageForTesting && !strings.HasPrefix(module, "ftl/") { panic(fmt.Sprintf("invalid reference %q, must start with ftl/ ", module)) } parts := strings.Split(module, "/") diff --git a/go-runtime/ftl/reflection/reflection_test.go b/go-runtime/ftl/reflection/reflection_test.go index 0546b50186..8743467666 100644 --- a/go-runtime/ftl/reflection/reflection_test.go +++ b/go-runtime/ftl/reflection/reflection_test.go @@ -1,7 +1,6 @@ package reflection import ( - "context" "reflect" "testing" @@ -39,12 +38,10 @@ type AllTypesToReflect struct { } func TestReflectSchemaType(t *testing.T) { - allowAnyPackageForTesting = true - t.Cleanup(func() { allowAnyPackageForTesting = false }) + AllowAnyPackageForTesting = true + t.Cleanup(func() { AllowAnyPackageForTesting = false }) - tr := NewTypeRegistry(WithSumType[MySumType](Variant1{}, Variant2{})) - ctx := context.Background() - ctx = ContextWithTypeRegistry(ctx, tr) + Register(WithSumType[MySumType](Variant1{}, Variant2{})) v := AllTypesToReflect{SumType: &Variant1{}} @@ -66,7 +63,7 @@ func TestReflectSchemaType(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - st := reflectSchemaType(ctx, reflect.TypeOf(tt.value).Elem()) + st := reflectSchemaType(reflect.TypeOf(tt.value).Elem()) assert.Equal(t, tt.expected, st) }) } @@ -74,7 +71,7 @@ func TestReflectSchemaType(t *testing.T) { t.Run("InvalidType", func(t *testing.T) { var invalid uint assert.Panics(t, func() { - reflectSchemaType(ctx, reflect.TypeOf(&invalid).Elem()) + reflectSchemaType(reflect.TypeOf(&invalid).Elem()) }) }) } diff --git a/go-runtime/ftl/reflection/type_registry.go b/go-runtime/ftl/reflection/type_registry.go index 2d0ed5f22f..38f46cb498 100644 --- a/go-runtime/ftl/reflection/type_registry.go +++ b/go-runtime/ftl/reflection/type_registry.go @@ -1,28 +1,15 @@ package reflection import ( - "context" "reflect" "github.com/alecthomas/types/optional" ) -type contextKeyTypeRegistry struct{} - -// ContextWithTypeRegistry adds a type registry to the given context. -func ContextWithTypeRegistry(ctx context.Context, r *TypeRegistry) context.Context { - return context.WithValue(ctx, contextKeyTypeRegistry{}, r) -} - -// TypeRegistryFromContext retrieves the [TypeRegistry] previously added to the -// context with [ContextWithTypeRegistry]. -func TypeRegistryFromContext(ctx context.Context) optional.Option[*TypeRegistry] { - t, ok := ctx.Value(contextKeyTypeRegistry{}).(*TypeRegistry) - if ok { - return optional.Some(t) - } - return optional.None[*TypeRegistry]() -} +// singletonTypeRegistry is the global type registry that all public functions in this +// package interface with. It is not truly threadsafe. However, everything is initialized +// in init() calls, which are safe, and the type registry is never mutated afterwards. +var singletonTypeRegistry = newTypeRegistry() // TypeRegistry is used for dynamic type resolution at runtime. It stores associations between sum type discriminators // and their variants, for use in encoding and decoding. @@ -38,24 +25,21 @@ type sumTypeVariant struct { goType reflect.Type } -// TypeRegistryOption is a functional option for configuring a [TypeRegistry]. -type TypeRegistryOption func(t *TypeRegistry) - // WithSumType adds a sum type and its variants to the type registry. -func WithSumType[Discriminator any](variants ...Discriminator) TypeRegistryOption { +func WithSumType[Discriminator any](variants ...Discriminator) func(t *TypeRegistry) { return func(t *TypeRegistry) { variantMap := map[string]reflect.Type{} for _, v := range variants { ref := TypeRefFromValue(v) variantMap[ref.Name] = reflect.TypeOf(v) } - t.RegisterSumType(reflect.TypeFor[Discriminator](), variantMap) + t.registerSumType(reflect.TypeFor[Discriminator](), variantMap) } } -// NewTypeRegistry creates a new [TypeRegistry] for instantiating types by their qualified +// newTypeRegistry creates a new [TypeRegistry] for instantiating types by their qualified // name at runtime. -func NewTypeRegistry(options ...TypeRegistryOption) *TypeRegistry { +func newTypeRegistry(options ...func(t *TypeRegistry)) *TypeRegistry { t := &TypeRegistry{ sumTypes: map[reflect.Type][]sumTypeVariant{}, variantsToDiscriminators: map[reflect.Type]reflect.Type{}, @@ -66,10 +50,17 @@ func NewTypeRegistry(options ...TypeRegistryOption) *TypeRegistry { return t } -// RegisterSumType registers a Go sum type with the type registry. +// Register applies all the provided options to the singleton TypeRegistry +func Register(options ...func(t *TypeRegistry)) { + for _, o := range options { + o(singletonTypeRegistry) + } +} + +// registerSumType registers a Go sum type with the type registry. // // Sum types are represented as enums in the FTL schema. -func (t *TypeRegistry) RegisterSumType(discriminator reflect.Type, variants map[string]reflect.Type) { +func (t *TypeRegistry) registerSumType(discriminator reflect.Type, variants map[string]reflect.Type) { var values []sumTypeVariant for name, v := range variants { t.variantsToDiscriminators[v] = discriminator @@ -81,18 +72,36 @@ func (t *TypeRegistry) RegisterSumType(discriminator reflect.Type, variants map[ t.sumTypes[discriminator] = values } +// ResetTypeRegistry clears the contents of the singleton type registry for tests to +// guarantee determinism. +func ResetTypeRegistry() { + singletonTypeRegistry = newTypeRegistry() +} + // IsSumTypeDiscriminator returns true if the given type is a sum type discriminator. -func (t *TypeRegistry) IsSumTypeDiscriminator(discriminator reflect.Type) bool { +func IsSumTypeDiscriminator(discriminator reflect.Type) bool { + return singletonTypeRegistry.isSumTypeDiscriminator(discriminator) +} + +func (t *TypeRegistry) isSumTypeDiscriminator(discriminator reflect.Type) bool { return t.getSumTypeVariants(discriminator).Ok() } // GetDiscriminatorByVariant returns the discriminator type for the given variant type. -func (t *TypeRegistry) GetDiscriminatorByVariant(variant reflect.Type) optional.Option[reflect.Type] { +func GetDiscriminatorByVariant(variant reflect.Type) optional.Option[reflect.Type] { + return singletonTypeRegistry.getDiscriminatorByVariant(variant) +} + +func (t *TypeRegistry) getDiscriminatorByVariant(variant reflect.Type) optional.Option[reflect.Type] { return optional.Zero(t.variantsToDiscriminators[variant]) } // GetVariantByName returns the variant type for the given discriminator and variant name. -func (t *TypeRegistry) GetVariantByName(discriminator reflect.Type, name string) optional.Option[reflect.Type] { +func GetVariantByName(discriminator reflect.Type, name string) optional.Option[reflect.Type] { + return singletonTypeRegistry.getVariantByName(discriminator, name) +} + +func (t *TypeRegistry) getVariantByName(discriminator reflect.Type, name string) optional.Option[reflect.Type] { variants, ok := t.getSumTypeVariants(discriminator).Get() if !ok { return optional.None[reflect.Type]() @@ -106,7 +115,11 @@ func (t *TypeRegistry) GetVariantByName(discriminator reflect.Type, name string) } // GetVariantByType returns the variant name for the given discriminator and variant type. -func (t *TypeRegistry) GetVariantByType(discriminator reflect.Type, variantType reflect.Type) optional.Option[string] { +func GetVariantByType(discriminator reflect.Type, variantType reflect.Type) optional.Option[string] { + return singletonTypeRegistry.getVariantByType(discriminator, variantType) +} + +func (t *TypeRegistry) getVariantByType(discriminator reflect.Type, variantType reflect.Type) optional.Option[string] { variants, ok := t.getSumTypeVariants(discriminator).Get() if !ok { return optional.None[string]() diff --git a/go-runtime/ftl/reflection/type_registry_test.go b/go-runtime/ftl/reflection/type_registry_test.go index 7b0892324e..61b66ab934 100644 --- a/go-runtime/ftl/reflection/type_registry_test.go +++ b/go-runtime/ftl/reflection/type_registry_test.go @@ -8,22 +8,27 @@ import ( ) func TestTypeRegistry(t *testing.T) { - allowAnyPackageForTesting = true - defer func() { allowAnyPackageForTesting = false }() - tr := NewTypeRegistry(WithSumType[MySumType](Variant1{}, Variant2{})) + AllowAnyPackageForTesting = true + defer func() { AllowAnyPackageForTesting = false }() + ResetTypeRegistry() + Register(WithSumType[MySumType](Variant1{}, Variant2{})) - svariant, ok := tr.GetVariantByType(reflect.TypeFor[MySumType](), reflect.TypeFor[Variant1]()).Get() + svariant, ok := GetVariantByType(reflect.TypeFor[MySumType](), reflect.TypeFor[Variant1]()).Get() assert.True(t, ok) assert.Equal(t, "Variant1", svariant) - variant, ok := tr.GetVariantByName(reflect.TypeFor[MySumType](), "Variant1").Get() + variant, ok := GetVariantByName(reflect.TypeFor[MySumType](), "Variant1").Get() assert.True(t, ok) assert.Equal(t, reflect.TypeFor[Variant1](), variant) - ok = tr.IsSumTypeDiscriminator(reflect.TypeFor[MySumType]()) + ok = IsSumTypeDiscriminator(reflect.TypeFor[MySumType]()) assert.True(t, ok) - discriminator, ok := tr.GetDiscriminatorByVariant(reflect.TypeFor[Variant1]()).Get() + discriminator, ok := GetDiscriminatorByVariant(reflect.TypeFor[Variant1]()).Get() assert.True(t, ok) assert.Equal(t, reflect.TypeFor[MySumType](), discriminator) + + ResetTypeRegistry() + _, ok = GetVariantByType(reflect.TypeFor[MySumType](), reflect.TypeFor[Variant1]()).Get() + assert.False(t, ok) // test ResetTypeRegistry() } diff --git a/go-runtime/server/server.go b/go-runtime/server/server.go index 73a78c45a2..ef3324ab22 100644 --- a/go-runtime/server/server.go +++ b/go-runtime/server/server.go @@ -68,7 +68,7 @@ func handler[Req, Resp any](ref reflection.Ref, verb func(ctx context.Context, r fn: func(ctx context.Context, reqdata []byte) ([]byte, error) { // Decode request. var req Req - err := encoding.Unmarshal(ctx, reqdata, &req) + err := encoding.Unmarshal(reqdata, &req) if err != nil { return nil, fmt.Errorf("invalid request to verb %s: %w", ref, err) } @@ -79,7 +79,7 @@ func handler[Req, Resp any](ref reflection.Ref, verb func(ctx context.Context, r return nil, fmt.Errorf("call to verb %s failed: %w", ref, err) } - respdata, err := encoding.Marshal(ctx, resp) + respdata, err := encoding.Marshal(resp) if err != nil { return nil, err }