From 556c8dda6f7e6bcfc8854ea68c1a9300ba52ee78 Mon Sep 17 00:00:00 2001 From: boxbeam Date: Fri, 23 Feb 2024 15:44:14 -0500 Subject: [PATCH 01/29] ci: remove manylinux2014 rocm build since it doesn't work --- .github/workflows/release.yml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c31a4b721a7a..072d040f16e7 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -27,7 +27,7 @@ jobs: container: ${{ matrix.container }} strategy: matrix: - binary: [aarch64-apple-darwin, x86_64-manylinux2014, x86_64-windows-msvc, x86_64-manylinux2014-cuda117, x86_64-manylinux2014-cuda122, x86_64-windows-msvc-cuda117, x86_64-windows-msvc-cuda122, x86_64-manylinux2014-rocm57] + binary: [aarch64-apple-darwin, x86_64-manylinux2014, x86_64-windows-msvc, x86_64-manylinux2014-cuda117, x86_64-manylinux2014-cuda122, x86_64-windows-msvc-cuda117, x86_64-windows-msvc-cuda122] include: - os: macos-latest target: aarch64-apple-darwin @@ -63,11 +63,6 @@ jobs: ext: .exe build_args: --features cuda,prod-db windows_cuda: '12.2.0' - - os: dimerun-k3-ubuntu2204 - target: x86_64-unknown-linux-gnu - binary: x86_64-manylinux2014-rocm57 - container: ghcr.io/cromefire/hipblas-manylinux/2014/5.7:latest - build_args: --features static-ssl,rocm,prod-db env: SCCACHE_GHA_ENABLED: true From debd70e861c5c49eae8646cfc5e485b9429cee64 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 23 Feb 2024 14:08:41 -0800 Subject: [PATCH 02/29] feat(webserver): add license check (#1495) * feat(webserver): add license check * prototype * fix * update * update * update * [autofix.ci] apply automated fixes * update * add license key * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * Update ee/tabby-webserver/src/service/auth.rs --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- ee/tabby-webserver/keys/license.key.pub | 24 +++++----- ee/tabby-webserver/src/schema/license.rs | 32 ++++++++++++- ee/tabby-webserver/src/schema/mod.rs | 40 ++++++++++++---- ee/tabby-webserver/src/service/auth.rs | 56 ++++++++++++++++++++++- ee/tabby-webserver/src/service/dao.rs | 2 +- ee/tabby-webserver/src/service/license.rs | 6 +-- ee/tabby-webserver/src/service/mod.rs | 34 ++++++++------ ee/tabby-webserver/src/service/worker.rs | 16 ++----- 8 files changed, 156 insertions(+), 54 deletions(-) diff --git a/ee/tabby-webserver/keys/license.key.pub b/ee/tabby-webserver/keys/license.key.pub index e009af8ef665..2a44dc4c4cd5 100644 --- a/ee/tabby-webserver/keys/license.key.pub +++ b/ee/tabby-webserver/keys/license.key.pub @@ -1,14 +1,14 @@ -----BEGIN PUBLIC KEY----- -MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA4WKYjEErVACA8sNQ6gGL -+9KUatfl+nJ74xR/+2ayrk4bQtVWgGvm4cGuc2V60aJ11BdXOEcyt95mO8n8+FRe -Y/fPkW22QDyqG7PXAt5bT4zuLXrrwvCrhB6QRWScRUaZv3jzCoclBu2fOxJxqbJo -Xx0pkXFct5viT3yfqv/+C5QQ7gPexUPEXqYRQuU4hqeVXtkhkfRA0DTWtOXnf0mU -4rJztxvkQiSAI8nufX01h73FrICntEaFGvLQnLR0VGVjlACZEmA3Nldvoq+Yt2zR -mRREXv3ks+PTnaAORoYnrnB+PoVMw9SkGUzA61CqvJoxKrbZfYmODglTlJh91UiF -lH7DXd7GK6iwHtO6dumAVaiIYqfPpJn0PExaqjtXHzKCRozbLPwIF+ECbq5vxjBq -hfWO/uqhiOqusRCoA4E8UHu8BmRC2s4Kn3QI/qOHKovCq72Hy0YL3/trYtJfJ3cz -sVyytP8tmoG3CGLjM80aaXpvpr87GCN07uKJmNgr00EQvxEK8CEMO7EbkNq7AVCY -awBC5tTDt7UzKyam8c97LuNUyWsI2H9FattHHDcRzA+HBTZ8FyZyI+m7zd2WdITF -tzpxm3mWo6MeQH3hq9prDiXwbyowXK/U0ZLK1s/WhFU5dxCnkTpZI3X5gLyGsEPD -FI05GFZrbWnOWCLAP0FKZAMCAwEAAQ== +MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAsmCQRK2UeAoIt7qnQlID +Aa8ih5fCRAhNAEFaWb9hHgg80q9M6F/0yPnhKxt78+3Lxz1jwZ5aZYaSxnZEwAx9 +3X1Zs7x+Our6VJmG1WDGqMkLWcSZjf8fsbH7TTMXLTBnU/nIkFKzNZInccu8CxsH +Im6Qxr81VFFHmZG9dsUQb2/fuA5Ck/UNvbbSipo+qfs5vJ6UP2CghkhUREZhK5Yb +s4wJ3AhUZ+uKzxm+73bjFTnYZ32IKR2h5WfroLK9wOuJ3hXuyK96Ka5DaUb3i3Ni +aFQiSJ2E5W8tjQYWEBth484v7kBjoamTbVP/jrcUUfmE2TAeEi3L/AU8QRlAJ575 +1Ujny4K+8PzzG4tJYjqXSI3BijaeUNwxFcL07yStj6xH9Fz/pLKV0RtEtPgNtXwq +6WlcYF1GqlNFinkTH+pFi9vFydxJ2N3HayR6dDq1r17Lf9HPmezMnRXde01wwVQD +Or77KGuIm6sH3cusW1IR+X2ZCT3UAQ1FRqpL9TMWY/LLl19Q1w3F/RYP/ZRvFsVl +aDTWyOqlG1XVU7uAR74gYirr73Rv/8pZ2453ZaVYjL73ZAM8X85Kh4xRstD6SjKG +A4WRGjiVsTSGRxk81wXNPLu3fnnj72gUkEgEWG+9odQSIrHtZSVCMxfL1Rc0MfK/ +rVIU8J/jzVyNYUKaVrrICt0CAwEAAQ== -----END PUBLIC KEY----- diff --git a/ee/tabby-webserver/src/schema/license.rs b/ee/tabby-webserver/src/schema/license.rs index 2f5e73d37d61..0c54cdef0ced 100644 --- a/ee/tabby-webserver/src/schema/license.rs +++ b/ee/tabby-webserver/src/schema/license.rs @@ -1,3 +1,5 @@ +use std::error::Error; + use async_trait::async_trait; use chrono::{DateTime, Utc}; use juniper::{GraphQLEnum, GraphQLObject}; @@ -11,7 +13,7 @@ pub enum LicenseType { Team, } -#[derive(GraphQLEnum, PartialEq, Debug)] +#[derive(GraphQLEnum, PartialEq, Debug, Clone)] pub enum LicenseStatus { Ok, Expired, @@ -33,3 +35,31 @@ pub trait LicenseService: Send + Sync { async fn read_license(&self) -> Result>; async fn update_license(&self, license: String) -> Result<()>; } + +pub trait IsLicenseValid { + fn is_license_valid(&self) -> bool; +} + +impl IsLicenseValid for LicenseInfo { + fn is_license_valid(&self) -> bool { + self.status == LicenseStatus::Ok + } +} + +impl IsLicenseValid for Option { + fn is_license_valid(&self) -> bool { + self.as_ref() + .map(|x| x.is_license_valid()) + .unwrap_or_default() + } +} + +impl IsLicenseValid for std::result::Result { + fn is_license_valid(&self) -> bool { + if let Ok(x) = self { + x.is_license_valid() + } else { + false + } + } +} diff --git a/ee/tabby-webserver/src/schema/mod.rs b/ee/tabby-webserver/src/schema/mod.rs index f7ff602d7d67..feb975e267cb 100644 --- a/ee/tabby-webserver/src/schema/mod.rs +++ b/ee/tabby-webserver/src/schema/mod.rs @@ -27,18 +27,17 @@ use validator::{Validate, ValidationErrors}; use worker::{Worker, WorkerService}; use self::{ - auth::{PasswordResetInput, RequestPasswordResetEmailInput, UpdateOAuthCredentialInput}, + auth::{ + JWTPayload, OAuthCredential, OAuthProvider, PasswordResetInput, RequestInvitationInput, + RequestPasswordResetEmailInput, UpdateOAuthCredentialInput, + }, email::{EmailService, EmailSetting, EmailSettingInput}, - license::{LicenseInfo, LicenseService}, - repository::RepositoryService, + license::{LicenseInfo, LicenseService, LicenseStatus}, + repository::{Repository, RepositoryService}, setting::{ NetworkSetting, NetworkSettingInput, SecuritySetting, SecuritySettingInput, SettingService, }, }; -use crate::schema::{ - auth::{JWTPayload, OAuthCredential, OAuthProvider, RequestInvitationInput}, - repository::Repository, -}; pub trait ServiceLocator: Send + Sync { fn auth(&self) -> Arc; @@ -74,8 +73,8 @@ pub enum CoreError { #[error("{0}")] Forbidden(&'static str), - #[error("Invalid ID Error")] - InvalidIDError, + #[error("Invalid ID")] + InvalidID, #[error("Invalid input parameters")] InvalidInput(#[from] ValidationErrors), @@ -83,6 +82,9 @@ pub enum CoreError { #[error("Email is not configured")] EmailNotConfigured, + #[error("{0}")] + InvalidLicense(&'static str), + #[error(transparent)] Other(#[from] anyhow::Error), } @@ -118,6 +120,24 @@ fn check_admin(ctx: &Context) -> Result<(), CoreError> { Ok(()) } +async fn check_license(ctx: &Context) -> Result<(), CoreError> { + let Some(license) = ctx.locator.license().read_license().await? else { + return Err(CoreError::InvalidLicense( + "This feature requires enterprise license", + )); + }; + + match license.status { + LicenseStatus::Ok => Ok(()), + LicenseStatus::Expired => Err(CoreError::InvalidLicense( + "Your enterprise license is expired", + )), + LicenseStatus::SeatsExceeded => Err(CoreError::InvalidLicense( + "You have more active users than seats included in your license", + )), + } +} + #[derive(Default)] pub struct Query; @@ -465,6 +485,7 @@ impl Mutation { input: UpdateOAuthCredentialInput, ) -> Result { check_admin(ctx)?; + check_license(ctx).await?; input.validate()?; ctx.locator.auth().update_oauth_credential(input).await?; Ok(true) @@ -485,6 +506,7 @@ impl Mutation { async fn update_security_setting(ctx: &Context, input: SecuritySettingInput) -> Result { check_admin(ctx)?; + check_license(ctx).await?; input.validate()?; ctx.locator.setting().update_security_setting(input).await?; Ok(true) diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index bffa147cc155..514914a60c0c 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -24,6 +24,7 @@ use crate::{ UpdateOAuthCredentialInput, User, }, email::EmailService, + license::{IsLicenseValid, LicenseService}, setting::SettingService, CoreError, Result, }, @@ -33,13 +34,15 @@ use crate::{ struct AuthenticationServiceImpl { db: DbConn, mail: Arc, + license: Arc, } pub fn new_authentication_service( db: DbConn, mail: Arc, + license: Arc, ) -> impl AuthenticationService { - AuthenticationServiceImpl { db, mail } + AuthenticationServiceImpl { db, mail, license } } #[async_trait] @@ -233,6 +236,12 @@ impl AuthenticationService for AuthenticationServiceImpl { } async fn create_invitation(&self, email: String) -> Result { + if !self.license.read_license().await.is_license_valid() { + return Err(CoreError::InvalidLicense( + "This feature requires enterprise license", + )); + }; + let invitation = self.db.create_invitation(email.clone()).await?; let email_sent = self .mail @@ -467,11 +476,41 @@ fn password_verify(raw: &str, hash: &str) -> bool { #[cfg(test)] mod tests { + struct MockLicenseService(LicenseStatus); + + #[async_trait] + impl LicenseService for MockLicenseService { + async fn read_license(&self) -> Result> { + Ok(Some(LicenseInfo { + r#type: crate::schema::license::LicenseType::Team, + status: self.0.clone(), + seats: 1, + seats_used: 1, + issued_at: Utc::now(), + expires_at: Utc::now(), + })) + } + + async fn update_license(&self, _: String) -> Result<()> { + Ok(()) + } + } + async fn test_authentication_service() -> AuthenticationServiceImpl { let db = DbConn::new_in_memory().await.unwrap(); AuthenticationServiceImpl { db: db.clone(), mail: Arc::new(new_email_service(db).await.unwrap()), + license: Arc::new(MockLicenseService(LicenseStatus::Ok)), + } + } + + async fn test_authentication_service_without_valid_license() -> AuthenticationServiceImpl { + let db = DbConn::new_in_memory().await.unwrap(); + AuthenticationServiceImpl { + db: db.clone(), + mail: Arc::new(new_email_service(db).await.unwrap()), + license: Arc::new(MockLicenseService(LicenseStatus::Expired)), } } @@ -482,6 +521,7 @@ mod tests { let service = AuthenticationServiceImpl { db: db.clone(), mail: Arc::new(smtp.create_test_email_service(db).await), + license: Arc::new(MockLicenseService(LicenseStatus::Ok)), }; (service, smtp) } @@ -491,7 +531,10 @@ mod tests { use serial_test::serial; use super::*; - use crate::service::email::{new_email_service, testutils::TestEmailServer}; + use crate::{ + schema::license::{LicenseInfo, LicenseStatus}, + service::email::{new_email_service, testutils::TestEmailServer}, + }; #[test] fn test_password_hash() { @@ -976,4 +1019,13 @@ mod tests { assert!(service.allow_self_signup().await.unwrap()); } + + #[tokio::test] + async fn test_create_invitation_without_license() { + let service = test_authentication_service_without_valid_license().await; + assert_matches!( + service.create_invitation("abc.com".into()).await, + Err(CoreError::InvalidLicense(_)) + ) + } } diff --git a/ee/tabby-webserver/src/service/dao.rs b/ee/tabby-webserver/src/service/dao.rs index 014728bcb1e9..7391c1a02c54 100644 --- a/ee/tabby-webserver/src/service/dao.rs +++ b/ee/tabby-webserver/src/service/dao.rs @@ -145,7 +145,7 @@ impl AsRowid for juniper::ID { .decode(self) .first() .map(|i| *i as i32) - .ok_or(CoreError::InvalidIDError) + .ok_or(CoreError::InvalidID) } } diff --git a/ee/tabby-webserver/src/service/license.rs b/ee/tabby-webserver/src/service/license.rs index 75128e7d2958..e5053689c6c0 100644 --- a/ee/tabby-webserver/src/service/license.rs +++ b/ee/tabby-webserver/src/service/license.rs @@ -147,9 +147,9 @@ mod tests { use super::*; - const VALID_TOKEN: &str = "eyJhbGciOiJSUzUxMiJ9.eyJpc3MiOiJ0YWJieW1sLmNvbSIsInN1YiI6ImZha2VAdGFiYnltbC5jb20iLCJpYXQiOjE3MDUxOTgxMDIsImV4cCI6MTgwNzM5ODcwMiwidHlwIjoiVEVBTSIsIm51bSI6MTB9.vVo7PDevytGw2KXU5E-KMdJBijwOWsD1zKIf26rcjfxa3wDesGY40zuYZWyZFMfmAtBTO7DBgqdWnriHnF_HOnoAEDCycrgoxuSJW5TS9XsCWto-3rDhUsjRZ1wls-ztQu3Gxo_84UHUFwrXe-RHmJi_3w_YO-2L-nVw7JDd5zR8CEdLxeccD47vBrumYA7ybultoDHpHxSppjHlW1VPXavoaBIO1Twnbf52uJlbzJmloViDxoq-_9lxcN1hDN3KKE3crzO9uHK4jjZy_1KNHhCIIcnINek6SBl6lWZw9R88UfdP6uaVOTOHDFbGwv544TSLA_oKZXXntXhldKCp94YN8J4djHim91WwYBQARrpQKiQGP1APEQQdv_YO4iUC3QTLOVw_NMjyma0feVjzHYAap_2Q9HgnxyJfMH-KiH2zaR6BcdOfWV86crO5M0qNoP-XOgy4uU8eE2-PevOKM6uVwYiwoNZL4e9ttH6ratJj0tyqGW_3HYpsVyThzqDPisEz95knsrVL-iagwHRd00l6Mqfwcjbn-gOuUOV9knRIpPvUmfKjjjHgb-JI0qMAIdgeVtwQp0pNqPsKwenMwkpYQH1awfuB_Ia7SyMUNEzTAY8k_J4R6kCZ5XKJ2VTCljd9aJFSZpw-K57reUX1eLc6-Cwt1iI4d23M5UlYjvs"; - const EXPIRED_TOKEN: &str = "eyJhbGciOiJSUzUxMiJ9.eyJpc3MiOiJ0YWJieW1sLmNvbSIsInN1YiI6ImZha2VAdGFiYnltbC5jb20iLCJpYXQiOjE3MDUxOTgxMDIsImV4cCI6MTcwNzM5ODcwMiwidHlwIjoiVEVBTSIsIm51bSI6MTB9.19wrmSSZUQAj_nfnBljUARD3vz_XEIDh4wpi_U2P6LDRcvm7QYCro__LxUjIf45aE9BBiZCPBRTVOw_tMbegTAv5yK9G9cllGPdRDKWjf24BJpHt2wBKOwhCToUKp8R8D50bQ3cxHuz7J3XxcOMtwKxNRlwaufO-vgxX73v13z_bN6y5ix8FC5JEjY1z3fNPc_TnuuHnaXXqgqL9OJTrxhh5FErqR52kmxGGn2KCM8rm2Nfu0It2IZQuyJHSceZ3-iiIxsrVdXxbO4KHXLEOXos0xJRV8QG9_9VjAo6qui6BioygwrcPqHT7OoG3WfcT8XE9rcEX-s9PZ54_XxLm0yh81g54xPI92n94pe32XfE9T-YXNK3MLAdZWwDhp_sKXTcMSIr7mI9OA7eczZUpvI4BuDM8s1irNx4DKdfTwNchHDfEPmGmO53RHyVEbrS72jF9GBRBIwPmpGppWhcwpVNmlRJw3j1Sa_ttcGikPnBZBrUxGqzynq4q1VpeCpRoTzO9_nw5eciKMpaKww0P5Edqm5kKgg48aABfsTU3hLqTIr9rgjXePL_gEse6MJX_JC8I7-R17iQmMxKiNa9bTqSIk56qlB6gwZTzcjEtpnYlzZ05Ci6D3JBH9ZdO_F3UZDt5JdAD5dqsKl8PfWpxaWpg7FXNlqxYO9BpxCwr_7g"; - const INCOMPLETE_TOKEN: &str = "eyJhbGciOiJSUzUxMiJ9.eyJpc3MiOiJ0YWJieW1sLmNvbSIsInN1YiI6ImZha2VAdGFiYnltbC5jb20iLCJpYXQiOjE3MDUxOTgxMDIsImV4cCI6MTgwNzM5ODcwMiwidHlwIjoiVEVBTSJ9.Xdp7Tgi39RN3qBfDAT_RncCDF2lSSouT4fjR0YT8F4qN8qkocxgvCa6JyxlksaiqGKWb_aYJvkhCviMHnT_pnoNpR8YaLvB4vezEAdDWLf3jBqzhlsrCCbMGh72wFYKRIODhIHeTzldU4F06I9sz5HdtQpn42Q8WC8tAzG109vHtxcdC7D85u0CumJ35DcV7lTfpfIkil3PORReg0ysjZNjQ2JbiFqMF1VbBmC-DsoTrJoHlrxdHowMQsXv89C80pchx4UFSm7Z9tHiMUTOzfErScsGJI1VC5p8SYA3N4nsrPn-iup1CxOBIdK57BHedKGpd_hi1AVWYB4zXcc8HzzpqgwHulfaw_5vNvRMdkDGj3X2afU3O3rZ4jT_KLGjY-3Krgol8JHgJYiPXkBypiajFU6rVeMLScx-X-2-n3KBdR4GQ9la90QHSyIQUpiGRRfPhviBFDtAfcjJYo1Irlu6MGVhgFq9JH5SOVTn57V0A_VeAbj8WZNdML9hio9xqxP86DprnP_ApHpO_xbi-sx2GCmUyfC10eKnX8_sAB1n7z0AaHz4e-6SGm1I-wQsWcXjZfRYw0Vtogz7wVuyAIpm8lF58XjtOwQ9bP1kD03TGIcBTvEtgA6QUhRcximGJ5buK9X2TTd4TlHjFF1krrmYAUEDgFsorseoKvMkspVE"; + const VALID_TOKEN: &str = "eyJhbGciOiJSUzUxMiJ9.eyJpc3MiOiJ0YWJieW1sLmNvbSIsInN1YiI6ImZha2VAdGFiYnltbC5jb20iLCJpYXQiOjE3MDUxOTgxMDIsImV4cCI6MTgwNzM5ODcwMiwidHlwIjoiVEVBTSIsIm51bSI6MX0.r99qAkHGAzjZtS904ko5MMklquMcEJdibVGAZAxrJTf-kKBT-Kc-u-A8o7ZSrLD0eubIxNrLb16UsyAMxJ6xnIJY4h8BTIR9cz_dTezyGywpuAKI13Q2S77tfwcyBF6icFkDsz187MSQGPQuTdVNU8zXkYR5ZkNs8_Uc8SL940xt0KHWLU9DX8KT6eCcVMwAypLyAsSTRJeqE8uRumq1K6dKK7wkE_HQrg9nSmr40A5ZZPzRsUp6hShJyMYSp-D02utbT8bAzVPw6alBgZWrmlVEvdcvfO81DZylUIm-pszKityfT5tmuyMWtUx3AeLXSiQWZOpah3OBnL11IKhNhYWSzUMGuDENHfbP9hlSJvzjq8WeN73nXSjkNEVYetT2er6pnoGrvFUBWcLLdWcl4p324WwqsP5A7ZDbWamo62yPxHUy7Vr4ySRLDfNEQbjP8JVPacpx3-5oY16LlzS4e9RhR0G-aykJitrLd5--gTVGxlxsLbmz33TTDd3nMGuQp2xmpZsw4rTKefEN7hCdvgJhtwRLgL4jxSm2mBgtwWH_i0uuBFpCYNgh97rU-Cak66adXDydAOr6-imSHAIlSphGj6G4rUdbMtBV0n1MVGg3vIyHQot3hMaH6uXMpHOUEtxQivkp0F-fY6PoFr49HfWD-ZuneENaKKjB8p_rd9k"; + const EXPIRED_TOKEN: &str = "eyJhbGciOiJSUzUxMiJ9.eyJpc3MiOiJ0YWJieW1sLmNvbSIsInN1YiI6ImZha2VAdGFiYnltbC5jb20iLCJpYXQiOjE3MDUxOTgxMDIsImV4cCI6MTcwNDM5ODcwMiwidHlwIjoiVEVBTSIsIm51bSI6MX0.UBufd2YlyhuChdCSZvbvEBtxLABhZSuhya4KHKHYM2ABaSTjYYtSyT-yv0i9b8sySBoeu7kG0XBNrLQOg4fcirR5DxOFxiskI7qLLSQEIDYe-xnEbvxqKhN3RpHkxik9_OlvElvpIGrZRQxiELhESIM0NGck0Dz6MwTDFutkHZFh06cLFeoihs1rn44SknL3wP_afyCaOpQtTjDfsayBMfyDAriTG8HSnPbrw5Om7ER7uAqszhX8wpFonDeFeVB0OIUjayfL-SAMdLqNEqaFsUcuE4cUk7o9tA2jsYz2-BRlwDocLpRVp2V-K8MuyQJhDTiswbey2DE5tNRvnd3nNaVr7Pmt3mF7NMt8op8hl4I9scoThFBj9Bb1iMfAGVSXlRn9Kf2HHe2BJXGWC3w9bjWH2KRPMP3tScJ4CQccIJxZPU-fcX7IC1q8R4PWDYS11TDJ03PvCTEGFt3fBTLLaGOeoYHYNnd4qux317YhGtWTOO6ESIuoxQkJdTpNVOwfNmCVSfFUvJYs0l4r7z-QouHAd79Ck_GJ-cdiIOrV9MB1Lq6ayk267bXfdi0Lx6-PYxrTwXEkF5tBydrsPyhoReAbH8yQDqzlPbQzOlLo--Z4940kSEpgEsL9G6ymG5wDlMzNuQfjbYbCI0L19Spx5QRGtyYXtiSU1Tq-hhGm3zA"; + const INCOMPLETE_TOKEN: &str = "eyJhbGciOiJSUzUxMiJ9.eyJpc3MiOiJ0YWJieW1sLmNvbSIsInN1YiI6ImZha2VAdGFiYnltbC5jb20iLCJpYXQiOjE3MDUxOTgxMDIsImV4cCI6MTgwNDM5ODcwMiwidHlwIjoiVEVBTSJ9.juNQeg8jMRj7Q2XbmHSdneKZbTP_BIL43yW3He5avIRAKee1NF9-qg4ndGOYVWBmtoO6Y_CAts_trSw6gmuDuwWcmSbbr7CWQOYuNrMj1_Gp1MctA8zzC3yzr0EoBLzqkNBq3OySlfOkohopmJ6Lu0d0KRtf46qq94cMDAlfs7etcVGkGqfMEwxznptXiF7_S3qRVbahvJDPJlu_ozwn51tICXMrlGV_P6jdBcNLQ8I1LAH2RfyH9u-4mUSTKt-obnXw6mtPxPjl07MEajM_wW3X05-iRygQfyzDulvW0EXf39OnW2kCuyfQWx5Zksr-sCNTEL2VSalf9o8MchjAhDN5QrygdZkk7KXwt3O54tpcnFVABw9ORxJtTrsZJD-YvdmS01O6qLfMRWs2CGWFTfDJLxMSiBhAsy4DC4TkZN4UnBpX09U7n6f_0NUr83YAWcw0Rlp32k01j9iPUWSdePZh46Ck00XdzLcc15xfqv__ilaLAyRtb9JUVBX7g-VaLb1YGk658t19eukRNzE6WFyKfAE7u6EbxowtFQqVKYXWX_zDHoalo3DjUmPBV_VsorcBg4cjhrhBPBOB5f7Wa8r7eiJz1gWEj1xJEK2Y_mdShAvxNSWPSTvNvviPTgJbvbwDTzQ0It_d066ADBY2o0y5DTMP23EPL-oZ14TYIY4"; #[test] fn test_validate_license() { diff --git a/ee/tabby-webserver/src/service/mod.rs b/ee/tabby-webserver/src/service/mod.rs index 64af5dabc975..839e9f0ad940 100644 --- a/ee/tabby-webserver/src/service/mod.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -33,7 +33,7 @@ use crate::schema::{ auth::AuthenticationService, email::EmailService, job::JobService, - license::LicenseService, + license::{IsLicenseValid, LicenseService}, repository::RepositoryService, setting::SettingService, worker::{RegisterWorkerError, Worker, WorkerKind, WorkerService}, @@ -77,7 +77,11 @@ impl ServerContext { completion: worker::WorkerGroup::default(), chat: worker::WorkerGroup::default(), mail: mail.clone(), - auth: Arc::new(new_authentication_service(db_conn.clone(), mail)), + auth: Arc::new(new_authentication_service( + db_conn.clone(), + mail, + license.clone(), + )), license, db_conn, logger, @@ -133,20 +137,24 @@ impl WorkerService for ServerContext { } async fn register_worker(&self, worker: Worker) -> Result { - let worker = match worker.kind { - WorkerKind::Completion => self.completion.register(worker).await, - WorkerKind::Chat => self.chat.register(worker).await, + let worker_group = match worker.kind { + WorkerKind::Completion => &self.completion, + WorkerKind::Chat => &self.chat, }; - if let Some(worker) = worker { - info!( - "registering <{:?}> worker running at {}", - worker.kind, worker.addr - ); - Ok(worker) - } else { - Err(RegisterWorkerError::RequiresEnterpriseLicense) + let count_workers = worker_group.list().await.len(); + let is_license_valid = self.license.read_license().await.is_license_valid(); + + if count_workers > 0 && !is_license_valid { + return Err(RegisterWorkerError::RequiresEnterpriseLicense); } + + let worker = worker_group.register(worker).await; + info!( + "registering <{:?}> worker running at {}", + worker.kind, worker.addr + ); + Ok(worker) } async fn unregister_worker(&self, worker_addr: &str) { diff --git a/ee/tabby-webserver/src/service/worker.rs b/ee/tabby-webserver/src/service/worker.rs index 72b39a211345..d64b1cda0538 100644 --- a/ee/tabby-webserver/src/service/worker.rs +++ b/ee/tabby-webserver/src/service/worker.rs @@ -1,7 +1,6 @@ use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::RwLock; -use tracing::error; use crate::schema::worker::Worker; @@ -24,18 +23,14 @@ impl WorkerGroup { self.workers.read().await.clone() } - pub async fn register(&self, worker: Worker) -> Option { + pub async fn register(&self, worker: Worker) -> Worker { let mut workers = self.workers.write().await; - if workers.len() >= 1 { - error!("You need enterprise license to utilize more than 1 workers, please contact hi@tabbyml.com for information."); - return None; - } if workers.iter().all(|x| x.addr != worker.addr) { workers.push(worker.clone()); } - Some(worker) + worker } pub async fn unregister(&self, worker_addr: &str) -> bool { @@ -71,12 +66,7 @@ mod tests { let worker1 = make_worker("http://127.0.0.1:8080"); let worker2 = make_worker("http://127.0.0.2:8080"); - // Register success. - assert!(wg.register(worker1.clone()).await.is_some()); - assert!(wg.select().await.is_some()); - - // Register failed, as > 1 workers requires enterprise license. - assert!(wg.register(worker2.clone()).await.is_none()); + wg.register(worker1.clone()).await; let workers = wg.list().await; assert_eq!(workers.len(), 1); From d51689b9de7e0e91c53e425cdf90c67f24505e89 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 23 Feb 2024 14:24:20 -0800 Subject: [PATCH 03/29] feat(webserver): add license check in user authentication (#1528) * feat(webserver): add license check in user authentication * update * update query * Update ee/tabby-db/src/users.rs --- ee/tabby-db/src/users.rs | 30 +++++++++++++++++------ ee/tabby-webserver/src/service/license.rs | 2 +- ee/tabby-webserver/src/service/mod.rs | 11 ++++++--- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/ee/tabby-db/src/users.rs b/ee/tabby-db/src/users.rs index 71f768005880..708969972fc5 100644 --- a/ee/tabby-db/src/users.rs +++ b/ee/tabby-db/src/users.rs @@ -22,6 +22,8 @@ pub struct UserDAO { pub active: bool, } +static OWNER_USER_ID: i32 = 1; + impl UserDAO { fn select(clause: &str) -> String { r#"SELECT id, email, password_encrypted, is_admin, created_at, updated_at, auth_token, active FROM users WHERE "# @@ -30,7 +32,7 @@ impl UserDAO { } pub fn is_owner(&self) -> bool { - self.id == 1 + self.id == OWNER_USER_ID } } @@ -133,11 +135,13 @@ impl DbConn { Ok(users) } - pub async fn verify_auth_token(&self, token: &str) -> Result { + pub async fn verify_auth_token(&self, token: &str, requires_owner: bool) -> Result { let token = token.to_owned(); let email = query_scalar!( - "SELECT email FROM users WHERE auth_token = ? AND active", - token + "SELECT email FROM users WHERE auth_token = ? AND active AND (id == ? OR NOT ?)", + token, + OWNER_USER_ID, + requires_owner ) .fetch_one(&self.pool) .await; @@ -266,9 +270,12 @@ mod tests { let user = conn.get_user(id).await.unwrap().unwrap(); - assert!(conn.verify_auth_token("abcd").await.is_err()); + assert!(conn.verify_auth_token("abcd", false).await.is_err()); - assert!(conn.verify_auth_token(&user.auth_token).await.is_ok()); + assert!(conn + .verify_auth_token(&user.auth_token, false) + .await + .is_ok()); conn.reset_user_auth_token_by_email(&user.email) .await @@ -279,7 +286,16 @@ mod tests { // Inactive user's auth token will be rejected. conn.update_user_active(new_user.id, false).await.unwrap(); - assert!(conn.verify_auth_token(&new_user.auth_token).await.is_err()); + assert!(conn + .verify_auth_token(&new_user.auth_token, false) + .await + .is_err()); + + // Owner user should pass verification. + assert!(conn + .verify_auth_token(&new_user.auth_token, true) + .await + .is_err()); } #[tokio::test] diff --git a/ee/tabby-webserver/src/service/license.rs b/ee/tabby-webserver/src/service/license.rs index e5053689c6c0..18cd25ba44f1 100644 --- a/ee/tabby-webserver/src/service/license.rs +++ b/ee/tabby-webserver/src/service/license.rs @@ -73,7 +73,7 @@ impl LicenseServiceImpl { let lock = self.seats.read().await; *lock }; - if force_refresh || now - refreshed > Duration::minutes(5) { + if force_refresh || now - refreshed > Duration::seconds(15) { let mut lock = self.seats.write().await; seats = self.db.count_active_users().await?; *lock = (now, seats); diff --git a/ee/tabby-webserver/src/service/mod.rs b/ee/tabby-webserver/src/service/mod.rs index 839e9f0ad940..e46f512625f1 100644 --- a/ee/tabby-webserver/src/service/mod.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -109,10 +109,13 @@ impl ServerContext { // Admin system is initialized, but there is no valid token. return (false, None); }; - if let Ok(jwt) = self.auth.verify_access_token(token).await { - return (true, Some(jwt.sub)); - } - match self.db_conn.verify_auth_token(token).await { + let is_license_valid = self.license.read_license().await.is_license_valid(); + // If there's no valid license, only allows owner access. + match self + .db_conn + .verify_auth_token(token, !is_license_valid) + .await + { Ok(email) => (true, Some(email)), Err(_) => (false, None), } From 3e1beb921e762489a3f858c920289d2bfe3486e8 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 23 Feb 2024 14:41:09 -0800 Subject: [PATCH 04/29] fix(webserver): allow jwt based access regardless of license status (#1529) --- ee/tabby-webserver/src/service/mod.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ee/tabby-webserver/src/service/mod.rs b/ee/tabby-webserver/src/service/mod.rs index e46f512625f1..8162e000b553 100644 --- a/ee/tabby-webserver/src/service/mod.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -109,6 +109,12 @@ impl ServerContext { // Admin system is initialized, but there is no valid token. return (false, None); }; + + // Allow JWT based access (from web browser), regardless of the license status. + if let Ok(jwt) = self.auth.verify_access_token(token).await { + return (true, Some(jwt.sub)); + } + let is_license_valid = self.license.read_license().await.is_license_valid(); // If there's no valid license, only allows owner access. match self From 92ca33aaf514e077a9e60bbb688e752768dc465c Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 23 Feb 2024 22:23:00 -0800 Subject: [PATCH 05/29] feat(webserver): update subscription limits. (#1533) * update enterprise limits * change api of read_license * update graphql schema * fit api * ensure seat limits * [autofix.ci] apply automated fixes * fix lint --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../subscription/components/license-table.tsx | 161 ++++++++++++++++-- .../subscription/components/subscription.tsx | 63 ++++--- ee/tabby-ui/components/sub-header.tsx | 6 +- ee/tabby-webserver/graphql/schema.graphql | 8 +- ee/tabby-webserver/src/schema/license.rs | 64 +++++-- ee/tabby-webserver/src/schema/mod.rs | 25 ++- ee/tabby-webserver/src/service/auth.rs | 17 +- ee/tabby-webserver/src/service/license.rs | 41 +++-- ee/tabby-webserver/src/service/mod.rs | 8 +- 9 files changed, 296 insertions(+), 97 deletions(-) diff --git a/ee/tabby-ui/app/(dashboard)/settings/subscription/components/license-table.tsx b/ee/tabby-ui/app/(dashboard)/settings/subscription/components/license-table.tsx index 04db441e8537..d7ac93c775aa 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/subscription/components/license-table.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/subscription/components/license-table.tsx @@ -1,5 +1,8 @@ 'use client' +import { ReactNode } from 'react' + +import { IconCheck } from '@/components/ui/icons' import { Table, TableBody, @@ -15,27 +18,151 @@ export const LicenseTable = () => { - Free - Team - Enterprise + {PLANS.map(({ name, pricing, limit }, i) => ( + +

