Skip to content

Commit

Permalink
💊 fix #48 and #51. last issue maybe not fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
quenbyako committed Dec 23, 2020
1 parent faf556e commit d682540
Show file tree
Hide file tree
Showing 15 changed files with 392 additions and 93 deletions.
37 changes: 29 additions & 8 deletions encoding/tl/cursor_r.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,29 @@ import (
type Decoder struct {
r io.Reader
err error

// see Decoder.ExpectTypesInInterface description
expectedTypes []reflect.Type
}

// NewDecoder returns a new decoder that reads from r.
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
}

// ExpectTypesInInterface defines, how decoder must parse implicit objects.
// how does expectedTypes works:
// So, imagine: you want parse []int32, but also you can get []int64, or SomeCustomType, or even [][]bool.
// How to deal it?
// expectedTypes store your predictions (like "if you got unknown type, parse it as int32, not int64")
// also, if you have predictions deeper than first unknown type, you can say decoder to use predicted vals
//
// So, next time, when you'll have strucre object with interface{} which expect contains []float64 or sort
// of — use this feature via d.ExpectTypesInInterface()
func (d *Decoder) ExpectTypesInInterface(types ...reflect.Type) {
d.expectedTypes = types
}

func (d *Decoder) read(buf []byte) {
if d.err != nil {
return
Expand Down Expand Up @@ -140,19 +156,24 @@ func (d *Decoder) DumpWithoutRead() ([]byte, error) {
}

func (d *Decoder) PopVector(as reflect.Type) any {
if d.err != nil {
return nil
}
return d.popVector(as, false)
}

crc := d.PopCRC()
func (d *Decoder) popVector(as reflect.Type, ignoreCRC bool) any {
if d.err != nil {
d.err = errors.Wrap(d.err, "read crc")
return nil
}
if !ignoreCRC {
crc := d.PopCRC()
if d.err != nil {
d.err = errors.Wrap(d.err, "read crc")
return nil
}

if crc != CrcVector {
d.err = fmt.Errorf("not a vector: %#v, want: %#v", crc, CrcVector)
return nil
if crc != CrcVector {
d.err = fmt.Errorf("not a vector: 0x%08x, want: 0x%08x", crc, CrcVector)
return nil
}
}

size := d.PopUint()
Expand Down
57 changes: 48 additions & 9 deletions encoding/tl/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ import (
"github.com/pkg/errors"
)

func Decode(data []byte, v any) error {
if v == nil {
func Decode(data []byte, res any) error {
if res == nil {
return errors.New("can't unmarshal to nil value")
}
if reflect.TypeOf(v).Kind() != reflect.Ptr {
return fmt.Errorf("v value is not pointer as expected. got %v", reflect.TypeOf(v))
if reflect.TypeOf(res).Kind() != reflect.Ptr {
return fmt.Errorf("res value is not pointer as expected. got %v", reflect.TypeOf(res))
}

d := NewDecoder(bytes.NewReader(data))

d.decodeValue(reflect.ValueOf(v))
d.decodeValue(reflect.ValueOf(res))
if d.err != nil {
return errors.Wrapf(d.err, "decode %T", v)
return errors.Wrapf(d.err, "decode %T", res)
}

return nil
Expand All @@ -35,8 +35,14 @@ func Decode(data []byte, v any) error {
// due to TL doesn't provide mechanism for understanding is message a int or string, you MUST guarantee, that
// input stream DOES NOT contains any type WITHOUT its CRC code. So, strings, ints, floats, etc. CAN'T BE
// automatically parsed.
func DecodeUnknownObject(data []byte) (Object, error) {
//
// expectNextTypes is your predictions how decoder must parse objects hidden under interfaces.
// See Decoder.ExpectTypesInInterface description
func DecodeUnknownObject(data []byte, expectNextTypes ...reflect.Type) (Object, error) {
d := NewDecoder(bytes.NewReader(data))
if len(expectNextTypes) > 0 {
d.ExpectTypesInInterface(expectNextTypes...)
}

obj := d.decodeRegisteredObject()
if d.err != nil {
Expand Down Expand Up @@ -228,20 +234,53 @@ func (d *Decoder) decodeValueGeneral(value reflect.Value) interface{} {
return val
}

// decodeRegisteredObject пробует определить,
func (d *Decoder) decodeRegisteredObject() Object {
crc := d.PopCRC()
if d.err != nil {
d.err = errors.Wrap(d.err, "read crc")
}

_typ, ok := objectByCrc[crc]
var _typ reflect.Type

// firstly, we are checking specific crc situations.
// See https://github.com/xelaj/mtproto/issues/51
switch crc {
case CrcVector:
if len(d.expectedTypes) == 0 {
d.err = &ErrMustParseSlicesExplicitly{}
return nil
}
_typ = d.expectedTypes[0]
d.expectedTypes = d.expectedTypes[1:]

res := d.popVector(_typ.Elem(), true)
if d.err != nil {
return nil
}

return &InterfacedObject{res}

case CrcFalse:
return &InterfacedObject{false}

case CrcTrue:
return &InterfacedObject{true}

case CrcNull:
return &InterfacedObject{nil}
}

// in other ways we're trying to get object from registred crcs
var ok bool
_typ, ok = objectByCrc[crc]
if !ok {
msg, err := d.DumpWithoutRead()
if err != nil {
return nil
}

d.err = ErrRegisteredObjectNotFound{
d.err = &ErrRegisteredObjectNotFound{
Crc: crc,
Data: msg,
}
Expand Down
100 changes: 89 additions & 11 deletions encoding/tl/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
package tl_test

import (
"reflect"
"testing"

"github.com/k0kubun/pp"
"github.com/stretchr/testify/assert"
"github.com/xelaj/go-dry"

"github.com/xelaj/mtproto/encoding/tl"
)

Expand Down Expand Up @@ -80,19 +81,96 @@ func TestDecode(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var (
data = tt.data
v = tt.v
expected = tt.expected
wantErr = noErrAsDefault(tt.wantErr)
)
err := tl.Decode(data, v)
if !wantErr(t, err) {
pp.Println(dry.BytesEncodeHex(string(data)))
tt.wantErr = noErrAsDefault(tt.wantErr)

err := tl.Decode(tt.data, tt.v)
if !tt.wantErr(t, err) {
pp.Println(dry.BytesEncodeHex(string(tt.data)))
return
}
if err != nil {
assert.Equal(t, tt.expected, tt.v)
}
})
}
}

func TestDecodeUnknown(t *testing.T) {
tests := []struct {
name string
data []byte
hintsForDecoder []reflect.Type
expected any
wantErr assert.ErrorAssertionFunc
}{
{
name: "authSentCode",
// | CRC || Flag || CRC |
data: Hexed("0225005E020000008659BB3D0500000012316637366461306431353531313539363336008C15A372"),
expected: &AuthSentCode{
Type: &AuthSentCodeTypeApp{
Length: 5,
},
PhoneCodeHash: "1f76da0d1551159636",
NextType: 0x72a3158c,
Timeout: 0,
},
},
{
name: "poll-results",
data: Hexed("a3c1dcba1e00000015c4b51c02000000d2da6d3b010000000301020302000000d2da6d3b" +
"0000000003040506060000000c00000015c4b51c02000000050000000600000005616c616c610000" +
"15c4b51c00000000"),
expected: &PollResults{
Min: false,
Results: []*PollAnswerVoters{
{
Chosen: true,
Correct: false,
Option: []byte{
0x01, 0x02, 0x03,
},
Voters: 2,
},
{
Chosen: false,
Correct: false,
Option: []byte{
0x04, 0x05, 0x06,
},
Voters: 6,
},
},
TotalVoters: 12,
RecentVoters: []int32{
5,
6,
},
Solution: "alala",
SolutionEntities: []MessageEntity{},
},
},
{
name: "predicting-[]int64",
data: Hexed("15c4b51c00000000"),
expected: []int64{},
hintsForDecoder: []reflect.Type{reflect.TypeOf([]int64{})},
},
// TODO: отработать возможные ошибки
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.wantErr = noErrAsDefault(tt.wantErr)

res, err := tl.DecodeUnknownObject(tt.data, tt.hintsForDecoder...)
if !tt.wantErr(t, err) {
pp.Println(dry.BytesEncodeHex(string(tt.data)))
return
}

assert.Equal(t, expected, v)
if err == nil {
assert.Equal(t, tt.expected, tl.UnwrapNativeTypes(res))
}
})
}
}
Expand Down
5 changes: 2 additions & 3 deletions encoding/tl/encoder_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@ import (
"testing"

"github.com/xelaj/mtproto/encoding/tl"
"github.com/xelaj/mtproto/telegram"
)

func BenchmarkEncoder(b *testing.B) {
for i := 0; i < b.N; i++ {
tl.Marshal(&telegram.AccountInstallThemeParams{
tl.Marshal(&AccountInstallThemeParams{
Dark: true,
Format: "abc",
Theme: &telegram.InputThemeObj{
Theme: &InputThemeObj{
ID: 123,
AccessHash: 321,
},
Expand Down
10 changes: 8 additions & 2 deletions encoding/tl/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,14 @@ type ErrRegisteredObjectNotFound struct {
Data []byte
}

func (e ErrRegisteredObjectNotFound) Error() string {
return fmt.Sprintf("object with provided crc not registered: 0x%x", e.Crc)
func (e *ErrRegisteredObjectNotFound) Error() string {
return fmt.Sprintf("object with provided crc not registered: 0x%08x", e.Crc)
}

type ErrMustParseSlicesExplicitly struct{}

func (e *ErrMustParseSlicesExplicitly) Error() string {
return "got vector CRC code when parsing unknown object: vectors can't be parsed as predicted objects"
}

type ErrorPartialWrite struct {
Expand Down
22 changes: 22 additions & 0 deletions encoding/tl/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,25 @@ type Marshaler interface {
type Unmarshaler interface {
UnmarshalTL(*Decoder) error
}

// InterfacedObject is specific struct for handling bool types, slice and null as object.
// See https://github.com/xelaj/mtproto/issues/51
type InterfacedObject struct {
value interface{}
}

func (*InterfacedObject) CRC() uint32 {
panic("makes no sense")
}

func (*InterfacedObject) UnmarshalTL(*Decoder) error {
panic("impossible to (un)marshal hidden object. Use explicit methods")
}

func (*InterfacedObject) MarshalTL(*Encoder) error {
panic("impossible to (un)marshal hidden object. Use explicit methods")
}

func (i *InterfacedObject) Unwrap() interface{} {
return i.value
}
4 changes: 2 additions & 2 deletions encoding/tl/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ func registerEnum(o Object) {

func RegisterObjects(obs ...Object) {
for _, o := range obs {
if _, found := objectByCrc[o.CRC()]; found {
panic(fmt.Errorf("object with that crc already registered: %d", o.CRC()))
if val, found := objectByCrc[o.CRC()]; found {
panic(fmt.Errorf("object with that crc already registered as %v: 0x%08x", val.String(), o.CRC()))
}

registerObject(o)
Expand Down
8 changes: 8 additions & 0 deletions encoding/tl/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,11 @@ func sliceToInterfaceSlice(in any) []any {

return res
}

func UnwrapNativeTypes(in Object) interface{} {
if v, ok := in.(*InterfacedObject); ok {
return v.Unwrap()
}

return in
}
Loading

0 comments on commit d682540

Please sign in to comment.