diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py index 538e2792428..062c588b067 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py @@ -11,7 +11,8 @@ from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_DEFAULT from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_PERSONA from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt -from onyx.chat.models import SubAnswer +from onyx.agent_search.shared_graph_utils.utils import parse_question_id +from onyx.chat.models import SubAnswerPiece from onyx.utils.logger import setup_logger logger = setup_logger() @@ -20,6 +21,7 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: question = state["question"] docs = state["documents"] + level, question_nr = parse_question_id(state["question_id"]) persona_prompt = get_persona_prompt(state["subgraph_config"].search_request.persona) if len(persona_prompt) > 0: @@ -51,9 +53,10 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: ) dispatch_custom_event( "sub_answers", - SubAnswer( + SubAnswerPiece( sub_answer=content, - sub_question_id=state["question_id"], + level=level, + level_question_nr=question_nr, ), ) response.append(content) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes.py b/backend/onyx/agent_search/expanded_retrieval/nodes.py index 7ff034fd593..f1eb798d4cc 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes.py @@ -29,9 +29,9 @@ from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI_ORIGINAL from onyx.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT from onyx.agent_search.shared_graph_utils.utils import dispatch_separated -from onyx.agent_search.shared_graph_utils.utils import make_question_id +from onyx.agent_search.shared_graph_utils.utils import parse_question_id from onyx.chat.models import ExtendedToolResponse -from onyx.chat.models import SubQuery +from onyx.chat.models import SubQueryPiece from onyx.configs.dev_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS from onyx.configs.dev_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS from onyx.configs.dev_configs import AGENT_RERANKING_STATS @@ -49,11 +49,16 @@ logger = setup_logger() -def dispatch_subquery(subquestion_id: str) -> Callable[[str, int], None]: +def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None]: def helper(token: str, num: int) -> None: dispatch_custom_event( "subqueries", - SubQuery(sub_query=token, sub_question_id=subquestion_id, query_id=num), + SubQueryPiece( + sub_query=token, + level=level, + level_question_nr=question_nr, + query_id=num, + ), ) return helper @@ -69,7 +74,9 @@ def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate: chat_session_id = state["subgraph_config"].chat_session_id sub_question_id = state.get("sub_question_id") if sub_question_id is None: - sub_question_id = make_question_id(0, 0) # 0_0 for original question + level, question_nr = 0, 0 + else: + level, question_nr = parse_question_id(sub_question_id) if chat_session_id is None: raise ValueError("chat_session_id must be provided for agent search") @@ -81,7 +88,7 @@ def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate: ] llm_response_list = dispatch_separated( - llm.stream(prompt=msg), dispatch_subquery(sub_question_id) + llm.stream(prompt=msg), dispatch_subquery(level, question_nr) ) llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content @@ -125,12 +132,18 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: retrieved_docs = cast( list[InferenceSection], tool_response.response.top_sections ) + level, question_nr = ( + parse_question_id(state["sub_question_id"]) + if state["sub_question_id"] + else (0, 0) + ) dispatch_custom_event( "tool_response", ExtendedToolResponse( id=tool_response.id, - sub_question_id=state["sub_question_id"] or make_question_id(0, 0), response=tool_response.response, + level=level, + level_question_nr=question_nr, ), ) diff --git a/backend/onyx/agent_search/main/nodes.py b/backend/onyx/agent_search/main/nodes.py index e45195f9a72..822f7591745 100644 --- a/backend/onyx/agent_search/main/nodes.py +++ b/backend/onyx/agent_search/main/nodes.py @@ -62,7 +62,7 @@ from onyx.agent_search.shared_graph_utils.utils import format_entity_term_extraction from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt from onyx.agent_search.shared_graph_utils.utils import make_question_id -from onyx.chat.models import SubQuestion +from onyx.chat.models import SubQuestionPiece from onyx.db.chat import log_agent_metrics from onyx.db.chat import log_agent_sub_question_results from onyx.utils.logger import setup_logger @@ -74,11 +74,10 @@ def dispatch_subquestion(level: int) -> Callable[[str, int], None]: def helper(sub_question_part: str, num: int) -> None: dispatch_custom_event( "decomp_qs", - SubQuestion( + SubQuestionPiece( sub_question=sub_question_part, - question_id=make_question_id( - level, num + 1 - ), # question 0 reserved for original question if used + level=level, + level_question_nr=num + 1, ), ) diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 122ca955897..c95417a1f47 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -14,9 +14,9 @@ from onyx.chat.models import AnswerStream from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import ProSearchConfig -from onyx.chat.models import SubAnswer -from onyx.chat.models import SubQuery -from onyx.chat.models import SubQuestion +from onyx.chat.models import SubAnswerPiece +from onyx.chat.models import SubQueryPiece +from onyx.chat.models import SubQuestionPiece from onyx.chat.models import ToolResponse from onyx.context.search.models import SearchRequest from onyx.db.engine import get_session_context_manager @@ -43,11 +43,11 @@ def _parse_agent_event( if event_type == "on_custom_event": # TODO: different AnswerStream types for different events if event["name"] == "decomp_qs": - return cast(SubQuestion, event["data"]) + return cast(SubQuestionPiece, event["data"]) elif event["name"] == "subqueries": - return cast(SubQuery, event["data"]) + return cast(SubQueryPiece, event["data"]) elif event["name"] == "sub_answers": - return cast(SubAnswer, event["data"]) + return cast(SubAnswerPiece, event["data"]) elif event["name"] == "main_answer": return OnyxAnswerPiece(answer_piece=cast(str, event["data"])) elif event["name"] == "tool_response": diff --git a/backend/onyx/chat/models.py b/backend/onyx/chat/models.py index 233d9982110..e28231ce975 100644 --- a/backend/onyx/chat/models.py +++ b/backend/onyx/chat/models.py @@ -347,35 +347,33 @@ def from_model( ) -class SubQuery(BaseModel): +class SubQueryPiece(BaseModel): sub_query: str - sub_question_id: str # _ + level: int + level_question_nr: int query_id: int - @model_validator(mode="after") - def check_sub_question_id(self) -> "SubQuery": - if len(self.sub_question_id.split("_")) != 2: - raise ValueError( - "sub_question_id must be in the format _" - ) - return self - -class SubAnswer(BaseModel): +class SubAnswerPiece(BaseModel): sub_answer: str - sub_question_id: str # _ + level: int + level_question_nr: int -class SubQuestion(BaseModel): - question_id: str # _ +class SubQuestionPiece(BaseModel): sub_question: str + level: int + level_question_nr: int class ExtendedToolResponse(ToolResponse): - sub_question_id: str # _ + level: int + level_question_nr: int -ProSearchPacket = SubQuestion | SubAnswer | SubQuery | ExtendedToolResponse +ProSearchPacket = ( + SubQuestionPiece | SubAnswerPiece | SubQueryPiece | ExtendedToolResponse +) AnswerPacket = ( AnswerQuestionPossibleReturn | ProSearchPacket | ToolCallKickoff | ToolResponse