Skip to content

Commit

Permalink
chore: implement encoding.Unmarshal
Browse files Browse the repository at this point in the history
fixes #971
  • Loading branch information
worstell committed Feb 23, 2024
1 parent 71da18f commit 6a2bfec
Show file tree
Hide file tree
Showing 7 changed files with 335 additions and 17 deletions.
6 changes: 2 additions & 4 deletions backend/controller/ingress/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package ingress

import (
"bytes"
"encoding/json"
"net/http"
"net/url"
"reflect"
Expand All @@ -17,8 +16,7 @@ import (
)

type AliasRequest struct {
// FIXME: This should be an alias (`json:"alias"`) once encoding.Unmarshal is available.
Aliased string
Aliased string `json:"alias"`
}

type PathParameterRequest struct {
Expand Down Expand Up @@ -184,7 +182,7 @@ func TestBuildRequestBody(t *testing.T) {
assert.NoError(t, err)
actualrv := reflect.New(reflect.TypeOf(test.expected))
actual := actualrv.Interface()
err = json.Unmarshal(requestBody, actual)
err = encoding.Unmarshal(requestBody, actual)
assert.NoError(t, err)
assert.Equal(t, test.expected, actualrv.Elem().Interface(), assert.OmitEmpty())
})
Expand Down
234 changes: 228 additions & 6 deletions go-runtime/encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@ import (
"encoding/json"
"fmt"
"reflect"
"strconv"
"strings"

"github.com/TBD54566975/ftl/backend/schema/strcase"
)

var (
textUnarmshaler = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
jsonUnmarshaler = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
textMarshaler = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
textUnmarshaler = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
jsonMarshaler = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
jsonUnmarshaler = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
)

func Marshal(v any) ([]byte, error) {
Expand All @@ -32,11 +35,11 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error {
}
t := v.Type()
switch {
case t.Kind() == reflect.Ptr && t.Elem().Implements(jsonUnmarshaler):
case t.Kind() == reflect.Ptr && t.Elem().Implements(jsonMarshaler):
v = v.Elem()
fallthrough

case t.Implements(jsonUnmarshaler):
case t.Implements(jsonMarshaler):
enc := v.Interface().(json.Marshaler) //nolint:forcetypeassert
data, err := enc.MarshalJSON()
if err != nil {
Expand All @@ -45,11 +48,11 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error {
w.Write(data)
return nil

case t.Kind() == reflect.Ptr && t.Elem().Implements(textUnarmshaler):
case t.Kind() == reflect.Ptr && t.Elem().Implements(textMarshaler):
v = v.Elem()
fallthrough

case t.Implements(textUnarmshaler):
case t.Implements(textMarshaler):
enc := v.Interface().(encoding.TextMarshaler) //nolint:forcetypeassert
data, err := enc.MarshalText()
if err != nil {
Expand Down Expand Up @@ -198,3 +201,222 @@ func encodeString(v reflect.Value, w *bytes.Buffer) error {
w.WriteRune('"')
return nil
}

func Unmarshal(data []byte, v any) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return fmt.Errorf("unmarshal expects a non-nil pointer")
}
return decodeValue(data, rv.Elem())
}

func decodeValue(data []byte, v reflect.Value) error {
if !v.CanSet() {
return fmt.Errorf("cannot set value")
}

t := v.Type()
switch {
case v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(jsonUnmarshaler):
v = v.Addr()
fallthrough

case t.Implements(jsonUnmarshaler):
if v.IsNil() {
v.Set(reflect.New(t.Elem()))
}
dec := v.Interface().(json.Unmarshaler) //nolint:forcetypeassert
return dec.UnmarshalJSON(data)

case v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(textUnmarshaler):
v = v.Addr()
fallthrough

case t.Implements(textUnmarshaler):
if v.IsNil() {
v.Set(reflect.New(t.Elem()))
}
dec := v.Interface().(encoding.TextUnmarshaler) //nolint:forcetypeassert
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
return dec.UnmarshalText([]byte(s))
}

switch v.Kind() {
case reflect.Struct:
return decodeStruct(data, v)

case reflect.Ptr:
if string(data) == "null" {
v.Set(reflect.Zero(v.Type()))
return nil
}
return decodeValue(data, v.Elem())

case reflect.Slice:
if v.Type().Elem().Kind() == reflect.Uint8 {
return decodeBytes(data, v)
}
return decodeSlice(data, v)

case reflect.Map:
return decodeMap(data, v)

case reflect.String:
return decodeString(data, v)

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return decodeInt(data, v)

case reflect.Float32, reflect.Float64:
return decodeFloat(data, v)

case reflect.Bool:
return decodeBool(data, v)

case reflect.Interface:
if v.Type().NumMethod() != 0 {
return fmt.Errorf("the only interface type supported is any, not %s", v.Type())
}

var anyInterface any
if err := json.Unmarshal(data, &anyInterface); err != nil {
return err
}
reflectedValue := reflect.ValueOf(anyInterface)
if !reflectedValue.Type().AssignableTo(v.Type()) {
return fmt.Errorf("cannot assign type %s to any interface", reflectedValue.Type())
}
v.Set(reflectedValue)
return nil

default:
return fmt.Errorf("unsupported type: %s", v.Type())
}
}

func decodeStruct(data []byte, v reflect.Value) error {
var jsonData map[string]json.RawMessage
if err := json.Unmarshal(data, &jsonData); err != nil {
return err
}

for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
fieldType := v.Type().Field(i)
jsonKey := strcase.ToLowerCamel(fieldType.Name)
if _, ok := jsonData[jsonKey]; ok {
fieldTypeStr := fieldType.Type.String()
switch {
case fieldTypeStr == "*Unit" || fieldTypeStr == "Unit":
if fieldTypeStr == "*Unit" && field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}

default:
if err := decodeValue(jsonData[jsonKey], field); err != nil {
return fmt.Errorf("error decoding field '%s': %w", fieldType.Name, err)
}
}
}
}
return nil
}

func decodeString(data []byte, v reflect.Value) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
v.SetString(s)
return nil
}

func decodeInt(data []byte, v reflect.Value) error {
var n int64
n, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return err
}
v.SetInt(n)
return nil
}

