Skip to content

Commit

Permalink
Add support for sendmmsg(2) on linux
Browse files Browse the repository at this point in the history
  • Loading branch information
colinmarc committed Sep 18, 2024
1 parent 22e9043 commit 779b2b7
Show file tree
Hide file tree
Showing 5 changed files with 270 additions and 5 deletions.
21 changes: 21 additions & 0 deletions src/backend/libc/net/syscalls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@ use super::msghdr::with_xdp_msghdr;
#[cfg(target_os = "linux")]
use super::write_sockaddr::encode_sockaddr_xdp;
use crate::backend::c;
#[cfg(target_os = "linux")]
use crate::backend::conv::ret_u32;
use crate::backend::conv::{borrowed_fd, ret, ret_owned_fd, ret_send_recv, send_recv_len};
use crate::fd::{BorrowedFd, OwnedFd};
use crate::io;
#[cfg(target_os = "linux")]
use crate::net::xdp::SocketAddrXdp;
#[cfg(target_os = "linux")]
use crate::net::MMsgHdr;
use crate::net::{SocketAddrAny, SocketAddrV4, SocketAddrV6};
use crate::utils::as_ptr;
use core::mem::{size_of, MaybeUninit};
Expand Down Expand Up @@ -455,6 +459,23 @@ pub(crate) fn sendmsg_xdp(
})
}

