Skip to content

Commit

Permalink
feat: improve the LLM agent's logic (#26)
Browse files Browse the repository at this point in the history
- Move to tool calls
- Enable parallel tool execution
- Improve should continue logic

Signed-off-by: Calum Murray <[email protected]>
  • Loading branch information
Cali0707 authored Jul 18, 2024
1 parent 79c42f7 commit c6186fa
Showing 1 changed file with 57 additions and 34 deletions.
91 changes: 57 additions & 34 deletions core/chat-app/chat.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import asyncio
from typing import Sequence, TypedDict, Annotated
import operator
import json
from dotenv import load_dotenv

from langchain_openai import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.tools.render import format_tool_to_openai_function
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.messages import FunctionMessage
from langchain.tools.render import format_tool_to_openai_tool
from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.globals import set_debug

from langgraph.prebuilt import ToolExecutor, ToolInvocation
Expand All @@ -32,80 +33,101 @@ class AgentState(TypedDict):
async def on_chat_start():
chat_history = ConversationBufferMemory(return_messages=True)
cl.user_session.set("chat_history", chat_history)

model = ChatOpenAI(temperature=0.1, streaming=True, max_retries=5, timeout=60.)

tools = [HumanInput()]
tools.extend(create_cloudevents_tools())

tool_executor = ToolExecutor(tools)

functions = [format_tool_to_openai_function(t) for t in tools]
tools = [format_tool_to_openai_tool(t) for t in tools]

prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant. "
"Use the provided tools to progress towards answering the question. "
"If you do not have the final answer yet, prefix your response with CONTINUE. "
"You have access to the following tools: {tool_names}.",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(tool_names=", ".join([tool["function"]["name"] for tool in tools]))

model = model.bind_functions(functions)
model = prompt | model.bind_tools(tools)

def should_continue(state: AgentState) -> str:
messages = state["messages"]
last_message = messages[-1]
if "function_call" not in last_message.additional_kwargs:
if "tool_calls" in last_message.additional_kwargs:
return "tool_call"
elif "CONTINUE" not in last_message.content:
return "end"
else:
return "continue"
return "model_call"

async def call_model(state: AgentState):
print("calling model...")
messages = state["messages"]
print(state)
response = await model.ainvoke(messages)
response = await model.ainvoke(state)
return {"messages": [response]}

async def call_tool(state: AgentState):
messages = state["messages"]
last_message = messages[-1]
print(last_message)

action = ToolInvocation(
tool=last_message.additional_kwargs["function_call"]["name"],
tool_input=json.loads(last_message.additional_kwargs["function_call"]["arguments"]),
)
tool_calls = []
tasks = []
for tool_call in last_message.additional_kwargs["tool_calls"]:
function_name = tool_call["function"]["name"]
action = ToolInvocation(
tool=function_name,
tool_input=json.loads(tool_call["function"]["arguments"]),
)
tool_calls.append(action)
tasks.append(cl.Task(title=action.tool, status=cl.TaskStatus.RUNNING))

tasks_list = cl.user_session.get("tasks_list")
if tasks_list is None:
tasks_list = cl.TaskList()

task = cl.Task(title=action.tool, status=cl.TaskStatus.RUNNING)
await tasks_list.add_task(task)
await asyncio.gather(*[tasks_list.add_task(task) for task in tasks])
await tasks_list.send()

response = await tool_executor.ainvoke(action)
print(f"response: {response}")

function_message = FunctionMessage(content=str(response), name=action.tool)
responses = await asyncio.gather(*[tool_executor.ainvoke(tool_call) for tool_call in tool_calls])
tool_messages = []
for i in range(len(responses)):
tool_messages.append(ToolMessage(tool_call_id=last_message.additional_kwargs["tool_calls"][i]["id"], content=str(responses[i]), name=tool_calls[i].tool))

task.status = cl.TaskStatus.DONE
for task in tasks:
task.status = cl.TaskStatus.DONE
await tasks_list.send()

cl.user_session.set("tasks_list", tasks_list)

return {"messages": [function_message]}
return {"messages": tool_messages}

graph = StateGraph(AgentState)

graph.add_node("agent", call_model)
graph.add_node("action", call_tool)
graph.add_node("model", call_model)
graph.add_node("tool", call_tool)

graph.set_entry_point("agent")
graph.set_entry_point("model")

graph.add_conditional_edges(
"agent",
"model",
should_continue,
{
"continue": "action",
"tool_call": "tool",
"model_call": "model",
"end": END,
},
)

graph.add_edge("action", "agent")
graph.add_edge("tool", "model")

memory = MemorySaver()

Expand All @@ -128,17 +150,18 @@ async def main(message: cl.Message):
}
cl.user_session.set("tasks_list", None)


#res = await runner.ainvoke(inputs)

msg = cl.Message(content="")

chunk = ""
async for event in runner.astream_events(inputs, {"configurable": {"thread_id": "thread-1"}}, version="v1"):
kind = event["event"]
if kind == "on_chat_model_stream":
content = event["data"]["chunk"].content
content= event["data"]["chunk"].content
if content:
await msg.stream_token(content)
chunk += content
if chunk.strip() not in "CONTINUE":
await msg.stream_token(content)
elif chunk.strip() == "CONTINUE":
chunk = ""

await msg.send()
#await cl.Message(content=res["messages"][-1].content).send()

0 comments on commit c6186fa

Please sign in to comment.