diff --git a/v1/mux.go b/v1/mux.go index 0a240ed..a1c4c09 100644 --- a/v1/mux.go +++ b/v1/mux.go @@ -6,6 +6,7 @@ import ( "crypto/ed25519" "errors" "fmt" + "io" "net" "os" "sync" @@ -59,6 +60,7 @@ func (m *Mux) setErr(err error) error { for _, s := range m.streams { s.cond.L.Lock() s.err = err + s.readBuf = nil s.cond.Broadcast() s.cond.L.Unlock() } @@ -381,7 +383,7 @@ func (s *Stream) consumeFrame(h frameHeader, payload []byte) { // set payload and wait for it to be consumed s.readBuf = payload s.cond.Broadcast() // wake Read - for len(s.readBuf) != 0 && s.err == nil { + for len(s.readBuf) != 0 { s.cond.Wait() } } @@ -391,24 +393,28 @@ func (s *Stream) Read(p []byte) (int, error) { s.cond.L.Lock() defer s.cond.L.Unlock() if !s.rd.IsZero() { - if !time.Now().Before(s.rd) { - return 0, os.ErrDeadlineExceeded - } - timer := time.AfterFunc(time.Until(s.rd), s.cond.Broadcast) - defer timer.Stop() + defer time.AfterFunc(time.Until(s.rd), s.cond.Broadcast).Stop() } for len(s.readBuf) == 0 && s.err == nil && (s.rd.IsZero() || time.Now().Before(s.rd)) { s.cond.Wait() } - if s.err != nil { - return 0, s.err - } else if !s.rd.IsZero() && !time.Now().Before(s.rd) { - return 0, os.ErrDeadlineExceeded - } n := copy(p, s.readBuf) s.readBuf = s.readBuf[n:] - s.cond.Broadcast() // wake consumeFrame - return n, s.err + + err := s.err + if err == ErrPeerClosedStream { + err = io.EOF + } else if !(s.rd.IsZero() || time.Now().Before(s.rd)) { + err = os.ErrDeadlineExceeded + } else if err != nil { + s.readBuf = nil // if the error is fatal, drop the rest of the buffer + } + if len(s.readBuf) > 0 { + err = nil // if more data is available, silence the error + } else { + s.cond.Broadcast() // wake consumeFrame + } + return n, err } // Write writes data to the Stream. @@ -452,6 +458,7 @@ func (s *Stream) Close() error { s.cond.L.Lock() defer s.cond.L.Unlock() s.err = ErrClosedStream + s.readBuf = nil s.cond.Broadcast() return err } diff --git a/v1/mux_test.go b/v1/mux_test.go index 4824039..305aac0 100644 --- a/v1/mux_test.go +++ b/v1/mux_test.go @@ -262,7 +262,7 @@ func TestDeadline(t *testing.T) { // need to write a fairly large message; otherwise the packets just // get buffered and "succeed" instantly - if _, err := s.Write(make([]byte, 1<<20)); err != nil { + if _, err := s.Write(make([]byte, m.settings.RequestedPacketSize*20)); err != nil { return err } else if _, err := io.ReadFull(s, buf[:13]); err != nil { return err diff --git a/v2/mux.go b/v2/mux.go index e78afb4..5641795 100644 --- a/v2/mux.go +++ b/v2/mux.go @@ -66,6 +66,7 @@ func (m *Mux) setErr(err error) error { for _, s := range m.streams { s.cond.L.Lock() s.err = err + s.readBuf = nil s.cond.Broadcast() s.cond.L.Unlock() } @@ -500,9 +501,12 @@ func (s *Stream) consumeFrame(h frameHeader, payload []byte) { // set payload and wait for it to be consumed s.cond.L.Lock() defer s.cond.L.Unlock() + if s.err != nil { + return + } s.readBuf = payload s.cond.Broadcast() // wake Read - for len(s.readBuf) > 0 && s.err == nil && (s.rd.IsZero() || time.Now().Before(s.rd)) { + for len(s.readBuf) > 0 { s.cond.Wait() } } @@ -516,27 +520,28 @@ func (s *Stream) Read(p []byte) (int, error) { panic("mux: Read called before Write on newly-Dialed Stream") } if !s.rd.IsZero() { - if !time.Now().Before(s.rd) { - return 0, os.ErrDeadlineExceeded - } - timer := time.AfterFunc(time.Until(s.rd), s.cond.Broadcast) - defer timer.Stop() + defer time.AfterFunc(time.Until(s.rd), s.cond.Broadcast).Stop() } for len(s.readBuf) == 0 && s.err == nil && (s.rd.IsZero() || time.Now().Before(s.rd)) { s.cond.Wait() } - if s.err != nil { - if s.err == ErrPeerClosedStream { - return 0, io.EOF - } - return 0, s.err - } else if !s.rd.IsZero() && !time.Now().Before(s.rd) { - return 0, os.ErrDeadlineExceeded - } n := copy(p, s.readBuf) s.readBuf = s.readBuf[n:] - s.cond.Broadcast() // wake consumeFrame - return n, nil + + err := s.err + if err == ErrPeerClosedStream { + err = io.EOF + } else if !(s.rd.IsZero() || time.Now().Before(s.rd)) { + err = os.ErrDeadlineExceeded + } else if err != nil { + s.readBuf = nil // if the error is fatal, drop the rest of the buffer + } + if len(s.readBuf) > 0 { + err = nil // if more data is available, silence the error + } else { + s.cond.Broadcast() // wake consumeFrame + } + return n, err } // Write writes data to the Stream. @@ -590,6 +595,7 @@ func (s *Stream) Close() error { return nil } s.err = ErrClosedStream + s.readBuf = nil s.cond.Broadcast() s.cond.L.Unlock() diff --git a/v2/mux_test.go b/v2/mux_test.go index 00c466a..03dc110 100644 --- a/v2/mux_test.go +++ b/v2/mux_test.go @@ -277,7 +277,7 @@ func TestDeadline(t *testing.T) { // need to write a fairly large message; otherwise the packets just // get buffered and "succeed" instantly if _, err := s.Write(make([]byte, m1.settings.PacketSize*20)); err != nil { - return fmt.Errorf("foo: %w", err) + return err } else if _, err := io.ReadFull(s, buf[:13]); err != nil { return err } else if string(buf) != "hello, world!" {