From 89ed2e140fe521b938449e7854c06e3e9338d0d1 Mon Sep 17 00:00:00 2001 From: Alexander Clausen Date: Wed, 25 Sep 2024 18:02:55 +0200 Subject: [PATCH] ipc_test: some more error handling --- ipc_test/src/backend_memfd.rs | 43 +++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/ipc_test/src/backend_memfd.rs b/ipc_test/src/backend_memfd.rs index 98f33737..5b4ab9bb 100644 --- a/ipc_test/src/backend_memfd.rs +++ b/ipc_test/src/backend_memfd.rs @@ -25,48 +25,48 @@ use serde::{de::DeserializeOwned, Serialize}; use crate::common::ShmConnectError; -fn read_size(mut stream: &UnixStream) -> usize { +fn read_size(mut stream: &UnixStream) -> Result { let mut buf: [u8; std::mem::size_of::()] = [0; std::mem::size_of::()]; - stream.read_exact(&mut buf).expect("read message size"); - usize::from_be_bytes(buf) + stream.read_exact(&mut buf)?; + Ok(usize::from_be_bytes(buf)) } /// connect to the given unix domain socket and grab a SHM handle -fn recv_shm_handle(socket_path: &Path) -> (H, RawFd) +fn recv_shm_handle(socket_path: &Path) -> Result<(H, RawFd), ShmConnectError> where H: DeserializeOwned, { - let stream = UnixStream::connect(socket_path).expect("connect to socket"); + let stream = UnixStream::connect(socket_path)?; let mut fds: [i32; 1] = [0]; - let size = read_size(&stream); + let size = read_size(&stream)?; // message must be longer than 0: assert!(size > 0); let mut bytes: Vec = vec![0; size]; - stream - .recv_with_fd(bytes.as_mut_slice(), &mut fds) - .expect("read initial message with fds"); + stream.recv_with_fd(bytes.as_mut_slice(), &mut fds)?; - let payload: H = bincode::deserialize(&bytes[..]).expect("deserialize"); - (payload, fds[0]) + let payload: H = bincode::deserialize(&bytes[..])?; + Ok((payload, fds[0])) } -fn handle_connection(mut stream: UnixStream, fd: RawFd, init_data_serialized: &[u8]) { +fn handle_connection( + mut stream: UnixStream, + fd: RawFd, + init_data_serialized: &[u8], +) -> Result<(), ShmConnectError> { let fds = [fd]; // message must not be empty: assert!(!init_data_serialized.is_empty()); - stream - .write_all(&init_data_serialized.len().to_be_bytes()) - .expect("send shm info size"); - stream - .send_with_fd(init_data_serialized, &fds) - .expect("send shm info with fds"); + stream.write_all(&init_data_serialized.len().to_be_bytes())?; + stream.send_with_fd(init_data_serialized, &fds)?; + + Ok(()) } /// start a thread that serves shm handles at the given socket path @@ -109,7 +109,10 @@ where Ok((stream, _addr)) => { /* connection succeeded */ let my_init = init_data_serialized.clone(); - std::thread::spawn(move || handle_connection(stream, fd, &my_init)); + std::thread::spawn(move || { + handle_connection(stream, fd, &my_init) + .expect("could not let other side connect") + }); } Err(err) => { /* EAGAIN / EWOULDBLOCK */ @@ -207,7 +210,7 @@ impl MemfdShm { I: DeserializeOwned, { let socket_path = Path::new(handle); - let (init_data, fd) = recv_shm_handle::(socket_path); + let (init_data, fd) = recv_shm_handle::(socket_path)?; // safety: we exlusively own the fd, which we just received via // the unix domain socket, so it must be open and valid.