Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempted improvements to "fix streamed response messages" #278

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

dmontagu
Copy link
Contributor

This gets one of the new tests passing, and I think improves the state of that PR conceptually, but now a test is failing and I'm not sure what the intended behavior of that test was or whether the new output makes sense.

In particular, tests.test_streaming.test_call_tool fails in a way that I can't tell what's going wrong.

Comment on lines 1018 to 1030
async for item in stream_structured_parts(model_response):
handled_any_parts = True
if self._result_schema and (match := self._result_schema.find_tool([item])):
call, _ = match
return _MarkFinalResult(model_response, call.tool_name)
elif isinstance(item, _messages.ToolCallPart):
call = item
if tool := self._function_tools.get(call.tool_name):
# NOTE: The next line starts the tool running, which is at odds with the
# logfire span created below saying that's where the tool is run.
tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
else:
parts.append(self._unknown_tool(call.tool_name))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is one of the interesting bits of this PR

Comment on lines 1127 to 1152
async def stream_structured_parts(model_response: StreamStructuredResponse) -> AsyncIterator[ModelResponsePart]:
last_part_index = 0
last_part: ModelResponsePart | None = None

def new_or_updated_parts() -> Iterator[ModelResponsePart]:
nonlocal last_part_index, last_part

structured_msg = model_response.get()
new_last_part_index = last_part_index
for i, part in enumerate(structured_msg.parts[last_part_index:]):
if i == 0 and part == last_part:
continue # this part was not updated
yield part
last_part = part
new_last_part_index = last_part_index + i
last_part_index = new_last_part_index

while True:
for p in new_or_updated_parts():
yield p
try:
await model_response.__anext__()
except StopAsyncIteration:
break
for p in new_or_updated_parts():
yield p
Copy link
Contributor Author

@dmontagu dmontagu Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the other interesting bit of this PR — the idea is that I've added a function that streams the changes to parts coming out of the model response. Once we have this, we can just handle them one at a time above.

Base automatically changed from streamed-response-messages to main December 16, 2024 20:54
Copy link

cloudflare-workers-and-pages bot commented Dec 17, 2024

Deploying pydantic-ai with  Cloudflare Pages  Cloudflare Pages

Latest commit: a78fb53
Status: ✅  Deploy successful!
Preview URL: https://b6791f89.pydantic-ai.pages.dev
Branch Preview URL: https://dmontagu-streamed-response-m.pydantic-ai.pages.dev

View logs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants