Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: race condition when adding new channel to NodeInfo #735

Merged
merged 4 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions internal/libs/sync/concurrent_slice.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package sync

import (
"encoding/json"
"sync"
)

// ConcurrentSlice is a thread-safe slice.
//
// It is safe to use from multiple goroutines without additional locking.
// It should be referenced by pointer.
//
// Initialize using NewConcurrentSlice().
type ConcurrentSlice[T any] struct {
mtx sync.RWMutex
items []T
}

// NewConcurrentSlice creates a new thread-safe slice.
func NewConcurrentSlice[T any](initial ...T) *ConcurrentSlice[T] {
return &ConcurrentSlice[T]{
items: initial,
}
}

// Append adds an element to the slice
func (s *ConcurrentSlice[T]) Append(val ...T) {
s.mtx.Lock()
defer s.mtx.Unlock()

s.items = append(s.items, val...)
}

// Reset removes all elements from the slice
func (s *ConcurrentSlice[T]) Reset() {
s.mtx.Lock()
defer s.mtx.Unlock()

s.items = []T{}
}

// Get returns the value at the given index
func (s *ConcurrentSlice[T]) Get(index int) T {
s.mtx.RLock()
defer s.mtx.RUnlock()

return s.items[index]
}

// Set updates the value at the given index.
// If the index is greater than the length of the slice, it panics.
// If the index is equal to the length of the slice, the value is appended.
// Otherwise, the value at the index is updated.
func (s *ConcurrentSlice[T]) Set(index int, val T) {
s.mtx.Lock()
defer s.mtx.Unlock()

if index > len(s.items) {
panic("index out of range")
} else if index == len(s.items) {
s.items = append(s.items, val)
return
}

s.items[index] = val
}

// ToSlice returns a copy of the underlying slice
func (s *ConcurrentSlice[T]) ToSlice() []T {
s.mtx.RLock()
defer s.mtx.RUnlock()

slice := make([]T, len(s.items))
copy(slice, s.items)
return slice
}

// Len returns the length of the slice
func (s *ConcurrentSlice[T]) Len() int {
s.mtx.RLock()
defer s.mtx.RUnlock()

return len(s.items)
}

// Copy returns a new deep copy of concurrentSlice with the same elements
func (s *ConcurrentSlice[T]) Copy() ConcurrentSlice[T] {
s.mtx.RLock()
defer s.mtx.RUnlock()

return ConcurrentSlice[T]{
items: s.ToSlice(),
}
}

// MarshalJSON implements the json.Marshaler interface.
func (cs *ConcurrentSlice[T]) MarshalJSON() ([]byte, error) {
cs.mtx.RLock()
defer cs.mtx.RUnlock()

return json.Marshal(cs.items)
}

// UnmarshalJSON implements the json.Unmarshaler interface.
func (cs *ConcurrentSlice[T]) UnmarshalJSON(data []byte) error {
var items []T
if err := json.Unmarshal(data, &items); err != nil {
return err
}

cs.mtx.Lock()
defer cs.mtx.Unlock()

cs.items = items
return nil
}
96 changes: 96 additions & 0 deletions internal/libs/sync/concurrent_slice_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package sync

import (
"encoding/json"
"sync"
"testing"

"github.com/stretchr/testify/assert"
)

func TestConcurrentSlice(t *testing.T) {
s := NewConcurrentSlice[int](1, 2, 3)

// Test Append
s.Append(4)
if s.Len() != 4 {
t.Errorf("Expected length of slice to be 4, got %d", s.Len())
}

// Test Get
if s.Get(3) != 4 {
t.Errorf("Expected element at index 3 to be 4, got %d", s.Get(3))
}

// Test Set
s.Set(1, 5)

// Test ToSlice
slice := s.ToSlice()
if len(slice) != 4 || slice[3] != 4 || slice[1] != 5 {
t.Errorf("Expected ToSlice to return [1 5 3 4], got %v", slice)
}

// Test Reset
s.Reset()
if s.Len() != 0 {
t.Errorf("Expected length of slice to be 0 after Reset, got %d", s.Len())
}

// Test Copy
s.Append(5)
copy := s.Copy()
if copy.Len() != 1 || copy.Get(0) != 5 {
t.Errorf("Expected Copy to return a new slice with [5], got %v", copy.ToSlice())
}
}

