From 719633aaa8f9cc819b1b78ea32d9bba356e23f0f Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sat, 16 Mar 2024 20:14:52 +0530 Subject: [PATCH 01/20] feat: implement `Packet::size` for v4/v5 (#817) * refactor: implement size for all v5 packet types * refactor: implement `Packet.size()` for v4 --- rumqttc/src/mqttbytes/v4/connack.rs | 7 +++++++ rumqttc/src/mqttbytes/v4/connect.rs | 7 +++++++ rumqttc/src/mqttbytes/v4/mod.rs | 21 +++++++++++++++++++++ rumqttc/src/v5/mqttbytes/v5/connect.rs | 7 +++++++ rumqttc/src/v5/mqttbytes/v5/mod.rs | 19 +++++++++++++++++++ rumqttc/src/v5/mqttbytes/v5/ping.rs | 8 ++++++++ 6 files changed, 69 insertions(+) diff --git a/rumqttc/src/mqttbytes/v4/connack.rs b/rumqttc/src/mqttbytes/v4/connack.rs index 45391951..65a0da48 100644 --- a/rumqttc/src/mqttbytes/v4/connack.rs +++ b/rumqttc/src/mqttbytes/v4/connack.rs @@ -61,6 +61,13 @@ impl ConnAck { Ok(1 + count + len) } + + pub fn size(&self) -> usize { + let len = self.len(); + let remaining_len_size = len_len(len); + + 1 + remaining_len_size + len + } } /// Connection return code type diff --git a/rumqttc/src/mqttbytes/v4/connect.rs b/rumqttc/src/mqttbytes/v4/connect.rs index cdba1014..8732384f 100644 --- a/rumqttc/src/mqttbytes/v4/connect.rs +++ b/rumqttc/src/mqttbytes/v4/connect.rs @@ -132,6 +132,13 @@ impl Connect { buffer[flags_index] = connect_flags; Ok(1 + count + len) } + + pub fn size(&self) -> usize { + let len = self.len(); + let remaining_len_size = len_len(len); + + 1 + remaining_len_size + len + } } /// LastWill that broker forwards on behalf of the client diff --git a/rumqttc/src/mqttbytes/v4/mod.rs b/rumqttc/src/mqttbytes/v4/mod.rs index abe45612..4ac4b388 100644 --- a/rumqttc/src/mqttbytes/v4/mod.rs +++ b/rumqttc/src/mqttbytes/v4/mod.rs @@ -47,6 +47,27 @@ pub enum Packet { Disconnect, } +impl Packet { + pub fn size(&self) -> usize { + match self { + Self::Publish(publish) => publish.size(), + Self::Subscribe(subscription) => subscription.size(), + Self::Unsubscribe(unsubscribe) => unsubscribe.size(), + Self::ConnAck(ack) => ack.size(), + Self::PubAck(ack) => ack.size(), + Self::SubAck(ack) => ack.size(), + Self::UnsubAck(unsuback) => unsuback.size(), + Self::PubRec(pubrec) => pubrec.size(), + Self::PubRel(pubrel) => pubrel.size(), + Self::PubComp(pubcomp) => pubcomp.size(), + Self::Connect(connect) => connect.size(), + Self::PingReq => PingReq.size(), + Self::PingResp => PingResp.size(), + Self::Disconnect => Disconnect.size(), + } + } +} + /// Reads a stream of bytes and extracts next MQTT packet out of it pub fn read(stream: &mut BytesMut, max_size: usize) -> Result { let fixed_header = check(stream.iter(), max_size)?; diff --git a/rumqttc/src/v5/mqttbytes/v5/connect.rs b/rumqttc/src/v5/mqttbytes/v5/connect.rs index 83918b87..a351c411 100644 --- a/rumqttc/src/v5/mqttbytes/v5/connect.rs +++ b/rumqttc/src/v5/mqttbytes/v5/connect.rs @@ -127,6 +127,13 @@ impl Connect { buffer[flags_index] = connect_flags; Ok(1 + count + len) } + + pub fn size(&self, will: &Option, login: &Option) -> usize { + let len = self.len(will, login); + let remaining_len_size = len_len(len); + + 1 + remaining_len_size + len + } } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/rumqttc/src/v5/mqttbytes/v5/mod.rs b/rumqttc/src/v5/mqttbytes/v5/mod.rs index bf4dcb42..01ddef99 100644 --- a/rumqttc/src/v5/mqttbytes/v5/mod.rs +++ b/rumqttc/src/v5/mqttbytes/v5/mod.rs @@ -144,6 +144,25 @@ impl Packet { Self::Disconnect(disconnect) => disconnect.write(write), } } + + pub fn size(&self) -> usize { + match self { + Self::Publish(publish) => publish.size(), + Self::Subscribe(subscription) => subscription.size(), + Self::Unsubscribe(unsubscribe) => unsubscribe.size(), + Self::ConnAck(ack) => ack.size(), + Self::PubAck(ack) => ack.size(), + Self::SubAck(ack) => ack.size(), + Self::UnsubAck(unsuback) => unsuback.size(), + Self::PubRec(pubrec) => pubrec.size(), + Self::PubRel(pubrel) => pubrel.size(), + Self::PubComp(pubcomp) => pubcomp.size(), + Self::Connect(connect, will, login) => connect.size(will, login), + Self::PingReq(req) => req.size(), + Self::PingResp(resp) => resp.size(), + Self::Disconnect(disconnect) => disconnect.size(), + } + } } /// MQTT packet type diff --git a/rumqttc/src/v5/mqttbytes/v5/ping.rs b/rumqttc/src/v5/mqttbytes/v5/ping.rs index 086311ed..d69ead6f 100644 --- a/rumqttc/src/v5/mqttbytes/v5/ping.rs +++ b/rumqttc/src/v5/mqttbytes/v5/ping.rs @@ -9,6 +9,10 @@ impl PingReq { payload.put_slice(&[0xC0, 0x00]); Ok(2) } + + pub fn size(&self) -> usize { + 2 + } } #[derive(Debug, Clone, PartialEq, Eq)] @@ -19,4 +23,8 @@ impl PingResp { payload.put_slice(&[0xD0, 0x00]); Ok(2) } + + pub fn size(&self) -> usize { + 2 + } } From cfdd394db003197918779563e8224f8f8683d229 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sat, 16 Mar 2024 21:21:09 +0530 Subject: [PATCH 02/20] refactor: `read` and `write` methods on `v4::Packet` (#821) * refactor: `Packet::read` * refactor: `Packet::write` * test: fix changes in refactor --- benchmarks/parsers/v4.rs | 3 +- rumqttc/src/framed.rs | 4 +- rumqttc/src/mqttbytes/v4/mod.rs | 88 ++++++++++++++++++++------------- rumqttc/src/state.rs | 10 ++-- rumqttc/tests/broker.rs | 2 +- 5 files changed, 65 insertions(+), 42 deletions(-) diff --git a/benchmarks/parsers/v4.rs b/benchmarks/parsers/v4.rs index 8a97bf2f..4fbde7ba 100644 --- a/benchmarks/parsers/v4.rs +++ b/benchmarks/parsers/v4.rs @@ -1,6 +1,7 @@ use bytes::{Buf, BytesMut}; use rumqttc::mqttbytes::v4; use rumqttc::mqttbytes::QoS; +use rumqttc::Packet; use std::time::Instant; mod common; @@ -31,7 +32,7 @@ fn main() { let start = Instant::now(); let mut packets = Vec::with_capacity(count); while output.has_remaining() { - let packet = v4::read(&mut output, 10 * 1024).unwrap(); + let packet = Packet::read(&mut output, 10 * 1024).unwrap(); packets.push(packet); } diff --git a/rumqttc/src/framed.rs b/rumqttc/src/framed.rs index b0a536e7..9a5d862f 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/framed.rs @@ -58,7 +58,7 @@ impl Network { pub async fn read(&mut self) -> io::Result { loop { - let required = match read(&mut self.read, self.max_incoming_size) { + let required = match Packet::read(&mut self.read, self.max_incoming_size) { Ok(packet) => return Ok(packet), Err(mqttbytes::Error::InsufficientBytes(required)) => required, Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), @@ -75,7 +75,7 @@ impl Network { pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { let mut count = 0; loop { - match read(&mut self.read, self.max_incoming_size) { + match Packet::read(&mut self.read, self.max_incoming_size) { Ok(packet) => { state.handle_incoming_packet(packet)?; diff --git a/rumqttc/src/mqttbytes/v4/mod.rs b/rumqttc/src/mqttbytes/v4/mod.rs index 4ac4b388..3c9225e8 100644 --- a/rumqttc/src/mqttbytes/v4/mod.rs +++ b/rumqttc/src/mqttbytes/v4/mod.rs @@ -66,45 +66,67 @@ impl Packet { Self::Disconnect => Disconnect.size(), } } -} -/// Reads a stream of bytes and extracts next MQTT packet out of it -pub fn read(stream: &mut BytesMut, max_size: usize) -> Result { - let fixed_header = check(stream.iter(), max_size)?; + /// Reads a stream of bytes and extracts next MQTT packet out of it + pub fn read(stream: &mut BytesMut, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; - // Test with a stream with exactly the size to check border panics - let packet = stream.split_to(fixed_header.frame_length()); - let packet_type = fixed_header.packet_type()?; + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + PacketType::Disconnect => Ok(Packet::Disconnect), + _ => Err(Error::PayloadRequired), + }; + } - if fixed_header.remaining_len == 0 { - // no payload packets - return match packet_type { - PacketType::PingReq => Ok(Packet::PingReq), - PacketType::PingResp => Ok(Packet::PingResp), - PacketType::Disconnect => Ok(Packet::Disconnect), - _ => Err(Error::PayloadRequired), + let packet = packet.freeze(); + let packet = match packet_type { + PacketType::Connect => Packet::Connect(Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(PubAck::read(fixed_header, packet)?), + PacketType::PubRec => Packet::PubRec(PubRec::read(fixed_header, packet)?), + PacketType::PubRel => Packet::PubRel(PubRel::read(fixed_header, packet)?), + PacketType::PubComp => Packet::PubComp(PubComp::read(fixed_header, packet)?), + PacketType::Subscribe => Packet::Subscribe(Subscribe::read(fixed_header, packet)?), + PacketType::SubAck => Packet::SubAck(SubAck::read(fixed_header, packet)?), + PacketType::Unsubscribe => { + Packet::Unsubscribe(Unsubscribe::read(fixed_header, packet)?) + } + PacketType::UnsubAck => Packet::UnsubAck(UnsubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect, }; - } - let packet = packet.freeze(); - let packet = match packet_type { - PacketType::Connect => Packet::Connect(Connect::read(fixed_header, packet)?), - PacketType::ConnAck => Packet::ConnAck(ConnAck::read(fixed_header, packet)?), - PacketType::Publish => Packet::Publish(Publish::read(fixed_header, packet)?), - PacketType::PubAck => Packet::PubAck(PubAck::read(fixed_header, packet)?), - PacketType::PubRec => Packet::PubRec(PubRec::read(fixed_header, packet)?), - PacketType::PubRel => Packet::PubRel(PubRel::read(fixed_header, packet)?), - PacketType::PubComp => Packet::PubComp(PubComp::read(fixed_header, packet)?), - PacketType::Subscribe => Packet::Subscribe(Subscribe::read(fixed_header, packet)?), - PacketType::SubAck => Packet::SubAck(SubAck::read(fixed_header, packet)?), - PacketType::Unsubscribe => Packet::Unsubscribe(Unsubscribe::read(fixed_header, packet)?), - PacketType::UnsubAck => Packet::UnsubAck(UnsubAck::read(fixed_header, packet)?), - PacketType::PingReq => Packet::PingReq, - PacketType::PingResp => Packet::PingResp, - PacketType::Disconnect => Packet::Disconnect, - }; + Ok(packet) + } - Ok(packet) + /// Serializes the MQTT packet into a stream of bytes + pub fn write(&self, stream: &mut BytesMut) -> Result { + match self { + Packet::Connect(c) => c.write(stream), + Packet::ConnAck(c) => c.write(stream), + Packet::Publish(p) => p.write(stream), + Packet::PubAck(p) => p.write(stream), + Packet::PubRec(p) => p.write(stream), + Packet::PubRel(p) => p.write(stream), + Packet::PubComp(p) => p.write(stream), + Packet::Subscribe(s) => s.write(stream), + Packet::SubAck(s) => s.write(stream), + Packet::Unsubscribe(u) => u.write(stream), + Packet::UnsubAck(u) => u.write(stream), + Packet::PingReq => PingReq.write(stream), + Packet::PingResp => PingResp.write(stream), + Packet::Disconnect => Disconnect.write(stream), + } + } } /// Return number of remaining length bytes required for encoding length diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index da33bd2f..acee6f1d 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -703,7 +703,7 @@ mod test { let publish = build_incoming_publish(QoS::ExactlyOnce, 1); mqtt.handle_incoming_publish(&publish).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); match packet { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), _ => panic!("Invalid network request: {:?}", packet), @@ -770,14 +770,14 @@ mod test { let publish = build_outgoing_publish(QoS::ExactlyOnce); mqtt.outgoing_publish(publish).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); match packet { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); match packet { Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), @@ -790,14 +790,14 @@ mod test { let publish = build_incoming_publish(QoS::ExactlyOnce, 1); mqtt.handle_incoming_publish(&publish).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); match packet { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } mqtt.handle_incoming_pubrel(&PubRel::new(1)).unwrap(); - let packet = read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); match packet { Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), diff --git a/rumqttc/tests/broker.rs b/rumqttc/tests/broker.rs index a6ebacc8..ea66448f 100644 --- a/rumqttc/tests/broker.rs +++ b/rumqttc/tests/broker.rs @@ -232,7 +232,7 @@ impl Network { pub async fn readb(&mut self, incoming: &mut VecDeque) -> io::Result<()> { let mut count = 0; loop { - match read(&mut self.read, self.max_incoming_size) { + match Packet::read(&mut self.read, self.max_incoming_size) { Ok(packet) => { incoming.push_back(packet); count += 1; From 8b4f96d4830fed28a52cdce4731dfa4c90aa1a74 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Sat, 16 Mar 2024 21:34:46 +0530 Subject: [PATCH 03/20] refactor: `AsyncReadWrite` (#822) --- rumqttc/src/eventloop.rs | 4 ++-- rumqttc/src/framed.rs | 10 +++++----- rumqttc/src/proxy.rs | 8 ++++---- rumqttc/src/tls.rs | 8 ++++---- rumqttc/src/v5/eventloop.rs | 4 ++-- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index a8aee76c..fe971a6f 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -2,7 +2,7 @@ use crate::{framed::Network, Transport}; use crate::{Incoming, MqttState, NetworkOptions, Packet, Request, StateError}; use crate::{MqttOptions, Outgoing}; -use crate::framed::N; +use crate::framed::AsyncReadWrite; use crate::mqttbytes::v4::*; use flume::{bounded, Receiver, Sender}; use tokio::net::{lookup_host, TcpSocket, TcpStream}; @@ -369,7 +369,7 @@ async fn network_connect( _ => options.broker_address(), }; - let tcp_stream: Box = { + let tcp_stream: Box = { #[cfg(feature = "proxy")] match options.proxy() { Some(proxy) => proxy.connect(&domain, port, network_options).await?, diff --git a/rumqttc/src/framed.rs b/rumqttc/src/framed.rs index 9a5d862f..d2ec7367 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/framed.rs @@ -10,7 +10,7 @@ use std::io; /// appropriate to achieve performance pub struct Network { /// Socket for IO - socket: Box, + socket: Box, /// Buffered reads read: BytesMut, /// Maximum packet size @@ -20,8 +20,8 @@ pub struct Network { } impl Network { - pub fn new(socket: impl N + 'static, max_incoming_size: usize) -> Network { - let socket = Box::new(socket) as Box; + pub fn new(socket: impl AsyncReadWrite + 'static, max_incoming_size: usize) -> Network { + let socket = Box::new(socket) as Box; Network { socket, read: BytesMut::with_capacity(10 * 1024), @@ -117,5 +117,5 @@ impl Network { } } -pub trait N: AsyncRead + AsyncWrite + Send + Unpin {} -impl N for T where T: AsyncRead + AsyncWrite + Send + Unpin {} +pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Unpin {} +impl AsyncReadWrite for T where T: AsyncRead + AsyncWrite + Send + Unpin {} diff --git a/rumqttc/src/proxy.rs b/rumqttc/src/proxy.rs index 3dbe741c..e7f84cd3 100644 --- a/rumqttc/src/proxy.rs +++ b/rumqttc/src/proxy.rs @@ -1,5 +1,5 @@ use crate::eventloop::socket_connect; -use crate::framed::N; +use crate::framed::AsyncReadWrite; use crate::NetworkOptions; use std::io; @@ -46,10 +46,10 @@ impl Proxy { broker_addr: &str, broker_port: u16, network_options: NetworkOptions, - ) -> Result, ProxyError> { + ) -> Result, ProxyError> { let proxy_addr = format!("{}:{}", self.addr, self.port); - let tcp: Box = Box::new(socket_connect(proxy_addr, network_options).await?); + let tcp: Box = Box::new(socket_connect(proxy_addr, network_options).await?); let mut tcp = match self.ty { ProxyType::Http => tcp, #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))] @@ -67,7 +67,7 @@ impl ProxyAuth { self, host: &str, port: u16, - tcp_stream: &mut Box, + tcp_stream: &mut Box, ) -> Result<(), ProxyError> { match self { Self::None => async_http_proxy::http_connect_tokio(tcp_stream, host, port).await?, diff --git a/rumqttc/src/tls.rs b/rumqttc/src/tls.rs index c8e77571..f80dba64 100644 --- a/rumqttc/src/tls.rs +++ b/rumqttc/src/tls.rs @@ -16,7 +16,7 @@ use std::io::{BufReader, Cursor}; #[cfg(feature = "use-rustls")] use std::sync::Arc; -use crate::framed::N; +use crate::framed::AsyncReadWrite; use crate::TlsConfiguration; #[cfg(feature = "use-native-tls")] @@ -166,9 +166,9 @@ pub async fn tls_connect( addr: &str, _port: u16, tls_config: &TlsConfiguration, - tcp: Box, -) -> Result, Error> { - let tls: Box = match tls_config { + tcp: Box, +) -> Result, Error> { + let tls: Box = match tls_config { #[cfg(feature = "use-rustls")] TlsConfiguration::Simple { .. } | TlsConfiguration::Rustls(_) => { let connector = rustls_connector(tls_config).await?; diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 36c10971..27c26f29 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -2,7 +2,7 @@ use super::framed::Network; use super::mqttbytes::v5::*; use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Transport}; use crate::eventloop::socket_connect; -use crate::framed::N; +use crate::framed::AsyncReadWrite; use flume::{bounded, Receiver, Sender}; use tokio::select; @@ -304,7 +304,7 @@ async fn network_connect(options: &MqttOptions) -> Result options.broker_address(), }; - let tcp_stream: Box = { + let tcp_stream: Box = { #[cfg(feature = "proxy")] match options.proxy() { Some(proxy) => { From abf416c14457eeca17f903444f5446069943cc22 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Mon, 18 Mar 2024 09:44:54 +0000 Subject: [PATCH 04/20] feat: MQTT `Codec` decoder --- Cargo.lock | 13 +++++++++++++ rumqttc/Cargo.toml | 2 ++ rumqttc/src/framed.rs | 25 ++++++++++++++++--------- rumqttc/src/mqttbytes/mod.rs | 4 +++- rumqttc/src/mqttbytes/v4/codec.rs | 24 ++++++++++++++++++++++++ rumqttc/src/mqttbytes/v4/mod.rs | 2 ++ 6 files changed, 60 insertions(+), 10 deletions(-) create mode 100644 rumqttc/src/mqttbytes/v4/codec.rs diff --git a/Cargo.lock b/Cargo.lock index 2efed53f..e6f8648c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1951,6 +1951,8 @@ dependencies = [ "tokio", "tokio-native-tls", "tokio-rustls", + "tokio-stream", + "tokio-util", "url", "ws_stream_tungstenite", ] @@ -2543,6 +2545,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.10" diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 551dcab5..1ace7369 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -25,6 +25,7 @@ proxy = ["dep:async-http-proxy"] [dependencies] futures-util = { version = "0.3", default_features = false, features = ["std"] } tokio = { version = "1.36", features = ["rt", "macros", "io-util", "net", "time"] } +tokio-util = { version = "0.7", features = ["codec"] } bytes = "1.5" log = "0.4" flume = { version = "0.11", default-features = false, features = ["async"] } @@ -47,6 +48,7 @@ native-tls = { version = "0.2.11", optional = true } url = { version = "2", default-features = false, optional = true } # proxy async-http-proxy = { version = "1.2.5", features = ["runtime-tokio", "basic-auth"], optional = true } +tokio-stream = "0.1.15" [dev-dependencies] bincode = "1.3.3" diff --git a/rumqttc/src/framed.rs b/rumqttc/src/framed.rs index d2ec7367..b61ba3ee 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/framed.rs @@ -1,5 +1,6 @@ use bytes::BytesMut; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio_util::codec::Decoder; use crate::mqttbytes::{self, v4::*}; use crate::{Incoming, MqttState, StateError}; @@ -13,8 +14,8 @@ pub struct Network { socket: Box, /// Buffered reads read: BytesMut, - /// Maximum packet size - max_incoming_size: usize, + /// Use to decode MQTT packets + codec: Codec, /// Maximum readv count max_readb_count: usize, } @@ -25,7 +26,7 @@ impl Network { Network { socket, read: BytesMut::with_capacity(10 * 1024), - max_incoming_size, + codec: Codec { max_incoming_size }, max_readb_count: 10, } } @@ -58,8 +59,10 @@ impl Network { pub async fn read(&mut self) -> io::Result { loop { - let required = match Packet::read(&mut self.read, self.max_incoming_size) { - Ok(packet) => return Ok(packet), + let required = match self.codec.decode(&mut self.read) { + Ok(Some(packet)) => return Ok(packet), + // TODO: figure out how not to block + Ok(_) => 2, Err(mqttbytes::Error::InsufficientBytes(required)) => required, Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), }; @@ -75,17 +78,19 @@ impl Network { pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { let mut count = 0; loop { - match Packet::read(&mut self.read, self.max_incoming_size) { - Ok(packet) => { + match self.codec.decode(&mut self.read) { + Ok(Some(packet)) => { state.handle_incoming_packet(packet)?; count += 1; if count >= self.max_readb_count { - return Ok(()); + break; } } // If some packets are already framed, return those - Err(mqttbytes::Error::InsufficientBytes(_)) if count > 0 => return Ok(()), + Err(mqttbytes::Error::InsufficientBytes(_)) if count > 0 => break, + // TODO: figure out how not to block + Ok(_) => break, // Wait for more bytes until a frame can be created Err(mqttbytes::Error::InsufficientBytes(required)) => { self.read_bytes(required).await?; @@ -93,6 +98,8 @@ impl Network { Err(e) => return Err(StateError::Deserialization(e)), }; } + + Ok(()) } pub async fn connect(&mut self, connect: Connect) -> io::Result { diff --git a/rumqttc/src/mqttbytes/mod.rs b/rumqttc/src/mqttbytes/mod.rs index 69858d80..72e61a2c 100644 --- a/rumqttc/src/mqttbytes/mod.rs +++ b/rumqttc/src/mqttbytes/mod.rs @@ -13,7 +13,7 @@ pub mod v4; pub use topic::*; /// Error during serialization and deserialization -#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +#[derive(Debug, thiserror::Error)] pub enum Error { #[error("Expected Connect, received: {0:?}")] NotConnect(PacketType), @@ -60,6 +60,8 @@ pub enum Error { /// proceed further #[error("At least {0} more bytes required to frame packet")] InsufficientBytes(usize), + #[error("IO: {0}")] + Io(#[from] std::io::Error), } /// MQTT packet type diff --git a/rumqttc/src/mqttbytes/v4/codec.rs b/rumqttc/src/mqttbytes/v4/codec.rs new file mode 100644 index 00000000..fb28477d --- /dev/null +++ b/rumqttc/src/mqttbytes/v4/codec.rs @@ -0,0 +1,24 @@ +use bytes::{Buf, BytesMut}; +use tokio_util::codec::Decoder; + +use super::{Error, Packet}; + +/// MQTT v4 codec +pub struct Codec { + /// Maximum packet size + pub max_incoming_size: usize, +} + +impl Decoder for Codec { + type Item = Packet; + type Error = Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + if src.remaining() == 0 { + return Ok(None); + } + + let packet = Packet::read(src, self.max_incoming_size)?; + Ok(Some(packet)) + } +} diff --git a/rumqttc/src/mqttbytes/v4/mod.rs b/rumqttc/src/mqttbytes/v4/mod.rs index 3c9225e8..4906b2fa 100644 --- a/rumqttc/src/mqttbytes/v4/mod.rs +++ b/rumqttc/src/mqttbytes/v4/mod.rs @@ -1,5 +1,6 @@ use super::*; +mod codec; mod connack; mod connect; mod disconnect; @@ -27,6 +28,7 @@ pub use suback::*; pub use subscribe::*; pub use unsuback::*; pub use unsubscribe::*; +pub use codec::*; /// Encapsulates all MQTT packet types #[derive(Debug, Clone, PartialEq, Eq)] From 0842f730a6876c4ac645f8631668b6edb0b3d107 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Mon, 18 Mar 2024 14:05:38 +0000 Subject: [PATCH 05/20] feat: MQTT `Codec` encoder --- rumqttc/src/eventloop.rs | 18 ++++++++-- rumqttc/src/framed.rs | 24 ++++++++----- rumqttc/src/lib.rs | 19 ---------- rumqttc/src/mqttbytes/mod.rs | 2 ++ rumqttc/src/mqttbytes/v4/codec.rs | 17 +++++++-- rumqttc/src/mqttbytes/v4/mod.rs | 9 ++++- rumqttc/src/state.rs | 59 +++++++++++++++---------------- 7 files changed, 83 insertions(+), 65 deletions(-) diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index fe971a6f..5ca48613 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -356,7 +356,11 @@ async fn network_connect( if matches!(options.transport(), Transport::Unix) { let file = options.broker_addr.as_str(); let socket = UnixStream::connect(Path::new(file)).await?; - let network = Network::new(socket, options.max_incoming_packet_size); + let network = Network::new( + socket, + options.max_incoming_packet_size, + options.max_outgoing_packet_size, + ); return Ok(network); } @@ -388,13 +392,21 @@ async fn network_connect( }; let network = match options.transport() { - Transport::Tcp => Network::new(tcp_stream, options.max_incoming_packet_size), + Transport::Tcp => Network::new( + tcp_stream, + options.max_incoming_packet_size, + options.max_outgoing_packet_size, + ), #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))] Transport::Tls(tls_config) => { let socket = tls::tls_connect(&options.broker_addr, options.port, &tls_config, tcp_stream) .await?; - Network::new(socket, options.max_incoming_packet_size) + Network::new( + socket, + options.max_incoming_packet_size, + options.max_outgoing_packet_size, + ) } #[cfg(unix)] Transport::Unix => unreachable!(), diff --git a/rumqttc/src/framed.rs b/rumqttc/src/framed.rs index b61ba3ee..ee1e14a8 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/framed.rs @@ -1,6 +1,6 @@ use bytes::BytesMut; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio_util::codec::Decoder; +use tokio_util::codec::{Decoder, Encoder}; use crate::mqttbytes::{self, v4::*}; use crate::{Incoming, MqttState, StateError}; @@ -21,12 +21,19 @@ pub struct Network { } impl Network { - pub fn new(socket: impl AsyncReadWrite + 'static, max_incoming_size: usize) -> Network { + pub fn new( + socket: impl AsyncReadWrite + 'static, + max_incoming_size: usize, + max_outgoing_size: usize, + ) -> Network { let socket = Box::new(socket) as Box; Network { socket, read: BytesMut::with_capacity(10 * 1024), - codec: Codec { max_incoming_size }, + codec: Codec { + max_incoming_size, + max_outgoing_size, + }, max_readb_count: 10, } } @@ -102,15 +109,14 @@ impl Network { Ok(()) } - pub async fn connect(&mut self, connect: Connect) -> io::Result { + pub async fn connect(&mut self, connect: Connect) -> io::Result<()> { let mut write = BytesMut::new(); - let len = match connect.write(&mut write) { - Ok(size) => size, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), - }; + self.codec + .encode(Packet::Connect(connect), &mut write) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?; self.socket.write_all(&write[..]).await?; - Ok(len) + Ok(()) } pub async fn flush(&mut self, write: &mut BytesMut) -> io::Result<()> { diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 43dbb3be..9c30d46a 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -200,25 +200,6 @@ pub enum Request { Disconnect(Disconnect), } -impl Request { - fn size(&self) -> usize { - match &self { - Request::Publish(publish) => publish.size(), - Request::PubAck(puback) => puback.size(), - Request::PubRec(pubrec) => pubrec.size(), - Request::PubComp(pubcomp) => pubcomp.size(), - Request::PubRel(pubrel) => pubrel.size(), - Request::PingReq(pingreq) => pingreq.size(), - Request::PingResp(pingresp) => pingresp.size(), - Request::Subscribe(subscribe) => subscribe.size(), - Request::SubAck(suback) => suback.size(), - Request::Unsubscribe(unsubscribe) => unsubscribe.size(), - Request::UnsubAck(unsuback) => unsuback.size(), - Request::Disconnect(disconn) => disconn.size(), - } - } -} - impl From for Request { fn from(publish: Publish) -> Request { Request::Publish(publish) diff --git a/rumqttc/src/mqttbytes/mod.rs b/rumqttc/src/mqttbytes/mod.rs index 72e61a2c..3345b897 100644 --- a/rumqttc/src/mqttbytes/mod.rs +++ b/rumqttc/src/mqttbytes/mod.rs @@ -62,6 +62,8 @@ pub enum Error { InsufficientBytes(usize), #[error("IO: {0}")] Io(#[from] std::io::Error), + #[error("Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'")] + OutgoingPacketTooLarge { pkt_size: usize, max: usize }, } /// MQTT packet type diff --git a/rumqttc/src/mqttbytes/v4/codec.rs b/rumqttc/src/mqttbytes/v4/codec.rs index fb28477d..b588db6b 100644 --- a/rumqttc/src/mqttbytes/v4/codec.rs +++ b/rumqttc/src/mqttbytes/v4/codec.rs @@ -1,12 +1,15 @@ use bytes::{Buf, BytesMut}; -use tokio_util::codec::Decoder; +use tokio_util::codec::{Decoder, Encoder}; use super::{Error, Packet}; /// MQTT v4 codec +#[derive(Debug, Clone)] pub struct Codec { - /// Maximum packet size + /// Maximum packet size allowed by client pub max_incoming_size: usize, + /// Maximum packet size allowed by broker + pub max_outgoing_size: usize, } impl Decoder for Codec { @@ -22,3 +25,13 @@ impl Decoder for Codec { Ok(Some(packet)) } } + +impl Encoder for Codec { + type Error = Error; + + fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> { + item.write(dst, self.max_outgoing_size)?; + + Ok(()) + } +} diff --git a/rumqttc/src/mqttbytes/v4/mod.rs b/rumqttc/src/mqttbytes/v4/mod.rs index 4906b2fa..ed438dd0 100644 --- a/rumqttc/src/mqttbytes/v4/mod.rs +++ b/rumqttc/src/mqttbytes/v4/mod.rs @@ -111,7 +111,14 @@ impl Packet { } /// Serializes the MQTT packet into a stream of bytes - pub fn write(&self, stream: &mut BytesMut) -> Result { + pub fn write(&self, stream: &mut BytesMut, max_size: usize) -> Result { + if self.size() > max_size { + return Err(Error::OutgoingPacketTooLarge { + pkt_size: self.size(), + max: max_size, + }) + } + match self { Packet::Connect(c) => c.write(stream), Packet::ConnAck(c) => c.write(stream), diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index acee6f1d..bf1792d4 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -5,6 +5,7 @@ use crate::mqttbytes::{self, *}; use bytes::BytesMut; use std::collections::VecDeque; use std::{io, time::Instant}; +use tokio_util::codec::Encoder; /// Errors during state handling #[derive(Debug, thiserror::Error)] @@ -30,8 +31,6 @@ pub enum StateError { EmptySubscription, #[error("Mqtt serialization/deserialization error: {0}")] Deserialization(#[from] mqttbytes::Error), - #[error("Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'")] - OutgoingPacketTooLarge { pkt_size: usize, max: usize }, } /// State of the mqtt connection. @@ -74,8 +73,8 @@ pub struct MqttState { pub write: BytesMut, /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, - /// Maximum outgoing packet size, set via MqttOptions - pub max_outgoing_packet_size: usize, + /// Used to encode packets + pub codec: Codec, } impl MqttState { @@ -101,7 +100,11 @@ impl MqttState { events: VecDeque::with_capacity(100), write: BytesMut::with_capacity(10 * 1024), manual_acks, - max_outgoing_packet_size, + codec: Codec { + max_outgoing_size: max_outgoing_packet_size, + // The following is ignored for encoding + max_incoming_size: max_outgoing_packet_size, + }, } } @@ -146,8 +149,6 @@ impl MqttState { /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should /// be put on to the network by the eventloop pub fn handle_outgoing_packet(&mut self, request: Request) -> Result<(), StateError> { - // Enforce max outgoing packet size - self.check_size(request.size())?; match request { Request::Publish(publish) => self.outgoing_publish(publish)?, Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, @@ -247,8 +248,9 @@ impl MqttState { self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); self.inflight += 1; - publish.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.codec + .encode(Packet::Publish(publish), &mut self.write)?; self.events.push_back(event); self.collision_ping_count = 0; } @@ -265,7 +267,8 @@ impl MqttState { Some(_) => { // NOTE: Inflight - 1 for qos2 in comp self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid); - PubRel::new(pubrec.pkid).write(&mut self.write)?; + let pubrel = PubRel { pkid: pubrec.pkid }; + self.codec.encode(Packet::PubRel(pubrel), &mut self.write)?; let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); self.events.push_back(event); @@ -285,8 +288,10 @@ impl MqttState { .ok_or(StateError::Unsolicited(pubrel.pkid))?; match publish.take() { Some(_) => { - PubComp::new(pubrel.pkid).write(&mut self.write)?; let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid)); + let pubcomp = PubComp { pkid: pubrel.pkid }; + self.codec + .encode(Packet::PubComp(pubcomp), &mut self.write)?; self.events.push_back(event); Ok(()) } @@ -299,8 +304,9 @@ impl MqttState { fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<(), StateError> { if let Some(publish) = self.check_collision(pubcomp.pkid) { - publish.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.codec + .encode(Packet::Publish(publish), &mut self.write)?; self.events.push_back(event); self.collision_ping_count = 0; } @@ -361,8 +367,9 @@ impl MqttState { publish.payload.len() ); - publish.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.codec + .encode(Packet::Publish(publish), &mut self.write)?; self.events.push_back(event); Ok(()) } @@ -371,23 +378,22 @@ impl MqttState { let pubrel = self.save_pubrel(pubrel)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); - PubRel::new(pubrel.pkid).write(&mut self.write)?; - let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); + self.codec.encode(Packet::PubRel(pubrel), &mut self.write)?; self.events.push_back(event); Ok(()) } fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> { - puback.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::PubAck(puback.pkid)); + self.codec.encode(Packet::PubAck(puback), &mut self.write)?; self.events.push_back(event); Ok(()) } fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> { - pubrec.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::PubRec(pubrec.pkid)); + self.codec.encode(Packet::PubRec(pubrec), &mut self.write)?; self.events.push_back(event); Ok(()) } @@ -421,8 +427,8 @@ impl MqttState { elapsed_out.as_millis() ); - PingReq.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::PingReq); + self.codec.encode(Packet::PingReq, &mut self.write)?; self.events.push_back(event); Ok(()) } @@ -440,8 +446,9 @@ impl MqttState { subscription.filters, subscription.pkid ); - subscription.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid)); + self.codec + .encode(Packet::Subscribe(subscription), &mut self.write)?; self.events.push_back(event); Ok(()) } @@ -455,8 +462,9 @@ impl MqttState { unsub.topics, unsub.pkid ); - unsub.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid)); + self.codec + .encode(Packet::Unsubscribe(unsub), &mut self.write)?; self.events.push_back(event); Ok(()) } @@ -464,8 +472,8 @@ impl MqttState { fn outgoing_disconnect(&mut self) -> Result<(), StateError> { debug!("Disconnect"); - Disconnect.write(&mut self.write)?; let event = Event::Outgoing(Outgoing::Disconnect); + self.codec.encode(Packet::Disconnect, &mut self.write)?; self.events.push_back(event); Ok(()) } @@ -480,17 +488,6 @@ impl MqttState { None } - fn check_size(&self, pkt_size: usize) -> Result<(), StateError> { - if pkt_size > self.max_outgoing_packet_size { - Err(StateError::OutgoingPacketTooLarge { - pkt_size, - max: self.max_outgoing_packet_size, - }) - } else { - Ok(()) - } - } - fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result { let pubrel = match pubrel.pkid { // consider PacketIdentifier(0) as uninitialized packets From 036e1f2fe5456b12a5a3ba677ba943117c8d07d6 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Mon, 18 Mar 2024 15:52:52 +0000 Subject: [PATCH 06/20] refactor: move write buffer into `Network` --- rumqttc/src/eventloop.rs | 12 ++-- rumqttc/src/framed.rs | 17 +++-- rumqttc/src/state.rs | 137 ++++++++++++++++++++++++--------------- 3 files changed, 101 insertions(+), 65 deletions(-) diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index 5ca48613..08b989c2 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -81,7 +81,7 @@ pub struct EventLoop { /// Pending packets from last session pub pending: VecDeque, /// Network connection to the broker - network: Option, + pub network: Option, /// Keep alive time keepalive_timeout: Option>>, pub network_options: NetworkOptions, @@ -189,7 +189,7 @@ impl EventLoop { o = network.readb(&mut self.state) => { o?; // flush all the acks and return first incoming packet - match time::timeout(network_timeout, network.flush(&mut self.state.write)).await { + match time::timeout(network_timeout, network.flush()).await { Ok(inner) => inner?, Err(_)=> return Err(ConnectionError::FlushTimeout), }; @@ -229,8 +229,8 @@ impl EventLoop { self.mqtt_options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { Ok(request) => { - self.state.handle_outgoing_packet(request)?; - match time::timeout(network_timeout, network.flush(&mut self.state.write)).await { + self.state.handle_outgoing_packet(request, network)?; + match time::timeout(network_timeout, network.flush()).await { Ok(inner) => inner?, Err(_)=> return Err(ConnectionError::FlushTimeout), }; @@ -245,8 +245,8 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive); - self.state.handle_outgoing_packet(Request::PingReq(PingReq))?; - match time::timeout(network_timeout, network.flush(&mut self.state.write)).await { + self.state.handle_outgoing_packet(Request::PingReq(PingReq), network)?; + match time::timeout(network_timeout, network.flush()).await { Ok(inner) => inner?, Err(_)=> return Err(ConnectionError::FlushTimeout), }; diff --git a/rumqttc/src/framed.rs b/rumqttc/src/framed.rs index ee1e14a8..0b172262 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/framed.rs @@ -14,6 +14,8 @@ pub struct Network { socket: Box, /// Buffered reads read: BytesMut, + /// Buffered writes + pub write: BytesMut, /// Use to decode MQTT packets codec: Codec, /// Maximum readv count @@ -30,6 +32,7 @@ impl Network { Network { socket, read: BytesMut::with_capacity(10 * 1024), + write: BytesMut::with_capacity(10 * 1024), codec: Codec { max_incoming_size, max_outgoing_size, @@ -87,7 +90,7 @@ impl Network { loop { match self.codec.decode(&mut self.read) { Ok(Some(packet)) => { - state.handle_incoming_packet(packet)?; + state.handle_incoming_packet(packet, self)?; count += 1; if count >= self.max_readb_count { @@ -119,13 +122,17 @@ impl Network { Ok(()) } - pub async fn flush(&mut self, write: &mut BytesMut) -> io::Result<()> { - if write.is_empty() { + pub fn write(&mut self, packet: Packet) -> Result<(), crate::mqttbytes::Error> { + self.codec.encode(packet, &mut self.write) + } + + pub async fn flush(&mut self) -> io::Result<()> { + if self.write.is_empty() { return Ok(()); } - self.socket.write_all(&write[..]).await?; - write.clear(); + self.socket.write_all(&self.write[..]).await?; + self.write.clear(); Ok(()) } } diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index bf1792d4..168764a5 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -1,11 +1,10 @@ +use crate::framed::Network; use crate::{Event, Incoming, Outgoing, Request}; use crate::mqttbytes::v4::*; use crate::mqttbytes::{self, *}; -use bytes::BytesMut; use std::collections::VecDeque; use std::{io, time::Instant}; -use tokio_util::codec::Encoder; /// Errors during state handling #[derive(Debug, thiserror::Error)] @@ -69,8 +68,6 @@ pub struct MqttState { pub collision: Option, /// Buffered incoming packets pub events: VecDeque, - /// Write buffer - pub write: BytesMut, /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, /// Used to encode packets @@ -98,7 +95,6 @@ impl MqttState { collision: None, // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), - write: BytesMut::with_capacity(10 * 1024), manual_acks, codec: Codec { max_outgoing_size: max_outgoing_packet_size, @@ -138,7 +134,6 @@ impl MqttState { self.await_pingresp = false; self.collision_ping_count = 0; self.inflight = 0; - self.write.clear(); pending } @@ -148,16 +143,20 @@ impl MqttState { /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should /// be put on to the network by the eventloop - pub fn handle_outgoing_packet(&mut self, request: Request) -> Result<(), StateError> { + pub fn handle_outgoing_packet( + &mut self, + request: Request, + network: &mut Network, + ) -> Result<(), StateError> { match request { - Request::Publish(publish) => self.outgoing_publish(publish)?, - Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, - Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, - Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, - Request::PingReq(_) => self.outgoing_ping()?, - Request::Disconnect(_) => self.outgoing_disconnect()?, - Request::PubAck(puback) => self.outgoing_puback(puback)?, - Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, + Request::Publish(publish) => self.outgoing_publish(publish, network)?, + Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel, network)?, + Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe, network)?, + Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe, network)?, + Request::PingReq(_) => self.outgoing_ping(network)?, + Request::Disconnect(_) => self.outgoing_disconnect(network)?, + Request::PubAck(puback) => self.outgoing_puback(puback, network)?, + Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec, network)?, _ => unimplemented!(), }; @@ -169,16 +168,20 @@ impl MqttState { /// user to consume and `Packet` which for the eventloop to put on the network /// E.g For incoming QoS1 publish packet, this method returns (Publish, Puback). Publish packet will /// be forwarded to user and Pubck packet will be written to network - pub fn handle_incoming_packet(&mut self, packet: Incoming) -> Result<(), StateError> { + pub fn handle_incoming_packet( + &mut self, + packet: Incoming, + network: &mut Network, + ) -> Result<(), StateError> { let out = match &packet { Incoming::PingResp => self.handle_incoming_pingresp(), - Incoming::Publish(publish) => self.handle_incoming_publish(publish), + Incoming::Publish(publish) => self.handle_incoming_publish(publish, network), Incoming::SubAck(_suback) => self.handle_incoming_suback(), Incoming::UnsubAck(_unsuback) => self.handle_incoming_unsuback(), - Incoming::PubAck(puback) => self.handle_incoming_puback(puback), - Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec), - Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel), - Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp), + Incoming::PubAck(puback) => self.handle_incoming_puback(puback, network), + Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec, network), + Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel, network), + Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp, network), _ => { error!("Invalid incoming packet = {:?}", packet); return Err(StateError::WrongPacket); @@ -201,7 +204,11 @@ impl MqttState { /// Results in a publish notification in all the QoS cases. Replys with an ack /// in case of QoS1 and Replys rec in case of QoS while also storing the message - fn handle_incoming_publish(&mut self, publish: &Publish) -> Result<(), StateError> { + fn handle_incoming_publish( + &mut self, + publish: &Publish, + network: &mut Network, + ) -> Result<(), StateError> { let qos = publish.qos; match qos { @@ -209,7 +216,7 @@ impl MqttState { QoS::AtLeastOnce => { if !self.manual_acks { let puback = PubAck::new(publish.pkid); - self.outgoing_puback(puback)?; + self.outgoing_puback(puback, network)?; } Ok(()) } @@ -219,14 +226,18 @@ impl MqttState { if !self.manual_acks { let pubrec = PubRec::new(pkid); - self.outgoing_pubrec(pubrec)?; + self.outgoing_pubrec(pubrec, network)?; } Ok(()) } } } - fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<(), StateError> { + fn handle_incoming_puback( + &mut self, + puback: &PubAck, + network: &mut Network, + ) -> Result<(), StateError> { let publish = self .outgoing_pub .get_mut(puback.pkid as usize) @@ -249,8 +260,7 @@ impl MqttState { self.inflight += 1; let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - self.codec - .encode(Packet::Publish(publish), &mut self.write)?; + network.write(Packet::Publish(publish))?; self.events.push_back(event); self.collision_ping_count = 0; } @@ -258,7 +268,11 @@ impl MqttState { v } - fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result<(), StateError> { + fn handle_incoming_pubrec( + &mut self, + pubrec: &PubRec, + network: &mut Network, + ) -> Result<(), StateError> { let publish = self .outgoing_pub .get_mut(pubrec.pkid as usize) @@ -268,7 +282,7 @@ impl MqttState { // NOTE: Inflight - 1 for qos2 in comp self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid); let pubrel = PubRel { pkid: pubrec.pkid }; - self.codec.encode(Packet::PubRel(pubrel), &mut self.write)?; + network.write(Packet::PubRel(pubrel))?; let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); self.events.push_back(event); @@ -281,7 +295,11 @@ impl MqttState { } } - fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<(), StateError> { + fn handle_incoming_pubrel( + &mut self, + pubrel: &PubRel, + network: &mut Network, + ) -> Result<(), StateError> { let publish = self .incoming_pub .get_mut(pubrel.pkid as usize) @@ -290,8 +308,7 @@ impl MqttState { Some(_) => { let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid)); let pubcomp = PubComp { pkid: pubrel.pkid }; - self.codec - .encode(Packet::PubComp(pubcomp), &mut self.write)?; + network.write(Packet::PubComp(pubcomp))?; self.events.push_back(event); Ok(()) } @@ -302,11 +319,14 @@ impl MqttState { } } - fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<(), StateError> { + fn handle_incoming_pubcomp( + &mut self, + pubcomp: &PubComp, + network: &mut Network, + ) -> Result<(), StateError> { if let Some(publish) = self.check_collision(pubcomp.pkid) { let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - self.codec - .encode(Packet::Publish(publish), &mut self.write)?; + network.write(Packet::Publish(publish))?; self.events.push_back(event); self.collision_ping_count = 0; } @@ -334,7 +354,11 @@ impl MqttState { /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish(&mut self, mut publish: Publish) -> Result<(), StateError> { + fn outgoing_publish( + &mut self, + mut publish: Publish, + network: &mut Network, + ) -> Result<(), StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); @@ -368,32 +392,31 @@ impl MqttState { ); let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - self.codec - .encode(Packet::Publish(publish), &mut self.write)?; + network.write(Packet::Publish(publish))?; self.events.push_back(event); Ok(()) } - fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result<(), StateError> { + fn outgoing_pubrel(&mut self, pubrel: PubRel, network: &mut Network) -> Result<(), StateError> { let pubrel = self.save_pubrel(pubrel)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); - self.codec.encode(Packet::PubRel(pubrel), &mut self.write)?; + network.write(Packet::PubRel(pubrel))?; self.events.push_back(event); Ok(()) } - fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> { + fn outgoing_puback(&mut self, puback: PubAck, network: &mut Network) -> Result<(), StateError> { let event = Event::Outgoing(Outgoing::PubAck(puback.pkid)); - self.codec.encode(Packet::PubAck(puback), &mut self.write)?; + network.write(Packet::PubAck(puback))?; self.events.push_back(event); Ok(()) } - fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> { + fn outgoing_pubrec(&mut self, pubrec: PubRec, network: &mut Network) -> Result<(), StateError> { let event = Event::Outgoing(Outgoing::PubRec(pubrec.pkid)); - self.codec.encode(Packet::PubRec(pubrec), &mut self.write)?; + network.write(Packet::PubRec(pubrec))?; self.events.push_back(event); Ok(()) } @@ -401,7 +424,7 @@ impl MqttState { /// check when the last control packet/pingreq packet is received and return /// the status which tells if keep alive time has exceeded /// NOTE: status will be checked for zero keepalive times also - fn outgoing_ping(&mut self) -> Result<(), StateError> { + fn outgoing_ping(&mut self, network: &mut Network) -> Result<(), StateError> { let elapsed_in = self.last_incoming.elapsed(); let elapsed_out = self.last_outgoing.elapsed(); @@ -428,12 +451,16 @@ impl MqttState { ); let event = Event::Outgoing(Outgoing::PingReq); - self.codec.encode(Packet::PingReq, &mut self.write)?; + network.write(Packet::PingReq)?; self.events.push_back(event); Ok(()) } - fn outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<(), StateError> { + fn outgoing_subscribe( + &mut self, + mut subscription: Subscribe, + network: &mut Network, + ) -> Result<(), StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); } @@ -447,13 +474,16 @@ impl MqttState { ); let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid)); - self.codec - .encode(Packet::Subscribe(subscription), &mut self.write)?; + network.write(Packet::Subscribe(subscription))?; self.events.push_back(event); Ok(()) } - fn outgoing_unsubscribe(&mut self, mut unsub: Unsubscribe) -> Result<(), StateError> { + fn outgoing_unsubscribe( + &mut self, + mut unsub: Unsubscribe, + network: &mut Network, + ) -> Result<(), StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -463,17 +493,16 @@ impl MqttState { ); let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid)); - self.codec - .encode(Packet::Unsubscribe(unsub), &mut self.write)?; + network.write(Packet::Unsubscribe(unsub))?; self.events.push_back(event); Ok(()) } - fn outgoing_disconnect(&mut self) -> Result<(), StateError> { + fn outgoing_disconnect(&mut self, network: &mut Network) -> Result<(), StateError> { debug!("Disconnect"); let event = Event::Outgoing(Outgoing::Disconnect); - self.codec.encode(Packet::Disconnect, &mut self.write)?; + network.write(Packet::Disconnect)?; self.events.push_back(event); Ok(()) } From 8a220c33fc728b3c32ca723fd75ac00b977469d7 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Mon, 18 Mar 2024 17:12:18 +0000 Subject: [PATCH 07/20] refactor: testing improvements --- rumqttc/src/eventloop.rs | 11 +- rumqttc/src/framed.rs | 19 +- rumqttc/src/mqttbytes/v4/codec.rs | 37 +++- rumqttc/src/state.rs | 285 +++++++++++------------------- rumqttc/tests/reliability.rs | 2 +- 5 files changed, 159 insertions(+), 195 deletions(-) diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index 08b989c2..6b3af08a 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -104,11 +104,10 @@ impl EventLoop { let pending = VecDeque::new(); let max_inflight = mqtt_options.inflight; let manual_acks = mqtt_options.manual_acks; - let max_outgoing_packet_size = mqtt_options.max_outgoing_packet_size; EventLoop { mqtt_options, - state: MqttState::new(max_inflight, manual_acks, max_outgoing_packet_size), + state: MqttState::new(max_inflight, manual_acks), requests_tx, requests_rx, pending, @@ -229,7 +228,9 @@ impl EventLoop { self.mqtt_options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { Ok(request) => { - self.state.handle_outgoing_packet(request, network)?; + if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { + network.write(outgoing)?; + } match time::timeout(network_timeout, network.flush()).await { Ok(inner) => inner?, Err(_)=> return Err(ConnectionError::FlushTimeout), @@ -245,7 +246,9 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive); - self.state.handle_outgoing_packet(Request::PingReq(PingReq), network)?; + if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq(PingReq))? { + network.write(outgoing)?; + } match time::timeout(network_timeout, network.flush()).await { Ok(inner) => inner?, Err(_)=> return Err(ConnectionError::FlushTimeout), diff --git a/rumqttc/src/framed.rs b/rumqttc/src/framed.rs index 0b172262..4d45cd11 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/framed.rs @@ -90,7 +90,9 @@ impl Network { loop { match self.codec.decode(&mut self.read) { Ok(Some(packet)) => { - state.handle_incoming_packet(packet, self)?; + if let Some(packet) = state.handle_incoming_packet(packet)? { + self.write(packet)?; + } count += 1; if count >= self.max_readb_count { @@ -113,17 +115,14 @@ impl Network { } pub async fn connect(&mut self, connect: Connect) -> io::Result<()> { - let mut write = BytesMut::new(); - self.codec - .encode(Packet::Connect(connect), &mut write) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?; - - self.socket.write_all(&write[..]).await?; - Ok(()) + self.write(Packet::Connect(connect)) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string())) } - pub fn write(&mut self, packet: Packet) -> Result<(), crate::mqttbytes::Error> { - self.codec.encode(packet, &mut self.write) + pub fn write(&mut self, packet: Packet) -> Result<(), crate::state::StateError> { + self.codec + .encode(packet, &mut self.write) + .map_err(Into::into) } pub async fn flush(&mut self) -> io::Result<()> { diff --git a/rumqttc/src/mqttbytes/v4/codec.rs b/rumqttc/src/mqttbytes/v4/codec.rs index b588db6b..3e7c73d5 100644 --- a/rumqttc/src/mqttbytes/v4/codec.rs +++ b/rumqttc/src/mqttbytes/v4/codec.rs @@ -28,10 +28,43 @@ impl Decoder for Codec { impl Encoder for Codec { type Error = Error; - + fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> { item.write(dst, self.max_outgoing_size)?; Ok(()) - } + } +} + +#[cfg(test)] +mod tests { + use bytes::BytesMut; + use tokio_util::codec::Encoder; + + use super::Codec; + use crate::{mqttbytes::Error, Packet, Publish, QoS}; + + #[test] + fn outgoing_max_packet_size_check() { + let mut buf = BytesMut::new(); + let mut codec = Codec { + max_incoming_size: 100, + max_outgoing_size: 200, + }; + + let mut small_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 100]); + small_publish.pkid = 1; + codec + .encode(Packet::Publish(small_publish), &mut buf) + .unwrap(); + + let large_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 265]); + match codec.encode(Packet::Publish(large_publish), &mut buf) { + Err(Error::OutgoingPacketTooLarge { + pkt_size: 281, + max: 200, + }) => {} + _ => unreachable!(), + } + } } diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index 168764a5..f6ffc5de 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -1,4 +1,3 @@ -use crate::framed::Network; use crate::{Event, Incoming, Outgoing, Request}; use crate::mqttbytes::v4::*; @@ -70,15 +69,13 @@ pub struct MqttState { pub events: VecDeque, /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, - /// Used to encode packets - pub codec: Codec, } impl MqttState { /// Creates new mqtt state. Same state should be used during a /// connection for persistent sessions while new state should /// instantiated for clean sessions - pub fn new(max_inflight: u16, manual_acks: bool, max_outgoing_packet_size: usize) -> Self { + pub fn new(max_inflight: u16, manual_acks: bool) -> Self { MqttState { await_pingresp: false, collision_ping_count: 0, @@ -96,11 +93,6 @@ impl MqttState { // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), manual_acks, - codec: Codec { - max_outgoing_size: max_outgoing_packet_size, - // The following is ignored for encoding - max_incoming_size: max_outgoing_packet_size, - }, } } @@ -146,22 +138,21 @@ impl MqttState { pub fn handle_outgoing_packet( &mut self, request: Request, - network: &mut Network, - ) -> Result<(), StateError> { - match request { - Request::Publish(publish) => self.outgoing_publish(publish, network)?, - Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel, network)?, - Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe, network)?, - Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe, network)?, - Request::PingReq(_) => self.outgoing_ping(network)?, - Request::Disconnect(_) => self.outgoing_disconnect(network)?, - Request::PubAck(puback) => self.outgoing_puback(puback, network)?, - Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec, network)?, + ) -> Result, StateError> { + let packet = match request { + Request::Publish(publish) => self.outgoing_publish(publish)?, + Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, + Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, + Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, + Request::PingReq(_) => self.outgoing_ping()?, + Request::Disconnect(_) => self.outgoing_disconnect()?, + Request::PubAck(puback) => self.outgoing_puback(puback)?, + Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, _ => unimplemented!(), }; self.last_outgoing = Instant::now(); - Ok(()) + Ok(packet) } /// Consolidates handling of all incoming mqtt packets. Returns a `Notification` which for the @@ -171,54 +162,48 @@ impl MqttState { pub fn handle_incoming_packet( &mut self, packet: Incoming, - network: &mut Network, - ) -> Result<(), StateError> { - let out = match &packet { - Incoming::PingResp => self.handle_incoming_pingresp(), - Incoming::Publish(publish) => self.handle_incoming_publish(publish, network), - Incoming::SubAck(_suback) => self.handle_incoming_suback(), - Incoming::UnsubAck(_unsuback) => self.handle_incoming_unsuback(), - Incoming::PubAck(puback) => self.handle_incoming_puback(puback, network), - Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec, network), - Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel, network), - Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp, network), + ) -> Result, StateError> { + let outgoing = match &packet { + Incoming::PingResp => self.handle_incoming_pingresp()?, + Incoming::Publish(publish) => self.handle_incoming_publish(publish)?, + Incoming::SubAck(_suback) => self.handle_incoming_suback()?, + Incoming::UnsubAck(_unsuback) => self.handle_incoming_unsuback()?, + Incoming::PubAck(puback) => self.handle_incoming_puback(puback)?, + Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec)?, + Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel)?, + Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp)?, _ => { error!("Invalid incoming packet = {:?}", packet); return Err(StateError::WrongPacket); } }; - - out?; self.events.push_back(Event::Incoming(packet)); self.last_incoming = Instant::now(); - Ok(()) + + Ok(outgoing) } - fn handle_incoming_suback(&mut self) -> Result<(), StateError> { - Ok(()) + fn handle_incoming_suback(&mut self) -> Result, StateError> { + Ok(None) } - fn handle_incoming_unsuback(&mut self) -> Result<(), StateError> { - Ok(()) + fn handle_incoming_unsuback(&mut self) -> Result, StateError> { + Ok(None) } /// Results in a publish notification in all the QoS cases. Replys with an ack /// in case of QoS1 and Replys rec in case of QoS while also storing the message - fn handle_incoming_publish( - &mut self, - publish: &Publish, - network: &mut Network, - ) -> Result<(), StateError> { + fn handle_incoming_publish(&mut self, publish: &Publish) -> Result, StateError> { let qos = publish.qos; match qos { - QoS::AtMostOnce => Ok(()), + QoS::AtMostOnce => Ok(None), QoS::AtLeastOnce => { if !self.manual_acks { let puback = PubAck::new(publish.pkid); - self.outgoing_puback(puback, network)?; + return self.outgoing_puback(puback); } - Ok(()) + Ok(None) } QoS::ExactlyOnce => { let pkid = publish.pkid; @@ -226,53 +211,41 @@ impl MqttState { if !self.manual_acks { let pubrec = PubRec::new(pkid); - self.outgoing_pubrec(pubrec, network)?; + return self.outgoing_pubrec(pubrec); } - Ok(()) + Ok(None) } } } - fn handle_incoming_puback( - &mut self, - puback: &PubAck, - network: &mut Network, - ) -> Result<(), StateError> { + fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { let publish = self .outgoing_pub .get_mut(puback.pkid as usize) .ok_or(StateError::Unsolicited(puback.pkid))?; self.last_puback = puback.pkid; - let v = match publish.take() { - Some(_) => { - self.inflight -= 1; - Ok(()) - } - None => { - error!("Unsolicited puback packet: {:?}", puback.pkid); - Err(StateError::Unsolicited(puback.pkid)) - } - }; + publish.take().ok_or({ + error!("Unsolicited puback packet: {:?}", puback.pkid); + StateError::Unsolicited(puback.pkid) + })?; - if let Some(publish) = self.check_collision(puback.pkid) { + self.inflight -= 1; + let packet = self.check_collision(puback.pkid).map(|publish| { self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); self.inflight += 1; let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - network.write(Packet::Publish(publish))?; self.events.push_back(event); self.collision_ping_count = 0; - } - v + Packet::Publish(publish) + }); + + Ok(packet) } - fn handle_incoming_pubrec( - &mut self, - pubrec: &PubRec, - network: &mut Network, - ) -> Result<(), StateError> { + fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { let publish = self .outgoing_pub .get_mut(pubrec.pkid as usize) @@ -282,11 +255,10 @@ impl MqttState { // NOTE: Inflight - 1 for qos2 in comp self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid); let pubrel = PubRel { pkid: pubrec.pkid }; - network.write(Packet::PubRel(pubrel))?; - let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubRel(pubrel))) } None => { error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); @@ -295,11 +267,7 @@ impl MqttState { } } - fn handle_incoming_pubrel( - &mut self, - pubrel: &PubRel, - network: &mut Network, - ) -> Result<(), StateError> { + fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result, StateError> { let publish = self .incoming_pub .get_mut(pubrel.pkid as usize) @@ -308,9 +276,9 @@ impl MqttState { Some(_) => { let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid)); let pubcomp = PubComp { pkid: pubrel.pkid }; - network.write(Packet::PubComp(pubcomp))?; self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubComp(pubcomp))) } None => { error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); @@ -319,46 +287,37 @@ impl MqttState { } } - fn handle_incoming_pubcomp( - &mut self, - pubcomp: &PubComp, - network: &mut Network, - ) -> Result<(), StateError> { - if let Some(publish) = self.check_collision(pubcomp.pkid) { + fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { + self.outgoing_rel + .get_mut(pubcomp.pkid as usize) + .ok_or(StateError::Unsolicited(pubcomp.pkid))? + .take() + .ok_or({ + error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); + StateError::Unsolicited(pubcomp.pkid) + })?; + + self.inflight -= 1; + let packet = self.check_collision(pubcomp.pkid).map(|publish| { let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - network.write(Packet::Publish(publish))?; self.events.push_back(event); self.collision_ping_count = 0; - } - let pubrel = self - .outgoing_rel - .get_mut(pubcomp.pkid as usize) - .ok_or(StateError::Unsolicited(pubcomp.pkid))?; - match pubrel.take() { - Some(_) => { - self.inflight -= 1; - Ok(()) - } - None => { - error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); - Err(StateError::Unsolicited(pubcomp.pkid)) - } - } + Packet::Publish(publish) + }); + + Ok(packet) } - fn handle_incoming_pingresp(&mut self) -> Result<(), StateError> { + fn handle_incoming_pingresp(&mut self) -> Result, StateError> { self.await_pingresp = false; - Ok(()) + + Ok(None) } /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish( - &mut self, - mut publish: Publish, - network: &mut Network, - ) -> Result<(), StateError> { + fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); @@ -375,7 +334,7 @@ impl MqttState { self.collision = Some(publish); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); - return Ok(()); + return Ok(None); } // if there is an existing publish at this pkid, this implies that broker hasn't acked this @@ -392,39 +351,39 @@ impl MqttState { ); let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); - network.write(Packet::Publish(publish))?; self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::Publish(publish))) } - fn outgoing_pubrel(&mut self, pubrel: PubRel, network: &mut Network) -> Result<(), StateError> { + fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { let pubrel = self.save_pubrel(pubrel)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); - network.write(Packet::PubRel(pubrel))?; self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubRel(pubrel))) } - fn outgoing_puback(&mut self, puback: PubAck, network: &mut Network) -> Result<(), StateError> { + fn outgoing_puback(&mut self, puback: PubAck) -> Result, StateError> { let event = Event::Outgoing(Outgoing::PubAck(puback.pkid)); - network.write(Packet::PubAck(puback))?; self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubAck(puback))) } - fn outgoing_pubrec(&mut self, pubrec: PubRec, network: &mut Network) -> Result<(), StateError> { + fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result, StateError> { let event = Event::Outgoing(Outgoing::PubRec(pubrec.pkid)); - network.write(Packet::PubRec(pubrec))?; self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PubRec(pubrec))) } /// check when the last control packet/pingreq packet is received and return /// the status which tells if keep alive time has exceeded /// NOTE: status will be checked for zero keepalive times also - fn outgoing_ping(&mut self, network: &mut Network) -> Result<(), StateError> { + fn outgoing_ping(&mut self) -> Result, StateError> { let elapsed_in = self.last_incoming.elapsed(); let elapsed_out = self.last_outgoing.elapsed(); @@ -451,16 +410,15 @@ impl MqttState { ); let event = Event::Outgoing(Outgoing::PingReq); - network.write(Packet::PingReq)?; self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::PingReq)) } fn outgoing_subscribe( &mut self, mut subscription: Subscribe, - network: &mut Network, - ) -> Result<(), StateError> { + ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); } @@ -474,16 +432,15 @@ impl MqttState { ); let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid)); - network.write(Packet::Subscribe(subscription))?; self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::Subscribe(subscription))) } fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, - network: &mut Network, - ) -> Result<(), StateError> { + ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -493,18 +450,18 @@ impl MqttState { ); let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid)); - network.write(Packet::Unsubscribe(unsub))?; self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::Unsubscribe(unsub))) } - fn outgoing_disconnect(&mut self, network: &mut Network) -> Result<(), StateError> { + fn outgoing_disconnect(&mut self) -> Result, StateError> { debug!("Disconnect"); let event = Event::Outgoing(Outgoing::Disconnect); - network.write(Packet::Disconnect)?; self.events.push_back(event); - Ok(()) + + Ok(Some(Packet::Disconnect)) } fn check_collision(&mut self, pkid: u16) -> Option { @@ -558,7 +515,6 @@ mod test { use crate::mqttbytes::v4::*; use crate::mqttbytes::*; use crate::{Event, Incoming, Outgoing, Request}; - use bytes::BufMut; fn build_outgoing_publish(qos: QoS) -> Publish { let topic = "hello/world".to_owned(); @@ -580,7 +536,7 @@ mod test { } fn build_mqttstate() -> MqttState { - MqttState::new(100, false, usize::MAX) + MqttState::new(100, false) } #[test] @@ -600,25 +556,6 @@ mod test { } } - #[test] - fn outgoing_max_packet_size_check() { - let mut mqtt = MqttState::new(100, false, 200); - - let small_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 100]); - assert_eq!( - mqtt.handle_outgoing_packet(Request::Publish(small_publish)) - .is_ok(), - true - ); - - let large_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 265]); - assert_eq!( - mqtt.handle_outgoing_packet(Request::Publish(large_publish)) - .is_ok(), - false - ); - } - #[test] fn outgoing_publish_should_set_pkid_and_add_publish_to_queue() { let mut mqtt = build_mqttstate(); @@ -728,8 +665,7 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - mqtt.handle_incoming_publish(&publish).unwrap(); - let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = mqtt.handle_incoming_publish(&publish).unwrap().unwrap(); match packet { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), _ => panic!("Invalid network request: {:?}", packet), @@ -795,15 +731,16 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish).unwrap(); - let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = mqtt.outgoing_publish(publish).unwrap().unwrap(); match packet { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } - mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); - let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = mqtt + .handle_incoming_pubrec(&PubRec::new(1)) + .unwrap() + .unwrap(); match packet { Packet::PubRel(pubrel) => assert_eq!(pubrel.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), @@ -815,15 +752,16 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - mqtt.handle_incoming_publish(&publish).unwrap(); - let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = mqtt.handle_incoming_publish(&publish).unwrap().unwrap(); match packet { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } - mqtt.handle_incoming_pubrel(&PubRel::new(1)).unwrap(); - let packet = Packet::read(&mut mqtt.write, 10 * 1024).unwrap(); + let packet = mqtt + .handle_incoming_pubrel(&PubRel::new(1)) + .unwrap() + .unwrap(); match packet { Packet::PubComp(pubcomp) => assert_eq!(pubcomp.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), @@ -874,15 +812,6 @@ mod test { mqtt.outgoing_ping().unwrap(); } - #[test] - fn state_should_be_clean_properly() { - let mut mqtt = build_mqttstate(); - mqtt.write.put(&b"test"[..]); - // After this clean state.write should be empty - mqtt.clean(); - assert!(mqtt.write.is_empty()); - } - #[test] fn clean_is_calculating_pending_correctly() { let mut mqtt = build_mqttstate(); diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 3e7acd1e..7d96ae44 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -570,7 +570,7 @@ async fn state_is_being_cleaned_properly_and_pending_request_calculated_properly if let Err(e) = res { match e { ConnectionError::FlushTimeout => { - assert!(eventloop.state.write.is_empty()); + assert!(eventloop.network.as_ref().unwrap().write.is_empty()); println!("State is being clean properly"); } _ => { From a5b2717d99de64758e2860849bc4116fa0fba5bd Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Tue, 19 Mar 2024 18:03:47 +0530 Subject: [PATCH 08/20] fix: `readb` should block for 1 packet (#824) --- rumqttc/src/framed.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/rumqttc/src/framed.rs b/rumqttc/src/framed.rs index 4d45cd11..4be455d5 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/framed.rs @@ -100,9 +100,11 @@ impl Network { } } // If some packets are already framed, return those - Err(mqttbytes::Error::InsufficientBytes(_)) if count > 0 => break, - // TODO: figure out how not to block - Ok(_) => break, + Err(mqttbytes::Error::InsufficientBytes(_)) | Ok(_) if count > 0 => break, + // NOTE: read atleast 1 packet + Ok(_) => { + self.read_bytes(2).await?; + } // Wait for more bytes until a frame can be created Err(mqttbytes::Error::InsufficientBytes(required)) => { self.read_bytes(required).await?; From 303d8c22c3519a97d3dbe336d5663d210eca7765 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Tue, 19 Mar 2024 14:01:48 +0000 Subject: [PATCH 09/20] fix: `Network::connect` should flush --- rumqttc/src/framed.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rumqttc/src/framed.rs b/rumqttc/src/framed.rs index 4be455d5..598be8ea 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/framed.rs @@ -118,7 +118,9 @@ impl Network { pub async fn connect(&mut self, connect: Connect) -> io::Result<()> { self.write(Packet::Connect(connect)) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string())) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?; + + self.flush().await } pub fn write(&mut self, packet: Packet) -> Result<(), crate::state::StateError> { From 7fee5c17895b4b7e0282c0f2ed69d286ec362cf2 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Tue, 19 Mar 2024 14:05:31 +0000 Subject: [PATCH 10/20] test: fix network expectations for `EventLoop::clean` --- rumqttc/tests/reliability.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 7d96ae44..0a83d57c 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -570,7 +570,7 @@ async fn state_is_being_cleaned_properly_and_pending_request_calculated_properly if let Err(e) = res { match e { ConnectionError::FlushTimeout => { - assert!(eventloop.network.as_ref().unwrap().write.is_empty()); + assert!(eventloop.network.is_none()); println!("State is being clean properly"); } _ => { From 75efeaaae2733bb89f4283ae0204e15c121601a1 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Tue, 19 Mar 2024 14:46:17 +0000 Subject: [PATCH 11/20] refactor: use `Framed` to encode/decode --- rumqttc/Cargo.toml | 2 +- rumqttc/src/eventloop.rs | 41 ++++++------ rumqttc/src/framed.rs | 123 ++++++++++------------------------- rumqttc/src/state.rs | 2 + rumqttc/tests/reliability.rs | 2 +- 5 files changed, 57 insertions(+), 113 deletions(-) diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index 1ace7369..bba64822 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -23,7 +23,7 @@ websocket = ["dep:async-tungstenite", "dep:ws_stream_tungstenite", "dep:http"] proxy = ["dep:async-http-proxy"] [dependencies] -futures-util = { version = "0.3", default_features = false, features = ["std"] } +futures-util = { version = "0.3", default_features = false, features = ["std", "sink"] } tokio = { version = "1.36", features = ["rt", "macros", "io-util", "net", "time"] } tokio-util = { version = "0.7", features = ["codec"] } bytes = "1.5" diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index 6b3af08a..796acaa2 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -38,8 +38,6 @@ pub enum ConnectionError { MqttState(#[from] StateError), #[error("Network timeout")] NetworkTimeout, - #[error("Flush timeout")] - FlushTimeout, #[cfg(feature = "websocket")] #[error("Websocket: {0}")] Websocket(#[from] async_tungstenite::tungstenite::error::Error), @@ -173,7 +171,6 @@ impl EventLoop { // let await_acks = self.state.await_acks; let inflight_full = self.state.inflight >= self.mqtt_options.inflight; let collision = self.state.collision.is_some(); - let network_timeout = Duration::from_secs(self.network_options.connection_timeout()); // Read buffered events from previous polls before calling a new poll if let Some(event) = self.state.events.pop_front() { @@ -187,11 +184,6 @@ impl EventLoop { // Pull a bunch of packets from network, reply in bunch and yield the first item o = network.readb(&mut self.state) => { o?; - // flush all the acks and return first incoming packet - match time::timeout(network_timeout, network.flush()).await { - Ok(inner) => inner?, - Err(_)=> return Err(ConnectionError::FlushTimeout), - }; Ok(self.state.events.pop_front().unwrap()) }, // Handles pending and new requests. @@ -229,12 +221,9 @@ impl EventLoop { ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { Ok(request) => { if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { - network.write(outgoing)?; + network.send(outgoing).await?; } - match time::timeout(network_timeout, network.flush()).await { - Ok(inner) => inner?, - Err(_)=> return Err(ConnectionError::FlushTimeout), - }; + Ok(self.state.events.pop_front().unwrap()) } Err(_) => Err(ConnectionError::RequestsDone), @@ -247,12 +236,8 @@ impl EventLoop { timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive); if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq(PingReq))? { - network.write(outgoing)?; + network.send(outgoing).await?; } - match time::timeout(network_timeout, network.flush()).await { - Ok(inner) => inner?, - Err(_)=> return Err(ConnectionError::FlushTimeout), - }; Ok(self.state.events.pop_front().unwrap()) } } @@ -354,6 +339,7 @@ async fn network_connect( options: &MqttOptions, network_options: NetworkOptions, ) -> Result { + let network_timeout = Duration::from_secs(network_options.connection_timeout()); // Process Unix files early, as proxy is not supported for them. #[cfg(unix)] if matches!(options.transport(), Transport::Unix) { @@ -363,6 +349,7 @@ async fn network_connect( socket, options.max_incoming_packet_size, options.max_outgoing_packet_size, + network_timeout, ); return Ok(network); } @@ -399,6 +386,7 @@ async fn network_connect( tcp_stream, options.max_incoming_packet_size, options.max_outgoing_packet_size, + network_timeout, ), #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))] Transport::Tls(tls_config) => { @@ -409,6 +397,7 @@ async fn network_connect( socket, options.max_incoming_packet_size, options.max_outgoing_packet_size, + network_timeout, ) } #[cfg(unix)] @@ -428,7 +417,12 @@ async fn network_connect( async_tungstenite::tokio::client_async(request, tcp_stream).await?; validate_response_headers(response)?; - Network::new(WsStream::new(socket), options.max_incoming_packet_size) + Network::new( + WsStream::new(socket), + options.max_incoming_packet_size, + options.max_outgoing_packet_size, + network_timeout, + ) } #[cfg(all(feature = "use-rustls", feature = "websocket"))] Transport::Wss(tls_config) => { @@ -451,7 +445,12 @@ async fn network_connect( .await?; validate_response_headers(response)?; - Network::new(WsStream::new(socket), options.max_incoming_packet_size) + Network::new( + WsStream::new(socket), + options.max_incoming_packet_size, + options.max_outgoing_packet_size, + network_timeout, + ) } }; @@ -477,7 +476,7 @@ async fn mqtt_connect( } // send mqtt connect packet - network.connect(connect).await?; + network.send(Packet::Connect(connect)).await?; // validate connack match network.read().await? { diff --git a/rumqttc/src/framed.rs b/rumqttc/src/framed.rs index 598be8ea..1b8aef98 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/framed.rs @@ -1,25 +1,22 @@ -use bytes::BytesMut; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio_util::codec::{Decoder, Encoder}; +use futures_util::{SinkExt, StreamExt}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::time::timeout; +use tokio_util::codec::Framed; use crate::mqttbytes::{self, v4::*}; use crate::{Incoming, MqttState, StateError}; -use std::io; +use std::time::Duration; /// Network transforms packets <-> frames efficiently. It takes /// advantage of pre-allocation, buffering and vectorization when /// appropriate to achieve performance pub struct Network { - /// Socket for IO - socket: Box, - /// Buffered reads - read: BytesMut, - /// Buffered writes - pub write: BytesMut, - /// Use to decode MQTT packets - codec: Codec, + /// Frame MQTT packets from network connection + framed: Framed, Codec>, /// Maximum readv count max_readb_count: usize, + /// Time within which network operations should complete + timeout: Duration, } impl Network { @@ -27,59 +24,27 @@ impl Network { socket: impl AsyncReadWrite + 'static, max_incoming_size: usize, max_outgoing_size: usize, + timeout: Duration, ) -> Network { let socket = Box::new(socket) as Box; + let codec = Codec { + max_incoming_size, + max_outgoing_size, + }; + let framed = Framed::new(socket, codec); + Network { - socket, - read: BytesMut::with_capacity(10 * 1024), - write: BytesMut::with_capacity(10 * 1024), - codec: Codec { - max_incoming_size, - max_outgoing_size, - }, + framed, max_readb_count: 10, + timeout, } } - /// Reads more than 'required' bytes to frame a packet into self.read buffer - async fn read_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - loop { - let read = self.socket.read_buf(&mut self.read).await?; - if 0 == read { - return if self.read.is_empty() { - Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "connection closed by peer", - )) - } else { - Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "connection reset by peer", - )) - }; - } - - total_read += read; - if total_read >= required { - return Ok(total_read); - } - } - } - - pub async fn read(&mut self) -> io::Result { - loop { - let required = match self.codec.decode(&mut self.read) { - Ok(Some(packet)) => return Ok(packet), - // TODO: figure out how not to block - Ok(_) => 2, - Err(mqttbytes::Error::InsufficientBytes(required)) => required, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), - }; - - // read more packets until a frame can be created. This function - // blocks until a frame can be created. Use this in a select! branch - self.read_bytes(required).await?; + pub async fn read(&mut self) -> Result { + match self.framed.next().await { + Some(Ok(packet)) => Ok(packet), + Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(), + Some(Err(e)) => Err(StateError::Deserialization(e)), } } @@ -88,10 +53,10 @@ impl Network { pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { let mut count = 0; loop { - match self.codec.decode(&mut self.read) { - Ok(Some(packet)) => { + match self.framed.next().await { + Some(Ok(packet)) => { if let Some(packet) = state.handle_incoming_packet(packet)? { - self.write(packet)?; + self.send(packet).await?; } count += 1; @@ -100,43 +65,21 @@ impl Network { } } // If some packets are already framed, return those - Err(mqttbytes::Error::InsufficientBytes(_)) | Ok(_) if count > 0 => break, + Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None if count > 0 => break, // NOTE: read atleast 1 packet - Ok(_) => { - self.read_bytes(2).await?; - } - // Wait for more bytes until a frame can be created - Err(mqttbytes::Error::InsufficientBytes(required)) => { - self.read_bytes(required).await?; - } - Err(e) => return Err(StateError::Deserialization(e)), + Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(), + Some(Err(e)) => return Err(StateError::Deserialization(e)), }; } Ok(()) } - pub async fn connect(&mut self, connect: Connect) -> io::Result<()> { - self.write(Packet::Connect(connect)) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?; - - self.flush().await - } - - pub fn write(&mut self, packet: Packet) -> Result<(), crate::state::StateError> { - self.codec - .encode(packet, &mut self.write) - .map_err(Into::into) - } - - pub async fn flush(&mut self) -> io::Result<()> { - if self.write.is_empty() { - return Ok(()); + pub async fn send(&mut self, packet: Packet) -> Result<(), crate::state::StateError> { + match timeout(self.timeout, self.framed.send(packet)).await { + Ok(inner) => inner.map_err(Into::into), + Err(_) => Err(StateError::FlushTimeout), } - - self.socket.write_all(&self.write[..]).await?; - self.write.clear(); - Ok(()) } } diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index f6ffc5de..c1014022 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -29,6 +29,8 @@ pub enum StateError { EmptySubscription, #[error("Mqtt serialization/deserialization error: {0}")] Deserialization(#[from] mqttbytes::Error), + #[error("Flush timeout")] + FlushTimeout, } /// State of the mqtt connection. diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 0a83d57c..49ce30d6 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -569,7 +569,7 @@ async fn state_is_being_cleaned_properly_and_pending_request_calculated_properly let res = run(&mut eventloop, false).await; if let Err(e) = res { match e { - ConnectionError::FlushTimeout => { + ConnectionError::MqttState(StateError::FlushTimeout) => { assert!(eventloop.network.is_none()); println!("State is being clean properly"); } From 041c6c8f2e3a7b14b578e6814137ba06a2a4b0b7 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Tue, 19 Mar 2024 15:19:43 +0000 Subject: [PATCH 12/20] deprecate `Network::readb` --- rumqttc/src/eventloop.rs | 7 +++++-- rumqttc/src/framed.rs | 37 ++------------------------------- rumqttc/src/mqttbytes/v4/mod.rs | 4 ++-- 3 files changed, 9 insertions(+), 39 deletions(-) diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index 796acaa2..656b0871 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -182,8 +182,11 @@ impl EventLoop { // instead of returning a None event, we try again. select! { // Pull a bunch of packets from network, reply in bunch and yield the first item - o = network.readb(&mut self.state) => { - o?; + o = network.read() => { + let incoming = o?; + if let Some(packet) = self.state.handle_incoming_packet(incoming)? { + network.send(packet).await?; + } Ok(self.state.events.pop_front().unwrap()) }, // Handles pending and new requests. diff --git a/rumqttc/src/framed.rs b/rumqttc/src/framed.rs index 1b8aef98..6b17fb1e 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/framed.rs @@ -4,7 +4,7 @@ use tokio::time::timeout; use tokio_util::codec::Framed; use crate::mqttbytes::{self, v4::*}; -use crate::{Incoming, MqttState, StateError}; +use crate::{Incoming, StateError}; use std::time::Duration; /// Network transforms packets <-> frames efficiently. It takes @@ -13,8 +13,6 @@ use std::time::Duration; pub struct Network { /// Frame MQTT packets from network connection framed: Framed, Codec>, - /// Maximum readv count - max_readb_count: usize, /// Time within which network operations should complete timeout: Duration, } @@ -33,11 +31,7 @@ impl Network { }; let framed = Framed::new(socket, codec); - Network { - framed, - max_readb_count: 10, - timeout, - } + Network { framed, timeout } } pub async fn read(&mut self) -> Result { @@ -48,33 +42,6 @@ impl Network { } } - /// Read packets in bulk. This allow replies to be in bulk. This method is used - /// after the connection is established to read a bunch of incoming packets - pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { - let mut count = 0; - loop { - match self.framed.next().await { - Some(Ok(packet)) => { - if let Some(packet) = state.handle_incoming_packet(packet)? { - self.send(packet).await?; - } - - count += 1; - if count >= self.max_readb_count { - break; - } - } - // If some packets are already framed, return those - Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None if count > 0 => break, - // NOTE: read atleast 1 packet - Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(), - Some(Err(e)) => return Err(StateError::Deserialization(e)), - }; - } - - Ok(()) - } - pub async fn send(&mut self, packet: Packet) -> Result<(), crate::state::StateError> { match timeout(self.timeout, self.framed.send(packet)).await { Ok(inner) => inner.map_err(Into::into), diff --git a/rumqttc/src/mqttbytes/v4/mod.rs b/rumqttc/src/mqttbytes/v4/mod.rs index ed438dd0..3621945d 100644 --- a/rumqttc/src/mqttbytes/v4/mod.rs +++ b/rumqttc/src/mqttbytes/v4/mod.rs @@ -15,6 +15,7 @@ mod subscribe; mod unsuback; mod unsubscribe; +pub use codec::*; pub use connack::*; pub use connect::*; pub use disconnect::*; @@ -28,7 +29,6 @@ pub use suback::*; pub use subscribe::*; pub use unsuback::*; pub use unsubscribe::*; -pub use codec::*; /// Encapsulates all MQTT packet types #[derive(Debug, Clone, PartialEq, Eq)] @@ -116,7 +116,7 @@ impl Packet { return Err(Error::OutgoingPacketTooLarge { pkt_size: self.size(), max: max_size, - }) + }); } match self { From 5dbe5a18225831c8e30da895c97eb373039df348 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Tue, 19 Mar 2024 16:27:06 +0000 Subject: [PATCH 13/20] refactor: v5 implementation --- rumqttc/src/v5/eventloop.rs | 66 ++++++-- rumqttc/src/v5/framed.rs | 149 +++++------------- rumqttc/src/v5/mqttbytes/mod.rs | 6 +- rumqttc/src/v5/mqttbytes/v5/codec.rs | 73 +++++++++ rumqttc/src/v5/mqttbytes/v5/mod.rs | 14 +- rumqttc/src/v5/state.rs | 223 ++++++++++++++------------- 6 files changed, 289 insertions(+), 242 deletions(-) create mode 100644 rumqttc/src/v5/mqttbytes/v5/codec.rs diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 27c26f29..7c8b5e51 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -210,17 +210,21 @@ impl EventLoop { self.options.pending_throttle ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { Ok(request) => { - self.state.handle_outgoing_packet(request)?; - network.flush(&mut self.state.write).await?; + if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { + network.send(outgoing).await?; + } + Ok(self.state.events.pop_front().unwrap()) } Err(_) => Err(ConnectionError::RequestsDone), }, // Pull a bunch of packets from network, reply in bunch and yield the first item - o = network.readb(&mut self.state) => { - o?; - // flush all the acks and return first incoming packet - network.flush(&mut self.state.write).await?; + o = network.read() => { + let incoming = o?; + if let Some(packet) = self.state.handle_incoming_packet(incoming)? { + network.send(packet).await?; + } + Ok(self.state.events.pop_front().unwrap()) }, // We generate pings irrespective of network activity. This keeps the ping logic @@ -229,8 +233,10 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.options.keep_alive); - self.state.handle_outgoing_packet(Request::PingReq)?; - network.flush(&mut self.state.write).await?; + if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? { + network.send(outgoing).await?; + } + Ok(self.state.events.pop_front().unwrap()) } } @@ -276,7 +282,9 @@ async fn connect(options: &mut MqttOptions) -> Result<(Network, Incoming), Conne } async fn network_connect(options: &MqttOptions) -> Result { - let mut max_incoming_pkt_size = Some(options.default_max_incoming_size); + let mut max_incoming_pkt_size = Some(options.default_max_incoming_size); // incoming == outgoing + let max_outgoing_pkt_size = Some(options.default_max_incoming_size); + let network_timeout = Duration::from_secs(options.network_options.connection_timeout()); // Override default value if max_packet_size is set on `connect_properties` if let Some(connect_props) = &options.connect_properties { @@ -291,7 +299,12 @@ async fn network_connect(options: &MqttOptions) -> Result Result Network::new(tcp_stream, max_incoming_pkt_size), + Transport::Tcp => Network::new( + tcp_stream, + max_incoming_pkt_size, + max_outgoing_pkt_size, + network_timeout, + ), #[cfg(any(feature = "use-native-tls", feature = "use-rustls"))] Transport::Tls(tls_config) => { let socket = tls::tls_connect(&options.broker_addr, options.port, &tls_config, tcp_stream) .await?; - Network::new(socket, max_incoming_pkt_size) + Network::new( + socket, + max_incoming_pkt_size, + max_outgoing_pkt_size, + network_timeout, + ) } #[cfg(unix)] Transport::Unix => unreachable!(), @@ -352,7 +375,12 @@ async fn network_connect(options: &MqttOptions) -> Result { @@ -375,7 +403,12 @@ async fn network_connect(options: &MqttOptions) -> Result frames efficiently. It takes /// advantage of pre-allocation, buffering and vectorization when /// appropriate to achieve performance pub struct Network { - /// Socket for IO - socket: Box, - /// Buffered reads - read: BytesMut, - /// Maximum packet size - max_incoming_size: Option, - /// Maximum readv count - max_readb_count: usize, + /// Frame MQTT packets from network connection + framed: Framed, Codec>, + /// Time within which network operations should complete + timeout: Duration, } - impl Network { - pub fn new(socket: impl N + 'static, max_incoming_size: Option) -> Network { - let socket = Box::new(socket) as Box; - Network { - socket, - read: BytesMut::with_capacity(10 * 1024), + pub fn new( + socket: impl AsyncReadWrite + 'static, + max_incoming_size: Option, + max_outgoing_size: Option, + timeout: Duration, + ) -> Network { + let socket = Box::new(socket) as Box; + let codec = Codec { max_incoming_size, - max_readb_count: 10, - } - } - - /// Reads more than 'required' bytes to frame a packet into self.read buffer - async fn read_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - loop { - let read = self.socket.read_buf(&mut self.read).await?; - if 0 == read { - return if self.read.is_empty() { - Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "connection closed by peer", - )) - } else { - Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "connection reset by peer", - )) - }; - } + max_outgoing_size, + }; + let framed = Framed::new(socket, codec); - total_read += read; - if total_read >= required { - return Ok(total_read); - } - } + Network { framed, timeout } } - pub async fn read(&mut self) -> io::Result { - loop { - let required = match Packet::read(&mut self.read, self.max_incoming_size) { - Ok(packet) => return Ok(packet), - Err(mqttbytes::Error::InsufficientBytes(required)) => required, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), - }; - - // read more packets until a frame can be created. This function - // blocks until a frame can be created. Use this in a select! branch - self.read_bytes(required).await?; + pub async fn read(&mut self) -> Result { + match self.framed.next().await { + Some(Ok(packet)) => Ok(packet), + Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(), + Some(Err(e)) => Err(StateError::Deserialization(e)), } } - /// Read packets in bulk. This allow replies to be in bulk. This method is used - /// after the connection is established to read a bunch of incoming packets - pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { - let mut count = 0; - loop { - match Packet::read(&mut self.read, self.max_incoming_size) { - Ok(packet) => { - state.handle_incoming_packet(packet)?; - - count += 1; - if count >= self.max_readb_count { - return Ok(()); - } - } - // If some packets are already framed, return those - Err(mqttbytes::Error::InsufficientBytes(_)) if count > 0 => return Ok(()), - // Wait for more bytes until a frame can be created - Err(mqttbytes::Error::InsufficientBytes(required)) => { - self.read_bytes(required).await?; - } - Err(mqttbytes::Error::PayloadSizeLimitExceeded { pkt_size, max }) => { - state.handle_protocol_error()?; - return Err(StateError::IncomingPacketTooLarge { pkt_size, max }); - } - Err(e) => return Err(StateError::Deserialization(e)), - }; + pub async fn send(&mut self, packet: Packet) -> Result<(), StateError> { + match timeout(self.timeout, self.framed.send(packet)).await { + Ok(inner) => inner.map_err(StateError::Deserialization), + Err(e) => Err(StateError::Timeout(e)), } } - - pub async fn connect(&mut self, connect: Connect, options: &MqttOptions) -> io::Result { - let mut write = BytesMut::new(); - let last_will = options.last_will(); - let login = options.credentials().map(|l| Login { - username: l.0, - password: l.1, - }); - - let len = match Packet::Connect(connect, last_will, login).write(&mut write) { - Ok(size) => size, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), - }; - - self.socket.write_all(&write[..]).await?; - Ok(len) - } - - pub async fn flush(&mut self, write: &mut BytesMut) -> io::Result<()> { - if write.is_empty() { - return Ok(()); - } - - self.socket.write_all(&write[..]).await?; - write.clear(); - Ok(()) - } } - -pub trait N: AsyncRead + AsyncWrite + Send + Unpin {} -impl N for T where T: AsyncRead + AsyncWrite + Send + Unpin {} diff --git a/rumqttc/src/v5/mqttbytes/mod.rs b/rumqttc/src/v5/mqttbytes/mod.rs index 231c6806..42c1fcdd 100644 --- a/rumqttc/src/v5/mqttbytes/mod.rs +++ b/rumqttc/src/v5/mqttbytes/mod.rs @@ -130,7 +130,7 @@ pub fn matches(topic: &str, filter: &str) -> bool { } /// Error during serialization and deserialization -#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +#[derive(Debug, thiserror::Error)] pub enum Error { #[error("Invalid return code received as response for connect = {0}")] InvalidConnectReturnCode(u8), @@ -183,4 +183,8 @@ pub enum Error { /// proceed further #[error("Insufficient number of bytes to frame packet, {0} more bytes required")] InsufficientBytes(usize), + #[error("IO: {0}")] + Io(#[from] std::io::Error), + #[error("Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'")] + OutgoingPacketTooLarge { pkt_size: u32, max: u32 }, } diff --git a/rumqttc/src/v5/mqttbytes/v5/codec.rs b/rumqttc/src/v5/mqttbytes/v5/codec.rs new file mode 100644 index 00000000..fc24105c --- /dev/null +++ b/rumqttc/src/v5/mqttbytes/v5/codec.rs @@ -0,0 +1,73 @@ +use bytes::{Buf, BytesMut}; +use tokio_util::codec::{Decoder, Encoder}; + +use super::{Error, Packet}; + +/// MQTT v4 codec +#[derive(Debug, Clone)] +pub struct Codec { + /// Maximum packet size allowed by client + pub max_incoming_size: Option, + /// Maximum packet size allowed by broker + pub max_outgoing_size: Option, +} + +impl Decoder for Codec { + type Item = Packet; + type Error = Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + if src.remaining() == 0 { + return Ok(None); + } + + let packet = Packet::read(src, self.max_incoming_size)?; + Ok(Some(packet)) + } +} + +impl Encoder for Codec { + type Error = Error; + + fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> { + item.write(dst, self.max_outgoing_size)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use bytes::BytesMut; + use tokio_util::codec::Encoder; + + use super::Codec; + use crate::v5::{ + mqttbytes::{Error, QoS}, + Packet, Publish, + }; + + #[test] + fn outgoing_max_packet_size_check() { + let mut buf = BytesMut::new(); + let mut codec = Codec { + max_incoming_size: Some(100), + max_outgoing_size: Some(200), + }; + + let mut small_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 100], None); + small_publish.pkid = 1; + codec + .encode(Packet::Publish(small_publish), &mut buf) + .unwrap(); + + let large_publish = Publish::new("hello/world", QoS::AtLeastOnce, vec![1; 265], None); + match codec.encode(Packet::Publish(large_publish), &mut buf) { + Err(Error::OutgoingPacketTooLarge { + pkt_size: 282, + max: 200, + }) => {} + _ => unreachable!(), + } + } +} diff --git a/rumqttc/src/v5/mqttbytes/v5/mod.rs b/rumqttc/src/v5/mqttbytes/v5/mod.rs index 01ddef99..d98ab475 100644 --- a/rumqttc/src/v5/mqttbytes/v5/mod.rs +++ b/rumqttc/src/v5/mqttbytes/v5/mod.rs @@ -1,6 +1,7 @@ use std::slice::Iter; pub use self::{ + codec::Codec, connack::{ConnAck, ConnAckProperties, ConnectReturnCode}, connect::{Connect, ConnectProperties, LastWill, LastWillProperties, Login}, disconnect::{Disconnect, DisconnectReasonCode}, @@ -19,6 +20,7 @@ pub use self::{ use super::*; use bytes::{Buf, BufMut, Bytes, BytesMut}; +mod codec; mod connack; mod connect; mod disconnect; @@ -126,7 +128,17 @@ impl Packet { Ok(packet) } - pub fn write(&self, write: &mut BytesMut) -> Result { + pub fn write(&self, write: &mut BytesMut, max_size: Option) -> Result { + if let Some(max_size) = max_size { + if self.size() > max_size { + dbg!(); + return Err(Error::OutgoingPacketTooLarge { + pkt_size: self.size() as u32, + max: max_size as u32, + }); + } + } + match self { Self::Publish(publish) => publish.write(write), Self::Subscribe(subscription) => subscription.write(write), diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 8473f1f4..c817191e 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -9,8 +9,8 @@ use super::{Event, Incoming, Outgoing, Request}; use bytes::{Bytes, BytesMut}; use std::collections::{HashMap, VecDeque}; -use std::convert::TryInto; use std::{io, time::Instant}; +use tokio::time::error::Elapsed; /// Errors during state handling #[derive(Debug, thiserror::Error)] @@ -42,10 +42,6 @@ pub enum StateError { "Cannot use topic alias '{alias:?}'. It's greater than the broker's maximum of '{max:?}'." )] InvalidAlias { alias: u16, max: u16 }, - #[error("Cannot send packet of size '{pkt_size:?}'. It's greater than the broker's maximum packet size of: '{max:?}'")] - OutgoingPacketTooLarge { pkt_size: u32, max: u32 }, - #[error("Cannot receive packet of size '{pkt_size:?}'. It's greater than the client's maximum packet size of: '{max:?}'")] - IncomingPacketTooLarge { pkt_size: usize, max: usize }, #[error("Server sent disconnect with reason `{reason_string:?}` and code '{reason_code:?}' ")] ServerDisconnect { reason_code: DisconnectReasonCode, @@ -65,6 +61,8 @@ pub enum StateError { PubCompFail { reason: PubCompReason }, #[error("Connection failed with reason '{reason:?}' ")] ConnFail { reason: ConnectReturnCode }, + #[error("Timeout")] + Timeout(#[from] Elapsed), } /// State of the mqtt connection. @@ -108,7 +106,7 @@ pub struct MqttState { /// `topic_alias_maximum` RECEIVED via connack packet pub broker_topic_alias_max: u16, /// The broker's `max_packet_size` received via connack - pub max_outgoing_packet_size: Option, + pub max_outgoing_packet_size: Option, /// Maximum number of allowed inflight QoS1 & QoS2 requests pub(crate) max_outgoing_inflight: u16, /// Upper limit on the maximum number of allowed inflight QoS1 & QoS2 requests @@ -181,77 +179,67 @@ impl MqttState { /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should /// be put on to the network by the eventloop - pub fn handle_outgoing_packet(&mut self, request: Request) -> Result<(), StateError> { - match request { - Request::Publish(publish) => { - self.check_size(publish.size())?; - self.outgoing_publish(publish)? - } - Request::PubRel(pubrel) => { - self.check_size(pubrel.size())?; - self.outgoing_pubrel(pubrel)? - } - Request::Subscribe(subscribe) => { - self.check_size(subscribe.size())?; - self.outgoing_subscribe(subscribe)? - } - Request::Unsubscribe(unsubscribe) => { - self.check_size(unsubscribe.size())?; - self.outgoing_unsubscribe(unsubscribe)? - } + pub fn handle_outgoing_packet( + &mut self, + request: Request, + ) -> Result, StateError> { + let packet = match request { + Request::Publish(publish) => self.outgoing_publish(publish)?, + Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, + Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, + Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, Request::PingReq => self.outgoing_ping()?, Request::Disconnect => { self.outgoing_disconnect(DisconnectReasonCode::NormalDisconnection)? } - Request::PubAck(puback) => { - self.check_size(puback.size())?; - self.outgoing_puback(puback)? - } - Request::PubRec(pubrec) => { - self.check_size(pubrec.size())?; - self.outgoing_pubrec(pubrec)? - } + Request::PubAck(puback) => self.outgoing_puback(puback)?, + Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, _ => unimplemented!(), }; self.last_outgoing = Instant::now(); - Ok(()) + Ok(packet) } /// Consolidates handling of all incoming mqtt packets. Returns a `Notification` which for the /// user to consume and `Packet` which for the eventloop to put on the network /// E.g For incoming QoS1 publish packet, this method returns (Publish, Puback). Publish packet will /// be forwarded to user and Pubck packet will be written to network - pub fn handle_incoming_packet(&mut self, mut packet: Incoming) -> Result<(), StateError> { - let out = match &mut packet { - Incoming::PingResp(_) => self.handle_incoming_pingresp(), - Incoming::Publish(publish) => self.handle_incoming_publish(publish), - Incoming::SubAck(suback) => self.handle_incoming_suback(suback), - Incoming::UnsubAck(unsuback) => self.handle_incoming_unsuback(unsuback), - Incoming::PubAck(puback) => self.handle_incoming_puback(puback), - Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec), - Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel), - Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp), - Incoming::ConnAck(connack) => self.handle_incoming_connack(connack), - Incoming::Disconnect(disconn) => self.handle_incoming_disconn(disconn), + pub fn handle_incoming_packet( + &mut self, + mut packet: Incoming, + ) -> Result, StateError> { + let outgoing = match &mut packet { + Incoming::PingResp(_) => self.handle_incoming_pingresp()?, + Incoming::Publish(publish) => self.handle_incoming_publish(publish)?, + Incoming::SubAck(suback) => self.handle_incoming_suback(suback)?, + Incoming::UnsubAck(unsuback) => self.handle_incoming_unsuback(unsuback)?, + Incoming::PubAck(puback) => self.handle_incoming_puback(puback)?, + Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec)?, + Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel)?, + Incoming::PubComp(pubcomp) => self.handle_incoming_pubcomp(pubcomp)?, + Incoming::ConnAck(connack) => self.handle_incoming_connack(connack)?, + Incoming::Disconnect(disconn) => self.handle_incoming_disconn(disconn)?, _ => { error!("Invalid incoming packet = {:?}", packet); return Err(StateError::WrongPacket); } }; - out?; self.events.push_back(Event::Incoming(packet)); self.last_incoming = Instant::now(); - Ok(()) + Ok(outgoing) } - pub fn handle_protocol_error(&mut self) -> Result<(), StateError> { + pub fn handle_protocol_error(&mut self) -> Result, StateError> { // send DISCONNECT packet with REASON_CODE 0x82 self.outgoing_disconnect(DisconnectReasonCode::ProtocolError) } - fn handle_incoming_suback(&mut self, suback: &mut SubAck) -> Result<(), StateError> { + fn handle_incoming_suback( + &mut self, + suback: &mut SubAck, + ) -> Result, StateError> { for reason in suback.return_codes.iter() { match reason { SubscribeReasonCode::Success(qos) => { @@ -260,19 +248,25 @@ impl MqttState { _ => return Err(StateError::SubFail { reason: *reason }), } } - Ok(()) + Ok(None) } - fn handle_incoming_unsuback(&mut self, unsuback: &mut UnsubAck) -> Result<(), StateError> { + fn handle_incoming_unsuback( + &mut self, + unsuback: &mut UnsubAck, + ) -> Result, StateError> { for reason in unsuback.reasons.iter() { if reason != &UnsubAckReason::Success { return Err(StateError::UnsubFail { reason: *reason }); } } - Ok(()) + Ok(None) } - fn handle_incoming_connack(&mut self, connack: &mut ConnAck) -> Result<(), StateError> { + fn handle_incoming_connack( + &mut self, + connack: &mut ConnAck, + ) -> Result, StateError> { if connack.code != ConnectReturnCode::Success { return Err(StateError::ConnFail { reason: connack.code, @@ -291,12 +285,15 @@ impl MqttState { // to save some space. } - self.max_outgoing_packet_size = props.max_packet_size; + self.max_outgoing_packet_size = props.max_packet_size.map(|i| i as usize); } - Ok(()) + Ok(None) } - fn handle_incoming_disconn(&mut self, disconn: &mut Disconnect) -> Result<(), StateError> { + fn handle_incoming_disconn( + &mut self, + disconn: &mut Disconnect, + ) -> Result, StateError> { let reason_code = disconn.reason_code; let reason_string = if let Some(props) = &disconn.properties { props.reason_string.clone() @@ -311,7 +308,10 @@ impl MqttState { /// Results in a publish notification in all the QoS cases. Replys with an ack /// in case of QoS1 and Replys rec in case of QoS while also storing the message - fn handle_incoming_publish(&mut self, publish: &mut Publish) -> Result<(), StateError> { + fn handle_incoming_publish( + &mut self, + publish: &mut Publish, + ) -> Result, StateError> { let qos = publish.qos; let topic_alias = match &publish.properties { @@ -332,13 +332,13 @@ impl MqttState { } match qos { - QoS::AtMostOnce => Ok(()), + QoS::AtMostOnce => Ok(None), QoS::AtLeastOnce => { if !self.manual_acks { let puback = PubAck::new(publish.pkid, None); self.outgoing_puback(puback)?; } - Ok(()) + Ok(None) } QoS::ExactlyOnce => { let pkid = publish.pkid; @@ -348,12 +348,12 @@ impl MqttState { let pubrec = PubRec::new(pkid, None); self.outgoing_pubrec(pubrec)?; } - Ok(()) + Ok(None) } } } - fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<(), StateError> { + fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { let publish = self .outgoing_pub .get_mut(puback.pkid as usize) @@ -361,7 +361,7 @@ impl MqttState { let v = match publish.take() { Some(_) => { self.inflight -= 1; - Ok(()) + Ok(None) } None => { error!("Unsolicited puback packet: {:?}", puback.pkid); @@ -382,7 +382,7 @@ impl MqttState { self.inflight += 1; let pkid = publish.pkid; - Packet::Publish(publish).write(&mut self.write)?; + Packet::Publish(publish).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); self.collision_ping_count = 0; @@ -391,7 +391,7 @@ impl MqttState { v } - fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result<(), StateError> { + fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { let publish = self .outgoing_pub .get_mut(pubrec.pkid as usize) @@ -408,11 +408,12 @@ impl MqttState { // NOTE: Inflight - 1 for qos2 in comp self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid); - Packet::PubRel(PubRel::new(pubrec.pkid, None)).write(&mut self.write)?; + Packet::PubRel(PubRel::new(pubrec.pkid, None)) + .write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } None => { error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); @@ -421,7 +422,7 @@ impl MqttState { } } - fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<(), StateError> { + fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result, StateError> { let publish = self .incoming_pub .get_mut(pubrel.pkid as usize) @@ -434,10 +435,11 @@ impl MqttState { }); } - Packet::PubComp(PubComp::new(pubrel.pkid, None)).write(&mut self.write)?; + Packet::PubComp(PubComp::new(pubrel.pkid, None)) + .write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } None => { error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); @@ -446,10 +448,10 @@ impl MqttState { } } - fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<(), StateError> { + fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { if let Some(publish) = self.check_collision(pubcomp.pkid) { let pkid = publish.pkid; - Packet::Publish(publish).write(&mut self.write)?; + Packet::Publish(publish).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); self.collision_ping_count = 0; @@ -468,7 +470,7 @@ impl MqttState { } self.inflight -= 1; - Ok(()) + Ok(None) } None => { error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); @@ -477,14 +479,14 @@ impl MqttState { } } - fn handle_incoming_pingresp(&mut self) -> Result<(), StateError> { + fn handle_incoming_pingresp(&mut self) -> Result, StateError> { self.await_pingresp = false; - Ok(()) + Ok(None) } /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish(&mut self, mut publish: Publish) -> Result<(), StateError> { + fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); @@ -501,7 +503,7 @@ impl MqttState { self.collision = Some(publish); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); - return Ok(()); + return Ok(None); } // if there is an existing publish at this pkid, this implies that broker hasn't acked this @@ -532,43 +534,44 @@ impl MqttState { } }; - Packet::Publish(publish).write(&mut self.write)?; + Packet::Publish(publish).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result<(), StateError> { + fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { let pubrel = self.save_pubrel(pubrel)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); - Packet::PubRel(PubRel::new(pubrel.pkid, None)).write(&mut self.write)?; + Packet::PubRel(PubRel::new(pubrel.pkid, None)) + .write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> { + fn outgoing_puback(&mut self, puback: PubAck) -> Result, StateError> { let pkid = puback.pkid; - Packet::PubAck(puback).write(&mut self.write)?; + Packet::PubAck(puback).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PubAck(pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> { + fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result, StateError> { let pkid = pubrec.pkid; - Packet::PubRec(pubrec).write(&mut self.write)?; + Packet::PubRec(pubrec).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PubRec(pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } /// check when the last control packet/pingreq packet is received and return /// the status which tells if keep alive time has exceeded /// NOTE: status will be checked for zero keepalive times also - fn outgoing_ping(&mut self) -> Result<(), StateError> { + fn outgoing_ping(&mut self) -> Result, StateError> { let elapsed_in = self.last_incoming.elapsed(); let elapsed_out = self.last_outgoing.elapsed(); @@ -591,13 +594,16 @@ impl MqttState { elapsed_in, elapsed_out, ); - Packet::PingReq(PingReq).write(&mut self.write)?; + Packet::PingReq(PingReq).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::PingReq); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_subscribe(&mut self, mut subscription: Subscribe) -> Result<(), StateError> { + fn outgoing_subscribe( + &mut self, + mut subscription: Subscribe, + ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); } @@ -611,13 +617,16 @@ impl MqttState { ); let pkid = subscription.pkid; - Packet::Subscribe(subscription).write(&mut self.write)?; + Packet::Subscribe(subscription).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Subscribe(pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_unsubscribe(&mut self, mut unsub: Unsubscribe) -> Result<(), StateError> { + fn outgoing_unsubscribe( + &mut self, + mut unsub: Unsubscribe, + ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -627,19 +636,23 @@ impl MqttState { ); let pkid = unsub.pkid; - Packet::Unsubscribe(unsub).write(&mut self.write)?; + Packet::Unsubscribe(unsub).write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Unsubscribe(pkid)); self.events.push_back(event); - Ok(()) + Ok(None) } - fn outgoing_disconnect(&mut self, reason: DisconnectReasonCode) -> Result<(), StateError> { + fn outgoing_disconnect( + &mut self, + reason: DisconnectReasonCode, + ) -> Result, StateError> { debug!("Disconnect with {:?}", reason); - Packet::Disconnect(Disconnect::new(reason)).write(&mut self.write)?; + Packet::Disconnect(Disconnect::new(reason)) + .write(&mut self.write, self.max_outgoing_packet_size)?; let event = Event::Outgoing(Outgoing::Disconnect); self.events.push_back(event); - Ok(()) + Ok(None) } fn check_collision(&mut self, pkid: u16) -> Option { @@ -652,18 +665,6 @@ impl MqttState { None } - fn check_size(&self, pkt_size: usize) -> Result<(), StateError> { - let pkt_size = pkt_size.try_into()?; - - match self.max_outgoing_packet_size { - Some(max_size) if pkt_size > max_size => Err(StateError::OutgoingPacketTooLarge { - pkt_size, - max: max_size, - }), - _ => Ok(()), - } - } - fn save_pubrel(&mut self, mut pubrel: PubRel) -> Result { let pubrel = match pubrel.pkid { // consider PacketIdentifier(0) as uninitialized packets From c861a958736e5999a4fa6a87a7e6ac06a36e939e Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Tue, 19 Mar 2024 16:30:16 +0000 Subject: [PATCH 14/20] doc: changelog entry --- rumqttc/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rumqttc/CHANGELOG.md b/rumqttc/CHANGELOG.md index cc5c0c1e..70cce1df 100644 --- a/rumqttc/CHANGELOG.md +++ b/rumqttc/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +* Refactor `Network`, simplify with `Framed` + ### Deprecated ### Removed From 6c9a8d93f87f388cfcd3c3ce67f5028cfdbcd475 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Wed, 20 Mar 2024 15:08:48 +0000 Subject: [PATCH 15/20] allow configuring `MqttOptions.max_request_batch` Defaults to 10 --- rumqttc/src/lib.rs | 8 +++++++- rumqttc/src/v5/mod.rs | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 9c30d46a..85c0073e 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -483,7 +483,7 @@ impl MqttOptions { max_incoming_packet_size: 10 * 1024, max_outgoing_packet_size: 10 * 1024, request_channel_capacity: 10, - max_request_batch: 0, + max_request_batch: 10, pending_throttle: Duration::from_micros(0), inflight: 100, last_will: None, @@ -642,6 +642,12 @@ impl MqttOptions { self.request_channel_capacity } + /// set maximum request batch count + pub fn set_max_request_batch(&mut self, max_request_batch: usize) -> &mut Self { + self.max_request_batch = max_request_batch; + self + } + /// Enables throttling and sets outoing message rate to the specified 'rate' pub fn set_pending_throttle(&mut self, duration: Duration) -> &mut Self { self.pending_throttle = duration; diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 663cfd27..88f5279c 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -126,7 +126,7 @@ impl MqttOptions { client_id: id.into(), credentials: None, request_channel_capacity: 10, - max_request_batch: 0, + max_request_batch: 10, pending_throttle: Duration::from_micros(0), last_will: None, conn_timeout: 5, @@ -274,6 +274,12 @@ impl MqttOptions { self } + /// set maximum request batch count + pub fn set_max_request_batch(&mut self, max_request_batch: usize) -> &mut Self { + self.max_request_batch = max_request_batch; + self + } + /// Request channel capacity pub fn request_channel_capacity(&self) -> usize { self.request_channel_capacity From 33abd75373c73f08802cb8d1594356a274e8c7ea Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Wed, 20 Mar 2024 15:09:33 +0000 Subject: [PATCH 16/20] feat: buffer packets --- rumqttc/src/eventloop.rs | 17 +++++++++++++---- rumqttc/src/framed.rs | 34 ++++++++++++++++++++++++++++++---- rumqttc/src/v5/eventloop.rs | 15 +++++++++++---- rumqttc/src/v5/framed.rs | 32 +++++++++++++++++++++++++++++--- rumqttc/tests/broker.rs | 3 +++ 5 files changed, 86 insertions(+), 15 deletions(-) diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index 656b0871..87a239fb 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -185,7 +185,7 @@ impl EventLoop { o = network.read() => { let incoming = o?; if let Some(packet) = self.state.handle_incoming_packet(incoming)? { - network.send(packet).await?; + network.write(packet).await?; } Ok(self.state.events.pop_front().unwrap()) }, @@ -224,7 +224,7 @@ impl EventLoop { ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { Ok(request) => { if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { - network.send(outgoing).await?; + network.write(outgoing).await?; } Ok(self.state.events.pop_front().unwrap()) @@ -239,8 +239,11 @@ impl EventLoop { timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive); if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq(PingReq))? { - network.send(outgoing).await?; + network.write(outgoing).await?; } + // NOTE: Pings should be sent instantly + network.flush().await?; + Ok(self.state.events.pop_front().unwrap()) } } @@ -353,6 +356,7 @@ async fn network_connect( options.max_incoming_packet_size, options.max_outgoing_packet_size, network_timeout, + options.max_request_batch, ); return Ok(network); } @@ -390,6 +394,7 @@ async fn network_connect( options.max_incoming_packet_size, options.max_outgoing_packet_size, network_timeout, + options.max_request_batch, ), #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))] Transport::Tls(tls_config) => { @@ -401,6 +406,7 @@ async fn network_connect( options.max_incoming_packet_size, options.max_outgoing_packet_size, network_timeout, + options.max_request_batch, ) } #[cfg(unix)] @@ -425,6 +431,7 @@ async fn network_connect( options.max_incoming_packet_size, options.max_outgoing_packet_size, network_timeout, + options.max_request_batch, ) } #[cfg(all(feature = "use-rustls", feature = "websocket"))] @@ -453,6 +460,7 @@ async fn network_connect( options.max_incoming_packet_size, options.max_outgoing_packet_size, network_timeout, + options.max_request_batch, ) } }; @@ -479,7 +487,8 @@ async fn mqtt_connect( } // send mqtt connect packet - network.send(Packet::Connect(connect)).await?; + network.write(Packet::Connect(connect)).await?; + network.flush().await?; // validate connack match network.read().await? { diff --git a/rumqttc/src/framed.rs b/rumqttc/src/framed.rs index 6b17fb1e..10664f3a 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/framed.rs @@ -15,6 +15,10 @@ pub struct Network { framed: Framed, Codec>, /// Time within which network operations should complete timeout: Duration, + /// Number of packets currently written into buffer + buffered_packets: usize, + /// Maximum number of packets that can be buffered + max_buffered_packets: usize, } impl Network { @@ -23,6 +27,7 @@ impl Network { max_incoming_size: usize, max_outgoing_size: usize, timeout: Duration, + max_buffered_packets: usize, ) -> Network { let socket = Box::new(socket) as Box; let codec = Codec { @@ -31,7 +36,12 @@ impl Network { }; let framed = Framed::new(socket, codec); - Network { framed, timeout } + Network { + framed, + timeout, + buffered_packets: 0, + max_buffered_packets, + } } pub async fn read(&mut self) -> Result { @@ -42,9 +52,25 @@ impl Network { } } - pub async fn send(&mut self, packet: Packet) -> Result<(), crate::state::StateError> { - match timeout(self.timeout, self.framed.send(packet)).await { - Ok(inner) => inner.map_err(Into::into), + /// Write packets into buffer, flush after `MAX_BUFFERED_PACKETS`` + pub async fn write(&mut self, packet: Packet) -> Result<(), StateError> { + self.buffered_packets += 1; + self.framed + .feed(packet) + .await + .map_err(StateError::Deserialization)?; + if self.buffered_packets >= self.max_buffered_packets { + self.flush().await?; + } + + Ok(()) + } + + /// Force flush all packets in buffer, reset count + pub async fn flush(&mut self) -> Result<(), StateError> { + self.buffered_packets = 0; + match timeout(self.timeout, self.framed.flush()).await { + Ok(inner) => inner.map_err(StateError::Deserialization), Err(_) => Err(StateError::FlushTimeout), } } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 7c8b5e51..f54b9b52 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -211,7 +211,7 @@ impl EventLoop { ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { Ok(request) => { if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { - network.send(outgoing).await?; + network.write(outgoing).await?; } Ok(self.state.events.pop_front().unwrap()) @@ -222,7 +222,7 @@ impl EventLoop { o = network.read() => { let incoming = o?; if let Some(packet) = self.state.handle_incoming_packet(incoming)? { - network.send(packet).await?; + network.write(packet).await?; } Ok(self.state.events.pop_front().unwrap()) @@ -234,8 +234,9 @@ impl EventLoop { timeout.as_mut().reset(Instant::now() + self.options.keep_alive); if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? { - network.send(outgoing).await?; + network.write(outgoing).await?; } + network.flush().await?; Ok(self.state.events.pop_front().unwrap()) } @@ -304,6 +305,7 @@ async fn network_connect(options: &MqttOptions) -> Result Result { @@ -356,6 +359,7 @@ async fn network_connect(options: &MqttOptions) -> Result Result Result, Codec>, /// Time within which network operations should complete timeout: Duration, + /// Number of packets currently written into buffer + buffered_packets: usize, + /// Maximum number of packets that can be buffered + max_buffered_packets: usize, } impl Network { pub fn new( @@ -25,6 +29,7 @@ impl Network { max_incoming_size: Option, max_outgoing_size: Option, timeout: Duration, + max_buffered_packets: usize, ) -> Network { let socket = Box::new(socket) as Box; let codec = Codec { @@ -33,7 +38,12 @@ impl Network { }; let framed = Framed::new(socket, codec); - Network { framed, timeout } + Network { + framed, + timeout, + buffered_packets: 0, + max_buffered_packets, + } } pub async fn read(&mut self) -> Result { @@ -44,8 +54,24 @@ impl Network { } } - pub async fn send(&mut self, packet: Packet) -> Result<(), StateError> { - match timeout(self.timeout, self.framed.send(packet)).await { + /// Write packets into buffer, flush after `MAX_BUFFERED_PACKETS` + pub async fn write(&mut self, packet: Packet) -> Result<(), StateError> { + self.buffered_packets += 1; + self.framed + .feed(packet) + .await + .map_err(StateError::Deserialization)?; + if self.buffered_packets >= self.max_buffered_packets { + self.flush().await?; + } + + Ok(()) + } + + /// Force flush all packets in buffer, reset count + pub async fn flush(&mut self) -> Result<(), StateError> { + self.buffered_packets = 0; + match timeout(self.timeout, self.framed.flush()).await { Ok(inner) => inner.map_err(StateError::Deserialization), Err(e) => Err(StateError::Timeout(e)), } diff --git a/rumqttc/tests/broker.rs b/rumqttc/tests/broker.rs index ea66448f..cbf781ed 100644 --- a/rumqttc/tests/broker.rs +++ b/rumqttc/tests/broker.rs @@ -147,6 +147,9 @@ impl Broker { /// Selects between outgoing and incoming packets pub async fn tick(&mut self) -> Event { + if let Some(incoming) = self.incoming.pop_front() { + return Event::Incoming(incoming); + } select! { request = self.outgoing_rx.recv_async() => { let request = request.unwrap(); From 8b0aaa1a2ee681a5e036364a0a5df91d83ce37e8 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Wed, 20 Mar 2024 15:10:01 +0000 Subject: [PATCH 17/20] test: fix by avoiding buffer --- rumqttc/tests/reliability.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 49ce30d6..d55abd25 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -292,7 +292,7 @@ async fn requests_are_blocked_after_max_inflight_queue_size() { #[tokio::test] async fn requests_are_recovered_after_inflight_queue_size_falls_below_max() { let mut options = MqttOptions::new("dummy", "127.0.0.1", 1888); - options.set_inflight(3); + options.set_inflight(3).set_max_request_batch(1); let (client, mut eventloop) = AsyncClient::new(options, 5); @@ -474,7 +474,9 @@ async fn next_poll_after_connect_failure_reconnects() { #[tokio::test] async fn reconnection_resumes_from_the_previous_state() { let mut options = MqttOptions::new("dummy", "127.0.0.1", 3001); - options.set_keep_alive(Duration::from_secs(5)); + options + .set_keep_alive(Duration::from_secs(5)) + .set_max_request_batch(1); // start sending qos0 publishes. Makes sure that there is out activity but no in activity let (client, mut eventloop) = AsyncClient::new(options, 5); @@ -514,7 +516,9 @@ async fn reconnection_resumes_from_the_previous_state() { #[tokio::test] async fn reconnection_resends_unacked_packets_from_the_previous_connection_first() { let mut options = MqttOptions::new("dummy", "127.0.0.1", 3002); - options.set_keep_alive(Duration::from_secs(5)); + options + .set_keep_alive(Duration::from_secs(5)) + .set_max_request_batch(1); // start sending qos0 publishes. this makes sure that there is // outgoing activity but no incoming activity From f71665bf1581d35576bfc4993be6e89f18bf3efb Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Thu, 21 Mar 2024 17:05:35 +0000 Subject: [PATCH 18/20] fix: network buffering based on bytes written --- rumqttc/src/eventloop.rs | 10 +++++----- rumqttc/src/framed.rs | 35 ++++++++++------------------------- rumqttc/src/lib.rs | 22 +++++++++++----------- rumqttc/src/state.rs | 2 ++ rumqttc/src/v5/eventloop.rs | 10 +++++----- rumqttc/src/v5/framed.rs | 33 +++++++++------------------------ rumqttc/src/v5/mod.rs | 22 +++++++++++----------- rumqttc/src/v5/state.rs | 2 ++ rumqttc/tests/reliability.rs | 6 +++--- 9 files changed, 58 insertions(+), 84 deletions(-) diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index 87a239fb..604f0ff3 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -356,7 +356,7 @@ async fn network_connect( options.max_incoming_packet_size, options.max_outgoing_packet_size, network_timeout, - options.max_request_batch, + options.network_buffer_capacity, ); return Ok(network); } @@ -394,7 +394,7 @@ async fn network_connect( options.max_incoming_packet_size, options.max_outgoing_packet_size, network_timeout, - options.max_request_batch, + options.network_buffer_capacity, ), #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))] Transport::Tls(tls_config) => { @@ -406,7 +406,7 @@ async fn network_connect( options.max_incoming_packet_size, options.max_outgoing_packet_size, network_timeout, - options.max_request_batch, + options.network_buffer_capacity, ) } #[cfg(unix)] @@ -431,7 +431,7 @@ async fn network_connect( options.max_incoming_packet_size, options.max_outgoing_packet_size, network_timeout, - options.max_request_batch, + options.network_buffer_capacity, ) } #[cfg(all(feature = "use-rustls", feature = "websocket"))] @@ -460,7 +460,7 @@ async fn network_connect( options.max_incoming_packet_size, options.max_outgoing_packet_size, network_timeout, - options.max_request_batch, + options.network_buffer_capacity, ) } }; diff --git a/rumqttc/src/framed.rs b/rumqttc/src/framed.rs index 10664f3a..d713f6bc 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/framed.rs @@ -13,12 +13,8 @@ use std::time::Duration; pub struct Network { /// Frame MQTT packets from network connection framed: Framed, Codec>, - /// Time within which network operations should complete + /// Time within which network write operations should complete timeout: Duration, - /// Number of packets currently written into buffer - buffered_packets: usize, - /// Maximum number of packets that can be buffered - max_buffered_packets: usize, } impl Network { @@ -27,48 +23,37 @@ impl Network { max_incoming_size: usize, max_outgoing_size: usize, timeout: Duration, - max_buffered_packets: usize, + buffer_capacity: usize, ) -> Network { let socket = Box::new(socket) as Box; let codec = Codec { max_incoming_size, max_outgoing_size, }; - let framed = Framed::new(socket, codec); + let framed = Framed::with_capacity(socket, codec, buffer_capacity); - Network { - framed, - timeout, - buffered_packets: 0, - max_buffered_packets, - } + Network { framed, timeout } } pub async fn read(&mut self) -> Result { match self.framed.next().await { Some(Ok(packet)) => Ok(packet), - Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(), + Some(Err(mqttbytes::Error::InsufficientBytes(_))) => unreachable!(), Some(Err(e)) => Err(StateError::Deserialization(e)), + None => Err(StateError::ConnectionClosed), } } - /// Write packets into buffer, flush after `MAX_BUFFERED_PACKETS`` + /// Write packets into buffer, flush instantly for `PingReq`/`PingResp` pub async fn write(&mut self, packet: Packet) -> Result<(), StateError> { - self.buffered_packets += 1; - self.framed - .feed(packet) - .await - .map_err(StateError::Deserialization)?; - if self.buffered_packets >= self.max_buffered_packets { - self.flush().await?; + match timeout(self.timeout, self.framed.send(packet)).await { + Ok(inner) => inner.map_err(StateError::Deserialization), + Err(_) => Err(StateError::FlushTimeout), } - - Ok(()) } /// Force flush all packets in buffer, reset count pub async fn flush(&mut self) -> Result<(), StateError> { - self.buffered_packets = 0; match timeout(self.timeout, self.framed.flush()).await { Ok(inner) => inner.map_err(StateError::Deserialization), Err(_) => Err(StateError::FlushTimeout), diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 85c0073e..4c9acf3e 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -442,8 +442,8 @@ pub struct MqttOptions { max_outgoing_packet_size: usize, /// request (publish, subscribe) channel capacity request_channel_capacity: usize, - /// Max internal request batching - max_request_batch: usize, + /// Network buffer capacity in memory + network_buffer_capacity: usize, /// Minimum delay time between consecutive outgoing packets /// while retransmitting pending packets pending_throttle: Duration, @@ -483,7 +483,7 @@ impl MqttOptions { max_incoming_packet_size: 10 * 1024, max_outgoing_packet_size: 10 * 1024, request_channel_capacity: 10, - max_request_batch: 10, + network_buffer_capacity: 10 * 1024, pending_throttle: Duration::from_micros(0), inflight: 100, last_will: None, @@ -642,9 +642,9 @@ impl MqttOptions { self.request_channel_capacity } - /// set maximum request batch count - pub fn set_max_request_batch(&mut self, max_request_batch: usize) -> &mut Self { - self.max_request_batch = max_request_batch; + /// Maximum buffer capacity before network flush + pub fn set_network_buffer_capacity(&mut self, network_buffer_capacity: usize) -> &mut Self { + self.network_buffer_capacity = network_buffer_capacity; self } @@ -848,12 +848,12 @@ impl std::convert::TryFrom for MqttOptions { options.request_channel_capacity = request_channel_capacity; } - if let Some(max_request_batch) = queries - .remove("max_request_batch_num") + if let Some(network_buffer_capacity) = queries + .remove("network_buffer_capacity_num") .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) .transpose()? { - options.max_request_batch = max_request_batch; + options.network_buffer_capacity = network_buffer_capacity; } if let Some(pending_throttle) = queries @@ -893,7 +893,7 @@ impl Debug for MqttOptions { .field("credentials", &self.credentials) .field("max_packet_size", &self.max_incoming_packet_size) .field("request_channel_capacity", &self.request_channel_capacity) - .field("max_request_batch", &self.max_request_batch) + .field("network_buffer_capacity", &self.network_buffer_capacity) .field("pending_throttle", &self.pending_throttle) .field("inflight", &self.inflight) .field("last_will", &self.last_will) @@ -976,7 +976,7 @@ mod test { OptionError::RequestChannelCapacity ); assert_eq!( - err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), + err("mqtt://host:42?client_id=foo&network_buffer_capacity_num=foo"), OptionError::MaxRequestBatch ); assert_eq!( diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index c1014022..06a75c1b 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -31,6 +31,8 @@ pub enum StateError { Deserialization(#[from] mqttbytes::Error), #[error("Flush timeout")] FlushTimeout, + #[error("Connection Closed")] + ConnectionClosed, } /// State of the mqtt connection. diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index f54b9b52..9fe13dbe 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -305,7 +305,7 @@ async fn network_connect(options: &MqttOptions) -> Result Result { @@ -359,7 +359,7 @@ async fn network_connect(options: &MqttOptions) -> Result Result Result, Codec>, - /// Time within which network operations should complete + /// Time within which network write operations should complete timeout: Duration, - /// Number of packets currently written into buffer - buffered_packets: usize, - /// Maximum number of packets that can be buffered - max_buffered_packets: usize, } impl Network { pub fn new( @@ -29,48 +25,37 @@ impl Network { max_incoming_size: Option, max_outgoing_size: Option, timeout: Duration, - max_buffered_packets: usize, + buffer_capacity: usize, ) -> Network { let socket = Box::new(socket) as Box; let codec = Codec { max_incoming_size, max_outgoing_size, }; - let framed = Framed::new(socket, codec); + let framed = Framed::with_capacity(socket, codec, buffer_capacity); - Network { - framed, - timeout, - buffered_packets: 0, - max_buffered_packets, - } + Network { framed, timeout } } pub async fn read(&mut self) -> Result { match self.framed.next().await { Some(Ok(packet)) => Ok(packet), - Some(Err(mqttbytes::Error::InsufficientBytes(_))) | None => unreachable!(), + Some(Err(mqttbytes::Error::InsufficientBytes(_))) => unreachable!(), Some(Err(e)) => Err(StateError::Deserialization(e)), + None => Err(StateError::ConnectionClosed), } } /// Write packets into buffer, flush after `MAX_BUFFERED_PACKETS` pub async fn write(&mut self, packet: Packet) -> Result<(), StateError> { - self.buffered_packets += 1; - self.framed - .feed(packet) - .await - .map_err(StateError::Deserialization)?; - if self.buffered_packets >= self.max_buffered_packets { - self.flush().await?; + match timeout(self.timeout, self.framed.send(packet)).await { + Ok(inner) => inner.map_err(StateError::Deserialization), + Err(e) => Err(StateError::Timeout(e)), } - - Ok(()) } /// Force flush all packets in buffer, reset count pub async fn flush(&mut self) -> Result<(), StateError> { - self.buffered_packets = 0; match timeout(self.timeout, self.framed.flush()).await { Ok(inner) => inner.map_err(StateError::Deserialization), Err(e) => Err(StateError::Timeout(e)), diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 88f5279c..8fce8b49 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -78,8 +78,8 @@ pub struct MqttOptions { credentials: Option<(String, String)>, /// request (publish, subscribe) channel capacity request_channel_capacity: usize, - /// Max internal request batching - max_request_batch: usize, + /// Network buffer capacity in memory + network_buffer_capacity: usize, /// Minimum delay time between consecutive outgoing packets /// while retransmitting pending packets pending_throttle: Duration, @@ -126,7 +126,7 @@ impl MqttOptions { client_id: id.into(), credentials: None, request_channel_capacity: 10, - max_request_batch: 10, + network_buffer_capacity: 10 * 1024, pending_throttle: Duration::from_micros(0), last_will: None, conn_timeout: 5, @@ -274,9 +274,9 @@ impl MqttOptions { self } - /// set maximum request batch count - pub fn set_max_request_batch(&mut self, max_request_batch: usize) -> &mut Self { - self.max_request_batch = max_request_batch; + /// Maximum buffer capacity before network flush + pub fn set_network_buffer_capacity(&mut self, network_buffer_capacity: usize) -> &mut Self { + self.network_buffer_capacity = network_buffer_capacity; self } @@ -660,12 +660,12 @@ impl std::convert::TryFrom for MqttOptions { options.request_channel_capacity = request_channel_capacity; } - if let Some(max_request_batch) = queries - .remove("max_request_batch_num") + if let Some(network_buffer_capacity) = queries + .remove("network_buffer_capacity_num") .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) .transpose()? { - options.max_request_batch = max_request_batch; + options.network_buffer_capacity = network_buffer_capacity; } if let Some(pending_throttle) = queries @@ -710,7 +710,7 @@ impl Debug for MqttOptions { .field("client_id", &self.client_id) .field("credentials", &self.credentials) .field("request_channel_capacity", &self.request_channel_capacity) - .field("max_request_batch", &self.max_request_batch) + .field("network_buffer_capacity", &self.network_buffer_capacity) .field("pending_throttle", &self.pending_throttle) .field("last_will", &self.last_will) .field("conn_timeout", &self.conn_timeout) @@ -791,7 +791,7 @@ mod test { OptionError::RequestChannelCapacity ); assert_eq!( - err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), + err("mqtt://host:42?client_id=foo&network_buffer_capacity_num=foo"), OptionError::MaxRequestBatch ); assert_eq!( diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index c817191e..7d41da1c 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -63,6 +63,8 @@ pub enum StateError { ConnFail { reason: ConnectReturnCode }, #[error("Timeout")] Timeout(#[from] Elapsed), + #[error("Connection Closed")] + ConnectionClosed, } /// State of the mqtt connection. diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index d55abd25..79b121f6 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -292,7 +292,7 @@ async fn requests_are_blocked_after_max_inflight_queue_size() { #[tokio::test] async fn requests_are_recovered_after_inflight_queue_size_falls_below_max() { let mut options = MqttOptions::new("dummy", "127.0.0.1", 1888); - options.set_inflight(3).set_max_request_batch(1); + options.set_inflight(3).set_network_buffer_capacity(0); // NOTE: to instantly flush let (client, mut eventloop) = AsyncClient::new(options, 5); @@ -476,7 +476,7 @@ async fn reconnection_resumes_from_the_previous_state() { let mut options = MqttOptions::new("dummy", "127.0.0.1", 3001); options .set_keep_alive(Duration::from_secs(5)) - .set_max_request_batch(1); + .set_network_buffer_capacity(0); // NOTE: to instantly flush // start sending qos0 publishes. Makes sure that there is out activity but no in activity let (client, mut eventloop) = AsyncClient::new(options, 5); @@ -518,7 +518,7 @@ async fn reconnection_resends_unacked_packets_from_the_previous_connection_first let mut options = MqttOptions::new("dummy", "127.0.0.1", 3002); options .set_keep_alive(Duration::from_secs(5)) - .set_max_request_batch(1); + .set_network_buffer_capacity(0); // NOTE: to instantly flush // start sending qos0 publishes. this makes sure that there is // outgoing activity but no incoming activity From 83ed8c4c3d099e05fe2d4cba455d13b8a00d6624 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Thu, 21 Mar 2024 17:10:23 +0000 Subject: [PATCH 19/20] fix: don't return `Error::InsufficientBytes` --- rumqttc/src/mqttbytes/v4/codec.rs | 15 +++++++++------ rumqttc/src/v5/mqttbytes/v5/codec.rs | 15 +++++++++------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/rumqttc/src/mqttbytes/v4/codec.rs b/rumqttc/src/mqttbytes/v4/codec.rs index 3e7c73d5..c8fb9173 100644 --- a/rumqttc/src/mqttbytes/v4/codec.rs +++ b/rumqttc/src/mqttbytes/v4/codec.rs @@ -1,4 +1,4 @@ -use bytes::{Buf, BytesMut}; +use bytes::BytesMut; use tokio_util::codec::{Decoder, Encoder}; use super::{Error, Packet}; @@ -17,12 +17,15 @@ impl Decoder for Codec { type Error = Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - if src.remaining() == 0 { - return Ok(None); + match Packet::read(src, self.max_incoming_size) { + Ok(packet) => Ok(Some(packet)), + // NOTE: not enough bytes to construct packet, reserve enough in src buffer + Err(Error::InsufficientBytes(b)) => { + src.reserve(b); + Ok(None) + } + Err(e) => Err(e), } - - let packet = Packet::read(src, self.max_incoming_size)?; - Ok(Some(packet)) } } diff --git a/rumqttc/src/v5/mqttbytes/v5/codec.rs b/rumqttc/src/v5/mqttbytes/v5/codec.rs index fc24105c..2bd3272b 100644 --- a/rumqttc/src/v5/mqttbytes/v5/codec.rs +++ b/rumqttc/src/v5/mqttbytes/v5/codec.rs @@ -1,4 +1,4 @@ -use bytes::{Buf, BytesMut}; +use bytes::BytesMut; use tokio_util::codec::{Decoder, Encoder}; use super::{Error, Packet}; @@ -17,12 +17,15 @@ impl Decoder for Codec { type Error = Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - if src.remaining() == 0 { - return Ok(None); + match Packet::read(src, self.max_incoming_size) { + Ok(packet) => Ok(Some(packet)), + // NOTE: not enough bytes to construct packet, reserve enough in src buffer + Err(Error::InsufficientBytes(b)) => { + src.reserve(b); + Ok(None) + } + Err(e) => Err(e), } - - let packet = Packet::read(src, self.max_incoming_size)?; - Ok(Some(packet)) } } From aa08d346d67483448e98186a0a8986ad4fdf4ec4 Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Thu, 21 Mar 2024 21:23:23 +0000 Subject: [PATCH 20/20] fix: buffer til capacity breach --- rumqttc/src/eventloop.rs | 3 --- rumqttc/src/framed.rs | 28 +++++++++++++++++++++++----- rumqttc/src/v5/eventloop.rs | 2 -- rumqttc/src/v5/framed.rs | 29 ++++++++++++++++++++++++----- 4 files changed, 47 insertions(+), 15 deletions(-) diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index 604f0ff3..5317e437 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -241,8 +241,6 @@ impl EventLoop { if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq(PingReq))? { network.write(outgoing).await?; } - // NOTE: Pings should be sent instantly - network.flush().await?; Ok(self.state.events.pop_front().unwrap()) } @@ -488,7 +486,6 @@ async fn mqtt_connect( // send mqtt connect packet network.write(Packet::Connect(connect)).await?; - network.flush().await?; // validate connack match network.read().await? { diff --git a/rumqttc/src/framed.rs b/rumqttc/src/framed.rs index d713f6bc..336de5e5 100644 --- a/rumqttc/src/framed.rs +++ b/rumqttc/src/framed.rs @@ -15,6 +15,8 @@ pub struct Network { framed: Framed, Codec>, /// Time within which network write operations should complete timeout: Duration, + /// Capacity upto which buffering is good + buffer_capacity: usize, } impl Network { @@ -32,7 +34,11 @@ impl Network { }; let framed = Framed::with_capacity(socket, codec, buffer_capacity); - Network { framed, timeout } + Network { + framed, + timeout, + buffer_capacity, + } } pub async fn read(&mut self) -> Result { @@ -44,12 +50,24 @@ impl Network { } } - /// Write packets into buffer, flush instantly for `PingReq`/`PingResp` + /// Write packets into buffer, flushes `Connect`/`PingReq`/`PingResp` packets instantly, + /// or on breaching buffer capacity pub async fn write(&mut self, packet: Packet) -> Result<(), StateError> { - match timeout(self.timeout, self.framed.send(packet)).await { - Ok(inner) => inner.map_err(StateError::Deserialization), - Err(_) => Err(StateError::FlushTimeout), + let packet_size = packet.size(); + let should_flush = match packet { + Packet::Connect(_) | Packet::PingReq | Packet::PingResp => true, + _ => false, + }; + self.framed + .feed(packet) + .await + .map_err(StateError::Deserialization)?; + + if should_flush || self.framed.write_buffer().len() + packet_size >= self.buffer_capacity { + self.flush().await?; } + + Ok(()) } /// Force flush all packets in buffer, reset count diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index 9fe13dbe..c0ba2ee7 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -236,7 +236,6 @@ impl EventLoop { if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? { network.write(outgoing).await?; } - network.flush().await?; Ok(self.state.events.pop_front().unwrap()) } @@ -442,7 +441,6 @@ async fn mqtt_connect( network .write(Packet::Connect(connect, last_will, None)) .await?; - network.flush().await?; // validate connack match network.read().await? { diff --git a/rumqttc/src/v5/framed.rs b/rumqttc/src/v5/framed.rs index 4cb8421a..e6796944 100644 --- a/rumqttc/src/v5/framed.rs +++ b/rumqttc/src/v5/framed.rs @@ -18,7 +18,10 @@ pub struct Network { framed: Framed, Codec>, /// Time within which network write operations should complete timeout: Duration, + /// Capacity upto which buffering is good + buffer_capacity: usize, } + impl Network { pub fn new( socket: impl AsyncReadWrite + 'static, @@ -34,7 +37,11 @@ impl Network { }; let framed = Framed::with_capacity(socket, codec, buffer_capacity); - Network { framed, timeout } + Network { + framed, + timeout, + buffer_capacity, + } } pub async fn read(&mut self) -> Result { @@ -46,12 +53,24 @@ impl Network { } } - /// Write packets into buffer, flush after `MAX_BUFFERED_PACKETS` + /// Write packets into buffer, flushes `Connect`/`PingReq`/`PingResp` packets instantly, + /// or on breaching buffer capacity pub async fn write(&mut self, packet: Packet) -> Result<(), StateError> { - match timeout(self.timeout, self.framed.send(packet)).await { - Ok(inner) => inner.map_err(StateError::Deserialization), - Err(e) => Err(StateError::Timeout(e)), + let packet_size = packet.size(); + let should_flush = match packet { + Packet::Connect(..) | Packet::PingReq(_) | Packet::PingResp(_) => true, + _ => false, + }; + self.framed + .feed(packet) + .await + .map_err(StateError::Deserialization)?; + + if should_flush || self.framed.write_buffer().len() + packet_size >= self.buffer_capacity { + self.flush().await?; } + + Ok(()) } /// Force flush all packets in buffer, reset count