Skip to content

Commit

Permalink
fix: fix enum schema evolution
Browse files Browse the repository at this point in the history
  • Loading branch information
redaLaanait committed Jan 17, 2024
1 parent d25c1c8 commit ff7b796
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 18 deletions.
34 changes: 18 additions & 16 deletions codec_enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ import (
func createDecoderOfEnum(schema Schema, typ reflect2.Type) ValDecoder {
switch {
case typ.Kind() == reflect.String:
return &enumCodec{symbols: schema.(*EnumSchema).Symbols()}
return &enumCodec{enum: schema.(*EnumSchema)}
case typ.Implements(textUnmarshalerType):
return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols()}
return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema)}
case reflect2.PtrTo(typ).Implements(textUnmarshalerType):
return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols(), ptr: true}
return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema), ptr: true}
}

return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
Expand All @@ -26,34 +26,35 @@ func createDecoderOfEnum(schema Schema, typ reflect2.Type) ValDecoder {
func createEncoderOfEnum(schema Schema, typ reflect2.Type) ValEncoder {
switch {
case typ.Kind() == reflect.String:
return &enumCodec{symbols: schema.(*EnumSchema).Symbols()}
return &enumCodec{enum: schema.(*EnumSchema)}
case typ.Implements(textMarshalerType):
return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols()}
return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema)}
case reflect2.PtrTo(typ).Implements(textMarshalerType):
return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols(), ptr: true}
return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema), ptr: true}
}

return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
}

type enumCodec struct {
symbols []string
enum *EnumSchema
}

