Skip to content

Commit

Permalink
Merge pull request #654 from gopcua/issue-640-panic-in-secure-channel
Browse files Browse the repository at this point in the history
Fix races to c.Session() and c.SecureChannel()
  • Loading branch information
magiconair authored May 25, 2023
2 parents 21b0a2c + 611b2d7 commit 0b5ef64
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
41 changes: 23 additions & 18 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ func (c *Client) Connect(ctx context.Context) error {
return c.cfgerr
}

// todo(fs): the secure channel is 'nil' during a re-connect
// todo(fs): but we expect this method to be called once during startup
// todo(fs): so this is probably safe
if c.SecureChannel() != nil {
return errors.Errorf("already connected")
}
Expand Down Expand Up @@ -605,8 +608,8 @@ func (c *Client) CloseWithContext(ctx context.Context) error {
if c.mcancel != nil {
c.mcancel()
}
if c.SecureChannel() != nil {
c.SecureChannel().Close()
if sc := c.SecureChannel(); sc != nil {
sc.Close()
c.setSecureChannel(nil)
}

Expand Down Expand Up @@ -657,6 +660,8 @@ func (c *Client) setPublishTimeout(d time.Duration) {
}

// SecureChannel returns the active secure channel.
// During reconnect this value can change.
// Make sure to capture the value in a method before using it.
func (c *Client) SecureChannel() *uasc.SecureChannel {
return c.atomicSechan.Load().(*uasc.SecureChannel)
}
Expand All @@ -667,6 +672,8 @@ func (c *Client) setSecureChannel(sc *uasc.SecureChannel) {
}

// Session returns the active session.
// During reconnect this value can change.
// Make sure to capture the value in a method before using it.
func (c *Client) Session() *Session {
return c.atomicSession.Load().(*Session)
}
Expand All @@ -676,11 +683,6 @@ func (c *Client) setSession(s *Session) {
stats.Client().Add("Session", 1)
}

// sessionClosed returns true when there is no session.
func (c *Client) sessionClosed() bool {
return c.Session() == nil
}

// Session is a OPC/UA session as described in Part 4, 5.6.
type Session struct {
cfg *uasc.SessionConfig
Expand Down Expand Up @@ -727,7 +729,8 @@ func (c *Client) CreateSession(cfg *uasc.SessionConfig) (*Session, error) {

// Note: Starting with v0.5 this method is superseded by the non 'WithContext' method.
func (c *Client) CreateSessionWithContext(ctx context.Context, cfg *uasc.SessionConfig) (*Session, error) {
if c.SecureChannel() == nil {
sc := c.SecureChannel()
if sc == nil {
return nil, ua.StatusBadServerNotConnected
}

Expand All @@ -752,14 +755,14 @@ func (c *Client) CreateSessionWithContext(ctx context.Context, cfg *uasc.Session

var s *Session
// for the CreateSessionRequest the authToken is always nil.
// use c.SecureChannel().SendRequest() to enforce this.
err := c.SecureChannel().SendRequestWithContext(ctx, req, nil, func(v interface{}) error {
// use sc.SendRequest() to enforce this.
err := sc.SendRequestWithContext(ctx, req, nil, func(v interface{}) error {
var res *ua.CreateSessionResponse
if err := safeAssign(v, &res); err != nil {
return err
}

err := c.SecureChannel().VerifySessionSignature(res.ServerCertificate, nonce, res.ServerSignature.Signature)
err := sc.VerifySessionSignature(res.ServerCertificate, nonce, res.ServerSignature.Signature)
if err != nil {
log.Printf("error verifying session signature: %s", err)
return nil
Expand Down Expand Up @@ -820,11 +823,12 @@ func (c *Client) ActivateSession(s *Session) error {

// Note: Starting with v0.5 this method is superseded by the non 'WithContext' method.
func (c *Client) ActivateSessionWithContext(ctx context.Context, s *Session) error {
if c.SecureChannel() == nil {
sc := c.SecureChannel()
if sc == nil {
return ua.StatusBadServerNotConnected
}
stats.Client().Add("ActivateSession", 1)
sig, sigAlg, err := c.SecureChannel().NewSessionSignature(s.serverCertificate, s.serverNonce)
sig, sigAlg, err := sc.NewSessionSignature(s.serverCertificate, s.serverNonce)
if err != nil {
log.Printf("error creating session signature: %s", err)
return nil
Expand All @@ -835,7 +839,7 @@ func (c *Client) ActivateSessionWithContext(ctx context.Context, s *Session) err
// nothing to do

case *ua.UserNameIdentityToken:
pass, passAlg, err := c.SecureChannel().EncryptUserPassword(s.cfg.AuthPolicyURI, s.cfg.AuthPassword, s.serverCertificate, s.serverNonce)
pass, passAlg, err := sc.EncryptUserPassword(s.cfg.AuthPolicyURI, s.cfg.AuthPassword, s.serverCertificate, s.serverNonce)
if err != nil {
log.Printf("error encrypting user password: %s", err)
return err
Expand All @@ -844,7 +848,7 @@ func (c *Client) ActivateSessionWithContext(ctx context.Context, s *Session) err
tok.EncryptionAlgorithm = passAlg

case *ua.X509IdentityToken:
tokSig, tokSigAlg, err := c.SecureChannel().NewUserTokenSignature(s.cfg.AuthPolicyURI, s.serverCertificate, s.serverNonce)
tokSig, tokSigAlg, err := sc.NewUserTokenSignature(s.cfg.AuthPolicyURI, s.serverCertificate, s.serverNonce)
if err != nil {
log.Printf("error creating session signature: %s", err)
return err
Expand All @@ -868,7 +872,7 @@ func (c *Client) ActivateSessionWithContext(ctx context.Context, s *Session) err
UserIdentityToken: ua.NewExtensionObject(s.cfg.UserIdentityToken),
UserTokenSignature: s.cfg.UserTokenSignature,
}
return c.SecureChannel().SendRequestWithContext(ctx, req, s.resp.AuthenticationToken, func(v interface{}) error {
return sc.SendRequestWithContext(ctx, req, s.resp.AuthenticationToken, func(v interface{}) error {
var res *ua.ActivateSessionResponse
if err := safeAssign(v, &res); err != nil {
return err
Expand Down Expand Up @@ -965,14 +969,15 @@ func (c *Client) SendWithContext(ctx context.Context, req ua.Request, h func(int
// the response. If the client has an active session it injects the
// authentication token.
func (c *Client) sendWithTimeout(ctx context.Context, req ua.Request, timeout time.Duration, h func(interface{}) error) error {
if c.SecureChannel() == nil {
sc := c.SecureChannel()
if sc == nil {
return ua.StatusBadServerNotConnected
}
var authToken *ua.NodeID
if s := c.Session(); s != nil {
authToken = s.resp.AuthenticationToken
}
return c.SecureChannel().SendRequestWithTimeoutWithContext(ctx, req, authToken, timeout, h)
return sc.SendRequestWithTimeoutWithContext(ctx, req, authToken, timeout, h)
}

// Node returns a node object which accesses its attributes
Expand Down
11 changes: 9 additions & 2 deletions client_sub.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,21 @@ func (c *Client) sendRepublishRequests(ctx context.Context, sub *Subscription, a
req.RetransmitSequenceNumber,
)

if c.sessionClosed() {
s := c.Session()
if s == nil {
debug.Printf("Republishing subscription %d aborted", req.SubscriptionID)
return ua.StatusBadSessionClosed
}

sc := c.SecureChannel()
if sc == nil {
debug.Printf("Republishing subscription %d aborted", req.SubscriptionID)
return ua.StatusBadNotConnected
}

debug.Printf("RepublishRequest: req=%s", debug.ToJSON(req))
var res *ua.RepublishResponse
err := c.SecureChannel().SendRequestWithContext(ctx, req, c.Session().resp.AuthenticationToken, func(v interface{}) error {
err := sc.SendRequestWithContext(ctx, req, s.resp.AuthenticationToken, func(v interface{}) error {
return safeAssign(v, &res)
})
debug.Printf("RepublishResponse: res=%s err=%v", debug.ToJSON(res), err)
Expand Down

0 comments on commit 0b5ef64

Please sign in to comment.