Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPT] handle custom url #1371

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions Services/Services_bases/gpt_service/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import octobot_commons.time_frame_manager as time_frame_manager
import octobot_commons.authentication as authentication
import octobot_commons.tree as tree
import octobot_commons.configuration.fields_utils as fields_utils

import octobot.constants as constants
import octobot.community as community
Expand All @@ -46,13 +47,15 @@ def get_fields_description(self):
if self._env_secret_key is None:
return {
services_constants.CONIG_OPENAI_SECRET_KEY: "Your openai API secret key",
services_constants.CONIG_LLM_CUSTOM_BASE_URL: "Custom LLM base url to use. Leave empty to use openai.com",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

}
return {}

def get_default_value(self):
if self._env_secret_key is None:
return {
services_constants.CONIG_OPENAI_SECRET_KEY: "",
services_constants.CONIG_LLM_CUSTOM_BASE_URL: "",
}
return {}

Expand Down Expand Up @@ -104,7 +107,10 @@ async def get_chat_completion(
return await self._get_signal_from_gpt(messages, model, max_tokens, n, stop, temperature)

def _get_client(self) -> openai.AsyncOpenAI:
return openai.AsyncOpenAI(api_key=self._get_api_key())
return openai.AsyncOpenAI(
api_key=self._get_api_key(),
base_url=self._get_base_url(),
)

async def _get_signal_from_gpt(
self,
Expand All @@ -128,7 +134,10 @@ async def _get_signal_from_gpt(
)
self._update_token_usage(completions.usage.total_tokens)
return completions.choices[0].message.content
except openai.BadRequestError as err:
except (
openai.BadRequestError, # error in request
openai.UnprocessableEntityError # error in model (ex: model not found)
)as err:
raise errors.InvalidRequestError(
f"Error when running request with model {model} (invalid request): {err}"
) from err
Expand Down Expand Up @@ -315,6 +324,14 @@ def _get_api_key(self):
services_constants.CONIG_OPENAI_SECRET_KEY
]

def _get_base_url(self):
value = self.config[services_constants.CONFIG_CATEGORY_SERVICES][services_constants.CONFIG_GPT].get(
services_constants.CONIG_LLM_CUSTOM_BASE_URL
)
if fields_utils.has_invalid_default_config_value(value):
return None
return value or None

async def prepare(self) -> None:
try:
if self.use_stored_signals_only():
Expand All @@ -323,8 +340,12 @@ async def prepare(self) -> None:
fetched_models = await self._get_client().models.list()
self.models = [d.id for d in fetched_models.data]
if self.model not in self.models:
self.logger.warning(f"Warning: selected '{self.model}' model is not in GPT available models. "
f"Available models are: {self.models}")
self.logger.warning(
f"Warning: the default '{self.model}' model is not in available LLM models from the "
f"selected LLM provider. "
f"Available models are: {self.models}. Please select an available model when configuring your "
f"evaluators."
)
except openai.AuthenticationError as err:
self.logger.error(f"Invalid OpenAI api key: {err}")
self.creation_error_message = err
Expand Down