Skip to content

Commit

Permalink
Add OwningQueue, a wrapper around VirtQueue that allocates its own bu…
Browse files Browse the repository at this point in the history
…ffers.

This is based on how VirtIOSocket was managing its RX queue buffers, so
is initially used there.
  • Loading branch information
qwandor committed Jul 12, 2024
1 parent d0ea169 commit 6abd592
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 117 deletions.
131 changes: 14 additions & 117 deletions src/device/socket/vsock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@ use super::protocol::{
};
use super::DEFAULT_RX_BUFFER_SIZE;
use crate::hal::Hal;
use crate::queue::VirtQueue;
use crate::queue::{owning::OwningQueue, VirtQueue};
use crate::transport::Transport;
use crate::volatile::volread;
use crate::{Error, Result};
use alloc::boxed::Box;
use crate::Result;
use core::mem::size_of;
use core::ptr::{null_mut, NonNull};
use log::debug;
use zerocopy::{AsBytes, FromBytes, FromZeroes};
use zerocopy::{AsBytes, FromBytes};

pub(crate) const RX_QUEUE_IDX: u16 = 0;
pub(crate) const TX_QUEUE_IDX: u16 = 1;
Expand Down Expand Up @@ -222,30 +220,13 @@ pub struct VirtIOSocket<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize = DEFA
{
transport: T,
/// Virtqueue to receive packets.
rx: VirtQueue<H, { QUEUE_SIZE }>,
rx: OwningQueue<H, QUEUE_SIZE, RX_BUFFER_SIZE>,
tx: VirtQueue<H, { QUEUE_SIZE }>,
/// Virtqueue to receive events from the device.
event: VirtQueue<H, { QUEUE_SIZE }>,
/// The guest_cid field contains the guest’s context ID, which uniquely identifies
/// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
guest_cid: u64,
rx_queue_buffers: [NonNull<[u8; RX_BUFFER_SIZE]>; QUEUE_SIZE],
}

// SAFETY: The `rx_queue_buffers` can be accessed from any thread.
unsafe impl<H: Hal, T: Transport + Send, const RX_BUFFER_SIZE: usize> Send
for VirtIOSocket<H, T, RX_BUFFER_SIZE>
where
VirtQueue<H, QUEUE_SIZE>: Send,
{
}

// SAFETY: A `&VirtIOSocket` only allows reading the guest CID from a field.
unsafe impl<H: Hal, T: Transport + Sync, const RX_BUFFER_SIZE: usize> Sync
for VirtIOSocket<H, T, RX_BUFFER_SIZE>
where
VirtQueue<H, QUEUE_SIZE>: Sync,
{
}

impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> Drop
Expand All @@ -257,12 +238,6 @@ impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> Drop
self.transport.queue_unset(RX_QUEUE_IDX);
self.transport.queue_unset(TX_QUEUE_IDX);
self.transport.queue_unset(EVENT_QUEUE_IDX);

for buffer in self.rx_queue_buffers {
// Safe because we obtained the RX buffer pointer from Box::into_raw, and it won't be
// used anywhere else after the driver is destroyed.
unsafe { drop(Box::from_raw(buffer.as_ptr())) };
}
}
}

Expand All @@ -281,7 +256,7 @@ impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BU
};
debug!("guest cid: {guest_cid:?}");

let mut rx = VirtQueue::new(
let rx = VirtQueue::new(
&mut transport,
RX_QUEUE_IDX,
negotiated_features.contains(Feature::RING_INDIRECT_DESC),
Expand All @@ -300,17 +275,7 @@ impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BU
negotiated_features.contains(Feature::RING_EVENT_IDX),
)?;

// Allocate and add buffers for the RX queue.
let mut rx_queue_buffers = [null_mut(); QUEUE_SIZE];
for (i, rx_queue_buffer) in rx_queue_buffers.iter_mut().enumerate() {
let mut buffer: Box<[u8; RX_BUFFER_SIZE]> = FromZeroes::new_box_zeroed();
// Safe because the buffer lives as long as the queue, as specified in the function
// safety requirement, and we don't access it until it is popped.
let token = unsafe { rx.add(&[], &mut [buffer.as_mut_slice()]) }?;
assert_eq!(i, token.into());
*rx_queue_buffer = Box::into_raw(buffer);
}
let rx_queue_buffers = rx_queue_buffers.map(|ptr| NonNull::new(ptr).unwrap());
let rx = OwningQueue::new(rx)?;

transport.finish_init();
if rx.should_notify() {
Expand All @@ -323,7 +288,6 @@ impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BU
tx,
event,
guest_cid,
rx_queue_buffers,
})
}

