Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR Diff that adds tests #195

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions cover_agent/CoverAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __init__(self, args):
llm_model=args.model,
api_base=args.api_base,
use_report_coverage_feature_flag=args.use_report_coverage_feature_flag,
diff_coverage=args.diff_coverage,
comparasion_branch=args.branch,
)

def _validate_paths(self):
Expand Down Expand Up @@ -124,9 +126,14 @@ def run(self):
and iteration_count < self.args.max_iterations
):
# Log the current coverage
self.logger.info(
f"Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%"
)
if self.args.diff_coverage:
self.logger.info(
f"Current Diff Coverage: {round(self.test_gen.current_coverage * 100, 2)}%"
)
else:
self.logger.info(
f"Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%"
)
self.logger.info(f"Desired Coverage: {self.test_gen.desired_coverage}%")

# Generate new tests
Expand All @@ -147,7 +154,10 @@ def run(self):
iteration_count += 1

# Check if the desired coverage has been reached
if self.test_gen.current_coverage < (self.test_gen.desired_coverage / 100):
if self.test_gen.current_coverage < (self.test_gen.desired_coverage / 100) and self.args.diff_coverage:
# Run the coverage tool again if the desired coverage hasn't been reached
self.test_gen.run_diff_coverage()
elif self.test_gen.current_coverage < (self.test_gen.desired_coverage / 100):
# Run the coverage tool again if the desired coverage hasn't been reached
self.test_gen.run_coverage()

Expand All @@ -157,7 +167,10 @@ def run(self):
f"Reached above target coverage of {self.test_gen.desired_coverage}% (Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%) in {iteration_count} iterations."
)
elif iteration_count == self.args.max_iterations:
failure_message = f"Reached maximum iteration limit without achieving desired coverage. Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%"
if self.args.diff_coverage:
failure_message = f"Reached maximum iteration limit without achieving desired diff coverage. Current Coverage: {round(self.test_gen.diff_coverage_percentage * 100, 2)}%"
else:
failure_message = f"Reached maximum iteration limit without achieving desired coverage. Current Coverage: {round(self.test_gen.current_coverage * 100, 2)}%"
if self.args.strict_coverage:
# User requested strict coverage (similar to "--cov-fail-under in pytest-cov"). Fail with exist code 2.
self.logger.error(failure_message)
Expand All @@ -179,4 +192,4 @@ def run(self):

# Finish the Weights & Biases run if it was initialized
if "WANDB_API_KEY" in os.environ:
wandb.finish()
wandb.finish()
43 changes: 43 additions & 0 deletions cover_agent/CoverageProcessor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from cover_agent.CustomLogger import CustomLogger
from typing import Literal, Tuple, Union
import csv
import json
import os
import re
import xml.etree.ElementTree as ET
Expand All @@ -21,6 +22,7 @@ def __init__(
file_path (str): The path to the coverage report file.
src_file_path (str): The fully qualified path of the file for which coverage data is being processed.
coverage_type (Literal["cobertura", "lcov"]): The type of coverage report being processed.
use_report_coverage_feature_flag (bool, optional): Controls whether to process coverage data for all files in the report (True) or only for the specified file (False). Defaults to False.

Attributes:
file_path (str): The path to the coverage report file.
Expand Down Expand Up @@ -265,6 +267,47 @@ def parse_missed_covered_lines_jacoco_csv(

return missed, covered

def parse_json_diff_coverage_report(self) -> Tuple[List[int], List[int], float]:
"""
Parses a JSON-formatted diff coverage report to extract covered lines, missed lines,
and the coverage percentage for the specified src_file_path.

Returns:
Tuple[List[int], List[int], float]: A tuple containing lists of covered and missed lines,
and the coverage percentage.
"""
with open(self.file_path, "r") as file:
report_data = json.load(file)

# Create relative path components of `src_file_path` for matching
src_relative_path = os.path.relpath(self.src_file_path)
src_relative_components = src_relative_path.split(os.sep)

# Initialize variables for covered and missed lines
relevant_stats = None

for file_path, stats in report_data["src_stats"].items():
# Split the JSON's file path into components
file_path_components = file_path.split(os.sep)

# Match if the JSON path ends with the same components as `src_file_path`
if file_path_components[-len(src_relative_components):] == src_relative_components:
relevant_stats = stats
break

# If a match is found, extract the data
if relevant_stats:
covered_lines = relevant_stats["covered_lines"]
violation_lines = relevant_stats["violation_lines"]
coverage_percentage = relevant_stats["percent_covered"] / 100 # Convert to decimal
else:
# Default values if the file isn't found in the report
covered_lines = []
violation_lines = []
coverage_percentage = 0.0

return covered_lines, violation_lines, coverage_percentage

def extract_package_and_class_java(self):
package_pattern = re.compile(r"^\s*package\s+([\w\.]+)\s*;.*$")
class_pattern = re.compile(r"^\s*public\s+class\s+(\w+).*")
Expand Down
7 changes: 4 additions & 3 deletions cover_agent/PromptBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from jinja2 import Environment, StrictUndefined

from cover_agent.settings.config_loader import get_settings
import subprocess
import re

MAX_TESTS_PER_RUN = 4

Expand Down Expand Up @@ -31,7 +33,6 @@
======
"""


class PromptBuilder:
def __init__(
self,
Expand Down Expand Up @@ -74,7 +75,7 @@ def __init__(
self.code_coverage_report = code_coverage_report
self.language = language
self.testing_framework = testing_framework

# add line numbers to each line in 'source_file'. start from 1
self.source_file_numbered = "\n".join(
[f"{i + 1} {line}" for i, line in enumerate(self.source_file.split("\n"))]
Expand Down Expand Up @@ -195,4 +196,4 @@ def build_prompt_custom(self, file) -> dict:
logging.error(f"Error rendering prompt: {e}")
return {"system": "", "user": ""}

return {"system": system_prompt, "user": user_prompt}
return {"system": system_prompt, "user": user_prompt}
Loading
Loading