Skip to content

Commit

Permalink
refactor: update retrieve_context and rag_query_handler funcs
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Liu <[email protected]>
  • Loading branch information
apepkuss committed Nov 5, 2024
1 parent 206ad34 commit 02ffa0f
Showing 1 changed file with 77 additions and 223 deletions.
300 changes: 77 additions & 223 deletions src/backend/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,29 +270,13 @@ pub(crate) async fn rag_query_handler(mut req: Request<Body>) -> Response<Body>
};

// * perform retrieval
let mut retrieve_object = match retrieve_context(&mut chat_request, server_info).await {
let retrieve_object = match retrieve_context(&mut chat_request, server_info).await {
Ok(retrieve_object) => retrieve_object,
Err(response) => {
return response;
}
};

// keep the point with the highest score
let scored_points = retrieve_object.points.as_mut().unwrap();

for (idx, point) in scored_points.iter().enumerate() {
// log
info!(target: "stdout", "point: {}, score: {}, source: {}", idx, point.score, &point.source);
}

// remove all points which have different score from the first point
if scored_points.len() > 1 {
let first_score = scored_points[0].score;
scored_points.retain(|point| point.score == first_score);
}

info!(target: "stdout", "{} point(s) with the highest score kept", scored_points.len());

// * update messages with retrieved context
match retrieve_object.points {
Some(scored_points) => {
Expand Down Expand Up @@ -457,7 +441,10 @@ async fn retrieve_context(
) -> Result<RetrieveObject, Response<Body>> {
info!(target: "stdout", "Compute embeddings for user query.");

// * compute embeddings for user query
let context_window = chat_request.context_window.unwrap() as usize;
info!(target: "stdout", "context window: {}", context_window);

// compute embeddings for user query
let embedding_response = match chat_request.messages.is_empty() {
true => {
let err_msg = "Messages should not be empty.";
Expand All @@ -468,10 +455,11 @@ async fn retrieve_context(
return Err(error::bad_request(err_msg));
}
false => {
let last_message = chat_request.messages.last().unwrap();
match last_message {
ChatCompletionRequestMessage::User(user_message) => {
let query_text = match user_message.content() {
// get the user messages in the context window
let mut last_messages = Vec::new();
for (idx, message) in chat_request.messages.iter().rev().enumerate() {
if let ChatCompletionRequestMessage::User(user_message) = message {
let user_content = match user_message.content() {
ChatCompletionUserMessageContent::Text(text) => text,
_ => {
let err_msg = "The last message must be a text content user message";
Expand All @@ -483,56 +471,75 @@ async fn retrieve_context(
}
};

// log
info!(target: "stdout", "query text: {}", query_text);
if !user_content.ends_with("<server-health>") {
last_messages.push(user_content.clone());
} else if idx == 0 {
let content = user_content.trim_end_matches("<server-health>").to_string();
last_messages.push(content);
break;
}
}

// get the available embedding models
let embedding_model_names = match llama_core::utils::embedding_model_names() {
Ok(model_names) => model_names,
Err(e) => {
let err_msg = e.to_string();
if last_messages.len() == context_window {
break;
}
}

// log
error!(target: "stdout", "{}", &err_msg);
// join the user messages in the context window into a single string
let query_text = if last_messages.len() > 0 {

Check failure on line 489 in src/backend/ggml.rs

View workflow job for this annotation

GitHub Actions / build-wasm (ubuntu-22.04)

length comparison to zero

Check failure on line 489 in src/backend/ggml.rs

View workflow job for this annotation

GitHub Actions / build-wasm (macos-14)

length comparison to zero
info!(target: "stdout", "Found the latest {} user messages.", last_messages.len());

return Err(error::internal_server_error(err_msg));
}
};
last_messages.reverse();
last_messages.join("\n")
} else {
let warn_msg = "No user messages found.";

// create a embedding request
let embedding_request = EmbeddingRequest {
model: embedding_model_names[0].clone(),
input: query_text.into(),
encoding_format: None,
user: chat_request.user.clone(),
};
// log
warn!(target: "stdout", "{}", &warn_msg);

let rag_embedding_request = RagEmbeddingRequest {
embedding_request,
qdrant_url: server_info.qdrant_config.url.clone(),
qdrant_collection_name: server_info.qdrant_config.collection_name.clone(),
};
return Err(error::bad_request(warn_msg));
};

// compute embeddings for query
match llama_core::rag::rag_query_to_embeddings(&rag_embedding_request).await {
Ok(embedding_response) => embedding_response,
Err(e) => {
let err_msg = e.to_string();
// log
info!(target: "stdout", "query text for the context retrieval: {}", query_text);

// log
error!(target: "stdout", "{}", &err_msg);
// get the available embedding models
let embedding_model_names = match llama_core::utils::embedding_model_names() {
Ok(model_names) => model_names,
Err(e) => {
let err_msg = e.to_string();

return Err(error::internal_server_error(err_msg));
}
}
// log
error!(target: "stdout", "{}", &err_msg);

return Err(error::internal_server_error(err_msg));
}
_ => {
let err_msg = "The last message must be a user message";
};

// create a embedding request
let embedding_request = EmbeddingRequest {
model: embedding_model_names[0].clone(),
input: InputText::String(query_text),
encoding_format: None,
user: chat_request.user.clone(),
};

let rag_embedding_request = RagEmbeddingRequest {
embedding_request,
qdrant_url: server_info.qdrant_config.url.clone(),
qdrant_collection_name: server_info.qdrant_config.collection_name.clone(),
};

// compute embeddings for query
match llama_core::rag::rag_query_to_embeddings(&rag_embedding_request).await {
Ok(embedding_response) => embedding_response,
Err(e) => {
let err_msg = e.to_string();

// log
error!(target: "stdout", "{}", &err_msg);

return Err(error::bad_request(err_msg));
return Err(error::internal_server_error(err_msg));
}
}
}
Expand All @@ -549,9 +556,8 @@ async fn retrieve_context(
}
};

// perform the first round retrieval (using the last user message)
let mut retrieve_object = RetrieveObject::default();
match llama_core::rag::rag_retrieve_context(
// perform the context retrieval
let mut retrieve_object: RetrieveObject = match llama_core::rag::rag_retrieve_context(
query_embedding.as_slice(),
server_info.qdrant_config.url.to_string().as_str(),
server_info.qdrant_config.collection_name.as_str(),
Expand All @@ -560,173 +566,21 @@ async fn retrieve_context(
)
.await
{
Ok(search_result) => {
retrieve_object = search_result;
if retrieve_object.points.is_none() {
retrieve_object.points = Some(Vec::new());
}

info!(target: "stdout", "{} point(s) retrieved from the first retrieval", retrieve_object.points.as_ref().unwrap().len());
}
Ok(search_result) => search_result,
Err(e) => {
// log
error!(target: "stdout", "No point retrieved. {}", e);
}
};

// perform the second round retrieval (using the last 3 user messages)
if let Some(multi_retrieval) = MULTI_RETRIEVAL.get() {
if *multi_retrieval {
let embedding_response_second_retrieval = match chat_request.messages.is_empty() {
true => {
let err_msg = "Messages should not be empty.";

// log
error!(target: "stdout", "{}", &err_msg);

return Err(error::bad_request(err_msg));
}
false => {
// get the last 3 user messages
let mut last_messages = Vec::new();
for message in chat_request.messages.iter().rev() {
if let ChatCompletionRequestMessage::User(user_message) = message {
let user_content = match user_message.content() {
ChatCompletionUserMessageContent::Text(text) => text,
_ => {
let err_msg =
"The last message must be a text content user message";
let err_msg = format!("No point retrieved. {}", e);

// log
error!(target: "stdout", "{}", &err_msg);

return Err(error::bad_request(err_msg));
}
};

last_messages.push(user_content.clone());
}

if last_messages.len() == 3 {
break;
}
}

info!(target: "stdout", "Found the latest {} user messages.", last_messages.len());

// join the last 3 user messages into a single string
let query_text = last_messages.join("\n");

// log
info!(target: "stdout", "query text for the second retrieval: {}", query_text);

// get the available embedding models
let embedding_model_names = match llama_core::utils::embedding_model_names() {
Ok(model_names) => model_names,
Err(e) => {
let err_msg = e.to_string();

// log
error!(target: "stdout", "{}", &err_msg);

return Err(error::internal_server_error(err_msg));
}
};

// create a embedding request
let embedding_request = EmbeddingRequest {
model: embedding_model_names[0].clone(),
input: InputText::String(query_text),
encoding_format: None,
user: chat_request.user.clone(),
};

let rag_embedding_request = RagEmbeddingRequest {
embedding_request,
qdrant_url: server_info.qdrant_config.url.clone(),
qdrant_collection_name: server_info.qdrant_config.collection_name.clone(),
};

// compute embeddings for query
match llama_core::rag::rag_query_to_embeddings(&rag_embedding_request).await {
Ok(embedding_response) => embedding_response,
Err(e) => {
let err_msg = e.to_string();

// log
error!(target: "stdout", "{}", &err_msg);

return Err(error::internal_server_error(err_msg));
}
}
}
};

// compute embeddings for the multiple user messages
let query_embedding_second_retrieval: Vec<f32> =
match embedding_response_second_retrieval.data.first() {
Some(embedding) => embedding.embedding.iter().map(|x| *x as f32).collect(),
None => {
let err_msg = "No embeddings returned";

// log
error!(target: "stdout", "{}", &err_msg);

return Err(error::internal_server_error(err_msg));
}
};
// log
error!(target: "stdout", "{}", &err_msg);

// perform the second retrieval
match llama_core::rag::rag_retrieve_context(
query_embedding_second_retrieval.as_slice(),
server_info.qdrant_config.url.to_string().as_str(),
server_info.qdrant_config.collection_name.as_str(),
server_info.qdrant_config.limit as usize,
Some(server_info.qdrant_config.score_threshold),
)
.await
{
Ok(retrieve_object_second_retrieval) => {
match retrieve_object_second_retrieval.points {
Some(scored_points) if !scored_points.is_empty() => {
info!(target: "stdout", "{} point(s) retrieved from the second retrieval", scored_points.len());

if let Some(points) = retrieve_object.points.as_mut() {
points.extend(scored_points);
}
}
_ => {
warn!(target: "stdout", "No point retrieved from the second retrieval");
}
}
}
Err(e) => {
// log
error!(target: "stdout", "No point retrieved from the second retrieval. {}", e);
}
};
return Err(error::internal_server_error(err_msg));
}
};
if retrieve_object.points.is_none() {
retrieve_object.points = Some(Vec::new());
}

info!(target: "stdout", "{} point(s) retrieved from the two rounds of retrieval", retrieve_object.points.as_ref().unwrap().len());

// sort the points by score
retrieve_object.points.as_mut().unwrap().sort_by(|a, b| {
// Compare scores in reverse order (highest first)
// Using partial_cmp for floating point comparison
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});

// remove the duplicate points
retrieve_object
.points
.as_mut()
.unwrap()
.dedup_by_key(|point| point.source.clone());

info!(target: "stdout", "{} point(s) kept after removing duplicates", retrieve_object.points.as_ref().unwrap().len());
info!(target: "stdout", "{} point(s) retrieved", retrieve_object.points.as_ref().unwrap().len());

Ok(retrieve_object)
}
Expand Down

0 comments on commit 02ffa0f

Please sign in to comment.