Skip to content

Commit

Permalink
feat: impl refresh token api
Browse files Browse the repository at this point in the history
  • Loading branch information
darknight committed Dec 3, 2023
1 parent f3a3108 commit bf0c828
Show file tree
Hide file tree
Showing 9 changed files with 412 additions and 30 deletions.
40 changes: 39 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions ee/tabby-webserver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ tabby-common = { path = "../../crates/tabby-common" }
tarpc = { version = "0.33.0", features = ["serde-transport"] }
thiserror.workspace = true
tokio = { workspace = true, features = ["fs"] }
tokio-cron-scheduler = "0.9.4"
tokio-rusqlite = "0.4.0"
tokio-tungstenite = "0.20.1"
tower = { version = "0.4", features = ["util"] }
Expand Down
7 changes: 7 additions & 0 deletions ee/tabby-webserver/graphql/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type Mutation {
register(email: String!, password1: String!, password2: String!, invitationCode: String): RegisterResponse!
tokenAuth(email: String!, password: String!): TokenAuthResponse!
verifyToken(token: String!): VerifyTokenResponse!
refreshToken(refreshToken: String!): RefreshTokenResponse!
createInvitation(email: String!): Int!
deleteInvitation(id: Int!): Int!
}
Expand Down Expand Up @@ -62,6 +63,12 @@ type TokenAuthResponse {
refreshToken: String!
}

type RefreshTokenResponse {
accessToken: String!
refreshToken: String!
refreshExpiresAt: Float!
}

schema {
query: Query
mutation: Mutation
Expand Down
65 changes: 58 additions & 7 deletions ee/tabby-webserver/src/schema/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use juniper::{FieldError, GraphQLObject, IntoFieldError, ScalarValue};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use uuid::Uuid;
use validator::ValidationErrors;

use super::from_validation_errors;
Expand All @@ -19,6 +20,7 @@ lazy_static! {
jwt_token_secret().as_bytes()
);
static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes
static ref JWT_REFRESH_PERIOD: i64 = 7 * 24 * 60 * 60; // 7 days
}

pub fn generate_jwt(claims: Claims) -> jwt::errors::Result<String> {
Expand All @@ -37,10 +39,15 @@ fn jwt_token_secret() -> String {
std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string())
}

pub fn generate_refresh_token(utc_ts: i64) -> (String, i64) {
let token = Uuid::new_v4().to_string().replace('-', "");
(token, utc_ts + *JWT_REFRESH_PERIOD)
}

#[derive(Debug, GraphQLObject)]
pub struct RegisterResponse {
access_token: String,
refresh_token: String,
pub refresh_token: String,
}

impl RegisterResponse {
Expand Down Expand Up @@ -82,7 +89,7 @@ impl<S: ScalarValue> IntoFieldError<S> for RegisterError {
#[derive(Debug, GraphQLObject)]
pub struct TokenAuthResponse {
access_token: String,
refresh_token: String,
pub refresh_token: String,
}

impl TokenAuthResponse {
Expand Down Expand Up @@ -127,11 +134,45 @@ impl<S: ScalarValue> IntoFieldError<S> for TokenAuthError {
}
}

#[derive(Debug, Default, GraphQLObject)]
#[derive(Error, Debug)]
pub enum RefreshTokenError {
#[error("Invalid refresh token")]
InvalidRefreshToken,

#[error("Expired refresh token")]
ExpiredRefreshToken,

#[error("User not found")]
UserNotFound,

#[error(transparent)]
Other(#[from] anyhow::Error),

#[error("Unknown error")]
Unknown,
}

impl<S: ScalarValue> IntoFieldError<S> for RefreshTokenError {
fn into_field_error(self) -> FieldError<S> {
self.into()
}
}

#[derive(Debug, GraphQLObject)]
pub struct RefreshTokenResponse {
access_token: String,
refresh_token: String,
refresh_expires_in: i32,
pub access_token: String,
pub refresh_token: String,
pub refresh_expires_at: f64,
}

impl RefreshTokenResponse {
pub fn new(access_token: String, refresh_token: String, refresh_expires_at: f64) -> Self {
Self {
access_token,
refresh_token,
refresh_expires_at,
}
}
}

#[derive(Debug, GraphQLObject)]
Expand Down Expand Up @@ -215,7 +256,10 @@ pub trait AuthenticationService: Send + Sync {
password: String,
) -> std::result::Result<TokenAuthResponse, TokenAuthError>;

async fn refresh_token(&self, refresh_token: String) -> Result<RefreshTokenResponse>;
async fn refresh_token(
&self,
refresh_token: String,
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>;
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse>;
async fn is_admin_initialized(&self) -> Result<bool>;

Expand Down Expand Up @@ -245,4 +289,11 @@ mod tests {
&UserInfo::new("test".to_string(), false)
);
}

#[test]
fn test_generate_refresh_token() {
let (token, exp) = generate_refresh_token(100);
assert_eq!(token.len(), 32);
assert_eq!(exp, 100 + *JWT_REFRESH_PERIOD);
}
}
16 changes: 13 additions & 3 deletions ee/tabby-webserver/src/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use std::sync::Arc;

use auth::AuthenticationService;
use juniper::{
graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, IntoFieldError,
Object, RootNode, ScalarValue, Value,
graphql_object, graphql_value, EmptySubscription, FieldError, IntoFieldError, Object, RootNode,
ScalarValue, Value,
};
use juniper_axum::FromAuth;
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
Expand All @@ -17,7 +17,10 @@ use self::{
worker::WorkerService,
};
use crate::schema::{
auth::{RegisterResponse, TokenAuthResponse, VerifyTokenResponse},
auth::{
RefreshTokenError, RefreshTokenResponse, RegisterResponse, TokenAuthResponse,
VerifyTokenResponse,
},
worker::Worker,
};

Expand Down Expand Up @@ -135,6 +138,13 @@ impl Mutation {
Ok(ctx.locator.auth().verify_token(token).await?)
}

async fn refresh_token(
ctx: &Context,
refresh_token: String,
) -> Result<RefreshTokenResponse, RefreshTokenError> {
ctx.locator.auth().refresh_token(refresh_token).await
}

async fn create_invitation(ctx: &Context, email: String) -> Result<i32> {
if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {
Expand Down
Loading

0 comments on commit bf0c828

Please sign in to comment.