diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index 3d0281d41788..e6d3e5356ae9 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -128,7 +128,8 @@ impl AuthenticationService for DbConn { ValidationErrors { errors }.into_field_error() })?; - if self.is_admin_initialized().await? { + let is_admin_initialized = self.is_admin_initialized().await?; + if is_admin_initialized { let err = Err("Invitation code is not valid".into()); let Some(invitation_code) = invitation_code else { return err; @@ -150,9 +151,9 @@ impl AuthenticationService for DbConn { let pwd_hash = password_hash(&input.password1)?; - self.create_user(input.email.clone(), pwd_hash, false) + let id = self.create_user(input.email.clone(), pwd_hash, !is_admin_initialized) .await?; - let user = self.get_user_by_email(&input.email).await?.unwrap(); + let user = self.get_user(id).await?.unwrap(); let access_token = generate_jwt(Claims::new(UserInfo::new( user.email.clone(), diff --git a/ee/tabby-webserver/src/service/db.rs b/ee/tabby-webserver/src/service/db.rs index a34cf41fc6f2..89747a9eb077 100644 --- a/ee/tabby-webserver/src/service/db.rs +++ b/ee/tabby-webserver/src/service/db.rs @@ -167,25 +167,35 @@ impl DbConn { email: String, password_encrypted: String, is_admin: bool, - ) -> Result<()> { + ) -> Result { let res = self .conn .call(move |c| { - c.execute( + let mut stmt = c.prepare( r#"INSERT INTO users (email, password_encrypted, is_admin) VALUES (?, ?, ?)"#, - params![email, password_encrypted, is_admin], - ) + )?; + let id = stmt.insert((email, password_encrypted, is_admin))?; + Ok(id) + }) + .await?; + + Ok(res as i32) + } + + pub async fn get_user(&self, id: i32) -> Result> { + let user = self + .conn + .call(move |c| { + c.query_row(User::select("id = ?").as_str(), params![id], User::from_row) + .optional() }) .await?; - if res != 1 { - return Err(anyhow::anyhow!("failed to create user")); - } - Ok(()) + Ok(user) } pub async fn get_user_by_email(&self, email: &str) -> Result> { - let email = email.to_string(); + let email = email.to_owned(); let user = self .conn .call(move |c| { @@ -292,6 +302,8 @@ impl DbConn { #[cfg(test)] mod tests { + use juniper::FieldResult; + use super::*; use crate::schema::auth::AuthenticationService; @@ -300,14 +312,13 @@ mod tests { DbConn::init_db(conn).await } - async fn create_admin_user(conn: &DbConn) -> String { + async fn create_admin_user(conn: &DbConn) -> i32 { let email = "test@example.com"; let passwd = "123456"; let is_admin = true; conn.create_user(email.to_string(), passwd.to_string(), is_admin) .await - .unwrap(); - email.to_owned() + .unwrap() } #[tokio::test] @@ -337,8 +348,8 @@ mod tests { async fn test_create_user() { let conn = new_in_memory().await.unwrap(); - let email = create_admin_user(&conn).await; - let user = conn.get_user_by_email(&email).await.unwrap().unwrap(); + let id = create_admin_user(&conn).await; + let user = conn.get_user(id).await.unwrap().unwrap(); assert_eq!(user.id, 1); } @@ -385,4 +396,62 @@ mod tests { let invitations = conn.list_invitations().await.unwrap(); assert!(invitations.is_empty()); } + + #[tokio::test] + async fn test_invitation_flow() { + let conn = new_in_memory().await.unwrap(); + + assert!(!conn.is_admin_initialized().await.unwrap()); + create_admin_user(&conn).await; + + let email = "user@user.com"; + let password = "12345678"; + + conn.create_invitation(email.to_owned()).await.unwrap(); + let invitation = &conn.list_invitations().await.unwrap()[0]; + + // Admin initialized, registeration requires a invitation code; + assert!( + conn.register( + email.to_owned(), + password.to_owned(), + password.to_owned(), + None + ) + .await.is_err() + ); + + // Invalid invitation code won't work. + assert!(conn + .register( + email.to_owned(), + password.to_owned(), + password.to_owned(), + Some("abc".to_owned()) + ) + .await + .is_err()); + + // Register success. + assert!(conn + .register( + email.to_owned(), + password.to_owned(), + password.to_owned(), + Some(invitation.code.clone()) + ) + .await + .is_ok()); + + // Try register again with same email failed. + assert!(conn + .register( + email.to_owned(), + password.to_owned(), + password.to_owned(), + Some(invitation.code.clone()) + ) + .await + .is_err()); + } }