Skip to content

Commit

Permalink
add tests, format files
Browse files Browse the repository at this point in the history
  • Loading branch information
darknight committed Dec 2, 2023
1 parent 0ebba25 commit b6b5394
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 20 deletions.
5 changes: 4 additions & 1 deletion ee/tabby-webserver/src/schema/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,10 @@ pub trait AuthenticationService: Send + Sync {
password: String,
) -> std::result::Result<TokenAuthResponse, TokenAuthError>;

async fn refresh_token(&self, refresh_token: String) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>;
async fn refresh_token(
&self,
refresh_token: String,
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>;
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse>;
async fn is_admin_initialized(&self) -> Result<bool>;

Expand Down
8 changes: 5 additions & 3 deletions ee/tabby-webserver/src/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ use self::{
worker::WorkerService,
};
use crate::schema::{
auth::{RegisterResponse, TokenAuthResponse, VerifyTokenResponse},
auth::{
RefreshTokenError, RefreshTokenResponse, RegisterResponse, TokenAuthResponse,
VerifyTokenResponse,
},
worker::Worker,
};
use crate::schema::auth::{RefreshTokenError, RefreshTokenResponse};

pub trait ServiceLocator: Send + Sync {
fn auth(&self) -> &dyn AuthenticationService;
Expand Down Expand Up @@ -138,7 +140,7 @@ impl Mutation {

async fn refresh_token(
ctx: &Context,
refresh_token: String
refresh_token: String,
) -> Result<RefreshTokenResponse, RefreshTokenError> {
ctx.locator.auth().refresh_token(refresh_token).await
}
Expand Down
30 changes: 20 additions & 10 deletions ee/tabby-webserver/src/service/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ use async_trait::async_trait;
use validator::Validate;

use super::db::DbConn;
use crate::schema::auth::{generate_jwt, validate_jwt, AuthenticationService, Claims, Invitation, RefreshTokenResponse, RegisterError, RegisterResponse, TokenAuthError, TokenAuthResponse, UserInfo, VerifyTokenResponse, RefreshTokenError};
use crate::schema::auth::generate_refresh_token;
use crate::schema::auth::{
generate_jwt, generate_refresh_token, validate_jwt, AuthenticationService, Claims, Invitation,
RefreshTokenError, RefreshTokenResponse, RegisterError, RegisterResponse, TokenAuthError,
TokenAuthResponse, UserInfo, VerifyTokenResponse,
};

/// Input parameters for register mutation
/// `validate` attribute is used to validate the input parameters
Expand Down Expand Up @@ -143,10 +146,9 @@ impl AuthenticationService for DbConn {
.await?;
let user = self.get_user(id).await?.unwrap();

let (refresh_token, expires_in) = generate_refresh_token(
chrono::Utc::now().timestamp()
);
self.create_refresh_token(id, &refresh_token, expires_in).await?;
let (refresh_token, expires_in) = generate_refresh_token(chrono::Utc::now().timestamp());
self.create_refresh_token(id, &refresh_token, expires_in)
.await?;

let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new(
user.email.clone(),
Expand Down Expand Up @@ -180,7 +182,8 @@ impl AuthenticationService for DbConn {
Some(refresh_token) => refresh_token.token,
None => {
let (token, expires_in) = generate_refresh_token(utc_ts);
self.create_refresh_token(user.id, &token, expires_in).await?;
self.create_refresh_token(user.id, &token, expires_in)
.await?;
token
}
};
Expand All @@ -196,7 +199,10 @@ impl AuthenticationService for DbConn {
Ok(resp)
}

async fn refresh_token(&self, token: String) -> std::result::Result<RefreshTokenResponse, RefreshTokenError> {
async fn refresh_token(
&self,
token: String,
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError> {
let Some(refresh_token) = self.get_refresh_token(&token).await? else {
return Err(RefreshTokenError::InvalidRefreshToken);
};
Expand Down Expand Up @@ -392,7 +398,10 @@ mod tests {
let conn = DbConn::new_in_memory().await.unwrap();
let reg_resp = register_admin_user(&conn).await;

let resp = conn.refresh_token(reg_resp.refresh_token.clone()).await.unwrap();
let resp = conn
.refresh_token(reg_resp.refresh_token.clone())
.await
.unwrap();
// refreshed access token should be valid
assert!(validate_jwt(&resp.access_token).is_ok());
// refresh token should be no change
Expand All @@ -402,7 +411,8 @@ mod tests {

let auth_resp = conn
.token_auth(ADMIN_EMAIL.to_owned(), ADMIN_PASSWORD.to_owned())
.await.unwrap();
.await
.unwrap();
assert_eq!(resp.refresh_token, auth_resp.refresh_token);
}
}
47 changes: 41 additions & 6 deletions ee/tabby-webserver/src/service/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,8 @@ pub struct RefreshToken {
}

impl RefreshToken {

fn select(clause: &str) -> String {
r#"SELECT id, user_id, token, expires_in, created_at FROM refresh_tokens WHERE "#
.to_owned()
r#"SELECT id, user_id, token, expires_in, created_at FROM refresh_tokens WHERE "#.to_owned()
+ clause
}

Expand Down Expand Up @@ -390,22 +388,28 @@ impl DbConn {
RefreshToken::select("token = ?").as_str(),
params![token],
RefreshToken::from_row,
).optional()
)
.optional()
})
.await?;

Ok(token)
}

pub async fn get_user_unexpired_token(&self, user_id: i32, utc_ts: i64) -> Result<Option<RefreshToken>> {
pub async fn get_user_unexpired_token(
&self,
user_id: i32,
utc_ts: i64,
) -> Result<Option<RefreshToken>> {
let token = self
.conn
.call(move |c| {
c.query_row(
RefreshToken::select("user_id = ? AND expires_in > ?").as_str(),
params![user_id, utc_ts],
RefreshToken::from_row,
).optional()
)
.optional()
})
.await?;

Expand Down Expand Up @@ -502,4 +506,35 @@ mod tests {
let invitations = conn.list_invitations().await.unwrap();
assert!(invitations.is_empty());
}

#[tokio::test]
async fn test_create_refresh_token() {
let conn = DbConn::new_in_memory().await.unwrap();

conn.create_refresh_token(1, "test", 100).await.unwrap();

let token = conn.get_refresh_token("test").await.unwrap().unwrap();

assert_eq!(token.user_id, 1);
assert_eq!(token.token, "test");
assert_eq!(token.expires_in, 100);
}

#[tokio::test]
async fn test_get_user_unexpired_token() {
let conn = DbConn::new_in_memory().await.unwrap();

conn.create_refresh_token(1, "test", 100).await.unwrap();

// utc_ts is 10, expires_in is 100, so the token is not expired
let token = conn.get_user_unexpired_token(1, 10).await.unwrap().unwrap();

assert_eq!(token.user_id, 1);
assert_eq!(token.token, "test");
assert_eq!(token.expires_in, 100);

// utc_ts is 1000, expires_in is 100, so the token is expired
let token = conn.get_user_unexpired_token(1, 1000).await.unwrap();
assert!(token.is_none());
}
}

0 comments on commit b6b5394

Please sign in to comment.