From fae48d4973313e56530d98e9793a24871ec25516 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 16 Aug 2024 10:06:47 -0700 Subject: [PATCH] chore(webserver): add deleteThreadMessagePair mutation (#2889) * chore(webserver): add deleteThreadMessagePair mutation * update * update * add test coverage * update * update --- ee/tabby-db/src/threads.rs | 42 +++++++++++ ee/tabby-schema/graphql/schema.graphql | 2 + ee/tabby-schema/src/schema/mod.rs | 27 +++++++ ee/tabby-schema/src/schema/thread.rs | 11 ++- ee/tabby-webserver/src/service/thread.rs | 96 ++++++++++++++++++++++++ 5 files changed, 175 insertions(+), 3 deletions(-) diff --git a/ee/tabby-db/src/threads.rs b/ee/tabby-db/src/threads.rs index 7466b903d026..12a44fb861dd 100644 --- a/ee/tabby-db/src/threads.rs +++ b/ee/tabby-db/src/threads.rs @@ -234,4 +234,46 @@ impl DbConn { Ok(messages) } + + pub async fn delete_thread_message_pair( + &self, + thread_id: i64, + user_message_id: i64, + assistant_message_id: i64, + ) -> Result<()> { + #[derive(FromRow)] + struct Response { + id: i64, + role: String, + } + + let message = query_as!( + Response, + "SELECT id, role FROM thread_messages WHERE thread_id = ? AND id >= ? AND id <= ?", + thread_id, + user_message_id, + assistant_message_id + ) + .fetch_all(&self.pool) + .await?; + + if message.len() != 2 { + bail!("Thread message pair is not valid") + } + + let is_valid_user_message = message[0].id == user_message_id && message[0].role == "user"; + let is_valid_assistant_message = + message[1].id == assistant_message_id && message[1].role == "assistant"; + + if !is_valid_user_message || !is_valid_assistant_message { + bail!("Invalid message pair"); + } + + let message_ids = format!("{}, {}", user_message_id, assistant_message_id); + query!("DELETE FROM thread_messages WHERE id IN (?)", message_ids) + .execute(&self.pool) + .await?; + + Ok(()) + } } diff --git a/ee/tabby-schema/graphql/schema.graphql b/ee/tabby-schema/graphql/schema.graphql index 667c3e145b97..5953b40f5776 100644 --- a/ee/tabby-schema/graphql/schema.graphql +++ b/ee/tabby-schema/graphql/schema.graphql @@ -475,6 +475,8 @@ type Mutation { triggerJobRun(command: String!): ID! createWebCrawlerUrl(input: CreateWebCrawlerUrlInput!): ID! deleteWebCrawlerUrl(id: ID!): Boolean! + "Delete pair of user message and bot response in a thread." + deleteThreadMessagePair(threadId: ID!, userMessageId: ID!, assistantMessageId: ID!): Boolean! } type NetworkSetting { diff --git a/ee/tabby-schema/src/schema/mod.rs b/ee/tabby-schema/src/schema/mod.rs index 6d396d14be8c..d39085639e2a 100644 --- a/ee/tabby-schema/src/schema/mod.rs +++ b/ee/tabby-schema/src/schema/mod.rs @@ -916,6 +916,33 @@ impl Mutation { ctx.locator.web_crawler().delete_web_crawler_url(id).await?; Ok(true) } + + /// Delete pair of user message and bot response in a thread. + async fn delete_thread_message_pair( + ctx: &Context, + thread_id: ID, + user_message_id: ID, + assistant_message_id: ID, + ) -> Result { + // ast-grep-ignore: use-schema-result + use anyhow::Context; + + let user = check_user(ctx).await?; + let svc = ctx.locator.thread(); + let thread = svc.get(&thread_id).await?.context("Thread not found")?; + + if thread.user_id != user.id { + return Err(CoreError::Forbidden( + "You must be the thread owner to delete the latest message pair", + )); + } + + ctx.locator + .thread() + .delete_thread_message_pair(&thread_id, &user_message_id, &assistant_message_id) + .await?; + Ok(true) + } } async fn check_analytic_access(ctx: &Context, users: &[ID]) -> Result<(), CoreError> { diff --git a/ee/tabby-schema/src/schema/thread.rs b/ee/tabby-schema/src/schema/thread.rs index e91744625d5c..4ab637490101 100644 --- a/ee/tabby-schema/src/schema/thread.rs +++ b/ee/tabby-schema/src/schema/thread.rs @@ -46,9 +46,6 @@ pub trait ThreadService: Send + Sync { // /// Delete a thread by ID // async fn delete(&self, id: ID) -> Result<()>; - // /// Delete a message by ID - // async fn delete_message(&self, id: ID) -> Result<()>; - /// Query messages in a thread async fn list_thread_messages( &self, @@ -58,4 +55,12 @@ pub trait ThreadService: Send + Sync { first: Option, last: Option, ) -> Result>; + + /// Delete pair of user message and bot response in a thread. + async fn delete_thread_message_pair( + &self, + thread_id: &ID, + user_message_id: &ID, + assistant_message_id: &ID, + ) -> Result<()>; } diff --git a/ee/tabby-webserver/src/service/thread.rs b/ee/tabby-webserver/src/service/thread.rs index 931dfe60c6c2..d6a1f5249c66 100644 --- a/ee/tabby-webserver/src/service/thread.rs +++ b/ee/tabby-webserver/src/service/thread.rs @@ -209,6 +209,22 @@ impl ThreadService for ThreadServiceImpl { to_vec_messages(messages) } + + async fn delete_thread_message_pair( + &self, + thread_id: &ID, + user_message_id: &ID, + assistant_message_id: &ID, + ) -> Result<()> { + self.db + .delete_thread_message_pair( + thread_id.as_rowid()?, + user_message_id.as_rowid()?, + assistant_message_id.as_rowid()?, + ) + .await?; + Ok(()) + } } fn to_vec_messages(messages: Vec) -> Result> { @@ -280,4 +296,84 @@ mod tests { .await .is_err()); } + + #[tokio::test] + async fn test_delete_thread_message_pair() { + let db = DbConn::new_in_memory().await.unwrap(); + let user_id = create_user(&db).await.as_id(); + let service = create(db.clone(), None); + + let thread_id = service + .create( + &user_id, + &CreateThreadInput { + user_message: CreateMessageInput { + content: "Ping!".to_string(), + attachments: None, + }, + }, + ) + .await + .unwrap(); + + let assistant_message_id = db + .create_thread_message( + thread_id.as_rowid().unwrap(), + thread::Role::Assistant.as_enum_str(), + "Pong!", + None, + None, + false, + ) + .await + .unwrap(); + + let user_message_id = assistant_message_id - 1; + + // Create another user message to test the error case + let another_user_message_id = db + .create_thread_message( + thread_id.as_rowid().unwrap(), + thread::Role::User.as_enum_str(), + "Ping another time!", + None, + None, + false, + ) + .await + .unwrap(); + + let messages = service + .list_thread_messages(&thread_id, None, None, None, None) + .await + .unwrap(); + assert_eq!(messages.len(), 3); + + assert!(service + .delete_thread_message_pair( + &thread_id, + &another_user_message_id.as_id(), + &assistant_message_id.as_id() + ) + .await + .is_err()); + + assert!(service + .delete_thread_message_pair( + &thread_id, + &assistant_message_id.as_id(), + &another_user_message_id.as_id() + ) + .await + .is_err()); + + assert!(service + .delete_thread_message_pair( + &thread_id, + &user_message_id.as_id(), + &assistant_message_id.as_id() + ) + .await + .is_ok()); + } }