Skip to content

Commit

Permalink
feat: finished aggregated solution implementation for openai and anth…
Browse files Browse the repository at this point in the history
…ropic
  • Loading branch information
KaiserRuben committed Jul 15, 2024
1 parent 1ce18e9 commit 8e76c7a
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 54 deletions.
11 changes: 6 additions & 5 deletions src/ai/LLM/BaseLLMService.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, List, Union, Tuple
from tqdm import tqdm
import logging

from data.Finding import Finding
Expand All @@ -22,12 +23,12 @@ def get_url(self) -> str:
pass

@abstractmethod
def _generate(self, prompt: str) -> Dict[str, str]:
def _generate(self, prompt: str, json=False) -> Dict[str, str]:
pass

def generate(self, prompt: str) -> Dict[str, str]:
def generate(self, prompt: str, json=False) -> Dict[str, str]:
try:
return self._generate(prompt)
return self._generate(prompt, json)
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
return {"error": str(e)}
Expand Down Expand Up @@ -110,7 +111,7 @@ def generate_aggregated_solution(self, findings: List[Finding]) -> List[Tuple[st

results = []

for group, meta_info in finding_groups:
for group, meta_info in tqdm(finding_groups, desc="Generating aggregated solutions for group", unit="group"):
prompt = self._get_aggregated_solution_prompt(group, meta_info)
response = self.generate(prompt)
solution = self._process_aggregated_solution_response(response)
Expand All @@ -134,7 +135,7 @@ def _get_findings_str_for_aggregation(self, findings, details=False) -> str:

def _subdivide_finding_group(self, findings: List[Finding]) -> List[Tuple[List[Finding], Dict]]:
prompt = self._get_subdivision_prompt(findings)
response = self.generate(prompt)
response = self.generate(prompt, json=True)
return self._process_subdivision_response(response, findings)

@abstractmethod
Expand Down
91 changes: 76 additions & 15 deletions src/ai/LLM/Strategies/AnthropicService.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import re
from typing import Dict, List, Optional, Union
from enum import Enum

Expand All @@ -8,12 +9,14 @@
from ai.LLM.LLMServiceMixin import LLMServiceMixin
from data.Finding import Finding
from ai.LLM.Strategies.openai_prompts import (
CLASSIFY_KIND_TEMPLATE,
SHORT_RECOMMENDATION_TEMPLATE,
GENERIC_LONG_RECOMMENDATION_TEMPLATE,
SEARCH_TERMS_TEMPLATE,
META_PROMPT_GENERATOR_TEMPLATE,
LONG_RECOMMENDATION_TEMPLATE, COMBINE_DESCRIPTIONS_TEMPLATE,
OPENAI_CLASSIFY_KIND_TEMPLATE,
OPENAI_SHORT_RECOMMENDATION_TEMPLATE,
OPENAI_GENERIC_LONG_RECOMMENDATION_TEMPLATE,
OPENAI_SEARCH_TERMS_TEMPLATE,
OPENAI_META_PROMPT_GENERATOR_TEMPLATE,
OPENAI_LONG_RECOMMENDATION_TEMPLATE,
OPENAI_COMBINE_DESCRIPTIONS_TEMPLATE,
OPENAI_AGGREGATED_SOLUTION_TEMPLATE, OPENAI_SUBDIVISION_PROMPT_TEMPLATE,
)
from utils.text_tools import clean
from config import config
Expand Down Expand Up @@ -70,7 +73,7 @@ def get_url(self) -> str:
"""Get the URL for the Anthropic API (placeholder method)."""
return "-"

def _generate(self, prompt: str) -> Dict[str, str]:
def _generate(self, prompt: str, json=False) -> Dict[str, str]:
"""
Generate a response using the Anthropic API.
Expand All @@ -81,29 +84,34 @@ def _generate(self, prompt: str) -> Dict[str, str]:
Dict[str, str]: A dictionary containing the generated response.
"""
try:
messages = [{"role": "user", "content": prompt}]
if json:
messages.append({"role": "assistant", "content": "Here is the JSON requested:\n{"})
message = self.client.messages.create(
max_tokens=1024,
messages=[{"role": "user", "content": prompt}],
messages=messages,
model=self.model,
)
content = message.content[0].text
if json:
content = "{" + content
return {"response": content}
except Exception as e:
return self.handle_api_error(e)

def _get_classification_prompt(self, options: str, field_name: str, finding_str: str) -> str:
"""Generate the classification prompt for Anthropic."""
return CLASSIFY_KIND_TEMPLATE.format(options=options, field_name=field_name, data=finding_str)
return OPENAI_CLASSIFY_KIND_TEMPLATE.format(options=options, field_name=field_name, data=finding_str)

def _get_recommendation_prompt(self, finding: Finding, short: bool) -> str:
"""Generate the recommendation prompt for Anthropic."""
if short:
return SHORT_RECOMMENDATION_TEMPLATE.format(data=str(finding))
return OPENAI_SHORT_RECOMMENDATION_TEMPLATE.format(data=str(finding))
elif finding.solution and finding.solution.short_description:
finding.solution.add_to_metadata("used_meta_prompt", True)
return self._generate_prompt_with_meta_prompts(finding)
else:
return GENERIC_LONG_RECOMMENDATION_TEMPLATE
return OPENAI_GENERIC_LONG_RECOMMENDATION_TEMPLATE

def _process_recommendation_response(self, response: Dict[str, str], finding: Finding, short: bool) -> Union[
str, List[str]]:
Expand All @@ -117,11 +125,11 @@ def _process_recommendation_response(self, response: Dict[str, str], finding: Fi
def _generate_prompt_with_meta_prompts(self, finding: Finding) -> str:
"""Generate a prompt with meta-prompts for long recommendations."""
short_recommendation = finding.solution.short_description
meta_prompt_generator = META_PROMPT_GENERATOR_TEMPLATE.format(finding=str(finding))
meta_prompt_generator = OPENAI_META_PROMPT_GENERATOR_TEMPLATE.format(finding=str(finding))
meta_prompt_response = self.generate(meta_prompt_generator)
meta_prompts = clean(meta_prompt_response.get("response", ""), llm_service=self)

long_prompt = LONG_RECOMMENDATION_TEMPLATE.format(meta_prompts=meta_prompts)
long_prompt = OPENAI_LONG_RECOMMENDATION_TEMPLATE.format(meta_prompts=meta_prompts)

finding.solution.add_to_metadata(
"prompt_long_breakdown",
Expand All @@ -135,7 +143,7 @@ def _generate_prompt_with_meta_prompts(self, finding: Finding) -> str:

def _get_search_terms_prompt(self, finding: Finding) -> str:
"""Generate the search terms prompt for Anthropic."""
return SEARCH_TERMS_TEMPLATE.format(data=str(finding))
return OPENAI_SEARCH_TERMS_TEMPLATE.format(data=str(finding))

def _process_search_terms_response(self, response: Dict[str, str], finding: Finding) -> str:
"""Process the search terms response from Anthropic."""
Expand All @@ -144,6 +152,59 @@ def _process_search_terms_response(self, response: Dict[str, str], finding: Find
return ""
return clean(response["response"], llm_service=self)

def _get_subdivision_prompt(self, findings: List[Finding]) -> str:
findings_str = self._get_findings_str_for_aggregation(findings)
return OPENAI_SUBDIVISION_PROMPT_TEMPLATE.format(data=findings_str)

def _process_subdivision_response(self, response: Dict[str, str], findings: List[Finding]) -> List[Tuple[List[Finding], Dict]]:
if "response" not in response:
logger.warning("Failed to subdivide findings")
return [(findings, {})] # Return all findings as a single group if subdivision fails

try:
response = response["response"]
# remove prefix ```json and suffix ```
response = re.sub(r'^```json', '', response)
response = re.sub(r'```$', '', response)
subdivisions = json.loads(response)["subdivisions"]
except json.JSONDecodeError:
logger.error("Failed to parse JSON response")
return [(findings, {})]
except KeyError:
logger.error("Unexpected JSON structure in response")
return [(findings, {})]

result = []
for subdivision in subdivisions:
try:
group_indices = [int(i.strip()) - 1 for i in subdivision["group"].split(',')]
group = [findings[i] for i in group_indices if i < len(findings)]
meta_info = {"reason": subdivision.get("reason", "")}
if len(group) == 1:
continue # Skip single-element groups for *aggregated* solutions
result.append((group, meta_info))
except ValueError:
logger.error(f"Failed to parse group indices: {subdivision['group']}")
continue
except KeyError:
logger.error("Unexpected subdivision structure")
continue

return result

def _get_aggregated_solution_prompt(self, findings: List[Finding], meta_info: Dict) -> str:
findings_str = self._get_findings_str_for_aggregation(findings, details=True)
return OPENAI_AGGREGATED_SOLUTION_TEMPLATE.format(
data=findings_str,
meta_info=meta_info.get("reason", "")
)

def _process_aggregated_solution_response(self, response: Dict[str, str]) -> str:
if "response" not in response:
logger.warning("Failed to generate an aggregated solution")
return ""
return clean(response["response"], llm_service=self)

def convert_dict_to_str(self, data: Dict) -> str:
"""
Convert a dictionary to a string representation.
Expand Down Expand Up @@ -171,7 +232,7 @@ def combine_descriptions(self, descriptions: List[str]) -> str:
if len(descriptions) <= 1:
return descriptions[0] if descriptions else ""

prompt = COMBINE_DESCRIPTIONS_TEMPLATE.format(data=descriptions)
prompt = OPENAI_COMBINE_DESCRIPTIONS_TEMPLATE.format(data=descriptions)

response = self.generate(prompt)
if "response" not in response:
Expand Down
3 changes: 2 additions & 1 deletion src/ai/LLM/Strategies/OLLAMAService.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def get_url(self) -> str:
@retry(
stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60)
)
def _generate(self, prompt: str) -> Dict[str, str]:
def _generate(self, prompt: str, json=True) -> Dict[str, str]:
# The JSON Param is ignored by the OLLAMA server, it always returns JSON
payload = {"prompt": prompt, **self.generate_payload}
try:
timeout = httpx.Timeout(timeout=300.0)
Expand Down
100 changes: 81 additions & 19 deletions src/ai/LLM/Strategies/OpenAIService.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import json
import re
from enum import Enum
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Tuple

import openai

from ai.LLM.BaseLLMService import BaseLLMService
from ai.LLM.LLMServiceMixin import LLMServiceMixin
from data.Finding import Finding
from ai.LLM.Strategies.openai_prompts import (
CLASSIFY_KIND_TEMPLATE,
SHORT_RECOMMENDATION_TEMPLATE,
GENERIC_LONG_RECOMMENDATION_TEMPLATE,
SEARCH_TERMS_TEMPLATE,
META_PROMPT_GENERATOR_TEMPLATE,
LONG_RECOMMENDATION_TEMPLATE, COMBINE_DESCRIPTIONS_TEMPLATE,
OPENAI_CLASSIFY_KIND_TEMPLATE,
OPENAI_SHORT_RECOMMENDATION_TEMPLATE,
OPENAI_GENERIC_LONG_RECOMMENDATION_TEMPLATE,
OPENAI_SEARCH_TERMS_TEMPLATE,
OPENAI_META_PROMPT_GENERATOR_TEMPLATE,
OPENAI_LONG_RECOMMENDATION_TEMPLATE,
OPENAI_COMBINE_DESCRIPTIONS_TEMPLATE,
OPENAI_AGGREGATED_SOLUTION_TEMPLATE, OPENAI_SUBDIVISION_PROMPT_TEMPLATE,
)
from utils.text_tools import clean

Expand All @@ -25,7 +28,7 @@


class OpenAIService(BaseLLMService, LLMServiceMixin):
def __init__(self, api_key: Optional[str] = None, model: str = "gpt-4"):
def __init__(self, api_key: Optional[str] = None, model: str = "gpt-4o"):
"""
Initialize the OpenAIService.
Expand Down Expand Up @@ -57,27 +60,33 @@ def get_context_size(self) -> int:
def get_url(self) -> str:
return "-"

def _generate(self, prompt: str) -> Dict[str, str]:
def _generate(self, prompt: str, json=False) -> Dict[str, str]:
try:
response = openai.chat.completions.create(
model=self.model, messages=[{"role": "user", "content": prompt}]
)
params = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}]
}

if json:
params["response_format"] = {"type": "json_object"}

response = openai.chat.completions.create(**params)
content = response.choices[0].message.content
return {"response": content}
except Exception as e:
return self.handle_api_error(e)

def _get_classification_prompt(self, options: str, field_name: str, finding_str: str) -> str:
return CLASSIFY_KIND_TEMPLATE.format(options=options, field_name=field_name, data=finding_str)
return OPENAI_CLASSIFY_KIND_TEMPLATE.format(options=options, field_name=field_name, data=finding_str)

def _get_recommendation_prompt(self, finding: Finding, short: bool) -> str:
if short:
return SHORT_RECOMMENDATION_TEMPLATE.format(data=str(finding))
return OPENAI_SHORT_RECOMMENDATION_TEMPLATE.format(data=str(finding))
elif finding.solution and finding.solution.short_description:
finding.solution.add_to_metadata("used_meta_prompt", True)
return self._generate_prompt_with_meta_prompts(finding)
else:
return GENERIC_LONG_RECOMMENDATION_TEMPLATE
return OPENAI_GENERIC_LONG_RECOMMENDATION_TEMPLATE

def _process_recommendation_response(self, response: Dict[str, str], finding: Finding, short: bool) -> Union[
str, List[str]]:
Expand All @@ -89,11 +98,11 @@ def _process_recommendation_response(self, response: Dict[str, str], finding: Fi

def _generate_prompt_with_meta_prompts(self, finding: Finding) -> str:
short_recommendation = finding.solution.short_description
meta_prompt_generator = META_PROMPT_GENERATOR_TEMPLATE.format(finding=str(finding))
meta_prompt_generator = OPENAI_META_PROMPT_GENERATOR_TEMPLATE.format(finding=str(finding))
meta_prompt_response = self.generate(meta_prompt_generator)
meta_prompts = clean(meta_prompt_response.get("response", ""), llm_service=self)

long_prompt = LONG_RECOMMENDATION_TEMPLATE.format(meta_prompts=meta_prompts)
long_prompt = OPENAI_LONG_RECOMMENDATION_TEMPLATE.format(meta_prompts=meta_prompts)

finding.solution.add_to_metadata(
"prompt_long_breakdown",
Expand All @@ -106,14 +115,67 @@ def _generate_prompt_with_meta_prompts(self, finding: Finding) -> str:
return long_prompt

def _get_search_terms_prompt(self, finding: Finding) -> str:
return SEARCH_TERMS_TEMPLATE.format(data=str(finding))
return OPENAI_SEARCH_TERMS_TEMPLATE.format(data=str(finding))

def _process_search_terms_response(self, response: Dict[str, str], finding: Finding) -> str:
if "response" not in response:
logger.warning(f"Failed to generate search terms for the finding: {finding.title}")
return ""
return clean(response["response"], llm_service=self)

def _get_subdivision_prompt(self, findings: List[Finding]) -> str:
findings_str = self._get_findings_str_for_aggregation(findings)
return OPENAI_SUBDIVISION_PROMPT_TEMPLATE.format(data=findings_str)

def _process_subdivision_response(self, response: Dict[str, str], findings: List[Finding]) -> List[Tuple[List[Finding], Dict]]:
if "response" not in response:
logger.warning("Failed to subdivide findings")
return [(findings, {})] # Return all findings as a single group if subdivision fails

try:
response = response["response"]
# remove prefix ```json and suffix ```
response = re.sub(r'^```json', '', response)
response = re.sub(r'```$', '', response)
subdivisions = json.loads(response)["subdivisions"]
except json.JSONDecodeError:
logger.error("Failed to parse JSON response")
return [(findings, {})]
except KeyError:
logger.error("Unexpected JSON structure in response")
return [(findings, {})]

result = []
for subdivision in subdivisions:
try:
group_indices = [int(i.strip()) - 1 for i in subdivision["group"].split(',')]
group = [findings[i] for i in group_indices if i < len(findings)]
meta_info = {"reason": subdivision.get("reason", "")}
if len(group) == 1:
continue # Skip single-element groups for *aggregated* solutions
result.append((group, meta_info))
except ValueError:
logger.error(f"Failed to parse group indices: {subdivision['group']}")
continue
except KeyError:
logger.error("Unexpected subdivision structure")
continue

return result

def _get_aggregated_solution_prompt(self, findings: List[Finding], meta_info: Dict) -> str:
findings_str = self._get_findings_str_for_aggregation(findings, details=True)
return OPENAI_AGGREGATED_SOLUTION_TEMPLATE.format(
data=findings_str,
meta_info=meta_info.get("reason", "")
)

def _process_aggregated_solution_response(self, response: Dict[str, str]) -> str:
if "response" not in response:
logger.warning("Failed to generate an aggregated solution")
return ""
return clean(response["response"], llm_service=self)

def convert_dict_to_str(self, data: Dict) -> str:
"""
Convert a dictionary to a string representation.
Expand Down Expand Up @@ -141,7 +203,7 @@ def combine_descriptions(self, descriptions: List[str]) -> str:
if len(descriptions) <= 1:
return descriptions[0] if descriptions else ""

prompt = COMBINE_DESCRIPTIONS_TEMPLATE.format(data=descriptions)
prompt = OPENAI_COMBINE_DESCRIPTIONS_TEMPLATE.format(data=descriptions)

response = self.generate(prompt)
if "response" not in response:
Expand Down
2 changes: 1 addition & 1 deletion src/ai/LLM/Strategies/ollama_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def answer_in_json_prompt(key: str) -> str:
"1. Summary: A brief overview of the core security challenges (1-2 sentences)\n"
"2. Strategic Solution: A high-level approach to address the underlying issues (3-5 key points)\n"
"3. Implementation Guidance: General steps for putting the strategy into action\n"
"4. Long-term Considerations: Suggestions for ongoing improvement and risk mitigation\n\n"
"4. Long-term Considerations: Suggestions for ongoing improvement and risk mitigation. Give first steps or initial research that could lay a foundation.\n\n"
"You may use Markdown formatting in your response to improve readability.\n"
f"{answer_in_json_prompt('aggregated_solution')}"
"Findings:\n{data}"
Expand Down
Loading

0 comments on commit 8e76c7a

Please sign in to comment.