From 6bce6f013a1eeb4b001f2ac1f71be2a296aabfd5 Mon Sep 17 00:00:00 2001 From: Yuwei Ba Date: Thu, 28 Dec 2023 03:19:52 +1100 Subject: [PATCH] enable ipv6 on wireguard (#234) * enable ipv6 on wg * clippy --- clash/tests/data/config/wg.yaml | 15 +- clash_lib/src/proxy/converters/wireguard.rs | 29 ++- clash_lib/src/proxy/wg/device.rs | 203 ++++++++++++++------ clash_lib/src/proxy/wg/mod.rs | 34 ++-- clash_lib/src/proxy/wg/wireguard.rs | 12 +- 5 files changed, 193 insertions(+), 100 deletions(-) diff --git a/clash/tests/data/config/wg.yaml b/clash/tests/data/config/wg.yaml index dae22bb5c..186212f97 100644 --- a/clash/tests/data/config/wg.yaml +++ b/clash/tests/data/config/wg.yaml @@ -44,15 +44,16 @@ experimental: proxies: - name: "wg" type: wireguard - server: 10.0.0.17 - port: 51820 - private-key: 2AS8PSccSenWrws5ExglmpwjVBub9Oy9X3zOlk6heHU= - ip: 192.168.2.2 - public-key: MAZPwQBniuXmQf5w8BwM3owlO7Kw07rzyZUXxOvsF3w= - allowed-ips: ['0.0.0.0/0'] + server: engage.cloudflareclient.com + port: 2408 + private-key: uIwDn4c7656E/1pHkJu23ZOe/4SuCnL+vL+jE2s4MHE= + ip: 172.16.0.2/32 + ipv6: 2606:4700:110:8e5e:fa1:3f30:c077:e17c/128 + public-key: bmXOC+F1FxEMF9dyiK2H5/1SUtzH0JuVo51h2wPfgyo= + allowed-ips: ['0.0.0.0/0', '::/0'] remote-dns-resolve: true dns: - - 8.8.8.8 + - 1.1.1.1 udp: true diff --git a/clash_lib/src/proxy/converters/wireguard.rs b/clash_lib/src/proxy/converters/wireguard.rs index 957387867..b6e152a9a 100644 --- a/clash_lib/src/proxy/converters/wireguard.rs +++ b/clash_lib/src/proxy/converters/wireguard.rs @@ -1,3 +1,5 @@ +use ipnet::IpNet; + use crate::{ config::internal::proxy::OutboundWireguard, proxy::{ @@ -26,14 +28,31 @@ impl TryFrom<&OutboundWireguard> for AnyOutboundHandler { port: s.port, ip: s .ip - .parse() - .map_err(|x| Error::InvalidConfig(format!("invalid ip address: {}", x)))?, + .parse::() + .map(|x| match x.addr() { + std::net::IpAddr::V4(v4) => Ok(v4), + std::net::IpAddr::V6(_) => Err(Error::InvalidConfig( + "invalid ip address: put an v4 address here".to_owned(), + )), + }) + .map_err(|x| { + Error::InvalidConfig(format!("invalid ip address: {}, {}", x, s.ip)) + })??, ipv6: s .ipv6 .as_ref() - .map(|x| { - x.parse() - .map_err(|x| Error::InvalidConfig(format!("invalid ipv6 address: {}", x))) + .and_then(|x| { + x.parse::() + .map(|x| match x.addr() { + std::net::IpAddr::V4(_) => Err(Error::InvalidConfig( + "invalid ip address: put an v6 address here".to_owned(), + )), + std::net::IpAddr::V6(v6) => Ok(v6), + }) + .map_err(|e| { + Error::InvalidConfig(format!("invalid ipv6 address: {}, {}", e, x)) + }) + .ok() }) .transpose()?, private_key: s.private_key.to_owned(), diff --git a/clash_lib/src/proxy/wg/device.rs b/clash_lib/src/proxy/wg/device.rs index 34e6e3d9b..8bd92970b 100644 --- a/clash_lib/src/proxy/wg/device.rs +++ b/clash_lib/src/proxy/wg/device.rs @@ -1,12 +1,14 @@ use std::{ collections::{HashMap, VecDeque}, - net::{IpAddr, SocketAddr}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, sync::Arc, time::Duration, }; use bytes::{BufMut, Bytes, BytesMut}; use futures::{SinkExt, StreamExt}; + +use rand::seq::SliceRandom; use smoltcp::{ iface::{Config, Interface, SocketHandle, SocketSet}, phy::Device, @@ -56,7 +58,8 @@ enum SenderType { } pub struct DeviceManager { - addr: IpAddr, // TODO: support ipv6 + addr: Ipv4Addr, + addr_v6: Option, resolver: ThreadSafeDNSResolver, dns_servers: Vec, @@ -74,7 +77,8 @@ pub struct DeviceManager { impl DeviceManager { pub fn new( - addr: IpAddr, + addr: Ipv4Addr, + addr_v6: Option, resolver: ThreadSafeDNSResolver, dns_servers: Vec, packet_notifier: Receiver<()>, @@ -89,6 +93,8 @@ impl DeviceManager { Self { addr, + addr_v6, + resolver, dns_servers, @@ -131,47 +137,99 @@ impl DeviceManager { pub async fn look_up_dns(&self, host: &str, server: SocketAddr) -> Option { debug!("looking up {} on {}", host, server); - let mut socket = Self::new_udp_socket(self).await; - let mut msg = hickory_proto::op::Message::new(); - - msg.add_query({ - let mut q = hickory_proto::op::Query::new(); - let name = hickory_proto::rr::Name::from_str_relaxed(host) - .unwrap() - .append_domain(&hickory_proto::rr::Name::root()) - .unwrap(); - q.set_name(name); - q.set_query_type(hickory_proto::rr::RecordType::A); - q - }); - - msg.set_recursion_desired(true); - let pkt = UdpPacket::new(msg.to_vec().unwrap(), SocksAddr::any_ipv4(), server.into()); - - socket.feed(pkt).await.ok()?; - socket.flush().await.ok()?; - trace!("sent dns query: {:?}", msg); + #[async_recursion::async_recursion] + async fn query( + rtype: hickory_proto::rr::RecordType, + host: &str, + server: SocketAddr, + mut socket: UdpPair, + ) -> Option { + let mut msg = hickory_proto::op::Message::new(); + + msg.add_query({ + let mut q = hickory_proto::op::Query::new(); + let name = hickory_proto::rr::Name::from_str_relaxed(host) + .unwrap() + .append_domain(&hickory_proto::rr::Name::root()) + .unwrap(); + q.set_name(name); + q.set_query_type(rtype); + q + }); + + msg.set_recursion_desired(true); + + let pkt = UdpPacket::new(msg.to_vec().unwrap(), SocksAddr::any_ipv4(), server.into()); + + socket.feed(pkt).await.ok()?; + socket.flush().await.ok()?; + trace!("sent dns query: {:?}", msg); + + let pkt = match tokio::time::timeout(Duration::from_secs(5), socket.next()).await { + Ok(Some(pkt)) => pkt, + _ => { + warn!("wg dns query timed out with server {server}"); + return None; + } + }; + + let msg = hickory_proto::op::Message::from_vec(&pkt.data).ok()?; + trace!("got dns response: {:?}", msg); + for ans in msg.answers().iter() { + if ans.record_type() == rtype { + if let Some(data) = ans.data() { + match (rtype, data) { + (_, hickory_proto::rr::RData::CNAME(cname)) => { + debug!( + "{} resolved to CNAME {}, asking recursively", + host, cname.0 + ); + return query(rtype, &cname.0.to_ascii(), server, socket).await; + } + ( + hickory_proto::rr::RecordType::A, + hickory_proto::rr::RData::A(addr), + ) => { + return Some(std::net::IpAddr::V4(addr.0)); + } + ( + hickory_proto::rr::RecordType::AAAA, + hickory_proto::rr::RData::AAAA(addr), + ) => { + return Some(std::net::IpAddr::V6(addr.0)); + } + _ => return None, + } + }; + } + } + None + } - let pkt = match tokio::time::timeout(Duration::from_secs(5), socket.next()).await { - Ok(Some(pkt)) => pkt, - _ => { - warn!("wg dns query timed out with server {server}"); - return None; + let socket = self.new_udp_socket().await; + let v4_query = query(hickory_proto::rr::RecordType::A, host, server, socket); + if self.addr_v6.is_some() { + let socket = self.new_udp_socket().await; + let v6_query = query(hickory_proto::rr::RecordType::AAAA, host, server, socket); + match tokio::time::timeout( + Duration::from_secs(5), + futures::future::join(v4_query, v6_query), + ) + .await + { + Ok((_, Some(v6))) => Some(v6), + Ok((v4, _)) => v4, + _ => { + warn!("wg dns query timed out with server {server}"); + None + } } - }; - - let msg = hickory_proto::op::Message::from_vec(&pkt.data).ok()?; - trace!("got dns response: {:?}", msg); - msg.answers() - .iter() - .find_map(|ans| match ans.record_type() { - hickory_proto::rr::RecordType::A => ans.data().and_then(|data| match data { - hickory_proto::rr::RData::A(addr) => Some(std::net::IpAddr::V4(addr.0)), - _ => None, - }), - _ => None, - }) + } else { + tokio::time::timeout(Duration::from_secs(5), v4_query) + .await + .ok()? + } } pub async fn poll_sockets(&self, mut device: VirtualIpDevice) { @@ -181,6 +239,10 @@ impl DeviceManager { let mut iface = Interface::new(config, &mut device, Instant::now()); iface.update_ip_addrs(|addrs| { addrs.push(IpCidr::new(self.addr.into(), 32)).unwrap(); + + if let Some(addr_v6) = self.addr_v6 { + addrs.push(IpCidr::new(addr_v6.into(), 128)).unwrap(); + } }); let (device_sender, mut device_receiver) = tokio::sync::mpsc::channel(1024); @@ -206,7 +268,10 @@ impl DeviceManager { .connect( iface.context(), remote, - (self.addr, self.get_ephemeral_tcp_port().await), + (match remote { + SocketAddr::V4(_) => IpAddr::V4(self.addr), + SocketAddr::V6(_) => IpAddr::V6(self.addr_v6.unwrap()), + }, self.get_ephemeral_tcp_port().await), ) .unwrap(); @@ -234,13 +299,7 @@ impl DeviceManager { socket_pairs.insert(handle, SenderType::Tcp(sender)); tcp_queue.insert(handle, VecDeque::new()); } - Socket::Udp(mut socket, sender, mut receiver) => { - socket - .bind( - (self.addr, self.get_ephemeral_udp_port().await), - ) - .unwrap(); - + Socket::Udp(socket, sender, mut receiver) => { let handle = sockets.add(socket); let device_sender = device_sender.clone(); @@ -401,29 +460,47 @@ impl DeviceManager { let ip = match &pkt.dst_addr { SocksAddr::Ip(addr) => addr.ip(), SocksAddr::Domain(domain, _) => { - if let Some(dns_server) = self.dns_servers.get(0) { - let ip = self.look_up_dns(domain, *dns_server).await; - if let Some(ip) = ip { - debug!("host {} resolved to {} on wg stack", domain, ip); - ip - } else { - warn!("failed to resolve domain on wireguard: {}", domain); - continue; - } + if let Ok(ip) = domain.parse::() { + ip } else { - match self.resolver.resolve(domain, false).await { - Ok(Some(ip)) => { - debug!("host {} resolved to {} on local", domain, ip); + let dns_server = self.dns_servers.choose(&mut rand::thread_rng()); + if let Some(dns_server) = dns_server { + let ip = self.look_up_dns(domain, *dns_server).await; + if let Some(ip) = ip { + debug!("host {} resolved to {} on wg stack", domain, ip); ip - } - _ => { + } else { warn!("failed to resolve domain on wireguard: {}", domain); continue; } + } else { + match self.resolver.resolve(domain, false).await { + Ok(Some(ip)) => { + debug!("host {} resolved to {} on local", domain, ip); + ip + } + _ => { + warn!("failed to resolve domain on wireguard: {}", domain); + continue; + } + } } } } }; + + if !socket.is_open() { + let local_addr: IpAddr = match ip { + IpAddr::V4(_) => self.addr.into(), + IpAddr::V6(_) => self.addr_v6.unwrap().into(), + }; + socket + .bind( + (local_addr, self.get_ephemeral_udp_port().await), + ) + .unwrap(); + } + match socket.send_slice(&pkt.data, (ip, pkt.dst_addr.port())) { Ok(_) => {} Err(e) => { diff --git a/clash_lib/src/proxy/wg/mod.rs b/clash_lib/src/proxy/wg/mod.rs index 774fd3f0e..22ea0d817 100644 --- a/clash_lib/src/proxy/wg/mod.rs +++ b/clash_lib/src/proxy/wg/mod.rs @@ -27,6 +27,7 @@ use async_trait::async_trait; use futures::TryFutureExt; use ipnet::IpNet; +use rand::seq::SliceRandom; use tokio::sync::OnceCell; use tracing::debug; @@ -104,8 +105,6 @@ impl Handler { .transpose()? .unwrap_or_default(); - debug!("allowed_ips: {:?}", allowed_ips); - // we shouldn't create a new tunnel for each connection let wg = wireguard::WireguardTunnel::new( Config { @@ -123,11 +122,8 @@ impl Handler { .as_ref() .map(|s| s.parse::().unwrap().0.into()), remote_endpoint: (server_ip, self.opts.port).into(), - source_peer_ip: self - .opts - .ipv6 - .map(|ip| ip.into()) - .unwrap_or(self.opts.ip.into()), + source_peer_ip: self.opts.ip, + source_peer_ipv6: self.opts.ipv6, keepalive_seconds: Some(10), allowed_ips, }, @@ -152,7 +148,8 @@ impl Handler { ); let device_manager = Arc::new(device::DeviceManager::new( - self.opts.ip.into(), + self.opts.ip, + self.opts.ipv6, resolver, if self.opts.remote_dns_resolve { self.opts @@ -223,22 +220,19 @@ impl OutboundHandler for Handler { "use remote dns to resolve domain: {}", sess.destination.host() ); + let server = self + .opts + .dns + .as_ref() + .unwrap() + .choose(&mut rand::thread_rng()) + .unwrap(); + inner .device_manager .look_up_dns( &sess.destination.host(), - ( - self.opts - .dns - .as_ref() - .unwrap() - .first() - .unwrap() - .parse::() - .unwrap(), - 53, - ) - .into(), + (server.parse::().unwrap(), 53).into(), ) .await .ok_or(new_io_error("invalid remote address"))? diff --git a/clash_lib/src/proxy/wg/wireguard.rs b/clash_lib/src/proxy/wg/wireguard.rs index d0f7c1a37..4eee6be91 100644 --- a/clash_lib/src/proxy/wg/wireguard.rs +++ b/clash_lib/src/proxy/wg/wireguard.rs @@ -28,7 +28,8 @@ use crate::{proxy::utils::new_udp_socket, Error}; use super::events::PortProtocol; pub struct WireguardTunnel { - pub(crate) source_peer_ip: IpAddr, + pub(crate) source_peer_ip: Ipv4Addr, + pub(crate) source_peer_ipv6: Option, peer: Arc>, udp: UdpSocket, pub(crate) endpoint: SocketAddr, @@ -54,7 +55,8 @@ pub struct Config { pub endpoint_public_key: PublicKey, pub preshared_key: Option, pub remote_endpoint: SocketAddr, - pub source_peer_ip: IpAddr, + pub source_peer_ip: Ipv4Addr, + pub source_peer_ipv6: Option, pub keepalive_seconds: Option, pub allowed_ips: Vec, } @@ -65,7 +67,6 @@ impl WireguardTunnel { packet_writer: Sender<(PortProtocol, Bytes)>, packet_reader: Receiver, ) -> Result { - let source_peer_ip = config.source_peer_ip; let peer = Tunn::new( config.private_key, config.endpoint_public_key, @@ -87,7 +88,8 @@ impl WireguardTunnel { .await?; Ok(Self { - source_peer_ip, + source_peer_ip: config.source_peer_ip, + source_peer_ipv6: config.source_peer_ipv6, peer: Arc::new(Mutex::new(peer)), udp, endpoint: remote_endpoint, @@ -332,7 +334,7 @@ impl WireguardTunnel { }), Ok(IpVersion::Ipv6) => Ipv6Packet::new_checked(&packet) .ok() - .filter(|packet| Ipv6Addr::from(packet.dst_addr()) == self.source_peer_ip) + .filter(|packet| Some(Ipv6Addr::from(packet.dst_addr())) == self.source_peer_ipv6) .and_then(|packet| { match packet.next_header() { IpProtocol::Tcp => Some(PortProtocol::Tcp),