From 7744fd2ad982d0ec67388fac369f39774a30de34 Mon Sep 17 00:00:00 2001 From: Adam McQuilkin <46639306+ajmcquilkin@users.noreply.github.com> Date: Fri, 8 Mar 2024 18:47:48 -0800 Subject: [PATCH] Added heartbeat handler --- src/connections/handlers.rs | 125 ++++++++++++++++++++-------------- src/connections/stream_api.rs | 19 ++++-- 2 files changed, 88 insertions(+), 56 deletions(-) diff --git a/src/connections/handlers.rs b/src/connections/handlers.rs index f464856..b5bd2e5 100644 --- a/src/connections/handlers.rs +++ b/src/connections/handlers.rs @@ -1,10 +1,14 @@ +use std::sync::Arc; + use crate::errors_internal::{Error, InternalChannelError, InternalStreamError}; use crate::protobufs; use crate::types::EncodedToRadioPacketWithHeader; use log::{debug, error, trace}; +use prost::Message; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::spawn; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; +use tokio::sync::Mutex; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; @@ -48,19 +52,6 @@ where let mut read_stream = read_stream; - let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); - - let handle = tokio::spawn(async move { - loop { - tokio::select! { - _ = rx.recv() => {} - _ = tokio::time::sleep(tokio::time::Duration::from_secs(60)) => { - error!("Didn't receive a message on read handler for 60s"); - } - } - } - }); - loop { let mut buffer = [0u8; 1024]; match read_stream.read(&mut buffer).await { @@ -86,20 +77,16 @@ where )); } } - - tx.send("hello there").expect("send failed"); } - handle.abort(); - - trace!("Read handler finished"); + // trace!("Read handler finished"); // Return type should be never (!) } pub fn spawn_write_handler( cancellation_token: CancellationToken, - write_stream: W, + write_stream: Arc>, write_input_rx: tokio::sync::mpsc::UnboundedReceiver, ) -> JoinHandle> where @@ -125,7 +112,7 @@ where async fn start_write_handler( _cancellation_token: CancellationToken, - mut write_stream: W, + write_stream: Arc>, mut write_input_rx: tokio::sync::mpsc::UnboundedReceiver, ) -> Result<(), Error> where @@ -133,22 +120,11 @@ where { debug!("Started write handler"); - let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); - - let handle = tokio::spawn(async move { - loop { - tokio::select! { - _ = rx.recv() => {} - _ = tokio::time::sleep(tokio::time::Duration::from_secs(60)) => { - error!("Didn't receive a message on write handler for 60s"); - } - } - } - }); - while let Some(message) = write_input_rx.recv().await { trace!("Writing packet data: {:?}", message); + let mut write_stream = write_stream.lock().await; + if let Err(e) = write_stream.write(message.data()).await { error!("Error writing to stream: {:?}", e); return Err(Error::InternalStreamError( @@ -157,12 +133,8 @@ where }, )); } - - tx.send("hello there").expect("send failed"); } - handle.abort(); - debug!("Write handler finished"); Ok(()) @@ -197,26 +169,75 @@ async fn start_processing_handler( let mut buffer = StreamBuffer::new(decoded_packet_tx); - let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + while let Some(message) = read_output_rx.recv().await { + trace!("Processing {} bytes from radio", message.data().len()); + buffer.process_incoming_bytes(message); + } + + trace!("Processing read_output_rx channel closed"); +} - let handle = tokio::spawn(async move { - loop { - tokio::select! { - _ = rx.recv() => {} - _ = tokio::time::sleep(tokio::time::Duration::from_secs(60)) => { - error!("Didn't receive a message on processing handler for 60s"); - } +pub fn spawn_heartbeat_handler( + cancellation_token: CancellationToken, + write_stream: Arc>, +) -> JoinHandle> +where + W: AsyncWriteExt + Send + Unpin + 'static, +{ + let handle = start_heartbeat_handler(cancellation_token.clone(), write_stream); + + spawn(async move { + tokio::select! { + _ = cancellation_token.cancelled() => { + debug!("Heartbeat handler cancelled"); + Ok(()) + } + write_result = handle => { + if let Err(e) = &write_result { + error!("Heartbeat handler unexpectedly terminated {e:?}"); + } + write_result } } - }); + }) +} - while let Some(message) = read_output_rx.recv().await { - trace!("Processing {} bytes from radio", message.data().len()); - buffer.process_incoming_bytes(message); - tx.send("hello there").expect("send failed"); +async fn start_heartbeat_handler( + _cancellation_token: CancellationToken, + write_stream: Arc>, +) -> Result<(), Error> +where + W: AsyncWriteExt + Send + Unpin + 'static, +{ + debug!("Started heartbeat handler"); + + loop { + tokio::time::sleep(std::time::Duration::from_secs(5 * 60)).await; + + let mut write_stream = write_stream.lock().await; + + let heartbeat_packet = protobufs::ToRadio::default(); + + let mut buffer = Vec::new(); + match heartbeat_packet.encode(&mut buffer) { + Ok(_) => (), + Err(e) => { + error!("Error encoding heartbeat packet: {:?}", e); + continue; + } + }; + + if let Err(e) = write_stream.write(&buffer).await { + error!("Error writing heartbeat packet to stream: {:?}", e); + return Err(Error::InternalStreamError( + InternalStreamError::StreamWriteError { + source: Box::new(e), + }, + )); + } } - handle.abort(); + // debug!("Heartbeat handler finished"); - trace!("Processing read_output_rx channel closed"); + // Return type should be never (!) } diff --git a/src/connections/stream_api.rs b/src/connections/stream_api.rs index 75713db..2d74796 100644 --- a/src/connections/stream_api.rs +++ b/src/connections/stream_api.rs @@ -1,10 +1,10 @@ use futures_util::future::join3; use log::trace; use prost::Message; -use std::{fmt::Display, marker::PhantomData}; +use std::{fmt::Display, marker::PhantomData, sync::Arc}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, - sync::mpsc::UnboundedSender, + sync::{mpsc::UnboundedSender, Mutex}, task::JoinHandle, }; use tokio_util::sync::CancellationToken; @@ -74,6 +74,7 @@ pub struct ConnectedStreamApi { read_handle: JoinHandle>, write_handle: JoinHandle>, processing_handle: JoinHandle>, + heartbeat_handle: JoinHandle>, cancellation_token: CancellationToken, @@ -435,11 +436,16 @@ impl StreamApi { let (read_stream, write_stream) = tokio::io::split(stream_handle.stream); let cancellation_token = CancellationToken::new(); + let write_stream_mutex = Arc::new(Mutex::new(write_stream)); + let read_handle = handlers::spawn_read_handler(cancellation_token.clone(), read_stream, read_output_tx); - let write_handle = - handlers::spawn_write_handler(cancellation_token.clone(), write_stream, write_input_rx); + let write_handle = handlers::spawn_write_handler( + cancellation_token.clone(), + write_stream_mutex.clone(), + write_input_rx, + ); let processing_handle = handlers::spawn_processing_handler( cancellation_token.clone(), @@ -447,6 +453,9 @@ impl StreamApi { decoded_packet_tx, ); + let heartbeat_handle = + handlers::spawn_heartbeat_handler(cancellation_token.clone(), write_stream_mutex); + // Persist channels and kill switch to struct let write_input_tx = write_input_tx; @@ -461,6 +470,7 @@ impl StreamApi { read_handle, write_handle, processing_handle, + heartbeat_handle, cancellation_token, typestate: PhantomData, }, @@ -536,6 +546,7 @@ impl ConnectedStreamApi { read_handle: self.read_handle, write_handle: self.write_handle, processing_handle: self.processing_handle, + heartbeat_handle: self.heartbeat_handle, cancellation_token: self.cancellation_token, typestate: PhantomData, })