Skip to content

Commit

Permalink
feature: Add timeouts for handshake functions (paradigmxyz#7295)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackG-eth authored Mar 26, 2024
1 parent 3f34db3 commit 4d798c7
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 2 deletions.
3 changes: 3 additions & 0 deletions crates/net/ecies/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ pub enum ECIESErrorImpl {
/// a message from the (partially filled) buffer.
#[error("stream closed due to not being readable")]
UnreadableStream,
// Error when data is not recieved from peer for a prolonged period.
#[error("never recieved data from remote peer")]
StreamTimeout,
}

impl From<ECIESErrorImpl> for ECIESError {
Expand Down
67 changes: 66 additions & 1 deletion crates/net/ecies/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@ use std::{
io,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::{
io::{AsyncRead, AsyncWrite},
time::timeout,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_stream::{Stream, StreamExt};
use tokio_util::codec::{Decoder, Framed};
use tracing::{instrument, trace};

const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);

/// `ECIES` stream over TCP exchanging raw bytes
#[derive(Debug)]
#[pin_project::pin_project]
Expand All @@ -40,6 +46,27 @@ where
transport: Io,
secret_key: SecretKey,
remote_id: PeerId,
) -> Result<Self, ECIESError> {
Self::connect_with_timeout(transport, secret_key, remote_id, HANDSHAKE_TIMEOUT).await
}

/// Wrapper around connect_no_timeout which enforces a timeout.
pub async fn connect_with_timeout(
transport: Io,
secret_key: SecretKey,
remote_id: PeerId,
timeout_limit: Duration,
) -> Result<Self, ECIESError> {
timeout(timeout_limit, Self::connect_without_timeout(transport, secret_key, remote_id))
.await
.map_err(|_| ECIESError::from(ECIESErrorImpl::StreamTimeout))?
}

/// Connect to an `ECIES` server with no timeout.
pub async fn connect_without_timeout(
transport: Io,
secret_key: SecretKey,
remote_id: PeerId,
) -> Result<Self, ECIESError> {
let ecies = ECIESCodec::new_client(secret_key, remote_id)
.map_err(|_| io::Error::new(io::ErrorKind::Other, "invalid handshake"))?;
Expand Down Expand Up @@ -180,4 +207,42 @@ mod tests {
// make sure the server receives the message and asserts before ending the test
handle.await.unwrap();
}

#[tokio::test]
async fn connection_should_timeout() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_key = SecretKey::new(&mut rand::thread_rng());

let _handle = tokio::spawn(async move {
// Delay accepting the connection for longer than the client's timeout period
tokio::time::sleep(Duration::from_secs(11)).await;
let (incoming, _) = listener.accept().await.unwrap();
let mut stream = ECIESStream::incoming(incoming, server_key).await.unwrap();

// use the stream to get the next message
let message = stream.next().await.unwrap().unwrap();
assert_eq!(message, Bytes::from("hello"));
});

// create the server pubkey
let server_id = pk2id(&server_key.public_key(SECP256K1));

let client_key = SecretKey::new(&mut rand::thread_rng());
let outgoing = TcpStream::connect(addr).await.unwrap();

// Attempt to connect, expecting a timeout due to the server's delayed response
let connect_result = ECIESStream::connect_with_timeout(
outgoing,
client_key,
server_id,
Duration::from_secs(1),
)
.await;

// Assert that a timeout error occurred
assert!(
matches!(connect_result, Err(e) if e.to_string() == ECIESErrorImpl::StreamTimeout.to_string())
);
}
}
3 changes: 3 additions & 0 deletions crates/net/eth-wire/src/errors/eth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ pub enum EthStreamError {
/// The number of transaction sizes.
sizes_len: usize,
},
/// Error when data is not recieved from peer for a prolonged period.
#[error("never recieved data from remote peer")]
StreamTimeout,
}

