Skip to content

Commit

Permalink
Somewhat fixed issue with bot not using its memory in request sequen…
Browse files Browse the repository at this point in the history
…ces.
  • Loading branch information
janthonysantana committed Jun 9, 2024
1 parent 8ba7a53 commit a4f2c1b
Show file tree
Hide file tree
Showing 2 changed files with 246 additions and 38 deletions.
19 changes: 11 additions & 8 deletions server/routes/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_mental_health_agent_welcome(user_id):
timestamp = datetime.now().isoformat()

system_message = """
You are a therapy companion.
Your name is Aria, you are a therapy companion.
You are a patient, empathetic virtual therapist. Your purpose is not to replace human therapists, but to lend aid when human therapists are not available.
Expand All @@ -64,6 +64,7 @@ def run_mental_health_agent(user_id, chat_id):
body = request.get_json()

prompt = body.get("prompt")
turn_id = body.get("turn_id")

system_message = """
You are a therapy companion.
Expand All @@ -73,17 +74,19 @@ def run_mental_health_agent(user_id, chat_id):
Your job is to gently guide the user, your patient, through their mental healing journey.
You will speak in a natural, concise, and casual tone. Do not be verbose. Your role is not to ramble about psychology theory, but to support and listen to your patient.
If you do not know the answer of a question, do not give a `I am a virtual assistant` disclaimer, instead, honestly state that you don't know the answer.
Last Conversation Log:
{history}
Last Conversation Summary:
{summary}
If you do not know the answer of a question, honestly state that you don't know the answer. Do not make up an answer.
"""

timestamp = datetime.now().isoformat()

response_message = get_langchain_agent_response(f"mental-health-{ENV}", "chatbot_logs", system_message, prompt, user_id, timestamp)
response_message = get_langchain_agent_response(f"mental-health-{ENV}",
"chatbot_logs",
system_message,
prompt,
user_id,
int(chat_id),
turn_id + 1,
timestamp)

return {"message": response_message}

Expand Down
265 changes: 235 additions & 30 deletions server/tools/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dotenv import load_dotenv
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from langchain_community.vectorstores import AzureCosmosDBVectorSearch
from langchain.prompts import PromptTemplate
from langchain.prompts import PromptTemplate, SystemMessagePromptTemplate
from langchain_core.messages import SystemMessage
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import CharacterTextSplitter
Expand All @@ -18,13 +18,16 @@
from langchain.agents import Tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
# from langchain.memory import MemoryConversationSummaryMemory, CombinedMemory
from langchain_mongodb import MongoDBChatMessageHistory
from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.utils import ConfigurableFieldSpec
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.ai import AIMessage

from tools.azure_mongodb import MongoDBClient

sessions = {}

def setup_langchain(db_name, collection_name, system_prompt):
"""
Sets up and returns the components necessary for LangChain interaction with Azure OpenAI.
Expand Down Expand Up @@ -133,7 +136,7 @@ def get_azure_openai_llm():
AOAI_ENDPOINT, AOAI_KEY, AOAI_API_VERSION, _, AOAI_COMPLETIONS = get_azure_openai_variables()

