diff --git a/README.md b/README.md index 310e6cf..971bace 100644 --- a/README.md +++ b/README.md @@ -115,7 +115,7 @@ When a non-`nil` union value is encountered, a single key is en/decoded. The key type name, or scheam full name in the case of a named schema (enum, fixed or record). * ***T:** This is allowed in a "nullable" union. A nullable union is defined as a two schema union, with one of the types being `null` (ie. `["null", "string"]` or `["string", "null"]`), in this case -a `*T` is allowed, with `T` matching the conversion table above. In the case of a slice, the slice can be used +a `*T` is allowed, with `T` matching the conversion table above. In the case of a slice or a map, the slice or the map can be used directly. * **any:** An `interface` can be provided and the type or name resolved. Primitive types are pre-registered, but named types, maps and slices will need to be registered with the `Register` function. diff --git a/cmd/avrogen/main.go b/cmd/avrogen/main.go index 2009101..157fe83 100644 --- a/cmd/avrogen/main.go +++ b/cmd/avrogen/main.go @@ -25,6 +25,8 @@ type config struct { FullName bool Encoders bool FullSchema bool + PlainMap bool + PlainSlice bool StrictTypes bool Initialisms string } @@ -45,6 +47,8 @@ func realMain(args []string, stdout, stderr io.Writer) int { flgs.BoolVar(&cfg.Encoders, "encoders", false, "Generate encoders for the structs.") flgs.BoolVar(&cfg.FullSchema, "fullschema", false, "Use the full schema in the generated encoders.") flgs.BoolVar(&cfg.StrictTypes, "strict-types", false, "Use strict type sizes (e.g. int32) during generation.") + flgs.BoolVar(&cfg.PlainMap, "plain-map", false, "Use a plain map instead of a ptr for nullable map unions.") + flgs.BoolVar(&cfg.PlainSlice, "plain-slice", false, "Use a plain slice instead of a ptr for nullable slice unions.") flgs.StringVar(&cfg.Initialisms, "initialisms", "", "Custom initialisms [,...] for struct and field names.") flgs.StringVar(&cfg.TemplateFileName, "template-filename", "", "Override output template with one loaded from file.") flgs.Usage = func() { @@ -85,6 +89,8 @@ func realMain(args []string, stdout, stderr io.Writer) int { gen.WithEncoders(cfg.Encoders), gen.WithInitialisms(initialisms), gen.WithTemplate(string(template)), + gen.WithPlainMap(cfg.PlainMap), + gen.WithPlainSlice(cfg.PlainSlice), gen.WithStrictTypes(cfg.StrictTypes), gen.WithFullSchema(cfg.FullSchema), } diff --git a/cmd/avrogen/main_test.go b/cmd/avrogen/main_test.go index c15c3c4..10d7807 100644 --- a/cmd/avrogen/main_test.go +++ b/cmd/avrogen/main_test.go @@ -189,6 +189,33 @@ func TestAvroGen_GeneratesSchemaWithStrictTypes(t *testing.T) { assert.Equal(t, want, got) } +func TestAvroGen_GeneratePlain(t *testing.T) { + for _, opt := range []string{"map", "slice"} { + t.Run(opt, func(t *testing.T) { + t.Parallel() + + path := t.TempDir() + file := filepath.Join(path, "test.go") + + args := []string{"avrogen", "-pkg", "testpkg", "-o", file, "-plain-" + opt, "testdata/schema.avsc"} + gotCode := realMain(args, io.Discard, io.Discard) + require.Equal(t, 0, gotCode) + + got, err := os.ReadFile(file) + require.NoError(t, err) + + if *update { + err = os.WriteFile("testdata/golden_plain"+opt+".go", got, 0600) + require.NoError(t, err) + } + + want, err := os.ReadFile("testdata/golden_plain" + opt + ".go") + require.NoError(t, err) + assert.Equal(t, want, got) + }) + } +} + func TestParseTags(t *testing.T) { tests := []struct { name string diff --git a/cmd/avrogen/testdata/golden.go b/cmd/avrogen/testdata/golden.go index 1f17271..4583397 100644 --- a/cmd/avrogen/testdata/golden.go +++ b/cmd/avrogen/testdata/golden.go @@ -5,6 +5,8 @@ package testpkg // Test is a test struct. type Test struct { // SomeString is a string. - SomeString string `avro:"someString"` - SomeInt int `avro:"someInt"` + SomeString string `avro:"someString"` + SomeInt int `avro:"someInt"` + SomeNullableMap *map[string]int `avro:"someNullableMap"` + SomeNullableSlice *[]int `avro:"someNullableSlice"` } diff --git a/cmd/avrogen/testdata/golden_encoders.go b/cmd/avrogen/testdata/golden_encoders.go index 44b3e50..a5a552d 100644 --- a/cmd/avrogen/testdata/golden_encoders.go +++ b/cmd/avrogen/testdata/golden_encoders.go @@ -8,11 +8,13 @@ import ( // Test is a test struct. type Test struct { // SomeString is a string. - SomeString string `avro:"someString"` - SomeInt int `avro:"someInt"` + SomeString string `avro:"someString"` + SomeInt int `avro:"someInt"` + SomeNullableMap *map[string]int `avro:"someNullableMap"` + SomeNullableSlice *[]int `avro:"someNullableSlice"` } -var schemaTest = avro.MustParse(`{"name":"a.b.test","type":"record","fields":[{"name":"someString","type":"string"},{"name":"someInt","type":"int"}]}`) +var schemaTest = avro.MustParse(`{"name":"a.b.test","type":"record","fields":[{"name":"someString","type":"string"},{"name":"someInt","type":"int"},{"name":"someNullableMap","type":["null",{"type":"map","values":"int"}]},{"name":"someNullableSlice","type":["null",{"type":"array","items":"int"}]}]}`) // Schema returns the schema for Test. func (o *Test) Schema() avro.Schema { diff --git a/cmd/avrogen/testdata/golden_encoders_fullschema.go b/cmd/avrogen/testdata/golden_encoders_fullschema.go index 7bc7736..9471dfc 100644 --- a/cmd/avrogen/testdata/golden_encoders_fullschema.go +++ b/cmd/avrogen/testdata/golden_encoders_fullschema.go @@ -8,11 +8,13 @@ import ( // Test is a test struct. type Test struct { // SomeString is a string. - SomeString string `avro:"someString"` - SomeInt int `avro:"someInt"` + SomeString string `avro:"someString"` + SomeInt int `avro:"someInt"` + SomeNullableMap *map[string]int `avro:"someNullableMap"` + SomeNullableSlice *[]int `avro:"someNullableSlice"` } -var schemaTest = avro.MustParse(`{"name":"a.b.test","doc":"Test is a test struct","type":"record","fields":[{"name":"someString","doc":"SomeString is a string","type":"string"},{"name":"someInt","type":"int"}]}`) +var schemaTest = avro.MustParse(`{"name":"a.b.test","doc":"Test is a test struct","type":"record","fields":[{"name":"someString","doc":"SomeString is a string","type":"string"},{"name":"someInt","type":"int"},{"name":"someNullableMap","type":["null",{"type":"map","values":"int"}]},{"name":"someNullableSlice","type":["null",{"type":"array","items":"int"}]}]}`) // Schema returns the schema for Test. func (o *Test) Schema() avro.Schema { diff --git a/cmd/avrogen/testdata/golden_fullname.go b/cmd/avrogen/testdata/golden_fullname.go index 5b44801..65327f4 100644 --- a/cmd/avrogen/testdata/golden_fullname.go +++ b/cmd/avrogen/testdata/golden_fullname.go @@ -4,6 +4,8 @@ package testpkg // Test is a test struct. type ABTest struct { // SomeString is a string. - SomeString string `avro:"someString"` - SomeInt int `avro:"someInt"` + SomeString string `avro:"someString"` + SomeInt int `avro:"someInt"` + SomeNullableMap *map[string]int `avro:"someNullableMap"` + SomeNullableSlice *[]int `avro:"someNullableSlice"` } diff --git a/cmd/avrogen/testdata/golden_plainmap.go b/cmd/avrogen/testdata/golden_plainmap.go new file mode 100644 index 0000000..a276792 --- /dev/null +++ b/cmd/avrogen/testdata/golden_plainmap.go @@ -0,0 +1,11 @@ +// Code generated by avro/gen. DO NOT EDIT. +package testpkg + +// Test is a test struct. +type Test struct { + // SomeString is a string. + SomeString string `avro:"someString"` + SomeInt int `avro:"someInt"` + SomeNullableMap map[string]int `avro:"someNullableMap"` + SomeNullableSlice *[]int `avro:"someNullableSlice"` +} diff --git a/cmd/avrogen/testdata/golden_plainslice.go b/cmd/avrogen/testdata/golden_plainslice.go new file mode 100644 index 0000000..fa01932 --- /dev/null +++ b/cmd/avrogen/testdata/golden_plainslice.go @@ -0,0 +1,11 @@ +// Code generated by avro/gen. DO NOT EDIT. +package testpkg + +// Test is a test struct. +type Test struct { + // SomeString is a string. + SomeString string `avro:"someString"` + SomeInt int `avro:"someInt"` + SomeNullableMap *map[string]int `avro:"someNullableMap"` + SomeNullableSlice []int `avro:"someNullableSlice"` +} diff --git a/cmd/avrogen/testdata/golden_stricttypes.go b/cmd/avrogen/testdata/golden_stricttypes.go index a3edc67..012d8f6 100644 --- a/cmd/avrogen/testdata/golden_stricttypes.go +++ b/cmd/avrogen/testdata/golden_stricttypes.go @@ -4,6 +4,8 @@ package testpkg // Test is a test struct. type Test struct { // SomeString is a string. - SomeString string `avro:"someString"` - SomeInt int32 `avro:"someInt"` + SomeString string `avro:"someString"` + SomeInt int32 `avro:"someInt"` + SomeNullableMap *map[string]int32 `avro:"someNullableMap"` + SomeNullableSlice *[]int32 `avro:"someNullableSlice"` } diff --git a/cmd/avrogen/testdata/schema.avsc b/cmd/avrogen/testdata/schema.avsc index 8754f39..8c75f7f 100644 --- a/cmd/avrogen/testdata/schema.avsc +++ b/cmd/avrogen/testdata/schema.avsc @@ -5,6 +5,8 @@ "doc": "Test is a test struct", "fields": [ { "name": "someString", "type": "string", "doc": "SomeString is a string" }, - { "name": "someInt", "type": "int" } + { "name": "someInt", "type": "int" }, + { "name": "someNullableMap", "type": ["null", {"type": "map", "values": "int"}]}, + { "name": "someNullableSlice", "type": ["null", {"type": "array", "items": "int"}]} ] -} \ No newline at end of file +} diff --git a/codec_union.go b/codec_union.go index 7d80b53..1638392 100644 --- a/codec_union.go +++ b/codec_union.go @@ -13,10 +13,18 @@ import ( func createDecoderOfUnion(d *decoderContext, schema *UnionSchema, typ reflect2.Type) ValDecoder { switch typ.Kind() { case reflect.Map: - if typ.(reflect2.MapType).Key().Kind() != reflect.String || - typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { + if typ.(reflect2.MapType).Key().Kind() != reflect.String { break } + + if typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { + if !schema.Nullable() { + break + } + + return decoderOfNullableUnion(d, schema, typ) + } + return decoderOfMapUnion(d, schema, typ) case reflect.Slice: if !schema.Nullable() { @@ -44,10 +52,18 @@ func createDecoderOfUnion(d *decoderContext, schema *UnionSchema, typ reflect2.T func createEncoderOfUnion(e *encoderContext, schema *UnionSchema, typ reflect2.Type) ValEncoder { switch typ.Kind() { case reflect.Map: - if typ.(reflect2.MapType).Key().Kind() != reflect.String || - typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { + if typ.(reflect2.MapType).Key().Kind() != reflect.String { break } + + if typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { + if !schema.Nullable() { + break + } + + return encoderOfNullableUnion(e, schema, typ) + } + return encoderOfMapUnion(e, schema, typ) case reflect.Slice: if !schema.Nullable() { @@ -179,6 +195,8 @@ func decoderOfNullableUnion(d *decoderContext, schema Schema, typ reflect2.Type) isPtr = true case *reflect2.UnsafeSliceType: baseTyp = v + case *reflect2.UnsafeMapType: + baseTyp = v } decoder := decoderOfType(d, union.Types()[typeIdx], baseTyp) @@ -249,6 +267,8 @@ func encoderOfNullableUnion(e *encoderContext, schema Schema, typ reflect2.Type) isPtr = true case *reflect2.UnsafeSliceType: baseTyp = v + case *reflect2.UnsafeMapType: + baseTyp = v } encoder := encoderOfType(e, union.Types()[typeIdx], baseTyp) diff --git a/decoder_union_test.go b/decoder_union_test.go index 4af546b..bf08093 100644 --- a/decoder_union_test.go +++ b/decoder_union_test.go @@ -225,6 +225,42 @@ func TestDecoder_UnionPtrReversedNull(t *testing.T) { assert.Nil(t, got) } +func TestDecoder_UnionNullableMap(t *testing.T) { + tt := []struct { + name string + data []byte + want map[string]string + }{{ + name: "WithData", + data: []byte{0x02, 0x01, 0x10, 0x06, 0x66, 0x6F, 0x6F, 0x06, 0x66, 0x6F, 0x6F, 0x00}, + want: map[string]string{"foo": "foo"}, + }, { + name: "Empty", + data: []byte{0x02, 0x00}, + want: map[string]string{}, + }, { + name: "Null", + data: []byte{0x00}, + want: nil, + }} + + schema := `["null", {"type":"map", "values": "string"}]` + + for _, test := range tt { + t.Run(test.name, func(t *testing.T) { + defer ConfigTeardown() + + dec, _ := avro.NewDecoder(schema, bytes.NewReader(test.data)) + + var got map[string]string + err := dec.Decode(&got) + + require.NoError(t, err) + assert.Equal(t, test.want, got) + }) + } +} + func TestDecoder_UnionNullableSlice(t *testing.T) { defer ConfigTeardown() diff --git a/encoder_union_test.go b/encoder_union_test.go index 72c9d6e..9d3a4cf 100644 --- a/encoder_union_test.go +++ b/encoder_union_test.go @@ -256,6 +256,43 @@ func TestEncoder_UnionPtrNotNullable(t *testing.T) { assert.Error(t, err) } +func TestEncoder_UnionNullableMap(t *testing.T) { + tt := []struct { + name string + data map[string]string + want []byte + }{{ + name: "WithData", + data: map[string]string{"foo": "foo"}, + want: []byte{0x02, 0x01, 0x10, 0x06, 0x66, 0x6F, 0x6F, 0x06, 0x66, 0x6F, 0x6F, 0x00}, + }, { + name: "Empty", + data: map[string]string{}, + want: []byte{0x02, 0x00}, + }, { + name: "Null", + data: nil, + want: []byte{0x00}, + }} + + schema := `["null", {"type":"map", "values": "string"}]` + + for _, test := range tt { + t.Run(test.name, func(t *testing.T) { + defer ConfigTeardown() + + buf := bytes.NewBuffer([]byte{}) + enc, err := avro.NewEncoder(schema, buf) + require.NoError(t, err) + + err = enc.Encode(test.data) + + require.NoError(t, err) + assert.Equal(t, test.want, buf.Bytes()) + }) + } +} + func TestEncoder_UnionNullableSlice(t *testing.T) { defer ConfigTeardown() diff --git a/gen/gen.go b/gen/gen.go index 010fc85..b604a42 100644 --- a/gen/gen.go +++ b/gen/gen.go @@ -24,6 +24,8 @@ type Config struct { FullName bool Encoders bool FullSchema bool + PlainMap bool + PlainSlice bool StrictTypes bool Initialisms []string LogicalTypes []LogicalType @@ -43,6 +45,9 @@ const ( Kebab TagStyle = "kebab" // UpperCamel is a style like ImWrittenInUpperCamel. UpperCamel TagStyle = "upper-camel" + + prefixMap string = "map[" + prefixSlice string = "[]" ) //go:embed output_template.tmpl @@ -85,6 +90,8 @@ func StructFromSchema(schema avro.Schema, w io.Writer, cfg Config) error { WithInitialisms(cfg.Initialisms), WithStrictTypes(cfg.StrictTypes), WithFullSchema(cfg.FullSchema), + WithPlainMap(cfg.PlainMap), + WithPlainSlice(cfg.PlainSlice), } for _, opt := range cfg.LogicalTypes { opts = append(opts, WithLogicalType(opt)) @@ -168,6 +175,20 @@ func WithFullSchema(b bool) OptsFunc { } } +// WithPlainMap configures the generator to emit a plain map and not a map ptr for nullable unions. +func WithPlainMap(b bool) OptsFunc { + return func(g *Generator) { + g.plainMap = b + } +} + +// WithPlainSlice configures the generator to emit a plain map and not a map ptr for nullable unions. +func WithPlainSlice(b bool) OptsFunc { + return func(g *Generator) { + g.plainSlice = b + } +} + // LogicalType used when the name of the "LogicalType" field in the Avro schema matches the Name attribute. type LogicalType struct { // Name of the LogicalType @@ -210,6 +231,8 @@ type Generator struct { fullName bool encoders bool fullSchema bool + plainMap bool + plainSlice bool strictTypes bool initialisms []string logicalTypes map[avro.LogicalType]LogicalType @@ -356,6 +379,14 @@ func (g *Generator) resolveUnionTypes(s *avro.UnionSchema) string { types = append(types, g.generate(elem)) } if s.Nullable() { + if g.plainMap && strings.HasPrefix(types[0], prefixMap) { + return types[0] + } + + if g.plainSlice && strings.HasPrefix(types[0], prefixSlice) { + return types[0] + } + return "*" + types[0] } return "any" diff --git a/gen/gen_test.go b/gen/gen_test.go index cceede1..7fd2f31 100644 --- a/gen/gen_test.go +++ b/gen/gen_test.go @@ -186,6 +186,60 @@ func TestStruct_ConfigurableLogicalTypes(t *testing.T) { } } +func TestStruct_GenPlain(t *testing.T) { + tt := []struct { + name, schema string + lines []string + cfg gen.Config + }{ + { + name: "Map", + schema: `{ + "type": "record", + "name": "test", + "fields": [ + { "name": "someRegularMap", "type": {"type": "map", "values": "int"}}, + { "name": "someNullableMap", "type": ["null", {"type": "map", "values": "int"}]} + ] +}`, + lines: []string{ + "SomeRegularMap map[string]int `avro:\"someRegularMap\"`", + "SomeNullableMap map[string]int `avro:\"someNullableMap\"`", + }, + cfg: gen.Config{PackageName: "Something", PlainMap: true}, + }, { + name: "Slice", + schema: `{ + "type": "record", + "name": "test", + "fields": [ + + { "name": "someRegularSlice", "type": {"type": "array", "items": "int"}}, + { "name": "someNullableSlice", "type": ["null", {"type": "array", "items": "int"}]} + ] +}`, + lines: []string{ + "SomeRegularSlice []int `avro:\"someRegularSlice\"`", + "SomeNullableSlice []int `avro:\"someNullableSlice\"`", + }, + cfg: gen.Config{PackageName: "Something", PlainSlice: true}, + }, + } + + for _, test := range tt { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + _, lines := generate(t, test.schema, test.cfg) + + for _, line := range test.lines { + assert.Contains(t, lines, line) + } + + }) + } +} + func TestStruct_GenFromRecordSchema(t *testing.T) { fileName := "testdata/golden.go" gc := gen.Config{PackageName: "Something"}