Skip to content

Commit

Permalink
Somewhat working encryption side
Browse files Browse the repository at this point in the history
  • Loading branch information
Hasan6979 committed Nov 14, 2024
1 parent 49b5a5d commit 89b0ce1
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 136 deletions.
120 changes: 72 additions & 48 deletions boringtun/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub mod tun;
#[path = "tun_linux.rs"]
pub mod tun;

use std::collections::{HashMap, VecDeque};
use std::collections::HashMap;
use std::io::{self, Write as _};
use std::mem::MaybeUninit;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
Expand All @@ -38,20 +38,14 @@ use std::thread::JoinHandle;
use crate::noise::errors::WireGuardError;
use crate::noise::handshake::parse_handshake_anon;
use crate::noise::rate_limiter::RateLimiter;
use crate::noise::session::{
self, EncryptionTaskData, NetworkTaskData, Session, ENCRYPTED_RING_BUFFER,
PLAINTEXT_RING_BUFFER,
};
use crate::noise::{Packet, Tunn, TunnResult};
use crate::noise::session::{Session, ENCRYPTED_RING_BUFFER, PLAINTEXT_RING_BUFFER, RB_SIZE};
use crate::noise::{NeptunResult, Packet, Tunn, TunnResult};
use crate::x25519;
use allowed_ips::AllowedIps;
use async_channel::{Receiver, Sender};
use once_cell::sync::Lazy;
use parking_lot::Mutex;
use peer::{AllowedIP, Peer};
use poll::{EventPoll, EventRef, WaitResult};
use rand_core::{OsRng, RngCore};
use ring::aead::{LessSafeKey, UnboundKey, CHACHA20_POLY1305};
use socket2::{Domain, Protocol, Socket, Type};
use tun::TunSocket;

Expand Down Expand Up @@ -165,12 +159,15 @@ pub struct Device {

rate_limiter: Option<Arc<RateLimiter>>,

rx: Receiver<()>,
tx: Sender<()>,
encyrpt_tx: Sender<usize>,
network_rx: Receiver<()>,
network_tx: Sender<()>,
#[cfg(target_os = "linux")]
uapi_fd: i32,
}

static mut ITER: usize = 0;

