diff --git a/examples/users/src/main.rs b/examples/users/src/main.rs index 332e67b..158c7b5 100644 --- a/examples/users/src/main.rs +++ b/examples/users/src/main.rs @@ -1,3 +1,4 @@ +use rwf::controller::LoginController; use rwf::{http::Server, prelude::*}; mod controllers; @@ -7,8 +8,11 @@ mod models; async fn main() { Logger::init(); + let signup: LoginController = + LoginController::new("templates/signup.html").redirect("/profile"); + Server::new(vec![ - route!("/signup" => controllers::Signup), + route!("/signup" => { signup }), route!("/login" => controllers::login), route!("/profile" => controllers::profile), ]) diff --git a/examples/users/src/models.rs b/examples/users/src/models.rs index 166096d..e468dcc 100644 --- a/examples/users/src/models.rs +++ b/examples/users/src/models.rs @@ -3,13 +3,22 @@ use rwf::crypto::{hash, hash_validate}; use rwf::prelude::*; use tokio::task::spawn_blocking; +#[derive(Clone, macros::Model, macros::UserModel)] +#[user_model(email, password_hash)] +pub struct User2 { + id: Option, + email: String, + password_hash: String, +} + pub enum UserLogin { NoSuchUser, WrongPassword, Ok(User), } -#[derive(Clone, macros::Model)] +#[derive(Clone, macros::Model, macros::UserModel)] +#[user_model(email, password)] pub struct User { id: Option, email: String, diff --git a/examples/users/templates/signup.html b/examples/users/templates/signup.html index fe25897..b757dfd 100644 --- a/examples/users/templates/signup.html +++ b/examples/users/templates/signup.html @@ -1,14 +1,12 @@ - - - + <%% "templates/head.html" %>
- <% if error %> + <% if error_user_exists %>
Account with this email already exists, and the password is incorrect.
@@ -16,12 +14,23 @@ <%= csrf_token() %>
- + + <% if error_identifier %> +
+ Provided email is not valid. +
+ <% end %>
+ + <% if error_password %> +
+ Provided password is not valid. +
+ <% end %>
diff --git a/rwf-macros/src/lib.rs b/rwf-macros/src/lib.rs index 784a686..9afed36 100644 --- a/rwf-macros/src/lib.rs +++ b/rwf-macros/src/lib.rs @@ -1,17 +1,12 @@ extern crate proc_macro; -use proc_macro::TokenStream; - -use syn::{ - parse_macro_input, punctuated::Punctuated, Attribute, Data, DeriveInput, Expr, ItemFn, Meta, - ReturnType, Token, Type, Visibility, -}; - -use quote::quote; - mod model; mod prelude; mod render; +mod route; +mod user; + +use prelude::*; /// The `#[derive(Model)]` macro. /// @@ -522,16 +517,7 @@ pub fn error(input: TokenStream) -> TokenStream { /// ``` #[proc_macro] pub fn route(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input with Punctuated]>::parse_terminated); - let mut iter = input.into_iter(); - - let route = iter.next().unwrap(); - let controller = iter.next().unwrap(); - - quote! { - #controller::default().route(#route) - } - .into() + route::route_impl(input) } /// Create CRUD routes for the controller. @@ -737,3 +723,8 @@ pub fn controller(_args: TokenStream, input: TokenStream) -> TokenStream { } .into() } + +#[proc_macro_derive(UserModel, attributes(user_model))] +pub fn derive_user_model(input: TokenStream) -> TokenStream { + user::impl_derive_user_model(input) +} diff --git a/rwf-macros/src/prelude.rs b/rwf-macros/src/prelude.rs index d54fb65..55b1f54 100644 --- a/rwf-macros/src/prelude.rs +++ b/rwf-macros/src/prelude.rs @@ -1,4 +1,5 @@ pub use proc_macro::TokenStream; pub use quote::quote; pub use syn::parse::*; +pub use syn::punctuated::Punctuated; pub use syn::*; diff --git a/rwf-macros/src/route.rs b/rwf-macros/src/route.rs new file mode 100644 index 0000000..0a353a9 --- /dev/null +++ b/rwf-macros/src/route.rs @@ -0,0 +1,34 @@ +use crate::prelude::*; + +struct RouteInput { + route: Expr, + controller: Expr, +} + +impl Parse for RouteInput { + fn parse(input: ParseStream) -> Result { + let route = input.parse()?; + let _arrow: Token![=>] = input.parse()?; + + let controller = input.parse()?; + + Ok(Self { route, controller }) + } +} + +pub(crate) fn route_impl(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as RouteInput); + + let route = input.route; + let controller = input.controller; + + let controller = match controller { + Expr::Path(expr) => quote! { #expr::default() }, + expr => quote! { #expr }, + }; + + quote! { + #controller.route(#route) + } + .into() +} diff --git a/rwf-macros/src/user.rs b/rwf-macros/src/user.rs new file mode 100644 index 0000000..83bf95a --- /dev/null +++ b/rwf-macros/src/user.rs @@ -0,0 +1,52 @@ +use super::prelude::*; + +struct UserModel { + identifier: Ident, + password: Ident, +} + +impl Parse for UserModel { + fn parse(input: parse::ParseStream) -> Result { + let identifier: Ident = input.parse()?; + let _: Token![,] = input.parse()?; + let password = input.parse()?; + + Ok(UserModel { + identifier, + password, + }) + } +} + +pub(crate) fn impl_derive_user_model(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let ident = &input.ident; + + if let Some(attr) = input.attrs.first() { + match attr.meta { + Meta::List(ref attrs) => { + let attrs = syn::parse2::(attrs.tokens.clone()).unwrap(); + + let identifier = attrs.identifier.to_string(); + let password = attrs.password.to_string(); + + return quote! { + impl rwf::model::UserModel for #ident { + fn identifier_column() -> &'static str { + #identifier + } + + fn password_column() -> &'static str { + #password + } + } + } + .into(); + } + + _ => (), + } + } + + quote! {}.into() +} diff --git a/rwf/src/controller/error.rs b/rwf/src/controller/error.rs index 91f8ee4..9092639 100644 --- a/rwf/src/controller/error.rs +++ b/rwf/src/controller/error.rs @@ -48,6 +48,9 @@ pub enum Error { #[error("timeout exceeded")] TimeoutError(#[from] tokio::time::error::Elapsed), + + #[error("user error: {0}")] + UserError(#[from] crate::model::user::Error), } impl Error { diff --git a/rwf/src/controller/mod.rs b/rwf/src/controller/mod.rs index 5000621..7cd52fe 100644 --- a/rwf/src/controller/mod.rs +++ b/rwf/src/controller/mod.rs @@ -33,6 +33,7 @@ pub mod middleware; pub mod ser; pub mod static_files; pub mod turbo_stream; +pub mod user; #[cfg(feature = "wsgi")] pub mod wsgi; @@ -50,6 +51,7 @@ pub use error::Error; pub use middleware::{Middleware, MiddlewareHandler, MiddlewareSet, Outcome, RateLimiter}; pub use static_files::{CacheControl, StaticFiles}; pub use turbo_stream::TurboStream; +pub use user::LoginController; use super::http::{ websocket::{self, DataFrame}, diff --git a/rwf/src/controller/user.rs b/rwf/src/controller/user.rs new file mode 100644 index 0000000..c0a68c9 --- /dev/null +++ b/rwf/src/controller/user.rs @@ -0,0 +1,85 @@ +use std::marker::PhantomData; + +use super::{Controller, Error, PageController}; +use crate::view::{Context, Template}; +use crate::{ + http::{Request, Response}, + model::{user::Error as UserError, UserModel}, +}; +use async_trait::async_trait; + +pub struct LoginController { + redirect: Option, + template: &'static str, + _marker: PhantomData, +} + +impl LoginController { + pub fn new(template: &'static str) -> Self { + Self { + redirect: None, + template, + _marker: PhantomData, + } + } + + pub fn redirect(mut self, redirect: impl ToString) -> Self { + self.redirect = Some(redirect.to_string()); + self + } + + fn error(&self, request: &Request, error: &str) -> Result { + let template = Template::load(self.template)?; + let mut ctx = Context::new(); + ctx.set(error, true)?; + ctx.set("request", request.clone())?; + Ok(Response::new().html(template.render(&ctx)?).code(400)) + } +} + +#[async_trait] +impl PageController for LoginController { + async fn get(&self, request: &Request) -> Result { + let mut ctx = Context::new(); + ctx.set("request", request.clone())?; + let template = Template::load(self.template)?; + Ok(Response::new().html(template.render(&ctx)?)) + } + + async fn post(&self, request: &Request) -> Result { + let form = request.form_data()?; + let identifier: String = match form.get_required("identifier") { + Ok(field) => field, + Err(_) => return self.error(request, "error_identifier"), + }; + let identifier = identifier.trim().to_string(); + let password: String = match form.get_required("password") { + Ok(field) => field, + Err(_) => return self.error(request, "error_password"), + }; + + match T::create_user(identifier, password).await { + Ok(user) => { + let id = user.id().integer()?; + let response = request.login(id); + + if let Some(ref redirect) = self.redirect { + Ok(response.redirect(redirect)) + } else { + Ok(response) + } + } + Err(err) => match err { + UserError::UserExists => return self.error(request, "error_user_exists"), + err => return Err(err.into()), + }, + } + } +} + +#[async_trait] +impl Controller for LoginController { + async fn handle(&self, request: &Request) -> Result { + PageController::handle(self, request).await + } +} diff --git a/rwf/src/model/error.rs b/rwf/src/model/error.rs index b2f692d..495e958 100644 --- a/rwf/src/model/error.rs +++ b/rwf/src/model/error.rs @@ -47,6 +47,9 @@ pub enum Error { "column \"{0}\" is missing from the row returned by the database,\ndid you forget to specify it in the query?" )] Column(String), + + #[error("value is not an integer")] + NotAnInteger, } impl Error { diff --git a/rwf/src/model/mod.rs b/rwf/src/model/mod.rs index b6d204e..7369d8b 100644 --- a/rwf/src/model/mod.rs +++ b/rwf/src/model/mod.rs @@ -28,6 +28,7 @@ pub mod prelude; pub mod row; pub mod select; pub mod update; +pub mod user; pub mod value; pub use column::{Column, Columns, ToColumn}; @@ -48,6 +49,7 @@ pub use pool::{get_connection, get_pool, start_transaction, Connection, Connecti pub use row::Row; pub use select::Select; pub use update::Update; +pub use user::UserModel; pub use value::{ToValue, Value}; /// Convert a PostgreSQL row to a Rust struct. Type conversions are handled by `tokio_postgres`. This only @@ -593,6 +595,8 @@ impl Query { } } + /// If a unique constraint on any of these columns is triggered, + /// the row will be automatically updated. pub fn unique_by(self, columns: &[impl ToColumn]) -> Self { match self { Query::Insert(insert) => Query::Insert(insert.unique_by(columns)), diff --git a/rwf/src/model/row.rs b/rwf/src/model/row.rs index c217205..e0fd342 100644 --- a/rwf/src/model/row.rs +++ b/rwf/src/model/row.rs @@ -49,10 +49,12 @@ impl std::ops::Deref for Row { } impl Row { + /// Create new row. pub fn new(row: tokio_postgres::Row) -> Self { Self { row: Arc::new(row) } } + /// Convert the row to a map of column names and values. pub fn values(self) -> Result, Error> { let mut result = HashMap::new(); for column in self.columns() { @@ -62,6 +64,12 @@ impl Row { Ok(result) } + + /// Consume the row and return the inner `tokio_postgres::Row` if there + /// are no more references to this row. + pub fn into_inner(self) -> Option { + Arc::into_inner(self.row) + } } #[cfg(test)] diff --git a/rwf/src/model/user.rs b/rwf/src/model/user.rs new file mode 100644 index 0000000..eff1d01 --- /dev/null +++ b/rwf/src/model/user.rs @@ -0,0 +1,97 @@ +use super::{Model, Pool, Row, ToValue, Value}; +use async_trait::async_trait; +use tokio::task::spawn_blocking; + +use thiserror::Error; + +/// User model error. +#[derive(Debug, Error)] +pub enum Error { + /// User already exists. + #[error("user already exists")] + UserExists, + + /// User doesn't exist. + #[error("user does not exist")] + UserDoesNotExist, + + /// Wrong password. + #[error("supplied password is incorrect")] + WrongPassword, + + /// Some database error. + #[error("{0}")] + DatabaseError(#[from] super::Error), +} + +/// Implement user creation and authentication for any database model +/// which has at least the identifier column and a password column. The identifier +/// column must have a unique index. +#[async_trait] +pub trait UserModel: Model + Sync { + fn identifier_column() -> &'static str; + fn password_column() -> &'static str; + + async fn create_user( + identifier: impl ToValue + Send, + password: impl ToString + Send, + ) -> Result { + let exists = Self::filter(Self::identifier_column(), identifier.to_value()) + .limit(1) + .fetch_optional(Pool::pool()) + .await?; + + if exists.is_some() { + return Err(Error::UserExists); + } + + let password = password.to_string(); + + let password_hash = spawn_blocking(move || crate::crypto::hash(password.as_bytes())) + .await + .unwrap() + .unwrap(); + + let user = Self::create(&[ + (Self::identifier_column(), identifier.to_value()), + (Self::password_column(), password_hash.to_value()), + ]) + .unique_by(&[Self::identifier_column()]) + .fetch(Pool::pool()) + .await?; + + Ok(user) + } + + async fn login_user( + identifier: impl ToValue + Send, + password: impl ToString + Send, + ) -> Result { + let user = Row::filter(Self::identifier_column(), identifier.to_value()) + .not(Self::password_column(), Value::Null) // Make sure column exists + .take_one() + .fetch_optional(Pool::pool()) + .await?; + + if let Some(user) = user { + let column: String = user.try_get(&Self::password_column()).unwrap(); + + let password = password.to_string(); + + let valid = + spawn_blocking(move || crate::crypto::hash_validate(column.as_bytes(), &password)) + .await + .unwrap() + .unwrap(); + + if valid { + let row = user.into_inner().unwrap(); + Ok(Self::from_row(row)?) + } else { + Err(Error::WrongPassword) + } + } else { + Err(Error::UserDoesNotExist) + } + } +} diff --git a/rwf/src/model/value.rs b/rwf/src/model/value.rs index 8d55c67..a1aa7ee 100644 --- a/rwf/src/model/value.rs +++ b/rwf/src/model/value.rs @@ -114,6 +114,24 @@ impl Value { _ => false, } } + + /// Convert the value to an integer if it is one. + pub fn integer(self) -> Result { + match self { + Value::Int(i) => Ok(i as i64), + Value::Integer(i) => Ok(i), + Value::BigInt(i) => Ok(i), + Value::SmallInt(i) => Ok(i as i64), + Value::Optional(value) => match *value { + Some(Value::Int(i)) => Ok(i as i64), + Some(Value::Integer(i)) => Ok(i), + Some(Value::BigInt(i)) => Ok(i), + Some(Value::SmallInt(i)) => Ok(i as i64), + _ => Err(Error::NotAnInteger), + }, + _ => Err(Error::NotAnInteger), + } + } } /// Convert a Rust type to a [`Value`]. Implementation for many common types