diff --git a/client.go b/client.go index 95cf03f..4794139 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,8 @@ package coap import ( + "errors" + "fmt" "net" "time" ) @@ -18,28 +20,55 @@ const ( // Conn is a CoAP client connection. type Conn struct { - conn *net.UDPConn - buf []byte + connTCP *net.TCPConn + buf []byte + conn *net.UDPConn + Net string +} + +type Addr struct { + Tcp *net.TCPAddr + Udp *net.UDPAddr } // Dial connects a CoAP client. func Dial(n, addr string) (*Conn, error) { - uaddr, err := net.ResolveUDPAddr(n, addr) - if err != nil { - return nil, err - } + switch n { + case "udp": + uaddr, err := net.ResolveUDPAddr(n, addr) + if err != nil { + return nil, err + } - s, err := net.DialUDP("udp", nil, uaddr) - if err != nil { - return nil, err - } + s, err := net.DialUDP("udp", nil, uaddr) + if err != nil { + return nil, err + } + + return &Conn{conn: s, buf: make([]byte, maxPktLen), connTCP: nil}, nil + case "tcp": + taddr, err := net.ResolveTCPAddr(n, addr) + if err != nil { + return nil, err + } - return &Conn{s, make([]byte, maxPktLen)}, nil + s, err := net.DialTCP("tcp", nil, taddr) + if err != nil { + return nil, err + } + + return &Conn{conn: nil, buf: make([]byte, maxPktLen), connTCP: s}, nil + default: + return nil, errors.New("unrecognized network type") + } } // Send a message. Get a response if there is one. -func (c *Conn) Send(req Message) (*Message, error) { - err := Transmit(c.conn, nil, req) +func (c *Conn) Send(req Message) (Message, error) { + + //defer c.Close() + //not sure if that's a good idea to have it be default behavior. Maybe have it be based on a setting in Conn? + err := Transmit(c, Addr{}, req) if err != nil { return nil, err } @@ -47,20 +76,54 @@ func (c *Conn) Send(req Message) (*Message, error) { if !req.IsConfirmable() { return nil, nil } - - rv, err := Receive(c.conn, c.buf) + fmt.Println("about to receive in send()") + rv, err := Receive(c, c.buf) if err != nil { return nil, err } - return &rv, nil + return rv, nil } // Receive a message. -func (c *Conn) Receive() (*Message, error) { - rv, err := Receive(c.conn, c.buf) +func (c *Conn) Receive() (Message, error) { + rv, err := Receive(c, c.buf) if err != nil { return nil, err } - return &rv, nil + return rv, nil +} + +func (c *Conn) Network() (string, error) { + fmt.Println("conn.Network() called") + if c.Net != "" { + return c.Net, nil + } + if c.conn != nil && c.connTCP != nil { + fmt.Println("satisfied conditions for udp/tcp both being non-nil") + return "", errors.New("multiple non-nil connections in Conn. it should be only one") + } + if c.conn != nil { + return "udp", nil + } + if c.connTCP != nil { + return "tcp", nil + } else { + fmt.Println("both connections are nil") + return "", errors.New("all connections in Conn struct are nil") + } +} + +func (c *Conn) Close() error { + n, err := c.Network() + if err != nil { + return err + } + switch n { + case "udp": + return c.conn.Close() + case "tcp": + return c.connTCP.Close() + } + return err } diff --git a/example/client/goap_client.go b/example/client/goap_client.go index 36aca0c..3add324 100644 --- a/example/client/goap_client.go +++ b/example/client/goap_client.go @@ -4,17 +4,17 @@ import ( "log" "os" - "github.com/dustin/go-coap" + "github.com/runtimeco/go-coap" ) func main() { - req := coap.Message{ + req := coap.NewDgramMessage(coap.MessageParams{ Type: coap.Confirmable, Code: coap.GET, MessageID: 12345, Payload: []byte("hello, world!"), - } + }) path := "/some/path" if len(os.Args) > 1 { diff --git a/example/obsclient/obsclient.go b/example/obsclient/obsclient.go index 77507f9..d5cd7ac 100644 --- a/example/obsclient/obsclient.go +++ b/example/obsclient/obsclient.go @@ -3,16 +3,16 @@ package main import ( "log" - "github.com/dustin/go-coap" + "github.com/runtimeco/go-coap" ) func main() { - req := coap.Message{ + req := coap.NewDgramMessage(coap.MessageParams{ Type: coap.NonConfirmable, Code: coap.GET, MessageID: 12345, - } + }) req.AddOption(coap.Observe, 1) req.SetPathString("/some/path") diff --git a/example/obsserver/obsserver.go b/example/obsserver/obsserver.go index 42736f6..f158491 100644 --- a/example/obsserver/obsserver.go +++ b/example/obsserver/obsserver.go @@ -6,19 +6,19 @@ import ( "net" "time" - "github.com/dustin/go-coap" + "github.com/runtimeco/go-coap" ) -func periodicTransmitter(l *net.UDPConn, a *net.UDPAddr, m *coap.Message) { +func periodicTransmitter(l *net.UDPConn, a *net.UDPAddr, m coap.Message) { subded := time.Now() for { - msg := coap.Message{ + msg := coap.NewDgramMessage(coap.MessageParams{ Type: coap.Acknowledgement, Code: coap.Content, - MessageID: m.MessageID, + MessageID: m.MessageID(), Payload: []byte(fmt.Sprintf("Been running for %v", time.Since(subded))), - } + }) msg.SetOption(coap.ContentFormat, coap.TextPlain) msg.SetOption(coap.LocationPath, m.Path()) @@ -36,9 +36,9 @@ func periodicTransmitter(l *net.UDPConn, a *net.UDPAddr, m *coap.Message) { func main() { log.Fatal(coap.ListenAndServe("udp", ":5683", - coap.FuncHandler(func(l *net.UDPConn, a *net.UDPAddr, m *coap.Message) *coap.Message { + coap.FuncHandler(func(l *net.UDPConn, a *net.UDPAddr, m coap.Message) coap.Message { log.Printf("Got message path=%q: %#v from %v", m.Path(), m, a) - if m.Code == coap.GET && m.Option(coap.Observe) != nil { + if m.Code() == coap.GET && m.Option(coap.Observe) != nil { if value, ok := m.Option(coap.Observe).([]uint8); ok && len(value) >= 1 && value[0] == 1 { go periodicTransmitter(l, a, m) diff --git a/example/server/coap_server.go b/example/server/coap_server.go index fda20ac..fda9c1f 100644 --- a/example/server/coap_server.go +++ b/example/server/coap_server.go @@ -4,19 +4,19 @@ import ( "log" "net" - "github.com/dustin/go-coap" + "github.com/runtimeco/go-coap" ) -func handleA(l *net.UDPConn, a *net.UDPAddr, m *coap.Message) *coap.Message { +func handleA(l *net.UDPConn, a *net.UDPAddr, m coap.Message) coap.Message { log.Printf("Got message in handleA: path=%q: %#v from %v", m.Path(), m, a) if m.IsConfirmable() { - res := &coap.Message{ + res := coap.NewDgramMessage(coap.MessageParams{ Type: coap.Acknowledgement, Code: coap.Content, - MessageID: m.MessageID, - Token: m.Token, + MessageID: m.MessageID(), + Token: m.Token(), Payload: []byte("hello to you!"), - } + }) res.SetOption(coap.ContentFormat, coap.TextPlain) log.Printf("Transmitting from A %#v", res) @@ -25,16 +25,16 @@ func handleA(l *net.UDPConn, a *net.UDPAddr, m *coap.Message) *coap.Message { return nil } -func handleB(l *net.UDPConn, a *net.UDPAddr, m *coap.Message) *coap.Message { +func handleB(l *net.UDPConn, a *net.UDPAddr, m coap.Message) coap.Message { log.Printf("Got message in handleB: path=%q: %#v from %v", m.Path(), m, a) if m.IsConfirmable() { - res := &coap.Message{ + res := coap.NewDgramMessage(coap.MessageParams{ Type: coap.Acknowledgement, Code: coap.Content, - MessageID: m.MessageID, - Token: m.Token, + MessageID: m.MessageID(), + Token: m.Token(), Payload: []byte("good bye!"), - } + }) res.SetOption(coap.ContentFormat, coap.TextPlain) log.Printf("Transmitting from B %#v", res) diff --git a/message.go b/message.go index 630ad2f..3a4b7b9 100644 --- a/message.go +++ b/message.go @@ -1,12 +1,11 @@ package coap import ( - "bytes" "encoding/binary" "errors" "fmt" + "io" "reflect" - "sort" "strings" ) @@ -210,16 +209,32 @@ var optionDefs = [256]optionDef{ } // MediaType specifies the content type of a message. -type MediaType byte +type MediaType uint16 // Content types. const ( - TextPlain MediaType = 0 // text/plain;charset=utf-8 - AppLinkFormat MediaType = 40 // application/link-format - AppXML MediaType = 41 // application/xml - AppOctets MediaType = 42 // application/octet-stream - AppExi MediaType = 47 // application/exi - AppJSON MediaType = 50 // application/json + TextPlain MediaType = 0 // text/plain;charset=utf-8 + AppCoseEncrypt0 MediaType = 16 //application/cose; cose-type="cose-encrypt0" (RFC 8152) + AppCoseMac0 MediaType = 17 //application/cose; cose-type="cose-mac0" (RFC 8152) + AppCoseSign1 MediaType = 18 //application/cose; cose-type="cose-sign1" (RFC 8152) + AppLinkFormat MediaType = 40 // application/link-format + AppXML MediaType = 41 // application/xml + AppOctets MediaType = 42 // application/octet-stream + AppExi MediaType = 47 // application/exi + AppJSON MediaType = 50 // application/json + AppJsonPatch MediaType = 51 //application/json-patch+json (RFC6902) + AppJsonMergePatch MediaType = 52 //application/merge-patch+json (RFC7396) + AppCBOR MediaType = 60 //application/cbor (RFC 7049) + AppCWT MediaType = 61 //application/cwt + AppCoseEncrypt MediaType = 96 //application/cose; cose-type="cose-encrypt" (RFC 8152) + AppCoseMac MediaType = 97 //application/cose; cose-type="cose-mac" (RFC 8152) + AppCoseSign MediaType = 98 //application/cose; cose-type="cose-sign" (RFC 8152) + AppCoseKey MediaType = 101 //application/cose-key (RFC 8152) + AppCoseKeySet MediaType = 102 //application/cose-key-set (RFC 8152) + AppCoapGroup MediaType = 256 //coap-group+json (RFC 7390) + AppOcfCbor MediaType = 10000 //application/vnd.ocf+cbor + AppLwm2mTLV MediaType = 11542 //application/vnd.oma.lwm2m+tlv + AppLwm2mJSON MediaType = 11543 //application/vnd.oma.lwm2m+json ) type option struct { @@ -334,24 +349,82 @@ func (o options) Minus(oid OptionID) options { return rv } -// Message is a CoAP message. -type Message struct { +type Message interface { + Type() COAPType + Code() COAPCode + MessageID() uint16 + Token() []byte + Payload() []byte + AllOptions() options + + IsConfirmable() bool + Options(o OptionID) []interface{} + Option(o OptionID) interface{} + optionStrings(o OptionID) []string + Path() []string + PathString() string + SetPathString(s string) + SetPath(s []string) + SetURIQuery(s string) + SetObserve(b int) + SetPayload(p []byte) + RemoveOption(opID OptionID) + AddOption(opID OptionID, val interface{}) + SetOption(opID OptionID, val interface{}) + MarshalBinary() ([]byte, error) + UnmarshalBinary(data []byte) error +} + +type MessageParams struct { Type COAPType Code COAPCode MessageID uint16 + Token []byte + Payload []byte +} + +// MessageBase is a CoAP message. +type MessageBase struct { + typ COAPType + code COAPCode + messageID uint16 - Token, Payload []byte + token, payload []byte opts options } +func (m *MessageBase) Type() COAPType { + return m.typ +} + +func (m *MessageBase) Code() COAPCode { + return m.code +} + +func (m *MessageBase) MessageID() uint16 { + return m.messageID +} + +func (m *MessageBase) Token() []byte { + return m.token +} + +func (m *MessageBase) Payload() []byte { + return m.payload +} + +func (m *MessageBase) AllOptions() options { + return m.opts +} + // IsConfirmable returns true if this message is confirmable. -func (m Message) IsConfirmable() bool { - return m.Type == Confirmable +func (m *MessageBase) IsConfirmable() bool { + return m.typ == Confirmable } // Options gets all the values for the given option. -func (m Message) Options(o OptionID) []interface{} { +func (m *MessageBase) Options(o OptionID) []interface{} { var rv []interface{} for _, v := range m.opts { @@ -364,7 +437,7 @@ func (m Message) Options(o OptionID) []interface{} { } // Option gets the first value for the given option ID. -func (m Message) Option(o OptionID) interface{} { +func (m *MessageBase) Option(o OptionID) interface{} { for _, v := range m.opts { if o == v.ID { return v.Value @@ -373,7 +446,7 @@ func (m Message) Option(o OptionID) interface{} { return nil } -func (m Message) optionStrings(o OptionID) []string { +func (m *MessageBase) optionStrings(o OptionID) []string { var rv []string for _, o := range m.Options(o) { rv = append(rv, o.(string)) @@ -382,17 +455,17 @@ func (m Message) optionStrings(o OptionID) []string { } // Path gets the Path set on this message if any. -func (m Message) Path() []string { +func (m *MessageBase) Path() []string { return m.optionStrings(URIPath) } // PathString gets a path as a / separated string. -func (m Message) PathString() string { +func (m *MessageBase) PathString() string { return strings.Join(m.Path(), "/") } // SetPathString sets a path by a / separated string. -func (m *Message) SetPathString(s string) { +func (m *MessageBase) SetPathString(s string) { for s[0] == '/' { s = s[1:] } @@ -400,17 +473,32 @@ func (m *Message) SetPathString(s string) { } // SetPath updates or adds a URIPath attribute on this message. -func (m *Message) SetPath(s []string) { +func (m *MessageBase) SetPath(s []string) { m.SetOption(URIPath, s) } +// Set URIQuery attibute to the message +func (m *MessageBase) SetURIQuery(s string) { + m.AddOption(URIQuery, s) +} + +// Set Observer attribute to the message +func (m *MessageBase) SetObserve(b int) { + m.AddOption(Observe, b) +} + +// SetPayload +func (m *MessageBase) SetPayload(p []byte) { + m.payload = p +} + // RemoveOption removes all references to an option -func (m *Message) RemoveOption(opID OptionID) { +func (m *MessageBase) RemoveOption(opID OptionID) { m.opts = m.opts.Minus(opID) } // AddOption adds an option. -func (m *Message) AddOption(opID OptionID, val interface{}) { +func (m *MessageBase) AddOption(opID OptionID, val interface{}) { iv := reflect.ValueOf(val) if (iv.Kind() == reflect.Slice || iv.Kind() == reflect.Array) && iv.Type().Elem().Kind() == reflect.String { @@ -423,7 +511,7 @@ func (m *Message) AddOption(opID OptionID, val interface{}) { } // SetOption sets an option, discarding any previous value -func (m *Message) SetOption(opID OptionID, val interface{}) { +func (m *MessageBase) SetOption(opID OptionID, val interface{}) { m.RemoveOption(opID) m.AddOption(opID, val) } @@ -436,33 +524,7 @@ const ( extoptError = 15 ) -// MarshalBinary produces the binary form of this Message. -func (m *Message) MarshalBinary() ([]byte, error) { - tmpbuf := []byte{0, 0} - binary.BigEndian.PutUint16(tmpbuf, m.MessageID) - - /* - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - |Ver| T | TKL | Code | Message ID | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Token (if any, TKL bytes) ... - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Options (if any) ... - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - |1 1 1 1 1 1 1 1| Payload (if any) ... - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - - buf := bytes.Buffer{} - buf.Write([]byte{ - (1 << 6) | (uint8(m.Type) << 4) | uint8(0xf&len(m.Token)), - byte(m.Code), - tmpbuf[0], tmpbuf[1], - }) - buf.Write(m.Token) - +func writeOpt(o option, buf io.Writer, delta int) { /* 0 1 2 3 4 5 6 7 +---------------+---------------+ @@ -509,13 +571,13 @@ func (m *Message) MarshalBinary() ([]byte, error) { d, dx := extendOpt(delta) l, lx := extendOpt(length) - buf.WriteByte(byte(d<<4) | byte(l)) + buf.Write([]byte{byte(d<<4) | byte(l)}) tmp := []byte{0, 0} writeExt := func(opt, ext int) { switch opt { case extoptByteCode: - buf.WriteByte(byte(ext)) + buf.Write([]byte{byte(ext)}) case extoptWordCode: binary.BigEndian.PutUint16(tmp, uint16(ext)) buf.Write(tmp) @@ -526,116 +588,83 @@ func (m *Message) MarshalBinary() ([]byte, error) { writeExt(l, lx) } - sort.Stable(&m.opts) + b := o.toBytes() + writeOptHeader(delta, len(b)) + buf.Write(b) +} +func writeOpts(buf io.Writer, opts options) { prev := 0 - - for _, o := range m.opts { - b := o.toBytes() - writeOptHeader(int(o.ID)-prev, len(b)) - buf.Write(b) + for _, o := range opts { + writeOpt(o, buf, int(o.ID)-prev) prev = int(o.ID) } - - if len(m.Payload) > 0 { - buf.Write([]byte{0xff}) - } - - buf.Write(m.Payload) - - return buf.Bytes(), nil } -// ParseMessage extracts the Message from the given input. -func ParseMessage(data []byte) (Message, error) { - rv := Message{} - return rv, rv.UnmarshalBinary(data) -} - -// UnmarshalBinary parses the given binary slice as a Message. -func (m *Message) UnmarshalBinary(data []byte) error { - if len(data) < 4 { - return errors.New("short packet") - } - - if data[0]>>6 != 1 { - return errors.New("invalid version") - } - - m.Type = COAPType((data[0] >> 4) & 0x3) - tokenLen := int(data[0] & 0xf) - if tokenLen > 8 { - return ErrInvalidTokenLen - } - - m.Code = COAPCode(data[1]) - m.MessageID = binary.BigEndian.Uint16(data[2:4]) - - if tokenLen > 0 { - m.Token = make([]byte, tokenLen) - } - if len(data) < 4+tokenLen { - return errors.New("truncated") - } - copy(m.Token, data[4:4+tokenLen]) - b := data[4+tokenLen:] +// parseBody extracts the options and payload from a byte slice. The supplied +// byte slice contains everything following the message header (everything +// after the token). +func parseBody(data []byte) (options, []byte, error) { prev := 0 parseExtOpt := func(opt int) (int, error) { switch opt { case extoptByteCode: - if len(b) < 1 { + if len(data) < 1 { return -1, errors.New("truncated") } - opt = int(b[0]) + extoptByteAddend - b = b[1:] + opt = int(data[0]) + extoptByteAddend + data = data[1:] case extoptWordCode: - if len(b) < 2 { + if len(data) < 2 { return -1, errors.New("truncated") } - opt = int(binary.BigEndian.Uint16(b[:2])) + extoptWordAddend - b = b[2:] + opt = int(binary.BigEndian.Uint16(data[:2])) + extoptWordAddend + data = data[2:] } return opt, nil } - for len(b) > 0 { - if b[0] == 0xff { - b = b[1:] + var opts options + + for len(data) > 0 { + if data[0] == 0xff { + data = data[1:] break } - delta := int(b[0] >> 4) - length := int(b[0] & 0x0f) + delta := int(data[0] >> 4) + length := int(data[0] & 0x0f) if delta == extoptError || length == extoptError { - return errors.New("unexpected extended option marker") + return nil, nil, errors.New("unexpected extended option marker") } - b = b[1:] + data = data[1:] delta, err := parseExtOpt(delta) if err != nil { - return err + return nil, nil, err } length, err = parseExtOpt(length) if err != nil { - return err + return nil, nil, err } - if len(b) < length { - return errors.New("truncated") + if len(data) < length { + return nil, nil, errors.New("truncated") } oid := OptionID(prev + delta) - opval := parseOptionValue(oid, b[:length]) - b = b[length:] + opval := parseOptionValue(oid, data[:length]) + data = data[length:] prev = int(oid) if opval != nil { - m.opts = append(m.opts, option{ID: oid, Value: opval}) + opt := option{ID: oid, Value: opval} + opts = append(opts, opt) } } - m.Payload = b - return nil + + return opts, data, nil } diff --git a/message_test.go b/message_test.go index 7cf7651..5e440be 100644 --- a/message_test.go +++ b/message_test.go @@ -9,47 +9,47 @@ import ( ) var ( - _ = encoding.BinaryMarshaler(&Message{}) - _ = encoding.BinaryUnmarshaler(&Message{}) + _ = encoding.BinaryMarshaler(&DgramMessage{}) + _ = encoding.BinaryUnmarshaler(&DgramMessage{}) ) // assertEqualMessages compares the e(xptected) message to the a(ctual) message // and reports any diffs with t.Errorf. func assertEqualMessages(t *testing.T, e, a Message) { - if e.Type != a.Type { - t.Errorf("Expected type %v, got %v", e.Type, a.Type) + if e.Type() != a.Type() { + t.Errorf("Expected type %v, got %v", e.Type(), a.Type()) } - if e.Code != a.Code { - t.Errorf("Expected code %v, got %v", e.Code, a.Code) + if e.Code() != a.Code() { + t.Errorf("Expected code %v, got %v", e.Code(), a.Code()) } - if e.MessageID != a.MessageID { - t.Errorf("Expected MessageID %v, got %v", e.MessageID, a.MessageID) + if e.MessageID() != a.MessageID() { + t.Errorf("Expected MessageID %v, got %v", e.MessageID(), a.MessageID()) } - if !bytes.Equal(e.Token, a.Token) { - t.Errorf("Expected token %#v, got %#v", e.Token, a.Token) + if !bytes.Equal(e.Token(), a.Token()) { + t.Errorf("Expected token %#v, got %#v", e.Token(), a.Token()) } - if !bytes.Equal(e.Payload, a.Payload) { - t.Errorf("Expected payload %#v, got %#v", e.Payload, a.Payload) + if !bytes.Equal(e.Payload(), a.Payload()) { + t.Errorf("Expected payload %#v, got %#v", e.Payload(), a.Payload()) } - if len(e.opts) != len(a.opts) { - t.Errorf("Expected %v options, got %v", len(e.opts), len(a.opts)) + if len(e.AllOptions()) != len(a.AllOptions()) { + t.Errorf("Expected %v options, got %v", len(e.AllOptions()), len(a.AllOptions())) } else { - for i, _ := range e.opts { - if e.opts[i].ID != a.opts[i].ID { - t.Errorf("Expected option ID %v, got %v", e.opts[i].ID, a.opts[i].ID) + for i, _ := range e.AllOptions() { + if e.AllOptions()[i].ID != a.AllOptions()[i].ID { + t.Errorf("Expected option ID %v, got %v", e.AllOptions()[i].ID, a.AllOptions()[i].ID) continue } - switch e.opts[i].Value.(type) { + switch e.AllOptions()[i].Value.(type) { case []byte: - expected := e.opts[i].Value.([]byte) - actual := a.opts[i].Value.([]byte) + expected := e.AllOptions()[i].Value.([]byte) + actual := a.AllOptions()[i].Value.([]byte) if !bytes.Equal(expected, actual) { - t.Errorf("Expected Option ID %v value %v, got %v", e.opts[i].ID, expected, actual) + t.Errorf("Expected Option ID %v value %v, got %v", e.AllOptions()[i].ID, expected, actual) } default: - if e.opts[i].Value != a.opts[i].Value { - t.Errorf("Expected Option ID %v value %v, got %v", e.opts[i].ID, e.opts[i].Value, a.opts[i].Value) + if e.AllOptions()[i].Value != a.AllOptions()[i].Value { + t.Errorf("Expected Option ID %v value %v, got %v", e.AllOptions()[i].ID, e.AllOptions()[i].Value, a.AllOptions()[i].Value) } } } @@ -98,8 +98,8 @@ func TestMessageConfirmable(t *testing.T) { m Message exp bool }{ - {Message{Type: Confirmable}, true}, - {Message{Type: NonConfirmable}, false}, + {&DgramMessage{MessageBase{typ: Confirmable}}, true}, + {&DgramMessage{MessageBase{typ: NonConfirmable}}, false}, } for _, test := range tests { @@ -111,7 +111,7 @@ func TestMessageConfirmable(t *testing.T) { } func TestMissingOption(t *testing.T) { - got := Message{}.Option(MaxAge) + got := (&DgramMessage{}).Option(MaxAge) if got != nil { t.Errorf("Expected nil, got %v", got) } @@ -162,10 +162,12 @@ func TestCodeString(t *testing.T) { } func TestEncodeMessageWithoutOptionsAndPayload(t *testing.T) { - req := Message{ - Type: Confirmable, - Code: GET, - MessageID: 12345, + req := DgramMessage{ + MessageBase{ + typ: Confirmable, + code: GET, + messageID: 12345, + }, } data, err := req.MarshalBinary() @@ -181,10 +183,12 @@ func TestEncodeMessageWithoutOptionsAndPayload(t *testing.T) { } func TestEncodeMessageSmall(t *testing.T) { - req := Message{ - Type: Confirmable, - Code: GET, - MessageID: 12345, + req := DgramMessage{ + MessageBase{ + typ: Confirmable, + code: GET, + messageID: 12345, + }, } req.AddOption(ETag, []byte("weetag")) @@ -206,11 +210,13 @@ func TestEncodeMessageSmall(t *testing.T) { } func TestEncodeMessageSmallWithPayload(t *testing.T) { - req := Message{ - Type: Confirmable, - Code: GET, - MessageID: 12345, - Payload: []byte("hi"), + req := DgramMessage{ + MessageBase{ + typ: Confirmable, + code: GET, + messageID: 12345, + payload: []byte("hi"), + }, } req.AddOption(ETag, []byte("weetag")) @@ -246,7 +252,7 @@ func TestInvalidMessageParsing(t *testing.T) { } for _, data := range invalidPackets { - msg, err := ParseMessage(data) + msg, err := ParseDgramMessage(data) if err == nil { t.Errorf("Unexpected success parsing short message (%#v): %v", data, msg) } @@ -254,13 +260,15 @@ func TestInvalidMessageParsing(t *testing.T) { } func TestOptionsWithIllegalLengthAreIgnoredDuringParsing(t *testing.T) { - exp := Message{ - Type: Confirmable, - Code: GET, - MessageID: 0xabcd, - Payload: []byte{}, - } - msg, err := ParseMessage([]byte{0x40, 0x01, 0xab, 0xcd, + exp := &DgramMessage{ + MessageBase{ + typ: Confirmable, + code: GET, + messageID: 0xabcd, + payload: []byte{}, + }, + } + msg, err := ParseDgramMessage([]byte{0x40, 0x01, 0xab, 0xcd, 0x73, // URI-Port option (uint) with length 3 (valid lengths are 0-2) 0x11, 0x22, 0x33, 0xff}) if err != nil { @@ -270,7 +278,7 @@ func TestOptionsWithIllegalLengthAreIgnoredDuringParsing(t *testing.T) { t.Errorf("Expected\n%#v\ngot\n%#v", exp, msg) } - msg, err = ParseMessage([]byte{0x40, 0x01, 0xab, 0xcd, + msg, err = ParseDgramMessage([]byte{0x40, 0x01, 0xab, 0xcd, 0xd5, 0x01, // Max-Age option (uint) with length 5 (valid lengths are 0-4) 0x11, 0x22, 0x33, 0x44, 0x55, 0xff}) if err != nil { @@ -283,25 +291,25 @@ func TestOptionsWithIllegalLengthAreIgnoredDuringParsing(t *testing.T) { func TestDecodeMessageWithoutOptionsAndPayload(t *testing.T) { input := []byte{0x40, 0x1, 0x30, 0x39} - msg, err := ParseMessage(input) + msg, err := ParseDgramMessage(input) if err != nil { t.Fatalf("Error parsing message: %v", err) } - if msg.Type != Confirmable { - t.Errorf("Expected message type confirmable, got %v", msg.Type) + if msg.Type() != Confirmable { + t.Errorf("Expected message type confirmable, got %v", msg.Type()) } - if msg.Code != GET { - t.Errorf("Expected message code GET, got %v", msg.Code) + if msg.Code() != GET { + t.Errorf("Expected message code GET, got %v", msg.Code()) } - if msg.MessageID != 12345 { - t.Errorf("Expected message ID 12345, got %v", msg.MessageID) + if msg.MessageID() != 12345 { + t.Errorf("Expected message ID 12345, got %v", msg.MessageID()) } - if len(msg.Token) != 0 { - t.Errorf("Incorrect token: %q", msg.Token) + if len(msg.Token()) != 0 { + t.Errorf("Incorrect token: %q", msg.Token()) } - if len(msg.Payload) != 0 { - t.Errorf("Incorrect payload: %q", msg.Payload) + if len(msg.Payload()) != 0 { + t.Errorf("Incorrect payload: %q", msg.Payload()) } } @@ -312,31 +320,33 @@ func TestDecodeMessageSmallWithPayload(t *testing.T) { 0xff, 'h', 'i', } - msg, err := ParseMessage(input) + msg, err := ParseDgramMessage(input) if err != nil { t.Fatalf("Error parsing message: %v", err) } - if msg.Type != Confirmable { - t.Errorf("Expected message type confirmable, got %v", msg.Type) + if msg.Type() != Confirmable { + t.Errorf("Expected message type confirmable, got %v", msg.Type()) } - if msg.Code != GET { - t.Errorf("Expected message code GET, got %v", msg.Code) + if msg.Code() != GET { + t.Errorf("Expected message code GET, got %v", msg.Code()) } - if msg.MessageID != 12345 { - t.Errorf("Expected message ID 12345, got %v", msg.MessageID) + if msg.MessageID() != 12345 { + t.Errorf("Expected message ID 12345, got %v", msg.MessageID()) } - if !bytes.Equal(msg.Payload, []byte("hi")) { - t.Errorf("Incorrect payload: %q", msg.Payload) + if !bytes.Equal(msg.Payload(), []byte("hi")) { + t.Errorf("Incorrect payload: %q", msg.Payload()) } } func TestEncodeMessageVerySmall(t *testing.T) { - req := Message{ - Type: Confirmable, - Code: GET, - MessageID: 12345, + req := &DgramMessage{ + MessageBase{ + typ: Confirmable, + code: GET, + messageID: 12345, + }, } req.SetPathString("x") @@ -356,10 +366,12 @@ func TestEncodeMessageVerySmall(t *testing.T) { // Same as above, but with a leading slash func TestEncodeMessageVerySmall2(t *testing.T) { - req := Message{ - Type: Confirmable, - Code: GET, - MessageID: 12345, + req := &DgramMessage{ + MessageBase{ + typ: Confirmable, + code: GET, + messageID: 12345, + }, } req.SetPathString("/x") @@ -385,7 +397,13 @@ func TestEncodeSeveral(t *testing.T) { "f", "h", "g", "i", "j"}, } for p, a := range tests { - m := &Message{Type: Confirmable, Code: GET, MessageID: 12345} + m := &DgramMessage{ + MessageBase{ + typ: Confirmable, + code: GET, + messageID: 12345, + }, + } m.SetPathString(p) b, err := m.MarshalBinary() if err != nil { @@ -393,7 +411,7 @@ func TestEncodeSeveral(t *testing.T) { t.Fail() continue } - m2, err := ParseMessage(b) + m2, err := ParseDgramMessage(b) if err != nil { t.Fatalf("Can't parse my own message at %#v: %v", p, err) } @@ -406,7 +424,13 @@ func TestEncodeSeveral(t *testing.T) { } func TestPathAsOption(t *testing.T) { - m := &Message{Type: Confirmable, Code: GET, MessageID: 12345} + m := &DgramMessage{ + MessageBase{ + typ: Confirmable, + code: GET, + messageID: 12345, + }, + } m.SetOption(LocationPath, []string{"a", "b"}) got, err := m.MarshalBinary() if err != nil { @@ -419,10 +443,12 @@ func TestPathAsOption(t *testing.T) { } func TestEncodePath14(t *testing.T) { - req := Message{ - Type: Confirmable, - Code: GET, - MessageID: 12345, + req := DgramMessage{ + MessageBase{ + typ: Confirmable, + code: GET, + messageID: 12345, + }, } req.SetPathString("123456789ABCDE") @@ -443,10 +469,12 @@ func TestEncodePath14(t *testing.T) { } func TestEncodePath15(t *testing.T) { - req := Message{ - Type: Confirmable, - Code: GET, - MessageID: 12345, + req := DgramMessage{ + MessageBase{ + typ: Confirmable, + code: GET, + messageID: 12345, + }, } req.SetPathString("123456789ABCDEF") @@ -467,10 +495,12 @@ func TestEncodePath15(t *testing.T) { } func TestEncodeLargePath(t *testing.T) { - req := Message{ - Type: Confirmable, - Code: GET, - MessageID: 12345, + req := DgramMessage{ + MessageBase{ + typ: Confirmable, + code: GET, + messageID: 12345, + }, } req.SetPathString("this_path_is_longer_than_fifteen_bytes") @@ -507,18 +537,20 @@ func TestDecodeLargePath(t *testing.T) { 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, } - req, err := ParseMessage(data) + req, err := ParseDgramMessage(data) if err != nil { t.Fatalf("Error parsing request: %v", err) } path := "this_path_is_longer_than_fifteen_bytes" - exp := Message{ - Type: Confirmable, - Code: GET, - MessageID: 12345, - Payload: []byte{}, + exp := &DgramMessage{ + MessageBase{ + typ: Confirmable, + code: GET, + messageID: 12345, + payload: []byte{}, + }, } exp.SetOption(URIPath, path) @@ -535,16 +567,18 @@ func TestDecodeMessageSmaller(t *testing.T) { 0x65, 0x65, 0x74, 0x61, 0x67, 0xa1, 0x3, } - req, err := ParseMessage(data) + req, err := ParseDgramMessage(data) if err != nil { t.Fatalf("Error parsing request: %v", err) } - exp := Message{ - Type: Confirmable, - Code: GET, - MessageID: 12345, - Payload: []byte{}, + exp := &DgramMessage{ + MessageBase{ + typ: Confirmable, + code: GET, + messageID: 12345, + payload: []byte{}, + }, } exp.SetOption(ETag, []byte("weetag")) @@ -621,30 +655,30 @@ func TestExample1(t *testing.T) { input := append([]byte{0x40, 1, 0x7d, 0x34, (11 << 4) | 11}, []byte("temperature")...) - msg, err := ParseMessage(input) + msg, err := ParseDgramMessage(input) if err != nil { t.Fatalf("Error parsing message: %v", err) } - if msg.Type != Confirmable { - t.Errorf("Expected message type confirmable, got %v", msg.Type) + if msg.Type() != Confirmable { + t.Errorf("Expected message type confirmable, got %v", msg.Type()) } - if msg.Code != GET { - t.Errorf("Expected message code GET, got %v", msg.Code) + if msg.Code() != GET { + t.Errorf("Expected message code GET, got %v", msg.Code()) } - if msg.MessageID != 0x7d34 { - t.Errorf("Expected message ID 0x7d34, got 0x%x", msg.MessageID) + if msg.MessageID() != 0x7d34 { + t.Errorf("Expected message ID 0x7d34, got 0x%x", msg.MessageID()) } if msg.Option(URIPath).(string) != "temperature" { t.Errorf("Incorrect uri path: %q", msg.Option(URIPath)) } - if len(msg.Token) > 0 { - t.Errorf("Incorrect token: %x", msg.Token) + if len(msg.Token()) > 0 { + t.Errorf("Incorrect token: %x", msg.Token()) } - if len(msg.Payload) > 0 { - t.Errorf("Incorrect payload: %q", msg.Payload) + if len(msg.Payload()) > 0 { + t.Errorf("Incorrect payload: %q", msg.Payload()) } } @@ -661,26 +695,26 @@ func TestExample1Res(t *testing.T) { input := append([]byte{0x60, 69, 0x7d, 0x34, 0xff}, []byte("22.3 C")...) - msg, err := ParseMessage(input) + msg, err := ParseDgramMessage(input) if err != nil { t.Fatalf("Error parsing message: %v", err) } - if msg.Type != Acknowledgement { - t.Errorf("Expected message type confirmable, got %v", msg.Type) + if msg.Type() != Acknowledgement { + t.Errorf("Expected message type confirmable, got %v", msg.Type()) } - if msg.Code != Content { - t.Errorf("Expected message code Content, got %v", msg.Code) + if msg.Code() != Content { + t.Errorf("Expected message code Content, got %v", msg.Code()) } - if msg.MessageID != 0x7d34 { - t.Errorf("Expected message ID 0x7d34, got 0x%x", msg.MessageID) + if msg.MessageID() != 0x7d34 { + t.Errorf("Expected message ID 0x7d34, got 0x%x", msg.MessageID()) } - if len(msg.Token) > 0 { - t.Errorf("Incorrect token: %x", msg.Token) + if len(msg.Token()) > 0 { + t.Errorf("Incorrect token: %x", msg.Token()) } - if !bytes.Equal(msg.Payload, []byte("22.3 C")) { - t.Errorf("Incorrect payload: %q", msg.Payload) + if !bytes.Equal(msg.Payload(), []byte("22.3 C")) { + t.Errorf("Incorrect payload: %q", msg.Payload()) } } @@ -691,17 +725,17 @@ func TestIssue15(t *testing.T) { 0x72, 0x6b, 0x2f, 0x63, 0x63, 0x33, 0x30, 0x30, 0x30, 0x2d, 0x70, 0x61, 0x74, 0x63, 0x68, 0x2d, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0xff, 0x31, 0x2e, 0x32, 0x38} - msg, err := ParseMessage(input) + msg, err := ParseDgramMessage(input) if err != nil { t.Fatalf("Error parsing message: %v", err) } - if !bytes.Equal(msg.Token, []byte{1, 2, 3}) { - t.Errorf("Expected token = [1, 2, 3], got %v", msg.Token) + if !bytes.Equal(msg.Token(), []byte{1, 2, 3}) { + t.Errorf("Expected token = [1, 2, 3], got %v", msg.Token()) } - if !bytes.Equal(msg.Payload, []byte{0x31, 0x2e, 0x32, 0x38}) { - t.Errorf("Expected payload = {0x31, 0x2e, 0x32, 0x38}, got %v", msg.Payload) + if !bytes.Equal(msg.Payload(), []byte{0x31, 0x2e, 0x32, 0x38}) { + t.Errorf("Expected payload = {0x31, 0x2e, 0x32, 0x38}, got %v", msg.Payload()) } pathExp := "E/spark/cc3000-patch-version" @@ -714,7 +748,7 @@ func TestErrorOptionMarker(t *testing.T) { input := []byte{0x53, 0x2, 0x7a, 0x23, 0x1, 0x2, 0x3, 0xbf, 0x01, 0x02, 0x03, 0x04, 0x05, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xe, 0xf, 0x10} - msg, err := ParseMessage(input) + msg, err := ParseDgramMessage(input) if err == nil { t.Errorf("Unexpected success parsing malformed option: %v", msg) } @@ -725,7 +759,7 @@ func TestDecodeContentFormatOptionToMediaType(t *testing.T) { 0x40, 0x1, 0x30, 0x39, 0xc1, 0x32, 0x51, 0x29, } - parsedMsg, err := ParseMessage(data) + parsedMsg, err := ParseDgramMessage(data) if err != nil { t.Fatalf("Error parsing request: %v", err) } @@ -742,12 +776,14 @@ func TestDecodeContentFormatOptionToMediaType(t *testing.T) { } func TestEncodeMessageWithAllOptions(t *testing.T) { - req := Message{ - Type: Confirmable, - Code: GET, - MessageID: 12345, - Token: []byte("TOKEN"), - Payload: []byte("PAYLOAD"), + req := &DgramMessage{ + MessageBase{ + typ: Confirmable, + code: GET, + messageID: 12345, + token: []byte("TOKEN"), + payload: []byte("PAYLOAD"), + }, } req.AddOption(IfMatch, []byte("IFMATCH")) @@ -772,7 +808,7 @@ func TestEncodeMessageWithAllOptions(t *testing.T) { t.Fatalf("Error encoding request: %v", err) } - parsedMsg, err := ParseMessage(data) + parsedMsg, err := ParseDgramMessage(data) if err != nil { t.Fatalf("Error parsing binary packet: %v", err) } diff --git a/messagedgram.go b/messagedgram.go new file mode 100644 index 0000000..11cebdb --- /dev/null +++ b/messagedgram.go @@ -0,0 +1,109 @@ +package coap + +import ( + "bytes" + "encoding/binary" + "errors" + "sort" +) + +// DgramMessage implements Message interface. +type DgramMessage struct { + MessageBase +} + +func NewDgramMessage(p MessageParams) *DgramMessage { + return &DgramMessage{ + MessageBase{ + typ: p.Type, + code: p.Code, + messageID: p.MessageID, + token: p.Token, + payload: p.Payload, + }, + } +} + +// MarshalBinary produces the binary form of this DgramMessage. +func (m *DgramMessage) MarshalBinary() ([]byte, error) { + tmpbuf := []byte{0, 0} + binary.BigEndian.PutUint16(tmpbuf, m.MessageID()) + + /* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |Ver| T | TKL | Code | Message ID | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Token (if any, TKL bytes) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Options (if any) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |1 1 1 1 1 1 1 1| Payload (if any) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + + buf := bytes.Buffer{} + buf.Write([]byte{ + (1 << 6) | (uint8(m.Type()) << 4) | uint8(0xf&len(m.MessageBase.token)), + byte(m.MessageBase.code), + tmpbuf[0], tmpbuf[1], + }) + buf.Write(m.MessageBase.token) + + sort.Stable(&m.MessageBase.opts) + writeOpts(&buf, m.MessageBase.opts) + + if len(m.MessageBase.payload) > 0 { + buf.Write([]byte{0xff}) + } + + buf.Write(m.MessageBase.payload) + + return buf.Bytes(), nil +} + +// UnmarshalBinary parses the given binary slice as a DgramMessage. +func (m *DgramMessage) UnmarshalBinary(data []byte) error { + if len(data) < 4 { + return errors.New("short packet") + } + + if data[0]>>6 != 1 { + return errors.New("invalid version") + } + + m.MessageBase.typ = COAPType((data[0] >> 4) & 0x3) + tokenLen := int(data[0] & 0xf) + if tokenLen > 8 { + return ErrInvalidTokenLen + } + + m.MessageBase.code = COAPCode(data[1]) + m.MessageBase.messageID = binary.BigEndian.Uint16(data[2:4]) + + if tokenLen > 0 { + m.MessageBase.token = make([]byte, tokenLen) + } + if len(data) < 4+tokenLen { + return errors.New("truncated") + } + copy(m.MessageBase.token, data[4:4+tokenLen]) + b := data[4+tokenLen:] + + o, p, err := parseBody(b) + if err != nil { + return err + } + + m.MessageBase.payload = p + m.MessageBase.opts = o + + return nil +} + +// ParseDgramMessage extracts the Message from the given input. +func ParseDgramMessage(data []byte) (*DgramMessage, error) { + rv := &DgramMessage{} + return rv, rv.UnmarshalBinary(data) +} diff --git a/messagetcp.go b/messagetcp.go index 09354b9..ec2658b 100644 --- a/messagetcp.go +++ b/messagetcp.go @@ -1,69 +1,286 @@ package coap import ( + "bytes" "encoding/binary" - "errors" + "fmt" "io" + "sort" ) -// TcpMessage is a CoAP Message that can encode itself for TCP +const ( + TCP_MESSAGE_LEN13_BASE = 13 + TCP_MESSAGE_LEN14_BASE = 269 + TCP_MESSAGE_LEN15_BASE = 65805 + TCP_MESSAGE_MAX_LEN = 0x7fff0000 // Large number that works in 32-bit builds. +) + +// TcpMessage is a CoAP MessageBase that can encode itself for TCP // transport. type TcpMessage struct { - Message + MessageBase } -func (m *TcpMessage) MarshalBinary() ([]byte, error) { - bin, err := m.Message.MarshalBinary() - if err != nil { - return nil, err +func NewTcpMessage(p MessageParams) *TcpMessage { + return &TcpMessage{ + MessageBase{ + typ: p.Type, + code: p.Code, + messageID: p.MessageID, + token: p.Token, + payload: p.Payload, + }, } +} +func (m *TcpMessage) MarshalBinary() ([]byte, error) { /* - A CoAP TCP message looks like: - - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Message Length |Ver| T | TKL | Code | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Token (if any, TKL bytes) ... - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Options (if any) ... - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - |1 1 1 1 1 1 1 1| Payload (if any) ... - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + A CoAP TCP message looks like: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Len | TKL | Extended Length ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Code | TKL bytes ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Options (if any) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |1 1 1 1 1 1 1 1| Payload (if any) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + The size of the Extended Length field is inferred from the value of the + Len field as follows: + + | Len value | Extended Length size | Total length | + +------------+-----------------------+---------------------------+ + | 0-12 | 0 | Len | + | 13 | 1 | Extended Length + 13 | + | 14 | 2 | Extended Length + 269 | + | 15 | 4 | Extended Length + 65805 | */ - l := []byte{0, 0} - binary.BigEndian.PutUint16(l, uint16(len(bin))) + buf := bytes.Buffer{} + + sort.Stable(&m.MessageBase.opts) + writeOpts(&buf, m.MessageBase.opts) + + if len(m.MessageBase.payload) > 0 { + buf.Write([]byte{0xff}) + buf.Write(m.MessageBase.payload) + } + + var lenNib uint8 + var extLenBytes []byte + + if buf.Len() < TCP_MESSAGE_LEN13_BASE { + lenNib = uint8(buf.Len()) + } else if buf.Len() < TCP_MESSAGE_LEN14_BASE { + lenNib = 13 + extLen := buf.Len() - TCP_MESSAGE_LEN13_BASE + extLenBytes = []byte{uint8(extLen)} + } else if buf.Len() < TCP_MESSAGE_LEN15_BASE { + lenNib = 14 + extLen := buf.Len() - TCP_MESSAGE_LEN14_BASE + extLenBytes = make([]byte, 2) + binary.BigEndian.PutUint16(extLenBytes, uint16(extLen)) + } else if buf.Len() < TCP_MESSAGE_MAX_LEN { + lenNib = 15 + extLen := buf.Len() - TCP_MESSAGE_LEN15_BASE + extLenBytes = make([]byte, 4) + binary.BigEndian.PutUint32(extLenBytes, uint32(extLen)) + } + + hdr := make([]byte, 1+len(extLenBytes)+len(m.MessageBase.token)+1) + hdrOff := 0 + + // Length and TKL nibbles. + hdr[hdrOff] = uint8(0xf&len(m.MessageBase.token)) | (lenNib << 4) + hdrOff++ + + // Extended length, if present. + if len(extLenBytes) > 0 { + copy(hdr[hdrOff:hdrOff+len(extLenBytes)], extLenBytes) + hdrOff += len(extLenBytes) + } + + // Code. + hdr[hdrOff] = byte(m.MessageBase.code) + hdrOff++ + + // Token. + if len(m.MessageBase.token) > 0 { + copy(hdr[hdrOff:hdrOff+len(m.MessageBase.token)], m.MessageBase.token) + hdrOff += len(m.MessageBase.token) + } + + return append(hdr, buf.Bytes()...), nil +} + +// msgTcpInfo describes a single TCP CoAP message. Used during reassembly. +type msgTcpInfo struct { + typ uint8 + token []byte + code uint8 + hdrLen int + totLen int +} + +// readTcpMsgInfo infers information about a TCP CoAP message from the first +// fragment. +func readTcpMsgInfo(r io.Reader) (msgTcpInfo, error) { + mti := msgTcpInfo{} + + hdrOff := 0 + + var firstByte byte + if err := binary.Read(r, binary.BigEndian, &firstByte); err != nil { + return mti, err + } + hdrOff++ + + lenNib := (firstByte & 0xf0) >> 4 + tkl := firstByte & 0x0f - return append(l, bin...), nil + var opLen int + if lenNib < TCP_MESSAGE_LEN13_BASE { + opLen = int(lenNib) + } else if lenNib == 13 { + var extLen byte + if err := binary.Read(r, binary.BigEndian, &extLen); err != nil { + return mti, err + } + hdrOff++ + opLen = TCP_MESSAGE_LEN13_BASE + int(extLen) + } else if lenNib == 14 { + var extLen uint16 + if err := binary.Read(r, binary.BigEndian, &extLen); err != nil { + return mti, err + } + hdrOff += 2 + opLen = TCP_MESSAGE_LEN14_BASE + int(extLen) + } else if lenNib == 15 { + var extLen uint32 + if err := binary.Read(r, binary.BigEndian, &extLen); err != nil { + return mti, err + } + hdrOff += 4 + opLen = TCP_MESSAGE_LEN15_BASE + int(extLen) + } + + mti.totLen = hdrOff + 1 + int(tkl) + opLen + + if err := binary.Read(r, binary.BigEndian, &mti.code); err != nil { + return mti, err + } + hdrOff++ + + mti.token = make([]byte, tkl) + if _, err := io.ReadFull(r, mti.token); err != nil { + return mti, err + } + hdrOff += int(tkl) + + mti.hdrLen = hdrOff + + return mti, nil +} + +func readTcpMsgBody(mti msgTcpInfo, r io.Reader) (options, []byte, error) { + bodyLen := mti.totLen - mti.hdrLen + b := make([]byte, bodyLen) + if _, err := io.ReadFull(r, b); err != nil { + return nil, nil, err + } + + o, p, err := parseBody(b) + if err != nil { + return nil, nil, err + } + + return o, p, nil +} + +func (m *TcpMessage) fill(mti msgTcpInfo, o options, p []byte) { + m.MessageBase.typ = COAPType(mti.typ) + m.MessageBase.code = COAPCode(mti.code) + m.MessageBase.token = mti.token + m.MessageBase.opts = o + m.MessageBase.payload = p } func (m *TcpMessage) UnmarshalBinary(data []byte) error { - if len(data) < 4 { - return errors.New("short packet") + r := bytes.NewReader(data) + + mti, err := readTcpMsgInfo(r) + if err != nil { + return fmt.Errorf("Error reading TCP CoAP header; %s", err.Error()) + } + + if len(data) != mti.totLen { + return fmt.Errorf("CoAP length mismatch (hdr=%d pkt=%d)", + mti.totLen, len(data)) + } + + o, p, err := readTcpMsgBody(mti, r) + if err != nil { + return err + } + + m.fill(mti, o, p) + return nil +} + +// PullTcp extracts a complete TCP CoAP message from the front of a byte queue. +// +// Return values: +// *TcpMessage: On success, points to the extracted message; nil if a complete +// message could not be extracted. +// []byte: The unread portion of of the supplied byte buffer. If a message +// was not extracted, this is the unchanged buffer that was passed in. +// error: Non-nil if the buffer contains an invalid CoAP message. +// +// Note: It is not an error if the supplied buffer does not contain a complete +// message. In such a case, nil *TclMessage and error values are returned +// along with the original buffer. +func PullTcp(data []byte) (*TcpMessage, []byte, error) { + r := bytes.NewReader(data) + m, err := Decode(r) + if err != nil { + if err == io.EOF || err == io.ErrUnexpectedEOF { + // Packet is incomplete. + return nil, data, nil + } else { + // Some other error. + return nil, data, err + } + } + + // Determine the number of bytes read. These bytes get trimmed from the + // front of the returned data slice. + sz, err := r.Seek(0, io.SeekCurrent) + if err != nil { + // This should never happen. + return nil, data, err } - return m.Message.UnmarshalBinary(data) + return m, data[sz:], nil } // Decode reads a single message from its input. func Decode(r io.Reader) (*TcpMessage, error) { - var ln uint16 - err := binary.Read(r, binary.BigEndian, &ln) + mti, err := readTcpMsgInfo(r) if err != nil { return nil, err } - packet := make([]byte, ln) - _, err = io.ReadFull(r, packet) + o, p, err := readTcpMsgBody(mti, r) if err != nil { return nil, err } - m := TcpMessage{} + m := &TcpMessage{} + m.fill(mti, o, p) - err = m.UnmarshalBinary(packet) - return &m, err + return m, nil } diff --git a/messagetcp_test.go b/messagetcp_test.go index be80e8e..9717a35 100644 --- a/messagetcp_test.go +++ b/messagetcp_test.go @@ -2,35 +2,33 @@ package coap import ( "bytes" - "encoding/binary" "testing" ) func TestTCPDecodeMessageSmallWithPayload(t *testing.T) { - input := []byte{0, 0, - 0x40, 0x1, 0x30, 0x39, 0x21, 0x3, + input := []byte{ + 13 << 4, // len=13, tkl=0 + 0x01, // Extended Length + 0x01, // Code + 0x30, 0x39, 0x21, 0x3, 0x26, 0x77, 0x65, 0x65, 0x74, 0x61, 0x67, - 0xff, 'h', 'i', + 0xff, + 'h', 'i', } - binary.BigEndian.PutUint16(input, uint16(len(input)-2)) - msg, err := Decode(bytes.NewReader(input)) if err != nil { t.Fatalf("Error parsing message: %v", err) } - if msg.Type != Confirmable { - t.Errorf("Expected message type confirmable, got %v", msg.Type) - } - if msg.Code != GET { - t.Errorf("Expected message code GET, got %v", msg.Code) + if msg.Type() != Confirmable { + t.Errorf("Expected message type confirmable, got %v", msg.Type()) } - if msg.MessageID != 12345 { - t.Errorf("Expected message ID 12345, got %v", msg.MessageID) + if msg.Code() != GET { + t.Errorf("Expected message code GET, got %v", msg.Code()) } - if !bytes.Equal(msg.Payload, []byte("hi")) { - t.Errorf("Incorrect payload: %q", msg.Payload) + if !bytes.Equal(msg.Payload(), []byte("hi")) { + t.Errorf("Incorrect payload: %q", msg.Payload()) } } diff --git a/server.go b/server.go index 343b875..0365b3b 100644 --- a/server.go +++ b/server.go @@ -12,62 +12,96 @@ const maxPktLen = 1500 // Handler is a type that handles CoAP messages. type Handler interface { // Handle the message and optionally return a response message. - ServeCOAP(l *net.UDPConn, a *net.UDPAddr, m *Message) *Message + ServeCOAP(c *Conn, m Message) Message } -type funcHandler func(l *net.UDPConn, a *net.UDPAddr, m *Message) *Message +type funcHandler func(c *Conn, m Message) Message -func (f funcHandler) ServeCOAP(l *net.UDPConn, a *net.UDPAddr, m *Message) *Message { - return f(l, a, m) +func (f funcHandler) ServeCOAP(c *Conn, m Message) Message { + return f(c, m) } // FuncHandler builds a handler from a function. -func FuncHandler(f func(l *net.UDPConn, a *net.UDPAddr, m *Message) *Message) Handler { +func FuncHandler(f func(c *Conn, m Message) Message) Handler { return funcHandler(f) } -func handlePacket(l *net.UDPConn, data []byte, u *net.UDPAddr, +//should handlePacket be exported? +func handlePacket(c *Conn, data []byte, addr Addr, rh Handler) { - msg, err := ParseMessage(data) + msg, err := ParseDgramMessage(data) if err != nil { log.Printf("Error parsing %v", err) return } - rv := rh.ServeCOAP(l, u, &msg) + rv := rh.ServeCOAP(c, msg) if rv != nil { - Transmit(l, u, *rv) + Transmit(c, addr, rv) } } // Transmit a message. -func Transmit(l *net.UDPConn, a *net.UDPAddr, m Message) error { +func Transmit(c *Conn, address Addr, m Message) error { d, err := m.MarshalBinary() if err != nil { return err } - - if a == nil { - _, err = l.Write(d) - } else { - _, err = l.WriteTo(d, a) + net, err := c.Network() + if err != nil { + return err + } + if net == "udp" { + addr := address.Udp.String() + + if string([]byte(addr)) == "" { + _, err = c.conn.Write(d) + } else { + // _, err = c.conn.Write(d) //this line is just to prevent the "use of writeto with pre-connected connection" error + _, err = c.conn.WriteToUDP(d, address.Udp) + } + return err + } + if net == "tcp" { + _, err := c.connTCP.Write(d) + return err } return err } // Receive a message. -func Receive(l *net.UDPConn, buf []byte) (Message, error) { - l.SetReadDeadline(time.Now().Add(ResponseTimeout)) - - nr, _, err := l.ReadFromUDP(buf) +func Receive(c *Conn, buf []byte) (Message, error) { + n, err := c.Network() if err != nil { - return Message{}, err + return nil, err + } + switch n { + case "udp": + c.conn.SetReadDeadline(time.Now().Add(ResponseTimeout)) + + nr, err := c.conn.Read(buf) + if err != nil { + return &DgramMessage{}, err + } + return ParseDgramMessage(buf[:nr]) + case "tcp": + c.connTCP.SetReadDeadline(time.Now().Add(ResponseTimeout)) + for { + _, err := c.connTCP.Read(buf) + if err != nil { + return &TcpMessage{}, err + } + m, _, err := PullTcp(buf) + return m, err + } + default: + return nil, err + } - return ParseMessage(buf[:nr]) } -// ListenAndServe binds to the given address and serve requests forever. +// ListenAndServe binds to the given address and serve requests forever. This has not been modified to handle TCP func ListenAndServe(n, addr string, rh Handler) error { uaddr, err := net.ResolveUDPAddr(n, addr) if err != nil { @@ -79,24 +113,67 @@ func ListenAndServe(n, addr string, rh Handler) error { return err } - return Serve(l, rh) + return Serve( + &Conn{conn: l}, + rh, + ) } // Serve processes incoming UDP packets on the given listener, and processes // these requests forever (or until the listener is closed). -func Serve(listener *net.UDPConn, rh Handler) error { +func Serve(listener *Conn, rh Handler) error { buf := make([]byte, maxPktLen) - for { - nr, addr, err := listener.ReadFromUDP(buf) - if err != nil { - if neterr, ok := err.(net.Error); ok && (neterr.Temporary() || neterr.Timeout()) { - time.Sleep(5 * time.Millisecond) - continue + n, err := listener.Network() + if err != nil { + return err + } + if n == "udp" { + for { + nr, addr, err := listener.conn.ReadFromUDP(buf) + if err != nil { + if neterr, ok := err.(net.Error); ok && (neterr.Temporary() || neterr.Timeout()) { + time.Sleep(5 * time.Millisecond) + continue + } + return err + } + tmp := make([]byte, nr) + copy(tmp, buf) + go handlePacket(listener, tmp, Addr{Udp: addr}, rh) + } + } + if n == "tcp" { //i need to get this function to keep looping and reading until it gets a full TCP packet + for { + _, err := listener.connTCP.Read(listener.buf) //maybe needs pullTCP()? + if err != nil { + if neterr, ok := err.(net.Error); ok && (neterr.Temporary() || neterr.Timeout()) { + time.Sleep(5 * time.Millisecond) + continue + } + return err + } + if len(listener.buf) > 0 { + + tmp, buf, err := PullTcp(listener.buf) + if err != nil { + return err + } + if len(listener.buf) > len(buf) { + listener.buf = buf + + m, err := tmp.MarshalBinary() + if err != nil { + return err + } + + addr, err := net.ResolveTCPAddr("tcp", listener.connTCP.RemoteAddr().String()) + if err != nil { + return err + } + go handlePacket(listener, m, Addr{Tcp: addr}, rh) + } } - return err } - tmp := make([]byte, nr) - copy(tmp, buf) - go handlePacket(listener, tmp, addr, rh) } + return err } diff --git a/serverTCP_test.go b/serverTCP_test.go new file mode 100644 index 0000000..9965f5c --- /dev/null +++ b/serverTCP_test.go @@ -0,0 +1,128 @@ +package coap + +import ( + "net" + "testing" +) + +func startTCPLisenter(t *testing.T) (*net.TCPListener, string) { + tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5683") + if err != nil { + t.Fatal("Can't resolve TCP addr") + } + tcpListener, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + t.Fatal("Can't listen on TCP: ", err) + } + + coapServerAddr := tcpListener.Addr().String() + return tcpListener, coapServerAddr +} + +func dialAndSendTCP(t *testing.T, addr string, req Message) Message { + c, err := Dial("tcp", addr) + if err != nil { + t.Log("the addr to dial was: ", addr) + t.Fatalf("Error dialing: %v", err) + } + m, err := c.Send(req) + if err != nil { + t.Fatalf("Error sending request: %v", err) + } + return m +} + +func TestServeTCPWithAckResponse(t *testing.T) { + req := &TcpMessage{ + MessageBase{ + typ: Confirmable, + code: POST, + messageID: 9876, + payload: []byte("Content sent by client"), + }, + } + req.SetOption(ContentFormat, TextPlain) + req.SetPathString("/req/path") + + res := &TcpMessage{ + MessageBase{ + typ: Acknowledgement, + code: Content, + messageID: req.MessageID(), + payload: []byte("Reply from CoAP server"), + }, + } + res.SetOption(ContentFormat, TextPlain) + res.SetPath(req.Path()) + + handler := FuncHandler(func(c *Conn, m Message) Message { + t.Log(m.Type(), "payload:", m.Payload()) + + assertEqualMessages(t, req, m) + return res + }) + + tcpListener, coapServerAddr := startTCPLisenter(t) + defer tcpListener.Close() + go dialAndTest(t, coapServerAddr, req, true, res) + + tcpConn, err := tcpListener.AcceptTCP() + if err != nil { + t.Fatal("err accepting TCPconn: ", err) + } + + go Serve( + &Conn{connTCP: tcpConn}, + handler, + ) + + /* m := dialAndSendTCP(t, coapServerAddr, req) + + if m == nil { + t.Fatalf("Didn't receive CoAP response") + } + assertEqualMessages(t, res, m) + */ +} + +func TestServeTCPWithoutAckResponse(t *testing.T) { + req := &TcpMessage{ + MessageBase{ + typ: NonConfirmable, + code: POST, + messageID: 54321, + payload: []byte("Content sent by client"), + }, + } + req.SetOption(ContentFormat, AppOctets) + + handler := FuncHandler(func(c *Conn, m Message) Message { + assertEqualMessages(t, req, m) + return nil + }) + + tcpListener, coapServerAddr := startTCPLisenter(t) + defer tcpListener.Close() + dialAndTest(t, coapServerAddr, req, false, &TcpMessage{}) + tcpConn, err := tcpListener.AcceptTCP() + if err != nil { + t.Fatal("err accepting TCPconn: ", err) + } + + go Serve( + &Conn{connTCP: tcpConn}, + handler, + ) + +} + +func dialAndTest(t *testing.T, addr string, req *TcpMessage, ack bool, res *TcpMessage) { + m := dialAndSendTCP(t, addr, req) + if ack { + assertEqualMessages(t, res, m) + + } else if m != nil { + t.Errorf("recieved an ack when expecting none") + } + +} diff --git a/server_test.go b/server_test.go index 28d9cb4..e0d3a6d 100644 --- a/server_test.go +++ b/server_test.go @@ -12,13 +12,13 @@ func startUDPLisenter(t *testing.T) (*net.UDPConn, string) { } udpListener, err := net.ListenUDP("udp", udpAddr) if err != nil { - t.Fatal("Can't listen on UDP") + t.Fatal("Can't listen on UDP ", err) } coapServerAddr := udpListener.LocalAddr().String() return udpListener, coapServerAddr } -func dialAndSend(t *testing.T, addr string, req Message) *Message { +func dialAndSend(t *testing.T, addr string, req Message) Message { c, err := Dial("udp", addr) if err != nil { t.Fatalf("Error dialing: %v", err) @@ -31,57 +31,63 @@ func dialAndSend(t *testing.T, addr string, req Message) *Message { } func TestServeWithAckResponse(t *testing.T) { - req := Message{ - Type: Confirmable, - Code: POST, - MessageID: 9876, - Payload: []byte("Content sent by client"), + req := &DgramMessage{ + MessageBase{ + typ: Confirmable, + code: POST, + messageID: 9876, + payload: []byte("Content sent by client"), + }, } req.SetOption(ContentFormat, TextPlain) req.SetPathString("/req/path") - res := Message{ - Type: Acknowledgement, - Code: Content, - MessageID: req.MessageID, - Payload: []byte("Reply from CoAP server"), + res := &DgramMessage{ + MessageBase{ + typ: Acknowledgement, + code: Content, + messageID: req.MessageID(), + payload: []byte("Reply from CoAP server"), + }, } res.SetOption(ContentFormat, TextPlain) res.SetPath(req.Path()) - handler := FuncHandler(func(l *net.UDPConn, a *net.UDPAddr, m *Message) *Message { - assertEqualMessages(t, req, *m) - return &res + handler := FuncHandler(func(c *Conn, m Message) Message { + assertEqualMessages(t, req, m) + return res }) udpListener, coapServerAddr := startUDPLisenter(t) defer udpListener.Close() - go Serve(udpListener, handler) + go Serve(&Conn{conn: udpListener}, handler) m := dialAndSend(t, coapServerAddr, req) if m == nil { t.Fatalf("Didn't receive CoAP response") } - assertEqualMessages(t, res, *m) + assertEqualMessages(t, res, m) } func TestServeWithoutAckResponse(t *testing.T) { - req := Message{ - Type: NonConfirmable, - Code: POST, - MessageID: 54321, - Payload: []byte("Content sent by client"), + req := &DgramMessage{ + MessageBase{ + typ: NonConfirmable, + code: POST, + messageID: 54321, + payload: []byte("Content sent by client"), + }, } req.SetOption(ContentFormat, AppOctets) - handler := FuncHandler(func(l *net.UDPConn, a *net.UDPAddr, m *Message) *Message { - assertEqualMessages(t, req, *m) + handler := FuncHandler(func(c *Conn, m Message) Message { + assertEqualMessages(t, req, m) return nil }) udpListener, coapServerAddr := startUDPLisenter(t) defer udpListener.Close() - go Serve(udpListener, handler) + go Serve(&Conn{conn: udpListener}, handler) m := dialAndSend(t, coapServerAddr, req) if m != nil { diff --git a/servmux.go b/servmux.go index 23132f1..ce99af3 100644 --- a/servmux.go +++ b/servmux.go @@ -1,9 +1,5 @@ package coap -import ( - "net" -) - // ServeMux provides mappings from a common endpoint to handlers by // request path. type ServeMux struct { @@ -13,6 +9,7 @@ type ServeMux struct { type muxEntry struct { h Handler pattern string + network string } // NewServeMux creates a new ServeMux. @@ -33,10 +30,11 @@ func pathMatch(pattern, path string) bool { // Find a handler on a handler map given a path string // Most-specific (longest) pattern wins -func (mux *ServeMux) match(path string) (h Handler, pattern string) { +func (mux *ServeMux) match(path, network string) (h Handler, pattern string) { var n = 0 for k, v := range mux.m { - if !pathMatch(k, path) { + net := mux.m[path].network + if !pathMatch(k, path) && net != network { continue } if h == nil || len(k) > n { @@ -48,11 +46,13 @@ func (mux *ServeMux) match(path string) (h Handler, pattern string) { return } -func notFoundHandler(l *net.UDPConn, a *net.UDPAddr, m *Message) *Message { +func notFoundHandler(c *Conn, m Message) Message { if m.IsConfirmable() { - return &Message{ - Type: Acknowledgement, - Code: NotFound, + return &DgramMessage{ + MessageBase{ + typ: Acknowledgement, + code: NotFound, + }, } } return nil @@ -62,33 +62,39 @@ var _ = Handler(&ServeMux{}) // ServeCOAP handles a single COAP message. The message arrives from // the given listener having originated from the given UDPAddr. -func (mux *ServeMux) ServeCOAP(l *net.UDPConn, a *net.UDPAddr, m *Message) *Message { - h, _ := mux.match(m.PathString()) +//WARNING I SHOULD PROBABLY HANDLE ERRORS FOR Conn.Network() +func (mux *ServeMux) ServeCOAP(c *Conn, m Message) Message { + n, _ := c.Network() + h, _ := mux.match(m.PathString(), n) if h == nil { h, _ = funcHandler(notFoundHandler), "" } // TODO: Rewrite path? - return h.ServeCOAP(l, a, m) + return h.ServeCOAP(c, m) } // Handle configures a handler for the given path. -func (mux *ServeMux) Handle(pattern string, handler Handler) { +func (mux *ServeMux) Handle(n string, pattern string, handler Handler) { for pattern != "" && pattern[0] == '/' { pattern = pattern[1:] } if pattern == "" { - panic("http: invalid pattern " + pattern) + panic("coap: invalid pattern " + pattern) } if handler == nil { - panic("http: nil handler") + panic("coap: nil handler") } - - mux.m[pattern] = muxEntry{h: handler, pattern: pattern} + if _, ok := mux.m[pattern]; ok { + if mux.m[pattern].network == n { + panic("coap: multiple registration for " + pattern + " on transport: " + n) + } + } + mux.m[pattern] = muxEntry{h: handler, pattern: pattern, network: n} } // HandleFunc configures a handler for the given path. -func (mux *ServeMux) HandleFunc(pattern string, - f func(l *net.UDPConn, a *net.UDPAddr, m *Message) *Message) { - mux.Handle(pattern, FuncHandler(f)) +func (mux *ServeMux) HandleFunc(network, pattern string, + f func(c *Conn, m Message) Message) { + mux.Handle(network, pattern, FuncHandler(f)) } diff --git a/servmux_test.go b/servmux_test.go index b9ddc41..9b9e55c 100644 --- a/servmux_test.go +++ b/servmux_test.go @@ -1,7 +1,6 @@ package coap import ( - "net" "testing" ) @@ -9,28 +8,37 @@ func TestPathMatching(t *testing.T) { m := NewServeMux() msgs := map[string]int{} - - m.HandleFunc("/a", func(l *net.UDPConn, a *net.UDPAddr, m *Message) *Message { + //using nil for network type because no transport is being used in this test + m.HandleFunc("", "/a", func(c *Conn, m Message) Message { msgs["a"]++ + t.Log("get a request on /a ", string(m.Payload())) return nil }) - m.HandleFunc("/b", func(l *net.UDPConn, a *net.UDPAddr, m *Message) *Message { + m.HandleFunc("", "/b", func(c *Conn, m Message) Message { msgs["b"]++ + t.Log("get a request on /b ", string(m.Payload())) return nil }) - msg := &Message{} + msg := &DgramMessage{} + cTcp := &Conn{Net: "tcp"} //it's easier to set Conn.Net and not use it than it is to explicitly accept connections without a stated transport type + cUdp := &Conn{Net: "udp"} msg.SetPathString("/a") - m.ServeCOAP(nil, nil, msg) + msg.SetPayload([]byte("hi a1")) + m.ServeCOAP(cTcp, msg) msg.SetPathString("/a") - m.ServeCOAP(nil, nil, msg) + msg.SetPayload([]byte("hi a2")) + m.ServeCOAP(cTcp, msg) msg.SetPathString("/b") - m.ServeCOAP(nil, nil, msg) + msg.SetPayload([]byte("hi b1")) + m.ServeCOAP(cUdp, msg) msg.SetPathString("/c") - m.ServeCOAP(nil, nil, msg) - msg.Type = NonConfirmable + msg.SetPayload([]byte("hi c")) + m.ServeCOAP(cUdp, msg) + msg.MessageBase.typ = NonConfirmable msg.SetPathString("/c") - m.ServeCOAP(nil, nil, msg) + msg.SetPayload([]byte("hi c")) + m.ServeCOAP(cTcp, msg) if msgs["a"] != 2 { t.Errorf("Expected 2 messages for /a, got %v", msgs["a"])