From f404679eefa8bfcc5c91b36b524e4394db6d67cf Mon Sep 17 00:00:00 2001 From: David Hutchison Date: Sun, 26 May 2019 15:05:30 -0600 Subject: [PATCH 1/5] move recv() loop from uasc into opcua package In preparation for a server implmentation, move the main read loop out of the uasc package and allow the caller (client / server) to receive unsolicited uasc messages. --- client.go | 40 ++- examples/accesslevel/accesslevel.go | 3 +- examples/browse/browse.go | 3 +- examples/crypto/crypto.go | 5 +- examples/datetime/datetime.go | 1 + examples/history-read/history-read.go | 3 +- examples/read/read.go | 3 +- examples/write/write.go | 3 +- uapolicy/cert_utils.go | 16 +- uasc/secure_channel.go | 479 ++++++++++++++++++-------- 10 files changed, 390 insertions(+), 166 deletions(-) diff --git a/client.go b/client.go index be3226e9..21dc7e85 100644 --- a/client.go +++ b/client.go @@ -25,7 +25,7 @@ import ( // GetEndpoints returns the available endpoint descriptions for the server. func GetEndpoints(endpoint string) ([]*ua.EndpointDescription, error) { c := NewClient(endpoint) - if err := c.Dial(); err != nil { + if err := c.Dial(context.Background()); err != nil { return nil, err } defer c.Close() @@ -127,11 +127,15 @@ func NewClient(endpoint string, opts ...Option) *Client { } // Connect establishes a secure channel and creates a new session. -func (c *Client) Connect() (err error) { +func (c *Client) Connect(ctx context.Context) (err error) { + if ctx == nil { + ctx = context.Background() + } + if c.sechan != nil { return fmt.Errorf("already connected") } - if err := c.Dial(); err != nil { + if err := c.Dial(ctx); err != nil { return err } s, err := c.CreateSession(c.sessionCfg) @@ -147,12 +151,16 @@ func (c *Client) Connect() (err error) { } // Dial establishes a secure channel. -func (c *Client) Dial() error { +func (c *Client) Dial(ctx context.Context) error { + if ctx == nil { + ctx = context.Background() + } + c.once.Do(func() { c.session.Store((*Session)(nil)) }) if c.sechan != nil { return fmt.Errorf("secure channel already connected") } - conn, err := uacp.Dial(context.Background(), c.endpointURL) + conn, err := uacp.Dial(ctx, c.endpointURL) if err != nil { return err } @@ -161,14 +169,34 @@ func (c *Client) Dial() error { _ = conn.Close() return err } + c.sechan = sechan + go c.monitorChannel(ctx) + if err := sechan.Open(); err != nil { _ = conn.Close() + c.sechan = nil return err } - c.sechan = sechan + return nil } +func (c *Client) monitorChannel(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return nil + default: + msg := c.sechan.Receive(ctx) + debug.Printf("Received unsolicited message from server: %T", msg.V) + if msg.Err != nil { + debug.Printf("Received error: %s", msg.Err) + return msg.Err + } + } + } +} + // Close closes the session and the secure channel. func (c *Client) Close() error { // try to close the session but ignore any error diff --git a/examples/accesslevel/accesslevel.go b/examples/accesslevel/accesslevel.go index 7cc1c9a2..d3c5c31a 100644 --- a/examples/accesslevel/accesslevel.go +++ b/examples/accesslevel/accesslevel.go @@ -5,6 +5,7 @@ package main import ( + "context" "flag" "log" @@ -23,7 +24,7 @@ func main() { log.SetFlags(0) c := opcua.NewClient(*endpoint) - if err := c.Connect(); err != nil { + if err := c.Connect(context.Background()); err != nil { log.Fatal(err) } defer c.Close() diff --git a/examples/browse/browse.go b/examples/browse/browse.go index d3262ac9..782d8d38 100644 --- a/examples/browse/browse.go +++ b/examples/browse/browse.go @@ -5,6 +5,7 @@ package main import ( + "context" "flag" "fmt" "log" @@ -53,7 +54,7 @@ func main() { log.SetFlags(0) c := opcua.NewClient(*endpoint) - if err := c.Connect(); err != nil { + if err := c.Connect(context.Background()); err != nil { log.Fatal(err) } defer c.Close() diff --git a/examples/crypto/crypto.go b/examples/crypto/crypto.go index ff6342c4..ca05dcc7 100644 --- a/examples/crypto/crypto.go +++ b/examples/crypto/crypto.go @@ -6,6 +6,7 @@ package main import ( "bufio" + "context" "crypto/rsa" "crypto/tls" "flag" @@ -58,7 +59,7 @@ func main() { // Create a Client with the selected options c := opcua.NewClient(*endpoint, opts...) - if err := c.Connect(); err != nil { + if err := c.Connect(context.Background()); err != nil { log.Fatal(err) } defer c.Close() @@ -83,7 +84,7 @@ func main() { d := opcua.NewClient(*endpoint, opts...) // Create a channel only and do not activate it automatically - d.Dial() + d.Dial(context.Background()) defer d.Close() // Activate the previous session on the new channel diff --git a/examples/datetime/datetime.go b/examples/datetime/datetime.go index 7a9b3491..ba3836be 100644 --- a/examples/datetime/datetime.go +++ b/examples/datetime/datetime.go @@ -5,6 +5,7 @@ package main import ( + "context" "flag" "fmt" "log" diff --git a/examples/history-read/history-read.go b/examples/history-read/history-read.go index 21469587..233b9b4b 100644 --- a/examples/history-read/history-read.go +++ b/examples/history-read/history-read.go @@ -5,6 +5,7 @@ package main import ( + "context" "flag" "log" "time" @@ -22,7 +23,7 @@ func main() { log.SetFlags(0) c := opcua.NewClient(*endpoint) - if err := c.Connect(); err != nil { + if err := c.Connect(context.Background()); err != nil { log.Fatal(err) } defer c.Close() diff --git a/examples/read/read.go b/examples/read/read.go index f62428c2..2dd45f71 100644 --- a/examples/read/read.go +++ b/examples/read/read.go @@ -5,6 +5,7 @@ package main import ( + "context" "flag" "log" @@ -23,7 +24,7 @@ func main() { log.SetFlags(0) c := opcua.NewClient(*endpoint, opcua.SecurityMode(ua.MessageSecurityModeNone)) - if err := c.Connect(); err != nil { + if err := c.Connect(context.Background()); err != nil { log.Fatal(err) } defer c.Close() diff --git a/examples/write/write.go b/examples/write/write.go index e38597bc..2ec73602 100644 --- a/examples/write/write.go +++ b/examples/write/write.go @@ -5,6 +5,7 @@ package main import ( + "context" "flag" "log" @@ -24,7 +25,7 @@ func main() { log.SetFlags(0) c := opcua.NewClient(*endpoint) - if err := c.Connect(); err != nil { + if err := c.Connect(context.Background()); err != nil { log.Fatal(err) } defer c.Close() diff --git a/uapolicy/cert_utils.go b/uapolicy/cert_utils.go index cfadddf0..4471c1d0 100644 --- a/uapolicy/cert_utils.go +++ b/uapolicy/cert_utils.go @@ -4,7 +4,11 @@ package uapolicy -import "crypto/sha1" +import ( + "crypto/rsa" + "crypto/sha1" + "crypto/x509" +) // Thumbprint returns the thumbprint of a DER-encoded certificate func Thumbprint(c []byte) []byte { @@ -12,3 +16,13 @@ func Thumbprint(c []byte) []byte { return thumbprint[:] } + +// PublicKey returns the RSA PublicKey from a DER-encoded certificate +func PublicKey(c []byte) (*rsa.PublicKey, error) { + cert, err := x509.ParseCertificate(c) + if err != nil { + return nil, err + } + + return cert.PublicKey.(*rsa.PublicKey), nil +} diff --git a/uasc/secure_channel.go b/uasc/secure_channel.go index 8f609690..37798f26 100644 --- a/uasc/secure_channel.go +++ b/uasc/secure_channel.go @@ -5,6 +5,7 @@ package uasc import ( + "context" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -32,8 +33,10 @@ const ( ) type Response struct { - V interface{} - Err error + ReqID uint32 + SCID uint32 + V interface{} + Err error } type SecureChannel struct { @@ -48,9 +51,6 @@ type SecureChannel struct { // reqhdr is the header for the next request. reqhdr *ua.RequestHeader - // quit signals the termination of the recv loop. - quit chan struct{} - // state is the state of the secure channel. // Must be accessed with atomic.LoadInt32/StoreInt32 state int32 @@ -61,6 +61,8 @@ type SecureChannel struct { mu sync.Mutex handler map[uint32]chan Response + chunks map[uint32][]*MessageChunk + enc *uapolicy.EncryptionAlgorithm } @@ -86,9 +88,6 @@ func NewSecureChannel(endpoint string, c *uacp.Conn, cfg *Config) (*SecureChanne cfg.SecurityMode = ua.MessageSecurityModeNone } - // always reset the secure channel id - cfg.SecureChannelID = 0 - return &SecureChannel{ EndpointURL: endpoint, c: c, @@ -100,24 +99,11 @@ func NewSecureChannel(endpoint string, c *uacp.Conn, cfg *Config) (*SecureChanne AdditionalHeader: ua.NewExtensionObject(nil), }, state: secureChannelCreated, - quit: make(chan struct{}), handler: make(map[uint32]chan Response), + chunks: make(map[uint32][]*MessageChunk), }, nil } -func (s *SecureChannel) Open() error { - go s.recv() - return s.openSecureChannel() -} - -func (s *SecureChannel) Close() error { - if err := s.closeSecureChannel(); err != nil { - log.Print("failed to send close secure channel request") - } - close(s.quit) - return s.c.Close() -} - func (s *SecureChannel) LocalEndpoint() string { return s.EndpointURL } @@ -130,78 +116,6 @@ func (s *SecureChannel) hasState(n int32) bool { return atomic.LoadInt32(&s.state) == n } -func (s *SecureChannel) openSecureChannel() error { - var err error - var localKey *rsa.PrivateKey - var remoteKey *rsa.PublicKey - - // Set the encryption methods to Asymmetric with the appropriate - // public keys. OpenSecureChannel is always encrypted with the - // asymmetric algorithms. - // The default value of the encryption algorithm method is the - // SecurityModeNone so no additional work is required for that case - if s.cfg.SecurityMode != ua.MessageSecurityModeNone { - localKey = s.cfg.LocalKey - // todo(dh): move this into the uapolicy package proper or - // adjust the Asymmetric method to receive a certificate instead - remoteCert, err := x509.ParseCertificate(s.cfg.RemoteCertificate) - if err != nil { - return err - } - var ok bool - remoteKey, ok = remoteCert.PublicKey.(*rsa.PublicKey) - if !ok { - return ua.StatusBadCertificateInvalid - } - } - - s.enc, err = uapolicy.Asymmetric(s.cfg.SecurityPolicyURI, localKey, remoteKey) - if err != nil { - return err - } - - nonce := make([]byte, s.enc.NonceLength()) - if _, err := rand.Read(nonce); err != nil { - return err - } - - req := &ua.OpenSecureChannelRequest{ - ClientProtocolVersion: 0, - RequestType: ua.SecurityTokenRequestTypeIssue, - SecurityMode: s.cfg.SecurityMode, - ClientNonce: nonce, - RequestedLifetime: s.cfg.Lifetime, - } - - return s.Send(req, nil, func(v interface{}) error { - resp, ok := v.(*ua.OpenSecureChannelResponse) - if !ok { - return fmt.Errorf("got %T, want OpenSecureChannelResponse", req) - } - s.cfg.SecurityTokenID = resp.SecurityToken.TokenID - - s.enc, err = uapolicy.Symmetric(s.cfg.SecurityPolicyURI, nonce, resp.ServerNonce) - if err != nil { - return err - } - - s.setState(secureChannelOpen) - return nil - }) -} - -// closeSecureChannel sends CloseSecureChannelRequest on top of UASC to SecureChannel. -func (s *SecureChannel) closeSecureChannel() error { - req := &ua.CloseSecureChannelRequest{} - - defer s.setState(secureChannelClosed) - // Don't send the CloseSecureChannel message if it was never fully opened (due to ERR, etc) - if !s.hasState(secureChannelOpen) { - return nil - } - return s.Send(req, nil, nil) -} - // Send sends the service request and calls h with the response. func (s *SecureChannel) Send(svc interface{}, authToken *ua.NodeID, h func(interface{}) error) error { return s.SendWithTimeout(svc, authToken, s.cfg.RequestTimeout, h) @@ -210,11 +124,14 @@ func (s *SecureChannel) Send(svc interface{}, authToken *ua.NodeID, h func(inter // SendWithTimeout sends the service request and calls h with the response with a specific timeout. func (s *SecureChannel) SendWithTimeout(svc interface{}, authToken *ua.NodeID, timeout time.Duration, h func(interface{}) error) error { ch, reqid, err := s.sendAsyncWithTimeout(svc, authToken, timeout) + respRequired := !(h == nil) + + ch, reqid, err := s.SendAsync(svc, authToken, respRequired) if err != nil { return err } - if h == nil { + if !respRequired { return nil } @@ -237,13 +154,13 @@ func (s *SecureChannel) SendWithTimeout(svc interface{}, authToken *ua.NodeID, t // SendAsync sends the service request and returns a channel which will receive the // response when it arrives. -func (s *SecureChannel) SendAsync(svc interface{}, authToken *ua.NodeID) (resp chan Response, reqID uint32, err error) { - return s.sendAsyncWithTimeout(svc, authToken, s.cfg.RequestTimeout) +func (s *SecureChannel) SendAsync(svc interface{}, authToken *ua.NodeID, respReq bool) (resp chan Response, reqID uint32, err error) { + return s.sendAsyncWithTimeout(svc, authToken, respReq, s.cfg.RequestTimeout) } // sendAsyncWithTimeout sends the service request with a specific timeout and returns a channel which will receive the // response when it arrives. -func (s *SecureChannel) sendAsyncWithTimeout(svc interface{}, authToken *ua.NodeID, timeout time.Duration) (resp chan Response, reqID uint32, err error) { +func (s *SecureChannel) sendAsyncWithTimeout(svc interface{}, authToken *ua.NodeID, respReq bool, timeout time.Duration) (resp chan Response, reqID uint32, err error) { typeID := ua.ServiceTypeID(svc) if typeID == 0 { return nil, 0, fmt.Errorf("unknown service %T. Did you call register?", svc) @@ -255,17 +172,20 @@ func (s *SecureChannel) sendAsyncWithTimeout(svc interface{}, authToken *ua.Node s.mu.Lock() // the request header is always the first field val := reflect.ValueOf(svc) - val.Elem().Field(0).Set(reflect.ValueOf(s.reqhdr)) - // update counters + rHdr := val.Elem().Field(0) s.cfg.SequenceNumber++ - s.cfg.RequestID++ - s.reqhdr.AuthenticationToken = authToken - s.reqhdr.RequestHandle++ - s.reqhdr.Timestamp = time.Now() - if timeout > 0 && timeout < s.cfg.RequestTimeout { - timeout = s.cfg.RequestTimeout + if _, ok := rHdr.Interface().(*ua.RequestHeader); ok { + val.Elem().Field(0).Set(reflect.ValueOf(s.reqhdr)) + + s.reqhdr.AuthenticationToken = authToken + s.cfg.RequestID++ + s.reqhdr.RequestHandle++ + s.reqhdr.Timestamp = time.Now() + if timeout > 0 && timeout < s.cfg.RequestTimeout { + timeout = s.cfg.RequestTimeout + } + s.reqhdr.TimeoutHint = uint32(timeout / time.Millisecond) } - s.reqhdr.TimeoutHint = uint32(timeout / time.Millisecond) // encode the message m := NewMessage(svc, typeID, s.cfg) @@ -287,9 +207,12 @@ func (s *SecureChannel) sendAsyncWithTimeout(svc interface{}, authToken *ua.Node if _, err := s.c.Write(b); err != nil { return nil, reqid, err } - debug.Printf("conn %d/%d: send %T with %d bytes", s.c.ID(), reqid, svc, len(b)) + debug.Printf("uasc %d/%d: send %T with %d bytes", s.c.ID(), reqid, svc, len(b)) - // register the handler + // register the handler if a callback was passed + if !respReq { + return nil, 0, nil + } resp = make(chan Response) s.mu.Lock() if s.handler[reqid] != nil { @@ -301,7 +224,7 @@ func (s *SecureChannel) sendAsyncWithTimeout(svc interface{}, authToken *ua.Node return resp, reqid, nil } -func (s *SecureChannel) readchunk() (*MessageChunk, error) { +func (s *SecureChannel) readChunk() (*MessageChunk, error) { // read a full message from the underlying conn. b, err := s.c.Receive() if err == io.EOF || s.hasState(secureChannelClosed) { @@ -320,18 +243,57 @@ func (s *SecureChannel) readchunk() (*MessageChunk, error) { return nil, fmt.Errorf("sechan: decode header failed: %s", err) } - // drop if the channel id does not match - if s.cfg.SecureChannelID > 0 && s.cfg.SecureChannelID != h.SecureChannelID { - return nil, fmt.Errorf("sechan: secure channel id mismatch: got 0x%04x, want 0x%04x", h.SecureChannelID, s.cfg.SecureChannelID) - } - // decode the other headers m := new(MessageChunk) if _, err := m.Decode(b); err != nil { return nil, fmt.Errorf("sechan: decode chunk failed: %s", err) } - // decrypt the block + // OPN Request, initialize encryption + // todo(dh): How to account for renew requests? + switch m.MessageType { + case "OPN": + debug.Printf("uasc: OPN Request") + // Make sure we have a valid security header + if m.AsymmetricSecurityHeader == nil { + return nil, ua.StatusBadDecodingError // todo(dh): check if this is the correct error + } + + // Load the remote certificates from the security header, if present + var remoteKey *rsa.PublicKey + if m.SecurityPolicyURI != ua.SecurityPolicyURINone { + remoteKey, err = uapolicy.PublicKey(m.AsymmetricSecurityHeader.SenderCertificate) + if err != nil { + return nil, err + } + + s.cfg.RemoteCertificate = m.AsymmetricSecurityHeader.SenderCertificate + debug.Printf("Setting securityPolicy to %s", m.SecurityPolicyURI) + } + + s.cfg.SecurityPolicyURI = m.SecurityPolicyURI + s.cfg.RequestID = m.RequestID + + s.enc, err = uapolicy.Asymmetric(m.SecurityPolicyURI, s.cfg.LocalKey, remoteKey) + if err != nil { + return nil, err + } + + case "CLO": + if !s.hasState(secureChannelOpen) { + return nil, ua.StatusBadSecureChannelIDInvalid + } + + // We received the close request so no response is necessary. + // Returning io.EOF signals to the calling methods that the channel is to be shut down + s.setState(secureChannelClosed) + + return nil, io.EOF + + case "MSG": + } + + // Decrypts the block and returns data back into m.Data m.Data, err = s.verifyAndDecrypt(m, b) if err != nil { return nil, err @@ -345,32 +307,105 @@ func (s *SecureChannel) readchunk() (*MessageChunk, error) { if s.cfg.SecureChannelID == 0 { s.cfg.SecureChannelID = h.SecureChannelID - debug.Printf("conn %d/%d: set secure channel id to %d", s.c.ID(), m.SequenceHeader.RequestID, s.cfg.SecureChannelID) + debug.Printf("uasc %d/%d: set secure channel id to %d", s.c.ID(), m.SequenceHeader.RequestID, s.cfg.SecureChannelID) } return m, nil } -// recv receives message chunks from the secure channel, decodes and forwards +// Receive waits for a complete message to be read from the channel and +// sends it back to the caller. If the caller was initiated from a +// Send(), the message is directed to the registered callback function +// and Receive() does not return. Otherwise, if no handler is detected, +// the Receive returns with the message as a return value. +// This behaviour means that anticipated results are automatically directed back to +// their callers but unsolicited messages are sent to the caller of +// Receive() to handle. +func (s *SecureChannel) Receive(ctx context.Context) Response { + for { + select { + case <-ctx.Done(): + return Response{Err: io.EOF} + default: + reqid, svc, err := s.receive(ctx) + if _, ok := err.(*uacp.Error); ok || err == io.EOF { + // todo: notifyCaller has been deprecated, but how else to purge all pending callbacks? + s.notifyCallers(err) + s.Close() + return Response{ + ReqID: reqid, + SCID: s.cfg.SecureChannelID, + V: svc, + Err: err, + } + } + if err != nil { + debug.Printf("uasc %d/%d: err: %v", s.c.ID(), reqid, err) + } else { + debug.Printf("uasc %d/%d: recv %T", s.c.ID(), reqid, svc) + } + + // todo: validate request ID / check that it is increasing correctly + s.cfg.RequestID = reqid + + switch svc.(type) { + case *ua.OpenSecureChannelRequest: + err := s.handleOpenSecureChannelRequest(svc) + if err != nil { + return Response{ + Err: err, + } + } + continue + } + + // check if we have a pending request handler for this response. + s.mu.Lock() + ch, ok := s.handler[reqid] + delete(s.handler, reqid) + s.mu.Unlock() + if !ok { + debug.Printf("uasc %d/%d: no handler for %T, returning result to caller", s.c.ID(), reqid, svc) + return Response{ + ReqID: reqid, + SCID: s.cfg.SecureChannelID, + V: svc, + Err: err, + } + } + + // send response to caller + go func() { + debug.Printf("sending %T to handler\n", svc) + ch <- Response{ + ReqID: reqid, + SCID: s.cfg.SecureChannelID, + V: svc, + Err: err, + } + }() + } + } +} + +// receive receives message chunks from the secure channel, decodes and forwards // them to the registered callback channel, if there is one. Otherwise, // the message is dropped. -func (s *SecureChannel) recv() { - // chunks maps request id to message chunks - chunks := map[uint32][]*MessageChunk{} +func (s *SecureChannel) receive(ctx context.Context) (uint32, interface{}, error) { for { select { - case <-s.quit: - return + case <-ctx.Done(): + return 0, nil, nil default: - chunk, err := s.readchunk() + chunk, err := s.readChunk() if err == io.EOF { - return + return 0, nil, err } if errf, ok := err.(*uacp.Error); ok { s.notifyCallers(errf) - continue + return 0, nil, errf } if err != nil { debug.Printf("error received while receiving chunk: %s", err) @@ -379,43 +414,39 @@ func (s *SecureChannel) recv() { hdr := chunk.Header reqid := chunk.SequenceHeader.RequestID - debug.Printf("conn %d/%d: recv %s%c with %d bytes", s.c.ID(), reqid, hdr.MessageType, hdr.ChunkType, hdr.MessageSize) + debug.Printf("uasc %d/%d: recv %s%c with %d bytes", s.c.ID(), reqid, hdr.MessageType, hdr.ChunkType, hdr.MessageSize) switch hdr.ChunkType { case 'A': - delete(chunks, reqid) + delete(s.chunks, reqid) msga := new(MessageAbort) if _, err := msga.Decode(chunk.Data); err != nil { debug.Printf("conn %d/%d: invalid MSGA chunk. %s", s.c.ID(), reqid, err) - s.notifyCaller(reqid, nil, ua.StatusBadDecodingError) - continue + return reqid, nil, ua.StatusBadDecodingError } - s.notifyCaller(reqid, nil, ua.StatusCode(msga.ErrorCode)) - continue + return reqid, nil, ua.StatusCode(msga.ErrorCode) case 'C': - chunks[reqid] = append(chunks[reqid], chunk) - if n := len(chunks[reqid]); uint32(n) > s.c.MaxChunkCount() { - delete(chunks, reqid) - s.notifyCaller(reqid, nil, fmt.Errorf("too many chunks: %d > %d", n, s.c.MaxChunkCount())) + s.chunks[reqid] = append(s.chunks[reqid], chunk) + if n := len(s.chunks[reqid]); uint32(n) > s.c.MaxChunkCount() { + delete(s.chunks, reqid) + return reqid, nil, fmt.Errorf("too many chunks: %d > %d", n, s.c.MaxChunkCount()) } continue } // merge chunks - all := append(chunks[reqid], chunk) - delete(chunks, reqid) + all := append(s.chunks[reqid], chunk) + delete(s.chunks, reqid) b, err := mergeChunks(all) if err != nil { - s.notifyCaller(reqid, nil, fmt.Errorf("chunk merge error: %v", err)) - continue + return reqid, nil, fmt.Errorf("chunk merge error: %v", err) } if uint32(len(b)) > s.c.MaxMessageSize() { - s.notifyCaller(reqid, nil, fmt.Errorf("message too large: %d > %d", uint32(len(b)), s.c.MaxMessageSize())) - continue + return reqid, nil, fmt.Errorf("message too large: %d > %d", uint32(len(b)), s.c.MaxMessageSize()) } // since we are not decoding the ResponseHeader separately @@ -427,8 +458,7 @@ func (s *SecureChannel) recv() { // handlers and check them periodically to time them out. _, svc, err := ua.DecodeService(b) if err != nil { - s.notifyCaller(reqid, nil, err) - continue + return reqid, nil, err } // extract the ServiceStatus field from the @@ -440,13 +470,12 @@ func (s *SecureChannel) recv() { val := reflect.ValueOf(svc) field0 := val.Elem().Field(0).Interface() if hdr, ok := field0.(*ua.ResponseHeader); ok { - debug.Printf("conn %d/%d: res:%v", s.c.ID(), reqid, hdr.ServiceResult) + debug.Printf("uasc %d/%d: res:%v", s.c.ID(), reqid, hdr.ServiceResult) if hdr.ServiceResult != ua.StatusOK { - s.notifyCaller(reqid, svc, hdr.ServiceResult) - continue + return reqid, svc, hdr.ServiceResult } } - s.notifyCaller(reqid, svc, err) + return reqid, svc, err } } } @@ -471,9 +500,9 @@ func (s *SecureChannel) notifyCaller(reqid uint32, svc interface{}, err error) { func (s *SecureChannel) notifyCallerLock(reqid uint32, svc interface{}, err error) { if err != nil { - debug.Printf("conn %d/%d: %v", s.c.ID(), reqid, err) + debug.Printf("uasc %d/%d: %v", s.c.ID(), reqid, err) } else { - debug.Printf("conn %d/%d: recv %T", s.c.ID(), reqid, svc) + debug.Printf("uasc %d/%d: recv %T", s.c.ID(), reqid, svc) } // check if we have a pending request handler for this response. @@ -481,17 +510,163 @@ func (s *SecureChannel) notifyCallerLock(reqid uint32, svc interface{}, err erro // no handler -> next response if ch == nil { - debug.Printf("conn %d/%d: no handler for %T", s.c.ID(), reqid, svc) + debug.Printf("uasc %d/%d: no handler for %T", s.c.ID(), reqid, svc) return } // send response to caller go func() { - ch <- Response{svc, err} + ch <- Response{ + ReqID: reqid, + SCID: s.cfg.SecureChannelID, + V: svc, + Err: err, + } close(ch) }() } +// Open opens a new secure channel with a server +func (s *SecureChannel) Open() error { + return s.openSecureChannel() +} + +// Close closes an existing secure channel +func (s *SecureChannel) Close() error { + if err := s.closeSecureChannel(); err != nil && err != io.EOF { + debug.Printf("failed to send close secure channel request: %s", err) + } + + if err := s.c.Close(); err != nil && err != io.EOF { + debug.Printf("failed to close transport connection: %s", err) + } + + return io.EOF +} + +func (s *SecureChannel) openSecureChannel() error { + var err error + var localKey *rsa.PrivateKey + var remoteKey *rsa.PublicKey + + // Set the encryption methods to Asymmetric with the appropriate + // public keys. OpenSecureChannel is always encrypted with the + // asymmetric algorithms. + // The default value of the encryption algorithm method is the + // SecurityModeNone so no additional work is required for that case + if s.cfg.SecurityMode != ua.MessageSecurityModeNone { + localKey = s.cfg.LocalKey + // todo(dh): move this into the uapolicy package proper or + // adjust the Asymmetric method to receive a certificate instead + remoteCert, err := x509.ParseCertificate(s.cfg.RemoteCertificate) + if err != nil { + return err + } + var ok bool + remoteKey, ok = remoteCert.PublicKey.(*rsa.PublicKey) + if !ok { + return ua.StatusBadCertificateInvalid + } + } + + s.enc, err = uapolicy.Asymmetric(s.cfg.SecurityPolicyURI, localKey, remoteKey) + if err != nil { + return err + } + + nonce := make([]byte, s.enc.NonceLength()) + if _, err := rand.Read(nonce); err != nil { + return err + } + + req := &ua.OpenSecureChannelRequest{ + ClientProtocolVersion: 0, + RequestType: ua.SecurityTokenRequestTypeIssue, + SecurityMode: s.cfg.SecurityMode, + ClientNonce: nonce, + RequestedLifetime: s.cfg.Lifetime, + } + + return s.Send(req, nil, func(v interface{}) error { + resp, ok := v.(*ua.OpenSecureChannelResponse) + if !ok { + return fmt.Errorf("got %T, want OpenSecureChannelResponse", req) + } + s.cfg.SecurityTokenID = resp.SecurityToken.TokenID + + s.enc, err = uapolicy.Symmetric(s.cfg.SecurityPolicyURI, nonce, resp.ServerNonce) + if err != nil { + return err + } + + s.setState(secureChannelOpen) + return nil + }) +} + +// closeSecureChannel sends CloseSecureChannelRequest on top of UASC to SecureChannel. +func (s *SecureChannel) closeSecureChannel() error { + req := &ua.CloseSecureChannelRequest{} + + defer s.setState(secureChannelClosed) + // Don't send the CloseSecureChannel message if it was never fully opened (due to ERR, etc) + if !s.hasState(secureChannelOpen) { + return io.EOF + } + + err := s.Send(req, nil, nil) + if err != nil { + return err + } + + return io.EOF +} + +func (s *SecureChannel) handleOpenSecureChannelRequest(svc interface{}) error { + debug.Printf("handleOpenSecureChannelRequest: Got OPN Request\n") + + var err error + + req, ok := svc.(*ua.OpenSecureChannelRequest) + if !ok { + debug.Printf("Expected OpenSecureChannel Request, got %T\n", svc) + } + + s.cfg.Lifetime = req.RequestedLifetime + s.cfg.SecurityMode = req.SecurityMode + + nonce := make([]byte, s.enc.NonceLength()) + if _, err := rand.Read(nonce); err != nil { + return err + } + resp := &ua.OpenSecureChannelResponse{ + ResponseHeader: &ua.ResponseHeader{ + Timestamp: time.Now(), + RequestHandle: req.RequestHeader.RequestHandle, + ServiceDiagnostics: &ua.DiagnosticInfo{}, + StringTable: []string{}, + AdditionalHeader: ua.NewExtensionObject(nil), + }, + ServerProtocolVersion: 0, + SecurityToken: &ua.ChannelSecurityToken{ + ChannelID: s.cfg.SecureChannelID, + TokenID: s.cfg.SecurityTokenID, + CreatedAt: time.Now(), + RevisedLifetime: req.RequestedLifetime, + }, + ServerNonce: nonce, + } + + s.Send(resp, nil, nil) + s.enc, err = uapolicy.Symmetric(s.cfg.SecurityPolicyURI, nonce, req.ClientNonce) + if err != nil { + return err + } + s.setState(secureChannelOpen) + + return nil +} + func (s *SecureChannel) popHandlerLock(reqid uint32) chan Response { ch := s.handler[reqid] delete(s.handler, reqid) From 015fd7998972e04213205034b7e7b9cd2b11c1da Mon Sep 17 00:00:00 2001 From: David Hutchison Date: Sun, 26 May 2019 20:21:02 -0600 Subject: [PATCH 2/5] modify error handling --- client.go | 17 ++++++++++++----- uasc/secure_channel.go | 6 +++++- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index 21dc7e85..45bb3980 100644 --- a/client.go +++ b/client.go @@ -8,6 +8,7 @@ import ( "context" "crypto/rand" "fmt" + "io" "log" "reflect" "sort" @@ -181,18 +182,24 @@ func (c *Client) Dial(ctx context.Context) error { return nil } -func (c *Client) monitorChannel(ctx context.Context) error { +func (c *Client) monitorChannel(ctx context.Context) { for { select { case <-ctx.Done(): - return nil + return default: msg := c.sechan.Receive(ctx) - debug.Printf("Received unsolicited message from server: %T", msg.V) if msg.Err != nil { - debug.Printf("Received error: %s", msg.Err) - return msg.Err + if msg.Err == io.EOF { + debug.Printf("Connection closed") + } else { + debug.Printf("Received error: %s", msg.Err) + } + // todo (dh): apart from the above message, we're ignoring this error because there is nothing watching it + // I'd prefer to have a way to return the error to the upper application. + return } + debug.Printf("Received unsolicited message from server: %T", msg.V) } } } diff --git a/uasc/secure_channel.go b/uasc/secure_channel.go index 37798f26..4722d79c 100644 --- a/uasc/secure_channel.go +++ b/uasc/secure_channel.go @@ -657,7 +657,11 @@ func (s *SecureChannel) handleOpenSecureChannelRequest(svc interface{}) error { ServerNonce: nonce, } - s.Send(resp, nil, nil) + err = s.Send(resp, nil, nil) + if err != nil { + return err + } + s.enc, err = uapolicy.Symmetric(s.cfg.SecurityPolicyURI, nonce, req.ClientNonce) if err != nil { return err From d6615d061c8225a3d176471735e0849311342e53 Mon Sep 17 00:00:00 2001 From: David Hutchison Date: Mon, 27 May 2019 08:19:07 -0600 Subject: [PATCH 3/5] added additional context checking and minor code cleanup --- uasc/secure_channel.go | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/uasc/secure_channel.go b/uasc/secure_channel.go index 4722d79c..982f52ab 100644 --- a/uasc/secure_channel.go +++ b/uasc/secure_channel.go @@ -124,7 +124,7 @@ func (s *SecureChannel) Send(svc interface{}, authToken *ua.NodeID, h func(inter // SendWithTimeout sends the service request and calls h with the response with a specific timeout. func (s *SecureChannel) SendWithTimeout(svc interface{}, authToken *ua.NodeID, timeout time.Duration, h func(interface{}) error) error { ch, reqid, err := s.sendAsyncWithTimeout(svc, authToken, timeout) - respRequired := !(h == nil) + respRequired := h != nil ch, reqid, err := s.SendAsync(svc, authToken, respRequired) if err != nil { @@ -175,7 +175,7 @@ func (s *SecureChannel) sendAsyncWithTimeout(svc interface{}, authToken *ua.Node rHdr := val.Elem().Field(0) s.cfg.SequenceNumber++ if _, ok := rHdr.Interface().(*ua.RequestHeader); ok { - val.Elem().Field(0).Set(reflect.ValueOf(s.reqhdr)) + rHdr.Set(reflect.ValueOf(s.reqhdr)) s.reqhdr.AuthenticationToken = authToken s.cfg.RequestID++ @@ -330,7 +330,7 @@ func (s *SecureChannel) Receive(ctx context.Context) Response { reqid, svc, err := s.receive(ctx) if _, ok := err.(*uacp.Error); ok || err == io.EOF { // todo: notifyCaller has been deprecated, but how else to purge all pending callbacks? - s.notifyCallers(err) + s.notifyCallers(ctx, err) s.Close() return Response{ ReqID: reqid, @@ -377,12 +377,16 @@ func (s *SecureChannel) Receive(ctx context.Context) Response { // send response to caller go func() { debug.Printf("sending %T to handler\n", svc) - ch <- Response{ + r := Response{ ReqID: reqid, SCID: s.cfg.SecureChannelID, V: svc, Err: err, } + select { + case <-ctx.Done(): + case ch <- r: + } }() } } @@ -480,25 +484,25 @@ func (s *SecureChannel) receive(ctx context.Context) (uint32, interface{}, error } } -func (s *SecureChannel) notifyCallers(err error) { +func (s *SecureChannel) notifyCallers(ctx context.Context, err error) { s.mu.Lock() var reqids []uint32 for id := range s.handler { reqids = append(reqids, id) } for _, id := range reqids { - s.notifyCallerLock(id, nil, err) + s.notifyCallerLock(ctx, id, nil, err) } s.mu.Unlock() } -func (s *SecureChannel) notifyCaller(reqid uint32, svc interface{}, err error) { +func (s *SecureChannel) notifyCaller(ctx context.Context, reqid uint32, svc interface{}, err error) { s.mu.Lock() - s.notifyCallerLock(reqid, svc, err) + s.notifyCallerLock(ctx, reqid, svc, err) s.mu.Unlock() } -func (s *SecureChannel) notifyCallerLock(reqid uint32, svc interface{}, err error) { +func (s *SecureChannel) notifyCallerLock(ctx context.Context, reqid uint32, svc interface{}, err error) { if err != nil { debug.Printf("uasc %d/%d: %v", s.c.ID(), reqid, err) } else { @@ -516,12 +520,16 @@ func (s *SecureChannel) notifyCallerLock(reqid uint32, svc interface{}, err erro // send response to caller go func() { - ch <- Response{ + r := Response{ ReqID: reqid, SCID: s.cfg.SecureChannelID, V: svc, Err: err, } + select { + case <-ctx.Done(): + case ch <- r: + } close(ch) }() } From 3b8dfd9c7f1b8dc4124cc443dc655a7afb1ccda5 Mon Sep 17 00:00:00 2001 From: David Hutchison Date: Mon, 27 May 2019 08:26:56 -0600 Subject: [PATCH 4/5] removed the notifyCaller method as it is no longer in use --- uasc/secure_channel.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/uasc/secure_channel.go b/uasc/secure_channel.go index 982f52ab..70fb91a7 100644 --- a/uasc/secure_channel.go +++ b/uasc/secure_channel.go @@ -496,12 +496,6 @@ func (s *SecureChannel) notifyCallers(ctx context.Context, err error) { s.mu.Unlock() } -func (s *SecureChannel) notifyCaller(ctx context.Context, reqid uint32, svc interface{}, err error) { - s.mu.Lock() - s.notifyCallerLock(ctx, reqid, svc, err) - s.mu.Unlock() -} - func (s *SecureChannel) notifyCallerLock(ctx context.Context, reqid uint32, svc interface{}, err error) { if err != nil { debug.Printf("uasc %d/%d: %v", s.c.ID(), reqid, err) From 024329a15fcfb5920832db3f82a8401474e885c3 Mon Sep 17 00:00:00 2001 From: David Hutchison Date: Wed, 12 Jun 2019 20:27:33 -0600 Subject: [PATCH 5/5] added context variable to examples --- examples/accesslevel/accesslevel.go | 4 +++- examples/browse/browse.go | 4 +++- examples/crypto/crypto.go | 6 ++++-- examples/datetime/datetime.go | 4 +++- examples/history-read/history-read.go | 4 +++- examples/read/read.go | 4 +++- examples/server/server.go | 4 +++- examples/subscribe/subscribe.go | 14 +++++++------- examples/write/write.go | 4 +++- uasc/secure_channel.go | 5 +---- 10 files changed, 33 insertions(+), 20 deletions(-) diff --git a/examples/accesslevel/accesslevel.go b/examples/accesslevel/accesslevel.go index d3c5c31a..65755282 100644 --- a/examples/accesslevel/accesslevel.go +++ b/examples/accesslevel/accesslevel.go @@ -23,8 +23,10 @@ func main() { flag.Parse() log.SetFlags(0) + ctx := context.Background() + c := opcua.NewClient(*endpoint) - if err := c.Connect(context.Background()); err != nil { + if err := c.Connect(ctx); err != nil { log.Fatal(err) } defer c.Close() diff --git a/examples/browse/browse.go b/examples/browse/browse.go index 782d8d38..b35c3c80 100644 --- a/examples/browse/browse.go +++ b/examples/browse/browse.go @@ -53,8 +53,10 @@ func main() { flag.Parse() log.SetFlags(0) + ctx := context.Background() + c := opcua.NewClient(*endpoint) - if err := c.Connect(context.Background()); err != nil { + if err := c.Connect(ctx); err != nil { log.Fatal(err) } defer c.Close() diff --git a/examples/crypto/crypto.go b/examples/crypto/crypto.go index ca05dcc7..7461cb89 100644 --- a/examples/crypto/crypto.go +++ b/examples/crypto/crypto.go @@ -42,6 +42,8 @@ func main() { flag.Parse() log.SetFlags(0) + ctx := context.Background() + // Get a list of the endpoints for our target server endpoints, err := opcua.GetEndpoints(*endpoint) if err != nil { @@ -59,7 +61,7 @@ func main() { // Create a Client with the selected options c := opcua.NewClient(*endpoint, opts...) - if err := c.Connect(context.Background()); err != nil { + if err := c.Connect(ctx); err != nil { log.Fatal(err) } defer c.Close() @@ -84,7 +86,7 @@ func main() { d := opcua.NewClient(*endpoint, opts...) // Create a channel only and do not activate it automatically - d.Dial(context.Background()) + d.Dial(ctx) defer d.Close() // Activate the previous session on the new channel diff --git a/examples/datetime/datetime.go b/examples/datetime/datetime.go index ba3836be..b4ab3292 100644 --- a/examples/datetime/datetime.go +++ b/examples/datetime/datetime.go @@ -25,6 +25,8 @@ func main() { flag.Parse() log.SetFlags(0) + ctx := context.Background() + endpoints, err := opcua.GetEndpoints(*endpoint) if err != nil { log.Fatal(err) @@ -46,7 +48,7 @@ func main() { } c := opcua.NewClient(ep.EndpointURL, opts...) - if err := c.Connect(); err != nil { + if err := c.Connect(ctx); err != nil { log.Fatal(err) } defer c.Close() diff --git a/examples/history-read/history-read.go b/examples/history-read/history-read.go index 233b9b4b..fc509fbf 100644 --- a/examples/history-read/history-read.go +++ b/examples/history-read/history-read.go @@ -22,8 +22,10 @@ func main() { flag.Parse() log.SetFlags(0) + ctx := context.Background() + c := opcua.NewClient(*endpoint) - if err := c.Connect(context.Background()); err != nil { + if err := c.Connect(ctx); err != nil { log.Fatal(err) } defer c.Close() diff --git a/examples/read/read.go b/examples/read/read.go index 2dd45f71..65ad8954 100644 --- a/examples/read/read.go +++ b/examples/read/read.go @@ -23,8 +23,10 @@ func main() { flag.Parse() log.SetFlags(0) + ctx := context.Background() + c := opcua.NewClient(*endpoint, opcua.SecurityMode(ua.MessageSecurityModeNone)) - if err := c.Connect(context.Background()); err != nil { + if err := c.Connect(ctx); err != nil { log.Fatal(err) } defer c.Close() diff --git a/examples/server/server.go b/examples/server/server.go index 96a70ea9..421b85e5 100644 --- a/examples/server/server.go +++ b/examples/server/server.go @@ -18,12 +18,14 @@ func main() { ) flag.Parse() + ctx := context.Background() + log.Printf("Listening on %s", *endpoint) l, err := uacp.Listen(*endpoint, nil) if err != nil { log.Fatal(err) } - c, err := l.Accept(context.Background()) + c, err := l.Accept(ctx) if err != nil { log.Fatal(err) } diff --git a/examples/subscribe/subscribe.go b/examples/subscribe/subscribe.go index c18e62d6..1692de2b 100644 --- a/examples/subscribe/subscribe.go +++ b/examples/subscribe/subscribe.go @@ -29,6 +29,12 @@ func main() { flag.Parse() log.SetFlags(0) + // add an arbitrary timeout to demonstrate how to stop a subscription + // with a context. + d := 30 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), d) + defer cancel() + endpoints, err := opcua.GetEndpoints(*endpoint) if err != nil { log.Fatal(err) @@ -50,7 +56,7 @@ func main() { } c := opcua.NewClient(ep.EndpointURL, opts...) - if err := c.Connect(); err != nil { + if err := c.Connect(ctx); err != nil { log.Fatal(err) } defer c.Close() @@ -77,12 +83,6 @@ func main() { log.Fatal(err) } - // add an arbitrary timeout to demonstrate how to stop a subscription - // with a context. - d := 30 * time.Second - ctx, cancel := context.WithTimeout(context.Background(), d) - defer cancel() - go sub.Run(ctx) // start Publish loop // read from subscription's notification channel until ctx is cancelled diff --git a/examples/write/write.go b/examples/write/write.go index 2ec73602..1518a9b7 100644 --- a/examples/write/write.go +++ b/examples/write/write.go @@ -24,8 +24,10 @@ func main() { flag.Parse() log.SetFlags(0) + ctx := context.Background() + c := opcua.NewClient(*endpoint) - if err := c.Connect(context.Background()); err != nil { + if err := c.Connect(ctx); err != nil { log.Fatal(err) } defer c.Close() diff --git a/uasc/secure_channel.go b/uasc/secure_channel.go index 70fb91a7..eb6e94b6 100644 --- a/uasc/secure_channel.go +++ b/uasc/secure_channel.go @@ -11,7 +11,6 @@ import ( "crypto/x509" "fmt" "io" - "log" "math" "reflect" "sync" @@ -123,7 +122,6 @@ func (s *SecureChannel) Send(svc interface{}, authToken *ua.NodeID, h func(inter // SendWithTimeout sends the service request and calls h with the response with a specific timeout. func (s *SecureChannel) SendWithTimeout(svc interface{}, authToken *ua.NodeID, timeout time.Duration, h func(interface{}) error) error { - ch, reqid, err := s.sendAsyncWithTimeout(svc, authToken, timeout) respRequired := h != nil ch, reqid, err := s.SendAsync(svc, authToken, respRequired) @@ -329,7 +327,6 @@ func (s *SecureChannel) Receive(ctx context.Context) Response { default: reqid, svc, err := s.receive(ctx) if _, ok := err.(*uacp.Error); ok || err == io.EOF { - // todo: notifyCaller has been deprecated, but how else to purge all pending callbacks? s.notifyCallers(ctx, err) s.Close() return Response{ @@ -408,7 +405,7 @@ func (s *SecureChannel) receive(ctx context.Context) (uint32, interface{}, error return 0, nil, err } if errf, ok := err.(*uacp.Error); ok { - s.notifyCallers(errf) + s.notifyCallers(ctx, errf) return 0, nil, errf } if err != nil {