From 3f715a842164cd9b67637058e1e6e0ae17b5e620 Mon Sep 17 00:00:00 2001 From: Vincent Vanackere Date: Wed, 23 Jul 2014 15:07:31 +0200 Subject: [PATCH] Remove data race between *Conn.Close() and *Conn.reader() A single channel is now used to signal that the connection was closed. There is also no need to set the net.Conn field to nil as this was causing a nil pointer dereference in the reader loop. --- conn.go | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/conn.go b/conn.go index 71875b6..903df01 100644 --- a/conn.go +++ b/conn.go @@ -32,14 +32,13 @@ type messagePacket struct { type Conn struct { conn net.Conn isTLS bool - isClosing bool Debug debugging chanConfirm chan bool chanResults map[uint64]chan *ber.Packet chanMessage chan *messagePacket chanMessageID chan uint64 wgSender sync.WaitGroup - wgClose sync.WaitGroup + chanDone chan struct{} once sync.Once } @@ -76,19 +75,19 @@ func NewConn(conn net.Conn) *Conn { chanMessageID: make(chan uint64), chanMessage: make(chan *messagePacket, 10), chanResults: map[uint64]chan *ber.Packet{}, + chanDone: make(chan struct{}), } } func (l *Conn) start() { go l.reader() go l.processMessages() - l.wgClose.Add(1) } // Close closes the connection. func (l *Conn) Close() { l.once.Do(func() { - l.isClosing = true + close(l.chanDone) l.wgSender.Wait() l.Debug.Printf("Sending quit message and waiting for confirmation") @@ -100,11 +99,8 @@ func (l *Conn) Close() { if err := l.conn.Close(); err != nil { log.Print(err) } - - l.conn = nil - l.wgClose.Done() }) - l.wgClose.Wait() + <-l.chanDone } // Returns the next available messageID @@ -158,8 +154,17 @@ func (l *Conn) StartTLS(config *tls.Config) error { return nil } +func (l *Conn) closing() bool { + select { + case <-l.chanDone: + return true + default: + return false + } +} + func (l *Conn) sendMessage(packet *ber.Packet) (chan *ber.Packet, error) { - if l.isClosing { + if l.closing() { return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed")) } out := make(chan *ber.Packet) @@ -174,7 +179,7 @@ func (l *Conn) sendMessage(packet *ber.Packet) (chan *ber.Packet, error) { } func (l *Conn) finishMessage(messageID uint64) { - if l.isClosing { + if l.closing() { return } message := &messagePacket{ @@ -185,12 +190,13 @@ func (l *Conn) finishMessage(messageID uint64) { } func (l *Conn) sendProcessMessage(message *messagePacket) bool { - if l.isClosing { + l.wgSender.Add(1) + defer l.wgSender.Done() + + if l.closing() { return false } - l.wgSender.Add(1) l.chanMessage <- message - l.wgSender.Done() return true }