Skip to content

Commit

Permalink
Merge pull request #564 from siscia/write_arrays
Browse files Browse the repository at this point in the history
Add support for arrays
  • Loading branch information
magiconair authored Mar 14, 2022
2 parents 486ed60 + 1081569 commit bda1dc9
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 3 deletions.
58 changes: 57 additions & 1 deletion ua/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ func decode(b []byte, val reflect.Value, name string) (n int, err error) {
val.SetString(buf.ReadString())
case reflect.Slice:
return decodeSlice(b, val, name)
case reflect.Array:
return decodeArray(b, val, name)
case reflect.Ptr:
return decode(b, val.Elem(), name)
case reflect.Struct:
Expand Down Expand Up @@ -126,7 +128,7 @@ func decodeSlice(b []byte, val reflect.Value, name string) (int, error) {
}

if n > math.MaxInt32 {
return buf.Pos(), errors.Errorf("array too large: %d", n)
return buf.Pos(), errors.Errorf("array too large: %d > %d", n, math.MaxInt32)
}

// elemType is the type of the slice elements
Expand Down Expand Up @@ -163,3 +165,57 @@ func decodeSlice(b []byte, val reflect.Value, name string) (int, error) {

return pos, nil
}

func decodeArray(b []byte, val reflect.Value, name string) (int, error) {
buf := NewBuffer(b)
n := buf.ReadUint32()
if buf.Error() != nil {
return buf.Pos(), buf.Error()
}

if n == null {
return buf.Pos(), nil
}

if n > math.MaxInt32 {
return buf.Pos(), errors.Errorf("array too large: %d > %d", n, math.MaxInt32)
}

if n > uint32(val.Len()) {
return buf.Pos(), errors.Errorf("array too large: %d > %d", n, val.Len())
}

// elemType is the type of the slice elements
// e.g. *Foo for []*Foo
elemType := val.Type().Elem()
// fmt.Println("elemType: ", elemType.String())

// fast path for []byte
if elemType.Kind() == reflect.Uint8 {
// fmt.Println("decode: []byte fast path")
reflect.Copy(val, reflect.ValueOf(buf.ReadN(int(n))))
return buf.Pos(), buf.Error()
}

pos := buf.Pos()
// a is a pointer to an array [n]*Foo, where n is know at compile time
a := reflect.New(val.Type()).Elem()
for i := 0; i < int(n); i++ {

// if the slice elements are pointers we need to create
// them before we can marshal data into them.
if elemType.Kind() == reflect.Ptr {
a.Index(i).Set(reflect.New(elemType.Elem()))
}

ename := fmt.Sprintf("%s[%d]", name, i)
m, err := decode(b[pos:], a.Index(i), ename)
if err != nil {
return pos, err
}
pos += m
}
val.Set(a)

return pos, nil
}
155 changes: 153 additions & 2 deletions ua/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ type B struct {
S []*A
}

type C struct {
A [2]int32
B [2]byte
}

func TestCodec(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -112,6 +117,54 @@ func TestCodec(t *testing.T) {
0x67, 0x45, 0x00, 0x00,
},
},
{
name: "[2]byte",
v: &struct{ V [2]byte }{[2]byte{0x12, 0x34}},
b: []byte{
// length
0x02, 0x00, 0x00, 0x00,
// elem 1
0x12,
// elem 2
0x34,
},
},
{
name: "[2]byte{1}",
v: &struct{ V [2]byte }{[2]byte{0x12}},
b: []byte{
// length
0x02, 0x00, 0x00, 0x00,
// elem 1
0x12,
// elem 2
0x00,
},
},
{
name: "[2]uint32",
v: &struct{ V [2]uint32 }{[2]uint32{0x1234, 0x4567}},
b: []byte{
// length
0x02, 0x00, 0x00, 0x00,
// elem 1
0x34, 0x12, 0x00, 0x00,
// elem 2
0x67, 0x45, 0x00, 0x00,
},
},
{
name: "[2]uint32{1}",
v: &struct{ V [2]uint32 }{[2]uint32{1}},
b: []byte{
// length
0x02, 0x00, 0x00, 0x00,
// elem 1
0x01, 0x00, 0x00, 0x00,
// zero element of the array
0x00, 0x00, 0x00, 0x00,
},
},
{
name: "string",
v: &struct{ V string }{"abc"},
Expand Down Expand Up @@ -180,6 +233,28 @@ func TestCodec(t *testing.T) {
0x03, 0x00, 0x00, 0x00,
},
},
{
name: "[2]byte",
v: &struct{ V [2]byte }{[2]byte{}},
b: []byte{
// length
0x02, 0x00, 0x00, 0x00,
// values
0x00,
0x00,
},
},
{
name: "[2]uint32",
v: &struct{ V [2]uint32 }{[2]uint32{}},
b: []byte{
// length
0x02, 0x00, 0x00, 0x00,
// values
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
},
},
{
name: "[]*A",
v: &struct{ V []*A }{
Expand All @@ -206,7 +281,7 @@ func TestCodec(t *testing.T) {
},
},
b: []byte{
// B.A.N
// B.A.V
0x34, 0x12, 0x00, 0x00,
// B.A.S == nil
0xff, 0xff, 0xff, 0xff,
Expand All @@ -222,7 +297,7 @@ func TestCodec(t *testing.T) {
},
},
b: []byte{
// B.A.N
// B.A.V
0x90, 0x78, 0x00, 0x00,
// len(B.A.S)
0x02, 0x00, 0x00, 0x00,
Expand All @@ -232,6 +307,66 @@ func TestCodec(t *testing.T) {
0x67, 0x45, 0x00, 0x00,
},
},
{
name: "&C",
v: &C{A: [2]int32{1, 2}, B: [2]byte{3, 4}},
b: []byte{
// len(C.A)
0x02, 0x00, 0x00, 0x00,
// C.A[0]
0x01, 0x00, 0x00, 0x00,
// C.A[1]
0x02, 0x00, 0x00, 0x00,
// len(C.B)
0x02, 0x00, 0x00, 0x00,
// C.B[0]
0x03,
// C.B[1]
0x04,
},
},
{
name: "[3]C",
v: &struct{ V [3]C }{[3]C{
{},
{A: [2]int32{7}, B: [2]byte{1}},
{A: [2]int32{0, 9}, B: [2]byte{3, 4}},
}},
b: []byte{
// len(V)
0x03, 0x00, 0x00, 0x00,
// len(V[0].A)
0x02, 0x00, 0x00, 0x00,
// V[0].A
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
// len(V[0].B)
0x02, 0x00, 0x00, 0x00,
// V[0].B
0x00,
0x00,
// len(V[1].A)
0x02, 0x00, 0x00, 0x00,
// V[1].A
0x07, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
// len(V[1].B)
0x02, 0x00, 0x00, 0x00,
// V[1].B
0x01,
0x00,
// len(V[2].A)
0x02, 0x00, 0x00, 0x00,
// V[2].A
0x00, 0x00, 0x00, 0x00,
0x09, 0x00, 0x00, 0x00,
// len(V[2].B)
0x02, 0x00, 0x00, 0x00,
// V[2].B
0x03,
0x04,
},
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -266,3 +401,19 @@ func TestCodec(t *testing.T) {
})
}
}

