diff --git a/Cargo.Bazel.lock b/Cargo.Bazel.lock index 45a90e1b2..ffb0360c4 100644 --- a/Cargo.Bazel.lock +++ b/Cargo.Bazel.lock @@ -1,5 +1,5 @@ { - "checksum": "347ca0fe4b75977928f2e59b15583a8514505bc59a8dea46398337f26e3344a8", + "checksum": "b6fd3963e35e2a6beb1e16662f2a7be5ac7bb50ddcaf7e47880459331b4ad8f0", "crates": { "addr2line 0.20.0": { "name": "addr2line", @@ -3945,6 +3945,10 @@ "id": "network-interface 1.0.3", "target": "network_interface" }, + { + "id": "once_cell 1.18.0", + "target": "once_cell" + }, { "id": "prost 0.11.9", "target": "prost" diff --git a/Cargo.lock b/Cargo.lock index 5ed10b488..572cb189a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -810,6 +810,7 @@ dependencies = [ "murmur3", "netstack-lwip", "network-interface", + "once_cell", "prost", "public-suffix", "rand", diff --git a/clash_lib/Cargo.toml b/clash_lib/Cargo.toml index 00f7ab2d1..2d715627b 100644 --- a/clash_lib/Cargo.toml +++ b/clash_lib/Cargo.toml @@ -40,6 +40,7 @@ boring-sys = { git = "https://github.com/Watfaq/boring.git", rev = "24c006f" } hyper-boring = { git = "https://github.com/Watfaq/boring.git", rev = "24c006f" } tokio-boring = { git = "https://github.com/Watfaq/boring.git", rev = "24c006f" } ip_network_table-deps-treebitmap = "0.5.0" +once_cell = "1.18.0" crc32fast = "1.3.2" brotli = "3.3.4" diff --git a/clash_lib/src/app/dns/dns_client.rs b/clash_lib/src/app/dns/dns_client.rs index 87bd627fa..87e9123af 100644 --- a/clash_lib/src/app/dns/dns_client.rs +++ b/clash_lib/src/app/dns/dns_client.rs @@ -4,12 +4,18 @@ use std::str::FromStr; use std::{net, sync::Arc, time::Duration}; use async_trait::async_trait; -use rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore}; + +use rustls::ClientConfig; +use tokio::sync::RwLock; +use tokio::task::JoinHandle; +use tracing::warn; +use trust_dns_client::client::AsyncClient; use trust_dns_client::{ client, proto::iocompat::AsyncIoTokioAsStd, tcp::TcpClientStream, udp::UdpClientStream, }; +use trust_dns_proto::error::ProtoError; -use crate::common::tls; +use crate::common::tls::{self, GLOBAL_ROOT_STORE}; use crate::dns::dhcp::DhcpClient; use crate::dns::ThreadSafeDNSClient; use tokio::net::TcpStream as TokioTcpStream; @@ -72,9 +78,23 @@ pub struct Opts { pub iface: Option, } +enum DnsConfig { + Udp(net::SocketAddr, Option), + Tcp(net::SocketAddr, Option), + Tls(net::SocketAddr, String, Option), + Https(net::SocketAddr, String, Option), +} + +struct Inner { + c: client::AsyncClient, + bg_handle: Option>>, +} + /// DnsClient pub struct DnsClient { - c: client::AsyncClient, + inner: Arc>, + + cfg: DnsConfig, // debug purpose host: String, @@ -116,25 +136,17 @@ impl DnsClient { match other { DNSNetMode::UDP => { - let stream = UdpClientStream::::with_bind_addr_and_timeout( - net::SocketAddr::new(ip, opts.port), - // TODO: simplify this match - match &opts.iface { - Some(iface) => match iface { - Interface::IpAddr(ip) => Some(SocketAddr::new(ip.clone(), 0)), - _ => None, - }, - _ => None, - }, - Duration::from_secs(5), - ); - let (client, bg) = client::AsyncClient::connect(stream) - .await - .map_err(|x| Error::DNSError(x.to_string()))?; + let cfg = + DnsConfig::Udp(net::SocketAddr::new(ip, opts.port), opts.iface.clone()); + let (client, bg) = dns_stream_builder(&cfg).await?; - tokio::spawn(bg); Ok(Arc::new(Self { - c: client, + inner: Arc::new(RwLock::new(Inner { + c: client, + bg_handle: Some(bg), + })), + + cfg, host: opts.host, port: opts.port, @@ -143,24 +155,18 @@ impl DnsClient { })) } DNSNetMode::TCP => { - let (stream, sender) = TcpClientStream::>::with_bind_addr_and_timeout( - net::SocketAddr::new(ip, opts.port), - match &opts.iface { - Some(iface) => match iface { - Interface::IpAddr(ip) => Some(SocketAddr::new(ip.clone(), 0)), - _ => None, - }, - _ => None, - }, - Duration::from_secs(5), - ); + let cfg = + DnsConfig::Tcp(net::SocketAddr::new(ip, opts.port), opts.iface.clone()); + + let (client, bg) = dns_stream_builder(&cfg).await?; - let (client, bg) = client::AsyncClient::new(stream, sender, None) - .await - .map_err(|x| Error::DNSError(x.to_string()))?; - tokio::spawn(bg); Ok(Arc::new(Self { - c: client, + inner: Arc::new(RwLock::new(Inner { + c: client, + bg_handle: Some(bg), + })), + + cfg, host: opts.host, port: opts.port, @@ -169,49 +175,21 @@ impl DnsClient { })) } DNSNetMode::DoT => { - let mut root_store = RootCertStore::empty(); - root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map( - |ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - }, - )); - let mut tls_config = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store) - .with_no_client_auth(); - tls_config.alpn_protocols = vec!["dot".into()]; - - let (stream, sender) = tls_client_connect_with_bind_addr::< - AsyncIoTokioAsStd, - >( + let cfg = DnsConfig::Tls( net::SocketAddr::new(ip, opts.port), - match &opts.iface { - Some(iface) => match iface { - Interface::IpAddr(ip) => Some(SocketAddr::new(ip.clone(), 0)), - _ => None, - }, - _ => None, - }, opts.host.clone(), - Arc::new(tls_config), + opts.iface.clone(), ); - let (client, bg) = client::AsyncClient::with_timeout( - stream, - sender, - Duration::from_secs(5), - None, - ) - .await - .map_err(|x| Error::DNSError(x.to_string()))?; + let (client, bg) = dns_stream_builder(&cfg).await?; - tokio::spawn(bg); Ok(Arc::new(Self { - c: client, + inner: Arc::new(RwLock::new(Inner { + c: client, + bg_handle: Some(bg), + })), + + cfg, host: opts.host, port: opts.port, @@ -220,51 +198,21 @@ impl DnsClient { })) } DNSNetMode::DoH => { - let mut root_store = RootCertStore::empty(); - root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map( - |ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - }, - )); - let mut tls_config = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store) - .with_no_client_auth(); - tls_config.alpn_protocols = vec!["h2".into()]; - - if opts.host == ip.to_string() { - tls_config - .dangerous() - .set_certificate_verifier(Arc::new(tls::NoHostnameTlsVerifier)); - } - - let mut stream_builder = - HttpsClientStreamBuilder::with_client_config(Arc::new(tls_config)); - if let Some(iface) = &opts.iface { - match iface { - Interface::IpAddr(ip) => { - stream_builder.bind_addr(net::SocketAddr::new(ip.clone(), 0)) - } - _ => {} - } - } - let stream = stream_builder.build::>( + let cfg = DnsConfig::Https( net::SocketAddr::new(ip, opts.port), opts.host.clone(), + opts.iface.clone(), ); - let (client, bg) = client::AsyncClient::connect(stream) - .await - .map_err(|x| Error::DNSError(x.to_string()))?; + let (client, bg) = dns_stream_builder(&cfg).await?; - tokio::spawn(bg); Ok(Arc::new(Self { - c: client, + inner: Arc::new(RwLock::new(Inner { + c: client, + bg_handle: Some(bg), + })), + cfg, host: opts.host, port: opts.port, net: opts.net, @@ -296,9 +244,26 @@ impl Client for DnsClient { } async fn exchange(&self, msg: &Message) -> anyhow::Result { + let mut inner = self.inner.write().await; + if let Some(bg) = &inner.bg_handle { + if bg.is_finished() { + warn!("dns client background task is finished, likely connection closed, restarting a new one"); + let (client, bg) = dns_stream_builder(&self.cfg).await?; + inner.c = client; + inner.bg_handle.replace(bg); + } + } else { + unreachable!("dns bg task handle dangling"); + } + + drop(inner); + let mut req = DnsRequest::new(msg.clone(), DnsRequestOptions::default()); req.set_id(rand::random::()); - self.c + self.inner + .read() + .await + .c .send(req) .first_answer() .await @@ -306,3 +271,106 @@ impl Client for DnsClient { .map(|x| x.into()) } } + +async fn dns_stream_builder( + cfg: &DnsConfig, +) -> Result<(AsyncClient, JoinHandle>), Error> { + match cfg { + DnsConfig::Udp(addr, iface) => { + let stream = UdpClientStream::::with_bind_addr_and_timeout( + net::SocketAddr::new(addr.ip(), addr.port()), + // TODO: simplify this match + match iface { + Some(iface) => match iface { + Interface::IpAddr(ip) => Some(SocketAddr::new(ip.clone(), 0)), + _ => None, + }, + _ => None, + }, + Duration::from_secs(5), + ); + client::AsyncClient::connect(stream) + .await + .map(|(x, y)| (x, tokio::spawn(y))) + .map_err(|x| Error::DNSError(x.to_string())) + } + DnsConfig::Tcp(addr, iface) => { + let (stream, sender) = + TcpClientStream::>::with_bind_addr_and_timeout( + net::SocketAddr::new(addr.ip(), addr.port()), + match iface { + Some(iface) => match iface { + Interface::IpAddr(ip) => Some(SocketAddr::new(ip.clone(), 0)), + _ => None, + }, + _ => None, + }, + Duration::from_secs(5), + ); + + client::AsyncClient::new(stream, sender, None) + .await + .map(|(x, y)| (x, tokio::spawn(y))) + .map_err(|x| Error::DNSError(x.to_string())) + } + DnsConfig::Tls(addr, host, iface) => { + let mut tls_config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(GLOBAL_ROOT_STORE.clone()) + .with_no_client_auth(); + tls_config.alpn_protocols = vec!["dot".into()]; + + let (stream, sender) = + tls_client_connect_with_bind_addr::>( + net::SocketAddr::new(addr.ip(), addr.port()), + match iface { + Some(iface) => match iface { + Interface::IpAddr(ip) => Some(SocketAddr::new(ip.clone(), 0)), + _ => None, + }, + _ => None, + }, + host.clone(), + Arc::new(tls_config), + ); + + client::AsyncClient::with_timeout(stream, sender, Duration::from_secs(5), None) + .await + .map(|(x, y)| (x, tokio::spawn(y))) + .map_err(|x| Error::DNSError(x.to_string())) + } + DnsConfig::Https(addr, host, iface) => { + let mut tls_config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(GLOBAL_ROOT_STORE.clone()) + .with_no_client_auth(); + tls_config.alpn_protocols = vec!["h2".into()]; + + if host == &addr.ip().to_string() { + tls_config + .dangerous() + .set_certificate_verifier(Arc::new(tls::NoHostnameTlsVerifier)); + } + + let mut stream_builder = + HttpsClientStreamBuilder::with_client_config(Arc::new(tls_config)); + if let Some(iface) = iface { + match iface { + Interface::IpAddr(ip) => { + stream_builder.bind_addr(net::SocketAddr::new(ip.clone(), 0)) + } + _ => {} + } + } + let stream = stream_builder.build::>( + net::SocketAddr::new(addr.ip(), addr.port()), + host.clone(), + ); + + client::AsyncClient::connect(stream) + .await + .map(|(x, y)| (x, tokio::spawn(y))) + .map_err(|x| Error::DNSError(x.to_string())) + } + } +} diff --git a/clash_lib/src/common/tls.rs b/clash_lib/src/common/tls.rs index 84aa9cbf1..66bd265d8 100644 --- a/clash_lib/src/common/tls.rs +++ b/clash_lib/src/common/tls.rs @@ -1,8 +1,27 @@ -use rustls::client::{ServerCertVerified, ServerCertVerifier, WebPkiVerifier}; +use once_cell::sync::Lazy; +use rustls::{ + client::{ServerCertVerified, ServerCertVerifier, WebPkiVerifier}, + OwnedTrustAnchor, RootCertStore, +}; use tracing::warn; use rustls::{Certificate, ServerName}; -use std::time::SystemTime; +use std::{sync::Arc, time::SystemTime}; + +pub static GLOBAL_ROOT_STORE: Lazy> = Lazy::new(|| global_root_store()); + +fn global_root_store() -> Arc { + let mut root_store = RootCertStore::empty(); + root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + + Arc::new(root_store) +} /// Warning: NO validation on certs. pub struct DummyTlsVerifier; diff --git a/clash_lib/src/proxy/transport/tls.rs b/clash_lib/src/proxy/transport/tls.rs index 5890c30f4..f70826f90 100644 --- a/clash_lib/src/proxy/transport/tls.rs +++ b/clash_lib/src/proxy/transport/tls.rs @@ -1,10 +1,13 @@ use std::{io, sync::Arc}; -use rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}; +use rustls::{ClientConfig, ServerName}; use serde::Serialize; use tokio_rustls::TlsConnector; -use crate::{common::tls, proxy::AnyStream}; +use crate::{ + common::tls::{self, GLOBAL_ROOT_STORE}, + proxy::AnyStream, +}; #[derive(Serialize, Clone)] pub struct TLSOptions { @@ -14,18 +17,9 @@ pub struct TLSOptions { } pub async fn wrap_stream(stream: AnyStream, opt: TLSOptions) -> io::Result { - // TODO save root store to avoid re-creating it - let mut root_store = RootCertStore::empty(); - root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); let mut tls_config = ClientConfig::builder() .with_safe_defaults() - .with_root_certificates(root_store) + .with_root_certificates(GLOBAL_ROOT_STORE.clone()) .with_no_client_auth(); tls_config.alpn_protocols = opt .alpn