diff --git a/Cargo.lock b/Cargo.lock index 29224b6..80d8278 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1343,7 +1343,7 @@ checksum = "576c539151d4769fb4d1a0c25c4108dd18facd04c5695b02cf2d226ab4e43aa5" [[package]] name = "ic_websocket_gateway" -version = "1.3.1" +version = "1.3.2" dependencies = [ "async-trait", "candid", diff --git a/src/canister-utils/src/lib.rs b/src/canister-utils/src/lib.rs index c83dd9e..39293d0 100644 --- a/src/canister-utils/src/lib.rs +++ b/src/canister-utils/src/lib.rs @@ -93,7 +93,7 @@ pub struct WebsocketMessage { } /// Element of the list of messages returned to the WS Gateway after polling. -#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq)] +#[derive(Debug, CandidType, Clone, Deserialize, Serialize, Eq, PartialEq)] pub struct CanisterOutputMessage { /// The client that the gateway will forward the message to or that sent the message. pub client_key: ClientKey, @@ -122,7 +122,7 @@ pub enum CanisterServiceMessage { } /// List of messages returned to the WS Gateway after polling. -#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq)] +#[derive(Debug, CandidType, Clone, Deserialize, Serialize, Eq, PartialEq)] pub struct CanisterOutputCertifiedMessages { pub messages: Vec, // List of messages. #[serde(with = "serde_bytes")] diff --git a/src/ic-websocket-gateway/Cargo.toml b/src/ic-websocket-gateway/Cargo.toml index d0ec656..d9c964d 100644 --- a/src/ic-websocket-gateway/Cargo.toml +++ b/src/ic-websocket-gateway/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ic_websocket_gateway" -version = "1.3.1" +version = "1.3.2" edition.workspace = true rust-version.workspace = true repository.workspace = true diff --git a/src/ic-websocket-gateway/src/canister_poller.rs b/src/ic-websocket-gateway/src/canister_poller.rs index 04df3ed..7172d21 100644 --- a/src/ic-websocket-gateway/src/canister_poller.rs +++ b/src/ic-websocket-gateway/src/canister_poller.rs @@ -6,12 +6,18 @@ use canister_utils::{ use gateway_state::{CanisterEntry, CanisterPrincipal, ClientSender, GatewayState, PollerState}; use ic_agent::{Agent, AgentError}; use std::{sync::Arc, time::Duration}; -use tokio::sync::mpsc::Sender; +use tokio::{sync::mpsc::Sender, time::timeout}; use tracing::{error, span, trace, warn, Instrument, Level, Span}; -enum PollingStatus { +pub(crate) const POLLING_TIMEOUT_MS: u64 = 5_000; + +type PollingTimeout = Duration; + +#[derive(Debug, PartialEq, Eq)] +pub(crate) enum PollingStatus { NoMessagesPolled, MessagesPolled(CanisterOutputCertifiedMessages), + PollerTimedOut, } /// Poller which periodically queries a canister for new messages and relays them to the client @@ -59,7 +65,7 @@ impl CanisterPoller { // initially set to None as the first iteration will not have a previous span let mut previous_polling_iteration_span: Option = None; loop { - let polling_iteration_span = span!(Level::TRACE, "Polling Iteration", canister_id = %self.canister_id, polling_iteration = self.polling_iteration); + let polling_iteration_span = span!(Level::TRACE, "Polling Iteration", canister_id = %self.canister_id, polling_iteration = self.polling_iteration, cargo_version = env!("CARGO_PKG_VERSION")); if let Some(previous_polling_iteration_span) = previous_polling_iteration_span { // create a follow from relationship between the current and previous polling iteration // this enables to crawl polling iterations in reverse chronological order @@ -97,33 +103,39 @@ impl CanisterPoller { pub async fn poll_and_relay(&mut self) -> Result<(), String> { let start_polling_instant = tokio::time::Instant::now(); - if let PollingStatus::MessagesPolled(certified_canister_output) = - self.poll_canister().await? - { - let relay_messages_span = - span!(parent: &Span::current(), Level::TRACE, "Relay Canister Messages"); - let end_of_queue_reached = { - match certified_canister_output.is_end_of_queue { - Some(is_end_of_queue_reached) => is_end_of_queue_reached, - // if 'is_end_of_queue' is None, the CDK version is < 0.3.1 and does not have such a field - // in this case, assume that the queue is fully drained and therefore will be polled again - // after waiting for 'polling_interval_ms' - None => true, + match self.poll_canister().await? { + PollingStatus::MessagesPolled(certified_canister_output) => { + let relay_messages_span = + span!(parent: &Span::current(), Level::TRACE, "Relay Canister Messages"); + let end_of_queue_reached = { + match certified_canister_output.is_end_of_queue { + Some(is_end_of_queue_reached) => is_end_of_queue_reached, + // if 'is_end_of_queue' is None, the CDK version is < 0.3.1 and does not have such a field + // in this case, assume that the queue is fully drained and therefore will be polled again + // after waiting for 'polling_interval_ms' + None => true, + } + }; + self.update_nonce(&certified_canister_output)?; + // relaying of messages cannot be done in a separate task for each polling iteration + // as they might interleave and break the correct ordering of messages + // TODO: create a separate task dedicated to relaying messages which receives the messages from the poller via a queue + // and relays them in FIFO order + self.relay_messages(certified_canister_output) + .instrument(relay_messages_span) + .await; + if !end_of_queue_reached { + // if the queue is not fully drained, return immediately so that the next polling iteration can be started + warn!("Canister queue is not fully drained. Polling immediately"); + return Ok(()); } - }; - self.update_nonce(&certified_canister_output)?; - // relaying of messages cannot be done in a separate task for each polling iteration - // as they might interleave and break the correct ordering of messages - // TODO: create a separate task dedicated to relaying messages which receives the messages from the poller via a queue - // and relays them in FIFO order - self.relay_messages(certified_canister_output) - .instrument(relay_messages_span) - .await; - if !end_of_queue_reached { - // if the queue is not fully drained, return immediately so that the next polling iteration can be started - warn!("Canister queue is not fully drained. Polling immediately"); + }, + PollingStatus::PollerTimedOut => { + // if the poller timed out, it already waited way too long... return immediately so that the next polling iteration can be started + warn!("Poller timed out. Polling immediately"); return Ok(()); - } + }, + PollingStatus::NoMessagesPolled => (), } // compute the amout of time to sleep for before polling again @@ -135,20 +147,26 @@ impl CanisterPoller { } /// Polls the canister for messages - async fn poll_canister(&mut self) -> Result { + pub(crate) async fn poll_canister(&mut self) -> Result { trace!("Started polling iteration"); // get messages to be relayed to clients from canister (starting from 'message_nonce') - match ws_get_messages( - &self.agent, - &self.canister_id, - CanisterWsGetMessagesArguments { - nonce: self.next_message_nonce, - }, + // the response timeout of the IC CDK is 2 minutes which implies that the poller would be stuck for that long waiting for a response + // to prevent this, we set a timeout of 5 seconds, if the poller does not receive a response in time, it polls immediately + // in case of a timeout, the message nonce is not updated so that no messages are lost by polling immediately again + match timeout( + PollingTimeout::from_millis(POLLING_TIMEOUT_MS), + ws_get_messages( + &self.agent, + &self.canister_id, + CanisterWsGetMessagesArguments { + nonce: self.next_message_nonce, + }, + ), ) .await { - Ok(certified_canister_output) => { + Ok(Ok(certified_canister_output)) => { let number_of_polled_messages = certified_canister_output.messages.len(); if number_of_polled_messages == 0 { trace!("No messages polled from canister"); @@ -161,7 +179,7 @@ impl CanisterPoller { Ok(PollingStatus::MessagesPolled(certified_canister_output)) } }, - Err(IcError::Agent(e)) => { + Ok(Err(IcError::Agent(e))) => { if is_recoverable_error(&e) { // if the error is due to a replica which is either actively malicious or simply unavailable // or to a malfunctioning boundary node, @@ -174,8 +192,12 @@ impl CanisterPoller { Err(format!("Unrecoverable agent error: {:?}", e)) } }, - Err(IcError::Candid(e)) => Err(format!("Unrecoverable candid error: {:?}", e)), - Err(IcError::Cdk(e)) => Err(format!("Unrecoverable CDK error: {:?}", e)), + Ok(Err(IcError::Candid(e))) => Err(format!("Unrecoverable candid error: {:?}", e)), + Ok(Err(IcError::Cdk(e))) => Err(format!("Unrecoverable CDK error: {:?}", e)), + Err(e) => { + warn!("Poller took too long to retrieve messages: {:?}", e); + Ok(PollingStatus::PollerTimedOut) + }, } } diff --git a/src/ic-websocket-gateway/src/main.rs b/src/ic-websocket-gateway/src/main.rs index c54c8e8..d4c3f09 100644 --- a/src/ic-websocket-gateway/src/main.rs +++ b/src/ic-websocket-gateway/src/main.rs @@ -84,6 +84,7 @@ async fn main() -> Result<(), String> { // must be printed after initializing tracing to ensure that the info are captured info!("Deployment info: {:?}", deployment_info); + info!("Cargo version: {}", env!("CARGO_PKG_VERSION")); info!("Gateway Agent principal: {}", gateway_principal); let tls_config = if deployment_info.tls_certificate_pem_path.is_some() diff --git a/src/ic-websocket-gateway/src/tests/canister_poller.rs b/src/ic-websocket-gateway/src/tests/canister_poller.rs index f9f821d..d50ab8a 100644 --- a/src/ic-websocket-gateway/src/tests/canister_poller.rs +++ b/src/ic-websocket-gateway/src/tests/canister_poller.rs @@ -11,12 +11,15 @@ mod test { use lazy_static::lazy_static; use std::{ sync::{Arc, Mutex}, + thread, time::Duration, }; use tokio::sync::mpsc::{self, Receiver, Sender}; use tracing::Span; - use crate::canister_poller::{get_nonce_from_message, CanisterPoller}; + use crate::canister_poller::{ + get_nonce_from_message, CanisterPoller, PollingStatus, POLLING_TIMEOUT_MS, + }; struct MockCanisterOutputCertifiedMessages(CanisterOutputCertifiedMessages); @@ -237,7 +240,7 @@ mod test { let end_polling_instant = tokio::time::Instant::now(); let elapsed = end_polling_instant - start_polling_instant; // run 'cargo test -- --nocapture' to see the elapsed time - println!("Elapsed: {:?}", elapsed); + println!("Elapsed after relaying (should not sleep): {:?}", elapsed); assert!( elapsed > Duration::from_millis(polling_interval_ms) // Reasonable to expect that the time it takes to sleep @@ -294,7 +297,7 @@ mod test { poller.poll_and_relay().await.expect("Failed to poll"); let end_polling_instant = tokio::time::Instant::now(); let elapsed = end_polling_instant - start_polling_instant; - println!("Elapsed: {:?}", elapsed); + println!("Elapsed after relaying (should sleep): {:?}", elapsed); assert!( // The `poll_and_relay` function should not sleep for `polling_interval_ms` // if the queue is not empty. @@ -319,6 +322,51 @@ mod test { drop(guard); } + #[tokio::test] + async fn should_not_sleep_after_timeout() { + let server = &*MOCK_SERVER; + let path = "/ws_get_messages"; + let mut guard = server.lock().unwrap(); + // do not drop the guard until the end of this test to make sure that no other test interleaves and overwrites the mock response + let mock = guard + .mock("GET", path) + .with_chunked_body(|w| { + thread::sleep(Duration::from_millis(POLLING_TIMEOUT_MS + 10)); + w.write_all(&vec![]) + }) + .expect(2) + .create_async() + .await; + + let polling_interval_ms = 100; + let (client_channel_tx, _): (Sender, Receiver) = + mpsc::channel(100); + + let mut poller = create_poller(polling_interval_ms, client_channel_tx); + + // check that the poller times out + assert_eq!( + Ok(PollingStatus::PollerTimedOut), + poller.poll_canister().await + ); + + // check that the poller does not wait for a polling interval after timing out + let start_polling_instant = tokio::time::Instant::now(); + poller.poll_and_relay().await.expect("Failed to poll"); + let end_polling_instant = tokio::time::Instant::now(); + let elapsed = end_polling_instant - start_polling_instant; + println!("Elapsed due to timeout: {:?}", elapsed); + assert!( + // The `poll_canister` function should not sleep for `polling_interval_ms` + // after the poller times out. + elapsed < Duration::from_millis(POLLING_TIMEOUT_MS + polling_interval_ms) + ); + + mock.assert_async().await; + // just to make it explicit that the guard should be kept for the whole duration of the test + drop(guard); + } + #[tokio::test] async fn should_terminate_polling_with_error() { let server = &*MOCK_SERVER; diff --git a/src/ic-websocket-gateway/src/ws_listener.rs b/src/ic-websocket-gateway/src/ws_listener.rs index 629e443..92c09ed 100644 --- a/src/ic-websocket-gateway/src/ws_listener.rs +++ b/src/ic-websocket-gateway/src/ws_listener.rs @@ -146,7 +146,8 @@ impl WsListener { Level::DEBUG, "Accept Connection", client_addr = ?client_addr.ip(), - client_id = self.next_client_id + client_id = self.next_client_id, + cargo_version = env!("CARGO_PKG_VERSION"), ); let client_id = self.next_client_id; let tls_acceptor = self.tls_acceptor.clone();