diff --git a/.gitignore b/.gitignore index e69de29b..46573f87 100644 --- a/.gitignore +++ b/.gitignore @@ -0,0 +1,2 @@ +# Goland +.idea \ No newline at end of file diff --git a/codec.go b/codec.go index 156839a3..f2626222 100644 --- a/codec.go +++ b/codec.go @@ -69,14 +69,11 @@ func (c *frozenConfig) DecoderOf(schema Schema, typ reflect2.Type) ValDecoder { if decoder != nil { return decoder } - - ptrType := typ.(*reflect2.UnsafePtrType) - decoder = decoderOfType(c, schema, ptrType.Elem()) - c.addDecoderToCache(schema.Fingerprint(), rtype, decoder) + decoder = c.processingGroup.processingDecoderOfType(c, schema, typ, decoderOfType) return decoder } -func decoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func decoderOfType(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValDecoder { if dec := createDecoderOfMarshaler(cfg, schema, typ); dec != nil { return dec } @@ -93,22 +90,22 @@ func decoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecod return createDecoderOfNative(schema, typ) case Record: - return createDecoderOfRecord(cfg, schema, typ) + return createDecoderOfRecord(cfg, p, schema, typ) case Ref: - return decoderOfType(cfg, schema.(*RefSchema).Schema(), typ) + return decoderOfType(cfg, p, schema.(*RefSchema).Schema(), typ) case Enum: return createDecoderOfEnum(schema, typ) case Array: - return createDecoderOfArray(cfg, schema, typ) + return createDecoderOfArray(cfg, p, schema, typ) case Map: - return createDecoderOfMap(cfg, schema, typ) + return createDecoderOfMap(cfg, p, schema, typ) case Union: - return createDecoderOfUnion(cfg, schema, typ) + return createDecoderOfUnion(cfg, p, schema, typ) case Fixed: return createDecoderOfFixed(schema, typ) @@ -123,18 +120,12 @@ func (c *frozenConfig) EncoderOf(schema Schema, typ reflect2.Type) ValEncoder { if typ == nil { typ = reflect2.TypeOf((*null)(nil)) } - rtype := typ.RType() encoder := c.getEncoderFromCache(schema.Fingerprint(), rtype) if encoder != nil { return encoder } - - encoder = encoderOfType(c, schema, typ) - if typ.LikePtr() { - encoder = &onePtrEncoder{encoder} - } - c.addEncoderToCache(schema.Fingerprint(), rtype, encoder) + encoder = c.processingGroup.processingEncoderOfType(c, schema, typ, encoderOfType) return encoder } @@ -146,7 +137,7 @@ func (e *onePtrEncoder) Encode(ptr unsafe.Pointer, w *Writer) { e.enc.Encode(noescape(unsafe.Pointer(&ptr)), w) } -func encoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func encoderOfType(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValEncoder { if enc := createEncoderOfMarshaler(cfg, schema, typ); enc != nil { return enc } @@ -160,22 +151,22 @@ func encoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncod return createEncoderOfNative(schema, typ) case Record: - return createEncoderOfRecord(cfg, schema, typ) + return createEncoderOfRecord(cfg, p, schema, typ) case Ref: - return encoderOfType(cfg, schema.(*RefSchema).Schema(), typ) + return encoderOfType(cfg, p, schema.(*RefSchema).Schema(), typ) case Enum: return createEncoderOfEnum(schema, typ) case Array: - return createEncoderOfArray(cfg, schema, typ) + return createEncoderOfArray(cfg, p, schema, typ) case Map: - return createEncoderOfMap(cfg, schema, typ) + return createEncoderOfMap(cfg, p, schema, typ) case Union: - return createEncoderOfUnion(cfg, schema, typ) + return createEncoderOfUnion(cfg, p, schema, typ) case Fixed: return createEncoderOfFixed(schema, typ) diff --git a/codec_array.go b/codec_array.go index 658ae583..5679cca3 100644 --- a/codec_array.go +++ b/codec_array.go @@ -10,26 +10,26 @@ import ( "github.com/modern-go/reflect2" ) -func createDecoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func createDecoderOfArray(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValDecoder { if typ.Kind() == reflect.Slice { - return decoderOfArray(cfg, schema, typ) + return decoderOfArray(cfg, p, schema, typ) } return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} } -func createEncoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func createEncoderOfArray(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValEncoder { if typ.Kind() == reflect.Slice { - return encoderOfArray(cfg, schema, typ) + return encoderOfArray(cfg, p, schema, typ) } return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} } -func decoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func decoderOfArray(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValDecoder { arr := schema.(*ArraySchema) sliceType := typ.(*reflect2.UnsafeSliceType) - decoder := decoderOfType(cfg, arr.Items(), sliceType.Elem()) + decoder := decoderOfType(cfg, p, arr.Items(), sliceType.Elem()) return &arrayDecoder{typ: sliceType, decoder: decoder} } @@ -68,10 +68,10 @@ func (d *arrayDecoder) Decode(ptr unsafe.Pointer, r *Reader) { } } -func encoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func encoderOfArray(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValEncoder { arr := schema.(*ArraySchema) sliceType := typ.(*reflect2.UnsafeSliceType) - encoder := encoderOfType(cfg, arr.Items(), sliceType.Elem()) + encoder := encoderOfType(cfg, p, arr.Items(), sliceType.Elem()) return &arrayEncoder{ blockLength: cfg.getBlockLength(), diff --git a/codec_generic.go b/codec_generic.go index 4265288c..570f8b02 100644 --- a/codec_generic.go +++ b/codec_generic.go @@ -15,7 +15,8 @@ func genericDecode(schema Schema, r *Reader) any { r.ReportError("Read", err.Error()) return nil } - decoderOfType(r.cfg, schema, rTyp).Decode(rPtr, r) + dec := r.cfg.processingGroup.processingDecoderOfType(r.cfg, schema, rTyp, decoderOfType) + dec.Decode(rPtr, r) if r.Error != nil { return nil } diff --git a/codec_map.go b/codec_map.go index 58018ba6..8fc1ce0d 100644 --- a/codec_map.go +++ b/codec_map.go @@ -11,38 +11,38 @@ import ( "github.com/modern-go/reflect2" ) -func createDecoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func createDecoderOfMap(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValDecoder { if typ.Kind() == reflect.Map { keyType := typ.(reflect2.MapType).Key() switch { case keyType.Kind() == reflect.String: - return decoderOfMap(cfg, schema, typ) + return decoderOfMap(cfg, p, schema, typ) case keyType.Implements(textUnmarshalerType): - return decoderOfMapUnmarshaler(cfg, schema, typ) + return decoderOfMapUnmarshaler(cfg, p, schema, typ) } } return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} } -func createEncoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func createEncoderOfMap(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValEncoder { if typ.Kind() == reflect.Map { keyType := typ.(reflect2.MapType).Key() switch { case keyType.Kind() == reflect.String: - return encoderOfMap(cfg, schema, typ) + return encoderOfMap(cfg, p, schema, typ) case keyType.Implements(textMarshalerType): - return encoderOfMapMarshaler(cfg, schema, typ) + return encoderOfMapMarshaler(cfg, p, schema, typ) } } return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} } -func decoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func decoderOfMap(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValDecoder { m := schema.(*MapSchema) mapType := typ.(*reflect2.UnsafeMapType) - decoder := decoderOfType(cfg, m.Values(), mapType.Elem()) + decoder := decoderOfType(cfg, p, m.Values(), mapType.Elem()) return &mapDecoder{ mapType: mapType, @@ -82,10 +82,10 @@ func (d *mapDecoder) Decode(ptr unsafe.Pointer, r *Reader) { } } -func decoderOfMapUnmarshaler(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func decoderOfMapUnmarshaler(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValDecoder { m := schema.(*MapSchema) mapType := typ.(*reflect2.UnsafeMapType) - decoder := decoderOfType(cfg, m.Values(), mapType.Elem()) + decoder := decoderOfType(cfg, p, m.Values(), mapType.Elem()) return &mapDecoderUnmarshaler{ mapType: mapType, @@ -141,10 +141,10 @@ func (d *mapDecoderUnmarshaler) Decode(ptr unsafe.Pointer, r *Reader) { } } -func encoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func encoderOfMap(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValEncoder { m := schema.(*MapSchema) mapType := typ.(*reflect2.UnsafeMapType) - encoder := encoderOfType(cfg, m.Values(), mapType.Elem()) + encoder := encoderOfType(cfg, p, m.Values(), mapType.Elem()) return &mapEncoder{ blockLength: cfg.getBlockLength(), @@ -186,10 +186,10 @@ func (e *mapEncoder) Encode(ptr unsafe.Pointer, w *Writer) { } } -func encoderOfMapMarshaler(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func encoderOfMapMarshaler(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValEncoder { m := schema.(*MapSchema) mapType := typ.(*reflect2.UnsafeMapType) - encoder := encoderOfType(cfg, m.Values(), mapType.Elem()) + encoder := encoderOfType(cfg, p, m.Values(), mapType.Elem()) return &mapEncoderMarshaller{ blockLength: cfg.getBlockLength(), diff --git a/codec_ptr.go b/codec_ptr.go index fc94a68c..b6319928 100644 --- a/codec_ptr.go +++ b/codec_ptr.go @@ -7,11 +7,11 @@ import ( "github.com/modern-go/reflect2" ) -func decoderOfPtr(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func decoderOfPtr(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValDecoder { ptrType := typ.(*reflect2.UnsafePtrType) elemType := ptrType.Elem() - decoder := decoderOfType(cfg, schema, elemType) + decoder := decoderOfType(cfg, p, schema, elemType) return &dereferenceDecoder{typ: elemType, decoder: decoder} } @@ -34,11 +34,11 @@ func (d *dereferenceDecoder) Decode(ptr unsafe.Pointer, r *Reader) { d.decoder.Decode(*((*unsafe.Pointer)(ptr)), r) } -func encoderOfPtr(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func encoderOfPtr(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValEncoder { ptrType := typ.(*reflect2.UnsafePtrType) elemType := ptrType.Elem() - enc := encoderOfType(cfg, schema, elemType) + enc := encoderOfType(cfg, p, schema, elemType) return &dereferenceEncoder{typ: elemType, encoder: enc} } diff --git a/codec_record.go b/codec_record.go index 86295f20..62585d45 100644 --- a/codec_record.go +++ b/codec_record.go @@ -10,20 +10,20 @@ import ( "github.com/modern-go/reflect2" ) -func createDecoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func createDecoderOfRecord(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValDecoder { switch typ.Kind() { case reflect.Struct: - return decoderOfStruct(cfg, schema, typ) + return decoderOfStruct(cfg, p, schema, typ) case reflect.Map: if typ.(reflect2.MapType).Key().Kind() != reflect.String || typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { break } - return decoderOfRecord(cfg, schema, typ) + return decoderOfRecord(cfg, p, schema, typ) case reflect.Ptr: - return decoderOfPtr(cfg, schema, typ) + return decoderOfPtr(cfg, p, schema, typ) case reflect.Interface: if ifaceType, ok := typ.(*reflect2.UnsafeIFaceType); ok { @@ -34,26 +34,33 @@ func createDecoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for avro %s", typ.String(), schema.Type())} } -func createEncoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func createEncoderOfRecord(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValEncoder { switch typ.Kind() { case reflect.Struct: - return encoderOfStruct(cfg, schema, typ) + return encoderOfStruct(cfg, p, schema, typ) case reflect.Map: if typ.(reflect2.MapType).Key().Kind() != reflect.String || typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { break } - return encoderOfRecord(cfg, schema, typ) + return encoderOfRecord(cfg, p, schema, typ) case reflect.Ptr: - return encoderOfPtr(cfg, schema, typ) + return encoderOfPtr(cfg, p, schema, typ) } return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for avro %s", typ.String(), schema.Type())} } -func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func decoderOfStruct(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValDecoder { + cached := p.getProcessingDecoderFromCache(schema.Fingerprint(), typ.RType()) + if cached != nil { + return cached + } + dec := &structDecoder{typ: typ, fields: nil} + p.addProcessingDecoderToCache(schema.Fingerprint(), typ.RType(), dec) + rec := schema.(*RecordSchema) structDesc := describeStruct(cfg.getTagKey(), typ) @@ -77,14 +84,14 @@ func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec continue } - dec := decoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type()) fields = append(fields, &structFieldDecoder{ field: sf.Field, - decoder: dec, + decoder: decoderOfType(cfg, p, field.Type(), sf.Field[len(sf.Field)-1].Type()), }) } - return &structDecoder{typ: typ, fields: fields} + dec.fields = fields + return dec } type structFieldDecoder struct { @@ -133,7 +140,14 @@ func (d *structDecoder) Decode(ptr unsafe.Pointer, r *Reader) { } } -func encoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func encoderOfStruct(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValEncoder { + cached := p.getProcessingEncoderFromCache(schema.Fingerprint(), typ.RType()) + if cached != nil { + return cached + } + enc := &structEncoder{typ: typ, fields: nil} + p.addProcessingEncoderToCache(schema.Fingerprint(), typ.RType(), enc) + rec := schema.(*RecordSchema) structDesc := describeStruct(cfg.getTagKey(), typ) @@ -143,7 +157,7 @@ func encoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEnc if sf != nil { fields = append(fields, &structFieldEncoder{ field: sf.Field, - encoder: encoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type()), + encoder: encoderOfType(cfg, p, field.Type(), sf.Field[len(sf.Field)-1].Type()), }) continue } @@ -165,14 +179,14 @@ func encoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEnc defaultType := reflect2.TypeOf(&def) fields = append(fields, &structFieldEncoder{ defaultPtr: reflect2.PtrOf(&def), - encoder: encoderOfPtrUnion(cfg, field.Type(), defaultType), + encoder: encoderOfPtrUnion(cfg, p, field.Type(), defaultType), }) continue } } defaultType := reflect2.TypeOf(def) - defaultEncoder := encoderOfType(cfg, field.Type(), defaultType) + defaultEncoder := encoderOfType(cfg, p, field.Type(), defaultType) if defaultType.LikePtr() { defaultEncoder = &onePtrEncoder{defaultEncoder} } @@ -181,7 +195,9 @@ func encoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEnc encoder: defaultEncoder, }) } - return &structEncoder{typ: typ, fields: fields} + + enc.fields = fields + return enc } type structFieldEncoder struct { @@ -231,7 +247,7 @@ func (e *structEncoder) Encode(ptr unsafe.Pointer, w *Writer) { } } -func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func decoderOfRecord(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValDecoder { rec := schema.(*RecordSchema) mapType := typ.(*reflect2.UnsafeMapType) @@ -239,7 +255,7 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec for i, field := range rec.Fields() { fields[i] = recordMapDecoderField{ name: field.Name(), - decoder: decoderOfType(cfg, field.Type(), mapType.Elem()), + decoder: decoderOfType(cfg, p, field.Type(), mapType.Elem()), } } @@ -278,7 +294,7 @@ func (d *recordMapDecoder) Decode(ptr unsafe.Pointer, r *Reader) { } } -func encoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func encoderOfRecord(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValEncoder { rec := schema.(*RecordSchema) mapType := typ.(*reflect2.UnsafeMapType) @@ -288,7 +304,7 @@ func encoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEnc name: field.Name(), hasDef: field.HasDefault(), def: field.Default(), - encoder: encoderOfType(cfg, field.Type(), mapType.Elem()), + encoder: encoderOfType(cfg, p, field.Type(), mapType.Elem()), } if field.HasDefault() { @@ -303,7 +319,7 @@ func encoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEnc } defaultType := reflect2.TypeOf(fields[i].def) - fields[i].defEncoder = encoderOfType(cfg, field.Type(), defaultType) + fields[i].defEncoder = encoderOfType(cfg, p, field.Type(), defaultType) if defaultType.LikePtr() { fields[i].defEncoder = &onePtrEncoder{fields[i].defEncoder} } diff --git a/codec_union.go b/codec_union.go index 7f864be6..b6b6b6c1 100644 --- a/codec_union.go +++ b/codec_union.go @@ -10,7 +10,7 @@ import ( "github.com/modern-go/reflect2" ) -func createDecoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func createDecoderOfUnion(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValDecoder { switch typ.Kind() { case reflect.Map: if typ.(reflect2.MapType).Key().Kind() != reflect.String || @@ -23,11 +23,11 @@ func createDecoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) V if !schema.(*UnionSchema).Nullable() { break } - return decoderOfPtrUnion(cfg, schema, typ) + return decoderOfPtrUnion(cfg, p, schema, typ) case reflect.Interface: if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok { - dec, err := decoderOfResolvedUnion(cfg, schema) + dec, err := decoderOfResolvedUnion(cfg, p, schema) if err != nil { return &errorDecoder{err: fmt.Errorf("avro: problem resolving decoder for Avro %s: %w", schema.Type(), err)} } @@ -39,7 +39,7 @@ func createDecoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) V return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} } -func createEncoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func createEncoderOfUnion(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValEncoder { switch typ.Kind() { case reflect.Map: if typ.(reflect2.MapType).Key().Kind() != reflect.String || @@ -52,10 +52,10 @@ func createEncoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) V if !schema.(*UnionSchema).Nullable() { break } - return encoderOfPtrUnion(cfg, schema, typ) + return encoderOfPtrUnion(cfg, p, schema, typ) } - return encoderOfResolverUnion(cfg, schema, typ) + return encoderOfResolverUnion(cfg, p, schema, typ) } func decoderOfMapUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { @@ -96,7 +96,10 @@ func (d *mapUnionDecoder) Decode(ptr unsafe.Pointer, r *Reader) { keyPtr := reflect2.PtrOf(key) elemPtr := d.elemType.UnsafeNew() - decoderOfType(d.cfg, resSchema, d.elemType).Decode(elemPtr, r) + + dec := d.cfg.processingGroup.processingDecoderOfType(d.cfg, resSchema, d.elemType, decoderOfType) + + dec.Decode(elemPtr, r) d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr) } @@ -146,19 +149,20 @@ func (e *mapUnionEncoder) Encode(ptr unsafe.Pointer, w *Writer) { elemType := reflect2.TypeOf(val) elemPtr := reflect2.PtrOf(val) - encoder := encoderOfType(e.cfg, schema, elemType) + encoder := e.cfg.processingGroup.processingEncoderOfType(e.cfg, schema, elemType, encoderOfType) + if elemType.LikePtr() { encoder = &onePtrEncoder{encoder} } encoder.Encode(elemPtr, w) } -func decoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func decoderOfPtrUnion(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValDecoder { union := schema.(*UnionSchema) _, typeIdx := union.Indices() ptrType := typ.(*reflect2.UnsafePtrType) elemType := ptrType.Elem() - decoder := decoderOfType(cfg, union.Types()[typeIdx], elemType) + decoder := decoderOfType(cfg, p, union.Types()[typeIdx], elemType) return &unionPtrDecoder{ schema: union, @@ -196,11 +200,11 @@ func (d *unionPtrDecoder) Decode(ptr unsafe.Pointer, r *Reader) { d.decoder.Decode(*((*unsafe.Pointer)(ptr)), r) } -func encoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func encoderOfPtrUnion(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValEncoder { union := schema.(*UnionSchema) nullIdx, typeIdx := union.Indices() ptrType := typ.(*reflect2.UnsafePtrType) - encoder := encoderOfType(cfg, union.Types()[typeIdx], ptrType.Elem()) + encoder := encoderOfType(cfg, p, union.Types()[typeIdx], ptrType.Elem()) return &unionPtrEncoder{ schema: union, @@ -227,7 +231,7 @@ func (e *unionPtrEncoder) Encode(ptr unsafe.Pointer, w *Writer) { e.encoder.Encode(*((*unsafe.Pointer)(ptr)), w) } -func decoderOfResolvedUnion(cfg *frozenConfig, schema Schema) (ValDecoder, error) { +func decoderOfResolvedUnion(cfg *frozenConfig, p *processing, schema Schema) (ValDecoder, error) { union := schema.(*UnionSchema) types := make([]reflect2.Type, len(union.Types())) @@ -252,7 +256,7 @@ func decoderOfResolvedUnion(cfg *frozenConfig, schema Schema) (ValDecoder, error break } - decoder := decoderOfType(cfg, schema, typ) + decoder := decoderOfType(cfg, p, schema, typ) decoders[i] = decoder types[i] = typ } @@ -344,7 +348,7 @@ func unionResolutionName(schema Schema) string { return name } -func encoderOfResolverUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func encoderOfResolverUnion(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValEncoder { union := schema.(*UnionSchema) names, err := cfg.resolver.Name(typ) @@ -367,7 +371,7 @@ func encoderOfResolverUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) return &errorEncoder{err: fmt.Errorf("avro: unknown union type %s", names[0])} } - encoder := encoderOfType(cfg, schema, typ) + encoder := encoderOfType(cfg, p, schema, typ) return &unionResolverEncoder{ pos: pos, diff --git a/config.go b/config.go index ad299a7d..d116726c 100644 --- a/config.go +++ b/config.go @@ -2,10 +2,9 @@ package avro import ( "errors" + "github.com/modern-go/reflect2" "io" "sync" - - "github.com/modern-go/reflect2" ) const maxByteSliceSize = 1024 * 1024 @@ -80,6 +79,8 @@ func (c Config) Freeze() API { }, } + api.processingGroup = newProcessingGroup() + return api } @@ -114,6 +115,8 @@ type frozenConfig struct { decoderCache sync.Map // map[cacheKey]ValDecoder encoderCache sync.Map // map[cacheKey]ValEncoder + processingGroup *processingGroup + readerPool *sync.Pool writerPool *sync.Pool diff --git a/decoder_record_test.go b/decoder_record_test.go index 0d1d23f4..4630429c 100644 --- a/decoder_record_test.go +++ b/decoder_record_test.go @@ -393,3 +393,28 @@ func TestDecoder_RefStruct(t *testing.T) { require.NoError(t, err) assert.Equal(t, want, got) } + +func TestDecoder_RecursiveStructs(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x36, 0x6, 0x66, 0x6f, 0x6f, 0x2, 0x38, 0x6, 0x62, 0x61, 0x72, 0x0} + schema := `{ + "type": "record", + "namespace": "tns", + "name": "test", + "fields" : [ + {"name": "a", "type": "long"}, + {"name": "b", "type": "string"}, + {"name": "r", "type": ["null", "tns.test"]} + ] +}` + + dec, err := avro.NewDecoder(schema, bytes.NewReader(data)) + require.NoError(t, err) + + var got TestRecursiveRecord + err = dec.Decode(&got) + + require.NoError(t, err) + assert.Equal(t, TestRecursiveRecord{A: 27, B: "foo", R: &TestRecursiveRecord{A: 28, B: "bar"}}, got) +} diff --git a/encoder_record_test.go b/encoder_record_test.go index 1aba99bc..f5f009c8 100644 --- a/encoder_record_test.go +++ b/encoder_record_test.go @@ -622,3 +622,26 @@ func TestEncoder_RefStruct(t *testing.T) { require.NoError(t, err) assert.Equal(t, []byte{0x36, 0x06, 0x66, 0x6f, 0x6f, 0x36, 0x06, 0x66, 0x6f, 0x6f}, buf.Bytes()) } + +func TestEncoder_RecursiveStructs(t *testing.T) { + defer ConfigTeardown() + schema := `{ + "type": "record", + "namespace": "tns", + "name": "test", + "fields" : [ + {"name": "a", "type": "long"}, + {"name": "b", "type": "string"}, + {"name": "r", "type": ["null", "tns.test"]} + ] +}` + obj := TestRecursiveRecord{A: 27, B: "foo", R: &TestRecursiveRecord{A: 28, B: "bar"}} + buf := &bytes.Buffer{} + enc, err := avro.NewEncoder(schema, buf) + require.NoError(t, err) + + err = enc.Encode(obj) + + require.NoError(t, err) + assert.Equal(t, []byte{0x36, 0x6, 0x66, 0x6f, 0x6f, 0x2, 0x38, 0x6, 0x62, 0x61, 0x72, 0x0}, buf.Bytes()) +} diff --git a/go.mod b/go.mod index caf519bf..97a0242f 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/mitchellh/mapstructure v1.5.0 github.com/modern-go/reflect2 v1.0.2 github.com/stretchr/testify v1.7.1 + golang.org/x/sync v0.5.0 ) require ( diff --git a/go.sum b/go.sum index c154e11f..bee02275 100644 --- a/go.sum +++ b/go.sum @@ -21,6 +21,8 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/mmhash/hash.go b/internal/mmhash/hash.go new file mode 100644 index 00000000..09d6a9dc --- /dev/null +++ b/internal/mmhash/hash.go @@ -0,0 +1,46 @@ +// Package mmhash export runtime.memhash +package mmhash + +import "unsafe" + +//go:noescape +//go:linkname memhash runtime.memhash +func memhash(p unsafe.Pointer, h, s uintptr) uintptr + +type stringStruct struct { + str unsafe.Pointer + len int +} + +// Sum64 sum bytes to uint64. +func Sum64(data []byte) uint64 { + ss := (*stringStruct)(unsafe.Pointer(&data)) + return uint64(memhash(ss.str, 0, uintptr(ss.len))) +} + +// Hash sum bytes to uint32. +func Hash(b []byte) uint32 { + const ( + seed = 0xbc9f1d34 + m = 0xc6a4a793 + ) + h := uint32(seed) ^ uint32(len(b))*m + for ; len(b) >= 4; b = b[4:] { + h += uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 + h *= m + h ^= h >> 16 + } + switch len(b) { + case 3: + h += uint32(b[2]) << 16 + fallthrough + case 2: + h += uint32(b[1]) << 8 + fallthrough + case 1: + h += uint32(b[0]) + h *= m + h ^= h >> 24 + } + return h +} diff --git a/processings.go b/processings.go new file mode 100644 index 00000000..0859e2a6 --- /dev/null +++ b/processings.go @@ -0,0 +1,132 @@ +package avro + +import ( + "encoding/binary" + "github.com/modern-go/reflect2" + "golang.org/x/sync/singleflight" + "sync" + "unsafe" +) + +type decoderOfTypeHandler func(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValDecoder +type encoderOfTypeHandler func(cfg *frozenConfig, p *processing, schema Schema, typ reflect2.Type) ValEncoder + +type processing struct { + key []byte + cfg *frozenConfig + decoders map[cacheKey]ValDecoder + encoders map[cacheKey]ValEncoder +} + +func (p *processing) addProcessingEncoderToCache(fingerprint [32]byte, rtype uintptr, enc ValEncoder) { + key := cacheKey{fingerprint: fingerprint, rtype: rtype} + p.encoders[key] = enc +} + +func (p *processing) getProcessingEncoderFromCache(fingerprint [32]byte, rtype uintptr) ValEncoder { + key := cacheKey{fingerprint: fingerprint, rtype: rtype} + if enc := p.cfg.getEncoderFromCache(fingerprint, rtype); enc != nil { + return enc + } + if enc, ok := p.encoders[key]; ok { + return enc + } + return nil +} + +func (p *processing) addProcessingDecoderToCache(fingerprint [32]byte, rtype uintptr, dec ValDecoder) { + key := cacheKey{fingerprint: fingerprint, rtype: rtype} + p.decoders[key] = dec +} + +func (p *processing) getProcessingDecoderFromCache(fingerprint [32]byte, rtype uintptr) ValDecoder { + key := cacheKey{fingerprint: fingerprint, rtype: rtype} + if dec := p.cfg.getDecoderFromCache(fingerprint, rtype); dec != nil { + return dec + } + if dec, ok := p.decoders[key]; ok { + return dec + } + return nil +} + +func (p *processing) clean() { + for key := range p.encoders { + delete(p.encoders, key) + } + for key := range p.decoders { + delete(p.decoders, key) + } + p.cfg = nil +} + +func newProcessingGroup() *processingGroup { + return &processingGroup{ + group: new(singleflight.Group), + pool: new(sync.Pool), + } +} + +type processingGroup struct { + group *singleflight.Group + pool *sync.Pool +} + +func (ps *processingGroup) borrow() *processing { + cached := ps.pool.Get() + if cached != nil { + return cached.(*processing) + } + return &processing{ + key: make([]byte, 64), + decoders: map[cacheKey]ValDecoder{}, + encoders: map[cacheKey]ValEncoder{}, + } +} + +func (ps *processingGroup) processingDecoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type, handler decoderOfTypeHandler) ValDecoder { + p := ps.borrow() + p.cfg = cfg + fingerprint := schema.Fingerprint() + rtype := typ.RType() + copy(p.key[:32], fingerprint[:]) + binary.LittleEndian.PutUint64(p.key[32:], uint64(rtype)) + copy(p.key[:48], []byte{2}) + ptrType, isPtr := typ.(*reflect2.UnsafePtrType) + if isPtr { + typ = ptrType.Elem() + } + v, _, _ := ps.group.Do(*(*string)(unsafe.Pointer(&p.key)), func() (interface{}, error) { + return handler(cfg, p, schema, typ), nil + }) + dec := v.(ValDecoder) + cfg.addDecoderToCache(schema.Fingerprint(), rtype, dec) + ps.finish(p) + return dec +} + +func (ps *processingGroup) processingEncoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type, handler encoderOfTypeHandler) ValEncoder { + p := ps.borrow() + p.cfg = cfg + fingerprint := schema.Fingerprint() + rtype := typ.RType() + copy(p.key[:32], fingerprint[:]) + binary.LittleEndian.PutUint64(p.key[32:], uint64(rtype)) + copy(p.key[:48], []byte{1}) + v, _, _ := ps.group.Do(*(*string)(unsafe.Pointer(&p.key)), func() (interface{}, error) { + return handler(cfg, p, schema, typ), nil + }) + enc := v.(ValEncoder) + if typ.LikePtr() { + enc = &onePtrEncoder{enc} + } + cfg.addEncoderToCache(schema.Fingerprint(), rtype, enc) + ps.finish(p) + return enc +} + +func (ps *processingGroup) finish(p *processing) { + ps.group.Forget(*(*string)(unsafe.Pointer(&p.key))) + p.clean() + ps.pool.Put(p) +} diff --git a/types_test.go b/types_test.go index 020369eb..f5b5853d 100644 --- a/types_test.go +++ b/types_test.go @@ -55,3 +55,9 @@ type TestUnexportedRecord struct { A int64 `avro:"a"` b string `avro:"b"` } + +type TestRecursiveRecord struct { + A int64 `avro:"a"` + B string `avro:"b"` + R *TestRecursiveRecord `avro:"r"` +}