diff --git a/src/backend/ggml.rs b/src/backend/ggml.rs index ee4617f..2ea41ad 100644 --- a/src/backend/ggml.rs +++ b/src/backend/ggml.rs @@ -270,29 +270,13 @@ pub(crate) async fn rag_query_handler(mut req: Request) -> Response }; // * perform retrieval - let mut retrieve_object = match retrieve_context(&mut chat_request, server_info).await { + let retrieve_object = match retrieve_context(&mut chat_request, server_info).await { Ok(retrieve_object) => retrieve_object, Err(response) => { return response; } }; - // keep the point with the highest score - let scored_points = retrieve_object.points.as_mut().unwrap(); - - for (idx, point) in scored_points.iter().enumerate() { - // log - info!(target: "stdout", "point: {}, score: {}, source: {}", idx, point.score, &point.source); - } - - // remove all points which have different score from the first point - if scored_points.len() > 1 { - let first_score = scored_points[0].score; - scored_points.retain(|point| point.score == first_score); - } - - info!(target: "stdout", "{} point(s) with the highest score kept", scored_points.len()); - // * update messages with retrieved context match retrieve_object.points { Some(scored_points) => { @@ -457,7 +441,10 @@ async fn retrieve_context( ) -> Result> { info!(target: "stdout", "Compute embeddings for user query."); - // * compute embeddings for user query + let context_window = chat_request.context_window.unwrap() as usize; + info!(target: "stdout", "context window: {}", context_window); + + // compute embeddings for user query let embedding_response = match chat_request.messages.is_empty() { true => { let err_msg = "Messages should not be empty."; @@ -468,10 +455,11 @@ async fn retrieve_context( return Err(error::bad_request(err_msg)); } false => { - let last_message = chat_request.messages.last().unwrap(); - match last_message { - ChatCompletionRequestMessage::User(user_message) => { - let query_text = match user_message.content() { + // get the user messages in the context window + let mut last_messages = Vec::new(); + for (idx, message) in chat_request.messages.iter().rev().enumerate() { + if let ChatCompletionRequestMessage::User(user_message) = message { + let user_content = match user_message.content() { ChatCompletionUserMessageContent::Text(text) => text, _ => { let err_msg = "The last message must be a text content user message"; @@ -483,56 +471,75 @@ async fn retrieve_context( } }; - // log - info!(target: "stdout", "query text: {}", query_text); + if !user_content.ends_with("") { + last_messages.push(user_content.clone()); + } else if idx == 0 { + let content = user_content.trim_end_matches("").to_string(); + last_messages.push(content); + break; + } + } - // get the available embedding models - let embedding_model_names = match llama_core::utils::embedding_model_names() { - Ok(model_names) => model_names, - Err(e) => { - let err_msg = e.to_string(); + if last_messages.len() == context_window { + break; + } + } - // log - error!(target: "stdout", "{}", &err_msg); + // join the user messages in the context window into a single string + let query_text = if last_messages.len() > 0 { + info!(target: "stdout", "Found the latest {} user messages.", last_messages.len()); - return Err(error::internal_server_error(err_msg)); - } - }; + last_messages.reverse(); + last_messages.join("\n") + } else { + let warn_msg = "No user messages found."; - // create a embedding request - let embedding_request = EmbeddingRequest { - model: embedding_model_names[0].clone(), - input: query_text.into(), - encoding_format: None, - user: chat_request.user.clone(), - }; + // log + warn!(target: "stdout", "{}", &warn_msg); - let rag_embedding_request = RagEmbeddingRequest { - embedding_request, - qdrant_url: server_info.qdrant_config.url.clone(), - qdrant_collection_name: server_info.qdrant_config.collection_name.clone(), - }; + return Err(error::bad_request(warn_msg)); + }; - // compute embeddings for query - match llama_core::rag::rag_query_to_embeddings(&rag_embedding_request).await { - Ok(embedding_response) => embedding_response, - Err(e) => { - let err_msg = e.to_string(); + // log + info!(target: "stdout", "query text for the context retrieval: {}", query_text); - // log - error!(target: "stdout", "{}", &err_msg); + // get the available embedding models + let embedding_model_names = match llama_core::utils::embedding_model_names() { + Ok(model_names) => model_names, + Err(e) => { + let err_msg = e.to_string(); - return Err(error::internal_server_error(err_msg)); - } - } + // log + error!(target: "stdout", "{}", &err_msg); + + return Err(error::internal_server_error(err_msg)); } - _ => { - let err_msg = "The last message must be a user message"; + }; + + // create a embedding request + let embedding_request = EmbeddingRequest { + model: embedding_model_names[0].clone(), + input: InputText::String(query_text), + encoding_format: None, + user: chat_request.user.clone(), + }; + + let rag_embedding_request = RagEmbeddingRequest { + embedding_request, + qdrant_url: server_info.qdrant_config.url.clone(), + qdrant_collection_name: server_info.qdrant_config.collection_name.clone(), + }; + + // compute embeddings for query + match llama_core::rag::rag_query_to_embeddings(&rag_embedding_request).await { + Ok(embedding_response) => embedding_response, + Err(e) => { + let err_msg = e.to_string(); // log error!(target: "stdout", "{}", &err_msg); - return Err(error::bad_request(err_msg)); + return Err(error::internal_server_error(err_msg)); } } } @@ -549,9 +556,8 @@ async fn retrieve_context( } }; - // perform the first round retrieval (using the last user message) - let mut retrieve_object = RetrieveObject::default(); - match llama_core::rag::rag_retrieve_context( + // perform the context retrieval + let mut retrieve_object: RetrieveObject = match llama_core::rag::rag_retrieve_context( query_embedding.as_slice(), server_info.qdrant_config.url.to_string().as_str(), server_info.qdrant_config.collection_name.as_str(), @@ -560,173 +566,21 @@ async fn retrieve_context( ) .await { - Ok(search_result) => { - retrieve_object = search_result; - if retrieve_object.points.is_none() { - retrieve_object.points = Some(Vec::new()); - } - - info!(target: "stdout", "{} point(s) retrieved from the first retrieval", retrieve_object.points.as_ref().unwrap().len()); - } + Ok(search_result) => search_result, Err(e) => { - // log - error!(target: "stdout", "No point retrieved. {}", e); - } - }; - - // perform the second round retrieval (using the last 3 user messages) - if let Some(multi_retrieval) = MULTI_RETRIEVAL.get() { - if *multi_retrieval { - let embedding_response_second_retrieval = match chat_request.messages.is_empty() { - true => { - let err_msg = "Messages should not be empty."; - - // log - error!(target: "stdout", "{}", &err_msg); - - return Err(error::bad_request(err_msg)); - } - false => { - // get the last 3 user messages - let mut last_messages = Vec::new(); - for message in chat_request.messages.iter().rev() { - if let ChatCompletionRequestMessage::User(user_message) = message { - let user_content = match user_message.content() { - ChatCompletionUserMessageContent::Text(text) => text, - _ => { - let err_msg = - "The last message must be a text content user message"; + let err_msg = format!("No point retrieved. {}", e); - // log - error!(target: "stdout", "{}", &err_msg); - - return Err(error::bad_request(err_msg)); - } - }; - - last_messages.push(user_content.clone()); - } - - if last_messages.len() == 3 { - break; - } - } - - info!(target: "stdout", "Found the latest {} user messages.", last_messages.len()); - - // join the last 3 user messages into a single string - let query_text = last_messages.join("\n"); - - // log - info!(target: "stdout", "query text for the second retrieval: {}", query_text); - - // get the available embedding models - let embedding_model_names = match llama_core::utils::embedding_model_names() { - Ok(model_names) => model_names, - Err(e) => { - let err_msg = e.to_string(); - - // log - error!(target: "stdout", "{}", &err_msg); - - return Err(error::internal_server_error(err_msg)); - } - }; - - // create a embedding request - let embedding_request = EmbeddingRequest { - model: embedding_model_names[0].clone(), - input: InputText::String(query_text), - encoding_format: None, - user: chat_request.user.clone(), - }; - - let rag_embedding_request = RagEmbeddingRequest { - embedding_request, - qdrant_url: server_info.qdrant_config.url.clone(), - qdrant_collection_name: server_info.qdrant_config.collection_name.clone(), - }; - - // compute embeddings for query - match llama_core::rag::rag_query_to_embeddings(&rag_embedding_request).await { - Ok(embedding_response) => embedding_response, - Err(e) => { - let err_msg = e.to_string(); - - // log - error!(target: "stdout", "{}", &err_msg); - - return Err(error::internal_server_error(err_msg)); - } - } - } - }; - - // compute embeddings for the multiple user messages - let query_embedding_second_retrieval: Vec = - match embedding_response_second_retrieval.data.first() { - Some(embedding) => embedding.embedding.iter().map(|x| *x as f32).collect(), - None => { - let err_msg = "No embeddings returned"; - - // log - error!(target: "stdout", "{}", &err_msg); - - return Err(error::internal_server_error(err_msg)); - } - }; + // log + error!(target: "stdout", "{}", &err_msg); - // perform the second retrieval - match llama_core::rag::rag_retrieve_context( - query_embedding_second_retrieval.as_slice(), - server_info.qdrant_config.url.to_string().as_str(), - server_info.qdrant_config.collection_name.as_str(), - server_info.qdrant_config.limit as usize, - Some(server_info.qdrant_config.score_threshold), - ) - .await - { - Ok(retrieve_object_second_retrieval) => { - match retrieve_object_second_retrieval.points { - Some(scored_points) if !scored_points.is_empty() => { - info!(target: "stdout", "{} point(s) retrieved from the second retrieval", scored_points.len()); - - if let Some(points) = retrieve_object.points.as_mut() { - points.extend(scored_points); - } - } - _ => { - warn!(target: "stdout", "No point retrieved from the second retrieval"); - } - } - } - Err(e) => { - // log - error!(target: "stdout", "No point retrieved from the second retrieval. {}", e); - } - }; + return Err(error::internal_server_error(err_msg)); } + }; + if retrieve_object.points.is_none() { + retrieve_object.points = Some(Vec::new()); } - info!(target: "stdout", "{} point(s) retrieved from the two rounds of retrieval", retrieve_object.points.as_ref().unwrap().len()); - - // sort the points by score - retrieve_object.points.as_mut().unwrap().sort_by(|a, b| { - // Compare scores in reverse order (highest first) - // Using partial_cmp for floating point comparison - b.score - .partial_cmp(&a.score) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - // remove the duplicate points - retrieve_object - .points - .as_mut() - .unwrap() - .dedup_by_key(|point| point.source.clone()); - - info!(target: "stdout", "{} point(s) kept after removing duplicates", retrieve_object.points.as_ref().unwrap().len()); + info!(target: "stdout", "{} point(s) retrieved", retrieve_object.points.as_ref().unwrap().len()); Ok(retrieve_object) }