diff --git a/dbgpt/util/json_utils.py b/dbgpt/util/json_utils.py index f5e41f431..4ba081c1b 100644 --- a/dbgpt/util/json_utils.py +++ b/dbgpt/util/json_utils.py @@ -48,6 +48,7 @@ def find_json_objects(text): escape_character = False stack = [] start_index = -1 + modified_text = list(text) # Convert text to a list for easy modification for i, char in enumerate(text): # Handle escape characters @@ -59,12 +60,12 @@ def find_json_objects(text): if char == '"' and not escape_character: inside_string = not inside_string - if not inside_string and char == "\n": - continue - if inside_string and char == "\n": - char = "\\n" - if inside_string and char == "\t": - char = "\\t" + # Replace newline and tab characters inside strings + if inside_string: + if char == "\n": + modified_text[i] = "\\n" + elif char == "\t": + modified_text[i] = "\\t" # Handle opening brackets if char in "{[" and not inside_string: @@ -78,7 +79,8 @@ def find_json_objects(text): if not stack: end_index = i + 1 try: - json_obj = json.loads(text[start_index:end_index]) + json_str = "".join(modified_text[start_index:end_index]) + json_obj = json.loads(json_str) json_objects.append(json_obj) except json.JSONDecodeError: pass diff --git a/dbgpt/util/tests/test_json_utils.py b/dbgpt/util/tests/test_json_utils.py new file mode 100644 index 000000000..414290a62 --- /dev/null +++ b/dbgpt/util/tests/test_json_utils.py @@ -0,0 +1,64 @@ +import pytest + +from dbgpt.util.json_utils import find_json_objects + +# 定义参数化测试数据 +test_data = [ + ( + """ + ```json + + { + "serial_number": "1", + "agent": "CodeOptimizer", + "content": "```json +select * +from table +where column = 'value' +``` optimize the code above.", + "rely": "" + } + ``` + """, + [ + { + "serial_number": "1", + "agent": "CodeOptimizer", + "content": "```json\nselect * \nfrom table\nwhere column = 'value'\n``` optimize the code above.", + "rely": "", + } + ], + "Test case with nested code block", + ), + ( + """ + { + "key": "value" + } + """, + [{"key": "value"}], + "Test case with simple JSON", + ), + ( + """ + { + "key1": "value1" + } + { + "key2": "value2" + } + """, + [{"key1": "value1"}, {"key2": "value2"}], + "Test case with multiple JSON objects", + ), + ("", [], "Test case with empty input"), + ("This is not a JSON string", [], "Test case with non-JSON input"), +] + + +@pytest.mark.parametrize("text, expected, description", test_data) +def test_find_json_objects(text, expected, description): + result = find_json_objects(text) + assert ( + result == expected + ), f"Test failed: {description}\nExpected: {expected}\nGot: {result}"