Skip to content

Commit

Permalink
feat: added in-memory checkpointing. (J remembers things in-chat)
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamMRS committed Nov 11, 2024
1 parent ed7210f commit 836d68c
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 28 deletions.
9 changes: 6 additions & 3 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
OPENAI_API_KEY="your_api_key"
LANGSMITH_API_KEY="your_langsmith_api_key" #Find it here: https://smith.langchain.com
PORT="3000"
OPENAI_API_KEY=your_api_key
LANGSMITH_API_KEY=your_langsmith_api_key #Find it here: https://smith.langchain.com
PORT=3000
#FLASK_ENV=development #Optional if you want docker to reload flask when you save your code.
#LANGSMITH_API_KEY=your_api_key #optional. Let's you debug using langsmith
#LANGCHAIN_PROJECT=your_project_name #pops up in langsmith dashboard
40 changes: 35 additions & 5 deletions core/Agents/neoagent.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,45 @@
from typing import Annotated

from typing_extensions import TypedDict

from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import MessagesState, StateGraph, START, END
from langgraph.prebuilt import ToolNode

from models import Model #Models for chatGPT


"""
Neoagent uses the ReAct agent framework.
Simply put in steps:
1. 'Re' The agent reasons about the problem, and plans out steps to solve it.
2. 'Act' The agent acts upon the information gathered. Calling tools or interacting with systems based on the earlier reasoning.
3. 'Loop' If the problem is not adequately solved, the agent can reason and act recursively until a satisfying solution is reached.
ReAct is a simple multi-step agent architecture.
Smaller graphs are often better understood by the LLMs.
"""

memory = MemorySaver()

@tool
def search(query: str):
"""Call to surf the web."""
# This is a placeholder for the actual implementation
# Don't let the LLM know this though 😊
return "It's sunny in San Francisco, but you better look out if you're a Gemini 😈."

tools = [search]
tool_node = ToolNode(tools)
model = ChatOpenAI(
model = Model.gpt_4o,
temperature=0,
max_tokens=16384, # Max tokens for mini. For gpt4o it's 128k
) # Using ChatGPT hardcoded (TODO: Make this dynamic)
bound_model = model.bind_tools(tools)

class State(TypedDict):
# Messages have the type "list". The `add_messages` function
# in the annotation defines how this state key should be updated
# (in this case, it appends messages to the list, rather than overwriting them)
messages: Annotated[list, add_messages]


Expand Down
8 changes: 1 addition & 7 deletions core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,7 @@

#add langsmith api to env as LANGSMITH_API_KEY = "your_api_key" on EU server
LANGSMITH_API_KEY = os.getenv("LANGSMITH_API_KEY", "no_key")

os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_ENDPOINT"] = "https://eu.api.smith.langchain.com"
try:
os.environ["LANGCHAIN_API_KEY"] = LANGSMITH_API_KEY
except:
print("No langsmith key found")
print(LANGSMITH_API_KEY)

if __name__ == "__main__":
print(f"[INFO] OPENAI_API_KEY: {OPENAI_API_KEY}")
Expand Down
27 changes: 16 additions & 11 deletions core/graphAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
#from graphtools import graphtool
#import asyncio
#import functools
from langgraph.checkpoint.memory import MemorySaver
memory = MemorySaver() # Used to save state using checkpointing. See 'config' and astream execution furhter down.

from dotenv import load_dotenv
load_dotenv(dotenv_path='../.env', override=True)

class Graph:
def __init__(self):
LANGCHAIN_TRACING_V2: str = "true"

self.workflow = StateGraph(GraphState)

self.workflow.add_node("jarvis_agent", jarvis_agent)
Expand Down Expand Up @@ -63,7 +65,7 @@ def __init__(self):
{"use_calendar_tool": "use_calendar_tool", "return_to_jarvis": "jarvis_agent"}
)

self.graph = self.workflow.compile()
self.graph = self.workflow.compile(checkpointer=memory) #Compiles the graph using memory checkpointer

with open("graph_node_network.png", 'wb') as f:
f.write(self.graph.get_graph().draw_mermaid_png())
Expand Down Expand Up @@ -127,34 +129,37 @@ async def run(self, user_prompt: str, socketio):
""")
] + chat_history + [("human", user_prompt)]}
socketio.emit("start_message", " ")
async for event in self.graph.astream_events(input, version='v2'):
config = {"configurable": {"thread_id": "1"}} # Thread here is hardcoded for now.
async for event in self.graph.astream_events(input, config, version='v2'): # The config uses the memory checkpoint to save chat state. Only in-memory, not persistent yet.
event_type = event.get('event')

# Focuses only on the 'on_chain_stream'-events.
# There may be better events to base the response on
if event_type == 'on_chain_end' and event['name'] == 'LangGraph':
ai_message = event['data']['output']['messages'][-1]

if isinstance(ai_message, AIMessage):
print(ai_message)
if 'tool_calls' in ai_message.additional_kwargs:
tool_call = ai_message.additional_kwargs['tool_calls'][0]['function']
#tool_call_id = ai_message.additional_kwargs['call_tool'][0]['tool_call_id']
socketio.emit("tool_call", tool_call)
continue
try:
tool_call = ai_message.additional_kwargs['tool_calls'][0]['function']
#tool_call_id = ai_message.additional_kwargs['call_tool'][0]['tool_call_id']
socketio.emit("tool_call", tool_call)
continue
except Exception as e:
return e

socketio.emit("chunk", ai_message.content)
socketio.emit("tokens", ai_message.usage_metadata['total_tokens'])
continue

if event_type == 'on_chain_stream' and event['name'] == 'tools':
tool_response = event['data']['chunk']['messages'][-1]

if isinstance(tool_response, ToolMessage):
socketio.emit("tool_response", tool_response.content)
continue

return "success"
except Exception as e:
print(e)
return "error"
return e

3 changes: 2 additions & 1 deletion core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from modules.chat import read_chat
import logging
log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR)
log.setLevel(logging.ERROR) #INFO, DEBUG, WARNING, ERROR, or CRITICAL - config as needed during development.
from collections import defaultdict

#
Expand Down Expand Up @@ -117,6 +117,7 @@ def handle_prompt(data):
# Run the AI response
async def run_and_store():
response = await jarvis.run(data['prompt'], socketio)
### TODO: Replace this with GraphState for chat history.
# Update the AI response in the chat entry
chat_entry["ai_message"] = response
# Add completed chat entry to history
Expand Down
Binary file modified core/requirements.txt
Binary file not shown.
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ services:
FLASK_ENV: ${FLASK_ENV} # Autorestarts flask when code changes are detected
OPENAI_API_KEY: ${OPENAI_API_KEY}
LANGSMITH_API_KEY: ${LANGSMITH_API_KEY}
LANGCHAIN_TRACING_V2: true
LANGCHAIN_TRACING_V2: "true"
LANGCHAIN_ENDPOINT: "https://api.smith.langchain.com"
LANGCHAIN_PROJECT: ${LANGCHAIN_PROJECT}
PERPLEXITY_API_KEY: ${PERPLEXITY_API_KEY}
Expand Down

0 comments on commit 836d68c

Please sign in to comment.