From 791155505e99ca051a20751384ac49764cd9877d Mon Sep 17 00:00:00 2001 From: Josue Santana Date: Thu, 13 Jun 2024 23:25:41 -0400 Subject: [PATCH] Completely redid how memory works, now directly integrates with database. --- server/agents/ai_agent.py | 25 +- server/agents/mental_health_agent.py | 467 +++++++++++++++++---------- server/app.py | 4 + server/models/agent_fact.py | 1 - server/models/chat_summary.py | 4 +- server/models/chat_turn.py | 4 +- server/services/azure_mongodb.py | 15 + server/tests/test_user.py | 2 +- server/utils/consts.py | 27 +- 9 files changed, 356 insertions(+), 193 deletions(-) diff --git a/server/agents/ai_agent.py b/server/agents/ai_agent.py index 5f22bab2..bfef3d74 100644 --- a/server/agents/ai_agent.py +++ b/server/agents/ai_agent.py @@ -1,29 +1,35 @@ import re -from langchain_community.vectorstores import AzureCosmosDBVectorSearch +from langchain_community.vectorstores.azure_cosmos_db import ( + AzureCosmosDBVectorSearch, + CosmosDBSimilarityType, + CosmosDBVectorSearchType +) from langchain.agents import Tool from langchain.agents.agent_toolkits import create_conversational_retrieval_agent from langchain.tools import StructuredTool from langchain_core.messages import SystemMessage from langchain_core.vectorstores import VectorStoreRetriever +from pymongo.database import Database + from services.azure_mongodb import MongoDBClient from services.my_azure import get_azure_openai_variables, get_azure_openai_llm, get_azure_openai_embeddings class AIAgent: def __init__(self, system_message:str, schema:list[str]=[]): - self.db = (MongoDBClient.get_client())[MongoDBClient.get_db_name()] + self.db:Database = (MongoDBClient.get_client())[MongoDBClient.get_db_name()] self.llm = get_azure_openai_llm() self.embedding_model = get_azure_openai_embeddings() self.system_message = SystemMessage(content=system_message) - self.agent_executor = create_conversational_retrieval_agent( - llm=self.llm, - tools=self.__create_agent_tools(schema), - system_message = self.system_message, - verbose=True - ) + # self.agent_executor = create_conversational_retrieval_agent( + # llm=self.llm, + # tools=self.__create_agent_tools(schema), + # system_message = self.system_message, + # verbose=True + # ) def run(self, message:str) -> str: @@ -43,10 +49,11 @@ def _get_cosmosdb_vector_store_retriever(self, collection_name, top_k=3) -> Vect connection_string = CONNECTION_STRING, namespace = f"{db_name}.{collection_name}", embedding = AOAI_EMBEDDINGS, - index_name =f"{db_name}_{collection_name}_index", + index_name =f"VectorSearchIndex", embedding_key = "contentVector", #TODO: Find out what these are for text_key = "_id" #TODO: Find out what these are for ) + vector_store.create_index() return vector_store.as_retriever(search_kwargs={"k": top_k}) diff --git a/server/agents/mental_health_agent.py b/server/agents/mental_health_agent.py index ce2f140e..7a49f2af 100644 --- a/server/agents/mental_health_agent.py +++ b/server/agents/mental_health_agent.py @@ -3,24 +3,28 @@ import spacy import json from enum import Enum -from pymongo import ASCENDING +import pymongo -from langchain_core.messages.human import HumanMessage -from langchain_core.messages.ai import AIMessage -from langchain_community.chat_message_histories.in_memory import ChatMessageHistory +from langchain_openai import AzureOpenAIEmbeddings from langchain_core.runnables.history import RunnableWithMessageHistory -from langchain_core.runnables.utils import ConfigurableFieldSpec from langchain.memory import ConversationSummaryMemory -from langchain_core.messages import SystemMessage from langchain_community.utilities import BingSearchAPIWrapper from langchain_community.tools.tavily_search import TavilySearchResults from langchain.agents.agent_toolkits import create_conversational_retrieval_agent -from langchain.agents import Tool +from langchain.agents import Tool, create_tool_calling_agent, AgentExecutor +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory from .ai_agent import AIAgent from services.azure_mongodb import MongoDBClient from utils.consts import SYSTEM_MESSAGE +from utils.consts import AGENT_FACTS +from utils.consts import PROCESSING_STEP from utils.docs import format_docs +from models.agent_fact import AgentFact +from models.chat_summary import ChatSummary +from models.chat_turn import ChatTurn + # Load spaCy model nlp = spacy.load("en_core_web_sm") @@ -39,68 +43,33 @@ class MentalHealthAIAgent(AIAgent): user entities, chat summaries, chat turns, resources, and user resources. """ - db = MongoDBClient.get_client() def __init__(self, system_message=SYSTEM_MESSAGE, schema=[]): self.system_message = system_message super().__init__(system_message, schema) - self.agent_executor = create_conversational_retrieval_agent( + self.agent_executor:AgentExecutor = create_conversational_retrieval_agent( llm=self.llm, - tools= self.prepare_tools(), - system_message = self.system_message, - verbose=True + tools=self.prepare_tools(), + system_message=self.system_message, + verbose=True, + kwargs={} ) - - - # @staticmethod - # def get_user_by_id(user_id:str) -> str: - # """ - # Retrieves a user by their ID. - # """ - - # doc = MentalHealthAIAgent.db.users.find_one({"_id": user_id}) - - def get_chat_history(self, user_id, chat_id, history_scope:ChatHistoryScope): - """ - Used to find details from previous conversations with the user. - """ - collection = self.db["chat_turns"] - - # Check if the 'timestamp' index already exists - indexes = collection.list_indexes() - if not any(index['key'].get('timestamp') for index in indexes): - collection.create_index([('timestamp', ASCENDING)]) - - turns = [] - if history_scope == ChatHistoryScope.ALL: - turns = list(collection.find({"user_id": user_id}).sort({"timestamp": -1}).limit(5)) - elif history_scope == ChatHistoryScope.PREVIOUS: - turns = list(collection.find({"user_id": user_id, "chat_id": (chat_id - 1)}).sort({"timestamp": -1})) - elif history_scope == ChatHistoryScope.CURRENT: - turns = list(collection.find({"user_id": user_id, "chat_id": chat_id}).sort({"timestamp": -1}).limit(5)) - - turns.reverse() - history_list = [] - - for turn in turns: - if turn.get("human_message"): - history_list.append(HumanMessage(turn.get("human_message"))) - if turn.get("ai_message"): - history_list.append(AIMessage(turn.get("ai_message"))) - - chat_history = ChatMessageHistory() - chat_history.add_messages(history_list) - - return chat_history + def get_session_history(self, session_id: str) -> MongoDBChatMessageHistory: + return MongoDBChatMessageHistory( + MongoDBClient.get_mongodb_variables(), + session_id, + MongoDBClient.get_db_name(), + collection_name="history" + ) def get_agent_memory(self, user_id, chat_id, history_scope=ChatHistoryScope.ALL): chat_history = self.get_chat_history(user_id, chat_id, history_scope) - + memory = ConversationSummaryMemory.from_messages( llm=self.llm, chat_memory=chat_history, @@ -110,75 +79,189 @@ def get_agent_memory(self, user_id, chat_id, history_scope=ChatHistoryScope.ALL) return memory - def get_agent_with_history(self, memory): + def get_agent_with_history(self, agent_executor): agent_with_history = RunnableWithMessageHistory( - self.agent_executor, - lambda chat_id, user_id: memory.chat_memory, + agent_executor, + get_session_history=self.get_session_history, input_messages_key="input", - history_messages_key="chat_history", - history_factory_config=[ - ConfigurableFieldSpec( - id="user_id", - annotation=str, - name="User ID", - description="Unique identifier for the user.", - default="", - is_shared=True, - ), - ConfigurableFieldSpec( - id="chat_id", - annotation=str, - name="Chat ID", - description="Unique identifier for the chat session.", - default="", - is_shared=True, - ), - ] + history_messages_key="history", + verbose=True ) return agent_with_history - - - def write_agent_response_to_db(self, invocation, user_id, chat_id, turn_id): - db_client = MongoDBClient.get_client() - db = db_client[MongoDBClient.get_db_name()] - chat_turns_collection = db["chat_turns"] - chat_turns_collection.insert_one({ - "user_id": user_id, - "chat_id": chat_id, - "turn_id": turn_id, - "human_message": invocation.get("input"), - "ai_message": invocation.get("output"), - "timestamp": datetime.now().isoformat() - }) - return invocation["output"] - - def run(self, message:str, with_history=True, user_id=None, chat_id=None, turn_id=None, history_scope=None): + def run(self, message: str, with_history=True, user_id=None, chat_id=None, turn_id=None, history_scope=None): if not with_history: return super().run(message) else: - #:TODO throw error if user_id, chat_id, or history_scope is set to None. - memory = self.get_agent_memory(user_id, chat_id, history_scope) - agent_with_history = self.get_agent_with_history(memory) - - if memory.buffer: - addendum = f""" - Previous Conversation Summary: - {memory.buffer} - """ - self.system_message.content = f"{self.system_message.content}\n{addendum}" + + # TODO: throw error if user_id, chat_id, or history_scope is set to None. + session_id = f"{user_id}-{chat_id}" + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", self.system_message.content), + MessagesPlaceholder(variable_name="history"), + ("human", "{input}"), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + ) + + tools = self.prepare_tools() + agent = create_tool_calling_agent(self.llm, self.prepare_tools(), prompt) + agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) + + agent_with_history = self.get_agent_with_history(agent_executor) invocation = agent_with_history.invoke( - { "input": message }, - config={"configurable": {"user_id": user_id, "chat_id": chat_id}} + {"input": message, "agent_scratchpad": []}, + config={"configurable": {"session_id": session_id}} ) - self.write_agent_response_to_db(invocation, user_id, chat_id, turn_id) + # This updates certain collections in the database based on recent history + if (turn_id + 1) % PROCESSING_STEP == 0: + # Chat Summary: + # Update every 5 chat turns + # Therapy Material + # Maybe not get it from DB at all? Just perform Bing search? + # User Entity: + # Can be saved from chat summary step, every 5 chat turns + # User Journey: + # Can be either updated at the end of the chat, or every 5 chat turns + # User Material: + # Possibly updated every 5 chat turns, at the end of a chat, or not at all + + self.update_chat_summary(user_id, chat_id) + self.update_user_entities() + self.update_user_journey() return invocation["output"] - + + def update_chat_summary(self, user_id, chat_id): + # TODO: Redo this function + pass + + # collection: Collection = self.db["chat_summaries"] + # loader = None + # last_summary = None + # latest_turns = None + # current_summary = None + + # query_filter = { + # "user_id": user_id, + # "chat_id": chat_id + # } + # If chat summary is empty, use the chat summary from last session. + # if collection.count_documents(query_filter): + # last_summary = MongoDBClient.get_mongodb_loader("chat_summaries", { + # "user_id": user_id, + # "chat_id": int(chat_id) - 1 + # }).load() + + # pass + + # # last_summary:ChatSummary = ChatSummary.model_validate(collection.find_one()) + # else: + + # If chat summary is not empty, use the latest chat summary and grab the latest 5 turns + # loader_output = MongoDBClient.get_mongodb_loader("chat_summaries", { + # "user_id": user_id, + # "chat_id": int(chat_id) - 1 + # }).load() + # last_summary:ChatSummary = ChatSummary.model_validate(collection.find_one(query_filter)) + # latest_turns = list(collection.find({"user_id": user_id, "chat_id": chat_id}).sort({"timestamp": -1}).limit(5)) + # pass + # try: + # last_summary: ChatSummary = ChatSummary.model_validate(collection.find_one({ + # "user_id": user_id, + # "chat_id": int(chat_id) - 1 + # })) + # except ValidationError as e: + # print(e) + # last_summary = ChatSummary.model_construct() + + # summary_message = AIMessage(last_summary.summary_text) + + # turns = self.db["chat_turns"].find( + # {"user_id": user_id, "chat_id": chat_id}).sort({"timestamp": -1}).limit(5) + # latest_turns: list[ChatTurn] = [ + # ChatTurn.model_validate(turn) for turn in turns] + # conversation_log = [] + + # for turn in latest_turns: + # conversation_log.append(HumanMessage(turn.human_message)) + # conversation_log.append(AIMessage(turn.ai_message)) + + # summary_template = """ + # Given the following summary and most recent conversation log, generate a new summary that updates and rewrites the existing summary as needed to + # capture the most salient points of the conversation log. + # Summary so far: + # {summary} + + # Latest conversation log: + # {conversation_log} + # """ + + # summary_prompt = PromptTemplate.from_template(summary_template) + + # result = summary_prompt.invoke( + # {"summary": summary_message, "conversation_log": conversation_log}) + + def update_user_entities(self): + # TODO + pass + + def update_user_journey(self): + # TODO + pass + + def update_collections(self): + # TODO + pass + + @staticmethod + def generate_embeddings(text: str): + embeddings_model = AzureOpenAIEmbeddings( + azure_endpoint= os.environ["AOAI_ENDPOINT"], + api_key= os.environ["AOAI_KEY"], + azure_deployment= os.environ["EMBEDDINGS_DEPLOYMENT_NAME"] + ) + embeddings = embeddings_model.embed_query(text) + + return embeddings + + + @staticmethod + def load_agent_facts_to_db(): + db = MongoDBClient.get_client()[MongoDBClient.get_db_name()] + collection = db["agent_facts"] + + if collection.count_documents({}) == 0: + print("There are no documents in Agent Facts. writing documents...") + validated_models: list[AgentFact] = [ + AgentFact.model_validate(fact_dict) for fact_dict in AGENT_FACTS] + facts_to_load = [AgentFact.model_dump( + fact_model) for fact_model in validated_models] + collection.insert_many(facts_to_load) + + bulk_operations = [] + for doc in collection.find(): + if "contentVector" in doc: + del doc["contentVector"] + + content = json.dumps(doc, default=str) + content_vector = MentalHealthAIAgent.generate_embeddings(content) + + bulk_operations.append(pymongo.UpdateOne( + {"_id": doc["_id"]}, + {"$set": {"contentVector": content_vector}}, + upsert=True + )) + collection.bulk_write(bulk_operations) + else: + print("Agent facts are already populated. Skipping step.") + def analyze_chat(self, text): """Analyze the chat text to determine emotional state and detect triggers.""" doc = nlp(text) @@ -240,7 +323,7 @@ def analyze_chat(self, text): ] return {"emotions": emotions, "triggers": triggers, "patterns": patterns} - + def _run(self, message: str, with_history=True, user_id=None, chat_id=None, turn_id=None, history_scope=None): try: if not with_history: @@ -249,19 +332,19 @@ def _run(self, message: str, with_history=True, user_id=None, chat_id=None, turn memory = self.get_agent_memory(user_id, chat_id, history_scope) if not memory: return "Error: Unable to retrieve conversation history." - + if memory.buffer: addendum = f""" Previous Conversation Summary: {memory.buffer} - """ + """ self.system_message.content = f"{self.system_message.content}\n{addendum}" agent_with_history = self.get_agent_with_history(memory) - + # Analyze the message for emotional content analysis_results = self.analyze_chat(message) - response_addendum = self.format_response_addendum(analysis_results) + response_addendum = self.format_response_addendum(analysis_results) # Invoke the agent with history context invocation = agent_with_history.invoke({"input": f"{message}\n{response_addendum}"}, config={"configurable": {"user_id": user_id, "chat_id": chat_id}}) @@ -273,53 +356,63 @@ def _run(self, message: str, with_history=True, user_id=None, chat_id=None, turn except Exception as e: return f"An error occurred: {str(e)}" - def format_response_addendum(self, analysis_results): patterns = analysis_results['patterns'] response_addendum = "" for state, suggestions in patterns.items(): - response_addendum += f"Detected {state}: " + "; ".join(suggestions) + "\n" + response_addendum += f"Detected {state}: " + \ + "; ".join(suggestions) + "\n" return response_addendum.strip() - def prepare_tools(self): # search = BingSearchAPIWrapper(k=5) search = TavilySearchResults() community_tools = [search] - + # cosmosdb_tool = get_cosmosdb_tool(db_name, collection_name) - user_journeys_retriever_chain = self._get_cosmosdb_vector_store_retriever("user_journeys") | format_docs - user_materials_retriever_chain = self._get_cosmosdb_vector_store_retriever("user_materials") | format_docs - user_entities_retriever_chain = self._get_cosmosdb_vector_store_retriever("user_entities") | format_docs - agent_facts_retriever_chain = self._get_cosmosdb_vector_store_retriever("agent_facts") | format_docs + agent_facts_retriever_chain = self._get_cosmosdb_vector_store_retriever("agent_facts") #| format_docs + # user_profiles_retriever_chain = self._get_cosmosdb_vector_store_retriever("users") | format_docs + # user_journeys_retriever_chain = self._get_cosmosdb_vector_store_retriever("user_journeys") | format_docs + # user_materials_retriever_chain = self._get_cosmosdb_vector_store_retriever("user_materials") | format_docs + # user_entities_retriever_chain = self._get_cosmosdb_vector_store_retriever("user_entities") | format_docs + # agent_facts_retriever_chain = self._get_cosmosdb_vector_store_retriever("agent_facts") | format_docs custom_tools = [ - Tool( - name = "vector_search_user_journeys", - func = user_journeys_retriever_chain.invoke, - description = "Searches a mental health patient's user journey for " - ), - Tool( - name = "vector_search_user_materials", - func = user_materials_retriever_chain.invoke, - description = "" - ), - Tool( - name = "vector_search_user_entities", - func = user_entities_retriever_chain.invoke, - description = "" - ), - Tool( - name = "vector_search_agent_facts", - func = agent_facts_retriever_chain.invoke, - description = "" - ) + # Tool( + # name="vector_search_agent_facts", + # func=agent_facts_retriever_chain.invoke, + # description="Searches for facts about the agent itself." + # ), + # Tool( + # name = "vector_search_user_journeys", + # func = user_profiles_retriever_chain.invoke, + # description = "Searches a user's profile for personal information." + # ), + # Tool( + # name = "vector_search_user_journeys", + # func = user_journeys_retriever_chain.invoke, + # description = "Searches a mental health patient's user journey." + # ), + # Tool( + # name = "vector_search_user_materials", + # func = user_materials_retriever_chain.invoke, + # description = "" + # ), + # Tool( + # name = "vector_search_user_entities", + # func = user_entities_retriever_chain.invoke, + # description = "" + # ), + # Tool( + # name = "vector_search_agent_facts", + # func = agent_facts_retriever_chain.invoke, + # description = "" + # ) ] - all_tools = community_tools #+ custom_tools + all_tools = community_tools + custom_tools return all_tools - def get_initial_greeting(self, user_id): db_client = MongoDBClient.get_client() db_name = MongoDBClient.get_db_name() @@ -340,54 +433,62 @@ def get_initial_greeting(self, user_id): "therapy_plan": [], "mental_health_concerns": [] }) - + addendum = """This is your first session with the patient. Be polite and introduce yourself in a friendly and inviting manner. In this session, do your best to understand what the user hopes to achieve through your service, and derive a therapy style fitting to their needs. """ - + full_system_message = ''.join([system_message.content, addendum]) system_message.content = full_system_message response = self.run( - message="", - with_history=False, - user_id=user_id, - chat_id=0, - turn_id=0, - history_scope = ChatHistoryScope.ALL, - ) - - return { - "message": response, - "chat_id": 0 + message="", + with_history=False, + user_id=user_id, + chat_id=0, + turn_id=0, + history_scope=ChatHistoryScope.ALL, + ) + + return { + "message": response, + "chat_id": 0 } - + else: try: - last_turn = db.chat_turns.find({"user_id": user_id}).sort({"timestamp": -1}).limit(1).next() + last_turn = db.chat_turns.find({"user_id": user_id}).sort( + {"timestamp": -1}).limit(1).next() except StopIteration: last_turn = {} old_chat_id = last_turn.get("chat_id", -1) new_chat_id = old_chat_id + 1 - response = self.run( - message="", - with_history=True, - user_id=user_id, - chat_id=new_chat_id, - turn_id=0, - history_scope=ChatHistoryScope.PREVIOUS - ) - + message="", + with_history=True, + user_id=user_id, + chat_id=new_chat_id, + turn_id=0, + history_scope=ChatHistoryScope.PREVIOUS + ) + return { "message": response, "chat_id": new_chat_id } - - def get_user_journey_by_user_id(self, user_id:str) -> str: + def get_user_profile_by_user_id(self, user_id: str) -> str: + """ + Retrieves a user journey by the user's ID. + """ + doc = self.db["users"].find_one({"user_id": user_id}) + if "contentVector" in doc: + del doc["contentVector"] + return json.dumps(doc) + + def get_user_journey_by_user_id(self, user_id: str) -> str: """ Retrieves a user journey by the user's ID. """ @@ -395,18 +496,17 @@ def get_user_journey_by_user_id(self, user_id:str) -> str: if "contentVector" in doc: del doc["contentVector"] return json.dumps(doc) - - - def get_chat_summary_by_composite_id(self, user_id:str, chat_id:str) -> str: + + def get_chat_summary_by_composite_id(self, user_id: str, chat_id: str) -> str: """ Retrieves a summary of a chat between a user and the agent by a combination of the user's ID and the chat instance ID. """ - doc = self.db["chat_summaries"].find_one({"user_id": user_id, "chat_id": chat_id}) + doc = self.db["chat_summaries"].find_one( + {"user_id": user_id, "chat_id": chat_id}) if "contentVector" in doc: del doc["contentVector"] return json.dumps(doc) - def get_user_material_by_user_id(self, user_id: str) -> str: """ @@ -417,7 +517,6 @@ def get_user_material_by_user_id(self, user_id: str) -> str: del doc["contentVector"] return json.dumps(doc) - def get_user_entity_by_user_id(self, user_id: str): """ Retrieves a user's known entity by the user's ID. @@ -491,7 +590,7 @@ def get_user_entity_by_user_id(self, user_id: str): # dimensions, # similarity_algorithm, # kind, -# m, +# m, # ef_construction # ) @@ -502,4 +601,18 @@ def get_user_entity_by_user_id(self, user_id: str): # description = "Searches the Mental Health database for psychology theory." # ) -# return cosmosdb_tool \ No newline at end of file +# return cosmosdb_tool + + +# Agent Fact: +# Prepopulate to DB +# Chat Summary: +# Update every 5 chat turns +# Therapy Material +# Maybe not get it from DB at all? Just perform Bing search? +# User Entity: +# Can be saved from chat summary step, every 5 chat turns +# User Journey: +# Can be either updated at the end of the chat, or every 5 chat turns +# User Material: +# Possibly updated every 5 chat turns, at the end of a chat, or not at all diff --git a/server/app.py b/server/app.py index daeda188..e7c2d696 100644 --- a/server/app.py +++ b/server/app.py @@ -9,6 +9,8 @@ from routes.user import user_routes from routes.ai import ai_routes +from agents.mental_health_agent import MentalHealthAIAgent + # Set up the app app = Flask(__name__) app.config['JWT_SECRET_KEY'] = 't54WKRE5t5UaZnEWDvUd75Qe5ilYAKKe9n8tbUGv3_Q' #FIXME: This should be an environment variable. @@ -19,6 +21,8 @@ app.register_blueprint(user_routes) app.register_blueprint(ai_routes) +# DB pre-load +MentalHealthAIAgent.load_agent_facts_to_db() # Base endpoint @app.get("/") diff --git a/server/models/agent_fact.py b/server/models/agent_fact.py index 354dc1fc..a8eb8484 100644 --- a/server/models/agent_fact.py +++ b/server/models/agent_fact.py @@ -6,6 +6,5 @@ from pydantic import BaseModel class AgentFact(BaseModel): - fact_id: str sample_query: str fact: str \ No newline at end of file diff --git a/server/models/chat_summary.py b/server/models/chat_summary.py index 8615ce6f..ca6d7d94 100644 --- a/server/models/chat_summary.py +++ b/server/models/chat_summary.py @@ -15,7 +15,7 @@ class ConcernProgress(BaseModel): class ChatSummary(BaseModel): user_id: str chat_id: str - last_update: datetime + last_update: datetime = datetime.now() perceived_mood: str - summary_text: str + summary_text: str = "" concerns_progress: list[ConcernProgress] diff --git a/server/models/chat_turn.py b/server/models/chat_turn.py index d21eb409..8be28a4c 100644 --- a/server/models/chat_turn.py +++ b/server/models/chat_turn.py @@ -8,8 +8,8 @@ class ChatTurn(BaseModel): user_id: str - chat_id: str - turn_id: str + chat_id: int + turn_id: int human_message: str ai_message: str timestamp: datetime diff --git a/server/services/azure_mongodb.py b/server/services/azure_mongodb.py index 5dcf86ee..6866b935 100644 --- a/server/services/azure_mongodb.py +++ b/server/services/azure_mongodb.py @@ -8,6 +8,7 @@ import pymongo from pymongo import UpdateOne, ReturnDocument import mongomock +from langchain_community.document_loaders.mongodb import MongodbLoader from utils.consts import APP_NAME @@ -44,6 +45,20 @@ def get_client(cls): return cls._client + @staticmethod + def get_mongodb_loader(collection_name, db_filter): + CONNECTION_STRING = MongoDBClient.get_mongodb_variables() + + loader = MongodbLoader( + connection_string=CONNECTION_STRING, + db_name= MongoDBClient.get_db_name(), + collection_name=collection_name, + filter_criteria=db_filter + ) + + return loader + + @classmethod def get_db_name(cls): ENV = os.environ.get("FLASK_ENV") diff --git a/server/tests/test_user.py b/server/tests/test_user.py index b909eac0..53241632 100644 --- a/server/tests/test_user.py +++ b/server/tests/test_user.py @@ -2,7 +2,7 @@ sys.path.append(".") from app import app -from tools.azure_mongodb import MongoDBClient +from services.azure_mongodb import MongoDBClient user_data = { "username": "user1", diff --git a/server/utils/consts.py b/server/utils/consts.py index fdd9c19a..6658262f 100644 --- a/server/utils/consts.py +++ b/server/utils/consts.py @@ -14,4 +14,29 @@ AI: I'm sorry, but I am not built to answer that question. My role is to help you with your mental health goals. If you do not know the answer to a question, respond with \"I don't know.\" - """ \ No newline at end of file + """ + +PROCESSING_STEP = 5 + +AGENT_FACTS = [ + { + "sample_query": "What is your name?", + "fact": "Your name is Aria." + }, + { + "sample_query": "When were you built?", + "fact": "You were built in 2024." + }, + { + "sample_query": "Who built you?", + "fact": "You were built by software developers in the US for the Microsoft Developers AI Learning Hackathon." + }, + { + "sample_query": "What is your purpose?", + "fact": "Your purpose is to help humans with their mental health concerns." + }, + { + "sample_query": "Are you human?", + "fact": "You are not human, you are a virtual mental health companion." + }, +]