forked from samholt/HDTwinGen
-
Notifications
You must be signed in to change notification settings - Fork 0
/
summarizer_and_modifier.py
562 lines (455 loc) · 22.5 KB
/
summarizer_and_modifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
import os
import re
import unittest
import tempfile
import shutil
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
from pathlib import Path
from long_vars import extension_to_language
from llm_utils import embedding_rl
import pandas as pd
import shelve
import numpy as np
import ast
import shutil
import subprocess
from time import perf_counter
import json
from timeout_decorator import timeout
def cache_or_compute(variable_key, computation_function, *args, cache_file="project_cache"):
"""
Generic function to cache or compute data.
Parameters:
- variable_key (str): The key associated with the variable to be stored or retrieved from the cache.
- computation_function (function): The function to compute the data if it's not in the cache.
- *args: Arguments for the computation_function.
- cache_file (str): Path to the project-wide shelve cache.
Returns:
- result: The cached or computed data.
"""
# Access the shelve database
with shelve.open(cache_file) as db:
# If the data associated with the variable_key is already in the cache
if variable_key in db:
return db[variable_key]
# If the data isn't in the cache, compute it
result = computation_function(*args)
# Store the computed data in the cache
db[variable_key] = result
return result
# =================================================================================================
# Summarizer
def load_code_files_into_dict(folder_path, number_lines=False, file_extensions=None, skip_tests=False, return_all_files=False):
root_path = Path(folder_path)
code_files_dict = {}
# Consider these as typical code or text file extensions
# code_file_extensions = ['.py', '.c', '.cpp', '.h', '.java', '.js', '.ts', '.html', '.css', '.md', '.txt', '.xml']
if file_extensions is None:
code_file_extensions = ['.py', '.c', '.cpp', '.h', '.java', '.js', '.ts', '.html', '.css', '.txt', '.xml', '.sql']
else:
code_file_extensions = file_extensions
for current_path, _, files in os.walk(root_path):
# Skip the .git folder
if ".git" in current_path:
continue
for file in files:
if not return_all_files and (Path(file).suffix not in code_file_extensions):
continue # Skip non-code files
if skip_tests and ('test' in file.lower() or 'test' in current_path.lower()):
continue
file_path = Path(current_path) / file
try:
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.read().split('\n')
if number_lines:
numbered_lines = [f"{idx + 1}: {line}" for idx, line in enumerate(lines)]
else:
numbered_lines = lines
code_files_dict[str(file_path.relative_to(root_path))] = numbered_lines
except UnicodeDecodeError:
# Skip non-UTF-8 encoded files
continue
except Exception as e:
print(f"Error reading {file_path}: {e}")
return code_files_dict
def count_total_characters(data):
total_chars = 0
for content in data.values():
# If content is a list of lines
if isinstance(content, list):
for line in content:
total_chars += len(line)
# If content is an empty string (e.g. non-text files)
elif isinstance(content, str):
total_chars += len(content)
return total_chars
def count_key_characters(data):
return sum(len(key) for key in data.keys())
def extract_file_names(s):
file_name_pattern = r'([\w/]+\.ts)'
matches = re.findall(file_name_pattern, s)
matches = list(set(matches))
return matches
def format_files_into_prompt_format(file_dict, unique_file_names=None, character_limit=14_000):
"""
Generate a string output based on the given dictionary of files and an optional set of unique file names.
Parameters:
- file_dict: Dictionary with file names as keys and list of lines as values.
- unique_file_names (optional): Set of unique file names to be extracted from the files dictionary.
Returns:
- A string representation in the specified format.
"""
# If unique_file_names is not provided, process all file names in the dictionary
if unique_file_names is None:
unique_file_names = file_dict.keys()
files_missed_out = []
output = ""
for file_name in unique_file_names:
# Determine the language based on the file extension
file_extension = file_name[file_name.rfind('.'):]
lang = extension_to_language.get(file_extension, 'LANG')
content = file_dict.get(file_name, [])
additional_output = file_name + "\n```" + lang + "\n" + "\n".join(content) + "\n```\n\n"
if (len(output) + len(additional_output)) < character_limit:
output += file_name + "\n```" + lang + "\n" + "\n".join(content) + "\n```\n\n"
else:
files_missed_out.append(file_name)
return output, files_missed_out
def remove_line_numbers(strings):
# Regular expression pattern to match 'number: ' at the beginning of each string
pattern = r'^\d+:\s'
return [re.sub(pattern, '', s) for s in strings]
def add_line_numbers(strings):
numbered_lines = [f"{idx + 1}: {line}" for idx, line in enumerate(strings)]
return numbered_lines
def fix_line_spacings(strings):
# Regular expression pattern to match 'number: ' at the beginning of each string
lines = []
for line in strings:
leading_spaces = len(line) - len(line.lstrip(' '))
leading_spaces = int(np.ceil(leading_spaces / 2) * 2)
line = ' ' * leading_spaces + line.lstrip(' ')
lines.append(line)
return lines
def fix_indentation(lines: [str]) -> [str]:
fixing_rounds = 0
valid_python = False
while not valid_python and fixing_rounds < 10:
try:
ast.parse('\n'.join(lines))
valid_python = True
except IndentationError as e:
fixing_rounds += 1
# Fix indentation
prob_ln = e.args[1][1] - 1
line = lines[prob_ln]
prev_line = lines[prob_ln - 1]
if prev_line == '':
break
leading_spaces = len(prev_line) - len(prev_line.lstrip(' '))
leading_spaces += 2
line = ' ' * leading_spaces + line.lstrip(' ')
lines[prob_ln] = line
except Exception as e:
break
return lines
# =================================================================================================
# Summarizer Using Embeddings
def embed_dataframe(df, batch_size = 16):
"""
Embed 'code_raw' cells in the dataframe using the embedding_rl function.
Parameters:
- df (pd.DataFrame): Dataframe containing the 'code_raw' column to embed.
Returns:
- df_embedded (pd.DataFrame): Dataframe with an additional 'embeddings' column.
"""
# Make sure 'code_raw' exists in dataframe
if 'code_raw' not in df.columns:
raise ValueError("The dataframe must contain a 'code_raw' column.")
# Split the 'code_raw' column into batches of size 2048 or fewer
n = len(df)
code_batches = [df['code_raw'].iloc[i:i+batch_size].tolist() for i in range(0, n, batch_size)]
# For each batch, get the embeddings using the embedding_rl function
all_embeddings = []
for batch in code_batches:
embeddings = embedding_rl(batch)
all_embeddings.extend(embeddings)
# Add the embeddings back to the dataframe
df['embeddings'] = all_embeddings
return df
def embed_all_code_files(file_dict):
df = pd.DataFrame([{'file_path': k, 'code': v} for k, v in file_dict.items()])
df['code_raw'] = df['code'].apply(lambda x: '\n'.join(remove_line_numbers(x)))
df = embed_dataframe(df)
return df
# =========================================================================
# Write files out utility functions
def write_files_from_dict(file_dict, base_dir="./output"):
"""
Writes files to a folder based on the given dictionary.
:param file_dict: Dictionary with filenames as keys and lists of file content as values.
:param base_dir: Base directory where files should be saved.
"""
# Ensure base directory exists
if not os.path.exists(base_dir):
os.makedirs(base_dir)
# Iterate through the file dictionary
for file_path, lines in file_dict.items():
# Construct the full path for the file
full_path = os.path.join(base_dir, file_path)
# Create the directories if they don't exist
os.makedirs(os.path.dirname(full_path), exist_ok=True)
lines = fix_line_spacings(remove_line_numbers(lines))
# Write content to the file
with open(full_path, "w") as f:
for line in lines:
f.write(line + "\n")
# =========================================================================
# Misc
def write_files_from_dict(file_dict, base_dir="output"):
"""
Writes files to a folder based on the given dictionary.
:param file_dict: Dictionary with filenames as keys and lists of file content as values.
:param base_dir: Base directory where files should be saved.
"""
# If base directory exists, remove it
if os.path.exists(base_dir):
shutil.rmtree(base_dir)
# Ensure base directory is created again
os.makedirs(base_dir)
# Iterate through the file dictionary
for file_path, lines in file_dict.items():
# Construct the full path for the file
full_path = os.path.join(base_dir, file_path)
# Create the directories if they don't exist
os.makedirs(os.path.dirname(full_path), exist_ok=True)
lines = fix_line_spacings(remove_line_numbers(lines))
# Write content to the file
with open(full_path, "w") as f:
for line in lines:
f.write(line + "\n")
# =========================================================================
# Syntax checker
@timeout(60, timeout_exception=StopIteration)
def check_syntax(file_dict: dict):
with tempfile.TemporaryDirectory() as tmpdirname:
write_files_from_dict(file_dict, tmpdirname)
# Look for top level folders
top_folders = list(set([file_name.split('/')[0] for file_name in file_dict.keys() if '/' in file_name]))
top_modules = list(set([top_folder for top_folder in top_folders if f'{top_folder}/__init__.py' in file_dict.keys()]))
top_modules_extra = list(set([file_name.split('/')[0] for file_name in file_dict.keys() if file_name.count('/') >= 2 and '.py' in file_name]))
top_modules.extend(top_modules_extra)
top_folder_with_code = [top_folder for top_folder in top_folders if (not f'{top_folder}/__init__.py' in file_dict.keys()) and (len([f for f in file_dict.keys() if top_folder in f and f.endswith('.py')]) >= 1)]
top_code_files = [f for f in file_dict.keys() if f.endswith('.py') and not '/' in f]
pylint_args = []
pylint_args.extend(top_modules)
pylint_args.extend([f'{f}/*.py' for f in top_folder_with_code])
pylint_args.extend(top_code_files)
command = ['python3', '-m', 'pylint', '--disable=all', '--enable=E', '--score=no']
if len(pylint_args) == 0:
# No python files found, therefore skipping
return ''
command.extend(pylint_args)
result = subprocess.run(
command,
capture_output=True,
text=True,
cwd=tmpdirname)
stderr = result.stderr
stdout = result.stdout
return stdout
def count_errors_in_syntax(syntax_output: str):
pattern = r".+:\d+:\d+: [E]\d+: .+"
errors = re.findall(pattern, syntax_output)
return len(errors)
# =========================================================================
# PyTest
import xml.etree.ElementTree as ET
def parse_and_print_junit_xml(file_path):
tree = ET.parse(file_path)
root = tree.getroot()
testsuites = []
for testsuite in root.findall('testsuite'):
total_tests = int(testsuite.get("tests", 0))
errors = int(testsuite.get("errors", 0))
failures = int(testsuite.get("failures", 0))
skipped = int(testsuite.get("skipped", 0))
passed = total_tests - (errors + failures + skipped)
results = {
"total": total_tests,
"passed": passed,
"errors": errors,
"failures": failures,
"skipped": skipped
}
testsuites.append(results)
return testsuites
@timeout(60, timeout_exception=StopIteration)
def check_pytest(file_dict: dict):
with tempfile.TemporaryDirectory() as tmpdirname:
write_files_from_dict(file_dict, tmpdirname)
command = ['python3', '-m', 'pytest', '--junitxml=result_pytest.xml']
result = subprocess.run(
command,
capture_output=True,
text=True,
cwd=tmpdirname)
stderr = result.stderr
stdout = result.stdout
try:
results = parse_and_print_junit_xml(Path(tmpdirname) / 'result_pytest.xml')
total_tests = np.sum([r['total'] for r in results])
total_passed = np.sum([r['passed'] for r in results])
if total_tests == total_passed:
output = 'All tests passed.'
else:
output = f'Failed tests. Please check the output below. {stdout}'
except FileNotFoundError as e:
output = 'No tests found.'
# print(stdout)
# print(stderr)
# print(results)
# print('=====================')
return output
class TestSummarizer(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def test_group_chat_implementation(self):
code_repository_folder_file_path = "./test_suite/mock_data/chatbot-ui/"
# Assuming your agent's method to process and output code is named "process"
output = load_code_files_into_dict(code_repository_folder_file_path)
total_characters = count_total_characters(output)
print(f"Total characters in the data structure: {total_characters}")
total_key_characters = count_key_characters(output)
print(f"Total characters in the dictionary keys: {total_key_characters}")
print(f"Total characters in the dictionary values: {total_characters - total_key_characters}")
def test_extract_file_names(self):
string_content = """
To add this feature we need to consider both the frontend and backend parts of the implementation.
On the frontend, we'll have to:
1. Add the new button in the appropriate component file. As this is a chat feature, it seems likely that the required component would be located in the `components/` directory, but without viewing the contents it's hard to say for certain which file we should target. I suggest inspecting `components/Sidebar/index.ts` and `components/Buttons/SidebarActionButton/index.ts`. These file names suggest they might contain the code defining the chat UI and its buttons.
2. Once we've created our button, we'll need to bind an onClick event handler to it. This event handler will interact with `pages/api/chat.ts`.
On the backend:
1. We need to define the API route which will handle delete requests for chat message history. I suggest viewing the contents of `pages/api/chat.ts` to see if there is an existing endpoint we can adapt, or whether we need to define a new one.
May I proceed to view the contents of `components/Sidebar/index.ts`, `components/Buttons/SidebarActionButton/index.ts`, and `pages/api/chat.ts`?
"""
files = extract_file_names(string_content)
# Create a Test Suite
class TestGenerateOutput(unittest.TestCase):
def setUp(self):
self.files = {
'components/Sidebar/index.ts': ['import React from "react";', '...'],
'components/Buttons/SidebarActionButton/index.ts': ['import React from "react";', '...'],
'pages/api/chat.ts': ['import { NextApiRequest, NextApiResponse } from "next";', '...']
}
def test_all_files_returned(self):
result = format_files_into_prompt_format(self.files)
for file_name in self.files.keys():
self.assertIn(file_name, result)
def test_specific_files_returned(self):
unique_files = {'pages/api/chat.ts'}
result = format_files_into_prompt_format(self.files, unique_files)
self.assertIn('pages/api/chat.ts', result)
self.assertNotIn('components/Sidebar/index.ts', result)
self.assertNotIn('components/Buttons/SidebarActionButton/index.ts', result)
class TestLoadCodeFilesIntoDictFormat(unittest.TestCase):
def setUp(self):
self.temp_dir = TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def create_temp_file(self, folder, filename, content):
with open(Path(folder) / filename, 'w', encoding='utf-8') as f:
f.write(content)
def test_line_format(self):
# Create a sample file with multiple lines
file_content = "Line1\nLine2\nLine3"
self.create_temp_file(self.temp_dir.name, "sample.txt", file_content)
result = load_code_files_into_dict(self.temp_dir.name)
expected_result = [
"1: Line1",
"2: Line2",
"3: Line3"
]
self.assertIn("sample.txt", result)
self.assertEqual(result["sample.txt"], expected_result)
def test_line_format_large_file(self):
# Test with a large file
lines = [f"Line {i}" for i in range(1, 1001)] # 1000 lines
file_content = "\n".join(lines)
self.create_temp_file(self.temp_dir.name, "large.txt", file_content)
result = load_code_files_into_dict(self.temp_dir.name)
self.assertIn("large.txt", result)
# Ensure that the line format is correct for all lines
for idx, line in enumerate(result["large.txt"], start=1):
expected_line = f"{idx}: Line {idx}"
self.assertEqual(line, expected_line)
# Unit test class
class TestRemoveLineNumbers(unittest.TestCase):
def test_remove_line_numbers(self):
# Test case 1
input_strings = ['5: We appreciate your interest in contributing to our project.', '25: git clone https://github.com/mckaywrigley/chatbot-ui.git']
expected_output = ['We appreciate your interest in contributing to our project.', 'git clone https://github.com/mckaywrigley/chatbot-ui.git']
self.assertEqual(remove_line_numbers(input_strings), expected_output)
# Test case 2: Test strings without line numbers and colons
input_strings = ['Hello World!', 'Python is fun.']
self.assertEqual(remove_line_numbers(input_strings), input_strings)
# Test case 3: Test empty strings and strings with only line numbers and colons
input_strings = ['5: ', '25:', '']
expected_output = [' ', '', '']
self.assertEqual(remove_line_numbers(input_strings), expected_output)
class TestFixLineSpacings(unittest.TestCase):
def test_fix_line_spacings(self):
# Test case 1
input_strings = [' def add_message(self, message):',
' def __init__(self,chat_id,participants=None,messages=None):',
' self.chat_id = chat_id',
' self.messages.append(message)',
' def add_message(self, message):',]
expected_output = [' def add_message(self, message):',
' def __init__(self,chat_id,participants=None,messages=None):',
' self.chat_id = chat_id',
' self.messages.append(message)',
' def add_message(self, message):',]
self.assertEqual(fix_line_spacings(input_strings), expected_output)
class TestCheckSyntax(unittest.TestCase):
def test_syntax_parser_on_file_dict_example_clean(self):
file_dict = load_code_files_into_dict('repos/flask/examples/tutorial')
syntax_output = check_syntax(file_dict)
print(syntax_output)
self.assertTrue(check_syntax({}), "Empty dictionary should be considered valid syntax")
def test_syntax_parser_on_file_dict_example_errors(self):
file_dict = load_code_files_into_dict('results/latest/run-20230922-044053_LLMatic-zero_shot_donnemartin-system-design-oop-whatsapp-donnemartin-system-design-oop-url_shortener-donnemartin-system-design-oop-dropbox-donnemartin-system-design-oop-twitter_0_10-runs_log/whatsapp/1/LLMatic')
syntax_output = check_syntax(file_dict)
print(syntax_output)
self.assertTrue(check_syntax({}), "Empty dictionary should be considered valid syntax")
def test_syntax_parser_on_file_dict_giving_help_func(self):
file_dict = {'requirements.txt': ['Flask', 'requests', 'sqlite3', 'pytest']}
syntax_output = check_syntax(file_dict)
print(syntax_output)
self.assertTrue(check_syntax({}), "Empty dictionary should be considered valid syntax")
class TestPyTest(unittest.TestCase):
def test_check_pytest_clean(self):
file_dict = load_code_files_into_dict('repos/flask/examples/tutorial')
syntax_output = check_pytest(file_dict)
print(syntax_output)
self.assertTrue(check_syntax({}), "Empty dictionary should be considered valid syntax")
def test_check_pytest_errors(self):
file_dict = load_code_files_into_dict('results/latest/run-20230922-044053_LLMatic-zero_shot_donnemartin-system-design-oop-whatsapp-donnemartin-system-design-oop-url_shortener-donnemartin-system-design-oop-dropbox-donnemartin-system-design-oop-twitter_0_10-runs_log/whatsapp/1/LLMatic')
syntax_output = check_pytest(file_dict)
print(syntax_output)
self.assertTrue(check_syntax({}), "Empty dictionary should be considered valid syntax")
if __name__ == '__main__':
# unittest.main()
TestSummarizer().test_group_chat_implementation()
# TestFixLineSpacings().test_fix_line_spacings()
# TestCheckSyntax().test_syntax_parser_on_file_dict_example_clean()
# TestCheckSyntax().test_syntax_parser_on_file_dict_example_errors()
# TestCheckSyntax().test_syntax_parser_on_file_dict_giving_help_func()
# TestPyTest().test_check_pytest_errors()
# from utils.llm_tools import ProcessFunctionCall
# # ProcessFunctionCall().test_process_function_call_and_return_message_empty()
# ProcessFunctionCall().test_process_function_call_and_return_message_json_args_cut_short()