From e83e33bc860fee4c73d5b4df26f3e508cb3d3eb5 Mon Sep 17 00:00:00 2001 From: Elizabeth Worstell Date: Thu, 22 Feb 2024 18:34:04 -0800 Subject: [PATCH] chore: implement encoding.Unmarshal fixes #971 --- backend/controller/ingress/request_test.go | 6 +- go-runtime/encoding/encoding.go | 196 ++++++++++++++++++++- go-runtime/encoding/encoding_test.go | 99 +++++++++++ go-runtime/ftl/call.go | 3 +- go-runtime/ftl/option.go | 2 +- go-runtime/server/server.go | 3 +- integration/integration_test.go | 5 +- 7 files changed, 297 insertions(+), 17 deletions(-) diff --git a/backend/controller/ingress/request_test.go b/backend/controller/ingress/request_test.go index 4cee7f9b88..7671c5f85b 100644 --- a/backend/controller/ingress/request_test.go +++ b/backend/controller/ingress/request_test.go @@ -2,7 +2,6 @@ package ingress import ( "bytes" - "encoding/json" "net/http" "net/url" "reflect" @@ -17,8 +16,7 @@ import ( ) type AliasRequest struct { - // FIXME: This should be an alias (`json:"alias"`) once encoding.Unmarshal is available. - Aliased string + Aliased string `json:"alias"` } type PathParameterRequest struct { @@ -184,7 +182,7 @@ func TestBuildRequestBody(t *testing.T) { assert.NoError(t, err) actualrv := reflect.New(reflect.TypeOf(test.expected)) actual := actualrv.Interface() - err = json.Unmarshal(requestBody, actual) + err = encoding.Unmarshal(requestBody, actual) assert.NoError(t, err) assert.Equal(t, test.expected, actualrv.Elem().Interface(), assert.OmitEmpty()) }) diff --git a/go-runtime/encoding/encoding.go b/go-runtime/encoding/encoding.go index 4849c2e72c..0dd71dcd33 100644 --- a/go-runtime/encoding/encoding.go +++ b/go-runtime/encoding/encoding.go @@ -15,8 +15,10 @@ import ( ) var ( - textUnarmshaler = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() - jsonUnmarshaler = reflect.TypeOf((*json.Marshaler)(nil)).Elem() + textMarshaler = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + textUnmarshaler = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + jsonMarshaler = reflect.TypeOf((*json.Marshaler)(nil)).Elem() + jsonUnmarshaler = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() ) func Marshal(v any) ([]byte, error) { @@ -32,11 +34,11 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error { } t := v.Type() switch { - case t.Kind() == reflect.Ptr && t.Elem().Implements(jsonUnmarshaler): + case t.Kind() == reflect.Ptr && t.Elem().Implements(jsonMarshaler): v = v.Elem() fallthrough - case t.Implements(jsonUnmarshaler): + case t.Implements(jsonMarshaler): enc := v.Interface().(json.Marshaler) //nolint:forcetypeassert data, err := enc.MarshalJSON() if err != nil { @@ -45,11 +47,11 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error { w.Write(data) return nil - case t.Kind() == reflect.Ptr && t.Elem().Implements(textUnarmshaler): + case t.Kind() == reflect.Ptr && t.Elem().Implements(textMarshaler): v = v.Elem() fallthrough - case t.Implements(textUnarmshaler): + case t.Implements(textMarshaler): enc := v.Interface().(encoding.TextMarshaler) //nolint:forcetypeassert data, err := enc.MarshalText() if err != nil { @@ -198,3 +200,185 @@ func encodeString(v reflect.Value, w *bytes.Buffer) error { w.WriteRune('"') return nil } + +func Unmarshal(data []byte, v any) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return fmt.Errorf("unmarshal expects a non-nil pointer") + } + + d := json.NewDecoder(bytes.NewReader(data)) + return decodeValue(d, rv.Elem()) +} + +func decodeValue(d *json.Decoder, v reflect.Value) error { + if !v.CanSet() { + return fmt.Errorf("cannot set value") + } + + t := v.Type() + switch { + case v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(jsonUnmarshaler): + v = v.Addr() + fallthrough + + case t.Implements(jsonUnmarshaler): + if v.IsNil() { + v.Set(reflect.New(t.Elem())) + } + o := v.Interface() + return d.Decode(&o) + + case v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(textUnmarshaler): + v = v.Addr() + fallthrough + + case t.Implements(textUnmarshaler): + if v.IsNil() { + v.Set(reflect.New(t.Elem())) + } + dec := v.Interface().(encoding.TextUnmarshaler) //nolint:forcetypeassert + var s string + if err := d.Decode(&s); err != nil { + return err + } + return dec.UnmarshalText([]byte(s)) + } + + switch v.Kind() { + case reflect.Struct: + return decodeStruct(d, v) + + case reflect.Ptr: + if token, err := d.Token(); err != nil { + return err + } else if token == nil { + v.Set(reflect.Zero(v.Type())) + return nil + } + return decodeValue(d, v.Elem()) + + case reflect.Slice: + if v.Type().Elem().Kind() == reflect.Uint8 { + return decodeBytes(d, v) + } + return decodeSlice(d, v) + + case reflect.Map: + return decodeMap(d, v) + + case reflect.Interface: + if v.Type().NumMethod() != 0 { + return fmt.Errorf("the only interface type supported is any, not %s", v.Type()) + } + fallthrough + + default: + return d.Decode(v.Addr().Interface()) + } +} + +func decodeStruct(d *json.Decoder, v reflect.Value) error { + if err := expectDelim(d, '{'); err != nil { + return err + } + + for d.More() { + token, err := d.Token() + if err != nil { + return err + } + key, ok := token.(string) + if !ok { + return fmt.Errorf("expected string key, got %T", token) + } + + field := v.FieldByNameFunc(func(s string) bool { + return strcase.ToLowerCamel(s) == key + }) + if !field.IsValid() { + return fmt.Errorf("no field corresponding to key %s", key) + } + fieldTypeStr := field.Type().String() + switch { + case fieldTypeStr == "*Unit" || fieldTypeStr == "Unit": + if fieldTypeStr == "*Unit" && field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + default: + if err := decodeValue(d, field); err != nil { + return err + } + } + } + + // consume the closing delimiter of the object + _, err := d.Token() + return err +} + +func decodeBytes(d *json.Decoder, v reflect.Value) error { + var b []byte + if err := d.Decode(&b); err != nil { + return err + } + v.SetBytes(b) + return nil +} + +func decodeSlice(d *json.Decoder, v reflect.Value) error { + if err := expectDelim(d, '['); err != nil { + return err + } + + for d.More() { + newElem := reflect.New(v.Type().Elem()).Elem() + if err := decodeValue(d, newElem); err != nil { + return err + } + v.Set(reflect.Append(v, newElem)) + } + // consume the closing delimiter of the slice + _, err := d.Token() + return err +} + +func decodeMap(d *json.Decoder, v reflect.Value) error { + if err := expectDelim(d, '{'); err != nil { + return err + } + + if v.IsNil() { + v.Set(reflect.MakeMap(v.Type())) + } + + valType := v.Type().Elem() + for d.More() { + key, err := d.Token() + if err != nil { + return err + } + + newElem := reflect.New(valType).Elem() + if err := decodeValue(d, newElem); err != nil { + return err + } + + v.SetMapIndex(reflect.ValueOf(key), newElem) + } + // consume the closing delimiter of the map + _, err := d.Token() + return err +} + +func expectDelim(d *json.Decoder, expected json.Delim) error { + token, err := d.Token() + if err != nil { + return err + } + delim, ok := token.(json.Delim) + if !ok || delim != expected { + return fmt.Errorf("expected delimiter %q, got %q", expected, token) + } + return nil +} diff --git a/go-runtime/encoding/encoding_test.go b/go-runtime/encoding/encoding_test.go index 860a7e479a..18053f466f 100644 --- a/go-runtime/encoding/encoding_test.go +++ b/go-runtime/encoding/encoding_test.go @@ -1,6 +1,7 @@ package encoding_test import ( + "reflect" "testing" . "github.com/TBD54566975/ftl/go-runtime/encoding" @@ -31,6 +32,11 @@ func TestMarshal(t *testing.T) { {name: "Option", input: struct{ Option ftl.Option[int] }{ftl.Some(42)}, expected: `{"option":42}`}, {name: "OptionPtr", input: struct{ Option *ftl.Option[int] }{&somePtr}, expected: `{"option":42}`}, {name: "OptionStruct", input: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}, expected: `{"option":{"fooBar":"foo"}}`}, + {name: "Unit", input: ftl.Unit{}, expected: `{}`}, + {name: "UnitField", input: struct { + String string + Unit ftl.Unit + }{String: "something", Unit: ftl.Unit{}}, expected: `{"string":"something"}`}, } for _, tt := range tests { @@ -41,3 +47,96 @@ func TestMarshal(t *testing.T) { }) } } + +func TestUnmarshal(t *testing.T) { + type inner struct { + FooBar string + } + somePtr := ftl.Some(42) + tests := []struct { + name string + input string + expected any + err string + }{ + {name: "FieldRenaming", input: `{"fooBar":""}`, expected: struct{ FooBar string }{""}}, + {name: "String", input: `{"string":"foo"}`, expected: struct{ String string }{"foo"}}, + {name: "Int", input: `{"int":42}`, expected: struct{ Int int }{42}}, + {name: "Float", input: `{"float":42.42}`, expected: struct{ Float float64 }{42.42}}, + {name: "Bool", input: `{"bool":true}`, expected: struct{ Bool bool }{true}}, + {name: "Nil", input: `{"nil":null}`, expected: struct{ Nil *int }{nil}}, + {name: "Slice", input: `{"slice":[1,2,3]}`, expected: struct{ Slice []int }{[]int{1, 2, 3}}}, + {name: "SliceOfStrings", input: `{"slice":["hello","world"]}`, expected: struct{ Slice []string }{[]string{"hello", "world"}}}, + {name: "Map", input: `{"map":{"foo":42}}`, expected: struct{ Map map[string]int }{map[string]int{"foo": 42}}}, + {name: "Option", input: `{"option":42}`, expected: struct{ Option ftl.Option[int] }{ftl.Some(42)}}, + {name: "OptionPtr", input: `{"option":42}`, expected: struct{ Option *ftl.Option[int] }{&somePtr}}, + {name: "OptionStruct", input: `{"option":{"fooBar":"foo"}}`, expected: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}}, + {name: "Unit", input: `{}`, expected: ftl.Unit{}}, + {name: "UnitField", input: `{"string":"something"}`, expected: struct { + String string + Unit ftl.Unit + }{String: "something", Unit: ftl.Unit{}}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + eType := reflect.TypeOf(tt.expected) + if eType.Kind() == reflect.Ptr { + eType = eType.Elem() + } + o := reflect.New(eType) + err := Unmarshal([]byte(tt.input), o.Interface()) + assert.EqualError(t, err, tt.err) + assert.Equal(t, tt.expected, o.Elem().Interface()) + }) + } +} + +func TestRoundTrip(t *testing.T) { + type inner struct { + FooBar string + } + somePtr := ftl.Some(42) + tests := []struct { + name string + input any + }{ + {name: "FieldRenaming", input: struct{ FooBar string }{""}}, + {name: "String", input: struct{ String string }{"foo"}}, + {name: "Int", input: struct{ Int int }{42}}, + {name: "Float", input: struct{ Float float64 }{42.42}}, + {name: "Bool", input: struct{ Bool bool }{true}}, + {name: "Nil", input: struct{ Nil *int }{nil}}, + {name: "Slice", input: struct{ Slice []int }{[]int{1, 2, 3}}}, + {name: "SliceOfStrings", input: struct{ Slice []string }{[]string{"hello", "world"}}}, + {name: "Map", input: struct{ Map map[string]int }{map[string]int{"foo": 42}}}, + {name: "Option", input: struct{ Option ftl.Option[int] }{ftl.Some(42)}}, + {name: "OptionPtr", input: struct{ Option *ftl.Option[int] }{&somePtr}}, + {name: "OptionStruct", input: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}}, + {name: "Unit", input: ftl.Unit{}}, + {name: "UnitField", input: struct { + String string + Unit ftl.Unit + }{String: "something", Unit: ftl.Unit{}}}, + {name: "Aliased", input: struct { + TokenID string `json:"token_id"` + }{"123"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + marshaled, err := Marshal(tt.input) + assert.NoError(t, err) + + eType := reflect.TypeOf(tt.input) + if eType.Kind() == reflect.Ptr { + eType = eType.Elem() + } + o := reflect.New(eType) + err = Unmarshal(marshaled, o.Interface()) + assert.NoError(t, err) + + assert.Equal(t, tt.input, o.Elem().Interface()) + }) + } +} diff --git a/go-runtime/ftl/call.go b/go-runtime/ftl/call.go index 5ae3c61999..0aa59a9b5c 100644 --- a/go-runtime/ftl/call.go +++ b/go-runtime/ftl/call.go @@ -2,7 +2,6 @@ package ftl import ( "context" - "encoding/json" "fmt" "reflect" "runtime" @@ -34,7 +33,7 @@ func Call[Req, Resp any](ctx context.Context, verb Verb[Req, Resp], req Req) (re return resp, fmt.Errorf("%s: %s", callee, cresp.Error.Message) case *ftlv1.CallResponse_Body: - err = json.Unmarshal(cresp.Body, &resp) + err = encoding.Unmarshal(cresp.Body, &resp) if err != nil { return resp, fmt.Errorf("%s: failed to decode response: %w", callee, err) } diff --git a/go-runtime/ftl/option.go b/go-runtime/ftl/option.go index 99e39fbefe..73e4b5bee8 100644 --- a/go-runtime/ftl/option.go +++ b/go-runtime/ftl/option.go @@ -179,7 +179,7 @@ func (o *Option[T]) UnmarshalJSON(data []byte) error { o.ok = false return nil } - if err := json.Unmarshal(data, &o.value); err != nil { + if err := ftlencoding.Unmarshal(data, &o.value); err != nil { return err } o.ok = true diff --git a/go-runtime/server/server.go b/go-runtime/server/server.go index 9809177800..a2117f19dc 100644 --- a/go-runtime/server/server.go +++ b/go-runtime/server/server.go @@ -2,7 +2,6 @@ package server import ( "context" - "encoding/json" "fmt" "net/url" "runtime/debug" @@ -56,7 +55,7 @@ func Handle[Req, Resp any](verb func(ctx context.Context, req Req) (Resp, error) fn: func(ctx context.Context, reqdata []byte) ([]byte, error) { // Decode request. var req Req - err := json.Unmarshal(reqdata, &req) + err := encoding.Unmarshal(reqdata, &req) if err != nil { return nil, fmt.Errorf("invalid request to verb %s: %w", ref, err) } diff --git a/integration/integration_test.go b/integration/integration_test.go index 7621c94be6..4b3ed26140 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -31,6 +31,7 @@ import ( ftlv1 "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1" "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/ftlv1connect" schemapb "github.com/TBD54566975/ftl/backend/protos/xyz/block/ftl/v1/schema" + "github.com/TBD54566975/ftl/go-runtime/encoding" "github.com/TBD54566975/ftl/internal/exec" "github.com/TBD54566975/ftl/internal/log" "github.com/TBD54566975/ftl/internal/rpc" @@ -386,7 +387,7 @@ type obj map[string]any func call[Resp any](module, verb string, req obj, onResponse func(t testing.TB, resp Resp)) assertion { return func(t testing.TB, ic itContext) error { - jreq, err := json.Marshal(req) + jreq, err := encoding.Marshal(req) assert.NoError(t, err) cresp, err := ic.verbs.Call(ic, connect.NewRequest(&ftlv1.CallRequest{ @@ -402,7 +403,7 @@ func call[Resp any](module, verb string, req obj, onResponse func(t testing.TB, } var resp Resp - err = json.Unmarshal(cresp.Msg.GetBody(), &resp) + err = encoding.Unmarshal(cresp.Msg.GetBody(), &resp) assert.NoError(t, err) onResponse(t, resp)