Skip to content

Commit

Permalink
chore: implement encoding.Unmarshal (#975)
Browse files Browse the repository at this point in the history
fixes #971
fixes #951
  • Loading branch information
worstell authored Feb 23, 2024
1 parent 71da18f commit 0639b89
Show file tree
Hide file tree
Showing 7 changed files with 297 additions and 17 deletions.
6 changes: 2 additions & 4 deletions backend/controller/ingress/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package ingress

import (
"bytes"
"encoding/json"
"net/http"
"net/url"
"reflect"
Expand All @@ -17,8 +16,7 @@ import (
)

type AliasRequest struct {
// FIXME: This should be an alias (`json:"alias"`) once encoding.Unmarshal is available.
Aliased string
Aliased string `json:"alias"`
}

type PathParameterRequest struct {
Expand Down Expand Up @@ -184,7 +182,7 @@ func TestBuildRequestBody(t *testing.T) {
assert.NoError(t, err)
actualrv := reflect.New(reflect.TypeOf(test.expected))
actual := actualrv.Interface()
err = json.Unmarshal(requestBody, actual)
err = encoding.Unmarshal(requestBody, actual)
assert.NoError(t, err)
assert.Equal(t, test.expected, actualrv.Elem().Interface(), assert.OmitEmpty())
})
Expand Down
196 changes: 190 additions & 6 deletions go-runtime/encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ import (
)

var (
textUnarmshaler = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
jsonUnmarshaler = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
textMarshaler = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
textUnmarshaler = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
jsonMarshaler = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
jsonUnmarshaler = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
)

func Marshal(v any) ([]byte, error) {
Expand All @@ -32,11 +34,11 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error {
}
t := v.Type()
switch {
case t.Kind() == reflect.Ptr && t.Elem().Implements(jsonUnmarshaler):
case t.Kind() == reflect.Ptr && t.Elem().Implements(jsonMarshaler):
v = v.Elem()
fallthrough

case t.Implements(jsonUnmarshaler):
case t.Implements(jsonMarshaler):
enc := v.Interface().(json.Marshaler) //nolint:forcetypeassert
data, err := enc.MarshalJSON()
if err != nil {
Expand All @@ -45,11 +47,11 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error {
w.Write(data)
return nil

case t.Kind() == reflect.Ptr && t.Elem().Implements(textUnarmshaler):
case t.Kind() == reflect.Ptr && t.Elem().Implements(textMarshaler):
v = v.Elem()
fallthrough

case t.Implements(textUnarmshaler):
case t.Implements(textMarshaler):
enc := v.Interface().(encoding.TextMarshaler) //nolint:forcetypeassert
data, err := enc.MarshalText()
if err != nil {
Expand Down Expand Up @@ -198,3 +200,185 @@ func encodeString(v reflect.Value, w *bytes.Buffer) error {
w.WriteRune('"')
return nil
}

func Unmarshal(data []byte, v any) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return fmt.Errorf("unmarshal expects a non-nil pointer")
}

d := json.NewDecoder(bytes.NewReader(data))
return decodeValue(d, rv.Elem())
}

func decodeValue(d *json.Decoder, v reflect.Value) error {
if !v.CanSet() {
return fmt.Errorf("cannot set value")
}

t := v.Type()
switch {
case v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(jsonUnmarshaler):
v = v.Addr()
fallthrough

case t.Implements(jsonUnmarshaler):
if v.IsNil() {
v.Set(reflect.New(t.Elem()))
}
o := v.Interface()
return d.Decode(&o)

case v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(textUnmarshaler):
v = v.Addr()
fallthrough

case t.Implements(textUnmarshaler):
if v.IsNil() {
v.Set(reflect.New(t.Elem()))
}
dec := v.Interface().(encoding.TextUnmarshaler) //nolint:forcetypeassert
var s string
if err := d.Decode(&s); err != nil {
return err
}
return dec.UnmarshalText([]byte(s))
}

switch v.Kind() {
case reflect.Struct:
return decodeStruct(d, v)

case reflect.Ptr:
if token, err := d.Token(); err != nil {
return err
} else if token == nil {
v.Set(reflect.Zero(v.Type()))
return nil
}
return decodeValue(d, v.Elem())

case reflect.Slice:
if v.Type().Elem().Kind() == reflect.Uint8 {
return decodeBytes(d, v)
}
return decodeSlice(d, v)

case reflect.Map:
return decodeMap(d, v)

case reflect.Interface:
if v.Type().NumMethod() != 0 {
return fmt.Errorf("the only interface type supported is any, not %s", v.Type())
}
fallthrough

default:
return d.Decode(v.Addr().Interface())
}
}

func decodeStruct(d *json.Decoder, v reflect.Value) error {
if err := expectDelim(d, '{'); err != nil {
return err
}

for d.More() {
token, err := d.Token()
if err != nil {
return err
}
key, ok := token.(string)
if !ok {
return fmt.Errorf("expected string key, got %T", token)
}

field := v.FieldByNameFunc(func(s string) bool {
return strcase.ToLowerCamel(s) == key
})
if !field.IsValid() {
return fmt.Errorf("no field corresponding to key %s", key)
}
fieldTypeStr := field.Type().String()
switch {
case fieldTypeStr == "*Unit" || fieldTypeStr == "Unit":
if fieldTypeStr == "*Unit" && field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
default:
if err := decodeValue(d, field); err != nil {
return err
}
}
}

// consume the closing delimiter of the object
_, err := d.Token()
return err
}

func decodeBytes(d *json.Decoder, v reflect.Value) error {
var b []byte
if err := d.Decode(&b); err != nil {
return err
}
v.SetBytes(b)
return nil
}

func decodeSlice(d *json.Decoder, v reflect.Value) error {
if err := expectDelim(d, '['); err != nil {
return err
}

for d.More() {
newElem := reflect.New(v.Type().Elem()).Elem()
if err := decodeValue(d, newElem); err != nil {
return err
}
v.Set(reflect.Append(v, newElem))
}
// consume the closing delimiter of the slice
_, err := d.Token()
return err
}

func decodeMap(d *json.Decoder, v reflect.Value) error {
if err := expectDelim(d, '{'); err != nil {
return err
}

if v.IsNil() {
v.Set(reflect.MakeMap(v.Type()))
}

valType := v.Type().Elem()
for d.More() {
key, err := d.Token()
if err != nil {
return err
}

newElem := reflect.New(valType).Elem()
if err := decodeValue(d, newElem); err != nil {
return err
}

v.SetMapIndex(reflect.ValueOf(key), newElem)
}
// consume the closing delimiter of the map
_, err := d.Token()
return err
}

func expectDelim(d *json.Decoder, expected json.Delim) error {
token, err := d.Token()
if err != nil {
return err
}
delim, ok := token.(json.Delim)
if !ok || delim != expected {
return fmt.Errorf("expected delimiter %q, got %q", expected, token)
}
return nil
}
99 changes: 99 additions & 0 deletions go-runtime/encoding/encoding_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package encoding_test

import (
"reflect"
"testing"

. "github.com/TBD54566975/ftl/go-runtime/encoding"
Expand Down Expand Up @@ -31,6 +32,11 @@ func TestMarshal(t *testing.T) {
{name: "Option", input: struct{ Option ftl.Option[int] }{ftl.Some(42)}, expected: `{"option":42}`},
{name: "OptionPtr", input: struct{ Option *ftl.Option[int] }{&somePtr}, expected: `{"option":42}`},
{name: "OptionStruct", input: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}, expected: `{"option":{"fooBar":"foo"}}`},
{name: "Unit", input: ftl.Unit{}, expected: `{}`},
{name: "UnitField", input: struct {
String string
Unit ftl.Unit
}{String: "something", Unit: ftl.Unit{}}, expected: `{"string":"something"}`},
}

for _, tt := range tests {
Expand All @@ -41,3 +47,96 @@ func TestMarshal(t *testing.T) {
})
}
}

func TestUnmarshal(t *testing.T) {
type inner struct {
FooBar string
}
somePtr := ftl.Some(42)
tests := []struct {
name string
input string
expected any
err string
}{
{name: "FieldRenaming", input: `{"fooBar":""}`, expected: struct{ FooBar string }{""}},
{name: "String", input: `{"string":"foo"}`, expected: struct{ String string }{"foo"}},
{name: "Int", input: `{"int":42}`, expected: struct{ Int int }{42}},
{name: "Float", input: `{"float":42.42}`, expected: struct{ Float float64 }{42.42}},
{name: "Bool", input: `{"bool":true}`, expected: struct{ Bool bool }{true}},
{name: "Nil", input: `{"nil":null}`, expected: struct{ Nil *int }{nil}},
{name: "Slice", input: `{"slice":[1,2,3]}`, expected: struct{ Slice []int }{[]int{1, 2, 3}}},
{name: "SliceOfStrings", input: `{"slice":["hello","world"]}`, expected: struct{ Slice []string }{[]string{"hello", "world"}}},
{name: "Map", input: `{"map":{"foo":42}}`, expected: struct{ Map map[string]int }{map[string]int{"foo": 42}}},
{name: "Option", input: `{"option":42}`, expected: struct{ Option ftl.Option[int] }{ftl.Some(42)}},
{name: "OptionPtr", input: `{"option":42}`, expected: struct{ Option *ftl.Option[int] }{&somePtr}},
{name: "OptionStruct", input: `{"option":{"fooBar":"foo"}}`, expected: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}},
{name: "Unit", input: `{}`, expected: ftl.Unit{}},
{name: "UnitField", input: `{"string":"something"}`, expected: struct {
String string
Unit ftl.Unit
}{String: "something", Unit: ftl.Unit{}}},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
eType := reflect.TypeOf(tt.expected)
if eType.Kind() == reflect.Ptr {
eType = eType.Elem()
}
o := reflect.New(eType)
err := Unmarshal([]byte(tt.input), o.Interface())
assert.EqualError(t, err, tt.err)
assert.Equal(t, tt.expected, o.Elem().Interface())
})
}
}

