From 8271b749af4b516a73368ca4613ef614510aab0f Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 4 Dec 2024 10:05:36 -0800 Subject: [PATCH] fix session --- examples/auth/src/main.rs | 2 +- .../src/controllers/signup/middleware.rs | 6 +- rwf-tests/src/main.rs | 9 ++- rwf/src/controller/auth.rs | 6 +- rwf/src/controller/middleware/csrf.rs | 5 +- rwf/src/controller/mod.rs | 6 +- rwf/src/http/request.rs | 58 ++++++------------- rwf/src/http/response.rs | 25 +++----- rwf/src/http/server.rs | 5 +- rwf/src/model/mod.rs | 2 +- 10 files changed, 42 insertions(+), 82 deletions(-) diff --git a/examples/auth/src/main.rs b/examples/auth/src/main.rs index 58f27d65..478b5ac1 100644 --- a/examples/auth/src/main.rs +++ b/examples/auth/src/main.rs @@ -59,7 +59,7 @@ impl Controller for ProtectedAreaController { } async fn handle(&self, request: &Request) -> Result { - let session = request.session().unwrap(); + let session = request.session(); let welcome = format!("

Welcome, user {:?}

", session.session_id); Ok(Response::new().html(welcome)) } diff --git a/examples/turbo/src/controllers/signup/middleware.rs b/examples/turbo/src/controllers/signup/middleware.rs index 4a8b4023..ac59aa43 100644 --- a/examples/turbo/src/controllers/signup/middleware.rs +++ b/examples/turbo/src/controllers/signup/middleware.rs @@ -7,10 +7,8 @@ pub struct LoggedInCheck; #[rwf::async_trait] impl Middleware for LoggedInCheck { async fn handle_request(&self, request: Request) -> Result { - if let Some(session) = request.session() { - if session.authenticated() { - return Ok(Outcome::Stop(request, Response::new().redirect("/chat"))); - } + if request.session().authenticated() { + return Ok(Outcome::Stop(request, Response::new().redirect("/chat"))); } Ok(Outcome::Forward(request)) diff --git a/rwf-tests/src/main.rs b/rwf-tests/src/main.rs index 7c4d01e6..7ad34dba 100644 --- a/rwf-tests/src/main.rs +++ b/rwf-tests/src/main.rs @@ -104,11 +104,10 @@ impl RestController for BasePlayerController { type Resource = i64; async fn get(&self, request: &Request, id: &i64) -> Result { - if let Some(session) = request.session() { - session - .websocket() - .send(websocket::Message::Text("controller websocket".into()))?; - } + request + .session() + .websocket() + .send(websocket::Message::Text("controller websocket".into()))?; Ok(Response::new().html(format!("

base player controller, id: {}

", id))) } diff --git a/rwf/src/controller/auth.rs b/rwf/src/controller/auth.rs index dfb6d758..072bd286 100644 --- a/rwf/src/controller/auth.rs +++ b/rwf/src/controller/auth.rs @@ -329,11 +329,7 @@ impl SessionAuth { #[async_trait] impl Authentication for SessionAuth { async fn authorize(&self, request: &Request) -> Result { - if let Some(session) = request.session() { - Ok(session.authenticated()) - } else { - Ok(false) - } + Ok(request.session().authenticated()) } async fn denied(&self, _request: &Request) -> Result { diff --git a/rwf/src/controller/middleware/csrf.rs b/rwf/src/controller/middleware/csrf.rs index 3c9fb166..36d7bc69 100644 --- a/rwf/src/controller/middleware/csrf.rs +++ b/rwf/src/controller/middleware/csrf.rs @@ -61,10 +61,7 @@ impl Middleware for Csrf { } let header = request.header(CSRF_HEADER); - let session_id = match request.session_id() { - Some(session_id) => session_id.to_string(), - None => return Ok(Outcome::Stop(request, Response::csrf_error())), - }; + let session_id = request.session_id().to_string(); if let Some(header) = header { if csrf_token_validate(header, &session_id) { diff --git a/rwf/src/controller/mod.rs b/rwf/src/controller/mod.rs index 752890c3..50006213 100644 --- a/rwf/src/controller/mod.rs +++ b/rwf/src/controller/mod.rs @@ -728,11 +728,7 @@ pub trait WebsocketController: Controller { ) -> Result { use tokio::sync::broadcast::error::RecvError; - let session_id = if let Some(session) = request.session() { - session.session_id.clone() - } else { - return Err(Error::SessionMissingError); - }; + let session_id = request.session().session_id.clone(); info!( "{} {} {} connected", diff --git a/rwf/src/http/request.rs b/rwf/src/http/request.rs index 50ac7e10..eb2f10f1 100644 --- a/rwf/src/http/request.rs +++ b/rwf/src/http/request.rs @@ -22,7 +22,7 @@ use crate::{ #[derive(Debug, Clone)] pub struct Request { head: Head, - session: Option, + session: Session, inner: Arc, params: Option>, received_at: OffsetDateTime, @@ -35,7 +35,7 @@ impl Default for Request { fn default() -> Self { Self { head: Head::default(), - session: None, + session: Session::default(), inner: Arc::new(Inner::default()), params: None, received_at: OffsetDateTime::now_utc(), @@ -100,8 +100,8 @@ impl Request { let cookies = head.cookies(); let (session, renew_session) = match cookies.get_session()? { - Some(session) => (Some(session), false), - None => (Some(Session::anonymous()), true), + Some(session) => (session, false), + None => (Session::anonymous(), true), }; Ok(Request { @@ -214,8 +214,8 @@ impl Request { /// Get the session set on the request, if any. While all requests served /// by Rwf should have a session (guest or authenticated), some HTTP clients /// may not send the cookie back (e.g. cURL won't). - pub fn session(&self) -> Option<&Session> { - self.session.as_ref() + pub fn session(&self) -> &Session { + &self.session } /// Was the CSRF protection bypassed on this request? @@ -235,22 +235,16 @@ impl Request { /// /// This should uniquely identify a browser if it's a guest session, /// or a user if the user is logged in. - pub fn session_id(&self) -> Option { - self.session - .as_ref() - .map(|session| session.session_id.clone()) + pub fn session_id(&self) -> SessionId { + self.session.session_id.clone() } /// Get the authenticated user's ID. Combined with the `?` operator, /// will return `403 - Unauthorized` if not logged in. pub fn user_id(&self) -> Result { - if let Some(session_id) = self.session_id() { - match session_id { - SessionId::Authenticated(id) => Ok(id), - _ => Err(Error::Forbidden), - } - } else { - Err(Error::Forbidden) + match self.session_id() { + SessionId::Authenticated(id) => Ok(id), + _ => Err(Error::Forbidden), } } @@ -273,9 +267,7 @@ impl Request { /// ``` pub async fn user(&self, conn: &mut ConnectionGuard) -> Result, Error> { match self.session_id() { - Some(SessionId::Authenticated(user_id)) => { - Ok(Some(T::find(user_id).fetch(conn).await?)) - } + SessionId::Authenticated(user_id) => Ok(Some(T::find(user_id).fetch(conn).await?)), _ => Ok(None), } @@ -294,8 +286,9 @@ impl Request { /// /// This is automatically done by the HTTP server, /// if the session is available. - pub fn set_session(mut self, session: Option) -> Self { + pub(crate) fn set_session(mut self, session: Session) -> Self { self.session = session; + self.renew_session = true; self } @@ -328,10 +321,7 @@ impl Request { /// let response = request.login(1234); /// ``` pub fn login(&self, user_id: i64) -> Response { - let mut session = self - .session() - .map(|s| s.clone()) - .unwrap_or(Session::empty()); + let mut session = self.session.clone(); session.session_id = SessionId::Authenticated(user_id); Response::new().set_session(session).html("") } @@ -385,12 +375,7 @@ impl Request { /// let response = request.logout(); /// ``` pub fn logout(&self) -> Response { - let mut session = self - .session() - .map(|s| s.clone()) - .unwrap_or(Session::empty()); - session.session_id = SessionId::default(); - Response::new().set_session(session).html("") + Response::new().set_session(Session::anonymous()).html("") } pub(crate) fn renew_session(&self) -> bool { @@ -416,13 +401,7 @@ impl ToTemplateValue for Request { "query".to_string(), self.path().query().to_string().to_template_value()?, ); - hash.insert( - "session".to_string(), - match self.session() { - Some(session) => session.to_template_value()?, - None => Value::Null, - }, - ); + hash.insert("session".to_string(), self.session().to_template_value()?); Ok(Value::Hash(hash)) } } @@ -479,12 +458,11 @@ pub mod test { assert_eq!(req.peer(), &dummy_ip()); assert_eq!(req.upgrade_websocket(), false); assert_eq!(req.skip_csrf(), false); - assert_eq!(req.session(), None); + assert!(!req.session().authenticated()); assert!(req.user_id().is_err()); assert_eq!(req.body(), b"12345"); assert_eq!(req.string(), "12345".to_string()); assert!(req.form_data().is_err()); - assert!(req.session_id().is_none()); assert_eq!(req.query().len(), 1); assert_eq!(req.path().base(), "/apples"); diff --git a/rwf/src/http/response.rs b/rwf/src/http/response.rs index d5a377f6..6e078b4a 100644 --- a/rwf/src/http/response.rs +++ b/rwf/src/http/response.rs @@ -228,28 +228,21 @@ impl Response { /// /// This makes sure a valid session cookie is set on all responses. pub fn from_request(mut self, request: &Request) -> Result { - // Set an anonymous session if none is set on the request. - if self.session.is_none() && request.session().is_none() { - self.session = Some(Session::anonymous()); - } - // Session set manually on the request already. if let Some(ref session) = self.session { self.cookies.add_session(&session)?; } else { let session = request.session(); - if let Some(session) = session { - if session.should_renew() || request.renew_session() { - let session = session - .clone() - .renew(get_config().general.session_duration()); - self.cookies.add_session(&session)?; - - // Set the session on the response, so it can be - // passed down in handle_stream. - self.session = Some(session); - } + if session.should_renew() || request.renew_session() { + let session = session + .clone() + .renew(get_config().general.session_duration()); + self.cookies.add_session(&session)?; + + // Set the session on the response, so it can be + // passed down in handle_stream. + self.session = Some(session); } } diff --git a/rwf/src/http/server.rs b/rwf/src/http/server.rs index 0201c19c..1f951db3 100644 --- a/rwf/src/http/server.rs +++ b/rwf/src/http/server.rs @@ -151,7 +151,10 @@ impl Server { // Set the session on the request before we pass it down // to the stream handler. - let request = request.set_session(response.session().clone()); + let request = match response.session().clone() { + Some(session) => request.set_session(session), + None => request, + }; let ok = response.status().ok(); // Calculate duration. diff --git a/rwf/src/model/mod.rs b/rwf/src/model/mod.rs index dd1b8fa4..b6d204e1 100644 --- a/rwf/src/model/mod.rs +++ b/rwf/src/model/mod.rs @@ -4,7 +4,7 @@ use crate::colors::MaybeColorize; use crate::config::get_config; -use pool::{ConnectionRequest, ToConnectionRequest}; +use pool::ToConnectionRequest; use std::time::{Duration, Instant}; use tracing::{error, info};