Skip to content

Commit

Permalink
add write back-pressure support
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed May 12, 2024
1 parent 184fc53 commit a37e9dc
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 90 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ serde_json = "1"
thiserror = "1"

[dev-dependencies]
rand = "0.8"
env_logger = "0.11"
ntex-tls = "1.1"
openssl = "0.10"
Expand Down
252 changes: 167 additions & 85 deletions src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,8 @@ pin_project_lite::pin_project! {
U: Decoder,
U: 'static,
{
codec: U,
service: Pipeline<S>,
inner: DispatcherInner<S, U>,
pool: Pool,
#[pin]
response: Option<PipelineCall<S, DispatchItem<U>>>,
response_idx: usize,
}
}

Expand All @@ -42,16 +37,22 @@ bitflags::bitflags! {
}
}

struct DispatcherInner<S: Service<DispatchItem<U>>, U: Encoder + Decoder> {
struct DispatcherInner<S: Service<DispatchItem<U>>, U: Encoder + Decoder + 'static> {
io: IoBoxed,
flags: Flags,
codec: U,
service: Pipeline<S>,
st: IoDispatcherState,
state: Rc<RefCell<DispatcherState<S, U>>>,
config: DispatcherConfig,
read_remains: u32,
read_remains_prev: u32,
read_max_timeout: Seconds,
keepalive_timeout: Seconds,

//#[pin]
response: Option<PipelineCall<S, DispatchItem<U>>>,
response_idx: usize,
}

