Skip to content

Commit

Permalink
parallel function calling in openai agent
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich committed Dec 19, 2024
1 parent 2ba5f2e commit 91124f4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,6 @@ async def _arun_step(
step, task.extra_state["new_memory"], verbose=self._verbose
)

# TODO: see if we want to do step-based inputs
tools = self.get_tools(task.input)
openai_tools = [tool.metadata.to_openai_tool() for tool in tools]

Expand All @@ -670,40 +669,46 @@ async def _arun_step(
task, mode=mode, **llm_chat_kwargs
)

# TODO: implement _should_continue
latest_tool_calls = self.get_latest_tool_calls(task) or []
latest_tool_outputs: List[ToolOutput] = []

if not self._should_continue(
latest_tool_calls, task.extra_state["n_function_calls"]
):
is_done = True

else:
is_done = False

# Validate all tool calls first
for tool_call in latest_tool_calls:
# Some validation
if not isinstance(tool_call, get_args(OpenAIToolCall)):
raise ValueError("Invalid tool_call object")

if tool_call.type != "function":
raise ValueError("Invalid tool type. Unsupported by OpenAI")

# TODO: maybe execute this with multi-threading
return_direct = await self._acall_function(
tools,
tool_call,
task.extra_state["new_memory"],
latest_tool_outputs,
)
# Execute all tool calls in parallel using asyncio.gather
tool_results = await asyncio.gather(
*[
self._acall_function(
tools,
tool_call,
task.extra_state["new_memory"],
latest_tool_outputs,
)
for tool_call in latest_tool_calls
]
)

# Process results
for return_direct in tool_results:
task.extra_state["sources"].append(latest_tool_outputs[-1])

# change function call to the default value, if a custom function was given
# as an argument (none and auto are predefined by OpenAI)
# change function call to the default value if a custom function was given
if tool_choice not in ("auto", "none"):
tool_choice = "auto"
task.extra_state["n_function_calls"] += 1

# If any tool call requests direct return and it's the only call
if return_direct and len(latest_tool_calls) == 1:
is_done = True
response_str = latest_tool_outputs[-1].content
Expand All @@ -723,7 +728,6 @@ async def _arun_step(
[
step.get_next_step(
step_id=str(uuid.uuid4()),
# NOTE: input is unused
input=None,
)
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-agent-openai"
readme = "README.md"
version = "0.4.0"
version = "0.4.1"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
Expand Down

0 comments on commit 91124f4

Please sign in to comment.