diff --git a/src/api.rs b/src/api.rs index 3a385aa..68a918e 100644 --- a/src/api.rs +++ b/src/api.rs @@ -9,7 +9,7 @@ use log::error; use reqwest::{Client, Identity, Upgraded}; use semver::Version; use serde::{Deserialize, Serialize}; -use std::{path::Path, str::FromStr, sync::Arc}; +use std::{path::Path, str::FromStr}; use thiserror::Error; use url::Url; @@ -103,11 +103,11 @@ struct ServerDetails { #[derive(Debug, Clone)] pub struct LookupData { /// Server url - pub url: Arc, + pub url: Url, /// The server version pub version: Version, /// Association token if the server supports providing one - pub association: Arc>, + pub association: Option, } /// Errors that can occur while looking up a server @@ -226,9 +226,9 @@ pub async fn lookup_server( } Ok(LookupData { - url: Arc::new(url), + url, version: details.version, - association: Arc::new(details.association), + association: details.association, }) } @@ -254,7 +254,7 @@ pub enum ServerStreamError { /// * `base_url` - The server base URL (Connection URL) /// * `association` - Optional client association token pub async fn create_server_stream( - http_client: reqwest::Client, + http_client: &reqwest::Client, base_url: &Url, association: Option<&String>, ) -> Result { @@ -398,7 +398,7 @@ pub async fn proxy_http_request( /// * `base_url` - The server base URL (Connection URL) /// * `association` - Association token pub async fn create_server_tunnel( - http_client: reqwest::Client, + http_client: &reqwest::Client, base_url: &Url, association: &str, ) -> Result { diff --git a/src/ctx.rs b/src/ctx.rs new file mode 100644 index 0000000..6e68c5f --- /dev/null +++ b/src/ctx.rs @@ -0,0 +1,14 @@ +//! Shared context state that the app should store and pass to the +//! various servers when they are started + +use url::Url; + +/// Shared context +pub struct ClientContext { + /// HTTP client for the client to make requests with + pub http_client: reqwest::Client, + /// Base URL of the connected server + pub base_url: Url, + /// Optional association token + pub association: Option, +} diff --git a/src/lib.rs b/src/lib.rs index 7d7f746..54817e9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ pub use semver::Version; pub use url::Url; pub mod api; +pub mod ctx; pub mod fire; pub mod servers; pub mod update; diff --git a/src/servers/blaze.rs b/src/servers/blaze.rs index c61a953..48edb91 100644 --- a/src/servers/blaze.rs +++ b/src/servers/blaze.rs @@ -1,26 +1,19 @@ //! Server connected to by BlazeSDK clients (Majority of the game traffic) use super::{spawn_server_task, BLAZE_PORT}; -use crate::api::create_server_stream; +use crate::{api::create_server_stream, ctx::ClientContext}; use log::{debug, error}; use std::{net::Ipv4Addr, sync::Arc}; use tokio::{ io::copy_bidirectional, net::{TcpListener, TcpStream}, }; -use url::Url; /// Starts the blaze server /// /// ## Arguments -/// * `http_client` - The HTTP client passed around for connection upgrades -/// * `base_url` - The server base URL to connect clients to -/// * `association` - Optional client association -pub async fn start_blaze_server( - http_client: reqwest::Client, - base_url: Arc, - association: Arc>, -) -> std::io::Result<()> { +/// * `ctx` - The client context +pub async fn start_blaze_server(ctx: Arc) -> std::io::Result<()> { // Bind the local socket for accepting connections let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, BLAZE_PORT)).await?; @@ -28,12 +21,7 @@ pub async fn start_blaze_server( loop { let (client_stream, _) = listener.accept().await?; - spawn_server_task(handle( - client_stream, - http_client.clone(), - base_url.clone(), - association.clone(), - )); + spawn_server_task(handle(client_stream, ctx.clone())); } } @@ -41,26 +29,24 @@ pub async fn start_blaze_server( /// /// ## Arguments /// * `client_stream` - The client stream to read and write from -/// * `http_client` - The HTTP client passed around for connection upgrades -/// * `base_url` - The server base URL to connect clients to -/// * `association` - Client association token if supported -async fn handle( - mut client_stream: TcpStream, - http_client: reqwest::Client, - base_url: Arc, - association: Arc>, -) { +/// * `ctx` - The client context +async fn handle(mut client_stream: TcpStream, ctx: Arc) { debug!("Starting blaze connection"); // Create a stream to the Pocket Relay server - let mut server_stream = - match create_server_stream(http_client, &base_url, Option::as_ref(&association)).await { - Ok(stream) => stream, - Err(err) => { - error!("Failed to create server stream: {}", err); - return; - } - }; + let mut server_stream = match create_server_stream( + &ctx.http_client, + &ctx.base_url, + Option::as_ref(&ctx.association), + ) + .await + { + Ok(stream) => stream, + Err(err) => { + error!("Failed to create server stream: {}", err); + return; + } + }; debug!("Blaze connection linked"); diff --git a/src/servers/http.rs b/src/servers/http.rs index 175aa80..da57f29 100644 --- a/src/servers/http.rs +++ b/src/servers/http.rs @@ -3,7 +3,7 @@ //! is only capable of communicating over SSLv3 use super::HTTP_PORT; -use crate::api::proxy_http_request; +use crate::{api::proxy_http_request, ctx::ClientContext}; use hyper::{ http::uri::PathAndQuery, service::{make_service_fn, service_fn}, @@ -16,30 +16,22 @@ use std::{ net::{Ipv4Addr, SocketAddr, SocketAddrV4}, sync::Arc, }; -use url::Url; /// Starts the HTTP proxy server /// /// ## Arguments -/// * `http_client` - The HTTP client passed around for sending the requests -/// * `base_url` - The server base URL to proxy requests to -pub async fn start_http_server( - http_client: reqwest::Client, - base_url: Arc, -) -> std::io::Result<()> { +/// * `ctx` - The client context +pub async fn start_http_server(ctx: Arc) -> std::io::Result<()> { // Create the socket address the server will bind too let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, HTTP_PORT)); // Create service that uses the `handle function` let make_svc = make_service_fn(move |_conn| { - let http_client = http_client.clone(); - let base_url = base_url.clone(); + let ctx = ctx.clone(); async move { // service_fn converts our function into a `Service` - Ok::<_, Infallible>(service_fn(move |request| { - handle(request, http_client.clone(), base_url.clone()) - })) + Ok::<_, Infallible>(service_fn(move |request| handle(request, ctx.clone()))) } }); @@ -54,13 +46,11 @@ pub async fn start_http_server( /// to the Pocket Relay server /// /// ## Arguments -/// * `request` - The HTTP request -/// * `http_client` - The HTTP client to proxy the request with -/// * `base_url` - The server base URL (Connection URL) +/// * `request` - The HTTP request +/// * `ctx` - The client context async fn handle( request: Request, - http_client: reqwest::Client, - base_url: Arc, + ctx: Arc, ) -> Result, Infallible> { let path_and_query = request .uri() @@ -75,7 +65,7 @@ async fn handle( let path_and_query = path_and_query.strip_prefix('/').unwrap_or(path_and_query); // Create the new url from the path - let url = match base_url.join(path_and_query) { + let url = match ctx.base_url.join(path_and_query) { Ok(value) => value, Err(err) => { error!("Failed to create HTTP proxy URL: {}", err); @@ -87,7 +77,7 @@ async fn handle( }; // Proxy the request to the server - let response = match proxy_http_request(&http_client, url).await { + let response = match proxy_http_request(&ctx.http_client, url).await { Ok(value) => value, Err(err) => { error!("Failed to proxy HTTP request: {}", err); diff --git a/src/servers/telemetry.rs b/src/servers/telemetry.rs index c7a8d9d..6b7d216 100644 --- a/src/servers/telemetry.rs +++ b/src/servers/telemetry.rs @@ -2,24 +2,22 @@ //! forwarding them to the connect Pocket Relay server use super::{spawn_server_task, TELEMETRY_PORT}; -use crate::api::{publish_telemetry_event, TelemetryEvent}; +use crate::{ + api::{publish_telemetry_event, TelemetryEvent}, + ctx::ClientContext, +}; use log::error; use std::{net::Ipv4Addr, sync::Arc}; use tokio::{ io::AsyncReadExt, net::{TcpListener, TcpStream}, }; -use url::Url; /// Starts the telemetry server /// /// ## Arguments -/// * `http_client` - The HTTP client used to forward messages -/// * `base_url` - The server base URL to connect clients to -pub async fn start_telemetry_server( - http_client: reqwest::Client, - base_url: Arc, -) -> std::io::Result<()> { +/// * `ctx` - The client context +pub async fn start_telemetry_server(ctx: Arc) -> std::io::Result<()> { // Bind the local socket for accepting connections let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, TELEMETRY_PORT)).await?; @@ -27,14 +25,18 @@ pub async fn start_telemetry_server( loop { let (client_stream, _) = listener.accept().await?; - spawn_server_task(handle(client_stream, http_client.clone(), base_url.clone())); + spawn_server_task(handle(client_stream, ctx.clone())); } } /// Handler for processing telemetry client connections -async fn handle(mut client_stream: TcpStream, http_client: reqwest::Client, base_url: Arc) { - while let Ok(event) = read_telemetry_event(&mut client_stream).await { - if let Err(err) = publish_telemetry_event(&http_client, &base_url, event).await { +/// +/// ## Arguments +/// * `stream` - The stream to decode from +/// * `ctx` - The client context +async fn handle(mut stream: TcpStream, ctx: Arc) { + while let Ok(event) = read_telemetry_event(&mut stream).await { + if let Err(err) = publish_telemetry_event(&ctx.http_client, &ctx.base_url, event).await { error!("Failed to publish telemetry event: {}", err); } } diff --git a/src/servers/tunnel.rs b/src/servers/tunnel.rs index 004257e..c8419c8 100644 --- a/src/servers/tunnel.rs +++ b/src/servers/tunnel.rs @@ -9,6 +9,7 @@ use self::codec::{TunnelCodec, TunnelMessage}; use crate::{ api::create_server_tunnel, + ctx::ClientContext, servers::{spawn_server_task, GAME_HOST_PORT, RANDOM_PORT, TUNNEL_HOST_PORT}, }; use bytes::Bytes; @@ -26,7 +27,6 @@ use std::{ }; use tokio::{io::ReadBuf, net::UdpSocket, sync::mpsc, try_join}; use tokio_util::codec::Framed; -use url::Url; /// The fixed size of socket pool to use const SOCKET_POOL_SIZE: usize = 4; @@ -41,15 +41,9 @@ static LOCAL_SEND_TARGET: SocketAddr = /// connection to the server /// /// ## Arguments -/// * `http_client` - The HTTP client passed around for connection upgrades -/// * `base_url` - The server base URL to connect clients to -/// * `association` - Optional client association -pub async fn start_tunnel_server( - http_client: reqwest::Client, - base_url: Arc, - association: Arc>, -) -> std::io::Result<()> { - let association = match Option::as_ref(&association) { +/// * `ctx` - The client context +pub async fn start_tunnel_server(ctx: Arc) -> std::io::Result<()> { + let association = match Option::as_ref(&ctx.association) { Some(value) => value, // Don't try and tunnel without a token None => return Ok(()), @@ -63,25 +57,24 @@ pub async fn start_tunnel_server( // Looping to attempt reconnecting if lost while attempt_errors < MAX_ERROR_ATTEMPTS { // Create the tunnel (Future will end if tunnel stopped) - let reconnect_time = - if let Err(err) = create_tunnel(http_client.clone(), &base_url, association).await { - error!("Failed to create tunnel: {}", err); + let reconnect_time = if let Err(err) = create_tunnel(ctx.clone(), association).await { + error!("Failed to create tunnel: {}", err); - // Set last error - last_error = Some(err); + // Set last error + last_error = Some(err); - // Increase error attempts - attempt_errors += 1; + // Increase error attempts + attempt_errors += 1; - // Error should be delayed by the number of errors already hit - Duration::from_millis(1000 * attempt_errors as u64) - } else { - // Reset error attempts - attempt_errors = 0; + // Error should be delayed by the number of errors already hit + Duration::from_millis(1000 * attempt_errors as u64) + } else { + // Reset error attempts + attempt_errors = 0; - // Non errored reconnect can be quick - Duration::from_millis(1000) - }; + // Non errored reconnect can be quick + Duration::from_millis(1000) + }; debug!( "Next tunnel create attempt in: {}s", @@ -101,15 +94,11 @@ pub async fn start_tunnel_server( /// Creates a new tunnel /// /// ## Arguments -/// * `http_client` - The HTTP client passed around for connection upgrades -/// * `base_url` - The server base URL to connect clients to -async fn create_tunnel( - http_client: reqwest::Client, - base_url: &Url, - association: &str, -) -> std::io::Result<()> { +/// * `ctx` - The client context +/// * `association` - The client association token +async fn create_tunnel(ctx: Arc, association: &str) -> std::io::Result<()> { // Create the tunnel with the server - let io = create_server_tunnel(http_client, base_url, association) + let io = create_server_tunnel(&ctx.http_client, &ctx.base_url, association) .await // Wrap the tunnel with the [`TunnelCodec`] framing .map(|io| Framed::new(io, TunnelCodec::default()))