From 4a3dc027b95279342b7154385ff23743c0050d98 Mon Sep 17 00:00:00 2001 From: Luca Cominardi Date: Fri, 15 Sep 2023 12:04:04 +0200 Subject: [PATCH] Allow to join multiple multicast groups on UDP --- Cargo.toml | 2 +- commons/zenoh-protocol/src/core/endpoint.rs | 52 ++++++++++++ io/zenoh-links/zenoh-link-udp/src/lib.rs | 1 + .../zenoh-link-udp/src/multicast.rs | 82 ++++++++++--------- io/zenoh-transport/src/multicast/link.rs | 29 ++----- io/zenoh-transport/src/multicast/transport.rs | 1 - 6 files changed, 105 insertions(+), 62 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3b95999788..c97bceb42a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -135,7 +135,7 @@ serde_yaml = "0.9.19" sha3 = "0.10.6" shared_memory = "0.12.4" shellexpand = "3.0.0" -socket2 = "0.5.1" +socket2 = { version ="0.5.1", features = [ "all" ] } stop-token = "0.7.0" syn = "1.0.109" tide = "0.16.0" diff --git a/commons/zenoh-protocol/src/core/endpoint.rs b/commons/zenoh-protocol/src/core/endpoint.rs index 316e007476..e596b78bde 100644 --- a/commons/zenoh-protocol/src/core/endpoint.rs +++ b/commons/zenoh-protocol/src/core/endpoint.rs @@ -22,6 +22,7 @@ pub const METADATA_SEPARATOR: char = '?'; pub const LIST_SEPARATOR: char = ';'; pub const FIELD_SEPARATOR: char = '='; pub const CONFIG_SEPARATOR: char = '#'; +pub const VALUE_SEPARATOR: char = '|'; fn split_once(s: &str, c: char) -> (&str, &str) { match s.find(c) { @@ -98,6 +99,17 @@ impl Parameters { Self::iter(s).find(|x| x.0 == k).map(|x| x.1) } + pub fn values<'s>(s: &'s str, k: &str) -> impl Iterator + DoubleEndedIterator { + match Self::get(s, k) { + Some(v) => v.split(VALUE_SEPARATOR), + None => { + let mut i = "".split(VALUE_SEPARATOR); + i.next(); + i + } + } + } + pub(super) fn insert<'s, I>(iter: I, k: &'s str, v: &'s str) -> String where I: Iterator, @@ -272,6 +284,10 @@ impl<'a> Metadata<'a> { pub fn get(&'a self, k: &str) -> Option<&'a str> { Parameters::get(self.0, k) } + + pub fn values(&'a self, k: &str) -> impl Iterator + DoubleEndedIterator { + Parameters::values(self.0, k) + } } impl AsRef for Metadata<'_> { @@ -385,6 +401,10 @@ impl<'a> Config<'a> { pub fn get(&'a self, k: &str) -> Option<&'a str> { Parameters::get(self.0, k) } + + pub fn values(&'a self, k: &str) -> impl Iterator + DoubleEndedIterator { + Parameters::values(self.0, k) + } } impl AsRef for Config<'_> { @@ -764,11 +784,13 @@ fn endpoints() { .iter() .find(|x| x == &("a", "1")) .unwrap(); + assert_eq!(endpoint.metadata().get("a"), Some("1")); endpoint .metadata() .iter() .find(|x| x == &("b", "2")) .unwrap(); + assert_eq!(endpoint.metadata().get("b"), Some("2")); assert!(endpoint.config().as_str().is_empty()); assert_eq!(endpoint.config().iter().count(), 0); @@ -783,11 +805,13 @@ fn endpoints() { .iter() .find(|x| x == &("a", "1")) .unwrap(); + assert_eq!(endpoint.metadata().get("a"), Some("1")); endpoint .metadata() .iter() .find(|x| x == &("b", "2")) .unwrap(); + assert_eq!(endpoint.metadata().get("a"), Some("1")); assert!(endpoint.config().as_str().is_empty()); assert_eq!(endpoint.config().iter().count(), 0); @@ -800,7 +824,9 @@ fn endpoints() { assert_eq!(endpoint.config().as_str(), "A=1;B=2"); assert_eq!(endpoint.config().iter().count(), 2); endpoint.config().iter().find(|x| x == &("A", "1")).unwrap(); + assert_eq!(endpoint.config().get("A"), Some("1")); endpoint.config().iter().find(|x| x == &("B", "2")).unwrap(); + assert_eq!(endpoint.config().get("B"), Some("2")); let endpoint = EndPoint::from_str("udp/127.0.0.1:7447#B=2;A=1").unwrap(); assert_eq!(endpoint.as_str(), "udp/127.0.0.1:7447#A=1;B=2"); @@ -811,7 +837,9 @@ fn endpoints() { assert_eq!(endpoint.config().as_str(), "A=1;B=2"); assert_eq!(endpoint.config().iter().count(), 2); endpoint.config().iter().find(|x| x == &("A", "1")).unwrap(); + assert_eq!(endpoint.config().get("A"), Some("1")); endpoint.config().iter().find(|x| x == &("B", "2")).unwrap(); + assert_eq!(endpoint.config().get("B"), Some("2")); let endpoint = EndPoint::from_str("udp/127.0.0.1:7447?a=1;b=2#A=1;B=2").unwrap(); assert_eq!(endpoint.as_str(), "udp/127.0.0.1:7447?a=1;b=2#A=1;B=2"); @@ -824,15 +852,19 @@ fn endpoints() { .iter() .find(|x| x == &("a", "1")) .unwrap(); + assert_eq!(endpoint.metadata().get("a"), Some("1")); endpoint .metadata() .iter() .find(|x| x == &("b", "2")) .unwrap(); + assert_eq!(endpoint.metadata().get("b"), Some("2")); assert_eq!(endpoint.config().as_str(), "A=1;B=2"); assert_eq!(endpoint.config().iter().count(), 2); endpoint.config().iter().find(|x| x == &("A", "1")).unwrap(); + assert_eq!(endpoint.config().get("A"), Some("1")); endpoint.config().iter().find(|x| x == &("B", "2")).unwrap(); + assert_eq!(endpoint.config().get("B"), Some("2")); let endpoint = EndPoint::from_str("udp/127.0.0.1:7447?b=2;a=1#B=2;A=1").unwrap(); assert_eq!(endpoint.as_str(), "udp/127.0.0.1:7447?a=1;b=2#A=1;B=2"); @@ -845,15 +877,19 @@ fn endpoints() { .iter() .find(|x| x == &("a", "1")) .unwrap(); + assert_eq!(endpoint.metadata().get("a"), Some("1")); endpoint .metadata() .iter() .find(|x| x == &("b", "2")) .unwrap(); + assert_eq!(endpoint.metadata().get("b"), Some("2")); assert_eq!(endpoint.config().as_str(), "A=1;B=2"); assert_eq!(endpoint.config().iter().count(), 2); endpoint.config().iter().find(|x| x == &("A", "1")).unwrap(); + assert_eq!(endpoint.config().get("A"), Some("1")); endpoint.config().iter().find(|x| x == &("B", "2")).unwrap(); + assert_eq!(endpoint.config().get("B"), Some("2")); let mut endpoint = EndPoint::from_str("udp/127.0.0.1:7447?a=1;b=2").unwrap(); endpoint.metadata_mut().insert("c", "3").unwrap(); @@ -884,4 +920,20 @@ fn endpoints() { .extend([("A", "1"), ("C", "3"), ("B", "2")].iter().copied()) .unwrap(); assert_eq!(endpoint.as_str(), "udp/127.0.0.1:7447#A=1;B=2;C=3"); + + let endpoint = + EndPoint::from_str("udp/127.0.0.1:7447#iface=en0;join=224.0.0.1|224.0.0.2|224.0.0.3") + .unwrap(); + let c = endpoint.config(); + assert_eq!(c.get("iface"), Some("en0")); + assert_eq!(c.get("join"), Some("224.0.0.1|224.0.0.2|224.0.0.3")); + assert_eq!(c.values("iface").count(), 1); + let mut i = c.values("iface"); + assert_eq!(i.next(), Some("en0")); + assert_eq!(c.values("join").count(), 3); + let mut i = c.values("join"); + assert_eq!(i.next(), Some("224.0.0.1")); + assert_eq!(i.next(), Some("224.0.0.2")); + assert_eq!(i.next(), Some("224.0.0.3")); + assert_eq!(i.next(), None); } diff --git a/io/zenoh-links/zenoh-link-udp/src/lib.rs b/io/zenoh-links/zenoh-link-udp/src/lib.rs index a9f974d08b..20a48e8f4d 100644 --- a/io/zenoh-links/zenoh-link-udp/src/lib.rs +++ b/io/zenoh-links/zenoh-link-udp/src/lib.rs @@ -86,6 +86,7 @@ impl LocatorInspector for UdpLocatorInspector { pub mod config { pub const UDP_MULTICAST_IFACE: &str = "iface"; + pub const UDP_MULTICAST_JOIN: &str = "join"; } pub async fn get_udp_addrs(address: Address<'_>) -> ZResult> { diff --git a/io/zenoh-links/zenoh-link-udp/src/multicast.rs b/io/zenoh-links/zenoh-link-udp/src/multicast.rs index d5ab9b89ed..96d89a4b49 100644 --- a/io/zenoh-links/zenoh-link-udp/src/multicast.rs +++ b/io/zenoh-links/zenoh-link-udp/src/multicast.rs @@ -19,7 +19,7 @@ use socket2::{Domain, Protocol, Socket, Type}; use std::sync::Arc; use std::{borrow::Cow, fmt}; use zenoh_link_commons::{LinkManagerMulticastTrait, LinkMulticast, LinkMulticastTrait}; -use zenoh_protocol::core::{EndPoint, Locator}; +use zenoh_protocol::core::{Config, EndPoint, Locator}; use zenoh_result::{bail, zerror, Error as ZError, ZResult}; pub struct LinkMulticastUdp { @@ -154,22 +154,16 @@ impl LinkManagerMulticastUdp { async fn new_link_inner( &self, mcast_addr: &SocketAddr, - iface: Option<&str>, + config: Config<'_>, ) -> ZResult<(UdpSocket, UdpSocket, SocketAddr)> { let domain = match mcast_addr.ip() { IpAddr::V4(_) => Domain::IPV4, IpAddr::V6(_) => Domain::IPV6, }; - // Defaults - let _default_ipv4_iface = Ipv4Addr::UNSPECIFIED; - let default_ipv6_iface = 0; - let default_ipv4_addr = Ipv4Addr::UNSPECIFIED; - let default_ipv6_addr = Ipv6Addr::UNSPECIFIED; - // Get default iface address to bind the socket on if provided let mut iface_addr: Option = None; - if let Some(iface) = iface { + if let Some(iface) = config.get(UDP_MULTICAST_IFACE) { iface_addr = match iface.parse() { Ok(addr) => Some(addr), Err(_) => zenoh_util::net::get_unicast_addresses_of_interface(iface)? @@ -206,8 +200,8 @@ impl LinkManagerMulticastUdp { match iface { Some(iface) => iface, None => match mcast_addr.ip() { - IpAddr::V4(_) => IpAddr::V4(default_ipv4_addr), - IpAddr::V6(_) => IpAddr::V6(default_ipv6_addr), + IpAddr::V4(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED), + IpAddr::V6(_) => IpAddr::V6(Ipv6Addr::UNSPECIFIED), }, } } @@ -242,37 +236,54 @@ impl LinkManagerMulticastUdp { mcast_sock .set_reuse_address(true) .map_err(|e| zerror!("{}: {}", mcast_addr, e))?; + mcast_sock + .set_reuse_port(true) + .map_err(|e| zerror!("{}: {}", mcast_addr, e))?; - // Bind the socket - let default_mcast_addr = { - #[cfg(unix)] - { - match mcast_addr.ip() { - IpAddr::V4(ip4) => IpAddr::V4(ip4), - IpAddr::V6(_) => local_addr, - } - } // See UNIX Network Programmping p.212 - #[cfg(windows)] - { - match mcast_addr.ip() { - IpAddr::V4(_) => IpAddr::V4(default_ipv4_addr), - IpAddr::V6(_) => IpAddr::V6(default_ipv6_addr), - } - } + // Bind the socket: let's bing to the unspecified address so we can join and read + // from multiple multicast groups. + let bind_mcast_addr = match mcast_addr.ip() { + IpAddr::V4(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED), + IpAddr::V6(_) => IpAddr::V6(Ipv6Addr::UNSPECIFIED), }; mcast_sock - .bind(&SocketAddr::new(default_mcast_addr, mcast_addr.port()).into()) + .bind(&SocketAddr::new(bind_mcast_addr, mcast_addr.port()).into()) .map_err(|e| zerror!("{}: {}", mcast_addr, e))?; // Join the multicast group + let join = config.values(UDP_MULTICAST_JOIN); match mcast_addr.ip() { IpAddr::V4(dst_ip4) => match local_addr { - IpAddr::V4(src_ip4) => mcast_sock.join_multicast_v4(&dst_ip4, &src_ip4), - IpAddr::V6(_) => panic!(), + IpAddr::V4(src_ip4) => { + // Join default multicast group + mcast_sock + .join_multicast_v4(&dst_ip4, &src_ip4) + .map_err(|e| zerror!("{}: {}", mcast_addr, e))?; + // Join any additional multicast group + for g in join { + let g: Ipv4Addr = + g.parse().map_err(|e| zerror!("{}: {}", mcast_addr, e))?; + mcast_sock + .join_multicast_v4(&g, &src_ip4) + .map_err(|e| zerror!("{}: {}", mcast_addr, e))?; + } + } + IpAddr::V6(src_ip6) => bail!("{}: unexepcted IPv6 source address", src_ip6), }, - IpAddr::V6(dst_ip6) => mcast_sock.join_multicast_v6(&dst_ip6, default_ipv6_iface), - } - .map_err(|e| zerror!("{}: {}", mcast_addr, e))?; + IpAddr::V6(dst_ip6) => { + // Join default multicast group + mcast_sock + .join_multicast_v6(&dst_ip6, 0) + .map_err(|e| zerror!("{}: {}", mcast_addr, e))?; + // Join any additional multicast group + for g in join { + let g: Ipv6Addr = g.parse().map_err(|e| zerror!("{}: {}", mcast_addr, e))?; + mcast_sock + .join_multicast_v6(&g, 0) + .map_err(|e| zerror!("{}: {}", mcast_addr, e))?; + } + } + }; // Build the async_std multicast UdpSocket let mcast_sock: UdpSocket = std::net::UdpSocket::from(mcast_sock).into(); @@ -296,10 +307,7 @@ impl LinkManagerMulticastTrait for LinkManagerMulticastUdp { let mut errs: Vec = vec![]; for maddr in mcast_addrs { - match self - .new_link_inner(&maddr, endpoint.config().get(UDP_MULTICAST_IFACE)) - .await - { + match self.new_link_inner(&maddr, endpoint.config()).await { Ok((mcast_sock, ucast_sock, ucast_addr)) => { let link = Arc::new(LinkMulticastUdp::new( ucast_addr, ucast_sock, maddr, mcast_sock, diff --git a/io/zenoh-transport/src/multicast/link.rs b/io/zenoh-transport/src/multicast/link.rs index 95d3b90036..b430e7efb1 100644 --- a/io/zenoh-transport/src/multicast/link.rs +++ b/io/zenoh-transport/src/multicast/link.rs @@ -30,7 +30,7 @@ use zenoh_core::zlock; use zenoh_link::{LinkMulticast, Locator}; use zenoh_protocol::{ core::{Bits, Priority, Resolution, WhatAmI, ZenohId}, - transport::{BatchSize, Join, KeepAlive, PrioritySn, TransportMessage, TransportSn}, + transport::{BatchSize, Join, PrioritySn, TransportMessage, TransportSn}, }; use zenoh_result::{bail, zerror, ZResult}; use zenoh_sync::{RecyclingObjectPool, Signal}; @@ -40,7 +40,6 @@ pub(super) struct TransportLinkMulticastConfig { pub(super) zid: ZenohId, pub(super) whatami: WhatAmI, pub(super) lease: Duration, - pub(super) keep_alive: usize, pub(super) join_interval: Duration, pub(super) sn_resolution: Bits, pub(super) batch_size: BatchSize, @@ -212,17 +211,13 @@ async fn tx_task( enum Action { Pull((WBatch, usize)), Join, - KeepAlive, Stop, } - async fn pull(pipeline: &mut TransmissionPipelineConsumer, keep_alive: Duration) -> Action { - match pipeline.pull().timeout(keep_alive).await { - Ok(res) => match res { - Some(sb) => Action::Pull(sb), - None => Action::Stop, - }, - Err(_) => Action::KeepAlive, + async fn pull(pipeline: &mut TransmissionPipelineConsumer) -> Action { + match pipeline.pull().await { + Some(sb) => Action::Pull(sb), + None => Action::Stop, } } @@ -236,10 +231,9 @@ async fn tx_task( Action::Join } - let keep_alive = config.join_interval / config.keep_alive as u32; let mut last_join = Instant::now().checked_sub(config.join_interval).unwrap(); loop { - match pull(&mut pipeline, keep_alive) + match pull(&mut pipeline) .race(join(last_join, config.join_interval)) .await { @@ -300,17 +294,6 @@ async fn tx_task( last_join = Instant::now(); } - Action::KeepAlive => { - let message: TransportMessage = KeepAlive.into(); - - #[allow(unused_variables)] // Used when stats feature is enabled - let n = link.send(&message).await?; - #[cfg(feature = "stats")] - { - stats.inc_tx_t_msgs(1); - stats.inc_tx_bytes(n); - } - } Action::Stop => { // Drain the transmission pipeline and write remaining bytes on the wire let mut batches = pipeline.drain(); diff --git a/io/zenoh-transport/src/multicast/transport.rs b/io/zenoh-transport/src/multicast/transport.rs index 67c3ac268d..9d89fabc4b 100644 --- a/io/zenoh-transport/src/multicast/transport.rs +++ b/io/zenoh-transport/src/multicast/transport.rs @@ -248,7 +248,6 @@ impl TransportMulticastInner { zid: self.manager.config.zid, whatami: self.manager.config.whatami, lease: self.manager.config.multicast.lease, - keep_alive: self.manager.config.multicast.keep_alive, join_interval: self.manager.config.multicast.join_interval, sn_resolution: self.manager.config.resolution.get(Field::FrameSN), batch_size,