Skip to content

Commit

Permalink
Adding support for o1-preview (#166)
Browse files Browse the repository at this point in the history
* Updated non-streaming support and added test.

* Added tests using o1-preview. Wow!

* Increased code coverage with o1-preview and updated final code.

* Incrementing version.

* Increased coverage and updated o1 calls.

* Merged system prompt instead of popping.
  • Loading branch information
EmbeddedDevops1 authored Oct 8, 2024
1 parent 73c1723 commit 32b2058
Show file tree
Hide file tree
Showing 13 changed files with 382 additions and 88 deletions.
87 changes: 54 additions & 33 deletions cover_agent/AICaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, model: str, api_base: str = ""):
Parameters:
model (str): The name of the model to be used.
api_base (str): The base API url to use in case model is set to Ollama or Hugging Face
api_base (str): The base API URL to use in case the model is set to Ollama or Hugging Face.
"""
self.model = model
self.api_base = api_base
Expand All @@ -25,6 +25,7 @@ def call_model(self, prompt: dict, max_tokens=4096, stream=True):
Parameters:
prompt (dict): The prompt to be sent to the language model.
max_tokens (int, optional): The maximum number of tokens to generate in the response. Defaults to 4096.
stream (bool, optional): Whether to stream the response or not. Defaults to True.
Returns:
tuple: A tuple containing the response generated by the language model, the number of tokens used from the prompt, and the total number of tokens in the response.
Expand All @@ -36,21 +37,34 @@ def call_model(self, prompt: dict, max_tokens=4096, stream=True):
if prompt["system"] == "":
messages = [{"role": "user", "content": prompt["user"]}]
else:
messages = [
{"role": "system", "content": prompt["system"]},
{"role": "user", "content": prompt["user"]},
]
if self.model in ["o1-preview", "o1-mini"]:
# o1 doesn't accept a system message so we add it to the prompt
messages = [
{"role": "user", "content": prompt["system"] + "\n" + prompt["user"]},
]
else:
messages = [
{"role": "system", "content": prompt["system"]},
{"role": "user", "content": prompt["user"]},
]

# Default Completion parameters
# Default completion parameters
completion_params = {
"model": self.model,
"messages": messages,
"max_tokens": max_tokens,
"stream": True,
"stream": stream, # Use the stream parameter passed to the method
"temperature": 0.2,
"max_tokens": max_tokens,
}

# API base exception for OpenAI Compatible, Ollama and Hugging Face models
# Model-specific adjustments
if self.model in ["o1-preview", "o1-mini"]:
completion_params["temperature"] = 1
completion_params["stream"] = False # o1 doesn't support streaming
completion_params["max_completion_tokens"] = max_tokens
completion_params.pop("max_tokens", None) # Remove 'max_tokens' if present

# API base exception for OpenAI Compatible, Ollama, and Hugging Face models
if (
"ollama" in self.model
or "huggingface" in self.model
Expand All @@ -60,39 +74,46 @@ def call_model(self, prompt: dict, max_tokens=4096, stream=True):

response = litellm.completion(**completion_params)

chunks = []
print("Streaming results from LLM model...") if stream else None
try:
for chunk in response:
print(chunk.choices[0].delta.content or "", end="", flush=True) if stream else None
chunks.append(chunk)
time.sleep(
0.01
) # Optional: Delay to simulate more 'natural' response pacing
except Exception as e:
print(f"Error during streaming: {e}") if stream else None
print("\n") if stream else None

model_response = litellm.stream_chunk_builder(chunks, messages=messages)
if stream:
chunks = []
print("Streaming results from LLM model...")
try:
for chunk in response:
print(chunk.choices[0].delta.content or "", end="", flush=True)
chunks.append(chunk)
time.sleep(
0.01
) # Optional: Delay to simulate more 'natural' response pacing
except Exception as e:
print(f"Error during streaming: {e}")
print("\n")
# Build the final response from the streamed chunks
model_response = litellm.stream_chunk_builder(chunks, messages=messages)
content = model_response["choices"][0]["message"]["content"]
usage = model_response["usage"]
prompt_tokens = int(usage["prompt_tokens"])
completion_tokens = int(usage["completion_tokens"])
else:
# Non-streaming response is a CompletionResponse object
content = response.choices[0].message.content
print(f"Printing results from LLM model...\n{content}")
usage = response.usage
prompt_tokens = int(usage.prompt_tokens)
completion_tokens = int(usage.completion_tokens)

if "WANDB_API_KEY" in os.environ:
root_span = Trace(
name="inference_"
+ datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
kind="llm", # kind can be "llm", "chain", "agent" or "tool
kind="llm", # kind can be "llm", "chain", "agent", or "tool"
inputs={
"user_prompt": prompt["user"],
"system_prompt": prompt["system"],
},
outputs={
"model_response": model_response["choices"][0]["message"]["content"]
},
outputs={"model_response": content},
)
root_span.log(name="inference")

# Returns: Response, Prompt token count, and Response token count
return (
model_response["choices"][0]["message"]["content"],
int(model_response["usage"]["prompt_tokens"]),
int(model_response["usage"]["completion_tokens"]),
)
# Returns: Response, Prompt token count, and Completion token count
return content, prompt_tokens, completion_tokens

10 changes: 8 additions & 2 deletions cover_agent/ReportGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,15 @@ def generate_full_diff(cls, original, processed):
@classmethod
def generate_partial_diff(cls, original, processed, context_lines=3):
"""
Generates a partial diff of both the original and processed test files,
Generates a partial diff of both the original and processed test files,
showing only added, removed, or changed lines with a few lines of context.
Note:
- The `difflib.unified_diff` function is used, which includes header lines (`---` and `+++`)
that indicate the original and modified file names or timestamps.
- It also includes context lines starting with `@@`, which show the range of lines affected.
- These lines are essential parts of the diff output and should be included in the expected outputs of tests.
:param original: String content of the original test file.
:param processed: String content of the processed test file.
:param context_lines: Number of context lines to include around changes.
Expand Down
18 changes: 15 additions & 3 deletions cover_agent/UnitTestGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
self.total_input_token_count = 0
self.total_output_token_count = 0
self.testing_framework = "Unknown"
self.code_coverage_report = ""

