diff --git a/codec_default.go b/codec_default.go index 59618c7d..264a1663 100644 --- a/codec_default.go +++ b/codec_default.go @@ -20,7 +20,9 @@ func createDefaultDecoder(cfg *frozenConfig, schema Schema, def any, typ reflect switch schema.Type() { case Null: - return &nullDefaultDecoder{} + return &nullDefaultDecoder{ + typ: typ, + } case Boolean: return &boolDefaultDecoder{ @@ -150,10 +152,13 @@ func (d *boolDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { } type nullDefaultDecoder struct { + typ reflect2.Type } -func (d *nullDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - return +func (d *nullDefaultDecoder) Decode(ptr unsafe.Pointer, _ *Reader) { + if d.typ.IsNullable() { + d.typ.UnsafeSet(ptr, d.typ.UnsafeNew()) + } } type intDefaultDecoder struct { @@ -270,7 +275,6 @@ func (d *doubleDefaultDecoder) Decode(ptr unsafe.Pointer, r *Reader) { default: r.ReportError("decode default", "unsupported type") } - } type stringDefaultDecoder struct { diff --git a/codec_native.go b/codec_native.go index 1917e8e3..bf2d2c96 100644 --- a/codec_native.go +++ b/codec_native.go @@ -11,7 +11,7 @@ import ( ) func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { - actual := schema.(*PrimitiveSchema).actual + converter := resolveConverter(schema.(*PrimitiveSchema).actual) switch typ.Kind() { case reflect.Bool: @@ -60,9 +60,7 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { if schema.Type() != Long { break } - return &longCodec[uint32]{ - promoter: getCodecPromoter[uint32](actual), - } + return &longCodec[uint32]{convert: converter.toLong} case reflect.Int64: st := schema.Type() @@ -73,12 +71,12 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { case st == Long && lt == TimeMicros: // time.Duration return &timeMicrosCodec{ - promoter: getCodecPromoter[int64](actual), + convert: converter.toLong, } case st == Long: return &longCodec[int64]{ - promoter: getCodecPromoter[int64](actual), + convert: converter.toLong, } default: @@ -90,7 +88,7 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { break } return &float32Codec{ - promoter: getCodecPromoter[float32](actual), + convert: converter.toFloat, } case reflect.Float64: @@ -98,7 +96,7 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { break } return &float64Codec{ - promoter: getCodecPromoter[float64](actual), + convert: converter.toDouble, } case reflect.String: @@ -106,7 +104,7 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { break } return &stringCodec{ - promoter: getCodecPromoter[string](actual), + convert: converter.toString, } case reflect.Slice: @@ -115,7 +113,7 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { } return &bytesCodec{ sliceType: typ.(*reflect2.UnsafeSliceType), - promoter: getCodecPromoter[[]byte](actual), + convert: converter.toBytes, } case reflect.Struct: @@ -131,20 +129,19 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { case Istpy1Time && st == Long && lt == TimestampMillis: return ×tampMillisCodec{ - promoter: getCodecPromoter[int64](actual), + convert: converter.toLong, } case Istpy1Time && st == Long && lt == TimestampMicros: return ×tampMicrosCodec{ - promoter: getCodecPromoter[int64](actual), + convert: converter.toLong, } case Istpy1Rat && st == Bytes && lt == Decimal: dec := ls.(*DecimalLogicalSchema) - return &bytesDecimalCodec{ prec: dec.Precision(), scale: dec.Scale(), - promoter: getCodecPromoter[[]byte](actual), + convert: converter.toBytes, } default: @@ -165,7 +162,7 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { return &bytesDecimalPtrCodec{ prec: dec.Precision(), scale: dec.Scale(), - promoter: getCodecPromoter[[]byte](actual), + convert: converter.toBytes, } } @@ -369,13 +366,13 @@ type largeInt interface { } type longCodec[T largeInt] struct { - promoter *codecPromoter[T] + convert func(*Reader) int64 } func (c *longCodec[T]) Decode(ptr unsafe.Pointer, r *Reader) { var v T - if c.promoter != nil { - v = c.promoter.promote(r) + if c.convert != nil { + v = T(c.convert(r)) } else { v = T(r.ReadLong()) } @@ -387,13 +384,13 @@ func (*longCodec[T]) Encode(ptr unsafe.Pointer, w *Writer) { } type float32Codec struct { - promoter *codecPromoter[float32] + convert func(*Reader) float32 } func (c *float32Codec) Decode(ptr unsafe.Pointer, r *Reader) { var v float32 - if c.promoter != nil { - v = c.promoter.promote(r) + if c.convert != nil { + v = c.convert(r) } else { v = r.ReadFloat() } @@ -412,13 +409,13 @@ func (*float32DoubleCodec) Encode(ptr unsafe.Pointer, w *Writer) { } type float64Codec struct { - promoter *codecPromoter[float64] + convert func(*Reader) float64 } func (c *float64Codec) Decode(ptr unsafe.Pointer, r *Reader) { var v float64 - if c.promoter != nil { - v = c.promoter.promote(r) + if c.convert != nil { + v = c.convert(r) } else { v = r.ReadDouble() } @@ -430,13 +427,13 @@ func (*float64Codec) Encode(ptr unsafe.Pointer, w *Writer) { } type stringCodec struct { - promoter *codecPromoter[string] + convert func(*Reader) string } func (c *stringCodec) Decode(ptr unsafe.Pointer, r *Reader) { var v string - if c.promoter != nil { - v = c.promoter.promote(r) + if c.convert != nil { + v = c.convert(r) } else { v = r.ReadString() } @@ -449,17 +446,16 @@ func (*stringCodec) Encode(ptr unsafe.Pointer, w *Writer) { type bytesCodec struct { sliceType *reflect2.UnsafeSliceType - promoter *codecPromoter[[]byte] + convert func(*Reader) []byte } func (c *bytesCodec) Decode(ptr unsafe.Pointer, r *Reader) { var b []byte - if c.promoter != nil { - b = c.promoter.promote(r) + if c.convert != nil { + b = c.convert(r) } else { b = r.ReadBytes() } - // b := r.ReadBytes() c.sliceType.UnsafeSet(ptr, reflect2.PtrOf(b)) } @@ -482,13 +478,13 @@ func (c *dateCodec) Encode(ptr unsafe.Pointer, w *Writer) { } type timestampMillisCodec struct { - promoter *codecPromoter[int64] + convert func(*Reader) int64 } func (c *timestampMillisCodec) Decode(ptr unsafe.Pointer, r *Reader) { var i int64 - if c.promoter != nil { - i = c.promoter.promote(r) + if c.convert != nil { + i = c.convert(r) } else { i = r.ReadLong() } @@ -503,13 +499,13 @@ func (c *timestampMillisCodec) Encode(ptr unsafe.Pointer, w *Writer) { } type timestampMicrosCodec struct { - promoter *codecPromoter[int64] + convert func(*Reader) int64 } func (c *timestampMicrosCodec) Decode(ptr unsafe.Pointer, r *Reader) { var i int64 - if c.promoter != nil { - i = c.promoter.promote(r) + if c.convert != nil { + i = c.convert(r) } else { i = r.ReadLong() } @@ -523,8 +519,7 @@ func (c *timestampMicrosCodec) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteLong(t.Unix()*1e6 + int64(t.Nanosecond()/1e3)) } -type timeMillisCodec struct { -} +type timeMillisCodec struct{} func (c *timeMillisCodec) Decode(ptr unsafe.Pointer, r *Reader) { i := r.ReadInt() @@ -537,13 +532,13 @@ func (c *timeMillisCodec) Encode(ptr unsafe.Pointer, w *Writer) { } type timeMicrosCodec struct { - promoter *codecPromoter[int64] + convert func(*Reader) int64 } func (c *timeMicrosCodec) Decode(ptr unsafe.Pointer, r *Reader) { var i int64 - if c.promoter != nil { - i = c.promoter.promote(r) + if c.convert != nil { + i = c.convert(r) } else { i = r.ReadLong() } @@ -558,19 +553,18 @@ func (c *timeMicrosCodec) Encode(ptr unsafe.Pointer, w *Writer) { var one = big.NewInt(1) type bytesDecimalCodec struct { - prec int - scale int - promoter *codecPromoter[[]byte] + prec int + scale int + convert func(*Reader) []byte } func (c *bytesDecimalCodec) Decode(ptr unsafe.Pointer, r *Reader) { var b []byte - if c.promoter != nil { - b = c.promoter.promote(r) + if c.convert != nil { + b = c.convert(r) } else { b = r.ReadBytes() } - // b := r.ReadBytes() if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 { i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8)) } @@ -611,19 +605,19 @@ func (c *bytesDecimalCodec) Encode(ptr unsafe.Pointer, w *Writer) { } type bytesDecimalPtrCodec struct { - prec int - scale int - promoter *codecPromoter[[]byte] + prec int + scale int + convert func(*Reader) []byte } func (c *bytesDecimalPtrCodec) Decode(ptr unsafe.Pointer, r *Reader) { var b []byte - if c.promoter != nil { - b = c.promoter.promote(r) + if c.convert != nil { + b = c.convert(r) } else { b = r.ReadBytes() } - // b := r.ReadBytes() + if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 { i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8)) } diff --git a/codec_promoter.go b/codec_promoter.go deleted file mode 100644 index 5776121f..00000000 --- a/codec_promoter.go +++ /dev/null @@ -1,64 +0,0 @@ -package avro - -import ( - "reflect" - - "github.com/modern-go/reflect2" -) - -func getCodecPromoter[T any](actual Type) *codecPromoter[T] { - if actual == "" { - return nil - } - - return &codecPromoter[T]{actual: actual} -} - -type codecPromoter[T any] struct { - actual Type -} - -func (p *codecPromoter[T]) promote(r *Reader) (t T) { - tt := reflect2.TypeOf(t) - - convert := func(typ reflect2.Type, obj any) (t T) { - if !reflect.TypeOf(obj).ConvertibleTo(typ.Type1()) { - r.ReportError("decode promotable", "unsupported type") - // return zero value - return t - } - return reflect.ValueOf(obj).Convert(typ.Type1()).Interface().(T) - } - - switch p.actual { - case Int: - var obj int32 - (&intCodec[int32]{}).Decode(reflect2.PtrOf(&obj), r) - t = convert(tt, obj) - - case Long: - var obj int64 - (&longCodec[int64]{}).Decode(reflect2.PtrOf(&obj), r) - t = convert(tt, obj) - - case Float: - var obj float32 - (&float32Codec{}).Decode(reflect2.PtrOf(&obj), r) - t = convert(tt, obj) - - case String: - var obj string - (&stringCodec{}).Decode(reflect2.PtrOf(&obj), r) - t = convert(tt, obj) - - case Bytes: - var obj []byte - (&bytesCodec{}).Decode(reflect2.PtrOf(&obj), r) - t = convert(tt, obj) - - default: - r.ReportError("decode promotable", "unsupported actual type") - } - - return t -} diff --git a/codec_record.go b/codec_record.go index e62068aa..7e04a6f0 100644 --- a/codec_record.go +++ b/codec_record.go @@ -263,6 +263,7 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec fields[i] = recordMapDecoderField{ name: field.Name(), decoder: createSkipDecoder(field.Type()), + skip: true, } continue } @@ -273,11 +274,6 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec name: field.Name(), decoder: createDefaultDecoder(cfg, field.Type(), field.def, mapType.Elem()), } - } else { - fields[i] = recordMapDecoderField{ - name: field.Name(), - decoder: createSkipDecoder(field.Type()), - } } continue @@ -299,6 +295,7 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec type recordMapDecoderField struct { name string decoder ValDecoder + skip bool } type recordMapDecoder struct { @@ -315,6 +312,9 @@ func (d *recordMapDecoder) Decode(ptr unsafe.Pointer, r *Reader) { for _, field := range d.fields { elem := d.elemType.UnsafeNew() field.decoder.Decode(elem, r) + if field.skip { + continue + } d.mapType.UnsafeSetIndex(ptr, reflect2.PtrOf(field), elem) } diff --git a/converter.go b/converter.go new file mode 100644 index 00000000..fa4b47a7 --- /dev/null +++ b/converter.go @@ -0,0 +1,99 @@ +package avro + +import ( + "fmt" + + "github.com/modern-go/reflect2" +) + +type converter struct { + toLong func(*Reader) int64 + toFloat func(*Reader) float32 + toDouble func(*Reader) float64 + toString func(*Reader) string + toBytes func(*Reader) []byte +} + +// resolveConverter returns a set of converter functions based on the actual type. +// Depending on the actual type value, some converter functions may be nil; +// thus, the downstream caller must first check the converter function value. +func resolveConverter(typ Type) converter { + cv := converter{} + + cv.toLong, _ = createLongConverter(typ) + cv.toFloat, _ = createFloatConverter(typ) + cv.toDouble, _ = createDoubleConverter(typ) + cv.toString, _ = createStringConverter(typ) + cv.toBytes, _ = createBytesConverter(typ) + + return cv +} + +func createLongConverter(typ Type) (func(*Reader) int64, error) { + switch typ { + case Int: + return func(r *Reader) int64 { return int64(r.ReadInt()) }, nil + case Long: + return func(r *Reader) int64 { return r.ReadLong() }, nil + default: + return nil, fmt.Errorf("cannot promote from %q to %q", typ, Long) + } +} + +func createFloatConverter(typ Type) (func(*Reader) float32, error) { + switch typ { + case Int: + return func(r *Reader) float32 { return float32(r.ReadInt()) }, nil + case Long: + return func(r *Reader) float32 { return float32(r.ReadLong()) }, nil + case Float: + return func(r *Reader) float32 { return r.ReadFloat() }, nil + default: + return nil, fmt.Errorf("cannot promote from %q to %q", typ, Long) + } +} + +func createDoubleConverter(typ Type) (func(*Reader) float64, error) { + switch typ { + case Int: + return func(r *Reader) float64 { return float64(r.ReadInt()) }, nil + case Long: + return func(r *Reader) float64 { return float64(r.ReadLong()) }, nil + case Float: + return func(r *Reader) float64 { return float64(r.ReadFloat()) }, nil + case Double: + return func(r *Reader) float64 { return r.ReadDouble() }, nil + default: + return nil, fmt.Errorf("cannot promote from %q to %q", typ, Long) + } +} + +func createStringConverter(typ Type) (func(*Reader) string, error) { + switch typ { + case Bytes: + return func(r *Reader) string { + b := r.ReadBytes() + // TBD: update go.mod version to go 1.20 minimum + // runtime.KeepAlive(b) // TBD: I guess this line is required? + // return unsafe.String(unsafe.SliceData(b), len(b)) + return string(b) + }, nil + case String: + return func(r *Reader) string { return r.ReadString() }, nil + default: + return nil, fmt.Errorf("cannot promote from %q to %q", typ, Long) + } +} + +func createBytesConverter(typ Type) (func(*Reader) []byte, error) { + switch typ { + case String: + return func(r *Reader) []byte { + return reflect2.UnsafeCastString(r.ReadString()) + }, nil + case Bytes: + return func(r *Reader) []byte { return r.ReadBytes() }, nil + default: + return nil, fmt.Errorf("cannot promote from %q to %q", typ, Long) + } +} diff --git a/reader_generic.go b/reader_generic.go index b7cfa02c..79e5c38c 100644 --- a/reader_generic.go +++ b/reader_generic.go @@ -2,16 +2,15 @@ package avro import ( "fmt" - "log" "reflect" "time" ) // ReadNext reads the next Avro element as a generic interface. func (r *Reader) ReadNext(schema Schema) any { - var rp ReaderPromoter = r + var rp iReaderPromoter = r if sch, ok := schema.(*PrimitiveSchema); ok && sch.actual != "" { - rp = &readerPromoter{r: r, actual: sch.actual, current: sch.Type()} + rp = newReaderPromoter(sch.actual, r) } var ls LogicalSchema @@ -20,8 +19,6 @@ func (r *Reader) ReadNext(schema Schema) any { ls = lts.Logical() } - log.Println("ls", ls) - switch schema.Type() { case Boolean: return r.ReadBool() diff --git a/reader_promoter.go b/reader_promoter.go index f89b2f86..b5575eed 100644 --- a/reader_promoter.go +++ b/reader_promoter.go @@ -1,10 +1,6 @@ package avro -import ( - "reflect" -) - -type ReaderPromoter interface { +type iReaderPromoter interface { ReadLong() int64 ReadFloat() float32 ReadDouble() float64 @@ -13,90 +9,59 @@ type ReaderPromoter interface { } type readerPromoter struct { - actual, current Type - r *Reader + actual Type + r *Reader + converter } -var _ ReaderPromoter = &readerPromoter{} - -var promotedInvalid = struct{}{} - -func (p *readerPromoter) readActual() any { - switch p.actual { - case Int: - return p.r.ReadInt() - - case Long: - return p.r.ReadLong() - - case Float: - return p.r.ReadFloat() - - case String: - return p.r.ReadString() - - case Bytes: - return p.r.ReadBytes() - - default: - p.r.ReportError("decode promotable", "unsupported actual type") - return promotedInvalid +func newReaderPromoter(actual Type, r *Reader) *readerPromoter { + rp := &readerPromoter{ + actual: actual, + r: r, + converter: resolveConverter(actual), } + + return rp } +var _ iReaderPromoter = &readerPromoter{} + func (p *readerPromoter) ReadLong() int64 { - if v := p.readActual(); v != promotedInvalid { - return p.promote(v, p.current).(int64) + if p.toLong != nil { + return p.toLong(p.r) } - return 0 + return p.r.ReadLong() } func (p *readerPromoter) ReadFloat() float32 { - if v := p.readActual(); v != promotedInvalid { - return p.promote(v, p.current).(float32) + if p.toFloat != nil { + return p.toFloat(p.r) } - return 0 + return p.r.ReadFloat() } func (p *readerPromoter) ReadDouble() float64 { - if v := p.readActual(); v != promotedInvalid { - return p.promote(v, p.current).(float64) + if p.toDouble != nil { + return p.toDouble(p.r) } - return 0 + return p.r.ReadDouble() } func (p *readerPromoter) ReadString() string { - if v := p.readActual(); v != promotedInvalid { - return p.promote(v, p.current).(string) + if p.toString != nil { + return p.toString(p.r) } - return "" + return p.r.ReadString() } func (p *readerPromoter) ReadBytes() []byte { - if v := p.readActual(); v != promotedInvalid { - return p.promote(v, p.current).([]byte) - } - - return nil -} - -func (p *readerPromoter) promote(obj any, st Type) (t any) { - switch st { - case Long: - return int64(reflect.ValueOf(obj).Int()) - case Float: - return float32(reflect.ValueOf(obj).Int()) - case Double: - return float64(reflect.ValueOf(obj).Float()) - case String: - return string(reflect.ValueOf(obj).Bytes()) - case Bytes: - return []byte(reflect.ValueOf(obj).String()) + if p.toBytes != nil { + return p.toBytes(p.r) } - return obj + return p.r.ReadBytes() } diff --git a/schema.go b/schema.go index c9e174ea..34c67830 100644 --- a/schema.go +++ b/schema.go @@ -75,9 +75,30 @@ const ( Duration LogicalType = "duration" ) +func isNative(typ Type) bool { + switch typ { + case Null, Boolean, Int, Long, Float, Double, Bytes, String: + return true + default: + } + + return false +} + +func isPromotable(typ Type) bool { + switch typ { + case Int, Long, Float, String, Bytes: + return true + default: + } + + return false +} + // Action is a field action used during decoding process. type Action string +// Action type constants const ( FieldDrain Action = "drain" FieldSetDefault Action = "set_default" @@ -397,7 +418,7 @@ type PrimitiveSchema struct { // actual presents the actual type of the encoded value // which can be promoted to schema current type. - // This field is only used in the context of write read schema resolution. + // It's only used in the context of write-read schema resolution. actual Type } @@ -666,6 +687,7 @@ func (f *Field) Name() string { return f.name } +// Action returns the action of a field. func (f *Field) Action() Action { return f.action } diff --git a/schema_compatibility.go b/schema_compatibility.go index 8111a375..c80ba9b2 100644 --- a/schema_compatibility.go +++ b/schema_compatibility.go @@ -268,26 +268,10 @@ func (c *SchemaCompatibility) getField(a []*Field, f *Field, optFns ...func(*get return nil, false } -func isNative(typ Type) bool { - switch typ { - case Null, Boolean, Int, Long, Float, Double, Bytes, String: - return true - default: - } - - return false -} - -func isPromotable(typ Type) bool { - switch typ { - case Int, Long, Float, String, Bytes: - return true - default: - } - - return false -} - +// Resolve returns a composite schema that allows decoding data written by the writer schema, +// and makes necessary adjustments to support the reader schema. +// +// It fails if the writer and reader schemas are not compatible. func (c *SchemaCompatibility) Resolve(reader, writer Schema) (Schema, error) { if reader.Type() == Ref { reader = reader.(*RefSchema).Schema() @@ -300,13 +284,15 @@ func (c *SchemaCompatibility) Resolve(reader, writer Schema) (Schema, error) { return nil, err } + return c.resolve(reader, writer) +} + +func (c *SchemaCompatibility) resolve(reader, writer Schema) (Schema, error) { if writer.Type() != reader.Type() { if isPromotable(writer.Type()) { - // TODO clean up - r := *reader.(*PrimitiveSchema) + r := NewPrimitiveSchema(reader.Type(), reader.(*PrimitiveSchema).Logical()) r.actual = writer.Type() - - return &r, nil + return r, nil } if reader.Type() == Union { diff --git a/schema_compatibility_test.go b/schema_compatibility_test.go index c7a3f84b..3346a181 100644 --- a/schema_compatibility_test.go +++ b/schema_compatibility_test.go @@ -337,6 +337,10 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { "type": "array", "items": "int" }, "default":[1, 2, 3, 4] + },{ + "name": "g", + "type": ["null", "string"], + "default": null } ] }`) @@ -354,6 +358,7 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { B string `avro:"b"` D []int32 `avro:"d"` K []byte `avro:"k"` + G *string `avro:"g"` } a2 := A2{} @@ -363,5 +368,5 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { t.Fatalf("unmarshal error %v", err) } - log.Printf("result: %+v %+v %T %+v", a2, a2.A, a2.A, string(a2.K)) + log.Printf("result: %+v", a2) }