Skip to content

Commit

Permalink
Seed Connection rng with Endpoint rng
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-yxchen committed Oct 19, 2023
1 parent 3748e1e commit 166752b
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 79 deletions.
5 changes: 3 additions & 2 deletions quinn-proto/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::{

use bytes::{Bytes, BytesMut};
use frame::StreamMetaVec;
use rand::{rngs::StdRng, Rng};
use rand::{rngs::StdRng, Rng, SeedableRng};
use thiserror::Error;
use tracing::{debug, error, trace, trace_span, warn};

Expand Down Expand Up @@ -252,7 +252,7 @@ impl Connection {
now: Instant,
version: u32,
allow_mtud: bool,
mut rng: StdRng,
rng_seed: <StdRng as SeedableRng>::Seed,
) -> Self {
let side = if server_config.is_some() {
Side::Server
Expand All @@ -268,6 +268,7 @@ impl Connection {
expected_token: Bytes::new(),
client_hello: None,
});
let mut rng = StdRng::from_seed(rng_seed);
let path_validated = server_config.as_ref().map_or(true, |c| c.use_retry);
let mut this = Self {
endpoint_config,
Expand Down
18 changes: 7 additions & 11 deletions quinn-proto/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{
};

use bytes::{BufMut, Bytes, BytesMut};
use rand::{rngs::StdRng, Rng, RngCore};
use rand::{rngs::StdRng, Rng, RngCore, SeedableRng};
use rustc_hash::FxHashMap;
use slab::Slab;
use thiserror::Error;
Expand Down Expand Up @@ -57,10 +57,10 @@ impl Endpoint {
config: Arc<EndpointConfig>,
server_config: Option<Arc<ServerConfig>>,
allow_mtud: bool,
rng: StdRng,
rng_seed: Option<<StdRng as SeedableRng>::Seed>,
) -> Self {
Self {
rng,
rng: rng_seed.map_or(StdRng::from_entropy(), StdRng::from_seed),
index: ConnectionIndex::default(),
connections: Slab::new(),
local_cid_generator: (config.connection_id_generator_factory.as_ref())(),
Expand Down Expand Up @@ -127,7 +127,6 @@ impl Endpoint {
local_ip: Option<IpAddr>,
ecn: Option<EcnCodepoint>,
data: BytesMut,
rng: impl FnOnce() -> StdRng,
) -> Option<DatagramEvent> {
let datagram_len = data.len();
let (first_decode, remaining) = match PartialDecode::new(
Expand Down Expand Up @@ -234,7 +233,7 @@ impl Endpoint {
};
return match first_decode.finish(Some(&*crypto.header.remote)) {
Ok(packet) => {
self.handle_first_packet(now, addresses, ecn, packet, remaining, &crypto, rng)
self.handle_first_packet(now, addresses, ecn, packet, remaining, &crypto)
}
Err(e) => {
trace!("unable to decode initial packet: {}", e);
Expand Down Expand Up @@ -318,7 +317,6 @@ impl Endpoint {
config: ClientConfig,
remote: SocketAddr,
server_name: &str,
rng: impl FnOnce() -> StdRng,
) -> Result<(ConnectionHandle, Connection), ConnectError> {
if self.is_full() {
return Err(ConnectError::TooManyConnections);
Expand Down Expand Up @@ -360,7 +358,6 @@ impl Endpoint {
tls,
None,
config.transport,
rng,
);
Ok((ch, conn))
}
Expand Down Expand Up @@ -407,7 +404,6 @@ impl Endpoint {
mut packet: Packet,
rest: Option<BytesMut>,
crypto: &Keys,
rng: impl FnOnce() -> StdRng,
) -> Option<DatagramEvent> {
let (src_cid, dst_cid, token, packet_number, version) = match packet.header {
Header::Initial {
Expand Down Expand Up @@ -559,7 +555,6 @@ impl Endpoint {
tls,
Some(server_config),
transport_config,
rng,
);
if dst_cid.len() != 0 {
self.index.insert_initial(dst_cid, ch);
Expand Down Expand Up @@ -594,8 +589,9 @@ impl Endpoint {
tls: Box<dyn crypto::Session>,
server_config: Option<Arc<ServerConfig>>,
transport_config: Arc<TransportConfig>,
rng: impl FnOnce() -> StdRng,
) -> Connection {
let mut rng_seed = [0; 32];
self.rng.fill_bytes(&mut rng_seed);
let conn = Connection::new(
self.config.clone(),
server_config,
Expand All @@ -610,7 +606,7 @@ impl Endpoint {
now,
version,
self.allow_mtud,
rng(),
rng_seed,
);

let id = self.connections.insert(ConnectionMeta {
Expand Down
51 changes: 12 additions & 39 deletions quinn-proto/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{
use assert_matches::assert_matches;
use bytes::Bytes;
use hex_literal::hex;
use rand::{rngs::StdRng, RngCore, SeedableRng};
use rand::RngCore;
use ring::hmac;
use rustls::AlertDescription;
use tracing::info;
Expand All @@ -31,7 +31,7 @@ fn version_negotiate_server() {
Default::default(),
Some(Arc::new(server_config())),
true,
StdRng::from_entropy(),
None,
);
let now = Instant::now();
let event = server.handle(
Expand All @@ -41,7 +41,6 @@ fn version_negotiate_server() {
None,
// Long-header packet with reserved version number
hex!("80 0a1a2a3a 04 00000000 04 00000000 00")[..].into(),
StdRng::from_entropy,
);
let Some(DatagramEvent::Response(Transmit { contents, .. })) = event else {
panic!("expected a response");
Expand Down Expand Up @@ -69,16 +68,10 @@ fn version_negotiate_client() {
}),
None,
true,
StdRng::from_entropy(),
None,
);
let (_, mut client_ch) = client
.connect(
Instant::now(),
client_config(),
server_addr,
"localhost",
StdRng::from_entropy,
)
.connect(Instant::now(), client_config(), server_addr, "localhost")
.unwrap();
let now = Instant::now();
let opt_event = client.handle(
Expand All @@ -92,7 +85,6 @@ fn version_negotiate_client() {
0a1a2a3a"
)[..]
.into(),
StdRng::from_entropy,
);
if let Some(DatagramEvent::ConnectionEvent(_, event)) = opt_event {
client_ch.handle_event(event);
Expand Down Expand Up @@ -192,12 +184,8 @@ fn server_stateless_reset() {
let mut pair = Pair::new(endpoint_config.clone(), server_config());
let (client_ch, _) = pair.connect();
pair.drive(); // Flush any post-handshake frames
pair.server.endpoint = Endpoint::new(
endpoint_config,
Some(Arc::new(server_config())),
true,
StdRng::from_entropy(),
);
pair.server.endpoint =
Endpoint::new(endpoint_config, Some(Arc::new(server_config())), true, None);
// Force the server to generate the smallest possible stateless reset
pair.client.connections.get_mut(&client_ch).unwrap().ping();
info!("resetting");
Expand All @@ -222,12 +210,8 @@ fn client_stateless_reset() {

let mut pair = Pair::new(endpoint_config.clone(), server_config());
let (_, server_ch) = pair.connect();
pair.client.endpoint = Endpoint::new(
endpoint_config,
Some(Arc::new(server_config())),
true,
StdRng::from_entropy(),
);
pair.client.endpoint =
Endpoint::new(endpoint_config, Some(Arc::new(server_config())), true, None);
// Send something big enough to allow room for a smaller stateless reset.
pair.server.connections.get_mut(&server_ch).unwrap().close(
pair.time,
Expand Down Expand Up @@ -1367,14 +1351,9 @@ fn cid_rotation() {
}),
Some(Arc::new(server_config())),
true,
StdRng::from_entropy(),
);
let client = Endpoint::new(
Arc::new(EndpointConfig::default()),
None,
true,
StdRng::from_entropy(),
);
let client = Endpoint::new(Arc::new(EndpointConfig::default()), None, true, None);

let mut pair = Pair::new_from_endpoint(client, server);
let (_, server_ch) = pair.connect();
Expand Down Expand Up @@ -1956,15 +1935,14 @@ fn malformed_token_len() {
Default::default(),
Some(Arc::new(server_config())),
true,
StdRng::from_entropy(),
None,
);
server.handle(
Instant::now(),
client_addr,
None,
None,
hex!("8900 0000 0101 0000 1b1b 841b 0000 0000 3f00")[..].into(),
StdRng::from_entropy,
);
}

Expand Down Expand Up @@ -2060,18 +2038,13 @@ fn migrate_detects_new_mtu_and_respects_original_peer_max_udp_payload_size() {
Arc::new(server_endpoint_config),
Some(Arc::new(server_config())),
true,
StdRng::from_entropy(),
None,
);
let client_endpoint_config = EndpointConfig {
max_udp_payload_size: VarInt::from(client_max_udp_payload_size),
..EndpointConfig::default()
};
let client = Endpoint::new(
Arc::new(client_endpoint_config),
None,
true,
StdRng::from_entropy(),
);
let client = Endpoint::new(Arc::new(client_endpoint_config), None, true, None);
let mut pair = Pair::new_from_endpoint(client, server);
pair.mtu = 1300;

Expand Down
20 changes: 4 additions & 16 deletions quinn-proto/src/tests/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use std::{
use assert_matches::assert_matches;
use bytes::BytesMut;
use lazy_static::lazy_static;
use rand::{rngs::StdRng, SeedableRng};
use rustls::{Certificate, KeyLogFile, PrivateKey};
use tracing::{info_span, trace};

Expand Down Expand Up @@ -54,9 +53,9 @@ impl Pair {
endpoint_config.clone(),
Some(Arc::new(server_config)),
true,
StdRng::from_entropy(),
None,
);
let client = Endpoint::new(endpoint_config, None, true, StdRng::from_entropy());
let client = Endpoint::new(endpoint_config, None, true, None);

Self::new_from_endpoint(client, server)
}
Expand Down Expand Up @@ -212,13 +211,7 @@ impl Pair {
let _guard = span.enter();
let (client_ch, client_conn) = self
.client
.connect(
Instant::now(),
config,
self.server.addr,
"localhost",
StdRng::from_entropy,
)
.connect(Instant::now(), config, self.server.addr, "localhost")
.unwrap();
self.client.connections.insert(client_ch, client_conn);
client_ch
Expand Down Expand Up @@ -344,12 +337,7 @@ impl TestEndpoint {

while self.inbound.front().map_or(false, |x| x.0 <= now) {
let (recv_time, ecn, packet) = self.inbound.pop_front().unwrap();
if let Some(event) = self
.endpoint
.handle(recv_time, remote, None, ecn, packet, || {
StdRng::from_entropy()
})
{
if let Some(event) = self.endpoint.handle(recv_time, remote, None, ecn, packet) {
match event {
DatagramEvent::NewConnection(ch, conn) => {
self.connections.insert(ch, conn);
Expand Down
1 change: 0 additions & 1 deletion quinn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ futures-io = { version = "0.3.19", optional = true }
rustc-hash = "1.1"
pin-project-lite = "0.2"
proto = { package = "quinn-proto", path = "../quinn-proto", version = "0.11", default-features = false }
rand = "0.8"
rustls = { version = "0.21.0", default-features = false, features = ["quic"], optional = true }
thiserror = "1.0.21"
tracing = "0.1.10"
Expand Down
14 changes: 4 additions & 10 deletions quinn/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use pin_project_lite::pin_project;
use proto::{
self as proto, ClientConfig, ConnectError, ConnectionHandle, DatagramEvent, ServerConfig,
};
use rand::{rngs::StdRng, SeedableRng};
use rustc_hash::FxHashMap;
use tokio::sync::{futures::Notified, mpsc, Notify};
use tracing::{Instrument, Span};
Expand Down Expand Up @@ -115,7 +114,7 @@ impl Endpoint {
Arc::new(config),
server_config.map(Arc::new),
allow_mtud,
StdRng::from_entropy(),
None,
),
addr.is_ipv6(),
runtime.clone(),
Expand Down Expand Up @@ -193,13 +192,9 @@ impl Endpoint {
addr
};

let (ch, conn) = endpoint.inner.connect(
Instant::now(),
config,
addr,
server_name,
StdRng::from_entropy,
)?;
let (ch, conn) = endpoint
.inner
.connect(Instant::now(), config, addr, server_name)?;

let socket = endpoint.socket.clone();
Ok(endpoint
Expand Down Expand Up @@ -427,7 +422,6 @@ impl State {
meta.dst_ip,
meta.ecn.map(proto_ecn),
buf,
StdRng::from_entropy,
) {
Some(DatagramEvent::NewConnection(handle, conn)) => {
let conn = self.connections.insert(
Expand Down

0 comments on commit 166752b

Please sign in to comment.