From c9d07e137b8193c9fee101afe19f9eeaa566e370 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 10 Sep 2024 15:00:25 +0800 Subject: [PATCH] refactor(api/core): Improve type hints and apply ruff formatter in agent runner and model manager. (#8166) --- api/core/agent/base_agent_runner.py | 239 +++++++++++++++------------- api/core/model_manager.py | 157 ++++++++---------- 2 files changed, 199 insertions(+), 197 deletions(-) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index d8290ca608b0cb..d09a9956a4a591 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -1,6 +1,7 @@ import json import logging import uuid +from collections.abc import Mapping, Sequence from datetime import datetime, timezone from typing import Optional, Union, cast @@ -45,22 +46,25 @@ logger = logging.getLogger(__name__) + class BaseAgentRunner(AppRunner): - def __init__(self, tenant_id: str, - application_generate_entity: AgentChatAppGenerateEntity, - conversation: Conversation, - app_config: AgentChatAppConfig, - model_config: ModelConfigWithCredentialsEntity, - config: AgentEntity, - queue_manager: AppQueueManager, - message: Message, - user_id: str, - memory: Optional[TokenBufferMemory] = None, - prompt_messages: Optional[list[PromptMessage]] = None, - variables_pool: Optional[ToolRuntimeVariablePool] = None, - db_variables: Optional[ToolConversationVariables] = None, - model_instance: ModelInstance = None - ) -> None: + def __init__( + self, + tenant_id: str, + application_generate_entity: AgentChatAppGenerateEntity, + conversation: Conversation, + app_config: AgentChatAppConfig, + model_config: ModelConfigWithCredentialsEntity, + config: AgentEntity, + queue_manager: AppQueueManager, + message: Message, + user_id: str, + memory: Optional[TokenBufferMemory] = None, + prompt_messages: Optional[list[PromptMessage]] = None, + variables_pool: Optional[ToolRuntimeVariablePool] = None, + db_variables: Optional[ToolConversationVariables] = None, + model_instance: ModelInstance = None, + ) -> None: """ Agent runner :param tenant_id: tenant id @@ -88,9 +92,7 @@ def __init__(self, tenant_id: str, self.message = message self.user_id = user_id self.memory = memory - self.history_prompt_messages = self.organize_agent_history( - prompt_messages=prompt_messages or [] - ) + self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) self.variables_pool = variables_pool self.db_variables_pool = db_variables self.model_instance = model_instance @@ -111,12 +113,16 @@ def __init__(self, tenant_id: str, retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, return_resource=app_config.additional_features.show_retrieve_source, invoke_from=application_generate_entity.invoke_from, - hit_callback=hit_callback + hit_callback=hit_callback, ) # get how many agent thoughts have been created - self.agent_thought_count = db.session.query(MessageAgentThought).filter( - MessageAgentThought.message_id == self.message.id, - ).count() + self.agent_thought_count = ( + db.session.query(MessageAgentThought) + .filter( + MessageAgentThought.message_id == self.message.id, + ) + .count() + ) db.session.close() # check if model supports stream tool call @@ -135,25 +141,26 @@ def __init__(self, tenant_id: str, self.query = None self._current_thoughts: list[PromptMessage] = [] - def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ - -> AgentChatAppGenerateEntity: + def _repack_app_generate_entity( + self, app_generate_entity: AgentChatAppGenerateEntity + ) -> AgentChatAppGenerateEntity: """ Repack app generate entity """ if app_generate_entity.app_config.prompt_template.simple_prompt_template is None: - app_generate_entity.app_config.prompt_template.simple_prompt_template = '' + app_generate_entity.app_config.prompt_template.simple_prompt_template = "" return app_generate_entity - + def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: """ - convert tool to prompt message tool + convert tool to prompt message tool """ tool_entity = ToolManager.get_agent_tool_runtime( tenant_id=self.tenant_id, app_id=self.app_config.app_id, agent_tool=tool, - invoke_from=self.application_generate_entity.invoke_from + invoke_from=self.application_generate_entity.invoke_from, ) tool_entity.load_variables(self.variables_pool) @@ -164,7 +171,7 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P "type": "object", "properties": {}, "required": [], - } + }, ) parameters = tool_entity.get_all_runtime_parameters() @@ -177,19 +184,19 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] - message_tool.parameters['properties'][parameter.name] = { + message_tool.parameters["properties"][parameter.name] = { "type": parameter_type, - "description": parameter.llm_description or '', + "description": parameter.llm_description or "", } if len(enum) > 0: - message_tool.parameters['properties'][parameter.name]['enum'] = enum + message_tool.parameters["properties"][parameter.name]["enum"] = enum if parameter.required: - message_tool.parameters['required'].append(parameter.name) + message_tool.parameters["required"].append(parameter.name) return message_tool, tool_entity - + def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool: """ convert dataset retriever tool to prompt message tool @@ -201,24 +208,24 @@ def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRe "type": "object", "properties": {}, "required": [], - } + }, ) for parameter in tool.get_runtime_parameters(): - parameter_type = 'string' - - prompt_tool.parameters['properties'][parameter.name] = { + parameter_type = "string" + + prompt_tool.parameters["properties"][parameter.name] = { "type": parameter_type, - "description": parameter.llm_description or '', + "description": parameter.llm_description or "", } if parameter.required: - if parameter.name not in prompt_tool.parameters['required']: - prompt_tool.parameters['required'].append(parameter.name) + if parameter.name not in prompt_tool.parameters["required"]: + prompt_tool.parameters["required"].append(parameter.name) return prompt_tool - - def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]: + + def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]: """ Init tools """ @@ -261,51 +268,51 @@ def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] - - prompt_tool.parameters['properties'][parameter.name] = { + + prompt_tool.parameters["properties"][parameter.name] = { "type": parameter_type, - "description": parameter.llm_description or '', + "description": parameter.llm_description or "", } if len(enum) > 0: - prompt_tool.parameters['properties'][parameter.name]['enum'] = enum + prompt_tool.parameters["properties"][parameter.name]["enum"] = enum if parameter.required: - if parameter.name not in prompt_tool.parameters['required']: - prompt_tool.parameters['required'].append(parameter.name) + if parameter.name not in prompt_tool.parameters["required"]: + prompt_tool.parameters["required"].append(parameter.name) return prompt_tool - - def create_agent_thought(self, message_id: str, message: str, - tool_name: str, tool_input: str, messages_ids: list[str] - ) -> MessageAgentThought: + + def create_agent_thought( + self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str] + ) -> MessageAgentThought: """ Create agent thought """ thought = MessageAgentThought( message_id=message_id, message_chain_id=None, - thought='', + thought="", tool=tool_name, - tool_labels_str='{}', - tool_meta_str='{}', + tool_labels_str="{}", + tool_meta_str="{}", tool_input=tool_input, message=message, message_token=0, message_unit_price=0, message_price_unit=0, - message_files=json.dumps(messages_ids) if messages_ids else '', - answer='', - observation='', + message_files=json.dumps(messages_ids) if messages_ids else "", + answer="", + observation="", answer_token=0, answer_unit_price=0, answer_price_unit=0, tokens=0, total_price=0, position=self.agent_thought_count + 1, - currency='USD', + currency="USD", latency=0, - created_by_role='account', + created_by_role="account", created_by=self.user_id, ) @@ -318,22 +325,22 @@ def create_agent_thought(self, message_id: str, message: str, return thought - def save_agent_thought(self, - agent_thought: MessageAgentThought, - tool_name: str, - tool_input: Union[str, dict], - thought: str, - observation: Union[str, dict], - tool_invoke_meta: Union[str, dict], - answer: str, - messages_ids: list[str], - llm_usage: LLMUsage = None) -> MessageAgentThought: + def save_agent_thought( + self, + agent_thought: MessageAgentThought, + tool_name: str, + tool_input: Union[str, dict], + thought: str, + observation: Union[str, dict], + tool_invoke_meta: Union[str, dict], + answer: str, + messages_ids: list[str], + llm_usage: LLMUsage = None, + ) -> MessageAgentThought: """ Save agent thought """ - agent_thought = db.session.query(MessageAgentThought).filter( - MessageAgentThought.id == agent_thought.id - ).first() + agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() if thought is not None: agent_thought.thought = thought @@ -356,7 +363,7 @@ def save_agent_thought(self, observation = json.dumps(observation, ensure_ascii=False) except Exception as e: observation = json.dumps(observation) - + agent_thought.observation = observation if answer is not None: @@ -364,7 +371,7 @@ def save_agent_thought(self, if messages_ids is not None and len(messages_ids) > 0: agent_thought.message_files = json.dumps(messages_ids) - + if llm_usage: agent_thought.message_token = llm_usage.prompt_tokens agent_thought.message_price_unit = llm_usage.prompt_price_unit @@ -377,7 +384,7 @@ def save_agent_thought(self, # check if tool labels is not empty labels = agent_thought.tool_labels or {} - tools = agent_thought.tool.split(';') if agent_thought.tool else [] + tools = agent_thought.tool.split(";") if agent_thought.tool else [] for tool in tools: if not tool: continue @@ -386,7 +393,7 @@ def save_agent_thought(self, if tool_label: labels[tool] = tool_label.to_dict() else: - labels[tool] = {'en_US': tool, 'zh_Hans': tool} + labels[tool] = {"en_US": tool, "zh_Hans": tool} agent_thought.tool_labels_str = json.dumps(labels) @@ -401,14 +408,18 @@ def save_agent_thought(self, db.session.commit() db.session.close() - + def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): """ convert tool variables to db variables """ - db_variables = db.session.query(ToolConversationVariables).filter( - ToolConversationVariables.conversation_id == self.message.conversation_id, - ).first() + db_variables = ( + db.session.query(ToolConversationVariables) + .filter( + ToolConversationVariables.conversation_id == self.message.conversation_id, + ) + .first() + ) db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) @@ -425,9 +436,14 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P if isinstance(prompt_message, SystemPromptMessage): result.append(prompt_message) - messages: list[Message] = db.session.query(Message).filter( - Message.conversation_id == self.message.conversation_id, - ).order_by(Message.created_at.asc()).all() + messages: list[Message] = ( + db.session.query(Message) + .filter( + Message.conversation_id == self.message.conversation_id, + ) + .order_by(Message.created_at.asc()) + .all() + ) for message in messages: if message.id == self.message.id: @@ -439,13 +455,13 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P for agent_thought in agent_thoughts: tools = agent_thought.tool if tools: - tools = tools.split(';') + tools = tools.split(";") tool_calls: list[AssistantPromptMessage.ToolCall] = [] tool_call_response: list[ToolPromptMessage] = [] try: tool_inputs = json.loads(agent_thought.tool_input) except Exception as e: - tool_inputs = { tool: {} for tool in tools } + tool_inputs = {tool: {} for tool in tools} try: tool_responses = json.loads(agent_thought.observation) except Exception as e: @@ -454,27 +470,33 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P for tool in tools: # generate a uuid for tool call tool_call_id = str(uuid.uuid4()) - tool_calls.append(AssistantPromptMessage.ToolCall( - id=tool_call_id, - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( + tool_calls.append( + AssistantPromptMessage.ToolCall( + id=tool_call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool, + arguments=json.dumps(tool_inputs.get(tool, {})), + ), + ) + ) + tool_call_response.append( + ToolPromptMessage( + content=tool_responses.get(tool, agent_thought.observation), name=tool, - arguments=json.dumps(tool_inputs.get(tool, {})), + tool_call_id=tool_call_id, ) - )) - tool_call_response.append(ToolPromptMessage( - content=tool_responses.get(tool, agent_thought.observation), - name=tool, - tool_call_id=tool_call_id, - )) - - result.extend([ - AssistantPromptMessage( - content=agent_thought.thought, - tool_calls=tool_calls, - ), - *tool_call_response - ]) + ) + + result.extend( + [ + AssistantPromptMessage( + content=agent_thought.thought, + tool_calls=tool_calls, + ), + *tool_call_response, + ] + ) if not tools: result.append(AssistantPromptMessage(content=agent_thought.thought)) else: @@ -496,10 +518,7 @@ def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.transform_message_files( - files, - file_extra_config - ) + file_objs = message_file_parser.transform_message_files(files, file_extra_config) else: file_objs = [] diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 7b1a7ada5b2225..990efd36c609c2 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,6 +1,6 @@ import logging import os -from collections.abc import Callable, Generator +from collections.abc import Callable, Generator, Sequence from typing import IO, Optional, Union, cast from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle @@ -41,7 +41,7 @@ def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> No configuration=provider_model_bundle.configuration, model_type=provider_model_bundle.model_type_instance.model_type, model=model, - credentials=self.credentials + credentials=self.credentials, ) @staticmethod @@ -54,10 +54,7 @@ def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, m """ configuration = provider_model_bundle.configuration model_type = provider_model_bundle.model_type_instance.model_type - credentials = configuration.get_current_credentials( - model_type=model_type, - model=model - ) + credentials = configuration.get_current_credentials(model_type=model_type, model=model) if credentials is None: raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.") @@ -65,10 +62,9 @@ def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, m return credentials @staticmethod - def _get_load_balancing_manager(configuration: ProviderConfiguration, - model_type: ModelType, - model: str, - credentials: dict) -> Optional["LBModelManager"]: + def _get_load_balancing_manager( + configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict + ) -> Optional["LBModelManager"]: """ Get load balancing model credentials :param configuration: provider configuration @@ -81,8 +77,7 @@ def _get_load_balancing_manager(configuration: ProviderConfiguration, current_model_setting = None # check if model is disabled by admin for model_setting in configuration.model_settings: - if (model_setting.model_type == model_type - and model_setting.model == model): + if model_setting.model_type == model_type and model_setting.model == model: current_model_setting = model_setting break @@ -95,17 +90,23 @@ def _get_load_balancing_manager(configuration: ProviderConfiguration, model_type=model_type, model=model, load_balancing_configs=current_model_setting.load_balancing_configs, - managed_credentials=credentials if configuration.custom_configuration.provider else None + managed_credentials=credentials if configuration.custom_configuration.provider else None, ) return lb_model_manager return None - def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ - -> Union[LLMResult, Generator]: + def invoke_llm( + self, + prompt_messages: list[PromptMessage], + model_parameters: Optional[dict] = None, + tools: Sequence[PromptMessageTool] | None = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -132,11 +133,12 @@ def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Opt stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) - def get_llm_num_tokens(self, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_llm_num_tokens( + self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """ Get number of tokens for llm @@ -153,11 +155,10 @@ def get_llm_num_tokens(self, prompt_messages: list[PromptMessage], model=self.model, credentials=self.credentials, prompt_messages=prompt_messages, - tools=tools + tools=tools, ) - def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) -> TextEmbeddingResult: """ Invoke large language model @@ -174,7 +175,7 @@ def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \ model=self.model, credentials=self.credentials, texts=texts, - user=user + user=user, ) def get_text_embedding_num_tokens(self, texts: list[str]) -> int: @@ -192,13 +193,17 @@ def get_text_embedding_num_tokens(self, texts: list[str]) -> int: function=self.model_type_instance.get_num_tokens, model=self.model, credentials=self.credentials, - texts=texts + texts=texts, ) - def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def invoke_rerank( + self, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -221,11 +226,10 @@ def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[f docs=docs, score_threshold=score_threshold, top_n=top_n, - user=user + user=user, ) - def invoke_moderation(self, text: str, user: Optional[str] = None) \ - -> bool: + def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: """ Invoke moderation model @@ -242,11 +246,10 @@ def invoke_moderation(self, text: str, user: Optional[str] = None) \ model=self.model, credentials=self.credentials, text=text, - user=user + user=user, ) - def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \ - -> str: + def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke large language model @@ -263,11 +266,10 @@ def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \ model=self.model, credentials=self.credentials, file=file, - user=user + user=user, ) - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) \ - -> str: + def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str: """ Invoke large language tts model @@ -288,7 +290,7 @@ def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Option content_text=content_text, user=user, tenant_id=tenant_id, - voice=voice + voice=voice, ) def _round_robin_invoke(self, function: Callable, *args, **kwargs): @@ -312,8 +314,8 @@ def _round_robin_invoke(self, function: Callable, *args, **kwargs): raise last_exception try: - if 'credentials' in kwargs: - del kwargs['credentials'] + if "credentials" in kwargs: + del kwargs["credentials"] return function(*args, **kwargs, credentials=lb_config.credentials) except InvokeRateLimitError as e: # expire in 60 seconds @@ -340,9 +342,7 @@ def get_tts_voices(self, language: Optional[str] = None) -> list: self.model_type_instance = cast(TTSModel, self.model_type_instance) return self.model_type_instance.get_tts_model_voices( - model=self.model, - credentials=self.credentials, - language=language + model=self.model, credentials=self.credentials, language=language ) @@ -363,9 +363,7 @@ def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelTyp return self.get_default_model_instance(tenant_id, model_type) provider_model_bundle = self._provider_manager.get_provider_model_bundle( - tenant_id=tenant_id, - provider=provider, - model_type=model_type + tenant_id=tenant_id, provider=provider, model_type=model_type ) return ModelInstance(provider_model_bundle, model) @@ -386,10 +384,7 @@ def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> M :param model_type: model type :return: """ - default_model_entity = self._provider_manager.get_default_model( - tenant_id=tenant_id, - model_type=model_type - ) + default_model_entity = self._provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type) if not default_model_entity: raise ProviderTokenNotInitError(f"Default model not found for {model_type}") @@ -398,17 +393,20 @@ def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> M tenant_id=tenant_id, provider=default_model_entity.provider.provider, model_type=model_type, - model=default_model_entity.model + model=default_model_entity.model, ) class LBModelManager: - def __init__(self, tenant_id: str, - provider: str, - model_type: ModelType, - model: str, - load_balancing_configs: list[ModelLoadBalancingConfiguration], - managed_credentials: Optional[dict] = None) -> None: + def __init__( + self, + tenant_id: str, + provider: str, + model_type: ModelType, + model: str, + load_balancing_configs: list[ModelLoadBalancingConfiguration], + managed_credentials: Optional[dict] = None, + ) -> None: """ Load balancing model manager :param tenant_id: tenant_id @@ -439,10 +437,7 @@ def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]: :return: """ cache_key = "model_lb_index:{}:{}:{}:{}".format( - self._tenant_id, - self._provider, - self._model_type.value, - self._model + self._tenant_id, self._provider, self._model_type.value, self._model ) cooldown_load_balancing_configs = [] @@ -473,10 +468,12 @@ def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]: continue - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): - logger.info(f"Model LB\nid: {config.id}\nname:{config.name}\n" - f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n" - f"model_type: {self._model_type.value}\nmodel: {self._model}") + if bool(os.environ.get("DEBUG", "False").lower() == "true"): + logger.info( + f"Model LB\nid: {config.id}\nname:{config.name}\n" + f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n" + f"model_type: {self._model_type.value}\nmodel: {self._model}" + ) return config @@ -490,14 +487,10 @@ def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> :return: """ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( - self._tenant_id, - self._provider, - self._model_type.value, - self._model, - config.id + self._tenant_id, self._provider, self._model_type.value, self._model, config.id ) - redis_client.setex(cooldown_cache_key, expire, 'true') + redis_client.setex(cooldown_cache_key, expire, "true") def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: """ @@ -506,11 +499,7 @@ def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: :return: """ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( - self._tenant_id, - self._provider, - self._model_type.value, - self._model, - config.id + self._tenant_id, self._provider, self._model_type.value, self._model, config.id ) res = redis_client.exists(cooldown_cache_key) @@ -518,11 +507,9 @@ def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: return res @staticmethod - def get_config_in_cooldown_and_ttl(tenant_id: str, - provider: str, - model_type: ModelType, - model: str, - config_id: str) -> tuple[bool, int]: + def get_config_in_cooldown_and_ttl( + tenant_id: str, provider: str, model_type: ModelType, model: str, config_id: str + ) -> tuple[bool, int]: """ Get model load balancing config is in cooldown and ttl :param tenant_id: workspace id @@ -533,11 +520,7 @@ def get_config_in_cooldown_and_ttl(tenant_id: str, :return: """ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( - tenant_id, - provider, - model_type.value, - model, - config_id + tenant_id, provider, model_type.value, model, config_id ) ttl = redis_client.ttl(cooldown_cache_key)