From 4ade51d5146dfeee0f73be45ba3f58773ef1dc26 Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Fri, 17 Nov 2023 19:41:29 +0100 Subject: [PATCH] fix: type promotion POC --- codec_native.go | 178 +++++++++++++++++++++++++++-------- codec_promoter.go | 64 +++++++++++++ reader_generic.go | 28 ++++-- reader_promoter.go | 102 ++++++++++++++++++++ schema.go | 5 + schema_compatibility.go | 6 +- schema_compatibility_test.go | 56 +++++++---- 7 files changed, 371 insertions(+), 68 deletions(-) create mode 100644 codec_promoter.go create mode 100644 reader_promoter.go diff --git a/codec_native.go b/codec_native.go index d821747c..1917e8e3 100644 --- a/codec_native.go +++ b/codec_native.go @@ -11,6 +11,8 @@ import ( ) func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { + actual := schema.(*PrimitiveSchema).actual + switch typ.Kind() { case reflect.Bool: if schema.Type() != Boolean { @@ -58,7 +60,9 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { if schema.Type() != Long { break } - return &longCodec[uint32]{} + return &longCodec[uint32]{ + promoter: getCodecPromoter[uint32](actual), + } case reflect.Int64: st := schema.Type() @@ -68,10 +72,14 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { return &timeMillisCodec{} case st == Long && lt == TimeMicros: // time.Duration - return &timeMicrosCodec{} + return &timeMicrosCodec{ + promoter: getCodecPromoter[int64](actual), + } case st == Long: - return &longCodec[int64]{} + return &longCodec[int64]{ + promoter: getCodecPromoter[int64](actual), + } default: break @@ -81,25 +89,34 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { if schema.Type() != Float { break } - return &float32Codec{} + return &float32Codec{ + promoter: getCodecPromoter[float32](actual), + } case reflect.Float64: if schema.Type() != Double { break } - return &float64Codec{} + return &float64Codec{ + promoter: getCodecPromoter[float64](actual), + } case reflect.String: if schema.Type() != String { break } - return &stringCodec{} + return &stringCodec{ + promoter: getCodecPromoter[string](actual), + } case reflect.Slice: if typ.(reflect2.SliceType).Elem().Kind() != reflect.Uint8 || schema.Type() != Bytes { break } - return &bytesCodec{sliceType: typ.(*reflect2.UnsafeSliceType)} + return &bytesCodec{ + sliceType: typ.(*reflect2.UnsafeSliceType), + promoter: getCodecPromoter[[]byte](actual), + } case reflect.Struct: st := schema.Type() @@ -113,15 +130,22 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { return &dateCodec{} case Istpy1Time && st == Long && lt == TimestampMillis: - return ×tampMillisCodec{} + return ×tampMillisCodec{ + promoter: getCodecPromoter[int64](actual), + } case Istpy1Time && st == Long && lt == TimestampMicros: - return ×tampMicrosCodec{} + return ×tampMicrosCodec{ + promoter: getCodecPromoter[int64](actual), + } case Istpy1Rat && st == Bytes && lt == Decimal: dec := ls.(*DecimalLogicalSchema) - return &bytesDecimalCodec{prec: dec.Precision(), scale: dec.Scale()} + return &bytesDecimalCodec{ + prec: dec.Precision(), scale: dec.Scale(), + promoter: getCodecPromoter[[]byte](actual), + } default: break @@ -139,7 +163,10 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { } dec := ls.(*DecimalLogicalSchema) - return &bytesDecimalPtrCodec{prec: dec.Precision(), scale: dec.Scale()} + return &bytesDecimalPtrCodec{ + prec: dec.Precision(), scale: dec.Scale(), + promoter: getCodecPromoter[[]byte](actual), + } } return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} @@ -206,6 +233,7 @@ func createEncoderOfNative(schema Schema, typ reflect2.Type) ValEncoder { return &timeMillisCodec{} case st == Long && lt == TimeMicros: // time.Duration + return &timeMicrosCodec{} case st == Long: @@ -340,20 +368,37 @@ type largeInt interface { ~int32 | ~uint32 | int64 } -type longCodec[T largeInt] struct{} +type longCodec[T largeInt] struct { + promoter *codecPromoter[T] +} -func (*longCodec[T]) Decode(ptr unsafe.Pointer, r *Reader) { - *((*T)(ptr)) = T(r.ReadLong()) +func (c *longCodec[T]) Decode(ptr unsafe.Pointer, r *Reader) { + var v T + if c.promoter != nil { + v = c.promoter.promote(r) + } else { + v = T(r.ReadLong()) + } + *((*T)(ptr)) = v } func (*longCodec[T]) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteLong(int64(*((*T)(ptr)))) } -type float32Codec struct{} +type float32Codec struct { + promoter *codecPromoter[float32] +} + +func (c *float32Codec) Decode(ptr unsafe.Pointer, r *Reader) { + var v float32 + if c.promoter != nil { + v = c.promoter.promote(r) + } else { + v = r.ReadFloat() + } -func (*float32Codec) Decode(ptr unsafe.Pointer, r *Reader) { - *((*float32)(ptr)) = r.ReadFloat() + *((*float32)(ptr)) = v } func (*float32Codec) Encode(ptr unsafe.Pointer, w *Writer) { @@ -366,20 +411,36 @@ func (*float32DoubleCodec) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteDouble(float64(*((*float32)(ptr)))) } -type float64Codec struct{} +type float64Codec struct { + promoter *codecPromoter[float64] +} -func (*float64Codec) Decode(ptr unsafe.Pointer, r *Reader) { - *((*float64)(ptr)) = r.ReadDouble() +func (c *float64Codec) Decode(ptr unsafe.Pointer, r *Reader) { + var v float64 + if c.promoter != nil { + v = c.promoter.promote(r) + } else { + v = r.ReadDouble() + } + *((*float64)(ptr)) = v } func (*float64Codec) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteDouble(*((*float64)(ptr))) } -type stringCodec struct{} +type stringCodec struct { + promoter *codecPromoter[string] +} -func (*stringCodec) Decode(ptr unsafe.Pointer, r *Reader) { - *((*string)(ptr)) = r.ReadString() +func (c *stringCodec) Decode(ptr unsafe.Pointer, r *Reader) { + var v string + if c.promoter != nil { + v = c.promoter.promote(r) + } else { + v = r.ReadString() + } + *((*string)(ptr)) = v } func (*stringCodec) Encode(ptr unsafe.Pointer, w *Writer) { @@ -388,10 +449,17 @@ func (*stringCodec) Encode(ptr unsafe.Pointer, w *Writer) { type bytesCodec struct { sliceType *reflect2.UnsafeSliceType + promoter *codecPromoter[[]byte] } func (c *bytesCodec) Decode(ptr unsafe.Pointer, r *Reader) { - b := r.ReadBytes() + var b []byte + if c.promoter != nil { + b = c.promoter.promote(r) + } else { + b = r.ReadBytes() + } + // b := r.ReadBytes() c.sliceType.UnsafeSet(ptr, reflect2.PtrOf(b)) } @@ -413,10 +481,17 @@ func (c *dateCodec) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteInt(int32(days)) } -type timestampMillisCodec struct{} +type timestampMillisCodec struct { + promoter *codecPromoter[int64] +} func (c *timestampMillisCodec) Decode(ptr unsafe.Pointer, r *Reader) { - i := r.ReadLong() + var i int64 + if c.promoter != nil { + i = c.promoter.promote(r) + } else { + i = r.ReadLong() + } sec := i / 1e3 nsec := (i - sec*1e3) * 1e6 *((*time.Time)(ptr)) = time.Unix(sec, nsec).UTC() @@ -427,10 +502,17 @@ func (c *timestampMillisCodec) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteLong(t.Unix()*1e3 + int64(t.Nanosecond()/1e6)) } -type timestampMicrosCodec struct{} +type timestampMicrosCodec struct { + promoter *codecPromoter[int64] +} func (c *timestampMicrosCodec) Decode(ptr unsafe.Pointer, r *Reader) { - i := r.ReadLong() + var i int64 + if c.promoter != nil { + i = c.promoter.promote(r) + } else { + i = r.ReadLong() + } sec := i / 1e6 nsec := (i - sec*1e6) * 1e3 *((*time.Time)(ptr)) = time.Unix(sec, nsec).UTC() @@ -441,7 +523,8 @@ func (c *timestampMicrosCodec) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteLong(t.Unix()*1e6 + int64(t.Nanosecond()/1e3)) } -type timeMillisCodec struct{} +type timeMillisCodec struct { +} func (c *timeMillisCodec) Decode(ptr unsafe.Pointer, r *Reader) { i := r.ReadInt() @@ -453,10 +536,17 @@ func (c *timeMillisCodec) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteInt(int32(d.Nanoseconds() / int64(time.Millisecond))) } -type timeMicrosCodec struct{} +type timeMicrosCodec struct { + promoter *codecPromoter[int64] +} func (c *timeMicrosCodec) Decode(ptr unsafe.Pointer, r *Reader) { - i := r.ReadLong() + var i int64 + if c.promoter != nil { + i = c.promoter.promote(r) + } else { + i = r.ReadLong() + } *((*time.Duration)(ptr)) = time.Duration(i) * time.Microsecond } @@ -468,12 +558,19 @@ func (c *timeMicrosCodec) Encode(ptr unsafe.Pointer, w *Writer) { var one = big.NewInt(1) type bytesDecimalCodec struct { - prec int - scale int + prec int + scale int + promoter *codecPromoter[[]byte] } func (c *bytesDecimalCodec) Decode(ptr unsafe.Pointer, r *Reader) { - b := r.ReadBytes() + var b []byte + if c.promoter != nil { + b = c.promoter.promote(r) + } else { + b = r.ReadBytes() + } + // b := r.ReadBytes() if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 { i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8)) } @@ -514,12 +611,19 @@ func (c *bytesDecimalCodec) Encode(ptr unsafe.Pointer, w *Writer) { } type bytesDecimalPtrCodec struct { - prec int - scale int + prec int + scale int + promoter *codecPromoter[[]byte] } func (c *bytesDecimalPtrCodec) Decode(ptr unsafe.Pointer, r *Reader) { - b := r.ReadBytes() + var b []byte + if c.promoter != nil { + b = c.promoter.promote(r) + } else { + b = r.ReadBytes() + } + // b := r.ReadBytes() if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 { i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8)) } diff --git a/codec_promoter.go b/codec_promoter.go new file mode 100644 index 00000000..5776121f --- /dev/null +++ b/codec_promoter.go @@ -0,0 +1,64 @@ +package avro + +import ( + "reflect" + + "github.com/modern-go/reflect2" +) + +func getCodecPromoter[T any](actual Type) *codecPromoter[T] { + if actual == "" { + return nil + } + + return &codecPromoter[T]{actual: actual} +} + +type codecPromoter[T any] struct { + actual Type +} + +func (p *codecPromoter[T]) promote(r *Reader) (t T) { + tt := reflect2.TypeOf(t) + + convert := func(typ reflect2.Type, obj any) (t T) { + if !reflect.TypeOf(obj).ConvertibleTo(typ.Type1()) { + r.ReportError("decode promotable", "unsupported type") + // return zero value + return t + } + return reflect.ValueOf(obj).Convert(typ.Type1()).Interface().(T) + } + + switch p.actual { + case Int: + var obj int32 + (&intCodec[int32]{}).Decode(reflect2.PtrOf(&obj), r) + t = convert(tt, obj) + + case Long: + var obj int64 + (&longCodec[int64]{}).Decode(reflect2.PtrOf(&obj), r) + t = convert(tt, obj) + + case Float: + var obj float32 + (&float32Codec{}).Decode(reflect2.PtrOf(&obj), r) + t = convert(tt, obj) + + case String: + var obj string + (&stringCodec{}).Decode(reflect2.PtrOf(&obj), r) + t = convert(tt, obj) + + case Bytes: + var obj []byte + (&bytesCodec{}).Decode(reflect2.PtrOf(&obj), r) + t = convert(tt, obj) + + default: + r.ReportError("decode promotable", "unsupported actual type") + } + + return t +} diff --git a/reader_generic.go b/reader_generic.go index b75d240e..b7cfa02c 100644 --- a/reader_generic.go +++ b/reader_generic.go @@ -2,18 +2,26 @@ package avro import ( "fmt" + "log" "reflect" "time" ) // ReadNext reads the next Avro element as a generic interface. func (r *Reader) ReadNext(schema Schema) any { + var rp ReaderPromoter = r + if sch, ok := schema.(*PrimitiveSchema); ok && sch.actual != "" { + rp = &readerPromoter{r: r, actual: sch.actual, current: sch.Type()} + } + var ls LogicalSchema lts, ok := schema.(LogicalTypeSchema) if ok { ls = lts.Logical() } + log.Println("ls", ls) + switch schema.Type() { case Boolean: return r.ReadBool() @@ -34,34 +42,34 @@ func (r *Reader) ReadNext(schema Schema) any { if ls != nil { switch ls.Type() { case TimeMicros: - return time.Duration(r.ReadLong()) * time.Microsecond + return time.Duration(rp.ReadLong()) * time.Microsecond case TimestampMillis: - i := r.ReadLong() + i := rp.ReadLong() sec := i / 1e3 nsec := (i - sec*1e3) * 1e6 return time.Unix(sec, nsec).UTC() case TimestampMicros: - i := r.ReadLong() + i := rp.ReadLong() sec := i / 1e6 nsec := (i - sec*1e6) * 1e3 return time.Unix(sec, nsec).UTC() } } - return r.ReadLong() + return rp.ReadLong() case Float: - return r.ReadFloat() + return rp.ReadFloat() case Double: - return r.ReadDouble() + return rp.ReadDouble() case String: - return r.ReadString() + return rp.ReadString() case Bytes: if ls != nil && ls.Type() == Decimal { dec := ls.(*DecimalLogicalSchema) - return ratFromBytes(r.ReadBytes(), dec.Scale()) + return ratFromBytes(rp.ReadBytes(), dec.Scale()) } - return r.ReadBytes() + return rp.ReadBytes() case Record: fields := schema.(*RecordSchema).Fields() obj := make(map[string]any, len(fields)) @@ -97,7 +105,7 @@ func (r *Reader) ReadNext(schema Schema) any { return obj case Union: types := schema.(*UnionSchema).Types() - idx := int(r.ReadLong()) + idx := int(rp.ReadLong()) if idx < 0 || idx > len(types)-1 { r.ReportError("Read", "unknown union type") return nil diff --git a/reader_promoter.go b/reader_promoter.go new file mode 100644 index 00000000..f89b2f86 --- /dev/null +++ b/reader_promoter.go @@ -0,0 +1,102 @@ +package avro + +import ( + "reflect" +) + +type ReaderPromoter interface { + ReadLong() int64 + ReadFloat() float32 + ReadDouble() float64 + ReadString() string + ReadBytes() []byte +} + +type readerPromoter struct { + actual, current Type + r *Reader +} + +var _ ReaderPromoter = &readerPromoter{} + +var promotedInvalid = struct{}{} + +func (p *readerPromoter) readActual() any { + switch p.actual { + case Int: + return p.r.ReadInt() + + case Long: + return p.r.ReadLong() + + case Float: + return p.r.ReadFloat() + + case String: + return p.r.ReadString() + + case Bytes: + return p.r.ReadBytes() + + default: + p.r.ReportError("decode promotable", "unsupported actual type") + return promotedInvalid + } +} + +func (p *readerPromoter) ReadLong() int64 { + if v := p.readActual(); v != promotedInvalid { + return p.promote(v, p.current).(int64) + } + + return 0 +} + +func (p *readerPromoter) ReadFloat() float32 { + if v := p.readActual(); v != promotedInvalid { + return p.promote(v, p.current).(float32) + } + + return 0 +} + +func (p *readerPromoter) ReadDouble() float64 { + if v := p.readActual(); v != promotedInvalid { + return p.promote(v, p.current).(float64) + } + + return 0 +} + +func (p *readerPromoter) ReadString() string { + if v := p.readActual(); v != promotedInvalid { + return p.promote(v, p.current).(string) + } + + return "" +} + +func (p *readerPromoter) ReadBytes() []byte { + if v := p.readActual(); v != promotedInvalid { + return p.promote(v, p.current).([]byte) + } + + return nil +} + +func (p *readerPromoter) promote(obj any, st Type) (t any) { + switch st { + case Long: + return int64(reflect.ValueOf(obj).Int()) + case Float: + return float32(reflect.ValueOf(obj).Int()) + case Double: + return float64(reflect.ValueOf(obj).Float()) + case String: + return string(reflect.ValueOf(obj).Bytes()) + case Bytes: + return []byte(reflect.ValueOf(obj).String()) + } + + return obj +} diff --git a/schema.go b/schema.go index 6c67ed42..c9e174ea 100644 --- a/schema.go +++ b/schema.go @@ -394,6 +394,11 @@ type PrimitiveSchema struct { typ Type logical LogicalSchema + + // actual presents the actual type of the encoded value + // which can be promoted to schema current type. + // This field is only used in the context of write read schema resolution. + actual Type } // NewPrimitiveSchema creates a new PrimitiveSchema. diff --git a/schema_compatibility.go b/schema_compatibility.go index 402b4efe..8111a375 100644 --- a/schema_compatibility.go +++ b/schema_compatibility.go @@ -302,7 +302,11 @@ func (c *SchemaCompatibility) Resolve(reader, writer Schema) (Schema, error) { if writer.Type() != reader.Type() { if isPromotable(writer.Type()) { - return reader, nil + // TODO clean up + r := *reader.(*PrimitiveSchema) + r.actual = writer.Type() + + return &r, nil } if reader.Type() == Union { diff --git a/schema_compatibility_test.go b/schema_compatibility_test.go index 4b445b90..c7a3f84b 100644 --- a/schema_compatibility_test.go +++ b/schema_compatibility_test.go @@ -3,6 +3,7 @@ package avro_test import ( "log" "testing" + "time" "github.com/hamba/avro/v2" "github.com/stretchr/testify/assert" @@ -287,17 +288,23 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { },{ "name": "a", "type": "int" + }, + { + "name": "k", + "type": "string" }] }`) type A1 struct { - A int32 `avro:"a"` - C int32 `avro:"c"` + A int32 `avro:"a"` + C int32 `avro:"c"` + K string `avro:"k"` } a1 := A1{ A: 10, C: 1000000, + K: "K value", } b, err := avro.Marshal(sch1, a1) @@ -309,21 +316,29 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { "name": "A", "type": "record", "fields": [ - { - "name": "b", - "type": "string", - "default": "boo" - },{ - "name": "aa", - "aliases": ["a"], - "type": "long" - },{ - "name": "d", - "type": { - "type": "array", "items": "int" + { + "name": "k", + "type": "bytes" }, - "default":[1, 2, 3, 4] - }] + { + "name": "b", + "type": "string", + "default": "boo" + },{ + "name": "aa", + "aliases": ["a"], + "type": { + "type": "long", + "logicalType":"time-micros" + } + },{ + "name": "d", + "type": { + "type": "array", "items": "int" + }, + "default":[1, 2, 3, 4] + } + ] }`) sc := avro.NewSchemaCompatibility() @@ -335,9 +350,10 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { } type A2 struct { - A int64 `avro:"aa"` - B string `avro:"b"` - D []int32 `avro:"d"` + A time.Duration `avro:"aa"` + B string `avro:"b"` + D []int32 `avro:"d"` + K []byte `avro:"k"` } a2 := A2{} @@ -347,5 +363,5 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { t.Fatalf("unmarshal error %v", err) } - log.Printf("result: %+v", a2) + log.Printf("result: %+v %+v %T %+v", a2, a2.A, a2.A, string(a2.K)) }