diff --git a/README.md b/README.md index dd4d7aba..afa4e642 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,40 @@ cd streamlit_app streamlit run app.py ``` +## Usage +```python +from memary.agent.chat_agent import ChatAgent + +system_persona_txt = "data/system_persona.txt" +user_persona_txt = "data/user_persona.txt" +past_chat_json = "data/past_chat.json" +memory_stream_json = "data/memory_stream.json" +entity_knowledge_store_json = "data/entity_knowledge_store.json" +chat_agent = ChatAgent( + "Personal Agent", + memory_stream_json, + entity_knowledge_store_json, + system_persona_txt, + user_persona_txt, + past_chat_json, +) +``` +Pass in subset of `['search', 'vision', 'locate', 'stocks']` as `include_from_defaults` for different set of default tools upon initialization. +### Adding Custom Tools +```python +def multiply(a: int, b: int) -> int: + """Multiply two integers and returns the result integer""" + return a * b + +chat_agent.add_tool({"multiply": multiply}) +``` +More information about creating custom tools for the LlamaIndex ReAct Agent can be found [here](https://docs.llamaindex.ai/en/stable/examples/agent/react_agent/). + +### Removing Tools +```python +chat_agent.remove_tool("multiply") +``` + ## Detailed Component Breakdown ### Routing Agent diff --git a/diagrams/context_window.png b/diagrams/context_window.png new file mode 100644 index 00000000..074baeb4 Binary files /dev/null and b/diagrams/context_window.png differ diff --git a/diagrams/memary_logo_bw.png b/diagrams/memary_logo_bw.png new file mode 100644 index 00000000..aeaf5e1d Binary files /dev/null and b/diagrams/memary_logo_bw.png differ diff --git a/src/memary/agent/base_agent.py b/src/memary/agent/base_agent.py index 1ee2d271..a63a12fa 100644 --- a/src/memary/agent/base_agent.py +++ b/src/memary/agent/base_agent.py @@ -2,6 +2,7 @@ import os import sys from pathlib import Path +from typing import Any, Callable, Dict, List import geocoder import googlemaps @@ -9,12 +10,8 @@ import requests from ansistrip import ansi_strip from dotenv import load_dotenv -from llama_index.core import ( - KnowledgeGraphIndex, - Settings, - SimpleDirectoryReader, - StorageContext, -) +from llama_index.core import (KnowledgeGraphIndex, Settings, + SimpleDirectoryReader, StorageContext) from llama_index.core.agent import ReActAgent from llama_index.core.llms import ChatMessage from llama_index.core.query_engine import RetrieverQueryEngine @@ -28,10 +25,8 @@ from llama_index.multi_modal_llms.openai import OpenAIMultiModal from memary.agent.data_types import Context, Message -from memary.agent.llm_api.tools import ( - ollama_chat_completions_request, - openai_chat_completions_request, -) +from memary.agent.llm_api.tools import (ollama_chat_completions_request, + openai_chat_completions_request) from memary.memory import EntityKnowledgeStore, MemoryStream from memary.synonym_expand.synonym import custom_synonym_expand_fn @@ -65,6 +60,7 @@ def __init__( past_chat_json, llm_model_name="llama3", vision_model_name="llava", + include_from_defaults=["search", "locate", "vision", "stocks"], debug=True, ): load_dotenv() @@ -99,7 +95,6 @@ def __init__( ) self.vantage_key = os.getenv("ALPHA_VANTAGE_API_KEY") - # self.news_data_key = os.getenv("NEWS_DATA_API_KEY") self.storage_context = StorageContext.from_defaults( graph_store=self.graph_store @@ -116,18 +111,9 @@ def __init__( graph_rag_retriever, ) - search_tool = FunctionTool.from_defaults(fn=self.search) - locate_tool = FunctionTool.from_defaults(fn=self.locate) - vision_tool = FunctionTool.from_defaults(fn=self.vision) - stock_tool = FunctionTool.from_defaults(fn=self.stock_price) - # news_tool = FunctionTool.from_defaults(fn=self.get_news) - self.debug = debug - self.routing_agent = ReActAgent.from_tools( - [search_tool, locate_tool, vision_tool, stock_tool], - llm=self.llm, - verbose=True, - ) + self.tools = {} + self._init_default_tools(default_tools=include_from_defaults) self.memory_stream = MemoryStream(memory_stream_json) self.entity_knowledge_store = EntityKnowledgeStore(entity_knowledge_store_json) @@ -211,7 +197,7 @@ def vision(self, query: str, img_url: str) -> str: os.remove(query_image_path) # delete image after use return response - def stock_price(self, query: str) -> str: + def stocks(self, query: str) -> str: """Get the stock price of the company given the ticker""" request_api = requests.get( r"https://www.alphavantage.co/query?function=GLOBAL_QUOTE&symbol=" @@ -435,19 +421,67 @@ def get_entity(self, retrieve) -> list[str]: entities.remove(exceptions) return entities - def update_tools(self, updatedTools): - print("recieved update tools") - tools = [] - for tool in updatedTools: - if tool == "Search": - tools.append(FunctionTool.from_defaults(fn=self.search)) - elif tool == "Location": - tools.append(FunctionTool.from_defaults(fn=self.locate)) - elif tool == "Vision": - tools.append(FunctionTool.from_defaults(fn=self.vision)) - elif tool == "Stocks": - tools.append(FunctionTool.from_defaults(fn=self.stock_price)) - # elif tool == "News": - # tools.append(FunctionTool.from_defaults(fn=self.get_news)) - - self.routing_agent = ReActAgent.from_tools(tools, llm=self.llm, verbose=True) + def _init_ReAct_agent(self): + """Initializes ReAct Agent with list of tools in self.tools.""" + tool_fns = [] + for func in self.tools.values(): + tool_fns.append(FunctionTool.from_defaults(fn=func)) + self.routing_agent = ReActAgent.from_tools(tool_fns, llm=self.llm, verbose=True) + + def _init_default_tools(self, default_tools: List[str]): + """Initializes ReAct Agent from the default list of tools memary provides. + List of strings passed in during initialization denoting which default tools to include. + Args: + default_tools (list(str)): list of tool names in string form + """ + + for tool in default_tools: + if tool == "search": + self.tools["search"] = self.search + elif tool == "locate": + self.tools["locate"] = self.locate + elif tool == "vision": + self.tools["vision"] = self.vision + elif tool == "stocks": + self.tools["stocks"] = self.stocks + self._init_ReAct_agent() + + def add_tool(self, tool_additions: Dict[str, Callable[..., Any]]): + """Adds specified tools to be used by the ReAct Agent. + Args: + tools (dict(str, func)): dictionary of tools with names as keys and associated functions as values + """ + + for tool_name in tool_additions: + self.tools[tool_name] = tool_additions[tool_name] + self._init_ReAct_agent() + + def remove_tool(self, tool_name: str): + """Removes specified tool from list of available tools for use by the ReAct Agent. + Args: + tool_name (str): name of tool to be removed in string form + """ + + if tool_name in self.tools: + del self.tools[tool_name] + self._init_ReAct_agent() + else: + raise ("Unknown tool_name provided for removal.") + + def update_tools(self, updated_tools: List[str]): + """Resets ReAct Agent tools to only include subset of default tools. + Args: + updated_tools (list(str)): list of default tools to include + """ + + self.tools.clear() + for tool in updated_tools: + if tool == "search": + self.tools["search"] = self.search + elif tool == "locate": + self.tools["locate"] = self.locate + elif tool == "vision": + self.tools["vision"] = self.vision + elif tool == "stocks": + self.tools["stocks"] = self.stocks + self._init_ReAct_agent() diff --git a/src/memary/agent/chat_agent.py b/src/memary/agent/chat_agent.py index d8d992da..814068a6 100644 --- a/src/memary/agent/chat_agent.py +++ b/src/memary/agent/chat_agent.py @@ -1,22 +1,39 @@ from typing import Optional, List from memary.agent.base_agent import Agent +import logging class ChatAgent(Agent): """ChatAgent currently able to support Llama3 running on Ollama (default) and gpt-3.5-turbo for llm models, and LLaVA running on Ollama (default) and gpt-4-vision-preview for the vision tool. """ - def __init__(self, name, memory_stream_json, entity_knowledge_store_json, - system_persona_txt, user_persona_txt, past_chat_json, llm_model_name="llama3", vision_model_name="llava"): - super().__init__(name, memory_stream_json, entity_knowledge_store_json, - system_persona_txt, user_persona_txt, past_chat_json, llm_model_name, vision_model_name) - - def add_chat(self, - role: str, - content: str, - entities: Optional[List[str]] = None): + def __init__( + self, + name, + memory_stream_json, + entity_knowledge_store_json, + system_persona_txt, + user_persona_txt, + past_chat_json, + llm_model_name="llama3", + vision_model_name="llava", + include_from_defaults=["search", "locate", "vision", "stocks"], + ): + super().__init__( + name, + memory_stream_json, + entity_knowledge_store_json, + system_persona_txt, + user_persona_txt, + past_chat_json, + llm_model_name, + vision_model_name, + include_from_defaults, + ) + + def add_chat(self, role: str, content: str, entities: Optional[List[str]] = None): """Add a chat to the agent's memory. Args: @@ -30,8 +47,7 @@ def add_chat(self, if entities: self.memory_stream.add_memory(entities) self.memory_stream.save_memory() - self.entity_knowledge_store.add_memory( - self.memory_stream.get_memory()) + self.entity_knowledge_store.add_memory(self.memory_stream.get_memory()) self.entity_knowledge_store.save_memory() self._replace_memory_from_llm_message() @@ -41,28 +57,27 @@ def get_chat(self): return self.contexts def clearMemory(self): + """Clears Neo4j database and memory stream/entity knowledge store.""" + + logging.info("Deleting memory stream and entity knowledge store...") self.memory_stream.clear_memory() self.entity_knowledge_store.clear_memory() - - # print("removed from mem stream and entity knowdlege store ") - "clears knowledge neo4j database" - print("Deleting nodes from Neo4j...") + logging.info("Deleting nodes from Neo4j...") try: self.graph_store.query("MATCH (n) DETACH DELETE n") except Exception as e: - print(f"Error deleting nodes: {e}") - print("Nodes deleted from Neo4j.") + logging.error(f"Error deleting nodes: {e}") + logging.info("Nodes deleted from Neo4j.") def _replace_memory_from_llm_message(self): """Replace the memory_stream from the llm_message.""" - self.message.llm_message[ - "memory_stream"] = self.memory_stream.get_memory() + self.message.llm_message["memory_stream"] = self.memory_stream.get_memory() def _replace_eks_to_from_message(self): """Replace the entity knowledge store from the llm_message. eks = entity knowledge store""" - self.message.llm_message[ - "knowledge_entity_store"] = self.entity_knowledge_store.get_memory( - ) + self.message.llm_message["knowledge_entity_store"] = ( + self.entity_knowledge_store.get_memory() + ) diff --git a/streamlit_app/app.py b/streamlit_app/app.py index 9f2cb004..e08d2ea1 100644 --- a/streamlit_app/app.py +++ b/streamlit_app/app.py @@ -136,13 +136,12 @@ def get_models(llm_models, vision_models): tools = st.multiselect( "Select tools to include:", - # ["Search", "Location", "Vision", "Stocks", "News"], #all options available - # ["Search", "Location", "Vision", "Stocks", "News"],) #options that are selected by default - ["Search", "Location", "Vision", "Stocks"], # all options available - ["Search", "Location", "Vision", "Stocks"], - ) # options that are selected by default + ["search", "locate", "vision", "stocks"], # all options available + ["search", "locate", "vision", "stocks"], # options that are selected by default + ) - if "Vision" in tools: + img_url = "" + if "vision" in tools: img_url = st.text_input("URL of image, leave blank if no image to provide") if img_url: st.image(img_url, caption="Uploaded Image", use_column_width=True) @@ -167,7 +166,6 @@ def get_models(llm_models, vision_models): st.write("Please select at least one tool") st.stop() - print("start update tools") chat_agent.update_tools(tools) if img_url: