Skip to content

Commit

Permalink
Fix failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Dec 17, 2024
1 parent 7465690 commit a78fb53
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 14 deletions.
49 changes: 36 additions & 13 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,19 +1015,32 @@ async def _handle_streamed_structured_response(
tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
parts: list[_messages.ModelRequestPart] = []

async for item in stream_structured_parts(model_response):
def handle_completed_part(p: ModelResponsePart | None) -> None:
if isinstance(p, _messages.ToolCallPart):
if tool := self._function_tools.get(p.tool_name):
tasks.append(asyncio.create_task(tool.run(deps, p, conv_messages), name=p.tool_name))
else:
parts.append(self._unknown_tool(p.tool_name))

last_part_index = 0
last_part: ModelResponsePart | None = None
async for part_index, part in _stream_structured_parts(model_response):
handled_any_parts = True
if self._result_schema and (match := self._result_schema.find_tool([item])):
if self._result_schema and (match := self._result_schema.find_tool([part])):
call, _ = match
for task in tasks:
# Abandon the execution of all tool calls and return the final result
task.cancel()
return _MarkFinalResult(model_response, call.tool_name)
elif isinstance(item, _messages.ToolCallPart):
call = item
if tool := self._function_tools.get(call.tool_name):
# NOTE: The next line starts the tool running, which is at odds with the
# logfire span created below saying that's where the tool is run.
tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
else:
parts.append(self._unknown_tool(call.tool_name))

if part_index > last_part_index:
# Only process non-result parts when we've moved to a new part
handle_completed_part(last_part)

last_part_index = part_index
last_part = part

handle_completed_part(last_part)

if not handled_any_parts:
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
Expand Down Expand Up @@ -1124,19 +1137,29 @@ class _MarkFinalResult(Generic[ResultData_Co]):
"""Name of the final result tool, None if the result is a string."""


async def stream_structured_parts(model_response: StreamStructuredResponse) -> AsyncIterator[ModelResponsePart]:
async def _stream_structured_parts(
model_response: StreamStructuredResponse,
) -> AsyncIterator[tuple[int, ModelResponsePart]]:
"""Yields a tuple of [index, part] for each part in the model response as it is received.
You can assume that when the index increases, the previous part is complete.
The reason we do not stream only "completed" parts here is because we only want to iterate over completed parts
for function tool calls. For the final result, we want to iterate over all deltas. At this level of API,
we aren't aware of whether a part is part of the final response or not.
"""
last_part_index = 0
last_part: ModelResponsePart | None = None

def new_or_updated_parts() -> Iterator[ModelResponsePart]:
def new_or_updated_parts() -> Iterator[tuple[int, ModelResponsePart]]:
nonlocal last_part_index, last_part

structured_msg = model_response.get()
new_last_part_index = last_part_index
for i, part in enumerate(structured_msg.parts[last_part_index:]):
if i == 0 and part == last_part:
continue # this part was not updated
yield part
yield last_part_index + i, part
last_part = part
new_last_part_index = last_part_index + i
last_part_index = new_last_part_index
Expand Down
1 change: 0 additions & 1 deletion tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ async def stream_structured_function(
else:
last = messages[-1]
assert isinstance(last, ModelRequest)
print('\n', last.parts, '\n')
assert isinstance(last.parts[0], ToolReturnPart)
assert agent_info.result_tools is not None
assert len(agent_info.result_tools) == 1
Expand Down

0 comments on commit a78fb53

Please sign in to comment.