Skip to content

Commit

Permalink
Support type conversions (#87)
Browse files Browse the repository at this point in the history
* Support type conversion on encode
* Support type conversion on read
* Fix negative int value read
  • Loading branch information
at-wat authored Dec 15, 2019
1 parent 70de7fd commit 6250d16
Show file tree
Hide file tree
Showing 5 changed files with 283 additions and 11 deletions.
31 changes: 31 additions & 0 deletions datatype.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

package ebml

import (
"reflect"
)

// Type represents EBML Element data type.
type Type int

Expand Down Expand Up @@ -51,3 +55,30 @@ func (t Type) String() string {
return "Unknown type"
}
}

func isConvertible(src, dst reflect.Type) bool {
switch src.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch dst.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
default:
return false
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
switch dst.Kind() {
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
default:
return false
}
case reflect.Float32, reflect.Float64:
switch dst.Kind() {
case reflect.Float32, reflect.Float64:
return true
default:
return false
}
}
return false
}
26 changes: 20 additions & 6 deletions unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import (
)

var (
errUnknownElement = errors.New("unknown element")
errIndefiniteType = errors.New("unmarshal to indefinite type")
errUnknownElement = errors.New("unknown element")
errIndefiniteType = errors.New("unmarshal to indefinite type")
errIncompatibleType = errors.New("unmarshal to incompatible type")
)

