From 35f90ee432a530a9846e1c34bf2fae798d2d0059 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Thu, 21 Dec 2023 22:15:13 +0100 Subject: [PATCH] feat: move generic decoding to codec level (#336) --- codec_dynamic.go | 4 +- codec_generic.go | 135 ++++++++++++++++++++ codec_generic_internal_test.go | 221 +++++++++++++++++++++++++++++++++ codec_union.go | 2 +- 4 files changed, 359 insertions(+), 3 deletions(-) create mode 100644 codec_generic.go create mode 100644 codec_generic_internal_test.go diff --git a/codec_dynamic.go b/codec_dynamic.go index 631b0d53..4b6271d9 100644 --- a/codec_dynamic.go +++ b/codec_dynamic.go @@ -15,13 +15,13 @@ func (d *efaceDecoder) Decode(ptr unsafe.Pointer, r *Reader) { pObj := (*any)(ptr) obj := *pObj if obj == nil { - *pObj = r.ReadNext(d.schema) + *pObj = genericDecode(d.schema, r) return } typ := reflect2.TypeOf(obj) if typ.Kind() != reflect.Ptr { - *pObj = r.ReadNext(d.schema) + *pObj = genericDecode(d.schema, r) return } diff --git a/codec_generic.go b/codec_generic.go new file mode 100644 index 00000000..4265288c --- /dev/null +++ b/codec_generic.go @@ -0,0 +1,135 @@ +package avro + +import ( + "fmt" + "math/big" + "time" + "unsafe" + + "github.com/modern-go/reflect2" +) + +func genericDecode(schema Schema, r *Reader) any { + rPtr, rTyp, err := genericReceiver(schema) + if err != nil { + r.ReportError("Read", err.Error()) + return nil + } + decoderOfType(r.cfg, schema, rTyp).Decode(rPtr, r) + if r.Error != nil { + return nil + } + obj := rTyp.UnsafeIndirect(rPtr) + if reflect2.IsNil(obj) { + return nil + } + + // seems generic reader is not compatible with codec + if rTyp.Type1() == ratType { + dec := obj.(big.Rat) + return &dec + } + + return obj +} + +func genericReceiver(schema Schema) (unsafe.Pointer, reflect2.Type, error) { + var ls LogicalSchema + lts, ok := schema.(LogicalTypeSchema) + if ok { + ls = lts.Logical() + } + + name := string(schema.Type()) + if ls != nil { + name += "." + string(ls.Type()) + } + + switch schema.Type() { + case Boolean: + var v bool + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Int: + if ls != nil { + switch ls.Type() { + case Date: + var v time.Time + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + + case TimeMillis: + var v time.Duration + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + } + } + var v int + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Long: + if ls != nil { + switch ls.Type() { + case TimeMicros: + var v time.Duration + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + + case TimestampMillis: + var v time.Time + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + + case TimestampMicros: + var v time.Time + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + } + } + var v int64 + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Float: + var v float32 + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Double: + var v float64 + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case String: + var v string + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Bytes: + if ls != nil && ls.Type() == Decimal { + var v *big.Rat + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + } + var v []byte + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Record: + var v map[string]any + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Ref: + return genericReceiver(schema.(*RefSchema).Schema()) + case Enum: + var v string + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Array: + v := make([]any, 0) + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Map: + var v map[string]any + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Union: + var v map[string]any + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Fixed: + fixed := schema.(*FixedSchema) + ls := fixed.Logical() + if ls != nil { + switch ls.Type() { + case Duration: + var v LogicalDuration + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Decimal: + var v big.Rat + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + } + } + v := byteSliceToArray(make([]byte, fixed.Size()), fixed.Size()) + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + default: + return nil, nil, fmt.Errorf("dynamic receiver not found for schema: %v", name) + } +} diff --git a/codec_generic_internal_test.go b/codec_generic_internal_test.go new file mode 100644 index 00000000..b04f1984 --- /dev/null +++ b/codec_generic_internal_test.go @@ -0,0 +1,221 @@ +package avro + +import ( + "bytes" + "math/big" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenericDecode(t *testing.T) { + tests := []struct { + name string + data []byte + schema string + want any + wantErr require.ErrorAssertionFunc + }{ + + { + name: "Bool", + data: []byte{0x01}, + schema: "boolean", + want: true, + wantErr: require.NoError, + }, + { + name: "Int", + data: []byte{0x36}, + schema: "int", + want: 27, + wantErr: require.NoError, + }, + { + name: "Int Date", + data: []byte{0xAE, 0x9D, 0x02}, + schema: `{"type":"int","logicalType":"date"}`, + want: time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC), + wantErr: require.NoError, + }, + { + name: "Int Time-Millis", + data: []byte{0xAA, 0xB4, 0xDE, 0x75}, + schema: `{"type":"int","logicalType":"time-millis"}`, + want: 123456789 * time.Millisecond, + wantErr: require.NoError, + }, + { + name: "Long", + data: []byte{0x36}, + schema: "long", + want: int64(27), + wantErr: require.NoError, + }, + { + name: "Long Time-Micros", + data: []byte{0x86, 0xEA, 0xC8, 0xE9, 0x97, 0x07}, + schema: `{"type":"long","logicalType":"time-micros"}`, + want: 123456789123 * time.Microsecond, + wantErr: require.NoError, + }, + { + name: "Long Timestamp-Millis", + data: []byte{0x90, 0xB2, 0xAE, 0xC3, 0xEC, 0x5B}, + schema: `{"type":"long","logicalType":"timestamp-millis"}`, + want: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC), + wantErr: require.NoError, + }, + { + name: "Long Timestamp-Micros", + data: []byte{0x80, 0xCD, 0xB7, 0xA2, 0xEE, 0xC7, 0xCD, 0x05}, + schema: `{"type":"long","logicalType":"timestamp-micros"}`, + want: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC), + wantErr: require.NoError, + }, + { + name: "Float", + data: []byte{0x33, 0x33, 0x93, 0x3F}, + schema: "float", + want: float32(1.15), + wantErr: require.NoError, + }, + { + name: "Double", + data: []byte{0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0xF2, 0x3F}, + schema: "double", + want: float64(1.15), + wantErr: require.NoError, + }, + { + name: "String", + data: []byte{0x06, 0x66, 0x6F, 0x6F}, + schema: "string", + want: "foo", + wantErr: require.NoError, + }, + { + name: "Bytes", + data: []byte{0x08, 0xEC, 0xAB, 0x44, 0x00}, + schema: "bytes", + want: []byte{0xEC, 0xAB, 0x44, 0x00}, + wantErr: require.NoError, + }, + { + name: "Bytes Decimal", + data: []byte{0x6, 0x00, 0x87, 0x78}, + schema: `{"type":"bytes","logicalType":"decimal","precision":4,"scale":2}`, + want: big.NewRat(1734, 5), + wantErr: require.NoError, + }, + { + name: "Record", + data: []byte{0x36, 0x06, 0x66, 0x6f, 0x6f}, + schema: `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": "long"}, {"name": "b", "type": "string"}]}`, + want: map[string]any{"a": int64(27), "b": "foo"}, + wantErr: require.NoError, + }, + { + name: "Ref", + data: []byte{0x36, 0x06, 0x66, 0x6f, 0x6f, 0x36, 0x06, 0x66, 0x6f, 0x6f}, + schema: `{"type":"record","name":"parent","fields":[{"name":"a","type":{"type":"record","name":"test","fields":[{"name":"a","type":"long"},{"name":"b","type":"string"}]}},{"name":"b","type":"test"}]}`, + want: map[string]any{"a": map[string]any{"a": int64(27), "b": "foo"}, "b": map[string]any{"a": int64(27), "b": "foo"}}, + wantErr: require.NoError, + }, + { + name: "Array", + data: []byte{0x04, 0x36, 0x38, 0x0}, + schema: `{"type":"array", "items": "int"}`, + want: []any{27, 28}, + wantErr: require.NoError, + }, + { + name: "Map", + data: []byte{0x02, 0x06, 0x66, 0x6F, 0x6F, 0x06, 0x66, 0x6F, 0x6F, 0x00}, + schema: `{"type":"map", "values": "string"}`, + want: map[string]any{"foo": "foo"}, + wantErr: require.NoError, + }, + { + name: "Enum", + data: []byte{0x02}, + schema: `{"type":"enum", "name": "test", "symbols": ["foo", "bar"]}`, + want: "bar", + wantErr: require.NoError, + }, + { + name: "Enum Invalid Symbol", + data: []byte{0x04}, + schema: `{"type":"enum", "name": "test", "symbols": ["foo", "bar"]}`, + want: nil, + wantErr: require.Error, + }, + { + name: "Union", + data: []byte{0x02, 0x06, 0x66, 0x6F, 0x6F}, + schema: `["null", "string"]`, + want: map[string]any{"string": "foo"}, + wantErr: require.NoError, + }, + { + name: "Union Nil", + data: []byte{0x00}, + schema: `["null", "string"]`, + want: nil, + wantErr: require.NoError, + }, + { + name: "Union Named", + data: []byte{0x02, 0x02}, + schema: `["null", {"type":"enum", "name": "test", "symbols": ["foo", "bar"]}]`, + want: map[string]any{"test": "bar"}, + wantErr: require.NoError, + }, + { + name: "Union Invalid Schema", + data: []byte{0x04}, + schema: `["null", "string"]`, + want: nil, + wantErr: require.Error, + }, + { + name: "Fixed", + data: []byte{0x66, 0x6F, 0x6F, 0x66, 0x6F, 0x6F}, + schema: `{"type":"fixed", "name": "test", "size": 6}`, + want: [6]byte{'f', 'o', 'o', 'f', 'o', 'o'}, + wantErr: require.NoError, + }, + { + name: "Fixed Decimal", + data: []byte{0x00, 0x00, 0x00, 0x00, 0x87, 0x78}, + schema: `{"type":"fixed", "name": "test", "size": 6,"logicalType":"decimal","precision":4,"scale":2}`, + want: big.NewRat(1734, 5), + wantErr: require.NoError, + }, + } + + for i, test := range tests { + test := test + t.Run(strconv.Itoa(i), func(t *testing.T) { + schema := MustParse(test.schema) + r := NewReader(bytes.NewReader(test.data), 10) + + got := genericDecode(schema, r) + + test.wantErr(t, r.Error) + assert.Equal(t, test.want, got) + }) + } +} + +func TestGenericDecode_UnsupportedType(t *testing.T) { + schema := NewPrimitiveSchema(Type("test"), nil) + r := NewReader(bytes.NewReader([]byte{0x01}), 10) + + _ = genericDecode(schema, r) + + assert.Error(t, r.Error) +} diff --git a/codec_union.go b/codec_union.go index d07cbf0a..7f864be6 100644 --- a/codec_union.go +++ b/codec_union.go @@ -294,7 +294,7 @@ func (d *unionResolvedDecoder) Decode(ptr unsafe.Pointer, r *Reader) { // We cannot resolve this, set it to the map type name := schemaTypeName(schema) obj := map[string]any{} - obj[name] = r.ReadNext(schema) + obj[name] = genericDecode(schema, r) *pObj = obj return