Skip to content

Commit

Permalink
refactor: use Arc<TuicConnection>
Browse files Browse the repository at this point in the history
  • Loading branch information
iHsin committed Mar 24, 2024
1 parent 18677c0 commit 264e969
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 21 deletions.
26 changes: 16 additions & 10 deletions clash_lib/src/proxy/tuic/handle_stream.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::sync::atomic::Ordering;
use std::sync::Arc;

use bytes::Bytes;
use quinn::{RecvStream, SendStream, VarInt};
Expand Down Expand Up @@ -46,7 +47,7 @@ impl TuicConnection {
Ok(self.conn.read_datagram().await?)
}

pub async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) {
pub async fn handle_uni_stream(self: Arc<Self>, recv: RecvStream, _reg: Register) {
tracing::debug!("[relay] incoming unidirectional stream");

let res = match self.inner.accept_uni_stream(recv).await {
Expand All @@ -66,20 +67,23 @@ impl TuicConnection {
}
}

pub async fn handle_bi_stream(self, send: SendStream, recv: RecvStream, _reg: Register) {
pub async fn handle_bi_stream(
self: Arc<Self>,
send: SendStream,
recv: RecvStream,
_reg: Register,
) {
tracing::debug!("[relay] incoming bidirectional stream");

let res = match self.inner.accept_bi_stream(send, recv).await {
Err(err) => Err::<(), _>(anyhow!(err)),
_ => unreachable!(), // already filtered in `tuic_quinn`
let err = match self.inner.accept_bi_stream(send, recv).await {
Err(err) => anyhow!(err),
_ => anyhow!("A client shouldn't receive bi stream"),
};

if let Err(err) = res {
tracing::warn!("[relay] incoming bidirectional stream error: {err}");
}
tracing::warn!("[relay] incoming bidirectional stream error: {err}");
}

pub async fn handle_datagram(self, dg: Bytes) {
pub async fn handle_datagram(self: Arc<Self>, dg: Bytes) {
tracing::debug!("[relay] incoming datagram");

let res = match self.inner.accept_datagram(dg) {
Expand All @@ -91,7 +95,9 @@ impl TuicConnection {
}
UdpRelayMode::Quic => Err(anyhow!("wrong packet source")),
},
_ => unreachable!(), // already filtered in `tuic_quinn`
_ => Err(anyhow!(
"Datagram shouldn't receive any data expect UDP packet"
)),
};

if let Err(err) = res {
Expand Down
6 changes: 3 additions & 3 deletions clash_lib/src/proxy/tuic/handle_task.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::time::Duration;
use std::{sync::Arc, time::Duration};

use bytes::Bytes;
use quinn::ZeroRttAccepted;
Expand All @@ -13,7 +13,7 @@ use crate::session::SocksAddr as ClashSocksAddr;
use super::types::{TuicConnection, UdpRelayMode};

impl TuicConnection {
pub async fn tuic_auth(self, zero_rtt_accepted: Option<ZeroRttAccepted>) {
pub async fn tuic_auth(self: Arc<Self>, zero_rtt_accepted: Option<ZeroRttAccepted>) {
if let Some(zero_rtt_accepted) = zero_rtt_accepted {
tracing::debug!("[auth] waiting for connection to be fully established");
zero_rtt_accepted.await;
Expand Down Expand Up @@ -156,7 +156,7 @@ impl TuicConnection {
/// Tasks triggered by timer
/// Won't return unless occurs error
pub async fn cyclical_tasks(
self,
self: Arc<Self>,
heartbeat_interval: Duration,
gc_interval: Duration,
gc_lifetime: Duration,
Expand Down
7 changes: 3 additions & 4 deletions clash_lib/src/proxy/tuic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ pub struct HandlerOptions {
pub struct Handler {
opts: HandlerOptions,
ep: TuicEndpoint,
conn: AsyncMutex<Option<TuicConnection>>,
conn: AsyncMutex<Option<Arc<TuicConnection>>>,
next_assoc_id: AtomicU16,
}

Expand Down Expand Up @@ -188,7 +188,7 @@ impl Handler {
next_assoc_id: AtomicU16::new(0),
}))
}
async fn get_conn(&self) -> Result<TuicConnection> {
async fn get_conn(&self) -> Result<Arc<TuicConnection>> {
let fut = async {
let mut guard = self.conn.lock().await;
if guard.is_none() {
Expand All @@ -202,7 +202,6 @@ impl Handler {
} else {
conn
};
// TODO TuicConnection is huge, is it necessary to clone it? If it is, should we use Arc ?
*guard = Some(conn.clone());
Ok(conn)
};
Expand Down Expand Up @@ -248,7 +247,7 @@ struct TuicDatagramOutbound {
impl TuicDatagramOutbound {
pub fn new(
assoc_id: u16,
conn: TuicConnection,
conn: Arc<TuicConnection>,
local_addr: ClashSocksAddr,
) -> AnyOutboundDatagram {
// TODO not sure about the size of buffer
Expand Down
8 changes: 4 additions & 4 deletions clash_lib/src/proxy/tuic/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub struct TuicEndpoint {
pub gc_lifetime: Duration,
}
impl TuicEndpoint {
pub async fn connect(&self) -> Result<TuicConnection> {
pub async fn connect(&self) -> Result<Arc<TuicConnection>> {
let mut last_err = None;

for addr in self.server.resolve().await? {
Expand Down Expand Up @@ -121,7 +121,7 @@ impl TuicConnection {
heartbeat: Duration,
gc_interval: Duration,
gc_lifetime: Duration,
) -> Self {
) -> Arc<Self> {
let conn = Self {
conn: conn.clone(),
inner: InnerConnection::<tuic_quinn::side::Client>::new(conn),
Expand All @@ -135,7 +135,7 @@ impl TuicConnection {
max_concurrent_bi_streams: Arc::new(AtomicU32::new(32)),
udp_sessions: Arc::new(AsyncRwLock::new(HashMap::new())),
};

let conn = Arc::new(conn);
tokio::spawn(
conn.clone()
.init(zero_rtt_accepted, heartbeat, gc_interval, gc_lifetime),
Expand All @@ -144,7 +144,7 @@ impl TuicConnection {
conn
}
async fn init(
self,
self: Arc<Self>,
zero_rtt_accepted: Option<ZeroRttAccepted>,
heartbeat: Duration,
gc_interval: Duration,
Expand Down

0 comments on commit 264e969

Please sign in to comment.