Skip to content

Commit

Permalink
refactor(dns): system dns to support IPv6 (#515)
Browse files Browse the repository at this point in the history
* refactor system dns

* linux

* fmt

* clippy
  • Loading branch information
ibigbug authored Jul 30, 2024
1 parent 303c01a commit 6661a52
Show file tree
Hide file tree
Showing 16 changed files with 184 additions and 146 deletions.
2 changes: 1 addition & 1 deletion clash_lib/src/app/dns/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ impl TryFrom<&crate::config::def::Config> for Config {

Ok(Self {
enable: dc.enable,
ipv6: dc.ipv6,
ipv6: c.ipv6 && dc.ipv6,
nameserver: nameservers,
fallback,
fallback_filter: dc.fallback_filter.clone().into(),
Expand Down
9 changes: 6 additions & 3 deletions clash_lib/src/app/dns/dhcp.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
dns::{
dns_client::DNSNetMode, helper::make_clients, Client, Resolver,
dns_client::DNSNetMode, helper::make_clients, Client, EnhancedResolver,
ThreadSafeDNSClient,
},
proxy::utils::{new_udp_socket, Interface},
Expand Down Expand Up @@ -63,8 +63,11 @@ impl Client for DhcpClient {
dbg_str.push(format!("{:?}", c));
}
debug!("using clients: {:?}", dbg_str);
tokio::time::timeout(DHCP_TIMEOUT, Resolver::batch_exchange(&clients, msg))
.await?
tokio::time::timeout(
DHCP_TIMEOUT,
EnhancedResolver::batch_exchange(&clients, msg),
)
.await?
}
}

Expand Down
5 changes: 2 additions & 3 deletions clash_lib/src/app/dns/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@ mod helper;
pub mod resolver;
mod server;

pub use resolver::SystemResolver;

pub use config::Config;

pub use resolver::Resolver;
pub use resolver::{new as new_resolver, EnhancedResolver, SystemResolver};

pub use server::get_dns_listener;

#[async_trait]
Expand Down
69 changes: 31 additions & 38 deletions clash_lib/src/app/dns/resolver/enhanced.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,12 @@ use crate::dns::{
DomainFilter, FallbackDomainFilter, FallbackIPFilter, GeoIPFilter,
IPNetFilter,
},
resolver::system::SystemResolver,
ClashResolver, Config, ResolverKind, ThreadSafeDNSResolver,
ClashResolver, Config, ResolverKind,
};

static TTL: Duration = Duration::from_secs(60);

pub struct Resolver {
pub struct EnhancedResolver {
ipv6: AtomicBool,
hosts: Option<trie::StringTrie<net::IpAddr>>,
main: Vec<ThreadSafeDNSClient>,
Expand All @@ -49,15 +48,15 @@ pub struct Resolver {
fake_dns: Option<ThreadSafeFakeDns>,
}

impl Resolver {
impl EnhancedResolver {
/// For testing purpose
#[cfg(test)]
pub async fn new_default() -> Self {
use crate::app::dns::dns_client::DNSNetMode;

use crate::app::dns::config::NameServer;

Resolver {
EnhancedResolver {
ipv6: AtomicBool::new(false),
hosts: None,
main: make_clients(
Expand All @@ -79,18 +78,12 @@ impl Resolver {
}
}

pub async fn new_resolver(
pub async fn new(
cfg: &Config,
store: ThreadSafeCacheFile,
mmdb: Arc<Mmdb>,
) -> ThreadSafeDNSResolver {
if !cfg.enable {
return Arc::new(
SystemResolver::new().expect("failed to create system resolver"),
);
}

let default_resolver = Arc::new(Resolver {
) -> Self {
let default_resolver = Arc::new(EnhancedResolver {
ipv6: AtomicBool::new(false),
hosts: None,
main: make_clients(cfg.default_nameserver.clone(), None).await,
Expand All @@ -103,7 +96,7 @@ impl Resolver {
fake_dns: None,
});

let r = Resolver {
Self {
ipv6: AtomicBool::new(cfg.ipv6),
main: make_clients(
cfg.nameserver.clone(),
Expand Down Expand Up @@ -206,9 +199,7 @@ impl Resolver {
}
_ => None,
},
};

Arc::new(r)
}
}

pub async fn batch_exchange(
Expand Down Expand Up @@ -262,7 +253,7 @@ impl Resolver {

match self.exchange(m).await {
Ok(result) => {
let ip_list = Resolver::ip_list_of_message(&result);
let ip_list = EnhancedResolver::ip_list_of_message(&result);
if !ip_list.is_empty() {
Ok(ip_list)
} else {
Expand Down Expand Up @@ -293,15 +284,15 @@ impl Resolver {
let q = message.query().unwrap();

let query = async move {
if Resolver::is_ip_request(q) {
if EnhancedResolver::is_ip_request(q) {
return self.ip_exchange(message).await;
}

if let Some(matched) = self.match_policy(message) {
return Resolver::batch_exchange(matched, message).await;
return EnhancedResolver::batch_exchange(matched, message).await;
}

Resolver::batch_exchange(&self.main, message).await
EnhancedResolver::batch_exchange(&self.main, message).await
};

let rv = query.await;
Expand Down Expand Up @@ -345,7 +336,7 @@ impl Resolver {
if let (Some(_fallback), Some(_fallback_domain_filters), Some(policy)) =
(&self.fallback, &self.fallback_domain_filters, &self.policy)
{
if let Some(domain) = Resolver::domain_name_of_message(m) {
if let Some(domain) = EnhancedResolver::domain_name_of_message(m) {
return policy.search(&domain).map(|n| n.get_data().unwrap());
}
}
Expand All @@ -357,29 +348,31 @@ impl Resolver {
message: &op::Message,
) -> anyhow::Result<op::Message> {
if let Some(matched) = self.match_policy(message) {
return Resolver::batch_exchange(matched, message).await;
return EnhancedResolver::batch_exchange(matched, message).await;
}

if self.should_only_query_fallback(message) {
// self.fallback guaranteed in the above check
return Resolver::batch_exchange(
return EnhancedResolver::batch_exchange(
self.fallback.as_ref().unwrap(),
message,
)
.await;
}

let main_query = Resolver::batch_exchange(&self.main, message);
let main_query = EnhancedResolver::batch_exchange(&self.main, message);

if self.fallback.is_none() {
return main_query.await;
}

let fallback_query =
Resolver::batch_exchange(self.fallback.as_ref().unwrap(), message);
let fallback_query = EnhancedResolver::batch_exchange(
self.fallback.as_ref().unwrap(),
message,
);

if let Ok(main_result) = main_query.await {
let ip_list = Resolver::ip_list_of_message(&main_result);
let ip_list = EnhancedResolver::ip_list_of_message(&main_result);
if !ip_list.is_empty() {
// TODO: only check 1st?
if !self.should_ip_fallback(&ip_list[0]) {
Expand All @@ -395,7 +388,7 @@ impl Resolver {
if let (Some(_), Some(fallback_domain_filters)) =
(&self.fallback, &self.fallback_domain_filters)
{
if let Some(domain) = Resolver::domain_name_of_message(message) {
if let Some(domain) = EnhancedResolver::domain_name_of_message(message) {
for f in fallback_domain_filters.iter() {
if f.apply(domain.as_str()) {
return true;
Expand Down Expand Up @@ -449,7 +442,7 @@ impl Resolver {
}

#[async_trait]
impl ClashResolver for Resolver {
impl ClashResolver for EnhancedResolver {
#[instrument(skip(self))]
async fn resolve(
&self,
Expand Down Expand Up @@ -618,7 +611,7 @@ mod tests {

use crate::app::dns::{
dns_client::{DNSNetMode, DnsClient, Opts},
resolver::enhanced::Resolver,
resolver::enhanced::EnhancedResolver,
ThreadSafeDNSClient,
};

Expand Down Expand Up @@ -688,7 +681,7 @@ mod tests {
#[ignore = "network unstable on CI"]
async fn test_dot_resolve() {
let c = DnsClient::new_client(Opts {
r: Some(Arc::new(Resolver::new_default().await)),
r: Some(Arc::new(EnhancedResolver::new_default().await)),
host: "dns.google".to_string(),
port: 853,
net: DNSNetMode::DoT,
Expand All @@ -703,7 +696,7 @@ mod tests {
#[tokio::test]
#[ignore = "network unstable on CI"]
async fn test_doh_resolve() {
let default_resolver = Arc::new(Resolver::new_default().await);
let default_resolver = Arc::new(EnhancedResolver::new_default().await);

let c = DnsClient::new_client(Opts {
r: Some(default_resolver.clone()),
Expand Down Expand Up @@ -741,11 +734,11 @@ mod tests {
q.set_query_type(rr::RecordType::A);
m.add_query(q);

let r = Resolver::batch_exchange(&vec![c.clone()], &m)
let r = EnhancedResolver::batch_exchange(&vec![c.clone()], &m)
.await
.expect("should exchange");

let ips = Resolver::ip_list_of_message(&r);
let ips = EnhancedResolver::ip_list_of_message(&r);

assert!(!ips.is_empty());
assert!(!ips[0].is_unspecified());
Expand All @@ -757,11 +750,11 @@ mod tests {
q.set_query_type(rr::RecordType::AAAA);
m.add_query(q);

let r = Resolver::batch_exchange(&vec![c.clone()], &m)
let r = EnhancedResolver::batch_exchange(&vec![c.clone()], &m)
.await
.expect("should exchange");

let ips = Resolver::ip_list_of_message(&r);
let ips = EnhancedResolver::ip_list_of_message(&r);

assert!(!ips.is_empty());
assert!(!ips[0].is_unspecified());
Expand Down
27 changes: 26 additions & 1 deletion clash_lib/src/app/dns/resolver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,30 @@ mod system;
#[path = "system_non_linux.rs"]
mod system;

pub use enhanced::Resolver;
use std::sync::Arc;

pub use enhanced::EnhancedResolver;
pub use system::SystemResolver;

use crate::{app::profile::ThreadSafeCacheFile, common::mmdb::Mmdb};

use super::{Config, ThreadSafeDNSResolver};

pub async fn new(
cfg: &Config,
store: Option<ThreadSafeCacheFile>,
mmdb: Option<Arc<Mmdb>>,
) -> ThreadSafeDNSResolver {
if cfg.enable {
match (store, mmdb) {
(Some(store), Some(mmdb)) => {
Arc::new(EnhancedResolver::new(cfg, store, mmdb).await)
}
_ => panic!("enhanced resolver requires cache store and mmdb"),
}
} else {
Arc::new(
SystemResolver::new(cfg.ipv6).expect("failed to create system resolver"),
)
}
}
37 changes: 18 additions & 19 deletions clash_lib/src/app/dns/resolver/system_linux.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
use std::sync::atomic::AtomicBool;

use async_trait::async_trait;
use hickory_resolver::{
name_server::{GenericConnector, TokioRuntimeProvider},
AsyncResolver,
};
use rand::seq::IteratorRandom;
use tracing::warn;

use crate::app::dns::{ClashResolver, ResolverKind};

pub struct SystemResolver(AsyncResolver<GenericConnector<TokioRuntimeProvider>>);
pub struct SystemResolver {
inner: AsyncResolver<GenericConnector<TokioRuntimeProvider>>,
ipv6: AtomicBool,
}

/// Bug in libc, use tokio impl instead: https://sourceware.org/bugzilla/show_bug.cgi?id=10652
impl SystemResolver {
pub fn new() -> anyhow::Result<Self> {
warn!(
"Default dns resolver doesn't support ipv6, please enable clash dns \
resolver if you need ipv6 support."
);

Ok(Self(
hickory_resolver::AsyncResolver::tokio_from_system_conf()?,
))
pub fn new(ipv6: bool) -> anyhow::Result<Self> {
Ok(Self {
inner: hickory_resolver::AsyncResolver::tokio_from_system_conf()?,
ipv6: AtomicBool::new(ipv6),
})
}
}

Expand All @@ -31,7 +31,7 @@ impl ClashResolver for SystemResolver {
host: &str,
_: bool,
) -> anyhow::Result<Option<std::net::IpAddr>> {
let response = self.0.lookup_ip(host).await?;
let response = self.inner.lookup_ip(host).await?;
Ok(response
.iter()
.filter(|x| self.ipv6() || x.is_ipv4())
Expand All @@ -43,7 +43,7 @@ impl ClashResolver for SystemResolver {
host: &str,
_: bool,
) -> anyhow::Result<Option<std::net::Ipv4Addr>> {
let response = self.0.ipv4_lookup(host).await?;
let response = self.inner.ipv4_lookup(host).await?;
Ok(response.iter().map(|x| x.0).choose(&mut rand::thread_rng()))
}

Expand All @@ -52,7 +52,7 @@ impl ClashResolver for SystemResolver {
host: &str,
_: bool,
) -> anyhow::Result<Option<std::net::Ipv6Addr>> {
let response = self.0.ipv6_lookup(host).await?;
let response = self.inner.ipv6_lookup(host).await?;
Ok(response.iter().map(|x| x.0).choose(&mut rand::thread_rng()))
}

Expand All @@ -64,12 +64,11 @@ impl ClashResolver for SystemResolver {
}

fn ipv6(&self) -> bool {
// TODO: support ipv6
false
self.ipv6.load(std::sync::atomic::Ordering::Relaxed)
}

fn set_ipv6(&self, _: bool) {
// NOOP
fn set_ipv6(&self, val: bool) {
self.ipv6.store(val, std::sync::atomic::Ordering::Relaxed);
}

fn kind(&self) -> ResolverKind {
Expand Down Expand Up @@ -113,7 +112,7 @@ mod tests {

#[tokio::test]
async fn test_system_resolver_default_config() {
let resolver = SystemResolver::new().unwrap();
let resolver = SystemResolver::new(false).unwrap();
let response = resolver.resolve("www.google.com", false).await.unwrap();
assert!(response.is_some());
}
Expand Down
Loading

0 comments on commit 6661a52

Please sign in to comment.