Skip to content

Commit

Permalink
Merge pull request #22 from omnia-network/dev
Browse files Browse the repository at this point in the history
Polling timeout
  • Loading branch information
massimoalbarello authored Jan 23, 2024
2 parents f8384de + 3673c8a commit 4f26691
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 47 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions src/canister-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pub struct WebsocketMessage {
}

/// Element of the list of messages returned to the WS Gateway after polling.
#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq)]
#[derive(Debug, CandidType, Clone, Deserialize, Serialize, Eq, PartialEq)]
pub struct CanisterOutputMessage {
/// The client that the gateway will forward the message to or that sent the message.
pub client_key: ClientKey,
Expand Down Expand Up @@ -122,7 +122,7 @@ pub enum CanisterServiceMessage {
}

/// List of messages returned to the WS Gateway after polling.
#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq)]
#[derive(Debug, CandidType, Clone, Deserialize, Serialize, Eq, PartialEq)]
pub struct CanisterOutputCertifiedMessages {
pub messages: Vec<CanisterOutputMessage>, // List of messages.
#[serde(with = "serde_bytes")]
Expand Down
2 changes: 1 addition & 1 deletion src/ic-websocket-gateway/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ic_websocket_gateway"
version = "1.3.1"
version = "1.3.2"
edition.workspace = true
rust-version.workspace = true
repository.workspace = true
Expand Down
100 changes: 61 additions & 39 deletions src/ic-websocket-gateway/src/canister_poller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@ use canister_utils::{
use gateway_state::{CanisterEntry, CanisterPrincipal, ClientSender, GatewayState, PollerState};
use ic_agent::{Agent, AgentError};
use std::{sync::Arc, time::Duration};
use tokio::sync::mpsc::Sender;
use tokio::{sync::mpsc::Sender, time::timeout};
use tracing::{error, span, trace, warn, Instrument, Level, Span};

enum PollingStatus {
pub(crate) const POLLING_TIMEOUT_MS: u64 = 5_000;

type PollingTimeout = Duration;

#[derive(Debug, PartialEq, Eq)]
pub(crate) enum PollingStatus {
NoMessagesPolled,
MessagesPolled(CanisterOutputCertifiedMessages),
PollerTimedOut,
}

/// Poller which periodically queries a canister for new messages and relays them to the client
Expand Down Expand Up @@ -59,7 +65,7 @@ impl CanisterPoller {
// initially set to None as the first iteration will not have a previous span
let mut previous_polling_iteration_span: Option<Span> = None;
loop {
let polling_iteration_span = span!(Level::TRACE, "Polling Iteration", canister_id = %self.canister_id, polling_iteration = self.polling_iteration);
let polling_iteration_span = span!(Level::TRACE, "Polling Iteration", canister_id = %self.canister_id, polling_iteration = self.polling_iteration, cargo_version = env!("CARGO_PKG_VERSION"));
if let Some(previous_polling_iteration_span) = previous_polling_iteration_span {
// create a follow from relationship between the current and previous polling iteration
// this enables to crawl polling iterations in reverse chronological order
Expand Down Expand Up @@ -97,33 +103,39 @@ impl CanisterPoller {
pub async fn poll_and_relay(&mut self) -> Result<(), String> {
let start_polling_instant = tokio::time::Instant::now();

if let PollingStatus::MessagesPolled(certified_canister_output) =
self.poll_canister().await?
{
let relay_messages_span =
span!(parent: &Span::current(), Level::TRACE, "Relay Canister Messages");
let end_of_queue_reached = {
match certified_canister_output.is_end_of_queue {
Some(is_end_of_queue_reached) => is_end_of_queue_reached,
// if 'is_end_of_queue' is None, the CDK version is < 0.3.1 and does not have such a field
// in this case, assume that the queue is fully drained and therefore will be polled again
// after waiting for 'polling_interval_ms'
None => true,
match self.poll_canister().await? {
PollingStatus::MessagesPolled(certified_canister_output) => {
let relay_messages_span =
span!(parent: &Span::current(), Level::TRACE, "Relay Canister Messages");
let end_of_queue_reached = {
match certified_canister_output.is_end_of_queue {
Some(is_end_of_queue_reached) => is_end_of_queue_reached,
// if 'is_end_of_queue' is None, the CDK version is < 0.3.1 and does not have such a field
// in this case, assume that the queue is fully drained and therefore will be polled again
// after waiting for 'polling_interval_ms'
None => true,
}
};
self.update_nonce(&certified_canister_output)?;
// relaying of messages cannot be done in a separate task for each polling iteration
// as they might interleave and break the correct ordering of messages
// TODO: create a separate task dedicated to relaying messages which receives the messages from the poller via a queue
// and relays them in FIFO order
self.relay_messages(certified_canister_output)
.instrument(relay_messages_span)
.await;
if !end_of_queue_reached {
// if the queue is not fully drained, return immediately so that the next polling iteration can be started
warn!("Canister queue is not fully drained. Polling immediately");
return Ok(());
}
};
self.update_nonce(&certified_canister_output)?;
// relaying of messages cannot be done in a separate task for each polling iteration
// as they might interleave and break the correct ordering of messages
// TODO: create a separate task dedicated to relaying messages which receives the messages from the poller via a queue
// and relays them in FIFO order
self.relay_messages(certified_canister_output)
.instrument(relay_messages_span)
.await;
if !end_of_queue_reached {
// if the queue is not fully drained, return immediately so that the next polling iteration can be started
warn!("Canister queue is not fully drained. Polling immediately");
},
PollingStatus::PollerTimedOut => {
// if the poller timed out, it already waited way too long... return immediately so that the next polling iteration can be started
warn!("Poller timed out. Polling immediately");
return Ok(());
}
},
PollingStatus::NoMessagesPolled => (),
}

// compute the amout of time to sleep for before polling again
Expand All @@ -135,20 +147,26 @@ impl CanisterPoller {
}

/// Polls the canister for messages
async fn poll_canister(&mut self) -> Result<PollingStatus, String> {
pub(crate) async fn poll_canister(&mut self) -> Result<PollingStatus, String> {
trace!("Started polling iteration");

// get messages to be relayed to clients from canister (starting from 'message_nonce')
match ws_get_messages(
&self.agent,
&self.canister_id,
CanisterWsGetMessagesArguments {
nonce: self.next_message_nonce,
},
// the response timeout of the IC CDK is 2 minutes which implies that the poller would be stuck for that long waiting for a response
// to prevent this, we set a timeout of 5 seconds, if the poller does not receive a response in time, it polls immediately
// in case of a timeout, the message nonce is not updated so that no messages are lost by polling immediately again
match timeout(
PollingTimeout::from_millis(POLLING_TIMEOUT_MS),
ws_get_messages(
&self.agent,
&self.canister_id,
CanisterWsGetMessagesArguments {
nonce: self.next_message_nonce,
},
),
)
.await
{
Ok(certified_canister_output) => {
Ok(Ok(certified_canister_output)) => {
let number_of_polled_messages = certified_canister_output.messages.len();
if number_of_polled_messages == 0 {
trace!("No messages polled from canister");
Expand All @@ -161,7 +179,7 @@ impl CanisterPoller {
Ok(PollingStatus::MessagesPolled(certified_canister_output))
}
},
Err(IcError::Agent(e)) => {
Ok(Err(IcError::Agent(e))) => {
if is_recoverable_error(&e) {
// if the error is due to a replica which is either actively malicious or simply unavailable
// or to a malfunctioning boundary node,
Expand All @@ -174,8 +192,12 @@ impl CanisterPoller {
Err(format!("Unrecoverable agent error: {:?}", e))
}
},
Err(IcError::Candid(e)) => Err(format!("Unrecoverable candid error: {:?}", e)),
Err(IcError::Cdk(e)) => Err(format!("Unrecoverable CDK error: {:?}", e)),
Ok(Err(IcError::Candid(e))) => Err(format!("Unrecoverable candid error: {:?}", e)),
Ok(Err(IcError::Cdk(e))) => Err(format!("Unrecoverable CDK error: {:?}", e)),
Err(e) => {
warn!("Poller took too long to retrieve messages: {:?}", e);
Ok(PollingStatus::PollerTimedOut)
},
}
}

Expand Down
1 change: 1 addition & 0 deletions src/ic-websocket-gateway/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ async fn main() -> Result<(), String> {

// must be printed after initializing tracing to ensure that the info are captured
info!("Deployment info: {:?}", deployment_info);
info!("Cargo version: {}", env!("CARGO_PKG_VERSION"));
info!("Gateway Agent principal: {}", gateway_principal);

let tls_config = if deployment_info.tls_certificate_pem_path.is_some()
Expand Down
54 changes: 51 additions & 3 deletions src/ic-websocket-gateway/src/tests/canister_poller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ mod test {
use lazy_static::lazy_static;
use std::{
sync::{Arc, Mutex},
thread,
time::Duration,
};
use tokio::sync::mpsc::{self, Receiver, Sender};
use tracing::Span;

use crate::canister_poller::{get_nonce_from_message, CanisterPoller};
use crate::canister_poller::{
get_nonce_from_message, CanisterPoller, PollingStatus, POLLING_TIMEOUT_MS,
};

struct MockCanisterOutputCertifiedMessages(CanisterOutputCertifiedMessages);

Expand Down Expand Up @@ -237,7 +240,7 @@ mod test {
let end_polling_instant = tokio::time::Instant::now();
let elapsed = end_polling_instant - start_polling_instant;
// run 'cargo test -- --nocapture' to see the elapsed time
println!("Elapsed: {:?}", elapsed);
println!("Elapsed after relaying (should not sleep): {:?}", elapsed);
assert!(
elapsed > Duration::from_millis(polling_interval_ms)
// Reasonable to expect that the time it takes to sleep
Expand Down Expand Up @@ -294,7 +297,7 @@ mod test {
poller.poll_and_relay().await.expect("Failed to poll");
let end_polling_instant = tokio::time::Instant::now();
let elapsed = end_polling_instant - start_polling_instant;
println!("Elapsed: {:?}", elapsed);
println!("Elapsed after relaying (should sleep): {:?}", elapsed);
assert!(
// The `poll_and_relay` function should not sleep for `polling_interval_ms`
// if the queue is not empty.
Expand All @@ -319,6 +322,51 @@ mod test {
drop(guard);
}

#[tokio::test]
async fn should_not_sleep_after_timeout() {
let server = &*MOCK_SERVER;
let path = "/ws_get_messages";
let mut guard = server.lock().unwrap();
// do not drop the guard until the end of this test to make sure that no other test interleaves and overwrites the mock response
let mock = guard
.mock("GET", path)
.with_chunked_body(|w| {
thread::sleep(Duration::from_millis(POLLING_TIMEOUT_MS + 10));
w.write_all(&vec![])
})
.expect(2)
.create_async()
.await;

let polling_interval_ms = 100;
let (client_channel_tx, _): (Sender<IcWsCanisterMessage>, Receiver<IcWsCanisterMessage>) =
mpsc::channel(100);

let mut poller = create_poller(polling_interval_ms, client_channel_tx);

// check that the poller times out
assert_eq!(
Ok(PollingStatus::PollerTimedOut),
poller.poll_canister().await
);

// check that the poller does not wait for a polling interval after timing out
let start_polling_instant = tokio::time::Instant::now();
poller.poll_and_relay().await.expect("Failed to poll");
let end_polling_instant = tokio::time::Instant::now();
let elapsed = end_polling_instant - start_polling_instant;
println!("Elapsed due to timeout: {:?}", elapsed);
assert!(
// The `poll_canister` function should not sleep for `polling_interval_ms`
// after the poller times out.
elapsed < Duration::from_millis(POLLING_TIMEOUT_MS + polling_interval_ms)
);

mock.assert_async().await;
// just to make it explicit that the guard should be kept for the whole duration of the test
drop(guard);
}

#[tokio::test]
async fn should_terminate_polling_with_error() {
let server = &*MOCK_SERVER;
Expand Down
3 changes: 2 additions & 1 deletion src/ic-websocket-gateway/src/ws_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ impl WsListener {
Level::DEBUG,
"Accept Connection",
client_addr = ?client_addr.ip(),
client_id = self.next_client_id
client_id = self.next_client_id,
cargo_version = env!("CARGO_PKG_VERSION"),
);
let client_id = self.next_client_id;
let tls_acceptor = self.tls_acceptor.clone();
Expand Down

0 comments on commit 4f26691

Please sign in to comment.