Skip to content

Commit

Permalink
Bugfix - backwards compatibility with model names for openai (#153)
Browse files Browse the repository at this point in the history
--model="openai/gpt-4o" would throw an exception due to no longer
including the prefix openai in the model_cost table

https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json

<img width="1263" alt="Screen Shot 2024-10-08 at 19 20 18"
src="https://github.com/user-attachments/assets/849978e2-3184-4ae7-b1db-381ff986c117">
  • Loading branch information
Avi-Robusta authored Oct 8, 2024
1 parent 7dc06c3 commit 93dd488
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions holmes/core/tool_calling_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,32 @@ def check_llm(self, model, api_key):
if not model_requirements["keys_in_environment"]:
raise Exception(f"model {model} requires the following environment variables: {model_requirements['missing_keys']}")

def _strip_model_prefix(self) -> str:
"""
Helper function to strip 'openai/' prefix from model name if it exists.
model cost is taken from here which does not have the openai prefix
https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json
"""
model_name = self.model
if model_name.startswith('openai/'):
model_name = model_name[len('openai/'):] # Strip the 'openai/' prefix
return model_name


# this unfortunately does not seem to work for azure if the deployment name is not a well-known model name
#if not litellm.supports_function_calling(model=model):
# raise Exception(f"model {model} does not support function calling. You must use HolmesGPT with a model that supports function calling.")
def get_context_window_size(self) -> int:
return litellm.model_cost[self.model]['max_input_tokens']
model_name = self._strip_model_prefix()
return litellm.model_cost[model_name]['max_input_tokens']

def count_tokens_for_message(self, messages: list[dict]) -> int:
return litellm.token_counter(model=self.model,
messages=messages)

def get_maximum_output_token(self) -> int:
return litellm.model_cost[self.model]['max_output_tokens']
model_name = self._strip_model_prefix()
return litellm.model_cost[model_name]['max_output_tokens']

def call(self, system_prompt, user_prompt, post_process_prompt: Optional[str] = None, response_format: dict = None) -> LLMResult:
messages = [
Expand Down

0 comments on commit 93dd488

Please sign in to comment.