Skip to content

Commit

Permalink
refactor: cleanup the dependency chain - ServerContext should be the …
Browse files Browse the repository at this point in the history
…only thing being public in server mod
  • Loading branch information
wsxiaoys committed Dec 1, 2023
1 parent 1a9cbdc commit fb34e9c
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 118 deletions.
2 changes: 1 addition & 1 deletion crates/juniper-axum/src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ where
let split = authorization.split_once(' ');
match split {
// Found proper bearer
Some((name, contents)) if name == "Bearer" => Ok(Self(Some(contents.to_owned()))),
Some(("Bearer", contents)) => Ok(Self(Some(contents.to_owned()))),
_ => Ok(Self(None)),
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/juniper-axum/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pub mod extract;
pub mod response;

use std::{future};
use std::future;

use axum::{
extract::{Extension, State},
Expand Down
95 changes: 64 additions & 31 deletions ee/tabby-webserver/src/schema/auth.rs
Original file line number Diff line number Diff line change
@@ -1,38 +1,35 @@
use std::fmt::Debug;

use async_trait::async_trait;
use jsonwebtoken as jwt;
use juniper::{FieldError, GraphQLObject, IntoFieldError, Object, ScalarValue, Value};
use juniper::{FieldResult, GraphQLObject};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use validator::ValidationError;

use crate::server::auth::JWT_DEFAULT_EXP;

#[derive(Debug)]
pub struct ValidationErrors {
pub errors: Vec<ValidationError>,
}

impl<S: ScalarValue> IntoFieldError<S> for ValidationErrors {
fn into_field_error(self) -> FieldError<S> {
let errors = self
.errors
.into_iter()
.map(|err| {
let mut obj = Object::with_capacity(2);
obj.add_field("path", Value::scalar(err.code.to_string()));
obj.add_field(
"message",
Value::scalar(err.message.unwrap_or_default().to_string()),
);
obj.into()
})
.collect::<Vec<_>>();
let mut ext = Object::with_capacity(2);
ext.add_field("code", Value::scalar("validation-error".to_string()));
ext.add_field("errors", Value::list(errors));

FieldError::new("Invalid input parameters", ext.into())
}

lazy_static! {
static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret(
jwt_token_secret().as_bytes()
);
static ref JWT_DECODING_KEY: jwt::DecodingKey = jwt::DecodingKey::from_secret(
jwt_token_secret().as_bytes()
);
static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes
}

pub fn generate_jwt(claims: Claims) -> jwt::errors::Result<String> {
let header = jwt::Header::default();
let token = jwt::encode(&header, &claims, &JWT_ENCODING_KEY)?;
Ok(token)
}

pub fn validate_jwt(token: &str) -> jwt::errors::Result<Claims> {
let validation = jwt::Validation::default();
let data = jwt::decode::<Claims>(token, &JWT_DECODING_KEY, &validation)?;
Ok(data.claims)
}

fn jwt_token_secret() -> String {
std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string())
}

#[derive(Debug, GraphQLObject)]
Expand Down Expand Up @@ -127,3 +124,39 @@ impl Claims {
&self.user
}
}

#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_jwt() {
let claims = Claims::new(UserInfo::new("test".to_string(), false));
let token = generate_jwt(claims).unwrap();

assert!(!token.is_empty())
}

#[test]
fn test_validate_jwt() {
let claims = Claims::new(UserInfo::new("test".to_string(), false));
let token = generate_jwt(claims).unwrap();
let claims = validate_jwt(&token).unwrap();
assert_eq!(
claims.user_info(),
&UserInfo::new("test".to_string(), false)
);
}
}

