Skip to content

Commit

Permalink
chore(webserver): add CodeQueryParams override to code query input (#…
Browse files Browse the repository at this point in the history
…2858)

* chore(webserver): add CodeQueryParams override to code query input

* update

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
wsxiaoys and autofix-ci[bot] authored Aug 13, 2024
1 parent 104bd03 commit 5d818ee
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 11 deletions.
17 changes: 15 additions & 2 deletions ee/tabby-schema/graphql/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ input CodeQueryInput {
content: String!
}

input CodeSearchParamsOverrideInput {
minEmbeddingScore: Float
minBm25Score: Float
minRrfScore: Float
numToReturn: Int
numToScore: Int
}

input CreateIntegrationInput {
displayName: String!
accessToken: String!
Expand All @@ -104,7 +112,7 @@ input CreateMessageInput {

input CreateThreadAndRunInput {
thread: CreateThreadInput!
options: ThreadRunOptionsInput! = {codeQuery: null, docQuery: null, generateRelevantQuestions: false}
options: ThreadRunOptionsInput! = {codeQuery: null, debugOptions: null, docQuery: null, generateRelevantQuestions: false}
}

input CreateThreadInput {
Expand All @@ -114,7 +122,7 @@ input CreateThreadInput {
input CreateThreadRunInput {
threadId: ID!
additionalUserMessage: CreateMessageInput!
options: ThreadRunOptionsInput! = {codeQuery: null, docQuery: null, generateRelevantQuestions: false}
options: ThreadRunOptionsInput! = {codeQuery: null, debugOptions: null, docQuery: null, generateRelevantQuestions: false}
}

input CreateWebCrawlerUrlInput {
Expand Down Expand Up @@ -173,10 +181,15 @@ input SecuritySettingInput {
disableClientSideTelemetry: Boolean!
}

input ThreadRunDebugOptionsInput {
codeSearchParamsOverride: CodeSearchParamsOverrideInput = null
}

input ThreadRunOptionsInput {
docQuery: DocQueryInput = null
codeQuery: CodeQueryInput = null
generateRelevantQuestions: Boolean! = false
debugOptions: ThreadRunDebugOptionsInput = null
}

input UpdateIntegrationInput {
Expand Down
39 changes: 39 additions & 0 deletions ee/tabby-schema/src/schema/thread/inputs.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use juniper::{GraphQLInputObject, ID};
use tabby_common::api::code::CodeSearchParams;
use validator::Validate;

#[derive(GraphQLInputObject)]
Expand Down Expand Up @@ -34,6 +35,26 @@ pub struct CodeQueryInput {
pub content: String,
}

impl CodeSearchParamsOverrideInput {
pub fn override_params(&self, params: &mut CodeSearchParams) {
if let Some(min_embedding_score) = self.min_embedding_score {
params.min_embedding_score = min_embedding_score as f32;
}
if let Some(min_bm25_score) = self.min_bm25_score {
params.min_bm25_score = min_bm25_score as f32;
}
if let Some(min_rrf_score) = self.min_rrf_score {
params.min_rrf_score = min_rrf_score as f32;
}
if let Some(num_to_return) = self.num_to_return {
params.num_to_return = num_to_return as usize;
}
if let Some(num_to_score) = self.num_to_score {
params.num_to_score = num_to_score as usize;
}
}
}

#[derive(GraphQLInputObject, Validate, Default, Clone)]
pub struct ThreadRunOptionsInput {
#[validate(nested)]
Expand All @@ -46,6 +67,24 @@ pub struct ThreadRunOptionsInput {

#[graphql(default)]
pub generate_relevant_questions: bool,

#[graphql(default)]
pub debug_options: Option<ThreadRunDebugOptionsInput>,
}

#[derive(GraphQLInputObject, Clone)]
pub struct CodeSearchParamsOverrideInput {
pub min_embedding_score: Option<f64>,
pub min_bm25_score: Option<f64>,
pub min_rrf_score: Option<f64>,
pub num_to_return: Option<i32>,
pub num_to_score: Option<i32>,
}

#[derive(GraphQLInputObject, Clone)]
pub struct ThreadRunDebugOptionsInput {
#[graphql(default)]
pub code_search_params_override: Option<CodeSearchParamsOverrideInput>,
}

#[derive(GraphQLInputObject, Validate)]
Expand Down
24 changes: 15 additions & 9 deletions ee/tabby-webserver/src/service/answer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use tabby_inference::ChatCompletionStream;
use tabby_schema::{
repository::RepositoryService,
thread::{
self, CodeQueryInput, DocQueryInput, MessageAttachment, MessageAttachmentCode,
MessageAttachmentDoc, ThreadRunItem, ThreadRunOptionsInput,
self, CodeQueryInput, CodeSearchParamsOverrideInput, DocQueryInput, MessageAttachment,
MessageAttachmentCode, MessageAttachmentDoc, ThreadRunItem, ThreadRunOptionsInput,
},
web_crawler::WebCrawlerService,
};
Expand Down Expand Up @@ -108,7 +108,7 @@ impl AnswerService {
language: code_query.language,
content: code_query.content,
};
self.collect_relevant_code(&code_query, &self.config.code_search_params).await
self.collect_relevant_code(&code_query, &self.config.code_search_params, None).await
} else {
vec![]
};
Expand Down Expand Up @@ -208,7 +208,7 @@ impl AnswerService {

// 1. Collect relevant code if needed.
if let Some(code_query) = options.code_query.as_ref() {
attachment.code = self.collect_relevant_code(code_query, &self.config.code_search_params).await.iter()
attachment.code = self.collect_relevant_code(code_query, &self.config.code_search_params, options.debug_options.as_ref().and_then(|x| x.code_search_params_override.as_ref())).await.iter()
.map(|x| MessageAttachmentCode{
filepath: Some(x.doc.filepath.clone()),
content: x.doc.body.clone(),
Expand Down Expand Up @@ -292,16 +292,22 @@ impl AnswerService {

async fn collect_relevant_code(
&self,
query: &CodeQueryInput,
input: &CodeQueryInput,
params: &CodeSearchParams,
override_params: Option<&CodeSearchParamsOverrideInput>,
) -> Vec<CodeSearchHit> {
let query = CodeSearchQuery {
git_url: query.git_url.clone(),
filepath: query.filepath.clone(),
language: query.language.clone(),
content: query.content.clone(),
git_url: input.git_url.clone(),
filepath: input.filepath.clone(),
language: input.language.clone(),
content: input.content.clone(),
};

let mut params = params.clone();
override_params
.as_ref()
.inspect(|x| x.override_params(&mut params));

match self.code.search_in_language(query, params.clone()).await {
Ok(docs) => docs.hits,
Err(err) => {
Expand Down

0 comments on commit 5d818ee

Please sign in to comment.