From 6336ed5432e43652c6f066b2467e45ceeecd3235 Mon Sep 17 00:00:00 2001 From: Lu Yang Date: Sat, 24 Aug 2024 23:37:40 +0100 Subject: [PATCH] associate a user with multiple git providers --- lapdev-api/src/account.rs | 138 ++++++- lapdev-api/src/auth.rs | 11 +- lapdev-api/src/organization.rs | 17 +- lapdev-api/src/project.rs | 10 +- lapdev-api/src/router.rs | 13 + lapdev-api/src/session.rs | 144 +++++-- lapdev-api/src/state.rs | 13 +- lapdev-common/src/console.rs | 2 - lapdev-common/src/lib.rs | 32 +- lapdev-conductor/src/server.rs | 67 ++- lapdev-dashboard/main.css | 23 +- lapdev-dashboard/src/app.rs | 2 + lapdev-dashboard/src/git_provider.rs | 391 ++++++++++++++++++ lapdev-dashboard/src/lib.rs | 1 + lapdev-dashboard/src/nav.rs | 8 + lapdev-db/src/api.rs | 119 +++++- lapdev-db/src/entities/mod.rs | 1 + lapdev-db/src/entities/oauth_connection.rs | 26 ++ lapdev-db/src/entities/organization.rs | 1 + lapdev-db/src/entities/project.rs | 1 + lapdev-db/src/entities/user.rs | 3 - .../m20231106_100019_create_user_table.rs | 20 - .../m20231109_171859_create_project_table.rs | 2 + .../m20240823_165223_create_oauth_table.rs | 102 +++++ lapdev-db/src/migration/mod.rs | 2 + lapdev-enterprise/src/enterprise.rs | 49 ++- lapdev-enterprise/src/quota.rs | 8 +- 27 files changed, 1073 insertions(+), 133 deletions(-) create mode 100644 lapdev-dashboard/src/git_provider.rs create mode 100644 lapdev-db/src/entities/oauth_connection.rs create mode 100644 lapdev-db/src/migration/m20240823_165223_create_oauth_table.rs diff --git a/lapdev-api/src/account.rs b/lapdev-api/src/account.rs index f19f45f..5852340 100644 --- a/lapdev-api/src/account.rs +++ b/lapdev-api/src/account.rs @@ -1,7 +1,7 @@ -use std::str::FromStr; +use std::{collections::HashMap, str::FromStr}; use axum::{ - extract::{Path, State}, + extract::{Host, Path, Query, State}, response::{IntoResponse, Response}, Json, }; @@ -12,15 +12,15 @@ use axum_extra::{ use chrono::Utc; use hyper::StatusCode; use lapdev_common::{ - console::{MeUser, Organization}, - NewSshKey, SshKey, UserRole, + console::{MeUser, NewSessionResponse, Organization}, + GitProvider, NewSshKey, SshKey, UserRole, }; use lapdev_db::{api::DbApi, entities}; use lapdev_rpc::error::ApiError; use russh::keys::PublicKeyBase64; use sea_orm::{prelude::Uuid, ActiveModelTrait, ActiveValue}; -use crate::state::CoreState; +use crate::{session::create_oauth_connection, state::CoreState}; pub async fn me( State(state): State, @@ -32,7 +32,6 @@ pub async fn me( .await .unwrap_or_default(); Ok(Json(MeUser { - login: user.provider_login, avatar_url: user.avatar_url, email: user.email, name: user.name, @@ -242,3 +241,130 @@ pub async fn delete_ssh_key( Ok(StatusCode::NO_CONTENT.into_response()) } + +pub async fn get_git_providers( + TypedHeader(cookie): TypedHeader, + State(state): State, +) -> Result { + let user = state.authenticate(&cookie).await?; + let all_oauths = state.db.get_user_all_oauth(user.id).await?; + + let mut git_providers = Vec::new(); + for (auth_provider, (_, config)) in state.auth.clients.read().await.iter() { + let oauth = all_oauths + .iter() + .find(|o| o.provider == auth_provider.to_string()); + + let git_provider = GitProvider { + auth_provider: *auth_provider, + connected: oauth.is_some(), + avatar_url: oauth.as_ref().and_then(|o| o.avatar_url.clone()), + email: oauth.as_ref().and_then(|o| o.email.clone()), + name: oauth.as_ref().and_then(|o| o.name.clone()), + read_repo: oauth.as_ref().map(|o| o.read_repo), + scopes: config.scopes.iter().map(|s| s.to_string()).collect(), + all_scopes: config + .read_repo_scopes + .iter() + .map(|s| s.to_string()) + .collect(), + }; + git_providers.push(git_provider); + } + + Ok(Json(git_providers)) +} + +pub async fn connect_git_provider( + TypedHeader(cookie): TypedHeader, + Host(hostname): Host, + Query(query): Query>, + State(state): State, +) -> Result { + let user = state.authenticate(&cookie).await?; + let provider_name = query + .get("provider") + .ok_or_else(|| ApiError::InvalidRequest("no provider in query string".to_string()))?; + + let oauth = state.db.get_user_oauth(user.id, provider_name).await?; + if oauth.is_some() { + return Err(ApiError::InvalidRequest( + "provider already connected".to_string(), + ))?; + } + + let (headers, url) = + create_oauth_connection(&state, Some(user.id), false, &hostname, &query).await?; + + Ok((headers, Json(NewSessionResponse { url })).into_response()) +} + +pub async fn update_scope( + TypedHeader(cookie): TypedHeader, + Host(hostname): Host, + Query(query): Query>, + State(state): State, +) -> Result { + let user = state.authenticate(&cookie).await?; + let all_oauths = state.db.get_user_all_oauth(user.id).await?; + let provider_name = query + .get("provider") + .ok_or_else(|| ApiError::InvalidRequest("no provider in query string".to_string()))?; + if !all_oauths.iter().any(|o| &o.provider == provider_name) { + return Err(ApiError::InvalidRequest( + "provider isn't connected".to_string(), + ))?; + } + + let read_repo = query + .get("read_repo") + .ok_or_else(|| ApiError::InvalidRequest("no read_repo in query string".to_string()))?; + + let read_repo = match read_repo.as_str() { + "yes" => true, + "no" => false, + _ => { + return Err(ApiError::InvalidRequest( + "read_repo should be either yes or no".to_string(), + )) + } + }; + + let (headers, url) = + create_oauth_connection(&state, Some(user.id), read_repo, &hostname, &query).await?; + + Ok((headers, Json(NewSessionResponse { url })).into_response()) +} + +pub async fn disconnect_git_provider( + TypedHeader(cookie): TypedHeader, + Query(query): Query>, + State(state): State, +) -> Result { + let user = state.authenticate(&cookie).await?; + let all_oauths = state.db.get_user_all_oauth(user.id).await?; + if all_oauths.len() < 2 { + return Err(ApiError::InvalidRequest( + "You can't disconnect all git providers".to_string(), + ))?; + } + + let provider_name = query + .get("provider") + .ok_or_else(|| ApiError::InvalidRequest("no provider in query string".to_string()))?; + let oauth = state + .db + .get_user_oauth(user.id, provider_name) + .await? + .ok_or_else(|| ApiError::InvalidRequest("provider isn't connected".to_string()))?; + + entities::oauth_connection::ActiveModel { + id: ActiveValue::Set(oauth.id), + deleted_at: ActiveValue::Set(Some(Utc::now().into())), + ..Default::default() + } + .update(&state.db.conn) + .await?; + + Ok(()) +} diff --git a/lapdev-api/src/auth.rs b/lapdev-api/src/auth.rs index 8883675..bf887e0 100644 --- a/lapdev-api/src/auth.rs +++ b/lapdev-api/src/auth.rs @@ -17,6 +17,7 @@ pub struct AuthConfig { pub auth_url: &'static str, pub token_url: &'static str, pub scopes: &'static [&'static str], + pub read_repo_scope: &'static str, pub read_repo_scopes: &'static [&'static str], } @@ -27,6 +28,7 @@ impl AuthConfig { auth_url: "https://github.com/login/oauth/authorize", token_url: "https://github.com/login/oauth/access_token", scopes: &["read:user", "user:email"], + read_repo_scope: "repo", read_repo_scopes: &["read:user", "user:email", "repo"], }; pub const GITLAB: Self = AuthConfig { @@ -35,6 +37,7 @@ impl AuthConfig { auth_url: "https://gitlab.com/oauth/authorize", token_url: "https://gitlab.com/oauth/token", scopes: &["read_user"], + read_repo_scope: "read_repository", read_repo_scopes: &["read_user", "read_repository"], }; } @@ -88,17 +91,17 @@ impl Auth { &self, provider: AuthProvider, redirect_url: &str, - no_read_repo: bool, + read_repo: bool, ) -> Result<(String, String)> { let clients = self.clients.read().await; let (client, config) = clients .get(&provider) .ok_or_else(|| anyhow::anyhow!("can't find provider"))?; let mut client = client.authorize_url(oauth2::CsrfToken::new_random); - for scope in if no_read_repo { - config.scopes - } else { + for scope in if read_repo { config.read_repo_scopes + } else { + config.scopes } { client = client.add_scope(oauth2::Scope::new(scope.to_string())); } diff --git a/lapdev-api/src/organization.rs b/lapdev-api/src/organization.rs index a114f0b..664528a 100644 --- a/lapdev-api/src/organization.rs +++ b/lapdev-api/src/organization.rs @@ -40,19 +40,10 @@ pub async fn create_organization( let now = Utc::now(); let txn = state.db.conn.begin().await?; - let org = entities::organization::ActiveModel { - id: ActiveValue::Set(Uuid::new_v4()), - deleted_at: ActiveValue::Set(None), - name: ActiveValue::Set(name.to_string()), - auto_start: ActiveValue::Set(true), - allow_workspace_change_auto_start: ActiveValue::Set(true), - auto_stop: ActiveValue::Set(Some(3600)), - allow_workspace_change_auto_stop: ActiveValue::Set(true), - last_auto_stop_check: ActiveValue::Set(None), - usage_limit: ActiveValue::Set(30 * 60 * 60), - } - .insert(&txn) - .await?; + let org = state + .db + .create_new_organization(&txn, name.to_string()) + .await?; entities::organization_member::ActiveModel { created_at: ActiveValue::Set(Utc::now().into()), diff --git a/lapdev-api/src/project.rs b/lapdev-api/src/project.rs index baf1d92..e7ae7f5 100644 --- a/lapdev-api/src/project.rs +++ b/lapdev-api/src/project.rs @@ -122,10 +122,14 @@ pub async fn get_project_branches( info: RequestInfo, ) -> Result { let (user, project) = state.get_project(&cookie, org_id, project_id).await?; - let auth = if let Ok(Some(user)) = state.db.get_user(project.created_by).await { - (user.provider_login, user.access_token) + let auth = if let Ok(Some(oauth)) = state.db.get_oauth(project.oauth_id).await { + (oauth.provider_login, oauth.access_token) } else { - (user.provider_login.clone(), user.access_token.clone()) + let oauth = state + .conductor + .find_match_oauth_for_repo(&user, &project.repo_url) + .await?; + (oauth.provider_login, oauth.access_token) }; let branches = state .conductor diff --git a/lapdev-api/src/router.rs b/lapdev-api/src/router.rs index ef4af56..67ad449 100644 --- a/lapdev-api/src/router.rs +++ b/lapdev-api/src/router.rs @@ -165,6 +165,19 @@ fn v1_api_routes(additional_router: Option>) -> Router, + connect_provider: Option, } -pub(crate) async fn new_session( - Host(hostname): Host, - Query(query): Query>, - State(state): State, -) -> Result { - let next = query - .get("next") - .ok_or_else(|| ApiError::InvalidRequest("no next url in query string".to_string()))?; +pub async fn create_oauth_connection( + state: &CoreState, + connect_user_id: Option, + read_repo: bool, + hostname: &str, + query: &HashMap, +) -> Result<(HeaderMap, String), ApiError> { let host = query .get("host") .ok_or_else(|| ApiError::InvalidRequest("no host in query string".to_string()))?; + let next = query + .get("next") + .ok_or_else(|| ApiError::InvalidRequest("no next url in query string".to_string()))?; let provider = query .get("provider") .ok_or_else(|| ApiError::InvalidRequest("no provider in query string".to_string()))?; + let redirect_url = format!( + "{host}/api/private/session/authorize?provider={provider}{}&next={next}", + if connect_user_id.is_some() { + "&connect_provider=yes" + } else { + "" + } + ); + let provider = AuthProvider::from_str(provider) .map_err(|_| ApiError::InvalidRequest(format!("provider {provider} is invalid")))?; - - let redirect_url = - format!("{host}/api/private/session/authorize?provider={provider}&next={next}"); - let oauth_no_read_repo = state.db.oauth_no_read_repo().await.unwrap_or(false); let (url, csrf) = state .auth - .authorize_url(provider, &redirect_url, oauth_no_read_repo) + .authorize_url(provider, &redirect_url, read_repo) .await?; let mut claims = Claims::new()?; claims.add_additional(OAUTH_STATE, csrf.clone())?; claims.add_additional(REDIRECT_URL, redirect_url.clone())?; + claims.add_additional(READ_REPO, read_repo)?; + if let Some(id) = connect_user_id { + claims.add_additional(CONNECT_USER, id.to_string())?; + } let token = pasetors::local::encrypt(&state.auth_token_key, &claims, None, None)?; - let cookie = format!("{TOKEN_COOKIE_NAME}={token}; Path=/"); + let cookie = format!("{OAUTH_STATE_COOKIE}={token}; Path=/"); let cookie = if let Some(hostname) = hostname.split(':').next() { if hostname.parse::().is_err() { format!("{cookie}; Domain=.{hostname}") @@ -75,8 +91,16 @@ pub(crate) async fn new_session( }; let mut headers = HeaderMap::new(); headers.insert(SET_COOKIE, cookie.parse()?); + Ok((headers, url)) +} - Ok((headers, Json(NewSessionResponse { url, state: csrf })).into_response()) +pub(crate) async fn new_session( + Host(hostname): Host, + Query(query): Query>, + State(state): State, +) -> Result { + let (headers, url) = create_oauth_connection(&state, None, false, &hostname, &query).await?; + Ok((headers, Json(NewSessionResponse { url })).into_response()) } pub(crate) async fn session_authorize( @@ -86,7 +110,7 @@ pub(crate) async fn session_authorize( info: RequestInfo, TypedHeader(cookie): TypedHeader, ) -> Result { - let token = state.token(&cookie)?; + let token = state.auth_state_token(&cookie)?; let claims = token.payload_claims().ok_or(ApiError::InvalidAuthToken)?; let redirect_url = claims @@ -145,18 +169,88 @@ pub(crate) async fn session_authorize( } }; - let user = match entities::user::Entity::find() - .filter(entities::user::Column::Provider.eq(query.provider.to_string())) - .filter(entities::user::Column::ProviderId.eq(provider_user.id)) - .filter(entities::user::Column::DeletedAt.is_null()) + if query.connect_provider.as_deref() == Some("yes") { + let now = Utc::now(); + + let connect_user = claims + .get_claim(CONNECT_USER) + .ok_or_else(|| ApiError::InvalidRequest("doens't have connect user".to_string()))?; + let connect_user: String = serde_json::from_value(connect_user.to_owned()) + .map_err(|_| ApiError::InvalidRequest("invalid connect user".to_string()))?; + let user_id = Uuid::from_str(&connect_user)?; + + let read_repo = claims + .get_claim(READ_REPO) + .ok_or_else(|| ApiError::InvalidRequest("doens't have read repo".to_string()))?; + let read_repo: bool = serde_json::from_value(read_repo.to_owned()) + .map_err(|_| ApiError::InvalidRequest("invalid read repo".to_string()))?; + + match entities::oauth_connection::Entity::find() + .filter(entities::oauth_connection::Column::Provider.eq(query.provider.to_string())) + .filter(entities::oauth_connection::Column::UserId.eq(user_id)) + .filter(entities::oauth_connection::Column::DeletedAt.is_null()) + .one(&state.db.conn) + .await? + { + Some(c) => { + entities::oauth_connection::ActiveModel { + id: ActiveValue::Set(c.id), + provider_login: ActiveValue::Set(provider_user.login), + access_token: ActiveValue::Set(token.secret().to_string()), + avatar_url: ActiveValue::Set(provider_user.avatar_url.clone()), + email: ActiveValue::Set(provider_user.email.clone()), + name: ActiveValue::Set(provider_user.name.clone()), + read_repo: ActiveValue::Set(read_repo), + ..Default::default() + } + .update(&state.db.conn) + .await?; + } + None => { + entities::oauth_connection::ActiveModel { + id: ActiveValue::Set(Uuid::new_v4()), + user_id: ActiveValue::Set(user_id), + created_at: ActiveValue::Set(now.into()), + deleted_at: ActiveValue::Set(None), + provider: ActiveValue::Set(query.provider.to_string()), + provider_id: ActiveValue::Set(provider_user.id), + provider_login: ActiveValue::Set(provider_user.login), + access_token: ActiveValue::Set(token.secret().to_string()), + avatar_url: ActiveValue::Set(provider_user.avatar_url), + email: ActiveValue::Set(provider_user.email), + name: ActiveValue::Set(provider_user.name), + read_repo: ActiveValue::Set(read_repo), + } + .insert(&state.db.conn) + .await?; + } + } + + return Ok(Redirect::temporary(query.next.as_deref().unwrap_or("/")).into_response()); + } + + let user = match entities::oauth_connection::Entity::find() + .filter(entities::oauth_connection::Column::Provider.eq(query.provider.to_string())) + .filter(entities::oauth_connection::Column::ProviderId.eq(provider_user.id)) + .filter(entities::oauth_connection::Column::DeletedAt.is_null()) .one(&state.db.conn) .await? { - Some(user) => { - entities::user::ActiveModel { - id: ActiveValue::Set(user.id), + Some(conn) => { + let conn = entities::oauth_connection::ActiveModel { + id: ActiveValue::Set(conn.id), provider_login: ActiveValue::Set(provider_user.login), access_token: ActiveValue::Set(token.secret().to_string()), + avatar_url: ActiveValue::Set(provider_user.avatar_url.clone()), + email: ActiveValue::Set(provider_user.email.clone()), + name: ActiveValue::Set(provider_user.name.clone()), + ..Default::default() + } + .update(&state.db.conn) + .await?; + + entities::user::ActiveModel { + id: ActiveValue::Set(conn.user_id), avatar_url: ActiveValue::Set(provider_user.avatar_url), email: ActiveValue::Set(provider_user.email), name: ActiveValue::Set(provider_user.name), diff --git a/lapdev-api/src/state.rs b/lapdev-api/src/state.rs index 939697e..e4175fa 100644 --- a/lapdev-api/src/state.rs +++ b/lapdev-api/src/state.rs @@ -31,6 +31,7 @@ use crate::{ auth::{Auth, AuthConfig}, cert::{load_cert, CertStore}, github::GithubClient, + session::OAUTH_STATE_COOKIE, }; pub const TOKEN_COOKIE_NAME: &str = "token"; @@ -157,8 +158,8 @@ impl CoreState { Ok(()) } - pub fn token(&self, cookie: &headers::Cookie) -> Result { - let token = cookie.get(TOKEN_COOKIE_NAME).ok_or(ApiError::NoAuthToken)?; + fn cookie_token(&self, cookie: &headers::Cookie, name: &str) -> Result { + let token = cookie.get(name).ok_or(ApiError::NoAuthToken)?; let untrusted_token = UntrustedToken::try_from(token).map_err(|_| ApiError::InvalidAuthToken)?; let token = pasetors::local::decrypt( @@ -172,6 +173,14 @@ impl CoreState { Ok(token) } + pub fn auth_state_token(&self, cookie: &headers::Cookie) -> Result { + self.cookie_token(cookie, OAUTH_STATE_COOKIE) + } + + pub fn token(&self, cookie: &headers::Cookie) -> Result { + self.cookie_token(cookie, TOKEN_COOKIE_NAME) + } + pub async fn require_enterprise(&self) -> Result<(), ApiError> { if self.conductor.enterprise.has_valid_license().await { return Ok(()); diff --git a/lapdev-common/src/console.rs b/lapdev-common/src/console.rs index deb4056..c6fef5b 100644 --- a/lapdev-common/src/console.rs +++ b/lapdev-common/src/console.rs @@ -7,7 +7,6 @@ use crate::UserRole; #[derive(Serialize, Deserialize)] pub struct NewSessionResponse { pub url: String, - pub state: String, } #[derive(Serialize, Deserialize, Clone, Debug)] @@ -23,7 +22,6 @@ pub struct Organization { #[derive(Serialize, Deserialize, Clone, Debug)] pub struct MeUser { - pub login: String, pub avatar_url: Option, pub email: Option, pub name: Option, diff --git a/lapdev-common/src/lib.rs b/lapdev-common/src/lib.rs index 7839e6e..81415e0 100644 --- a/lapdev-common/src/lib.rs +++ b/lapdev-common/src/lib.rs @@ -242,7 +242,16 @@ pub enum WorkspaceHostStatus { } #[derive( - Hash, EnumString, strum_macros::Display, PartialEq, Eq, Debug, Deserialize, Serialize, Clone, + Hash, + EnumString, + strum_macros::Display, + PartialEq, + Eq, + Debug, + Deserialize, + Serialize, + Clone, + Copy, )] pub enum AuthProvider { Github, @@ -277,15 +286,6 @@ pub enum WorkspaceStatus { Deleted, } -impl WorkspaceStatus { - pub const RUNNING: &'static [WorkspaceStatus] = &[ - WorkspaceStatus::New, - WorkspaceStatus::PrebuildBuilding, - WorkspaceStatus::Building, - WorkspaceStatus::Running, - ]; -} - #[derive(Serialize, Deserialize, Debug, Clone, Hash, Eq, PartialEq)] pub struct WorkspaceInfo { pub name: String, @@ -754,3 +754,15 @@ pub struct ClusterUser { pub struct UpdateClusterUser { pub cluster_admin: bool, } + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] +pub struct GitProvider { + pub auth_provider: AuthProvider, + pub avatar_url: Option, + pub name: Option, + pub email: Option, + pub connected: bool, + pub read_repo: Option, + pub scopes: Vec, + pub all_scopes: Vec, +} diff --git a/lapdev-conductor/src/server.rs b/lapdev-conductor/src/server.rs index 7ae16bb..62a1e64 100644 --- a/lapdev-conductor/src/server.rs +++ b/lapdev-conductor/src/server.rs @@ -12,12 +12,12 @@ use data_encoding::BASE64_MIME; use futures::{channel::mpsc::UnboundedReceiver, stream::AbortHandle, SinkExt, StreamExt}; use git2::{Cred, FetchOptions, FetchPrune, RemoteCallbacks, Repository}; use lapdev_common::{ - utils::rand_string, AuditAction, AuditResourceKind, BuildTarget, CreateWorkspaceRequest, - DeleteWorkspaceRequest, GitBranch, NewProject, NewProjectResponse, NewWorkspace, - NewWorkspaceResponse, PrebuildInfo, PrebuildStatus, PrebuildUpdateEvent, RepoBuildInfo, - RepoBuildOutput, RepoContent, RepoContentPosition, RepoSource, StartWorkspaceRequest, - StopWorkspaceRequest, UsageResourceKind, WorkspaceStatus, WorkspaceUpdateEvent, - LAPDEV_DEFAULT_OSUSER, + utils::rand_string, AuditAction, AuditResourceKind, AuthProvider, BuildTarget, + CreateWorkspaceRequest, DeleteWorkspaceRequest, GitBranch, NewProject, NewProjectResponse, + NewWorkspace, NewWorkspaceResponse, PrebuildInfo, PrebuildStatus, PrebuildUpdateEvent, + RepoBuildInfo, RepoBuildOutput, RepoContent, RepoContentPosition, RepoSource, + StartWorkspaceRequest, StopWorkspaceRequest, UsageResourceKind, WorkspaceStatus, + WorkspaceUpdateEvent, LAPDEV_DEFAULT_OSUSER, }; use lapdev_common::{PrebuildReplicaStatus, WorkspaceHostStatus}; use lapdev_db::{api::DbApi, entities}; @@ -626,6 +626,35 @@ impl Conductor { }) } + pub async fn find_match_oauth_for_repo( + &self, + user: &entities::user::Model, + repo: &str, + ) -> Result { + let oauths = self.db.get_user_all_oauth(user.id).await?; + + let repo = repo.to_lowercase(); + + if let Some(oauth) = if repo.contains("github.com") { + oauths + .iter() + .find(|o| o.provider == AuthProvider::Github.to_string()) + } else if repo.contains("gitlab.com") { + oauths + .iter() + .find(|o| o.provider == AuthProvider::Gitlab.to_string()) + } else { + None + } { + return Ok(oauth.clone()); + } + + let oauth = oauths + .first() + .ok_or_else(|| anyhow!("user doesn't have any oauth connections"))?; + Ok(oauth.to_owned()) + } + pub async fn create_project( &self, user: entities::user::Model, @@ -635,11 +664,14 @@ impl Conductor { user_agent: Option, ) -> Result { let repo = self.format_repo_url(&project.repo); + + let oauth = self.find_match_oauth_for_repo(&user, &repo).await?; + let repo = self .get_raw_repo_details( &repo, None, - (user.provider_login.clone(), user.access_token.clone()), + (oauth.provider_login.clone(), oauth.access_token.clone()), ) .await?; @@ -658,12 +690,14 @@ impl Conductor { id: ActiveValue::Set(id), name: ActiveValue::Set(repo.name.clone()), created_at: ActiveValue::Set(now.into()), + deleted_at: ActiveValue::Set(None), repo_url: ActiveValue::Set(repo.url.clone()), repo_name: ActiveValue::Set(repo.name.clone()), organization_id: ActiveValue::Set(org_id), created_by: ActiveValue::Set(user.id), + oauth_id: ActiveValue::Set(oauth.id), machine_type_id: ActiveValue::Set(project.machine_type_id), - ..Default::default() + env: ActiveValue::Set(None), }; let project = project.insert(&txn).await?; @@ -1078,6 +1112,8 @@ impl Conductor { let (id_rsa, public_key) = self.generate_key_pair()?; let osuser = self.get_osuser(user).await; + self.enterprise.check_organization_limit(org).await?; + let txn = self.db.conn.begin().await?; if let Some(quota) = self .enterprise @@ -1502,10 +1538,13 @@ impl Conductor { ip: Option, user_agent: Option, ) -> Result { - let auth = if let Ok(Some(user)) = self.db.get_user(project.created_by).await { - (user.provider_login, user.access_token) + let auth = if let Ok(Some(oauth)) = self.db.get_oauth(project.oauth_id).await { + (oauth.provider_login, oauth.access_token) } else { - (user.provider_login.clone(), user.access_token.clone()) + let oauth = self + .find_match_oauth_for_repo(user, &project.repo_url) + .await?; + (oauth.provider_login.clone(), oauth.access_token.clone()) }; let branches = self .project_branches(user.id, project, auth.clone(), ip, user_agent) @@ -1557,11 +1596,12 @@ impl Conductor { if let Some(project) = project { project } else { + let oauth = self.find_match_oauth_for_repo(user, &repo).await?; return self .get_raw_repo_details( &repo, branch, - (user.provider_login.clone(), user.access_token.clone()), + (oauth.provider_login.clone(), oauth.access_token.clone()), ) .await; } @@ -2425,6 +2465,9 @@ impl Conductor { )); } + let org = self.db.get_organization(workspace.organization_id).await?; + self.enterprise.check_organization_limit(&org).await?; + let txn = self.db.conn.begin().await?; if let Some(quota) = self .enterprise diff --git a/lapdev-dashboard/main.css b/lapdev-dashboard/main.css index a6be362..76b9cd8 100644 --- a/lapdev-dashboard/main.css +++ b/lapdev-dashboard/main.css @@ -1063,6 +1063,11 @@ input[type="range"]::-ms-fill-lower { margin-bottom: 1rem; } +.my-8 { + margin-top: 2rem; + margin-bottom: 2rem; +} + .-mb-px { margin-bottom: -1px; } @@ -1498,6 +1503,10 @@ input[type="range"]::-ms-fill-lower { row-gap: 0.75rem; } +.gap-y-6 { + row-gap: 1.5rem; +} + .gap-y-8 { row-gap: 2rem; } @@ -1693,11 +1702,6 @@ input[type="range"]::-ms-fill-lower { background-color: rgb(249 250 251 / var(--tw-bg-opacity)); } -.bg-gray-800 { - --tw-bg-opacity: 1; - background-color: rgb(31 41 55 / var(--tw-bg-opacity)); -} - .bg-gray-900\/50 { background-color: rgb(17 24 39 / 0.5); } @@ -1915,6 +1919,10 @@ input[type="range"]::-ms-fill-lower { padding-bottom: 1rem; } +.pb-8 { + padding-bottom: 2rem; +} + .pl-10 { padding-left: 2.5rem; } @@ -2665,11 +2673,6 @@ input[type="range"]::-ms-fill-lower { .md\:p-5 { padding: 1.25rem; } - - .md\:text-2xl { - font-size: 1.5rem; - line-height: 2rem; - } } @media (min-width: 1024px) { diff --git a/lapdev-dashboard/src/app.rs b/lapdev-dashboard/src/app.rs index a24db7a..241b313 100644 --- a/lapdev-dashboard/src/app.rs +++ b/lapdev-dashboard/src/app.rs @@ -11,6 +11,7 @@ use crate::{ account::{get_login, AccountSettings, JoinView, Login}, audit_log::AuditLogView, cluster::{ClusterSettings, ClusterUsersView, MachineTypeView, WorkspaceHostView}, + git_provider::GitProviderView, license::{LicenseView, SignLicenseView}, nav::{AdminSideNav, NavExpanded, SideNav, TopNav}, organization::{NewOrgModal, OrgMembers, OrgSettings}, @@ -91,6 +92,7 @@ pub fn App() -> impl IntoView { } /> } /> } /> + } /> } /> } /> } /> diff --git a/lapdev-dashboard/src/git_provider.rs b/lapdev-dashboard/src/git_provider.rs new file mode 100644 index 0000000..8473084 --- /dev/null +++ b/lapdev-dashboard/src/git_provider.rs @@ -0,0 +1,391 @@ +use std::time::Duration; + +use anyhow::Result; +use gloo_net::http::Request; +use lapdev_common::{console::NewSessionResponse, AuthProvider, GitProvider}; +use leptos::{ + component, create_action, create_local_resource, create_rw_signal, document, + event_target_checked, set_timeout, view, window, For, IntoView, RwSignal, Signal, SignalGet, + SignalGetUntracked, SignalSet, SignalUpdate, +}; +use leptos_router::use_location; +use wasm_bindgen::{JsCast, UnwrapThrowExt}; +use web_sys::FocusEvent; + +use crate::modal::{CreationModal, ErrorResponse}; + +async fn get_git_providers() -> Result> { + let resp = Request::get("/api/v1/account/git_providers").send().await?; + let machine_types: Vec = resp.json().await?; + Ok(machine_types) +} + +async fn connect_oauth( + provider: AuthProvider, + update_read_repo: Option, +) -> Result<(), ErrorResponse> { + let location = use_location(); + let next = format!( + "{}{}", + location.pathname.get_untracked(), + location.search.get_untracked() + ); + let location = window().window().location(); + let url = if update_read_repo.is_some() { + "/api/v1/account/git_providers/update_scope" + } else { + "/api/v1/account/git_providers/connect" + }; + + let read_repo = if update_read_repo.unwrap_or(false) { + "yes".to_string() + } else { + "no".to_string() + }; + + let resp = Request::put(url) + .query([ + ("provider", &provider.to_string()), + ("next", &next), + ("host", &location.origin().unwrap_or_default()), + ("read_repo", &read_repo), + ]) + .send() + .await?; + + if resp.status() != 200 { + let error = resp + .json::() + .await + .unwrap_or_else(|_| ErrorResponse { + error: "Internal Server Error".to_string(), + }); + return Err(error); + } + + let resp: NewSessionResponse = resp.json().await?; + let _ = window().location().set_href(&resp.url); + + Ok(()) +} + +async fn disconnect_oauth(provider: AuthProvider) -> Result<(), ErrorResponse> { + let location = use_location(); + let next = format!( + "{}{}{}", + window().location().origin().unwrap_or_default(), + location.pathname.get_untracked(), + location.search.get_untracked(), + ); + + let resp = Request::put("/api/v1/account/git_providers/disconnect") + .query([("provider", &provider.to_string()), ("next", &next)]) + .send() + .await?; + + if resp.status() != 200 { + let error = resp + .json::() + .await + .unwrap_or_else(|_| ErrorResponse { + error: "Internal Server Error".to_string(), + }); + return Err(error); + } + + Ok(()) +} + +#[component] +pub fn GitProviderView() -> impl IntoView { + let error = create_rw_signal(None); + + let git_provider_counter = create_rw_signal(0); + let git_providers = create_local_resource( + move || git_provider_counter.get(), + |_| async move { get_git_providers().await.unwrap_or_default() }, + ); + let git_providers = Signal::derive(move || git_providers.get().unwrap_or_default()); + + view! { +
+
+
+
+
+ Git Providers +
+

{"Connect to a git provider or manage your permissions of a git provider."}

+
+
+
+ + { move || if let Some(error) = error.get() { + view! { +
+ { error } +
+ }.into_view() + } else { + view!{}.into_view() + }} + +
+ + } + } + /> +
+ +
+ } +} + +#[component] +fn GitProviderControl( + provider: AuthProvider, + error: RwSignal>, + git_provider_counter: RwSignal, + update_scope_modal_hidden: RwSignal, +) -> impl IntoView { + let dropdown_hidden = create_rw_signal(true); + let toggle_dropdown = move |_| { + if dropdown_hidden.get_untracked() { + dropdown_hidden.set(false); + } else { + dropdown_hidden.set(true); + } + }; + + let connect_action = create_action(move |_| async move { + if let Err(e) = connect_oauth(provider, None).await { + error.set(Some(e.error)); + } + }); + let connect = move |_| { + dropdown_hidden.set(true); + error.set(None); + connect_action.dispatch(()); + }; + + let disconnect_action = create_action(move |_| async move { + if let Err(e) = disconnect_oauth(provider).await { + error.set(Some(e.error)); + } else { + git_provider_counter.update(|c| { + *c += 1; + }); + } + }); + let disconnect = move |_| { + dropdown_hidden.set(true); + error.set(None); + disconnect_action.dispatch(()); + }; + + let on_focusout = move |e: FocusEvent| { + let node = e + .current_target() + .unwrap_throw() + .unchecked_into::(); + + set_timeout( + move || { + let has_focus = if let Some(active) = document().active_element() { + let active: web_sys::Node = active.into(); + node.contains(Some(&active)) + } else { + false + }; + if !has_focus && !dropdown_hidden.get_untracked() { + dropdown_hidden.set(true); + } + }, + Duration::from_secs(0), + ); + }; + + view! { +
+ // +
+ + +
+
+ } +} + +#[component] +fn GitProviderItem( + provider: GitProvider, + error: RwSignal>, + git_provider_counter: RwSignal, +) -> impl IntoView { + let update_scope_modal_hidden = create_rw_signal(true); + + let icon = match provider.auth_provider { + AuthProvider::Github => view! { + + }, + AuthProvider::Gitlab => view! { + + }, + }; + + view! { +
+
+
+ {icon} +

{provider.auth_provider.to_string()}

+
+
+
+
+ user photo +
+

{provider.name}

+

{provider.email}

+
+
+

+ Not Connected +

+
+
+
+
+

+ Permissions +

+

+ { if provider.read_repo == Some(true) { provider.all_scopes.join(", ") } else { provider.scopes.join(", ") } } +

+
+
+
+ +
+
+
+ + } +} + +#[component] +pub fn UpdateScopeModal( + provider: AuthProvider, + read_repo: bool, + update_scope_modal_hidden: RwSignal, +) -> impl IntoView { + let read_repo = create_rw_signal(read_repo); + + let body = view! { +
+ Turn this option on if you want to open private repo in Lapdev +
+
+
+ +
+ +
+ }; + + let action = create_action(move |_| connect_oauth(provider, Some(read_repo.get_untracked()))); + view! { + + } +} diff --git a/lapdev-dashboard/src/lib.rs b/lapdev-dashboard/src/lib.rs index 866454c..e63534c 100644 --- a/lapdev-dashboard/src/lib.rs +++ b/lapdev-dashboard/src/lib.rs @@ -3,6 +3,7 @@ pub mod app; pub mod audit_log; pub mod cluster; pub mod datepicker; +pub mod git_provider; pub mod license; pub mod modal; pub mod nav; diff --git a/lapdev-dashboard/src/nav.rs b/lapdev-dashboard/src/nav.rs index 47cc276..a4c2c0c 100644 --- a/lapdev-dashboard/src/nav.rs +++ b/lapdev-dashboard/src/nav.rs @@ -206,6 +206,14 @@ pub fn SideNavAccount() -> impl IntoView { SSH Keys + +
  • + + Git Providers + +
  • } } diff --git a/lapdev-db/src/api.rs b/lapdev-db/src/api.rs index 5e602f5..e5d5a6e 100644 --- a/lapdev-db/src/api.rs +++ b/lapdev-db/src/api.rs @@ -24,7 +24,8 @@ use super::entities::workspace; pub const LAPDEV_CLUSTER_NOT_INITIATED: &str = "lapdev-cluster-not-initiated"; const LAPDEV_API_AUTH_TOKEN_KEY: &str = "lapdev-api-auth-token-key"; -const LAPDEV_OAUTH_NO_READ_REPO: &str = "lapdev-oauth-no-read-repo"; +const LAPDEV_DEFAULT_USAGE_LIMIT: &str = "lapdev-default-org-usage-limit"; +const LAPDEV_DEFAULT_RUNNING_WORKSPACE_LIMIT: &str = "lapdev-default-org-running-workspace-limit"; #[derive(Clone)] pub struct DbApi { @@ -100,12 +101,6 @@ impl DbApi { self.generate_api_auth_token_key().await } - pub async fn oauth_no_read_repo(&self) -> Result { - self.get_config(LAPDEV_OAUTH_NO_READ_REPO) - .await - .map(|v| v == "yes") - } - async fn get_api_auth_token_key(&self) -> Result> { let key = self.get_config(LAPDEV_API_AUTH_TOKEN_KEY).await?; let key = STANDARD.decode(key)?; @@ -137,6 +132,15 @@ impl DbApi { Ok(model.value) } + async fn get_config_in_txn(&self, txn: &DatabaseTransaction, name: &str) -> Result { + let model = entities::config::Entity::find() + .filter(entities::config::Column::Name.eq(name)) + .one(txn) + .await? + .ok_or_else(|| anyhow!("no config found"))?; + Ok(model.value) + } + pub async fn get_base_hostname(&self) -> Result { self.get_config(LAPDEV_BASE_HOSTNAME).await } @@ -254,6 +258,42 @@ impl DbApi { Ok(models) } + pub async fn create_new_organization( + &self, + txn: &DatabaseTransaction, + name: String, + ) -> Result { + let default_usage_limit = self + .get_config_in_txn(txn, LAPDEV_DEFAULT_USAGE_LIMIT) + .await + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(0); + let default_running_workspace_limit = self + .get_config_in_txn(txn, LAPDEV_DEFAULT_RUNNING_WORKSPACE_LIMIT) + .await + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(0); + + let org = entities::organization::ActiveModel { + id: ActiveValue::Set(Uuid::new_v4()), + deleted_at: ActiveValue::Set(None), + name: ActiveValue::Set(name.to_string()), + auto_start: ActiveValue::Set(true), + allow_workspace_change_auto_start: ActiveValue::Set(true), + auto_stop: ActiveValue::Set(Some(3600)), + allow_workspace_change_auto_stop: ActiveValue::Set(true), + last_auto_stop_check: ActiveValue::Set(None), + usage_limit: ActiveValue::Set(default_usage_limit), + running_workspace_limit: ActiveValue::Set(default_running_workspace_limit), + } + .insert(txn) + .await?; + + Ok(org) + } + pub async fn create_new_user( &self, txn: &DatabaseTransaction, @@ -274,23 +314,30 @@ impl DbApi { }; let now = Utc::now(); - let org = entities::organization::ActiveModel { - id: ActiveValue::Set(uuid::Uuid::new_v4()), - name: ActiveValue::Set("Personal".to_string()), - auto_start: ActiveValue::Set(true), - auto_stop: ActiveValue::Set(Some(3600)), - allow_workspace_change_auto_start: ActiveValue::Set(true), - allow_workspace_change_auto_stop: ActiveValue::Set(true), - ..Default::default() + let org = self + .create_new_organization(txn, "Personal".to_string()) + .await?; + + let user = entities::user::ActiveModel { + id: ActiveValue::Set(Uuid::new_v4()), + created_at: ActiveValue::Set(now.into()), + deleted_at: ActiveValue::Set(None), + provider: ActiveValue::Set(provider.to_string()), + osuser: ActiveValue::Set(format!("{provider}_{}", provider_user.login)), + avatar_url: ActiveValue::Set(provider_user.avatar_url.clone()), + email: ActiveValue::Set(provider_user.email.clone()), + name: ActiveValue::Set(provider_user.name.clone()), + current_organization: ActiveValue::Set(org.id), + cluster_admin: ActiveValue::Set(cluster_admin), } .insert(txn) .await?; - let user = entities::user::ActiveModel { + entities::oauth_connection::ActiveModel { id: ActiveValue::Set(Uuid::new_v4()), + user_id: ActiveValue::Set(user.id), created_at: ActiveValue::Set(now.into()), deleted_at: ActiveValue::Set(None), - osuser: ActiveValue::Set(format!("{provider}_{}", provider_user.login)), provider: ActiveValue::Set(provider.to_string()), provider_id: ActiveValue::Set(provider_user.id), provider_login: ActiveValue::Set(provider_user.login), @@ -298,8 +345,7 @@ impl DbApi { avatar_url: ActiveValue::Set(provider_user.avatar_url), email: ActiveValue::Set(provider_user.email), name: ActiveValue::Set(provider_user.name), - current_organization: ActiveValue::Set(org.id), - cluster_admin: ActiveValue::Set(cluster_admin), + read_repo: ActiveValue::Set(false), } .insert(txn) .await?; @@ -326,6 +372,41 @@ impl DbApi { Ok(model) } + pub async fn get_oauth(&self, id: Uuid) -> Result> { + let model = entities::oauth_connection::Entity::find() + .filter(entities::oauth_connection::Column::Id.eq(id)) + .filter(entities::oauth_connection::Column::DeletedAt.is_null()) + .one(&self.conn) + .await?; + Ok(model) + } + + pub async fn get_user_all_oauth( + &self, + user_id: Uuid, + ) -> Result> { + let model = entities::oauth_connection::Entity::find() + .filter(entities::oauth_connection::Column::UserId.eq(user_id)) + .filter(entities::oauth_connection::Column::DeletedAt.is_null()) + .all(&self.conn) + .await?; + Ok(model) + } + + pub async fn get_user_oauth( + &self, + user_id: Uuid, + provider_name: &str, + ) -> Result> { + let model = entities::oauth_connection::Entity::find() + .filter(entities::oauth_connection::Column::UserId.eq(user_id)) + .filter(entities::oauth_connection::Column::Provider.eq(provider_name)) + .filter(entities::oauth_connection::Column::DeletedAt.is_null()) + .one(&self.conn) + .await?; + Ok(model) + } + pub async fn get_user_organizations( &self, user_id: Uuid, diff --git a/lapdev-db/src/entities/mod.rs b/lapdev-db/src/entities/mod.rs index 6ecdef9..1900d05 100644 --- a/lapdev-db/src/entities/mod.rs +++ b/lapdev-db/src/entities/mod.rs @@ -5,6 +5,7 @@ pub mod prelude; pub mod audit_log; pub mod config; pub mod machine_type; +pub mod oauth_connection; pub mod organization; pub mod organization_member; pub mod prebuild; diff --git a/lapdev-db/src/entities/oauth_connection.rs b/lapdev-db/src/entities/oauth_connection.rs new file mode 100644 index 0000000..5df7e58 --- /dev/null +++ b/lapdev-db/src/entities/oauth_connection.rs @@ -0,0 +1,26 @@ +//! `SeaORM` Entity. Generated by sea-orm-codegen 0.12.4 + +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)] +#[sea_orm(table_name = "oauth_connection")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub id: Uuid, + pub user_id: Uuid, + pub created_at: DateTimeWithTimeZone, + pub deleted_at: Option, + pub provider: String, + pub provider_id: i32, + pub provider_login: String, + pub access_token: String, + pub avatar_url: Option, + pub email: Option, + pub name: Option, + pub read_repo: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/lapdev-db/src/entities/organization.rs b/lapdev-db/src/entities/organization.rs index b70598b..8a13c42 100644 --- a/lapdev-db/src/entities/organization.rs +++ b/lapdev-db/src/entities/organization.rs @@ -15,6 +15,7 @@ pub struct Model { pub allow_workspace_change_auto_stop: bool, pub last_auto_stop_check: Option, pub usage_limit: i64, + pub running_workspace_limit: i32, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/lapdev-db/src/entities/project.rs b/lapdev-db/src/entities/project.rs index 0f0cef0..b173db9 100644 --- a/lapdev-db/src/entities/project.rs +++ b/lapdev-db/src/entities/project.rs @@ -12,6 +12,7 @@ pub struct Model { pub name: String, pub organization_id: Uuid, pub created_by: Uuid, + pub oauth_id: Uuid, pub repo_url: String, pub repo_name: String, pub machine_type_id: Uuid, diff --git a/lapdev-db/src/entities/user.rs b/lapdev-db/src/entities/user.rs index 6fd8ac1..ca06b0a 100644 --- a/lapdev-db/src/entities/user.rs +++ b/lapdev-db/src/entities/user.rs @@ -10,9 +10,6 @@ pub struct Model { pub created_at: DateTimeWithTimeZone, pub deleted_at: Option, pub provider: String, - pub provider_id: i32, - pub provider_login: String, - pub access_token: String, pub avatar_url: Option, pub email: Option, pub name: Option, diff --git a/lapdev-db/src/migration/m20231106_100019_create_user_table.rs b/lapdev-db/src/migration/m20231106_100019_create_user_table.rs index d949557..2656279 100644 --- a/lapdev-db/src/migration/m20231106_100019_create_user_table.rs +++ b/lapdev-db/src/migration/m20231106_100019_create_user_table.rs @@ -19,9 +19,6 @@ impl MigrationTrait for Migration { ) .col(ColumnDef::new(User::DeletedAt).timestamp_with_time_zone()) .col(ColumnDef::new(User::Provider).string().not_null()) - .col(ColumnDef::new(User::ProviderId).integer().not_null()) - .col(ColumnDef::new(User::ProviderLogin).string().not_null()) - .col(ColumnDef::new(User::AccessToken).string().not_null()) .col(ColumnDef::new(User::AvatarUrl).string()) .col(ColumnDef::new(User::Email).string()) .col(ColumnDef::new(User::Name).string()) @@ -32,20 +29,6 @@ impl MigrationTrait for Migration { ) .await?; - manager - .create_index( - Index::create() - .name("user_provider_provider_id_idx") - .table(User::Table) - .unique() - .nulls_not_distinct() - .col(User::Provider) - .col(User::ProviderId) - .col(User::DeletedAt) - .to_owned(), - ) - .await?; - Ok(()) } } @@ -57,9 +40,6 @@ pub enum User { CreatedAt, DeletedAt, Provider, - ProviderId, - ProviderLogin, - AccessToken, AvatarUrl, Email, Name, diff --git a/lapdev-db/src/migration/m20231109_171859_create_project_table.rs b/lapdev-db/src/migration/m20231109_171859_create_project_table.rs index 44cb41b..01abccd 100644 --- a/lapdev-db/src/migration/m20231109_171859_create_project_table.rs +++ b/lapdev-db/src/migration/m20231109_171859_create_project_table.rs @@ -23,6 +23,7 @@ impl MigrationTrait for Migration { .col(ColumnDef::new(Project::Name).string().not_null()) .col(ColumnDef::new(Project::OrganizationId).uuid().not_null()) .col(ColumnDef::new(Project::CreatedBy).uuid().not_null()) + .col(ColumnDef::new(Project::OauthId).uuid().not_null()) .col(ColumnDef::new(Project::RepoUrl).string().not_null()) .col(ColumnDef::new(Project::RepoName).string().not_null()) .col(ColumnDef::new(Project::MachineTypeId).uuid().not_null()) @@ -62,6 +63,7 @@ pub enum Project { Name, OrganizationId, CreatedBy, + OauthId, RepoUrl, RepoName, MachineTypeId, diff --git a/lapdev-db/src/migration/m20240823_165223_create_oauth_table.rs b/lapdev-db/src/migration/m20240823_165223_create_oauth_table.rs new file mode 100644 index 0000000..939bc3a --- /dev/null +++ b/lapdev-db/src/migration/m20240823_165223_create_oauth_table.rs @@ -0,0 +1,102 @@ +use sea_orm_migration::prelude::*; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .create_table( + Table::create() + .table(OauthConnection::Table) + .if_not_exists() + .col( + ColumnDef::new(OauthConnection::Id) + .uuid() + .not_null() + .primary_key(), + ) + .col(ColumnDef::new(OauthConnection::UserId).uuid().not_null()) + .col( + ColumnDef::new(OauthConnection::CreatedAt) + .timestamp_with_time_zone() + .not_null(), + ) + .col(ColumnDef::new(OauthConnection::DeletedAt).timestamp_with_time_zone()) + .col( + ColumnDef::new(OauthConnection::Provider) + .string() + .not_null(), + ) + .col( + ColumnDef::new(OauthConnection::ProviderId) + .integer() + .not_null(), + ) + .col( + ColumnDef::new(OauthConnection::ProviderLogin) + .string() + .not_null(), + ) + .col( + ColumnDef::new(OauthConnection::AccessToken) + .string() + .not_null(), + ) + .col(ColumnDef::new(OauthConnection::AvatarUrl).string()) + .col(ColumnDef::new(OauthConnection::Email).string()) + .col(ColumnDef::new(OauthConnection::Name).string()) + .col(ColumnDef::new(OauthConnection::ReadRepo).boolean()) + .to_owned(), + ) + .await?; + + manager + .create_index( + Index::create() + .name("oauth_connection_provider_provider_id_deleted_at_idx") + .table(OauthConnection::Table) + .unique() + .nulls_not_distinct() + .col(OauthConnection::Provider) + .col(OauthConnection::ProviderId) + .col(OauthConnection::DeletedAt) + .to_owned(), + ) + .await?; + + manager + .create_index( + Index::create() + .name("oauth_connection_user_id_provider_deleted_at_idx") + .table(OauthConnection::Table) + .unique() + .nulls_not_distinct() + .col(OauthConnection::UserId) + .col(OauthConnection::Provider) + .col(OauthConnection::DeletedAt) + .to_owned(), + ) + .await?; + + Ok(()) + } +} + +#[derive(DeriveIden)] +enum OauthConnection { + Table, + Id, + UserId, + CreatedAt, + DeletedAt, + Provider, + ProviderId, + ProviderLogin, + AccessToken, + AvatarUrl, + Email, + Name, + ReadRepo, +} diff --git a/lapdev-db/src/migration/mod.rs b/lapdev-db/src/migration/mod.rs index de3e989..375bc97 100644 --- a/lapdev-db/src/migration/mod.rs +++ b/lapdev-db/src/migration/mod.rs @@ -17,6 +17,7 @@ pub mod m20240228_141013_create_user_invitation_table; pub mod m20240311_220708_create_prebuild_replica_table; pub mod m20240312_175753_create_table_update_trigger; pub mod m20240316_194115_create_workspace_port_table; +pub mod m20240823_165223_create_oauth_table; pub struct Migrator; @@ -41,6 +42,7 @@ impl MigratorTrait for Migrator { Box::new(m20240311_220708_create_prebuild_replica_table::Migration), Box::new(m20240312_175753_create_table_update_trigger::Migration), Box::new(m20240316_194115_create_workspace_port_table::Migration), + Box::new(m20240823_165223_create_oauth_table::Migration), ] } } diff --git a/lapdev-enterprise/src/enterprise.rs b/lapdev-enterprise/src/enterprise.rs index 299ac05..dbb7a89 100644 --- a/lapdev-enterprise/src/enterprise.rs +++ b/lapdev-enterprise/src/enterprise.rs @@ -1,10 +1,11 @@ use std::collections::HashMap; use anyhow::Result; -use chrono::{DateTime, FixedOffset}; +use chrono::{DateTime, FixedOffset, Utc}; use lapdev_common::{AuditLogRecord, LAPDEV_BASE_HOSTNAME}; use lapdev_common::{AuditLogResult, QuotaKind, QuotaResult}; use lapdev_db::{api::DbApi, entities}; +use lapdev_rpc::error::ApiError; use sea_orm::{ ActiveModelTrait, ActiveValue, ColumnTrait, DatabaseTransaction, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, @@ -14,6 +15,9 @@ use uuid::Uuid; use crate::usage::Usage; use crate::{auto_start_stop::AutoStartStop, license::License, quota::Quota}; +const LAPDEV_USAGE_LIMIT_ERROR: &str = "lapdev-usage-limit-error"; +const LAPDEV_RUNNING_WORKSPACE_LIMIT_ERROR: &str = "lapdev-running-workspace-limit-error"; + pub struct Enterprise { pub quota: Quota, pub auto_start_stop: AutoStartStop, @@ -41,6 +45,49 @@ impl Enterprise { self.license.has_valid().await } + pub async fn check_organization_limit( + &self, + organization: &entities::organization::Model, + ) -> Result<(), ApiError> { + if !self.license.has_valid().await { + return Ok(()); + } + + if organization.running_workspace_limit > 0 { + let count = self + .quota + .get_organization_existing(&QuotaKind::RunningWorkspace, organization.id) + .await?; + if count as i32 >= organization.running_workspace_limit { + return Err(ApiError::InvalidRequest( + self.db + .get_config(LAPDEV_RUNNING_WORKSPACE_LIMIT_ERROR) + .await + .unwrap_or_else(|_| { + "You have reached the running workspace limit".to_string() + }), + )); + } + } + + if organization.usage_limit > 0 { + let usage = self + .usage + .get_monthly_cost(organization.id, None, Utc::now().into(), None) + .await?; + if usage as i64 >= organization.usage_limit { + return Err(ApiError::InvalidRequest( + self.db + .get_config(LAPDEV_USAGE_LIMIT_ERROR) + .await + .unwrap_or_else(|_| "You have reached the usage limit".to_string()), + )); + } + } + + Ok(()) + } + pub async fn check_create_workspace_quota( &self, txn: &DatabaseTransaction, diff --git a/lapdev-enterprise/src/quota.rs b/lapdev-enterprise/src/quota.rs index 694a9ea..285b95f 100644 --- a/lapdev-enterprise/src/quota.rs +++ b/lapdev-enterprise/src/quota.rs @@ -209,8 +209,9 @@ impl Quota { .filter(entities::workspace::Column::UserId.eq(user)) .filter( entities::workspace::Column::Status - .is_in(WorkspaceStatus::RUNNING.iter().map(|s| s.to_string())), + .ne(WorkspaceStatus::Stopped.to_string()), ) + .filter(entities::workspace::Column::ComposeParent.is_null()) .count(&self.db.conn) .await? as usize } @@ -229,7 +230,7 @@ impl Quota { Ok(existing) } - async fn get_organization_existing( + pub async fn get_organization_existing( &self, kind: &QuotaKind, organization: Uuid, @@ -248,8 +249,9 @@ impl Quota { .filter(entities::workspace::Column::OrganizationId.eq(organization)) .filter( entities::workspace::Column::Status - .is_in(WorkspaceStatus::RUNNING.iter().map(|s| s.to_string())), + .ne(WorkspaceStatus::Stopped.to_string()), ) + .filter(entities::workspace::Column::ComposeParent.is_null()) .count(&self.db.conn) .await? as usize }