diff --git a/persist/sqlite/peers.go b/persist/sqlite/peers.go index 0de01dd..d29aed3 100644 --- a/persist/sqlite/peers.go +++ b/persist/sqlite/peers.go @@ -21,8 +21,8 @@ func getPeerInfo(tx txn, peer string) (syncer.PeerInfo, error) { } func (s *Store) updatePeerInfo(tx txn, peer string, info syncer.PeerInfo) error { - const query = `UPDATE syncer_peers SET first_seen=$2, last_connect=$3, synced_blocks=$4, sync_duration=$5 WHERE peer_address=$1` - _, err := tx.Exec(query, peer, (*sqlTime)(&info.FirstSeen), (*sqlTime)(&info.LastConnect), info.SyncedBlocks, info.SyncDuration) + const query = `UPDATE syncer_peers SET first_seen=$1, last_connect=$2, synced_blocks=$3, sync_duration=$4 WHERE peer_address=$5 RETURNING peer_address` + err := tx.QueryRow(query, (*sqlTime)(&info.FirstSeen), (*sqlTime)(&info.LastConnect), info.SyncedBlocks, info.SyncDuration, peer).Scan(&peer) return err } @@ -149,16 +149,23 @@ func (s *Store) Banned(peer string) (banned bool) { checkSubnets := make([]string, 0, maxMaskLen) for i := maxMaskLen; i > 0; i-- { - check := subnet.IP.String() + "/" + strconv.Itoa(i) - checkSubnets = append(checkSubnets, check) + _, subnet, err := net.ParseCIDR(subnet.IP.String() + "/" + strconv.Itoa(i)) + if err != nil { + panic("failed to parse CIDR") + } + checkSubnets = append(checkSubnets, subnet.String()) } err = s.transaction(func(tx txn) error { query := `SELECT net_cidr, expiration FROM syncer_bans WHERE net_cidr IN (` + queryPlaceHolders(len(checkSubnets)) + `) ORDER BY expiration DESC LIMIT 1` + var subnet string var expiration time.Time - err := tx.QueryRow(query, queryArgs(checkSubnets)...).Scan((*sqlTime)(&expiration)) + err := tx.QueryRow(query, queryArgs(checkSubnets)...).Scan(&subnet, (*sqlTime)(&expiration)) banned = time.Now().Before(expiration) // will return false for any sql errors, including ErrNoRows + if err == nil && banned { + s.log.Debug("found ban", zap.String("subnet", subnet), zap.Time("expiration", expiration)) + } return err }) if err != nil && !errors.Is(err, sql.ErrNoRows) { diff --git a/persist/sqlite/peers_test.go b/persist/sqlite/peers_test.go new file mode 100644 index 0000000..2de6d2f --- /dev/null +++ b/persist/sqlite/peers_test.go @@ -0,0 +1,101 @@ +package sqlite + +import ( + "net" + "path/filepath" + "testing" + "time" + + "go.sia.tech/walletd/syncer" + "go.uber.org/zap/zaptest" +) + +func TestAddPeer(t *testing.T) { + log := zaptest.NewLogger(t) + db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + const peer = "1.2.3.4:9981" + + if err := db.AddPeer(peer); err != nil { + t.Fatal(err) + } + + lastConnect := time.Now().Truncate(time.Second) // stored as unix milliseconds + syncedBlocks := uint64(15) + syncDuration := 5 * time.Second + + err = db.UpdatePeerInfo(peer, func(info *syncer.PeerInfo) { + info.LastConnect = lastConnect + info.SyncedBlocks = syncedBlocks + info.SyncDuration = syncDuration + }) + if err != nil { + t.Fatal(err) + } + + info, err := db.PeerInfo(peer) + if err != nil { + t.Fatal(err) + } + + if !info.LastConnect.Equal(lastConnect) { + t.Errorf("expected LastConnect = %v; got %v", lastConnect, info.LastConnect) + } + if info.SyncedBlocks != syncedBlocks { + t.Errorf("expected SyncedBlocks = %d; got %d", syncedBlocks, info.SyncedBlocks) + } + if info.SyncDuration != 5*time.Second { + t.Errorf("expected SyncDuration = %s; got %s", syncDuration, info.SyncDuration) + } +} + +func TestBanPeer(t *testing.T) { + log := zaptest.NewLogger(t) + db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + const peer = "1.2.3.4" + + if db.Banned(peer) { + t.Fatal("expected peer to not be banned") + } + + // ban the peer + if err := db.Ban(peer, time.Second, "test"); err != nil { + t.Fatal(err) + } + + if !db.Banned(peer) { + t.Fatal("expected peer to be banned") + } + + // wait for the ban to expire + time.Sleep(time.Second) + + if db.Banned(peer) { + t.Fatal("expected peer to not be banned") + } + + // ban a subnet + _, subnet, err := net.ParseCIDR(peer + "/24") + if err != nil { + t.Fatal(err) + } + + t.Log("banning", subnet) + + if err := db.Ban(subnet.String(), time.Second, "test"); err != nil { + t.Fatal(err) + } + + if !db.Banned(peer) { + t.Fatal("expected peer to be banned") + } +}