diff --git a/src/backend/ggml.rs b/src/backend/ggml.rs
index ee4617f..2ea41ad 100644
--- a/src/backend/ggml.rs
+++ b/src/backend/ggml.rs
@@ -270,29 +270,13 @@ pub(crate) async fn rag_query_handler(mut req: Request
) -> Response
};
// * perform retrieval
- let mut retrieve_object = match retrieve_context(&mut chat_request, server_info).await {
+ let retrieve_object = match retrieve_context(&mut chat_request, server_info).await {
Ok(retrieve_object) => retrieve_object,
Err(response) => {
return response;
}
};
- // keep the point with the highest score
- let scored_points = retrieve_object.points.as_mut().unwrap();
-
- for (idx, point) in scored_points.iter().enumerate() {
- // log
- info!(target: "stdout", "point: {}, score: {}, source: {}", idx, point.score, &point.source);
- }
-
- // remove all points which have different score from the first point
- if scored_points.len() > 1 {
- let first_score = scored_points[0].score;
- scored_points.retain(|point| point.score == first_score);
- }
-
- info!(target: "stdout", "{} point(s) with the highest score kept", scored_points.len());
-
// * update messages with retrieved context
match retrieve_object.points {
Some(scored_points) => {
@@ -457,7 +441,10 @@ async fn retrieve_context(
) -> Result> {
info!(target: "stdout", "Compute embeddings for user query.");
- // * compute embeddings for user query
+ let context_window = chat_request.context_window.unwrap() as usize;
+ info!(target: "stdout", "context window: {}", context_window);
+
+ // compute embeddings for user query
let embedding_response = match chat_request.messages.is_empty() {
true => {
let err_msg = "Messages should not be empty.";
@@ -468,10 +455,11 @@ async fn retrieve_context(
return Err(error::bad_request(err_msg));
}
false => {
- let last_message = chat_request.messages.last().unwrap();
- match last_message {
- ChatCompletionRequestMessage::User(user_message) => {
- let query_text = match user_message.content() {
+ // get the user messages in the context window
+ let mut last_messages = Vec::new();
+ for (idx, message) in chat_request.messages.iter().rev().enumerate() {
+ if let ChatCompletionRequestMessage::User(user_message) = message {
+ let user_content = match user_message.content() {
ChatCompletionUserMessageContent::Text(text) => text,
_ => {
let err_msg = "The last message must be a text content user message";
@@ -483,56 +471,75 @@ async fn retrieve_context(
}
};
- // log
- info!(target: "stdout", "query text: {}", query_text);
+ if !user_content.ends_with("") {
+ last_messages.push(user_content.clone());
+ } else if idx == 0 {
+ let content = user_content.trim_end_matches("").to_string();
+ last_messages.push(content);
+ break;
+ }
+ }
- // get the available embedding models
- let embedding_model_names = match llama_core::utils::embedding_model_names() {
- Ok(model_names) => model_names,
- Err(e) => {
- let err_msg = e.to_string();
+ if last_messages.len() == context_window {
+ break;
+ }
+ }
- // log
- error!(target: "stdout", "{}", &err_msg);
+ // join the user messages in the context window into a single string
+ let query_text = if last_messages.len() > 0 {
+ info!(target: "stdout", "Found the latest {} user messages.", last_messages.len());
- return Err(error::internal_server_error(err_msg));
- }
- };
+ last_messages.reverse();
+ last_messages.join("\n")
+ } else {
+ let warn_msg = "No user messages found.";
- // create a embedding request
- let embedding_request = EmbeddingRequest {
- model: embedding_model_names[0].clone(),
- input: query_text.into(),
- encoding_format: None,
- user: chat_request.user.clone(),
- };
+ // log
+ warn!(target: "stdout", "{}", &warn_msg);
- let rag_embedding_request = RagEmbeddingRequest {
- embedding_request,
- qdrant_url: server_info.qdrant_config.url.clone(),
- qdrant_collection_name: server_info.qdrant_config.collection_name.clone(),
- };
+ return Err(error::bad_request(warn_msg));
+ };
- // compute embeddings for query
- match llama_core::rag::rag_query_to_embeddings(&rag_embedding_request).await {
- Ok(embedding_response) => embedding_response,
- Err(e) => {
- let err_msg = e.to_string();
+ // log
+ info!(target: "stdout", "query text for the context retrieval: {}", query_text);
- // log
- error!(target: "stdout", "{}", &err_msg);
+ // get the available embedding models
+ let embedding_model_names = match llama_core::utils::embedding_model_names() {
+ Ok(model_names) => model_names,
+ Err(e) => {
+ let err_msg = e.to_string();
- return Err(error::internal_server_error(err_msg));
- }
- }
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return Err(error::internal_server_error(err_msg));
}
- _ => {
- let err_msg = "The last message must be a user message";
+ };
+
+ // create a embedding request
+ let embedding_request = EmbeddingRequest {
+ model: embedding_model_names[0].clone(),
+ input: InputText::String(query_text),
+ encoding_format: None,
+ user: chat_request.user.clone(),
+ };
+
+ let rag_embedding_request = RagEmbeddingRequest {
+ embedding_request,
+ qdrant_url: server_info.qdrant_config.url.clone(),
+ qdrant_collection_name: server_info.qdrant_config.collection_name.clone(),
+ };
+
+ // compute embeddings for query
+ match llama_core::rag::rag_query_to_embeddings(&rag_embedding_request).await {
+ Ok(embedding_response) => embedding_response,
+ Err(e) => {
+ let err_msg = e.to_string();
// log
error!(target: "stdout", "{}", &err_msg);
- return Err(error::bad_request(err_msg));
+ return Err(error::internal_server_error(err_msg));
}
}
}
@@ -549,9 +556,8 @@ async fn retrieve_context(
}
};
- // perform the first round retrieval (using the last user message)
- let mut retrieve_object = RetrieveObject::default();
- match llama_core::rag::rag_retrieve_context(
+ // perform the context retrieval
+ let mut retrieve_object: RetrieveObject = match llama_core::rag::rag_retrieve_context(
query_embedding.as_slice(),
server_info.qdrant_config.url.to_string().as_str(),
server_info.qdrant_config.collection_name.as_str(),
@@ -560,173 +566,21 @@ async fn retrieve_context(
)
.await
{
- Ok(search_result) => {
- retrieve_object = search_result;
- if retrieve_object.points.is_none() {
- retrieve_object.points = Some(Vec::new());
- }
-
- info!(target: "stdout", "{} point(s) retrieved from the first retrieval", retrieve_object.points.as_ref().unwrap().len());
- }
+ Ok(search_result) => search_result,
Err(e) => {
- // log
- error!(target: "stdout", "No point retrieved. {}", e);
- }
- };
-
- // perform the second round retrieval (using the last 3 user messages)
- if let Some(multi_retrieval) = MULTI_RETRIEVAL.get() {
- if *multi_retrieval {
- let embedding_response_second_retrieval = match chat_request.messages.is_empty() {
- true => {
- let err_msg = "Messages should not be empty.";
-
- // log
- error!(target: "stdout", "{}", &err_msg);
-
- return Err(error::bad_request(err_msg));
- }
- false => {
- // get the last 3 user messages
- let mut last_messages = Vec::new();
- for message in chat_request.messages.iter().rev() {
- if let ChatCompletionRequestMessage::User(user_message) = message {
- let user_content = match user_message.content() {
- ChatCompletionUserMessageContent::Text(text) => text,
- _ => {
- let err_msg =
- "The last message must be a text content user message";
+ let err_msg = format!("No point retrieved. {}", e);
- // log
- error!(target: "stdout", "{}", &err_msg);
-
- return Err(error::bad_request(err_msg));
- }
- };
-
- last_messages.push(user_content.clone());
- }
-
- if last_messages.len() == 3 {
- break;
- }
- }
-
- info!(target: "stdout", "Found the latest {} user messages.", last_messages.len());
-
- // join the last 3 user messages into a single string
- let query_text = last_messages.join("\n");
-
- // log
- info!(target: "stdout", "query text for the second retrieval: {}", query_text);
-
- // get the available embedding models
- let embedding_model_names = match llama_core::utils::embedding_model_names() {
- Ok(model_names) => model_names,
- Err(e) => {
- let err_msg = e.to_string();
-
- // log
- error!(target: "stdout", "{}", &err_msg);
-
- return Err(error::internal_server_error(err_msg));
- }
- };
-
- // create a embedding request
- let embedding_request = EmbeddingRequest {
- model: embedding_model_names[0].clone(),
- input: InputText::String(query_text),
- encoding_format: None,
- user: chat_request.user.clone(),
- };
-
- let rag_embedding_request = RagEmbeddingRequest {
- embedding_request,
- qdrant_url: server_info.qdrant_config.url.clone(),
- qdrant_collection_name: server_info.qdrant_config.collection_name.clone(),
- };
-
- // compute embeddings for query
- match llama_core::rag::rag_query_to_embeddings(&rag_embedding_request).await {
- Ok(embedding_response) => embedding_response,
- Err(e) => {
- let err_msg = e.to_string();
-
- // log
- error!(target: "stdout", "{}", &err_msg);
-
- return Err(error::internal_server_error(err_msg));
- }
- }
- }
- };
-
- // compute embeddings for the multiple user messages
- let query_embedding_second_retrieval: Vec =
- match embedding_response_second_retrieval.data.first() {
- Some(embedding) => embedding.embedding.iter().map(|x| *x as f32).collect(),
- None => {
- let err_msg = "No embeddings returned";
-
- // log
- error!(target: "stdout", "{}", &err_msg);
-
- return Err(error::internal_server_error(err_msg));
- }
- };
+ // log
+ error!(target: "stdout", "{}", &err_msg);
- // perform the second retrieval
- match llama_core::rag::rag_retrieve_context(
- query_embedding_second_retrieval.as_slice(),
- server_info.qdrant_config.url.to_string().as_str(),
- server_info.qdrant_config.collection_name.as_str(),
- server_info.qdrant_config.limit as usize,
- Some(server_info.qdrant_config.score_threshold),
- )
- .await
- {
- Ok(retrieve_object_second_retrieval) => {
- match retrieve_object_second_retrieval.points {
- Some(scored_points) if !scored_points.is_empty() => {
- info!(target: "stdout", "{} point(s) retrieved from the second retrieval", scored_points.len());
-
- if let Some(points) = retrieve_object.points.as_mut() {
- points.extend(scored_points);
- }
- }
- _ => {
- warn!(target: "stdout", "No point retrieved from the second retrieval");
- }
- }
- }
- Err(e) => {
- // log
- error!(target: "stdout", "No point retrieved from the second retrieval. {}", e);
- }
- };
+ return Err(error::internal_server_error(err_msg));
}
+ };
+ if retrieve_object.points.is_none() {
+ retrieve_object.points = Some(Vec::new());
}
- info!(target: "stdout", "{} point(s) retrieved from the two rounds of retrieval", retrieve_object.points.as_ref().unwrap().len());
-
- // sort the points by score
- retrieve_object.points.as_mut().unwrap().sort_by(|a, b| {
- // Compare scores in reverse order (highest first)
- // Using partial_cmp for floating point comparison
- b.score
- .partial_cmp(&a.score)
- .unwrap_or(std::cmp::Ordering::Equal)
- });
-
- // remove the duplicate points
- retrieve_object
- .points
- .as_mut()
- .unwrap()
- .dedup_by_key(|point| point.source.clone());
-
- info!(target: "stdout", "{} point(s) kept after removing duplicates", retrieve_object.points.as_ref().unwrap().len());
+ info!(target: "stdout", "{} point(s) retrieved", retrieve_object.points.as_ref().unwrap().len());
Ok(retrieve_object)
}