diff --git a/ocf/ocf.go b/ocf/ocf.go index 9bf5c8c3..75216656 100644 --- a/ocf/ocf.go +++ b/ocf/ocf.go @@ -144,6 +144,7 @@ type encoderConfig struct { CodecCompression int Metadata map[string][]byte Sync [16]byte + EncodingConfig avro.API } // EncoderFunc represents an configuration function for Encoder. @@ -186,6 +187,13 @@ func WithSyncBlock(sync [16]byte) EncoderFunc { } } +// WithEncodingConfig sets the value encoder config on the OCF encoder. +func WithEncodingConfig(wCfg avro.API) EncoderFunc { + return func(cfg *encoderConfig) { + cfg.EncodingConfig = wCfg + } +} + // Encoder writes Avro container file to an output stream. type Encoder struct { writer *avro.Writer @@ -209,6 +217,7 @@ func NewEncoder(s string, w io.Writer, opts ...EncoderFunc) (*Encoder, error) { CodecName: Null, CodecCompression: -1, Metadata: map[string][]byte{}, + EncodingConfig: avro.DefaultConfig, } for _, opt := range opts { opt(&cfg) @@ -233,12 +242,12 @@ func NewEncoder(s string, w io.Writer, opts ...EncoderFunc) (*Encoder, error) { return nil, err } - writer := avro.NewWriter(w, 512) + writer := avro.NewWriter(w, 512, avro.WithWriterConfig(cfg.EncodingConfig)) buf := &bytes.Buffer{} e := &Encoder{ writer: writer, buf: buf, - encoder: avro.NewEncoderForSchema(h.Schema, buf), + encoder: cfg.EncodingConfig.NewEncoder(h.Schema, buf), sync: h.Sync, codec: h.Codec, blockLength: cfg.BlockLength, @@ -268,7 +277,7 @@ func NewEncoder(s string, w io.Writer, opts ...EncoderFunc) (*Encoder, error) { return nil, err } - writer := avro.NewWriter(w, 512) + writer := avro.NewWriter(w, 512, avro.WithWriterConfig(cfg.EncodingConfig)) writer.WriteVal(HeaderSchema, header) if err = writer.Flush(); err != nil { return nil, err @@ -278,7 +287,7 @@ func NewEncoder(s string, w io.Writer, opts ...EncoderFunc) (*Encoder, error) { e := &Encoder{ writer: writer, buf: buf, - encoder: avro.NewEncoderForSchema(schema, buf), + encoder: cfg.EncodingConfig.NewEncoder(schema, buf), sync: header.Sync, codec: codec, blockLength: cfg.BlockLength, diff --git a/ocf/ocf_test.go b/ocf/ocf_test.go index cd6f97b9..3b02a082 100644 --- a/ocf/ocf_test.go +++ b/ocf/ocf_test.go @@ -409,6 +409,80 @@ func TestEncoder(t *testing.T) { assert.NoError(t, err) } +func TestEncoder_WithEncodingConfig(t *testing.T) { + arrSchema := `{"type": "array", "items": "long"}` + syncMarker := [16]byte{0x1F, 0x1F, 0x1F, 0x1F, 0x2F, 0x2F, 0x2F, 0x2F, 0x3F, 0x3F, 0x3F, 0x3F, 0x4F, 0x4F, 0x4F, 0x4F} + + skipOcfHeader := func(encoded []byte) []byte { + index := bytes.Index(encoded, syncMarker[:]) + require.False(t, index == -1) + return encoded[index+len(syncMarker):] // +1 for the null byte + } + + tests := []struct { + name string + data any + encConfig avro.API + wantPayload []byte // without OCF header + }{ + { + name: "no encoding config", + data: []int64{1, 2, 3, 4, 5}, + wantPayload: []byte{ + 0x2, 0x10, // OCF block header: 1 elems, 8 bytes + 0x9, 0xA, // array block header: 5 elems, 5 bytes + 0x2, 0x4, 0x6, 0x8, 0xA, 0x0, // array block payload with terminator + 0x1F, 0x1F, 0x1F, 0x1F, 0x2F, 0x2F, 0x2F, 0x2F, 0x3F, 0x3F, 0x3F, 0x3F, 0x4F, 0x4F, 0x4F, 0x4F, // OCF trailing sync marker + }, + }, + { + name: "no array bytes size", + encConfig: avro.Config{DisableBlockSizeHeader: true}.Freeze(), + data: []int64{1, 2, 3, 4, 5}, + wantPayload: []byte{ + 0x2, 0x0E, // OCF block header: 1 elem, 7 bytes + 0xA, // array block header: 5 elems + 0x2, 0x4, 0x6, 0x8, 0xA, 0x0, // array block payload with terminator + 0x1F, 0x1F, 0x1F, 0x1F, 0x2F, 0x2F, 0x2F, 0x2F, 0x3F, 0x3F, 0x3F, 0x3F, 0x4F, 0x4F, 0x4F, 0x4F, // OCF trailing sync marker + }, + }, + { + name: "non-default array block length", + encConfig: avro.Config{BlockLength: 5}.Freeze(), + data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + wantPayload: []byte{ + 0x2, 0x1c, // OCF block header: 1 elems, 15 bytes + 0x9, 0xA, // array block 1 header: 5 elems, 5 bytes + 0x2, 0x4, 0x6, 0x8, 0xA, // array block 1 + 0x7, 0x8, // array block 2 header: 4 elems, 4 bytes + 0xC, 0xE, 0x10, 0x12, 0x0, // array block 2 with terminator + 0x1F, 0x1F, 0x1F, 0x1F, 0x2F, 0x2F, 0x2F, 0x2F, 0x3F, 0x3F, 0x3F, 0x3F, 0x4F, 0x4F, 0x4F, 0x4F, // OCF sync marker + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := &bytes.Buffer{} + opts := []ocf.EncoderFunc{ocf.WithSyncBlock(syncMarker)} + if tt.encConfig != nil { + opts = append(opts, ocf.WithEncodingConfig(tt.encConfig)) + } + enc, err := ocf.NewEncoder(arrSchema, buf, opts...) + require.NoError(t, err) + + err = enc.Encode(tt.data) + require.NoError(t, err) + + err = enc.Close() + assert.NoError(t, err) + + assert.Equal(t, tt.wantPayload, skipOcfHeader(buf.Bytes())) + }) + } + +} + func TestEncoder_ExistingOCF(t *testing.T) { record := FullRecord{ Strings: []string{"another", "record"},