// Unmarshal EBML stream.
Expand Down Expand Up @@ -151,10 +152,23 @@ func readElement(r0 io.Reader, n int64, vo reflect.Value, depth int, pos uint64,
}
vr := reflect.ValueOf(val)
if vnext.IsValid() && vnext.CanSet() {
if vr.Type() == vnext.Type() {
vnext.Set(reflect.ValueOf(val))
} else if vnext.Kind() == reflect.Slice && vr.Type() == vnext.Type().Elem() {
vnext.Set(reflect.Append(vnext, reflect.ValueOf(val)))
switch {
case vr.Type() == vnext.Type():
vnext.Set(vr)
case isConvertible(vr.Type(), vnext.Type()):
vnext.Set(vr.Convert(vnext.Type()))
case vnext.Kind() == reflect.Slice:
t := vnext.Type().Elem()
switch {
case vr.Type() == t:
vnext.Set(reflect.Append(vnext, vr))
case isConvertible(vr.Type(), t):
vnext.Set(reflect.Append(vnext, vr.Convert(t)))
default:
return nil, errIncompatibleType
}
default:
return nil, errIncompatibleType
}
}
if elem != nil {
Expand Down
190 changes: 190 additions & 0 deletions unmarshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,124 @@ func ExampleUnmarshal() {
// Output: {{webm 2 2}}
}

func TestUnmarshal_Convert(t *testing.T) {
cases := map[string]struct {
b []byte
expected interface{}
}{
"UInt64ToUInt64": {
[]byte{0x42, 0x87, 0x81, 0x02},
struct {
DocTypeVersion uint64 `ebml:"EBMLDocTypeVersion"`
}{2},
},
"UInt64ToUInt32": {
[]byte{0x42, 0x87, 0x81, 0x02},
struct {
DocTypeVersion uint32 `ebml:"EBMLDocTypeVersion"`
}{2},
},
"UInt64ToUInt16": {
[]byte{0x42, 0x87, 0x81, 0x02},
struct {
DocTypeVersion uint16 `ebml:"EBMLDocTypeVersion"`
}{2},
},
"UInt64ToUInt8": {
[]byte{0x42, 0x87, 0x81, 0x02},
struct {
DocTypeVersion uint8 `ebml:"EBMLDocTypeVersion"`
}{2},
},
"UInt64ToUInt": {
[]byte{0x42, 0x87, 0x81, 0x02},
struct {
DocTypeVersion uint `ebml:"EBMLDocTypeVersion"`
}{2},
},
"Int64ToInt64": {
[]byte{0xFB, 0x81, 0xFF},
struct {
ReferenceBlock int64 `ebml:"ReferenceBlock"`
}{-1},
},
"Int64ToInt32": {
[]byte{0xFB, 0x81, 0xFF},
struct {
ReferenceBlock int32 `ebml:"ReferenceBlock"`
}{-1},
},
"Int64ToInt16": {
[]byte{0xFB, 0x81, 0xFF},
struct {
ReferenceBlock int16 `ebml:"ReferenceBlock"`
}{-1},
},
"Int64ToInt8": {
[]byte{0xFB, 0x81, 0xFF},
struct {
ReferenceBlock int8 `ebml:"ReferenceBlock"`
}{-1},
},
"Int64ToInt": {
[]byte{0xFB, 0x81, 0xFF},
struct {
ReferenceBlock int `ebml:"ReferenceBlock"`
}{-1},
},
"Float64ToFloat64": {
[]byte{0x44, 0x89, 0x84, 0x00, 0x00, 0x00, 0x00},
struct {
Duration float64 `ebml:"Duration"`
}{0.0},
},
"Float64ToFloat32": {
[]byte{0x44, 0x89, 0x84, 0x00, 0x00, 0x00, 0x00},
struct {
Duration float32 `ebml:"Duration"`
}{0.0},
},
"UInt64ToUInt64Slice": {
[]byte{0x42, 0x87, 0x81, 0x02},
struct {
DocTypeVersion []uint64 `ebml:"EBMLDocTypeVersion"`
}{[]uint64{2}},
},
"UInt64ToUInt32Slice": {
[]byte{0x42, 0x87, 0x81, 0x02},
struct {
DocTypeVersion []uint32 `ebml:"EBMLDocTypeVersion"`
}{[]uint32{2}},
},
"Int64ToInt32Slice": {
[]byte{0xFB, 0x81, 0xFF},
struct {
ReferenceBlock []int32 `ebml:"ReferenceBlock"`
}{[]int32{-1}},
},
"Float64ToFloat32Slice": {
[]byte{0x44, 0x89, 0x84, 0x00, 0x00, 0x00, 0x00},
struct {
Duration []float32 `ebml:"Duration"`
}{[]float32{0.0}},
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
ret := reflect.New(reflect.ValueOf(c.expected).Type())
if err := Unmarshal(bytes.NewReader(c.b), ret.Interface()); err != nil {
t.Fatalf("Unexpected error: %v\n", err)
}

if !reflect.DeepEqual(c.expected, ret.Elem().Interface()) {
t.Errorf("Unexpected convert result, expected: %v, got %v",
c.expected, ret.Elem().Interface())
}
})
}
}

func TestUnmarshal_OptionError(t *testing.T) {
errExpected := errors.New("an error")
err := Unmarshal(&bytes.Buffer{}, &struct{}{},
Expand Down Expand Up @@ -160,6 +278,78 @@ func TestUnmarshal_Error(t *testing.T) {
})
}
})
t.Run("Incompatible", func(t *testing.T) {
cases := map[string]struct {
b []byte
ret interface{}
err error
}{
"UInt64ToInt64": {
b: []byte{0x42, 0x87, 0x81, 0x02},
ret: &struct {
DocTypeVersion int64 `ebml:"EBMLDocTypeVersion"`
}{},
err: errIncompatibleType,
},
"Int64ToUInt64": {
b: []byte{0xFB, 0x81, 0xFF},
ret: &struct {
ReferenceBlock uint64 `ebml:"ReferenceBlock"`
}{},
err: errIncompatibleType,
},
"Float64ToInt64": {
b: []byte{0x44, 0x89, 0x84, 0x00, 0x00, 0x00, 0x00},
ret: &struct {
Duration int64 `ebml:"Duration"`
}{},
err: errIncompatibleType,
},
"StringToInt64": {
b: []byte{0x42, 0x82, 0x85, 0x77, 0x65, 0x62, 0x6d, 0x00},
ret: &struct {
EBMLDocType int64 `ebml:"EBMLDocType"`
}{},
err: errIncompatibleType,
},
"UInt64ToInt64Slice": {
b: []byte{0x42, 0x87, 0x81, 0x02},
ret: &struct {
DocTypeVersion []int64 `ebml:"EBMLDocTypeVersion"`
}{},
err: errIncompatibleType,
},
"Int64ToUInt64Slice": {
b: []byte{0xFB, 0x81, 0xFF},
ret: &struct {
ReferenceBlock []uint64 `ebml:"ReferenceBlock"`
}{},
err: errIncompatibleType,
},
"Float64ToInt64Slice": {
b: []byte{0x44, 0x89, 0x84, 0x00, 0x00, 0x00, 0x00},
ret: &struct {
Duration []int64 `ebml:"Duration"`
}{},
err: errIncompatibleType,
},
"StringToInt64Slice": {
b: []byte{0x42, 0x82, 0x85, 0x77, 0x65, 0x62, 0x6d, 0x00},
ret: &struct {
EBMLDocType []int64 `ebml:"EBMLDocType"`
}{},
err: errIncompatibleType,
},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
if err := Unmarshal(bytes.NewBuffer(c.b), c.ret); err != c.err {
t.Errorf("Unexpected error, expected: %v, got: %v\n", c.err, err)
}
})
}
})

}

