diff --git a/v2/mux.go b/v2/mux.go index c27c78c..5641795 100644 --- a/v2/mux.go +++ b/v2/mux.go @@ -257,6 +257,8 @@ func (m *Mux) readLoop() { cond: sync.Cond{L: new(sync.Mutex)}, covert: covert, established: true, + + closed: make(chan struct{}), } m.streams[h.id] = curStream m.cond.Broadcast() // wake (*Mux).AcceptStream @@ -314,6 +316,8 @@ func (m *Mux) DialStream() *Stream { cond: sync.Cond{L: new(sync.Mutex)}, established: false, err: m.err, // stream is unusable if m.err is set + + closed: make(chan struct{}), } m.streams[s.id] = s m.nextID += 2 @@ -341,15 +345,19 @@ func (m *Mux) DialCovertStream() *Stream { // DialStreamContext creates a new Stream with the provided context. When the // context expires, the Stream will be closed and any pending calls will return -// ctx.Err(). DialStreamContext spawns a goroutine whose lifetime matches that -// of the context. +// ctx.Err(). // // Unlike e.g. net.Dial, this does not perform any I/O; the peer will not be // aware of the new Stream until Write is called. func (m *Mux) DialStreamContext(ctx context.Context) *Stream { s := m.DialStream() go func() { - <-ctx.Done() + select { + case <-s.closed: + return + case <-ctx.Done(): + } + s.cond.L.Lock() defer s.cond.L.Unlock() if ctx.Err() != nil && s.err == nil { @@ -425,6 +433,8 @@ type Stream struct { err error readBuf []byte rd, wd time.Time // deadlines + + closed chan struct{} // closed when the Stream is closed } // LocalAddr returns the underlying connection's LocalAddr. @@ -567,6 +577,13 @@ func (s *Stream) Write(p []byte) (int, error) { // Close closes the Stream. The underlying connection is not closed. func (s *Stream) Close() error { + select { + case <-s.closed: + default: + // close the channel to signal the context goroutine to exit + close(s.closed) + } + // cancel outstanding Read/Write calls // // NOTE: Read calls will be interrupted immediately, but Write calls might