From 10b955da832308ff071a72dac8fc1483955fe4ab Mon Sep 17 00:00:00 2001 From: Reda Laanait Date: Fri, 17 Nov 2023 19:41:29 +0100 Subject: [PATCH] fix: type promotion POC --- codec_promoter.go | 64 ++++++++++++++++++++++ reader_promoter.go | 102 +++++++++++++++++++++++++++++++++++ schema.go | 5 ++ schema_compatibility.go | 6 ++- schema_compatibility_test.go | 56 ++++++++++++------- 5 files changed, 212 insertions(+), 21 deletions(-) create mode 100644 codec_promoter.go create mode 100644 reader_promoter.go diff --git a/codec_promoter.go b/codec_promoter.go new file mode 100644 index 00000000..5776121f --- /dev/null +++ b/codec_promoter.go @@ -0,0 +1,64 @@ +package avro + +import ( + "reflect" + + "github.com/modern-go/reflect2" +) + +func getCodecPromoter[T any](actual Type) *codecPromoter[T] { + if actual == "" { + return nil + } + + return &codecPromoter[T]{actual: actual} +} + +type codecPromoter[T any] struct { + actual Type +} + +func (p *codecPromoter[T]) promote(r *Reader) (t T) { + tt := reflect2.TypeOf(t) + + convert := func(typ reflect2.Type, obj any) (t T) { + if !reflect.TypeOf(obj).ConvertibleTo(typ.Type1()) { + r.ReportError("decode promotable", "unsupported type") + // return zero value + return t + } + return reflect.ValueOf(obj).Convert(typ.Type1()).Interface().(T) + } + + switch p.actual { + case Int: + var obj int32 + (&intCodec[int32]{}).Decode(reflect2.PtrOf(&obj), r) + t = convert(tt, obj) + + case Long: + var obj int64 + (&longCodec[int64]{}).Decode(reflect2.PtrOf(&obj), r) + t = convert(tt, obj) + + case Float: + var obj float32 + (&float32Codec{}).Decode(reflect2.PtrOf(&obj), r) + t = convert(tt, obj) + + case String: + var obj string + (&stringCodec{}).Decode(reflect2.PtrOf(&obj), r) + t = convert(tt, obj) + + case Bytes: + var obj []byte + (&bytesCodec{}).Decode(reflect2.PtrOf(&obj), r) + t = convert(tt, obj) + + default: + r.ReportError("decode promotable", "unsupported actual type") + } + + return t +} diff --git a/reader_promoter.go b/reader_promoter.go new file mode 100644 index 00000000..f89b2f86 --- /dev/null +++ b/reader_promoter.go @@ -0,0 +1,102 @@ +package avro + +import ( + "reflect" +) + +type ReaderPromoter interface { + ReadLong() int64 + ReadFloat() float32 + ReadDouble() float64 + ReadString() string + ReadBytes() []byte +} + +type readerPromoter struct { + actual, current Type + r *Reader +} + +var _ ReaderPromoter = &readerPromoter{} + +var promotedInvalid = struct{}{} + +func (p *readerPromoter) readActual() any { + switch p.actual { + case Int: + return p.r.ReadInt() + + case Long: + return p.r.ReadLong() + + case Float: + return p.r.ReadFloat() + + case String: + return p.r.ReadString() + + case Bytes: + return p.r.ReadBytes() + + default: + p.r.ReportError("decode promotable", "unsupported actual type") + return promotedInvalid + } +} + +func (p *readerPromoter) ReadLong() int64 { + if v := p.readActual(); v != promotedInvalid { + return p.promote(v, p.current).(int64) + } + + return 0 +} + +func (p *readerPromoter) ReadFloat() float32 { + if v := p.readActual(); v != promotedInvalid { + return p.promote(v, p.current).(float32) + } + + return 0 +} + +func (p *readerPromoter) ReadDouble() float64 { + if v := p.readActual(); v != promotedInvalid { + return p.promote(v, p.current).(float64) + } + + return 0 +} + +func (p *readerPromoter) ReadString() string { + if v := p.readActual(); v != promotedInvalid { + return p.promote(v, p.current).(string) + } + + return "" +} + +func (p *readerPromoter) ReadBytes() []byte { + if v := p.readActual(); v != promotedInvalid { + return p.promote(v, p.current).([]byte) + } + + return nil +} + +func (p *readerPromoter) promote(obj any, st Type) (t any) { + switch st { + case Long: + return int64(reflect.ValueOf(obj).Int()) + case Float: + return float32(reflect.ValueOf(obj).Int()) + case Double: + return float64(reflect.ValueOf(obj).Float()) + case String: + return string(reflect.ValueOf(obj).Bytes()) + case Bytes: + return []byte(reflect.ValueOf(obj).String()) + } + + return obj +} diff --git a/schema.go b/schema.go index 6c67ed42..c9e174ea 100644 --- a/schema.go +++ b/schema.go @@ -394,6 +394,11 @@ type PrimitiveSchema struct { typ Type logical LogicalSchema + + // actual presents the actual type of the encoded value + // which can be promoted to schema current type. + // This field is only used in the context of write read schema resolution. + actual Type } // NewPrimitiveSchema creates a new PrimitiveSchema. diff --git a/schema_compatibility.go b/schema_compatibility.go index 402b4efe..8111a375 100644 --- a/schema_compatibility.go +++ b/schema_compatibility.go @@ -302,7 +302,11 @@ func (c *SchemaCompatibility) Resolve(reader, writer Schema) (Schema, error) { if writer.Type() != reader.Type() { if isPromotable(writer.Type()) { - return reader, nil + // TODO clean up + r := *reader.(*PrimitiveSchema) + r.actual = writer.Type() + + return &r, nil } if reader.Type() == Union { diff --git a/schema_compatibility_test.go b/schema_compatibility_test.go index 4b445b90..c7a3f84b 100644 --- a/schema_compatibility_test.go +++ b/schema_compatibility_test.go @@ -3,6 +3,7 @@ package avro_test import ( "log" "testing" + "time" "github.com/hamba/avro/v2" "github.com/stretchr/testify/assert" @@ -287,17 +288,23 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { },{ "name": "a", "type": "int" + }, + { + "name": "k", + "type": "string" }] }`) type A1 struct { - A int32 `avro:"a"` - C int32 `avro:"c"` + A int32 `avro:"a"` + C int32 `avro:"c"` + K string `avro:"k"` } a1 := A1{ A: 10, C: 1000000, + K: "K value", } b, err := avro.Marshal(sch1, a1) @@ -309,21 +316,29 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { "name": "A", "type": "record", "fields": [ - { - "name": "b", - "type": "string", - "default": "boo" - },{ - "name": "aa", - "aliases": ["a"], - "type": "long" - },{ - "name": "d", - "type": { - "type": "array", "items": "int" + { + "name": "k", + "type": "bytes" }, - "default":[1, 2, 3, 4] - }] + { + "name": "b", + "type": "string", + "default": "boo" + },{ + "name": "aa", + "aliases": ["a"], + "type": { + "type": "long", + "logicalType":"time-micros" + } + },{ + "name": "d", + "type": { + "type": "array", "items": "int" + }, + "default":[1, 2, 3, 4] + } + ] }`) sc := avro.NewSchemaCompatibility() @@ -335,9 +350,10 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { } type A2 struct { - A int64 `avro:"aa"` - B string `avro:"b"` - D []int32 `avro:"d"` + A time.Duration `avro:"aa"` + B string `avro:"b"` + D []int32 `avro:"d"` + K []byte `avro:"k"` } a2 := A2{} @@ -347,5 +363,5 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { t.Fatalf("unmarshal error %v", err) } - log.Printf("result: %+v", a2) + log.Printf("result: %+v %+v %T %+v", a2, a2.A, a2.A, string(a2.K)) }