func decodeFloat(data []byte, v reflect.Value) error {
var f float64
if err := json.Unmarshal(data, &f); err != nil {
return err
}
v.SetFloat(f)
return nil
}

func decodeBool(data []byte, v reflect.Value) error {
var b bool
if err := json.Unmarshal(data, &b); err != nil {
return err
}
v.SetBool(b)
return nil
}

func decodeBytes(data []byte, v reflect.Value) error {
var b []byte
if err := json.Unmarshal(data, &b); err != nil {
return err
}
v.SetBytes(b)
return nil
}

func decodeSlice(data []byte, v reflect.Value) error {
var intfSlice []any
if err := json.Unmarshal(data, &intfSlice); err != nil {
return err
}

sliceType := v.Type().Elem()
newSlice := reflect.MakeSlice(reflect.SliceOf(sliceType), len(intfSlice), len(intfSlice))

for i, intf := range intfSlice {
elem := newSlice.Index(i)
elemData, err := json.Marshal(intf)
if err != nil {
return err
}
if err = decodeValue(elemData, elem); err != nil {
return err
}
}

v.Set(newSlice)

return nil
}

func decodeMap(data []byte, v reflect.Value) error {
if v.IsNil() {
v.Set(reflect.MakeMap(v.Type()))
}

var intfMap map[string]any
if err := json.Unmarshal(data, &intfMap); err != nil {
return err
}

for key, intf := range intfMap {
valData, err := json.Marshal(intf)
if err != nil {
return err
}
mapValue := reflect.New(v.Type().Elem())
if err = decodeValue(valData, mapValue.Elem()); err != nil {
return err
}
v.SetMapIndex(reflect.ValueOf(key), mapValue.Elem())
}

return nil
}
Loading

0 comments on commit 6a2bfec

Please sign in to comment.