diff --git a/quinn-proto/src/address_discovery.rs b/quinn-proto/src/address_discovery.rs new file mode 100644 index 000000000..85b31b667 --- /dev/null +++ b/quinn-proto/src/address_discovery.rs @@ -0,0 +1,128 @@ +//! Address discovery types from +//! + +use crate::VarInt; + +pub(crate) const TRANSPORT_PARAMETER_CODE: u64 = 0x9f81a176; + +/// The role of each participant. +/// +/// When enabled, this is reported as a transport parameter. +#[derive(PartialEq, Eq, Clone, Copy, Debug, Default)] +pub(crate) enum Role { + /// Is able to report observer addresses to other peers, but it's not interested in receiving + /// reports about its own address. + SendOnly, + /// Is interested on reports about its own observed address, but will not report back to other + /// peers. + ReceiveOnly, + /// Will both report and receive reports of observed addresses. + Both, + /// Address discovery is disabled. + #[default] + Disabled, +} + +impl TryFrom for Role { + type Error = crate::transport_parameters::Error; + + fn try_from(value: VarInt) -> Result { + match value.0 { + 0 => Ok(Self::SendOnly), + 1 => Ok(Self::ReceiveOnly), + 2 => Ok(Self::Both), + _ => Err(crate::transport_parameters::Error::IllegalValue), + } + } +} + +impl Role { + /// Whether address discovery is disabled. + pub(crate) fn is_disabled(&self) -> bool { + matches!(self, Self::Disabled) + } + + /// Whether this peer's role allows for address reporting to other peers. + fn is_reporter(&self) -> bool { + matches!(self, Self::SendOnly | Self::Both) + } + + /// Whether this peer's role accepts observed address reports. + fn receives_reports(&self) -> bool { + matches!(self, Self::ReceiveOnly | Self::Both) + } + + /// Whether this peer should report observed addresses to the other peer. + pub(crate) fn should_report(&self, other: &Self) -> bool { + self.is_reporter() && other.receives_reports() + } + + /// Sets whether this peer should provide observed addresses to other peers. + pub(crate) fn send_reports_to_peers(&mut self, provide: bool) { + if provide { + self.enable_sending_reports_to_peers() + } else { + self.disable_sending_reports_to_peers() + } + } + + /// Enables sending reports of observed addresses to other peers. + fn enable_sending_reports_to_peers(&mut self) { + match self { + Self::SendOnly => {} // already enabled + Self::ReceiveOnly => *self = Self::Both, + Self::Both => {} // already enabled + Self::Disabled => *self = Self::SendOnly, + } + } + + /// Disables sending reports of observed addresses to other peers. + fn disable_sending_reports_to_peers(&mut self) { + match self { + Self::SendOnly => *self = Self::Disabled, + Self::ReceiveOnly => {} // already disabled + Self::Both => *self = Self::ReceiveOnly, + Self::Disabled => {} // already disabled + } + } + + /// Sets whether this peer should accept received reports of observed addresses from other + /// peers. + pub(crate) fn receive_reports_from_peers(&mut self, receive: bool) { + if receive { + self.enable_receiving_reports_from_peers() + } else { + self.disable_receiving_reports_from_peers() + } + } + + /// Enables receiving reports of observed addresses from other peers. + fn enable_receiving_reports_from_peers(&mut self) { + match self { + Self::SendOnly => *self = Self::Both, + Self::ReceiveOnly => {} // already enabled + Self::Both => {} // already enabled + Self::Disabled => *self = Self::ReceiveOnly, + } + } + + /// Disables receiving reports of observed addresses from other peers. + fn disable_receiving_reports_from_peers(&mut self) { + match self { + Self::SendOnly => {} // already disabled + Self::ReceiveOnly => *self = Self::Disabled, + Self::Both => *self = Self::SendOnly, + Self::Disabled => {} // already disabled + } + } + + /// Gives the [`VarInt`] representing this [`Role`] as a transport parameter. + pub(crate) fn as_transport_parameter(&self) -> Option { + match self { + Self::SendOnly => Some(VarInt(0)), + Self::ReceiveOnly => Some(VarInt(1)), + Self::Both => Some(VarInt(2)), + Self::Disabled => None, + } + } +} diff --git a/quinn-proto/src/config.rs b/quinn-proto/src/config.rs index 571f998dd..5dc0204fe 100644 --- a/quinn-proto/src/config.rs +++ b/quinn-proto/src/config.rs @@ -14,6 +14,7 @@ use thiserror::Error; #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))] use crate::crypto::rustls::{configured_provider, QuicServerConfig}; use crate::{ + address_discovery, cid_generator::{ConnectionIdGenerator, HashedConnectionIdGenerator}, congestion, crypto::{self, HandshakeTokenKey, HmacKey}, @@ -63,6 +64,8 @@ pub struct TransportConfig { pub(crate) congestion_controller_factory: Arc, pub(crate) enable_segmentation_offload: bool, + + pub(crate) address_discovery_role: crate::address_discovery::Role, } impl TransportConfig { @@ -334,6 +337,27 @@ impl TransportConfig { self.enable_segmentation_offload = enabled; self } + + /// Whether to send observed address reports to peers. + /// + /// This will aid peers in inferring their reachable address, which in most NATd networks + /// will not be easily available to them. + pub fn send_observed_address_reports(&mut self, enabled: bool) -> &mut Self { + self.address_discovery_role.send_reports_to_peers(enabled); + self + } + + /// Whether to receive observed address reports from other peers. + /// + /// Peers with the address discovery extension enabled that are willing to provide observed + /// address reports will do so if this transport parameter is set. In general, observed address + /// reports cannot be trusted. This, however, can aid the current endpoint in inferring its + /// reachable address, which in most NATd networks will not be easily available. + pub fn receive_observed_address_reports(&mut self, enabled: bool) -> &mut Self { + self.address_discovery_role + .receive_reports_from_peers(enabled); + self + } } impl Default for TransportConfig { @@ -374,6 +398,8 @@ impl Default for TransportConfig { congestion_controller_factory: Arc::new(congestion::CubicConfig::default()), enable_segmentation_offload: true, + + address_discovery_role: address_discovery::Role::default(), } } } @@ -405,6 +431,7 @@ impl fmt::Debug for TransportConfig { deterministic_packet_numbers: _, congestion_controller_factory: _, enable_segmentation_offload, + address_discovery_role, } = self; fmt.debug_struct("TransportConfig") .field("max_concurrent_bidi_streams", max_concurrent_bidi_streams) @@ -432,6 +459,7 @@ impl fmt::Debug for TransportConfig { .field("datagram_send_buffer_size", datagram_send_buffer_size) .field("congestion_controller_factory", &"[ opaque ]") .field("enable_segmentation_offload", enable_segmentation_offload) + .field("address_discovery_role", address_discovery_role) .finish() } } diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index babb0d757..a29d25932 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -19,7 +19,7 @@ use crate::{ coding::BufMutExt, config::{ServerConfig, TransportConfig}, crypto::{self, KeyPair, Keys, PacketKey}, - frame::{self, Close, Datagram, FrameStruct}, + frame::{self, Close, Datagram, FrameStruct, ObservedAddr}, packet::{ FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, LongType, Packet, PacketNumber, PartialDecode, SpaceId, @@ -226,6 +226,12 @@ pub struct Connection { /// no outgoing application data. app_limited: bool, + // + // ObservedAddr + // + /// Sequence number for the next observed address frame sent to the peer. + next_observed_addr_seq_no: VarInt, + streams: StreamsState, /// Surplus remote CIDs for future use on new paths rem_cids: CidQueue, @@ -345,6 +351,8 @@ impl Connection { receiving_ecn: false, total_authed_packets: 0, + next_observed_addr_seq_no: 0u32.into(), + streams: StreamsState::new( side, config.max_concurrent_uni_streams, @@ -2638,6 +2646,9 @@ impl Connection { let mut close = None; let payload_len = payload.len(); let mut ack_eliciting = false; + // if this packet triggers a path migration and includes a observed address frame, it's + // stored here + let mut migration_observed_addr = None; for result in frame::Iter::new(payload)? { let frame = result?; let span = match frame { @@ -2681,7 +2692,8 @@ impl Connection { Frame::Padding | Frame::PathChallenge(_) | Frame::PathResponse(_) - | Frame::NewConnectionId(_) => {} + | Frame::NewConnectionId(_) + | Frame::ObservedAddr(_) => {} _ => { is_probing_packet = false; } @@ -2909,6 +2921,33 @@ impl Connection { self.discard_space(now, SpaceId::Handshake); } } + Frame::ObservedAddr(observed) => { + // check if params allows the peer to send report and this node to receive it + if !self + .peer_params + .address_discovery_role + .should_report(&self.config.address_discovery_role) + { + return Err(TransportError::PROTOCOL_VIOLATION( + "received OBSERVED_ADDRESS frame when not negotiated", + )); + } + // must only be sent in data space + if packet.header.space() != SpaceId::Data { + return Err(TransportError::PROTOCOL_VIOLATION( + "OBSERVED_ADDRESS frame outside data space", + )); + } + + if remote == self.path.remote { + if let Some(updated) = self.path.update_observed_addr_report(observed) { + self.events.push_back(Event::ObservedAddr(updated)); + } + } else { + // include in migration + migration_observed_addr = Some(observed) + } + } } } @@ -2945,7 +2984,7 @@ impl Connection { .migration, "migration-initiating packets should have been dropped immediately" ); - self.migrate(now, remote); + self.migrate(now, remote, migration_observed_addr); // Break linkability, if possible self.update_rem_cid(); self.spin = false; @@ -2954,7 +2993,7 @@ impl Connection { Ok(()) } - fn migrate(&mut self, now: Instant, remote: SocketAddr) { + fn migrate(&mut self, now: Instant, remote: SocketAddr, observed_addr: Option) { trace!(%remote, "migration initiated"); // Reset rtt/congestion state for new path unless it looks like a NAT rebinding. // Note that the congestion window will not grow until validation terminates. Helps mitigate @@ -2974,6 +3013,12 @@ impl Connection { &self.config, ) }; + new_path.last_observed_addr_report = self.path.last_observed_addr_report.clone(); + if let Some(report) = observed_addr { + if let Some(updated) = new_path.update_observed_addr_report(report) { + self.events.push_back(Event::ObservedAddr(updated)); + } + } new_path.challenge = Some(self.rng.gen()); new_path.challenge_pending = true; let prev_pto = self.pto(SpaceId::Data); @@ -3058,6 +3103,53 @@ impl Connection { self.stats.frame_tx.handshake_done.saturating_add(1); } + // OBSERVED_ADDR + let mut send_observed_address = + |space_id: SpaceId, + buf: &mut Vec, + max_size: usize, + space: &mut PacketSpace, + sent: &mut SentFrames, + stats: &mut ConnectionStats, + skip_sent_check: bool| { + // should only be sent within Data space and only if allowed by extension + // negotiation + // send is also skipped if the path has already sent an observed address + let send_allowed = self + .config + .address_discovery_role + .should_report(&self.peer_params.address_discovery_role); + let send_required = + space.pending.observed_addr || !self.path.observed_addr_sent || skip_sent_check; + if space_id != SpaceId::Data || !send_allowed || !send_required { + return; + } + + let observed = + frame::ObservedAddr::new(self.path.remote, self.next_observed_addr_seq_no); + + if buf.len() + observed.size() < max_size { + observed.write(buf); + + self.next_observed_addr_seq_no = + self.next_observed_addr_seq_no.saturating_add(1u8); + self.path.observed_addr_sent = true; + + stats.frame_tx.observed_addr += 1; + sent.retransmits.get_or_create().observed_addr = true; + space.pending.observed_addr = false; + } + }; + send_observed_address( + space_id, + buf, + max_size, + space, + &mut sent, + &mut self.stats, + false, + ); + // PING if mem::replace(&mut space.ping_pending, false) { trace!("PING"); @@ -3127,7 +3219,16 @@ impl Connection { trace!("PATH_CHALLENGE {:08x}", token); buf.write(frame::FrameType::PATH_CHALLENGE); buf.write(token); - self.stats.frame_tx.path_challenge += 1; + + send_observed_address( + space_id, + buf, + max_size, + space, + &mut sent, + &mut self.stats, + true, + ); } } @@ -3140,6 +3241,19 @@ impl Connection { buf.write(frame::FrameType::PATH_RESPONSE); buf.write(token); self.stats.frame_tx.path_response += 1; + + // NOTE: this is technically not required but might be useful to ride the + // request/response nature of path challenges to refresh an observation + // Since PATH_RESPONSE is a probing frame, this is allowed by the spec. + send_observed_address( + space_id, + buf, + max_size, + space, + &mut sent, + &mut self.stats, + true, + ); } } @@ -3760,6 +3874,8 @@ pub enum Event { DatagramReceived, /// One or more application datagrams have been sent after blocking DatagramsUnblocked, + /// Received an observation of our external address from the peer. + ObservedAddr(SocketAddr), } fn instant_saturating_sub(x: Instant, y: Instant) -> Duration { diff --git a/quinn-proto/src/connection/paths.rs b/quinn-proto/src/connection/paths.rs index 2c0476c06..9f7a4e0b9 100644 --- a/quinn-proto/src/connection/paths.rs +++ b/quinn-proto/src/connection/paths.rs @@ -7,7 +7,10 @@ use super::{ pacing::Pacer, spaces::{PacketSpace, SentPacket}, }; -use crate::{congestion, packet::SpaceId, Duration, Instant, TransportConfig, TIMER_GRANULARITY}; +use crate::{ + congestion, frame::ObservedAddr, packet::SpaceId, Duration, Instant, TransportConfig, + TIMER_GRANULARITY, +}; /// Description of a particular network path pub(super) struct PathData { @@ -37,6 +40,11 @@ pub(super) struct PathData { /// Used in persistent congestion determination. pub(super) first_packet_after_rtt_sample: Option<(SpaceId, u64)>, pub(super) in_flight: InFlight, + /// Whether this path has had it's remote address reported back to the peer. This only happens + /// if both peers agree to so based on their transport parameters. + pub(super) observed_addr_sent: bool, + /// Observed address frame with the largest sequence number received from the peer on this path. + pub(super) last_observed_addr_report: Option, /// Number of the first packet sent on this path /// /// Used to determine whether a packet was sent on an earlier path. Insufficient to determine if @@ -90,10 +98,15 @@ impl PathData { ), first_packet_after_rtt_sample: None, in_flight: InFlight::new(), + observed_addr_sent: false, + last_observed_addr_report: None, first_packet: None, } } + /// Create a new path from a previous one. + /// + /// This should only be called when migrating paths. pub(super) fn from_previous(remote: SocketAddr, prev: &Self, now: Instant) -> Self { let congestion = prev.congestion.clone_box(); let smoothed_rtt = prev.rtt.get(); @@ -111,6 +124,8 @@ impl PathData { mtud: prev.mtud.clone(), first_packet_after_rtt_sample: prev.first_packet_after_rtt_sample, in_flight: InFlight::new(), + observed_addr_sent: false, + last_observed_addr_report: None, first_packet: None, } } @@ -156,6 +171,37 @@ impl PathData { self.in_flight.remove(packet); true } + + /// Updates the last observed address report received on this path. + /// + /// If the address was updated, it's returned to be informed to the application. + #[must_use = "updated observed address must be reported to the application"] + pub(super) fn update_observed_addr_report( + &mut self, + observed: ObservedAddr, + ) -> Option { + match self.last_observed_addr_report.as_mut() { + Some(prev) => { + if prev.seq_no >= observed.seq_no { + // frames that do not increase the sequence number on this path are ignored + None + } else if prev.ip == observed.ip && prev.port == observed.port { + // keep track of the last seq_no but do not report the address as updated + prev.seq_no = observed.seq_no; + None + } else { + let addr = observed.socket_addr(); + self.last_observed_addr_report = Some(observed); + Some(addr) + } + } + None => { + let addr = observed.socket_addr(); + self.last_observed_addr_report = Some(observed); + Some(addr) + } + } + } } /// RTT estimation for a particular network path diff --git a/quinn-proto/src/connection/spaces.rs b/quinn-proto/src/connection/spaces.rs index ed58b51c1..0d0edad68 100644 --- a/quinn-proto/src/connection/spaces.rs +++ b/quinn-proto/src/connection/spaces.rs @@ -309,6 +309,7 @@ pub struct Retransmits { pub(super) retire_cids: Vec, pub(super) ack_frequency: bool, pub(super) handshake_done: bool, + pub(super) observed_addr: bool, } impl Retransmits { @@ -326,6 +327,7 @@ impl Retransmits { && self.retire_cids.is_empty() && !self.ack_frequency && !self.handshake_done + && !self.observed_addr } } @@ -347,6 +349,7 @@ impl ::std::ops::BitOrAssign for Retransmits { self.retire_cids.extend(rhs.retire_cids); self.ack_frequency |= rhs.ack_frequency; self.handshake_done |= rhs.handshake_done; + self.observed_addr |= rhs.observed_addr; } } diff --git a/quinn-proto/src/connection/stats.rs b/quinn-proto/src/connection/stats.rs index 9ddb42d1a..31f5f1d14 100644 --- a/quinn-proto/src/connection/stats.rs +++ b/quinn-proto/src/connection/stats.rs @@ -53,6 +53,7 @@ pub struct FrameStats { pub streams_blocked_uni: u64, pub stop_sending: u64, pub stream: u64, + pub observed_addr: u64, } impl FrameStats { @@ -93,6 +94,7 @@ impl FrameStats { Frame::AckFrequency(_) => self.ack_frequency += 1, Frame::ImmediateAck => self.immediate_ack += 1, Frame::HandshakeDone => self.handshake_done = self.handshake_done.saturating_add(1), + Frame::ObservedAddr(_) => self.observed_addr += 1, } } } diff --git a/quinn-proto/src/frame.rs b/quinn-proto/src/frame.rs index 0bc7f34ad..915721c7f 100644 --- a/quinn-proto/src/frame.rs +++ b/quinn-proto/src/frame.rs @@ -1,6 +1,7 @@ use std::{ fmt::{self, Write}, io, mem, + net::{IpAddr, SocketAddr}, ops::{Range, RangeInclusive}, }; @@ -134,6 +135,9 @@ frame_types! { ACK_FREQUENCY = 0xaf, IMMEDIATE_ACK = 0x1f, // DATAGRAM + // ADDRESS DISCOVERY REPORT + OBSERVED_IPV4_ADDR = 0x9f81a6, + OBSERVED_IPV6_ADDR = 0x9f81a7, } const STREAM_TYS: RangeInclusive = RangeInclusive::new(0x08, 0x0f); @@ -164,6 +168,7 @@ pub(crate) enum Frame { AckFrequency(AckFrequency), ImmediateAck, HandshakeDone, + ObservedAddr(ObservedAddr), } impl Frame { @@ -205,6 +210,7 @@ impl Frame { AckFrequency(_) => FrameType::ACK_FREQUENCY, ImmediateAck => FrameType::IMMEDIATE_ACK, HandshakeDone => FrameType::HANDSHAKE_DONE, + ObservedAddr(ref observed) => observed.get_type(), } } @@ -687,6 +693,11 @@ impl Iter { reordering_threshold: self.bytes.get()?, }), FrameType::IMMEDIATE_ACK => Frame::ImmediateAck, + FrameType::OBSERVED_IPV4_ADDR | FrameType::OBSERVED_IPV6_ADDR => { + let is_ipv6 = ty == FrameType::OBSERVED_IPV6_ADDR; + let observed = ObservedAddr::read(&mut self.bytes, is_ipv6)?; + Frame::ObservedAddr(observed) + } _ => { if let Some(s) = ty.stream() { Frame::Stream(Stream { @@ -929,8 +940,86 @@ impl AckFrequency { } } +/* Address Discovery https://datatracker.ietf.org/doc/draft-seemann-quic-address-discovery/ */ + +/// Conjuction of the information contained in the address discovery frames +/// ([`FrameType::OBSERVED_IPV4_ADDR`], [`FrameType::OBSERVED_IPV6_ADDR`]). +#[derive(Debug, PartialEq, Eq, Clone)] +pub(crate) struct ObservedAddr { + /// Monotonically increasing integer within the same connection. + pub(crate) seq_no: VarInt, + /// Reported observed address. + pub(crate) ip: IpAddr, + /// Reported observed port. + pub(crate) port: u16, +} + +impl ObservedAddr { + pub(crate) fn new>(remote: std::net::SocketAddr, seq_no: N) -> Self { + Self { + ip: remote.ip(), + port: remote.port(), + seq_no: seq_no.into(), + } + } + + /// Get the [`FrameType`] for this frame. + pub(crate) fn get_type(&self) -> FrameType { + if self.ip.is_ipv6() { + FrameType::OBSERVED_IPV6_ADDR + } else { + FrameType::OBSERVED_IPV4_ADDR + } + } + + /// Compute the number of bytes needed to encode the frame. + pub(crate) fn size(&self) -> usize { + let type_size = VarInt(self.get_type().0).size(); + let req_id_bytes = self.seq_no.size(); + let ip_bytes = if self.ip.is_ipv6() { 16 } else { 4 }; + let port_bytes = 2; + type_size + req_id_bytes + ip_bytes + port_bytes + } + + /// Unconditionally write this frame to `buf`. + pub(crate) fn write(&self, buf: &mut W) { + buf.write(self.get_type()); + buf.write(self.seq_no); + match self.ip { + IpAddr::V4(ipv4_addr) => { + buf.write(ipv4_addr); + } + IpAddr::V6(ipv6_addr) => { + buf.write(ipv6_addr); + } + } + buf.write::(self.port); + } + + /// Reads the frame contents from the buffer. + /// + /// Should only be called when the fram type has been identified as + /// [`FrameType::OBSERVED_IPV4_ADDR`] or [`FrameType::OBSERVED_IPV6_ADDR`]. + pub(crate) fn read(bytes: &mut R, is_ipv6: bool) -> coding::Result { + let seq_no = bytes.get()?; + let ip = if is_ipv6 { + IpAddr::V6(bytes.get()?) + } else { + IpAddr::V4(bytes.get()?) + }; + let port = bytes.get()?; + Ok(Self { seq_no, ip, port }) + } + + /// Gives the [`SocketAddr`] reported in the frame. + pub(crate) fn socket_addr(&self) -> SocketAddr { + (self.ip, self.port).into() + } +} + #[cfg(test)] mod test { + use super::*; use crate::coding::Codec; use assert_matches::assert_matches; @@ -996,4 +1085,29 @@ mod test { assert_eq!(frames.len(), 1); assert_matches!(&frames[0], Frame::ImmediateAck); } + + /// Test that encoding and decoding [`ObservedAddr`] produces the same result. + #[test] + fn test_observed_addr_roundrip() { + let observed_addr = ObservedAddr { + seq_no: VarInt(42), + ip: std::net::Ipv4Addr::LOCALHOST.into(), + port: 4242, + }; + let mut buf = Vec::with_capacity(observed_addr.size()); + observed_addr.write(&mut buf); + + assert_eq!( + observed_addr.size(), + buf.len(), + "expected written bytes and actual size differ" + ); + + let mut decoded = frames(buf); + assert_eq!(decoded.len(), 1); + match decoded.pop().expect("non empty") { + Frame::ObservedAddr(decoded) => assert_eq!(decoded, observed_addr), + x => panic!("incorrect frame {x:?}"), + } + } } diff --git a/quinn-proto/src/lib.rs b/quinn-proto/src/lib.rs index 79eddc827..225e4eb93 100644 --- a/quinn-proto/src/lib.rs +++ b/quinn-proto/src/lib.rs @@ -88,6 +88,8 @@ pub use crate::cid_generator::{ mod token; use token::{ResetToken, RetryToken}; +mod address_discovery; + #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index cb39b351f..98438d342 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -3185,6 +3185,280 @@ fn voluntary_ack_with_large_datagrams() { ); } +/// Test the address discovery extension on a normal setup. +#[test] +fn address_discovery() { + let _guard = subscribe(); + + let server = ServerConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role::Both, + ..TransportConfig::default() + }), + ..server_config() + }; + let mut pair = Pair::new(Default::default(), server); + let client_config = ClientConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role::Both, + ..TransportConfig::default() + }), + ..client_config() + }; + let conn_handle = pair.begin_connect(client_config); + + // wait for idle connections + pair.drive(); + + // check that the client received the correct address + let expected_addr = pair.client.addr; + let conn = pair.client_conn_mut(conn_handle); + assert_matches!(conn.poll(), Some(Event::HandshakeDataReady)); + assert_matches!(conn.poll(), Some(Event::Connected)); + assert_matches!(conn.poll(), Some(Event::ObservedAddr(addr)) if addr == expected_addr); + assert_matches!(conn.poll(), None); + + // check that the server received the correct address + let conn_handle = pair.server.assert_accept(); + let expected_addr = pair.server.addr; + let conn = pair.server_conn_mut(conn_handle); + assert_matches!(conn.poll(), Some(Event::HandshakeDataReady)); + assert_matches!(conn.poll(), Some(Event::Connected)); + assert_matches!(conn.poll(), Some(Event::ObservedAddr(addr)) if addr == expected_addr); + assert_matches!(conn.poll(), None); +} + +/// Test that a different address discovery configuration on 0rtt used by the client is accepted by +/// the server. +/// NOTE: this test is the same as zero_rtt_happypath, changing client transport parameters on +/// resumption. +#[test] +fn address_discovery_zero_rtt_accepted() { + let _guard = subscribe(); + let server = ServerConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role::Both, + ..TransportConfig::default() + }), + ..server_config() + }; + let mut pair = Pair::new(Default::default(), server); + + pair.server.incoming_connection_behavior = IncomingConnectionBehavior::Validate; + let client_cfg = ClientConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role::Both, + ..TransportConfig::default() + }), + ..client_config() + }; + let alt_client_cfg = ClientConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role::Disabled, + ..TransportConfig::default() + }), + ..client_cfg.clone() + }; + + // Establish normal connection + let client_ch = pair.begin_connect(client_cfg); + pair.drive(); + pair.server.assert_accept(); + pair.client + .connections + .get_mut(&client_ch) + .unwrap() + .close(pair.time, VarInt(0), [][..].into()); + pair.drive(); + + pair.client.addr = SocketAddr::new( + Ipv6Addr::LOCALHOST.into(), + CLIENT_PORTS.lock().unwrap().next().unwrap(), + ); + info!("resuming session"); + let client_ch = pair.begin_connect(alt_client_cfg); + assert!(pair.client_conn_mut(client_ch).has_0rtt()); + let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); + const MSG: &[u8] = b"Hello, 0-RTT!"; + pair.client_send(client_ch, s).write(MSG).unwrap(); + pair.drive(); + + let conn = pair.client_conn_mut(client_ch); + assert_matches!(conn.poll(), Some(Event::HandshakeDataReady)); + assert_matches!(conn.poll(), Some(Event::Connected)); + + assert!(pair.client_conn_mut(client_ch).accepted_0rtt()); + let server_ch = pair.server.assert_accept(); + + let conn = pair.server_conn_mut(server_ch); + assert_matches!(conn.poll(), Some(Event::HandshakeDataReady)); + // We don't currently preserve stream event order wrt. connection events + assert_matches!(conn.poll(), Some(Event::Connected)); + assert_matches!( + conn.poll(), + Some(Event::Stream(StreamEvent::Opened { dir: Dir::Uni })) + ); + + let mut recv = pair.server_recv(server_ch, s); + let mut chunks = recv.read(false).unwrap(); + assert_matches!( + chunks.next(usize::MAX), + Ok(Some(chunk)) if chunk.offset == 0 && chunk.bytes == MSG + ); + let _ = chunks.finalize(); + assert_eq!(pair.client_conn_mut(client_ch).lost_packets(), 0); +} + +/// Test that a different address discovery configuration on 0rtt used by the server is rejected by +/// the client. +/// NOTE: the server MUST not change configuration on resumption. However, there is no designed +/// behaviour when this is encountered. Quinn chooses to accept and then close the connection, +/// which is what this test checks. +#[test] +fn address_discovery_zero_rtt_rejection() { + let _guard = subscribe(); + let server_cfg = ServerConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role::Disabled, + ..TransportConfig::default() + }), + ..server_config() + }; + let alt_server_cfg = ServerConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role::SendOnly, + ..TransportConfig::default() + }), + ..server_cfg.clone() + }; + let mut pair = Pair::new(Default::default(), server_cfg); + let client_cfg = ClientConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role::Both, + ..TransportConfig::default() + }), + ..client_config() + }; + + // Establish normal connection + let client_ch = pair.begin_connect(client_cfg.clone()); + pair.drive(); + let server_ch = pair.server.assert_accept(); + let conn = pair.server_conn_mut(server_ch); + assert_matches!(conn.poll(), Some(Event::HandshakeDataReady)); + assert_matches!(conn.poll(), Some(Event::Connected)); + assert_matches!(conn.poll(), None); + pair.client + .connections + .get_mut(&client_ch) + .unwrap() + .close(pair.time, VarInt(0), [][..].into()); + pair.drive(); + assert_matches!( + pair.server_conn_mut(server_ch).poll(), + Some(Event::ConnectionLost { .. }) + ); + assert_matches!(pair.server_conn_mut(server_ch).poll(), None); + pair.client.connections.clear(); + pair.server.connections.clear(); + + // Changing address discovery configurations makes the client close the connection + pair.server + .set_server_config(Some(Arc::new(alt_server_cfg))); + info!("resuming session"); + let client_ch = pair.begin_connect(client_cfg); + assert!(pair.client_conn_mut(client_ch).has_0rtt()); + let s = pair.client_streams(client_ch).open(Dir::Uni).unwrap(); + const MSG: &[u8] = b"Hello, 0-RTT!"; + pair.client_send(client_ch, s).write(MSG).unwrap(); + pair.drive(); + let conn = pair.client_conn_mut(server_ch); + assert_matches!(conn.poll(), Some(Event::HandshakeDataReady)); + assert_matches!( + conn.poll(), + Some(Event::ConnectionLost { reason }) if matches!(reason, ConnectionError::TransportError(_) ) + ); +} + +#[test] +fn address_discovery_retransmission() { + let _guard = subscribe(); + + let server = ServerConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role::Both, + ..TransportConfig::default() + }), + ..server_config() + }; + let mut pair = Pair::new(Default::default(), server); + let client_config = ClientConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role::Both, + ..TransportConfig::default() + }), + ..client_config() + }; + let client_ch = pair.begin_connect(client_config); + pair.step(); + + // lose the last packet + pair.client.inbound.pop_back().unwrap(); + pair.step(); + let conn = pair.client_conn_mut(client_ch); + assert_matches!(conn.poll(), Some(Event::HandshakeDataReady)); + assert_matches!(conn.poll(), Some(Event::Connected)); + assert_matches!(conn.poll(), None); + + pair.drive(); + let conn = pair.client_conn_mut(client_ch); + assert_matches!(conn.poll(), + Some(Event::ObservedAddr(addr)) if addr == pair.client.addr); +} + +#[test] +fn address_discovery_rebind_retransmission() { + let _guard = subscribe(); + + let server = ServerConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role::Both, + ..TransportConfig::default() + }), + ..server_config() + }; + let mut pair = Pair::new(Default::default(), server); + let client_config = ClientConfig { + transport: Arc::new(TransportConfig { + address_discovery_role: crate::address_discovery::Role::Both, + ..TransportConfig::default() + }), + ..client_config() + }; + let client_ch = pair.begin_connect(client_config); + pair.step(); + + // lose the last packet + pair.client.inbound.pop_back().unwrap(); + pair.step(); + let conn = pair.client_conn_mut(client_ch); + assert_matches!(conn.poll(), Some(Event::HandshakeDataReady)); + assert_matches!(conn.poll(), Some(Event::Connected)); + assert_matches!(conn.poll(), None); + + // simulate a rebind to ensure we will get an updated address instead of retransmitting + // outdated info + pair.client_conn_mut(client_ch).local_address_changed(); + pair.client + .addr + .set_port(pair.client.addr.port().overflowing_add(1).0); + + pair.drive(); + let conn = pair.client_conn_mut(client_ch); + assert_matches!(conn.poll(), + Some(Event::ObservedAddr(addr)) if addr == pair.client.addr); +} + #[test] fn reject_short_idcid() { let _guard = subscribe(); diff --git a/quinn-proto/src/transport_parameters.rs b/quinn-proto/src/transport_parameters.rs index 381968059..a224ba8e1 100644 --- a/quinn-proto/src/transport_parameters.rs +++ b/quinn-proto/src/transport_parameters.rs @@ -16,6 +16,7 @@ use rand::{Rng as _, RngCore}; use thiserror::Error; use crate::{ + address_discovery, cid_generator::ConnectionIdGenerator, cid_queue::CidQueue, coding::{BufExt, BufMutExt, UnexpectedEnd}, @@ -104,6 +105,8 @@ macro_rules! make_struct { /// of transport parameter extensions. /// When present, it is included during serialization but ignored during deserialization. pub(crate) grease_transport_parameter: Option, + /// The role of this peer in address discovery, if any. + pub(crate) address_discovery_role: address_discovery::Role, } // We deliberately don't implement the `Default` trait, since that would be public, and @@ -126,6 +129,8 @@ macro_rules! make_struct { stateless_reset_token: None, preferred_address: None, grease_transport_parameter: None, + + address_discovery_role: address_discovery::Role::Disabled, } } } @@ -168,6 +173,7 @@ impl TransportParameters { VarInt::from_u64(u64::try_from(TIMER_GRANULARITY.as_micros()).unwrap()).unwrap(), ), grease_transport_parameter: Some(ReservedTransportParameter::random(rng)), + address_discovery_role: config.address_discovery_role, ..Self::default() } } @@ -184,6 +190,7 @@ impl TransportParameters { || cached.initial_max_streams_uni > self.initial_max_streams_uni || cached.max_datagram_frame_size > self.max_datagram_frame_size || cached.grease_quic_bit && !self.grease_quic_bit + || cached.address_discovery_role != self.address_discovery_role { return Err(TransportError::PROTOCOL_VIOLATION( "0-RTT accepted with incompatible transport parameters", @@ -357,6 +364,12 @@ impl TransportParameters { w.write_var(x.size() as u64); w.write(x); } + + if let Some(varint_role) = self.address_discovery_role.as_transport_parameter() { + w.write_var(address_discovery::TRANSPORT_PARAMETER_CODE); + w.write_var(varint_role.size() as u64); + w.write(varint_role); + } } /// Decode `TransportParameters` from buffer @@ -421,6 +434,21 @@ impl TransportParameters { _ => return Err(Error::Malformed), }, 0xff04de1b => params.min_ack_delay = Some(r.get().unwrap()), + address_discovery::TRANSPORT_PARAMETER_CODE => { + if !params.address_discovery_role.is_disabled() { + // duplicate parameter + return Err(Error::Malformed); + } + let value: VarInt = r.get()?; + if len != value.size() { + return Err(Error::Malformed); + } + params.address_discovery_role = value.try_into()?; + tracing::debug!( + role = ?params.address_discovery_role, + "address discovery enabled for peer" + ); + } _ => { macro_rules! parse { {$($(#[$doc:meta])* $name:ident ($code:expr) = $default:expr,)*} => { @@ -581,6 +609,7 @@ mod test { }), grease_quic_bit: true, min_ack_delay: Some(2_000u32.into()), + address_discovery_role: address_discovery::Role::SendOnly, ..TransportParameters::default() }; params.write(&mut buf); diff --git a/quinn-proto/src/varint.rs b/quinn-proto/src/varint.rs index a72fb3431..08022d9db 100644 --- a/quinn-proto/src/varint.rs +++ b/quinn-proto/src/varint.rs @@ -50,6 +50,14 @@ impl VarInt { self.0 } + /// Saturating integer addition. Computes self + rhs, saturating at the numeric bounds instead + /// of overflowing. + pub fn saturating_add(self, rhs: impl Into) -> Self { + let rhs = rhs.into(); + let inner = self.0.saturating_add(rhs.0).min(Self::MAX.0); + Self(inner) + } + /// Compute the number of bytes needed to encode this value pub(crate) const fn size(self) -> usize { let x = self.0; @@ -191,3 +199,19 @@ impl Codec for VarInt { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_saturating_add() { + // add within range behaves normally + let large: VarInt = u32::MAX.into(); + let next = u64::from(u32::MAX) + 1; + assert_eq!(large.saturating_add(1u8), VarInt::from_u64(next).unwrap()); + + // outside range saturates + assert_eq!(VarInt::MAX.saturating_add(1u8), VarInt::MAX) + } +} diff --git a/quinn/Cargo.toml b/quinn/Cargo.toml index a061520d7..d60989abc 100644 --- a/quinn/Cargo.toml +++ b/quinn/Cargo.toml @@ -68,6 +68,7 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread", "time", "macros tracing-subscriber = { workspace = true } tracing-futures = { workspace = true } url = { workspace = true } +tokio-stream = "0.1.15" [[example]] name = "server" diff --git a/quinn/examples/client.rs b/quinn/examples/client.rs index 0ace61f95..80fc3562d 100644 --- a/quinn/examples/client.rs +++ b/quinn/examples/client.rs @@ -13,7 +13,7 @@ use std::{ use anyhow::{anyhow, Result}; use clap::Parser; -use proto::crypto::rustls::QuicClientConfig; +use proto::{crypto::rustls::QuicClientConfig, TransportConfig}; use rustls::pki_types::CertificateDer; use tracing::{error, info}; use url::Url; @@ -101,8 +101,13 @@ async fn run(options: Opt) -> Result<()> { client_crypto.key_log = Arc::new(rustls::KeyLogFile::new()); } - let client_config = + let mut transport = TransportConfig::default(); + transport + .send_observed_address_reports(true) + .receive_observed_address_reports(true); + let mut client_config = quinn::ClientConfig::new(Arc::new(QuicClientConfig::try_from(client_crypto)?)); + client_config.transport_config(Arc::new(transport)); let mut endpoint = quinn::Endpoint::client(options.bind)?; endpoint.set_default_client_config(client_config); @@ -117,6 +122,18 @@ async fn run(options: Opt) -> Result<()> { .await .map_err(|e| anyhow!("failed to connect: {}", e))?; eprintln!("connected at {:?}", start.elapsed()); + let mut external_addresses = conn.observed_external_addr(); + tokio::spawn(async move { + loop { + if let Some(new_addr) = *external_addresses.borrow_and_update() { + info!(%new_addr, "new external address report"); + } + if external_addresses.changed().await.is_err() { + break; + } + } + }); + let (mut send, mut recv) = conn .open_bi() .await diff --git a/quinn/examples/server.rs b/quinn/examples/server.rs index b6f63160e..b65d739be 100644 --- a/quinn/examples/server.rs +++ b/quinn/examples/server.rs @@ -127,7 +127,10 @@ async fn run(options: Opt) -> Result<()> { let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(QuicServerConfig::try_from(server_crypto)?)); let transport_config = Arc::get_mut(&mut server_config.transport).unwrap(); - transport_config.max_concurrent_uni_streams(0_u8.into()); + transport_config + .max_concurrent_uni_streams(0_u8.into()) + .send_observed_address_reports(true) + .receive_observed_address_reports(true); let root = Arc::::from(options.root.clone()); if !root.exists() { @@ -176,6 +179,21 @@ async fn handle_connection(root: Arc, conn: quinn::Incoming) -> Result<()> .protocol .map_or_else(|| "".into(), |x| String::from_utf8_lossy(&x).into_owned()) ); + + let mut external_addresses = connection.observed_external_addr(); + tokio::spawn( + async move { + loop { + if let Some(new_addr) = *external_addresses.borrow_and_update() { + info!(%new_addr, "new external address report"); + } + if external_addresses.changed().await.is_err() { + break; + } + } + } + .instrument(span.clone()), + ); async { info!("established"); diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index cf09c6e67..0969f514f 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -14,7 +14,7 @@ use bytes::Bytes; use pin_project_lite::pin_project; use rustc_hash::FxHashMap; use thiserror::Error; -use tokio::sync::{futures::Notified, mpsc, oneshot, Notify}; +use tokio::sync::{futures::Notified, mpsc, oneshot, watch, Notify}; use tracing::{debug_span, Instrument, Span}; use crate::{ @@ -636,6 +636,12 @@ impl Connection { // May need to send MAX_STREAMS to make progress conn.wake(); } + + /// Track changed on our external address as reported by the peer. + pub fn observed_external_addr(&self) -> watch::Receiver> { + let conn = self.0.state.lock("external_addr"); + conn.observed_external_addr.subscribe() + } } pin_project! { @@ -892,6 +898,7 @@ impl ConnectionRef { runtime, send_buffer: Vec::new(), buffered_transmit: None, + observed_external_addr: watch::Sender::new(None), }), shared: Shared::default(), })) @@ -974,6 +981,8 @@ pub(crate) struct State { send_buffer: Vec, /// We buffer a transmit when the underlying I/O would block buffered_transmit: Option, + /// Our last external address reported by the peer. + pub(crate) observed_external_addr: watch::Sender>, } impl State { @@ -1131,6 +1140,12 @@ impl State { wake_stream(id, &mut self.stopped); wake_stream(id, &mut self.blocked_writers); } + ObservedAddr(observed) => { + self.observed_external_addr.send_if_modified(|addr| { + let old = addr.replace(observed); + old != *addr + }); + } } } }