From b8abd08ab7ba7597cc80ad0230500a2001175e3b Mon Sep 17 00:00:00 2001 From: Sven Rademakers Date: Wed, 20 Sep 2023 11:13:32 +0100 Subject: [PATCH] flash_service: stability and improvements * added a API call so a client can ask the current status of the `FlashService` * fixes regarding cancelling of the node flash * cleanup of code --- Cargo.lock | 54 +-- bmcd/Cargo.toml | 1 + bmcd/src/flash_service.rs | 335 +++++++++++++------ bmcd/src/into_legacy_response.rs | 6 + bmcd/src/legacy.rs | 47 ++- bmcd/src/main.rs | 18 +- tpi_rs/src/app/flash_application.rs | 289 ---------------- tpi_rs/src/app/flash_context.rs | 275 +++++++++++++++ tpi_rs/src/app/mod.rs | 2 +- tpi_rs/src/c_interface.rs | 7 +- tpi_rs/src/middleware/firmware_update/mod.rs | 3 +- tpi_rs/src/utils/io.rs | 5 + 12 files changed, 591 insertions(+), 451 deletions(-) delete mode 100644 tpi_rs/src/app/flash_application.rs create mode 100644 tpi_rs/src/app/flash_context.rs diff --git a/Cargo.lock b/Cargo.lock index 095e8a3..dc7a57f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -89,7 +89,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e01ed3140b2f8d422c68afa1ed2e85d996ea619c988ac834d255db32138655cb" dependencies = [ "quote", - "syn 2.0.33", + "syn 2.0.37", ] [[package]] @@ -224,7 +224,7 @@ dependencies = [ "actix-router", "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.37", ] [[package]] @@ -256,9 +256,9 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.0.5" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c378d78423fdad8089616f827526ee33c19f2fddbd5de1629152c9593ba4783" +checksum = "0f2135563fb5c609d2b2b87c1e8ce7bc41b0b45430fa9661f457981503dd5bf0" dependencies = [ "memchr", ] @@ -436,6 +436,7 @@ dependencies = [ "log", "nix 0.26.4", "openssl", + "rand 0.8.5", "serde", "serde_json", "serde_yaml", @@ -476,7 +477,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.37", ] [[package]] @@ -518,9 +519,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.30" +version = "0.4.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "defd4e7873dbddba6c7c91e199c7fcb946abc4a6a4ac3195400bcfb01b5de877" +checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" dependencies = [ "android-tzdata", "iana-time-zone", @@ -530,18 +531,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.4.3" +version = "4.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84ed82781cea27b43c9b106a979fe450a13a31aab0500595fb3fc06616de08e6" +checksum = "b1d7b8d5ec32af0fadc644bf1fd509a688c2103b185644bb1e29d164e0703136" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.4.2" +version = "4.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bb9faaa7c2ef94b2743a21f5a29e6f0010dff4caa69ac8e9d6cf8b6fa74da08" +checksum = "5179bb514e4d7c2051749d8fcefa2ed6d06a9f4e6d69faf3805f5d80b8cf8d56" dependencies = [ "anstream", "anstyle", @@ -835,7 +836,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.37", ] [[package]] @@ -946,9 +947,9 @@ checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" [[package]] name = "hermit-abi" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" [[package]] name = "http" @@ -1130,13 +1131,12 @@ checksum = "1a9bad9f94746442c783ca431b22403b519cd7fbeed0533fdd6328b2f2212128" [[package]] name = "local-channel" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f303ec0e94c6c54447f84f3b0ef7af769858a9c4ef56ef2a986d3dcd4c3fc9c" +checksum = "e0a493488de5f18c8ffcba89eebb8532ffc562dc400490eb65b84893fae0b178" dependencies = [ "futures-core", "futures-sink", - "futures-util", "local-waker", ] @@ -1336,7 +1336,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.37", ] [[package]] @@ -1709,7 +1709,7 @@ checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.37", ] [[package]] @@ -1830,9 +1830,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.33" +version = "2.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9caece70c63bfba29ec2fed841a09851b14a235c60010fa4de58089b6c025668" +checksum = "7303ef2c05cd654186cb250d29049a24840ca25d2747c25c0381c8d9e2f582e8" dependencies = [ "proc-macro2", "quote", @@ -1872,7 +1872,7 @@ checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.37", ] [[package]] @@ -1947,7 +1947,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.37", ] [[package]] @@ -2040,9 +2040,9 @@ dependencies = [ [[package]] name = "typenum" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicase" @@ -2142,7 +2142,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.37", "wasm-bindgen-shared", ] @@ -2164,7 +2164,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.33", + "syn 2.0.37", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/bmcd/Cargo.toml b/bmcd/Cargo.toml index 68ff266..e368561 100644 --- a/bmcd/Cargo.toml +++ b/bmcd/Cargo.toml @@ -23,3 +23,4 @@ tokio.workspace = true tokio-util.workspace = true futures.workspace = true serde.workspace = true +rand = "0.8.5" diff --git a/bmcd/src/flash_service.rs b/bmcd/src/flash_service.rs index 586aaa5..9d3d626 100644 --- a/bmcd/src/flash_service.rs +++ b/bmcd/src/flash_service.rs @@ -1,150 +1,198 @@ use crate::into_legacy_response::LegacyResponse; use actix_web::{http::StatusCode, web::Bytes}; -use anyhow::Context; -use futures::future::BoxFuture; +use rand::Rng; +use serde::{Serialize, Serializer}; use std::{ collections::hash_map::DefaultHasher, error::Error, fmt::Display, hash::{Hash, Hasher}, + ops::{Deref, DerefMut}, sync::Arc, time::{Duration, Instant}, }; -use tokio::sync::mpsc::{channel, error::SendError, Sender}; +use tokio::sync::Mutex; +use tokio::{ + io::AsyncRead, + sync::mpsc::{channel, error::SendError, Sender}, +}; use tokio_util::sync::CancellationToken; -use tpi_rs::app::flash_application::flash_node; -use tpi_rs::{app::bmc_application::BmcApplication, middleware::NodeId, utils::logging_sink}; -use tpi_rs::{app::flash_application::FlashContext, utils::ReceiverReader}; - -pub type FlashDoneFut = BoxFuture<'static, anyhow::Result<()>>; +use tpi_rs::{app::bmc_application::BmcApplication, middleware::NodeId}; +use tpi_rs::{app::flash_context::FlashContext, utils::ReceiverReader}; const RESET_TIMEOUT: Duration = Duration::from_secs(10); -struct TransferContext { - pub peer: u64, - pub bytes_sender: Sender, - pub cancel: CancellationToken, - last_recieved_chunk: Instant, -} - -impl TransferContext { - pub fn duration_since_last_chunk(&self) -> Duration { - Instant::now().saturating_duration_since(self.last_recieved_chunk) - } -} - pub struct FlashService { - status: Option, - bmc: Arc, + status: Arc>, } impl FlashService { - pub fn new(bmc: Arc) -> Self { - Self { status: None, bmc } + pub fn new() -> Self { + Self { + status: Arc::new(Mutex::new(FlashStatus::Ready)), + } } + /// Start a node flash command and initialize [`FlashService`] for chunked + /// file transfer. Calling this function twice results in a + /// `Err(FlashError::InProgress)`. Unless the first file transfer deemed to + /// be stale. In this case the [`FlashService`] will be reset and initialize + /// for a new transfer. A transfer is stale when the `RESET_TIMEOUT` is + /// reached. Meaning no chunk has been received for longer as + /// `RESET_TIMEOUT`. pub async fn start_transfer( - &mut self, + &self, peer: &str, filename: String, size: u64, node: NodeId, - ) -> Result { - if let Some(context) = &self.status { - if context.duration_since_last_chunk() < RESET_TIMEOUT { - return Err(FlashError::InProgress); - } else { - log::warn!( - "Assuming last transfer will never complete as last request was {}s ago. Resetting flash service", - context.duration_since_last_chunk().as_secs() - ); - self.reset(); - } - } + bmc: Arc, + ) -> Result<(), FlashError> { + let mut status = self.status.lock().await; + self.reset_transfer_on_timeout(peer, status.deref_mut())?; + + let mut hasher = DefaultHasher::new(); + peer.hash(&mut hasher); + let peer = hasher.finish(); let (sender, receiver) = channel::(128); - let (progress_sender, progress_receiver) = channel(32); - let done_token = CancellationToken::new(); - let context = FlashContext { + let transfer_context = TransferContext::new(peer, sender); + let cancel = transfer_context.cancel.child_token(); + let id = transfer_context.id; + + let context = FlashContext::new( + id, filename, size, node, - byte_stream: ReceiverReader::new(receiver), - bmc: self.bmc.clone(), - progress_sender, - cancel: done_token.clone(), - }; - - // execute flashing of the image. - let flash_handle = tokio::spawn(flash_node(context)); - logging_sink(progress_receiver); + ReceiverReader::new(receiver), + bmc, + cancel, + ); - let mut hasher = DefaultHasher::new(); - peer.hash(&mut hasher); + Self::run_flash_worker(context, self.status.clone()); + *status = FlashStatus::Transferring(transfer_context); + log::info!("new transfer started. id: {}", id); - let context = TransferContext { - peer: hasher.finish(), - bytes_sender: sender, - cancel: done_token.clone(), - last_recieved_chunk: Instant::now(), - }; - - self.status = Some(context); - - Ok(Box::pin(async move { - let result = flash_handle - .await - .context("join error waiting for flashing to complete"); - done_token.cancel(); - result? - })) + Ok(()) } - pub async fn put_chunk(&mut self, peer: String, data: Bytes) -> Result<(), FlashError> { - if data.is_empty() { - self.reset(); - return Err(FlashError::EmptyPayload); + /// When a 'start_transfer' call is made while we are still in a transfer + /// state, assume that the current transfer is stale given the timeout limit + /// is reached. + fn reset_transfer_on_timeout( + &self, + peer: &str, + mut status: impl DerefMut, + ) -> Result<(), FlashError> { + if let FlashStatus::Transferring(context) = &*status { + let duration = context.duration_since_last_chunk(peer)?; + if duration < RESET_TIMEOUT { + return Err(FlashError::InProgress); + } else { + log::warn!( + "Assuming transfer ({}) will never complete as last request was {}s ago. Resetting flash service", + context.id, + duration.as_secs() + ); + *status = FlashStatus::Ready; + } } + Ok(()) + } - let mut hasher = DefaultHasher::new(); - peer.hash(&mut hasher); - let hashed_peer = hasher.finish(); + /// Worker task that performs the actual node flash. This tasks finishes if + /// one of the following scenario's is met: + /// * flashing completed successfully + /// * flashing was canceled + /// * Error occurred during flashing. + /// + /// Note that the "global" status does not get updated when the task was + /// canceled. Cancel can only be true on a state transition from + /// `FlashStatus::Transferring`, meaning a state transition already + /// happened. In this case we omit a state transition to + /// `FlashSstatus::Error(_)` + fn run_flash_worker( + mut context: FlashContext, + status: Arc>, + ) { + let start_time = Instant::now(); + tokio::spawn(async move { + let (new_state, was_cancelled) = context.flash_node().await.map_or_else( + |e| { + let error = e.to_string(); + let is_cancelled = context.cancel.is_cancelled(); + if is_cancelled { + log::error!("flashing stopped: {}. ({})", error, context.id); + } + (FlashStatus::Error(e.to_string()), is_cancelled) + }, + |_| { + let duration = Instant::now().saturating_duration_since(start_time); + log::info!( + "flashing successful. took {}m{}s. ({})", + duration.as_secs() / 60, + duration.as_secs() % 60, + context.id, + ); - let result = if let Some(context) = &mut self.status { - if context.peer != hashed_peer { - return Err(FlashError::PeersDoNotMatch(peer)); - } + (FlashStatus::Done(duration, context.size), false) + }, + ); - match context.bytes_sender.send(data).await { - Ok(_) => { - context.last_recieved_chunk = Instant::now(); - Ok(()) + let mut status_unlocked = status.lock().await; + if let FlashStatus::Transferring(_) = &*status_unlocked { + if !was_cancelled { + *status_unlocked = new_state; } - Err(_) if context.bytes_sender.is_closed() => Err(FlashError::Aborted), - Err(e) => Err(e.into()), } - } else { - Err(FlashError::TransferNotStarted) - }; + }); + } - if result.is_err() { - self.reset(); - } + /// Write a chunk of bytes to the module that is selected for flashing. + /// + /// # Return + /// + /// This function returns: + /// + /// * 'Err(FlashError::WrongState)' if this function is called when + /// ['FlashService'] is not in 'Transferring' state. + /// * 'Err(FlashError::EmptyPayload)' when data == empty + /// * 'Err(FlashError::Error(_)' when there is an internal error + /// * Ok(()) on success + pub async fn put_chunk(&self, peer: String, data: Bytes) -> Result<(), FlashError> { + let mut status = self.status.lock().await; + if let FlashStatus::Transferring(ref mut context) = *status { + if data.is_empty() { + *status = FlashStatus::Ready; + return Err(FlashError::EmptyPayload); + } - result - } + if let Err(e) = context.push_bytes(peer, data).await { + *status = FlashStatus::Error(e.to_string()); + return Err(e); + } - fn reset(&mut self) { - if let Some(context) = &self.status { - context.cancel.cancel(); + Ok(()) + } else { + log::error!( + "cannot put chunk. state is not transferring. state= {:?}", + status + ); + Err(FlashError::WrongState) } - self.status = None; + } + + /// Return a borrow to the current status of the flash service + /// This object implements [`serde::Serialize`] + pub async fn status(&self) -> impl Deref + '_ { + self.status.lock().await } } -#[derive(Debug, PartialEq)] +#[derive(Debug)] pub enum FlashError { InProgress, - TransferNotStarted, + WrongState, EmptyPayload, PeersDoNotMatch(String), Aborted, @@ -157,11 +205,11 @@ impl Display for FlashError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { FlashError::InProgress => write!(f, "another flashing operation in progress"), - FlashError::TransferNotStarted => { - write!(f, "transfer not started yet, did not expect that command") + FlashError::WrongState => { + write!(f, "cannot execute command in current state") } FlashError::Aborted => write!(f, "flash operation was aborted"), - FlashError::MpscError(_) => write!(f, "internal error sending buffers"), + FlashError::MpscError(e) => write!(f, "internal error sending buffers: {}", e), FlashError::EmptyPayload => write!(f, "received emply payload"), FlashError::PeersDoNotMatch(peer) => { write!(f, "no flash service in progress for {}", peer) @@ -180,7 +228,7 @@ impl From for LegacyResponse { fn from(value: FlashError) -> Self { let status_code = match value { FlashError::InProgress => StatusCode::SERVICE_UNAVAILABLE, - FlashError::TransferNotStarted => StatusCode::BAD_REQUEST, + FlashError::WrongState => StatusCode::BAD_REQUEST, FlashError::MpscError(_) => StatusCode::INTERNAL_SERVER_ERROR, FlashError::Aborted => StatusCode::INTERNAL_SERVER_ERROR, FlashError::EmptyPayload => StatusCode::BAD_REQUEST, @@ -189,3 +237,84 @@ impl From for LegacyResponse { (status_code, value.to_string()).into() } } + +#[derive(Debug, Serialize)] +pub enum FlashStatus { + Ready, + Transferring(TransferContext), + Done(Duration, u64), + Error(String), +} + +/// Context object for node flashing. This object will cancel the node flash +/// cancel-token when it goes out of scope, Aborting the node flash task. +/// Typically happens on a state transition inside the [`FlashService`]. +#[derive(Debug, Serialize)] +pub struct TransferContext { + pub id: u64, + pub peer: u64, + #[serde(skip)] + pub bytes_sender: Sender, + #[serde(skip)] + pub cancel: CancellationToken, + #[serde(serialize_with = "serialize_seconds_until_now")] + last_recieved_chunk: Instant, +} + +fn serialize_seconds_until_now(instant: &Instant, s: S) -> Result +where + S: Serializer, +{ + let secs = Instant::now().saturating_duration_since(*instant).as_secs(); + s.serialize_u64(secs) +} + +impl TransferContext { + pub fn new(peer: u64, bytes_sender: Sender) -> Self { + let mut rng = rand::thread_rng(); + let id = rng.gen(); + + TransferContext { + id, + peer, + bytes_sender, + cancel: CancellationToken::new(), + last_recieved_chunk: Instant::now(), + } + } + + pub fn duration_since_last_chunk(&self, peer: &str) -> Result { + let mut hasher = DefaultHasher::new(); + peer.hash(&mut hasher); + let hashed_peer = hasher.finish(); + if self.peer != hashed_peer { + return Err(FlashError::PeersDoNotMatch(peer.into())); + } + + Ok(Instant::now().saturating_duration_since(self.last_recieved_chunk)) + } + + async fn push_bytes(&mut self, peer: String, data: Bytes) -> Result<(), FlashError> { + let mut hasher = DefaultHasher::new(); + peer.hash(&mut hasher); + let hashed_peer = hasher.finish(); + if self.peer != hashed_peer { + return Err(FlashError::PeersDoNotMatch(peer)); + } + + match self.bytes_sender.send(data).await { + Ok(_) => { + self.last_recieved_chunk = Instant::now(); + Ok(()) + } + Err(_) if self.bytes_sender.is_closed() => Err(FlashError::Aborted), + Err(e) => Err(e.into()), + } + } +} + +impl Drop for TransferContext { + fn drop(&mut self) { + self.cancel.cancel(); + } +} diff --git a/bmcd/src/into_legacy_response.rs b/bmcd/src/into_legacy_response.rs index 8f54448..cd4a042 100644 --- a/bmcd/src/into_legacy_response.rs +++ b/bmcd/src/into_legacy_response.rs @@ -67,6 +67,12 @@ impl From for LegacyResponse { } } +impl From for LegacyResponse { + fn from(value: serde_json::Error) -> Self { + LegacyResponse::Error(StatusCode::INTERNAL_SERVER_ERROR, value.to_string().into()) + } +} + impl ResponseError for LegacyResponse {} impl Display for LegacyResponse { diff --git a/bmcd/src/legacy.rs b/bmcd/src/legacy.rs index 617f623..ffafc8e 100644 --- a/bmcd/src/legacy.rs +++ b/bmcd/src/legacy.rs @@ -9,8 +9,9 @@ use actix_web::{get, web, HttpRequest, Responder}; use anyhow::Context; use nix::sys::statfs::statfs; use serde_json::json; +use std::ops::Deref; use std::str::FromStr; -use tokio::sync::{mpsc, Mutex}; +use tokio::sync::mpsc; use tpi_rs::app::bmc_application::{BmcApplication, UsbConfig}; use tpi_rs::middleware::{NodeId, UsbMode, UsbRoute}; use tpi_rs::utils::logging_sink; @@ -29,6 +30,11 @@ const API_VERSION: &str = "1.1"; pub fn config(cfg: &mut web::ServiceConfig) { cfg.service( web::resource("/api/bmc") + .route( + web::get() + .guard(fn_guard(flash_status_guard)) + .to(handle_flash_status), + ) .route( web::get() .guard(fn_guard(flash_guard)) @@ -43,11 +49,21 @@ pub fn info_config(cfg: &mut web::ServiceConfig) { cfg.service(info_handler); } +fn flash_status_guard(context: &GuardContext<'_>) -> bool { + let query = context.head().uri.query(); + query + .map(|q| q.contains("status")) + .and(query.map(|q| q.contains("type=flash"))) + .and(query.map(|q| q.contains("opt=get"))) + .unwrap_or(false) +} + fn flash_guard(context: &GuardContext<'_>) -> bool { let query = context.head().uri.query(); - let is_set = query.map(|q| q.contains("opt=set")).unwrap_or(false); - let is_type = query.map(|q| q.contains("type=flash")).unwrap_or(false); - is_set && is_type + query + .map(|q| q.contains("opt=set")) + .and(query.map(|q| q.contains("type=flash"))) + .unwrap_or(false) } #[get("/api/bmc/info")] @@ -352,8 +368,13 @@ async fn get_usb_mode(bmc: &BmcApplication) -> impl Into { ) } +async fn handle_flash_status(flash: web::Data) -> LegacyResult { + Ok(serde_json::to_string(flash.status().await.deref())?) +} + async fn handle_flash_request( - flash: web::Data>, + flash: web::Data, + bmc: web::Data, request: HttpRequest, query: Query, ) -> LegacyResult { @@ -379,23 +400,15 @@ async fn handle_flash_request( .map(Into::into) .context("peer_addr unknown")?; - let on_done = flash - .lock() - .await - .start_transfer(&peer, file, size, node) + flash + .start_transfer(&peer, file, size, node, bmc.into_inner()) .await?; - tokio::spawn(async move { - if let Err(e) = on_done.await { - log::error!("{}", e); - } - }); - Ok(Null) } async fn handle_chunk( - flash: web::Data>, + flash: web::Data, request: HttpRequest, chunk: Bytes, ) -> LegacyResult { @@ -405,6 +418,6 @@ async fn handle_chunk( .map(Into::into) .context("peer_addr unknown")?; - flash.lock().await.put_chunk(peer, chunk).await?; + flash.put_chunk(peer, chunk).await?; Ok(Null) } diff --git a/bmcd/src/main.rs b/bmcd/src/main.rs index f4e026d..1a073a2 100644 --- a/bmcd/src/main.rs +++ b/bmcd/src/main.rs @@ -10,11 +10,7 @@ use anyhow::Context; use clap::{command, value_parser, Arg}; use log::LevelFilter; use openssl::ssl::SslAcceptorBuilder; -use std::{ - path::{Path, PathBuf}, - sync::Arc, -}; -use tokio::sync::Mutex; +use std::path::{Path, PathBuf}; use tpi_rs::app::{bmc_application::BmcApplication, event_application::run_event_listener}; pub mod config; mod flash_service; @@ -34,10 +30,9 @@ async fn main() -> anyhow::Result<()> { let tls = load_tls_configuration(&config.tls.private_key, &config.tls.certificate)?; let tls6 = load_tls_configuration(&config.tls.private_key, &config.tls.certificate)?; - let bmc = Arc::new(BmcApplication::new().await?); - run_event_listener(bmc.clone())?; - let flash_service = Data::new(Mutex::new(FlashService::new(bmc.clone()))); - let bmc = Data::from(bmc); + let bmc = Data::new(BmcApplication::new().await?); + run_event_listener(bmc.clone().into_inner())?; + let flash_service = Data::new(FlashService::new()); let run_server = HttpServer::new(move || { App::new() @@ -64,6 +59,7 @@ async fn main() -> anyhow::Result<()> { .default_service(web::route().to(redirect)) }) .bind(("0.0.0.0", HTTP_PORT))? + .bind(("::1", HTTP_PORT))? .run(); tokio::try_join!(run_server, redirect_server)?; @@ -89,8 +85,10 @@ fn init_logger() { simple_logger::SimpleLogger::new() .with_level(level) - .with_module_level("bmcd", LevelFilter::Info) + .with_module_level("bmcd", LevelFilter::Debug) .with_module_level("actix_http", LevelFilter::Info) + .with_module_level("h2::codec", LevelFilter::Info) + .with_module_level("h2::proto", LevelFilter::Info) .with_colors(true) .env() .init() diff --git a/tpi_rs/src/app/flash_application.rs b/tpi_rs/src/app/flash_application.rs deleted file mode 100644 index bf99a6a..0000000 --- a/tpi_rs/src/app/flash_application.rs +++ /dev/null @@ -1,289 +0,0 @@ -use super::bmc_application::UsbConfig; -use crate::app::bmc_application::BmcApplication; -use crate::middleware::{ - firmware_update::{FlashProgress, FlashStatus, FlashingError, SUPPORTED_DEVICES}, - NodeId, UsbRoute, -}; -use anyhow::{bail, Context}; -use crc::{Crc, CRC_64_REDIS}; -use futures::TryFutureExt; -use std::{sync::Arc, time::Duration}; -use tokio::{ - fs, - io::{self, AsyncRead, AsyncReadExt, AsyncSeekExt, AsyncWrite, AsyncWriteExt}, - sync::mpsc::{channel, Receiver, Sender}, - time::{sleep, Instant}, -}; -use tokio_util::sync::CancellationToken; - -const REBOOT_DELAY: Duration = Duration::from_millis(500); -const BUF_SIZE: u64 = 8 * 1024; -const PROGRESS_REPORT_PERCENT: u64 = 5; - -pub struct FlashContext { - pub filename: String, - pub size: u64, - pub node: NodeId, - pub byte_stream: R, - pub bmc: Arc, - pub progress_sender: Sender, - pub cancel: CancellationToken, -} - -pub async fn flash_node(context: FlashContext) -> anyhow::Result<()> { - let bmc = context.bmc; - let node = context.node; - let progress_sender = context.progress_sender; - let filename = context.filename; - let mut image = context.byte_stream; - let image_size = context.size; - - let mut driver = bmc - .configure_node_for_fwupgrade( - node, - UsbRoute::Bmc, - progress_sender.clone(), - SUPPORTED_DEVICES.keys(), - ) - .await?; - - let mut progress_state = FlashProgress { - message: String::new(), - status: FlashStatus::Setup, - }; - - progress_state.message = format!("Writing {:?}", filename); - progress_sender.send(progress_state.clone()).await?; - - let (img_len, img_checksum) = write_to_device( - &mut image, - image_size, - &mut driver, - &progress_sender, - &context.cancel, - ) - .await?; - - progress_state.message = String::from("Verifying checksum..."); - progress_sender.send(progress_state.clone()).await?; - - driver.seek(std::io::SeekFrom::Start(0)).await?; - - verify_checksum( - img_checksum, - img_len, - &mut driver, - &progress_sender, - &context.cancel, - ) - .await?; - - progress_state.message = String::from("Flashing successful, restarting device..."); - progress_sender.send(progress_state.clone()).await?; - - bmc.activate_slot(!node.to_bitfield(), node.to_bitfield()) - .await?; - - //TODO: we probably want to restore the state prior flashing - bmc.usb_boot(node, false).await?; - bmc.configure_usb(UsbConfig::UsbA(node)).await?; - - sleep(REBOOT_DELAY).await; - - bmc.activate_slot(node.to_bitfield(), node.to_bitfield()) - .await?; - - progress_state.message = String::from("Done"); - progress_sender.send(progress_state).await?; - Ok(()) -} - -async fn write_to_device( - image: &mut R, - image_len: u64, - image_writer: &mut W, - sender: &Sender, - cancel: &CancellationToken, -) -> anyhow::Result<(u64, u64)> -where - W: ?Sized + AsyncWrite + std::marker::Unpin, - R: ?Sized + AsyncRead + std::marker::Unpin, -{ - let reader = image; - let writer = image_writer; - - let mut buffer = vec![0u8; BUF_SIZE as usize]; - let mut total_read = 0; - - let img_crc = Crc::::new(&CRC_64_REDIS); - let mut img_digest = img_crc.digest(); - - let (size_sender, size_receiver) = channel::(32); - tokio::spawn(run_progress_printer( - image_len, - sender.clone(), - size_receiver, - )); - - let mut progress_update_guard = 0u64; - - // Read_exact and write_all is used here to enforce a certian write size to the `image_writer`. - // This function could be further optimized to reduce the amount of awaiting reads/writes, e.g. - // look at the implementation of `tokio::io::copy`. But in the case of an 'rockusb' - // image_writer, writes that are misaligned with the sector-size of its device induces an extra - // buffering penalty. - while total_read < image_len && !cancel.is_cancelled() { - let buf_len = BUF_SIZE.min(image_len - total_read) as usize; - reader.read_exact(&mut buffer[..buf_len]).await?; - - total_read += buf_len as u64; - - // we accept sporadic lost progress updates or in worst case an error - // inside the channel. It should never prevent the writing process from - // completing. - // Updates to the progress printer are throttled with an arbitrary - // value. - if progress_update_guard % 1000 == 0 { - let _ = size_sender.try_send(total_read); - } - progress_update_guard += 1; - - img_digest.update(&buffer[..buf_len]); - writer - .write_all(&buffer[..buf_len]) - .map_err(|e| anyhow::anyhow!("device write error: {}", e)) - .await?; - } - - writer.flush().await?; - - if cancel.is_cancelled() { - bail!("write is cancelled") - } else { - Ok((image_len, img_digest.finalize())) - } -} - -async fn run_progress_printer( - img_len: u64, - logging: Sender, - mut read_reciever: Receiver, -) -> anyhow::Result<()> { - let start_time = Instant::now(); - let mut last_print = 0; - - while let Some(total_read) = read_reciever.recv().await { - let read_percent = 100 * total_read / img_len; - - let progress_counter = read_percent - last_print; - if progress_counter >= PROGRESS_REPORT_PERCENT { - #[allow(clippy::cast_precision_loss)] // This affects files > 4 exabytes long - let read_proportion = (total_read as f64) / (img_len as f64); - - let duration = start_time.elapsed(); - let estimated_end = duration.div_f64(read_proportion); - let estimated_left = estimated_end - duration; - - let est_seconds = estimated_left.as_secs() % 60; - let est_minutes = (estimated_left.as_secs() / 60) % 60; - - let message = format!( - "Progress: {:>2}%, estimated time left: {:02}:{:02}", - read_percent, est_minutes, est_seconds, - ); - - logging - .send(FlashProgress { - status: FlashStatus::Progress { - read_percent: read_percent as usize, - est_minutes, - est_seconds, - }, - message, - }) - .await - .context("progress update error")?; - last_print = read_percent; - } - } - Ok(()) -} - -async fn verify_checksum( - img_checksum: u64, - img_len: u64, - reader: &mut R, - sender: &Sender, - cancel: &CancellationToken, -) -> anyhow::Result<()> -where - R: AsyncRead + std::marker::Unpin, -{ - flush_file_caches().await?; - - let dev_checksum = calc_file_checksum(reader, img_len, cancel).await?; - - if img_checksum == dev_checksum { - Ok(()) - } else { - sender - .send(FlashProgress { - status: FlashStatus::Error(FlashingError::ChecksumMismatch), - message: format!( - "Source and destination checksum mismatch: {:#x} != {:#x}", - img_checksum, dev_checksum - ), - }) - .await?; - - bail!(FlashingError::ChecksumMismatch) - } -} - -async fn flush_file_caches() -> io::Result<()> { - let mut file = fs::OpenOptions::new() - .write(true) - .open("/proc/sys/vm/drop_caches") - .await?; - - // Free reclaimable slab objects and page cache - file.write_u8(b'3').await -} - -// This function and `write_to_device()` could be merged into one with an optional callback for -// every chunk read, but async closures are unstable and async blocks seem to require a Mutex. -async fn calc_file_checksum( - reader: &mut R, - total_size: u64, - cancel: &CancellationToken, -) -> anyhow::Result -where - R: AsyncRead + std::marker::Unpin, -{ - let mut reader = io::BufReader::with_capacity(BUF_SIZE as usize, reader); - - let mut buffer = vec![0u8; BUF_SIZE as usize]; - let mut total_read = 0; - - let crc = Crc::::new(&CRC_64_REDIS); - let mut digest = crc.digest(); - - while total_read < total_size && !cancel.is_cancelled() { - let bytes_left = total_size - total_read; - let buffer_size = buffer.len().min(bytes_left as usize); - let num_read = reader.read(&mut buffer[..buffer_size]).await?; - if num_read == 0 { - log::error!("read 0 bytes with {} bytes to go", bytes_left); - bail!(FlashingError::IoError); - } - - total_read += num_read as u64; - digest.update(&buffer[..num_read]); - } - - if cancel.is_cancelled() { - bail!("checksum calculation is cancelled"); - } else { - Ok(digest.finalize()) - } -} diff --git a/tpi_rs/src/app/flash_context.rs b/tpi_rs/src/app/flash_context.rs new file mode 100644 index 0000000..4fec9a5 --- /dev/null +++ b/tpi_rs/src/app/flash_context.rs @@ -0,0 +1,275 @@ +use super::bmc_application::UsbConfig; +use crate::app::bmc_application::BmcApplication; +use crate::middleware::{ + firmware_update::{FlashProgress, FlashStatus, FlashingError, SUPPORTED_DEVICES}, + NodeId, UsbRoute, +}; +use crate::utils::logging_sink; +use anyhow::{bail, ensure, Context}; +use crc::{Crc, CRC_64_REDIS}; +use futures::TryFutureExt; +use std::{sync::Arc, time::Duration}; +use tokio::{ + fs, + io::{self, AsyncRead, AsyncReadExt, AsyncSeekExt, AsyncWrite, AsyncWriteExt}, + sync::mpsc::{channel, Receiver, Sender}, + time::{sleep, Instant}, +}; +use tokio_util::sync::CancellationToken; + +const REBOOT_DELAY: Duration = Duration::from_millis(500); +const BUF_SIZE: u64 = 8 * 1024; +const PROGRESS_REPORT_PERCENT: u64 = 5; + +pub struct FlashContext { + pub id: u64, + pub filename: String, + pub size: u64, + pub node: NodeId, + pub byte_stream: R, + pub bmc: Arc, + pub progress_sender: Sender, + pub cancel: CancellationToken, +} + +impl FlashContext { + pub fn new( + id: u64, + filename: String, + size: u64, + node: NodeId, + byte_stream: R, + bmc: Arc, + cancel: CancellationToken, + ) -> Self { + let (progress_sender, progress_receiver) = channel(32); + logging_sink(progress_receiver); + + Self { + id, + filename, + size, + node, + byte_stream, + bmc, + progress_sender, + cancel, + } + } + + pub async fn flash_node(&mut self) -> anyhow::Result<()> { + let mut device = self + .bmc + .configure_node_for_fwupgrade( + self.node, + UsbRoute::Bmc, + self.progress_sender.clone(), + SUPPORTED_DEVICES.keys(), + ) + .await?; + + let mut progress_state = FlashProgress { + message: String::new(), + status: FlashStatus::Setup, + }; + + progress_state.message = format!("Writing {:?}", self.filename); + self.progress_sender.send(progress_state.clone()).await?; + + let img_checksum = self.write_to_device(&mut device).await?; + + progress_state.message = String::from("Verifying checksum..."); + self.progress_sender.send(progress_state.clone()).await?; + + device.seek(std::io::SeekFrom::Start(0)).await?; + flush_file_caches().await?; + + self.verify_checksum(&mut device, img_checksum).await?; + + progress_state.message = String::from("Flashing successful, restarting device..."); + self.progress_sender.send(progress_state.clone()).await?; + + self.bmc + .activate_slot(!self.node.to_bitfield(), self.node.to_bitfield()) + .await?; + + //TODO: we probably want to restore the state prior flashing + self.bmc.usb_boot(self.node, false).await?; + self.bmc.configure_usb(UsbConfig::UsbA(self.node)).await?; + + sleep(REBOOT_DELAY).await; + + self.bmc + .activate_slot(self.node.to_bitfield(), self.node.to_bitfield()) + .await?; + + progress_state.message = String::from("Done"); + self.progress_sender.send(progress_state).await?; + Ok(()) + } + + async fn write_to_device( + &mut self, + mut device: W, + ) -> anyhow::Result { + let mut buffer = vec![0u8; BUF_SIZE as usize]; + let mut total_read = 0; + + let img_crc = Crc::::new(&CRC_64_REDIS); + let mut img_digest = img_crc.digest(); + + let (size_sender, size_receiver) = channel::(32); + tokio::spawn(run_progress_printer( + self.size, + self.progress_sender.clone(), + size_receiver, + )); + + let mut progress_update_guard = 0u64; + + // Read_exact and write_all is used here to enforce a certain write size to the `image_writer`. + // This function could be further optimized to reduce the amount of awaiting reads/writes, e.g. + // look at the implementation of `tokio::io::copy`. But in the case of an 'rockusb' + // image_writer, writes that are misaligned with the sector-size of its device induces an extra + // buffering penalty. + while total_read < self.size { + let buf_len = BUF_SIZE.min(self.size - total_read) as usize; + let read_task = self.byte_stream.read_exact(&mut buffer[..buf_len]); + let monitor_cancel = self.cancel.cancelled(); + + tokio::select! { + // self.bytes_stream is not guaranteed to be "cancel-safe". Canceling read_task + // might result in a data loss. We allow this as a cancel aborts the whole file + // file transfer + res = read_task => ensure!(res? == buf_len), + _ = monitor_cancel => bail!("write cancelled"), + } + + total_read += buf_len as u64; + + // we accept sporadic lost progress updates or in worst case an error + // inside the channel. It should never prevent the writing process from + // completing. + // Updates to the progress printer are throttled with an arbitrary + // value. + if progress_update_guard % 1000 == 0 { + let _ = size_sender.try_send(total_read); + } + progress_update_guard += 1; + + img_digest.update(&buffer[..buf_len]); + device + .write_all(&buffer[..buf_len]) + .map_err(|e| anyhow::anyhow!("device write error: {}", e)) + .await?; + } + + device.flush().await?; + + Ok(img_digest.finalize()) + } + + async fn verify_checksum(&self, reader: L, img_checksum: u64) -> anyhow::Result<()> + where + L: AsyncRead + std::marker::Unpin, + { + let mut reader = io::BufReader::with_capacity(BUF_SIZE as usize, reader); + + let mut buffer = vec![0u8; BUF_SIZE as usize]; + let mut total_read = 0; + + let crc = Crc::::new(&CRC_64_REDIS); + let mut digest = crc.digest(); + + while total_read < self.size { + let bytes_left = self.size - total_read; + let buffer_size = buffer.len().min(bytes_left as usize); + let read_task = reader.read(&mut buffer[..buffer_size]); + let monitor_cancel = self.cancel.cancelled(); + + let num_read = tokio::select! { + res = read_task => res?, + _ = monitor_cancel => bail!("checksum calculation is cancelled"), + }; + + if num_read == 0 { + log::error!("read 0 bytes with {} bytes to go", bytes_left); + bail!(FlashingError::IoError); + } + + total_read += num_read as u64; + digest.update(&buffer[..num_read]); + } + + let dev_checksum = digest.finalize(); + if img_checksum != dev_checksum { + self.progress_sender + .send(FlashProgress { + status: FlashStatus::Error(FlashingError::ChecksumMismatch), + message: format!( + "Source and destination checksum mismatch: {:#x} != {:#x}", + img_checksum, dev_checksum + ), + }) + .await?; + + bail!(FlashingError::ChecksumMismatch) + } + Ok(()) + } +} + +async fn run_progress_printer( + img_len: u64, + logging: Sender, + mut read_reciever: Receiver, +) -> anyhow::Result<()> { + let start_time = Instant::now(); + let mut last_print = 0; + + while let Some(total_read) = read_reciever.recv().await { + let read_percent = 100 * total_read / img_len; + + let progress_counter = read_percent - last_print; + if progress_counter >= PROGRESS_REPORT_PERCENT { + #[allow(clippy::cast_precision_loss)] // This affects files > 4 exabytes long + let read_proportion = (total_read as f64) / (img_len as f64); + + let duration = start_time.elapsed(); + let estimated_end = duration.div_f64(read_proportion); + let estimated_left = estimated_end - duration; + + let est_seconds = estimated_left.as_secs() % 60; + let est_minutes = (estimated_left.as_secs() / 60) % 60; + + let message = format!( + "Progress: {:>2}%, estimated time left: {:02}:{:02}", + read_percent, est_minutes, est_seconds, + ); + + logging + .send(FlashProgress { + status: FlashStatus::Progress { + read_percent: read_percent as usize, + est_minutes, + est_seconds, + }, + message, + }) + .await + .context("progress update error")?; + last_print = read_percent; + } + } + Ok(()) +} + +async fn flush_file_caches() -> io::Result<()> { + let mut file = fs::OpenOptions::new() + .write(true) + .open("/proc/sys/vm/drop_caches") + .await?; + + // Free reclaimable slab objects and page cache + file.write_u8(b'3').await +} diff --git a/tpi_rs/src/app/mod.rs b/tpi_rs/src/app/mod.rs index c6a0966..e900f7d 100644 --- a/tpi_rs/src/app/mod.rs +++ b/tpi_rs/src/app/mod.rs @@ -1,3 +1,3 @@ pub mod bmc_application; pub mod event_application; -pub mod flash_application; +pub mod flash_context; diff --git a/tpi_rs/src/c_interface.rs b/tpi_rs/src/c_interface.rs index 98b2be1..2b9a028 100644 --- a/tpi_rs/src/c_interface.rs +++ b/tpi_rs/src/c_interface.rs @@ -17,7 +17,7 @@ use tokio_util::sync::CancellationToken; use crate::app::bmc_application::{BmcApplication, UsbConfig}; use crate::app::event_application::run_event_listener; -use crate::app::flash_application::{flash_node, FlashContext}; +use crate::app::flash_context::FlashContext; use crate::middleware::firmware_update::FlashingError; use crate::middleware::{UsbMode, UsbRoute}; @@ -225,7 +225,8 @@ pub unsafe extern "C" fn tpi_flash_node(node: c_int, image_path: *const c_char) let (sender, receiver) = channel(64); let img_file = tokio::fs::File::open(&node_image).await.unwrap(); let img_len = img_file.metadata().await.unwrap().len(); - let context = FlashContext { + let mut context = FlashContext { + id: 123, filename: node_image .file_name() .unwrap() @@ -239,7 +240,7 @@ pub unsafe extern "C" fn tpi_flash_node(node: c_int, image_path: *const c_char) cancel: CancellationToken::new(), }; - let handle = tokio::spawn(flash_node(context)); + let handle = tokio::spawn(async move { context.flash_node().await }); let print_handle = logging_sink(receiver); let (res, _) = join!(handle, print_handle); diff --git a/tpi_rs/src/middleware/firmware_update/mod.rs b/tpi_rs/src/middleware/firmware_update/mod.rs index 2c5bc82..a2c17f3 100644 --- a/tpi_rs/src/middleware/firmware_update/mod.rs +++ b/tpi_rs/src/middleware/firmware_update/mod.rs @@ -2,6 +2,7 @@ mod rockusb_fwudate; use self::rockusb_fwudate::new_rockusb_transport; mod rpi_fwupdate; use rpi_fwupdate::new_rpi_transport; +use serde::Serialize; pub mod transport; use self::transport::FwUpdateTransport; use futures::future::BoxFuture; @@ -57,7 +58,7 @@ pub fn fw_update_transport( .ok_or(anyhow::anyhow!("no driver available for {:?}", device)) } -#[derive(Debug, Clone, Copy)] +#[derive(Serialize, Debug, Clone, Copy)] pub enum FlashingError { InvalidArgs, DeviceNotFound, diff --git a/tpi_rs/src/utils/io.rs b/tpi_rs/src/utils/io.rs index aa3c9bb..897fcc1 100644 --- a/tpi_rs/src/utils/io.rs +++ b/tpi_rs/src/utils/io.rs @@ -7,6 +7,11 @@ use tokio::{ /// This struct wraps a [tokio::sync::mpsc::Receiver] and transforms that /// exposes a [AsyncRead] interface. +/// +/// # Cancel +/// +/// This struct is *not* cancel safe! Using this struct in a tokio::select loop +/// can cause data loss. pub struct ReceiverReader where T: Deref,