Skip to content

Commit

Permalink
refactor: cleanup service level error as we're handling input validat…
Browse files Browse the repository at this point in the history
…ion at graphql level. (#1513)

* refactor out TokenAuthError

* refactor out RefreshTokenError

* refactor out RegisterError

* refactor out PasswordResetError

* fix test

* fix typos
  • Loading branch information
wsxiaoys authored Feb 22, 2024
1 parent 2708b3e commit 7ba5e72
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 322 deletions.
217 changes: 98 additions & 119 deletions ee/tabby-webserver/src/schema/auth.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use std::fmt::Debug;
use std::{borrow::Cow, fmt::Debug};

use anyhow::Result;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use jsonwebtoken as jwt;
use juniper::{
FieldError, GraphQLEnum, GraphQLInputObject, GraphQLObject, IntoFieldError, ScalarValue, ID,
};
use juniper::{GraphQLEnum, GraphQLInputObject, GraphQLObject, ID};
use juniper_axum::relay;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
Expand All @@ -15,9 +13,8 @@ use thiserror::Error;
use tokio::task::JoinHandle;
use tracing::{error, warn};
use uuid::Uuid;
use validator::{Validate, ValidationErrors};
use validator::Validate;

use super::from_validation_errors;
use crate::schema::Context;

lazy_static! {
Expand Down Expand Up @@ -80,33 +77,6 @@ impl RegisterResponse {
}
}

#[derive(Error, Debug)]
pub enum RegisterError {
#[error("Invalid input parameters")]
InvalidInput(#[from] ValidationErrors),

#[error("Invitation code is not valid")]
InvalidInvitationCode,

#[error("Email is already registered")]
DuplicateEmail,

#[error(transparent)]
Other(#[from] anyhow::Error),

#[error("Unknown error")]
Unknown,
}

impl<S: ScalarValue> IntoFieldError<S> for RegisterError {
fn into_field_error(self) -> FieldError<S> {
match self {
Self::InvalidInput(errors) => from_validation_errors(errors),
_ => self.into(),
}
}
}

#[derive(Debug, GraphQLObject)]
pub struct TokenAuthResponse {
access_token: String,
Expand All @@ -122,46 +92,67 @@ impl TokenAuthResponse {
}
}

#[derive(Error, Debug)]
pub enum TokenAuthError {
#[error("Invalid input parameters")]
InvalidInput(#[from] ValidationErrors),

#[error("User not found")]
UserNotFound,

#[error("Password is not valid")]
InvalidPassword,

#[error("User is disabled")]
UserDisabled,

#[error(transparent)]
Other(#[from] anyhow::Error),

#[error("Unknown error")]
Unknown,
}

#[derive(Error, Debug)]
pub enum PasswordResetError {
#[error("Invalid code")]
InvalidCode,
#[error("Invalid password")]
InvalidInput(#[from] ValidationErrors),
#[error(transparent)]
Other(#[from] anyhow::Error),
#[error("Unknown error")]
Unknown,
/// Input parameters for token_auth mutation
/// See `RegisterInput` for `validate` attribute usage
#[derive(Validate)]
pub struct TokenAuthInput {
#[validate(email(code = "email", message = "Email is invalid"))]
#[validate(length(
max = 128,
code = "email",
message = "Email must be at most 128 characters"
))]
pub email: String,
#[validate(length(
min = 8,
code = "password",
message = "Password must be at least 8 characters"
))]
#[validate(length(
max = 20,
code = "password",
message = "Password must be at most 20 characters"
))]
pub password: String,
}

impl<S: ScalarValue> IntoFieldError<S> for PasswordResetError {
fn into_field_error(self) -> FieldError<S> {
match self {
Self::InvalidInput(errors) => from_validation_errors(errors),
_ => self.into(),
}
}
/// Input parameters for register mutation
/// `validate` attribute is used to validate the input parameters
/// - `code` argument specifies which parameter causes the failure
/// - `message` argument provides client friendly error message
///
#[derive(Validate)]
pub struct RegisterInput {
#[validate(email(code = "email", message = "Email is invalid"))]
#[validate(length(
max = 128,
code = "email",
message = "Email must be at most 128 characters"
))]
pub email: String,
#[validate(length(
min = 8,
code = "password1",
message = "Password must be at least 8 characters"
))]
#[validate(length(
max = 20,
code = "password1",
message = "Password must be at most 20 characters"
))]
#[validate(custom = "validate_password")]
pub password1: String,
#[validate(must_match(
code = "password2",
message = "Passwords do not match",
other = "password1"
))]
#[validate(length(
max = 20,
code = "password2",
message = "Password must be at most 20 characters"
))]
pub password2: String,
}

#[derive(Default, Serialize)]
Expand Down Expand Up @@ -191,42 +182,6 @@ pub enum OAuthError {
Unknown,
}

impl<S: ScalarValue> IntoFieldError<S> for TokenAuthError {
fn into_field_error(self) -> FieldError<S> {
match self {
Self::InvalidInput(errors) => from_validation_errors(errors),
_ => self.into(),
}
}
}