func (c *enumCodec) Decode(ptr unsafe.Pointer, r *Reader) {
i := int(r.ReadInt())

if i < 0 || i >= len(c.symbols) {
symbol, ok := c.enum.Symbol(i)
if !ok {
r.ReportError("decode enum symbol", "unknown enum symbol")
return
}

*((*string)(ptr)) = c.symbols[i]
*((*string)(ptr)) = symbol
}

func (c *enumCodec) Encode(ptr unsafe.Pointer, w *Writer) {
str := *((*string)(ptr))
for i, sym := range c.symbols {
for i, sym := range c.enum.symbols {
if str != sym {
continue
}
Expand All @@ -66,15 +67,16 @@ func (c *enumCodec) Encode(ptr unsafe.Pointer, w *Writer) {
}

type enumTextMarshalerCodec struct {
typ reflect2.Type
symbols []string
ptr bool
typ reflect2.Type
enum *EnumSchema
ptr bool
}

func (c *enumTextMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) {
i := int(r.ReadInt())

if i < 0 || i >= len(c.symbols) {
symbol, ok := c.enum.Symbol(i)
if !ok {
r.ReportError("decode enum symbol", "unknown enum symbol")
return
}
Expand All @@ -92,7 +94,7 @@ func (c *enumTextMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) {
obj = c.typ.UnsafeIndirect(ptr)
}
unmarshaler := (obj).(encoding.TextUnmarshaler)
if err := unmarshaler.UnmarshalText([]byte(c.symbols[i])); err != nil {
if err := unmarshaler.UnmarshalText([]byte(symbol)); err != nil {
r.ReportError("decode enum text unmarshaler", err.Error())
}
}
Expand All @@ -116,7 +118,7 @@ func (c *enumTextMarshalerCodec) Encode(ptr unsafe.Pointer, w *Writer) {
}

str := string(b)
for i, sym := range c.symbols {
for i, sym := range c.enum.symbols {
if str != sym {
continue
}
Expand Down
24 changes: 24 additions & 0 deletions config_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,31 @@ func TestConfig_ReusesDecoders_WithRecordFieldActions(t *testing.T) {

assert.NotSame(t, dec1, dec2)
})
}

func TestConfig_ReusesDecoders_WithEnum(t *testing.T) {
sch := `{
"type": "enum",
"name": "test.enum",
"symbols": ["foo"],
"default": "foo"
}`
typ := reflect2.TypeOfPtr(new(string))

api := Config{
TagKey: "test",
BlockLength: 2,
}.Freeze()
cfg := api.(*frozenConfig)

schema1 := MustParse(sch)
schema2 := MustParse(sch)
schema2.(*EnumSchema).actual = []string{"foo", "bar"}

dec1 := cfg.DecoderOf(schema1, typ)
dec2 := cfg.DecoderOf(schema2, typ)

assert.NotSame(t, dec1, dec2)
}

func TestConfig_DisableCache_DoesNotReuseDecoders(t *testing.T) {
Expand Down
47 changes: 45 additions & 2 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -846,11 +846,14 @@ type EnumSchema struct {
name
properties
fingerprinter
cacheFingerprinter

symbols []string
def string

doc string
doc string
// actual presents the actual symbols of the encoded value.
// It's only used in the context of write-read schema resolution.
actual []string
}

// NewEnumSchema creates a new enum schema instance.
Expand Down Expand Up @@ -915,11 +918,42 @@ func (s *EnumSchema) Symbols() []string {
return s.symbols
}

// Symbol returns the symbol for the given index.
// It might return the default value in the context of schema read-write resolution.
func (s *EnumSchema) Symbol(i int) (string, bool) {
symbols := s.symbols
// has actual symbols
hasActual := len(s.actual) > 0
if hasActual {
symbols = s.actual
}

if i < 0 || i >= len(symbols) {
return "", false
}

symbol := symbols[i]

if hasActual && !hasSymbol(s.symbols, symbol) {
if !s.HasDefault() {
return "", false
}
return s.Default(), true
}

return symbol, true
}

// Default returns the default of an enum or an empty string.
func (s *EnumSchema) Default() string {
return s.def
}

// HasDefault determines if the schema has a default value.
func (s *EnumSchema) HasDefault() bool {
return s.def != ""
}

// String returns the canonical form of the schema.
func (s *EnumSchema) String() string {
symbols := ""
Expand Down Expand Up @@ -975,6 +1009,15 @@ func (s *EnumSchema) FingerprintUsing(typ FingerprintType) ([]byte, error) {
return s.fingerprinter.FingerprintUsing(typ, s)
}

// CacheFingerprint returns a special fingerprint of the schema for caching purposes.
func (s *EnumSchema) CacheFingerprint() [32]byte {
if len(s.actual) == 0 || !s.HasDefault() {
return s.Fingerprint()
}

return s.cacheFingerprinter.fingerprint([]any{s.Fingerprint(), s.actual, s.Default()})
}

// ArraySchema is an Avro array type schema.
type ArraySchema struct {
properties
Expand Down
17 changes: 17 additions & 0 deletions schema_compatibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ func (c *SchemaCompatibility) match(reader, writer Schema) error {
}

if err := c.checkEnumSymbols(r, w); err != nil {
if r.HasDefault() {
return nil
}
return err
}

Expand Down Expand Up @@ -324,6 +327,20 @@ func (c *SchemaCompatibility) resolve(reader, writer Schema) (Schema, error) {
}

if writer.Type() == Enum {
r := reader.(*EnumSchema)
w := writer.(*EnumSchema)
if err := c.checkEnumSymbols(r, w); err != nil {
if r.HasDefault() {
enum, _ := NewEnumSchema(r.Name(), r.Namespace(), r.Symbols(),
WithAliases(r.Aliases()),
WithDefault(r.Default()),
)
enum.actual = w.Symbols()
return enum, nil
}

return nil, err
}
return reader, nil
}

Expand Down
37 changes: 37 additions & 0 deletions schema_compatibility_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ func TestSchemaCompatibility_Compatible(t *testing.T) {
writer: `{"type":"enum", "name":"test", "namespace": "org.hamba.avro", "symbols":["TEST1", "TEST2"]}`,
wantErr: assert.Error,
},
{
name: "Enum Reader Missing Symbol With Default",
reader: `{"type":"enum", "name":"test", "namespace": "org.hamba.avro", "symbols":["TEST1"], "default": "TEST1"}`,
writer: `{"type":"enum", "name":"test", "namespace": "org.hamba.avro", "symbols":["TEST1", "TEST2"]}`,
wantErr: assert.NoError,
},
{
name: "Enum Writer Missing Symbol",
reader: `{"type":"enum", "name":"test", "namespace": "org.hamba.avro", "symbols":["TEST1", "TEST2"]}`,
Expand Down Expand Up @@ -387,6 +393,37 @@ func TestSchemaCompatibility_Resolve(t *testing.T) {
value: map[string]any{"foo": "bar"},
want: map[string]any{"foo": []byte("bar")},
},
{
name: "Enum Reader Missing Symbols With Default",
reader: `{
"type": "enum",
"name": "test.enum",
"symbols": ["foo"],
"default": "foo"
}`,
writer: `{
"type": "enum",
"name": "test.enum",
"symbols": ["foo", "bar"]
}`,
value: "bar",
want: "foo",
},
{
name: "Enum Writer Missing Symbols",
reader: `{
"type": "enum",
"name": "test.enum",
"symbols": ["foo", "bar"]
}`,
writer: `{
"type": "enum",
"name": "test.enum",
"symbols": ["foo"]
}`,
value: "foo",
want: "foo",
},
{
name: "Enum With Alias",
reader: `{
Expand Down
19 changes: 19 additions & 0 deletions schema_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,25 @@ func TestSchema_CacheFingerprint(t *testing.T) {
assert.NotEqual(t, schema.Fingerprint(), schema.CacheFingerprint())
})

t.Run("enum", func(t *testing.T) {
schema1 := MustParse(`{
"type": "enum",
"name": "test.enum",
"symbols": ["foo"]
}`).(*EnumSchema)

schema2 := MustParse(`{
"type": "enum",
"name": "test.enum",
"symbols": ["foo"],
"default": "foo"
}`).(*EnumSchema)
schema2.actual = []string{"boo"}

assert.Equal(t, schema1.Fingerprint(), schema1.CacheFingerprint())
assert.NotEqual(t, schema1.CacheFingerprint(), schema2.CacheFingerprint())
})

t.Run("record", func(t *testing.T) {
schema1 := MustParse(`{
"type": "record",
Expand Down

0 comments on commit ff7b796

Please sign in to comment.