#[async_trait]
pub trait AuthenticationService {
async fn register(
&self,
email: String,
password1: String,
password2: String,
) -> FieldResult<RegisterResponse>;
async fn token_auth(&self, email: String, password: String) -> FieldResult<TokenAuthResponse>;
async fn refresh_token(&self, refresh_token: String) -> FieldResult<RefreshTokenResponse>;
async fn verify_token(&self, access_token: String) -> FieldResult<VerifyTokenResponse>;
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@ pub mod auth;

use std::sync::Arc;


use auth::AuthenticationService;
use juniper::{
graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, RootNode,
graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, IntoFieldError,
Object, RootNode, ScalarValue, Value,
};
use juniper_axum::FromAuth;
use validator::ValidationError;

use self::auth::validate_jwt;
use crate::{
api::Worker,
schema::auth::{RegisterResponse, TokenAuthResponse, VerifyTokenResponse},
server::{
auth::{validate_jwt, AuthenticationService, RegisterInput, TokenAuthInput},
ServerContext,
},
server::ServerContext,
};

pub struct Context {
Expand Down Expand Up @@ -71,28 +71,53 @@ impl Mutation {
password1: String,
password2: String,
) -> FieldResult<RegisterResponse> {
let input = RegisterInput {
email,
password1,
password2,
};
ctx.server.auth().register(input).await
ctx.server
.auth()
.register(email, password1, password2)
.await
}

async fn token_auth(
ctx: &Context,
email: String,
password: String,
) -> FieldResult<TokenAuthResponse> {
let input = TokenAuthInput { email, password };
ctx.server.auth().token_auth(input).await
ctx.server.auth().token_auth(email, password).await
}

async fn verify_token(ctx: &Context, token: String) -> FieldResult<VerifyTokenResponse> {
ctx.server.auth().verify_token(token).await
}
}

#[derive(Debug)]
pub struct ValidationErrors {
pub errors: Vec<ValidationError>,
}

impl<S: ScalarValue> IntoFieldError<S> for ValidationErrors {
fn into_field_error(self) -> FieldError<S> {
let errors = self
.errors
.into_iter()
.map(|err| {
let mut obj = Object::with_capacity(2);
obj.add_field("path", Value::scalar(err.code.to_string()));
obj.add_field(
"message",
Value::scalar(err.message.unwrap_or_default().to_string()),
);
obj.into()
})
.collect::<Vec<_>>();
let mut ext = Object::with_capacity(2);
ext.add_field("code", Value::scalar("validation-error".to_string()));
ext.add_field("errors", Value::list(errors));

FieldError::new("Invalid input parameters", ext.into())
}
}

pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Context>>;

pub fn create_schema() -> Schema {
Expand Down
95 changes: 26 additions & 69 deletions ee/tabby-webserver/src/server/auth.rs
Original file line number Diff line number Diff line change
@@ -1,48 +1,37 @@
use std::env;

use argon2::{
password_hash,
password_hash::{rand_core::OsRng, SaltString},
Argon2, PasswordHasher, PasswordVerifier,
};
use async_trait::async_trait;
use jsonwebtoken as jwt;
use juniper::{FieldResult, IntoFieldError};
use lazy_static::lazy_static;
use validator::Validate;

use crate::{
db::DbConn,
schema::auth::{
Claims, RefreshTokenResponse, RegisterResponse, TokenAuthResponse, UserInfo,
ValidationErrors, VerifyTokenResponse,
schema::{
auth::{
generate_jwt, validate_jwt, AuthenticationService, Claims, RefreshTokenResponse,
RegisterResponse, TokenAuthResponse, UserInfo, VerifyTokenResponse,
},
ValidationErrors,
},
};

lazy_static! {
static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret(
jwt_token_secret().as_bytes()
);
static ref JWT_DECODING_KEY: jwt::DecodingKey = jwt::DecodingKey::from_secret(
jwt_token_secret().as_bytes()
);
pub static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes
}

/// Input parameters for register mutation
/// `validate` attribute is used to validate the input parameters
/// - `code` argument specifies which parameter causes the failure
/// - `message` argument provides client friendly error message
///
#[derive(Validate)]
pub struct RegisterInput {
struct RegisterInput {
#[validate(email(code = "email", message = "Email is invalid"))]
#[validate(length(
max = 128,
code = "email",
message = "Email must be at most 128 characters"
))]
pub email: String,
email: String,
#[validate(length(
min = 8,
code = "password1",
Expand All @@ -58,7 +47,7 @@ pub struct RegisterInput {
message = "Passwords do not match",
other = "password2"
))]
pub password1: String,
password1: String,
#[validate(length(
min = 8,
code = "password2",
Expand All @@ -69,7 +58,7 @@ pub struct RegisterInput {
code = "password2",
message = "Password must be at most 20 characters"
))]
pub password2: String,
password2: String,
}

