Skip to content

Commit

Permalink
Merge pull request #149 from rcore-os/ownedqueue
Browse files Browse the repository at this point in the history
Add OwningQueue, a wrapper around VirtQueue that allocates its own buffers
  • Loading branch information
qwandor authored Jul 24, 2024
2 parents d0ea169 + 6abd592 commit 1a9e99d
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 117 deletions.
131 changes: 14 additions & 117 deletions src/device/socket/vsock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@ use super::protocol::{
};
use super::DEFAULT_RX_BUFFER_SIZE;
use crate::hal::Hal;
use crate::queue::VirtQueue;
use crate::queue::{owning::OwningQueue, VirtQueue};
use crate::transport::Transport;
use crate::volatile::volread;
use crate::{Error, Result};
use alloc::boxed::Box;
use crate::Result;
use core::mem::size_of;
use core::ptr::{null_mut, NonNull};
use log::debug;
use zerocopy::{AsBytes, FromBytes, FromZeroes};
use zerocopy::{AsBytes, FromBytes};

pub(crate) const RX_QUEUE_IDX: u16 = 0;
pub(crate) const TX_QUEUE_IDX: u16 = 1;
Expand Down Expand Up @@ -222,30 +220,13 @@ pub struct VirtIOSocket<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize = DEFA
{
transport: T,
/// Virtqueue to receive packets.
rx: VirtQueue<H, { QUEUE_SIZE }>,
rx: OwningQueue<H, QUEUE_SIZE, RX_BUFFER_SIZE>,
tx: VirtQueue<H, { QUEUE_SIZE }>,
/// Virtqueue to receive events from the device.
event: VirtQueue<H, { QUEUE_SIZE }>,
/// The guest_cid field contains the guest’s context ID, which uniquely identifies
/// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
guest_cid: u64,
rx_queue_buffers: [NonNull<[u8; RX_BUFFER_SIZE]>; QUEUE_SIZE],
}

// SAFETY: The `rx_queue_buffers` can be accessed from any thread.
unsafe impl<H: Hal, T: Transport + Send, const RX_BUFFER_SIZE: usize> Send
for VirtIOSocket<H, T, RX_BUFFER_SIZE>
where
VirtQueue<H, QUEUE_SIZE>: Send,
{
}

// SAFETY: A `&VirtIOSocket` only allows reading the guest CID from a field.
unsafe impl<H: Hal, T: Transport + Sync, const RX_BUFFER_SIZE: usize> Sync
for VirtIOSocket<H, T, RX_BUFFER_SIZE>
where
VirtQueue<H, QUEUE_SIZE>: Sync,
{
}

impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> Drop
Expand All @@ -257,12 +238,6 @@ impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> Drop
self.transport.queue_unset(RX_QUEUE_IDX);
self.transport.queue_unset(TX_QUEUE_IDX);
self.transport.queue_unset(EVENT_QUEUE_IDX);

for buffer in self.rx_queue_buffers {
// Safe because we obtained the RX buffer pointer from Box::into_raw, and it won't be
// used anywhere else after the driver is destroyed.
unsafe { drop(Box::from_raw(buffer.as_ptr())) };
}
}
}

Expand All @@ -281,7 +256,7 @@ impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BU
};
debug!("guest cid: {guest_cid:?}");

let mut rx = VirtQueue::new(
let rx = VirtQueue::new(
&mut transport,
RX_QUEUE_IDX,
negotiated_features.contains(Feature::RING_INDIRECT_DESC),
Expand All @@ -300,17 +275,7 @@ impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BU
negotiated_features.contains(Feature::RING_EVENT_IDX),
)?;

// Allocate and add buffers for the RX queue.
let mut rx_queue_buffers = [null_mut(); QUEUE_SIZE];
for (i, rx_queue_buffer) in rx_queue_buffers.iter_mut().enumerate() {
let mut buffer: Box<[u8; RX_BUFFER_SIZE]> = FromZeroes::new_box_zeroed();
// Safe because the buffer lives as long as the queue, as specified in the function
// safety requirement, and we don't access it until it is popped.
let token = unsafe { rx.add(&[], &mut [buffer.as_mut_slice()]) }?;
assert_eq!(i, token.into());
*rx_queue_buffer = Box::into_raw(buffer);
}
let rx_queue_buffers = rx_queue_buffers.map(|ptr| NonNull::new(ptr).unwrap());
let rx = OwningQueue::new(rx)?;

transport.finish_init();
if rx.should_notify() {
Expand All @@ -323,7 +288,6 @@ impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BU
tx,
event,
guest_cid,
rx_queue_buffers,
})
}

