Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…panion into dev
  • Loading branch information
dhrumilp12 committed Jun 14, 2024
2 parents 932385f + 7911555 commit d400ac1
Show file tree
Hide file tree
Showing 9 changed files with 356 additions and 193 deletions.
25 changes: 16 additions & 9 deletions server/agents/ai_agent.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,35 @@
import re

from langchain_community.vectorstores import AzureCosmosDBVectorSearch
from langchain_community.vectorstores.azure_cosmos_db import (
AzureCosmosDBVectorSearch,
CosmosDBSimilarityType,
CosmosDBVectorSearchType
)
from langchain.agents import Tool
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
from langchain.tools import StructuredTool
from langchain_core.messages import SystemMessage
from langchain_core.vectorstores import VectorStoreRetriever

from pymongo.database import Database

from services.azure_mongodb import MongoDBClient
from services.my_azure import get_azure_openai_variables, get_azure_openai_llm, get_azure_openai_embeddings

class AIAgent:
def __init__(self, system_message:str, schema:list[str]=[]):
self.db = (MongoDBClient.get_client())[MongoDBClient.get_db_name()]
self.db:Database = (MongoDBClient.get_client())[MongoDBClient.get_db_name()]
self.llm = get_azure_openai_llm()
self.embedding_model = get_azure_openai_embeddings()

self.system_message = SystemMessage(content=system_message)

self.agent_executor = create_conversational_retrieval_agent(
llm=self.llm,
tools=self.__create_agent_tools(schema),
system_message = self.system_message,
verbose=True
)
# self.agent_executor = create_conversational_retrieval_agent(
# llm=self.llm,
# tools=self.__create_agent_tools(schema),
# system_message = self.system_message,
# verbose=True
# )


def run(self, message:str) -> str:
Expand All @@ -43,10 +49,11 @@ def _get_cosmosdb_vector_store_retriever(self, collection_name, top_k=3) -> Vect
connection_string = CONNECTION_STRING,
namespace = f"{db_name}.{collection_name}",
embedding = AOAI_EMBEDDINGS,
index_name =f"{db_name}_{collection_name}_index",
index_name =f"VectorSearchIndex",
embedding_key = "contentVector", #TODO: Find out what these are for
text_key = "_id" #TODO: Find out what these are for
)
vector_store.create_index()
return vector_store.as_retriever(search_kwargs={"k": top_k})


Expand Down
Loading

0 comments on commit d400ac1

Please sign in to comment.