Skip to content

Commit

Permalink
feat: implement register api check with invitation code (#934)
Browse files Browse the repository at this point in the history
* feat(webserver): implement is_admin_initialized graphql api

* refactor

* add unit test

* [autofix.ci] apply automated fixes

* renaming

* temp invitations

* update

* update

* implement register check

* test

* update invitations

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
wsxiaoys and autofix-ci[bot] authored Dec 1, 2023
1 parent 88e5187 commit 19d773e
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 10 deletions.
13 changes: 12 additions & 1 deletion ee/tabby-webserver/graphql/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ enum WorkerKind {

type Mutation {
resetRegistrationToken: String!
register(email: String!, password1: String!, password2: String!): RegisterResponse!
register(email: String!, password1: String!, password2: String!, invitationCode: String): RegisterResponse!
tokenAuth(email: String!, password: String!): TokenAuthResponse!
verifyToken(token: String!): VerifyTokenResponse!
createInvitation(email: String!): Int!
deleteInvitation(id: Int!): Int!
}

type UserInfo {
Expand All @@ -33,6 +35,15 @@ type Claims {
type Query {
workers: [Worker!]!
registrationToken: String!
isAdminInitialized: Boolean!
invitations: [Invitation!]!
}

type Invitation {
id: Int!
email: String!
code: String!
createdAt: String!
}

type Worker {
Expand Down
15 changes: 15 additions & 0 deletions ee/tabby-webserver/src/schema/auth.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::fmt::Debug;

use anyhow::Result;
use async_trait::async_trait;
use jsonwebtoken as jwt;
use juniper::{FieldResult, GraphQLObject};
Expand Down Expand Up @@ -125,18 +126,32 @@ impl Claims {
}
}

#[derive(Debug, Default, Serialize, Deserialize, GraphQLObject)]
pub struct Invitation {
pub id: i32,
pub email: String,
pub code: String,

pub created_at: String,
}

#[async_trait]
pub trait AuthenticationService: Send + Sync {
async fn register(
&self,
email: String,
password1: String,
password2: String,
invitation_code: Option<String>,
) -> FieldResult<RegisterResponse>;
async fn token_auth(&self, email: String, password: String) -> FieldResult<TokenAuthResponse>;
async fn refresh_token(&self, refresh_token: String) -> FieldResult<RefreshTokenResponse>;
async fn verify_token(&self, access_token: String) -> FieldResult<VerifyTokenResponse>;
async fn is_admin_initialized(&self) -> FieldResult<bool>;

async fn create_invitation(&self, email: String) -> Result<i32>;
async fn list_invitations(&self) -> Result<Vec<Invitation>>;
async fn delete_invitation(&self, id: i32) -> Result<i32>;
}

#[cfg(test)]
Expand Down
42 changes: 38 additions & 4 deletions ee/tabby-webserver/src/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ use juniper_axum::FromAuth;
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
use validator::ValidationError;

use self::{auth::validate_jwt, worker::WorkerService};
use self::{
auth::{validate_jwt, Invitation},
worker::WorkerService,
};
use crate::schema::{
auth::{RegisterResponse, TokenAuthResponse, VerifyTokenResponse},
worker::Worker,
Expand Down Expand Up @@ -57,6 +60,15 @@ impl Query {
async fn is_admin_initialized(ctx: &Context) -> FieldResult<bool> {
ctx.locator.auth().is_admin_initialized().await
}

async fn invitations(ctx: &Context) -> FieldResult<Vec<Invitation>> {
if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {
return Ok(ctx.locator.auth().list_invitations().await?);
}
}
Err(unauthorized("Only admin is able to query invitations"))
}
}

#[derive(Default)]
Expand All @@ -71,9 +83,8 @@ impl Mutation {
return Ok(reg_token);
}
}
Err(FieldError::new(
Err(unauthorized(
"Only admin is able to reset registration token",
graphql_value!("Unauthorized"),
))
}

Expand All @@ -82,10 +93,11 @@ impl Mutation {
email: String,
password1: String,
password2: String,
invitation_code: Option<String>,
) -> FieldResult<RegisterResponse> {
ctx.locator
.auth()
.register(email, password1, password2)
.register(email, password1, password2, invitation_code)
.await
}

Expand All @@ -100,6 +112,24 @@ impl Mutation {
async fn verify_token(ctx: &Context, token: String) -> FieldResult<VerifyTokenResponse> {
ctx.locator.auth().verify_token(token).await
}

async fn create_invitation(ctx: &Context, email: String) -> FieldResult<i32> {
if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {
return Ok(ctx.locator.auth().create_invitation(email).await?);
}
}
Err(unauthorized("Only admin is able to create invitation"))
}

async fn delete_invitation(ctx: &Context, id: i32) -> FieldResult<i32> {
if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {
return Ok(ctx.locator.auth().delete_invitation(id).await?);
}
}
Err(unauthorized("Only admin is able to delete invitation"))
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -135,3 +165,7 @@ pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Context>>
pub fn create_schema() -> Schema {
Schema::new(Query, Mutation, EmptySubscription::new())
}

fn unauthorized(msg: &str) -> FieldError {
FieldError::new(msg, graphql_value!("Unauthorized"))
}
33 changes: 31 additions & 2 deletions ee/tabby-webserver/src/service/auth.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use anyhow::Result;
use argon2::{
password_hash,
password_hash::{rand_core::OsRng, SaltString},
Expand All @@ -10,8 +11,8 @@ use validator::Validate;
use super::db::DbConn;
use crate::schema::{
auth::{
generate_jwt, validate_jwt, AuthenticationService, Claims, RefreshTokenResponse,
RegisterResponse, TokenAuthResponse, UserInfo, VerifyTokenResponse,
generate_jwt, validate_jwt, AuthenticationService, Claims, Invitation,
RefreshTokenResponse, RegisterResponse, TokenAuthResponse, UserInfo, VerifyTokenResponse,
},
ValidationErrors,
};
Expand Down Expand Up @@ -109,6 +110,7 @@ impl AuthenticationService for DbConn {
email: String,
password1: String,
password2: String,
invitation_code: Option<String>,
) -> FieldResult<RegisterResponse> {
let input = RegisterInput {
email,
Expand All @@ -126,6 +128,21 @@ impl AuthenticationService for DbConn {
ValidationErrors { errors }.into_field_error()
})?;

if self.is_admin_initialized().await? {
let err = Err("Invitation code is not valid".into());
let Some(invitation_code) = invitation_code else {
return err;
};

let Some(invitation) = self.get_invitation_by_code(&invitation_code).await? else {
return err;
};

if invitation.email != input.email {
return err;
}
};

// check if email exists
if self.get_user_by_email(&input.email).await?.is_some() {
return Err("Email already exists".into());
Expand Down Expand Up @@ -193,6 +210,18 @@ impl AuthenticationService for DbConn {
let admin = self.list_admin_users().await?;
Ok(!admin.is_empty())
}

async fn create_invitation(&self, email: String) -> Result<i32> {
self.create_invitation(email).await
}

async fn list_invitations(&self) -> Result<Vec<Invitation>> {
self.list_invitations().await
}

async fn delete_invitation(&self, id: i32) -> Result<i32> {
self.delete_invitation(id).await
}
}

fn password_hash(raw: &str) -> password_hash::Result<String> {
Expand Down
121 changes: 118 additions & 3 deletions ee/tabby-webserver/src/service/db.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
use std::{path::PathBuf, sync::Arc};

use anyhow::Result;
use anyhow::{anyhow, Result};
use lazy_static::lazy_static;
use rusqlite::{params, OptionalExtension, Row};
use rusqlite_migration::{AsyncMigrations, M};
use tabby_common::path::tabby_root;
use tokio_rusqlite::Connection;
use uuid::Uuid;

use crate::schema::auth::Invitation;

lazy_static! {
static ref MIGRATIONS: AsyncMigrations = AsyncMigrations::new(vec![
M::up(
r#"
CREATE TABLE IF NOT EXISTS registration_token (
CREATE TABLE registration_token (
id INTEGER PRIMARY KEY AUTOINCREMENT,
token VARCHAR(255) NOT NULL,
created_at TIMESTAMP DEFAULT (DATETIME('now')),
Expand All @@ -22,7 +25,7 @@ lazy_static! {
),
M::up(
r#"
CREATE TABLE IF NOT EXISTS users (
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email VARCHAR(150) NOT NULL COLLATE NOCASE,
password_encrypted VARCHAR(128) NOT NULL,
Expand All @@ -33,6 +36,18 @@ lazy_static! {
);
"#
),
M::up(
r#"
CREATE TABLE invitations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email VARCHAR(150) NOT NULL COLLATE NOCASE,
code VARCHAR(36) NOT NULL,
created_at TIMESTAMP DEFAULT (DATETIME('now')),
CONSTRAINT `idx_email` UNIQUE (`email`)
CONSTRAINT `idx_code` UNIQUE (`code`)
);
"#
),
]);
}

Expand Down Expand Up @@ -200,6 +215,81 @@ impl DbConn {
}
}

impl Invitation {
fn from_row(row: &Row<'_>) -> std::result::Result<Self, rusqlite::Error> {
Ok(Self {
id: row.get(0)?,
email: row.get(1)?,
code: row.get(2)?,
created_at: row.get(3)?,
})
}
}

/// db read/write operations for `invitations` table
impl DbConn {
pub async fn list_invitations(&self) -> Result<Vec<Invitation>> {
let invitations = self
.conn
.call(move |c| {
let mut stmt =
c.prepare(r#"SELECT id, email, code, created_at FROM invitations"#)?;
let iter = stmt.query_map([], Invitation::from_row)?;
Ok(iter.filter_map(|x| x.ok()).collect::<Vec<_>>())
})
.await?;

Ok(invitations)
}

pub async fn get_invitation_by_code(&self, code: &str) -> Result<Option<Invitation>> {
let code = code.to_owned();
let token = self
.conn
.call(|conn| {
conn.query_row(
r#"SELECT id, email, code, created_at FROM invitations WHERE code = ?"#,
[code],
Invitation::from_row,
)
.optional()
})
.await?;

Ok(token)
}

pub async fn create_invitation(&self, email: String) -> Result<i32> {
let code = Uuid::new_v4().to_string();
let res = self
.conn
.call(move |c| {
let mut stmt =
c.prepare(r#"INSERT INTO invitations (email, code) VALUES (?, ?)"#)?;
let rowid = stmt.insert((email, code))?;
Ok(rowid)
})
.await?;
if res != 1 {
return Err(anyhow!("failed to create invitation"));
}

Ok(res as i32)
}

pub async fn delete_invitation(&self, id: i32) -> Result<i32> {
let res = self
.conn
.call(move |c| c.execute(r#"DELETE FROM invitations WHERE id = ?"#, params![id]))
.await?;
if res != 1 {
return Err(anyhow!("failed to delete invitation"));
}

Ok(id)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -270,4 +360,29 @@ mod tests {
create_admin_user(&conn).await;
assert!(conn.is_admin_initialized().await.unwrap());
}

#[tokio::test]
async fn test_invitations() {
let conn = new_in_memory().await.unwrap();

let email = "[email protected]".to_owned();
conn.create_invitation(email).await.unwrap();

let invitations = conn.list_invitations().await.unwrap();
assert_eq!(1, invitations.len());

assert!(Uuid::parse_str(&invitations[0].code).is_ok());
let invitation = conn
.get_invitation_by_code(&invitations[0].code)
.await
.ok()
.flatten()
.unwrap();
assert_eq!(invitation.id, invitations[0].id);

conn.delete_invitation(invitations[0].id).await.unwrap();

let invitations = conn.list_invitations().await.unwrap();
assert!(invitations.is_empty());
}
}

0 comments on commit 19d773e

Please sign in to comment.