Skip to content

Commit

Permalink
refactor: optimize graphAgent
Browse files Browse the repository at this point in the history
  • Loading branch information
JonBergland committed Oct 24, 2024
1 parent 81f1685 commit d9e6a1d
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions core/graphAgent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Literal
from langchain_openai import ChatOpenAI
from graphstate import GraphState
from tools.tools import get_tools
Expand All @@ -14,6 +15,7 @@
import functools

class Graph:
MAIN_AGENT = "chatbot"
def __init__(self):
LANGCHAIN_TRACING_V2: str = "true"

Expand All @@ -22,17 +24,17 @@ def __init__(self):

self.workflow = StateGraph(GraphState)

self.workflow.add_node("chatbot", self.chatbot)
self.workflow.add_node(self.MAIN_AGENT, self.chatbot)
self.workflow.add_node("tools", ToolNode(get_tools()))

self.workflow.add_edge(START, "chatbot")
self.workflow.add_edge("tools", "chatbot")
self.workflow.add_edge("chatbot", END)
self.workflow.add_edge(START, self.MAIN_AGENT)
self.workflow.add_edge("tools", self.MAIN_AGENT)

# Defining conditional edges
self.workflow.add_conditional_edges(
"chatbot",
tools_condition
self.MAIN_AGENT,
tools_condition,
{"tools": "tools", "__end__": END}
)
self.graph = self.workflow.compile()

Expand Down Expand Up @@ -73,8 +75,10 @@ async def run(self, user_prompt: str, socketio):
# There may be better events to base the response on
if event_type == 'on_chain_stream' and event['name'] == 'LangGraph':
chunk = event['data']['chunk']
if 'chatbot' in chunk:
ai_message = event['data']['chunk']['chatbot']['messages'][-1]

# Filters the stream to only get events by main agent
if self.MAIN_AGENT in chunk:
ai_message = event['data']['chunk'][self.MAIN_AGENT]['messages'][-1]

if isinstance(ai_message, AIMessage):
if 'tool_calls' in ai_message.additional_kwargs:
Expand Down

0 comments on commit d9e6a1d

Please sign in to comment.