Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IPC: better error handling #74

Merged
merged 2 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 55 additions & 45 deletions ipc_test/src/backend_memfd.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#![forbid(clippy::unwrap_used)]
//! Raw memory backend using memfd with huge page support
// TODO:
// #![forbid(clippy::unwrap_used)]
use std::{
fs::{remove_file, File},
io::{self, Read, Write},
Expand All @@ -24,56 +23,58 @@ use nix::poll::{PollFd, PollFlags};
use sendfd::{RecvWithFd, SendWithFd};
use serde::{de::DeserializeOwned, Serialize};

fn read_size(mut stream: &UnixStream) -> usize {
use crate::common::ShmConnectError;

fn read_size(mut stream: &UnixStream) -> Result<usize, ShmConnectError> {
let mut buf: [u8; std::mem::size_of::<usize>()] = [0; std::mem::size_of::<usize>()];
stream.read_exact(&mut buf).expect("read message size");
usize::from_be_bytes(buf)
stream.read_exact(&mut buf)?;
Ok(usize::from_be_bytes(buf))
}

/// connect to the given unix domain socket and grab a SHM handle
fn recv_shm_handle<H>(socket_path: &Path) -> (H, RawFd)
fn recv_shm_handle<H>(socket_path: &Path) -> Result<(H, RawFd), ShmConnectError>
where
H: DeserializeOwned,
{
let stream = UnixStream::connect(socket_path).expect("connect to socket");
let stream = UnixStream::connect(socket_path)?;

let mut fds: [i32; 1] = [0];

let size = read_size(&stream);
let size = read_size(&stream)?;

// message must be longer than 0:
assert!(size > 0);

let mut bytes: Vec<u8> = vec![0; size];

stream
.recv_with_fd(bytes.as_mut_slice(), &mut fds)
.expect("read initial message with fds");
stream.recv_with_fd(bytes.as_mut_slice(), &mut fds)?;

let payload: H = bincode::deserialize(&bytes[..]).expect("deserialize");
(payload, fds[0])
let payload: H = bincode::deserialize(&bytes[..])?;
Ok((payload, fds[0]))
}

fn handle_connection(mut stream: UnixStream, fd: RawFd, init_data_serialized: &[u8]) {
fn handle_connection(
mut stream: UnixStream,
fd: RawFd,
init_data_serialized: &[u8],
) -> Result<(), ShmConnectError> {
let fds = [fd];

// message must not be empty:
assert!(!init_data_serialized.is_empty());

stream
.write_all(&init_data_serialized.len().to_be_bytes())
.expect("send shm info size");
stream
.send_with_fd(init_data_serialized, &fds)
.expect("send shm info with fds");
stream.write_all(&init_data_serialized.len().to_be_bytes())?;
stream.send_with_fd(init_data_serialized, &fds)?;

Ok(())
}

/// start a thread that serves shm handles at the given socket path
pub fn serve_shm_handle<I>(
init_data: I,
fd: RawFd,
socket_path: &Path,
) -> (Arc<AtomicBool>, JoinHandle<()>)
) -> Result<(Arc<AtomicBool>, JoinHandle<()>), ShmConnectError>
where
I: Serialize,
{
Expand All @@ -83,11 +84,9 @@ where
remove_file(socket_path).expect("remove existing socket");
}

let listener = UnixListener::bind(socket_path).unwrap();

let listener = UnixListener::bind(socket_path)?;
let outer_stop = Arc::clone(&stop_event);

let init_data_serialized = bincode::serialize(&init_data).unwrap();
let init_data_serialized = bincode::serialize(&init_data)?;

listener
.set_nonblocking(true)
Expand All @@ -96,7 +95,6 @@ where
let join_handle = std::thread::spawn(move || {
// Stolen from the example on `UnixListener`:
// accept connections and process them, spawning a new thread for each one

loop {
if stop_event.load(Ordering::Relaxed) {
debug!("stopping `serve_shm_handle` thread");
Expand All @@ -107,15 +105,19 @@ where
Ok((stream, _addr)) => {
/* connection succeeded */
let my_init = init_data_serialized.clone();
std::thread::spawn(move || handle_connection(stream, fd, &my_init));
std::thread::spawn(move || {
handle_connection(stream, fd, &my_init)
.expect("could not let other side connect")
});
}
Err(err) => {
/* EAGAIN / EWOULDBLOCK */
if err.kind() == io::ErrorKind::WouldBlock {
let flags = PollFlags::POLLIN;
let pollfd = PollFd::new(listener.as_fd(), flags);
nix::poll::poll(&mut [pollfd], 100u16)
.expect("poll for socket to be ready");
if let Err(e) = nix::poll::poll(&mut [pollfd], 100u16) {
log::error!("poll failed: {e}");
}
continue;
}
/* connection failed */
Expand All @@ -126,7 +128,7 @@ where
}
});

(outer_stop, join_handle)
Ok((outer_stop, join_handle))
}

pub struct MemfdShm {
Expand All @@ -150,7 +152,12 @@ impl MemfdShm {
/// If `enable_huge` is specified and not enough huge pages are available
/// from the operating system, mapping the memory area can fail.
///
pub fn new<I>(enable_huge: bool, socket_path: &Path, size: usize, init_data: I) -> Self
pub fn new<I>(
enable_huge: bool,
socket_path: &Path,
size: usize,
init_data: I,
) -> Result<Self, ShmConnectError>
where
I: Serialize,
{
Expand All @@ -160,27 +167,30 @@ impl MemfdShm {
} else {
memfd_options
};
let memfd = memfd_options.create("MemfdShm").unwrap();
let memfd = memfd_options
.create("MemfdShm")
.map_err(|e| ShmConnectError::Other { msg: e.to_string() })?;
let file = memfd.as_file();
file.set_len(size as u64).unwrap();
file.set_len(size as u64)?;

memfd
.add_seals(&[FileSeal::SealShrink, FileSeal::SealGrow])
.unwrap();

memfd.add_seal(FileSeal::SealSeal).unwrap();
.map_err(|e| ShmConnectError::Other { msg: e.to_string() })?;
memfd
.add_seal(FileSeal::SealSeal)
.map_err(|e| ShmConnectError::Other { msg: e.to_string() })?;

let file = memfd.into_file();
let mmap = MmapOptions::new().map_raw(&file).unwrap();
let mmap = MmapOptions::new().map_raw(&file)?;

let bg_thread = serve_shm_handle(&init_data, file.as_raw_fd(), socket_path);
let bg_thread = serve_shm_handle(&init_data, file.as_raw_fd(), socket_path)?;

Self {
Ok(Self {
mmap,
file,
socket_path: socket_path.to_owned(),
bg_thread: Some(bg_thread),
}
})
}

pub fn as_mut_ptr(&self) -> *mut u8 {
Expand All @@ -191,27 +201,27 @@ impl MemfdShm {
self.socket_path.to_string_lossy().deref().to_owned()
}

pub fn connect<I>(handle: &str) -> (Self, I)
pub fn connect<I>(handle: &str) -> Result<(Self, I), ShmConnectError>
where
I: DeserializeOwned,
{
let socket_path = Path::new(handle);
let (init_data, fd) = recv_shm_handle::<I>(socket_path);
let (init_data, fd) = recv_shm_handle::<I>(socket_path)?;

// safety: we exlusively own the fd, which we just received via
// the unix domain socket, so it must be open and valid.
let file = unsafe { File::from_raw_fd(fd) };
let mmap = MmapOptions::new().map_raw(&file).unwrap();
let mmap = MmapOptions::new().map_raw(&file)?;

(
Ok((
Self {
mmap,
file,
socket_path: socket_path.to_owned(),
bg_thread: None,
},
init_data,
)
))
}
}

Expand Down
49 changes: 31 additions & 18 deletions ipc_test/src/backend_shm.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
// TODO:
// #![forbid(clippy::unwrap_used)]
///! Raw memory backend using the `shared_memory` crate
///
//#![forbid(clippy::unwrap_used)]
//! Raw memory backend using the `shared_memory` crate
use std::{
fs::{remove_file, OpenOptions},
io::Write,
Expand All @@ -13,6 +11,8 @@ use std::{
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use shared_memory::{Shmem, ShmemConf};

use crate::common::ShmConnectError;

/// Initialization data that we serialize to a file, so our users don't have to
/// pass around so many things out-of-band.
#[derive(Serialize, Deserialize)]
Expand All @@ -32,32 +32,41 @@ impl SharedMemory {
/// Create a new shared memory mapping
///
/// `enable_huge` is not supported and ignored.
pub fn new<I>(_enable_huge: bool, handle_path: &Path, size: usize, init_data: I) -> Self
pub fn new<I>(
_enable_huge: bool,
handle_path: &Path,
size: usize,
init_data: I,
) -> Result<Self, ShmConnectError>
where
I: Serialize,
{
let shm_impl = ShmemConf::new().size(size).create().unwrap();
let shm_impl = ShmemConf::new()
.size(size)
.create()
.map_err(|e| ShmConnectError::Other { msg: e.to_string() })?;

let mut f = OpenOptions::new()
.create(true)
.truncate(true)
.write(true)
.open(handle_path)
.unwrap();
.open(handle_path)?;

let init_data_wrapped = InitData {
size,
os_handle: shm_impl.get_os_id().to_string(),
payload: init_data,
};

bincode::serialize_into(&f, &init_data_wrapped).unwrap();
bincode::serialize_into(&f, &init_data_wrapped)?;

f.flush().unwrap();
f.flush()?;

Self {
Ok(Self {
shm_impl: Mutex::new(shm_impl),
handle_path: handle_path.to_owned(),
is_owner: true,
}
})
}

pub fn as_mut_ptr(&self) -> *mut u8 {
Expand All @@ -68,30 +77,34 @@ impl SharedMemory {
self.handle_path.to_str().unwrap().to_owned()
}

pub fn connect<I>(handle_path: &str) -> (Self, I)
pub fn connect<I>(handle_path: &str) -> Result<(Self, I), ShmConnectError>
where
I: DeserializeOwned,
{
let f = OpenOptions::new().read(true).open(handle_path).unwrap();
let f = OpenOptions::new().read(true).open(handle_path)?;

let init_data_wrapped: InitData<I> = bincode::deserialize_from(f).unwrap();
let init_data_wrapped: InitData<I> = bincode::deserialize_from(f)?;
let InitData {
os_handle,
size,
payload,
..
} = init_data_wrapped;

let shm_impl = ShmemConf::new().os_id(os_handle).size(size).open().unwrap();
let shm_impl = ShmemConf::new()
.os_id(os_handle)
.size(size)
.open()
.map_err(|e| ShmConnectError::Other { msg: e.to_string() })?;

(
Ok((
Self {
shm_impl: Mutex::new(shm_impl),
handle_path: PathBuf::from_str(handle_path).unwrap(),
is_owner: false,
},
payload,
)
))
}
}

Expand Down
11 changes: 11 additions & 0 deletions ipc_test/src/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#[derive(thiserror::Error, Debug)]
pub enum ShmConnectError {
#[error("I/O error: {0}")]
IOError(#[from] std::io::Error),

#[error("serialization error: {0}")]
SerializationError(#[from] bincode::Error),

#[error("other error: {msg}")]
Other { msg: String },
}
1 change: 1 addition & 0 deletions ipc_test/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod common;
pub(crate) mod freestack;
pub mod slab;

Expand Down
11 changes: 7 additions & 4 deletions ipc_test/src/slab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crossbeam::channel::{bounded, Sender};
use raw_sync::locks::{LockImpl, LockInit, Mutex};
use serde::{Deserialize, Serialize};

use crate::{align_to, freestack::FreeStack, shm::Shm};
use crate::{align_to, common::ShmConnectError, freestack::FreeStack, shm::Shm};

/// A handle for reading from a shared memory slot
pub struct Slot {
Expand Down Expand Up @@ -100,7 +100,10 @@ pub struct SharedSlabAllocator {
}

#[derive(thiserror::Error, Debug)]
pub enum SlabInitError {}
pub enum SlabInitError {
#[error("connection failed: {0}")]
ConnectError(#[from] ShmConnectError),
}

///
/// Single-producer multiple consumer communication via shared memory
Expand Down Expand Up @@ -168,7 +171,7 @@ impl SharedSlabAllocator {
slot_size,
total_size,
};
let shm = Shm::new(huge_pages, shm_path, total_size, slab_info);
let shm = Shm::new(huge_pages, shm_path, total_size, slab_info)?;

Self::from_shm_and_slab_info(shm, slab_info, true)
}
Expand Down Expand Up @@ -247,7 +250,7 @@ impl SharedSlabAllocator {
}

pub fn connect(handle_path: &str) -> Result<Self, SlabInitError> {
let (shm, slab_info): (_, SlabInfo) = Shm::connect(handle_path);
let (shm, slab_info): (_, SlabInfo) = Shm::connect(handle_path)?;
Self::from_shm_and_slab_info(shm, slab_info, false)
}

Expand Down
Loading