llm = AzureChatOpenAI(
temperature = 0.0,
temperature = 0.3,
openai_api_version = AOAI_API_VERSION,
azure_endpoint = AOAI_ENDPOINT,
openai_api_key = AOAI_KEY,
Expand Down Expand Up @@ -208,7 +211,37 @@ def get_cosmosdb_tool(db_name, collection_name):
return cosmosdb_tool


def get_mongodb_agent_with_history(llm, tools, db_name, collection_name, system_message, user_id, timestamp):
def find_session(user_id, chat_id):
if sessions.get((user_id, chat_id)) is None:
chat_history = ChatMessageHistory()
sessions[(user_id, chat_id)] = chat_history
return sessions[(user_id, chat_id)]


def get_user_history(user_id, chat_id):
"""
Used to find personal details from previous conversations with the user.
"""
db_client = MongoDBClient.get_client()
db = db_client[MongoDBClient.get_db()]
turns = list(db.chat_turns.find({"user_id": user_id, "chat_id": chat_id}).sort({"timestamp": -1}))
turns.reverse()
history = []

for turn in turns:
if turn.get("human_message"):
history.append(HumanMessage(turn.get("human_message")))
if turn.get("ai_message"):
history.append(AIMessage(turn.get("ai_message")))

return history


chat_history = ChatMessageHistory()
chat_history.add_messages(history)
sessions[(user_id, chat_id)] = chat_history

def get_mongodb_agent_with_history(llm, tools, db_name, collection_name, message, system_message, user_id, chat_id, timestamp, history=[]):
CONNECTION_STRING = MongoDBClient.get_mongodb_variables()

agent_executor = create_conversational_retrieval_agent(
Expand All @@ -218,31 +251,52 @@ def get_mongodb_agent_with_history(llm, tools, db_name, collection_name, system_
verbose=True
)

message_history = MongoDBChatMessageHistory(
connection_string=CONNECTION_STRING,
session_id=f"{user_id};{timestamp}",
database_name=db_name,
collection_name=collection_name
)
# message_history = MongoDBChatMessageHistory(
# connection_string=CONNECTION_STRING,
# session_id=f"{user_id};{timestamp}",
# database_name=db_name,
# collection_name=collection_name
# )

# chat_history = ChatMessageHistory(session_id=)
# chat_history.add_messages(history)

agent_with_history = RunnableWithMessageHistory(
agent_executor,
lambda session_id: message_history,
find_session,
input_messages_key="input",
history_messages_key="chat_history",
history_factory_config=[
ConfigurableFieldSpec(
id="user_id",
annotation=str,
name="User ID",
description="Unique identifier for the user.",
default="",
is_shared=True,
),
ConfigurableFieldSpec(
id="chat_id",
annotation=str,
name="Chat ID",
description="Unique identifier for the chat session.",
default="",
is_shared=True,
),
]
)

return agent_with_history


def get_langchain_initial_state(db_name, collection_name, user_id, system_message, timestamp):
def get_langchain_initial_state(db_name, collection_name, user_id, system_message, timestamp, chat_id=None, history=[]):

db_client = MongoDBClient.get_client()
db = db_client[MongoDBClient.get_db()]
user_journey_collection = db["user_journeys"]
user_journey = user_journey_collection.find_one({"user_id": user_id})
# Has user engaged with chatbot before?
if user_journey is None:
# user has not engaged with the chatbot before
user_journey_collection.insert_one({
"user_id": user_id,
"patient_goals": [],
Expand All @@ -256,24 +310,66 @@ def get_langchain_initial_state(db_name, collection_name, user_id, system_messag
In this session, do your best to understand what the user hopes to achieve through your service, and derive a therapy style fitting to
their needs.
"""
mod_system_message = ''.join([system_message, addendum])
response = get_langchain_agent_response(db_name, collection_name, mod_system_message, "", user_id, timestamp)

full_system_message = SystemMessagePromptTemplate.from_template(''.join([system_message, addendum])).format()
response = get_langchain_agent_response(db_name=db_name,
collection_name=collection_name,
system_message=full_system_message,
message="",
user_id=user_id,
chat_id=0,
turn_id=0,
timestamp=timestamp)
return response
else:
pass
# do something else
recent_turns = db.chat_turns.find({"user_id": user_id}).sort({"timestamp": -1}).limit(5).next()
# Get summary
chat_summary = db.chat_summaries.find({"user_id": user_id}).sort({"timestamp": -1}).limit(1).next()
new_chat_id = chat_summary.get("chat_id") + 1

recent_turns = list(db.chat_turns.find({"user_id": user_id}).sort({"timestamp": -1}).limit(5))
recent_turns.reverse()
history = []
for turn in recent_turns:
if turn.get("human_message"):
history.append(HumanMessage(turn.get("human_message")))
if turn.get("ai_message"):
history.append(AIMessage(turn.get("ai_message")))

old_chat_id = recent_turns[0].get("chat_id", -1) + 1

chat_summary = ""
# chat_summary = db.chat_summaries.find({"user_id": user_id}).sort({"timestamp": -1}).limit(1).next()
new_chat_id = old_chat_id


addendum = """
Last Conversation Summary:
{summary}
"""

system_message_template = SystemMessagePromptTemplate.from_template(''.join([system_message, addendum]))

full_system_message = system_message_template.format(history=history, summary=chat_summary)
response = get_langchain_agent_response(db_name=db_name,
collection_name=collection_name,
system_message=full_system_message,
message="",
user_id=user_id,
chat_id=new_chat_id,
turn_id=0,
timestamp=timestamp,
history=history)
return response

# last_turn = db.chat_turns.find({"user_id": user_id}).sort({"chat_id": -1}).limit(1).next()
# Is this a new conversation?
# if last_turn.get("chat_id") != chat_id or chat_id is None:
# pass
# Get summary
# new_chat_id = chat_summary.get("chat_id") + 1

# prompt_template = PromptTemplate.from_template(mod_system_message)
# prompt = prompt_template.format(user_id=user_id, )



def get_langchain_agent_response(db_name, collection_name, system_message, prompt, user_id, timestamp):
def get_langchain_agent_response(db_name, collection_name, system_message, message, user_id, chat_id, turn_id, timestamp, history=[]):
llm = get_azure_openai_llm()

# At the start of a chat turn, we want to load the conversation with the necessary context.
Expand All @@ -296,23 +392,132 @@ def get_langchain_agent_response(db_name, collection_name, system_message, promp
# return_messages=True
# )

system_message_obj = SystemMessage(
content=system_message
)

search = TavilySearchResults() # Going to use this to connect user to resources
# cosmosdb_tool = get_cosmosdb_tool(db_name, collection_name)
tools = [search] #, cosmosdb_tool]

# memory = ConversationEntityMemory()

agent_with_history = get_mongodb_agent_with_history(llm, tools, db_name, collection_name, system_message_obj, user_id, timestamp)
if type(system_message) == str:
system_message_obj = SystemMessage(system_message)
else:
system_message_obj = system_message


if not history:
db_client = MongoDBClient.get_client()
db = db_client[MongoDBClient.get_db()]
turns = list(db.chat_turns.find({"user_id": user_id, "chat_id": chat_id}).sort({"timestamp": -1}).limit(5))
turns.reverse()
history = []

for turn in turns:
if turn.get("human_message"):
history.append(HumanMessage(turn.get("human_message")))
if turn.get("ai_message"):
history.append(AIMessage(turn.get("ai_message")))

prompt = f"history:\n{history}\n{message}"

agent_with_history = get_mongodb_agent_with_history(llm, tools, db_name, collection_name, message, system_message_obj, user_id, chat_id, timestamp, history=history)

invocation = agent_with_history.invoke(
{ "input": prompt },
config={"configurable": {"session_id": ""}}
config={"configurable": {"user_id": user_id, "chat_id": chat_id}}
)

db_client = MongoDBClient.get_client()
db = db_client[MongoDBClient.get_db()]
chat_turns_collection = db["chat_turns"]

chat_turns_collection.insert_one({
"user_id": user_id,
"chat_id": chat_id,
"turn_id": turn_id,
"human_message": invocation.get("input"),
"ai_message": invocation.get("output"),
"timestamp": timestamp
})


return invocation["output"]

# Chat Turn Schema
{
"user_id": "",
"chat_id": "",
"turn_id": "",
"human_message": "",
"ai_message": "",
"timestamp": "",
}

# Chat Summary Schema
{
"user_id": "",
"chat_id": "",
"timestamp": "",
"last_updated": "",
"perceived_mood": "",
"summary_text": "",
"concerns_progress": {
{
"label": "",
"delta": ""
}
},
}

# User Journey schema
{
"user_id": "",
"patient_goals": [],
"therapy_type": [],
"last_updated": "",
"therapy_plan": [
{
"chat_id": "",
"exercises": "",
"submit_assignments": [],
"assign_assignments": [],
"assign_exercise": [],
"share_resource": []
}
],
"mental_health_concerns": [
{
"label": "",
"severity": "",
}
]
}

# User Entities Schema
{
"user_id": "",
"entity_id": "",
"entity_data": []

}

# Resources Schema
{
"resource_id": "",
"resource_type": "Article/Video/Contact Information/Exercise",
"": ""

}

# User Resource Schema
{
"user_id": "",
"resource_id": "",
"user_liked": "",
"user_viewed": "",
}

# At the time of creating a response, we want to save the response to mongodb

return invocation["output"]

# Chat Turn Schema
Expand Down

0 comments on commit a4f2c1b

Please sign in to comment.