-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(ai/LLM): Overhaul LLM services architecture
Introduce LLMServiceMixin for shared functionality Refactor BaseLLMService with abstract methods Enhance error handling and logging across services Update AnthropicService, OLLAMAService, OpenAIService Add testing notebook for services
- Loading branch information
1 parent
c80e71d
commit ad45931
Showing
7 changed files
with
586 additions
and
333 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import json | ||
from typing import Any, Dict | ||
import logging | ||
from utils.text_tools import clean | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LLMServiceMixin: | ||
""" | ||
A mixin class providing common utility methods for LLM services. | ||
This mixin should be used in conjunction with BaseLLMService implementations | ||
to provide shared functionality across different LLM services. | ||
Attributes: | ||
config (Dict[str, Any]): Configuration dictionary for the LLM service. | ||
The config dictionary should include the following keys: | ||
- api_key (str): The API key for the LLM service. | ||
- model (str, optional): The name of the model to use. Defaults to 'default_model'. | ||
- Additional service-specific configuration keys as needed. | ||
""" | ||
|
||
def __init__(self, config: Dict[str, Any]): | ||
""" | ||
Initialize the LLMServiceMixin. | ||
Args: | ||
config (Dict[str, Any]): Configuration dictionary for the LLM service. | ||
""" | ||
self.config = config | ||
|
||
def get_api_key(self) -> str: | ||
""" | ||
Retrieve the API key from the configuration. | ||
Returns: | ||
str: The API key. | ||
Raises: | ||
ValueError: If the API key is not provided in the configuration. | ||
""" | ||
api_key = self.config.get('api_key') | ||
if not api_key: | ||
raise ValueError("API key not provided in configuration") | ||
return api_key | ||
|
||
def get_model(self) -> str: | ||
""" | ||
Retrieve the model name from the configuration. | ||
Returns: | ||
str: The model name, or 'default_model' if not specified. | ||
""" | ||
return self.config.get('model', 'default_model') | ||
|
||
def clean_response(self, response: str) -> str: | ||
""" | ||
Clean the response using the utility function from text_tools. | ||
Args: | ||
response (str): The raw response string. | ||
Returns: | ||
str: The cleaned response string. | ||
""" | ||
return clean(response, llm_service=self) | ||
|
||
def parse_json_response(self, response: str) -> Dict[str, Any]: | ||
""" | ||
Parse a JSON response string into a dictionary. | ||
Args: | ||
response (str): The JSON response string. | ||
Returns: | ||
Dict[str, Any]: The parsed JSON as a dictionary, or an empty dict if parsing fails. | ||
""" | ||
try: | ||
return json.loads(response) | ||
except json.JSONDecodeError as e: | ||
logger.error(f"Failed to parse JSON response: {e}") | ||
return {} | ||
|
||
def format_prompt(self, template: str, **kwargs) -> str: | ||
""" | ||
Format a prompt template with the given keyword arguments. | ||
Args: | ||
template (str): The prompt template string. | ||
**kwargs: Keyword arguments to fill in the template. | ||
Returns: | ||
str: The formatted prompt string. | ||
""" | ||
try: | ||
return template.format(**kwargs) | ||
except KeyError as e: | ||
logger.error(f"Missing key in prompt template: {e}") | ||
return template | ||
|
||
def handle_api_error(self, e: Exception) -> Dict[str, str]: | ||
""" | ||
Handle and log API errors. | ||
Args: | ||
e (Exception): The exception that occurred during the API call. | ||
Returns: | ||
Dict[str, str]: A dictionary containing the error message. | ||
""" | ||
logger.error(f"API error occurred: {str(e)}") | ||
return {"error": str(e)} | ||
|
||
def convert_dict_to_str(self, data: Dict[str, Any]) -> str: | ||
""" | ||
Convert a dictionary to a descriptive string. | ||
This method uses the LLM to generate a human-readable description of the dictionary. | ||
Args: | ||
data (Dict[str, Any]): The dictionary to convert. | ||
Returns: | ||
str: A string description of the dictionary, or the stringified dictionary if conversion fails. | ||
""" | ||
prompt = self.format_prompt( | ||
"Convert the following dictionary to a descriptive string: {data}", | ||
data=json.dumps(data) | ||
) | ||
response = self.generate(prompt) | ||
if "error" in response: | ||
logger.info(f"Failed to convert dictionary to string, returning it as str conversion.") | ||
return str(data) | ||
return self.clean_response(response.get("response", str(data))) | ||
|
||
def generate(self, prompt: str) -> Dict[str, str]: | ||
""" | ||
Generate a response using the LLM. | ||
This method should be implemented by the class that uses this mixin. | ||
Args: | ||
prompt (str): The input prompt for the LLM. | ||
Returns: | ||
Dict[str, str]: The generated response. | ||
Raises: | ||
NotImplementedError: If the method is not implemented by the using class. | ||
""" | ||
raise NotImplementedError("The generate method must be implemented by the class using this mixin.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,89 @@ | ||
from enum import Enum | ||
from typing import Dict, Optional, List | ||
from typing import Dict, Optional, List, Union | ||
|
||
from ai.LLM.BaseLLMService import BaseLLMService | ||
from data.Finding import Finding | ||
|
||
|
||
class LLMServiceStrategy: | ||
def __init__(self, llm_service: BaseLLMService): | ||
""" | ||
Initialize the LLMServiceStrategy. | ||
Args: | ||
llm_service (BaseLLMService): An instance of a class inheriting from BaseLLMService. | ||
""" | ||
if not isinstance(llm_service, BaseLLMService): | ||
raise ValueError("llm_service must be an instance of a class inheriting from BaseLLMService") | ||
self.llm_service = llm_service | ||
|
||
def generate(self, prompt: str) -> Dict[str, str]: | ||
return self.llm_service.generate(prompt) | ||
def get_model_name(self) -> str: | ||
"""Get the name of the current LLM model.""" | ||
return self.llm_service.get_model_name() | ||
|
||
def get_url(self) -> str: | ||
"""Get the URL associated with the current LLM service.""" | ||
return self.llm_service.get_url() | ||
|
||
def classify_kind(self, finding: Finding, field_name: str, options: List[Enum]) -> Optional[Enum]: | ||
def generate(self, prompt: str) -> Dict[str, str]: | ||
""" | ||
Generate a response using the current LLM service. | ||
Args: | ||
prompt (str): The input prompt. | ||
Returns: | ||
Dict[str, str]: The generated response. | ||
""" | ||
return self.llm_service.generate(prompt) | ||
|
||
def classify_kind(self, finding: Finding, field_name: str, options: Optional[List[Enum]] = None) -> Optional[Enum]: | ||
""" | ||
Classify the kind of finding. | ||
Args: | ||
finding (Finding): The finding to classify. | ||
field_name (str): The name of the field to classify. | ||
options (Optional[List[Enum]]): The possible classification options. | ||
Returns: | ||
Optional[Enum]: The classified kind, or None if classification failed. | ||
""" | ||
return self.llm_service.classify_kind(finding, field_name, options) | ||
|
||
def get_recommendation(self, finding: Finding, short: bool = True) -> str: | ||
def get_recommendation(self, finding: Finding, short: bool = True) -> Union[str, List[str]]: | ||
""" | ||
Get a recommendation for a finding. | ||
Args: | ||
finding (Finding): The finding to get a recommendation for. | ||
short (bool): Whether to get a short or long recommendation. | ||
Returns: | ||
Union[str, List[str]]: The generated recommendation. | ||
""" | ||
return self.llm_service.get_recommendation(finding, short) | ||
|
||
def get_search_terms(self, finding: Finding) -> str: | ||
""" | ||
Get search terms for a finding. | ||
Args: | ||
finding (Finding): The finding to get search terms for. | ||
Returns: | ||
str: The generated search terms. | ||
""" | ||
return self.llm_service.get_search_terms(finding) | ||
|
||
def convert_dict_to_str(self, data) -> str: | ||
def convert_dict_to_str(self, data: Dict) -> str: | ||
""" | ||
Convert a dictionary to a string representation. | ||
Args: | ||
data (Dict): The dictionary to convert. | ||
Returns: | ||
str: The string representation of the dictionary. | ||
""" | ||
return self.llm_service.convert_dict_to_str(data) |
Oops, something went wrong.