diff --git a/docs/docs/api_reference/agent/workflow.md b/docs/docs/api_reference/agent/workflow.md new file mode 100644 index 0000000000000..aefb274d845cf --- /dev/null +++ b/docs/docs/api_reference/agent/workflow.md @@ -0,0 +1,12 @@ +::: llama_index.core.agent.workflow + options: + members: + - MultiAgentWorkflow + - BaseWorkflowAgent + - FunctionAgent + - ReactAgent + - AgentInput + - AgentStream + - AgentOutput + - ToolCall + - ToolCallResult diff --git a/docs/docs/understanding/agent/multi_agents.md b/docs/docs/understanding/agent/multi_agents.md new file mode 100644 index 0000000000000..3d1a299bcea17 --- /dev/null +++ b/docs/docs/understanding/agent/multi_agents.md @@ -0,0 +1,257 @@ +# Multi-Agent Workflows + +The MultiAgentWorkflow uses Workflow Agents to allow you to create a system of multiple agents that can collaborate and hand off tasks to each other based on their specialized capabilities. This enables building more complex agent systems where different agents handle different aspects of a task. + +## Quick Start + +Here's a simple example of setting up a multi-agent workflow with a calculator agent and a retriever agent: + +```python +from llama_index.core.agent.workflow import ( + MultiAgentWorkflow, + FunctionAgent, + ReactAgent, +) +from llama_index.core.tools import FunctionTool + + +# Define some tools +def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + +def subtract(a: int, b: int) -> int: + """Subtract two numbers.""" + return a - b + + +# Create agent configs +# NOTE: we can use FunctionAgent or ReactAgent here. +# FunctionAgent works for LLMs with a function calling API. +# ReactAgent works for any LLM. +calculator_agent = FunctionAgent( + name="calculator", + description="Performs basic arithmetic operations", + system_prompt="You are a calculator assistant.", + tools=[ + FunctionTool.from_defaults(fn=add), + FunctionTool.from_defaults(fn=subtract), + ], + llm=OpenAI(model="gpt-4"), +) + +retriever_agent = FunctionAgent( + name="retriever", + description="Manages data retrieval", + system_prompt="You are a retrieval assistant.", + is_entrypoint_agent=True, + llm=OpenAI(model="gpt-4"), +) + +# Create and run the workflow +workflow = MultiAgentWorkflow( + agent_configs=[calculator_agent, retriever_agent] +) + +# Run the system +response = await workflow.run(user_msg="Can you add 5 and 3?") + +# Or stream the events +handler = workflow.run(user_msg="Can you add 5 and 3?") +async for event in handler.stream_events(): + if hasattr(event, "delta"): + print(event.delta, end="", flush=True) +``` + +## How It Works + +The MultiAgentWorkflow manages a collection of agents, each with their own specialized capabilities. One agent must be designated as the entry point agent (`is_entrypoint_agent=True`). + +When a user message comes in, it's first routed to the entry point agent. Each agent can then: + +1. Handle the request directly using their tools +2. Hand off to another agent better suited for the task +3. Return a response to the user + +## Configuration Options + +### Agent Config + +Each agent holds a certain set of configuration options. Whether you use `FunctionAgent` or `ReactAgent`, the core options are the same. + +```python +FunctionAgent( + # Unique name for the agent (str) + name="name", + # Description of agent's capabilities (str) + description="description", + # System prompt for the agent (str) + system_prompt="system_prompt", + # Tools available to this agent (List[BaseTool]) + tools=[...], + # LLM to use for this agent. (BaseLLM) + llm=OpenAI(model="gpt-4"), + # Whether this is the entry point. (bool) + is_entrypoint_agent=True, + # List of agents this one can hand off to. Defaults to all agents. (List[str]) + can_handoff_to=[...], +) +``` + +### Workflow Options + +The MultiAgentWorkflow constructor accepts: + +```python +MultiAgentWorkflow( + # List of agent configs. (List[BaseWorkflowAgent]) + agents=[...], + # Initial state dict. (Optional[dict]) + initial_state=None, + # Custom prompt for handoffs. Should contain the `agent_info` string variable. (Optional[str]) + handoff_prompt=None, + # Custom prompt for state. Should contain the `state` and `msg` string variables. (Optional[str]) + state_prompt=None, +) +``` + +### State Management + +#### Initial Global State + +You can provide an initial state dict that will be available to all agents: + +```python +workflow = MultiAgentWorkflow( + agents=[...], + initial_state={"counter": 0}, + state_prompt="Current state: {state}. User message: {msg}", +) +``` + +The state is stored in the `state` key of the workflow context. + +#### Persisting State Between Runs + +In order to persist state between runs, you can pass in the context from the previous run: + +```python +workflow = MultiAgentWorkflow(...) + +# Run the workflow +handler = workflow.run(user_msg="Can you add 5 and 3?") +response = await handler + +# Pass in the context from the previous run +handler = workflow.run(ctx=handler.ctx, user_msg="Can you add 5 and 3?") +response = await handler +``` + +#### Serializing Context / State + +As with normal workflows, the context is serializable: + +```python +from llama_index.core.workflow import ( + Context, + JsonSerializer, + JsonPickleSerializer, +) + +# the default serializer is JsonSerializer for safety +ctx_dict = handler.ctx.to_dict(serializer=JsonSerializer()) + +# then you can rehydrate the context +ctx = Context.from_dict(ctx_dict, serializer=JsonSerializer()) +``` + +## Streaming Events + +The workflow emits various events during execution that you can stream: + +```python +async for event in workflow.run(...).stream_events(): + if isinstance(event, AgentInput): + print(event.input) + print(event.current_agent_name) + elif isinstance(event, AgentStream): + # Agent thinking/tool calling response stream + print(event.delta) + print(event.current_agent_name) + elif isinstance(event, AgentOutput): + print(event.response) + print(event.tool_calls) + print(event.raw) + print(event.current_agent_name) + elif isinstance(event, ToolCall): + # Tool being called + print(event.tool_name) + print(event.tool_kwargs) + elif isinstance(event, ToolCallResult): + # Result of tool call + print(event.tool_output) +``` + +## Accessing Context in Tools + +The `FunctionToolWithContext` allows tools to access the workflow context: + +```python +from llama_index.core.workflow import FunctionToolWithContext + + +async def get_counter(ctx: Context) -> int: + """Get the current counter value.""" + return await ctx.get("counter", default=0) + + +counter_tool = FunctionToolWithContext.from_defaults( + async_fn=get_counter, description="Get the current counter value" +) +``` + +## Human in the Loop + +Using the context, you can implement a human in the loop pattern in your tools: + +```python +from llama_index.core.workflow import Event + + +class AskForConfirmationEvent(Event): + """Ask for confirmation event.""" + + confirmation_id: str + + +class ConfirmationEvent(Event): + """Confirmation event.""" + + confirmation: bool + confirmation_id: str + + +async def ask_for_confirmation(ctx: Context) -> bool: + """Ask the user for confirmation.""" + ctx.write_event_to_stream(AskForConfirmationEvent(confirmation_id="1234")) + + result = await ctx.wait_for_event( + ConfirmationEvent, requirements={"confirmation_id": "1234"} + ) + return result.confirmation +``` + +When this function is called, it will block the workflow execution until the user sends the required confirmation event. + +```python +handler = workflow.run(user_msg="Can you add 5 and 3?") + +async for event in handler.stream_events(): + if isinstance(event, AskForConfirmationEvent): + print(event.confirmation_id) + handler.ctx.send_event( + ConfirmationEvent(confirmation=True, confirmation_id="1234") + ) + ... +``` diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 0f5b103a290af..c5f5ee60ea269 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -50,6 +50,7 @@ nav: - Enhancing with LlamaParse: ./understanding/agent/llamaparse.md - Memory: ./understanding/agent/memory.md - Adding other tools: ./understanding/agent/tools.md + - Multi-agent workflows: ./understanding/agent/multi_agents.md - Building Workflows: - Introduction to workflows: ./understanding/workflows/index.md - A basic workflow: ./understanding/workflows/basic_flow.md @@ -852,6 +853,7 @@ nav: - ./api_reference/agent/openai.md - ./api_reference/agent/openai_legacy.md - ./api_reference/agent/react.md + - ./api_reference/agent/workflow.md - Callbacks: - ./api_reference/callbacks/agentops.md - ./api_reference/callbacks/aim.md diff --git a/llama-index-core/llama_index/core/agent/workflow/BUILD b/llama-index-core/llama_index/core/agent/workflow/BUILD new file mode 100644 index 0000000000000..db46e8d6c978c --- /dev/null +++ b/llama-index-core/llama_index/core/agent/workflow/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-core/llama_index/core/agent/workflow/__init__.py b/llama-index-core/llama_index/core/agent/workflow/__init__.py new file mode 100644 index 0000000000000..aba7b324bcbcc --- /dev/null +++ b/llama-index-core/llama_index/core/agent/workflow/__init__.py @@ -0,0 +1,26 @@ +from llama_index.core.agent.workflow.multi_agent_workflow import MultiAgentWorkflow +from llama_index.core.agent.workflow.base_agent import BaseWorkflowAgent +from llama_index.core.agent.workflow.function_agent import FunctionAgent +from llama_index.core.agent.workflow.react_agent import ReactAgent +from llama_index.core.agent.workflow.workflow_events import ( + AgentInput, + AgentSetup, + AgentStream, + AgentOutput, + ToolCall, + ToolCallResult, +) + + +__all__ = [ + "AgentInput", + "AgentSetup", + "AgentStream", + "AgentOutput", + "BaseWorkflowAgent", + "FunctionAgent", + "MultiAgentWorkflow", + "ReactAgent", + "ToolCall", + "ToolCallResult", +] diff --git a/llama-index-core/llama_index/core/agent/workflow/base_agent.py b/llama-index-core/llama_index/core/agent/workflow/base_agent.py new file mode 100644 index 0000000000000..ea20220fcf54c --- /dev/null +++ b/llama-index-core/llama_index/core/agent/workflow/base_agent.py @@ -0,0 +1,71 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +from llama_index.core.agent.workflow.workflow_events import ( + AgentOutput, + ToolCallResult, +) +from llama_index.core.bridge.pydantic import BaseModel, Field, ConfigDict +from llama_index.core.llms import ChatMessage, LLM +from llama_index.core.memory import BaseMemory +from llama_index.core.tools import BaseTool, AsyncBaseTool +from llama_index.core.workflow import Context +from llama_index.core.objects import ObjectRetriever +from llama_index.core.settings import Settings + + +def get_default_llm() -> LLM: + return Settings.llm + + +class BaseWorkflowAgent(BaseModel, ABC): + """Base class for all agents, combining config and logic.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + name: str = Field(description="The name of the agent") + description: str = Field( + description="The description of what the agent does and is responsible for" + ) + system_prompt: Optional[str] = Field( + default=None, description="The system prompt for the agent" + ) + tools: Optional[List[BaseTool]] = Field( + default=None, description="The tools that the agent can use" + ) + tool_retriever: Optional[ObjectRetriever] = Field( + default=None, + description="The tool retriever for the agent, can be provided instead of tools", + ) + can_handoff_to: Optional[List[str]] = Field( + default=None, description="The agent names that this agent can hand off to" + ) + llm: LLM = Field( + default_factory=get_default_llm, description="The LLM that the agent uses" + ) + is_entrypoint_agent: bool = Field( + default=False, + description="Whether the agent is the entrypoint agent in a multi-agent workflow", + ) + + @abstractmethod + async def take_step( + self, + ctx: Context, + llm_input: List[ChatMessage], + tools: List[AsyncBaseTool], + memory: BaseMemory, + ) -> AgentOutput: + """Take a single step with the agent.""" + + @abstractmethod + async def handle_tool_call_results( + self, ctx: Context, results: List[ToolCallResult], memory: BaseMemory + ) -> None: + """Handle tool call results.""" + + @abstractmethod + async def finalize( + self, ctx: Context, output: AgentOutput, memory: BaseMemory + ) -> AgentOutput: + """Finalize the agent's execution.""" diff --git a/llama-index-core/llama_index/core/agent/workflow/function_agent.py b/llama-index-core/llama_index/core/agent/workflow/function_agent.py new file mode 100644 index 0000000000000..dab790679d2cf --- /dev/null +++ b/llama-index-core/llama_index/core/agent/workflow/function_agent.py @@ -0,0 +1,114 @@ +from typing import List + +from llama_index.core.agent.workflow.base_agent import BaseWorkflowAgent +from llama_index.core.agent.workflow.workflow_events import ( + AgentInput, + AgentOutput, + AgentStream, + ToolCallResult, +) +from llama_index.core.llms import ChatMessage +from llama_index.core.memory import BaseMemory +from llama_index.core.tools import AsyncBaseTool +from llama_index.core.workflow import Context + + +class FunctionAgent(BaseWorkflowAgent): + """Function calling agent implementation.""" + + scratchpad_key: str = "scratchpad" + + async def take_step( + self, + ctx: Context, + llm_input: List[ChatMessage], + tools: List[AsyncBaseTool], + memory: BaseMemory, + ) -> AgentOutput: + """Take a single step with the function calling agent.""" + if not self.llm.metadata.is_function_calling_model: + raise ValueError("LLM must be a FunctionCallingLLM") + + scratchpad: List[ChatMessage] = await ctx.get(self.scratchpad_key, default=[]) + current_llm_input = [*llm_input, *scratchpad] + + ctx.write_event_to_stream( + AgentInput(input=current_llm_input, current_agent_name=self.name) + ) + + response = await self.llm.astream_chat_with_tools( # type: ignore + tools, chat_history=current_llm_input, allow_parallel_tool_calls=True + ) + async for r in response: + tool_calls = self.llm.get_tool_calls_from_response( # type: ignore + r, error_on_no_tool_call=False + ) + ctx.write_event_to_stream( + AgentStream( + delta=r.delta or "", + response=r.message.content, + tool_calls=tool_calls or [], + raw=r.raw, + current_agent_name=self.name, + ) + ) + + tool_calls = self.llm.get_tool_calls_from_response( # type: ignore + r, error_on_no_tool_call=False + ) + + # only add to scratchpad if we didn't select the handoff tool + if not any(tool_call.tool_name == "handoff" for tool_call in tool_calls): + scratchpad.append(r.message) + await ctx.set(self.scratchpad_key, scratchpad) + + return AgentOutput( + response=r.message.content, + tool_calls=tool_calls or [], + raw=r.raw, + current_agent_name=self.name, + ) + + async def handle_tool_call_results( + self, ctx: Context, results: List[ToolCallResult], memory: BaseMemory + ) -> None: + """Handle tool call results for function calling agent.""" + scratchpad: List[ChatMessage] = await ctx.get(self.scratchpad_key, default=[]) + + for tool_call_result in results: + # don't add handoff tool calls to memory + if tool_call_result.tool_name == "handoff": + continue + + scratchpad.append( + ChatMessage( + role="tool", + content=str(tool_call_result.tool_output.content), + additional_kwargs={"tool_call_id": tool_call_result.tool_id}, + ) + ) + + if tool_call_result.return_direct: + scratchpad.append( + ChatMessage( + role="assistant", + content=str(tool_call_result.tool_output.content), + additional_kwargs={"tool_call_id": tool_call_result.tool_id}, + ) + ) + break + + await ctx.set(self.scratchpad_key, scratchpad) + + async def finalize( + self, ctx: Context, output: AgentOutput, memory: BaseMemory + ) -> AgentOutput: + """Finalize the function calling agent. + + Adds all in-progress messages to memory. + """ + scratchpad: List[ChatMessage] = await ctx.get(self.scratchpad_key, default=[]) + for msg in scratchpad: + await memory.aput(msg) + + return output diff --git a/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py b/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py new file mode 100644 index 0000000000000..60cddef841b3e --- /dev/null +++ b/llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py @@ -0,0 +1,358 @@ +from typing import Any, Dict, List, Optional, Sequence, Union + +from llama_index.core.agent.workflow.base_agent import BaseWorkflowAgent +from llama_index.core.agent.workflow.workflow_events import ( + ToolCall, + ToolCallResult, + AgentInput, + AgentSetup, + AgentOutput, +) +from llama_index.core.llms import ChatMessage +from llama_index.core.llms.llm import ToolSelection +from llama_index.core.memory import BaseMemory, ChatMemoryBuffer +from llama_index.core.prompts import BasePromptTemplate, PromptTemplate +from llama_index.core.tools import ( + BaseTool, + AsyncBaseTool, + ToolOutput, + adapt_to_async_tool, +) +from llama_index.core.workflow import ( + Context, + FunctionToolWithContext, + StartEvent, + StopEvent, + Workflow, + step, +) +from llama_index.core.settings import Settings + + +DEFAULT_HANDOFF_PROMPT = """Useful for handing off to another agent. +If you are currently not equipped to handle the user's request, or another agent is better suited to handle the request, please hand off to the appropriate agent. + +Currently available agents: +{agent_info} +""" + + +async def handoff(ctx: Context, to_agent: str, reason: str) -> str: + """Handoff control of that chat to the given agent.""" + agents: dict[str, BaseWorkflowAgent] = await ctx.get("agents") + current_agent: BaseWorkflowAgent = await ctx.get("current_agent") + if to_agent not in agents: + valid_agents = ", ".join([x for x in agents if x != current_agent.name]) + return f"Agent {to_agent} not found. Please select a valid agent to hand off to. Valid agents: {valid_agents}" + + await ctx.set("next_agent", agents[to_agent].name) + return f"Handed off to {to_agent} because: {reason}" + + +class MultiAgentWorkflow(Workflow): + """A workflow for managing multiple agents with handoffs.""" + + def __init__( + self, + agents: List[BaseWorkflowAgent], + initial_state: Optional[Dict] = None, + handoff_prompt: Optional[Union[str, BasePromptTemplate]] = None, + state_prompt: Optional[Union[str, BasePromptTemplate]] = None, + timeout: Optional[float] = None, + **workflow_kwargs: Any, + ): + super().__init__(timeout=timeout, **workflow_kwargs) + if not agents: + raise ValueError("At least one agent must be provided") + + self.agents = {cfg.name: cfg for cfg in agents} + only_one_root_agent = sum(cfg.is_entrypoint_agent for cfg in agents) == 1 + if not only_one_root_agent: + raise ValueError("Exactly one root agent must be provided") + + self.root_agent = next(agent for agent in agents if agent.is_entrypoint_agent) + + self.initial_state = initial_state or {} + + self.handoff_prompt = handoff_prompt or DEFAULT_HANDOFF_PROMPT + if isinstance(self.handoff_prompt, str): + self.handoff_prompt = PromptTemplate(self.handoff_prompt) + if "{agent_info}" not in self.handoff_prompt.template: + raise ValueError("Handoff prompt must contain {agent_info}") + + self.state_prompt = state_prompt + if isinstance(self.state_prompt, str): + self.state_prompt = PromptTemplate(self.state_prompt) + if ( + "{state}" not in self.state_prompt.template + or "{msg}" not in self.state_prompt.template + ): + raise ValueError("State prompt must contain {state} and {msg}") + + def _ensure_tools_are_async( + self, tools: Sequence[BaseTool] + ) -> Sequence[AsyncBaseTool]: + """Ensure all tools are async.""" + return [adapt_to_async_tool(tool) for tool in tools] + + def _get_handoff_tool(self, current_agent: BaseWorkflowAgent) -> AsyncBaseTool: + """Creates a handoff tool for the given agent.""" + agent_info = {cfg.name: cfg.description for cfg in self.agents.values()} + + # Filter out agents that the current agent cannot handoff to + configs_to_remove = [] + for name in agent_info: + if name == current_agent.name: + configs_to_remove.append(name) + elif ( + current_agent.can_handoff_to is not None + and name not in current_agent.can_handoff_to + ): + configs_to_remove.append(name) + + for name in configs_to_remove: + agent_info.pop(name) + + fn_tool_prompt = self.handoff_prompt.format(agent_info=str(agent_info)) + return FunctionToolWithContext.from_defaults( + async_fn=handoff, description=fn_tool_prompt, return_direct=True + ) + + async def _init_context(self, ctx: Context, ev: StartEvent) -> None: + """Initialize the context once, if needed.""" + if not await ctx.get("memory", default=None): + default_memory = ev.get("memory", default=None) + default_memory = default_memory or ChatMemoryBuffer.from_defaults( + llm=self.agents[self.root_agent.name].llm or Settings.llm + ) + await ctx.set("memory", default_memory) + if not await ctx.get("agents", default=None): + await ctx.set("agents", self.agents) + if not await ctx.get("state", default=None): + await ctx.set("state", self.initial_state) + if not await ctx.get("current_agent", default=None): + await ctx.set("current_agent", self.root_agent) + + async def _call_tool( + self, + ctx: Context, + tool: AsyncBaseTool, + tool_input: dict, + ) -> ToolOutput: + """Call the given tool with the given input.""" + try: + if isinstance(tool, FunctionToolWithContext): + tool_output = await tool.acall(ctx=ctx, **tool_input) + else: + tool_output = await tool.acall(**tool_input) + except Exception as e: + tool_output = ToolOutput( + content=str(e), + tool_name=tool.metadata.name, + raw_input=tool_input, + raw_output=str(e), + is_error=True, + ) + + return tool_output + + @step + async def init_run(self, ctx: Context, ev: StartEvent) -> AgentInput: + """Sets up the workflow and validates inputs.""" + await self._init_context(ctx, ev) + + user_msg = ev.get("user_msg") + chat_history = ev.get("chat_history") + if user_msg and chat_history: + raise ValueError("Cannot provide both user_msg and chat_history") + + if isinstance(user_msg, str): + user_msg = ChatMessage(role="user", content=user_msg) + + await ctx.set("user_msg_str", user_msg.content) + + # Add messages to memory + memory: BaseMemory = await ctx.get("memory") + if user_msg: + # Add the state to the user message if it exists and if requested + current_state = await ctx.get("state") + if self.state_prompt and current_state: + user_msg.content = self.state_prompt.format( + state=current_state, msg=user_msg.content + ) + + await memory.aput(user_msg) + input_messages = memory.get(input=user_msg.content) + else: + memory.set(chat_history) + input_messages = memory.get() + + # send to the current agent + current_agent: BaseWorkflowAgent = await ctx.get("current_agent") + return AgentInput(input=input_messages, current_agent_name=current_agent.name) + + @step + async def setup_agent(self, ctx: Context, ev: AgentInput) -> AgentSetup: + """Main agent handling logic.""" + current_agent_name = ev.current_agent_name + agent = self.agents[current_agent_name] + llm_input = ev.input + + # Set up the tools + tools = [*agent.tools] if agent.tools else [] + if agent.tool_retriever: + retrieved_tools = await agent.tool_retriever.aretrieve( + llm_input[-1].content or str(llm_input) + ) + tools.extend(retrieved_tools) + + if agent.can_handoff_to or agent.can_handoff_to is None: + handoff_tool = self._get_handoff_tool(agent) + tools.append(handoff_tool) + + async_tools = self._ensure_tools_are_async(tools) + + if agent.system_prompt: + llm_input = [ + ChatMessage(role="system", content=agent.system_prompt), + *llm_input, + ] + + await ctx.set("tools_by_name", {tool.metadata.name: tool for tool in tools}) + + return AgentSetup( + input=llm_input, + current_agent_name=ev.current_agent_name, + tools=async_tools, + ) + + @step + async def run_agent_step(self, ctx: Context, ev: AgentSetup) -> AgentOutput: + """Run the agent.""" + memory: BaseMemory = await ctx.get("memory") + agent = self.agents[ev.current_agent_name] + + return await agent.take_step( + ctx, + ev.input, + ev.tools, + memory, + ) + + @step + async def parse_agent_output( + self, ctx: Context, ev: AgentOutput + ) -> Union[StopEvent, ToolCall, None]: + if not ev.tool_calls: + agent = self.agents[ev.current_agent_name] + memory: BaseMemory = await ctx.get("memory") + output = await agent.finalize(ctx, ev, memory) + + return StopEvent(result=output) + + await ctx.set("num_tool_calls", len(ev.tool_calls)) + + for tool_call in ev.tool_calls: + ctx.send_event( + ToolCall( + tool_name=tool_call.tool_name, + tool_kwargs=tool_call.tool_kwargs, + tool_id=tool_call.tool_id, + ) + ) + + return None + + @step + async def call_tool(self, ctx: Context, ev: ToolCall) -> ToolCallResult: + """Calls the tool and handles the result.""" + ctx.write_event_to_stream(ev) + + tools_by_name: dict[str, AsyncBaseTool] = await ctx.get("tools_by_name") + if ev.tool_name not in tools_by_name: + tool = None + result = ToolOutput( + content=f"Tool {ev.tool_name} not found. Please select a tool that is available.", + tool_name=ev.tool_name, + raw_input=ev.tool_kwargs, + raw_output=None, + is_error=True, + ) + else: + tool = tools_by_name[ev.tool_name] + result = await self._call_tool(ctx, tool, ev.tool_kwargs) + + result_ev = ToolCallResult( + tool_name=ev.tool_name, + tool_kwargs=ev.tool_kwargs, + tool_id=ev.tool_id, + tool_output=result, + return_direct=tool.metadata.return_direct if tool else False, + ) + + ctx.write_event_to_stream(result_ev) + return result_ev + + @step + async def aggregate_tool_results( + self, ctx: Context, ev: ToolCallResult + ) -> Union[AgentInput, StopEvent, None]: + """Aggregate tool results and return the next agent input.""" + num_tool_calls = await ctx.get("num_tool_calls", default=0) + if num_tool_calls == 0: + raise ValueError("No tool calls found, cannot aggregate results.") + + tool_call_results: list[ToolCallResult] = ctx.collect_events( # type: ignore + ev, expected=[ToolCallResult] * num_tool_calls + ) + if not tool_call_results: + return None + + memory: BaseMemory = await ctx.get("memory") + agent: BaseWorkflowAgent = await ctx.get("current_agent") + + await agent.handle_tool_call_results(ctx, tool_call_results, memory) + + # set the next agent, if needed + # the handoff tool sets this + next_agent_name = await ctx.get("next_agent", default=None) + if next_agent_name: + await ctx.set("current_agent", self.agents[next_agent_name]) + + if any( + tool_call_result.return_direct for tool_call_result in tool_call_results + ): + # if any tool calls return directly, take the first one + return_direct_tool = next( + tool_call_result + for tool_call_result in tool_call_results + if tool_call_result.return_direct + ) + + # always finalize the agent, even if we're just handing off + result = AgentOutput( + response=return_direct_tool.tool_output.content, + tool_calls=[ + ToolSelection( + tool_id=t.tool_id, + tool_name=t.tool_name, + tool_kwargs=t.tool_kwargs, + ) + for t in tool_call_results + ], + raw=return_direct_tool.tool_output.raw_output, + current_agent_name=agent.name, + ) + result = await agent.finalize(ctx, result, memory) + + # we don't want to stop the system if we're just handing off + if return_direct_tool.tool_name != "handoff": + return StopEvent(result=result) + + user_msg_str = await ctx.get("user_msg_str") + input_messages = memory.get(input=user_msg_str) + + # get this again, in case it changed + agent = await ctx.get("current_agent") + + return AgentInput(input=input_messages, current_agent_name=agent.name) diff --git a/llama-index-core/llama_index/core/agent/workflow/react_agent.py b/llama-index-core/llama_index/core/agent/workflow/react_agent.py new file mode 100644 index 0000000000000..1f51c2a539815 --- /dev/null +++ b/llama-index-core/llama_index/core/agent/workflow/react_agent.py @@ -0,0 +1,189 @@ +import uuid +from typing import List, Optional, cast + +from llama_index.core.agent.workflow.base_agent import BaseWorkflowAgent +from llama_index.core.agent.workflow.workflow_events import ( + AgentInput, + AgentOutput, + AgentStream, + ToolCallResult, +) +from llama_index.core.agent.react.formatter import ReActChatFormatter +from llama_index.core.agent.react.output_parser import ReActOutputParser +from llama_index.core.agent.react.types import ( + ActionReasoningStep, + BaseReasoningStep, + ObservationReasoningStep, + ResponseReasoningStep, +) +from llama_index.core.bridge.pydantic import Field +from llama_index.core.llms import ChatMessage +from llama_index.core.llms.llm import ToolSelection +from llama_index.core.memory import BaseMemory +from llama_index.core.tools import AsyncBaseTool +from llama_index.core.workflow import Context + + +class ReactAgent(BaseWorkflowAgent): + """React agent implementation.""" + + reasoning_key: str = "current_reasoning" + output_parser: Optional[ReActOutputParser] = Field( + default=None, description="The react output parser" + ) + formatter: Optional[ReActChatFormatter] = Field( + default=None, + description="The react chat formatter to format the reasoning steps and chat history into an llm input.", + ) + + async def take_step( + self, + ctx: Context, + llm_input: List[ChatMessage], + tools: List[AsyncBaseTool], + memory: BaseMemory, + ) -> AgentOutput: + """Take a single step with the React agent.""" + # remove system prompt, since the react prompt will be combined with it + if llm_input[0].role == "system": + system_prompt = llm_input[0].content or "" + llm_input = llm_input[1:] + else: + system_prompt = "" + + output_parser = self.output_parser or ReActOutputParser() + react_chat_formatter = self.formatter or ReActChatFormatter( + context=system_prompt + ) + + # Format initial chat input + current_reasoning: list[BaseReasoningStep] = await ctx.get( + self.reasoning_key, default=[] + ) + input_chat = react_chat_formatter.format( + tools, + chat_history=llm_input, + current_reasoning=current_reasoning, + ) + + ctx.write_event_to_stream( + AgentInput(input=input_chat, current_agent_name=self.name) + ) + + # Initial LLM call + response = await self.llm.astream_chat(input_chat) + async for r in response: + ctx.write_event_to_stream( + AgentStream( + delta=r.delta or "", + response=r.message.content, + tool_calls=[], + raw=r.raw, + current_agent_name=self.name, + ) + ) + + # Parse reasoning step and check if done + message_content = r.message.content + if not message_content: + raise ValueError("Got empty message") + + try: + reasoning_step = output_parser.parse(message_content, is_streaming=False) + except ValueError as e: + error_msg = f"Error: Could not parse output. Please follow the thought-action-input format. Try again. Details: {e!s}" + await memory.aput(r.message) + await memory.aput(ChatMessage(role="user", content=error_msg)) + + return AgentOutput( + response=r.message.content, + tool_calls=[], + raw=r.raw, + current_agent_name=self.name, + ) + + # add to reasoning if not a handoff + if hasattr(reasoning_step, "action") and reasoning_step.action != "handoff": + current_reasoning.append(reasoning_step) + await ctx.set(self.reasoning_key, current_reasoning) + + # If response step, we're done + if reasoning_step.is_done: + return AgentOutput( + response=r.message.content, + tool_calls=[], + raw=r.raw, + current_agent_name=self.name, + ) + + reasoning_step = cast(ActionReasoningStep, reasoning_step) + if not isinstance(reasoning_step, ActionReasoningStep): + raise ValueError(f"Expected ActionReasoningStep, got {reasoning_step}") + + # Create tool call + tool_calls = [ + ToolSelection( + tool_id=str(uuid.uuid4()), + tool_name=reasoning_step.action, + tool_kwargs=reasoning_step.action_input, + ) + ] + + return AgentOutput( + response=r.message.content, + tool_calls=tool_calls, + raw=r.raw, + current_agent_name=self.name, + ) + + async def handle_tool_call_results( + self, ctx: Context, results: List[ToolCallResult], memory: BaseMemory + ) -> None: + """Handle tool call results for React agent.""" + current_reasoning: list[BaseReasoningStep] = await ctx.get( + self.reasoning_key, default=[] + ) + for tool_call_result in results: + # don't add handoff tool calls to reasoning + if tool_call_result.tool_name == "handoff": + continue + + obs_step = ObservationReasoningStep( + observation=str(tool_call_result.tool_output.content), + return_direct=tool_call_result.return_direct, + ) + current_reasoning.append(obs_step) + + if tool_call_result.return_direct: + current_reasoning.append( + ResponseReasoningStep( + thought=obs_step.observation, + response=obs_step.observation, + is_streaming=False, + ) + ) + break + + await ctx.set(self.reasoning_key, current_reasoning) + + async def finalize( + self, ctx: Context, output: AgentOutput, memory: BaseMemory + ) -> AgentOutput: + """Finalize the React agent.""" + current_reasoning: list[BaseReasoningStep] = await ctx.get( + self.reasoning_key, default=[] + ) + + reasoning_str = "\n".join([x.get_content() for x in current_reasoning]) + + if reasoning_str: + reasoning_msg = ChatMessage(role="assistant", content=reasoning_str) + await memory.aput(reasoning_msg) + await ctx.set(self.reasoning_key, []) + + # remove "Answer:" from the response + if output.response and "Answer:" in output.response: + start_idx = output.response.index("Answer:") + output.response = output.response[start_idx + len("Answer:") :].strip() + + return output diff --git a/llama-index-core/llama_index/core/agent/workflow/workflow_events.py b/llama-index-core/llama_index/core/agent/workflow/workflow_events.py new file mode 100644 index 0000000000000..2a0532e79c229 --- /dev/null +++ b/llama-index-core/llama_index/core/agent/workflow/workflow_events.py @@ -0,0 +1,57 @@ +from typing import Any + +from llama_index.core.tools import AsyncBaseTool, ToolSelection, ToolOutput +from llama_index.core.llms import ChatMessage +from llama_index.core.workflow import Event + + +class AgentInput(Event): + """LLM input.""" + + input: list[ChatMessage] + current_agent_name: str + + +class AgentSetup(Event): + """Agent setup.""" + + input: list[ChatMessage] + current_agent_name: str + tools: list[AsyncBaseTool] + + +class AgentStream(Event): + """Agent stream.""" + + delta: str + response: str + current_agent_name: str + tool_calls: list[ToolSelection] + raw: Any + + +class AgentOutput(Event): + """LLM output.""" + + response: str + tool_calls: list[ToolSelection] + raw: Any + current_agent_name: str + + def __str__(self) -> str: + return str(self.response) + + +class ToolCall(Event): + """All tool calls are surfaced.""" + + tool_name: str + tool_kwargs: dict + tool_id: str + + +class ToolCallResult(ToolCall): + """Tool call result.""" + + tool_output: ToolOutput + return_direct: bool diff --git a/llama-index-core/llama_index/core/tools/types.py b/llama-index-core/llama_index/core/tools/types.py index a35a990af7708..941fa9cfa9f0e 100644 --- a/llama-index-core/llama_index/core/tools/types.py +++ b/llama-index-core/llama_index/core/tools/types.py @@ -1,3 +1,4 @@ +import asyncio import json from abc import abstractmethod from dataclasses import dataclass @@ -195,7 +196,7 @@ def call(self, input: Any) -> ToolOutput: return self.base_tool(input) async def acall(self, input: Any) -> ToolOutput: - return self.call(input) + return await asyncio.to_thread(self.call, input) def adapt_to_async_tool(tool: BaseTool) -> AsyncBaseTool: diff --git a/llama-index-core/llama_index/core/workflow/__init__.py b/llama-index-core/llama_index/core/workflow/__init__.py index 54dfd2660540a..f842cd20d7499 100644 --- a/llama-index-core/llama_index/core/workflow/__init__.py +++ b/llama-index-core/llama_index/core/workflow/__init__.py @@ -16,6 +16,7 @@ InputRequiredEvent, HumanResponseEvent, ) +from llama_index.core.workflow.tools import FunctionToolWithContext from llama_index.core.workflow.workflow import Workflow from llama_index.core.workflow.context import Context from llama_index.core.workflow.context_serializers import ( @@ -46,4 +47,5 @@ "JsonSerializer", "WorkflowCheckpointer", "Checkpoint", + "FunctionToolWithContext", ] diff --git a/llama-index-core/llama_index/core/workflow/context.py b/llama-index-core/llama_index/core/workflow/context.py index c829de6409fa4..822e67f0f1f67 100644 --- a/llama-index-core/llama_index/core/workflow/context.py +++ b/llama-index-core/llama_index/core/workflow/context.py @@ -1,8 +1,9 @@ import asyncio import json import warnings +import uuid from collections import defaultdict -from typing import Dict, Any, Optional, List, Type, TYPE_CHECKING, Set, Tuple +from typing import Dict, Any, Optional, List, Type, TYPE_CHECKING, Set, Tuple, TypeVar from .context_serializers import BaseSerializer, JsonSerializer from .decorators import StepConfig @@ -12,6 +13,8 @@ if TYPE_CHECKING: # pragma: no cover from .workflow import Workflow +T = TypeVar("T", bound=Event) + class Context: """A global object representing a context for a given workflow run. @@ -281,6 +284,44 @@ def send_event(self, message: Event, step: Optional[str] = None) -> None: self._broker_log.append(message) + async def wait_for_event( + self, + event_type: Type[T], + requirements: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = 2000, + ) -> T: + """Asynchronously wait for a specific event type to be received. + + Args: + event_type: The type of event to wait for + requirements: Optional dict of requirements the event must match + timeout: Optional timeout in seconds. Defaults to 2000s. + + Returns: + The event type that was requested. + + Raises: + asyncio.TimeoutError: If the timeout is reached before receiving matching event + """ + requirements = requirements or {} + waiter_id = str(uuid.uuid4()) + self._queues[waiter_id] = asyncio.Queue() + + while True: + event = await asyncio.wait_for( + self._queues[waiter_id].get(), timeout=timeout + ) + if isinstance(event, event_type): + if all( + event.get(k, default=None) == v for k, v in requirements.items() + ): + # in the case of checkpointing/resuming, we only want to delete + # once we've received the event we're looking for + del self._queues[waiter_id] + return event + else: + continue + def write_event_to_stream(self, ev: Optional[Event]) -> None: self._streaming_queue.put_nowait(ev) diff --git a/llama-index-core/llama_index/core/workflow/tools.py b/llama-index-core/llama_index/core/workflow/tools.py new file mode 100644 index 0000000000000..de2a89ab568ca --- /dev/null +++ b/llama-index-core/llama_index/core/workflow/tools.py @@ -0,0 +1,131 @@ +from inspect import signature +from typing import Any, Awaitable, Optional, Callable, Type, List, Tuple, Union, cast + +from llama_index.core.bridge.pydantic import BaseModel, FieldInfo, create_model +from llama_index.core.tools import ( + FunctionTool, + ToolOutput, + ToolMetadata, +) +from llama_index.core.workflow import ( + Context, +) + +AsyncCallable = Callable[..., Awaitable[Any]] + + +def create_schema_from_function( + name: str, + func: Union[Callable[..., Any], Callable[..., Awaitable[Any]]], + additional_fields: Optional[ + List[Union[Tuple[str, Type, Any], Tuple[str, Type]]] + ] = None, +) -> Type[BaseModel]: + """Create schema from function.""" + fields = {} + params = signature(func).parameters + for param_name in params: + # TODO: Very hacky way to remove the ctx parameter from the signature + if param_name == "ctx": + continue + + param_type = params[param_name].annotation + param_default = params[param_name].default + + if param_type is params[param_name].empty: + param_type = Any + + if param_default is params[param_name].empty: + # Required field + fields[param_name] = (param_type, FieldInfo()) + elif isinstance(param_default, FieldInfo): + # Field with pydantic.Field as default value + fields[param_name] = (param_type, param_default) + else: + fields[param_name] = (param_type, FieldInfo(default=param_default)) + + additional_fields = additional_fields or [] + for field_info in additional_fields: + if len(field_info) == 3: + field_info = cast(Tuple[str, Type, Any], field_info) + field_name, field_type, field_default = field_info + fields[field_name] = (field_type, FieldInfo(default=field_default)) + elif len(field_info) == 2: + # Required field has no default value + field_info = cast(Tuple[str, Type], field_info) + field_name, field_type = field_info + fields[field_name] = (field_type, FieldInfo()) + else: + raise ValueError( + f"Invalid additional field info: {field_info}. " + "Must be a tuple of length 2 or 3." + ) + + return create_model(name, **fields) # type: ignore + + +class FunctionToolWithContext(FunctionTool): + """ + A function tool that also includes passing in workflow context. + + Only overrides the call methods to include the context. + """ + + @classmethod + def from_defaults( + cls, + fn: Optional[Callable[..., Any]] = None, + name: Optional[str] = None, + description: Optional[str] = None, + return_direct: bool = False, + fn_schema: Optional[Type[BaseModel]] = None, + async_fn: Optional[AsyncCallable] = None, + tool_metadata: Optional[ToolMetadata] = None, + ) -> "FunctionToolWithContext": + if tool_metadata is None: + fn_to_parse = fn or async_fn + assert fn_to_parse is not None, "fn or async_fn must be provided." + name = name or fn_to_parse.__name__ + docstring = fn_to_parse.__doc__ + + # TODO: Very hacky way to remove the ctx parameter from the signature + signature_str = str(signature(fn_to_parse)) + signature_str = signature_str.replace( + "ctx: llama_index.core.workflow.context.Context, ", "" + ) + signature_str = signature_str.replace( + "ctx: llama_index.core.workflow.context.Context", "" + ) + + description = description or f"{name}{signature_str}\n{docstring}" + if fn_schema is None: + fn_schema = create_schema_from_function( + f"{name}", fn_to_parse, additional_fields=None + ) + tool_metadata = ToolMetadata( + name=name, + description=description, + fn_schema=fn_schema, + return_direct=return_direct, + ) + return cls(fn=fn, metadata=tool_metadata, async_fn=async_fn) + + def call(self, ctx: Context, *args: Any, **kwargs: Any) -> ToolOutput: # type: ignore + """Call.""" + tool_output = self._fn(ctx, *args, **kwargs) + return ToolOutput( + content=str(tool_output), + tool_name=self.metadata.name, + raw_input={"args": args, "kwargs": kwargs}, + raw_output=tool_output, + ) + + async def acall(self, ctx: Context, *args: Any, **kwargs: Any) -> ToolOutput: # type: ignore + """Call.""" + tool_output = await self._async_fn(ctx, *args, **kwargs) + return ToolOutput( + content=str(tool_output), + tool_name=self.metadata.name, + raw_input={"args": args, "kwargs": kwargs}, + raw_output=tool_output, + ) diff --git a/llama-index-core/tests/agent/workflow/BUILD b/llama-index-core/tests/agent/workflow/BUILD new file mode 100644 index 0000000000000..57341b1358b56 --- /dev/null +++ b/llama-index-core/tests/agent/workflow/BUILD @@ -0,0 +1,3 @@ +python_tests( + name="tests", +) diff --git a/llama-index-core/tests/agent/workflow/__init__.py b/llama-index-core/tests/agent/workflow/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llama-index-core/tests/agent/workflow/test_multi_agent_workflow.py b/llama-index-core/tests/agent/workflow/test_multi_agent_workflow.py new file mode 100644 index 0000000000000..340c98cdcdbc5 --- /dev/null +++ b/llama-index-core/tests/agent/workflow/test_multi_agent_workflow.py @@ -0,0 +1,275 @@ +from typing import Any, List +import pytest + +from llama_index.core.llms import MockLLM +from llama_index.core.agent.workflow.multi_agent_workflow import MultiAgentWorkflow +from llama_index.core.agent.workflow.function_agent import FunctionAgent +from llama_index.core.agent.workflow.react_agent import ReactAgent +from llama_index.core.llms import ( + ChatMessage, + ChatResponse, + MessageRole, + ChatResponseAsyncGen, + LLMMetadata, +) +from llama_index.core.tools import FunctionTool, ToolSelection +from llama_index.core.memory import ChatMemoryBuffer + + +class MockLLM(MockLLM): + def __init__(self, responses: List[ChatMessage]): + super().__init__() + self._responses = responses + self._response_index = 0 + + @property + def metadata(self) -> LLMMetadata: + return LLMMetadata(is_function_calling_model=True) + + async def astream_chat( + self, messages: List[ChatMessage], **kwargs: Any + ) -> ChatResponseAsyncGen: + response_msg = self._responses[self._response_index] + self._response_index = (self._response_index + 1) % len(self._responses) + + async def _gen(): + yield ChatResponse( + message=response_msg, + delta=response_msg.content, + raw={"content": response_msg.content}, + ) + + return _gen() + + async def astream_chat_with_tools( + self, tools: List[Any], chat_history: List[ChatMessage], **kwargs: Any + ) -> ChatResponseAsyncGen: + response_msg = self._responses[self._response_index] + self._response_index = (self._response_index + 1) % len(self._responses) + + async def _gen(): + yield ChatResponse( + message=response_msg, + delta=response_msg.content, + raw={"content": response_msg.content}, + ) + + return _gen() + + def get_tool_calls_from_response( + self, response: ChatResponse, **kwargs: Any + ) -> List[ToolSelection]: + return response.message.additional_kwargs.get("tool_calls", []) + + +def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + +def subtract(a: int, b: int) -> int: + """Subtract two numbers.""" + return a - b + + +@pytest.fixture() +def calculator_agent(): + return ReactAgent( + name="calculator", + description="Performs basic arithmetic operations", + system_prompt="You are a calculator assistant.", + tools=[ + FunctionTool.from_defaults(fn=add), + FunctionTool.from_defaults(fn=subtract), + ], + llm=MockLLM( + responses=[ + ChatMessage( + role=MessageRole.ASSISTANT, + content='Thought: I need to add these numbers\nAction: add\nAction Input: {"a": 5, "b": 3}\n', + ), + ChatMessage( + role=MessageRole.ASSISTANT, + content=r"Thought: The result is 8\Answer: The sum is 8", + ), + ] + ), + ) + + +@pytest.fixture() +def retriever_agent(): + return FunctionAgent( + name="retriever", + description="Manages data retrieval", + system_prompt="You are a retrieval assistant.", + is_entrypoint_agent=True, + llm=MockLLM( + responses=[ + ChatMessage( + role=MessageRole.ASSISTANT, + content="Let me help you with that calculation. I'll hand this off to the calculator.", + additional_kwargs={ + "tool_calls": [ + ToolSelection( + tool_id="one", + tool_name="handoff", + tool_kwargs={ + "to_agent": "calculator", + "reason": "This requires arithmetic operations.", + }, + ) + ] + }, + ), + ], + ), + ) + + +@pytest.mark.asyncio() +async def test_basic_workflow(calculator_agent, retriever_agent): + """Test basic workflow initialization and validation.""" + workflow = MultiAgentWorkflow( + agents=[calculator_agent, retriever_agent], + ) + + assert workflow.root_agent == retriever_agent + assert len(workflow.agents) == 2 + assert "calculator" in workflow.agents + assert "retriever" in workflow.agents + + +@pytest.mark.asyncio() +async def test_workflow_requires_root_agent(): + """Test that workflow requires exactly one root agent.""" + with pytest.raises(ValueError, match="Exactly one root agent must be provided"): + MultiAgentWorkflow( + agents=[ + FunctionAgent( + name="agent1", + description="test", + is_entrypoint_agent=True, + llm=MockLLM( + responses=[ + ChatMessage(role=MessageRole.ASSISTANT, content="test"), + ] + ), + ), + ReactAgent( + name="agent2", + description="test", + is_entrypoint_agent=True, + llm=MockLLM( + responses=[ + ChatMessage(role=MessageRole.ASSISTANT, content="test"), + ] + ), + ), + ] + ) + + +@pytest.mark.asyncio() +async def test_workflow_execution(calculator_agent, retriever_agent): + """Test basic workflow execution with agent handoff.""" + workflow = MultiAgentWorkflow( + agents=[calculator_agent, retriever_agent], + ) + + memory = ChatMemoryBuffer.from_defaults() + handler = workflow.run(user_msg="Can you add 5 and 3?", memory=memory) + + events = [] + async for event in handler.stream_events(): + events.append(event) + + response = await handler + + # Verify we got events indicating handoff and calculation + assert any( + ev.current_agent_name == "retriever" + if hasattr(ev, "current_agent_name") + else False + for ev in events + ) + assert any( + ev.current_agent_name == "calculator" + if hasattr(ev, "current_agent_name") + else False + for ev in events + ) + assert "8" in response.response + + +@pytest.mark.asyncio() +async def test_invalid_handoff(): + """Test handling of invalid agent handoff.""" + agent1 = FunctionAgent( + name="agent1", + description="test", + is_entrypoint_agent=True, + llm=MockLLM( + responses=[ + ChatMessage( + role=MessageRole.ASSISTANT, + content="handoff invalid_agent Because reasons", + additional_kwargs={ + "tool_calls": [ + ToolSelection( + tool_id="one", + tool_name="handoff", + tool_kwargs={ + "to_agent": "invalid_agent", + "reason": "Because reasons", + }, + ) + ] + }, + ), + ChatMessage(role=MessageRole.ASSISTANT, content="guess im stuck here"), + ], + ), + ) + + workflow = MultiAgentWorkflow( + agents=[agent1], + ) + + handler = workflow.run(user_msg="test") + events = [] + async for event in handler.stream_events(): + events.append(event) + + response = await handler + assert "Agent invalid_agent not found" in str(events) + + +@pytest.mark.asyncio() +async def test_workflow_with_state(): + """Test workflow with state management.""" + agent = FunctionAgent( + name="agent", + description="test", + is_entrypoint_agent=True, + llm=MockLLM( + responses=[ + ChatMessage( + role=MessageRole.ASSISTANT, content="Current state processed" + ) + ], + ), + ) + + workflow = MultiAgentWorkflow( + agents=[agent], + initial_state={"counter": 0}, + state_prompt="Current state: {state}. User message: {msg}", + ) + + handler = workflow.run(user_msg="test") + async for _ in handler.stream_events(): + pass + + response = await handler + assert response is not None diff --git a/llama-index-core/tests/workflow/test_context.py b/llama-index-core/tests/workflow/test_context.py index 06e4aa3568010..1cb35c2379075 100644 --- a/llama-index-core/tests/workflow/test_context.py +++ b/llama-index-core/tests/workflow/test_context.py @@ -1,3 +1,4 @@ +import asyncio from unittest import mock from typing import Union, Optional @@ -130,3 +131,42 @@ async def test_empty_inprogress_when_workflow_done(workflow): # there shouldn't be any in progress events for inprogress_list in h.ctx._in_progress.values(): assert len(inprogress_list) == 0 + + +@pytest.mark.asyncio() +async def test_wait_for_event(ctx): + wait_job = asyncio.create_task(ctx.wait_for_event(Event)) + await asyncio.sleep(0.01) + ctx.send_event(Event(msg="foo")) + ev = await wait_job + assert ev.msg == "foo" + + +@pytest.mark.asyncio() +async def test_wait_for_event_with_requirements(ctx): + wait_job = asyncio.create_task(ctx.wait_for_event(Event, {"msg": "foo"})) + await asyncio.sleep(0.01) + ctx.send_event(Event(msg="bar")) + ctx.send_event(Event(msg="foo")) + ev = await wait_job + assert ev.msg == "foo" + + +@pytest.mark.asyncio() +async def test_wait_for_event_in_workflow(): + class TestWorkflow(Workflow): + @step + async def step1(self, ctx: Context, ev: StartEvent) -> StopEvent: + ctx.write_event_to_stream(Event(msg="foo")) + result = await ctx.wait_for_event(Event) + return StopEvent(result=result.msg) + + workflow = TestWorkflow() + handler = workflow.run() + async for ev in handler.stream_events(): + if isinstance(ev, Event) and ev.msg == "foo": + handler.ctx.send_event(Event(msg="bar")) + break + + result = await handler + assert result == "bar"