diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 639b937..c025195 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import json import os from collections.abc import AsyncIterator, Iterable from contextlib import asynccontextmanager @@ -39,7 +40,6 @@ ) try: - from json_repair import repair_json from mistralai import ( UNSET, CompletionChunk as MistralCompletionChunk, @@ -547,14 +547,14 @@ def get(self, *, final: bool = False) -> ModelResponse: elif self._delta_content and self._result_tools: # NOTE: Params set for the most efficient and fastest way. - output_json = repair_json(self._delta_content, return_objects=True, skip_json_loads=True) - assert isinstance( - output_json, dict - ), f'Expected repair_json as type dict, invalid type: {type(output_json)}' + output_json = _repair_json(self._delta_content) + # assert isinstance( + # output_json, (dict, type(None)) + # ), f'Expected repair_json as type dict, invalid type: {type(output_json)}' - if output_json: + if isinstance(output_json, dict) and output_json: for result_tool in self._result_tools.values(): - # NOTE: Additional verification to prevent JSON validation to crash in `result.py` + # NOTE: Additional verification to prevent JSON validation to crash in `_result.py` # Ensures required parameters in the JSON schema are respected, especially for stream-based return types. # For example, `return_type=list[str]` expects a 'response' key with value type array of str. # when `{"response":` then `repair_json` sets `{"response": ""}` (type not found default str) @@ -678,3 +678,65 @@ def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None result = None return result + + +def _repair_json(s: str) -> dict[str, Any] | list[Any] | None: + """Attempt to parse a given string as JSON, repairing common issues.""" + # Attempt to parse the string as-is. + try: + return json.loads(s, strict=False) + except json.JSONDecodeError: + pass + + new_chars: list[str] = [] + stack: list[Any] = [] + is_inside_string = False + escaped = False + + # Process each character in the string. + for char in s: + if is_inside_string: + if char == '"' and not escaped: + is_inside_string = False + elif char == '\n' and not escaped: + char = '\\n' # Replace newline with escape sequence. + elif char == '\\': + escaped = not escaped + else: + escaped = False + else: + if char == '"': + is_inside_string = True + escaped = False + elif char == '{': + stack.append('}') + elif char == '[': + stack.append(']') + elif char == '}' or char == ']': + if stack and stack[-1] == char: + stack.pop() + else: + # Mismatched closing character; the input is malformed. + return None + + # Append the processed character to the new string. + new_chars.append(char) + + # If we're still inside a string at the end of processing, close the string. + if is_inside_string: + new_chars.append('"') + + # Reverse the stack to get the closing characters. + stack.reverse() + + # Try to parse the modified string until we succeed or run out of characters. + while new_chars: + try: + value = ''.join(new_chars + stack) + return json.loads(value, strict=False) + except json.JSONDecodeError: + # If parsing fails, try removing the last character. + new_chars.pop() + + # If we still can't parse the string as JSON, return None. + return None diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 5055008..10a4eeb 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -46,7 +46,7 @@ openai = ["openai>=1.54.3"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"] anthropic = ["anthropic>=0.40.0"] groq = ["groq>=0.12.0"] -mistral = ["mistralai>=1.2.5", "json-repair>=0.30.3"] +mistral = ["mistralai>=1.2.5"] logfire = ["logfire>=2.3"] [dependency-groups] diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 86532de..ac46faa 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -544,7 +544,7 @@ async def test_request_result_type_with_arguments_str_response(allow_model_reque ##################### -async def test_stream_structured_with_all_typd(allow_model_requests: None): +async def test_stream_structured_with_all_type(allow_model_requests: None): class MyTypedDict(TypedDict, total=False): first: str second: int @@ -563,19 +563,19 @@ class MyTypedDict(TypedDict, total=False): '", "second": 2', ), text_chunk( - '", "bool_value": true', + ', "bool_value": true', ), text_chunk( - '", "nullable_value": null', + ', "nullable_value": null', ), text_chunk( - '", "array_value": ["A", "B", "C"]', + ', "array_value": ["A", "B", "C"]', ), text_chunk( - '", "dict_value": {"A": "A", "B":"B"}', + ', "dict_value": {"A": "A", "B":"B"}', ), text_chunk( - '", "dict_int_value": {"A": 1, "B":2}', + ', "dict_int_value": {"A": 1, "B":2}', ), text_chunk('}'), chunk([]), @@ -721,8 +721,8 @@ class MyTypedDict(TypedDict, total=False): {'first': 'One'}, {'first': 'One'}, {'first': 'One'}, - {'first': 'One', 'second': ''}, - {'first': 'One', 'second': ''}, + {'first': 'One'}, + {'first': 'One'}, {'first': 'One', 'second': ''}, {'first': 'One', 'second': 'T'}, {'first': 'One', 'second': 'Tw'}, @@ -828,6 +828,7 @@ async def test_stream_result_type_primitif_array(allow_model_requests: None): v = [c async for c in result.stream(debounce_by=None)] assert v == snapshot( [ + [''], ['f'], ['fi'], ['fir'], @@ -835,13 +836,13 @@ async def test_stream_result_type_primitif_array(allow_model_requests: None): ['first'], ['first'], ['first'], - ['first'], + ['first', ''], ['first', 'O'], ['first', 'On'], ['first', 'One'], ['first', 'One'], ['first', 'One'], - ['first', 'One'], + ['first', 'One', ''], ['first', 'One', 's'], ['first', 'One', 'se'], ['first', 'One', 'sec'], @@ -850,7 +851,7 @@ async def test_stream_result_type_primitif_array(allow_model_requests: None): ['first', 'One', 'second'], ['first', 'One', 'second'], ['first', 'One', 'second'], - ['first', 'One', 'second'], + ['first', 'One', 'second', ''], ['first', 'One', 'second', 'T'], ['first', 'One', 'second', 'Tw'], ['first', 'One', 'second', 'Two'], @@ -869,10 +870,10 @@ async def test_stream_result_type_primitif_array(allow_model_requests: None): assert result.cost().response_tokens == len(stream) -async def test_stream_result_type_basemodel(allow_model_requests: None): +async def test_stream_result_type_basemodel_with_default_params(allow_model_requests: None): class MyTypedBaseModel(BaseModel): - first: str = '' # Note: Don't forget to set default values - second: str = '' + first: str = '' # Note: Default, set value. + second: str = '' # Note: Default, set value. # Given stream = [ @@ -958,6 +959,79 @@ class MyTypedBaseModel(BaseModel): assert result.cost().response_tokens == len(stream) +async def test_stream_result_type_basemodel_with_required_params(allow_model_requests: None): + class MyTypedBaseModel(BaseModel): + first: str # Note: Required params + second: str # Note: Required params + + # Given + stream = [ + text_chunk('{'), + text_chunk('"'), + text_chunk('f'), + text_chunk('i'), + text_chunk('r'), + text_chunk('s'), + text_chunk('t'), + text_chunk('"'), + text_chunk(':'), + text_chunk(' '), + text_chunk('"'), + text_chunk('O'), + text_chunk('n'), + text_chunk('e'), + text_chunk('"'), + text_chunk(','), + text_chunk(' '), + text_chunk('"'), + text_chunk('s'), + text_chunk('e'), + text_chunk('c'), + text_chunk('o'), + text_chunk('n'), + text_chunk('d'), + text_chunk('"'), + text_chunk(':'), + text_chunk(' '), + text_chunk('"'), + text_chunk('T'), + text_chunk('w'), + text_chunk('o'), + text_chunk('"'), + text_chunk('}'), + chunk([]), + ] + + mock_client = MockMistralAI.create_stream_mock(stream) + model = MistralModel('mistral-large-latest', client=mock_client) + agent = Agent(model=model, result_type=MyTypedBaseModel) + + # When + async with agent.run_stream('User prompt value') as result: + # Then + assert result.is_structured + assert not result.is_complete + v = [c async for c in result.stream(debounce_by=None)] + assert v == snapshot( + [ + MyTypedBaseModel(first='One', second=''), + MyTypedBaseModel(first='One', second='T'), + MyTypedBaseModel(first='One', second='Tw'), + MyTypedBaseModel(first='One', second='Two'), + MyTypedBaseModel(first='One', second='Two'), + MyTypedBaseModel(first='One', second='Two'), + MyTypedBaseModel(first='One', second='Two'), + ] + ) + assert result.is_complete + assert result.cost().request_tokens == 34 + assert result.cost().response_tokens == 34 + assert result.cost().total_tokens == 34 + + # double check cost matches stream count + assert result.cost().response_tokens == len(stream) + + ##################### ## Completion Function call ##################### diff --git a/uv.lock b/uv.lock index 1c8e956..530d8de 100644 --- a/uv.lock +++ b/uv.lock @@ -820,15 +820,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/75/fc5a34b0376437eaac80c22886840d8f39ee7f0992c2e3bd4c246b91cab3/jiter-0.7.1-cp39-none-win_amd64.whl", hash = "sha256:6592f4067c74176e5f369228fb2995ed01400c9e8e1225fb73417183a5e635f0", size = 202098 }, ] -[[package]] -name = "json-repair" -version = "0.30.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/2f/7a/7745d0d908563a478421c7520649dfd6a5c551858e2233ff7caf20cb8df7/json_repair-0.30.3.tar.gz", hash = "sha256:0ac56e7ae9253ee9c507a7e1a3a26799c9b0bbe5e2bec1b2cc5053e90d5b05e3", size = 27803 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/fd/2d/79a46330c4b97ee90dd403fb0d267da7b25b24d7db604c5294e5c57d5f7c/json_repair-0.30.3-py3-none-any.whl", hash = "sha256:63bb588162b0958ae93d85356ecbe54c06b8c33f8a4834f93fa2719ea669804e", size = 18951 }, -] - [[package]] name = "jsonpath-python" version = "1.0.6" @@ -1677,7 +1668,6 @@ logfire = [ { name = "logfire" }, ] mistral = [ - { name = "json-repair" }, { name = "mistralai" }, ] openai = [ @@ -1709,7 +1699,6 @@ requires-dist = [ { name = "griffe", specifier = ">=1.3.2" }, { name = "groq", marker = "extra == 'groq'", specifier = ">=0.12.0" }, { name = "httpx", specifier = ">=0.27.2" }, - { name = "json-repair", marker = "extra == 'mistral'", specifier = ">=0.30.3" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=2.3" }, { name = "logfire-api", specifier = ">=1.2.0" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.2.5" },