From d40cbed8e8438debf963d0f368dfc0bb42df54a4 Mon Sep 17 00:00:00 2001 From: David Wurtz Date: Fri, 22 Nov 2024 09:56:12 -0800 Subject: [PATCH] Dw/split run (#233) * split run method into init and run_test_gen methods * update version --- cover_agent/CoverAgent.py | 52 ++++++++++++++++++++------------ cover_agent/UnitTestValidator.py | 4 +-- cover_agent/version.txt | 2 +- 3 files changed, 36 insertions(+), 22 deletions(-) diff --git a/cover_agent/CoverAgent.py b/cover_agent/CoverAgent.py index 09c051968..e3202c987 100644 --- a/cover_agent/CoverAgent.py +++ b/cover_agent/CoverAgent.py @@ -4,6 +4,8 @@ import sys import wandb +from typing import List + from cover_agent.CustomLogger import CustomLogger from cover_agent.PromptBuilder import adapt_test_command_for_a_single_test_via_ai from cover_agent.ReportGenerator import ReportGenerator @@ -130,26 +132,14 @@ def _duplicate_test_file(self): # Otherwise, set the test file output path to the current test file self.args.test_file_output_path = self.args.test_file_path - def run(self): + def init(self): """ - Run the test generation process. - - This method performs the following steps: + Prepare for test generation process 1. Initialize the Weights & Biases run if the WANDS_API_KEY environment variable is set. 2. Initialize variables to track progress. 3. Run the initial test suite analysis. - 4. Loop until desired coverage is reached or maximum iterations are met. - 5. Generate new tests. - 6. Loop through each new test and validate it. - 7. Insert the test result into the database. - 8. Increment the iteration count. - 9. Check if the desired coverage has been reached. - 10. If the desired coverage has been reached, log the final coverage. - 11. If the maximum iteration limit is reached, log a failure message if strict coverage is specified. - 12. Provide metrics on total token usage. - 13. Generate a report. - 14. Finish the Weights & Biases run if it was initialized. + """ # Check if user has exported the WANDS_API_KEY environment variable if "WANDB_API_KEY" in os.environ: @@ -159,15 +149,35 @@ def run(self): run_name = f"{self.args.model}_" + time_and_date wandb.init(project="cover-agent", name=run_name) - # Initialize variables to track progress - iteration_count = 0 - test_results_list = [] - # Run initial test suite analysis self.test_validator.initial_test_suite_analysis() failed_test_runs, language, test_framework, coverage_report = self.test_validator.get_coverage() self.test_gen.build_prompt(failed_test_runs, language, test_framework, coverage_report) + return failed_test_runs, language, test_framework, coverage_report + + def run_test_gen(self, failed_test_runs: List, language: str, test_framework: str, coverage_report: str): + """ + Run the test generation process. + + This method performs the following steps: + + 1. Loop until desired coverage is reached or maximum iterations are met. + 2. Generate new tests. + 3. Loop through each new test and validate it. + 4. Insert the test result into the database. + 5. Increment the iteration count. + 6. Check if the desired coverage has been reached. + 7. If the desired coverage has been reached, log the final coverage. + 8. If the maximum iteration limit is reached, log a failure message if strict coverage is specified. + 9. Provide metrics on total token usage. + 10. Generate a report. + 11. Finish the Weights & Biases run if it was initialized. + """ + # Initialize variables to track progress + iteration_count = 0 + test_results_list = [] + # Loop until desired coverage is reached or maximum iterations are met while ( self.test_validator.current_coverage < (self.test_validator.desired_coverage / 100) @@ -240,3 +250,7 @@ def run(self): # Finish the Weights & Biases run if it was initialized if "WANDB_API_KEY" in os.environ: wandb.finish() + + def run(self): + failed_test_runs, language, test_framework, coverage_report = self.init() + self.run_test_gen(failed_test_runs, language, test_framework, coverage_report) \ No newline at end of file diff --git a/cover_agent/UnitTestValidator.py b/cover_agent/UnitTestValidator.py index 561a3ae7f..1e8ae2676 100644 --- a/cover_agent/UnitTestValidator.py +++ b/cover_agent/UnitTestValidator.py @@ -124,7 +124,7 @@ def get_coverage(self): self.run_coverage() return self.failed_test_runs, self.language, self.testing_framework, self.code_coverage_report - def get_code_language(self, source_file_path): + def get_code_language(self, source_file_path: str) -> str: """ Get the programming language based on the file extension of the provided source file path. @@ -251,7 +251,7 @@ def initial_test_suite_analysis(self): relevant_line_number_to_insert_imports_after = tests_dict.get( "relevant_line_number_to_insert_imports_after", None ) - self.testing_framework = tests_dict.get("testing_framework", "Unknown") + self.testing_framework: str = tests_dict.get("testing_framework", "Unknown") counter_attempts += 1 if not relevant_line_number_to_insert_tests_after: diff --git a/cover_agent/version.txt b/cover_agent/version.txt index 967b33ffb..08456a479 100644 --- a/cover_agent/version.txt +++ b/cover_agent/version.txt @@ -1 +1 @@ -0.2.7 \ No newline at end of file +0.2.8 \ No newline at end of file