diff --git a/Cargo.lock b/Cargo.lock index c1e2ccf..3f738cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -396,6 +396,7 @@ dependencies = [ "actix-web", "anyhow", "build-time", + "futures", "if-addrs", "log", "mime", diff --git a/bmcd/Cargo.toml b/bmcd/Cargo.toml index 14430d4..8f6f7db 100644 --- a/bmcd/Cargo.toml +++ b/bmcd/Cargo.toml @@ -18,3 +18,4 @@ log.workspace = true simple_logger.workspace = true tokio = { workspace = true, features = ["net"] } mime = "0.3.17" +futures = "0.3.28" diff --git a/bmcd/src/flash_service.rs b/bmcd/src/flash_service.rs index a258d2b..9f7d4c6 100644 --- a/bmcd/src/flash_service.rs +++ b/bmcd/src/flash_service.rs @@ -1,7 +1,16 @@ #![allow(dead_code, unused)] use crate::into_legacy_response::LegacyResponse; use actix_web::{http::StatusCode, web::Bytes}; -use std::{error::Error, fmt::Display, sync::Arc}; +use anyhow::Context; +use futures::future::BoxFuture; +use futures::TryFutureExt; +use std::{ + collections::hash_map::DefaultHasher, + error::Error, + fmt::Display, + hash::{Hash, Hasher}, + sync::Arc, +}; use tokio::{ io::{AsyncRead, BufReader}, sync::mpsc::{channel, error::SendError, Receiver, Sender}, @@ -9,12 +18,14 @@ use tokio::{ use tpi_rs::{ app::bmc_application::BmcApplication, middleware::{firmware_update::SUPPORTED_DEVICES, NodeId, UsbRoute}, + utils::logging_sink, }; use tpi_rs::{app::flash_application::flash_node, middleware::firmware_update::FlashStatus}; use tpi_rs::{app::flash_application::FlashContext, utils::ReceiverReader}; +pub type FlashDoneFut = BoxFuture<'static, anyhow::Result<()>>; pub struct FlashService { - status: Option>, + status: Option<(u64, Sender)>, bmc: Arc, } @@ -25,16 +36,19 @@ impl FlashService { pub async fn start_transfer( &mut self, + peer: &str, filename: String, size: usize, node: NodeId, - ) -> Result<(), FlashError> { + ) -> Result { if self.status.is_some() { return Err(FlashError::InProgress); } let (sender, receiver) = channel::(128); let (progress_sender, progress_receiver) = channel(32); + logging_sink(progress_receiver); + let context = FlashContext { filename, size, @@ -45,14 +59,24 @@ impl FlashService { }; /// execute the flashing of the image. - tokio::spawn(flash_node(context)); + let flash_handle = tokio::spawn(flash_node(context)); - self.status = Some(sender); - Ok(()) + let mut hasher = DefaultHasher::new(); + peer.hash(&mut hasher); + self.status = Some((hasher.finish(), sender)); + Ok(Box::pin(async move { + flash_handle + .await + .context("join error waiting for flashing to complete")? + })) } - pub async fn stream_chunk(&mut self, data: Bytes) -> Result<(), FlashError> { - if let Some(sender) = &self.status { + pub async fn stream_chunk(&mut self, peer: &str, data: Bytes) -> Result<(), FlashError> { + let mut hasher = DefaultHasher::new(); + peer.hash(&mut hasher); + let hash = hasher.finish(); + + if let Some((hash, sender)) = &self.status { match sender.send(data).await { Ok(_) => Ok(()), Err(e) if sender.is_closed() => Err(FlashError::Aborted), @@ -62,6 +86,10 @@ impl FlashService { Err(FlashError::UnexpectedCommand) } } + + pub fn reset(&mut self) { + self.status = None; + } } #[derive(Debug, PartialEq)] diff --git a/bmcd/src/legacy.rs b/bmcd/src/legacy.rs index 14bac8e..6f8c5cb 100644 --- a/bmcd/src/legacy.rs +++ b/bmcd/src/legacy.rs @@ -1,11 +1,12 @@ //! Routes for legacy API present in versions <= 1.1.0 of the firmware. use crate::flash_service::FlashService; -use crate::into_legacy_response::LegacyResult; use crate::into_legacy_response::{IntoLegacyResponse, LegacyResponse}; -use actix_web::http::header::{CONTENT_ENCODING, TRANSFER_ENCODING}; +use crate::into_legacy_response::{LegacyResult, Null}; +use actix_web::guard::{fn_guard, GuardContext}; +use actix_web::http::header::CONTENT_TYPE; use actix_web::http::StatusCode; use actix_web::web::Bytes; -use actix_web::{web, HttpRequest, HttpResponse}; +use actix_web::{web, HttpRequest, Responder}; use anyhow::Context; use nix::sys::statfs::statfs; use serde_json::json; @@ -13,26 +14,43 @@ use std::str::FromStr; use tokio::sync::{mpsc, Mutex}; use tpi_rs::app::bmc_application::{BmcApplication, UsbConfig}; use tpi_rs::middleware::{NodeId, UsbMode, UsbRoute}; - type Query = web::Query>; pub fn config(cfg: &mut web::ServiceConfig) { cfg.service( web::resource("/api/bmc") - .route(web::get().to(api_entry)) - .route(web::post().to(api_post)), + .route( + web::route() + .guard(fn_guard(flash_guard)) + .to(handle_flash_request), + ) + .route(web::get().to(api_entry)), ); } -async fn api_entry(bmc: web::Data, query: Query) -> HttpResponse { +fn flash_guard(context: &GuardContext<'_>) -> bool { + let is_set = context + .head() + .uri + .query() + .is_some_and(|q| q.contains("opt=set")); + let is_type = context + .head() + .uri + .query() + .is_some_and(|q| q.contains("type=flash")); + is_set && is_type +} + +async fn api_entry(bmc: web::Data, query: Query) -> impl Responder { let is_set = match query.get("opt").map(String::as_str) { Some("set") => true, Some("get") => false, - _ => return LegacyResponse::bad_request("Missing `opt` parameter").into(), + _ => return LegacyResponse::bad_request("Missing `opt` parameter"), }; let Some(ty) = query.get("type") else { - return LegacyResponse::bad_request("Missing `opt` parameter").into() + return LegacyResponse::bad_request("Missing `opt` parameter") }; let bmc = bmc.as_ref(); @@ -41,7 +59,7 @@ async fn api_entry(bmc: web::Data, query: Query) -> HttpResponse ("network", true) => reset_network(bmc).await.legacy_response(), ("nodeinfo", true) => set_node_info().legacy_response(), ("nodeinfo", false) => get_node_info(bmc).legacy_response(), - ("node_to_msd", true) => set_node_to_msd(bmc, query).await.legacy_response(), + ("node_to_msd", true) => set_node_to_msd(bmc, query).await.into(), ("other", false) => get_system_information().await.legacy_response(), ("power", true) => set_node_power(bmc, query).await.legacy_response(), ("power", false) => get_node_power(bmc).await.legacy_response(), @@ -49,11 +67,14 @@ async fn api_entry(bmc: web::Data, query: Query) -> HttpResponse ("sdcard", false) => get_sdcard_info(), ("uart", true) => write_to_uart(bmc, query).legacy_response(), ("uart", false) => read_from_uart(bmc, query).legacy_response(), - ("usb", true) => set_usb_mode(bmc, query).await.legacy_response(), + ("usb", true) => set_usb_mode(bmc, query).await.into(), ("usb", false) => get_usb_mode(bmc).await.into(), - _ => LegacyResponse::bad_request("Invalid `type` parameter"), + _ => ( + StatusCode::BAD_REQUEST, + format!("Invalid `type` parameter {}", ty), + ) + .legacy_response(), } - .into() } fn clear_usb_boot(bmc: &BmcApplication) -> impl IntoLegacyResponse { @@ -312,41 +333,25 @@ async fn get_usb_mode(bmc: &BmcApplication) -> anyhow::Result>, - chunk: Bytes, - request: HttpRequest, - query: Query, -) -> HttpResponse { - if query.get("opt").map(String::as_ref) != Some("set") { - return LegacyResponse::bad_request("Invalid `opt` parameter").into(); - } - - let Some(ty) = query.get("type") else { - return LegacyResponse::bad_request("Missing `type` parameter").into(); - }; - - match ty.as_ref() { - "firmware" => handle_flash_request(flash, request, chunk, query) - .await - .legacy_response(), - "flash" => LegacyResponse::stub(), - _ => LegacyResponse::bad_request("Invalid `type` parameter"), - } - .into() -} - async fn handle_flash_request( flash: web::Data>, request: HttpRequest, chunk: Bytes, query: Query, -) -> LegacyResult<()> { +) -> LegacyResult { let mut flash_service = flash.lock().await; if is_stream_chunck(&request) { - (*flash_service).stream_chunk(chunk).await?; - return Ok(()); + (*flash_service) + .stream_chunk( + request + .connection_info() + .peer_addr() + .context("peer_addr unknown")?, + chunk, + ) + .await?; + return Ok(Null); } let node = get_node_param(&query)?; @@ -365,21 +370,34 @@ async fn handle_flash_request( let size = usize::from_str(size) .map_err(|_| LegacyResponse::bad_request("`lenght` parameter not a number"))?; - (*flash_service) - .start_transfer(file, size, node) + let on_done = (*flash_service) + .start_transfer( + request + .connection_info() + .peer_addr() + .context("peer_addr unknown")?, + file, + size, + node, + ) .await - .map_err(|e| (StatusCode::SERVICE_UNAVAILABLE, format!("{}", e)).into()) + .map_err(|e| (StatusCode::SERVICE_UNAVAILABLE, format!("{}", e)).legacy_response())?; + + let service = flash.clone(); + tokio::spawn(async move { + if let Err(e) = on_done.await { + log::error!("{}", e); + } + service.lock().await.reset(); + }); + + Ok(Null) } fn is_stream_chunck(request: &HttpRequest) -> bool { request .headers() - .get(TRANSFER_ENCODING) + .get(CONTENT_TYPE) .map(|v| v.to_str().unwrap()) - == Some("chunked") - && request - .headers() - .get(CONTENT_ENCODING) - .map(|v| v.to_str().unwrap()) - == Some(mime::APPLICATION_OCTET_STREAM.essence_str()) + == Some("application/octet-stream") } diff --git a/bmcd/src/main.rs b/bmcd/src/main.rs index 6901c81..295060e 100644 --- a/bmcd/src/main.rs +++ b/bmcd/src/main.rs @@ -16,11 +16,13 @@ async fn main() -> anyhow::Result<()> { let bmc = Data::new(BmcApplication::new().await?); run_event_listener(bmc.deref().clone())?; + let flash_service = Data::new(Mutex::new(FlashService::new(bmc.deref().clone()))); + HttpServer::new(move || { App::new() // Shared state: BmcApplication instance .app_data(bmc.clone()) - .app_data(Mutex::new(FlashService::new(bmc.deref().clone()))) + .app_data(flash_service.clone()) // Legacy API .configure(legacy::config) // Enable logger diff --git a/tpi_rs/src/utils/mod.rs b/tpi_rs/src/utils/mod.rs index 430a6cc..07c30bf 100644 --- a/tpi_rs/src/utils/mod.rs +++ b/tpi_rs/src/utils/mod.rs @@ -1,5 +1,20 @@ mod event_listener; +use std::fmt::Display; + #[doc(inline)] pub use event_listener::*; mod io; pub use io::*; +use tokio::sync::mpsc::Receiver; + +// for now we print the status updates to console. In the future we would like to pass +// this back to the clients. +pub fn logging_sink( + mut receiver: Receiver, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + while let Some(msg) = receiver.recv().await { + log::info!("{}", msg); + } + }) +}