Skip to content

Commit

Permalink
feat(webserver): Add graphql api for oauth credential management (#1177)
Browse files Browse the repository at this point in the history
* feat(webserver): graphql api for oauth management

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* resolve comment

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
darknight and autofix-ci[bot] authored Jan 9, 2024
1 parent 356d1b0 commit ef7674c
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 30 deletions.
79 changes: 58 additions & 21 deletions ee/tabby-db/src/github_oauth_credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,52 @@ impl DbConn {
pub async fn update_github_oauth_credential(
&self,
client_id: &str,
client_secret: &str,
client_secret: Option<&str>,
active: bool,
) -> Result<()> {
let client_id = client_id.to_string();
let client_secret = client_secret.to_string();

self.conn
.call(move |c| {
let mut stmt = c.prepare(
r#"INSERT INTO github_oauth_credential (id, client_id, client_secret)
VALUES (:id, :cid, :secret) ON CONFLICT(id) DO UPDATE
SET client_id = :cid, client_secret = :secret, active = :active, updated_at = datetime('now')
WHERE id = :id"#,
)?;
stmt.insert(named_params! {
if let Some(client_secret) = client_secret {
let client_secret = client_secret.to_string();
let sql = r#"INSERT INTO github_oauth_credential (id, client_id, client_secret, active)
VALUES (:id, :cid, :secret, :active) ON CONFLICT(id) DO UPDATE
SET client_id = :cid, client_secret = :secret, active = :active, updated_at = datetime('now')
WHERE id = :id"#;
self.conn
.call(move |c| {
let mut stmt = c.prepare(sql)?;
stmt.insert(named_params! {
":id": GITHUB_OAUTH_CREDENTIAL_ROW_ID,
":cid": client_id,
":secret": client_secret,
":active": active,
})?;
Ok(())
})
.await?;

Ok(())
})?;
Ok(())
})
.await?;
Ok(())
} else {
let sql = r#"
UPDATE github_oauth_credential SET client_id = :cid, active = :active, updated_at = datetime('now')
WHERE id = :id"#;
let rows = self
.conn
.call(move |c| {
let mut stmt = c.prepare(sql)?;
let rows = stmt.execute(named_params! {
":id": GITHUB_OAUTH_CREDENTIAL_ROW_ID,
":cid": client_id,
":active": active,
})?;
Ok(rows)
})
.await?;
if rows != 1 {
return Err(anyhow::anyhow!(
"failed to update: github credential not found"
));
}
Ok(())
}
}

pub async fn read_github_oauth_credential(&self) -> Result<Option<GithubOAuthCredentialDAO>> {
Expand All @@ -82,9 +103,16 @@ mod tests {

#[tokio::test]
async fn test_update_github_oauth_credential() {
// test insert
let conn = DbConn::new_in_memory().await.unwrap();
conn.update_github_oauth_credential("client_id", "client_secret", false)

// test update failure when no record exists
let res = conn
.update_github_oauth_credential("client_id", None, false)
.await;
assert!(res.is_err());

// test insert
conn.update_github_oauth_credential("client_id", Some("client_secret"), true)
.await
.unwrap();
let res = conn.read_github_oauth_credential().await.unwrap().unwrap();
Expand All @@ -93,12 +121,21 @@ mod tests {
assert!(res.active);

// test update
conn.update_github_oauth_credential("client_id", "client_secret_2", false)
conn.update_github_oauth_credential("client_id", Some("client_secret_2"), false)
.await
.unwrap();
let res = conn.read_github_oauth_credential().await.unwrap().unwrap();
assert_eq!(res.client_id, "client_id");
assert_eq!(res.client_secret, "client_secret_2");
assert!(!res.active);

// test update without client_secret
conn.update_github_oauth_credential("client_id_2", None, true)
.await
.unwrap();
let res = conn.read_github_oauth_credential().await.unwrap().unwrap();
assert_eq!(res.client_id, "client_id_2");
assert_eq!(res.client_secret, "client_secret_2");
assert!(res.active);
}
}
20 changes: 17 additions & 3 deletions ee/tabby-webserver/graphql/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type Mutation {
createInvitation(email: String!): ID!
deleteInvitation(id: Int!): Int! @deprecated
deleteInvitationNext(id: ID!): ID!
updateOauthCredential(provider: OAuthProvider!, clientId: String!, clientSecret: String, active: Boolean!): Boolean!
}

"DateTime"
Expand Down Expand Up @@ -49,6 +50,7 @@ type Query {
usersNext(after: String, before: String, first: Int, last: Int): UserConnection!
invitationsNext(after: String, before: String, first: Int, last: Int): InvitationConnection!
jobRuns(after: String, before: String, first: Int, last: Int): JobRunConnection!
oauthCredential(provider: OAuthProvider!): OAuthCredential
}

type UserEdge {
Expand Down Expand Up @@ -84,6 +86,14 @@ type UserConnection {
pageInfo: PageInfo!
}

type OAuthCredential {
provider: OAuthProvider!
clientId: String!
active: Boolean!
createdAt: DateTimeUtc!
updatedAt: DateTimeUtc!
}

type VerifyTokenResponse {
claims: JWTPayload!
}
Expand All @@ -103,6 +113,11 @@ type User {
createdAt: DateTimeUtc!
}

type TokenAuthResponse {
accessToken: String!
refreshToken: String!
}

type Worker {
kind: WorkerKind!
name: String!
Expand All @@ -119,9 +134,8 @@ type InvitationEdge {
cursor: String!
}

type TokenAuthResponse {
accessToken: String!
refreshToken: String!
enum OAuthProvider {
GITHUB
}

type PageInfo {
Expand Down
30 changes: 29 additions & 1 deletion ee/tabby-webserver/src/schema/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use anyhow::Result;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use jsonwebtoken as jwt;
use juniper::{FieldError, GraphQLObject, IntoFieldError, ScalarValue, ID};
use juniper::{FieldError, GraphQLEnum, GraphQLObject, IntoFieldError, ScalarValue, ID};
use juniper_axum::relay;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -334,6 +334,21 @@ impl relay::NodeType for InvitationNext {
}
}

#[derive(GraphQLEnum, Clone)]
#[non_exhaustive]
pub enum OAuthProvider {
Github,
}

#[derive(GraphQLObject)]
pub struct OAuthCredential {
pub provider: OAuthProvider,
pub client_id: String,
pub active: bool,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}

#[async_trait]
pub trait AuthenticationService: Send + Sync {
async fn register(
Expand Down Expand Up @@ -384,6 +399,19 @@ pub trait AuthenticationService: Send + Sync {
code: String,
client: Arc<GithubClient>,
) -> std::result::Result<GithubAuthResponse, GithubAuthError>;

async fn read_oauth_credential(
&self,
provider: OAuthProvider,
) -> Result<Option<OAuthCredential>>;

async fn update_oauth_credential(
&self,
provider: OAuthProvider,
client_id: String,
client_secret: Option<String>,
active: bool,
) -> Result<()>;
}

#[cfg(test)]
Expand Down
20 changes: 18 additions & 2 deletions ee/tabby-webserver/src/schema/dao.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use tabby_db::{InvitationDAO, JobRunDAO, UserDAO};
use tabby_db::{GithubOAuthCredentialDAO, InvitationDAO, JobRunDAO, UserDAO};

use crate::schema::{auth, job};
use crate::schema::{
auth,
auth::{OAuthCredential, OAuthProvider},
job,
};

impl From<InvitationDAO> for auth::InvitationNext {
fn from(val: InvitationDAO) -> Self {
Expand Down Expand Up @@ -38,3 +42,15 @@ impl From<UserDAO> for auth::User {
}
}
}

impl From<GithubOAuthCredentialDAO> for OAuthCredential {
fn from(val: GithubOAuthCredentialDAO) -> Self {
OAuthCredential {
provider: OAuthProvider::Github,
client_id: val.client_id,
active: val.active,
created_at: val.created_at,
updated_at: val.updated_at,
}
}
}
37 changes: 37 additions & 0 deletions ee/tabby-webserver/src/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use tracing::error;
use validator::ValidationErrors;
use worker::{Worker, WorkerService};

use crate::schema::auth::{OAuthCredential, OAuthProvider};

pub trait ServiceLocator: Send + Sync {
fn auth(&self) -> Arc<dyn AuthenticationService>;
fn worker(&self) -> Arc<dyn WorkerService>;
Expand Down Expand Up @@ -240,6 +242,20 @@ impl Query {
"Only admin is able to query job runs",
)))
}

async fn oauth_credential(
ctx: &Context,
provider: OAuthProvider,
) -> Result<Option<OAuthCredential>> {
if let Some(claims) = &ctx.claims {
if claims.is_admin {
return Ok(ctx.locator.auth().read_oauth_credential(provider).await?);
}
}
Err(CoreError::Unauthorized(
"Only admin is able to query oauth credential",
))
}
}

#[derive(Default)]
Expand Down Expand Up @@ -341,6 +357,27 @@ impl Mutation {
"Only admin is able to delete invitation",
))
}

async fn update_oauth_credential(
ctx: &Context,
provider: OAuthProvider,
client_id: String,
client_secret: Option<String>,
active: bool,
) -> Result<bool> {
if let Some(claims) = &ctx.claims {
if claims.is_admin {
ctx.locator
.auth()
.update_oauth_credential(provider, client_id, client_secret, active)
.await?;
return Ok(true);
}
}
Err(CoreError::Unauthorized(
"Only admin is able to update oauth credential",
))
}
}

fn from_validation_errors<S: ScalarValue>(error: ValidationErrors) -> FieldError<S> {
Expand Down
31 changes: 28 additions & 3 deletions ee/tabby-webserver/src/service/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ use crate::{
oauth::github::GithubClient,
schema::auth::{
generate_jwt, generate_refresh_token, validate_jwt, AuthenticationService, GithubAuthError,
GithubAuthResponse, InvitationNext, JWTPayload, RefreshTokenError, RefreshTokenResponse,
RegisterError, RegisterResponse, TokenAuthError, TokenAuthResponse, User,
VerifyTokenResponse,
GithubAuthResponse, InvitationNext, JWTPayload, OAuthCredential, OAuthProvider,
RefreshTokenError, RefreshTokenResponse, RegisterError, RegisterResponse, TokenAuthError,
TokenAuthResponse, User, VerifyTokenResponse,
},
};

Expand Down Expand Up @@ -387,6 +387,31 @@ impl AuthenticationService for DbConn {
};
Ok(resp)
}

async fn read_oauth_credential(
&self,
provider: OAuthProvider,
) -> Result<Option<OAuthCredential>> {
match provider {
OAuthProvider::Github => {
Ok(self.read_github_oauth_credential().await?.map(|x| x.into()))
}
}
}

async fn update_oauth_credential(
&self,
provider: OAuthProvider,
client_id: String,
client_secret: Option<String>,
active: bool,
) -> Result<()> {
match provider {
OAuthProvider::Github => Ok(self
.update_github_oauth_credential(&client_id, client_secret.as_deref(), active)
.await?),
}
}
}

fn password_hash(raw: &str) -> password_hash::Result<String> {
Expand Down

0 comments on commit ef7674c

Please sign in to comment.