diff --git a/cover_agent/AICaller.py b/cover_agent/AICaller.py index a78d8c1f3..8f996e986 100644 --- a/cover_agent/AICaller.py +++ b/cover_agent/AICaller.py @@ -13,7 +13,7 @@ def __init__(self, model: str, api_base: str = ""): Parameters: model (str): The name of the model to be used. - api_base (str): The base API url to use in case model is set to Ollama or Hugging Face + api_base (str): The base API URL to use in case the model is set to Ollama or Hugging Face. """ self.model = model self.api_base = api_base @@ -25,6 +25,7 @@ def call_model(self, prompt: dict, max_tokens=4096, stream=True): Parameters: prompt (dict): The prompt to be sent to the language model. max_tokens (int, optional): The maximum number of tokens to generate in the response. Defaults to 4096. + stream (bool, optional): Whether to stream the response or not. Defaults to True. Returns: tuple: A tuple containing the response generated by the language model, the number of tokens used from the prompt, and the total number of tokens in the response. @@ -36,21 +37,34 @@ def call_model(self, prompt: dict, max_tokens=4096, stream=True): if prompt["system"] == "": messages = [{"role": "user", "content": prompt["user"]}] else: - messages = [ - {"role": "system", "content": prompt["system"]}, - {"role": "user", "content": prompt["user"]}, - ] + if self.model in ["o1-preview", "o1-mini"]: + # o1 doesn't accept a system message so we add it to the prompt + messages = [ + {"role": "user", "content": prompt["system"] + "\n" + prompt["user"]}, + ] + else: + messages = [ + {"role": "system", "content": prompt["system"]}, + {"role": "user", "content": prompt["user"]}, + ] - # Default Completion parameters + # Default completion parameters completion_params = { "model": self.model, "messages": messages, - "max_tokens": max_tokens, - "stream": True, + "stream": stream, # Use the stream parameter passed to the method "temperature": 0.2, + "max_tokens": max_tokens, } - # API base exception for OpenAI Compatible, Ollama and Hugging Face models + # Model-specific adjustments + if self.model in ["o1-preview", "o1-mini"]: + completion_params["temperature"] = 1 + completion_params["stream"] = False # o1 doesn't support streaming + completion_params["max_completion_tokens"] = max_tokens + completion_params.pop("max_tokens", None) # Remove 'max_tokens' if present + + # API base exception for OpenAI Compatible, Ollama, and Hugging Face models if ( "ollama" in self.model or "huggingface" in self.model @@ -60,39 +74,46 @@ def call_model(self, prompt: dict, max_tokens=4096, stream=True): response = litellm.completion(**completion_params) - chunks = [] - print("Streaming results from LLM model...") if stream else None - try: - for chunk in response: - print(chunk.choices[0].delta.content or "", end="", flush=True) if stream else None - chunks.append(chunk) - time.sleep( - 0.01 - ) # Optional: Delay to simulate more 'natural' response pacing - except Exception as e: - print(f"Error during streaming: {e}") if stream else None - print("\n") if stream else None - - model_response = litellm.stream_chunk_builder(chunks, messages=messages) + if stream: + chunks = [] + print("Streaming results from LLM model...") + try: + for chunk in response: + print(chunk.choices[0].delta.content or "", end="", flush=True) + chunks.append(chunk) + time.sleep( + 0.01 + ) # Optional: Delay to simulate more 'natural' response pacing + except Exception as e: + print(f"Error during streaming: {e}") + print("\n") + # Build the final response from the streamed chunks + model_response = litellm.stream_chunk_builder(chunks, messages=messages) + content = model_response["choices"][0]["message"]["content"] + usage = model_response["usage"] + prompt_tokens = int(usage["prompt_tokens"]) + completion_tokens = int(usage["completion_tokens"]) + else: + # Non-streaming response is a CompletionResponse object + content = response.choices[0].message.content + print(f"Printing results from LLM model...\n{content}") + usage = response.usage + prompt_tokens = int(usage.prompt_tokens) + completion_tokens = int(usage.completion_tokens) if "WANDB_API_KEY" in os.environ: root_span = Trace( name="inference_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), - kind="llm", # kind can be "llm", "chain", "agent" or "tool + kind="llm", # kind can be "llm", "chain", "agent", or "tool" inputs={ "user_prompt": prompt["user"], "system_prompt": prompt["system"], }, - outputs={ - "model_response": model_response["choices"][0]["message"]["content"] - }, + outputs={"model_response": content}, ) root_span.log(name="inference") - # Returns: Response, Prompt token count, and Response token count - return ( - model_response["choices"][0]["message"]["content"], - int(model_response["usage"]["prompt_tokens"]), - int(model_response["usage"]["completion_tokens"]), - ) + # Returns: Response, Prompt token count, and Completion token count + return content, prompt_tokens, completion_tokens + diff --git a/cover_agent/ReportGenerator.py b/cover_agent/ReportGenerator.py index 78eb8d746..ebf62d786 100644 --- a/cover_agent/ReportGenerator.py +++ b/cover_agent/ReportGenerator.py @@ -115,9 +115,15 @@ def generate_full_diff(cls, original, processed): @classmethod def generate_partial_diff(cls, original, processed, context_lines=3): """ - Generates a partial diff of both the original and processed test files, + Generates a partial diff of both the original and processed test files, showing only added, removed, or changed lines with a few lines of context. - + + Note: + - The `difflib.unified_diff` function is used, which includes header lines (`---` and `+++`) + that indicate the original and modified file names or timestamps. + - It also includes context lines starting with `@@`, which show the range of lines affected. + - These lines are essential parts of the diff output and should be included in the expected outputs of tests. + :param original: String content of the original test file. :param processed: String content of the processed test file. :param context_lines: Number of context lines to include around changes. diff --git a/cover_agent/UnitTestGenerator.py b/cover_agent/UnitTestGenerator.py index b0ed70026..e7da950a6 100644 --- a/cover_agent/UnitTestGenerator.py +++ b/cover_agent/UnitTestGenerator.py @@ -80,6 +80,7 @@ def __init__( self.total_input_token_count = 0 self.total_output_token_count = 0 self.testing_framework = "Unknown" + self.code_coverage_report = "" # Read self.source_file_path into a string with open(self.source_file_path, "r") as f: @@ -184,7 +185,11 @@ def run_coverage(self): if key not in self.last_coverage_percentages: self.last_coverage_percentages[key] = 0 self.last_coverage_percentages[key] = percentage_covered - percentage_covered = total_lines_covered / total_lines + try: + percentage_covered = total_lines_covered / total_lines + except ZeroDivisionError: + self.logger.error(f"ZeroDivisionError: Attempting to perform total_lines_covered / total_lines: {total_lines_covered} / {total_lines}.") + percentage_covered = 0 self.logger.info( f"Total lines covered: {total_lines_covered}, Total lines missed: {total_lines_missed}, Total lines: {total_lines}" @@ -325,9 +330,11 @@ def initial_test_suite_analysis(self): prompt_headers_indentation = self.prompt_builder.build_prompt_custom( file="analyze_suite_test_headers_indentation" ) + self.ai_caller.model = "gpt-4o" if self.llm_model in ["o1-preview", "o1-mini"] else self.llm_model # Exception for OpenAI's new reasoning engines response, prompt_token_count, response_token_count = ( self.ai_caller.call_model(prompt=prompt_headers_indentation) ) + self.ai_caller.model = self.llm_model self.total_input_token_count += prompt_token_count self.total_output_token_count += response_token_count tests_dict = load_yaml(response) @@ -350,9 +357,11 @@ def initial_test_suite_analysis(self): prompt_test_insert_line = self.prompt_builder.build_prompt_custom( file="analyze_suite_test_insert_line" ) + self.ai_caller.model = "gpt-4o" if self.llm_model in ["o1-preview", "o1-mini"] else self.llm_model # Exception for OpenAI's new reasoning engines response, prompt_token_count, response_token_count = ( self.ai_caller.call_model(prompt=prompt_test_insert_line) ) + self.ai_caller.model = self.llm_model self.total_input_token_count += prompt_token_count self.total_output_token_count += response_token_count tests_dict = load_yaml(response) @@ -401,8 +410,9 @@ def generate_tests(self, max_tokens=4096): """ self.prompt = self.build_prompt() + stream = False if self.llm_model in ["o1-preview", "o1-mini"] else True response, prompt_token_count, response_token_count = ( - self.ai_caller.call_model(prompt=self.prompt, max_tokens=max_tokens) + self.ai_caller.call_model(prompt=self.prompt, max_tokens=max_tokens, stream=stream) ) self.total_input_token_count += prompt_token_count self.total_output_token_count += response_token_count @@ -790,15 +800,17 @@ def extract_error_message(self, stderr, stdout): ) # Run the analysis via LLM + self.ai_caller.model = "gpt-4o" if self.llm_model in ["o1-preview", "o1-mini"] else self.llm_model # Exception for OpenAI's new reasoning engines response, prompt_token_count, response_token_count = ( self.ai_caller.call_model(prompt=prompt_headers_indentation, stream=False) ) + self.ai_caller.model = self.llm_model # Reset self.total_input_token_count += prompt_token_count self.total_output_token_count += response_token_count tests_dict = load_yaml(response) return tests_dict.get("error_summary", f"ERROR: Unable to summarize error message from inputs. STDERR: {stderr}\nSTDOUT: {stdout}.") except Exception as e: - logging.error(f"ERROR: Unable to extract error message from inputs using LLM.\nSTDERR: {stderr}\nSTDOUT: {stdout}\n\n{response}") + logging.error(f"ERROR: Unable to extract error message from inputs using LLM.\nSTDERR: {stderr}\nSTDOUT: {stdout}") logging.error(f"Error extracting error message: {e}") return "" \ No newline at end of file diff --git a/cover_agent/version.txt b/cover_agent/version.txt index 341cf11fa..7dff5b892 100644 --- a/cover_agent/version.txt +++ b/cover_agent/version.txt @@ -1 +1 @@ -0.2.0 \ No newline at end of file +0.2.1 \ No newline at end of file diff --git a/tests/test_AICaller.py b/tests/test_AICaller.py index 91756b23d..9a55a2b88 100644 --- a/tests/test_AICaller.py +++ b/tests/test_AICaller.py @@ -1,7 +1,7 @@ import os import pytest -from unittest.mock import patch +from unittest.mock import patch, Mock from cover_agent.AICaller import AICaller @@ -112,3 +112,35 @@ def test_call_model_missing_keys(self, ai_caller): str(exc_info.value) == "\"The prompt dictionary must contain 'system' and 'user' keys.\"" ) + + @patch("cover_agent.AICaller.litellm.completion") + def test_call_model_o1_preview(self, mock_completion, ai_caller): + ai_caller.model = "o1-preview" + prompt = {"system": "System message", "user": "Hello, world!"} + # Mock the response + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content="response"))] + mock_response.usage = Mock(prompt_tokens=2, completion_tokens=10) + mock_completion.return_value = mock_response + # Call the method + response, prompt_tokens, response_tokens = ai_caller.call_model(prompt, stream=False) + assert response == "response" + assert prompt_tokens == 2 + assert response_tokens == 10 + + @patch("cover_agent.AICaller.litellm.completion") + def test_call_model_streaming_response(self, mock_completion, ai_caller): + prompt = {"system": "", "user": "Hello, world!"} + # Mock the response to be an iterable of chunks + mock_chunk = Mock() + mock_chunk.choices = [Mock(delta=Mock(content="response part"))] + mock_completion.return_value = [mock_chunk] + with patch("cover_agent.AICaller.litellm.stream_chunk_builder") as mock_builder: + mock_builder.return_value = { + "choices": [{"message": {"content": "response"}}], + "usage": {"prompt_tokens": 2, "completion_tokens": 10}, + } + response, prompt_tokens, response_tokens = ai_caller.call_model(prompt, stream=True) + assert response == "response" + assert prompt_tokens == 2 + assert response_tokens == 10 \ No newline at end of file diff --git a/tests/test_CoverAgent.py b/tests/test_CoverAgent.py index ebc084e0f..c29b0061a 100644 --- a/tests/test_CoverAgent.py +++ b/tests/test_CoverAgent.py @@ -167,4 +167,42 @@ def test_duplicate_test_file_without_output_path(self, mock_isfile): # Clean up the temp files os.remove(temp_source_file.name) - os.remove(temp_test_file.name) \ No newline at end of file + os.remove(temp_test_file.name) + + @patch("cover_agent.CoverAgent.os.environ", {}) + @patch("cover_agent.CoverAgent.sys.exit") + @patch("cover_agent.CoverAgent.UnitTestGenerator") + @patch("cover_agent.CoverAgent.UnitTestDB") + def test_run_max_iterations_strict_coverage(self, mock_test_db, mock_unit_test_generator, mock_sys_exit): + with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file: + with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_test_file: + args = argparse.Namespace( + source_file_path=temp_source_file.name, + test_file_path=temp_test_file.name, + test_file_output_path="output_test_file.py", + code_coverage_report_path="coverage_report.xml", + test_command="pytest", + test_command_dir=os.getcwd(), + included_files=None, + coverage_type="cobertura", + report_filepath="test_results.html", + desired_coverage=90, + max_iterations=1, + additional_instructions="", + model="openai/test-model", + api_base="openai/test-api", + use_report_coverage_feature_flag=False, + log_db_path="", + run_tests_multiple_times=False, + strict_coverage=True + ) + # Mock the methods used in run + instance = mock_unit_test_generator.return_value + instance.current_coverage = 0.5 # below desired coverage + instance.desired_coverage = 90 + instance.generate_tests.return_value = {"new_tests": [{}]} + agent = CoverAgent(args) + agent.run() + # Assertions to ensure sys.exit was called + mock_sys_exit.assert_called_once_with(2) + mock_test_db.return_value.dump_to_report.assert_called_once_with(args.report_filepath) diff --git a/tests/test_PromptBuilder.py b/tests/test_PromptBuilder.py index 7a42a31d5..d57c32f0d 100644 --- a/tests/test_PromptBuilder.py +++ b/tests/test_PromptBuilder.py @@ -204,4 +204,4 @@ def test_custom_analyze_test_run_failure(self): # Clean up os.remove(source_file.name) os.remove(test_file.name) - os.remove(tmp_file.name) \ No newline at end of file + os.remove(tmp_file.name) diff --git a/tests/test_ReportGenerator.py b/tests/test_ReportGenerator.py index 1008559ed..43e683ea9 100644 --- a/tests/test_ReportGenerator.py +++ b/tests/test_ReportGenerator.py @@ -45,4 +45,14 @@ def test_generate_report(self, sample_results, expected_output, tmp_path): assert expected_output[2] in content # Check if the row includes "test_current_date" assert expected_output[3] in content # Check if the HTML closes properly + def test_generate_partial_diff_basic(self): + original = "line1\nline2\nline3" + processed = "line1\nline2 modified\nline3\nline4" + diff_output = ReportGenerator.generate_partial_diff(original, processed) + assert '+line2 modified' in diff_output + assert '+line4' in diff_output + assert '-line2' in diff_output + assert ' line1' in diff_output + + # Additional validation can be added based on specific content if required diff --git a/tests/test_UnitTestDB.py b/tests/test_UnitTestDB.py index d6cd54168..6389bc634 100644 --- a/tests/test_UnitTestDB.py +++ b/tests/test_UnitTestDB.py @@ -1,6 +1,8 @@ import pytest import os from datetime import datetime, timedelta +from cover_agent.UnitTestDB import dump_to_report_cli +from cover_agent.UnitTestDB import dump_to_report from cover_agent.UnitTestDB import UnitTestDB, UnitTestGenerationAttempt DB_NAME = "unit_test_runs.db" @@ -91,3 +93,21 @@ def test_dump_to_report(self, unit_test_db, tmp_path): assert "sample test code" in content assert "sample new test code" in content assert "def test_example(): pass" in content + + + def test_dump_to_report_cli_custom_args(self, unit_test_db, tmp_path, monkeypatch): + custom_db_path = str(tmp_path / "cli_custom_unit_test_runs.db") + custom_report_filepath = str(tmp_path / "cli_custom_report.html") + monkeypatch.setattr("sys.argv", [ + "prog", + "--path-to-db", custom_db_path, + "--report-filepath", custom_report_filepath + ]) + dump_to_report_cli() + assert os.path.exists(custom_report_filepath) + + + def test_dump_to_report_defaults(self, unit_test_db, tmp_path): + report_filepath = tmp_path / "default_report.html" + dump_to_report(report_filepath=str(report_filepath)) + assert os.path.exists(report_filepath) diff --git a/tests/test_UnitTestGenerator.py b/tests/test_UnitTestGenerator.py index 116cf0c4b..3cb5a7394 100644 --- a/tests/test_UnitTestGenerator.py +++ b/tests/test_UnitTestGenerator.py @@ -1,10 +1,12 @@ -import pytest -from cover_agent.UnitTestGenerator import UnitTestGenerator +from cover_agent.CoverageProcessor import CoverageProcessor from cover_agent.ReportGenerator import ReportGenerator -import os - +from cover_agent.Runner import Runner +from cover_agent.UnitTestGenerator import UnitTestGenerator from unittest.mock import patch, mock_open - +import datetime +import os +import pytest +import tempfile class TestUnitTestGenerator: def test_get_included_files_mixed_paths(self): @@ -28,3 +30,111 @@ def test_get_included_files_valid_paths(self): result == "file_path: `file1.txt`\ncontent:\n```\nfile content\n```\nfile_path: `file2.txt`\ncontent:\n```\nfile content\n```" ) + def test_get_code_language_no_extension(self): + with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file: + generator = UnitTestGenerator( + source_file_path=temp_source_file.name, + test_file_path="test_test.py", + code_coverage_report_path="coverage.xml", + test_command="pytest", + llm_model="gpt-3" + ) + language = generator.get_code_language("filename") + assert language == "unknown" + + def test_extract_error_message_exception_handling(self): + # PromptBuilder will not instantiate so we're expecting an empty error_message. + with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file: + generator = UnitTestGenerator( + source_file_path=temp_source_file.name, + test_file_path="test_test.py", + code_coverage_report_path="coverage.xml", + test_command="pytest", + llm_model="gpt-3" + ) + with patch.object(generator, 'ai_caller') as mock_ai_caller: + mock_ai_caller.call_model.side_effect = Exception("Mock exception") + error_message = generator.extract_error_message(stderr="stderr content", stdout="stdout content") + assert '' in error_message + + # def test_get_included_files_none(self): + # result = UnitTestGenerator.get_included_files(None) + # assert result == "" + + # def test_run_coverage_command_failure(self): + # with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file: + # generator = UnitTestGenerator( + # source_file_path=temp_source_file.name, + # test_file_path="test_test.py", + # code_coverage_report_path="coverage.xml", + # test_command="invalid_command", + # llm_model="gpt-3" + # ) + # with pytest.raises(AssertionError): + # generator.run_coverage() + + # def test_extract_error_message_success(self): + # with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file: + # generator = UnitTestGenerator( + # source_file_path=temp_source_file.name, + # test_file_path="test_test.py", + # code_coverage_report_path="coverage.xml", + # test_command="pytest", + # llm_model="gpt-3" + # ) + # with patch.object(generator.ai_caller, 'call_model', return_value=("error_summary: 'Mocked error summary'", 10, 10)): + # error_message = generator.extract_error_message(stderr="stderr content", stdout="stdout content") + # assert error_message == "" + + def test_run_coverage_with_report_coverage_flag(self): + with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file: + generator = UnitTestGenerator( + source_file_path=temp_source_file.name, + test_file_path="test_test.py", + code_coverage_report_path="coverage.xml", + test_command="pytest", + llm_model="gpt-3", + use_report_coverage_feature_flag=True + ) + with patch.object(Runner, 'run_command', return_value=("", "", 0, datetime.datetime.now())): + with patch.object(CoverageProcessor, 'process_coverage_report', return_value={'test.py': ([], [], 1.0)}): + generator.run_coverage() + # Dividing by zero so we're expecting a logged error and a return of 0 + assert generator.current_coverage == 0 + + def test_build_prompt_with_failed_tests(self): + with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file: + generator = UnitTestGenerator( + source_file_path=temp_source_file.name, + test_file_path="test_test.py", + code_coverage_report_path="coverage.xml", + test_command="pytest", + llm_model="gpt-3" + ) + generator.failed_test_runs = [ + { + "code": {"test_code": "def test_example(): assert False"}, + "error_message": "AssertionError" + } + ] + prompt = generator.build_prompt() + assert "Failed Test:" in prompt['user'] + + + def test_generate_tests_invalid_yaml(self): + with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file: + generator = UnitTestGenerator( + source_file_path=temp_source_file.name, + test_file_path="test_test.py", + code_coverage_report_path="coverage.xml", + test_command="pytest", + llm_model="gpt-3" + ) + generator.build_prompt = lambda: "Test prompt" + with patch.object(generator.ai_caller, 'call_model', return_value=("This is not YAML", 10, 10)): + result = generator.generate_tests() + + # The eventual call to try_fix_yaml() will end up spitting out the same string but deeming is "YAML." + # While this is not a valid YAML, the function will return the original string (for better or for worse). + assert result =="This is not YAML" + \ No newline at end of file diff --git a/tests/test_load_yaml.py b/tests/test_load_yaml.py index d7e8db7fa..e88757bdd 100644 --- a/tests/test_load_yaml.py +++ b/tests/test_load_yaml.py @@ -95,26 +95,26 @@ def test_load_invalid_yaml2(self): # auto-generated by cover agent -def test_try_fix_yaml_snippet_extraction(): - from cover_agent.utils import try_fix_yaml + def test_try_fix_yaml_snippet_extraction(self): + from cover_agent.utils import try_fix_yaml - yaml_str = "```yaml\nname: John Smith\nage: 35\n```" - expected_output = {"name": "John Smith", "age": 35} - assert try_fix_yaml(yaml_str) == expected_output + yaml_str = "```yaml\nname: John Smith\nage: 35\n```" + expected_output = {"name": "John Smith", "age": 35} + assert try_fix_yaml(yaml_str) == expected_output -def test_try_fix_yaml_remove_all_lines(): - from cover_agent.utils import try_fix_yaml + def test_try_fix_yaml_remove_all_lines(self): + from cover_agent.utils import try_fix_yaml - yaml_str = "language: python\nname: John Smith\nage: 35\ninvalid_line" - expected_output = {"language": "python", "name": "John Smith", "age": 35} - assert try_fix_yaml(yaml_str) == expected_output + yaml_str = "language: python\nname: John Smith\nage: 35\ninvalid_line" + expected_output = {"language": "python", "name": "John Smith", "age": 35} + assert try_fix_yaml(yaml_str) == expected_output -def test_try_fix_yaml_llama3_8b(): - from cover_agent.utils import try_fix_yaml + def test_try_fix_yaml_llama3_8b(self): + from cover_agent.utils import try_fix_yaml - yaml_str = """\ + yaml_str = """\ here is the response: language: python @@ -128,31 +128,36 @@ def test_try_fix_yaml_llama3_8b(): hope this helps! """ - expected_output = { - "here is the response": None, - "language": "python", - "new_tests": [ - { - "test_behavior": "aaa\n", - "test_name": "test_current_date", - "test_code": "bbb\n", - "test_tags": "happy path", - } - ], - } - assert try_fix_yaml(yaml_str) == expected_output + expected_output = { + "here is the response": None, + "language": "python", + "new_tests": [ + { + "test_behavior": "aaa\n", + "test_name": "test_current_date", + "test_code": "bbb\n", + "test_tags": "happy path", + } + ], + } + assert try_fix_yaml(yaml_str) == expected_output -def test_invalid_yaml_wont_parse(): - from cover_agent.utils import try_fix_yaml + def test_invalid_yaml_wont_parse(self): + from cover_agent.utils import try_fix_yaml - yaml_str = """ + yaml_str = """ here is the response language: python tests: - test_behavior: | - aaa - test_name:""" - expected_output = None - assert load_yaml(yaml_str) == expected_output +aaa +test_name:""" + expected_output = None + assert load_yaml(yaml_str) == expected_output + + def test_load_yaml_second_fallback_failure(self): + yaml_str = "```yaml\ninvalid_yaml: [unclosed_list\n```" + assert load_yaml(yaml_str) is None + diff --git a/tests/test_main.py b/tests/test_main.py index 35fa57add..216e99e61 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -94,3 +94,41 @@ def test_main_test_file_not_found( main() assert str(exc_info.value) == f"Test file not found at {args.test_file_path}" + + @patch("cover_agent.main.CoverAgent") + @patch("cover_agent.main.parse_args") + @patch("cover_agent.main.os.path.isfile") + def test_main_calls_agent_run( + self, mock_isfile, mock_parse_args, mock_cover_agent + ): + args = argparse.Namespace( + source_file_path="test_source.py", + test_file_path="test_file.py", + test_file_output_path="", + code_coverage_report_path="coverage_report.xml", + test_command="pytest", + test_command_dir=os.getcwd(), + included_files=None, + coverage_type="cobertura", + report_filepath="test_results.html", + desired_coverage=90, + max_iterations=10, + additional_instructions="", + model="gpt-4o", + api_base="http://localhost:11434", + strict_coverage=False, + run_tests_multiple_times=1, + use_report_coverage_feature_flag=False, + log_db_path="", + ) + mock_parse_args.return_value = args + # Mock os.path.isfile to return True for both source and test file paths + mock_isfile.side_effect = lambda path: path in [args.source_file_path, args.test_file_path] + mock_agent_instance = MagicMock() + mock_cover_agent.return_value = mock_agent_instance + + main() + + mock_cover_agent.assert_called_once_with(args) + mock_agent_instance.run.assert_called_once() + diff --git a/tests_integration/increase_coverage.py b/tests_integration/increase_coverage.py index 6c9bfc83e..c3ad834d6 100755 --- a/tests_integration/increase_coverage.py +++ b/tests_integration/increase_coverage.py @@ -9,18 +9,19 @@ # List of source/test files to iterate over: SOURCE_TEST_FILE_LIST = [ # ["cover_agent/AICaller.py", "tests/test_AICaller.py"], - # ["cover_agent/CoverAgent.py", "tests/test_CoverAgent.py"], + # ["cover_agent/CoverAgent.py", "tests/test_CoverAgent.py"], # ["cover_agent/CoverageProcessor.py", "tests/test_CoverageProcessor.py"], + # ["cover_agent/CustomLogger.py", ""], # ["cover_agent/FilePreprocessor.py", "tests/test_FilePreprocessor.py"], # ["cover_agent/PromptBuilder.py", "tests/test_PromptBuilder.py"], - ["cover_agent/ReportGenerator.py", "tests/test_ReportGenerator.py"], + # ["cover_agent/ReportGenerator.py", "tests/test_ReportGenerator.py"], # ["cover_agent/Runner.py", "tests/test_Runner.py"], # ["cover_agent/UnitTestDB.py", "tests/test_UnitTestDB.py"], # ["cover_agent/UnitTestGenerator.py", "tests/test_UnitTestGenerator.py"], + # ["cover_agent/main.py", "tests/test_main.py"], + # ["cover_agent/settings/config_loader.py", ""], + ["cover_agent/utils.py", "tests/test_load_yaml.py"], # ["cover_agent/version.py", "tests/test_version.py"], - # ["cover_agent/utils.py", "tests/test_load_yaml.py"], - # ["cover_agent/settings/config_loader.py", "tests/test_.py"], - # ["cover_agent/CustomLogger.py", ""], ] @@ -36,9 +37,10 @@ def __init__(self, source_file_path, test_file_path): self.coverage_type = "cobertura" self.report_filepath = "test_results.html" self.desired_coverage = 100 - self.max_iterations = 4 - self.additional_instructions = "" - self.model = "gpt-4o" + self.max_iterations = 5 + self.additional_instructions = "Do not indent the tests" + # self.model = "gpt-4o" + self.model = "o1-mini" self.api_base = "http://localhost:11434" self.prompt_only = False self.strict_coverage = False