Skip to content

Commit

Permalink
Stream timeout (#117)
Browse files Browse the repository at this point in the history
* add conn timeout

* manual emit error

* log

* manual timeout
  • Loading branch information
ibigbug authored Oct 15, 2023
1 parent 8bcd902 commit bc12752
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 11 deletions.
28 changes: 17 additions & 11 deletions clash_lib/src/app/dispatcher/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::app::dispatcher::tracked::TrackedDatagram;
use crate::app::dispatcher::tracked::TrackedStream;
use crate::app::outbound::manager::ThreadSafeOutboundManager;
use crate::app::router::ThreadSafeRouter;
use crate::common::io::copy_buf_bidirectional_with_timeout;
use crate::config::def::RunMode;
use crate::config::internal::proxy::PROXY_DIRECT;
use crate::config::internal::proxy::PROXY_GLOBAL;
Expand All @@ -17,7 +18,7 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use tokio::io::{copy_bidirectional, AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use tracing::info_span;
Expand Down Expand Up @@ -131,16 +132,21 @@ impl Dispatcher {
{
Ok(rhs) => {
debug!("remote connection established {}", sess);
let mut rhs = Box::new(
TrackedStream::new(rhs, self.manager.clone(), sess.clone(), rule).await,
);
match copy_bidirectional(&mut lhs, &mut rhs)
.instrument(info_span!(
"copy_bidirectional",
outbound_name = outbound_name,
session = %sess,
))
.await
let mut rhs =
TrackedStream::new(rhs, self.manager.clone(), sess.clone(), rule).await;
match copy_buf_bidirectional_with_timeout(
&mut lhs,
&mut rhs,
4096,
Duration::from_secs(10),
Duration::from_secs(10),
)
.instrument(info_span!(
"copy_bidirectional",
outbound_name = outbound_name,
session = %sess,
))
.await
{
Ok((up, down)) => {
debug!(
Expand Down
1 change: 1 addition & 0 deletions clash_lib/src/app/dispatcher/tracked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ impl AsyncRead for TrackedStream {
self.tracker
.download_total
.fetch_add(download as u64, std::sync::atomic::Ordering::Release);

v
}
}
Expand Down
291 changes: 291 additions & 0 deletions clash_lib/src/common/io.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
/// copy of https://github.com/eycorsican/leaf/blob/a77a1e497ae034f3a2a89c8628d5e7ebb2af47f0/leaf/src/common/io.rs
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;

use futures::ready;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

#[derive(Debug)]
pub struct CopyBuffer {
read_done: bool,
need_flush: bool,
pos: usize,
cap: usize,
amt: u64,
buf: Box<[u8]>,
}

impl CopyBuffer {
#[allow(unused)]
pub fn new() -> Self {
Self {
read_done: false,
need_flush: false,
pos: 0,
cap: 0,
amt: 0,
buf: vec![0; 2 * 1024].into_boxed_slice(),
}
}

pub fn new_with_capacity(size: usize) -> Result<Self, std::io::Error> {
let mut buf = Vec::new();
buf.try_reserve(size).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("new buffer failed: {}", e),
)
})?;
buf.resize(size, 0);
Ok(Self {
read_done: false,
need_flush: false,
pos: 0,
cap: 0,
amt: 0,
buf: buf.into_boxed_slice(),
})
}

pub fn amount_transfered(&self) -> u64 {
self.amt
}

pub fn poll_copy<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<u64>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
loop {
// If our buffer is empty, then we need to read some data to
// continue.
if self.pos == self.cap && !self.read_done {
let me = &mut *self;
let mut buf = ReadBuf::new(&mut me.buf);

match reader.as_mut().poll_read(cx, &mut buf) {
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => {
// Try flushing when the reader has no progress to avoid deadlock
// when the reader depends on buffered writer.
if self.need_flush {
ready!(writer.as_mut().poll_flush(cx))?;
self.need_flush = false;
}

return Poll::Pending;
}
}

let n = buf.filled().len();
if n == 0 {
self.read_done = true;
} else {
self.pos = 0;
self.cap = n;
}
}

// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let me = &mut *self;
let i = ready!(writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]))?;
if i == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"write zero byte into writer",
)));
} else {
self.pos += i;
self.amt += i as u64;
self.need_flush = true;
}
}

// If pos larger than cap, this loop will never stop.
// In particular, user's wrong poll_write implementation returning
// incorrect written length may lead to thread blocking.
debug_assert!(
self.pos <= self.cap,
"writer returned length larger than input slice"
);

