-
Notifications
You must be signed in to change notification settings - Fork 243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attempted improvements to "fix streamed response messages" #278
base: main
Are you sure you want to change the base?
Conversation
async for item in stream_structured_parts(model_response): | ||
handled_any_parts = True | ||
if self._result_schema and (match := self._result_schema.find_tool([item])): | ||
call, _ = match | ||
return _MarkFinalResult(model_response, call.tool_name) | ||
elif isinstance(item, _messages.ToolCallPart): | ||
call = item | ||
if tool := self._function_tools.get(call.tool_name): | ||
# NOTE: The next line starts the tool running, which is at odds with the | ||
# logfire span created below saying that's where the tool is run. | ||
tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name)) | ||
else: | ||
parts.append(self._unknown_tool(call.tool_name)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is one of the interesting bits of this PR
async def stream_structured_parts(model_response: StreamStructuredResponse) -> AsyncIterator[ModelResponsePart]: | ||
last_part_index = 0 | ||
last_part: ModelResponsePart | None = None | ||
|
||
def new_or_updated_parts() -> Iterator[ModelResponsePart]: | ||
nonlocal last_part_index, last_part | ||
|
||
structured_msg = model_response.get() | ||
new_last_part_index = last_part_index | ||
for i, part in enumerate(structured_msg.parts[last_part_index:]): | ||
if i == 0 and part == last_part: | ||
continue # this part was not updated | ||
yield part | ||
last_part = part | ||
new_last_part_index = last_part_index + i | ||
last_part_index = new_last_part_index | ||
|
||
while True: | ||
for p in new_or_updated_parts(): | ||
yield p | ||
try: | ||
await model_response.__anext__() | ||
except StopAsyncIteration: | ||
break | ||
for p in new_or_updated_parts(): | ||
yield p |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the other interesting bit of this PR — the idea is that I've added a function that streams the changes to parts coming out of the model response. Once we have this, we can just handle them one at a time above.
Deploying pydantic-ai with Cloudflare Pages
|
This gets one of the new tests passing, and I think improves the state of that PR conceptually, but now a test is failing and I'm not sure what the intended behavior of that test was or whether the new output makes sense.
In particular,
tests.test_streaming.test_call_tool
fails in a way that I can't tell what's going wrong.