Skip to content

Commit

Permalink
initial commit of PR Diff.
Browse files Browse the repository at this point in the history
  • Loading branch information
EmbeddedDevops1 committed Nov 6, 2024
1 parent 738bf47 commit b966bc6
Show file tree
Hide file tree
Showing 11 changed files with 682 additions and 405 deletions.
25 changes: 19 additions & 6 deletions cover_agent/CoverAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __init__(self, args):
llm_model=args.model,
api_base=args.api_base,
use_report_coverage_feature_flag=args.use_report_coverage_feature_flag,
diff_coverage=args.diff_coverage,
comparasion_branch=args.branch,
)

def _validate_paths(self):
Expand Down Expand Up @@ -124,9 +126,14 @@ def run(self):
and iteration_count < self.args.max_iterations
):
# Log the current coverage
self.logger.info(
f"Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%"
)
if self.args.diff_coverage:
self.logger.info(
f"Current Diff Coverage: {round(self.test_gen.current_coverage * 100, 2)}%"
)
else:
self.logger.info(
f"Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%"
)
self.logger.info(f"Desired Coverage: {self.test_gen.desired_coverage}%")

# Generate new tests
Expand All @@ -147,7 +154,10 @@ def run(self):
iteration_count += 1

# Check if the desired coverage has been reached
if self.test_gen.current_coverage < (self.test_gen.desired_coverage / 100):
if self.test_gen.current_coverage < (self.test_gen.desired_coverage / 100) and self.args.diff_coverage:
# Run the coverage tool again if the desired coverage hasn't been reached
self.test_gen.run_diff_coverage()
elif self.test_gen.current_coverage < (self.test_gen.desired_coverage / 100):
# Run the coverage tool again if the desired coverage hasn't been reached
self.test_gen.run_coverage()

Expand All @@ -157,7 +167,10 @@ def run(self):
f"Reached above target coverage of {self.test_gen.desired_coverage}% (Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%) in {iteration_count} iterations."
)
elif iteration_count == self.args.max_iterations:
failure_message = f"Reached maximum iteration limit without achieving desired coverage. Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%"
if self.args.diff_coverage:
failure_message = f"Reached maximum iteration limit without achieving desired diff coverage. Current Coverage: {round(self.test_gen.diff_coverage_percentage * 100, 2)}%"
else:
failure_message = f"Reached maximum iteration limit without achieving desired coverage. Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%"
if self.args.strict_coverage:
# User requested strict coverage (similar to "--cov-fail-under in pytest-cov"). Fail with exist code 2.
self.logger.error(failure_message)
Expand All @@ -179,4 +192,4 @@ def run(self):

