From 3ac63a8a5c2831f883e61899d2dfe754f9cc4e3a Mon Sep 17 00:00:00 2001 From: yuanliang Date: Fri, 31 May 2024 10:55:11 +0800 Subject: [PATCH] feat: Refactoring Encoder to Optimize Reflection Signed-off-by: yuanliang --- .gitignore | 1 + codec/encode.go | 609 ++++++++++++++++++++++++++++++++ go.mod | 6 +- go.sum | 9 +- ua/datatypes.go | 59 ++++ ua/diagnostic_info.go | 40 +++ ua/expanded_node_id.go | 18 + ua/extension_object.go | 22 ++ ua/node_id.go | 32 ++ ua/stream.go | 2 +- ua/variant.go | 30 ++ uacp/uacp.go | 42 +++ uasc/codec_test.go | 60 ++-- uasc/header.go | 14 + uasc/message.go | 98 +++++ uasc/message_test.go | 296 ++++++++++++++++ uasc/secure_channel.go | 2 +- uasc/secure_channel_instance.go | 1 + 18 files changed, 1302 insertions(+), 39 deletions(-) create mode 100644 codec/encode.go diff --git a/.gitignore b/.gitignore index 971ffc61..f7a0f8b7 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ dist/ __pycache__/ .vscode/ .idea/ +uasc/*.txt diff --git a/codec/encode.go b/codec/encode.go new file mode 100644 index 00000000..cc4c88c0 --- /dev/null +++ b/codec/encode.go @@ -0,0 +1,609 @@ +package codec + +import ( + "bytes" + "fmt" + "math" + "reflect" + "sync" + "time" +) + +// Marshal returns the OPCUA encoding of v. +func Marshal(v any) ([]byte, error) { + e := newEncodeState() + defer encodeStatePool.Put(e) + + err := e.marshal(v) + if err != nil { + return nil, err + } + buf := append([]byte(nil), e.Bytes()...) + + return buf, nil +} + +// Marshaler is the interface implemented by types that +// can marshal themselves into valid OPCUA. +type Marshaler interface { + MarshalOPCUA() ([]byte, error) +} + +// An UnsupportedTypeError is returned by [Marshal] when attempting +// to encode an unsupported value type. +type UnsupportedTypeError struct { + Type reflect.Type +} + +func (e *UnsupportedTypeError) Error() string { + return "opcua: unsupported type: " + e.Type.String() +} + +// An UnsupportedValueError is returned by [Marshal] when attempting +// to encode an unsupported value. +type UnsupportedValueError struct { + Value reflect.Value + Str string +} + +func (e *UnsupportedValueError) Error() string { + return "opcua: unsupported value: " + e.Str +} + +// A MarshalerError represents an error from calling a +// [Marshaler.MarshalOPCUA] method. +type MarshalerError struct { + Type reflect.Type + Err error + sourceFunc string +} + +func (e *MarshalerError) Error() string { + srcFunc := e.sourceFunc + if srcFunc == "" { + srcFunc = "MarshalOPCUA" + } + return "opcua: error calling " + srcFunc + + " for type " + e.Type.String() + + ": " + e.Err.Error() +} + +type encodeState struct { + bytes.Buffer + + ptrLevel uint + ptrSeen map[any]struct{} +} + +const startDetectingCyclesAfter = 1000 + +var encodeStatePool sync.Pool + +func newEncodeState() *encodeState { + if v := encodeStatePool.Get(); v != nil { + e := v.(*encodeState) + e.Reset() + if len(e.ptrSeen) > 0 { + panic("ptrEncoder.encode should have emptied ptrSeen via defers") + } + e.ptrLevel = 0 + return e + } + return &encodeState{ptrSeen: make(map[any]struct{})} +} + +// codecError is an error wrapper type for internal use only. +// Panics with errors are wrapped in codecError so that the top-level recover +// can distinguish intentional panics from this package. +type codecError struct{ error } + +func (e *encodeState) marshal(v any) (err error) { + defer func() { + if r := recover(); r != nil { + if je, ok := r.(codecError); ok { + err = je.error + } else { + panic(r) + } + } + }() + e.reflectValue(reflect.ValueOf(v)) + return nil +} + +// error aborts the encoding by panicking with err wrapped in jsonError. +func (e *encodeState) error(err error) { + panic(codecError{err}) +} + +func (e *encodeState) reflectValue(v reflect.Value) { + valueEncoder(v)(e, v) +} + +type encoderFunc func(e *encodeState, v reflect.Value) + +var encoderCache sync.Map // map[reflect.Type]encoderFunc + +func valueEncoder(v reflect.Value) encoderFunc { + if isTime(v) { + return timeEncoder + } + + return typeEncoder(v.Type()) +} + +func typeEncoder(t reflect.Type) encoderFunc { + if fi, ok := encoderCache.Load(t); ok { + return fi.(encoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f encoderFunc + ) + wg.Add(1) + fi, loaded := encoderCache.LoadOrStore(t, encoderFunc(func(e *encodeState, v reflect.Value) { + wg.Wait() + f(e, v) + })) + if loaded { + return fi.(encoderFunc) + } + + // Compute the real encoder and replace the indirect func with it. + f = newTypeEncoder(t, true) + wg.Done() + encoderCache.Store(t, f) + return f +} + +var marshalerType = reflect.TypeFor[Marshaler]() + +// newTypeEncoder constructs an encoderFunc for a type. +// The returned encoder only checks CanAddr when allowAddr is true. +func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc { + kind := t.Kind() + // If we have a non-pointer value whose type implements + // Marshaler with a value receiver, then we're better off taking + // the address of the value - otherwise we end up with an + // allocation as we cast the value to an interface. + if kind != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) { + return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false)) + } + if t.Implements(marshalerType) { + return marshalerEncoder + } + + switch kind { + case reflect.Bool: + return boolEncoder + case reflect.Int8: + return int8Encoder + case reflect.Uint8: + return uint8Encoder + case reflect.Int16: + return int16Encoder + case reflect.Uint16: + return uint16Encoder + case reflect.Int32: + return int32Encoder + case reflect.Uint32: + return uint32Encoder + case reflect.Int64: + return int64Encoder + case reflect.Uint64: + return uint64Encoder + case reflect.Float32: + return float32Encoder + case reflect.Float64: + return float64Encoder + case reflect.String: + return stringEncoder + case reflect.Interface: + return interfaceEncoder + case reflect.Ptr: + return newPtrEncoder(t) + case reflect.Struct: + return newStructEncoder(t) + case reflect.Array: + return newArrayEncoder(t) + case reflect.Slice: + return newSliceEncoder(t) + default: + return unsupportedTypeEncoder + } +} + +func isTime(val reflect.Value) bool { + return val.CanConvert(timeType) +} + +func marshalerEncoder(e *encodeState, v reflect.Value) { + if v.Kind() == reflect.Pointer && v.IsNil() { + return + } + m, ok := v.Interface().(Marshaler) + if !ok { + return + } + b, err := m.MarshalOPCUA() + if err == nil { + e.Grow(len(b)) + out := e.AvailableBuffer() + out = append(out, b...) + e.Buffer.Write(out) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err, "MarshalOPCUA"}) + } +} + +func addrMarshalerEncoder(e *encodeState, v reflect.Value) { + va := v.Addr() + if va.IsNil() { + e.WriteString("null") + return + } + m := va.Interface().(Marshaler) + b, err := m.MarshalOPCUA() + if err == nil { + e.Grow(len(b)) + out := e.AvailableBuffer() + out = append(out, b...) + e.Buffer.Write(out) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err, "MarshalOPCUA"}) + } +} + +func (e *encodeState) writeUint16(n uint16) { + e.Write([]byte{byte(n), byte(n >> 8)}) +} + +func (e *encodeState) writeUint32(n uint32) { + e.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) +} + +func (e *encodeState) writeUint64(n uint64) { + e.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24), byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56)}) +} + +func boolEncoder(e *encodeState, v reflect.Value) { + val := v.Bool() + if val { + e.WriteByte('1') + } else { + e.WriteByte('0') + } +} + +func int8Encoder(e *encodeState, v reflect.Value) { + val := int8(v.Int()) + e.WriteByte(byte(val)) +} + +func uint8Encoder(e *encodeState, v reflect.Value) { + val := uint8(v.Uint()) + e.WriteByte(val) +} + +func int16Encoder(e *encodeState, v reflect.Value) { + val := uint16(v.Int()) + e.writeUint16(val) +} + +func uint16Encoder(e *encodeState, v reflect.Value) { + val := uint16(v.Uint()) + e.writeUint16(val) +} + +func int32Encoder(e *encodeState, v reflect.Value) { + val := uint32(v.Int()) + e.writeUint32(val) +} + +func uint32Encoder(e *encodeState, v reflect.Value) { + val := uint32(v.Uint()) + e.writeUint32(val) +} + +func int64Encoder(e *encodeState, v reflect.Value) { + val := uint64(v.Int()) + e.writeUint64(val) +} + +func uint64Encoder(e *encodeState, v reflect.Value) { + val := v.Uint() + e.writeUint64(val) +} + +var ( + null = []byte{0xff, 0xff, 0xff, 0xff} + f32qnan = []byte{0x00, 0x00, 0xff, 0xc0} + f64qnan = []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf8, 0xff} +) + +func float32Encoder(e *encodeState, v reflect.Value) { + if math.IsNaN(v.Float()) { + e.Write(f32qnan) + } else { + val := math.Float32bits((float32)(v.Float())) + e.writeUint32(val) + } +} + +func float64Encoder(e *encodeState, v reflect.Value) { + if math.IsNaN(v.Float()) { + e.Write(f64qnan) + } else { + val := math.Float64bits(v.Float()) + e.writeUint64(val) + } +} + +func stringEncoder(e *encodeState, v reflect.Value) { + s := v.String() + if s == "" { + e.Write(null) + return + } + + l := len(s) + e.writeUint32(uint32(l)) + e.Write([]byte(s)) +} + +var timeType = reflect.TypeOf(time.Time{}) + +func timeEncoder(e *encodeState, v reflect.Value) { + var ts uint64 + val := v.Convert(timeType).Interface().(time.Time) + if !v.IsZero() { + // encode time in "100 nanosecond intervals since January 1, 1601" + ts = uint64(val.UTC().UnixNano()/100 + 116444736000000000) + } + e.writeUint64(ts) +} + +func interfaceEncoder(e *encodeState, v reflect.Value) { + if v.IsNil() { + return + } + e.reflectValue(v.Elem()) +} + +func unsupportedTypeEncoder(e *encodeState, v reflect.Value) { + e.error(&UnsupportedTypeError{v.Type()}) +} + +type structEncoder struct { + fields structFields +} + +type structFields struct { + list []field +} + +func (se structEncoder) encode(e *encodeState, v reflect.Value) { +FieldLoop: + for i := range se.fields.list { + f := &se.fields.list[i] + + // Find the nested struct field by following f.index. + fv := v + for _, i := range f.index { + if fv.Kind() == reflect.Pointer { + if fv.IsNil() { + continue FieldLoop + } + fv = fv.Elem() + } + fv = fv.Field(i) + } + + f.encoder(e, fv) + } +} + +func newStructEncoder(t reflect.Type) encoderFunc { + se := structEncoder{fields: cachedTypeFields(t)} + return se.encode +} + +func encodeByteSlice(e *encodeState, v reflect.Value) { + if v.IsNil() { + e.Write(null) + return + } + + n := v.Len() + e.writeUint32(uint32(n)) + + b := make([]byte, n) + reflect.Copy(reflect.ValueOf(b), v) + e.Write(b) +} + +// sliceEncoder just wraps an arrayEncoder, checking to make sure the value isn't nil. +type sliceEncoder struct { + arrayEnc encoderFunc +} + +func (se sliceEncoder) encode(e *encodeState, v reflect.Value) { + if v.IsNil() { + e.Write(null) + return + } + + if v.Len() > math.MaxInt32 { + panic("array too large") + } + + if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter { + // We're a large number of nested ptrEncoder.encode calls deep; + // start checking if we've run into a pointer cycle. + // Here we use a struct to memorize the pointer to the first element of the slice + // and its length. + ptr := struct { + ptr interface{} // always an unsafe.Pointer, but avoids a dependency on package unsafe + len int + }{v.UnsafePointer(), v.Len()} + if _, ok := e.ptrSeen[ptr]; ok { + e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())}) + } + e.ptrSeen[ptr] = struct{}{} + defer delete(e.ptrSeen, ptr) + } + se.arrayEnc(e, v) + e.ptrLevel-- +} + +func newSliceEncoder(t reflect.Type) encoderFunc { + // Byte slices get special treatment; arrays don't. + if t.Elem().Kind() == reflect.Uint8 { + p := reflect.PointerTo(t.Elem()) + if !p.Implements(marshalerType) { + return encodeByteSlice + } + } + enc := sliceEncoder{newArrayEncoder(t)} + return enc.encode +} + +type arrayEncoder struct { + elemEnc encoderFunc +} + +func (ae arrayEncoder) encode(e *encodeState, v reflect.Value) { + n := v.Len() + e.writeUint32(uint32(n)) + + // fast path for []byte + if v.Type().Elem().Kind() == reflect.Uint8 { + b := make([]byte, n) + reflect.Copy(reflect.ValueOf(b), v) + e.Write(b) + return + } + + // loop over elements + // we write all the elements, also the zero values + for i := 0; i < n; i++ { + ae.elemEnc(e, v.Index(i)) + } +} + +func newArrayEncoder(t reflect.Type) encoderFunc { + enc := arrayEncoder{typeEncoder(t.Elem())} + return enc.encode +} + +type ptrEncoder struct { + elemEnc encoderFunc +} + +func (pe ptrEncoder) encode(e *encodeState, v reflect.Value) { + if v.IsNil() { + return + } + if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter { + // We're a large number of nested ptrEncoder.encode calls deep; + // start checking if we've run into a pointer cycle. + ptr := v.Interface() + if _, ok := e.ptrSeen[ptr]; ok { + e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())}) + } + e.ptrSeen[ptr] = struct{}{} + defer delete(e.ptrSeen, ptr) + } + pe.elemEnc(e, v.Elem()) + e.ptrLevel-- +} + +func newPtrEncoder(t reflect.Type) encoderFunc { + enc := ptrEncoder{typeEncoder(t.Elem())} + return enc.encode +} + +type condAddrEncoder struct { + canAddrEnc, elseEnc encoderFunc +} + +func (ce condAddrEncoder) encode(e *encodeState, v reflect.Value) { + if v.CanAddr() { + ce.canAddrEnc(e, v) + } else { + ce.elseEnc(e, v) + } +} + +// newCondAddrEncoder returns an encoder that checks whether its value +// CanAddr and delegates to canAddrEnc if so, else to elseEnc. +func newCondAddrEncoder(canAddrEnc, elseEnc encoderFunc) encoderFunc { + enc := condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc} + return enc.encode +} + +// A field represents a single field found in a struct. +type field struct { + name string + nameBytes []byte // []byte(name) + + index []int + typ reflect.Type + + encoder encoderFunc +} + +// typeFields returns a list of fields that the encoder should recognize for a given type. +// The algorithm is a depth-first search of the set of structures, including any reachable anonymous structures, +// until the structure is traversed completely. +func typeFields(t reflect.Type) structFields { + var fields []field + + visitField := func(f reflect.StructField) { + t := f.Type + // time.Time is special because it has embedded structs that use timeEncoder. + if t == timeType || (t.Kind() == reflect.Pointer && t.Elem() == timeType) { + fields = append(fields, field{name: f.Name, nameBytes: []byte(f.Name), index: f.Index, typ: t, encoder: timeEncoder}) + return + } + // 如果t实现了Marshaler接口,则使用Marshaler的MarshalOPCUA方法 + if t.Implements(marshalerType) { + fields = append(fields, field{name: f.Name, nameBytes: []byte(f.Name), index: f.Index, typ: t, encoder: marshalerEncoder}) + return + } + + // Check for anonymous fields (embedded structs). + if f.Anonymous { + if t.Kind() == reflect.Pointer { + t = f.Type.Elem() + } + fields = append(fields, typeFields(t).list...) + } + + fields = append(fields, field{name: f.Name, nameBytes: []byte(f.Name), index: f.Index, typ: f.Type, encoder: typeEncoder(f.Type)}) + } + + // Process all fields in the root struct. + for i := 0; i < t.NumField(); i++ { + visitField(t.Field(i)) + } + return structFields{list: fields} +} + +var fieldCache sync.Map // map[reflect.Type]structFields + +// cachedTypeFields is like typeFields but uses a cache to avoid repeated work. +func cachedTypeFields(t reflect.Type) structFields { + if f, ok := fieldCache.Load(t); ok { + return f.(structFields) + } + f, _ := fieldCache.LoadOrStore(t, typeFields(t)) + return f.(structFields) +} diff --git a/go.mod b/go.mod index f75924d4..ed724078 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,11 @@ go 1.20 require ( github.com/pascaldekloe/goe v0.1.1 github.com/pkg/errors v0.9.1 - golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 - golang.org/x/term v0.8.0 + golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 + golang.org/x/term v0.18.0 ) -require golang.org/x/sys v0.8.0 // indirect +require golang.org/x/sys v0.18.0 // indirect retract ( v0.2.5 // https://github.com/gopcua/opcua/issues/538 diff --git a/go.sum b/go.sum index ac17d5df..ff4affa7 100644 --- a/go.sum +++ b/go.sum @@ -2,9 +2,6 @@ github.com/pascaldekloe/goe v0.1.1 h1:Ah6WQ56rZONR3RW3qWa2NCZ6JAVvSpUcoLBaOmYFt9 github.com/pascaldekloe/goe v0.1.1/go.mod h1:KSyfaxQOh0HZPjDP1FL/kFtbqYqrALJTaMafFUIccqU= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= -golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= diff --git a/ua/datatypes.go b/ua/datatypes.go index aa5e12ea..d3954366 100644 --- a/ua/datatypes.go +++ b/ua/datatypes.go @@ -5,11 +5,14 @@ package ua import ( + "bytes" "encoding/binary" "encoding/hex" "fmt" "strings" "time" + + "github.com/gopcua/opcua/codec" ) // These flags define which fields of a DataValue are set. @@ -84,6 +87,36 @@ func (d *DataValue) Encode(s *Stream) { } } +func (d *DataValue) MarshalOPCUA() ([]byte, error) { + var buf bytes.Buffer + var err error + var b []byte + + buf.WriteByte(d.EncodingMask) + if d.Has(DataValueValue) { + b, err = codec.Marshal(d.Value) + buf.Write(b) + } + if d.Has(DataValueStatusCode) { + buf.Write([]byte{byte(d.Status), byte(d.Status >> 8), byte(d.Status >> 16), byte(d.Status >> 24)}) + } + if d.Has(DataValueSourceTimestamp) { + b, err = codec.Marshal(d.SourceTimestamp) + buf.Write(b) + } + if d.Has(DataValueSourcePicoseconds) { + buf.Write([]byte{byte(d.SourcePicoseconds), byte(d.SourcePicoseconds >> 8)}) + } + if d.Has(DataValueServerTimestamp) { + b, err = codec.Marshal(d.ServerTimestamp) + buf.Write(b) + } + if d.Has(DataValueServerPicoseconds) { + buf.Write([]byte{byte(d.ServerPicoseconds), byte(d.ServerPicoseconds >> 8)}) + } + return buf.Bytes(), err +} + func (d *DataValue) Has(mask byte) bool { return d.EncodingMask&mask == mask } @@ -233,6 +266,32 @@ func (l *LocalizedText) Encode(s *Stream) { } } +func (l *LocalizedText) MarshalOPCUA() ([]byte, error) { + var buf bytes.Buffer + var err error + + buf.WriteByte(l.EncodingMask) + if l.Has(LocalizedTextLocale) { + n := len(l.Locale) + if n == 0 { + buf.Write([]byte{0xff, 0xff, 0xff, 0xff}) + } else { + buf.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) + buf.Write([]byte(l.Locale)) + } + } + if l.Has(LocalizedTextText) { + n := len(l.Text) + if n == 0 { + buf.Write([]byte{0xff, 0xff, 0xff, 0xff}) + } else { + buf.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) + buf.Write([]byte(l.Text)) + } + } + return buf.Bytes(), err +} + func (l *LocalizedText) Has(mask byte) bool { return l.EncodingMask&mask == mask } diff --git a/ua/diagnostic_info.go b/ua/diagnostic_info.go index b72d46df..e5457a9e 100644 --- a/ua/diagnostic_info.go +++ b/ua/diagnostic_info.go @@ -4,6 +4,12 @@ package ua +import ( + "bytes" + + "github.com/gopcua/opcua/codec" +) + // These flags define which fields of a DiagnosticInfo are set. // Bits are or'ed together if multiple fields are set. const ( @@ -83,6 +89,40 @@ func (d *DiagnosticInfo) Encode(s *Stream) { } } +func (d *DiagnosticInfo) MarshalOPCUA() ([]byte, error) { + var buf bytes.Buffer + buf.WriteByte(d.EncodingMask) + + if d.Has(DiagnosticInfoSymbolicID) { + buf.Write([]byte{byte(d.SymbolicID), byte(d.SymbolicID >> 8), byte(d.SymbolicID >> 16), byte(d.SymbolicID >> 24)}) + } + if d.Has(DiagnosticInfoNamespaceURI) { + buf.Write([]byte{byte(d.NamespaceURI), byte(d.NamespaceURI >> 8), byte(d.NamespaceURI >> 16), byte(d.NamespaceURI >> 24)}) + } + if d.Has(DiagnosticInfoLocale) { + buf.Write([]byte{byte(d.Locale), byte(d.Locale >> 8), byte(d.Locale >> 16), byte(d.Locale >> 24)}) + } + if d.Has(DiagnosticInfoLocalizedText) { + buf.Write([]byte{byte(d.LocalizedText), byte(d.LocalizedText >> 8), byte(d.LocalizedText >> 16), byte(d.LocalizedText >> 24)}) + } + if d.Has(DiagnosticInfoAdditionalInfo) { + b, _ := codec.Marshal(d.AdditionalInfo) + buf.Write(b) + } + if d.Has(DiagnosticInfoInnerStatusCode) { + buf.Write([]byte{byte(d.InnerStatusCode), byte(d.InnerStatusCode >> 8), byte(d.InnerStatusCode >> 16), byte(d.InnerStatusCode >> 24)}) + } + if d.Has(DiagnosticInfoInnerDiagnosticInfo) { + b, err := codec.Marshal(d.InnerDiagnosticInfo) + if err != nil { + return nil, err + } + buf.Write(b) + } + + return buf.Bytes(), nil +} + func (d *DiagnosticInfo) Has(mask byte) bool { return d.EncodingMask&mask == mask } diff --git a/ua/expanded_node_id.go b/ua/expanded_node_id.go index c0e7ee3c..257dc29c 100644 --- a/ua/expanded_node_id.go +++ b/ua/expanded_node_id.go @@ -5,11 +5,13 @@ package ua import ( + "bytes" "encoding/base64" "math" "strconv" "strings" + "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/errors" ) @@ -112,6 +114,22 @@ func (e *ExpandedNodeID) Encode(s *Stream) { } } +func (e *ExpandedNodeID) MarshalOPCUA() ([]byte, error) { + var buf bytes.Buffer + b, err := e.NodeID.MarshalOPCUA() + buf.Write(b) + + if e.HasNamespaceURI() { + b, err = codec.Marshal(e.NamespaceURI) + buf.Write(b) + } + if e.HasServerIndex() { + b, err = codec.Marshal(e.ServerIndex) + buf.Write(b) + } + return buf.Bytes(), err +} + // HasNamespaceURI checks if an ExpandedNodeID has NamespaceURI Flag. func (e *ExpandedNodeID) HasNamespaceURI() bool { return e.NodeID.EncodingMask()>>7&0x1 == 1 diff --git a/ua/extension_object.go b/ua/extension_object.go index a616a41f..82e3ee6b 100644 --- a/ua/extension_object.go +++ b/ua/extension_object.go @@ -5,6 +5,9 @@ package ua import ( + "bytes" + + "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/debug" "github.com/gopcua/opcua/id" ) @@ -104,6 +107,25 @@ func (e *ExtensionObject) Encode(s *Stream) { s.Write(body.Bytes()) } +func (e *ExtensionObject) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + b, err := codec.Marshal(e.TypeID) + buf.Write(b) + buf.WriteByte(e.EncodingMask) + if e.EncodingMask == ExtensionObjectEmpty { + return buf.Bytes(), err + } + + body, err := codec.Marshal(e.Value) + if err != nil { + return buf.Bytes(), err + } + n := uint32(len(body)) + buf.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) + buf.Write(body) + return buf.Bytes(), err +} + func (e *ExtensionObject) UpdateMask() { if e.Value == nil { e.EncodingMask = ExtensionObjectEmpty diff --git a/ua/node_id.go b/ua/node_id.go index 9584b1b5..c9934aaf 100644 --- a/ua/node_id.go +++ b/ua/node_id.go @@ -5,11 +5,13 @@ package ua import ( + "bytes" "encoding/base64" "encoding/json" "fmt" "math" + "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/errors" ) @@ -389,6 +391,36 @@ func (n *NodeID) Encode(s *Stream) { } } +func (n *NodeID) MarshalOPCUA() ([]byte, error) { + var buf bytes.Buffer + buf.WriteByte(byte(n.mask)) + switch n.Type() { + case NodeIDTypeTwoByte: + buf.WriteByte(byte(n.nid)) + case NodeIDTypeFourByte: + buf.WriteByte(byte(n.ns)) + buf.Write([]byte{byte(n.nid), byte(n.nid >> 8)}) + case NodeIDTypeNumeric: + buf.Write([]byte{byte(n.ns), byte(n.ns >> 8), byte(n.nid), byte(n.nid >> 8), byte(n.nid >> 16), byte(n.nid >> 24)}) + case NodeIDTypeGUID: + buf.Write([]byte{byte(n.ns), byte(n.ns >> 8)}) + b, _ := codec.Marshal(n.gid) + buf.Write(b) + case NodeIDTypeByteString, NodeIDTypeString: + buf.Write([]byte{byte(n.ns), byte(n.ns >> 8)}) + l := uint32(len(n.bid)) + if l == 0 { + buf.Write([]byte{0xff, 0xff, 0xff, 0xff}) + } else { + buf.Write([]byte{byte(l), byte(l >> 8), byte(l >> 16), byte(l >> 24)}) + buf.Write(n.bid) + } + default: + return nil, fmt.Errorf("invalid node id type: %d", n.mask) + } + return buf.Bytes(), nil +} + func (n *NodeID) MarshalJSON() ([]byte, error) { if n == nil { return []byte(`null`), nil diff --git a/ua/stream.go b/ua/stream.go index 9f5d39ed..27063e0a 100644 --- a/ua/stream.go +++ b/ua/stream.go @@ -9,7 +9,7 @@ import ( "github.com/gopcua/opcua/errors" ) -const DefaultBufSize = 1024 +const DefaultBufSize = 256 type Stream struct { buf []byte diff --git a/ua/variant.go b/ua/variant.go index bb29604c..dd9f8dfa 100644 --- a/ua/variant.go +++ b/ua/variant.go @@ -5,10 +5,12 @@ package ua import ( + "bytes" "fmt" "reflect" "time" + "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/errors" ) @@ -340,6 +342,34 @@ func (m *Variant) Encode(s *Stream) { } } +func (m *Variant) MarshalOPCUA() ([]byte, error) { + var buf bytes.Buffer + buf.WriteByte(m.mask) + + // a null value specifies that no other fields are encoded + if m.Type() == TypeIDNull { + return buf.Bytes(), nil + } + + if m.Has(VariantArrayValues) { + buf.Write([]byte{byte(m.arrayLength), byte(m.arrayLength >> 8), byte(m.arrayLength >> 16), byte(m.arrayLength >> 24)}) + } + + b, err := codec.Marshal(m.value) + if err != nil { + return buf.Bytes(), err + } + buf.Write(b) + + if m.Has(VariantArrayDimensions) { + buf.Write([]byte{byte(m.arrayDimensionsLength), byte(m.arrayDimensionsLength >> 8), byte(m.arrayDimensionsLength >> 16), byte(m.arrayDimensionsLength >> 24)}) + for _, v := range m.arrayDimensions { + buf.Write([]byte{byte(v), byte(v >> 8), byte(v >> 16), byte(v >> 24)}) + } + } + return buf.Bytes(), nil +} + // encode recursively writes the values to the buffer. func (m *Variant) encode(s *Stream, val reflect.Value) { if val.Kind() != reflect.Slice || m.Type() == TypeIDByteString { diff --git a/uacp/uacp.go b/uacp/uacp.go index 2ae23826..6b921474 100644 --- a/uacp/uacp.go +++ b/uacp/uacp.go @@ -5,6 +5,9 @@ package uacp import ( + "bytes" + "encoding/binary" + "github.com/gopcua/opcua/errors" "github.com/gopcua/opcua/ua" ) @@ -55,6 +58,18 @@ func (h *Header) Encode(s *ua.Stream) { s.WriteUint32(h.MessageSize) } +func (h *Header) MarshalOPCUA() ([]byte, error) { + if len(h.MessageType) != 3 { + return nil, errors.Errorf("invalid message type: %q", h.MessageType) + } + + var buf bytes.Buffer + buf.Write([]byte(h.MessageType)) + buf.WriteByte(h.ChunkType) + buf.Write([]byte{byte(h.MessageSize), byte(h.MessageSize >> 8), byte(h.MessageSize >> 16), byte(h.MessageSize >> 24)}) + return buf.Bytes(), nil +} + // Hello represents a OPC UA Hello. // // Specification: Part6, 7.1.2.3 @@ -87,6 +102,23 @@ func (h *Hello) Encode(s *ua.Stream) { s.WriteString(h.EndpointURL) } +func (h *Hello) MarshalOPCUA() ([]byte, error) { + var buf bytes.Buffer + buf.Write([]byte{byte(h.Version), byte(h.Version >> 8), byte(h.Version >> 16), byte(h.Version >> 24)}) + buf.Write([]byte{byte(h.ReceiveBufSize), byte(h.ReceiveBufSize >> 8), byte(h.ReceiveBufSize >> 16), byte(h.ReceiveBufSize >> 24)}) + buf.Write([]byte{byte(h.SendBufSize), byte(h.SendBufSize >> 8), byte(h.SendBufSize >> 16), byte(h.SendBufSize >> 24)}) + buf.Write([]byte{byte(h.MaxMessageSize), byte(h.MaxMessageSize >> 8), byte(h.MaxMessageSize >> 16), byte(h.MaxMessageSize >> 24)}) + buf.Write([]byte{byte(h.MaxChunkCount), byte(h.MaxChunkCount >> 8), byte(h.MaxChunkCount >> 16), byte(h.MaxChunkCount >> 24)}) + if len(h.EndpointURL) == 0 { + buf.Write([]byte{0xff, 0xff, 0xff, 0xff}) + } else { + n := len(h.EndpointURL) + buf.Write([]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}) + buf.WriteString(h.EndpointURL) + } + return buf.Bytes(), nil +} + // Acknowledge represents a OPC UA Acknowledge. // // Specification: Part6, 7.1.2.4 @@ -116,6 +148,16 @@ func (a *Acknowledge) Encode(s *ua.Stream) { s.WriteUint32(a.MaxChunkCount) } +func (a *Acknowledge) MarshalOPCUA() ([]byte, error) { + buf := make([]byte, 0, 160) + binary.LittleEndian.AppendUint32(buf, a.Version) + binary.LittleEndian.AppendUint32(buf, a.ReceiveBufSize) + binary.LittleEndian.AppendUint32(buf, a.SendBufSize) + binary.LittleEndian.AppendUint32(buf, a.MaxMessageSize) + binary.LittleEndian.AppendUint32(buf, a.MaxChunkCount) + return buf, nil +} + // ReverseHello represents a OPC UA ReverseHello. // // Specification: Part6, 7.1.2.6 diff --git a/uasc/codec_test.go b/uasc/codec_test.go index 9bd52826..bb636a48 100644 --- a/uasc/codec_test.go +++ b/uasc/codec_test.go @@ -7,10 +7,9 @@ package uasc import ( - "reflect" "testing" - "github.com/gopcua/opcua/ua" + "github.com/gopcua/opcua/codec" "github.com/pascaldekloe/goe/verify" ) @@ -29,37 +28,42 @@ func RunCodecTest(t *testing.T, cases []CodecTestCase) { for _, c := range cases { t.Run(c.Name, func(t *testing.T) { - t.Run("decode", func(t *testing.T) { - // create a new instance of the same type as c.Struct - typ := reflect.ValueOf(c.Struct).Type() - var v reflect.Value - switch typ.Kind() { - case reflect.Ptr: - v = reflect.New(typ.Elem()) // typ: *struct, v: *struct - case reflect.Slice: - v = reflect.New(typ) // typ: []x, v: *[]x - default: - t.Fatalf("%T is not a pointer or a slice", c.Struct) - } + // t.Run("decode", func(t *testing.T) { + // // create a new instance of the same type as c.Struct + // typ := reflect.ValueOf(c.Struct).Type() + // var v reflect.Value + // switch typ.Kind() { + // case reflect.Ptr: + // v = reflect.New(typ.Elem()) // typ: *struct, v: *struct + // case reflect.Slice: + // v = reflect.New(typ) // typ: []x, v: *[]x + // default: + // t.Fatalf("%T is not a pointer or a slice", c.Struct) + // } - if _, err := ua.Decode(c.Bytes, v.Interface()); err != nil { - t.Fatal(err) - } + // if _, err := ua.Decode(c.Bytes, v.Interface()); err != nil { + // t.Fatal(err) + // } - // if v is a *[]x we need to dereference it before comparing it. - if typ.Kind() == reflect.Slice { - v = v.Elem() - } - verify.Values(t, "", v.Interface(), c.Struct) - }) + // // if v is a *[]x we need to dereference it before comparing it. + // if typ.Kind() == reflect.Slice { + // v = v.Elem() + // } + // verify.Values(t, "", v.Interface(), c.Struct) + // }) t.Run("encode", func(t *testing.T) { - s := ua.NewStream(ua.DefaultBufSize) - s.WriteAny(c.Struct) - if s.Error() != nil { - t.Fatal(s.Error()) + b, err := codec.Marshal(c.Struct) + if err != nil { + t.Fatal(err) } - verify.Values(t, "", s.Bytes(), c.Bytes) + verify.Values(t, "", b, c.Bytes) + // s := ua.NewStream(ua.DefaultBufSize) + // s.WriteAny(c.Struct) + // if s.Error() != nil { + // t.Fatal(s.Error()) + // } + // verify.Values(t, "", s.Bytes(), c.Bytes) }) }) } diff --git a/uasc/header.go b/uasc/header.go index 75a5fc26..20444877 100644 --- a/uasc/header.go +++ b/uasc/header.go @@ -5,6 +5,7 @@ package uasc import ( + "bytes" "fmt" "github.com/gopcua/opcua/errors" @@ -62,6 +63,19 @@ func (h *Header) Encode(s *ua.Stream) { s.WriteUint32(h.SecureChannelID) } +func (h *Header) MarshalOPCUA() ([]byte, error) { + if len(h.MessageType) != 3 { + return nil, errors.Errorf("invalid message type: %q", h.MessageType) + } + + var buf bytes.Buffer + buf.WriteString(h.MessageType) + buf.WriteByte(h.ChunkType) + buf.Write([]byte{byte(h.MessageSize), byte(h.MessageSize >> 8), byte(h.MessageSize >> 16), byte(h.MessageSize >> 24)}) + buf.Write([]byte{byte(h.SecureChannelID), byte(h.SecureChannelID >> 8), byte(h.SecureChannelID >> 16), byte(h.SecureChannelID >> 24)}) + return buf.Bytes(), nil +} + // String returns Header in string. func (h *Header) String() string { return fmt.Sprintf( diff --git a/uasc/message.go b/uasc/message.go index 4d53d139..1bc23566 100644 --- a/uasc/message.go +++ b/uasc/message.go @@ -5,8 +5,10 @@ package uasc import ( + "fmt" "math" + "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/errors" "github.com/gopcua/opcua/ua" ) @@ -146,6 +148,10 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { if buf.Error() != nil { return nil, errors.Errorf("failed to encode chunk: %s", buf.Error()) } + for _, v := range buf.Bytes() { + fmt.Printf("0x%02x,", v) + } + fmt.Println() return [][]byte{buf.Bytes()}, nil case "CLO", "MSG": @@ -182,3 +188,95 @@ func (m *Message) EncodeChunks(maxBodySize uint32) ([][]byte, error) { return nil, errors.Errorf("invalid message type %q", m.Header.MessageType) } } + +func (m *Message) MarshalOPCUA() ([]byte, error) { + chunks, err := m.MarshalChunks(math.MaxUint32) + if err != nil { + return nil, err + } + + return chunks[0], nil +} + +func (m *Message) MarshalChunks(maxBodySize uint32) ([][]byte, error) { + typeID, err := codec.Marshal(m.TypeID) + if err != nil { + return nil, errors.Errorf("failed to encode typeid: %s", err) + } + service, err := codec.Marshal(m.Service) + if err != nil { + return nil, errors.Errorf("failed to encode service: %s", err) + } + + dataBody := make([]byte, len(typeID)+len(service)) + copy(dataBody, typeID) + copy(dataBody[len(typeID):], service) + + nrChunks := uint32(len(dataBody))/(maxBodySize) + 1 + chunks := make([][]byte, nrChunks) + + switch m.Header.MessageType { + case "OPN": + asymmetricSecurityHeader, err := codec.Marshal(m.AsymmetricSecurityHeader) + if err != nil { + return nil, errors.Errorf("failed to encode asymmetric security header: %s", err) + } + sequenceHeader, err := codec.Marshal(m.SequenceHeader) + if err != nil { + return nil, errors.Errorf("failed to encode sequence header: %s", err) + } + + m.Header.MessageSize = uint32(12 + len(asymmetricSecurityHeader) + len(sequenceHeader) + len(dataBody)) + header, err := codec.Marshal(m.Header) + if err != nil { + return nil, errors.Errorf("failed to encode header: %s", err) + } + chunks[0] = append(chunks[0], header...) + chunks[0] = append(chunks[0], asymmetricSecurityHeader...) + chunks[0] = append(chunks[0], sequenceHeader...) + chunks[0] = append(chunks[0], dataBody...) + return chunks, nil + + case "CLO", "MSG": + symmetricSecurityHeader, err := codec.Marshal(m.SymmetricSecurityHeader) + if err != nil { + return nil, errors.Errorf("failed to encode symmetric security header: %s", err) + } + sequenceHeader, err := codec.Marshal(m.SequenceHeader) + if err != nil { + return nil, errors.Errorf("failed to encode sequence header: %s", err) + } + + start, end := 0, int(maxBodySize) + for i := uint32(0); i < nrChunks-1; i++ { + m.Header.MessageSize = maxBodySize + 24 + m.Header.ChunkType = ChunkTypeIntermediate + + header, err := codec.Marshal(m.Header) + if err != nil { + return nil, errors.Errorf("failed to encode header: %s", err) + } + + chunks[i] = append(chunks[i], header...) + chunks[i] = append(chunks[i], symmetricSecurityHeader...) + chunks[i] = append(chunks[i], sequenceHeader...) + chunks[i] = append(chunks[i], dataBody[start:end]...) + start, end = end, end+int(maxBodySize) + } + + m.Header.ChunkType = ChunkTypeFinal + m.Header.MessageSize = uint32(24 + len(dataBody)) + + header, err := codec.Marshal(m.Header) + if err != nil { + return nil, err + } + chunks[nrChunks-1] = append(chunks[nrChunks-1], header...) + chunks[nrChunks-1] = append(chunks[nrChunks-1], symmetricSecurityHeader...) + chunks[nrChunks-1] = append(chunks[nrChunks-1], sequenceHeader...) + chunks[nrChunks-1] = append(chunks[nrChunks-1], dataBody...) + return chunks, nil + default: + return nil, errors.Errorf("invalid message type %q", m.Header.MessageType) + } +} diff --git a/uasc/message_test.go b/uasc/message_test.go index 1311456b..7056d965 100644 --- a/uasc/message_test.go +++ b/uasc/message_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/gopcua/opcua/codec" "github.com/gopcua/opcua/id" "github.com/gopcua/opcua/ua" @@ -245,6 +246,56 @@ func TestMessage(t *testing.T) { 0x00, 0x00, 0x00, 0x00, 0x00, }, }, + // { + // Name: "OPN-2", + // Struct: func() interface{} { + // s := &SecureChannel{ + // endpointURL: "opc.tcp://192.168.118.199:4840", + // cfg: &Config{ + // SecurityPolicyURI: "opc.tcp://192.168.118.199:4840/FOO", + // }, + // } + // instance := &channelInstance{ + // sc: s, + // sequenceNumber: 1, + // securityTokenID: 0, + // maxBodySize: 65510, + // } + // m := instance.newMessage( + // &ua.OpenSecureChannelRequest{ + // RequestHeader: &ua.RequestHeader{ + // AuthenticationToken: ua.NewTwoByteNodeID(0), + // Timestamp: time.Date(2024, time.May, 30, 16, 17, 44, 0, time.Local), + // RequestHandle: 1, + // ReturnDiagnostics: 0, + // TimeoutHint: 10000, + // AdditionalHeader: ua.NewExtensionObject(nil), + // }, + // ClientProtocolVersion: 0, + // RequestType: ua.SecurityTokenRequestTypeIssue, + // SecurityMode: ua.MessageSecurityModeNone, + // ClientNonce: []byte{}, + // RequestedLifetime: 3600000, + // }, + // id.OpenSecureChannelRequest_Encoding_DefaultBinary, + // s.nextRequestID(), + // ) + + // // set message size manually, since it is computed in Encode + // // otherwise, the decode tests failed. + // m.Header.MessageSize = 119 + + // return m + // }(), + // Bytes: []byte{ // OpenSecureChannelRequest + // // Message Header + // // MessageType: OPN + // 0x4f, 0x50, 0x4e, + // // Chunk Type: Final + // 0x46, + // 0x77, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, 0x6f, 0x70, 0x63, 0x2e, 0x74, 0x63, 0x70, 0x3a, 0x2f, 0x2f, 0x31, 0x39, 0x32, 0x2e, 0x31, 0x36, 0x38, 0x2e, 0x31, 0x31, 0x38, 0x2e, 0x31, 0x39, 0x39, 0x3a, 0x34, 0x38, 0x34, 0x30, 0x2f, 0x46, 0x4f, 0x4f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0xbe, 0x01, 0x00, 0x00, 0x00, 0x04, 0xd5, 0xd8, 0x69, 0xb2, 0xda, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x10, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0xee, 0x36, 0x00, + // }, + // }, } RunCodecTest(t, cases) } @@ -495,3 +546,248 @@ func BenchmarkEncodeMessage(b *testing.B) { } } } + +func BenchmarkEncodeMessage_WithCodec(b *testing.B) { + cases := []CodecTestCase{ + { + Name: "OPN", + Struct: func() interface{} { + s := &SecureChannel{ + cfg: &Config{ + SecurityPolicyURI: "http://gopcua.example/OPCUA/SecurityPolicy#Foo", + }, + } + instance := &channelInstance{ + sc: s, + sequenceNumber: 0, + securityTokenID: 0, + } + m := instance.newMessage( + &ua.OpenSecureChannelRequest{ + RequestHeader: &ua.RequestHeader{ + AuthenticationToken: ua.NewTwoByteNodeID(0), + Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), + RequestHandle: 1, + ReturnDiagnostics: 0x03ff, + AdditionalHeader: ua.NewExtensionObject(nil), + }, + ClientProtocolVersion: 0, + RequestType: ua.SecurityTokenRequestTypeIssue, + SecurityMode: ua.MessageSecurityModeNone, + RequestedLifetime: 6000000, + }, + id.OpenSecureChannelRequest_Encoding_DefaultBinary, + s.nextRequestID(), + ) + + // set message size manually, since it is computed in Encode + // otherwise, the decode tests failed. + m.Header.MessageSize = 131 + + return m + }(), + Bytes: []byte{ // OpenSecureChannelRequest + // Message Header + // MessageType: OPN + 0x4f, 0x50, 0x4e, + // Chunk Type: Final + 0x46, + // MessageSize: 131 + 0x83, 0x00, 0x00, 0x00, + // SecureChannelID: 0 + 0x00, 0x00, 0x00, 0x00, + // AsymmetricSecurityHeader + // SecurityPolicyURILength + 0x2e, 0x00, 0x00, 0x00, + // SecurityPolicyURI + 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x67, + 0x6f, 0x70, 0x63, 0x75, 0x61, 0x2e, 0x65, 0x78, + 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2f, 0x4f, 0x50, + 0x43, 0x55, 0x41, 0x2f, 0x53, 0x65, 0x63, 0x75, + 0x72, 0x69, 0x74, 0x79, 0x50, 0x6f, 0x6c, 0x69, + 0x63, 0x79, 0x23, 0x46, 0x6f, 0x6f, + // SenderCertificate + 0xff, 0xff, 0xff, 0xff, + // ReceiverCertificateThumbprint + 0xff, 0xff, 0xff, 0xff, + // Sequence Header + // SequenceNumber + 0x01, 0x00, 0x00, 0x00, + // RequestID + 0x01, 0x00, 0x00, 0x00, + // TypeID + 0x01, 0x00, 0xbe, 0x01, + + // RequestHeader + // - AuthenticationToken + 0x00, 0x00, + // - Timestamp + 0x00, 0x98, 0x67, 0xdd, 0xfd, 0x30, 0xd4, 0x01, + // - RequestHandle + 0x01, 0x00, 0x00, 0x00, + // - ReturnDiagnostics + 0xff, 0x03, 0x00, 0x00, + // - AuditEntry + 0xff, 0xff, 0xff, 0xff, + // - TimeoutHint + 0x00, 0x00, 0x00, 0x00, + // - AdditionalHeader + // - TypeID + 0x00, 0x00, + // - EncodingMask + 0x00, + // ClientProtocolVersion + 0x00, 0x00, 0x00, 0x00, + // SecurityTokenRequestType + 0x00, 0x00, 0x00, 0x00, + // MessageSecurityMode + 0x01, 0x00, 0x00, 0x00, + // ClientNonce + 0xff, 0xff, 0xff, 0xff, + // RequestedLifetime + 0x80, 0x8d, 0x5b, 0x00, + }, + }, + { + Name: "MSG", + Struct: func() interface{} { + s := &SecureChannel{ + cfg: &Config{ + SecurityPolicyURI: "http://gopcua.example/OPCUA/SecurityPolicy#Foo", + }, + } + instance := &channelInstance{ + sc: s, + sequenceNumber: 0, + securityTokenID: 0, + } + m := instance.newMessage( + &ua.GetEndpointsRequest{ + RequestHeader: &ua.RequestHeader{ + AuthenticationToken: ua.NewTwoByteNodeID(0), + Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), + RequestHandle: 1, + ReturnDiagnostics: 0x03ff, + AdditionalHeader: ua.NewExtensionObject(nil), + }, + EndpointURL: "opc.tcp://wow.its.easy:11111/UA/Server", + }, + id.GetEndpointsRequest_Encoding_DefaultBinary, + s.nextRequestID(), + ) + + // set message size manually, since it is computed in Encode + // otherwise, the decode tests failed. + m.Header.MessageSize = 107 + + return m + }(), + Bytes: []byte{ // GetEndpointsRequest + // Message Header + // MessageType: MSG + 0x4d, 0x53, 0x47, + // Chunk Type: Final + 0x46, + // MessageSize: 107 + 0x6b, 0x00, 0x00, 0x00, + // SecureChannelID: 0 + 0x00, 0x00, 0x00, 0x00, + // SymmetricSecurityHeader + // TokenID + 0x00, 0x00, 0x00, 0x00, + // Sequence Header + // SequenceNumber + 0x01, 0x00, 0x00, 0x00, + // RequestID + 0x01, 0x00, 0x00, 0x00, + // TypeID + 0x01, 0x00, 0xac, 0x01, + // RequestHeader + 0x00, 0x00, 0x00, 0x98, 0x67, 0xdd, 0xfd, 0x30, + 0xd4, 0x01, 0x01, 0x00, 0x00, 0x00, 0xff, 0x03, + 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, + // ClientProtocolVersion + 0x26, 0x00, 0x00, 0x00, 0x6f, 0x70, 0x63, 0x2e, + 0x74, 0x63, 0x70, 0x3a, 0x2f, 0x2f, 0x77, 0x6f, + 0x77, 0x2e, 0x69, 0x74, 0x73, 0x2e, 0x65, 0x61, + 0x73, 0x79, 0x3a, 0x31, 0x31, 0x31, 0x31, 0x31, + 0x2f, 0x55, 0x41, 0x2f, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, + // LocaleIDs + 0xff, 0xff, 0xff, 0xff, + // ProfileURIs + 0xff, 0xff, 0xff, 0xff, + }, + }, { + Name: "CLO", + Struct: func() interface{} { + s := &SecureChannel{ + cfg: &Config{ + SecurityPolicyURI: "http://gopcua.example/OPCUA/SecurityPolicy#Foo", + }, + } + instance := &channelInstance{ + sc: s, + sequenceNumber: 0, + securityTokenID: 0, + } + m := instance.newMessage( + &ua.CloseSecureChannelRequest{ + RequestHeader: &ua.RequestHeader{ + AuthenticationToken: ua.NewTwoByteNodeID(0), + Timestamp: time.Date(2018, time.August, 10, 23, 0, 0, 0, time.UTC), + RequestHandle: 1, + ReturnDiagnostics: 0x03ff, + AdditionalHeader: ua.NewExtensionObject(nil), + }, + }, + id.CloseSecureChannelRequest_Encoding_DefaultBinary, + s.nextRequestID(), + ) + + // set message size manually, since it is computed in Encode + // otherwise, the decode tests failed. + m.Header.MessageSize = 57 + + return m + }(), + Bytes: []byte{ // OpenSecureChannelRequest + // Message Header + // MessageType: CLO + 0x43, 0x4c, 0x4f, + // Chunk Type: Final + 0x46, + // MessageSize: 57 + 0x39, 0x00, 0x00, 0x00, + // SecureChannelID: 0 + 0x00, 0x00, 0x00, 0x00, + // SymmetricSecurityHeader + // TokenID + 0x00, 0x00, 0x00, 0x00, + // Sequence Header + // SequenceNumber + 0x01, 0x00, 0x00, 0x00, + // RequestID + 0x01, 0x00, 0x00, 0x00, + // TypeID + 0x01, 0x00, 0xc4, 0x01, + // RequestHeader + 0x00, 0x00, 0x00, 0x98, 0x67, 0xdd, 0xfd, 0x30, + 0xd4, 0x01, 0x01, 0x00, 0x00, 0x00, 0xff, 0x03, + 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, + }, + }, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, tc := range cases { + _, err := codec.Marshal(tc.Struct) + if err != nil { + b.Fatalf("fail to encode message, err: %v", err) + } + } + } +} diff --git a/uasc/secure_channel.go b/uasc/secure_channel.go index 78822642..10e0cf46 100644 --- a/uasc/secure_channel.go +++ b/uasc/secure_channel.go @@ -771,7 +771,7 @@ func (s *SecureChannel) sendAsyncWithTimeout( s.handlersMu.Unlock() } - chunks, err := m.EncodeChunks(instance.maxBodySize) + chunks, err := m.MarshalChunks(instance.maxBodySize) if err != nil { return nil, err } diff --git a/uasc/secure_channel_instance.go b/uasc/secure_channel_instance.go index bf123eb8..0a76f46d 100644 --- a/uasc/secure_channel_instance.go +++ b/uasc/secure_channel_instance.go @@ -71,6 +71,7 @@ func (c *channelInstance) newRequestMessage(req ua.Request, reqID uint32, authTo AuthenticationToken: authToken, Timestamp: c.sc.timeNow(), RequestHandle: reqID, // TODO: can I cheat like this? + AdditionalHeader: ua.NewExtensionObject(nil), } if timeout > 0 && timeout < c.sc.cfg.RequestTimeout {