Skip to content

Commit

Permalink
refactor and separate id fields
Browse files Browse the repository at this point in the history
  • Loading branch information
evan-danswer committed Jan 7, 2025
1 parent 3a38407 commit 7898f38
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_DEFAULT
from onyx.agent_search.shared_graph_utils.prompts import ASSISTANT_SYSTEM_PROMPT_PERSONA
from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt
from onyx.chat.models import SubAnswer
from onyx.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import SubAnswerPiece
from onyx.utils.logger import setup_logger

logger = setup_logger()
Expand All @@ -20,6 +21,7 @@
def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate:
question = state["question"]
docs = state["documents"]
level, question_nr = parse_question_id(state["question_id"])
persona_prompt = get_persona_prompt(state["subgraph_config"].search_request.persona)

if len(persona_prompt) > 0:
Expand Down Expand Up @@ -51,9 +53,10 @@ def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate:
)
dispatch_custom_event(
"sub_answers",
SubAnswer(
SubAnswerPiece(
sub_answer=content,
sub_question_id=state["question_id"],
level=level,
level_question_nr=question_nr,
),
)
response.append(content)
Expand Down
27 changes: 20 additions & 7 deletions backend/onyx/agent_search/expanded_retrieval/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI_ORIGINAL
from onyx.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
from onyx.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import SubQuery
from onyx.chat.models import SubQueryPiece
from onyx.configs.dev_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.dev_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.dev_configs import AGENT_RERANKING_STATS
Expand All @@ -49,11 +49,16 @@
logger = setup_logger()


def dispatch_subquery(subquestion_id: str) -> Callable[[str, int], None]:
def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None]:
def helper(token: str, num: int) -> None:
dispatch_custom_event(
"subqueries",
SubQuery(sub_query=token, sub_question_id=subquestion_id, query_id=num),
SubQueryPiece(
sub_query=token,
level=level,
level_question_nr=question_nr,
query_id=num,
),
)

return helper
Expand All @@ -69,7 +74,9 @@ def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate:
chat_session_id = state["subgraph_config"].chat_session_id
sub_question_id = state.get("sub_question_id")
if sub_question_id is None:
sub_question_id = make_question_id(0, 0) # 0_0 for original question
level, question_nr = 0, 0
else:
level, question_nr = parse_question_id(sub_question_id)

if chat_session_id is None:
raise ValueError("chat_session_id must be provided for agent search")
Expand All @@ -81,7 +88,7 @@ def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate:
]

llm_response_list = dispatch_separated(
llm.stream(prompt=msg), dispatch_subquery(sub_question_id)
llm.stream(prompt=msg), dispatch_subquery(level, question_nr)
)

llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
Expand Down Expand Up @@ -125,12 +132,18 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate:
retrieved_docs = cast(
list[InferenceSection], tool_response.response.top_sections
)
level, question_nr = (
parse_question_id(state["sub_question_id"])
if state["sub_question_id"]
else (0, 0)
)
dispatch_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
sub_question_id=state["sub_question_id"] or make_question_id(0, 0),
response=tool_response.response,
level=level,
level_question_nr=question_nr,
),
)

Expand Down
9 changes: 4 additions & 5 deletions backend/onyx/agent_search/main/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
from onyx.agent_search.shared_graph_utils.utils import format_entity_term_extraction
from onyx.agent_search.shared_graph_utils.utils import get_persona_prompt
from onyx.agent_search.shared_graph_utils.utils import make_question_id
from onyx.chat.models import SubQuestion
from onyx.chat.models import SubQuestionPiece
from onyx.db.chat import log_agent_metrics
from onyx.db.chat import log_agent_sub_question_results
from onyx.utils.logger import setup_logger
Expand All @@ -74,11 +74,10 @@ def dispatch_subquestion(level: int) -> Callable[[str, int], None]:
def helper(sub_question_part: str, num: int) -> None:
dispatch_custom_event(
"decomp_qs",
SubQuestion(
SubQuestionPiece(
sub_question=sub_question_part,
question_id=make_question_id(
level, num + 1
), # question 0 reserved for original question if used
level=level,
level_question_nr=num + 1,
),
)

Expand Down
12 changes: 6 additions & 6 deletions backend/onyx/agent_search/run_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from onyx.chat.models import AnswerStream
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import ProSearchConfig
from onyx.chat.models import SubAnswer
from onyx.chat.models import SubQuery
from onyx.chat.models import SubQuestion
from onyx.chat.models import SubAnswerPiece
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece
from onyx.chat.models import ToolResponse
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
Expand All @@ -43,11 +43,11 @@ def _parse_agent_event(
if event_type == "on_custom_event":
# TODO: different AnswerStream types for different events
if event["name"] == "decomp_qs":
return cast(SubQuestion, event["data"])
return cast(SubQuestionPiece, event["data"])
elif event["name"] == "subqueries":
return cast(SubQuery, event["data"])
return cast(SubQueryPiece, event["data"])
elif event["name"] == "sub_answers":
return cast(SubAnswer, event["data"])
return cast(SubAnswerPiece, event["data"])
elif event["name"] == "main_answer":
return OnyxAnswerPiece(answer_piece=cast(str, event["data"]))
elif event["name"] == "tool_response":
Expand Down
30 changes: 14 additions & 16 deletions backend/onyx/chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,35 +347,33 @@ def from_model(
)


class SubQuery(BaseModel):
class SubQueryPiece(BaseModel):
sub_query: str
sub_question_id: str # <level>_<question_nr>
level: int
level_question_nr: int
query_id: int

@model_validator(mode="after")
def check_sub_question_id(self) -> "SubQuery":
if len(self.sub_question_id.split("_")) != 2:
raise ValueError(
"sub_question_id must be in the format <level>_<question_nr>"
)
return self


class SubAnswer(BaseModel):
class SubAnswerPiece(BaseModel):
sub_answer: str
sub_question_id: str # <level>_<question_nr>
level: int
level_question_nr: int


class SubQuestion(BaseModel):
question_id: str # <level>_<question_nr>
class SubQuestionPiece(BaseModel):
sub_question: str
level: int
level_question_nr: int


class ExtendedToolResponse(ToolResponse):
sub_question_id: str # <level>_<question_nr>
level: int
level_question_nr: int


ProSearchPacket = SubQuestion | SubAnswer | SubQuery | ExtendedToolResponse
ProSearchPacket = (
SubQuestionPiece | SubAnswerPiece | SubQueryPiece | ExtendedToolResponse
)

AnswerPacket = (
AnswerQuestionPossibleReturn | ProSearchPacket | ToolCallKickoff | ToolResponse
Expand Down

0 comments on commit 7898f38

Please sign in to comment.