Skip to content

Commit

Permalink
refactor: merge session locking and sending logic, cleaned up session…
Browse files Browse the repository at this point in the history
… starting, removed session stop unwrapping
  • Loading branch information
jacobtread committed Oct 13, 2024
1 parent f841113 commit bcdd1bc
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 117 deletions.
2 changes: 1 addition & 1 deletion src/routes/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ pub async fn handle_request_login_code(
);

// Send the message
session.notify_handle().notify(notify_origin);
session.notify_handle.notify(notify_origin);

Ok(StatusCode::OK)
}
Expand Down
6 changes: 4 additions & 2 deletions src/routes/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
sessions::{AssociationId, Sessions},
tunnel::{Tunnel, TunnelService},
},
session::{router::BlazeRouter, Session},
session::{data::SessionData, router::BlazeRouter, Session},
utils::logging::LOG_FILE_NAME,
};
use axum::{
Expand Down Expand Up @@ -120,7 +120,9 @@ pub async fn handle_upgrade(
}
};

Session::start(upgraded, addr, association_id, router).await;
let data = SessionData::new(addr, association_id);

Session::run(upgraded, data, router).await;
}

/// GET /api/server/tunnel
Expand Down
16 changes: 8 additions & 8 deletions src/session/data.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{net::Ipv4Addr, sync::Arc};

use parking_lot::{Mutex, MutexGuard};
use parking_lot::{RwLock, RwLockReadGuard};
use serde::Serialize;