func TestFailDecodeArray(t *testing.T) {
b := []byte{
// len
0x03, 0x00, 0x00, 0x00,
// Values
0x00, 0x00, 0x00, 0x00, // 0
0x00, 0x00, 0x00, 0x00, // 1
0x07, 0x00, 0x00, 0x00, // 7
}
var a [2]int32
_, err := Decode(b, &a)
if err == nil {
t.Fatalf("was expecting error for tryig to decode a stream of bytes with length 3 into an array of size 2")
}
}
33 changes: 33 additions & 0 deletions ua/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ func encode(val reflect.Value, name string) ([]byte, error) {
return writeStruct(val, name)
case reflect.Slice:
return writeSlice(val, name)
case reflect.Array:
return writeArray(val, name)
default:
return nil, errors.Errorf("unsupported type: %s", val.Type())
}
Expand Down Expand Up @@ -136,3 +138,34 @@ func writeSlice(val reflect.Value, name string) ([]byte, error) {
}
return buf.Bytes(), buf.Error()
}

func writeArray(val reflect.Value, name string) ([]byte, error) {
buf := NewBuffer(nil)

if val.Len() > math.MaxInt32 {
return nil, errors.Errorf("array too large: %d > %d", val.Len(), math.MaxInt32)
}

buf.WriteUint32(uint32(val.Len()))

// fast path for []byte
if val.Type().Elem().Kind() == reflect.Uint8 {
// fmt.Println("encode: []byte fast path")
b := make([]byte, val.Len())
reflect.Copy(reflect.ValueOf(b), val)
buf.Write(b)
return buf.Bytes(), buf.Error()
}

// loop over elements
// we write all the elements, also the zero values
for i := 0; i < val.Len(); i++ {
ename := fmt.Sprintf("%s[%d]", name, i)
b, err := encode(val.Index(i), ename)
if err != nil {
return nil, err
}
buf.Write(b)
}
return buf.Bytes(), buf.Error()
}

0 comments on commit bda1dc9

Please sign in to comment.