Skip to content

Commit

Permalink
Adding timeouts on the async handlers for p2p
Browse files Browse the repository at this point in the history
  • Loading branch information
gameofpointers committed May 22, 2024
1 parent 5a12e9b commit 23fd370
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 58 deletions.
22 changes: 17 additions & 5 deletions p2p/node/api.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package node

import (
"context"
"math/big"
"reflect"
"runtime/debug"
Expand All @@ -24,6 +25,8 @@ import (
"github.com/dominant-strategies/go-quai/common"
)

const requestTimeout = 10 * time.Second

// Starts the node and all of its services
func (p *P2PNode) Start() error {
log.Global.Infof("starting P2P node...")
Expand Down Expand Up @@ -116,7 +119,7 @@ func (p *P2PNode) Stop() error {
}
}

func (p *P2PNode) requestFromPeers(topic *pubsubManager.Topic, requestData interface{}, respDataType interface{}, resultChan chan interface{}) {
func (p *P2PNode) requestFromPeers(ctx context.Context, topic *pubsubManager.Topic, requestData interface{}, respDataType interface{}, resultChan chan interface{}) {
go func() {
defer func() {
if r := recover(); r != nil {
Expand All @@ -137,6 +140,7 @@ func (p *P2PNode) requestFromPeers(topic *pubsubManager.Topic, requestData inter
for peerID := range peers {
requestWg.Add(1)
go func(peerID peer.ID) {
defer requestWg.Done()
defer func() {
if r := recover(); r != nil {
log.Global.WithFields(log.Fields{
Expand All @@ -145,15 +149,14 @@ func (p *P2PNode) requestFromPeers(topic *pubsubManager.Topic, requestData inter
}).Error("Go-Quai Panicked")
}
}()
defer requestWg.Done()
p.requestAndWait(peerID, topic, requestData, respDataType, resultChan)
p.requestAndWait(ctx, peerID, topic, requestData, respDataType, resultChan)
}(peerID)
}
requestWg.Wait()
}()
}

func (p *P2PNode) requestAndWait(peerID peer.ID, topic *pubsubManager.Topic, reqData interface{}, respDataType interface{}, resultChan chan interface{}) {
func (p *P2PNode) requestAndWait(ctx context.Context, peerID peer.ID, topic *pubsubManager.Topic, reqData interface{}, respDataType interface{}, resultChan chan interface{}) {
defer func() {
if r := recover(); r != nil {
log.Global.WithFields(log.Fields{
Expand All @@ -175,6 +178,13 @@ func (p *P2PNode) requestAndWait(peerID peer.ID, topic *pubsubManager.Topic, req
select {
case resultChan <- recvd:
// Data sent successfully
case <-ctx.Done():
// Request timed out, return
log.Global.WithFields(log.Fields{
"peerId": peerID,
"message": "Request timed out, data not sent",
}).Warning("Missed data request")

default:
// Optionally log the missed send or handle it in another way
log.Global.WithFields(log.Fields{
Expand Down Expand Up @@ -206,6 +216,8 @@ func (p *P2PNode) Request(location common.Location, requestData interface{}, res
}

resultChan := make(chan interface{}, 10)
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
defer cancel()
// If it is a hash, first check to see if it is contained in the caches
if hash, ok := requestData.(common.Hash); ok {
result, ok := p.cacheGet(hash, responseDataType, location)
Expand All @@ -215,7 +227,7 @@ func (p *P2PNode) Request(location common.Location, requestData interface{}, res
}
}

p.requestFromPeers(topic, requestData, responseDataType, resultChan)
p.requestFromPeers(ctx, topic, requestData, responseDataType, resultChan)
// TODO: optimize with waitgroups or a doneChan to only query if no peers responded
// Right now this creates too many streams, so don't call this until we have a better solution
// p.queryDHT(location, requestData, responseDataType, resultChan)
Expand Down
94 changes: 48 additions & 46 deletions p2p/node/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,56 +42,58 @@ func (p *P2PNode) eventLoop() {
for {
select {
case evt := <-sub.Out():
switch e := evt.(type) {
case event.EvtLocalProtocolsUpdated:
log.Global.Debugf("Event: 'Local protocols updated' - added: %+v, removed: %+v", e.Added, e.Removed)
case event.EvtLocalAddressesUpdated:
p2pAddr, err := p.p2pAddress()
if err != nil {
log.Global.Errorf("error computing p2p address: %s", err)
} else {
for _, addr := range e.Current {
addr := addr.Address.Encapsulate(p2pAddr)
log.Global.Infof("Event: 'Local address updated': %s", addr)
go func(evt interface{}) {
switch e := evt.(type) {
case event.EvtLocalProtocolsUpdated:
log.Global.Debugf("Event: 'Local protocols updated' - added: %+v, removed: %+v", e.Added, e.Removed)
case event.EvtLocalAddressesUpdated:
p2pAddr, err := p.p2pAddress()
if err != nil {
log.Global.Errorf("error computing p2p address: %s", err)
} else {
for _, addr := range e.Current {
addr := addr.Address.Encapsulate(p2pAddr)
log.Global.Infof("Event: 'Local address updated': %s", addr)
}
// log removed addresses
for _, addr := range e.Removed {
addr := addr.Address.Encapsulate(p2pAddr)
log.Global.Infof("Event: 'Local address removed': %s", addr)
}
}
// log removed addresses
for _, addr := range e.Removed {
addr := addr.Address.Encapsulate(p2pAddr)
log.Global.Infof("Event: 'Local address removed': %s", addr)
case event.EvtLocalReachabilityChanged:
log.Global.Debugf("Event: 'Local reachability changed': %+v", e.Reachability)
case event.EvtNATDeviceTypeChanged:
log.Global.Debugf("Event: 'NAT device type changed' - DeviceType %v, transport: %v", e.NatDeviceType.String(), e.TransportProtocol.String())
case event.EvtPeerProtocolsUpdated:
log.Global.Debugf("Event: 'Peer protocols updated' - added: %+v, removed: %+v, peer: %+v", e.Added, e.Removed, e.Peer)
case event.EvtPeerIdentificationCompleted:
log.Global.Debugf("Event: 'Peer identification completed' - %v", e.Peer)
case event.EvtPeerIdentificationFailed:
log.Global.Debugf("Event 'Peer identification failed' - peer: %v, reason: %v", e.Peer, e.Reason.Error())
case event.EvtPeerConnectednessChanged:
// get the peer info
peerInfo := p.peerManager.GetHost().Peerstore().PeerInfo(e.Peer)
// get the peer ID
peerID := peerInfo.ID
// get the peer protocols
peerProtocols, err := p.peerManager.GetHost().Peerstore().GetProtocols(peerID)
if err != nil {
log.Global.Errorf("error getting peer protocols: %s", err)
}
}
case event.EvtLocalReachabilityChanged:
log.Global.Debugf("Event: 'Local reachability changed': %+v", e.Reachability)
case event.EvtNATDeviceTypeChanged:
log.Global.Debugf("Event: 'NAT device type changed' - DeviceType %v, transport: %v", e.NatDeviceType.String(), e.TransportProtocol.String())
case event.EvtPeerProtocolsUpdated:
log.Global.Debugf("Event: 'Peer protocols updated' - added: %+v, removed: %+v, peer: %+v", e.Added, e.Removed, e.Peer)
case event.EvtPeerIdentificationCompleted:
log.Global.Debugf("Event: 'Peer identification completed' - %v", e.Peer)
case event.EvtPeerIdentificationFailed:
log.Global.Debugf("Event 'Peer identification failed' - peer: %v, reason: %v", e.Peer, e.Reason.Error())
case event.EvtPeerConnectednessChanged:
// get the peer info
peerInfo := p.peerManager.GetHost().Peerstore().PeerInfo(e.Peer)
// get the peer ID
peerID := peerInfo.ID
// get the peer protocols
peerProtocols, err := p.peerManager.GetHost().Peerstore().GetProtocols(peerID)
if err != nil {
log.Global.Errorf("error getting peer protocols: %s", err)
}
// get the peer addresses
peerAddresses := p.peerManager.GetHost().Peerstore().Addrs(peerID)
log.Global.Debugf("Event: 'Peer connectedness change' - Peer %s (peerInfo: %+v) is now %s, protocols: %v, addresses: %v", peerID.String(), peerInfo, e.Connectedness, peerProtocols, peerAddresses)
// get the peer addresses
peerAddresses := p.peerManager.GetHost().Peerstore().Addrs(peerID)
log.Global.Debugf("Event: 'Peer connectedness change' - Peer %s (peerInfo: %+v) is now %s, protocols: %v, addresses: %v", peerID.String(), peerInfo, e.Connectedness, peerProtocols, peerAddresses)

if e.Connectedness == network.NotConnected {
p.peerManager.RemovePeer(peerID)
if e.Connectedness == network.NotConnected {
p.peerManager.RemovePeer(peerID)
}
case *event.EvtNATDeviceTypeChanged:
log.Global.Debugf("Event `NAT device type changed` - DeviceType %v, transport: %v", e.NatDeviceType.String(), e.TransportProtocol.String())
default:
log.Global.Debugf("Received unknown event (type: %T): %+v", e, e)
}
case *event.EvtNATDeviceTypeChanged:
log.Global.Debugf("Event `NAT device type changed` - DeviceType %v, transport: %v", e.NatDeviceType.String(), e.TransportProtocol.String())
default:
log.Global.Debugf("Received unknown event (type: %T): %+v", e, e)
}
}(evt)
case <-p.ctx.Done():
log.Global.Warnf("Context cancel received. Stopping event listener")
return
Expand Down
40 changes: 40 additions & 0 deletions p2p/node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ type P2PNode struct {

// runtime context
ctx context.Context

// used to control all the different sub processes of the P2PNode
cancel context.CancelFunc

// host management interface
host host.Host

// dht interface
dht *dual.DHT
}

// Returns a new libp2p node.
Expand Down Expand Up @@ -187,6 +196,7 @@ func NewNode(ctx context.Context, quitCh chan struct{}) (*P2PNode, error) {
return nil, err
}

ctx, cancel := context.WithCancel(ctx)
p2p := &P2PNode{
ctx: ctx,
bootpeers: bootpeers,
Expand All @@ -195,6 +205,9 @@ func NewNode(ctx context.Context, quitCh chan struct{}) (*P2PNode, error) {
requestManager: requestManager.NewManager(),
cache: initializeCaches(common.GenerateLocations(common.MaxRegions, common.MaxZones)),
quitCh: quitCh,
cancel: cancel,
host: host,
dht: dht,
}

sm, err := streamManager.NewStreamManager(peerManager.C_peerCount, p2p, host)
Expand All @@ -206,6 +219,33 @@ func NewNode(ctx context.Context, quitCh chan struct{}) (*P2PNode, error) {
return p2p, nil
}

// Close performs cleanup of resources used by P2PNode
func (p *P2PNode) Close() error {
p.cancel()
// Close PubSub manager
if err := p.pubsub.Stop(); err != nil {
log.Global.Errorf("error closing pubsub manager: %s", err)
}

// Close the stream manager
if err := p.peerManager.Stop(); err != nil {
log.Global.Errorf("error closing peer manager: %s", err)
}

// Close DHT
if err := p.dht.Close(); err != nil {
log.Global.Errorf("error closing DHT: %s", err)
}

// Close the libp2p host
if err := p.host.Close(); err != nil {
log.Global.Errorf("error closing libp2p host: %s", err)
}

close(p.quitCh)
return nil
}

func initializeCaches(locations []common.Location) map[string]map[string]*lru.Cache[common.Hash, interface{}] {
caches := make(map[string]map[string]*lru.Cache[common.Hash, interface{}])
for _, location := range locations {
Expand Down
3 changes: 2 additions & 1 deletion p2p/node/pubsubManager/gossipsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ func (g *PubsubManager) Subscribe(location common.Location, datatype interface{}
}()
// Create a channel for messages
msgChan := make(chan *pubsub.Message, msgChanSize)
// close the msgChan if we exit this function
defer close(msgChan)
full := 0
// Start worker goroutines
for i := 0; i < numWorkers; i++ {
Expand All @@ -154,7 +156,6 @@ func (g *PubsubManager) Subscribe(location common.Location, datatype interface{}
if err != nil || msg == nil {
// if context was cancelled, then we are shutting down
if g.ctx.Err() != nil {
close(msgChan)
return
}
log.Global.Errorf("error getting next message from subscription: %s", err)
Expand Down
2 changes: 1 addition & 1 deletion p2p/node/requestManager/requestManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func NewManager() RequestManager {
}
}

// Generates a new random uint32 as request ID
// CreateRequest generates a new random uint32 as request ID
func (m *requestIDManager) CreateRequest() uint32 {
var id uint32
for {
Expand Down
28 changes: 23 additions & 5 deletions p2p/protocol/handler.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package protocol

import (
"context"
"errors"
"io"
"math/big"
"runtime/debug"
"sync"

"github.com/libp2p/go-libp2p/core/network"
"github.com/sirupsen/logrus"
Expand All @@ -16,8 +18,11 @@ import (
"github.com/dominant-strategies/go-quai/trie"
)

const numWorkers = 10 // Number of workers per stream
const msgChanSize = 10 // 10 requests per stream
const (
numWorkers = 10 // Number of workers per stream
msgChanSize = 100 // 100 requests per stream
protocolName = "quai"
)

// QuaiProtocolHandler handles all the incoming requests and responds with corresponding data
func QuaiProtocolHandler(stream network.Stream, node QuaiP2PNode) {
Expand All @@ -42,11 +47,21 @@ func QuaiProtocolHandler(stream network.Stream, node QuaiP2PNode) {
// Create a channel for messages
msgChan := make(chan []byte, msgChanSize)
full := 0

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var once sync.Once
// Start worker goroutines
for i := 0; i < numWorkers; i++ {
go func() {
for message := range msgChan { // This should exit when msgChan is closed
handleMessage(message, stream, node)
for {
select {
case message := <-msgChan:
handleMessage(message, stream, node)
case <-ctx.Done():
once.Do(func() { close(msgChan) })
return
}
}
}()
}
Expand All @@ -56,7 +71,7 @@ func QuaiProtocolHandler(stream network.Stream, node QuaiP2PNode) {
data, err := common.ReadMessageFromStream(stream)
if err != nil {
if errors.Is(err, network.ErrReset) || errors.Is(err, io.EOF) {
close(msgChan)
once.Do(func() { close(msgChan) })
return
}

Expand All @@ -68,6 +83,9 @@ func QuaiProtocolHandler(stream network.Stream, node QuaiP2PNode) {
// Send to worker goroutines
select {
case msgChan <- data:
case <-ctx.Done():
once.Do(func() { close(msgChan) })
return
default:
if full%1000 == 0 {
log.Global.WithField("stream with peer", stream.Conn().RemotePeer()).Warnf("QuaiProtocolHandler message channel is full. Lost messages: %d", full)
Expand Down

0 comments on commit 23fd370

Please sign in to comment.