From b9d744286b0da73eaa7439607f720d3091b60f76 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Mon, 13 Nov 2023 18:19:07 +0100 Subject: [PATCH 01/25] feat: add support for schema evolution poc --- codec_default.go | 565 ++++++++++++++++++++++++++++++ codec_default_test.go | 644 +++++++++++++++++++++++++++++++++++ codec_record.go | 46 +++ schema.go | 18 + schema_compatibility.go | 186 +++++++++- schema_compatibility_test.go | 74 ++++ 6 files changed, 1531 insertions(+), 2 deletions(-) create mode 100644 codec_default.go create mode 100644 codec_default_test.go diff --git a/codec_default.go b/codec_default.go new file mode 100644 index 00000000..59618c7d --- /dev/null +++ b/codec_default.go @@ -0,0 +1,565 @@ +package avro + +import ( + "encoding" + "encoding/binary" + "fmt" + "math/big" + "reflect" + "unsafe" + + "github.com/modern-go/reflect2" +) + +func createDefaultDecoder(cfg *frozenConfig, schema Schema, def any, typ reflect2.Type) ValDecoder { + if typ.Kind() == reflect.Interface { + if schema.Type() != Union && schema.Type() != Null { + return &efaceDefaultDecoder{def: def} + } + } + + switch schema.Type() { + case Null: + return &nullDefaultDecoder{} + + case Boolean: + return &boolDefaultDecoder{ + def: def, + typ: typ, + } + + case Int: + return &intDefaultDecoder{ + def: def, + typ: typ, + } + + case Long: + return &longDefaultDecoder{ + def: def, + typ: typ, + } + + case Float: + return &floatDefaultDecoder{ + def: def, + typ: typ, + } + + case Double: + return &doubleDefaultDecoder{ + def: def, + typ: typ, + } + + case String: + if typ.Implements(textUnmarshalerType) { + return &textDefaultMarshalerCodec{typ, def} + } + ptrType := reflect2.PtrTo(typ) + if ptrType.Implements(textUnmarshalerType) { + return &referenceDecoder{ + &textDefaultMarshalerCodec{typ: ptrType, def: def}, + } + } + + return &stringDefaultDecoder{ + def: def, + typ: typ, + } + + case Bytes: + return &bytesDefaultDecoder{ + def: def, + typ: typ, + } + + case Fixed: + return &fixedDefaultDecoder{ + fixed: schema.(*FixedSchema), + def: def, + typ: typ, + } + + case Enum: + return &enumDefaultDecoder{typ: typ, def: def} + + case Ref: + return createDefaultDecoder(cfg, schema.(*RefSchema).Schema(), def, typ) + + case Record: + return defaultDecoderOfRecord(cfg, schema, def, typ) + + case Array: + return defaultDecoderOfArray(cfg, schema, def, typ) + + case Map: + return defaultDecoderOfMap(cfg, schema, def, typ) + + case Union: + return createDefaultDecoder(cfg, schema.(*UnionSchema).Types()[0], def, typ) + + default: + return &errorDecoder{err: fmt.Errorf("avro: schema type %s is unsupported", schema.Type())} + } +} + +type textDefaultMarshalerCodec struct { + typ reflect2.Type + def any +} + +func (d textDefaultMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) { + obj := d.typ.UnsafeIndirect(ptr) + if reflect2.IsNil(obj) { + ptrType := d.typ.(*reflect2.UnsafePtrType) + newPtr := ptrType.Elem().UnsafeNew() + *((*unsafe.Pointer)(ptr)) = newPtr + obj = d.typ.UnsafeIndirect(ptr) + } + unmarshaler := (obj).(encoding.TextUnmarshaler) + + b := []byte(d.def.(string)) + + err := unmarshaler.UnmarshalText(b) + if err != nil { + r.ReportError("textMarshalerCodec", err.Error()) + } +} + +type efaceDefaultDecoder struct { + def any +} + +func (d *efaceDefaultDecoder) Decode(ptr unsafe.Pointer, _ *Reader) { + *(*any)(ptr) = d.def +} + +type boolDefaultDecoder struct { + def any + typ reflect2.Type +} + +func (d *boolDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + def, ok := d.def.(bool) + if !ok { + r.ReportError("decode default", "inconvertible type") + return + } + *((*bool)(ptr)) = def +} + +type nullDefaultDecoder struct { +} + +func (d *nullDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + return +} + +type intDefaultDecoder struct { + def any + typ reflect2.Type +} + +func (d *intDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + def := d.def + if reflect.TypeOf(d.def) != d.typ.Type1() { + if !reflect.TypeOf(d.def).ConvertibleTo(d.typ.Type1()) { + r.ReportError("decode default", "inconvertible type") + return + } + + def = reflect.ValueOf(d.def).Convert(d.typ.Type1()).Interface() + } + + switch d.typ.Kind() { + case reflect.Int: + *((*int)(ptr)) = def.(int) + case reflect.Uint: + *((*uint)(ptr)) = def.(uint) + case reflect.Int8: + *((*int8)(ptr)) = def.(int8) + case reflect.Uint8: + *((*uint8)(ptr)) = def.(uint8) + case reflect.Int16: + *((*int16)(ptr)) = def.(int16) + case reflect.Uint16: + *((*uint16)(ptr)) = def.(uint16) + case reflect.Int32: + *((*int32)(ptr)) = def.(int32) + default: + r.ReportError("decode default", "unsupported type") + } +} + +type longDefaultDecoder struct { + def any + typ reflect2.Type +} + +func (d *longDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + def := d.def + if reflect.TypeOf(d.def) != d.typ.Type1() { + if !reflect.TypeOf(d.def).ConvertibleTo(d.typ.Type1()) { + r.ReportError("decode default", "inconvertible type") + return + } + + def = reflect.ValueOf(d.def).Convert(d.typ.Type1()).Interface() + } + + switch d.typ.Kind() { + case reflect.Int32: + *((*int32)(ptr)) = def.(int32) + case reflect.Uint32: + *((*uint32)(ptr)) = def.(uint32) + case reflect.Int64: + *((*int64)(ptr)) = def.(int64) + default: + r.ReportError("decode default", "unsupported type") + } +} + +type floatDefaultDecoder struct { + def any + typ reflect2.Type +} + +func (d *floatDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + def := d.def + if reflect.TypeOf(d.def) != d.typ.Type1() { + if !reflect.TypeOf(d.def).ConvertibleTo(d.typ.Type1()) { + r.ReportError("decode default", "inconvertible type") + return + } + + def = reflect.ValueOf(d.def).Convert(d.typ.Type1()).Interface() + } + + switch d.typ.Kind() { + case reflect.Float32: + *((*float32)(ptr)) = def.(float32) + case reflect.Float64: + *((*float64)(ptr)) = def.(float64) + default: + r.ReportError("decode default", "unsupported type") + } +} + +type doubleDefaultDecoder struct { + def any + typ reflect2.Type +} + +func (d *doubleDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + def := d.def + if reflect.TypeOf(d.def) != d.typ.Type1() { + if !reflect.TypeOf(d.def).ConvertibleTo(d.typ.Type1()) { + r.ReportError("decode default", "inconvertible type") + return + } + + def = reflect.ValueOf(d.def).Convert(d.typ.Type1()).Interface() + } + + switch d.typ.Kind() { + case reflect.Float64: + *((*float64)(ptr)) = def.(float64) + case reflect.Float32: + *((*float32)(ptr)) = def.(float32) + default: + r.ReportError("decode default", "unsupported type") + } + +} + +type stringDefaultDecoder struct { + def any + typ reflect2.Type +} + +func (d *stringDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + def, ok := d.def.(string) + if !ok { + r.ReportError("decode default", "inconvertible type") + return + } + + *((*string)(ptr)) = def +} + +type bytesDefaultDecoder struct { + def any + typ reflect2.Type +} + +func (d *bytesDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + if d.typ.Kind() != reflect.Slice { + r.ReportError("decode default", "inconvertible type") + return + } + if d.typ.(reflect2.SliceType).Elem().Kind() != reflect.Uint8 { + r.ReportError("decode default", "inconvertible type") + return + } + + def, ok := d.def.(string) + if !ok { + r.ReportError("decode default", "inconvertible type") + return + } + runes := []rune(def) + l := len(runes) + b := make([]byte, l) + for i := 0; i < l; i++ { + if runes[i] < 0 || runes[i] > 255 { + r.ReportError("decode default", "invalid default") + return + } + b[i] = uint8(runes[i]) + } + d.typ.(*reflect2.UnsafeSliceType).UnsafeSet(ptr, reflect2.PtrOf(b)) +} + +func defaultDecoderOfRecord(cfg *frozenConfig, schema Schema, def any, typ reflect2.Type) ValDecoder { + rec := schema.(*RecordSchema) + mDef, ok := def.(map[string]any) + if !ok { + return &errorDecoder{err: fmt.Errorf("avro: invalid default for record field")} + } + + fields := make([]*Field, len(rec.Fields())) + for i, field := range rec.Fields() { + f, err := NewField(field.Name(), field.Type(), + WithDefault(mDef[field.Name()]), WithAliases(field.Aliases()), WithOrder(field.Order()), + ) + if err != nil { + return &errorDecoder{err: fmt.Errorf("avro: %w", err)} + } + f.action = FieldSetDefault + fields[i] = f + } + + r, err := NewRecordSchema(rec.Name(), rec.Namespace(), fields, WithAliases(rec.Aliases())) + if err != nil { + return &errorDecoder{err: fmt.Errorf("avro: %w", err)} + } + + switch typ.Kind() { + case reflect.Struct: + return decoderOfStruct(cfg, r, typ) + case reflect.Map: + return decoderOfRecord(cfg, r, typ) + } + + return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} +} + +type enumDefaultDecoder struct { + typ reflect2.Type + def any +} + +func (d *enumDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + unmarshal := func(def string, isPtr bool) { + var obj any + if isPtr { + obj = d.typ.PackEFace(ptr) + } else { + obj = d.typ.UnsafeIndirect(ptr) + } + if reflect2.IsNil(obj) { + ptrType := d.typ.(*reflect2.UnsafePtrType) + newPtr := ptrType.Elem().UnsafeNew() + *((*unsafe.Pointer)(ptr)) = newPtr + obj = d.typ.UnsafeIndirect(ptr) + } + unmarshaler := (obj).(encoding.TextUnmarshaler) + err := unmarshaler.UnmarshalText([]byte(def)) + if err != nil { + r.ReportError("textMarshalerCodec", err.Error()) + } + } + + def, ok := d.def.(string) + if !ok { + r.ReportError("decode default", "inconvertible type") + } + + switch { + case d.typ.Kind() == reflect.String: + *((*string)(ptr)) = def + return + case reflect2.PtrTo(d.typ).Implements(textUnmarshalerType): + unmarshal(def, true) + return + case d.typ.Implements(textUnmarshalerType): + unmarshal(def, false) + return + default: + r.ReportError("decode default", "unsupported type") + } +} + +func defaultDecoderOfArray(cfg *frozenConfig, schema Schema, def any, typ reflect2.Type) ValDecoder { + if typ.Kind() != reflect.Slice { + return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} + } + + return &sliceDefaultDecoder{ + def: def, + typ: typ.(*reflect2.UnsafeSliceType), + decoder: func(def any) ValDecoder { + return createDefaultDecoder(cfg, schema.(*ArraySchema).Items(), def, typ.(*reflect2.UnsafeSliceType).Elem()) + }, + } +} + +type sliceDefaultDecoder struct { + def any + typ *reflect2.UnsafeSliceType + decoder func(def any) ValDecoder +} + +func (d *sliceDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + def, ok := d.def.([]any) + if !ok { + r.ReportError("decode default", "inconvertible type") + return + } + + size := len(def) + d.typ.UnsafeGrow(ptr, size) + for i := 0; i < size; i++ { + elemPtr := d.typ.UnsafeGetIndex(ptr, i) + d.decoder(def[i]).Decode(elemPtr, nil) + } +} + +func defaultDecoderOfMap(cfg *frozenConfig, schema Schema, def any, typ reflect2.Type) ValDecoder { + if typ.Kind() != reflect.Map { + return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} + } + + return &mapDefaultDecoder{ + typ: typ.(*reflect2.UnsafeMapType), + def: def, + decoder: func(def any) ValDecoder { + return createDefaultDecoder(cfg, schema.(*MapSchema).Values(), def, typ.(*reflect2.UnsafeMapType).Elem()) + }, + } +} + +type mapDefaultDecoder struct { + typ *reflect2.UnsafeMapType + decoder func(def any) ValDecoder + def any +} + +func (d *mapDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + def, ok := d.def.(map[string]any) + if !ok { + r.ReportError("decode default", "inconvertible type") + return + } + + if d.typ.UnsafeIsNil(ptr) { + d.typ.UnsafeSet(ptr, d.typ.UnsafeMakeMap(0)) + } + for k, v := range def { + key := k + keyPtr := reflect2.PtrOf(&key) + elemPtr := d.typ.UnsafeNew() + d.decoder(v).Decode(elemPtr, nil) + d.typ.UnsafeSetIndex(ptr, keyPtr, elemPtr) + } +} + +type fixedDefaultDecoder struct { + typ reflect2.Type + def any + fixed *FixedSchema +} + +func (d *fixedDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + def, ok := d.def.(string) + if !ok { + r.ReportError("decode default", "inconvertible type") + return + } + runes := []rune(def) + l := len(runes) + b := make([]byte, l) + for i := 0; i < l; i++ { + if runes[i] < 0 || runes[i] > 255 { + r.ReportError("decode default", "invalid default") + return + } + b[i] = uint8(runes[i]) + } + + switch d.typ.Kind() { + case reflect.Array: + arrayType := d.typ.(reflect2.ArrayType) + if arrayType.Elem().Kind() != reflect.Uint8 || arrayType.Len() != d.fixed.Size() { + r.ReportError("decode default", "unsupported type") + return + } + if arrayType.Len() != l { + r.ReportError("decode default", "invalid default") + return + } + for i := 0; i < arrayType.Len(); i++ { + arrayType.UnsafeSetIndex(ptr, i, reflect2.PtrOf(b[i])) + } + + case reflect.Uint64: + if d.fixed.Size() != 8 { + r.ReportError("decode default", "unsupported type") + return + } + if l != 8 { + r.ReportError("decode default", "invalid default") + return + } + *(*uint64)(ptr) = binary.BigEndian.Uint64(b) + + case reflect.Struct: + ls := d.fixed.Logical() + if ls == nil { + break + } + typ1 := d.typ.Type1() + switch { + case typ1.ConvertibleTo(durType) && ls.Type() == Duration: + if l != 12 { + r.ReportError("decode default", "invalid default") + return + } + *((*LogicalDuration)(ptr)) = durationFromBytes(b) + + case typ1.ConvertibleTo(ratType) && ls.Type() == Decimal: + dec := ls.(*DecimalLogicalSchema) + if d.fixed.Size() != l { + r.ReportError("decode default", "invalid default") + return + } + *((*big.Rat)(ptr)) = *ratFromBytes(b, dec.Scale()) + default: + r.ReportError("decode default", "unsupported type") + } + + default: + r.ReportError("decode default", "unsupported type") + } +} + +func durationFromBytes(b []byte) LogicalDuration { + var duration LogicalDuration + + duration.Months = binary.LittleEndian.Uint32(b[0:4]) + duration.Days = binary.LittleEndian.Uint32(b[4:8]) + duration.Milliseconds = binary.LittleEndian.Uint32(b[8:12]) + + return duration +} diff --git a/codec_default_test.go b/codec_default_test.go new file mode 100644 index 00000000..d64db27b --- /dev/null +++ b/codec_default_test.go @@ -0,0 +1,644 @@ +package avro_test + +import ( + "bytes" + "math" + "math/big" + "testing" + + "github.com/hamba/avro/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecoder_DefaultBool(t *testing.T) { + + defer ConfigTeardown() + + // write schema + // `{ + // // "type": "record", + // // "name": "test", + // // "fields" : [ + // // {"name": "a", "type": "string"} + // // ] + // // }` + + // {"a": "foo"} + data := []byte{0x6, 0x66, 0x6f, 0x6f} + + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": "boolean", "default": true} + ] + }`) + + // hack: set field action to force decode default behavior + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + dec := avro.NewDecoderForSchema(schema, bytes.NewReader(data)) + + type TestRecord struct { + A string `avro:"a"` + B bool `avro:"b"` + } + + var got TestRecord + err := dec.Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: true, A: "foo"}, got) +} + +func TestDecoder_DefaultInt(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x6, 0x66, 0x6f, 0x6f} + + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": "int", "default": 1000} + ] + }`) + + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + dec := avro.NewDecoderForSchema(schema, bytes.NewReader(data)) + + type TestRecord struct { + A string `avro:"a"` + B int32 `avro:"b"` + } + + var got TestRecord + err := dec.Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: 1000, A: "foo"}, got) +} + +func TestDecoder_DefaultLong(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x6, 0x66, 0x6f, 0x6f} + + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": "long", "default": 1000} + ] + }`) + + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + type TestRecord struct { + A string `avro:"a"` + B int64 `avro:"b"` + } + + var got TestRecord + err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: 1000, A: "foo"}, got) +} + +func TestDecoder_DefaultFloat(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x6, 0x66, 0x6f, 0x6f} + + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": "float", "default": 10.45} + ] + }`) + + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + type TestRecord struct { + A string `avro:"a"` + B float32 `avro:"b"` + } + + var got TestRecord + err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: 10.45, A: "foo"}, got) +} + +func TestDecoder_DefaultDouble(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x6, 0x66, 0x6f, 0x6f} + + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": "double", "default": 10.45} + ] + }`) + + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + type TestRecord struct { + A string `avro:"a"` + B float64 `avro:"b"` + } + + var got TestRecord + err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: 10.45, A: "foo"}, got) +} + +func TestDecoder_DefaultBytes(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x6, 0x66, 0x6f, 0x6f} + + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": "bytes", "default": "value"} + ] + }`) + + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + type TestRecord struct { + A string `avro:"a"` + B []byte `avro:"b"` + } + + var got TestRecord + err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: []byte("value"), A: "foo"}, got) +} + +func TestDecoder_DefaultString(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x6, 0x66, 0x6f, 0x6f} + + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": "string", "default": "value"} + ] + }`) + + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + type TestRecord struct { + A string `avro:"a"` + B string `avro:"b"` + } + + var got TestRecord + err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: "value", A: "foo"}, got) +} + +func TestDecoder_DefaultEnum(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x6, 0x66, 0x6f, 0x6f} + + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + { + "name": "b", + "type": { + "type": "enum", + "name": "test.enum", + "symbols": ["foo", "bar"] + }, + "default": "bar" + } + ] + }`) + + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + t.Run("simple", func(t *testing.T) { + type TestRecord struct { + A string `avro:"a"` + B string `avro:"b"` + } + + var got TestRecord + err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: "bar", A: "foo"}, got) + + }) + + t.Run("TextUnmarshaler", func(t *testing.T) { + type TestRecord struct { + A string `avro:"a"` + B testEnumTextUnmarshaler `avro:"b"` + } + + var got TestRecord + err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: 1, A: "foo"}, got) + }) + + t.Run("TextUnmarshaler Ptr", func(t *testing.T) { + type TestRecord struct { + A string `avro:"a"` + B *testEnumTextUnmarshaler `avro:"b"` + } + + var got TestRecord + err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + + require.NoError(t, err) + var v testEnumTextUnmarshaler = 1 + assert.Equal(t, TestRecord{B: &v, A: "foo"}, got) + }) +} + +func TestDecoder_DefaultUnion(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x6, 0x66, 0x6f, 0x6f} + + type TestRecord struct { + A string `avro:"a"` + B any `avro:"b"` + } + + t.Run("null default", func(t *testing.T) { + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": ["null", "long"], "default": null} + ] + }`) + + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + var got TestRecord + err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: nil, A: "foo"}, got) + }) + + t.Run("not null default", func(t *testing.T) { + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": ["string", "long"], "default": "bar"} + ] + }`) + + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + var got TestRecord + err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: "bar", A: "foo"}, got) + }) +} + +func TestDecoder_DefaultArray(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x6, 0x66, 0x6f, 0x6f} + + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + { + "name": "b", + "type": { + "type": "array", "items": "int" + }, + "default":[1, 2, 3, 4] + } + ] + }`) + + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + dec := avro.NewDecoderForSchema(schema, bytes.NewReader(data)) + + type TestRecord struct { + A string `avro:"a"` + B []int16 `avro:"b"` + } + + var got TestRecord + err := dec.Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: []int16{1, 2, 3, 4}, A: "foo"}, got) +} + +func TestDecoder_DefaultMap(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x6, 0x66, 0x6f, 0x6f} + + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + { + "name": "b", + "type": { + "type": "map", "values": "string" + }, + "default": {"foo":"bar"} + } + ] + }`) + + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + dec := avro.NewDecoderForSchema(schema, bytes.NewReader(data)) + + type TestRecord struct { + A string `avro:"a"` + B map[string]string `avro:"b"` + } + + var got TestRecord + err := dec.Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: map[string]string{"foo": "bar"}, A: "foo"}, got) +} + +func TestDecoder_DefaultRecord(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x6, 0x66, 0x6f, 0x6f} + + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + { + "name": "b", + "type": { + "type": "record", + "name": "test.record", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": "string"} + ] + }, + "default": {"a":"foo", "b": "bar"} + } + ] + }`) + + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + t.Run("struct", func(t *testing.T) { + dec := avro.NewDecoderForSchema(schema, bytes.NewReader(data)) + + type subRecord struct { + A string `avro:"a"` + B string `avro:"b"` + } + type TestRecord struct { + A string `avro:"a"` + B subRecord `avro:"b"` + } + + var got TestRecord + err := dec.Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: subRecord{A: "foo", B: "bar"}, A: "foo"}, got) + }) + + t.Run("map", func(t *testing.T) { + dec := avro.NewDecoderForSchema(schema, bytes.NewReader(data)) + + var got map[string]any + err := dec.Decode(&got) + + require.NoError(t, err) + assert.Equal(t, map[string]any{"b": map[string]any{"a": "foo", "b": "bar"}, "a": "foo"}, got) + }) +} + +func TestDecoder_DefaultRef(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x6, 0x66, 0x6f, 0x6f} + + _ = avro.MustParse(`{ + "type": "record", + "name": "test.embed", + "fields" : [ + {"name": "a", "type": "string"} + ] + }`) + + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": "test.embed", "default": {"a": "foo"}} + ] + }`) + + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + dec := avro.NewDecoderForSchema(schema, bytes.NewReader(data)) + + var got map[string]any + err := dec.Decode(&got) + + require.NoError(t, err) + assert.Equal(t, map[string]any{"b": map[string]any{"a": "foo"}, "a": "foo"}, got) +} + +func TestDecoder_DefaultFixed(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x6, 0x66, 0x6f, 0x6f} + + t.Run("array", func(t *testing.T) { + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + { + "name": "b", + "type": { + "type": "fixed", + "name": "test.fixed", + "size": 3 + }, + "default": "foo" + } + ] + }`) + + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + type TestRecord struct { + A string `avro:"a"` + B [3]byte `avro:"b"` + } + + var got TestRecord + err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: [3]byte{'f', 'o', 'o'}, A: "foo"}, got) + }) + + t.Run("uint64", func(t *testing.T) { + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + { + "name": "b", + "type": { + "type": "fixed", + "name": "test.fixed", + "size": 8 + }, + "default": "\u00ff\u00ff\u00ff\u00ff\u00ff\u00ff\u00ff\u00ff" + } + ] + }`) + + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + type TestRecord struct { + A string `avro:"a"` + B uint64 `avro:"b"` + } + + var got TestRecord + err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: uint64(math.MaxUint64), A: "foo"}, got) + }) + + t.Run("duration", func(t *testing.T) { + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + { + "name": "b", + "type": { + "type": "fixed", + "name": "test.fixed", + "logicalType":"duration", + "size":12 + }, + "default": "\u000c\u0000\u0000\u0000\u0022\u0000\u0000\u0000\u0052\u00aa\u0008\u0000" + } + ] + }`) + + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + type TestRecord struct { + A string `avro:"a"` + B avro.LogicalDuration `avro:"b"` + } + + var got TestRecord + err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + + require.NoError(t, err) + + assert.Equal(t, uint32(12), got.B.Months) + assert.Equal(t, uint32(34), got.B.Days) + assert.Equal(t, uint32(567890), got.B.Milliseconds) + assert.Equal(t, "foo", got.A) + }) + + t.Run("rat", func(t *testing.T) { + schema := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + { + "name": "b", + "type": { + "type": "fixed", + "name": "test.fixed", + "size": 6, + "logicalType":"decimal", + "precision":4, + "scale":2 + }, + "default": "\u0000\u0000\u0000\u0000\u0087\u0078" + } + ] + }`) + avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + + type TestRecord struct { + A string `avro:"a"` + B big.Rat `avro:"b"` + } + + var got TestRecord + err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + + require.NoError(t, err) + assert.Equal(t, big.NewRat(1734, 5), &got.B) + assert.Equal(t, "foo", got.A) + }) + +} diff --git a/codec_record.go b/codec_record.go index 86295f20..e62068aa 100644 --- a/codec_record.go +++ b/codec_record.go @@ -59,6 +59,13 @@ func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec fields := make([]*structFieldDecoder, 0, len(rec.Fields())) for _, field := range rec.Fields() { + if field.action == FieldDrain { + fields = append(fields, &structFieldDecoder{ + decoder: createSkipDecoder(field.Type()), + }) + continue + } + sf := structDesc.Fields.Get(field.Name()) if sf == nil { for _, alias := range field.Aliases() { @@ -77,6 +84,21 @@ func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec continue } + if field.action == FieldSetDefault { + if field.hasDef { + fields = append(fields, &structFieldDecoder{ + field: sf.Field, + decoder: createDefaultDecoder(cfg, field.Type(), field.def, sf.Field[len(sf.Field)-1].Type()), + }) + } else { + fields = append(fields, &structFieldDecoder{ + decoder: createSkipDecoder(field.Type()), + }) + } + + continue + } + dec := decoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type()) fields = append(fields, &structFieldDecoder{ field: sf.Field, @@ -237,6 +259,30 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec fields := make([]recordMapDecoderField, len(rec.Fields())) for i, field := range rec.Fields() { + if field.action == FieldDrain { + fields[i] = recordMapDecoderField{ + name: field.Name(), + decoder: createSkipDecoder(field.Type()), + } + continue + } + + if field.action == FieldSetDefault { + if field.hasDef { + fields[i] = recordMapDecoderField{ + name: field.Name(), + decoder: createDefaultDecoder(cfg, field.Type(), field.def, mapType.Elem()), + } + } else { + fields[i] = recordMapDecoderField{ + name: field.Name(), + decoder: createSkipDecoder(field.Type()), + } + } + + continue + } + fields[i] = recordMapDecoderField{ name: field.Name(), decoder: decoderOfType(cfg, field.Type(), mapType.Elem()), diff --git a/schema.go b/schema.go index 580f2633..6c67ed42 100644 --- a/schema.go +++ b/schema.go @@ -75,6 +75,14 @@ const ( Duration LogicalType = "duration" ) +// Action is a field action used during decoding process. +type Action string + +const ( + FieldDrain Action = "drain" + FieldSetDefault Action = "set_default" +) + // FingerprintType is a fingerprinting algorithm. type FingerprintType string @@ -590,6 +598,7 @@ type Field struct { hasDef bool def any order Order + action Action } type noDef struct{} @@ -642,11 +651,20 @@ func NewField(name string, typ Schema, opts ...SchemaOption) (*Field, error) { return f, nil } +// SetFieldAction updates the given field's action. Mainly used for testing purposes. +func SetFieldAction(field *Field, action Action) { + field.action = action +} + // Name returns the name of a field. func (f *Field) Name() string { return f.name } +func (f *Field) Action() Action { + return f.action +} + // Aliases return the field aliases. func (f *Field) Aliases() []string { return f.aliases diff --git a/schema_compatibility.go b/schema_compatibility.go index d1779ea9..402b4efe 100644 --- a/schema_compatibility.go +++ b/schema_compatibility.go @@ -203,7 +203,9 @@ func (c *SchemaCompatibility) checkEnumSymbols(reader, writer *EnumSchema) error func (c *SchemaCompatibility) checkRecordFields(reader, writer *RecordSchema) error { for _, field := range reader.Fields() { - f, ok := c.getField(writer.Fields(), field) + f, ok := c.getField(writer.Fields(), field, func(gfo *getFieldOptions) { + gfo.fieldAlias = true + }) if !ok { if field.HasDefault() { continue @@ -230,12 +232,192 @@ func (c *SchemaCompatibility) contains(a []string, s string) bool { return false } -func (c *SchemaCompatibility) getField(a []*Field, f *Field) (*Field, bool) { +type getFieldOptions struct { + fieldAlias bool + elemAlias bool +} + +func (c *SchemaCompatibility) getField(a []*Field, f *Field, optFns ...func(*getFieldOptions)) (*Field, bool) { + opt := getFieldOptions{} + for _, fn := range optFns { + if fn == nil { + continue + } + fn(&opt) + } for _, field := range a { if field.Name() == f.Name() { return field, true } + if opt.fieldAlias { + for _, alias := range f.Aliases() { + if field.Name() == alias { + return field, true + } + } + } + if opt.elemAlias { + for _, alias := range field.Aliases() { + if f.Name() == alias { + return field, true + } + } + } } return nil, false } + +func isNative(typ Type) bool { + switch typ { + case Null, Boolean, Int, Long, Float, Double, Bytes, String: + return true + default: + } + + return false +} + +func isPromotable(typ Type) bool { + switch typ { + case Int, Long, Float, String, Bytes: + return true + default: + } + + return false +} + +func (c *SchemaCompatibility) Resolve(reader, writer Schema) (Schema, error) { + if reader.Type() == Ref { + reader = reader.(*RefSchema).Schema() + } + if writer.Type() == Ref { + writer = writer.(*RefSchema).Schema() + } + + if err := c.compatible(reader, writer); err != nil { + return nil, err + } + + if writer.Type() != reader.Type() { + if isPromotable(writer.Type()) { + return reader, nil + } + + if reader.Type() == Union { + for _, schema := range reader.(*UnionSchema).Types() { + sch, err := c.Resolve(schema, writer) + if err != nil { + continue + } + + return sch, nil + } + + return nil, fmt.Errorf("reader union lacking writer schema %s", writer.Type()) + } + + if writer.Type() == Union { + schemas := make([]Schema, 0) + for _, schema := range writer.(*UnionSchema).Types() { + sch, err := c.Resolve(reader, schema) + if err != nil { + return nil, err + } + schemas = append(schemas, sch) + } + return NewUnionSchema(schemas) + } + } + + if isNative(writer.Type()) { + return reader, nil + } + + if writer.Type() == Enum { + return reader, nil + } + + if writer.Type() == Fixed { + return reader, nil + } + + if writer.Type() == Union { + schemas := make([]Schema, 0) + for _, schema := range writer.(*UnionSchema).Types() { + sch, err := c.Resolve(reader, schema) + if err != nil { + return nil, err + } + schemas = append(schemas, sch) + } + return NewUnionSchema(schemas) + } + + if writer.Type() == Array { + schema, err := c.Resolve(reader.(*ArraySchema).Items(), writer.(*ArraySchema).Items()) + if err != nil { + return nil, err + } + return NewArraySchema(schema), nil + } + + if writer.Type() == Map { + schema, err := c.Resolve(reader.(*MapSchema).Values(), writer.(*MapSchema).Values()) + if err != nil { + return nil, err + } + return NewMapSchema(schema), nil + } + + if writer.Type() == Record { + return c.resolveRecord(reader, writer) + } + + return nil, fmt.Errorf("failed to resolve composite schema for %s and %s", reader.Type(), writer.Type()) +} + +func (c *SchemaCompatibility) resolveRecord(reader, writer Schema) (Schema, error) { + w := writer.(*RecordSchema) + r := reader.(*RecordSchema) + + fields := make([]*Field, 0) + founds := make(map[string]struct{}) + + for _, field := range w.Fields() { + if field == nil { + continue + } + f := *field + rf, ok := c.getField(r.Fields(), field, func(gfo *getFieldOptions) { + gfo.elemAlias = true + }) + if !ok { + f.action = FieldDrain + fields = append(fields, &f) + continue + } + ft, err := c.Resolve(rf.Type(), field.Type()) + if err != nil { + return nil, err + } + rf.typ = ft + fields = append(fields, rf) + founds[rf.Name()] = struct{}{} + } + + for _, field := range r.Fields() { + if field == nil { + continue + } + if _, ok := founds[field.Name()]; ok { + continue + } + f := *field + f.action = FieldSetDefault + fields = append(fields, &f) + } + + return NewRecordSchema(r.Name(), r.Namespace(), fields) +} diff --git a/schema_compatibility_test.go b/schema_compatibility_test.go index 07044f05..4b445b90 100644 --- a/schema_compatibility_test.go +++ b/schema_compatibility_test.go @@ -1,6 +1,7 @@ package avro_test import ( + "log" "testing" "github.com/hamba/avro/v2" @@ -275,3 +276,76 @@ func TestSchemaCompatibility_CompatibleUsesCacheWithError(t *testing.T) { assert.Error(t, err) } + +func TestSchemaCompatibility_Resolve(t *testing.T) { + sch1 := avro.MustParse(`{ + "name": "A", + "type": "record", + "fields": [{ + "name": "c", + "type": "long" + },{ + "name": "a", + "type": "int" + }] + }`) + + type A1 struct { + A int32 `avro:"a"` + C int32 `avro:"c"` + } + + a1 := A1{ + A: 10, + C: 1000000, + } + + b, err := avro.Marshal(sch1, a1) + if err != nil { + t.Fatalf("marshal error%v", err) + } + + sch2 := avro.MustParse(`{ + "name": "A", + "type": "record", + "fields": [ + { + "name": "b", + "type": "string", + "default": "boo" + },{ + "name": "aa", + "aliases": ["a"], + "type": "long" + },{ + "name": "d", + "type": { + "type": "array", "items": "int" + }, + "default":[1, 2, 3, 4] + }] + }`) + + sc := avro.NewSchemaCompatibility() + + // resolve composite schema + sch, err := sc.Resolve(sch2, sch1) + if err != nil { + t.Fatalf("err: %v", err) + } + + type A2 struct { + A int64 `avro:"aa"` + B string `avro:"b"` + D []int32 `avro:"d"` + } + + a2 := A2{} + + err = avro.Unmarshal(sch, b, &a2) + if err != nil { + t.Fatalf("unmarshal error %v", err) + } + + log.Printf("result: %+v", a2) +} From 4ade51d5146dfeee0f73be45ba3f58773ef1dc26 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Fri, 17 Nov 2023 19:41:29 +0100 Subject: [PATCH 02/25] fix: type promotion POC --- codec_native.go | 178 +++++++++++++++++++++++++++-------- codec_promoter.go | 64 +++++++++++++ reader_generic.go | 28 ++++-- reader_promoter.go | 102 ++++++++++++++++++++ schema.go | 5 + schema_compatibility.go | 6 +- schema_compatibility_test.go | 56 +++++++---- 7 files changed, 371 insertions(+), 68 deletions(-) create mode 100644 codec_promoter.go create mode 100644 reader_promoter.go diff --git a/codec_native.go b/codec_native.go index d821747c..1917e8e3 100644 --- a/codec_native.go +++ b/codec_native.go @@ -11,6 +11,8 @@ import ( ) func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { + actual := schema.(*PrimitiveSchema).actual + switch typ.Kind() { case reflect.Bool: if schema.Type() != Boolean { @@ -58,7 +60,9 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { if schema.Type() != Long { break } - return &longCodec[uint32]{} + return &longCodec[uint32]{ + promoter: getCodecPromoter[uint32](actual), + } case reflect.Int64: st := schema.Type() @@ -68,10 +72,14 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { return &timeMillisCodec{} case st == Long && lt == TimeMicros: // time.Duration - return &timeMicrosCodec{} + return &timeMicrosCodec{ + promoter: getCodecPromoter[int64](actual), + } case st == Long: - return &longCodec[int64]{} + return &longCodec[int64]{ + promoter: getCodecPromoter[int64](actual), + } default: break @@ -81,25 +89,34 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { if schema.Type() != Float { break } - return &float32Codec{} + return &float32Codec{ + promoter: getCodecPromoter[float32](actual), + } case reflect.Float64: if schema.Type() != Double { break } - return &float64Codec{} + return &float64Codec{ + promoter: getCodecPromoter[float64](actual), + } case reflect.String: if schema.Type() != String { break } - return &stringCodec{} + return &stringCodec{ + promoter: getCodecPromoter[string](actual), + } case reflect.Slice: if typ.(reflect2.SliceType).Elem().Kind() != reflect.Uint8 || schema.Type() != Bytes { break } - return &bytesCodec{sliceType: typ.(*reflect2.UnsafeSliceType)} + return &bytesCodec{ + sliceType: typ.(*reflect2.UnsafeSliceType), + promoter: getCodecPromoter[[]byte](actual), + } case reflect.Struct: st := schema.Type() @@ -113,15 +130,22 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { return &dateCodec{} case Istpy1Time && st == Long && lt == TimestampMillis: - return ×tampMillisCodec{} + return ×tampMillisCodec{ + promoter: getCodecPromoter[int64](actual), + } case Istpy1Time && st == Long && lt == TimestampMicros: - return ×tampMicrosCodec{} + return ×tampMicrosCodec{ + promoter: getCodecPromoter[int64](actual), + } case Istpy1Rat && st == Bytes && lt == Decimal: dec := ls.(*DecimalLogicalSchema) - return &bytesDecimalCodec{prec: dec.Precision(), scale: dec.Scale()} + return &bytesDecimalCodec{ + prec: dec.Precision(), scale: dec.Scale(), + promoter: getCodecPromoter[[]byte](actual), + } default: break @@ -139,7 +163,10 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { } dec := ls.(*DecimalLogicalSchema) - return &bytesDecimalPtrCodec{prec: dec.Precision(), scale: dec.Scale()} + return &bytesDecimalPtrCodec{ + prec: dec.Precision(), scale: dec.Scale(), + promoter: getCodecPromoter[[]byte](actual), + } } return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} @@ -206,6 +233,7 @@ func createEncoderOfNative(schema Schema, typ reflect2.Type) ValEncoder { return &timeMillisCodec{} case st == Long && lt == TimeMicros: // time.Duration + return &timeMicrosCodec{} case st == Long: @@ -340,20 +368,37 @@ type largeInt interface { ~int32 | ~uint32 | int64 } -type longCodec[T largeInt] struct{} +type longCodec[T largeInt] struct { + promoter *codecPromoter[T] +} -func (*longCodec[T]) Decode(ptr unsafe.Pointer, r *Reader) { - *((*T)(ptr)) = T(r.ReadLong()) +func (c *longCodec[T]) Decode(ptr unsafe.Pointer, r *Reader) { + var v T + if c.promoter != nil { + v = c.promoter.promote(r) + } else { + v = T(r.ReadLong()) + } + *((*T)(ptr)) = v } func (*longCodec[T]) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteLong(int64(*((*T)(ptr)))) } -type float32Codec struct{} +type float32Codec struct { + promoter *codecPromoter[float32] +} + +func (c *float32Codec) Decode(ptr unsafe.Pointer, r *Reader) { + var v float32 + if c.promoter != nil { + v = c.promoter.promote(r) + } else { + v = r.ReadFloat() + } -func (*float32Codec) Decode(ptr unsafe.Pointer, r *Reader) { - *((*float32)(ptr)) = r.ReadFloat() + *((*float32)(ptr)) = v } func (*float32Codec) Encode(ptr unsafe.Pointer, w *Writer) { @@ -366,20 +411,36 @@ func (*float32DoubleCodec) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteDouble(float64(*((*float32)(ptr)))) } -type float64Codec struct{} +type float64Codec struct { + promoter *codecPromoter[float64] +} -func (*float64Codec) Decode(ptr unsafe.Pointer, r *Reader) { - *((*float64)(ptr)) = r.ReadDouble() +func (c *float64Codec) Decode(ptr unsafe.Pointer, r *Reader) { + var v float64 + if c.promoter != nil { + v = c.promoter.promote(r) + } else { + v = r.ReadDouble() + } + *((*float64)(ptr)) = v } func (*float64Codec) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteDouble(*((*float64)(ptr))) } -type stringCodec struct{} +type stringCodec struct { + promoter *codecPromoter[string] +} -func (*stringCodec) Decode(ptr unsafe.Pointer, r *Reader) { - *((*string)(ptr)) = r.ReadString() +func (c *stringCodec) Decode(ptr unsafe.Pointer, r *Reader) { + var v string + if c.promoter != nil { + v = c.promoter.promote(r) + } else { + v = r.ReadString() + } + *((*string)(ptr)) = v } func (*stringCodec) Encode(ptr unsafe.Pointer, w *Writer) { @@ -388,10 +449,17 @@ func (*stringCodec) Encode(ptr unsafe.Pointer, w *Writer) { type bytesCodec struct { sliceType *reflect2.UnsafeSliceType + promoter *codecPromoter[[]byte] } func (c *bytesCodec) Decode(ptr unsafe.Pointer, r *Reader) { - b := r.ReadBytes() + var b []byte + if c.promoter != nil { + b = c.promoter.promote(r) + } else { + b = r.ReadBytes() + } + // b := r.ReadBytes() c.sliceType.UnsafeSet(ptr, reflect2.PtrOf(b)) } @@ -413,10 +481,17 @@ func (c *dateCodec) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteInt(int32(days)) } -type timestampMillisCodec struct{} +type timestampMillisCodec struct { + promoter *codecPromoter[int64] +} func (c *timestampMillisCodec) Decode(ptr unsafe.Pointer, r *Reader) { - i := r.ReadLong() + var i int64 + if c.promoter != nil { + i = c.promoter.promote(r) + } else { + i = r.ReadLong() + } sec := i / 1e3 nsec := (i - sec*1e3) * 1e6 *((*time.Time)(ptr)) = time.Unix(sec, nsec).UTC() @@ -427,10 +502,17 @@ func (c *timestampMillisCodec) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteLong(t.Unix()*1e3 + int64(t.Nanosecond()/1e6)) } -type timestampMicrosCodec struct{} +type timestampMicrosCodec struct { + promoter *codecPromoter[int64] +} func (c *timestampMicrosCodec) Decode(ptr unsafe.Pointer, r *Reader) { - i := r.ReadLong() + var i int64 + if c.promoter != nil { + i = c.promoter.promote(r) + } else { + i = r.ReadLong() + } sec := i / 1e6 nsec := (i - sec*1e6) * 1e3 *((*time.Time)(ptr)) = time.Unix(sec, nsec).UTC() @@ -441,7 +523,8 @@ func (c *timestampMicrosCodec) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteLong(t.Unix()*1e6 + int64(t.Nanosecond()/1e3)) } -type timeMillisCodec struct{} +type timeMillisCodec struct { +} func (c *timeMillisCodec) Decode(ptr unsafe.Pointer, r *Reader) { i := r.ReadInt() @@ -453,10 +536,17 @@ func (c *timeMillisCodec) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteInt(int32(d.Nanoseconds() / int64(time.Millisecond))) } -type timeMicrosCodec struct{} +type timeMicrosCodec struct { + promoter *codecPromoter[int64] +} func (c *timeMicrosCodec) Decode(ptr unsafe.Pointer, r *Reader) { - i := r.ReadLong() + var i int64 + if c.promoter != nil { + i = c.promoter.promote(r) + } else { + i = r.ReadLong() + } *((*time.Duration)(ptr)) = time.Duration(i) * time.Microsecond } @@ -468,12 +558,19 @@ func (c *timeMicrosCodec) Encode(ptr unsafe.Pointer, w *Writer) { var one = big.NewInt(1) type bytesDecimalCodec struct { - prec int - scale int + prec int + scale int + promoter *codecPromoter[[]byte] } func (c *bytesDecimalCodec) Decode(ptr unsafe.Pointer, r *Reader) { - b := r.ReadBytes() + var b []byte + if c.promoter != nil { + b = c.promoter.promote(r) + } else { + b = r.ReadBytes() + } + // b := r.ReadBytes() if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 { i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8)) } @@ -514,12 +611,19 @@ func (c *bytesDecimalCodec) Encode(ptr unsafe.Pointer, w *Writer) { } type bytesDecimalPtrCodec struct { - prec int - scale int + prec int + scale int + promoter *codecPromoter[[]byte] } func (c *bytesDecimalPtrCodec) Decode(ptr unsafe.Pointer, r *Reader) { - b := r.ReadBytes() + var b []byte + if c.promoter != nil { + b = c.promoter.promote(r) + } else { + b = r.ReadBytes() + } + // b := r.ReadBytes() if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 { i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8)) } diff --git a/codec_promoter.go b/codec_promoter.go new file mode 100644 index 00000000..5776121f --- /dev/null +++ b/codec_promoter.go @@ -0,0 +1,64 @@ +package avro + +import ( + "reflect" + + "github.com/modern-go/reflect2" +) + +func getCodecPromoter[T any](actual Type) *codecPromoter[T] { + if actual == "" { + return nil + } + + return &codecPromoter[T]{actual: actual} +} + +type codecPromoter[T any] struct { + actual Type +} + +func (p *codecPromoter[T]) promote(r *Reader) (t T) { + tt := reflect2.TypeOf(t) + + convert := func(typ reflect2.Type, obj any) (t T) { + if !reflect.TypeOf(obj).ConvertibleTo(typ.Type1()) { + r.ReportError("decode promotable", "unsupported type") + // return zero value + return t + } + return reflect.ValueOf(obj).Convert(typ.Type1()).Interface().(T) + } + + switch p.actual { + case Int: + var obj int32 + (&intCodec[int32]{}).Decode(reflect2.PtrOf(&obj), r) + t = convert(tt, obj) + + case Long: + var obj int64 + (&longCodec[int64]{}).Decode(reflect2.PtrOf(&obj), r) + t = convert(tt, obj) + + case Float: + var obj float32 + (&float32Codec{}).Decode(reflect2.PtrOf(&obj), r) + t = convert(tt, obj) + + case String: + var obj string + (&stringCodec{}).Decode(reflect2.PtrOf(&obj), r) + t = convert(tt, obj) + + case Bytes: + var obj []byte + (&bytesCodec{}).Decode(reflect2.PtrOf(&obj), r) + t = convert(tt, obj) + + default: + r.ReportError("decode promotable", "unsupported actual type") + } + + return t +} diff --git a/reader_generic.go b/reader_generic.go index b75d240e..b7cfa02c 100644 --- a/reader_generic.go +++ b/reader_generic.go @@ -2,18 +2,26 @@ package avro import ( "fmt" + "log" "reflect" "time" ) // ReadNext reads the next Avro element as a generic interface. func (r *Reader) ReadNext(schema Schema) any { + var rp ReaderPromoter = r + if sch, ok := schema.(*PrimitiveSchema); ok && sch.actual != "" { + rp = &readerPromoter{r: r, actual: sch.actual, current: sch.Type()} + } + var ls LogicalSchema lts, ok := schema.(LogicalTypeSchema) if ok { ls = lts.Logical() } + log.Println("ls", ls) + switch schema.Type() { case Boolean: return r.ReadBool() @@ -34,34 +42,34 @@ func (r *Reader) ReadNext(schema Schema) any { if ls != nil { switch ls.Type() { case TimeMicros: - return time.Duration(r.ReadLong()) * time.Microsecond + return time.Duration(rp.ReadLong()) * time.Microsecond case TimestampMillis: - i := r.ReadLong() + i := rp.ReadLong() sec := i / 1e3 nsec := (i - sec*1e3) * 1e6 return time.Unix(sec, nsec).UTC() case TimestampMicros: - i := r.ReadLong() + i := rp.ReadLong() sec := i / 1e6 nsec := (i - sec*1e6) * 1e3 return time.Unix(sec, nsec).UTC() } } - return r.ReadLong() + return rp.ReadLong() case Float: - return r.ReadFloat() + return rp.ReadFloat() case Double: - return r.ReadDouble() + return rp.ReadDouble() case String: - return r.ReadString() + return rp.ReadString() case Bytes: if ls != nil && ls.Type() == Decimal { dec := ls.(*DecimalLogicalSchema) - return ratFromBytes(r.ReadBytes(), dec.Scale()) + return ratFromBytes(rp.ReadBytes(), dec.Scale()) } - return r.ReadBytes() + return rp.ReadBytes() case Record: fields := schema.(*RecordSchema).Fields() obj := make(map[string]any, len(fields)) @@ -97,7 +105,7 @@ func (r *Reader) ReadNext(schema Schema) any { return obj case Union: types := schema.(*UnionSchema).Types() - idx := int(r.ReadLong()) + idx := int(rp.ReadLong()) if idx < 0 || idx > len(types)-1 { r.ReportError("Read", "unknown union type") return nil diff --git a/reader_promoter.go b/reader_promoter.go new file mode 100644 index 00000000..f89b2f86 --- /dev/null +++ b/reader_promoter.go @@ -0,0 +1,102 @@ +package avro + +import ( + "reflect" +) + +type ReaderPromoter interface { + ReadLong() int64 + ReadFloat() float32 + ReadDouble() float64 + ReadString() string + ReadBytes() []byte +} + +type readerPromoter struct { + actual, current Type + r *Reader +} + +var _ ReaderPromoter = &readerPromoter{} + +var promotedInvalid = struct{}{} + +func (p *readerPromoter) readActual() any { + switch p.actual { + case Int: + return p.r.ReadInt() + + case Long: + return p.r.ReadLong() + + case Float: + return p.r.ReadFloat() + + case String: + return p.r.ReadString() + + case Bytes: + return p.r.ReadBytes() + + default: + p.r.ReportError("decode promotable", "unsupported actual type") + return promotedInvalid + } +} + +func (p *readerPromoter) ReadLong() int64 { + if v := p.readActual(); v != promotedInvalid { + return p.promote(v, p.current).(int64) + } + + return 0 +} + +func (p *readerPromoter) ReadFloat() float32 { + if v := p.readActual(); v != promotedInvalid { + return p.promote(v, p.current).(float32) + } + + return 0 +} + +func (p *readerPromoter) ReadDouble() float64 { + if v := p.readActual(); v != promotedInvalid { + return p.promote(v, p.current).(float64) + } + + return 0 +} + +func (p *readerPromoter) ReadString() string { + if v := p.readActual(); v != promotedInvalid { + return p.promote(v, p.current).(string) + } + + return "" +} + +func (p *readerPromoter) ReadBytes() []byte { + if v := p.readActual(); v != promotedInvalid { + return p.promote(v, p.current).([]byte) + } + + return nil +} + +func (p *readerPromoter) promote(obj any, st Type) (t any) { + switch st { + case Long: + return int64(reflect.ValueOf(obj).Int()) + case Float: + return float32(reflect.ValueOf(obj).Int()) + case Double: + return float64(reflect.ValueOf(obj).Float()) + case String: + return string(reflect.ValueOf(obj).Bytes()) + case Bytes: + return []byte(reflect.ValueOf(obj).String()) + } + + return obj +} diff --git a/schema.go b/schema.go index 6c67ed42..c9e174ea 100644 --- a/schema.go +++ b/schema.go @@ -394,6 +394,11 @@ type PrimitiveSchema struct { typ Type logical LogicalSchema + + // actual presents the actual type of the encoded value + // which can be promoted to schema current type. + // This field is only used in the context of write read schema resolution. + actual Type } // NewPrimitiveSchema creates a new PrimitiveSchema. diff --git a/schema_compatibility.go b/schema_compatibility.go index 402b4efe..8111a375 100644 --- a/schema_compatibility.go +++ b/schema_compatibility.go @@ -302,7 +302,11 @@ func (c *SchemaCompatibility) Resolve(reader, writer Schema) (Schema, error) { if writer.Type() != reader.Type() { if isPromotable(writer.Type()) { - return reader, nil + // TODO clean up + r := *reader.(*PrimitiveSchema) + r.actual = writer.Type() + + return &r, nil } if reader.Type() == Union { diff --git a/schema_compatibility_test.go b/schema_compatibility_test.go index 4b445b90..c7a3f84b 100644 --- a/schema_compatibility_test.go +++ b/schema_compatibility_test.go @@ -3,6 +3,7 @@ package avro_test import ( "log" "testing" + "time" "github.com/hamba/avro/v2" "github.com/stretchr/testify/assert" @@ -287,17 +288,23 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { },{ "name": "a", "type": "int" + }, + { + "name": "k", + "type": "string" }] }`) type A1 struct { - A int32 `avro:"a"` - C int32 `avro:"c"` + A int32 `avro:"a"` + C int32 `avro:"c"` + K string `avro:"k"` } a1 := A1{ A: 10, C: 1000000, + K: "K value", } b, err := avro.Marshal(sch1, a1) @@ -309,21 +316,29 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { "name": "A", "type": "record", "fields": [ - { - "name": "b", - "type": "string", - "default": "boo" - },{ - "name": "aa", - "aliases": ["a"], - "type": "long" - },{ - "name": "d", - "type": { - "type": "array", "items": "int" + { + "name": "k", + "type": "bytes" }, - "default":[1, 2, 3, 4] - }] + { + "name": "b", + "type": "string", + "default": "boo" + },{ + "name": "aa", + "aliases": ["a"], + "type": { + "type": "long", + "logicalType":"time-micros" + } + },{ + "name": "d", + "type": { + "type": "array", "items": "int" + }, + "default":[1, 2, 3, 4] + } + ] }`) sc := avro.NewSchemaCompatibility() @@ -335,9 +350,10 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { } type A2 struct { - A int64 `avro:"aa"` - B string `avro:"b"` - D []int32 `avro:"d"` + A time.Duration `avro:"aa"` + B string `avro:"b"` + D []int32 `avro:"d"` + K []byte `avro:"k"` } a2 := A2{} @@ -347,5 +363,5 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { t.Fatalf("unmarshal error %v", err) } - log.Printf("result: %+v", a2) + log.Printf("result: %+v %+v %T %+v", a2, a2.A, a2.A, string(a2.K)) } From 845d478e0c6cdd86caf2c1ae5b937d8d133d24c5 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Sat, 25 Nov 2023 17:09:49 +0100 Subject: [PATCH 03/25] fix: Fix and include last remarks --- codec_default.go | 12 +++-- codec_native.go | 102 +++++++++++++++++------------------ codec_promoter.go | 64 ---------------------- codec_record.go | 10 ++-- converter.go | 99 ++++++++++++++++++++++++++++++++++ reader_generic.go | 7 +-- reader_promoter.go | 91 ++++++++++--------------------- schema.go | 24 ++++++++- schema_compatibility.go | 34 ++++-------- schema_compatibility_test.go | 7 ++- 10 files changed, 229 insertions(+), 221 deletions(-) delete mode 100644 codec_promoter.go create mode 100644 converter.go diff --git a/codec_default.go b/codec_default.go index 59618c7d..264a1663 100644 --- a/codec_default.go +++ b/codec_default.go @@ -20,7 +20,9 @@ func createDefaultDecoder(cfg *frozenConfig, schema Schema, def any, typ reflect switch schema.Type() { case Null: - return &nullDefaultDecoder{} + return &nullDefaultDecoder{ + typ: typ, + } case Boolean: return &boolDefaultDecoder{ @@ -150,10 +152,13 @@ func (d *boolDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { } type nullDefaultDecoder struct { + typ reflect2.Type } -func (d *nullDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - return +func (d *nullDefaultDecoder) Decode(ptr unsafe.Pointer, _ *Reader) { + if d.typ.IsNullable() { + d.typ.UnsafeSet(ptr, d.typ.UnsafeNew()) + } } type intDefaultDecoder struct { @@ -270,7 +275,6 @@ func (d *doubleDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { default: r.ReportError("decode default", "unsupported type") } - } type stringDefaultDecoder struct { diff --git a/codec_native.go b/codec_native.go index 1917e8e3..bf2d2c96 100644 --- a/codec_native.go +++ b/codec_native.go @@ -11,7 +11,7 @@ import ( ) func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { - actual := schema.(*PrimitiveSchema).actual + converter := resolveConverter(schema.(*PrimitiveSchema).actual) switch typ.Kind() { case reflect.Bool: @@ -60,9 +60,7 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { if schema.Type() != Long { break } - return &longCodec[uint32]{ - promoter: getCodecPromoter[uint32](actual), - } + return &longCodec[uint32]{convert: converter.toLong} case reflect.Int64: st := schema.Type() @@ -73,12 +71,12 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { case st == Long && lt == TimeMicros: // time.Duration return &timeMicrosCodec{ - promoter: getCodecPromoter[int64](actual), + convert: converter.toLong, } case st == Long: return &longCodec[int64]{ - promoter: getCodecPromoter[int64](actual), + convert: converter.toLong, } default: @@ -90,7 +88,7 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { break } return &float32Codec{ - promoter: getCodecPromoter[float32](actual), + convert: converter.toFloat, } case reflect.Float64: @@ -98,7 +96,7 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { break } return &float64Codec{ - promoter: getCodecPromoter[float64](actual), + convert: converter.toDouble, } case reflect.String: @@ -106,7 +104,7 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { break } return &stringCodec{ - promoter: getCodecPromoter[string](actual), + convert: converter.toString, } case reflect.Slice: @@ -115,7 +113,7 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { } return &bytesCodec{ sliceType: typ.(*reflect2.UnsafeSliceType), - promoter: getCodecPromoter[[]byte](actual), + convert: converter.toBytes, } case reflect.Struct: @@ -131,20 +129,19 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { case Istpy1Time && st == Long && lt == TimestampMillis: return ×tampMillisCodec{ - promoter: getCodecPromoter[int64](actual), + convert: converter.toLong, } case Istpy1Time && st == Long && lt == TimestampMicros: return ×tampMicrosCodec{ - promoter: getCodecPromoter[int64](actual), + convert: converter.toLong, } case Istpy1Rat && st == Bytes && lt == Decimal: dec := ls.(*DecimalLogicalSchema) - return &bytesDecimalCodec{ prec: dec.Precision(), scale: dec.Scale(), - promoter: getCodecPromoter[[]byte](actual), + convert: converter.toBytes, } default: @@ -165,7 +162,7 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { return &bytesDecimalPtrCodec{ prec: dec.Precision(), scale: dec.Scale(), - promoter: getCodecPromoter[[]byte](actual), + convert: converter.toBytes, } } @@ -369,13 +366,13 @@ type largeInt interface { } type longCodec[T largeInt] struct { - promoter *codecPromoter[T] + convert func(*Reader) int64 } func (c *longCodec[T]) Decode(ptr unsafe.Pointer, r *Reader) { var v T - if c.promoter != nil { - v = c.promoter.promote(r) + if c.convert != nil { + v = T(c.convert(r)) } else { v = T(r.ReadLong()) } @@ -387,13 +384,13 @@ func (*longCodec[T]) Encode(ptr unsafe.Pointer, w *Writer) { } type float32Codec struct { - promoter *codecPromoter[float32] + convert func(*Reader) float32 } func (c *float32Codec) Decode(ptr unsafe.Pointer, r *Reader) { var v float32 - if c.promoter != nil { - v = c.promoter.promote(r) + if c.convert != nil { + v = c.convert(r) } else { v = r.ReadFloat() } @@ -412,13 +409,13 @@ func (*float32DoubleCodec) Encode(ptr unsafe.Pointer, w *Writer) { } type float64Codec struct { - promoter *codecPromoter[float64] + convert func(*Reader) float64 } func (c *float64Codec) Decode(ptr unsafe.Pointer, r *Reader) { var v float64 - if c.promoter != nil { - v = c.promoter.promote(r) + if c.convert != nil { + v = c.convert(r) } else { v = r.ReadDouble() } @@ -430,13 +427,13 @@ func (*float64Codec) Encode(ptr unsafe.Pointer, w *Writer) { } type stringCodec struct { - promoter *codecPromoter[string] + convert func(*Reader) string } func (c *stringCodec) Decode(ptr unsafe.Pointer, r *Reader) { var v string - if c.promoter != nil { - v = c.promoter.promote(r) + if c.convert != nil { + v = c.convert(r) } else { v = r.ReadString() } @@ -449,17 +446,16 @@ func (*stringCodec) Encode(ptr unsafe.Pointer, w *Writer) { type bytesCodec struct { sliceType *reflect2.UnsafeSliceType - promoter *codecPromoter[[]byte] + convert func(*Reader) []byte } func (c *bytesCodec) Decode(ptr unsafe.Pointer, r *Reader) { var b []byte - if c.promoter != nil { - b = c.promoter.promote(r) + if c.convert != nil { + b = c.convert(r) } else { b = r.ReadBytes() } - // b := r.ReadBytes() c.sliceType.UnsafeSet(ptr, reflect2.PtrOf(b)) } @@ -482,13 +478,13 @@ func (c *dateCodec) Encode(ptr unsafe.Pointer, w *Writer) { } type timestampMillisCodec struct { - promoter *codecPromoter[int64] + convert func(*Reader) int64 } func (c *timestampMillisCodec) Decode(ptr unsafe.Pointer, r *Reader) { var i int64 - if c.promoter != nil { - i = c.promoter.promote(r) + if c.convert != nil { + i = c.convert(r) } else { i = r.ReadLong() } @@ -503,13 +499,13 @@ func (c *timestampMillisCodec) Encode(ptr unsafe.Pointer, w *Writer) { } type timestampMicrosCodec struct { - promoter *codecPromoter[int64] + convert func(*Reader) int64 } func (c *timestampMicrosCodec) Decode(ptr unsafe.Pointer, r *Reader) { var i int64 - if c.promoter != nil { - i = c.promoter.promote(r) + if c.convert != nil { + i = c.convert(r) } else { i = r.ReadLong() } @@ -523,8 +519,7 @@ func (c *timestampMicrosCodec) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteLong(t.Unix()*1e6 + int64(t.Nanosecond()/1e3)) } -type timeMillisCodec struct { -} +type timeMillisCodec struct{} func (c *timeMillisCodec) Decode(ptr unsafe.Pointer, r *Reader) { i := r.ReadInt() @@ -537,13 +532,13 @@ func (c *timeMillisCodec) Encode(ptr unsafe.Pointer, w *Writer) { } type timeMicrosCodec struct { - promoter *codecPromoter[int64] + convert func(*Reader) int64 } func (c *timeMicrosCodec) Decode(ptr unsafe.Pointer, r *Reader) { var i int64 - if c.promoter != nil { - i = c.promoter.promote(r) + if c.convert != nil { + i = c.convert(r) } else { i = r.ReadLong() } @@ -558,19 +553,18 @@ func (c *timeMicrosCodec) Encode(ptr unsafe.Pointer, w *Writer) { var one = big.NewInt(1) type bytesDecimalCodec struct { - prec int - scale int - promoter *codecPromoter[[]byte] + prec int + scale int + convert func(*Reader) []byte } func (c *bytesDecimalCodec) Decode(ptr unsafe.Pointer, r *Reader) { var b []byte - if c.promoter != nil { - b = c.promoter.promote(r) + if c.convert != nil { + b = c.convert(r) } else { b = r.ReadBytes() } - // b := r.ReadBytes() if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 { i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8)) } @@ -611,19 +605,19 @@ func (c *bytesDecimalCodec) Encode(ptr unsafe.Pointer, w *Writer) { } type bytesDecimalPtrCodec struct { - prec int - scale int - promoter *codecPromoter[[]byte] + prec int + scale int + convert func(*Reader) []byte } func (c *bytesDecimalPtrCodec) Decode(ptr unsafe.Pointer, r *Reader) { var b []byte - if c.promoter != nil { - b = c.promoter.promote(r) + if c.convert != nil { + b = c.convert(r) } else { b = r.ReadBytes() } - // b := r.ReadBytes() + if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 { i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8)) } diff --git a/codec_promoter.go b/codec_promoter.go deleted file mode 100644 index 5776121f..00000000 --- a/codec_promoter.go +++ /dev/null @@ -1,64 +0,0 @@ -package avro - -import ( - "reflect" - - "github.com/modern-go/reflect2" -) - -func getCodecPromoter[T any](actual Type) *codecPromoter[T] { - if actual == "" { - return nil - } - - return &codecPromoter[T]{actual: actual} -} - -type codecPromoter[T any] struct { - actual Type -} - -func (p *codecPromoter[T]) promote(r *Reader) (t T) { - tt := reflect2.TypeOf(t) - - convert := func(typ reflect2.Type, obj any) (t T) { - if !reflect.TypeOf(obj).ConvertibleTo(typ.Type1()) { - r.ReportError("decode promotable", "unsupported type") - // return zero value - return t - } - return reflect.ValueOf(obj).Convert(typ.Type1()).Interface().(T) - } - - switch p.actual { - case Int: - var obj int32 - (&intCodec[int32]{}).Decode(reflect2.PtrOf(&obj), r) - t = convert(tt, obj) - - case Long: - var obj int64 - (&longCodec[int64]{}).Decode(reflect2.PtrOf(&obj), r) - t = convert(tt, obj) - - case Float: - var obj float32 - (&float32Codec{}).Decode(reflect2.PtrOf(&obj), r) - t = convert(tt, obj) - - case String: - var obj string - (&stringCodec{}).Decode(reflect2.PtrOf(&obj), r) - t = convert(tt, obj) - - case Bytes: - var obj []byte - (&bytesCodec{}).Decode(reflect2.PtrOf(&obj), r) - t = convert(tt, obj) - - default: - r.ReportError("decode promotable", "unsupported actual type") - } - - return t -} diff --git a/codec_record.go b/codec_record.go index e62068aa..7e04a6f0 100644 --- a/codec_record.go +++ b/codec_record.go @@ -263,6 +263,7 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec fields[i] = recordMapDecoderField{ name: field.Name(), decoder: createSkipDecoder(field.Type()), + skip: true, } continue } @@ -273,11 +274,6 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec name: field.Name(), decoder: createDefaultDecoder(cfg, field.Type(), field.def, mapType.Elem()), } - } else { - fields[i] = recordMapDecoderField{ - name: field.Name(), - decoder: createSkipDecoder(field.Type()), - } } continue @@ -299,6 +295,7 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec type recordMapDecoderField struct { name string decoder ValDecoder + skip bool } type recordMapDecoder struct { @@ -315,6 +312,9 @@ func (d *recordMapDecoder) Decode(ptr unsafe.Pointer, r *Reader) { for _, field := range d.fields { elem := d.elemType.UnsafeNew() field.decoder.Decode(elem, r) + if field.skip { + continue + } d.mapType.UnsafeSetIndex(ptr, reflect2.PtrOf(field), elem) } diff --git a/converter.go b/converter.go new file mode 100644 index 00000000..fa4b47a7 --- /dev/null +++ b/converter.go @@ -0,0 +1,99 @@ +package avro + +import ( + "fmt" + + "github.com/modern-go/reflect2" +) + +type converter struct { + toLong func(*Reader) int64 + toFloat func(*Reader) float32 + toDouble func(*Reader) float64 + toString func(*Reader) string + toBytes func(*Reader) []byte +} + +// resolveConverter returns a set of converter functions based on the actual type. +// Depending on the actual type value, some converter functions may be nil; +// thus, the downstream caller must first check the converter function value. +func resolveConverter(typ Type) converter { + cv := converter{} + + cv.toLong, _ = createLongConverter(typ) + cv.toFloat, _ = createFloatConverter(typ) + cv.toDouble, _ = createDoubleConverter(typ) + cv.toString, _ = createStringConverter(typ) + cv.toBytes, _ = createBytesConverter(typ) + + return cv +} + +func createLongConverter(typ Type) (func(*Reader) int64, error) { + switch typ { + case Int: + return func(r *Reader) int64 { return int64(r.ReadInt()) }, nil + case Long: + return func(r *Reader) int64 { return r.ReadLong() }, nil + default: + return nil, fmt.Errorf("cannot promote from %q to %q", typ, Long) + } +} + +func createFloatConverter(typ Type) (func(*Reader) float32, error) { + switch typ { + case Int: + return func(r *Reader) float32 { return float32(r.ReadInt()) }, nil + case Long: + return func(r *Reader) float32 { return float32(r.ReadLong()) }, nil + case Float: + return func(r *Reader) float32 { return r.ReadFloat() }, nil + default: + return nil, fmt.Errorf("cannot promote from %q to %q", typ, Long) + } +} + +func createDoubleConverter(typ Type) (func(*Reader) float64, error) { + switch typ { + case Int: + return func(r *Reader) float64 { return float64(r.ReadInt()) }, nil + case Long: + return func(r *Reader) float64 { return float64(r.ReadLong()) }, nil + case Float: + return func(r *Reader) float64 { return float64(r.ReadFloat()) }, nil + case Double: + return func(r *Reader) float64 { return r.ReadDouble() }, nil + default: + return nil, fmt.Errorf("cannot promote from %q to %q", typ, Long) + } +} + +func createStringConverter(typ Type) (func(*Reader) string, error) { + switch typ { + case Bytes: + return func(r *Reader) string { + b := r.ReadBytes() + // TBD: update go.mod version to go 1.20 minimum + // runtime.KeepAlive(b) // TBD: I guess this line is required? + // return unsafe.String(unsafe.SliceData(b), len(b)) + return string(b) + }, nil + case String: + return func(r *Reader) string { return r.ReadString() }, nil + default: + return nil, fmt.Errorf("cannot promote from %q to %q", typ, Long) + } +} + +func createBytesConverter(typ Type) (func(*Reader) []byte, error) { + switch typ { + case String: + return func(r *Reader) []byte { + return reflect2.UnsafeCastString(r.ReadString()) + }, nil + case Bytes: + return func(r *Reader) []byte { return r.ReadBytes() }, nil + default: + return nil, fmt.Errorf("cannot promote from %q to %q", typ, Long) + } +} diff --git a/reader_generic.go b/reader_generic.go index b7cfa02c..79e5c38c 100644 --- a/reader_generic.go +++ b/reader_generic.go @@ -2,16 +2,15 @@ package avro import ( "fmt" - "log" "reflect" "time" ) // ReadNext reads the next Avro element as a generic interface. func (r *Reader) ReadNext(schema Schema) any { - var rp ReaderPromoter = r + var rp iReaderPromoter = r if sch, ok := schema.(*PrimitiveSchema); ok && sch.actual != "" { - rp = &readerPromoter{r: r, actual: sch.actual, current: sch.Type()} + rp = newReaderPromoter(sch.actual, r) } var ls LogicalSchema @@ -20,8 +19,6 @@ func (r *Reader) ReadNext(schema Schema) any { ls = lts.Logical() } - log.Println("ls", ls) - switch schema.Type() { case Boolean: return r.ReadBool() diff --git a/reader_promoter.go b/reader_promoter.go index f89b2f86..b5575eed 100644 --- a/reader_promoter.go +++ b/reader_promoter.go @@ -1,10 +1,6 @@ package avro -import ( - "reflect" -) - -type ReaderPromoter interface { +type iReaderPromoter interface { ReadLong() int64 ReadFloat() float32 ReadDouble() float64 @@ -13,90 +9,59 @@ type ReaderPromoter interface { } type readerPromoter struct { - actual, current Type - r *Reader + actual Type + r *Reader + converter } -var _ ReaderPromoter = &readerPromoter{} - -var promotedInvalid = struct{}{} - -func (p *readerPromoter) readActual() any { - switch p.actual { - case Int: - return p.r.ReadInt() - - case Long: - return p.r.ReadLong() - - case Float: - return p.r.ReadFloat() - - case String: - return p.r.ReadString() - - case Bytes: - return p.r.ReadBytes() - - default: - p.r.ReportError("decode promotable", "unsupported actual type") - return promotedInvalid +func newReaderPromoter(actual Type, r *Reader) *readerPromoter { + rp := &readerPromoter{ + actual: actual, + r: r, + converter: resolveConverter(actual), } + + return rp } +var _ iReaderPromoter = &readerPromoter{} + func (p *readerPromoter) ReadLong() int64 { - if v := p.readActual(); v != promotedInvalid { - return p.promote(v, p.current).(int64) + if p.toLong != nil { + return p.toLong(p.r) } - return 0 + return p.r.ReadLong() } func (p *readerPromoter) ReadFloat() float32 { - if v := p.readActual(); v != promotedInvalid { - return p.promote(v, p.current).(float32) + if p.toFloat != nil { + return p.toFloat(p.r) } - return 0 + return p.r.ReadFloat() } func (p *readerPromoter) ReadDouble() float64 { - if v := p.readActual(); v != promotedInvalid { - return p.promote(v, p.current).(float64) + if p.toDouble != nil { + return p.toDouble(p.r) } - return 0 + return p.r.ReadDouble() } func (p *readerPromoter) ReadString() string { - if v := p.readActual(); v != promotedInvalid { - return p.promote(v, p.current).(string) + if p.toString != nil { + return p.toString(p.r) } - return "" + return p.r.ReadString() } func (p *readerPromoter) ReadBytes() []byte { - if v := p.readActual(); v != promotedInvalid { - return p.promote(v, p.current).([]byte) - } - - return nil -} - -func (p *readerPromoter) promote(obj any, st Type) (t any) { - switch st { - case Long: - return int64(reflect.ValueOf(obj).Int()) - case Float: - return float32(reflect.ValueOf(obj).Int()) - case Double: - return float64(reflect.ValueOf(obj).Float()) - case String: - return string(reflect.ValueOf(obj).Bytes()) - case Bytes: - return []byte(reflect.ValueOf(obj).String()) + if p.toBytes != nil { + return p.toBytes(p.r) } - return obj + return p.r.ReadBytes() } diff --git a/schema.go b/schema.go index c9e174ea..7506975b 100644 --- a/schema.go +++ b/schema.go @@ -75,9 +75,30 @@ const ( Duration LogicalType = "duration" ) +func isNative(typ Type) bool { + switch typ { + case Null, Boolean, Int, Long, Float, Double, Bytes, String: + return true + default: + } + + return false +} + +func isPromotable(typ Type) bool { + switch typ { + case Int, Long, Float, String, Bytes: + return true + default: + } + + return false +} + // Action is a field action used during decoding process. type Action string +// Action type constants. const ( FieldDrain Action = "drain" FieldSetDefault Action = "set_default" @@ -397,7 +418,7 @@ type PrimitiveSchema struct { // actual presents the actual type of the encoded value // which can be promoted to schema current type. - // This field is only used in the context of write read schema resolution. + // It's only used in the context of write-read schema resolution. actual Type } @@ -666,6 +687,7 @@ func (f *Field) Name() string { return f.name } +// Action returns the action of a field. func (f *Field) Action() Action { return f.action } diff --git a/schema_compatibility.go b/schema_compatibility.go index 8111a375..c80ba9b2 100644 --- a/schema_compatibility.go +++ b/schema_compatibility.go @@ -268,26 +268,10 @@ func (c *SchemaCompatibility) getField(a []*Field, f *Field, optFns ...func(*get return nil, false } -func isNative(typ Type) bool { - switch typ { - case Null, Boolean, Int, Long, Float, Double, Bytes, String: - return true - default: - } - - return false -} - -func isPromotable(typ Type) bool { - switch typ { - case Int, Long, Float, String, Bytes: - return true - default: - } - - return false -} - +// Resolve returns a composite schema that allows decoding data written by the writer schema, +// and makes necessary adjustments to support the reader schema. +// +// It fails if the writer and reader schemas are not compatible. func (c *SchemaCompatibility) Resolve(reader, writer Schema) (Schema, error) { if reader.Type() == Ref { reader = reader.(*RefSchema).Schema() @@ -300,13 +284,15 @@ func (c *SchemaCompatibility) Resolve(reader, writer Schema) (Schema, error) { return nil, err } + return c.resolve(reader, writer) +} + +func (c *SchemaCompatibility) resolve(reader, writer Schema) (Schema, error) { if writer.Type() != reader.Type() { if isPromotable(writer.Type()) { - // TODO clean up - r := *reader.(*PrimitiveSchema) + r := NewPrimitiveSchema(reader.Type(), reader.(*PrimitiveSchema).Logical()) r.actual = writer.Type() - - return &r, nil + return r, nil } if reader.Type() == Union { diff --git a/schema_compatibility_test.go b/schema_compatibility_test.go index c7a3f84b..3346a181 100644 --- a/schema_compatibility_test.go +++ b/schema_compatibility_test.go @@ -337,6 +337,10 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { "type": "array", "items": "int" }, "default":[1, 2, 3, 4] + },{ + "name": "g", + "type": ["null", "string"], + "default": null } ] }`) @@ -354,6 +358,7 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { B string `avro:"b"` D []int32 `avro:"d"` K []byte `avro:"k"` + G *string `avro:"g"` } a2 := A2{} @@ -363,5 +368,5 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { t.Fatalf("unmarshal error %v", err) } - log.Printf("result: %+v %+v %T %+v", a2, a2.A, a2.A, string(a2.K)) + log.Printf("result: %+v", a2) } From 908868d0a2e3f6f8e0154edc88639464cd560d33 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Sat, 25 Nov 2023 21:00:17 +0100 Subject: [PATCH 04/25] reduce default decoder verbosity --- codec_default.go | 182 +++++++++++++---------------------------------- 1 file changed, 50 insertions(+), 132 deletions(-) diff --git a/codec_default.go b/codec_default.go index 264a1663..d76cd4d9 100644 --- a/codec_default.go +++ b/codec_default.go @@ -26,65 +26,63 @@ func createDefaultDecoder(cfg *frozenConfig, schema Schema, def any, typ reflect case Boolean: return &boolDefaultDecoder{ - def: def, - typ: typ, + def: def.(bool), } case Int: return &intDefaultDecoder{ - def: def, + def: def.(int), typ: typ, } case Long: return &longDefaultDecoder{ - def: def, + def: def.(int64), typ: typ, } case Float: return &floatDefaultDecoder{ - def: def, + def: def.(float32), typ: typ, } case Double: return &doubleDefaultDecoder{ - def: def, + def: def.(float64), typ: typ, } case String: if typ.Implements(textUnmarshalerType) { - return &textDefaultMarshalerCodec{typ, def} + return &textDefaultMarshalerCodec{typ, def.(string)} } ptrType := reflect2.PtrTo(typ) if ptrType.Implements(textUnmarshalerType) { return &referenceDecoder{ - &textDefaultMarshalerCodec{typ: ptrType, def: def}, + &textDefaultMarshalerCodec{typ: ptrType, def: def.(string)}, } } return &stringDefaultDecoder{ - def: def, - typ: typ, + def: def.(string), } case Bytes: return &bytesDefaultDecoder{ - def: def, + def: def.(string), typ: typ, } case Fixed: return &fixedDefaultDecoder{ fixed: schema.(*FixedSchema), - def: def, + def: def.(string), typ: typ, } case Enum: - return &enumDefaultDecoder{typ: typ, def: def} + return &enumDefaultDecoder{typ: typ, def: def.(string)} case Ref: return createDefaultDecoder(cfg, schema.(*RefSchema).Schema(), def, typ) @@ -108,7 +106,7 @@ func createDefaultDecoder(cfg *frozenConfig, schema Schema, def any, typ reflect type textDefaultMarshalerCodec struct { typ reflect2.Type - def any + def string } func (d textDefaultMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) { @@ -121,7 +119,7 @@ func (d textDefaultMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) { } unmarshaler := (obj).(encoding.TextUnmarshaler) - b := []byte(d.def.(string)) + b := []byte(d.def) err := unmarshaler.UnmarshalText(b) if err != nil { @@ -138,17 +136,11 @@ func (d *efaceDefaultDecoder) Decode(ptr unsafe.Pointer, _ *Reader) { } type boolDefaultDecoder struct { - def any - typ reflect2.Type + def bool } func (d *boolDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - def, ok := d.def.(bool) - if !ok { - r.ReportError("decode default", "inconvertible type") - return - } - *((*bool)(ptr)) = def + *((*bool)(ptr)) = d.def } type nullDefaultDecoder struct { @@ -162,138 +154,91 @@ func (d *nullDefaultDecoder) Decode(ptr unsafe.Pointer, _ *Reader) { } type intDefaultDecoder struct { - def any + def int typ reflect2.Type } func (d *intDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - def := d.def - if reflect.TypeOf(d.def) != d.typ.Type1() { - if !reflect.TypeOf(d.def).ConvertibleTo(d.typ.Type1()) { - r.ReportError("decode default", "inconvertible type") - return - } - - def = reflect.ValueOf(d.def).Convert(d.typ.Type1()).Interface() - } - switch d.typ.Kind() { case reflect.Int: - *((*int)(ptr)) = def.(int) + *((*int)(ptr)) = d.def case reflect.Uint: - *((*uint)(ptr)) = def.(uint) + *((*uint)(ptr)) = uint(d.def) case reflect.Int8: - *((*int8)(ptr)) = def.(int8) + *((*int8)(ptr)) = int8(d.def) case reflect.Uint8: - *((*uint8)(ptr)) = def.(uint8) + *((*uint8)(ptr)) = uint8(d.def) case reflect.Int16: - *((*int16)(ptr)) = def.(int16) + *((*int16)(ptr)) = int16(d.def) case reflect.Uint16: - *((*uint16)(ptr)) = def.(uint16) + *((*uint16)(ptr)) = uint16(d.def) case reflect.Int32: - *((*int32)(ptr)) = def.(int32) + *((*int32)(ptr)) = int32(d.def) default: r.ReportError("decode default", "unsupported type") } } type longDefaultDecoder struct { - def any + def int64 typ reflect2.Type } func (d *longDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - def := d.def - if reflect.TypeOf(d.def) != d.typ.Type1() { - if !reflect.TypeOf(d.def).ConvertibleTo(d.typ.Type1()) { - r.ReportError("decode default", "inconvertible type") - return - } - - def = reflect.ValueOf(d.def).Convert(d.typ.Type1()).Interface() - } - switch d.typ.Kind() { case reflect.Int32: - *((*int32)(ptr)) = def.(int32) + *((*int32)(ptr)) = int32(d.def) case reflect.Uint32: - *((*uint32)(ptr)) = def.(uint32) + *((*uint32)(ptr)) = uint32(d.def) case reflect.Int64: - *((*int64)(ptr)) = def.(int64) + *((*int64)(ptr)) = d.def default: r.ReportError("decode default", "unsupported type") } } type floatDefaultDecoder struct { - def any + def float32 typ reflect2.Type } func (d *floatDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - def := d.def - if reflect.TypeOf(d.def) != d.typ.Type1() { - if !reflect.TypeOf(d.def).ConvertibleTo(d.typ.Type1()) { - r.ReportError("decode default", "inconvertible type") - return - } - - def = reflect.ValueOf(d.def).Convert(d.typ.Type1()).Interface() - } - switch d.typ.Kind() { case reflect.Float32: - *((*float32)(ptr)) = def.(float32) + *((*float32)(ptr)) = d.def case reflect.Float64: - *((*float64)(ptr)) = def.(float64) + *((*float64)(ptr)) = float64(d.def) default: r.ReportError("decode default", "unsupported type") } } type doubleDefaultDecoder struct { - def any + def float64 typ reflect2.Type } func (d *doubleDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - def := d.def - if reflect.TypeOf(d.def) != d.typ.Type1() { - if !reflect.TypeOf(d.def).ConvertibleTo(d.typ.Type1()) { - r.ReportError("decode default", "inconvertible type") - return - } - - def = reflect.ValueOf(d.def).Convert(d.typ.Type1()).Interface() - } - switch d.typ.Kind() { case reflect.Float64: - *((*float64)(ptr)) = def.(float64) + *((*float64)(ptr)) = d.def case reflect.Float32: - *((*float32)(ptr)) = def.(float32) + *((*float32)(ptr)) = float32(d.def) default: r.ReportError("decode default", "unsupported type") } } type stringDefaultDecoder struct { - def any - typ reflect2.Type + def string } func (d *stringDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - def, ok := d.def.(string) - if !ok { - r.ReportError("decode default", "inconvertible type") - return - } - - *((*string)(ptr)) = def + *((*string)(ptr)) = d.def } type bytesDefaultDecoder struct { - def any + def string typ reflect2.Type } @@ -307,12 +252,7 @@ func (d *bytesDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { return } - def, ok := d.def.(string) - if !ok { - r.ReportError("decode default", "inconvertible type") - return - } - runes := []rune(def) + runes := []rune(d.def) l := len(runes) b := make([]byte, l) for i := 0; i < l; i++ { @@ -361,7 +301,7 @@ func defaultDecoderOfRecord(cfg *frozenConfig, schema Schema, def any, typ refle type enumDefaultDecoder struct { typ reflect2.Type - def any + def string } func (d *enumDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { @@ -385,20 +325,15 @@ func (d *enumDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { } } - def, ok := d.def.(string) - if !ok { - r.ReportError("decode default", "inconvertible type") - } - switch { case d.typ.Kind() == reflect.String: - *((*string)(ptr)) = def + *((*string)(ptr)) = d.def return case reflect2.PtrTo(d.typ).Implements(textUnmarshalerType): - unmarshal(def, true) + unmarshal(d.def, true) return case d.typ.Implements(textUnmarshalerType): - unmarshal(def, false) + unmarshal(d.def, false) return default: r.ReportError("decode default", "unsupported type") @@ -411,7 +346,7 @@ func defaultDecoderOfArray(cfg *frozenConfig, schema Schema, def any, typ reflec } return &sliceDefaultDecoder{ - def: def, + def: def.([]any), typ: typ.(*reflect2.UnsafeSliceType), decoder: func(def any) ValDecoder { return createDefaultDecoder(cfg, schema.(*ArraySchema).Items(), def, typ.(*reflect2.UnsafeSliceType).Elem()) @@ -420,23 +355,17 @@ func defaultDecoderOfArray(cfg *frozenConfig, schema Schema, def any, typ reflec } type sliceDefaultDecoder struct { - def any + def []any typ *reflect2.UnsafeSliceType decoder func(def any) ValDecoder } func (d *sliceDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - def, ok := d.def.([]any) - if !ok { - r.ReportError("decode default", "inconvertible type") - return - } - - size := len(def) + size := len(d.def) d.typ.UnsafeGrow(ptr, size) for i := 0; i < size; i++ { elemPtr := d.typ.UnsafeGetIndex(ptr, i) - d.decoder(def[i]).Decode(elemPtr, nil) + d.decoder(d.def[i]).Decode(elemPtr, nil) } } @@ -447,7 +376,7 @@ func defaultDecoderOfMap(cfg *frozenConfig, schema Schema, def any, typ reflect2 return &mapDefaultDecoder{ typ: typ.(*reflect2.UnsafeMapType), - def: def, + def: def.(map[string]any), decoder: func(def any) ValDecoder { return createDefaultDecoder(cfg, schema.(*MapSchema).Values(), def, typ.(*reflect2.UnsafeMapType).Elem()) }, @@ -457,20 +386,14 @@ func defaultDecoderOfMap(cfg *frozenConfig, schema Schema, def any, typ reflect2 type mapDefaultDecoder struct { typ *reflect2.UnsafeMapType decoder func(def any) ValDecoder - def any + def map[string]any } func (d *mapDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - def, ok := d.def.(map[string]any) - if !ok { - r.ReportError("decode default", "inconvertible type") - return - } - if d.typ.UnsafeIsNil(ptr) { d.typ.UnsafeSet(ptr, d.typ.UnsafeMakeMap(0)) } - for k, v := range def { + for k, v := range d.def { key := k keyPtr := reflect2.PtrOf(&key) elemPtr := d.typ.UnsafeNew() @@ -481,17 +404,12 @@ func (d *mapDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { type fixedDefaultDecoder struct { typ reflect2.Type - def any + def string fixed *FixedSchema } func (d *fixedDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - def, ok := d.def.(string) - if !ok { - r.ReportError("decode default", "inconvertible type") - return - } - runes := []rune(def) + runes := []rune(d.def) l := len(runes) b := make([]byte, l) for i := 0; i < l; i++ { From 3f02ab45107ccf3dcdb5628dfefa8b08d2e3e693 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Thu, 30 Nov 2023 17:28:10 +0100 Subject: [PATCH 05/25] add tests for schema resolution --- schema_compatibility.go | 12 +- schema_compatibility_test.go | 415 ++++++++++++++++++++++++++++------- 2 files changed, 336 insertions(+), 91 deletions(-) diff --git a/schema_compatibility.go b/schema_compatibility.go index c80ba9b2..fc753363 100644 --- a/schema_compatibility.go +++ b/schema_compatibility.go @@ -289,12 +289,6 @@ func (c *SchemaCompatibility) Resolve(reader, writer Schema) (Schema, error) { func (c *SchemaCompatibility) resolve(reader, writer Schema) (Schema, error) { if writer.Type() != reader.Type() { - if isPromotable(writer.Type()) { - r := NewPrimitiveSchema(reader.Type(), reader.(*PrimitiveSchema).Logical()) - r.actual = writer.Type() - return r, nil - } - if reader.Type() == Union { for _, schema := range reader.(*UnionSchema).Types() { sch, err := c.Resolve(schema, writer) @@ -319,6 +313,12 @@ func (c *SchemaCompatibility) resolve(reader, writer Schema) (Schema, error) { } return NewUnionSchema(schemas) } + + if isPromotable(writer.Type()) { + r := NewPrimitiveSchema(reader.Type(), reader.(*PrimitiveSchema).Logical()) + r.actual = writer.Type() + return r, nil + } } if isNative(writer.Type()) { diff --git a/schema_compatibility_test.go b/schema_compatibility_test.go index 3346a181..1c6a8f5e 100644 --- a/schema_compatibility_test.go +++ b/schema_compatibility_test.go @@ -1,9 +1,8 @@ package avro_test import ( - "log" + "math/big" "testing" - "time" "github.com/hamba/avro/v2" "github.com/stretchr/testify/assert" @@ -278,95 +277,341 @@ func TestSchemaCompatibility_CompatibleUsesCacheWithError(t *testing.T) { assert.Error(t, err) } -func TestSchemaCompatibility_Resolve(t *testing.T) { - sch1 := avro.MustParse(`{ - "name": "A", - "type": "record", - "fields": [{ - "name": "c", - "type": "long" - },{ - "name": "a", - "type": "int" - }, - { - "name": "k", - "type": "string" - }] - }`) - - type A1 struct { - A int32 `avro:"a"` - C int32 `avro:"c"` - K string `avro:"k"` - } - - a1 := A1{ - A: 10, - C: 1000000, - K: "K value", +func TestSchemaCompatibility_ResolveV2(t *testing.T) { + tests := []struct { + name string + reader string + writer string + value any + want any + }{ + { + name: "Int Promote Long", + reader: `"long"`, + writer: `"int"`, + value: 10, + want: int64(10), + }, + { + name: "Int Promote Float", + reader: `"float"`, + writer: `"int"`, + value: 10, + want: float32(10), + }, + { + name: "Int Promote Double", + reader: `"double"`, + writer: `"int"`, + value: 10, + want: float64(10), + }, + { + name: "Long Promote Float", + reader: `"float"`, + writer: `"long"`, + value: int64(10), + want: float32(10), + }, + { + name: "Long Promote Double", + reader: `"double"`, + writer: `"long"`, + value: int64(10), + want: float64(10), + }, + { + name: "Float Promote Double", + reader: `"double"`, + writer: `"float"`, + value: float32(10.5), + want: float64(10.5), + }, + { + name: "String Promote Bytes", + reader: `"bytes"`, + writer: `"string"`, + value: "foo", + want: []byte("foo"), + }, + { + name: "Bytes Promote String", + reader: `"string"`, + writer: `"bytes"`, + value: []byte("foo"), + want: "foo", + }, + { + name: "Union Match", + reader: `["int", "long", "string"]`, + writer: `["string", "int", "long"]`, + value: "foo", + want: "foo", + }, { + name: "Union Writer Missing Schema", + reader: `["int", "long", "string"]`, + writer: `["string", "int"]`, + value: "foo", + want: "foo", + }, + { + name: "Union Writer Not Union", + reader: `["int", "long", "string"]`, + writer: `"int"`, + value: 10, + want: 10, + }, + { + name: "Union Reader Not Union", + reader: `"int"`, + writer: `["int"]`, + value: 10, + want: 10, + }, + { + name: "Record Reader Field Missing", + reader: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "b", "type": "string"}, {"name": "a", "type": "int"}]}`, + value: map[string]any{"a": 10, "b": "foo"}, + want: map[string]any{"a": 10}, + }, + { + name: "Record Writer Field Missing With Default", + reader: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}, {"name": "b", "type": "string", "default": "test"}]}`, + writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + value: map[string]any{"a": 10}, + want: map[string]any{"a": 10, "b": "test"}, + }, + { + name: "Record Reader Field With Alias", + reader: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "aa", "type": "int", "aliases": ["a"]}]}`, + writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + value: map[string]any{"a": 10}, + want: map[string]any{"aa": 10}, + }, + { + name: "Record Reader Field With Alias And Promotion", + reader: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "aa", "type": "double", "aliases": ["a"]}]}`, + writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + value: map[string]any{"a": 10}, + want: map[string]any{"aa": float64(10)}, + }, + { + name: "Record Writer Field Missing With Bytes Default", + reader: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}, {"name": "b", "type": "bytes", "default":"\u0066\u006f\u006f"}]}`, + writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + value: map[string]any{"a": 10}, + want: map[string]any{"a": 10, "b": []byte("foo")}, + }, + { + name: "Record Writer Field Missing With Bytes Default", + reader: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}, {"name": "b", "type": "bytes", "default":"\u0066\u006f\u006f"}]}`, + writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + value: map[string]any{"a": 10}, + want: map[string]any{"a": 10, "b": []byte("foo")}, + }, + { + name: "Record Writer Field Missing With Record Default", + reader: `{ + "type":"record", "name":"test", "namespace": "org.hamba.avro", + "fields":[ + {"name": "a", "type": "int"}, + { + "name": "b", + "type": { + "type": "record", + "name": "test.record", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": "string"} + ] + }, + "default":{"a":"foo", "b": "bar"} + } + ] + }`, + writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + value: map[string]any{"a": 10}, + want: map[string]any{"a": 10, "b": map[string]any{"a": "foo", "b": "bar"}}, + }, + { + name: "Record Writer Field Missing With Map Default", + reader: `{ + "type":"record", "name":"test", "namespace": "org.hamba.avro", + "fields":[ + {"name": "a", "type": "int"}, + { + "name": "b", + "type": { + "type": "map", "values": "string" + }, + "default":{"foo":"bar"} + } + ] + }`, + writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + value: map[string]any{"a": 10}, + want: map[string]any{"a": 10, "b": map[string]any{"foo": "bar"}}, + }, + { + name: "Record Writer Field Missing With Array Default", + reader: `{ + "type":"record", "name":"test", "namespace": "org.hamba.avro", + "fields":[ + {"name": "a", "type": "int"}, + { + "name": "b", + "type": { + "type": "array", "items": "int" + }, + "default":[1, 2, 3, 4] + } + ] + }`, + writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + value: map[string]any{"a": 10}, + want: map[string]any{"a": 10, "b": []any{1, 2, 3, 4}}, + }, + { + name: "Record Writer Field Missing With Union Null Default", + reader: `{ + "type":"record", "name":"test", "namespace": "org.hamba.avro", + "fields":[ + {"name": "a", "type": "int"}, + { + "name": "b", + "type":["null", "long"], + "default": null + } + ] + }`, + writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + value: map[string]any{"a": 10}, + want: map[string]any{"a": 10, "b": nil}, + }, + { + name: "Record Writer Field Missing With Union Non-null Default", + reader: `{ + "type":"record", "name":"test", "namespace": "org.hamba.avro", + "fields":[ + {"name": "a", "type": "int"}, + { + "name": "b", + "type":["string", "long"], + "default": "bar" + } + ] + }`, + writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + value: map[string]any{"a": 10}, + want: map[string]any{"a": 10, "b": "bar"}, + }, + { + name: "Record Writer Field Missing With Fixed Duration Default", + reader: `{ + "type":"record", "name":"test", "namespace": "org.hamba.avro", + "fields":[ + {"name": "a", "type": "int"}, + { + "name": "b", + "type": { + "type": "fixed", + "name": "test.fixed", + "logicalType":"duration", + "size":12 + }, + "default": "\u000c\u0000\u0000\u0000\u0022\u0000\u0000\u0000\u0052\u00aa\u0008\u0000" + } + ] + }`, + writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + value: map[string]any{"a": 10}, + want: map[string]any{ + "a": 10, + "b": avro.LogicalDuration{ + Months: uint32(12), + Days: uint32(34), + Milliseconds: uint32(567890), + }, + }, + }, + { + name: "Record Writer Field Missing With Fixed Logical Decimal Default", + reader: `{ + "type":"record", "name":"test", "namespace": "org.hamba.avro", + "fields":[ + {"name": "a", "type": "int"}, + { + "name": "b", + "type": { + "type": "fixed", + "name": "test.fixed", + "size": 6, + "logicalType":"decimal", + "precision":4, + "scale":2 + }, + "default": "\u0000\u0000\u0000\u0000\u0087\u0078" + } + ] + }`, + writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + value: map[string]any{"a": 10}, + want: map[string]any{ + "a": 10, + "b": *big.NewRat(1734, 5), + }, + }, + { + name: "Record Writer Field Missing With Enum Duration Default", + reader: `{ + "type":"record", "name":"test", "namespace": "org.hamba.avro", + "fields":[ + {"name": "a", "type": "int"}, + { + "name": "b", + "type": { + "type": "enum", + "name": "test.enum", + "symbols": ["foo", "bar"] + }, + "default": "bar" + } + ] + }`, + writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + value: map[string]any{"a": 10}, + want: map[string]any{ + "a": 10, + "b": "bar", + }, + }, } - b, err := avro.Marshal(sch1, a1) - if err != nil { - t.Fatalf("marshal error%v", err) - } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + r := avro.MustParse(test.reader) + w := avro.MustParse(test.writer) + sc := avro.NewSchemaCompatibility() - sch2 := avro.MustParse(`{ - "name": "A", - "type": "record", - "fields": [ - { - "name": "k", - "type": "bytes" - }, - { - "name": "b", - "type": "string", - "default": "boo" - },{ - "name": "aa", - "aliases": ["a"], - "type": { - "type": "long", - "logicalType":"time-micros" - } - },{ - "name": "d", - "type": { - "type": "array", "items": "int" - }, - "default":[1, 2, 3, 4] - },{ - "name": "g", - "type": ["null", "string"], - "default": null + b, err := avro.Marshal(w, test.value) + if err != nil { + t.Fatalf("marshal error%v", err) } - ] - }`) - - sc := avro.NewSchemaCompatibility() - - // resolve composite schema - sch, err := sc.Resolve(sch2, sch1) - if err != nil { - t.Fatalf("err: %v", err) - } - type A2 struct { - A time.Duration `avro:"aa"` - B string `avro:"b"` - D []int32 `avro:"d"` - K []byte `avro:"k"` - G *string `avro:"g"` - } + sch, err := sc.Resolve(r, w) + if err != nil { + t.Fatalf("resolve error %v", err) + } - a2 := A2{} + var result any + err = avro.Unmarshal(sch, b, &result) + if err != nil { + t.Fatalf("unmarshal error %v", err) + } - err = avro.Unmarshal(sch, b, &a2) - if err != nil { - t.Fatalf("unmarshal error %v", err) + assert.Equal(t, test.want, result) + }) } - - log.Printf("result: %+v", a2) } From 2498c2936073519832f28d6e5c527d723478455a Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Thu, 30 Nov 2023 17:28:56 +0100 Subject: [PATCH 06/25] attempt replacing readNext by native decoders --- codec_default.go | 165 +++++++++------- ..._test.go => codec_default_internal_test.go | 183 +++++++++++------- codec_dynamic.go | 129 +++++++++++- codec_native.go | 1 - codec_union.go | 12 +- converter.go | 2 - decoder_dynamic_bench_test.go | 60 ++++++ reader_promoter.go | 3 +- schema.go | 49 ++++- schema_internal_test.go | 4 +- 10 files changed, 460 insertions(+), 148 deletions(-) rename codec_default_test.go => codec_default_internal_test.go (71%) create mode 100644 decoder_dynamic_bench_test.go diff --git a/codec_default.go b/codec_default.go index d76cd4d9..dff030d8 100644 --- a/codec_default.go +++ b/codec_default.go @@ -14,7 +14,7 @@ import ( func createDefaultDecoder(cfg *frozenConfig, schema Schema, def any, typ reflect2.Type) ValDecoder { if typ.Kind() == reflect.Interface { if schema.Type() != Union && schema.Type() != Null { - return &efaceDefaultDecoder{def: def} + return &efaceDefaultDecoder{def: def, schema: schema} } } @@ -23,36 +23,30 @@ func createDefaultDecoder(cfg *frozenConfig, schema Schema, def any, typ reflect return &nullDefaultDecoder{ typ: typ, } - case Boolean: return &boolDefaultDecoder{ def: def.(bool), } - case Int: return &intDefaultDecoder{ def: def.(int), typ: typ, } - case Long: return &longDefaultDecoder{ def: def.(int64), typ: typ, } - case Float: return &floatDefaultDecoder{ def: def.(float32), typ: typ, } - case Double: return &doubleDefaultDecoder{ def: def.(float64), typ: typ, } - case String: if typ.Implements(textUnmarshalerType) { return &textDefaultMarshalerCodec{typ, def.(string)} @@ -63,42 +57,32 @@ func createDefaultDecoder(cfg *frozenConfig, schema Schema, def any, typ reflect &textDefaultMarshalerCodec{typ: ptrType, def: def.(string)}, } } - return &stringDefaultDecoder{ def: def.(string), } - case Bytes: return &bytesDefaultDecoder{ - def: def.(string), + def: def.([]byte), typ: typ, } - case Fixed: return &fixedDefaultDecoder{ fixed: schema.(*FixedSchema), - def: def.(string), + def: def.([]byte), typ: typ, } - case Enum: return &enumDefaultDecoder{typ: typ, def: def.(string)} - case Ref: return createDefaultDecoder(cfg, schema.(*RefSchema).Schema(), def, typ) - case Record: return defaultDecoderOfRecord(cfg, schema, def, typ) - case Array: return defaultDecoderOfArray(cfg, schema, def, typ) - case Map: return defaultDecoderOfMap(cfg, schema, def, typ) - case Union: - return createDefaultDecoder(cfg, schema.(*UnionSchema).Types()[0], def, typ) - + return defaultDecoderOfUnion(schema.(*UnionSchema), def, typ) default: return &errorDecoder{err: fmt.Errorf("avro: schema type %s is unsupported", schema.Type())} } @@ -123,24 +107,25 @@ func (d textDefaultMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) { err := unmarshaler.UnmarshalText(b) if err != nil { - r.ReportError("textMarshalerCodec", err.Error()) + r.ReportError("decode default textMarshalerCodec", err.Error()) } } type efaceDefaultDecoder struct { - def any + def any + schema Schema } -func (d *efaceDefaultDecoder) Decode(ptr unsafe.Pointer, _ *Reader) { - *(*any)(ptr) = d.def -} +func (d *efaceDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + rPtr, rTyp, err := dynamicReceiver(d.schema, r.cfg.resolver) + if err != nil { + r.ReportError("decode default", err.Error()) + return + } -type boolDefaultDecoder struct { - def bool -} + createDefaultDecoder(r.cfg, d.schema, d.def, rTyp).Decode(rPtr, r) -func (d *boolDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - *((*bool)(ptr)) = d.def + *(*any)(ptr) = rTyp.UnsafeIndirect(rPtr) } type nullDefaultDecoder struct { @@ -148,9 +133,15 @@ type nullDefaultDecoder struct { } func (d *nullDefaultDecoder) Decode(ptr unsafe.Pointer, _ *Reader) { - if d.typ.IsNullable() { - d.typ.UnsafeSet(ptr, d.typ.UnsafeNew()) - } + *((*unsafe.Pointer)(ptr)) = nil +} + +type boolDefaultDecoder struct { + def bool +} + +func (d *boolDefaultDecoder) Decode(ptr unsafe.Pointer, _ *Reader) { + *((*bool)(ptr)) = d.def } type intDefaultDecoder struct { @@ -233,12 +224,12 @@ type stringDefaultDecoder struct { def string } -func (d *stringDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { +func (d *stringDefaultDecoder) Decode(ptr unsafe.Pointer, _ *Reader) { *((*string)(ptr)) = d.def } type bytesDefaultDecoder struct { - def string + def []byte typ reflect2.Type } @@ -252,17 +243,7 @@ func (d *bytesDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { return } - runes := []rune(d.def) - l := len(runes) - b := make([]byte, l) - for i := 0; i < l; i++ { - if runes[i] < 0 || runes[i] > 255 { - r.ReportError("decode default", "invalid default") - return - } - b[i] = uint8(runes[i]) - } - d.typ.(*reflect2.UnsafeSliceType).UnsafeSet(ptr, reflect2.PtrOf(b)) + d.typ.(*reflect2.UnsafeSliceType).UnsafeSet(ptr, reflect2.PtrOf(d.def)) } func defaultDecoderOfRecord(cfg *frozenConfig, schema Schema, def any, typ reflect2.Type) ValDecoder { @@ -321,7 +302,7 @@ func (d *enumDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { unmarshaler := (obj).(encoding.TextUnmarshaler) err := unmarshaler.UnmarshalText([]byte(def)) if err != nil { - r.ReportError("textMarshalerCodec", err.Error()) + r.ReportError("decode default textMarshalerCodec", err.Error()) } } @@ -365,7 +346,7 @@ func (d *sliceDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { d.typ.UnsafeGrow(ptr, size) for i := 0; i < size; i++ { elemPtr := d.typ.UnsafeGetIndex(ptr, i) - d.decoder(d.def[i]).Decode(elemPtr, nil) + d.decoder(d.def[i]).Decode(elemPtr, r) } } @@ -397,29 +378,19 @@ func (d *mapDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { key := k keyPtr := reflect2.PtrOf(&key) elemPtr := d.typ.UnsafeNew() - d.decoder(v).Decode(elemPtr, nil) + d.decoder(v).Decode(elemPtr, r) d.typ.UnsafeSetIndex(ptr, keyPtr, elemPtr) } } type fixedDefaultDecoder struct { typ reflect2.Type - def string + def []byte fixed *FixedSchema } func (d *fixedDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - runes := []rune(d.def) - l := len(runes) - b := make([]byte, l) - for i := 0; i < l; i++ { - if runes[i] < 0 || runes[i] > 255 { - r.ReportError("decode default", "invalid default") - return - } - b[i] = uint8(runes[i]) - } - + l := len(d.def) switch d.typ.Kind() { case reflect.Array: arrayType := d.typ.(reflect2.ArrayType) @@ -432,7 +403,7 @@ func (d *fixedDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { return } for i := 0; i < arrayType.Len(); i++ { - arrayType.UnsafeSetIndex(ptr, i, reflect2.PtrOf(b[i])) + arrayType.UnsafeSetIndex(ptr, i, reflect2.PtrOf(d.def[i])) } case reflect.Uint64: @@ -444,7 +415,7 @@ func (d *fixedDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { r.ReportError("decode default", "invalid default") return } - *(*uint64)(ptr) = binary.BigEndian.Uint64(b) + *(*uint64)(ptr) = binary.BigEndian.Uint64(d.def) case reflect.Struct: ls := d.fixed.Logical() @@ -458,7 +429,7 @@ func (d *fixedDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { r.ReportError("decode default", "invalid default") return } - *((*LogicalDuration)(ptr)) = durationFromBytes(b) + *((*LogicalDuration)(ptr)) = durationFromBytes(d.def) case typ1.ConvertibleTo(ratType) && ls.Type() == Decimal: dec := ls.(*DecimalLogicalSchema) @@ -466,7 +437,7 @@ func (d *fixedDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { r.ReportError("decode default", "invalid default") return } - *((*big.Rat)(ptr)) = *ratFromBytes(b, dec.Scale()) + *((*big.Rat)(ptr)) = *ratFromBytes(d.def, dec.Scale()) default: r.ReportError("decode default", "unsupported type") } @@ -485,3 +456,67 @@ func durationFromBytes(b []byte) LogicalDuration { return duration } + +func defaultDecoderOfUnion(schema *UnionSchema, def any, typ reflect2.Type) ValDecoder { + return &unionDefaultDecoder{ + typ: typ, + def: def, + union: schema, + } +} + +type unionDefaultDecoder struct { + typ reflect2.Type + def any + union *UnionSchema +} + +func (d *unionDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + switch d.typ.Kind() { + case reflect.Map: + if d.typ.(reflect2.MapType).Key().Kind() != reflect.String || + d.typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { + break + } + schema := d.union.Types()[0] + if schema.Type() == Null { + return + } + + mapType := d.typ.(*reflect2.UnsafeMapType) + if mapType.UnsafeIsNil(ptr) { + mapType.UnsafeSet(ptr, mapType.UnsafeMakeMap(0)) + } + + key := schemaTypeName(schema) + keyPtr := reflect2.PtrOf(key) + elemPtr := mapType.Elem().UnsafeNew() + + decoder := createDefaultDecoder(r.cfg, d.union.Types()[0], d.def, mapType.Elem()) + decoder.Decode(elemPtr, r) + + mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr) + + case reflect.Ptr: + if !d.union.Nullable() { + break + } + if d.union.Types()[0].Type() == Null { + *((*unsafe.Pointer)(ptr)) = nil + return + } + + decoder := createDefaultDecoder(r.cfg, d.union.Types()[0], d.def, d.typ.(*reflect2.UnsafePtrType).Elem()) + if *((*unsafe.Pointer)(ptr)) == nil { + newPtr := d.typ.UnsafeNew() + decoder.Decode(newPtr, r) + *((*unsafe.Pointer)(ptr)) = newPtr + return + } + decoder.Decode(*((*unsafe.Pointer)(ptr)), r) + + case reflect.Interface: + decoder := createDefaultDecoder(r.cfg, d.union.Types()[0], d.def, d.typ) + decoder.Decode(ptr, r) + } +} diff --git a/codec_default_test.go b/codec_default_internal_test.go similarity index 71% rename from codec_default_test.go rename to codec_default_internal_test.go index d64db27b..a68865e3 100644 --- a/codec_default_test.go +++ b/codec_default_internal_test.go @@ -1,16 +1,36 @@ -package avro_test +package avro import ( "bytes" + "errors" "math" "math/big" "testing" - "github.com/hamba/avro/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +type testEnumTextUnmarshaler int + +func (m *testEnumTextUnmarshaler) UnmarshalText(data []byte) error { + switch string(data) { + case "foo": + *m = 0 + return nil + case "bar": + *m = 1 + return nil + default: + return errors.New("unknown symbol") + } +} + +func ConfigTeardown() { + // Reset the caches + DefaultConfig = Config{}.Freeze() +} + func TestDecoder_DefaultBool(t *testing.T) { defer ConfigTeardown() @@ -27,7 +47,7 @@ func TestDecoder_DefaultBool(t *testing.T) { // {"a": "foo"} data := []byte{0x6, 0x66, 0x6f, 0x6f} - schema := avro.MustParse(`{ + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -37,9 +57,9 @@ func TestDecoder_DefaultBool(t *testing.T) { }`) // hack: set field action to force decode default behavior - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault - dec := avro.NewDecoderForSchema(schema, bytes.NewReader(data)) + dec := NewDecoderForSchema(schema, bytes.NewReader(data)) type TestRecord struct { A string `avro:"a"` @@ -58,7 +78,7 @@ func TestDecoder_DefaultInt(t *testing.T) { data := []byte{0x6, 0x66, 0x6f, 0x6f} - schema := avro.MustParse(`{ + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -67,9 +87,9 @@ func TestDecoder_DefaultInt(t *testing.T) { ] }`) - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault - dec := avro.NewDecoderForSchema(schema, bytes.NewReader(data)) + dec := NewDecoderForSchema(schema, bytes.NewReader(data)) type TestRecord struct { A string `avro:"a"` @@ -88,7 +108,7 @@ func TestDecoder_DefaultLong(t *testing.T) { data := []byte{0x6, 0x66, 0x6f, 0x6f} - schema := avro.MustParse(`{ + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -97,7 +117,7 @@ func TestDecoder_DefaultLong(t *testing.T) { ] }`) - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault type TestRecord struct { A string `avro:"a"` @@ -105,7 +125,7 @@ func TestDecoder_DefaultLong(t *testing.T) { } var got TestRecord - err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) require.NoError(t, err) assert.Equal(t, TestRecord{B: 1000, A: "foo"}, got) @@ -116,7 +136,7 @@ func TestDecoder_DefaultFloat(t *testing.T) { data := []byte{0x6, 0x66, 0x6f, 0x6f} - schema := avro.MustParse(`{ + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -125,7 +145,7 @@ func TestDecoder_DefaultFloat(t *testing.T) { ] }`) - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault type TestRecord struct { A string `avro:"a"` @@ -133,7 +153,7 @@ func TestDecoder_DefaultFloat(t *testing.T) { } var got TestRecord - err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) require.NoError(t, err) assert.Equal(t, TestRecord{B: 10.45, A: "foo"}, got) @@ -144,7 +164,7 @@ func TestDecoder_DefaultDouble(t *testing.T) { data := []byte{0x6, 0x66, 0x6f, 0x6f} - schema := avro.MustParse(`{ + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -153,7 +173,7 @@ func TestDecoder_DefaultDouble(t *testing.T) { ] }`) - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault type TestRecord struct { A string `avro:"a"` @@ -161,7 +181,7 @@ func TestDecoder_DefaultDouble(t *testing.T) { } var got TestRecord - err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) require.NoError(t, err) assert.Equal(t, TestRecord{B: 10.45, A: "foo"}, got) @@ -172,7 +192,7 @@ func TestDecoder_DefaultBytes(t *testing.T) { data := []byte{0x6, 0x66, 0x6f, 0x6f} - schema := avro.MustParse(`{ + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -181,7 +201,7 @@ func TestDecoder_DefaultBytes(t *testing.T) { ] }`) - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault type TestRecord struct { A string `avro:"a"` @@ -189,7 +209,7 @@ func TestDecoder_DefaultBytes(t *testing.T) { } var got TestRecord - err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) require.NoError(t, err) assert.Equal(t, TestRecord{B: []byte("value"), A: "foo"}, got) @@ -200,7 +220,7 @@ func TestDecoder_DefaultString(t *testing.T) { data := []byte{0x6, 0x66, 0x6f, 0x6f} - schema := avro.MustParse(`{ + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -209,7 +229,7 @@ func TestDecoder_DefaultString(t *testing.T) { ] }`) - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault type TestRecord struct { A string `avro:"a"` @@ -217,7 +237,7 @@ func TestDecoder_DefaultString(t *testing.T) { } var got TestRecord - err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) require.NoError(t, err) assert.Equal(t, TestRecord{B: "value", A: "foo"}, got) @@ -228,7 +248,7 @@ func TestDecoder_DefaultEnum(t *testing.T) { data := []byte{0x6, 0x66, 0x6f, 0x6f} - schema := avro.MustParse(`{ + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -245,7 +265,7 @@ func TestDecoder_DefaultEnum(t *testing.T) { ] }`) - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault t.Run("simple", func(t *testing.T) { type TestRecord struct { @@ -254,7 +274,7 @@ func TestDecoder_DefaultEnum(t *testing.T) { } var got TestRecord - err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) require.NoError(t, err) assert.Equal(t, TestRecord{B: "bar", A: "foo"}, got) @@ -268,7 +288,7 @@ func TestDecoder_DefaultEnum(t *testing.T) { } var got TestRecord - err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) require.NoError(t, err) assert.Equal(t, TestRecord{B: 1, A: "foo"}, got) @@ -281,7 +301,7 @@ func TestDecoder_DefaultEnum(t *testing.T) { } var got TestRecord - err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) require.NoError(t, err) var v testEnumTextUnmarshaler = 1 @@ -294,13 +314,13 @@ func TestDecoder_DefaultUnion(t *testing.T) { data := []byte{0x6, 0x66, 0x6f, 0x6f} - type TestRecord struct { - A string `avro:"a"` - B any `avro:"b"` - } - t.Run("null default", func(t *testing.T) { - schema := avro.MustParse(`{ + type TestRecord struct { + A string `avro:"a"` + B *string `avro:"b"` + } + + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -309,17 +329,22 @@ func TestDecoder_DefaultUnion(t *testing.T) { ] }`) - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault var got TestRecord - err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) require.NoError(t, err) assert.Equal(t, TestRecord{B: nil, A: "foo"}, got) }) t.Run("not null default", func(t *testing.T) { - schema := avro.MustParse(`{ + type TestRecord struct { + A string `avro:"a"` + B any `avro:"b"` + } + + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -328,14 +353,38 @@ func TestDecoder_DefaultUnion(t *testing.T) { ] }`) - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault var got TestRecord - err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) require.NoError(t, err) assert.Equal(t, TestRecord{B: "bar", A: "foo"}, got) }) + + t.Run("map receiver", func(t *testing.T) { + type TestRecord struct { + A string `avro:"a"` + B map[string]any `avro:"b"` + } + + schema := MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": ["string", "long"], "default": "bar"} + ] + }`) + + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault + + var got TestRecord + err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: map[string]any{"string": "bar"}, A: "foo"}, got) + }) } func TestDecoder_DefaultArray(t *testing.T) { @@ -343,7 +392,7 @@ func TestDecoder_DefaultArray(t *testing.T) { data := []byte{0x6, 0x66, 0x6f, 0x6f} - schema := avro.MustParse(`{ + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -358,9 +407,9 @@ func TestDecoder_DefaultArray(t *testing.T) { ] }`) - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault - dec := avro.NewDecoderForSchema(schema, bytes.NewReader(data)) + dec := NewDecoderForSchema(schema, bytes.NewReader(data)) type TestRecord struct { A string `avro:"a"` @@ -379,7 +428,7 @@ func TestDecoder_DefaultMap(t *testing.T) { data := []byte{0x6, 0x66, 0x6f, 0x6f} - schema := avro.MustParse(`{ + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -394,9 +443,9 @@ func TestDecoder_DefaultMap(t *testing.T) { ] }`) - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault - dec := avro.NewDecoderForSchema(schema, bytes.NewReader(data)) + dec := NewDecoderForSchema(schema, bytes.NewReader(data)) type TestRecord struct { A string `avro:"a"` @@ -415,7 +464,7 @@ func TestDecoder_DefaultRecord(t *testing.T) { data := []byte{0x6, 0x66, 0x6f, 0x6f} - schema := avro.MustParse(`{ + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -435,10 +484,10 @@ func TestDecoder_DefaultRecord(t *testing.T) { ] }`) - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault t.Run("struct", func(t *testing.T) { - dec := avro.NewDecoderForSchema(schema, bytes.NewReader(data)) + dec := NewDecoderForSchema(schema, bytes.NewReader(data)) type subRecord struct { A string `avro:"a"` @@ -457,7 +506,7 @@ func TestDecoder_DefaultRecord(t *testing.T) { }) t.Run("map", func(t *testing.T) { - dec := avro.NewDecoderForSchema(schema, bytes.NewReader(data)) + dec := NewDecoderForSchema(schema, bytes.NewReader(data)) var got map[string]any err := dec.Decode(&got) @@ -472,7 +521,7 @@ func TestDecoder_DefaultRef(t *testing.T) { data := []byte{0x6, 0x66, 0x6f, 0x6f} - _ = avro.MustParse(`{ + _ = MustParse(`{ "type": "record", "name": "test.embed", "fields" : [ @@ -480,7 +529,7 @@ func TestDecoder_DefaultRef(t *testing.T) { ] }`) - schema := avro.MustParse(`{ + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -489,9 +538,9 @@ func TestDecoder_DefaultRef(t *testing.T) { ] }`) - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault - dec := avro.NewDecoderForSchema(schema, bytes.NewReader(data)) + dec := NewDecoderForSchema(schema, bytes.NewReader(data)) var got map[string]any err := dec.Decode(&got) @@ -506,7 +555,7 @@ func TestDecoder_DefaultFixed(t *testing.T) { data := []byte{0x6, 0x66, 0x6f, 0x6f} t.Run("array", func(t *testing.T) { - schema := avro.MustParse(`{ + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -523,7 +572,7 @@ func TestDecoder_DefaultFixed(t *testing.T) { ] }`) - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault type TestRecord struct { A string `avro:"a"` @@ -531,14 +580,14 @@ func TestDecoder_DefaultFixed(t *testing.T) { } var got TestRecord - err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) require.NoError(t, err) assert.Equal(t, TestRecord{B: [3]byte{'f', 'o', 'o'}, A: "foo"}, got) }) t.Run("uint64", func(t *testing.T) { - schema := avro.MustParse(`{ + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -555,7 +604,7 @@ func TestDecoder_DefaultFixed(t *testing.T) { ] }`) - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault type TestRecord struct { A string `avro:"a"` @@ -563,14 +612,14 @@ func TestDecoder_DefaultFixed(t *testing.T) { } var got TestRecord - err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) require.NoError(t, err) assert.Equal(t, TestRecord{B: uint64(math.MaxUint64), A: "foo"}, got) }) t.Run("duration", func(t *testing.T) { - schema := avro.MustParse(`{ + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -588,15 +637,15 @@ func TestDecoder_DefaultFixed(t *testing.T) { ] }`) - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault type TestRecord struct { - A string `avro:"a"` - B avro.LogicalDuration `avro:"b"` + A string `avro:"a"` + B LogicalDuration `avro:"b"` } var got TestRecord - err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) require.NoError(t, err) @@ -607,7 +656,7 @@ func TestDecoder_DefaultFixed(t *testing.T) { }) t.Run("rat", func(t *testing.T) { - schema := avro.MustParse(`{ + schema := MustParse(`{ "type": "record", "name": "test", "fields" : [ @@ -626,7 +675,7 @@ func TestDecoder_DefaultFixed(t *testing.T) { } ] }`) - avro.SetFieldAction(schema.(*avro.RecordSchema).Fields()[1], avro.FieldSetDefault) + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault type TestRecord struct { A string `avro:"a"` @@ -634,7 +683,7 @@ func TestDecoder_DefaultFixed(t *testing.T) { } var got TestRecord - err := avro.NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) require.NoError(t, err) assert.Equal(t, big.NewRat(1734, 5), &got.B) diff --git a/codec_dynamic.go b/codec_dynamic.go index 631b0d53..eaaf7de0 100644 --- a/codec_dynamic.go +++ b/codec_dynamic.go @@ -1,7 +1,10 @@ package avro import ( + "fmt" + "math/big" "reflect" + "time" "unsafe" "github.com/modern-go/reflect2" @@ -15,13 +18,27 @@ func (d *efaceDecoder) Decode(ptr unsafe.Pointer, r *Reader) { pObj := (*any)(ptr) obj := *pObj if obj == nil { - *pObj = r.ReadNext(d.schema) + rPtr, rtyp, err := dynamicReceiver(d.schema, r.cfg.resolver) + if err != nil { + r.ReportError("Read", err.Error()) + return + } + decoderOfType(r.cfg, d.schema, rtyp).Decode(rPtr, r) + *pObj = rtyp.UnsafeIndirect(rPtr) + // *pObj = r.ReadNext(d.schema) return } typ := reflect2.TypeOf(obj) if typ.Kind() != reflect.Ptr { - *pObj = r.ReadNext(d.schema) + rPtr, rTyp, err := dynamicReceiver(d.schema, r.cfg.resolver) + if err != nil { + r.ReportError("Read", err.Error()) + return + } + decoderOfType(r.cfg, d.schema, rTyp).Decode(rPtr, r) + *pObj = rTyp.UnsafeIndirect(rPtr) + // *pObj = r.ReadNext(d.schema) return } @@ -45,3 +62,111 @@ func (e *interfaceEncoder) Encode(ptr unsafe.Pointer, w *Writer) { obj := e.typ.UnsafeIndirect(ptr) w.WriteVal(e.schema, obj) } + +func dynamicReceiver(schema Schema, resolver *TypeResolver) (unsafe.Pointer, reflect2.Type, error) { + var ls LogicalSchema + lts, ok := schema.(LogicalTypeSchema) + if ok { + ls = lts.Logical() + } + + name := string(schema.Type()) + if ls != nil { + name += "." + string(ls.Type()) + } + if resolver != nil { + typ, err := resolver.Type(name) + if err == nil { + return typ.UnsafeNew(), typ, nil + } + } + + switch schema.Type() { + case Boolean: + var v bool + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Int: + if ls != nil { + switch ls.Type() { + case Date: + var v time.Time + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + + case TimeMillis: + var v time.Duration + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + } + } + var v int + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Long: + if ls != nil { + switch ls.Type() { + case TimeMicros: + var v time.Duration + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + + case TimestampMillis: + var v time.Time + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + + case TimestampMicros: + var v time.Time + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + } + } + var v int64 + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Float: + var v float32 + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Double: + var v float64 + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case String: + var v string + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Bytes: + if ls != nil && ls.Type() == Decimal { + var v *big.Rat + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + } + var v []byte + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Record: + var v map[string]any + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Ref: + return dynamicReceiver(schema.(*RefSchema).Schema(), resolver) + case Enum: + var v string + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Array: + v := make([]any, 0) + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Map: + var v map[string]any + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Union: + var v map[string]any + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Fixed: + fixed := schema.(*FixedSchema) + ls := fixed.Logical() + if ls != nil { + switch ls.Type() { + case Duration: + var v LogicalDuration + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Decimal: + var v big.Rat + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + } + } + // note that uint64 case is not supported, due to the lack of indicator at the schema-level (logical type) + var v []byte + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + default: + return nil, nil, fmt.Errorf("dynamic receiver not found for schema: %v", name) + } +} diff --git a/codec_native.go b/codec_native.go index bf2d2c96..04cdafed 100644 --- a/codec_native.go +++ b/codec_native.go @@ -230,7 +230,6 @@ func createEncoderOfNative(schema Schema, typ reflect2.Type) ValEncoder { return &timeMillisCodec{} case st == Long && lt == TimeMicros: // time.Duration - return &timeMicrosCodec{} case st == Long: diff --git a/codec_union.go b/codec_union.go index d07cbf0a..b8e81748 100644 --- a/codec_union.go +++ b/codec_union.go @@ -294,7 +294,17 @@ func (d *unionResolvedDecoder) Decode(ptr unsafe.Pointer, r *Reader) { // We cannot resolve this, set it to the map type name := schemaTypeName(schema) obj := map[string]any{} - obj[name] = r.ReadNext(schema) + + rPtr, rTyp, err := dynamicReceiver(schema, r.cfg.resolver) + if err != nil { + r.ReportError("Read", err.Error()) + return + } + decoderOfType(r.cfg, schema, rTyp).Decode(rPtr, r) + + obj[name] = rTyp.UnsafeIndirect(rPtr) + + // obj[name] = r.ReadNext(schema) *pObj = obj return diff --git a/converter.go b/converter.go index fa4b47a7..99dcd05f 100644 --- a/converter.go +++ b/converter.go @@ -19,13 +19,11 @@ type converter struct { // thus, the downstream caller must first check the converter function value. func resolveConverter(typ Type) converter { cv := converter{} - cv.toLong, _ = createLongConverter(typ) cv.toFloat, _ = createFloatConverter(typ) cv.toDouble, _ = createDoubleConverter(typ) cv.toString, _ = createStringConverter(typ) cv.toBytes, _ = createBytesConverter(typ) - return cv } diff --git a/decoder_dynamic_bench_test.go b/decoder_dynamic_bench_test.go new file mode 100644 index 00000000..da216ea9 --- /dev/null +++ b/decoder_dynamic_bench_test.go @@ -0,0 +1,60 @@ +package avro_test + +import ( + "bytes" + "testing" + + "github.com/hamba/avro/v2" +) + +func BenchmarkDecoder_Interface(b *testing.B) { + tests := []struct { + name string + data []byte + schema string + got any + want any + }{ + { + name: "Empty Interface", + data: []byte{0x36, 0x06, 0x66, 0x6f, 0x6f}, + schema: `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": "long"}, {"name": "b", "type": "string"}]}`, + got: nil, + want: map[string]any{"a": int64(27), "b": "foo"}, + }, + { + name: "Interface Non-Ptr", + data: []byte{0x36, 0x06, 0x66, 0x6f, 0x6f}, + schema: `{"type": "record", "name": "test", "fields": [{"name": "a", "type": "long"}, {"name": "b", "type": "string"}]}`, + got: TestRecord{}, + want: map[string]any{"a": int64(27), "b": "foo"}, + }, + { + name: "Interface Nil Ptr", + data: []byte{0x36, 0x06, 0x66, 0x6f, 0x6f}, + schema: `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": "long"}, {"name": "b", "type": "string"}]}`, + got: (*TestRecord)(nil), + want: &TestRecord{A: 27, B: "foo"}, + }, + { + name: "Interface Ptr", + data: []byte{0x36, 0x06, 0x66, 0x6f, 0x6f}, + schema: `{"type": "record", "name": "test", "fields": [{"name": "a", "type": "long"}, {"name": "b", "type": "string"}]}`, + got: &TestRecord{}, + want: &TestRecord{A: 27, B: "foo"}, + }, + } + + for _, test := range tests { + test := test + b.Run(test.name, func(b *testing.B) { + defer ConfigTeardown() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + dec, _ := avro.NewDecoder(test.schema, bytes.NewReader(test.data)) + _ = dec.Decode(&test.got) + } + }) + } +} diff --git a/reader_promoter.go b/reader_promoter.go index b5575eed..2a6ff3bf 100644 --- a/reader_promoter.go +++ b/reader_promoter.go @@ -44,7 +44,8 @@ func (p *readerPromoter) ReadFloat() float32 { func (p *readerPromoter) ReadDouble() float64 { if p.toDouble != nil { - return p.toDouble(p.r) + v := p.toDouble(p.r) + return v } return p.r.ReadDouble() diff --git a/schema.go b/schema.go index 7506975b..1364ce83 100644 --- a/schema.go +++ b/schema.go @@ -479,9 +479,22 @@ func (s *PrimitiveSchema) MarshalJSON() ([]byte, error) { return buf.Bytes(), nil } +// Temporary HACK to allow testing schema resolution logic... +// a better solution would be to extend decoder cache key. +type primitiveSchemaFingerprint struct { + s *PrimitiveSchema +} + +func (sfp *primitiveSchemaFingerprint) String() string { + if sfp.s.actual == "" { + return sfp.s.String() + } + return sfp.s.String() + ":" + string(sfp.s.actual) +} + // Fingerprint returns the SHA256 fingerprint of the schema. func (s *PrimitiveSchema) Fingerprint() [32]byte { - return s.fingerprinter.Fingerprint(s) + return s.fingerprinter.Fingerprint(&primitiveSchemaFingerprint{s: s}) } // FingerprintUsing returns the fingerprint of the schema using the given algorithm or an error. @@ -489,6 +502,12 @@ func (s *PrimitiveSchema) FingerprintUsing(typ FingerprintType) ([]byte, error) return s.fingerprinter.FingerprintUsing(typ, s) } +// Actual returns the actual type of the schema. +// This field is only presents during write-read schema resolution. +func (s *PrimitiveSchema) Actual() Type { + return s.actual +} + // RecordSchema is an Avro record type schema. type RecordSchema struct { name @@ -677,11 +696,6 @@ func NewField(name string, typ Schema, opts ...SchemaOption) (*Field, error) { return f, nil } -// SetFieldAction updates the given field's action. Mainly used for testing purposes. -func SetFieldAction(field *Field, action Action) { - field.action = action -} - // Name returns the name of a field. func (f *Field) Name() string { return f.name @@ -1408,10 +1422,18 @@ func isValidDefault(schema Schema, def any) (any, bool) { } } return def, found - case String, Bytes, Fixed: + case String: if _, ok := def.(string); ok { return def, true } + case Bytes, Fixed: + // Spec: Default values for bytes and fixed fields are JSON strings, + // where Unicode code points 0-255 are mapped to unsigned 8-bit byte values 0-255. + if d, ok := def.(string); ok { + if b, ok := isValidDefaultBytes(d); ok { + return b, true + } + } case Boolean: if _, ok := def.(bool); ok { return def, true @@ -1523,3 +1545,16 @@ func schemaTypeName(schema Schema) string { } return sname } + +func isValidDefaultBytes(def string) ([]byte, bool) { + runes := []rune(def) + l := len(runes) + b := make([]byte, l) + for i := 0; i < l; i++ { + if runes[i] < 0 || runes[i] > 255 { + return nil, false + } + b[i] = byte(runes[i]) + } + return b, true +} diff --git a/schema_internal_test.go b/schema_internal_test.go index 67dec6b5..2104cf81 100644 --- a/schema_internal_test.go +++ b/schema_internal_test.go @@ -144,7 +144,7 @@ func TestIsValidDefault(t *testing.T) { return NewPrimitiveSchema(Bytes, nil) }, def: "test", - want: "test", + want: []byte("test"), wantOk: true, }, { @@ -199,7 +199,7 @@ func TestIsValidDefault(t *testing.T) { return s }, def: "test", - want: "test", + want: []byte("test"), wantOk: true, }, { From 9d6c104a7480108cbf482ac4a4ad642eb571cc0a Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Mon, 11 Dec 2023 19:44:48 +0100 Subject: [PATCH 07/25] fix(schema compatibility): support named schema aliases --- schema.go | 6 ++++++ schema_compatibility.go | 7 ++++++- schema_internal_test.go | 4 ++-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/schema.go b/schema.go index 1364ce83..db8dc411 100644 --- a/schema.go +++ b/schema.go @@ -196,6 +196,9 @@ type NamedSchema interface { // FullName returns the full qualified name of a schema. FullName() string + + // Aliases returns the full qualified aliases of a schema. + Aliases() []string } // LogicalTypeSchema represents a schema that can contain a logical type. @@ -1431,6 +1434,9 @@ func isValidDefault(schema Schema, def any) (any, bool) { // where Unicode code points 0-255 are mapped to unsigned 8-bit byte values 0-255. if d, ok := def.(string); ok { if b, ok := isValidDefaultBytes(d); ok { + if schema.Type() == Fixed { + return byteSliceToArray(b, schema.(*FixedSchema).Size()), true + } return b, true } } diff --git a/schema_compatibility.go b/schema_compatibility.go index fc753363..e3ff509a 100644 --- a/schema_compatibility.go +++ b/schema_compatibility.go @@ -177,6 +177,11 @@ func (c *SchemaCompatibility) match(reader, writer Schema) error { func (c *SchemaCompatibility) checkSchemaName(reader, writer NamedSchema) error { if reader.FullName() != writer.FullName() { + for _, alias := range reader.Aliases() { + if alias == writer.FullName() { + return nil + } + } return fmt.Errorf("reader schema %s and writer schema %s names do not match", reader.FullName(), writer.FullName()) } @@ -409,5 +414,5 @@ func (c *SchemaCompatibility) resolveRecord(reader, writer Schema) (Schema, erro fields = append(fields, &f) } - return NewRecordSchema(r.Name(), r.Namespace(), fields) + return NewRecordSchema(r.Name(), r.Namespace(), fields, WithAliases(r.Aliases())) } diff --git a/schema_internal_test.go b/schema_internal_test.go index 2104cf81..7872eaf7 100644 --- a/schema_internal_test.go +++ b/schema_internal_test.go @@ -195,11 +195,11 @@ func TestIsValidDefault(t *testing.T) { { name: "Fixed", schemaFn: func() Schema { - s, _ := NewFixedSchema("foo", "", 1, nil) + s, _ := NewFixedSchema("foo", "", 4, nil) return s }, def: "test", - want: []byte("test"), + want: [4]byte{'t', 'e', 's', 't'}, wantOk: true, }, { From 9b51d57ffae1b8836b972ae96ab27059f8e6a66f Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Mon, 11 Dec 2023 19:54:10 +0100 Subject: [PATCH 08/25] fix(default decoder): try different implementation --- codec_default.go | 532 +++-------------------------------------------- codec_record.go | 27 ++- 2 files changed, 49 insertions(+), 510 deletions(-) diff --git a/codec_default.go b/codec_default.go index dff030d8..946d0ee7 100644 --- a/codec_default.go +++ b/codec_default.go @@ -1,522 +1,46 @@ package avro import ( - "encoding" - "encoding/binary" - "fmt" - "math/big" - "reflect" "unsafe" "github.com/modern-go/reflect2" ) -func createDefaultDecoder(cfg *frozenConfig, schema Schema, def any, typ reflect2.Type) ValDecoder { - if typ.Kind() == reflect.Interface { - if schema.Type() != Union && schema.Type() != Null { - return &efaceDefaultDecoder{def: def, schema: schema} - } +func createDefaultDecoder( + cfg *frozenConfig, + schema Schema, + def any, + typ reflect2.Type, + w *Writer, + r *Reader, +) ValDecoder { + defaultType := reflect2.TypeOf(def) + var defaultEncoder ValEncoder + // tmp workaround: codec_union failed to resolve name of struct{} typ + if def == nullDefault { + defaultEncoder = &nullCodec{} + } else { + defaultEncoder = encoderOfType(cfg, schema, defaultType) } - - switch schema.Type() { - case Null: - return &nullDefaultDecoder{ - typ: typ, - } - case Boolean: - return &boolDefaultDecoder{ - def: def.(bool), - } - case Int: - return &intDefaultDecoder{ - def: def.(int), - typ: typ, - } - case Long: - return &longDefaultDecoder{ - def: def.(int64), - typ: typ, - } - case Float: - return &floatDefaultDecoder{ - def: def.(float32), - typ: typ, - } - case Double: - return &doubleDefaultDecoder{ - def: def.(float64), - typ: typ, - } - case String: - if typ.Implements(textUnmarshalerType) { - return &textDefaultMarshalerCodec{typ, def.(string)} - } - ptrType := reflect2.PtrTo(typ) - if ptrType.Implements(textUnmarshalerType) { - return &referenceDecoder{ - &textDefaultMarshalerCodec{typ: ptrType, def: def.(string)}, - } - } - return &stringDefaultDecoder{ - def: def.(string), - } - case Bytes: - return &bytesDefaultDecoder{ - def: def.([]byte), - typ: typ, - } - case Fixed: - return &fixedDefaultDecoder{ - fixed: schema.(*FixedSchema), - def: def.([]byte), - typ: typ, - } - case Enum: - return &enumDefaultDecoder{typ: typ, def: def.(string)} - case Ref: - return createDefaultDecoder(cfg, schema.(*RefSchema).Schema(), def, typ) - case Record: - return defaultDecoderOfRecord(cfg, schema, def, typ) - case Array: - return defaultDecoderOfArray(cfg, schema, def, typ) - case Map: - return defaultDecoderOfMap(cfg, schema, def, typ) - case Union: - return defaultDecoderOfUnion(schema.(*UnionSchema), def, typ) - default: - return &errorDecoder{err: fmt.Errorf("avro: schema type %s is unsupported", schema.Type())} - } -} - -type textDefaultMarshalerCodec struct { - typ reflect2.Type - def string -} - -func (d textDefaultMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) { - obj := d.typ.UnsafeIndirect(ptr) - if reflect2.IsNil(obj) { - ptrType := d.typ.(*reflect2.UnsafePtrType) - newPtr := ptrType.Elem().UnsafeNew() - *((*unsafe.Pointer)(ptr)) = newPtr - obj = d.typ.UnsafeIndirect(ptr) - } - unmarshaler := (obj).(encoding.TextUnmarshaler) - - b := []byte(d.def) - - err := unmarshaler.UnmarshalText(b) - if err != nil { - r.ReportError("decode default textMarshalerCodec", err.Error()) - } -} - -type efaceDefaultDecoder struct { - def any - schema Schema -} - -func (d *efaceDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - rPtr, rTyp, err := dynamicReceiver(d.schema, r.cfg.resolver) - if err != nil { - r.ReportError("decode default", err.Error()) - return - } - - createDefaultDecoder(r.cfg, d.schema, d.def, rTyp).Decode(rPtr, r) - - *(*any)(ptr) = rTyp.UnsafeIndirect(rPtr) -} - -type nullDefaultDecoder struct { - typ reflect2.Type -} - -func (d *nullDefaultDecoder) Decode(ptr unsafe.Pointer, _ *Reader) { - *((*unsafe.Pointer)(ptr)) = nil -} - -type boolDefaultDecoder struct { - def bool -} - -func (d *boolDefaultDecoder) Decode(ptr unsafe.Pointer, _ *Reader) { - *((*bool)(ptr)) = d.def -} - -type intDefaultDecoder struct { - def int - typ reflect2.Type -} - -func (d *intDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - switch d.typ.Kind() { - case reflect.Int: - *((*int)(ptr)) = d.def - case reflect.Uint: - *((*uint)(ptr)) = uint(d.def) - case reflect.Int8: - *((*int8)(ptr)) = int8(d.def) - case reflect.Uint8: - *((*uint8)(ptr)) = uint8(d.def) - case reflect.Int16: - *((*int16)(ptr)) = int16(d.def) - case reflect.Uint16: - *((*uint16)(ptr)) = uint16(d.def) - case reflect.Int32: - *((*int32)(ptr)) = int32(d.def) - default: - r.ReportError("decode default", "unsupported type") - } -} - -type longDefaultDecoder struct { - def int64 - typ reflect2.Type -} - -func (d *longDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - switch d.typ.Kind() { - case reflect.Int32: - *((*int32)(ptr)) = int32(d.def) - case reflect.Uint32: - *((*uint32)(ptr)) = uint32(d.def) - case reflect.Int64: - *((*int64)(ptr)) = d.def - default: - r.ReportError("decode default", "unsupported type") - } -} - -type floatDefaultDecoder struct { - def float32 - typ reflect2.Type -} - -func (d *floatDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - switch d.typ.Kind() { - case reflect.Float32: - *((*float32)(ptr)) = d.def - case reflect.Float64: - *((*float64)(ptr)) = float64(d.def) - default: - r.ReportError("decode default", "unsupported type") - } -} - -type doubleDefaultDecoder struct { - def float64 - typ reflect2.Type -} - -func (d *doubleDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - switch d.typ.Kind() { - case reflect.Float64: - *((*float64)(ptr)) = d.def - case reflect.Float32: - *((*float32)(ptr)) = float32(d.def) - default: - r.ReportError("decode default", "unsupported type") - } -} - -type stringDefaultDecoder struct { - def string -} - -func (d *stringDefaultDecoder) Decode(ptr unsafe.Pointer, _ *Reader) { - *((*string)(ptr)) = d.def -} - -type bytesDefaultDecoder struct { - def []byte - typ reflect2.Type -} - -func (d *bytesDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - if d.typ.Kind() != reflect.Slice { - r.ReportError("decode default", "inconvertible type") - return - } - if d.typ.(reflect2.SliceType).Elem().Kind() != reflect.Uint8 { - r.ReportError("decode default", "inconvertible type") - return - } - - d.typ.(*reflect2.UnsafeSliceType).UnsafeSet(ptr, reflect2.PtrOf(d.def)) -} - -func defaultDecoderOfRecord(cfg *frozenConfig, schema Schema, def any, typ reflect2.Type) ValDecoder { - rec := schema.(*RecordSchema) - mDef, ok := def.(map[string]any) - if !ok { - return &errorDecoder{err: fmt.Errorf("avro: invalid default for record field")} - } - - fields := make([]*Field, len(rec.Fields())) - for i, field := range rec.Fields() { - f, err := NewField(field.Name(), field.Type(), - WithDefault(mDef[field.Name()]), WithAliases(field.Aliases()), WithOrder(field.Order()), - ) - if err != nil { - return &errorDecoder{err: fmt.Errorf("avro: %w", err)} - } - f.action = FieldSetDefault - fields[i] = f - } - - r, err := NewRecordSchema(rec.Name(), rec.Namespace(), fields, WithAliases(rec.Aliases())) - if err != nil { - return &errorDecoder{err: fmt.Errorf("avro: %w", err)} + if defaultType.LikePtr() { + defaultEncoder = &onePtrEncoder{defaultEncoder} } + defaultEncoder.Encode(reflect2.PtrOf(def), w) - switch typ.Kind() { - case reflect.Struct: - return decoderOfStruct(cfg, r, typ) - case reflect.Map: - return decoderOfRecord(cfg, r, typ) + return &defaultDecoder{ + defaultReader: r, + decoder: decoderOfType(cfg, schema, typ), } - - return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} } -type enumDefaultDecoder struct { - typ reflect2.Type - def string +type defaultDecoder struct { + defaultReader *Reader + decoder ValDecoder } -func (d *enumDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - unmarshal := func(def string, isPtr bool) { - var obj any - if isPtr { - obj = d.typ.PackEFace(ptr) - } else { - obj = d.typ.UnsafeIndirect(ptr) - } - if reflect2.IsNil(obj) { - ptrType := d.typ.(*reflect2.UnsafePtrType) - newPtr := ptrType.Elem().UnsafeNew() - *((*unsafe.Pointer)(ptr)) = newPtr - obj = d.typ.UnsafeIndirect(ptr) - } - unmarshaler := (obj).(encoding.TextUnmarshaler) - err := unmarshaler.UnmarshalText([]byte(def)) - if err != nil { - r.ReportError("decode default textMarshalerCodec", err.Error()) - } - } - - switch { - case d.typ.Kind() == reflect.String: - *((*string)(ptr)) = d.def - return - case reflect2.PtrTo(d.typ).Implements(textUnmarshalerType): - unmarshal(d.def, true) - return - case d.typ.Implements(textUnmarshalerType): - unmarshal(d.def, false) - return - default: - r.ReportError("decode default", "unsupported type") - } +// Decode implements ValDecoder. +func (d *defaultDecoder) Decode(ptr unsafe.Pointer, _ *Reader) { + d.decoder.Decode(ptr, d.defaultReader) } -func defaultDecoderOfArray(cfg *frozenConfig, schema Schema, def any, typ reflect2.Type) ValDecoder { - if typ.Kind() != reflect.Slice { - return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} - } - - return &sliceDefaultDecoder{ - def: def.([]any), - typ: typ.(*reflect2.UnsafeSliceType), - decoder: func(def any) ValDecoder { - return createDefaultDecoder(cfg, schema.(*ArraySchema).Items(), def, typ.(*reflect2.UnsafeSliceType).Elem()) - }, - } -} - -type sliceDefaultDecoder struct { - def []any - typ *reflect2.UnsafeSliceType - decoder func(def any) ValDecoder -} - -func (d *sliceDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - size := len(d.def) - d.typ.UnsafeGrow(ptr, size) - for i := 0; i < size; i++ { - elemPtr := d.typ.UnsafeGetIndex(ptr, i) - d.decoder(d.def[i]).Decode(elemPtr, r) - } -} - -func defaultDecoderOfMap(cfg *frozenConfig, schema Schema, def any, typ reflect2.Type) ValDecoder { - if typ.Kind() != reflect.Map { - return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} - } - - return &mapDefaultDecoder{ - typ: typ.(*reflect2.UnsafeMapType), - def: def.(map[string]any), - decoder: func(def any) ValDecoder { - return createDefaultDecoder(cfg, schema.(*MapSchema).Values(), def, typ.(*reflect2.UnsafeMapType).Elem()) - }, - } -} - -type mapDefaultDecoder struct { - typ *reflect2.UnsafeMapType - decoder func(def any) ValDecoder - def map[string]any -} - -func (d *mapDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - if d.typ.UnsafeIsNil(ptr) { - d.typ.UnsafeSet(ptr, d.typ.UnsafeMakeMap(0)) - } - for k, v := range d.def { - key := k - keyPtr := reflect2.PtrOf(&key) - elemPtr := d.typ.UnsafeNew() - d.decoder(v).Decode(elemPtr, r) - d.typ.UnsafeSetIndex(ptr, keyPtr, elemPtr) - } -} - -type fixedDefaultDecoder struct { - typ reflect2.Type - def []byte - fixed *FixedSchema -} - -func (d *fixedDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - l := len(d.def) - switch d.typ.Kind() { - case reflect.Array: - arrayType := d.typ.(reflect2.ArrayType) - if arrayType.Elem().Kind() != reflect.Uint8 || arrayType.Len() != d.fixed.Size() { - r.ReportError("decode default", "unsupported type") - return - } - if arrayType.Len() != l { - r.ReportError("decode default", "invalid default") - return - } - for i := 0; i < arrayType.Len(); i++ { - arrayType.UnsafeSetIndex(ptr, i, reflect2.PtrOf(d.def[i])) - } - - case reflect.Uint64: - if d.fixed.Size() != 8 { - r.ReportError("decode default", "unsupported type") - return - } - if l != 8 { - r.ReportError("decode default", "invalid default") - return - } - *(*uint64)(ptr) = binary.BigEndian.Uint64(d.def) - - case reflect.Struct: - ls := d.fixed.Logical() - if ls == nil { - break - } - typ1 := d.typ.Type1() - switch { - case typ1.ConvertibleTo(durType) && ls.Type() == Duration: - if l != 12 { - r.ReportError("decode default", "invalid default") - return - } - *((*LogicalDuration)(ptr)) = durationFromBytes(d.def) - - case typ1.ConvertibleTo(ratType) && ls.Type() == Decimal: - dec := ls.(*DecimalLogicalSchema) - if d.fixed.Size() != l { - r.ReportError("decode default", "invalid default") - return - } - *((*big.Rat)(ptr)) = *ratFromBytes(d.def, dec.Scale()) - default: - r.ReportError("decode default", "unsupported type") - } - - default: - r.ReportError("decode default", "unsupported type") - } -} - -func durationFromBytes(b []byte) LogicalDuration { - var duration LogicalDuration - - duration.Months = binary.LittleEndian.Uint32(b[0:4]) - duration.Days = binary.LittleEndian.Uint32(b[4:8]) - duration.Milliseconds = binary.LittleEndian.Uint32(b[8:12]) - - return duration -} - -func defaultDecoderOfUnion(schema *UnionSchema, def any, typ reflect2.Type) ValDecoder { - return &unionDefaultDecoder{ - typ: typ, - def: def, - union: schema, - } -} - -type unionDefaultDecoder struct { - typ reflect2.Type - def any - union *UnionSchema -} - -func (d *unionDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - switch d.typ.Kind() { - case reflect.Map: - if d.typ.(reflect2.MapType).Key().Kind() != reflect.String || - d.typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { - break - } - schema := d.union.Types()[0] - if schema.Type() == Null { - return - } - - mapType := d.typ.(*reflect2.UnsafeMapType) - if mapType.UnsafeIsNil(ptr) { - mapType.UnsafeSet(ptr, mapType.UnsafeMakeMap(0)) - } - - key := schemaTypeName(schema) - keyPtr := reflect2.PtrOf(key) - elemPtr := mapType.Elem().UnsafeNew() - - decoder := createDefaultDecoder(r.cfg, d.union.Types()[0], d.def, mapType.Elem()) - decoder.Decode(elemPtr, r) - - mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr) - - case reflect.Ptr: - if !d.union.Nullable() { - break - } - if d.union.Types()[0].Type() == Null { - *((*unsafe.Pointer)(ptr)) = nil - return - } - - decoder := createDefaultDecoder(r.cfg, d.union.Types()[0], d.def, d.typ.(*reflect2.UnsafePtrType).Elem()) - if *((*unsafe.Pointer)(ptr)) == nil { - newPtr := d.typ.UnsafeNew() - decoder.Decode(newPtr, r) - *((*unsafe.Pointer)(ptr)) = newPtr - return - } - decoder.Decode(*((*unsafe.Pointer)(ptr)), r) - - case reflect.Interface: - decoder := createDefaultDecoder(r.cfg, d.union.Types()[0], d.def, d.typ) - decoder.Decode(ptr, r) - } -} +var _ ValDecoder = &defaultDecoder{} diff --git a/codec_record.go b/codec_record.go index 7e04a6f0..47403083 100644 --- a/codec_record.go +++ b/codec_record.go @@ -1,6 +1,7 @@ package avro import ( + "bytes" "errors" "fmt" "io" @@ -58,6 +59,12 @@ func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec structDesc := describeStruct(cfg.getTagKey(), typ) fields := make([]*structFieldDecoder, 0, len(rec.Fields())) + + // TBD figure out how to cache record defaults binary + buf := bytes.NewBuffer([]byte{}) + defW := NewWriter(buf, 512, WithWriterConfig(cfg)) + defR := NewReader(buf, 512, WithReaderConfig(cfg)) + for _, field := range rec.Fields() { if field.action == FieldDrain { fields = append(fields, &structFieldDecoder{ @@ -88,11 +95,7 @@ func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec if field.hasDef { fields = append(fields, &structFieldDecoder{ field: sf.Field, - decoder: createDefaultDecoder(cfg, field.Type(), field.def, sf.Field[len(sf.Field)-1].Type()), - }) - } else { - fields = append(fields, &structFieldDecoder{ - decoder: createSkipDecoder(field.Type()), + decoder: createDefaultDecoder(cfg, field.Type(), field.def, sf.Field[len(sf.Field)-1].Type(), defW, defR), }) } @@ -106,6 +109,10 @@ func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec }) } + if err := defW.Flush(); err != nil { + return &errorDecoder{err: fmt.Errorf("decode default: %w", err)} + } + return &structDecoder{typ: typ, fields: fields} } @@ -257,6 +264,10 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec rec := schema.(*RecordSchema) mapType := typ.(*reflect2.UnsafeMapType) + buf := bytes.NewBuffer([]byte{}) + defW := NewWriter(buf, 512, WithWriterConfig(cfg)) + defR := NewReader(buf, 512, WithReaderConfig(cfg)) + fields := make([]recordMapDecoderField, len(rec.Fields())) for i, field := range rec.Fields() { if field.action == FieldDrain { @@ -272,7 +283,7 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec if field.hasDef { fields[i] = recordMapDecoderField{ name: field.Name(), - decoder: createDefaultDecoder(cfg, field.Type(), field.def, mapType.Elem()), + decoder: createDefaultDecoder(cfg, field.Type(), field.def, mapType.Elem(), defW, defR), } } @@ -285,6 +296,10 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec } } + if err := defW.Flush(); err != nil { + return &errorDecoder{err: fmt.Errorf("decode default: %w", err)} + } + return &recordMapDecoder{ mapType: mapType, elemType: mapType.Elem(), From eaef2adc8b9394b19c283f52551547e4e3c83d86 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Tue, 12 Dec 2023 19:02:15 +0100 Subject: [PATCH 09/25] fix(defaut decoder): cache field encoded default and borrow reader --- codec_default.go | 52 +++++++++++++++++++++++++++++------------------- codec_record.go | 22 ++------------------ schema.go | 28 ++++++++++++++++++++++++-- 3 files changed, 59 insertions(+), 43 deletions(-) diff --git a/codec_default.go b/codec_default.go index 946d0ee7..35b11cdc 100644 --- a/codec_default.go +++ b/codec_default.go @@ -1,35 +1,45 @@ package avro import ( + "fmt" "unsafe" "github.com/modern-go/reflect2" ) -func createDefaultDecoder( - cfg *frozenConfig, - schema Schema, - def any, - typ reflect2.Type, - w *Writer, - r *Reader, -) ValDecoder { - defaultType := reflect2.TypeOf(def) - var defaultEncoder ValEncoder - // tmp workaround: codec_union failed to resolve name of struct{} typ - if def == nullDefault { - defaultEncoder = &nullCodec{} - } else { - defaultEncoder = encoderOfType(cfg, schema, defaultType) - } - if defaultType.LikePtr() { - defaultEncoder = &onePtrEncoder{defaultEncoder} +func createDefaultDecoder(cfg *frozenConfig, field *Field, typ reflect2.Type) ValDecoder { + fn := func(def any) ([]byte, error) { + defaultType := reflect2.TypeOf(def) + var defaultEncoder ValEncoder + // tmp workaround: codec_union failed to resolve name of struct{} typ + if def == nullDefault { + defaultEncoder = &nullCodec{} + } else { + defaultEncoder = encoderOfType(cfg, field.Type(), defaultType) + } + if defaultType.LikePtr() { + defaultEncoder = &onePtrEncoder{defaultEncoder} + } + + w := cfg.borrowWriter() + defaultEncoder.Encode(reflect2.PtrOf(def), w) + if w.Error != nil { + return nil, w.Error + } + if err := w.Flush(); err != nil { + return nil, err + } + + return w.Buffer(), nil } - defaultEncoder.Encode(reflect2.PtrOf(def), w) + b, err := field.encodeDefault(fn) + if err != nil { + return &errorDecoder{err: fmt.Errorf("decode default: %w", err)} + } return &defaultDecoder{ - defaultReader: r, - decoder: decoderOfType(cfg, schema, typ), + defaultReader: cfg.borrowReader(b), + decoder: decoderOfType(cfg, field.Type(), typ), } } diff --git a/codec_record.go b/codec_record.go index 47403083..5dafca92 100644 --- a/codec_record.go +++ b/codec_record.go @@ -1,7 +1,6 @@ package avro import ( - "bytes" "errors" "fmt" "io" @@ -60,11 +59,6 @@ func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec fields := make([]*structFieldDecoder, 0, len(rec.Fields())) - // TBD figure out how to cache record defaults binary - buf := bytes.NewBuffer([]byte{}) - defW := NewWriter(buf, 512, WithWriterConfig(cfg)) - defR := NewReader(buf, 512, WithReaderConfig(cfg)) - for _, field := range rec.Fields() { if field.action == FieldDrain { fields = append(fields, &structFieldDecoder{ @@ -95,7 +89,7 @@ func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec if field.hasDef { fields = append(fields, &structFieldDecoder{ field: sf.Field, - decoder: createDefaultDecoder(cfg, field.Type(), field.def, sf.Field[len(sf.Field)-1].Type(), defW, defR), + decoder: createDefaultDecoder(cfg, field, sf.Field[len(sf.Field)-1].Type()), }) } @@ -109,10 +103,6 @@ func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec }) } - if err := defW.Flush(); err != nil { - return &errorDecoder{err: fmt.Errorf("decode default: %w", err)} - } - return &structDecoder{typ: typ, fields: fields} } @@ -264,10 +254,6 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec rec := schema.(*RecordSchema) mapType := typ.(*reflect2.UnsafeMapType) - buf := bytes.NewBuffer([]byte{}) - defW := NewWriter(buf, 512, WithWriterConfig(cfg)) - defR := NewReader(buf, 512, WithReaderConfig(cfg)) - fields := make([]recordMapDecoderField, len(rec.Fields())) for i, field := range rec.Fields() { if field.action == FieldDrain { @@ -283,7 +269,7 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec if field.hasDef { fields[i] = recordMapDecoderField{ name: field.Name(), - decoder: createDefaultDecoder(cfg, field.Type(), field.def, mapType.Elem(), defW, defR), + decoder: createDefaultDecoder(cfg, field, mapType.Elem()), } } @@ -296,10 +282,6 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec } } - if err := defW.Flush(); err != nil { - return &errorDecoder{err: fmt.Errorf("decode default: %w", err)} - } - return &recordMapDecoder{ mapType: mapType, elemType: mapType.Elem(), diff --git a/schema.go b/schema.go index db8dc411..a0ebe17c 100644 --- a/schema.go +++ b/schema.go @@ -646,7 +646,12 @@ type Field struct { hasDef bool def any order Order - action Action + + // action mainly used when decoding data that lack the field for schema evolution purposes. + action Action + // encodedDef mainly used when decoding data that lack the field for schema evolution purposes. + // Its value remains empty unless the field's encodeDefault function is called. + encodedDef []byte } type noDef struct{} @@ -735,6 +740,25 @@ func (f *Field) Default() any { return f.def } +func (f *Field) encodeDefault(encode func(any) ([]byte, error)) ([]byte, error) { + if f.encodedDef != nil { + return f.encodedDef, nil + } + if !f.hasDef { + return nil, fmt.Errorf("avro: '%s' field must have a non-empty default value", f.name) + } + if encode == nil { + return nil, fmt.Errorf("avro: failed to encode '%s' default value", f.name) + } + b, err := encode(f.def) + if err != nil { + return nil, err + } + f.encodedDef = b + + return f.encodedDef, nil +} + // Doc returns the documentation of a field. func (f *Field) Doc() string { return f.doc @@ -818,7 +842,7 @@ func NewEnumSchema(name, namespace string, symbols []string, opts ...SchemaOptio } for _, sym := range symbols { if err = validateName(sym); err != nil { - return nil, fmt.Errorf("avro: invalid symnol %q", sym) + return nil, fmt.Errorf("avro: invalid symbol %q", sym) } } From b4224dc71d9b5e7cb6e5c6f926b84a569fe24280 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Tue, 12 Dec 2023 19:07:34 +0100 Subject: [PATCH 10/25] fix: resolver unable to resolve nullDefault type --- codec_default.go | 10 +--------- resolver.go | 1 + schema.go | 2 +- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/codec_default.go b/codec_default.go index 35b11cdc..9063592d 100644 --- a/codec_default.go +++ b/codec_default.go @@ -10,17 +10,10 @@ import ( func createDefaultDecoder(cfg *frozenConfig, field *Field, typ reflect2.Type) ValDecoder { fn := func(def any) ([]byte, error) { defaultType := reflect2.TypeOf(def) - var defaultEncoder ValEncoder - // tmp workaround: codec_union failed to resolve name of struct{} typ - if def == nullDefault { - defaultEncoder = &nullCodec{} - } else { - defaultEncoder = encoderOfType(cfg, field.Type(), defaultType) - } + defaultEncoder := encoderOfType(cfg, field.Type(), defaultType) if defaultType.LikePtr() { defaultEncoder = &onePtrEncoder{defaultEncoder} } - w := cfg.borrowWriter() defaultEncoder.Encode(reflect2.PtrOf(def), w) if w.Error != nil { @@ -29,7 +22,6 @@ func createDefaultDecoder(cfg *frozenConfig, field *Field, typ reflect2.Type) Va if err := w.Flush(); err != nil { return nil, err } - return w.Buffer(), nil } diff --git a/resolver.go b/resolver.go index aabf4663..7f67fa84 100644 --- a/resolver.go +++ b/resolver.go @@ -22,6 +22,7 @@ func NewTypeResolver() *TypeResolver { // Register basic types r.Register(string(Null), &null{}) + r.Register(string(Null), null{}) r.Register(string(Int), int8(0)) r.Register(string(Int), int16(0)) r.Register(string(Int), int32(0)) diff --git a/schema.go b/schema.go index a0ebe17c..3e3b6a25 100644 --- a/schema.go +++ b/schema.go @@ -17,7 +17,7 @@ import ( jsoniter "github.com/json-iterator/go" ) -var nullDefault = struct{}{} +var nullDefault null = struct{}{} var ( schemaReserved = []string{ From 0b74a61fcf65e141741d83315c3c92933ad570dd Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Fri, 15 Dec 2023 15:42:08 +0100 Subject: [PATCH 11/25] clean up generic decode and improve test coverage --- codec_default.go | 3 - codec_default_internal_test.go | 27 ++++- codec_dynamic.go | 127 +------------------- codec_record.go | 7 +- codec_union.go | 10 +- converter.go | 2 +- converter_test.go | 137 +++++++++++++++++++++ decoder_dynamic_bench_test.go | 60 ---------- generic.go | 122 +++++++++++++++++++ generic_internal_test.go | 212 +++++++++++++++++++++++++++++++++ reader_generic.go | 25 ++-- reader_promoter.go | 68 ----------- schema_compatibility.go | 56 +++++---- schema_compatibility_test.go | 79 +++++++++++- 14 files changed, 619 insertions(+), 316 deletions(-) create mode 100644 converter_test.go delete mode 100644 decoder_dynamic_bench_test.go create mode 100644 generic.go create mode 100644 generic_internal_test.go delete mode 100644 reader_promoter.go diff --git a/codec_default.go b/codec_default.go index 9063592d..43fff385 100644 --- a/codec_default.go +++ b/codec_default.go @@ -19,9 +19,6 @@ func createDefaultDecoder(cfg *frozenConfig, field *Field, typ reflect2.Type) Va if w.Error != nil { return nil, w.Error } - if err := w.Flush(); err != nil { - return nil, err - } return w.Buffer(), nil } diff --git a/codec_default_internal_test.go b/codec_default_internal_test.go index a68865e3..59c82130 100644 --- a/codec_default_internal_test.go +++ b/codec_default_internal_test.go @@ -31,8 +31,33 @@ func ConfigTeardown() { DefaultConfig = Config{}.Freeze() } -func TestDecoder_DefaultBool(t *testing.T) { +func TestDecoder_InvalidDefault(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x6, 0x66, 0x6f, 0x6f} + schema := MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": "boolean", "default": true} + ] + }`) + + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault + // alter default value to force encoding failure + schema.(*RecordSchema).fields[1].def = "invalid value" + + dec := NewDecoderForSchema(schema, bytes.NewReader(data)) + + var got map[string]any + err := dec.Decode(&got) + + require.Error(t, err) +} + +func TestDecoder_DefaultBool(t *testing.T) { defer ConfigTeardown() // write schema diff --git a/codec_dynamic.go b/codec_dynamic.go index eaaf7de0..6e27781e 100644 --- a/codec_dynamic.go +++ b/codec_dynamic.go @@ -1,10 +1,7 @@ package avro import ( - "fmt" - "math/big" "reflect" - "time" "unsafe" "github.com/modern-go/reflect2" @@ -18,26 +15,14 @@ func (d *efaceDecoder) Decode(ptr unsafe.Pointer, r *Reader) { pObj := (*any)(ptr) obj := *pObj if obj == nil { - rPtr, rtyp, err := dynamicReceiver(d.schema, r.cfg.resolver) - if err != nil { - r.ReportError("Read", err.Error()) - return - } - decoderOfType(r.cfg, d.schema, rtyp).Decode(rPtr, r) - *pObj = rtyp.UnsafeIndirect(rPtr) + *pObj = genericDecode(d.schema, r) // *pObj = r.ReadNext(d.schema) return } typ := reflect2.TypeOf(obj) if typ.Kind() != reflect.Ptr { - rPtr, rTyp, err := dynamicReceiver(d.schema, r.cfg.resolver) - if err != nil { - r.ReportError("Read", err.Error()) - return - } - decoderOfType(r.cfg, d.schema, rTyp).Decode(rPtr, r) - *pObj = rTyp.UnsafeIndirect(rPtr) + *pObj = genericDecode(d.schema, r) // *pObj = r.ReadNext(d.schema) return } @@ -62,111 +47,3 @@ func (e *interfaceEncoder) Encode(ptr unsafe.Pointer, w *Writer) { obj := e.typ.UnsafeIndirect(ptr) w.WriteVal(e.schema, obj) } - -func dynamicReceiver(schema Schema, resolver *TypeResolver) (unsafe.Pointer, reflect2.Type, error) { - var ls LogicalSchema - lts, ok := schema.(LogicalTypeSchema) - if ok { - ls = lts.Logical() - } - - name := string(schema.Type()) - if ls != nil { - name += "." + string(ls.Type()) - } - if resolver != nil { - typ, err := resolver.Type(name) - if err == nil { - return typ.UnsafeNew(), typ, nil - } - } - - switch schema.Type() { - case Boolean: - var v bool - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - case Int: - if ls != nil { - switch ls.Type() { - case Date: - var v time.Time - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - - case TimeMillis: - var v time.Duration - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - } - } - var v int - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - case Long: - if ls != nil { - switch ls.Type() { - case TimeMicros: - var v time.Duration - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - - case TimestampMillis: - var v time.Time - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - - case TimestampMicros: - var v time.Time - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - } - } - var v int64 - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - case Float: - var v float32 - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - case Double: - var v float64 - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - case String: - var v string - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - case Bytes: - if ls != nil && ls.Type() == Decimal { - var v *big.Rat - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - } - var v []byte - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - case Record: - var v map[string]any - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - case Ref: - return dynamicReceiver(schema.(*RefSchema).Schema(), resolver) - case Enum: - var v string - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - case Array: - v := make([]any, 0) - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - case Map: - var v map[string]any - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - case Union: - var v map[string]any - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - case Fixed: - fixed := schema.(*FixedSchema) - ls := fixed.Logical() - if ls != nil { - switch ls.Type() { - case Duration: - var v LogicalDuration - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - case Decimal: - var v big.Rat - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - } - } - // note that uint64 case is not supported, due to the lack of indicator at the schema-level (logical type) - var v []byte - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - default: - return nil, nil, fmt.Errorf("dynamic receiver not found for schema: %v", name) - } -} diff --git a/codec_record.go b/codec_record.go index 5dafca92..24e1f72e 100644 --- a/codec_record.go +++ b/codec_record.go @@ -91,9 +91,9 @@ func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec field: sf.Field, decoder: createDefaultDecoder(cfg, field, sf.Field[len(sf.Field)-1].Type()), }) - } - continue + continue + } } dec := decoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type()) @@ -271,9 +271,8 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec name: field.Name(), decoder: createDefaultDecoder(cfg, field, mapType.Elem()), } + continue } - - continue } fields[i] = recordMapDecoderField{ diff --git a/codec_union.go b/codec_union.go index b8e81748..475f5229 100644 --- a/codec_union.go +++ b/codec_union.go @@ -295,15 +295,7 @@ func (d *unionResolvedDecoder) Decode(ptr unsafe.Pointer, r *Reader) { name := schemaTypeName(schema) obj := map[string]any{} - rPtr, rTyp, err := dynamicReceiver(schema, r.cfg.resolver) - if err != nil { - r.ReportError("Read", err.Error()) - return - } - decoderOfType(r.cfg, schema, rTyp).Decode(rPtr, r) - - obj[name] = rTyp.UnsafeIndirect(rPtr) - + obj[name] = genericDecode(schema, r) // obj[name] = r.ReadNext(schema) *pObj = obj diff --git a/converter.go b/converter.go index 99dcd05f..0cc0bcd5 100644 --- a/converter.go +++ b/converter.go @@ -72,7 +72,7 @@ func createStringConverter(typ Type) (func(*Reader) string, error) { return func(r *Reader) string { b := r.ReadBytes() // TBD: update go.mod version to go 1.20 minimum - // runtime.KeepAlive(b) // TBD: I guess this line is required? + // runtime.KeepAlive(b) // return unsafe.String(unsafe.SliceData(b), len(b)) return string(b) }, nil diff --git a/converter_test.go b/converter_test.go new file mode 100644 index 00000000..743164fa --- /dev/null +++ b/converter_test.go @@ -0,0 +1,137 @@ +package avro + +import ( + "bytes" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConverter(t *testing.T) { + tests := []struct { + data []byte + want any + typ, wantTyp Type + wantErr require.ErrorAssertionFunc + }{ + { + data: []byte{0xE2, 0xA2, 0xF3, 0xAD, 0x07}, + want: int64(987654321), + typ: Int, + wantTyp: Long, + wantErr: require.NoError, + }, + { + data: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01}, + want: int64(9223372036854775807), + typ: Long, + wantTyp: Long, + wantErr: require.NoError, + }, + { + data: []byte{0xE2, 0xA2, 0xF3, 0xAD, 0x07}, + want: float32(987654321), + typ: Int, + wantTyp: Float, + wantErr: require.NoError, + }, + { + data: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01}, + want: float32(9223372036854775807), + typ: Long, + wantTyp: Float, + wantErr: require.NoError, + }, + { + data: []byte{0x62, 0x20, 0x71, 0x49}, + want: float32(987654.124), + typ: Float, + wantTyp: Float, + wantErr: require.NoError, + }, + { + data: []byte{0xE2, 0xA2, 0xF3, 0xAD, 0x07}, + want: float64(987654321), + typ: Int, + wantTyp: Double, + wantErr: require.NoError, + }, + { + data: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01}, + want: float64(9223372036854775807), + typ: Long, + wantTyp: Double, + wantErr: require.NoError, + }, + { + data: []byte{0x62, 0x20, 0x71, 0x49}, + want: float64(float32(987654.124)), + typ: Float, + wantTyp: Double, + wantErr: require.NoError, + }, + { + data: []byte{0xB6, 0xF3, 0x7D, 0x54, 0x34, 0x6F, 0x9D, 0xC1}, + want: float64(-123456789.123), + typ: Double, + wantTyp: Double, + wantErr: require.NoError, + }, + { + data: []byte{0x08, 0xEC, 0xAB, 0x44, 0x00}, + want: string([]byte{0xEC, 0xAB, 0x44, 0x00}), + typ: Bytes, + wantTyp: String, + wantErr: require.NoError, + }, + { + data: []byte{0x28, 0x6F, 0x70, 0x70, 0x61, 0x6E, 0x20, 0x67, 0x61, 0x6E, 0x67, 0x6E, 0x61, 0x6D, 0x20, 0x73, 0x74, 0x79, 0x6C, 0x65, 0x21}, + want: "oppan gangnam style!", + typ: String, + wantTyp: String, + wantErr: require.NoError, + }, + { + data: []byte{0x36, 0xD1, 0x87, 0xD0, 0xB5, 0x2D, 0xD1, 0x82, 0xD0, 0xBE, 0x20, 0xD0, 0xBF, 0xD0, 0xBE, 0x20, 0xD1, 0x80, 0xD1, 0x83, 0xD1, 0x81, 0xD1, 0x81, 0xD0, 0xBA, 0xD0, 0xB8}, + want: []byte("че-то по русски"), + typ: String, + wantTyp: Bytes, + wantErr: require.NoError, + }, + { + data: []byte{0x0C, 0xAC, 0xDC, 0x01, 0x00, 0x10, 0x0F}, + want: []byte{0xAC, 0xDC, 0x01, 0x00, 0x10, 0x0F}, + typ: Bytes, + wantTyp: Bytes, + wantErr: require.NoError, + }, + } + + for i, test := range tests { + test := test + t.Run(strconv.Itoa(i), func(t *testing.T) { + r := NewReader(bytes.NewReader(test.data), 10) + conv := resolveConverter(test.typ) + + var got any + switch test.wantTyp { + case Long: + got = conv.toLong(r) + case Float: + got = conv.toFloat(r) + case Double: + got = conv.toDouble(r) + case String: + got = conv.toString(r) + case Bytes: + got = conv.toBytes(r) + default: + } + + test.wantErr(t, r.Error) + assert.Equal(t, test.want, got) + }) + } +} diff --git a/decoder_dynamic_bench_test.go b/decoder_dynamic_bench_test.go deleted file mode 100644 index da216ea9..00000000 --- a/decoder_dynamic_bench_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package avro_test - -import ( - "bytes" - "testing" - - "github.com/hamba/avro/v2" -) - -func BenchmarkDecoder_Interface(b *testing.B) { - tests := []struct { - name string - data []byte - schema string - got any - want any - }{ - { - name: "Empty Interface", - data: []byte{0x36, 0x06, 0x66, 0x6f, 0x6f}, - schema: `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": "long"}, {"name": "b", "type": "string"}]}`, - got: nil, - want: map[string]any{"a": int64(27), "b": "foo"}, - }, - { - name: "Interface Non-Ptr", - data: []byte{0x36, 0x06, 0x66, 0x6f, 0x6f}, - schema: `{"type": "record", "name": "test", "fields": [{"name": "a", "type": "long"}, {"name": "b", "type": "string"}]}`, - got: TestRecord{}, - want: map[string]any{"a": int64(27), "b": "foo"}, - }, - { - name: "Interface Nil Ptr", - data: []byte{0x36, 0x06, 0x66, 0x6f, 0x6f}, - schema: `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": "long"}, {"name": "b", "type": "string"}]}`, - got: (*TestRecord)(nil), - want: &TestRecord{A: 27, B: "foo"}, - }, - { - name: "Interface Ptr", - data: []byte{0x36, 0x06, 0x66, 0x6f, 0x6f}, - schema: `{"type": "record", "name": "test", "fields": [{"name": "a", "type": "long"}, {"name": "b", "type": "string"}]}`, - got: &TestRecord{}, - want: &TestRecord{A: 27, B: "foo"}, - }, - } - - for _, test := range tests { - test := test - b.Run(test.name, func(b *testing.B) { - defer ConfigTeardown() - b.ResetTimer() - - for n := 0; n < b.N; n++ { - dec, _ := avro.NewDecoder(test.schema, bytes.NewReader(test.data)) - _ = dec.Decode(&test.got) - } - }) - } -} diff --git a/generic.go b/generic.go new file mode 100644 index 00000000..ff55bdcd --- /dev/null +++ b/generic.go @@ -0,0 +1,122 @@ +package avro + +import ( + "fmt" + "math/big" + "time" + "unsafe" + + "github.com/modern-go/reflect2" +) + +func genericDecode(schema Schema, r *Reader) any { + rPtr, rTyp, err := genericReceiver(schema) + if err != nil { + r.ReportError("Read", err.Error()) + return nil + } + decoderOfType(r.cfg, schema, rTyp).Decode(rPtr, r) + + return rTyp.UnsafeIndirect(rPtr) +} + +func genericReceiver(schema Schema) (unsafe.Pointer, reflect2.Type, error) { + var ls LogicalSchema + lts, ok := schema.(LogicalTypeSchema) + if ok { + ls = lts.Logical() + } + + name := string(schema.Type()) + if ls != nil { + name += "." + string(ls.Type()) + } + + switch schema.Type() { + case Boolean: + var v bool + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Int: + if ls != nil { + switch ls.Type() { + case Date: + var v time.Time + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + + case TimeMillis: + var v time.Duration + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + } + } + var v int + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Long: + if ls != nil { + switch ls.Type() { + case TimeMicros: + var v time.Duration + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + + case TimestampMillis: + var v time.Time + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + + case TimestampMicros: + var v time.Time + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + } + } + var v int64 + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Float: + var v float32 + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Double: + var v float64 + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case String: + var v string + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Bytes: + if ls != nil && ls.Type() == Decimal { + var v *big.Rat + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + } + var v []byte + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Record: + var v map[string]any + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Ref: + return genericReceiver(schema.(*RefSchema).Schema()) + case Enum: + var v string + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Array: + v := make([]any, 0) + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Map: + var v map[string]any + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Union: + var v map[string]any + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Fixed: + fixed := schema.(*FixedSchema) + ls := fixed.Logical() + if ls != nil { + switch ls.Type() { + case Duration: + var v LogicalDuration + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + case Decimal: + var v big.Rat + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + } + } + v := byteSliceToArray(make([]byte, fixed.Size()), fixed.Size()) + return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + default: + return nil, nil, fmt.Errorf("dynamic receiver not found for schema: %v", name) + } +} diff --git a/generic_internal_test.go b/generic_internal_test.go new file mode 100644 index 00000000..12b45ee4 --- /dev/null +++ b/generic_internal_test.go @@ -0,0 +1,212 @@ +package avro + +import ( + "bytes" + "math/big" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenericDecode(t *testing.T) { + tests := []struct { + name string + data []byte + schema string + want any + wantErr require.ErrorAssertionFunc + }{ + + { + name: "Bool", + data: []byte{0x01}, + schema: "boolean", + want: true, + wantErr: require.NoError, + }, + { + name: "Int", + data: []byte{0x36}, + schema: "int", + want: 27, + wantErr: require.NoError, + }, + { + name: "Int Date", + data: []byte{0xAE, 0x9D, 0x02}, + schema: `{"type":"int","logicalType":"date"}`, + want: time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC), + wantErr: require.NoError, + }, + { + name: "Int Time-Millis", + data: []byte{0xAA, 0xB4, 0xDE, 0x75}, + schema: `{"type":"int","logicalType":"time-millis"}`, + want: 123456789 * time.Millisecond, + wantErr: require.NoError, + }, + { + name: "Long", + data: []byte{0x36}, + schema: "long", + want: int64(27), + wantErr: require.NoError, + }, + { + name: "Long Time-Micros", + data: []byte{0x86, 0xEA, 0xC8, 0xE9, 0x97, 0x07}, + schema: `{"type":"long","logicalType":"time-micros"}`, + want: 123456789123 * time.Microsecond, + wantErr: require.NoError, + }, + { + name: "Long Timestamp-Millis", + data: []byte{0x90, 0xB2, 0xAE, 0xC3, 0xEC, 0x5B}, + schema: `{"type":"long","logicalType":"timestamp-millis"}`, + want: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC), + wantErr: require.NoError, + }, + { + name: "Long Timestamp-Micros", + data: []byte{0x80, 0xCD, 0xB7, 0xA2, 0xEE, 0xC7, 0xCD, 0x05}, + schema: `{"type":"long","logicalType":"timestamp-micros"}`, + want: time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC), + wantErr: require.NoError, + }, + { + name: "Float", + data: []byte{0x33, 0x33, 0x93, 0x3F}, + schema: "float", + want: float32(1.15), + wantErr: require.NoError, + }, + { + name: "Double", + data: []byte{0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0xF2, 0x3F}, + schema: "double", + want: float64(1.15), + wantErr: require.NoError, + }, + { + name: "String", + data: []byte{0x06, 0x66, 0x6F, 0x6F}, + schema: "string", + want: "foo", + wantErr: require.NoError, + }, + { + name: "Bytes", + data: []byte{0x08, 0xEC, 0xAB, 0x44, 0x00}, + schema: "bytes", + want: []byte{0xEC, 0xAB, 0x44, 0x00}, + wantErr: require.NoError, + }, + { + name: "Bytes Decimal", + data: []byte{0x6, 0x00, 0x87, 0x78}, + schema: `{"type":"bytes","logicalType":"decimal","precision":4,"scale":2}`, + want: big.NewRat(1734, 5), + wantErr: require.NoError, + }, + { + name: "Record", + data: []byte{0x36, 0x06, 0x66, 0x6f, 0x6f}, + schema: `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": "long"}, {"name": "b", "type": "string"}]}`, + want: map[string]any{"a": int64(27), "b": "foo"}, + wantErr: require.NoError, + }, + { + name: "Ref", + data: []byte{0x36, 0x06, 0x66, 0x6f, 0x6f, 0x36, 0x06, 0x66, 0x6f, 0x6f}, + schema: `{"type":"record","name":"parent","fields":[{"name":"a","type":{"type":"record","name":"test","fields":[{"name":"a","type":"long"},{"name":"b","type":"string"}]}},{"name":"b","type":"test"}]}`, + want: map[string]any{"a": map[string]any{"a": int64(27), "b": "foo"}, "b": map[string]any{"a": int64(27), "b": "foo"}}, + wantErr: require.NoError, + }, + { + name: "Array", + data: []byte{0x04, 0x36, 0x38, 0x0}, + schema: `{"type":"array", "items": "int"}`, + want: []any{27, 28}, + wantErr: require.NoError, + }, + { + name: "Map", + data: []byte{0x02, 0x06, 0x66, 0x6F, 0x6F, 0x06, 0x66, 0x6F, 0x6F, 0x00}, + schema: `{"type":"map", "values": "string"}`, + want: map[string]any{"foo": "foo"}, + wantErr: require.NoError, + }, + { + name: "Enum", + data: []byte{0x02}, + schema: `{"type":"enum", "name": "test", "symbols": ["foo", "bar"]}`, + want: "bar", + wantErr: require.NoError, + }, + { + name: "Enum Invalid Symbol", + data: []byte{0x04}, + schema: `{"type":"enum", "name": "test", "symbols": ["foo", "bar"]}`, + want: nil, + wantErr: require.Error, + }, + { + name: "Union", + data: []byte{0x02, 0x06, 0x66, 0x6F, 0x6F}, + schema: `["null", "string"]`, + want: map[string]any{"string": "foo"}, + wantErr: require.NoError, + }, + { + name: "Union Nil", + data: []byte{0x00}, + schema: `["null", "string"]`, + want: nil, + wantErr: require.NoError, + }, + { + name: "Union Named", + data: []byte{0x02, 0x02}, + schema: `["null", {"type":"enum", "name": "test", "symbols": ["foo", "bar"]}]`, + want: map[string]any{"test": "bar"}, + wantErr: require.NoError, + }, + { + name: "Union Invalid Schema", + data: []byte{0x04}, + schema: `["null", "string"]`, + want: nil, + wantErr: require.Error, + }, + { + name: "Fixed", + data: []byte{0x66, 0x6F, 0x6F, 0x66, 0x6F, 0x6F}, + schema: `{"type":"fixed", "name": "test", "size": 6}`, + want: [6]byte{'f', 'o', 'o', 'f', 'o', 'o'}, + wantErr: require.NoError, + }, + { + name: "Fixed Decimal", + data: []byte{0x00, 0x00, 0x00, 0x00, 0x87, 0x78}, + schema: `{"type":"fixed", "name": "test", "size": 6,"logicalType":"decimal","precision":4,"scale":2}`, + want: big.NewRat(1734, 5), + wantErr: require.NoError, + }, + } + + for i, test := range tests { + test := test + t.Run(strconv.Itoa(i), func(t *testing.T) { + schema := MustParse(test.schema) + r := NewReader(bytes.NewReader(test.data), 10) + + got := r.ReadNext(schema) + + test.wantErr(t, r.Error) + assert.Equal(t, test.want, got) + }) + } +} diff --git a/reader_generic.go b/reader_generic.go index 79e5c38c..b75d240e 100644 --- a/reader_generic.go +++ b/reader_generic.go @@ -8,11 +8,6 @@ import ( // ReadNext reads the next Avro element as a generic interface. func (r *Reader) ReadNext(schema Schema) any { - var rp iReaderPromoter = r - if sch, ok := schema.(*PrimitiveSchema); ok && sch.actual != "" { - rp = newReaderPromoter(sch.actual, r) - } - var ls LogicalSchema lts, ok := schema.(LogicalTypeSchema) if ok { @@ -39,34 +34,34 @@ func (r *Reader) ReadNext(schema Schema) any { if ls != nil { switch ls.Type() { case TimeMicros: - return time.Duration(rp.ReadLong()) * time.Microsecond + return time.Duration(r.ReadLong()) * time.Microsecond case TimestampMillis: - i := rp.ReadLong() + i := r.ReadLong() sec := i / 1e3 nsec := (i - sec*1e3) * 1e6 return time.Unix(sec, nsec).UTC() case TimestampMicros: - i := rp.ReadLong() + i := r.ReadLong() sec := i / 1e6 nsec := (i - sec*1e6) * 1e3 return time.Unix(sec, nsec).UTC() } } - return rp.ReadLong() + return r.ReadLong() case Float: - return rp.ReadFloat() + return r.ReadFloat() case Double: - return rp.ReadDouble() + return r.ReadDouble() case String: - return rp.ReadString() + return r.ReadString() case Bytes: if ls != nil && ls.Type() == Decimal { dec := ls.(*DecimalLogicalSchema) - return ratFromBytes(rp.ReadBytes(), dec.Scale()) + return ratFromBytes(r.ReadBytes(), dec.Scale()) } - return rp.ReadBytes() + return r.ReadBytes() case Record: fields := schema.(*RecordSchema).Fields() obj := make(map[string]any, len(fields)) @@ -102,7 +97,7 @@ func (r *Reader) ReadNext(schema Schema) any { return obj case Union: types := schema.(*UnionSchema).Types() - idx := int(rp.ReadLong()) + idx := int(r.ReadLong()) if idx < 0 || idx > len(types)-1 { r.ReportError("Read", "unknown union type") return nil diff --git a/reader_promoter.go b/reader_promoter.go deleted file mode 100644 index 2a6ff3bf..00000000 --- a/reader_promoter.go +++ /dev/null @@ -1,68 +0,0 @@ -package avro - -type iReaderPromoter interface { - ReadLong() int64 - ReadFloat() float32 - ReadDouble() float64 - ReadString() string - ReadBytes() []byte -} - -type readerPromoter struct { - actual Type - r *Reader - converter -} - -func newReaderPromoter(actual Type, r *Reader) *readerPromoter { - rp := &readerPromoter{ - actual: actual, - r: r, - converter: resolveConverter(actual), - } - - return rp -} - -var _ iReaderPromoter = &readerPromoter{} - -func (p *readerPromoter) ReadLong() int64 { - if p.toLong != nil { - return p.toLong(p.r) - } - - return p.r.ReadLong() -} - -func (p *readerPromoter) ReadFloat() float32 { - if p.toFloat != nil { - return p.toFloat(p.r) - } - - return p.r.ReadFloat() -} - -func (p *readerPromoter) ReadDouble() float64 { - if p.toDouble != nil { - v := p.toDouble(p.r) - return v - } - - return p.r.ReadDouble() -} - -func (p *readerPromoter) ReadString() string { - if p.toString != nil { - return p.toString(p.r) - } - - return p.r.ReadString() -} - -func (p *readerPromoter) ReadBytes() []byte { - if p.toBytes != nil { - return p.toBytes(p.r) - } - - return p.r.ReadBytes() -} diff --git a/schema_compatibility.go b/schema_compatibility.go index e3ff509a..0985d276 100644 --- a/schema_compatibility.go +++ b/schema_compatibility.go @@ -177,11 +177,14 @@ func (c *SchemaCompatibility) match(reader, writer Schema) error { func (c *SchemaCompatibility) checkSchemaName(reader, writer NamedSchema) error { if reader.FullName() != writer.FullName() { - for _, alias := range reader.Aliases() { - if alias == writer.FullName() { - return nil - } + if c.contains(reader.Aliases(), writer.FullName()) { + return nil } + // for _, alias := range reader.Aliases() { + // if alias == writer.FullName() { + // return nil + // } + // } return fmt.Errorf("reader schema %s and writer schema %s names do not match", reader.FullName(), writer.FullName()) } @@ -255,17 +258,13 @@ func (c *SchemaCompatibility) getField(a []*Field, f *Field, optFns ...func(*get return field, true } if opt.fieldAlias { - for _, alias := range f.Aliases() { - if field.Name() == alias { - return field, true - } + if c.contains(f.Aliases(), field.Name()) { + return field, true } } if opt.elemAlias { - for _, alias := range field.Aliases() { - if f.Name() == alias { - return field, true - } + if c.contains(field.Aliases(), f.Name()) { + return field, true } } } @@ -300,7 +299,6 @@ func (c *SchemaCompatibility) resolve(reader, writer Schema) (Schema, error) { if err != nil { continue } - return sch, nil } @@ -324,6 +322,8 @@ func (c *SchemaCompatibility) resolve(reader, writer Schema) (Schema, error) { r.actual = writer.Type() return r, nil } + + return nil, fmt.Errorf("failed to resolve composite schema for %s and %s", reader.Type(), writer.Type()) } if isNative(writer.Type()) { @@ -351,7 +351,7 @@ func (c *SchemaCompatibility) resolve(reader, writer Schema) (Schema, error) { } if writer.Type() == Array { - schema, err := c.Resolve(reader.(*ArraySchema).Items(), writer.(*ArraySchema).Items()) + schema, err := c.resolve(reader.(*ArraySchema).Items(), writer.(*ArraySchema).Items()) if err != nil { return nil, err } @@ -359,7 +359,7 @@ func (c *SchemaCompatibility) resolve(reader, writer Schema) (Schema, error) { } if writer.Type() == Map { - schema, err := c.Resolve(reader.(*MapSchema).Values(), writer.(*MapSchema).Values()) + schema, err := c.resolve(reader.(*MapSchema).Values(), writer.(*MapSchema).Values()) if err != nil { return nil, err } @@ -378,40 +378,38 @@ func (c *SchemaCompatibility) resolveRecord(reader, writer Schema) (Schema, erro r := reader.(*RecordSchema) fields := make([]*Field, 0) - founds := make(map[string]struct{}) + seen := make(map[string]struct{}) for _, field := range w.Fields() { - if field == nil { - continue - } - f := *field + f, _ := NewField(field.Name(), field.Type(), WithAliases(field.aliases), WithOrder(field.order)) + f.def = field.def + f.hasDef = field.hasDef rf, ok := c.getField(r.Fields(), field, func(gfo *getFieldOptions) { gfo.elemAlias = true }) if !ok { f.action = FieldDrain - fields = append(fields, &f) + fields = append(fields, f) continue } - ft, err := c.Resolve(rf.Type(), field.Type()) + ft, err := c.resolve(rf.Type(), f.Type()) if err != nil { return nil, err } rf.typ = ft fields = append(fields, rf) - founds[rf.Name()] = struct{}{} + seen[rf.Name()] = struct{}{} } for _, field := range r.Fields() { - if field == nil { - continue - } - if _, ok := founds[field.Name()]; ok { + if _, ok := seen[field.Name()]; ok { continue } - f := *field + f, _ := NewField(field.Name(), field.Type(), WithAliases(field.aliases), WithOrder(field.order)) + f.def = field.def + f.hasDef = field.hasDef f.action = FieldSetDefault - fields = append(fields, &f) + fields = append(fields, f) } return NewRecordSchema(r.Name(), r.Namespace(), fields, WithAliases(r.Aliases())) diff --git a/schema_compatibility_test.go b/schema_compatibility_test.go index 1c6a8f5e..a018c70a 100644 --- a/schema_compatibility_test.go +++ b/schema_compatibility_test.go @@ -277,7 +277,7 @@ func TestSchemaCompatibility_CompatibleUsesCacheWithError(t *testing.T) { assert.Error(t, err) } -func TestSchemaCompatibility_ResolveV2(t *testing.T) { +func TestSchemaCompatibility_Resolve(t *testing.T) { tests := []struct { name string reader string @@ -341,6 +341,52 @@ func TestSchemaCompatibility_ResolveV2(t *testing.T) { value: []byte("foo"), want: "foo", }, + { + name: "Array With Items Promotion", + reader: `{"type":"array", "items": "long"}`, + writer: `{"type":"array", "items": "int"}`, + value: []any{int32(10), int32(15)}, + want: []any{int64(10), int64(15)}, + }, + { + name: "Map With Items Promotion", + reader: `{"type":"map", "values": "bytes"}`, + writer: `{"type":"map", "values": "string"}`, + value: map[string]any{"foo": "bar"}, + want: map[string]any{"foo": []byte("bar")}, + }, + { + name: "Enum With Alias", + reader: `{ + "type": "enum", + "name": "test.enum2", + "aliases": ["test.enum"], + "symbols": ["foo", "bar"] + }`, + writer: `{ + "type": "enum", + "name": "test.enum", + "symbols": ["foo", "bar"] + }`, + value: "foo", + want: "foo", + }, + { + name: "Fixed With Alias", + reader: `{ + "type": "fixed", + "name": "test.fixed2", + "aliases": ["test.fixed"], + "size": 3 + }`, + writer: `{ + "type": "fixed", + "name": "test.fixed", + "size": 3 + }`, + value: [3]byte{'f', 'o', 'o'}, + want: [3]byte{'f', 'o', 'o'}, + }, { name: "Union Match", reader: `["int", "long", "string"]`, @@ -368,6 +414,13 @@ func TestSchemaCompatibility_ResolveV2(t *testing.T) { value: 10, want: 10, }, + { + name: "Record Reader With Alias", + reader: `{"type":"record", "name":"test2", "aliases": ["test"], "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + value: map[string]any{"a": 10}, + want: map[string]any{"a": 10}, + }, { name: "Record Reader Field Missing", reader: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, @@ -434,6 +487,30 @@ func TestSchemaCompatibility_ResolveV2(t *testing.T) { value: map[string]any{"a": 10}, want: map[string]any{"a": 10, "b": map[string]any{"a": "foo", "b": "bar"}}, }, + // { + // name: "Record Writer Field Missing With Record Default 2", + // reader: `{ + // "type":"record", "name":"test", "namespace": "org.hamba.avro", + // "fields":[ + // {"name": "a", "type": "int"}, + // { + // "name": "b", + // "type": { + // "type": "record", + // "name": "test.record", + // "fields" : [ + // {"name": "a", "type": "string"}, + // {"name": "b", "type": "string"} + // ] + // }, + // "default":{"a":"foo 2", "b": "bar 2"} + // } + // ] + // }`, + // writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + // value: map[string]any{"a": 10}, + // want: map[string]any{"a": 10, "b": map[string]any{"a": "foo 2", "b": "bar 2"}}, + // }, { name: "Record Writer Field Missing With Map Default", reader: `{ From 0fc43f2f52cff501529f7b49daa7c0dadf7ed01e Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Fri, 15 Dec 2023 18:05:34 +0100 Subject: [PATCH 12/25] fix decoder cachekey to consider primitives with promotion and fields default --- codec.go | 9 +++-- schema.go | 64 +++++++++++++++++++++++++++--------- schema_compatibility_test.go | 50 ++++++++++++++-------------- 3 files changed, 81 insertions(+), 42 deletions(-) diff --git a/codec.go b/codec.go index 156839a3..e061d74f 100644 --- a/codec.go +++ b/codec.go @@ -34,7 +34,8 @@ type ValEncoder interface { // ReadVal parses Avro value and stores the result in the value pointed to by obj. func (r *Reader) ReadVal(schema Schema, obj any) { - decoder := r.cfg.getDecoderFromCache(schema.Fingerprint(), reflect2.RTypeOf(obj)) + key := cacheFingerprintOf(schema) + decoder := r.cfg.getDecoderFromCache(key, reflect2.RTypeOf(obj)) if decoder == nil { typ := reflect2.TypeOf(obj) if typ.Kind() != reflect.Ptr { @@ -65,14 +66,16 @@ func (w *Writer) WriteVal(schema Schema, val any) { func (c *frozenConfig) DecoderOf(schema Schema, typ reflect2.Type) ValDecoder { rtype := typ.RType() - decoder := c.getDecoderFromCache(schema.Fingerprint(), rtype) + + key := cacheFingerprintOf(schema) + decoder := c.getDecoderFromCache(key, rtype) if decoder != nil { return decoder } ptrType := typ.(*reflect2.UnsafePtrType) decoder = decoderOfType(c, schema, ptrType.Elem()) - c.addDecoderToCache(schema.Fingerprint(), rtype, decoder) + c.addDecoderToCache(key, rtype, decoder) return decoder } diff --git a/schema.go b/schema.go index 3e3b6a25..6df880ca 100644 --- a/schema.go +++ b/schema.go @@ -315,6 +315,32 @@ func (f *fingerprinter) FingerprintUsing(typ FingerprintType, stringer fmt.Strin return fingerprint, nil } +// cacheFingerprintOf returns a special fingerprint mainly used by decoders cache. +func cacheFingerprintOf(schema Schema) [32]byte { + if s, ok := schema.(interface{ CacheFingerprint() [32]byte }); ok { + return s.CacheFingerprint() + } + return schema.Fingerprint() +} + +type cacheFingerprinter struct { + key atomic.Value // [32]byte +} + +func (sf *cacheFingerprinter) fingerprint(data []any) [32]byte { + if v := sf.key.Load(); v != nil { + return v.([32]byte) + } + + b, err := jsoniter.Marshal(data) + if err != nil { + panic("cache fingerprint: couldn't json marshal receipt data: " + err.Error()) + } + key := sha256.Sum256(b) + sf.key.Store(key) + return key +} + type properties struct { props map[string]any } @@ -415,6 +441,7 @@ func WithProps(props map[string]any) SchemaOption { type PrimitiveSchema struct { properties fingerprinter + cacheFingerprinter typ Type logical LogicalSchema @@ -482,22 +509,9 @@ func (s *PrimitiveSchema) MarshalJSON() ([]byte, error) { return buf.Bytes(), nil } -// Temporary HACK to allow testing schema resolution logic... -// a better solution would be to extend decoder cache key. -type primitiveSchemaFingerprint struct { - s *PrimitiveSchema -} - -func (sfp *primitiveSchemaFingerprint) String() string { - if sfp.s.actual == "" { - return sfp.s.String() - } - return sfp.s.String() + ":" + string(sfp.s.actual) -} - // Fingerprint returns the SHA256 fingerprint of the schema. func (s *PrimitiveSchema) Fingerprint() [32]byte { - return s.fingerprinter.Fingerprint(&primitiveSchemaFingerprint{s: s}) + return s.fingerprinter.Fingerprint(s) } // FingerprintUsing returns the fingerprint of the schema using the given algorithm or an error. @@ -505,6 +519,15 @@ func (s *PrimitiveSchema) FingerprintUsing(typ FingerprintType) ([]byte, error) return s.fingerprinter.FingerprintUsing(typ, s) } +// CacheFingerprint returns a special fingerprint of the schema for caching purposes. +func (s *PrimitiveSchema) CacheFingerprint() [32]byte { + data := []any{s.Fingerprint()} + if s.actual != "" { + data = append(data, s.actual) + } + return s.cacheFingerprinter.fingerprint(data) +} + // Actual returns the actual type of the schema. // This field is only presents during write-read schema resolution. func (s *PrimitiveSchema) Actual() Type { @@ -516,7 +539,7 @@ type RecordSchema struct { name properties fingerprinter - + cacheFingerprinter isError bool fields []*Field doc string @@ -635,6 +658,17 @@ func (s *RecordSchema) FingerprintUsing(typ FingerprintType) ([]byte, error) { return s.fingerprinter.FingerprintUsing(typ, s) } +// CacheFingerprint returns a special fingerprint of the schema for caching purposes. +func (s *RecordSchema) CacheFingerprint() [32]byte { + data := []any{s.Fingerprint()} + for _, field := range s.fields { + if field.Default() != nil { + data = append(data, field.Default()) + } + } + return s.cacheFingerprinter.fingerprint(data) +} + // Field is an Avro record type field. type Field struct { properties diff --git a/schema_compatibility_test.go b/schema_compatibility_test.go index a018c70a..55813d38 100644 --- a/schema_compatibility_test.go +++ b/schema_compatibility_test.go @@ -487,30 +487,32 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { value: map[string]any{"a": 10}, want: map[string]any{"a": 10, "b": map[string]any{"a": "foo", "b": "bar"}}, }, - // { - // name: "Record Writer Field Missing With Record Default 2", - // reader: `{ - // "type":"record", "name":"test", "namespace": "org.hamba.avro", - // "fields":[ - // {"name": "a", "type": "int"}, - // { - // "name": "b", - // "type": { - // "type": "record", - // "name": "test.record", - // "fields" : [ - // {"name": "a", "type": "string"}, - // {"name": "b", "type": "string"} - // ] - // }, - // "default":{"a":"foo 2", "b": "bar 2"} - // } - // ] - // }`, - // writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, - // value: map[string]any{"a": 10}, - // want: map[string]any{"a": 10, "b": map[string]any{"a": "foo 2", "b": "bar 2"}}, - // }, + { + // assert that we are not mistakenly using the wrong cached decoder. + // decoder cache must be aware of fields defaults. + name: "Record Writer Field Missing With Record Default 2", + reader: `{ + "type":"record", "name":"test", "namespace": "org.hamba.avro", + "fields":[ + {"name": "a", "type": "int"}, + { + "name": "b", + "type": { + "type": "record", + "name": "test.record", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": "string"} + ] + }, + "default":{"a":"foo 2", "b": "bar 2"} + } + ] + }`, + writer: `{"type":"record", "name":"test", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + value: map[string]any{"a": 10}, + want: map[string]any{"a": 10, "b": map[string]any{"a": "foo 2", "b": "bar 2"}}, + }, { name: "Record Writer Field Missing With Map Default", reader: `{ From 62b27ce2f44b3f283b77a8420163ee6a745fa88d Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Fri, 15 Dec 2023 20:24:08 +0100 Subject: [PATCH 13/25] fix: codec generic --- codec.go | 1 - generic.go => codec_generic.go | 15 +++++- ..._test.go => codec_generic_internal_test.go | 11 +++- schema.go | 11 ---- schema_compatibility_test.go | 51 +++++++++++++++---- 5 files changed, 65 insertions(+), 24 deletions(-) rename generic.go => codec_generic.go (92%) rename generic_internal_test.go => codec_generic_internal_test.go (95%) diff --git a/codec.go b/codec.go index e061d74f..66ad77d1 100644 --- a/codec.go +++ b/codec.go @@ -66,7 +66,6 @@ func (w *Writer) WriteVal(schema Schema, val any) { func (c *frozenConfig) DecoderOf(schema Schema, typ reflect2.Type) ValDecoder { rtype := typ.RType() - key := cacheFingerprintOf(schema) decoder := c.getDecoderFromCache(key, rtype) if decoder != nil { diff --git a/generic.go b/codec_generic.go similarity index 92% rename from generic.go rename to codec_generic.go index ff55bdcd..4265288c 100644 --- a/generic.go +++ b/codec_generic.go @@ -16,8 +16,21 @@ func genericDecode(schema Schema, r *Reader) any { return nil } decoderOfType(r.cfg, schema, rTyp).Decode(rPtr, r) + if r.Error != nil { + return nil + } + obj := rTyp.UnsafeIndirect(rPtr) + if reflect2.IsNil(obj) { + return nil + } + + // seems generic reader is not compatible with codec + if rTyp.Type1() == ratType { + dec := obj.(big.Rat) + return &dec + } - return rTyp.UnsafeIndirect(rPtr) + return obj } func genericReceiver(schema Schema) (unsafe.Pointer, reflect2.Type, error) { diff --git a/generic_internal_test.go b/codec_generic_internal_test.go similarity index 95% rename from generic_internal_test.go rename to codec_generic_internal_test.go index 12b45ee4..79524c63 100644 --- a/generic_internal_test.go +++ b/codec_generic_internal_test.go @@ -203,10 +203,19 @@ func TestGenericDecode(t *testing.T) { schema := MustParse(test.schema) r := NewReader(bytes.NewReader(test.data), 10) - got := r.ReadNext(schema) + got := genericDecode(schema, r) test.wantErr(t, r.Error) assert.Equal(t, test.want, got) }) } } + +func TestReader_UnsupportedType(t *testing.T) { + schema := NewPrimitiveSchema(Type("test"), nil) + r := NewReader(bytes.NewReader([]byte{0x01}), 10) + + _ = genericDecode(schema, r) + + assert.Error(t, r.Error) +} diff --git a/schema.go b/schema.go index 6df880ca..e2658d43 100644 --- a/schema.go +++ b/schema.go @@ -528,12 +528,6 @@ func (s *PrimitiveSchema) CacheFingerprint() [32]byte { return s.cacheFingerprinter.fingerprint(data) } -// Actual returns the actual type of the schema. -// This field is only presents during write-read schema resolution. -func (s *PrimitiveSchema) Actual() Type { - return s.actual -} - // RecordSchema is an Avro record type schema. type RecordSchema struct { name @@ -743,11 +737,6 @@ func (f *Field) Name() string { return f.name } -// Action returns the action of a field. -func (f *Field) Action() Action { - return f.action -} - // Aliases return the field aliases. func (f *Field) Aliases() []string { return f.aliases diff --git a/schema_compatibility_test.go b/schema_compatibility_test.go index 55813d38..94f39fe2 100644 --- a/schema_compatibility_test.go +++ b/schema_compatibility_test.go @@ -638,7 +638,7 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { value: map[string]any{"a": 10}, want: map[string]any{ "a": 10, - "b": *big.NewRat(1734, 5), + "b": big.NewRat(1734, 5), }, }, { @@ -665,6 +665,43 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { "b": "bar", }, }, + { + name: "Record Writer Field Missing With Ref Default", + reader: `{ + "type": "record", + "name": "parent", + "namespace": "org.hamba.avro", + "fields": [{ + "name": "a", + "type": "int" + }, + { + "name": "b", + "type": { + "type": "record", + "name": "test", + "fields": [{ + "name": "a", + "type": "long" + }] + }, + "default": {"a": 10} + }, + { + "name": "c", + "type": "test", + "default": {"a": 20} + } + ] + }`, + writer: `{"type":"record", "name":"parent", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, + value: map[string]any{"a": 10}, + want: map[string]any{ + "a": 10, + "b": map[string]any{"a": int64(10)}, + "c": map[string]any{"a": int64(20)}, + }, + }, } for _, test := range tests { @@ -675,20 +712,14 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { sc := avro.NewSchemaCompatibility() b, err := avro.Marshal(w, test.value) - if err != nil { - t.Fatalf("marshal error%v", err) - } + assert.NoError(t, err) sch, err := sc.Resolve(r, w) - if err != nil { - t.Fatalf("resolve error %v", err) - } + assert.NoError(t, err) var result any err = avro.Unmarshal(sch, b, &result) - if err != nil { - t.Fatalf("unmarshal error %v", err) - } + assert.NoError(t, err) assert.Equal(t, test.want, result) }) From 7db00da94c9cf92dd197c60507abd4e52411a8a3 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Sat, 16 Dec 2023 18:44:09 +0100 Subject: [PATCH 14/25] cleanups and fixes --- codec_default_internal_test.go | 40 ++++++- schema.go | 25 +++-- schema_compatibility.go | 23 ++-- schema_compatibility_test.go | 112 ++++++++++++++++--- schema_internal_test.go | 198 +++++++++++++++++++++++++++++++++ 5 files changed, 357 insertions(+), 41 deletions(-) diff --git a/codec_default_internal_test.go b/codec_default_internal_test.go index 59c82130..9adf1b1a 100644 --- a/codec_default_internal_test.go +++ b/codec_default_internal_test.go @@ -57,6 +57,45 @@ func TestDecoder_InvalidDefault(t *testing.T) { require.Error(t, err) } +func TestDecoder_DrainField(t *testing.T) { + defer ConfigTeardown() + + // write schema + // `{ + // // "type": "record", + // // "name": "test", + // // "fields" : [ + // // {"name": "a", "type": "string"} + // // ] + // // }` + + // {"a": "foo"} + data := []byte{0x6, 0x66, 0x6f, 0x6f} + + schema := MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": "float", "default": 10.45} + ] + }`) + + schema.(*RecordSchema).Fields()[0].action = FieldDrain + schema.(*RecordSchema).Fields()[1].action = FieldSetDefault + + type TestRecord struct { + A string `avro:"a"` + B float32 `avro:"b"` + } + + var got TestRecord + err := NewDecoderForSchema(schema, bytes.NewReader(data)).Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecord{B: 10.45, A: ""}, got) +} + func TestDecoder_DefaultBool(t *testing.T) { defer ConfigTeardown() @@ -714,5 +753,4 @@ func TestDecoder_DefaultFixed(t *testing.T) { assert.Equal(t, big.NewRat(1734, 5), &got.B) assert.Equal(t, "foo", got.A) }) - } diff --git a/schema.go b/schema.go index e2658d43..50779d84 100644 --- a/schema.go +++ b/schema.go @@ -521,11 +521,11 @@ func (s *PrimitiveSchema) FingerprintUsing(typ FingerprintType) ([]byte, error) // CacheFingerprint returns a special fingerprint of the schema for caching purposes. func (s *PrimitiveSchema) CacheFingerprint() [32]byte { - data := []any{s.Fingerprint()} - if s.actual != "" { - data = append(data, s.actual) + if s.actual == "" { + return s.Fingerprint() } - return s.cacheFingerprinter.fingerprint(data) + + return s.cacheFingerprinter.fingerprint([]any{s.Fingerprint(), s.actual}) } // RecordSchema is an Avro record type schema. @@ -654,12 +654,17 @@ func (s *RecordSchema) FingerprintUsing(typ FingerprintType) ([]byte, error) { // CacheFingerprint returns a special fingerprint of the schema for caching purposes. func (s *RecordSchema) CacheFingerprint() [32]byte { - data := []any{s.Fingerprint()} + data := make([]any, 0) for _, field := range s.fields { if field.Default() != nil { data = append(data, field.Default()) } } + if len(data) == 0 { + return s.Fingerprint() + } + + data = append(data, s.Fingerprint()) return s.cacheFingerprinter.fingerprint(data) } @@ -679,7 +684,7 @@ type Field struct { action Action // encodedDef mainly used when decoding data that lack the field for schema evolution purposes. // Its value remains empty unless the field's encodeDefault function is called. - encodedDef []byte + encodedDef atomic.Value } type noDef struct{} @@ -764,8 +769,8 @@ func (f *Field) Default() any { } func (f *Field) encodeDefault(encode func(any) ([]byte, error)) ([]byte, error) { - if f.encodedDef != nil { - return f.encodedDef, nil + if v := f.encodedDef.Load(); v != nil { + return v.([]byte), nil } if !f.hasDef { return nil, fmt.Errorf("avro: '%s' field must have a non-empty default value", f.name) @@ -777,9 +782,9 @@ func (f *Field) encodeDefault(encode func(any) ([]byte, error)) ([]byte, error) if err != nil { return nil, err } - f.encodedDef = b + f.encodedDef.Store(b) - return f.encodedDef, nil + return b, nil } // Doc returns the documentation of a field. diff --git a/schema_compatibility.go b/schema_compatibility.go index 0985d276..57221753 100644 --- a/schema_compatibility.go +++ b/schema_compatibility.go @@ -180,11 +180,6 @@ func (c *SchemaCompatibility) checkSchemaName(reader, writer NamedSchema) error if c.contains(reader.Aliases(), writer.FullName()) { return nil } - // for _, alias := range reader.Aliases() { - // if alias == writer.FullName() { - // return nil - // } - // } return fmt.Errorf("reader schema %s and writer schema %s names do not match", reader.FullName(), writer.FullName()) } @@ -248,9 +243,6 @@ type getFieldOptions struct { func (c *SchemaCompatibility) getField(a []*Field, f *Field, optFns ...func(*getFieldOptions)) (*Field, bool) { opt := getFieldOptions{} for _, fn := range optFns { - if fn == nil { - continue - } fn(&opt) } for _, field := range a { @@ -277,13 +269,6 @@ func (c *SchemaCompatibility) getField(a []*Field, f *Field, optFns ...func(*get // // It fails if the writer and reader schemas are not compatible. func (c *SchemaCompatibility) Resolve(reader, writer Schema) (Schema, error) { - if reader.Type() == Ref { - reader = reader.(*RefSchema).Schema() - } - if writer.Type() == Ref { - writer = writer.(*RefSchema).Schema() - } - if err := c.compatible(reader, writer); err != nil { return nil, err } @@ -291,7 +276,15 @@ func (c *SchemaCompatibility) Resolve(reader, writer Schema) (Schema, error) { return c.resolve(reader, writer) } +// resolve requires the reader's schema to be already compatible with the writer's. func (c *SchemaCompatibility) resolve(reader, writer Schema) (Schema, error) { + if reader.Type() == Ref { + reader = reader.(*RefSchema).Schema() + } + if writer.Type() == Ref { + writer = writer.(*RefSchema).Schema() + } + if writer.Type() != reader.Type() { if reader.Type() == Union { for _, schema := range reader.(*UnionSchema).Types() { diff --git a/schema_compatibility_test.go b/schema_compatibility_test.go index 94f39fe2..d656cb9d 100644 --- a/schema_compatibility_test.go +++ b/schema_compatibility_test.go @@ -3,6 +3,7 @@ package avro_test import ( "math/big" "testing" + "time" "github.com/hamba/avro/v2" "github.com/stretchr/testify/assert" @@ -292,6 +293,27 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { value: 10, want: int64(10), }, + { + name: "Int Promote Long Time millis", + reader: `{"type":"long","logicalType":"timestamp-millis"}`, + writer: `"int"`, + value: 5000, + want: time.UnixMilli(5000).UTC(), + }, + { + name: "Int Promote Long Time micros", + reader: `{"type":"long","logicalType":"timestamp-micros"}`, + writer: `"int"`, + value: 5000, + want: time.UnixMicro(5000).UTC(), + }, + { + name: "Int Promote Long Time micros", + reader: `{"type":"long","logicalType":"time-micros"}`, + writer: `"int"`, + value: 5000, + want: 5000 * time.Microsecond, + }, { name: "Int Promote Float", reader: `"float"`, @@ -334,6 +356,16 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { value: "foo", want: []byte("foo"), }, + { + // I'm not sure about this edge cases; + // I took the reverse path and tried to find a Decimal that can be encoded to + // a binary that is a valid UTF-8 sequence. + name: "String Promote Bytes With Logical Decimal", + reader: `{"type":"bytes","logicalType":"decimal","precision":4,"scale":2}`, + writer: `"string"`, + value: "d", + want: big.NewRat(1, 1), + }, { name: "Bytes Promote String", reader: `"string"`, @@ -671,35 +703,49 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { "type": "record", "name": "parent", "namespace": "org.hamba.avro", - "fields": [{ - "name": "a", - "type": "int" - }, + "fields": [ { - "name": "b", + "name": "a", "type": { "type": "record", - "name": "test", + "name": "embed", + "namespace": "org.hamba.avro", "fields": [{ "name": "a", "type": "long" }] - }, - "default": {"a": 10} + } }, { - "name": "c", - "type": "test", + "name": "b", + "type": "embed", "default": {"a": 20} } ] }`, - writer: `{"type":"record", "name":"parent", "namespace": "org.hamba.avro", "fields":[{"name": "a", "type": "int"}]}`, - value: map[string]any{"a": 10}, + writer: `{ + "type": "record", + "name": "parent", + "namespace": "org.hamba.avro", + "fields": [ + { + "name": "a", + "type": { + "type": "record", + "name": "embed", + "namespace": "org.hamba.avro", + "fields": [{ + "name": "a", + "type": "long" + }] + } + } + ] + }`, + value: map[string]any{"a": map[string]any{"a": int64(10)}}, want: map[string]any{ - "a": 10, - "b": map[string]any{"a": int64(10)}, - "c": map[string]any{"a": int64(20)}, + "a": map[string]any{"a": int64(10)}, + "b": map[string]any{"a": int64(20)}, }, }, } @@ -725,3 +771,39 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { }) } } + +func TestSchemaCompatibility_ResolveWithRefs(t *testing.T) { + sch1 := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"} + ] + }`) + sch2 := avro.MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "bytes"} + ] + }`) + + r := avro.NewRefSchema(sch1.(*avro.RecordSchema)) + w := avro.NewRefSchema(sch2.(*avro.RecordSchema)) + + sc := avro.NewSchemaCompatibility() + + value := map[string]any{"a": []byte("foo")} + b, err := avro.Marshal(w, value) + assert.NoError(t, err) + + sch, err := sc.Resolve(r, w) + assert.NoError(t, err) + + var result any + err = avro.Unmarshal(sch, b, &result) + assert.NoError(t, err) + + want := map[string]any{"a": "foo"} + assert.Equal(t, want, result) +} diff --git a/schema_internal_test.go b/schema_internal_test.go index 7872eaf7..6d78e9b3 100644 --- a/schema_internal_test.go +++ b/schema_internal_test.go @@ -1,6 +1,7 @@ package avro import ( + "strconv" "testing" "github.com/stretchr/testify/assert" @@ -383,3 +384,200 @@ func TestSchema_FingerprintUsingCaches(t *testing.T) { assert.Equal(t, want, value) assert.Equal(t, want, got) } + +func TestSchema_IsPromotable(t *testing.T) { + tests := []struct { + typ Type + wantOk bool + }{ + { + typ: Int, + wantOk: true, + }, + { + typ: Long, + wantOk: true, + }, + { + typ: Float, + wantOk: true, + }, + { + typ: String, + wantOk: true, + }, + { + typ: Bytes, + wantOk: true, + }, + { + typ: Double, + wantOk: false, + }, + { + typ: Boolean, + wantOk: false, + }, + { + typ: Null, + wantOk: false, + }, + } + + for i, test := range tests { + test := test + t.Run(strconv.Itoa(i), func(t *testing.T) { + ok := isPromotable(test.typ) + assert.Equal(t, test.wantOk, ok) + }) + } +} + +func TestSchema_IsNative(t *testing.T) { + tests := []struct { + typ Type + wantOk bool + }{ + { + typ: Null, + wantOk: true, + }, + { + typ: Boolean, + wantOk: true, + }, + { + typ: Int, + wantOk: true, + }, + { + typ: Long, + wantOk: true, + }, + + { + typ: Float, + wantOk: true, + }, + { + typ: Double, + wantOk: true, + }, + + { + typ: Bytes, + wantOk: true, + }, + { + typ: String, + wantOk: true, + }, + { + typ: Record, + wantOk: false, + }, + { + typ: Array, + wantOk: false, + }, + { + typ: Map, + wantOk: false, + }, + { + typ: Fixed, + wantOk: false, + }, + { + typ: Enum, + wantOk: false, + }, + { + typ: Union, + wantOk: false, + }, + } + + for i, test := range tests { + test := test + t.Run(strconv.Itoa(i), func(t *testing.T) { + ok := isNative(test.typ) + assert.Equal(t, test.wantOk, ok) + }) + } +} + +func TestSchema_FieldEncodeDefault(t *testing.T) { + schema := MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string", "default": "bar"}, + {"name": "b", "type": "boolean"} + ] + }`).(*RecordSchema) + + fooEncoder := func(a any) ([]byte, error) { + return []byte("foo"), nil + } + barEncoder := func(a any) ([]byte, error) { + return []byte("bar"), nil + } + + assert.Equal(t, nil, schema.fields[0].encodedDef.Load()) + + _, err := schema.fields[0].encodeDefault(nil) + assert.Error(t, err) + + _, err = schema.fields[1].encodeDefault(fooEncoder) + assert.Error(t, err) + + def, err := schema.fields[0].encodeDefault(fooEncoder) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), def) + + def, err = schema.fields[0].encodeDefault(barEncoder) + assert.NoError(t, err) + assert.Equal(t, []byte("foo"), def) +} + +func TestSchema_CacheFingerprint(t *testing.T) { + t.Run("invalid", func(t *testing.T) { + cacheFingerprint := cacheFingerprinter{} + assert.Panics(t, func() { + cacheFingerprint.fingerprint([]any{func() {}}) + }) + }) + + t.Run("promoted", func(t *testing.T) { + schema := NewPrimitiveSchema(Long, nil) + assert.Equal(t, schema.Fingerprint(), schema.CacheFingerprint()) + + schema = NewPrimitiveSchema(Long, nil) + schema.actual = Int + assert.NotEqual(t, schema.Fingerprint(), schema.CacheFingerprint()) + }) + + t.Run("record", func(t *testing.T) { + schema1 := MustParse(`{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "string"}, + {"name": "b", "type": "boolean"} + ] + }`).(*RecordSchema) + + schema2 := MustParse(`{ + "type": "record", + "name": "test2", + "fields" : [ + {"name": "a", "type": "string", "default": "bar"}, + {"name": "b", "type": "boolean", "default": false} + ] + }`).(*RecordSchema) + + assert.Equal(t, schema1.Fingerprint(), schema1.CacheFingerprint()) + assert.NotEqual(t, schema1.CacheFingerprint(), schema2.CacheFingerprint()) + }) +} From 51224e680114f3d8d06dbeb07557762795c2a367 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Thu, 21 Dec 2023 19:11:51 +0100 Subject: [PATCH 15/25] fix: fix resolve record --- schema_compatibility.go | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/schema_compatibility.go b/schema_compatibility.go index 57221753..f0cae86a 100644 --- a/schema_compatibility.go +++ b/schema_compatibility.go @@ -373,34 +373,42 @@ func (c *SchemaCompatibility) resolveRecord(reader, writer Schema) (Schema, erro fields := make([]*Field, 0) seen := make(map[string]struct{}) - for _, field := range w.Fields() { - f, _ := NewField(field.Name(), field.Type(), WithAliases(field.aliases), WithOrder(field.order)) - f.def = field.def - f.hasDef = field.hasDef - rf, ok := c.getField(r.Fields(), field, func(gfo *getFieldOptions) { + for _, wf := range w.Fields() { + rf, ok := c.getField(r.Fields(), wf, func(gfo *getFieldOptions) { gfo.elemAlias = true }) if !ok { + f, _ := NewField(wf.Name(), wf.Type(), WithAliases(wf.aliases), WithOrder(wf.order)) + // I believe def is read only it can be copied even if it's a like-pointer type; + // data race should not occur. + f.def = wf.def + f.hasDef = wf.hasDef f.action = FieldDrain fields = append(fields, f) continue } - ft, err := c.resolve(rf.Type(), f.Type()) + + ft, err := c.resolve(rf.Type(), wf.Type()) if err != nil { return nil, err } - rf.typ = ft - fields = append(fields, rf) + f, _ := NewField(rf.Name(), ft, WithAliases(rf.aliases), WithOrder(rf.order)) + f.def = rf.def + f.hasDef = rf.hasDef + fields = append(fields, f) + seen[rf.Name()] = struct{}{} } - for _, field := range r.Fields() { - if _, ok := seen[field.Name()]; ok { + for _, rf := range r.Fields() { + // check if seen in writer's record + if _, ok := seen[rf.Name()]; ok { continue } - f, _ := NewField(field.Name(), field.Type(), WithAliases(field.aliases), WithOrder(field.order)) - f.def = field.def - f.hasDef = field.hasDef + + f, _ := NewField(rf.Name(), rf.Type(), WithAliases(rf.aliases), WithOrder(rf.order)) + f.def = rf.def + f.hasDef = rf.hasDef f.action = FieldSetDefault fields = append(fields, f) } From ade5d387214b98809e7e7bfd0b29ab70654e67b4 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Thu, 21 Dec 2023 19:12:24 +0100 Subject: [PATCH 16/25] improve record cache fingerprint --- schema.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/schema.go b/schema.go index 50779d84..c3161632 100644 --- a/schema.go +++ b/schema.go @@ -657,7 +657,7 @@ func (s *RecordSchema) CacheFingerprint() [32]byte { data := make([]any, 0) for _, field := range s.fields { if field.Default() != nil { - data = append(data, field.Default()) + data = append(data, field.Name(), field.Default()) } } if len(data) == 0 { From e364790f9425dec2616c390c580fb0b64aa87a7d Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Tue, 9 Jan 2024 13:49:17 +0100 Subject: [PATCH 17/25] fix(default encoder): better handling of nullDefault --- codec_default.go | 3 +++ resolver.go | 1 - schema.go | 4 ++-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/codec_default.go b/codec_default.go index 43fff385..72831931 100644 --- a/codec_default.go +++ b/codec_default.go @@ -10,6 +10,9 @@ import ( func createDefaultDecoder(cfg *frozenConfig, field *Field, typ reflect2.Type) ValDecoder { fn := func(def any) ([]byte, error) { defaultType := reflect2.TypeOf(def) + if defaultType == nil { + defaultType = reflect2.TypeOf((*null)(nil)) + } defaultEncoder := encoderOfType(cfg, field.Type(), defaultType) if defaultType.LikePtr() { defaultEncoder = &onePtrEncoder{defaultEncoder} diff --git a/resolver.go b/resolver.go index 7f67fa84..aabf4663 100644 --- a/resolver.go +++ b/resolver.go @@ -22,7 +22,6 @@ func NewTypeResolver() *TypeResolver { // Register basic types r.Register(string(Null), &null{}) - r.Register(string(Null), null{}) r.Register(string(Int), int8(0)) r.Register(string(Int), int16(0)) r.Register(string(Int), int32(0)) diff --git a/schema.go b/schema.go index c3161632..4cc1bdbd 100644 --- a/schema.go +++ b/schema.go @@ -17,7 +17,7 @@ import ( jsoniter "github.com/json-iterator/go" ) -var nullDefault null = struct{}{} +var nullDefault = struct{}{} var ( schemaReserved = []string{ @@ -778,7 +778,7 @@ func (f *Field) encodeDefault(encode func(any) ([]byte, error)) ([]byte, error) if encode == nil { return nil, fmt.Errorf("avro: failed to encode '%s' default value", f.name) } - b, err := encode(f.def) + b, err := encode(f.Default()) if err != nil { return nil, err } From bc0e276f87d10f37fb6573bd29df49a956166765 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Tue, 9 Jan 2024 14:06:35 +0100 Subject: [PATCH 18/25] fix: record cache fingerprint --- schema.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/schema.go b/schema.go index 4cc1bdbd..1b8592f3 100644 --- a/schema.go +++ b/schema.go @@ -656,7 +656,7 @@ func (s *RecordSchema) FingerprintUsing(typ FingerprintType) ([]byte, error) { func (s *RecordSchema) CacheFingerprint() [32]byte { data := make([]any, 0) for _, field := range s.fields { - if field.Default() != nil { + if field.HasDefault() { data = append(data, field.Name(), field.Default()) } } From 2ec1b6067dc4c816401305c6cbf9288c5a357b80 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Tue, 9 Jan 2024 14:09:20 +0100 Subject: [PATCH 19/25] rename FieldDrain by FieldIgnore --- codec_default_internal_test.go | 4 ++-- codec_record.go | 4 ++-- schema.go | 2 +- schema_compatibility.go | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/codec_default_internal_test.go b/codec_default_internal_test.go index 9adf1b1a..27788039 100644 --- a/codec_default_internal_test.go +++ b/codec_default_internal_test.go @@ -57,7 +57,7 @@ func TestDecoder_InvalidDefault(t *testing.T) { require.Error(t, err) } -func TestDecoder_DrainField(t *testing.T) { +func TestDecoder_IgnoreField(t *testing.T) { defer ConfigTeardown() // write schema @@ -81,7 +81,7 @@ func TestDecoder_DrainField(t *testing.T) { ] }`) - schema.(*RecordSchema).Fields()[0].action = FieldDrain + schema.(*RecordSchema).Fields()[0].action = FieldIgnore schema.(*RecordSchema).Fields()[1].action = FieldSetDefault type TestRecord struct { diff --git a/codec_record.go b/codec_record.go index 24e1f72e..751e8679 100644 --- a/codec_record.go +++ b/codec_record.go @@ -60,7 +60,7 @@ func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec fields := make([]*structFieldDecoder, 0, len(rec.Fields())) for _, field := range rec.Fields() { - if field.action == FieldDrain { + if field.action == FieldIgnore { fields = append(fields, &structFieldDecoder{ decoder: createSkipDecoder(field.Type()), }) @@ -256,7 +256,7 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec fields := make([]recordMapDecoderField, len(rec.Fields())) for i, field := range rec.Fields() { - if field.action == FieldDrain { + if field.action == FieldIgnore { fields[i] = recordMapDecoderField{ name: field.Name(), decoder: createSkipDecoder(field.Type()), diff --git a/schema.go b/schema.go index 1b8592f3..627b529a 100644 --- a/schema.go +++ b/schema.go @@ -100,7 +100,7 @@ type Action string // Action type constants. const ( - FieldDrain Action = "drain" + FieldIgnore Action = "ignore" FieldSetDefault Action = "set_default" ) diff --git a/schema_compatibility.go b/schema_compatibility.go index f0cae86a..fd0c1a98 100644 --- a/schema_compatibility.go +++ b/schema_compatibility.go @@ -383,7 +383,7 @@ func (c *SchemaCompatibility) resolveRecord(reader, writer Schema) (Schema, erro // data race should not occur. f.def = wf.def f.hasDef = wf.hasDef - f.action = FieldDrain + f.action = FieldIgnore fields = append(fields, f) continue } From 17926aca2e72f10ba66aa71072da505553d45d77 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Wed, 10 Jan 2024 11:52:11 +0100 Subject: [PATCH 20/25] fix: record cache fingerprint --- schema.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/schema.go b/schema.go index 627b529a..0a6aa896 100644 --- a/schema.go +++ b/schema.go @@ -90,9 +90,8 @@ func isPromotable(typ Type) bool { case Int, Long, Float, String, Bytes: return true default: + return false } - - return false } // Action is a field action used during decoding process. @@ -656,14 +655,13 @@ func (s *RecordSchema) FingerprintUsing(typ FingerprintType) ([]byte, error) { func (s *RecordSchema) CacheFingerprint() [32]byte { data := make([]any, 0) for _, field := range s.fields { - if field.HasDefault() { + if field.HasDefault() && field.action == FieldSetDefault { data = append(data, field.Name(), field.Default()) } } if len(data) == 0 { return s.Fingerprint() } - data = append(data, s.Fingerprint()) return s.cacheFingerprinter.fingerprint(data) } From 1dc0276f3a0d0162d342c827fba8c7c3b05213c4 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Wed, 10 Jan 2024 13:25:57 +0100 Subject: [PATCH 21/25] fix: record cache fingerprint --- config_internal_test.go | 51 +++++++++++++++++++++++++++++++++++++++++ schema.go | 3 +++ 2 files changed, 54 insertions(+) diff --git a/config_internal_test.go b/config_internal_test.go index e1782ca0..674bd543 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -44,6 +44,57 @@ func TestConfig_ReusesDecoders(t *testing.T) { assert.Same(t, dec1, dec2) } +func TestConfig_ReusesDecoders_WithRecordFieldActions(t *testing.T) { + type testObj struct { + A int64 `avro:"a"` + B string `avro:"b"` + } + sch := `{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "long"}, + {"name": "a", "type": "string", "default": "foo"} + ] + }` + typ := reflect2.TypeOfPtr(&testObj{}) + + t.Run("set default", func(t *testing.T) { + api := Config{ + TagKey: "test", + BlockLength: 2, + }.Freeze() + cfg := api.(*frozenConfig) + + schema1 := MustParse(sch) + schema2 := MustParse(sch) + schema2.(*RecordSchema).Fields()[1].action = FieldSetDefault + + dec1 := cfg.DecoderOf(schema1, typ) + dec2 := cfg.DecoderOf(schema2, typ) + + assert.NotSame(t, dec1, dec2) + }) + + t.Run("ignore", func(t *testing.T) { + api := Config{ + TagKey: "test", + BlockLength: 2, + }.Freeze() + cfg := api.(*frozenConfig) + + schema1 := MustParse(sch) + schema1.(*RecordSchema).Fields()[0].action = FieldIgnore + schema2 := MustParse(sch) + + dec1 := cfg.DecoderOf(schema1, typ) + dec2 := cfg.DecoderOf(schema2, typ) + + assert.NotSame(t, dec1, dec2) + }) + +} + func TestConfig_DisableCache_DoesNotReuseDecoders(t *testing.T) { type testObj struct { A int64 `avro:"a"` diff --git a/schema.go b/schema.go index 0a6aa896..4587a761 100644 --- a/schema.go +++ b/schema.go @@ -658,6 +658,9 @@ func (s *RecordSchema) CacheFingerprint() [32]byte { if field.HasDefault() && field.action == FieldSetDefault { data = append(data, field.Name(), field.Default()) } + if field.action == FieldIgnore { + data = append(data, field.Name()+string(FieldIgnore)) + } } if len(data) == 0 { return s.Fingerprint() From 85f8d7cb8fb1c6db8c36bac8ab9c04a4d7e23f62 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Wed, 10 Jan 2024 13:46:47 +0100 Subject: [PATCH 22/25] clean up Co-authored-by: Nicholas Wiersma --- schema.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/schema.go b/schema.go index 4587a761..cbaf919d 100644 --- a/schema.go +++ b/schema.go @@ -80,9 +80,8 @@ func isNative(typ Type) bool { case Null, Boolean, Int, Long, Float, Double, Bytes, String: return true default: + return false } - - return false } func isPromotable(typ Type) bool { From b96aefc3d39038c7b0d989070a083704f9ad0c12 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Wed, 10 Jan 2024 13:50:05 +0100 Subject: [PATCH 23/25] clean up --- schema.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/schema.go b/schema.go index cbaf919d..243598b0 100644 --- a/schema.go +++ b/schema.go @@ -80,7 +80,7 @@ func isNative(typ Type) bool { case Null, Boolean, Int, Long, Float, Double, Bytes, String: return true default: - return false + return false } } From f2c19a2d2030e0ff0e23f53552c0dfeb1b54616a Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Fri, 12 Jan 2024 10:36:10 +0100 Subject: [PATCH 24/25] fix: codec default reader/writer usage --- codec_default.go | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/codec_default.go b/codec_default.go index 72831931..5225c616 100644 --- a/codec_default.go +++ b/codec_default.go @@ -18,11 +18,17 @@ func createDefaultDecoder(cfg *frozenConfig, field *Field, typ reflect2.Type) Va defaultEncoder = &onePtrEncoder{defaultEncoder} } w := cfg.borrowWriter() + defer cfg.returnWriter(w) + defaultEncoder.Encode(reflect2.PtrOf(def), w) if w.Error != nil { return nil, w.Error } - return w.Buffer(), nil + b := w.Buffer() + data := make([]byte, len(b)) + copy(data, b) + + return data, nil } b, err := field.encodeDefault(fn) @@ -30,19 +36,22 @@ func createDefaultDecoder(cfg *frozenConfig, field *Field, typ reflect2.Type) Va return &errorDecoder{err: fmt.Errorf("decode default: %w", err)} } return &defaultDecoder{ - defaultReader: cfg.borrowReader(b), - decoder: decoderOfType(cfg, field.Type(), typ), + data: b, + decoder: decoderOfType(cfg, field.Type(), typ), } } type defaultDecoder struct { - defaultReader *Reader - decoder ValDecoder + data []byte + decoder ValDecoder } // Decode implements ValDecoder. -func (d *defaultDecoder) Decode(ptr unsafe.Pointer, _ *Reader) { - d.decoder.Decode(ptr, d.defaultReader) +func (d *defaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + rr := r.cfg.borrowReader(d.data) + defer r.cfg.returnReader(rr) + + d.decoder.Decode(ptr, rr) } var _ ValDecoder = &defaultDecoder{} From 35a64d950e4516d062eb5c5744948e0588bce6d0 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Fri, 12 Jan 2024 10:40:11 +0100 Subject: [PATCH 25/25] fix: bytes to string converter --- converter.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/converter.go b/converter.go index 0cc0bcd5..0a100a35 100644 --- a/converter.go +++ b/converter.go @@ -2,6 +2,7 @@ package avro import ( "fmt" + "unsafe" "github.com/modern-go/reflect2" ) @@ -71,10 +72,10 @@ func createStringConverter(typ Type) (func(*Reader) string, error) { case Bytes: return func(r *Reader) string { b := r.ReadBytes() - // TBD: update go.mod version to go 1.20 minimum - // runtime.KeepAlive(b) - // return unsafe.String(unsafe.SliceData(b), len(b)) - return string(b) + if len(b) == 0 { + return "" + } + return *(*string)(unsafe.Pointer(&b)) }, nil case String: return func(r *Reader) string { return r.ReadString() }, nil