func BenchmarkUnmarshal(b *testing.B) {
Expand Down
39 changes: 34 additions & 5 deletions value.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,14 @@ func readInt(r io.Reader, n uint64) (interface{}, error) {
if err != nil {
return 0, err
}
return int64(v.(uint64)), nil
v64 := v.(uint64)
if n != 8 && (v64&(1<<(n*8-1))) != 0 {
// negative value
for i := n; i < 8; i++ {
v64 |= 0xFF << (i * 8)
}
}
return int64(v64), nil
}
func readUInt(r io.Reader, n uint64) (interface{}, error) {
bs := make([]byte, n)
Expand Down Expand Up @@ -262,15 +269,37 @@ func encodeString(i interface{}, n uint64) ([]byte, error) {
return append([]byte(v), bytes.Repeat([]byte{0x00}, int(n)-len(v))...), nil
}
func encodeInt(i interface{}, n uint64) ([]byte, error) {
v, ok := i.(int64)
if !ok {
var v int64
switch v2 := i.(type) {
case int:
v = int64(v2)
case int8:
v = int64(v2)
case int16:
v = int64(v2)
case int32:
v = int64(v2)
case int64:
v = v2
default:
return []byte{}, errInvalidType
}
return encodeUInt(uint64(v), n)
}
func encodeUInt(i interface{}, n uint64) ([]byte, error) {
v, ok := i.(uint64)
if !ok {
var v uint64
switch v2 := i.(type) {
case uint:
v = uint64(v2)
case uint8:
v = uint64(v2)
case uint16:
v = uint64(v2)
case uint32:
v = uint64(v2)
case uint64:
v = v2
default:
return []byte{}, errInvalidType
}
switch {
Expand Down
8 changes: 8 additions & 0 deletions value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ func TestValue(t *testing.T) {
"Block": {[]byte{0x85, 0x12, 0x34, 0x80, 0x34, 0x56}, TypeBlock,
Block{uint64(5), int16(0x1234), true, false, LacingNo, false, nil, [][]byte{{0x34, 0x56}}}, 0, nil,
},
"ConvertInt8": {[]byte{0x01}, TypeInt, int64(0x01), 0, int8(0x01)},
"ConvertInt16": {[]byte{0x01, 0x02}, TypeInt, int64(0x0102), 0, int16(0x0102)},
"ConvertInt32": {[]byte{0x01, 0x02, 0x03, 0x04}, TypeInt, int64(0x01020304), 0, int32(0x01020304)},
"ConvertInt": {[]byte{0x01, 0x02, 0x03, 0x04}, TypeInt, int64(0x01020304), 0, int(0x01020304)},
"ConvertUInt8": {[]byte{0x01}, TypeUInt, uint64(0x01), 0, uint8(0x01)},
"ConvertUInt16": {[]byte{0x01, 0x02}, TypeUInt, uint64(0x0102), 0, uint16(0x0102)},
"ConvertUInt32": {[]byte{0x01, 0x02, 0x03, 0x04}, TypeUInt, uint64(0x01020304), 0, uint32(0x01020304)},
"ConvertUInt": {[]byte{0x01, 0x02, 0x03, 0x04}, TypeUInt, uint64(0x01020304), 0, uint(0x01020304)},
}
for n, c := range testCases {
t.Run("Read "+n, func(t *testing.T) {
Expand Down

0 comments on commit 6250d16

Please sign in to comment.