diff --git a/Cargo.lock b/Cargo.lock index 50ef0146a..94082bff7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -722,6 +722,7 @@ version = "0.1.0" dependencies = [ "aes-gcm", "anyhow", + "arc-swap", "async-recursion", "async-trait", "axum", @@ -2249,7 +2250,7 @@ checksum = "9252111cf132ba0929b6f8e030cac2a24b507f3a4d6db6fb2896f27b354c714b" [[package]] name = "netstack-lwip" version = "0.3.4" -source = "git+https://github.com/Watfaq/netstack-lwip.git?rev=8c8c0b0#8c8c0b0646ebeb6eb84821d95b7261d3e00d94dd" +source = "git+https://github.com/Watfaq/netstack-lwip.git?rev=5ad376f#5ad376f6c48df459c9120a781c6e4d90650435db" dependencies = [ "anyhow", "bindgen 0.59.2", diff --git a/clash/src/main.rs b/clash/src/main.rs index 1d8e8fad3..258af7493 100644 --- a/clash/src/main.rs +++ b/clash/src/main.rs @@ -1,6 +1,7 @@ extern crate clash_lib as clash; use clap::Parser; +use clash::TokioRuntime; use std::path::PathBuf; #[derive(Parser)] @@ -27,6 +28,7 @@ fn main() { clash::start(clash::Options { config: clash::Config::File("".to_string(), cli.config.to_string_lossy().to_string()), cwd: cli.directory.map(|x| x.to_string_lossy().to_string()), + rt: Some(TokioRuntime::MultiThread), }) .unwrap(); } diff --git a/clash_lib/Cargo.toml b/clash_lib/Cargo.toml index e65a9c175..6f768ff1c 100644 --- a/clash_lib/Cargo.toml +++ b/clash_lib/Cargo.toml @@ -41,6 +41,7 @@ hyper-boring = { git = "https://github.com/Watfaq/boring.git", rev = "24c006f" } tokio-boring = { git = "https://github.com/Watfaq/boring.git", rev = "24c006f" } ip_network_table-deps-treebitmap = "0.5.0" once_cell = "1.18.0" +arc-swap = "1.6.0" # opentelemetry opentelemetry = "0.20" @@ -60,7 +61,7 @@ tower-http = { version = "0.4.0", features = ["fs", "trace", "cors"] } chrono = { version = "0.4.26", features = ["serde"] } tun = { git = "https://github.com/Watfaq/rust-tun.git", rev = "28936b6", features = ["async"] } -netstack-lwip = { git = "https://github.com/Watfaq/netstack-lwip.git", rev = "8c8c0b0" } +netstack-lwip = { git = "https://github.com/Watfaq/netstack-lwip.git", rev = "5ad376f" } boringtun = { version = "0.6.0", features = ["device"] } serde = { version = "1.0", features=["derive"] } diff --git a/clash_lib/src/app/api/handlers/provider.rs b/clash_lib/src/app/api/handlers/provider.rs index d6346df14..fc824bafa 100644 --- a/clash_lib/src/app/api/handlers/provider.rs +++ b/clash_lib/src/app/api/handlers/provider.rs @@ -51,7 +51,7 @@ pub fn routes(outbound_manager: ThreadSafeOutboundManager) -> Router) -> impl IntoResponse { - let outbound_manager = state.outbound_manager.read().await; + let outbound_manager = state.outbound_manager.clone(); let mut res = HashMap::new(); let mut providers = HashMap::new(); @@ -76,7 +76,7 @@ async fn find_proxy_provider_by_name( mut req: Request, next: Next, ) -> Response { - let outbound_manager = state.outbound_manager.read().await; + let outbound_manager = state.outbound_manager.clone(); if let Some(provider) = outbound_manager.get_proxy_provider(&name) { req.extensions_mut().insert(provider); next.run(req).await @@ -154,7 +154,7 @@ async fn get_proxy( Extension(proxy): Extension, State(state): State, ) -> impl IntoResponse { - let outbound_manager = state.outbound_manager.read().await; + let outbound_manager = state.outbound_manager.clone(); axum::response::Json(outbound_manager.get_proxy(&proxy).await) } @@ -168,7 +168,7 @@ async fn get_proxy_delay( Extension(proxy): Extension, Query(q): Query, ) -> impl IntoResponse { - let outbound_manager = state.outbound_manager.read().await; + let outbound_manager = state.outbound_manager.clone(); let timeout = Duration::from_millis(q.timeout.into()); let n = proxy.name().to_owned(); match outbound_manager.url_test(proxy, &q.url, timeout).await { diff --git a/clash_lib/src/app/api/handlers/proxy.rs b/clash_lib/src/app/api/handlers/proxy.rs index 086ca6b3d..41607b2ef 100644 --- a/clash_lib/src/app/api/handlers/proxy.rs +++ b/clash_lib/src/app/api/handlers/proxy.rs @@ -50,7 +50,7 @@ pub fn routes( } async fn get_proxies(State(state): State) -> impl IntoResponse { - let outbound_manager = state.outbound_manager.read().await; + let outbound_manager = state.outbound_manager.clone(); let mut res = HashMap::new(); let proxies = outbound_manager.get_proxies().await; res.insert("proxies".to_owned(), proxies); @@ -63,7 +63,7 @@ async fn find_proxy_by_name( mut req: Request, next: Next, ) -> Response { - let outbound_manager = state.outbound_manager.read().await; + let outbound_manager = state.outbound_manager.clone(); if let Some(proxy) = outbound_manager.get_outbound(&name) { req.extensions_mut().insert(proxy); next.run(req).await @@ -76,7 +76,7 @@ async fn get_proxy( Extension(proxy): Extension, State(state): State, ) -> impl IntoResponse { - let outbound_manager = state.outbound_manager.read().await; + let outbound_manager = state.outbound_manager.clone(); axum::response::Json(outbound_manager.get_proxy(&proxy).await) } @@ -91,7 +91,7 @@ async fn update_proxy( Extension(proxy): Extension, Json(payload): Json, ) -> impl IntoResponse { - let outbound_manager = state.outbound_manager.read().await; + let outbound_manager = state.outbound_manager.clone(); if let Some(ctrl) = outbound_manager.get_selector_control(proxy.name()) { match ctrl.lock().await.select(&payload.name).await { Ok(_) => { @@ -130,7 +130,7 @@ async fn get_proxy_delay( Extension(proxy): Extension, Query(q): Query, ) -> impl IntoResponse { - let outbound_manager = state.outbound_manager.read().await; + let outbound_manager = state.outbound_manager.clone(); let timeout = Duration::from_millis(q.timeout.into()); let n = proxy.name().to_owned(); match outbound_manager.url_test(proxy, &q.url, timeout).await { diff --git a/clash_lib/src/app/dispatcher/dispatcher.rs b/clash_lib/src/app/dispatcher/dispatcher.rs index c521276ef..8cc882923 100644 --- a/clash_lib/src/app/dispatcher/dispatcher.rs +++ b/clash_lib/src/app/dispatcher/dispatcher.rs @@ -8,6 +8,7 @@ use crate::config::internal::proxy::PROXY_GLOBAL; use crate::proxy::datagram::UdpPacket; use crate::proxy::AnyInboundDatagram; use crate::session::Session; +use arc_swap::ArcSwap; use futures::SinkExt; use futures::StreamExt; use std::collections::HashMap; @@ -17,7 +18,6 @@ use std::sync::Arc; use std::time::Duration; use std::time::Instant; use tokio::io::{copy_bidirectional, AsyncRead, AsyncWrite, AsyncWriteExt}; -use tokio::sync::Mutex; use tokio::sync::RwLock; use tokio::task::JoinHandle; use tracing::info_span; @@ -34,7 +34,7 @@ pub struct Dispatcher { outbound_manager: ThreadSafeOutboundManager, router: ThreadSafeRouter, resolver: ThreadSafeDNSResolver, - mode: Arc>, + mode: ArcSwap, manager: Arc, } @@ -58,20 +58,19 @@ impl Dispatcher { outbound_manager, router, resolver, - mode: Arc::new(RwLock::new(mode)), + mode: Arc::new(mode).into(), manager: statistics_manager, } } pub async fn set_mode(&self, mode: RunMode) { info!("run mode switched to {}", mode); - let mut m = self.mode.write().await; - *m = mode; + + self.mode.store(Arc::new(mode)); } pub async fn get_mode(&self) -> RunMode { - let mode = self.mode.read().await; - mode.clone() + **self.mode.load() } #[instrument(skip(lhs))] @@ -107,15 +106,15 @@ impl Dispatcher { sess }; - let mode = self.mode.read().await; + let mode = **self.mode.load(); debug!("dispatching {} with mode {}", sess, mode); - let (outbound_name, rule) = match *mode { + let (outbound_name, rule) = match mode { RunMode::Global => (PROXY_GLOBAL, None), RunMode::Rule => self.router.match_route(&sess).await, RunMode::Direct => (PROXY_DIRECT, None), }; - let mgr = self.outbound_manager.read().await; + let mgr = self.outbound_manager.clone(); let handler = mgr.get_outbound(outbound_name).unwrap_or_else(|| { debug!("unknown rule: {}, fallback to direct", outbound_name); mgr.get_outbound(PROXY_DIRECT).unwrap() @@ -186,7 +185,7 @@ impl Dispatcher { let router = self.router.clone(); let outbound_manager = self.outbound_manager.clone(); let resolver = self.resolver.clone(); - let mode = self.mode.clone(); + let mode = **self.mode.load(); let manager = self.manager.clone(); let (mut local_w, mut local_r) = udp_inbound.split(); @@ -236,10 +235,10 @@ impl Dispatcher { let mut packet = packet; packet.dst_addr = sess.destination.clone(); - let mode = mode.read().await; + let mode = mode.clone(); trace!("dispatching {} with mode {}", sess, mode); - let (outbound_name, rule) = match *mode { + let (outbound_name, rule) = match mode { RunMode::Global => (PROXY_GLOBAL, None), RunMode::Rule => router.match_route(&sess).await, RunMode::Direct => (PROXY_DIRECT, None), @@ -249,7 +248,7 @@ impl Dispatcher { let remote_receiver_w = remote_receiver_w.clone(); - let mgr = outbound_manager.read().await; + let mgr = outbound_manager.clone(); let handler = mgr.get_outbound(&outbound_name).unwrap_or_else(|| { debug!("unknown rule: {}, fallback to direct", outbound_name); mgr.get_outbound(PROXY_DIRECT).unwrap() @@ -381,7 +380,7 @@ impl Dispatcher { type OutboundPacketSender = tokio::sync::mpsc::Sender; // outbound packet sender struct TimeoutUdpSessionManager { - map: Arc>, + map: Arc>, cleaner: Option>, } @@ -395,7 +394,7 @@ impl Drop for TimeoutUdpSessionManager { impl TimeoutUdpSessionManager { fn new() -> Self { - let map = Arc::new(Mutex::new(OutboundHandleMap::new())); + let map = Arc::new(RwLock::new(OutboundHandleMap::new())); let timeout = Duration::from_secs(10); let map_cloned = map.clone(); @@ -405,7 +404,7 @@ impl TimeoutUdpSessionManager { tokio::time::sleep(Duration::from_secs(10)).await; trace!("timeout udp session cleaner scanning"); - let mut g = map_cloned.lock().await; + let mut g = map_cloned.write().await; let mut alived = 0; let mut expired = 0; g.0.retain(|k, x| { @@ -445,7 +444,7 @@ impl TimeoutUdpSessionManager { send_handle: JoinHandle<()>, sender: OutboundPacketSender, ) { - let mut map = self.map.lock().await; + let mut map = self.map.write().await; map.insert(outbound_name, src_addr, recv_handle, send_handle, sender); } @@ -454,7 +453,7 @@ impl TimeoutUdpSessionManager { outbound_name: &str, src_addr: SocketAddr, ) -> Option { - let mut map = self.map.lock().await; + let mut map = self.map.write().await; map.get_outbound_sender_mut(outbound_name, src_addr) } } diff --git a/clash_lib/src/app/outbound/manager.rs b/clash_lib/src/app/outbound/manager.rs index e198c0182..87f6385c5 100644 --- a/clash_lib/src/app/outbound/manager.rs +++ b/clash_lib/src/app/outbound/manager.rs @@ -48,7 +48,7 @@ pub struct OutboundManager { static DEFAULT_LATENCY_TEST_URL: &str = "http://www.gstatic.com/generate_204"; -pub type ThreadSafeOutboundManager = Arc>; +pub type ThreadSafeOutboundManager = Arc; impl OutboundManager { pub async fn new( diff --git a/clash_lib/src/app/remote_content_manager/mod.rs b/clash_lib/src/app/remote_content_manager/mod.rs index 6c66f9a92..b8fa6b62a 100644 --- a/clash_lib/src/app/remote_content_manager/mod.rs +++ b/clash_lib/src/app/remote_content_manager/mod.rs @@ -66,6 +66,8 @@ struct ProxyState { pub struct ProxyManager { proxy_state: Arc>>, dns_resolver: ThreadSafeDNSResolver, + + connector_map: Arc>>>, } impl ProxyManager { @@ -73,6 +75,7 @@ impl ProxyManager { Self { dns_resolver, proxy_state: Arc::new(RwLock::new(HashMap::new())), + connector_map: Arc::new(RwLock::new(HashMap::new())), } } @@ -165,7 +168,13 @@ impl ProxyManager { ssl.set_alpn_protos(b"\x02h2\x08http/1.1") .map_err(map_io_error)?; - let connector = HttpsConnector::with_connector(connector, ssl).map_err(map_io_error)?; + let mut g = self.connector_map.write().await; + let connector = g + .entry(name.clone()) + .or_insert(HttpsConnector::with_connector(connector, ssl).map_err(map_io_error)?); + + let connector = connector.clone(); + let client = hyper::Client::builder().build::<_, hyper::Body>(connector); let req = Request::get(url) diff --git a/clash_lib/src/app/remote_content_manager/providers/proxy_provider/plain_provider.rs b/clash_lib/src/app/remote_content_manager/providers/proxy_provider/plain_provider.rs index 701dfcd96..19c9c1c8f 100644 --- a/clash_lib/src/app/remote_content_manager/providers/proxy_provider/plain_provider.rs +++ b/clash_lib/src/app/remote_content_manager/providers/proxy_provider/plain_provider.rs @@ -2,7 +2,6 @@ use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; use erased_serde::Serialize; -use tokio::sync::Mutex; use tracing::debug; use crate::{ @@ -16,14 +15,10 @@ use crate::{ use super::proxy_provider::ProxyProvider; -struct Inner { - hc: Arc, -} - pub struct PlainProvider { name: String, proxies: Vec, - inner: Arc>, + hc: Arc, } impl PlainProvider { @@ -46,11 +41,7 @@ impl PlainProvider { }); } - Ok(Self { - name, - proxies, - inner: Arc::new(Mutex::new(Inner { hc })), - }) + Ok(Self { name, proxies, hc }) } } @@ -93,10 +84,10 @@ impl ProxyProvider for PlainProvider { } async fn touch(&self) { - self.inner.lock().await.hc.touch().await; + self.hc.touch().await; } async fn healthcheck(&self) { - self.inner.lock().await.hc.check().await; + self.hc.check().await; } } diff --git a/clash_lib/src/lib.rs b/clash_lib/src/lib.rs index e2be61208..ad2b06ebe 100644 --- a/clash_lib/src/lib.rs +++ b/clash_lib/src/lib.rs @@ -24,7 +24,7 @@ use tokio::task::JoinHandle; use std::sync::Arc; use thiserror::Error; -use tokio::sync::{broadcast, mpsc, Mutex, RwLock}; +use tokio::sync::{broadcast, mpsc, Mutex}; mod app; mod common; @@ -60,6 +60,12 @@ pub type Runner = futures::future::BoxFuture<'static, ()>; pub struct Options { pub config: Config, pub cwd: Option, + pub rt: Option, +} + +pub enum TokioRuntime { + MultiThread, + SingleThread, } pub enum Config { @@ -83,18 +89,24 @@ pub struct RuntimeController { static RUNTIME_CONTROLLER: Storage> = Storage::new(); pub fn start(opts: Options) -> Result<(), Error> { - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build()? - .block_on(async { - match start_async(opts).await { - Err(e) => { - eprintln!("start error: {}", e); - Err(e) - } - Ok(_) => Ok(()), + let rt = match opts.rt.as_ref().unwrap_or(&TokioRuntime::MultiThread) { + &TokioRuntime::MultiThread => tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build()?, + &TokioRuntime::SingleThread => tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?, + }; + + rt.block_on(async { + match start_async(opts).await { + Err(e) => { + eprintln!("start error: {}", e); + Err(e) } - }) + Ok(_) => Ok(()), + } + }) } pub fn shutdown() -> bool { @@ -148,7 +160,7 @@ async fn start_async(opts: Options) -> Result<(), Error> { let dns_resolver = dns::Resolver::new(&config.dns, cache_store.clone(), mmdb.clone()).await; - let outbound_manager = Arc::new(RwLock::new( + let outbound_manager = Arc::new( OutboundManager::new( config .proxies @@ -173,7 +185,7 @@ async fn start_async(opts: Options) -> Result<(), Error> { cwd.to_string_lossy().to_string(), ) .await?, - )); + ); let router = Arc::new( Router::new( @@ -279,6 +291,7 @@ mod tests { start(Options { config: Config::Str(conf.to_string()), cwd: None, + rt: None, }) .unwrap() }); diff --git a/clash_lib/src/proxy/selector/mod.rs b/clash_lib/src/proxy/selector/mod.rs index 7c862b0e9..8e5f7fb8f 100644 --- a/clash_lib/src/proxy/selector/mod.rs +++ b/clash_lib/src/proxy/selector/mod.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, io, sync::Arc}; use async_trait::async_trait; use erased_serde::Serialize; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, RwLock}; use tracing::debug; use crate::{ @@ -44,7 +44,7 @@ pub struct HandlerOptions { pub struct Handler { opts: HandlerOptions, providers: Vec, - inner: Arc>, + inner: Arc>, } impl Handler { @@ -60,7 +60,7 @@ impl Handler { Self { opts, providers, - inner: Arc::new(Mutex::new(HandlerInner { + inner: Arc::new(RwLock::new(HandlerInner { current: seleted.unwrap_or(current), })), } @@ -69,7 +69,7 @@ impl Handler { async fn selected_proxy(&self, touch: bool) -> AnyOutboundHandler { let proxies = get_proxies_from_providers(&self.providers, touch).await; for proxy in proxies { - if proxy.name() == self.inner.lock().await.current { + if proxy.name() == self.inner.read().await.current { p_debug!("{} selected {}", self.name(), proxy.name()); return proxy; } @@ -83,7 +83,7 @@ impl SelectorControl for Handler { async fn select(&mut self, name: &str) -> Result<(), Error> { let proxies = get_proxies_from_providers(&self.providers, false).await; if proxies.iter().any(|x| x.name() == name) { - self.inner.lock().await.current = name.to_owned(); + self.inner.write().await.current = name.to_owned(); Ok(()) } else { Err(Error::Operation(format!("proxy {} not found", name))) @@ -91,7 +91,7 @@ impl SelectorControl for Handler { } async fn current(&self) -> String { - let inner = self.inner.lock().await.current.to_owned(); + let inner = self.inner.read().await.current.to_owned(); inner } } @@ -165,7 +165,7 @@ impl OutboundHandler for Handler { m.insert("type".to_string(), Box::new(self.proto()) as _); m.insert( "now".to_string(), - Box::new(self.inner.lock().await.current.clone()) as _, + Box::new(self.inner.read().await.current.clone()) as _, ); m.insert( "all".to_string(),