diff --git a/Cargo.lock b/Cargo.lock index f238708b66a9..11a0e628f93c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -436,9 +436,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "cargo-lock" @@ -3802,6 +3802,7 @@ dependencies = [ "lazy_static", "mime_guess", "pin-project", + "reqwest", "rust-embed 8.0.0", "serde", "serde_json", diff --git a/ee/tabby-db/migrations/06-github-oauth-credential/down.sql b/ee/tabby-db/migrations/06-github-oauth-credential/down.sql new file mode 100644 index 000000000000..31004eb9a2a5 --- /dev/null +++ b/ee/tabby-db/migrations/06-github-oauth-credential/down.sql @@ -0,0 +1 @@ +DROP TABLE github_oauth_credential; \ No newline at end of file diff --git a/ee/tabby-db/migrations/06-github-oauth-credential/up.sql b/ee/tabby-db/migrations/06-github-oauth-credential/up.sql new file mode 100644 index 000000000000..fc56ac5e89d2 --- /dev/null +++ b/ee/tabby-db/migrations/06-github-oauth-credential/up.sql @@ -0,0 +1,8 @@ +CREATE TABLE github_oauth_credential ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + client_id VARCHAR(32) NOT NULL, + client_secret VARCHAR(64) NOT NULL, + active BOOLEAN DEFAULT (1), + created_at TIMESTAMP DEFAULT (DATETIME('now')), + updated_at TIMESTAMP DEFAULT (DATETIME('now')) +); diff --git a/ee/tabby-db/src/github_oauth_credential.rs b/ee/tabby-db/src/github_oauth_credential.rs new file mode 100644 index 000000000000..fe1e819cd6c6 --- /dev/null +++ b/ee/tabby-db/src/github_oauth_credential.rs @@ -0,0 +1,104 @@ +use anyhow::Result; +use chrono::{DateTime, Utc}; +use rusqlite::{named_params, OptionalExtension}; + +use super::DbConn; + +const GITHUB_OAUTH_CREDENTIAL_ROW_ID: i32 = 1; + +pub struct GithubOAuthCredentialDAO { + pub client_id: String, + pub client_secret: String, + pub active: bool, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +impl GithubOAuthCredentialDAO { + fn from_row(row: &rusqlite::Row<'_>) -> std::result::Result { + Ok(Self { + client_id: row.get(0)?, + client_secret: row.get(1)?, + active: row.get(2)?, + created_at: row.get(3)?, + updated_at: row.get(4)?, + }) + } +} + +/// db read/write operations for `github_oauth_credential` table +impl DbConn { + pub async fn update_github_oauth_credential( + &self, + client_id: &str, + client_secret: &str, + active: bool, + ) -> Result<()> { + let client_id = client_id.to_string(); + let client_secret = client_secret.to_string(); + + self.conn + .call(move |c| { + let mut stmt = c.prepare( + r#"INSERT INTO github_oauth_credential (id, client_id, client_secret) + VALUES (:id, :cid, :secret) ON CONFLICT(id) DO UPDATE + SET client_id = :cid, client_secret = :secret, active = :active, updated_at = datetime('now') + WHERE id = :id"#, + )?; + stmt.insert(named_params! { + ":id": GITHUB_OAUTH_CREDENTIAL_ROW_ID, + ":cid": client_id, + ":secret": client_secret, + ":active": active, + })?; + Ok(()) + }) + .await?; + + Ok(()) + } + + pub async fn read_github_oauth_credential(&self) -> Result> { + let token = self + .conn + .call(|conn| { + Ok(conn + .query_row( + r#"SELECT client_id, client_secret, active, created_at, updated_at FROM github_oauth_credential WHERE id = ?"#, + [GITHUB_OAUTH_CREDENTIAL_ROW_ID], + GithubOAuthCredentialDAO::from_row, + ) + .optional()) + }) + .await?; + + Ok(token?) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_update_github_oauth_credential() { + // test insert + let conn = DbConn::new_in_memory().await.unwrap(); + conn.update_github_oauth_credential("client_id", "client_secret", false) + .await + .unwrap(); + let res = conn.read_github_oauth_credential().await.unwrap().unwrap(); + assert_eq!(res.client_id, "client_id"); + assert_eq!(res.client_secret, "client_secret"); + assert!(res.active); + + // test update + conn.update_github_oauth_credential("client_id", "client_secret_2", false) + .await + .unwrap(); + let res = conn.read_github_oauth_credential().await.unwrap().unwrap(); + assert_eq!(res.client_id, "client_id"); + assert_eq!(res.client_secret, "client_secret_2"); + assert!(!res.active); + } +} diff --git a/ee/tabby-db/src/lib.rs b/ee/tabby-db/src/lib.rs index 2f0608c49c78..a9d26bb81c8f 100644 --- a/ee/tabby-db/src/lib.rs +++ b/ee/tabby-db/src/lib.rs @@ -1,7 +1,9 @@ +pub use github_oauth_credential::GithubOAuthCredentialDAO; pub use invitations::InvitationDAO; pub use job_runs::JobRunDAO; pub use users::UserDAO; +mod github_oauth_credential; mod invitations; mod job_runs; mod path; diff --git a/ee/tabby-webserver/Cargo.toml b/ee/tabby-webserver/Cargo.toml index b0a091e5e803..74c191199cc3 100644 --- a/ee/tabby-webserver/Cargo.toml +++ b/ee/tabby-webserver/Cargo.toml @@ -20,6 +20,7 @@ juniper-axum = { path = "../../crates/juniper-axum" } lazy_static.workspace = true mime_guess = "2.0.4" pin-project = "1.1.3" +reqwest = { workspace = true, features = ["json"] } rust-embed = "8.0.0" serde.workspace = true serde_json.workspace = true diff --git a/ee/tabby-webserver/docs/api_spec.md b/ee/tabby-webserver/docs/api_spec.md index 7d01e939be6c..5869c596538b 100644 --- a/ee/tabby-webserver/docs/api_spec.md +++ b/ee/tabby-webserver/docs/api_spec.md @@ -141,3 +141,26 @@ The `Content-Type` for successful response is always `application/json`. ] } ``` + +## OAuth api: `/oauth_callback` + +### GitHub + +**URL:** `/oauth_callback/github` + +**Method:** `GET` + +**Request example:** + +```shell +curl --request GET \ + --url http://localhost:8080/oauth_callback/github?code=1234567890 +``` + +**Response example:** + +The request will redirect to `/auth/signin` with refresh token & access token attached. + +``` +http://localhost:8080/auth/signin?refresh_token=321bc1bbb043456dae1a7abc0c447875&access_token=eyJ0eXAi......1NiJ9.eyJleHAi......bWluIjp0cnVlfQ.GvHSMUfc...S5BnwY +``` diff --git a/ee/tabby-webserver/src/handler.rs b/ee/tabby-webserver/src/handler.rs index 865ba9dfef56..55f0cbc6d922 100644 --- a/ee/tabby-webserver/src/handler.rs +++ b/ee/tabby-webserver/src/handler.rs @@ -14,7 +14,7 @@ use tabby_common::{ }; use crate::{ - hub, repositories, + hub, oauth, repositories, schema::{create_schema, Schema, ServiceLocator}, service::create_service_locator, ui, @@ -48,7 +48,8 @@ pub async fn attach_webserver( .nest( "/repositories", repositories::routes(rs.clone(), ctx.auth()), - ); + ) + .nest("/oauth_callback", oauth::routes(ctx.auth())); let ui = ui .route("/graphiql", routing::get(graphiql("/graphql", None))) diff --git a/ee/tabby-webserver/src/lib.rs b/ee/tabby-webserver/src/lib.rs index 21e55a6ed42d..1812d22ac100 100644 --- a/ee/tabby-webserver/src/lib.rs +++ b/ee/tabby-webserver/src/lib.rs @@ -1,5 +1,6 @@ mod handler; mod hub; +mod oauth; mod repositories; mod schema; mod service; diff --git a/ee/tabby-webserver/src/oauth/github.rs b/ee/tabby-webserver/src/oauth/github.rs new file mode 100644 index 000000000000..434dd3f4e163 --- /dev/null +++ b/ee/tabby-webserver/src/oauth/github.rs @@ -0,0 +1,108 @@ +use anyhow::Result; +use serde::Deserialize; +use tabby_db::GithubOAuthCredentialDAO; + +#[derive(Debug, Deserialize)] +struct GithubOAuthResponse { + #[serde(default)] + access_token: String, + #[serde(default)] + scope: String, + #[serde(default)] + token_type: String, + + #[serde(default)] + error: String, + #[serde(default)] + error_description: String, + #[serde(default)] + error_uri: String, +} + +#[derive(Debug, Deserialize)] +struct GithubUserEmail { + email: String, + primary: bool, + verified: bool, + visibility: String, +} + +pub struct GithubClient { + client: reqwest::Client, +} + +impl Default for GithubClient { + fn default() -> Self { + Self::new() + } +} + +impl GithubClient { + pub fn new() -> Self { + Self { + client: reqwest::Client::new(), + } + } + + pub async fn fetch_user_email( + &self, + code: String, + credential: GithubOAuthCredentialDAO, + ) -> Result { + let token_resp = self.exchange_access_token(code, credential).await?; + if !token_resp.error.is_empty() { + return Err(anyhow::anyhow!( + "Failed to exchange access token: {}", + token_resp.error_description + )); + } + + let resp = self + .client + .get("https://api.github.com/user/emails") + .header(reqwest::header::USER_AGENT, "Tabby") + .header(reqwest::header::ACCEPT, "application/vnd.github+json") + .header( + reqwest::header::AUTHORIZATION, + format!("Bearer {}", token_resp.access_token), + ) + .header("X-GitHub-Api-Version", "2022-11-28") + .send() + .await? + .json::>() + .await?; + + if resp.is_empty() { + return Err(anyhow::anyhow!("No email address found")); + } + for item in &resp { + if item.primary { + return Ok(item.email.clone()); + } + } + Ok(resp[0].email.clone()) + } + + async fn exchange_access_token( + &self, + code: String, + credential: GithubOAuthCredentialDAO, + ) -> Result { + let params = [ + ("client_id", credential.client_id.as_str()), + ("client_secret", credential.client_secret.as_str()), + ("code", code.as_str()), + ]; + let resp = self + .client + .post("https://github.com/login/oauth/access_token") + .header(reqwest::header::ACCEPT, "application/json") + .form(¶ms) + .send() + .await? + .json::() + .await?; + + Ok(resp) + } +} diff --git a/ee/tabby-webserver/src/oauth/mod.rs b/ee/tabby-webserver/src/oauth/mod.rs new file mode 100644 index 000000000000..1f554120af54 --- /dev/null +++ b/ee/tabby-webserver/src/oauth/mod.rs @@ -0,0 +1,70 @@ +use std::sync::Arc; + +use axum::{ + extract::{Query, State}, + http::StatusCode, + response::Redirect, + routing, Router, +}; +use serde::Deserialize; +use tracing::error; + +use crate::{ + oauth::github::GithubClient, + schema::{ + auth::{AuthenticationService, GithubAuthError}, + ServiceLocator, + }, +}; + +pub mod github; + +#[derive(Clone)] +#[non_exhaustive] +struct OAuthState { + auth: Arc, + github_client: Arc, +} + +pub fn routes(auth: Arc) -> Router { + let state = OAuthState { + auth, + github_client: Arc::new(GithubClient::new()), + }; + + Router::new() + .route("/github", routing::get(github_callback)) + .with_state(state) +} + +#[derive(Deserialize)] +#[allow(dead_code)] +struct GithubCallbackParam { + code: String, + state: Option, +} + +async fn github_callback( + State(state): State, + Query(param): Query, +) -> Result { + match state + .auth + .github_auth(param.code, state.github_client.clone()) + .await + { + Ok(resp) => { + let uri = format!( + "/auth/signin?refresh_token={}&access_token={}", + resp.refresh_token, resp.access_token, + ); + Ok(Redirect::temporary(&uri)) + } + Err(GithubAuthError::InvalidVerificationCode) => Err(StatusCode::BAD_REQUEST), + Err(GithubAuthError::CredentialNotActive) => Err(StatusCode::NOT_FOUND), + Err(e) => { + error!("Failed to authenticate with Github: {:?}", e); + Err(StatusCode::INTERNAL_SERVER_ERROR) + } + } +} diff --git a/ee/tabby-webserver/src/schema/auth.rs b/ee/tabby-webserver/src/schema/auth.rs index f753cc4b4907..b7ecffff9a6f 100644 --- a/ee/tabby-webserver/src/schema/auth.rs +++ b/ee/tabby-webserver/src/schema/auth.rs @@ -1,4 +1,4 @@ -use std::fmt::Debug; +use std::{fmt::Debug, sync::Arc}; use anyhow::Result; use async_trait::async_trait; @@ -14,7 +14,7 @@ use uuid::Uuid; use validator::ValidationErrors; use super::from_validation_errors; -use crate::schema::Context; +use crate::{oauth::github::GithubClient, schema::Context}; lazy_static! { static ref JWT_TOKEN_SECRET: String = jwt_token_secret(); @@ -145,6 +145,27 @@ pub enum TokenAuthError { Unknown, } +#[derive(Default, Serialize)] +pub struct GithubAuthResponse { + pub access_token: String, + pub refresh_token: String, +} + +#[derive(Error, Debug)] +pub enum GithubAuthError { + #[error("The code passed is incorrect or expired")] + InvalidVerificationCode, + + #[error("The Github credential is not active")] + CredentialNotActive, + + #[error(transparent)] + Other(#[from] anyhow::Error), + + #[error("Unknown error")] + Unknown, +} + impl IntoFieldError for TokenAuthError { fn into_field_error(self) -> FieldError { match self { @@ -354,6 +375,12 @@ pub trait AuthenticationService: Send + Sync { first: Option, last: Option, ) -> Result>; + + async fn github_auth( + &self, + code: String, + client: Arc, + ) -> std::result::Result; } #[cfg(test)] diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index 8d1e9e913f47..6cf859124458 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -1,4 +1,4 @@ -use std::borrow::Cow; +use std::{borrow::Cow, sync::Arc}; use anyhow::{anyhow, Result}; use argon2::{ @@ -11,10 +11,14 @@ use juniper::ID; use tabby_db::DbConn; use validator::{Validate, ValidationError}; -use crate::schema::auth::{ - generate_jwt, generate_refresh_token, validate_jwt, AuthenticationService, InvitationNext, - JWTPayload, RefreshTokenError, RefreshTokenResponse, RegisterError, RegisterResponse, - TokenAuthError, TokenAuthResponse, User, VerifyTokenResponse, +use crate::{ + oauth::github::GithubClient, + schema::auth::{ + generate_jwt, generate_refresh_token, validate_jwt, AuthenticationService, GithubAuthError, + GithubAuthResponse, InvitationNext, JWTPayload, RefreshTokenError, RefreshTokenResponse, + RegisterError, RegisterResponse, TokenAuthError, TokenAuthResponse, User, + VerifyTokenResponse, + }, }; /// Input parameters for register mutation @@ -339,6 +343,26 @@ impl AuthenticationService for DbConn { Ok(invitations.into_iter().map(|x| x.into()).collect()) } + + async fn github_auth( + &self, + code: String, + client: Arc, + ) -> std::result::Result { + let credential = self + .read_github_oauth_credential() + .await? + .ok_or(GithubAuthError::CredentialNotActive)?; + if !credential.active { + return Err(GithubAuthError::CredentialNotActive); + } + + let _email = client.fetch_user_email(code, credential).await?; + + // TODO: auto register & generate token + + Ok(GithubAuthResponse::default()) + } } fn password_hash(raw: &str) -> password_hash::Result {