Skip to content

Commit

Permalink
Improve parsing complex json strings returned by LLM (#989)
Browse files Browse the repository at this point in the history
- Improve escaping to load complex json objects
- Fallback to a more forgiving [json5](https://json5.org/) loader if `json.loads` cannot parse complex json str

This should reduce failures to pick research tool and run code by agent
  • Loading branch information
debanjum authored Nov 28, 2024
2 parents 8cb0db0 + 8c120a5 commit f1190cc
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 8 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ dependencies = [
"anthropic == 0.26.1",
"docx2txt == 0.8",
"google-generativeai == 0.8.3",
"pyjson5 == 1.6.7",
]
dynamic = ["version"]

Expand Down
43 changes: 43 additions & 0 deletions src/khoj/processor/conversation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import mimetypes
import os
import queue
import re
import uuid
from dataclasses import dataclass
from datetime import datetime
Expand All @@ -14,6 +15,7 @@
from typing import Any, Callable, Dict, List, Optional

import PIL.Image
import pyjson5
import requests
import tiktoken
import yaml
Expand Down Expand Up @@ -538,6 +540,47 @@ def clean_code_python(code: str):
return code.strip().removeprefix("```python").removesuffix("```")


def load_complex_json(json_str):
"""
Preprocess a raw JSON string to escape unescaped double quotes within value strings,
while preserving the JSON structure and already escaped quotes.
"""

def replace_unescaped_quotes(match):
# Get the content between colons and commas/end braces
content = match.group(1)
# Replace unescaped double, single quotes that aren't already escaped
# Uses negative lookbehind to avoid replacing already escaped quotes
# Replace " with \"
processed_dq = re.sub(r'(?<!\\)"', '\\"', content)
# Replace \' with \\'
processed_final = re.sub(r"(?<!\\)\\'", r"\\\\'", processed_dq)
return f': "{processed_final}"'

# Match content between : and either , or }
# This pattern looks for ': ' followed by any characters until , or }
pattern = r':\s*"(.*?)(?<!\\)"(?=[,}])'

# Process the JSON string
cleaned = clean_json(rf"{json_str}")
processed = re.sub(pattern, replace_unescaped_quotes, cleaned)

# See which json loader can load the processed JSON as valid
errors = []
json_loaders_to_try = [json.loads, pyjson5.loads]
for loads in json_loaders_to_try:
try:
return loads(processed)
except (json.JSONDecodeError, pyjson5.Json5Exception) as e:
errors.append(f"{type(e).__name__}: {str(e)}")

# If all loaders fail, raise the aggregated error
raise ValueError(
f"Failed to load JSON with errors: {'; '.join(errors)}\n\n"
f"While attempting to load this cleaned JSON:\n{processed}"
)


def defilter_query(query: str):
"""Remove any query filters in query"""
defiltered_query = query
Expand Down
6 changes: 2 additions & 4 deletions src/khoj/processor/tools/run_code.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import base64
import datetime
import json
import logging
import mimetypes
import os
Expand All @@ -15,8 +14,8 @@
from khoj.processor.conversation.utils import (
ChatEvent,
clean_code_python,
clean_json,
construct_chat_history,
load_complex_json,
)
from khoj.routers.helpers import send_message_to_model_wrapper
from khoj.utils.helpers import is_none_or_empty, timer, truncate_code_context
Expand Down Expand Up @@ -135,8 +134,7 @@ async def generate_python_code(
)

# Validate that the response is a non-empty, JSON-serializable list
response = clean_json(response)
response = json.loads(response)
response = load_complex_json(response)
code = response.get("code", "").strip()
input_files = response.get("input_files", [])
input_links = response.get("input_links", [])
Expand Down
6 changes: 2 additions & 4 deletions src/khoj/routers/research.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
from datetime import datetime
from typing import Callable, Dict, List, Optional
Expand All @@ -10,10 +9,10 @@
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
InformationCollectionIteration,
clean_json,
construct_chat_history,
construct_iteration_history,
construct_tool_chat_history,
load_complex_json,
)
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.run_code import run_code
Expand Down Expand Up @@ -106,8 +105,7 @@ async def apick_next_tool(
return

try:
response = clean_json(response)
response = json.loads(response)
response = load_complex_json(response)
selected_tool = response.get("tool", None)
generated_query = response.get("query", None)
scratchpad = response.get("scratchpad", None)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_conversation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ def test_truncate_single_large_question(self):
assert truncated_chat_history[0] != copy_big_chat_message


def test_load_complex_raw_json_string():
# Arrange
raw_json = r"""{"key": "value with unescaped " and unescaped \' and escaped \" and escaped \\'"}"""
expeced_json = {"key": "value with unescaped \" and unescaped \\' and escaped \" and escaped \\'"}

# Act
parsed_json = utils.load_complex_json(raw_json)

# Assert
assert parsed_json == expeced_json


def generate_content(count):
return " ".join([f"{index}" for index, _ in enumerate(range(count))])

Expand Down

0 comments on commit f1190cc

Please sign in to comment.