Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Mistral optimised #396

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 69 additions & 7 deletions pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations as _annotations

import json
import os
from collections.abc import AsyncIterator, Iterable
from contextlib import asynccontextmanager
Expand Down Expand Up @@ -39,7 +40,6 @@
)

try:
from json_repair import repair_json
from mistralai import (
UNSET,
CompletionChunk as MistralCompletionChunk,
Expand Down Expand Up @@ -547,14 +547,14 @@ def get(self, *, final: bool = False) -> ModelResponse:

elif self._delta_content and self._result_tools:
# NOTE: Params set for the most efficient and fastest way.
output_json = repair_json(self._delta_content, return_objects=True, skip_json_loads=True)
assert isinstance(
output_json, dict
), f'Expected repair_json as type dict, invalid type: {type(output_json)}'
output_json = _repair_json(self._delta_content)
# assert isinstance(
# output_json, (dict, type(None))
# ), f'Expected repair_json as type dict, invalid type: {type(output_json)}'

if output_json:
if isinstance(output_json, dict) and output_json:
for result_tool in self._result_tools.values():
# NOTE: Additional verification to prevent JSON validation to crash in `result.py`
# NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
# Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
# For example, `return_type=list[str]` expects a 'response' key with value type array of str.
# when `{"response":` then `repair_json` sets `{"response": ""}` (type not found default str)
Expand Down Expand Up @@ -678,3 +678,65 @@ def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None
result = None

return result


def _repair_json(s: str) -> dict[str, Any] | list[Any] | None:
"""Attempt to parse a given string as JSON, repairing common issues."""
# Attempt to parse the string as-is.
try:
return json.loads(s, strict=False)
except json.JSONDecodeError:
pass

new_chars: list[str] = []
stack: list[Any] = []
is_inside_string = False
escaped = False

# Process each character in the string.
for char in s:
if is_inside_string:
if char == '"' and not escaped:
is_inside_string = False
elif char == '\n' and not escaped:
char = '\\n' # Replace newline with escape sequence.
elif char == '\\':
escaped = not escaped
else:
escaped = False
else:
if char == '"':
is_inside_string = True
escaped = False
elif char == '{':
stack.append('}')
elif char == '[':
stack.append(']')
elif char == '}' or char == ']':
if stack and stack[-1] == char:
stack.pop()
else:
# Mismatched closing character; the input is malformed.
return None

# Append the processed character to the new string.
new_chars.append(char)

# If we're still inside a string at the end of processing, close the string.
if is_inside_string:
new_chars.append('"')

# Reverse the stack to get the closing characters.
stack.reverse()

# Try to parse the modified string until we succeed or run out of characters.
while new_chars:
try:
value = ''.join(new_chars + stack)
return json.loads(value, strict=False)
except json.JSONDecodeError:
# If parsing fails, try removing the last character.
new_chars.pop()

# If we still can't parse the string as JSON, return None.
return None
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ openai = ["openai>=1.54.3"]
vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"]
anthropic = ["anthropic>=0.40.0"]
groq = ["groq>=0.12.0"]
mistral = ["mistralai>=1.2.5", "json-repair>=0.30.3"]
mistral = ["mistralai>=1.2.5"]
logfire = ["logfire>=2.3"]

[dependency-groups]
Expand Down
102 changes: 88 additions & 14 deletions tests/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ async def test_request_result_type_with_arguments_str_response(allow_model_reque
#####################


async def test_stream_structured_with_all_typd(allow_model_requests: None):
async def test_stream_structured_with_all_type(allow_model_requests: None):
class MyTypedDict(TypedDict, total=False):
first: str
second: int
Expand All @@ -563,19 +563,19 @@ class MyTypedDict(TypedDict, total=False):
'", "second": 2',
),
text_chunk(
'", "bool_value": true',
', "bool_value": true',
),
text_chunk(
'", "nullable_value": null',
', "nullable_value": null',
),
text_chunk(
'", "array_value": ["A", "B", "C"]',
', "array_value": ["A", "B", "C"]',
),
text_chunk(
'", "dict_value": {"A": "A", "B":"B"}',
', "dict_value": {"A": "A", "B":"B"}',
),
text_chunk(
'", "dict_int_value": {"A": 1, "B":2}',
', "dict_int_value": {"A": 1, "B":2}',
),
text_chunk('}'),
chunk([]),
Expand Down Expand Up @@ -721,8 +721,8 @@ class MyTypedDict(TypedDict, total=False):
{'first': 'One'},
{'first': 'One'},
{'first': 'One'},
{'first': 'One', 'second': ''},
{'first': 'One', 'second': ''},
{'first': 'One'},
{'first': 'One'},
{'first': 'One', 'second': ''},
{'first': 'One', 'second': 'T'},
{'first': 'One', 'second': 'Tw'},
Expand Down Expand Up @@ -828,20 +828,21 @@ async def test_stream_result_type_primitif_array(allow_model_requests: None):
v = [c async for c in result.stream(debounce_by=None)]
assert v == snapshot(
[
[''],
['f'],
['fi'],
['fir'],
['firs'],
['first'],
['first'],
['first'],
['first'],
['first', ''],
['first', 'O'],
['first', 'On'],
['first', 'One'],
['first', 'One'],
['first', 'One'],
['first', 'One'],
['first', 'One', ''],
['first', 'One', 's'],
['first', 'One', 'se'],
['first', 'One', 'sec'],
Expand All @@ -850,7 +851,7 @@ async def test_stream_result_type_primitif_array(allow_model_requests: None):
['first', 'One', 'second'],
['first', 'One', 'second'],
['first', 'One', 'second'],
['first', 'One', 'second'],
['first', 'One', 'second', ''],
['first', 'One', 'second', 'T'],
['first', 'One', 'second', 'Tw'],
['first', 'One', 'second', 'Two'],
Expand All @@ -869,10 +870,10 @@ async def test_stream_result_type_primitif_array(allow_model_requests: None):
assert result.cost().response_tokens == len(stream)


async def test_stream_result_type_basemodel(allow_model_requests: None):
async def test_stream_result_type_basemodel_with_default_params(allow_model_requests: None):
class MyTypedBaseModel(BaseModel):
first: str = '' # Note: Don't forget to set default values
second: str = ''
first: str = '' # Note: Default, set value.
second: str = '' # Note: Default, set value.

# Given
stream = [
Expand Down Expand Up @@ -958,6 +959,79 @@ class MyTypedBaseModel(BaseModel):
assert result.cost().response_tokens == len(stream)


async def test_stream_result_type_basemodel_with_required_params(allow_model_requests: None):
class MyTypedBaseModel(BaseModel):
first: str # Note: Required params
second: str # Note: Required params

# Given
stream = [
text_chunk('{'),
text_chunk('"'),
text_chunk('f'),
text_chunk('i'),
text_chunk('r'),
text_chunk('s'),
text_chunk('t'),
text_chunk('"'),
text_chunk(':'),
text_chunk(' '),
text_chunk('"'),
text_chunk('O'),
text_chunk('n'),
text_chunk('e'),
text_chunk('"'),
text_chunk(','),
text_chunk(' '),
text_chunk('"'),
text_chunk('s'),
text_chunk('e'),
text_chunk('c'),
text_chunk('o'),
text_chunk('n'),
text_chunk('d'),
text_chunk('"'),
text_chunk(':'),
text_chunk(' '),
text_chunk('"'),
text_chunk('T'),
text_chunk('w'),
text_chunk('o'),
text_chunk('"'),
text_chunk('}'),
chunk([]),
]

mock_client = MockMistralAI.create_stream_mock(stream)
model = MistralModel('mistral-large-latest', client=mock_client)
agent = Agent(model=model, result_type=MyTypedBaseModel)

# When
async with agent.run_stream('User prompt value') as result:
# Then
assert result.is_structured
assert not result.is_complete
v = [c async for c in result.stream(debounce_by=None)]
assert v == snapshot(
[
MyTypedBaseModel(first='One', second=''),
MyTypedBaseModel(first='One', second='T'),
MyTypedBaseModel(first='One', second='Tw'),
MyTypedBaseModel(first='One', second='Two'),
MyTypedBaseModel(first='One', second='Two'),
MyTypedBaseModel(first='One', second='Two'),
MyTypedBaseModel(first='One', second='Two'),
]
)
assert result.is_complete
assert result.cost().request_tokens == 34
assert result.cost().response_tokens == 34
assert result.cost().total_tokens == 34

# double check cost matches stream count
assert result.cost().response_tokens == len(stream)


#####################
## Completion Function call
#####################
Expand Down
11 changes: 0 additions & 11 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading