From bcdd1bc477fabeecf6226b31a048e82ad016558d Mon Sep 17 00:00:00 2001 From: Jacobtread Date: Sun, 13 Oct 2024 20:25:18 +1300 Subject: [PATCH] refactor: merge session locking and sending logic, cleaned up session starting, removed session stop unwrapping --- src/routes/auth.rs | 2 +- src/routes/server.rs | 6 +- src/session/data.rs | 16 +-- src/session/mod.rs | 197 +++++++++++++++----------------- src/session/router.rs | 2 +- src/session/routes/messaging.rs | 2 +- src/session/routes/other.rs | 2 +- src/session/routes/util.rs | 2 +- 8 files changed, 112 insertions(+), 117 deletions(-) diff --git a/src/routes/auth.rs b/src/routes/auth.rs index e0cb6c91..0106c607 100644 --- a/src/routes/auth.rs +++ b/src/routes/auth.rs @@ -240,7 +240,7 @@ pub async fn handle_request_login_code( ); // Send the message - session.notify_handle().notify(notify_origin); + session.notify_handle.notify(notify_origin); Ok(StatusCode::OK) } diff --git a/src/routes/server.rs b/src/routes/server.rs index 45f025cc..edcde985 100644 --- a/src/routes/server.rs +++ b/src/routes/server.rs @@ -11,7 +11,7 @@ use crate::{ sessions::{AssociationId, Sessions}, tunnel::{Tunnel, TunnelService}, }, - session::{router::BlazeRouter, Session}, + session::{data::SessionData, router::BlazeRouter, Session}, utils::logging::LOG_FILE_NAME, }; use axum::{ @@ -120,7 +120,9 @@ pub async fn handle_upgrade( } }; - Session::start(upgraded, addr, association_id, router).await; + let data = SessionData::new(addr, association_id); + + Session::run(upgraded, data, router).await; } /// GET /api/server/tunnel diff --git a/src/session/data.rs b/src/session/data.rs index 0073475e..d237d3ec 100644 --- a/src/session/data.rs +++ b/src/session/data.rs @@ -1,6 +1,6 @@ use std::{net::Ipv4Addr, sync::Arc}; -use parking_lot::{Mutex, MutexGuard}; +use parking_lot::{RwLock, RwLockReadGuard}; use serde::Serialize; use crate::{ @@ -31,7 +31,7 @@ use super::{ pub struct SessionData { /// Extended session data for authenticated sessions - ext: Mutex>, + ext: RwLock>, /// IP address associated with the session addr: Ipv4Addr, @@ -59,8 +59,8 @@ impl SessionData { } // Read from the underlying session data - fn read(&self) -> MutexGuard<'_, Option> { - self.ext.lock() + fn read(&self) -> RwLockReadGuard<'_, Option> { + self.ext.read() } /// Writes to the underlying session data without publishing the changes @@ -68,7 +68,7 @@ impl SessionData { where F: FnOnce(&mut SessionDataExt) -> O, { - self.ext.lock().as_mut().map(update) + self.ext.write().as_mut().map(update) } /// Writes to the underlying session data, publishes changes to @@ -78,7 +78,7 @@ impl SessionData { where F: FnOnce(&mut SessionDataExt) -> O, { - self.ext.lock().as_mut().map(|data| { + self.ext.write().as_mut().map(|data| { let value = update(data); data.publish_update(); value @@ -87,13 +87,13 @@ impl SessionData { /// Clears the underlying session data pub fn clear(&self) { - self.ext.lock().take(); + self.ext.write().take(); } /// Starts a session from the provided player association pub fn start_session(&self, player: SessionPlayerAssociation) -> Arc { self.ext - .lock() + .write() .insert(SessionDataExt::new(player)) // Obtain the player to return .player_assoc diff --git a/src/session/mod.rs b/src/session/mod.rs index 9f1bb705..51bc4d43 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -8,17 +8,15 @@ use self::{ }; use crate::{ database::entities::Player, - services::sessions::AssociationId, utils::components::{component_key, DEBUG_IGNORED_PACKETS}, }; use data::SessionData; use futures_util::{future::BoxFuture, Sink, Stream}; use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; -use log::{debug, log_enabled, warn}; +use log::{debug, log_enabled}; use std::{ fmt::Debug, - net::Ipv4Addr, pin::Pin, sync::{ atomic::{AtomicU32, Ordering}, @@ -39,21 +37,21 @@ pub mod routes; pub type SessionLink = Arc; pub type WeakSessionLink = Weak; +static SESSION_IDS: AtomicU32 = AtomicU32::new(1); + pub struct Session { /// Unique ID for this session - id: u32, + pub id: u32, - /// Lock for handling packets with a session, ensures only one packet is - /// processed at a time and in the same order that it was received / sent - busy_lock: Arc>, + /// Handle for sending packets to this session + pub notify_handle: SessionNotifyHandle, - /// Sender for sending packets to the session - tx: mpsc::UnboundedSender, - - /// Mutable data associated with the session + /// Data associated with the session pub data: SessionData, } +/// Handle for sending packets to a session notification +/// channel #[derive(Clone)] pub struct SessionNotifyHandle { busy_lock: Arc>, @@ -61,6 +59,18 @@ pub struct SessionNotifyHandle { } impl SessionNotifyHandle { + /// Creates a new session notify handle, provides both the handle + /// and the receiving end to use for receiving from the handle + pub fn new() -> (SessionNotifyHandle, mpsc::UnboundedReceiver) { + let (tx, rx) = mpsc::unbounded_channel(); + + let handle = Self { + busy_lock: Default::default(), + tx, + }; + (handle, rx) + } + /// Pushes a new notification packet pub fn notify(&self, packet: Packet) { let tx = self.tx.clone(); @@ -73,102 +83,39 @@ impl SessionNotifyHandle { let _ = tx.send(packet); }); } -} -static SESSION_IDS: AtomicU32 = AtomicU32::new(1); + /// Internally lock the busy lock, used by the router when it wants to handle a request + fn lock_internal(&self) -> BoxFuture<'static, OwnedMutexGuard<()>> { + Box::pin(self.busy_lock.clone().lock_owned()) + } + + /// Immediately queues a packet onto the channel, should only be used + /// internally for sending handled responses use [Self::notify] in all + /// other cases + fn send_internal(&self, packet: Packet) { + let _ = self.tx.send(packet); + } +} impl Session { - pub async fn start( - io: Upgraded, - addr: Ipv4Addr, - association: Option, - router: Arc, - ) { + pub async fn run(io: Upgraded, data: SessionData, router: Arc) { // Obtain a session ID let id = SESSION_IDS.fetch_add(1, Ordering::AcqRel); - let (tx, rx) = mpsc::unbounded_channel(); - + let (notify_handle, rx) = SessionNotifyHandle::new(); let session = Arc::new(Self { id, - busy_lock: Default::default(), - tx, - data: SessionData::new(addr, association), + notify_handle, + data, }); - SessionFuture { - io: Framed::new(TokioIo::new(io), PacketCodec::default()), - router: &router, - rx, - session: session.clone(), - read_state: ReadState::Recv, - write_state: WriteState::Recv, - stop: false, - } - .await; - - session.stop(); + SessionFuture::new(io, &session, &router, rx).await; } +} - pub fn notify_handle(&self) -> SessionNotifyHandle { - SessionNotifyHandle { - busy_lock: self.busy_lock.clone(), - tx: self.tx.clone(), - } - } - - /// Called when the session is considered stopped (Reader/Writer future has completed) - /// in order to clean up any remaining references to the session before dropping - fn stop(self: Arc) { - // Clear session data - self.data.clear(); - - // Session should now be the sole owner - let session = match Arc::try_unwrap(self) { - Ok(value) => value, - Err(arc) => { - let references = Arc::strong_count(&arc); - warn!( - "Failed to stop session {} there are still {} references to it", - arc.id, references - ); - return; - } - }; - - debug!("Session stopped (SID: {})", session.id); - } - - /// Logs the contents of the provided packet to the debug output along with - /// the header information and basic session information. - /// - /// `action` The name of the action this packet is undergoing. - /// (e.g. Writing or Reading) - /// `packet` The packet that is being logged - fn debug_log_packet(&self, action: &'static str, packet: &Packet) { - // Skip if debug logging is disabled - if !log_enabled!(log::Level::Debug) { - return; - } - - let key = component_key(packet.frame.component, packet.frame.command); - - // Don't log the packet if its debug ignored - if DEBUG_IGNORED_PACKETS.contains(&key) { - return; - } - - // Get the authenticated player to include in the debug message - let auth = self.data.get_player(); - - let debug_data = DebugSessionData { - action, - id: self.id, - auth, - }; - let debug_packet = PacketDebug { packet }; - - debug!("\n{:?}{:?}", debug_data, debug_packet); +impl Drop for Session { + fn drop(&mut self) { + debug!("Session stopped (SID: {})", self.id); } } @@ -197,7 +144,7 @@ struct SessionFuture<'a> { /// Receiver for packets to write rx: mpsc::UnboundedReceiver, /// The session this link is for - session: SessionLink, + session: &'a SessionLink, /// The router to use router: &'a BlazeRouter, /// The reading state @@ -239,6 +186,23 @@ enum ReadState<'a> { } impl SessionFuture<'_> { + pub fn new<'a>( + io: Upgraded, + session: &'a Arc, + router: &'a BlazeRouter, + rx: mpsc::UnboundedReceiver, + ) -> SessionFuture<'a> { + SessionFuture { + io: Framed::new(TokioIo::new(io), PacketCodec::default()), + router, + rx, + session, + read_state: ReadState::Recv, + write_state: WriteState::Recv, + stop: false, + } + } + /// Polls the write state, the poll ready state returns whether /// the future should continue fn poll_write_state(&mut self, cx: &mut Context<'_>) -> Poll<()> { @@ -263,7 +227,7 @@ impl SessionFuture<'_> { .take() .expect("Unexpected write state without packet"); - self.session.debug_log_packet("Send", &packet); + debug_log_packet(self.session, "Send", &packet); // Write the packet to the buffer Pin::new(&mut self.io) @@ -300,9 +264,7 @@ impl SessionFuture<'_> { let result = ready!(Pin::new(&mut self.io).poll_next(cx)); if let Some(Ok(packet)) = result { - let lock_future = self.session.busy_lock.clone().lock_owned(); - let lock_future: BoxFuture<'static, OwnedMutexGuard<()>> = - Box::pin(lock_future); + let lock_future = self.session.notify_handle.lock_internal(); self.read_state = ReadState::Acquire { lock_future, @@ -322,7 +284,7 @@ impl SessionFuture<'_> { .take() .expect("Unexpected acquire state without packet"); - self.session.debug_log_packet("Receive", &packet); + debug_log_packet(self.session, "Receive", &packet); let future = self.router.handle(self.session.clone(), packet); @@ -337,7 +299,7 @@ impl SessionFuture<'_> { let response = ready!(Pin::new(future).poll(cx)); // Send the response to the writer - _ = self.session.tx.send(response); + self.session.notify_handle.send_internal(response); // Reset back to the reading state self.read_state = ReadState::Recv; @@ -363,3 +325,34 @@ impl Future for SessionFuture<'_> { } } } + +impl Drop for SessionFuture<'_> { + fn drop(&mut self) { + // Clear session data, speeds up process of ending the session + // prevents session data being accessed while shutting down + self.session.data.clear(); + } +} + +/// Logs debugging information about a player +fn debug_log_packet(session: &Session, action: &'static str, packet: &Packet) { + // Skip if debug logging is disabled + if !log_enabled!(log::Level::Debug) { + return; + } + + let key = component_key(packet.frame.component, packet.frame.command); + + // Don't log the packet if its debug ignored + if DEBUG_IGNORED_PACKETS.contains(&key) { + return; + } + + let id = session.id; + let auth = session.data.get_player(); + + let debug_data = DebugSessionData { action, id, auth }; + let debug_packet = PacketDebug { packet }; + + debug!("\n{:?}{:?}", debug_data, debug_packet); +} diff --git a/src/session/router.rs b/src/session/router.rs index 7121b7f6..7144b84e 100644 --- a/src/session/router.rs +++ b/src/session/router.rs @@ -231,7 +231,7 @@ impl FromPacketRequest for GamePlayer { player, net_data, Arc::downgrade(&req.state), - req.state.notify_handle(), + req.state.notify_handle.clone(), )) }) } diff --git a/src/session/routes/messaging.rs b/src/session/routes/messaging.rs index 172c5257..e5ba6110 100644 --- a/src/session/routes/messaging.rs +++ b/src/session/routes/messaging.rs @@ -51,6 +51,6 @@ pub async fn handle_fetch_messages( }, ); - session.notify_handle().notify(notify); + session.notify_handle.notify(notify); Blaze(FetchMessageResponse { count: 1 }) } diff --git a/src/session/routes/other.rs b/src/session/routes/other.rs index 7e6e9a31..d7fdf76e 100644 --- a/src/session/routes/other.rs +++ b/src/session/routes/other.rs @@ -64,7 +64,7 @@ pub async fn handle_submit_offline( return; } - session.notify_handle().notify(Packet::notify( + session.notify_handle.notify(Packet::notify( game_reporting::COMPONENT, game_reporting::GAME_REPORT_SUBMITTED, GameReportResponse, diff --git a/src/session/routes/util.rs b/src/session/routes/util.rs index e6259061..9f8c721a 100644 --- a/src/session/routes/util.rs +++ b/src/session/routes/util.rs @@ -102,7 +102,7 @@ pub async fn handle_post_auth( // Subscribe to the session with itself session .data - .add_subscriber(player.id, session.notify_handle()); + .add_subscriber(player.id, session.notify_handle.clone()); Ok(Blaze(PostAuthResponse { telemetry: TelemetryServer,