Skip to content

Commit

Permalink
Tool flexibility (#40)
Browse files Browse the repository at this point in the history
* init: tool flexibility code

* add: tool removal

* fix: single class tool variable

* fix: update tools

* add: docs and agent fixes

* add: new images

* fix: more logging

* fix: typing
  • Loading branch information
kevinl424 authored May 27, 2024
1 parent c78a307 commit d9906d8
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 68 deletions.
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,40 @@ cd streamlit_app
streamlit run app.py
```

## Usage
```python
from memary.agent.chat_agent import ChatAgent

system_persona_txt = "data/system_persona.txt"
user_persona_txt = "data/user_persona.txt"
past_chat_json = "data/past_chat.json"
memory_stream_json = "data/memory_stream.json"
entity_knowledge_store_json = "data/entity_knowledge_store.json"
chat_agent = ChatAgent(
"Personal Agent",
memory_stream_json,
entity_knowledge_store_json,
system_persona_txt,
user_persona_txt,
past_chat_json,
)
```
Pass in subset of `['search', 'vision', 'locate', 'stocks']` as `include_from_defaults` for different set of default tools upon initialization.
### Adding Custom Tools
```python
def multiply(a: int, b: int) -> int:
"""Multiply two integers and returns the result integer"""
return a * b

chat_agent.add_tool({"multiply": multiply})
```
More information about creating custom tools for the LlamaIndex ReAct Agent can be found [here](https://docs.llamaindex.ai/en/stable/examples/agent/react_agent/).

### Removing Tools
```python
chat_agent.remove_tool("multiply")
```

## Detailed Component Breakdown

### Routing Agent
Expand Down
Binary file added diagrams/context_window.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added diagrams/memary_logo_bw.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
112 changes: 73 additions & 39 deletions src/memary/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,16 @@
import os
import sys
from pathlib import Path
from typing import Any, Callable, Dict, List

import geocoder
import googlemaps
import numpy as np
import requests
from ansistrip import ansi_strip
from dotenv import load_dotenv
from llama_index.core import (
KnowledgeGraphIndex,
Settings,
SimpleDirectoryReader,
StorageContext,
)
from llama_index.core import (KnowledgeGraphIndex, Settings,
SimpleDirectoryReader, StorageContext)
from llama_index.core.agent import ReActAgent
from llama_index.core.llms import ChatMessage
from llama_index.core.query_engine import RetrieverQueryEngine
Expand All @@ -28,10 +25,8 @@
from llama_index.multi_modal_llms.openai import OpenAIMultiModal

from memary.agent.data_types import Context, Message
from memary.agent.llm_api.tools import (
ollama_chat_completions_request,
openai_chat_completions_request,
)
from memary.agent.llm_api.tools import (ollama_chat_completions_request,
openai_chat_completions_request)
from memary.memory import EntityKnowledgeStore, MemoryStream
from memary.synonym_expand.synonym import custom_synonym_expand_fn

Expand Down Expand Up @@ -65,6 +60,7 @@ def __init__(
past_chat_json,
llm_model_name="llama3",
vision_model_name="llava",
include_from_defaults=["search", "locate", "vision", "stocks"],
debug=True,
):
load_dotenv()
Expand Down Expand Up @@ -99,7 +95,6 @@ def __init__(
)

self.vantage_key = os.getenv("ALPHA_VANTAGE_API_KEY")
# self.news_data_key = os.getenv("NEWS_DATA_API_KEY")

self.storage_context = StorageContext.from_defaults(
graph_store=self.graph_store
Expand All @@ -116,18 +111,9 @@ def __init__(
graph_rag_retriever,
)

search_tool = FunctionTool.from_defaults(fn=self.search)
locate_tool = FunctionTool.from_defaults(fn=self.locate)
vision_tool = FunctionTool.from_defaults(fn=self.vision)
stock_tool = FunctionTool.from_defaults(fn=self.stock_price)
# news_tool = FunctionTool.from_defaults(fn=self.get_news)

self.debug = debug
self.routing_agent = ReActAgent.from_tools(
[search_tool, locate_tool, vision_tool, stock_tool],
llm=self.llm,
verbose=True,
)
self.tools = {}
self._init_default_tools(default_tools=include_from_defaults)

self.memory_stream = MemoryStream(memory_stream_json)
self.entity_knowledge_store = EntityKnowledgeStore(entity_knowledge_store_json)
Expand Down Expand Up @@ -211,7 +197,7 @@ def vision(self, query: str, img_url: str) -> str:
os.remove(query_image_path) # delete image after use
return response

def stock_price(self, query: str) -> str:
def stocks(self, query: str) -> str:
"""Get the stock price of the company given the ticker"""
request_api = requests.get(
r"https://www.alphavantage.co/query?function=GLOBAL_QUOTE&symbol="
Expand Down Expand Up @@ -435,19 +421,67 @@ def get_entity(self, retrieve) -> list[str]:
entities.remove(exceptions)
return entities

def update_tools(self, updatedTools):
print("recieved update tools")
tools = []
for tool in updatedTools:
if tool == "Search":
tools.append(FunctionTool.from_defaults(fn=self.search))
elif tool == "Location":
tools.append(FunctionTool.from_defaults(fn=self.locate))
elif tool == "Vision":
tools.append(FunctionTool.from_defaults(fn=self.vision))
elif tool == "Stocks":
tools.append(FunctionTool.from_defaults(fn=self.stock_price))
# elif tool == "News":
# tools.append(FunctionTool.from_defaults(fn=self.get_news))

self.routing_agent = ReActAgent.from_tools(tools, llm=self.llm, verbose=True)
def _init_ReAct_agent(self):
"""Initializes ReAct Agent with list of tools in self.tools."""
tool_fns = []
for func in self.tools.values():
tool_fns.append(FunctionTool.from_defaults(fn=func))
self.routing_agent = ReActAgent.from_tools(tool_fns, llm=self.llm, verbose=True)

def _init_default_tools(self, default_tools: List[str]):
"""Initializes ReAct Agent from the default list of tools memary provides.
List of strings passed in during initialization denoting which default tools to include.
Args:
default_tools (list(str)): list of tool names in string form
"""

for tool in default_tools:
if tool == "search":
self.tools["search"] = self.search
elif tool == "locate":
self.tools["locate"] = self.locate
elif tool == "vision":
self.tools["vision"] = self.vision
elif tool == "stocks":
self.tools["stocks"] = self.stocks
self._init_ReAct_agent()

def add_tool(self, tool_additions: Dict[str, Callable[..., Any]]):
"""Adds specified tools to be used by the ReAct Agent.
Args:
tools (dict(str, func)): dictionary of tools with names as keys and associated functions as values
"""

for tool_name in tool_additions:
self.tools[tool_name] = tool_additions[tool_name]
self._init_ReAct_agent()

def remove_tool(self, tool_name: str):
"""Removes specified tool from list of available tools for use by the ReAct Agent.
Args:
tool_name (str): name of tool to be removed in string form
"""

if tool_name in self.tools:
del self.tools[tool_name]
self._init_ReAct_agent()
else:
raise ("Unknown tool_name provided for removal.")

def update_tools(self, updated_tools: List[str]):
"""Resets ReAct Agent tools to only include subset of default tools.
Args:
updated_tools (list(str)): list of default tools to include
"""

self.tools.clear()
for tool in updated_tools:
if tool == "search":
self.tools["search"] = self.search
elif tool == "locate":
self.tools["locate"] = self.locate
elif tool == "vision":
self.tools["vision"] = self.vision
elif tool == "stocks":
self.tools["stocks"] = self.stocks
self._init_ReAct_agent()
59 changes: 37 additions & 22 deletions src/memary/agent/chat_agent.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,39 @@
from typing import Optional, List

from memary.agent.base_agent import Agent
import logging


class ChatAgent(Agent):
"""ChatAgent currently able to support Llama3 running on Ollama (default) and gpt-3.5-turbo for llm models,
and LLaVA running on Ollama (default) and gpt-4-vision-preview for the vision tool.
"""
def __init__(self, name, memory_stream_json, entity_knowledge_store_json,
system_persona_txt, user_persona_txt, past_chat_json, llm_model_name="llama3", vision_model_name="llava"):
super().__init__(name, memory_stream_json, entity_knowledge_store_json,
system_persona_txt, user_persona_txt, past_chat_json, llm_model_name, vision_model_name)


def add_chat(self,
role: str,
content: str,
entities: Optional[List[str]] = None):
def __init__(
self,
name,
memory_stream_json,
entity_knowledge_store_json,
system_persona_txt,
user_persona_txt,
past_chat_json,
llm_model_name="llama3",
vision_model_name="llava",
include_from_defaults=["search", "locate", "vision", "stocks"],
):
super().__init__(
name,
memory_stream_json,
entity_knowledge_store_json,
system_persona_txt,
user_persona_txt,
past_chat_json,
llm_model_name,
vision_model_name,
include_from_defaults,
)

def add_chat(self, role: str, content: str, entities: Optional[List[str]] = None):
"""Add a chat to the agent's memory.
Args:
Expand All @@ -30,8 +47,7 @@ def add_chat(self,
if entities:
self.memory_stream.add_memory(entities)
self.memory_stream.save_memory()
self.entity_knowledge_store.add_memory(
self.memory_stream.get_memory())
self.entity_knowledge_store.add_memory(self.memory_stream.get_memory())
self.entity_knowledge_store.save_memory()

self._replace_memory_from_llm_message()
Expand All @@ -41,28 +57,27 @@ def get_chat(self):
return self.contexts

def clearMemory(self):
"""Clears Neo4j database and memory stream/entity knowledge store."""

logging.info("Deleting memory stream and entity knowledge store...")
self.memory_stream.clear_memory()
self.entity_knowledge_store.clear_memory()

# print("removed from mem stream and entity knowdlege store ")
"clears knowledge neo4j database"

print("Deleting nodes from Neo4j...")
logging.info("Deleting nodes from Neo4j...")
try:
self.graph_store.query("MATCH (n) DETACH DELETE n")
except Exception as e:
print(f"Error deleting nodes: {e}")
print("Nodes deleted from Neo4j.")
logging.error(f"Error deleting nodes: {e}")
logging.info("Nodes deleted from Neo4j.")

def _replace_memory_from_llm_message(self):
"""Replace the memory_stream from the llm_message."""
self.message.llm_message[
"memory_stream"] = self.memory_stream.get_memory()
self.message.llm_message["memory_stream"] = self.memory_stream.get_memory()

def _replace_eks_to_from_message(self):
"""Replace the entity knowledge store from the llm_message.
eks = entity knowledge store"""

self.message.llm_message[
"knowledge_entity_store"] = self.entity_knowledge_store.get_memory(
)
self.message.llm_message["knowledge_entity_store"] = (
self.entity_knowledge_store.get_memory()
)
12 changes: 5 additions & 7 deletions streamlit_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,12 @@ def get_models(llm_models, vision_models):

tools = st.multiselect(
"Select tools to include:",
# ["Search", "Location", "Vision", "Stocks", "News"], #all options available
# ["Search", "Location", "Vision", "Stocks", "News"],) #options that are selected by default
["Search", "Location", "Vision", "Stocks"], # all options available
["Search", "Location", "Vision", "Stocks"],
) # options that are selected by default
["search", "locate", "vision", "stocks"], # all options available
["search", "locate", "vision", "stocks"], # options that are selected by default
)

if "Vision" in tools:
img_url = ""
if "vision" in tools:
img_url = st.text_input("URL of image, leave blank if no image to provide")
if img_url:
st.image(img_url, caption="Uploaded Image", use_column_width=True)
Expand All @@ -167,7 +166,6 @@ def get_models(llm_models, vision_models):
st.write("Please select at least one tool")
st.stop()

print("start update tools")
chat_agent.update_tools(tools)

if img_url:
Expand Down

0 comments on commit d9906d8

Please sign in to comment.