func TestRoundTrip(t *testing.T) {
type inner struct {
FooBar string
}
somePtr := ftl.Some(42)
tests := []struct {
name string
input any
}{
{name: "FieldRenaming", input: struct{ FooBar string }{""}},
{name: "String", input: struct{ String string }{"foo"}},
{name: "Int", input: struct{ Int int }{42}},
{name: "Float", input: struct{ Float float64 }{42.42}},
{name: "Bool", input: struct{ Bool bool }{true}},
{name: "Nil", input: struct{ Nil *int }{nil}},
{name: "Slice", input: struct{ Slice []int }{[]int{1, 2, 3}}},
{name: "SliceOfStrings", input: struct{ Slice []string }{[]string{"hello", "world"}}},
{name: "Map", input: struct{ Map map[string]int }{map[string]int{"foo": 42}}},
{name: "Option", input: struct{ Option ftl.Option[int] }{ftl.Some(42)}},
{name: "OptionPtr", input: struct{ Option *ftl.Option[int] }{&somePtr}},
{name: "OptionStruct", input: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}},
{name: "Unit", input: ftl.Unit{}},
{name: "UnitField", input: struct {
String string
Unit ftl.Unit
}{String: "something", Unit: ftl.Unit{}}},
{name: "Aliased", input: struct {
TokenID string `json:"token_id"`
}{"123"}},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
marshaled, err := Marshal(tt.input)
assert.NoError(t, err)

eType := reflect.TypeOf(tt.input)
if eType.Kind() == reflect.Ptr {
eType = eType.Elem()
}
o := reflect.New(eType)
err = Unmarshal(marshaled, o.Interface())
assert.NoError(t, err)

assert.Equal(t, tt.input, o.Elem().Interface())
})
}
}
3 changes: 1 addition & 2 deletions go-runtime/ftl/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package ftl

import (
"context"
"encoding/json"
"fmt"
"reflect"
"runtime"
Expand Down Expand Up @@ -34,7 +33,7 @@ func Call[Req, Resp any](ctx context.Context, verb Verb[Req, Resp], req Req) (re
return resp, fmt.Errorf("%s: %s", callee, cresp.Error.Message)

case *ftlv1.CallResponse_Body:
err = json.Unmarshal(cresp.Body, &resp)
err = encoding.Unmarshal(cresp.Body, &resp)
if err != nil {
return resp, fmt.Errorf("%s: failed to decode response: %w", callee, err)
}
Expand Down
2 changes: 1 addition & 1 deletion go-runtime/ftl/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func (o *Option[T]) UnmarshalJSON(data []byte) error {
o.ok = false
return nil
}
if err := json.Unmarshal(data, &o.value); err != nil {
if err := ftlencoding.Unmarshal(data, &o.value); err != nil {
return err
}
o.ok = true
Expand Down
Loading

0 comments on commit 0639b89

Please sign in to comment.