diff --git a/go.mod b/go.mod index caf519bf..696ef6e9 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/ettle/strcase v0.2.0 github.com/golang/snappy v0.0.4 github.com/json-iterator/go v1.1.12 + github.com/klauspost/compress v1.17.4 github.com/mitchellh/mapstructure v1.5.0 github.com/modern-go/reflect2 v1.0.2 github.com/stretchr/testify v1.7.1 diff --git a/go.sum b/go.sum index c154e11f..76652a96 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= +github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/ocf/codec.go b/ocf/codec.go index 2340a6dd..422d2c96 100644 --- a/ocf/codec.go +++ b/ocf/codec.go @@ -10,6 +10,7 @@ import ( "io" "github.com/golang/snappy" + "github.com/klauspost/compress/zstd" ) // CodecName represents a compression codec name. @@ -17,9 +18,10 @@ type CodecName string // Supported compression codecs. const ( - Null CodecName = "null" - Deflate CodecName = "deflate" - Snappy CodecName = "snappy" + Null CodecName = "null" + Deflate CodecName = "deflate" + Snappy CodecName = "snappy" + ZStandard CodecName = "zstandard" ) func resolveCodec(name CodecName, lvl int) (Codec, error) { @@ -33,6 +35,9 @@ func resolveCodec(name CodecName, lvl int) (Codec, error) { case Snappy: return &SnappyCodec{}, nil + case ZStandard: + return &ZStandardCodec{}, nil + default: return nil, fmt.Errorf("unknown codec %s", name) } @@ -120,3 +125,22 @@ func (*SnappyCodec) Encode(b []byte) []byte { return dst } + +// ZStandardCodec is a zstandard compression codec. +type ZStandardCodec struct{} + +// Decode decodes the given bytes. +func (*ZStandardCodec) Decode(b []byte) ([]byte, error) { + dec, _ := zstd.NewReader(nil) + defer dec.Close() + + return dec.DecodeAll(b, nil) +} + +// Encode encodes the given bytes. +func (*ZStandardCodec) Encode(b []byte) []byte { + enc, _ := zstd.NewWriter(nil) + defer func() { _ = enc.Close() }() + + return enc.EncodeAll(b, nil) +} diff --git a/ocf/ocf_test.go b/ocf/ocf_test.go index 3b02a082..102a1da4 100644 --- a/ocf/ocf_test.go +++ b/ocf/ocf_test.go @@ -311,6 +311,65 @@ func TestDecoder_WithSnappyHandlesInvalidCRC(t *testing.T) { assert.Error(t, dec.Error()) } +func TestDecoder_WithZStandard(t *testing.T) { + unionStr := "union value" + want := FullRecord{ + Strings: []string{"string1", "string2", "string3", "string4", "string5"}, + Longs: []int64{1, 2, 3, 4, 5}, + Enum: "C", + Map: map[string]int{ + "key1": 1, + "key2": 2, + "key3": 3, + "key4": 4, + "key5": 5, + }, + Nullable: &unionStr, + Fixed: [16]byte{0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04}, + Record: &TestRecord{ + Long: 1925639126735, + String: "I am a test record", + Int: 666, + Float: 7171.17, + Double: 916734926348163.01973408746523, + Bool: true, + }, + } + + f, err := os.Open("testdata/full-zstd.avro") + require.NoError(t, err) + t.Cleanup(func() { _ = f.Close() }) + + dec, err := ocf.NewDecoder(f) + require.NoError(t, err) + + var count int + for dec.HasNext() { + count++ + var got FullRecord + err = dec.Decode(&got) + + require.NoError(t, err) + assert.Equal(t, want, got) + } + + require.NoError(t, dec.Error()) + assert.Equal(t, 1, count) +} + +func TestDecoder_WithZStandardHandlesInvalidData(t *testing.T) { + f, err := os.Open("testdata/zstd-invalid-data.avro") + require.NoError(t, err) + t.Cleanup(func() { _ = f.Close() }) + + dec, err := ocf.NewDecoder(f) + require.NoError(t, err) + + dec.HasNext() + + assert.Error(t, dec.Error()) +} + func TestDecoder_DecodeAvroError(t *testing.T) { data := []byte{'O', 'b', 'j', 0x01, 0x01, 0x26, 0x16, 'a', 'v', 'r', 'o', '.', 's', 'c', 'h', 'e', 'm', 'a', 0x0c, '"', 'l', 'o', 'n', 'g', '"', 0x00, 0xfb, 0x2b, 0x0f, 0x1a, 0xdd, 0xfd, 0x90, 0x7d, 0x87, 0x12, @@ -681,6 +740,43 @@ func TestEncoder_EncodeCompressesSnappy(t *testing.T) { assert.Equal(t, 938, buf.Len()) } +func TestEncoder_EncodeCompressesZStandard(t *testing.T) { + unionStr := "union value" + record := FullRecord{ + Strings: []string{"string1", "string2", "string3", "string4", "string5"}, + Longs: []int64{1, 2, 3, 4, 5}, + Enum: "C", + Map: map[string]int{ + "key1": 1, + "key2": 2, + "key3": 3, + "key4": 4, + "key5": 5, + }, + Nullable: &unionStr, + Fixed: [16]byte{0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04}, + Record: &TestRecord{ + Long: 1925639126735, + String: "I am a test record", + Int: 666, + Float: 7171.17, + Double: 916734926348163.01973408746523, + Bool: true, + }, + } + + buf := &bytes.Buffer{} + enc, _ := ocf.NewEncoder(schema, buf, ocf.WithCodec(ocf.ZStandard)) + + err := enc.Encode(record) + assert.NoError(t, err) + + err = enc.Close() + + require.NoError(t, err) + assert.Equal(t, 951, buf.Len()) +} + func TestEncoder_EncodeError(t *testing.T) { buf := &bytes.Buffer{} enc, err := ocf.NewEncoder(`"long"`, buf) diff --git a/ocf/testdata/full-zstd.avro b/ocf/testdata/full-zstd.avro new file mode 100644 index 00000000..4e663f71 Binary files /dev/null and b/ocf/testdata/full-zstd.avro differ diff --git a/ocf/testdata/zstd-invalid-data.avro b/ocf/testdata/zstd-invalid-data.avro new file mode 100644 index 00000000..be12de5d Binary files /dev/null and b/ocf/testdata/zstd-invalid-data.avro differ