From f97d5a8da1b8e54c3129053ccfdacf3a792bc19d Mon Sep 17 00:00:00 2001 From: Guillaume De Saint Martin Date: Mon, 21 Oct 2024 16:38:43 +0200 Subject: [PATCH] [GPT] handle custom url --- Services/Services_bases/gpt_service/gpt.py | 29 +++++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/Services/Services_bases/gpt_service/gpt.py b/Services/Services_bases/gpt_service/gpt.py index fa61c14d0..d30192879 100644 --- a/Services/Services_bases/gpt_service/gpt.py +++ b/Services/Services_bases/gpt_service/gpt.py @@ -29,6 +29,7 @@ import octobot_commons.time_frame_manager as time_frame_manager import octobot_commons.authentication as authentication import octobot_commons.tree as tree +import octobot_commons.configuration.fields_utils as fields_utils import octobot.constants as constants import octobot.community as community @@ -46,6 +47,7 @@ def get_fields_description(self): if self._env_secret_key is None: return { services_constants.CONIG_OPENAI_SECRET_KEY: "Your openai API secret key", + services_constants.CONIG_LLM_CUSTOM_BASE_URL: "Custom LLM base url to use. Leave empty to use openai.com", } return {} @@ -53,6 +55,7 @@ def get_default_value(self): if self._env_secret_key is None: return { services_constants.CONIG_OPENAI_SECRET_KEY: "", + services_constants.CONIG_LLM_CUSTOM_BASE_URL: "", } return {} @@ -104,7 +107,10 @@ async def get_chat_completion( return await self._get_signal_from_gpt(messages, model, max_tokens, n, stop, temperature) def _get_client(self) -> openai.AsyncOpenAI: - return openai.AsyncOpenAI(api_key=self._get_api_key()) + return openai.AsyncOpenAI( + api_key=self._get_api_key(), + base_url=self._get_base_url(), + ) async def _get_signal_from_gpt( self, @@ -128,7 +134,10 @@ async def _get_signal_from_gpt( ) self._update_token_usage(completions.usage.total_tokens) return completions.choices[0].message.content - except openai.BadRequestError as err: + except ( + openai.BadRequestError, # error in request + openai.UnprocessableEntityError # error in model (ex: model not found) + )as err: raise errors.InvalidRequestError( f"Error when running request with model {model} (invalid request): {err}" ) from err @@ -315,6 +324,14 @@ def _get_api_key(self): services_constants.CONIG_OPENAI_SECRET_KEY ] + def _get_base_url(self): + value = self.config[services_constants.CONFIG_CATEGORY_SERVICES][services_constants.CONFIG_GPT].get( + services_constants.CONIG_LLM_CUSTOM_BASE_URL + ) + if fields_utils.has_invalid_default_config_value(value): + return None + return value or None + async def prepare(self) -> None: try: if self.use_stored_signals_only(): @@ -323,8 +340,12 @@ async def prepare(self) -> None: fetched_models = await self._get_client().models.list() self.models = [d.id for d in fetched_models.data] if self.model not in self.models: - self.logger.warning(f"Warning: selected '{self.model}' model is not in GPT available models. " - f"Available models are: {self.models}") + self.logger.warning( + f"Warning: the default '{self.model}' model is not in available LLM models from the " + f"selected LLM provider. " + f"Available models are: {self.models}. Please select an available model when configuring your " + f"evaluators." + ) except openai.AuthenticationError as err: self.logger.error(f"Invalid OpenAI api key: {err}") self.creation_error_message = err