Skip to content

Commit

Permalink
ipc_test: some more error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
sk1p committed Sep 25, 2024
1 parent 2a95406 commit 89ed2e1
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions ipc_test/src/backend_memfd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,48 +25,48 @@ use serde::{de::DeserializeOwned, Serialize};

use crate::common::ShmConnectError;

fn read_size(mut stream: &UnixStream) -> usize {
fn read_size(mut stream: &UnixStream) -> Result<usize, ShmConnectError> {
let mut buf: [u8; std::mem::size_of::<usize>()] = [0; std::mem::size_of::<usize>()];
stream.read_exact(&mut buf).expect("read message size");
usize::from_be_bytes(buf)
stream.read_exact(&mut buf)?;
Ok(usize::from_be_bytes(buf))
}

/// connect to the given unix domain socket and grab a SHM handle
fn recv_shm_handle<H>(socket_path: &Path) -> (H, RawFd)
fn recv_shm_handle<H>(socket_path: &Path) -> Result<(H, RawFd), ShmConnectError>
where
H: DeserializeOwned,
{
let stream = UnixStream::connect(socket_path).expect("connect to socket");
let stream = UnixStream::connect(socket_path)?;

let mut fds: [i32; 1] = [0];

let size = read_size(&stream);
let size = read_size(&stream)?;

// message must be longer than 0:
assert!(size > 0);

let mut bytes: Vec<u8> = vec![0; size];

stream
.recv_with_fd(bytes.as_mut_slice(), &mut fds)
.expect("read initial message with fds");
stream.recv_with_fd(bytes.as_mut_slice(), &mut fds)?;

let payload: H = bincode::deserialize(&bytes[..]).expect("deserialize");
(payload, fds[0])
let payload: H = bincode::deserialize(&bytes[..])?;
Ok((payload, fds[0]))
}

fn handle_connection(mut stream: UnixStream, fd: RawFd, init_data_serialized: &[u8]) {
fn handle_connection(
mut stream: UnixStream,
fd: RawFd,
init_data_serialized: &[u8],
) -> Result<(), ShmConnectError> {
let fds = [fd];

// message must not be empty:
assert!(!init_data_serialized.is_empty());

stream
.write_all(&init_data_serialized.len().to_be_bytes())
.expect("send shm info size");
stream
.send_with_fd(init_data_serialized, &fds)
.expect("send shm info with fds");
stream.write_all(&init_data_serialized.len().to_be_bytes())?;
stream.send_with_fd(init_data_serialized, &fds)?;

Ok(())
}

/// start a thread that serves shm handles at the given socket path
Expand Down Expand Up @@ -109,7 +109,10 @@ where
Ok((stream, _addr)) => {
/* connection succeeded */
let my_init = init_data_serialized.clone();
std::thread::spawn(move || handle_connection(stream, fd, &my_init));
std::thread::spawn(move || {
handle_connection(stream, fd, &my_init)
.expect("could not let other side connect")
});
}
Err(err) => {
/* EAGAIN / EWOULDBLOCK */
Expand Down Expand Up @@ -207,7 +210,7 @@ impl MemfdShm {
I: DeserializeOwned,
{
let socket_path = Path::new(handle);
let (init_data, fd) = recv_shm_handle::<I>(socket_path);
let (init_data, fd) = recv_shm_handle::<I>(socket_path)?;

// safety: we exlusively own the fd, which we just received via
// the unix domain socket, so it must be open and valid.
Expand Down

0 comments on commit 89ed2e1

Please sign in to comment.