Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/CogitoNTNU/jarvis
Browse files Browse the repository at this point in the history
  • Loading branch information
JonBergland committed Oct 22, 2024
2 parents d1e2ae8 + c90eb9f commit 077e97d
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 13 deletions.
58 changes: 46 additions & 12 deletions core/graphAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,66 @@
import json
from config import OPENAI_API_KEY
from Agents.simpleagent import SimpleAgent
#from agent import Agent, Agent1
from graphtools import graphtool
import asyncio
from time import sleep

import functools

class Graph:
def __init__(self):
LANGCHAIN_TRACING_V2: str = "true"

self.llm = SimpleAgent.llm

self.llm_with_tools = self.llm.bind_tools(get_tools())
llm = SimpleAgent.llm
proof_reader = graphtool.create_agent(
llm,
get_tools(),
system_message="You should proof read the text before you send it to the user.",
)
proof_read_node = functools.partial(graphtool.agent_node, agent=proof_reader, name="proof_reader")
simple_agent = graphtool.create_agent(
llm,
get_tools(),
system_message="You should take the input of the user and use the tools available to you to generate a response.",
)
simple_agent_node = functools.partial(graphtool.agent_node, agent=simple_agent, name="simple_agent")

tool_node = ToolNode(get_tools())
self.workflow = StateGraph(GraphState)
# Adding nodes to the workflow
self.workflow.add_node("chatbot", self.chatbot)
self.workflow.add_node("tools", ToolNode(get_tools()))
self.workflow.add_node("simple_agent", simple_agent_node)
self.workflow.add_node("proof_reader", proof_read_node)
self.workflow.add_node("call_tool", tool_node)
# TODO: Visualize these tools

# Defining edges between nodes
self.workflow.add_edge(START, "chatbot")
self.workflow.add_edge("tools", "chatbot")
self.workflow.add_edge("chatbot", END)
self.workflow.add_conditional_edges(
"simple_agent",
graphtool.router,
{"continue": "simple_agent", "call_tool": "call_tool", END: END},
)
self.workflow.add_conditional_edges(
"proof_reader",
graphtool.router,
{"continue": "proof_reader", "call_tool": "call_tool", END: END},
)

self.workflow.add_conditional_edges(
"call_tool",
# Each agent node updates the 'sender' field
# the tool calling node does not, meaning
# this edge will route back to the original agent
# who invoked the tool
lambda x: x["sender"],
{
"simple_agent": "simple_agent",
"proof_reader": "proof_reader",
},
)
self.workflow.add_edge(START, "simple_agent")
self.workflow.add_edge("proof_reader", END)

# Defining conditional edges
self.workflow.add_conditional_edges(
"chatbot",
"simple_agent",
tools_condition
)
self.graph = self.workflow.compile()
Expand Down
26 changes: 25 additions & 1 deletion core/graphtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ToolMessage,
)
import operator
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
Expand Down Expand Up @@ -49,4 +50,27 @@ def agent_node(state, agent, name):
# Since we have a strict workflow, we can
# track the sender so we know who to pass to next.
"sender": name,
}
}


def create_agent(llm, tools, system_message: str):
"""Create an agent"""
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK, another assistant with different tools "
" will help where you left off. Execute what you can to make progress."
" If you or any of the other assistants have the final answer or deliverable,"
" prefix your response with FINAL ANSWER so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
return prompt | llm.bind_tools(tools)

0 comments on commit 077e97d

Please sign in to comment.