Skip to content

Commit

Permalink
firmware_update: rockusb_fwudate
Browse files Browse the repository at this point in the history
* Use the "new" rockusb driver instead of the `rockusb` rs
  implementation.
* Cleaned up the async logging/state component. Error information should
  flow back via "FwUdateError".
  • Loading branch information
svenrademakers committed Oct 18, 2023
1 parent a28a981 commit 1025080
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 330 deletions.
7 changes: 1 addition & 6 deletions src/api/legacy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use crate::api::streaming_data_service::StreamingDataService;
use crate::app::bmc_application::{BmcApplication, Encoding, UsbConfig};
use crate::app::transfer_action::{TransferType, UpgradeAction, UpgradeType};
use crate::hal::{NodeId, UsbMode, UsbRoute};
use crate::utils::logging_sink;
use actix_multipart::Multipart;
use actix_web::guard::{fn_guard, GuardContext};
use actix_web::http::StatusCode;
Expand All @@ -29,7 +28,6 @@ use std::collections::HashMap;
use std::ops::Deref;
use std::str::FromStr;
use tokio::io::AsyncBufReadExt;
use tokio::sync::mpsc;
use tokio_stream::StreamExt;
type Query = web::Query<std::collections::HashMap<String, String>>;

Expand Down Expand Up @@ -160,10 +158,7 @@ fn get_node_info(_bmc: &BmcApplication) -> impl Into<LegacyResponse> {

async fn set_node_to_msd(bmc: &BmcApplication, query: Query) -> LegacyResult<()> {
let node = get_node_param(&query)?;

let (tx, rx) = mpsc::channel(64);
logging_sink(rx);
bmc.set_node_in_msd(node, UsbRoute::Bmc, tx)
bmc.set_node_in_msd(node, UsbRoute::Bmc)
.await
.map(|_| ())
.map_err(Into::into)
Expand Down
60 changes: 11 additions & 49 deletions src/app/bmc_application.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::firmware_update::transport::FwUpdateTransport;
use crate::firmware_update::{
fw_update_transport, FlashProgress, FlashStatus, SUPPORTED_MSD_DEVICES,
};
use crate::firmware_update::{fw_update_transport, SUPPORTED_DEVICES};
use crate::hal::power_controller::PowerController;
use crate::hal::serial::SerialConnections;
use crate::hal::usbboot;
Expand All @@ -25,12 +23,11 @@ use crate::utils::{string_from_utf16, string_from_utf32};
use anyhow::{ensure, Context};
use log::{debug, info, trace};
use serde::{Deserialize, Serialize};
use std::ops::Deref;
use std::process::Command;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::sleep;

/// Stores which slots are actually used. This information is used to determine
/// for instance, which nodes need to be powered on, when such command is given
pub const ACTIVATED_NODES_KEY: &str = "activated_nodes";
Expand Down Expand Up @@ -223,47 +220,24 @@ impl BmcApplication {
self.power_controller.reset_node(node).await
}

pub async fn set_node_in_msd(
&self,
node: NodeId,
router: UsbRoute,
progress_sender: mpsc::Sender<FlashProgress>,
) -> anyhow::Result<()> {
pub async fn set_node_in_msd(&self, node: NodeId, router: UsbRoute) -> anyhow::Result<()> {
// The SUPPORTED_MSD_DEVICES list contains vid_pids of USB drivers we know will load the
// storage of a node as a MSD device.
self.configure_node_for_fwupgrade(
node,
router,
progress_sender,
SUPPORTED_MSD_DEVICES.deref(),
)
.await
.map(|_| ())
self.configure_node_for_fwupgrade(node, router, SUPPORTED_DEVICES.keys().into_iter())

Check failure on line 226 in src/app/bmc_application.rs

View workflow job for this annotation

GitHub Actions / clippy

useless conversion to the same type: `std::collections::hash_map::Keys<'_, (u16, u16), std::boxed::Box<dyn for<'a> std::ops::Fn(&'a rusb::Device<rusb::GlobalContext>) -> std::pin::Pin<std::boxed::Box<dyn futures::Future<Output = std::result::Result<std::boxed::Box<dyn firmware_update::transport::FwUpdateTransport>, firmware_update::FwUpdateError>> + std::marker::Send>> + std::marker::Send + std::marker::Sync>>`

error: useless conversion to the same type: `std::collections::hash_map::Keys<'_, (u16, u16), std::boxed::Box<dyn for<'a> std::ops::Fn(&'a rusb::Device<rusb::GlobalContext>) -> std::pin::Pin<std::boxed::Box<dyn futures::Future<Output = std::result::Result<std::boxed::Box<dyn firmware_update::transport::FwUpdateTransport>, firmware_update::FwUpdateError>> + std::marker::Send>> + std::marker::Send + std::marker::Sync>>` --> src/app/bmc_application.rs:226:57 | 226 | self.configure_node_for_fwupgrade(node, router, SUPPORTED_DEVICES.keys().into_iter()) | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ help: consider removing `.into_iter()`: `SUPPORTED_DEVICES.keys()` | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#useless_conversion = note: `-D clippy::useless-conversion` implied by `-D warnings`
.await
.map(|_| ())
}

pub async fn configure_node_for_fwupgrade<'a, I>(
&self,
node: NodeId,
router: UsbRoute,
progress_sender: mpsc::Sender<FlashProgress>,
any_of: I,
) -> anyhow::Result<Box<dyn FwUpdateTransport>>
where
I: IntoIterator<Item = &'a (u16, u16)>,
{
let mut progress_state = FlashProgress {
message: String::new(),
status: FlashStatus::Idle,
};

progress_state.message = format!("Powering off node {:?}...", node);
progress_state.status = FlashStatus::Progress {
read_percent: 0,
est_minutes: u64::MAX,
est_seconds: u64::MAX,
};
progress_sender.send(progress_state.clone()).await?;

log::info!("Powering off node {:?}...", node);
self.activate_slot(!node.to_bitfield(), node.to_bitfield())
.await?;
self.pin_controller
Expand All @@ -278,29 +252,17 @@ impl BmcApplication {
self.usb_boot(node, true).await?;
self.configure_usb(config).await?;

progress_state.message = String::from("Prerequisite settings toggled, powering on...");
progress_sender.send(progress_state.clone()).await?;

log::info!("Prerequisite settings toggled, powering on...");
self.activate_slot(node.to_bitfield(), node.to_bitfield())
.await?;

tokio::time::sleep(Duration::from_secs(2)).await;

progress_state.message = String::from("Checking for presence of a USB device...");
progress_sender.send(progress_state.clone()).await?;
log::info!("Checking for presence of a USB device...");

let matches = usbboot::get_usb_devices(any_of)?;
let usb_device = usbboot::extract_one_device(&matches).map_err(|e| {
progress_sender
.try_send(FlashProgress {
status: FlashStatus::Error(e),
message: String::new(),
})
.unwrap();
e
})?;

fw_update_transport(usb_device, progress_sender)?
let usb_device = usbboot::extract_one_device(&matches)?;
fw_update_transport(usb_device)?
.await
.context("USB driver init error")
}
Expand Down
23 changes: 8 additions & 15 deletions src/app/firmware_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
// limitations under the License.
use super::bmc_application::UsbConfig;
use crate::app::bmc_application::BmcApplication;
use crate::utils::{logging_sink, reader_with_crc64, WriteWatcher};
use crate::firmware_update::FwUpdateError;
use crate::utils::{reader_with_crc64, WriteWatcher};
use crate::{
firmware_update::{FlashProgress, FlashingError, SUPPORTED_DEVICES},
firmware_update::SUPPORTED_DEVICES,
hal::{NodeId, UsbRoute},
};
use anyhow::bail;
Expand All @@ -30,10 +31,11 @@ use std::{sync::Arc, time::Duration};
use tokio::fs::OpenOptions;
use tokio::io::sink;
use tokio::io::AsyncReadExt;
use tokio::sync::{mpsc, watch};
use tokio::io::AsyncSeekExt;
use tokio::sync::watch;
use tokio::{
fs,
io::{self, AsyncRead, AsyncSeekExt, AsyncWrite, AsyncWriteExt},
io::{self, AsyncRead, AsyncWrite, AsyncWriteExt},
time::sleep,
};
use tokio_util::sync::CancellationToken;
Expand All @@ -49,7 +51,6 @@ pub struct FirmwareRunner {
size: u64,
cancel: CancellationToken,
written_sender: watch::Sender<u64>,
progress_sender: mpsc::Sender<FlashProgress>,
}

impl FirmwareRunner {
Expand All @@ -60,26 +61,18 @@ impl FirmwareRunner {
cancel: CancellationToken,
written_sender: watch::Sender<u64>,
) -> Self {
let (sender, receiver) = mpsc::channel(16);
logging_sink(receiver);
Self {
reader,
file_name,
size,
cancel,
written_sender,
progress_sender: sender,
}
}

pub async fn flash_node(self, bmc: Arc<BmcApplication>, node: NodeId) -> anyhow::Result<()> {
let mut device = bmc
.configure_node_for_fwupgrade(
node,
UsbRoute::Bmc,
self.progress_sender.clone(),
SUPPORTED_DEVICES.keys(),
)
.configure_node_for_fwupgrade(node, UsbRoute::Bmc, SUPPORTED_DEVICES.keys())
.await?;

let write_watcher = WriteWatcher::new(&mut device, self.written_sender);
Expand All @@ -105,7 +98,7 @@ impl FirmwareRunner {
dev_checksum
);

bail!(FlashingError::ChecksumMismatch)
bail!(FwUpdateError::ChecksumMismatch)
}

log::info!("Flashing successful, restarting device...");
Expand Down
127 changes: 30 additions & 97 deletions src/firmware_update/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,137 +15,70 @@ mod rockusb_fwudate;
use self::rockusb_fwudate::new_rockusb_transport;
mod rpi_fwupdate;
use rpi_fwupdate::new_rpi_transport;
use serde::Serialize;
use thiserror::Error;
pub mod transport;
use self::transport::FwUpdateTransport;
use futures::future::BoxFuture;
use once_cell::sync::Lazy;
use rusb::GlobalContext;
use std::{
collections::HashMap,
fmt::{self, Display},
};
use tokio::sync::mpsc::Sender;

#[allow(clippy::vec_init_then_push)]
pub static SUPPORTED_MSD_DEVICES: Lazy<Vec<(u16, u16)>> = Lazy::new(|| {
let mut supported = Vec::<(u16, u16)>::new();
supported.push(rpi_fwupdate::VID_PID);
supported
});
use std::collections::HashMap;

pub static SUPPORTED_DEVICES: Lazy<HashMap<(u16, u16), FactoryItemCreator>> = Lazy::new(|| {
let mut creators = HashMap::<(u16, u16), FactoryItemCreator>::new();
creators.insert(
rpi_fwupdate::VID_PID,
Box::new(|_, logging| {
Box::pin(async move { new_rpi_transport(&logging).await.map(Into::into) })
}),
Box::new(|_| Box::pin(async move { new_rpi_transport().await.map(Into::into) })),
);

creators.insert(
rockusb_fwudate::RK3588_VID_PID,
Box::new(|device, logging| {
Box::new(|device| {
let clone = device.clone();
Box::pin(async move { new_rockusb_transport(clone, &logging).await.map(Into::into) })
Box::pin(async move { new_rockusb_transport(clone).await.map(Into::into) })
}),
);

creators
});

pub type FactoryItem = BoxFuture<'static, Result<Box<dyn FwUpdateTransport>, FlashingError>>;
pub type FactoryItem = BoxFuture<'static, Result<Box<dyn FwUpdateTransport>, FwUpdateError>>;
pub type FactoryItemCreator =
Box<dyn Fn(&rusb::Device<GlobalContext>, Sender<FlashProgress>) -> FactoryItem + Send + Sync>;
Box<dyn Fn(&rusb::Device<GlobalContext>) -> FactoryItem + Send + Sync>;

pub fn fw_update_transport(
device: &rusb::Device<GlobalContext>,
logging: Sender<FlashProgress>,
) -> anyhow::Result<FactoryItem> {
) -> Result<FactoryItem, FwUpdateError> {
let descriptor = device.device_descriptor()?;
let vid_pid = (descriptor.vendor_id(), descriptor.product_id());

SUPPORTED_DEVICES
.get(&vid_pid)
.map(|creator| creator(device, logging))
.ok_or(anyhow::anyhow!("no driver available for {:?}", device))
.map(|creator| creator(device))
.ok_or(FwUpdateError::NoDriver(device.clone()))
}

#[derive(Serialize, Debug, Clone, Copy)]
pub enum FlashingError {
DeviceNotFound,
GpioError,
UsbError,
IoError,
#[derive(Error, Debug)]
pub enum FwUpdateError {
#[error("Device {0:#06x}:{1:#06x} not found")]
DeviceNotFound(u16, u16),

Check warning on line 63 in src/firmware_update/mod.rs

View workflow job for this annotation

GitHub Actions / cargo-test

variant `DeviceNotFound` is never constructed

Check failure on line 63 in src/firmware_update/mod.rs

View workflow job for this annotation

GitHub Actions / clippy

variant `DeviceNotFound` is never constructed

error: variant `DeviceNotFound` is never constructed --> src/firmware_update/mod.rs:63:5 | 61 | pub enum FwUpdateError { | ------------- variant in this enum 62 | #[error("Device {0:#06x}:{1:#06x} not found")] 63 | DeviceNotFound(u16, u16), | ^^^^^^^^^^^^^^ | = note: `FwUpdateError` has a derived impl for the trait `Debug`, but this is intentionally ignored during dead code analysis = note: `-D dead-code` implied by `-D warnings`
#[error("No supported devices found")]
NoDevices,
#[error("Several supported devices found: found {0:?}, expected 1")]
MultipleDevicesFound(usize),
#[error("rusb error")]
RusbError(#[from] rusb::Error),
#[error("IO error")]
IoError(#[from] std::io::Error),
#[error("Error loading as USB MSD: {0}")]
InternalError(String),
#[error("integrity check of written image failed")]
ChecksumMismatch,
#[error("no firmware update driver available for {0:?}")]
NoDriver(rusb::Device<GlobalContext>),
}

impl fmt::Display for FlashingError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
FlashingError::DeviceNotFound => write!(f, "Device not found"),
FlashingError::GpioError => write!(f, "Error toggling GPIO lines"),
FlashingError::UsbError => write!(f, "Error enumerating USB devices"),
FlashingError::IoError => write!(f, "File IO error"),
FlashingError::ChecksumMismatch => {
write!(f, "Failed to verify image after writing to the node")
}
}
}
}

impl std::error::Error for FlashingError {}

pub trait FlashingErrorExt<T, E: Display> {
fn map_err_into_logged_usb(self, logging: &Sender<FlashProgress>) -> Result<T, FlashingError>;
fn map_err_into_logged_io(self, logging: &Sender<FlashProgress>) -> Result<T, FlashingError>;
}

impl<T, E: Display> FlashingErrorExt<T, E> for Result<T, E> {
fn map_err_into_logged_usb(self, logging: &Sender<FlashProgress>) -> Result<T, FlashingError> {
self.map_err(|e| {
logging
.try_send(FlashProgress {
status: FlashStatus::Error(FlashingError::UsbError),
message: format!("{}", e),
})
.expect("logging channel to be open");
FlashingError::UsbError
})
}
fn map_err_into_logged_io(self, logging: &Sender<FlashProgress>) -> Result<T, FlashingError> {
self.map_err(|e| {
logging
.try_send(FlashProgress {
status: FlashStatus::Error(FlashingError::IoError),
message: format!("{}", e),
})
.expect("logging channel to be open");
FlashingError::IoError
})
}
}

#[derive(Debug, Clone, Copy)]
pub enum FlashStatus {
Idle,
Setup,
Progress {
read_percent: usize,
est_minutes: u64,
est_seconds: u64,
},
Error(FlashingError),
}

#[derive(Debug, Clone)]
pub struct FlashProgress {
pub status: FlashStatus,
pub message: String,
}

impl Display for FlashProgress {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.message)
impl FwUpdateError {
pub fn internal_error<E: ToString>(error: E) -> FwUpdateError {
FwUpdateError::InternalError(error.to_string())
}
}
Loading

0 comments on commit 1025080

Please sign in to comment.