From e57dcfa6c62807b1d82ea03667651ae839a4eada Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sun, 24 Nov 2024 22:30:41 -0800 Subject: [PATCH 1/6] Fix CSRF protection --- Cargo.lock | 2 +- docs/docs/controllers/pages.md | 4 +- .../templates/templates-in-controllers.md | 6 +-- docs/docs/views/templates/variables.md | 2 +- examples/files/src/controllers/mod.rs | 6 +-- examples/turbo/src/controllers/chat/mod.rs | 17 +++++--- .../turbo/src/controllers/chat/typing/mod.rs | 5 ++- examples/turbo/src/controllers/signup/mod.rs | 4 +- rwf-admin/src/controllers/models.rs | 10 +++-- rwf-admin/src/controllers/requests.rs | 4 +- rwf-macros/Cargo.toml | 2 +- rwf-macros/src/lib.rs | 2 +- rwf-macros/src/render.rs | 40 +++++++++++++++---- rwf/Cargo.toml | 2 +- rwf/src/controller/middleware/csrf.rs | 8 +++- rwf/src/controller/mod.rs | 2 +- rwf/src/crypto.rs | 12 ++++-- rwf/src/http/request.rs | 20 +++++++++- rwf/src/view/template/context.rs | 9 +++++ rwf/src/view/template/lexer/value.rs | 21 +++++++++- 20 files changed, 132 insertions(+), 46 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 278bf295..c340ed11 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1920,7 +1920,7 @@ dependencies = [ [[package]] name = "rwf-macros" -version = "0.1.11" +version = "0.1.12" dependencies = [ "pluralizer", "proc-macro2", diff --git a/docs/docs/controllers/pages.md b/docs/docs/controllers/pages.md index f98c5b5f..77c81d37 100644 --- a/docs/docs/controllers/pages.md +++ b/docs/docs/controllers/pages.md @@ -18,7 +18,7 @@ if request.get() { To avoid doing this and cluttering your codebase, Rwf comes with the [`PageController`](https://docs.rs/rwf/latest/rwf/controller/trait.PageController.html). This controller trait implements the `GET`/`POST` split automatically and routes requests to two separate methods: `async fn get` and `async fn post`. - Let's use the example of a login page and implement the `PageController` for it: +Let's use the example of a login page built using the `PageController`: ```rust use rwf::prelude::*; @@ -29,7 +29,7 @@ struct Login; impl PageController for Login { // Handle GET and show the login form. async fn get(&self, request: &Request) -> Result { - render!("templates/login.html") + render!(request, "templates/login.html") } // Handle POST, receive form data, check information, and diff --git a/docs/docs/views/templates/templates-in-controllers.md b/docs/docs/views/templates/templates-in-controllers.md index b14fe1d4..904f3218 100644 --- a/docs/docs/views/templates/templates-in-controllers.md +++ b/docs/docs/views/templates/templates-in-controllers.md @@ -26,7 +26,7 @@ Since it's very common to render templates inside controllers, Rwf has the `rend #[async_trait] impl Controller for Index { async fn handle(&self, request: &Request) -> Result { - render!("templates/index.html", "title" => "Home page") + render!(request, "templates/index.html", "title" => "Home page") } } ``` @@ -36,7 +36,7 @@ The `render!` macro takes the template path as the first argument, and optionall If the template doesn't have any variables, you can use `render!` with just the template name: ```rust -render!("templates/index.html") +render!(request, "templates/index.html") ``` ### Response code @@ -44,5 +44,5 @@ render!("templates/index.html") By default, the `render!` macro returns the rendered template with HTTP code `200 OK`. If you want to return a different code, pass it as the last argument to the macro: ```rust -render!("templates/index.html", "title" => "Home page", 201) +render!(request, "templates/index.html", "title" => "Home page", 201) ``` diff --git a/docs/docs/views/templates/variables.md b/docs/docs/views/templates/variables.md index fec947a7..3813b714 100644 --- a/docs/docs/views/templates/variables.md +++ b/docs/docs/views/templates/variables.md @@ -79,7 +79,7 @@ async fn main() { You can override default variables in each template, by specifying the variable value when rendering the template: ```rust -render!("templates/index.html", "global_var" => "Another value") +render!(request, "templates/index.html", "global_var" => "Another value") ``` diff --git a/examples/files/src/controllers/mod.rs b/examples/files/src/controllers/mod.rs index 9f52ac67..4025a96b 100644 --- a/examples/files/src/controllers/mod.rs +++ b/examples/files/src/controllers/mod.rs @@ -6,8 +6,8 @@ pub struct Upload; #[async_trait] impl PageController for Upload { /// Upload page. - async fn get(&self, _req: &Request) -> Result { - render!("templates/upload.html") + async fn get(&self, req: &Request) -> Result { + render!(req, "templates/upload.html") } /// Handle upload file. @@ -16,7 +16,7 @@ impl PageController for Upload { let comment = form_data.get_required::("comment")?; if let Some(file) = form_data.file("file") { - render!("templates/ok.html", + render!(req, "templates/ok.html", "name" => file.name(), "size" => file.body().len() as i64, "content_type" => file.content_type(), diff --git a/examples/turbo/src/controllers/chat/mod.rs b/examples/turbo/src/controllers/chat/mod.rs index 53bfffe3..d8ec2aba 100644 --- a/examples/turbo/src/controllers/chat/mod.rs +++ b/examples/turbo/src/controllers/chat/mod.rs @@ -33,8 +33,14 @@ impl Default for ChatController { } impl ChatController { - fn chat_message(user: &User, message: &ChatMessage, mine: bool) -> Result { + fn chat_message( + request: &Request, + user: &User, + message: &ChatMessage, + mine: bool, + ) -> Result { Ok(turbo_stream!( + request, "templates/chat_message.html", "messages", "message" => UserMessage { @@ -73,7 +79,7 @@ impl PageController for ChatController { }) .collect::>(); - render!("templates/chat.html", + render!(request, "templates/chat.html", "title" => "rwf + Turbo = chat", "messages" => messages, "user" => user @@ -99,16 +105,17 @@ impl PageController for ChatController { // Broadcast the message to everyone else. { let broadcast = Comms::broadcast(&user); - let message = Self::chat_message(&user, &message, false)?.render(); + let message = Self::chat_message(request, &user, &message, false)?.render(); broadcast.send(message)?; - broadcast.send(TypingState { typing: false }.render(&user)?)?; + broadcast.send(TypingState { typing: false }.render(request, &user)?)?; } // Display the message for the user. - let chat_message = Self::chat_message(&user, &message, true)?; + let chat_message = Self::chat_message(request, &user, &message, true)?; let form = turbo_stream!( + request, "templates/chat_form.html", "form", "user" => user, diff --git a/examples/turbo/src/controllers/chat/typing/mod.rs b/examples/turbo/src/controllers/chat/typing/mod.rs index b7df9234..86950dae 100644 --- a/examples/turbo/src/controllers/chat/typing/mod.rs +++ b/examples/turbo/src/controllers/chat/typing/mod.rs @@ -15,7 +15,7 @@ impl Controller for TypingController { if let Some(user) = user { let broadcast = Comms::broadcast(&user); - broadcast.send(state.render(&user)?)?; + broadcast.send(state.render(request, &user)?)?; Ok(serde_json::json!({ "status": "success", @@ -33,8 +33,9 @@ pub struct TypingState { } impl TypingState { - pub fn render(&self, user: &User) -> Result { + pub fn render(&self, request: &Request, user: &User) -> Result { let stream = turbo_stream!( + request, "templates/typing.html", "typing-indicators" "user" => user.clone() diff --git a/examples/turbo/src/controllers/signup/mod.rs b/examples/turbo/src/controllers/signup/mod.rs index 4bedc64d..6cca561e 100644 --- a/examples/turbo/src/controllers/signup/mod.rs +++ b/examples/turbo/src/controllers/signup/mod.rs @@ -29,8 +29,8 @@ impl Default for SignupController { #[async_trait] impl PageController for SignupController { /// Respond to GET request. - async fn get(&self, _request: &Request) -> Result { - render!("templates/signup.html", "title" => "Signup") + async fn get(&self, request: &Request) -> Result { + render!(request, "templates/signup.html", "title" => "Signup") } /// Respond to POST request. diff --git a/rwf-admin/src/controllers/models.rs b/rwf-admin/src/controllers/models.rs index b983a8e5..b6d6fcea 100644 --- a/rwf-admin/src/controllers/models.rs +++ b/rwf-admin/src/controllers/models.rs @@ -8,9 +8,10 @@ pub struct ModelsController; #[async_trait] impl Controller for ModelsController { - async fn handle(&self, _request: &Request) -> Result { + async fn handle(&self, request: &Request) -> Result { let tables = Table::load().await?; - render!("templates/rwf_admin/models.html", + render!(request, + "templates/rwf_admin/models.html", "title" => "Models | Rust Web Framework", "models" => tables ) @@ -85,7 +86,8 @@ impl PageController for ModelController { data.push(row.values()?); } - render!("templates/rwf_admin/model.html", + render!(request, + "templates/rwf_admin/model.html", "title" => format!("{} | Rust Web Framework", model), "table_name" => model, "columns" => columns, @@ -114,7 +116,7 @@ impl PageController for NewModelController { .filter(|c| !c.skip()) .collect::>(); - render!("templates/rwf_admin/model_new.html", + render!(request, "templates/rwf_admin/model_new.html", "title" => format!("New record | {} | Rust Web Framework", model), "table_name" => model, "columns" => columns, diff --git a/rwf-admin/src/controllers/requests.rs b/rwf-admin/src/controllers/requests.rs index ec7807bc..777aa834 100644 --- a/rwf-admin/src/controllers/requests.rs +++ b/rwf-admin/src/controllers/requests.rs @@ -6,7 +6,7 @@ pub struct Requests; #[async_trait] impl Controller for Requests { - async fn handle(&self, _request: &Request) -> Result { + async fn handle(&self, request: &Request) -> Result { let requests = { let mut conn = Pool::connection().await?; RequestByCode::count(60).fetch_all(&mut conn).await? @@ -20,7 +20,7 @@ impl Controller for Requests { let requests = serde_json::to_string(&requests)?; let duration = serde_json::to_string(&duration)?; - render!("templates/rwf_admin/requests.html", + render!(request, "templates/rwf_admin/requests.html", "title" => "Requests | Rust Web Framework", "requests" => requests, "duration" => duration, diff --git a/rwf-macros/Cargo.toml b/rwf-macros/Cargo.toml index 1c8b692a..2a2ec830 100644 --- a/rwf-macros/Cargo.toml +++ b/rwf-macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rwf-macros" -version = "0.1.11" +version = "0.1.12" edition = "2021" license = "MIT" description = "Macros for the Rust Web Framework" diff --git a/rwf-macros/src/lib.rs b/rwf-macros/src/lib.rs index d808c378..3c4580a9 100644 --- a/rwf-macros/src/lib.rs +++ b/rwf-macros/src/lib.rs @@ -606,7 +606,7 @@ pub fn context(input: TokenStream) -> TokenStream { /// ### Example /// /// ```ignore -/// render!("templates/index.html", "title" => "Home page") +/// render!(request, "templates/index.html", "title" => "Home page") /// ``` #[proc_macro] pub fn render(input: TokenStream) -> TokenStream { diff --git a/rwf-macros/src/render.rs b/rwf-macros/src/render.rs index f0d03463..93c8588b 100644 --- a/rwf-macros/src/render.rs +++ b/rwf-macros/src/render.rs @@ -1,13 +1,17 @@ use crate::prelude::*; struct RenderInput { + request: Expr, + _comma_0: Token![,], template_name: LitStr, - _comma: Option, + _comma_1: Option, context: Vec, code: Option, } struct TurboStreamInput { + request: Expr, + _comma_0: Token![,], template_name: LitStr, _comma_1: Token![,], id: Expr, @@ -18,8 +22,10 @@ struct TurboStreamInput { impl TurboStreamInput { fn render_input(&self) -> RenderInput { RenderInput { + request: self.request.clone(), + _comma_0: self._comma_0.clone(), template_name: self.template_name.clone(), - _comma: self._comma_2.clone(), + _comma_1: self._comma_2.clone(), context: self.context.clone(), code: None, } @@ -28,6 +34,8 @@ impl TurboStreamInput { impl Parse for TurboStreamInput { fn parse(input: ParseStream) -> Result { + let request: Expr = input.parse()?; + let _comma_0: Token![,] = input.parse()?; let template_name: LitStr = input.parse()?; let _comma_1: Token![,] = input.parse()?; let id: Expr = input.parse()?; @@ -43,6 +51,8 @@ impl Parse for TurboStreamInput { } Ok(TurboStreamInput { + request, + _comma_0, template_name, _comma_1, id, @@ -61,11 +71,15 @@ struct ContextInput { } struct Context { + // request: Expr, + // _comma_0: Token![,], values: Vec, } impl Parse for Context { fn parse(input: ParseStream) -> Result { + // let request: Expr = input.parse()?; + // let _comma_0: Token![,] = input.parse()?; let mut values = vec![]; loop { let context: Result = input.parse(); @@ -77,7 +91,11 @@ impl Parse for Context { } } - Ok(Context { values }) + Ok(Context { + // request, + // _comma_0, + values, + }) } } @@ -94,11 +112,13 @@ impl Parse for ContextInput { impl Parse for RenderInput { fn parse(input: ParseStream) -> Result { + let request: Expr = input.parse()?; + let _comma_0: Token![,] = input.parse()?; let template_name: LitStr = input.parse()?; - let _comma: Option = input.parse()?; + let _comma_1: Option = input.parse()?; let mut code = None; - let context = if _comma.is_some() { + let context = if _comma_1.is_some() { let mut result = vec![]; loop { if input.peek(LitInt) { @@ -121,8 +141,10 @@ impl Parse for RenderInput { }; Ok(RenderInput { + request, + _comma_0, template_name, - _comma, + _comma_1, context, code, }) @@ -130,13 +152,15 @@ impl Parse for RenderInput { } fn render_call(input: &RenderInput) -> proc_macro2::TokenStream { + let request = &input.request; let render_call = if input.context.is_empty() { vec![quote! { - let html = template.render_default()?; + let context = rwf::view::template::Context::from_request(#request)?; + let html = template.render(&context)?; }] } else { let mut values = vec![quote! { - let mut context = rwf::view::template::Context::new(); + let mut context = rwf::view::template::Context::from_request(#request)?; }]; for value in &input.context { diff --git a/rwf/Cargo.toml b/rwf/Cargo.toml index 333bfd98..8e8279f9 100644 --- a/rwf/Cargo.toml +++ b/rwf/Cargo.toml @@ -32,7 +32,7 @@ parking_lot = "0.12" once_cell = "1" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -rwf-macros = { path = "../rwf-macros", version = "0.1.11" } +rwf-macros = { path = "../rwf-macros", version = "0.1.12" } colored = "2" serde = { version = "1", features = ["derive"] } serde_json = "1" diff --git a/rwf/src/controller/middleware/csrf.rs b/rwf/src/controller/middleware/csrf.rs index 2f4afb04..3c9fb166 100644 --- a/rwf/src/controller/middleware/csrf.rs +++ b/rwf/src/controller/middleware/csrf.rs @@ -61,9 +61,13 @@ impl Middleware for Csrf { } let header = request.header(CSRF_HEADER); + let session_id = match request.session_id() { + Some(session_id) => session_id.to_string(), + None => return Ok(Outcome::Stop(request, Response::csrf_error())), + }; if let Some(header) = header { - if csrf_token_validate(header) { + if csrf_token_validate(header, &session_id) { return Ok(Outcome::Forward(request)); } } @@ -71,7 +75,7 @@ impl Middleware for Csrf { match request.form_data() { Ok(form_data) => { if let Some(token) = form_data.get::(CSRF_INPUT) { - if csrf_token_validate(&token) { + if csrf_token_validate(&token, &session_id) { return Ok(Outcome::Forward(request)); } } diff --git a/rwf/src/controller/mod.rs b/rwf/src/controller/mod.rs index 59347929..91219f8d 100644 --- a/rwf/src/controller/mod.rs +++ b/rwf/src/controller/mod.rs @@ -300,7 +300,7 @@ pub trait Controller: Sync + Send { /// impl PageController for MyPage { /// // Respond to a GET request. /// async fn get(&self, request: &Request) -> Result { -/// render!("templates/my_page.html") +/// render!(request, "templates/my_page.html") /// } /// } /// ``` diff --git a/rwf/src/crypto.rs b/rwf/src/crypto.rs index 4f4040f0..2c1d3fac 100644 --- a/rwf/src/crypto.rs +++ b/rwf/src/crypto.rs @@ -241,9 +241,13 @@ pub fn random_string(n: usize) -> String { /// /// let token = csrf_token().unwrap(); /// ``` -pub fn csrf_token() -> Result { +pub fn csrf_token(session_id: &str) -> Result { // Our encryption is salted, re-using some known plain text isn't an issue. - let token = format!("{}_csrf", OffsetDateTime::now_utc().unix_timestamp()); + let token = format!( + "{}_{}", + OffsetDateTime::now_utc().unix_timestamp(), + session_id + ); encrypt(token.as_bytes()) } @@ -257,7 +261,7 @@ pub fn csrf_token() -> Result { /// let token = csrf_token().unwrap(); /// assert!(csrf_token_validate(&token)); /// ``` -pub fn csrf_token_validate(token: &str) -> bool { +pub fn csrf_token_validate(token: &str, session_id: &str) -> bool { match decrypt(token) { Ok(value) => { let value = String::from_utf8_lossy(&value).to_string(); @@ -277,7 +281,7 @@ pub fn csrf_token_validate(token: &str) -> bool { return false; }; - if marker.is_none() { + if marker != Some(session_id) { return false; } diff --git a/rwf/src/http/request.rs b/rwf/src/http/request.rs index 3299bd7e..665bbf03 100644 --- a/rwf/src/http/request.rs +++ b/rwf/src/http/request.rs @@ -1,9 +1,9 @@ //! HTTP request. -use std::fmt::Debug; use std::marker::Unpin; use std::net::SocketAddr; use std::ops::Deref; use std::sync::Arc; +use std::{collections::HashMap, fmt::Debug}; use serde::Deserialize; use serde_json::{Deserializer, Value}; @@ -15,6 +15,7 @@ use crate::{ config::get_config, controller::{Session, SessionId}, model::{ConnectionGuard, Model}, + view::ToTemplateValue, }; /// HTTP request. @@ -354,6 +355,23 @@ impl Deref for Request { } } +impl ToTemplateValue for Request { + fn to_template_value(&self) -> Result { + let mut hash = HashMap::new(); + hash.insert( + "path".to_string(), + self.path().to_string().to_template_value()?, + ); + hash.insert( + "session_id".to_string(), + self.session() + .map(|s| s.session_id.to_string()) + .to_template_value()?, + ); + Ok(crate::view::Value::Hash(hash)) + } +} + #[cfg(test)] pub mod test { use super::*; diff --git a/rwf/src/view/template/context.rs b/rwf/src/view/template/context.rs index ef055c5c..fce7f43d 100644 --- a/rwf/src/view/template/context.rs +++ b/rwf/src/view/template/context.rs @@ -9,6 +9,7 @@ //! let ctx = context!("var" => 1, "title" => "hello world!"); //! ``` //! +use crate::http::Request; use crate::view::template::{Error, ToTemplateValue, Value}; use parking_lot::RwLock; use std::collections::HashMap; @@ -32,6 +33,14 @@ impl Context { DEFAULTS.read().clone() } + /// Create template context from request. + pub fn from_request(request: &Request) -> Result { + let mut ctx = Self::new(); + ctx.set("request", request.to_template_value()?)?; + + Ok(ctx) + } + /// Get a variable value. pub fn get(&self, key: &str) -> Option { self.values.get(key).cloned() diff --git a/rwf/src/view/template/lexer/value.rs b/rwf/src/view/template/lexer/value.rs index 58e91751..5425d81d 100644 --- a/rwf/src/view/template/lexer/value.rs +++ b/rwf/src/view/template/lexer/value.rs @@ -185,6 +185,14 @@ impl Value { args: &[Value], context: &Context, ) -> Result { + let session_id = match context.get("request") { + Some(Value::Hash(hash)) => hash + .get("session_id") + .unwrap_or(&Value::String("".into())) + .to_string(), + + _ => "".to_string(), + }; match method_name { "nil" | "null" | "blank" => return Ok(Value::Boolean(self == &Value::Null)), "integer" => { @@ -378,11 +386,11 @@ impl Value { } }, - "csrf_token_raw" => Value::SafeString(crypto::csrf_token().unwrap()), + "csrf_token_raw" => Value::SafeString(crypto::csrf_token(&session_id).unwrap()), "csrf_token" => Value::SafeString(format!( r#""#, CSRF_INPUT, - crypto::csrf_token().unwrap(), + crypto::csrf_token(&session_id).unwrap(), )), "render" => match &args { @@ -465,6 +473,15 @@ impl ToTemplateValue for String { } } +impl ToTemplateValue for Option { + fn to_template_value(&self) -> Result { + match self { + Some(s) => Ok(s.to_template_value()?), + None => Ok(Value::Null), + } + } +} + impl ToTemplateValue for &str { fn to_template_value(&self) -> Result { Ok(Value::String(self.to_string())) From da1d30a6e8283f395b75bcaf8c904cb455e4645f Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sun, 24 Nov 2024 22:37:47 -0800 Subject: [PATCH 2/6] fix tests --- rwf/src/crypto.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rwf/src/crypto.rs b/rwf/src/crypto.rs index 2c1d3fac..4001c889 100644 --- a/rwf/src/crypto.rs +++ b/rwf/src/crypto.rs @@ -239,7 +239,7 @@ pub fn random_string(n: usize) -> String { /// ``` /// use rwf::crypto::csrf_token; /// -/// let token = csrf_token().unwrap(); +/// let token = csrf_token("1234").unwrap(); /// ``` pub fn csrf_token(session_id: &str) -> Result { // Our encryption is salted, re-using some known plain text isn't an issue. @@ -258,8 +258,8 @@ pub fn csrf_token(session_id: &str) -> Result { /// /// ``` /// # use rwf::crypto::{csrf_token, csrf_token_validate}; -/// let token = csrf_token().unwrap(); -/// assert!(csrf_token_validate(&token)); +/// let token = csrf_token("1234").unwrap(); +/// assert!(csrf_token_validate(&token, "1234")); /// ``` pub fn csrf_token_validate(token: &str, session_id: &str) -> bool { match decrypt(token) { From 4d3e4cdf3454114b61c07c313464cecedc8a05e5 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sun, 24 Nov 2024 22:38:16 -0800 Subject: [PATCH 3/6] remove irrelevant comment --- rwf/src/crypto.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/rwf/src/crypto.rs b/rwf/src/crypto.rs index 4001c889..73fb0eb2 100644 --- a/rwf/src/crypto.rs +++ b/rwf/src/crypto.rs @@ -242,7 +242,6 @@ pub fn random_string(n: usize) -> String { /// let token = csrf_token("1234").unwrap(); /// ``` pub fn csrf_token(session_id: &str) -> Result { - // Our encryption is salted, re-using some known plain text isn't an issue. let token = format!( "{}_{}", OffsetDateTime::now_utc().unix_timestamp(), From bd45c009067518bb9427ae206d44aab11aca5c78 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sun, 24 Nov 2024 23:01:21 -0800 Subject: [PATCH 4/6] add date header --- rwf/src/http/response.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/rwf/src/http/response.rs b/rwf/src/http/response.rs index 5345d31a..ef07ee1e 100644 --- a/rwf/src/http/response.rs +++ b/rwf/src/http/response.rs @@ -15,6 +15,7 @@ use once_cell::sync::Lazy; use serde::Serialize; use std::collections::HashMap; use std::marker::Unpin; +use time::OffsetDateTime; use tokio::io::{AsyncWrite, AsyncWriteExt}; use super::{head::Version, Body, Cookie, Cookies, Error, Headers, Request}; @@ -209,6 +210,12 @@ impl Response { ("content-type".to_string(), "text/plain".to_string()), ("server".to_string(), "rwf".to_string()), ("connection".to_string(), "keep-alive".to_string()), + ( + "date".to_string(), + OffsetDateTime::now_utc() + .format(&time::format_description::well_known::Rfc2822) + .unwrap(), + ), ])), body: Body::bytes(vec![]), version: Version::Http1, From e4710798b560e8de1b654f1cc16bcce45a592eef Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Mon, 25 Nov 2024 14:06:00 -0800 Subject: [PATCH 5/6] Add whole session to template --- rwf/src/controller/auth.rs | 19 +++++++++++++++++++ rwf/src/http/request.rs | 13 ++++++++----- rwf/src/view/template/lexer/value.rs | 16 +++++++++++----- 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/rwf/src/controller/auth.rs b/rwf/src/controller/auth.rs index f34973aa..0ca892ff 100644 --- a/rwf/src/controller/auth.rs +++ b/rwf/src/controller/auth.rs @@ -7,11 +7,13 @@ use super::Error; use crate::comms::WebsocketSender; use crate::config::get_config; use crate::http::{Authorization, Request, Response}; +use crate::view::{ToTemplateValue, Value}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use time::{Duration, OffsetDateTime}; +use std::collections::HashMap; use std::fmt::Debug; use std::sync::Arc; @@ -215,6 +217,23 @@ impl Default for Session { } } +impl ToTemplateValue for Session { + fn to_template_value(&self) -> Result { + let mut hash = HashMap::new(); + hash.insert("expiration".into(), Value::Integer(self.expiration)); + hash.insert( + "session_id".into(), + Value::String(self.session_id.to_string()), + ); + hash.insert( + "payload".into(), + Value::String(serde_json::to_string(&self.payload).unwrap()), + ); + + Ok(Value::Hash(hash)) + } +} + impl Session { /// Create a guest session. pub fn anonymous() -> Self { diff --git a/rwf/src/http/request.rs b/rwf/src/http/request.rs index 665bbf03..33c3e058 100644 --- a/rwf/src/http/request.rs +++ b/rwf/src/http/request.rs @@ -357,18 +357,21 @@ impl Deref for Request { impl ToTemplateValue for Request { fn to_template_value(&self) -> Result { + use crate::view::Value; + let mut hash = HashMap::new(); hash.insert( "path".to_string(), self.path().to_string().to_template_value()?, ); hash.insert( - "session_id".to_string(), - self.session() - .map(|s| s.session_id.to_string()) - .to_template_value()?, + "session".to_string(), + match self.session() { + Some(session) => session.to_template_value()?, + None => Value::Null, + }, ); - Ok(crate::view::Value::Hash(hash)) + Ok(Value::Hash(hash)) } } diff --git a/rwf/src/view/template/lexer/value.rs b/rwf/src/view/template/lexer/value.rs index 5425d81d..6166261f 100644 --- a/rwf/src/view/template/lexer/value.rs +++ b/rwf/src/view/template/lexer/value.rs @@ -185,14 +185,20 @@ impl Value { args: &[Value], context: &Context, ) -> Result { + let default_session_id = "".to_string(); let session_id = match context.get("request") { - Some(Value::Hash(hash)) => hash - .get("session_id") - .unwrap_or(&Value::String("".into())) - .to_string(), + Some(Value::Hash(hash)) => match hash.get("session") { + Some(Value::Hash(session)) => match session.get("session_id") { + Some(session_id) => session_id.to_string(), + None => default_session_id, + }, + + _ => default_session_id, + }, - _ => "".to_string(), + _ => default_session_id, }; + match method_name { "nil" | "null" | "blank" => return Ok(Value::Boolean(self == &Value::Null)), "integer" => { From cc67038e8af43cf9dc72f72893bb2532323edc1e Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Mon, 25 Nov 2024 14:47:03 -0800 Subject: [PATCH 6/6] Cleaner --- .../templates/templates-in-controllers.md | 6 ++++-- rwf/src/view/template/context.rs | 18 +++++++++++++++++ rwf/src/view/template/lexer/value.rs | 20 ++++--------------- scripts/docs.sh | 7 +++++++ 4 files changed, 33 insertions(+), 18 deletions(-) create mode 100644 scripts/docs.sh diff --git a/docs/docs/views/templates/templates-in-controllers.md b/docs/docs/views/templates/templates-in-controllers.md index 904f3218..51ba695e 100644 --- a/docs/docs/views/templates/templates-in-controllers.md +++ b/docs/docs/views/templates/templates-in-controllers.md @@ -31,14 +31,16 @@ impl Controller for Index { } ``` -The `render!` macro takes the template path as the first argument, and optionally, a mapping of variable names and values as subsequent arguments. It returns a [`Response`](../../controllers/response.md) automatically. +The `render!` macro takes the request as the first argument, the template path, and optionally a mapping of variable names and values. It returns a [`Response`](../../controllers/response.md) automatically. -If the template doesn't have any variables, you can use `render!` with just the template name: +If the template doesn't have any variables, you can omit them: ```rust render!(request, "templates/index.html") ``` +Passing the request into the macro ensures that secure [CSRF](../../security/CSRF.md) protection tokens are generated automatically. + ### Response code By default, the `render!` macro returns the rendered template with HTTP code `200 OK`. If you want to return a different code, pass it as the last argument to the macro: diff --git a/rwf/src/view/template/context.rs b/rwf/src/view/template/context.rs index fce7f43d..f8d097ed 100644 --- a/rwf/src/view/template/context.rs +++ b/rwf/src/view/template/context.rs @@ -57,6 +57,24 @@ impl Context { pub fn defaults(context: Self) { (*DEFAULTS.write()) = context; } + + /// Get the request session ID from the context, if any. + pub fn session_id(&self) -> Result { + match self.get("request") { + Some(Value::Hash(hash)) => match hash.get("session") { + Some(Value::Hash(session)) => match session.get("session_id") { + Some(session_id) => Ok(session_id.to_string()), + None => Err(Error::Runtime( + "session_id is missing from the context".into(), + )), + }, + + _ => Err(Error::Runtime("session is missing from the context".into())), + }, + + _ => Err(Error::Runtime("request is missing from the context".into())), + } + } } impl ToTemplateValue for Context { diff --git a/rwf/src/view/template/lexer/value.rs b/rwf/src/view/template/lexer/value.rs index 6166261f..11f7b7bd 100644 --- a/rwf/src/view/template/lexer/value.rs +++ b/rwf/src/view/template/lexer/value.rs @@ -185,20 +185,6 @@ impl Value { args: &[Value], context: &Context, ) -> Result { - let default_session_id = "".to_string(); - let session_id = match context.get("request") { - Some(Value::Hash(hash)) => match hash.get("session") { - Some(Value::Hash(session)) => match session.get("session_id") { - Some(session_id) => session_id.to_string(), - None => default_session_id, - }, - - _ => default_session_id, - }, - - _ => default_session_id, - }; - match method_name { "nil" | "null" | "blank" => return Ok(Value::Boolean(self == &Value::Null)), "integer" => { @@ -392,11 +378,13 @@ impl Value { } }, - "csrf_token_raw" => Value::SafeString(crypto::csrf_token(&session_id).unwrap()), + "csrf_token_raw" => { + Value::SafeString(crypto::csrf_token(&context.session_id()?).unwrap()) + } "csrf_token" => Value::SafeString(format!( r#""#, CSRF_INPUT, - crypto::csrf_token(&session_id).unwrap(), + crypto::csrf_token(&context.session_id()?).unwrap(), )), "render" => match &args { diff --git a/scripts/docs.sh b/scripts/docs.sh new file mode 100644 index 00000000..005a6613 --- /dev/null +++ b/scripts/docs.sh @@ -0,0 +1,7 @@ +#!/bin/bash +DIR="$( cd "$( dirname "$0" )" && pwd )" +cd "$DIR" +cd ../docs + +source venv/bin/activate +mkdocs serve