Skip to content

Commit

Permalink
Merge branch 'master' into nate/goroutine-cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
n8maninger authored Sep 14, 2024
2 parents d930243 + b7fef91 commit 7635545
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 31 deletions.
33 changes: 20 additions & 13 deletions v1/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/ed25519"
"errors"
"fmt"
"io"
"net"
"os"
"sync"
Expand Down Expand Up @@ -59,6 +60,7 @@ func (m *Mux) setErr(err error) error {
for _, s := range m.streams {
s.cond.L.Lock()
s.err = err
s.readBuf = nil
s.cond.Broadcast()
s.cond.L.Unlock()
}
Expand Down Expand Up @@ -381,7 +383,7 @@ func (s *Stream) consumeFrame(h frameHeader, payload []byte) {
// set payload and wait for it to be consumed
s.readBuf = payload
s.cond.Broadcast() // wake Read
for len(s.readBuf) != 0 && s.err == nil {
for len(s.readBuf) != 0 {
s.cond.Wait()
}
}
Expand All @@ -391,24 +393,28 @@ func (s *Stream) Read(p []byte) (int, error) {
s.cond.L.Lock()
defer s.cond.L.Unlock()
if !s.rd.IsZero() {
if !time.Now().Before(s.rd) {
return 0, os.ErrDeadlineExceeded
}
timer := time.AfterFunc(time.Until(s.rd), s.cond.Broadcast)
defer timer.Stop()
defer time.AfterFunc(time.Until(s.rd), s.cond.Broadcast).Stop()
}
for len(s.readBuf) == 0 && s.err == nil && (s.rd.IsZero() || time.Now().Before(s.rd)) {
s.cond.Wait()
}
if s.err != nil {
return 0, s.err
} else if !s.rd.IsZero() && !time.Now().Before(s.rd) {
return 0, os.ErrDeadlineExceeded
}
n := copy(p, s.readBuf)
s.readBuf = s.readBuf[n:]
s.cond.Broadcast() // wake consumeFrame
return n, s.err

err := s.err
if err == ErrPeerClosedStream {
err = io.EOF
} else if !(s.rd.IsZero() || time.Now().Before(s.rd)) {
err = os.ErrDeadlineExceeded
} else if err != nil {
s.readBuf = nil // if the error is fatal, drop the rest of the buffer
}
if len(s.readBuf) > 0 {
err = nil // if more data is available, silence the error
} else {
s.cond.Broadcast() // wake consumeFrame
}
return n, err
}

// Write writes data to the Stream.
Expand Down Expand Up @@ -452,6 +458,7 @@ func (s *Stream) Close() error {
s.cond.L.Lock()
defer s.cond.L.Unlock()
s.err = ErrClosedStream
s.readBuf = nil
s.cond.Broadcast()
return err
}
Expand Down
2 changes: 1 addition & 1 deletion v1/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ func TestDeadline(t *testing.T) {

// need to write a fairly large message; otherwise the packets just
// get buffered and "succeed" instantly
if _, err := s.Write(make([]byte, 1<<20)); err != nil {
if _, err := s.Write(make([]byte, m.settings.RequestedPacketSize*20)); err != nil {
return err
} else if _, err := io.ReadFull(s, buf[:13]); err != nil {
return err
Expand Down
38 changes: 22 additions & 16 deletions v2/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ func (m *Mux) setErr(err error) error {
for _, s := range m.streams {
s.cond.L.Lock()
s.err = err
s.readBuf = nil
s.cond.Broadcast()
s.cond.L.Unlock()
}
Expand Down Expand Up @@ -500,9 +501,12 @@ func (s *Stream) consumeFrame(h frameHeader, payload []byte) {
// set payload and wait for it to be consumed
s.cond.L.Lock()
defer s.cond.L.Unlock()
if s.err != nil {
return
}
s.readBuf = payload
s.cond.Broadcast() // wake Read
for len(s.readBuf) > 0 && s.err == nil && (s.rd.IsZero() || time.Now().Before(s.rd)) {
for len(s.readBuf) > 0 {
s.cond.Wait()
}
}
Expand All @@ -516,27 +520,28 @@ func (s *Stream) Read(p []byte) (int, error) {
panic("mux: Read called before Write on newly-Dialed Stream")
}
if !s.rd.IsZero() {
if !time.Now().Before(s.rd) {
return 0, os.ErrDeadlineExceeded
}
timer := time.AfterFunc(time.Until(s.rd), s.cond.Broadcast)
defer timer.Stop()
defer time.AfterFunc(time.Until(s.rd), s.cond.Broadcast).Stop()
}
for len(s.readBuf) == 0 && s.err == nil && (s.rd.IsZero() || time.Now().Before(s.rd)) {
s.cond.Wait()
}
if s.err != nil {
if s.err == ErrPeerClosedStream {
return 0, io.EOF
}
return 0, s.err
} else if !s.rd.IsZero() && !time.Now().Before(s.rd) {
return 0, os.ErrDeadlineExceeded
}
n := copy(p, s.readBuf)
s.readBuf = s.readBuf[n:]
s.cond.Broadcast() // wake consumeFrame
return n, nil

err := s.err
if err == ErrPeerClosedStream {
err = io.EOF
} else if !(s.rd.IsZero() || time.Now().Before(s.rd)) {
err = os.ErrDeadlineExceeded
} else if err != nil {
s.readBuf = nil // if the error is fatal, drop the rest of the buffer
}
if len(s.readBuf) > 0 {
err = nil // if more data is available, silence the error
} else {
s.cond.Broadcast() // wake consumeFrame
}
return n, err
}

// Write writes data to the Stream.
Expand Down Expand Up @@ -590,6 +595,7 @@ func (s *Stream) Close() error {
return nil
}
s.err = ErrClosedStream
s.readBuf = nil
s.cond.Broadcast()
s.cond.L.Unlock()

Expand Down
2 changes: 1 addition & 1 deletion v2/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ func TestDeadline(t *testing.T) {
// need to write a fairly large message; otherwise the packets just
// get buffered and "succeed" instantly
if _, err := s.Write(make([]byte, m1.settings.PacketSize*20)); err != nil {
return fmt.Errorf("foo: %w", err)
return err
} else if _, err := io.ReadFull(s, buf[:13]); err != nil {
return err
} else if string(buf) != "hello, world!" {
Expand Down

0 comments on commit 7635545

Please sign in to comment.