Expand Down Expand Up @@ -412,18 +376,10 @@ impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BU
&mut self,
handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>,
) -> Result<Option<VsockEvent>> {
let Some((header, body, token)) = self.pop_packet_from_rx_queue()? else {
return Ok(None);
};

let result = VsockEvent::from_header(&header).and_then(|event| handler(event, body));

unsafe {
// TODO: What about if both handler and this give errors?
self.add_buffer_to_rx_queue(token)?;
}

result
self.rx.poll(&mut self.transport, |buffer| {
let (header, body) = read_header_and_body(buffer)?;
VsockEvent::from_header(&header).and_then(|event| handler(event, body))
})
}

/// Requests to shut down the connection cleanly, sending hints about whether we will send or
Expand Down Expand Up @@ -481,78 +437,19 @@ impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BU
};
Ok(())
}

/// Adds the buffer at the given index in `rx_queue_buffers` back to the RX queue.
///
/// # Safety
///
/// The buffer must not currently be in the RX queue, and no other references to it must exist
/// between when this method is called and when it is popped from the queue.
unsafe fn add_buffer_to_rx_queue(&mut self, index: u16) -> Result {
// Safe because the buffer lives as long as the queue, and the caller guarantees that it's
// not currently in the queue or referred to anywhere else until it is popped.
unsafe {
let buffer = self
.rx_queue_buffers
.get_mut(usize::from(index))
.ok_or(Error::WrongToken)?
.as_mut();
let new_token = self.rx.add(&[], &mut [buffer])?;
// If the RX buffer somehow gets assigned a different token, then our safety assumptions
// are broken and we can't safely continue to do anything with the device.
assert_eq!(new_token, index);
}

if self.rx.should_notify() {
self.transport.notify(RX_QUEUE_IDX);
}

Ok(())
}

/// Pops one packet from the RX queue, if there is one pending. Returns the header, and a
/// reference to the buffer containing the body.
///
/// Returns `None` if there is no pending packet.
fn pop_packet_from_rx_queue(&mut self) -> Result<Option<(VirtioVsockHdr, &[u8], u16)>> {
let Some(token) = self.rx.peek_used() else {
return Ok(None);
};

// Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same
// buffer to `pop_used` as we previously passed to `add` for the token. Once we add the
// buffer back to the RX queue then we don't access it again until next time it is popped.
let (header, body) = unsafe {
let buffer = self.rx_queue_buffers[usize::from(token)].as_mut();
let _len = self.rx.pop_used(token, &[], &mut [buffer])?;

// Read the header and body from the buffer. Don't check the result yet, because we need
// to add the buffer back to the queue either way.
let header_result = read_header_and_body(buffer);
if header_result.is_err() {
// If there was an error, add the buffer back immediately. Ignore any errors, as we
// need to return the first error.
let _ = self.add_buffer_to_rx_queue(token);
}

header_result
}?;

debug!("Received packet {:?}. Op {:?}", header, header.op());
Ok(Some((header, body, token)))
}
}

fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> {
// Shouldn't panic, because we know `RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>()`.
let header = VirtioVsockHdr::read_from_prefix(buffer).unwrap();
// This could fail if the device returns a buffer used length shorter than the header size.
let header = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?;
let body_length = header.len() as usize;

// This could fail if the device returns an unreasonably long body length.
let data_end = size_of::<VirtioVsockHdr>()
.checked_add(body_length)
.ok_or(SocketError::InvalidNumber)?;
// This could fail if the device returns a body length longer than the buffer we gave it.
// This could fail if the device returns a body length longer than buffer used length it
// returned.
let data = buffer
.get(size_of::<VirtioVsockHdr>()..data_end)
.ok_or(SocketError::BufferTooShort)?;
Expand Down
3 changes: 3 additions & 0 deletions src/queue.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#![deny(unsafe_op_in_unsafe_fn)]

#[cfg(feature = "alloc")]
pub mod owning;

use crate::hal::{BufferDirection, Dma, Hal, PhysAddr};
use crate::transport::Transport;
use crate::{align_up, nonnull_slice_from_raw_parts, pages, Error, Result, PAGE_SIZE};
Expand Down
149 changes: 149 additions & 0 deletions src/queue/owning.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
use super::VirtQueue;
use crate::{transport::Transport, Error, Hal, Result};
use alloc::boxed::Box;
use core::convert::TryInto;
use core::ptr::{null_mut, NonNull};
use zerocopy::FromZeroes;

/// A wrapper around [`Queue`] that owns all the buffers that are passed to the queue.
#[derive(Debug)]
pub struct OwningQueue<H: Hal, const SIZE: usize, const BUFFER_SIZE: usize> {
queue: VirtQueue<H, SIZE>,
buffers: [NonNull<[u8; BUFFER_SIZE]>; SIZE],
}

impl<H: Hal, const SIZE: usize, const BUFFER_SIZE: usize> OwningQueue<H, SIZE, BUFFER_SIZE> {
/// Constructs a new `OwningQueue` wrapping around the given `VirtQueue`.
///
/// This will allocate `SIZE` buffers of `BUFFER_SIZE` bytes each and add them to the queue.
///
/// The caller is responsible for notifying the device if `should_notify` returns true.
pub fn new(mut queue: VirtQueue<H, SIZE>) -> Result<Self> {
let mut buffers = [null_mut(); SIZE];
for (i, queue_buffer) in buffers.iter_mut().enumerate() {
let mut buffer: Box<[u8; BUFFER_SIZE]> = FromZeroes::new_box_zeroed();
// SAFETY: The buffer lives as long as the queue, as specified in the function safety
// requirement, and we don't access it until it is popped.
let token = unsafe { queue.add(&[], &mut [buffer.as_mut_slice()]) }?;
assert_eq!(i, token.into());
*queue_buffer = Box::into_raw(buffer);
}
let buffers = buffers.map(|ptr| NonNull::new(ptr).unwrap());

Ok(Self { queue, buffers })
}

/// Returns whether the driver should notify the device after adding a new buffer to the
/// virtqueue.
///
/// This will be false if the device has supressed notifications.
pub fn should_notify(&self) -> bool {
self.queue.should_notify()
}

/// Adds the buffer at the given index in `buffers` back to the queue.
///
/// Automatically notifies the device if required.
///
/// # Safety
///
/// The buffer must not currently be in the RX queue, and no other references to it must exist
/// between when this method is called and when it is popped from the queue.
unsafe fn add_buffer_to_queue(&mut self, index: u16, transport: &mut impl Transport) -> Result {
// SAFETY: The buffer lives as long as the queue, and the caller guarantees that it's not
// currently in the queue or referred to anywhere else until it is popped.
unsafe {
let buffer = self
.buffers
.get_mut(usize::from(index))
.ok_or(Error::WrongToken)?
.as_mut();
let new_token = self.queue.add(&[], &mut [buffer])?;
// If the RX buffer somehow gets assigned a different token, then our safety assumptions
// are broken and we can't safely continue to do anything with the device.
assert_eq!(new_token, index);
}

if self.queue.should_notify() {
transport.notify(self.queue.queue_idx);
}

Ok(())
}

fn pop(&mut self) -> Result<Option<(&[u8], u16)>> {
let Some(token) = self.queue.peek_used() else {
return Ok(None);
};

// SAFETY: The device has told us it has finished using the buffer, and there are no other
// references to it.
let buffer = unsafe { self.buffers[usize::from(token)].as_mut() };
// SAFETY: We maintain a consistent mapping of tokens to buffers, so we pass the same buffer
// to `pop_used` as we previously passed to `add` for the token. Once we add the buffer back
// to the RX queue then we don't access it again until next time it is popped.
let len = unsafe { self.queue.pop_used(token, &[], &mut [buffer])? }
.try_into()
.unwrap();

Ok(Some((&buffer[0..len], token)))
}

/// Checks whether there are any buffers which the device has marked as used so the driver
/// should now process. If so, passes the first one to the `handle` function and then adds it
/// back to the queue.
///
/// Returns an error if there is an error accessing the queue or `handler` returns an error.
/// Returns `Ok(None)` if there are no pending buffers to handle, or if `handler` returns
/// `Ok(None)`.
///
/// If `handler` panics then the buffer will not be added back to the queue, so this should be
/// avoided.
pub fn poll<T>(
&mut self,
transport: &mut impl Transport,
handler: impl FnOnce(&[u8]) -> Result<Option<T>>,
) -> Result<Option<T>> {
let Some((buffer, token)) = self.pop()? else {
return Ok(None);
};

let result = handler(buffer);

// SAFETY: The buffer was just popped from the queue so it's not in it, and there won't be
// any other references until next time it's popped.
unsafe {
self.add_buffer_to_queue(token, transport)?;
}

result
}
}

// SAFETY: The `buffers` can be accessed from any thread.
unsafe impl<H: Hal, const SIZE: usize, const BUFFER_SIZE: usize> Send
for OwningQueue<H, SIZE, BUFFER_SIZE>
where
VirtQueue<H, SIZE>: Send,
{
}

// SAFETY: An `&OwningQueue` only allows calling `should_notify`.
unsafe impl<H: Hal, const SIZE: usize, const BUFFER_SIZE: usize> Sync
for OwningQueue<H, SIZE, BUFFER_SIZE>
where
VirtQueue<H, SIZE>: Sync,
{
}

impl<H: Hal, const SIZE: usize, const BUFFER_SIZE: usize> Drop
for OwningQueue<H, SIZE, BUFFER_SIZE>
{
fn drop(&mut self) {
for buffer in self.buffers {
// Safe because we obtained the buffer pointer from Box::into_raw, and it won't be used
// anywhere else after the queue is destroyed.
unsafe { drop(Box::from_raw(buffer.as_ptr())) };
}
}
}

0 comments on commit 6abd592

Please sign in to comment.