diff --git a/ee/tabby-db/src/server_setting.rs b/ee/tabby-db/src/server_setting.rs index c6dc229f3b72..f2a41afeee0b 100644 --- a/ee/tabby-db/src/server_setting.rs +++ b/ee/tabby-db/src/server_setting.rs @@ -53,6 +53,32 @@ impl DbConn { Ok(setting) } + pub async fn update_security_setting( + &self, + allowed_register_domain_list: Option, + disable_client_side_telemetry: bool, + ) -> Result<()> { + query!("INSERT INTO server_setting (id, security_allowed_register_domain_list, security_disable_client_side_telemetry) VALUES ($1, $2, $3) + ON CONFLICT(id) DO UPDATE SET security_allowed_register_domain_list = $2, security_disable_client_side_telemetry = $3", + SERVER_SETTING_ROW_ID, + allowed_register_domain_list, + disable_client_side_telemetry, + ).execute(&self.pool).await?; + Ok(()) + } + + pub async fn update_network_setting(&self, external_url: String) -> Result<()> { + query!( + "INSERT INTO server_setting (id, network_external_url) VALUES ($1, $2) + ON CONFLICT(id) DO UPDATE SET network_external_url = $2", + SERVER_SETTING_ROW_ID, + external_url + ) + .execute(&self.pool) + .await?; + Ok(()) + } + pub async fn update_server_setting( &self, security_allowed_register_domain_list: Option, diff --git a/ee/tabby-webserver/src/schema/mod.rs b/ee/tabby-webserver/src/schema/mod.rs index b395ccf51543..4a7ade4d1d7a 100644 --- a/ee/tabby-webserver/src/schema/mod.rs +++ b/ee/tabby-webserver/src/schema/mod.rs @@ -28,7 +28,9 @@ use worker::{Worker, WorkerService}; use self::{ email::{EmailService, EmailSetting}, repository::{RepositoryError, RepositoryService}, - setting::{ServerSetting, SettingService}, + setting::{ + NetworkSetting, NetworkSettingInput, SecuritySetting, SecuritySettingInput, SettingService, + }, }; use crate::schema::{ auth::{JWTPayload, OAuthCredential, OAuthProvider}, @@ -236,13 +238,23 @@ impl Query { Ok(val) } - async fn server_setting(ctx: &Context) -> Result { + async fn network_setting(ctx: &Context) -> Result { let Some(JWTPayload { is_admin: true, .. }) = &ctx.claims else { return Err(CoreError::Unauthorized( "Only admin can access server settings", )); }; - let val = ctx.locator.setting().read_server_setting().await?; + let val = ctx.locator.setting().read_network_setting().await?; + Ok(val) + } + + async fn security_setting(ctx: &Context) -> Result { + let Some(JWTPayload { is_admin: true, .. }) = &ctx.claims else { + return Err(CoreError::Unauthorized( + "Only admin can access server settings", + )); + }; + let val = ctx.locator.setting().read_security_setting().await?; Ok(val) } @@ -465,25 +477,23 @@ impl Mutation { Ok(true) } - async fn update_server_setting( - ctx: &Context, - security_allowed_register_domain_list: Vec, - security_disable_client_side_telemetry: bool, - network_external_url: String, - ) -> Result { + async fn update_security_setting(ctx: &Context, input: SecuritySettingInput) -> Result { let Some(JWTPayload { is_admin: true, .. }) = &ctx.claims else { return Err(CoreError::Unauthorized( "Only admin can access server settings", )); }; - ctx.locator - .setting() - .update_server_setting(ServerSetting { - security_allowed_register_domain_list, - security_disable_client_side_telemetry, - network_external_url, - }) - .await?; + ctx.locator.setting().update_security_setting(input).await?; + Ok(true) + } + + async fn update_network_setting(ctx: &Context, input: NetworkSettingInput) -> Result { + let Some(JWTPayload { is_admin: true, .. }) = &ctx.claims else { + return Err(CoreError::Unauthorized( + "Only admin can access server settings", + )); + }; + ctx.locator.setting().update_network_setting(input).await?; Ok(true) } diff --git a/ee/tabby-webserver/src/schema/setting.rs b/ee/tabby-webserver/src/schema/setting.rs index b6383a716b8b..b383b99fa10d 100644 --- a/ee/tabby-webserver/src/schema/setting.rs +++ b/ee/tabby-webserver/src/schema/setting.rs @@ -2,31 +2,43 @@ use std::collections::HashSet; use anyhow::Result; use async_trait::async_trait; -use juniper::GraphQLObject; +use juniper::{GraphQLInputObject, GraphQLObject}; use validator::{validate_email, Validate, ValidationError}; #[async_trait] pub trait SettingService: Send + Sync { - async fn read_server_setting(&self) -> Result; - async fn update_server_setting(&self, setting: ServerSetting) -> Result<()>; + async fn read_security_setting(&self) -> Result; + async fn update_security_setting(&self, input: SecuritySettingInput) -> Result<()>; + + async fn read_network_setting(&self) -> Result; + async fn update_network_setting(&self, input: NetworkSettingInput) -> Result<()>; } #[derive(GraphQLObject, Debug, PartialEq)] -pub struct ServerSetting { - pub security_allowed_register_domain_list: Vec, - pub security_disable_client_side_telemetry: bool, - pub network_external_url: String, +pub struct SecuritySetting { + pub allowed_register_domain_list: Vec, + pub disable_client_side_telemetry: bool, } -#[derive(Validate)] -pub struct ServerSettingInput<'a> { +#[derive(GraphQLInputObject, Validate)] +pub struct SecuritySettingInput { #[validate(custom = "validate_unique_domains")] - pub security_allowed_register_domain_list: Vec<&'a str>, + pub allowed_register_domain_list: Vec, + pub disable_client_side_telemetry: bool, +} + +#[derive(GraphQLObject, Debug, PartialEq)] +pub struct NetworkSetting { + pub external_url: String, +} + +#[derive(GraphQLInputObject, Validate)] +pub struct NetworkSettingInput { #[validate(url)] - pub network_external_url: &'a str, + pub external_url: String, } -fn validate_unique_domains(domains: &Vec<&str>) -> Result<(), ValidationError> { +fn validate_unique_domains(domains: &[String]) -> Result<(), ValidationError> { let unique: HashSet<_> = domains.iter().collect(); if unique.len() != domains.len() { let collision = domains.iter().find(|s| unique.contains(s)).unwrap(); @@ -51,12 +63,14 @@ mod tests { #[test] fn test_validate_urls() { - assert!(validate_unique_domains(&vec!["example.com"]).is_ok()); + assert!(validate_unique_domains(&["example.com".to_owned()]).is_ok()); - assert!(validate_unique_domains(&vec!["https://example.com"]).is_err()); + assert!(validate_unique_domains(&["https://example.com".to_owned()]).is_err()); - assert!(validate_unique_domains(&vec!["domain.withmultipleparts.com"]).is_ok()); + assert!(validate_unique_domains(&["domain.withmultipleparts.com".to_owned()]).is_ok()); - assert!(validate_unique_domains(&vec!["example.com", "example.com"]).is_err()); + assert!( + validate_unique_domains(&["example.com".to_owned(), "example.com".to_owned()]).is_err() + ); } } diff --git a/ee/tabby-webserver/src/service/dao.rs b/ee/tabby-webserver/src/service/dao.rs index 5cbb4d5fb509..daa9970fd5ca 100644 --- a/ee/tabby-webserver/src/service/dao.rs +++ b/ee/tabby-webserver/src/service/dao.rs @@ -10,7 +10,7 @@ use crate::schema::{ email::EmailSetting, job, repository::Repository, - setting::ServerSetting, + setting::{NetworkSetting, SecuritySetting}, CoreError, }; @@ -98,15 +98,22 @@ impl From for EmailSetting { } } -impl From for ServerSetting { +impl From for SecuritySetting { fn from(value: ServerSettingDAO) -> Self { Self { - security_allowed_register_domain_list: value + allowed_register_domain_list: value .security_allowed_register_domain_list() .map(|s| s.to_owned()) .collect(), - security_disable_client_side_telemetry: value.security_disable_client_side_telemetry, - network_external_url: value.network_external_url, + disable_client_side_telemetry: value.security_disable_client_side_telemetry, + } + } +} + +impl From for NetworkSetting { + fn from(value: ServerSettingDAO) -> Self { + Self { + external_url: value.network_external_url, } } } diff --git a/ee/tabby-webserver/src/service/setting.rs b/ee/tabby-webserver/src/service/setting.rs index 6b8eec5d254c..4e0f654d8994 100644 --- a/ee/tabby-webserver/src/service/setting.rs +++ b/ee/tabby-webserver/src/service/setting.rs @@ -3,34 +3,35 @@ use async_trait::async_trait; use tabby_db::DbConn; use validator::Validate; -use crate::schema::setting::{ServerSetting, ServerSettingInput, SettingService}; +use crate::schema::setting::{ + NetworkSetting, NetworkSettingInput, SecuritySetting, SecuritySettingInput, SettingService, +}; #[async_trait] impl SettingService for DbConn { - async fn read_server_setting(&self) -> Result { - let setting = self.read_server_setting().await?; - Ok(setting.into()) + async fn read_security_setting(&self) -> Result { + Ok(self.read_server_setting().await?.into()) } - async fn update_server_setting(&self, setting: ServerSetting) -> Result<()> { - ServerSettingInput { - security_allowed_register_domain_list: setting - .security_allowed_register_domain_list - .iter() - .map(|s| &**s) - .collect(), - network_external_url: &setting.network_external_url, - } - .validate()?; - let allowed_domains = setting.security_allowed_register_domain_list.join(","); - let allowed_domains = (!allowed_domains.is_empty()).then_some(allowed_domains); - self.update_server_setting( - allowed_domains, - setting.security_disable_client_side_telemetry, - setting.network_external_url, - ) - .await?; - Ok(()) + async fn update_security_setting(&self, input: SecuritySettingInput) -> Result<()> { + input.validate()?; + let domains = if input.allowed_register_domain_list.is_empty() { + None + } else { + Some(input.allowed_register_domain_list.join(",")) + }; + + self.update_security_setting(domains, input.disable_client_side_telemetry) + .await + } + + async fn read_network_setting(&self) -> Result { + Ok(self.read_server_setting().await?.into()) + } + + async fn update_network_setting(&self, input: NetworkSettingInput) -> Result<()> { + input.validate()?; + self.update_network_setting(input.external_url).await } } @@ -39,37 +40,60 @@ mod tests { use super::*; #[tokio::test] - async fn test_read_server_setting() { + async fn test_security_setting() { + let db = DbConn::new_in_memory().await.unwrap(); + + assert_eq!( + SettingService::read_security_setting(&db).await.unwrap(), + SecuritySetting { + allowed_register_domain_list: vec![], + disable_client_side_telemetry: false, + } + ); + + SettingService::update_security_setting( + &db, + SecuritySettingInput { + allowed_register_domain_list: vec!["example.com".into()], + disable_client_side_telemetry: true, + }, + ) + .await + .unwrap(); + + assert_eq!( + SettingService::read_security_setting(&db).await.unwrap(), + SecuritySetting { + allowed_register_domain_list: vec!["example.com".into()], + disable_client_side_telemetry: true, + } + ); + } + + #[tokio::test] + async fn test_network_setting() { let db = DbConn::new_in_memory().await.unwrap(); - let server_setting = SettingService::read_server_setting(&db).await.unwrap(); assert_eq!( - server_setting, - ServerSetting { - security_allowed_register_domain_list: vec![], - security_disable_client_side_telemetry: false, - network_external_url: "http://localhost:8080".into(), + SettingService::read_network_setting(&db).await.unwrap(), + NetworkSetting { + external_url: "http://localhost:8080".into(), } ); - SettingService::update_server_setting( + SettingService::update_network_setting( &db, - ServerSetting { - security_allowed_register_domain_list: vec!["example.com".into()], - security_disable_client_side_telemetry: true, - network_external_url: "http://localhost:9090".into(), + NetworkSettingInput { + external_url: "http://localhost:8081".into(), }, ) .await .unwrap(); - let server_setting = SettingService::read_server_setting(&db).await.unwrap(); assert_eq!( - server_setting, - ServerSetting { - security_allowed_register_domain_list: vec!["example.com".into()], - security_disable_client_side_telemetry: true, - network_external_url: "http://localhost:9090".into(), + SettingService::read_network_setting(&db).await.unwrap(), + NetworkSetting { + external_url: "http://localhost:8081".into(), } ); }