From a28a98107f871d7e746071e5e5e4d8326544e691 Mon Sep 17 00:00:00 2001 From: Sven Rademakers Date: Mon, 16 Oct 2023 15:55:22 +0100 Subject: [PATCH] streaming_data_service: * Includes patch for tokio::io::Take which caused overflows when steaming files over 4GB * Decreased lock contention of the streaming_data_service state by exposing the "Sender" of the data. (`take_sender()`) * Upgraded from peer validation to handle validation. Only data that is send to the correct handle endpoint is allowed. * various fixes and improvements --- Cargo.lock | 220 ++++++++++----- Cargo.toml | 7 + src/api/legacy.rs | 86 +++--- src/api/streaming_data_service.rs | 281 +++++++------------ src/app/firmware_runner.rs | 114 ++++---- src/app/mod.rs | 1 + src/app/transfer_action.rs | 150 ++++++++++ src/app/transfer_context.rs | 130 +-------- src/authentication/authentication_service.rs | 5 +- src/hal/mod.rs | 8 + src/utils/io.rs | 24 +- 11 files changed, 549 insertions(+), 477 deletions(-) create mode 100644 src/app/transfer_action.rs diff --git a/Cargo.lock b/Cargo.lock index 73ca227..e0a440f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -76,7 +76,7 @@ dependencies = [ "actix-utils", "ahash", "base64", - "bitflags 2.4.0", + "bitflags 2.4.1", "brotli", "bytes", "bytestring", @@ -113,6 +113,44 @@ dependencies = [ "syn 2.0.38", ] +[[package]] +name = "actix-multipart" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b960e2aea75f49c8f069108063d12a48d329fc8b60b786dfc7552a9d5918d2d" +dependencies = [ + "actix-multipart-derive", + "actix-utils", + "actix-web", + "bytes", + "derive_more", + "futures-core", + "futures-util", + "httparse", + "local-waker", + "log", + "memchr", + "mime", + "serde", + "serde_json", + "serde_plain", + "tempfile", + "tokio", +] + +[[package]] +name = "actix-multipart-derive" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a0a77f836d869f700e5b47ac7c3c8b9c8bc82e4aec861954c6198abee3ebd4d" +dependencies = [ + "darling", + "parse-size", + "proc-macro2", + "quote", + "syn 2.0.38", +] + [[package]] name = "actix-router" version = "0.5.1" @@ -277,9 +315,9 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.1.1" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea5d730647d4fadd988536d06fecce94b7b4f2a7efdae548f1cf4b63205518ab" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" dependencies = [ "memchr", ] @@ -374,6 +412,17 @@ version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "619743e34b5ba4e9703bba34deac3427c72507c7159f5fd030aea8cac0cfe341" +[[package]] +name = "async-trait" +version = "0.1.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -418,9 +467,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" [[package]] name = "bitvec" @@ -468,8 +517,10 @@ name = "bmcd" version = "1.3.0" dependencies = [ "actix-files", + "actix-multipart", "actix-web", "anyhow", + "async-trait", "base64", "bincode", "build-time", @@ -760,10 +811,11 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2696e8a945f658fd14dc3b87242e6b80cd0f36ff04ea560fa39082368847946" +checksum = "0f32d04922c60427da6f9fef14d042d9edddef64cb9d4ce0d64d0685fbeb1fd3" dependencies = [ + "powerfmt", "serde", ] @@ -816,25 +868,14 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.4" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "add4f07d43996f76ef320709726a556a9d4f965d9410d8d0271132d2f8293480" +checksum = "ac3e13f66a2f95e32a39eaa81f6b95d42878ca0e1db0c7543723dfe12557e860" dependencies = [ - "errno-dragonfly", "libc", "windows-sys 0.48.0", ] -[[package]] -name = "errno-dragonfly" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" -dependencies = [ - "cc", - "libc", -] - [[package]] name = "evdev" version = "0.12.1" @@ -859,11 +900,17 @@ dependencies = [ "instant", ] +[[package]] +name = "fastrand" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" + [[package]] name = "flate2" -version = "1.0.27" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6c98ee8095e9d1dcbf2fcc6d95acccb90d1c81db1e44725c6a984b1dbdfb010" +checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" dependencies = [ "crc32fast", "miniz_oxide", @@ -1138,16 +1185,16 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.57" +version = "0.1.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fad5b825842d2b38bd206f3e81d6957625fd7f0a361e345c30e01a0ae2dd613" +checksum = "8326b86b6cff230b97d0d312a6c40a60726df3332e721f72a1b035f451663b20" dependencies = [ "android_system_properties", "core-foundation-sys", "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows", + "windows-core", ] [[package]] @@ -1241,9 +1288,9 @@ checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" [[package]] name = "jobserver" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "936cfd212a0155903bcbc060e316fb6cc7cbf2e1907329391ebadc1fe0ce77c2" +checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" dependencies = [ "libc", ] @@ -1295,9 +1342,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.8" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3852614a3bd9ca9804678ba6be5e3b8ce76dfc902cae004e3e0c44051b6e88db" +checksum = "da2479e8c062e40bf0066ffa0bc823de0a9368974af99c9f6df941d2c231e03f" [[package]] name = "local-channel" @@ -1467,16 +1514,16 @@ version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053" dependencies = [ - "bitflags 2.4.0", + "bitflags 2.4.1", "cfg-if", "libc", ] [[package]] name = "num-traits" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f30b0abd723be7e2ffca1272140fac1a2f084c77ec3e123c192b66af1ee9e6c2" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" dependencies = [ "autocfg", ] @@ -1548,7 +1595,7 @@ version = "0.10.57" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bac25ee399abb46215765b1cb35bc0212377e58a061560d8b29b024fd0430e7c" dependencies = [ - "bitflags 2.4.0", + "bitflags 2.4.1", "cfg-if", "foreign-types", "libc", @@ -1613,6 +1660,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "parse-size" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "944553dd59c802559559161f9816429058b869003836120e262e8caec061b7ae" + [[package]] name = "paste" version = "1.0.14" @@ -1643,6 +1696,12 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -1661,9 +1720,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.68" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b1106fec09662ec6dd98ccac0f81cef56984d0b49f75c92d8cbad76e20c005c" +checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" dependencies = [ "unicode-ident", ] @@ -1776,9 +1835,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.6" +version = "1.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebee201405406dbf528b8b672104ae6d6d63e6d118cb10e4d51abbc7b58044ff" +checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" dependencies = [ "aho-corasick", "memchr", @@ -1788,9 +1847,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.9" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59b23e92ee4318893fa3fe3e6fb365258efbfe6ac6ab30f090cdcbb7aa37efa9" +checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" dependencies = [ "aho-corasick", "memchr", @@ -1799,9 +1858,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.5" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "remove_dir_all" @@ -1843,7 +1902,7 @@ source = "git+https://github.com/collabora/rockchiprs?rev=dc90ab5#dc90ab5cc8d2ec dependencies = [ "bytes", "crc", - "fastrand", + "fastrand 1.9.0", "num_enum", "rusb", "thiserror", @@ -1876,11 +1935,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.17" +version = "0.38.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f25469e9ae0f3d0047ca8b93fc56843f38e6774f0914a107ff8b41be8be8e0b7" +checksum = "745ecfa778e66b2b63c88a61cb36e0eea109e803b0b86bf9879fbc77c70e86ed" dependencies = [ - "bitflags 2.4.0", + "bitflags 2.4.1", "errno", "libc", "linux-raw-sys", @@ -1942,24 +2001,24 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad977052201c6de01a8ef2aa3378c4bd23217a056337d1d6da40468d267a4fb0" +checksum = "836fa6a3e1e547f9a2c4040802ec865b5d85f4014efe00555d7090a3dcaa1090" [[package]] name = "serde" -version = "1.0.188" +version = "1.0.189" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" +checksum = "8e422a44e74ad4001bdc8eede9a4570ab52f71190e9c076d14369f38b9200537" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.188" +version = "1.0.189" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" +checksum = "1e48d1f918009ce3145511378cf68d613e3b3d9137d67272562080d68a2b32d5" dependencies = [ "proc-macro2", "quote", @@ -1977,6 +2036,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_plain" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1fc6db65a611022b23a0dec6975d63fb80a302cb3388835ff02c097258d50" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -2187,6 +2255,19 @@ dependencies = [ "remove_dir_all", ] +[[package]] +name = "tempfile" +version = "3.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb94d2f3cc536af71caac6b6fcebf65860b347e7ce0cc9ebe8f70d3e521054ef" +dependencies = [ + "cfg-if", + "fastrand 2.0.1", + "redox_syscall", + "rustix", + "windows-sys 0.48.0", +] + [[package]] name = "thiserror" version = "1.0.49" @@ -2209,14 +2290,15 @@ dependencies = [ [[package]] name = "time" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "426f806f4089c493dcac0d24c29c01e2c38baf8e30f1b716ee37e83d200b18fe" +checksum = "c4a34ab300f2dee6e562c10a046fc05e358b29f9bf92277f30c3c8d82275f6f5" dependencies = [ "deranged", "itoa", "libc", "num_threads", + "powerfmt", "serde", "time-core", "time-macros", @@ -2254,9 +2336,8 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.32.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17ed6077ed6cd6c74735e21f37eb16dc3935f96878b1fe961074089cc80893f9" +version = "1.33.0" +source = "git+https://github.com/tokio-rs/tokio?rev=654a3d5acf37841d74dca411ec7a7cc70495e1cd#654a3d5acf37841d74dca411ec7a7cc70495e1cd" dependencies = [ "backtrace", "bytes", @@ -2274,8 +2355,7 @@ dependencies = [ [[package]] name = "tokio-macros" version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +source = "git+https://github.com/tokio-rs/tokio?rev=654a3d5acf37841d74dca411ec7a7cc70495e1cd#654a3d5acf37841d74dca411ec7a7cc70495e1cd" dependencies = [ "proc-macro2", "quote", @@ -2353,11 +2433,10 @@ dependencies = [ [[package]] name = "tracing" -version = "0.1.37" +version = "0.1.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +checksum = "ee2ef2af84856a50c1d430afce2fdded0a4ec7eda868db86409b4543df0797f9" dependencies = [ - "cfg-if", "log", "pin-project-lite", "tracing-core", @@ -2365,9 +2444,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", ] @@ -2542,10 +2621,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] -name = "windows" -version = "0.48.0" +name = "windows-core" +version = "0.51.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" +checksum = "f1f8cf84f35d2db49a46868f947758c7a1138116f7fac3bc844f43ade1292e64" dependencies = [ "windows-targets", ] @@ -2675,9 +2754,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "winnow" -version = "0.5.16" +version = "0.5.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "037711d82167854aff2018dfd193aa0fef5370f456732f0d5a0c59b0f1b4b907" +checksum = "a3b801d0e0a6726477cc207f60162da452f3a95adb368399bef20a946e06f65c" dependencies = [ "memchr", ] @@ -2712,11 +2791,10 @@ dependencies = [ [[package]] name = "zstd-sys" -version = "2.0.8+zstd.1.5.5" +version = "2.0.9+zstd.1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5556e6ee25d32df2586c098bbfa278803692a20d0ab9565e049480d52707ec8c" +checksum = "9e16efa8a874a0481a574084d34cc26fdb3b99627480f785888deb6386506656" dependencies = [ "cc", - "libc", "pkg-config", ] diff --git a/Cargo.toml b/Cargo.toml index 7619406..c1f8bbe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,8 @@ serde_with = "3.3.0" thiserror = "1.0.49" tokio-stream = "0.1.14" humansize = "2.1.3" +actix-multipart = "0.6.1" +async-trait = "0.1.74" [dev-dependencies] tempdir = "0.3.7" @@ -60,3 +62,8 @@ strip = true [features] vendored = ["openssl/vendored"] +[patch.crates-io] +# this patch needs to be removed as soon as the given rev lands in a release +tokio = { git = "https://github.com/tokio-rs/tokio", rev="654a3d5acf37841d74dca411ec7a7cc70495e1cd" } + + diff --git a/src/api/legacy.rs b/src/api/legacy.rs index 7ffff61..8a160db 100644 --- a/src/api/legacy.rs +++ b/src/api/legacy.rs @@ -12,18 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. //! Routes for legacy API present in versions <= 1.1.0 of the firmware. -use super::streaming_data_service::TransferType; use crate::api::into_legacy_response::LegacyResponse; use crate::api::into_legacy_response::{LegacyResult, Null}; use crate::api::streaming_data_service::StreamingDataService; use crate::app::bmc_application::{BmcApplication, Encoding, UsbConfig}; -use crate::app::firmware_runner::FirmwareRunner; +use crate::app::transfer_action::{TransferType, UpgradeAction, UpgradeType}; use crate::hal::{NodeId, UsbMode, UsbRoute}; use crate::utils::logging_sink; +use actix_multipart::Multipart; use actix_web::guard::{fn_guard, GuardContext}; use actix_web::http::StatusCode; -use actix_web::web::Bytes; -use actix_web::{get, web, HttpRequest, Responder}; +use actix_web::{get, post, web, Responder}; use anyhow::Context; use serde_json::json; use std::collections::HashMap; @@ -31,6 +30,7 @@ use std::ops::Deref; use std::str::FromStr; use tokio::io::AsyncBufReadExt; use tokio::sync::mpsc; +use tokio_stream::StreamExt; type Query = web::Query>; /// version 1: @@ -56,9 +56,9 @@ pub fn config(cfg: &mut web::ServiceConfig) { .guard(fn_guard(flash_guard)) .to(handle_flash_request), ) - .route(web::post().guard(fn_guard(flash_guard)).to(handle_chunk)) .route(web::get().to(api_entry)), - ); + ) + .service(handle_file_upload); } pub fn info_config(cfg: &mut web::ServiceConfig) { @@ -66,12 +66,16 @@ pub fn info_config(cfg: &mut web::ServiceConfig) { } fn flash_status_guard(context: &GuardContext<'_>) -> bool { - let Some(query) = context.head().uri.query() else { return false; }; + let Some(query) = context.head().uri.query() else { + return false; + }; query.contains("opt=get") && (query.contains("type=flash") || query.contains("type=firmware")) } fn flash_guard(context: &GuardContext<'_>) -> bool { - let Some(query) = context.head().uri.query() else { return false; }; + let Some(query) = context.head().uri.query() else { + return false; + }; query.contains("opt=set") && (query.contains("type=flash") || query.contains("type=firmware")) } @@ -460,9 +464,8 @@ async fn handle_flash_status(flash: web::Data) -> LegacyRe async fn handle_flash_request( ss: web::Data, bmc: web::Data, - request: HttpRequest, mut query: Query, -) -> LegacyResult { +) -> LegacyResult { let file = query .get("file") .ok_or(LegacyResponse::bad_request( @@ -470,15 +473,15 @@ async fn handle_flash_request( ))? .to_string(); - let peer: String = request - .connection_info() - .peer_addr() - .map(Into::into) - .context("peer_addr unknown")?; - - let (firmware_request, process_name) = match query.get_mut("type").map(|c| c.as_str()) { - Some("firmware") => (true, "upgrade os task".to_string()), - Some("flash") => (false, "node flash service".to_string()), + let (process_name, upgrade_type) = match query.get_mut("type").map(|c| c.as_str()) { + Some("firmware") => ("os upgrade service".to_string(), UpgradeType::OsUpgrade), + Some("flash") => { + let node = get_node_param(&query)?; + ( + format!("{node} upgrade service"), + UpgradeType::Module(node, bmc.clone().into_inner()), + ) + } _ => panic!("programming error: `type` should equal 'firmware' or 'flash'"), }; @@ -492,34 +495,31 @@ async fn handle_flash_request( let size = u64::from_str(size) .map_err(|_| LegacyResponse::bad_request("`length` parameter is not a number"))?; - TransferType::Remote(peer, size) + TransferType::Remote(file, size) }; - let handle = ss.request_transfer(process_name, transfer_type).await?; - let context = FirmwareRunner::new(file.into(), handle); - - if firmware_request { - ss.execute_worker(context.os_update()).await?; - } else { - let node = get_node_param(&query)?; - ss.execute_worker(context.flash_node(bmc.clone().into_inner(), node)) - .await?; - } - - Ok(Null) + let action = UpgradeAction::new(upgrade_type, transfer_type); + let handle = ss.request_transfer(process_name, action).await?; + let json = json!({"handle": handle}); + Ok(json.to_string()) } -async fn handle_chunk( - flash: web::Data, - request: HttpRequest, - chunk: Bytes, -) -> LegacyResult { - let peer: String = request - .connection_info() - .peer_addr() - .map(Into::into) - .context("peer_addr unknown")?; +#[post("/api/bmc/upload/{handle}")] +async fn handle_file_upload( + handle: web::Path, + ss: web::Data, + mut payload: Multipart, +) -> impl Responder { + let sender = ss.take_sender(*handle).await?; + let Some(Ok(mut field)) = payload.next().await else { + return Err(LegacyResponse::bad_request("Multipart form invalid")); + }; + + while let Some(Ok(chunk)) = field.next().await { + if sender.send(chunk).await.is_err() { + return Err((StatusCode::GONE, "upload cancelled").into()); + } + } - flash.put_chunk(peer, chunk).await?; Ok(Null) } diff --git a/src/api/streaming_data_service.rs b/src/api/streaming_data_service.rs index fb685a6..a12f18e 100644 --- a/src/api/streaming_data_service.rs +++ b/src/api/streaming_data_service.rs @@ -12,28 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. use crate::api::into_legacy_response::LegacyResponse; -use crate::app::transfer_context::{TransferContext, TransferSource}; +use crate::app::transfer_context::TransferContext; use actix_web::http::StatusCode; use bytes::Bytes; +use futures::future::BoxFuture; use futures::Future; use humansize::{format_size, DECIMAL}; +use rand::Rng; use serde::Serialize; use std::fmt::{Debug, Display}; use std::{ - ops::{Deref, DerefMut}, + ops::Deref, sync::Arc, time::{Duration, Instant}, }; use thiserror::Error; -use tokio::fs::OpenOptions; -use tokio::io::AsyncSeekExt; use tokio::sync::{mpsc, watch, Mutex}; -use tokio::{io::AsyncRead, sync::mpsc::error::SendError}; -use tokio_stream::wrappers::ReceiverStream; -use tokio_stream::StreamExt; -use tokio_util::io::StreamReader; +use tokio::time::sleep; use tokio_util::sync::CancellationToken; -const RESET_TIMEOUT: Duration = Duration::from_secs(10); pub struct StreamingDataService { status: Arc>, @@ -46,26 +42,27 @@ impl StreamingDataService { } } - /// Start a node flash command and initialize [`StreamingDataService`] for - /// chunked file transfer. Calling this function twice results in a - /// `Err(StreamingServiceError::InProgress)`. Unless the first file transfer - /// deemed to be stale. In this case the [`StreamingDataService`] will be - /// reset and initialize for a new transfer. A transfer is stale when the - /// `RESET_TIMEOUT` is reached. Meaning no chunk has been received for - /// longer as `RESET_TIMEOUT`. - pub async fn request_transfer>( + /// Initialize [`StreamingDataService`] for chunked file transfer. Calling + /// this function cancels any ongoing transfers and resets the internal + /// state of the service to `StreamingState::Ready` before going to + /// `StreamingState::Transferring` again. + pub async fn request_transfer( &self, process_name: String, - transfer_type: M, - ) -> Result { - let transfer_type = transfer_type.into(); - let mut status = self.status.lock().await; - self.reset_transfer_on_timeout(status.deref_mut())?; + action: impl TransferAction, + ) -> Result { + let mut rng = rand::thread_rng(); + let id = rng.gen(); - let (reader, context) = match transfer_type { - TransferType::Local(path) => Self::local(process_name, path).await?, - TransferType::Remote(peer, length) => Self::remote(process_name, peer, length)?, - }; + let size = action.total_size()?; + + let (written_sender, written_receiver) = watch::channel(0u64); + let cancel = CancellationToken::new(); + let (sender, worker) = action + .into_data_processor(64, written_sender, cancel.child_token()) + .await?; + let context = + TransferContext::new(id, process_name, size, written_receiver, sender, cancel); log::info!( "new transfer initialized: '{}'({}) {}", @@ -74,115 +71,46 @@ impl StreamingDataService { format_size(context.size, DECIMAL), ); - *status = StreamingState::Transferring(context); - Ok(reader) - } - - pub fn remote( - process_name: String, - peer: String, - size: u64, - ) -> Result<(ReaderContext, TransferContext), StreamingServiceError> { - let (bytes_sender, bytes_receiver) = mpsc::channel::(256); - let reader = StreamReader::new( - ReceiverStream::new(bytes_receiver).map(Ok::), - ); - let source: TransferSource = bytes_sender.into(); + self.execute_worker(&context, worker).await; + Self::cancel_request_on_timeout(self.status.clone()); + *self.status.lock().await = StreamingState::Transferring(context); - Ok(Self::new_context_pair( - process_name, - peer, - reader, - source, - size, - )) + Ok(id) } - pub async fn local( - process_name: String, - path: String, - ) -> Result<(ReaderContext, TransferContext), StreamingServiceError> { - let mut file = OpenOptions::new().read(true).open(&path).await?; - let file_size = file.seek(std::io::SeekFrom::End(0)).await?; - file.seek(std::io::SeekFrom::Start(0)).await?; - Ok(Self::new_context_pair( - process_name, - path, - file, - TransferSource::Local, - file_size, - )) - } - - fn new_context_pair( - process_name: String, - peer: String, - reader: R, - source: TransferSource, - size: u64, - ) -> (ReaderContext, TransferContext) - where - R: AsyncRead + Unpin + Send + Sync + 'static, - { - let (written_sender, written_receiver) = watch::channel(0u64); - - let transfer_ctx = TransferContext::new(peer, process_name, source, size, written_receiver); - let reader_ctx = ReaderContext { - reader: Box::new(reader), - size, - cancel: transfer_ctx.get_child_token(), - written_sender, - }; - - (reader_ctx, transfer_ctx) - } + fn cancel_request_on_timeout(status: Arc>) { + tokio::spawn(async move { + sleep(Duration::from_secs(10)).await; + let mut status_unlocked = status.lock().await; - /// When a 'start_transfer' call is made while we are still in a transfer - /// state, assume that the current transfer is stale given the timeout limit - /// is reached. - fn reset_transfer_on_timeout( - &self, - mut status: impl DerefMut, - ) -> Result<(), StreamingServiceError> { - if let StreamingState::Transferring(context) = &*status { - let duration = context.duration_since_last_chunk(); - if duration < RESET_TIMEOUT { - return Err(StreamingServiceError::InProgress); - } else { - log::warn!( - "Assuming transfer ({}) will never complete as last request was {}s ago. Resetting flash service", - context.id, - duration.as_secs() - ); - *status = StreamingState::Ready; + if matches!( + &*status_unlocked, + StreamingState::Transferring(ctx) if ctx.data_sender.is_some() + ) { + *status_unlocked = StreamingState::Error("Send timeout".to_string()); } - } - Ok(()) + }); } /// Worker task that performs the actual node flash. This tasks finishes if /// one of the following scenario's is met: - /// * flashing completed successfully - /// * flashing was canceled - /// * Error occurred during flashing. + /// * transfer & flashing completed successfully + /// * transfer & flashing was canceled + /// * Error occurred during transfer or flashing. + /// + /// Note that the "global" status (`StreamingState`) does not get updated to + /// `StreamingState::Error(_)` when the worker was canceled as the cancel + /// was an effect of a prior state change. In this case we omit the state + /// transition to `FlashSstatus::Error(_)` /// - /// Note that the "global" status does not get updated when the task was - /// canceled. Cancel can only be true on a state transition from - /// `StreamingState::Transferring`, meaning a state transition already - /// happened. In this case we omit a state transition to - /// `FlashSstatus::Error(_)` - pub async fn execute_worker( + async fn execute_worker( &self, + context: &TransferContext, future: impl Future> + Send + 'static, - ) -> Result<(), StreamingServiceError> { - let status = self.status.lock().await; - let StreamingState::Transferring(ctx) = &*status else { - return Err(StreamingServiceError::WrongState(status.to_string(), "Transferring".to_string())); - }; - - let id = ctx.id; - let cancel = ctx.get_child_token(); - let size = ctx.size; + ) { + let id = context.id; + let cancel = context.get_child_token(); + let size = context.size; let start_time = Instant::now(); let status = self.status.clone(); @@ -214,7 +142,11 @@ impl StreamingDataService { // already correct, therefore we omit a state transition in this scenario. let mut status_unlocked = status.lock().await; if let StreamingState::Transferring(ctx) = &*status_unlocked { - log::debug!("last recorded transfer state {:#?}", ctx); + log::debug!( + "last recorded transfer state: {:#?}", + serde_json::to_string(ctx) + ); + if !was_cancelled { log::info!("state={new_state}"); *status_unlocked = new_state; @@ -223,7 +155,6 @@ impl StreamingDataService { } } }); - Ok(()) } /// Write a chunk of bytes to the module that is selected for flashing. @@ -234,31 +165,27 @@ impl StreamingDataService { /// /// * 'Err(StreamingServiceError::WrongState)' if this function is called when /// ['StreamingDataService'] is not in 'Transferring' state. - /// * 'Err(StreamingServiceError::EmptyPayload)' when data == empty - /// * 'Err(StreamingServiceError::Error(_)' when there is an internal error + /// * 'Err(StreamingServiceError::HandlesDoNotMatch)', the passed id is + /// unknown + /// * 'Err(StreamingServiceError::SenderTaken(_)' /// * Ok(()) on success - pub async fn put_chunk(&self, peer: String, data: Bytes) -> Result<(), StreamingServiceError> { + pub async fn take_sender(&self, id: u32) -> Result, StreamingServiceError> { let mut status = self.status.lock().await; - if let StreamingState::Transferring(ref mut context) = *status { - context.is_equal_peer(&peer)?; - - if data.is_empty() { - *status = StreamingState::Ready; - return Err(StreamingServiceError::EmptyPayload); - } - - if let Err(e) = context.push_bytes(data).await { - *status = StreamingState::Error(e.to_string()); - return Err(e); - } - - Ok(()) - } else { - Err(StreamingServiceError::WrongState( + let StreamingState::Transferring(ref mut context) = *status else { + return Err(StreamingServiceError::WrongState( status.to_string(), "Transferring".to_string(), - )) + )); + }; + + if id != context.id { + return Err(StreamingServiceError::HandlesDoNotMatch); } + + context + .data_sender + .take() + .ok_or(StreamingServiceError::SenderTaken) } /// Return a borrow to the current status of the flash service @@ -270,44 +197,32 @@ impl StreamingDataService { #[derive(Error, Debug)] pub enum StreamingServiceError { - #[error("another flashing operation in progress")] - InProgress, #[error("cannot execute command in current state. current={0}, expected={1}")] WrongState(String, String), - #[error("received empty payload")] - EmptyPayload, - #[error("unauthorized request from peer {0}")] - PeersDoNotMatch(String), - #[error("{0} was aborted")] - Aborted(String), - #[error("error processing internal buffers")] - MpscError(#[from] SendError), - #[error("Received more bytes as negotiated")] - LengthExceeded, + #[error("unauthorized request for handle")] + HandlesDoNotMatch, #[error("IO error")] IoError(#[from] std::io::Error), - #[error("not a remote transfer")] - IsLocalTransfer, + #[error( + "cannot transfer bytes to worker. This is either because the transfer \ + happens locally, or is already ongoing." + )] + SenderTaken, } impl From for LegacyResponse { fn from(value: StreamingServiceError) -> Self { let status_code = match value { - StreamingServiceError::InProgress => StatusCode::SERVICE_UNAVAILABLE, StreamingServiceError::WrongState(_, _) => StatusCode::BAD_REQUEST, - StreamingServiceError::MpscError(_) => StatusCode::INTERNAL_SERVER_ERROR, - StreamingServiceError::Aborted(_) => StatusCode::INTERNAL_SERVER_ERROR, - StreamingServiceError::EmptyPayload => StatusCode::BAD_REQUEST, - StreamingServiceError::PeersDoNotMatch(_) => StatusCode::BAD_REQUEST, - StreamingServiceError::LengthExceeded => StatusCode::BAD_REQUEST, + StreamingServiceError::HandlesDoNotMatch => StatusCode::BAD_REQUEST, StreamingServiceError::IoError(_) => StatusCode::INTERNAL_SERVER_ERROR, - StreamingServiceError::IsLocalTransfer => StatusCode::BAD_REQUEST, + StreamingServiceError::SenderTaken => StatusCode::BAD_REQUEST, }; (status_code, value.to_string()).into() } } -#[derive(Serialize, Debug)] +#[derive(Serialize)] pub enum StreamingState { Ready, Transferring(TransferContext), @@ -326,14 +241,28 @@ impl Display for StreamingState { } } -pub struct ReaderContext { - pub reader: Box, - pub size: u64, - pub cancel: CancellationToken, - pub written_sender: watch::Sender, -} +/// Implementers of this trait return a "sender" and "worker" pair which allow +/// them to asynchronously process bytes that are sent over the optional sender. +#[async_trait::async_trait] +pub trait TransferAction { + /// Construct a "data processor". Implementers are obliged to cancel the + /// worker when the cancel token returns canceled. Secondly, they are + /// expected to report status, via the watcher, on how many bytes are + /// processed. + /// The "sender" equals `None` when the transfer happens internally, and therefore + /// does not require any external object to feed data to the worker. This + /// typically happens when a file transfer is executed locally from disk. + async fn into_data_processor( + self, + channel_size: usize, + watcher: watch::Sender, + cancel: CancellationToken, + ) -> std::io::Result<( + Option>, + BoxFuture<'static, anyhow::Result<()>>, + )>; -pub enum TransferType { - Local(String), - Remote(String, u64), + /// return the amount of data that is going to be transferred from the + /// "sender" to the "worker". + fn total_size(&self) -> std::io::Result; } diff --git a/src/app/firmware_runner.rs b/src/app/firmware_runner.rs index 4116b36..355fa16 100644 --- a/src/app/firmware_runner.rs +++ b/src/app/firmware_runner.rs @@ -12,16 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. use super::bmc_application::UsbConfig; -use crate::api::streaming_data_service::ReaderContext; use crate::app::bmc_application::BmcApplication; use crate::utils::{logging_sink, reader_with_crc64, WriteWatcher}; use crate::{ - firmware_update::{FlashProgress, FlashStatus, FlashingError, SUPPORTED_DEVICES}, + firmware_update::{FlashProgress, FlashingError, SUPPORTED_DEVICES}, hal::{NodeId, UsbRoute}, }; use anyhow::bail; use crc::Crc; use crc::CRC_64_REDIS; +use humansize::{format_size, DECIMAL}; use std::cmp::Ordering; use std::io::{Error, ErrorKind}; use std::path::PathBuf; @@ -30,10 +30,10 @@ use std::{sync::Arc, time::Duration}; use tokio::fs::OpenOptions; use tokio::io::sink; use tokio::io::AsyncReadExt; +use tokio::sync::{mpsc, watch}; use tokio::{ fs, io::{self, AsyncRead, AsyncSeekExt, AsyncWrite, AsyncWriteExt}, - sync::mpsc::{channel, Sender}, time::sleep, }; use tokio_util::sync::CancellationToken; @@ -44,31 +44,35 @@ const MOUNT_POINT: &str = "/tmp/os_upgrade"; // Contains collection of functions that execute some business flow in relation // to file transfers in the BMC. See `flash_node` and `os_update`. pub struct FirmwareRunner { - pub filename: String, - pub context: Option, - pub progress_sender: Sender, + reader: Box, + file_name: String, + size: u64, + cancel: CancellationToken, + written_sender: watch::Sender, + progress_sender: mpsc::Sender, } impl FirmwareRunner { - pub fn new(filename: PathBuf, reader_context: ReaderContext) -> Self { - let (progress_sender, progress_receiver) = channel(32); - logging_sink(progress_receiver); - + pub fn new( + reader: Box, + file_name: String, + size: u64, + cancel: CancellationToken, + written_sender: watch::Sender, + ) -> Self { + let (sender, receiver) = mpsc::channel(16); + logging_sink(receiver); Self { - filename: filename - .file_name() - .map(|s| s.to_string_lossy().to_string()) - .unwrap_or(filename.to_string_lossy().to_string()), - progress_sender, - context: Some(reader_context), + reader, + file_name, + size, + cancel, + written_sender, + progress_sender: sender, } } - pub async fn flash_node( - mut self, - bmc: Arc, - node: NodeId, - ) -> anyhow::Result<()> { + pub async fn flash_node(self, bmc: Arc, node: NodeId) -> anyhow::Result<()> { let mut device = bmc .configure_node_for_fwupgrade( node, @@ -78,67 +82,51 @@ impl FirmwareRunner { ) .await?; - let mut progress_state = FlashProgress { - message: String::new(), - status: FlashStatus::Setup, - }; - - progress_state.message = format!("Writing {:?}", self.filename); - self.progress_sender.send(progress_state.clone()).await?; - - let context = self.context.take().expect("context should always be set"); - let write_watcher = WriteWatcher::new(&mut device, context.written_sender); - let img_checksum = - copy_with_crc(context.reader, write_watcher, context.size, &context.cancel).await?; - - progress_state.message = String::from("Verifying checksum..."); - self.progress_sender.send(progress_state.clone()).await?; + let write_watcher = WriteWatcher::new(&mut device, self.written_sender); + let img_checksum = copy_with_crc( + self.reader.take(self.size), + write_watcher, + self.size, + &self.cancel, + ) + .await?; + log::info!("Verifying checksum..."); device.seek(std::io::SeekFrom::Start(0)).await?; flush_file_caches().await?; let dev_checksum = - copy_with_crc(&mut device, sink(), context.size, &context.cancel).await?; + copy_with_crc(&mut device.take(self.size), sink(), self.size, &self.cancel).await?; if img_checksum != dev_checksum { - self.progress_sender - .send(FlashProgress { - status: FlashStatus::Error(FlashingError::ChecksumMismatch), - message: format!( - "Source and destination checksum mismatch: {:#x} != {:#x}", - img_checksum, dev_checksum - ), - }) - .await?; + log::error!( + "Source and destination checksum mismatch: {:#x} != {:#x}", + img_checksum, + dev_checksum + ); bail!(FlashingError::ChecksumMismatch) } - progress_state.message = String::from("Flashing successful, restarting device..."); - self.progress_sender.send(progress_state.clone()).await?; - + log::info!("Flashing successful, restarting device..."); bmc.activate_slot(!node.to_bitfield(), node.to_bitfield()) .await?; //TODO: we probably want to restore the state prior flashing bmc.usb_boot(node, false).await?; bmc.configure_usb(UsbConfig::UsbA(node)).await?; - sleep(REBOOT_DELAY).await; - bmc.activate_slot(node.to_bitfield(), node.to_bitfield()) .await?; - progress_state.message = String::from("Done"); - self.progress_sender.send(progress_state).await?; Ok(()) } - pub async fn os_update(mut self) -> anyhow::Result<()> { + pub async fn os_update(self) -> anyhow::Result<()> { log::info!("start os update"); let mut os_update_img = PathBuf::from(MOUNT_POINT); - os_update_img.push(&self.filename); + os_update_img.push(&self.file_name); tokio::fs::create_dir_all(MOUNT_POINT).await?; @@ -149,9 +137,8 @@ impl FirmwareRunner { .open(&os_update_img) .await?; - let context = self.context.take().expect("context should always be set"); - let write_watcher = WriteWatcher::new(&mut file, context.written_sender); - let result = copy_with_crc(context.reader, write_watcher, context.size, &context.cancel) + let write_watcher = WriteWatcher::new(&mut file, self.written_sender); + let result = copy_with_crc(self.reader, write_watcher, self.size, &self.cancel) .await .and_then(|crc| { log::info!("crc os_update image: {}", crc); @@ -172,9 +159,9 @@ impl FirmwareRunner { } } -/// Copies `self.size` bytes from `reader` to `writer` and returns the crc -/// that was calculated over the reader. This function returns an -/// `io::Error(Interrupted)` in case a cancel was issued. +/// Copies bytes from `reader` to `writer` until the reader is exhausted. This +/// function returns the crc that was calculated over the reader. This function +/// returns an `io::Error(Interrupted)` in case a cancel was issued. async fn copy_with_crc( reader: L, mut writer: W, @@ -186,7 +173,7 @@ where W: AsyncWrite + std::marker::Unpin, { let crc = Crc::::new(&CRC_64_REDIS); - let mut crc_reader = reader_with_crc64(reader.take(size), &crc); + let mut crc_reader = reader_with_crc64(reader, &crc); let copy_task = tokio::io::copy(&mut crc_reader, &mut writer); let cancel = cancel.cancelled(); @@ -208,12 +195,13 @@ fn validate_size(len: u64, total_size: u64) -> std::io::Result<()> { match len.cmp(&total_size) { Ordering::Less => Err(Error::new( ErrorKind::UnexpectedEof, - format!("missing {} bytes", total_size - len), + format!("missing {} bytes", format_size(total_size - len, DECIMAL)), )), Ordering::Greater => panic!("reads are capped to self.size"), Ordering::Equal => Ok(()), } } + async fn flush_file_caches() -> io::Result<()> { let mut file = fs::OpenOptions::new() .write(true) diff --git a/src/app/mod.rs b/src/app/mod.rs index 0a75d9a..965bfda 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -14,4 +14,5 @@ pub mod bmc_application; pub mod event_application; pub mod firmware_runner; +pub mod transfer_action; pub mod transfer_context; diff --git a/src/app/transfer_action.rs b/src/app/transfer_action.rs new file mode 100644 index 0000000..0b3bff0 --- /dev/null +++ b/src/app/transfer_action.rs @@ -0,0 +1,150 @@ +// Copyright 2023 Turing Machines +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +use std::sync::Arc; +use std::{io::ErrorKind, path::PathBuf}; + +use super::bmc_application::BmcApplication; +use super::firmware_runner::FirmwareRunner; +use crate::api::streaming_data_service::TransferAction; +use crate::hal::NodeId; +use bytes::Bytes; +use futures::future::BoxFuture; +use serde::Serialize; +use std::io::Seek; +use tokio::sync::mpsc; +use tokio::{fs::OpenOptions, io::AsyncRead, sync::watch}; +use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::StreamExt; +use tokio_util::{io::StreamReader, sync::CancellationToken}; + +#[derive(Debug)] +pub struct UpgradeAction { + transfer_type: TransferType, + upgrade_type: UpgradeType, +} + +impl UpgradeAction { + pub fn new(upgrade_type: UpgradeType, transfer_type: TransferType) -> Self { + Self { + transfer_type, + upgrade_type, + } + } +} + +#[async_trait::async_trait] +impl TransferAction for UpgradeAction { + async fn into_data_processor( + self, + channel_size: usize, + written_sender: watch::Sender, + cancel: CancellationToken, + ) -> std::io::Result<( + Option>, + BoxFuture<'static, anyhow::Result<()>>, + )> { + let file_name = self.transfer_type.file_name()?; + let size = self.transfer_type.size()?; + let (sender, receiver) = self.transfer_type.transfer_channel(channel_size).await?; + + let worker = self.upgrade_type.run(FirmwareRunner::new( + receiver, + file_name, + size, + cancel, + written_sender, + )); + Ok((sender, worker)) + } + + fn total_size(&self) -> std::io::Result { + self.transfer_type.size() + } +} + +#[derive(Debug)] +pub enum UpgradeType { + OsUpgrade, + Module(NodeId, Arc), +} + +impl UpgradeType { + pub fn run( + self, + firmware_runner: FirmwareRunner, + ) -> BoxFuture<'static, Result<(), anyhow::Error>> { + match self { + UpgradeType::OsUpgrade => Box::pin(firmware_runner.os_update()), + UpgradeType::Module(bmc, node) => Box::pin(firmware_runner.flash_node(node, bmc)), + } + } +} + +#[derive(Debug, Serialize)] +pub enum TransferType { + Local(String), + Remote(String, u64), +} + +impl TransferType { + pub fn size(&self) -> std::io::Result { + match self { + TransferType::Local(path) => { + let mut file = std::fs::OpenOptions::new().read(true).open(path)?; + file.seek(std::io::SeekFrom::End(0)) + } + TransferType::Remote(_, size) => Ok(*size), + } + } + + pub fn file_name(&self) -> std::io::Result { + match self { + TransferType::Local(path) => { + let file_name = PathBuf::from(path) + .file_name() + .ok_or(std::io::Error::from(ErrorKind::InvalidInput))? + .to_string_lossy() + .to_string(); + Ok(file_name) + } + TransferType::Remote(file_name, _) => Ok(file_name.clone()), + } + } + + pub async fn transfer_channel( + self, + items: usize, + ) -> std::io::Result<( + Option>, + Box, + )> { + match self { + TransferType::Local(path) => OpenOptions::new().read(true).open(&path).await.map(|x| { + ( + None, + Box::new(x) as Box, + ) + }), + TransferType::Remote(_, _) => { + let (sender, receiver) = mpsc::channel(items); + Ok(( + Some(sender), + Box::new(StreamReader::new( + ReceiverStream::new(receiver).map(Ok::), + )), + )) + } + } + } +} diff --git a/src/app/transfer_context.rs b/src/app/transfer_context.rs index e4caded..d58b0d8 100644 --- a/src/app/transfer_context.rs +++ b/src/app/transfer_context.rs @@ -12,120 +12,44 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::api::streaming_data_service::StreamingServiceError; use bytes::Bytes; -use rand::Rng; use serde::{Serialize, Serializer}; -use std::{ - collections::hash_map::DefaultHasher, - hash::{Hash, Hasher}, - time::Duration, -}; -use tokio::{ - sync::{mpsc, watch}, - time::Instant, -}; +use tokio::sync::{mpsc, watch}; use tokio_util::sync::CancellationToken; /// Context object for node flashing. This object acts as a "cancel-guard" for /// the [`StreamingDataService`]. If [`TransferContext`] gets dropped, it will /// cancel its "cancel" token, effectively aborting the node flash task. This /// typically happens on a state transition inside the [`StreamingDataService`]. -#[derive(Serialize, Debug)] +#[derive(Serialize)] pub struct TransferContext { - pub id: u64, - pub peer: String, - #[serde(skip)] - pub peer_hash: u64, + pub id: u32, pub process_name: String, pub size: u64, - #[serde(serialize_with = "serialize_source", rename = "bytes_sent")] - reader: TransferSource, + #[serde(skip)] + pub data_sender: Option>, #[serde(serialize_with = "serialize_cancellation_token")] cancelled: CancellationToken, - #[serde(serialize_with = "serialize_seconds_until_now")] - last_recieved_chunk: Instant, #[serde(serialize_with = "serialize_written_bytes")] bytes_written: watch::Receiver, } impl TransferContext { - pub fn new>( - peer: String, + pub fn new( + id: u32, process_name: String, - transfer_source: S, size: u64, written_receiver: watch::Receiver, + data_sender: Option>, + cancel_token: CancellationToken, ) -> Self { - let mut rng = rand::thread_rng(); - let id = rng.gen(); - - let mut hasher = DefaultHasher::new(); - peer.hash(&mut hasher); - let peer_hash = hasher.finish(); - TransferContext { id, - peer, - peer_hash, - process_name, size, - reader: transfer_source.into(), - cancelled: CancellationToken::new(), - last_recieved_chunk: Instant::now(), + process_name, + cancelled: cancel_token, bytes_written: written_receiver, - } - } - - pub fn duration_since_last_chunk(&self) -> Duration { - Instant::now().saturating_duration_since(self.last_recieved_chunk) - } - - pub fn is_equal_peer(&self, peer: &str) -> Result<(), StreamingServiceError> { - let mut hasher = DefaultHasher::new(); - peer.hash(&mut hasher); - let hashed_peer = hasher.finish(); - if self.peer_hash != hashed_peer { - return Err(StreamingServiceError::PeersDoNotMatch(peer.to_string())); - } - - Ok(()) - } - - /// Send given bytes through a channel towards the object that is - /// processing the file transfer ([`FirmwareRunner`]). This function should - /// defer from making any application and state transitions. This is up to - /// to the receiver side. This function does however contain some conveniences - /// to book-keep transfer meta-data. - /// This function is only relevant for cases where source == - /// TransferSource::Peer(_)` - pub async fn push_bytes(&mut self, data: Bytes) -> Result<(), StreamingServiceError> { - let (bytes_sender, bytes_sent) = match &mut self.reader { - TransferSource::Local => return Err(StreamingServiceError::IsLocalTransfer), - TransferSource::Peer(None, _) => return Err(StreamingServiceError::LengthExceeded), - TransferSource::Peer(sender, b) => (sender, b), - }; - - let len = data.len(); - match bytes_sender.as_mut().unwrap().send(data).await { - Ok(_) => { - self.last_recieved_chunk = Instant::now(); - *bytes_sent += len as u64; - // Close the channel to signal to the other side that the last - // chunk was sent. We cannot however switch yet to "Done" state. - // As its up to the receiving side to signal (see - // 'StreamingDataService::execute_worker') when its done - // processing this data. - if *bytes_sent >= self.size { - log::info!("{}:{}", bytes_sent, self.size); - *bytes_sender = None; - } - Ok(()) - } - Err(_) if bytes_sender.as_ref().unwrap().is_closed() => { - Err(StreamingServiceError::Aborted(self.process_name.clone())) - } - Err(e) => Err(e.into()), + data_sender, } } @@ -150,39 +74,9 @@ where s.serialize_bool(cancel_token.is_cancelled()) } -fn serialize_seconds_until_now(instant: &Instant, s: S) -> Result -where - S: Serializer, -{ - let secs = Instant::now().saturating_duration_since(*instant).as_secs(); - s.serialize_u64(secs) -} - fn serialize_written_bytes(receiver: &watch::Receiver, s: S) -> Result where S: Serializer, { s.serialize_u64(*receiver.borrow()) } - -fn serialize_source(source: &TransferSource, s: S) -> Result -where - S: Serializer, -{ - match source { - TransferSource::Local => s.serialize_none(), - TransferSource::Peer(_, bytes_sent) => s.serialize_u64(*bytes_sent), - } -} - -#[derive(Debug)] -pub enum TransferSource { - Local, - Peer(Option>, u64), -} - -impl From> for TransferSource { - fn from(value: mpsc::Sender) -> Self { - TransferSource::Peer(Some(value), 0) - } -} diff --git a/src/authentication/authentication_service.rs b/src/authentication/authentication_service.rs index a4370fa..15c1524 100644 --- a/src/authentication/authentication_service.rs +++ b/src/authentication/authentication_service.rs @@ -173,5 +173,8 @@ fn unauthorized_response( let response = HttpResponse::Unauthorized() .insert_header((header::WWW_AUTHENTICATE, bearer_str)) .body(response_text.to_string()); - Ok(ServiceResponse::new(request.clone(), response)).map(ServiceResponse::map_into_right_body) + Ok(ServiceResponse::map_into_right_body(ServiceResponse::new( + request.clone(), + response, + ))) } diff --git a/src/hal/mod.rs b/src/hal/mod.rs index 2e8e93a..6c24d3e 100644 --- a/src/hal/mod.rs +++ b/src/hal/mod.rs @@ -1,3 +1,5 @@ +use std::fmt::Display; + // Copyright 2023 Turing Machines // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -59,6 +61,12 @@ impl NodeId { } } +impl Display for NodeId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "node {:?}", (*self as u8) + 1) + } +} + #[repr(C)] #[derive(Debug, Eq, PartialEq, Clone, Copy, serde::Serialize, serde::Deserialize)] pub enum NodeType { diff --git a/src/utils/io.rs b/src/utils/io.rs index 755aaa1..0ba6019 100644 --- a/src/utils/io.rs +++ b/src/utils/io.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. use crc::{Crc, Digest}; -use std::pin::Pin; +use std::{pin::Pin, task::Poll}; use tokio::{ io::{AsyncRead, AsyncWrite}, sync::watch, @@ -26,6 +26,7 @@ where sender: watch::Sender, inner: W, } + impl WriteWatcher where W: AsyncWrite, @@ -49,10 +50,13 @@ where buf: &[u8], ) -> std::task::Poll> { let me = Pin::get_mut(self); - me.written += buf.len() as u64; - me.sender.send_replace(me.written); - Pin::new(&mut me.inner).poll_write(cx, buf) + let result = Pin::new(&mut me.inner).poll_write(cx, buf); + if let Poll::Ready(Ok(written)) = result { + me.written += written as u64; + me.sender.send_replace(me.written); + } + result } fn poll_flush( @@ -68,7 +72,7 @@ where cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let me = Pin::get_mut(self); - Pin::new(&mut me.inner).poll_flush(cx) + Pin::new(&mut me.inner).poll_shutdown(cx) } } @@ -135,6 +139,16 @@ mod test { array } + #[tokio::test] + async fn write_watcher_test() { + let mut reader = tokio::io::repeat(0b101).take(1044 * 1004); + let (sender, receiver) = watch::channel(0u64); + let mut writer = WriteWatcher::new(tokio::io::sink(), sender); + let copied = tokio::io::copy(&mut reader, &mut writer).await.unwrap(); + assert_eq!(copied, 1044 * 1004); + assert_eq!(*receiver.borrow(), 1044 * 1004); + } + #[tokio::test] async fn crc_reader_test() { let mut buffer = random_array::<{ 1024 * 1024 + 23 }>();