// === impl EthStreamError ===
Expand Down
75 changes: 75 additions & 0 deletions crates/net/eth-wire/src/ethstream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{
errors::{EthHandshakeError, EthStreamError},
message::{EthBroadcastMessage, ProtocolBroadcastMessage},
p2pstream::HANDSHAKE_TIMEOUT,
types::{EthMessage, ProtocolMessage, Status},
CanDisconnect, DisconnectReason, EthVersion,
};
Expand All @@ -13,7 +14,9 @@ use reth_primitives::{
use std::{
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::time::timeout;
use tokio_stream::Stream;
use tracing::{debug, trace};

Expand Down Expand Up @@ -51,6 +54,27 @@ where
/// handshake is completed successfully. This also returns the `Status` message sent by the
/// remote peer.
pub async fn handshake(
self,
status: Status,
fork_filter: ForkFilter,
) -> Result<(EthStream<S>, Status), EthStreamError> {
self.handshake_with_timeout(status, fork_filter, HANDSHAKE_TIMEOUT).await
}

/// Wrapper around handshake which enforces a timeout.
pub async fn handshake_with_timeout(
self,
status: Status,
fork_filter: ForkFilter,
timeout_limit: Duration,
) -> Result<(EthStream<S>, Status), EthStreamError> {
timeout(timeout_limit, Self::handshake_without_timeout(self, status, fork_filter))
.await
.map_err(|_| EthStreamError::StreamTimeout)?
}

/// Handshake with no timeout
pub async fn handshake_without_timeout(
mut self,
status: Status,
fork_filter: ForkFilter,
Expand Down Expand Up @@ -321,6 +345,8 @@ where

#[cfg(test)]
mod tests {
use std::time::Duration;

use super::UnauthedEthStream;
use crate::{
errors::{EthHandshakeError, EthStreamError},
Expand Down Expand Up @@ -642,4 +668,53 @@ mod tests {
// make sure the server receives the message and asserts before ending the test
handle.await.unwrap();
}

#[tokio::test]
async fn handshake_should_timeout() {
let genesis = B256::random();
let fork_filter = ForkFilter::new(Head::default(), genesis, 0, Vec::new());

let status = Status {
version: EthVersion::Eth67 as u8,
chain: NamedChain::Mainnet.into(),
total_difficulty: U256::ZERO,
blockhash: B256::random(),
genesis,
// Pass the current fork id.
forkid: fork_filter.current(),
};

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();

let status_clone = status;
let fork_filter_clone = fork_filter.clone();
let _handle = tokio::spawn(async move {
// Delay accepting the connection for longer than the client's timeout period
tokio::time::sleep(Duration::from_secs(11)).await;
// roughly based off of the design of tokio::net::TcpListener
let (incoming, _) = listener.accept().await.unwrap();
let stream = PassthroughCodec::default().framed(incoming);
let (_, their_status) = UnauthedEthStream::new(stream)
.handshake(status_clone, fork_filter_clone)
.await
.unwrap();

// just make sure it equals our status (our status is a clone of their status)
assert_eq!(their_status, status_clone);
});

let outgoing = TcpStream::connect(local_addr).await.unwrap();
let sink = PassthroughCodec::default().framed(outgoing);

// try to connect
let handshake_result = UnauthedEthStream::new(sink)
.handshake_with_timeout(status, fork_filter, Duration::from_secs(1))
.await;

// Assert that a timeout error occurred
assert!(
matches!(handshake_result, Err(e) if e.to_string() == EthStreamError::StreamTimeout.to_string())
);
}
}
2 changes: 1 addition & 1 deletion crates/net/eth-wire/src/p2pstream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ const MAX_P2P_MESSAGE_ID: u8 = P2PMessageID::Pong as u8;

/// [`HANDSHAKE_TIMEOUT`] determines the amount of time to wait before determining that a `p2p`
/// handshake has timed out.
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
pub(crate) const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);

/// [`PING_TIMEOUT`] determines the amount of time to wait before determining that a `p2p` ping has
/// timed out.
Expand Down

0 comments on commit 4d798c7

Please sign in to comment.