Skip to content

Commit

Permalink
Add configurable K (#364)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Aug 9, 2024
1 parent 19185f1 commit f4b272e
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions backend/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from collections import defaultdict
from typing import Annotated, Literal, Sequence, TypedDict
from typing import Annotated, Literal, Optional, Sequence, TypedDict

import weaviate
from langchain_anthropic import ChatAnthropic
Expand All @@ -9,6 +9,7 @@
from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
HumanMessage,
convert_to_messages,
Expand All @@ -19,7 +20,7 @@
PromptTemplate,
)
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import ConfigurableField, RunnableConfig
from langchain_core.runnables import ConfigurableField, RunnableConfig, ensure_config
from langchain_fireworks import ChatFireworks
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
Expand Down Expand Up @@ -131,7 +132,7 @@ def update_documents(
class AgentState(TypedDict):
query: str
documents: Annotated[list[Document], update_documents]
messages: Annotated[list[BaseMessage], add_messages]
messages: Annotated[list[AnyMessage], add_messages]
# for convenience in evaluations
answer: str
feedback_urls: dict[str, list[str]]
Expand Down Expand Up @@ -202,7 +203,7 @@ class AgentState(TypedDict):
)


def get_retriever() -> BaseRetriever:
def get_retriever(k: Optional[int] = None) -> BaseRetriever:
weaviate_client = weaviate.connect_to_wcs(
cluster_url=os.environ["WEAVIATE_URL"],
auth_credentials=weaviate.classes.init.Auth.api_key(
Expand All @@ -217,7 +218,8 @@ def get_retriever() -> BaseRetriever:
embedding=get_embeddings_model(),
attributes=["source", "title"],
)
return weaviate_client.as_retriever(search_kwargs=dict(k=6))
k = k or 6
return weaviate_client.as_retriever(search_kwargs=dict(k=k))


def format_docs(docs: Sequence[Document]) -> str:
Expand All @@ -228,8 +230,11 @@ def format_docs(docs: Sequence[Document]) -> str:
return "\n".join(formatted_docs)


def retrieve_documents(state: AgentState) -> AgentState:
retriever = get_retriever()
def retrieve_documents(
state: AgentState, *, config: Optional[RunnableConfig] = None
) -> AgentState:
config = ensure_config(config)
retriever = get_retriever(k=config["configurable"].get("k"))
messages = convert_to_messages(state["messages"])
query = messages[-1].content
relevant_documents = retriever.invoke(query)
Expand Down Expand Up @@ -350,7 +355,16 @@ def route_to_response_synthesizer(
return "response_synthesizer"


workflow = StateGraph(AgentState)
class Configuration(TypedDict):
model_name: str
k: int


class InputSchema(TypedDict):
messages: list[AnyMessage]


workflow = StateGraph(AgentState, Configuration, input=InputSchema)

# define nodes
workflow.add_node("retriever", retrieve_documents)
Expand Down

0 comments on commit f4b272e

Please sign in to comment.