# Finish the Weights & Biases run if it was initialized
if "WANDB_API_KEY" in os.environ:
wandb.finish()
wandb.finish()
17 changes: 17 additions & 0 deletions cover_agent/CoverageProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,23 @@ def parse_missed_covered_lines_jacoco_csv(

return missed, covered

def parse_diff_coverage_report(self, report_text: str) -> Tuple[int, int, int]:
# Extract total lines
total_lines_match = re.search(r'Total:\s+(\d+)\s+lines', report_text)
total_lines = int(total_lines_match.group(1)) if total_lines_match else 0

# Extract missing lines
missing_lines_match = re.search(r'Missing:\s+(\d+)\s+lines', report_text)
missing_lines = int(missing_lines_match.group(1)) if missing_lines_match else 0

coverage_match = re.search(r'Coverage:\s+(\d+)%', report_text)
coverage_percentage = float(coverage_match.group(1)) / 100 if coverage_match else 0.0

# Calculate processed lines
processed_lines = total_lines - missing_lines

return processed_lines, missing_lines, coverage_percentage

def extract_package_and_class_java(self):
package_pattern = re.compile(r"^\s*package\s+([\w\.]+)\s*;.*$")
class_pattern = re.compile(r"^\s*public\s+class\s+(\w+).*")
Expand Down
78 changes: 77 additions & 1 deletion cover_agent/PromptBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from jinja2 import Environment, StrictUndefined

from cover_agent.settings.config_loader import get_settings
import subprocess
import re

MAX_TESTS_PER_RUN = 4

Expand Down Expand Up @@ -31,6 +33,14 @@
======
"""

DIFF_COVERAGE_TEXT = """
## Diff Coverage
Focus on writing tests for only the lines changed in this branch. The following lines have been changed in the source file since the last test run:
======
{changed_lines}
======
"""


class PromptBuilder:
def __init__(
Expand All @@ -43,6 +53,8 @@ def __init__(
failed_test_runs: str = "",
language: str = "python",
testing_framework: str = "NOT KNOWN",
diff_coverage: bool = False,
diff_branch: str = "main",
):
"""
The `PromptBuilder` class is responsible for building a formatted prompt string by replacing placeholders with the actual content of files read during initialization. It takes in various paths and settings as parameters and provides a method to generate the prompt.
Expand Down Expand Up @@ -74,7 +86,12 @@ def __init__(
self.code_coverage_report = code_coverage_report
self.language = language
self.testing_framework = testing_framework
self.diff_branch = diff_branch

if diff_coverage:
diff_output = self._get_diff(source_file_path)
changed_lines = self._extract_changed_lines(diff_output)

# add line numbers to each line in 'source_file'. start from 1
self.source_file_numbered = "\n".join(
[f"{i + 1} {line}" for i, line in enumerate(self.source_file.split("\n"))]
Expand Down Expand Up @@ -102,6 +119,12 @@ def __init__(
else ""
)

self.diff_coverage_instructions = (
DIFF_COVERAGE_TEXT.format(changed_lines=changed_lines)
if diff_coverage
else ""
)

self.stdout_from_run = ""
self.stderr_from_run = ""

Expand All @@ -120,6 +143,57 @@ def _read_file(self, file_path):
return f.read()
except Exception as e:
return f"Error reading {file_path}: {e}"

def _get_diff(self, source_file_path):
try:
# Get unstaged changes
command_unstaged = ["git", "diff", self.diff_branch, "--", source_file_path]
result_unstaged = subprocess.run(
command_unstaged,
capture_output=True,
text=True,
check=True
)

# Get staged changes
command_staged = ["git", "diff", "--staged", self.diff_branch, "--", source_file_path]
result_staged = subprocess.run(
command_staged,
capture_output=True,
text=True,
check=True
)

# Combine both diffs
combined_diff = result_unstaged.stdout + result_staged.stdout
return combined_diff
except subprocess.CalledProcessError as e:
logging.error(f"Error getting diff with main: {e}")
return ""


def _extract_changed_lines(self, diff_output):
"""
Extract the line numbers of the changed lines from the diff output.
Parameters:
diff_output (str): The diff output between the source file and the main branch.
Returns:
list: A list of tuples representing the changed line numbers.
"""
changed_lines = []
diff_lines = diff_output.split("\n")
for line in diff_lines:
if line.startswith("@@"):
# Extract line numbers from the diff hunk header
match = re.search(r'@@ -\d+(?:,\d+)? \+(\d+)(?:,(\d+))? @@', line)
if match:
start_line = int(match.group(1))
line_count = int(match.group(2)) if match.group(2) else 1
for i in range(start_line, start_line + line_count):
changed_lines.append(i)
return changed_lines

def build_prompt(self) -> dict:
variables = {
Expand All @@ -133,6 +207,7 @@ def build_prompt(self) -> dict:
"additional_includes_section": self.included_files,
"failed_tests_section": self.failed_test_runs,
"additional_instructions_text": self.additional_instructions,
"diff_coverage_text": self.diff_coverage_instructions,
"language": self.language,
"max_tests": MAX_TESTS_PER_RUN,
"testing_framework": self.testing_framework,
Expand Down Expand Up @@ -175,6 +250,7 @@ def build_prompt_custom(self, file) -> dict:
"additional_includes_section": self.included_files,
"failed_tests_section": self.failed_test_runs,
"additional_instructions_text": self.additional_instructions,
"diff_coverage_text": self.diff_coverage_instructions,
"language": self.language,
"max_tests": MAX_TESTS_PER_RUN,
"testing_framework": self.testing_framework,
Expand All @@ -195,4 +271,4 @@ def build_prompt_custom(self, file) -> dict:
logging.error(f"Error rendering prompt: {e}")
return {"system": "", "user": ""}

return {"system": system_prompt, "user": user_prompt}
return {"system": system_prompt, "user": user_prompt}
70 changes: 69 additions & 1 deletion cover_agent/UnitTestGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(
desired_coverage: int = 90, # Default to 90% coverage if not specified
additional_instructions: str = "",
use_report_coverage_feature_flag: bool = False,
diff_coverage: bool = False,
comparasion_branch: str = "main",
):
"""
Initialize the UnitTestGenerator class with the provided parameters.
Expand Down Expand Up @@ -70,6 +72,8 @@ def __init__(
self.use_report_coverage_feature_flag = use_report_coverage_feature_flag
self.last_coverage_percentages = {}
self.llm_model = llm_model
self.diff_coverage = diff_coverage
self.comparasion_branch = comparasion_branch

# Objects to instantiate
self.ai_caller = AICaller(model=llm_model, api_base=api_base)
Expand All @@ -96,8 +100,9 @@ def get_coverage_and_build_prompt(self):
Returns:
None
"""
# Run coverage and build the prompt
self.run_coverage()
if self.diff_coverage:
self.run_diff_coverage()
self.prompt = self.build_prompt()

def get_code_language(self, source_file_path):
Expand Down Expand Up @@ -224,6 +229,48 @@ def run_coverage(self):
with open(self.code_coverage_report_path, "r") as f:
self.code_coverage_report = f.read()

def run_diff_coverage(self):
"""
Preform a diff coverage command to generate diff coverage report.
Process the diff coverage report to extract the diff coverage percentage.
Parameters:
- None
Returns:
- None
"""
# Perform a diff coverage command to generate a diff coverage report
coverage_filename = os.path.basename(self.code_coverage_report_path)
coverage_command = f"diff-cover --compare-branch={self.comparasion_branch} {coverage_filename}"

self.logger.info(
f'Running diff coverage command to generate diff coverage report: "{coverage_command}"'
)
stdout, stderr, exit_code, time_of_test_command = Runner.run_command(
command=coverage_command, cwd=self.test_command_dir
)
assert (
exit_code == 0
), f'Fatal: Error running test command. Are you sure the command is correct? "{coverage_command}"\nExit code {exit_code}. \nStdout: \n{stdout} \nStderr: \n{stderr}'

coverage_processor = CoverageProcessor(
file_path=self.code_coverage_report_path,
src_file_path=self.source_file_path,
coverage_type=self.coverage_type,
use_report_coverage_feature_flag=self.use_report_coverage_feature_flag
)

lines_processed, lines_missed, diff_coverage_percentage = coverage_processor.parse_diff_coverage_report(
report_text=stdout
)

self.logger.info(
f"Lines processed: {lines_processed}, Lines missed: {lines_missed}, Diff coverage: {diff_coverage_percentage}"
)

self.current_coverage = diff_coverage_percentage

@staticmethod
def get_included_files(included_files):
"""
Expand Down Expand Up @@ -303,6 +350,8 @@ def build_prompt(self) -> dict:
failed_test_runs=failed_test_runs_value,
language=self.language,
testing_framework=self.testing_framework,
diff_coverage=self.diff_coverage,
diff_branch=self.comparasion_branch,
)

return self.prompt_builder.build_prompt()
Expand Down Expand Up @@ -548,6 +597,20 @@ def validate_test(self, generated_test: dict, num_attempts=1):
)
if exit_code != 0:
break

if self.diff_coverage:
report_path = self.code_coverage_report_path.replace(self.test_command_dir, "")
report_path = report_path.lstrip("/")
test_command = f"diff-cover --compare-branch={self.comparasion_branch} {report_path}"
self.logger.info(
f'Running diff coverage command to generate diff coverage report: "{test_command}"'
)
stdout, stderr, exit_code, time_of_test_command = Runner.run_command(
command=test_command, cwd=self.test_command_dir
)
if exit_code != 0:
break



# Step 3: Check for pass/fail from the Runner object
Expand Down Expand Up @@ -624,6 +687,10 @@ def validate_test(self, generated_test: dict, num_attempts=1):
coverage_percentages[key] = percentage_covered

new_percentage_covered = total_lines_covered / total_lines
elif self.diff_coverage:
_, _, new_percentage_covered = new_coverage_processor.parse_diff_coverage_report(
report_text=stdout
)
else:
_, _, new_percentage_covered = (
new_coverage_processor.process_coverage_report(
Expand Down Expand Up @@ -751,6 +818,7 @@ def validate_test(self, generated_test: dict, num_attempts=1):
"original_test_file": original_content,
"processed_test_file": "N/A",
}


def to_dict(self):
return {
Expand Down
13 changes: 13 additions & 0 deletions cover_agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,19 @@ def parse_args():
default="",
help="Path to optional log database. Default: %(default)s.",
)
parser.add_argument(
"--diff-coverage",
action="store_true",
default=False,
help="If set, Cover-Agent will only generate tests based on the diff between branches. Default: False.",
)
parser.add_argument(
"--branch",
type=str,
default="main",
help="The branch to compare against when using --diff-coverage. Default: %(default)s.",
)

return parser.parse_args()


Expand Down
5 changes: 5 additions & 0 deletions cover_agent/settings/test_generation_prompt.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ The test framework used for running tests is `{{ testing_framework }}`.
{{ additional_instructions_text|trim }}
{% endif %}
{%- if diff_coverage_text|trim %}
{{ diff_coverage_text|trim }}
{% endif %}
## Code Coverage
Based on the code coverage report below, your goal is to suggest new test cases that would increase the current coverage. Focus only on untested or partially tested areas:
Expand Down
Loading

0 comments on commit b966bc6

Please sign in to comment.