Skip to content

Commit

Permalink
refactor(ai/LLM): Overhaul LLM services architecture
Browse files Browse the repository at this point in the history
Introduce LLMServiceMixin for shared functionality
Refactor BaseLLMService with abstract methods
Enhance error handling and logging across services
Update AnthropicService, OLLAMAService, OpenAIService
Add testing notebook for services
  • Loading branch information
KaiserRuben committed Jul 14, 2024
1 parent c80e71d commit ad45931
Show file tree
Hide file tree
Showing 7 changed files with 586 additions and 333 deletions.
62 changes: 59 additions & 3 deletions src/ai/LLM/BaseLLMService.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, List, Union
import logging

from data.Finding import Finding

logger = logging.getLogger(__name__)


class BaseLLMService(ABC):
@abstractmethod
Expand All @@ -15,19 +18,72 @@ def get_url(self) -> str:
pass

@abstractmethod
def generate(self, prompt: str) -> Dict[str, str]:
def _generate(self, prompt: str) -> Dict[str, str]:
pass

def generate(self, prompt: str) -> Dict[str, str]:
try:
return self._generate(prompt)
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
return {"error": str(e)}

def classify_kind(self, finding: Finding, field_name: str, options: Optional[List[Enum]] = None) -> Optional[Enum]:
if options is None:
logger.warning(f"No options provided for field {field_name}")
return None

options_str = ", ".join([option.value for option in options])
prompt = self._get_classification_prompt(options_str, field_name, str(finding))
response = self.generate(prompt)

return self._process_classification_response(response, field_name, finding, options_str, options)

@abstractmethod
def classify_kind(self, finding: Finding, field_name: str, options: List[Enum]) -> Optional[Enum]:
def _get_classification_prompt(self, options: str, field_name: str, finding_str: str) -> str:
pass

def _process_classification_response(self, response: Dict[str, str], field_name: str, finding: Finding,
options_str: str, options: List[Enum]) -> Optional[Enum]:
if "selected_option" not in response:
logger.warning(f"Failed to classify the {field_name} for the finding: {finding.title}")
return None
if response["selected_option"] in ["None", "NotListed"]:
logger.info(f"Chose None for {field_name} for the finding: {finding.title}")
return None
if response["selected_option"] not in options_str:
logger.warning(f"Failed to classify the {field_name} for the finding: {finding.title}")
return None

return next(option for option in options if option.value == response["selected_option"])

def get_recommendation(self, finding: Finding, short: bool = True) -> Union[str, List[str]]:
prompt = self._get_recommendation_prompt(finding, short)
finding.solution.add_to_metadata(f"prompt_{'short' if short else 'long'}", prompt)
response = self.generate(prompt)

return self._process_recommendation_response(response, finding, short)

@abstractmethod
def get_recommendation(self, finding: Finding, short: bool = True) -> str:
def _get_recommendation_prompt(self, finding: Finding, short: bool) -> str:
pass

@abstractmethod
def _process_recommendation_response(self, response: Dict[str, str], finding: Finding, short: bool) -> Union[
str, List[str]]:
pass

def get_search_terms(self, finding: Finding) -> str:
prompt = self._get_search_terms_prompt(finding)
response = self.generate(prompt)
return self._process_search_terms_response(response, finding)

@abstractmethod
def _get_search_terms_prompt(self, finding: Finding) -> str:
pass

@abstractmethod
def _process_search_terms_response(self, response: Dict[str, str], finding: Finding) -> str:
pass

@abstractmethod
Expand Down
153 changes: 153 additions & 0 deletions src/ai/LLM/LLMServiceMixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import json
from typing import Any, Dict
import logging
from utils.text_tools import clean

logger = logging.getLogger(__name__)


class LLMServiceMixin:
"""
A mixin class providing common utility methods for LLM services.
This mixin should be used in conjunction with BaseLLMService implementations
to provide shared functionality across different LLM services.
Attributes:
config (Dict[str, Any]): Configuration dictionary for the LLM service.
The config dictionary should include the following keys:
- api_key (str): The API key for the LLM service.
- model (str, optional): The name of the model to use. Defaults to 'default_model'.
- Additional service-specific configuration keys as needed.
"""

def __init__(self, config: Dict[str, Any]):
"""
Initialize the LLMServiceMixin.
Args:
config (Dict[str, Any]): Configuration dictionary for the LLM service.
"""
self.config = config

def get_api_key(self) -> str:
"""
Retrieve the API key from the configuration.
Returns:
str: The API key.
Raises:
ValueError: If the API key is not provided in the configuration.
"""
api_key = self.config.get('api_key')
if not api_key:
raise ValueError("API key not provided in configuration")
return api_key

def get_model(self) -> str:
"""
Retrieve the model name from the configuration.
Returns:
str: The model name, or 'default_model' if not specified.
"""
return self.config.get('model', 'default_model')

def clean_response(self, response: str) -> str:
"""
Clean the response using the utility function from text_tools.
Args:
response (str): The raw response string.
Returns:
str: The cleaned response string.
"""
return clean(response, llm_service=self)

