diff --git a/src/transport/manager/address.rs b/src/transport/manager/address.rs index e712cfe1..30a13a4b 100644 --- a/src/transport/manager/address.rs +++ b/src/transport/manager/address.rs @@ -18,7 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{error::DialError, types::ConnectionId, PeerId}; +use crate::{error::DialError, PeerId}; use multiaddr::{Multiaddr, Protocol}; use multihash::Multihash; @@ -50,9 +50,6 @@ pub struct AddressRecord { /// Address. address: Multiaddr, - - /// Connection ID, if specified. - connection_id: Option, } impl AsRef for AddressRecord { @@ -64,12 +61,7 @@ impl AsRef for AddressRecord { impl AddressRecord { /// Create new `AddressRecord` and if `address` doesn't contain `P2p`, /// append the provided `PeerId` to the address. - pub fn new( - peer: &PeerId, - address: Multiaddr, - score: i32, - connection_id: Option, - ) -> Self { + pub fn new(peer: &PeerId, address: Multiaddr, score: i32) -> Self { let address = if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) { address.with(Protocol::P2p( Multihash::from_bytes(&peer.to_bytes()).expect("valid peer id"), @@ -78,11 +70,7 @@ impl AddressRecord { address }; - Self { - address, - score, - connection_id, - } + Self { address, score } } /// Create `AddressRecord` from `Multiaddr`. @@ -97,7 +85,6 @@ impl AddressRecord { Some(AddressRecord { address, score: 0i32, - connection_id: None, }) } @@ -112,20 +99,10 @@ impl AddressRecord { &self.address } - /// Get connection ID. - pub fn connection_id(&self) -> &Option { - &self.connection_id - } - /// Update score of an address. pub fn update_score(&mut self, score: i32) { self.score = self.score.saturating_add(score); } - - /// Set `ConnectionId` for the [`AddressRecord`]. - pub fn set_connection_id(&mut self, connection_id: ConnectionId) { - self.connection_id = Some(connection_id); - } } impl PartialEq for AddressRecord { @@ -161,8 +138,8 @@ impl FromIterator for AddressStore { fn from_iter>(iter: T) -> Self { let mut store = AddressStore::new(); for address in iter { - if let Some(address) = AddressRecord::from_multiaddr(address) { - store.insert(address); + if let Some(record) = AddressRecord::from_multiaddr(address) { + store.insert(record); } } @@ -292,7 +269,6 @@ mod tests { .with(Protocol::from(address.ip())) .with(Protocol::Tcp(address.port())), score, - None, ) } @@ -316,7 +292,6 @@ mod tests { .with(Protocol::Tcp(address.port())) .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))), score, - None, ) } @@ -340,7 +315,6 @@ mod tests { .with(Protocol::Udp(address.port())) .with(Protocol::QuicV1), score, - None, ) } diff --git a/src/transport/manager/handle.rs b/src/transport/manager/handle.rs index 937ebc3e..26dad579 100644 --- a/src/transport/manager/handle.rs +++ b/src/transport/manager/handle.rs @@ -25,8 +25,9 @@ use crate::{ executor::Executor, protocol::ProtocolSet, transport::manager::{ - address::{AddressRecord, AddressStore}, - types::{PeerContext, PeerState, SupportedTransport}, + address::AddressRecord, + peer_state::StateDialResult, + types::{PeerContext, SupportedTransport}, ProtocolContext, TransportManagerEvent, LOG_TARGET, }, types::{protocol::ProtocolName, ConnectionId}, @@ -223,11 +224,7 @@ impl TransportManagerHandle { ); let mut peers = self.peers.write(); - let entry = peers.entry(*peer).or_insert_with(|| PeerContext { - state: PeerState::Disconnected { dial_record: None }, - addresses: AddressStore::new(), - secondary_connection: None, - }); + let entry = peers.entry(*peer).or_insert_with(|| PeerContext::default()); // All addresses should be valid at this point, since the peer ID was either added or // double checked. @@ -249,36 +246,21 @@ impl TransportManagerHandle { } { - match self.peers.read().get(peer) { - Some(PeerContext { - state: PeerState::Connected { .. }, - .. - }) => return Err(ImmediateDialError::AlreadyConnected), - Some(PeerContext { - state: PeerState::Disconnected { dial_record }, - addresses, - .. - }) => { - if addresses.is_empty() { - return Err(ImmediateDialError::NoAddressAvailable); - } - - // peer is already being dialed, don't dial again until the first dial concluded - if dial_record.is_some() { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?dial_record, - "peer is aready being dialed", - ); - return Ok(()); - } - } - Some(PeerContext { - state: PeerState::Dialing { .. } | PeerState::Opening { .. }, - .. - }) => return Ok(()), - None => return Err(ImmediateDialError::NoAddressAvailable), + let peers = self.peers.read(); + let Some(PeerContext { state, addresses }) = peers.get(peer) else { + return Err(ImmediateDialError::NoAddressAvailable); + }; + + match state.can_dial() { + StateDialResult::AlreadyConnected => + return Err(ImmediateDialError::AlreadyConnected), + StateDialResult::DialingInProgress => return Ok(()), + StateDialResult::Ok => {} + }; + + // Check if we have enough addresses to dial. + if addresses.is_empty() { + return Err(ImmediateDialError::NoAddressAvailable); } } @@ -338,6 +320,11 @@ impl TransportHandle { #[cfg(test)] mod tests { + use crate::transport::manager::{ + address::AddressStore, + peer_state::{ConnectionRecord, PeerState}, + }; + use super::*; use multihash::Multihash; use parking_lot::lock_api::RwLock; @@ -454,16 +441,16 @@ mod tests { peer, PeerContext { state: PeerState::Connected { - record: AddressRecord::from_multiaddr( - Multiaddr::empty() + record: ConnectionRecord { + address: Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) .with(Protocol::Tcp(8888)) .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), - dial_record: None, + connection_id: ConnectionId::from(0), + }, + secondary: None, }, - secondary_connection: None, + addresses: AddressStore::from_iter( vec![Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) @@ -497,15 +484,15 @@ mod tests { peer, PeerContext { state: PeerState::Dialing { - record: AddressRecord::from_multiaddr( - Multiaddr::empty() + dial_record: ConnectionRecord { + address: Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) .with(Protocol::Tcp(8888)) .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), + connection_id: ConnectionId::from(0), + }, }, - secondary_connection: None, + addresses: AddressStore::from_iter( vec![Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) @@ -539,7 +526,6 @@ mod tests { peer, PeerContext { state: PeerState::Disconnected { dial_record: None }, - secondary_connection: None, addresses: AddressStore::new(), }, ); @@ -565,17 +551,16 @@ mod tests { peer, PeerContext { state: PeerState::Disconnected { - dial_record: Some( - AddressRecord::from_multiaddr( - Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), - ), + dial_record: Some(ConnectionRecord::new( + peer, + Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer))), + ConnectionId::from(0), + )), }, - secondary_connection: None, + addresses: AddressStore::from_iter( vec![Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) diff --git a/src/transport/manager/mod.rs b/src/transport/manager/mod.rs index ce17b1c8..7dba673d 100644 --- a/src/transport/manager/mod.rs +++ b/src/transport/manager/mod.rs @@ -27,9 +27,10 @@ use crate::{ protocol::{InnerTransportEvent, TransportService}, transport::{ manager::{ - address::{AddressRecord, AddressStore}, + address::AddressRecord, handle::InnerTransportManagerCommand, - types::{PeerContext, PeerState}, + peer_state::{ConnectionRecord, PeerState, StateDialResult}, + types::PeerContext, }, Endpoint, Transport, TransportEvent, }, @@ -37,7 +38,7 @@ use crate::{ BandwidthSink, PeerId, }; -use address::scores; +use address::{scores, AddressStore}; use futures::{Stream, StreamExt}; use indexmap::IndexMap; use multiaddr::{Multiaddr, Protocol}; @@ -46,7 +47,7 @@ use parking_lot::RwLock; use tokio::sync::mpsc::{channel, Receiver, Sender}; use std::{ - collections::{hash_map::Entry, HashMap, HashSet}, + collections::{HashMap, HashSet}, pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, @@ -61,6 +62,7 @@ pub use types::SupportedTransport; mod address; pub mod limits; +mod peer_state; mod types; pub(crate) mod handle; @@ -73,12 +75,6 @@ pub(crate) mod handle; /// Logging target for the file. const LOG_TARGET: &str = "litep2p::transport-manager"; -/// Score for a working address. -const SCORE_CONNECT_SUCCESS: i32 = 100i32; - -/// Score for a non-working address. -const SCORE_CONNECT_FAILURE: i32 = -100i32; - /// The connection established result. #[derive(Debug, Clone, Copy, Eq, PartialEq)] enum ConnectionEstablishedResult { @@ -320,7 +316,7 @@ impl TransportManager { } /// Get next connection ID. - fn next_connection_id(&mut self) -> ConnectionId { + fn next_connection_id(&self) -> ConnectionId { let connection_id = self.next_connection_id.fetch_add(1usize, Ordering::Relaxed); ConnectionId::from(connection_id) @@ -415,6 +411,31 @@ impl TransportManager { self.transport_manager_handle.add_known_address(&peer, address) } + /// Return multiple addresses to dial on supported protocols. + fn supported_transports_addresses( + addresses: &[Multiaddr], + ) -> HashMap> { + let mut transports = HashMap::>::new(); + + for address in addresses.iter().cloned() { + #[cfg(feature = "quic")] + if address.iter().any(|p| std::matches!(&p, Protocol::QuicV1)) { + transports.entry(SupportedTransport::Quic).or_default().push(address); + continue; + } + + #[cfg(feature = "websocket")] + if address.iter().any(|p| std::matches!(&p, Protocol::Ws(_) | Protocol::Wss(_))) { + transports.entry(SupportedTransport::WebSocket).or_default().push(address); + continue; + } + + transports.entry(SupportedTransport::Tcp).or_default().push(address); + } + + transports + } + /// Dial peer using `PeerId`. /// /// Returns an error if the peer is unknown or the peer is already connected. @@ -430,157 +451,58 @@ impl TransportManager { } let mut peers = self.peers.write(); - // if the peer is disconnected, return its context - // - // otherwise set the state back what it was and return dial status to caller - let PeerContext { - state, - secondary_connection, - addresses, - } = match peers.remove(&peer) { - None => return Err(Error::PeerDoesntExist(peer)), - Some( - context @ PeerContext { - state: PeerState::Connected { .. }, - .. - }, - ) => { - peers.insert(peer, context); - return Err(Error::AlreadyConnected); - } - Some( - context @ PeerContext { - state: PeerState::Dialing { .. } | PeerState::Opening { .. }, - .. - }, - ) => { - peers.insert(peer, context); - return Ok(()); - } - Some(context) => context, - }; - - if let PeerState::Disconnected { - dial_record: Some(_), - } = &state - { - tracing::debug!( - target: LOG_TARGET, - ?peer, - "peer is already being dialed", - ); - - peers.insert( - peer, - PeerContext { - state, - secondary_connection, - addresses, - }, - ); - - return Ok(()); - } + let context = peers.entry(peer).or_insert_with(|| PeerContext::default()); - let mut records: HashMap<_, _> = addresses - .addresses(limit) - .into_iter() - .map(|address| (address.clone(), AddressRecord::new(&peer, address, 0, None))) - .collect(); + // Check if dialing is possible before allocating addresses. + match context.state.can_dial() { + StateDialResult::AlreadyConnected => return Err(Error::AlreadyConnected), + StateDialResult::DialingInProgress => return Ok(()), + StateDialResult::Ok => {} + }; - if records.is_empty() { + // The addresses are sorted by score and contain the remote peer ID. + // We double checked above that the remote peer is not the local peer. + let dial_addresses = context.addresses.addresses(limit); + if dial_addresses.is_empty() { return Err(Error::NoAddressAvailable(peer)); } - - let locked_addresses = self.listen_addresses.read(); - for record in records.values() { - if locked_addresses.contains(record.as_ref()) { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?record, - "tried to dial self", - ); - - debug_assert!(false); - return Err(Error::TriedToDialSelf); - } - } - drop(locked_addresses); - - // set connection id for the address record and put peer into `Opening` state - let connection_id = - ConnectionId::from(self.next_connection_id.fetch_add(1usize, Ordering::Relaxed)); + let connection_id = self.next_connection_id(); tracing::debug!( target: LOG_TARGET, ?connection_id, - addresses = ?records, + addresses = ?dial_addresses, "dial remote peer", ); - let mut transports = HashSet::new(); - #[cfg(feature = "websocket")] - let mut websocket = Vec::new(); - #[cfg(feature = "quic")] - let mut quic = Vec::new(); - let mut tcp = Vec::new(); + let transports = Self::supported_transports_addresses(&dial_addresses); - for (address, record) in &mut records { - record.set_connection_id(connection_id); + // Dialing addresses will succeed because the `context.state.can_dial()` returned `Ok`. + let result = context.state.dial_addresses( + connection_id, + dial_addresses.iter().cloned().collect(), + transports.keys().cloned().collect(), + ); + if result != StateDialResult::Ok { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?connection_id, + state = ?context.state, + "invalid state for dialing", + ); + } - #[cfg(feature = "quic")] - if address.iter().any(|p| std::matches!(&p, Protocol::QuicV1)) { - quic.push(address.clone()); - transports.insert(SupportedTransport::Quic); + for (transport, addresses) in transports { + if addresses.is_empty() { continue; } - #[cfg(feature = "websocket")] - if address.iter().any(|p| std::matches!(&p, Protocol::Ws(_) | Protocol::Wss(_))) { - websocket.push(address.clone()); - transports.insert(SupportedTransport::WebSocket); + let Some(installed_transport) = self.transports.get_mut(&transport) else { continue; - } - - tcp.push(address.clone()); - transports.insert(SupportedTransport::Tcp); - } - - peers.insert( - peer, - PeerContext { - state: PeerState::Opening { - records, - connection_id, - transports, - }, - secondary_connection, - addresses, - }, - ); - - if !tcp.is_empty() { - self.transports - .get_mut(&SupportedTransport::Tcp) - .expect("transport to be supported") - .open(connection_id, tcp)?; - } - - #[cfg(feature = "quic")] - if !quic.is_empty() { - self.transports - .get_mut(&SupportedTransport::Quic) - .expect("transport to be supported") - .open(connection_id, quic)?; - } + }; - #[cfg(feature = "websocket")] - if !websocket.is_empty() { - self.transports - .get_mut(&SupportedTransport::WebSocket) - .expect("transport to be supported") - .open(connection_id, websocket)?; + installed_transport.open(connection_id, addresses)?; } self.pending_connections.insert(connection_id, peer); @@ -594,19 +516,19 @@ impl TransportManager { pub async fn dial_address(&mut self, address: Multiaddr) -> crate::Result<()> { self.connection_limits.on_dial_address()?; - let mut record = AddressRecord::from_multiaddr(address) + let address_record = AddressRecord::from_multiaddr(address) .ok_or(Error::AddressError(AddressError::PeerIdMissing))?; - if self.listen_addresses.read().contains(record.as_ref()) { + if self.listen_addresses.read().contains(address_record.as_ref()) { return Err(Error::TriedToDialSelf); } - tracing::debug!(target: LOG_TARGET, address = ?record.address(), "dial address"); + tracing::debug!(target: LOG_TARGET, address = ?address_record.address(), "dial address"); - let mut protocol_stack = record.as_ref().iter(); + let mut protocol_stack = address_record.as_ref().iter(); match protocol_stack .next() - .ok_or_else(|| Error::TransportNotSupported(record.address().clone()))? + .ok_or_else(|| Error::TransportNotSupported(address_record.address().clone()))? { Protocol::Ip4(_) | Protocol::Ip6(_) => {} Protocol::Dns(_) | Protocol::Dns4(_) | Protocol::Dns6(_) => {} @@ -616,29 +538,36 @@ impl TransportManager { ?transport, "invalid transport, expected `ip4`/`ip6`" ); - return Err(Error::TransportNotSupported(record.address().clone())); + return Err(Error::TransportNotSupported( + address_record.address().clone(), + )); } }; let supported_transport = match protocol_stack .next() - .ok_or_else(|| Error::TransportNotSupported(record.address().clone()))? + .ok_or_else(|| Error::TransportNotSupported(address_record.address().clone()))? { Protocol::Tcp(_) => match protocol_stack.next() { #[cfg(feature = "websocket")] Some(Protocol::Ws(_)) | Some(Protocol::Wss(_)) => SupportedTransport::WebSocket, Some(Protocol::P2p(_)) => SupportedTransport::Tcp, - _ => return Err(Error::TransportNotSupported(record.address().clone())), + _ => + return Err(Error::TransportNotSupported( + address_record.address().clone(), + )), }, #[cfg(feature = "quic")] Protocol::Udp(_) => match protocol_stack .next() - .ok_or_else(|| Error::TransportNotSupported(record.address().clone()))? + .ok_or_else(|| Error::TransportNotSupported(address_record.address().clone()))? { Protocol::QuicV1 => SupportedTransport::Quic, _ => { - tracing::debug!(target: LOG_TARGET, address = ?record.address(), "expected `quic-v1`"); - return Err(Error::TransportNotSupported(record.address().clone())); + tracing::debug!(target: LOG_TARGET, address = ?address_record.address(), "expected `quic-v1`"); + return Err(Error::TransportNotSupported( + address_record.address().clone(), + )); } }, protocol => { @@ -648,77 +577,44 @@ impl TransportManager { "invalid protocol" ); - return Err(Error::TransportNotSupported(record.address().clone())); + return Err(Error::TransportNotSupported( + address_record.address().clone(), + )); } }; // when constructing `AddressRecord`, `PeerId` was verified to be part of the address let remote_peer_id = - PeerId::try_from_multiaddr(record.address()).expect("`PeerId` to exist"); + PeerId::try_from_multiaddr(address_record.address()).expect("`PeerId` to exist"); // set connection id for the address record and put peer into `Dialing` state let connection_id = self.next_connection_id(); - record.set_connection_id(connection_id); + let dial_record = ConnectionRecord { + address: address_record.address().clone(), + connection_id, + }; { let mut peers = self.peers.write(); - match peers.entry(remote_peer_id) { - Entry::Occupied(occupied) => { - let context = occupied.into_mut(); + let context = peers.entry(remote_peer_id).or_insert_with(|| PeerContext::default()); - context.addresses.insert(record.clone()); + // Keep the provided record around for possible future dials. + context.addresses.insert(address_record.clone()); - tracing::debug!( - target: LOG_TARGET, - peer = ?remote_peer_id, - state = ?context.state, - "peer state exists", - ); - - match context.state { - PeerState::Connected { .. } => { - return Err(Error::AlreadyConnected); - } - PeerState::Dialing { .. } | PeerState::Opening { .. } => { - return Ok(()); - } - PeerState::Disconnected { - dial_record: Some(_), - } => { - tracing::debug!( - target: LOG_TARGET, - peer = ?remote_peer_id, - state = ?context.state, - "peer is already being dialed from a disconnected state" - ); - return Ok(()); - } - PeerState::Disconnected { dial_record: None } => { - context.state = PeerState::Dialing { - record: record.clone(), - }; - } - } - } - Entry::Vacant(vacant) => { - let mut addresses = AddressStore::new(); - addresses.insert(record.clone()); - vacant.insert(PeerContext { - state: PeerState::Dialing { - record: record.clone(), - }, - addresses, - secondary_connection: None, - }); - } + match context.state.dial_single_address(dial_record) { + StateDialResult::AlreadyConnected => return Err(Error::AlreadyConnected), + StateDialResult::DialingInProgress => return Ok(()), + StateDialResult::Ok => {} }; } self.transports .get_mut(&supported_transport) - .ok_or(Error::TransportNotSupported(record.address().clone()))? - .dial(connection_id, record.address().clone())?; + .ok_or(Error::TransportNotSupported( + address_record.address().clone(), + ))? + .dial(connection_id, address_record.address().clone())?; self.pending_connections.insert(connection_id, remote_peer_id); Ok(()) @@ -742,13 +638,15 @@ impl TransportManager { // We need a valid context for this peer to keep track of failed addresses. let context = peers.entry(peer_id).or_insert_with(|| PeerContext::default()); - context - .addresses - .insert(AddressRecord::new(&peer_id, address.clone(), score, None)); + context.addresses.insert(AddressRecord::new(&peer_id, address.clone(), score)); } /// Handle dial failure. + /// + /// The main purpose of this function is to advance the internal `PeerState`. fn on_dial_failure(&mut self, connection_id: ConnectionId) -> crate::Result<()> { + tracing::trace!(target: LOG_TARGET, ?connection_id, "on dial failure"); + let peer = self.pending_connections.remove(&connection_id).ok_or_else(|| { tracing::error!( target: LOG_TARGET, @@ -759,124 +657,29 @@ impl TransportManager { })?; let mut peers = self.peers.write(); - let context = peers.get_mut(&peer).ok_or_else(|| { - tracing::error!( + let context = peers.entry(peer).or_insert_with(|| PeerContext::default()); + let previous_state = context.state.clone(); + + if !context.state.on_dial_failure(connection_id) { + tracing::warn!( target: LOG_TARGET, ?peer, ?connection_id, - "dial failed for a peer that doesn't exist", + state = ?context.state, + "invalid state for dial failure", + ); + } else { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?previous_state, + state = ?context.state, + "on dial failure completed" ); - debug_assert!(false); - - Error::InvalidState - })?; - - match std::mem::replace( - &mut context.state, - PeerState::Disconnected { dial_record: None }, - ) { - PeerState::Dialing { ref mut record } => { - debug_assert_eq!(record.connection_id(), &Some(connection_id)); - if record.connection_id() != &Some(connection_id) { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?record, - "unknown dial failure for a dialing peer", - ); - - context.state = PeerState::Dialing { - record: record.clone(), - }; - debug_assert!(false); - return Ok(()); - } - - record.update_score(SCORE_CONNECT_FAILURE); - context.addresses.insert(record.clone()); - - context.state = PeerState::Disconnected { dial_record: None }; - Ok(()) - } - PeerState::Opening { .. } => { - todo!(); - } - PeerState::Connected { - record, - dial_record: Some(mut dial_record), - } => { - if dial_record.connection_id() != &Some(connection_id) { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?record, - "unknown dial failure for a connected peer", - ); - - context.state = PeerState::Connected { - record, - dial_record: Some(dial_record), - }; - debug_assert!(false); - return Ok(()); - } - - dial_record.update_score(SCORE_CONNECT_FAILURE); - context.addresses.insert(dial_record); - - context.state = PeerState::Connected { - record, - dial_record: None, - }; - Ok(()) - } - PeerState::Disconnected { - dial_record: Some(mut dial_record), - } => { - tracing::debug!( - target: LOG_TARGET, - ?connection_id, - ?dial_record, - "dial failed for a disconnected peer", - ); - - if dial_record.connection_id() != &Some(connection_id) { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?dial_record, - "unknown dial failure for a disconnected peer", - ); - - context.state = PeerState::Disconnected { - dial_record: Some(dial_record), - }; - debug_assert!(false); - return Ok(()); - } - - dial_record.update_score(SCORE_CONNECT_FAILURE); - context.addresses.insert(dial_record); - - Ok(()) - } - state => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?state, - "invalid state for dial failure", - ); - context.state = state; - - debug_assert!(false); - Ok(()) - } } + + Ok(()) } fn on_pending_incoming_connection(&mut self) -> crate::Result<()> { @@ -889,140 +692,40 @@ impl TransportManager { &mut self, peer: PeerId, connection_id: ConnectionId, - ) -> crate::Result> { + ) -> Option { + tracing::trace!(target: LOG_TARGET, ?peer, ?connection_id, "connection closed"); + self.connection_limits.on_connection_closed(connection_id); let mut peers = self.peers.write(); - let Some(context) = peers.get_mut(&peer) else { + let context = peers.entry(peer).or_insert_with(|| PeerContext::default()); + + let previous_state = context.state.clone(); + let connection_closed = context.state.on_connection_closed(connection_id); + + if context.state == previous_state { tracing::warn!( target: LOG_TARGET, ?peer, ?connection_id, - "cannot handle closed connection: peer doesn't exist", + state = ?context.state, + "invalid state for a closed connection", + ); + } else { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?previous_state, + state = ?context.state, + "on connection closed completed" ); - debug_assert!(false); - return Err(Error::PeerDoesntExist(peer)); - }; - - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "connection closed", - ); - - match std::mem::replace( - &mut context.state, - PeerState::Disconnected { dial_record: None }, - ) { - PeerState::Connected { - record, - dial_record: actual_dial_record, - } => match record.connection_id() == &Some(connection_id) { - // primary connection was closed - // - // if secondary connection exists, switch to using it while keeping peer in - // `Connected` state and if there's only one connection, set peer - // state to `Disconnected` - true => match context.secondary_connection.take() { - None => { - context.addresses.insert(record); - context.state = PeerState::Disconnected { - dial_record: actual_dial_record, - }; - - Ok(Some(TransportEvent::ConnectionClosed { - peer, - connection_id, - })) - } - Some(secondary_connection) => { - context.addresses.insert(record); - context.state = PeerState::Connected { - record: secondary_connection, - dial_record: actual_dial_record, - }; - - Ok(None) - } - }, - // secondary connection was closed - false => match context.secondary_connection.take() { - Some(secondary_connection) => { - if secondary_connection.connection_id() != &Some(connection_id) { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "unknown connection was closed, potentially ignored tertiary connection", - ); - - context.secondary_connection = Some(secondary_connection); - context.state = PeerState::Connected { - record, - dial_record: actual_dial_record, - }; - - return Ok(None); - } - - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "secondary connection closed", - ); - - context.addresses.insert(secondary_connection); - context.state = PeerState::Connected { - record, - dial_record: actual_dial_record, - }; - Ok(None) - } - None => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "non-primary connection was closed but secondary connection doesn't exist", - ); - - debug_assert!(false); - Err(Error::InvalidState) - } - }, - }, - PeerState::Disconnected { dial_record } => match context.secondary_connection.take() { - Some(record) => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?record, - ?dial_record, - "peer is disconnected but secondary connection exists", - ); - - debug_assert!(false); - context.state = PeerState::Disconnected { dial_record }; - Err(Error::InvalidState) - } - None => { - context.state = PeerState::Disconnected { dial_record }; - - Ok(Some(TransportEvent::ConnectionClosed { - peer, - connection_id, - })) - } - }, - state => { - tracing::warn!(target: LOG_TARGET, ?peer, ?connection_id, ?state, "invalid state for a closed connection"); - debug_assert!(false); - Err(Error::InvalidState) - } } + + connection_closed.then_some(TransportEvent::ConnectionClosed { + peer, + connection_id, + }) } /// Update the address on a connection established. @@ -1041,7 +744,6 @@ impl TransportManager { &peer, endpoint.address().clone(), scores::CONNECTION_ESTABLISHED, - None, ); let context = peers.entry(peer).or_insert_with(|| PeerContext::default()); @@ -1085,239 +787,57 @@ impl TransportManager { } let mut peers = self.peers.write(); - match peers.get_mut(&peer) { - Some(context) => match context.state { - PeerState::Connected { - ref mut dial_record, - .. - } => match context.secondary_connection { - Some(_) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - ?endpoint, - "secondary connection already exists, ignoring connection", - ); - - return Ok(ConnectionEstablishedResult::Reject); - } - None => match dial_record.take() { - Some(record) - if record.connection_id() == &Some(endpoint.connection_id()) => - { - tracing::debug!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - address = ?endpoint.address(), - "dialed connection opened as secondary connection", - ); - - context.secondary_connection = Some(AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_CONNECT_SUCCESS, - Some(endpoint.connection_id()), - )); - } - None => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - address = ?endpoint.address(), - "secondary connection", - ); - - context.secondary_connection = Some(AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_CONNECT_SUCCESS, - Some(endpoint.connection_id()), - )); - } - Some(record) => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - address = ?endpoint.address(), - dial_record = ?record, - "unknown connection opened as secondary connection, discarding", - ); - - // Preserve the dial record. - *dial_record = Some(record); - - return Ok(ConnectionEstablishedResult::Reject); - } - }, - }, - PeerState::Dialing { ref record, .. } => { - match record.connection_id() == &Some(endpoint.connection_id()) { - true => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - ?endpoint, - ?record, - "connection opened to remote", - ); - - context.state = PeerState::Connected { - record: record.clone(), - dial_record: None, - }; - } - false => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - ?endpoint, - "connection opened by remote while local node was dialing", - ); + let context = peers.entry(peer).or_insert_with(|| PeerContext::default()); - context.state = PeerState::Connected { - record: AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_CONNECT_SUCCESS, - Some(endpoint.connection_id()), - ), - dial_record: Some(record.clone()), - }; - } - } - } - PeerState::Opening { - ref mut records, - connection_id, - ref transports, - } => { - debug_assert!(std::matches!(endpoint, &Endpoint::Listener { .. })); - - tracing::trace!( - target: LOG_TARGET, - ?peer, - dial_connection_id = ?connection_id, - dial_records = ?records, - dial_transports = ?transports, - listener_endpoint = ?endpoint, - "inbound connection while opening an outbound connection", - ); + let previous_state = context.state.clone(); + let connection_accepted = context + .state + .on_connection_established(ConnectionRecord::from_endpoint(peer, endpoint)); - // cancel all pending dials - transports.iter().for_each(|transport| { - self.transports - .get_mut(transport) - .expect("transport to exist") - .cancel(connection_id); - }); - - // since an inbound connection was removed, the outbound connection can be - // removed from pending dials - // - // all records have the same `ConnectionId` so it doesn't matter which of them - // is used to remove the pending dial - self.pending_connections.remove( - &records - .iter() - .next() - .expect("record to exist") - .1 - .connection_id() - .expect("`ConnectionId` to exist"), - ); + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?endpoint, + ?previous_state, + state = ?context.state, + "on connection established completed" + ); - let record = match records.remove(endpoint.address()) { - Some(mut record) => { - record.update_score(SCORE_CONNECT_SUCCESS); - record.set_connection_id(endpoint.connection_id()); - record - } - None => AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_CONNECT_SUCCESS, - Some(endpoint.connection_id()), - ), - }; - context.addresses.extend(records.iter().map(|(_, record)| record)); - - context.state = PeerState::Connected { - record, - dial_record: None, - }; - } - PeerState::Disconnected { - ref mut dial_record, - } => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - ?endpoint, - ?dial_record, - "connection opened by remote or delayed dial succeeded", - ); + if connection_accepted { + // Cancel all pending dials if the connection was established. + if let PeerState::Opening { + connection_id, + transports, + .. + } = previous_state + { + // cancel all pending dials + transports.iter().for_each(|transport| { + self.transports + .get_mut(transport) + .expect("transport to exist") + .cancel(connection_id); + }); - let (record, dial_record) = match dial_record.take() { - Some(mut dial_record) => - if dial_record.address() == endpoint.address() { - dial_record.set_connection_id(endpoint.connection_id()); - (dial_record, None) - } else { - ( - AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_CONNECT_SUCCESS, - Some(endpoint.connection_id()), - ), - Some(dial_record), - ) - }, - None => ( - AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_CONNECT_SUCCESS, - Some(endpoint.connection_id()), - ), - None, - ), - }; - - context.state = PeerState::Connected { - record, - dial_record, - }; - } - }, - None => { - peers.insert( - peer, - PeerContext { - state: PeerState::Connected { - record: AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_CONNECT_SUCCESS, - Some(endpoint.connection_id()), - ), - dial_record: None, - }, - addresses: AddressStore::new(), - secondary_connection: None, - }, - ); + // since an inbound connection was removed, the outbound connection can be + // removed from pending dials + // + // This may race in the following scenario: + // + // T0: we open address X on protocol TCP + // T1: remote peer opens a connection with us + // T2: address X is dialed and event is propagated from TCP to transport manager + // T3: `on_connection_established` is called for T1 and pending connections cleared + // T4: event from T2 is delivered. + // + // TODO: see https://github.com/paritytech/litep2p/issues/276 for more details. + self.pending_connections.remove(&connection_id); } + + return Ok(ConnectionEstablishedResult::Accept); } - Ok(ConnectionEstablishedResult::Accept) + Ok(ConnectionEstablishedResult::Reject) } fn on_connection_opened( @@ -1340,107 +860,83 @@ impl TransportManager { }; let mut peers = self.peers.write(); - let context = peers.get_mut(&peer).ok_or_else(|| { + let context = peers.entry(peer).or_insert_with(|| PeerContext::default()); + + // Keep track of the address. + context.addresses.insert(AddressRecord::new( + &peer, + address.clone(), + scores::CONNECTION_ESTABLISHED, + )); + + let previous_state = context.state.clone(); + let record = ConnectionRecord::new(peer, address.clone(), connection_id); + let state_advanced = context.state.on_connection_opened(record); + if !state_advanced { tracing::warn!( target: LOG_TARGET, ?peer, ?connection_id, - "connection opened but peer doesn't exist", + state = ?context.state, + "connection opened but `PeerState` is not `Opening`", ); + return Err(Error::InvalidState); + } - debug_assert!(false); - Error::InvalidState - })?; + // State advanced from `Opening` to `Dialing`. + let PeerState::Opening { + connection_id, + transports, + .. + } = previous_state + else { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?connection_id, + state = ?context.state, + "State mismatch in opening expected by peer state transition", + ); + return Err(Error::InvalidState); + }; - match std::mem::replace( - &mut context.state, - PeerState::Disconnected { dial_record: None }, - ) { - PeerState::Opening { - mut records, - connection_id, - transports, - } => { + // Cancel open attempts for other transports as connection already exists. + for transport in transports.iter() { + self.transports + .get_mut(transport) + .expect("transport to exist") + .cancel(connection_id); + } + + let negotiation = self + .transports + .get_mut(&transport) + .expect("transport to exist") + .negotiate(connection_id); + + match negotiation { + Ok(()) => { tracing::trace!( target: LOG_TARGET, ?peer, ?connection_id, - ?address, ?transport, - "connection opened to peer", + "negotiation started" ); - // cancel open attempts for other transports as connection already exists - for transport in transports.iter() { - self.transports - .get_mut(transport) - .expect("transport to exist") - .cancel(connection_id); - } - - // set peer state to `Dialing` to signal that the connection is fully opening - // - // set the succeeded `AddressRecord` as the one that is used for dialing and move - // all other address records back to `AddressStore`. and ask - // transport to negotiate the - let mut dial_record = records.remove(&address).expect("address to exist"); - dial_record.update_score(SCORE_CONNECT_SUCCESS); - - // negotiate the connection - match self - .transports - .get_mut(&transport) - .expect("transport to exist") - .negotiate(connection_id) - { - Ok(()) => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?dial_record, - ?transport, - "negotiation started" - ); - - self.pending_connections.insert(connection_id, peer); - - context.state = PeerState::Dialing { - record: dial_record, - }; - - for (_, record) in records { - context.addresses.insert(record); - } + self.pending_connections.insert(connection_id, peer); - Ok(()) - } - Err(error) => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?error, - "failed to negotiate connection", - ); - context.state = PeerState::Disconnected { dial_record: None }; - - debug_assert!(false); - Err(Error::InvalidState) - } - } + Ok(()) } - state => { + Err(err) => { tracing::warn!( target: LOG_TARGET, ?peer, ?connection_id, - ?state, - "connection opened but `PeerState` is not `Opening`", + ?err, + "failed to negotiate connection", ); - context.state = state; - - debug_assert!(false); + context.state = PeerState::Disconnected { dial_record: None }; Err(Error::InvalidState) } } @@ -1452,7 +948,7 @@ impl TransportManager { transport: SupportedTransport, connection_id: ConnectionId, ) -> crate::Result> { - let Some(peer) = self.pending_connections.remove(&connection_id) else { + let Some(peer) = self.pending_connections.get(&connection_id).copied() else { tracing::warn!( target: LOG_TARGET, ?connection_id, @@ -1462,75 +958,43 @@ impl TransportManager { }; let mut peers = self.peers.write(); - let context = peers.get_mut(&peer).ok_or_else(|| { + let context = peers.entry(peer).or_insert_with(|| PeerContext::default()); + + let previous_state = context.state.clone(); + let last_transport = context.state.on_open_failure(transport); + + if context.state == previous_state { tracing::warn!( target: LOG_TARGET, ?peer, ?connection_id, - "open failure but peer doesn't exist", + ?transport, + state = ?context.state, + "invalid state for a open failure", ); - debug_assert!(false); - Error::InvalidState - })?; - - match std::mem::replace( - &mut context.state, - PeerState::Disconnected { dial_record: None }, - ) { - PeerState::Opening { - records, - connection_id, - mut transports, - } => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?transport, - "open failure for peer", - ); - transports.remove(&transport); - - if transports.is_empty() { - for (_, mut record) in records { - record.update_score(SCORE_CONNECT_FAILURE); - context.addresses.insert(record); - } - - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "open failure for last transport", - ); - - return Ok(Some(peer)); - } - - self.pending_connections.insert(connection_id, peer); - context.state = PeerState::Opening { - records, - connection_id, - transports, - }; + return Err(Error::InvalidState); + } - Ok(None) - } - state => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?state, - "open failure but `PeerState` is not `Opening`", - ); - context.state = state; + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?transport, + ?previous_state, + state = ?context.state, + "on open failure transition completed" + ); - debug_assert!(false); - Err(Error::InvalidState) - } + if last_transport { + tracing::trace!(target: LOG_TARGET, ?peer, ?connection_id, "open failure for last transport"); + // Remove the pending connection. + self.pending_connections.remove(&connection_id); + // Provide the peer to notify the open failure. + return Ok(Some(peer)); } + + Ok(None) } /// Poll next event from [`crate::transport::manager::TransportManager`]. @@ -1542,13 +1006,8 @@ impl TransportManager { peer, connection: connection_id, } => match self.on_connection_closed(peer, connection_id) { - Ok(None) => {} - Ok(Some(event)) => return Some(event), - Err(error) => tracing::error!( - target: LOG_TARGET, - ?error, - "failed to handle closed connection", - ), + None => {} + Some(event) => return Some(event), } }, command = self.cmd_rx.recv() => match command? { @@ -1662,6 +1121,7 @@ impl TransportManager { } TransportEvent::ConnectionEstablished { peer, endpoint } => { self.opening_errors.remove(&endpoint.connection_id()); + match self.on_connection_established(peer, &endpoint) { Err(error) => { tracing::debug!( @@ -1826,6 +1286,7 @@ impl TransportManager { #[cfg(test)] mod tests { + use crate::transport::manager::{address::AddressStore, peer_state::SecondaryOrDialing}; use limits::ConnectionLimitsConfig; use multihash::Multihash; @@ -1841,6 +1302,7 @@ mod tests { use std::{ net::{Ipv4Addr, Ipv6Addr}, sync::Arc, + usize, }; /// Setup TCP address and connection id. @@ -2184,7 +1646,6 @@ mod tests { PeerContext { state: PeerState::Disconnected { dial_record: None }, addresses: AddressStore::new(), - secondary_connection: None, }, ); @@ -2291,8 +1752,8 @@ mod tests { assert_eq!(manager.pending_connections.len(), 1); match &manager.peers.read().get(&peer).unwrap().state { - PeerState::Dialing { record } => { - assert_eq!(record.address(), &dial_address); + PeerState::Dialing { dial_record } => { + assert_eq!(dial_record.address, dial_address); } state => panic!("invalid state for peer: {state:?}"), } @@ -2312,8 +1773,8 @@ mod tests { let peer = peers.get(&peer).unwrap(); match &peer.state { - PeerState::Connected { dial_record, .. } => { - assert!(dial_record.is_none()); + PeerState::Connected { secondary, .. } => { + assert!(secondary.is_none()); assert!(peer.addresses.addresses.contains_key(&dial_address)); } state => panic!("invalid state: {state:?}"), @@ -2358,8 +1819,8 @@ mod tests { assert_eq!(manager.pending_connections.len(), 1); match &manager.peers.read().get(&peer).unwrap().state { - PeerState::Dialing { record } => { - assert_eq!(record.address(), &dial_address); + PeerState::Dialing { dial_record } => { + assert_eq!(dial_record.address, dial_address); } state => panic!("invalid state for peer: {state:?}"), } @@ -2385,7 +1846,7 @@ mod tests { dial_record: Some(dial_record), .. } => { - assert_eq!(dial_record.address(), &dial_address); + assert_eq!(dial_record.address, dial_address); } state => panic!("invalid state: {state:?}"), } @@ -2445,8 +1906,8 @@ mod tests { assert_eq!(manager.pending_connections.len(), 1); match &manager.peers.read().get(&peer).unwrap().state { - PeerState::Dialing { record } => { - assert_eq!(record.address(), &dial_address); + PeerState::Dialing { dial_record } => { + assert_eq!(dial_record.address, dial_address); } state => panic!("invalid state for peer: {state:?}"), } @@ -2472,7 +1933,7 @@ mod tests { dial_record: Some(dial_record), .. } => { - assert_eq!(dial_record.address(), &dial_address); + assert_eq!(dial_record.address, dial_address); } state => panic!("invalid state: {state:?}"), } @@ -2491,7 +1952,7 @@ mod tests { match &peer.state { PeerState::Connected { - dial_record: None, .. + secondary: None, .. } => {} state => panic!("invalid state: {state:?}"), } @@ -2548,10 +2009,8 @@ mod tests { match &peer.state { PeerState::Connected { - dial_record: None, .. - } => { - assert!(peer.secondary_connection.is_none()); - } + secondary: None, .. + } => {} state => panic!("invalid state: {state:?}"), } } @@ -2570,11 +2029,10 @@ mod tests { match &context.state { PeerState::Connected { - dial_record: None, .. + secondary: Some(SecondaryOrDialing::Secondary(secondary_connection)), + .. } => { - let seconary_connection = context.secondary_connection.as_ref().unwrap(); - assert_eq!(seconary_connection.address(), &address2); - assert_eq!(seconary_connection.score(), SCORE_CONNECT_SUCCESS); + assert_eq!(secondary_connection.address, address2); } state => panic!("invalid state: {state:?}"), } @@ -2594,12 +2052,10 @@ mod tests { match &peer.state { PeerState::Connected { - dial_record: None, .. + secondary: Some(SecondaryOrDialing::Secondary(secondary_connection)), + .. } => { - assert_eq!( - peer.secondary_connection.as_ref().unwrap().address(), - &address2 - ); + assert_eq!(secondary_connection.address, address2); // Endpoint::listener addresses are not tracked. assert!(!peer.addresses.addresses.contains_key(&address2)); assert!(!peer.addresses.addresses.contains_key(&address3)); @@ -2656,10 +2112,8 @@ mod tests { match &peer.state { PeerState::Connected { - dial_record: None, .. - } => { - assert!(peer.secondary_connection.is_none()); - } + secondary: None, .. + } => {} state => panic!("invalid state: {state:?}"), } } @@ -2674,16 +2128,10 @@ mod tests { state => panic!("invalid state: {state:?}"), }; - let dial_record = Some(AddressRecord::new( - &peer, - address2.clone(), - 0, - Some(ConnectionId::from(0usize)), - )); - + let dial_record = ConnectionRecord::new(peer, address2.clone(), ConnectionId::from(0)); peer_context.state = PeerState::Connected { record, - dial_record, + secondary: Some(SecondaryOrDialing::Dialing(dial_record)), }; } @@ -2763,15 +2211,18 @@ mod tests { match &peer.state { PeerState::Connected { - dial_record: None, .. + record, + secondary: None, + .. } => { - assert!(peer.secondary_connection.is_none()); + // Primary connection is established. + assert_eq!(record.connection_id, ConnectionId::from(0usize)); } state => panic!("invalid state: {state:?}"), } } - // second connection is established, verify that the seconary connection is tracked + // second connection is established, verify that the secondary connection is tracked let emit_event = manager .on_connection_established( peer, @@ -2788,18 +2239,17 @@ mod tests { match &context.state { PeerState::Connected { - dial_record: None, .. + secondary: Some(SecondaryOrDialing::Secondary(secondary_connection)), + .. } => { - let seconary_connection = context.secondary_connection.as_ref().unwrap(); - assert_eq!(seconary_connection.address(), &address2); - assert_eq!(seconary_connection.score(), SCORE_CONNECT_SUCCESS); + assert_eq!(secondary_connection.address, address2); } state => panic!("invalid state: {state:?}"), } drop(peers); // close the secondary connection and verify that the peer remains connected - let emit_event = manager.on_connection_closed(peer, ConnectionId::from(1usize)).unwrap(); + let emit_event = manager.on_connection_closed(peer, ConnectionId::from(1usize)); assert!(emit_event.is_none()); let peers = manager.peers.read(); @@ -2807,12 +2257,16 @@ mod tests { match &context.state { PeerState::Connected { - dial_record: None, + secondary: None, record, } => { - assert!(context.secondary_connection.is_none()); assert!(context.addresses.addresses.contains_key(&address2)); - assert_eq!(record.connection_id(), &Some(ConnectionId::from(0usize))); + assert_eq!( + context.addresses.addresses.get(&address2).unwrap().score(), + scores::CONNECTION_ESTABLISHED + ); + // Primary remains opened. + assert_eq!(record.connection_id, ConnectionId::from(0usize)); } state => panic!("invalid state: {state:?}"), } @@ -2859,22 +2313,20 @@ mod tests { ConnectionEstablishedResult::Accept )); - // verify that the peer state is `Connected` with no seconary connection + // verify that the peer state is `Connected` with no secondary connection { let peers = manager.peers.read(); let peer = peers.get(&peer).unwrap(); match &peer.state { PeerState::Connected { - dial_record: None, .. - } => { - assert!(peer.secondary_connection.is_none()); - } + secondary: None, .. + } => {} state => panic!("invalid state: {state:?}"), } } - // second connection is established, verify that the seconary connection is tracked + // second connection is established, verify that the secondary connection is tracked let emit_event = manager .on_connection_established( peer, @@ -2891,11 +2343,10 @@ mod tests { match &context.state { PeerState::Connected { - dial_record: None, .. + secondary: Some(SecondaryOrDialing::Secondary(secondary_connection)), + .. } => { - let seconary_connection = context.secondary_connection.as_ref().unwrap(); - assert_eq!(seconary_connection.address(), &address2); - assert_eq!(seconary_connection.score(), SCORE_CONNECT_SUCCESS); + assert_eq!(secondary_connection.address, address2); } state => panic!("invalid state: {state:?}"), } @@ -2903,7 +2354,7 @@ mod tests { // close the primary connection and verify that the peer remains connected // while the primary connection address is stored in peer addresses - let emit_event = manager.on_connection_closed(peer, ConnectionId::from(0usize)).unwrap(); + let emit_event = manager.on_connection_closed(peer, ConnectionId::from(0usize)); assert!(emit_event.is_none()); let peers = manager.peers.read(); @@ -2911,12 +2362,12 @@ mod tests { match &context.state { PeerState::Connected { - dial_record: None, + secondary: None, record, } => { - assert!(context.secondary_connection.is_none()); - assert!(context.addresses.addresses.contains_key(&address1)); - assert_eq!(record.connection_id(), &Some(ConnectionId::from(1usize))); + assert!(!context.addresses.addresses.contains_key(&address1)); + assert!(context.addresses.addresses.contains_key(&address2)); + assert_eq!(record.connection_id, ConnectionId::from(1usize)); } state => panic!("invalid state: {state:?}"), } @@ -2985,10 +2436,11 @@ mod tests { let peer = peers.get(&peer).unwrap(); match &peer.state { - PeerState::Connected { .. } => {} + PeerState::Connected { + secondary: None, .. + } => {} state => panic!("invalid state: {state:?}"), } - assert!(peer.secondary_connection.is_none()); } // second connection is established, verify that the seconary connection is tracked @@ -3013,13 +2465,14 @@ mod tests { let context = peers.get(&peer).unwrap(); match &context.state { - PeerState::Connected { .. } => {} + PeerState::Connected { + secondary: Some(SecondaryOrDialing::Secondary(secondary_connection)), + .. + } => { + assert_eq!(secondary_connection.address, address2); + } state => panic!("invalid state: {state:?}"), } - assert_eq!( - context.secondary_connection.as_ref().unwrap().address(), - &address2, - ); drop(peers); // third connection is established, verify that it's discarded @@ -3042,7 +2495,7 @@ mod tests { drop(peers); // close the tertiary connection that was ignored - let emit_event = manager.on_connection_closed(peer, ConnectionId::from(2usize)).unwrap(); + let emit_event = manager.on_connection_closed(peer, ConnectionId::from(2usize)); assert!(emit_event.is_none()); // verify that the state remains unchanged @@ -3050,17 +2503,18 @@ mod tests { let context = peers.get(&peer).unwrap(); match &context.state { - PeerState::Connected { .. } => {} + PeerState::Connected { + secondary: Some(SecondaryOrDialing::Secondary(secondary_connection)), + .. + } => { + assert_eq!(secondary_connection.address, address2); + assert_eq!( + context.addresses.addresses.get(&address2).unwrap().score(), + scores::CONNECTION_ESTABLISHED + ); + } state => panic!("invalid state: {state:?}"), } - assert_eq!( - context.secondary_connection.as_ref().unwrap().address(), - &address2 - ); - assert_eq!( - context.addresses.addresses.get(&address2).unwrap().score(), - scores::CONNECTION_ESTABLISHED - ); drop(peers); } @@ -3083,27 +2537,6 @@ mod tests { manager.on_dial_failure(ConnectionId::random()).unwrap(); } - #[tokio::test] - #[cfg(debug_assertions)] - #[should_panic] - async fn dial_failure_for_unknow_peer() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ConnectionLimitsConfig::default(), - ); - let connection_id = ConnectionId::random(); - let peer = PeerId::random(); - manager.pending_connections.insert(connection_id, peer); - manager.on_dial_failure(connection_id).unwrap(); - } - #[tokio::test] #[cfg(debug_assertions)] #[should_panic] @@ -3275,16 +2708,16 @@ mod tests { peer, PeerContext { state: PeerState::Connected { - record: AddressRecord::from_multiaddr( - Multiaddr::empty() + record: ConnectionRecord { + address: Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) .with(Protocol::Tcp(8888)) .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), - dial_record: None, + connection_id: ConnectionId::from(0usize), + }, + secondary: None, }, - secondary_connection: None, + addresses: AddressStore::from_iter( vec![Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) @@ -3323,15 +2756,15 @@ mod tests { peer, PeerContext { state: PeerState::Dialing { - record: AddressRecord::from_multiaddr( - Multiaddr::empty() + dial_record: ConnectionRecord { + address: Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) .with(Protocol::Tcp(8888)) .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), + connection_id: ConnectionId::from(0usize), + }, }, - secondary_connection: None, + addresses: AddressStore::from_iter( vec![Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) @@ -3354,10 +2787,10 @@ mod tests { let peer_context = peers.get(&peer).unwrap(); match &peer_context.state { - PeerState::Dialing { record } => { + PeerState::Dialing { dial_record } => { assert_eq!( - record.address(), - &Multiaddr::empty() + dial_record.address, + Multiaddr::empty() .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) .with(Protocol::Tcp(8888)) .with(Protocol::P2p(Multihash::from(peer))) @@ -3386,17 +2819,16 @@ mod tests { peer, PeerContext { state: PeerState::Disconnected { - dial_record: Some( - AddressRecord::from_multiaddr( - Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), - ), + dial_record: Some(ConnectionRecord::new( + peer, + Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer))), + ConnectionId::from(0), + )), }, - secondary_connection: None, + addresses: AddressStore::new(), }, ); @@ -3602,19 +3034,13 @@ mod tests { let peers = manager.peers.read(); match peers.get(&peer).unwrap() { PeerContext { - state: - PeerState::Connected { - record, - dial_record, - }, - secondary_connection, + state: PeerState::Connected { record, secondary }, addresses, } => { - assert!(!addresses.addresses.contains_key(record.address())); - assert!(dial_record.is_none()); - assert!(secondary_connection.is_none()); - assert_eq!(record.address(), &dial_address); - assert_eq!(record.connection_id(), &Some(connection_id)); + assert!(!addresses.addresses.contains_key(&record.address)); + assert!(secondary.is_none()); + assert_eq!(record.address, dial_address); + assert_eq!(record.connection_id, connection_id); } state => panic!("invalid peer state: {state:?}"), } @@ -3694,16 +3120,15 @@ mod tests { let peers = manager.peers.read(); match peers.get(&peer).unwrap() { PeerContext { - state: PeerState::Connected { record, .. }, - secondary_connection, + state: PeerState::Connected { record, secondary }, addresses, } => { // Saved from the dial attempt. assert_eq!(addresses.addresses.get(&dial_address).unwrap().score(), 0); - assert!(secondary_connection.is_none()); - assert_eq!(record.address(), &dial_address); - assert_eq!(record.connection_id(), &Some(connection_id)); + assert!(secondary.is_none()); + assert_eq!(record.address, dial_address); + assert_eq!(record.connection_id, connection_id); } state => panic!("invalid peer state: {state:?}"), } @@ -3773,7 +3198,7 @@ mod tests { assert_eq!(result, ConnectionEstablishedResult::Reject); // Close one connection. - let _ = manager.on_connection_closed(peer, first_connection_id).unwrap(); + assert!(manager.on_connection_closed(peer, first_connection_id).is_none()); // The second peer can establish 2 inbounds now. let result = manager @@ -3864,7 +3289,7 @@ mod tests { )); // Close one connection. - let _ = manager.on_connection_closed(peer, first_connection_id).unwrap(); + assert!(manager.on_connection_closed(peer, first_connection_id).is_some()); // We can now dial again. manager.dial_address(first_addr.clone()).await.unwrap(); @@ -3891,7 +3316,7 @@ mod tests { // Random peer ID. let peer = PeerId::random(); - let (first_addr, first_connection_id) = setup_dial_addr(peer, 0); + let (first_addr, _first_connection_id) = setup_dial_addr(peer, 0); let second_connection_id = ConnectionId::from(1); let different_connection_id = ConnectionId::from(2); @@ -3900,18 +3325,16 @@ mod tests { let mut peers = manager.peers.write(); let state = PeerState::Connected { - record: AddressRecord::new(&peer, first_addr.clone(), 0, Some(first_connection_id)), - dial_record: Some(AddressRecord::new( - &peer, + record: ConnectionRecord::new(peer, first_addr.clone(), ConnectionId::from(0)), + secondary: Some(SecondaryOrDialing::Dialing(ConnectionRecord::new( + peer, first_addr.clone(), - 0, - Some(second_connection_id), - )), + second_connection_id, + ))), }; let peer_context = PeerContext { state, - secondary_connection: None, addresses: AddressStore::from_iter(vec![first_addr.clone()].into_iter()), }; @@ -3970,8 +3393,8 @@ mod tests { let peers = manager.peers.read(); let peer_context = peers.get(&peer).unwrap(); match &peer_context.state { - PeerState::Dialing { record } => { - assert_eq!(record.address(), &first_addr); + PeerState::Dialing { dial_record } => { + assert_eq!(dial_record.address, first_addr); } state => panic!("invalid state: {state:?}"), } @@ -3991,21 +3414,20 @@ mod tests { match &peer_context.state { PeerState::Connected { record, - dial_record, + secondary: Some(SecondaryOrDialing::Dialing(dial_record)), } => { - assert_eq!(record.address(), &remote_addr); - assert_eq!(record.connection_id(), &Some(remote_connection_id)); + assert_eq!(record.address, remote_addr); + assert_eq!(record.connection_id, remote_connection_id); - let dial_record = dial_record.as_ref().unwrap(); - assert_eq!(dial_record.address(), &first_addr); - assert_eq!(dial_record.connection_id(), &Some(first_connection_id)) + assert_eq!(dial_record.address, first_addr); + assert_eq!(dial_record.connection_id, first_connection_id) } state => panic!("invalid state: {state:?}"), } } // Step 3. The peer disconnects while we have a dialing in flight. - let event = manager.on_connection_closed(peer, remote_connection_id).unwrap().unwrap(); + let event = manager.on_connection_closed(peer, remote_connection_id).unwrap(); match event { TransportEvent::ConnectionClosed { peer: event_peer, @@ -4022,8 +3444,8 @@ mod tests { match &peer_context.state { PeerState::Disconnected { dial_record } => { let dial_record = dial_record.as_ref().unwrap(); - assert_eq!(dial_record.address(), &first_addr); - assert_eq!(dial_record.connection_id(), &Some(first_connection_id)); + assert_eq!(dial_record.address, first_addr); + assert_eq!(dial_record.connection_id, first_connection_id); } state => panic!("invalid state: {state:?}"), } @@ -4038,8 +3460,8 @@ mod tests { match &peer_context.state { PeerState::Disconnected { dial_record } => { let dial_record = dial_record.as_ref().unwrap(); - assert_eq!(dial_record.address(), &first_addr); - assert_eq!(dial_record.connection_id(), &Some(first_connection_id)); + assert_eq!(dial_record.address, first_addr); + assert_eq!(dial_record.connection_id, first_connection_id); } state => panic!("invalid state: {state:?}"), } @@ -4059,15 +3481,14 @@ mod tests { match &peer_context.state { PeerState::Connected { record, - dial_record, + secondary: Some(SecondaryOrDialing::Dialing(dial_record)), } => { - assert_eq!(record.address(), &remote_addr); - assert_eq!(record.connection_id(), &Some(remote_connection_id)); + assert_eq!(record.address, remote_addr); + assert_eq!(record.connection_id, remote_connection_id); // We have not overwritten the first dial record in step 4. - let dial_record = dial_record.as_ref().unwrap(); - assert_eq!(dial_record.address(), &first_addr); - assert_eq!(dial_record.connection_id(), &Some(first_connection_id)); + assert_eq!(dial_record.address, first_addr); + assert_eq!(dial_record.connection_id, first_connection_id); } state => panic!("invalid state: {state:?}"), } @@ -4122,8 +3543,8 @@ mod tests { let peers = manager.peers.read(); let peer_context = peers.get(&peer).unwrap(); match &peer_context.state { - PeerState::Dialing { record } => { - assert_eq!(record.address(), &dial_address); + PeerState::Dialing { dial_record } => { + assert_eq!(dial_record.address, dial_address); } state => panic!("invalid state: {state:?}"), } @@ -4150,8 +3571,8 @@ mod tests { let peer_context = peers.get(&peer).unwrap(); match &peer_context.state { // Must still be dialing the first address. - PeerState::Dialing { record } => { - assert_eq!(record.address(), &dial_address); + PeerState::Dialing { dial_record } => { + assert_eq!(dial_record.address, dial_address); } state => panic!("invalid state: {state:?}"), } diff --git a/src/transport/manager/peer_state.rs b/src/transport/manager/peer_state.rs new file mode 100644 index 00000000..fcdb5330 --- /dev/null +++ b/src/transport/manager/peer_state.rs @@ -0,0 +1,948 @@ +// Copyright 2024 litep2p developers +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +//! Peer state management. + +use crate::{ + transport::{ + manager::{SupportedTransport, LOG_TARGET}, + Endpoint, + }, + types::ConnectionId, + PeerId, +}; + +use multiaddr::{Multiaddr, Protocol}; + +use std::collections::HashSet; + +/// The peer state that tracks connections and dialing attempts. +/// +/// # State Machine +/// +/// ## [`PeerState::Disconnected`] +/// +/// Initially, the peer is in the [`PeerState::Disconnected`] state without a +/// [`PeerState::Disconnected::dial_record`]. This means the peer is fully disconnected. +/// +/// Next states: +/// - [`PeerState::Disconnected`] -> [`PeerState::Dialing`] (via [`PeerState::dial_single_address`]) +/// - [`PeerState::Disconnected`] -> [`PeerState::Opening`] (via [`PeerState::dial_addresses`]) +/// +/// ## [`PeerState::Dialing`] +/// +/// The peer can transition to the [`PeerState::Dialing`] state when a dialing attempt is +/// initiated. This only happens when the peer is dialed on a single address via +/// [`PeerState::dial_single_address`], or when a socket connection established +/// in [`PeerState::Opening`] is upgraded to noise and yamux negotiation phase. +/// +/// The dialing state implies the peer is reached on the socket address provided, as well as +/// negotiating noise and yamux protocols. +/// +/// Next states: +/// - [`PeerState::Dialing`] -> [`PeerState::Connected`] (via +/// [`PeerState::on_connection_established`]) +/// - [`PeerState::Dialing`] -> [`PeerState::Disconnected`] (via [`PeerState::on_dial_failure`]) +/// +/// ## [`PeerState::Opening`] +/// +/// The peer can transition to the [`PeerState::Opening`] state when a dialing attempt is +/// initiated on multiple addresses via [`PeerState::dial_addresses`]. This takes into account +/// the parallelism factor (8 maximum) of the dialing attempts. +/// +/// The opening state holds information about which protocol is being dialed to properly report back +/// errors. +/// +/// The opening state is similar to the dial state, however the peer is only reached on a socket +/// address. The noise and yamux protocols are not negotiated yet. This state transitions to +/// [`PeerState::Dialing`] for the final part of the negotiation. Please note that it would be +/// wasteful to negotiate the noise and yamux protocols on all addresses, since only one +/// connection is kept around. +/// +/// Next states: +/// - [`PeerState::Opening`] -> [`PeerState::Dialing`] (via transport manager +/// `on_connection_opened`) +/// - [`PeerState::Opening`] -> [`PeerState::Disconnected`] (via transport manager +/// `on_connection_opened` if negotiation cannot be started or via `on_open_failure`) +/// - [`PeerState::Opening`] -> [`PeerState::Connected`] (via transport manager +/// `on_connection_established` when an incoming connection is accepted) +#[derive(Debug, Clone, PartialEq)] +pub enum PeerState { + /// `Litep2p` is connected to peer. + Connected { + /// The established record of the connection. + record: ConnectionRecord, + + /// Secondary record, this can either be a dial record or an established connection. + /// + /// While the local node was dialing a remote peer, the remote peer might've dialed + /// the local node and connection was established successfully. The original dial + /// address is stored for processing later when the dial attempt concludes as + /// either successful/failed. + secondary: Option, + }, + + /// Connection to peer is opening over one or more addresses. + Opening { + /// Address records used for dialing. + addresses: HashSet, + + /// Connection ID. + connection_id: ConnectionId, + + /// Active transports. + transports: HashSet, + }, + + /// Peer is being dialed. + Dialing { + /// Address record. + dial_record: ConnectionRecord, + }, + + /// `Litep2p` is not connected to peer. + Disconnected { + /// Dial address, if it exists. + /// + /// While the local node was dialing a remote peer, the remote peer might've dialed + /// the local node and connection was established successfully. The connection might've + /// been closed before the dial concluded which means that + /// [`crate::transport::manager::TransportManager`] must be prepared to handle the dial + /// failure even after the connection has been closed. + dial_record: Option, + }, +} + +/// The state of the secondary connection. +#[derive(Debug, Clone, PartialEq)] +pub enum SecondaryOrDialing { + /// The secondary connection is established. + Secondary(ConnectionRecord), + /// The primary connection is established, but the secondary connection is still dialing. + Dialing(ConnectionRecord), +} + +/// Result of initiating a dial. +#[derive(Debug, Clone, PartialEq)] +pub enum StateDialResult { + /// The peer is already connected. + AlreadyConnected, + /// The dialing state is already in progress. + DialingInProgress, + /// The peer is disconnected, start dialing. + Ok, +} + +impl PeerState { + /// Check if the peer can be dialed. + pub fn can_dial(&self) -> StateDialResult { + match self { + // The peer is already connected, no need to dial again. + Self::Connected { .. } => return StateDialResult::AlreadyConnected, + // The dialing state is already in progress, an event will be emitted later. + Self::Dialing { .. } + | Self::Opening { .. } + | Self::Disconnected { + dial_record: Some(_), + } => { + return StateDialResult::DialingInProgress; + } + + Self::Disconnected { dial_record: None } => StateDialResult::Ok, + } + } + + /// Dial the peer on a single address. + pub fn dial_single_address(&mut self, dial_record: ConnectionRecord) -> StateDialResult { + match self.can_dial() { + StateDialResult::Ok => { + *self = PeerState::Dialing { dial_record }; + StateDialResult::Ok + } + reason => reason, + } + } + + /// Dial the peer on multiple addresses. + pub fn dial_addresses( + &mut self, + connection_id: ConnectionId, + addresses: HashSet, + transports: HashSet, + ) -> StateDialResult { + match self.can_dial() { + StateDialResult::Ok => { + *self = PeerState::Opening { + addresses, + connection_id, + transports, + }; + StateDialResult::Ok + } + reason => reason, + } + } + + /// Handle dial failure. + /// + /// # Transitions + /// + /// - [`PeerState::Dialing`] (with record) -> [`PeerState::Disconnected`] + /// - [`PeerState::Connected`] (with dial record) -> [`PeerState::Connected`] + /// - [`PeerState::Disconnected`] (with dial record) -> [`PeerState::Disconnected`] + /// + /// Returns `true` if the connection was handled. + pub fn on_dial_failure(&mut self, connection_id: ConnectionId) -> bool { + match self { + // Clear the dial record if the connection ID matches. + Self::Dialing { dial_record } => + if dial_record.connection_id == connection_id { + *self = Self::Disconnected { dial_record: None }; + return true; + }, + + Self::Connected { + record, + secondary: Some(SecondaryOrDialing::Dialing(dial_record)), + } => + if dial_record.connection_id == connection_id { + *self = Self::Connected { + record: record.clone(), + secondary: None, + }; + return true; + }, + + Self::Disconnected { + dial_record: Some(dial_record), + } => + if dial_record.connection_id == connection_id { + *self = Self::Disconnected { dial_record: None }; + return true; + }, + + Self::Opening { .. } | Self::Connected { .. } | Self::Disconnected { .. } => + return false, + }; + + false + } + + /// Returns `true` if the connection should be accepted by the transport manager. + pub fn on_connection_established(&mut self, connection: ConnectionRecord) -> bool { + match self { + // Transform the dial record into a secondary connection. + Self::Connected { + record, + secondary: Some(SecondaryOrDialing::Dialing(dial_record)), + } => + if dial_record.connection_id == connection.connection_id { + *self = Self::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Secondary(connection)), + }; + + return true; + }, + + // There's place for a secondary connection. + Self::Connected { + record, + secondary: None, + } => { + *self = Self::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Secondary(connection)), + }; + + return true; + } + + // Convert the dial record into a primary connection or preserve it. + Self::Dialing { dial_record } + | Self::Disconnected { + dial_record: Some(dial_record), + } => + if dial_record.connection_id == connection.connection_id { + *self = Self::Connected { + record: connection.clone(), + secondary: None, + }; + return true; + } else { + *self = Self::Connected { + record: connection, + secondary: Some(SecondaryOrDialing::Dialing(dial_record.clone())), + }; + return true; + }, + + Self::Disconnected { dial_record: None } => { + *self = Self::Connected { + record: connection, + secondary: None, + }; + + return true; + } + + // Accept the incoming connection. + Self::Opening { + addresses, + connection_id, + .. + } => { + tracing::trace!( + target: LOG_TARGET, + ?connection, + opening_addresses = ?addresses, + opening_connection_id = ?connection_id, + "Connection established while opening" + ); + + *self = Self::Connected { + record: connection, + secondary: None, + }; + + return true; + } + + _ => {} + }; + + return false; + } + + /// Returns `true` if the connection was closed. + pub fn on_connection_closed(&mut self, connection_id: ConnectionId) -> bool { + match self { + Self::Connected { record, secondary } => { + // Primary connection closed. + if record.connection_id == connection_id { + match secondary { + // Promote secondary connection to primary. + Some(SecondaryOrDialing::Secondary(secondary)) => { + *self = Self::Connected { + record: secondary.clone(), + secondary: None, + }; + } + // Preserve the dial record. + Some(SecondaryOrDialing::Dialing(dial_record)) => { + *self = Self::Disconnected { + dial_record: Some(dial_record.clone()), + }; + + return true; + } + None => { + *self = Self::Disconnected { dial_record: None }; + + return true; + } + }; + + return false; + } + + match secondary { + // Secondary connection closed. + Some(SecondaryOrDialing::Secondary(secondary)) + if secondary.connection_id == connection_id => + { + *self = Self::Connected { + record: record.clone(), + secondary: None, + }; + } + _ => (), + } + } + _ => (), + } + + false + } + + /// Returns `true` if the last transport failed to open. + pub fn on_open_failure(&mut self, transport: SupportedTransport) -> bool { + match self { + Self::Opening { transports, .. } => { + transports.remove(&transport); + + if transports.is_empty() { + *self = Self::Disconnected { dial_record: None }; + return true; + } + + return false; + } + _ => false, + } + } + + /// Returns `true` if the connection was opened. + pub fn on_connection_opened(&mut self, record: ConnectionRecord) -> bool { + match self { + Self::Opening { + addresses, + connection_id, + .. + } => { + if record.connection_id != *connection_id || !addresses.contains(&record.address) { + tracing::warn!( + target: LOG_TARGET, + ?record, + ?addresses, + ?connection_id, + "Connection opened for unknown address or connection ID", + ); + } + + *self = Self::Dialing { + dial_record: record.clone(), + }; + + true + } + _ => false, + } + } +} + +/// The connection record keeps track of the connection ID and the address of the connection. +/// +/// The connection ID is used to track the connection in the transport layer. +/// While the address is used to keep a healthy view of the network for dialing purposes. +/// +/// # Note +/// +/// The structure is used to keep track of: +/// +/// - dialing state for outbound connections. +/// - established outbound connections via [`PeerState::Connected`]. +/// - established inbound connections via `PeerContext::secondary_connection`. +#[derive(Debug, Clone, Hash, PartialEq)] +pub struct ConnectionRecord { + /// Address of the connection. + /// + /// The address must contain the peer ID extension `/p2p/`. + pub address: Multiaddr, + + /// Connection ID resulted from dialing. + pub connection_id: ConnectionId, +} + +impl ConnectionRecord { + /// Construct a new connection record. + pub fn new(peer: PeerId, address: Multiaddr, connection_id: ConnectionId) -> Self { + Self { + address: Self::ensure_peer_id(peer, address), + connection_id, + } + } + + /// Create a new connection record from the peer ID and the endpoint. + pub fn from_endpoint(peer: PeerId, endpoint: &Endpoint) -> Self { + Self { + address: Self::ensure_peer_id(peer, endpoint.address().clone()), + connection_id: endpoint.connection_id(), + } + } + + /// Ensures the peer ID is present in the address. + fn ensure_peer_id(peer: PeerId, mut address: Multiaddr) -> Multiaddr { + if let Some(Protocol::P2p(multihash)) = address.iter().last() { + if multihash != *peer.as_ref() { + tracing::warn!( + target: LOG_TARGET, + ?address, + ?peer, + "Peer ID mismatch in address", + ); + + address.pop(); + address.push(Protocol::P2p(*peer.as_ref())); + } + + address + } else { + address.with(Protocol::P2p(*peer.as_ref())) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn state_can_dial() { + let state = PeerState::Disconnected { dial_record: None }; + assert_eq!(state.can_dial(), StateDialResult::Ok); + + let record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(0), + ); + + let state = PeerState::Disconnected { + dial_record: Some(record.clone()), + }; + assert_eq!(state.can_dial(), StateDialResult::DialingInProgress); + + let state = PeerState::Dialing { + dial_record: record.clone(), + }; + assert_eq!(state.can_dial(), StateDialResult::DialingInProgress); + + let state = PeerState::Opening { + addresses: Default::default(), + connection_id: ConnectionId::from(0), + transports: Default::default(), + }; + assert_eq!(state.can_dial(), StateDialResult::DialingInProgress); + + let state = PeerState::Connected { + record, + secondary: None, + }; + assert_eq!(state.can_dial(), StateDialResult::AlreadyConnected); + } + + #[test] + fn state_dial_single_address() { + let record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(0), + ); + + let mut state = PeerState::Disconnected { dial_record: None }; + assert_eq!( + state.dial_single_address(record.clone()), + StateDialResult::Ok + ); + assert_eq!( + state, + PeerState::Dialing { + dial_record: record + } + ); + } + + #[test] + fn state_dial_addresses() { + let mut state = PeerState::Disconnected { dial_record: None }; + assert_eq!( + state.dial_addresses( + ConnectionId::from(0), + Default::default(), + Default::default() + ), + StateDialResult::Ok + ); + assert_eq!( + state, + PeerState::Opening { + addresses: Default::default(), + connection_id: ConnectionId::from(0), + transports: Default::default() + } + ); + } + + #[test] + fn check_dial_failure() { + let record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(0), + ); + + // Check from the dialing state. + { + let mut state = PeerState::Dialing { + dial_record: record.clone(), + }; + let previous_state = state.clone(); + // Check with different connection ID. + state.on_dial_failure(ConnectionId::from(1)); + assert_eq!(state, previous_state); + + // Check with the same connection ID. + state.on_dial_failure(ConnectionId::from(0)); + assert_eq!(state, PeerState::Disconnected { dial_record: None }); + } + + // Check from the connected state without dialing state. + { + let mut state = PeerState::Connected { + record: record.clone(), + secondary: None, + }; + let previous_state = state.clone(); + // Check with different connection ID. + state.on_dial_failure(ConnectionId::from(1)); + assert_eq!(state, previous_state); + + // Check with the same connection ID. + // The connection ID is checked against dialing records, not established connections. + state.on_dial_failure(ConnectionId::from(0)); + assert_eq!(state, previous_state); + } + + // Check from the connected state with dialing state. + { + let mut state = PeerState::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Dialing(record.clone())), + }; + let previous_state = state.clone(); + // Check with different connection ID. + state.on_dial_failure(ConnectionId::from(1)); + assert_eq!(state, previous_state); + + // Check with the same connection ID. + // Dial record is cleared. + state.on_dial_failure(ConnectionId::from(0)); + assert_eq!( + state, + PeerState::Connected { + record: record.clone(), + secondary: None, + } + ); + } + + // Check from the disconnected state. + { + let mut state = PeerState::Disconnected { + dial_record: Some(record.clone()), + }; + let previous_state = state.clone(); + // Check with different connection ID. + state.on_dial_failure(ConnectionId::from(1)); + assert_eq!(state, previous_state); + + // Check with the same connection ID. + state.on_dial_failure(ConnectionId::from(0)); + assert_eq!(state, PeerState::Disconnected { dial_record: None }); + } + } + + #[test] + fn check_connection_established() { + let record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(0), + ); + let second_record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(1), + ); + + // Check from the connected state without secondary connection. + { + let mut state = PeerState::Connected { + record: record.clone(), + secondary: None, + }; + // Secondary is established. + assert!(state.on_connection_established(record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Secondary(record.clone())), + } + ); + } + + // Check from the connected state with secondary dialing connection. + { + let mut state = PeerState::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Dialing(record.clone())), + }; + // Promote the secondary connection. + assert!(state.on_connection_established(record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Secondary(record.clone())), + } + ); + } + + // Check from the connected state with secondary established connection. + { + let mut state = PeerState::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Secondary(record.clone())), + }; + // No state to advance. + assert!(!state.on_connection_established(record.clone())); + } + + // Opening state is completely wiped out. + { + let mut state = PeerState::Opening { + addresses: Default::default(), + connection_id: ConnectionId::from(0), + transports: Default::default(), + }; + assert!(state.on_connection_established(record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: record.clone(), + secondary: None, + } + ); + } + + // Disconnected state with dial record. + { + let mut state = PeerState::Disconnected { + dial_record: Some(record.clone()), + }; + assert!(state.on_connection_established(record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: record.clone(), + secondary: None, + } + ); + } + + // Disconnected with different dial record. + { + let mut state = PeerState::Disconnected { + dial_record: Some(record.clone()), + }; + assert!(state.on_connection_established(second_record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: second_record.clone(), + secondary: Some(SecondaryOrDialing::Dialing(record.clone())) + } + ); + } + + // Disconnected without dial record. + { + let mut state = PeerState::Disconnected { dial_record: None }; + assert!(state.on_connection_established(record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: record.clone(), + secondary: None, + } + ); + } + + // Dialing with different dial record. + { + let mut state = PeerState::Dialing { + dial_record: record.clone(), + }; + assert!(state.on_connection_established(second_record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: second_record.clone(), + secondary: Some(SecondaryOrDialing::Dialing(record.clone())) + } + ); + } + + // Dialing with the same dial record. + { + let mut state = PeerState::Dialing { + dial_record: record.clone(), + }; + assert!(state.on_connection_established(record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: record.clone(), + secondary: None, + } + ); + } + } + + #[test] + fn check_connection_closed() { + let record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(0), + ); + let second_record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(1), + ); + + // Primary is closed + { + let mut state = PeerState::Connected { + record: record.clone(), + secondary: None, + }; + assert!(state.on_connection_closed(ConnectionId::from(0))); + assert_eq!(state, PeerState::Disconnected { dial_record: None }); + } + + // Primary is closed with secondary promoted + { + let mut state = PeerState::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Secondary(second_record.clone())), + }; + // Peer is still connected. + assert!(!state.on_connection_closed(ConnectionId::from(0))); + assert_eq!( + state, + PeerState::Connected { + record: second_record.clone(), + secondary: None, + } + ); + } + + // Primary is closed with secondary dial record + { + let mut state = PeerState::Connected { + record: record.clone(), + secondary: Some(SecondaryOrDialing::Dialing(second_record.clone())), + }; + assert!(state.on_connection_closed(ConnectionId::from(0))); + assert_eq!( + state, + PeerState::Disconnected { + dial_record: Some(second_record.clone()) + } + ); + } + } + + #[test] + fn check_open_failure() { + let mut state = PeerState::Opening { + addresses: Default::default(), + connection_id: ConnectionId::from(0), + transports: [SupportedTransport::Tcp].into_iter().collect(), + }; + + // This is the last protocol + assert!(state.on_open_failure(SupportedTransport::Tcp)); + assert_eq!(state, PeerState::Disconnected { dial_record: None }); + } + + #[test] + fn check_open_connection() { + let record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(0), + ); + + let mut state = PeerState::Opening { + addresses: Default::default(), + connection_id: ConnectionId::from(0), + transports: [SupportedTransport::Tcp].into_iter().collect(), + }; + + assert!(state.on_connection_opened(record.clone())); + } + + #[test] + fn check_full_lifecycle() { + let record = ConnectionRecord::new( + PeerId::random(), + "/ip4/1.1.1.1/tcp/80".parse().unwrap(), + ConnectionId::from(0), + ); + + let mut state = PeerState::Disconnected { dial_record: None }; + // Dialing. + assert_eq!( + state.dial_single_address(record.clone()), + StateDialResult::Ok + ); + assert_eq!( + state, + PeerState::Dialing { + dial_record: record.clone() + } + ); + + // Dialing failed. + state.on_dial_failure(ConnectionId::from(0)); + assert_eq!(state, PeerState::Disconnected { dial_record: None }); + + // Opening. + assert_eq!( + state.dial_addresses( + ConnectionId::from(0), + Default::default(), + Default::default() + ), + StateDialResult::Ok + ); + + // Open failure. + assert!(state.on_open_failure(SupportedTransport::Tcp)); + assert_eq!(state, PeerState::Disconnected { dial_record: None }); + + // Dial again. + assert_eq!( + state.dial_single_address(record.clone()), + StateDialResult::Ok + ); + assert_eq!( + state, + PeerState::Dialing { + dial_record: record.clone() + } + ); + + // Successful dial. + assert!(state.on_connection_established(record.clone())); + assert_eq!( + state, + PeerState::Connected { + record: record.clone(), + secondary: None + } + ); + } +} diff --git a/src/transport/manager/types.rs b/src/transport/manager/types.rs index 2a853606..15eb2c50 100644 --- a/src/transport/manager/types.rs +++ b/src/transport/manager/types.rs @@ -18,14 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{ - transport::manager::address::{AddressRecord, AddressStore}, - types::ConnectionId, -}; - -use multiaddr::Multiaddr; - -use std::collections::{HashMap, HashSet}; +use crate::transport::manager::{address::AddressStore, peer_state::PeerState}; /// Supported protocols. #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] @@ -46,63 +39,12 @@ pub enum SupportedTransport { WebSocket, } -/// Peer state. -#[derive(Debug)] -pub enum PeerState { - /// `Litep2p` is connected to peer. - Connected { - /// Address record. - record: AddressRecord, - - /// Dial address, if it exists. - /// - /// While the local node was dialing a remote peer, the remote peer might've dialed - /// the local node and connection was established successfully. This dial address - /// is stored for processing later when the dial attempt concluded as either - /// successful/failed. - dial_record: Option, - }, - - /// Connection to peer is opening over one or more addresses. - Opening { - /// Address records used for dialing. - records: HashMap, - - /// Connection ID. - connection_id: ConnectionId, - - /// Active transports. - transports: HashSet, - }, - - /// Peer is being dialed. - Dialing { - /// Address record. - record: AddressRecord, - }, - - /// `Litep2p` is not connected to peer. - Disconnected { - /// Dial address, if it exists. - /// - /// While the local node was dialing a remote peer, the remote peer might've dialed - /// the local node and connection was established successfully. The connection might've - /// been closed before the dial concluded which means that - /// [`crate::transport::manager::TransportManager`] must be prepared to handle the dial - /// failure even after the connection has been closed. - dial_record: Option, - }, -} - /// Peer context. #[derive(Debug)] pub struct PeerContext { /// Peer state. pub state: PeerState, - /// Secondary connection, if it's open. - pub secondary_connection: Option, - /// Known addresses of peer. pub addresses: AddressStore, } @@ -111,7 +53,6 @@ impl Default for PeerContext { fn default() -> Self { Self { state: PeerState::Disconnected { dial_record: None }, - secondary_connection: None, addresses: AddressStore::new(), } }