// If we've written all the data and we've seen EOF, flush out the
// data and finish the transfer.
if self.pos == self.cap && self.read_done {
ready!(writer.as_mut().poll_flush(cx))?;
return Poll::Ready(Ok(self.amt));
}
}
}
}

enum TransferState {
Running(CopyBuffer),
ShuttingDown(u64),
Done,
}

struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> {
a: &'a mut A,
b: &'a mut B,
a_to_b: TransferState,
b_to_a: TransferState,
a_to_b_count: u64,
b_to_a_count: u64,
a_to_b_delay: Option<Pin<Box<tokio::time::Sleep>>>,
b_to_a_delay: Option<Pin<Box<tokio::time::Sleep>>>,
a_to_b_timeout_duration: Duration,
b_to_a_timeout_duration: Duration,
}

impl<'a, A, B> Future for CopyBidirectional<'a, A, B>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
type Output = io::Result<(u64, u64)>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Unpack self into mut refs to each field to avoid borrow check issues.
let CopyBidirectional {
a,
b,
a_to_b,
b_to_a,
a_to_b_count,
b_to_a_count,
a_to_b_delay,
b_to_a_delay,
a_to_b_timeout_duration,
b_to_a_timeout_duration,
} = &mut *self;

let mut a = Pin::new(a);
let mut b = Pin::new(b);

loop {
match a_to_b {
TransferState::Running(buf) => {
let res = buf.poll_copy(cx, a.as_mut(), b.as_mut());
match res {
Poll::Ready(Ok(count)) => {
*a_to_b = TransferState::ShuttingDown(count);
continue;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => {
if let Some(delay) = a_to_b_delay {
match delay.as_mut().poll(cx) {
Poll::Ready(()) => {
*a_to_b =
TransferState::ShuttingDown(buf.amount_transfered());
continue;
}
Poll::Pending => (),
}
}
}
}
}
TransferState::ShuttingDown(count) => {
let res = b.as_mut().poll_shutdown(cx);
match res {
Poll::Ready(Ok(())) => {
*a_to_b_count += *count;
*a_to_b = TransferState::Done;
b_to_a_delay
.replace(Box::pin(tokio::time::sleep(*b_to_a_timeout_duration)));
continue;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => (),
}
}
TransferState::Done => (),
}

match b_to_a {
TransferState::Running(buf) => {
let res = buf.poll_copy(cx, b.as_mut(), a.as_mut());
match res {
Poll::Ready(Ok(count)) => {
*b_to_a = TransferState::ShuttingDown(count);
continue;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => {
if let Some(delay) = b_to_a_delay {
match delay.as_mut().poll(cx) {
Poll::Ready(()) => {
*b_to_a =
TransferState::ShuttingDown(buf.amount_transfered());
continue;
}
Poll::Pending => (),
}
}
}
}
}
TransferState::ShuttingDown(count) => {
let res = a.as_mut().poll_shutdown(cx);
match res {
Poll::Ready(Ok(())) => {
*b_to_a_count += *count;
*b_to_a = TransferState::Done;
a_to_b_delay
.replace(Box::pin(tokio::time::sleep(*a_to_b_timeout_duration)));
continue;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => (),
}
}
TransferState::Done => (),
}

match (&a_to_b, &b_to_a) {
(TransferState::Done, TransferState::Done) => break,
_ => return Poll::Pending,
}
}

Poll::Ready(Ok((*a_to_b_count, *b_to_a_count)))
}
}

pub async fn copy_buf_bidirectional_with_timeout<A, B>(
a: &mut A,
b: &mut B,
size: usize,
a_to_b_timeout_duration: Duration,
b_to_a_timeout_duration: Duration,
) -> Result<(u64, u64), std::io::Error>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
CopyBidirectional {
a,
b,
a_to_b: TransferState::Running(CopyBuffer::new_with_capacity(size)?),
b_to_a: TransferState::Running(CopyBuffer::new_with_capacity(size)?),
a_to_b_count: 0,
b_to_a_count: 0,
a_to_b_delay: None,
b_to_a_delay: None,
a_to_b_timeout_duration,
b_to_a_timeout_duration,
}
.await
}
1 change: 1 addition & 0 deletions clash_lib/src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod auth;
pub mod crypto;
pub mod errors;
pub mod http;
pub mod io;
pub mod mmdb;
pub mod timed_future;
pub mod tls;
Expand Down

0 comments on commit bc12752

Please sign in to comment.