#[derive(Error, Debug)]
pub enum RefreshTokenError {
#[error("Invalid refresh token")]
InvalidRefreshToken,

#[error("Expired refresh token")]
ExpiredRefreshToken,

#[error("User not found")]
UserNotFound,

#[error("User is disabled")]
UserDisabled,

#[error(transparent)]
Other(#[from] anyhow::Error),

#[error("Unknown error")]
Unknown,
}

impl<S: ScalarValue> IntoFieldError<S> for RefreshTokenError {
fn into_field_error(self) -> FieldError<S> {
self.into()
}
}

#[derive(Debug, GraphQLObject)]
pub struct RefreshTokenResponse {
pub access_token: String,
Expand Down Expand Up @@ -412,21 +367,13 @@ pub trait AuthenticationService: Send + Sync {
&self,
email: String,
password1: String,
password2: String,
invitation_code: Option<String>,
) -> std::result::Result<RegisterResponse, RegisterError>;
) -> Result<RegisterResponse>;
async fn allow_self_signup(&self) -> Result<bool>;

async fn token_auth(
&self,
email: String,
password: String,
) -> std::result::Result<TokenAuthResponse, TokenAuthError>;
async fn token_auth(&self, email: String, password: String) -> Result<TokenAuthResponse>;

async fn refresh_token(
&self,
refresh_token: String,
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>;
async fn refresh_token(&self, refresh_token: String) -> Result<RefreshTokenResponse>;
async fn delete_expired_token(&self) -> Result<()>;
async fn delete_expired_password_resets(&self) -> Result<()>;
async fn verify_access_token(&self, access_token: &str) -> Result<JWTPayload>;
Expand All @@ -438,7 +385,7 @@ pub trait AuthenticationService: Send + Sync {
async fn delete_invitation(&self, id: &ID) -> Result<ID>;

async fn reset_user_auth_token(&self, email: &str) -> Result<()>;
async fn password_reset(&self, code: &str, password: &str) -> Result<(), PasswordResetError>;
async fn password_reset(&self, code: &str, password: &str) -> Result<()>;
async fn request_password_reset_email(&self, email: String) -> Result<Option<JoinHandle<()>>>;

async fn list_users(
Expand Down Expand Up @@ -477,6 +424,38 @@ pub trait AuthenticationService: Send + Sync {
async fn update_user_role(&self, id: &ID, is_admin: bool) -> Result<()>;
}

fn validate_password(value: &str) -> Result<(), validator::ValidationError> {
let make_validation_error = |message: &'static str| {
let mut err = validator::ValidationError::new("password1");
err.message = Some(Cow::Borrowed(message));
Err(err)
};

let contains_lowercase = value.chars().any(|x| x.is_ascii_lowercase());
if !contains_lowercase {
return make_validation_error("Password should contain at least one lowercase character");
}

let contains_uppercase = value.chars().any(|x| x.is_ascii_uppercase());
if !contains_uppercase {
return make_validation_error("Password should contain at least one uppercase character");
}

let contains_digit = value.chars().any(|x| x.is_ascii_digit());
if !contains_digit {
return make_validation_error("Password should contain at least one numeric character");
}

let contains_special_char = value.chars().any(|x| x.is_ascii_punctuation());
if !contains_special_char {
return make_validation_error(
"Password should contain at least one special character, e.g @#$%^&{}",
);
}

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
37 changes: 24 additions & 13 deletions ee/tabby-webserver/src/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ pub mod worker;
use std::sync::Arc;

use auth::{
validate_jwt, AuthenticationService, Invitation, RefreshTokenError, RefreshTokenResponse,
RegisterError, RegisterResponse, TokenAuthError, TokenAuthResponse, User,
validate_jwt, AuthenticationService, Invitation, RefreshTokenResponse, RegisterResponse,
TokenAuthResponse, User,
};
use job::{JobRun, JobService};
use juniper::{
Expand Down Expand Up @@ -383,31 +383,42 @@ impl Mutation {
password1: String,
password2: String,
invitation_code: Option<String>,
) -> Result<RegisterResponse, RegisterError> {
ctx.locator
) -> Result<RegisterResponse> {
let input = auth::RegisterInput {
email,
password1,
password2,
};
input.validate()?;

Ok(ctx
.locator
.auth()
.register(email, password1, password2, invitation_code)
.await
.register(input.email, input.password1, invitation_code)
.await?)
}

async fn token_auth(
ctx: &Context,
email: String,
password: String,
) -> Result<TokenAuthResponse, TokenAuthError> {
ctx.locator.auth().token_auth(email, password).await
) -> Result<TokenAuthResponse> {
let input = auth::TokenAuthInput { email, password };
input.validate()?;
Ok(ctx
.locator
.auth()
.token_auth(input.email, input.password)
.await?)
}

async fn verify_token(ctx: &Context, token: String) -> Result<bool> {
ctx.locator.auth().verify_access_token(&token).await?;
Ok(true)
}

async fn refresh_token(
ctx: &Context,
refresh_token: String,
) -> Result<RefreshTokenResponse, RefreshTokenError> {
ctx.locator.auth().refresh_token(refresh_token).await
async fn refresh_token(ctx: &Context, refresh_token: String) -> Result<RefreshTokenResponse> {
Ok(ctx.locator.auth().refresh_token(refresh_token).await?)
}

async fn create_invitation(ctx: &Context, email: String) -> Result<ID> {
Expand Down
Loading

0 comments on commit 7ba5e72

Please sign in to comment.