diff --git a/src/device/socket/vsock.rs b/src/device/socket/vsock.rs index 6c5a3f26..4a9e33be 100644 --- a/src/device/socket/vsock.rs +++ b/src/device/socket/vsock.rs @@ -7,15 +7,13 @@ use super::protocol::{ }; use super::DEFAULT_RX_BUFFER_SIZE; use crate::hal::Hal; -use crate::queue::VirtQueue; +use crate::queue::{owning::OwningQueue, VirtQueue}; use crate::transport::Transport; use crate::volatile::volread; -use crate::{Error, Result}; -use alloc::boxed::Box; +use crate::Result; use core::mem::size_of; -use core::ptr::{null_mut, NonNull}; use log::debug; -use zerocopy::{AsBytes, FromBytes, FromZeroes}; +use zerocopy::{AsBytes, FromBytes}; pub(crate) const RX_QUEUE_IDX: u16 = 0; pub(crate) const TX_QUEUE_IDX: u16 = 1; @@ -222,30 +220,13 @@ pub struct VirtIOSocket, + rx: OwningQueue, tx: VirtQueue, /// Virtqueue to receive events from the device. event: VirtQueue, /// The guest_cid field contains the guest’s context ID, which uniquely identifies /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed. guest_cid: u64, - rx_queue_buffers: [NonNull<[u8; RX_BUFFER_SIZE]>; QUEUE_SIZE], -} - -// SAFETY: The `rx_queue_buffers` can be accessed from any thread. -unsafe impl Send - for VirtIOSocket -where - VirtQueue: Send, -{ -} - -// SAFETY: A `&VirtIOSocket` only allows reading the guest CID from a field. -unsafe impl Sync - for VirtIOSocket -where - VirtQueue: Sync, -{ } impl Drop @@ -257,12 +238,6 @@ impl Drop self.transport.queue_unset(RX_QUEUE_IDX); self.transport.queue_unset(TX_QUEUE_IDX); self.transport.queue_unset(EVENT_QUEUE_IDX); - - for buffer in self.rx_queue_buffers { - // Safe because we obtained the RX buffer pointer from Box::into_raw, and it won't be - // used anywhere else after the driver is destroyed. - unsafe { drop(Box::from_raw(buffer.as_ptr())) }; - } } } @@ -281,7 +256,7 @@ impl VirtIOSocket VirtIOSocket = FromZeroes::new_box_zeroed(); - // Safe because the buffer lives as long as the queue, as specified in the function - // safety requirement, and we don't access it until it is popped. - let token = unsafe { rx.add(&[], &mut [buffer.as_mut_slice()]) }?; - assert_eq!(i, token.into()); - *rx_queue_buffer = Box::into_raw(buffer); - } - let rx_queue_buffers = rx_queue_buffers.map(|ptr| NonNull::new(ptr).unwrap()); + let rx = OwningQueue::new(rx)?; transport.finish_init(); if rx.should_notify() { @@ -323,7 +288,6 @@ impl VirtIOSocket VirtIOSocket Result>, ) -> Result> { - let Some((header, body, token)) = self.pop_packet_from_rx_queue()? else { - return Ok(None); - }; - - let result = VsockEvent::from_header(&header).and_then(|event| handler(event, body)); - - unsafe { - // TODO: What about if both handler and this give errors? - self.add_buffer_to_rx_queue(token)?; - } - - result + self.rx.poll(&mut self.transport, |buffer| { + let (header, body) = read_header_and_body(buffer)?; + VsockEvent::from_header(&header).and_then(|event| handler(event, body)) + }) } /// Requests to shut down the connection cleanly, sending hints about whether we will send or @@ -481,78 +437,19 @@ impl VirtIOSocket Result { - // Safe because the buffer lives as long as the queue, and the caller guarantees that it's - // not currently in the queue or referred to anywhere else until it is popped. - unsafe { - let buffer = self - .rx_queue_buffers - .get_mut(usize::from(index)) - .ok_or(Error::WrongToken)? - .as_mut(); - let new_token = self.rx.add(&[], &mut [buffer])?; - // If the RX buffer somehow gets assigned a different token, then our safety assumptions - // are broken and we can't safely continue to do anything with the device. - assert_eq!(new_token, index); - } - - if self.rx.should_notify() { - self.transport.notify(RX_QUEUE_IDX); - } - - Ok(()) - } - - /// Pops one packet from the RX queue, if there is one pending. Returns the header, and a - /// reference to the buffer containing the body. - /// - /// Returns `None` if there is no pending packet. - fn pop_packet_from_rx_queue(&mut self) -> Result> { - let Some(token) = self.rx.peek_used() else { - return Ok(None); - }; - - // Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same - // buffer to `pop_used` as we previously passed to `add` for the token. Once we add the - // buffer back to the RX queue then we don't access it again until next time it is popped. - let (header, body) = unsafe { - let buffer = self.rx_queue_buffers[usize::from(token)].as_mut(); - let _len = self.rx.pop_used(token, &[], &mut [buffer])?; - - // Read the header and body from the buffer. Don't check the result yet, because we need - // to add the buffer back to the queue either way. - let header_result = read_header_and_body(buffer); - if header_result.is_err() { - // If there was an error, add the buffer back immediately. Ignore any errors, as we - // need to return the first error. - let _ = self.add_buffer_to_rx_queue(token); - } - - header_result - }?; - - debug!("Received packet {:?}. Op {:?}", header, header.op()); - Ok(Some((header, body, token))) - } } fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> { - // Shouldn't panic, because we know `RX_BUFFER_SIZE > size_of::()`. - let header = VirtioVsockHdr::read_from_prefix(buffer).unwrap(); + // This could fail if the device returns a buffer used length shorter than the header size. + let header = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?; let body_length = header.len() as usize; // This could fail if the device returns an unreasonably long body length. let data_end = size_of::() .checked_add(body_length) .ok_or(SocketError::InvalidNumber)?; - // This could fail if the device returns a body length longer than the buffer we gave it. + // This could fail if the device returns a body length longer than buffer used length it + // returned. let data = buffer .get(size_of::()..data_end) .ok_or(SocketError::BufferTooShort)?; diff --git a/src/queue.rs b/src/queue.rs index 3573a39d..37854b15 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -1,5 +1,8 @@ #![deny(unsafe_op_in_unsafe_fn)] +#[cfg(feature = "alloc")] +pub mod owning; + use crate::hal::{BufferDirection, Dma, Hal, PhysAddr}; use crate::transport::Transport; use crate::{align_up, nonnull_slice_from_raw_parts, pages, Error, Result, PAGE_SIZE}; diff --git a/src/queue/owning.rs b/src/queue/owning.rs new file mode 100644 index 00000000..7b9cf086 --- /dev/null +++ b/src/queue/owning.rs @@ -0,0 +1,149 @@ +use super::VirtQueue; +use crate::{transport::Transport, Error, Hal, Result}; +use alloc::boxed::Box; +use core::convert::TryInto; +use core::ptr::{null_mut, NonNull}; +use zerocopy::FromZeroes; + +/// A wrapper around [`Queue`] that owns all the buffers that are passed to the queue. +#[derive(Debug)] +pub struct OwningQueue { + queue: VirtQueue, + buffers: [NonNull<[u8; BUFFER_SIZE]>; SIZE], +} + +impl OwningQueue { + /// Constructs a new `OwningQueue` wrapping around the given `VirtQueue`. + /// + /// This will allocate `SIZE` buffers of `BUFFER_SIZE` bytes each and add them to the queue. + /// + /// The caller is responsible for notifying the device if `should_notify` returns true. + pub fn new(mut queue: VirtQueue) -> Result { + let mut buffers = [null_mut(); SIZE]; + for (i, queue_buffer) in buffers.iter_mut().enumerate() { + let mut buffer: Box<[u8; BUFFER_SIZE]> = FromZeroes::new_box_zeroed(); + // SAFETY: The buffer lives as long as the queue, as specified in the function safety + // requirement, and we don't access it until it is popped. + let token = unsafe { queue.add(&[], &mut [buffer.as_mut_slice()]) }?; + assert_eq!(i, token.into()); + *queue_buffer = Box::into_raw(buffer); + } + let buffers = buffers.map(|ptr| NonNull::new(ptr).unwrap()); + + Ok(Self { queue, buffers }) + } + + /// Returns whether the driver should notify the device after adding a new buffer to the + /// virtqueue. + /// + /// This will be false if the device has supressed notifications. + pub fn should_notify(&self) -> bool { + self.queue.should_notify() + } + + /// Adds the buffer at the given index in `buffers` back to the queue. + /// + /// Automatically notifies the device if required. + /// + /// # Safety + /// + /// The buffer must not currently be in the RX queue, and no other references to it must exist + /// between when this method is called and when it is popped from the queue. + unsafe fn add_buffer_to_queue(&mut self, index: u16, transport: &mut impl Transport) -> Result { + // SAFETY: The buffer lives as long as the queue, and the caller guarantees that it's not + // currently in the queue or referred to anywhere else until it is popped. + unsafe { + let buffer = self + .buffers + .get_mut(usize::from(index)) + .ok_or(Error::WrongToken)? + .as_mut(); + let new_token = self.queue.add(&[], &mut [buffer])?; + // If the RX buffer somehow gets assigned a different token, then our safety assumptions + // are broken and we can't safely continue to do anything with the device. + assert_eq!(new_token, index); + } + + if self.queue.should_notify() { + transport.notify(self.queue.queue_idx); + } + + Ok(()) + } + + fn pop(&mut self) -> Result> { + let Some(token) = self.queue.peek_used() else { + return Ok(None); + }; + + // SAFETY: The device has told us it has finished using the buffer, and there are no other + // references to it. + let buffer = unsafe { self.buffers[usize::from(token)].as_mut() }; + // SAFETY: We maintain a consistent mapping of tokens to buffers, so we pass the same buffer + // to `pop_used` as we previously passed to `add` for the token. Once we add the buffer back + // to the RX queue then we don't access it again until next time it is popped. + let len = unsafe { self.queue.pop_used(token, &[], &mut [buffer])? } + .try_into() + .unwrap(); + + Ok(Some((&buffer[0..len], token))) + } + + /// Checks whether there are any buffers which the device has marked as used so the driver + /// should now process. If so, passes the first one to the `handle` function and then adds it + /// back to the queue. + /// + /// Returns an error if there is an error accessing the queue or `handler` returns an error. + /// Returns `Ok(None)` if there are no pending buffers to handle, or if `handler` returns + /// `Ok(None)`. + /// + /// If `handler` panics then the buffer will not be added back to the queue, so this should be + /// avoided. + pub fn poll( + &mut self, + transport: &mut impl Transport, + handler: impl FnOnce(&[u8]) -> Result>, + ) -> Result> { + let Some((buffer, token)) = self.pop()? else { + return Ok(None); + }; + + let result = handler(buffer); + + // SAFETY: The buffer was just popped from the queue so it's not in it, and there won't be + // any other references until next time it's popped. + unsafe { + self.add_buffer_to_queue(token, transport)?; + } + + result + } +} + +// SAFETY: The `buffers` can be accessed from any thread. +unsafe impl Send + for OwningQueue +where + VirtQueue: Send, +{ +} + +// SAFETY: An `&OwningQueue` only allows calling `should_notify`. +unsafe impl Sync + for OwningQueue +where + VirtQueue: Sync, +{ +} + +impl Drop + for OwningQueue +{ + fn drop(&mut self) { + for buffer in self.buffers { + // Safe because we obtained the buffer pointer from Box::into_raw, and it won't be used + // anywhere else after the queue is destroyed. + unsafe { drop(Box::from_raw(buffer.as_ptr())) }; + } + } +}