diff --git a/src/transport/manager/address.rs b/src/transport/manager/address.rs index e18d7c05..155d2d08 100644 --- a/src/transport/manager/address.rs +++ b/src/transport/manager/address.rs @@ -18,12 +18,36 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{types::ConnectionId, PeerId}; +use crate::{ + error::{DialError, NegotiationError}, + PeerId, +}; + +use std::collections::{hash_map::Entry, HashMap}; use multiaddr::{Multiaddr, Protocol}; use multihash::Multihash; -use std::collections::{BinaryHeap, HashSet}; +/// Maximum number of addresses tracked for a peer. +const MAX_ADDRESSES: usize = 64; + +/// Scores for address records. +pub mod scores { + /// Score indicating that the connection was successfully established. + pub const CONNECTION_ESTABLISHED: i32 = 100i32; + + /// Score for a connection with a peer using a different ID than expected. + pub const DIFFERENT_PEER_ID: i32 = 50i32; + + /// Score for failing to connect due to an invalid or unreachable address. + pub const CONNECTION_FAILURE: i32 = -100i32; + + /// Score for a connection attempt that failed due to a timeout. + pub const TIMEOUT_FAILURE: i32 = -50i32; +} + +/// Remove the address from the store if the score is below this threshold. +const REMOVE_THRESHOLD: i32 = scores::CONNECTION_FAILURE * 2; #[allow(clippy::derived_hash_with_manual_eq)] #[derive(Debug, Clone, Hash)] @@ -33,9 +57,6 @@ pub struct AddressRecord { /// Address. address: Multiaddr, - - /// Connection ID, if specified. - connection_id: Option, } impl AsRef for AddressRecord { @@ -47,12 +68,7 @@ impl AsRef for AddressRecord { impl AddressRecord { /// Create new `AddressRecord` and if `address` doesn't contain `P2p`, /// append the provided `PeerId` to the address. - pub fn new( - peer: &PeerId, - address: Multiaddr, - score: i32, - connection_id: Option, - ) -> Self { + pub fn new(peer: &PeerId, address: Multiaddr, score: i32) -> Self { let address = if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) { address.with(Protocol::P2p( Multihash::from_bytes(&peer.to_bytes()).expect("valid peer id"), @@ -61,11 +77,7 @@ impl AddressRecord { address }; - Self { - address, - score, - connection_id, - } + Self { address, score } } /// Create `AddressRecord` from `Multiaddr`. @@ -80,7 +92,6 @@ impl AddressRecord { Some(AddressRecord { address, score: 0i32, - connection_id: None, }) } @@ -95,20 +106,10 @@ impl AddressRecord { &self.address } - /// Get connection ID. - pub fn connection_id(&self) -> &Option { - &self.connection_id - } - /// Update score of an address. pub fn update_score(&mut self, score: i32) { self.score = self.score.saturating_add(score); } - - /// Set `ConnectionId` for the [`AddressRecord`]. - pub fn set_connection_id(&mut self, connection_id: ConnectionId) { - self.connection_id = Some(connection_id); - } } impl PartialEq for AddressRecord { @@ -134,19 +135,18 @@ impl Ord for AddressRecord { /// Store for peer addresses. #[derive(Debug)] pub struct AddressStore { - //// Addresses sorted by score. - pub by_score: BinaryHeap, + /// Addresses available. + pub addresses: HashMap, - /// Addresses queryable by hashing them for faster lookup. - pub by_address: HashSet, + max_capacity: usize, } impl FromIterator for AddressStore { fn from_iter>(iter: T) -> Self { let mut store = AddressStore::new(); for address in iter { - if let Some(address) = AddressRecord::from_multiaddr(address) { - store.insert(address); + if let Some(record) = AddressRecord::from_multiaddr(address) { + store.insert(record); } } @@ -158,8 +158,7 @@ impl FromIterator for AddressStore { fn from_iter>(iter: T) -> Self { let mut store = AddressStore::new(); for record in iter { - store.by_address.insert(record.address.clone()); - store.by_score.push(record); + store.insert(record); } store @@ -186,52 +185,84 @@ impl AddressStore { /// Create new [`AddressStore`]. pub fn new() -> Self { Self { - by_score: BinaryHeap::new(), - by_address: HashSet::new(), + addresses: HashMap::with_capacity(MAX_ADDRESSES), + max_capacity: MAX_ADDRESSES, } } - /// Check if [`AddressStore`] is empty. - pub fn is_empty(&self) -> bool { - self.by_score.is_empty() + /// Get the score for a given error. + pub fn error_score(error: &DialError) -> i32 { + match error { + DialError::Timeout => scores::CONNECTION_ESTABLISHED, + DialError::AddressError(_) => scores::CONNECTION_FAILURE, + DialError::DnsError(_) => scores::CONNECTION_FAILURE, + DialError::NegotiationError(negotiation_error) => match negotiation_error { + NegotiationError::PeerIdMismatch(_, _) => scores::DIFFERENT_PEER_ID, + // Timeout during the negotiation phase. + NegotiationError::Timeout => scores::TIMEOUT_FAILURE, + // Treat other errors as connection failures. + _ => scores::CONNECTION_FAILURE, + }, + } } - /// Check if address is already in the a - pub fn contains(&self, address: &Multiaddr) -> bool { - self.by_address.contains(address) + /// Check if [`AddressStore`] is empty. + pub fn is_empty(&self) -> bool { + self.addresses.is_empty() } - /// Insert new address record into [`AddressStore`] with default address score. - pub fn insert(&mut self, mut record: AddressRecord) { - if self.by_address.contains(record.address()) { + /// Insert the address record into [`AddressStore`] with the provided score. + /// + /// If the address is not in the store, it will be inserted. + /// Otherwise, the score and connection ID will be updated. + pub fn insert(&mut self, record: AddressRecord) { + let num_addresses = self.addresses.len(); + + if let Entry::Occupied(mut occupied) = self.addresses.entry(record.address.clone()) { + occupied.get_mut().update_score(record.score); + if occupied.get().score <= REMOVE_THRESHOLD { + occupied.remove(); + } return; } - record.connection_id = None; - self.by_address.insert(record.address.clone()); - self.by_score.push(record); - } - - /// Pop address with the highest score from [`AddressStore`]. - pub fn pop(&mut self) -> Option { - self.by_score.pop().map(|record| { - self.by_address.remove(&record.address); - record - }) - } - - /// Take at most `limit` `AddressRecord`s from [`AddressStore`]. - pub fn take(&mut self, limit: usize) -> Vec { - let mut records = Vec::new(); + // The eviction algorithm favours addresses with higher scores. + // + // This algorithm has the following implications: + // - it keeps the best addresses in the store. + // - if the store is at capacity, the worst address will be evicted. + // - an address that is not dialed yet (with score zero) will be preferred over an address + // that already failed (with negative score). + if num_addresses >= self.max_capacity { + // No need to keep track of negative addresses if we are at capacity. + if record.score < 0 { + return; + } - for _ in 0..limit { - match self.pop() { - Some(record) => records.push(record), - None => break, + let Some(min_record) = self.addresses.values().min().cloned() else { + return; + }; + // The lowest score is better than the new record. + if record.score < min_record.score { + return; } + self.addresses.remove(min_record.address()); + } + + // There's no need to keep track of this address if the score is below the threshold. + if record.score <= REMOVE_THRESHOLD { + return; } - records + // Insert the record. + self.addresses.insert(record.address.clone(), record); + } + + /// Return the available addresses sorted by score. + pub fn addresses(&self, limit: usize) -> Vec { + let mut records = self.addresses.values().cloned().collect::>(); + records.sort_by(|lhs, rhs| rhs.score.cmp(&lhs.score)); + records.into_iter().take(limit).map(|record| record.address).collect() } } @@ -256,7 +287,7 @@ mod tests { ), rng.gen_range(1..=65535), )); - let score: i32 = rng.gen(); + let score: i32 = rng.gen_range(10..=200); AddressRecord::new( &peer, @@ -264,7 +295,6 @@ mod tests { .with(Protocol::from(address.ip())) .with(Protocol::Tcp(address.port())), score, - None, ) } @@ -279,7 +309,7 @@ mod tests { ), rng.gen_range(1..=65535), )); - let score: i32 = rng.gen(); + let score: i32 = rng.gen_range(10..=200); AddressRecord::new( &peer, @@ -288,7 +318,6 @@ mod tests { .with(Protocol::Tcp(address.port())) .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))), score, - None, ) } @@ -303,7 +332,7 @@ mod tests { ), rng.gen_range(1..=65535), )); - let score: i32 = rng.gen(); + let score: i32 = rng.gen_range(10..=200); AddressRecord::new( &peer, @@ -312,10 +341,82 @@ mod tests { .with(Protocol::Udp(address.port())) .with(Protocol::QuicV1), score, - None, ) } + #[test] + fn insert_record() { + let mut store = AddressStore::new(); + let mut rng = rand::thread_rng(); + + let mut record = tcp_address_record(&mut rng); + record.score = 10; + + store.insert(record.clone()); + + assert_eq!(store.addresses.len(), 1); + assert_eq!(store.addresses.get(record.address()).unwrap(), &record); + + // This time the record is updated. + store.insert(record.clone()); + + assert_eq!(store.addresses.len(), 1); + let store_record = store.addresses.get(record.address()).unwrap(); + assert_eq!(store_record.score, record.score * 2); + } + + #[test] + fn evict_below_threshold() { + let mut store = AddressStore::new(); + let mut rng = rand::thread_rng(); + + let mut record = tcp_address_record(&mut rng); + record.score = scores::CONNECTION_FAILURE; + store.insert(record.clone()); + + assert_eq!(store.addresses.len(), 1); + + store.insert(record.clone()); + + assert_eq!(store.addresses.len(), 0); + } + + #[test] + fn evict_on_capacity() { + let mut store = AddressStore { + addresses: HashMap::new(), + max_capacity: 2, + }; + + let mut rng = rand::thread_rng(); + let mut first_record = tcp_address_record(&mut rng); + first_record.score = scores::CONNECTION_ESTABLISHED; + let mut second_record = ws_address_record(&mut rng); + second_record.score = 0; + + store.insert(first_record.clone()); + store.insert(second_record.clone()); + + assert_eq!(store.addresses.len(), 2); + + // We have better addresses, ignore this one. + let mut third_record = quic_address_record(&mut rng); + third_record.score = scores::CONNECTION_FAILURE; + store.insert(third_record.clone()); + assert_eq!(store.addresses.len(), 2); + assert!(store.addresses.contains_key(first_record.address())); + assert!(store.addresses.contains_key(second_record.address())); + + // Evict the address with the lowest score. + let mut fourth_record = quic_address_record(&mut rng); + fourth_record.score = scores::DIFFERENT_PEER_ID; + store.insert(fourth_record.clone()); + + assert_eq!(store.addresses.len(), 2); + assert!(store.addresses.contains_key(first_record.address())); + assert!(store.addresses.contains_key(fourth_record.address())); + } + #[test] fn take_multiple_records() { let mut store = AddressStore::new(); @@ -331,16 +432,19 @@ mod tests { store.insert(quic_address_record(&mut rng)); } - let known_addresses = store.by_address.len(); + let known_addresses = store.addresses.len(); assert!(known_addresses >= 3); - let taken = store.take(known_addresses - 2); + let taken = store.addresses(known_addresses - 2); assert_eq!(known_addresses - 2, taken.len()); assert!(!store.is_empty()); let mut prev: Option = None; - for record in taken { - assert!(!store.contains(record.address())); + for address in taken { + // Addresses are still in the store. + assert!(store.addresses.contains_key(&address)); + + let record = store.addresses.get(&address).unwrap().clone(); if let Some(previous) = prev { assert!(previous.score > record.score); @@ -359,14 +463,15 @@ mod tests { store.insert(ws_address_record(&mut rng)); store.insert(quic_address_record(&mut rng)); - assert_eq!(store.by_address.len(), 3); + assert_eq!(store.addresses.len(), 3); - let taken = store.take(8usize); + let taken = store.addresses(8usize); assert_eq!(taken.len(), 3); - assert!(store.is_empty()); let mut prev: Option = None; for record in taken { + let record = store.addresses.get(&record).unwrap().clone(); + if prev.is_none() { prev = Some(record); } else { @@ -401,10 +506,9 @@ mod tests { .collect::>(); store.extend(records); - for record in store.by_score { + for record in store.addresses.values().cloned() { let stored = cloned.get(record.address()).unwrap(); assert_eq!(stored.score(), record.score()); - assert_eq!(stored.connection_id(), record.connection_id()); assert_eq!(stored.address(), record.address()); } } @@ -433,10 +537,9 @@ mod tests { let cloned = records.iter().cloned().collect::>(); store.extend(records.iter().map(|(_, record)| record)); - for record in store.by_score { + for record in store.addresses.values().cloned() { let stored = cloned.get(record.address()).unwrap(); assert_eq!(stored.score(), record.score()); - assert_eq!(stored.connection_id(), record.connection_id()); assert_eq!(stored.address(), record.address()); } } diff --git a/src/transport/manager/handle.rs b/src/transport/manager/handle.rs index 0ded6406..ce3c0173 100644 --- a/src/transport/manager/handle.rs +++ b/src/transport/manager/handle.rs @@ -25,8 +25,9 @@ use crate::{ executor::Executor, protocol::ProtocolSet, transport::manager::{ - address::{AddressRecord, AddressStore}, - types::{PeerContext, PeerState, SupportedTransport}, + address::AddressRecord, + peer_state::StateDialResult, + types::{PeerContext, SupportedTransport}, ProtocolContext, TransportManagerEvent, LOG_TARGET, }, types::{protocol::ProtocolName, ConnectionId}, @@ -180,50 +181,54 @@ impl TransportManagerHandle { peer: &PeerId, addresses: impl Iterator, ) -> usize { - let mut peers = self.peers.write(); - let addresses = addresses - .filter_map(|address| { - (self.supported_transport(&address) && !self.is_local_address(&address)) - .then_some(AddressRecord::from_multiaddr(address)?) - }) - .collect::>(); - - // if all of the added addresses belonged to unsupported transports, exit early - let num_added = addresses.len(); - if num_added == 0 { - tracing::debug!( - target: LOG_TARGET, - ?peer, - "didn't add any addresses for peer because transport is not supported", - ); + let mut peer_addresses = HashMap::new(); - return 0usize; + for address in addresses { + // There is not supported transport configured that can dial this address. + if !self.supported_transport(&address) { + continue; + } + if self.is_local_address(&address) { + continue; + } + + // Check the peer ID if present. + if let Some(Protocol::P2p(multihash)) = address.iter().last() { + // Ignore the address if the peer ID is invalid. + let Ok(peer_id) = PeerId::from_multihash(multihash.clone()) else { + continue; + }; + + // This can correspond to the provided peerID or to a different one. + // It is important to keep track of all addresses to have a healthy + // address store to dial from. + peer_addresses.entry(peer_id).or_insert_with(HashSet::new).insert(address); + continue; + } + + // Add the provided peer ID to the address. + let address = address.with(Protocol::P2p(multihash::Multihash::from(peer.clone()))); + peer_addresses.entry(*peer).or_insert_with(HashSet::new).insert(address); } + let num_added = peer_addresses.get(peer).map_or(0, |addresses| addresses.len()); + tracing::trace!( target: LOG_TARGET, ?peer, - ?addresses, + ?peer_addresses, "add known addresses", ); - match peers.get_mut(peer) { - Some(context) => - for record in addresses { - if !context.addresses.contains(record.address()) { - context.addresses.insert(record); - } - }, - None => { - peers.insert( - *peer, - PeerContext { - state: PeerState::Disconnected { dial_record: None }, - addresses: AddressStore::from_iter(addresses), - secondary_connection: None, - }, - ); - } + let mut peers = self.peers.write(); + for (peer, addresses) in peer_addresses { + let entry = peers.entry(peer).or_insert_with(|| PeerContext::default()); + + // All addresses should be valid at this point, since the peer ID was either added or + // double checked. + entry.addresses.extend( + addresses.into_iter().filter_map(|addr| AddressRecord::from_multiaddr(addr)), + ); } num_added @@ -238,36 +243,21 @@ impl TransportManagerHandle { } { - match self.peers.read().get(peer) { - Some(PeerContext { - state: PeerState::Connected { .. }, - .. - }) => return Err(ImmediateDialError::AlreadyConnected), - Some(PeerContext { - state: PeerState::Disconnected { dial_record }, - addresses, - .. - }) => { - if addresses.is_empty() { - return Err(ImmediateDialError::NoAddressAvailable); - } - - // peer is already being dialed, don't dial again until the first dial concluded - if dial_record.is_some() { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?dial_record, - "peer is aready being dialed", - ); - return Ok(()); - } - } - Some(PeerContext { - state: PeerState::Dialing { .. } | PeerState::Opening { .. }, - .. - }) => return Ok(()), - None => return Err(ImmediateDialError::NoAddressAvailable), + let peers = self.peers.read(); + let Some(PeerContext { state, addresses }) = peers.get(peer) else { + return Err(ImmediateDialError::NoAddressAvailable); + }; + + match state.can_dial() { + StateDialResult::AlreadyConnected => + return Err(ImmediateDialError::AlreadyConnected), + StateDialResult::DialingInProgress => return Ok(()), + StateDialResult::Ok => {} + }; + + // Check if we have enough addresses to dial. + if addresses.is_empty() { + return Err(ImmediateDialError::NoAddressAvailable); } } @@ -327,6 +317,11 @@ impl TransportHandle { #[cfg(test)] mod tests { + use crate::transport::manager::{ + address::AddressStore, + peer_state::{ConnectionRecord, PeerState}, + }; + use super::*; use multihash::Multihash; use parking_lot::lock_api::RwLock; @@ -443,16 +438,16 @@ mod tests { peer, PeerContext { state: PeerState::Connected { - record: AddressRecord::from_multiaddr( - Multiaddr::empty() + record: ConnectionRecord { + address: Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) .with(Protocol::Tcp(8888)) .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), - dial_record: None, + connection_id: ConnectionId::from(0), + }, + secondary: None, }, - secondary_connection: None, + addresses: AddressStore::from_iter( vec![Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) @@ -486,15 +481,15 @@ mod tests { peer, PeerContext { state: PeerState::Dialing { - record: AddressRecord::from_multiaddr( - Multiaddr::empty() + dial_record: ConnectionRecord { + address: Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) .with(Protocol::Tcp(8888)) .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), + connection_id: ConnectionId::from(0), + }, }, - secondary_connection: None, + addresses: AddressStore::from_iter( vec![Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) @@ -528,7 +523,6 @@ mod tests { peer, PeerContext { state: PeerState::Disconnected { dial_record: None }, - secondary_connection: None, addresses: AddressStore::new(), }, ); @@ -554,17 +548,16 @@ mod tests { peer, PeerContext { state: PeerState::Disconnected { - dial_record: Some( - AddressRecord::from_multiaddr( - Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), - ), + dial_record: Some(ConnectionRecord::new( + peer, + Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer))), + ConnectionId::from(0), + )), }, - secondary_connection: None, + addresses: AddressStore::from_iter( vec![Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) diff --git a/src/transport/manager/mod.rs b/src/transport/manager/mod.rs index 6816bbf2..1aba50f1 100644 --- a/src/transport/manager/mod.rs +++ b/src/transport/manager/mod.rs @@ -22,14 +22,15 @@ use crate::{ addresses::PublicAddresses, codec::ProtocolCodec, crypto::ed25519::Keypair, - error::{AddressError, DialError, Error}, + error::{AddressError, DialError, Error, NegotiationError}, executor::Executor, protocol::{InnerTransportEvent, TransportService}, transport::{ manager::{ - address::{AddressRecord, AddressStore}, + address::AddressRecord, handle::InnerTransportManagerCommand, - types::{PeerContext, PeerState}, + peer_state::{ConnectionRecord, PeerState, StateDialResult}, + types::PeerContext, }, Endpoint, Transport, TransportEvent, }, @@ -37,6 +38,7 @@ use crate::{ BandwidthSink, PeerId, }; +use address::{scores, AddressStore}; use futures::{Stream, StreamExt}; use indexmap::IndexMap; use multiaddr::{Multiaddr, Protocol}; @@ -45,7 +47,7 @@ use parking_lot::RwLock; use tokio::sync::mpsc::{channel, Receiver, Sender}; use std::{ - collections::{hash_map::Entry, HashMap, HashSet}, + collections::{HashMap, HashSet}, pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, @@ -60,6 +62,7 @@ pub use types::SupportedTransport; mod address; pub mod limits; +mod peer_state; mod types; pub(crate) mod handle; @@ -72,12 +75,6 @@ pub(crate) mod handle; /// Logging target for the file. const LOG_TARGET: &str = "litep2p::transport-manager"; -/// Score for a working address. -const SCORE_CONNECT_SUCCESS: i32 = 100i32; - -/// Score for a non-working address. -const SCORE_CONNECT_FAILURE: i32 = -100i32; - /// The connection established result. #[derive(Debug, Clone, Copy, Eq, PartialEq)] enum ConnectionEstablishedResult { @@ -319,7 +316,7 @@ impl TransportManager { } /// Get next connection ID. - fn next_connection_id(&mut self) -> ConnectionId { + fn next_connection_id(&self) -> ConnectionId { let connection_id = self.next_connection_id.fetch_add(1usize, Ordering::Relaxed); ConnectionId::from(connection_id) @@ -414,6 +411,29 @@ impl TransportManager { self.transport_manager_handle.add_known_address(&peer, address) } + /// Return multiple addresses to dial on supported protocols. + fn open_addresses(addresses: &[Multiaddr]) -> HashMap> { + let mut transports = HashMap::>::new(); + + for address in addresses.iter().cloned() { + #[cfg(feature = "quic")] + if address.iter().any(|p| std::matches!(&p, Protocol::QuicV1)) { + transports.entry(SupportedTransport::Quic).or_default().push(address); + continue; + } + + #[cfg(feature = "websocket")] + if address.iter().any(|p| std::matches!(&p, Protocol::Ws(_) | Protocol::Wss(_))) { + transports.entry(SupportedTransport::WebSocket).or_default().push(address); + continue; + } + + transports.entry(SupportedTransport::Tcp).or_default().push(address); + } + + transports + } + /// Dial peer using `PeerId`. /// /// Returns an error if the peer is unknown or the peer is already connected. @@ -429,157 +449,50 @@ impl TransportManager { } let mut peers = self.peers.write(); - // if the peer is disconnected, return its context - // - // otherwise set the state back what it was and return dial status to caller - let PeerContext { - state, - secondary_connection, - mut addresses, - } = match peers.remove(&peer) { - None => return Err(Error::PeerDoesntExist(peer)), - Some( - context @ PeerContext { - state: PeerState::Connected { .. }, - .. - }, - ) => { - peers.insert(peer, context); - return Err(Error::AlreadyConnected); - } - Some( - context @ PeerContext { - state: PeerState::Dialing { .. } | PeerState::Opening { .. }, - .. - }, - ) => { - peers.insert(peer, context); - return Ok(()); - } - Some(context) => context, - }; + let context = peers.entry(peer).or_insert_with(|| PeerContext::default()); - if let PeerState::Disconnected { - dial_record: Some(_), - } = &state - { - tracing::debug!( - target: LOG_TARGET, - ?peer, - "peer is already being dialed", - ); - - peers.insert( - peer, - PeerContext { - state, - secondary_connection, - addresses, - }, - ); - - return Ok(()); - } - - let mut records: HashMap<_, _> = addresses - .take(limit) - .into_iter() - .map(|record| (record.address().clone(), record)) - .collect(); + // Check if dialing is possible before allocating addresses. + match context.state.can_dial() { + StateDialResult::AlreadyConnected => return Err(Error::AlreadyConnected), + StateDialResult::DialingInProgress => return Ok(()), + StateDialResult::Ok => {} + }; - if records.is_empty() { + // The addresses are sorted by score and contain the remote peer ID. + // We double checked above that the remote peer is not the local peer. + let dial_addresses = context.addresses.addresses(limit); + if dial_addresses.is_empty() { return Err(Error::NoAddressAvailable(peer)); } - - let locked_addresses = self.listen_addresses.read(); - for record in records.values() { - if locked_addresses.contains(record.as_ref()) { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?record, - "tried to dial self", - ); - - debug_assert!(false); - return Err(Error::TriedToDialSelf); - } - } - drop(locked_addresses); - - // set connection id for the address record and put peer into `Opening` state - let connection_id = - ConnectionId::from(self.next_connection_id.fetch_add(1usize, Ordering::Relaxed)); + let connection_id = self.next_connection_id(); tracing::debug!( target: LOG_TARGET, ?connection_id, - addresses = ?records, + addresses = ?dial_addresses, "dial remote peer", ); - let mut transports = HashSet::new(); - #[cfg(feature = "websocket")] - let mut websocket = Vec::new(); - #[cfg(feature = "quic")] - let mut quic = Vec::new(); - let mut tcp = Vec::new(); + let transports = Self::open_addresses(&dial_addresses); - for (address, record) in &mut records { - record.set_connection_id(connection_id); + // Dialing addresses will succeed because the `context.state.can_dial()` returned `Ok`. + let result = context.state.dial_addresses( + connection_id, + dial_addresses.iter().cloned().collect(), + transports.keys().cloned().collect(), + ); + assert_eq!(result, StateDialResult::Ok); - #[cfg(feature = "quic")] - if address.iter().any(|p| std::matches!(&p, Protocol::QuicV1)) { - quic.push(address.clone()); - transports.insert(SupportedTransport::Quic); + for (transport, addresses) in transports { + if addresses.is_empty() { continue; } - #[cfg(feature = "websocket")] - if address.iter().any(|p| std::matches!(&p, Protocol::Ws(_) | Protocol::Wss(_))) { - websocket.push(address.clone()); - transports.insert(SupportedTransport::WebSocket); + let Some(installed_transport) = self.transports.get_mut(&transport) else { continue; - } - - tcp.push(address.clone()); - transports.insert(SupportedTransport::Tcp); - } - - peers.insert( - peer, - PeerContext { - state: PeerState::Opening { - records, - connection_id, - transports, - }, - secondary_connection, - addresses, - }, - ); - - if !tcp.is_empty() { - self.transports - .get_mut(&SupportedTransport::Tcp) - .expect("transport to be supported") - .open(connection_id, tcp)?; - } - - #[cfg(feature = "quic")] - if !quic.is_empty() { - self.transports - .get_mut(&SupportedTransport::Quic) - .expect("transport to be supported") - .open(connection_id, quic)?; - } + }; - #[cfg(feature = "websocket")] - if !websocket.is_empty() { - self.transports - .get_mut(&SupportedTransport::WebSocket) - .expect("transport to be supported") - .open(connection_id, websocket)?; + installed_transport.open(connection_id, addresses)?; } self.pending_connections.insert(connection_id, peer); @@ -593,19 +506,19 @@ impl TransportManager { pub async fn dial_address(&mut self, address: Multiaddr) -> crate::Result<()> { self.connection_limits.on_dial_address()?; - let mut record = AddressRecord::from_multiaddr(address) + let address_record = AddressRecord::from_multiaddr(address) .ok_or(Error::AddressError(AddressError::PeerIdMissing))?; - if self.listen_addresses.read().contains(record.as_ref()) { + if self.listen_addresses.read().contains(address_record.as_ref()) { return Err(Error::TriedToDialSelf); } - tracing::debug!(target: LOG_TARGET, address = ?record.address(), "dial address"); + tracing::debug!(target: LOG_TARGET, address = ?address_record.address(), "dial address"); - let mut protocol_stack = record.as_ref().iter(); + let mut protocol_stack = address_record.as_ref().iter(); match protocol_stack .next() - .ok_or_else(|| Error::TransportNotSupported(record.address().clone()))? + .ok_or_else(|| Error::TransportNotSupported(address_record.address().clone()))? { Protocol::Ip4(_) | Protocol::Ip6(_) => {} Protocol::Dns(_) | Protocol::Dns4(_) | Protocol::Dns6(_) => {} @@ -615,29 +528,36 @@ impl TransportManager { ?transport, "invalid transport, expected `ip4`/`ip6`" ); - return Err(Error::TransportNotSupported(record.address().clone())); + return Err(Error::TransportNotSupported( + address_record.address().clone(), + )); } }; let supported_transport = match protocol_stack .next() - .ok_or_else(|| Error::TransportNotSupported(record.address().clone()))? + .ok_or_else(|| Error::TransportNotSupported(address_record.address().clone()))? { Protocol::Tcp(_) => match protocol_stack.next() { #[cfg(feature = "websocket")] Some(Protocol::Ws(_)) | Some(Protocol::Wss(_)) => SupportedTransport::WebSocket, Some(Protocol::P2p(_)) => SupportedTransport::Tcp, - _ => return Err(Error::TransportNotSupported(record.address().clone())), + _ => + return Err(Error::TransportNotSupported( + address_record.address().clone(), + )), }, #[cfg(feature = "quic")] Protocol::Udp(_) => match protocol_stack .next() - .ok_or_else(|| Error::TransportNotSupported(record.address().clone()))? + .ok_or_else(|| Error::TransportNotSupported(address_record.address().clone()))? { Protocol::QuicV1 => SupportedTransport::Quic, _ => { - tracing::debug!(target: LOG_TARGET, address = ?record.address(), "expected `quic-v1`"); - return Err(Error::TransportNotSupported(record.address().clone())); + tracing::debug!(target: LOG_TARGET, address = ?address_record.address(), "expected `quic-v1`"); + return Err(Error::TransportNotSupported( + address_record.address().clone(), + )); } }, protocol => { @@ -647,85 +567,94 @@ impl TransportManager { "invalid protocol" ); - return Err(Error::TransportNotSupported(record.address().clone())); + return Err(Error::TransportNotSupported( + address_record.address().clone(), + )); } }; // when constructing `AddressRecord`, `PeerId` was verified to be part of the address let remote_peer_id = - PeerId::try_from_multiaddr(record.address()).expect("`PeerId` to exist"); + PeerId::try_from_multiaddr(address_record.address()).expect("`PeerId` to exist"); // set connection id for the address record and put peer into `Dialing` state let connection_id = self.next_connection_id(); - record.set_connection_id(connection_id); + let dial_record = ConnectionRecord { + address: address_record.address().clone(), + connection_id, + }; { let mut peers = self.peers.write(); - match peers.entry(remote_peer_id) { - Entry::Occupied(occupied) => { - let context = occupied.into_mut(); + let context = peers.entry(remote_peer_id).or_insert_with(|| PeerContext::default()); - // For a better address tacking, see: - // https://github.com/paritytech/litep2p/issues/180 - // - // TODO: context.addresses.insert(record.clone()); - - tracing::debug!( - target: LOG_TARGET, - peer = ?remote_peer_id, - state = ?context.state, - "peer state exists", - ); + // Keep the provided record around for possible future dials. + context.addresses.insert(address_record.clone()); - match context.state { - PeerState::Connected { .. } => { - return Err(Error::AlreadyConnected); - } - PeerState::Dialing { .. } | PeerState::Opening { .. } => { - return Ok(()); - } - PeerState::Disconnected { - dial_record: Some(_), - } => { - tracing::debug!( - target: LOG_TARGET, - peer = ?remote_peer_id, - state = ?context.state, - "peer is already being dialed from a disconnected state" - ); - return Ok(()); - } - PeerState::Disconnected { dial_record: None } => { - context.state = PeerState::Dialing { - record: record.clone(), - }; - } - } - } - Entry::Vacant(vacant) => { - vacant.insert(PeerContext { - state: PeerState::Dialing { - record: record.clone(), - }, - addresses: AddressStore::new(), - secondary_connection: None, - }); - } + match context.state.dial_single_address(dial_record) { + StateDialResult::AlreadyConnected => return Err(Error::AlreadyConnected), + StateDialResult::DialingInProgress => return Ok(()), + StateDialResult::Ok => {} }; } self.transports .get_mut(&supported_transport) - .ok_or(Error::TransportNotSupported(record.address().clone()))? - .dial(connection_id, record.address().clone())?; + .ok_or(Error::TransportNotSupported( + address_record.address().clone(), + ))? + .dial(connection_id, address_record.address().clone())?; self.pending_connections.insert(connection_id, remote_peer_id); Ok(()) } + // Update the address on a dial failure. + fn update_address_on_dial_failure(&mut self, mut address: Multiaddr, error: &DialError) { + let mut peers = self.peers.write(); + + let score = AddressStore::error_score(error); + + // Check if the address corresponds to a different peer ID than the one we're + // dialing. This can happen if the node operation restarts the node. + // + // In this case the address is reachable, however the peer ID is different. + // Keep track of this address for future dials. + // + // Note: this is happening quite often in practice and is the primary reason + if let DialError::NegotiationError(NegotiationError::PeerIdMismatch(_, provided)) = error { + let context = peers.entry(*provided).or_insert_with(|| PeerContext::default()); + + if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) { + address.pop(); + } + context.addresses.insert(AddressRecord::new(&provided, address.clone(), score)); + + return; + } + + // Extract the peer ID at this point to give `NegotiationError::PeerIdMismatch` a chance to + // propagate. + let peer_id = match address.iter().last() { + Some(Protocol::P2p(hash)) => PeerId::from_multihash(hash).ok(), + _ => None, + }; + let Some(peer_id) = peer_id else { + return; + }; + + // We need a valid context for this peer to keep track of failed addresses. + let context = peers.entry(peer_id).or_insert_with(|| PeerContext::default()); + context.addresses.insert(AddressRecord::new(&peer_id, address.clone(), score)); + } + /// Handle dial failure. + /// + /// The main purpose of this function is to advance the internal `PeerState`. fn on_dial_failure(&mut self, connection_id: ConnectionId) -> crate::Result<()> { + tracing::trace!(target: LOG_TARGET, ?connection_id, "on dial failure"); + let peer = self.pending_connections.remove(&connection_id).ok_or_else(|| { tracing::error!( target: LOG_TARGET, @@ -736,124 +665,29 @@ impl TransportManager { })?; let mut peers = self.peers.write(); - let context = peers.get_mut(&peer).ok_or_else(|| { - tracing::error!( + let context = peers.entry(peer).or_insert_with(|| PeerContext::default()); + let previous_state = context.state.clone(); + + if !context.state.on_dial_failure(connection_id) { + tracing::warn!( target: LOG_TARGET, ?peer, ?connection_id, - "dial failed for a peer that doesn't exist", + state = ?context.state, + "invalid state for dial failure", + ); + } else { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?previous_state, + state = ?context.state, + "on dial failure completed" ); - debug_assert!(false); - - Error::InvalidState - })?; - - match std::mem::replace( - &mut context.state, - PeerState::Disconnected { dial_record: None }, - ) { - PeerState::Dialing { ref mut record } => { - debug_assert_eq!(record.connection_id(), &Some(connection_id)); - if record.connection_id() != &Some(connection_id) { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?record, - "unknown dial failure for a dialing peer", - ); - - context.state = PeerState::Dialing { - record: record.clone(), - }; - debug_assert!(false); - return Ok(()); - } - - record.update_score(SCORE_CONNECT_FAILURE); - context.addresses.insert(record.clone()); - - context.state = PeerState::Disconnected { dial_record: None }; - Ok(()) - } - PeerState::Opening { .. } => { - todo!(); - } - PeerState::Connected { - record, - dial_record: Some(mut dial_record), - } => { - if dial_record.connection_id() != &Some(connection_id) { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?record, - "unknown dial failure for a connected peer", - ); - - context.state = PeerState::Connected { - record, - dial_record: Some(dial_record), - }; - debug_assert!(false); - return Ok(()); - } - - dial_record.update_score(SCORE_CONNECT_FAILURE); - context.addresses.insert(dial_record); - - context.state = PeerState::Connected { - record, - dial_record: None, - }; - Ok(()) - } - PeerState::Disconnected { - dial_record: Some(mut dial_record), - } => { - tracing::debug!( - target: LOG_TARGET, - ?connection_id, - ?dial_record, - "dial failed for a disconnected peer", - ); - - if dial_record.connection_id() != &Some(connection_id) { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?dial_record, - "unknown dial failure for a disconnected peer", - ); - - context.state = PeerState::Disconnected { - dial_record: Some(dial_record), - }; - debug_assert!(false); - return Ok(()); - } - - dial_record.update_score(SCORE_CONNECT_FAILURE); - context.addresses.insert(dial_record); - - Ok(()) - } - state => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?state, - "invalid state for dial failure", - ); - context.state = state; - - debug_assert!(false); - Ok(()) - } } + + Ok(()) } fn on_pending_incoming_connection(&mut self) -> crate::Result<()> { @@ -866,140 +700,62 @@ impl TransportManager { &mut self, peer: PeerId, connection_id: ConnectionId, - ) -> crate::Result> { + ) -> Option { + tracing::trace!(target: LOG_TARGET, ?peer, ?connection_id, "connection closed"); + self.connection_limits.on_connection_closed(connection_id); let mut peers = self.peers.write(); - let Some(context) = peers.get_mut(&peer) else { + let context = peers.entry(peer).or_insert_with(|| PeerContext::default()); + + let previous_state = context.state.clone(); + let connection_closed = context.state.on_connection_closed(connection_id); + + if context.state == previous_state { tracing::warn!( target: LOG_TARGET, ?peer, ?connection_id, - "cannot handle closed connection: peer doesn't exist", + state = ?context.state, + "invalid state for a closed connection", ); - debug_assert!(false); - return Err(Error::PeerDoesntExist(peer)); - }; - - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "connection closed", - ); - - match std::mem::replace( - &mut context.state, - PeerState::Disconnected { dial_record: None }, - ) { - PeerState::Connected { - record, - dial_record: actual_dial_record, - } => match record.connection_id() == &Some(connection_id) { - // primary connection was closed - // - // if secondary connection exists, switch to using it while keeping peer in - // `Connected` state and if there's only one connection, set peer - // state to `Disconnected` - true => match context.secondary_connection.take() { - None => { - context.addresses.insert(record); - context.state = PeerState::Disconnected { - dial_record: actual_dial_record, - }; - - Ok(Some(TransportEvent::ConnectionClosed { - peer, - connection_id, - })) - } - Some(secondary_connection) => { - context.addresses.insert(record); - context.state = PeerState::Connected { - record: secondary_connection, - dial_record: actual_dial_record, - }; - - Ok(None) - } - }, - // secondary connection was closed - false => match context.secondary_connection.take() { - Some(secondary_connection) => { - if secondary_connection.connection_id() != &Some(connection_id) { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "unknown connection was closed, potentially ignored tertiary connection", - ); + } else { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?previous_state, + state = ?context.state, + "on connection closed completed" + ); + } - context.secondary_connection = Some(secondary_connection); - context.state = PeerState::Connected { - record, - dial_record: actual_dial_record, - }; + connection_closed.then_some(TransportEvent::ConnectionClosed { + peer, + connection_id, + }) + } - return Ok(None); - } + /// Update the address on a connection established. + fn update_address_on_connection_established(&mut self, peer: PeerId, endpoint: &Endpoint) { + // The connection can be inbound or outbound. + // For the inbound connection type, in most cases, the remote peer dialed + // with an ephemeral port which it might not be listening on. + // Therefore, we only insert the address into the store if we're the dialer. + if endpoint.is_listener() { + return; + } - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "secondary connection closed", - ); - - context.addresses.insert(secondary_connection); - context.state = PeerState::Connected { - record, - dial_record: actual_dial_record, - }; - Ok(None) - } - None => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "non-primary connection was closed but secondary connection doesn't exist", - ); - - debug_assert!(false); - Err(Error::InvalidState) - } - }, - }, - PeerState::Disconnected { dial_record } => match context.secondary_connection.take() { - Some(record) => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?record, - ?dial_record, - "peer is disconnected but secondary connection exists", - ); + let mut peers = self.peers.write(); - debug_assert!(false); - context.state = PeerState::Disconnected { dial_record }; - Err(Error::InvalidState) - } - None => { - context.state = PeerState::Disconnected { dial_record }; + let record = AddressRecord::new( + &peer, + endpoint.address().clone(), + scores::CONNECTION_ESTABLISHED, + ); - Ok(Some(TransportEvent::ConnectionClosed { - peer, - connection_id, - })) - } - }, - state => { - tracing::warn!(target: LOG_TARGET, ?peer, ?connection_id, ?state, "invalid state for a closed connection"); - debug_assert!(false); - Err(Error::InvalidState) - } - } + let context = peers.entry(peer).or_insert_with(|| PeerContext::default()); + context.addresses.insert(record); } fn on_connection_established( @@ -1007,6 +763,8 @@ impl TransportManager { peer: PeerId, endpoint: &Endpoint, ) -> crate::Result { + self.update_address_on_connection_established(peer, &endpoint); + if let Some(dialed_peer) = self.pending_connections.remove(&endpoint.connection_id()) { if dialed_peer != peer { tracing::warn!( @@ -1037,252 +795,55 @@ impl TransportManager { } let mut peers = self.peers.write(); - match peers.get_mut(&peer) { - Some(context) => match context.state { - PeerState::Connected { - ref mut dial_record, - .. - } => match context.secondary_connection { - Some(_) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - ?endpoint, - "secondary connection already exists, ignoring connection", - ); - - // insert address into the store only if we're the dialer - // - // if we're the listener, remote might have dialed with an ephemeral port - // which it might not be listening, making this address useless - if endpoint.is_listener() { - context.addresses.insert(AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_CONNECT_SUCCESS, - None, - )) - } + let context = peers.entry(peer).or_insert_with(|| PeerContext::default()); - return Ok(ConnectionEstablishedResult::Reject); - } - None => match dial_record.take() { - Some(record) - if record.connection_id() == &Some(endpoint.connection_id()) => - { - tracing::debug!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - address = ?endpoint.address(), - "dialed connection opened as secondary connection", - ); + let previous_state = context.state.clone(); + let connection_accepted = context + .state + .on_connection_established(ConnectionRecord::from_endpoint(peer, endpoint)); - context.secondary_connection = Some(AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_CONNECT_SUCCESS, - Some(endpoint.connection_id()), - )); - } - None => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - address = ?endpoint.address(), - "secondary connection", - ); - - context.secondary_connection = Some(AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_CONNECT_SUCCESS, - Some(endpoint.connection_id()), - )); - } - Some(record) => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - address = ?endpoint.address(), - dial_record = ?record, - "unknown connection opened as secondary connection, discarding", - ); - - // Preserve the dial record. - *dial_record = Some(record); - - return Ok(ConnectionEstablishedResult::Reject); - } - }, - }, - PeerState::Dialing { ref record, .. } => { - match record.connection_id() == &Some(endpoint.connection_id()) { - true => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - ?endpoint, - ?record, - "connection opened to remote", - ); - - context.state = PeerState::Connected { - record: record.clone(), - dial_record: None, - }; - } - false => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - ?endpoint, - "connection opened by remote while local node was dialing", - ); - - context.state = PeerState::Connected { - record: AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_CONNECT_SUCCESS, - Some(endpoint.connection_id()), - ), - dial_record: Some(record.clone()), - }; - } - } - } - PeerState::Opening { - ref mut records, - connection_id, - ref transports, - } => { - debug_assert!(std::matches!(endpoint, &Endpoint::Listener { .. })); - - tracing::trace!( - target: LOG_TARGET, - ?peer, - dial_connection_id = ?connection_id, - dial_records = ?records, - dial_transports = ?transports, - listener_endpoint = ?endpoint, - "inbound connection while opening an outbound connection", - ); - - // cancel all pending dials - transports.iter().for_each(|transport| { - self.transports - .get_mut(transport) - .expect("transport to exist") - .cancel(connection_id); - }); - - // since an inbound connection was removed, the outbound connection can be - // removed from pending dials - // - // all records have the same `ConnectionId` so it doesn't matter which of them - // is used to remove the pending dial - self.pending_connections.remove( - &records - .iter() - .next() - .expect("record to exist") - .1 - .connection_id() - .expect("`ConnectionId` to exist"), - ); + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?endpoint, + ?previous_state, + state = ?context.state, + "on connection established completed" + ); - let record = match records.remove(endpoint.address()) { - Some(mut record) => { - record.update_score(SCORE_CONNECT_SUCCESS); - record.set_connection_id(endpoint.connection_id()); - record - } - None => AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_CONNECT_SUCCESS, - Some(endpoint.connection_id()), - ), - }; - context.addresses.extend(records.iter().map(|(_, record)| record)); - - context.state = PeerState::Connected { - record, - dial_record: None, - }; - } - PeerState::Disconnected { - ref mut dial_record, - } => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - ?endpoint, - ?dial_record, - "connection opened by remote or delayed dial succeeded", - ); + if connection_accepted { + // Cancel all pending dials if the connection was established. + if let PeerState::Opening { + connection_id, + transports, + .. + } = previous_state + { + // cancel all pending dials + transports.iter().for_each(|transport| { + self.transports + .get_mut(transport) + .expect("transport to exist") + .cancel(connection_id); + }); - let (record, dial_record) = match dial_record.take() { - Some(mut dial_record) => - if dial_record.address() == endpoint.address() { - dial_record.set_connection_id(endpoint.connection_id()); - (dial_record, None) - } else { - ( - AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_CONNECT_SUCCESS, - Some(endpoint.connection_id()), - ), - Some(dial_record), - ) - }, - None => ( - AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_CONNECT_SUCCESS, - Some(endpoint.connection_id()), - ), - None, - ), - }; - - context.state = PeerState::Connected { - record, - dial_record, - }; - } - }, - None => { - peers.insert( - peer, - PeerContext { - state: PeerState::Connected { - record: AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_CONNECT_SUCCESS, - Some(endpoint.connection_id()), - ), - dial_record: None, - }, - addresses: AddressStore::new(), - secondary_connection: None, - }, - ); + // since an inbound connection was removed, the outbound connection can be + // removed from pending dials + // + // TODO: This may race in the following scenario: + // + // T0: we open address X on protocol TCP + // T1: remote peer opens a connection with us + // T2: address X is dialed and event is propagated from TCP to transport manager + // T3: `on_connection_established` is called for T1 and pending connections cleared + // T4: event from T2 is delivered. + self.pending_connections.remove(&connection_id); } + + return Ok(ConnectionEstablishedResult::Accept); } - Ok(ConnectionEstablishedResult::Accept) + Ok(ConnectionEstablishedResult::Reject) } fn on_connection_opened( @@ -1305,107 +866,83 @@ impl TransportManager { }; let mut peers = self.peers.write(); - let context = peers.get_mut(&peer).ok_or_else(|| { + let context = peers.entry(peer).or_insert_with(|| PeerContext::default()); + + // Keep track of the address. + context.addresses.insert(AddressRecord::new( + &peer, + address.clone(), + scores::CONNECTION_ESTABLISHED, + )); + + let previous_state = context.state.clone(); + let record = ConnectionRecord::new(peer, address.clone(), connection_id); + let state_advanced = context.state.on_connection_opened(record); + if !state_advanced { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?connection_id, + state = ?context.state, + "connection opened but `PeerState` is not `Opening`", + ); + return Err(Error::InvalidState); + } + + // State advanced from `Opening` to `Dialing`. + let PeerState::Opening { + connection_id, + transports, + .. + } = previous_state + else { tracing::warn!( target: LOG_TARGET, ?peer, ?connection_id, - "connection opened but peer doesn't exist", + state = ?context.state, + "State missmatch in opening expected by peer state transition", ); + return Err(Error::InvalidState); + }; - debug_assert!(false); - Error::InvalidState - })?; + // Cancel open attempts for other transports as connection already exists. + for transport in transports.iter() { + self.transports + .get_mut(transport) + .expect("transport to exist") + .cancel(connection_id); + } - match std::mem::replace( - &mut context.state, - PeerState::Disconnected { dial_record: None }, - ) { - PeerState::Opening { - mut records, - connection_id, - transports, - } => { + let negotiation = self + .transports + .get_mut(&transport) + .expect("transport to exist") + .negotiate(connection_id); + + match negotiation { + Ok(()) => { tracing::trace!( target: LOG_TARGET, ?peer, ?connection_id, - ?address, ?transport, - "connection opened to peer", + "negotiation started" ); - // cancel open attempts for other transports as connection already exists - for transport in transports.iter() { - self.transports - .get_mut(transport) - .expect("transport to exist") - .cancel(connection_id); - } - - // set peer state to `Dialing` to signal that the connection is fully opening - // - // set the succeeded `AddressRecord` as the one that is used for dialing and move - // all other address records back to `AddressStore`. and ask - // transport to negotiate the - let mut dial_record = records.remove(&address).expect("address to exist"); - dial_record.update_score(SCORE_CONNECT_SUCCESS); - - // negotiate the connection - match self - .transports - .get_mut(&transport) - .expect("transport to exist") - .negotiate(connection_id) - { - Ok(()) => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?dial_record, - ?transport, - "negotiation started" - ); - - self.pending_connections.insert(connection_id, peer); - - context.state = PeerState::Dialing { - record: dial_record, - }; - - for (_, record) in records { - context.addresses.insert(record); - } + self.pending_connections.insert(connection_id, peer); - Ok(()) - } - Err(error) => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?error, - "failed to negotiate connection", - ); - context.state = PeerState::Disconnected { dial_record: None }; - - debug_assert!(false); - Err(Error::InvalidState) - } - } + Ok(()) } - state => { + Err(err) => { tracing::warn!( target: LOG_TARGET, ?peer, ?connection_id, - ?state, - "connection opened but `PeerState` is not `Opening`", + ?err, + "failed to negotiate connection", ); - context.state = state; - - debug_assert!(false); + context.state = PeerState::Disconnected { dial_record: None }; Err(Error::InvalidState) } } @@ -1417,7 +954,7 @@ impl TransportManager { transport: SupportedTransport, connection_id: ConnectionId, ) -> crate::Result> { - let Some(peer) = self.pending_connections.remove(&connection_id) else { + let Some(peer) = self.pending_connections.get(&connection_id).copied() else { tracing::warn!( target: LOG_TARGET, ?connection_id, @@ -1427,75 +964,43 @@ impl TransportManager { }; let mut peers = self.peers.write(); - let context = peers.get_mut(&peer).ok_or_else(|| { + let context = peers.entry(peer).or_insert_with(|| PeerContext::default()); + + let previous_state = context.state.clone(); + let last_transport = context.state.on_open_failure(transport); + + if context.state == previous_state { tracing::warn!( target: LOG_TARGET, ?peer, ?connection_id, - "open failure but peer doesn't exist", + ?transport, + state = ?context.state, + "invalid state for a open failure", ); - debug_assert!(false); - Error::InvalidState - })?; - - match std::mem::replace( - &mut context.state, - PeerState::Disconnected { dial_record: None }, - ) { - PeerState::Opening { - records, - connection_id, - mut transports, - } => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?transport, - "open failure for peer", - ); - transports.remove(&transport); - - if transports.is_empty() { - for (_, mut record) in records { - record.update_score(SCORE_CONNECT_FAILURE); - context.addresses.insert(record); - } - - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "open failure for last transport", - ); - - return Ok(Some(peer)); - } - - self.pending_connections.insert(connection_id, peer); - context.state = PeerState::Opening { - records, - connection_id, - transports, - }; + return Err(Error::InvalidState); + } - Ok(None) - } - state => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?state, - "open failure but `PeerState` is not `Opening`", - ); - context.state = state; + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?transport, + ?previous_state, + state = ?context.state, + "on open failure transition completed" + ); - debug_assert!(false); - Err(Error::InvalidState) - } + if last_transport { + tracing::trace!(target: LOG_TARGET, ?peer, ?connection_id, "open failure for last transport"); + // Remove the pending connection. + self.pending_connections.remove(&connection_id); + // Provide the peer to notify the open failure. + return Ok(Some(peer)); } + + Ok(None) } /// Poll next event from [`crate::transport::manager::TransportManager`]. @@ -1507,13 +1012,8 @@ impl TransportManager { peer, connection: connection_id, } => match self.on_connection_closed(peer, connection_id) { - Ok(None) => {} - Ok(Some(event)) => return Some(event), - Err(error) => tracing::error!( - target: LOG_TARGET, - ?error, - "failed to handle closed connection", - ), + None => {} + Some(event) => return Some(event), } }, command = self.cmd_rx.recv() => match command? { @@ -1541,6 +1041,11 @@ impl TransportManager { "failed to dial peer", ); + // Update the addresses on dial failure regardless of the + // internal peer context state. This ensures a robust address tracking + // while taking into account the error type. + self.update_address_on_dial_failure(address.clone(), &error); + if let Ok(()) = self.on_dial_failure(connection_id) { match address.iter().last() { Some(Protocol::P2p(hash)) => match PeerId::from_multihash(hash) { @@ -1622,6 +1127,7 @@ impl TransportManager { } TransportEvent::ConnectionEstablished { peer, endpoint } => { self.opening_errors.remove(&endpoint.connection_id()); + match self.on_connection_established(peer, &endpoint) { Err(error) => { tracing::debug!( @@ -1686,6 +1192,10 @@ impl TransportManager { } } TransportEvent::OpenFailure { connection_id, errors } => { + for (address, error) in &errors { + self.update_address_on_dial_failure(address.clone(), error); + } + match self.on_open_failure(transport, connection_id) { Err(error) => tracing::debug!( target: LOG_TARGET, @@ -1782,6 +1292,7 @@ impl TransportManager { #[cfg(test)] mod tests { + use crate::transport::manager::{address::AddressStore, peer_state::SecondaryOrDialing}; use limits::ConnectionLimitsConfig; use multihash::Multihash; @@ -1797,6 +1308,7 @@ mod tests { use std::{ net::{Ipv4Addr, Ipv6Addr}, sync::Arc, + usize, }; /// Setup TCP address and connection id. @@ -2140,7 +1652,6 @@ mod tests { PeerContext { state: PeerState::Disconnected { dial_record: None }, addresses: AddressStore::new(), - secondary_connection: None, }, ); @@ -2247,8 +1758,8 @@ mod tests { assert_eq!(manager.pending_connections.len(), 1); match &manager.peers.read().get(&peer).unwrap().state { - PeerState::Dialing { record } => { - assert_eq!(record.address(), &dial_address); + PeerState::Dialing { dial_record } => { + assert_eq!(dial_record.address, dial_address); } state => panic!("invalid state for peer: {state:?}"), } @@ -2268,9 +1779,9 @@ mod tests { let peer = peers.get(&peer).unwrap(); match &peer.state { - PeerState::Connected { dial_record, .. } => { - assert!(dial_record.is_none()); - assert!(peer.addresses.contains(&dial_address)); + PeerState::Connected { secondary, .. } => { + assert!(secondary.is_none()); + assert!(peer.addresses.addresses.contains_key(&dial_address)); } state => panic!("invalid state: {state:?}"), } @@ -2314,8 +1825,8 @@ mod tests { assert_eq!(manager.pending_connections.len(), 1); match &manager.peers.read().get(&peer).unwrap().state { - PeerState::Dialing { record } => { - assert_eq!(record.address(), &dial_address); + PeerState::Dialing { dial_record } => { + assert_eq!(dial_record.address, dial_address); } state => panic!("invalid state for peer: {state:?}"), } @@ -2341,7 +1852,7 @@ mod tests { dial_record: Some(dial_record), .. } => { - assert_eq!(dial_record.address(), &dial_address); + assert_eq!(dial_record.address, dial_address); } state => panic!("invalid state: {state:?}"), } @@ -2357,7 +1868,7 @@ mod tests { PeerState::Disconnected { dial_record: None, .. } => { - assert!(peer.addresses.contains(&dial_address)); + assert!(peer.addresses.addresses.contains_key(&dial_address)); } state => panic!("invalid state: {state:?}"), } @@ -2401,8 +1912,8 @@ mod tests { assert_eq!(manager.pending_connections.len(), 1); match &manager.peers.read().get(&peer).unwrap().state { - PeerState::Dialing { record } => { - assert_eq!(record.address(), &dial_address); + PeerState::Dialing { dial_record } => { + assert_eq!(dial_record.address, dial_address); } state => panic!("invalid state for peer: {state:?}"), } @@ -2428,7 +1939,7 @@ mod tests { dial_record: Some(dial_record), .. } => { - assert_eq!(dial_record.address(), &dial_address); + assert_eq!(dial_record.address, dial_address); } state => panic!("invalid state: {state:?}"), } @@ -2447,7 +1958,7 @@ mod tests { match &peer.state { PeerState::Connected { - dial_record: None, .. + secondary: None, .. } => {} state => panic!("invalid state: {state:?}"), } @@ -2492,7 +2003,7 @@ mod tests { let established_result = manager .on_connection_established( peer, - &Endpoint::listener(address1, ConnectionId::from(0usize)), + &Endpoint::dialer(address1.clone(), ConnectionId::from(0usize)), ) .unwrap(); assert_eq!(established_result, ConnectionEstablishedResult::Accept); @@ -2504,10 +2015,8 @@ mod tests { match &peer.state { PeerState::Connected { - dial_record: None, .. - } => { - assert!(peer.secondary_connection.is_none()); - } + secondary: None, .. + } => {} state => panic!("invalid state: {state:?}"), } } @@ -2526,11 +2035,10 @@ mod tests { match &context.state { PeerState::Connected { - dial_record: None, .. + secondary: Some(SecondaryOrDialing::Secondary(secondary_connection)), + .. } => { - let seconary_connection = context.secondary_connection.as_ref().unwrap(); - assert_eq!(seconary_connection.address(), &address2); - assert_eq!(seconary_connection.score(), SCORE_CONNECT_SUCCESS); + assert_eq!(secondary_connection.address, address2); } state => panic!("invalid state: {state:?}"), } @@ -2550,12 +2058,17 @@ mod tests { match &peer.state { PeerState::Connected { - dial_record: None, .. + secondary: Some(SecondaryOrDialing::Secondary(secondary_connection)), + .. } => { - let seconary_connection = peer.secondary_connection.as_ref().unwrap(); - assert_eq!(seconary_connection.address(), &address2); - assert_eq!(seconary_connection.score(), SCORE_CONNECT_SUCCESS); - assert!(peer.addresses.contains(&address3)); + assert_eq!(secondary_connection.address, address2); + // Endpoint::listener addresses are not tracked. + assert!(!peer.addresses.addresses.contains_key(&address2)); + assert!(!peer.addresses.addresses.contains_key(&address3)); + assert_eq!( + peer.addresses.addresses.get(&address1).unwrap().score(), + scores::CONNECTION_ESTABLISHED + ); } state => panic!("invalid state: {state:?}"), } @@ -2606,10 +2119,8 @@ mod tests { match &peer.state { PeerState::Connected { - dial_record: None, .. - } => { - assert!(peer.secondary_connection.is_none()); - } + secondary: None, .. + } => {} state => panic!("invalid state: {state:?}"), } } @@ -2624,16 +2135,10 @@ mod tests { state => panic!("invalid state: {state:?}"), }; - let dial_record = Some(AddressRecord::new( - &peer, - address2.clone(), - 0, - Some(ConnectionId::from(0usize)), - )); - + let dial_record = ConnectionRecord::new(peer, address2.clone(), ConnectionId::from(0)); peer_context.state = PeerState::Connected { record, - dial_record, + secondary: Some(SecondaryOrDialing::Dialing(dial_record)), }; } @@ -2713,15 +2218,18 @@ mod tests { match &peer.state { PeerState::Connected { - dial_record: None, .. + record, + secondary: None, + .. } => { - assert!(peer.secondary_connection.is_none()); + // Primary connection is established. + assert_eq!(record.connection_id, ConnectionId::from(0usize)); } state => panic!("invalid state: {state:?}"), } } - // second connection is established, verify that the seconary connection is tracked + // second connection is established, verify that the secondary connection is tracked let emit_event = manager .on_connection_established( peer, @@ -2738,18 +2246,17 @@ mod tests { match &context.state { PeerState::Connected { - dial_record: None, .. + secondary: Some(SecondaryOrDialing::Secondary(secondary_connection)), + .. } => { - let seconary_connection = context.secondary_connection.as_ref().unwrap(); - assert_eq!(seconary_connection.address(), &address2); - assert_eq!(seconary_connection.score(), SCORE_CONNECT_SUCCESS); + assert_eq!(secondary_connection.address, address2); } state => panic!("invalid state: {state:?}"), } drop(peers); // close the secondary connection and verify that the peer remains connected - let emit_event = manager.on_connection_closed(peer, ConnectionId::from(1usize)).unwrap(); + let emit_event = manager.on_connection_closed(peer, ConnectionId::from(1usize)); assert!(emit_event.is_none()); let peers = manager.peers.read(); @@ -2757,12 +2264,16 @@ mod tests { match &context.state { PeerState::Connected { - dial_record: None, + secondary: None, record, } => { - assert!(context.secondary_connection.is_none()); - assert!(context.addresses.contains(&address2)); - assert_eq!(record.connection_id(), &Some(ConnectionId::from(0usize))); + assert!(context.addresses.addresses.contains_key(&address2)); + assert_eq!( + context.addresses.addresses.get(&address2).unwrap().score(), + scores::CONNECTION_ESTABLISHED + ); + // Primary remains opened. + assert_eq!(record.connection_id, ConnectionId::from(0usize)); } state => panic!("invalid state: {state:?}"), } @@ -2809,22 +2320,20 @@ mod tests { ConnectionEstablishedResult::Accept )); - // verify that the peer state is `Connected` with no seconary connection + // verify that the peer state is `Connected` with no secondary connection { let peers = manager.peers.read(); let peer = peers.get(&peer).unwrap(); match &peer.state { PeerState::Connected { - dial_record: None, .. - } => { - assert!(peer.secondary_connection.is_none()); - } + secondary: None, .. + } => {} state => panic!("invalid state: {state:?}"), } } - // second connection is established, verify that the seconary connection is tracked + // second connection is established, verify that the secondary connection is tracked let emit_event = manager .on_connection_established( peer, @@ -2841,11 +2350,10 @@ mod tests { match &context.state { PeerState::Connected { - dial_record: None, .. + secondary: Some(SecondaryOrDialing::Secondary(secondary_connection)), + .. } => { - let seconary_connection = context.secondary_connection.as_ref().unwrap(); - assert_eq!(seconary_connection.address(), &address2); - assert_eq!(seconary_connection.score(), SCORE_CONNECT_SUCCESS); + assert_eq!(secondary_connection.address, address2); } state => panic!("invalid state: {state:?}"), } @@ -2853,7 +2361,7 @@ mod tests { // close the primary connection and verify that the peer remains connected // while the primary connection address is stored in peer addresses - let emit_event = manager.on_connection_closed(peer, ConnectionId::from(0usize)).unwrap(); + let emit_event = manager.on_connection_closed(peer, ConnectionId::from(0usize)); assert!(emit_event.is_none()); let peers = manager.peers.read(); @@ -2861,12 +2369,12 @@ mod tests { match &context.state { PeerState::Connected { - dial_record: None, + secondary: None, record, } => { - assert!(context.secondary_connection.is_none()); - assert!(context.addresses.contains(&address1)); - assert_eq!(record.connection_id(), &Some(ConnectionId::from(1usize))); + assert!(!context.addresses.addresses.contains_key(&address1)); + assert!(context.addresses.addresses.contains_key(&address2)); + assert_eq!(record.connection_id, ConnectionId::from(1usize)); } state => panic!("invalid state: {state:?}"), } @@ -2914,7 +2422,7 @@ mod tests { let emit_event = manager .on_connection_established( peer, - &Endpoint::listener(address1, ConnectionId::from(0usize)), + &Endpoint::listener(address1.clone(), ConnectionId::from(0usize)), ) .unwrap(); assert!(std::matches!( @@ -2922,6 +2430,13 @@ mod tests { ConnectionEstablishedResult::Accept )); + // The address1 should be ignored because it is an inbound connection + // initiated from an ephemeral port. + let peers = manager.peers.read(); + let context = peers.get(&peer).unwrap(); + assert!(!context.addresses.addresses.contains_key(&address1)); + drop(peers); + // verify that the peer state is `Connected` with no seconary connection { let peers = manager.peers.read(); @@ -2929,10 +2444,8 @@ mod tests { match &peer.state { PeerState::Connected { - dial_record: None, .. - } => { - assert!(peer.secondary_connection.is_none()); - } + secondary: None, .. + } => {} state => panic!("invalid state: {state:?}"), } } @@ -2949,16 +2462,21 @@ mod tests { ConnectionEstablishedResult::Accept )); + // Ensure we keep track of this address. + let peers = manager.peers.read(); + let context = peers.get(&peer).unwrap(); + assert!(context.addresses.addresses.contains_key(&address2)); + drop(peers); + let peers = manager.peers.read(); let context = peers.get(&peer).unwrap(); match &context.state { PeerState::Connected { - dial_record: None, .. + secondary: Some(SecondaryOrDialing::Secondary(secondary_connection)), + .. } => { - let seconary_connection = context.secondary_connection.as_ref().unwrap(); - assert_eq!(seconary_connection.address(), &address2); - assert_eq!(seconary_connection.score(), SCORE_CONNECT_SUCCESS); + assert_eq!(secondary_connection.address, address2); } state => panic!("invalid state: {state:?}"), } @@ -2978,11 +2496,13 @@ mod tests { let peers = manager.peers.read(); let context = peers.get(&peer).unwrap(); - assert!(context.addresses.contains(&address3)); + // The tertiary connection should be ignored because it is an inbound connection + // initiated from an ephemeral port. + assert!(!context.addresses.addresses.contains_key(&address3)); drop(peers); // close the tertiary connection that was ignored - let emit_event = manager.on_connection_closed(peer, ConnectionId::from(2usize)).unwrap(); + let emit_event = manager.on_connection_closed(peer, ConnectionId::from(2usize)); assert!(emit_event.is_none()); // verify that the state remains unchanged @@ -2991,11 +2511,14 @@ mod tests { match &context.state { PeerState::Connected { - dial_record: None, .. + secondary: Some(SecondaryOrDialing::Secondary(secondary_connection)), + .. } => { - let seconary_connection = context.secondary_connection.as_ref().unwrap(); - assert_eq!(seconary_connection.address(), &address2); - assert_eq!(seconary_connection.score(), SCORE_CONNECT_SUCCESS); + assert_eq!(secondary_connection.address, address2); + assert_eq!( + context.addresses.addresses.get(&address2).unwrap().score(), + scores::CONNECTION_ESTABLISHED + ); } state => panic!("invalid state: {state:?}"), } @@ -3021,27 +2544,6 @@ mod tests { manager.on_dial_failure(ConnectionId::random()).unwrap(); } - #[tokio::test] - #[cfg(debug_assertions)] - #[should_panic] - async fn dial_failure_for_unknow_peer() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ConnectionLimitsConfig::default(), - ); - let connection_id = ConnectionId::random(); - let peer = PeerId::random(); - manager.pending_connections.insert(connection_id, peer); - manager.on_dial_failure(connection_id).unwrap(); - } - #[tokio::test] #[cfg(debug_assertions)] #[should_panic] @@ -3213,16 +2715,16 @@ mod tests { peer, PeerContext { state: PeerState::Connected { - record: AddressRecord::from_multiaddr( - Multiaddr::empty() + record: ConnectionRecord { + address: Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) .with(Protocol::Tcp(8888)) .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), - dial_record: None, + connection_id: ConnectionId::from(0usize), + }, + secondary: None, }, - secondary_connection: None, + addresses: AddressStore::from_iter( vec![Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) @@ -3261,15 +2763,15 @@ mod tests { peer, PeerContext { state: PeerState::Dialing { - record: AddressRecord::from_multiaddr( - Multiaddr::empty() + dial_record: ConnectionRecord { + address: Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) .with(Protocol::Tcp(8888)) .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), + connection_id: ConnectionId::from(0usize), + }, }, - secondary_connection: None, + addresses: AddressStore::from_iter( vec![Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) @@ -3292,10 +2794,10 @@ mod tests { let peer_context = peers.get(&peer).unwrap(); match &peer_context.state { - PeerState::Dialing { record } => { + PeerState::Dialing { dial_record } => { assert_eq!( - record.address(), - &Multiaddr::empty() + dial_record.address, + Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) .with(Protocol::Tcp(8888)) .with(Protocol::P2p(Multihash::from(peer))) @@ -3324,17 +2826,16 @@ mod tests { peer, PeerContext { state: PeerState::Disconnected { - dial_record: Some( - AddressRecord::from_multiaddr( - Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), - ), + dial_record: Some(ConnectionRecord::new( + peer, + Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer))), + ConnectionId::from(0), + )), }, - secondary_connection: None, + addresses: AddressStore::new(), }, ); @@ -3540,19 +3041,13 @@ mod tests { let peers = manager.peers.read(); match peers.get(&peer).unwrap() { PeerContext { - state: - PeerState::Connected { - record, - dial_record, - }, - secondary_connection, + state: PeerState::Connected { record, secondary }, addresses, } => { - assert!(!addresses.contains(record.address())); - assert!(dial_record.is_none()); - assert!(secondary_connection.is_none()); - assert_eq!(record.address(), &dial_address); - assert_eq!(record.connection_id(), &Some(connection_id)); + assert!(!addresses.addresses.contains_key(&record.address)); + assert!(secondary.is_none()); + assert_eq!(record.address, dial_address); + assert_eq!(record.connection_id, connection_id); } state => panic!("invalid peer state: {state:?}"), } @@ -3632,19 +3127,15 @@ mod tests { let peers = manager.peers.read(); match peers.get(&peer).unwrap() { PeerContext { - state: - PeerState::Connected { - record, - dial_record, - }, - secondary_connection, + state: PeerState::Connected { record, secondary }, addresses, } => { - assert!(addresses.is_empty()); - assert!(dial_record.is_none()); - assert!(secondary_connection.is_none()); - assert_eq!(record.address(), &dial_address); - assert_eq!(record.connection_id(), &Some(connection_id)); + // Saved from the dial attempt. + assert_eq!(addresses.addresses.get(&dial_address).unwrap().score(), 0); + + assert!(secondary.is_none()); + assert_eq!(record.address, dial_address); + assert_eq!(record.connection_id, connection_id); } state => panic!("invalid peer state: {state:?}"), } @@ -3714,7 +3205,7 @@ mod tests { assert_eq!(result, ConnectionEstablishedResult::Reject); // Close one connection. - let _ = manager.on_connection_closed(peer, first_connection_id).unwrap(); + assert!(manager.on_connection_closed(peer, first_connection_id).is_none()); // The second peer can establish 2 inbounds now. let result = manager @@ -3805,7 +3296,7 @@ mod tests { )); // Close one connection. - let _ = manager.on_connection_closed(peer, first_connection_id).unwrap(); + assert!(manager.on_connection_closed(peer, first_connection_id).is_some()); // We can now dial again. manager.dial_address(first_addr.clone()).await.unwrap(); @@ -3832,7 +3323,7 @@ mod tests { // Random peer ID. let peer = PeerId::random(); - let (first_addr, first_connection_id) = setup_dial_addr(peer, 0); + let (first_addr, _first_connection_id) = setup_dial_addr(peer, 0); let second_connection_id = ConnectionId::from(1); let different_connection_id = ConnectionId::from(2); @@ -3841,18 +3332,16 @@ mod tests { let mut peers = manager.peers.write(); let state = PeerState::Connected { - record: AddressRecord::new(&peer, first_addr.clone(), 0, Some(first_connection_id)), - dial_record: Some(AddressRecord::new( - &peer, + record: ConnectionRecord::new(peer, first_addr.clone(), ConnectionId::from(0)), + secondary: Some(SecondaryOrDialing::Dialing(ConnectionRecord::new( + peer, first_addr.clone(), - 0, - Some(second_connection_id), - )), + second_connection_id, + ))), }; let peer_context = PeerContext { state, - secondary_connection: None, addresses: AddressStore::from_iter(vec![first_addr.clone()].into_iter()), }; @@ -3911,8 +3400,8 @@ mod tests { let peers = manager.peers.read(); let peer_context = peers.get(&peer).unwrap(); match &peer_context.state { - PeerState::Dialing { record } => { - assert_eq!(record.address(), &first_addr); + PeerState::Dialing { dial_record } => { + assert_eq!(dial_record.address, first_addr); } state => panic!("invalid state: {state:?}"), } @@ -3932,21 +3421,20 @@ mod tests { match &peer_context.state { PeerState::Connected { record, - dial_record, + secondary: Some(SecondaryOrDialing::Dialing(dial_record)), } => { - assert_eq!(record.address(), &remote_addr); - assert_eq!(record.connection_id(), &Some(remote_connection_id)); + assert_eq!(record.address, remote_addr); + assert_eq!(record.connection_id, remote_connection_id); - let dial_record = dial_record.as_ref().unwrap(); - assert_eq!(dial_record.address(), &first_addr); - assert_eq!(dial_record.connection_id(), &Some(first_connection_id)) + assert_eq!(dial_record.address, first_addr); + assert_eq!(dial_record.connection_id, first_connection_id) } state => panic!("invalid state: {state:?}"), } } // Step 3. The peer disconnects while we have a dialing in flight. - let event = manager.on_connection_closed(peer, remote_connection_id).unwrap().unwrap(); + let event = manager.on_connection_closed(peer, remote_connection_id).unwrap(); match event { TransportEvent::ConnectionClosed { peer: event_peer, @@ -3963,8 +3451,8 @@ mod tests { match &peer_context.state { PeerState::Disconnected { dial_record } => { let dial_record = dial_record.as_ref().unwrap(); - assert_eq!(dial_record.address(), &first_addr); - assert_eq!(dial_record.connection_id(), &Some(first_connection_id)); + assert_eq!(dial_record.address, first_addr); + assert_eq!(dial_record.connection_id, first_connection_id); } state => panic!("invalid state: {state:?}"), } @@ -3979,8 +3467,8 @@ mod tests { match &peer_context.state { PeerState::Disconnected { dial_record } => { let dial_record = dial_record.as_ref().unwrap(); - assert_eq!(dial_record.address(), &first_addr); - assert_eq!(dial_record.connection_id(), &Some(first_connection_id)); + assert_eq!(dial_record.address, first_addr); + assert_eq!(dial_record.connection_id, first_connection_id); } state => panic!("invalid state: {state:?}"), } @@ -4000,15 +3488,14 @@ mod tests { match &peer_context.state { PeerState::Connected { record, - dial_record, + secondary: Some(SecondaryOrDialing::Dialing(dial_record)), } => { - assert_eq!(record.address(), &remote_addr); - assert_eq!(record.connection_id(), &Some(remote_connection_id)); + assert_eq!(record.address, remote_addr); + assert_eq!(record.connection_id, remote_connection_id); // We have not overwritten the first dial record in step 4. - let dial_record = dial_record.as_ref().unwrap(); - assert_eq!(dial_record.address(), &first_addr); - assert_eq!(dial_record.connection_id(), &Some(first_connection_id)); + assert_eq!(dial_record.address, first_addr); + assert_eq!(dial_record.connection_id, first_connection_id); } state => panic!("invalid state: {state:?}"), } @@ -4025,7 +3512,7 @@ mod tests { } #[tokio::test] - async fn do_not_overwrite_dial_addresses() { + async fn persist_dial_addresses() { let _ = tracing_subscriber::fmt() .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) .try_init(); @@ -4063,14 +3550,17 @@ mod tests { let peers = manager.peers.read(); let peer_context = peers.get(&peer).unwrap(); match &peer_context.state { - PeerState::Dialing { record } => { - assert_eq!(record.address(), &dial_address); + PeerState::Dialing { dial_record } => { + assert_eq!(dial_record.address, dial_address); } state => panic!("invalid state: {state:?}"), } - // The address is not saved yet. - assert!(!peer_context.addresses.contains(&dial_address)); + // The address is saved for future dials. + assert_eq!( + peer_context.addresses.addresses.get(&dial_address).unwrap().score(), + 0 + ); } let second_address = Multiaddr::empty() @@ -4088,14 +3578,21 @@ mod tests { let peer_context = peers.get(&peer).unwrap(); match &peer_context.state { // Must still be dialing the first address. - PeerState::Dialing { record } => { - assert_eq!(record.address(), &dial_address); + PeerState::Dialing { dial_record } => { + assert_eq!(dial_record.address, dial_address); } state => panic!("invalid state: {state:?}"), } - assert!(!peer_context.addresses.contains(&dial_address)); - assert!(!peer_context.addresses.contains(&second_address)); + // The address is still saved, even if a second dial is not initiated. + assert_eq!( + peer_context.addresses.addresses.get(&dial_address).unwrap().score(), + 0 + ); + assert_eq!( + peer_context.addresses.addresses.get(&second_address).unwrap().score(), + 0 + ); } } diff --git a/src/transport/manager/peer_state.rs b/src/transport/manager/peer_state.rs new file mode 100644 index 00000000..923bdf98 --- /dev/null +++ b/src/transport/manager/peer_state.rs @@ -0,0 +1,928 @@ +// Copyright 2024 litep2p developers +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +//! Peer state management. + +use crate::{ + transport::{manager::SupportedTransport, Endpoint}, + types::ConnectionId, + PeerId, +}; + +use multiaddr::{Multiaddr, Protocol}; +use multihash::Multihash; + +use std::collections::HashSet; + +/// The peer state that tracks connections and dialing attempts. +/// +/// # State Machine +/// +/// ## [`PeerState::Disconnected`] +/// +/// Initially, the peer is in the [`PeerState::Disconnected`] state without a +/// [`PeerState::Disconnected::dial_record`]. This means the peer is fully disconnected. +/// +/// Next states: +/// - [`PeerState::Disconnected`] -> [`PeerState::Dialing`] (via [`PeerState::dial_single_address`]) +/// - [`PeerState::Disconnected`] -> [`PeerState::Opening`] (via [`PeerState::dial_addresses`]) +/// +/// ## [`PeerState::Dialing`] +/// +/// The peer can transition to the [`PeerState::Dialing`] state when a dialing attempt is +/// initiated. This only happens when the peer is dialed on a single address via +/// [`PeerState::dial_single_address`]. +/// +/// The dialing state implies the peer is reached on the socket address provided, as well as +/// negotiating noise and yamux protocols. +/// +/// Next states: +/// - [`PeerState::Dialing`] -> [`PeerState::Connected`] (via +/// [`PeerState::on_connection_established`]) +/// - [`PeerState::Dialing`] -> [`PeerState::Disconnected`] (via [`PeerState::on_dial_failure`]) +/// +/// ## [`PeerState::Opening`] +/// +/// The peer can transition to the [`PeerState::Opening`] state when a dialing attempt is +/// initiated on multiple addresses via [`PeerState::dial_addresses`]. This takes into account +/// the parallelism factor (8 maximum) of the dialing attempts. +/// +/// The opening state holds information about which protocol is being dialed to properly report back +/// errors. +/// +/// The opening state is similar to the dial state, however the peer is only reached on a socket +/// address. The noise and yamux protocols are not negotiated yet. This state transitions to +/// [`PeerState::Dialing`] for the final part of the negotiation. Please note that it would be +/// wasteful to negotiate the noise and yamux protocols on all addresses, since only one +/// connection is kept around. +/// +/// This is something we'll reconsider in the future if we encounter issues. +/// +/// Next states: +/// - [`PeerState::Opening`] -> [`PeerState::Dialing`] (via transport manager +/// `on_connection_opened`) +/// - [`PeerState::Opening`] -> [`PeerState::Disconnected`] (via transport manager +/// `on_connection_opened` if negotiation cannot be started or via `on_open_failure`) +#[derive(Debug, Clone, PartialEq)] +pub enum PeerState { + /// `Litep2p` is connected to peer. + Connected { + /// The established record of the connection. + record: ConnectionRecord, + + /// Secondary record, this can either be a dial record or an established connection. + /// + /// While the local node was dialing a remote peer, the remote peer might've dialed + /// the local node and connection was established successfully. This dial address + /// is stored for processing later when the dial attempt concluded as either + /// successful/failed. + secondary: Option, + }, + + /// Connection to peer is opening over one or more addresses. + Opening { + /// Address records used for dialing. + addresses: HashSet, + + /// Connection ID. + connection_id: ConnectionId, + + /// Active transports. + transports: HashSet, + }, + + /// Peer is being dialed. + Dialing { + /// Address record. + dial_record: ConnectionRecord, + }, + + /// `Litep2p` is not connected to peer. + Disconnected { + /// Dial address, if it exists. + /// + /// While the local node was dialing a remote peer, the remote peer might've dialed + /// the local node and connection was established successfully. The connection might've + /// been closed before the dial concluded which means that + /// [`crate::transport::manager::TransportManager`] must be prepared to handle the dial + /// failure even after the connection has been closed. + dial_record: Option, + }, +} + +/// The state of the secondary connection. +#[derive(Debug, Clone, PartialEq)] +pub enum SecondaryOrDialing { + /// The secondary connection is established. + Secondary(ConnectionRecord), + /// The primary connection is established, but the secondary connection is still dialing. + Dialing(ConnectionRecord), +} + +/// Result of initiating a dial. +#[derive(Debug, Clone, PartialEq)] +pub enum StateDialResult { + /// The peer is already connected. + AlreadyConnected, + /// The dialing state is already in progress. + DialingInProgress, + /// The peer is disconnected, start dialing. + Ok, +} + +impl PeerState { + /// Check if the peer can be dialed. + pub fn can_dial(&self) -> StateDialResult { + match self { + // The peer is already connected, no need to dial again. + Self::Connected { .. } => return StateDialResult::AlreadyConnected, + // The dialing state is already in progress, an event will be emitted later. + Self::Dialing { .. } + | Self::Opening { .. } + | Self::Disconnected { + dial_record: Some(_), + } => { + return StateDialResult::DialingInProgress; + } + + Self::Disconnected { dial_record: None } => StateDialResult::Ok, + } + } + + /// Dial the peer on a single address. + pub fn dial_single_address(&mut self, dial_record: ConnectionRecord) -> StateDialResult { + let check = self.can_dial(); + if check != StateDialResult::Ok { + return check; + } + + match self { + Self::Disconnected { dial_record: None } => { + *self = PeerState::Dialing { dial_record }; + return StateDialResult::Ok; + } + state => panic!( + "unexpected state: {:?} validated by Self::can_dial; qed", + state + ), + } + } + + /// Dial the peer on multiple addresses. + pub fn dial_addresses( + &mut self, + connection_id: ConnectionId, + addresses: HashSet, + transports: HashSet, + ) -> StateDialResult { + let check = self.can_dial(); + if check != StateDialResult::Ok { + return check; + } + + match self { + Self::Disconnected { dial_record: None } => { + *self = PeerState::Opening { + addresses, + connection_id, + transports, + }; + return StateDialResult::Ok; + } + state => panic!( + "unexpected state: {:?} validated by Self::can_dial; qed", + state + ), + } + } + + /// Handle dial failure. + /// + /// # Transitions + /// + /// - [`PeerState::Dialing`] (with record) -> [`PeerState::Disconnected`] + /// - [`PeerState::Connected`] (with dial record) -> [`PeerState::Connected`] + /// - [`PeerState::Disconnected`] (with dial record) -> [`PeerState::Disconnected`] + /// + /// Returns `true` if the connection was handled. + pub fn on_dial_failure(&mut self, connection_id: ConnectionId) -> bool { + match self { + // Clear the dial record if the connection ID matches. + Self::Dialing { dial_record } => + if dial_record.connection_id == connection_id { + *self = Self::Disconnected { dial_record: None }; + return true; + }, + + Self::Connected { + record, + secondary: Some(SecondaryOrDialing::Dialing(dial_record)), + } => + if dial_record.connection_id == connection_id { + *self = Self::Connected { + record: record.clone(), + secondary: None, + }; + return true; + }, + + Self::Disconnected { + dial_record: Some(dial_record), + } => + if dial_record.connection_id == connection_id { + *self = Self::Disconnected { dial_record: None }; + return true; + }, + + _ => (), + }; + + false + } + + /// Returns `true` if the connection should be accepted by the transport manager. + pub fn on_connection_established(&mut self, connection: ConnectionRecord) -> bool { + match self { + // Transform the dial record into a secondary connection. + Self::Connected { + record, + secondary: Some(SecondaryOrDialing::Dialing(dial_record)), + } => + if dial_record.connection_id == connection.connection_id { + *self = Self::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Secondary(connection)), + }; + + return true; + }, + // There's place for a secondary connection. + Self::Connected { + record, + secondary: None, + } => { + *self = Self::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Secondary(connection)), + }; + + return true; + } + + // Convert the dial record into a primary connection or preserve it. + Self::Dialing { dial_record } + | Self::Disconnected { + dial_record: Some(dial_record), + } => + if dial_record.connection_id == connection.connection_id { + *self = Self::Connected { + record: connection.clone(), + secondary: None, + }; + return true; + } else { + *self = Self::Connected { + record: connection, + secondary: Some(SecondaryOrDialing::Dialing(dial_record.clone())), + }; + return true; + }, + + Self::Disconnected { dial_record: None } => { + *self = Self::Connected { + record: connection, + secondary: None, + }; + + return true; + } + + // Accept the incoming connection. + Self::Opening { .. } => { + *self = Self::Connected { + record: connection, + secondary: None, + }; + + return true; + } + + _ => {} + }; + + return false; + } + + /// Returns `true` if the connection was closed. + pub fn on_connection_closed(&mut self, connection_id: ConnectionId) -> bool { + match self { + Self::Connected { record, secondary } => { + // Primary connection closed. + if record.connection_id == connection_id { + match secondary { + // Promote secondary connection to primary. + Some(SecondaryOrDialing::Secondary(secondary)) => { + *self = Self::Connected { + record: secondary.clone(), + secondary: None, + }; + } + // Preserve the dial record. + Some(SecondaryOrDialing::Dialing(dial_record)) => { + *self = Self::Disconnected { + dial_record: Some(dial_record.clone()), + }; + + return true; + } + None => { + *self = Self::Disconnected { dial_record: None }; + + return true; + } + }; + + return false; + } + + match secondary { + // Secondary connection closed. + Some(SecondaryOrDialing::Secondary(secondary)) + if secondary.connection_id == connection_id => + { + *self = Self::Connected { + record: record.clone(), + secondary: None, + }; + } + _ => (), + } + } + _ => (), + } + + false + } + + /// Returns `true` if the last transport failed to open. + pub fn on_open_failure(&mut self, transport: SupportedTransport) -> bool { + match self { + Self::Opening { transports, .. } => { + transports.remove(&transport); + + if transports.is_empty() { + *self = Self::Disconnected { dial_record: None }; + return true; + } + + return false; + } + _ => false, + } + } + + /// Returns `true` if the connection was opened. + pub fn on_connection_opened(&mut self, record: ConnectionRecord) -> bool { + match self { + Self::Opening { .. } => { + // TODO: Litep2p did not check previously if the + // connection record is valid or not, in terms of having + // the same connection ID and the address part of the + // address set. + + *self = Self::Dialing { + dial_record: record.clone(), + }; + + true + } + _ => false, + } + } +} + +/// The connection record keeps track of the connection ID and the address of the connection. +/// +/// The connection ID is used to track the connection in the transport layer. +/// While the address is used to keep a healthy view of the network for dialing purposes. +/// +/// # Note +/// +/// The structure is used to keep track of: +/// +/// - dialing state for outbound connections. +/// - established outbound connections via [`PeerState::Connected`]. +/// - established inbound connections via `PeerContext::secondary_connection`. +#[derive(Debug, Clone, Hash, PartialEq)] +pub struct ConnectionRecord { + /// Address of the connection. + /// + /// The address must contain the peer ID extension `/p2p/`. + pub address: Multiaddr, + + /// Connection ID resulted from dialing. + pub connection_id: ConnectionId, +} + +impl ConnectionRecord { + /// Construct a new connection record. + pub fn new(peer: PeerId, address: Multiaddr, connection_id: ConnectionId) -> Self { + Self { + address: Self::ensure_peer_id(peer, address), + connection_id, + } + } + + /// Create a new connection record from the peer ID and the endpoint. + pub fn from_endpoint(peer: PeerId, endpoint: &Endpoint) -> Self { + Self { + address: Self::ensure_peer_id(peer, endpoint.address().clone()), + connection_id: endpoint.connection_id(), + } + } + + /// Ensures the peer ID is present in the address. + fn ensure_peer_id(peer: PeerId, address: Multiaddr) -> Multiaddr { + if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) { + address.with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).expect("valid peer id"), + )) + } else { + address + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn state_can_dial() { + let state = PeerState::Disconnected { dial_record: None }; + assert_eq!(state.can_dial(), StateDialResult::Ok); + + let record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(0), + ); + + let state = PeerState::Disconnected { + dial_record: Some(record.clone()), + }; + assert_eq!(state.can_dial(), StateDialResult::DialingInProgress); + + let state = PeerState::Dialing { + dial_record: record.clone(), + }; + assert_eq!(state.can_dial(), StateDialResult::DialingInProgress); + + let state = PeerState::Opening { + addresses: Default::default(), + connection_id: ConnectionId::from(0), + transports: Default::default(), + }; + assert_eq!(state.can_dial(), StateDialResult::DialingInProgress); + + let state = PeerState::Connected { + record, + secondary: None, + }; + assert_eq!(state.can_dial(), StateDialResult::AlreadyConnected); + } + + #[test] + fn state_dial_single_address() { + let record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(0), + ); + + let mut state = PeerState::Disconnected { dial_record: None }; + assert_eq!( + state.dial_single_address(record.clone()), + StateDialResult::Ok + ); + assert_eq!( + state, + PeerState::Dialing { + dial_record: record + } + ); + } + + #[test] + fn state_dial_addresses() { + let mut state = PeerState::Disconnected { dial_record: None }; + assert_eq!( + state.dial_addresses( + ConnectionId::from(0), + Default::default(), + Default::default() + ), + StateDialResult::Ok + ); + assert_eq!( + state, + PeerState::Opening { + addresses: Default::default(), + connection_id: ConnectionId::from(0), + transports: Default::default() + } + ); + } + + #[test] + fn check_dial_failure() { + let record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(0), + ); + + // Check from the dialing state. + { + let mut state = PeerState::Dialing { + dial_record: record.clone(), + }; + let previous_state = state.clone(); + // Check with different connection ID. + state.on_dial_failure(ConnectionId::from(1)); + assert_eq!(state, previous_state); + + // Check with the same connection ID. + state.on_dial_failure(ConnectionId::from(0)); + assert_eq!(state, PeerState::Disconnected { dial_record: None }); + } + + // Check from the connected state without dialing state. + { + let mut state = PeerState::Connected { + record: record.clone(), + secondary: None, + }; + let previous_state = state.clone(); + // Check with different connection ID. + state.on_dial_failure(ConnectionId::from(1)); + assert_eq!(state, previous_state); + + // Check with the same connection ID. + // The connection ID is checked against dialing records, not established connections. + state.on_dial_failure(ConnectionId::from(0)); + assert_eq!(state, previous_state); + } + + // Check from the connected state with dialing state. + { + let mut state = PeerState::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Dialing(record.clone())), + }; + let previous_state = state.clone(); + // Check with different connection ID. + state.on_dial_failure(ConnectionId::from(1)); + assert_eq!(state, previous_state); + + // Check with the same connection ID. + // Dial record is cleared. + state.on_dial_failure(ConnectionId::from(0)); + assert_eq!( + state, + PeerState::Connected { + record: record.clone(), + secondary: None, + } + ); + } + + // Check from the disconnected state. + { + let mut state = PeerState::Disconnected { + dial_record: Some(record.clone()), + }; + let previous_state = state.clone(); + // Check with different connection ID. + state.on_dial_failure(ConnectionId::from(1)); + assert_eq!(state, previous_state); + + // Check with the same connection ID. + state.on_dial_failure(ConnectionId::from(0)); + assert_eq!(state, PeerState::Disconnected { dial_record: None }); + } + } + + #[test] + fn check_connection_established() { + let record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(0), + ); + let second_record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(1), + ); + + // Check from the connected state without secondary connection. + { + let mut state = PeerState::Connected { + record: record.clone(), + secondary: None, + }; + // Secondary is established. + assert!(state.on_connection_established(record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Secondary(record.clone())), + } + ); + } + + // Check from the connected state with secondary dialing connection. + { + let mut state = PeerState::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Dialing(record.clone())), + }; + // Promote the secondary connection. + assert!(state.on_connection_established(record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Secondary(record.clone())), + } + ); + } + + // Check from the connected state with secondary established connection. + { + let mut state = PeerState::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Secondary(record.clone())), + }; + // No state to advance. + assert!(!state.on_connection_established(record.clone())); + } + + // Opening state is completely wiped out. + { + let mut state = PeerState::Opening { + addresses: Default::default(), + connection_id: ConnectionId::from(0), + transports: Default::default(), + }; + assert!(state.on_connection_established(record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: record.clone(), + secondary: None, + } + ); + } + + // Disconnected state with dial record. + { + let mut state = PeerState::Disconnected { + dial_record: Some(record.clone()), + }; + assert!(state.on_connection_established(record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: record.clone(), + secondary: None, + } + ); + } + + // Disconnected with different dial record. + { + let mut state = PeerState::Disconnected { + dial_record: Some(record.clone()), + }; + assert!(state.on_connection_established(second_record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: second_record.clone(), + secondary: Some(SecondaryOrDialing::Dialing(record.clone())) + } + ); + } + + // Disconnected without dial record. + { + let mut state = PeerState::Disconnected { dial_record: None }; + assert!(state.on_connection_established(record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: record.clone(), + secondary: None, + } + ); + } + + // Dialing with different dial record. + { + let mut state = PeerState::Dialing { + dial_record: record.clone(), + }; + assert!(state.on_connection_established(second_record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: second_record.clone(), + secondary: Some(SecondaryOrDialing::Dialing(record.clone())) + } + ); + } + + // Dialing with the same dial record. + { + let mut state = PeerState::Dialing { + dial_record: record.clone(), + }; + assert!(state.on_connection_established(record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: record.clone(), + secondary: None, + } + ); + } + } + + #[test] + fn check_connection_closed() { + let record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(0), + ); + let second_record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(1), + ); + + // Primary is closed + { + let mut state = PeerState::Connected { + record: record.clone(), + secondary: None, + }; + assert!(state.on_connection_closed(ConnectionId::from(0))); + assert_eq!(state, PeerState::Disconnected { dial_record: None }); + } + + // Primary is closed with secondary promoted + { + let mut state = PeerState::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Secondary(second_record.clone())), + }; + // Peer is still connected. + assert!(!state.on_connection_closed(ConnectionId::from(0))); + assert_eq!( + state, + PeerState::Connected { + record: second_record.clone(), + secondary: None, + } + ); + } + + // Primary is closed with secondary dial record + { + let mut state = PeerState::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Dialing(second_record.clone())), + }; + assert!(state.on_connection_closed(ConnectionId::from(0))); + assert_eq!( + state, + PeerState::Disconnected { + dial_record: Some(second_record.clone()) + } + ); + } + } + + #[test] + fn check_open_failure() { + let mut state = PeerState::Opening { + addresses: Default::default(), + connection_id: ConnectionId::from(0), + transports: [SupportedTransport::Tcp].into_iter().collect(), + }; + + // This is the last protocol + assert!(state.on_open_failure(SupportedTransport::Tcp)); + assert_eq!(state, PeerState::Disconnected { dial_record: None }); + } + + #[test] + fn check_open_connection() { + let record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(0), + ); + + let mut state = PeerState::Opening { + addresses: Default::default(), + connection_id: ConnectionId::from(0), + transports: [SupportedTransport::Tcp].into_iter().collect(), + }; + + assert!(state.on_connection_opened(record.clone())); + } + + #[test] + fn check_full_lifecycle() { + let record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(0), + ); + + let mut state = PeerState::Disconnected { dial_record: None }; + // Dialing. + assert_eq!( + state.dial_single_address(record.clone()), + StateDialResult::Ok + ); + assert_eq!( + state, + PeerState::Dialing { + dial_record: record.clone() + } + ); + + // Dialing failed. + state.on_dial_failure(ConnectionId::from(0)); + assert_eq!(state, PeerState::Disconnected { dial_record: None }); + + // Opening. + assert_eq!( + state.dial_addresses( + ConnectionId::from(0), + Default::default(), + Default::default() + ), + StateDialResult::Ok + ); + + // Open failure. + assert!(state.on_open_failure(SupportedTransport::Tcp)); + assert_eq!(state, PeerState::Disconnected { dial_record: None }); + + // Dial again. + assert_eq!( + state.dial_single_address(record.clone()), + StateDialResult::Ok + ); + assert_eq!( + state, + PeerState::Dialing { + dial_record: record.clone() + } + ); + + // Successful dial. + assert!(state.on_connection_established(record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: record.clone(), + secondary: None + } + ); + } +} diff --git a/src/transport/manager/types.rs b/src/transport/manager/types.rs index b367a086..15eb2c50 100644 --- a/src/transport/manager/types.rs +++ b/src/transport/manager/types.rs @@ -18,14 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{ - transport::manager::address::{AddressRecord, AddressStore}, - types::ConnectionId, -}; - -use multiaddr::Multiaddr; - -use std::collections::{HashMap, HashSet}; +use crate::transport::manager::{address::AddressStore, peer_state::PeerState}; /// Supported protocols. #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] @@ -46,63 +39,21 @@ pub enum SupportedTransport { WebSocket, } -/// Peer state. -#[derive(Debug)] -pub enum PeerState { - /// `Litep2p` is connected to peer. - Connected { - /// Address record. - record: AddressRecord, - - /// Dial address, if it exists. - /// - /// While the local node was dialing a remote peer, the remote peer might've dialed - /// the local node and connection was established successfully. This dial address - /// is stored for processing later when the dial attempt concluded as either - /// successful/failed. - dial_record: Option, - }, - - /// Connection to peer is opening over one or more addresses. - Opening { - /// Address records used for dialing. - records: HashMap, - - /// Connection ID. - connection_id: ConnectionId, - - /// Active transports. - transports: HashSet, - }, - - /// Peer is being dialed. - Dialing { - /// Address record. - record: AddressRecord, - }, - - /// `Litep2p` is not connected to peer. - Disconnected { - /// Dial address, if it exists. - /// - /// While the local node was dialing a remote peer, the remote peer might've dialed - /// the local node and connection was established successfully. The connection might've - /// been closed before the dial concluded which means that - /// [`crate::transport::manager::TransportManager`] must be prepared to handle the dial - /// failure even after the connection has been closed. - dial_record: Option, - }, -} - /// Peer context. #[derive(Debug)] pub struct PeerContext { /// Peer state. pub state: PeerState, - /// Secondary connection, if it's open. - pub secondary_connection: Option, - /// Known addresses of peer. pub addresses: AddressStore, } + +impl Default for PeerContext { + fn default() -> Self { + Self { + state: PeerState::Disconnected { dial_record: None }, + addresses: AddressStore::new(), + } + } +}