From 3841a5f2ad6d66dff9edd4978088da5ffd8e03e5 Mon Sep 17 00:00:00 2001 From: ethan256 Date: Sat, 18 May 2024 00:30:12 +0800 Subject: [PATCH 01/14] feat: enhance message encode --- ua/datatypes.go | 41 ++-- ua/diagnostic_info.go | 21 +- ua/encode.go | 10 + ua/expanded_node_id.go | 12 +- ua/extension_object.go | 20 +- ua/node_id.go | 27 ++- ua/stream.go | 304 +++++++++++++++++++++++++++++ ua/variant.go | 75 ++++--- uacp/codec_test.go | 8 +- uacp/conn.go | 14 +- uacp/conn_test.go | 5 +- uacp/uacp.go | 68 +++---- uasc/asymmetric_security_header.go | 11 +- uasc/codec_test.go | 9 +- uasc/header.go | 15 +- uasc/message.go | 26 +-- uasc/message_test.go | 253 +++++++++++++++++++++++- uasc/sequence_header.go | 9 +- uasc/symmetric_security_header.go | 7 +- 19 files changed, 742 insertions(+), 193 deletions(-) create mode 100644 ua/stream.go diff --git a/ua/datatypes.go b/ua/datatypes.go index c298841c..2aefc0f7 100644 --- a/ua/datatypes.go +++ b/ua/datatypes.go @@ -61,29 +61,27 @@ func (d *DataValue) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (d *DataValue) Encode() ([]byte, error) { - buf := NewBuffer(nil) - buf.WriteUint8(d.EncodingMask) +func (d *DataValue) Encode(s *Stream) { + s.WriteUint8(d.EncodingMask) if d.Has(DataValueValue) { - buf.WriteStruct(d.Value) + s.WriteStruct(d.Value) } if d.Has(DataValueStatusCode) { - buf.WriteUint32(uint32(d.Status)) + s.WriteUint32(uint32(d.Status)) } if d.Has(DataValueSourceTimestamp) { - buf.WriteTime(d.SourceTimestamp) + s.WriteTime(d.SourceTimestamp) } if d.Has(DataValueSourcePicoseconds) { - buf.WriteUint16(d.SourcePicoseconds) + s.WriteUint16(d.SourcePicoseconds) } if d.Has(DataValueServerTimestamp) { - buf.WriteTime(d.ServerTimestamp) + s.WriteTime(d.ServerTimestamp) } if d.Has(DataValueServerPicoseconds) { - buf.WriteUint16(d.ServerPicoseconds) + s.WriteUint16(d.ServerPicoseconds) } - return buf.Bytes(), buf.Error() } func (d *DataValue) Has(mask byte) bool { @@ -152,13 +150,11 @@ func (g *GUID) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (g *GUID) Encode() ([]byte, error) { - buf := NewBuffer(nil) - buf.WriteUint32(g.Data1) - buf.WriteUint16(g.Data2) - buf.WriteUint16(g.Data3) - buf.Write(g.Data4) - return buf.Bytes(), buf.Error() +func (g *GUID) Encode(s *Stream) { + s.WriteUint32(g.Data1) + s.WriteUint16(g.Data2) + s.WriteUint16(g.Data3) + s.Write(g.Data4) } // String returns GUID in human-readable string. @@ -227,16 +223,15 @@ func (l *LocalizedText) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (l *LocalizedText) Encode() ([]byte, error) { - buf := NewBuffer(nil) - buf.WriteUint8(l.EncodingMask) +func (l *LocalizedText) Encode(s *Stream) error { + s.WriteUint8(l.EncodingMask) if l.Has(LocalizedTextLocale) { - buf.WriteString(l.Locale) + s.WriteString(l.Locale) } if l.Has(LocalizedTextText) { - buf.WriteString(l.Text) + s.WriteString(l.Text) } - return buf.Bytes(), buf.Error() + return s.Error() } func (l *LocalizedText) Has(mask byte) bool { diff --git a/ua/diagnostic_info.go b/ua/diagnostic_info.go index a3b05fd4..9a873d07 100644 --- a/ua/diagnostic_info.go +++ b/ua/diagnostic_info.go @@ -58,31 +58,30 @@ func (d *DiagnosticInfo) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (d *DiagnosticInfo) Encode() ([]byte, error) { - buf := NewBuffer(nil) - buf.WriteByte(d.EncodingMask) +func (d *DiagnosticInfo) Encode(s *Stream) error { + s.WriteByte(d.EncodingMask) if d.Has(DiagnosticInfoSymbolicID) { - buf.WriteInt32(d.SymbolicID) + s.WriteInt32(d.SymbolicID) } if d.Has(DiagnosticInfoNamespaceURI) { - buf.WriteInt32(d.NamespaceURI) + s.WriteInt32(d.NamespaceURI) } if d.Has(DiagnosticInfoLocale) { - buf.WriteInt32(d.Locale) + s.WriteInt32(d.Locale) } if d.Has(DiagnosticInfoLocalizedText) { - buf.WriteInt32(d.LocalizedText) + s.WriteInt32(d.LocalizedText) } if d.Has(DiagnosticInfoAdditionalInfo) { - buf.WriteString(d.AdditionalInfo) + s.WriteString(d.AdditionalInfo) } if d.Has(DiagnosticInfoInnerStatusCode) { - buf.WriteUint32(uint32(d.InnerStatusCode)) + s.WriteUint32(uint32(d.InnerStatusCode)) } if d.Has(DiagnosticInfoInnerDiagnosticInfo) { - buf.WriteStruct(d.InnerDiagnosticInfo) + s.WriteStruct(d.InnerDiagnosticInfo) } - return buf.Bytes(), buf.Error() + return s.Error() } func (d *DiagnosticInfo) Has(mask byte) bool { diff --git a/ua/encode.go b/ua/encode.go index 2d6462f1..0b17fec6 100644 --- a/ua/encode.go +++ b/ua/encode.go @@ -15,6 +15,16 @@ import ( "github.com/gopcua/opcua/errors" ) +type ValEncoder interface { + Encode(s *Stream) error +} + +var valEncoder = reflect.TypeOf((*ValEncoder)(nil)).Elem() + +func isValEncoder(val reflect.Value) bool { + return val.Type().Implements(valEncoder) +} + // debugCodec enables printing of debug messages in the opcua codec. var debugCodec = debug.FlagSet("codec") diff --git a/ua/expanded_node_id.go b/ua/expanded_node_id.go index 404a6edb..925bb450 100644 --- a/ua/expanded_node_id.go +++ b/ua/expanded_node_id.go @@ -102,17 +102,15 @@ func (e *ExpandedNodeID) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (e *ExpandedNodeID) Encode() ([]byte, error) { - buf := NewBuffer(nil) - buf.WriteStruct(e.NodeID) +func (e *ExpandedNodeID) Encode(s *Stream) error { + s.WriteStruct(e.NodeID) if e.HasNamespaceURI() { - buf.WriteString(e.NamespaceURI) + s.WriteString(e.NamespaceURI) } if e.HasServerIndex() { - buf.WriteUint32(e.ServerIndex) + s.WriteUint32(e.ServerIndex) } - return buf.Bytes(), buf.Error() - + return s.Error() } // HasNamespaceURI checks if an ExpandedNodeID has NamespaceURI Flag. diff --git a/ua/extension_object.go b/ua/extension_object.go index 4ee7dafd..064e9f04 100644 --- a/ua/extension_object.go +++ b/ua/extension_object.go @@ -84,25 +84,25 @@ func (e *ExtensionObject) Decode(b []byte) (int, error) { return buf.Pos(), body.Error() } -func (e *ExtensionObject) Encode() ([]byte, error) { - buf := NewBuffer(nil) +func (e *ExtensionObject) Encode(s *Stream) error { if e == nil { e = &ExtensionObject{TypeID: NewTwoByteExpandedNodeID(0), EncodingMask: ExtensionObjectEmpty} } - buf.WriteStruct(e.TypeID) - buf.WriteByte(e.EncodingMask) + s.WriteStruct(e.TypeID) + s.WriteByte(e.EncodingMask) if e.EncodingMask == ExtensionObjectEmpty { - return buf.Bytes(), buf.Error() + return s.Error() } - body := NewBuffer(nil) + // TODO: use pool? + body := NewStream(DefaultBufSize) body.WriteStruct(e.Value) if body.Error() != nil { - return nil, body.Error() + return body.Error() } - buf.WriteUint32(uint32(body.Len())) - buf.Write(body.Bytes()) - return buf.Bytes(), buf.Error() + s.WriteUint32(uint32(body.Len())) + s.Write(body.Bytes()) + return s.Error() } func (e *ExtensionObject) UpdateMask() { diff --git a/ua/node_id.go b/ua/node_id.go index 81c892ea..50cb7de3 100644 --- a/ua/node_id.go +++ b/ua/node_id.go @@ -366,29 +366,28 @@ func (n *NodeID) Decode(b []byte) (int, error) { } } -func (n *NodeID) Encode() ([]byte, error) { - buf := NewBuffer(nil) - buf.WriteByte(byte(n.mask)) +func (n *NodeID) Encode(s *Stream) error { + s.WriteByte(byte(n.mask)) switch n.Type() { case NodeIDTypeTwoByte: - buf.WriteByte(byte(n.nid)) + s.WriteByte(byte(n.nid)) case NodeIDTypeFourByte: - buf.WriteByte(byte(n.ns)) - buf.WriteUint16(uint16(n.nid)) + s.WriteByte(byte(n.ns)) + s.WriteUint16(uint16(n.nid)) case NodeIDTypeNumeric: - buf.WriteUint16(n.ns) - buf.WriteUint32(n.nid) + s.WriteUint16(n.ns) + s.WriteUint32(n.nid) case NodeIDTypeGUID: - buf.WriteUint16(n.ns) - buf.WriteStruct(n.gid) + s.WriteUint16(n.ns) + s.WriteStruct(n.gid) case NodeIDTypeByteString, NodeIDTypeString: - buf.WriteUint16(n.ns) - buf.WriteByteString(n.bid) + s.WriteUint16(n.ns) + s.WriteByteString(n.bid) default: - return nil, errors.Errorf("invalid node id type %v", n.Type()) + return errors.Errorf("invalid node id type %v", n.Type()) } - return buf.Bytes(), buf.Error() + return s.Error() } func (n *NodeID) MarshalJSON() ([]byte, error) { diff --git a/ua/stream.go b/ua/stream.go new file mode 100644 index 00000000..d14b9aeb --- /dev/null +++ b/ua/stream.go @@ -0,0 +1,304 @@ +package ua + +import ( + "encoding/binary" + "fmt" + "io" + "math" + "reflect" + "time" + + "github.com/gopcua/opcua/errors" +) + +const DefaultBufSize = 1024 + +type Stream struct { + buf []byte + pos int + err error +} + +func NewStream(size int) *Stream { + return &Stream{ + buf: make([]byte, 0, size), + } +} + +func (s *Stream) Error() error { + return s.err +} + +func (s *Stream) Len() int { + return len(s.buf) +} + +func (s *Stream) Reset() { + s.buf = s.buf[:0] + s.pos = 0 + s.err = nil +} + +func (s *Stream) Bytes() []byte { + return s.buf +} + +func (b *Stream) ReadN(n int) []byte { + if b.err != nil { + return nil + } + d := b.buf[b.pos:] + if n > len(d) { + b.err = io.ErrUnexpectedEOF + return nil + } + b.pos += n + return d[:n] +} + +func (b *Stream) WriteBool(v bool) { + if v { + b.WriteUint8(1) + } else { + b.WriteUint8(0) + } +} + +func (b *Stream) WriteByte(n byte) { + b.buf = append(b.buf, n) +} + +func (b *Stream) WriteInt8(n int8) { + b.buf = append(b.buf, byte(n)) +} + +func (b *Stream) WriteUint8(n uint8) { + b.buf = append(b.buf, byte(n)) +} + +func (b *Stream) WriteInt16(n int16) { + b.WriteUint16(uint16(n)) +} + +func (b *Stream) WriteUint16(n uint16) { + d := make([]byte, 2) + binary.LittleEndian.PutUint16(d, n) + b.Write(d) +} + +func (b *Stream) WriteInt32(n int32) { + b.WriteUint32(uint32(n)) +} + +func (b *Stream) WriteUint32(n uint32) { + d := make([]byte, 4) + binary.LittleEndian.PutUint32(d, n) + b.Write(d) +} + +func (b *Stream) WriteInt64(n int64) { + b.WriteUint64(uint64(n)) +} + +func (b *Stream) WriteUint64(n uint64) { + d := make([]byte, 8) + binary.LittleEndian.PutUint64(d, n) + b.Write(d) +} + +func (b *Stream) WriteFloat32(n float32) { + if math.IsNaN(float64(n)) { + b.WriteUint32(f32qnan) + } else { + b.WriteUint32(math.Float32bits(n)) + } +} + +func (b *Stream) WriteFloat64(n float64) { + if math.IsNaN(n) { + b.WriteUint64(f64qnan) + } else { + b.WriteUint64(math.Float64bits(n)) + } +} + +func (b *Stream) WriteString(s string) { + if s == "" { + b.WriteUint32(null) + return + } + b.WriteByteString([]byte(s)) +} + +func (b *Stream) WriteByteString(d []byte) { + if b.err != nil { + return + } + if len(d) > math.MaxInt32 { + b.err = errors.Errorf("value too large") + return + } + if d == nil { + b.WriteUint32(null) + return + } + b.WriteUint32(uint32(len(d))) + b.Write(d) +} + +func (b *Stream) WriteTime(v time.Time) { + d := make([]byte, 8) + if !v.IsZero() { + // encode time in "100 nanosecond intervals since January 1, 1601" + ts := uint64(v.UTC().UnixNano()/100 + 116444736000000000) + binary.LittleEndian.PutUint64(d, ts) + } + b.Write(d) +} + +func (b *Stream) Write(d []byte) { + if b.err != nil { + return + } + b.buf = append(b.buf, d...) +} + +func (s *Stream) WriteStruct(w interface{}) { + if s.err != nil { + return + } + val := reflect.ValueOf(w) + switch x := w.(type) { + case ValEncoder: + s.err = x.Encode(s) + default: + s.err = s.encode(val, val.Type().String()) + } +} + +func (s *Stream) encode(val reflect.Value, name string) error { + if debugCodec { + fmt.Printf("encode: %s has type %s and is a %s\n", name, val.Type(), val.Type().Kind()) + } + + switch { + case isValEncoder(val): + v := val.Interface().(ValEncoder) + return v.Encode(s) + + case isTime(val): + s.WriteTime(val.Convert(timeType).Interface().(time.Time)) + + default: + switch val.Kind() { + case reflect.Bool: + s.WriteBool(val.Bool()) + case reflect.Int8: + s.WriteInt8(int8(val.Int())) + case reflect.Uint8: + s.WriteUint8(uint8(val.Uint())) + case reflect.Int16: + s.WriteInt16(int16(val.Int())) + case reflect.Uint16: + s.WriteUint16(uint16(val.Uint())) + case reflect.Int32: + s.WriteInt32(int32(val.Int())) + case reflect.Uint32: + s.WriteUint32(uint32(val.Uint())) + case reflect.Int64: + s.WriteInt64(int64(val.Int())) + case reflect.Uint64: + s.WriteUint64(uint64(val.Uint())) + case reflect.Float32: + s.WriteFloat32(float32(val.Float())) + case reflect.Float64: + s.WriteFloat64(float64(val.Float())) + case reflect.String: + s.WriteString(val.String()) + case reflect.Ptr: + if val.IsNil() { + return nil + } + return s.encode(val.Elem(), name) + case reflect.Struct: + return s.writeStruct(val, name) + case reflect.Slice: + return s.writeSlice(val, name) + case reflect.Array: + return s.writeArray(val, name) + default: + return errors.Errorf("unsupported type: %s", val.Type()) + } + } + return s.Error() +} + +func (s *Stream) writeStruct(val reflect.Value, name string) error { + valt := val.Type() + for i := 0; i < val.NumField(); i++ { + ft := valt.Field(i) + fname := name + "." + ft.Name + if err := s.encode(val.Field(i), fname); err != nil { + return err + } + } + return nil +} + +func (s *Stream) writeSlice(val reflect.Value, name string) error { + if val.IsNil() { + s.WriteUint32(null) + return s.Error() + } + + if val.Len() > math.MaxInt32 { + return errors.Errorf("array too large") + } + + s.WriteUint32(uint32(val.Len())) + + // fast path for []byte + if val.Type().Elem().Kind() == reflect.Uint8 { + // fmt.Println("[]byte fast path") + s.Write(val.Bytes()) + return s.Error() + } + + // loop over elements + for i := 0; i < val.Len(); i++ { + ename := fmt.Sprintf("%s[%d]", name, i) + s.encode(val.Index(i), ename) + if s.Error() != nil { + return s.Error() + } + } + return s.Error() +} + +func (s *Stream) writeArray(val reflect.Value, name string) error { + if val.Len() > math.MaxInt32 { + return errors.Errorf("array too large: %d > %d", val.Len(), math.MaxInt32) + } + + s.WriteUint32(uint32(val.Len())) + + // fast path for []byte + if val.Type().Elem().Kind() == reflect.Uint8 { + // fmt.Println("encode: []byte fast path") + b := make([]byte, val.Len()) + reflect.Copy(reflect.ValueOf(b), val) + s.Write(b) + return s.Error() + } + + // loop over elements + // we write all the elements, also the zero values + for i := 0; i < val.Len(); i++ { + ename := fmt.Sprintf("%s[%d]", name, i) + s.encode(val.Index(i), ename) + if s.Error() != nil { + return s.Error() + } + } + return s.Error() +} diff --git a/ua/variant.go b/ua/variant.go index 21d4f43f..be6b7db9 100644 --- a/ua/variant.go +++ b/ua/variant.go @@ -318,95 +318,94 @@ func (m *Variant) decodeValue(buf *Buffer) interface{} { } // Encode implements the codec interface. -func (m *Variant) Encode() ([]byte, error) { - buf := NewBuffer(nil) - buf.WriteByte(m.mask) +func (m *Variant) Encode(s *Stream) error { + s.WriteByte(m.mask) // a null value specifies that no other fields are encoded if m.Type() == TypeIDNull { - return buf.Bytes(), buf.Error() + return s.Error() } if m.Has(VariantArrayValues) { - buf.WriteInt32(m.arrayLength) + s.WriteInt32(m.arrayLength) } - m.encode(buf, reflect.ValueOf(m.value)) + m.encode(s, reflect.ValueOf(m.value)) if m.Has(VariantArrayDimensions) { - buf.WriteInt32(m.arrayDimensionsLength) + s.WriteInt32(m.arrayDimensionsLength) for i := 0; i < int(m.arrayDimensionsLength); i++ { - buf.WriteInt32(m.arrayDimensions[i]) + s.WriteInt32(m.arrayDimensions[i]) } } - return buf.Bytes(), buf.Error() + return s.Error() } // encode recursively writes the values to the buffer. -func (m *Variant) encode(buf *Buffer, val reflect.Value) { +func (m *Variant) encode(s *Stream, val reflect.Value) { if val.Kind() != reflect.Slice || m.Type() == TypeIDByteString { - m.encodeValue(buf, val.Interface()) + m.encodeValue(s, val.Interface()) return } for i := 0; i < val.Len(); i++ { - m.encode(buf, val.Index(i)) + m.encode(s, val.Index(i)) } } // encodeValue writes a single value of the base type to the buffer. -func (m *Variant) encodeValue(buf *Buffer, v interface{}) { +func (m *Variant) encodeValue(s *Stream, v interface{}) { switch x := v.(type) { case bool: - buf.WriteBool(x) + s.WriteBool(x) case int8: - buf.WriteInt8(x) + s.WriteInt8(x) case byte: - buf.WriteByte(x) + s.WriteByte(x) case int16: - buf.WriteInt16(x) + s.WriteInt16(x) case uint16: - buf.WriteUint16(x) + s.WriteUint16(x) case int32: - buf.WriteInt32(x) + s.WriteInt32(x) case uint32: - buf.WriteUint32(x) + s.WriteUint32(x) case int64: - buf.WriteInt64(x) + s.WriteInt64(x) case uint64: - buf.WriteUint64(x) + s.WriteUint64(x) case float32: - buf.WriteFloat32(x) + s.WriteFloat32(x) case float64: - buf.WriteFloat64(x) + s.WriteFloat64(x) case string: - buf.WriteString(x) + s.WriteString(x) case time.Time: - buf.WriteTime(x) + s.WriteTime(x) case *GUID: - buf.WriteStruct(x) + s.WriteStruct(x) case []byte: - buf.WriteByteString(x) + s.WriteByteString(x) case XMLElement: - buf.WriteString(string(x)) + s.WriteString(string(x)) case *NodeID: - buf.WriteStruct(x) + s.WriteStruct(x) case *ExpandedNodeID: - buf.WriteStruct(x) + s.WriteStruct(x) case StatusCode: - buf.WriteUint32(uint32(x)) + s.WriteUint32(uint32(x)) case *QualifiedName: - buf.WriteStruct(x) + s.WriteStruct(x) case *LocalizedText: - buf.WriteStruct(x) + s.WriteStruct(x) case *ExtensionObject: - buf.WriteStruct(x) + s.WriteStruct(x) case *DataValue: - buf.WriteStruct(x) + s.WriteStruct(x) case *Variant: - buf.WriteStruct(x) + s.WriteStruct(x) case *DiagnosticInfo: - buf.WriteStruct(x) + s.WriteStruct(x) } } diff --git a/uacp/codec_test.go b/uacp/codec_test.go index 360ab773..98f5669b 100644 --- a/uacp/codec_test.go +++ b/uacp/codec_test.go @@ -54,11 +54,11 @@ func RunCodecTest(t *testing.T, cases []CodecTestCase) { }) t.Run("encode", func(t *testing.T) { - b, err := ua.Encode(c.Struct) - if err != nil { - t.Fatal(err) + s := ua.NewStream(ua.DefaultBufSize) + s.WriteStruct(c.Struct) + if s.Error() != nil { + t.Fatalf("fail to encode message, err: %v", s.Error()) } - verify.Values(t, "", b, c.Bytes) }) }) } diff --git a/uacp/conn.go b/uacp/conn.go index c2acb79c..d8dfe2a8 100644 --- a/uacp/conn.go +++ b/uacp/conn.go @@ -396,27 +396,29 @@ func (c *Conn) Send(typ string, msg interface{}) error { return errors.Errorf("invalid msg type: %s", typ) } - body, err := ua.Encode(msg) - if err != nil { - return errors.Errorf("encode msg failed: %s", err) + bodyStream := ua.NewStream(ua.DefaultBufSize) + bodyStream.WriteStruct(msg) + if bodyStream.Error() != nil { + return errors.Errorf("encode msg failed: %s", bodyStream.Error()) } h := Header{ MessageType: typ[:3], ChunkType: typ[3], - MessageSize: uint32(len(body) + hdrlen), + MessageSize: uint32(bodyStream.Len() + hdrlen), } if h.MessageSize > c.ack.SendBufSize { return errors.Errorf("send packet too large: %d > %d bytes", h.MessageSize, c.ack.SendBufSize) } - hdr, err := h.Encode() + headerStream := ua.NewStream(ua.DefaultBufSize) + err := h.Encode(headerStream) if err != nil { return errors.Errorf("encode hdr failed: %s", err) } - b := append(hdr, body...) + b := append(headerStream.Bytes(), bodyStream.Bytes()...) if _, err := c.Write(b); err != nil { return errors.Errorf("write failed: %s", err) } diff --git a/uacp/conn_test.go b/uacp/conn_test.go index 2269bbe2..f4fdffa3 100644 --- a/uacp/conn_test.go +++ b/uacp/conn_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/gopcua/opcua/errors" + "github.com/gopcua/opcua/ua" "github.com/pascaldekloe/goe/verify" ) @@ -120,10 +121,12 @@ NEXT: } got = got[hdrlen:] - want, err := msg.Encode() + s := ua.NewStream(ua.DefaultBufSize) + err = msg.Encode(s) if err != nil { t.Fatal(err) } + want := s.Bytes() verify.Values(t, "", got, want) } diff --git a/uacp/uacp.go b/uacp/uacp.go index 7aa37b91..1aa737a8 100644 --- a/uacp/uacp.go +++ b/uacp/uacp.go @@ -45,15 +45,14 @@ func (h *Header) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *Header) Encode() ([]byte, error) { - buf := ua.NewBuffer(nil) +func (h *Header) Encode(s *ua.Stream) error { if len(h.MessageType) != 3 { - return nil, errors.Errorf("invalid message type: %q", h.MessageType) + return errors.Errorf("invalid message type: %q", h.MessageType) } - buf.Write([]byte(h.MessageType)) - buf.WriteByte(h.ChunkType) - buf.WriteUint32(h.MessageSize) - return buf.Bytes(), buf.Error() + s.Write([]byte(h.MessageType)) + s.WriteByte(h.ChunkType) + s.WriteUint32(h.MessageSize) + return s.Error() } // Hello represents a OPC UA Hello. @@ -79,15 +78,14 @@ func (h *Hello) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *Hello) Encode() ([]byte, error) { - buf := ua.NewBuffer(nil) - buf.WriteUint32(h.Version) - buf.WriteUint32(h.ReceiveBufSize) - buf.WriteUint32(h.SendBufSize) - buf.WriteUint32(h.MaxMessageSize) - buf.WriteUint32(h.MaxChunkCount) - buf.WriteString(h.EndpointURL) - return buf.Bytes(), buf.Error() +func (h *Hello) Encode(s *ua.Stream) error { + s.WriteUint32(h.Version) + s.WriteUint32(h.ReceiveBufSize) + s.WriteUint32(h.SendBufSize) + s.WriteUint32(h.MaxMessageSize) + s.WriteUint32(h.MaxChunkCount) + s.WriteString(h.EndpointURL) + return s.Error() } // Acknowledge represents a OPC UA Acknowledge. @@ -111,14 +109,13 @@ func (a *Acknowledge) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (a *Acknowledge) Encode() ([]byte, error) { - buf := ua.NewBuffer(nil) - buf.WriteUint32(a.Version) - buf.WriteUint32(a.ReceiveBufSize) - buf.WriteUint32(a.SendBufSize) - buf.WriteUint32(a.MaxMessageSize) - buf.WriteUint32(a.MaxChunkCount) - return buf.Bytes(), buf.Error() +func (a *Acknowledge) Encode(s *ua.Stream) error { + s.WriteUint32(a.Version) + s.WriteUint32(a.ReceiveBufSize) + s.WriteUint32(a.SendBufSize) + s.WriteUint32(a.MaxMessageSize) + s.WriteUint32(a.MaxChunkCount) + return s.Error() } // ReverseHello represents a OPC UA ReverseHello. @@ -136,11 +133,10 @@ func (r *ReverseHello) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (r *ReverseHello) Encode() ([]byte, error) { - buf := ua.NewBuffer(nil) - buf.WriteString(r.ServerURI) - buf.WriteString(r.EndpointURL) - return buf.Bytes(), buf.Error() +func (r *ReverseHello) Encode(s *ua.Stream) error { + s.WriteString(r.ServerURI) + s.WriteString(r.EndpointURL) + return s.Error() } // Error represents a OPC UA Error. @@ -158,11 +154,10 @@ func (e *Error) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (e *Error) Encode() ([]byte, error) { - buf := ua.NewBuffer(nil) - buf.WriteUint32(e.ErrorCode) - buf.WriteString(e.Reason) - return buf.Bytes(), buf.Error() +func (e *Error) Encode(s *ua.Stream) error { + s.WriteUint32(e.ErrorCode) + s.WriteString(e.Reason) + return s.Error() } func (e *Error) Error() string { @@ -183,6 +178,7 @@ func (m *Message) Decode(b []byte) (int, error) { return len(b), nil } -func (m *Message) Encode() ([]byte, error) { - return m.Data, nil +func (m *Message) Encode(s *ua.Stream) error { + s.Write(m.Data) + return nil } diff --git a/uasc/asymmetric_security_header.go b/uasc/asymmetric_security_header.go index f24c5908..6c2032c8 100644 --- a/uasc/asymmetric_security_header.go +++ b/uasc/asymmetric_security_header.go @@ -34,12 +34,11 @@ func (h *AsymmetricSecurityHeader) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *AsymmetricSecurityHeader) Encode() ([]byte, error) { - buf := ua.NewBuffer(nil) - buf.WriteString(h.SecurityPolicyURI) - buf.WriteByteString(h.SenderCertificate) - buf.WriteByteString(h.ReceiverCertificateThumbprint) - return buf.Bytes(), buf.Error() +func (h *AsymmetricSecurityHeader) Encode(s *ua.Stream) error { + s.WriteString(h.SecurityPolicyURI) + s.WriteByteString(h.SenderCertificate) + s.WriteByteString(h.ReceiverCertificateThumbprint) + return s.Error() } // String returns Header in string. diff --git a/uasc/codec_test.go b/uasc/codec_test.go index 18f94aee..2ec99e80 100644 --- a/uasc/codec_test.go +++ b/uasc/codec_test.go @@ -54,11 +54,12 @@ func RunCodecTest(t *testing.T, cases []CodecTestCase) { }) t.Run("encode", func(t *testing.T) { - b, err := ua.Encode(c.Struct) - if err != nil { - t.Fatal(err) + s := ua.NewStream(ua.DefaultBufSize) + s.WriteStruct(c.Struct) + if s.Error() != nil { + t.Fatal(s.Error()) } - verify.Values(t, "", b, c.Bytes) + verify.Values(t, "", s.Bytes(), c.Bytes) }) }) } diff --git a/uasc/header.go b/uasc/header.go index fbab80a4..ddc01a77 100644 --- a/uasc/header.go +++ b/uasc/header.go @@ -51,16 +51,15 @@ func (h *Header) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *Header) Encode() ([]byte, error) { - buf := ua.NewBuffer(nil) +func (h *Header) Encode(s *ua.Stream) error { if len(h.MessageType) != 3 { - return nil, errors.Errorf("invalid message type: %q", h.MessageType) + return errors.Errorf("invalid message type: %q", h.MessageType) } - buf.Write([]byte(h.MessageType)) - buf.WriteByte(h.ChunkType) - buf.WriteUint32(h.MessageSize) - buf.WriteUint32(h.SecureChannelID) - return buf.Bytes(), buf.Error() + s.Write([]byte(h.MessageType)) + s.WriteByte(h.ChunkType) + s.WriteUint32(h.MessageSize) + s.WriteUint32(h.SecureChannelID) + return s.Error() } // String returns Header in string. diff --git a/uasc/message.go b/uasc/message.go index 8657d111..b144e4f1 100644 --- a/uasc/message.go +++ b/uasc/message.go @@ -74,11 +74,10 @@ func (m *MessageAbort) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (m *MessageAbort) Encode() ([]byte, error) { - buf := ua.NewBuffer(nil) - buf.WriteUint32(m.ErrorCode) - buf.WriteString(m.Reason) - return buf.Bytes(), buf.Error() +func (m *MessageAbort) Encode(s *ua.Stream) error { + s.WriteUint32(m.ErrorCode) + s.WriteString(m.Reason) + return s.Error() } func (m *MessageAbort) MessageAbort() string { @@ -112,16 +111,17 @@ func (m *Message) Decode(b []byte) (int, error) { return len(b), err } -func (m *Message) Encode() ([]byte, error) { +func (m *Message) Encode(s *ua.Stream) error { chunks, err := m.EncodeChunks(math.MaxUint32) if err != nil { - return nil, err + return err } - return chunks[0], nil + s.Write(chunks[0]) + return nil } func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { - dataBody := ua.NewBuffer(nil) + dataBody := ua.NewStream(ua.DefaultBufSize) dataBody.WriteStruct(m.TypeID) dataBody.WriteStruct(m.Service) @@ -134,7 +134,7 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { switch m.Header.MessageType { case "OPN": - partialHeader := ua.NewBuffer(nil) + partialHeader := ua.NewStream(ua.DefaultBufSize) partialHeader.WriteStruct(m.AsymmetricSecurityHeader) partialHeader.WriteStruct(m.SequenceHeader) @@ -143,7 +143,7 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { } m.Header.MessageSize = uint32(12 + partialHeader.Len() + dataBody.Len()) - buf := ua.NewBuffer(nil) + buf := ua.NewStream(ua.DefaultBufSize) buf.WriteStruct(m.Header) buf.Write(partialHeader.Bytes()) buf.Write(dataBody.Bytes()) @@ -155,7 +155,7 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { for i := uint32(0); i < nrChunks-1; i++ { m.Header.MessageSize = maxBodySize + 24 m.Header.ChunkType = ChunkTypeIntermediate - chunk := ua.NewBuffer(nil) + chunk := ua.NewStream(ua.DefaultBufSize) chunk.WriteStruct(m.Header) chunk.WriteStruct(m.SymmetricSecurityHeader) chunk.WriteStruct(m.SequenceHeader) @@ -169,7 +169,7 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { m.Header.ChunkType = ChunkTypeFinal m.Header.MessageSize = uint32(24 + dataBody.Len()) - chunk := ua.NewBuffer(nil) + chunk := ua.NewStream(ua.DefaultBufSize) chunk.WriteStruct(m.Header) chunk.WriteStruct(m.SymmetricSecurityHeader) chunk.WriteStruct(m.SequenceHeader) diff --git a/uasc/message_test.go b/uasc/message_test.go index 1153b612..2e10e5bd 100644 --- a/uasc/message_test.go +++ b/uasc/message_test.go @@ -24,7 +24,7 @@ func TestMessage(t *testing.T) { }, } instance := &channelInstance{ - sc: s, + sc: s, sequenceNumber: 0, securityTokenID: 0, } @@ -123,7 +123,7 @@ func TestMessage(t *testing.T) { }, } instance := &channelInstance{ - sc: s, + sc: s, sequenceNumber: 0, securityTokenID: 0, } @@ -194,7 +194,7 @@ func TestMessage(t *testing.T) { }, } instance := &channelInstance{ - sc: s, + sc: s, sequenceNumber: 0, securityTokenID: 0, } @@ -248,3 +248,250 @@ func TestMessage(t *testing.T) { } RunCodecTest(t, cases) } + +func BenchmarkEncodeMessage(b *testing.B) { + cases := []CodecTestCase{ + { + Name: "OPN", + Struct: func() interface{} { + s := &SecureChannel{ + cfg: &Config{ + SecurityPolicyURI: "http://gopcua.example/OPCUA/SecurityPolicy#Foo", + }, + } + instance := &channelInstance{ + sc: s, + sequenceNumber: 0, + securityTokenID: 0, + } + m := instance.newMessage( + &ua.OpenSecureChannelRequest{ + RequestHeader: &ua.RequestHeader{ + AuthenticationToken: ua.NewTwoByteNodeID(0), + Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), + RequestHandle: 1, + ReturnDiagnostics: 0x03ff, + AdditionalHeader: ua.NewExtensionObject(nil), + }, + ClientProtocolVersion: 0, + RequestType: ua.SecurityTokenRequestTypeIssue, + SecurityMode: ua.MessageSecurityModeNone, + RequestedLifetime: 6000000, + }, + id.OpenSecureChannelRequest_Encoding_DefaultBinary, + s.nextRequestID(), + ) + + // set message size manually, since it is computed in Encode + // otherwise, the decode tests failed. + m.Header.MessageSize = 131 + + return m + }(), + Bytes: []byte{ // OpenSecureChannelRequest + // Message Header + // MessageType: OPN + 0x4f, 0x50, 0x4e, + // Chunk Type: Final + 0x46, + // MessageSize: 131 + 0x83, 0x00, 0x00, 0x00, + // SecureChannelID: 0 + 0x00, 0x00, 0x00, 0x00, + // AsymmetricSecurityHeader + // SecurityPolicyURILength + 0x2e, 0x00, 0x00, 0x00, + // SecurityPolicyURI + 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x67, + 0x6f, 0x70, 0x63, 0x75, 0x61, 0x2e, 0x65, 0x78, + 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2f, 0x4f, 0x50, + 0x43, 0x55, 0x41, 0x2f, 0x53, 0x65, 0x63, 0x75, + 0x72, 0x69, 0x74, 0x79, 0x50, 0x6f, 0x6c, 0x69, + 0x63, 0x79, 0x23, 0x46, 0x6f, 0x6f, + // SenderCertificate + 0xff, 0xff, 0xff, 0xff, + // ReceiverCertificateThumbprint + 0xff, 0xff, 0xff, 0xff, + // Sequence Header + // SequenceNumber + 0x01, 0x00, 0x00, 0x00, + // RequestID + 0x01, 0x00, 0x00, 0x00, + // TypeID + 0x01, 0x00, 0xbe, 0x01, + + // RequestHeader + // - AuthenticationToken + 0x00, 0x00, + // - Timestamp + 0x00, 0x98, 0x67, 0xdd, 0xfd, 0x30, 0xd4, 0x01, + // - RequestHandle + 0x01, 0x00, 0x00, 0x00, + // - ReturnDiagnostics + 0xff, 0x03, 0x00, 0x00, + // - AuditEntry + 0xff, 0xff, 0xff, 0xff, + // - TimeoutHint + 0x00, 0x00, 0x00, 0x00, + // - AdditionalHeader + // - TypeID + 0x00, 0x00, + // - EncodingMask + 0x00, + // ClientProtocolVersion + 0x00, 0x00, 0x00, 0x00, + // SecurityTokenRequestType + 0x00, 0x00, 0x00, 0x00, + // MessageSecurityMode + 0x01, 0x00, 0x00, 0x00, + // ClientNonce + 0xff, 0xff, 0xff, 0xff, + // RequestedLifetime + 0x80, 0x8d, 0x5b, 0x00, + }, + }, + { + Name: "MSG", + Struct: func() interface{} { + s := &SecureChannel{ + cfg: &Config{ + SecurityPolicyURI: "http://gopcua.example/OPCUA/SecurityPolicy#Foo", + }, + } + instance := &channelInstance{ + sc: s, + sequenceNumber: 0, + securityTokenID: 0, + } + m := instance.newMessage( + &ua.GetEndpointsRequest{ + RequestHeader: &ua.RequestHeader{ + AuthenticationToken: ua.NewTwoByteNodeID(0), + Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), + RequestHandle: 1, + ReturnDiagnostics: 0x03ff, + AdditionalHeader: ua.NewExtensionObject(nil), + }, + EndpointURL: "opc.tcp://wow.its.easy:11111/UA/Server", + }, + id.GetEndpointsRequest_Encoding_DefaultBinary, + s.nextRequestID(), + ) + + // set message size manually, since it is computed in Encode + // otherwise, the decode tests failed. + m.Header.MessageSize = 107 + + return m + }(), + Bytes: []byte{ // GetEndpointsRequest + // Message Header + // MessageType: MSG + 0x4d, 0x53, 0x47, + // Chunk Type: Final + 0x46, + // MessageSize: 107 + 0x6b, 0x00, 0x00, 0x00, + // SecureChannelID: 0 + 0x00, 0x00, 0x00, 0x00, + // SymmetricSecurityHeader + // TokenID + 0x00, 0x00, 0x00, 0x00, + // Sequence Header + // SequenceNumber + 0x01, 0x00, 0x00, 0x00, + // RequestID + 0x01, 0x00, 0x00, 0x00, + // TypeID + 0x01, 0x00, 0xac, 0x01, + // RequestHeader + 0x00, 0x00, 0x00, 0x98, 0x67, 0xdd, 0xfd, 0x30, + 0xd4, 0x01, 0x01, 0x00, 0x00, 0x00, 0xff, 0x03, + 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, + // ClientProtocolVersion + 0x26, 0x00, 0x00, 0x00, 0x6f, 0x70, 0x63, 0x2e, + 0x74, 0x63, 0x70, 0x3a, 0x2f, 0x2f, 0x77, 0x6f, + 0x77, 0x2e, 0x69, 0x74, 0x73, 0x2e, 0x65, 0x61, + 0x73, 0x79, 0x3a, 0x31, 0x31, 0x31, 0x31, 0x31, + 0x2f, 0x55, 0x41, 0x2f, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, + // LocaleIDs + 0xff, 0xff, 0xff, 0xff, + // ProfileURIs + 0xff, 0xff, 0xff, 0xff, + }, + }, { + Name: "CLO", + Struct: func() interface{} { + s := &SecureChannel{ + cfg: &Config{ + SecurityPolicyURI: "http://gopcua.example/OPCUA/SecurityPolicy#Foo", + }, + } + instance := &channelInstance{ + sc: s, + sequenceNumber: 0, + securityTokenID: 0, + } + m := instance.newMessage( + &ua.CloseSecureChannelRequest{ + RequestHeader: &ua.RequestHeader{ + AuthenticationToken: ua.NewTwoByteNodeID(0), + Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), + RequestHandle: 1, + ReturnDiagnostics: 0x03ff, + AdditionalHeader: ua.NewExtensionObject(nil), + }, + }, + id.CloseSecureChannelRequest_Encoding_DefaultBinary, + s.nextRequestID(), + ) + + // set message size manually, since it is computed in Encode + // otherwise, the decode tests failed. + m.Header.MessageSize = 57 + + return m + }(), + Bytes: []byte{ // OpenSecureChannelRequest + // Message Header + // MessageType: CLO + 0x43, 0x4c, 0x4f, + // Chunk Type: Final + 0x46, + // MessageSize: 57 + 0x39, 0x00, 0x00, 0x00, + // SecureChannelID: 0 + 0x00, 0x00, 0x00, 0x00, + // SymmetricSecurityHeader + // TokenID + 0x00, 0x00, 0x00, 0x00, + // Sequence Header + // SequenceNumber + 0x01, 0x00, 0x00, 0x00, + // RequestID + 0x01, 0x00, 0x00, 0x00, + // TypeID + 0x01, 0x00, 0xc4, 0x01, + // RequestHeader + 0x00, 0x00, 0x00, 0x98, 0x67, 0xdd, 0xfd, 0x30, + 0xd4, 0x01, 0x01, 0x00, 0x00, 0x00, 0xff, 0x03, + 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, + }, + }, + } + + b.ResetTimer() + s := ua.NewStream(ua.DefaultBufSize) + for i := 0; i < b.N; i++ { + for _, tc := range cases { + s.WriteStruct(tc.Struct) + if s.Error() != nil { + b.Fatalf("fail to encode message, err: %v", s.Error()) + } + s.Reset() + } + } +} diff --git a/uasc/sequence_header.go b/uasc/sequence_header.go index 00eefa82..ca2bdc68 100644 --- a/uasc/sequence_header.go +++ b/uasc/sequence_header.go @@ -31,11 +31,10 @@ func (h *SequenceHeader) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *SequenceHeader) Encode() ([]byte, error) { - buf := ua.NewBuffer(nil) - buf.WriteUint32(h.SequenceNumber) - buf.WriteUint32(h.RequestID) - return buf.Bytes(), buf.Error() +func (h *SequenceHeader) Encode(s *ua.Stream) error { + s.WriteUint32(h.SequenceNumber) + s.WriteUint32(h.RequestID) + return s.Error() } // String returns Header in string. diff --git a/uasc/symmetric_security_header.go b/uasc/symmetric_security_header.go index cad46f6e..40670086 100644 --- a/uasc/symmetric_security_header.go +++ b/uasc/symmetric_security_header.go @@ -28,10 +28,9 @@ func (h *SymmetricSecurityHeader) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *SymmetricSecurityHeader) Encode() ([]byte, error) { - buf := ua.NewBuffer(nil) - buf.WriteUint32(h.TokenID) - return buf.Bytes(), buf.Error() +func (h *SymmetricSecurityHeader) Encode(s *ua.Stream) error { + s.WriteUint32(h.TokenID) + return s.Error() } // String returns Header in string. From 8060b25b5d15fbd50d0e4e689ed46bee75bbd1b0 Mon Sep 17 00:00:00 2001 From: ethan256 Date: Sat, 18 May 2024 10:34:15 +0800 Subject: [PATCH 02/14] feat: use stream to reduce memory allocation during encode. --- ua/buffer.go | 123 ----------------------------------- ua/codec_test.go | 9 +-- ua/datatypes.go | 8 ++- ua/decode_test.go | 9 +-- ua/diagnostic_info.go | 2 +- ua/encode.go | 135 +++++++++++++++++--------------------- ua/expanded_node_id.go | 2 +- ua/extension_object.go | 5 +- ua/node_id.go | 2 +- ua/stream.go | 143 ----------------------------------------- ua/variant.go | 18 +++--- uacp/codec_test.go | 2 +- uacp/conn.go | 2 +- uasc/codec_test.go | 2 +- uasc/message.go | 29 +++++---- uasc/message_test.go | 4 +- 16 files changed, 108 insertions(+), 387 deletions(-) diff --git a/ua/buffer.go b/ua/buffer.go index d91b0699..a34ddaa4 100644 --- a/ua/buffer.go +++ b/ua/buffer.go @@ -9,8 +9,6 @@ import ( "io" "math" "time" - - "github.com/gopcua/opcua/errors" ) const ( @@ -203,124 +201,3 @@ func (b *Buffer) ReadN(n int) []byte { b.pos += n return d[:n] } - -func (b *Buffer) WriteBool(v bool) { - if v { - b.WriteUint8(1) - } else { - b.WriteUint8(0) - } -} - -func (b *Buffer) WriteByte(n byte) { - b.buf = append(b.buf, n) -} - -func (b *Buffer) WriteInt8(n int8) { - b.buf = append(b.buf, byte(n)) -} - -func (b *Buffer) WriteUint8(n uint8) { - b.buf = append(b.buf, byte(n)) -} - -func (b *Buffer) WriteInt16(n int16) { - b.WriteUint16(uint16(n)) -} - -func (b *Buffer) WriteUint16(n uint16) { - d := make([]byte, 2) - binary.LittleEndian.PutUint16(d, n) - b.Write(d) -} - -func (b *Buffer) WriteInt32(n int32) { - b.WriteUint32(uint32(n)) -} - -func (b *Buffer) WriteUint32(n uint32) { - d := make([]byte, 4) - binary.LittleEndian.PutUint32(d, n) - b.Write(d) -} - -func (b *Buffer) WriteInt64(n int64) { - b.WriteUint64(uint64(n)) -} - -func (b *Buffer) WriteUint64(n uint64) { - d := make([]byte, 8) - binary.LittleEndian.PutUint64(d, n) - b.Write(d) -} - -func (b *Buffer) WriteFloat32(n float32) { - if math.IsNaN(float64(n)) { - b.WriteUint32(f32qnan) - } else { - b.WriteUint32(math.Float32bits(n)) - } -} - -func (b *Buffer) WriteFloat64(n float64) { - if math.IsNaN(n) { - b.WriteUint64(f64qnan) - } else { - b.WriteUint64(math.Float64bits(n)) - } -} - -func (b *Buffer) WriteString(s string) { - if s == "" { - b.WriteUint32(null) - return - } - b.WriteByteString([]byte(s)) -} - -func (b *Buffer) WriteByteString(d []byte) { - if b.err != nil { - return - } - if len(d) > math.MaxInt32 { - b.err = errors.Errorf("value too large") - return - } - if d == nil { - b.WriteUint32(null) - return - } - b.WriteUint32(uint32(len(d))) - b.Write(d) -} - -func (b *Buffer) WriteStruct(w interface{}) { - if b.err != nil { - return - } - var d []byte - switch x := w.(type) { - case BinaryEncoder: - d, b.err = x.Encode() - default: - d, b.err = Encode(w) - } - b.Write(d) -} - -func (b *Buffer) WriteTime(v time.Time) { - d := make([]byte, 8) - if !v.IsZero() { - // encode time in "100 nanosecond intervals since January 1, 1601" - ts := uint64(v.UTC().UnixNano()/100 + 116444736000000000) - binary.LittleEndian.PutUint64(d, ts) - } - b.Write(d) -} - -func (b *Buffer) Write(d []byte) { - if b.err != nil { - return - } - b.buf = append(b.buf, d...) -} diff --git a/ua/codec_test.go b/ua/codec_test.go index 24fe252a..3ba7bd9c 100644 --- a/ua/codec_test.go +++ b/ua/codec_test.go @@ -53,11 +53,12 @@ func RunCodecTest(t *testing.T, cases []CodecTestCase) { }) t.Run("encode", func(t *testing.T) { - b, err := Encode(c.Struct) - if err != nil { - t.Fatal(err) + s := NewStream(DefaultBufSize) + s.WriteAny(c.Struct) + if s.Error() != nil { + t.Fatal(s.Error()) } - verify.Values(t, "", b, c.Bytes) + verify.Values(t, "", s.Bytes(), c.Bytes) }) }) } diff --git a/ua/datatypes.go b/ua/datatypes.go index 2aefc0f7..f7f75769 100644 --- a/ua/datatypes.go +++ b/ua/datatypes.go @@ -61,11 +61,11 @@ func (d *DataValue) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (d *DataValue) Encode(s *Stream) { +func (d *DataValue) Encode(s *Stream) error { s.WriteUint8(d.EncodingMask) if d.Has(DataValueValue) { - s.WriteStruct(d.Value) + s.WriteAny(d.Value) } if d.Has(DataValueStatusCode) { s.WriteUint32(uint32(d.Status)) @@ -82,6 +82,7 @@ func (d *DataValue) Encode(s *Stream) { if d.Has(DataValueServerPicoseconds) { s.WriteUint16(d.ServerPicoseconds) } + return s.Error() } func (d *DataValue) Has(mask byte) bool { @@ -150,11 +151,12 @@ func (g *GUID) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (g *GUID) Encode(s *Stream) { +func (g *GUID) Encode(s *Stream) error { s.WriteUint32(g.Data1) s.WriteUint16(g.Data2) s.WriteUint16(g.Data3) s.Write(g.Data4) + return s.Error() } // String returns GUID in human-readable string. diff --git a/ua/decode_test.go b/ua/decode_test.go index acdc04f0..9ef2b14a 100644 --- a/ua/decode_test.go +++ b/ua/decode_test.go @@ -403,11 +403,12 @@ func TestCodec(t *testing.T) { } }) t.Run("encode", func(t *testing.T) { - b, err := Encode(tt.v) - if err != nil { - t.Fatal(err) + s := NewStream(DefaultBufSize) + s.WriteAny(tt.v) + if s.Error() != nil { + t.Fatal(s.Error()) } - if got, want := b, tt.b; !bytes.Equal(got, want) { + if got, want := s.Bytes(), tt.b; !bytes.Equal(got, want) { t.Fatalf("\ngot %#v\nwant %#v", got, want) } }) diff --git a/ua/diagnostic_info.go b/ua/diagnostic_info.go index 9a873d07..850d897d 100644 --- a/ua/diagnostic_info.go +++ b/ua/diagnostic_info.go @@ -79,7 +79,7 @@ func (d *DiagnosticInfo) Encode(s *Stream) error { s.WriteUint32(uint32(d.InnerStatusCode)) } if d.Has(DiagnosticInfoInnerDiagnosticInfo) { - s.WriteStruct(d.InnerDiagnosticInfo) + s.WriteAny(d.InnerDiagnosticInfo) } return s.Error() } diff --git a/ua/encode.go b/ua/encode.go index 0b17fec6..83c13add 100644 --- a/ua/encode.go +++ b/ua/encode.go @@ -5,7 +5,6 @@ package ua import ( - "encoding/hex" "fmt" "math" "reflect" @@ -15,23 +14,13 @@ import ( "github.com/gopcua/opcua/errors" ) -type ValEncoder interface { - Encode(s *Stream) error -} - -var valEncoder = reflect.TypeOf((*ValEncoder)(nil)).Elem() - -func isValEncoder(val reflect.Value) bool { - return val.Type().Implements(valEncoder) -} - // debugCodec enables printing of debug messages in the opcua codec. var debugCodec = debug.FlagSet("codec") // BinaryEncoder is the interface implemented by an object that can // marshal itself into a binary OPC/UA representation. type BinaryEncoder interface { - Encode() ([]byte, error) + Encode(s *Stream) error } var binaryEncoder = reflect.TypeOf((*BinaryEncoder)(nil)).Elem() @@ -40,150 +29,142 @@ func isBinaryEncoder(val reflect.Value) bool { return val.Type().Implements(binaryEncoder) } -func Encode(v interface{}) ([]byte, error) { - val := reflect.ValueOf(v) - return encode(val, val.Type().String()) +func (s *Stream) WriteAny(w interface{}) { + if s.err != nil { + return + } + val := reflect.ValueOf(w) + switch x := w.(type) { + case BinaryEncoder: + s.err = x.Encode(s) + default: + s.err = s.encode(val, val.Type().String()) + } } -func encode(val reflect.Value, name string) ([]byte, error) { +func (s *Stream) encode(val reflect.Value, name string) error { if debugCodec { fmt.Printf("encode: %s has type %s and is a %s\n", name, val.Type(), val.Type().Kind()) } - buf := NewBuffer(nil) switch { case isBinaryEncoder(val): v := val.Interface().(BinaryEncoder) - return dump(v.Encode()) + return v.Encode(s) case isTime(val): - buf.WriteTime(val.Convert(timeType).Interface().(time.Time)) + s.WriteTime(val.Convert(timeType).Interface().(time.Time)) default: switch val.Kind() { case reflect.Bool: - buf.WriteBool(val.Bool()) + s.WriteBool(val.Bool()) case reflect.Int8: - buf.WriteInt8(int8(val.Int())) + s.WriteInt8(int8(val.Int())) case reflect.Uint8: - buf.WriteUint8(uint8(val.Uint())) + s.WriteUint8(uint8(val.Uint())) case reflect.Int16: - buf.WriteInt16(int16(val.Int())) + s.WriteInt16(int16(val.Int())) case reflect.Uint16: - buf.WriteUint16(uint16(val.Uint())) + s.WriteUint16(uint16(val.Uint())) case reflect.Int32: - buf.WriteInt32(int32(val.Int())) + s.WriteInt32(int32(val.Int())) case reflect.Uint32: - buf.WriteUint32(uint32(val.Uint())) + s.WriteUint32(uint32(val.Uint())) case reflect.Int64: - buf.WriteInt64(int64(val.Int())) + s.WriteInt64(int64(val.Int())) case reflect.Uint64: - buf.WriteUint64(uint64(val.Uint())) + s.WriteUint64(uint64(val.Uint())) case reflect.Float32: - buf.WriteFloat32(float32(val.Float())) + s.WriteFloat32(float32(val.Float())) case reflect.Float64: - buf.WriteFloat64(float64(val.Float())) + s.WriteFloat64(float64(val.Float())) case reflect.String: - buf.WriteString(val.String()) + s.WriteString(val.String()) case reflect.Ptr: if val.IsNil() { - return nil, nil + return nil } - return dump(encode(val.Elem(), name)) + return s.encode(val.Elem(), name) case reflect.Struct: - return dump(writeStruct(val, name)) + return s.writeStruct(val, name) case reflect.Slice: - return dump(writeSlice(val, name)) + return s.writeSlice(val, name) case reflect.Array: - return dump(writeArray(val, name)) + return s.writeArray(val, name) default: - return nil, errors.Errorf("unsupported type: %s", val.Type()) + return errors.Errorf("unsupported type: %s", val.Type()) } } - return dump(buf.Bytes(), buf.Error()) + return s.Error() } -func dump(b []byte, err error) ([]byte, error) { - if debugCodec { - fmt.Printf("%s---\n", hex.Dump(b)) - } - return b, err -} - -func writeStruct(val reflect.Value, name string) ([]byte, error) { - var buf []byte +func (s *Stream) writeStruct(val reflect.Value, name string) error { valt := val.Type() for i := 0; i < val.NumField(); i++ { ft := valt.Field(i) fname := name + "." + ft.Name - b, err := encode(val.Field(i), fname) - if err != nil { - return nil, err + if err := s.encode(val.Field(i), fname); err != nil { + return err } - buf = append(buf, b...) } - return buf, nil + return nil } -func writeSlice(val reflect.Value, name string) ([]byte, error) { - buf := NewBuffer(nil) +func (s *Stream) writeSlice(val reflect.Value, name string) error { if val.IsNil() { - buf.WriteUint32(null) - return buf.Bytes(), buf.Error() + s.WriteUint32(null) + return s.Error() } if val.Len() > math.MaxInt32 { - return nil, errors.Errorf("array too large") + return errors.Errorf("array too large") } - buf.WriteUint32(uint32(val.Len())) + s.WriteUint32(uint32(val.Len())) // fast path for []byte if val.Type().Elem().Kind() == reflect.Uint8 { // fmt.Println("[]byte fast path") - buf.Write(val.Bytes()) - return buf.Bytes(), buf.Error() + s.Write(val.Bytes()) + return s.Error() } // loop over elements for i := 0; i < val.Len(); i++ { ename := fmt.Sprintf("%s[%d]", name, i) - b, err := encode(val.Index(i), ename) - if err != nil { - return nil, err + s.encode(val.Index(i), ename) + if s.Error() != nil { + return s.Error() } - buf.Write(b) } - return buf.Bytes(), buf.Error() + return s.Error() } -func writeArray(val reflect.Value, name string) ([]byte, error) { - buf := NewBuffer(nil) - +func (s *Stream) writeArray(val reflect.Value, name string) error { if val.Len() > math.MaxInt32 { - return nil, errors.Errorf("array too large: %d > %d", val.Len(), math.MaxInt32) + return errors.Errorf("array too large: %d > %d", val.Len(), math.MaxInt32) } - buf.WriteUint32(uint32(val.Len())) + s.WriteUint32(uint32(val.Len())) // fast path for []byte if val.Type().Elem().Kind() == reflect.Uint8 { // fmt.Println("encode: []byte fast path") b := make([]byte, val.Len()) reflect.Copy(reflect.ValueOf(b), val) - buf.Write(b) - return buf.Bytes(), buf.Error() + s.Write(b) + return s.Error() } // loop over elements // we write all the elements, also the zero values for i := 0; i < val.Len(); i++ { ename := fmt.Sprintf("%s[%d]", name, i) - b, err := encode(val.Index(i), ename) - if err != nil { - return nil, err + s.encode(val.Index(i), ename) + if s.Error() != nil { + return s.Error() } - buf.Write(b) } - return buf.Bytes(), buf.Error() + return s.Error() } diff --git a/ua/expanded_node_id.go b/ua/expanded_node_id.go index 925bb450..121203ff 100644 --- a/ua/expanded_node_id.go +++ b/ua/expanded_node_id.go @@ -103,7 +103,7 @@ func (e *ExpandedNodeID) Decode(b []byte) (int, error) { } func (e *ExpandedNodeID) Encode(s *Stream) error { - s.WriteStruct(e.NodeID) + s.WriteAny(e.NodeID) if e.HasNamespaceURI() { s.WriteString(e.NamespaceURI) } diff --git a/ua/extension_object.go b/ua/extension_object.go index 064e9f04..cae66a56 100644 --- a/ua/extension_object.go +++ b/ua/extension_object.go @@ -88,15 +88,14 @@ func (e *ExtensionObject) Encode(s *Stream) error { if e == nil { e = &ExtensionObject{TypeID: NewTwoByteExpandedNodeID(0), EncodingMask: ExtensionObjectEmpty} } - s.WriteStruct(e.TypeID) + s.WriteAny(e.TypeID) s.WriteByte(e.EncodingMask) if e.EncodingMask == ExtensionObjectEmpty { return s.Error() } - // TODO: use pool? body := NewStream(DefaultBufSize) - body.WriteStruct(e.Value) + body.WriteAny(e.Value) if body.Error() != nil { return body.Error() } diff --git a/ua/node_id.go b/ua/node_id.go index 50cb7de3..b2a98c1b 100644 --- a/ua/node_id.go +++ b/ua/node_id.go @@ -380,7 +380,7 @@ func (n *NodeID) Encode(s *Stream) error { s.WriteUint32(n.nid) case NodeIDTypeGUID: s.WriteUint16(n.ns) - s.WriteStruct(n.gid) + s.WriteAny(n.gid) case NodeIDTypeByteString, NodeIDTypeString: s.WriteUint16(n.ns) s.WriteByteString(n.bid) diff --git a/ua/stream.go b/ua/stream.go index d14b9aeb..e9c59778 100644 --- a/ua/stream.go +++ b/ua/stream.go @@ -2,10 +2,8 @@ package ua import ( "encoding/binary" - "fmt" "io" "math" - "reflect" "time" "github.com/gopcua/opcua/errors" @@ -36,7 +34,6 @@ func (s *Stream) Len() int { func (s *Stream) Reset() { s.buf = s.buf[:0] s.pos = 0 - s.err = nil } func (s *Stream) Bytes() []byte { @@ -162,143 +159,3 @@ func (b *Stream) Write(d []byte) { } b.buf = append(b.buf, d...) } - -func (s *Stream) WriteStruct(w interface{}) { - if s.err != nil { - return - } - val := reflect.ValueOf(w) - switch x := w.(type) { - case ValEncoder: - s.err = x.Encode(s) - default: - s.err = s.encode(val, val.Type().String()) - } -} - -func (s *Stream) encode(val reflect.Value, name string) error { - if debugCodec { - fmt.Printf("encode: %s has type %s and is a %s\n", name, val.Type(), val.Type().Kind()) - } - - switch { - case isValEncoder(val): - v := val.Interface().(ValEncoder) - return v.Encode(s) - - case isTime(val): - s.WriteTime(val.Convert(timeType).Interface().(time.Time)) - - default: - switch val.Kind() { - case reflect.Bool: - s.WriteBool(val.Bool()) - case reflect.Int8: - s.WriteInt8(int8(val.Int())) - case reflect.Uint8: - s.WriteUint8(uint8(val.Uint())) - case reflect.Int16: - s.WriteInt16(int16(val.Int())) - case reflect.Uint16: - s.WriteUint16(uint16(val.Uint())) - case reflect.Int32: - s.WriteInt32(int32(val.Int())) - case reflect.Uint32: - s.WriteUint32(uint32(val.Uint())) - case reflect.Int64: - s.WriteInt64(int64(val.Int())) - case reflect.Uint64: - s.WriteUint64(uint64(val.Uint())) - case reflect.Float32: - s.WriteFloat32(float32(val.Float())) - case reflect.Float64: - s.WriteFloat64(float64(val.Float())) - case reflect.String: - s.WriteString(val.String()) - case reflect.Ptr: - if val.IsNil() { - return nil - } - return s.encode(val.Elem(), name) - case reflect.Struct: - return s.writeStruct(val, name) - case reflect.Slice: - return s.writeSlice(val, name) - case reflect.Array: - return s.writeArray(val, name) - default: - return errors.Errorf("unsupported type: %s", val.Type()) - } - } - return s.Error() -} - -func (s *Stream) writeStruct(val reflect.Value, name string) error { - valt := val.Type() - for i := 0; i < val.NumField(); i++ { - ft := valt.Field(i) - fname := name + "." + ft.Name - if err := s.encode(val.Field(i), fname); err != nil { - return err - } - } - return nil -} - -func (s *Stream) writeSlice(val reflect.Value, name string) error { - if val.IsNil() { - s.WriteUint32(null) - return s.Error() - } - - if val.Len() > math.MaxInt32 { - return errors.Errorf("array too large") - } - - s.WriteUint32(uint32(val.Len())) - - // fast path for []byte - if val.Type().Elem().Kind() == reflect.Uint8 { - // fmt.Println("[]byte fast path") - s.Write(val.Bytes()) - return s.Error() - } - - // loop over elements - for i := 0; i < val.Len(); i++ { - ename := fmt.Sprintf("%s[%d]", name, i) - s.encode(val.Index(i), ename) - if s.Error() != nil { - return s.Error() - } - } - return s.Error() -} - -func (s *Stream) writeArray(val reflect.Value, name string) error { - if val.Len() > math.MaxInt32 { - return errors.Errorf("array too large: %d > %d", val.Len(), math.MaxInt32) - } - - s.WriteUint32(uint32(val.Len())) - - // fast path for []byte - if val.Type().Elem().Kind() == reflect.Uint8 { - // fmt.Println("encode: []byte fast path") - b := make([]byte, val.Len()) - reflect.Copy(reflect.ValueOf(b), val) - s.Write(b) - return s.Error() - } - - // loop over elements - // we write all the elements, also the zero values - for i := 0; i < val.Len(); i++ { - ename := fmt.Sprintf("%s[%d]", name, i) - s.encode(val.Index(i), ename) - if s.Error() != nil { - return s.Error() - } - } - return s.Error() -} diff --git a/ua/variant.go b/ua/variant.go index be6b7db9..ec8cabcf 100644 --- a/ua/variant.go +++ b/ua/variant.go @@ -383,29 +383,29 @@ func (m *Variant) encodeValue(s *Stream, v interface{}) { case time.Time: s.WriteTime(x) case *GUID: - s.WriteStruct(x) + s.WriteAny(x) case []byte: s.WriteByteString(x) case XMLElement: s.WriteString(string(x)) case *NodeID: - s.WriteStruct(x) + s.WriteAny(x) case *ExpandedNodeID: - s.WriteStruct(x) + s.WriteAny(x) case StatusCode: s.WriteUint32(uint32(x)) case *QualifiedName: - s.WriteStruct(x) + s.WriteAny(x) case *LocalizedText: - s.WriteStruct(x) + s.WriteAny(x) case *ExtensionObject: - s.WriteStruct(x) + s.WriteAny(x) case *DataValue: - s.WriteStruct(x) + s.WriteAny(x) case *Variant: - s.WriteStruct(x) + s.WriteAny(x) case *DiagnosticInfo: - s.WriteStruct(x) + s.WriteAny(x) } } diff --git a/uacp/codec_test.go b/uacp/codec_test.go index 98f5669b..b8a590f9 100644 --- a/uacp/codec_test.go +++ b/uacp/codec_test.go @@ -55,7 +55,7 @@ func RunCodecTest(t *testing.T, cases []CodecTestCase) { t.Run("encode", func(t *testing.T) { s := ua.NewStream(ua.DefaultBufSize) - s.WriteStruct(c.Struct) + s.WriteAny(c.Struct) if s.Error() != nil { t.Fatalf("fail to encode message, err: %v", s.Error()) } diff --git a/uacp/conn.go b/uacp/conn.go index d8dfe2a8..32c86e4e 100644 --- a/uacp/conn.go +++ b/uacp/conn.go @@ -397,7 +397,7 @@ func (c *Conn) Send(typ string, msg interface{}) error { } bodyStream := ua.NewStream(ua.DefaultBufSize) - bodyStream.WriteStruct(msg) + bodyStream.WriteAny(msg) if bodyStream.Error() != nil { return errors.Errorf("encode msg failed: %s", bodyStream.Error()) } diff --git a/uasc/codec_test.go b/uasc/codec_test.go index 2ec99e80..9bd52826 100644 --- a/uasc/codec_test.go +++ b/uasc/codec_test.go @@ -55,7 +55,7 @@ func RunCodecTest(t *testing.T, cases []CodecTestCase) { t.Run("encode", func(t *testing.T) { s := ua.NewStream(ua.DefaultBufSize) - s.WriteStruct(c.Struct) + s.WriteAny(c.Struct) if s.Error() != nil { t.Fatal(s.Error()) } diff --git a/uasc/message.go b/uasc/message.go index b144e4f1..7e68b831 100644 --- a/uasc/message.go +++ b/uasc/message.go @@ -122,8 +122,8 @@ func (m *Message) Encode(s *ua.Stream) error { func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { dataBody := ua.NewStream(ua.DefaultBufSize) - dataBody.WriteStruct(m.TypeID) - dataBody.WriteStruct(m.Service) + dataBody.WriteAny(m.TypeID) + dataBody.WriteAny(m.Service) if dataBody.Error() != nil { return nil, dataBody.Error() @@ -135,8 +135,8 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { switch m.Header.MessageType { case "OPN": partialHeader := ua.NewStream(ua.DefaultBufSize) - partialHeader.WriteStruct(m.AsymmetricSecurityHeader) - partialHeader.WriteStruct(m.SequenceHeader) + partialHeader.WriteAny(m.AsymmetricSecurityHeader) + partialHeader.WriteAny(m.SequenceHeader) if partialHeader.Error() != nil { return nil, partialHeader.Error() @@ -144,11 +144,12 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { m.Header.MessageSize = uint32(12 + partialHeader.Len() + dataBody.Len()) buf := ua.NewStream(ua.DefaultBufSize) - buf.WriteStruct(m.Header) + buf.WriteAny(m.Header) buf.Write(partialHeader.Bytes()) buf.Write(dataBody.Bytes()) - return [][]byte{buf.Bytes()}, buf.Error() + b := append([]byte(nil), buf.Bytes()...) + return [][]byte{b}, buf.Error() case "CLO", "MSG": @@ -156,29 +157,29 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { m.Header.MessageSize = maxBodySize + 24 m.Header.ChunkType = ChunkTypeIntermediate chunk := ua.NewStream(ua.DefaultBufSize) - chunk.WriteStruct(m.Header) - chunk.WriteStruct(m.SymmetricSecurityHeader) - chunk.WriteStruct(m.SequenceHeader) + chunk.WriteAny(m.Header) + chunk.WriteAny(m.SymmetricSecurityHeader) + chunk.WriteAny(m.SequenceHeader) chunk.Write(dataBody.ReadN(int(maxBodySize))) if chunk.Error() != nil { return nil, chunk.Error() } - chunks[i] = chunk.Bytes() + chunks[i] = append(chunks[i], chunk.Bytes()...) } m.Header.ChunkType = ChunkTypeFinal m.Header.MessageSize = uint32(24 + dataBody.Len()) chunk := ua.NewStream(ua.DefaultBufSize) - chunk.WriteStruct(m.Header) - chunk.WriteStruct(m.SymmetricSecurityHeader) - chunk.WriteStruct(m.SequenceHeader) + chunk.WriteAny(m.Header) + chunk.WriteAny(m.SymmetricSecurityHeader) + chunk.WriteAny(m.SequenceHeader) chunk.Write(dataBody.Bytes()) if chunk.Error() != nil { return nil, chunk.Error() } - chunks[nrChunks-1] = chunk.Bytes() + chunks[nrChunks-1] = append(chunks[nrChunks-1], chunk.Bytes()...) return chunks, nil default: return nil, errors.Errorf("invalid message type %q", m.Header.MessageType) diff --git a/uasc/message_test.go b/uasc/message_test.go index 2e10e5bd..4280b826 100644 --- a/uasc/message_test.go +++ b/uasc/message_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/gopcua/opcua/id" + "github.com/pascaldekloe/goe/verify" "github.com/gopcua/opcua/ua" ) @@ -487,10 +488,11 @@ func BenchmarkEncodeMessage(b *testing.B) { s := ua.NewStream(ua.DefaultBufSize) for i := 0; i < b.N; i++ { for _, tc := range cases { - s.WriteStruct(tc.Struct) + s.WriteAny(tc.Struct) if s.Error() != nil { b.Fatalf("fail to encode message, err: %v", s.Error()) } + verify.Values(b, "", s.Bytes(), tc.Bytes) s.Reset() } } From a2588ecd1c3c88a8ff0bb294b443e5100b57a128 Mon Sep 17 00:00:00 2001 From: ethan256 Date: Sat, 18 May 2024 17:12:48 +0800 Subject: [PATCH 03/14] feat: add streamPool --- ua/encode.go | 17 +++++++++++++++++ ua/extension_object.go | 3 ++- uacp/conn.go | 6 ++++-- uasc/message.go | 16 +++++++++++----- 4 files changed, 34 insertions(+), 8 deletions(-) diff --git a/ua/encode.go b/ua/encode.go index 83c13add..6f184fdd 100644 --- a/ua/encode.go +++ b/ua/encode.go @@ -8,12 +8,29 @@ import ( "fmt" "math" "reflect" + "sync" "time" "github.com/gopcua/opcua/debug" "github.com/gopcua/opcua/errors" ) +var streamPool sync.Pool + +func BorrowStream() *Stream { + v, ok := streamPool.Get().(*Stream) + if ok { + v.Reset() + return v + } + + return NewStream(DefaultBufSize) +} + +func ReturnStream(s *Stream) { + streamPool.Put(s) +} + // debugCodec enables printing of debug messages in the opcua codec. var debugCodec = debug.FlagSet("codec") diff --git a/ua/extension_object.go b/ua/extension_object.go index cae66a56..5b6e04e9 100644 --- a/ua/extension_object.go +++ b/ua/extension_object.go @@ -94,7 +94,8 @@ func (e *ExtensionObject) Encode(s *Stream) error { return s.Error() } - body := NewStream(DefaultBufSize) + body := BorrowStream() + defer ReturnStream(body) body.WriteAny(e.Value) if body.Error() != nil { return body.Error() diff --git a/uacp/conn.go b/uacp/conn.go index 32c86e4e..302b560b 100644 --- a/uacp/conn.go +++ b/uacp/conn.go @@ -396,7 +396,8 @@ func (c *Conn) Send(typ string, msg interface{}) error { return errors.Errorf("invalid msg type: %s", typ) } - bodyStream := ua.NewStream(ua.DefaultBufSize) + bodyStream := ua.BorrowStream() + ua.ReturnStream(bodyStream) bodyStream.WriteAny(msg) if bodyStream.Error() != nil { return errors.Errorf("encode msg failed: %s", bodyStream.Error()) @@ -412,7 +413,8 @@ func (c *Conn) Send(typ string, msg interface{}) error { return errors.Errorf("send packet too large: %d > %d bytes", h.MessageSize, c.ack.SendBufSize) } - headerStream := ua.NewStream(ua.DefaultBufSize) + headerStream := ua.BorrowStream() + defer ua.ReturnStream(headerStream) err := h.Encode(headerStream) if err != nil { return errors.Errorf("encode hdr failed: %s", err) diff --git a/uasc/message.go b/uasc/message.go index 7e68b831..be961add 100644 --- a/uasc/message.go +++ b/uasc/message.go @@ -121,7 +121,8 @@ func (m *Message) Encode(s *ua.Stream) error { } func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { - dataBody := ua.NewStream(ua.DefaultBufSize) + dataBody := ua.BorrowStream() + defer ua.ReturnStream(dataBody) dataBody.WriteAny(m.TypeID) dataBody.WriteAny(m.Service) @@ -134,7 +135,8 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { switch m.Header.MessageType { case "OPN": - partialHeader := ua.NewStream(ua.DefaultBufSize) + partialHeader := ua.BorrowStream() + defer ua.ReturnStream(partialHeader) partialHeader.WriteAny(m.AsymmetricSecurityHeader) partialHeader.WriteAny(m.SequenceHeader) @@ -143,7 +145,8 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { } m.Header.MessageSize = uint32(12 + partialHeader.Len() + dataBody.Len()) - buf := ua.NewStream(ua.DefaultBufSize) + buf := ua.BorrowStream() + ua.ReturnStream(buf) buf.WriteAny(m.Header) buf.Write(partialHeader.Bytes()) buf.Write(dataBody.Bytes()) @@ -156,21 +159,24 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { for i := uint32(0); i < nrChunks-1; i++ { m.Header.MessageSize = maxBodySize + 24 m.Header.ChunkType = ChunkTypeIntermediate - chunk := ua.NewStream(ua.DefaultBufSize) + chunk := ua.BorrowStream() chunk.WriteAny(m.Header) chunk.WriteAny(m.SymmetricSecurityHeader) chunk.WriteAny(m.SequenceHeader) chunk.Write(dataBody.ReadN(int(maxBodySize))) if chunk.Error() != nil { + ua.ReturnStream(chunk) return nil, chunk.Error() } chunks[i] = append(chunks[i], chunk.Bytes()...) + ua.ReturnStream(chunk) } m.Header.ChunkType = ChunkTypeFinal m.Header.MessageSize = uint32(24 + dataBody.Len()) - chunk := ua.NewStream(ua.DefaultBufSize) + chunk := ua.BorrowStream() + defer ua.ReturnStream(chunk) chunk.WriteAny(m.Header) chunk.WriteAny(m.SymmetricSecurityHeader) chunk.WriteAny(m.SequenceHeader) From 35ce5d434a869f3ec8077e699290021f6e572da8 Mon Sep 17 00:00:00 2001 From: ethan256 Date: Sun, 19 May 2024 16:00:33 +0800 Subject: [PATCH 04/14] fix: fix some bug --- uacp/conn.go | 2 +- uasc/message.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/uacp/conn.go b/uacp/conn.go index 302b560b..e7bffabe 100644 --- a/uacp/conn.go +++ b/uacp/conn.go @@ -397,7 +397,7 @@ func (c *Conn) Send(typ string, msg interface{}) error { } bodyStream := ua.BorrowStream() - ua.ReturnStream(bodyStream) + defer ua.ReturnStream(bodyStream) bodyStream.WriteAny(msg) if bodyStream.Error() != nil { return errors.Errorf("encode msg failed: %s", bodyStream.Error()) diff --git a/uasc/message.go b/uasc/message.go index be961add..caf686bf 100644 --- a/uasc/message.go +++ b/uasc/message.go @@ -146,12 +146,12 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { m.Header.MessageSize = uint32(12 + partialHeader.Len() + dataBody.Len()) buf := ua.BorrowStream() - ua.ReturnStream(buf) buf.WriteAny(m.Header) buf.Write(partialHeader.Bytes()) buf.Write(dataBody.Bytes()) b := append([]byte(nil), buf.Bytes()...) + ua.ReturnStream(buf) return [][]byte{b}, buf.Error() case "CLO", "MSG": From fd36b8c6dc827e310eb584d67eda4ebd2307f10c Mon Sep 17 00:00:00 2001 From: yuanliang Date: Mon, 20 May 2024 14:56:20 +0800 Subject: [PATCH 05/14] fix: avoid memory leak --- ua/datatypes.go | 9 ++--- ua/diagnostic_info.go | 3 +- ua/encode.go | 58 +++++++++++++++--------------- ua/expanded_node_id.go | 3 +- ua/extension_object.go | 8 ++--- ua/node_id.go | 5 ++- ua/stream.go | 4 +++ ua/variant.go | 6 ++-- uacp/conn.go | 10 +++--- uacp/conn_test.go | 6 ++-- uacp/uacp.go | 21 +++++------ uasc/asymmetric_security_header.go | 3 +- uasc/codec_test.go | 3 +- uasc/header.go | 6 ++-- uasc/message.go | 39 +++++++++----------- uasc/message_test.go | 6 ++-- uasc/sequence_header.go | 3 +- uasc/symmetric_security_header.go | 3 +- 18 files changed, 88 insertions(+), 108 deletions(-) diff --git a/ua/datatypes.go b/ua/datatypes.go index f7f75769..aa5e12ea 100644 --- a/ua/datatypes.go +++ b/ua/datatypes.go @@ -61,7 +61,7 @@ func (d *DataValue) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (d *DataValue) Encode(s *Stream) error { +func (d *DataValue) Encode(s *Stream) { s.WriteUint8(d.EncodingMask) if d.Has(DataValueValue) { @@ -82,7 +82,6 @@ func (d *DataValue) Encode(s *Stream) error { if d.Has(DataValueServerPicoseconds) { s.WriteUint16(d.ServerPicoseconds) } - return s.Error() } func (d *DataValue) Has(mask byte) bool { @@ -151,12 +150,11 @@ func (g *GUID) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (g *GUID) Encode(s *Stream) error { +func (g *GUID) Encode(s *Stream) { s.WriteUint32(g.Data1) s.WriteUint16(g.Data2) s.WriteUint16(g.Data3) s.Write(g.Data4) - return s.Error() } // String returns GUID in human-readable string. @@ -225,7 +223,7 @@ func (l *LocalizedText) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (l *LocalizedText) Encode(s *Stream) error { +func (l *LocalizedText) Encode(s *Stream) { s.WriteUint8(l.EncodingMask) if l.Has(LocalizedTextLocale) { s.WriteString(l.Locale) @@ -233,7 +231,6 @@ func (l *LocalizedText) Encode(s *Stream) error { if l.Has(LocalizedTextText) { s.WriteString(l.Text) } - return s.Error() } func (l *LocalizedText) Has(mask byte) bool { diff --git a/ua/diagnostic_info.go b/ua/diagnostic_info.go index 850d897d..b72d46df 100644 --- a/ua/diagnostic_info.go +++ b/ua/diagnostic_info.go @@ -58,7 +58,7 @@ func (d *DiagnosticInfo) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (d *DiagnosticInfo) Encode(s *Stream) error { +func (d *DiagnosticInfo) Encode(s *Stream) { s.WriteByte(d.EncodingMask) if d.Has(DiagnosticInfoSymbolicID) { s.WriteInt32(d.SymbolicID) @@ -81,7 +81,6 @@ func (d *DiagnosticInfo) Encode(s *Stream) error { if d.Has(DiagnosticInfoInnerDiagnosticInfo) { s.WriteAny(d.InnerDiagnosticInfo) } - return s.Error() } func (d *DiagnosticInfo) Has(mask byte) bool { diff --git a/ua/encode.go b/ua/encode.go index 6f184fdd..2a58552a 100644 --- a/ua/encode.go +++ b/ua/encode.go @@ -15,16 +15,16 @@ import ( "github.com/gopcua/opcua/errors" ) -var streamPool sync.Pool +var streamPool sync.Pool = sync.Pool{ + New: func() interface{} { + return NewStream(DefaultBufSize) + }, +} func BorrowStream() *Stream { - v, ok := streamPool.Get().(*Stream) - if ok { - v.Reset() - return v - } - - return NewStream(DefaultBufSize) + v := streamPool.Get().(*Stream) + v.Reset() + return v } func ReturnStream(s *Stream) { @@ -37,7 +37,7 @@ var debugCodec = debug.FlagSet("codec") // BinaryEncoder is the interface implemented by an object that can // marshal itself into a binary OPC/UA representation. type BinaryEncoder interface { - Encode(s *Stream) error + Encode(s *Stream) } var binaryEncoder = reflect.TypeOf((*BinaryEncoder)(nil)).Elem() @@ -53,13 +53,13 @@ func (s *Stream) WriteAny(w interface{}) { val := reflect.ValueOf(w) switch x := w.(type) { case BinaryEncoder: - s.err = x.Encode(s) + x.Encode(s) default: - s.err = s.encode(val, val.Type().String()) + s.encode(val, val.Type().String()) } } -func (s *Stream) encode(val reflect.Value, name string) error { +func (s *Stream) encode(val reflect.Value, name string) { if debugCodec { fmt.Printf("encode: %s has type %s and is a %s\n", name, val.Type(), val.Type().Kind()) } @@ -67,7 +67,7 @@ func (s *Stream) encode(val reflect.Value, name string) error { switch { case isBinaryEncoder(val): v := val.Interface().(BinaryEncoder) - return v.Encode(s) + v.Encode(s) case isTime(val): s.WriteTime(val.Convert(timeType).Interface().(time.Time)) @@ -100,42 +100,41 @@ func (s *Stream) encode(val reflect.Value, name string) error { s.WriteString(val.String()) case reflect.Ptr: if val.IsNil() { - return nil + return } - return s.encode(val.Elem(), name) + s.encode(val.Elem(), name) case reflect.Struct: - return s.writeStruct(val, name) + s.writeStruct(val, name) case reflect.Slice: - return s.writeSlice(val, name) + s.writeSlice(val, name) case reflect.Array: - return s.writeArray(val, name) + s.writeArray(val, name) default: - return errors.Errorf("unsupported type: %s", val.Type()) + s.WrapError(errors.Errorf("unsupported type: %s", val.Type())) } } - return s.Error() } -func (s *Stream) writeStruct(val reflect.Value, name string) error { +func (s *Stream) writeStruct(val reflect.Value, name string) { valt := val.Type() for i := 0; i < val.NumField(); i++ { ft := valt.Field(i) fname := name + "." + ft.Name - if err := s.encode(val.Field(i), fname); err != nil { - return err + if s.encode(val.Field(i), fname); s.err != nil { + return } } - return nil } -func (s *Stream) writeSlice(val reflect.Value, name string) error { +func (s *Stream) writeSlice(val reflect.Value, name string) { if val.IsNil() { s.WriteUint32(null) - return s.Error() + return } if val.Len() > math.MaxInt32 { - return errors.Errorf("array too large") + s.WrapError(errors.Errorf("array too large")) + return } s.WriteUint32(uint32(val.Len())) @@ -144,7 +143,7 @@ func (s *Stream) writeSlice(val reflect.Value, name string) error { if val.Type().Elem().Kind() == reflect.Uint8 { // fmt.Println("[]byte fast path") s.Write(val.Bytes()) - return s.Error() + return } // loop over elements @@ -152,10 +151,9 @@ func (s *Stream) writeSlice(val reflect.Value, name string) error { ename := fmt.Sprintf("%s[%d]", name, i) s.encode(val.Index(i), ename) if s.Error() != nil { - return s.Error() + return } } - return s.Error() } func (s *Stream) writeArray(val reflect.Value, name string) error { diff --git a/ua/expanded_node_id.go b/ua/expanded_node_id.go index 121203ff..c0e7ee3c 100644 --- a/ua/expanded_node_id.go +++ b/ua/expanded_node_id.go @@ -102,7 +102,7 @@ func (e *ExpandedNodeID) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (e *ExpandedNodeID) Encode(s *Stream) error { +func (e *ExpandedNodeID) Encode(s *Stream) { s.WriteAny(e.NodeID) if e.HasNamespaceURI() { s.WriteString(e.NamespaceURI) @@ -110,7 +110,6 @@ func (e *ExpandedNodeID) Encode(s *Stream) error { if e.HasServerIndex() { s.WriteUint32(e.ServerIndex) } - return s.Error() } // HasNamespaceURI checks if an ExpandedNodeID has NamespaceURI Flag. diff --git a/ua/extension_object.go b/ua/extension_object.go index 5b6e04e9..bfe7e15b 100644 --- a/ua/extension_object.go +++ b/ua/extension_object.go @@ -84,25 +84,25 @@ func (e *ExtensionObject) Decode(b []byte) (int, error) { return buf.Pos(), body.Error() } -func (e *ExtensionObject) Encode(s *Stream) error { +func (e *ExtensionObject) Encode(s *Stream) { if e == nil { e = &ExtensionObject{TypeID: NewTwoByteExpandedNodeID(0), EncodingMask: ExtensionObjectEmpty} } s.WriteAny(e.TypeID) s.WriteByte(e.EncodingMask) if e.EncodingMask == ExtensionObjectEmpty { - return s.Error() + return } body := BorrowStream() defer ReturnStream(body) body.WriteAny(e.Value) if body.Error() != nil { - return body.Error() + s.WrapError(body.Error()) + return } s.WriteUint32(uint32(body.Len())) s.Write(body.Bytes()) - return s.Error() } func (e *ExtensionObject) UpdateMask() { diff --git a/ua/node_id.go b/ua/node_id.go index b2a98c1b..9584b1b5 100644 --- a/ua/node_id.go +++ b/ua/node_id.go @@ -366,7 +366,7 @@ func (n *NodeID) Decode(b []byte) (int, error) { } } -func (n *NodeID) Encode(s *Stream) error { +func (n *NodeID) Encode(s *Stream) { s.WriteByte(byte(n.mask)) switch n.Type() { @@ -385,9 +385,8 @@ func (n *NodeID) Encode(s *Stream) error { s.WriteUint16(n.ns) s.WriteByteString(n.bid) default: - return errors.Errorf("invalid node id type %v", n.Type()) + s.err = errors.Errorf("invalid node id type %v", n.Type()) } - return s.Error() } func (n *NodeID) MarshalJSON() ([]byte, error) { diff --git a/ua/stream.go b/ua/stream.go index e9c59778..e6cd31d7 100644 --- a/ua/stream.go +++ b/ua/stream.go @@ -23,6 +23,10 @@ func NewStream(size int) *Stream { } } +func (s *Stream) WrapError(err error) { + s.err = errors.Join(err) +} + func (s *Stream) Error() error { return s.err } diff --git a/ua/variant.go b/ua/variant.go index ec8cabcf..bb29604c 100644 --- a/ua/variant.go +++ b/ua/variant.go @@ -318,12 +318,12 @@ func (m *Variant) decodeValue(buf *Buffer) interface{} { } // Encode implements the codec interface. -func (m *Variant) Encode(s *Stream) error { +func (m *Variant) Encode(s *Stream) { s.WriteByte(m.mask) // a null value specifies that no other fields are encoded if m.Type() == TypeIDNull { - return s.Error() + return } if m.Has(VariantArrayValues) { @@ -338,8 +338,6 @@ func (m *Variant) Encode(s *Stream) error { s.WriteInt32(m.arrayDimensions[i]) } } - - return s.Error() } // encode recursively writes the values to the buffer. diff --git a/uacp/conn.go b/uacp/conn.go index e7bffabe..ab500ebd 100644 --- a/uacp/conn.go +++ b/uacp/conn.go @@ -415,12 +415,14 @@ func (c *Conn) Send(typ string, msg interface{}) error { headerStream := ua.BorrowStream() defer ua.ReturnStream(headerStream) - err := h.Encode(headerStream) - if err != nil { - return errors.Errorf("encode hdr failed: %s", err) + h.Encode(headerStream) + if headerStream.Error() != nil { + return errors.Errorf("encode hdr failed: %s", headerStream.Error()) } - b := append(headerStream.Bytes(), bodyStream.Bytes()...) + b := make([]byte, 0, headerStream.Len()+bodyStream.Len()) + b = append(b, headerStream.Bytes()...) + b = append(b, bodyStream.Bytes()...) if _, err := c.Write(b); err != nil { return errors.Errorf("write failed: %s", err) } diff --git a/uacp/conn_test.go b/uacp/conn_test.go index f4fdffa3..5620569e 100644 --- a/uacp/conn_test.go +++ b/uacp/conn_test.go @@ -122,9 +122,9 @@ NEXT: got = got[hdrlen:] s := ua.NewStream(ua.DefaultBufSize) - err = msg.Encode(s) - if err != nil { - t.Fatal(err) + msg.Encode(s) + if s.Error() != nil { + t.Fatal(s.Error()) } want := s.Bytes() verify.Values(t, "", got, want) diff --git a/uacp/uacp.go b/uacp/uacp.go index 1aa737a8..2ae23826 100644 --- a/uacp/uacp.go +++ b/uacp/uacp.go @@ -45,14 +45,14 @@ func (h *Header) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *Header) Encode(s *ua.Stream) error { +func (h *Header) Encode(s *ua.Stream) { if len(h.MessageType) != 3 { - return errors.Errorf("invalid message type: %q", h.MessageType) + s.WrapError(errors.Errorf("invalid message type: %q", h.MessageType)) + return } s.Write([]byte(h.MessageType)) s.WriteByte(h.ChunkType) s.WriteUint32(h.MessageSize) - return s.Error() } // Hello represents a OPC UA Hello. @@ -78,14 +78,13 @@ func (h *Hello) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *Hello) Encode(s *ua.Stream) error { +func (h *Hello) Encode(s *ua.Stream) { s.WriteUint32(h.Version) s.WriteUint32(h.ReceiveBufSize) s.WriteUint32(h.SendBufSize) s.WriteUint32(h.MaxMessageSize) s.WriteUint32(h.MaxChunkCount) s.WriteString(h.EndpointURL) - return s.Error() } // Acknowledge represents a OPC UA Acknowledge. @@ -109,13 +108,12 @@ func (a *Acknowledge) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (a *Acknowledge) Encode(s *ua.Stream) error { +func (a *Acknowledge) Encode(s *ua.Stream) { s.WriteUint32(a.Version) s.WriteUint32(a.ReceiveBufSize) s.WriteUint32(a.SendBufSize) s.WriteUint32(a.MaxMessageSize) s.WriteUint32(a.MaxChunkCount) - return s.Error() } // ReverseHello represents a OPC UA ReverseHello. @@ -133,10 +131,9 @@ func (r *ReverseHello) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (r *ReverseHello) Encode(s *ua.Stream) error { +func (r *ReverseHello) Encode(s *ua.Stream) { s.WriteString(r.ServerURI) s.WriteString(r.EndpointURL) - return s.Error() } // Error represents a OPC UA Error. @@ -154,10 +151,9 @@ func (e *Error) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (e *Error) Encode(s *ua.Stream) error { +func (e *Error) Encode(s *ua.Stream) { s.WriteUint32(e.ErrorCode) s.WriteString(e.Reason) - return s.Error() } func (e *Error) Error() string { @@ -178,7 +174,6 @@ func (m *Message) Decode(b []byte) (int, error) { return len(b), nil } -func (m *Message) Encode(s *ua.Stream) error { +func (m *Message) Encode(s *ua.Stream) { s.Write(m.Data) - return nil } diff --git a/uasc/asymmetric_security_header.go b/uasc/asymmetric_security_header.go index 6c2032c8..713c0aac 100644 --- a/uasc/asymmetric_security_header.go +++ b/uasc/asymmetric_security_header.go @@ -34,11 +34,10 @@ func (h *AsymmetricSecurityHeader) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *AsymmetricSecurityHeader) Encode(s *ua.Stream) error { +func (h *AsymmetricSecurityHeader) Encode(s *ua.Stream) { s.WriteString(h.SecurityPolicyURI) s.WriteByteString(h.SenderCertificate) s.WriteByteString(h.ReceiverCertificateThumbprint) - return s.Error() } // String returns Header in string. diff --git a/uasc/codec_test.go b/uasc/codec_test.go index 9bd52826..14967df9 100644 --- a/uasc/codec_test.go +++ b/uasc/codec_test.go @@ -54,7 +54,8 @@ func RunCodecTest(t *testing.T, cases []CodecTestCase) { }) t.Run("encode", func(t *testing.T) { - s := ua.NewStream(ua.DefaultBufSize) + s := ua.BorrowStream() + defer ua.ReturnStream(s) s.WriteAny(c.Struct) if s.Error() != nil { t.Fatal(s.Error()) diff --git a/uasc/header.go b/uasc/header.go index ddc01a77..75a5fc26 100644 --- a/uasc/header.go +++ b/uasc/header.go @@ -51,15 +51,15 @@ func (h *Header) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *Header) Encode(s *ua.Stream) error { +func (h *Header) Encode(s *ua.Stream) { if len(h.MessageType) != 3 { - return errors.Errorf("invalid message type: %q", h.MessageType) + s.WrapError(errors.Errorf("invalid message type: %q", h.MessageType)) + return } s.Write([]byte(h.MessageType)) s.WriteByte(h.ChunkType) s.WriteUint32(h.MessageSize) s.WriteUint32(h.SecureChannelID) - return s.Error() } // String returns Header in string. diff --git a/uasc/message.go b/uasc/message.go index caf686bf..9b3dd365 100644 --- a/uasc/message.go +++ b/uasc/message.go @@ -74,10 +74,9 @@ func (m *MessageAbort) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (m *MessageAbort) Encode(s *ua.Stream) error { +func (m *MessageAbort) Encode(s *ua.Stream) { s.WriteUint32(m.ErrorCode) s.WriteString(m.Reason) - return s.Error() } func (m *MessageAbort) MessageAbort() string { @@ -111,13 +110,12 @@ func (m *Message) Decode(b []byte) (int, error) { return len(b), err } -func (m *Message) Encode(s *ua.Stream) error { +func (m *Message) Encode(s *ua.Stream) { chunks, err := m.EncodeChunks(math.MaxUint32) if err != nil { - return err + return } s.Write(chunks[0]) - return nil } func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { @@ -125,9 +123,8 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { defer ua.ReturnStream(dataBody) dataBody.WriteAny(m.TypeID) dataBody.WriteAny(m.Service) - if dataBody.Error() != nil { - return nil, dataBody.Error() + return nil, errors.Errorf("failed to encode databody: %s", dataBody.Error()) } nrChunks := uint32(dataBody.Len())/(maxBodySize) + 1 @@ -136,53 +133,49 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { switch m.Header.MessageType { case "OPN": partialHeader := ua.BorrowStream() - defer ua.ReturnStream(partialHeader) + defer ua.ReturnStream(dataBody) partialHeader.WriteAny(m.AsymmetricSecurityHeader) partialHeader.WriteAny(m.SequenceHeader) - if partialHeader.Error() != nil { - return nil, partialHeader.Error() + return nil, errors.Errorf("failed to encode partial header: %s", partialHeader.Error()) } m.Header.MessageSize = uint32(12 + partialHeader.Len() + dataBody.Len()) - buf := ua.BorrowStream() + buf := ua.NewStream(ua.DefaultBufSize) buf.WriteAny(m.Header) buf.Write(partialHeader.Bytes()) buf.Write(dataBody.Bytes()) - - b := append([]byte(nil), buf.Bytes()...) - ua.ReturnStream(buf) - return [][]byte{b}, buf.Error() + if buf.Error() != nil { + return nil, errors.Errorf("failed to encode chunk: %s", buf.Error()) + } + return [][]byte{buf.Bytes()}, nil case "CLO", "MSG": - + chunk := ua.NewStream(ua.DefaultBufSize) for i := uint32(0); i < nrChunks-1; i++ { + chunk.Reset() m.Header.MessageSize = maxBodySize + 24 m.Header.ChunkType = ChunkTypeIntermediate - chunk := ua.BorrowStream() chunk.WriteAny(m.Header) chunk.WriteAny(m.SymmetricSecurityHeader) chunk.WriteAny(m.SequenceHeader) chunk.Write(dataBody.ReadN(int(maxBodySize))) if chunk.Error() != nil { - ua.ReturnStream(chunk) - return nil, chunk.Error() + return nil, errors.Errorf("failed to encode chunk: %s", chunk.Error()) } chunks[i] = append(chunks[i], chunk.Bytes()...) - ua.ReturnStream(chunk) } m.Header.ChunkType = ChunkTypeFinal m.Header.MessageSize = uint32(24 + dataBody.Len()) - chunk := ua.BorrowStream() - defer ua.ReturnStream(chunk) + chunk.Reset() chunk.WriteAny(m.Header) chunk.WriteAny(m.SymmetricSecurityHeader) chunk.WriteAny(m.SequenceHeader) chunk.Write(dataBody.Bytes()) if chunk.Error() != nil { - return nil, chunk.Error() + return nil, errors.Errorf("failed to encode chunk: %s", chunk.Error()) } chunks[nrChunks-1] = append(chunks[nrChunks-1], chunk.Bytes()...) diff --git a/uasc/message_test.go b/uasc/message_test.go index 4280b826..1311456b 100644 --- a/uasc/message_test.go +++ b/uasc/message_test.go @@ -9,7 +9,6 @@ import ( "time" "github.com/gopcua/opcua/id" - "github.com/pascaldekloe/goe/verify" "github.com/gopcua/opcua/ua" ) @@ -484,16 +483,15 @@ func BenchmarkEncodeMessage(b *testing.B) { }, } - b.ResetTimer() s := ua.NewStream(ua.DefaultBufSize) + b.ResetTimer() for i := 0; i < b.N; i++ { for _, tc := range cases { + s.Reset() s.WriteAny(tc.Struct) if s.Error() != nil { b.Fatalf("fail to encode message, err: %v", s.Error()) } - verify.Values(b, "", s.Bytes(), tc.Bytes) - s.Reset() } } } diff --git a/uasc/sequence_header.go b/uasc/sequence_header.go index ca2bdc68..bd7324b5 100644 --- a/uasc/sequence_header.go +++ b/uasc/sequence_header.go @@ -31,10 +31,9 @@ func (h *SequenceHeader) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *SequenceHeader) Encode(s *ua.Stream) error { +func (h *SequenceHeader) Encode(s *ua.Stream) { s.WriteUint32(h.SequenceNumber) s.WriteUint32(h.RequestID) - return s.Error() } // String returns Header in string. diff --git a/uasc/symmetric_security_header.go b/uasc/symmetric_security_header.go index 40670086..44d57c2a 100644 --- a/uasc/symmetric_security_header.go +++ b/uasc/symmetric_security_header.go @@ -28,9 +28,8 @@ func (h *SymmetricSecurityHeader) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *SymmetricSecurityHeader) Encode(s *ua.Stream) error { +func (h *SymmetricSecurityHeader) Encode(s *ua.Stream) { s.WriteUint32(h.TokenID) - return s.Error() } // String returns Header in string. From 08a57265906d8958b09799cad455554555e97e75 Mon Sep 17 00:00:00 2001 From: yuanliang Date: Mon, 20 May 2024 17:31:25 +0800 Subject: [PATCH 06/14] fix: remove streamPool to avoid memory leak --- ua/encode.go | 17 ----------------- ua/extension_object.go | 3 +-- ua/stream.go | 1 + uacp/conn.go | 6 ++---- uasc/codec_test.go | 3 +-- uasc/message.go | 6 ++---- 6 files changed, 7 insertions(+), 29 deletions(-) diff --git a/ua/encode.go b/ua/encode.go index 2a58552a..3f66cd22 100644 --- a/ua/encode.go +++ b/ua/encode.go @@ -8,29 +8,12 @@ import ( "fmt" "math" "reflect" - "sync" "time" "github.com/gopcua/opcua/debug" "github.com/gopcua/opcua/errors" ) -var streamPool sync.Pool = sync.Pool{ - New: func() interface{} { - return NewStream(DefaultBufSize) - }, -} - -func BorrowStream() *Stream { - v := streamPool.Get().(*Stream) - v.Reset() - return v -} - -func ReturnStream(s *Stream) { - streamPool.Put(s) -} - // debugCodec enables printing of debug messages in the opcua codec. var debugCodec = debug.FlagSet("codec") diff --git a/ua/extension_object.go b/ua/extension_object.go index bfe7e15b..a616a41f 100644 --- a/ua/extension_object.go +++ b/ua/extension_object.go @@ -94,8 +94,7 @@ func (e *ExtensionObject) Encode(s *Stream) { return } - body := BorrowStream() - defer ReturnStream(body) + body := NewStream(DefaultBufSize) body.WriteAny(e.Value) if body.Error() != nil { s.WrapError(body.Error()) diff --git a/ua/stream.go b/ua/stream.go index e6cd31d7..9f5d39ed 100644 --- a/ua/stream.go +++ b/ua/stream.go @@ -38,6 +38,7 @@ func (s *Stream) Len() int { func (s *Stream) Reset() { s.buf = s.buf[:0] s.pos = 0 + s.err = nil } func (s *Stream) Bytes() []byte { diff --git a/uacp/conn.go b/uacp/conn.go index ab500ebd..b7018fde 100644 --- a/uacp/conn.go +++ b/uacp/conn.go @@ -396,8 +396,7 @@ func (c *Conn) Send(typ string, msg interface{}) error { return errors.Errorf("invalid msg type: %s", typ) } - bodyStream := ua.BorrowStream() - defer ua.ReturnStream(bodyStream) + bodyStream := ua.NewStream(ua.DefaultBufSize) bodyStream.WriteAny(msg) if bodyStream.Error() != nil { return errors.Errorf("encode msg failed: %s", bodyStream.Error()) @@ -413,8 +412,7 @@ func (c *Conn) Send(typ string, msg interface{}) error { return errors.Errorf("send packet too large: %d > %d bytes", h.MessageSize, c.ack.SendBufSize) } - headerStream := ua.BorrowStream() - defer ua.ReturnStream(headerStream) + headerStream := ua.NewStream(ua.DefaultBufSize) h.Encode(headerStream) if headerStream.Error() != nil { return errors.Errorf("encode hdr failed: %s", headerStream.Error()) diff --git a/uasc/codec_test.go b/uasc/codec_test.go index 14967df9..9bd52826 100644 --- a/uasc/codec_test.go +++ b/uasc/codec_test.go @@ -54,8 +54,7 @@ func RunCodecTest(t *testing.T, cases []CodecTestCase) { }) t.Run("encode", func(t *testing.T) { - s := ua.BorrowStream() - defer ua.ReturnStream(s) + s := ua.NewStream(ua.DefaultBufSize) s.WriteAny(c.Struct) if s.Error() != nil { t.Fatal(s.Error()) diff --git a/uasc/message.go b/uasc/message.go index 9b3dd365..4d53d139 100644 --- a/uasc/message.go +++ b/uasc/message.go @@ -119,8 +119,7 @@ func (m *Message) Encode(s *ua.Stream) { } func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { - dataBody := ua.BorrowStream() - defer ua.ReturnStream(dataBody) + dataBody := ua.NewStream(ua.DefaultBufSize) dataBody.WriteAny(m.TypeID) dataBody.WriteAny(m.Service) if dataBody.Error() != nil { @@ -132,8 +131,7 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { switch m.Header.MessageType { case "OPN": - partialHeader := ua.BorrowStream() - defer ua.ReturnStream(dataBody) + partialHeader := ua.NewStream(ua.DefaultBufSize) partialHeader.WriteAny(m.AsymmetricSecurityHeader) partialHeader.WriteAny(m.SequenceHeader) if partialHeader.Error() != nil { From 3ac63a8a5c2831f883e61899d2dfe754f9cc4e3a Mon Sep 17 00:00:00 2001 From: yuanliang Date: Fri, 31 May 2024 10:55:11 +0800 Subject: [PATCH 07/14] feat: Refactoring Encoder to Optimize Reflection Signed-off-by: yuanliang --- .gitignore | 1 + codec/encode.go | 609 ++++++++++++++++++++++++++++++++ go.mod | 6 +- go.sum | 9 +- ua/datatypes.go | 59 ++++ ua/diagnostic_info.go | 40 +++ ua/expanded_node_id.go | 18 + ua/extension_object.go | 22 ++ ua/node_id.go | 32 ++ ua/stream.go | 2 +- ua/variant.go | 30 ++ uacp/uacp.go | 42 +++ uasc/codec_test.go | 60 ++-- uasc/header.go | 14 + uasc/message.go | 98 +++++ uasc/message_test.go | 296 ++++++++++++++++ uasc/secure_channel.go | 2 +- uasc/secure_channel_instance.go | 1 + 18 files changed, 1302 insertions(+), 39 deletions(-) create mode 100644 codec/encode.go diff --git a/.gitignore b/.gitignore index 971ffc61..f7a0f8b7 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ dist/ __pycache__/ .vscode/ .idea/ +uasc/*.txt diff --git a/codec/encode.go b/codec/encode.go new file mode 100644 index 00000000..cc4c88c0 --- /dev/null +++ b/codec/encode.go @@ -0,0 +1,609 @@ +package codec + +import ( + "bytes" + "fmt" + "math" + "reflect" + "sync" + "time" +) + +// Marshal returns the OPCUA encoding of v. +func Marshal(v any) ([]byte, error) { + e := newEncodeState() + defer encodeStatePool.Put(e) + + err := e.marshal(v) + if err != nil { + return nil, err + } + buf := append([]byte(nil), e.Bytes()...) + + return buf, nil +} + +// Marshaler is the interface implemented by types that +// can marshal themselves into valid OPCUA. +type Marshaler interface { + MarshalOPCUA() ([]byte, error) +} + +// An UnsupportedTypeError is returned by [Marshal] when attempting +// to encode an unsupported value type. +type UnsupportedTypeError struct { + Type reflect.Type +} + +func (e *UnsupportedTypeError) Error() string { + return "opcua: unsupported type: " + e.Type.String() +} + +// An UnsupportedValueError is returned by [Marshal] when attempting +// to encode an unsupported value. +type UnsupportedValueError struct { + Value reflect.Value + Str string +} + +func (e *UnsupportedValueError) Error() string { + return "opcua: unsupported value: " + e.Str +} + +// A MarshalerError represents an error from calling a +// [Marshaler.MarshalOPCUA] method. +type MarshalerError struct { + Type reflect.Type + Err error + sourceFunc string +} + +func (e *MarshalerError) Error() string { + srcFunc := e.sourceFunc + if srcFunc == "" { + srcFunc = "MarshalOPCUA" + } + return "opcua: error calling " + srcFunc + + " for type " + e.Type.String() + + ": " + e.Err.Error() +} + +type encodeState struct { + bytes.Buffer + + ptrLevel uint + ptrSeen map[any]struct{} +} + +const startDetectingCyclesAfter = 1000 + +var encodeStatePool sync.Pool + +func newEncodeState() *encodeState { + if v := encodeStatePool.Get(); v != nil { + e := v.(*encodeState) + e.Reset() + if len(e.ptrSeen) > 0 { + panic("ptrEncoder.encode should have emptied ptrSeen via defers") + } + e.ptrLevel = 0 + return e + } + return &encodeState{ptrSeen: make(map[any]struct{})} +} + +// codecError is an error wrapper type for internal use only. +// Panics with errors are wrapped in codecError so that the top-level recover +// can distinguish intentional panics from this package. +type codecError struct{ error } + +func (e *encodeState) marshal(v any) (err error) { + defer func() { + if r := recover(); r != nil { + if je, ok := r.(codecError); ok { + err = je.error + } else { + panic(r) + } + } + }() + e.reflectValue(reflect.ValueOf(v)) + return nil +} + +// error aborts the encoding by panicking with err wrapped in jsonError. +func (e *encodeState) error(err error) { + panic(codecError{err}) +} + +func (e *encodeState) reflectValue(v reflect.Value) { + valueEncoder(v)(e, v) +} + +type encoderFunc func(e *encodeState, v reflect.Value) + +var encoderCache sync.Map // map[reflect.Type]encoderFunc + +func valueEncoder(v reflect.Value) encoderFunc { + if isTime(v) { + return timeEncoder + } + + return typeEncoder(v.Type()) +} + +func typeEncoder(t reflect.Type) encoderFunc { + if fi, ok := encoderCache.Load(t); ok { + return fi.(encoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f encoderFunc + ) + wg.Add(1) + fi, loaded := encoderCache.LoadOrStore(t, encoderFunc(func(e *encodeState, v reflect.Value) { + wg.Wait() + f(e, v) + })) + if loaded { + return fi.(encoderFunc) + } + + // Compute the real encoder and replace the indirect func with it. + f = newTypeEncoder(t, true) + wg.Done() + encoderCache.Store(t, f) + return f +} + +var marshalerType = reflect.TypeFor[Marshaler]() + +// newTypeEncoder constructs an encoderFunc for a type. +// The returned encoder only checks CanAddr when allowAddr is true. +func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc { + kind := t.Kind() + // If we have a non-pointer value whose type implements + // Marshaler with a value receiver, then we're better off taking + // the address of the value - otherwise we end up with an + // allocation as we cast the value to an interface. + if kind != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) { + return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false)) + } + if t.Implements(marshalerType) { + return marshalerEncoder + } + + switch kind { + case reflect.Bool: + return boolEncoder + case reflect.Int8: + return int8Encoder + case reflect.Uint8: + return uint8Encoder + case reflect.Int16: + return int16Encoder + case reflect.Uint16: + return uint16Encoder + case reflect.Int32: + return int32Encoder + case reflect.Uint32: + return uint32Encoder + case reflect.Int64: + return int64Encoder + case reflect.Uint64: + return uint64Encoder + case reflect.Float32: + return float32Encoder + case reflect.Float64: + return float64Encoder + case reflect.String: + return stringEncoder + case reflect.Interface: + return interfaceEncoder + case reflect.Ptr: + return newPtrEncoder(t) + case reflect.Struct: + return newStructEncoder(t) + case reflect.Array: + return newArrayEncoder(t) + case reflect.Slice: + return newSliceEncoder(t) + default: + return unsupportedTypeEncoder + } +} + +func isTime(val reflect.Value) bool { + return val.CanConvert(timeType) +} + +func marshalerEncoder(e *encodeState, v reflect.Value) { + if v.Kind() == reflect.Pointer && v.IsNil() { + return + } + m, ok := v.Interface().(Marshaler) + if !ok { + return + } + b, err := m.MarshalOPCUA() + if err == nil { + e.Grow(len(b)) + out := e.AvailableBuffer() + out = append(out, b...) + e.Buffer.Write(out) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err, "MarshalOPCUA"}) + } +} + +func addrMarshalerEncoder(e *encodeState, v reflect.Value) { + va := v.Addr() + if va.IsNil() { + e.WriteString("null") + return + } + m := va.Interface().(Marshaler) + b, err := m.MarshalOPCUA() + if err == nil { + e.Grow(len(b)) + out := e.AvailableBuffer() + out = append(out, b...) + e.Buffer.Write(out) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err, "MarshalOPCUA"}) + } +} + +func (e *encodeState) writeUint16(n uint16) { + e.Write([]byte{byte(n), byte(n >> 8)}) +} + +func (e *encodeState) writeUint32(n uint32) { + e.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) +} + +func (e *encodeState) writeUint64(n uint64) { + e.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24), byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56)}) +} + +func boolEncoder(e *encodeState, v reflect.Value) { + val := v.Bool() + if val { + e.WriteByte('1') + } else { + e.WriteByte('0') + } +} + +func int8Encoder(e *encodeState, v reflect.Value) { + val := int8(v.Int()) + e.WriteByte(byte(val)) +} + +func uint8Encoder(e *encodeState, v reflect.Value) { + val := uint8(v.Uint()) + e.WriteByte(val) +} + +func int16Encoder(e *encodeState, v reflect.Value) { + val := uint16(v.Int()) + e.writeUint16(val) +} + +func uint16Encoder(e *encodeState, v reflect.Value) { + val := uint16(v.Uint()) + e.writeUint16(val) +} + +func int32Encoder(e *encodeState, v reflect.Value) { + val := uint32(v.Int()) + e.writeUint32(val) +} + +func uint32Encoder(e *encodeState, v reflect.Value) { + val := uint32(v.Uint()) + e.writeUint32(val) +} + +func int64Encoder(e *encodeState, v reflect.Value) { + val := uint64(v.Int()) + e.writeUint64(val) +} + +func uint64Encoder(e *encodeState, v reflect.Value) { + val := v.Uint() + e.writeUint64(val) +} + +var ( + null = []byte{0xff, 0xff, 0xff, 0xff} + f32qnan = []byte{0x00, 0x00, 0xff, 0xc0} + f64qnan = []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf8, 0xff} +) + +func float32Encoder(e *encodeState, v reflect.Value) { + if math.IsNaN(v.Float()) { + e.Write(f32qnan) + } else { + val := math.Float32bits((float32)(v.Float())) + e.writeUint32(val) + } +} + +func float64Encoder(e *encodeState, v reflect.Value) { + if math.IsNaN(v.Float()) { + e.Write(f64qnan) + } else { + val := math.Float64bits(v.Float()) + e.writeUint64(val) + } +} + +func stringEncoder(e *encodeState, v reflect.Value) { + s := v.String() + if s == "" { + e.Write(null) + return + } + + l := len(s) + e.writeUint32(uint32(l)) + e.Write([]byte(s)) +} + +var timeType = reflect.TypeOf(time.Time{}) + +func timeEncoder(e *encodeState, v reflect.Value) { + var ts uint64 + val := v.Convert(timeType).Interface().(time.Time) + if !v.IsZero() { + // encode time in "100 nanosecond intervals since January 1, 1601" + ts = uint64(val.UTC().UnixNano()/100 + 116444736000000000) + } + e.writeUint64(ts) +} + +func interfaceEncoder(e *encodeState, v reflect.Value) { + if v.IsNil() { + return + } + e.reflectValue(v.Elem()) +} + +func unsupportedTypeEncoder(e *encodeState, v reflect.Value) { + e.error(&UnsupportedTypeError{v.Type()}) +} + +type structEncoder struct { + fields structFields +} + +type structFields struct { + list []field +} + +func (se structEncoder) encode(e *encodeState, v reflect.Value) { +FieldLoop: + for i := range se.fields.list { + f := &se.fields.list[i] + + // Find the nested struct field by following f.index. + fv := v + for _, i := range f.index { + if fv.Kind() == reflect.Pointer { + if fv.IsNil() { + continue FieldLoop + } + fv = fv.Elem() + } + fv = fv.Field(i) + } + + f.encoder(e, fv) + } +} + +func newStructEncoder(t reflect.Type) encoderFunc { + se := structEncoder{fields: cachedTypeFields(t)} + return se.encode +} + +func encodeByteSlice(e *encodeState, v reflect.Value) { + if v.IsNil() { + e.Write(null) + return + } + + n := v.Len() + e.writeUint32(uint32(n)) + + b := make([]byte, n) + reflect.Copy(reflect.ValueOf(b), v) + e.Write(b) +} + +// sliceEncoder just wraps an arrayEncoder, checking to make sure the value isn't nil. +type sliceEncoder struct { + arrayEnc encoderFunc +} + +func (se sliceEncoder) encode(e *encodeState, v reflect.Value) { + if v.IsNil() { + e.Write(null) + return + } + + if v.Len() > math.MaxInt32 { + panic("array too large") + } + + if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter { + // We're a large number of nested ptrEncoder.encode calls deep; + // start checking if we've run into a pointer cycle. + // Here we use a struct to memorize the pointer to the first element of the slice + // and its length. + ptr := struct { + ptr interface{} // always an unsafe.Pointer, but avoids a dependency on package unsafe + len int + }{v.UnsafePointer(), v.Len()} + if _, ok := e.ptrSeen[ptr]; ok { + e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())}) + } + e.ptrSeen[ptr] = struct{}{} + defer delete(e.ptrSeen, ptr) + } + se.arrayEnc(e, v) + e.ptrLevel-- +} + +func newSliceEncoder(t reflect.Type) encoderFunc { + // Byte slices get special treatment; arrays don't. + if t.Elem().Kind() == reflect.Uint8 { + p := reflect.PointerTo(t.Elem()) + if !p.Implements(marshalerType) { + return encodeByteSlice + } + } + enc := sliceEncoder{newArrayEncoder(t)} + return enc.encode +} + +type arrayEncoder struct { + elemEnc encoderFunc +} + +func (ae arrayEncoder) encode(e *encodeState, v reflect.Value) { + n := v.Len() + e.writeUint32(uint32(n)) + + // fast path for []byte + if v.Type().Elem().Kind() == reflect.Uint8 { + b := make([]byte, n) + reflect.Copy(reflect.ValueOf(b), v) + e.Write(b) + return + } + + // loop over elements + // we write all the elements, also the zero values + for i := 0; i < n; i++ { + ae.elemEnc(e, v.Index(i)) + } +} + +func newArrayEncoder(t reflect.Type) encoderFunc { + enc := arrayEncoder{typeEncoder(t.Elem())} + return enc.encode +} + +type ptrEncoder struct { + elemEnc encoderFunc +} + +func (pe ptrEncoder) encode(e *encodeState, v reflect.Value) { + if v.IsNil() { + return + } + if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter { + // We're a large number of nested ptrEncoder.encode calls deep; + // start checking if we've run into a pointer cycle. + ptr := v.Interface() + if _, ok := e.ptrSeen[ptr]; ok { + e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())}) + } + e.ptrSeen[ptr] = struct{}{} + defer delete(e.ptrSeen, ptr) + } + pe.elemEnc(e, v.Elem()) + e.ptrLevel-- +} + +func newPtrEncoder(t reflect.Type) encoderFunc { + enc := ptrEncoder{typeEncoder(t.Elem())} + return enc.encode +} + +type condAddrEncoder struct { + canAddrEnc, elseEnc encoderFunc +} + +func (ce condAddrEncoder) encode(e *encodeState, v reflect.Value) { + if v.CanAddr() { + ce.canAddrEnc(e, v) + } else { + ce.elseEnc(e, v) + } +} + +// newCondAddrEncoder returns an encoder that checks whether its value +// CanAddr and delegates to canAddrEnc if so, else to elseEnc. +func newCondAddrEncoder(canAddrEnc, elseEnc encoderFunc) encoderFunc { + enc := condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc} + return enc.encode +} + +// A field represents a single field found in a struct. +type field struct { + name string + nameBytes []byte // []byte(name) + + index []int + typ reflect.Type + + encoder encoderFunc +} + +// typeFields returns a list of fields that the encoder should recognize for a given type. +// The algorithm is a depth-first search of the set of structures, including any reachable anonymous structures, +// until the structure is traversed completely. +func typeFields(t reflect.Type) structFields { + var fields []field + + visitField := func(f reflect.StructField) { + t := f.Type + // time.Time is special because it has embedded structs that use timeEncoder. + if t == timeType || (t.Kind() == reflect.Pointer && t.Elem() == timeType) { + fields = append(fields, field{name: f.Name, nameBytes: []byte(f.Name), index: f.Index, typ: t, encoder: timeEncoder}) + return + } + // 如果t实现了Marshaler接口,则使用Marshaler的MarshalOPCUA方法 + if t.Implements(marshalerType) { + fields = append(fields, field{name: f.Name, nameBytes: []byte(f.Name), index: f.Index, typ: t, encoder: marshalerEncoder}) + return + } + + // Check for anonymous fields (embedded structs). + if f.Anonymous { + if t.Kind() == reflect.Pointer { + t = f.Type.Elem() + } + fields = append(fields, typeFields(t).list...) + } + + fields = append(fields, field{name: f.Name, nameBytes: []byte(f.Name), index: f.Index, typ: f.Type, encoder: typeEncoder(f.Type)}) + } + + // Process all fields in the root struct. + for i := 0; i < t.NumField(); i++ { + visitField(t.Field(i)) + } + return structFields{list: fields} +} + +var fieldCache sync.Map // map[reflect.Type]structFields + +// cachedTypeFields is like typeFields but uses a cache to avoid repeated work. +func cachedTypeFields(t reflect.Type) structFields { + if f, ok := fieldCache.Load(t); ok { + return f.(structFields) + } + f, _ := fieldCache.LoadOrStore(t, typeFields(t)) + return f.(structFields) +} diff --git a/go.mod b/go.mod index f75924d4..ed724078 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,11 @@ go 1.20 require ( github.com/pascaldekloe/goe v0.1.1 github.com/pkg/errors v0.9.1 - golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 - golang.org/x/term v0.8.0 + golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 + golang.org/x/term v0.18.0 ) -require golang.org/x/sys v0.8.0 // indirect +require golang.org/x/sys v0.18.0 // indirect retract ( v0.2.5 // https://github.com/gopcua/opcua/issues/538 diff --git a/go.sum b/go.sum index ac17d5df..ff4affa7 100644 --- a/go.sum +++ b/go.sum @@ -2,9 +2,6 @@ github.com/pascaldekloe/goe v0.1.1 h1:Ah6WQ56rZONR3RW3qWa2NCZ6JAVvSpUcoLBaOmYFt9 github.com/pascaldekloe/goe v0.1.1/go.mod h1:KSyfaxQOh0HZPjDP1FL/kFtbqYqrALJTaMafFUIccqU= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= -golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= diff --git a/ua/datatypes.go b/ua/datatypes.go index aa5e12ea..d3954366 100644 --- a/ua/datatypes.go +++ b/ua/datatypes.go @@ -5,11 +5,14 @@ package ua import ( + "bytes" "encoding/binary" "encoding/hex" "fmt" "strings" "time" + + "github.com/gopcua/opcua/codec" ) // These flags define which fields of a DataValue are set. @@ -84,6 +87,36 @@ func (d *DataValue) Encode(s *Stream) { } } +func (d *DataValue) MarshalOPCUA() ([]byte, error) { + var buf bytes.Buffer + var err error + var b []byte + + buf.WriteByte(d.EncodingMask) + if d.Has(DataValueValue) { + b, err = codec.Marshal(d.Value) + buf.Write(b) + } + if d.Has(DataValueStatusCode) { + buf.Write([]byte{byte(d.Status), byte(d.Status >> 8), byte(d.Status >> 16), byte(d.Status >> 24)}) + } + if d.Has(DataValueSourceTimestamp) { + b, err = codec.Marshal(d.SourceTimestamp) + buf.Write(b) + } + if d.Has(DataValueSourcePicoseconds) { + buf.Write([]byte{byte(d.SourcePicoseconds), byte(d.SourcePicoseconds >> 8)}) + } + if d.Has(DataValueServerTimestamp) { + b, err = codec.Marshal(d.ServerTimestamp) + buf.Write(b) + } + if d.Has(DataValueServerPicoseconds) { + buf.Write([]byte{byte(d.ServerPicoseconds), byte(d.ServerPicoseconds >> 8)}) + } + return buf.Bytes(), err +} + func (d *DataValue) Has(mask byte) bool { return d.EncodingMask&mask == mask } @@ -233,6 +266,32 @@ func (l *LocalizedText) Encode(s *Stream) { } } +func (l *LocalizedText) MarshalOPCUA() ([]byte, error) { + var buf bytes.Buffer + var err error + + buf.WriteByte(l.EncodingMask) + if l.Has(LocalizedTextLocale) { + n := len(l.Locale) + if n == 0 { + buf.Write([]byte{0xff, 0xff, 0xff, 0xff}) + } else { + buf.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) + buf.Write([]byte(l.Locale)) + } + } + if l.Has(LocalizedTextText) { + n := len(l.Text) + if n == 0 { + buf.Write([]byte{0xff, 0xff, 0xff, 0xff}) + } else { + buf.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) + buf.Write([]byte(l.Text)) + } + } + return buf.Bytes(), err +} + func (l *LocalizedText) Has(mask byte) bool { return l.EncodingMask&mask == mask } diff --git a/ua/diagnostic_info.go b/ua/diagnostic_info.go index b72d46df..e5457a9e 100644 --- a/ua/diagnostic_info.go +++ b/ua/diagnostic_info.go @@ -4,6 +4,12 @@ package ua +import ( + "bytes" + + "github.com/gopcua/opcua/codec" +) + // These flags define which fields of a DiagnosticInfo are set. // Bits are or'ed together if multiple fields are set. const ( @@ -83,6 +89,40 @@ func (d *DiagnosticInfo) Encode(s *Stream) { } } +func (d *DiagnosticInfo) MarshalOPCUA() ([]byte, error) { + var buf bytes.Buffer + buf.WriteByte(d.EncodingMask) + + if d.Has(DiagnosticInfoSymbolicID) { + buf.Write([]byte{byte(d.SymbolicID), byte(d.SymbolicID >> 8), byte(d.SymbolicID >> 16), byte(d.SymbolicID >> 24)}) + } + if d.Has(DiagnosticInfoNamespaceURI) { + buf.Write([]byte{byte(d.NamespaceURI), byte(d.NamespaceURI >> 8), byte(d.NamespaceURI >> 16), byte(d.NamespaceURI >> 24)}) + } + if d.Has(DiagnosticInfoLocale) { + buf.Write([]byte{byte(d.Locale), byte(d.Locale >> 8), byte(d.Locale >> 16), byte(d.Locale >> 24)}) + } + if d.Has(DiagnosticInfoLocalizedText) { + buf.Write([]byte{byte(d.LocalizedText), byte(d.LocalizedText >> 8), byte(d.LocalizedText >> 16), byte(d.LocalizedText >> 24)}) + } + if d.Has(DiagnosticInfoAdditionalInfo) { + b, _ := codec.Marshal(d.AdditionalInfo) + buf.Write(b) + } + if d.Has(DiagnosticInfoInnerStatusCode) { + buf.Write([]byte{byte(d.InnerStatusCode), byte(d.InnerStatusCode >> 8), byte(d.InnerStatusCode >> 16), byte(d.InnerStatusCode >> 24)}) + } + if d.Has(DiagnosticInfoInnerDiagnosticInfo) { + b, err := codec.Marshal(d.InnerDiagnosticInfo) + if err != nil { + return nil, err + } + buf.Write(b) + } + + return buf.Bytes(), nil +} + func (d *DiagnosticInfo) Has(mask byte) bool { return d.EncodingMask&mask == mask } diff --git a/ua/expanded_node_id.go b/ua/expanded_node_id.go index c0e7ee3c..257dc29c 100644 --- a/ua/expanded_node_id.go +++ b/ua/expanded_node_id.go @@ -5,11 +5,13 @@ package ua import ( + "bytes" "encoding/base64" "math" "strconv" "strings" + "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/errors" ) @@ -112,6 +114,22 @@ func (e *ExpandedNodeID) Encode(s *Stream) { } } +func (e *ExpandedNodeID) MarshalOPCUA() ([]byte, error) { + var buf bytes.Buffer + b, err := e.NodeID.MarshalOPCUA() + buf.Write(b) + + if e.HasNamespaceURI() { + b, err = codec.Marshal(e.NamespaceURI) + buf.Write(b) + } + if e.HasServerIndex() { + b, err = codec.Marshal(e.ServerIndex) + buf.Write(b) + } + return buf.Bytes(), err +} + // HasNamespaceURI checks if an ExpandedNodeID has NamespaceURI Flag. func (e *ExpandedNodeID) HasNamespaceURI() bool { return e.NodeID.EncodingMask()>>7&0x1 == 1 diff --git a/ua/extension_object.go b/ua/extension_object.go index a616a41f..82e3ee6b 100644 --- a/ua/extension_object.go +++ b/ua/extension_object.go @@ -5,6 +5,9 @@ package ua import ( + "bytes" + + "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/debug" "github.com/gopcua/opcua/id" ) @@ -104,6 +107,25 @@ func (e *ExtensionObject) Encode(s *Stream) { s.Write(body.Bytes()) } +func (e *ExtensionObject) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + b, err := codec.Marshal(e.TypeID) + buf.Write(b) + buf.WriteByte(e.EncodingMask) + if e.EncodingMask == ExtensionObjectEmpty { + return buf.Bytes(), err + } + + body, err := codec.Marshal(e.Value) + if err != nil { + return buf.Bytes(), err + } + n := uint32(len(body)) + buf.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) + buf.Write(body) + return buf.Bytes(), err +} + func (e *ExtensionObject) UpdateMask() { if e.Value == nil { e.EncodingMask = ExtensionObjectEmpty diff --git a/ua/node_id.go b/ua/node_id.go index 9584b1b5..c9934aaf 100644 --- a/ua/node_id.go +++ b/ua/node_id.go @@ -5,11 +5,13 @@ package ua import ( + "bytes" "encoding/base64" "encoding/json" "fmt" "math" + "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/errors" ) @@ -389,6 +391,36 @@ func (n *NodeID) Encode(s *Stream) { } } +func (n *NodeID) MarshalOPCUA() ([]byte, error) { + var buf bytes.Buffer + buf.WriteByte(byte(n.mask)) + switch n.Type() { + case NodeIDTypeTwoByte: + buf.WriteByte(byte(n.nid)) + case NodeIDTypeFourByte: + buf.WriteByte(byte(n.ns)) + buf.Write([]byte{byte(n.nid), byte(n.nid >> 8)}) + case NodeIDTypeNumeric: + buf.Write([]byte{byte(n.ns), byte(n.ns >> 8), byte(n.nid), byte(n.nid >> 8), byte(n.nid >> 16), byte(n.nid >> 24)}) + case NodeIDTypeGUID: + buf.Write([]byte{byte(n.ns), byte(n.ns >> 8)}) + b, _ := codec.Marshal(n.gid) + buf.Write(b) + case NodeIDTypeByteString, NodeIDTypeString: + buf.Write([]byte{byte(n.ns), byte(n.ns >> 8)}) + l := uint32(len(n.bid)) + if l == 0 { + buf.Write([]byte{0xff, 0xff, 0xff, 0xff}) + } else { + buf.Write([]byte{byte(l), byte(l >> 8), byte(l >> 16), byte(l >> 24)}) + buf.Write(n.bid) + } + default: + return nil, fmt.Errorf("invalid node id type: %d", n.mask) + } + return buf.Bytes(), nil +} + func (n *NodeID) MarshalJSON() ([]byte, error) { if n == nil { return []byte(`null`), nil diff --git a/ua/stream.go b/ua/stream.go index 9f5d39ed..27063e0a 100644 --- a/ua/stream.go +++ b/ua/stream.go @@ -9,7 +9,7 @@ import ( "github.com/gopcua/opcua/errors" ) -const DefaultBufSize = 1024 +const DefaultBufSize = 256 type Stream struct { buf []byte diff --git a/ua/variant.go b/ua/variant.go index bb29604c..dd9f8dfa 100644 --- a/ua/variant.go +++ b/ua/variant.go @@ -5,10 +5,12 @@ package ua import ( + "bytes" "fmt" "reflect" "time" + "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/errors" ) @@ -340,6 +342,34 @@ func (m *Variant) Encode(s *Stream) { } } +func (m *Variant) MarshalOPCUA() ([]byte, error) { + var buf bytes.Buffer + buf.WriteByte(m.mask) + + // a null value specifies that no other fields are encoded + if m.Type() == TypeIDNull { + return buf.Bytes(), nil + } + + if m.Has(VariantArrayValues) { + buf.Write([]byte{byte(m.arrayLength), byte(m.arrayLength >> 8), byte(m.arrayLength >> 16), byte(m.arrayLength >> 24)}) + } + + b, err := codec.Marshal(m.value) + if err != nil { + return buf.Bytes(), err + } + buf.Write(b) + + if m.Has(VariantArrayDimensions) { + buf.Write([]byte{byte(m.arrayDimensionsLength), byte(m.arrayDimensionsLength >> 8), byte(m.arrayDimensionsLength >> 16), byte(m.arrayDimensionsLength >> 24)}) + for _, v := range m.arrayDimensions { + buf.Write([]byte{byte(v), byte(v >> 8), byte(v >> 16), byte(v >> 24)}) + } + } + return buf.Bytes(), nil +} + // encode recursively writes the values to the buffer. func (m *Variant) encode(s *Stream, val reflect.Value) { if val.Kind() != reflect.Slice || m.Type() == TypeIDByteString { diff --git a/uacp/uacp.go b/uacp/uacp.go index 2ae23826..6b921474 100644 --- a/uacp/uacp.go +++ b/uacp/uacp.go @@ -5,6 +5,9 @@ package uacp import ( + "bytes" + "encoding/binary" + "github.com/gopcua/opcua/errors" "github.com/gopcua/opcua/ua" ) @@ -55,6 +58,18 @@ func (h *Header) Encode(s *ua.Stream) { s.WriteUint32(h.MessageSize) } +func (h *Header) MarshalOPCUA() ([]byte, error) { + if len(h.MessageType) != 3 { + return nil, errors.Errorf("invalid message type: %q", h.MessageType) + } + + var buf bytes.Buffer + buf.Write([]byte(h.MessageType)) + buf.WriteByte(h.ChunkType) + buf.Write([]byte{byte(h.MessageSize), byte(h.MessageSize >> 8), byte(h.MessageSize >> 16), byte(h.MessageSize >> 24)}) + return buf.Bytes(), nil +} + // Hello represents a OPC UA Hello. // // Specification: Part6, 7.1.2.3 @@ -87,6 +102,23 @@ func (h *Hello) Encode(s *ua.Stream) { s.WriteString(h.EndpointURL) } +func (h *Hello) MarshalOPCUA() ([]byte, error) { + var buf bytes.Buffer + buf.Write([]byte{byte(h.Version), byte(h.Version >> 8), byte(h.Version >> 16), byte(h.Version >> 24)}) + buf.Write([]byte{byte(h.ReceiveBufSize), byte(h.ReceiveBufSize >> 8), byte(h.ReceiveBufSize >> 16), byte(h.ReceiveBufSize >> 24)}) + buf.Write([]byte{byte(h.SendBufSize), byte(h.SendBufSize >> 8), byte(h.SendBufSize >> 16), byte(h.SendBufSize >> 24)}) + buf.Write([]byte{byte(h.MaxMessageSize), byte(h.MaxMessageSize >> 8), byte(h.MaxMessageSize >> 16), byte(h.MaxMessageSize >> 24)}) + buf.Write([]byte{byte(h.MaxChunkCount), byte(h.MaxChunkCount >> 8), byte(h.MaxChunkCount >> 16), byte(h.MaxChunkCount >> 24)}) + if len(h.EndpointURL) == 0 { + buf.Write([]byte{0xff, 0xff, 0xff, 0xff}) + } else { + n := len(h.EndpointURL) + buf.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) + buf.WriteString(h.EndpointURL) + } + return buf.Bytes(), nil +} + // Acknowledge represents a OPC UA Acknowledge. // // Specification: Part6, 7.1.2.4 @@ -116,6 +148,16 @@ func (a *Acknowledge) Encode(s *ua.Stream) { s.WriteUint32(a.MaxChunkCount) } +func (a *Acknowledge) MarshalOPCUA() ([]byte, error) { + buf := make([]byte, 0, 160) + binary.LittleEndian.AppendUint32(buf, a.Version) + binary.LittleEndian.AppendUint32(buf, a.ReceiveBufSize) + binary.LittleEndian.AppendUint32(buf, a.SendBufSize) + binary.LittleEndian.AppendUint32(buf, a.MaxMessageSize) + binary.LittleEndian.AppendUint32(buf, a.MaxChunkCount) + return buf, nil +} + // ReverseHello represents a OPC UA ReverseHello. // // Specification: Part6, 7.1.2.6 diff --git a/uasc/codec_test.go b/uasc/codec_test.go index 9bd52826..bb636a48 100644 --- a/uasc/codec_test.go +++ b/uasc/codec_test.go @@ -7,10 +7,9 @@ package uasc import ( - "reflect" "testing" - "github.com/gopcua/opcua/ua" + "github.com/gopcua/opcua/codec" "github.com/pascaldekloe/goe/verify" ) @@ -29,37 +28,42 @@ func RunCodecTest(t *testing.T, cases []CodecTestCase) { for _, c := range cases { t.Run(c.Name, func(t *testing.T) { - t.Run("decode", func(t *testing.T) { - // create a new instance of the same type as c.Struct - typ := reflect.ValueOf(c.Struct).Type() - var v reflect.Value - switch typ.Kind() { - case reflect.Ptr: - v = reflect.New(typ.Elem()) // typ: *struct, v: *struct - case reflect.Slice: - v = reflect.New(typ) // typ: []x, v: *[]x - default: - t.Fatalf("%T is not a pointer or a slice", c.Struct) - } + // t.Run("decode", func(t *testing.T) { + // // create a new instance of the same type as c.Struct + // typ := reflect.ValueOf(c.Struct).Type() + // var v reflect.Value + // switch typ.Kind() { + // case reflect.Ptr: + // v = reflect.New(typ.Elem()) // typ: *struct, v: *struct + // case reflect.Slice: + // v = reflect.New(typ) // typ: []x, v: *[]x + // default: + // t.Fatalf("%T is not a pointer or a slice", c.Struct) + // } - if _, err := ua.Decode(c.Bytes, v.Interface()); err != nil { - t.Fatal(err) - } + // if _, err := ua.Decode(c.Bytes, v.Interface()); err != nil { + // t.Fatal(err) + // } - // if v is a *[]x we need to dereference it before comparing it. - if typ.Kind() == reflect.Slice { - v = v.Elem() - } - verify.Values(t, "", v.Interface(), c.Struct) - }) + // // if v is a *[]x we need to dereference it before comparing it. + // if typ.Kind() == reflect.Slice { + // v = v.Elem() + // } + // verify.Values(t, "", v.Interface(), c.Struct) + // }) t.Run("encode", func(t *testing.T) { - s := ua.NewStream(ua.DefaultBufSize) - s.WriteAny(c.Struct) - if s.Error() != nil { - t.Fatal(s.Error()) + b, err := codec.Marshal(c.Struct) + if err != nil { + t.Fatal(err) } - verify.Values(t, "", s.Bytes(), c.Bytes) + verify.Values(t, "", b, c.Bytes) + // s := ua.NewStream(ua.DefaultBufSize) + // s.WriteAny(c.Struct) + // if s.Error() != nil { + // t.Fatal(s.Error()) + // } + // verify.Values(t, "", s.Bytes(), c.Bytes) }) }) } diff --git a/uasc/header.go b/uasc/header.go index 75a5fc26..20444877 100644 --- a/uasc/header.go +++ b/uasc/header.go @@ -5,6 +5,7 @@ package uasc import ( + "bytes" "fmt" "github.com/gopcua/opcua/errors" @@ -62,6 +63,19 @@ func (h *Header) Encode(s *ua.Stream) { s.WriteUint32(h.SecureChannelID) } +func (h *Header) MarshalOPCUA() ([]byte, error) { + if len(h.MessageType) != 3 { + return nil, errors.Errorf("invalid message type: %q", h.MessageType) + } + + var buf bytes.Buffer + buf.WriteString(h.MessageType) + buf.WriteByte(h.ChunkType) + buf.Write([]byte{byte(h.MessageSize), byte(h.MessageSize >> 8), byte(h.MessageSize >> 16), byte(h.MessageSize >> 24)}) + buf.Write([]byte{byte(h.SecureChannelID), byte(h.SecureChannelID >> 8), byte(h.SecureChannelID >> 16), byte(h.SecureChannelID >> 24)}) + return buf.Bytes(), nil +} + // String returns Header in string. func (h *Header) String() string { return fmt.Sprintf( diff --git a/uasc/message.go b/uasc/message.go index 4d53d139..1bc23566 100644 --- a/uasc/message.go +++ b/uasc/message.go @@ -5,8 +5,10 @@ package uasc import ( + "fmt" "math" + "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/errors" "github.com/gopcua/opcua/ua" ) @@ -146,6 +148,10 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { if buf.Error() != nil { return nil, errors.Errorf("failed to encode chunk: %s", buf.Error()) } + for _, v := range buf.Bytes() { + fmt.Printf("0x%02x,", v) + } + fmt.Println() return [][]byte{buf.Bytes()}, nil case "CLO", "MSG": @@ -182,3 +188,95 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { return nil, errors.Errorf("invalid message type %q", m.Header.MessageType) } } + +func (m *Message) MarshalOPCUA() ([]byte, error) { + chunks, err := m.MarshalChunks(math.MaxUint32) + if err != nil { + return nil, err + } + + return chunks[0], nil +} + +func (m *Message) MarshalChunks(maxBodySize uint32) ([][]byte, error) { + typeID, err := codec.Marshal(m.TypeID) + if err != nil { + return nil, errors.Errorf("failed to encode typeid: %s", err) + } + service, err := codec.Marshal(m.Service) + if err != nil { + return nil, errors.Errorf("failed to encode service: %s", err) + } + + dataBody := make([]byte, len(typeID)+len(service)) + copy(dataBody, typeID) + copy(dataBody[len(typeID):], service) + + nrChunks := uint32(len(dataBody))/(maxBodySize) + 1 + chunks := make([][]byte, nrChunks) + + switch m.Header.MessageType { + case "OPN": + asymmetricSecurityHeader, err := codec.Marshal(m.AsymmetricSecurityHeader) + if err != nil { + return nil, errors.Errorf("failed to encode asymmetric security header: %s", err) + } + sequenceHeader, err := codec.Marshal(m.SequenceHeader) + if err != nil { + return nil, errors.Errorf("failed to encode sequence header: %s", err) + } + + m.Header.MessageSize = uint32(12 + len(asymmetricSecurityHeader) + len(sequenceHeader) + len(dataBody)) + header, err := codec.Marshal(m.Header) + if err != nil { + return nil, errors.Errorf("failed to encode header: %s", err) + } + chunks[0] = append(chunks[0], header...) + chunks[0] = append(chunks[0], asymmetricSecurityHeader...) + chunks[0] = append(chunks[0], sequenceHeader...) + chunks[0] = append(chunks[0], dataBody...) + return chunks, nil + + case "CLO", "MSG": + symmetricSecurityHeader, err := codec.Marshal(m.SymmetricSecurityHeader) + if err != nil { + return nil, errors.Errorf("failed to encode symmetric security header: %s", err) + } + sequenceHeader, err := codec.Marshal(m.SequenceHeader) + if err != nil { + return nil, errors.Errorf("failed to encode sequence header: %s", err) + } + + start, end := 0, int(maxBodySize) + for i := uint32(0); i < nrChunks-1; i++ { + m.Header.MessageSize = maxBodySize + 24 + m.Header.ChunkType = ChunkTypeIntermediate + + header, err := codec.Marshal(m.Header) + if err != nil { + return nil, errors.Errorf("failed to encode header: %s", err) + } + + chunks[i] = append(chunks[i], header...) + chunks[i] = append(chunks[i], symmetricSecurityHeader...) + chunks[i] = append(chunks[i], sequenceHeader...) + chunks[i] = append(chunks[i], dataBody[start:end]...) + start, end = end, end+int(maxBodySize) + } + + m.Header.ChunkType = ChunkTypeFinal + m.Header.MessageSize = uint32(24 + len(dataBody)) + + header, err := codec.Marshal(m.Header) + if err != nil { + return nil, err + } + chunks[nrChunks-1] = append(chunks[nrChunks-1], header...) + chunks[nrChunks-1] = append(chunks[nrChunks-1], symmetricSecurityHeader...) + chunks[nrChunks-1] = append(chunks[nrChunks-1], sequenceHeader...) + chunks[nrChunks-1] = append(chunks[nrChunks-1], dataBody...) + return chunks, nil + default: + return nil, errors.Errorf("invalid message type %q", m.Header.MessageType) + } +} diff --git a/uasc/message_test.go b/uasc/message_test.go index 1311456b..7056d965 100644 --- a/uasc/message_test.go +++ b/uasc/message_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/id" "github.com/gopcua/opcua/ua" @@ -245,6 +246,56 @@ func TestMessage(t *testing.T) { 0x00, 0x00, 0x00, 0x00, 0x00, }, }, + // { + // Name: "OPN-2", + // Struct: func() interface{} { + // s := &SecureChannel{ + // endpointURL: "opc.tcp://192.168.118.199:4840", + // cfg: &Config{ + // SecurityPolicyURI: "opc.tcp://192.168.118.199:4840/FOO", + // }, + // } + // instance := &channelInstance{ + // sc: s, + // sequenceNumber: 1, + // securityTokenID: 0, + // maxBodySize: 65510, + // } + // m := instance.newMessage( + // &ua.OpenSecureChannelRequest{ + // RequestHeader: &ua.RequestHeader{ + // AuthenticationToken: ua.NewTwoByteNodeID(0), + // Timestamp: time.Date(2024, time.May, 30, 16, 17, 44, 0, time.Local), + // RequestHandle: 1, + // ReturnDiagnostics: 0, + // TimeoutHint: 10000, + // AdditionalHeader: ua.NewExtensionObject(nil), + // }, + // ClientProtocolVersion: 0, + // RequestType: ua.SecurityTokenRequestTypeIssue, + // SecurityMode: ua.MessageSecurityModeNone, + // ClientNonce: []byte{}, + // RequestedLifetime: 3600000, + // }, + // id.OpenSecureChannelRequest_Encoding_DefaultBinary, + // s.nextRequestID(), + // ) + + // // set message size manually, since it is computed in Encode + // // otherwise, the decode tests failed. + // m.Header.MessageSize = 119 + + // return m + // }(), + // Bytes: []byte{ // OpenSecureChannelRequest + // // Message Header + // // MessageType: OPN + // 0x4f, 0x50, 0x4e, + // // Chunk Type: Final + // 0x46, + // 0x77, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, 0x6f, 0x70, 0x63, 0x2e, 0x74, 0x63, 0x70, 0x3a, 0x2f, 0x2f, 0x31, 0x39, 0x32, 0x2e, 0x31, 0x36, 0x38, 0x2e, 0x31, 0x31, 0x38, 0x2e, 0x31, 0x39, 0x39, 0x3a, 0x34, 0x38, 0x34, 0x30, 0x2f, 0x46, 0x4f, 0x4f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0xbe, 0x01, 0x00, 0x00, 0x00, 0x04, 0xd5, 0xd8, 0x69, 0xb2, 0xda, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x10, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0xee, 0x36, 0x00, + // }, + // }, } RunCodecTest(t, cases) } @@ -495,3 +546,248 @@ func BenchmarkEncodeMessage(b *testing.B) { } } } + +func BenchmarkEncodeMessage_WithCodec(b *testing.B) { + cases := []CodecTestCase{ + { + Name: "OPN", + Struct: func() interface{} { + s := &SecureChannel{ + cfg: &Config{ + SecurityPolicyURI: "http://gopcua.example/OPCUA/SecurityPolicy#Foo", + }, + } + instance := &channelInstance{ + sc: s, + sequenceNumber: 0, + securityTokenID: 0, + } + m := instance.newMessage( + &ua.OpenSecureChannelRequest{ + RequestHeader: &ua.RequestHeader{ + AuthenticationToken: ua.NewTwoByteNodeID(0), + Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), + RequestHandle: 1, + ReturnDiagnostics: 0x03ff, + AdditionalHeader: ua.NewExtensionObject(nil), + }, + ClientProtocolVersion: 0, + RequestType: ua.SecurityTokenRequestTypeIssue, + SecurityMode: ua.MessageSecurityModeNone, + RequestedLifetime: 6000000, + }, + id.OpenSecureChannelRequest_Encoding_DefaultBinary, + s.nextRequestID(), + ) + + // set message size manually, since it is computed in Encode + // otherwise, the decode tests failed. + m.Header.MessageSize = 131 + + return m + }(), + Bytes: []byte{ // OpenSecureChannelRequest + // Message Header + // MessageType: OPN + 0x4f, 0x50, 0x4e, + // Chunk Type: Final + 0x46, + // MessageSize: 131 + 0x83, 0x00, 0x00, 0x00, + // SecureChannelID: 0 + 0x00, 0x00, 0x00, 0x00, + // AsymmetricSecurityHeader + // SecurityPolicyURILength + 0x2e, 0x00, 0x00, 0x00, + // SecurityPolicyURI + 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x67, + 0x6f, 0x70, 0x63, 0x75, 0x61, 0x2e, 0x65, 0x78, + 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2f, 0x4f, 0x50, + 0x43, 0x55, 0x41, 0x2f, 0x53, 0x65, 0x63, 0x75, + 0x72, 0x69, 0x74, 0x79, 0x50, 0x6f, 0x6c, 0x69, + 0x63, 0x79, 0x23, 0x46, 0x6f, 0x6f, + // SenderCertificate + 0xff, 0xff, 0xff, 0xff, + // ReceiverCertificateThumbprint + 0xff, 0xff, 0xff, 0xff, + // Sequence Header + // SequenceNumber + 0x01, 0x00, 0x00, 0x00, + // RequestID + 0x01, 0x00, 0x00, 0x00, + // TypeID + 0x01, 0x00, 0xbe, 0x01, + + // RequestHeader + // - AuthenticationToken + 0x00, 0x00, + // - Timestamp + 0x00, 0x98, 0x67, 0xdd, 0xfd, 0x30, 0xd4, 0x01, + // - RequestHandle + 0x01, 0x00, 0x00, 0x00, + // - ReturnDiagnostics + 0xff, 0x03, 0x00, 0x00, + // - AuditEntry + 0xff, 0xff, 0xff, 0xff, + // - TimeoutHint + 0x00, 0x00, 0x00, 0x00, + // - AdditionalHeader + // - TypeID + 0x00, 0x00, + // - EncodingMask + 0x00, + // ClientProtocolVersion + 0x00, 0x00, 0x00, 0x00, + // SecurityTokenRequestType + 0x00, 0x00, 0x00, 0x00, + // MessageSecurityMode + 0x01, 0x00, 0x00, 0x00, + // ClientNonce + 0xff, 0xff, 0xff, 0xff, + // RequestedLifetime + 0x80, 0x8d, 0x5b, 0x00, + }, + }, + { + Name: "MSG", + Struct: func() interface{} { + s := &SecureChannel{ + cfg: &Config{ + SecurityPolicyURI: "http://gopcua.example/OPCUA/SecurityPolicy#Foo", + }, + } + instance := &channelInstance{ + sc: s, + sequenceNumber: 0, + securityTokenID: 0, + } + m := instance.newMessage( + &ua.GetEndpointsRequest{ + RequestHeader: &ua.RequestHeader{ + AuthenticationToken: ua.NewTwoByteNodeID(0), + Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), + RequestHandle: 1, + ReturnDiagnostics: 0x03ff, + AdditionalHeader: ua.NewExtensionObject(nil), + }, + EndpointURL: "opc.tcp://wow.its.easy:11111/UA/Server", + }, + id.GetEndpointsRequest_Encoding_DefaultBinary, + s.nextRequestID(), + ) + + // set message size manually, since it is computed in Encode + // otherwise, the decode tests failed. + m.Header.MessageSize = 107 + + return m + }(), + Bytes: []byte{ // GetEndpointsRequest + // Message Header + // MessageType: MSG + 0x4d, 0x53, 0x47, + // Chunk Type: Final + 0x46, + // MessageSize: 107 + 0x6b, 0x00, 0x00, 0x00, + // SecureChannelID: 0 + 0x00, 0x00, 0x00, 0x00, + // SymmetricSecurityHeader + // TokenID + 0x00, 0x00, 0x00, 0x00, + // Sequence Header + // SequenceNumber + 0x01, 0x00, 0x00, 0x00, + // RequestID + 0x01, 0x00, 0x00, 0x00, + // TypeID + 0x01, 0x00, 0xac, 0x01, + // RequestHeader + 0x00, 0x00, 0x00, 0x98, 0x67, 0xdd, 0xfd, 0x30, + 0xd4, 0x01, 0x01, 0x00, 0x00, 0x00, 0xff, 0x03, + 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, + // ClientProtocolVersion + 0x26, 0x00, 0x00, 0x00, 0x6f, 0x70, 0x63, 0x2e, + 0x74, 0x63, 0x70, 0x3a, 0x2f, 0x2f, 0x77, 0x6f, + 0x77, 0x2e, 0x69, 0x74, 0x73, 0x2e, 0x65, 0x61, + 0x73, 0x79, 0x3a, 0x31, 0x31, 0x31, 0x31, 0x31, + 0x2f, 0x55, 0x41, 0x2f, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, + // LocaleIDs + 0xff, 0xff, 0xff, 0xff, + // ProfileURIs + 0xff, 0xff, 0xff, 0xff, + }, + }, { + Name: "CLO", + Struct: func() interface{} { + s := &SecureChannel{ + cfg: &Config{ + SecurityPolicyURI: "http://gopcua.example/OPCUA/SecurityPolicy#Foo", + }, + } + instance := &channelInstance{ + sc: s, + sequenceNumber: 0, + securityTokenID: 0, + } + m := instance.newMessage( + &ua.CloseSecureChannelRequest{ + RequestHeader: &ua.RequestHeader{ + AuthenticationToken: ua.NewTwoByteNodeID(0), + Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), + RequestHandle: 1, + ReturnDiagnostics: 0x03ff, + AdditionalHeader: ua.NewExtensionObject(nil), + }, + }, + id.CloseSecureChannelRequest_Encoding_DefaultBinary, + s.nextRequestID(), + ) + + // set message size manually, since it is computed in Encode + // otherwise, the decode tests failed. + m.Header.MessageSize = 57 + + return m + }(), + Bytes: []byte{ // OpenSecureChannelRequest + // Message Header + // MessageType: CLO + 0x43, 0x4c, 0x4f, + // Chunk Type: Final + 0x46, + // MessageSize: 57 + 0x39, 0x00, 0x00, 0x00, + // SecureChannelID: 0 + 0x00, 0x00, 0x00, 0x00, + // SymmetricSecurityHeader + // TokenID + 0x00, 0x00, 0x00, 0x00, + // Sequence Header + // SequenceNumber + 0x01, 0x00, 0x00, 0x00, + // RequestID + 0x01, 0x00, 0x00, 0x00, + // TypeID + 0x01, 0x00, 0xc4, 0x01, + // RequestHeader + 0x00, 0x00, 0x00, 0x98, 0x67, 0xdd, 0xfd, 0x30, + 0xd4, 0x01, 0x01, 0x00, 0x00, 0x00, 0xff, 0x03, + 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, + }, + }, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, tc := range cases { + _, err := codec.Marshal(tc.Struct) + if err != nil { + b.Fatalf("fail to encode message, err: %v", err) + } + } + } +} diff --git a/uasc/secure_channel.go b/uasc/secure_channel.go index 78822642..10e0cf46 100644 --- a/uasc/secure_channel.go +++ b/uasc/secure_channel.go @@ -771,7 +771,7 @@ func (s *SecureChannel) sendAsyncWithTimeout( s.handlersMu.Unlock() } - chunks, err := m.EncodeChunks(instance.maxBodySize) + chunks, err := m.MarshalChunks(instance.maxBodySize) if err != nil { return nil, err } diff --git a/uasc/secure_channel_instance.go b/uasc/secure_channel_instance.go index bf123eb8..0a76f46d 100644 --- a/uasc/secure_channel_instance.go +++ b/uasc/secure_channel_instance.go @@ -71,6 +71,7 @@ func (c *channelInstance) newRequestMessage(req ua.Request, reqID uint32, authTo AuthenticationToken: authToken, Timestamp: c.sc.timeNow(), RequestHandle: reqID, // TODO: can I cheat like this? + AdditionalHeader: ua.NewExtensionObject(nil), } if timeout > 0 && timeout < c.sc.cfg.RequestTimeout { From 910db7e05e1b64720109d43927f9bbb1ca6c9769 Mon Sep 17 00:00:00 2001 From: yuanliang Date: Sat, 1 Jun 2024 14:07:37 +0800 Subject: [PATCH 08/14] fix: remove invalid code and fix unittest Signed-off-by: yuanliang --- codec/encode.go | 23 ++- go.sum | 3 + ua/codec_test.go | 12 +- ua/datatypes.go | 45 +----- ua/decode.go | 4 + ua/decode_test.go | 11 +- ua/diagnostic_info.go | 25 --- ua/encode.go | 168 -------------------- ua/expanded_node_id.go | 10 -- ua/extension_object.go | 22 +-- ua/node_id.go | 23 --- ua/stream.go | 166 ------------------- ua/variant.go | 93 +---------- uacp/codec_test.go | 8 +- uacp/conn.go | 21 ++- uacp/conn_test.go | 10 +- uacp/uacp.go | 51 +----- uasc/asymmetric_security_header.go | 6 - uasc/header.go | 11 -- uasc/message.go | 83 ---------- uasc/message_test.go | 247 ----------------------------- uasc/secure_channel_test.go | 3 + uasc/sequence_header.go | 5 - uasc/symmetric_security_header.go | 4 - 24 files changed, 74 insertions(+), 980 deletions(-) delete mode 100644 ua/encode.go delete mode 100644 ua/stream.go diff --git a/codec/encode.go b/codec/encode.go index cc4c88c0..bcc01480 100644 --- a/codec/encode.go +++ b/codec/encode.go @@ -276,9 +276,9 @@ func (e *encodeState) writeUint64(n uint64) { func boolEncoder(e *encodeState, v reflect.Value) { val := v.Bool() if val { - e.WriteByte('1') + e.WriteByte(1) } else { - e.WriteByte('0') + e.WriteByte(0) } } @@ -566,17 +566,26 @@ type field struct { func typeFields(t reflect.Type) structFields { var fields []field + // Visit the fields of the given type. visitField := func(f reflect.StructField) { t := f.Type + // return marshalerEncoder directly, if it implements Marshaler. + if t.Implements(marshalerType) { + fields = append(fields, field{name: f.Name, nameBytes: []byte(f.Name), index: f.Index, typ: t, encoder: marshalerEncoder}) + return + } + // time.Time is special because it has embedded structs that use timeEncoder. - if t == timeType || (t.Kind() == reflect.Pointer && t.Elem() == timeType) { + if t.AssignableTo(timeType) || (t.Kind() == reflect.Pointer && t.Elem().AssignableTo(timeType)) { fields = append(fields, field{name: f.Name, nameBytes: []byte(f.Name), index: f.Index, typ: t, encoder: timeEncoder}) return } - // 如果t实现了Marshaler接口,则使用Marshaler的MarshalOPCUA方法 - if t.Implements(marshalerType) { - fields = append(fields, field{name: f.Name, nameBytes: []byte(f.Name), index: f.Index, typ: t, encoder: marshalerEncoder}) - return + if t.ConvertibleTo(timeType) { + converted := reflect.New(t).Elem().Convert(timeType) + if _, ok := converted.Interface().(time.Time); ok { + fields = append(fields, field{name: f.Name, nameBytes: []byte(f.Name), index: f.Index, typ: t, encoder: timeEncoder}) + return + } } // Check for anonymous fields (embedded structs). diff --git a/go.sum b/go.sum index ff4affa7..338fc88f 100644 --- a/go.sum +++ b/go.sum @@ -3,5 +3,8 @@ github.com/pascaldekloe/goe v0.1.1/go.mod h1:KSyfaxQOh0HZPjDP1FL/kFtbqYqrALJTaMa github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= +golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= diff --git a/ua/codec_test.go b/ua/codec_test.go index 3ba7bd9c..5b6beea6 100644 --- a/ua/codec_test.go +++ b/ua/codec_test.go @@ -10,7 +10,9 @@ import ( "reflect" "testing" + "github.com/gopcua/opcua/codec" "github.com/pascaldekloe/goe/verify" + "github.com/stretchr/testify/assert" ) // CodecTestCase describes a test case for a encoding and decoding an @@ -53,12 +55,12 @@ func RunCodecTest(t *testing.T, cases []CodecTestCase) { }) t.Run("encode", func(t *testing.T) { - s := NewStream(DefaultBufSize) - s.WriteAny(c.Struct) - if s.Error() != nil { - t.Fatal(s.Error()) + b, err := codec.Marshal(c.Struct) + if err != nil { + t.Fatalf("failed to marshal message: %v", err) } - verify.Values(t, "", s.Bytes(), c.Bytes) + assert.Equal(t, c.Bytes, b) + // verify.Values(t, "", b, c.Bytes) }) }) } diff --git a/ua/datatypes.go b/ua/datatypes.go index d3954366..cf6d4532 100644 --- a/ua/datatypes.go +++ b/ua/datatypes.go @@ -64,29 +64,6 @@ func (d *DataValue) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (d *DataValue) Encode(s *Stream) { - s.WriteUint8(d.EncodingMask) - - if d.Has(DataValueValue) { - s.WriteAny(d.Value) - } - if d.Has(DataValueStatusCode) { - s.WriteUint32(uint32(d.Status)) - } - if d.Has(DataValueSourceTimestamp) { - s.WriteTime(d.SourceTimestamp) - } - if d.Has(DataValueSourcePicoseconds) { - s.WriteUint16(d.SourcePicoseconds) - } - if d.Has(DataValueServerTimestamp) { - s.WriteTime(d.ServerTimestamp) - } - if d.Has(DataValueServerPicoseconds) { - s.WriteUint16(d.ServerPicoseconds) - } -} - func (d *DataValue) MarshalOPCUA() ([]byte, error) { var buf bytes.Buffer var err error @@ -183,11 +160,13 @@ func (g *GUID) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (g *GUID) Encode(s *Stream) { - s.WriteUint32(g.Data1) - s.WriteUint16(g.Data2) - s.WriteUint16(g.Data3) - s.Write(g.Data4) +func (g *GUID) MarshalOPCUA() ([]byte, error) { + buf := make([]byte, 0, 8+len(g.Data4)) + buf = binary.LittleEndian.AppendUint32(buf, g.Data1) + buf = binary.LittleEndian.AppendUint16(buf, g.Data2) + buf = binary.LittleEndian.AppendUint16(buf, g.Data3) + buf = append(buf, g.Data4...) + return buf, nil } // String returns GUID in human-readable string. @@ -256,16 +235,6 @@ func (l *LocalizedText) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (l *LocalizedText) Encode(s *Stream) { - s.WriteUint8(l.EncodingMask) - if l.Has(LocalizedTextLocale) { - s.WriteString(l.Locale) - } - if l.Has(LocalizedTextText) { - s.WriteString(l.Text) - } -} - func (l *LocalizedText) MarshalOPCUA() ([]byte, error) { var buf bytes.Buffer var err error diff --git a/ua/decode.go b/ua/decode.go index 0a679d5d..9bcc059c 100644 --- a/ua/decode.go +++ b/ua/decode.go @@ -10,6 +10,7 @@ import ( "reflect" "time" + "github.com/gopcua/opcua/debug" "github.com/gopcua/opcua/errors" ) @@ -35,6 +36,9 @@ func Decode(b []byte, v interface{}) (int, error) { return decode(b, val, val.Type().String()) } +// debugCodec enables printing of debug messages in the opcua codec. +var debugCodec = debug.FlagSet("codec") + func decode(b []byte, val reflect.Value, name string) (n int, err error) { if debugCodec { fmt.Printf("decode: %s has type %v and is a %s, %d bytes\n", name, val.Type(), val.Type().Kind(), len(b)) diff --git a/ua/decode_test.go b/ua/decode_test.go index 9ef2b14a..63b716e5 100644 --- a/ua/decode_test.go +++ b/ua/decode_test.go @@ -9,6 +9,8 @@ import ( "reflect" "testing" "time" + + "github.com/gopcua/opcua/codec" ) type A struct { @@ -403,12 +405,11 @@ func TestCodec(t *testing.T) { } }) t.Run("encode", func(t *testing.T) { - s := NewStream(DefaultBufSize) - s.WriteAny(tt.v) - if s.Error() != nil { - t.Fatal(s.Error()) + b, err := codec.Marshal(tt.v) + if err != nil { + t.Fatal(err) } - if got, want := s.Bytes(), tt.b; !bytes.Equal(got, want) { + if got, want := b, tt.b; !bytes.Equal(got, want) { t.Fatalf("\ngot %#v\nwant %#v", got, want) } }) diff --git a/ua/diagnostic_info.go b/ua/diagnostic_info.go index e5457a9e..48f4f8d8 100644 --- a/ua/diagnostic_info.go +++ b/ua/diagnostic_info.go @@ -64,31 +64,6 @@ func (d *DiagnosticInfo) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (d *DiagnosticInfo) Encode(s *Stream) { - s.WriteByte(d.EncodingMask) - if d.Has(DiagnosticInfoSymbolicID) { - s.WriteInt32(d.SymbolicID) - } - if d.Has(DiagnosticInfoNamespaceURI) { - s.WriteInt32(d.NamespaceURI) - } - if d.Has(DiagnosticInfoLocale) { - s.WriteInt32(d.Locale) - } - if d.Has(DiagnosticInfoLocalizedText) { - s.WriteInt32(d.LocalizedText) - } - if d.Has(DiagnosticInfoAdditionalInfo) { - s.WriteString(d.AdditionalInfo) - } - if d.Has(DiagnosticInfoInnerStatusCode) { - s.WriteUint32(uint32(d.InnerStatusCode)) - } - if d.Has(DiagnosticInfoInnerDiagnosticInfo) { - s.WriteAny(d.InnerDiagnosticInfo) - } -} - func (d *DiagnosticInfo) MarshalOPCUA() ([]byte, error) { var buf bytes.Buffer buf.WriteByte(d.EncodingMask) diff --git a/ua/encode.go b/ua/encode.go deleted file mode 100644 index 3f66cd22..00000000 --- a/ua/encode.go +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2018-2020 opcua authors. All rights reserved. -// Use of this source code is governed by a MIT-style license that can be -// found in the LICENSE file. - -package ua - -import ( - "fmt" - "math" - "reflect" - "time" - - "github.com/gopcua/opcua/debug" - "github.com/gopcua/opcua/errors" -) - -// debugCodec enables printing of debug messages in the opcua codec. -var debugCodec = debug.FlagSet("codec") - -// BinaryEncoder is the interface implemented by an object that can -// marshal itself into a binary OPC/UA representation. -type BinaryEncoder interface { - Encode(s *Stream) -} - -var binaryEncoder = reflect.TypeOf((*BinaryEncoder)(nil)).Elem() - -func isBinaryEncoder(val reflect.Value) bool { - return val.Type().Implements(binaryEncoder) -} - -func (s *Stream) WriteAny(w interface{}) { - if s.err != nil { - return - } - val := reflect.ValueOf(w) - switch x := w.(type) { - case BinaryEncoder: - x.Encode(s) - default: - s.encode(val, val.Type().String()) - } -} - -func (s *Stream) encode(val reflect.Value, name string) { - if debugCodec { - fmt.Printf("encode: %s has type %s and is a %s\n", name, val.Type(), val.Type().Kind()) - } - - switch { - case isBinaryEncoder(val): - v := val.Interface().(BinaryEncoder) - v.Encode(s) - - case isTime(val): - s.WriteTime(val.Convert(timeType).Interface().(time.Time)) - - default: - switch val.Kind() { - case reflect.Bool: - s.WriteBool(val.Bool()) - case reflect.Int8: - s.WriteInt8(int8(val.Int())) - case reflect.Uint8: - s.WriteUint8(uint8(val.Uint())) - case reflect.Int16: - s.WriteInt16(int16(val.Int())) - case reflect.Uint16: - s.WriteUint16(uint16(val.Uint())) - case reflect.Int32: - s.WriteInt32(int32(val.Int())) - case reflect.Uint32: - s.WriteUint32(uint32(val.Uint())) - case reflect.Int64: - s.WriteInt64(int64(val.Int())) - case reflect.Uint64: - s.WriteUint64(uint64(val.Uint())) - case reflect.Float32: - s.WriteFloat32(float32(val.Float())) - case reflect.Float64: - s.WriteFloat64(float64(val.Float())) - case reflect.String: - s.WriteString(val.String()) - case reflect.Ptr: - if val.IsNil() { - return - } - s.encode(val.Elem(), name) - case reflect.Struct: - s.writeStruct(val, name) - case reflect.Slice: - s.writeSlice(val, name) - case reflect.Array: - s.writeArray(val, name) - default: - s.WrapError(errors.Errorf("unsupported type: %s", val.Type())) - } - } -} - -func (s *Stream) writeStruct(val reflect.Value, name string) { - valt := val.Type() - for i := 0; i < val.NumField(); i++ { - ft := valt.Field(i) - fname := name + "." + ft.Name - if s.encode(val.Field(i), fname); s.err != nil { - return - } - } -} - -func (s *Stream) writeSlice(val reflect.Value, name string) { - if val.IsNil() { - s.WriteUint32(null) - return - } - - if val.Len() > math.MaxInt32 { - s.WrapError(errors.Errorf("array too large")) - return - } - - s.WriteUint32(uint32(val.Len())) - - // fast path for []byte - if val.Type().Elem().Kind() == reflect.Uint8 { - // fmt.Println("[]byte fast path") - s.Write(val.Bytes()) - return - } - - // loop over elements - for i := 0; i < val.Len(); i++ { - ename := fmt.Sprintf("%s[%d]", name, i) - s.encode(val.Index(i), ename) - if s.Error() != nil { - return - } - } -} - -func (s *Stream) writeArray(val reflect.Value, name string) error { - if val.Len() > math.MaxInt32 { - return errors.Errorf("array too large: %d > %d", val.Len(), math.MaxInt32) - } - - s.WriteUint32(uint32(val.Len())) - - // fast path for []byte - if val.Type().Elem().Kind() == reflect.Uint8 { - // fmt.Println("encode: []byte fast path") - b := make([]byte, val.Len()) - reflect.Copy(reflect.ValueOf(b), val) - s.Write(b) - return s.Error() - } - - // loop over elements - // we write all the elements, also the zero values - for i := 0; i < val.Len(); i++ { - ename := fmt.Sprintf("%s[%d]", name, i) - s.encode(val.Index(i), ename) - if s.Error() != nil { - return s.Error() - } - } - return s.Error() -} diff --git a/ua/expanded_node_id.go b/ua/expanded_node_id.go index 257dc29c..ed4ec92b 100644 --- a/ua/expanded_node_id.go +++ b/ua/expanded_node_id.go @@ -104,16 +104,6 @@ func (e *ExpandedNodeID) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (e *ExpandedNodeID) Encode(s *Stream) { - s.WriteAny(e.NodeID) - if e.HasNamespaceURI() { - s.WriteString(e.NamespaceURI) - } - if e.HasServerIndex() { - s.WriteUint32(e.ServerIndex) - } -} - func (e *ExpandedNodeID) MarshalOPCUA() ([]byte, error) { var buf bytes.Buffer b, err := e.NodeID.MarshalOPCUA() diff --git a/ua/extension_object.go b/ua/extension_object.go index 82e3ee6b..5e33e131 100644 --- a/ua/extension_object.go +++ b/ua/extension_object.go @@ -87,27 +87,7 @@ func (e *ExtensionObject) Decode(b []byte) (int, error) { return buf.Pos(), body.Error() } -func (e *ExtensionObject) Encode(s *Stream) { - if e == nil { - e = &ExtensionObject{TypeID: NewTwoByteExpandedNodeID(0), EncodingMask: ExtensionObjectEmpty} - } - s.WriteAny(e.TypeID) - s.WriteByte(e.EncodingMask) - if e.EncodingMask == ExtensionObjectEmpty { - return - } - - body := NewStream(DefaultBufSize) - body.WriteAny(e.Value) - if body.Error() != nil { - s.WrapError(body.Error()) - return - } - s.WriteUint32(uint32(body.Len())) - s.Write(body.Bytes()) -} - -func (e *ExtensionObject) MarshalJSON() ([]byte, error) { +func (e *ExtensionObject) MarshalOPCUA() ([]byte, error) { var buf bytes.Buffer b, err := codec.Marshal(e.TypeID) buf.Write(b) diff --git a/ua/node_id.go b/ua/node_id.go index c9934aaf..9a195c3e 100644 --- a/ua/node_id.go +++ b/ua/node_id.go @@ -368,29 +368,6 @@ func (n *NodeID) Decode(b []byte) (int, error) { } } -func (n *NodeID) Encode(s *Stream) { - s.WriteByte(byte(n.mask)) - - switch n.Type() { - case NodeIDTypeTwoByte: - s.WriteByte(byte(n.nid)) - case NodeIDTypeFourByte: - s.WriteByte(byte(n.ns)) - s.WriteUint16(uint16(n.nid)) - case NodeIDTypeNumeric: - s.WriteUint16(n.ns) - s.WriteUint32(n.nid) - case NodeIDTypeGUID: - s.WriteUint16(n.ns) - s.WriteAny(n.gid) - case NodeIDTypeByteString, NodeIDTypeString: - s.WriteUint16(n.ns) - s.WriteByteString(n.bid) - default: - s.err = errors.Errorf("invalid node id type %v", n.Type()) - } -} - func (n *NodeID) MarshalOPCUA() ([]byte, error) { var buf bytes.Buffer buf.WriteByte(byte(n.mask)) diff --git a/ua/stream.go b/ua/stream.go deleted file mode 100644 index 27063e0a..00000000 --- a/ua/stream.go +++ /dev/null @@ -1,166 +0,0 @@ -package ua - -import ( - "encoding/binary" - "io" - "math" - "time" - - "github.com/gopcua/opcua/errors" -) - -const DefaultBufSize = 256 - -type Stream struct { - buf []byte - pos int - err error -} - -func NewStream(size int) *Stream { - return &Stream{ - buf: make([]byte, 0, size), - } -} - -func (s *Stream) WrapError(err error) { - s.err = errors.Join(err) -} - -func (s *Stream) Error() error { - return s.err -} - -func (s *Stream) Len() int { - return len(s.buf) -} - -func (s *Stream) Reset() { - s.buf = s.buf[:0] - s.pos = 0 - s.err = nil -} - -func (s *Stream) Bytes() []byte { - return s.buf -} - -func (b *Stream) ReadN(n int) []byte { - if b.err != nil { - return nil - } - d := b.buf[b.pos:] - if n > len(d) { - b.err = io.ErrUnexpectedEOF - return nil - } - b.pos += n - return d[:n] -} - -func (b *Stream) WriteBool(v bool) { - if v { - b.WriteUint8(1) - } else { - b.WriteUint8(0) - } -} - -func (b *Stream) WriteByte(n byte) { - b.buf = append(b.buf, n) -} - -func (b *Stream) WriteInt8(n int8) { - b.buf = append(b.buf, byte(n)) -} - -func (b *Stream) WriteUint8(n uint8) { - b.buf = append(b.buf, byte(n)) -} - -func (b *Stream) WriteInt16(n int16) { - b.WriteUint16(uint16(n)) -} - -func (b *Stream) WriteUint16(n uint16) { - d := make([]byte, 2) - binary.LittleEndian.PutUint16(d, n) - b.Write(d) -} - -func (b *Stream) WriteInt32(n int32) { - b.WriteUint32(uint32(n)) -} - -func (b *Stream) WriteUint32(n uint32) { - d := make([]byte, 4) - binary.LittleEndian.PutUint32(d, n) - b.Write(d) -} - -func (b *Stream) WriteInt64(n int64) { - b.WriteUint64(uint64(n)) -} - -func (b *Stream) WriteUint64(n uint64) { - d := make([]byte, 8) - binary.LittleEndian.PutUint64(d, n) - b.Write(d) -} - -func (b *Stream) WriteFloat32(n float32) { - if math.IsNaN(float64(n)) { - b.WriteUint32(f32qnan) - } else { - b.WriteUint32(math.Float32bits(n)) - } -} - -func (b *Stream) WriteFloat64(n float64) { - if math.IsNaN(n) { - b.WriteUint64(f64qnan) - } else { - b.WriteUint64(math.Float64bits(n)) - } -} - -func (b *Stream) WriteString(s string) { - if s == "" { - b.WriteUint32(null) - return - } - b.WriteByteString([]byte(s)) -} - -func (b *Stream) WriteByteString(d []byte) { - if b.err != nil { - return - } - if len(d) > math.MaxInt32 { - b.err = errors.Errorf("value too large") - return - } - if d == nil { - b.WriteUint32(null) - return - } - b.WriteUint32(uint32(len(d))) - b.Write(d) -} - -func (b *Stream) WriteTime(v time.Time) { - d := make([]byte, 8) - if !v.IsZero() { - // encode time in "100 nanosecond intervals since January 1, 1601" - ts := uint64(v.UTC().UnixNano()/100 + 116444736000000000) - binary.LittleEndian.PutUint64(d, ts) - } - b.Write(d) -} - -func (b *Stream) Write(d []byte) { - if b.err != nil { - return - } - b.buf = append(b.buf, d...) -} diff --git a/ua/variant.go b/ua/variant.go index dd9f8dfa..a545bfc2 100644 --- a/ua/variant.go +++ b/ua/variant.go @@ -319,29 +319,6 @@ func (m *Variant) decodeValue(buf *Buffer) interface{} { } } -// Encode implements the codec interface. -func (m *Variant) Encode(s *Stream) { - s.WriteByte(m.mask) - - // a null value specifies that no other fields are encoded - if m.Type() == TypeIDNull { - return - } - - if m.Has(VariantArrayValues) { - s.WriteInt32(m.arrayLength) - } - - m.encode(s, reflect.ValueOf(m.value)) - - if m.Has(VariantArrayDimensions) { - s.WriteInt32(m.arrayDimensionsLength) - for i := 0; i < int(m.arrayDimensionsLength); i++ { - s.WriteInt32(m.arrayDimensions[i]) - } - } -} - func (m *Variant) MarshalOPCUA() ([]byte, error) { var buf bytes.Buffer buf.WriteByte(m.mask) @@ -354,12 +331,7 @@ func (m *Variant) MarshalOPCUA() ([]byte, error) { if m.Has(VariantArrayValues) { buf.Write([]byte{byte(m.arrayLength), byte(m.arrayLength >> 8), byte(m.arrayLength >> 16), byte(m.arrayLength >> 24)}) } - - b, err := codec.Marshal(m.value) - if err != nil { - return buf.Bytes(), err - } - buf.Write(b) + m.encode(&buf, reflect.ValueOf(m.value)) if m.Has(VariantArrayDimensions) { buf.Write([]byte{byte(m.arrayDimensionsLength), byte(m.arrayDimensionsLength >> 8), byte(m.arrayDimensionsLength >> 16), byte(m.arrayDimensionsLength >> 24)}) @@ -371,69 +343,14 @@ func (m *Variant) MarshalOPCUA() ([]byte, error) { } // encode recursively writes the values to the buffer. -func (m *Variant) encode(s *Stream, val reflect.Value) { +func (m *Variant) encode(buf *bytes.Buffer, val reflect.Value) { if val.Kind() != reflect.Slice || m.Type() == TypeIDByteString { - m.encodeValue(s, val.Interface()) + b, _ := codec.Marshal(val.Interface()) + buf.Write(b) return } for i := 0; i < val.Len(); i++ { - m.encode(s, val.Index(i)) - } -} - -// encodeValue writes a single value of the base type to the buffer. -func (m *Variant) encodeValue(s *Stream, v interface{}) { - switch x := v.(type) { - case bool: - s.WriteBool(x) - case int8: - s.WriteInt8(x) - case byte: - s.WriteByte(x) - case int16: - s.WriteInt16(x) - case uint16: - s.WriteUint16(x) - case int32: - s.WriteInt32(x) - case uint32: - s.WriteUint32(x) - case int64: - s.WriteInt64(x) - case uint64: - s.WriteUint64(x) - case float32: - s.WriteFloat32(x) - case float64: - s.WriteFloat64(x) - case string: - s.WriteString(x) - case time.Time: - s.WriteTime(x) - case *GUID: - s.WriteAny(x) - case []byte: - s.WriteByteString(x) - case XMLElement: - s.WriteString(string(x)) - case *NodeID: - s.WriteAny(x) - case *ExpandedNodeID: - s.WriteAny(x) - case StatusCode: - s.WriteUint32(uint32(x)) - case *QualifiedName: - s.WriteAny(x) - case *LocalizedText: - s.WriteAny(x) - case *ExtensionObject: - s.WriteAny(x) - case *DataValue: - s.WriteAny(x) - case *Variant: - s.WriteAny(x) - case *DiagnosticInfo: - s.WriteAny(x) + m.encode(buf, val.Index(i)) } } diff --git a/uacp/codec_test.go b/uacp/codec_test.go index b8a590f9..4343ae49 100644 --- a/uacp/codec_test.go +++ b/uacp/codec_test.go @@ -10,6 +10,7 @@ import ( "reflect" "testing" + "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/ua" "github.com/pascaldekloe/goe/verify" ) @@ -54,10 +55,9 @@ func RunCodecTest(t *testing.T, cases []CodecTestCase) { }) t.Run("encode", func(t *testing.T) { - s := ua.NewStream(ua.DefaultBufSize) - s.WriteAny(c.Struct) - if s.Error() != nil { - t.Fatalf("fail to encode message, err: %v", s.Error()) + _, err := codec.Marshal(c.Struct) + if err != nil { + t.Fatalf("fail to encode message, err: %v", err) } }) }) diff --git a/uacp/conn.go b/uacp/conn.go index b7018fde..ce237923 100644 --- a/uacp/conn.go +++ b/uacp/conn.go @@ -13,6 +13,7 @@ import ( "sync/atomic" "time" + "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/debug" "github.com/gopcua/opcua/errors" "github.com/gopcua/opcua/ua" @@ -396,31 +397,27 @@ func (c *Conn) Send(typ string, msg interface{}) error { return errors.Errorf("invalid msg type: %s", typ) } - bodyStream := ua.NewStream(ua.DefaultBufSize) - bodyStream.WriteAny(msg) - if bodyStream.Error() != nil { - return errors.Errorf("encode msg failed: %s", bodyStream.Error()) + body, err := codec.Marshal(msg) + if err != nil { + return errors.Errorf("encode msg failed: %v", err) } h := Header{ MessageType: typ[:3], ChunkType: typ[3], - MessageSize: uint32(bodyStream.Len() + hdrlen), + MessageSize: uint32(len(body) + hdrlen), } if h.MessageSize > c.ack.SendBufSize { return errors.Errorf("send packet too large: %d > %d bytes", h.MessageSize, c.ack.SendBufSize) } - headerStream := ua.NewStream(ua.DefaultBufSize) - h.Encode(headerStream) - if headerStream.Error() != nil { - return errors.Errorf("encode hdr failed: %s", headerStream.Error()) + hdr, err := h.MarshalOPCUA() + if err != nil { + return errors.Errorf("encode hdr failed: %v", err) } - b := make([]byte, 0, headerStream.Len()+bodyStream.Len()) - b = append(b, headerStream.Bytes()...) - b = append(b, bodyStream.Bytes()...) + b := append(hdr, body...) if _, err := c.Write(b); err != nil { return errors.Errorf("write failed: %s", err) } diff --git a/uacp/conn_test.go b/uacp/conn_test.go index 5620569e..50bcefc9 100644 --- a/uacp/conn_test.go +++ b/uacp/conn_test.go @@ -10,8 +10,8 @@ import ( "testing" "time" + "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/errors" - "github.com/gopcua/opcua/ua" "github.com/pascaldekloe/goe/verify" ) @@ -121,12 +121,10 @@ NEXT: } got = got[hdrlen:] - s := ua.NewStream(ua.DefaultBufSize) - msg.Encode(s) - if s.Error() != nil { - t.Fatal(s.Error()) + want, err := codec.Marshal(msg) + if err != nil { + t.Fatal(err) } - want := s.Bytes() verify.Values(t, "", got, want) } diff --git a/uacp/uacp.go b/uacp/uacp.go index 6b921474..60c0d017 100644 --- a/uacp/uacp.go +++ b/uacp/uacp.go @@ -48,16 +48,6 @@ func (h *Header) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *Header) Encode(s *ua.Stream) { - if len(h.MessageType) != 3 { - s.WrapError(errors.Errorf("invalid message type: %q", h.MessageType)) - return - } - s.Write([]byte(h.MessageType)) - s.WriteByte(h.ChunkType) - s.WriteUint32(h.MessageSize) -} - func (h *Header) MarshalOPCUA() ([]byte, error) { if len(h.MessageType) != 3 { return nil, errors.Errorf("invalid message type: %q", h.MessageType) @@ -93,15 +83,6 @@ func (h *Hello) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *Hello) Encode(s *ua.Stream) { - s.WriteUint32(h.Version) - s.WriteUint32(h.ReceiveBufSize) - s.WriteUint32(h.SendBufSize) - s.WriteUint32(h.MaxMessageSize) - s.WriteUint32(h.MaxChunkCount) - s.WriteString(h.EndpointURL) -} - func (h *Hello) MarshalOPCUA() ([]byte, error) { var buf bytes.Buffer buf.Write([]byte{byte(h.Version), byte(h.Version >> 8), byte(h.Version >> 16), byte(h.Version >> 24)}) @@ -140,21 +121,13 @@ func (a *Acknowledge) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (a *Acknowledge) Encode(s *ua.Stream) { - s.WriteUint32(a.Version) - s.WriteUint32(a.ReceiveBufSize) - s.WriteUint32(a.SendBufSize) - s.WriteUint32(a.MaxMessageSize) - s.WriteUint32(a.MaxChunkCount) -} - func (a *Acknowledge) MarshalOPCUA() ([]byte, error) { buf := make([]byte, 0, 160) - binary.LittleEndian.AppendUint32(buf, a.Version) - binary.LittleEndian.AppendUint32(buf, a.ReceiveBufSize) - binary.LittleEndian.AppendUint32(buf, a.SendBufSize) - binary.LittleEndian.AppendUint32(buf, a.MaxMessageSize) - binary.LittleEndian.AppendUint32(buf, a.MaxChunkCount) + buf = binary.LittleEndian.AppendUint32(buf, a.Version) + buf = binary.LittleEndian.AppendUint32(buf, a.ReceiveBufSize) + buf = binary.LittleEndian.AppendUint32(buf, a.SendBufSize) + buf = binary.LittleEndian.AppendUint32(buf, a.MaxMessageSize) + buf = binary.LittleEndian.AppendUint32(buf, a.MaxChunkCount) return buf, nil } @@ -173,11 +146,6 @@ func (r *ReverseHello) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (r *ReverseHello) Encode(s *ua.Stream) { - s.WriteString(r.ServerURI) - s.WriteString(r.EndpointURL) -} - // Error represents a OPC UA Error. // // Specification: Part6, 7.1.2.5 @@ -193,11 +161,6 @@ func (e *Error) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (e *Error) Encode(s *ua.Stream) { - s.WriteUint32(e.ErrorCode) - s.WriteString(e.Reason) -} - func (e *Error) Error() string { return ua.StatusCode(e.ErrorCode).Error() } @@ -215,7 +178,3 @@ func (m *Message) Decode(b []byte) (int, error) { m.Data = b return len(b), nil } - -func (m *Message) Encode(s *ua.Stream) { - s.Write(m.Data) -} diff --git a/uasc/asymmetric_security_header.go b/uasc/asymmetric_security_header.go index 713c0aac..856f8abe 100644 --- a/uasc/asymmetric_security_header.go +++ b/uasc/asymmetric_security_header.go @@ -34,12 +34,6 @@ func (h *AsymmetricSecurityHeader) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *AsymmetricSecurityHeader) Encode(s *ua.Stream) { - s.WriteString(h.SecurityPolicyURI) - s.WriteByteString(h.SenderCertificate) - s.WriteByteString(h.ReceiverCertificateThumbprint) -} - // String returns Header in string. func (a *AsymmetricSecurityHeader) String() string { return fmt.Sprintf( diff --git a/uasc/header.go b/uasc/header.go index 20444877..f4e91b74 100644 --- a/uasc/header.go +++ b/uasc/header.go @@ -52,17 +52,6 @@ func (h *Header) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *Header) Encode(s *ua.Stream) { - if len(h.MessageType) != 3 { - s.WrapError(errors.Errorf("invalid message type: %q", h.MessageType)) - return - } - s.Write([]byte(h.MessageType)) - s.WriteByte(h.ChunkType) - s.WriteUint32(h.MessageSize) - s.WriteUint32(h.SecureChannelID) -} - func (h *Header) MarshalOPCUA() ([]byte, error) { if len(h.MessageType) != 3 { return nil, errors.Errorf("invalid message type: %q", h.MessageType) diff --git a/uasc/message.go b/uasc/message.go index 1bc23566..75e3caae 100644 --- a/uasc/message.go +++ b/uasc/message.go @@ -5,7 +5,6 @@ package uasc import ( - "fmt" "math" "github.com/gopcua/opcua/codec" @@ -76,11 +75,6 @@ func (m *MessageAbort) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (m *MessageAbort) Encode(s *ua.Stream) { - s.WriteUint32(m.ErrorCode) - s.WriteString(m.Reason) -} - func (m *MessageAbort) MessageAbort() string { return ua.StatusCode(m.ErrorCode).Error() } @@ -112,83 +106,6 @@ func (m *Message) Decode(b []byte) (int, error) { return len(b), err } -func (m *Message) Encode(s *ua.Stream) { - chunks, err := m.EncodeChunks(math.MaxUint32) - if err != nil { - return - } - s.Write(chunks[0]) -} - -func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { - dataBody := ua.NewStream(ua.DefaultBufSize) - dataBody.WriteAny(m.TypeID) - dataBody.WriteAny(m.Service) - if dataBody.Error() != nil { - return nil, errors.Errorf("failed to encode databody: %s", dataBody.Error()) - } - - nrChunks := uint32(dataBody.Len())/(maxBodySize) + 1 - chunks := make([][]byte, nrChunks) - - switch m.Header.MessageType { - case "OPN": - partialHeader := ua.NewStream(ua.DefaultBufSize) - partialHeader.WriteAny(m.AsymmetricSecurityHeader) - partialHeader.WriteAny(m.SequenceHeader) - if partialHeader.Error() != nil { - return nil, errors.Errorf("failed to encode partial header: %s", partialHeader.Error()) - } - - m.Header.MessageSize = uint32(12 + partialHeader.Len() + dataBody.Len()) - buf := ua.NewStream(ua.DefaultBufSize) - buf.WriteAny(m.Header) - buf.Write(partialHeader.Bytes()) - buf.Write(dataBody.Bytes()) - if buf.Error() != nil { - return nil, errors.Errorf("failed to encode chunk: %s", buf.Error()) - } - for _, v := range buf.Bytes() { - fmt.Printf("0x%02x,", v) - } - fmt.Println() - return [][]byte{buf.Bytes()}, nil - - case "CLO", "MSG": - chunk := ua.NewStream(ua.DefaultBufSize) - for i := uint32(0); i < nrChunks-1; i++ { - chunk.Reset() - m.Header.MessageSize = maxBodySize + 24 - m.Header.ChunkType = ChunkTypeIntermediate - chunk.WriteAny(m.Header) - chunk.WriteAny(m.SymmetricSecurityHeader) - chunk.WriteAny(m.SequenceHeader) - chunk.Write(dataBody.ReadN(int(maxBodySize))) - if chunk.Error() != nil { - return nil, errors.Errorf("failed to encode chunk: %s", chunk.Error()) - } - - chunks[i] = append(chunks[i], chunk.Bytes()...) - } - - m.Header.ChunkType = ChunkTypeFinal - m.Header.MessageSize = uint32(24 + dataBody.Len()) - chunk.Reset() - chunk.WriteAny(m.Header) - chunk.WriteAny(m.SymmetricSecurityHeader) - chunk.WriteAny(m.SequenceHeader) - chunk.Write(dataBody.Bytes()) - if chunk.Error() != nil { - return nil, errors.Errorf("failed to encode chunk: %s", chunk.Error()) - } - - chunks[nrChunks-1] = append(chunks[nrChunks-1], chunk.Bytes()...) - return chunks, nil - default: - return nil, errors.Errorf("invalid message type %q", m.Header.MessageType) - } -} - func (m *Message) MarshalOPCUA() ([]byte, error) { chunks, err := m.MarshalChunks(math.MaxUint32) if err != nil { diff --git a/uasc/message_test.go b/uasc/message_test.go index 7056d965..1df6035a 100644 --- a/uasc/message_test.go +++ b/uasc/message_test.go @@ -534,253 +534,6 @@ func BenchmarkEncodeMessage(b *testing.B) { }, } - s := ua.NewStream(ua.DefaultBufSize) - b.ResetTimer() - for i := 0; i < b.N; i++ { - for _, tc := range cases { - s.Reset() - s.WriteAny(tc.Struct) - if s.Error() != nil { - b.Fatalf("fail to encode message, err: %v", s.Error()) - } - } - } -} - -func BenchmarkEncodeMessage_WithCodec(b *testing.B) { - cases := []CodecTestCase{ - { - Name: "OPN", - Struct: func() interface{} { - s := &SecureChannel{ - cfg: &Config{ - SecurityPolicyURI: "http://gopcua.example/OPCUA/SecurityPolicy#Foo", - }, - } - instance := &channelInstance{ - sc: s, - sequenceNumber: 0, - securityTokenID: 0, - } - m := instance.newMessage( - &ua.OpenSecureChannelRequest{ - RequestHeader: &ua.RequestHeader{ - AuthenticationToken: ua.NewTwoByteNodeID(0), - Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), - RequestHandle: 1, - ReturnDiagnostics: 0x03ff, - AdditionalHeader: ua.NewExtensionObject(nil), - }, - ClientProtocolVersion: 0, - RequestType: ua.SecurityTokenRequestTypeIssue, - SecurityMode: ua.MessageSecurityModeNone, - RequestedLifetime: 6000000, - }, - id.OpenSecureChannelRequest_Encoding_DefaultBinary, - s.nextRequestID(), - ) - - // set message size manually, since it is computed in Encode - // otherwise, the decode tests failed. - m.Header.MessageSize = 131 - - return m - }(), - Bytes: []byte{ // OpenSecureChannelRequest - // Message Header - // MessageType: OPN - 0x4f, 0x50, 0x4e, - // Chunk Type: Final - 0x46, - // MessageSize: 131 - 0x83, 0x00, 0x00, 0x00, - // SecureChannelID: 0 - 0x00, 0x00, 0x00, 0x00, - // AsymmetricSecurityHeader - // SecurityPolicyURILength - 0x2e, 0x00, 0x00, 0x00, - // SecurityPolicyURI - 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x67, - 0x6f, 0x70, 0x63, 0x75, 0x61, 0x2e, 0x65, 0x78, - 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2f, 0x4f, 0x50, - 0x43, 0x55, 0x41, 0x2f, 0x53, 0x65, 0x63, 0x75, - 0x72, 0x69, 0x74, 0x79, 0x50, 0x6f, 0x6c, 0x69, - 0x63, 0x79, 0x23, 0x46, 0x6f, 0x6f, - // SenderCertificate - 0xff, 0xff, 0xff, 0xff, - // ReceiverCertificateThumbprint - 0xff, 0xff, 0xff, 0xff, - // Sequence Header - // SequenceNumber - 0x01, 0x00, 0x00, 0x00, - // RequestID - 0x01, 0x00, 0x00, 0x00, - // TypeID - 0x01, 0x00, 0xbe, 0x01, - - // RequestHeader - // - AuthenticationToken - 0x00, 0x00, - // - Timestamp - 0x00, 0x98, 0x67, 0xdd, 0xfd, 0x30, 0xd4, 0x01, - // - RequestHandle - 0x01, 0x00, 0x00, 0x00, - // - ReturnDiagnostics - 0xff, 0x03, 0x00, 0x00, - // - AuditEntry - 0xff, 0xff, 0xff, 0xff, - // - TimeoutHint - 0x00, 0x00, 0x00, 0x00, - // - AdditionalHeader - // - TypeID - 0x00, 0x00, - // - EncodingMask - 0x00, - // ClientProtocolVersion - 0x00, 0x00, 0x00, 0x00, - // SecurityTokenRequestType - 0x00, 0x00, 0x00, 0x00, - // MessageSecurityMode - 0x01, 0x00, 0x00, 0x00, - // ClientNonce - 0xff, 0xff, 0xff, 0xff, - // RequestedLifetime - 0x80, 0x8d, 0x5b, 0x00, - }, - }, - { - Name: "MSG", - Struct: func() interface{} { - s := &SecureChannel{ - cfg: &Config{ - SecurityPolicyURI: "http://gopcua.example/OPCUA/SecurityPolicy#Foo", - }, - } - instance := &channelInstance{ - sc: s, - sequenceNumber: 0, - securityTokenID: 0, - } - m := instance.newMessage( - &ua.GetEndpointsRequest{ - RequestHeader: &ua.RequestHeader{ - AuthenticationToken: ua.NewTwoByteNodeID(0), - Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), - RequestHandle: 1, - ReturnDiagnostics: 0x03ff, - AdditionalHeader: ua.NewExtensionObject(nil), - }, - EndpointURL: "opc.tcp://wow.its.easy:11111/UA/Server", - }, - id.GetEndpointsRequest_Encoding_DefaultBinary, - s.nextRequestID(), - ) - - // set message size manually, since it is computed in Encode - // otherwise, the decode tests failed. - m.Header.MessageSize = 107 - - return m - }(), - Bytes: []byte{ // GetEndpointsRequest - // Message Header - // MessageType: MSG - 0x4d, 0x53, 0x47, - // Chunk Type: Final - 0x46, - // MessageSize: 107 - 0x6b, 0x00, 0x00, 0x00, - // SecureChannelID: 0 - 0x00, 0x00, 0x00, 0x00, - // SymmetricSecurityHeader - // TokenID - 0x00, 0x00, 0x00, 0x00, - // Sequence Header - // SequenceNumber - 0x01, 0x00, 0x00, 0x00, - // RequestID - 0x01, 0x00, 0x00, 0x00, - // TypeID - 0x01, 0x00, 0xac, 0x01, - // RequestHeader - 0x00, 0x00, 0x00, 0x98, 0x67, 0xdd, 0xfd, 0x30, - 0xd4, 0x01, 0x01, 0x00, 0x00, 0x00, 0xff, 0x03, - 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, - // ClientProtocolVersion - 0x26, 0x00, 0x00, 0x00, 0x6f, 0x70, 0x63, 0x2e, - 0x74, 0x63, 0x70, 0x3a, 0x2f, 0x2f, 0x77, 0x6f, - 0x77, 0x2e, 0x69, 0x74, 0x73, 0x2e, 0x65, 0x61, - 0x73, 0x79, 0x3a, 0x31, 0x31, 0x31, 0x31, 0x31, - 0x2f, 0x55, 0x41, 0x2f, 0x53, 0x65, 0x72, 0x76, - 0x65, 0x72, - // LocaleIDs - 0xff, 0xff, 0xff, 0xff, - // ProfileURIs - 0xff, 0xff, 0xff, 0xff, - }, - }, { - Name: "CLO", - Struct: func() interface{} { - s := &SecureChannel{ - cfg: &Config{ - SecurityPolicyURI: "http://gopcua.example/OPCUA/SecurityPolicy#Foo", - }, - } - instance := &channelInstance{ - sc: s, - sequenceNumber: 0, - securityTokenID: 0, - } - m := instance.newMessage( - &ua.CloseSecureChannelRequest{ - RequestHeader: &ua.RequestHeader{ - AuthenticationToken: ua.NewTwoByteNodeID(0), - Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), - RequestHandle: 1, - ReturnDiagnostics: 0x03ff, - AdditionalHeader: ua.NewExtensionObject(nil), - }, - }, - id.CloseSecureChannelRequest_Encoding_DefaultBinary, - s.nextRequestID(), - ) - - // set message size manually, since it is computed in Encode - // otherwise, the decode tests failed. - m.Header.MessageSize = 57 - - return m - }(), - Bytes: []byte{ // OpenSecureChannelRequest - // Message Header - // MessageType: CLO - 0x43, 0x4c, 0x4f, - // Chunk Type: Final - 0x46, - // MessageSize: 57 - 0x39, 0x00, 0x00, 0x00, - // SecureChannelID: 0 - 0x00, 0x00, 0x00, 0x00, - // SymmetricSecurityHeader - // TokenID - 0x00, 0x00, 0x00, 0x00, - // Sequence Header - // SequenceNumber - 0x01, 0x00, 0x00, 0x00, - // RequestID - 0x01, 0x00, 0x00, 0x00, - // TypeID - 0x01, 0x00, 0xc4, 0x01, - // RequestHeader - 0x00, 0x00, 0x00, 0x98, 0x67, 0xdd, 0xfd, 0x30, - 0xd4, 0x01, 0x01, 0x00, 0x00, 0x00, 0xff, 0x03, - 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, - }, - }, - } - b.ResetTimer() for i := 0; i < b.N; i++ { for _, tc := range cases { diff --git a/uasc/secure_channel_test.go b/uasc/secure_channel_test.go index c31f1d99..d0e0e285 100644 --- a/uasc/secure_channel_test.go +++ b/uasc/secure_channel_test.go @@ -65,6 +65,7 @@ func TestNewRequestMessage(t *testing.T) { AuthenticationToken: ua.NewTwoByteNodeID(0), Timestamp: fixedTime(), RequestHandle: 1, + AdditionalHeader: ua.NewExtensionObject(nil), }, }, }, @@ -103,6 +104,7 @@ func TestNewRequestMessage(t *testing.T) { AuthenticationToken: ua.NewTwoByteNodeID(0), Timestamp: fixedTime(), RequestHandle: 556, + AdditionalHeader: ua.NewExtensionObject(nil), }, }, }, @@ -137,6 +139,7 @@ func TestNewRequestMessage(t *testing.T) { AuthenticationToken: ua.NewTwoByteNodeID(0), Timestamp: fixedTime(), RequestHandle: 1, + AdditionalHeader: ua.NewExtensionObject(nil), }, }, }, diff --git a/uasc/sequence_header.go b/uasc/sequence_header.go index bd7324b5..9d37cdf4 100644 --- a/uasc/sequence_header.go +++ b/uasc/sequence_header.go @@ -31,11 +31,6 @@ func (h *SequenceHeader) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *SequenceHeader) Encode(s *ua.Stream) { - s.WriteUint32(h.SequenceNumber) - s.WriteUint32(h.RequestID) -} - // String returns Header in string. func (s *SequenceHeader) String() string { return fmt.Sprintf( diff --git a/uasc/symmetric_security_header.go b/uasc/symmetric_security_header.go index 44d57c2a..c0ed7af2 100644 --- a/uasc/symmetric_security_header.go +++ b/uasc/symmetric_security_header.go @@ -28,10 +28,6 @@ func (h *SymmetricSecurityHeader) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *SymmetricSecurityHeader) Encode(s *ua.Stream) { - s.WriteUint32(h.TokenID) -} - // String returns Header in string. func (h *SymmetricSecurityHeader) String() string { return fmt.Sprintf( From 855932c71ffdf54f6a8137077159c53387cbead3 Mon Sep 17 00:00:00 2001 From: yuanliang Date: Sun, 2 Jun 2024 14:18:39 +0800 Subject: [PATCH 09/14] feat: add stream to reduce alloc Signed-off-by: yuanliang --- codec/encode.go | 54 ++++++--------- codec/stream.go | 32 +++++++++ ua/datatypes.go | 46 ++++++------- ua/diagnostic_info.go | 10 ++- ua/expanded_node_id.go | 15 ++--- ua/extension_object.go | 12 ++-- ua/node_id.go | 9 ++- ua/variant.go | 13 ++-- uacp/conn.go | 6 +- uacp/uacp.go | 19 +++--- uasc/asymmetric_security_header.go | 18 +++++ uasc/header.go | 28 ++++++-- uasc/message.go | 39 ++++++++++- uasc/secure_channel.go | 105 ++++++++++++++++++++++++++--- uasc/sequence_header.go | 16 +++++ uasc/symmetric_security_header.go | 16 +++++ uasc/timer.go | 55 +++++++++++++++ 17 files changed, 376 insertions(+), 117 deletions(-) create mode 100644 codec/stream.go create mode 100644 uasc/timer.go diff --git a/codec/encode.go b/codec/encode.go index bcc01480..9b7c287e 100644 --- a/codec/encode.go +++ b/codec/encode.go @@ -1,7 +1,6 @@ package codec import ( - "bytes" "fmt" "math" "reflect" @@ -23,10 +22,10 @@ func Marshal(v any) ([]byte, error) { return buf, nil } -// Marshaler is the interface implemented by types that +// Encoder is the interface implemented by types that // can marshal themselves into valid OPCUA. -type Marshaler interface { - MarshalOPCUA() ([]byte, error) +type Encoder interface { + EncodeOPCUA(s *Stream) error } // An UnsupportedTypeError is returned by [Marshal] when attempting @@ -51,7 +50,7 @@ func (e *UnsupportedValueError) Error() string { } // A MarshalerError represents an error from calling a -// [Marshaler.MarshalOPCUA] method. +// [Encoder.EncodeOPCUA] method. type MarshalerError struct { Type reflect.Type Err error @@ -61,7 +60,7 @@ type MarshalerError struct { func (e *MarshalerError) Error() string { srcFunc := e.sourceFunc if srcFunc == "" { - srcFunc = "MarshalOPCUA" + srcFunc = "EncodeOPCUA" } return "opcua: error calling " + srcFunc + " for type " + e.Type.String() + @@ -69,7 +68,7 @@ func (e *MarshalerError) Error() string { } type encodeState struct { - bytes.Buffer + Stream ptrLevel uint ptrSeen map[any]struct{} @@ -89,7 +88,10 @@ func newEncodeState() *encodeState { e.ptrLevel = 0 return e } - return &encodeState{ptrSeen: make(map[any]struct{})} + return &encodeState{ + Stream: Stream{buf: make([]byte, 0, 256)}, + ptrSeen: make(map[any]struct{}), + } } // codecError is an error wrapper type for internal use only. @@ -161,7 +163,7 @@ func typeEncoder(t reflect.Type) encoderFunc { return f } -var marshalerType = reflect.TypeFor[Marshaler]() +var encoderType = reflect.TypeFor[Encoder]() // newTypeEncoder constructs an encoderFunc for a type. // The returned encoder only checks CanAddr when allowAddr is true. @@ -171,10 +173,10 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc { // Marshaler with a value receiver, then we're better off taking // the address of the value - otherwise we end up with an // allocation as we cast the value to an interface. - if kind != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) { + if kind != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(encoderType) { return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false)) } - if t.Implements(marshalerType) { + if t.Implements(encoderType) { return marshalerEncoder } @@ -226,38 +228,26 @@ func marshalerEncoder(e *encodeState, v reflect.Value) { if v.Kind() == reflect.Pointer && v.IsNil() { return } - m, ok := v.Interface().(Marshaler) + m, ok := v.Interface().(Encoder) if !ok { return } - b, err := m.MarshalOPCUA() - if err == nil { - e.Grow(len(b)) - out := e.AvailableBuffer() - out = append(out, b...) - e.Buffer.Write(out) - } + err := m.EncodeOPCUA(&e.Stream) if err != nil { - e.error(&MarshalerError{v.Type(), err, "MarshalOPCUA"}) + e.error(&MarshalerError{v.Type(), err, "EncodeOPCUA"}) } } func addrMarshalerEncoder(e *encodeState, v reflect.Value) { va := v.Addr() if va.IsNil() { - e.WriteString("null") + e.Write(null) return } - m := va.Interface().(Marshaler) - b, err := m.MarshalOPCUA() - if err == nil { - e.Grow(len(b)) - out := e.AvailableBuffer() - out = append(out, b...) - e.Buffer.Write(out) - } + m := va.Interface().(Encoder) + err := m.EncodeOPCUA(&e.Stream) if err != nil { - e.error(&MarshalerError{v.Type(), err, "MarshalOPCUA"}) + e.error(&MarshalerError{v.Type(), err, "EncodeOPCUA"}) } } @@ -467,7 +457,7 @@ func newSliceEncoder(t reflect.Type) encoderFunc { // Byte slices get special treatment; arrays don't. if t.Elem().Kind() == reflect.Uint8 { p := reflect.PointerTo(t.Elem()) - if !p.Implements(marshalerType) { + if !p.Implements(encoderType) { return encodeByteSlice } } @@ -570,7 +560,7 @@ func typeFields(t reflect.Type) structFields { visitField := func(f reflect.StructField) { t := f.Type // return marshalerEncoder directly, if it implements Marshaler. - if t.Implements(marshalerType) { + if t.Implements(encoderType) { fields = append(fields, field{name: f.Name, nameBytes: []byte(f.Name), index: f.Index, typ: t, encoder: marshalerEncoder}) return } diff --git a/codec/stream.go b/codec/stream.go new file mode 100644 index 00000000..70614b92 --- /dev/null +++ b/codec/stream.go @@ -0,0 +1,32 @@ +package codec + +type Stream struct { + buf []byte +} + +func NewStream(buf []byte) *Stream { + return &Stream{buf: buf} +} + +func (s *Stream) Write(p []byte) (n int, err error) { + s.buf = append(s.buf, p...) + return len(p), nil +} + +func (s *Stream) WriteString(str string) (n int, err error) { + s.buf = append(s.buf, str...) + return len(str), nil +} + +func (s *Stream) WriteByte(b byte) error { + s.buf = append(s.buf, b) + return nil +} + +func (s *Stream) Reset() { + s.buf = s.buf[:0] +} + +func (s *Stream) Bytes() []byte { + return s.buf[:len(s.buf)] +} diff --git a/ua/datatypes.go b/ua/datatypes.go index cf6d4532..1cc6318a 100644 --- a/ua/datatypes.go +++ b/ua/datatypes.go @@ -5,7 +5,6 @@ package ua import ( - "bytes" "encoding/binary" "encoding/hex" "fmt" @@ -64,34 +63,33 @@ func (d *DataValue) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (d *DataValue) MarshalOPCUA() ([]byte, error) { - var buf bytes.Buffer +func (d *DataValue) EncodeOPCUA(s *codec.Stream) error { var err error var b []byte - buf.WriteByte(d.EncodingMask) + s.WriteByte(d.EncodingMask) if d.Has(DataValueValue) { b, err = codec.Marshal(d.Value) - buf.Write(b) + s.Write(b) } if d.Has(DataValueStatusCode) { - buf.Write([]byte{byte(d.Status), byte(d.Status >> 8), byte(d.Status >> 16), byte(d.Status >> 24)}) + s.Write([]byte{byte(d.Status), byte(d.Status >> 8), byte(d.Status >> 16), byte(d.Status >> 24)}) } if d.Has(DataValueSourceTimestamp) { b, err = codec.Marshal(d.SourceTimestamp) - buf.Write(b) + s.Write(b) } if d.Has(DataValueSourcePicoseconds) { - buf.Write([]byte{byte(d.SourcePicoseconds), byte(d.SourcePicoseconds >> 8)}) + s.Write([]byte{byte(d.SourcePicoseconds), byte(d.SourcePicoseconds >> 8)}) } if d.Has(DataValueServerTimestamp) { b, err = codec.Marshal(d.ServerTimestamp) - buf.Write(b) + s.Write(b) } if d.Has(DataValueServerPicoseconds) { - buf.Write([]byte{byte(d.ServerPicoseconds), byte(d.ServerPicoseconds >> 8)}) + s.Write([]byte{byte(d.ServerPicoseconds), byte(d.ServerPicoseconds >> 8)}) } - return buf.Bytes(), err + return err } func (d *DataValue) Has(mask byte) bool { @@ -160,13 +158,14 @@ func (g *GUID) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (g *GUID) MarshalOPCUA() ([]byte, error) { +func (g *GUID) EncodeOPCUA(s *codec.Stream) error { buf := make([]byte, 0, 8+len(g.Data4)) buf = binary.LittleEndian.AppendUint32(buf, g.Data1) buf = binary.LittleEndian.AppendUint16(buf, g.Data2) buf = binary.LittleEndian.AppendUint16(buf, g.Data3) buf = append(buf, g.Data4...) - return buf, nil + s.Write(buf) + return nil } // String returns GUID in human-readable string. @@ -235,30 +234,27 @@ func (l *LocalizedText) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (l *LocalizedText) MarshalOPCUA() ([]byte, error) { - var buf bytes.Buffer - var err error - - buf.WriteByte(l.EncodingMask) +func (l *LocalizedText) EncodeOPCUA(s *codec.Stream) error { + s.WriteByte(l.EncodingMask) if l.Has(LocalizedTextLocale) { n := len(l.Locale) if n == 0 { - buf.Write([]byte{0xff, 0xff, 0xff, 0xff}) + s.Write([]byte{0xff, 0xff, 0xff, 0xff}) } else { - buf.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) - buf.Write([]byte(l.Locale)) + s.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) + s.Write([]byte(l.Locale)) } } if l.Has(LocalizedTextText) { n := len(l.Text) if n == 0 { - buf.Write([]byte{0xff, 0xff, 0xff, 0xff}) + s.Write([]byte{0xff, 0xff, 0xff, 0xff}) } else { - buf.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) - buf.Write([]byte(l.Text)) + s.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) + s.Write([]byte(l.Text)) } } - return buf.Bytes(), err + return nil } func (l *LocalizedText) Has(mask byte) bool { diff --git a/ua/diagnostic_info.go b/ua/diagnostic_info.go index 48f4f8d8..e95b72a2 100644 --- a/ua/diagnostic_info.go +++ b/ua/diagnostic_info.go @@ -5,8 +5,6 @@ package ua import ( - "bytes" - "github.com/gopcua/opcua/codec" ) @@ -64,8 +62,8 @@ func (d *DiagnosticInfo) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (d *DiagnosticInfo) MarshalOPCUA() ([]byte, error) { - var buf bytes.Buffer +func (d *DiagnosticInfo) EncodeOPCUA(buf *codec.Stream) error { + // var buf bytes.Buffer buf.WriteByte(d.EncodingMask) if d.Has(DiagnosticInfoSymbolicID) { @@ -90,12 +88,12 @@ func (d *DiagnosticInfo) MarshalOPCUA() ([]byte, error) { if d.Has(DiagnosticInfoInnerDiagnosticInfo) { b, err := codec.Marshal(d.InnerDiagnosticInfo) if err != nil { - return nil, err + return err } buf.Write(b) } - return buf.Bytes(), nil + return nil } func (d *DiagnosticInfo) Has(mask byte) bool { diff --git a/ua/expanded_node_id.go b/ua/expanded_node_id.go index ed4ec92b..51ca4935 100644 --- a/ua/expanded_node_id.go +++ b/ua/expanded_node_id.go @@ -5,7 +5,6 @@ package ua import ( - "bytes" "encoding/base64" "math" "strconv" @@ -104,20 +103,20 @@ func (e *ExpandedNodeID) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (e *ExpandedNodeID) MarshalOPCUA() ([]byte, error) { - var buf bytes.Buffer - b, err := e.NodeID.MarshalOPCUA() - buf.Write(b) +func (e *ExpandedNodeID) EncodeOPCUA(buf *codec.Stream) error { + // var buf bytes.Buffer + err := e.NodeID.EncodeOPCUA(buf) + // buf.Write(b) if e.HasNamespaceURI() { - b, err = codec.Marshal(e.NamespaceURI) + b, _ := codec.Marshal(e.NamespaceURI) buf.Write(b) } if e.HasServerIndex() { - b, err = codec.Marshal(e.ServerIndex) + b, _ := codec.Marshal(e.ServerIndex) buf.Write(b) } - return buf.Bytes(), err + return err } // HasNamespaceURI checks if an ExpandedNodeID has NamespaceURI Flag. diff --git a/ua/extension_object.go b/ua/extension_object.go index 5e33e131..90fba83e 100644 --- a/ua/extension_object.go +++ b/ua/extension_object.go @@ -5,8 +5,6 @@ package ua import ( - "bytes" - "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/debug" "github.com/gopcua/opcua/id" @@ -87,23 +85,23 @@ func (e *ExtensionObject) Decode(b []byte) (int, error) { return buf.Pos(), body.Error() } -func (e *ExtensionObject) MarshalOPCUA() ([]byte, error) { - var buf bytes.Buffer +func (e *ExtensionObject) EncodeOPCUA(buf *codec.Stream) error { + // var buf bytes.Buffer b, err := codec.Marshal(e.TypeID) buf.Write(b) buf.WriteByte(e.EncodingMask) if e.EncodingMask == ExtensionObjectEmpty { - return buf.Bytes(), err + return err } body, err := codec.Marshal(e.Value) if err != nil { - return buf.Bytes(), err + return err } n := uint32(len(body)) buf.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) buf.Write(body) - return buf.Bytes(), err + return err } func (e *ExtensionObject) UpdateMask() { diff --git a/ua/node_id.go b/ua/node_id.go index 9a195c3e..057bbaa1 100644 --- a/ua/node_id.go +++ b/ua/node_id.go @@ -5,7 +5,6 @@ package ua import ( - "bytes" "encoding/base64" "encoding/json" "fmt" @@ -368,8 +367,8 @@ func (n *NodeID) Decode(b []byte) (int, error) { } } -func (n *NodeID) MarshalOPCUA() ([]byte, error) { - var buf bytes.Buffer +func (n *NodeID) EncodeOPCUA(buf *codec.Stream) error { + // var buf bytes.Buffer buf.WriteByte(byte(n.mask)) switch n.Type() { case NodeIDTypeTwoByte: @@ -393,9 +392,9 @@ func (n *NodeID) MarshalOPCUA() ([]byte, error) { buf.Write(n.bid) } default: - return nil, fmt.Errorf("invalid node id type: %d", n.mask) + return fmt.Errorf("invalid node id type: %d", n.mask) } - return buf.Bytes(), nil + return nil } func (n *NodeID) MarshalJSON() ([]byte, error) { diff --git a/ua/variant.go b/ua/variant.go index a545bfc2..c80e9ec8 100644 --- a/ua/variant.go +++ b/ua/variant.go @@ -5,7 +5,6 @@ package ua import ( - "bytes" "fmt" "reflect" "time" @@ -319,19 +318,19 @@ func (m *Variant) decodeValue(buf *Buffer) interface{} { } } -func (m *Variant) MarshalOPCUA() ([]byte, error) { - var buf bytes.Buffer +func (m *Variant) EncodeOPCUA(buf *codec.Stream) error { + // var buf bytes.Buffer buf.WriteByte(m.mask) // a null value specifies that no other fields are encoded if m.Type() == TypeIDNull { - return buf.Bytes(), nil + return nil } if m.Has(VariantArrayValues) { buf.Write([]byte{byte(m.arrayLength), byte(m.arrayLength >> 8), byte(m.arrayLength >> 16), byte(m.arrayLength >> 24)}) } - m.encode(&buf, reflect.ValueOf(m.value)) + m.encode(buf, reflect.ValueOf(m.value)) if m.Has(VariantArrayDimensions) { buf.Write([]byte{byte(m.arrayDimensionsLength), byte(m.arrayDimensionsLength >> 8), byte(m.arrayDimensionsLength >> 16), byte(m.arrayDimensionsLength >> 24)}) @@ -339,11 +338,11 @@ func (m *Variant) MarshalOPCUA() ([]byte, error) { buf.Write([]byte{byte(v), byte(v >> 8), byte(v >> 16), byte(v >> 24)}) } } - return buf.Bytes(), nil + return nil } // encode recursively writes the values to the buffer. -func (m *Variant) encode(buf *bytes.Buffer, val reflect.Value) { +func (m *Variant) encode(buf *codec.Stream, val reflect.Value) { if val.Kind() != reflect.Slice || m.Type() == TypeIDByteString { b, _ := codec.Marshal(val.Interface()) buf.Write(b) diff --git a/uacp/conn.go b/uacp/conn.go index ce237923..3a60dff9 100644 --- a/uacp/conn.go +++ b/uacp/conn.go @@ -412,12 +412,14 @@ func (c *Conn) Send(typ string, msg interface{}) error { return errors.Errorf("send packet too large: %d > %d bytes", h.MessageSize, c.ack.SendBufSize) } - hdr, err := h.MarshalOPCUA() + hdr, err := codec.Marshal(&h) if err != nil { return errors.Errorf("encode hdr failed: %v", err) } - b := append(hdr, body...) + b := make([]byte, len(hdr)+len(body)) + copy(b, hdr) + copy(b[len(hdr):], body) if _, err := c.Write(b); err != nil { return errors.Errorf("write failed: %s", err) } diff --git a/uacp/uacp.go b/uacp/uacp.go index 60c0d017..67d9cf21 100644 --- a/uacp/uacp.go +++ b/uacp/uacp.go @@ -8,6 +8,7 @@ import ( "bytes" "encoding/binary" + "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/errors" "github.com/gopcua/opcua/ua" ) @@ -48,16 +49,16 @@ func (h *Header) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *Header) MarshalOPCUA() ([]byte, error) { +func (h *Header) EncodeOPCUA(buf *codec.Stream) error { if len(h.MessageType) != 3 { - return nil, errors.Errorf("invalid message type: %q", h.MessageType) + return errors.Errorf("invalid message type: %q", h.MessageType) } - var buf bytes.Buffer + // var buf bytes.Buffer buf.Write([]byte(h.MessageType)) buf.WriteByte(h.ChunkType) buf.Write([]byte{byte(h.MessageSize), byte(h.MessageSize >> 8), byte(h.MessageSize >> 16), byte(h.MessageSize >> 24)}) - return buf.Bytes(), nil + return nil } // Hello represents a OPC UA Hello. @@ -83,7 +84,7 @@ func (h *Hello) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *Hello) MarshalOPCUA() ([]byte, error) { +func (h *Hello) EncodeOPCUA(s *codec.Stream) error { var buf bytes.Buffer buf.Write([]byte{byte(h.Version), byte(h.Version >> 8), byte(h.Version >> 16), byte(h.Version >> 24)}) buf.Write([]byte{byte(h.ReceiveBufSize), byte(h.ReceiveBufSize >> 8), byte(h.ReceiveBufSize >> 16), byte(h.ReceiveBufSize >> 24)}) @@ -97,7 +98,8 @@ func (h *Hello) MarshalOPCUA() ([]byte, error) { buf.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) buf.WriteString(h.EndpointURL) } - return buf.Bytes(), nil + s.Write(buf.Bytes()) + return nil } // Acknowledge represents a OPC UA Acknowledge. @@ -121,14 +123,15 @@ func (a *Acknowledge) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (a *Acknowledge) MarshalOPCUA() ([]byte, error) { +func (a *Acknowledge) EncodeOPCUA(s *codec.Stream) error { buf := make([]byte, 0, 160) buf = binary.LittleEndian.AppendUint32(buf, a.Version) buf = binary.LittleEndian.AppendUint32(buf, a.ReceiveBufSize) buf = binary.LittleEndian.AppendUint32(buf, a.SendBufSize) buf = binary.LittleEndian.AppendUint32(buf, a.MaxMessageSize) buf = binary.LittleEndian.AppendUint32(buf, a.MaxChunkCount) - return buf, nil + s.Write(buf) + return nil } // ReverseHello represents a OPC UA ReverseHello. diff --git a/uasc/asymmetric_security_header.go b/uasc/asymmetric_security_header.go index 856f8abe..addebe80 100644 --- a/uasc/asymmetric_security_header.go +++ b/uasc/asymmetric_security_header.go @@ -6,10 +6,28 @@ package uasc import ( "fmt" + "sync" "github.com/gopcua/opcua/ua" ) +func acquireAsymmetricSecurityHeader() *AsymmetricSecurityHeader { + v := asymmetricSecurityHeaderPool.Get() + if v == nil { + return &AsymmetricSecurityHeader{} + } + return v.(*AsymmetricSecurityHeader) +} + +func releaseAsymmetricSecurityHeader(h *AsymmetricSecurityHeader) { + h.SecurityPolicyURI = "" + h.SenderCertificate = h.SenderCertificate[:0] + h.ReceiverCertificateThumbprint = h.ReceiverCertificateThumbprint[:0] + asymmetricSecurityHeaderPool.Put(h) +} + +var asymmetricSecurityHeaderPool sync.Pool + // AsymmetricSecurityHeader represents a Asymmetric Algorithm Security Header in OPC UA Secure Conversation. type AsymmetricSecurityHeader struct { SecurityPolicyURI string diff --git a/uasc/header.go b/uasc/header.go index f4e91b74..72c807ee 100644 --- a/uasc/header.go +++ b/uasc/header.go @@ -5,9 +5,10 @@ package uasc import ( - "bytes" "fmt" + "sync" + "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/errors" "github.com/gopcua/opcua/ua" ) @@ -26,6 +27,23 @@ const ( ChunkTypeError = 'A' ) +func acquireHeader() *Header { + v := headerPool.Get() + if v == nil { + return &Header{} + } + return v.(*Header) +} + +func releaseHeader(h *Header) { + h.MessageType = "" + h.MessageSize = 0 + h.SecureChannelID = 0 + headerPool.Put(h) +} + +var headerPool sync.Pool + // Header represents a OPC UA Secure Conversation Header. type Header struct { MessageType string @@ -52,17 +70,17 @@ func (h *Header) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *Header) MarshalOPCUA() ([]byte, error) { +func (h *Header) EncodeOPCUA(buf *codec.Stream) error { if len(h.MessageType) != 3 { - return nil, errors.Errorf("invalid message type: %q", h.MessageType) + return errors.Errorf("invalid message type: %q", h.MessageType) } - var buf bytes.Buffer + // var buf bytes.Buffer buf.WriteString(h.MessageType) buf.WriteByte(h.ChunkType) buf.Write([]byte{byte(h.MessageSize), byte(h.MessageSize >> 8), byte(h.MessageSize >> 16), byte(h.MessageSize >> 24)}) buf.Write([]byte{byte(h.SecureChannelID), byte(h.SecureChannelID >> 8), byte(h.SecureChannelID >> 16), byte(h.SecureChannelID >> 24)}) - return buf.Bytes(), nil + return nil } // String returns Header in string. diff --git a/uasc/message.go b/uasc/message.go index 75e3caae..2f462832 100644 --- a/uasc/message.go +++ b/uasc/message.go @@ -6,6 +6,7 @@ package uasc import ( "math" + "sync" "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/errors" @@ -19,6 +20,13 @@ type MessageHeader struct { *SequenceHeader } +func (m *MessageHeader) reset() { + m.Header = nil + m.AsymmetricSecurityHeader = nil + m.SymmetricSecurityHeader = nil + m.SequenceHeader = nil +} + func (m *MessageHeader) Decode(b []byte) (int, error) { buf := ua.NewBuffer(b) @@ -79,6 +87,25 @@ func (m *MessageAbort) MessageAbort() string { return ua.StatusCode(m.ErrorCode).Error() } +func acquireMessage() *Message { + m := messagePool.Get() + if m == nil { + return &Message{ + MessageHeader: &MessageHeader{}, + } + } + return m.(*Message) +} + +func releaseMessage(m *Message) { + m.TypeID = nil + m.Service = nil + m.MessageHeader.reset() + messagePool.Put(m) +} + +var messagePool sync.Pool + // Message represents a OPC UA Secure Conversation message. type Message struct { *MessageHeader @@ -86,6 +113,11 @@ type Message struct { Service interface{} } +func (m *Message) reset(typeID *ua.ExpandedNodeID, service interface{}) { + m.TypeID = typeID + m.Service = service +} + func (m *Message) Decode(b []byte) (int, error) { m.MessageHeader = new(MessageHeader) var pos int @@ -106,13 +138,14 @@ func (m *Message) Decode(b []byte) (int, error) { return len(b), err } -func (m *Message) MarshalOPCUA() ([]byte, error) { +func (m *Message) EncodeOPCUA(s *codec.Stream) error { chunks, err := m.MarshalChunks(math.MaxUint32) if err != nil { - return nil, err + return err } + s.Write(chunks[0]) - return chunks[0], nil + return nil } func (m *Message) MarshalChunks(maxBodySize uint32) ([][]byte, error) { diff --git a/uasc/secure_channel.go b/uasc/secure_channel.go index 10e0cf46..e8b4009a 100644 --- a/uasc/secure_channel.go +++ b/uasc/secure_channel.go @@ -17,6 +17,7 @@ import ( "github.com/gopcua/opcua/debug" "github.com/gopcua/opcua/errors" + "github.com/gopcua/opcua/id" "github.com/gopcua/opcua/ua" "github.com/gopcua/opcua/uacp" "github.com/gopcua/opcua/uapolicy" @@ -591,8 +592,8 @@ func (s *SecureChannel) scheduleRenewal(instance *channelInstance) { debug.Printf("uasc %d: security token is refreshed at %s (%s). channelID=%d tokenID=%d", s.c.ID(), time.Now().UTC().Add(when).Format(time.RFC3339), when, instance.secureChannelID, instance.securityTokenID) - t := time.NewTimer(when) - defer t.Stop() + t := AcquireTimer(when) + defer ReleaseTimer(t) select { case <-s.closing: @@ -623,8 +624,8 @@ func (s *SecureChannel) scheduleExpiration(instance *channelInstance) { debug.Printf("uasc %d: security token expires at %s. channelID=%d tokenID=%d", s.c.ID(), when.UTC().Format(time.RFC3339), instance.secureChannelID, instance.securityTokenID) - t := time.NewTimer(time.Until(when)) - defer t.Stop() + t := AcquireTimer(time.Until(when)) + defer ReleaseTimer(t) select { case <-s.closing: @@ -677,8 +678,8 @@ func (s *SecureChannel) sendRequestWithTimeout( } // `+ timeoutLeniency` to give the server a chance to respond to TimeoutHint - timer := time.NewTimer(timeout + timeoutLeniency) - defer timer.Stop() + timer := AcquireTimer(timeout + timeoutLeniency) + defer ReleaseTimer(timer) select { case <-ctx.Done(): @@ -749,9 +750,95 @@ func (s *SecureChannel) sendAsyncWithTimeout( instance.Lock() defer instance.Unlock() - m, err := instance.newRequestMessage(req, reqID, authToken, timeout) - if err != nil { - return nil, err + m := acquireMessage() + defer releaseMessage(m) + + typeID := ua.ServiceTypeID(req) + if typeID == 0 { + return nil, errors.Errorf("unknown service %T. Did you call register?", req) + } + if authToken == nil { + authToken = ua.NewTwoByteNodeID(0) + } + + reqHdr := &ua.RequestHeader{ + AuthenticationToken: authToken, + Timestamp: instance.sc.timeNow(), + RequestHandle: reqID, // TODO: can I cheat like this? + AdditionalHeader: ua.NewExtensionObject(nil), + } + + if timeout > 0 && timeout < instance.sc.cfg.RequestTimeout { + timeout = instance.sc.cfg.RequestTimeout + } + reqHdr.TimeoutHint = uint32(timeout / time.Millisecond) + req.SetHeader(reqHdr) + + h := acquireHeader() + ash := acquireAsymmetricSecurityHeader() + sh := acquireSequenceHeader() + ssh := acquireSymmetricSecurityHeader() + + defer func() { + releaseHeader(h) + releaseAsymmetricSecurityHeader(ash) + releaseSequenceHeader(sh) + releaseSymmetricSecurityHeader(ssh) + }() + + m.reset(ua.NewFourByteExpandedNodeID(0, typeID), req) + sequenceNumber := instance.nextSequenceNumber() + switch typeID { + case id.OpenSecureChannelRequest_Encoding_DefaultBinary, id.OpenSecureChannelResponse_Encoding_DefaultBinary: + // Do not send the thumbprint for security mode None + // even if we have a certificate. + // + // See https://github.com/gopcua/opcua/issues/259 + thumbprint := instance.sc.cfg.Thumbprint + if instance.sc.cfg.SecurityMode == ua.MessageSecurityModeNone { + thumbprint = nil + } + + h.MessageType = MessageTypeOpenSecureChannel + h.ChunkType = ChunkTypeFinal + h.SecureChannelID = instance.secureChannelID + + ash.SecurityPolicyURI = instance.sc.cfg.SecurityPolicyURI + ash.SenderCertificate = instance.sc.cfg.Certificate + ash.ReceiverCertificateThumbprint = thumbprint + + sh.SequenceNumber = sequenceNumber + sh.RequestID = reqID + + m.MessageHeader.Header = h + m.MessageHeader.AsymmetricSecurityHeader = ash + m.MessageHeader.SequenceHeader = sh + + case id.CloseSecureChannelRequest_Encoding_DefaultBinary, id.CloseSecureChannelResponse_Encoding_DefaultBinary: + + h.MessageType = MessageTypeCloseSecureChannel + h.ChunkType = ChunkTypeFinal + h.SecureChannelID = instance.secureChannelID + ssh.TokenID = instance.securityTokenID + sh.SequenceNumber = sequenceNumber + sh.RequestID = reqID + + m.MessageHeader.Header = h + m.MessageHeader.SymmetricSecurityHeader = ssh + m.MessageHeader.SequenceHeader = sh + + default: + + h.MessageType = MessageTypeMessage + h.ChunkType = ChunkTypeFinal + h.SecureChannelID = instance.secureChannelID + ssh.TokenID = instance.securityTokenID + sh.SequenceNumber = sequenceNumber + sh.RequestID = reqID + + m.MessageHeader.Header = h + m.MessageHeader.SymmetricSecurityHeader = ssh + m.MessageHeader.SequenceHeader = sh } var resp chan *response diff --git a/uasc/sequence_header.go b/uasc/sequence_header.go index 9d37cdf4..4726160d 100644 --- a/uasc/sequence_header.go +++ b/uasc/sequence_header.go @@ -6,10 +6,26 @@ package uasc import ( "fmt" + "sync" "github.com/gopcua/opcua/ua" ) +func acquireSequenceHeader() *SequenceHeader { + if v, ok := sequenceHeaderPool.Get().(*SequenceHeader); ok { + return v + } + return &SequenceHeader{} +} + +func releaseSequenceHeader(h *SequenceHeader) { + h.RequestID = 0 + h.SequenceNumber = 0 + sequenceHeaderPool.Put(h) +} + +var sequenceHeaderPool sync.Pool + // SequenceHeader represents a Sequence Header in OPC UA Secure Conversation. type SequenceHeader struct { SequenceNumber uint32 diff --git a/uasc/symmetric_security_header.go b/uasc/symmetric_security_header.go index c0ed7af2..31a032dd 100644 --- a/uasc/symmetric_security_header.go +++ b/uasc/symmetric_security_header.go @@ -6,10 +6,26 @@ package uasc import ( "fmt" + "sync" "github.com/gopcua/opcua/ua" ) +func acquireSymmetricSecurityHeader() *SymmetricSecurityHeader { + v := symmetricSecurityHeaderPool.Get() + if v == nil { + return &SymmetricSecurityHeader{} + } + return v.(*SymmetricSecurityHeader) +} + +func releaseSymmetricSecurityHeader(h *SymmetricSecurityHeader) { + h.TokenID = 0 + symmetricSecurityHeaderPool.Put(h) +} + +var symmetricSecurityHeaderPool sync.Pool + // SymmetricSecurityHeader represents a Symmetric Algorithm Security Header in OPC UA Secure Conversation. type SymmetricSecurityHeader struct { TokenID uint32 diff --git a/uasc/timer.go b/uasc/timer.go new file mode 100644 index 00000000..7c6c32b9 --- /dev/null +++ b/uasc/timer.go @@ -0,0 +1,55 @@ +package uasc + +import ( + "sync" + "time" +) + +func initTimer(t *time.Timer, timeout time.Duration) *time.Timer { + if t == nil { + return time.NewTimer(timeout) + } + if t.Reset(timeout) { + // developer sanity-check + panic("BUG: active timer trapped into initTimer()") + } + return t +} + +func stopTimer(t *time.Timer) { + if !t.Stop() { + // Collect possibly added time from the channel + // if timer has been stopped and nobody collected its value. + select { + case <-t.C: + default: + } + } +} + +// AcquireTimer returns a time.Timer from the pool and updates it to +// send the current time on its channel after at least timeout. +// +// The returned Timer may be returned to the pool with ReleaseTimer +// when no longer needed. This allows reducing GC load. +func AcquireTimer(timeout time.Duration) *time.Timer { + v := timerPool.Get() + if v == nil { + return time.NewTimer(timeout) + } + t := v.(*time.Timer) + initTimer(t, timeout) + return t +} + +// ReleaseTimer returns the time.Timer acquired via AcquireTimer to the pool +// and prevents the Timer from firing. +// +// Do not access the released time.Timer or read from its channel otherwise +// data races may occur. +func ReleaseTimer(t *time.Timer) { + stopTimer(t) + timerPool.Put(t) +} + +var timerPool sync.Pool From 0b18373854440e7b70bc962f96720426db526615 Mon Sep 17 00:00:00 2001 From: yuanliang Date: Sun, 2 Jun 2024 19:34:44 +0800 Subject: [PATCH 10/14] feat: add stream to reduce alloc Signed-off-by: yuanliang --- uasc/secure_channel.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/uasc/secure_channel.go b/uasc/secure_channel.go index e8b4009a..30b26f04 100644 --- a/uasc/secure_channel.go +++ b/uasc/secure_channel.go @@ -815,7 +815,6 @@ func (s *SecureChannel) sendAsyncWithTimeout( m.MessageHeader.SequenceHeader = sh case id.CloseSecureChannelRequest_Encoding_DefaultBinary, id.CloseSecureChannelResponse_Encoding_DefaultBinary: - h.MessageType = MessageTypeCloseSecureChannel h.ChunkType = ChunkTypeFinal h.SecureChannelID = instance.secureChannelID @@ -828,7 +827,6 @@ func (s *SecureChannel) sendAsyncWithTimeout( m.MessageHeader.SequenceHeader = sh default: - h.MessageType = MessageTypeMessage h.ChunkType = ChunkTypeFinal h.SecureChannelID = instance.secureChannelID From 956518e4c7db68e98993d10860501bb926c907f3 Mon Sep 17 00:00:00 2001 From: yuanliang Date: Mon, 3 Jun 2024 10:10:35 +0800 Subject: [PATCH 11/14] feat: replace the buffer allocations Signed-off-by: yuanliang --- codec/const.go | 7 ++++++ codec/encode.go | 54 ++++++++++++++---------------------------- codec/stream.go | 14 +++++++++++ ua/datatypes.go | 28 ++++++++++------------ ua/diagnostic_info.go | 19 +++++++-------- ua/expanded_node_id.go | 10 ++++---- ua/extension_object.go | 11 ++++----- ua/node_id.go | 26 ++++++++++---------- ua/variant.go | 13 +++++----- uacp/uacp.go | 43 +++++++++++++-------------------- uasc/header.go | 11 ++++----- uasc/message_test.go | 50 -------------------------------------- 12 files changed, 111 insertions(+), 175 deletions(-) create mode 100644 codec/const.go diff --git a/codec/const.go b/codec/const.go new file mode 100644 index 00000000..15c5db36 --- /dev/null +++ b/codec/const.go @@ -0,0 +1,7 @@ +package codec + +const ( + NULL = 0xffffffff + F32QNAN = 0xc0ff0000 + F64QNAN = 0xfff8000000000000 +) diff --git a/codec/encode.go b/codec/encode.go index 9b7c287e..99daeb71 100644 --- a/codec/encode.go +++ b/codec/encode.go @@ -241,7 +241,7 @@ func marshalerEncoder(e *encodeState, v reflect.Value) { func addrMarshalerEncoder(e *encodeState, v reflect.Value) { va := v.Addr() if va.IsNil() { - e.Write(null) + e.WriteUint32(NULL) return } m := va.Interface().(Encoder) @@ -251,18 +251,6 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value) { } } -func (e *encodeState) writeUint16(n uint16) { - e.Write([]byte{byte(n), byte(n >> 8)}) -} - -func (e *encodeState) writeUint32(n uint32) { - e.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) -} - -func (e *encodeState) writeUint64(n uint64) { - e.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24), byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56)}) -} - func boolEncoder(e *encodeState, v reflect.Value) { val := v.Bool() if val { @@ -284,67 +272,61 @@ func uint8Encoder(e *encodeState, v reflect.Value) { func int16Encoder(e *encodeState, v reflect.Value) { val := uint16(v.Int()) - e.writeUint16(val) + e.WriteUint16(val) } func uint16Encoder(e *encodeState, v reflect.Value) { val := uint16(v.Uint()) - e.writeUint16(val) + e.WriteUint16(val) } func int32Encoder(e *encodeState, v reflect.Value) { val := uint32(v.Int()) - e.writeUint32(val) + e.WriteUint32(val) } func uint32Encoder(e *encodeState, v reflect.Value) { val := uint32(v.Uint()) - e.writeUint32(val) + e.WriteUint32(val) } func int64Encoder(e *encodeState, v reflect.Value) { val := uint64(v.Int()) - e.writeUint64(val) + e.WriteUint64(val) } func uint64Encoder(e *encodeState, v reflect.Value) { val := v.Uint() - e.writeUint64(val) + e.WriteUint64(val) } -var ( - null = []byte{0xff, 0xff, 0xff, 0xff} - f32qnan = []byte{0x00, 0x00, 0xff, 0xc0} - f64qnan = []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf8, 0xff} -) - func float32Encoder(e *encodeState, v reflect.Value) { if math.IsNaN(v.Float()) { - e.Write(f32qnan) + e.WriteUint32(F32QNAN) } else { val := math.Float32bits((float32)(v.Float())) - e.writeUint32(val) + e.WriteUint32(val) } } func float64Encoder(e *encodeState, v reflect.Value) { if math.IsNaN(v.Float()) { - e.Write(f64qnan) + e.WriteUint64(F64QNAN) } else { val := math.Float64bits(v.Float()) - e.writeUint64(val) + e.WriteUint64(val) } } func stringEncoder(e *encodeState, v reflect.Value) { s := v.String() if s == "" { - e.Write(null) + e.WriteUint32(NULL) return } l := len(s) - e.writeUint32(uint32(l)) + e.WriteUint32(uint32(l)) e.Write([]byte(s)) } @@ -357,7 +339,7 @@ func timeEncoder(e *encodeState, v reflect.Value) { // encode time in "100 nanosecond intervals since January 1, 1601" ts = uint64(val.UTC().UnixNano()/100 + 116444736000000000) } - e.writeUint64(ts) + e.WriteUint64(ts) } func interfaceEncoder(e *encodeState, v reflect.Value) { @@ -407,12 +389,12 @@ func newStructEncoder(t reflect.Type) encoderFunc { func encodeByteSlice(e *encodeState, v reflect.Value) { if v.IsNil() { - e.Write(null) + e.WriteUint32(NULL) return } n := v.Len() - e.writeUint32(uint32(n)) + e.WriteUint32(uint32(n)) b := make([]byte, n) reflect.Copy(reflect.ValueOf(b), v) @@ -426,7 +408,7 @@ type sliceEncoder struct { func (se sliceEncoder) encode(e *encodeState, v reflect.Value) { if v.IsNil() { - e.Write(null) + e.WriteUint32(NULL) return } @@ -471,7 +453,7 @@ type arrayEncoder struct { func (ae arrayEncoder) encode(e *encodeState, v reflect.Value) { n := v.Len() - e.writeUint32(uint32(n)) + e.WriteUint32(uint32(n)) // fast path for []byte if v.Type().Elem().Kind() == reflect.Uint8 { diff --git a/codec/stream.go b/codec/stream.go index 70614b92..21a61aa0 100644 --- a/codec/stream.go +++ b/codec/stream.go @@ -1,5 +1,7 @@ package codec +import "encoding/binary" + type Stream struct { buf []byte } @@ -30,3 +32,15 @@ func (s *Stream) Reset() { func (s *Stream) Bytes() []byte { return s.buf[:len(s.buf)] } + +func (s *Stream) WriteUint16(n uint16) { + s.buf = binary.LittleEndian.AppendUint16(s.buf, n) +} + +func (s *Stream) WriteUint32(n uint32) { + s.buf = binary.LittleEndian.AppendUint32(s.buf, n) +} + +func (s *Stream) WriteUint64(n uint64) { + s.buf = binary.LittleEndian.AppendUint64(s.buf, n) +} diff --git a/ua/datatypes.go b/ua/datatypes.go index 1cc6318a..97fa0007 100644 --- a/ua/datatypes.go +++ b/ua/datatypes.go @@ -73,21 +73,21 @@ func (d *DataValue) EncodeOPCUA(s *codec.Stream) error { s.Write(b) } if d.Has(DataValueStatusCode) { - s.Write([]byte{byte(d.Status), byte(d.Status >> 8), byte(d.Status >> 16), byte(d.Status >> 24)}) + s.WriteUint32(uint32(d.Status)) } if d.Has(DataValueSourceTimestamp) { b, err = codec.Marshal(d.SourceTimestamp) s.Write(b) } if d.Has(DataValueSourcePicoseconds) { - s.Write([]byte{byte(d.SourcePicoseconds), byte(d.SourcePicoseconds >> 8)}) + s.WriteUint16(d.SourcePicoseconds) } if d.Has(DataValueServerTimestamp) { b, err = codec.Marshal(d.ServerTimestamp) s.Write(b) } if d.Has(DataValueServerPicoseconds) { - s.Write([]byte{byte(d.ServerPicoseconds), byte(d.ServerPicoseconds >> 8)}) + s.WriteUint16(d.ServerPicoseconds) } return err } @@ -159,12 +159,10 @@ func (g *GUID) Decode(b []byte) (int, error) { } func (g *GUID) EncodeOPCUA(s *codec.Stream) error { - buf := make([]byte, 0, 8+len(g.Data4)) - buf = binary.LittleEndian.AppendUint32(buf, g.Data1) - buf = binary.LittleEndian.AppendUint16(buf, g.Data2) - buf = binary.LittleEndian.AppendUint16(buf, g.Data3) - buf = append(buf, g.Data4...) - s.Write(buf) + s.WriteUint32(g.Data1) + s.WriteUint16(g.Data2) + s.WriteUint16(g.Data3) + s.Write(g.Data4) return nil } @@ -239,19 +237,19 @@ func (l *LocalizedText) EncodeOPCUA(s *codec.Stream) error { if l.Has(LocalizedTextLocale) { n := len(l.Locale) if n == 0 { - s.Write([]byte{0xff, 0xff, 0xff, 0xff}) + s.WriteUint32(codec.NULL) } else { - s.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) - s.Write([]byte(l.Locale)) + s.WriteUint32(uint32(n)) + s.WriteString(l.Locale) } } if l.Has(LocalizedTextText) { n := len(l.Text) if n == 0 { - s.Write([]byte{0xff, 0xff, 0xff, 0xff}) + s.WriteUint32(codec.NULL) } else { - s.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) - s.Write([]byte(l.Text)) + s.WriteUint32(uint32(n)) + s.WriteString(l.Text) } } return nil diff --git a/ua/diagnostic_info.go b/ua/diagnostic_info.go index e95b72a2..0232b3ce 100644 --- a/ua/diagnostic_info.go +++ b/ua/diagnostic_info.go @@ -62,35 +62,34 @@ func (d *DiagnosticInfo) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (d *DiagnosticInfo) EncodeOPCUA(buf *codec.Stream) error { - // var buf bytes.Buffer - buf.WriteByte(d.EncodingMask) +func (d *DiagnosticInfo) EncodeOPCUA(s *codec.Stream) error { + s.WriteByte(d.EncodingMask) if d.Has(DiagnosticInfoSymbolicID) { - buf.Write([]byte{byte(d.SymbolicID), byte(d.SymbolicID >> 8), byte(d.SymbolicID >> 16), byte(d.SymbolicID >> 24)}) + s.WriteUint32(uint32(d.SymbolicID)) } if d.Has(DiagnosticInfoNamespaceURI) { - buf.Write([]byte{byte(d.NamespaceURI), byte(d.NamespaceURI >> 8), byte(d.NamespaceURI >> 16), byte(d.NamespaceURI >> 24)}) + s.WriteUint32(uint32(d.NamespaceURI)) } if d.Has(DiagnosticInfoLocale) { - buf.Write([]byte{byte(d.Locale), byte(d.Locale >> 8), byte(d.Locale >> 16), byte(d.Locale >> 24)}) + s.WriteUint32(uint32(d.Locale)) } if d.Has(DiagnosticInfoLocalizedText) { - buf.Write([]byte{byte(d.LocalizedText), byte(d.LocalizedText >> 8), byte(d.LocalizedText >> 16), byte(d.LocalizedText >> 24)}) + s.WriteUint32(uint32(d.LocalizedText)) } if d.Has(DiagnosticInfoAdditionalInfo) { b, _ := codec.Marshal(d.AdditionalInfo) - buf.Write(b) + s.Write(b) } if d.Has(DiagnosticInfoInnerStatusCode) { - buf.Write([]byte{byte(d.InnerStatusCode), byte(d.InnerStatusCode >> 8), byte(d.InnerStatusCode >> 16), byte(d.InnerStatusCode >> 24)}) + s.WriteUint32(uint32(d.InnerStatusCode)) } if d.Has(DiagnosticInfoInnerDiagnosticInfo) { b, err := codec.Marshal(d.InnerDiagnosticInfo) if err != nil { return err } - buf.Write(b) + s.Write(b) } return nil diff --git a/ua/expanded_node_id.go b/ua/expanded_node_id.go index 51ca4935..aef56554 100644 --- a/ua/expanded_node_id.go +++ b/ua/expanded_node_id.go @@ -103,18 +103,16 @@ func (e *ExpandedNodeID) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (e *ExpandedNodeID) EncodeOPCUA(buf *codec.Stream) error { - // var buf bytes.Buffer - err := e.NodeID.EncodeOPCUA(buf) - // buf.Write(b) +func (e *ExpandedNodeID) EncodeOPCUA(s *codec.Stream) error { + err := e.NodeID.EncodeOPCUA(s) if e.HasNamespaceURI() { b, _ := codec.Marshal(e.NamespaceURI) - buf.Write(b) + s.Write(b) } if e.HasServerIndex() { b, _ := codec.Marshal(e.ServerIndex) - buf.Write(b) + s.Write(b) } return err } diff --git a/ua/extension_object.go b/ua/extension_object.go index 90fba83e..677a1a00 100644 --- a/ua/extension_object.go +++ b/ua/extension_object.go @@ -85,11 +85,10 @@ func (e *ExtensionObject) Decode(b []byte) (int, error) { return buf.Pos(), body.Error() } -func (e *ExtensionObject) EncodeOPCUA(buf *codec.Stream) error { - // var buf bytes.Buffer +func (e *ExtensionObject) EncodeOPCUA(s *codec.Stream) error { b, err := codec.Marshal(e.TypeID) - buf.Write(b) - buf.WriteByte(e.EncodingMask) + s.Write(b) + s.WriteByte(e.EncodingMask) if e.EncodingMask == ExtensionObjectEmpty { return err } @@ -99,8 +98,8 @@ func (e *ExtensionObject) EncodeOPCUA(buf *codec.Stream) error { return err } n := uint32(len(body)) - buf.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) - buf.Write(body) + s.WriteUint32(n) + s.Write(body) return err } diff --git a/ua/node_id.go b/ua/node_id.go index 057bbaa1..ca43530a 100644 --- a/ua/node_id.go +++ b/ua/node_id.go @@ -367,29 +367,29 @@ func (n *NodeID) Decode(b []byte) (int, error) { } } -func (n *NodeID) EncodeOPCUA(buf *codec.Stream) error { - // var buf bytes.Buffer - buf.WriteByte(byte(n.mask)) +func (n *NodeID) EncodeOPCUA(s *codec.Stream) error { + s.WriteByte(byte(n.mask)) switch n.Type() { case NodeIDTypeTwoByte: - buf.WriteByte(byte(n.nid)) + s.WriteByte(byte(n.nid)) case NodeIDTypeFourByte: - buf.WriteByte(byte(n.ns)) - buf.Write([]byte{byte(n.nid), byte(n.nid >> 8)}) + s.WriteByte(byte(n.ns)) + s.WriteUint16(uint16(n.nid)) case NodeIDTypeNumeric: - buf.Write([]byte{byte(n.ns), byte(n.ns >> 8), byte(n.nid), byte(n.nid >> 8), byte(n.nid >> 16), byte(n.nid >> 24)}) + s.WriteUint16(n.ns) + s.WriteUint32(n.nid) case NodeIDTypeGUID: - buf.Write([]byte{byte(n.ns), byte(n.ns >> 8)}) + s.WriteUint16(n.ns) b, _ := codec.Marshal(n.gid) - buf.Write(b) + s.Write(b) case NodeIDTypeByteString, NodeIDTypeString: - buf.Write([]byte{byte(n.ns), byte(n.ns >> 8)}) + s.WriteUint16(n.ns) l := uint32(len(n.bid)) if l == 0 { - buf.Write([]byte{0xff, 0xff, 0xff, 0xff}) + s.WriteUint32(codec.NULL) } else { - buf.Write([]byte{byte(l), byte(l >> 8), byte(l >> 16), byte(l >> 24)}) - buf.Write(n.bid) + s.WriteUint32(l) + s.Write(n.bid) } default: return fmt.Errorf("invalid node id type: %d", n.mask) diff --git a/ua/variant.go b/ua/variant.go index c80e9ec8..c3cb06b0 100644 --- a/ua/variant.go +++ b/ua/variant.go @@ -318,9 +318,8 @@ func (m *Variant) decodeValue(buf *Buffer) interface{} { } } -func (m *Variant) EncodeOPCUA(buf *codec.Stream) error { - // var buf bytes.Buffer - buf.WriteByte(m.mask) +func (m *Variant) EncodeOPCUA(s *codec.Stream) error { + s.WriteByte(m.mask) // a null value specifies that no other fields are encoded if m.Type() == TypeIDNull { @@ -328,14 +327,14 @@ func (m *Variant) EncodeOPCUA(buf *codec.Stream) error { } if m.Has(VariantArrayValues) { - buf.Write([]byte{byte(m.arrayLength), byte(m.arrayLength >> 8), byte(m.arrayLength >> 16), byte(m.arrayLength >> 24)}) + s.WriteUint32(uint32(m.arrayLength)) } - m.encode(buf, reflect.ValueOf(m.value)) + m.encode(s, reflect.ValueOf(m.value)) if m.Has(VariantArrayDimensions) { - buf.Write([]byte{byte(m.arrayDimensionsLength), byte(m.arrayDimensionsLength >> 8), byte(m.arrayDimensionsLength >> 16), byte(m.arrayDimensionsLength >> 24)}) + s.Write([]byte{byte(m.arrayDimensionsLength), byte(m.arrayDimensionsLength >> 8), byte(m.arrayDimensionsLength >> 16), byte(m.arrayDimensionsLength >> 24)}) for _, v := range m.arrayDimensions { - buf.Write([]byte{byte(v), byte(v >> 8), byte(v >> 16), byte(v >> 24)}) + s.WriteUint32(uint32(v)) } } return nil diff --git a/uacp/uacp.go b/uacp/uacp.go index 67d9cf21..0d25d267 100644 --- a/uacp/uacp.go +++ b/uacp/uacp.go @@ -5,9 +5,6 @@ package uacp import ( - "bytes" - "encoding/binary" - "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/errors" "github.com/gopcua/opcua/ua" @@ -49,15 +46,14 @@ func (h *Header) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *Header) EncodeOPCUA(buf *codec.Stream) error { +func (h *Header) EncodeOPCUA(s *codec.Stream) error { if len(h.MessageType) != 3 { return errors.Errorf("invalid message type: %q", h.MessageType) } - // var buf bytes.Buffer - buf.Write([]byte(h.MessageType)) - buf.WriteByte(h.ChunkType) - buf.Write([]byte{byte(h.MessageSize), byte(h.MessageSize >> 8), byte(h.MessageSize >> 16), byte(h.MessageSize >> 24)}) + s.WriteString(h.MessageType) + s.WriteByte(h.ChunkType) + s.WriteUint32(h.MessageSize) return nil } @@ -85,20 +81,17 @@ func (h *Hello) Decode(b []byte) (int, error) { } func (h *Hello) EncodeOPCUA(s *codec.Stream) error { - var buf bytes.Buffer - buf.Write([]byte{byte(h.Version), byte(h.Version >> 8), byte(h.Version >> 16), byte(h.Version >> 24)}) - buf.Write([]byte{byte(h.ReceiveBufSize), byte(h.ReceiveBufSize >> 8), byte(h.ReceiveBufSize >> 16), byte(h.ReceiveBufSize >> 24)}) - buf.Write([]byte{byte(h.SendBufSize), byte(h.SendBufSize >> 8), byte(h.SendBufSize >> 16), byte(h.SendBufSize >> 24)}) - buf.Write([]byte{byte(h.MaxMessageSize), byte(h.MaxMessageSize >> 8), byte(h.MaxMessageSize >> 16), byte(h.MaxMessageSize >> 24)}) - buf.Write([]byte{byte(h.MaxChunkCount), byte(h.MaxChunkCount >> 8), byte(h.MaxChunkCount >> 16), byte(h.MaxChunkCount >> 24)}) + s.WriteUint32(h.Version) + s.WriteUint32(h.ReceiveBufSize) + s.WriteUint32(h.SendBufSize) + s.WriteUint32(h.MaxMessageSize) + s.WriteUint32(h.MaxChunkCount) if len(h.EndpointURL) == 0 { - buf.Write([]byte{0xff, 0xff, 0xff, 0xff}) + s.WriteUint32(codec.NULL) } else { - n := len(h.EndpointURL) - buf.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) - buf.WriteString(h.EndpointURL) + s.WriteUint32(uint32(len(h.EndpointURL))) + s.WriteString(h.EndpointURL) } - s.Write(buf.Bytes()) return nil } @@ -124,13 +117,11 @@ func (a *Acknowledge) Decode(b []byte) (int, error) { } func (a *Acknowledge) EncodeOPCUA(s *codec.Stream) error { - buf := make([]byte, 0, 160) - buf = binary.LittleEndian.AppendUint32(buf, a.Version) - buf = binary.LittleEndian.AppendUint32(buf, a.ReceiveBufSize) - buf = binary.LittleEndian.AppendUint32(buf, a.SendBufSize) - buf = binary.LittleEndian.AppendUint32(buf, a.MaxMessageSize) - buf = binary.LittleEndian.AppendUint32(buf, a.MaxChunkCount) - s.Write(buf) + s.WriteUint32(a.Version) + s.WriteUint32(a.ReceiveBufSize) + s.WriteUint32(a.SendBufSize) + s.WriteUint32(a.MaxMessageSize) + s.WriteUint32(a.MaxChunkCount) return nil } diff --git a/uasc/header.go b/uasc/header.go index 72c807ee..ea2c3cae 100644 --- a/uasc/header.go +++ b/uasc/header.go @@ -70,16 +70,15 @@ func (h *Header) Decode(b []byte) (int, error) { return buf.Pos(), buf.Error() } -func (h *Header) EncodeOPCUA(buf *codec.Stream) error { +func (h *Header) EncodeOPCUA(s *codec.Stream) error { if len(h.MessageType) != 3 { return errors.Errorf("invalid message type: %q", h.MessageType) } - // var buf bytes.Buffer - buf.WriteString(h.MessageType) - buf.WriteByte(h.ChunkType) - buf.Write([]byte{byte(h.MessageSize), byte(h.MessageSize >> 8), byte(h.MessageSize >> 16), byte(h.MessageSize >> 24)}) - buf.Write([]byte{byte(h.SecureChannelID), byte(h.SecureChannelID >> 8), byte(h.SecureChannelID >> 16), byte(h.SecureChannelID >> 24)}) + s.WriteString(h.MessageType) + s.WriteByte(h.ChunkType) + s.WriteUint32(h.MessageSize) + s.WriteUint32(h.SecureChannelID) return nil } diff --git a/uasc/message_test.go b/uasc/message_test.go index 1df6035a..7c10d9eb 100644 --- a/uasc/message_test.go +++ b/uasc/message_test.go @@ -246,56 +246,6 @@ func TestMessage(t *testing.T) { 0x00, 0x00, 0x00, 0x00, 0x00, }, }, - // { - // Name: "OPN-2", - // Struct: func() interface{} { - // s := &SecureChannel{ - // endpointURL: "opc.tcp://192.168.118.199:4840", - // cfg: &Config{ - // SecurityPolicyURI: "opc.tcp://192.168.118.199:4840/FOO", - // }, - // } - // instance := &channelInstance{ - // sc: s, - // sequenceNumber: 1, - // securityTokenID: 0, - // maxBodySize: 65510, - // } - // m := instance.newMessage( - // &ua.OpenSecureChannelRequest{ - // RequestHeader: &ua.RequestHeader{ - // AuthenticationToken: ua.NewTwoByteNodeID(0), - // Timestamp: time.Date(2024, time.May, 30, 16, 17, 44, 0, time.Local), - // RequestHandle: 1, - // ReturnDiagnostics: 0, - // TimeoutHint: 10000, - // AdditionalHeader: ua.NewExtensionObject(nil), - // }, - // ClientProtocolVersion: 0, - // RequestType: ua.SecurityTokenRequestTypeIssue, - // SecurityMode: ua.MessageSecurityModeNone, - // ClientNonce: []byte{}, - // RequestedLifetime: 3600000, - // }, - // id.OpenSecureChannelRequest_Encoding_DefaultBinary, - // s.nextRequestID(), - // ) - - // // set message size manually, since it is computed in Encode - // // otherwise, the decode tests failed. - // m.Header.MessageSize = 119 - - // return m - // }(), - // Bytes: []byte{ // OpenSecureChannelRequest - // // Message Header - // // MessageType: OPN - // 0x4f, 0x50, 0x4e, - // // Chunk Type: Final - // 0x46, - // 0x77, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, 0x6f, 0x70, 0x63, 0x2e, 0x74, 0x63, 0x70, 0x3a, 0x2f, 0x2f, 0x31, 0x39, 0x32, 0x2e, 0x31, 0x36, 0x38, 0x2e, 0x31, 0x31, 0x38, 0x2e, 0x31, 0x39, 0x39, 0x3a, 0x34, 0x38, 0x34, 0x30, 0x2f, 0x46, 0x4f, 0x4f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0xbe, 0x01, 0x00, 0x00, 0x00, 0x04, 0xd5, 0xd8, 0x69, 0xb2, 0xda, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x10, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0xee, 0x36, 0x00, - // }, - // }, } RunCodecTest(t, cases) } From 43c6cfbc1acc7e2cb858b0b9c9fbd139ef3ccb2a Mon Sep 17 00:00:00 2001 From: yuanliang Date: Thu, 6 Jun 2024 16:08:24 +0800 Subject: [PATCH 12/14] feat: remove unuse field Signed-off-by: yuanliang --- codec/encode.go | 11 +++++------ ua/codec_test.go | 4 +--- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/codec/encode.go b/codec/encode.go index 99daeb71..b38bfde9 100644 --- a/codec/encode.go +++ b/codec/encode.go @@ -523,8 +523,7 @@ func newCondAddrEncoder(canAddrEnc, elseEnc encoderFunc) encoderFunc { // A field represents a single field found in a struct. type field struct { - name string - nameBytes []byte // []byte(name) + name string index []int typ reflect.Type @@ -543,19 +542,19 @@ func typeFields(t reflect.Type) structFields { t := f.Type // return marshalerEncoder directly, if it implements Marshaler. if t.Implements(encoderType) { - fields = append(fields, field{name: f.Name, nameBytes: []byte(f.Name), index: f.Index, typ: t, encoder: marshalerEncoder}) + fields = append(fields, field{name: f.Name, index: f.Index, typ: t, encoder: marshalerEncoder}) return } // time.Time is special because it has embedded structs that use timeEncoder. if t.AssignableTo(timeType) || (t.Kind() == reflect.Pointer && t.Elem().AssignableTo(timeType)) { - fields = append(fields, field{name: f.Name, nameBytes: []byte(f.Name), index: f.Index, typ: t, encoder: timeEncoder}) + fields = append(fields, field{name: f.Name, index: f.Index, typ: t, encoder: timeEncoder}) return } if t.ConvertibleTo(timeType) { converted := reflect.New(t).Elem().Convert(timeType) if _, ok := converted.Interface().(time.Time); ok { - fields = append(fields, field{name: f.Name, nameBytes: []byte(f.Name), index: f.Index, typ: t, encoder: timeEncoder}) + fields = append(fields, field{name: f.Name, index: f.Index, typ: t, encoder: timeEncoder}) return } } @@ -568,7 +567,7 @@ func typeFields(t reflect.Type) structFields { fields = append(fields, typeFields(t).list...) } - fields = append(fields, field{name: f.Name, nameBytes: []byte(f.Name), index: f.Index, typ: f.Type, encoder: typeEncoder(f.Type)}) + fields = append(fields, field{name: f.Name, index: f.Index, typ: f.Type, encoder: typeEncoder(f.Type)}) } // Process all fields in the root struct. diff --git a/ua/codec_test.go b/ua/codec_test.go index 5b6beea6..338af4ae 100644 --- a/ua/codec_test.go +++ b/ua/codec_test.go @@ -12,7 +12,6 @@ import ( "github.com/gopcua/opcua/codec" "github.com/pascaldekloe/goe/verify" - "github.com/stretchr/testify/assert" ) // CodecTestCase describes a test case for a encoding and decoding an @@ -59,8 +58,7 @@ func RunCodecTest(t *testing.T, cases []CodecTestCase) { if err != nil { t.Fatalf("failed to marshal message: %v", err) } - assert.Equal(t, c.Bytes, b) - // verify.Values(t, "", b, c.Bytes) + verify.Values(t, "", b, c.Bytes) }) }) } From a80cf3217266631b5628d669ae1d9d54f6cf1bf7 Mon Sep 17 00:00:00 2001 From: yuanliang Date: Fri, 14 Jun 2024 14:18:26 +0800 Subject: [PATCH 13/14] feat: remove invalid pool Signed-off-by: yuanliang --- uasc/asymmetric_security_header.go | 18 ----- uasc/header.go | 18 ----- uasc/message.go | 32 --------- uasc/secure_channel.go | 102 +++-------------------------- uasc/sequence_header.go | 16 ----- uasc/symmetric_security_header.go | 16 ----- 6 files changed, 8 insertions(+), 194 deletions(-) diff --git a/uasc/asymmetric_security_header.go b/uasc/asymmetric_security_header.go index addebe80..856f8abe 100644 --- a/uasc/asymmetric_security_header.go +++ b/uasc/asymmetric_security_header.go @@ -6,28 +6,10 @@ package uasc import ( "fmt" - "sync" "github.com/gopcua/opcua/ua" ) -func acquireAsymmetricSecurityHeader() *AsymmetricSecurityHeader { - v := asymmetricSecurityHeaderPool.Get() - if v == nil { - return &AsymmetricSecurityHeader{} - } - return v.(*AsymmetricSecurityHeader) -} - -func releaseAsymmetricSecurityHeader(h *AsymmetricSecurityHeader) { - h.SecurityPolicyURI = "" - h.SenderCertificate = h.SenderCertificate[:0] - h.ReceiverCertificateThumbprint = h.ReceiverCertificateThumbprint[:0] - asymmetricSecurityHeaderPool.Put(h) -} - -var asymmetricSecurityHeaderPool sync.Pool - // AsymmetricSecurityHeader represents a Asymmetric Algorithm Security Header in OPC UA Secure Conversation. type AsymmetricSecurityHeader struct { SecurityPolicyURI string diff --git a/uasc/header.go b/uasc/header.go index ea2c3cae..1e2e4587 100644 --- a/uasc/header.go +++ b/uasc/header.go @@ -6,7 +6,6 @@ package uasc import ( "fmt" - "sync" "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/errors" @@ -27,23 +26,6 @@ const ( ChunkTypeError = 'A' ) -func acquireHeader() *Header { - v := headerPool.Get() - if v == nil { - return &Header{} - } - return v.(*Header) -} - -func releaseHeader(h *Header) { - h.MessageType = "" - h.MessageSize = 0 - h.SecureChannelID = 0 - headerPool.Put(h) -} - -var headerPool sync.Pool - // Header represents a OPC UA Secure Conversation Header. type Header struct { MessageType string diff --git a/uasc/message.go b/uasc/message.go index 2f462832..71958fff 100644 --- a/uasc/message.go +++ b/uasc/message.go @@ -6,7 +6,6 @@ package uasc import ( "math" - "sync" "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/errors" @@ -20,13 +19,6 @@ type MessageHeader struct { *SequenceHeader } -func (m *MessageHeader) reset() { - m.Header = nil - m.AsymmetricSecurityHeader = nil - m.SymmetricSecurityHeader = nil - m.SequenceHeader = nil -} - func (m *MessageHeader) Decode(b []byte) (int, error) { buf := ua.NewBuffer(b) @@ -87,25 +79,6 @@ func (m *MessageAbort) MessageAbort() string { return ua.StatusCode(m.ErrorCode).Error() } -func acquireMessage() *Message { - m := messagePool.Get() - if m == nil { - return &Message{ - MessageHeader: &MessageHeader{}, - } - } - return m.(*Message) -} - -func releaseMessage(m *Message) { - m.TypeID = nil - m.Service = nil - m.MessageHeader.reset() - messagePool.Put(m) -} - -var messagePool sync.Pool - // Message represents a OPC UA Secure Conversation message. type Message struct { *MessageHeader @@ -113,11 +86,6 @@ type Message struct { Service interface{} } -func (m *Message) reset(typeID *ua.ExpandedNodeID, service interface{}) { - m.TypeID = typeID - m.Service = service -} - func (m *Message) Decode(b []byte) (int, error) { m.MessageHeader = new(MessageHeader) var pos int diff --git a/uasc/secure_channel.go b/uasc/secure_channel.go index 30b26f04..da3f5843 100644 --- a/uasc/secure_channel.go +++ b/uasc/secure_channel.go @@ -17,7 +17,6 @@ import ( "github.com/gopcua/opcua/debug" "github.com/gopcua/opcua/errors" - "github.com/gopcua/opcua/id" "github.com/gopcua/opcua/ua" "github.com/gopcua/opcua/uacp" "github.com/gopcua/opcua/uapolicy" @@ -746,99 +745,6 @@ func (s *SecureChannel) sendAsyncWithTimeout( respRequired bool, timeout time.Duration, ) (<-chan *response, error) { - - instance.Lock() - defer instance.Unlock() - - m := acquireMessage() - defer releaseMessage(m) - - typeID := ua.ServiceTypeID(req) - if typeID == 0 { - return nil, errors.Errorf("unknown service %T. Did you call register?", req) - } - if authToken == nil { - authToken = ua.NewTwoByteNodeID(0) - } - - reqHdr := &ua.RequestHeader{ - AuthenticationToken: authToken, - Timestamp: instance.sc.timeNow(), - RequestHandle: reqID, // TODO: can I cheat like this? - AdditionalHeader: ua.NewExtensionObject(nil), - } - - if timeout > 0 && timeout < instance.sc.cfg.RequestTimeout { - timeout = instance.sc.cfg.RequestTimeout - } - reqHdr.TimeoutHint = uint32(timeout / time.Millisecond) - req.SetHeader(reqHdr) - - h := acquireHeader() - ash := acquireAsymmetricSecurityHeader() - sh := acquireSequenceHeader() - ssh := acquireSymmetricSecurityHeader() - - defer func() { - releaseHeader(h) - releaseAsymmetricSecurityHeader(ash) - releaseSequenceHeader(sh) - releaseSymmetricSecurityHeader(ssh) - }() - - m.reset(ua.NewFourByteExpandedNodeID(0, typeID), req) - sequenceNumber := instance.nextSequenceNumber() - switch typeID { - case id.OpenSecureChannelRequest_Encoding_DefaultBinary, id.OpenSecureChannelResponse_Encoding_DefaultBinary: - // Do not send the thumbprint for security mode None - // even if we have a certificate. - // - // See https://github.com/gopcua/opcua/issues/259 - thumbprint := instance.sc.cfg.Thumbprint - if instance.sc.cfg.SecurityMode == ua.MessageSecurityModeNone { - thumbprint = nil - } - - h.MessageType = MessageTypeOpenSecureChannel - h.ChunkType = ChunkTypeFinal - h.SecureChannelID = instance.secureChannelID - - ash.SecurityPolicyURI = instance.sc.cfg.SecurityPolicyURI - ash.SenderCertificate = instance.sc.cfg.Certificate - ash.ReceiverCertificateThumbprint = thumbprint - - sh.SequenceNumber = sequenceNumber - sh.RequestID = reqID - - m.MessageHeader.Header = h - m.MessageHeader.AsymmetricSecurityHeader = ash - m.MessageHeader.SequenceHeader = sh - - case id.CloseSecureChannelRequest_Encoding_DefaultBinary, id.CloseSecureChannelResponse_Encoding_DefaultBinary: - h.MessageType = MessageTypeCloseSecureChannel - h.ChunkType = ChunkTypeFinal - h.SecureChannelID = instance.secureChannelID - ssh.TokenID = instance.securityTokenID - sh.SequenceNumber = sequenceNumber - sh.RequestID = reqID - - m.MessageHeader.Header = h - m.MessageHeader.SymmetricSecurityHeader = ssh - m.MessageHeader.SequenceHeader = sh - - default: - h.MessageType = MessageTypeMessage - h.ChunkType = ChunkTypeFinal - h.SecureChannelID = instance.secureChannelID - ssh.TokenID = instance.securityTokenID - sh.SequenceNumber = sequenceNumber - sh.RequestID = reqID - - m.MessageHeader.Header = h - m.MessageHeader.SymmetricSecurityHeader = ssh - m.MessageHeader.SequenceHeader = sh - } - var resp chan *response if respRequired { @@ -856,6 +762,14 @@ func (s *SecureChannel) sendAsyncWithTimeout( s.handlersMu.Unlock() } + instance.Lock() + defer instance.Unlock() + + m, err := instance.newRequestMessage(req, reqID, authToken, timeout) + if err != nil { + return nil, err + } + chunks, err := m.MarshalChunks(instance.maxBodySize) if err != nil { return nil, err diff --git a/uasc/sequence_header.go b/uasc/sequence_header.go index 4726160d..9d37cdf4 100644 --- a/uasc/sequence_header.go +++ b/uasc/sequence_header.go @@ -6,26 +6,10 @@ package uasc import ( "fmt" - "sync" "github.com/gopcua/opcua/ua" ) -func acquireSequenceHeader() *SequenceHeader { - if v, ok := sequenceHeaderPool.Get().(*SequenceHeader); ok { - return v - } - return &SequenceHeader{} -} - -func releaseSequenceHeader(h *SequenceHeader) { - h.RequestID = 0 - h.SequenceNumber = 0 - sequenceHeaderPool.Put(h) -} - -var sequenceHeaderPool sync.Pool - // SequenceHeader represents a Sequence Header in OPC UA Secure Conversation. type SequenceHeader struct { SequenceNumber uint32 diff --git a/uasc/symmetric_security_header.go b/uasc/symmetric_security_header.go index 31a032dd..c0ed7af2 100644 --- a/uasc/symmetric_security_header.go +++ b/uasc/symmetric_security_header.go @@ -6,26 +6,10 @@ package uasc import ( "fmt" - "sync" "github.com/gopcua/opcua/ua" ) -func acquireSymmetricSecurityHeader() *SymmetricSecurityHeader { - v := symmetricSecurityHeaderPool.Get() - if v == nil { - return &SymmetricSecurityHeader{} - } - return v.(*SymmetricSecurityHeader) -} - -func releaseSymmetricSecurityHeader(h *SymmetricSecurityHeader) { - h.TokenID = 0 - symmetricSecurityHeaderPool.Put(h) -} - -var symmetricSecurityHeaderPool sync.Pool - // SymmetricSecurityHeader represents a Symmetric Algorithm Security Header in OPC UA Secure Conversation. type SymmetricSecurityHeader struct { TokenID uint32 From 86b9f7e78db4c947344a0c620edba606ba1903fe Mon Sep 17 00:00:00 2001 From: yuanliang Date: Fri, 14 Jun 2024 14:31:22 +0800 Subject: [PATCH 14/14] feat: resolv conflicts Signed-off-by: yuanliang --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index ed724078..d3b7dce6 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,8 @@ go 1.20 require ( github.com/pascaldekloe/goe v0.1.1 github.com/pkg/errors v0.9.1 - golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 - golang.org/x/term v0.18.0 + golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 + golang.org/x/term v0.8.0 ) require golang.org/x/sys v0.18.0 // indirect diff --git a/go.sum b/go.sum index 338fc88f..77c6ab54 100644 --- a/go.sum +++ b/go.sum @@ -2,9 +2,9 @@ github.com/pascaldekloe/goe v0.1.1 h1:Ah6WQ56rZONR3RW3qWa2NCZ6JAVvSpUcoLBaOmYFt9 github.com/pascaldekloe/goe v0.1.1/go.mod h1:KSyfaxQOh0HZPjDP1FL/kFtbqYqrALJTaMafFUIccqU= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= -golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= +golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= +golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= -golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= +golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=