Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
operations: nlp: tools: dffml docs: Use model_name from global variab…
Browse files Browse the repository at this point in the history
…le, use gpt-4-0125-preview for best results

Signed-off-by: John Andersen <[email protected]>
  • Loading branch information
pdxjohnny committed Apr 10, 2024
1 parent d2d6cb1 commit 26a859e
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions operations/nlp/dffml_operations_nlp/tools/dffml_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ def load_docs_dffml():
)

# docker run --name postgres-embeddings-dffml-docs -d --restart=always -e POSTGRES_DB=docs_ai_alice_dffml -e POSTGRES_PASSWORD=password -e POSTGRES_USER=user -v $HOME/embeddings/openai/var-lib-postgresq-data:/var/lib/postgresql/data:z -p 127.0.0.1:5432:5432 pgvector/pgvector:pg16
POSTGRESQL_CONNECTION_STRING = "postgresql+psycopg2://user:password@localhost:5432/docs_ai_alice_dffml"
POSTGRESQL_CONNECTION_STRING = (
"postgresql+psycopg2://user:password@localhost:5432/docs_ai_alice_dffml"
)


# cachier does not work with PGVector @cachier(pickle_reload=False)
def load_retriever():
Expand Down Expand Up @@ -153,6 +156,7 @@ def load_retriever():
# retriever.add_documents(docs)
return vectorstore


# TODO https://python.langchain.com/docs/integrations/retrievers/merger_retriever/
retriever = load_retriever()

Expand All @@ -168,7 +172,9 @@ def load_retriever():
# base_retriever defined somewhere above...

compressor = LLMChainExtractor.from_llm(OpenAI(temperature=0))
compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)

from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_openai import ChatOpenAI
Expand Down Expand Up @@ -268,6 +274,11 @@ def cached_hub_pull(*args, **kwargs):
return hub.pull(*args, **kwargs)


OPENAI_MODEL_NAME_GPT_3_5 = "gpt-3.5-turbo"
OPENAI_MODEL_NAME_GPT_4 = "gpt-4-0125-preview"
OPENAI_MODEL_NAME = OPENAI_MODEL_NAME_GPT_4


### Edges


Expand Down Expand Up @@ -319,9 +330,7 @@ class grade(BaseModel):
binary_score: str = Field(description="Relevance score 'yes' or 'no'")

# LLM
model = ChatOpenAI(
temperature=0, model="gpt-4-0125-preview", streaming=True
)
model = ChatOpenAI(temperature=0, model=OPENAI_MODEL_NAME, streaming=True)

# Tool
grade_tool_oai = convert_to_openai_tool(grade)
Expand Down Expand Up @@ -384,9 +393,7 @@ def agent(state):
"""
print("---CALL AGENT---")
messages = state["messages"]
model = ChatOpenAI(
temperature=0, streaming=True, model="gpt-4-0125-preview"
)
model = ChatOpenAI(temperature=0, model=OPENAI_MODEL_NAME, streaming=True)
functions = [convert_to_openai_function(t) for t in tools]
model = model.bind_functions(functions)
response = model.invoke(messages)
Expand Down Expand Up @@ -453,9 +460,7 @@ def rewrite(state):
]

# Grader
model = ChatOpenAI(
temperature=0, model="gpt-4-0125-preview", streaming=True
)
model = ChatOpenAI(temperature=0, model=OPENAI_MODEL_NAME, streaming=True)
response = model.invoke(msg)
return {"messages": [response]}

Expand All @@ -482,7 +487,9 @@ def generate(state):
prompt = cached_hub_pull("rlm/rag-prompt")

# LLM
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, streaming=True)
llm = ChatOpenAI(
model_name=OPENAI_MODEL_NAME, temperature=0, streaming=True
)

# Post-processing
def format_docs(docs):
Expand Down Expand Up @@ -584,4 +591,5 @@ def format_docs(docs):
pprint.pprint("\n---\n")

import pathlib

pathlib.Path("~/chat-log.txt").expanduser().write_text("\n\n".join(chat_log))

0 comments on commit 26a859e

Please sign in to comment.