Skip to content

Commit

Permalink
Merge pull request #539 from gopcua/issue-538-session-id-invalid
Browse files Browse the repository at this point in the history
Fix invalid session id
  • Loading branch information
magiconair authored Jan 4, 2022
2 parents 8b599ae + aa8cd5d commit 2980901
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 52 deletions.
1 change: 0 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,6 @@ func (c *Client) UpdateNamespaces() error {
if err != nil {
return err
}
c.setSession(nil)
c.setNamespaces(ns)
return nil
}
Expand Down
28 changes: 14 additions & 14 deletions uacp/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ type Dialer struct {
}

func (d *Dialer) Dial(ctx context.Context, endpoint string) (*Conn, error) {
debug.Printf("Connecting to %s", endpoint)
debug.Printf("uacp: connecting to %s", endpoint)
_, raddr, err := ResolveEndpoint(endpoint)
if err != nil {
return nil, err
Expand All @@ -88,9 +88,9 @@ func (d *Dialer) Dial(ctx context.Context, endpoint string) (*Conn, error) {
return nil, err
}

debug.Printf("conn %d: start HEL/ACK handshake", conn.id)
debug.Printf("uacp %d: start HEL/ACK handshake", conn.id)
if err := conn.Handshake(endpoint); err != nil {
debug.Printf("conn %d: HEL/ACK handshake failed: %s", conn.id, err)
debug.Printf("uacp %d: HEL/ACK handshake failed: %s", conn.id, err)
conn.Close()
return nil, err
}
Expand Down Expand Up @@ -213,7 +213,7 @@ func (c *Conn) Close() (err error) {
}

func (c *Conn) close() error {
debug.Printf("conn %d: close", c.id)
debug.Printf("uacp %d: close", c.id)
return c.TCPConn.Close()
}

Expand Down Expand Up @@ -248,22 +248,22 @@ func (c *Conn) Handshake(endpoint string) error {
}
if ack.MaxChunkCount == 0 {
ack.MaxChunkCount = DefaultMaxChunkCount
debug.Printf("conn %d: server has no chunk limit. Using %d", c.id, ack.MaxChunkCount)
debug.Printf("uacp %d: server has no chunk limit. Using %d", c.id, ack.MaxChunkCount)
}
if ack.MaxMessageSize == 0 {
ack.MaxMessageSize = DefaultMaxMessageSize
debug.Printf("conn %d: server has no message size limit. Using %d", c.id, ack.MaxMessageSize)
debug.Printf("uacp %d: server has no message size limit. Using %d", c.id, ack.MaxMessageSize)
}
c.ack = ack
debug.Printf("conn %d: recv %#v", c.id, ack)
debug.Printf("uacp %d: recv %#v", c.id, ack)
return nil

case "ERRF":
errf := new(Error)
if _, err := errf.Decode(b[hdrlen:]); err != nil {
return errors.Errorf("uacp: decode ERR failed: %s", err)
}
debug.Printf("conn %d: recv %#v", c.id, errf)
debug.Printf("uacp %d: recv %#v", c.id, errf)
return errf

default:
Expand Down Expand Up @@ -297,7 +297,7 @@ func (c *Conn) srvhandshake(endpoint string) error {
c.SendError(ua.StatusBadTCPInternalError)
return err
}
debug.Printf("conn %d: recv %#v", c.id, hel)
debug.Printf("uacp %d: recv %#v", c.id, hel)
return nil

case "RHEF":
Expand All @@ -310,23 +310,23 @@ func (c *Conn) srvhandshake(endpoint string) error {
c.SendError(ua.StatusBadTCPEndpointURLInvalid)
return errors.Errorf("uacp: invalid endpoint url %s", rhe.EndpointURL)
}
debug.Printf("conn %d: connecting to %s", c.id, rhe.ServerURI)
debug.Printf("uacp %d: connecting to %s", c.id, rhe.ServerURI)
c.Close()
var dialer net.Dialer
c2, err := dialer.DialContext(context.Background(), "tcp", rhe.ServerURI)
if err != nil {
return err
}
c.TCPConn = c2.(*net.TCPConn)
debug.Printf("conn %d: recv %#v", c.id, rhe)
debug.Printf("uacp %d: recv %#v", c.id, rhe)
return nil

case "ERRF":
errf := new(Error)
if _, err := errf.Decode(b[hdrlen:]); err != nil {
return errors.Errorf("uacp: decode ERR failed: %s", err)
}
debug.Printf("conn %d: recv %#v", c.id, errf)
debug.Printf("uacp %d: recv %#v", c.id, errf)
return errf

default:
Expand Down Expand Up @@ -367,7 +367,7 @@ func (c *Conn) Receive() ([]byte, error) {
return nil, err
}

debug.Printf("conn %d: recv %s%c with %d bytes", c.id, h.MessageType, h.ChunkType, h.MessageSize)
debug.Printf("uacp %d: recv %s%c with %d bytes", c.id, h.MessageType, h.ChunkType, h.MessageSize)

if h.MessageType == "ERR" {
errf := new(Error)
Expand Down Expand Up @@ -408,7 +408,7 @@ func (c *Conn) Send(typ string, msg interface{}) error {
if _, err := c.Write(b); err != nil {
return errors.Errorf("write failed: %s", err)
}
debug.Printf("conn %d: sent %s with %d bytes", c.id, typ, len(b))
debug.Printf("uacp %d: sent %s with %d bytes", c.id, typ, len(b))

return nil
}
Expand Down
50 changes: 15 additions & 35 deletions uasc/secure_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,12 @@ func (s *SecureChannel) dispatcher() {
s.rcvLocker.lock()
}

debug.Printf("sending %T to handler\n", resp.V)
debug.Printf("uasc %d/%d: sending %T to handler", s.c.ID(), resp.ReqID, resp.V)
select {
case ch <- resp:
default:
// this should never happen since the chan is of size one
debug.Printf("unexpected state. channel write should always succeed.")
debug.Printf("uasc %d/%d: unexpected state. channel write should always succeed.", s.c.ID(), resp.ReqID)
}

s.rcvLocker.waitIfLock()
Expand All @@ -245,7 +245,7 @@ func (s *SecureChannel) receive(ctx context.Context) *response {
default:
chunk, err := s.readChunk()
if err == io.EOF {
debug.Printf("uasc readChunk EOF")
debug.Printf("uasc %d: readChunk EOF", s.c.ID())
return &response{Err: err}
}

Expand All @@ -272,7 +272,7 @@ func (s *SecureChannel) receive(ctx context.Context) *response {

msga := new(MessageAbort)
if _, err := msga.Decode(chunk.Data); err != nil {
debug.Printf("conn %d/%d: invalid MSGA chunk. %s", s.c.ID(), reqID, err)
debug.Printf("uasc %d/%d: invalid MSGA chunk. %s", s.c.ID(), reqID, err)
resp.Err = ua.StatusBadDecodingError
return resp
}
Expand Down Expand Up @@ -367,8 +367,6 @@ func (s *SecureChannel) readChunk() (*MessageChunk, error) {

switch m.MessageType {
case "OPN":
debug.Printf("uasc OPN Request")

// Make sure we have a valid security header
if m.AsymmetricSecurityHeader == nil {
return nil, ua.StatusBadDecodingError // todo(dh): check if this is the correct error
Expand All @@ -380,7 +378,7 @@ func (s *SecureChannel) readChunk() (*MessageChunk, error) {

if m.SecurityPolicyURI != ua.SecurityPolicyURINone {
s.cfg.RemoteCertificate = m.AsymmetricSecurityHeader.SenderCertificate
debug.Printf("Setting securityPolicy to %s", m.SecurityPolicyURI)
debug.Printf("uasc %d: setting securityPolicy to %s", s.c.ID(), m.SecurityPolicyURI)
}

s.cfg.SecurityPolicyURI = m.SecurityPolicyURI
Expand All @@ -389,7 +387,7 @@ func (s *SecureChannel) readChunk() (*MessageChunk, error) {
case "CLO":
return nil, io.EOF
case "MSG":
// nop
// noop
default:
return nil, errors.Errorf("sechan: unknown message type: %s", m.MessageType)
}
Expand Down Expand Up @@ -427,13 +425,10 @@ func (s *SecureChannel) verifyAndDecrypt(m *MessageChunk, b []byte, instance *ch
)

for i := len(instances) - 1; i >= 0; i-- {
// instances[i].Lock()
if verified, err = instances[i].verifyAndDecrypt(m, b); err == nil {
// instances[i].Unlock()
return verified, nil
}
// instances[i].Unlock()
debug.Printf("attempting an older channel state...")
debug.Printf("uasc %d: attempting an older channel state...", s.c.ID())
}

return nil, err
Expand Down Expand Up @@ -521,7 +516,7 @@ func (s *SecureChannel) open(ctx context.Context, instance *channelInstance, req
// trigger cleanup after we are all done
defer func() {
if s.openingInstance == nil || s.openingInstance.state != channelActive {
debug.Printf("failed to open a new secure channel")
debug.Printf("uasc %d: failed to open a new secure channel", s.c.ID())
}
s.openingInstance = nil
}()
Expand Down Expand Up @@ -581,7 +576,7 @@ func (s *SecureChannel) handleOpenSecureChannelResponse(resp *ua.OpenSecureChann

s.activeInstance = instance

debug.Printf("received security token: channelID=%d tokenID=%d createdAt=%s lifetime=%s", instance.secureChannelID, instance.securityTokenID, instance.createdAt.Format(time.RFC3339), instance.revisedLifetime)
debug.Printf("uasc %d: received security token. channelID=%d tokenID=%d createdAt=%s lifetime=%s", s.c.ID(), instance.secureChannelID, instance.securityTokenID, instance.createdAt.Format(time.RFC3339), instance.revisedLifetime)

go s.scheduleRenewal(instance)
go s.scheduleExpiration(instance)
Expand All @@ -596,7 +591,7 @@ func (s *SecureChannel) scheduleRenewal(instance *channelInstance) {
const renewAfter = 0.75
when := time.Second * time.Duration(instance.revisedLifetime.Seconds()*renewAfter)

debug.Printf("channelID %d will be refreshed in %s (%s)", instance.secureChannelID, when, time.Now().UTC().Add(when).Format(time.RFC3339))
debug.Printf("uasc %d: security token is refreshed at %s (%s). channelID=%d tokenID=%d", s.c.ID(), time.Now().UTC().Add(when).Format(time.RFC3339), when, instance.secureChannelID, instance.securityTokenID)

t := time.NewTimer(when)
defer t.Stop()
Expand Down Expand Up @@ -628,7 +623,7 @@ func (s *SecureChannel) scheduleExpiration(instance *channelInstance) {
const expireAfter = 1.25
when := instance.createdAt.Add(time.Second * time.Duration(instance.revisedLifetime.Seconds()*expireAfter))

debug.Printf("channelID %d/%d will expire at %s", instance.secureChannelID, instance.securityTokenID, when.UTC().Format(time.RFC3339))
debug.Printf("uasc %d: security token expires at %s. channelID=%d tokenID=%d", s.c.ID(), when.UTC().Format(time.RFC3339), instance.secureChannelID, instance.securityTokenID)

t := time.NewTimer(time.Until(when))

Expand All @@ -648,7 +643,7 @@ func (s *SecureChannel) scheduleExpiration(instance *channelInstance) {
for _, oldInstance := range oldInstances {
if oldInstance.secureChannelID != instance.secureChannelID {
// something has gone horribly wrong!
debug.Printf("secureChannelID mismatch during scheduleExpiration!")
debug.Printf("uasc %d: secureChannelID mismatch during scheduleExpiration!", s.c.ID())
}
if oldInstance.securityTokenID == instance.securityTokenID {
continue
Expand Down Expand Up @@ -825,7 +820,7 @@ func (s *SecureChannel) Close() (err error) {
}

func (s *SecureChannel) close() error {
debug.Printf("uasc Close()")
debug.Printf("uasc %d: Close()", s.c.ID())

defer func() {
close(s.closing)
Expand Down Expand Up @@ -863,14 +858,8 @@ func mergeChunks(chunks []*MessageChunk) ([]byte, error) {
return chunks[0].Data, nil
}

// todo(fs): check if this is correct and necessary
// sort.Sort(bySequence(chunks))

var (
b []byte
seqnr uint32
)

var b []byte
var seqnr uint32
for _, c := range chunks {
if c.SequenceHeader.SequenceNumber == seqnr {
continue // duplicate chunk
Expand All @@ -880,12 +869,3 @@ func mergeChunks(chunks []*MessageChunk) ([]byte, error) {
}
return b, nil
}

// todo(fs): we only need this if we need to sort chunks. Need to check the spec
// type bySequence []*MessageChunk

// func (a bySequence) Len() int { return len(a) }
// func (a bySequence) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// func (a bySequence) Less(i, j int) bool {
// return a[i].SequenceHeader.SequenceNumber < a[j].SequenceHeader.SequenceNumber
// }
4 changes: 2 additions & 2 deletions uatest/stats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ func TestStats(t *testing.T) {
"UpdateNamespaces": newExpVarInt(1),
"NodesToRead": newExpVarInt(1),
"Read": newExpVarInt(1),
"Send": newExpVarInt(1),
"Send": newExpVarInt(2),
"Close": newExpVarInt(1),
"CloseSession": newExpVarInt(2),
"SecureChannel": newExpVarInt(2),
"Session": newExpVarInt(5),
"Session": newExpVarInt(4),
"State": newExpVarInt(0),
}

Expand Down

0 comments on commit 2980901

Please sign in to comment.