func TestConcurrentSlice_Concurrency(t *testing.T) {
s := NewConcurrentSlice[int]()

var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(val int) {
defer wg.Done()
s.Append(val)
}(i)
}

wg.Wait()

assert.Equal(t, 100, s.Len())

if s.Len() != 100 {
t.Errorf("Expected length of slice to be 100, got %d", s.Len())
}

for i := 0; i < 100; i++ {
assert.Contains(t, s.ToSlice(), i)
}
}

func TestConcurrentSlice_MarshalUnmarshalJSON(t *testing.T) {
type node struct {
Channels *ConcurrentSlice[uint16]
}
cs := NewConcurrentSlice[uint16](1, 2, 3)

node1 := node{
Channels: cs,
}

// Marshal to JSON
data, err := json.Marshal(node1)
assert.NoError(t, err, "Failed to marshal concurrentSlice")

// Unmarshal from JSON
node2 := node{
// Channels: NewConcurrentSlice[uint16](),
}

err = json.Unmarshal(data, &node2)
assert.NoError(t, err, "Failed to unmarshal concurrentSlice")

assert.EqualValues(t, node1.Channels.ToSlice(), node2.Channels.ToSlice())
}
5 changes: 3 additions & 2 deletions internal/p2p/p2p_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package p2p_test
import (
"github.com/dashpay/tenderdash/crypto"
"github.com/dashpay/tenderdash/crypto/ed25519"
tmsync "github.com/dashpay/tenderdash/internal/libs/sync"
"github.com/dashpay/tenderdash/internal/p2p"
"github.com/dashpay/tenderdash/types"
)
Expand All @@ -25,7 +26,7 @@ var (
ListenAddr: "0.0.0.0:0",
Network: "test",
Moniker: string(selfID),
Channels: []byte{0x01, 0x02},
Channels: tmsync.NewConcurrentSlice[uint16](0x01, 0x02),
}

peerKey crypto.PrivKey = ed25519.GenPrivKeyFromSecret([]byte{0x84, 0xd7, 0x01, 0xbf, 0x83, 0x20, 0x1c, 0xfe})
Expand All @@ -35,6 +36,6 @@ var (
ListenAddr: "0.0.0.0:0",
Network: "test",
Moniker: string(peerID),
Channels: []byte{0x01, 0x02},
Channels: tmsync.NewConcurrentSlice[uint16](0x01, 0x02),
}
)
2 changes: 2 additions & 0 deletions internal/p2p/p2ptest/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/dashpay/tenderdash/config"
"github.com/dashpay/tenderdash/crypto"
"github.com/dashpay/tenderdash/crypto/ed25519"
tmsync "github.com/dashpay/tenderdash/internal/libs/sync"
"github.com/dashpay/tenderdash/internal/p2p"
p2pclient "github.com/dashpay/tenderdash/internal/p2p/client"
"github.com/dashpay/tenderdash/libs/log"
Expand Down Expand Up @@ -272,6 +273,7 @@ func (n *Network) MakeNode(ctx context.Context, t *testing.T, proTxHash crypto.P
ListenAddr: "0.0.0.0:0", // FIXME: We have to fake this for now.
Moniker: string(nodeID),
ProTxHash: proTxHash.Copy(),
Channels: tmsync.NewConcurrentSlice[uint16](),
}

transport := n.memoryNetwork.CreateTransport(nodeID)
Expand Down
10 changes: 5 additions & 5 deletions internal/p2p/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ func (r *Router) openConnection(ctx context.Context, conn Connection) {
return
}

r.routePeer(ctx, peerInfo.NodeID, conn, toChannelIDs(peerInfo.Channels))
r.routePeer(ctx, peerInfo.NodeID, conn, toChannelIDs(peerInfo.Channels.ToSlice()))
}

// dialPeers maintains outbound connections to peers by dialing them.
Expand Down Expand Up @@ -589,7 +589,7 @@ func (r *Router) connectPeer(ctx context.Context, address NodeAddress) {
}

// routePeer (also) calls connection close
go r.routePeer(ctx, address.NodeID, conn, toChannelIDs(peerInfo.Channels))
go r.routePeer(ctx, address.NodeID, conn, toChannelIDs(peerInfo.Channels.ToSlice()))
}

