diff --git a/ocf/ocf.go b/ocf/ocf.go index c446cef..d980e4f 100644 --- a/ocf/ocf.go +++ b/ocf/ocf.go @@ -6,6 +6,7 @@ package ocf import ( "bytes" "crypto/rand" + "encoding/json" "errors" "fmt" "io" @@ -20,10 +21,11 @@ const ( codecKey = "avro.codec" ) -var magicBytes = [4]byte{'O', 'b', 'j', 1} +var ( + magicBytes = [4]byte{'O', 'b', 'j', 1} -// HeaderSchema is the Avro schema of a container file header. -var HeaderSchema = avro.MustParse(`{ + // HeaderSchema is the Avro schema of a container file header. + HeaderSchema = avro.MustParse(`{ "type": "record", "name": "org.apache.avro.file.Header", "fields": [ @@ -33,6 +35,15 @@ var HeaderSchema = avro.MustParse(`{ ] }`) + // DefaultSchemaMarshaler calls the schema's String() method, to produce + // a "canonical" schema. + DefaultSchemaMarshaler = defaultMarshalSchema + // FullSchemaMarshaler calls the schema's MarshalJSON() method, to produce + // a schema with all details preserved. The "canonical" schema returned by + // the default marshaler does not preserve a type's extra properties. + FullSchemaMarshaler = fullMarshalSchema +) + // Header represents an Avro container file header. type Header struct { Magic [4]byte `avro:"magic"` @@ -42,6 +53,7 @@ type Header struct { type decoderConfig struct { DecoderConfig avro.API + SchemaCache *avro.SchemaCache } // DecoderFunc represents a configuration function for Decoder. @@ -54,6 +66,14 @@ func WithDecoderConfig(wCfg avro.API) DecoderFunc { } } +// WithDecoderSchemaCache sets the schema cache for the decoder. +// If not specified, defaults to avro.DefaultSchemaCache. +func WithDecoderSchemaCache(cache *avro.SchemaCache) DecoderFunc { + return func(cfg *decoderConfig) { + cfg.SchemaCache = cache + } +} + // Decoder reads and decodes Avro values from a container file. type Decoder struct { reader *avro.Reader @@ -61,6 +81,7 @@ type Decoder struct { decoder *avro.Decoder meta map[string][]byte sync [16]byte + schema avro.Schema codec Codec @@ -71,6 +92,7 @@ type Decoder struct { func NewDecoder(r io.Reader, opts ...DecoderFunc) (*Decoder, error) { cfg := decoderConfig{ DecoderConfig: avro.DefaultConfig, + SchemaCache: avro.DefaultSchemaCache, } for _, opt := range opts { opt(&cfg) @@ -78,7 +100,7 @@ func NewDecoder(r io.Reader, opts ...DecoderFunc) (*Decoder, error) { reader := avro.NewReader(r, 1024) - h, err := readHeader(reader) + h, err := readHeader(reader, cfg.SchemaCache) if err != nil { return nil, fmt.Errorf("decoder: %w", err) } @@ -92,6 +114,7 @@ func NewDecoder(r io.Reader, opts ...DecoderFunc) (*Decoder, error) { meta: h.Meta, sync: h.Sync, codec: h.Codec, + schema: h.Schema, }, nil } @@ -100,6 +123,12 @@ func (d *Decoder) Metadata() map[string][]byte { return d.meta } +// Schema returns the schema that was parsed from the file's metadata +// and that is used to interpret the file's contents. +func (d *Decoder) Schema() avro.Schema { + return d.schema +} + // HasNext determines if there is another value to read. func (d *Decoder) HasNext() bool { if d.count <= 0 { @@ -174,6 +203,8 @@ type encoderConfig struct { Metadata map[string][]byte Sync [16]byte EncodingConfig avro.API + SchemaCache *avro.SchemaCache + SchemaMarshaler func(avro.Schema) ([]byte, error) } // EncoderFunc represents a configuration function for Encoder. @@ -209,6 +240,22 @@ func WithMetadata(meta map[string][]byte) EncoderFunc { } } +// WithEncoderSchemaCache sets the schema cache for the encoder. +// If not specified, defaults to avro.DefaultSchemaCache. +func WithEncoderSchemaCache(cache *avro.SchemaCache) EncoderFunc { + return func(cfg *encoderConfig) { + cfg.SchemaCache = cache + } +} + +// WithSchemaMarshaler sets the schema marshaler for the encoder. +// If not specified, defaults to DefaultSchemaMarshaler. +func WithSchemaMarshaler(m func(avro.Schema) ([]byte, error)) EncoderFunc { + return func(cfg *encoderConfig) { + cfg.SchemaMarshaler = m + } +} + // WithSyncBlock sets the sync block. func WithSyncBlock(sync [16]byte) EncoderFunc { return func(cfg *encoderConfig) { @@ -241,17 +288,23 @@ type Encoder struct { // If the writer is an existing ocf file, it will append data using the // existing schema. func NewEncoder(s string, w io.Writer, opts ...EncoderFunc) (*Encoder, error) { - cfg := encoderConfig{ - BlockLength: 100, - CodecName: Null, - CodecCompression: -1, - Metadata: map[string][]byte{}, - EncodingConfig: avro.DefaultConfig, - } - for _, opt := range opts { - opt(&cfg) + cfg := computeEncoderConfig(opts) + schema, err := avro.ParseWithCache(s, "", cfg.SchemaCache) + if err != nil { + return nil, err } + return newEncoder(schema, w, cfg) +} + +// NewEncoderWithSchema returns a new encoder that writes to w using schema s. +// +// If the writer is an existing ocf file, it will append data using the +// existing schema. +func NewEncoderWithSchema(schema avro.Schema, w io.Writer, opts ...EncoderFunc) (*Encoder, error) { + return newEncoder(schema, w, computeEncoderConfig(opts)) +} +func newEncoder(schema avro.Schema, w io.Writer, cfg encoderConfig) (*Encoder, error) { switch file := w.(type) { case nil: return nil, errors.New("writer cannot be nil") @@ -263,7 +316,7 @@ func NewEncoder(s string, w io.Writer, opts ...EncoderFunc) (*Encoder, error) { if info.Size() > 0 { reader := avro.NewReader(file, 1024) - h, err := readHeader(reader) + h, err := readHeader(reader, cfg.SchemaCache) if err != nil { return nil, err } @@ -285,12 +338,12 @@ func NewEncoder(s string, w io.Writer, opts ...EncoderFunc) (*Encoder, error) { } } - schema, err := avro.Parse(s) + schemaJSON, err := cfg.SchemaMarshaler(schema) if err != nil { return nil, err } - cfg.Metadata[schemaKey] = []byte(schema.String()) + cfg.Metadata[schemaKey] = schemaJSON cfg.Metadata[codecKey] = []byte(cfg.CodecName) header := Header{ Magic: magicBytes, @@ -324,6 +377,22 @@ func NewEncoder(s string, w io.Writer, opts ...EncoderFunc) (*Encoder, error) { return e, nil } +func computeEncoderConfig(opts []EncoderFunc) encoderConfig { + cfg := encoderConfig{ + BlockLength: 100, + CodecName: Null, + CodecCompression: -1, + Metadata: map[string][]byte{}, + EncodingConfig: avro.DefaultConfig, + SchemaCache: avro.DefaultSchemaCache, + SchemaMarshaler: DefaultSchemaMarshaler, + } + for _, opt := range opts { + opt(&cfg) + } + return cfg +} + // Write v to the internal buffer. This method skips the internal encoder and // therefore the caller is responsible for encoding the bytes. No error will be // thrown if the bytes does not conform to the schema given to NewEncoder, but @@ -400,7 +469,7 @@ type ocfHeader struct { Sync [16]byte } -func readHeader(reader *avro.Reader) (*ocfHeader, error) { +func readHeader(reader *avro.Reader, schemaCache *avro.SchemaCache) (*ocfHeader, error) { var h Header reader.ReadVal(HeaderSchema, &h) if reader.Error != nil { @@ -410,7 +479,7 @@ func readHeader(reader *avro.Reader) (*ocfHeader, error) { if h.Magic != magicBytes { return nil, errors.New("invalid avro file") } - schema, err := avro.Parse(string(h.Meta[schemaKey])) + schema, err := avro.ParseBytesWithCache(h.Meta[schemaKey], "", schemaCache) if err != nil { return nil, err } @@ -447,3 +516,11 @@ func skipToEnd(reader *avro.Reader, sync [16]byte) error { } } } + +func defaultMarshalSchema(schema avro.Schema) ([]byte, error) { + return []byte(schema.String()), nil +} + +func fullMarshalSchema(schema avro.Schema) ([]byte, error) { + return json.Marshal(schema) +} diff --git a/ocf/ocf_test.go b/ocf/ocf_test.go index f5405c0..38a462b 100644 --- a/ocf/ocf_test.go +++ b/ocf/ocf_test.go @@ -3,6 +3,7 @@ package ocf_test import ( "bytes" "compress/flate" + "encoding/json" "errors" "flag" "io" @@ -967,6 +968,185 @@ func TestEncoder_WriteHeaderError(t *testing.T) { assert.Error(t, err) } +func TestWithSchemaCache(t *testing.T) { + schema := `{ + "type": "record", + "name": "Foo", + "namespace": "foo.bar.baz", + "fields": [ + { + "name": "name", + "type": "string" + }, + { + "name": "id", + "type": "long" + }, + { + "name": "meta", + "type": { + "type": "array", + "items": { + "type": "record", + "name": "FooMetadataEntry", + "namespace": "foo.bar.baz", + "fields": [ + { + "name": "key", + "type": "string" + }, + { + "name": "values", + "type": { + "type": "array", + "items": "string" + } + } + ] + } + } + } + ] + }` + type metaEntry struct { + Key string `avro:"key"` + Values []string `avro:"values"` + } + type foo struct { + Name string `avro:"name"` + ID int64 `avro:"id"` + Meta []metaEntry `avro:"meta"` + } + encoderCache := &avro.SchemaCache{} + var buf bytes.Buffer + enc, err := ocf.NewEncoder(schema, &buf, ocf.WithEncoderSchemaCache(encoderCache)) + require.NoError(t, err) + val := foo{ + Name: "Bob Loblaw", + ID: 42, + Meta: []metaEntry{ + { + Key: "abc", + Values: []string{"123", "456"}, + }, + }, + } + require.NoError(t, enc.Encode(val)) + require.NoError(t, enc.Close()) + + assert.NotNil(t, encoderCache.Get("foo.bar.baz.Foo")) + assert.NotNil(t, encoderCache.Get("foo.bar.baz.FooMetadataEntry")) + assert.Nil(t, avro.DefaultSchemaCache.Get("foo.bar.baz.Foo")) + assert.Nil(t, avro.DefaultSchemaCache.Get("foo.bar.baz.FooMetadataEntry")) + + decoderCache := &avro.SchemaCache{} + dec, err := ocf.NewDecoder(&buf, ocf.WithDecoderSchemaCache(decoderCache)) + require.NoError(t, err) + require.True(t, dec.HasNext()) + var roundTripVal foo + require.NoError(t, dec.Decode(&roundTripVal)) + require.False(t, dec.HasNext()) + require.Equal(t, val, roundTripVal) + + assert.NotNil(t, decoderCache.Get("foo.bar.baz.Foo")) + assert.NotNil(t, decoderCache.Get("foo.bar.baz.FooMetadataEntry")) + assert.Nil(t, avro.DefaultSchemaCache.Get("foo.bar.baz.Foo")) + assert.Nil(t, avro.DefaultSchemaCache.Get("foo.bar.baz.FooMetadataEntry")) +} + +func TestWithSchemaMarshaler(t *testing.T) { + schema := `{ + "type": "record", + "name": "Bar", + "namespace": "foo.bar.baz", + "fields": [ + { + "name": "name", + "type": "string", + "field-id": 1 + }, + { + "name": "id", + "type": "long", + "field-id": 2 + }, + { + "name": "meta", + "type": { + "type": "array", + "items": { + "type": "record", + "name": "FooMetadataEntry", + "namespace": "foo.bar.baz", + "fields": [ + { + "name": "key", + "type": "string", + "field-id": 4 + }, + { + "name": "values", + "type": { + "type": "array", + "items": "string", + "element-id": 6 + }, + "field-id": 5 + } + ] + } + }, + "field-id": 3 + } + ] + }` + parsedSchema := avro.MustParse(schema) + type metaEntry struct { + Key string `avro:"key"` + Values []string `avro:"values"` + } + type foo struct { + Name string `avro:"name"` + ID int64 `avro:"id"` + Meta []metaEntry `avro:"meta"` + } + var buf bytes.Buffer + enc, err := ocf.NewEncoderWithSchema(parsedSchema, &buf, ocf.WithSchemaMarshaler(ocf.FullSchemaMarshaler)) + require.NoError(t, err) + val := foo{ + Name: "Bob Loblaw", + ID: 42, + Meta: []metaEntry{ + { + Key: "abc", + Values: []string{"123", "456"}, + }, + }, + } + require.NoError(t, enc.Encode(val)) + require.NoError(t, enc.Close()) + + dec, err := ocf.NewDecoder(&buf) + require.NoError(t, err) + require.True(t, dec.HasNext()) + var roundTripVal foo + require.NoError(t, dec.Decode(&roundTripVal)) + require.False(t, dec.HasNext()) + require.Equal(t, val, roundTripVal) + + got, err := json.MarshalIndent(dec.Schema(), "", " ") + require.NoError(t, err) + + if *update { + err = os.WriteFile("testdata/full-schema.json", got, 0o644) + require.NoError(t, err) + } + + want, err := os.ReadFile("testdata/full-schema.json") + require.NoError(t, err) + assert.Equal(t, want, got) +} + func copyToTemp(t *testing.T, src string) *os.File { t.Helper() diff --git a/ocf/testdata/full-schema.json b/ocf/testdata/full-schema.json new file mode 100644 index 0000000..a1b5999 --- /dev/null +++ b/ocf/testdata/full-schema.json @@ -0,0 +1,43 @@ +{ + "name": "foo.bar.baz.Bar", + "type": "record", + "fields": [ + { + "name": "name", + "type": "string", + "field-id": 1 + }, + { + "name": "id", + "type": "long", + "field-id": 2 + }, + { + "name": "meta", + "type": { + "type": "array", + "items": { + "name": "foo.bar.baz.FooMetadataEntry", + "type": "record", + "fields": [ + { + "name": "key", + "type": "string", + "field-id": 4 + }, + { + "name": "values", + "type": { + "type": "array", + "items": "string", + "element-id": 6 + }, + "field-id": 5 + } + ] + } + }, + "field-id": 3 + } + ] +} \ No newline at end of file