use crate::{
Expand Down Expand Up @@ -31,7 +31,7 @@ use super::{

pub struct SessionData {
/// Extended session data for authenticated sessions
ext: Mutex<Option<SessionDataExt>>,
ext: RwLock<Option<SessionDataExt>>,

/// IP address associated with the session
addr: Ipv4Addr,
Expand Down Expand Up @@ -59,16 +59,16 @@ impl SessionData {
}

// Read from the underlying session data
fn read(&self) -> MutexGuard<'_, Option<SessionDataExt>> {
self.ext.lock()
fn read(&self) -> RwLockReadGuard<'_, Option<SessionDataExt>> {
self.ext.read()
}

/// Writes to the underlying session data without publishing the changes
fn write_silent<F, O>(&self, update: F) -> Option<O>
where
F: FnOnce(&mut SessionDataExt) -> O,
{
self.ext.lock().as_mut().map(update)
self.ext.write().as_mut().map(update)
}

/// Writes to the underlying session data, publishes changes to
Expand All @@ -78,7 +78,7 @@ impl SessionData {
where
F: FnOnce(&mut SessionDataExt) -> O,
{
self.ext.lock().as_mut().map(|data| {
self.ext.write().as_mut().map(|data| {
let value = update(data);
data.publish_update();
value
Expand All @@ -87,13 +87,13 @@ impl SessionData {

/// Clears the underlying session data
pub fn clear(&self) {
self.ext.lock().take();
self.ext.write().take();
}

/// Starts a session from the provided player association
pub fn start_session(&self, player: SessionPlayerAssociation) -> Arc<Player> {
self.ext
.lock()
.write()
.insert(SessionDataExt::new(player))
// Obtain the player to return
.player_assoc
Expand Down
197 changes: 95 additions & 102 deletions src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,15 @@ use self::{
};
use crate::{
database::entities::Player,
services::sessions::AssociationId,
utils::components::{component_key, DEBUG_IGNORED_PACKETS},
};
use data::SessionData;
use futures_util::{future::BoxFuture, Sink, Stream};
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use log::{debug, log_enabled, warn};
use log::{debug, log_enabled};
use std::{
fmt::Debug,
net::Ipv4Addr,
pin::Pin,
sync::{
atomic::{AtomicU32, Ordering},
Expand All @@ -39,28 +37,40 @@ pub mod routes;
pub type SessionLink = Arc<Session>;
pub type WeakSessionLink = Weak<Session>;

static SESSION_IDS: AtomicU32 = AtomicU32::new(1);

pub struct Session {
/// Unique ID for this session
id: u32,
pub id: u32,

/// Lock for handling packets with a session, ensures only one packet is
/// processed at a time and in the same order that it was received / sent
busy_lock: Arc<tokio::sync::Mutex<()>>,
/// Handle for sending packets to this session
pub notify_handle: SessionNotifyHandle,

/// Sender for sending packets to the session
tx: mpsc::UnboundedSender<Packet>,

/// Mutable data associated with the session
/// Data associated with the session
pub data: SessionData,
}

/// Handle for sending packets to a session notification
/// channel
#[derive(Clone)]
pub struct SessionNotifyHandle {
busy_lock: Arc<tokio::sync::Mutex<()>>,
tx: mpsc::UnboundedSender<Packet>,
}

impl SessionNotifyHandle {
/// Creates a new session notify handle, provides both the handle
/// and the receiving end to use for receiving from the handle
pub fn new() -> (SessionNotifyHandle, mpsc::UnboundedReceiver<Packet>) {
let (tx, rx) = mpsc::unbounded_channel();

let handle = Self {
busy_lock: Default::default(),
tx,
};
(handle, rx)
}

/// Pushes a new notification packet
pub fn notify(&self, packet: Packet) {
let tx = self.tx.clone();
Expand All @@ -73,102 +83,39 @@ impl SessionNotifyHandle {
let _ = tx.send(packet);
});
}
}

static SESSION_IDS: AtomicU32 = AtomicU32::new(1);
/// Internally lock the busy lock, used by the router when it wants to handle a request
fn lock_internal(&self) -> BoxFuture<'static, OwnedMutexGuard<()>> {
Box::pin(self.busy_lock.clone().lock_owned())
}

/// Immediately queues a packet onto the channel, should only be used
/// internally for sending handled responses use [Self::notify] in all
/// other cases
fn send_internal(&self, packet: Packet) {
let _ = self.tx.send(packet);
}
}

impl Session {
pub async fn start(
io: Upgraded,
addr: Ipv4Addr,
association: Option<AssociationId>,
router: Arc<BlazeRouter>,
) {
pub async fn run(io: Upgraded, data: SessionData, router: Arc<BlazeRouter>) {
// Obtain a session ID
let id = SESSION_IDS.fetch_add(1, Ordering::AcqRel);

let (tx, rx) = mpsc::unbounded_channel();

let (notify_handle, rx) = SessionNotifyHandle::new();
let session = Arc::new(Self {
id,
busy_lock: Default::default(),
tx,
data: SessionData::new(addr, association),
notify_handle,
data,
});

SessionFuture {
io: Framed::new(TokioIo::new(io), PacketCodec::default()),
router: &router,
rx,
session: session.clone(),
read_state: ReadState::Recv,
write_state: WriteState::Recv,
stop: false,
}
.await;

session.stop();
SessionFuture::new(io, &session, &router, rx).await;
}
}

pub fn notify_handle(&self) -> SessionNotifyHandle {
SessionNotifyHandle {
busy_lock: self.busy_lock.clone(),
tx: self.tx.clone(),
}
}

/// Called when the session is considered stopped (Reader/Writer future has completed)
/// in order to clean up any remaining references to the session before dropping
fn stop(self: Arc<Self>) {
// Clear session data
self.data.clear();

// Session should now be the sole owner
let session = match Arc::try_unwrap(self) {
Ok(value) => value,
Err(arc) => {
let references = Arc::strong_count(&arc);
warn!(
"Failed to stop session {} there are still {} references to it",
arc.id, references
);
return;
}
};

debug!("Session stopped (SID: {})", session.id);
}

/// Logs the contents of the provided packet to the debug output along with
/// the header information and basic session information.
///
/// `action` The name of the action this packet is undergoing.
/// (e.g. Writing or Reading)
/// `packet` The packet that is being logged
fn debug_log_packet(&self, action: &'static str, packet: &Packet) {
// Skip if debug logging is disabled
if !log_enabled!(log::Level::Debug) {
return;
}

let key = component_key(packet.frame.component, packet.frame.command);

// Don't log the packet if its debug ignored
if DEBUG_IGNORED_PACKETS.contains(&key) {
return;
}

// Get the authenticated player to include in the debug message
let auth = self.data.get_player();

let debug_data = DebugSessionData {
action,
id: self.id,
auth,
};
let debug_packet = PacketDebug { packet };

debug!("\n{:?}{:?}", debug_data, debug_packet);
impl Drop for Session {
fn drop(&mut self) {
debug!("Session stopped (SID: {})", self.id);
}
}

Expand Down Expand Up @@ -197,7 +144,7 @@ struct SessionFuture<'a> {
/// Receiver for packets to write
rx: mpsc::UnboundedReceiver<Packet>,
/// The session this link is for
session: SessionLink,
session: &'a SessionLink,
/// The router to use
router: &'a BlazeRouter,
/// The reading state
Expand Down Expand Up @@ -239,6 +186,23 @@ enum ReadState<'a> {
}

impl SessionFuture<'_> {
pub fn new<'a>(
io: Upgraded,
session: &'a Arc<Session>,
router: &'a BlazeRouter,
rx: mpsc::UnboundedReceiver<Packet>,
) -> SessionFuture<'a> {
SessionFuture {
io: Framed::new(TokioIo::new(io), PacketCodec::default()),
router,
rx,
session,
read_state: ReadState::Recv,
write_state: WriteState::Recv,
stop: false,
}
}

/// Polls the write state, the poll ready state returns whether
/// the future should continue
fn poll_write_state(&mut self, cx: &mut Context<'_>) -> Poll<()> {
Expand All @@ -263,7 +227,7 @@ impl SessionFuture<'_> {
.take()
.expect("Unexpected write state without packet");

self.session.debug_log_packet("Send", &packet);
debug_log_packet(self.session, "Send", &packet);

// Write the packet to the buffer
Pin::new(&mut self.io)
Expand Down Expand Up @@ -300,9 +264,7 @@ impl SessionFuture<'_> {
let result = ready!(Pin::new(&mut self.io).poll_next(cx));

if let Some(Ok(packet)) = result {
let lock_future = self.session.busy_lock.clone().lock_owned();
let lock_future: BoxFuture<'static, OwnedMutexGuard<()>> =
Box::pin(lock_future);
let lock_future = self.session.notify_handle.lock_internal();

self.read_state = ReadState::Acquire {
lock_future,
Expand All @@ -322,7 +284,7 @@ impl SessionFuture<'_> {
.take()
.expect("Unexpected acquire state without packet");

self.session.debug_log_packet("Receive", &packet);
debug_log_packet(self.session, "Receive", &packet);

let future = self.router.handle(self.session.clone(), packet);

Expand All @@ -337,7 +299,7 @@ impl SessionFuture<'_> {
let response = ready!(Pin::new(future).poll(cx));

// Send the response to the writer
_ = self.session.tx.send(response);
self.session.notify_handle.send_internal(response);

// Reset back to the reading state
self.read_state = ReadState::Recv;
Expand All @@ -363,3 +325,34 @@ impl Future for SessionFuture<'_> {
}
}
}

impl Drop for SessionFuture<'_> {
fn drop(&mut self) {
// Clear session data, speeds up process of ending the session
// prevents session data being accessed while shutting down
self.session.data.clear();
}
}

/// Logs debugging information about a player
fn debug_log_packet(session: &Session, action: &'static str, packet: &Packet) {
// Skip if debug logging is disabled
if !log_enabled!(log::Level::Debug) {
return;
}

let key = component_key(packet.frame.component, packet.frame.command);

// Don't log the packet if its debug ignored
if DEBUG_IGNORED_PACKETS.contains(&key) {
return;
}

let id = session.id;
let auth = session.data.get_player();

let debug_data = DebugSessionData { action, id, auth };
let debug_packet = PacketDebug { packet };

debug!("\n{:?}{:?}", debug_data, debug_packet);
}
2 changes: 1 addition & 1 deletion src/session/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ impl FromPacketRequest for GamePlayer {
player,
net_data,
Arc::downgrade(&req.state),
req.state.notify_handle(),
req.state.notify_handle.clone(),
))
})
}
Expand Down
Loading

0 comments on commit bcdd1bc

Please sign in to comment.