def parse_json_response(self, response: str) -> Dict[str, Any]:
"""
Parse a JSON response string into a dictionary.
Args:
response (str): The JSON response string.
Returns:
Dict[str, Any]: The parsed JSON as a dictionary, or an empty dict if parsing fails.
"""
try:
return json.loads(response)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON response: {e}")
return {}

def format_prompt(self, template: str, **kwargs) -> str:
"""
Format a prompt template with the given keyword arguments.
Args:
template (str): The prompt template string.
**kwargs: Keyword arguments to fill in the template.
Returns:
str: The formatted prompt string.
"""
try:
return template.format(**kwargs)
except KeyError as e:
logger.error(f"Missing key in prompt template: {e}")
return template

def handle_api_error(self, e: Exception) -> Dict[str, str]:
"""
Handle and log API errors.
Args:
e (Exception): The exception that occurred during the API call.
Returns:
Dict[str, str]: A dictionary containing the error message.
"""
logger.error(f"API error occurred: {str(e)}")
return {"error": str(e)}

def convert_dict_to_str(self, data: Dict[str, Any]) -> str:
"""
Convert a dictionary to a descriptive string.
This method uses the LLM to generate a human-readable description of the dictionary.
Args:
data (Dict[str, Any]): The dictionary to convert.
Returns:
str: A string description of the dictionary, or the stringified dictionary if conversion fails.
"""
prompt = self.format_prompt(
"Convert the following dictionary to a descriptive string: {data}",
data=json.dumps(data)
)
response = self.generate(prompt)
if "error" in response:
logger.info(f"Failed to convert dictionary to string, returning it as str conversion.")
return str(data)
return self.clean_response(response.get("response", str(data)))

def generate(self, prompt: str) -> Dict[str, str]:
"""
Generate a response using the LLM.
This method should be implemented by the class that uses this mixin.
Args:
prompt (str): The input prompt for the LLM.
Returns:
Dict[str, str]: The generated response.
Raises:
NotImplementedError: If the method is not implemented by the using class.
"""
raise NotImplementedError("The generate method must be implemented by the class using this mixin.")
73 changes: 67 additions & 6 deletions src/ai/LLM/LLMServiceStrategy.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,89 @@
from enum import Enum
from typing import Dict, Optional, List
from typing import Dict, Optional, List, Union

from ai.LLM.BaseLLMService import BaseLLMService
from data.Finding import Finding


class LLMServiceStrategy:
def __init__(self, llm_service: BaseLLMService):
"""
Initialize the LLMServiceStrategy.
Args:
llm_service (BaseLLMService): An instance of a class inheriting from BaseLLMService.
"""
if not isinstance(llm_service, BaseLLMService):
raise ValueError("llm_service must be an instance of a class inheriting from BaseLLMService")
self.llm_service = llm_service

def generate(self, prompt: str) -> Dict[str, str]:
return self.llm_service.generate(prompt)
def get_model_name(self) -> str:
"""Get the name of the current LLM model."""
return self.llm_service.get_model_name()

def get_url(self) -> str:
"""Get the URL associated with the current LLM service."""
return self.llm_service.get_url()

def classify_kind(self, finding: Finding, field_name: str, options: List[Enum]) -> Optional[Enum]:
def generate(self, prompt: str) -> Dict[str, str]:
"""
Generate a response using the current LLM service.
Args:
prompt (str): The input prompt.
Returns:
Dict[str, str]: The generated response.
"""
return self.llm_service.generate(prompt)

def classify_kind(self, finding: Finding, field_name: str, options: Optional[List[Enum]] = None) -> Optional[Enum]:
"""
Classify the kind of finding.
Args:
finding (Finding): The finding to classify.
field_name (str): The name of the field to classify.
options (Optional[List[Enum]]): The possible classification options.
Returns:
Optional[Enum]: The classified kind, or None if classification failed.
"""
return self.llm_service.classify_kind(finding, field_name, options)

def get_recommendation(self, finding: Finding, short: bool = True) -> str:
def get_recommendation(self, finding: Finding, short: bool = True) -> Union[str, List[str]]:
"""
Get a recommendation for a finding.
Args:
finding (Finding): The finding to get a recommendation for.
short (bool): Whether to get a short or long recommendation.
Returns:
Union[str, List[str]]: The generated recommendation.
"""
return self.llm_service.get_recommendation(finding, short)

def get_search_terms(self, finding: Finding) -> str:
"""
Get search terms for a finding.
Args:
finding (Finding): The finding to get search terms for.
Returns:
str: The generated search terms.
"""
return self.llm_service.get_search_terms(finding)

def convert_dict_to_str(self, data) -> str:
def convert_dict_to_str(self, data: Dict) -> str:
"""
Convert a dictionary to a string representation.
Args:
data (Dict): The dictionary to convert.
Returns:
str: The string representation of the dictionary.
"""
return self.llm_service.convert_dict_to_str(data)
Loading

0 comments on commit ad45931

Please sign in to comment.