Expand Down Expand Up @@ -412,18 +376,10 @@ impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BU
&mut self,
handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>,
) -> Result<Option<VsockEvent>> {
let Some((header, body, token)) = self.pop_packet_from_rx_queue()? else {
return Ok(None);
};

let result = VsockEvent::from_header(&header).and_then(|event| handler(event, body));

unsafe {
// TODO: What about if both handler and this give errors?
self.add_buffer_to_rx_queue(token)?;
}

result
self.rx.poll(&mut self.transport, |buffer| {
let (header, body) = read_header_and_body(buffer)?;
VsockEvent::from_header(&header).and_then(|event| handler(event, body))
})
}

/// Requests to shut down the connection cleanly, sending hints about whether we will send or
Expand Down Expand Up @@ -481,78 +437,19 @@ impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BU
};
Ok(())
}

/// Adds the buffer at the given index in `rx_queue_buffers` back to the RX queue.
///
/// # Safety
///
/// The buffer must not currently be in the RX queue, and no other references to it must exist
/// between when this method is called and when it is popped from the queue.
unsafe fn add_buffer_to_rx_queue(&mut self, index: u16) -> Result {
// Safe because the buffer lives as long as the queue, and the caller guarantees that it's
// not currently in the queue or referred to anywhere else until it is popped.
unsafe {
let buffer = self
.rx_queue_buffers
.get_mut(usize::from(index))
.ok_or(Error::WrongToken)?
.as_mut();
let new_token = self.rx.add(&[], &mut [buffer])?;
// If the RX buffer somehow gets assigned a different token, then our safety assumptions
// are broken and we can't safely continue to do anything with the device.
assert_eq!(new_token, index);
}

if self.rx.should_notify() {
self.transport.notify(RX_QUEUE_IDX);
}

Ok(())
}

/// Pops one packet from the RX queue, if there is one pending. Returns the header, and a
/// reference to the buffer containing the body.
///
/// Returns `None` if there is no pending packet.
fn pop_packet_from_rx_queue(&mut self) -> Result<Option<(VirtioVsockHdr, &[u8], u16)>> {
let Some(token) = self.rx.peek_used() else {
return Ok(None);
};

// Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same
// buffer to `pop_used` as we previously passed to `add` for the token. Once we add the
// buffer back to the RX queue then we don't access it again until next time it is popped.
let (header, body) = unsafe {
let buffer = self.rx_queue_buffers[usize::from(token)].as_mut();
let _len = self.rx.pop_used(token, &[], &mut [buffer])?;

// Read the header and body from the buffer. Don't check the result yet, because we need
// to add the buffer back to the queue either way.
let header_result = read_header_and_body(buffer);
if header_result.is_err() {
// If there was an error, add the buffer back immediately. Ignore any errors, as we
// need to return the first error.
let _ = self.add_buffer_to_rx_queue(token);
}

header_result
}?;

debug!("Received packet {:?}. Op {:?}", header, header.op());
Ok(Some((header, body, token)))
}
}

fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> {
// Shouldn't panic, because we know `RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>()`.
let header = VirtioVsockHdr::read_from_prefix(buffer).unwrap();
// This could fail if the device returns a buffer used length shorter than the header size.
let header = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?;
let body_length = header.len() as usize;

// This could fail if the device returns an unreasonably long body length.
let data_end = size_of::<VirtioVsockHdr>()
.checked_add(body_length)
.ok_or(SocketError::InvalidNumber)?;
// This could fail if the device returns a body length longer than the buffer we gave it.
// This could fail if the device returns a body length longer than buffer used length it
// returned.
let data = buffer
.get(size_of::<VirtioVsockHdr>()..data_end)
.ok_or(SocketError::BufferTooShort)?;
Expand Down
3 changes: 3 additions & 0 deletions src/queue.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#![deny(unsafe_op_in_unsafe_fn)]

#[cfg(feature = "alloc")]
pub mod owning;

use crate::hal::{BufferDirection, Dma, Hal, PhysAddr};
use crate::transport::Transport;
use crate::{align_up, nonnull_slice_from_raw_parts, pages, Error, Result, PAGE_SIZE};
Expand Down
149 changes: 149 additions & 0 deletions src/queue/owning.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
use super::VirtQueue;
use crate::{transport::Transport, Error, Hal, Result};
use alloc::boxed::Box;
use core::convert::TryInto;
use core::ptr::{null_mut, NonNull};
use zerocopy::FromZeroes;

/// A wrapper around [`Queue`] that owns all the buffers that are passed to the queue.
#[derive(Debug)]
pub struct OwningQueue<H: Hal, const SIZE: usize, const BUFFER_SIZE: usize> {
queue: VirtQueue<H, SIZE>,
buffers: [NonNull<[u8; BUFFER_SIZE]>; SIZE],
}

impl<H: Hal, const SIZE: usize, const BUFFER_SIZE: usize> OwningQueue<H, SIZE, BUFFER_SIZE> {
/// Constructs a new `OwningQueue` wrapping around the given `VirtQueue`.
///
/// This will allocate `SIZE` buffers of `BUFFER_SIZE` bytes each and add them to the queue.
///
/// The caller is responsible for notifying the device if `should_notify` returns true.
pub fn new(mut queue: VirtQueue<H, SIZE>) -> Result<Self> {
let mut buffers = [null_mut(); SIZE];
for (i, queue_buffer) in buffers.iter_mut().enumerate() {
let mut buffer: Box<[u8; BUFFER_SIZE]> = FromZeroes::new_box_zeroed();
// SAFETY: The buffer lives as long as the queue, as specified in the function safety
// requirement, and we don't access it until it is popped.
let token = unsafe { queue.add(&[], &mut [buffer.as_mut_slice()]) }?;
assert_eq!(i, token.into());
*queue_buffer = Box::into_raw(buffer);
}
let buffers = buffers.map(|ptr| NonNull::new(ptr).unwrap());

Ok(Self { queue, buffers })
}

/// Returns whether the driver should notify the device after adding a new buffer to the
/// virtqueue.
///
/// This will be false if the device has supressed notifications.
pub fn should_notify(&self) -> bool {
self.queue.should_notify()
}

/// Adds the buffer at the given index in `buffers` back to the queue.
///
/// Automatically notifies the device if required.
///
/// # Safety
///
/// The buffer must not currently be in the RX queue, and no other references to it must exist
/// between when this method is called and when it is popped from the queue.
unsafe fn add_buffer_to_queue(&mut self, index: u16, transport: &mut impl Transport) -> Result {
// SAFETY: The buffer lives as long as the queue, and the caller guarantees that it's not
// currently in the queue or referred to anywhere else until it is popped.
unsafe {
let buffer = self
.buffers
.get_mut(usize::from(index))
.ok_or(Error::WrongToken)?
.as_mut();
let new_token = self.queue.add(&[], &mut [buffer])?;
// If the RX buffer somehow gets assigned a different token, then our safety assumptions
// are broken and we can't safely continue to do anything with the device.
assert_eq!(new_token, index);
}

if self.queue.should_notify() {
transport.notify(self.queue.queue_idx);
}

Ok(())
}

fn pop(&mut self) -> Result<Option<(&[u8], u16)>> {
let Some(token) = self.queue.peek_used() else {
return Ok(None);
};

// SAFETY: The device has told us it has finished using the buffer, and there are no other
// references to it.
let buffer = unsafe { self.buffers[usize::from(token)].as_mut() };
// SAFETY: We maintain a consistent mapping of tokens to buffers, so we pass the same buffer
// to `pop_used` as we previously passed to `add` for the token. Once we add the buffer back
// to the RX queue then we don't access it again until next time it is popped.
let len = unsafe { self.queue.pop_used(token, &[], &mut [buffer])? }
.try_into()
.unwrap();

Ok(Some((&buffer[0..len], token)))
}

/// Checks whether there are any buffers which the device has marked as used so the driver
/// should now process. If so, passes the first one to the `handle` function and then adds it
/// back to the queue.
///
/// Returns an error if there is an error accessing the queue or `handler` returns an error.
/// Returns `Ok(None)` if there are no pending buffers to handle, or if `handler` returns
/// `Ok(None)`.
///
/// If `handler` panics then the buffer will not be added back to the queue, so this should be
/// avoided.
pub fn poll<T>(
&mut self,
transport: &mut impl Transport,
handler: impl FnOnce(&[u8]) -> Result<Option<T>>,
) -> Result<Option<T>> {
let Some((buffer, token)) = self.pop()? else {
return Ok(None);
};

let result = handler(buffer);

// SAFETY: The buffer was just popped from the queue so it's not in it, and there won't be
// any other references until next time it's popped.
unsafe {
self.add_buffer_to_queue(token, transport)?;
}

result
}
}

// SAFETY: The `buffers` can be accessed from any thread.
unsafe impl<H: Hal, const SIZE: usize, const BUFFER_SIZE: usize> Send
for OwningQueue<H, SIZE, BUFFER_SIZE>
where
VirtQueue<H, SIZE>: Send,
{
}

// SAFETY: An `&OwningQueue` only allows calling `should_notify`.
unsafe impl<H: Hal, const SIZE: usize, const BUFFER_SIZE: usize> Sync
for OwningQueue<H, SIZE, BUFFER_SIZE>
where
VirtQueue<H, SIZE>: Sync,
{
}

impl<H: Hal, const SIZE: usize, const BUFFER_SIZE: usize> Drop
for OwningQueue<H, SIZE, BUFFER_SIZE>
{
fn drop(&mut self) {
for buffer in self.buffers {
// Safe because we obtained the buffer pointer from Box::into_raw, and it won't be used
// anywhere else after the queue is destroyed.
unsafe { drop(Box::from_raw(buffer.as_ptr())) };
}
}
}

0 comments on commit 1a9e99d

Please sign in to comment.