Skip to content

Commit

Permalink
Merge pull request #27 from dhrumilp12/overhaul-server
Browse files Browse the repository at this point in the history
Made structural changes to the server code.
  • Loading branch information
janthonysantana authored Jun 22, 2024
2 parents 4223722 + 11cf39c commit 98f7f47
Show file tree
Hide file tree
Showing 12 changed files with 351 additions and 62 deletions.
1 change: 0 additions & 1 deletion .anima/.gitignore

This file was deleted.

3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.anima
.anima/
.ruff_cache/
84 changes: 74 additions & 10 deletions server/agents/mental_health_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,25 @@

# -- Standard libraries --
from datetime import datetime
import time
import asyncio
from operator import itemgetter

# -- 3rd Party libraries --
# import spacy

# Azure
# Langchain
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory import ConversationSummaryMemory
from langchain.memory import ConversationSummaryMemory, ConversationBufferMemory
from langchain_core.runnables import RunnablePassthrough
from langchain.agents import create_tool_calling_agent, AgentExecutor
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
from langchain_core.messages import trim_messages

# MongoDB
# -- Custom modules --
from .ai_agent import AIAgent
Expand Down Expand Up @@ -60,16 +67,16 @@ def __init__(self, system_message: str = SYSTEM_MESSAGE, tool_names: list[str] =
self.prompt = ChatPromptTemplate.from_messages(
[
("system", self.system_message.content),
("system", "user_id:{user_id}"),
MessagesPlaceholder(variable_name="history"),
("human", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)

tools = self._create_agent_tools(tool_names)
agent = create_tool_calling_agent(self.llm, tools, self.prompt)
self.agent = create_tool_calling_agent(self.llm, self.tools, self.prompt)
executor:AgentExecutor = AgentExecutor(
agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)
agent=self.agent, tools=self.tools, verbose=True, handle_parsing_errors=True)
self.agent_executor = self.get_agent_with_history(executor)


Expand All @@ -80,9 +87,10 @@ def get_session_history(self, session_id: str) -> MongoDBChatMessageHistory:
Args:
session_id (str): The session ID to retrieve the chat history for.
"""
CONNECTION_STRING = MongoDBClient.get_mongodb_variables()

history = MongoDBChatMessageHistory(
MongoDBClient.get_mongodb_variables(),
CONNECTION_STRING,
session_id,
MongoDBClient.get_db_name(),
collection_name="history"
Expand Down Expand Up @@ -151,7 +159,7 @@ def get_agent_executor(self, prompt):
return agent_executor


def exec_update_step():
def exec_update_step(self, user_id, chat_id=None, turn_id=None):
# Chat Summary:
# Update every 5 chat turns
# Therapy Material
Expand All @@ -162,6 +170,60 @@ def exec_update_step():
# Can be either updated at the end of the chat, or every 5 chat turns
# User Material:
# Possibly updated every 5 chat turns, at the end of a chat, or not at all

history:BaseChatMessageHistory = self.get_session_history(f"{user_id}-{chat_id}")
history_log = asyncio.run(history.aget_messages()) # Running async function as synchronous

# Get perceived mood
instructions = """
Given the messages provided, describe the user's mood in a single adjective.
Do your best to capture their intensity, attitude and disposition in that single word.
Do not include anything in your response aside from that word.
If you cannot complete this task, just answer \"None\".
"""

prompt = ChatPromptTemplate.from_messages(
[
("system", instructions),
MessagesPlaceholder(variable_name="messages"),
]
)

trimmer = trim_messages(
max_tokens=65,
strategy="last",
token_counter=self.llm,
include_system=True,
allow_partial=False,
start_on="human",
)

trimmer.invoke(history_log)
chain = RunnablePassthrough.assign(messages=itemgetter("messages") | trimmer) | prompt | self.llm
response = chain.invoke({"messages": history_log})
user_mood = None if response.content == "None" else response.content

print("The user is feeling: ", user_mood)

# agent_with_history = RunnableWithMessageHistory(
# chain,
# get_session_history=lambda _: memory,
# input_messages_key="input",
# history_messages_key="history",
# verbose=True
# )


# self.agent = create_tool_calling_agent(self.llm, self.tools, self.prompt)
# stateless_agent_executor:AgentExecutor = AgentExecutor(
# agent=self.agent, tools=self.tools, verbose=True, handle_parsing_errors=True)

# history=self.get_session_history(f"{user_id}-{chat_id}-{}")
# Must invoke with an agent that will not write to DB
# invocation = stateless_agent_executor.invoke(
# {"input": f"{message}\nuser_id:{user_id}", "agent_scratchpad": []})

# Get summary text
pass

def run(self, message: str, with_history:bool =True, user_id: str=None, chat_id:int=None, turn_id:int=None) -> str:
Expand All @@ -181,16 +243,18 @@ def run(self, message: str, with_history:bool =True, user_id: str=None, chat_id:
# else:
# TODO: throw error if user_id, chat_id is set to None.
session_id = f"{user_id}-{chat_id}"
# kwargs = {
# "timestamp": curr_epoch_time
# }

invocation = self.agent_executor.invoke(
{"input": f"{message}\nuser_id:{user_id}", "agent_scratchpad": []},
config={"configurable": {"session_id": session_id}}
)
{"input": message, "user_id": user_id, "agent_scratchpad": []},
config={"configurable": {"session_id": session_id}})

# This updates certain collections in the database based on recent history
if (turn_id + 1) % PROCESSING_STEP == 0:
# TODO
self.exec_update_step()
self.exec_update_step(user_id, chat_id)
pass

return invocation["output"]
Expand Down
14 changes: 7 additions & 7 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from flask_cors import CORS
from flask_jwt_extended import JWTManager

# from models.subscription import Subscription,db
# from routes.checkIn import checkIn_routes
from models.subscription import db as sub_db
from routes.check_in import check_in_routes
from services.db.agent_facts import load_agent_facts_to_db

from routes.user import user_routes
Expand All @@ -25,17 +25,17 @@
# Register routes
app.register_blueprint(user_routes)
app.register_blueprint(ai_routes)
app.register_blueprint(checkIn_routes)
app.register_blueprint(check_in_routes)

# DB pre-load
load_agent_facts_to_db()

# Subscription db
# app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///mydatabase.db'
# db.init_app(app)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///mydatabase.db'
sub_db.init_app(app)
## Create the tables
# with app.app_context():
# db.create_all()
with app.app_context():
sub_db.create_all()

# Base endpoint
@app.get("/")
Expand Down
4 changes: 2 additions & 2 deletions server/models/check_in.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from datetime import datetime, timedelta
from pydantic import BaseModel, Field, constr, validator
from pydantic import BaseModel, Field, constr, field_validator
from enum import Enum

class Frequency(str, Enum):
Expand Down Expand Up @@ -31,7 +31,7 @@ def save(self, db):
document = self.dict()
db.check_ins.insert_one(document)

@validator('check_in_time', pre=True)
@field_validator('check_in_time', pre=True)
def check_future_date(cls, v):
if v < datetime.now():
raise ValueError("Check-in time must be in the future")
Expand Down
Loading

0 comments on commit 98f7f47

Please sign in to comment.