From 9f2c9e128f50e9092f681733849fb7736c0b0034 Mon Sep 17 00:00:00 2001 From: Damien Deville Date: Tue, 19 Dec 2023 16:39:59 +0100 Subject: [PATCH] udp: add support for ECN on Windows --- quinn-proto/src/connection/mod.rs | 1 + quinn-udp/Cargo.toml | 3 +- quinn-udp/src/cmsg/mod.rs | 4 + quinn-udp/src/cmsg/windows.rs | 112 ++++++++++++ quinn-udp/src/lib.rs | 3 +- quinn-udp/src/windows.rs | 291 ++++++++++++++++++++++++++---- quinn/src/tests.rs | 5 +- 7 files changed, 385 insertions(+), 34 deletions(-) create mode 100644 quinn-udp/src/cmsg/windows.rs mode change 100644 => 100755 quinn/src/tests.rs diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index 4798a12111..0179afe53c 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -1209,6 +1209,7 @@ impl Connection { /// Retrieving the local IP address is currently supported on the following /// platforms: /// - Linux + /// - Windows /// /// On all non-supported platforms the local IP address will not be available, /// and the method will return `None`. diff --git a/quinn-udp/Cargo.toml b/quinn-udp/Cargo.toml index 0730a43115..68f387618e 100644 --- a/quinn-udp/Cargo.toml +++ b/quinn-udp/Cargo.toml @@ -28,4 +28,5 @@ socket2 = "0.5" tracing = "0.1.10" [target.'cfg(windows)'.dependencies] -windows-sys = { version = "0.52.0", features = ["Win32_Networking_WinSock"] } +windows-sys = { version = "0.52.0", features = ["Win32_Foundation", "Win32_System_IO", "Win32_Networking_WinSock"] } +once_cell = "1.19.0" diff --git a/quinn-udp/src/cmsg/mod.rs b/quinn-udp/src/cmsg/mod.rs index 6f3c867974..cc5ecdc7f7 100644 --- a/quinn-udp/src/cmsg/mod.rs +++ b/quinn-udp/src/cmsg/mod.rs @@ -7,6 +7,10 @@ use std::{ #[path = "unix.rs"] mod imp; +#[cfg(windows)] +#[path = "windows.rs"] +mod imp; + pub(crate) use imp::Aligned; // Helper traits for native types for control messages diff --git a/quinn-udp/src/cmsg/windows.rs b/quinn-udp/src/cmsg/windows.rs new file mode 100644 index 0000000000..63adcf9415 --- /dev/null +++ b/quinn-udp/src/cmsg/windows.rs @@ -0,0 +1,112 @@ +use std::ffi::{c_int, c_uchar}; + +use windows_sys::Win32::Networking::WinSock; + +use super::{CMsgHdr, MsgHdr}; + +#[derive(Copy, Clone)] +#[repr(align(8))] // Conservative bound for align_of +pub(crate) struct Aligned(pub(crate) T); + +/// Helpers for [`WinSock::WSAMSG`] +// https://learn.microsoft.com/en-us/windows/win32/api/ws2def/ns-ws2def-wsamsg +// https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/struct.WSAMSG.html +impl MsgHdr for WinSock::WSAMSG { + type ControlMessage = WinSock::CMSGHDR; + + fn control_len(&self) -> usize { + self.Control.len as _ + } + + fn set_control_len(&mut self, len: usize) { + self.Control.len = len as _; + } + + fn cmsg_firsthdr(&self) -> *mut Self::ControlMessage { + unsafe { self::wsa2::cmsg_firsthdr(self) } + } + + fn cmsg_nxthdr(&self, cmsg: &Self::ControlMessage) -> *mut Self::ControlMessage { + unsafe { self::wsa2::cmsg_nxthdr(self, cmsg) } + } +} + +/// Helpers for [`WinSock::CMSGHDR`] +// https://learn.microsoft.com/en-us/windows/win32/api/ws2def/ns-ws2def-wsacmsghdr +// https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/struct.CMSGHDR.html +impl CMsgHdr for WinSock::CMSGHDR { + fn set(&mut self, level: c_int, ty: c_int, len: usize) { + self.cmsg_level = level as _; + self.cmsg_type = ty as _; + self.cmsg_len = len as _; + } + + fn len(&self) -> usize { + self.cmsg_len as _ + } + + fn cmsg_space(length: usize) -> usize { + self::wsa2::cmsg_space(length) + } + + fn cmsg_len(length: usize) -> usize { + self::wsa2::cmsg_len(length) + } + + fn cmsg_data(&self) -> *mut c_uchar { + unsafe { self::wsa2::cmsg_data(self) } + } +} + +mod wsa2 { + use std::{mem, ptr}; + + use windows_sys::Win32::Networking::WinSock; + + // Helpers functions based on C macros from + // https://github.com/microsoft/win32metadata/blob/main/generation/WinSDK/RecompiledIdlHeaders/shared/ws2def.h#L741 + fn cmsghdr_align(length: usize) -> usize { + (length + mem::align_of::() - 1) + & !(mem::align_of::() - 1) + } + + fn cmsgdata_align(length: usize) -> usize { + (length + mem::align_of::() - 1) & !(mem::align_of::() - 1) + } + + pub(crate) unsafe fn cmsg_firsthdr(msg: *const WinSock::WSAMSG) -> *mut WinSock::CMSGHDR { + if (*msg).Control.len as usize >= mem::size_of::() { + (*msg).Control.buf as *mut WinSock::CMSGHDR + } else { + ptr::null_mut::() + } + } + + pub(crate) unsafe fn cmsg_nxthdr( + hdr: &WinSock::WSAMSG, + cmsg: *const WinSock::CMSGHDR, + ) -> *mut WinSock::CMSGHDR { + if cmsg.is_null() { + return cmsg_firsthdr(hdr); + } + let next = (cmsg as usize + cmsghdr_align((*cmsg).cmsg_len)) as *mut WinSock::CMSGHDR; + let max = hdr.Control.buf as usize + hdr.Control.len as usize; + if (next.offset(1)) as usize > max { + ptr::null_mut() + } else { + next + } + } + + pub(crate) unsafe fn cmsg_data(cmsg: *const WinSock::CMSGHDR) -> *mut u8 { + (cmsg as usize + cmsgdata_align(mem::size_of::())) as *mut u8 + } + + pub(crate) fn cmsg_space(length: usize) -> usize { + cmsgdata_align(mem::size_of::() + cmsghdr_align(length)) + } + + pub(crate) fn cmsg_len(length: usize) -> usize { + cmsgdata_align(mem::size_of::()) + length + } +} diff --git a/quinn-udp/src/lib.rs b/quinn-udp/src/lib.rs index 5adf622fde..c6c9a12dfb 100644 --- a/quinn-udp/src/lib.rs +++ b/quinn-udp/src/lib.rs @@ -15,8 +15,9 @@ use std::{ use bytes::Bytes; use tracing::warn; -#[cfg(unix)] +#[cfg(any(unix, windows))] mod cmsg; + #[cfg(unix)] #[path = "unix.rs"] mod imp; diff --git a/quinn-udp/src/windows.rs b/quinn-udp/src/windows.rs index 363a03e342..9a83812f8a 100644 --- a/quinn-udp/src/windows.rs +++ b/quinn-udp/src/windows.rs @@ -1,14 +1,26 @@ use std::{ io::{self, IoSliceMut}, mem, + net::{IpAddr, Ipv4Addr}, os::windows::io::AsRawSocket, + ptr, sync::Mutex, time::Instant, }; +use once_cell::sync::OnceCell; use windows_sys::Win32::Networking::WinSock; -use super::{log_sendmsg_error, RecvMeta, Transmit, UdpSockRef, IO_ERROR_LOG_INTERVAL}; +use crate::{ + cmsg::{self, CMsgHdr}, + log_sendmsg_error, EcnCodepoint, RecvMeta, Transmit, UdpSockRef, IO_ERROR_LOG_INTERVAL, +}; + +// Enough to store max(IP_PKTINFO + IP_ECN, IPV6_PKTINFO + IPV6_ECN) bytes (header + data) and some extra margin +const CMSG_LEN: usize = 128; + +// FIXME this could use [`std::sync::OnceLock`] once the MSRV is bumped to 1.70 and upper +static WSARECVMSG_PTR: OnceCell = OnceCell::new(); /// QUIC-friendly UDP interface for Windows #[derive(Debug)] @@ -18,6 +30,16 @@ pub struct UdpSocketState { impl UdpSocketState { pub fn new(socket: UdpSockRef<'_>) -> io::Result { + assert!( + CMSG_LEN + >= WinSock::CMSGHDR::cmsg_space(mem::size_of::()) + + WinSock::CMSGHDR::cmsg_space(mem::size_of::()) + ); + assert!( + mem::align_of::() <= mem::align_of::>(), + "control message buffers will be misaligned" + ); + socket.0.set_nonblocking(true)?; let addr = socket.0.local_addr()?; let is_ipv6 = addr.as_socket_ipv6().is_some(); @@ -38,6 +60,15 @@ impl UdpSocketState { }; let is_ipv4 = addr.as_socket_ipv4().is_some() || !v6only; + let wsa_recvmsg_ptr = WSARECVMSG_PTR.get_or_init(|| get_wsarecvmsg_fn(&*socket.0)); + + // We do not support anymore old version of windows that do not give access to WSARecvMsg() function + if wsa_recvmsg_ptr.is_none() { + tracing::error!("network stack does not support wsarecvmsg function"); + + return Err(io::Error::from(io::ErrorKind::Unsupported)); + } + if is_ipv4 { set_socket_option( &*socket.0, @@ -45,6 +76,14 @@ impl UdpSocketState { WinSock::IP_DONTFRAGMENT, OPTION_ON, )?; + + set_socket_option( + &*socket.0, + WinSock::IPPROTO_IP, + WinSock::IP_PKTINFO, + OPTION_ON, + )?; + set_socket_option(&*socket.0, WinSock::IPPROTO_IP, WinSock::IP_ECN, OPTION_ON)?; } if is_ipv6 { @@ -54,6 +93,20 @@ impl UdpSocketState { WinSock::IPV6_DONTFRAG, OPTION_ON, )?; + + set_socket_option( + &*socket.0, + WinSock::IPPROTO_IPV6, + WinSock::IPV6_PKTINFO, + OPTION_ON, + )?; + + set_socket_option( + &*socket.0, + WinSock::IPPROTO_IPV6, + WinSock::IPV6_ECN, + OPTION_ON, + )?; } let now = Instant::now(); @@ -62,30 +115,98 @@ impl UdpSocketState { }) } - pub fn send(&self, socket: UdpSockRef<'_>, transmits: &[Transmit]) -> Result { + pub fn send(&self, socket: UdpSockRef<'_>, transmits: &[Transmit]) -> io::Result { let mut sent = 0; for transmit in transmits { - match socket.0.send_to( - &transmit.contents, - &socket2::SockAddr::from(transmit.destination), - ) { - Ok(_) => { - sent += 1; + // we cannot use [`socket2::sendmsg()`] and [`socket2::MsgHdr`] as we do not have access + // to the inner field which holds the WSAMSG + let mut ctrl_buf = cmsg::Aligned([0; CMSG_LEN]); + let daddr = socket2::SockAddr::from(transmit.destination); + + let mut data = WinSock::WSABUF { + buf: transmit.contents.as_ptr() as *mut _, + len: transmit.contents.len() as _, + }; + + let ctrl = WinSock::WSABUF { + buf: ctrl_buf.0.as_mut_ptr(), + len: ctrl_buf.0.len() as _, + }; + + let mut wsa_msg = WinSock::WSAMSG { + name: daddr.as_ptr() as *mut _, + namelen: daddr.len(), + lpBuffers: &mut data, + Control: ctrl, + dwBufferCount: 1, + dwFlags: 0, + }; + + // Add control messages (ECN and PKTINFO) + let mut encoder = unsafe { cmsg::Encoder::new(&mut wsa_msg) }; + + if let Some(ip) = transmit.src_ip { + let ip = std::net::SocketAddr::new(ip, 0); + let ip = socket2::SockAddr::from(ip); + match ip.family() { + WinSock::AF_INET => { + let src_ip: WinSock::SOCKADDR_IN = unsafe { ptr::read(ip.as_ptr() as _) }; + let pktinfo = WinSock::IN_PKTINFO { + ipi_addr: src_ip.sin_addr, + ipi_ifindex: 0, + }; + encoder.push(WinSock::IPPROTO_IP, WinSock::IP_PKTINFO, pktinfo); + } + WinSock::AF_INET6 => { + let src_ip: WinSock::SOCKADDR_IN6 = unsafe { ptr::read(ip.as_ptr() as _) }; + let pktinfo = WinSock::IN6_PKTINFO { + ipi6_addr: src_ip.sin6_addr, + ipi6_ifindex: unsafe { src_ip.Anonymous.sin6_scope_id }, + }; + encoder.push(WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO, pktinfo); + } + _ => { + return Err(io::Error::from(io::ErrorKind::InvalidInput)); + } } + } + + // ECN is a C integer https://learn.microsoft.com/en-us/windows/win32/winsock/winsock-ecn + let ecn = transmit.ecn.map_or(0, |x| x as i32); + if transmit.destination.is_ipv4() { + encoder.push(WinSock::IPPROTO_IP, WinSock::IP_TOS, ecn); + } else { + encoder.push(WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN, ecn); + } + + encoder.finish(); + + let mut len = 0; + let rc = unsafe { + WinSock::WSASendMsg( + socket.0.as_raw_socket() as usize, + &wsa_msg, + 0, + &mut len, + ptr::null_mut(), + None, + ) + }; + + if rc == 0 { + sent += 1; + } else if sent != 0 { // We need to report that some packets were sent in this case, so we rely on // errors being either harmlessly transient (in the case of WouldBlock) or // recurring on the next call. - Err(_) if sent != 0 => return Ok(sent), - Err(e) => { - if e.kind() == io::ErrorKind::WouldBlock { - return Err(e); - } - - // Other errors are ignored, since they will usually be handled - // by higher level retransmits and timeouts. - log_sendmsg_error(&self.last_send_error, e, transmit); - sent += 1; - } + return Ok(sent); + } else if rc == WinSock::WSAEWOULDBLOCK { + return Err(io::Error::last_os_error()); + } else { + // Other errors are ignored, since they will usually be handled + // by higher level retransmits and timeouts. + log_sendmsg_error(&self.last_send_error, io::Error::last_os_error(), transmit); + sent += 1; } } Ok(sent) @@ -97,20 +218,96 @@ impl UdpSocketState { bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta], ) -> io::Result { - // Safety: both `IoSliceMut` and `MaybeUninitSlice` promise to have the - // same layout, that of `iovec`/`WSABUF`. Furthermore `recv_vectored` - // promises to not write unitialised bytes to the `bufs` and pass it - // directly to the `recvmsg` system call, so this is safe. - let bufs = unsafe { - &mut *(bufs as *mut [IoSliceMut<'_>] as *mut [socket2::MaybeUninitSlice<'_>]) + let wsa_recvmsg_ptr = WSARECVMSG_PTR + .get_or_init(|| get_wsarecvmsg_fn(&*socket.0)) + .expect("valid function pointer for wsarecvmsg"); + + // we cannot use [`socket2::MsgHdrMut`] as we do not have access to inner field which holds the WSAMSG + let mut ctrl_buf = cmsg::Aligned([0; CMSG_LEN]); + let mut source: WinSock::SOCKADDR_INET = unsafe { mem::zeroed() }; + let mut data = WinSock::WSABUF { + buf: bufs[0].as_mut_ptr(), + len: bufs[0].len() as _, + }; + + let ctrl = WinSock::WSABUF { + buf: ctrl_buf.0.as_mut_ptr(), + len: ctrl_buf.0.len() as _, + }; + + let mut wsa_msg = WinSock::WSAMSG { + name: &mut source as *mut _ as *mut _, + namelen: mem::size_of_val(&source) as _, + lpBuffers: &mut data, + Control: ctrl, + dwBufferCount: 1, + dwFlags: 0, + }; + + // FIXME add Safety: ? + let mut len = 0; + unsafe { + let rc = (wsa_recvmsg_ptr)( + socket.0.as_raw_socket() as usize, + &mut wsa_msg, + &mut len, + ptr::null_mut(), + None, + ); + if rc == -1 { + return Err(io::Error::last_os_error()); + } + } + + // FIXME add Safety: ? + let addr = unsafe { + let (_, addr) = socket2::SockAddr::try_init(|addr_storage, len| { + *len = mem::size_of_val(&source) as _; + ptr::copy_nonoverlapping(&source, addr_storage as _, 1); + Ok(()) + })?; + addr.as_socket() }; - let (len, _flags, addr) = socket.0.recv_from_vectored(bufs)?; + + // Decode control messages (PKTINFO and ECN) + let mut ecn_bits = 0; + let mut dst_ip = None; + + let cmsg_iter = unsafe { cmsg::Iter::new(&wsa_msg) }; + for cmsg in cmsg_iter { + // [header (len)][data][padding(len + sizeof(data))] -> [header][data][padding] + match (cmsg.cmsg_level, cmsg.cmsg_type) { + (WinSock::IPPROTO_IP, WinSock::IP_PKTINFO) => { + let pktinfo = + unsafe { cmsg::decode::(cmsg) }; + // Addr is stored in big endian format + let ip4 = Ipv4Addr::from(u32::from_be(unsafe { pktinfo.ipi_addr.S_un.S_addr })); + dst_ip = Some(ip4.into()); + } + (WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO) => { + let pktinfo = + unsafe { cmsg::decode::(cmsg) }; + // Addr is stored in big endian format + dst_ip = Some(IpAddr::from(unsafe { pktinfo.ipi6_addr.u.Byte })); + } + (WinSock::IPPROTO_IP, WinSock::IP_ECN) => { + // ECN is a C integer https://learn.microsoft.com/en-us/windows/win32/winsock/winsock-ecn + ecn_bits = unsafe { cmsg::decode::(cmsg) }; + } + (WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN) => { + // ECN is a C integer https://learn.microsoft.com/en-us/windows/win32/winsock/winsock-ecn + ecn_bits = unsafe { cmsg::decode::(cmsg) }; + } + _ => {} + } + } + meta[0] = RecvMeta { - len, - stride: len, - addr: addr.as_socket().unwrap(), - ecn: None, - dst_ip: None, + len: len as usize, + stride: len as usize, + addr: addr.unwrap(), + ecn: EcnCodepoint::from_bits(ecn_bits as u8), + dst_ip, }; Ok(1) } @@ -142,6 +339,38 @@ impl UdpSocketState { pub(crate) const BATCH_SIZE: usize = 1; +fn get_wsarecvmsg_fn(socket: &impl AsRawSocket) -> WinSock::LPFN_WSARECVMSG { + // Detect if OS expose WSARecvMsg API based on + // https://github.com/Azure/mio-uds-windows/blob/a3c97df82018086add96d8821edb4aa85ec1b42b/src/stdnet/ext.rs#L601 + let guid = WinSock::WSAID_WSARECVMSG; + let mut wsa_recvmsg_ptr = None; + let mut len = 0; + + // Safety: Option handles the NULL pointer with a None value + let rc = unsafe { + WinSock::WSAIoctl( + socket.as_raw_socket() as _, + WinSock::SIO_GET_EXTENSION_FUNCTION_POINTER, + &guid as *const _ as *const _, + mem::size_of_val(&guid) as u32, + &mut wsa_recvmsg_ptr as *mut _ as *mut _, + mem::size_of_val(&wsa_recvmsg_ptr) as u32, + &mut len, + ptr::null_mut(), + None, + ) + }; + + if rc == -1 { + tracing::debug!("ignoring wsarecvmsg function pointer due to ioctl error"); + } else if len as usize != mem::size_of::() { + tracing::debug!("ignoring wsarecvmsg function pointer due to pointer size mismatch"); + wsa_recvmsg_ptr = None; + } + + wsa_recvmsg_ptr +} + fn set_socket_option( socket: &impl AsRawSocket, level: i32, diff --git a/quinn/src/tests.rs b/quinn/src/tests.rs old mode 100644 new mode 100755 index 3436cb68d1..05c0ebf5e8 --- a/quinn/src/tests.rs +++ b/quinn/src/tests.rs @@ -479,7 +479,10 @@ fn run_echo(args: EchoArgs) { // If `local_ip` gets available on additional platforms - which // requires modifying this test - please update the list of supported // platforms in the doc comments of the various `local_ip` functions. - if cfg!(target_os = "linux") || cfg!(target_os = "freebsd") || cfg!(target_os = "macos") + if cfg!(target_os = "linux") + || cfg!(target_os = "freebsd") + || cfg!(target_os = "macos") + || cfg!(target_os = "windows") { let local_ip = incoming.local_ip().expect("Local IP must be available"); assert!(local_ip.is_loopback());