func (r *Router) getOrMakeQueue(peerID types.NodeID, channels ChannelIDSet) queue {
Expand Down Expand Up @@ -943,9 +943,9 @@ func (cs ChannelIDSet) Contains(id ChannelID) bool {
return ok
}

func toChannelIDs(bytes []byte) ChannelIDSet {
c := make(map[ChannelID]struct{}, len(bytes))
for _, b := range bytes {
func toChannelIDs(ids []uint16) ChannelIDSet {
c := make(map[ChannelID]struct{}, len(ids))
for _, b := range ids {
c[ChannelID(b)] = struct{}{}
}
return c
Expand Down
7 changes: 5 additions & 2 deletions internal/p2p/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
dbm "github.com/tendermint/tm-db"

"github.com/dashpay/tenderdash/crypto"
tmsync "github.com/dashpay/tenderdash/internal/libs/sync"
"github.com/dashpay/tenderdash/internal/p2p"
"github.com/dashpay/tenderdash/internal/p2p/mocks"
"github.com/dashpay/tenderdash/internal/p2p/p2ptest"
Expand Down Expand Up @@ -303,6 +304,7 @@ func TestRouter_AcceptPeers(t *testing.T) {
ListenAddr: "0.0.0.0:0",
Network: "other-network",
Moniker: string(peerID),
Channels: tmsync.NewConcurrentSlice[uint16](),
},
peerKey.PubKey(),
false,
Expand Down Expand Up @@ -504,6 +506,7 @@ func TestRouter_DialPeers(t *testing.T) {
ListenAddr: "0.0.0.0:0",
Network: "other-network",
Moniker: string(peerID),
Channels: tmsync.NewConcurrentSlice[uint16](),
},
peerKey.PubKey(),
nil,
Expand Down Expand Up @@ -766,7 +769,7 @@ func TestRouter_ChannelCompatability(t *testing.T) {
ListenAddr: "0.0.0.0:0",
Network: "test",
Moniker: string(peerID),
Channels: []byte{0x03},
Channels: tmsync.NewConcurrentSlice[uint16](0x03),
}

mockConnection := &mocks.Connection{}
Expand Down Expand Up @@ -817,7 +820,7 @@ func TestRouter_DontSendOnInvalidChannel(t *testing.T) {
ListenAddr: "0.0.0.0:0",
Network: "test",
Moniker: string(peerID),
Channels: []byte{0x02},
Channels: tmsync.NewConcurrentSlice[uint16](0x02),
}

mockConnection := &mocks.Connection{}
Expand Down
13 changes: 8 additions & 5 deletions internal/p2p/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (
"github.com/stretchr/testify/require"

"github.com/dashpay/tenderdash/crypto/ed25519"
tmsync "github.com/dashpay/tenderdash/internal/libs/sync"
"github.com/dashpay/tenderdash/internal/p2p"
"github.com/dashpay/tenderdash/libs/bytes"
"github.com/dashpay/tenderdash/types"
)

Expand Down Expand Up @@ -283,15 +283,18 @@ func TestConnection_Handshake(t *testing.T) {
ListenAddr: "listenaddr",
Network: "network",
Version: "1.2.3",
Channels: bytes.HexBytes([]byte{0xf0, 0x0f}),
Channels: tmsync.NewConcurrentSlice[uint16](0xf0, 0x0f),
Moniker: "moniker",
Other: types.NodeInfoOther{
TxIndex: "txindex",
RPCAddress: "rpc.domain.com",
},
}
bKey := ed25519.GenPrivKey()
bInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(bKey.PubKey())}
bInfo := types.NodeInfo{
NodeID: types.NodeIDFromPubKey(bKey.PubKey()),
Channels: tmsync.NewConcurrentSlice[uint16](),
}

errCh := make(chan error, 1)
go func() {
Expand Down Expand Up @@ -641,13 +644,13 @@ func dialAcceptHandshake(ctx context.Context, t *testing.T, a, b p2p.Transport)
errCh := make(chan error, 1)
go func() {
privKey := ed25519.GenPrivKey()
nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey())}
nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey()), Channels: tmsync.NewConcurrentSlice[uint16]()}
_, _, err := ba.Handshake(ctx, 0, nodeInfo, privKey)
errCh <- err
}()

privKey := ed25519.GenPrivKey()
nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey())}
nodeInfo := types.NodeInfo{NodeID: types.NodeIDFromPubKey(privKey.PubKey()), Channels: tmsync.NewConcurrentSlice[uint16]()}
_, _, err := ab.Handshake(ctx, 0, nodeInfo, privKey)
require.NoError(t, err)

Expand Down
Loading
Loading