Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve parsing complex json strings returned by LLM #989

Merged
merged 2 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ dependencies = [
"anthropic == 0.26.1",
"docx2txt == 0.8",
"google-generativeai == 0.8.3",
"pyjson5 == 1.6.7",
]
dynamic = ["version"]

Expand Down
43 changes: 43 additions & 0 deletions src/khoj/processor/conversation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import mimetypes
import os
import queue
import re
import uuid
from dataclasses import dataclass
from datetime import datetime
Expand All @@ -14,6 +15,7 @@
from typing import Any, Callable, Dict, List, Optional

import PIL.Image
import pyjson5
import requests
import tiktoken
import yaml
Expand Down Expand Up @@ -538,6 +540,47 @@ def clean_code_python(code: str):
return code.strip().removeprefix("```python").removesuffix("```")


def load_complex_json(json_str):
"""
Preprocess a raw JSON string to escape unescaped double quotes within value strings,
while preserving the JSON structure and already escaped quotes.
"""

def replace_unescaped_quotes(match):
# Get the content between colons and commas/end braces
content = match.group(1)
# Replace unescaped double, single quotes that aren't already escaped
# Uses negative lookbehind to avoid replacing already escaped quotes
# Replace " with \"
processed_dq = re.sub(r'(?<!\\)"', '\\"', content)
# Replace \' with \\'
processed_final = re.sub(r"(?<!\\)\\'", r"\\\\'", processed_dq)
return f': "{processed_final}"'

# Match content between : and either , or }
# This pattern looks for ': ' followed by any characters until , or }
pattern = r':\s*"(.*?)(?<!\\)"(?=[,}])'

# Process the JSON string
cleaned = clean_json(rf"{json_str}")
processed = re.sub(pattern, replace_unescaped_quotes, cleaned)

# See which json loader can load the processed JSON as valid
errors = []
json_loaders_to_try = [json.loads, pyjson5.loads]
for loads in json_loaders_to_try:
try:
return loads(processed)
except (json.JSONDecodeError, pyjson5.Json5Exception) as e:
errors.append(f"{type(e).__name__}: {str(e)}")

# If all loaders fail, raise the aggregated error
raise ValueError(
f"Failed to load JSON with errors: {'; '.join(errors)}\n\n"
f"While attempting to load this cleaned JSON:\n{processed}"
)


def defilter_query(query: str):
"""Remove any query filters in query"""
defiltered_query = query
Expand Down
6 changes: 2 additions & 4 deletions src/khoj/processor/tools/run_code.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import base64
import datetime
import json
import logging
import mimetypes
import os
Expand All @@ -15,8 +14,8 @@
from khoj.processor.conversation.utils import (
ChatEvent,
clean_code_python,
clean_json,
construct_chat_history,
load_complex_json,
)
from khoj.routers.helpers import send_message_to_model_wrapper
from khoj.utils.helpers import is_none_or_empty, timer, truncate_code_context
Expand Down Expand Up @@ -135,8 +134,7 @@ async def generate_python_code(
)

# Validate that the response is a non-empty, JSON-serializable list
response = clean_json(response)
response = json.loads(response)
response = load_complex_json(response)
code = response.get("code", "").strip()
input_files = response.get("input_files", [])
input_links = response.get("input_links", [])
Expand Down
6 changes: 2 additions & 4 deletions src/khoj/routers/research.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
from datetime import datetime
from typing import Callable, Dict, List, Optional
Expand All @@ -10,10 +9,10 @@
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
InformationCollectionIteration,
clean_json,
construct_chat_history,
construct_iteration_history,
construct_tool_chat_history,
load_complex_json,
)
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.run_code import run_code
Expand Down Expand Up @@ -106,8 +105,7 @@ async def apick_next_tool(
return

try:
response = clean_json(response)
response = json.loads(response)
response = load_complex_json(response)
selected_tool = response.get("tool", None)
generated_query = response.get("query", None)
scratchpad = response.get("scratchpad", None)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_conversation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ def test_truncate_single_large_question(self):
assert truncated_chat_history[0] != copy_big_chat_message


def test_load_complex_raw_json_string():
# Arrange
raw_json = r"""{"key": "value with unescaped " and unescaped \' and escaped \" and escaped \\'"}"""
expeced_json = {"key": "value with unescaped \" and unescaped \\' and escaped \" and escaped \\'"}

# Act
parsed_json = utils.load_complex_json(raw_json)

# Assert
assert parsed_json == expeced_json


def generate_content(count):
return " ".join([f"{index}" for index, _ in enumerate(range(count))])

Expand Down