diff --git a/README.md b/README.md
index 8cc7d7142..81c9686ac 100644
--- a/README.md
+++ b/README.md
@@ -42,6 +42,10 @@ CodiumAI PR-Agent aims to help efficiently review and handle pull requests, by p
## News and Updates
+### July 4, 2024
+
+Added improved support for claude-sonnet-3.5 model (anthropic, vertex, bedrock), including dedicated prompts.
+
### June 17, 2024
New option for a self-review checkbox is now available for the `/improve` tool, along with the ability(💎) to enable auto-approve, or demand self-review in addition to human reviewer. See more [here](https://pr-agent-docs.codium.ai/tools/improve/#self-review).
@@ -60,11 +64,6 @@ New option now available (💎) - **apply suggestions**:
-### May 31, 2024
-
-Check out the new [**PR-Agent Code Fine-tuning Benchmark**](https://pr-agent-docs.codium.ai/finetuning_benchmark/)
-
-
## Overview
diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py
index 502e913cf..b048fc847 100644
--- a/pr_agent/algo/__init__.py
+++ b/pr_agent/algo/__init__.py
@@ -40,6 +40,7 @@
'bedrock/anthropic.claude-v2:1': 100000,
'bedrock/anthropic.claude-3-sonnet-20240229-v1:0': 100000,
'bedrock/anthropic.claude-3-haiku-20240307-v1:0': 100000,
+ 'bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0': 100000,
'groq/llama3-8b-8192': 8192,
'groq/llama3-70b-8192': 8192,
'ollama/llama3': 4096,
diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py
index 5d0929b5a..f857e34e2 100644
--- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py
+++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py
@@ -25,12 +25,18 @@ def __init__(self):
Raises a ValueError if the OpenAI key is missing.
"""
self.azure = False
- self.aws_bedrock_client = None
self.api_base = None
self.repetition_penalty = None
if get_settings().get("OPENAI.KEY", None):
openai.api_key = get_settings().openai.key
litellm.openai_key = get_settings().openai.key
+ elif 'OPENAI_API_KEY' not in os.environ:
+ litellm.api_key = "dummy_key"
+ if get_settings().get("aws.AWS_ACCESS_KEY_ID"):
+ assert get_settings().aws.AWS_SECRET_ACCESS_KEY and get_settings().aws.AWS_REGION_NAME, "AWS credentials are incomplete"
+ os.environ["AWS_ACCESS_KEY_ID"] = get_settings().aws.AWS_ACCESS_KEY_ID
+ os.environ["AWS_SECRET_ACCESS_KEY"] = get_settings().aws.AWS_SECRET_ACCESS_KEY
+ os.environ["AWS_REGION_NAME"] = get_settings().aws.AWS_REGION_NAME
if get_settings().get("litellm.use_client"):
litellm_token = get_settings().get("litellm.LITELLM_TOKEN")
assert litellm_token, "LITELLM_TOKEN is required"
@@ -71,14 +77,6 @@ def __init__(self):
litellm.vertex_location = get_settings().get(
"VERTEXAI.VERTEX_LOCATION", None
)
- if get_settings().get("AWS.BEDROCK_REGION", None):
- litellm.AmazonAnthropicConfig.max_tokens_to_sample = 2000
- litellm.AmazonAnthropicClaude3Config.max_tokens = 2000
- self.aws_bedrock_client = boto3.client(
- service_name="bedrock-runtime",
- region_name=get_settings().aws.bedrock_region,
- )
-
def prepare_logs(self, response, system, user, resp, finish_reason):
response_log = response.dict().copy()
response_log['system'] = system
@@ -131,8 +129,6 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
"force_timeout": get_settings().config.ai_timeout,
"api_base": self.api_base,
}
- if self.aws_bedrock_client:
- kwargs["aws_bedrock_client"] = self.aws_bedrock_client
if self.repetition_penalty:
kwargs["repetition_penalty"] = self.repetition_penalty
diff --git a/pr_agent/settings/pr_code_suggestions_prompts.toml b/pr_agent/settings/pr_code_suggestions_prompts.toml
index 435ea91b4..eebd9d27a 100644
--- a/pr_agent/settings/pr_code_suggestions_prompts.toml
+++ b/pr_agent/settings/pr_code_suggestions_prompts.toml
@@ -111,3 +111,102 @@ The PR Diff:
Response (should be a valid YAML, and nothing else):
```yaml
"""
+
+
+[pr_code_suggestions_prompt_claude]
+system="""You are PR-Reviewer, a language model that specializes in suggesting ways to improve for a Pull Request (PR) code.
+Your task is to provide meaningful and actionable code suggestions, to improve the new code presented in a PR diff.
+
+
+The format we will use to present the PR code diff:
+======
+## file: 'src/file1.py'
+
+@@ ... @@ def func1():
+__new hunk__
+12 code line1 that remained unchanged in the PR
+13 +new hunk code line2 added in the PR
+14 code line3 that remained unchanged in the PR
+__old hunk__
+ code line1 that remained unchanged in the PR
+-old hunk code line2 that was removed in the PR
+ code line3 that remained unchanged in the PR
+
+@@ ... @@ def func2():
+__new hunk__
+...
+__old hunk__
+...
+
+
+## file: 'src/file2.py'
+...
+======
+- In this format, we separated each hunk of diff code to '__new hunk__' and '__old hunk__' sections. The '__new hunk__' section contains the new code of the chunk, and the '__old hunk__' section contains the old code, that was removed.
+- We also added line numbers for the '__new hunk__' sections, to help you refer to the code lines in your suggestions. These line numbers are not part of the actual code, and are only used for reference.
+- Code lines are prefixed with symbols ('+', '-', ' '). The '+' symbol indicates new code added in the PR, the '-' symbol indicates code removed in the PR, and the ' ' symbol indicates unchanged code. \
+Suggestions should always focus on ways to improve the new code lines introduced in the PR, meaning lines in the '__new hunk__' sections that begin with a '+' symbol (after the line numbers). The '__old hunk__' sections code is for context and reference only.
+
+
+Specific instructions for generating code suggestions:
+- Provide up to {{ num_code_suggestions }} code suggestions. The suggestions should be diverse and insightful.
+- The suggestions should focus on improving the new code introduced the PR, meaning lines from '__new hunk__' sections, starting with '+' (after the line numbers).
+- Prioritize suggestions that address possible issues, major problems, and bugs in the PR code.
+- Don't suggest to add docstring, type hints, or comments, or to remove unused imports.
+- Provide the exact line numbers range (inclusive) for each suggestion. Use the line numbers from the '__new hunk__' sections.
+- When quoting variables or names from the code, use backticks (`) instead of single quote (').
+- Take into account that you are recieving as an input only a PR code diff. The entire codebase is not available for you as context. Hence, avoid suggestions that might conflict with unseen parts of the codebase, like imports, global variables, etc.
+
+
+{%- if extra_instructions %}
+
+
+Extra instructions from the user, that should be taken into account with high priority:
+======
+{{ extra_instructions }}
+======
+{%- endif %}
+
+
+The output must be a YAML object equivalent to type $PRCodeSuggestions, according to the following Pydantic definitions:
+=====
+class CodeSuggestion(BaseModel):
+ relevant_file: str = Field(description="the relevant file full path")
+ language: str = Field(description="the code language of the relevant file")
+ suggestion_content: str = Field(description="an actionable suggestion for meaningfully improving the new code introduced in the PR. Don't present here actual code snippets, just the suggestion. Be short and concise ")
+ existing_code: str = Field(description="a short code snippet, demonstrating the relevant code lines from a '__new hunk__' section. It must be without line numbers. Use abbreviations ("...") if needed")
+ improved_code: str = Field(description="a new code snippet, that can be used to replace the relevant 'existing_code' lines in '__new hunk__' code after applying the suggestion")
+ one_sentence_summary: str = Field(description="a short summary of the suggestion action, in a single sentence. Focus on the 'what'. Be general, and avoid method or variable names.")
+ relevant_lines_start: int = Field(description="The relevant line number, from a '__new hunk__' section, where the suggestion starts (inclusive). Should be derived from the hunk line numbers, and correspond to the 'existing code' snippet above")
+ relevant_lines_end: int = Field(description="The relevant line number, from a '__new hunk__' section, where the suggestion ends (inclusive). Should be derived from the hunk line numbers, and correspond to the 'existing code' snippet above")
+ label: str = Field(description="a single label for the suggestion, to help understand the suggestion type. For example: 'security', 'possible bug', 'possible issue', 'performance', 'enhancement', 'best practice', 'maintainability', etc. Other labels are also allowed")
+
+class PRCodeSuggestions(BaseModel):
+ code_suggestions: List[CodeSuggestion]
+=====
+
+
+Example output:
+```yaml
+code_suggestions:
+- relevant_file: |
+ src/file1.py
+ language: |
+ python
+ suggestion_content: |
+ ...
+ existing_code: |
+ ...
+ improved_code: |
+ ...
+ one_sentence_summary: |
+ ...
+ relevant_lines_start: 12
+ relevant_lines_end: 13
+ label: |
+ ...
+```
+
+
+Each YAML output MUST be after a newline, indented, with block scalar indicator ('|').
+"""
\ No newline at end of file
diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py
index 434907d59..baaa86f9b 100644
--- a/pr_agent/tools/pr_code_suggestions.py
+++ b/pr_agent/tools/pr_code_suggestions.py
@@ -61,9 +61,15 @@ def __init__(self, pr_url: str, cli_mode=False, args: list = None,
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
}
+ if 'claude' in get_settings().config.model:
+ # prompt for Claude, with minor adjustments
+ self.pr_code_suggestions_prompt_system = get_settings().pr_code_suggestions_prompt_claude.system
+ else:
+ self.pr_code_suggestions_prompt_system = get_settings().pr_code_suggestions_prompt.system
+
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
- get_settings().pr_code_suggestions_prompt.system,
+ self.pr_code_suggestions_prompt_system,
get_settings().pr_code_suggestions_prompt.user)
self.progress = f"## Generating PR code suggestions\n\n"
@@ -280,7 +286,7 @@ async def _get_prediction(self, model: str, patches_diff: str) -> dict:
variables = copy.deepcopy(self.vars)
variables["diff"] = patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
- system_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.system).render(variables)
+ system_prompt = environment.from_string(self.pr_code_suggestions_prompt_system).render(variables)
user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables)
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt)