diff --git a/Cargo.lock b/Cargo.lock index 981a72b..281e3de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,27 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "CoreFoundation-sys" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0e9889e6db118d49d88d84728d0e964d973a5680befb5f85f55141beea5c20b" +dependencies = [ + "libc", + "mach", +] + +[[package]] +name = "IOKit-sys" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99696c398cbaf669d2368076bdb3d627fb0ce51a26899d7c61228c5c0af3bf4a" +dependencies = [ + "CoreFoundation-sys", + "libc", + "mach", +] + [[package]] name = "actix-codec" version = "0.5.1" @@ -1228,6 +1249,24 @@ version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +[[package]] +name = "mach" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fd13ee2dd61cc82833ba05ade5a30bb3d63f7ced605ef827063c63078302de9" +dependencies = [ + "libc", +] + +[[package]] +name = "mach2" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d0d1830bcd151a6fc4aea1369af235b36c1528fe976b8ff678683c9995eade8" +dependencies = [ + "libc", +] + [[package]] name = "md-5" version = "0.9.1" @@ -1300,6 +1339,19 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mio-serial" +version = "5.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20a4c60ca5c9c0e114b3bd66ff4aa5f9b2b175442be51ca6c4365d687a97a2ac" +dependencies = [ + "log", + "mio", + "nix 0.26.4", + "serialport", + "winapi", +] + [[package]] name = "nix" version = "0.23.2" @@ -1846,6 +1898,23 @@ dependencies = [ "unsafe-libyaml", ] +[[package]] +name = "serialport" +version = "4.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c32634e2bd4311420caa504404a55fad2131292c485c97014cbed89a5899885f" +dependencies = [ + "CoreFoundation-sys", + "IOKit-sys", + "bitflags 1.3.2", + "cfg-if", + "mach2", + "nix 0.26.4", + "regex", + "scopeguard", + "winapi", +] + [[package]] name = "sha-1" version = "0.9.8" @@ -2092,6 +2161,21 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-serial" +version = "5.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa6e2e4cf0520a99c5f87d5abb24172b5bd220de57c3181baaaa5440540c64aa" +dependencies = [ + "bytes", + "cfg-if", + "futures", + "log", + "mio-serial", + "tokio", + "tokio-util", +] + [[package]] name = "tokio-util" version = "0.7.9" @@ -2144,6 +2228,7 @@ dependencies = [ "simple_logger", "tempdir", "tokio", + "tokio-serial", "tokio-util", ] diff --git a/Cargo.toml b/Cargo.toml index 4e321d3..6b0914e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,6 @@ [workspace] members = ["bmcd", "tpi_rs"] +resolver = "2" [workspace.package] version = "1.3.0" diff --git a/bmcd/src/into_legacy_response.rs b/bmcd/src/into_legacy_response.rs index cd4a042..b77bb61 100644 --- a/bmcd/src/into_legacy_response.rs +++ b/bmcd/src/into_legacy_response.rs @@ -8,6 +8,7 @@ use std::{borrow::Cow, fmt::Display}; pub enum LegacyResponse { Success(Option), Error(StatusCode, Cow<'static, str>), + UartData(String), } impl LegacyResponse { @@ -83,6 +84,7 @@ impl Display for LegacyResponse { "{}", s.as_ref().map(|json| json.to_string()).unwrap_or_default() ), + LegacyResponse::UartData(s) => write!(f, "{}", s), LegacyResponse::Error(_, msg) => write!(f, "{}", msg), } } @@ -100,19 +102,26 @@ pub type LegacyResult = Result; impl From for HttpResponse { fn from(value: LegacyResponse) -> Self { - let (response, result) = match value { - LegacyResponse::Success(None) => { - (StatusCode::OK, serde_json::Value::String("ok".to_string())) - } - LegacyResponse::Success(Some(body)) => (StatusCode::OK, body), - LegacyResponse::Error(status_code, msg) => { - (status_code, serde_json::Value::String(msg.into_owned())) - } + let (response, result, is_uart) = match value { + LegacyResponse::Success(None) => ( + StatusCode::OK, + serde_json::Value::String("ok".to_string()), + false, + ), + LegacyResponse::Success(Some(body)) => (StatusCode::OK, body, false), + LegacyResponse::UartData(d) => (StatusCode::OK, serde_json::Value::String(d), true), + LegacyResponse::Error(status_code, msg) => ( + status_code, + serde_json::Value::String(msg.into_owned()), + false, + ), }; - let msg = json!({ - "response": [{ "result": result }] - }); + let keyname = if is_uart { "uart" } else { "result" }; + + let msg = json! {{ + "response": [{ keyname: result }] + }}; HttpResponseBuilder::new(response).json(msg) } diff --git a/bmcd/src/legacy.rs b/bmcd/src/legacy.rs index d8dbb4d..6cf4d0c 100644 --- a/bmcd/src/legacy.rs +++ b/bmcd/src/legacy.rs @@ -12,7 +12,7 @@ use serde_json::json; use std::ops::Deref; use std::str::FromStr; use tokio::sync::mpsc; -use tpi_rs::app::bmc_application::{BmcApplication, UsbConfig}; +use tpi_rs::app::bmc_application::{BmcApplication, Encoding, UsbConfig}; use tpi_rs::middleware::{NodeId, UsbMode, UsbRoute}; use tpi_rs::utils::logging_sink; type Query = web::Query>; @@ -72,7 +72,7 @@ async fn api_entry(bmc: web::Data, query: Query) -> impl Respond }; let Some(ty) = query.get("type") else { - return LegacyResponse::bad_request("Missing `type` parameter") + return LegacyResponse::bad_request("Missing `type` parameter"); }; let bmc = bmc.as_ref(); @@ -89,8 +89,8 @@ async fn api_entry(bmc: web::Data, query: Query) -> impl Respond ("reset", true) => reset_node(bmc, query).await.into(), ("sdcard", true) => format_sdcard().into(), ("sdcard", false) => get_sdcard_info(), - ("uart", true) => write_to_uart(bmc, query).into(), - ("uart", false) => read_from_uart(bmc, query).into(), + ("uart", true) => write_to_uart(bmc, query).await.into(), + ("uart", false) => read_from_uart(bmc, query).await.into(), ("usb", true) => set_usb_mode(bmc, query).await.into(), ("usb", false) => get_usb_mode(bmc).await.into(), _ => ( @@ -155,11 +155,15 @@ fn get_node_param(query: &Query) -> LegacyResult { }; let Ok(node_num) = i32::from_str(node_str) else { - return Err(LegacyResponse::bad_request("Parameter `node` is not a number")); + return Err(LegacyResponse::bad_request( + "Parameter `node` is not a number", + )); }; let Ok(node) = node_num.try_into() else { - return Err(LegacyResponse::bad_request("Parameter `node` is out of range 0..3 of node IDs")); + return Err(LegacyResponse::bad_request( + "Parameter `node` is out of range 0..3 of node IDs", + )); }; Ok(node) @@ -294,30 +298,54 @@ fn get_sdcard_fs_stat() -> anyhow::Result<(u64, u64)> { Ok((total, free)) } -fn write_to_uart(bmc: &BmcApplication, query: Query) -> LegacyResult<()> { +async fn write_to_uart(bmc: &BmcApplication, query: Query) -> LegacyResult<()> { let node = get_node_param(&query)?; let Some(cmd) = query.get("cmd") else { - return Err(LegacyResponse::bad_request("Missing `cmd` parameter")); + return Err(LegacyResponse::bad_request("Missing `cmd` parameter")); }; + let mut data = cmd.clone(); - uart_write(bmc, node, cmd) + data.push_str("\r\n"); + + bmc.serial_write(node, data.as_bytes()) + .await .context("write over UART") .map_err(Into::into) } -fn uart_write(_bmc: &BmcApplication, _node: NodeId, _cmd: &str) -> anyhow::Result<()> { - todo!() -} - -fn read_from_uart(bmc: &BmcApplication, query: Query) -> LegacyResult<()> { +async fn read_from_uart(bmc: &BmcApplication, query: Query) -> LegacyResult { let node = get_node_param(&query)?; - uart_read(bmc, node) - .context("read from UART") - .map_err(Into::into) + let enc = get_encoding_param(&query)?; + let data = bmc.serial_read(node, enc).await?; + + Ok(LegacyResponse::UartData(data)) } -fn uart_read(_bmc: &BmcApplication, _node: NodeId) -> anyhow::Result<()> { - todo!() +fn get_encoding_param(query: &Query) -> LegacyResult { + let Some(enc_str) = query.get("encoding") else { + return Ok(Encoding::Utf8); + }; + + match enc_str.as_str() { + "utf8" => Ok(Encoding::Utf8), + "utf16" | "utf16le" => Ok(Encoding::Utf16 { + little_endian: true, + }), + "utf16be" => Ok(Encoding::Utf16 { + little_endian: false, + }), + "utf32" | "utf32le" => Ok(Encoding::Utf32 { + little_endian: true, + }), + "utf32be" => Ok(Encoding::Utf32 { + little_endian: false, + }), + _ => { + let msg = "Invalid `encoding` parameter. Expected: utf8, utf16, utf16le, utf16be, \ + utf32, utf32le, utf32be."; + Err(LegacyResponse::bad_request(msg)) + } + } } /// switches the USB configuration. @@ -402,7 +430,7 @@ async fn handle_flash_request( ))?; let size = u64::from_str(size) - .map_err(|_| LegacyResponse::bad_request("`length` parameter not a number"))?; + .map_err(|_| LegacyResponse::bad_request("`length` parameter is not a number"))?; let peer: String = request .connection_info() diff --git a/bmcd/src/main.rs b/bmcd/src/main.rs index 1e24b24..6514b3b 100644 --- a/bmcd/src/main.rs +++ b/bmcd/src/main.rs @@ -37,6 +37,7 @@ async fn main() -> anyhow::Result<()> { init_logger(); let (tls, tls6) = load_config()?; let bmc = Data::new(BmcApplication::new().await?); + bmc.start_serial_workers().await?; run_event_listener(bmc.clone().into_inner())?; let flash_service = Data::new(FlashService::new()); let authentication = Arc::new(LinuxAuthenticator::new("/api/bmc/authenticate").await?); diff --git a/tpi_rs/Cargo.toml b/tpi_rs/Cargo.toml index 719fd2d..31020bc 100644 --- a/tpi_rs/Cargo.toml +++ b/tpi_rs/Cargo.toml @@ -19,6 +19,7 @@ rockusb = { version = "0.1.1" } rusb = "0.9.3" rustpiboot = { git = "https://github.com/ruslashev/rustpiboot.git", rev="89e6497"} serde = { version = "1.0.188", features = ["derive"] } +tokio-serial = { version = "5.4.4", features = ["rt", "codec"] } anyhow.workspace = true log.workspace = true diff --git a/tpi_rs/src/app/bmc_application.rs b/tpi_rs/src/app/bmc_application.rs index 8dcac6f..3c235bc 100644 --- a/tpi_rs/src/app/bmc_application.rs +++ b/tpi_rs/src/app/bmc_application.rs @@ -5,8 +5,10 @@ use crate::middleware::firmware_update::{ use crate::middleware::persistency::app_persistency::ApplicationPersistency; use crate::middleware::persistency::app_persistency::PersistencyBuilder; use crate::middleware::power_controller::PowerController; +use crate::middleware::serial::SerialConnections; use crate::middleware::usbboot; use crate::middleware::{pin_controller::PinController, NodeId, UsbMode, UsbRoute}; +use crate::utils::{string_from_utf16, string_from_utf32}; use anyhow::{ensure, Context}; use log::{debug, info, trace}; use serde::{Deserialize, Serialize}; @@ -34,12 +36,20 @@ pub enum UsbConfig { Node(NodeId, UsbRoute), } +/// Encodings used when reading from a serial port +pub enum Encoding { + Utf8, + Utf16 { little_endian: bool }, + Utf32 { little_endian: bool }, +} + #[derive(Debug)] pub struct BmcApplication { pub(super) pin_controller: PinController, pub(super) power_controller: PowerController, pub(super) app_db: ApplicationPersistency, pub(super) nodes_on: AtomicBool, + serial: SerialConnections, } impl BmcApplication { @@ -51,12 +61,14 @@ impl BmcApplication { .register_key(USB_CONFIG, &UsbConfig::UsbA(NodeId::Node1)) .build() .await?; + let serial = SerialConnections::new()?; let instance = Self { pin_controller, power_controller, app_db, nodes_on: AtomicBool::new(false), + serial, }; instance.initialize().await?; @@ -290,4 +302,23 @@ impl BmcApplication { Command::new("shutdown").args(["-r", "now"]).spawn()?; Ok(()) } + + pub async fn start_serial_workers(&self) -> anyhow::Result<()> { + Ok(self.serial.run().await?) + } + + pub async fn serial_read(&self, node: NodeId, encoding: Encoding) -> anyhow::Result { + let bytes = self.serial.read(node).await?; + + let res = match encoding { + Encoding::Utf8 => String::from_utf8_lossy(&bytes).to_string(), + Encoding::Utf16 { little_endian } => string_from_utf16(&bytes, little_endian), + Encoding::Utf32 { little_endian } => string_from_utf32(&bytes, little_endian), + }; + Ok(res) + } + + pub async fn serial_write(&self, node: NodeId, data: &[u8]) -> anyhow::Result<()> { + Ok(self.serial.write(node, data).await?) + } } diff --git a/tpi_rs/src/c_interface.rs b/tpi_rs/src/c_interface.rs index fef198b..c5d6ed5 100644 --- a/tpi_rs/src/c_interface.rs +++ b/tpi_rs/src/c_interface.rs @@ -141,7 +141,7 @@ pub extern "C" fn tpi_get_node_power(node: c_int) -> c_int { #[no_mangle] pub extern "C" fn tpi_reset_node(node: c_int) { let Ok(node_id) = node.try_into().map_err(|e| log::error!("{}", e)) else { - return ; + return; }; execute_routine(|bmc| Box::pin(bmc.reset_node(node_id))); } @@ -215,7 +215,7 @@ pub unsafe extern "C" fn tpi_flash_node(node: c_int, image_path: *const c_char) let node_image = PathBuf::from(bstr); let Ok(node_id) = node.try_into() else { - return FlashingResult::InvalidArgs + return FlashingResult::InvalidArgs; }; RUNTIME.block_on(async move { diff --git a/tpi_rs/src/middleware/firmware_update/rockusb_fwudate.rs b/tpi_rs/src/middleware/firmware_update/rockusb_fwudate.rs index 9bd964a..38c3542 100644 --- a/tpi_rs/src/middleware/firmware_update/rockusb_fwudate.rs +++ b/tpi_rs/src/middleware/firmware_update/rockusb_fwudate.rs @@ -1,5 +1,4 @@ use super::transport::{StdFwUpdateTransport, StdTransportWrapper}; -use rusb::DeviceDescriptor; use super::{FlashProgress, FlashingError, FlashingErrorExt}; use crate::middleware::firmware_update::FlashStatus; use crate::middleware::usbboot; @@ -9,6 +8,7 @@ use rockfile::boot::{ RkBootEntry, RkBootEntryBytes, RkBootHeader, RkBootHeaderBytes, RkBootHeaderEntry, }; use rockusb::libusb::{Transport, TransportIO}; +use rusb::DeviceDescriptor; use rusb::GlobalContext; use std::{mem::size_of, ops::Range, time::Duration}; use tokio::sync::mpsc::Sender; @@ -23,7 +23,12 @@ pub async fn new_rockusb_transport( let mut transport = Transport::from_usb_device(device.open().map_err_into_logged_usb(logging)?) .map_err(|_| FlashingError::UsbError)?; - if BootMode::Maskrom == device.device_descriptor().map_err_into_logged_usb(logging)?.into() { + if BootMode::Maskrom + == device + .device_descriptor() + .map_err_into_logged_usb(logging)? + .into() + { info!("Maskrom mode detected. loading usb-plug.."); transport = download_boot(&mut transport, logging).await?; logging @@ -134,10 +139,10 @@ async fn load_boot_entries( } #[derive(Debug, PartialEq, Eq, Copy, Clone)] - pub enum BootMode { +pub enum BootMode { Maskrom = 0, - Loader = 1, - } + Loader = 1, +} impl From for BootMode { fn from(dd: DeviceDescriptor) -> BootMode { @@ -146,5 +151,5 @@ impl From for BootMode { 1 => BootMode::Loader, _ => unreachable!(), } - } - } + } +} diff --git a/tpi_rs/src/middleware/mod.rs b/tpi_rs/src/middleware/mod.rs index 3a4f0ca..84c724c 100644 --- a/tpi_rs/src/middleware/mod.rs +++ b/tpi_rs/src/middleware/mod.rs @@ -4,6 +4,7 @@ mod helpers; pub mod persistency; pub mod pin_controller; pub mod power_controller; +pub mod serial; pub mod usbboot; #[repr(C)] diff --git a/tpi_rs/src/middleware/serial.rs b/tpi_rs/src/middleware/serial.rs new file mode 100644 index 0000000..5ebf2b0 --- /dev/null +++ b/tpi_rs/src/middleware/serial.rs @@ -0,0 +1,170 @@ +//! Handlers for UART connections to/from nodes +use std::error::Error; +use std::fmt::Display; +use std::sync::Arc; + +use anyhow::Result; +use bytes::{Bytes, BytesMut}; +use futures::{SinkExt, StreamExt}; +use tokio::sync::mpsc::{channel, Sender}; +use tokio::sync::Mutex; +use tokio_serial::{DataBits, Parity, SerialPortBuilderExt, StopBits}; +use tokio_util::codec::{BytesCodec, Decoder}; + +use crate::utils::ring_buf::RingBuffer; + +use super::NodeId; + +const OUTPUT_BUF_SIZE: usize = 16 * 1024; + +#[derive(Debug)] +pub struct SerialConnections { + handlers: Vec>, +} + +impl SerialConnections { + pub fn new() -> Result { + let paths = ["/dev/ttyS2", "/dev/ttyS1", "/dev/ttyS4", "/dev/ttyS5"]; + + let handlers: Vec> = paths + .iter() + .enumerate() + .map(|(i, path)| Mutex::new(Handler::new(i + 1, path))) + .collect(); + + Ok(SerialConnections { handlers }) + } + + pub async fn run(&self) -> Result<(), SerialError> { + for h in &self.handlers { + h.lock().await.start_reader()?; + } + Ok(()) + } + + pub async fn read(&self, node: NodeId) -> Result { + let idx = node as usize; + self.handlers[idx].lock().await.read().await + } + + pub async fn write>(&self, node: NodeId, data: B) -> Result<(), SerialError> { + let idx = node as usize; + self.handlers[idx].lock().await.write(data.into()).await + } +} + +#[derive(Debug)] +struct Handler { + node: usize, + path: &'static str, + ring_buffer: Arc>>, + worker_context: Option>, +} + +impl Handler { + fn new(node: usize, path: &'static str) -> Self { + Handler { + node, + path, + ring_buffer: Arc::new(Mutex::new(RingBuffer::default())), + worker_context: None, + } + } + + async fn write>(&self, data: B) -> Result<(), SerialError> { + let Some(sender) = &self.worker_context else { + return Err(SerialError::NotStarted); + }; + + sender + .send(data.into()) + .await + .map_err(|e| SerialError::InternalError(e.to_string())) + } + + async fn read(&self) -> Result { + if self.worker_context.is_none() { + return Err(SerialError::NotStarted); + }; + + Ok(self.ring_buffer.lock().await.read().into()) + } + + fn start_reader(&mut self) -> Result<(), SerialError> { + if self.worker_context.take().is_some() { + return Err(SerialError::AlreadyRunning); + }; + + let baud_rate = 115200; + let mut port = tokio_serial::new(self.path, baud_rate) + .data_bits(DataBits::Eight) + .parity(Parity::None) + .stop_bits(StopBits::One) + .open_native_async() + .map_err(|e| SerialError::InternalError(e.to_string()))?; + + // Disable exclusivity of the port to allow other applications to open it. + // Not a reason to abort if we can't. + if let Err(e) = port.set_exclusive(false) { + log::warn!("Unable to set exclusivity of port {}: {}", self.path, e); + } + + let (sender, mut receiver) = channel::(64); + self.worker_context = Some(sender); + + let node = self.node; + let buffer = self.ring_buffer.clone(); + tokio::spawn(async move { + let (mut sink, mut stream) = BytesCodec::new().framed(port).split(); + loop { + tokio::select! { + res = receiver.recv() => { + let Some(data) = res else { + log::error!("error sending data to uart"); + break; + }; + + if let Err(e) = sink.send(data).await { + log::error!("{}", e); + } + }, + res = stream.next() => { + let Some(res) = res else { + log::error!("Error reading serial stream of node {}", node); + break; + }; + + let Ok(bytes) = res else { + log::error!("Serial stream of node {} has closed", node); + break; + }; + buffer.lock().await.write(&bytes); + + }, + } + } + log::warn!("exiting serial worker"); + }); + + Ok(()) + } +} + +#[derive(Debug)] +pub enum SerialError { + NotStarted, + AlreadyRunning, + InternalError(String), +} + +impl Error for SerialError {} + +impl Display for SerialError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SerialError::NotStarted => write!(f, "serial worker not started"), + SerialError::AlreadyRunning => write!(f, "already running"), + SerialError::InternalError(e) => e.fmt(f), + } + } +} diff --git a/tpi_rs/src/utils/mod.rs b/tpi_rs/src/utils/mod.rs index 07c30bf..72285d4 100644 --- a/tpi_rs/src/utils/mod.rs +++ b/tpi_rs/src/utils/mod.rs @@ -1,4 +1,6 @@ mod event_listener; +pub mod ring_buf; + use std::fmt::Display; #[doc(inline)] @@ -18,3 +20,46 @@ pub fn logging_sink( } }) } + +pub fn string_from_utf16(bytes: &[u8], little_endian: bool) -> String { + let u16s = bytes.chunks_exact(2).map(|pair| { + let Ok(owned) = pair.try_into() else { + unreachable!() + }; + + if little_endian { + u16::from_le_bytes(owned) + } else { + u16::from_be_bytes(owned) + } + }); + + let mut string = char::decode_utf16(u16s) + .map(|r| r.unwrap_or(char::REPLACEMENT_CHARACTER)) + .collect::(); + + if bytes.len() % 2 == 1 { + string.push(char::REPLACEMENT_CHARACTER) + } + + string +} + +pub fn string_from_utf32(bytes: &[u8], little_endian: bool) -> String { + bytes + .chunks(4) + .map(|slice| { + let Ok(owned) = slice.try_into() else { + return char::REPLACEMENT_CHARACTER; + }; + + let scalar = if little_endian { + u32::from_le_bytes(owned) + } else { + u32::from_be_bytes(owned) + }; + + char::from_u32(scalar).unwrap_or(char::REPLACEMENT_CHARACTER) + }) + .collect() +} diff --git a/tpi_rs/src/utils/ring_buf.rs b/tpi_rs/src/utils/ring_buf.rs new file mode 100644 index 0000000..811b745 --- /dev/null +++ b/tpi_rs/src/utils/ring_buf.rs @@ -0,0 +1,119 @@ +#[derive(Debug)] +pub struct RingBuffer { + buf: Vec, + idx: usize, + len: usize, +} + +impl RingBuffer { + pub fn write(&mut self, data: &[u8]) { + let remaining = C - (self.idx + self.len); + let mut data_idx = 0; + + while data_idx < data.len() { + let copy_len = (data.len() - data_idx).min(remaining); + + let beg = (self.idx + self.len) % C; + let end = (self.idx + self.len + copy_len) % C; + + if beg < end { + self.buf[beg..end].copy_from_slice(&data[data_idx..data_idx + copy_len]); + } else { + self.buf[beg..].copy_from_slice(&data[data_idx..data_idx + copy_len]); + } + + self.len += copy_len; + data_idx += copy_len; + + if self.len > C { + self.idx = (self.idx + copy_len) % C; + self.len = C; + } + } + } + + pub fn read(&mut self) -> Vec { + let to_read = self.len; + let remaining = C - self.idx; + let mut bytes_read = 0; + let mut output = Vec::with_capacity(to_read); + + while bytes_read < to_read { + let read_len = (to_read - bytes_read).min(remaining); + + output.extend_from_slice(&self.buf[self.idx..self.idx + read_len]); + + self.len -= read_len; + bytes_read += read_len; + + self.idx = (self.idx + read_len) % C; + } + + output + } +} + +impl Default for RingBuffer { + fn default() -> Self { + Self { + buf: vec![0; C], + idx: 0, + len: 0, + } + } +} + +#[cfg(test)] +mod tests { + use super::RingBuffer; + + #[test] + fn test_simple() { + let mut b = RingBuffer::<5>::new(); + let empty: Vec = vec![]; + assert_eq!(b.read(), empty); + b.write(&[1, 2, 3]); + assert_eq!(b.read(), [1, 2, 3]); + assert_eq!(b.read(), empty); + } + + #[test] + fn test_exact_size() { + let mut b = RingBuffer::<3>::new(); + b.write(&[1, 2, 3]); + assert_eq!(b.read(), [1, 2, 3]); + b.write(&[4, 5, 6, 7]); + assert_eq!(b.read(), [5, 6, 7]); + } + + #[test] + fn test_overflow_simple() { + let mut b = RingBuffer::<5>::new(); + b.write(&[1, 2, 3, 4, 5, 6, 7]); + assert_eq!(b.read(), [3, 4, 5, 6, 7]); + } + + #[test] + fn test_overflow() { + let mut b = RingBuffer::<5>::new(); + b.write(&[1, 2, 3]); + b.write(&[4, 5, 6, 7]); + assert_eq!(b.read(), [3, 4, 5, 6, 7]); + } + + #[test] + fn test_overflow_exact() { + let mut b = RingBuffer::<5>::new(); + b.write(&[1, 2]); + b.write(&[3, 4, 5, 6, 7]); + assert_eq!(b.read(), [3, 4, 5, 6, 7]); + } + + #[test] + fn test_overflow_wrap() { + let mut b = RingBuffer::<5>::new(); + b.write(&[1, 2]); + b.write(&[3, 4, 5, 6, 7, 8, 9]); + assert_eq!(b.read(), [5, 6, 7, 8, 9]); + } +}