Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
flash_service: stability and improvements
Browse files Browse the repository at this point in the history
* added a API call so a client can ask the current status of the
  `FlashService`
* fixes regarding cancelling of the node flash
* cleanup of code
svenrademakers committed Sep 20, 2023
1 parent b5f5579 commit b8abd08
Showing 12 changed files with 591 additions and 451 deletions.
54 changes: 27 additions & 27 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions bmcd/Cargo.toml
Original file line number Diff line number Diff line change
@@ -23,3 +23,4 @@ tokio.workspace = true
tokio-util.workspace = true
futures.workspace = true
serde.workspace = true
rand = "0.8.5"
335 changes: 232 additions & 103 deletions bmcd/src/flash_service.rs

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions bmcd/src/into_legacy_response.rs
Original file line number Diff line number Diff line change
@@ -67,6 +67,12 @@ impl From<anyhow::Error> for LegacyResponse {
}
}

impl From<serde_json::Error> for LegacyResponse {
fn from(value: serde_json::Error) -> Self {
LegacyResponse::Error(StatusCode::INTERNAL_SERVER_ERROR, value.to_string().into())
}
}

impl ResponseError for LegacyResponse {}

impl Display for LegacyResponse {
47 changes: 30 additions & 17 deletions bmcd/src/legacy.rs
Original file line number Diff line number Diff line change
@@ -9,8 +9,9 @@ use actix_web::{get, web, HttpRequest, Responder};
use anyhow::Context;
use nix::sys::statfs::statfs;
use serde_json::json;
use std::ops::Deref;
use std::str::FromStr;
use tokio::sync::{mpsc, Mutex};
use tokio::sync::mpsc;
use tpi_rs::app::bmc_application::{BmcApplication, UsbConfig};
use tpi_rs::middleware::{NodeId, UsbMode, UsbRoute};
use tpi_rs::utils::logging_sink;
@@ -29,6 +30,11 @@ const API_VERSION: &str = "1.1";
pub fn config(cfg: &mut web::ServiceConfig) {
cfg.service(
web::resource("/api/bmc")
.route(
web::get()
.guard(fn_guard(flash_status_guard))
.to(handle_flash_status),
)
.route(
web::get()
.guard(fn_guard(flash_guard))
@@ -43,11 +49,21 @@ pub fn info_config(cfg: &mut web::ServiceConfig) {
cfg.service(info_handler);
}

fn flash_status_guard(context: &GuardContext<'_>) -> bool {
let query = context.head().uri.query();
query
.map(|q| q.contains("status"))
.and(query.map(|q| q.contains("type=flash")))
.and(query.map(|q| q.contains("opt=get")))
.unwrap_or(false)
}

fn flash_guard(context: &GuardContext<'_>) -> bool {
let query = context.head().uri.query();
let is_set = query.map(|q| q.contains("opt=set")).unwrap_or(false);
let is_type = query.map(|q| q.contains("type=flash")).unwrap_or(false);
is_set && is_type
query
.map(|q| q.contains("opt=set"))
.and(query.map(|q| q.contains("type=flash")))
.unwrap_or(false)
}

#[get("/api/bmc/info")]
@@ -352,8 +368,13 @@ async fn get_usb_mode(bmc: &BmcApplication) -> impl Into<LegacyResponse> {
)
}

async fn handle_flash_status(flash: web::Data<FlashService>) -> LegacyResult<String> {
Ok(serde_json::to_string(flash.status().await.deref())?)
}

async fn handle_flash_request(
flash: web::Data<Mutex<FlashService>>,
flash: web::Data<FlashService>,
bmc: web::Data<BmcApplication>,
request: HttpRequest,
query: Query,
) -> LegacyResult<Null> {
@@ -379,23 +400,15 @@ async fn handle_flash_request(
.map(Into::into)
.context("peer_addr unknown")?;

let on_done = flash
.lock()
.await
.start_transfer(&peer, file, size, node)
flash
.start_transfer(&peer, file, size, node, bmc.into_inner())
.await?;

tokio::spawn(async move {
if let Err(e) = on_done.await {
log::error!("{}", e);
}
});

Ok(Null)
}

async fn handle_chunk(
flash: web::Data<Mutex<FlashService>>,
flash: web::Data<FlashService>,
request: HttpRequest,
chunk: Bytes,
) -> LegacyResult<Null> {
@@ -405,6 +418,6 @@ async fn handle_chunk(
.map(Into::into)
.context("peer_addr unknown")?;

flash.lock().await.put_chunk(peer, chunk).await?;
flash.put_chunk(peer, chunk).await?;
Ok(Null)
}
18 changes: 8 additions & 10 deletions bmcd/src/main.rs
Original file line number Diff line number Diff line change
@@ -10,11 +10,7 @@ use anyhow::Context;
use clap::{command, value_parser, Arg};
use log::LevelFilter;
use openssl::ssl::SslAcceptorBuilder;
use std::{
path::{Path, PathBuf},
sync::Arc,
};
use tokio::sync::Mutex;
use std::path::{Path, PathBuf};
use tpi_rs::app::{bmc_application::BmcApplication, event_application::run_event_listener};
pub mod config;
mod flash_service;
@@ -34,10 +30,9 @@ async fn main() -> anyhow::Result<()> {
let tls = load_tls_configuration(&config.tls.private_key, &config.tls.certificate)?;
let tls6 = load_tls_configuration(&config.tls.private_key, &config.tls.certificate)?;

let bmc = Arc::new(BmcApplication::new().await?);
run_event_listener(bmc.clone())?;
let flash_service = Data::new(Mutex::new(FlashService::new(bmc.clone())));
let bmc = Data::from(bmc);
let bmc = Data::new(BmcApplication::new().await?);
run_event_listener(bmc.clone().into_inner())?;
let flash_service = Data::new(FlashService::new());

let run_server = HttpServer::new(move || {
App::new()
@@ -64,6 +59,7 @@ async fn main() -> anyhow::Result<()> {
.default_service(web::route().to(redirect))
})
.bind(("0.0.0.0", HTTP_PORT))?
.bind(("::1", HTTP_PORT))?
.run();

tokio::try_join!(run_server, redirect_server)?;
@@ -89,8 +85,10 @@ fn init_logger() {

simple_logger::SimpleLogger::new()
.with_level(level)
.with_module_level("bmcd", LevelFilter::Info)
.with_module_level("bmcd", LevelFilter::Debug)
.with_module_level("actix_http", LevelFilter::Info)
.with_module_level("h2::codec", LevelFilter::Info)
.with_module_level("h2::proto", LevelFilter::Info)
.with_colors(true)
.env()
.init()
289 changes: 0 additions & 289 deletions tpi_rs/src/app/flash_application.rs

This file was deleted.

275 changes: 275 additions & 0 deletions tpi_rs/src/app/flash_context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
use super::bmc_application::UsbConfig;
use crate::app::bmc_application::BmcApplication;
use crate::middleware::{
firmware_update::{FlashProgress, FlashStatus, FlashingError, SUPPORTED_DEVICES},
NodeId, UsbRoute,
};
use crate::utils::logging_sink;
use anyhow::{bail, ensure, Context};
use crc::{Crc, CRC_64_REDIS};
use futures::TryFutureExt;
use std::{sync::Arc, time::Duration};
use tokio::{
fs,
io::{self, AsyncRead, AsyncReadExt, AsyncSeekExt, AsyncWrite, AsyncWriteExt},
sync::mpsc::{channel, Receiver, Sender},
time::{sleep, Instant},
};
use tokio_util::sync::CancellationToken;

const REBOOT_DELAY: Duration = Duration::from_millis(500);
const BUF_SIZE: u64 = 8 * 1024;
const PROGRESS_REPORT_PERCENT: u64 = 5;

pub struct FlashContext<R: AsyncRead> {
pub id: u64,
pub filename: String,
pub size: u64,
pub node: NodeId,
pub byte_stream: R,
pub bmc: Arc<BmcApplication>,
pub progress_sender: Sender<FlashProgress>,
pub cancel: CancellationToken,
}

impl<R: AsyncRead + Unpin> FlashContext<R> {
pub fn new(
id: u64,
filename: String,
size: u64,
node: NodeId,
byte_stream: R,
bmc: Arc<BmcApplication>,
cancel: CancellationToken,
) -> Self {
let (progress_sender, progress_receiver) = channel(32);
logging_sink(progress_receiver);

Self {
id,
filename,
size,
node,
byte_stream,
bmc,
progress_sender,
cancel,
}
}

pub async fn flash_node(&mut self) -> anyhow::Result<()> {
let mut device = self
.bmc
.configure_node_for_fwupgrade(
self.node,
UsbRoute::Bmc,
self.progress_sender.clone(),
SUPPORTED_DEVICES.keys(),
)
.await?;

let mut progress_state = FlashProgress {
message: String::new(),
status: FlashStatus::Setup,
};

progress_state.message = format!("Writing {:?}", self.filename);
self.progress_sender.send(progress_state.clone()).await?;

let img_checksum = self.write_to_device(&mut device).await?;

progress_state.message = String::from("Verifying checksum...");
self.progress_sender.send(progress_state.clone()).await?;

device.seek(std::io::SeekFrom::Start(0)).await?;
flush_file_caches().await?;

self.verify_checksum(&mut device, img_checksum).await?;

progress_state.message = String::from("Flashing successful, restarting device...");
self.progress_sender.send(progress_state.clone()).await?;

self.bmc
.activate_slot(!self.node.to_bitfield(), self.node.to_bitfield())
.await?;

//TODO: we probably want to restore the state prior flashing
self.bmc.usb_boot(self.node, false).await?;
self.bmc.configure_usb(UsbConfig::UsbA(self.node)).await?;

sleep(REBOOT_DELAY).await;

self.bmc
.activate_slot(self.node.to_bitfield(), self.node.to_bitfield())
.await?;

progress_state.message = String::from("Done");
self.progress_sender.send(progress_state).await?;
Ok(())
}

async fn write_to_device<W: AsyncWrite + Unpin>(
&mut self,
mut device: W,
) -> anyhow::Result<u64> {
let mut buffer = vec![0u8; BUF_SIZE as usize];
let mut total_read = 0;

let img_crc = Crc::<u64>::new(&CRC_64_REDIS);
let mut img_digest = img_crc.digest();

let (size_sender, size_receiver) = channel::<u64>(32);
tokio::spawn(run_progress_printer(
self.size,
self.progress_sender.clone(),
size_receiver,
));

let mut progress_update_guard = 0u64;

// Read_exact and write_all is used here to enforce a certain write size to the `image_writer`.
// This function could be further optimized to reduce the amount of awaiting reads/writes, e.g.
// look at the implementation of `tokio::io::copy`. But in the case of an 'rockusb'
// image_writer, writes that are misaligned with the sector-size of its device induces an extra
// buffering penalty.
while total_read < self.size {
let buf_len = BUF_SIZE.min(self.size - total_read) as usize;
let read_task = self.byte_stream.read_exact(&mut buffer[..buf_len]);
let monitor_cancel = self.cancel.cancelled();

tokio::select! {
// self.bytes_stream is not guaranteed to be "cancel-safe". Canceling read_task
// might result in a data loss. We allow this as a cancel aborts the whole file
// file transfer
res = read_task => ensure!(res? == buf_len),
_ = monitor_cancel => bail!("write cancelled"),
}

total_read += buf_len as u64;

// we accept sporadic lost progress updates or in worst case an error
// inside the channel. It should never prevent the writing process from
// completing.
// Updates to the progress printer are throttled with an arbitrary
// value.
if progress_update_guard % 1000 == 0 {
let _ = size_sender.try_send(total_read);
}
progress_update_guard += 1;

img_digest.update(&buffer[..buf_len]);
device
.write_all(&buffer[..buf_len])
.map_err(|e| anyhow::anyhow!("device write error: {}", e))
.await?;
}

device.flush().await?;

Ok(img_digest.finalize())
}

async fn verify_checksum<L>(&self, reader: L, img_checksum: u64) -> anyhow::Result<()>
where
L: AsyncRead + std::marker::Unpin,
{
let mut reader = io::BufReader::with_capacity(BUF_SIZE as usize, reader);

let mut buffer = vec![0u8; BUF_SIZE as usize];
let mut total_read = 0;

let crc = Crc::<u64>::new(&CRC_64_REDIS);
let mut digest = crc.digest();

while total_read < self.size {
let bytes_left = self.size - total_read;
let buffer_size = buffer.len().min(bytes_left as usize);
let read_task = reader.read(&mut buffer[..buffer_size]);
let monitor_cancel = self.cancel.cancelled();

let num_read = tokio::select! {
res = read_task => res?,
_ = monitor_cancel => bail!("checksum calculation is cancelled"),
};

if num_read == 0 {
log::error!("read 0 bytes with {} bytes to go", bytes_left);
bail!(FlashingError::IoError);
}

total_read += num_read as u64;
digest.update(&buffer[..num_read]);
}

let dev_checksum = digest.finalize();
if img_checksum != dev_checksum {
self.progress_sender
.send(FlashProgress {
status: FlashStatus::Error(FlashingError::ChecksumMismatch),
message: format!(
"Source and destination checksum mismatch: {:#x} != {:#x}",
img_checksum, dev_checksum
),
})
.await?;

bail!(FlashingError::ChecksumMismatch)
}
Ok(())
}
}

async fn run_progress_printer(
img_len: u64,
logging: Sender<FlashProgress>,
mut read_reciever: Receiver<u64>,
) -> anyhow::Result<()> {
let start_time = Instant::now();
let mut last_print = 0;

while let Some(total_read) = read_reciever.recv().await {
let read_percent = 100 * total_read / img_len;

let progress_counter = read_percent - last_print;
if progress_counter >= PROGRESS_REPORT_PERCENT {
#[allow(clippy::cast_precision_loss)] // This affects files > 4 exabytes long
let read_proportion = (total_read as f64) / (img_len as f64);

let duration = start_time.elapsed();
let estimated_end = duration.div_f64(read_proportion);
let estimated_left = estimated_end - duration;

let est_seconds = estimated_left.as_secs() % 60;
let est_minutes = (estimated_left.as_secs() / 60) % 60;

let message = format!(
"Progress: {:>2}%, estimated time left: {:02}:{:02}",
read_percent, est_minutes, est_seconds,
);

logging
.send(FlashProgress {
status: FlashStatus::Progress {
read_percent: read_percent as usize,
est_minutes,
est_seconds,
},
message,
})
.await
.context("progress update error")?;
last_print = read_percent;
}
}
Ok(())
}

async fn flush_file_caches() -> io::Result<()> {
let mut file = fs::OpenOptions::new()
.write(true)
.open("/proc/sys/vm/drop_caches")
.await?;

// Free reclaimable slab objects and page cache
file.write_u8(b'3').await
}
2 changes: 1 addition & 1 deletion tpi_rs/src/app/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pub mod bmc_application;
pub mod event_application;
pub mod flash_application;
pub mod flash_context;
7 changes: 4 additions & 3 deletions tpi_rs/src/c_interface.rs
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ use tokio_util::sync::CancellationToken;

use crate::app::bmc_application::{BmcApplication, UsbConfig};
use crate::app::event_application::run_event_listener;
use crate::app::flash_application::{flash_node, FlashContext};
use crate::app::flash_context::FlashContext;
use crate::middleware::firmware_update::FlashingError;
use crate::middleware::{UsbMode, UsbRoute};

@@ -225,7 +225,8 @@ pub unsafe extern "C" fn tpi_flash_node(node: c_int, image_path: *const c_char)
let (sender, receiver) = channel(64);
let img_file = tokio::fs::File::open(&node_image).await.unwrap();
let img_len = img_file.metadata().await.unwrap().len();
let context = FlashContext {
let mut context = FlashContext {
id: 123,
filename: node_image
.file_name()
.unwrap()
@@ -239,7 +240,7 @@ pub unsafe extern "C" fn tpi_flash_node(node: c_int, image_path: *const c_char)
cancel: CancellationToken::new(),
};

let handle = tokio::spawn(flash_node(context));
let handle = tokio::spawn(async move { context.flash_node().await });

let print_handle = logging_sink(receiver);
let (res, _) = join!(handle, print_handle);
3 changes: 2 additions & 1 deletion tpi_rs/src/middleware/firmware_update/mod.rs
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@ mod rockusb_fwudate;
use self::rockusb_fwudate::new_rockusb_transport;
mod rpi_fwupdate;
use rpi_fwupdate::new_rpi_transport;
use serde::Serialize;
pub mod transport;
use self::transport::FwUpdateTransport;
use futures::future::BoxFuture;
@@ -57,7 +58,7 @@ pub fn fw_update_transport(
.ok_or(anyhow::anyhow!("no driver available for {:?}", device))
}

#[derive(Debug, Clone, Copy)]
#[derive(Serialize, Debug, Clone, Copy)]
pub enum FlashingError {
InvalidArgs,
DeviceNotFound,
5 changes: 5 additions & 0 deletions tpi_rs/src/utils/io.rs
Original file line number Diff line number Diff line change
@@ -7,6 +7,11 @@ use tokio::{

/// This struct wraps a [tokio::sync::mpsc::Receiver] and transforms that
/// exposes a [AsyncRead] interface.
///
/// # Cancel
///
/// This struct is *not* cancel safe! Using this struct in a tokio::select loop
/// can cause data loss.
pub struct ReceiverReader<T>
where
T: Deref<Target = [u8]>,

0 comments on commit b8abd08

Please sign in to comment.