diff --git a/quinn-udp/src/lib.rs b/quinn-udp/src/lib.rs index dc285cd583..4f6ad7b523 100644 --- a/quinn-udp/src/lib.rs +++ b/quinn-udp/src/lib.rs @@ -21,6 +21,7 @@ mod cmsg; #[path = "unix.rs"] mod imp; +// FIXME rename and add CmsgHelper in unix.rs to factorize #[cfg(windows)] #[path = "wsa_cmsg.rs"] mod cmsg; diff --git a/quinn-udp/src/windows.rs b/quinn-udp/src/windows.rs index d6135262fc..343cc5a2a9 100644 --- a/quinn-udp/src/windows.rs +++ b/quinn-udp/src/windows.rs @@ -151,7 +151,7 @@ fn send( for transmit in transmits { // we cannot use [`socket2::sendmsg()`] and [`socket2::MsgHdr`] as we do not have access // to the inner field which holds the WSAMSG - let mut ctrl_buf = cmsg::Aligned([0; CMSG_LEN]); + let mut ctrl_buf = Aligned([0; CMSG_LEN]); let daddr = socket2::SockAddr::from(transmit.destination); let mut data = WinSock::WSABUF { @@ -173,8 +173,10 @@ fn send( dwFlags: 0, }; + let mut helper = unsafe { CmsgHelper::new(&mut wsa_msg) }; + // Add control messages (ECN and PKTINFO) - let mut encoder = unsafe { cmsg::Encoder::new(&mut wsa_msg) }; + let mut encoder = unsafe { cmsg::Encoder::new(&mut helper) }; if let Some(ip) = transmit.src_ip { let ip = std::net::SocketAddr::new(ip, 0); @@ -253,7 +255,7 @@ fn recv( .expect("Valid function pointer for WSARecvMsg"); // we cannot use [`socket2::MsgHdrMut`] as we do not have access to inner field which holds the WSAMSG - let mut ctrl_buf = cmsg::Aligned([0; CMSG_LEN]); + let mut ctrl_buf = Aligned([0; CMSG_LEN]); let mut source: WinSock::SOCKADDR_INET = unsafe { mem::zeroed() }; let mut data = WinSock::WSABUF { buf: bufs[0].as_mut_ptr(), @@ -303,27 +305,28 @@ fn recv( let mut ecn_bits = 0; let mut dst_ip = None; - let cmsg_iter = unsafe { cmsg::Iter::new(&wsa_msg) }; - for cmsg in cmsg_iter { + let helper = unsafe { CmsgHelper::new(&mut wsa_msg) }; + for cmsg in helper { + // [header (len)][data][padding(len + sizeof(data))] -> [header][data][padding] match (cmsg.cmsg_level, cmsg.cmsg_type) { (WinSock::IPPROTO_IP, WinSock::IP_PKTINFO) => { - let pktinfo = unsafe { cmsg::decode::(cmsg) }; + let pktinfo = unsafe { CmsgHelper::decode::(cmsg) }; // Addr is stored in big endian format let ip4 = Ipv4Addr::from(u32::from_be(unsafe { pktinfo.ipi_addr.S_un.S_addr })); dst_ip = Some(ip4.into()); } (WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO) => { - let pktinfo = unsafe { cmsg::decode::(cmsg) }; + let pktinfo = unsafe { CmsgHelper::decode::(cmsg) }; // Addr is stored in big endian format dst_ip = Some(IpAddr::from(unsafe { pktinfo.ipi6_addr.u.Byte })); } (WinSock::IPPROTO_IP, WinSock::IP_ECN) => { // ECN is a C integer https://learn.microsoft.com/en-us/windows/win32/winsock/winsock-ecn - ecn_bits = unsafe { cmsg::decode::(cmsg) }; + ecn_bits = unsafe { CmsgHelper::decode::(cmsg) }; } (WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN) => { // ECN is a C integer https://learn.microsoft.com/en-us/windows/win32/winsock/winsock-ecn - ecn_bits = unsafe { cmsg::decode::(cmsg) }; + ecn_bits = unsafe { CmsgHelper::decode::(cmsg) }; } _ => {} } @@ -396,3 +399,119 @@ fn set_socket_option( } const OPTION_ON: u32 = 1; + +#[derive(Copy, Clone)] +#[repr(align(8))] // Conservative bound for align_of +struct Aligned(pub(crate) T); + +/// Cmsg Helper wrapping [`WinSock::WSAMSG`] and [`WinSock::CMSGHDR`] +pub(crate) struct CmsgHelper<'a> { + hdr: &'a mut WinSock::WSAMSG, + cmsg: Option<&'a mut WinSock::CMSGHDR>, +} + +impl<'a> CmsgHelper<'a> { + /// # Safety + /// - `hdr.Control.buf` must be a suitably aligned pointer to `hdr.Control.len` bytes that + /// can be safely written + pub(crate) unsafe fn new(hdr: &'a mut WinSock::WSAMSG) -> Self { + Self { + cmsg: Self::cmsg_firsthdr(hdr).as_mut(), + hdr, + } + } + + pub(crate) fn control_len(&self) -> usize { + self.hdr.Control.len as usize + } + + pub(crate) fn set_control_len(&mut self, len: usize) { + self.hdr.Control.len = len as _; + } + + pub(crate) fn cmsg_take(&mut self) -> Option<&'a mut WinSock::CMSGHDR> { + self.cmsg.take() + } + + pub(crate) fn set_cmsg(&mut self, cmsg: Option<&'a mut WinSock::CMSGHDR>) { + self.cmsg = cmsg; + } + + pub(crate) fn cmsghdr_align_of() -> usize { + mem::align_of::() + } + + /// # Safety + /// + /// `cmsg` must refer to a [`WinSock::CMSGHDR`] containing a payload of type `T` + unsafe fn decode(cmsg: &WinSock::CMSGHDR) -> T { + assert!(mem::align_of::() <= mem::align_of::()); + debug_assert_eq!( + cmsg.cmsg_len, + CmsgHelper::cmsg_len(mem::size_of::() as _) + ); + ptr::read(CmsgHelper::cmsg_data(cmsg) as *const T) + } + + // Helpers functions based on C macros from + // https://github.com/microsoft/win32metadata/blob/main/generation/WinSDK/RecompiledIdlHeaders/shared/ws2def.h#L741 + pub(crate) fn cmsghdr_align(length: usize) -> usize { + (length + mem::align_of::() - 1) + & !(mem::align_of::() - 1) + } + + pub(crate) fn cmsgdata_align(length: usize) -> usize { + (length + mem::align_of::() - 1) & !(mem::align_of::() - 1) + } + + pub(crate) unsafe fn cmsg_firsthdr(msg: *const WinSock::WSAMSG) -> *mut WinSock::CMSGHDR { + if (*msg).Control.len as usize >= mem::size_of::() { + (*msg).Control.buf as *mut WinSock::CMSGHDR + } else { + ptr::null_mut::() + } + } + + pub(crate) unsafe fn cmsg_nxthdr( + &self, + cmsg: *const WinSock::CMSGHDR, + ) -> *mut WinSock::CMSGHDR { + if cmsg.is_null() { + return Self::cmsg_firsthdr(self.hdr); + } + let next = (cmsg as usize + Self::cmsghdr_align((*cmsg).cmsg_len)) as *mut WinSock::CMSGHDR; + let max = self.hdr.Control.buf as usize + self.hdr.Control.len as usize; + if (next.offset(1)) as usize > max { + ptr::null_mut() + } else { + next + } + } + + pub(crate) unsafe fn cmsg_data(cmsg: *const WinSock::CMSGHDR) -> *mut u8 { + (cmsg as usize + Self::cmsgdata_align(mem::size_of::())) as *mut u8 + } + + pub(crate) fn cmsg_space(length: usize) -> usize { + Self::cmsgdata_align(mem::size_of::() + Self::cmsghdr_align(length)) + } + + pub(crate) fn cmsg_len(length: usize) -> usize { + Self::cmsgdata_align(mem::size_of::()) + length + } +} + +impl<'a> Iterator for CmsgHelper<'a> { + type Item = &'a WinSock::CMSGHDR; + + /// # Safety + /// + /// `self.hdr.Control.buf` must point to memory outliving `'a` which can be soundly read for the + /// lifetime of the constructed `Iter` and contains a buffer of [`WinSock::CMSGHDR`], i.e. + /// is aligned for [`WinSock::CMSGHDR`], is fully initialized, and has correct internal links. + fn next(&mut self) -> Option<&'a WinSock::CMSGHDR> { + let current = self.cmsg.take()?; + self.cmsg = unsafe { CmsgHelper::cmsg_nxthdr(self, current).as_mut() }; + Some(current) + } +} diff --git a/quinn-udp/src/wsa_cmsg.rs b/quinn-udp/src/wsa_cmsg.rs index 156ddcc7ea..4bf2375e16 100644 --- a/quinn-udp/src/wsa_cmsg.rs +++ b/quinn-udp/src/wsa_cmsg.rs @@ -1,103 +1,54 @@ -use std::{mem, ptr}; +use std::{ffi::c_int, mem, ptr}; -use windows_sys::Win32::Networking::WinSock; +use crate::imp::CmsgHelper; -#[derive(Copy, Clone)] -#[repr(align(8))] // Conservative bound for align_of -pub(crate) struct Aligned(pub(crate) T); - -// Helpers functions based on C macros from -// https://github.com/microsoft/win32metadata/blob/main/generation/WinSDK/RecompiledIdlHeaders/shared/ws2def.h#L741 -fn wsa_cmsghdr_align(length: usize) -> usize { - (length + mem::align_of::() - 1) & !(mem::align_of::() - 1) -} - -fn wsa_cmsgdata_align(length: usize) -> usize { - (length + mem::align_of::() - 1) & !(mem::align_of::() - 1) -} - -unsafe fn wsa_cmsg_firsthdr(msg: *const WinSock::WSAMSG) -> *mut WinSock::CMSGHDR { - if (*msg).Control.len as usize >= mem::size_of::() { - (*msg).Control.buf as *mut WinSock::CMSGHDR - } else { - ptr::null_mut::() - } -} - -unsafe fn wsa_cmsg_nxthdr( - msg: *const WinSock::WSAMSG, - cmsg: *const WinSock::CMSGHDR, -) -> *mut WinSock::CMSGHDR { - if cmsg.is_null() { - return wsa_cmsg_firsthdr(msg); - } - let next = (cmsg as usize + wsa_cmsghdr_align((*cmsg).cmsg_len)) as *mut WinSock::CMSGHDR; - let max = (*msg).Control.buf as usize + (*msg).Control.len as usize; - if (next.offset(1)) as usize > max { - ptr::null_mut() - } else { - next - } -} - -unsafe fn wsa_cmsg_data(cmsg: *const WinSock::CMSGHDR) -> *mut u8 { - (cmsg as usize + wsa_cmsgdata_align(mem::size_of::())) as *mut u8 -} - -fn wsa_cmsg_space(length: usize) -> usize { - wsa_cmsgdata_align(mem::size_of::() + wsa_cmsghdr_align(length)) -} - -fn wsa_cmsg_len(length: usize) -> usize { - wsa_cmsgdata_align(mem::size_of::()) + length -} - -/// Helper to encode a series of control messages ("cmsgs") to a buffer for use in `WSASendMsg`. +/// Helper to encode a series of control messages ("cmsgs") to a buffer for use in a `sendmsg`` like function /// -/// The operation must be "finished" for the `WSAMSG`` to be usable, either by calling `finish` +/// The operation must be "finished" for the message to be usable, either by calling `finish` /// explicitly or by dropping the `Encoder`. pub(crate) struct Encoder<'a> { - hdr: &'a mut WinSock::WSAMSG, - cmsg: Option<&'a mut WinSock::CMSGHDR>, + helper: &'a mut CmsgHelper<'a>, len: usize, } impl<'a> Encoder<'a> { /// # Safety - /// - `hdr.Control.buf` must be a suitably aligned pointer to `hdr.Control.len` bytes that - /// can be safely written - /// - The `Encoder` must be dropped before `hdr` is passed to a system call, and must not be leaked. - pub(crate) unsafe fn new(hdr: &'a mut WinSock::WSAMSG) -> Self { - Self { - cmsg: wsa_cmsg_firsthdr(hdr).as_mut(), - hdr, - len: 0, - } + /// - The `CmsgHelper` handles all the alignement constraints + /// - The `Encoder` must be dropped before the native build message is passed to a system call, + /// and must not be leaked. + pub(crate) unsafe fn new(helper: &'a mut CmsgHelper<'a>) -> Self { + Self { helper, len: 0 } } - /// Append a control message ([`WinSock::CMSGHDR`]) to the buffer. + /// Append a native control message to the buffer. /// /// # Panics /// - If insufficient buffer space remains. - /// - If `T` has stricter alignment requirements than `cmsghdr` - pub(crate) fn push(&mut self, level: i32, ty: i32, value: T) { - assert!(mem::align_of::() <= mem::align_of::()); - let space = wsa_cmsg_space(mem::size_of_val(&value) as _); + /// - If `T` has stricter alignment requirements than the native type + /// - level and type fields of cmsg must of a type compatible with [`std::ffi::c_int`]` + pub(crate) fn push(&mut self, level: c_int, ty: c_int, value: T) { + assert!(mem::align_of::() <= CmsgHelper::cmsghdr_align_of()); + let space = CmsgHelper::cmsg_space(mem::size_of_val(&value) as _); assert!( - self.hdr.Control.len as usize >= self.len + space, + self.helper.control_len() >= self.len + space, "control message buffer too small. Required: {}, Available: {}", self.len + space, - self.hdr.Control.len + self.helper.control_len() ); - let cmsg = self.cmsg.take().expect("no control buffer space remaining"); + let cmsg = self + .helper + .cmsg_take() + .expect("no control buffer space remaining"); cmsg.cmsg_level = level; cmsg.cmsg_type = ty; - cmsg.cmsg_len = wsa_cmsg_len(mem::size_of_val(&value) as _) as _; + cmsg.cmsg_len = CmsgHelper::cmsg_len(mem::size_of_val(&value) as _) as _; unsafe { - ptr::write(wsa_cmsg_data(cmsg) as *const T as *mut T, value); + ptr::write(CmsgHelper::cmsg_data(cmsg) as *const T as *mut T, value); } self.len += space; - self.cmsg = unsafe { wsa_cmsg_nxthdr(self.hdr, cmsg).as_mut() }; + + self.helper + .set_cmsg(unsafe { CmsgHelper::cmsg_nxthdr(self.helper, cmsg).as_mut() }); } /// Finishes appending control messages to the buffer @@ -107,46 +58,9 @@ impl<'a> Encoder<'a> { } // Statically guarantees that the encoding operation is "finished" before the control buffer is read -// by `WSASendMsg`. +// by sendmsg like functions. impl<'a> Drop for Encoder<'a> { fn drop(&mut self) { - self.hdr.Control.len = self.len as _; - } -} - -/// # Safety -/// -/// `cmsg` must refer to a [`WinSock::CMSGHDR`] containing a payload of type `T` -pub(crate) unsafe fn decode(cmsg: &WinSock::CMSGHDR) -> T { - assert!(mem::align_of::() <= mem::align_of::()); - debug_assert_eq!(cmsg.cmsg_len, wsa_cmsg_len(mem::size_of::() as _)); - ptr::read(wsa_cmsg_data(cmsg) as *const T) -} - -pub(crate) struct Iter<'a> { - hdr: &'a WinSock::WSAMSG, - cmsg: Option<&'a WinSock::CMSGHDR>, -} - -impl<'a> Iter<'a> { - /// # Safety - /// - /// `hdr.Control.buf` must point to memory outliving `'a` which can be soundly read for the - /// lifetime of the constructed `Iter` and contains a buffer of [`WinSock::CMSGHDR`], i.e. - /// is aligned for [`WinSock::CMSGHDR`], is fully initialized, and has correct internal links. - pub(crate) unsafe fn new(hdr: &'a WinSock::WSAMSG) -> Self { - Self { - hdr, - cmsg: wsa_cmsg_firsthdr(hdr).as_ref(), - } - } -} - -impl<'a> Iterator for Iter<'a> { - type Item = &'a WinSock::CMSGHDR; - fn next(&mut self) -> Option<&'a WinSock::CMSGHDR> { - let current = self.cmsg.take()?; - self.cmsg = unsafe { wsa_cmsg_nxthdr(self.hdr, current).as_ref() }; - Some(current) + self.helper.set_control_len(self.len); } }