diff --git a/io/zenoh-links/zenoh-link-tcp/src/unicast.rs b/io/zenoh-links/zenoh-link-tcp/src/unicast.rs index 59b6fa71af..9c69010903 100644 --- a/io/zenoh-links/zenoh-link-tcp/src/unicast.rs +++ b/io/zenoh-links/zenoh-link-tcp/src/unicast.rs @@ -167,14 +167,14 @@ impl LinkUnicastTrait for LinkUnicastTcp { // WARN assume the drop of TcpStream would clean itself // https://docs.rs/tokio/latest/tokio/net/struct.TcpStream.html#method.into_split -// impl Drop for LinkUnicastTcp { -// fn drop(&mut self) { -// // Close the underlying TCP socket -// ZRuntime::TX.handle().block_on(async { -// let _ = self.get_mut_socket().shutdown().await; -// }); -// } -// } +impl Drop for LinkUnicastTcp { + fn drop(&mut self) { + // Close the underlying TCP socket + zenoh_runtime::ZRuntime::Transport.block_in_place(async { + let _ = self.get_mut_socket().shutdown().await; + }); + } +} impl fmt::Display for LinkUnicastTcp { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/zenoh/tests/routing.rs b/zenoh/tests/routing.rs index 1c34c89309..72bbb90695 100644 --- a/zenoh/tests/routing.rs +++ b/zenoh/tests/routing.rs @@ -11,22 +11,27 @@ // Contributors: // ZettaScale Zenoh Team, // -use futures::future::try_join_all; -use futures::FutureExt as _; -use std::str::FromStr; -use std::sync::atomic::Ordering; -use std::sync::{atomic::AtomicUsize, Arc}; -use std::time::Duration; -use zenoh::config::{Config, ModeDependentValue}; -use zenoh::prelude::r#async::*; -use zenoh::{value::Value, Result}; +use std::{ + str::FromStr, + sync::{atomic::AtomicUsize, atomic::Ordering, Arc}, + time::Duration, +}; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; +use zenoh::{ + config::{Config, ModeDependentValue}, + prelude::r#async::*, + value::Value, + Result, +}; use zenoh_core::ztimeout; use zenoh_protocol::core::{WhatAmI, WhatAmIMatcher}; -use zenoh_result::{bail, zerror}; +use zenoh_result::bail; -const TIMEOUT: Duration = Duration::from_secs(360); +const TIMEOUT: Duration = Duration::from_secs(30); const MSG_COUNT: usize = 50; const MSG_SIZE: [usize; 2] = [1_024, 131_072]; +// Maximal recipes to run at once +const PARALLEL_RECIPES: usize = 8; #[derive(Debug, Clone, PartialEq, Eq)] enum Task { @@ -44,33 +49,52 @@ impl Task { &self, session: Arc, remaining_checkpoints: Arc, + token: CancellationToken, ) -> Result<()> { match self { // The Sub task checks if the incoming message matches the expected size until it receives enough counts. Self::Sub(ke, expected_size) => { let sub = ztimeout!(session.declare_subscriber(ke).res_async())?; let mut counter = 0; - while let Ok(sample) = sub.recv_async().await { - let recv_size = sample.value.payload.len(); - if recv_size != *expected_size { - bail!("Received payload size {recv_size} mismatches the expected {expected_size}"); - } - counter += 1; - if counter >= MSG_COUNT { - println!("Sub received sufficient amount of messages. Done."); - break; + loop { + tokio::select! { + _ = token.cancelled() => break, + res = sub.recv_async() => { + if let Ok(sample) = res { + let recv_size = sample.value.payload.len(); + if recv_size != *expected_size { + bail!("Received payload size {recv_size} mismatches the expected {expected_size}"); + } + counter += 1; + if counter >= MSG_COUNT { + println!("Sub received sufficient amount of messages. Done."); + break; + } + } + } } } + println!("Sub task done."); } // The Pub task keeps putting messages until all checkpoints are finished. Self::Pub(ke, payload_size) => { let value: Value = vec![0u8; *payload_size].into(); - while remaining_checkpoints.load(Ordering::Relaxed) > 0 { - ztimeout!(session - .put(ke, value.clone()) - .congestion_control(CongestionControl::Block) - .res_async())?; + // while remaining_checkpoints.load(Ordering::Relaxed) > 0 { + loop { + tokio::select! { + _ = token.cancelled() => break, + + // TODO: this won't yield after a timeout raised from recipe + res = tokio::time::timeout(std::time::Duration::from_secs(1), session + .put(ke, value.clone()) + .congestion_control(CongestionControl::Block) + .res()) => { + let _ = res?; + // TODO: check why this is needed + tokio::time::sleep(Duration::from_millis(1)).await; + } + } } println!("Pub task done."); } @@ -79,28 +103,34 @@ impl Task { Self::Get(ke, expected_size) => { let mut counter = 0; while counter < MSG_COUNT { - let replies = - ztimeout!(session.get(ke).timeout(Duration::from_secs(10)).res_async())?; - while let Ok(reply) = replies.recv_async().await { - match reply.sample { - Ok(sample) => { - let recv_size = sample.value.payload.len(); - if recv_size != *expected_size { - bail!("Received payload size {recv_size} mismatches the expected {expected_size}"); + tokio::select! { + _ = token.cancelled() => break, + replies = session.get(ke).timeout(Duration::from_secs(10)).res() => { + let replies = replies?; + while let Ok(reply) = replies.recv_async().await { + match reply.sample { + Ok(sample) => { + let recv_size = sample.value.payload.len(); + if recv_size != *expected_size { + bail!("Received payload size {recv_size} mismatches the expected {expected_size}"); + } + } + + Err(err) => { + log::warn!( + "Sample got from {} failed to unwrap! Error: {}.", + ke, + err + ); + continue; + } } - } - - Err(err) => { - log::warn!( - "Sample got from {} failed to unwrap! Error: {}.", - ke, - err - ); - continue; + counter += 1; } } - counter += 1; } + // TODO: check why this is needed + tokio::time::sleep(Duration::from_millis(1)).await; } println!("Get got sufficient amount of messages. Done."); } @@ -111,16 +141,11 @@ impl Task { let sample = Sample::try_from(ke.clone(), vec![0u8; *payload_size])?; loop { - futures::select! { + tokio::select! { + _ = token.cancelled() => break, query = queryable.recv_async() => { query?.reply(Ok(sample.clone())).res_async().await?; }, - - _ = tokio::time::sleep(Duration::from_millis(100)).fuse() => { - if remaining_checkpoints.load(Ordering::Relaxed) == 0 { - break; - } - } } } println!("Queryable task done."); @@ -134,18 +159,17 @@ impl Task { // Mark one checkpoint is finished. Self::Checkpoint => { if remaining_checkpoints.fetch_sub(1, Ordering::Relaxed) <= 1 { + token.cancel(); println!("The end of the recipe."); } } // Wait until all checkpoints are done Self::Wait => { - while remaining_checkpoints.load(Ordering::Relaxed) > 0 { - tokio::time::sleep(Duration::from_millis(100)).await; - } + token.cancelled().await; } } - Result::Ok(()) + Ok(()) } } @@ -198,6 +222,7 @@ impl Default for Node { #[derive(Debug, Clone)] struct Recipe { nodes: Vec, + token: CancellationToken, } // Display the Recipe as [NodeName1, NodeName2, ...] @@ -211,7 +236,10 @@ impl std::fmt::Display for Recipe { impl Recipe { fn new(nodes: impl IntoIterator) -> Self { let nodes = nodes.into_iter().collect(); - Self { nodes } + Self { + nodes, + token: CancellationToken::new(), + } } fn num_checkpoints(&self) -> usize { @@ -222,66 +250,81 @@ impl Recipe { let num_checkpoints = self.num_checkpoints(); let remaining_checkpoints = Arc::new(AtomicUsize::new(num_checkpoints)); println!( - "Recipe {} beging testing with {} checkpoint(s).", + "Recipe {} begin testing with {} checkpoint(s).", &self, &num_checkpoints ); + let mut recipe_join_set = tokio::task::JoinSet::new(); + // All concurrent tasks to run - let futures = self.nodes.clone().into_iter().map(move |node| { - let receipe_name = self.to_string(); + for node in self.nodes.clone() { // All nodes share the same checkpoint counter let remaining_checkpoints = remaining_checkpoints.clone(); + let token = self.token.clone(); + + let recipe_task = async move { + // Initiate + let session = { + // Load the config and build up a session + let config = { + let mut config = node.config.unwrap_or_default(); + config.set_mode(Some(node.mode)).unwrap(); + config.scouting.multicast.set_enabled(Some(false)).unwrap(); + config + .listen + .set_endpoints(node.listen.iter().map(|x| x.parse().unwrap()).collect()) + .unwrap(); + config + .connect + .set_endpoints( + node.connect.iter().map(|x| x.parse().unwrap()).collect(), + ) + .unwrap(); + config + }; - async move { - // Load the config and build up a session - let config = { - let mut config = node.config.unwrap_or_default(); - config.set_mode(Some(node.mode)).unwrap(); - config.scouting.multicast.set_enabled(Some(false)).unwrap(); - config - .listen - .set_endpoints(node.listen.iter().map(|x| x.parse().unwrap()).collect()) - .unwrap(); - config - .connect - .set_endpoints(node.connect.iter().map(|x| x.parse().unwrap()).collect()) - .unwrap(); - config - }; + // Warmup before the session starts + tokio::time::sleep(node.warmup).await; + println!("Node: {} starting...", &node.name); - // Warmup before the session starts - tokio::time::sleep(node.warmup).await; - println!("Node: {} starting...", &node.name); + // In case of client can't connect to some peers/routers + let session = loop { + if let Ok(session) = zenoh::open(config.clone()).res_async().await { + break session.into_arc(); + } else { + tokio::time::sleep(Duration::from_secs(1)).await; + } + }; - // In case of client can't connect to some peers/routers - let session = loop { - if let Ok(session) = zenoh::open(config.clone()).res_async().await { - break session.into_arc(); - } else { - tokio::time::sleep(Duration::from_secs(1)).await; - } + session }; - // Each node consists of a specified session associated with tasks to run - let node_tasks = node.con_task.into_iter().map(|seq_tasks| { + let mut node_join_set = tokio::task::JoinSet::new(); + for seq_tasks in node.con_task.into_iter() { + let token = token.clone(); + // The tasks share the same session and checkpoint counter let session = session.clone(); let remaining_checkpoints = remaining_checkpoints.clone(); - - tokio::task::spawn(async move { + node_join_set.spawn(async move { // Tasks in seq_tasks would execute serially - for t in seq_tasks { - t.run(session.clone(), remaining_checkpoints.clone()) - .await?; + for task in seq_tasks { + task.run( + session.clone(), + remaining_checkpoints.clone(), + token.clone(), + ) + .await?; } Result::Ok(()) - }) - }); + }); + } - // All tasks of the node run together - try_join_all(node_tasks.into_iter().map(tokio::task::spawn)) - .await - .map_err(|e| zerror!("The recipe {} failed due to {}", receipe_name, &e))?; + while let Some(res) = node_join_set.join_next().await { + let _ = res??; + } + // node_task_tracker.close(); + // node_task_tracker.wait().await; // Close the session once all the task assoicated with the node are done. Arc::try_unwrap(session) @@ -292,23 +335,38 @@ impl Recipe { println!("Node: {} is closed.", &node.name); Result::Ok(()) - } - }); + }; + recipe_join_set.spawn(recipe_task); + } // All tasks of the recipe run together - tokio::time::timeout( - TIMEOUT, - try_join_all(futures.into_iter().map(tokio::task::spawn)), - ) - .await - .map_err(|e| format!("The recipe: {} failed due to {}", &self, e))??; + loop { + tokio::select! { + _ = tokio::time::sleep(TIMEOUT) => { + dbg!("Timeout"); + + // Termination + remaining_checkpoints.swap(0, Ordering::Relaxed); + self.token.cancel(); + bail!("Timeout"); + }, + res = recipe_join_set.join_next() => { + if let Some(res) = res { + let _ = res??; + } else { + break + } + } + } + } + Ok(()) } } // Two peers connecting to a common node (either in router or peer mode) can discover each other. // And the message transmission should work even if the common node disappears after a while. -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn gossip() -> Result<()> { env_logger::try_init().unwrap_or_default(); @@ -316,7 +374,8 @@ async fn gossip() -> Result<()> { let ke = String::from("testKeyExprGossip"); let msg_size = 8; - let peer1 = Node { + // node1 in peer mode playing pub and queryable + let node1 = Node { name: format!("Pub & Queryable {}", WhatAmI::Peer), connect: vec![locator.clone()], mode: WhatAmI::Peer, @@ -332,7 +391,8 @@ async fn gossip() -> Result<()> { ]), ..Default::default() }; - let peer2 = Node { + // node2 in peer mode playing sub and get + let node2 = Node { name: format!("Sub & Get {}", WhatAmI::Peer), mode: WhatAmI::Peer, connect: vec![locator.clone()], @@ -351,22 +411,22 @@ async fn gossip() -> Result<()> { ..Default::default() }; + // Recipes: + // - node1: Peer, node2: Peer, node3: Peer + // - node1: Peer, node2: Peer, node3: Router for mode in [WhatAmI::Peer, WhatAmI::Router] { - Recipe::new([ - Node { - name: format!("Router {}", mode), - mode: WhatAmI::Peer, - listen: vec![locator.clone()], - con_task: ConcurrentTask::from([SequentialTask::from([Task::Sleep( - Duration::from_millis(1000), - )])]), - ..Default::default() - }, - peer1.clone(), - peer2.clone(), - ]) - .run() - .await?; + let node3 = Node { + name: format!("Router {}", mode), + mode: WhatAmI::Peer, + listen: vec![locator.clone()], + con_task: ConcurrentTask::from([SequentialTask::from([Task::Sleep( + Duration::from_millis(1000), + )])]), + ..Default::default() + }; + Recipe::new([node1.clone(), node2.clone(), node3]) + .run() + .await?; } println!("Gossip test passed."); @@ -430,10 +490,12 @@ async fn static_failover_brokering() -> Result<()> { } // All test cases varying in -// 1. Message size -// 2. Mode: peer or client -// 3. Spawning order -// #[tokio::test(flavor = "multi_thread", worker_threads = 4)] +// 1. Message size: 2 (sizes) +// 2. Mode: {Client, Peer} x {Client x Peer} x {Router} = 2 x 2 x 1 = 4 (cases) +// 3. Spawning order (delay_in_secs for node1, node2, and node3) = 6 (cases) +// +// Total cases = 2 x 4 x 6 = 96 +#[tokio::test(flavor = "multi_thread", worker_threads = 8)] async fn three_node_combination() -> Result<()> { env_logger::try_init().unwrap_or_default(); let modes = [WhatAmI::Peer, WhatAmI::Client]; @@ -450,7 +512,7 @@ async fn three_node_combination() -> Result<()> { // Ports going to be used: 17451 to 17498 let base_port = 17450; - let recipe_list = modes + let recipe_list: Vec<_> = modes .map(|n1| modes.map(|n2| (n1, n2))) .concat() .into_iter() @@ -464,12 +526,16 @@ async fn three_node_combination() -> Result<()> { let ke_pubsub = format!("three_node_combination_keyexpr_pubsub_{idx}"); let ke_getqueryable = format!("three_node_combination_keyexpr_getqueryable_{idx}"); + use rand::Rng; + let mut rng = rand::thread_rng(); + let router_node = Node { name: format!("Router {}", WhatAmI::Router), mode: WhatAmI::Router, listen: vec![locator.clone()], con_task: ConcurrentTask::from([SequentialTask::from([Task::Wait])]), - warmup: Duration::from_secs(delay1), + warmup: Duration::from_secs(delay1) + + Duration::from_millis(rng.gen_range(0..500)), ..Default::default() }; @@ -487,6 +553,7 @@ async fn three_node_combination() -> Result<()> { ke_pubsub.clone(), msg_size, )])]); + pub_node.warmup += Duration::from_millis(rng.gen_range(0..500)); let mut queryable_node = base; queryable_node.name = format!("Queryable {node1_mode}"); @@ -495,6 +562,7 @@ async fn three_node_combination() -> Result<()> { ke_getqueryable.clone(), msg_size, )])]); + queryable_node.warmup += Duration::from_millis(rng.gen_range(0..500)); (pub_node, queryable_node) }; @@ -513,6 +581,7 @@ async fn three_node_combination() -> Result<()> { Task::Sub(ke_pubsub, msg_size), Task::Checkpoint, ])]); + sub_node.warmup += Duration::from_millis(rng.gen_range(0..500)); let mut get_node = base; get_node.name = format!("Get {node2_mode}"); @@ -520,6 +589,7 @@ async fn three_node_combination() -> Result<()> { Task::Get(ke_getqueryable, msg_size), Task::Checkpoint, ])]); + get_node.warmup += Duration::from_millis(rng.gen_range(0..500)); (sub_node, get_node) }; @@ -529,21 +599,34 @@ async fn three_node_combination() -> Result<()> { Recipe::new([router_node, queryable_node, get_node]), ) }, - ); + ) + .collect(); + + for chunks in recipe_list.chunks(PARALLEL_RECIPES).map(|x| x.to_vec()) { + let mut join_set = tokio::task::JoinSet::new(); + for (pubsub, getqueryable) in chunks { + join_set.spawn(async move { + pubsub.run().await?; + getqueryable.run().await?; + Result::Ok(()) + }); + } - for (pubsub, getqueryable) in recipe_list { - pubsub.run().await?; - getqueryable.run().await?; + while let Some(res) = join_set.join_next().await { + let _ = res??; + } } println!("Three-node combination test passed."); - Result::Ok(()) + return Result::Ok(()); } // All test cases varying in -// 1. Message size -// 2. Mode -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +// 1. Message size: 2 (sizes) +// 2. Mode: {Client, Peer} x {Client, Peer} x {IsFirstListen} = 2 x 2 x 2 = 8 (modes) +// +// Total cases = 2 x 8 = 16 +#[tokio::test(flavor = "multi_thread", worker_threads = 8)] async fn two_node_combination() -> Result<()> { env_logger::try_init().unwrap_or_default(); @@ -560,7 +643,7 @@ async fn two_node_combination() -> Result<()> { let mut idx = 0; // Ports going to be used: 17500 to 17508 let base_port = 17500; - let recipe_list = modes + let recipe_list: Vec<_> = modes .into_iter() .flat_map(|(n1, n2, who)| MSG_SIZE.map(|s| (n1, n2, who, s))) .map(|(node1_mode, node2_mode, who, msg_size)| { @@ -635,11 +718,20 @@ async fn two_node_combination() -> Result<()> { Recipe::new([pub_node, sub_node]), Recipe::new([queryable_node, get_node]), ) - }); - - for (pubsub, getqueryable) in recipe_list { - pubsub.run().await?; - getqueryable.run().await?; + }) + .collect(); + + for chunks in recipe_list.chunks(PARALLEL_RECIPES).map(|x| x.to_vec()) { + let task_tracker = TaskTracker::new(); + for (pubsub, getqueryable) in chunks { + task_tracker.spawn(async move { + pubsub.run().await?; + getqueryable.run().await?; + Result::Ok(()) + }); + } + task_tracker.close(); + task_tracker.wait().await; } println!("Two-node combination test passed.");