diff --git a/ee/tabby-schema/src/policy.rs b/ee/tabby-schema/src/policy.rs index ddd73789e42a..1b8d5710cb78 100644 --- a/ee/tabby-schema/src/policy.rs +++ b/ee/tabby-schema/src/policy.rs @@ -216,4 +216,156 @@ mod tests { assert!(policy1.check_read_source(source_id).await.is_ok()); assert!(policy2.check_read_source(source_id).await.is_ok()); } + + #[tokio::test] + async fn test_check_delete_thread_messages() { + let db = DbConn::new_in_memory().await.unwrap(); + let user_id1 = testutils::create_user(&db).await; + let user_id2 = testutils::create_user2(&db).await; + + let policy1 = AccessPolicy::new(db.clone(), &user_id1.as_id(), false); + + assert!(policy1 + .check_delete_thread_messages(&user_id1.as_id()) + .is_ok()); + assert!(policy1 + .check_delete_thread_messages(&user_id2.as_id()) + .is_err()); + } + + #[tokio::test] + async fn test_check_update_thread_persistence() { + let db = DbConn::new_in_memory().await.unwrap(); + let user_id1 = testutils::create_user(&db).await; + let user_id2 = testutils::create_user2(&db).await; + + let policy1 = AccessPolicy::new(db.clone(), &user_id1.as_id(), false); + + assert!(policy1 + .check_update_thread_persistence(&user_id1.as_id()) + .is_ok()); + assert!(policy1 + .check_update_thread_persistence(&user_id2.as_id()) + .is_err()); + } + + #[tokio::test] + async fn test_check_read_analytic() { + let db = DbConn::new_in_memory().await.unwrap(); + let user_id1 = testutils::create_user(&db).await; + let user_id2 = testutils::create_user2(&db).await; + + let policy_normal = AccessPolicy::new(db.clone(), &user_id1.as_id(), false); + let policy_admin = AccessPolicy::new(db.clone(), &user_id1.as_id(), true); + + assert!(policy_normal + .check_read_analytic(&[user_id1.as_id()]) + .is_ok()); + assert!(policy_normal + .check_read_analytic(&[user_id2.as_id()]) + .is_err()); + assert!(policy_normal.check_read_analytic(&[]).is_err()); + + assert!(policy_admin + .check_read_analytic(&[user_id1.as_id()]) + .is_ok()); + assert!(policy_admin + .check_read_analytic(&[user_id2.as_id()]) + .is_ok()); + assert!(policy_admin.check_read_analytic(&[]).is_ok()); + } + + #[tokio::test] + async fn test_check_upsert_user_group_membership() { + let db = DbConn::new_in_memory().await.unwrap(); + let user_id1 = testutils::create_user(&db).await; + let user_id2 = testutils::create_user2(&db).await; + let user_group_id = db.create_user_group("test").await.unwrap(); + + db.upsert_user_group_membership(user_id1, user_group_id, true) + .await + .unwrap(); + + let policy_normal = AccessPolicy::new(db.clone(), &user_id2.as_id(), false); + let policy_group_admin = AccessPolicy::new(db.clone(), &user_id1.as_id(), false); + let policy_admin = AccessPolicy::new(db.clone(), &user_id1.as_id(), true); + + let input = UpsertUserGroupMembershipInput { + user_id: user_id2.as_id(), + user_group_id: user_group_id.as_id(), + is_group_admin: false, + }; + + assert!(policy_normal + .check_upsert_user_group_membership(&input) + .await + .is_err()); + + assert!(policy_group_admin + .check_upsert_user_group_membership(&input) + .await + .is_ok()); + + let admin_input = UpsertUserGroupMembershipInput { + is_group_admin: true, + user_id: user_id2.as_id(), + user_group_id: user_group_id.as_id(), + }; + assert!(policy_group_admin + .check_upsert_user_group_membership(&admin_input) + .await + .is_err()); + + assert!(policy_admin + .check_upsert_user_group_membership(&input) + .await + .is_ok()); + assert!(policy_admin + .check_upsert_user_group_membership(&admin_input) + .await + .is_ok()); + } + #[tokio::test] + async fn test_check_delete_user_group_membership() { + let db = DbConn::new_in_memory().await.unwrap(); + let user_id1 = testutils::create_user(&db).await; + let user_id2 = testutils::create_user2(&db).await; + let user_group_id = db.create_user_group("test").await.unwrap(); + + // Make user1 a group admin and user2 a normal member + db.upsert_user_group_membership(user_id1, user_group_id, true) + .await + .unwrap(); + db.upsert_user_group_membership(user_id2, user_group_id, false) + .await + .unwrap(); + + let policy_normal = AccessPolicy::new(db.clone(), &user_id2.as_id(), false); + let policy_group_admin = AccessPolicy::new(db.clone(), &user_id1.as_id(), false); + let policy_admin = AccessPolicy::new(db.clone(), &user_id1.as_id(), true); + + assert!(policy_normal + .check_delete_user_group_membership(&user_group_id.as_id(), &user_id2.as_id()) + .await + .is_err()); + + assert!(policy_group_admin + .check_delete_user_group_membership(&user_group_id.as_id(), &user_id2.as_id()) + .await + .is_ok()); + + assert!(policy_group_admin + .check_delete_user_group_membership(&user_group_id.as_id(), &user_id1.as_id()) + .await + .is_err()); + + assert!(policy_admin + .check_delete_user_group_membership(&user_group_id.as_id(), &user_id1.as_id()) + .await + .is_ok()); + assert!(policy_admin + .check_delete_user_group_membership(&user_group_id.as_id(), &user_id2.as_id()) + .await + .is_ok()); + } } diff --git a/ee/tabby-webserver/src/service/job.rs b/ee/tabby-webserver/src/service/job.rs index f2ec46578c43..9973b22657c3 100644 --- a/ee/tabby-webserver/src/service/job.rs +++ b/ee/tabby-webserver/src/service/job.rs @@ -132,4 +132,78 @@ mod tests { assert_matches!(job2dao.exit_code, Some(-1)); assert!(!job2dao.is_pending()) } + + #[tokio::test] + async fn test_list() { + let db = DbConn::new_in_memory().await.unwrap(); + let svc = super::create(db.clone()).await; + + let job1 = BackgroundJobEvent::WebCrawler(WebCrawlerJob::new( + "s1".into(), + "http://abc.com".into(), + None, + )); + let job2 = BackgroundJobEvent::WebCrawler(WebCrawlerJob::new( + "s2".into(), + "http://def.com".into(), + None, + )); + + let id1 = svc.trigger(job1.to_command()).await.unwrap(); + let id2 = svc.trigger(job2.to_command()).await.unwrap(); + + let ids = Vec::from([id1.clone(), id2.clone()]); + let all_jobs = svc + .list(Some(ids), None, None, None, None, None) + .await + .unwrap(); + assert_eq!(all_jobs.len(), 2); + + let specific_jobs = svc + .list(Some(vec![id1.clone()]), None, None, None, None, None) + .await + .unwrap(); + assert_eq!(specific_jobs.len(), 1); + assert_eq!(specific_jobs[0].id, id1); + + let first_job = svc + .list(None, None, None, None, Some(1), None) + .await + .unwrap(); + assert_eq!(first_job.len(), 1); + assert_eq!(first_job[0].id, id1); + } + + #[tokio::test] + async fn test_compute_stats() { + let db = DbConn::new_in_memory().await.unwrap(); + let svc = super::create(db.clone()).await; + + let job1 = BackgroundJobEvent::WebCrawler(WebCrawlerJob::new( + "s1".into(), + "http://abc.com".into(), + None, + )); + let job2 = BackgroundJobEvent::WebCrawler(WebCrawlerJob::new( + "s2".into(), + "http://edf.com".into(), + None, + )); + + svc.trigger(job1.to_command()).await.unwrap(); + svc.trigger(job2.to_command()).await.unwrap(); + + let stats = svc.compute_stats(None).await.unwrap(); + assert_eq!(stats.pending, 2); + assert_eq!(stats.success, 0); + assert_eq!(stats.failed, 0); + + let _ = db.update_job_status(1, 0).await; + let _ = db.update_job_status(2, 1).await; + + let updated_stats = svc.compute_stats(None).await.unwrap(); + assert_eq!(updated_stats.pending, 0); + assert_eq!(updated_stats.success, 1); + assert_eq!(updated_stats.failed, 1); + } } diff --git a/ee/tabby-webserver/src/service/thread.rs b/ee/tabby-webserver/src/service/thread.rs index c4d0ff4b4318..b2dcdc20eab8 100644 --- a/ee/tabby-webserver/src/service/thread.rs +++ b/ee/tabby-webserver/src/service/thread.rs @@ -270,11 +270,25 @@ pub fn create(db: DbConn, answer: Option>) -> impl ThreadServ #[cfg(test)] mod tests { + use tabby_common::{ + api::{ + code::{CodeSearch, CodeSearchParams}, + doc::DocSearch, + }, + config::AnswerConfig, + }; use tabby_db::{testutils::create_user, DbConn}; - use tabby_schema::thread::{CreateMessageInput, CreateThreadInput}; + use tabby_inference::ChatCompletionStream; + use tabby_schema::{ + context::ContextService, + thread::{CreateMessageInput, CreateThreadInput}, + }; use thread::MessageAttachmentCodeInput; use super::*; + use crate::answer::testutils::{ + FakeChatCompletionStream, FakeCodeSearch, FakeContextService, FakeDocSearch, + }; #[tokio::test] async fn test_create_thread() { @@ -426,4 +440,164 @@ mod tests { .unwrap(); assert_eq!(messages.len(), 1); } + + #[tokio::test] + async fn test_get_thread() { + let db = DbConn::new_in_memory().await.unwrap(); + let user_id = create_user(&db).await.as_id(); + let service = create(db, None); + + let input = CreateThreadInput { + user_message: CreateMessageInput { + content: "Ping".to_string(), + attachments: None, + }, + }; + + let thread_id = service.create(&user_id, &input).await.unwrap(); + + let thread = service.get(&thread_id).await.unwrap(); + assert!(thread.is_some()); + assert_eq!(thread.unwrap().id, thread_id); + + let non_existent_id = ID::from("non_existent".to_string()); + let non_existent_thread = service.get(&non_existent_id).await.unwrap(); + assert!(non_existent_thread.is_none()); + } + + #[tokio::test] + async fn test_set_persisted() { + let db = DbConn::new_in_memory().await.unwrap(); + let user_id = create_user(&db).await.as_id(); + let service = create(db.clone(), None); + + let input = CreateThreadInput { + user_message: CreateMessageInput { + content: "ping".to_string(), + attachments: None, + }, + }; + + let thread_id = service.create(&user_id, &input).await.unwrap(); + service.set_persisted(&thread_id).await.unwrap(); + } + + pub fn make_code_search_params() -> CodeSearchParams { + CodeSearchParams { + min_bm25_score: 0.5, + min_embedding_score: 0.7, + min_rrf_score: 0.3, + num_to_return: 5, + num_to_score: 10, + } + } + + pub fn make_answer_config() -> AnswerConfig { + AnswerConfig { + code_search_params: make_code_search_params(), + presence_penalty: 0.1, + } + } + + #[tokio::test] + async fn test_create_run() { + let db = DbConn::new_in_memory().await.unwrap(); + let user_id = create_user(&db).await.as_id(); + let chat: Arc = Arc::new(FakeChatCompletionStream); + let code: Arc = Arc::new(FakeCodeSearch); + let doc: Arc = Arc::new(FakeDocSearch); + let context: Arc = Arc::new(FakeContextService); + let serper = Some(Box::new(FakeDocSearch) as Box); + let config = make_answer_config(); + let answer_service = Arc::new(crate::answer::create( + &config, + chat.clone(), + code.clone(), + doc.clone(), + context.clone(), + serper, + )); + let service = create(db.clone(), Some(answer_service)); + + let input = CreateThreadInput { + user_message: CreateMessageInput { + content: "Test message".to_string(), + attachments: None, + }, + }; + + let thread_id = service.create(&user_id, &input).await.unwrap(); + + let policy = AccessPolicy::new(db.clone(), &user_id, false); + let options = ThreadRunOptionsInput::default(); + + let run_stream = service + .create_run(&policy, &thread_id, &options, None, true, true) + .await; + + assert!(run_stream.is_ok()); + } + + #[tokio::test] + async fn test_list_threads() { + let db = DbConn::new_in_memory().await.unwrap(); + let user_id = create_user(&db).await.as_id(); + let service = create(db, None); + + for i in 0..3 { + let input = CreateThreadInput { + user_message: CreateMessageInput { + content: format!("Test message {}", i), + attachments: None, + }, + }; + service.create(&user_id, &input).await.unwrap(); + } + + let threads = service + .list(None, None, None, None, None, None) + .await + .unwrap(); + assert_eq!(threads.len(), 3); + + let first_two = service + .list(None, None, None, None, Some(2), None) + .await + .unwrap(); + assert_eq!(first_two.len(), 2); + + let last_two = service + .list(None, None, None, None, None, Some(2)) + .await + .unwrap(); + assert_eq!(last_two.len(), 2); + assert_ne!(first_two[0].id, last_two[0].id); + + let ephemeral_threads = service + .list(None, Some(true), None, None, None, None) + .await + .unwrap(); + assert_eq!(ephemeral_threads.len(), 3); + + service.set_persisted(&threads[0].id).await.unwrap(); + + let persisted_threads = service + .list(None, Some(false), None, None, None, None) + .await + .unwrap(); + assert_eq!(persisted_threads.len(), 1); + + let specific_threads = service + .list( + Some(&[threads[0].id.clone(), threads[1].id.clone()]), + None, + None, + None, + None, + None, + ) + .await + .unwrap(); + assert_eq!(specific_threads.len(), 2); + } }