diff --git a/bmcd/src/flash_service.rs b/bmcd/src/flash_service.rs index 51b60b1..84481f3 100644 --- a/bmcd/src/flash_service.rs +++ b/bmcd/src/flash_service.rs @@ -6,12 +6,12 @@ use tokio::{ io::{AsyncRead, BufReader}, sync::mpsc::{channel, error::SendError, Receiver, Sender}, }; -use tpi_rs::app::flash_application::FlashContext; use tpi_rs::{ app::bmc_application::BmcApplication, middleware::{firmware_update::SUPPORTED_DEVICES, NodeId, UsbRoute}, }; use tpi_rs::{app::flash_application::flash_node, middleware::firmware_update::FlashStatus}; +use tpi_rs::{app::flash_application::FlashContext, utils::ReceiverReader}; pub struct FlashService { status: Option>, @@ -39,7 +39,7 @@ impl FlashService { filename, size, node, - byte_stream: tokio::fs::File::open("/todo/make/wrapper").await.unwrap(), + byte_stream: ReceiverReader::new(receiver), bmc: self.bmc.clone(), progress_sender, }; diff --git a/tpi_rs/src/utils/io.rs b/tpi_rs/src/utils/io.rs new file mode 100644 index 0000000..aa3c9bb --- /dev/null +++ b/tpi_rs/src/utils/io.rs @@ -0,0 +1,146 @@ +use bytes::BufMut; +use std::{io, ops::Deref, task::Poll}; +use tokio::{ + io::{AsyncRead, ReadBuf}, + sync::mpsc::Receiver, +}; + +/// This struct wraps a [tokio::sync::mpsc::Receiver] and transforms that +/// exposes a [AsyncRead] interface. +pub struct ReceiverReader +where + T: Deref, +{ + receiver: Receiver, + buffer: Vec, +} + +impl ReceiverReader +where + T: Deref, +{ + pub fn new(receiver: Receiver) -> Self { + Self { + receiver, + buffer: Vec::new(), + } + } + + pub fn push_to_buffer(&mut self, data: &[u8]) { + self.buffer.extend_from_slice(data); + } + + pub fn take_buffered_bytes(&mut self, read_buf: &mut ReadBuf<'_>) -> bool { + let len = self.buffer.len().min(read_buf.remaining()); + if len > 0 { + let data: Vec = self.buffer.drain(..len).collect(); + read_buf.put_slice(&data); + true + } else { + false + } + } +} + +impl AsyncRead for ReceiverReader +where + T: Deref, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + let mut read_bytes = this.take_buffered_bytes(buf); + + while buf.has_remaining_mut() { + match this.receiver.poll_recv(cx) { + Poll::Ready(Some(c)) => { + let bytes = c.deref(); + let len = bytes.len().min(buf.remaining()); + buf.put_slice(&bytes[..len]); + read_bytes = true; + + if len < bytes.len() { + this.push_to_buffer(&bytes[len..]); + } + } + Poll::Ready(None) => { + if !read_bytes { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::BrokenPipe, + "channel closed", + ))); + } else { + return Poll::Ready(Ok(())); + } + } + Poll::Pending => return Poll::Pending, + }; + } + + Poll::Ready(Ok(())) + } +} + +#[cfg(test)] +mod test { + use super::*; + use tokio::{io::AsyncReadExt, sync::mpsc::channel}; + + #[tokio::test] + async fn receive_once_and_drain_buffer_test() { + let (sender, receiver) = channel::>(2); + let mut rr = ReceiverReader::new(receiver); + sender.send(vec![1, 2]).await.unwrap(); + drop(sender); + + let mut buffer = [0u8; 5]; + rr.read(&mut buffer[0..1]).await.unwrap(); + assert_eq!(buffer[0], 1); + rr.read(&mut buffer[0..1]).await.unwrap(); + assert_eq!(buffer[0], 2); + assert!(rr.read(&mut buffer[0..1]).await.is_err()); + } + + #[tokio::test] + async fn drain_buffer_and_new_read_available_test() { + let (sender, receiver) = channel::>(2); + let mut rr = ReceiverReader::new(receiver); + sender.send(vec![1, 2]).await.unwrap(); + + let mut buffer = [0u8; 5]; + rr.read(&mut buffer[0..1]).await.unwrap(); + assert_eq!(buffer[0], 1); + rr.read(&mut buffer[0..1]).await.unwrap(); + assert_eq!(buffer[0], 2); + sender.send(vec![8, 9]).await.unwrap(); + rr.read(&mut buffer[0..2]).await.unwrap(); + assert_eq!(vec![8, 9], buffer[0..2]); + } + + #[tokio::test] + async fn exhaust_reader_return_result() { + let (sender, receiver) = channel::>(2); + let result = tokio::spawn(async { + let mut rr = ReceiverReader::new(receiver); + let mut buffer = [0u8; 5]; + rr.read(&mut buffer).await.unwrap(); + assert_eq!(vec![1, 2, 3, 4, 5], buffer); + rr + }); + + sender.send(vec![1, 2]).await.unwrap(); + sender.send(vec![3, 4, 5, 6, 7]).await.unwrap(); + let mut rr = result.await.unwrap(); + drop(sender); + + let mut buffer = [0u8; 4]; + let res = rr.read(&mut buffer).await.unwrap(); + // we have exhaust the channel unable to complete the whole 4 bytes + // read request. return the last available bytes + assert_eq!(vec![6, 7], buffer[0..2]); + assert_eq!(res, 2); + } +} diff --git a/tpi_rs/src/utils/mod.rs b/tpi_rs/src/utils/mod.rs index 9803229..430a6cc 100644 --- a/tpi_rs/src/utils/mod.rs +++ b/tpi_rs/src/utils/mod.rs @@ -1,3 +1,5 @@ mod event_listener; #[doc(inline)] pub use event_listener::*; +mod io; +pub use io::*;