-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(webserver): Add github oauth support (#1160)
* feat(webserver): add github oauth support * fix test * resolve comments * [autofix.ci] apply automated fixes (attempt 2/3) * fix test * [autofix.ci] apply automated fixes * switch to reqwest --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
cb035a6
commit f9dc54a
Showing
13 changed files
with
382 additions
and
11 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
DROP TABLE github_oauth_credential; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
CREATE TABLE github_oauth_credential ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
client_id VARCHAR(32) NOT NULL, | ||
client_secret VARCHAR(64) NOT NULL, | ||
active BOOLEAN DEFAULT (1), | ||
created_at TIMESTAMP DEFAULT (DATETIME('now')), | ||
updated_at TIMESTAMP DEFAULT (DATETIME('now')) | ||
); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
use anyhow::Result; | ||
use chrono::{DateTime, Utc}; | ||
use rusqlite::{named_params, OptionalExtension}; | ||
|
||
use super::DbConn; | ||
|
||
const GITHUB_OAUTH_CREDENTIAL_ROW_ID: i32 = 1; | ||
|
||
pub struct GithubOAuthCredentialDAO { | ||
pub client_id: String, | ||
pub client_secret: String, | ||
pub active: bool, | ||
pub created_at: DateTime<Utc>, | ||
pub updated_at: DateTime<Utc>, | ||
} | ||
|
||
impl GithubOAuthCredentialDAO { | ||
fn from_row(row: &rusqlite::Row<'_>) -> std::result::Result<Self, rusqlite::Error> { | ||
Ok(Self { | ||
client_id: row.get(0)?, | ||
client_secret: row.get(1)?, | ||
active: row.get(2)?, | ||
created_at: row.get(3)?, | ||
updated_at: row.get(4)?, | ||
}) | ||
} | ||
} | ||
|
||
/// db read/write operations for `github_oauth_credential` table | ||
impl DbConn { | ||
pub async fn update_github_oauth_credential( | ||
&self, | ||
client_id: &str, | ||
client_secret: &str, | ||
active: bool, | ||
) -> Result<()> { | ||
let client_id = client_id.to_string(); | ||
let client_secret = client_secret.to_string(); | ||
|
||
self.conn | ||
.call(move |c| { | ||
let mut stmt = c.prepare( | ||
r#"INSERT INTO github_oauth_credential (id, client_id, client_secret) | ||
VALUES (:id, :cid, :secret) ON CONFLICT(id) DO UPDATE | ||
SET client_id = :cid, client_secret = :secret, active = :active, updated_at = datetime('now') | ||
WHERE id = :id"#, | ||
)?; | ||
stmt.insert(named_params! { | ||
":id": GITHUB_OAUTH_CREDENTIAL_ROW_ID, | ||
":cid": client_id, | ||
":secret": client_secret, | ||
":active": active, | ||
})?; | ||
Ok(()) | ||
}) | ||
.await?; | ||
|
||
Ok(()) | ||
} | ||
|
||
pub async fn read_github_oauth_credential(&self) -> Result<Option<GithubOAuthCredentialDAO>> { | ||
let token = self | ||
.conn | ||
.call(|conn| { | ||
Ok(conn | ||
.query_row( | ||
r#"SELECT client_id, client_secret, active, created_at, updated_at FROM github_oauth_credential WHERE id = ?"#, | ||
[GITHUB_OAUTH_CREDENTIAL_ROW_ID], | ||
GithubOAuthCredentialDAO::from_row, | ||
) | ||
.optional()) | ||
}) | ||
.await?; | ||
|
||
Ok(token?) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[tokio::test] | ||
async fn test_update_github_oauth_credential() { | ||
// test insert | ||
let conn = DbConn::new_in_memory().await.unwrap(); | ||
conn.update_github_oauth_credential("client_id", "client_secret", false) | ||
.await | ||
.unwrap(); | ||
let res = conn.read_github_oauth_credential().await.unwrap().unwrap(); | ||
assert_eq!(res.client_id, "client_id"); | ||
assert_eq!(res.client_secret, "client_secret"); | ||
assert!(res.active); | ||
|
||
// test update | ||
conn.update_github_oauth_credential("client_id", "client_secret_2", false) | ||
.await | ||
.unwrap(); | ||
let res = conn.read_github_oauth_credential().await.unwrap().unwrap(); | ||
assert_eq!(res.client_id, "client_id"); | ||
assert_eq!(res.client_secret, "client_secret_2"); | ||
assert!(!res.active); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
mod handler; | ||
mod hub; | ||
mod oauth; | ||
mod repositories; | ||
mod schema; | ||
mod service; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
use anyhow::Result; | ||
use serde::Deserialize; | ||
use tabby_db::GithubOAuthCredentialDAO; | ||
|
||
#[derive(Debug, Deserialize)] | ||
struct GithubOAuthResponse { | ||
#[serde(default)] | ||
access_token: String, | ||
#[serde(default)] | ||
scope: String, | ||
#[serde(default)] | ||
token_type: String, | ||
|
||
#[serde(default)] | ||
error: String, | ||
#[serde(default)] | ||
error_description: String, | ||
#[serde(default)] | ||
error_uri: String, | ||
} | ||
|
||
#[derive(Debug, Deserialize)] | ||
struct GithubUserEmail { | ||
email: String, | ||
primary: bool, | ||
verified: bool, | ||
visibility: String, | ||
} | ||
|
||
pub struct GithubClient { | ||
client: reqwest::Client, | ||
} | ||
|
||
impl Default for GithubClient { | ||
fn default() -> Self { | ||
Self::new() | ||
} | ||
} | ||
|
||
impl GithubClient { | ||
pub fn new() -> Self { | ||
Self { | ||
client: reqwest::Client::new(), | ||
} | ||
} | ||
|
||
pub async fn fetch_user_email( | ||
&self, | ||
code: String, | ||
credential: GithubOAuthCredentialDAO, | ||
) -> Result<String> { | ||
let token_resp = self.exchange_access_token(code, credential).await?; | ||
if !token_resp.error.is_empty() { | ||
return Err(anyhow::anyhow!( | ||
"Failed to exchange access token: {}", | ||
token_resp.error_description | ||
)); | ||
} | ||
|
||
let resp = self | ||
.client | ||
.get("https://api.github.com/user/emails") | ||
.header(reqwest::header::USER_AGENT, "Tabby") | ||
.header(reqwest::header::ACCEPT, "application/vnd.github+json") | ||
.header( | ||
reqwest::header::AUTHORIZATION, | ||
format!("Bearer {}", token_resp.access_token), | ||
) | ||
.header("X-GitHub-Api-Version", "2022-11-28") | ||
.send() | ||
.await? | ||
.json::<Vec<GithubUserEmail>>() | ||
.await?; | ||
|
||
if resp.is_empty() { | ||
return Err(anyhow::anyhow!("No email address found")); | ||
} | ||
for item in &resp { | ||
if item.primary { | ||
return Ok(item.email.clone()); | ||
} | ||
} | ||
Ok(resp[0].email.clone()) | ||
} | ||
|
||
async fn exchange_access_token( | ||
&self, | ||
code: String, | ||
credential: GithubOAuthCredentialDAO, | ||
) -> Result<GithubOAuthResponse> { | ||
let params = [ | ||
("client_id", credential.client_id.as_str()), | ||
("client_secret", credential.client_secret.as_str()), | ||
("code", code.as_str()), | ||
]; | ||
let resp = self | ||
.client | ||
.post("https://github.com/login/oauth/access_token") | ||
.header(reqwest::header::ACCEPT, "application/json") | ||
.form(¶ms) | ||
.send() | ||
.await? | ||
.json::<GithubOAuthResponse>() | ||
.await?; | ||
|
||
Ok(resp) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
use std::sync::Arc; | ||
|
||
use axum::{ | ||
extract::{Query, State}, | ||
http::StatusCode, | ||
response::Redirect, | ||
routing, Router, | ||
}; | ||
use serde::Deserialize; | ||
use tracing::error; | ||
|
||
use crate::{ | ||
oauth::github::GithubClient, | ||
schema::{ | ||
auth::{AuthenticationService, GithubAuthError}, | ||
ServiceLocator, | ||
}, | ||
}; | ||
|
||
pub mod github; | ||
|
||
#[derive(Clone)] | ||
#[non_exhaustive] | ||
struct OAuthState { | ||
auth: Arc<dyn AuthenticationService>, | ||
github_client: Arc<GithubClient>, | ||
} | ||
|
||
pub fn routes(auth: Arc<dyn AuthenticationService>) -> Router { | ||
let state = OAuthState { | ||
auth, | ||
github_client: Arc::new(GithubClient::new()), | ||
}; | ||
|
||
Router::new() | ||
.route("/github", routing::get(github_callback)) | ||
.with_state(state) | ||
} | ||
|
||
#[derive(Deserialize)] | ||
#[allow(dead_code)] | ||
struct GithubCallbackParam { | ||
code: String, | ||
state: Option<String>, | ||
} | ||
|
||
async fn github_callback( | ||
State(state): State<OAuthState>, | ||
Query(param): Query<GithubCallbackParam>, | ||
) -> Result<Redirect, StatusCode> { | ||
match state | ||
.auth | ||
.github_auth(param.code, state.github_client.clone()) | ||
.await | ||
{ | ||
Ok(resp) => { | ||
let uri = format!( | ||
"/auth/signin?refresh_token={}&access_token={}", | ||
resp.refresh_token, resp.access_token, | ||
); | ||
Ok(Redirect::temporary(&uri)) | ||
} | ||
Err(GithubAuthError::InvalidVerificationCode) => Err(StatusCode::BAD_REQUEST), | ||
Err(GithubAuthError::CredentialNotActive) => Err(StatusCode::NOT_FOUND), | ||
Err(e) => { | ||
error!("Failed to authenticate with Github: {:?}", e); | ||
Err(StatusCode::INTERNAL_SERVER_ERROR) | ||
} | ||
} | ||
} |
Oops, something went wrong.