impl std::fmt::Debug for RegisterInput {
Expand All @@ -85,14 +74,14 @@ impl std::fmt::Debug for RegisterInput {
/// Input parameters for token_auth mutation
/// See `RegisterInput` for `validate` attribute usage
#[derive(Validate)]
pub struct TokenAuthInput {
struct TokenAuthInput {
#[validate(email(code = "email", message = "Email is invalid"))]
#[validate(length(
max = 128,
code = "email",
message = "Email must be at most 128 characters"
))]
pub email: String,
email: String,
#[validate(length(
min = 8,
code = "password",
Expand All @@ -103,7 +92,7 @@ pub struct TokenAuthInput {
code = "password",
message = "Password must be at most 20 characters"
))]
pub password: String,
password: String,
}

impl std::fmt::Debug for TokenAuthInput {
Expand All @@ -115,17 +104,19 @@ impl std::fmt::Debug for TokenAuthInput {
}
}

#[async_trait]
pub trait AuthenticationService {
async fn register(&self, input: RegisterInput) -> FieldResult<RegisterResponse>;
async fn token_auth(&self, input: TokenAuthInput) -> FieldResult<TokenAuthResponse>;
async fn refresh_token(&self, refresh_token: String) -> FieldResult<RefreshTokenResponse>;
async fn verify_token(&self, access_token: String) -> FieldResult<VerifyTokenResponse>;
}

#[async_trait]
impl AuthenticationService for DbConn {
async fn register(&self, input: RegisterInput) -> FieldResult<RegisterResponse> {
async fn register(
&self,
email: String,
password1: String,
password2: String,
) -> FieldResult<RegisterResponse> {
let input = RegisterInput {
email,
password1,
password2,
};
input.validate().map_err(|err| {
let errors = err
.field_errors()
Expand Down Expand Up @@ -157,7 +148,8 @@ impl AuthenticationService for DbConn {
Ok(resp)
}

async fn token_auth(&self, input: TokenAuthInput) -> FieldResult<TokenAuthResponse> {
async fn token_auth(&self, email: String, password: String) -> FieldResult<TokenAuthResponse> {
let input = TokenAuthInput { email, password };
input.validate().map_err(|err| {
let errors = err
.field_errors()
Expand Down Expand Up @@ -217,22 +209,6 @@ fn password_verify(raw: &str, hash: &str) -> bool {
}
}

fn generate_jwt(claims: Claims) -> jwt::errors::Result<String> {
let header = jwt::Header::default();
let token = jwt::encode(&header, &claims, &JWT_ENCODING_KEY)?;
Ok(token)
}

pub fn validate_jwt(token: &str) -> jwt::errors::Result<Claims> {
let validation = jwt::Validation::default();
let data = jwt::decode::<Claims>(token, &JWT_DECODING_KEY, &validation)?;
Ok(data.claims)
}

fn jwt_token_secret() -> String {
env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string())
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -254,23 +230,4 @@ mod tests {
assert!(password_verify(raw, &hash));
assert!(!password_verify(raw, "invalid hash"));
}

#[test]
fn test_generate_jwt() {
let claims = Claims::new(UserInfo::new("test".to_string(), false));
let token = generate_jwt(claims).unwrap();

assert!(!token.is_empty())
}

#[test]
fn test_validate_jwt() {
let claims = Claims::new(UserInfo::new("test".to_string(), false));
let token = generate_jwt(claims).unwrap();
let claims = validate_jwt(&token).unwrap();
assert_eq!(
claims.user_info(),
&UserInfo::new("test".to_string(), false)
);
}
}
Loading

0 comments on commit fb34e9c

Please sign in to comment.