diff --git a/Makefile b/Makefile index 5f50d8416..dfff35d3a 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,7 @@ installer: --add-data "cover_agent/settings/test_generation_prompt.toml:." \ --add-data "cover_agent/settings/analyze_suite_test_headers_indentation.toml:." \ --add-data "cover_agent/settings/analyze_suite_test_insert_line.toml:." \ + --add-data "cover_agent/settings/analyze_test_run_failure.toml:." \ --add-data "$(SITE_PACKAGES)/vendor:wandb/vendor" \ --hidden-import=tiktoken_ext.openai_public \ --hidden-import=tiktoken_ext \ diff --git a/cover_agent/AICaller.py b/cover_agent/AICaller.py index 6a35df95b..a78d8c1f3 100644 --- a/cover_agent/AICaller.py +++ b/cover_agent/AICaller.py @@ -18,7 +18,7 @@ def __init__(self, model: str, api_base: str = ""): self.model = model self.api_base = api_base - def call_model(self, prompt: dict, max_tokens=4096): + def call_model(self, prompt: dict, max_tokens=4096, stream=True): """ Call the language model with the provided prompt and retrieve the response. @@ -61,17 +61,17 @@ def call_model(self, prompt: dict, max_tokens=4096): response = litellm.completion(**completion_params) chunks = [] - print("Streaming results from LLM model...") + print("Streaming results from LLM model...") if stream else None try: for chunk in response: - print(chunk.choices[0].delta.content or "", end="", flush=True) + print(chunk.choices[0].delta.content or "", end="", flush=True) if stream else None chunks.append(chunk) time.sleep( 0.01 ) # Optional: Delay to simulate more 'natural' response pacing except Exception as e: - print(f"Error during streaming: {e}") - print("\n") + print(f"Error during streaming: {e}") if stream else None + print("\n") if stream else None model_response = litellm.stream_chunk_builder(chunks, messages=messages) diff --git a/cover_agent/PromptBuilder.py b/cover_agent/PromptBuilder.py index 6b81bf58d..c800ec1c9 100644 --- a/cover_agent/PromptBuilder.py +++ b/cover_agent/PromptBuilder.py @@ -42,6 +42,7 @@ def __init__( additional_instructions: str = "", failed_test_runs: str = "", language: str = "python", + testing_framework: str = "NOT KNOWN", ): """ The `PromptBuilder` class is responsible for building a formatted prompt string by replacing placeholders with the actual content of files read during initialization. It takes in various paths and settings as parameters and provides a method to generate the prompt. @@ -72,6 +73,8 @@ def __init__( self.test_file = self._read_file(test_file_path) self.code_coverage_report = code_coverage_report self.language = language + self.testing_framework = testing_framework + # add line numbers to each line in 'source_file'. start from 1 self.source_file_numbered = "\n".join( [f"{i + 1} {line}" for i, line in enumerate(self.source_file.split("\n"))] @@ -99,6 +102,9 @@ def __init__( else "" ) + self.stdout_from_run = "" + self.stderr_from_run = "" + def _read_file(self, file_path): """ Helper method to read file contents. @@ -138,6 +144,9 @@ def build_prompt(self) -> dict: "additional_instructions_text": self.additional_instructions, "language": self.language, "max_tests": MAX_TESTS_PER_RUN, + "testing_framework": self.testing_framework, + "stdout": self.stdout_from_run, + "stderr": self.stderr_from_run, } environment = Environment(undefined=StrictUndefined) try: @@ -155,6 +164,15 @@ def build_prompt(self) -> dict: return {"system": system_prompt, "user": user_prompt} def build_prompt_custom(self, file) -> dict: + """ + Builds a custom prompt by replacing placeholders with actual content from files and settings. + + Parameters: + file (str): The file to retrieve settings for building the prompt. + + Returns: + dict: A dictionary containing the system and user prompts. + """ variables = { "source_file_name": self.source_file_name, "test_file_name": self.test_file_name, @@ -168,15 +186,20 @@ def build_prompt_custom(self, file) -> dict: "additional_instructions_text": self.additional_instructions, "language": self.language, "max_tests": MAX_TESTS_PER_RUN, + "testing_framework": self.testing_framework, + "stdout": self.stdout_from_run, + "stderr": self.stderr_from_run, } environment = Environment(undefined=StrictUndefined) try: - system_prompt = environment.from_string( - get_settings().get(file).system - ).render(variables) - user_prompt = environment.from_string(get_settings().get(file).user).render( - variables - ) + settings = get_settings().get(file) + if settings is None or not hasattr(settings, "system") or not hasattr( + settings, "user" + ): + logging.error(f"Could not find settings for prompt file: {file}") + return {"system": "", "user": ""} + system_prompt = environment.from_string(settings.system).render(variables) + user_prompt = environment.from_string(settings.user).render(variables) except Exception as e: logging.error(f"Error rendering prompt: {e}") return {"system": "", "user": ""} diff --git a/cover_agent/UnitTestGenerator.py b/cover_agent/UnitTestGenerator.py index 0c73ecfc4..b0ed70026 100644 --- a/cover_agent/UnitTestGenerator.py +++ b/cover_agent/UnitTestGenerator.py @@ -79,6 +79,7 @@ def __init__( self.failed_test_runs = [] self.total_input_token_count = 0 self.total_output_token_count = 0 + self.testing_framework = "Unknown" # Read self.source_file_path into a string with open(self.source_file_path, "r") as f: @@ -269,10 +270,7 @@ def build_prompt(self): continue # dump dict to str code = json.dumps(failed_test_dict) - if "error_message" in failed_test: - error_message = failed_test["error_message"] - else: - error_message = None + error_message = failed_test.get("error_message", None) failed_test_runs_value += f"Failed Test:\n```\n{code}\n```\n" if error_message: failed_test_runs_value += ( @@ -296,6 +294,7 @@ def build_prompt(self): additional_instructions=self.additional_instructions, failed_test_runs=failed_test_runs_value, language=self.language, + testing_framework=self.testing_framework, ) return self.prompt_builder.build_prompt() @@ -363,6 +362,7 @@ def initial_test_suite_analysis(self): relevant_line_number_to_insert_imports_after = tests_dict.get( "relevant_line_number_to_insert_imports_after", None ) + self.testing_framework = tests_dict.get("testing_framework", "Unknown") counter_attempts += 1 if not relevant_line_number_to_insert_tests_after: @@ -562,9 +562,9 @@ def validate_test(self, generated_test: dict, num_attempts=1): "processed_test_file": processed_test, } - error_message = extract_error_message_python(fail_details["stdout"]) + error_message = self.extract_error_message(stderr=fail_details["stderr"], stdout=fail_details["stdout"]) if error_message: - logging.error(f"Error message:\n{error_message}") + logging.error(f"Error message summary:\n{error_message}") self.failed_test_runs.append( {"code": generated_test, "error_message": error_message} @@ -647,7 +647,7 @@ def validate_test(self, generated_test: dict, num_attempts=1): self.failed_test_runs.append( { "code": fail_details["test"], - "error_message": "did not increase code coverage", + "error_message": "Code coverage did not increase", } ) # Append failure details to the list @@ -686,7 +686,7 @@ def validate_test(self, generated_test: dict, num_attempts=1): self.failed_test_runs.append( { "code": fail_details["test"], - "error_message": "coverage verification error", + "error_message": "Coverage verification error", } ) # Append failure details to the list return fail_details @@ -762,30 +762,43 @@ def to_json(self): return json.dumps(self.to_dict()) -def extract_error_message_python(fail_message): - """ - Extracts and returns the error message from the provided failure message. - - Parameters: - fail_message (str): The failure message containing the error message to be extracted. - - Returns: - str: The extracted error message from the failure message, or an empty string if no error message is found. - - """ - try: - # Define a regular expression pattern to match the error message - MAX_LINES = 20 - pattern = r"={3,} FAILURES ={3,}(.*?)(={3,}|$)" - match = re.search(pattern, fail_message, re.DOTALL) - if match: - err_str = match.group(1).strip("\n") - err_str_lines = err_str.split("\n") - if len(err_str_lines) > MAX_LINES: - # show last MAX_lines lines - err_str = "...\n" + "\n".join(err_str_lines[-MAX_LINES:]) - return err_str - return "" - except Exception as e: - logging.error(f"Error extracting error message: {e}") - return "" \ No newline at end of file + def extract_error_message(self, stderr, stdout): + """ + Extracts the error message from the provided stderr and stdout outputs. + + Updates the PromptBuilder object with the stderr and stdout, builds a custom prompt for analyzing test run failures, + calls the language model to analyze the prompt, and loads the response into a dictionary. + + Returns the error summary from the loaded YAML data or a default error message if unable to summarize. + Logs errors encountered during the process. + + Parameters: + stderr (str): The standard error output from the test run. + stdout (str): The standard output from the test run. + + Returns: + str: The error summary extracted from the response or a default error message if extraction fails. + """ + try: + # Update the PromptBuilder object with stderr and stdout + self.prompt_builder.stderr_from_run = stderr + self.prompt_builder.stdout_from_run = stdout + + # Build the prompt + prompt_headers_indentation = self.prompt_builder.build_prompt_custom( + file="analyze_test_run_failure" + ) + + # Run the analysis via LLM + response, prompt_token_count, response_token_count = ( + self.ai_caller.call_model(prompt=prompt_headers_indentation, stream=False) + ) + self.total_input_token_count += prompt_token_count + self.total_output_token_count += response_token_count + tests_dict = load_yaml(response) + + return tests_dict.get("error_summary", f"ERROR: Unable to summarize error message from inputs. STDERR: {stderr}\nSTDOUT: {stdout}.") + except Exception as e: + logging.error(f"ERROR: Unable to extract error message from inputs using LLM.\nSTDERR: {stderr}\nSTDOUT: {stdout}\n\n{response}") + logging.error(f"Error extracting error message: {e}") + return "" \ No newline at end of file diff --git a/cover_agent/settings/analyze_test_run_failure.toml b/cover_agent/settings/analyze_test_run_failure.toml new file mode 100644 index 000000000..75e7c13e6 --- /dev/null +++ b/cover_agent/settings/analyze_test_run_failure.toml @@ -0,0 +1,41 @@ +[analyze_test_run_failure] +system="""\ +""" + +user="""\ +## Overview +You are a code assistant that accepts both the stdout and stderr from a test run, specifically for unit test regression testing. +Your goal is to analyze the output, and summarize the failure for further analysis. + +Please provide a one-sentence summary of the error, including the following details: +- The offending line of code (if available). +- The line number where the error occurred. +- Any other relevant details or information gleaned from the stdout and stderr. + +Here is the stdout and stderr from the test run: +========= +stdout: +{{ stdout|trim }} +========= + +stderr: +========= +{{ stderr|trim }} +========= + +Now, you need to analyze the output and provide a YAML object equivalent to type $TestFailureAnalysis, according to the following Pydantic definitions: +===== +class TestFailureAnalysis(BaseModel): + error_summary: str = Field(description="A one-sentence summary of the failure, including the offending line of code, line number, and other relevant information from the stdout/stderr.") +===== + +Example output: +```yaml +error_summary: ... +``` + +The Response should be only a valid YAML object, without any introduction text or follow-up text. + +Answer: +```yaml +""" diff --git a/cover_agent/settings/config_loader.py b/cover_agent/settings/config_loader.py index 1d2622d55..a2aca2c7f 100644 --- a/cover_agent/settings/config_loader.py +++ b/cover_agent/settings/config_loader.py @@ -7,6 +7,7 @@ "language_extensions.toml", "analyze_suite_test_headers_indentation.toml", "analyze_suite_test_insert_line.toml", + "analyze_test_run_failure.toml", ] diff --git a/cover_agent/settings/test_generation_prompt.toml b/cover_agent/settings/test_generation_prompt.toml index dbd789cbf..3845cd6ea 100644 --- a/cover_agent/settings/test_generation_prompt.toml +++ b/cover_agent/settings/test_generation_prompt.toml @@ -28,6 +28,8 @@ Here is the file that contains the existing tests, called `{{ test_file_name }}` {{ test_file| trim }} ========= +### Test Framework +The test framework used for running tests is `{{ testing_framework }}`. {%- if additional_includes_section|trim %} diff --git a/cover_agent/version.txt b/cover_agent/version.txt index 438debab3..2327344ad 100644 --- a/cover_agent/version.txt +++ b/cover_agent/version.txt @@ -1 +1 @@ -0.1.50 \ No newline at end of file +0.1.51 \ No newline at end of file diff --git a/tests/test_PromptBuilder.py b/tests/test_PromptBuilder.py index c3cf9ca78..7a42a31d5 100644 --- a/tests/test_PromptBuilder.py +++ b/tests/test_PromptBuilder.py @@ -1,4 +1,6 @@ +import os import pytest +import tempfile from unittest.mock import patch, mock_open from cover_agent.PromptBuilder import PromptBuilder @@ -172,3 +174,34 @@ def mock_render(*args, **kwargs): ) result = builder.build_prompt() assert result == {"system": "", "user": ""} + +class TestPromptBuilderEndToEnd: + def test_custom_analyze_test_run_failure(self): + # Create fake source and test files and tmp files and pass in the paths + source_file = tempfile.NamedTemporaryFile(mode="w", delete=False) + source_file.write("def foo():\n pass") + source_file.close() + test_file = tempfile.NamedTemporaryFile(mode="w", delete=False) + test_file.write("def test_foo():\n pass") + test_file.close() + tmp_file = tempfile.NamedTemporaryFile(mode="w", delete=False) + tmp_file.write("tmp file content") + tmp_file.close() + + builder = PromptBuilder( + source_file_path=source_file.name, + test_file_path=test_file.name, + code_coverage_report=tmp_file.name, + ) + + builder.stderr_from_run = "stderr content" + builder.stdout_from_run = "stdout content" + + result = builder.build_prompt_custom("analyze_test_run_failure") + assert "stderr content" in result["user"] + assert "stdout content" in result["user"] + + # Clean up + os.remove(source_file.name) + os.remove(test_file.name) + os.remove(tmp_file.name) \ No newline at end of file diff --git a/tests/test_UnitTestGenerator.py b/tests/test_UnitTestGenerator.py index 67adbc08a..116cf0c4b 100644 --- a/tests/test_UnitTestGenerator.py +++ b/tests/test_UnitTestGenerator.py @@ -1,8 +1,5 @@ import pytest -from cover_agent.UnitTestGenerator import ( - UnitTestGenerator, - extract_error_message_python, -) +from cover_agent.UnitTestGenerator import UnitTestGenerator from cover_agent.ReportGenerator import ReportGenerator import os @@ -31,17 +28,3 @@ def test_get_included_files_valid_paths(self): result == "file_path: `file1.txt`\ncontent:\n```\nfile content\n```\nfile_path: `file2.txt`\ncontent:\n```\nfile content\n```" ) - - -class TestExtractErrorMessage: - def test_extract_single_match(self): - fail_message = "=== FAILURES ===\\nError occurred here\\n=== END ===" - expected = "\\nError occurred here\\n" - result = extract_error_message_python(fail_message) - assert result == expected, f"Expected '{expected}', got '{result}'" - - def test_extract_bad_match(self): - fail_message = 33 - expected = "" - result = extract_error_message_python(fail_message) - assert result == expected, f"Expected '{expected}', got '{result}'" diff --git a/tests_integration/increase_coverage.py b/tests_integration/increase_coverage.py index e1105c4cd..6c9bfc83e 100755 --- a/tests_integration/increase_coverage.py +++ b/tests_integration/increase_coverage.py @@ -30,14 +30,14 @@ def __init__(self, source_file_path, test_file_path): self.test_file_path = test_file_path self.test_file_output_path = "" self.code_coverage_report_path = "coverage.xml" - self.test_command = f"poetry run pytest --cov=cover_agent --cov-report=xml --cov-report=term --log-cli-level=INFO --timeout=30 --disable-warnings" + self.test_command = f"poetry run pytest --cov=cover_agent --cov-report=xml --timeout=30 --disable-warnings" self.test_command_dir = os.getcwd() self.included_files = None self.coverage_type = "cobertura" self.report_filepath = "test_results.html" self.desired_coverage = 100 - self.max_iterations = 5 - self.additional_instructions = "Focus solely on the generate_partial_diff() function." + self.max_iterations = 4 + self.additional_instructions = "" self.model = "gpt-4o" self.api_base = "http://localhost:11434" self.prompt_only = False