Skip to content

Commit

Permalink
feat: add support for recursive schemas & structs (#413)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianiacobghiula authored Jun 25, 2024
1 parent e2e849d commit 66aad10
Show file tree
Hide file tree
Showing 29 changed files with 440 additions and 176 deletions.
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,6 @@ For security reasons, the configuration `Config.MaxByteSliceSize` restricts the
by the `Reader`. The default maximum size is `1MiB` and is configurable. This is required to stop untrusted input from consuming all memory and
crashing the application. Should this not be need, setting a negative number will disable the behaviour.

### Recursive Structs

At this moment recursive structs are not supported. It is planned for the future.

## Benchmark

Benchmark source code can be found at: [https://github.com/nrwiersma/avro-benchmarks](https://github.com/nrwiersma/avro-benchmarks)
Expand Down
116 changes: 78 additions & 38 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,48 +71,88 @@ func (c *frozenConfig) DecoderOf(schema Schema, typ reflect2.Type) ValDecoder {
}

ptrType := typ.(*reflect2.UnsafePtrType)
decoder = decoderOfType(c, schema, ptrType.Elem())
decoder = decoderOfType(newDecoderContext(c), schema, ptrType.Elem())
c.addDecoderToCache(schema.CacheFingerprint(), rtype, decoder)
return decoder
}

func decoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
if dec := createDecoderOfMarshaler(cfg, schema, typ); dec != nil {
type deferDecoder struct {
decoder ValDecoder
}

func (d *deferDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
d.decoder.Decode(ptr, r)
}

type deferEncoder struct {
encoder ValEncoder
}

func (d *deferEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
d.encoder.Encode(ptr, w)
}

type decoderContext struct {
cfg *frozenConfig
decoders map[cacheKey]ValDecoder
}

func newDecoderContext(cfg *frozenConfig) *decoderContext {
return &decoderContext{
cfg: cfg,
decoders: make(map[cacheKey]ValDecoder),
}
}

type encoderContext struct {
cfg *frozenConfig
encoders map[cacheKey]ValEncoder
}

func newEncoderContext(cfg *frozenConfig) *encoderContext {
return &encoderContext{
cfg: cfg,
encoders: make(map[cacheKey]ValEncoder),
}
}

func decoderOfType(d *decoderContext, schema Schema, typ reflect2.Type) ValDecoder {
if dec := createDecoderOfMarshaler(schema, typ); dec != nil {
return dec
}

// Handle eface case when it isnt a union
// Handle eface (empty interface) case when it isn't a union
if typ.Kind() == reflect.Interface && schema.Type() != Union {
if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok {
return newEfaceDecoder(cfg, schema)
return newEfaceDecoder(d, schema)
}
}

switch schema.Type() {
case String, Bytes, Int, Long, Float, Double, Boolean:
return createDecoderOfNative(schema.(*PrimitiveSchema), typ)

case Record:
return createDecoderOfRecord(cfg, schema, typ)

key := cacheKey{fingerprint: schema.CacheFingerprint(), rtype: typ.RType()}
defDec := &deferDecoder{}
d.decoders[key] = defDec
defDec.decoder = createDecoderOfRecord(d, schema.(*RecordSchema), typ)
return defDec.decoder
case Ref:
return decoderOfType(cfg, schema.(*RefSchema).Schema(), typ)

key := cacheKey{fingerprint: schema.(*RefSchema).Schema().CacheFingerprint(), rtype: typ.RType()}
if dec, f := d.decoders[key]; f {
return dec
}
return decoderOfType(d, schema.(*RefSchema).Schema(), typ)
case Enum:
return createDecoderOfEnum(schema, typ)

return createDecoderOfEnum(schema.(*EnumSchema), typ)
case Array:
return createDecoderOfArray(cfg, schema, typ)

return createDecoderOfArray(d, schema.(*ArraySchema), typ)
case Map:
return createDecoderOfMap(cfg, schema, typ)

return createDecoderOfMap(d, schema.(*MapSchema), typ)
case Union:
return createDecoderOfUnion(cfg, schema, typ)

return createDecoderOfUnion(d, schema.(*UnionSchema), typ)
case Fixed:
return createDecoderOfFixed(schema, typ)

return createDecoderOfFixed(schema.(*FixedSchema), typ)
default:
// It is impossible to get here with a valid schema
return &errorDecoder{err: fmt.Errorf("avro: schema type %s is unsupported", schema.Type())}
Expand All @@ -130,7 +170,7 @@ func (c *frozenConfig) EncoderOf(schema Schema, typ reflect2.Type) ValEncoder {
return encoder
}

encoder = encoderOfType(c, schema, typ)
encoder = encoderOfType(newEncoderContext(c), schema, typ)
if typ.LikePtr() {
encoder = &onePtrEncoder{encoder}
}
Expand All @@ -146,8 +186,8 @@ func (e *onePtrEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
e.enc.Encode(noescape(unsafe.Pointer(&ptr)), w)
}

func encoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
if enc := createEncoderOfMarshaler(cfg, schema, typ); enc != nil {
func encoderOfType(e *encoderContext, schema Schema, typ reflect2.Type) ValEncoder {
if enc := createEncoderOfMarshaler(schema, typ); enc != nil {
return enc
}

Expand All @@ -158,28 +198,28 @@ func encoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncod
switch schema.Type() {
case String, Bytes, Int, Long, Float, Double, Boolean, Null:
return createEncoderOfNative(schema, typ)

case Record:
return createEncoderOfRecord(cfg, schema, typ)

key := cacheKey{fingerprint: schema.Fingerprint(), rtype: typ.RType()}
defEnc := &deferEncoder{}
e.encoders[key] = defEnc
defEnc.encoder = createEncoderOfRecord(e, schema.(*RecordSchema), typ)
return defEnc.encoder
case Ref:
return encoderOfType(cfg, schema.(*RefSchema).Schema(), typ)

key := cacheKey{fingerprint: schema.(*RefSchema).Schema().Fingerprint(), rtype: typ.RType()}
if enc, f := e.encoders[key]; f {
return enc
}
return encoderOfType(e, schema.(*RefSchema).Schema(), typ)
case Enum:
return createEncoderOfEnum(schema, typ)

return createEncoderOfEnum(schema.(*EnumSchema), typ)
case Array:
return createEncoderOfArray(cfg, schema, typ)

return createEncoderOfArray(e, schema.(*ArraySchema), typ)
case Map:
return createEncoderOfMap(cfg, schema, typ)

return createEncoderOfMap(e, schema.(*MapSchema), typ)
case Union:
return createEncoderOfUnion(cfg, schema, typ)

return createEncoderOfUnion(e, schema.(*UnionSchema), typ)
case Fixed:
return createEncoderOfFixed(schema, typ)

return createEncoderOfFixed(schema.(*FixedSchema), typ)
default:
// It is impossible to get here with a valid schema
return &errorEncoder{err: fmt.Errorf("avro: schema type %s is unsupported", schema.Type())}
Expand Down
20 changes: 9 additions & 11 deletions codec_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,25 @@ import (
"github.com/modern-go/reflect2"
)

func createDecoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
func createDecoderOfArray(d *decoderContext, schema *ArraySchema, typ reflect2.Type) ValDecoder {
if typ.Kind() == reflect.Slice {
return decoderOfArray(cfg, schema, typ)
return decoderOfArray(d, schema, typ)
}

return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
}

func createEncoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
func createEncoderOfArray(e *encoderContext, schema *ArraySchema, typ reflect2.Type) ValEncoder {
if typ.Kind() == reflect.Slice {
return encoderOfArray(cfg, schema, typ)
return encoderOfArray(e, schema, typ)
}

return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
}

func decoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
arr := schema.(*ArraySchema)
func decoderOfArray(d *decoderContext, arr *ArraySchema, typ reflect2.Type) ValDecoder {
sliceType := typ.(*reflect2.UnsafeSliceType)
decoder := decoderOfType(cfg, arr.Items(), sliceType.Elem())
decoder := decoderOfType(d, arr.Items(), sliceType.Elem())

return &arrayDecoder{typ: sliceType, decoder: decoder}
}
Expand Down Expand Up @@ -74,13 +73,12 @@ func (d *arrayDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
}
}

func encoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
arr := schema.(*ArraySchema)
func encoderOfArray(e *encoderContext, arr *ArraySchema, typ reflect2.Type) ValEncoder {
sliceType := typ.(*reflect2.UnsafeSliceType)
encoder := encoderOfType(cfg, arr.Items(), sliceType.Elem())
encoder := encoderOfType(e, arr.Items(), sliceType.Elem())

return &arrayEncoder{
blockLength: cfg.getBlockLength(),
blockLength: e.cfg.getBlockLength(),
typ: sliceType,
encoder: encoder,
}
Expand Down
7 changes: 4 additions & 3 deletions codec_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ import (
"github.com/modern-go/reflect2"
)

func createDefaultDecoder(cfg *frozenConfig, field *Field, typ reflect2.Type) ValDecoder {
func createDefaultDecoder(d *decoderContext, field *Field, typ reflect2.Type) ValDecoder {
cfg := d.cfg
fn := func(def any) ([]byte, error) {
defaultType := reflect2.TypeOf(def)
if defaultType == nil {
defaultType = reflect2.TypeOf((*null)(nil))
}
defaultEncoder := encoderOfType(cfg, field.Type(), defaultType)
defaultEncoder := encoderOfType(newEncoderContext(cfg), field.Type(), defaultType)
if defaultType.LikePtr() {
defaultEncoder = &onePtrEncoder{defaultEncoder}
}
Expand All @@ -37,7 +38,7 @@ func createDefaultDecoder(cfg *frozenConfig, field *Field, typ reflect2.Type) Va
}
return &defaultDecoder{
data: b,
decoder: decoderOfType(cfg, field.Type(), typ),
decoder: decoderOfType(d, field.Type(), typ),
}
}

Expand Down
1 change: 0 additions & 1 deletion codec_default_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,6 @@ func TestDecoder_DefaultEnum(t *testing.T) {

require.NoError(t, err)
assert.Equal(t, TestRecord{B: "bar", A: "foo"}, got)

})

t.Run("TextUnmarshaler", func(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions codec_dynamic.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ type efaceDecoder struct {
dec ValDecoder
}

func newEfaceDecoder(cfg *frozenConfig, schema Schema) *efaceDecoder {
func newEfaceDecoder(d *decoderContext, schema Schema) *efaceDecoder {
typ, _ := genericReceiver(schema)
dec := decoderOfType(cfg, schema, typ)
dec := decoderOfType(d, schema, typ)

return &efaceDecoder{
schema: schema,
Expand Down
16 changes: 8 additions & 8 deletions codec_enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,27 @@ import (
"github.com/modern-go/reflect2"
)

func createDecoderOfEnum(schema Schema, typ reflect2.Type) ValDecoder {
func createDecoderOfEnum(schema *EnumSchema, typ reflect2.Type) ValDecoder {
switch {
case typ.Kind() == reflect.String:
return &enumCodec{enum: schema.(*EnumSchema)}
return &enumCodec{enum: schema}
case typ.Implements(textUnmarshalerType):
return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema)}
return &enumTextMarshalerCodec{typ: typ, enum: schema}
case reflect2.PtrTo(typ).Implements(textUnmarshalerType):
return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema), ptr: true}
return &enumTextMarshalerCodec{typ: typ, enum: schema, ptr: true}
}

return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
}

func createEncoderOfEnum(schema Schema, typ reflect2.Type) ValEncoder {
func createEncoderOfEnum(schema *EnumSchema, typ reflect2.Type) ValEncoder {
switch {
case typ.Kind() == reflect.String:
return &enumCodec{enum: schema.(*EnumSchema)}
return &enumCodec{enum: schema}
case typ.Implements(textMarshalerType):
return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema)}
return &enumTextMarshalerCodec{typ: typ, enum: schema}
case reflect2.PtrTo(typ).Implements(textMarshalerType):
return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema), ptr: true}
return &enumTextMarshalerCodec{typ: typ, enum: schema, ptr: true}
}

return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
Expand Down
13 changes: 4 additions & 9 deletions codec_fixed.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ import (
"github.com/modern-go/reflect2"
)

func createDecoderOfFixed(schema Schema, typ reflect2.Type) ValDecoder {
fixed := schema.(*FixedSchema)
func createDecoderOfFixed(fixed *FixedSchema, typ reflect2.Type) ValDecoder {
switch typ.Kind() {
case reflect.Array:
arrayType := typ.(reflect2.ArrayType)
Expand All @@ -21,7 +20,6 @@ func createDecoderOfFixed(schema Schema, typ reflect2.Type) ValDecoder {
return &fixedCodec{arrayType: typ.(*reflect2.UnsafeArrayType)}

case reflect.Uint64:
fixed := schema.(*FixedSchema)
if fixed.Size() != 8 {
break
}
Expand All @@ -44,23 +42,20 @@ func createDecoderOfFixed(schema Schema, typ reflect2.Type) ValDecoder {
}

return &errorDecoder{
err: fmt.Errorf("avro: %s is unsupported for Avro %s, size=%d", typ.String(), schema.Type(), fixed.Size()),
err: fmt.Errorf("avro: %s is unsupported for Avro %s, size=%d", typ.String(), fixed.Type(), fixed.Size()),
}
}

func createEncoderOfFixed(schema Schema, typ reflect2.Type) ValEncoder {
fixed := schema.(*FixedSchema)
func createEncoderOfFixed(fixed *FixedSchema, typ reflect2.Type) ValEncoder {
switch typ.Kind() {
case reflect.Array:
arrayType := typ.(reflect2.ArrayType)
fixed := schema.(*FixedSchema)
if arrayType.Elem().Kind() != reflect.Uint8 || arrayType.Len() != fixed.Size() {
break
}
return &fixedCodec{arrayType: typ.(*reflect2.UnsafeArrayType)}

case reflect.Uint64:
fixed := schema.(*FixedSchema)
if fixed.Size() != 8 {
break
}
Expand Down Expand Up @@ -92,7 +87,7 @@ func createEncoderOfFixed(schema Schema, typ reflect2.Type) ValEncoder {
}

return &errorEncoder{
err: fmt.Errorf("avro: %s is unsupported for Avro %s, size=%d", typ.String(), schema.Type(), fixed.Size()),
err: fmt.Errorf("avro: %s is unsupported for Avro %s, size=%d", typ.String(), fixed.Type(), fixed.Size()),
}
}

Expand Down
3 changes: 1 addition & 2 deletions codec_generic_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ func TestGenericDecode(t *testing.T) {
want any
wantErr require.ErrorAssertionFunc
}{

{
name: "Bool",
data: []byte{0x01},
Expand Down Expand Up @@ -228,7 +227,7 @@ func TestGenericDecode(t *testing.T) {

typ, err := genericReceiver(schema)
require.NoError(t, err)
dec := decoderOfType(DefaultConfig.(*frozenConfig), schema, typ)
dec := decoderOfType(newDecoderContext(DefaultConfig.(*frozenConfig)), schema, typ)

got := genericDecode(typ, dec, r)

Expand Down
Loading

0 comments on commit 66aad10

Please sign in to comment.