Skip to content

Commit

Permalink
sqlite: add peer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
n8maninger committed Jan 11, 2024
1 parent 7cc0a2c commit 3b7f51d
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 5 deletions.
17 changes: 12 additions & 5 deletions persist/sqlite/peers.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ func getPeerInfo(tx txn, peer string) (syncer.PeerInfo, error) {
}

func (s *Store) updatePeerInfo(tx txn, peer string, info syncer.PeerInfo) error {
const query = `UPDATE syncer_peers SET first_seen=$2, last_connect=$3, synced_blocks=$4, sync_duration=$5 WHERE peer_address=$1`
_, err := tx.Exec(query, peer, (*sqlTime)(&info.FirstSeen), (*sqlTime)(&info.LastConnect), info.SyncedBlocks, info.SyncDuration)
const query = `UPDATE syncer_peers SET first_seen=$1, last_connect=$2, synced_blocks=$3, sync_duration=$4 WHERE peer_address=$5 RETURNING peer_address`
err := tx.QueryRow(query, (*sqlTime)(&info.FirstSeen), (*sqlTime)(&info.LastConnect), info.SyncedBlocks, info.SyncDuration, peer).Scan(&peer)
return err
}

Expand Down Expand Up @@ -149,16 +149,23 @@ func (s *Store) Banned(peer string) (banned bool) {

checkSubnets := make([]string, 0, maxMaskLen)
for i := maxMaskLen; i > 0; i-- {
check := subnet.IP.String() + "/" + strconv.Itoa(i)
checkSubnets = append(checkSubnets, check)
_, subnet, err := net.ParseCIDR(subnet.IP.String() + "/" + strconv.Itoa(i))
if err != nil {
panic("failed to parse CIDR")
}
checkSubnets = append(checkSubnets, subnet.String())
}

err = s.transaction(func(tx txn) error {
query := `SELECT net_cidr, expiration FROM syncer_bans WHERE net_cidr IN (` + queryPlaceHolders(len(checkSubnets)) + `) ORDER BY expiration DESC LIMIT 1`

var subnet string
var expiration time.Time
err := tx.QueryRow(query, queryArgs(checkSubnets)...).Scan((*sqlTime)(&expiration))
err := tx.QueryRow(query, queryArgs(checkSubnets)...).Scan(&subnet, (*sqlTime)(&expiration))
banned = time.Now().Before(expiration) // will return false for any sql errors, including ErrNoRows
if err == nil && banned {
s.log.Debug("found ban", zap.String("subnet", subnet), zap.Time("expiration", expiration))
}
return err
})
if err != nil && !errors.Is(err, sql.ErrNoRows) {
Expand Down
101 changes: 101 additions & 0 deletions persist/sqlite/peers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package sqlite

import (
"net"
"path/filepath"
"testing"
"time"

"go.sia.tech/walletd/syncer"
"go.uber.org/zap/zaptest"
)

func TestAddPeer(t *testing.T) {
log := zaptest.NewLogger(t)
db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log)
if err != nil {
t.Fatal(err)
}
defer db.Close()

const peer = "1.2.3.4:9981"

if err := db.AddPeer(peer); err != nil {
t.Fatal(err)
}

lastConnect := time.Now().Truncate(time.Second) // stored as unix milliseconds
syncedBlocks := uint64(15)
syncDuration := 5 * time.Second

err = db.UpdatePeerInfo(peer, func(info *syncer.PeerInfo) {
info.LastConnect = lastConnect
info.SyncedBlocks = syncedBlocks
info.SyncDuration = syncDuration
})
if err != nil {
t.Fatal(err)
}

info, err := db.PeerInfo(peer)
if err != nil {
t.Fatal(err)
}

if !info.LastConnect.Equal(lastConnect) {
t.Errorf("expected LastConnect = %v; got %v", lastConnect, info.LastConnect)
}
if info.SyncedBlocks != syncedBlocks {
t.Errorf("expected SyncedBlocks = %d; got %d", syncedBlocks, info.SyncedBlocks)
}
if info.SyncDuration != 5*time.Second {
t.Errorf("expected SyncDuration = %s; got %s", syncDuration, info.SyncDuration)
}
}

func TestBanPeer(t *testing.T) {
log := zaptest.NewLogger(t)
db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log)
if err != nil {
t.Fatal(err)
}
defer db.Close()

const peer = "1.2.3.4"

if db.Banned(peer) {
t.Fatal("expected peer to not be banned")
}

// ban the peer
if err := db.Ban(peer, time.Second, "test"); err != nil {
t.Fatal(err)
}

if !db.Banned(peer) {
t.Fatal("expected peer to be banned")
}

// wait for the ban to expire
time.Sleep(time.Second)

if db.Banned(peer) {
t.Fatal("expected peer to not be banned")
}

// ban a subnet
_, subnet, err := net.ParseCIDR(peer + "/24")
if err != nil {
t.Fatal(err)
}

t.Log("banning", subnet)

if err := db.Ban(subnet.String(), time.Second, "test"); err != nil {
t.Fatal(err)
}

if !db.Banned(peer) {
t.Fatal("expected peer to be banned")
}
}

0 comments on commit 3b7f51d

Please sign in to comment.