diff --git a/Cargo.toml b/Cargo.toml index cc381a4..441c959 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ serde_json = "1" thiserror = "1" [dev-dependencies] +rand = "0.8" env_logger = "0.11" ntex-tls = "1.1" openssl = "0.10" diff --git a/src/io.rs b/src/io.rs index 2895e44..edd2917 100644 --- a/src/io.rs +++ b/src/io.rs @@ -21,13 +21,8 @@ pin_project_lite::pin_project! { U: Decoder, U: 'static, { - codec: U, - service: Pipeline, inner: DispatcherInner, pool: Pool, - #[pin] - response: Option>>, - response_idx: usize, } } @@ -42,9 +37,11 @@ bitflags::bitflags! { } } -struct DispatcherInner>, U: Encoder + Decoder> { +struct DispatcherInner>, U: Encoder + Decoder + 'static> { io: IoBoxed, flags: Flags, + codec: U, + service: Pipeline, st: IoDispatcherState, state: Rc>>, config: DispatcherConfig, @@ -52,6 +49,9 @@ struct DispatcherInner>, U: Encoder + Decoder> { read_remains_prev: u32, read_max_timeout: Seconds, keepalive_timeout: Seconds, + + response: Option>>, + response_idx: usize, } struct DispatcherState>, U: Encoder + Decoder> { @@ -78,6 +78,7 @@ impl ServiceResult { #[derive(Copy, Clone, Debug)] enum IoDispatcherState { Processing, + Backpressure, Stop, Shutdown, } @@ -124,13 +125,10 @@ where let keepalive_timeout = config.keepalive_timeout(); Dispatcher { - codec, pool, - service: Pipeline::new(service.into_service()), - response: None, - response_idx: 0, inner: DispatcherInner { io, + codec, state, keepalive_timeout, flags: if keepalive_timeout.is_zero() { @@ -138,8 +136,11 @@ where } else { Flags::empty() }, + service: Pipeline::new(service.into_service()), config: config.clone(), st: IoDispatcherState::Processing, + response: None, + response_idx: 0, read_remains: 0, read_remains_prev: 0, read_max_timeout: Seconds::ZERO, @@ -234,16 +235,16 @@ where let inner = &mut this.inner; // handle service response future - if let Some(fut) = this.response.as_mut().as_pin_mut() { - if let Poll::Ready(item) = fut.poll(cx) { + if let Some(fut) = inner.response.as_mut() { + if let Poll::Ready(item) = Pin::new(fut).poll(cx) { inner.state.borrow_mut().handle_result( item, - *this.response_idx, + inner.response_idx, inner.io.as_ref(), - this.codec, + &inner.codec, false, ); - this.response.set(None); + inner.response = None; } } @@ -258,10 +259,10 @@ where loop { match inner.st { IoDispatcherState::Processing => { - let item = match ready!(inner.poll_service(this.service, cx)) { + let item = match ready!(inner.poll_service(cx)) { PollService::Ready => { // decode incoming bytes stream - match inner.io.poll_recv_decode(this.codec, cx) { + match inner.io.poll_recv_decode(&inner.codec, cx) { Ok(decoded) => { inner.update_timer(&decoded); if let Some(el) = decoded.item { @@ -287,12 +288,8 @@ where } } Err(RecvError::WriteBackpressure) => { - if let Err(err) = ready!(inner.io.poll_flush(cx, false)) { - inner.st = IoDispatcherState::Stop; - DispatchItem::Disconnect(Some(err)) - } else { - continue; - } + inner.st = IoDispatcherState::Backpressure; + DispatchItem::WBackPressureEnabled } Err(RecvError::Decoder(err)) => { inner.st = IoDispatcherState::Stop; @@ -308,65 +305,35 @@ where PollService::Continue => continue, }; - // optimize first call - if this.response.is_none() { - this.response.set(Some(this.service.call_static(item))); - - let res = this.response.as_mut().as_pin_mut().unwrap().poll(cx); - let mut state = inner.state.borrow_mut(); - - if let Poll::Ready(res) = res { - // check if current result is only response - if state.queue.is_empty() { - match res { - Err(err) => { - state.error = Some(err.into()); - } - Ok(Some(item)) => { - if let Err(err) = inner.io.encode(item, this.codec) { - state.error = Some(IoDispatcherError::Encoder(err)); - } - } - Ok(None) => (), - } - } else { - *this.response_idx = state.base.wrapping_add(state.queue.len()); - state.queue.push_back(ServiceResult::Ready(res)); - } - this.response.set(None); - } else { - *this.response_idx = state.base.wrapping_add(state.queue.len()); - state.queue.push_back(ServiceResult::Pending); + inner.call_service(cx, item); + } + // handle write back-pressure + IoDispatcherState::Backpressure => { + match ready!(inner.poll_service(cx)) { + PollService::Ready => (), + PollService::Item(item) => { + inner.call_service(cx, item); } + PollService::Continue => continue, + }; + + let item = if let Err(err) = ready!(inner.io.poll_flush(cx, false)) { + inner.st = IoDispatcherState::Stop; + DispatchItem::Disconnect(Some(err)) } else { - let mut state = inner.state.borrow_mut(); - let response_idx = state.base.wrapping_add(state.queue.len()); - state.queue.push_back(ServiceResult::Pending); - - let st = inner.io.get_ref(); - let codec = this.codec.clone(); - let state = inner.state.clone(); - let fut = this.service.call_static(item); - #[allow(clippy::let_underscore_future)] - let _ = ntex::rt::spawn(async move { - let item = fut.await; - state.borrow_mut().handle_result( - item, - response_idx, - &st, - &codec, - true, - ); - }); - } + inner.st = IoDispatcherState::Processing; + DispatchItem::WBackPressureDisabled + }; + inner.call_service(cx, item); } + // drain service responses and shutdown io IoDispatcherState::Stop => { inner.io.stop_timer(); // service may relay on poll_ready for response results if !inner.flags.contains(Flags::READY_ERR) { - let _ = this.service.poll_ready(cx); + let _ = inner.service.poll_ready(cx); } if inner.state.borrow().queue.is_empty() { @@ -397,7 +364,7 @@ where } // shutdown service IoDispatcherState::Shutdown => { - return if this.service.poll_shutdown(cx).is_ready() { + return if inner.service.poll_shutdown(cx).is_ready() { log::trace!("{}: Service shutdown is completed, stop", inner.io.tag()); Poll::Ready( @@ -424,12 +391,53 @@ where U: Decoder + Encoder + Clone + 'static, ::Item: 'static, { - fn poll_service( - &mut self, - srv: &Pipeline, - cx: &mut Context<'_>, - ) -> Poll> { - match srv.poll_ready(cx) { + fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem) { + let mut state = self.state.borrow_mut(); + let mut fut = self.service.call_static(item); + + // optimize first call + if self.response.is_none() { + if let Poll::Ready(res) = Pin::new(&mut fut).poll(cx) { + // check if current result is only response + if state.queue.is_empty() { + match res { + Err(err) => { + state.error = Some(err.into()); + } + Ok(Some(item)) => { + if let Err(err) = self.io.encode(item, &self.codec) { + state.error = Some(IoDispatcherError::Encoder(err)); + } + } + Ok(None) => (), + } + } else { + self.response_idx = state.base.wrapping_add(state.queue.len()); + state.queue.push_back(ServiceResult::Ready(res)); + } + } else { + self.response = Some(fut); + self.response_idx = state.base.wrapping_add(state.queue.len()); + state.queue.push_back(ServiceResult::Pending); + } + } else { + let response_idx = state.base.wrapping_add(state.queue.len()); + state.queue.push_back(ServiceResult::Pending); + + let st = self.io.get_ref(); + let codec = self.codec.clone(); + let state = self.state.clone(); + + #[allow(clippy::let_underscore_future)] + let _ = ntex::rt::spawn(async move { + let item = fut.await; + state.borrow_mut().handle_result(item, response_idx, &st, &codec, true); + }); + } + } + + fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll> { + match self.service.poll_ready(cx) { Poll::Ready(Ok(_)) => { // check for errors let mut state = self.state.borrow_mut(); @@ -483,7 +491,10 @@ where self.st = IoDispatcherState::Stop; Poll::Ready(PollService::Item(DispatchItem::Disconnect(err))) } - IoStatusUpdate::WriteBackpressure => Poll::Pending, + IoStatusUpdate::WriteBackpressure => { + self.st = IoDispatcherState::Backpressure; + Poll::Ready(PollService::Item(DispatchItem::WBackPressureEnabled)) + } } } // handle service readiness error @@ -577,6 +588,7 @@ mod tests { use ntex::time::{sleep, Millis}; use ntex::util::{Bytes, BytesMut}; use ntex::{codec::BytesCodec, io as nio, service::ServiceCtx, testing::Io}; + use rand::Rng; use super::*; @@ -614,15 +626,15 @@ mod tests { ( Dispatcher { - codec, - service: Pipeline::new(service.into_service()), - response: None, - response_idx: 0, pool: io.memory_pool().pool(), inner: DispatcherInner { + codec, state, config, keepalive_timeout, + service: Pipeline::new(service.into_service()), + response: None, + response_idx: 0, io: IoBoxed::from(io), st: IoDispatcherState::Processing, flags: if keepalive_timeout.is_zero() { @@ -832,6 +844,75 @@ mod tests { assert_eq!(counter.get(), 1); } + #[ntex::test] + async fn test_write_backpressure() { + let (client, server) = Io::create(); + // do not allow to write to socket + client.remote_buffer_cap(0); + client.write("GET /test HTTP/1\r\n\r\n"); + + let data = Arc::new(Mutex::new(RefCell::new(Vec::new()))); + let data2 = data.clone(); + + let (disp, io) = Dispatcher::new_debug( + nio::Io::new(server), + BytesCodec, + ntex::service::fn_service(move |msg: DispatchItem| { + let data = data2.clone(); + async move { + match msg { + DispatchItem::Item(_) => { + data.lock().unwrap().borrow_mut().push(0); + let bytes = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(65_536) + .map(char::from) + .collect::(); + return Ok::<_, ()>(Some(Bytes::from(bytes))); + } + DispatchItem::WBackPressureEnabled => { + data.lock().unwrap().borrow_mut().push(1); + } + DispatchItem::WBackPressureDisabled => { + data.lock().unwrap().borrow_mut().push(2); + } + _ => (), + } + Ok(None) + } + }), + ); + let pool = io.memory_pool().pool().pool_ref(); + pool.set_read_params(8 * 1024, 1024); + pool.set_write_params(16 * 1024, 1024); + + ntex::rt::spawn(async move { + let _ = disp.await; + }); + + let buf = client.read_any(); + assert_eq!(buf, Bytes::from_static(b"")); + client.write("GET /test HTTP/1\r\n\r\n"); + sleep(Millis(25)).await; + + // buf must be consumed + assert_eq!(client.remote_buffer(|buf| buf.len()), 0); + + // response message + assert_eq!(io.with_write_buf(|buf| buf.len()).unwrap(), 65536); + + client.remote_buffer_cap(10240); + sleep(Millis(50)).await; + assert_eq!(io.with_write_buf(|buf| buf.len()).unwrap(), 55296); + + client.remote_buffer_cap(45056); + sleep(Millis(50)).await; + assert_eq!(io.with_write_buf(|buf| buf.len()).unwrap(), 10240); + + // backpressure disabled + assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]); + } + #[ntex::test] async fn test_shutdown_dispatcher_waker() { let (client, server) = Io::create(); diff --git a/src/v3/control.rs b/src/v3/control.rs index 598f885..0bac06b 100644 --- a/src/v3/control.rs +++ b/src/v3/control.rs @@ -16,6 +16,8 @@ pub enum Control { Subscribe(Subscribe), /// Unsubscribe packet Unsubscribe(Unsubscribe), + /// Write back-pressure is enabled/disabled + WrBackpressure(WrBackpressure), /// Connection dropped Closed(Closed), /// Service level error @@ -67,10 +69,14 @@ impl Control { Control::Disconnect(Disconnect) } - pub(super) fn closed() -> Self { + pub(super) const fn closed() -> Self { Control::Closed(Closed) } + pub(super) const fn wr_backpressure(enabled: bool) -> Self { + Control::WrBackpressure(WrBackpressure(enabled)) + } + pub(super) fn error(err: E) -> Self { Control::Error(Error::new(err)) } @@ -102,6 +108,7 @@ impl Control { log::warn!("Unsubscribe is not supported"); ControlAck { result: ControlAckKind::Disconnect } } + Control::WrBackpressure(msg) => msg.ack(), Control::Closed(msg) => msg.ack(), Control::Error(msg) => msg.ack(), Control::ProtocolError(msg) => msg.ack(), @@ -369,6 +376,24 @@ impl Unsubscribe { } } +/// Write back-pressure message +#[derive(Debug)] +pub struct WrBackpressure(bool); + +impl WrBackpressure { + #[inline] + /// Is write back-pressure enabled + pub fn enabled(&self) -> bool { + self.0 + } + + #[inline] + /// convert packet to a result + pub fn ack(self) -> ControlAck { + ControlAck { result: ControlAckKind::Nothing } + } +} + /// Connection closed message #[derive(Debug)] pub struct Closed; diff --git a/src/v3/dispatcher.rs b/src/v3/dispatcher.rs index a8fed42..43c257f 100644 --- a/src/v3/dispatcher.rs +++ b/src/v3/dispatcher.rs @@ -347,11 +347,11 @@ where } DispatchItem::WBackPressureEnabled => { self.inner.sink.enable_wr_backpressure(); - Ok(None) + control(Control::wr_backpressure(true), &self.inner, ctx).await } DispatchItem::WBackPressureDisabled => { self.inner.sink.disable_wr_backpressure(); - Ok(None) + control(Control::wr_backpressure(false), &self.inner, ctx).await } } } diff --git a/src/v5/control.rs b/src/v5/control.rs index c907385..2bad11c 100644 --- a/src/v5/control.rs +++ b/src/v5/control.rs @@ -19,6 +19,8 @@ pub enum Control { Subscribe(Subscribe), /// Unsubscribe packet from a client Unsubscribe(Unsubscribe), + /// Write back-pressure is enabled/disabled + WrBackpressure(WrBackpressure), /// Underlying transport connection closed Closed(Closed), /// Unhandled application level error from handshake, publish and control services @@ -71,6 +73,10 @@ impl Control { Control::Closed(Closed) } + pub(super) const fn wr_backpressure(enabled: bool) -> Self { + Control::WrBackpressure(WrBackpressure(enabled)) + } + pub(super) fn error(err: E) -> Self { Control::Error(Error::new(err)) } @@ -110,6 +116,7 @@ impl Control { Control::Disconnect(msg) => msg.ack(), Control::Subscribe(msg) => msg.ack(), Control::Unsubscribe(msg) => msg.ack(), + Control::WrBackpressure(msg) => msg.ack(), Control::Closed(msg) => msg.ack(), Control::Error(_) => super::disconnect("Error control message is not supported"), Control::ProtocolError(msg) => msg.ack(), @@ -471,6 +478,24 @@ impl<'a> UnsubscribeItem<'a> { } } +/// Write back-pressure message +#[derive(Debug)] +pub struct WrBackpressure(bool); + +impl WrBackpressure { + #[inline] + /// Is write back-pressure enabled + pub fn enabled(&self) -> bool { + self.0 + } + + #[inline] + /// convert packet to a result + pub fn ack(self) -> ControlAck { + ControlAck { packet: None, disconnect: false } + } +} + /// Connection closed message #[derive(Debug)] pub struct Closed; diff --git a/src/v5/dispatcher.rs b/src/v5/dispatcher.rs index 138211b..1500062 100644 --- a/src/v5/dispatcher.rs +++ b/src/v5/dispatcher.rs @@ -476,11 +476,11 @@ where } DispatchItem::WBackPressureEnabled => { self.inner.sink.enable_wr_backpressure(); - Ok(None) + control(Control::wr_backpressure(true), &self.inner, ctx, 0).await } DispatchItem::WBackPressureDisabled => { self.inner.sink.disable_wr_backpressure(); - Ok(None) + control(Control::wr_backpressure(false), &self.inner, ctx, 0).await } } }