Skip to content

Commit

Permalink
feat: move generic decoding to codec level (#336)
Browse files Browse the repository at this point in the history
  • Loading branch information
redaLaanait authored Dec 21, 2023
1 parent c208c84 commit 35f90ee
Show file tree
Hide file tree
Showing 4 changed files with 359 additions and 3 deletions.
4 changes: 2 additions & 2 deletions codec_dynamic.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ func (d *efaceDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
pObj := (*any)(ptr)
obj := *pObj
if obj == nil {
*pObj = r.ReadNext(d.schema)
*pObj = genericDecode(d.schema, r)
return
}

typ := reflect2.TypeOf(obj)
if typ.Kind() != reflect.Ptr {
*pObj = r.ReadNext(d.schema)
*pObj = genericDecode(d.schema, r)
return
}

Expand Down
135 changes: 135 additions & 0 deletions codec_generic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package avro

import (
"fmt"
"math/big"
"time"
"unsafe"

"github.com/modern-go/reflect2"
)

func genericDecode(schema Schema, r *Reader) any {
rPtr, rTyp, err := genericReceiver(schema)
if err != nil {
r.ReportError("Read", err.Error())
return nil
}
decoderOfType(r.cfg, schema, rTyp).Decode(rPtr, r)
if r.Error != nil {
return nil
}
obj := rTyp.UnsafeIndirect(rPtr)
if reflect2.IsNil(obj) {
return nil
}

// seems generic reader is not compatible with codec
if rTyp.Type1() == ratType {
dec := obj.(big.Rat)
return &dec
}

return obj
}

func genericReceiver(schema Schema) (unsafe.Pointer, reflect2.Type, error) {
var ls LogicalSchema
lts, ok := schema.(LogicalTypeSchema)
if ok {
ls = lts.Logical()
}

name := string(schema.Type())
if ls != nil {
name += "." + string(ls.Type())
}

switch schema.Type() {
case Boolean:
var v bool
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Int:
if ls != nil {
switch ls.Type() {
case Date:
var v time.Time
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil

case TimeMillis:
var v time.Duration
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
}
}
var v int
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Long:
if ls != nil {
switch ls.Type() {
case TimeMicros:
var v time.Duration
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil

case TimestampMillis:
var v time.Time
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil

case TimestampMicros:
var v time.Time
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
}
}
var v int64
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Float:
var v float32
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Double:
var v float64
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case String:
var v string
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Bytes:
if ls != nil && ls.Type() == Decimal {
var v *big.Rat
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
}
var v []byte
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Record:
var v map[string]any
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Ref:
return genericReceiver(schema.(*RefSchema).Schema())
case Enum:
var v string
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Array:
v := make([]any, 0)
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Map:
var v map[string]any
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Union:
var v map[string]any
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Fixed:
fixed := schema.(*FixedSchema)
ls := fixed.Logical()
if ls != nil {
switch ls.Type() {
case Duration:
var v LogicalDuration
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Decimal:
var v big.Rat
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
}
}
v := byteSliceToArray(make([]byte, fixed.Size()), fixed.Size())
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
default:
return nil, nil, fmt.Errorf("dynamic receiver not found for schema: %v", name)
}
}
221 changes: 221 additions & 0 deletions codec_generic_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
package avro

import (
"bytes"
"math/big"
"strconv"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGenericDecode(t *testing.T) {
tests := []struct {
name string
data []byte
schema string
want any
wantErr require.ErrorAssertionFunc
}{

{
name: "Bool",
data: []byte{0x01},
schema: "boolean",
want: true,
wantErr: require.NoError,
},
{
name: "Int",
data: []byte{0x36},
schema: "int",
want: 27,
wantErr: require.NoError,
},
{
name: "Int Date",
data: []byte{0xAE, 0x9D, 0x02},
schema: `{"type":"int","logicalType":"date"}`,
want: time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC),
wantErr: require.NoError,
},
{
name: "Int Time-Millis",
data: []byte{0xAA, 0xB4, 0xDE, 0x75},
schema: `{"type":"int","logicalType":"time-millis"}`,
want: 123456789 * time.Millisecond,
wantErr: require.NoError,
},
{
name: "Long",
data: []byte{0x36},
schema: "long",
want: int64(27),
wantErr: require.NoError,
},
{
name: "Long Time-Micros",
data: []byte{0x86, 0xEA, 0xC8, 0xE9, 0x97, 0x07},
schema: `{"type":"long","logicalType":"time-micros"}`,
want: 123456789123 * time.Microsecond,
wantErr: require.NoError,
},
{
name: "Long Timestamp-Millis",
data: []byte{0x90, 0xB2, 0xAE, 0xC3, 0xEC, 0x5B},
schema: `{"type":"long","logicalType":"timestamp-millis"}`,
want: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC),
wantErr: require.NoError,
},
{
name: "Long Timestamp-Micros",
data: []byte{0x80, 0xCD, 0xB7, 0xA2, 0xEE, 0xC7, 0xCD, 0x05},
schema: `{"type":"long","logicalType":"timestamp-micros"}`,
want: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC),
wantErr: require.NoError,
},
{
name: "Float",
data: []byte{0x33, 0x33, 0x93, 0x3F},
schema: "float",
want: float32(1.15),
wantErr: require.NoError,
},
{
name: "Double",
data: []byte{0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0xF2, 0x3F},
schema: "double",
want: float64(1.15),
wantErr: require.NoError,
},
{
name: "String",
data: []byte{0x06, 0x66, 0x6F, 0x6F},
schema: "string",
want: "foo",
wantErr: require.NoError,
},
{
name: "Bytes",
data: []byte{0x08, 0xEC, 0xAB, 0x44, 0x00},
schema: "bytes",
want: []byte{0xEC, 0xAB, 0x44, 0x00},
wantErr: require.NoError,
},
{
name: "Bytes Decimal",
data: []byte{0x6, 0x00, 0x87, 0x78},
schema: `{"type":"bytes","logicalType":"decimal","precision":4,"scale":2}`,
want: big.NewRat(1734, 5),
wantErr: require.NoError,
},
{
name: "Record",
data: []byte{0x36, 0x06, 0x66, 0x6f, 0x6f},
schema: `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": "long"}, {"name": "b", "type": "string"}]}`,
want: map[string]any{"a": int64(27), "b": "foo"},
wantErr: require.NoError,
},
{
name: "Ref",
data: []byte{0x36, 0x06, 0x66, 0x6f, 0x6f, 0x36, 0x06, 0x66, 0x6f, 0x6f},
schema: `{"type":"record","name":"parent","fields":[{"name":"a","type":{"type":"record","name":"test","fields":[{"name":"a","type":"long"},{"name":"b","type":"string"}]}},{"name":"b","type":"test"}]}`,
want: map[string]any{"a": map[string]any{"a": int64(27), "b": "foo"}, "b": map[string]any{"a": int64(27), "b": "foo"}},
wantErr: require.NoError,
},
{
name: "Array",
data: []byte{0x04, 0x36, 0x38, 0x0},
schema: `{"type":"array", "items": "int"}`,
want: []any{27, 28},
wantErr: require.NoError,
},
{
name: "Map",
data: []byte{0x02, 0x06, 0x66, 0x6F, 0x6F, 0x06, 0x66, 0x6F, 0x6F, 0x00},
schema: `{"type":"map", "values": "string"}`,
want: map[string]any{"foo": "foo"},
wantErr: require.NoError,
},
{
name: "Enum",
data: []byte{0x02},
schema: `{"type":"enum", "name": "test", "symbols": ["foo", "bar"]}`,
want: "bar",
wantErr: require.NoError,
},
{
name: "Enum Invalid Symbol",
data: []byte{0x04},
schema: `{"type":"enum", "name": "test", "symbols": ["foo", "bar"]}`,
want: nil,
wantErr: require.Error,
},
{
name: "Union",
data: []byte{0x02, 0x06, 0x66, 0x6F, 0x6F},
schema: `["null", "string"]`,
want: map[string]any{"string": "foo"},
wantErr: require.NoError,
},
{
name: "Union Nil",
data: []byte{0x00},
schema: `["null", "string"]`,
want: nil,
wantErr: require.NoError,
},
{
name: "Union Named",
data: []byte{0x02, 0x02},
schema: `["null", {"type":"enum", "name": "test", "symbols": ["foo", "bar"]}]`,
want: map[string]any{"test": "bar"},
wantErr: require.NoError,
},
{
name: "Union Invalid Schema",
data: []byte{0x04},
schema: `["null", "string"]`,
want: nil,
wantErr: require.Error,
},
{
name: "Fixed",
data: []byte{0x66, 0x6F, 0x6F, 0x66, 0x6F, 0x6F},
schema: `{"type":"fixed", "name": "test", "size": 6}`,
want: [6]byte{'f', 'o', 'o', 'f', 'o', 'o'},
wantErr: require.NoError,
},
{
name: "Fixed Decimal",
data: []byte{0x00, 0x00, 0x00, 0x00, 0x87, 0x78},
schema: `{"type":"fixed", "name": "test", "size": 6,"logicalType":"decimal","precision":4,"scale":2}`,
want: big.NewRat(1734, 5),
wantErr: require.NoError,
},
}

for i, test := range tests {
test := test
t.Run(strconv.Itoa(i), func(t *testing.T) {
schema := MustParse(test.schema)
r := NewReader(bytes.NewReader(test.data), 10)

got := genericDecode(schema, r)

test.wantErr(t, r.Error)
assert.Equal(t, test.want, got)
})
}
}

func TestGenericDecode_UnsupportedType(t *testing.T) {
schema := NewPrimitiveSchema(Type("test"), nil)
r := NewReader(bytes.NewReader([]byte{0x01}), 10)

_ = genericDecode(schema, r)

assert.Error(t, r.Error)
}
2 changes: 1 addition & 1 deletion codec_union.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ func (d *unionResolvedDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
// We cannot resolve this, set it to the map type
name := schemaTypeName(schema)
obj := map[string]any{}
obj[name] = r.ReadNext(schema)
obj[name] = genericDecode(schema, r)

*pObj = obj
return
Expand Down

0 comments on commit 35f90ee

Please sign in to comment.