Skip to content

Commit

Permalink
p2p: server: fix incorrect bufio usage (#6470)
Browse files Browse the repository at this point in the history
## Motivation

It is not correct to use `bufio.Buffer` to read the initial request
and then passing the underlying stream to the `StreamRequest`
callback, as `bufio.Buffer` may happen to read more data than
necessary, making it unavailable for the `StreamRequest` callback.
This behavior has been obvserved when using QUIC, which tends to
coalesce multiple writes more often.
  • Loading branch information
ivan4th committed Nov 19, 2024
1 parent 86ac1f4 commit b61c1ba
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 39 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ See [RELEASE](./RELEASE.md) for workflow instructions.

* [#6451](https://github.com/spacemeshos/go-spacemesh/pull/6451) Fix a possible deadloop in the beacon protocol.

* [#6470](https://github.com/spacemeshos/go-spacemesh/pull/6470) Fix I/O buffering issue which could be affecting QUIC connections.

## v1.7.6

### Upgrade information
Expand Down
8 changes: 8 additions & 0 deletions p2p/server/deadline_adjuster.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,11 @@ func (dadj *deadlineAdjuster) Write(p []byte) (n int, err error) {
}
return n, nil
}

// ReadByte implements io.ByteReader, which is needed for varint.ReadUvarint, which is
// used to read request length.
func (dadj *deadlineAdjuster) ReadByte() (byte, error) {
var b [1]byte
_, err := io.ReadFull(dadj, b[:])
return b[0], err
}
5 changes: 2 additions & 3 deletions p2p/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,7 @@ func (s *Server) Run(ctx context.Context) error {
func (s *Server) queueHandler(ctx context.Context, peer peer.ID, stream network.Stream) bool {
dadj := newDeadlineAdjuster(stream, s.timeout, s.hardTimeout)
defer dadj.Close()
rd := bufio.NewReader(dadj)
size, err := varint.ReadUvarint(rd)
size, err := varint.ReadUvarint(dadj)
if err != nil {
s.logger.Debug("initial read failed",
zap.String("protocol", s.protocol),
Expand All @@ -326,7 +325,7 @@ func (s *Server) queueHandler(ctx context.Context, peer peer.ID, stream network.
return false
}
buf := make([]byte, size)
_, err = io.ReadFull(rd, buf)
_, err = io.ReadFull(dadj, buf)
if err != nil {
s.logger.Debug("error reading request",
zap.String("protocol", s.protocol),
Expand Down
108 changes: 72 additions & 36 deletions p2p/server/server_test.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
package server

import (
"bufio"
"context"
"errors"
"io"
"sync"
"testing"
"time"

"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
"github.com/spacemeshos/go-scale/tester"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
"golang.org/x/sync/errgroup"

"github.com/spacemeshos/go-spacemesh/codec"
"github.com/spacemeshos/go-spacemesh/p2p/peerinfo"
)

Expand All @@ -40,7 +44,7 @@ func wrapHost(tb testing.TB, h host.Host) Host {
func TestServer(t *testing.T) {
const limit = 1024

mesh, err := mocknet.FullMeshConnected(4)
mesh, err := mocknet.FullMeshConnected(5)
require.NoError(t, err)
proto := "test"
request := []byte("test request")
Expand All @@ -52,6 +56,14 @@ func TestServer(t *testing.T) {
errhandler := func(_ context.Context, _ peer.ID, _ []byte) ([]byte, error) {
return nil, testErr
}
streamHandler := func(ctx context.Context, peer peer.ID, req []byte, stream io.ReadWriter) error {
extra := make([]byte, 2)
_, err := io.ReadFull(stream, extra)
if err != nil {
return err
}
return writeResponse(stream, &Response{Data: append(extra, req...)})
}
opts := []Opt{
WithTimeout(100 * time.Millisecond),
WithLog(zaptest.NewLogger(t)),
Expand All @@ -63,35 +75,39 @@ func TestServer(t *testing.T) {
WrapHandler(handler),
append(opts, WithRequestSizeLimit(2*limit))...,
)
srv1 := New(
wrapHost(t, mesh.Hosts()[1]),
proto,
WrapHandler(handler),
append(opts, WithRequestSizeLimit(limit))...,
)
srv2 := New(
wrapHost(t, mesh.Hosts()[2]),
proto,
WrapHandler(errhandler),
append(opts, WithRequestSizeLimit(limit))...,
)
srv3 := New(
wrapHost(t, mesh.Hosts()[3]),
proto,
WrapHandler(handler),
append(opts, WithRequestSizeLimit(limit))...,
)
srvs := []*Server{
New(
wrapHost(t, mesh.Hosts()[1]),
proto,
WrapHandler(handler),
append(opts, WithRequestSizeLimit(limit))...,
),
New(
wrapHost(t, mesh.Hosts()[2]),
proto,
WrapHandler(errhandler),
append(opts, WithRequestSizeLimit(limit))...,
),
New(
wrapHost(t, mesh.Hosts()[3]),
proto,
WrapHandler(handler),
append(opts, WithRequestSizeLimit(limit))...,
),
New(
wrapHost(t, mesh.Hosts()[4]),
proto,
streamHandler,
append(opts, WithRequestSizeLimit(limit))...,
),
}
ctx, cancel := context.WithCancel(context.Background())
var eg errgroup.Group
eg.Go(func() error {
return srv1.Run(ctx)
})
eg.Go(func() error {
return srv2.Run(ctx)
})
eg.Go(func() error {
return srv3.Run(ctx)
})
for _, srv := range srvs {
eg.Go(func() error {
return srv.Run(ctx)
})
}
require.Eventually(t, func() bool {
for _, h := range mesh.Hosts()[1:] {
if len(h.Mux().Protocols()) == 0 {
Expand All @@ -106,52 +122,52 @@ func TestServer(t *testing.T) {
})

t.Run("ReceiveMessage", func(t *testing.T) {
n := srv1.NumAcceptedRequests()
n := srvs[0].NumAcceptedRequests()
srvID := mesh.Hosts()[1].ID()
response, err := client.Request(ctx, srvID, request)
require.NoError(t, err)
expResponse := append(request, []byte(mesh.Hosts()[0].ID())...)
require.Equal(t, expResponse, response)
srvConns := mesh.Hosts()[1].Network().ConnsToPeer(mesh.Hosts()[0].ID())
require.NotEmpty(t, srvConns)
require.Equal(t, n+1, srv1.NumAcceptedRequests())
require.Equal(t, n+1, srvs[0].NumAcceptedRequests())

clientInfo := client.peerInfo().EnsurePeerInfo(srvID)
require.Equal(t, 1, clientInfo.ClientStats.SuccessCount())
require.Zero(t, clientInfo.ClientStats.FailureCount())

serverInfo := srv1.peerInfo().EnsurePeerInfo(mesh.Hosts()[0].ID())
serverInfo := srvs[0].peerInfo().EnsurePeerInfo(mesh.Hosts()[0].ID())
require.Eventually(t, func() bool {
return serverInfo.ServerStats.SuccessCount() == 1
}, 10*time.Second, 10*time.Millisecond)
require.Zero(t, serverInfo.ServerStats.FailureCount())
})
t.Run("ReceiveNoPeerInfo", func(t *testing.T) {
n := srv1.NumAcceptedRequests()
n := srvs[0].NumAcceptedRequests()
srvID := mesh.Hosts()[3].ID()
response, err := client.Request(ctx, srvID, request)
require.NoError(t, err)
expResponse := append(request, []byte(mesh.Hosts()[0].ID())...)
require.Equal(t, expResponse, response)
srvConns := mesh.Hosts()[3].Network().ConnsToPeer(mesh.Hosts()[0].ID())
require.NotEmpty(t, srvConns)
require.Equal(t, n+1, srv1.NumAcceptedRequests())
require.Equal(t, n+1, srvs[0].NumAcceptedRequests())
})
t.Run("ReceiveError", func(t *testing.T) {
n := srv1.NumAcceptedRequests()
n := srvs[0].NumAcceptedRequests()
srvID := mesh.Hosts()[2].ID()
_, err := client.Request(ctx, srvID, request)
var srvErr *ServerError
require.ErrorAs(t, err, &srvErr)
require.ErrorContains(t, err, "peer error")
require.ErrorContains(t, err, testErr.Error())
require.Equal(t, n+1, srv1.NumAcceptedRequests())
require.Equal(t, n+1, srvs[0].NumAcceptedRequests())

clientInfo := client.peerInfo().EnsurePeerInfo(srvID)
require.Zero(t, clientInfo.ClientStats.SuccessCount())
require.Equal(t, 1, clientInfo.ClientStats.FailureCount())

serverInfo := srv2.peerInfo().EnsurePeerInfo(mesh.Hosts()[0].ID())
serverInfo := srvs[1].peerInfo().EnsurePeerInfo(mesh.Hosts()[0].ID())
require.Eventually(t, func() bool {
return serverInfo.ServerStats.FailureCount() == 1
}, 10*time.Second, 10*time.Millisecond)
Expand All @@ -173,6 +189,26 @@ func TestServer(t *testing.T) {
)
require.Error(t, err)
})
t.Run("coalesced data", func(t *testing.T) {
stream, err := mesh.Hosts()[0].NewStream(ctx, mesh.Hosts()[4].ID(), protocol.ID(proto))
require.NoError(t, err)
defer stream.Close()
request := []byte{
0x04, // initial request length = 4 (varint)
0x00, 0x01, 0x02, 0x03, // initial request, instructs the handler to read 2 bytes
0x2a, 0x2b, // 2 bytes to read
}
// If the server reads too much data with initial request, it will then
// fail to read the data following it.
_, err = stream.Write(request)
require.NoError(t, err)

var r Response
rd := bufio.NewReader(stream)
_, err = codec.DecodeFrom(rd, &r)
require.NoError(t, err)
require.Equal(t, []byte{0x2a, 0x2b, 0x00, 0x01, 0x02, 0x03}, r.Data)
})
}

func Test_Queued(t *testing.T) {
Expand Down

0 comments on commit b61c1ba

Please sign in to comment.