From 7d698e278264d9453d17df2d21f6cb321d9606c5 Mon Sep 17 00:00:00 2001 From: Rigidity Date: Thu, 18 Jul 2024 13:49:22 -0400 Subject: [PATCH] Refactor --- crates/chia-client/examples/peer.rs | 9 +- crates/chia-client/examples/peer_discovery.rs | 9 +- crates/chia-client/src/client.rs | 190 +++++++++--------- crates/chia-client/src/error.rs | 9 +- crates/chia-client/src/event.rs | 5 +- crates/chia-client/src/peer.rs | 22 +- 6 files changed, 125 insertions(+), 119 deletions(-) diff --git a/crates/chia-client/examples/peer.rs b/crates/chia-client/examples/peer.rs index c4f94a18e..d18069903 100644 --- a/crates/chia-client/examples/peer.rs +++ b/crates/chia-client/examples/peer.rs @@ -1,4 +1,4 @@ -use std::env; +use std::{env, net::SocketAddr}; use chia_client::{create_tls_connector, Peer}; use chia_protocol::{Handshake, NodeType}; @@ -9,8 +9,11 @@ use chia_traits::Streamable; async fn main() -> anyhow::Result<()> { let ssl = ChiaCertificate::generate()?; let tls_connector = create_tls_connector(ssl.cert_pem.as_bytes(), ssl.key_pem.as_bytes())?; - let (peer, mut receiver) = - Peer::connect(env::var("PEER")?.parse()?, 58444, tls_connector).await?; + let (peer, mut receiver) = Peer::connect( + SocketAddr::new(env::var("PEER")?.parse()?, 58444), + tls_connector, + ) + .await?; peer.send(Handshake { network_id: "testnet11".to_string(), diff --git a/crates/chia-client/examples/peer_discovery.rs b/crates/chia-client/examples/peer_discovery.rs index 3791ef49c..a1f56e001 100644 --- a/crates/chia-client/examples/peer_discovery.rs +++ b/crates/chia-client/examples/peer_discovery.rs @@ -1,4 +1,4 @@ -use std::time::Duration; +use std::{net::SocketAddr, time::Duration}; use chia_client::{create_tls_connector, Peer}; use chia_protocol::{Handshake, NodeType, ProtocolMessageTypes}; @@ -15,8 +15,11 @@ async fn main() -> anyhow::Result<()> { let tls = create_tls_connector(cert.cert_pem.as_bytes(), cert.key_pem.as_bytes())?; for ip in lookup_host("dns-introducer.chia.net")? { - let Ok(response) = - timeout(Duration::from_secs(3), Peer::connect(ip, 8444, tls.clone())).await + let Ok(response) = timeout( + Duration::from_secs(3), + Peer::connect(SocketAddr::new(ip, 8444), tls.clone()), + ) + .await else { log::info!("{ip} exceeded connection timeout of 3 seconds"); continue; diff --git a/crates/chia-client/src/client.rs b/crates/chia-client/src/client.rs index ec9e1eacb..484b8eed8 100644 --- a/crates/chia-client/src/client.rs +++ b/crates/chia-client/src/client.rs @@ -1,6 +1,6 @@ use std::{ collections::{HashMap, HashSet}, - net::IpAddr, + net::{IpAddr, SocketAddr}, str::FromStr, sync::Arc, time::Duration, @@ -94,23 +94,19 @@ impl Client { self.0.peers.lock().await.is_empty() } - pub async fn peer_ids(&self) -> Vec { - self.0.peers.lock().await.keys().copied().collect() - } - - pub async fn peers(&self) -> Vec { - self.0.peers.lock().await.values().cloned().collect() + pub async fn peer_map(&self) -> HashMap { + self.0.peers.lock().await.clone() } pub async fn peer(&self, peer_id: PeerId) -> Option { self.0.peers.lock().await.get(&peer_id).cloned() } - pub async fn remove_peer(&self, peer_id: PeerId) -> Option { - self.0.peers.lock().await.remove(&peer_id) + pub async fn disconnect_peer(&self, peer_id: PeerId) { + self.0.peers.lock().await.remove(&peer_id); } - pub async fn clear(&self) { + pub async fn disconnect_all(&self) { self.0.peers.lock().await.clear(); } @@ -158,12 +154,12 @@ impl Client { let Ok(Ok(response)): std::result::Result, _> = timeout(self.0.options.request_peers_timeout, peer.request_peers()).await else { - log::info!("Failed to request peers from {}", peer.ip_addr()); - self.remove_peer(peer_id).await; + log::info!("Failed to request peers from {}", peer.socket_addr()); + self.disconnect_peer(peer_id).await; continue; }; - log::info!("Requested peers from {}", peer.ip_addr()); + log::info!("Requested peers from {}", peer.socket_addr()); let mut ips = HashSet::new(); @@ -173,8 +169,7 @@ impl Client { log::debug!("Failed to parse IP address {}", item.host); continue; }; - - ips.insert((ip_addr, item.port)); + ips.insert(SocketAddr::new(ip_addr, item.port)); } // Keep connecting peers until the peer list is exhausted, @@ -199,7 +194,7 @@ impl Client { async fn connect_dns(&self) { log::info!("Requesting peers from DNS introducer"); - let mut ips = Vec::new(); + let mut socket_addrs = Vec::new(); for dns_introducer in &self.0.options.network.dns_introducers { // If a DNS introducer lookup fails, we just skip it. @@ -207,126 +202,76 @@ impl Client { log::warn!("Failed to lookup DNS introducer `{dns_introducer}`"); continue; }; - ips.extend(result); + socket_addrs.extend( + result + .into_iter() + .map(|ip| SocketAddr::new(ip, self.0.options.network.default_port)), + ); } // Shuffle the list of IPs so that we don't always connect to the same ones. // This also prevents bias towards IPv4 or IPv6. - ips.as_mut_slice().shuffle(&mut thread_rng()); + socket_addrs.as_mut_slice().shuffle(&mut thread_rng()); // Keep track of where we are in the peer list. let mut cursor = 0; while self.len().await < self.0.options.target_peers { // If we've reached the end of the list of IPs, stop early. - if cursor >= ips.len() { + if cursor >= socket_addrs.len() { break; } // Get the remaining peers we can connect to, up to the concurrency limit. - let peers_to_try = &ips[cursor - ..ips + let new_addrs = &socket_addrs[cursor + ..socket_addrs .len() .min(cursor + self.0.options.connection_concurrency)]; // Increment the cursor by the number of peers we're trying to connect to. - cursor += peers_to_try.len(); - - self.connect_peers( - peers_to_try - .iter() - .map(|ip| (*ip, self.0.options.network.default_port)) - .collect(), - ) - .await; + cursor += new_addrs.len(); + + self.connect_peers(new_addrs.to_vec()).await; } } - async fn connect_peers(&self, potential_ips: Vec<(IpAddr, u16)>) { - let peer_lock = self.0.peers.lock().await; - let peers = peer_lock.clone(); - drop(peer_lock); - + async fn connect_peers(&self, socket_addrs: Vec) { // Add the connections and wait for them to complete. let mut connections = FuturesUnordered::new(); - for (ip, port) in potential_ips { - if peers.iter().any(|(_, peer)| peer.ip_addr() == ip) { + let peers = self.peer_map().await; + + for socket_addr in socket_addrs { + if peers + .iter() + .any(|(_, peer)| peer.socket_addr().ip() == socket_addr.ip()) + { continue; } - connections.push(async move { - self.connect_peer(ip, port) - .await - .map_err(|error| (ip, port, error)) - }); + connections.push(async move { (socket_addr, self.connect_peer(socket_addr).await) }); } - while let Some(result) = connections.next().await { + while let Some((socket_addr, result)) = connections.next().await { if self.len().await >= self.0.options.target_peers { break; } - let (peer, mut receiver) = match result { - Ok(result) => result, - Err((ip, port, error)) => { - log::warn!( - "{error} for peer {}", - if ip.is_ipv4() { - format!("{ip}:{port}") - } else { - format!("[{ip}]:{port}") - } - ); - continue; - } - }; - - let ip = peer.ip_addr(); - let peer_id = peer.peer_id(); - self.0.peers.lock().await.insert(peer_id, peer); - - let message_sender = self.0.message_sender.clone(); - let peer_map = self.0.peers.clone(); - - // Spawn a task to propagate messages from the peer. - tokio::spawn(async move { - while let Some(message) = receiver.recv().await { - if let Err(error) = message_sender - .lock() - .await - .send(Event::Message(peer_id, message)) - .await - { - log::warn!("Failed to send client message event: {error}"); - break; - } - } - - peer_map.lock().await.remove(&peer_id); - - if let Err(error) = message_sender - .lock() - .await - .send(Event::ConnectionClosed(peer_id)) - .await - { - log::warn!("Failed to send client connection closed event: {error}"); - } - - log::info!("Peer {ip} disconnected"); - }); + if let Err(error) = result { + log::warn!("Failed to connect to peer {socket_addr} with error: {error}",); + continue; + } - log::info!("Connected to peer {ip}"); + log::info!("Connected to peer {socket_addr}"); } } - async fn connect_peer(&self, ip: IpAddr, port: u16) -> Result<(Peer, mpsc::Receiver)> { - log::debug!("Connecting to peer {ip}"); + pub async fn connect_peer(&self, socket_addr: SocketAddr) -> Result { + log::debug!("Connecting to peer {socket_addr}"); let (peer, mut receiver) = timeout( self.0.options.connection_timeout, - Peer::connect(ip, port, self.0.tls_connector.clone()), + Peer::connect(socket_addr, self.0.tls_connector.clone()), ) .await .map_err(Error::ConnectionTimeout)??; @@ -371,6 +316,57 @@ impl Client { )); } - Ok((peer, receiver)) + self.add_peer(peer, receiver).await + } + + pub async fn add_peer( + &self, + peer: Peer, + mut receiver: mpsc::Receiver, + ) -> Result { + let socket_addr = peer.socket_addr(); + let peer_id = peer.peer_id(); + + self.0.peers.lock().await.insert(peer_id, peer); + + self.0 + .message_sender + .lock() + .await + .send(Event::Connected(peer_id)) + .await?; + + // Spawn a task to propagate messages from the peer. + let message_sender = self.0.message_sender.clone(); + let peer_map = self.0.peers.clone(); + + tokio::spawn(async move { + while let Some(message) = receiver.recv().await { + if let Err(error) = message_sender + .lock() + .await + .send(Event::Message(peer_id, message)) + .await + { + log::warn!("Failed to send client message event: {error}"); + break; + } + } + + peer_map.lock().await.remove(&peer_id); + + if let Err(error) = message_sender + .lock() + .await + .send(Event::Disconnected(socket_addr)) + .await + { + log::warn!("Failed to send client connection closed event: {error}"); + } + + log::info!("Peer {socket_addr} disconnected"); + }); + + Ok(peer_id) } } diff --git a/crates/chia-client/src/error.rs b/crates/chia-client/src/error.rs index 0eda44601..82150c6a0 100644 --- a/crates/chia-client/src/error.rs +++ b/crates/chia-client/src/error.rs @@ -1,7 +1,11 @@ use chia_protocol::ProtocolMessageTypes; use semver::Version; use thiserror::Error; -use tokio::{sync::oneshot::error::RecvError, time::error::Elapsed}; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::oneshot::error::RecvError; +use tokio::time::error::Elapsed; + +use crate::Event; #[derive(Debug, Error)] pub enum Error { @@ -38,6 +42,9 @@ pub enum Error { #[error("Failed to send event")] EventNotSent, + #[error("Failed to send message")] + Send(#[from] SendError), + #[error("Failed to receive message")] Recv(#[from] RecvError), diff --git a/crates/chia-client/src/event.rs b/crates/chia-client/src/event.rs index 81c198101..c22d01719 100644 --- a/crates/chia-client/src/event.rs +++ b/crates/chia-client/src/event.rs @@ -1,3 +1,5 @@ +use std::net::SocketAddr; + use chia_protocol::Message; use crate::PeerId; @@ -5,5 +7,6 @@ use crate::PeerId; #[derive(Debug, Clone)] pub enum Event { Message(PeerId, Message), - ConnectionClosed(PeerId), + Connected(PeerId), + Disconnected(SocketAddr), } diff --git a/crates/chia-client/src/peer.rs b/crates/chia-client/src/peer.rs index 29b958cb3..cbbf82363 100644 --- a/crates/chia-client/src/peer.rs +++ b/crates/chia-client/src/peer.rs @@ -1,4 +1,4 @@ -use std::{fmt, net::IpAddr, sync::Arc}; +use std::{fmt, net::SocketAddr, sync::Arc}; use chia_protocol::{ Bytes32, ChiaProtocolMessage, CoinStateFilters, Message, PuzzleSolutionResponse, @@ -47,21 +47,15 @@ struct PeerInner { inbound_handle: JoinHandle<()>, requests: Arc, peer_id: PeerId, - ip_addr: IpAddr, + socket_addr: SocketAddr, } impl Peer { pub async fn connect( - ip: IpAddr, - port: u16, + socket_addr: SocketAddr, tls_connector: TlsConnector, ) -> Result<(Self, mpsc::Receiver)> { - let uri = if ip.is_ipv4() { - format!("wss://{ip}:{port}/ws") - } else { - format!("wss://[{ip}]:{port}/ws") - }; - Self::connect_addr(&uri, tls_connector).await + Self::connect_addr(&format!("wss://{socket_addr}/ws"), tls_connector).await } pub async fn connect_addr( @@ -79,7 +73,7 @@ impl Peer { } pub fn from_websocket(ws: WebSocket) -> Result<(Self, mpsc::Receiver)> { - let (addr, cert) = match ws.get_ref() { + let (socket_addr, cert) = match ws.get_ref() { MaybeTlsStream::NativeTls(tls) => { let tls_stream = tls.get_ref(); let tcp_stream = tls_stream.get_ref().get_ref(); @@ -113,7 +107,7 @@ impl Peer { inbound_handle, requests, peer_id, - ip_addr: addr.ip(), + socket_addr, })); Ok((peer, receiver)) @@ -123,8 +117,8 @@ impl Peer { self.0.peer_id } - pub fn ip_addr(&self) -> IpAddr { - self.0.ip_addr + pub fn socket_addr(&self) -> SocketAddr { + self.0.socket_addr } pub async fn send_transaction(&self, spend_bundle: SpendBundle) -> Result {