Skip to content

Commit

Permalink
Split UnitTestGenerator into generator and validator. (#206)
Browse files Browse the repository at this point in the history
Issue #170 aims to tackle the refactor of UnitTestGenerator and
this is an attempt to split UnitTestGenerator into generator and
validator. This PR is a first of a series of refactoring we can
apply to UnitTestGenerator.

* Created a new class `UnitTestValidator` by copying running,
  validating and processing coverage from `UnitTestGenerator`

* Doesn't include any cleanup or optimization and kept the PR
  to be minimal structural changes.

* Use prompt from UnitTestGenerator when storing a failed test
  into the database.
  • Loading branch information
coderustic authored Nov 10, 2024
1 parent 11aad4f commit 5860b4d
Show file tree
Hide file tree
Showing 6 changed files with 972 additions and 741 deletions.
46 changes: 32 additions & 14 deletions cover_agent/CoverAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from cover_agent.CustomLogger import CustomLogger
from cover_agent.ReportGenerator import ReportGenerator
from cover_agent.UnitTestGenerator import UnitTestGenerator
from cover_agent.UnitTestValidator import UnitTestValidator
from cover_agent.UnitTestDB import UnitTestDB

class CoverAgent:
Expand All @@ -27,6 +28,21 @@ def __init__(self, args):
self._duplicate_test_file()

self.test_gen = UnitTestGenerator(
source_file_path=args.source_file_path,
test_file_path=args.test_file_output_path,
project_root=args.project_root,
code_coverage_report_path=args.code_coverage_report_path,
test_command=args.test_command,
test_command_dir=args.test_command_dir,
included_files=args.included_files,
coverage_type=args.coverage_type,
additional_instructions=args.additional_instructions,
llm_model=args.model,
api_base=args.api_base,
use_report_coverage_feature_flag=args.use_report_coverage_feature_flag,
)

self.test_validator = UnitTestValidator(
source_file_path=args.source_file_path,
test_file_path=args.test_file_output_path,
project_root=args.project_root,
Expand Down Expand Up @@ -123,29 +139,31 @@ def run(self):
test_results_list = []

# Run initial test suite analysis
self.test_gen.get_coverage_and_build_prompt()
self.test_gen.initial_test_suite_analysis()
self.test_validator.initial_test_suite_analysis()
failed_test_runs = self.test_validator.get_coverage()
self.test_gen.build_prompt(failed_test_runs)

# Loop until desired coverage is reached or maximum iterations are met
while (
self.test_gen.current_coverage < (self.test_gen.desired_coverage / 100)
self.test_validator.current_coverage < (self.test_validator.desired_coverage / 100)
and iteration_count < self.args.max_iterations
):
# Log the current coverage
self.logger.info(
f"Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%"
f"Current Coverage: {round(self.test_validator.current_coverage * 100, 2)}%"
)
self.logger.info(f"Desired Coverage: {self.test_gen.desired_coverage}%")
self.logger.info(f"Desired Coverage: {self.test_validator.desired_coverage}%")

# Generate new tests
generated_tests_dict = self.test_gen.generate_tests()
generated_tests_dict = self.test_gen.generate_tests(failed_test_runs)

# Loop through each new test and validate it
for generated_test in generated_tests_dict.get("new_tests", []):
# Validate the test and record the result
test_result = self.test_gen.validate_test(
test_result = self.test_validator.validate_test(
generated_test, self.args.run_tests_multiple_times
)
test_result["prompt"] = self.test_gen.prompt["user"] # get the prompt used to generate the test so that it is stored in the database
test_results_list.append(test_result)

# Insert the test result into the database
Expand All @@ -155,17 +173,17 @@ def run(self):
iteration_count += 1

# Check if the desired coverage has been reached
if self.test_gen.current_coverage < (self.test_gen.desired_coverage / 100):
if self.test_validator.current_coverage < (self.test_validator.desired_coverage / 100):
# Run the coverage tool again if the desired coverage hasn't been reached
self.test_gen.run_coverage()
self.test_validator.run_coverage()

# Log the final coverage
if self.test_gen.current_coverage >= (self.test_gen.desired_coverage / 100):
if self.test_validator.current_coverage >= (self.test_validator.desired_coverage / 100):
self.logger.info(
f"Reached above target coverage of {self.test_gen.desired_coverage}% (Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%) in {iteration_count} iterations."
f"Reached above target coverage of {self.test_validator.desired_coverage}% (Current Coverage: {round(self.test_validator.current_coverage * 100, 2)}%) in {iteration_count} iterations."
)
elif iteration_count == self.args.max_iterations:
failure_message = f"Reached maximum iteration limit without achieving desired coverage. Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%"
failure_message = f"Reached maximum iteration limit without achieving desired coverage. Current Coverage: {round(self.test_validator.current_coverage * 100, 2)}%"
if self.args.strict_coverage:
# User requested strict coverage (similar to "--cov-fail-under in pytest-cov"). Fail with exist code 2.
self.logger.error(failure_message)
Expand All @@ -175,10 +193,10 @@ def run(self):

# Provide metrics on total token usage
self.logger.info(
f"Total number of input tokens used for LLM model {self.test_gen.ai_caller.model}: {self.test_gen.total_input_token_count}"
f"Total number of input tokens used for LLM model {self.test_gen.ai_caller.model}: {self.test_gen.total_input_token_count + self.test_validator.total_input_token_count}"
)
self.logger.info(
f"Total number of output tokens used for LLM model {self.test_gen.ai_caller.model}: {self.test_gen.total_output_token_count}"
f"Total number of output tokens used for LLM model {self.test_gen.ai_caller.model}: {self.test_gen.total_output_token_count + self.test_validator.total_output_token_count}"
)

# Generate a report
Expand Down
Loading

0 comments on commit 5860b4d

Please sign in to comment.