diff --git a/go-runtime/encoding/encoding.go b/go-runtime/encoding/encoding.go index 6b7a075594..819a68c212 100644 --- a/go-runtime/encoding/encoding.go +++ b/go-runtime/encoding/encoding.go @@ -4,18 +4,22 @@ package encoding import ( "bytes" + "encoding" "encoding/base64" "encoding/json" "fmt" - "io" "reflect" - "strings" - "time" - "unicode" "github.com/TBD54566975/ftl/backend/schema/strcase" ) +var ( + textMarshaler = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + textUnmarshaler = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + jsonMarshaler = reflect.TypeOf((*json.Marshaler)(nil)).Elem() + jsonUnmarshaler = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() +) + func Marshal(v any) ([]byte, error) { w := &bytes.Buffer{} err := encodeValue(reflect.ValueOf(v), w) @@ -27,29 +31,37 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error { w.WriteString("null") return nil } - t := v.Type() - - // Special-cased types switch { - case t == reflect.TypeFor[time.Time](): - data, err := json.Marshal(v.Interface().(time.Time)) + case t.Kind() == reflect.Ptr && t.Elem().Implements(jsonMarshaler): + v = v.Elem() + fallthrough + + case t.Implements(jsonMarshaler): + enc := v.Interface().(json.Marshaler) //nolint:forcetypeassert + data, err := enc.MarshalJSON() if err != nil { return err } w.Write(data) return nil - case t == reflect.TypeFor[json.RawMessage](): - data, err := json.Marshal(v.Interface().(json.RawMessage)) + case t.Kind() == reflect.Ptr && t.Elem().Implements(textMarshaler): + v = v.Elem() + fallthrough + + case t.Implements(textMarshaler): + enc := v.Interface().(encoding.TextMarshaler) //nolint:forcetypeassert + data, err := enc.MarshalText() + if err != nil { + return err + } + data, err = json.Marshal(string(data)) if err != nil { return err } w.Write(data) return nil - - case isOption(v.Type()): - return encodeOption(v, w) } switch v.Kind() { @@ -95,24 +107,6 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error { } } -var ftlOptionTypePath = "github.com/TBD54566975/ftl/go-runtime/ftl.Option" - -func isOption(t reflect.Type) bool { - return strings.HasPrefix(t.PkgPath()+"."+t.Name(), ftlOptionTypePath) -} - -func encodeOption(v reflect.Value, w *bytes.Buffer) error { - if v.NumField() != 2 { - return fmt.Errorf("value cannot have type ftl.Option since it has %d fields rather than 2: %v", v.NumField(), v) - } - optionOk := v.Field(1).Bool() - if !optionOk { - w.WriteString("null") - return nil - } - return encodeValue(v.Field(0), w) -} - func encodeStruct(v reflect.Value, w *bytes.Buffer) error { w.WriteRune('{') afterFirst := false @@ -219,18 +213,36 @@ func Unmarshal(data []byte, v any) error { func decodeValue(d *json.Decoder, v reflect.Value) error { if !v.CanSet() { - allBytes, _ := io.ReadAll(d.Buffered()) - return fmt.Errorf("cannot set value: %v", string(allBytes)) + return fmt.Errorf("cannot set value") } t := v.Type() - - // Special-case types switch { - case t == reflect.TypeFor[time.Time](): - return d.Decode(v.Addr().Interface()) - case isOption(v.Type()): - return decodeOption(d, v) + case v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(jsonUnmarshaler): + v = v.Addr() + fallthrough + + case t.Implements(jsonUnmarshaler): + if v.IsNil() { + v.Set(reflect.New(t.Elem())) + } + o := v.Interface() + return d.Decode(&o) + + case v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(textUnmarshaler): + v = v.Addr() + fallthrough + + case t.Implements(textUnmarshaler): + if v.IsNil() { + v.Set(reflect.New(t.Elem())) + } + dec := v.Interface().(encoding.TextUnmarshaler) //nolint:forcetypeassert + var s string + if err := d.Decode(&s); err != nil { + return err + } + return dec.UnmarshalText([]byte(s)) } switch v.Kind() { @@ -238,15 +250,13 @@ func decodeValue(d *json.Decoder, v reflect.Value) error { return decodeStruct(d, v) case reflect.Ptr: - return handleIfNextTokenIsNull(d, func(d *json.Decoder) error { + if token, err := d.Token(); err != nil { + return err + } else if token == nil { v.Set(reflect.Zero(v.Type())) return nil - }, func(d *json.Decoder) error { - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } - return decodeValue(d, v.Elem()) - }) + } + return decodeValue(d, v.Elem()) case reflect.Slice: if v.Type().Elem().Kind() == reflect.Uint8 { @@ -268,63 +278,6 @@ func decodeValue(d *json.Decoder, v reflect.Value) error { } } -func handleIfNextTokenIsNull(d *json.Decoder, ifNullFn func(*json.Decoder) error, elseFn func(*json.Decoder) error) error { - isNull, err := isNextTokenNull(d) - if err != nil { - return err - } - if isNull { - err = ifNullFn(d) - if err != nil { - return err - } - // Consume the null token - _, err := d.Token() - if err != nil { - return err - } - return nil - } - return elseFn(d) -} - -// isNextTokenNull implements a cheap/dirty version of `Peek()`, which json.Decoder does -// not support. -func isNextTokenNull(d *json.Decoder) (bool, error) { - s, err := io.ReadAll(d.Buffered()) - if err != nil { - return false, err - } - if len(s) == 0 { - return false, fmt.Errorf("cannot check emptystring for token \"null\"") - } - if s[0] != ':' { - return false, fmt.Errorf("cannot check emptystring for token \"null\"") - } - i := 1 - for len(s) > i && unicode.IsSpace(rune(s[i])) { - i++ - } - if len(s) < i+4 { - return false, nil - } - return string(s[i:i+4]) == "null", nil -} - -func decodeOption(d *json.Decoder, v reflect.Value) error { - return handleIfNextTokenIsNull(d, func(d *json.Decoder) error { - v.FieldByName("Okay").SetBool(false) - return nil - }, func(d *json.Decoder) error { - err := decodeValue(d, v.FieldByName("Val")) - if err != nil { - return err - } - v.FieldByName("Okay").SetBool(true) - return nil - }) -} - func decodeStruct(d *json.Decoder, v reflect.Value) error { if err := expectDelim(d, '{'); err != nil { return err diff --git a/go-runtime/encoding/encoding_test.go b/go-runtime/encoding/encoding_test.go index 346ea921b1..4fc6105fec 100644 --- a/go-runtime/encoding/encoding_test.go +++ b/go-runtime/encoding/encoding_test.go @@ -3,7 +3,6 @@ package encoding_test import ( "reflect" "testing" - "time" "github.com/alecthomas/assert/v2" @@ -32,8 +31,6 @@ func TestMarshal(t *testing.T) { {name: "SliceOfStrings", input: struct{ Slice []string }{[]string{"hello", "world"}}, expected: `{"slice":["hello","world"]}`}, {name: "Map", input: struct{ Map map[string]int }{map[string]int{"foo": 42}}, expected: `{"map":{"foo":42}}`}, {name: "Option", input: struct{ Option ftl.Option[int] }{ftl.Some(42)}, expected: `{"option":42}`}, - {name: "OptionNull", input: struct{ Option ftl.Option[int] }{ftl.None[int]()}, expected: `{"option":null}`}, - {name: "OptionZero", input: struct{ Option ftl.Option[int] }{ftl.Some(0)}, expected: `{"option":0}`}, {name: "OptionPtr", input: struct{ Option *ftl.Option[int] }{&somePtr}, expected: `{"option":42}`}, {name: "OptionStruct", input: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}, expected: `{"option":{"fooBar":"foo"}}`}, {name: "Unit", input: ftl.Unit{}, expected: `{}`}, @@ -72,9 +69,6 @@ func TestUnmarshal(t *testing.T) { {name: "Slice", input: `{"slice":[1,2,3]}`, expected: struct{ Slice []int }{[]int{1, 2, 3}}}, {name: "SliceOfStrings", input: `{"slice":["hello","world"]}`, expected: struct{ Slice []string }{[]string{"hello", "world"}}}, {name: "Map", input: `{"map":{"foo":42}}`, expected: struct{ Map map[string]int }{map[string]int{"foo": 42}}}, - {name: "OptionNull", input: `{"option":null}`, expected: struct{ Option ftl.Option[int] }{ftl.None[int]()}}, - {name: "OptionNullWhitespace", input: `{"option": null}`, expected: struct{ Option ftl.Option[int] }{ftl.None[int]()}}, - {name: "OptionZero", input: `{"option":0}`, expected: struct{ Option ftl.Option[int] }{ftl.Some(0)}}, {name: "Option", input: `{"option":42}`, expected: struct{ Option ftl.Option[int] }{ftl.Some(42)}}, {name: "OptionPtr", input: `{"option":42}`, expected: struct{ Option *ftl.Option[int] }{&somePtr}}, {name: "OptionStruct", input: `{"option":{"fooBar":"foo"}}`, expected: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}}, @@ -83,12 +77,6 @@ func TestUnmarshal(t *testing.T) { String string Unit ftl.Unit }{String: "something", Unit: ftl.Unit{}}}, - // Whitespaces after each `:` and multiple fields to test handling of the - // two potential terminal delimiters: `}` and `,` - {name: "ComplexFormatting", input: `{"option": null, "bool": true}`, expected: struct { - Option ftl.Option[int] - Bool bool - }{ftl.None[int](), true}}, } for _, tt := range tests { @@ -123,9 +111,7 @@ func TestRoundTrip(t *testing.T) { {name: "Slice", input: struct{ Slice []int }{[]int{1, 2, 3}}}, {name: "SliceOfStrings", input: struct{ Slice []string }{[]string{"hello", "world"}}}, {name: "Map", input: struct{ Map map[string]int }{map[string]int{"foo": 42}}}, - {name: "Time", input: struct{ Time time.Time }{time.Date(2009, time.November, 29, 21, 33, 0, 0, time.UTC)}}, {name: "Option", input: struct{ Option ftl.Option[int] }{ftl.Some(42)}}, - {name: "OptionNull", input: struct{ Option ftl.Option[int] }{ftl.None[int]()}}, {name: "OptionPtr", input: struct{ Option *ftl.Option[int] }{&somePtr}}, {name: "OptionStruct", input: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}}, {name: "Unit", input: ftl.Unit{}}, diff --git a/go-runtime/ftl/option.go b/go-runtime/ftl/option.go index dfba72592f..73e4b5bee8 100644 --- a/go-runtime/ftl/option.go +++ b/go-runtime/ftl/option.go @@ -5,20 +5,25 @@ import ( "database/sql" "database/sql/driver" "encoding" + "encoding/json" "fmt" "reflect" + + ftlencoding "github.com/TBD54566975/ftl/go-runtime/encoding" ) // Stdlib interfaces types implement. type stdlib interface { fmt.Stringer fmt.GoStringer + json.Marshaler + json.Unmarshaler } // An Option type is a type that can contain a value or nothing. type Option[T any] struct { - Val T - Okay bool + value T + ok bool } var _ driver.Valuer = (*Option[int])(nil) @@ -26,69 +31,69 @@ var _ sql.Scanner = (*Option[int])(nil) func (o *Option[T]) Scan(src any) error { if src == nil { - o.Okay = false + o.ok = false var zero T - o.Val = zero + o.value = zero return nil } if value, ok := src.(T); ok { - o.Val = value - o.Okay = true + o.value = value + o.ok = true return nil } var value T switch scan := any(&value).(type) { case sql.Scanner: if err := scan.Scan(src); err != nil { - return fmt.Errorf("cannot scan %T into Option[%T]: %w", src, o.Val, err) + return fmt.Errorf("cannot scan %T into Option[%T]: %w", src, o.value, err) } - o.Val = value - o.Okay = true + o.value = value + o.ok = true case encoding.TextUnmarshaler: switch src := src.(type) { case string: if err := scan.UnmarshalText([]byte(src)); err != nil { - return fmt.Errorf("unmarshal from %T into Option[%T] failed: %w", src, o.Val, err) + return fmt.Errorf("unmarshal from %T into Option[%T] failed: %w", src, o.value, err) } - o.Val = value - o.Okay = true + o.value = value + o.ok = true case []byte: if err := scan.UnmarshalText(src); err != nil { - return fmt.Errorf("cannot scan %T into Option[%T]: %w", src, o.Val, err) + return fmt.Errorf("cannot scan %T into Option[%T]: %w", src, o.value, err) } - o.Val = value - o.Okay = true + o.value = value + o.ok = true default: - return fmt.Errorf("cannot unmarshal %T into Option[%T]", src, o.Val) + return fmt.Errorf("cannot unmarshal %T into Option[%T]", src, o.value) } default: - return fmt.Errorf("no decoding mechanism found for %T into Option[%T]", src, o.Val) + return fmt.Errorf("no decoding mechanism found for %T into Option[%T]", src, o.value) } return nil } func (o Option[T]) Value() (driver.Value, error) { - if !o.Okay { + if !o.ok { return nil, nil } - switch value := any(o.Val).(type) { + switch value := any(o.value).(type) { case driver.Valuer: return value.Value() case encoding.TextMarshaler: return value.MarshalText() } - return o.Val, nil + return o.value, nil } var _ stdlib = (*Option[int])(nil) // Some returns an Option that contains a value. -func Some[T any](value T) Option[T] { return Option[T]{Val: value, Okay: true} } +func Some[T any](value T) Option[T] { return Option[T]{value: value, ok: true} } // None returns an Option that contains nothing. func None[T any]() Option[T] { return Option[T]{} } @@ -132,46 +137,65 @@ func Zero[T any](value T) Option[T] { // Ptr returns a pointer to the value if the Option contains a value, otherwise nil. func (o Option[T]) Ptr() *T { - if o.Okay { - return &o.Val + if o.ok { + return &o.value } return nil } // Ok returns true if the Option contains a value. -func (o Option[T]) Ok() bool { return o.Okay } +func (o Option[T]) Ok() bool { return o.ok } // MustGet returns the value. It panics if the Option contains nothing. func (o Option[T]) MustGet() T { - if !o.Okay { + if !o.ok { var t T panic(fmt.Sprintf("Option[%T] contains nothing", t)) } - return o.Val + return o.value } // Get returns the value and a boolean indicating if the Option contains a value. -func (o Option[T]) Get() (T, bool) { return o.Val, o.Okay } +func (o Option[T]) Get() (T, bool) { return o.value, o.ok } // Default returns the Option value if it is present, otherwise it returns the // value passed. func (o Option[T]) Default(value T) T { - if o.Okay { - return o.Val + if o.ok { + return o.value } return value } +func (o Option[T]) MarshalJSON() ([]byte, error) { + if o.ok { + return ftlencoding.Marshal(o.value) + } + return []byte("null"), nil +} + +func (o *Option[T]) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + o.ok = false + return nil + } + if err := ftlencoding.Unmarshal(data, &o.value); err != nil { + return err + } + o.ok = true + return nil +} + func (o Option[T]) String() string { - if o.Okay { - return fmt.Sprintf("%v", o.Val) + if o.ok { + return fmt.Sprintf("%v", o.value) } return "None" } func (o Option[T]) GoString() string { - if o.Okay { - return fmt.Sprintf("Some[%T](%#v)", o.Val, o.Val) + if o.ok { + return fmt.Sprintf("Some[%T](%#v)", o.value, o.value) } - return fmt.Sprintf("None[%T]()", o.Val) + return fmt.Sprintf("None[%T]()", o.value) } diff --git a/go-runtime/ftl/option_test.go b/go-runtime/ftl/option_test.go index d07eb04279..13b104fc1f 100644 --- a/go-runtime/ftl/option_test.go +++ b/go-runtime/ftl/option_test.go @@ -2,6 +2,7 @@ package ftl import ( "database/sql" + "encoding/json" "testing" "github.com/alecthomas/assert/v2" @@ -19,6 +20,27 @@ func TestOptionGet(t *testing.T) { assert.False(t, ok) } +func TestOptionMarshalJSON(t *testing.T) { + o := Some(1) + b, err := o.MarshalJSON() + assert.NoError(t, err) + assert.Equal(t, "1", string(b)) + + o = None[int]() + b, err = o.MarshalJSON() + assert.NoError(t, err) + assert.Equal(t, "null", string(b)) +} + +func TestOptionUnmarshalJSON(t *testing.T) { + o := Option[int]{} + err := json.Unmarshal([]byte("1"), &o) + assert.NoError(t, err) + b, ok := o.Get() + assert.True(t, ok) + assert.Equal(t, 1, b) +} + func TestOptionString(t *testing.T) { o := Some(1) assert.Equal(t, "1", o.String())