Skip to content

Commit

Permalink
feat(db): Don't require the secret when updating the OAuth credential (
Browse files Browse the repository at this point in the history
…#1452)

* feat(db): Don't require the secret when updating the OAuth credential

* Do not expose client_secret through GraphQL

* Cover new code with test cases

* [autofix.ci] apply automated fixes

* Fix test

* update(ui): optional secret

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: liangfung <[email protected]>
  • Loading branch information
3 people authored Feb 17, 2024
1 parent ae0d596 commit 1ab58a2
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 29 deletions.
35 changes: 27 additions & 8 deletions ee/tabby-db/src/github_oauth_credential.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use anyhow::Result;
use anyhow::{anyhow, Result};
use chrono::{DateTime, Utc};
use sqlx::{query, FromRow};
use sqlx::{query, query_scalar, FromRow};

use super::DbConn;

Expand All @@ -19,10 +19,21 @@ impl DbConn {
pub async fn update_github_oauth_credential(
&self,
client_id: &str,
client_secret: &str,
client_secret: Option<&str>,
) -> Result<()> {
let client_id = client_id.to_string();
let client_secret = client_secret.to_string();
let mut transaction = self.pool.begin().await?;
let client_secret = match client_secret {
Some(secret) => secret.to_string(),
None => {
query_scalar!(
"SELECT client_secret FROM github_oauth_credential WHERE id = ?",
GITHUB_OAUTH_CREDENTIAL_ROW_ID
)
.fetch_one(&mut *transaction)
.await.map_err(|_| anyhow!("Must specify client secret when updating the OAuth credential for the first time"))?
}
};
query!(
r#"INSERT INTO github_oauth_credential (id, client_id, client_secret)
VALUES ($1, $2, $3) ON CONFLICT(id) DO UPDATE
Expand All @@ -32,8 +43,9 @@ impl DbConn {
client_id,
client_secret
)
.execute(&self.pool)
.execute(&mut *transaction)
.await?;
transaction.commit().await?;
Ok(())
}

Expand Down Expand Up @@ -64,15 +76,22 @@ mod tests {
let conn = DbConn::new_in_memory().await.unwrap();

// test insert
conn.update_github_oauth_credential("client_id", "client_secret")
conn.update_github_oauth_credential("client_id", Some("client_secret"))
.await
.unwrap();
let res = conn.read_github_oauth_credential().await.unwrap().unwrap();
assert_eq!(res.client_id, "client_id");
assert_eq!(res.client_secret, "client_secret");

// test update
conn.update_github_oauth_credential("client_id", "client_secret_2")
conn.update_github_oauth_credential("client_id", Some("client_secret_2"))
.await
.unwrap();
let res = conn.read_github_oauth_credential().await.unwrap().unwrap();
assert_eq!(res.client_id, "client_id");
assert_eq!(res.client_secret, "client_secret_2");

conn.update_github_oauth_credential("client_id", None)
.await
.unwrap();
let res = conn.read_github_oauth_credential().await.unwrap().unwrap();
Expand All @@ -84,7 +103,7 @@ mod tests {
assert!(conn.read_github_oauth_credential().await.unwrap().is_none());

// test update after delete
conn.update_github_oauth_credential("client_id_2", "client_secret_2")
conn.update_github_oauth_credential("client_id_2", Some("client_secret_2"))
.await
.unwrap();
let res = conn.read_github_oauth_credential().await.unwrap().unwrap();
Expand Down
34 changes: 26 additions & 8 deletions ee/tabby-db/src/google_oauth_credential.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use anyhow::Result;
use anyhow::{anyhow, Result};
use chrono::{DateTime, Utc};
use sqlx::{query, FromRow};
use sqlx::{query, query_scalar, FromRow};

use super::DbConn;

Expand All @@ -19,10 +19,20 @@ impl DbConn {
pub async fn update_google_oauth_credential(
&self,
client_id: &str,
client_secret: &str,
client_secret: Option<&str>,
) -> Result<()> {
let client_id = client_id.to_string();
let client_secret = client_secret.to_string();
let mut transaction = self.pool.begin().await?;
let client_secret = match client_secret {
Some(secret) => secret.to_string(),
None => query_scalar!(
"SELECT client_secret FROM google_oauth_credential WHERE id = ?",
GOOGLE_OAUTH_CREDENTIAL_ROW_ID
)
.fetch_one(&mut *transaction)
.await
.map_err(|_| anyhow!("Must specify client secret when updating the OAuth credential for the first time"))?,
};
query!(
r#"INSERT INTO google_oauth_credential (id, client_id, client_secret)
VALUES ($1, $2, $3) ON CONFLICT(id) DO UPDATE
Expand All @@ -32,8 +42,9 @@ impl DbConn {
client_id,
client_secret,
)
.execute(&self.pool)
.execute(&mut *transaction)
.await?;
transaction.commit().await?;
Ok(())
}

Expand Down Expand Up @@ -64,7 +75,7 @@ mod tests {
let conn = DbConn::new_in_memory().await.unwrap();

// test insert
conn.update_google_oauth_credential("client_id", "client_secret")
conn.update_google_oauth_credential("client_id", Some("client_secret"))
.await
.unwrap();
let res = conn.read_google_oauth_credential().await.unwrap().unwrap();
Expand All @@ -77,13 +88,20 @@ mod tests {
assert!(res.is_none());

// test insert with redirect_uri
conn.update_google_oauth_credential("client_id", "client_secret")
conn.update_google_oauth_credential("client_id", Some("client_secret"))
.await
.unwrap();
conn.read_google_oauth_credential().await.unwrap().unwrap();

conn.update_google_oauth_credential("client_id", None)
.await
.unwrap();
let res = conn.read_google_oauth_credential().await.unwrap().unwrap();
assert_eq!(res.client_id, "client_id");
assert_eq!(res.client_secret, "client_secret");

// test update
conn.update_google_oauth_credential("client_id_2", "client_secret_2")
conn.update_google_oauth_credential("client_id_2", Some("client_secret_2"))
.await
.unwrap();
let res = conn.read_google_oauth_credential().await.unwrap().unwrap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ const oauthCallbackUrl = graphql(/* GraphQL */ `

const formSchema = z.object({
clientId: z.string(),
clientSecret: z.string(),
clientSecret: z.string().optional(),
provider: z.nativeEnum(OAuthProvider)
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ export const oauthCredential = graphql(/* GraphQL */ `
oauthCredential(provider: $provider) {
provider
clientId
clientSecret
createdAt
updatedAt
}
Expand Down
3 changes: 1 addition & 2 deletions ee/tabby-webserver/graphql/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,12 @@ type UserConnection {
input UpdateOAuthCredentialInput {
provider: OAuthProvider!
clientId: String!
clientSecret: String!
clientSecret: String
}

type OAuthCredential {
provider: OAuthProvider!
clientId: String!
clientSecret: String!
createdAt: DateTimeUtc!
updatedAt: DateTimeUtc!
}
Expand Down
2 changes: 0 additions & 2 deletions ee/tabby-webserver/src/oauth/github.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,8 @@ impl GithubClient {
code: String,
credential: OAuthCredential,
) -> Result<GithubOAuthResponse> {
let client_secret = credential.client_secret;
let params = [
("client_id", credential.client_id.as_str()),
("client_secret", client_secret.as_str()),
("code", code.as_str()),
];
let resp = self
Expand Down
1 change: 0 additions & 1 deletion ee/tabby-webserver/src/oauth/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ impl GoogleClient {
) -> Result<GoogleOAuthResponse> {
let params = [
("client_id", credential.client_id.as_str()),
("client_secret", credential.client_secret.as_str()),
("code", code.as_str()),
("grant_type", "authorization_code"),
("redirect_uri", redirect_uri.as_str()),
Expand Down
3 changes: 1 addition & 2 deletions ee/tabby-webserver/src/schema/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,6 @@ pub struct OAuthCredential {
pub provider: OAuthProvider,
pub client_id: String,

pub client_secret: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
Expand All @@ -413,7 +412,7 @@ pub struct UpdateOAuthCredentialInput {
code = "clientSecret",
message = "Client secret cannot be empty"
))]
pub client_secret: String,
pub client_secret: Option<String>,
}

#[async_trait]
Expand Down
4 changes: 2 additions & 2 deletions ee/tabby-webserver/src/service/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,11 +486,11 @@ impl AuthenticationService for AuthenticationServiceImpl {
match input.provider {
OAuthProvider::Github => Ok(self
.db
.update_github_oauth_credential(&input.client_id, &input.client_secret)
.update_github_oauth_credential(&input.client_id, input.client_secret.as_deref())
.await?),
OAuthProvider::Google => Ok(self
.db
.update_google_oauth_credential(&input.client_id, &input.client_secret)
.update_google_oauth_credential(&input.client_id, input.client_secret.as_deref())
.await?),
}
}
Expand Down
2 changes: 0 additions & 2 deletions ee/tabby-webserver/src/service/dao.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ impl From<GithubOAuthCredentialDAO> for OAuthCredential {
OAuthCredential {
provider: OAuthProvider::Github,
client_id: val.client_id,
client_secret: val.client_secret,
created_at: val.created_at,
updated_at: val.updated_at,
}
Expand All @@ -71,7 +70,6 @@ impl From<GoogleOAuthCredentialDAO> for OAuthCredential {
OAuthCredential {
provider: OAuthProvider::Google,
client_id: val.client_id,
client_secret: val.client_secret,
created_at: val.created_at,
updated_at: val.updated_at,
}
Expand Down

0 comments on commit 1ab58a2

Please sign in to comment.