# Read self.source_file_path into a string
with open(self.source_file_path, "r") as f:
Expand Down Expand Up @@ -184,7 +185,11 @@ def run_coverage(self):
if key not in self.last_coverage_percentages:
self.last_coverage_percentages[key] = 0
self.last_coverage_percentages[key] = percentage_covered
percentage_covered = total_lines_covered / total_lines
try:
percentage_covered = total_lines_covered / total_lines
except ZeroDivisionError:
self.logger.error(f"ZeroDivisionError: Attempting to perform total_lines_covered / total_lines: {total_lines_covered} / {total_lines}.")
percentage_covered = 0

self.logger.info(
f"Total lines covered: {total_lines_covered}, Total lines missed: {total_lines_missed}, Total lines: {total_lines}"
Expand Down Expand Up @@ -325,9 +330,11 @@ def initial_test_suite_analysis(self):
prompt_headers_indentation = self.prompt_builder.build_prompt_custom(
file="analyze_suite_test_headers_indentation"
)
self.ai_caller.model = "gpt-4o" if self.llm_model in ["o1-preview", "o1-mini"] else self.llm_model # Exception for OpenAI's new reasoning engines
response, prompt_token_count, response_token_count = (
self.ai_caller.call_model(prompt=prompt_headers_indentation)
)
self.ai_caller.model = self.llm_model
self.total_input_token_count += prompt_token_count
self.total_output_token_count += response_token_count
tests_dict = load_yaml(response)
Expand All @@ -350,9 +357,11 @@ def initial_test_suite_analysis(self):
prompt_test_insert_line = self.prompt_builder.build_prompt_custom(
file="analyze_suite_test_insert_line"
)
self.ai_caller.model = "gpt-4o" if self.llm_model in ["o1-preview", "o1-mini"] else self.llm_model # Exception for OpenAI's new reasoning engines
response, prompt_token_count, response_token_count = (
self.ai_caller.call_model(prompt=prompt_test_insert_line)
)
self.ai_caller.model = self.llm_model
self.total_input_token_count += prompt_token_count
self.total_output_token_count += response_token_count
tests_dict = load_yaml(response)
Expand Down Expand Up @@ -401,8 +410,9 @@ def generate_tests(self, max_tokens=4096):
"""
self.prompt = self.build_prompt()

stream = False if self.llm_model in ["o1-preview", "o1-mini"] else True
response, prompt_token_count, response_token_count = (
self.ai_caller.call_model(prompt=self.prompt, max_tokens=max_tokens)
self.ai_caller.call_model(prompt=self.prompt, max_tokens=max_tokens, stream=stream)
)
self.total_input_token_count += prompt_token_count
self.total_output_token_count += response_token_count
Expand Down Expand Up @@ -790,15 +800,17 @@ def extract_error_message(self, stderr, stdout):
)

# Run the analysis via LLM
self.ai_caller.model = "gpt-4o" if self.llm_model in ["o1-preview", "o1-mini"] else self.llm_model # Exception for OpenAI's new reasoning engines
response, prompt_token_count, response_token_count = (
self.ai_caller.call_model(prompt=prompt_headers_indentation, stream=False)
)
self.ai_caller.model = self.llm_model # Reset
self.total_input_token_count += prompt_token_count
self.total_output_token_count += response_token_count
tests_dict = load_yaml(response)

return tests_dict.get("error_summary", f"ERROR: Unable to summarize error message from inputs. STDERR: {stderr}\nSTDOUT: {stdout}.")
except Exception as e:
logging.error(f"ERROR: Unable to extract error message from inputs using LLM.\nSTDERR: {stderr}\nSTDOUT: {stdout}\n\n{response}")
logging.error(f"ERROR: Unable to extract error message from inputs using LLM.\nSTDERR: {stderr}\nSTDOUT: {stdout}")
logging.error(f"Error extracting error message: {e}")
return ""
2 changes: 1 addition & 1 deletion cover_agent/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.0
0.2.1
34 changes: 33 additions & 1 deletion tests/test_AICaller.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

import pytest
from unittest.mock import patch
from unittest.mock import patch, Mock
from cover_agent.AICaller import AICaller


Expand Down Expand Up @@ -112,3 +112,35 @@ def test_call_model_missing_keys(self, ai_caller):
str(exc_info.value)
== "\"The prompt dictionary must contain 'system' and 'user' keys.\""
)

@patch("cover_agent.AICaller.litellm.completion")
def test_call_model_o1_preview(self, mock_completion, ai_caller):
ai_caller.model = "o1-preview"
prompt = {"system": "System message", "user": "Hello, world!"}
# Mock the response
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="response"))]
mock_response.usage = Mock(prompt_tokens=2, completion_tokens=10)
mock_completion.return_value = mock_response
# Call the method
response, prompt_tokens, response_tokens = ai_caller.call_model(prompt, stream=False)
assert response == "response"
assert prompt_tokens == 2
assert response_tokens == 10

@patch("cover_agent.AICaller.litellm.completion")
def test_call_model_streaming_response(self, mock_completion, ai_caller):
prompt = {"system": "", "user": "Hello, world!"}
# Mock the response to be an iterable of chunks
mock_chunk = Mock()
mock_chunk.choices = [Mock(delta=Mock(content="response part"))]
mock_completion.return_value = [mock_chunk]
with patch("cover_agent.AICaller.litellm.stream_chunk_builder") as mock_builder:
mock_builder.return_value = {
"choices": [{"message": {"content": "response"}}],
"usage": {"prompt_tokens": 2, "completion_tokens": 10},
}
response, prompt_tokens, response_tokens = ai_caller.call_model(prompt, stream=True)
assert response == "response"
assert prompt_tokens == 2
assert response_tokens == 10
40 changes: 39 additions & 1 deletion tests/test_CoverAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,42 @@ def test_duplicate_test_file_without_output_path(self, mock_isfile):

# Clean up the temp files
os.remove(temp_source_file.name)
os.remove(temp_test_file.name)
os.remove(temp_test_file.name)

@patch("cover_agent.CoverAgent.os.environ", {})
@patch("cover_agent.CoverAgent.sys.exit")
@patch("cover_agent.CoverAgent.UnitTestGenerator")
@patch("cover_agent.CoverAgent.UnitTestDB")
def test_run_max_iterations_strict_coverage(self, mock_test_db, mock_unit_test_generator, mock_sys_exit):
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_source_file:
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_test_file:
args = argparse.Namespace(
source_file_path=temp_source_file.name,
test_file_path=temp_test_file.name,
test_file_output_path="output_test_file.py",
code_coverage_report_path="coverage_report.xml",
test_command="pytest",
test_command_dir=os.getcwd(),
included_files=None,
coverage_type="cobertura",
report_filepath="test_results.html",
desired_coverage=90,
max_iterations=1,
additional_instructions="",
model="openai/test-model",
api_base="openai/test-api",
use_report_coverage_feature_flag=False,
log_db_path="",
run_tests_multiple_times=False,
strict_coverage=True
)
# Mock the methods used in run
instance = mock_unit_test_generator.return_value
instance.current_coverage = 0.5 # below desired coverage
instance.desired_coverage = 90
instance.generate_tests.return_value = {"new_tests": [{}]}
agent = CoverAgent(args)
agent.run()
# Assertions to ensure sys.exit was called
mock_sys_exit.assert_called_once_with(2)
mock_test_db.return_value.dump_to_report.assert_called_once_with(args.report_filepath)
2 changes: 1 addition & 1 deletion tests/test_PromptBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,4 @@ def test_custom_analyze_test_run_failure(self):
# Clean up
os.remove(source_file.name)
os.remove(test_file.name)
os.remove(tmp_file.name)
os.remove(tmp_file.name)
10 changes: 10 additions & 0 deletions tests/test_ReportGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,14 @@ def test_generate_report(self, sample_results, expected_output, tmp_path):
assert expected_output[2] in content # Check if the row includes "test_current_date"
assert expected_output[3] in content # Check if the HTML closes properly

def test_generate_partial_diff_basic(self):
original = "line1\nline2\nline3"
processed = "line1\nline2 modified\nline3\nline4"
diff_output = ReportGenerator.generate_partial_diff(original, processed)
assert '<span class="diff-added">+line2 modified</span>' in diff_output
assert '<span class="diff-added">+line4</span>' in diff_output
assert '<span class="diff-removed">-line2</span>' in diff_output
assert '<span class="diff-unchanged"> line1</span>' in diff_output


# Additional validation can be added based on specific content if required
20 changes: 20 additions & 0 deletions tests/test_UnitTestDB.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import os
from datetime import datetime, timedelta
from cover_agent.UnitTestDB import dump_to_report_cli
from cover_agent.UnitTestDB import dump_to_report
from cover_agent.UnitTestDB import UnitTestDB, UnitTestGenerationAttempt

DB_NAME = "unit_test_runs.db"
Expand Down Expand Up @@ -91,3 +93,21 @@ def test_dump_to_report(self, unit_test_db, tmp_path):
assert "sample test code" in content
assert "sample new test code" in content
assert "def test_example(): pass" in content


def test_dump_to_report_cli_custom_args(self, unit_test_db, tmp_path, monkeypatch):
custom_db_path = str(tmp_path / "cli_custom_unit_test_runs.db")
custom_report_filepath = str(tmp_path / "cli_custom_report.html")
monkeypatch.setattr("sys.argv", [
"prog",
"--path-to-db", custom_db_path,
"--report-filepath", custom_report_filepath
])
dump_to_report_cli()
assert os.path.exists(custom_report_filepath)


def test_dump_to_report_defaults(self, unit_test_db, tmp_path):
report_filepath = tmp_path / "default_report.html"
dump_to_report(report_filepath=str(report_filepath))
assert os.path.exists(report_filepath)
Loading

0 comments on commit 32b2058

Please sign in to comment.