struct ThreadData {
iface: Arc<TunSocket>,
src_buf: [u8; MAX_UDP_SIZE],
Expand Down Expand Up @@ -371,8 +368,8 @@ impl Device {
let uapi_fd = -1;
#[cfg(target_os = "linux")]
let uapi_fd = config.uapi_fd;
let (tx, rx) = async_channel::bounded(1024);
let (tx1, rx1) = async_channel::bounded(1024);
let (encyrpt_tx, encrypt_rx) = async_channel::bounded(1024);
let (network_tx, network_rx) = async_channel::bounded(1024);

let mut device = Device {
queue: Arc::new(poll),
Expand All @@ -392,8 +389,9 @@ impl Device {
cleanup_paths: Default::default(),
mtu: AtomicUsize::new(mtu),
rate_limiter: None,
rx: rx.clone(),
tx: tx.clone(),
encyrpt_tx: encyrpt_tx.clone(),
network_tx: network_tx.clone(),
network_rx,
#[cfg(target_os = "linux")]
uapi_fd,
};
Expand All @@ -407,8 +405,8 @@ impl Device {
device.register_notifiers()?;
device.register_timers()?;

let rx_clone = rx.clone();
std::thread::spawn(move || Session::encrypt_data_worker(rx_clone, tx1));
let rx_clone = encrypt_rx.clone();
std::thread::spawn(move || Session::encrypt_data_worker(rx_clone, network_tx));

#[cfg(target_os = "macos")]
{
Expand Down Expand Up @@ -463,10 +461,10 @@ impl Device {
self.udp4 = Some(udp_sock4.try_clone().unwrap());
self.udp6 = Some(udp_sock6.try_clone().unwrap());
// Send to network in a seperate thread
let rx_clone = self.rx.clone();
let rx_clone = self.network_rx.clone();
let uv4 = Arc::new(udp_sock4).clone();
let uv6 = Arc::new(udp_sock6).clone();
// std::thread::spawn(move || send_to_network(rx_clone, uv4, uv6));
std::thread::spawn(move || send_to_network(rx_clone, uv4, uv6));
self.listen_port = port;

Ok(())
Expand Down Expand Up @@ -823,7 +821,8 @@ impl Device {

let peers = &d.peers_by_ip;
for _ in 0..MAX_ITR {
if let Some(element) = unsafe { PLAINTEXT_RING_BUFFER.pop_front() } {
if let Some(element) = unsafe { PLAINTEXT_RING_BUFFER.get_mut(ITER) } {
let mut is_handshake_msg = false;
{
let mut item = element.lock();
let src = match iface.read(&mut item.data[..mtu]) {
Expand Down Expand Up @@ -871,20 +870,39 @@ impl Device {
{
let mut dst = entry.lock();
dst.peer = Some(peer.clone());
dst.res = tun.format_handshake_initiation(
let res = tun.format_handshake_initiation(
dst.data.as_mut_slice(),
false,
true,
);
match res {
NeptunResult::Done => dst.res = NeptunResult::Done,
NeptunResult::Err(e) => {
dst.res = NeptunResult::Err(e)
}
NeptunResult::WriteToNetwork(n) => {
dst.res = NeptunResult::WriteToNetwork(n)
}
_ => continue,
}
};
unsafe { ENCRYPTED_RING_BUFFER.push_back(entry) };
d.tx.send_blocking(()); // change the channel
let _ = d.network_tx.send_blocking(());
is_handshake_msg = true;
}
}
};
}
unsafe { PLAINTEXT_RING_BUFFER.push_back(element) };
// unsafe { PLAINTEXT_RING_BUFFER.push_back(element) };
// Notify the encrypt part with channel!!
d.tx.send_blocking(());
if !is_handshake_msg {
let _ = d.encyrpt_tx.send_blocking(unsafe { ITER });
if unsafe { ITER != (RB_SIZE - 1) } {
unsafe { ITER += 1 };
} else {
// Reset the write iterator
unsafe { ITER = 0 };
}
}
continue;
// TODO: Q the packet
}
Expand All @@ -896,30 +914,36 @@ impl Device {
}
}

fn send_to_network(rx: Receiver<()>, udp4: Arc<Socket>, udp6: Arc<Socket>) {
if rx.recv_blocking().is_ok() {
if let Some(msg) = unsafe { ENCRYPTED_RING_BUFFER.pop_back() } {
let msg = msg.lock();
match &msg.res {
TunnResult::Done => {}
TunnResult::Err(e) => {
tracing::error!(message = "Encapsulate error", error = ?e)
}
TunnResult::WriteToNetwork(packet) => {
let mut endpoint = msg.peer.as_ref().unwrap().endpoint_mut();
if let Some(conn) = endpoint.conn.as_mut() {
// Prefer to send using the connected socket
let _: Result<_, _> = conn.write(packet);
} else if let Some(addr @ SocketAddr::V4(_)) = endpoint.addr {
let _: Result<_, _> = udp4.send_to(packet, &addr.into());
} else if let Some(addr @ SocketAddr::V6(_)) = endpoint.addr {
let _: Result<_, _> = udp6.send_to(packet, &addr.into());
} else {
tracing::error!("No endpoint");
fn send_to_network(network_rx: Receiver<()>, udp4: Arc<Socket>, udp6: Arc<Socket>) {
while network_rx.recv_blocking().is_ok() {
if let Some(elem) = unsafe { ENCRYPTED_RING_BUFFER.pop_back() } {
{
let msg = elem.lock();
match &msg.res {
NeptunResult::Done => {}
NeptunResult::Err(e) => {
tracing::error!(message = "Encapsulate error", error = ?e)
}
}
_ => panic!("Unexpected result from encapsulate"),
};
NeptunResult::WriteToNetwork(len) => {
let mut endpoint = msg.peer.as_ref().unwrap().endpoint_mut();
let packet = &msg.data.as_slice()[..(*len)];
if let Some(conn) = endpoint.conn.as_mut() {
// Prefer to send using the connected socket
let _: Result<_, _> = conn.write(packet);
} else if let Some(addr @ SocketAddr::V4(_)) = endpoint.addr {
let _: Result<_, _> = udp4.send_to(packet, &addr.into());
} else if let Some(addr @ SocketAddr::V6(_)) = endpoint.addr {
let _: Result<_, _> = udp6.send_to(packet, &addr.into());
} else {
tracing::error!("No endpoint");
}
}
_ => panic!("Unexpected result from encapsulate"),
};
}
unsafe {
ENCRYPTED_RING_BUFFER.push_front(elem);
}
}
}
}
Expand Down
51 changes: 31 additions & 20 deletions boringtun/src/noise/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@ pub mod handshake;
pub mod rate_limiter;

pub mod session;
mod timers;
pub mod timers;

use once_cell::sync::Lazy;
use parking_lot::Mutex;
use ring::aead::{LessSafeKey, UnboundKey, CHACHA20_POLY1305};
use session::{Session, PLAINTEXT_RING_BUFFER};
use session::Session;

use crate::noise::errors::WireGuardError;
use crate::noise::handshake::Handshake;
Expand All @@ -22,7 +19,6 @@ use crate::x25519;
use std::collections::VecDeque;
use std::convert::{TryFrom, TryInto};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use std::time::Duration;

Expand Down Expand Up @@ -51,9 +47,9 @@ const N_SESSIONS: usize = 8;
pub enum TunnResult<'a> {
Done,
Err(WireGuardError),
WriteToNetwork(&'a [u8]),
WriteToTunnelV4(&'a [u8], Ipv4Addr),
WriteToTunnelV6(&'a [u8], Ipv6Addr),
WriteToNetwork(&'a mut [u8]),
WriteToTunnelV4(&'a mut [u8], Ipv4Addr),
WriteToTunnelV6(&'a mut [u8], Ipv6Addr),
}

impl<'a> From<WireGuardError> for TunnResult<'a> {
Expand All @@ -62,6 +58,21 @@ impl<'a> From<WireGuardError> for TunnResult<'a> {
}
}

#[derive(Debug)]
pub enum NeptunResult<'a> {
Done,
Err(WireGuardError),
WriteToNetwork(usize),
WriteToTunnelV4(&'a mut [u8], Ipv4Addr),
WriteToTunnelV6(&'a mut [u8], Ipv6Addr),
}

impl<'a> From<WireGuardError> for NeptunResult<'a> {
fn from(err: WireGuardError) -> NeptunResult<'a> {
NeptunResult::Err(err)
}
}

/// Tunnel represents a point-to-point WireGuard connection
pub struct Tunn {
/// The handshake currently in progress
Expand All @@ -74,7 +85,7 @@ pub struct Tunn {
packet_queue: VecDeque<Vec<u8>>,
/// Keeps tabs on the expiring timers
timers: timers::Timers,
tx_bytes: usize,
pub tx_bytes: usize,
rx_bytes: usize,
rate_limiter: Arc<RateLimiter>,
}
Expand Down Expand Up @@ -367,7 +378,7 @@ impl Tunn {

let session = self.handshake.receive_handshake_response(p)?;

let keepalive_packet = {
let (keepalive_packet, _) = {
Session::encrypt_data_pkt(
session.sending_key_counter.clone(),
session.sending_index,
Expand Down Expand Up @@ -425,11 +436,11 @@ impl Tunn {
}

/// Decrypts a data packet, and stores the decapsulated packet in dst.
fn handle_data(
fn handle_data<'a>(
&mut self,
packet: PacketData,
dst: &mut [u8],
) -> Result<TunnResult, WireGuardError> {
dst: &'a mut [u8],
) -> Result<TunnResult<'a>, WireGuardError> {
let r_idx = packet.receiver_idx as usize;
let idx = r_idx % N_SESSIONS;

Expand Down Expand Up @@ -462,9 +473,9 @@ impl Tunn {
&mut self,
dst: &'a mut [u8],
force_resend: bool,
) -> TunnResult<'a> {
) -> NeptunResult {
if self.handshake.is_in_progress() && !force_resend {
return TunnResult::Done;
return NeptunResult::Done;
}

if self.handshake.is_expired() {
Expand All @@ -481,9 +492,9 @@ impl Tunn {
self.timer_tick(TimerName::TimeLastHandshakeStarted);
}
self.timer_tick(TimerName::TimeLastPacketSent);
TunnResult::WriteToNetwork(packet)
NeptunResult::WriteToNetwork(packet.len())
}
Err(e) => TunnResult::Err(e),
Err(e) => NeptunResult::Err(e),
}
}

Expand Down Expand Up @@ -642,8 +653,8 @@ impl Tunn {
// fn create_handshake_init(tun: &mut Tunn) -> Vec<u8> {
// let mut dst = vec![0u8; 2048];
// let handshake_init = tun.format_handshake_initiation(&mut dst, false);
// assert!(matches!(handshake_init, TunnResult::WriteToNetwork(_)));
// let handshake_init = if let TunnResult::WriteToNetwork(sent) = handshake_init {
// assert!(matches!(handshake_init, NeptunResult::WriteToNetwork(_)));
// let handshake_init = if let NeptunResult::WriteToNetwork(sent) = handshake_init {
// sent
// } else {
// unreachable!();
Expand Down
Loading

0 comments on commit 89b0ce1

Please sign in to comment.