#[cfg(target_os = "linux")]
pub(crate) fn sendmmsg(
sockfd: BorrowedFd<'_>,
msgs: &mut [MMsgHdr<'_>],
flags: SendFlags,
) -> io::Result<usize> {
unsafe {
ret_u32(c::sendmmsg(
borrowed_fd(sockfd),
msgs.as_mut_ptr() as _,
msgs.len().try_into().unwrap_or(c::c_uint::MAX),
bitflags_bits!(flags),
))
.map(|ret| ret as usize)
}
}

#[cfg(not(any(
apple,
windows,
Expand Down
10 changes: 5 additions & 5 deletions src/backend/linux_raw/c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ pub(crate) use linux_raw_sys::{
general::{O_CLOEXEC as SOCK_CLOEXEC, O_NONBLOCK as SOCK_NONBLOCK},
if_ether::*,
net::{
linger, msghdr, sockaddr, sockaddr_in, sockaddr_in6, sockaddr_un, socklen_t, AF_DECnet,
__kernel_sa_family_t as sa_family_t, __kernel_sockaddr_storage as sockaddr_storage,
cmsghdr, in6_addr, in_addr, ip_mreq, ip_mreq_source, ip_mreqn, ipv6_mreq, AF_APPLETALK,
AF_ASH, AF_ATMPVC, AF_ATMSVC, AF_AX25, AF_BLUETOOTH, AF_BRIDGE, AF_CAN, AF_ECONET,
AF_IEEE802154, AF_INET, AF_INET6, AF_IPX, AF_IRDA, AF_ISDN, AF_IUCV, AF_KEY, AF_LLC,
AF_NETBEUI, AF_NETLINK, AF_NETROM, AF_PACKET, AF_PHONET, AF_PPPOX, AF_RDS, AF_ROSE,
cmsghdr, in6_addr, in_addr, ip_mreq, ip_mreq_source, ip_mreqn, ipv6_mreq, linger, mmsghdr,
msghdr, sockaddr, sockaddr_in, sockaddr_in6, sockaddr_un, socklen_t, AF_DECnet,
AF_APPLETALK, AF_ASH, AF_ATMPVC, AF_ATMSVC, AF_AX25, AF_BLUETOOTH, AF_BRIDGE, AF_CAN,
AF_ECONET, AF_IEEE802154, AF_INET, AF_INET6, AF_IPX, AF_IRDA, AF_ISDN, AF_IUCV, AF_KEY,
AF_LLC, AF_NETBEUI, AF_NETLINK, AF_NETROM, AF_PACKET, AF_PHONET, AF_PPPOX, AF_RDS, AF_ROSE,
AF_RXRPC, AF_SECURITY, AF_SNA, AF_TIPC, AF_UNIX, AF_UNSPEC, AF_WANPIPE, AF_X25, AF_XDP,
IP6T_SO_ORIGINAL_DST, IPPROTO_FRAGMENT, IPPROTO_ICMPV6, IPPROTO_MH, IPPROTO_ROUTING,
IPV6_ADD_MEMBERSHIP, IPV6_DROP_MEMBERSHIP, IPV6_FREEBIND, IPV6_MULTICAST_HOPS,
Expand Down
28 changes: 28 additions & 0 deletions src/backend/linux_raw/net/syscalls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use super::send_recv::{RecvFlags, SendFlags};
use super::write_sockaddr::encode_sockaddr_xdp;
use super::write_sockaddr::{encode_sockaddr_v4, encode_sockaddr_v6};
use crate::backend::c;
#[cfg(target_os = "linux")]
use crate::backend::conv::slice_mut;
use crate::backend::conv::{
by_mut, by_ref, c_int, c_uint, pass_usize, ret, ret_owned_fd, ret_usize, size_of, slice,
socklen_t, zero,
Expand All @@ -24,6 +26,8 @@ use crate::fd::{BorrowedFd, OwnedFd};
use crate::io::{self, IoSlice, IoSliceMut};
#[cfg(target_os = "linux")]
use crate::net::xdp::SocketAddrXdp;
#[cfg(target_os = "linux")]
use crate::net::MMsgHdr;
use crate::net::{
AddressFamily, Protocol, RecvAncillaryBuffer, RecvMsgReturn, SendAncillaryBuffer, Shutdown,
SocketAddrAny, SocketAddrUnix, SocketAddrV4, SocketAddrV6, SocketFlags, SocketType,
Expand Down Expand Up @@ -439,6 +443,30 @@ pub(crate) fn sendmsg_xdp(
})
}

#[cfg(target_os = "linux")]
#[inline]
pub(crate) fn sendmmsg(
sockfd: BorrowedFd<'_>,
msgs: &mut [MMsgHdr<'_>],
flags: SendFlags,
) -> io::Result<usize> {
let (msgs, len) = slice_mut(msgs);

#[cfg(not(target_arch = "x86"))]
let result = unsafe { ret_usize(syscall!(__NR_sendmmsg, sockfd, msgs, len, flags)) };

#[cfg(target_arch = "x86")]
let result = unsafe {
ret_usize(syscall!(
__NR_socketcall,
x86_sys(SYS_SENDMMSG),
slice_just_addr::<ArgReg<'_, SocketArg>, _>(&[sockfd.into(), msgs, len, flags.into()])
))
};

result
}

#[inline]
pub(crate) fn shutdown(fd: BorrowedFd<'_>, how: Shutdown) -> io::Result<()> {
#[cfg(not(target_arch = "x86"))]
Expand Down
110 changes: 110 additions & 0 deletions src/net/send_recv/msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@

#![allow(unsafe_code)]

#[cfg(target_os = "linux")]
use crate::backend::net::msghdr::{
with_noaddr_msghdr, with_unix_msghdr, with_v4_msghdr, with_v6_msghdr, with_xdp_msghdr,
};
use crate::backend::{self, c};
use crate::fd::{AsFd, BorrowedFd, OwnedFd};
use crate::io::{self, IoSlice, IoSliceMut};
#[cfg(linux_kernel)]
use crate::net::UCred;
#[cfg(target_os = "linux")]
use crate::net::{xdp::SocketAddrXdp, SocketAddrUnix};

use core::iter::FusedIterator;
use core::marker::PhantomData;
Expand Down Expand Up @@ -591,6 +597,94 @@ impl<'buf> Iterator for AncillaryDrain<'buf> {

impl FusedIterator for AncillaryDrain<'_> {}

/// An ABI-compatible wrapper for `mmsghdr`, for sending multiple messages with
/// [sendmmsg].
#[cfg(target_os = "linux")]
#[repr(transparent)]
pub struct MMsgHdr<'a> {
raw: c::mmsghdr,
_phantom: PhantomData<&'a mut ()>,
}

#[cfg(target_os = "linux")]
impl<'a> MMsgHdr<'a> {
/// Constructs a new message with no destination address.
pub fn new(iov: &[IoSlice<'a>], control: &mut SendAncillaryBuffer<'_, '_, '_>) -> Self {
with_noaddr_msghdr(iov, control, |msg_hdr| Self {
raw: c::mmsghdr {
msg_hdr,
msg_len: 0,
},
_phantom: PhantomData,
})
}

/// Constructs a new message to a specific IPv4 address.
pub fn new_v4(
addr: &SocketAddrV4,
iov: &[IoSlice<'a>],
control: &mut SendAncillaryBuffer<'_, '_, '_>,
) -> Self {
with_v4_msghdr(addr, iov, control, |msg_hdr| Self {
raw: c::mmsghdr {
msg_hdr,
msg_len: 0,
},
_phantom: PhantomData,
})
}

/// Constructs a new message to a specific IPv6 address.
pub fn new_v6(
addr: &SocketAddrV6,
iov: &[IoSlice<'a>],
control: &mut SendAncillaryBuffer<'_, '_, '_>,
) -> Self {
with_v6_msghdr(addr, iov, control, |msg_hdr| Self {
raw: c::mmsghdr {
msg_hdr,
msg_len: 0,
},
_phantom: PhantomData,
})
}

/// Constructs a new message to a specific Unix-domain address.
pub fn new_unix(
addr: &SocketAddrUnix,
iov: &[IoSlice<'a>],
control: &mut SendAncillaryBuffer<'_, '_, '_>,
) -> Self {
with_unix_msghdr(addr, iov, control, |msg_hdr| Self {
raw: c::mmsghdr {
msg_hdr,
msg_len: 0,
},
_phantom: PhantomData,
})
}

/// Constructs a new message to a specific XDP address.
pub fn new_xdp(
addr: &SocketAddrXdp,
iov: &[IoSlice<'a>],
control: &mut SendAncillaryBuffer<'_, '_, '_>,
) -> Self {
with_xdp_msghdr(addr, iov, control, |msg_hdr| Self {
raw: c::mmsghdr {
msg_hdr,
msg_len: 0,
},
_phantom: PhantomData,
})
}

/// Returns the number of bytes sent by [sendmmsg].
pub fn bytes(&self) -> usize {
self.raw.msg_len as _
}
}

/// `sendmsg(msghdr)`—Sends a message on a socket.
///
/// # References
Expand Down Expand Up @@ -781,6 +875,22 @@ pub fn sendmsg_any(
}
}

/// `sendmmsg(msghdr)`—Sends multiple messages on a socket.
///
/// # References
/// - [Linux]
///
/// [Linux]: https://man7.org/linux/man-pages/man2/sendmmsg.2.html
#[inline]
#[cfg(target_os = "linux")]
pub fn sendmmsg(
socket: impl AsFd,
msgs: &mut [MMsgHdr<'_>],
flags: SendFlags,
) -> io::Result<usize> {
backend::net::syscalls::sendmmsg(socket.as_fd(), msgs, flags)
}

/// `recvmsg(msghdr)`—Receives a message from a socket.
///
/// # References
Expand Down
106 changes: 106 additions & 0 deletions tests/net/v4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#![cfg(not(any(target_os = "redox", target_os = "wasi")))]

#[cfg(target_os = "linux")]
use rustix::net::MMsgHdr;
use rustix::net::{
accept, bind_v4, connect_v4, getsockname, listen, recv, send, socket, AddressFamily, Ipv4Addr,
RecvFlags, SendFlags, SocketAddrAny, SocketAddrV4, SocketType,
Expand Down Expand Up @@ -194,3 +196,107 @@ fn test_v4_msg() {
client.join().unwrap();
server.join().unwrap();
}

#[test]
#[cfg(target_os = "linux")]
fn test_v4_sendmmsg() {
crate::init();

use rustix::io::{IoSlice, IoSliceMut};
use rustix::net::{recvmsg, sendmmsg};

fn server(ready: Arc<(Mutex<u16>, Condvar)>) {
let connection_socket = socket(AddressFamily::INET, SocketType::STREAM, None).unwrap();

let name = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 0);
bind_v4(&connection_socket, &name).unwrap();

let who = match getsockname(&connection_socket).unwrap() {
SocketAddrAny::V4(addr) => addr,
_ => panic!(),
};

listen(&connection_socket, 1).unwrap();

{
let (lock, cvar) = &*ready;
let mut port = lock.lock().unwrap();
*port = who.port();
cvar.notify_all();
}

let mut buffer = vec![0; BUFFER_SIZE];
let data_socket = accept(&connection_socket).unwrap();

let res = recvmsg(
&data_socket,
&mut [IoSliceMut::new(&mut buffer)],
&mut Default::default(),
RecvFlags::empty(),
)
.unwrap();
assert_eq!(String::from_utf8_lossy(&buffer[..res.bytes]), "hello");

let res = recvmsg(
&data_socket,
&mut [IoSliceMut::new(&mut buffer)],
&mut Default::default(),
RecvFlags::empty(),
)
.unwrap();
assert_eq!(String::from_utf8_lossy(&buffer[..res.bytes]), "...world");
}

fn client(ready: Arc<(Mutex<u16>, Condvar)>) {
let port = {
let (lock, cvar) = &*ready;
let mut port = lock.lock().unwrap();
while *port == 0 {
port = cvar.wait(port).unwrap();
}
*port
};

let addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port);
let data_socket = socket(AddressFamily::INET, SocketType::STREAM, None).unwrap();
connect_v4(&data_socket, &addr).unwrap();

let mut off = 0;
loop {
let sent = sendmmsg(
&data_socket,
&mut [
MMsgHdr::new(&[IoSlice::new(b"hello")], &mut Default::default()),
MMsgHdr::new(&[IoSlice::new(b"...world")], &mut Default::default()),
][off..],
SendFlags::empty(),
)
.unwrap();

off += sent;
if off >= 2 {
break;
}
}
}

let ready = Arc::new((Mutex::new(0_u16), Condvar::new()));
let ready_clone = Arc::clone(&ready);

let server = thread::Builder::new()
.name("server".to_string())
.spawn(move || {
server(ready);
})
.unwrap();

let client = thread::Builder::new()
.name("client".to_string())
.spawn(move || {
client(ready_clone);
})
.unwrap();

client.join().unwrap();
server.join().unwrap();
}

0 comments on commit 779b2b7

Please sign in to comment.