From 6a2bfeca3f85c179beb86498986d9f888bbbb0f3 Mon Sep 17 00:00:00 2001 From: Elizabeth Worstell Date: Thu, 22 Feb 2024 18:34:04 -0800 Subject: [PATCH] chore: implement encoding.Unmarshal fixes #971 --- backend/controller/ingress/request_test.go | 6 +- go-runtime/encoding/encoding.go | 234 ++++++++++++++++++++- go-runtime/encoding/encoding_test.go | 99 +++++++++ go-runtime/ftl/call.go | 3 +- go-runtime/ftl/option.go | 2 +- go-runtime/server/server.go | 3 +- integration/integration_test.go | 5 +- 7 files changed, 335 insertions(+), 17 deletions(-) diff --git a/backend/controller/ingress/request_test.go b/backend/controller/ingress/request_test.go index 4cee7f9b88..7671c5f85b 100644 --- a/backend/controller/ingress/request_test.go +++ b/backend/controller/ingress/request_test.go @@ -2,7 +2,6 @@ package ingress import ( "bytes" - "encoding/json" "net/http" "net/url" "reflect" @@ -17,8 +16,7 @@ import ( ) type AliasRequest struct { - // FIXME: This should be an alias (`json:"alias"`) once encoding.Unmarshal is available. - Aliased string + Aliased string `json:"alias"` } type PathParameterRequest struct { @@ -184,7 +182,7 @@ func TestBuildRequestBody(t *testing.T) { assert.NoError(t, err) actualrv := reflect.New(reflect.TypeOf(test.expected)) actual := actualrv.Interface() - err = json.Unmarshal(requestBody, actual) + err = encoding.Unmarshal(requestBody, actual) assert.NoError(t, err) assert.Equal(t, test.expected, actualrv.Elem().Interface(), assert.OmitEmpty()) }) diff --git a/go-runtime/encoding/encoding.go b/go-runtime/encoding/encoding.go index 4849c2e72c..ab3da39896 100644 --- a/go-runtime/encoding/encoding.go +++ b/go-runtime/encoding/encoding.go @@ -9,14 +9,17 @@ import ( "encoding/json" "fmt" "reflect" + "strconv" "strings" "github.com/TBD54566975/ftl/backend/schema/strcase" ) var ( - textUnarmshaler = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() - jsonUnmarshaler = reflect.TypeOf((*json.Marshaler)(nil)).Elem() + textMarshaler = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + textUnmarshaler = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + jsonMarshaler = reflect.TypeOf((*json.Marshaler)(nil)).Elem() + jsonUnmarshaler = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() ) func Marshal(v any) ([]byte, error) { @@ -32,11 +35,11 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error { } t := v.Type() switch { - case t.Kind() == reflect.Ptr && t.Elem().Implements(jsonUnmarshaler): + case t.Kind() == reflect.Ptr && t.Elem().Implements(jsonMarshaler): v = v.Elem() fallthrough - case t.Implements(jsonUnmarshaler): + case t.Implements(jsonMarshaler): enc := v.Interface().(json.Marshaler) //nolint:forcetypeassert data, err := enc.MarshalJSON() if err != nil { @@ -45,11 +48,11 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error { w.Write(data) return nil - case t.Kind() == reflect.Ptr && t.Elem().Implements(textUnarmshaler): + case t.Kind() == reflect.Ptr && t.Elem().Implements(textMarshaler): v = v.Elem() fallthrough - case t.Implements(textUnarmshaler): + case t.Implements(textMarshaler): enc := v.Interface().(encoding.TextMarshaler) //nolint:forcetypeassert data, err := enc.MarshalText() if err != nil { @@ -198,3 +201,222 @@ func encodeString(v reflect.Value, w *bytes.Buffer) error { w.WriteRune('"') return nil } + +func Unmarshal(data []byte, v any) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return fmt.Errorf("unmarshal expects a non-nil pointer") + } + return decodeValue(data, rv.Elem()) +} + +func decodeValue(data []byte, v reflect.Value) error { + if !v.CanSet() { + return fmt.Errorf("cannot set value") + } + + t := v.Type() + switch { + case v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(jsonUnmarshaler): + v = v.Addr() + fallthrough + + case t.Implements(jsonUnmarshaler): + if v.IsNil() { + v.Set(reflect.New(t.Elem())) + } + dec := v.Interface().(json.Unmarshaler) //nolint:forcetypeassert + return dec.UnmarshalJSON(data) + + case v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(textUnmarshaler): + v = v.Addr() + fallthrough + + case t.Implements(textUnmarshaler): + if v.IsNil() { + v.Set(reflect.New(t.Elem())) + } + dec := v.Interface().(encoding.TextUnmarshaler) //nolint:forcetypeassert + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + return dec.UnmarshalText([]byte(s)) + } + + switch v.Kind() { + case reflect.Struct: + return decodeStruct(data, v) + + case reflect.Ptr: + if string(data) == "null" { + v.Set(reflect.Zero(v.Type())) + return nil + } + return decodeValue(data, v.Elem()) + + case reflect.Slice: + if v.Type().Elem().Kind() == reflect.Uint8 { + return decodeBytes(data, v) + } + return decodeSlice(data, v) + + case reflect.Map: + return decodeMap(data, v) + + case reflect.String: + return decodeString(data, v) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return decodeInt(data, v) + + case reflect.Float32, reflect.Float64: + return decodeFloat(data, v) + + case reflect.Bool: + return decodeBool(data, v) + + case reflect.Interface: + if v.Type().NumMethod() != 0 { + return fmt.Errorf("the only interface type supported is any, not %s", v.Type()) + } + + var anyInterface any + if err := json.Unmarshal(data, &anyInterface); err != nil { + return err + } + reflectedValue := reflect.ValueOf(anyInterface) + if !reflectedValue.Type().AssignableTo(v.Type()) { + return fmt.Errorf("cannot assign type %s to any interface", reflectedValue.Type()) + } + v.Set(reflectedValue) + return nil + + default: + return fmt.Errorf("unsupported type: %s", v.Type()) + } +} + +func decodeStruct(data []byte, v reflect.Value) error { + var jsonData map[string]json.RawMessage + if err := json.Unmarshal(data, &jsonData); err != nil { + return err + } + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + fieldType := v.Type().Field(i) + jsonKey := strcase.ToLowerCamel(fieldType.Name) + if _, ok := jsonData[jsonKey]; ok { + fieldTypeStr := fieldType.Type.String() + switch { + case fieldTypeStr == "*Unit" || fieldTypeStr == "Unit": + if fieldTypeStr == "*Unit" && field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + + default: + if err := decodeValue(jsonData[jsonKey], field); err != nil { + return fmt.Errorf("error decoding field '%s': %w", fieldType.Name, err) + } + } + } + } + return nil +} + +func decodeString(data []byte, v reflect.Value) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + v.SetString(s) + return nil +} + +func decodeInt(data []byte, v reflect.Value) error { + var n int64 + n, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return err + } + v.SetInt(n) + return nil +} + +func decodeFloat(data []byte, v reflect.Value) error { + var f float64 + if err := json.Unmarshal(data, &f); err != nil { + return err + } + v.SetFloat(f) + return nil +} + +func decodeBool(data []byte, v reflect.Value) error { + var b bool + if err := json.Unmarshal(data, &b); err != nil { + return err + } + v.SetBool(b) + return nil +} + +func decodeBytes(data []byte, v reflect.Value) error { + var b []byte + if err := json.Unmarshal(data, &b); err != nil { + return err + } + v.SetBytes(b) + return nil +} + +func decodeSlice(data []byte, v reflect.Value) error { + var intfSlice []any + if err := json.Unmarshal(data, &intfSlice); err != nil { + return err + } + + sliceType := v.Type().Elem() + newSlice := reflect.MakeSlice(reflect.SliceOf(sliceType), len(intfSlice), len(intfSlice)) + + for i, intf := range intfSlice { + elem := newSlice.Index(i) + elemData, err := json.Marshal(intf) + if err != nil { + return err + } + if err = decodeValue(elemData, elem); err != nil { + return err + } + } + + v.Set(newSlice) + + return nil +} + +func decodeMap(data []byte, v reflect.Value) error { + if v.IsNil() { + v.Set(reflect.MakeMap(v.Type())) + } + + var intfMap map[string]any + if err := json.Unmarshal(data, &intfMap); err != nil { + return err + } + + for key, intf := range intfMap { + valData, err := json.Marshal(intf) + if err != nil { + return err + } + mapValue := reflect.New(v.Type().Elem()) + if err = decodeValue(valData, mapValue.Elem()); err != nil { + return err + } + v.SetMapIndex(reflect.ValueOf(key), mapValue.Elem()) + } + + return nil +} diff --git a/go-runtime/encoding/encoding_test.go b/go-runtime/encoding/encoding_test.go index 860a7e479a..18053f466f 100644 --- a/go-runtime/encoding/encoding_test.go +++ b/go-runtime/encoding/encoding_test.go @@ -1,6 +1,7 @@ package encoding_test import ( + "reflect" "testing" . "github.com/TBD54566975/ftl/go-runtime/encoding" @@ -31,6 +32,11 @@ func TestMarshal(t *testing.T) { {name: "Option", input: struct{ Option ftl.Option[int] }{ftl.Some(42)}, expected: `{"option":42}`}, {name: "OptionPtr", input: struct{ Option *ftl.Option[int] }{&somePtr}, expected: `{"option":42}`}, {name: "OptionStruct", input: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}, expected: `{"option":{"fooBar":"foo"}}`}, + {name: "Unit", input: ftl.Unit{}, expected: `{}`}, + {name: "UnitField", input: struct { + String string + Unit ftl.Unit + }{String: "something", Unit: ftl.Unit{}}, expected: `{"string":"something"}`}, } for _, tt := range tests { @@ -41,3 +47,96 @@ func TestMarshal(t *testing.T) { }) } } + +func TestUnmarshal(t *testing.T) { + type inner struct { + FooBar string + } + somePtr := ftl.Some(42) + tests := []struct { + name string + input string + expected any + err string + }{ + {name: "FieldRenaming", input: `{"fooBar":""}`, expected: struct{ FooBar string }{""}}, + {name: "String", input: `{"string":"foo"}`, expected: struct{ String string }{"foo"}}, + {name: "Int", input: `{"int":42}`, expected: struct{ Int int }{42}}, + {name: "Float", input: `{"float":42.42}`, expected: struct{ Float float64 }{42.42}}, + {name: "Bool", input: `{"bool":true}`, expected: struct{ Bool bool }{true}}, + {name: "Nil", input: `{"nil":null}`, expected: struct{ Nil *int }{nil}}, + {name: "Slice", input: `{"slice":[1,2,3]}`, expected: struct{ Slice []int }{[]int{1, 2, 3}}}, + {name: "SliceOfStrings", input: `{"slice":["hello","world"]}`, expected: struct{ Slice []string }{[]string{"hello", "world"}}}, + {name: "Map", input: `{"map":{"foo":42}}`, expected: struct{ Map map[string]int }{map[string]int{"foo": 42}}}, + {name: "Option", input: `{"option":42}`, expected: struct{ Option ftl.Option[int] }{ftl.Some(42)}}, + {name: "OptionPtr", input: `{"option":42}`, expected: struct{ Option *ftl.Option[int] }{&somePtr}}, + {name: "OptionStruct", input: `{"option":{"fooBar":"foo"}}`, expected: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}}, + {name: "Unit", input: `{}`, expected: ftl.Unit{}}, + {name: "UnitField", input: `{"string":"something"}`, expected: struct { + String string + Unit ftl.Unit + }{String: "something", Unit: ftl.Unit{}}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + eType := reflect.TypeOf(tt.expected) + if eType.Kind() == reflect.Ptr { + eType = eType.Elem() + } + o := reflect.New(eType) + err := Unmarshal([]byte(tt.input), o.Interface()) + assert.EqualError(t, err, tt.err) + assert.Equal(t, tt.expected, o.Elem().Interface()) + }) + } +} + +func TestRoundTrip(t *testing.T) { + type inner struct { + FooBar string + } + somePtr := ftl.Some(42) + tests := []struct { + name string + input any + }{ + {name: "FieldRenaming", input: struct{ FooBar string }{""}}, + {name: "String", input: struct{ String string }{"foo"}}, + {name: "Int", input: struct{ Int int }{42}}, + {name: "Float", input: struct{ Float float64 }{42.42}}, + {name: "Bool", input: struct{ Bool bool }{true}}, + {name: "Nil", input: struct{ Nil *int }{nil}}, + {name: "Slice", input: struct{ Slice []int }{[]int{1, 2, 3}}}, + {name: "SliceOfStrings", input: struct{ Slice []string }{[]string{"hello", "world"}}}, + {name: "Map", input: struct{ Map map[string]int }{map[string]int{"foo": 42}}}, + {name: "Option", input: struct{ Option ftl.Option[int] }{ftl.Some(42)}}, + {name: "OptionPtr", input: struct{ Option *ftl.Option[int] }{&somePtr}}, + {name: "OptionStruct", input: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}}, + {name: "Unit", input: ftl.Unit{}}, + {name: "UnitField", input: struct { + String string + Unit ftl.Unit + }{String: "something", Unit: ftl.Unit{}}}, + {name: "Aliased", input: struct { + TokenID string `json:"token_id"` + }{"123"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + marshaled, err := Marshal(tt.input) + assert.NoError(t, err) + + eType := reflect.TypeOf(tt.input) + if eType.Kind() == reflect.Ptr { + eType = eType.Elem() + } + o := reflect.New(eType) + err = Unmarshal(marshaled, o.Interface()) + assert.NoError(t, err) + + assert.Equal(t, tt.input, o.Elem().Interface()) + }) + } +} diff --git a/go-runtime/ftl/call.go b/go-runtime/ftl/call.go index 5ae3c61999..0aa59a9b5c 100644 --- a/go-runtime/ftl/call.go +++ b/go-runtime/ftl/call.go @@ -2,7 +2,6 @@ package ftl import ( "context" - "encoding/json" "fmt" "reflect" "runtime" @@ -34,7 +33,7 @@ func Call[Req, Resp any](ctx context.Context, verb Verb[Req, Resp], req Req) (re return resp, fmt.Errorf("%s: %s", callee, cresp.Error.Message) case *ftlv1.CallResponse_Body: - err = json.Unmarshal(cresp.Body, &resp) + err = encoding.Unmarshal(cresp.Body, &resp) if err != nil { return resp, fmt.Errorf("%s: failed to decode response: %w", callee, err) } diff --git a/go-runtime/ftl/option.go b/go-runtime/ftl/option.go index 99e39fbefe..73e4b5bee8 100644 --- a/go-runtime/ftl/option.go +++ b/go-runtime/ftl/option.go @@ -179,7 +179,7 @@ func (o *Option[T]) UnmarshalJSON(data []byte) error { o.ok = false return nil } - if err := json.Unmarshal(data, &o.value); err != nil { + if err := ftlencoding.Unmarshal(data, &o.value); err != nil { return err } o.ok = true diff --git a/go-runtime/server/server.go b/go-runtime/server/server.go index 9809177800..a2117f19dc 100644 --- a/go-runtime/server/server.go +++ b/go-runtime/server/server.go @@ -2,7 +2,6 @@ package server import ( "context" - "encoding/json" "fmt" "net/url" "runtime/debug" @@ -56,7 +55,7 @@ func Handle[Req, Resp any](verb func(ctx context.Context, req Req) (Resp, error) fn: func(ctx context.Context, reqdata []byte) ([]byte, error) { // Decode request. var req Req - err := json.Unmarshal(reqdata, &req) + err := encoding.Unmarshal(reqdata, &req) if err != nil { return nil, fmt.Errorf("invalid request to verb %s: %w", ref, err) } diff --git a/integration/integration_test.go b/integration/integration_test.go index 7621c94be6..4b3ed26140 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -31,6 +31,7 @@ import ( ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1" "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/ftlv1connect" schemapb "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/schema" + "github.com/TBD54566975/ftl/go-runtime/encoding" "github.com/TBD54566975/ftl/internal/exec" "github.com/TBD54566975/ftl/internal/log" "github.com/TBD54566975/ftl/internal/rpc" @@ -386,7 +387,7 @@ type obj map[string]any func call[Resp any](module, verb string, req obj, onResponse func(t testing.TB, resp Resp)) assertion { return func(t testing.TB, ic itContext) error { - jreq, err := json.Marshal(req) + jreq, err := encoding.Marshal(req) assert.NoError(t, err) cresp, err := ic.verbs.Call(ic, connect.NewRequest(&ftlv1.CallRequest{ @@ -402,7 +403,7 @@ func call[Resp any](module, verb string, req obj, onResponse func(t testing.TB, } var resp Resp - err = json.Unmarshal(cresp.Msg.GetBody(), &resp) + err = encoding.Unmarshal(cresp.Msg.GetBody(), &resp) assert.NoError(t, err) onResponse(t, resp)