From 1a64620c513041cf48b13480f75ae8b238cf4fef Mon Sep 17 00:00:00 2001 From: Nicholas Wiersma Date: Thu, 18 Apr 2024 18:43:23 +0200 Subject: [PATCH] feat: support slices for nullable unions --- README.md | 3 +- codec_record.go | 2 +- codec_union.go | 97 ++++++++++++++++++++++++++++++++----------- decoder_union_test.go | 50 ++++++++++++++++++++-- encoder_union_test.go | 30 +++++++++++++ 5 files changed, 153 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 53dce505..95724cc3 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,8 @@ When a non-`nil` union value is encountered, a single key is en/decoded. The key type name, or scheam full name in the case of a named schema (enum, fixed or record). * ***T:** This is allowed in a "nullable" union. A nullable union is defined as a two schema union, with one of the types being `null` (ie. `["null", "string"]` or `["string", "null"]`), in this case -a `*T` is allowed, with `T` matching the conversion table above. +a `*T` is allowed, with `T` matching the conversion table above. In the case of a slice, the slice can be used +directly. * **any:** An `interface` can be provided and the type or name resolved. Primitive types are pre-registered, but named types, maps and slices will need to be registered with the `Register` function. In the case of arrays and maps the enclosed schema type or name is postfix to the type with a `:` separator, diff --git a/codec_record.go b/codec_record.go index 03bfed61..45ee02f8 100644 --- a/codec_record.go +++ b/codec_record.go @@ -184,7 +184,7 @@ func encoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEnc defaultType := reflect2.TypeOf(&def) fields = append(fields, &structFieldEncoder{ defaultPtr: reflect2.PtrOf(&def), - encoder: encoderOfPtrUnion(cfg, field.Type(), defaultType), + encoder: encoderOfNullableUnion(cfg, field.Type(), defaultType), }) continue } diff --git a/codec_union.go b/codec_union.go index 084f3c7f..9e3649af 100644 --- a/codec_union.go +++ b/codec_union.go @@ -18,20 +18,22 @@ func createDecoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) V break } return decoderOfMapUnion(cfg, schema, typ) - + case reflect.Slice: + if !schema.(*UnionSchema).Nullable() { + break + } + return decoderOfNullableUnion(cfg, schema, typ) case reflect.Ptr: if !schema.(*UnionSchema).Nullable() { break } - return decoderOfPtrUnion(cfg, schema, typ) - + return decoderOfNullableUnion(cfg, schema, typ) case reflect.Interface: if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok { dec, err := decoderOfResolvedUnion(cfg, schema) if err != nil { return &errorDecoder{err: fmt.Errorf("avro: problem resolving decoder for Avro %s: %w", schema.Type(), err)} } - return dec } } @@ -47,14 +49,17 @@ func createEncoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) V break } return encoderOfMapUnion(cfg, schema, typ) - + case reflect.Slice: + if !schema.(*UnionSchema).Nullable() { + break + } + return encoderOfNullableUnion(cfg, schema, typ) case reflect.Ptr: if !schema.(*UnionSchema).Nullable() { break } - return encoderOfPtrUnion(cfg, schema, typ) + return encoderOfNullableUnion(cfg, schema, typ) } - return encoderOfResolverUnion(cfg, schema, typ) } @@ -163,27 +168,39 @@ func (e *mapUnionEncoder) Encode(ptr unsafe.Pointer, w *Writer) { encoder.Encode(elemPtr, w) } -func decoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { +func decoderOfNullableUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { union := schema.(*UnionSchema) _, typeIdx := union.Indices() - ptrType := typ.(*reflect2.UnsafePtrType) - elemType := ptrType.Elem() - decoder := decoderOfType(cfg, union.Types()[typeIdx], elemType) - return &unionPtrDecoder{ + var ( + baseTyp reflect2.Type + isPtr bool + ) + switch v := typ.(type) { + case *reflect2.UnsafePtrType: + baseTyp = v.Elem() + isPtr = true + case *reflect2.UnsafeSliceType: + baseTyp = v + } + decoder := decoderOfType(cfg, union.Types()[typeIdx], baseTyp) + + return &unionNullableDecoder{ schema: union, - typ: elemType, + typ: baseTyp, + isPtr: isPtr, decoder: decoder, } } -type unionPtrDecoder struct { +type unionNullableDecoder struct { schema *UnionSchema typ reflect2.Type + isPtr bool decoder ValDecoder } -func (d *unionPtrDecoder) Decode(ptr unsafe.Pointer, r *Reader) { +func (d *unionNullableDecoder) Decode(ptr unsafe.Pointer, r *Reader) { _, schema := getUnionSchema(d.schema, r) if schema == nil { return @@ -194,47 +211,79 @@ func (d *unionPtrDecoder) Decode(ptr unsafe.Pointer, r *Reader) { return } + // Handle the non-ptr case separately. + if !d.isPtr { + if d.typ.UnsafeIsNil(ptr) { + // Create a new instance. + newPtr := d.typ.UnsafeNew() + d.decoder.Decode(newPtr, r) + d.typ.UnsafeSet(ptr, newPtr) + return + } + + // Reuse the existing instance. + d.decoder.Decode(ptr, r) + return + } + if *((*unsafe.Pointer)(ptr)) == nil { - // Create new instance + // Create new instance. newPtr := d.typ.UnsafeNew() d.decoder.Decode(newPtr, r) *((*unsafe.Pointer)(ptr)) = newPtr return } - // Reuse existing instance + // Reuse existing instance. d.decoder.Decode(*((*unsafe.Pointer)(ptr)), r) } -func encoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { +func encoderOfNullableUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { union := schema.(*UnionSchema) nullIdx, typeIdx := union.Indices() - ptrType := typ.(*reflect2.UnsafePtrType) - encoder := encoderOfType(cfg, union.Types()[typeIdx], ptrType.Elem()) - return &unionPtrEncoder{ + var ( + baseTyp reflect2.Type + isPtr bool + ) + switch v := typ.(type) { + case *reflect2.UnsafePtrType: + baseTyp = v.Elem() + isPtr = true + case *reflect2.UnsafeSliceType: + baseTyp = v + } + encoder := encoderOfType(cfg, union.Types()[typeIdx], baseTyp) + + return &unionNullableEncoder{ schema: union, encoder: encoder, + isPtr: isPtr, nullIdx: int64(nullIdx), typeIdx: int64(typeIdx), } } -type unionPtrEncoder struct { +type unionNullableEncoder struct { schema *UnionSchema encoder ValEncoder + isPtr bool nullIdx int64 typeIdx int64 } -func (e *unionPtrEncoder) Encode(ptr unsafe.Pointer, w *Writer) { +func (e *unionNullableEncoder) Encode(ptr unsafe.Pointer, w *Writer) { if *((*unsafe.Pointer)(ptr)) == nil { w.WriteLong(e.nullIdx) return } w.WriteLong(e.typeIdx) - e.encoder.Encode(*((*unsafe.Pointer)(ptr)), w) + newPtr := ptr + if e.isPtr { + newPtr = *((*unsafe.Pointer)(ptr)) + } + e.encoder.Encode(newPtr, w) } func decoderOfResolvedUnion(cfg *frozenConfig, schema Schema) (ValDecoder, error) { diff --git a/decoder_union_test.go b/decoder_union_test.go index e5858db8..c09e162a 100644 --- a/decoder_union_test.go +++ b/decoder_union_test.go @@ -182,17 +182,17 @@ func TestDecoder_UnionPtrReversed(t *testing.T) { func TestDecoder_UnionPtrReuseInstance(t *testing.T) { defer ConfigTeardown() - avro.Register("test", &TestRecord{}) - data := []byte{0x02, 0x36, 0x06, 0x66, 0x6F, 0x6F} schema := `["null", {"type": "record", "name": "test", "fields" : [{"name": "a", "type": "long"}, {"name": "b", "type": "string"}]}]` dec, _ := avro.NewDecoder(schema, bytes.NewReader(data)) - got := &TestRecord{} + var original TestRecord + got := &original err := dec.Decode(&got) require.NoError(t, err) assert.IsType(t, &TestRecord{}, got) + assert.Same(t, &original, got) assert.Equal(t, int64(27), got.A) assert.Equal(t, "foo", got.B) } @@ -225,6 +225,50 @@ func TestDecoder_UnionPtrReversedNull(t *testing.T) { assert.Nil(t, got) } +func TestDecoder_UnionNullableSlice(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x02, 0x06, 0x66, 0x6F, 0x6F} + schema := `["null", "bytes"]` + dec, _ := avro.NewDecoder(schema, bytes.NewReader(data)) + + var got []byte + err := dec.Decode(&got) + + want := []byte("foo") + require.NoError(t, err) + assert.Equal(t, want, got) +} + +func TestDecoder_UnionNullableSliceNull(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x00} + schema := `["null", "bytes"]` + dec, _ := avro.NewDecoder(schema, bytes.NewReader(data)) + + var got []byte + err := dec.Decode(&got) + + require.NoError(t, err) + assert.Nil(t, got) +} + +func TestDecoder_UnionNullableSliceNotNullButEmpty(t *testing.T) { + defer ConfigTeardown() + + data := []byte{0x02, 0x00} + schema := `["null", "bytes"]` + dec, _ := avro.NewDecoder(schema, bytes.NewReader(data)) + + var got []byte + err := dec.Decode(&got) + + require.NoError(t, err) + assert.NotNil(t, got) + assert.Empty(t, got) +} + func TestDecoder_UnionPtrInvalidSchema(t *testing.T) { defer ConfigTeardown() diff --git a/encoder_union_test.go b/encoder_union_test.go index ac438380..f7687ad8 100644 --- a/encoder_union_test.go +++ b/encoder_union_test.go @@ -256,6 +256,36 @@ func TestEncoder_UnionPtrNotNullable(t *testing.T) { assert.Error(t, err) } +func TestEncoder_UnionNullableSlice(t *testing.T) { + defer ConfigTeardown() + + schema := `["null", "bytes"]` + buf := bytes.NewBuffer([]byte{}) + enc, err := avro.NewEncoder(schema, buf) + require.NoError(t, err) + + b := []byte("foo") + err = enc.Encode(b) + + require.NoError(t, err) + assert.Equal(t, []byte{0x02, 0x06, 0x66, 0x6F, 0x6F}, buf.Bytes()) +} + +func TestEncoder_UnionNullableSliceNull(t *testing.T) { + defer ConfigTeardown() + + schema := `["null", "bytes"]` + buf := bytes.NewBuffer([]byte{}) + enc, err := avro.NewEncoder(schema, buf) + require.NoError(t, err) + + var b []byte + err = enc.Encode(b) + + require.NoError(t, err) + assert.Equal(t, []byte{0x00}, buf.Bytes()) +} + func TestEncoder_UnionInterface(t *testing.T) { defer ConfigTeardown()