Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use config_generation for safe multi-part config reads. #169

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/device/blk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,14 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);

// Read configuration space.
let capacity = transport.read_config_space::<u32>(offset_of!(BlkConfig, capacity_low))?
as u64
| (transport.read_config_space::<u32>(offset_of!(BlkConfig, capacity_high))? as u64)
<< 32;
let capacity = transport.read_consistent(|| {
Ok(
transport.read_config_space::<u32>(offset_of!(BlkConfig, capacity_low))? as u64
| (transport.read_config_space::<u32>(offset_of!(BlkConfig, capacity_high))?
as u64)
<< 32,
)
})?;
info!("found a block device of size {}KB", capacity / 2);

let queue = VirtQueue::new(
Expand Down
10 changes: 6 additions & 4 deletions src/device/console.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,12 @@ impl<H: Hal, T: Transport> VirtIOConsole<H, T> {
/// Returns the size of the console, if the device supports reporting this.
pub fn size(&self) -> Result<Option<Size>> {
if self.negotiated_features.contains(Features::SIZE) {
Ok(Some(Size {
columns: self.transport.read_config_space(offset_of!(Config, cols))?,
rows: self.transport.read_config_space(offset_of!(Config, rows))?,
}))
self.transport.read_consistent(|| {
Ok(Some(Size {
columns: self.transport.read_config_space(offset_of!(Config, cols))?,
rows: self.transport.read_config_space(offset_of!(Config, rows))?,
}))
})
} else {
Ok(None)
}
Expand Down
3 changes: 2 additions & 1 deletion src/device/net/dev_raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> VirtIONetRaw<H, T, QUEUE_SIZ
info!("negotiated_features {:?}", negotiated_features);

// Read configuration space.
let mac = transport.read_config_space(offset_of!(Config, mac))?;
let mac =
transport.read_consistent(|| transport.read_config_space(offset_of!(Config, mac)))?;
let status = transport.read_config_space::<Status>(offset_of!(Config, status))?;
debug!("Got MAC={:02x?}, status={:?}", mac, status);

Expand Down
17 changes: 10 additions & 7 deletions src/device/socket/vsock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,16 @@ impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BU

let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);

// Safe because config is a valid pointer to the device configuration space.
let guest_cid = transport
.read_config_space::<u32>(offset_of!(VirtioVsockConfig, guest_cid_low))?
as u64
| (transport.read_config_space::<u32>(offset_of!(VirtioVsockConfig, guest_cid_high))?
as u64)
<< 32;
let guest_cid = transport.read_consistent(|| {
Ok(
transport.read_config_space::<u32>(offset_of!(VirtioVsockConfig, guest_cid_low))?
as u64
| (transport
.read_config_space::<u32>(offset_of!(VirtioVsockConfig, guest_cid_high))?
as u64)
<< 32,
)
})?;
debug!("guest cid: {guest_cid:?}");

let rx = VirtQueue::new(
Expand Down
5 changes: 5 additions & 0 deletions src/transport/fake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ impl<C> Transport for FakeTransport<C> {
pending
}

fn read_config_generation(&self) -> u32 {
self.state.lock().unwrap().config_generation
}

fn read_config_space<T>(&self, offset: usize) -> Result<T, Error> {
assert!(align_of::<T>() <= 4,
"Driver expected config space alignment of {} bytes, but VirtIO only guarantees 4 byte alignment.",
Expand Down Expand Up @@ -133,6 +137,7 @@ pub struct State {
pub guest_page_size: u32,
pub interrupt_pending: bool,
pub queues: Vec<QueueStatus>,
pub config_generation: u32,
}

impl State {
Expand Down
5 changes: 5 additions & 0 deletions src/transport/mmio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,11 @@ impl Transport for MmioTransport {
}
}

fn read_config_generation(&self) -> u32 {
// SAFETY: self.header points to a valid VirtIO MMIO region.
unsafe { volread!(self.header, config_generation) }
}

fn read_config_space<T: FromBytes>(&self, offset: usize) -> Result<T, Error> {
assert!(align_of::<T>() <= 4,
"Driver expected config space alignment of {} bytes, but VirtIO only guarantees 4 byte alignment.",
Expand Down
16 changes: 16 additions & 0 deletions src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ pub trait Transport {
);
}

/// Reads the configuration space generation.
fn read_config_generation(&self) -> u32;

/// Reads a value from the device config space.
fn read_config_space<T: FromBytes>(&self, offset: usize) -> Result<T>;

Expand All @@ -110,6 +113,19 @@ pub trait Transport {
offset: usize,
value: T,
) -> Result<()>;

/// Safely reads multiple fields from config space by ensuring that the config generation is the
/// same before and after all reads, and retrying if not.
fn read_consistent<T>(&self, f: impl Fn() -> Result<T>) -> Result<T> {
loop {
let before = self.read_config_generation();
let result = f();
let after = self.read_config_generation();
if before == after {
break result;
}
}
}
}

bitflags! {
Expand Down
5 changes: 5 additions & 0 deletions src/transport/pci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,11 @@ impl Transport for PciTransport {
isr_status & 0x3 != 0
}

fn read_config_generation(&self) -> u32 {
// SAFETY: self.header points to a valid VirtIO MMIO region.
unsafe { volread!(self.common_cfg, config_generation) }.into()
}

fn read_config_space<T: FromBytes>(&self, offset: usize) -> Result<T, Error> {
assert!(align_of::<T>() <= 4,
"Driver expected config space alignment of {} bytes, but VirtIO only guarantees 4 byte alignment.",
Expand Down
7 changes: 7 additions & 0 deletions src/transport/some.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,13 @@ impl Transport for SomeTransport {
}
}

fn read_config_generation(&self) -> u32 {
match self {
Self::Mmio(mmio) => mmio.read_config_generation(),
Self::Pci(pci) => pci.read_config_generation(),
}
}

fn read_config_space<T: FromBytes>(&self, offset: usize) -> Result<T> {
match self {
Self::Mmio(mmio) => mmio.read_config_space(offset),
Expand Down
Loading