From a1529e84af9226b9eaf365020978e2532ebe4ff1 Mon Sep 17 00:00:00 2001 From: iHsin Date: Tue, 9 Apr 2024 21:46:00 +0800 Subject: [PATCH] refactor: tuic --- clash_lib/src/proxy/converters/tuic.rs | 16 ++- clash_lib/src/proxy/tuic/handle_task.rs | 9 +- clash_lib/src/proxy/tuic/mod.rs | 65 +++++++---- clash_lib/src/proxy/tuic/types.rs | 137 +++++++++++++----------- 4 files changed, 139 insertions(+), 88 deletions(-) diff --git a/clash_lib/src/proxy/converters/tuic.rs b/clash_lib/src/proxy/converters/tuic.rs index 5dd43c546..9e5413c2a 100644 --- a/clash_lib/src/proxy/converters/tuic.rs +++ b/clash_lib/src/proxy/converters/tuic.rs @@ -1,4 +1,7 @@ -use std::time::Duration; +use std::{ + sync::{atomic::AtomicU32, Arc}, + time::Duration, +}; use quinn::VarInt; @@ -28,7 +31,12 @@ impl TryFrom<&OutboundTuic> for AnyOutboundHandler { port: s.port, uuid: s.uuid.to_owned(), password: s.password.to_owned(), - udp_relay_mode: s.udp_relay_mode.to_owned().unwrap_or("native".to_string()), + udp_relay_mode: s + .udp_relay_mode + .to_owned() + .unwrap_or("native".to_string()) + .as_str() + .into(), disable_sni: s.disable_sni.unwrap_or(false), alpn: s .alpn @@ -37,7 +45,8 @@ impl TryFrom<&OutboundTuic> for AnyOutboundHandler { .unwrap_or_default(), heartbeat_interval: Duration::from_millis(s.heartbeat_interval.unwrap_or(3000)), reduce_rtt: s.reduce_rtt.unwrap_or(false) || s.fast_open.unwrap_or(false), - request_timeout: Duration::from_millis(s.request_timeout.unwrap_or(8000)), + request_timeout: Duration::from_millis(s.request_timeout.unwrap_or(4000)), + idle_timeout: VarInt::from_u64(s.request_timeout.unwrap_or(4000)).unwrap_or(VarInt::MAX), congestion_controller: s .congestion_controller .clone() @@ -54,6 +63,7 @@ impl TryFrom<&OutboundTuic> for AnyOutboundHandler { send_window: s.send_window.unwrap_or(8 * 1024 * 1024 * 2), receive_window: VarInt::from_u64(s.receive_window.unwrap_or(8 * 1024 * 1024)) .unwrap_or(VarInt::MAX), + mark: Arc::new(AtomicU32::new(s.mark.unwrap_or(6969))), }) } } diff --git a/clash_lib/src/proxy/tuic/handle_task.rs b/clash_lib/src/proxy/tuic/handle_task.rs index 340dc027f..b717f784b 100644 --- a/clash_lib/src/proxy/tuic/handle_task.rs +++ b/clash_lib/src/proxy/tuic/handle_task.rs @@ -26,9 +26,12 @@ impl TuicConnection { .authenticate(self.uuid, self.password.clone()) .await { - Ok(()) => tracing::info!("[auth] {uuid}", uuid = self.uuid), + Ok(()) => tracing::info!("[auth] success {uuid}", uuid = self.uuid), Err(err) => { - tracing::warn!("[auth] authentication sending error: {err}") + tracing::warn!( + "[auth] authentication sending error: {:?}", + anyhow::anyhow!(err) + ) } } } @@ -141,7 +144,7 @@ impl TuicConnection { } match self.inner.heartbeat().await { - Ok(()) => tracing::trace!("[heartbeat]"), + Ok(()) => tracing::debug!("[heartbeat]"), Err(err) => tracing::error!("[heartbeat] {err}"), } Ok(()) diff --git a/clash_lib/src/proxy/tuic/mod.rs b/clash_lib/src/proxy/tuic/mod.rs index 23327c859..e22e78f47 100644 --- a/clash_lib/src/proxy/tuic/mod.rs +++ b/clash_lib/src/proxy/tuic/mod.rs @@ -6,8 +6,10 @@ pub(crate) mod types; use crate::proxy::tuic::types::SocketAdderTrans; use anyhow::Result; use axum::async_trait; +use quinn::congestion::{BbrConfig, NewRenoConfig}; use quinn::{EndpointConfig, TokioRuntime}; use std::net::SocketAddr; +use std::sync::atomic::AtomicU32; use std::{ net::{Ipv4Addr, Ipv6Addr, UdpSocket}, sync::{ @@ -41,7 +43,7 @@ use tokio::sync::Mutex as AsyncMutex; use rustls::client::ClientConfig as TlsConfig; -use self::types::{CongestionControl, TuicConnection, UdpSession}; +use self::types::{CongestionControl, TuicConnection, UdpRelayMode, UdpSession}; use super::{ datagram::UdpPacket, AnyOutboundDatagram, AnyOutboundHandler, AnyStream, OutboundHandler, @@ -55,12 +57,13 @@ pub struct HandlerOptions { pub port: u16, pub uuid: Uuid, pub password: String, - pub udp_relay_mode: String, + pub udp_relay_mode: UdpRelayMode, pub disable_sni: bool, pub alpn: Vec>, pub heartbeat_interval: Duration, pub reduce_rtt: bool, pub request_timeout: Duration, + pub idle_timeout: VarInt, pub congestion_controller: CongestionControl, pub max_udp_relay_packet_size: u64, pub max_open_stream: VarInt, @@ -68,6 +71,7 @@ pub struct HandlerOptions { pub gc_lifetime: Duration, pub send_window: u64, pub receive_window: VarInt, + pub mark: Arc, /// not used pub ip: Option, @@ -147,38 +151,52 @@ impl Handler { crypto.enable_early_data = true; crypto.enable_sni = !opts.disable_sni; let mut quinn_config = QuinnConfig::new(Arc::new(crypto)); - let mut quinn_transport_config = QuinnTransportConfig::default(); - quinn_transport_config + let mut transport_config = QuinnTransportConfig::default(); + transport_config .max_concurrent_bidi_streams(opts.max_open_stream) .max_concurrent_uni_streams(opts.max_open_stream) .send_window(opts.send_window) .stream_receive_window(opts.receive_window) - .max_idle_timeout(None) - .congestion_controller_factory(Arc::new(CubicConfig::default())); - quinn_config.transport_config(Arc::new(quinn_transport_config)); + .max_idle_timeout(Some(opts.idle_timeout.into())); + match opts.congestion_controller { + CongestionControl::Cubic => { + transport_config.congestion_controller_factory(Arc::new(CubicConfig::default())) + } + CongestionControl::NewReno => { + transport_config.congestion_controller_factory(Arc::new(NewRenoConfig::default())) + } + CongestionControl::Bbr => { + transport_config.congestion_controller_factory(Arc::new(BbrConfig::default())) + } + }; + + quinn_config.transport_config(Arc::new(transport_config)); // Try to create an IPv4 socket as the placeholder first, if it fails, try IPv6. + // If it dont match server's ipv4/ipv6, rebind is needed let socket = UdpSocket::bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))).or_else(|err| { UdpSocket::bind(SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))).map_err(|_| err) })?; - + // TODO #362 socket.set_mark(6969)?; let mut endpoint = QuinnEndpoint::new( EndpointConfig::default(), None, socket, Arc::new(TokioRuntime), )?; + endpoint.set_default_client_config(quinn_config); let endpoint = TuicEndpoint { ep: endpoint, server: ServerAddr::new(opts.server.clone(), opts.port, None), uuid: opts.uuid, password: Arc::from(opts.password.clone().into_bytes().into_boxed_slice()), - udp_relay_mode: types::UdpRelayMode::Native, + udp_relay_mode: opts.udp_relay_mode, zero_rtt_handshake: opts.reduce_rtt, heartbeat: opts.heartbeat_interval, gc_interval: opts.gc_interval, gc_lifetime: opts.gc_lifetime, + mark: opts.mark.clone(), }; Ok(Arc::new(Self { opts, @@ -187,17 +205,26 @@ impl Handler { next_assoc_id: AtomicU16::new(0), })) } - async fn get_conn(&self) -> Result> { + async fn get_conn( + &self, + resolver: &ThreadSafeDNSResolver, + ) -> Result> { + let mark = 6969; // TODO #362 + let mut rebind = false; + // if mark not match the one current used, then rebind + if mark != self.opts.mark.swap(mark, Ordering::SeqCst) { + rebind = true; + } let fut = async { let mut guard = self.conn.lock().await; if guard.is_none() { // init - *guard = Some(self.ep.connect().await?); + *guard = Some(self.ep.connect(resolver, rebind).await?); } let conn = guard.take().unwrap(); - let conn = if conn.check_open().is_err() { + let conn = if conn.check_open().is_err() || rebind { // reconnect - self.ep.connect().await? + self.ep.connect(resolver, rebind).await? } else { conn }; @@ -210,12 +237,11 @@ impl Handler { async fn do_connect_stream( &self, sess: &Session, - _resolver: ThreadSafeDNSResolver, + resolver: ThreadSafeDNSResolver, ) -> Result { - let conn = self.get_conn().await?; + let conn = self.get_conn(&resolver).await?; let dest = sess.destination.clone().into_tuic(); let tuic_tcp = conn.connect_tcp(dest).await?.compat(); - let s = ChainedStreamWrapper::new(tuic_tcp); s.append_to_chain(self.name()).await; Ok(Box::new(s)) @@ -224,11 +250,10 @@ impl Handler { async fn do_connect_datagram( &self, sess: &Session, - _resolver: ThreadSafeDNSResolver, + resolver: ThreadSafeDNSResolver, ) -> Result { - let conn = self.get_conn().await?; - - let assos_id = self.next_assoc_id.fetch_add(1, Ordering::Relaxed); + let conn = self.get_conn(&resolver).await?; + let assos_id = self.next_assoc_id.fetch_add(1, Ordering::SeqCst); let quic_udp = TuicDatagramOutbound::new(assos_id, conn, sess.source.into()); let s = ChainedDatagramWrapper::new(quic_udp); s.append_to_chain(self.name()).await; diff --git a/clash_lib/src/proxy/tuic/types.rs b/clash_lib/src/proxy/tuic/types.rs index 84d8a5692..89855dbc6 100644 --- a/clash_lib/src/proxy/tuic/types.rs +++ b/clash_lib/src/proxy/tuic/types.rs @@ -1,9 +1,11 @@ +use crate::app::dns::ThreadSafeDNSResolver; use crate::session::SocksAddr as ClashSocksAddr; use anyhow::Result; use quinn::Connection as QuinnConnection; use quinn::{Endpoint as QuinnEndpoint, ZeroRttAccepted}; use register_count::Counter; use std::collections::HashMap; +use std::sync::atomic::Ordering; use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, str::FromStr, @@ -26,63 +28,75 @@ pub struct TuicEndpoint { pub heartbeat: Duration, pub gc_interval: Duration, pub gc_lifetime: Duration, + pub mark: Arc, } impl TuicEndpoint { - pub async fn connect(&self) -> Result> { - let mut last_err = None; + pub async fn connect( + &self, + resolver: &ThreadSafeDNSResolver, + rebind: bool, + ) -> Result> { + let remote_addr = self.server.resolve(resolver).await?; + let connect_to = async { + let match_ipv4 = remote_addr.is_ipv4() + && self + .ep + .local_addr() + .map_or(false, |local_addr| local_addr.is_ipv4()); + let match_ipv6 = remote_addr.is_ipv6() + && self + .ep + .local_addr() + .map_or(false, |local_addr| local_addr.is_ipv6()); - for addr in self.server.resolve().await? { - let connect_to = async { - let match_ipv4 = - addr.is_ipv4() && self.ep.local_addr().map_or(false, |addr| addr.is_ipv4()); - let match_ipv6 = - addr.is_ipv6() && self.ep.local_addr().map_or(false, |addr| addr.is_ipv6()); - - if !match_ipv4 && !match_ipv6 { - let bind_addr = if addr.is_ipv4() { - SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)) - } else { - SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0)) - }; - - self.ep - .rebind(UdpSocket::bind(bind_addr).map_err(|err| { - anyhow!("failed to create endpoint UDP socket {}", err) - })?) - .map_err(|err| anyhow!("failed to rebind endpoint UDP socket {}", err))?; - } - - tracing::trace!("Connect to {} {}", addr, self.server.server_name()); - let conn = self.ep.connect(addr, self.server.server_name())?; - let (conn, zero_rtt_accepted) = if self.zero_rtt_handshake { - match conn.into_0rtt() { - Ok((conn, zero_rtt_accepted)) => (conn, Some(zero_rtt_accepted)), - Err(conn) => (conn.await?, None), - } + // if client and server don't match each other or forced to rebind, then rebind local socket + if (!match_ipv4 && !match_ipv6) || rebind { + let bind_addr = if remote_addr.is_ipv4() { + SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)) } else { - (conn.await?, None) + SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0)) }; + let socket = UdpSocket::bind(bind_addr) + .map_err(|err| anyhow!("failed to bind local socket: {}", err))?; + let mark = self.mark.load(Ordering::Relaxed); + // ignore mark == 0, just for convenient + if mark != 0 { + // TODO #362 socket.set_mark(mark)?; + } + self.ep + .rebind(socket) + .map_err(|err| anyhow!("failed to rebind endpoint UDP socket {}", err))?; + } - Ok((conn, zero_rtt_accepted)) + tracing::trace!("Connect to {} {}", remote_addr, self.server.server_name()); + let conn = self.ep.connect(remote_addr, self.server.server_name())?; + let (conn, zero_rtt_accepted) = if self.zero_rtt_handshake { + match conn.into_0rtt() { + Ok((conn, zero_rtt_accepted)) => (conn, Some(zero_rtt_accepted)), + Err(conn) => (conn.await?, None), + } + } else { + (conn.await?, None) }; - match connect_to.await { - Ok((conn, zero_rtt_accepted)) => { - return Ok(TuicConnection::new( - conn, - zero_rtt_accepted, - self.udp_relay_mode, - self.uuid, - self.password.clone(), - self.heartbeat, - self.gc_interval, - self.gc_lifetime, - )); - } - Err(err) => last_err = Some(err), + Ok((conn, zero_rtt_accepted)) + }; + + match connect_to.await { + Ok((conn, zero_rtt_accepted)) => { + return Ok(TuicConnection::new( + conn, + zero_rtt_accepted, + self.udp_relay_mode, + self.uuid, + self.password.clone(), + self.heartbeat, + self.gc_interval, + self.gc_lifetime, + )); } + Err(err) => Err(err), } - Err(last_err.unwrap_or(anyhow!("dns resolve"))) } } @@ -177,7 +191,7 @@ impl TuicConnection { }; }; - tracing::warn!("connection error: {err}"); + tracing::warn!("connection error: {err:?}"); } } @@ -194,15 +208,16 @@ impl ServerAddr { pub fn server_name(&self) -> &str { &self.domain } - // TODO change to clash dns? - pub async fn resolve(&self) -> Result> { + + pub async fn resolve(&self, resolver: &ThreadSafeDNSResolver) -> Result { if let Some(ip) = self.ip { - Ok(vec![SocketAddr::from((ip, self.port))].into_iter()) + Ok(SocketAddr::from((ip, self.port))) } else { - Ok(tokio::net::lookup_host((self.domain.as_str(), self.port)) + let ip = resolver + .resolve(self.domain.as_str(), false) .await? - .collect::>() - .into_iter()) + .ok_or(anyhow!("Resolve failed: unknown hostname"))?; + Ok(SocketAddr::from((ip, self.port))) } } } @@ -212,17 +227,15 @@ pub enum UdpRelayMode { Native, Quic, } - -impl FromStr for UdpRelayMode { - type Err = &'static str; - - fn from_str(s: &str) -> Result { +impl From<&str> for UdpRelayMode { + fn from(s: &str) -> Self { if s.eq_ignore_ascii_case("native") { - Ok(Self::Native) + Self::Native } else if s.eq_ignore_ascii_case("quic") { - Ok(Self::Quic) + Self::Quic } else { - Err("invalid UDP relay mode") + // TODO logging + Self::Quic } } }