Skip to content

Commit

Permalink
[GPT] handle custom url
Browse files Browse the repository at this point in the history
  • Loading branch information
GuillaumeDSM committed Oct 21, 2024
1 parent 8090f0a commit c0ccca4
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions Services/Services_bases/gpt_service/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import octobot_commons.time_frame_manager as time_frame_manager
import octobot_commons.authentication as authentication
import octobot_commons.tree as tree
import octobot_commons.configuration.fields_utils as fields_utils

import octobot.constants as constants
import octobot.community as community
Expand All @@ -46,13 +47,15 @@ def get_fields_description(self):
if self._env_secret_key is None:
return {
services_constants.CONIG_OPENAI_SECRET_KEY: "Your openai API secret key",
services_constants.CONIG_LLM_CUSTOM_BASE_URL: "Custom LLM base url to use. Leave empty to use openai.com",
}
return {}

def get_default_value(self):
if self._env_secret_key is None:
return {
services_constants.CONIG_OPENAI_SECRET_KEY: "",
services_constants.CONIG_LLM_CUSTOM_BASE_URL: "",
}
return {}

Expand Down Expand Up @@ -104,7 +107,10 @@ async def get_chat_completion(
return await self._get_signal_from_gpt(messages, model, max_tokens, n, stop, temperature)

def _get_client(self) -> openai.AsyncOpenAI:
return openai.AsyncOpenAI(api_key=self._get_api_key())
return openai.AsyncOpenAI(
api_key=self._get_api_key(),
base_url=self._get_base_url(),
)

async def _get_signal_from_gpt(
self,
Expand All @@ -128,7 +134,10 @@ async def _get_signal_from_gpt(
)
self._update_token_usage(completions.usage.total_tokens)
return completions.choices[0].message.content
except openai.BadRequestError as err:
except (
openai.BadRequestError, # error in request
openai.UnprocessableEntityError # error in model (ex: model not found)
)as err:
raise errors.InvalidRequestError(
f"Error when running request with model {model} (invalid request): {err}"
) from err
Expand Down Expand Up @@ -315,6 +324,14 @@ def _get_api_key(self):
services_constants.CONIG_OPENAI_SECRET_KEY
]

def _get_base_url(self):
value = self.config[services_constants.CONFIG_CATEGORY_SERVICES][services_constants.CONFIG_GPT].get(
services_constants.CONIG_LLM_CUSTOM_BASE_URL
)
if fields_utils.has_invalid_default_config_value(value):
return None
return value or None

async def prepare(self) -> None:
try:
if self.use_stored_signals_only():
Expand All @@ -323,8 +340,12 @@ async def prepare(self) -> None:
fetched_models = await self._get_client().models.list()
self.models = [d.id for d in fetched_models.data]
if self.model not in self.models:
self.logger.warning(f"Warning: selected '{self.model}' model is not in GPT available models. "
f"Available models are: {self.models}")
self.logger.warning(
f"Warning: the default '{self.model}' model is not in available LLM models from the "
f"selected LLM provider. "
f"Available models are: {self.models}. Please select an available model when configuring your "
f"evaluators."
)
except openai.AuthenticationError as err:
self.logger.error(f"Invalid OpenAI api key: {err}")
self.creation_error_message = err
Expand Down

0 comments on commit c0ccca4

Please sign in to comment.