Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Rigidity committed Jul 18, 2024
1 parent 27c391e commit 7d698e2
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 119 deletions.
9 changes: 6 additions & 3 deletions crates/chia-client/examples/peer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::env;
use std::{env, net::SocketAddr};

use chia_client::{create_tls_connector, Peer};
use chia_protocol::{Handshake, NodeType};
Expand All @@ -9,8 +9,11 @@ use chia_traits::Streamable;
async fn main() -> anyhow::Result<()> {
let ssl = ChiaCertificate::generate()?;
let tls_connector = create_tls_connector(ssl.cert_pem.as_bytes(), ssl.key_pem.as_bytes())?;
let (peer, mut receiver) =
Peer::connect(env::var("PEER")?.parse()?, 58444, tls_connector).await?;
let (peer, mut receiver) = Peer::connect(
SocketAddr::new(env::var("PEER")?.parse()?, 58444),
tls_connector,
)
.await?;

peer.send(Handshake {
network_id: "testnet11".to_string(),
Expand Down
9 changes: 6 additions & 3 deletions crates/chia-client/examples/peer_discovery.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::time::Duration;
use std::{net::SocketAddr, time::Duration};

use chia_client::{create_tls_connector, Peer};
use chia_protocol::{Handshake, NodeType, ProtocolMessageTypes};
Expand All @@ -15,8 +15,11 @@ async fn main() -> anyhow::Result<()> {
let tls = create_tls_connector(cert.cert_pem.as_bytes(), cert.key_pem.as_bytes())?;

for ip in lookup_host("dns-introducer.chia.net")? {
let Ok(response) =
timeout(Duration::from_secs(3), Peer::connect(ip, 8444, tls.clone())).await
let Ok(response) = timeout(
Duration::from_secs(3),
Peer::connect(SocketAddr::new(ip, 8444), tls.clone()),
)
.await
else {
log::info!("{ip} exceeded connection timeout of 3 seconds");
continue;
Expand Down
190 changes: 93 additions & 97 deletions crates/chia-client/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{
collections::{HashMap, HashSet},
net::IpAddr,
net::{IpAddr, SocketAddr},
str::FromStr,
sync::Arc,
time::Duration,
Expand Down Expand Up @@ -94,23 +94,19 @@ impl Client {
self.0.peers.lock().await.is_empty()
}

pub async fn peer_ids(&self) -> Vec<PeerId> {
self.0.peers.lock().await.keys().copied().collect()
}

pub async fn peers(&self) -> Vec<Peer> {
self.0.peers.lock().await.values().cloned().collect()
pub async fn peer_map(&self) -> HashMap<PeerId, Peer> {
self.0.peers.lock().await.clone()
}

pub async fn peer(&self, peer_id: PeerId) -> Option<Peer> {
self.0.peers.lock().await.get(&peer_id).cloned()
}

pub async fn remove_peer(&self, peer_id: PeerId) -> Option<Peer> {
self.0.peers.lock().await.remove(&peer_id)
pub async fn disconnect_peer(&self, peer_id: PeerId) {
self.0.peers.lock().await.remove(&peer_id);
}

pub async fn clear(&self) {
pub async fn disconnect_all(&self) {
self.0.peers.lock().await.clear();
}

Expand Down Expand Up @@ -158,12 +154,12 @@ impl Client {
let Ok(Ok(response)): std::result::Result<Result<RespondPeers>, _> =
timeout(self.0.options.request_peers_timeout, peer.request_peers()).await
else {
log::info!("Failed to request peers from {}", peer.ip_addr());
self.remove_peer(peer_id).await;
log::info!("Failed to request peers from {}", peer.socket_addr());
self.disconnect_peer(peer_id).await;
continue;
};

log::info!("Requested peers from {}", peer.ip_addr());
log::info!("Requested peers from {}", peer.socket_addr());

let mut ips = HashSet::new();

Expand All @@ -173,8 +169,7 @@ impl Client {
log::debug!("Failed to parse IP address {}", item.host);
continue;
};

ips.insert((ip_addr, item.port));
ips.insert(SocketAddr::new(ip_addr, item.port));
}

// Keep connecting peers until the peer list is exhausted,
Expand All @@ -199,134 +194,84 @@ impl Client {
async fn connect_dns(&self) {
log::info!("Requesting peers from DNS introducer");

let mut ips = Vec::new();
let mut socket_addrs = Vec::new();

for dns_introducer in &self.0.options.network.dns_introducers {
// If a DNS introducer lookup fails, we just skip it.
let Ok(result) = lookup_host(dns_introducer) else {
log::warn!("Failed to lookup DNS introducer `{dns_introducer}`");
continue;
};
ips.extend(result);
socket_addrs.extend(
result
.into_iter()
.map(|ip| SocketAddr::new(ip, self.0.options.network.default_port)),
);
}

// Shuffle the list of IPs so that we don't always connect to the same ones.
// This also prevents bias towards IPv4 or IPv6.
ips.as_mut_slice().shuffle(&mut thread_rng());
socket_addrs.as_mut_slice().shuffle(&mut thread_rng());

// Keep track of where we are in the peer list.
let mut cursor = 0;

while self.len().await < self.0.options.target_peers {
// If we've reached the end of the list of IPs, stop early.
if cursor >= ips.len() {
if cursor >= socket_addrs.len() {
break;
}

// Get the remaining peers we can connect to, up to the concurrency limit.
let peers_to_try = &ips[cursor
..ips
let new_addrs = &socket_addrs[cursor
..socket_addrs
.len()
.min(cursor + self.0.options.connection_concurrency)];

// Increment the cursor by the number of peers we're trying to connect to.
cursor += peers_to_try.len();

self.connect_peers(
peers_to_try
.iter()
.map(|ip| (*ip, self.0.options.network.default_port))
.collect(),
)
.await;
cursor += new_addrs.len();

self.connect_peers(new_addrs.to_vec()).await;
}
}

async fn connect_peers(&self, potential_ips: Vec<(IpAddr, u16)>) {
let peer_lock = self.0.peers.lock().await;
let peers = peer_lock.clone();
drop(peer_lock);

async fn connect_peers(&self, socket_addrs: Vec<SocketAddr>) {
// Add the connections and wait for them to complete.
let mut connections = FuturesUnordered::new();

for (ip, port) in potential_ips {
if peers.iter().any(|(_, peer)| peer.ip_addr() == ip) {
let peers = self.peer_map().await;

for socket_addr in socket_addrs {
if peers
.iter()
.any(|(_, peer)| peer.socket_addr().ip() == socket_addr.ip())
{
continue;
}

connections.push(async move {
self.connect_peer(ip, port)
.await
.map_err(|error| (ip, port, error))
});
connections.push(async move { (socket_addr, self.connect_peer(socket_addr).await) });
}

while let Some(result) = connections.next().await {
while let Some((socket_addr, result)) = connections.next().await {
if self.len().await >= self.0.options.target_peers {
break;
}

let (peer, mut receiver) = match result {
Ok(result) => result,
Err((ip, port, error)) => {
log::warn!(
"{error} for peer {}",
if ip.is_ipv4() {
format!("{ip}:{port}")
} else {
format!("[{ip}]:{port}")
}
);
continue;
}
};

let ip = peer.ip_addr();
let peer_id = peer.peer_id();
self.0.peers.lock().await.insert(peer_id, peer);

let message_sender = self.0.message_sender.clone();
let peer_map = self.0.peers.clone();

// Spawn a task to propagate messages from the peer.
tokio::spawn(async move {
while let Some(message) = receiver.recv().await {
if let Err(error) = message_sender
.lock()
.await
.send(Event::Message(peer_id, message))
.await
{
log::warn!("Failed to send client message event: {error}");
break;
}
}

peer_map.lock().await.remove(&peer_id);

if let Err(error) = message_sender
.lock()
.await
.send(Event::ConnectionClosed(peer_id))
.await
{
log::warn!("Failed to send client connection closed event: {error}");
}

log::info!("Peer {ip} disconnected");
});
if let Err(error) = result {
log::warn!("Failed to connect to peer {socket_addr} with error: {error}",);
continue;
}

log::info!("Connected to peer {ip}");
log::info!("Connected to peer {socket_addr}");
}
}

async fn connect_peer(&self, ip: IpAddr, port: u16) -> Result<(Peer, mpsc::Receiver<Message>)> {
log::debug!("Connecting to peer {ip}");
pub async fn connect_peer(&self, socket_addr: SocketAddr) -> Result<PeerId> {
log::debug!("Connecting to peer {socket_addr}");

let (peer, mut receiver) = timeout(
self.0.options.connection_timeout,
Peer::connect(ip, port, self.0.tls_connector.clone()),
Peer::connect(socket_addr, self.0.tls_connector.clone()),
)
.await
.map_err(Error::ConnectionTimeout)??;
Expand Down Expand Up @@ -371,6 +316,57 @@ impl Client {
));
}

Ok((peer, receiver))
self.add_peer(peer, receiver).await
}

pub async fn add_peer(
&self,
peer: Peer,
mut receiver: mpsc::Receiver<Message>,
) -> Result<PeerId> {
let socket_addr = peer.socket_addr();
let peer_id = peer.peer_id();

self.0.peers.lock().await.insert(peer_id, peer);

self.0
.message_sender
.lock()
.await
.send(Event::Connected(peer_id))
.await?;

// Spawn a task to propagate messages from the peer.
let message_sender = self.0.message_sender.clone();
let peer_map = self.0.peers.clone();

tokio::spawn(async move {
while let Some(message) = receiver.recv().await {
if let Err(error) = message_sender
.lock()
.await
.send(Event::Message(peer_id, message))
.await
{
log::warn!("Failed to send client message event: {error}");
break;
}
}

peer_map.lock().await.remove(&peer_id);

if let Err(error) = message_sender
.lock()
.await
.send(Event::Disconnected(socket_addr))
.await
{
log::warn!("Failed to send client connection closed event: {error}");
}

log::info!("Peer {socket_addr} disconnected");
});

Ok(peer_id)
}
}
9 changes: 8 additions & 1 deletion crates/chia-client/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use chia_protocol::ProtocolMessageTypes;
use semver::Version;
use thiserror::Error;
use tokio::{sync::oneshot::error::RecvError, time::error::Elapsed};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::oneshot::error::RecvError;
use tokio::time::error::Elapsed;

use crate::Event;

#[derive(Debug, Error)]
pub enum Error {
Expand Down Expand Up @@ -38,6 +42,9 @@ pub enum Error {
#[error("Failed to send event")]
EventNotSent,

#[error("Failed to send message")]
Send(#[from] SendError<Event>),

#[error("Failed to receive message")]
Recv(#[from] RecvError),

Expand Down
5 changes: 4 additions & 1 deletion crates/chia-client/src/event.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use std::net::SocketAddr;

use chia_protocol::Message;

use crate::PeerId;

#[derive(Debug, Clone)]
pub enum Event {
Message(PeerId, Message),
ConnectionClosed(PeerId),
Connected(PeerId),
Disconnected(SocketAddr),
}
Loading

0 comments on commit 7d698e2

Please sign in to comment.