diff --git a/codec_enum.go b/codec_enum.go index 0fdb73a..8f23eb6 100644 --- a/codec_enum.go +++ b/codec_enum.go @@ -13,11 +13,11 @@ import ( func createDecoderOfEnum(schema Schema, typ reflect2.Type) ValDecoder { switch { case typ.Kind() == reflect.String: - return &enumCodec{symbols: schema.(*EnumSchema).Symbols()} + return &enumCodec{enum: schema.(*EnumSchema)} case typ.Implements(textUnmarshalerType): - return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols()} + return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema)} case reflect2.PtrTo(typ).Implements(textUnmarshalerType): - return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols(), ptr: true} + return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema), ptr: true} } return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} @@ -26,34 +26,35 @@ func createDecoderOfEnum(schema Schema, typ reflect2.Type) ValDecoder { func createEncoderOfEnum(schema Schema, typ reflect2.Type) ValEncoder { switch { case typ.Kind() == reflect.String: - return &enumCodec{symbols: schema.(*EnumSchema).Symbols()} + return &enumCodec{enum: schema.(*EnumSchema)} case typ.Implements(textMarshalerType): - return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols()} + return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema)} case reflect2.PtrTo(typ).Implements(textMarshalerType): - return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols(), ptr: true} + return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema), ptr: true} } return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} } type enumCodec struct { - symbols []string + enum *EnumSchema } func (c *enumCodec) Decode(ptr unsafe.Pointer, r *Reader) { i := int(r.ReadInt()) - if i < 0 || i >= len(c.symbols) { + symbol, ok := c.enum.Symbol(i) + if !ok { r.ReportError("decode enum symbol", "unknown enum symbol") return } - *((*string)(ptr)) = c.symbols[i] + *((*string)(ptr)) = symbol } func (c *enumCodec) Encode(ptr unsafe.Pointer, w *Writer) { str := *((*string)(ptr)) - for i, sym := range c.symbols { + for i, sym := range c.enum.symbols { if str != sym { continue } @@ -66,15 +67,16 @@ func (c *enumCodec) Encode(ptr unsafe.Pointer, w *Writer) { } type enumTextMarshalerCodec struct { - typ reflect2.Type - symbols []string - ptr bool + typ reflect2.Type + enum *EnumSchema + ptr bool } func (c *enumTextMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) { i := int(r.ReadInt()) - if i < 0 || i >= len(c.symbols) { + symbol, ok := c.enum.Symbol(i) + if !ok { r.ReportError("decode enum symbol", "unknown enum symbol") return } @@ -92,7 +94,7 @@ func (c *enumTextMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) { obj = c.typ.UnsafeIndirect(ptr) } unmarshaler := (obj).(encoding.TextUnmarshaler) - if err := unmarshaler.UnmarshalText([]byte(c.symbols[i])); err != nil { + if err := unmarshaler.UnmarshalText([]byte(symbol)); err != nil { r.ReportError("decode enum text unmarshaler", err.Error()) } } @@ -116,7 +118,7 @@ func (c *enumTextMarshalerCodec) Encode(ptr unsafe.Pointer, w *Writer) { } str := string(b) - for i, sym := range c.symbols { + for i, sym := range c.enum.symbols { if str != sym { continue } diff --git a/config_internal_test.go b/config_internal_test.go index 674bd54..d256c47 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -92,7 +92,31 @@ func TestConfig_ReusesDecoders_WithRecordFieldActions(t *testing.T) { assert.NotSame(t, dec1, dec2) }) +} + +func TestConfig_ReusesDecoders_WithEnum(t *testing.T) { + sch := `{ + "type": "enum", + "name": "test.enum", + "symbols": ["foo"], + "default": "foo" + }` + typ := reflect2.TypeOfPtr(new(string)) + api := Config{ + TagKey: "test", + BlockLength: 2, + }.Freeze() + cfg := api.(*frozenConfig) + + schema1 := MustParse(sch) + schema2 := MustParse(sch) + schema2.(*EnumSchema).actual = []string{"foo", "bar"} + + dec1 := cfg.DecoderOf(schema1, typ) + dec2 := cfg.DecoderOf(schema2, typ) + + assert.NotSame(t, dec1, dec2) } func TestConfig_DisableCache_DoesNotReuseDecoders(t *testing.T) { diff --git a/schema.go b/schema.go index 4dba5e5..4f88fec 100644 --- a/schema.go +++ b/schema.go @@ -848,11 +848,14 @@ type EnumSchema struct { name properties fingerprinter + cacheFingerprinter symbols []string def string - - doc string + doc string + // actual presents the actual symbols of the encoded value. + // It's only used in the context of write-read schema resolution. + actual []string } // NewEnumSchema creates a new enum schema instance. @@ -917,11 +920,42 @@ func (s *EnumSchema) Symbols() []string { return s.symbols } +// Symbol returns the symbol for the given index. +// It might return the default value in the context of write-read schema resolution. +func (s *EnumSchema) Symbol(i int) (string, bool) { + symbols := s.symbols + // has actual symbols + hasActual := len(s.actual) > 0 + if hasActual { + symbols = s.actual + } + + if i < 0 || i >= len(symbols) { + return "", false + } + + symbol := symbols[i] + + if hasActual && !hasSymbol(s.symbols, symbol) { + if !s.HasDefault() { + return "", false + } + return s.Default(), true + } + + return symbol, true +} + // Default returns the default of an enum or an empty string. func (s *EnumSchema) Default() string { return s.def } +// HasDefault determines if the schema has a default value. +func (s *EnumSchema) HasDefault() bool { + return s.def != "" +} + // String returns the canonical form of the schema. func (s *EnumSchema) String() string { symbols := "" @@ -977,6 +1011,15 @@ func (s *EnumSchema) FingerprintUsing(typ FingerprintType) ([]byte, error) { return s.fingerprinter.FingerprintUsing(typ, s) } +// CacheFingerprint returns a special fingerprint of the schema for caching purposes. +func (s *EnumSchema) CacheFingerprint() [32]byte { + if len(s.actual) == 0 || !s.HasDefault() { + return s.Fingerprint() + } + + return s.cacheFingerprinter.fingerprint([]any{s.Fingerprint(), s.actual, s.Default()}) +} + // ArraySchema is an Avro array type schema. type ArraySchema struct { properties diff --git a/schema_compatibility.go b/schema_compatibility.go index fd0c1a9..efb1eef 100644 --- a/schema_compatibility.go +++ b/schema_compatibility.go @@ -149,6 +149,9 @@ func (c *SchemaCompatibility) match(reader, writer Schema) error { } if err := c.checkEnumSymbols(r, w); err != nil { + if r.HasDefault() { + return nil + } return err } @@ -324,6 +327,20 @@ func (c *SchemaCompatibility) resolve(reader, writer Schema) (Schema, error) { } if writer.Type() == Enum { + r := reader.(*EnumSchema) + w := writer.(*EnumSchema) + if err := c.checkEnumSymbols(r, w); err != nil { + if r.HasDefault() { + enum, _ := NewEnumSchema(r.Name(), r.Namespace(), r.Symbols(), + WithAliases(r.Aliases()), + WithDefault(r.Default()), + ) + enum.actual = w.Symbols() + return enum, nil + } + + return nil, err + } return reader, nil } diff --git a/schema_compatibility_test.go b/schema_compatibility_test.go index d656cb9..500667b 100644 --- a/schema_compatibility_test.go +++ b/schema_compatibility_test.go @@ -178,6 +178,12 @@ func TestSchemaCompatibility_Compatible(t *testing.T) { writer: `{"type":"enum", "name":"test", "namespace": "org.hamba.avro", "symbols":["TEST1", "TEST2"]}`, wantErr: assert.Error, }, + { + name: "Enum Reader Missing Symbol With Default", + reader: `{"type":"enum", "name":"test", "namespace": "org.hamba.avro", "symbols":["TEST1"], "default": "TEST1"}`, + writer: `{"type":"enum", "name":"test", "namespace": "org.hamba.avro", "symbols":["TEST1", "TEST2"]}`, + wantErr: assert.NoError, + }, { name: "Enum Writer Missing Symbol", reader: `{"type":"enum", "name":"test", "namespace": "org.hamba.avro", "symbols":["TEST1", "TEST2"]}`, @@ -387,6 +393,53 @@ func TestSchemaCompatibility_Resolve(t *testing.T) { value: map[string]any{"foo": "bar"}, want: map[string]any{"foo": []byte("bar")}, }, + { + name: "Enum Reader Missing Symbols With Default", + reader: `{ + "type": "enum", + "name": "test.enum", + "symbols": ["foo"], + "default": "foo" + }`, + writer: `{ + "type": "enum", + "name": "test.enum", + "symbols": ["foo", "bar"] + }`, + value: "bar", + want: "foo", + }, + { + name: "Enum Writer Missing Symbols", + reader: `{ + "type": "enum", + "name": "test.enum", + "symbols": ["foo", "bar"] + }`, + writer: `{ + "type": "enum", + "name": "test.enum", + "symbols": ["foo"] + }`, + value: "foo", + want: "foo", + }, + { + name: "Enum Writer Missing Symbols and Unused Reader Default", + reader: `{ + "type": "enum", + "name": "test.enum", + "symbols": ["foo", "bar"], + "default": "bar" + }`, + writer: `{ + "type": "enum", + "name": "test.enum", + "symbols": ["foo"] + }`, + value: "foo", + want: "foo", + }, { name: "Enum With Alias", reader: `{ diff --git a/schema_internal_test.go b/schema_internal_test.go index 6d78e9b..575a608 100644 --- a/schema_internal_test.go +++ b/schema_internal_test.go @@ -558,6 +558,25 @@ func TestSchema_CacheFingerprint(t *testing.T) { assert.NotEqual(t, schema.Fingerprint(), schema.CacheFingerprint()) }) + t.Run("enum", func(t *testing.T) { + schema1 := MustParse(`{ + "type": "enum", + "name": "test.enum", + "symbols": ["foo"] + }`).(*EnumSchema) + + schema2 := MustParse(`{ + "type": "enum", + "name": "test.enum", + "symbols": ["foo"], + "default": "foo" + }`).(*EnumSchema) + schema2.actual = []string{"boo"} + + assert.Equal(t, schema1.Fingerprint(), schema1.CacheFingerprint()) + assert.NotEqual(t, schema1.CacheFingerprint(), schema2.CacheFingerprint()) + }) + t.Run("record", func(t *testing.T) { schema1 := MustParse(`{ "type": "record", @@ -581,3 +600,78 @@ func TestSchema_CacheFingerprint(t *testing.T) { assert.NotEqual(t, schema1.CacheFingerprint(), schema2.CacheFingerprint()) }) } + +func TestEnumSchema_GetSymbol(t *testing.T) { + tests := []struct { + schemaFn func() *EnumSchema + idx int + want any + wantOk bool + }{ + { + schemaFn: func() *EnumSchema { + enum, _ := NewEnumSchema("foo", "", []string{"BAR"}) + return enum + }, + idx: 0, + wantOk: true, + want: "BAR", + }, + { + schemaFn: func() *EnumSchema { + enum, _ := NewEnumSchema("foo", "", []string{"BAR"}) + return enum + }, + idx: 1, + wantOk: false, + }, + { + schemaFn: func() *EnumSchema { + enum, _ := NewEnumSchema("foo", "", []string{"FOO"}, WithDefault("FOO")) + return enum + }, + idx: 1, + wantOk: false, + }, + { + schemaFn: func() *EnumSchema { + enum, _ := NewEnumSchema("foo", "", []string{"FOO"}) + enum.actual = []string{"FOO", "BAR"} + return enum + }, + idx: 1, + wantOk: false, + }, + { + schemaFn: func() *EnumSchema { + enum, _ := NewEnumSchema("foo", "", []string{"FOO"}, WithDefault("FOO")) + enum.actual = []string{"FOO", "BAR"} + return enum + }, + idx: 1, + wantOk: true, + want: "FOO", + }, + { + schemaFn: func() *EnumSchema { + enum, _ := NewEnumSchema("foo", "", []string{"FOO", "BAR"}) + enum.actual = []string{"FOO"} + return enum + }, + idx: 0, + wantOk: true, + want: "FOO", + }, + } + + for i, test := range tests { + test := test + t.Run(strconv.Itoa(i), func(t *testing.T) { + got, ok := test.schemaFn().Symbol(test.idx) + assert.Equal(t, test.wantOk, ok) + if ok { + assert.Equal(t, test.want, got) + } + }) + } +}