Skip to content

Commit

Permalink
enable ipv6 on wireguard (#234)
Browse files Browse the repository at this point in the history
* enable ipv6 on wg

* clippy
  • Loading branch information
ibigbug authored Dec 27, 2023
1 parent f76ef2e commit 6bce6f0
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 100 deletions.
15 changes: 8 additions & 7 deletions clash/tests/data/config/wg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,16 @@ experimental:
proxies:
- name: "wg"
type: wireguard
server: 10.0.0.17
port: 51820
private-key: 2AS8PSccSenWrws5ExglmpwjVBub9Oy9X3zOlk6heHU=
ip: 192.168.2.2
public-key: MAZPwQBniuXmQf5w8BwM3owlO7Kw07rzyZUXxOvsF3w=
allowed-ips: ['0.0.0.0/0']
server: engage.cloudflareclient.com
port: 2408
private-key: uIwDn4c7656E/1pHkJu23ZOe/4SuCnL+vL+jE2s4MHE=
ip: 172.16.0.2/32
ipv6: 2606:4700:110:8e5e:fa1:3f30:c077:e17c/128
public-key: bmXOC+F1FxEMF9dyiK2H5/1SUtzH0JuVo51h2wPfgyo=
allowed-ips: ['0.0.0.0/0', '::/0']
remote-dns-resolve: true
dns:
- 8.8.8.8
- 1.1.1.1
udp: true


Expand Down
29 changes: 24 additions & 5 deletions clash_lib/src/proxy/converters/wireguard.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use ipnet::IpNet;

use crate::{
config::internal::proxy::OutboundWireguard,
proxy::{
Expand Down Expand Up @@ -26,14 +28,31 @@ impl TryFrom<&OutboundWireguard> for AnyOutboundHandler {
port: s.port,
ip: s
.ip
.parse()
.map_err(|x| Error::InvalidConfig(format!("invalid ip address: {}", x)))?,
.parse::<IpNet>()
.map(|x| match x.addr() {
std::net::IpAddr::V4(v4) => Ok(v4),
std::net::IpAddr::V6(_) => Err(Error::InvalidConfig(
"invalid ip address: put an v4 address here".to_owned(),
)),
})
.map_err(|x| {
Error::InvalidConfig(format!("invalid ip address: {}, {}", x, s.ip))
})??,
ipv6: s
.ipv6
.as_ref()
.map(|x| {
x.parse()
.map_err(|x| Error::InvalidConfig(format!("invalid ipv6 address: {}", x)))
.and_then(|x| {
x.parse::<IpNet>()
.map(|x| match x.addr() {
std::net::IpAddr::V4(_) => Err(Error::InvalidConfig(
"invalid ip address: put an v6 address here".to_owned(),
)),
std::net::IpAddr::V6(v6) => Ok(v6),
})
.map_err(|e| {
Error::InvalidConfig(format!("invalid ipv6 address: {}, {}", e, x))
})
.ok()
})
.transpose()?,
private_key: s.private_key.to_owned(),
Expand Down
203 changes: 140 additions & 63 deletions clash_lib/src/proxy/wg/device.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use std::{
collections::{HashMap, VecDeque},
net::{IpAddr, SocketAddr},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
sync::Arc,
time::Duration,
};

use bytes::{BufMut, Bytes, BytesMut};
use futures::{SinkExt, StreamExt};

use rand::seq::SliceRandom;
use smoltcp::{
iface::{Config, Interface, SocketHandle, SocketSet},
phy::Device,
Expand Down Expand Up @@ -56,7 +58,8 @@ enum SenderType {
}

pub struct DeviceManager {
addr: IpAddr, // TODO: support ipv6
addr: Ipv4Addr,
addr_v6: Option<Ipv6Addr>,
resolver: ThreadSafeDNSResolver,
dns_servers: Vec<SocketAddr>,

Expand All @@ -74,7 +77,8 @@ pub struct DeviceManager {

impl DeviceManager {
pub fn new(
addr: IpAddr,
addr: Ipv4Addr,
addr_v6: Option<Ipv6Addr>,
resolver: ThreadSafeDNSResolver,
dns_servers: Vec<SocketAddr>,
packet_notifier: Receiver<()>,
Expand All @@ -89,6 +93,8 @@ impl DeviceManager {

Self {
addr,
addr_v6,

resolver,
dns_servers,

Expand Down Expand Up @@ -131,47 +137,99 @@ impl DeviceManager {

pub async fn look_up_dns(&self, host: &str, server: SocketAddr) -> Option<IpAddr> {
debug!("looking up {} on {}", host, server);
let mut socket = Self::new_udp_socket(self).await;
let mut msg = hickory_proto::op::Message::new();

msg.add_query({
let mut q = hickory_proto::op::Query::new();
let name = hickory_proto::rr::Name::from_str_relaxed(host)
.unwrap()
.append_domain(&hickory_proto::rr::Name::root())
.unwrap();
q.set_name(name);
q.set_query_type(hickory_proto::rr::RecordType::A);
q
});

msg.set_recursion_desired(true);

let pkt = UdpPacket::new(msg.to_vec().unwrap(), SocksAddr::any_ipv4(), server.into());

socket.feed(pkt).await.ok()?;
socket.flush().await.ok()?;
trace!("sent dns query: {:?}", msg);
#[async_recursion::async_recursion]
async fn query(
rtype: hickory_proto::rr::RecordType,
host: &str,
server: SocketAddr,
mut socket: UdpPair,
) -> Option<IpAddr> {
let mut msg = hickory_proto::op::Message::new();

msg.add_query({
let mut q = hickory_proto::op::Query::new();
let name = hickory_proto::rr::Name::from_str_relaxed(host)
.unwrap()
.append_domain(&hickory_proto::rr::Name::root())
.unwrap();
q.set_name(name);
q.set_query_type(rtype);
q
});

msg.set_recursion_desired(true);

let pkt = UdpPacket::new(msg.to_vec().unwrap(), SocksAddr::any_ipv4(), server.into());

socket.feed(pkt).await.ok()?;
socket.flush().await.ok()?;
trace!("sent dns query: {:?}", msg);

let pkt = match tokio::time::timeout(Duration::from_secs(5), socket.next()).await {
Ok(Some(pkt)) => pkt,
_ => {
warn!("wg dns query timed out with server {server}");
return None;
}
};

let msg = hickory_proto::op::Message::from_vec(&pkt.data).ok()?;
trace!("got dns response: {:?}", msg);
for ans in msg.answers().iter() {
if ans.record_type() == rtype {
if let Some(data) = ans.data() {
match (rtype, data) {
(_, hickory_proto::rr::RData::CNAME(cname)) => {
debug!(
"{} resolved to CNAME {}, asking recursively",
host, cname.0
);
return query(rtype, &cname.0.to_ascii(), server, socket).await;
}
(
hickory_proto::rr::RecordType::A,
hickory_proto::rr::RData::A(addr),
) => {
return Some(std::net::IpAddr::V4(addr.0));
}
(
hickory_proto::rr::RecordType::AAAA,
hickory_proto::rr::RData::AAAA(addr),
) => {
return Some(std::net::IpAddr::V6(addr.0));
}
_ => return None,
}
};
}
}
None
}

let pkt = match tokio::time::timeout(Duration::from_secs(5), socket.next()).await {
Ok(Some(pkt)) => pkt,
_ => {
warn!("wg dns query timed out with server {server}");
return None;
let socket = self.new_udp_socket().await;
let v4_query = query(hickory_proto::rr::RecordType::A, host, server, socket);
if self.addr_v6.is_some() {
let socket = self.new_udp_socket().await;
let v6_query = query(hickory_proto::rr::RecordType::AAAA, host, server, socket);
match tokio::time::timeout(
Duration::from_secs(5),
futures::future::join(v4_query, v6_query),
)
.await
{
Ok((_, Some(v6))) => Some(v6),
Ok((v4, _)) => v4,
_ => {
warn!("wg dns query timed out with server {server}");
None
}
}
};

let msg = hickory_proto::op::Message::from_vec(&pkt.data).ok()?;
trace!("got dns response: {:?}", msg);
msg.answers()
.iter()
.find_map(|ans| match ans.record_type() {
hickory_proto::rr::RecordType::A => ans.data().and_then(|data| match data {
hickory_proto::rr::RData::A(addr) => Some(std::net::IpAddr::V4(addr.0)),
_ => None,
}),
_ => None,
})
} else {
tokio::time::timeout(Duration::from_secs(5), v4_query)
.await
.ok()?
}
}

pub async fn poll_sockets(&self, mut device: VirtualIpDevice) {
Expand All @@ -181,6 +239,10 @@ impl DeviceManager {
let mut iface = Interface::new(config, &mut device, Instant::now());
iface.update_ip_addrs(|addrs| {
addrs.push(IpCidr::new(self.addr.into(), 32)).unwrap();

if let Some(addr_v6) = self.addr_v6 {
addrs.push(IpCidr::new(addr_v6.into(), 128)).unwrap();
}
});

let (device_sender, mut device_receiver) = tokio::sync::mpsc::channel(1024);
Expand All @@ -206,7 +268,10 @@ impl DeviceManager {
.connect(
iface.context(),
remote,
(self.addr, self.get_ephemeral_tcp_port().await),
(match remote {
SocketAddr::V4(_) => IpAddr::V4(self.addr),
SocketAddr::V6(_) => IpAddr::V6(self.addr_v6.unwrap()),
}, self.get_ephemeral_tcp_port().await),
)
.unwrap();

Expand Down Expand Up @@ -234,13 +299,7 @@ impl DeviceManager {
socket_pairs.insert(handle, SenderType::Tcp(sender));
tcp_queue.insert(handle, VecDeque::new());
}
Socket::Udp(mut socket, sender, mut receiver) => {
socket
.bind(
(self.addr, self.get_ephemeral_udp_port().await),
)
.unwrap();

Socket::Udp(socket, sender, mut receiver) => {
let handle = sockets.add(socket);

let device_sender = device_sender.clone();
Expand Down Expand Up @@ -401,29 +460,47 @@ impl DeviceManager {
let ip = match &pkt.dst_addr {
SocksAddr::Ip(addr) => addr.ip(),
SocksAddr::Domain(domain, _) => {
if let Some(dns_server) = self.dns_servers.get(0) {
let ip = self.look_up_dns(domain, *dns_server).await;
if let Some(ip) = ip {
debug!("host {} resolved to {} on wg stack", domain, ip);
ip
} else {
warn!("failed to resolve domain on wireguard: {}", domain);
continue;
}
if let Ok(ip) = domain.parse::<IpAddr>() {
ip
} else {
match self.resolver.resolve(domain, false).await {
Ok(Some(ip)) => {
debug!("host {} resolved to {} on local", domain, ip);
let dns_server = self.dns_servers.choose(&mut rand::thread_rng());
if let Some(dns_server) = dns_server {
let ip = self.look_up_dns(domain, *dns_server).await;
if let Some(ip) = ip {
debug!("host {} resolved to {} on wg stack", domain, ip);
ip
}
_ => {
} else {
warn!("failed to resolve domain on wireguard: {}", domain);
continue;
}
} else {
match self.resolver.resolve(domain, false).await {
Ok(Some(ip)) => {
debug!("host {} resolved to {} on local", domain, ip);
ip
}
_ => {
warn!("failed to resolve domain on wireguard: {}", domain);
continue;
}
}
}
}
}
};

if !socket.is_open() {
let local_addr: IpAddr = match ip {
IpAddr::V4(_) => self.addr.into(),
IpAddr::V6(_) => self.addr_v6.unwrap().into(),
};
socket
.bind(
(local_addr, self.get_ephemeral_udp_port().await),
)
.unwrap();
}

match socket.send_slice(&pkt.data, (ip, pkt.dst_addr.port())) {
Ok(_) => {}
Err(e) => {
Expand Down
Loading

0 comments on commit 6bce6f0

Please sign in to comment.