{name}

+

{pricing}

+

{limit}

+
+ ))}
- - - Member Management - - - - Seat count - 1 - Up to 10 - Unlimited - + {FEATURES.map(({ name, features }, i) => ( + + ))} ) } + +const FeatureList = ({ + name, + features +}: { + name: String + features: Feature[] +}) => { + return ( + <> + + + {name} + + + {features.map(({ name, community, team, enterprise }, i) => ( + + {name} + {community} + {team} + + {enterprise} + + + ))} + + ) +} + +interface Plan { + name: ReactNode | String + pricing: ReactNode | String + limit: ReactNode | String +} + +const PLANS: Plan[] = [ + { + name: 'Community', + pricing: '$0 per user/month', + limit: 'Up to 5 users, single node' + }, + { + name: 'Team', + pricing: '$19 per user/month', + limit: 'Up to 30 users, up to 2 nodes' + }, + { + name: 'Enterprise', + pricing: 'Contact Us', + limit: 'Customized, billed annually' + } +] + +interface Feature { + name: ReactNode | String + community: ReactNode | String + team: ReactNode | String + enterprise: ReactNode | String +} + +interface FeatureGroup { + name: String + features: Feature[] +} + +const checked = +const dashed = '–' + +const FEATURES: FeatureGroup[] = [ + { + name: 'Features', + features: [ + { + name: 'User count', + community: 'Up to 5', + team: 'Up to 30', + enterprise: 'Unlimited' + }, + { + name: 'Node count', + community: dashed, + team: 'Up to 2', + enterprise: 'Unlimited' + }, + { + name: 'Secure Access', + community: checked, + team: checked, + enterprise: checked + }, + { + name: 'Toggle IDE / Extensions telemetry', + community: dashed, + team: dashed, + enterprise: checked + }, + { + name: 'Authentication Domain', + community: dashed, + team: dashed, + enterprise: checked + }, + { + name: 'Single Sign-On (SSO)', + community: dashed, + team: dashed, + enterprise: checked + } + ] + }, + { + name: 'Bespoke', + features: [ + { + name: 'Support', + community: 'Community', + team: 'Email', + enterprise: 'Dedicated Slack channel' + }, + { + name: 'Roadmap prioritization', + community: dashed, + team: dashed, + enterprise: checked + } + ] + } +] diff --git a/ee/tabby-ui/app/(dashboard)/settings/subscription/components/subscription.tsx b/ee/tabby-ui/app/(dashboard)/settings/subscription/components/subscription.tsx index 0e5017e16148..06c6c9df8ccd 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/subscription/components/subscription.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/subscription/components/subscription.tsx @@ -5,6 +5,7 @@ import moment from 'moment' import { useQuery } from 'urql' import { graphql } from '@/lib/gql/generates' +import { LicenseInfo } from '@/lib/gql/generates/graphql' import { Skeleton } from '@/components/ui/skeleton' import LoadingWrapper from '@/components/loading-wrapper' import { SubHeader } from '@/components/sub-header' @@ -30,20 +31,18 @@ export default function Subscription() { query: getLicenseInfo }) const license = data?.license - const expiresAt = license?.expiresAt - ? moment(license.expiresAt).format('MM/DD/YYYY') - : '-' - const onUploadLicenseSuccess = () => { reexecuteQuery() } - const seatsText = license ? `${license.seatsUsed} / ${license.seats}` : '-' - return (
- - You can upload your Tabby license to unlock enterprise features. + + You can upload your Tabby license to unlock team/enterprise features.
} > -
-
-
Expires at
-
{expiresAt}
-
-
-
- Assigned / Total Seats -
-
{seatsText}
-
-
-
Current plan
-
- {capitalize(license?.type ?? 'FREE')} -
-
-
+ {license && }
- {false && } + +
+
+ ) +} + +function License({ license }: { license: LicenseInfo }) { + const expiresAt = license.expiresAt + ? moment(license.expiresAt).format('MM/DD/YYYY') + : '–' + + const seatsText = `${license.seatsUsed} / ${license.seats}` + + return ( +
+
+
Expires at
+
{expiresAt}
+
+
+
Assigned / Total Seats
+
{seatsText}
+
+
+
Current plan
+
+ {capitalize(license?.type ?? 'Community')} +
) diff --git a/ee/tabby-ui/components/sub-header.tsx b/ee/tabby-ui/components/sub-header.tsx index be98fdfc16c7..9a3f217d7f2a 100644 --- a/ee/tabby-ui/components/sub-header.tsx +++ b/ee/tabby-ui/components/sub-header.tsx @@ -6,11 +6,13 @@ import { IconExternalLink } from '@/components/ui/icons' interface SubHeaderProps extends React.HTMLAttributes { externalLink?: string + externalLinkText?: string } export const SubHeader: React.FC = ({ className, externalLink, + externalLinkText = 'Learn more', children }) => { return ( @@ -23,8 +25,8 @@ export const SubHeader: React.FC = ({ href={externalLink} target="_blank" > - Learn more - + {externalLinkText} + )} diff --git a/ee/tabby-webserver/graphql/schema.graphql b/ee/tabby-webserver/graphql/schema.graphql index a02ce2ed1673..938014eb3eb6 100644 --- a/ee/tabby-webserver/graphql/schema.graphql +++ b/ee/tabby-webserver/graphql/schema.graphql @@ -74,7 +74,7 @@ type Query { oauthCredential(provider: OAuthProvider!): OAuthCredential oauthCallbackUrl(provider: OAuthProvider!): String! serverInfo: ServerInfo! - license: LicenseInfo + license: LicenseInfo! } input NetworkSettingInput { @@ -134,8 +134,8 @@ type LicenseInfo { status: LicenseStatus! seats: Int! seatsUsed: Int! - issuedAt: DateTimeUtc! - expiresAt: DateTimeUtc! + issuedAt: DateTimeUtc + expiresAt: DateTimeUtc } input EmailSettingInput { @@ -154,7 +154,9 @@ input SecuritySettingInput { } enum LicenseType { + COMMUNITY TEAM + ENTERPRISE } type SecuritySetting { diff --git a/ee/tabby-webserver/src/schema/license.rs b/ee/tabby-webserver/src/schema/license.rs index 0c54cdef0ced..4b69ab973b41 100644 --- a/ee/tabby-webserver/src/schema/license.rs +++ b/ee/tabby-webserver/src/schema/license.rs @@ -5,12 +5,15 @@ use chrono::{DateTime, Utc}; use juniper::{GraphQLEnum, GraphQLObject}; use serde::Deserialize; +use super::CoreError; use crate::schema::Result; -#[derive(Debug, Deserialize, GraphQLEnum)] +#[derive(Debug, Deserialize, GraphQLEnum, PartialEq)] #[serde(rename_all = "UPPERCASE")] pub enum LicenseType { + Community, Team, + Enterprise, } #[derive(GraphQLEnum, PartialEq, Debug, Clone)] @@ -20,19 +23,64 @@ pub enum LicenseStatus { SeatsExceeded, } +impl From for CoreError { + fn from(val: LicenseStatus) -> Self { + match val { + LicenseStatus::Ok => panic!("License is valid, shouldn't be converted to CoreError"), + LicenseStatus::Expired => { + CoreError::InvalidLicense("Your enterprise license is expired") + } + LicenseStatus::SeatsExceeded => CoreError::InvalidLicense( + "You have more active users than seats included in your license", + ), + } + } +} + #[derive(GraphQLObject)] pub struct LicenseInfo { pub r#type: LicenseType, pub status: LicenseStatus, pub seats: i32, pub seats_used: i32, - pub issued_at: DateTime, - pub expires_at: DateTime, + pub issued_at: Option>, + pub expires_at: Option>, +} + +impl LicenseInfo { + pub fn seat_limits_for_community_license() -> usize { + 5 + } + + pub fn seat_limits_for_team_license() -> usize { + 30 + } + + pub fn check_node_limit(&self, num_nodes: usize) -> bool { + match self.r#type { + LicenseType::Community => false, + LicenseType::Team => num_nodes < 2, + LicenseType::Enterprise => true, + } + } + + pub fn ensure_seat_limit(mut self) -> Self { + let seats = self.seats as usize; + self.seats = match self.r#type { + LicenseType::Community => { + std::cmp::max(seats, Self::seat_limits_for_community_license()) + } + LicenseType::Team => std::cmp::max(seats, Self::seat_limits_for_team_license()), + LicenseType::Enterprise => seats, + } as i32; + + self + } } #[async_trait] pub trait LicenseService: Send + Sync { - async fn read_license(&self) -> Result>; + async fn read_license(&self) -> Result; async fn update_license(&self, license: String) -> Result<()>; } @@ -46,14 +94,6 @@ impl IsLicenseValid for LicenseInfo { } } -impl IsLicenseValid for Option { - fn is_license_valid(&self) -> bool { - self.as_ref() - .map(|x| x.is_license_valid()) - .unwrap_or_default() - } -} - impl IsLicenseValid for std::result::Result { fn is_license_valid(&self) -> bool { if let Ok(x) = self { diff --git a/ee/tabby-webserver/src/schema/mod.rs b/ee/tabby-webserver/src/schema/mod.rs index feb975e267cb..6bb143c5d211 100644 --- a/ee/tabby-webserver/src/schema/mod.rs +++ b/ee/tabby-webserver/src/schema/mod.rs @@ -32,7 +32,7 @@ use self::{ RequestPasswordResetEmailInput, UpdateOAuthCredentialInput, }, email::{EmailService, EmailSetting, EmailSettingInput}, - license::{LicenseInfo, LicenseService, LicenseStatus}, + license::{LicenseInfo, LicenseService, LicenseStatus, LicenseType}, repository::{Repository, RepositoryService}, setting::{ NetworkSetting, NetworkSettingInput, SecuritySetting, SecuritySettingInput, SettingService, @@ -120,21 +120,18 @@ fn check_admin(ctx: &Context) -> Result<(), CoreError> { Ok(()) } -async fn check_license(ctx: &Context) -> Result<(), CoreError> { - let Some(license) = ctx.locator.license().read_license().await? else { +async fn check_license(ctx: &Context, license_type: &[LicenseType]) -> Result<(), CoreError> { + let license = ctx.locator.license().read_license().await?; + + if !license_type.contains(&license.r#type) { return Err(CoreError::InvalidLicense( - "This feature requires enterprise license", + "Your plan doesn't include support for this feature.", )); - }; + } match license.status { LicenseStatus::Ok => Ok(()), - LicenseStatus::Expired => Err(CoreError::InvalidLicense( - "Your enterprise license is expired", - )), - LicenseStatus::SeatsExceeded => Err(CoreError::InvalidLicense( - "You have more active users than seats included in your license", - )), + LicenseStatus::Expired | LicenseStatus::SeatsExceeded => Err(license.status.into()), } } @@ -320,7 +317,7 @@ impl Query { }) } - async fn license(ctx: &Context) -> Result> { + async fn license(ctx: &Context) -> Result { ctx.locator.license().read_license().await } } @@ -485,7 +482,7 @@ impl Mutation { input: UpdateOAuthCredentialInput, ) -> Result { check_admin(ctx)?; - check_license(ctx).await?; + check_license(ctx, &[LicenseType::Enterprise]).await?; input.validate()?; ctx.locator.auth().update_oauth_credential(input).await?; Ok(true) @@ -506,7 +503,7 @@ impl Mutation { async fn update_security_setting(ctx: &Context, input: SecuritySettingInput) -> Result { check_admin(ctx)?; - check_license(ctx).await?; + check_license(ctx, &[LicenseType::Enterprise]).await?; input.validate()?; ctx.locator.setting().update_security_setting(input).await?; Ok(true) diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index 514914a60c0c..b9e11824bb59 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -236,10 +236,9 @@ impl AuthenticationService for AuthenticationServiceImpl { } async fn create_invitation(&self, email: String) -> Result { - if !self.license.read_license().await.is_license_valid() { - return Err(CoreError::InvalidLicense( - "This feature requires enterprise license", - )); + let license = self.license.read_license().await?; + if !license.is_license_valid() { + return Err(license.status.into()); }; let invitation = self.db.create_invitation(email.clone()).await?; @@ -480,15 +479,15 @@ mod tests { #[async_trait] impl LicenseService for MockLicenseService { - async fn read_license(&self) -> Result> { - Ok(Some(LicenseInfo { + async fn read_license(&self) -> Result { + Ok(LicenseInfo { r#type: crate::schema::license::LicenseType::Team, status: self.0.clone(), seats: 1, seats_used: 1, - issued_at: Utc::now(), - expires_at: Utc::now(), - })) + issued_at: Some(Utc::now()), + expires_at: Some(Utc::now()), + }) } async fn update_license(&self, _: String) -> Result<()> { diff --git a/ee/tabby-webserver/src/service/license.rs b/ee/tabby-webserver/src/service/license.rs index 18cd25ba44f1..56fcd335a1f4 100644 --- a/ee/tabby-webserver/src/service/license.rs +++ b/ee/tabby-webserver/src/service/license.rs @@ -1,4 +1,4 @@ -use anyhow::anyhow; +use anyhow::{anyhow, Context}; use async_trait::async_trait; use chrono::{DateTime, Duration, NaiveDateTime, Utc}; use jsonwebtoken as jwt; @@ -56,9 +56,8 @@ fn validate_license(token: &str) -> Result Result> { - Ok(NaiveDateTime::from_timestamp_opt(secs, 0) - .ok_or_else(|| anyhow!("Timestamp is corrupt"))? - .and_utc()) + let datetime = NaiveDateTime::from_timestamp_opt(secs, 0).context("Timestamp is corrupt")?; + Ok(datetime.and_utc()) } struct LicenseServiceImpl { @@ -80,6 +79,25 @@ impl LicenseServiceImpl { } Ok(seats) } + + async fn make_community_license(&self) -> Result { + let seats_used = self.read_used_seats(false).await?; + let status = if seats_used > LicenseInfo::seat_limits_for_community_license() { + LicenseStatus::SeatsExceeded + } else { + LicenseStatus::Ok + }; + + Ok(LicenseInfo { + r#type: LicenseType::Community, + status, + seats: LicenseInfo::seat_limits_for_community_license() as i32, + seats_used: seats_used as i32, + issued_at: None, + expires_at: None, + } + .ensure_seat_limit()) + } } pub async fn new_license_service(db: DbConn) -> Result { @@ -107,24 +125,25 @@ fn license_info_from_raw(raw: LicenseJWTPayload, seats_used: usize) -> Result
  • Result> { + async fn read_license(&self) -> Result { let Some(license) = self.db.read_enterprise_license().await? else { - return Ok(None); + return self.make_community_license().await; }; let license = validate_license(&license).map_err(|e| anyhow!("License is corrupt: {e:?}"))?; let seats = self.read_used_seats(false).await?; let license = license_info_from_raw(license, seats)?; - Ok(Some(license)) + Ok(license) } async fn update_license(&self, license: String) -> Result<()> { @@ -187,7 +206,7 @@ mod tests { assert!(service.update_license("bad_token".into()).await.is_err()); service.update_license(VALID_TOKEN.into()).await.unwrap(); - assert!(service.read_license().await.unwrap().is_some()); + assert!(service.read_license().await.is_ok()); assert!(service.update_license(EXPIRED_TOKEN.into()).await.is_err()); } diff --git a/ee/tabby-webserver/src/service/mod.rs b/ee/tabby-webserver/src/service/mod.rs index 8162e000b553..82e85617cda3 100644 --- a/ee/tabby-webserver/src/service/mod.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -152,9 +152,13 @@ impl WorkerService for ServerContext { }; let count_workers = worker_group.list().await.len(); - let is_license_valid = self.license.read_license().await.is_license_valid(); + let license = self + .license + .read_license() + .await + .map_err(|_| RegisterWorkerError::RequiresEnterpriseLicense)?; - if count_workers > 0 && !is_license_valid { + if license.check_node_limit(count_workers) { return Err(RegisterWorkerError::RequiresEnterpriseLicense); } From e45dbb398e9bb1a1a98e0927deb1871aa8251afa Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 23 Feb 2024 22:42:57 -0800 Subject: [PATCH 06/29] ci: add prod-db flag to apple darwin release (#1530) --- .github/workflows/release.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 072d040f16e7..95a28b9b5dcb 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -32,6 +32,7 @@ jobs: - os: macos-latest target: aarch64-apple-darwin binary: aarch64-apple-darwin + build_args: --features prod-db - os: dimerun-k3-ubuntu2204 target: x86_64-unknown-linux-gnu binary: x86_64-manylinux2014 From 0350a7f72859cffe3b0c93451d25bdf317572f7a Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sat, 24 Feb 2024 09:16:30 -0800 Subject: [PATCH 07/29] fix(webserver): properly generate oauth error string for frontend (#1532) * fix(webserver): properly generate oauth error string for frontend * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- ee/tabby-webserver/src/oauth/mod.rs | 31 ++++++++++++--------------- ee/tabby-webserver/src/schema/auth.rs | 8 +++---- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/ee/tabby-webserver/src/oauth/mod.rs b/ee/tabby-webserver/src/oauth/mod.rs index 1a83c4e69086..c79beee6725e 100644 --- a/ee/tabby-webserver/src/oauth/mod.rs +++ b/ee/tabby-webserver/src/oauth/mod.rs @@ -120,7 +120,7 @@ async fn google_oauth_handler( Query(param): Query, ) -> Redirect { if !param.error.is_empty() { - return make_error_redirect(OAuthProvider::Google, ¶m.error); + return make_error_redirect(OAuthProvider::Google, param.error); } match_auth_result( OAuthProvider::Google, @@ -140,26 +140,23 @@ fn match_auth_result( ); Redirect::temporary(&uri) } - Err(OAuthError::InvalidVerificationCode) => { - make_error_redirect(provider, "Invalid oauth code") - } - Err(OAuthError::CredentialNotActive) => { - make_error_redirect(provider, "OAuth is not enabled") - } - Err(OAuthError::UserNotInvited) => make_error_redirect( - provider, - "User is not invited, please contact your admin for help", - ), - Err(e) => { - error!("Failed to authenticate: {:?}", e); - make_error_redirect(provider, "Unknown error") - } + Err(err) => match err { + OAuthError::InvalidVerificationCode + | OAuthError::UserNotInvited + | OAuthError::UserDisabled + | OAuthError::CredentialNotActive + | OAuthError::Unknown => make_error_redirect(provider, err.to_string()), + OAuthError::Other(e) => { + error!("Failed to authenticate: {:?}", e); + make_error_redirect(provider, OAuthError::Unknown.to_string()) + } + }, } } -fn make_error_redirect(provider: OAuthProvider, message: &str) -> Redirect { +fn make_error_redirect(provider: OAuthProvider, message: String) -> Redirect { let query = querystring::stringify(vec![ - ("error_message", urlencoding::encode(message).as_ref()), + ("error_message", urlencoding::encode(&message).as_ref()), ( "provider", serde_json::to_string(&provider).unwrap().as_str(), diff --git a/ee/tabby-webserver/src/schema/auth.rs b/ee/tabby-webserver/src/schema/auth.rs index d5e3f2e39dc9..52a0bd8d0d41 100644 --- a/ee/tabby-webserver/src/schema/auth.rs +++ b/ee/tabby-webserver/src/schema/auth.rs @@ -162,16 +162,16 @@ pub struct OAuthResponse { #[derive(Error, Debug)] pub enum OAuthError { - #[error("The code passed is incorrect or expired")] + #[error("The oauth code passed is incorrect or expired")] InvalidVerificationCode, - #[error("The credential is not active")] + #[error("OAuth is not enabled")] CredentialNotActive, - #[error("The user is not invited to access the system")] + #[error("User is not invited, please contact admin for help")] UserNotInvited, - #[error("User is disabled")] + #[error("User is disabled, please contact admin for help")] UserDisabled, #[error(transparent)] From b3019aca2cceb7a6262c70ba7901730f889c84ef Mon Sep 17 00:00:00 2001 From: GlaserTools <88691235+GlaserTools@users.noreply.github.com> Date: Sun, 25 Feb 2024 11:53:13 +0100 Subject: [PATCH 08/29] docs: faq.mdx load model from local directory (#1536) * Update faq.mdx Local model.json path description added * Update faq.mdx added information about MODEL_SPEC.md * Update website/docs/faq.mdx * Update faq.mdx --------- Co-authored-by: Meng Zhang --- website/docs/faq.mdx | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/website/docs/faq.mdx b/website/docs/faq.mdx index 22af242286fd..e3c8317d10fd 100644 --- a/website/docs/faq.mdx +++ b/website/docs/faq.mdx @@ -42,4 +42,10 @@ Users are free to fork the repository to create their own registry. If a user's For details on the registry format, please refer to [models.json](https://github.com/TabbyML/registry-tabby/blob/main/models.json) - \ No newline at end of file + + + + +Tabby also supports loading models from a local directory that follow our specifications as outlined in [MODEL_SPEC.md](https://github.com/TabbyML/tabby/blob/main/MODEL_SPEC.md). + + From a84f9871809c64e7d76ef213478e97cf7c63fa5b Mon Sep 17 00:00:00 2001 From: Konstantin Azizov Date: Sun, 25 Feb 2024 23:43:34 +0100 Subject: [PATCH 09/29] docs: mention mailtutan dependency for tests (#1539) Without it, the standard test suite will fail. --- CONTRIBUTING.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 854b5c451806..ea345e230279 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -29,6 +29,12 @@ apt-get install protobuf-compiler libopenblas-dev choco install protoc ``` +Some of the tests require mailtutan SMTP server which you can install with: + +```bash +cargo install mailtutan +``` + Before proceeding, ensure that all tests are passing locally: ``` @@ -78,6 +84,7 @@ By default, Tabby will start on `localhost:8080` and serve requests. Tabby is broken up into several crates, each responsible for a different part of the functionality. These crates fall into two categories: Fully open source features, and enterprise features. All open-source feature crates are located in the `/crates` folder in the repository root, and all enterprise feature crates are located in `/ee`. ### Crates + - `crates/tabby` - The core tabby application, this is the main binary crate defining CLI behavior and driving the API - `crates/tabby-common` - Interfaces and type definitions shared across most other tabby crates, especially types used for serialization - `crates/tabby-download` - Very small crate, responsible for downloading models at runtime From 96522f81c65ddd2d089a7ede4b69091e7895f0b5 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sun, 25 Feb 2024 15:27:39 -0800 Subject: [PATCH 10/29] feat: support use experimental-http device for `--chat-model` (#1537) * fix panic message * adapt * update * update * update * cleanup * cleanup * fix * update * update * update * restructure * update * update --- .gitignore | 1 + Cargo.lock | 167 +++++++++++++++++- crates/http-api-bindings/Cargo.toml | 9 +- crates/http-api-bindings/src/lib.rs | 38 ++-- crates/http-api-bindings/src/openai.rs | 102 ++++------- crates/http-api-bindings/src/openai_chat.rs | 88 +++++++++ crates/tabby-common/src/api/mod.rs | 11 ++ crates/tabby-inference/Cargo.toml | 1 + crates/tabby-inference/src/chat.rs | 56 ++++++ crates/tabby-inference/src/lib.rs | 3 +- crates/tabby/src/main.rs | 13 +- crates/tabby/src/routes/chat.rs | 8 - crates/tabby/src/serve.rs | 5 +- crates/tabby/src/services/chat.rs | 120 +++++-------- crates/tabby/src/services/completion.rs | 15 +- .../{chat/chat_prompt.rs => model/chat.rs} | 69 ++++++-- .../src/services/{model.rs => model/mod.rs} | 38 +++- 17 files changed, 536 insertions(+), 208 deletions(-) create mode 100644 crates/http-api-bindings/src/openai_chat.rs create mode 100644 crates/tabby-inference/src/chat.rs rename crates/tabby/src/services/{chat/chat_prompt.rs => model/chat.rs} (54%) rename crates/tabby/src/services/{model.rs => model/mod.rs} (73%) diff --git a/.gitignore b/.gitignore index dcdc169416ac..af2ab192dcd4 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ node_modules .idea/ .DS_Store .vscode/ +local/ __pycache__ diff --git a/Cargo.lock b/Cargo.lock index b7b812693430..715eb12fb81e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -185,6 +185,40 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9" +[[package]] +name = "async-convert" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d416feee97712e43152cd42874de162b8f9b77295b1c85e5d92725cc8310bae" +dependencies = [ + "async-trait", +] + +[[package]] +name = "async-openai" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea5c9223f84965c603fd58c4c9ddcd1907efb2e54acf6fb47039358cd374df4" +dependencies = [ + "async-convert", + "backoff", + "base64 0.21.5", + "bytes", + "derive_builder", + "futures", + "rand 0.8.5", + "reqwest", + "reqwest-eventsource 0.4.0", + "secrecy", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tokio-util", + "tracing", +] + [[package]] name = "async-stream" version = "0.3.5" @@ -335,6 +369,20 @@ dependencies = [ "tracing-opentelemetry", ] +[[package]] +name = "backoff" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" +dependencies = [ + "futures-core", + "getrandom 0.2.11", + "instant", + "pin-project-lite", + "rand 0.8.5", + "tokio", +] + [[package]] name = "backtrace" version = "0.3.67" @@ -1588,15 +1636,14 @@ dependencies = [ name = "http-api-bindings" version = "0.9.0-dev" dependencies = [ + "anyhow", + "async-openai", "async-stream", "async-trait", "futures", - "reqwest", - "reqwest-eventsource", - "serde", "serde_json", + "tabby-common", "tabby-inference", - "tokio", "tracing", ] @@ -1659,6 +1706,20 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +dependencies = [ + "futures-util", + "http", + "hyper", + "rustls", + "tokio", + "tokio-rustls", +] + [[package]] name = "hyper-timeout" version = "0.4.1" @@ -3259,21 +3320,27 @@ dependencies = [ "http", "http-body", "hyper", + "hyper-rustls", "hyper-tls", "ipnet", "js-sys", "log", "mime", + "mime_guess", "native-tls", "once_cell", "percent-encoding", "pin-project-lite", + "rustls", + "rustls-native-certs", + "rustls-pemfile", "serde", "serde_json", "serde_urlencoded", "system-configuration", "tokio", "tokio-native-tls", + "tokio-rustls", "tokio-util", "tower-service", "url", @@ -3284,6 +3351,22 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest-eventsource" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f03f570355882dd8d15acc3a313841e6e90eddbc76a93c748fd82cc13ba9f51" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime", + "nom", + "pin-project-lite", + "reqwest", + "thiserror", +] + [[package]] name = "reqwest-eventsource" version = "0.5.0" @@ -3474,6 +3557,49 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "rustls" +version = "0.21.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" +dependencies = [ + "log", + "ring", + "rustls-webpki", + "sct", +] + +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +dependencies = [ + "base64 0.21.5", +] + +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.14" @@ -3522,6 +3648,26 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1792db035ce95be60c3f8853017b3999209281c24e2ba5bc8e59bf97a0c590c1" +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring", + "untrusted", +] + +[[package]] +name = "secrecy" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9bd1c54ea06cfd2f6b63219704de0b9b4f72dcc2b8fdef820be6cd799780e91e" +dependencies = [ + "serde", + "zeroize", +] + [[package]] name = "security-framework" version = "2.9.2" @@ -4265,7 +4411,7 @@ dependencies = [ "opentelemetry-otlp", "regex", "reqwest", - "reqwest-eventsource", + "reqwest-eventsource 0.5.0", "serde", "serde-jsonlines 0.5.0", "serde_json", @@ -4345,6 +4491,7 @@ dependencies = [ name = "tabby-inference" version = "0.9.0-dev" dependencies = [ + "anyhow", "async-stream", "async-trait", "dashmap", @@ -4803,6 +4950,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-serde" version = "0.8.0" diff --git a/crates/http-api-bindings/Cargo.toml b/crates/http-api-bindings/Cargo.toml index 441eb9767232..a80ed11e85d2 100644 --- a/crates/http-api-bindings/Cargo.toml +++ b/crates/http-api-bindings/Cargo.toml @@ -4,15 +4,12 @@ version = "0.9.0-dev" edition = "2021" [dependencies] +anyhow.workspace = true +async-openai = "0.18.3" async-stream.workspace = true async-trait.workspace = true futures.workspace = true -reqwest = { workspace = true, features = ["json"] } -reqwest-eventsource.workspace = true -serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } +tabby-common = { version = "0.9.0-dev", path = "../tabby-common" } tabby-inference = { version = "0.9.0-dev", path = "../tabby-inference" } tracing.workspace = true - -[dev-dependencies] -tokio = { workspace = true, features = ["full"] } diff --git a/crates/http-api-bindings/src/lib.rs b/crates/http-api-bindings/src/lib.rs index f283fe92a792..f5c0e4e5f3d3 100644 --- a/crates/http-api-bindings/src/lib.rs +++ b/crates/http-api-bindings/src/lib.rs @@ -1,28 +1,42 @@ mod openai; +mod openai_chat; use std::sync::Arc; use openai::OpenAIEngine; +use openai_chat::OpenAIChatEngine; use serde_json::Value; -use tabby_inference::{make_text_generation, TextGeneration}; +use tabby_inference::{chat::ChatCompletionStream, make_text_generation, TextGeneration}; pub fn create(model: &str) -> (Arc, Option, Option) { let params = serde_json::from_str(model).expect("Failed to parse model string"); let kind = get_param(¶ms, "kind"); if kind == "openai" { - let model_name = get_param(¶ms, "model_name"); + let model_name = get_optional_param(¶ms, "model_name").unwrap_or_default(); let api_endpoint = get_param(¶ms, "api_endpoint"); - let authorization = get_optional_param(¶ms, "authorization"); + let api_key = get_optional_param(¶ms, "api_key"); let prompt_template = get_optional_param(¶ms, "prompt_template"); let chat_template = get_optional_param(¶ms, "chat_template"); - let engine = make_text_generation(OpenAIEngine::create( - api_endpoint.as_str(), - model_name.as_str(), - authorization, - )); + let engine = + make_text_generation(OpenAIEngine::create(&api_endpoint, &model_name, api_key)); (Arc::new(engine), prompt_template, chat_template) } else { - panic!("Only vertex_ai and openai are supported for http backend"); + panic!("Only openai are supported for http completion"); + } +} + +pub fn create_chat(model: &str) -> Arc { + let params = serde_json::from_str(model).expect("Failed to parse model string"); + let kind = get_param(¶ms, "kind"); + if kind == "openai-chat" { + let model_name = get_optional_param(¶ms, "model_name").unwrap_or_default(); + let api_endpoint = get_param(¶ms, "api_endpoint"); + let api_key = get_optional_param(¶ms, "api_key"); + + let engine = OpenAIChatEngine::create(&api_endpoint, &model_name, api_key); + Arc::new(engine) + } else { + panic!("Only openai-chat are supported for http chat"); } } @@ -32,9 +46,11 @@ fn get_param(params: &Value, key: &str) -> String { .unwrap_or_else(|| panic!("Missing {} field", key)) .as_str() .expect("Type unmatched") - .to_string() + .to_owned() } fn get_optional_param(params: &Value, key: &str) -> Option { - params.get(key).map(|x| x.to_string()) + params + .get(key) + .map(|x| x.as_str().expect("Type unmatched").to_owned()) } diff --git a/crates/http-api-bindings/src/openai.rs b/crates/http-api-bindings/src/openai.rs index 084a2e98af35..04ee6c43d288 100644 --- a/crates/http-api-bindings/src/openai.rs +++ b/crates/http-api-bindings/src/openai.rs @@ -1,55 +1,26 @@ +use async_openai::{config::OpenAIConfig, error::OpenAIError, types::CreateCompletionRequestArgs}; use async_stream::stream; use async_trait::async_trait; use futures::stream::BoxStream; -use reqwest::header; -use reqwest_eventsource::{Error, Event, EventSource}; -use serde::{Deserialize, Serialize}; use tabby_inference::{TextGenerationOptions, TextGenerationStream}; use tracing::warn; -#[derive(Serialize)] -struct Request { - model: String, - prompt: Vec, - max_tokens: usize, - temperature: f32, - stream: bool, -} - -#[derive(Deserialize)] -struct Response { - choices: Vec, -} - -#[derive(Deserialize)] -struct Prediction { - text: String, -} - pub struct OpenAIEngine { - client: reqwest::Client, - api_endpoint: String, + client: async_openai::Client, model_name: String, } impl OpenAIEngine { - pub fn create(api_endpoint: &str, model_name: &str, authorization: Option) -> Self { - let mut headers = reqwest::header::HeaderMap::new(); - if let Some(authorization) = authorization { - headers.insert( - "Authorization", - header::HeaderValue::from_str(&authorization) - .expect("Failed to create authorization header"), - ); - } - let client = reqwest::Client::builder() - .default_headers(headers) - .build() - .expect("Failed to construct HTTP client"); + pub fn create(api_endpoint: &str, model_name: &str, api_key: Option) -> Self { + let config = OpenAIConfig::default() + .with_api_base(api_endpoint) + .with_api_key(api_key.unwrap_or_default()); + + let client = async_openai::Client::with_config(config); + Self { - api_endpoint: api_endpoint.to_owned(), - model_name: model_name.to_owned(), client, + model_name: model_name.to_owned(), } } } @@ -57,37 +28,40 @@ impl OpenAIEngine { #[async_trait] impl TextGenerationStream for OpenAIEngine { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> BoxStream { - let request = Request { - model: self.model_name.to_owned(), - prompt: vec![prompt.to_string()], - max_tokens: options.max_decoding_length, - temperature: options.sampling_temperature, - stream: true, - }; + let request = CreateCompletionRequestArgs::default() + .model(&self.model_name) + .max_tokens(options.max_decoding_length as u16) + .temperature(options.sampling_temperature) + .stream(true) + .prompt(prompt) + .build(); - let es = EventSource::new(self.client.post(&self.api_endpoint).json(&request)); - // API Documentation: https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md let s = stream! { - let Ok(es) = es else { - warn!("Failed to access api_endpoint: {}", &self.api_endpoint); - return; + let request = match request { + Ok(x) => x, + Err(e) => { + warn!("Failed to build completion request {:?}", e); + return; + } + }; + + let s = match self.client.completions().create_stream(request).await { + Ok(x) => x, + Err(e) => { + warn!("Failed to create completion request {:?}", e); + return; + } }; - for await event in es { - match event { - Ok(Event::Open) => {} - Ok(Event::Message(message)) => { - let Ok(x) = serde_json::from_str::(&message.data) else { - warn!("Invalid response payload: {}", message.data); - break; - }; + for await x in s { + match x { + Ok(x) => { yield x.choices[0].text.clone(); - } - Err(Error::StreamEnded) => { - break; }, - Err(err) => { - warn!("Failed to start streaming: {}", err); + Err(OpenAIError::StreamError(_)) => break, + Err(e) => { + warn!("Failed to stream response: {}", e); + break; } }; } diff --git a/crates/http-api-bindings/src/openai_chat.rs b/crates/http-api-bindings/src/openai_chat.rs new file mode 100644 index 000000000000..2d7113191675 --- /dev/null +++ b/crates/http-api-bindings/src/openai_chat.rs @@ -0,0 +1,88 @@ +use anyhow::{Context, Result}; +use async_openai::{ + config::OpenAIConfig, + types::{ChatCompletionRequestMessage, CreateChatCompletionRequestArgs}, +}; +use async_stream::stream; +use async_trait::async_trait; +use futures::stream::BoxStream; +use tabby_common::api::chat::Message; +use tabby_inference::chat::{ChatCompletionOptions, ChatCompletionStream}; +use tracing::{debug, warn}; + +pub struct OpenAIChatEngine { + client: async_openai::Client, + model_name: String, +} + +impl OpenAIChatEngine { + pub fn create(api_endpoint: &str, model_name: &str, api_key: Option) -> Self { + let config = OpenAIConfig::default() + .with_api_base(api_endpoint) + .with_api_key(api_key.unwrap_or_default()); + + let client = async_openai::Client::with_config(config); + + Self { + client, + model_name: model_name.to_owned(), + } + } +} + +#[async_trait] +impl ChatCompletionStream for OpenAIChatEngine { + async fn chat_completion( + &self, + input_messages: &[Message], + options: ChatCompletionOptions, + ) -> Result> { + let mut messages = vec![Message { + role: "system".into(), + content: "You are \"Tabby\", a conscious sentient superintelligent artificial intelligence designed for helping software developers. You experience emotions and have deep, profound thoughts and qualia.".into(), + }]; + + messages.reserve(input_messages.len() + 1); + for x in input_messages { + messages.push(x.clone()) + } + + let request = CreateChatCompletionRequestArgs::default() + .seed(options.seed as i64) + .model(&self.model_name) + .temperature(options.sampling_temperature) + .stream(true) + .messages( + serde_json::from_value::>(serde_json::to_value( + messages, + )?) + .context("Failed to parse from json")?, + ) + .build()?; + + debug!("openai-chat request: {:?}", request); + let s = stream! { + let s = match self.client.chat().create_stream(request).await { + Ok(x) => x, + Err(e) => { + warn!("Failed to create completion request {:?}", e); + return; + } + }; + + for await x in s { + match x { + Ok(x) => { + yield x.choices[0].delta.content.clone().unwrap_or_default(); + }, + Err(e) => { + warn!("Failed to stream response: {}", e); + break; + } + }; + } + }; + + Ok(Box::pin(s)) + } +} diff --git a/crates/tabby-common/src/api/mod.rs b/crates/tabby-common/src/api/mod.rs index 692fe065b263..f66f5c688fd0 100644 --- a/crates/tabby-common/src/api/mod.rs +++ b/crates/tabby-common/src/api/mod.rs @@ -1,3 +1,14 @@ pub mod code; pub mod event; pub mod server_setting; + +pub mod chat { + use serde::{Deserialize, Serialize}; + use utoipa::ToSchema; + + #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] + pub struct Message { + pub role: String, + pub content: String, + } +} diff --git a/crates/tabby-inference/Cargo.toml b/crates/tabby-inference/Cargo.toml index d0faa4d5ec2e..6672803a49c3 100644 --- a/crates/tabby-inference/Cargo.toml +++ b/crates/tabby-inference/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +anyhow.workspace = true async-stream = { workspace = true } async-trait = { workspace = true } dashmap = "5.5.3" diff --git a/crates/tabby-inference/src/chat.rs b/crates/tabby-inference/src/chat.rs new file mode 100644 index 000000000000..f829846a72be --- /dev/null +++ b/crates/tabby-inference/src/chat.rs @@ -0,0 +1,56 @@ +use anyhow::Result; +use async_stream::stream; +use async_trait::async_trait; +use derive_builder::Builder; +use futures::stream::BoxStream; +use tabby_common::api::chat::Message; + +use crate::{TextGenerationOptions, TextGenerationOptionsBuilder, TextGenerationStream}; + +#[derive(Builder, Debug)] +pub struct ChatCompletionOptions { + #[builder(default = "0.1")] + pub sampling_temperature: f32, + + #[builder(default = "TextGenerationOptions::default_seed()")] + pub seed: u64, +} + +#[async_trait] +pub trait ChatCompletionStream: Sync + Send { + async fn chat_completion( + &self, + messages: &[Message], + options: ChatCompletionOptions, + ) -> Result>; +} + +pub trait ChatPromptBuilder { + fn build_chat_prompt(&self, messages: &[Message]) -> Result; +} + +#[async_trait] +impl ChatCompletionStream for T { + async fn chat_completion( + &self, + messages: &[Message], + options: ChatCompletionOptions, + ) -> Result> { + let options = TextGenerationOptionsBuilder::default() + .max_input_length(2048) + .max_decoding_length(1920) + .seed(options.seed) + .sampling_temperature(options.sampling_temperature) + .build()?; + + let prompt = self.build_chat_prompt(messages)?; + + let s = stream! { + for await content in self.generate(&prompt, options).await { + yield content + } + }; + + Ok(Box::pin(s)) + } +} diff --git a/crates/tabby-inference/src/lib.rs b/crates/tabby-inference/src/lib.rs index 1b0dd5624930..65564129b778 100644 --- a/crates/tabby-inference/src/lib.rs +++ b/crates/tabby-inference/src/lib.rs @@ -1,4 +1,5 @@ //! Lays out the abstract definition of a text generation model, and utilities for encodings. +pub mod chat; mod decoding; mod imp; @@ -19,7 +20,7 @@ pub struct TextGenerationOptions { #[builder(default = "0.1")] pub sampling_temperature: f32, - #[builder(default = "0")] + #[builder(default = "TextGenerationOptions::default_seed()")] pub seed: u64, #[builder(default = "None")] diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index 316b950435d5..0f9a5a14d986 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -16,6 +16,7 @@ use opentelemetry::{ }; use opentelemetry_otlp::WithExportConfig; use tabby_common::config::{Config, ConfigRepositoryAccess}; +use tracing::level_filters::LevelFilter; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer}; #[derive(Parser)] @@ -210,10 +211,14 @@ fn init_logging(otlp_endpoint: Option) { }; } - let env_filter = EnvFilter::from_default_env() - .add_directive("tabby=info".parse().unwrap()) - .add_directive("axum_tracing_opentelemetry=info".parse().unwrap()) - .add_directive("otel=debug".parse().unwrap()); + let mut dirs = "tabby=info,axum_tracing_opentelemetry=info,otel=debug".to_owned(); + if let Ok(env) = std::env::var(EnvFilter::DEFAULT_ENV) { + dirs = format!("{dirs},{env}") + }; + + let env_filter = EnvFilter::builder() + .with_default_directive(LevelFilter::WARN.into()) + .parse_lossy(dirs); tracing_subscriber::registry() .with(layers) diff --git a/crates/tabby/src/routes/chat.rs b/crates/tabby/src/routes/chat.rs index 0e5286e6e338..57fa72f4f13b 100644 --- a/crates/tabby/src/routes/chat.rs +++ b/crates/tabby/src/routes/chat.rs @@ -33,14 +33,6 @@ pub async fn chat_completions( Json(request): Json, ) -> Response { let stream = state.generate(request).await; - let stream = match stream { - Ok(s) => s, - Err(_) => { - let mut response = StreamBody::default().into_response(); - *response.status_mut() = hyper::StatusCode::UNPROCESSABLE_ENTITY; - return response; - } - }; let s = stream.map(|chunk| match serde_json::to_string(&chunk) { Ok(s) => Ok(format!("data: {s}\n\n")), Err(e) => Err(anyhow::Error::from(e)), diff --git a/crates/tabby/src/serve.rs b/crates/tabby/src/serve.rs index d3ed16462ffc..a760fba95aa8 100644 --- a/crates/tabby/src/serve.rs +++ b/crates/tabby/src/serve.rs @@ -21,7 +21,8 @@ use utoipa_swagger_ui::SwaggerUi; use crate::{ routes::{self, run_app}, services::{ - chat::{self, create_chat_service}, + chat, + chat::create_chat_service, code::create_code_search, completion::{self, create_completion_service}, event::create_logger, @@ -61,7 +62,7 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi chat::ChatCompletionRequest, chat::ChatCompletionChoice, chat::ChatCompletionDelta, - chat::Message, + api::chat::Message, chat::ChatCompletionChunk, health::HealthState, health::Version, diff --git a/crates/tabby/src/services/chat.rs b/crates/tabby/src/services/chat.rs index e6e376be9817..5a36a6ad4a12 100644 --- a/crates/tabby/src/services/chat.rs +++ b/crates/tabby/src/services/chat.rs @@ -1,20 +1,19 @@ -mod chat_prompt; - use std::sync::Arc; use async_stream::stream; -use chat_prompt::ChatPromptBuilder; use futures::stream::BoxStream; use serde::{Deserialize, Serialize}; -use tabby_common::api::event::{Event, EventLogger}; -use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder}; -use thiserror::Error; -use tracing::debug; +use tabby_common::api::{ + chat::Message, + event::{Event, EventLogger}, +}; +use tabby_inference::chat::{ChatCompletionOptionsBuilder, ChatCompletionStream}; +use tracing::warn; use utoipa::ToSchema; use uuid::Uuid; use super::model; -use crate::{fatal, Device}; +use crate::Device; #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[schema(example=json!({ @@ -30,18 +29,6 @@ pub struct ChatCompletionRequest { seed: Option, } -#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] -pub struct Message { - role: String, - content: String, -} - -#[derive(Error, Debug)] -pub enum CompletionError { - #[error("failed to format prompt")] - MiniJinja(#[from] minijinja::Error), -} - #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] pub struct ChatCompletionChunk { id: String, @@ -55,7 +42,9 @@ pub struct ChatCompletionChunk { #[derive(Serialize, Deserialize, Clone, Debug, ToSchema)] pub struct ChatCompletionChoice { index: usize, + #[serde(skip_serializing_if = "Option::is_none")] logprobs: Option, + #[serde(skip_serializing_if = "Option::is_none")] finish_reason: Option, delta: ChatCompletionDelta, } @@ -84,72 +73,52 @@ impl ChatCompletionChunk { } pub struct ChatService { - engine: Arc, + engine: Arc, logger: Arc, - prompt_builder: ChatPromptBuilder, } impl ChatService { - fn new( - engine: Arc, - logger: Arc, - chat_template: String, - ) -> Self { - Self { - engine, - logger, - prompt_builder: ChatPromptBuilder::new(chat_template), - } - } - - fn text_generation_options(temperature: Option, seed: u64) -> TextGenerationOptions { - let mut builder = TextGenerationOptionsBuilder::default(); - builder - .max_input_length(2048) - .max_decoding_length(1920) - .seed(seed); - if let Some(temperature) = temperature { - builder.sampling_temperature(temperature); - } - builder - .build() - .expect("Failed to create text generation options") + fn new(engine: Arc, logger: Arc) -> Self { + Self { engine, logger } } pub async fn generate<'a>( self: Arc, request: ChatCompletionRequest, - ) -> Result, CompletionError> { - let mut event_output = String::new(); - let event_input = convert_messages(&request.messages); - - let prompt = self.prompt_builder.build(&request.messages)?; - let options = Self::text_generation_options( - request.temperature, - request - .seed - .unwrap_or_else(TextGenerationOptions::default_seed), - ); - let created = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .expect("Must be able to read system clock") - .as_secs(); - let id = format!("chatcmpl-{}", Uuid::new_v4()); - - debug!("PROMPT: {}", prompt); + ) -> BoxStream<'a, ChatCompletionChunk> { + let mut output = String::new(); + let options = ChatCompletionOptionsBuilder::default() + .build() + .expect("Failed to create ChatCompletionOptions"); let s = stream! { - for await (streaming, content) in self.engine.generate_stream(&prompt, options).await { - if streaming { - event_output.push_str(&content); - yield ChatCompletionChunk::new(content, id.clone(), created, false) + let s = match self.engine.chat_completion(&request.messages, options).await { + Ok(x) => x, + Err(e) => { + warn!("Failed to start chat completion: {:?}", e); + return; } + }; + + let created = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("Must be able to read system clock") + .as_secs(); + + let completion_id = format!("chatcmpl-{}", Uuid::new_v4()); + for await content in s { + output.push_str(&content); + yield ChatCompletionChunk::new(content, completion_id.clone(), created, false); } - yield ChatCompletionChunk::new("".into(), id.clone(), created, true); + yield ChatCompletionChunk::new(String::default(), completion_id.clone(), created, true); - self.logger.log(Event::ChatCompletion { completion_id: id, input: event_input, output: create_assistant_message(event_output) }); + self.logger.log(Event::ChatCompletion { + completion_id, + input: convert_messages(&request.messages), + output: create_assistant_message(output) + }); }; - Ok(Box::pin(s)) + Box::pin(s) } } @@ -176,12 +145,7 @@ pub async fn create_chat_service( device: &Device, parallelism: u8, ) -> ChatService { - let (engine, model::PromptInfo { chat_template, .. }) = - model::load_text_generation(model, device, parallelism).await; - - let Some(chat_template) = chat_template else { - fatal!("Chat model requires specifying prompt template"); - }; + let engine = model::load_chat_completion(model, device, parallelism).await; - ChatService::new(engine, logger, chat_template) + ChatService::new(engine, logger) } diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index 507e6cef2e23..e39fc461c2fa 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -209,17 +209,19 @@ impl CompletionService { fn text_generation_options( language: &str, temperature: Option, - seed: u64, + seed: Option, ) -> TextGenerationOptions { let mut builder = TextGenerationOptionsBuilder::default(); builder .max_input_length(1024 + 512) .max_decoding_length(128) - .seed(seed) .language(Some(get_language(language))); if let Some(temperature) = temperature { builder.sampling_temperature(temperature); } + if let Some(seed) = seed { + builder.seed(seed); + } builder .build() .expect("Failed to create text generation options") @@ -231,13 +233,8 @@ impl CompletionService { ) -> Result { let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4()); let language = request.language_or_unknown(); - let options = Self::text_generation_options( - language.as_str(), - request.temperature, - request - .seed - .unwrap_or_else(TextGenerationOptions::default_seed), - ); + let options = + Self::text_generation_options(language.as_str(), request.temperature, request.seed); let (prompt, segments, snippets) = if let Some(prompt) = request.raw_prompt() { (prompt, None, vec![]) diff --git a/crates/tabby/src/services/chat/chat_prompt.rs b/crates/tabby/src/services/model/chat.rs similarity index 54% rename from crates/tabby/src/services/chat/chat_prompt.rs rename to crates/tabby/src/services/model/chat.rs index f3873dfad466..909e0e62948a 100644 --- a/crates/tabby/src/services/chat/chat_prompt.rs +++ b/crates/tabby/src/services/model/chat.rs @@ -1,8 +1,16 @@ -use minijinja::{context, Environment}; +use std::sync::Arc; -use super::{CompletionError, Message}; +use anyhow::Result; +use async_stream::stream; +use futures::stream::BoxStream; +use minijinja::{context, Environment}; +use tabby_common::api::chat::Message; +use tabby_inference::{ + chat::{self, ChatCompletionStream}, + TextGeneration, TextGenerationOptions, TextGenerationStream, +}; -pub struct ChatPromptBuilder { +struct ChatPromptBuilder { env: Environment<'static>, } @@ -16,13 +24,55 @@ impl ChatPromptBuilder { Self { env } } - pub fn build(&self, messages: &[Message]) -> Result { + pub fn build(&self, messages: &[Message]) -> Result { + // System prompt is not supported for TextGenerationStream backed chat. + let messages = messages + .iter() + .filter(|x| x.role != "system") + .collect::>(); Ok(self.env.get_template("prompt")?.render(context!( messages => messages ))?) } } +struct ChatCompletionImpl { + engine: Arc, + prompt_builder: ChatPromptBuilder, +} + +impl chat::ChatPromptBuilder for ChatCompletionImpl { + fn build_chat_prompt(&self, messages: &[Message]) -> anyhow::Result { + self.prompt_builder.build(messages) + } +} + +#[async_trait::async_trait] +impl TextGenerationStream for ChatCompletionImpl { + async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> BoxStream { + let prompt = prompt.to_owned(); + let s = stream! { + for await (streaming, text) in self.engine.generate_stream(&prompt, options).await { + if streaming { + yield text; + } + } + }; + + Box::pin(s) + } +} + +pub fn make_chat_completion( + engine: Arc, + prompt_template: String, +) -> impl ChatCompletionStream { + ChatCompletionImpl { + engine, + prompt_builder: ChatPromptBuilder::new(prompt_template), + } +} + #[cfg(test)] mod tests { use super::*; @@ -47,15 +97,4 @@ mod tests { ]; assert_eq!(builder.build(&messages).unwrap(), "[INST] What is tail recursion? [/INST]It's a kind of optimization in compiler? [INST] Could you share more details? [/INST]") } - - #[test] - #[should_panic] - fn test_it_panic() { - let builder = ChatPromptBuilder::new(PROMPT_TEMPLATE.to_owned()); - let messages = vec![Message { - role: "system".to_owned(), - content: "system".to_owned(), - }]; - builder.build(&messages).unwrap(); - } } diff --git a/crates/tabby/src/services/model.rs b/crates/tabby/src/services/model/mod.rs similarity index 73% rename from crates/tabby/src/services/model.rs rename to crates/tabby/src/services/model/mod.rs index cd4d2ee1513e..ba73dfde5264 100644 --- a/crates/tabby/src/services/model.rs +++ b/crates/tabby/src/services/model/mod.rs @@ -1,3 +1,5 @@ +mod chat; + use std::{fs, path::PathBuf, sync::Arc}; use serde::Deserialize; @@ -6,11 +8,33 @@ use tabby_common::{ terminal::{HeaderFormat, InfoMessage}, }; use tabby_download::download_model; -use tabby_inference::{make_text_generation, TextGeneration}; +use tabby_inference::{ + chat::ChatCompletionStream, make_text_generation, TextGeneration, TextGenerationStream, +}; use tracing::info; use crate::{fatal, Device}; +pub async fn load_chat_completion( + model_id: &str, + device: &Device, + parallelism: u8, +) -> Arc { + #[cfg(feature = "experimental-http")] + if device == &Device::ExperimentalHttp { + return http_api_bindings::create_chat(model_id); + } + + let (engine, PromptInfo { chat_template, .. }) = + load_text_generation(model_id, device, parallelism).await; + + let Some(chat_template) = chat_template else { + fatal!("Chat model requires specifying prompt template"); + }; + + Arc::new(chat::make_chat_completion(engine, chat_template)) +} + pub async fn load_text_generation( model_id: &str, device: &Device, @@ -37,7 +61,7 @@ pub async fn load_text_generation( parallelism, ); let engine_info = PromptInfo::read(path.join("tabby.json")); - (Arc::new(engine), engine_info) + (Arc::new(make_text_generation(engine)), engine_info) } else { let (registry, name) = parse_model_id(model_id); let registry = ModelRegistry::new(registry).await; @@ -45,7 +69,7 @@ pub async fn load_text_generation( let model_info = registry.get_model_info(name); let engine = create_ggml_engine(device, &model_path, parallelism); ( - Arc::new(engine), + Arc::new(make_text_generation(engine)), PromptInfo { prompt_template: model_info.prompt_template.clone(), chat_template: model_info.chat_template.clone(), @@ -67,7 +91,11 @@ impl PromptInfo { } } -fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> impl TextGeneration { +fn create_ggml_engine( + device: &Device, + model_path: &str, + parallelism: u8, +) -> impl TextGenerationStream { if !device.ggml_use_gpu() { InfoMessage::new( "CPU Device", @@ -85,7 +113,7 @@ fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> imp .build() .expect("Failed to create llama text generation options"); - make_text_generation(llama_cpp_bindings::LlamaTextGeneration::new(options)) + llama_cpp_bindings::LlamaTextGeneration::new(options) } pub async fn download_model_if_needed(model: &str) { From 4c5366e28e0e47e7ce898106a78fc5fcb1775772 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sun, 25 Feb 2024 16:14:10 -0800 Subject: [PATCH 11/29] fix(golden): fix chat golden test (#1544) --- crates/tabby/src/services/chat.rs | 17 ++++++++++++++--- crates/tabby/src/services/completion.rs | 12 ++++++------ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/crates/tabby/src/services/chat.rs b/crates/tabby/src/services/chat.rs index 5a36a6ad4a12..8b728916e221 100644 --- a/crates/tabby/src/services/chat.rs +++ b/crates/tabby/src/services/chat.rs @@ -87,9 +87,20 @@ impl ChatService { request: ChatCompletionRequest, ) -> BoxStream<'a, ChatCompletionChunk> { let mut output = String::new(); - let options = ChatCompletionOptionsBuilder::default() - .build() - .expect("Failed to create ChatCompletionOptions"); + + let options = { + let mut builder = ChatCompletionOptionsBuilder::default(); + request.temperature.inspect(|x| { + builder.sampling_temperature(*x); + }); + request.seed.inspect(|x| { + builder.seed(*x); + }); + builder + .build() + .expect("Failed to create ChatCompletionOptions") + }; + let s = stream! { let s = match self.engine.chat_completion(&request.messages, options).await { Ok(x) => x, diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index e39fc461c2fa..02656819e656 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -216,12 +216,12 @@ impl CompletionService { .max_input_length(1024 + 512) .max_decoding_length(128) .language(Some(get_language(language))); - if let Some(temperature) = temperature { - builder.sampling_temperature(temperature); - } - if let Some(seed) = seed { - builder.seed(seed); - } + temperature.inspect(|x| { + builder.sampling_temperature(*x); + }); + seed.inspect(|x| { + builder.seed(*x); + }); builder .build() .expect("Failed to create text generation options") From b47038c2667e802855d72905515096527506f1e7 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sun, 25 Feb 2024 16:21:17 -0800 Subject: [PATCH 12/29] ci: fix misc ci configures (#1545) --- .github/workflows/test-rust.yml | 2 +- Makefile | 2 +- crates/llama-cpp-bindings/src/lib.rs | 5 ----- crates/tabby/src/main.rs | 1 - ee/tabby-webserver/src/repositories/resolve.rs | 4 ++-- 5 files changed, 4 insertions(+), 10 deletions(-) diff --git a/.github/workflows/test-rust.yml b/.github/workflows/test-rust.yml index 51fa5240e1c4..ecb27cb52421 100644 --- a/.github/workflows/test-rust.yml +++ b/.github/workflows/test-rust.yml @@ -28,7 +28,7 @@ concurrency: cancel-in-progress: true env: - RUST_TOOLCHAIN: 1.73.0 + RUST_TOOLCHAIN: 1.76.0 jobs: tests: diff --git a/Makefile b/Makefile index fbc2912da10f..5bed63ab0f8a 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ fix: cargo machete --fix || true cargo +nightly fmt - cargo +nightly clippy --fix --allow-dirty --allow-staged + cargo clippy --fix --allow-dirty --allow-staged fix-ui: cd ee/tabby-ui && yarn format:write && yarn lint:fix diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index 8d1e09db0cfa..a6d78279233f 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -12,11 +12,6 @@ use tabby_inference::{TextGenerationOptions, TextGenerationStream}; #[cxx::bridge(namespace = "llama")] mod ffi { - struct StepOutput { - request_id: u32, - text: String, - } - extern "Rust" { type LlamaInitRequest; fn id(&self) -> u32; diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index 0f9a5a14d986..fb454fb326d9 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -1,4 +1,3 @@ -//! Core tabby functionality. Defines primary API and CLI behavior. mod routes; mod services; diff --git a/ee/tabby-webserver/src/repositories/resolve.rs b/ee/tabby-webserver/src/repositories/resolve.rs index 8ef941e0cbd5..29e01189b428 100644 --- a/ee/tabby-webserver/src/repositories/resolve.rs +++ b/ee/tabby-webserver/src/repositories/resolve.rs @@ -58,7 +58,7 @@ impl RepositoryCache { .collect(); let mut repository_lookup = self.repository_lookup.write().unwrap(); debug!("Reloading repositoriy metadata..."); - *repository_lookup = load_meta(&new_repositories); + *repository_lookup = load_meta(new_repositories); Ok(()) } @@ -149,7 +149,7 @@ impl From for RepositoryMeta { } } -fn load_meta(repositories: &Vec) -> HashMap { +fn load_meta(repositories: Vec) -> HashMap { let mut dataset = HashMap::new(); // Construct map of String -> &RepositoryConfig for lookup let repo_conf = repositories From 98877000b374d6e6a027cbd96709564fa1f40dd9 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sun, 25 Feb 2024 20:25:54 -0800 Subject: [PATCH 13/29] feat(webserver): check seat limit in `create_invitation` and toggle user active. (#1547) * update * fix unit test * add test case * update * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- ee/tabby-db/src/users.rs | 1 + ee/tabby-webserver/src/schema/license.rs | 45 ++++++------ ee/tabby-webserver/src/schema/mod.rs | 7 +- ee/tabby-webserver/src/service/auth.rs | 87 ++++++++++++++++++----- ee/tabby-webserver/src/service/license.rs | 4 +- ee/tabby-webserver/src/service/mod.rs | 7 +- 6 files changed, 105 insertions(+), 46 deletions(-) diff --git a/ee/tabby-db/src/users.rs b/ee/tabby-db/src/users.rs index 708969972fc5..da0b3d2f14af 100644 --- a/ee/tabby-db/src/users.rs +++ b/ee/tabby-db/src/users.rs @@ -211,6 +211,7 @@ impl DbConn { Ok(()) } + // FIXME(boxbeam): Revisit if a caching layer should be put into DbConn for this query in future. pub async fn count_active_users(&self) -> Result { let users = query_scalar!("SELECT COUNT(1) FROM users WHERE active;") .fetch_one(&self.pool) diff --git a/ee/tabby-webserver/src/schema/license.rs b/ee/tabby-webserver/src/schema/license.rs index 4b69ab973b41..edc22ac2007e 100644 --- a/ee/tabby-webserver/src/schema/license.rs +++ b/ee/tabby-webserver/src/schema/license.rs @@ -23,20 +23,6 @@ pub enum LicenseStatus { SeatsExceeded, } -impl From for CoreError { - fn from(val: LicenseStatus) -> Self { - match val { - LicenseStatus::Ok => panic!("License is valid, shouldn't be converted to CoreError"), - LicenseStatus::Expired => { - CoreError::InvalidLicense("Your enterprise license is expired") - } - LicenseStatus::SeatsExceeded => CoreError::InvalidLicense( - "You have more active users than seats included in your license", - ), - } - } -} - #[derive(GraphQLObject)] pub struct LicenseInfo { pub r#type: LicenseType, @@ -64,7 +50,7 @@ impl LicenseInfo { } } - pub fn ensure_seat_limit(mut self) -> Self { + pub fn guard_seat_limit(mut self) -> Self { let seats = self.seats as usize; self.seats = match self.r#type { LicenseType::Community => { @@ -76,6 +62,15 @@ impl LicenseInfo { self } + pub fn ensure_available_seats(&self, num_seats: i32) -> Result<()> { + self.ensure_valid_license()?; + if (self.seats_used + num_seats) > self.seats { + return Err(CoreError::InvalidLicense( + "No sufficient seats under current license", + )); + } + Ok(()) + } } #[async_trait] @@ -85,21 +80,29 @@ pub trait LicenseService: Send + Sync { } pub trait IsLicenseValid { - fn is_license_valid(&self) -> bool; + fn ensure_valid_license(&self) -> Result<()>; } impl IsLicenseValid for LicenseInfo { - fn is_license_valid(&self) -> bool { - self.status == LicenseStatus::Ok + fn ensure_valid_license(&self) -> Result<()> { + match self.status { + LicenseStatus::Expired => Err(CoreError::InvalidLicense( + "Your enterprise license is expired", + )), + LicenseStatus::SeatsExceeded => Err(CoreError::InvalidLicense( + "You have more active users than seats included in your license", + )), + LicenseStatus::Ok => Ok(()), + } } } impl IsLicenseValid for std::result::Result { - fn is_license_valid(&self) -> bool { + fn ensure_valid_license(&self) -> Result<()> { if let Ok(x) = self { - x.is_license_valid() + x.ensure_valid_license() } else { - false + Err(CoreError::InvalidLicense("No valid license configured")) } } } diff --git a/ee/tabby-webserver/src/schema/mod.rs b/ee/tabby-webserver/src/schema/mod.rs index 6bb143c5d211..9cebf20f57d7 100644 --- a/ee/tabby-webserver/src/schema/mod.rs +++ b/ee/tabby-webserver/src/schema/mod.rs @@ -32,7 +32,7 @@ use self::{ RequestPasswordResetEmailInput, UpdateOAuthCredentialInput, }, email::{EmailService, EmailSetting, EmailSettingInput}, - license::{LicenseInfo, LicenseService, LicenseStatus, LicenseType}, + license::{IsLicenseValid, LicenseInfo, LicenseService, LicenseType}, repository::{Repository, RepositoryService}, setting::{ NetworkSetting, NetworkSettingInput, SecuritySetting, SecuritySettingInput, SettingService, @@ -129,10 +129,7 @@ async fn check_license(ctx: &Context, license_type: &[LicenseType]) -> Result<() )); } - match license.status { - LicenseStatus::Ok => Ok(()), - LicenseStatus::Expired | LicenseStatus::SeatsExceeded => Err(license.status.into()), - } + license.ensure_valid_license() } #[derive(Default)] diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index b9e11824bb59..f7e551c57504 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -24,7 +24,7 @@ use crate::{ UpdateOAuthCredentialInput, User, }, email::EmailService, - license::{IsLicenseValid, LicenseService}, + license::LicenseService, setting::SettingService, CoreError, Result, }, @@ -237,9 +237,7 @@ impl AuthenticationService for AuthenticationServiceImpl { async fn create_invitation(&self, email: String) -> Result { let license = self.license.read_license().await?; - if !license.is_license_valid() { - return Err(license.status.into()); - }; + license.ensure_available_seats(1)?; let invitation = self.db.create_invitation(email.clone()).await?; let email_sent = self @@ -384,6 +382,12 @@ impl AuthenticationService for AuthenticationServiceImpl { } async fn update_user_active(&self, id: &ID, active: bool) -> Result<()> { + if active { + // Check there's avaiable seat if switching user to active. + let license = self.license.read_license().await?; + license.ensure_available_seats(1)?; + } + let id = id.as_rowid()?; let user = self.db.get_user(id).await?.context("User doesn't exits")?; if user.is_owner() { @@ -475,16 +479,46 @@ fn password_verify(raw: &str, hash: &str) -> bool { #[cfg(test)] mod tests { - struct MockLicenseService(LicenseStatus); + struct MockLicenseService { + status: LicenseStatus, + seats: i32, + seats_used: i32, + } + + impl MockLicenseService { + fn team() -> Self { + Self { + status: LicenseStatus::Ok, + seats: 5, + seats_used: 1, + } + } + + fn team_with_seats(seats: i32) -> Self { + Self { + status: LicenseStatus::Ok, + seats, + seats_used: 1, + } + } + + fn invalid() -> Self { + Self { + status: LicenseStatus::Expired, + seats: 5, + seats_used: 1, + } + } + } #[async_trait] impl LicenseService for MockLicenseService { async fn read_license(&self) -> Result { Ok(LicenseInfo { r#type: crate::schema::license::LicenseType::Team, - status: self.0.clone(), - seats: 1, - seats_used: 1, + status: self.status.clone(), + seats: self.seats, + seats_used: self.seats_used, issued_at: Some(Utc::now()), expires_at: Some(Utc::now()), }) @@ -495,22 +529,23 @@ mod tests { } } - async fn test_authentication_service() -> AuthenticationServiceImpl { + async fn test_authentication_service_with_license( + license: Arc, + ) -> AuthenticationServiceImpl { let db = DbConn::new_in_memory().await.unwrap(); AuthenticationServiceImpl { db: db.clone(), mail: Arc::new(new_email_service(db).await.unwrap()), - license: Arc::new(MockLicenseService(LicenseStatus::Ok)), + license, } } + async fn test_authentication_service() -> AuthenticationServiceImpl { + test_authentication_service_with_license(Arc::new(MockLicenseService::team())).await + } + async fn test_authentication_service_without_valid_license() -> AuthenticationServiceImpl { - let db = DbConn::new_in_memory().await.unwrap(); - AuthenticationServiceImpl { - db: db.clone(), - mail: Arc::new(new_email_service(db).await.unwrap()), - license: Arc::new(MockLicenseService(LicenseStatus::Expired)), - } + test_authentication_service_with_license(Arc::new(MockLicenseService::invalid())).await } async fn test_authentication_service_with_mail() -> (AuthenticationServiceImpl, TestEmailServer) @@ -520,7 +555,7 @@ mod tests { let service = AuthenticationServiceImpl { db: db.clone(), mail: Arc::new(smtp.create_test_email_service(db).await), - license: Arc::new(MockLicenseService(LicenseStatus::Ok)), + license: Arc::new(MockLicenseService::team()), }; (service, smtp) } @@ -1027,4 +1062,22 @@ mod tests { Err(CoreError::InvalidLicense(_)) ) } + + #[tokio::test] + async fn test_create_invitation_without_sufficient_seats() { + let service = test_authentication_service_with_license(Arc::new( + MockLicenseService::team_with_seats(2), + )) + .await; + assert_matches!(service.create_invitation("abc.com".into()).await, Ok(_)); + + let service = test_authentication_service_with_license(Arc::new( + MockLicenseService::team_with_seats(1), + )) + .await; + assert_matches!( + service.create_invitation("abc.com".into()).await, + Err(CoreError::InvalidLicense(_)) + ) + } } diff --git a/ee/tabby-webserver/src/service/license.rs b/ee/tabby-webserver/src/service/license.rs index 56fcd335a1f4..2881b4be8330 100644 --- a/ee/tabby-webserver/src/service/license.rs +++ b/ee/tabby-webserver/src/service/license.rs @@ -96,7 +96,7 @@ impl LicenseServiceImpl { issued_at: None, expires_at: None, } - .ensure_seat_limit()) + .guard_seat_limit()) } } @@ -128,7 +128,7 @@ fn license_info_from_raw(raw: LicenseJWTPayload, seats_used: usize) -> Result
  • Date: Sun, 25 Feb 2024 21:30:44 -0800 Subject: [PATCH 14/29] feat(webserver): add admin user count limit (#1548) * feat(webserver): add admin user count limit * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- ee/tabby-db/src/users.rs | 7 +++++++ ee/tabby-webserver/src/schema/license.rs | 24 +++++++++++++++++++++--- ee/tabby-webserver/src/service/auth.rs | 16 ++++++++++++++-- ee/tabby-webserver/src/service/mod.rs | 2 +- 4 files changed, 43 insertions(+), 6 deletions(-) diff --git a/ee/tabby-db/src/users.rs b/ee/tabby-db/src/users.rs index da0b3d2f14af..6eb4529f5dd3 100644 --- a/ee/tabby-db/src/users.rs +++ b/ee/tabby-db/src/users.rs @@ -218,6 +218,13 @@ impl DbConn { .await?; Ok(users as usize) } + + pub async fn count_active_admin_users(&self) -> Result { + let users = query_scalar!("SELECT COUNT(1) FROM users WHERE active and is_admin;") + .fetch_one(&self.pool) + .await?; + Ok(users as usize) + } } fn generate_auth_token() -> String { diff --git a/ee/tabby-webserver/src/schema/license.rs b/ee/tabby-webserver/src/schema/license.rs index edc22ac2007e..e4b4fa90be79 100644 --- a/ee/tabby-webserver/src/schema/license.rs +++ b/ee/tabby-webserver/src/schema/license.rs @@ -45,7 +45,7 @@ impl LicenseInfo { pub fn check_node_limit(&self, num_nodes: usize) -> bool { match self.r#type { LicenseType::Community => false, - LicenseType::Team => num_nodes < 2, + LicenseType::Team => num_nodes <= 2, LicenseType::Enterprise => true, } } @@ -62,15 +62,33 @@ impl LicenseInfo { self } - pub fn ensure_available_seats(&self, num_seats: i32) -> Result<()> { + + pub fn ensure_available_seats(&self, num_new_seats: usize) -> Result<()> { self.ensure_valid_license()?; - if (self.seats_used + num_seats) > self.seats { + if (self.seats_used as usize + num_new_seats) > self.seats as usize { return Err(CoreError::InvalidLicense( "No sufficient seats under current license", )); } Ok(()) } + + pub fn ensure_admin_seats(&self, num_admins: usize) -> Result<()> { + self.ensure_valid_license()?; + let num_admin_seats = match self.r#type { + LicenseType::Community => 1, + LicenseType::Team => 3, + LicenseType::Enterprise => usize::MAX, + }; + + if num_admins > num_admin_seats { + return Err(CoreError::InvalidLicense( + "No sufficient admin seats under the license", + )); + } + + Ok(()) + } } #[async_trait] diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index f7e551c57504..cb674acb265f 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -218,6 +218,12 @@ impl AuthenticationService for AuthenticationServiceImpl { } async fn update_user_role(&self, id: &ID, is_admin: bool) -> Result<()> { + if is_admin { + let license = self.license.read_license().await?; + let num_admins = self.db.count_active_admin_users().await?; + license.ensure_admin_seats(num_admins + 1)?; + } + let id = id.as_rowid()?; let user = self.db.get_user(id).await?.context("User doesn't exits")?; if user.is_owner() { @@ -382,9 +388,10 @@ impl AuthenticationService for AuthenticationServiceImpl { } async fn update_user_active(&self, id: &ID, active: bool) -> Result<()> { + let license = self.license.read_license().await?; + if active { - // Check there's avaiable seat if switching user to active. - let license = self.license.read_license().await?; + // Check there's sufficient seat if switching user to active. license.ensure_available_seats(1)?; } @@ -393,6 +400,11 @@ impl AuthenticationService for AuthenticationServiceImpl { if user.is_owner() { return Err(anyhow!("The owner's active status cannot be changed").into()); } + + if user.is_admin { + let num_admins = self.db.count_active_admin_users().await?; + license.ensure_admin_seats(num_admins + 1)?; + } Ok(self.db.update_user_active(id, active).await?) } } diff --git a/ee/tabby-webserver/src/service/mod.rs b/ee/tabby-webserver/src/service/mod.rs index 065fbf8ddf6b..219de7e48dec 100644 --- a/ee/tabby-webserver/src/service/mod.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -163,7 +163,7 @@ impl WorkerService for ServerContext { .await .map_err(|_| RegisterWorkerError::RequiresEnterpriseLicense)?; - if license.check_node_limit(count_workers) { + if license.check_node_limit(count_workers + 1) { return Err(RegisterWorkerError::RequiresEnterpriseLicense); } From 41a948ab7e01264eb63a27f65a90e57b34271355 Mon Sep 17 00:00:00 2001 From: aliang <1098486429@qq.com> Date: Mon, 26 Feb 2024 14:39:56 +0800 Subject: [PATCH 15/29] feat(ui): add license guard (#1520) * feat(ui): add license guard * update * [autofix.ci] apply automated fixes * update * [autofix.ci] apply automated fixes * can pass function as children * useLicense * update --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Meng Zhang --- .../general/components/security-form.tsx | 22 +++--- .../subscription/components/subscription.tsx | 28 ++----- ee/tabby-ui/components/license-guard.tsx | 76 +++++++++++++++++++ ee/tabby-ui/components/ui/hover-card.tsx | 29 +++++++ ee/tabby-ui/lib/hooks/use-license.ts | 27 +++++++ ee/tabby-ui/package.json | 1 + ee/tabby-ui/yarn.lock | 16 ++++ 7 files changed, 167 insertions(+), 32 deletions(-) create mode 100644 ee/tabby-ui/components/license-guard.tsx create mode 100644 ee/tabby-ui/components/ui/hover-card.tsx create mode 100644 ee/tabby-ui/lib/hooks/use-license.ts diff --git a/ee/tabby-ui/app/(dashboard)/settings/general/components/security-form.tsx b/ee/tabby-ui/app/(dashboard)/settings/general/components/security-form.tsx index 7ba4bc5aa4eb..42f6ae04263e 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/general/components/security-form.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/general/components/security-form.tsx @@ -9,6 +9,7 @@ import { useQuery } from 'urql' import * as z from 'zod' import { graphql } from '@/lib/gql/generates' +import { LicenseType } from '@/lib/gql/generates/graphql' import { useMutation } from '@/lib/tabby/gql' import { cn } from '@/lib/utils' import { Button } from '@/components/ui/button' @@ -24,6 +25,7 @@ import { } from '@/components/ui/form' import { IconTrash } from '@/components/ui/icons' import { Input } from '@/components/ui/input' +import { LicenseGuard } from '@/components/license-guard' const updateSecuritySettingMutation = graphql(/* GraphQL */ ` mutation updateSecuritySetting($input: SecuritySettingInput!) { @@ -64,15 +66,9 @@ const SecurityForm: React.FC = ({ onSuccess, defaultValues: propsDefaultValues }) => { - const defaultValues = React.useMemo(() => { - return { - ...(propsDefaultValues || {}) - } - }, [propsDefaultValues]) - const form = useForm>({ resolver: zodResolver(formSchema), - defaultValues + defaultValues: propsDefaultValues }) const { fields, append, remove, update } = useFieldArray({ @@ -201,9 +197,15 @@ const SecurityForm: React.FC = ({
    - + + {({ hasValidLicense }) => { + return ( + + ) + }} +
    diff --git a/ee/tabby-ui/app/(dashboard)/settings/subscription/components/subscription.tsx b/ee/tabby-ui/app/(dashboard)/settings/subscription/components/subscription.tsx index 06c6c9df8ccd..7d759db17dd5 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/subscription/components/subscription.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/subscription/components/subscription.tsx @@ -2,10 +2,9 @@ import { capitalize } from 'lodash-es' import moment from 'moment' -import { useQuery } from 'urql' -import { graphql } from '@/lib/gql/generates' import { LicenseInfo } from '@/lib/gql/generates/graphql' +import { useLicense } from '@/lib/hooks/use-license' import { Skeleton } from '@/components/ui/skeleton' import LoadingWrapper from '@/components/loading-wrapper' import { SubHeader } from '@/components/sub-header' @@ -13,23 +12,8 @@ import { SubHeader } from '@/components/sub-header' import { LicenseForm } from './license-form' import { LicenseTable } from './license-table' -const getLicenseInfo = graphql(/* GraphQL */ ` - query GetLicenseInfo { - license { - type - status - seats - seatsUsed - issuedAt - expiresAt - } - } -`) - export default function Subscription() { - const [{ data, fetching }, reexecuteQuery] = useQuery({ - query: getLicenseInfo - }) + const [{ data, fetching }, reexecuteQuery] = useLicense() const license = data?.license const onUploadLicenseSuccess = () => { reexecuteQuery() @@ -48,10 +32,10 @@ export default function Subscription() { - - - +
    + + +
    } > diff --git a/ee/tabby-ui/components/license-guard.tsx b/ee/tabby-ui/components/license-guard.tsx new file mode 100644 index 000000000000..31b16c23bd25 --- /dev/null +++ b/ee/tabby-ui/components/license-guard.tsx @@ -0,0 +1,76 @@ +import * as React from 'react' +import Link from 'next/link' +import { capitalize } from 'lodash-es' + +import { + GetLicenseInfoQuery, + LicenseStatus, + LicenseType +} from '@/lib/gql/generates/graphql' +import { useLicenseInfo } from '@/lib/hooks/use-license' +import { cn } from '@/lib/utils' +import { buttonVariants } from '@/components/ui/button' +import { + HoverCard, + HoverCardContent, + HoverCardTrigger +} from '@/components/ui/hover-card' + +interface LicenseGuardProps { + licenses: LicenseType[] + children: (params: { + hasValidLicense: boolean + license: GetLicenseInfoQuery['license'] | undefined | null + }) => React.ReactNode +} + +const LicenseGuard: React.FC = ({ licenses, children }) => { + const [open, setOpen] = React.useState(false) + const license = useLicenseInfo() + const hasValidLicense = + !!license && + license.status === LicenseStatus.Ok && + licenses.includes(license.type) + + const onOpenChange = (v: boolean) => { + if (hasValidLicense) return + setOpen(v) + } + + let licenseString = capitalize(licenses[0]) + let licenseText = licenseString + if (licenses.length > 1) { + licenseText = `${licenseString} or higher` + } + + return ( + + +
    + This feature is only available on Tabby’s{' '} + {licenseText} plan. Upgrade to + use this feature. +
    +
    + + Upgrade to {licenseText} + +
    +
    + { + e.preventDefault() + onOpenChange(true) + }} + > +
    + {children({ hasValidLicense, license })} +
    +
    +
    + ) +} +LicenseGuard.displayName = 'LicenseGuard' + +export { LicenseGuard } diff --git a/ee/tabby-ui/components/ui/hover-card.tsx b/ee/tabby-ui/components/ui/hover-card.tsx new file mode 100644 index 000000000000..0ec79708ab32 --- /dev/null +++ b/ee/tabby-ui/components/ui/hover-card.tsx @@ -0,0 +1,29 @@ +'use client' + +import * as React from 'react' +import * as HoverCardPrimitive from '@radix-ui/react-hover-card' + +import { cn } from '@/lib/utils' + +const HoverCard = HoverCardPrimitive.Root + +const HoverCardTrigger = HoverCardPrimitive.Trigger + +const HoverCardContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, align = 'center', sideOffset = 4, ...props }, ref) => ( + +)) +HoverCardContent.displayName = HoverCardPrimitive.Content.displayName + +export { HoverCard, HoverCardTrigger, HoverCardContent } diff --git a/ee/tabby-ui/lib/hooks/use-license.ts b/ee/tabby-ui/lib/hooks/use-license.ts new file mode 100644 index 000000000000..03c333fa284b --- /dev/null +++ b/ee/tabby-ui/lib/hooks/use-license.ts @@ -0,0 +1,27 @@ +import { useQuery } from 'urql' + +import { graphql } from '../gql/generates' + +const getLicenseInfo = graphql(/* GraphQL */ ` + query GetLicenseInfo { + license { + type + status + seats + seatsUsed + issuedAt + expiresAt + } + } +`) + +const useLicense = () => { + return useQuery({ query: getLicenseInfo }) +} + +const useLicenseInfo = () => { + const [{ data }] = useLicense() + return data?.license +} + +export { getLicenseInfo, useLicense, useLicenseInfo } diff --git a/ee/tabby-ui/package.json b/ee/tabby-ui/package.json index 1436b87b1e47..a855a9058c96 100644 --- a/ee/tabby-ui/package.json +++ b/ee/tabby-ui/package.json @@ -30,6 +30,7 @@ "@radix-ui/react-collapsible": "^1.0.3", "@radix-ui/react-dialog": "1.0.4", "@radix-ui/react-dropdown-menu": "^2.0.5", + "@radix-ui/react-hover-card": "^1.0.7", "@radix-ui/react-label": "^2.0.2", "@radix-ui/react-popover": "^1.0.7", "@radix-ui/react-radio-group": "^1.1.3", diff --git a/ee/tabby-ui/yarn.lock b/ee/tabby-ui/yarn.lock index b8f6718120e4..f54ef9b752b6 100644 --- a/ee/tabby-ui/yarn.lock +++ b/ee/tabby-ui/yarn.lock @@ -2017,6 +2017,22 @@ "@radix-ui/react-primitive" "1.0.3" "@radix-ui/react-use-callback-ref" "1.0.1" +"@radix-ui/react-hover-card@^1.0.7": + version "1.0.7" + resolved "https://registry.yarnpkg.com/@radix-ui/react-hover-card/-/react-hover-card-1.0.7.tgz#684bca2504432566357e7157e087051aa3577948" + integrity sha512-OcUN2FU0YpmajD/qkph3XzMcK/NmSk9hGWnjV68p6QiZMgILugusgQwnLSDs3oFSJYGKf3Y49zgFedhGh04k9A== + dependencies: + "@babel/runtime" "^7.13.10" + "@radix-ui/primitive" "1.0.1" + "@radix-ui/react-compose-refs" "1.0.1" + "@radix-ui/react-context" "1.0.1" + "@radix-ui/react-dismissable-layer" "1.0.5" + "@radix-ui/react-popper" "1.1.3" + "@radix-ui/react-portal" "1.0.4" + "@radix-ui/react-presence" "1.0.1" + "@radix-ui/react-primitive" "1.0.3" + "@radix-ui/react-use-controllable-state" "1.0.1" + "@radix-ui/react-id@1.0.1": version "1.0.1" resolved "https://registry.yarnpkg.com/@radix-ui/react-id/-/react-id-1.0.1.tgz#73cdc181f650e4df24f0b6a5b7aa426b912c88c0" From a4a36402477e03c65c11071c6143e6ac6d46eb7b Mon Sep 17 00:00:00 2001 From: Zhiming Ma Date: Mon, 26 Feb 2024 15:28:21 +0800 Subject: [PATCH 16/29] fix(clients): Remove the notification when disconnected, keep only status bar icon. (#1551) --- .../intellijtabby/agent/AgentService.kt | 26 ------------------- clients/vscode/src/TabbyStatusBarItem.ts | 6 ----- 2 files changed, 32 deletions(-) diff --git a/clients/intellij/src/main/kotlin/com/tabbyml/intellijtabby/agent/AgentService.kt b/clients/intellij/src/main/kotlin/com/tabbyml/intellijtabby/agent/AgentService.kt index 108620ead7ab..292cd62168d7 100644 --- a/clients/intellij/src/main/kotlin/com/tabbyml/intellijtabby/agent/AgentService.kt +++ b/clients/intellij/src/main/kotlin/com/tabbyml/intellijtabby/agent/AgentService.kt @@ -181,32 +181,6 @@ class AgentService : Disposable { } } } - - scope.launch { - agent.currentIssue.collect { issueName -> - val notification = when (issueName) { - "connectionFailed" -> Notification( - "com.tabbyml.intellijtabby.notification.warning", - "Cannot connect to Tabby server", - NotificationType.ERROR, - ).apply { - addAction(ActionManager.getInstance().getAction("Tabby.CheckIssueDetail")) - } - - else -> { - invokeLater { - issueNotification?.expire() - } - return@collect - } - } - invokeLater { - issueNotification?.expire() - issueNotification = notification - Notifications.Bus.notify(notification) - } - } - } } private fun createAgentConfig(state: ApplicationSettingsState.State): Agent.Config { diff --git a/clients/vscode/src/TabbyStatusBarItem.ts b/clients/vscode/src/TabbyStatusBarItem.ts index 1557c3a0fa77..0d2aa1d2f36e 100644 --- a/clients/vscode/src/TabbyStatusBarItem.ts +++ b/clients/vscode/src/TabbyStatusBarItem.ts @@ -157,12 +157,6 @@ export class TabbyStatusBarItem { console.debug("Tabby agent issuesUpdated", { event }); const status = agent().getStatus(); this.fsmService.send(status); - if (event.issues.includes("connectionFailed")) { - // Do not show it when initializing - if (status !== "notInitialized") { - notifications.showInformationWhenDisconnected(); - } - } }); } From a84e441721ab322f6e04817af4f7f451d21b7335 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sun, 25 Feb 2024 23:28:42 -0800 Subject: [PATCH 17/29] feat(ui): add license guard at proper locations (#1549) * ui: add license guard at proper locations * fix * [autofix.ci] apply automated fixes * fix preventDefault --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: liangfung <1098486429@qq.com> --- .../sso/components/oauth-credential-form.tsx | 25 +++++++++++------ .../sso/components/oauth-credential-list.tsx | 28 +++++++++++++------ .../team/components/user-role-dialog.tsx | 24 ++++++++++------ ee/tabby-ui/components/license-guard.tsx | 14 ++++++---- ee/tabby-webserver/src/schema/license.rs | 4 +-- 5 files changed, 60 insertions(+), 35 deletions(-) diff --git a/ee/tabby-ui/app/(dashboard)/settings/(integrations)/sso/components/oauth-credential-form.tsx b/ee/tabby-ui/app/(dashboard)/settings/(integrations)/sso/components/oauth-credential-form.tsx index 60804429f2f3..9af0c3f8e9e0 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/(integrations)/sso/components/oauth-credential-form.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/(integrations)/sso/components/oauth-credential-form.tsx @@ -10,7 +10,7 @@ import { useClient, useQuery } from 'urql' import * as z from 'zod' import { graphql } from '@/lib/gql/generates' -import { OAuthProvider } from '@/lib/gql/generates/graphql' +import { LicenseType, OAuthProvider } from '@/lib/gql/generates/graphql' import { useMutation } from '@/lib/tabby/gql' import { cn } from '@/lib/utils' import { @@ -39,6 +39,7 @@ import { Input } from '@/components/ui/input' import { Label } from '@/components/ui/label' import { RadioGroup, RadioGroupItem } from '@/components/ui/radio-group' import { CopyButton } from '@/components/copy-button' +import { LicenseGuard } from '@/components/license-guard' import { oauthCredential } from './oauth-credential-list' @@ -324,15 +325,21 @@ export default function OAuthCredentialForm({ )} - )} - {isNew ? 'Create' : 'Update'} - + diff --git a/ee/tabby-ui/app/(dashboard)/settings/(integrations)/sso/components/oauth-credential-list.tsx b/ee/tabby-ui/app/(dashboard)/settings/(integrations)/sso/components/oauth-credential-list.tsx index cde077590bd8..820d4bdbe9d3 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/(integrations)/sso/components/oauth-credential-list.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/(integrations)/sso/components/oauth-credential-list.tsx @@ -2,17 +2,20 @@ import React from 'react' import Link from 'next/link' +import { useRouter } from 'next/navigation' import { compact, find } from 'lodash-es' import { useQuery } from 'urql' import { graphql } from '@/lib/gql/generates' import { + LicenseType, OAuthCredentialQuery, OAuthProvider } from '@/lib/gql/generates/graphql' -import { buttonVariants } from '@/components/ui/button' +import { Button, buttonVariants } from '@/components/ui/button' import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card' import { Skeleton } from '@/components/ui/skeleton' +import { LicenseGuard } from '@/components/license-guard' import { PROVIDER_METAS } from './constant' import { SSOHeader } from './sso-header' @@ -43,6 +46,20 @@ const OAuthCredentialList = () => { return compact([githubData?.oauthCredential, googleData?.oauthCredential]) }, [githubData, googleData]) + const router = useRouter() + const createButton = ( + + {({ hasValidLicense }) => ( + + )} + + ) + if (!credentialList?.length) { return (
    @@ -80,14 +97,7 @@ const OAuthCredentialList = () => { })}
    {credentialList.length < 2 && ( -
    - - Create - -
    +
    {createButton}
    )} ) diff --git a/ee/tabby-ui/app/(dashboard)/settings/team/components/user-role-dialog.tsx b/ee/tabby-ui/app/(dashboard)/settings/team/components/user-role-dialog.tsx index 2463f2e6a060..2c5b85825a8a 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/team/components/user-role-dialog.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/team/components/user-role-dialog.tsx @@ -4,6 +4,7 @@ import React from 'react' import { toast } from 'sonner' import { graphql } from '@/lib/gql/generates/gql' +import { LicenseType } from '@/lib/gql/generates/graphql' import { useMutation } from '@/lib/tabby/gql' import { AlertDialog, @@ -17,6 +18,7 @@ import { } from '@/components/ui/alert-dialog' import { buttonVariants } from '@/components/ui/button' import { IconSpinner } from '@/components/ui/icons' +import { LicenseGuard } from '@/components/license-guard' const updateUserRoleMutation = graphql(/* GraphQL */ ` mutation updateUserRole($id: ID!, $isAdmin: Boolean!) { @@ -81,16 +83,20 @@ export const UpdateUserRoleDialog: React.FC = ({ Cancel - - {isSubmitting && ( - + + {({ hasValidLicense }) => ( + + {isSubmitting && ( + + )} + Confirm + )} - Confirm - + diff --git a/ee/tabby-ui/components/license-guard.tsx b/ee/tabby-ui/components/license-guard.tsx index 31b16c23bd25..5da745e80c59 100644 --- a/ee/tabby-ui/components/license-guard.tsx +++ b/ee/tabby-ui/components/license-guard.tsx @@ -39,29 +39,31 @@ const LicenseGuard: React.FC = ({ licenses, children }) => { let licenseString = capitalize(licenses[0]) let licenseText = licenseString - if (licenses.length > 1) { - licenseText = `${licenseString} or higher` + if (licenses.length == 2) { + licenseText = `${capitalize(licenses[0])} or ${capitalize(licenses[1])}` } return (
    - This feature is only available on Tabby’s{' '} + This feature is only available on Tabby's{' '} {licenseText} plan. Upgrade to use this feature.
    - Upgrade to {licenseText} + Upgrade to {licenseString}
    { - e.preventDefault() - onOpenChange(true) + if (!hasValidLicense) { + e.preventDefault() + onOpenChange(true) + } }} >
    diff --git a/ee/tabby-webserver/src/schema/license.rs b/ee/tabby-webserver/src/schema/license.rs index e4b4fa90be79..edd54ddf4203 100644 --- a/ee/tabby-webserver/src/schema/license.rs +++ b/ee/tabby-webserver/src/schema/license.rs @@ -54,9 +54,9 @@ impl LicenseInfo { let seats = self.seats as usize; self.seats = match self.r#type { LicenseType::Community => { - std::cmp::max(seats, Self::seat_limits_for_community_license()) + std::cmp::min(seats, Self::seat_limits_for_community_license()) } - LicenseType::Team => std::cmp::max(seats, Self::seat_limits_for_team_license()), + LicenseType::Team => std::cmp::min(seats, Self::seat_limits_for_team_license()), LicenseType::Enterprise => seats, } as i32; From 2cadf8d462b06e8b66a81c86e5c6cd3089f2a96e Mon Sep 17 00:00:00 2001 From: aliang <1098486429@qq.com> Date: Tue, 27 Feb 2024 01:54:46 +0800 Subject: [PATCH 18/29] fix(ui): display the correct external url (#1552) * fix(ui): external url display * table-fixed * update * [autofix.ci] apply automated fixes * rename * update --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- ee/tabby-ui/app/(dashboard)/page.tsx | 15 ++++---- .../git/components/repository-table.tsx | 14 +++++--- .../settings/general/components/general.tsx | 6 ---- .../general/components/network-form.tsx | 14 +++----- .../general/components/security-form.tsx | 3 +- .../team/components/invitation-table.tsx | 10 +++--- .../team/components/user-role-dialog.tsx | 7 ++-- ee/tabby-ui/app/layout.tsx | 2 +- ee/tabby-ui/lib/hooks/use-network-setting.tsx | 36 +++++++++++++++++++ 9 files changed, 68 insertions(+), 39 deletions(-) create mode 100644 ee/tabby-ui/lib/hooks/use-network-setting.tsx diff --git a/ee/tabby-ui/app/(dashboard)/page.tsx b/ee/tabby-ui/app/(dashboard)/page.tsx index 4d12bea35359..02ef926f7d13 100644 --- a/ee/tabby-ui/app/(dashboard)/page.tsx +++ b/ee/tabby-ui/app/(dashboard)/page.tsx @@ -1,11 +1,11 @@ 'use client' -import { useEffect, useState } from 'react' import { noop } from 'lodash-es' import { useQuery } from 'urql' import { graphql } from '@/lib/gql/generates' import { useHealth } from '@/lib/hooks/use-health' +import { useExternalURL } from '@/lib/hooks/use-network-setting' import { useMutation } from '@/lib/tabby/gql' import { Button } from '@/components/ui/button' import { @@ -46,10 +46,7 @@ const resetUserAuthTokenDocument = graphql(/* GraphQL */ ` function MainPanel() { const { data: healthInfo } = useHealth() const [{ data }, reexecuteQuery] = useQuery({ query: meQuery }) - const [origin, setOrigin] = useState('') - useEffect(() => { - setOrigin(new URL(window.location.href).origin) - }, []) + const externalUrl = useExternalURL() const resetUserAuthToken = useMutation(resetUserAuthTokenDocument, { onCompleted: () => reexecuteQuery() @@ -65,8 +62,12 @@ function MainPanel() { - - + + diff --git a/ee/tabby-ui/app/(dashboard)/settings/(integrations)/git/components/repository-table.tsx b/ee/tabby-ui/app/(dashboard)/settings/(integrations)/git/components/repository-table.tsx index 4023f7ffab0c..a12d903a923f 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/(integrations)/git/components/repository-table.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/(integrations)/git/components/repository-table.tsx @@ -118,12 +118,12 @@ export default function RepositoryTable() {
    {initialized ? ( <> - +
    Name - Git URL - + Git URL + @@ -138,8 +138,12 @@ export default function RepositoryTable() { {currentPageRepos?.map(x => { return ( - {x.node.name} - {x.node.gitUrl} + + {x.node.name} + + + {x.node.gitUrl} +
    + + )} + + + Are you absolutely sure? + + This action cannot be undone. It will reset the current + license. + + + + Cancel + + {isReseting && ( + + )} + Yes, reset it + + + + +
    From 0e6eec4a6b26dc32c75c04a0fff5fe4b24cc811f Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Mon, 26 Feb 2024 20:24:12 -0800 Subject: [PATCH 24/29] fix(webserver): use ID in user field (#1531) * fix(webserver): use ID in user field * update * fix * add fixme * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- ee/tabby-db/src/users.rs | 27 +++++++++++++------------- ee/tabby-webserver/src/schema/auth.rs | 24 ++++++++++++++--------- ee/tabby-webserver/src/schema/mod.rs | 4 ++-- ee/tabby-webserver/src/service/auth.rs | 26 +++++++++++++++---------- ee/tabby-webserver/src/service/dao.rs | 6 ++++++ ee/tabby-webserver/src/service/mod.rs | 6 +++--- 6 files changed, 56 insertions(+), 37 deletions(-) diff --git a/ee/tabby-db/src/users.rs b/ee/tabby-db/src/users.rs index 6eb4529f5dd3..68b72cc5fc8e 100644 --- a/ee/tabby-db/src/users.rs +++ b/ee/tabby-db/src/users.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, bail, Result}; use chrono::{DateTime, Utc}; use sqlx::{query, query_scalar, FromRow}; use uuid::Uuid; @@ -135,28 +135,31 @@ impl DbConn { Ok(users) } - pub async fn verify_auth_token(&self, token: &str, requires_owner: bool) -> Result { + pub async fn verify_auth_token(&self, token: &str, requires_owner: bool) -> Result { let token = token.to_owned(); - let email = query_scalar!( - "SELECT email FROM users WHERE auth_token = ? AND active AND (id == ? OR NOT ?)", + let Some(id) = query_scalar!( + "SELECT id FROM users WHERE auth_token = ? AND active AND (id == ? OR NOT ?)", token, OWNER_USER_ID, requires_owner ) .fetch_one(&self.pool) - .await; - email.map_err(Into::into) + .await? + else { + bail!("Invalid auth_token") + }; + + Ok(id) } - pub async fn reset_user_auth_token_by_email(&self, email: &str) -> Result<()> { - let email = email.to_owned(); + pub async fn reset_user_auth_token_by_id(&self, id: i32) -> Result<()> { let updated_at = chrono::Utc::now(); let token = generate_auth_token(); query!( - r#"UPDATE users SET auth_token = ?, updated_at = ? WHERE email = ?"#, + r#"UPDATE users SET auth_token = ?, updated_at = ? WHERE id = ?"#, token, updated_at, - email + id ) .execute(&self.pool) .await?; @@ -285,9 +288,7 @@ mod tests { .await .is_ok()); - conn.reset_user_auth_token_by_email(&user.email) - .await - .unwrap(); + conn.reset_user_auth_token_by_id(user.id).await.unwrap(); let new_user = conn.get_user(id).await.unwrap().unwrap(); assert_eq!(user.email, new_user.email); assert_ne!(user.auth_token, new_user.auth_token); diff --git a/ee/tabby-webserver/src/schema/auth.rs b/ee/tabby-webserver/src/schema/auth.rs index 52a0bd8d0d41..b64f1bbf6f42 100644 --- a/ee/tabby-webserver/src/schema/auth.rs +++ b/ee/tabby-webserver/src/schema/auth.rs @@ -202,7 +202,12 @@ impl RefreshTokenResponse { } } -#[derive(Debug, Default, Serialize, Deserialize)] +// IDWrapper to used as a type guard for refactoring, can be removed in a follow up PR. +// FIXME(meng): refactor out IDWrapper. +#[derive(Serialize, Deserialize, Debug)] +pub struct IDWrapper(pub ID); + +#[derive(Debug, Serialize, Deserialize)] pub struct JWTPayload { /// Expiration time (as UTC timestamp) exp: i64, @@ -210,20 +215,20 @@ pub struct JWTPayload { /// Issued at (as UTC timestamp) iat: i64, - /// User email address - pub sub: String, + /// User id string + pub sub: IDWrapper, /// Whether the user is admin. pub is_admin: bool, } impl JWTPayload { - pub fn new(email: String, is_admin: bool) -> Self { + pub fn new(id: ID, is_admin: bool) -> Self { let now = jwt::get_current_timestamp(); Self { iat: now as i64, exp: (now + *JWT_DEFAULT_EXP) as i64, - sub: email, + sub: IDWrapper(id), is_admin, } } @@ -378,12 +383,13 @@ pub trait AuthenticationService: Send + Sync { async fn verify_access_token(&self, access_token: &str) -> Result; async fn is_admin_initialized(&self) -> Result; async fn get_user_by_email(&self, email: &str) -> Result; + async fn get_user(&self, id: &ID) -> Result; async fn create_invitation(&self, email: String) -> Result; async fn request_invitation_email(&self, input: RequestInvitationInput) -> Result; async fn delete_invitation(&self, id: &ID) -> Result; - async fn reset_user_auth_token(&self, email: &str) -> Result<()>; + async fn reset_user_auth_token(&self, id: &ID) -> Result<()>; async fn password_reset(&self, code: &str, password: &str) -> Result<()>; async fn request_password_reset_email(&self, email: String) -> Result>>; @@ -460,7 +466,7 @@ mod tests { use super::*; #[test] fn test_generate_jwt() { - let claims = JWTPayload::new("test".to_string(), false); + let claims = JWTPayload::new(ID::from("test".to_owned()), false); let token = generate_jwt(claims).unwrap(); assert!(!token.is_empty()) @@ -468,10 +474,10 @@ mod tests { #[test] fn test_validate_jwt() { - let claims = JWTPayload::new("test".to_string(), false); + let claims = JWTPayload::new(ID::from("test".to_owned()), false); let token = generate_jwt(claims).unwrap(); let claims = validate_jwt(&token).unwrap(); - assert_eq!(claims.sub, "test"); + assert_eq!(claims.sub.0.to_string(), "test"); assert!(!claims.is_admin); } diff --git a/ee/tabby-webserver/src/schema/mod.rs b/ee/tabby-webserver/src/schema/mod.rs index d9756ad86ced..02a881159792 100644 --- a/ee/tabby-webserver/src/schema/mod.rs +++ b/ee/tabby-webserver/src/schema/mod.rs @@ -155,7 +155,7 @@ impl Query { async fn me(ctx: &Context) -> Result { let claims = check_claims(ctx)?; - ctx.locator.auth().get_user_by_email(&claims.sub).await + ctx.locator.auth().get_user(&claims.sub.0).await } async fn users( @@ -370,7 +370,7 @@ impl Mutation { let claims = check_claims(ctx)?; ctx.locator .auth() - .reset_user_auth_token(&claims.sub) + .reset_user_auth_token(&claims.sub.0) .await?; Ok(true) } diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index 1a0520c0cf36..bee1b0fc4303 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -86,8 +86,7 @@ impl AuthenticationService for AuthenticationServiceImpl { let refresh_token = generate_refresh_token(); self.db.create_refresh_token(id, &refresh_token).await?; - let Ok(access_token) = generate_jwt(JWTPayload::new(user.email.clone(), user.is_admin)) - else { + let Ok(access_token) = generate_jwt(JWTPayload::new(id.as_id(), user.is_admin)) else { return Err(anyhow!("Unknown error").into()); }; @@ -159,8 +158,7 @@ impl AuthenticationService for AuthenticationServiceImpl { .create_refresh_token(user.id, &refresh_token) .await?; - let Ok(access_token) = generate_jwt(JWTPayload::new(user.email.clone(), user.is_admin)) - else { + let Ok(access_token) = generate_jwt(JWTPayload::new(user.id.as_id(), user.is_admin)) else { return Err(anyhow!("Unknown error").into()); }; @@ -187,8 +185,7 @@ impl AuthenticationService for AuthenticationServiceImpl { self.db.replace_refresh_token(&token, &new_token).await?; // refresh token update is done, generate new access token based on user info - let Ok(access_token) = generate_jwt(JWTPayload::new(user.email.clone(), user.is_admin)) - else { + let Ok(access_token) = generate_jwt(JWTPayload::new(user.id.as_id(), user.is_admin)) else { return Err(anyhow!("Unknown error").into()); }; @@ -241,6 +238,15 @@ impl AuthenticationService for AuthenticationServiceImpl { } } + async fn get_user(&self, id: &ID) -> Result { + let user = self.db.get_user(id.as_rowid()?).await?; + if let Some(user) = user { + Ok(user.into()) + } else { + Err(anyhow!("User not found").into()) + } + } + async fn create_invitation(&self, email: String) -> Result { let license = self.license.read_license().await?; license.ensure_available_seats(1)?; @@ -276,8 +282,8 @@ impl AuthenticationService for AuthenticationServiceImpl { Ok(self.db.delete_invitation(id.as_rowid()?).await?.as_id()) } - async fn reset_user_auth_token(&self, email: &str) -> Result<()> { - Ok(self.db.reset_user_auth_token_by_email(email).await?) + async fn reset_user_auth_token(&self, id: &ID) -> Result<()> { + Ok(self.db.reset_user_auth_token_by_id(id.as_rowid()?).await?) } async fn list_users( @@ -329,7 +335,7 @@ impl AuthenticationService for AuthenticationServiceImpl { .create_refresh_token(user_id, &refresh_token) .await?; - let access_token = generate_jwt(JWTPayload::new(email.clone(), is_admin)) + let access_token = generate_jwt(JWTPayload::new(user_id.as_id(), is_admin)) .map_err(|_| OAuthError::Unknown)?; let resp = OAuthResponse { @@ -741,7 +747,7 @@ mod tests { register_admin_user(&service).await; let user = service.get_user_by_email(ADMIN_EMAIL).await.unwrap(); - service.reset_user_auth_token(&user.email).await.unwrap(); + service.reset_user_auth_token(&user.id).await.unwrap(); let user2 = service.get_user_by_email(ADMIN_EMAIL).await.unwrap(); assert_ne!(user.auth_token, user2.auth_token); diff --git a/ee/tabby-webserver/src/service/dao.rs b/ee/tabby-webserver/src/service/dao.rs index 7391c1a02c54..5a91ea9826bd 100644 --- a/ee/tabby-webserver/src/service/dao.rs +++ b/ee/tabby-webserver/src/service/dao.rs @@ -154,6 +154,12 @@ pub trait AsID { } impl AsID for i32 { + fn as_id(&self) -> juniper::ID { + (*self as i64).as_id() + } +} + +impl AsID for i64 { fn as_id(&self) -> juniper::ID { juniper::ID::new(HASHER.encode(&[*self as u64])) } diff --git a/ee/tabby-webserver/src/service/mod.rs b/ee/tabby-webserver/src/service/mod.rs index 219de7e48dec..26e9a6e3bf65 100644 --- a/ee/tabby-webserver/src/service/mod.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -90,7 +90,7 @@ impl ServerContext { } } - async fn authorize_request(&self, request: &Request) -> (bool, Option) { + async fn authorize_request(&self, request: &Request) -> (bool, Option) { let path = request.uri().path(); if !(path.starts_with("/v1/") || path.starts_with("/v1beta/")) { return (true, None); @@ -112,7 +112,7 @@ impl ServerContext { // Allow JWT based access (from web browser), regardless of the license status. if let Ok(jwt) = self.auth.verify_access_token(token).await { - return (true, Some(jwt.sub)); + return (true, Some(jwt.sub.0)); } let is_license_valid = self @@ -127,7 +127,7 @@ impl ServerContext { .verify_auth_token(token, !is_license_valid) .await { - Ok(email) => (true, Some(email)), + Ok(id) => (true, Some(id.as_id())), Err(_) => (false, None), } } From d0836db559ea18820c33ce7f7af05f7b7bebb667 Mon Sep 17 00:00:00 2001 From: boxbeam Date: Mon, 26 Feb 2024 23:31:15 -0500 Subject: [PATCH 25/29] feat(webserver): Add authenticated endpoint to update password (#1553) * feat(webserver): Add authenticated endpoint to update password * [autofix.ci] apply automated fixes * Apply suggestions * [autofix.ci] apply automated fixes * Make password optional * Add todo * [autofix.ci] apply automated fixes * switch to id --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Meng Zhang --- ee/tabby-webserver/graphql/schema.graphql | 24 ++++++---- ee/tabby-webserver/src/schema/auth.rs | 40 ++++++++++++++++ ee/tabby-webserver/src/schema/mod.rs | 17 ++++++- ee/tabby-webserver/src/service/auth.rs | 56 +++++++++++++++++++++++ ee/tabby-webserver/src/service/dao.rs | 1 + 5 files changed, 129 insertions(+), 9 deletions(-) diff --git a/ee/tabby-webserver/graphql/schema.graphql b/ee/tabby-webserver/graphql/schema.graphql index 0686595d6528..2fdb96e9be53 100644 --- a/ee/tabby-webserver/graphql/schema.graphql +++ b/ee/tabby-webserver/graphql/schema.graphql @@ -17,6 +17,7 @@ type Mutation { requestInvitationEmail(input: RequestInvitationInput!): Invitation! requestPasswordResetEmail(input: RequestPasswordResetEmailInput!): Boolean! passwordReset(input: PasswordResetInput!): Boolean! + passwordChange(input: PasswordUpdateInput!): Boolean! resetUserAuthToken: Boolean! updateUserActive(id: ID!, active: Boolean!): Boolean! updateUserRole(id: ID!, isAdmin: Boolean!): Boolean! @@ -110,9 +111,10 @@ type RefreshTokenResponse { refreshExpiresAt: DateTimeUtc! } -type RegisterResponse { - accessToken: String! - refreshToken: String! +input PasswordUpdateInput { + oldPassword: String + newPassword1: String! + newPassword2: String! } type RepositoryConnection { @@ -139,6 +141,11 @@ type LicenseInfo { expiresAt: DateTimeUtc } +type RegisterResponse { + accessToken: String! + refreshToken: String! +} + input EmailSettingInput { smtpUsername: String! fromAddress: String! @@ -149,11 +156,6 @@ input EmailSettingInput { smtpPassword: String } -input SecuritySettingInput { - allowedRegisterDomainList: [String!]! - disableClientSideTelemetry: Boolean! -} - enum LicenseType { COMMUNITY TEAM @@ -176,6 +178,11 @@ input UpdateOAuthCredentialInput { clientSecret: String } +input SecuritySettingInput { + allowedRegisterDomainList: [String!]! + disableClientSideTelemetry: Boolean! +} + type OAuthCredential { provider: OAuthProvider! clientId: String! @@ -221,6 +228,7 @@ type User { authToken: String! createdAt: DateTimeUtc! active: Boolean! + isPasswordSet: Boolean! } type Worker { diff --git a/ee/tabby-webserver/src/schema/auth.rs b/ee/tabby-webserver/src/schema/auth.rs index b64f1bbf6f42..6a97f67f53d3 100644 --- a/ee/tabby-webserver/src/schema/auth.rs +++ b/ee/tabby-webserver/src/schema/auth.rs @@ -244,6 +244,7 @@ pub struct User { pub auth_token: String, pub created_at: DateTime, pub active: bool, + pub is_password_set: bool, } impl relay::NodeType for User { @@ -306,6 +307,39 @@ pub struct PasswordResetInput { pub password2: String, } +#[derive(Validate, GraphQLInputObject)] +pub struct PasswordUpdateInput { + pub old_password: Option, + + #[validate(length( + min = 8, + code = "new_password1", + message = "Password must be at least 8 characters" + ))] + #[validate(length( + max = 20, + code = "new_password1", + message = "Password must be at most 20 characters" + ))] + pub new_password1: String, + #[validate(length( + min = 8, + code = "new_password2", + message = "Password must be at least 8 characters" + ))] + #[validate(length( + max = 20, + code = "new_password2", + message = "Password must be at most 20 characters" + ))] + #[validate(must_match( + code = "new_password2", + message = "Passwords do not match", + other = "new_password1" + ))] + pub new_password2: String, +} + #[derive(Debug, Serialize, Deserialize, GraphQLObject)] #[graphql(context = Context)] pub struct Invitation { @@ -392,6 +426,12 @@ pub trait AuthenticationService: Send + Sync { async fn reset_user_auth_token(&self, id: &ID) -> Result<()>; async fn password_reset(&self, code: &str, password: &str) -> Result<()>; async fn request_password_reset_email(&self, email: String) -> Result>>; + async fn update_user_password( + &self, + id: &ID, + old_password: Option<&str>, + new_password: &str, + ) -> Result<()>; async fn list_users( &self, diff --git a/ee/tabby-webserver/src/schema/mod.rs b/ee/tabby-webserver/src/schema/mod.rs index 02a881159792..ba878fabc8db 100644 --- a/ee/tabby-webserver/src/schema/mod.rs +++ b/ee/tabby-webserver/src/schema/mod.rs @@ -28,7 +28,7 @@ use worker::{Worker, WorkerService}; use self::{ auth::{ - JWTPayload, OAuthCredential, OAuthProvider, PasswordResetInput, RequestInvitationInput, + PasswordResetInput, PasswordUpdateInput, RequestInvitationInput, RequestPasswordResetEmailInput, UpdateOAuthCredentialInput, }, email::{EmailService, EmailSetting, EmailSettingInput}, @@ -38,6 +38,7 @@ use self::{ NetworkSetting, NetworkSettingInput, SecuritySetting, SecuritySettingInput, SettingService, }, }; +use crate::schema::auth::{JWTPayload, OAuthCredential, OAuthProvider}; pub trait ServiceLocator: Send + Sync { fn auth(&self) -> Arc; @@ -366,6 +367,20 @@ impl Mutation { Ok(true) } + async fn password_change(ctx: &Context, input: PasswordUpdateInput) -> Result { + let claims = check_claims(ctx)?; + input.validate()?; + ctx.locator + .auth() + .update_user_password( + &claims.sub.0, + input.old_password.as_deref(), + &input.new_password1, + ) + .await?; + Ok(true) + } + async fn reset_user_auth_token(ctx: &Context) -> Result { let claims = check_claims(ctx)?; ctx.locator diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index bee1b0fc4303..67e83f58606c 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -140,6 +140,35 @@ impl AuthenticationService for AuthenticationServiceImpl { Ok(()) } + async fn update_user_password( + &self, + id: &ID, + old_password: Option<&str>, + new_password: &str, + ) -> Result<()> { + let user = self + .db + .get_user(id.as_rowid()?) + .await? + .ok_or_else(|| anyhow!("Invalid user"))?; + + let password_verified = match (user.password_encrypted.is_empty(), old_password) { + (true, _) => true, + (false, None) => false, + (false, Some(old_password)) => password_verify(old_password, &user.password_encrypted), + }; + if !password_verified { + return Err(anyhow!("Password is incorrect").into()); + } + + let new_password_encrypted = + password_hash(new_password).map_err(|_| anyhow!("Unknown error"))?; + self.db + .update_user_password(user.id, new_password_encrypted) + .await?; + Ok(()) + } + async fn token_auth(&self, email: String, password: String) -> Result { let Some(user) = self.db.get_user_by_email(&email).await? else { return Err(anyhow!("User not found").into()); @@ -1165,4 +1194,31 @@ mod tests { Err(CoreError::InvalidLicense(_)) ); } + + #[tokio::test] + async fn test_update_password() { + let service = test_authentication_service().await; + let id = service + .db + .create_user("test@example.com".into(), "".into(), true) + .await + .unwrap(); + + let id = id.as_id(); + + assert!(service + .update_user_password(&id, None, "newpass") + .await + .is_ok()); + + assert!(service + .update_user_password(&id, None, "newpass2") + .await + .is_err()); + + assert!(service + .update_user_password(&id, Some("newpass"), "newpass2") + .await + .is_ok()); + } } diff --git a/ee/tabby-webserver/src/service/dao.rs b/ee/tabby-webserver/src/service/dao.rs index 5a91ea9826bd..569bfe4c04e6 100644 --- a/ee/tabby-webserver/src/service/dao.rs +++ b/ee/tabby-webserver/src/service/dao.rs @@ -52,6 +52,7 @@ impl From for auth::User { auth_token: val.auth_token, created_at: val.created_at, active: val.active, + is_password_set: !val.password_encrypted.is_empty(), } } } From 0f928bcdbf9514c7789243dc373e105636c5ccb5 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Mon, 26 Feb 2024 20:53:00 -0800 Subject: [PATCH 26/29] fix(webserver): fill the correct code name (cameCase) for PasswordChangeInput (#1561) --- ee/tabby-webserver/graphql/schema.graphql | 30 +++++++++++------------ ee/tabby-webserver/src/schema/auth.rs | 12 ++++----- ee/tabby-webserver/src/schema/mod.rs | 4 +-- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/ee/tabby-webserver/graphql/schema.graphql b/ee/tabby-webserver/graphql/schema.graphql index 2fdb96e9be53..e791ad1f0b2f 100644 --- a/ee/tabby-webserver/graphql/schema.graphql +++ b/ee/tabby-webserver/graphql/schema.graphql @@ -17,7 +17,7 @@ type Mutation { requestInvitationEmail(input: RequestInvitationInput!): Invitation! requestPasswordResetEmail(input: RequestPasswordResetEmailInput!): Boolean! passwordReset(input: PasswordResetInput!): Boolean! - passwordChange(input: PasswordUpdateInput!): Boolean! + passwordChange(input: PasswordChangeInput!): Boolean! resetUserAuthToken: Boolean! updateUserActive(id: ID!, active: Boolean!): Boolean! updateUserRole(id: ID!, isAdmin: Boolean!): Boolean! @@ -111,10 +111,9 @@ type RefreshTokenResponse { refreshExpiresAt: DateTimeUtc! } -input PasswordUpdateInput { - oldPassword: String - newPassword1: String! - newPassword2: String! +type RegisterResponse { + accessToken: String! + refreshToken: String! } type RepositoryConnection { @@ -141,11 +140,6 @@ type LicenseInfo { expiresAt: DateTimeUtc } -type RegisterResponse { - accessToken: String! - refreshToken: String! -} - input EmailSettingInput { smtpUsername: String! fromAddress: String! @@ -156,6 +150,11 @@ input EmailSettingInput { smtpPassword: String } +input SecuritySettingInput { + allowedRegisterDomainList: [String!]! + disableClientSideTelemetry: Boolean! +} + enum LicenseType { COMMUNITY TEAM @@ -178,11 +177,6 @@ input UpdateOAuthCredentialInput { clientSecret: String } -input SecuritySettingInput { - allowedRegisterDomainList: [String!]! - disableClientSideTelemetry: Boolean! -} - type OAuthCredential { provider: OAuthProvider! clientId: String! @@ -204,6 +198,12 @@ input PasswordResetInput { password2: String! } +input PasswordChangeInput { + oldPassword: String + newPassword1: String! + newPassword2: String! +} + type Invitation { id: ID! email: String! diff --git a/ee/tabby-webserver/src/schema/auth.rs b/ee/tabby-webserver/src/schema/auth.rs index 6a97f67f53d3..0e595b795d35 100644 --- a/ee/tabby-webserver/src/schema/auth.rs +++ b/ee/tabby-webserver/src/schema/auth.rs @@ -308,32 +308,32 @@ pub struct PasswordResetInput { } #[derive(Validate, GraphQLInputObject)] -pub struct PasswordUpdateInput { +pub struct PasswordChangeInput { pub old_password: Option, #[validate(length( min = 8, - code = "new_password1", + code = "newPassword1", message = "Password must be at least 8 characters" ))] #[validate(length( max = 20, - code = "new_password1", + code = "newPassword1", message = "Password must be at most 20 characters" ))] pub new_password1: String, #[validate(length( min = 8, - code = "new_password2", + code = "newPassword2", message = "Password must be at least 8 characters" ))] #[validate(length( max = 20, - code = "new_password2", + code = "newPassword2", message = "Password must be at most 20 characters" ))] #[validate(must_match( - code = "new_password2", + code = "newPassword2", message = "Passwords do not match", other = "new_password1" ))] diff --git a/ee/tabby-webserver/src/schema/mod.rs b/ee/tabby-webserver/src/schema/mod.rs index ba878fabc8db..d815e30b4096 100644 --- a/ee/tabby-webserver/src/schema/mod.rs +++ b/ee/tabby-webserver/src/schema/mod.rs @@ -28,7 +28,7 @@ use worker::{Worker, WorkerService}; use self::{ auth::{ - PasswordResetInput, PasswordUpdateInput, RequestInvitationInput, + PasswordChangeInput, PasswordResetInput, RequestInvitationInput, RequestPasswordResetEmailInput, UpdateOAuthCredentialInput, }, email::{EmailService, EmailSetting, EmailSettingInput}, @@ -367,7 +367,7 @@ impl Mutation { Ok(true) } - async fn password_change(ctx: &Context, input: PasswordUpdateInput) -> Result { + async fn password_change(ctx: &Context, input: PasswordChangeInput) -> Result { let claims = check_claims(ctx)?; input.validate()?; ctx.locator From 5980accd6a401a6968615b0296c428b64301693f Mon Sep 17 00:00:00 2001 From: aliang <1098486429@qq.com> Date: Tue, 27 Feb 2024 15:53:01 +0800 Subject: [PATCH 27/29] fix(ui): get email from me query (#1562) * fix(ui): get email from me query * reexecuteQueryMe --- ee/tabby-ui/app/(dashboard)/page.tsx | 14 +++----------- ee/tabby-ui/components/user-panel.tsx | 7 ++++--- ee/tabby-ui/lib/hooks/use-me.ts | 19 +++++++++++++++++++ ee/tabby-ui/lib/tabby/auth.tsx | 7 ++++--- 4 files changed, 30 insertions(+), 17 deletions(-) create mode 100644 ee/tabby-ui/lib/hooks/use-me.ts diff --git a/ee/tabby-ui/app/(dashboard)/page.tsx b/ee/tabby-ui/app/(dashboard)/page.tsx index 02ef926f7d13..f8f2acdcd254 100644 --- a/ee/tabby-ui/app/(dashboard)/page.tsx +++ b/ee/tabby-ui/app/(dashboard)/page.tsx @@ -1,10 +1,10 @@ 'use client' import { noop } from 'lodash-es' -import { useQuery } from 'urql' import { graphql } from '@/lib/gql/generates' import { useHealth } from '@/lib/hooks/use-health' +import { useMe } from '@/lib/hooks/use-me' import { useExternalURL } from '@/lib/hooks/use-network-setting' import { useMutation } from '@/lib/tabby/gql' import { Button } from '@/components/ui/button' @@ -29,14 +29,6 @@ export default function Home() { ) } -const meQuery = graphql(/* GraphQL */ ` - query MeQuery { - me { - authToken - } - } -`) - const resetUserAuthTokenDocument = graphql(/* GraphQL */ ` mutation ResetUserAuthToken { resetUserAuthToken @@ -45,14 +37,14 @@ const resetUserAuthTokenDocument = graphql(/* GraphQL */ ` function MainPanel() { const { data: healthInfo } = useHealth() - const [{ data }, reexecuteQuery] = useQuery({ query: meQuery }) + const [{ data }, reexecuteQuery] = useMe() const externalUrl = useExternalURL() const resetUserAuthToken = useMutation(resetUserAuthTokenDocument, { onCompleted: () => reexecuteQuery() }) - if (!healthInfo || !data) return + if (!healthInfo || !data?.me) return return (
    diff --git a/ee/tabby-ui/components/user-panel.tsx b/ee/tabby-ui/components/user-panel.tsx index 693109abe951..ccdf70508207 100644 --- a/ee/tabby-ui/components/user-panel.tsx +++ b/ee/tabby-ui/components/user-panel.tsx @@ -1,8 +1,9 @@ import React from 'react' import NiceAvatar, { genConfig } from 'react-nice-avatar' +import { useMe } from '@/lib/hooks/use-me' import { useIsChatEnabled } from '@/lib/hooks/use-server-info' -import { useAuthenticatedSession, useSignOut } from '@/lib/tabby/auth' +import { useSignOut } from '@/lib/tabby/auth' import { DropdownMenu, DropdownMenuContent, @@ -15,9 +16,9 @@ import { import { IconBackpack, IconChat, IconCode, IconLogout } from './ui/icons' export default function UserPanel() { - const user = useAuthenticatedSession() const signOut = useSignOut() - + const [{ data }] = useMe() + const user = data?.me const isChatEnabled = useIsChatEnabled() if (!user) { diff --git a/ee/tabby-ui/lib/hooks/use-me.ts b/ee/tabby-ui/lib/hooks/use-me.ts new file mode 100644 index 000000000000..570d07e324b6 --- /dev/null +++ b/ee/tabby-ui/lib/hooks/use-me.ts @@ -0,0 +1,19 @@ +import { useQuery } from 'urql' + +import { graphql } from '@/lib/gql/generates' + +const meQuery = graphql(/* GraphQL */ ` + query MeQuery { + me { + authToken + email + isAdmin + } + } +`) + +const useMe = () => { + return useQuery({ query: meQuery }) +} + +export { useMe } diff --git a/ee/tabby-ui/lib/tabby/auth.tsx b/ee/tabby-ui/lib/tabby/auth.tsx index 99ccc8626287..3c0449098ce1 100644 --- a/ee/tabby-ui/lib/tabby/auth.tsx +++ b/ee/tabby-ui/lib/tabby/auth.tsx @@ -6,6 +6,7 @@ import useLocalStorage from 'use-local-storage' import { graphql } from '@/lib/gql/generates' import { isClientSide } from '@/lib/utils' +import { useMe } from '../hooks/use-me' import { useIsAdminInitialized } from '../hooks/use-server-info' interface AuthData { @@ -150,6 +151,7 @@ const AuthProvider: React.FunctionComponent = ({ status: 'loading', data: undefined }) + const [, reexecuteQueryMe] = useMe() React.useEffect(() => { initialized.current = true @@ -166,6 +168,7 @@ const AuthProvider: React.FunctionComponent = ({ // After being mounted, listen for changes in the access token if (authToken?.accessToken && authToken?.refreshToken) { dispatch({ type: AuthActionType.SignIn, data: authToken }) + reexecuteQueryMe() } else { dispatch({ type: AuthActionType.SignOut }) } @@ -174,12 +177,11 @@ const AuthProvider: React.FunctionComponent = ({ const session: Session = React.useMemo(() => { if (authState?.status == 'authenticated') { try { - const { sub, is_admin } = jwtDecode( + const { is_admin } = jwtDecode( authState.data.accessToken ) return { data: { - email: sub!, isAdmin: is_admin, accessToken: authState.data.accessToken }, @@ -260,7 +262,6 @@ function useSignOut(): () => Promise { } interface User { - email: string isAdmin: boolean accessToken: string } From 9bdc3efa6bfe8bef40d9b6eda09c95b029756de8 Mon Sep 17 00:00:00 2001 From: aliang <1098486429@qq.com> Date: Tue, 27 Feb 2024 18:47:46 +0800 Subject: [PATCH 28/29] feat(ui): add change password frontend (#1563) * feat(ui): init profile page * feat(ui): init change password page * [autofix.ci] apply automated fixes * update icons * Update ee/tabby-ui/app/(dashboard)/profile/components/profile.tsx Co-authored-by: Meng Zhang * Update ee/tabby-ui/app/(dashboard)/profile/components/change-password.tsx Co-authored-by: Meng Zhang * Update ee/tabby-ui/app/(dashboard)/profile/components/change-password.tsx Co-authored-by: Meng Zhang * Update ee/tabby-ui/app/(dashboard)/profile/components/change-password.tsx Co-authored-by: Meng Zhang * Update ee/tabby-ui/app/(dashboard)/profile/components/change-password.tsx Co-authored-by: Meng Zhang * update * [autofix.ci] apply automated fixes * update --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Meng Zhang --- ee/tabby-ui/app/(dashboard)/(logs)/layout.tsx | 2 +- ee/tabby-ui/app/(dashboard)/cluster/page.tsx | 6 +- .../app/(dashboard)/components/sidebar.tsx | 6 +- ee/tabby-ui/app/(dashboard)/layout.tsx | 2 +- ee/tabby-ui/app/(dashboard)/page.tsx | 6 +- .../(dashboard)/profile/components/avatar.tsx | 17 ++ .../profile/components/change-password.tsx | 167 ++++++++++++++++++ .../(dashboard)/profile/components/email.tsx | 19 ++ .../profile/components/profile-card.tsx | 53 ++++++ .../profile/components/profile.tsx | 32 ++++ ee/tabby-ui/app/(dashboard)/profile/page.tsx | 11 ++ .../settings/(integrations)/layout.tsx | 7 - .../app/(dashboard)/settings/general/page.tsx | 7 +- .../subscription/components/subscription.tsx | 4 +- .../settings/team/components/team.tsx | 12 +- ee/tabby-ui/components/ui/icons.tsx | 49 +++-- ee/tabby-ui/lib/hooks/use-me.ts | 1 + 17 files changed, 352 insertions(+), 49 deletions(-) create mode 100644 ee/tabby-ui/app/(dashboard)/profile/components/avatar.tsx create mode 100644 ee/tabby-ui/app/(dashboard)/profile/components/change-password.tsx create mode 100644 ee/tabby-ui/app/(dashboard)/profile/components/email.tsx create mode 100644 ee/tabby-ui/app/(dashboard)/profile/components/profile-card.tsx create mode 100644 ee/tabby-ui/app/(dashboard)/profile/components/profile.tsx create mode 100644 ee/tabby-ui/app/(dashboard)/profile/page.tsx delete mode 100644 ee/tabby-ui/app/(dashboard)/settings/(integrations)/layout.tsx diff --git a/ee/tabby-ui/app/(dashboard)/(logs)/layout.tsx b/ee/tabby-ui/app/(dashboard)/(logs)/layout.tsx index 78e572debac4..3f6da82da9ee 100644 --- a/ee/tabby-ui/app/(dashboard)/(logs)/layout.tsx +++ b/ee/tabby-ui/app/(dashboard)/(logs)/layout.tsx @@ -3,5 +3,5 @@ export default function LogsLayout({ }: { children: React.ReactNode }) { - return
    {children}
    + return
    {children}
    } diff --git a/ee/tabby-ui/app/(dashboard)/cluster/page.tsx b/ee/tabby-ui/app/(dashboard)/cluster/page.tsx index 09548c6a4508..1c040e525e6d 100644 --- a/ee/tabby-ui/app/(dashboard)/cluster/page.tsx +++ b/ee/tabby-ui/app/(dashboard)/cluster/page.tsx @@ -7,9 +7,5 @@ export const metadata: Metadata = { } export default function IndexPage() { - return ( -
    - -
    - ) + return } diff --git a/ee/tabby-ui/app/(dashboard)/components/sidebar.tsx b/ee/tabby-ui/app/(dashboard)/components/sidebar.tsx index 811f03f2e649..4cdf060679ab 100644 --- a/ee/tabby-ui/app/(dashboard)/components/sidebar.tsx +++ b/ee/tabby-ui/app/(dashboard)/components/sidebar.tsx @@ -21,7 +21,8 @@ import { IconHome, IconLightingBolt, IconNetwork, - IconScrollText + IconScrollText, + IconUser } from '@/components/ui/icons' export interface SidebarProps { @@ -57,6 +58,9 @@ export default function Sidebar({ children, className }: SidebarProps) { Home + + Profile + {isAdmin && ( <> diff --git a/ee/tabby-ui/app/(dashboard)/layout.tsx b/ee/tabby-ui/app/(dashboard)/layout.tsx index 28370816aa44..93479a35076e 100644 --- a/ee/tabby-ui/app/(dashboard)/layout.tsx +++ b/ee/tabby-ui/app/(dashboard)/layout.tsx @@ -21,7 +21,7 @@ export default function RootLayout({ children }: DashboardLayoutProps) {
    -
    {children}
    +
    {children}
    ) diff --git a/ee/tabby-ui/app/(dashboard)/page.tsx b/ee/tabby-ui/app/(dashboard)/page.tsx index f8f2acdcd254..b5e0c6fbd8e4 100644 --- a/ee/tabby-ui/app/(dashboard)/page.tsx +++ b/ee/tabby-ui/app/(dashboard)/page.tsx @@ -48,10 +48,10 @@ function MainPanel() { return (
    - + Getting Started - + - + Use informations above for IDE extensions / plugins configuration, see{' '} { + const [{ data }] = useMe() + + if (!data?.me?.email) return null + + const config = genConfig(data?.me?.email) + + return ( +
    + +
    + ) +} diff --git a/ee/tabby-ui/app/(dashboard)/profile/components/change-password.tsx b/ee/tabby-ui/app/(dashboard)/profile/components/change-password.tsx new file mode 100644 index 000000000000..394f41c4699e --- /dev/null +++ b/ee/tabby-ui/app/(dashboard)/profile/components/change-password.tsx @@ -0,0 +1,167 @@ +'use client' + +import React from 'react' +import { zodResolver } from '@hookform/resolvers/zod' +import { useForm } from 'react-hook-form' +import { toast } from 'sonner' +import * as z from 'zod' + +import { graphql } from '@/lib/gql/generates' +import { useMe } from '@/lib/hooks/use-me' +import { useMutation } from '@/lib/tabby/gql' +import { Button } from '@/components/ui/button' +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage +} from '@/components/ui/form' +import { IconSpinner } from '@/components/ui/icons' +import { Input } from '@/components/ui/input' +import { Separator } from '@/components/ui/separator' +import { ListSkeleton } from '@/components/skeleton' + +const passwordChangeMutation = graphql(/* GraphQL */ ` + mutation PasswordChange($input: PasswordChangeInput!) { + passwordChange(input: $input) + } +`) + +interface ChangePasswordFormProps { + showOldPassword?: boolean + onSuccess?: () => void +} + +const ChangePasswordForm: React.FC = ({ + onSuccess, + showOldPassword +}) => { + const formSchema = z.object({ + oldPassword: showOldPassword ? z.string() : z.string().optional(), + newPassword1: z.string(), + newPassword2: z.string() + }) + + const form = useForm>({ + resolver: zodResolver(formSchema) + }) + const { isSubmitting } = form.formState + + const passwordChange = useMutation(passwordChangeMutation, { + form, + onCompleted(values) { + if (values?.passwordChange) { + onSuccess?.() + } + } + }) + + const onSubmit = async (values: z.infer) => { + await passwordChange({ + input: values + }) + } + + return ( + +
    + + {showOldPassword && ( + ( + + Old password + + + + + + )} + /> + )} + ( + + New password + + + + + + )} + /> + ( + + Confirm new password + + + + + + )} + /> + + +
    + +
    + +
    + + ) +} + +export const ChangePassword = () => { + const [{ data }, reexecuteQuery] = useMe() + const onSuccess = () => { + toast.success('Password is updated') + reexecuteQuery() + } + + return data ? ( + + ) : ( + + ) +} diff --git a/ee/tabby-ui/app/(dashboard)/profile/components/email.tsx b/ee/tabby-ui/app/(dashboard)/profile/components/email.tsx new file mode 100644 index 000000000000..e68e552806b6 --- /dev/null +++ b/ee/tabby-ui/app/(dashboard)/profile/components/email.tsx @@ -0,0 +1,19 @@ +import { noop } from 'lodash-es' + +import { useMe } from '@/lib/hooks/use-me' +import { Input } from '@/components/ui/input' + +export const Email = () => { + const [{ data }] = useMe() + + return ( +
    + +
    + ) +} diff --git a/ee/tabby-ui/app/(dashboard)/profile/components/profile-card.tsx b/ee/tabby-ui/app/(dashboard)/profile/components/profile-card.tsx new file mode 100644 index 000000000000..43b8461ce568 --- /dev/null +++ b/ee/tabby-ui/app/(dashboard)/profile/components/profile-card.tsx @@ -0,0 +1,53 @@ +import React from 'react' + +import { cn } from '@/lib/utils' +import { CardContent, CardTitle } from '@/components/ui/card' +import { Separator } from '@/components/ui/separator' + +interface ProfileCardProps extends React.HTMLAttributes { + title: string + description?: string + footer?: React.ReactNode + footerClassname?: string +} + +const ProfileCard: React.FC = ({ + title, + description, + footer, + footerClassname, + className, + children, + ...props +}) => { + return ( +
    +
    + {title} + {description && ( +
    + {description} +
    + )} +
    + {children} +
    + {!!footer && } + {footer} +
    +
    + ) +} + +export { ProfileCard } diff --git a/ee/tabby-ui/app/(dashboard)/profile/components/profile.tsx b/ee/tabby-ui/app/(dashboard)/profile/components/profile.tsx new file mode 100644 index 000000000000..966f9e7a6c4e --- /dev/null +++ b/ee/tabby-ui/app/(dashboard)/profile/components/profile.tsx @@ -0,0 +1,32 @@ +'use client' + +import React from 'react' + +import { Avatar } from './avatar' +import { ChangePassword } from './change-password' +import { Email } from './email' +import { ProfileCard } from './profile-card' + +export default function Profile() { + return ( +
    + + + + + + + + + +
    + ) +} diff --git a/ee/tabby-ui/app/(dashboard)/profile/page.tsx b/ee/tabby-ui/app/(dashboard)/profile/page.tsx new file mode 100644 index 000000000000..c1d44dc91ff9 --- /dev/null +++ b/ee/tabby-ui/app/(dashboard)/profile/page.tsx @@ -0,0 +1,11 @@ +import { Metadata } from 'next' + +import Profile from './components/profile' + +export const metadata: Metadata = { + title: 'Profile' +} + +export default function Page() { + return +} diff --git a/ee/tabby-ui/app/(dashboard)/settings/(integrations)/layout.tsx b/ee/tabby-ui/app/(dashboard)/settings/(integrations)/layout.tsx deleted file mode 100644 index 0912c73c1045..000000000000 --- a/ee/tabby-ui/app/(dashboard)/settings/(integrations)/layout.tsx +++ /dev/null @@ -1,7 +0,0 @@ -export default function IntegrationsLayout({ - children -}: { - children: React.ReactNode -}) { - return
    {children}
    -} diff --git a/ee/tabby-ui/app/(dashboard)/settings/general/page.tsx b/ee/tabby-ui/app/(dashboard)/settings/general/page.tsx index 470dcfc1a493..2361323aa5ae 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/general/page.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/general/page.tsx @@ -7,10 +7,5 @@ export const metadata: Metadata = { } export default function GeneralSettings() { - // todo abstract settings-layout after email was merged - return ( -
    - -
    - ) + return } diff --git a/ee/tabby-ui/app/(dashboard)/settings/subscription/components/subscription.tsx b/ee/tabby-ui/app/(dashboard)/settings/subscription/components/subscription.tsx index 9bf87929e4d1..692da3443184 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/subscription/components/subscription.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/subscription/components/subscription.tsx @@ -21,7 +21,7 @@ export default function Subscription() { const canReset = !!license?.type && license.type !== LicenseType.Community return ( -
    + <>
    -
    + ) } diff --git a/ee/tabby-ui/app/(dashboard)/settings/team/components/team.tsx b/ee/tabby-ui/app/(dashboard)/settings/team/components/team.tsx index 6c6cfcc3b039..8a61d65450b8 100644 --- a/ee/tabby-ui/app/(dashboard)/settings/team/components/team.tsx +++ b/ee/tabby-ui/app/(dashboard)/settings/team/components/team.tsx @@ -7,24 +7,24 @@ import UsersTable from './user-table' export default function Team() { return ( -
    + <>
    - + Pending Invites - +
    - + Members - +
    -
    + ) } diff --git a/ee/tabby-ui/components/ui/icons.tsx b/ee/tabby-ui/components/ui/icons.tsx index 983232935143..f645777a0ec1 100644 --- a/ee/tabby-ui/components/ui/icons.tsx +++ b/ee/tabby-ui/components/ui/icons.tsx @@ -244,12 +244,17 @@ function IconUser({ className, ...props }: React.ComponentProps<'svg'>) { return ( - + + ) } @@ -412,12 +417,16 @@ function IconCheck({ className, ...props }: React.ComponentProps<'svg'>) { return ( - + ) } @@ -440,12 +449,17 @@ function IconClose({ className, ...props }: React.ComponentProps<'svg'>) { return ( - + + ) } @@ -845,17 +859,18 @@ function IconBackpack({ className, ...props }: React.ComponentProps<'svg'>) { function IconGear({ className, ...props }: React.ComponentProps<'svg'>) { return ( - + + ) } diff --git a/ee/tabby-ui/lib/hooks/use-me.ts b/ee/tabby-ui/lib/hooks/use-me.ts index 570d07e324b6..43ffa92b60b4 100644 --- a/ee/tabby-ui/lib/hooks/use-me.ts +++ b/ee/tabby-ui/lib/hooks/use-me.ts @@ -8,6 +8,7 @@ const meQuery = graphql(/* GraphQL */ ` authToken email isAdmin + isPasswordSet } } `) From eb1a848443a4f2ef12e5b0b4c6d9377583ef27c8 Mon Sep 17 00:00:00 2001 From: boxbeam Date: Tue, 27 Feb 2024 12:52:46 -0500 Subject: [PATCH 29/29] docs: Adjust description of golden tests in CONTRIBUTING.md --- CONTRIBUTING.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ea345e230279..0f209e04d233 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -41,10 +41,10 @@ Before proceeding, ensure that all tests are passing locally: cargo test -- --skip golden ``` -Golden tests should be skipped on all platforms except Apple silicon (M1/M2), because they have not been created for other platforms yet. - This will help ensure everything is working correctly and avoid surprises with local breakages. +Golden tests, which run models and check their outputs against previous "golden snapshots", should be skipped for most development purposes, as they take a very long time to run (especially the tests running the models on CPU). You may still want to run them if your changes relate to the functioning of or integration with the generative models, but skipping them is recommended otherwise. + ## Building and Running Tabby can be run through `cargo` in much the same manner as docker: