Skip to content

Commit

Permalink
Make quinn-proto::{Connection, Endpoint} deterministic
Browse files Browse the repository at this point in the history
`quinn-proto::{Connection, Endpoint}` structs call
`StdRng::from_entropy()` in their implementation. It makes it hard to
reproduce errors in an otherwise deterministic test case.

This PR addresses the issue by adding a `StdRng` argument to the
respective `new` functions. For other functions that may create an
endpoint/connection, an argument of type `impl FnOnce() -> StdRng` is
added. This avoids eager calls to generate fresh entropy. e.g. the
`Endpoint::handle` function is frequently called but we'd only need the
rng when handling a new incoming connection.

Fixes lint errors.
  • Loading branch information
michael-yxchen committed Oct 18, 2023
1 parent e1e1e6e commit 3748e1e
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 26 deletions.
4 changes: 2 additions & 2 deletions quinn-proto/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::{

use bytes::{Bytes, BytesMut};
use frame::StreamMetaVec;
use rand::{rngs::StdRng, Rng, SeedableRng};
use rand::{rngs::StdRng, Rng};
use thiserror::Error;
use tracing::{debug, error, trace, trace_span, warn};

Expand Down Expand Up @@ -252,6 +252,7 @@ impl Connection {
now: Instant,
version: u32,
allow_mtud: bool,
mut rng: StdRng,
) -> Self {
let side = if server_config.is_some() {
Side::Server
Expand All @@ -267,7 +268,6 @@ impl Connection {
expected_token: Bytes::new(),
client_hello: None,
});
let mut rng = StdRng::from_entropy();
let path_validated = server_config.as_ref().map_or(true, |c| c.use_retry);
let mut this = Self {
endpoint_config,
Expand Down
14 changes: 11 additions & 3 deletions quinn-proto/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{
};

use bytes::{BufMut, Bytes, BytesMut};
use rand::{rngs::StdRng, Rng, RngCore, SeedableRng};
use rand::{rngs::StdRng, Rng, RngCore};
use rustc_hash::FxHashMap;
use slab::Slab;
use thiserror::Error;
Expand Down Expand Up @@ -57,9 +57,10 @@ impl Endpoint {
config: Arc<EndpointConfig>,
server_config: Option<Arc<ServerConfig>>,
allow_mtud: bool,
rng: StdRng,
) -> Self {
Self {
rng: StdRng::from_entropy(),
rng,
index: ConnectionIndex::default(),
connections: Slab::new(),
local_cid_generator: (config.connection_id_generator_factory.as_ref())(),
Expand Down Expand Up @@ -126,6 +127,7 @@ impl Endpoint {
local_ip: Option<IpAddr>,
ecn: Option<EcnCodepoint>,
data: BytesMut,
rng: impl FnOnce() -> StdRng,
) -> Option<DatagramEvent> {
let datagram_len = data.len();
let (first_decode, remaining) = match PartialDecode::new(
Expand Down Expand Up @@ -232,7 +234,7 @@ impl Endpoint {
};
return match first_decode.finish(Some(&*crypto.header.remote)) {
Ok(packet) => {
self.handle_first_packet(now, addresses, ecn, packet, remaining, &crypto)
self.handle_first_packet(now, addresses, ecn, packet, remaining, &crypto, rng)
}
Err(e) => {
trace!("unable to decode initial packet: {}", e);
Expand Down Expand Up @@ -316,6 +318,7 @@ impl Endpoint {
config: ClientConfig,
remote: SocketAddr,
server_name: &str,
rng: impl FnOnce() -> StdRng,
) -> Result<(ConnectionHandle, Connection), ConnectError> {
if self.is_full() {
return Err(ConnectError::TooManyConnections);
Expand Down Expand Up @@ -357,6 +360,7 @@ impl Endpoint {
tls,
None,
config.transport,
rng,
);
Ok((ch, conn))
}
Expand Down Expand Up @@ -403,6 +407,7 @@ impl Endpoint {
mut packet: Packet,
rest: Option<BytesMut>,
crypto: &Keys,
rng: impl FnOnce() -> StdRng,
) -> Option<DatagramEvent> {
let (src_cid, dst_cid, token, packet_number, version) = match packet.header {
Header::Initial {
Expand Down Expand Up @@ -554,6 +559,7 @@ impl Endpoint {
tls,
Some(server_config),
transport_config,
rng,
);
if dst_cid.len() != 0 {
self.index.insert_initial(dst_cid, ch);
Expand Down Expand Up @@ -588,6 +594,7 @@ impl Endpoint {
tls: Box<dyn crypto::Session>,
server_config: Option<Arc<ServerConfig>>,
transport_config: Arc<TransportConfig>,
rng: impl FnOnce() -> StdRng,
) -> Connection {
let conn = Connection::new(
self.config.clone(),
Expand All @@ -603,6 +610,7 @@ impl Endpoint {
now,
version,
self.allow_mtud,
rng(),
);

let id = self.connections.insert(ConnectionMeta {
Expand Down
58 changes: 50 additions & 8 deletions quinn-proto/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{
use assert_matches::assert_matches;
use bytes::Bytes;
use hex_literal::hex;
use rand::RngCore;
use rand::{rngs::StdRng, RngCore, SeedableRng};
use ring::hmac;
use rustls::AlertDescription;
use tracing::info;
Expand All @@ -27,7 +27,12 @@ use util::*;
fn version_negotiate_server() {
let _guard = subscribe();
let client_addr = "[::2]:7890".parse().unwrap();
let mut server = Endpoint::new(Default::default(), Some(Arc::new(server_config())), true);
let mut server = Endpoint::new(
Default::default(),
Some(Arc::new(server_config())),
true,
StdRng::from_entropy(),
);
let now = Instant::now();
let event = server.handle(
now,
Expand All @@ -36,6 +41,7 @@ fn version_negotiate_server() {
None,
// Long-header packet with reserved version number
hex!("80 0a1a2a3a 04 00000000 04 00000000 00")[..].into(),
StdRng::from_entropy,
);
let Some(DatagramEvent::Response(Transmit { contents, .. })) = event else {
panic!("expected a response");
Expand Down Expand Up @@ -63,9 +69,16 @@ fn version_negotiate_client() {
}),
None,
true,
StdRng::from_entropy(),
);
let (_, mut client_ch) = client
.connect(Instant::now(), client_config(), server_addr, "localhost")
.connect(
Instant::now(),
client_config(),
server_addr,
"localhost",
StdRng::from_entropy,
)
.unwrap();
let now = Instant::now();
let opt_event = client.handle(
Expand All @@ -79,6 +92,7 @@ fn version_negotiate_client() {
0a1a2a3a"
)[..]
.into(),
StdRng::from_entropy,
);
if let Some(DatagramEvent::ConnectionEvent(_, event)) = opt_event {
client_ch.handle_event(event);
Expand Down Expand Up @@ -178,7 +192,12 @@ fn server_stateless_reset() {
let mut pair = Pair::new(endpoint_config.clone(), server_config());
let (client_ch, _) = pair.connect();
pair.drive(); // Flush any post-handshake frames
pair.server.endpoint = Endpoint::new(endpoint_config, Some(Arc::new(server_config())), true);
pair.server.endpoint = Endpoint::new(
endpoint_config,
Some(Arc::new(server_config())),
true,
StdRng::from_entropy(),
);
// Force the server to generate the smallest possible stateless reset
pair.client.connections.get_mut(&client_ch).unwrap().ping();
info!("resetting");
Expand All @@ -203,7 +222,12 @@ fn client_stateless_reset() {

let mut pair = Pair::new(endpoint_config.clone(), server_config());
let (_, server_ch) = pair.connect();
pair.client.endpoint = Endpoint::new(endpoint_config, Some(Arc::new(server_config())), true);
pair.client.endpoint = Endpoint::new(
endpoint_config,
Some(Arc::new(server_config())),
true,
StdRng::from_entropy(),
);
// Send something big enough to allow room for a smaller stateless reset.
pair.server.connections.get_mut(&server_ch).unwrap().close(
pair.time,
Expand Down Expand Up @@ -1343,8 +1367,14 @@ fn cid_rotation() {
}),
Some(Arc::new(server_config())),
true,
StdRng::from_entropy(),
);
let client = Endpoint::new(
Arc::new(EndpointConfig::default()),
None,
true,
StdRng::from_entropy(),
);
let client = Endpoint::new(Arc::new(EndpointConfig::default()), None, true);

let mut pair = Pair::new_from_endpoint(client, server);
let (_, server_ch) = pair.connect();
Expand Down Expand Up @@ -1922,13 +1952,19 @@ fn big_cert_and_key() -> (rustls::Certificate, rustls::PrivateKey) {
fn malformed_token_len() {
let _guard = subscribe();
let client_addr = "[::2]:7890".parse().unwrap();
let mut server = Endpoint::new(Default::default(), Some(Arc::new(server_config())), true);
let mut server = Endpoint::new(
Default::default(),
Some(Arc::new(server_config())),
true,
StdRng::from_entropy(),
);
server.handle(
Instant::now(),
client_addr,
None,
None,
hex!("8900 0000 0101 0000 1b1b 841b 0000 0000 3f00")[..].into(),
StdRng::from_entropy,
);
}

Expand Down Expand Up @@ -2024,12 +2060,18 @@ fn migrate_detects_new_mtu_and_respects_original_peer_max_udp_payload_size() {
Arc::new(server_endpoint_config),
Some(Arc::new(server_config())),
true,
StdRng::from_entropy(),
);
let client_endpoint_config = EndpointConfig {
max_udp_payload_size: VarInt::from(client_max_udp_payload_size),
..EndpointConfig::default()
};
let client = Endpoint::new(Arc::new(client_endpoint_config), None, true);
let client = Endpoint::new(
Arc::new(client_endpoint_config),
None,
true,
StdRng::from_entropy(),
);
let mut pair = Pair::new_from_endpoint(client, server);
pair.mtu = 1300;

Expand Down
30 changes: 22 additions & 8 deletions quinn-proto/src/tests/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use std::{
use assert_matches::assert_matches;
use bytes::BytesMut;
use lazy_static::lazy_static;
use rand::{rngs::StdRng, SeedableRng};
use rustls::{Certificate, KeyLogFile, PrivateKey};
use tracing::{info_span, trace};

Expand Down Expand Up @@ -49,8 +50,13 @@ impl Pair {
}

pub(super) fn new(endpoint_config: Arc<EndpointConfig>, server_config: ServerConfig) -> Self {
let server = Endpoint::new(endpoint_config.clone(), Some(Arc::new(server_config)), true);
let client = Endpoint::new(endpoint_config, None, true);
let server = Endpoint::new(
endpoint_config.clone(),
Some(Arc::new(server_config)),
true,
StdRng::from_entropy(),
);
let client = Endpoint::new(endpoint_config, None, true, StdRng::from_entropy());

Self::new_from_endpoint(client, server)
}
Expand Down Expand Up @@ -206,7 +212,13 @@ impl Pair {
let _guard = span.enter();
let (client_ch, client_conn) = self
.client
.connect(Instant::now(), config, self.server.addr, "localhost")
.connect(
Instant::now(),
config,
self.server.addr,
"localhost",
StdRng::from_entropy,
)
.unwrap();
self.client.connections.insert(client_ch, client_conn);
client_ch
Expand Down Expand Up @@ -332,7 +344,12 @@ impl TestEndpoint {

while self.inbound.front().map_or(false, |x| x.0 <= now) {
let (recv_time, ecn, packet) = self.inbound.pop_front().unwrap();
if let Some(event) = self.endpoint.handle(recv_time, remote, None, ecn, packet) {
if let Some(event) = self
.endpoint
.handle(recv_time, remote, None, ecn, packet, || {
StdRng::from_entropy()
})
{
match event {
DatagramEvent::NewConnection(ch, conn) => {
self.connections.insert(ch, conn);
Expand All @@ -344,10 +361,7 @@ impl TestEndpoint {
self.captured_packets.extend(packet);
}

self.conn_events
.entry(ch)
.or_insert_with(VecDeque::new)
.push_back(event);
self.conn_events.entry(ch).or_default().push_back(event);
}
DatagramEvent::Response(transmit) => {
self.outbound.extend(split_transmit(transmit));
Expand Down
1 change: 1 addition & 0 deletions quinn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ futures-io = { version = "0.3.19", optional = true }
rustc-hash = "1.1"
pin-project-lite = "0.2"
proto = { package = "quinn-proto", path = "../quinn-proto", version = "0.11", default-features = false }
rand = "0.8"
rustls = { version = "0.21.0", default-features = false, features = ["quic"], optional = true }
thiserror = "1.0.21"
tracing = "0.1.10"
Expand Down
2 changes: 1 addition & 1 deletion quinn/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ impl Connection {
///
/// The dynamic type returned is determined by the configured
/// [`Session`](proto::crypto::Session). For the default `rustls` session, the return value can
/// be [`downcast`](Box::downcast) to a <code>Vec<[rustls::Certificate](rustls::Certificate)></code>
/// be [`downcast`](Box::downcast) to a <code>Vec<[rustls::Certificate]></code>
pub fn peer_identity(&self) -> Option<Box<dyn Any>> {
self.0
.state
Expand Down
19 changes: 15 additions & 4 deletions quinn/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use pin_project_lite::pin_project;
use proto::{
self as proto, ClientConfig, ConnectError, ConnectionHandle, DatagramEvent, ServerConfig,
};
use rand::{rngs::StdRng, SeedableRng};
use rustc_hash::FxHashMap;
use tokio::sync::{futures::Notified, mpsc, Notify};
use tracing::{Instrument, Span};
Expand Down Expand Up @@ -110,7 +111,12 @@ impl Endpoint {
let allow_mtud = !socket.may_fragment();
let rc = EndpointRef::new(
socket,
proto::Endpoint::new(Arc::new(config), server_config.map(Arc::new), allow_mtud),
proto::Endpoint::new(
Arc::new(config),
server_config.map(Arc::new),
allow_mtud,
StdRng::from_entropy(),
),
addr.is_ipv6(),
runtime.clone(),
);
Expand Down Expand Up @@ -187,9 +193,13 @@ impl Endpoint {
addr
};

let (ch, conn) = endpoint
.inner
.connect(Instant::now(), config, addr, server_name)?;
let (ch, conn) = endpoint.inner.connect(
Instant::now(),
config,
addr,
server_name,
StdRng::from_entropy,
)?;

let socket = endpoint.socket.clone();
Ok(endpoint
Expand Down Expand Up @@ -417,6 +427,7 @@ impl State {
meta.dst_ip,
meta.ecn.map(proto_ecn),
buf,
StdRng::from_entropy,
) {
Some(DatagramEvent::NewConnection(handle, conn)) => {
let conn = self.connections.insert(
Expand Down

0 comments on commit 3748e1e

Please sign in to comment.