diff --git a/ee/tabby-db/src/github_oauth_credential.rs b/ee/tabby-db/src/github_oauth_credential.rs index fe1e819cd6c6..7ff17cf40cbd 100644 --- a/ee/tabby-db/src/github_oauth_credential.rs +++ b/ee/tabby-db/src/github_oauth_credential.rs @@ -31,31 +31,52 @@ impl DbConn { pub async fn update_github_oauth_credential( &self, client_id: &str, - client_secret: &str, + client_secret: Option<&str>, active: bool, ) -> Result<()> { let client_id = client_id.to_string(); - let client_secret = client_secret.to_string(); - - self.conn - .call(move |c| { - let mut stmt = c.prepare( - r#"INSERT INTO github_oauth_credential (id, client_id, client_secret) - VALUES (:id, :cid, :secret) ON CONFLICT(id) DO UPDATE - SET client_id = :cid, client_secret = :secret, active = :active, updated_at = datetime('now') - WHERE id = :id"#, - )?; - stmt.insert(named_params! { + if let Some(client_secret) = client_secret { + let client_secret = client_secret.to_string(); + let sql = r#"INSERT INTO github_oauth_credential (id, client_id, client_secret, active) + VALUES (:id, :cid, :secret, :active) ON CONFLICT(id) DO UPDATE + SET client_id = :cid, client_secret = :secret, active = :active, updated_at = datetime('now') + WHERE id = :id"#; + self.conn + .call(move |c| { + let mut stmt = c.prepare(sql)?; + stmt.insert(named_params! { ":id": GITHUB_OAUTH_CREDENTIAL_ROW_ID, ":cid": client_id, ":secret": client_secret, ":active": active, - })?; - Ok(()) - }) - .await?; - - Ok(()) + })?; + Ok(()) + }) + .await?; + Ok(()) + } else { + let sql = r#" + UPDATE github_oauth_credential SET client_id = :cid, active = :active, updated_at = datetime('now') + WHERE id = :id"#; + let rows = self + .conn + .call(move |c| { + let mut stmt = c.prepare(sql)?; + let rows = stmt.execute(named_params! { + ":id": GITHUB_OAUTH_CREDENTIAL_ROW_ID, + ":cid": client_id, + ":active": active, + })?; + Ok(rows) + }) + .await?; + if rows != 1 { + return Err(anyhow::anyhow!( + "failed to update: github credential not found" + )); + } + Ok(()) + } } pub async fn read_github_oauth_credential(&self) -> Result> { @@ -82,9 +103,16 @@ mod tests { #[tokio::test] async fn test_update_github_oauth_credential() { - // test insert let conn = DbConn::new_in_memory().await.unwrap(); - conn.update_github_oauth_credential("client_id", "client_secret", false) + + // test update failure when no record exists + let res = conn + .update_github_oauth_credential("client_id", None, false) + .await; + assert!(res.is_err()); + + // test insert + conn.update_github_oauth_credential("client_id", Some("client_secret"), true) .await .unwrap(); let res = conn.read_github_oauth_credential().await.unwrap().unwrap(); @@ -93,12 +121,21 @@ mod tests { assert!(res.active); // test update - conn.update_github_oauth_credential("client_id", "client_secret_2", false) + conn.update_github_oauth_credential("client_id", Some("client_secret_2"), false) .await .unwrap(); let res = conn.read_github_oauth_credential().await.unwrap().unwrap(); assert_eq!(res.client_id, "client_id"); assert_eq!(res.client_secret, "client_secret_2"); assert!(!res.active); + + // test update without client_secret + conn.update_github_oauth_credential("client_id_2", None, true) + .await + .unwrap(); + let res = conn.read_github_oauth_credential().await.unwrap().unwrap(); + assert_eq!(res.client_id, "client_id_2"); + assert_eq!(res.client_secret, "client_secret_2"); + assert!(res.active); } } diff --git a/ee/tabby-webserver/graphql/schema.graphql b/ee/tabby-webserver/graphql/schema.graphql index 1bd823d3f592..ef284b2b4740 100644 --- a/ee/tabby-webserver/graphql/schema.graphql +++ b/ee/tabby-webserver/graphql/schema.graphql @@ -13,6 +13,7 @@ type Mutation { createInvitation(email: String!): ID! deleteInvitation(id: Int!): Int! @deprecated deleteInvitationNext(id: ID!): ID! + updateOauthCredential(provider: OAuthProvider!, clientId: String!, clientSecret: String, active: Boolean!): Boolean! } "DateTime" @@ -49,6 +50,7 @@ type Query { usersNext(after: String, before: String, first: Int, last: Int): UserConnection! invitationsNext(after: String, before: String, first: Int, last: Int): InvitationConnection! jobRuns(after: String, before: String, first: Int, last: Int): JobRunConnection! + oauthCredential(provider: OAuthProvider!): OAuthCredential } type UserEdge { @@ -84,6 +86,14 @@ type UserConnection { pageInfo: PageInfo! } +type OAuthCredential { + provider: OAuthProvider! + clientId: String! + active: Boolean! + createdAt: DateTimeUtc! + updatedAt: DateTimeUtc! +} + type VerifyTokenResponse { claims: JWTPayload! } @@ -103,6 +113,11 @@ type User { createdAt: DateTimeUtc! } +type TokenAuthResponse { + accessToken: String! + refreshToken: String! +} + type Worker { kind: WorkerKind! name: String! @@ -119,9 +134,8 @@ type InvitationEdge { cursor: String! } -type TokenAuthResponse { - accessToken: String! - refreshToken: String! +enum OAuthProvider { + GITHUB } type PageInfo { diff --git a/ee/tabby-webserver/src/schema/auth.rs b/ee/tabby-webserver/src/schema/auth.rs index 5cc22085970e..ca1939920706 100644 --- a/ee/tabby-webserver/src/schema/auth.rs +++ b/ee/tabby-webserver/src/schema/auth.rs @@ -4,7 +4,7 @@ use anyhow::Result; use async_trait::async_trait; use chrono::{DateTime, Utc}; use jsonwebtoken as jwt; -use juniper::{FieldError, GraphQLObject, IntoFieldError, ScalarValue, ID}; +use juniper::{FieldError, GraphQLEnum, GraphQLObject, IntoFieldError, ScalarValue, ID}; use juniper_axum::relay; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; @@ -334,6 +334,21 @@ impl relay::NodeType for InvitationNext { } } +#[derive(GraphQLEnum, Clone)] +#[non_exhaustive] +pub enum OAuthProvider { + Github, +} + +#[derive(GraphQLObject)] +pub struct OAuthCredential { + pub provider: OAuthProvider, + pub client_id: String, + pub active: bool, + pub created_at: DateTime, + pub updated_at: DateTime, +} + #[async_trait] pub trait AuthenticationService: Send + Sync { async fn register( @@ -384,6 +399,19 @@ pub trait AuthenticationService: Send + Sync { code: String, client: Arc, ) -> std::result::Result; + + async fn read_oauth_credential( + &self, + provider: OAuthProvider, + ) -> Result>; + + async fn update_oauth_credential( + &self, + provider: OAuthProvider, + client_id: String, + client_secret: Option, + active: bool, + ) -> Result<()>; } #[cfg(test)] diff --git a/ee/tabby-webserver/src/schema/dao.rs b/ee/tabby-webserver/src/schema/dao.rs index 2645518d889b..d3dfa098fec4 100644 --- a/ee/tabby-webserver/src/schema/dao.rs +++ b/ee/tabby-webserver/src/schema/dao.rs @@ -1,6 +1,10 @@ -use tabby_db::{InvitationDAO, JobRunDAO, UserDAO}; +use tabby_db::{GithubOAuthCredentialDAO, InvitationDAO, JobRunDAO, UserDAO}; -use crate::schema::{auth, job}; +use crate::schema::{ + auth, + auth::{OAuthCredential, OAuthProvider}, + job, +}; impl From for auth::InvitationNext { fn from(val: InvitationDAO) -> Self { @@ -38,3 +42,15 @@ impl From for auth::User { } } } + +impl From for OAuthCredential { + fn from(val: GithubOAuthCredentialDAO) -> Self { + OAuthCredential { + provider: OAuthProvider::Github, + client_id: val.client_id, + active: val.active, + created_at: val.created_at, + updated_at: val.updated_at, + } + } +} diff --git a/ee/tabby-webserver/src/schema/mod.rs b/ee/tabby-webserver/src/schema/mod.rs index 31170b327fbc..4ab2bbf90367 100644 --- a/ee/tabby-webserver/src/schema/mod.rs +++ b/ee/tabby-webserver/src/schema/mod.rs @@ -21,6 +21,8 @@ use tracing::error; use validator::ValidationErrors; use worker::{Worker, WorkerService}; +use crate::schema::auth::{OAuthCredential, OAuthProvider}; + pub trait ServiceLocator: Send + Sync { fn auth(&self) -> Arc; fn worker(&self) -> Arc; @@ -240,6 +242,20 @@ impl Query { "Only admin is able to query job runs", ))) } + + async fn oauth_credential( + ctx: &Context, + provider: OAuthProvider, + ) -> Result> { + if let Some(claims) = &ctx.claims { + if claims.is_admin { + return Ok(ctx.locator.auth().read_oauth_credential(provider).await?); + } + } + Err(CoreError::Unauthorized( + "Only admin is able to query oauth credential", + )) + } } #[derive(Default)] @@ -341,6 +357,27 @@ impl Mutation { "Only admin is able to delete invitation", )) } + + async fn update_oauth_credential( + ctx: &Context, + provider: OAuthProvider, + client_id: String, + client_secret: Option, + active: bool, + ) -> Result { + if let Some(claims) = &ctx.claims { + if claims.is_admin { + ctx.locator + .auth() + .update_oauth_credential(provider, client_id, client_secret, active) + .await?; + return Ok(true); + } + } + Err(CoreError::Unauthorized( + "Only admin is able to update oauth credential", + )) + } } fn from_validation_errors(error: ValidationErrors) -> FieldError { diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index d60a271c671e..b7a8eb79c5ec 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -15,9 +15,9 @@ use crate::{ oauth::github::GithubClient, schema::auth::{ generate_jwt, generate_refresh_token, validate_jwt, AuthenticationService, GithubAuthError, - GithubAuthResponse, InvitationNext, JWTPayload, RefreshTokenError, RefreshTokenResponse, - RegisterError, RegisterResponse, TokenAuthError, TokenAuthResponse, User, - VerifyTokenResponse, + GithubAuthResponse, InvitationNext, JWTPayload, OAuthCredential, OAuthProvider, + RefreshTokenError, RefreshTokenResponse, RegisterError, RegisterResponse, TokenAuthError, + TokenAuthResponse, User, VerifyTokenResponse, }, }; @@ -387,6 +387,31 @@ impl AuthenticationService for DbConn { }; Ok(resp) } + + async fn read_oauth_credential( + &self, + provider: OAuthProvider, + ) -> Result> { + match provider { + OAuthProvider::Github => { + Ok(self.read_github_oauth_credential().await?.map(|x| x.into())) + } + } + } + + async fn update_oauth_credential( + &self, + provider: OAuthProvider, + client_id: String, + client_secret: Option, + active: bool, + ) -> Result<()> { + match provider { + OAuthProvider::Github => Ok(self + .update_github_oauth_credential(&client_id, client_secret.as_deref(), active) + .await?), + } + } } fn password_hash(raw: &str) -> password_hash::Result {