struct DispatcherState<S: Service<DispatchItem<U>>, U: Encoder + Decoder> {
Expand All @@ -78,6 +79,7 @@ impl<T> ServiceResult<T> {
#[derive(Copy, Clone, Debug)]
enum IoDispatcherState {
Processing,
Backpressure,
Stop,
Shutdown,
}
Expand Down Expand Up @@ -123,17 +125,17 @@ where
let pool = io.memory_pool().pool();

Dispatcher {
codec,
pool,
service: Pipeline::new(service.into_service()),
response: None,
response_idx: 0,
inner: DispatcherInner {
io,
codec,
state,
service: Pipeline::new(service.into_service()),
flags: Flags::empty(),
config: config.clone(),
st: IoDispatcherState::Processing,
response: None,
response_idx: 0,
read_remains: 0,
read_remains_prev: 0,
read_max_timeout: Seconds::ZERO,
Expand Down Expand Up @@ -229,16 +231,16 @@ where
let inner = &mut this.inner;

// handle service response future
if let Some(fut) = this.response.as_mut().as_pin_mut() {
if let Poll::Ready(item) = fut.poll(cx) {
if let Some(fut) = inner.response.as_mut() {
if let Poll::Ready(item) = Pin::new(fut).poll(cx) {
inner.state.borrow_mut().handle_result(
item,
*this.response_idx,
inner.response_idx,
inner.io.as_ref(),
this.codec,
&inner.codec,
false,
);
this.response.set(None);
inner.response = None;
}
}

Expand All @@ -253,10 +255,10 @@ where
loop {
match inner.st {
IoDispatcherState::Processing => {
let item = match ready!(inner.poll_service(this.service, cx)) {
let item = match ready!(inner.poll_service(cx)) {
PollService::Ready => {
// decode incoming bytes stream
match inner.io.poll_recv_decode(this.codec, cx) {
match inner.io.poll_recv_decode(&inner.codec, cx) {
Ok(decoded) => {
inner.update_timer(&decoded);
if let Some(el) = decoded.item {
Expand All @@ -282,12 +284,8 @@ where
}
}
Err(RecvError::WriteBackpressure) => {
if let Err(err) = ready!(inner.io.poll_flush(cx, false)) {
inner.st = IoDispatcherState::Stop;
DispatchItem::Disconnect(Some(err))
} else {
continue;
}
inner.st = IoDispatcherState::Backpressure;
DispatchItem::WBackPressureEnabled
}
Err(RecvError::Decoder(err)) => {
inner.st = IoDispatcherState::Stop;
Expand All @@ -303,65 +301,35 @@ where
PollService::Continue => continue,
};

// optimize first call
if this.response.is_none() {
this.response.set(Some(this.service.call_static(item)));

let res = this.response.as_mut().as_pin_mut().unwrap().poll(cx);
let mut state = inner.state.borrow_mut();

if let Poll::Ready(res) = res {
// check if current result is only response
if state.queue.is_empty() {
match res {
Err(err) => {
state.error = Some(err.into());
}
Ok(Some(item)) => {
if let Err(err) = inner.io.encode(item, this.codec) {
state.error = Some(IoDispatcherError::Encoder(err));
}
}
Ok(None) => (),
}
} else {
*this.response_idx = state.base.wrapping_add(state.queue.len());
state.queue.push_back(ServiceResult::Ready(res));
}
this.response.set(None);
} else {
*this.response_idx = state.base.wrapping_add(state.queue.len());
state.queue.push_back(ServiceResult::Pending);
inner.call_service(cx, item);
}
// handle write back-pressure
IoDispatcherState::Backpressure => {
match ready!(inner.poll_service(cx)) {
PollService::Ready => (),
PollService::Item(item) => {
inner.call_service(cx, item);
}
PollService::Continue => continue,
};

let item = if let Err(err) = ready!(inner.io.poll_flush(cx, false)) {
inner.st = IoDispatcherState::Stop;
DispatchItem::Disconnect(Some(err))
} else {
let mut state = inner.state.borrow_mut();
let response_idx = state.base.wrapping_add(state.queue.len());
state.queue.push_back(ServiceResult::Pending);

let st = inner.io.get_ref();
let codec = this.codec.clone();
let state = inner.state.clone();
let fut = this.service.call_static(item);
#[allow(clippy::let_underscore_future)]
let _ = ntex::rt::spawn(async move {
let item = fut.await;
state.borrow_mut().handle_result(
item,
response_idx,
&st,
&codec,
true,
);
});
}
inner.st = IoDispatcherState::Processing;
DispatchItem::WBackPressureDisabled
};
inner.call_service(cx, item);
}

// drain service responses and shutdown io
IoDispatcherState::Stop => {
inner.io.stop_timer();

// service may relay on poll_ready for response results
if !inner.flags.contains(Flags::READY_ERR) {
let _ = this.service.poll_ready(cx);
let _ = inner.service.poll_ready(cx);
}

if inner.state.borrow().queue.is_empty() {
Expand Down Expand Up @@ -392,7 +360,7 @@ where
}
// shutdown service
IoDispatcherState::Shutdown => {
return if this.service.poll_shutdown(cx).is_ready() {
return if inner.service.poll_shutdown(cx).is_ready() {
log::trace!("{}: Service shutdown is completed, stop", inner.io.tag());

Poll::Ready(
Expand All @@ -419,12 +387,53 @@ where
U: Decoder + Encoder + Clone + 'static,
<U as Encoder>::Item: 'static,
{
fn poll_service(
&mut self,
srv: &Pipeline<S>,
cx: &mut Context<'_>,
) -> Poll<PollService<U>> {
match srv.poll_ready(cx) {
fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem<U>) {
let mut state = self.state.borrow_mut();
let mut fut = self.service.call_static(item);

// optimize first call
if self.response.is_none() {
if let Poll::Ready(res) = Pin::new(&mut fut).poll(cx) {
// check if current result is only response
if state.queue.is_empty() {
match res {
Err(err) => {
state.error = Some(err.into());
}
Ok(Some(item)) => {
if let Err(err) = self.io.encode(item, &self.codec) {
state.error = Some(IoDispatcherError::Encoder(err));
}
}
Ok(None) => (),
}
} else {
self.response_idx = state.base.wrapping_add(state.queue.len());
state.queue.push_back(ServiceResult::Ready(res));
}
} else {
self.response = Some(fut);
self.response_idx = state.base.wrapping_add(state.queue.len());
state.queue.push_back(ServiceResult::Pending);
}
} else {
let response_idx = state.base.wrapping_add(state.queue.len());
state.queue.push_back(ServiceResult::Pending);

let st = self.io.get_ref();
let codec = self.codec.clone();
let state = self.state.clone();

#[allow(clippy::let_underscore_future)]
let _ = ntex::rt::spawn(async move {
let item = fut.await;
state.borrow_mut().handle_result(item, response_idx, &st, &codec, true);
});
}
}

fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
match self.service.poll_ready(cx) {
Poll::Ready(Ok(_)) => {
// check for errors
let mut state = self.state.borrow_mut();
Expand Down Expand Up @@ -478,7 +487,10 @@ where
self.st = IoDispatcherState::Stop;
Poll::Ready(PollService::Item(DispatchItem::Disconnect(err)))
}
IoStatusUpdate::WriteBackpressure => Poll::Pending,
IoStatusUpdate::WriteBackpressure => {
self.st = IoDispatcherState::Backpressure;
Poll::Ready(PollService::Item(DispatchItem::WBackPressureEnabled))
}
}
}
// handle service readiness error
Expand Down Expand Up @@ -571,6 +583,7 @@ mod tests {
use ntex::time::{sleep, Millis};
use ntex::util::Bytes;
use ntex::{codec::BytesCodec, io as nio, service::ServiceCtx, testing::Io};
use rand::Rng;

use super::*;

Expand Down Expand Up @@ -599,15 +612,15 @@ mod tests {

(
Dispatcher {
codec,
service: Pipeline::new(service.into_service()),
response: None,
response_idx: 0,
pool: io.memory_pool().pool(),
inner: DispatcherInner {
codec,
state,
config,
keepalive_timeout,
service: Pipeline::new(service.into_service()),
response: None,
response_idx: 0,
io: IoBoxed::from(io),
st: IoDispatcherState::Processing,
flags: Flags::KA_ENABLED,
Expand Down Expand Up @@ -813,6 +826,75 @@ mod tests {
assert_eq!(counter.get(), 1);
}

#[ntex::test]
async fn test_write_backpressure() {
let (client, server) = Io::create();
// do not allow to write to socket
client.remote_buffer_cap(0);
client.write("GET /test HTTP/1\r\n\r\n");

let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
let data2 = data.clone();

let (disp, io) = Dispatcher::new_debug(
nio::Io::new(server),
BytesCodec,
ntex::service::fn_service(move |msg: DispatchItem<BytesCodec>| {
let data = data2.clone();
async move {
match msg {
DispatchItem::Item(_) => {
data.lock().unwrap().borrow_mut().push(0);
let bytes = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(65_536)
.map(char::from)
.collect::<String>();
return Ok::<_, ()>(Some(Bytes::from(bytes)));
}
DispatchItem::WBackPressureEnabled => {
data.lock().unwrap().borrow_mut().push(1);
}
DispatchItem::WBackPressureDisabled => {
data.lock().unwrap().borrow_mut().push(2);
}
_ => (),
}
Ok(None)
}
}),
);
let pool = io.memory_pool().pool().pool_ref();
pool.set_read_params(8 * 1024, 1024);
pool.set_write_params(16 * 1024, 1024);

ntex::rt::spawn(async move {
let _ = disp.await;
});

let buf = client.read_any();
assert_eq!(buf, Bytes::from_static(b""));
client.write("GET /test HTTP/1\r\n\r\n");
sleep(Millis(25)).await;

// buf must be consumed
assert_eq!(client.remote_buffer(|buf| buf.len()), 0);

// response message
assert_eq!(io.with_write_buf(|buf| buf.len()).unwrap(), 65536);

client.remote_buffer_cap(10240);
sleep(Millis(50)).await;
assert_eq!(io.with_write_buf(|buf| buf.len()).unwrap(), 55296);

client.remote_buffer_cap(45056);
sleep(Millis(50)).await;
assert_eq!(io.with_write_buf(|buf| buf.len()).unwrap(), 10240);

// backpressure disabled
assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
}

#[ntex::test]
async fn test_shutdown_dispatcher_waker() {
let (client, server) = Io::create();
Expand Down
Loading

0 comments on commit a37e9dc

Please sign in to comment.