diff --git a/cover_agent/AICaller.py b/cover_agent/AICaller.py
index a78d8c1f3..8f996e986 100644
--- a/cover_agent/AICaller.py
+++ b/cover_agent/AICaller.py
@@ -13,7 +13,7 @@ def __init__(self, model: str, api_base: str = ""):
Parameters:
model (str): The name of the model to be used.
- api_base (str): The base API url to use in case model is set to Ollama or Hugging Face
+ api_base (str): The base API URL to use in case the model is set to Ollama or Hugging Face.
"""
self.model = model
self.api_base = api_base
@@ -25,6 +25,7 @@ def call_model(self, prompt: dict, max_tokens=4096, stream=True):
Parameters:
prompt (dict): The prompt to be sent to the language model.
max_tokens (int, optional): The maximum number of tokens to generate in the response. Defaults to 4096.
+ stream (bool, optional): Whether to stream the response or not. Defaults to True.
Returns:
tuple: A tuple containing the response generated by the language model, the number of tokens used from the prompt, and the total number of tokens in the response.
@@ -36,21 +37,34 @@ def call_model(self, prompt: dict, max_tokens=4096, stream=True):
if prompt["system"] == "":
messages = [{"role": "user", "content": prompt["user"]}]
else:
- messages = [
- {"role": "system", "content": prompt["system"]},
- {"role": "user", "content": prompt["user"]},
- ]
+ if self.model in ["o1-preview", "o1-mini"]:
+ # o1 doesn't accept a system message so we add it to the prompt
+ messages = [
+ {"role": "user", "content": prompt["system"] + "\n" + prompt["user"]},
+ ]
+ else:
+ messages = [
+ {"role": "system", "content": prompt["system"]},
+ {"role": "user", "content": prompt["user"]},
+ ]
- # Default Completion parameters
+ # Default completion parameters
completion_params = {
"model": self.model,
"messages": messages,
- "max_tokens": max_tokens,
- "stream": True,
+ "stream": stream, # Use the stream parameter passed to the method
"temperature": 0.2,
+ "max_tokens": max_tokens,
}
- # API base exception for OpenAI Compatible, Ollama and Hugging Face models
+ # Model-specific adjustments
+ if self.model in ["o1-preview", "o1-mini"]:
+ completion_params["temperature"] = 1
+ completion_params["stream"] = False # o1 doesn't support streaming
+ completion_params["max_completion_tokens"] = max_tokens
+ completion_params.pop("max_tokens", None) # Remove 'max_tokens' if present
+
+ # API base exception for OpenAI Compatible, Ollama, and Hugging Face models
if (
"ollama" in self.model
or "huggingface" in self.model
@@ -60,39 +74,46 @@ def call_model(self, prompt: dict, max_tokens=4096, stream=True):
response = litellm.completion(**completion_params)
- chunks = []
- print("Streaming results from LLM model...") if stream else None
- try:
- for chunk in response:
- print(chunk.choices[0].delta.content or "", end="", flush=True) if stream else None
- chunks.append(chunk)
- time.sleep(
- 0.01
- ) # Optional: Delay to simulate more 'natural' response pacing
- except Exception as e:
- print(f"Error during streaming: {e}") if stream else None
- print("\n") if stream else None
-
- model_response = litellm.stream_chunk_builder(chunks, messages=messages)
+ if stream:
+ chunks = []
+ print("Streaming results from LLM model...")
+ try:
+ for chunk in response:
+ print(chunk.choices[0].delta.content or "", end="", flush=True)
+ chunks.append(chunk)
+ time.sleep(
+ 0.01
+ ) # Optional: Delay to simulate more 'natural' response pacing
+ except Exception as e:
+ print(f"Error during streaming: {e}")
+ print("\n")
+ # Build the final response from the streamed chunks
+ model_response = litellm.stream_chunk_builder(chunks, messages=messages)
+ content = model_response["choices"][0]["message"]["content"]
+ usage = model_response["usage"]
+ prompt_tokens = int(usage["prompt_tokens"])
+ completion_tokens = int(usage["completion_tokens"])
+ else:
+ # Non-streaming response is a CompletionResponse object
+ content = response.choices[0].message.content
+ print(f"Printing results from LLM model...\n{content}")
+ usage = response.usage
+ prompt_tokens = int(usage.prompt_tokens)
+ completion_tokens = int(usage.completion_tokens)
if "WANDB_API_KEY" in os.environ:
root_span = Trace(
name="inference_"
+ datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
- kind="llm", # kind can be "llm", "chain", "agent" or "tool
+ kind="llm", # kind can be "llm", "chain", "agent", or "tool"
inputs={
"user_prompt": prompt["user"],
"system_prompt": prompt["system"],
},
- outputs={
- "model_response": model_response["choices"][0]["message"]["content"]
- },
+ outputs={"model_response": content},
)
root_span.log(name="inference")
- # Returns: Response, Prompt token count, and Response token count
- return (
- model_response["choices"][0]["message"]["content"],
- int(model_response["usage"]["prompt_tokens"]),
- int(model_response["usage"]["completion_tokens"]),
- )
+ # Returns: Response, Prompt token count, and Completion token count
+ return content, prompt_tokens, completion_tokens
+
diff --git a/cover_agent/ReportGenerator.py b/cover_agent/ReportGenerator.py
index 78eb8d746..ebf62d786 100644
--- a/cover_agent/ReportGenerator.py
+++ b/cover_agent/ReportGenerator.py
@@ -115,9 +115,15 @@ def generate_full_diff(cls, original, processed):
@classmethod
def generate_partial_diff(cls, original, processed, context_lines=3):
"""
- Generates a partial diff of both the original and processed test files,
+ Generates a partial diff of both the original and processed test files,
showing only added, removed, or changed lines with a few lines of context.
-
+
+ Note:
+ - The `difflib.unified_diff` function is used, which includes header lines (`---` and `+++`)
+ that indicate the original and modified file names or timestamps.
+ - It also includes context lines starting with `@@`, which show the range of lines affected.
+ - These lines are essential parts of the diff output and should be included in the expected outputs of tests.
+
:param original: String content of the original test file.
:param processed: String content of the processed test file.
:param context_lines: Number of context lines to include around changes.
diff --git a/cover_agent/UnitTestGenerator.py b/cover_agent/UnitTestGenerator.py
index b0ed70026..e7da950a6 100644
--- a/cover_agent/UnitTestGenerator.py
+++ b/cover_agent/UnitTestGenerator.py
@@ -80,6 +80,7 @@ def __init__(
self.total_input_token_count = 0
self.total_output_token_count = 0
self.testing_framework = "Unknown"
+ self.code_coverage_report = ""
# Read self.source_file_path into a string
with open(self.source_file_path, "r") as f:
@@ -184,7 +185,11 @@ def run_coverage(self):
if key not in self.last_coverage_percentages:
self.last_coverage_percentages[key] = 0
self.last_coverage_percentages[key] = percentage_covered
- percentage_covered = total_lines_covered / total_lines
+ try:
+ percentage_covered = total_lines_covered / total_lines
+ except ZeroDivisionError:
+ self.logger.error(f"ZeroDivisionError: Attempting to perform total_lines_covered / total_lines: {total_lines_covered} / {total_lines}.")
+ percentage_covered = 0
self.logger.info(
f"Total lines covered: {total_lines_covered}, Total lines missed: {total_lines_missed}, Total lines: {total_lines}"
@@ -325,9 +330,11 @@ def initial_test_suite_analysis(self):
prompt_headers_indentation = self.prompt_builder.build_prompt_custom(
file="analyze_suite_test_headers_indentation"
)
+ self.ai_caller.model = "gpt-4o" if self.llm_model in ["o1-preview", "o1-mini"] else self.llm_model # Exception for OpenAI's new reasoning engines
response, prompt_token_count, response_token_count = (
self.ai_caller.call_model(prompt=prompt_headers_indentation)
)
+ self.ai_caller.model = self.llm_model
self.total_input_token_count += prompt_token_count
self.total_output_token_count += response_token_count
tests_dict = load_yaml(response)
@@ -350,9 +357,11 @@ def initial_test_suite_analysis(self):
prompt_test_insert_line = self.prompt_builder.build_prompt_custom(
file="analyze_suite_test_insert_line"
)
+ self.ai_caller.model = "gpt-4o" if self.llm_model in ["o1-preview", "o1-mini"] else self.llm_model # Exception for OpenAI's new reasoning engines
response, prompt_token_count, response_token_count = (
self.ai_caller.call_model(prompt=prompt_test_insert_line)
)
+ self.ai_caller.model = self.llm_model
self.total_input_token_count += prompt_token_count
self.total_output_token_count += response_token_count
tests_dict = load_yaml(response)
@@ -401,8 +410,9 @@ def generate_tests(self, max_tokens=4096):
"""
self.prompt = self.build_prompt()
+ stream = False if self.llm_model in ["o1-preview", "o1-mini"] else True
response, prompt_token_count, response_token_count = (
- self.ai_caller.call_model(prompt=self.prompt, max_tokens=max_tokens)
+ self.ai_caller.call_model(prompt=self.prompt, max_tokens=max_tokens, stream=stream)
)
self.total_input_token_count += prompt_token_count
self.total_output_token_count += response_token_count
@@ -790,15 +800,17 @@ def extract_error_message(self, stderr, stdout):
)
# Run the analysis via LLM
+ self.ai_caller.model = "gpt-4o" if self.llm_model in ["o1-preview", "o1-mini"] else self.llm_model # Exception for OpenAI's new reasoning engines
response, prompt_token_count, response_token_count = (
self.ai_caller.call_model(prompt=prompt_headers_indentation, stream=False)
)
+ self.ai_caller.model = self.llm_model # Reset
self.total_input_token_count += prompt_token_count
self.total_output_token_count += response_token_count
tests_dict = load_yaml(response)
return tests_dict.get("error_summary", f"ERROR: Unable to summarize error message from inputs. STDERR: {stderr}\nSTDOUT: {stdout}.")
except Exception as e:
- logging.error(f"ERROR: Unable to extract error message from inputs using LLM.\nSTDERR: {stderr}\nSTDOUT: {stdout}\n\n{response}")
+ logging.error(f"ERROR: Unable to extract error message from inputs using LLM.\nSTDERR: {stderr}\nSTDOUT: {stdout}")
logging.error(f"Error extracting error message: {e}")
return ""
\ No newline at end of file
diff --git a/cover_agent/version.txt b/cover_agent/version.txt
index 341cf11fa..7dff5b892 100644
--- a/cover_agent/version.txt
+++ b/cover_agent/version.txt
@@ -1 +1 @@
-0.2.0
\ No newline at end of file
+0.2.1
\ No newline at end of file
diff --git a/tests/test_AICaller.py b/tests/test_AICaller.py
index 91756b23d..9a55a2b88 100644
--- a/tests/test_AICaller.py
+++ b/tests/test_AICaller.py
@@ -1,7 +1,7 @@
import os
import pytest
-from unittest.mock import patch
+from unittest.mock import patch, Mock
from cover_agent.AICaller import AICaller
@@ -112,3 +112,35 @@ def test_call_model_missing_keys(self, ai_caller):
str(exc_info.value)
== "\"The prompt dictionary must contain 'system' and 'user' keys.\""
)
+
+ @patch("cover_agent.AICaller.litellm.completion")
+ def test_call_model_o1_preview(self, mock_completion, ai_caller):
+ ai_caller.model = "o1-preview"
+ prompt = {"system": "System message", "user": "Hello, world!"}
+ # Mock the response
+ mock_response = Mock()
+ mock_response.choices = [Mock(message=Mock(content="response"))]
+ mock_response.usage = Mock(prompt_tokens=2, completion_tokens=10)
+ mock_completion.return_value = mock_response
+ # Call the method
+ response, prompt_tokens, response_tokens = ai_caller.call_model(prompt, stream=False)
+ assert response == "response"
+ assert prompt_tokens == 2
+ assert response_tokens == 10
+
+ @patch("cover_agent.AICaller.litellm.completion")
+ def test_call_model_streaming_response(self, mock_completion, ai_caller):
+ prompt = {"system": "", "user": "Hello, world!"}
+ # Mock the response to be an iterable of chunks
+ mock_chunk = Mock()
+ mock_chunk.choices = [Mock(delta=Mock(content="response part"))]
+ mock_completion.return_value = [mock_chunk]
+ with patch("cover_agent.AICaller.litellm.stream_chunk_builder") as mock_builder:
+ mock_builder.return_value = {
+ "choices": [{"message": {"content": "response"}}],
+ "usage": {"prompt_tokens": 2, "completion_tokens": 10},
+ }
+ response, prompt_tokens, response_tokens = ai_caller.call_model(prompt, stream=True)
+ assert response == "response"
+ assert prompt_tokens == 2
+ assert response_tokens == 10
\ No newline at end of file
diff --git a/tests/test_CoverAgent.py b/tests/test_CoverAgent.py
index ebc084e0f..c29b0061a 100644
--- a/tests/test_CoverAgent.py
+++ b/tests/test_CoverAgent.py
@@ -167,4 +167,42 @@ def test_duplicate_test_file_without_output_path(self, mock_isfile):
# Clean up the temp files
os.remove(temp_source_file.name)
- os.remove(temp_test_file.name)
\ No newline at end of file
+ os.remove(temp_test_file.name)
+
+ @patch("cover_agent.CoverAgent.os.environ", {})
+ @patch("cover_agent.CoverAgent.sys.exit")
+ @patch("cover_agent.CoverAgent.UnitTestGenerator")
+ @patch("cover_agent.CoverAgent.UnitTestDB")
+ def test_run_max_iterations_strict_coverage(self, mock_test_db, mock_unit_test_generator, mock_sys_exit):
+ with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file:
+ with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_test_file:
+ args = argparse.Namespace(
+ source_file_path=temp_source_file.name,
+ test_file_path=temp_test_file.name,
+ test_file_output_path="output_test_file.py",
+ code_coverage_report_path="coverage_report.xml",
+ test_command="pytest",
+ test_command_dir=os.getcwd(),
+ included_files=None,
+ coverage_type="cobertura",
+ report_filepath="test_results.html",
+ desired_coverage=90,
+ max_iterations=1,
+ additional_instructions="",
+ model="openai/test-model",
+ api_base="openai/test-api",
+ use_report_coverage_feature_flag=False,
+ log_db_path="",
+ run_tests_multiple_times=False,
+ strict_coverage=True
+ )
+ # Mock the methods used in run
+ instance = mock_unit_test_generator.return_value
+ instance.current_coverage = 0.5 # below desired coverage
+ instance.desired_coverage = 90
+ instance.generate_tests.return_value = {"new_tests": [{}]}
+ agent = CoverAgent(args)
+ agent.run()
+ # Assertions to ensure sys.exit was called
+ mock_sys_exit.assert_called_once_with(2)
+ mock_test_db.return_value.dump_to_report.assert_called_once_with(args.report_filepath)
diff --git a/tests/test_PromptBuilder.py b/tests/test_PromptBuilder.py
index 7a42a31d5..d57c32f0d 100644
--- a/tests/test_PromptBuilder.py
+++ b/tests/test_PromptBuilder.py
@@ -204,4 +204,4 @@ def test_custom_analyze_test_run_failure(self):
# Clean up
os.remove(source_file.name)
os.remove(test_file.name)
- os.remove(tmp_file.name)
\ No newline at end of file
+ os.remove(tmp_file.name)
diff --git a/tests/test_ReportGenerator.py b/tests/test_ReportGenerator.py
index 1008559ed..43e683ea9 100644
--- a/tests/test_ReportGenerator.py
+++ b/tests/test_ReportGenerator.py
@@ -45,4 +45,14 @@ def test_generate_report(self, sample_results, expected_output, tmp_path):
assert expected_output[2] in content # Check if the row includes "test_current_date"
assert expected_output[3] in content # Check if the HTML closes properly
+ def test_generate_partial_diff_basic(self):
+ original = "line1\nline2\nline3"
+ processed = "line1\nline2 modified\nline3\nline4"
+ diff_output = ReportGenerator.generate_partial_diff(original, processed)
+ assert '+line2 modified' in diff_output
+ assert '+line4' in diff_output
+ assert '-line2' in diff_output
+ assert ' line1' in diff_output
+
+
# Additional validation can be added based on specific content if required
diff --git a/tests/test_UnitTestDB.py b/tests/test_UnitTestDB.py
index d6cd54168..6389bc634 100644
--- a/tests/test_UnitTestDB.py
+++ b/tests/test_UnitTestDB.py
@@ -1,6 +1,8 @@
import pytest
import os
from datetime import datetime, timedelta
+from cover_agent.UnitTestDB import dump_to_report_cli
+from cover_agent.UnitTestDB import dump_to_report
from cover_agent.UnitTestDB import UnitTestDB, UnitTestGenerationAttempt
DB_NAME = "unit_test_runs.db"
@@ -91,3 +93,21 @@ def test_dump_to_report(self, unit_test_db, tmp_path):
assert "sample test code" in content
assert "sample new test code" in content
assert "def test_example(): pass" in content
+
+
+ def test_dump_to_report_cli_custom_args(self, unit_test_db, tmp_path, monkeypatch):
+ custom_db_path = str(tmp_path / "cli_custom_unit_test_runs.db")
+ custom_report_filepath = str(tmp_path / "cli_custom_report.html")
+ monkeypatch.setattr("sys.argv", [
+ "prog",
+ "--path-to-db", custom_db_path,
+ "--report-filepath", custom_report_filepath
+ ])
+ dump_to_report_cli()
+ assert os.path.exists(custom_report_filepath)
+
+
+ def test_dump_to_report_defaults(self, unit_test_db, tmp_path):
+ report_filepath = tmp_path / "default_report.html"
+ dump_to_report(report_filepath=str(report_filepath))
+ assert os.path.exists(report_filepath)
diff --git a/tests/test_UnitTestGenerator.py b/tests/test_UnitTestGenerator.py
index 116cf0c4b..3cb5a7394 100644
--- a/tests/test_UnitTestGenerator.py
+++ b/tests/test_UnitTestGenerator.py
@@ -1,10 +1,12 @@
-import pytest
-from cover_agent.UnitTestGenerator import UnitTestGenerator
+from cover_agent.CoverageProcessor import CoverageProcessor
from cover_agent.ReportGenerator import ReportGenerator
-import os
-
+from cover_agent.Runner import Runner
+from cover_agent.UnitTestGenerator import UnitTestGenerator
from unittest.mock import patch, mock_open
-
+import datetime
+import os
+import pytest
+import tempfile
class TestUnitTestGenerator:
def test_get_included_files_mixed_paths(self):
@@ -28,3 +30,111 @@ def test_get_included_files_valid_paths(self):
result
== "file_path: `file1.txt`\ncontent:\n```\nfile content\n```\nfile_path: `file2.txt`\ncontent:\n```\nfile content\n```"
)
+ def test_get_code_language_no_extension(self):
+ with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file:
+ generator = UnitTestGenerator(
+ source_file_path=temp_source_file.name,
+ test_file_path="test_test.py",
+ code_coverage_report_path="coverage.xml",
+ test_command="pytest",
+ llm_model="gpt-3"
+ )
+ language = generator.get_code_language("filename")
+ assert language == "unknown"
+
+ def test_extract_error_message_exception_handling(self):
+ # PromptBuilder will not instantiate so we're expecting an empty error_message.
+ with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file:
+ generator = UnitTestGenerator(
+ source_file_path=temp_source_file.name,
+ test_file_path="test_test.py",
+ code_coverage_report_path="coverage.xml",
+ test_command="pytest",
+ llm_model="gpt-3"
+ )
+ with patch.object(generator, 'ai_caller') as mock_ai_caller:
+ mock_ai_caller.call_model.side_effect = Exception("Mock exception")
+ error_message = generator.extract_error_message(stderr="stderr content", stdout="stdout content")
+ assert '' in error_message
+
+ # def test_get_included_files_none(self):
+ # result = UnitTestGenerator.get_included_files(None)
+ # assert result == ""
+
+ # def test_run_coverage_command_failure(self):
+ # with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file:
+ # generator = UnitTestGenerator(
+ # source_file_path=temp_source_file.name,
+ # test_file_path="test_test.py",
+ # code_coverage_report_path="coverage.xml",
+ # test_command="invalid_command",
+ # llm_model="gpt-3"
+ # )
+ # with pytest.raises(AssertionError):
+ # generator.run_coverage()
+
+ # def test_extract_error_message_success(self):
+ # with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file:
+ # generator = UnitTestGenerator(
+ # source_file_path=temp_source_file.name,
+ # test_file_path="test_test.py",
+ # code_coverage_report_path="coverage.xml",
+ # test_command="pytest",
+ # llm_model="gpt-3"
+ # )
+ # with patch.object(generator.ai_caller, 'call_model', return_value=("error_summary: 'Mocked error summary'", 10, 10)):
+ # error_message = generator.extract_error_message(stderr="stderr content", stdout="stdout content")
+ # assert error_message == ""
+
+ def test_run_coverage_with_report_coverage_flag(self):
+ with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file:
+ generator = UnitTestGenerator(
+ source_file_path=temp_source_file.name,
+ test_file_path="test_test.py",
+ code_coverage_report_path="coverage.xml",
+ test_command="pytest",
+ llm_model="gpt-3",
+ use_report_coverage_feature_flag=True
+ )
+ with patch.object(Runner, 'run_command', return_value=("", "", 0, datetime.datetime.now())):
+ with patch.object(CoverageProcessor, 'process_coverage_report', return_value={'test.py': ([], [], 1.0)}):
+ generator.run_coverage()
+ # Dividing by zero so we're expecting a logged error and a return of 0
+ assert generator.current_coverage == 0
+
+ def test_build_prompt_with_failed_tests(self):
+ with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file:
+ generator = UnitTestGenerator(
+ source_file_path=temp_source_file.name,
+ test_file_path="test_test.py",
+ code_coverage_report_path="coverage.xml",
+ test_command="pytest",
+ llm_model="gpt-3"
+ )
+ generator.failed_test_runs = [
+ {
+ "code": {"test_code": "def test_example(): assert False"},
+ "error_message": "AssertionError"
+ }
+ ]
+ prompt = generator.build_prompt()
+ assert "Failed Test:" in prompt['user']
+
+
+ def test_generate_tests_invalid_yaml(self):
+ with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file:
+ generator = UnitTestGenerator(
+ source_file_path=temp_source_file.name,
+ test_file_path="test_test.py",
+ code_coverage_report_path="coverage.xml",
+ test_command="pytest",
+ llm_model="gpt-3"
+ )
+ generator.build_prompt = lambda: "Test prompt"
+ with patch.object(generator.ai_caller, 'call_model', return_value=("This is not YAML", 10, 10)):
+ result = generator.generate_tests()
+
+ # The eventual call to try_fix_yaml() will end up spitting out the same string but deeming is "YAML."
+ # While this is not a valid YAML, the function will return the original string (for better or for worse).
+ assert result =="This is not YAML"
+
\ No newline at end of file
diff --git a/tests/test_load_yaml.py b/tests/test_load_yaml.py
index d7e8db7fa..e88757bdd 100644
--- a/tests/test_load_yaml.py
+++ b/tests/test_load_yaml.py
@@ -95,26 +95,26 @@ def test_load_invalid_yaml2(self):
# auto-generated by cover agent
-def test_try_fix_yaml_snippet_extraction():
- from cover_agent.utils import try_fix_yaml
+ def test_try_fix_yaml_snippet_extraction(self):
+ from cover_agent.utils import try_fix_yaml
- yaml_str = "```yaml\nname: John Smith\nage: 35\n```"
- expected_output = {"name": "John Smith", "age": 35}
- assert try_fix_yaml(yaml_str) == expected_output
+ yaml_str = "```yaml\nname: John Smith\nage: 35\n```"
+ expected_output = {"name": "John Smith", "age": 35}
+ assert try_fix_yaml(yaml_str) == expected_output
-def test_try_fix_yaml_remove_all_lines():
- from cover_agent.utils import try_fix_yaml
+ def test_try_fix_yaml_remove_all_lines(self):
+ from cover_agent.utils import try_fix_yaml
- yaml_str = "language: python\nname: John Smith\nage: 35\ninvalid_line"
- expected_output = {"language": "python", "name": "John Smith", "age": 35}
- assert try_fix_yaml(yaml_str) == expected_output
+ yaml_str = "language: python\nname: John Smith\nage: 35\ninvalid_line"
+ expected_output = {"language": "python", "name": "John Smith", "age": 35}
+ assert try_fix_yaml(yaml_str) == expected_output
-def test_try_fix_yaml_llama3_8b():
- from cover_agent.utils import try_fix_yaml
+ def test_try_fix_yaml_llama3_8b(self):
+ from cover_agent.utils import try_fix_yaml
- yaml_str = """\
+ yaml_str = """\
here is the response:
language: python
@@ -128,31 +128,36 @@ def test_try_fix_yaml_llama3_8b():
hope this helps!
"""
- expected_output = {
- "here is the response": None,
- "language": "python",
- "new_tests": [
- {
- "test_behavior": "aaa\n",
- "test_name": "test_current_date",
- "test_code": "bbb\n",
- "test_tags": "happy path",
- }
- ],
- }
- assert try_fix_yaml(yaml_str) == expected_output
+ expected_output = {
+ "here is the response": None,
+ "language": "python",
+ "new_tests": [
+ {
+ "test_behavior": "aaa\n",
+ "test_name": "test_current_date",
+ "test_code": "bbb\n",
+ "test_tags": "happy path",
+ }
+ ],
+ }
+ assert try_fix_yaml(yaml_str) == expected_output
-def test_invalid_yaml_wont_parse():
- from cover_agent.utils import try_fix_yaml
+ def test_invalid_yaml_wont_parse(self):
+ from cover_agent.utils import try_fix_yaml
- yaml_str = """
+ yaml_str = """
here is the response
language: python
tests:
- test_behavior: |
- aaa
- test_name:"""
- expected_output = None
- assert load_yaml(yaml_str) == expected_output
+aaa
+test_name:"""
+ expected_output = None
+ assert load_yaml(yaml_str) == expected_output
+
+ def test_load_yaml_second_fallback_failure(self):
+ yaml_str = "```yaml\ninvalid_yaml: [unclosed_list\n```"
+ assert load_yaml(yaml_str) is None
+
diff --git a/tests/test_main.py b/tests/test_main.py
index 35fa57add..216e99e61 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -94,3 +94,41 @@ def test_main_test_file_not_found(
main()
assert str(exc_info.value) == f"Test file not found at {args.test_file_path}"
+
+ @patch("cover_agent.main.CoverAgent")
+ @patch("cover_agent.main.parse_args")
+ @patch("cover_agent.main.os.path.isfile")
+ def test_main_calls_agent_run(
+ self, mock_isfile, mock_parse_args, mock_cover_agent
+ ):
+ args = argparse.Namespace(
+ source_file_path="test_source.py",
+ test_file_path="test_file.py",
+ test_file_output_path="",
+ code_coverage_report_path="coverage_report.xml",
+ test_command="pytest",
+ test_command_dir=os.getcwd(),
+ included_files=None,
+ coverage_type="cobertura",
+ report_filepath="test_results.html",
+ desired_coverage=90,
+ max_iterations=10,
+ additional_instructions="",
+ model="gpt-4o",
+ api_base="http://localhost:11434",
+ strict_coverage=False,
+ run_tests_multiple_times=1,
+ use_report_coverage_feature_flag=False,
+ log_db_path="",
+ )
+ mock_parse_args.return_value = args
+ # Mock os.path.isfile to return True for both source and test file paths
+ mock_isfile.side_effect = lambda path: path in [args.source_file_path, args.test_file_path]
+ mock_agent_instance = MagicMock()
+ mock_cover_agent.return_value = mock_agent_instance
+
+ main()
+
+ mock_cover_agent.assert_called_once_with(args)
+ mock_agent_instance.run.assert_called_once()
+
diff --git a/tests_integration/increase_coverage.py b/tests_integration/increase_coverage.py
index 6c9bfc83e..c3ad834d6 100755
--- a/tests_integration/increase_coverage.py
+++ b/tests_integration/increase_coverage.py
@@ -9,18 +9,19 @@
# List of source/test files to iterate over:
SOURCE_TEST_FILE_LIST = [
# ["cover_agent/AICaller.py", "tests/test_AICaller.py"],
- # ["cover_agent/CoverAgent.py", "tests/test_CoverAgent.py"],
+ # ["cover_agent/CoverAgent.py", "tests/test_CoverAgent.py"],
# ["cover_agent/CoverageProcessor.py", "tests/test_CoverageProcessor.py"],
+ # ["cover_agent/CustomLogger.py", ""],
# ["cover_agent/FilePreprocessor.py", "tests/test_FilePreprocessor.py"],
# ["cover_agent/PromptBuilder.py", "tests/test_PromptBuilder.py"],
- ["cover_agent/ReportGenerator.py", "tests/test_ReportGenerator.py"],
+ # ["cover_agent/ReportGenerator.py", "tests/test_ReportGenerator.py"],
# ["cover_agent/Runner.py", "tests/test_Runner.py"],
# ["cover_agent/UnitTestDB.py", "tests/test_UnitTestDB.py"],
# ["cover_agent/UnitTestGenerator.py", "tests/test_UnitTestGenerator.py"],
+ # ["cover_agent/main.py", "tests/test_main.py"],
+ # ["cover_agent/settings/config_loader.py", ""],
+ ["cover_agent/utils.py", "tests/test_load_yaml.py"],
# ["cover_agent/version.py", "tests/test_version.py"],
- # ["cover_agent/utils.py", "tests/test_load_yaml.py"],
- # ["cover_agent/settings/config_loader.py", "tests/test_.py"],
- # ["cover_agent/CustomLogger.py", ""],
]
@@ -36,9 +37,10 @@ def __init__(self, source_file_path, test_file_path):
self.coverage_type = "cobertura"
self.report_filepath = "test_results.html"
self.desired_coverage = 100
- self.max_iterations = 4
- self.additional_instructions = ""
- self.model = "gpt-4o"
+ self.max_iterations = 5
+ self.additional_instructions = "Do not indent the tests"
+ # self.model = "gpt-4o"
+ self.model = "o1-mini"
self.api_base = "http://localhost:11434"
self.prompt_only = False
self.strict_coverage = False