diff --git a/Cargo.toml b/Cargo.toml index c15b87f..7f16e1f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,9 @@ repository = "https://github.com/PocketRelay/PocketRelayClientShared" [dependencies] +# Shared UDP tunnel protocol +pocket-relay-udp-tunnel = { version = "0" } + # Logging log = "0.4" diff --git a/src/servers/udp_tunnel.rs b/src/servers/udp_tunnel.rs index a966a63..a8ac36f 100644 --- a/src/servers/udp_tunnel.rs +++ b/src/servers/udp_tunnel.rs @@ -1,10 +1,17 @@ -use self::codec::TunnelMessage; +//! UDP Tunneling server +//! +//! Provides a local tunnel that connects clients by tunneling through the Pocket Relay +//! server. This allows clients with more strict NATs to host games without common issues +//! faced when trying to connect. This is the faster UDP implementation + use crate::{ ctx::ClientContext, servers::{spawn_server_task, GAME_HOST_PORT, RANDOM_PORT, TUNNEL_HOST_PORT}, }; -use codec::{MessageHeader, MessageReader, MessageWriter}; use log::{debug, error}; +use pocket_relay_udp_tunnel::{ + deserialize_message, serialize_message, MessageError, TunnelMessage, +}; use std::{ future::Future, io::ErrorKind, @@ -14,7 +21,14 @@ use std::{ task::{ready, Context, Poll}, time::Duration, }; -use tokio::{io::ReadBuf, net::UdpSocket, sync::mpsc, time::sleep, try_join}; +use thiserror::Error; +use tokio::{ + io::ReadBuf, + net::UdpSocket, + sync::mpsc, + time::{sleep, timeout}, + try_join, +}; /// The fixed size of socket pool to use const SOCKET_POOL_SIZE: usize = 4; @@ -25,12 +39,65 @@ const MAX_ERROR_ATTEMPTS: u8 = 5; static LOCAL_SEND_TARGET: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, GAME_HOST_PORT)); +/// Errors that can occur while creating a UDP tunnel +#[derive(Debug, Error)] +pub enum UdpTunnelError { + /// The base URL of the server is not compatible with the UDP tunnel + /// variant (It missing the host portion?) + #[error("host url incompatible with UDP tunnel")] + HostIncompatible, + + /// Server version does not support the UDP tunnel or the server has + /// explicitly disabled the tunnel + #[error("server incompatible with UDP tunnel")] + ServerIncompatible, + + /// Failed to bind the tunnel socket + #[error(transparent)] + Bind(std::io::Error), + + /// Failed to "connect" to the target server, happens when the host + /// is unreachable or DNS resolution fails + #[error(transparent)] + Connect(std::io::Error), + + /// Reached timeout while attempting to complete handshake + #[error("timeout reached while handshaking")] + HandshakeTimeout, + + /// Some generic IO error occurred while reading or writing + #[error(transparent)] + GenericIo(#[from] std::io::Error), + + /// Received malformed packet when creating the tunnel + #[error("malformed packet: {0}")] + MalformedPacket(#[from] MessageError), + + /// Got an unexpected packet during the handshake process + #[error("unexpected packet while handshaking")] + UnexpectedPacket, + + /// Failed to allocate the local socket pool + #[error(transparent)] + AllocateSocketPool(std::io::Error), +} + /// Starts the tunnel socket pool and creates the tunnel /// connection to the server /// /// ## Arguments -/// * `ctx` - The client context -pub async fn start_tunnel_server_v2(ctx: Arc) -> std::io::Result<()> { +/// * `ctx` - The client context +/// * `tunnel_port` - The UDP tunnel server port to connect to +pub async fn start_udp_tunnel_server( + ctx: Arc, + tunnel_port: u16, +) -> std::io::Result<()> { + let host = match ctx.base_url.host() { + Some(value) => value.to_string(), + // Cannot form a tunnel without a host + None => return Ok(()), + }; + let association = match Option::as_ref(&ctx.association) { Some(value) => value, // Don't try and tunnel without a token @@ -38,14 +105,15 @@ pub async fn start_tunnel_server_v2(ctx: Arc) -> std::io::Result< }; // Last encountered error - let mut last_error: Option = None; + let mut last_error: Option = None; // Number of attempts that errored let mut attempt_errors: u8 = 0; // Looping to attempt reconnecting if lost while attempt_errors < MAX_ERROR_ATTEMPTS { // Create the tunnel (Future will end if tunnel stopped) - let reconnect_time = if let Err(err) = create_tunnel(ctx.clone(), association).await { + let reconnect_time = if let Err(err) = create_tunnel(&host, tunnel_port, association).await + { error!("Failed to create tunnel: {}", err); // Set last error @@ -73,100 +141,47 @@ pub async fn start_tunnel_server_v2(ctx: Arc) -> std::io::Result< tokio::time::sleep(reconnect_time).await; } - Err(last_error.unwrap_or(std::io::Error::new( - ErrorKind::Other, - "Reached error connect limit", - ))) -} - -async fn handshake_tunnel(socket: &UdpSocket, association: &str) -> std::io::Result { - let mut writer = MessageWriter::default(); - let header = MessageHeader { - tunnel_id: u32::MAX, - version: 0, - }; - - let message = TunnelMessage::Initiate { - association_token: association.to_string(), - }; - header.write(&mut writer); - message.write(&mut writer); - - socket.send(&writer.buffer).await?; - - let mut buffer = [0u8; u16::MAX as usize]; - - let count = socket.recv(&mut buffer).await?; - let buffer = &buffer[..count]; - let mut reader = MessageReader::new(buffer); - - let header = MessageHeader::read(&mut reader) - .map_err(|_| std::io::Error::new(ErrorKind::Other, "Malformed packet"))?; - let message = TunnelMessage::read(&mut reader) - .map_err(|_| std::io::Error::new(ErrorKind::Other, "Malformed packet"))?; - - match message { - TunnelMessage::Initiated { tunnel_id } => Ok(tunnel_id), - _ => Err(std::io::Error::new(ErrorKind::Other, "Unexpected packet")), - } + Err(last_error + .map(|err| std::io::Error::new(ErrorKind::Other, err)) + .unwrap_or(std::io::Error::new( + ErrorKind::Other, + "Reached error connect limit", + ))) } /// Creates a new tunnel /// /// ## Arguments -/// * `ctx` - The client context +/// * `host` - The host for connecting the tunnel +/// * `tunnel_port` - The port the tunnel is running on /// * `association` - The client association token -async fn create_tunnel(ctx: Arc, association: &str) -> std::io::Result<()> { - let host = match ctx.base_url.host() { - Some(value) => value, - // Cannot form a tunnel without a host - None => return Ok(()), - }; - - let tunnel_port = match ctx.tunnel_port { - Some(value) => value, - // Cannot form a tunnel without a port - None => return Ok(()), - }; - +async fn create_tunnel( + host: &str, + tunnel_port: u16, + association: &str, +) -> Result<(), UdpTunnelError> { // Bind a local udp socket - let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0)).await?; + let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0)) + .await + .map_err(UdpTunnelError::Bind)?; // Map connection to remote tunnel server - socket.connect((host.to_string(), tunnel_port)).await?; - - debug!("Initiating tunnel: {}:{}", host, tunnel_port); + socket + .connect((host, tunnel_port)) + .await + .map_err(UdpTunnelError::Connect)?; - let tunnel_id: u32; + debug!("initiating tunnel: {}:{}", host, tunnel_port); - { - let mut retry_count = 0; - - loop { - match handshake_tunnel(&socket, association).await { - Ok(value) => { - tunnel_id = value; - break; - } - Err(err) => { - error!("failed to handshake for token: {}", err); + let tunnel_id = attempt_tunnel_handshake(&socket, association).await?; - retry_count += 1; - sleep(Duration::from_secs(5 * retry_count)).await; - - if retry_count > 5 { - return Err(err); - } - } - } - } - } - - debug!("Created server tunnel"); + debug!("created server tunnel: {}", tunnel_id); // Allocate the socket pool for the tunnel let (tx, rx) = mpsc::unbounded_channel(); - let pool = Socket::allocate_pool(tx).await?; + let pool = Socket::allocate_pool(tx) + .await + .map_err(UdpTunnelError::AllocateSocketPool)?; debug!("Allocated tunnel pool"); // Start the tunnel @@ -180,9 +195,85 @@ async fn create_tunnel(ctx: Arc, association: &str) -> std::io::R } .await; + // TODO: Handle connection lost + Ok(()) } +// Maximum number of times to try and handshake +const MAX_HANDSHAKE_ATTEMPTS: u8 = 5; + +// Time to elapse without a response before the handshake is considered timed out +const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(5); + +/// Attempts to complete a tunnel handshake, will retry until +async fn attempt_tunnel_handshake( + socket: &UdpSocket, + association: &str, +) -> Result { + let mut retry_count: u8 = 0; + let mut retry_delay: u64 = 5; + let mut last_err: UdpTunnelError; + + loop { + match timeout(HANDSHAKE_TIMEOUT, handshake_tunnel(socket, association)).await { + // Successful handshake + Ok(Ok(value)) => return Ok(value), + + // Got an error while processing the handshake + Ok(Err(err)) => { + error!("failed to handshake for token: {}", err); + last_err = err + } + // Handshaking process timed out + Err(_) => { + error!("timeout while attempting tunnel handshake"); + last_err = UdpTunnelError::HandshakeTimeout + } + } + + retry_count += 1; + + // Wait between attempts with exponential backoff + sleep(Duration::from_secs(retry_delay)).await; + retry_delay *= 2; + + if retry_count > MAX_HANDSHAKE_ATTEMPTS { + return Err(last_err); + } + } +} + +/// Completes a tunnel handshake over the provided socket, exchanges +/// the association token for a tunnel ID to use on future connections +async fn handshake_tunnel(socket: &UdpSocket, association: &str) -> Result { + // Serialize and write the initiate message + let buffer = serialize_message( + u32::MAX, + &TunnelMessage::Initiate { + association_token: association.to_string(), + }, + ); + + socket.send(&buffer).await?; + + // Allocate buffer and read message + let mut buffer = [0u8; u16::MAX as usize]; + let count = socket.recv(&mut buffer).await?; + let buffer = &buffer[..count]; + + // Deserialize a message + let packet = deserialize_message(buffer)?; + + match packet.message { + // Got the initiation message + TunnelMessage::Initiated { tunnel_id } => Ok(tunnel_id), + + // Not expecting any other packets in this state + _ => Err(UdpTunnelError::UnexpectedPacket), + } +} + /// Represents a tunnel and its pool of connections that it can /// send data to and receive data from struct Tunnel { @@ -249,17 +340,10 @@ impl Tunnel { .take() .expect("Unexpected write state without message"); - let header = MessageHeader { - tunnel_id: self.tunnel_id, - version: 0, - }; - - let mut buffer = MessageWriter::default(); - header.write(&mut buffer); - message.write(&mut buffer); + let buffer = serialize_message(self.tunnel_id, &message); // Write the packet to the buffer - ready!(Pin::new(&mut self.socket).poll_send(cx, &buffer.buffer)) + ready!(Pin::new(&mut self.socket).poll_send(cx, &buffer)) // Packet encoder impl shouldn't produce errors .expect("Message encoder errored"); @@ -296,19 +380,17 @@ impl Tunnel { }; let buffer = read_buffer.filled(); - let mut reader = MessageReader::new(buffer); - - let header = match MessageHeader::read(&mut reader) { - Ok(value) => value, - Err(_) => return Poll::Ready(TunnelReadState::Stop), - }; - let message = match TunnelMessage::read(&mut reader) { + let packet = match deserialize_message(buffer) { Ok(value) => value, - Err(_) => return Poll::Ready(TunnelReadState::Stop), + Err(err) => { + error!("encountered invalid tunnel message: {}", err); + return Poll::Ready(TunnelReadState::Stop); + } }; - match message { + match packet.message { + // Send forwarded messages to the correct socket handle TunnelMessage::Forward { index, message } => { // Get the handle to use within the connection pool let handle = self.pool.get(index as usize); @@ -319,6 +401,7 @@ impl Tunnel { } } + // Reply to keep-alive message TunnelMessage::KeepAlive => { self.write_state = TunnelWriteState::Write(Some(TunnelMessage::KeepAlive)); @@ -567,241 +650,3 @@ impl Future for Socket { Poll::Pending } } - -mod codec { - use thiserror::Error; - - #[derive(Default)] - pub struct MessageWriter { - pub buffer: Vec, - } - - impl MessageWriter { - #[inline] - pub fn write_u8(&mut self, value: u8) { - self.buffer.push(value) - } - - #[inline] - pub fn write_bytes(&mut self, value: &[u8]) { - self.buffer.extend_from_slice(value) - } - - pub fn write_u32(&mut self, value: u32) { - self.write_bytes(&value.to_be_bytes()) - } - - pub fn write_u16(&mut self, value: u16) { - self.write_bytes(&value.to_be_bytes()) - } - } - - pub struct MessageReader<'a> { - buffer: &'a [u8], - cursor: usize, - } - - impl<'a> MessageReader<'a> { - pub fn new(buffer: &'a [u8]) -> MessageReader<'a> { - MessageReader { buffer, cursor: 0 } - } - - #[inline] - pub fn capacity(&self) -> usize { - self.buffer.len() - } - - pub fn len(&self) -> usize { - self.capacity() - self.cursor - } - - pub fn read_u8(&mut self) -> Result { - if self.len() < 1 { - return Err(MessageError::Incomplete); - } - - let value = self.buffer[self.cursor]; - self.cursor += 1; - - Ok(value) - } - - pub fn read_u32(&mut self) -> Result { - let value = self.read_bytes(4)?; - let value = u32::from_be_bytes([value[0], value[1], value[2], value[3]]); - Ok(value) - } - - pub fn read_u16(&mut self) -> Result { - let value = self.read_bytes(2)?; - let value = u16::from_be_bytes([value[0], value[1]]); - Ok(value) - } - - pub fn read_bytes(&mut self, length: usize) -> Result<&'a [u8], MessageError> { - if self.len() < length { - return Err(MessageError::Incomplete); - } - let value = &self.buffer[self.cursor..self.cursor + length]; - self.cursor += length; - Ok(value) - } - } - - #[derive(Debug)] - pub struct MessageHeader { - /// Protocol version (For future sake) - pub version: u8, - /// ID of the tunnel this message is from, [u32::MAX] when the - /// tunnel is not yet initiated - pub tunnel_id: u32, - } - - #[derive(Debug, Error)] - pub enum MessageError { - #[error("unknown message type")] - UnknownMessageType, - - #[error("message was incomplete")] - Incomplete, - } - - impl MessageHeader { - pub fn read(buf: &mut MessageReader<'_>) -> Result { - let version = buf.read_u8()?; - let tunnel_id = buf.read_u32()?; - - Ok(Self { version, tunnel_id }) - } - - pub fn write(&self, buf: &mut MessageWriter) { - buf.write_u8(self.version); - buf.write_u32(self.tunnel_id); - } - } - - #[derive(Debug, PartialEq, Eq, Clone, Copy)] - #[repr(u8)] - pub enum MessageType { - /// Client is requesting to initiate a connection - Initiate = 0x0, - - /// Server has accepted a connection - Initiated = 0x1, - - /// Forward a message on behalf of the player to - /// another player - Forward = 0x2, - - /// Message to keep the stream alive - /// (When the connect is inactive) - KeepAlive = 0x3, - } - - impl TryFrom for MessageType { - type Error = MessageError; - fn try_from(value: u8) -> Result { - Ok(match value { - 0x0 => Self::Initiate, - 0x1 => Self::Initiated, - 0x2 => Self::Forward, - 0x3 => Self::KeepAlive, - _ => return Err(MessageError::UnknownMessageType), - }) - } - } - - #[derive(Debug)] - pub enum TunnelMessage { - /// Client is requesting to initiate a connection - Initiate { - /// Association token to authenticate with - association_token: String, - }, - - /// Server created and associated the tunnel - Initiated { - /// Unique ID for the tunnel to include in future messages - /// to identify itself - tunnel_id: u32, - }, - - /// Client wants to forward a message - Forward { - /// Local socket pool index the message was sent to. - /// Used to map to the target within the game - index: u8, - - /// Message contents to forward - message: Vec, - }, - - /// Keep alive - KeepAlive, - } - - impl TunnelMessage { - pub fn read(buf: &mut MessageReader<'_>) -> Result { - let ty = buf.read_u8()?; - let ty = MessageType::try_from(ty)?; - - match ty { - MessageType::Initiate => { - // Get length of the association token - let length = buf.read_u16()? as usize; - let token_bytes = buf.read_bytes(length)?; - let token = String::from_utf8_lossy(token_bytes); - Ok(TunnelMessage::Initiate { - association_token: token.to_string(), - }) - } - MessageType::Initiated => { - let tunnel_id = buf.read_u32()?; - - Ok(TunnelMessage::Initiated { tunnel_id }) - } - MessageType::Forward => { - let index = buf.read_u8()?; - - // Get length of the association token - let length = buf.read_u16()? as usize; - - let message = buf.read_bytes(length)?; - - Ok(TunnelMessage::Forward { - index, - message: message.to_vec(), - }) - } - MessageType::KeepAlive => Ok(TunnelMessage::KeepAlive), - } - } - - pub fn write(&self, buf: &mut MessageWriter) { - match self { - TunnelMessage::Initiate { association_token } => { - debug_assert!(association_token.len() < u16::MAX as usize); - buf.write_u8(MessageType::Initiate as u8); - - buf.write_u16(association_token.len() as u16); - buf.write_bytes(association_token.as_bytes()); - } - TunnelMessage::Initiated { tunnel_id } => { - buf.write_u8(MessageType::Initiated as u8); - buf.write_u32(*tunnel_id); - } - TunnelMessage::Forward { index, message } => { - buf.write_u8(MessageType::Forward as u8); - debug_assert!(message.len() < u16::MAX as usize); - - buf.write_u8(*index); - buf.write_u16(message.len() as u16); - buf.write_bytes(message); - } - TunnelMessage::KeepAlive => { - buf.write_u8(MessageType::KeepAlive as u8); - } - } - } - } -}